From ccb1a643e9b0959e19bd0263e493a8b7a8544972 Mon Sep 17 00:00:00 2001 From: Andrey Nering Date: Wed, 14 Jan 2026 18:09:02 -0300 Subject: [PATCH] refactor: make oauth dialog generic and move provider logic to interface --- internal/ui/dialog/oauth.go | 95 +++++++------------------------ internal/ui/dialog/oauth_hyper.go | 90 +++++++++++++++++++++++++++++ internal/ui/model/ui.go | 12 ++-- 3 files changed, 117 insertions(+), 80 deletions(-) create mode 100644 internal/ui/dialog/oauth_hyper.go diff --git a/internal/ui/dialog/oauth.go b/internal/ui/dialog/oauth.go index 68f22a76e037278d5884f44c9cfbdeea2f3549f9..f87a583be4160edf1d6a1e42de29f7dd01a513bc 100644 --- a/internal/ui/dialog/oauth.go +++ b/internal/ui/dialog/oauth.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "strings" - "time" "charm.land/bubbles/v2/help" "charm.land/bubbles/v2/key" @@ -14,13 +13,19 @@ import ( "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/oauth" - "github.com/charmbracelet/crush/internal/oauth/hyper" "github.com/charmbracelet/crush/internal/ui/common" "github.com/charmbracelet/crush/internal/uiutil" uv "github.com/charmbracelet/ultraviolet" "github.com/pkg/browser" ) +type OAuthProvider interface { + name() string + initiateAuth() tea.Msg + startPolling(deviceCode string, expiresIn int) tea.Cmd + stopPolling() tea.Msg +} + // OAuthState represents the current state of the device flow. type OAuthState int @@ -38,9 +43,10 @@ const OAuthID = "oauth" type OAuth struct { com *common.Common - provider catwalk.Provider - model config.SelectedModel - modelType config.SelectedModelType + provider catwalk.Provider + model config.SelectedModel + modelType config.SelectedModelType + oAuthProvider OAuthProvider State OAuthState @@ -63,8 +69,8 @@ type OAuth struct { var _ Dialog = (*OAuth)(nil) -// NewOAuth creates a new device flow component. -func NewOAuth(com *common.Common, provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType) (*OAuth, error) { +// newOAuth creates a new device flow component. +func newOAuth(com *common.Common, provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType, oAuthProvider OAuthProvider) (*OAuth, error) { t := com.Styles m := OAuth{} @@ -72,6 +78,7 @@ func NewOAuth(com *common.Common, provider catwalk.Provider, model config.Select m.provider = provider m.model = model m.modelType = modelType + m.oAuthProvider = oAuthProvider m.width = 60 m.State = OAuthStateInitializing @@ -103,7 +110,7 @@ func (m *OAuth) ID() string { // Init implements Dialog. func (m *OAuth) Init() tea.Cmd { - return tea.Batch(m.spinner.Tick, m.initiateDeviceAuth) + return tea.Batch(m.spinner.Tick, m.oAuthProvider.initiateAuth) } // HandleMsg handles messages and state transitions. @@ -151,16 +158,16 @@ func (m *OAuth) HandleMsg(msg tea.Msg) Action { m.expiresIn = msg.ExpiresIn m.verificationURL = msg.VerificationURL m.State = OAuthStateDisplay - return ActionCmd{m.startPolling(msg.DeviceCode)} + return ActionCmd{m.oAuthProvider.startPolling(msg.DeviceCode, msg.ExpiresIn)} case ActionCompleteOAuth: m.State = OAuthStateSuccess m.token = msg.Token - return ActionCmd{m.stopPolling} + return ActionCmd{m.oAuthProvider.stopPolling} case ActionOAuthErrored: m.State = OAuthStateError - cmd := tea.Batch(m.stopPolling, uiutil.ReportError(msg.Error)) + cmd := tea.Batch(m.oAuthProvider.stopPolling, uiutil.ReportError(msg.Error)) return ActionCmd{cmd} } return nil @@ -204,7 +211,7 @@ func (m *OAuth) headerContent() string { dialogStyle = t.Dialog.View.Width(m.width) headerOffset = titleStyle.GetHorizontalFrameSize() + dialogStyle.GetHorizontalFrameSize() ) - return common.DialogTitle(t, titleStyle.Render("Authenticate with Hyper"), m.width-headerOffset) + return common.DialogTitle(t, titleStyle.Render("Authenticate with "+m.oAuthProvider.name()), m.width-headerOffset) } func (m *OAuth) innerDialogContent() string { @@ -357,67 +364,3 @@ func (m *OAuth) saveKeyAndContinue() Action { ModelType: m.modelType, } } - -func (m *OAuth) initiateDeviceAuth() tea.Msg { - minimumWait := 750 * time.Millisecond - startTime := time.Now() - - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - authResp, err := hyper.InitiateDeviceAuth(ctx) - - ellapsed := time.Since(startTime) - if ellapsed < minimumWait { - time.Sleep(minimumWait - ellapsed) - } - - if err != nil { - return ActionOAuthErrored{fmt.Errorf("failed to initiate device auth: %w", err)} - } - - return ActionInitiateOAuth{ - DeviceCode: authResp.DeviceCode, - UserCode: authResp.UserCode, - ExpiresIn: authResp.ExpiresIn, - VerificationURL: authResp.VerificationURL, - } -} - -// startPolling starts polling for the device token. -func (m *OAuth) startPolling(deviceCode string) tea.Cmd { - return func() tea.Msg { - ctx, cancel := context.WithCancel(context.Background()) - m.cancelFunc = cancel - - refreshToken, err := hyper.PollForToken(ctx, deviceCode, m.expiresIn) - if err != nil { - if ctx.Err() != nil { - return nil - } - return ActionOAuthErrored{err} - } - - token, err := hyper.ExchangeToken(ctx, refreshToken) - if err != nil { - return ActionOAuthErrored{fmt.Errorf("token exchange failed: %w", err)} - } - - introspect, err := hyper.IntrospectToken(ctx, token.AccessToken) - if err != nil { - return ActionOAuthErrored{fmt.Errorf("token introspection failed: %w", err)} - } - if !introspect.Active { - return ActionOAuthErrored{fmt.Errorf("access token is not active")} - } - - return ActionCompleteOAuth{token} - } -} - -func (m *OAuth) stopPolling() tea.Msg { - if m.cancelFunc != nil { - m.cancelFunc() - } - return nil -} diff --git a/internal/ui/dialog/oauth_hyper.go b/internal/ui/dialog/oauth_hyper.go new file mode 100644 index 0000000000000000000000000000000000000000..65d3e34e817384e8e123e4720f7cdd310213aa87 --- /dev/null +++ b/internal/ui/dialog/oauth_hyper.go @@ -0,0 +1,90 @@ +package dialog + +import ( + "context" + "fmt" + "time" + + tea "charm.land/bubbletea/v2" + "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/oauth/hyper" + "github.com/charmbracelet/crush/internal/ui/common" +) + +func NewOAuthHyper(com *common.Common, provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType) (*OAuth, error) { + return newOAuth(com, provider, model, modelType, &OAuthHyper{}) +} + +type OAuthHyper struct { + cancelFunc func() +} + +var _ OAuthProvider = (*OAuthHyper)(nil) + +func (m *OAuthHyper) name() string { + return "Hyper" +} + +func (m *OAuthHyper) initiateAuth() tea.Msg { + minimumWait := 750 * time.Millisecond + startTime := time.Now() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + authResp, err := hyper.InitiateDeviceAuth(ctx) + + ellapsed := time.Since(startTime) + if ellapsed < minimumWait { + time.Sleep(minimumWait - ellapsed) + } + + if err != nil { + return ActionOAuthErrored{fmt.Errorf("failed to initiate device auth: %w", err)} + } + + return ActionInitiateOAuth{ + DeviceCode: authResp.DeviceCode, + UserCode: authResp.UserCode, + ExpiresIn: authResp.ExpiresIn, + VerificationURL: authResp.VerificationURL, + } +} + +func (m *OAuthHyper) startPolling(deviceCode string, expiresIn int) tea.Cmd { + return func() tea.Msg { + ctx, cancel := context.WithCancel(context.Background()) + m.cancelFunc = cancel + + refreshToken, err := hyper.PollForToken(ctx, deviceCode, expiresIn) + if err != nil { + if ctx.Err() != nil { + return nil + } + return ActionOAuthErrored{err} + } + + token, err := hyper.ExchangeToken(ctx, refreshToken) + if err != nil { + return ActionOAuthErrored{fmt.Errorf("token exchange failed: %w", err)} + } + + introspect, err := hyper.IntrospectToken(ctx, token.AccessToken) + if err != nil { + return ActionOAuthErrored{fmt.Errorf("token introspection failed: %w", err)} + } + if !introspect.Active { + return ActionOAuthErrored{fmt.Errorf("access token is not active")} + } + + return ActionCompleteOAuth{token} + } +} + +func (m *OAuthHyper) stopPolling() tea.Msg { + if m.cancelFunc != nil { + m.cancelFunc() + } + return nil +} diff --git a/internal/ui/model/ui.go b/internal/ui/model/ui.go index a61c3050ac7f982c0141b8525b304c92e379b5bc..e9bd2fd3cedb5dd93bb5e1bfe31e2fb2b6f65f72 100644 --- a/internal/ui/model/ui.go +++ b/internal/ui/model/ui.go @@ -1020,9 +1020,9 @@ func substituteArgs(content string, args map[string]string) string { func (m *UI) openAuthenticationDialog(provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType) tea.Cmd { switch provider.ID { case "hyper": - return m.openOAuthDialog(provider, model, modelType) + return m.openOAuthHyperDialog(provider, model, modelType) case catwalk.InferenceProviderCopilot: - return m.openOAuthDialog(provider, model, modelType) + return m.openOAuthCopilotDialog(provider, model, modelType) default: return m.openAPIKeyInputDialog(provider, model, modelType) } @@ -1042,13 +1042,13 @@ func (m *UI) openAPIKeyInputDialog(provider catwalk.Provider, model config.Selec return nil } -func (m *UI) openOAuthDialog(provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType) tea.Cmd { +func (m *UI) openOAuthHyperDialog(provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType) tea.Cmd { if m.dialog.ContainsDialog(dialog.OAuthID) { m.dialog.BringToFront(dialog.OAuthID) return nil } - oAuthDialog, err := dialog.NewOAuth(m.com, provider, model, modelType) + oAuthDialog, err := dialog.NewOAuthHyper(m.com, provider, model, modelType) if err != nil { return uiutil.ReportError(err) } @@ -1057,6 +1057,10 @@ func (m *UI) openOAuthDialog(provider catwalk.Provider, model config.SelectedMod return oAuthDialog.Init() } +func (m *UI) openOAuthCopilotDialog(provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType) tea.Cmd { + panic("TODO") +} + func (m *UI) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd { var cmds []tea.Cmd