refactor: make oauth dialog generic and move provider logic to interface

Andrey Nering created

Change summary

internal/ui/dialog/oauth.go       | 95 ++++++--------------------------
internal/ui/dialog/oauth_hyper.go | 90 +++++++++++++++++++++++++++++++
internal/ui/model/ui.go           | 12 ++-
3 files changed, 117 insertions(+), 80 deletions(-)

Detailed changes

internal/ui/dialog/oauth.go 🔗

@@ -4,7 +4,6 @@ import (
 	"context"
 	"fmt"
 	"strings"
-	"time"
 
 	"charm.land/bubbles/v2/help"
 	"charm.land/bubbles/v2/key"
@@ -14,13 +13,19 @@ import (
 	"github.com/charmbracelet/catwalk/pkg/catwalk"
 	"github.com/charmbracelet/crush/internal/config"
 	"github.com/charmbracelet/crush/internal/oauth"
-	"github.com/charmbracelet/crush/internal/oauth/hyper"
 	"github.com/charmbracelet/crush/internal/ui/common"
 	"github.com/charmbracelet/crush/internal/uiutil"
 	uv "github.com/charmbracelet/ultraviolet"
 	"github.com/pkg/browser"
 )
 
+type OAuthProvider interface {
+	name() string
+	initiateAuth() tea.Msg
+	startPolling(deviceCode string, expiresIn int) tea.Cmd
+	stopPolling() tea.Msg
+}
+
 // OAuthState represents the current state of the device flow.
 type OAuthState int
 
@@ -38,9 +43,10 @@ const OAuthID = "oauth"
 type OAuth struct {
 	com *common.Common
 
-	provider  catwalk.Provider
-	model     config.SelectedModel
-	modelType config.SelectedModelType
+	provider      catwalk.Provider
+	model         config.SelectedModel
+	modelType     config.SelectedModelType
+	oAuthProvider OAuthProvider
 
 	State OAuthState
 
@@ -63,8 +69,8 @@ type OAuth struct {
 
 var _ Dialog = (*OAuth)(nil)
 
-// NewOAuth creates a new device flow component.
-func NewOAuth(com *common.Common, provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType) (*OAuth, error) {
+// newOAuth creates a new device flow component.
+func newOAuth(com *common.Common, provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType, oAuthProvider OAuthProvider) (*OAuth, error) {
 	t := com.Styles
 
 	m := OAuth{}
@@ -72,6 +78,7 @@ func NewOAuth(com *common.Common, provider catwalk.Provider, model config.Select
 	m.provider = provider
 	m.model = model
 	m.modelType = modelType
+	m.oAuthProvider = oAuthProvider
 	m.width = 60
 	m.State = OAuthStateInitializing
 
@@ -103,7 +110,7 @@ func (m *OAuth) ID() string {
 
 // Init implements Dialog.
 func (m *OAuth) Init() tea.Cmd {
-	return tea.Batch(m.spinner.Tick, m.initiateDeviceAuth)
+	return tea.Batch(m.spinner.Tick, m.oAuthProvider.initiateAuth)
 }
 
 // HandleMsg handles messages and state transitions.
@@ -151,16 +158,16 @@ func (m *OAuth) HandleMsg(msg tea.Msg) Action {
 		m.expiresIn = msg.ExpiresIn
 		m.verificationURL = msg.VerificationURL
 		m.State = OAuthStateDisplay
-		return ActionCmd{m.startPolling(msg.DeviceCode)}
+		return ActionCmd{m.oAuthProvider.startPolling(msg.DeviceCode, msg.ExpiresIn)}
 
 	case ActionCompleteOAuth:
 		m.State = OAuthStateSuccess
 		m.token = msg.Token
-		return ActionCmd{m.stopPolling}
+		return ActionCmd{m.oAuthProvider.stopPolling}
 
 	case ActionOAuthErrored:
 		m.State = OAuthStateError
-		cmd := tea.Batch(m.stopPolling, uiutil.ReportError(msg.Error))
+		cmd := tea.Batch(m.oAuthProvider.stopPolling, uiutil.ReportError(msg.Error))
 		return ActionCmd{cmd}
 	}
 	return nil
@@ -204,7 +211,7 @@ func (m *OAuth) headerContent() string {
 		dialogStyle  = t.Dialog.View.Width(m.width)
 		headerOffset = titleStyle.GetHorizontalFrameSize() + dialogStyle.GetHorizontalFrameSize()
 	)
-	return common.DialogTitle(t, titleStyle.Render("Authenticate with Hyper"), m.width-headerOffset)
+	return common.DialogTitle(t, titleStyle.Render("Authenticate with "+m.oAuthProvider.name()), m.width-headerOffset)
 }
 
 func (m *OAuth) innerDialogContent() string {
@@ -357,67 +364,3 @@ func (m *OAuth) saveKeyAndContinue() Action {
 		ModelType: m.modelType,
 	}
 }
-
-func (m *OAuth) initiateDeviceAuth() tea.Msg {
-	minimumWait := 750 * time.Millisecond
-	startTime := time.Now()
-
-	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
-	defer cancel()
-
-	authResp, err := hyper.InitiateDeviceAuth(ctx)
-
-	ellapsed := time.Since(startTime)
-	if ellapsed < minimumWait {
-		time.Sleep(minimumWait - ellapsed)
-	}
-
-	if err != nil {
-		return ActionOAuthErrored{fmt.Errorf("failed to initiate device auth: %w", err)}
-	}
-
-	return ActionInitiateOAuth{
-		DeviceCode:      authResp.DeviceCode,
-		UserCode:        authResp.UserCode,
-		ExpiresIn:       authResp.ExpiresIn,
-		VerificationURL: authResp.VerificationURL,
-	}
-}
-
-// startPolling starts polling for the device token.
-func (m *OAuth) startPolling(deviceCode string) tea.Cmd {
-	return func() tea.Msg {
-		ctx, cancel := context.WithCancel(context.Background())
-		m.cancelFunc = cancel
-
-		refreshToken, err := hyper.PollForToken(ctx, deviceCode, m.expiresIn)
-		if err != nil {
-			if ctx.Err() != nil {
-				return nil
-			}
-			return ActionOAuthErrored{err}
-		}
-
-		token, err := hyper.ExchangeToken(ctx, refreshToken)
-		if err != nil {
-			return ActionOAuthErrored{fmt.Errorf("token exchange failed: %w", err)}
-		}
-
-		introspect, err := hyper.IntrospectToken(ctx, token.AccessToken)
-		if err != nil {
-			return ActionOAuthErrored{fmt.Errorf("token introspection failed: %w", err)}
-		}
-		if !introspect.Active {
-			return ActionOAuthErrored{fmt.Errorf("access token is not active")}
-		}
-
-		return ActionCompleteOAuth{token}
-	}
-}
-
-func (m *OAuth) stopPolling() tea.Msg {
-	if m.cancelFunc != nil {
-		m.cancelFunc()
-	}
-	return nil
-}

internal/ui/dialog/oauth_hyper.go 🔗

@@ -0,0 +1,90 @@
+package dialog
+
+import (
+	"context"
+	"fmt"
+	"time"
+
+	tea "charm.land/bubbletea/v2"
+	"github.com/charmbracelet/catwalk/pkg/catwalk"
+	"github.com/charmbracelet/crush/internal/config"
+	"github.com/charmbracelet/crush/internal/oauth/hyper"
+	"github.com/charmbracelet/crush/internal/ui/common"
+)
+
+func NewOAuthHyper(com *common.Common, provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType) (*OAuth, error) {
+	return newOAuth(com, provider, model, modelType, &OAuthHyper{})
+}
+
+type OAuthHyper struct {
+	cancelFunc func()
+}
+
+var _ OAuthProvider = (*OAuthHyper)(nil)
+
+func (m *OAuthHyper) name() string {
+	return "Hyper"
+}
+
+func (m *OAuthHyper) initiateAuth() tea.Msg {
+	minimumWait := 750 * time.Millisecond
+	startTime := time.Now()
+
+	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+	defer cancel()
+
+	authResp, err := hyper.InitiateDeviceAuth(ctx)
+
+	ellapsed := time.Since(startTime)
+	if ellapsed < minimumWait {
+		time.Sleep(minimumWait - ellapsed)
+	}
+
+	if err != nil {
+		return ActionOAuthErrored{fmt.Errorf("failed to initiate device auth: %w", err)}
+	}
+
+	return ActionInitiateOAuth{
+		DeviceCode:      authResp.DeviceCode,
+		UserCode:        authResp.UserCode,
+		ExpiresIn:       authResp.ExpiresIn,
+		VerificationURL: authResp.VerificationURL,
+	}
+}
+
+func (m *OAuthHyper) startPolling(deviceCode string, expiresIn int) tea.Cmd {
+	return func() tea.Msg {
+		ctx, cancel := context.WithCancel(context.Background())
+		m.cancelFunc = cancel
+
+		refreshToken, err := hyper.PollForToken(ctx, deviceCode, expiresIn)
+		if err != nil {
+			if ctx.Err() != nil {
+				return nil
+			}
+			return ActionOAuthErrored{err}
+		}
+
+		token, err := hyper.ExchangeToken(ctx, refreshToken)
+		if err != nil {
+			return ActionOAuthErrored{fmt.Errorf("token exchange failed: %w", err)}
+		}
+
+		introspect, err := hyper.IntrospectToken(ctx, token.AccessToken)
+		if err != nil {
+			return ActionOAuthErrored{fmt.Errorf("token introspection failed: %w", err)}
+		}
+		if !introspect.Active {
+			return ActionOAuthErrored{fmt.Errorf("access token is not active")}
+		}
+
+		return ActionCompleteOAuth{token}
+	}
+}
+
+func (m *OAuthHyper) stopPolling() tea.Msg {
+	if m.cancelFunc != nil {
+		m.cancelFunc()
+	}
+	return nil
+}

internal/ui/model/ui.go 🔗

@@ -1020,9 +1020,9 @@ func substituteArgs(content string, args map[string]string) string {
 func (m *UI) openAuthenticationDialog(provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType) tea.Cmd {
 	switch provider.ID {
 	case "hyper":
-		return m.openOAuthDialog(provider, model, modelType)
+		return m.openOAuthHyperDialog(provider, model, modelType)
 	case catwalk.InferenceProviderCopilot:
-		return m.openOAuthDialog(provider, model, modelType)
+		return m.openOAuthCopilotDialog(provider, model, modelType)
 	default:
 		return m.openAPIKeyInputDialog(provider, model, modelType)
 	}
@@ -1042,13 +1042,13 @@ func (m *UI) openAPIKeyInputDialog(provider catwalk.Provider, model config.Selec
 	return nil
 }
 
-func (m *UI) openOAuthDialog(provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType) tea.Cmd {
+func (m *UI) openOAuthHyperDialog(provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType) tea.Cmd {
 	if m.dialog.ContainsDialog(dialog.OAuthID) {
 		m.dialog.BringToFront(dialog.OAuthID)
 		return nil
 	}
 
-	oAuthDialog, err := dialog.NewOAuth(m.com, provider, model, modelType)
+	oAuthDialog, err := dialog.NewOAuthHyper(m.com, provider, model, modelType)
 	if err != nil {
 		return uiutil.ReportError(err)
 	}
@@ -1057,6 +1057,10 @@ func (m *UI) openOAuthDialog(provider catwalk.Provider, model config.SelectedMod
 	return oAuthDialog.Init()
 }
 
+func (m *UI) openOAuthCopilotDialog(provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType) tea.Cmd {
+	panic("TODO")
+}
+
 func (m *UI) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd {
 	var cmds []tea.Cmd