Detailed changes
@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"strings"
- "time"
"charm.land/bubbles/v2/help"
"charm.land/bubbles/v2/key"
@@ -14,13 +13,19 @@ import (
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/oauth"
- "github.com/charmbracelet/crush/internal/oauth/hyper"
"github.com/charmbracelet/crush/internal/ui/common"
"github.com/charmbracelet/crush/internal/uiutil"
uv "github.com/charmbracelet/ultraviolet"
"github.com/pkg/browser"
)
+type OAuthProvider interface {
+ name() string
+ initiateAuth() tea.Msg
+ startPolling(deviceCode string, expiresIn int) tea.Cmd
+ stopPolling() tea.Msg
+}
+
// OAuthState represents the current state of the device flow.
type OAuthState int
@@ -38,9 +43,10 @@ const OAuthID = "oauth"
type OAuth struct {
com *common.Common
- provider catwalk.Provider
- model config.SelectedModel
- modelType config.SelectedModelType
+ provider catwalk.Provider
+ model config.SelectedModel
+ modelType config.SelectedModelType
+ oAuthProvider OAuthProvider
State OAuthState
@@ -63,8 +69,8 @@ type OAuth struct {
var _ Dialog = (*OAuth)(nil)
-// NewOAuth creates a new device flow component.
-func NewOAuth(com *common.Common, provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType) (*OAuth, error) {
+// newOAuth creates a new device flow component.
+func newOAuth(com *common.Common, provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType, oAuthProvider OAuthProvider) (*OAuth, error) {
t := com.Styles
m := OAuth{}
@@ -72,6 +78,7 @@ func NewOAuth(com *common.Common, provider catwalk.Provider, model config.Select
m.provider = provider
m.model = model
m.modelType = modelType
+ m.oAuthProvider = oAuthProvider
m.width = 60
m.State = OAuthStateInitializing
@@ -103,7 +110,7 @@ func (m *OAuth) ID() string {
// Init implements Dialog.
func (m *OAuth) Init() tea.Cmd {
- return tea.Batch(m.spinner.Tick, m.initiateDeviceAuth)
+ return tea.Batch(m.spinner.Tick, m.oAuthProvider.initiateAuth)
}
// HandleMsg handles messages and state transitions.
@@ -151,16 +158,16 @@ func (m *OAuth) HandleMsg(msg tea.Msg) Action {
m.expiresIn = msg.ExpiresIn
m.verificationURL = msg.VerificationURL
m.State = OAuthStateDisplay
- return ActionCmd{m.startPolling(msg.DeviceCode)}
+ return ActionCmd{m.oAuthProvider.startPolling(msg.DeviceCode, msg.ExpiresIn)}
case ActionCompleteOAuth:
m.State = OAuthStateSuccess
m.token = msg.Token
- return ActionCmd{m.stopPolling}
+ return ActionCmd{m.oAuthProvider.stopPolling}
case ActionOAuthErrored:
m.State = OAuthStateError
- cmd := tea.Batch(m.stopPolling, uiutil.ReportError(msg.Error))
+ cmd := tea.Batch(m.oAuthProvider.stopPolling, uiutil.ReportError(msg.Error))
return ActionCmd{cmd}
}
return nil
@@ -204,7 +211,7 @@ func (m *OAuth) headerContent() string {
dialogStyle = t.Dialog.View.Width(m.width)
headerOffset = titleStyle.GetHorizontalFrameSize() + dialogStyle.GetHorizontalFrameSize()
)
- return common.DialogTitle(t, titleStyle.Render("Authenticate with Hyper"), m.width-headerOffset)
+ return common.DialogTitle(t, titleStyle.Render("Authenticate with "+m.oAuthProvider.name()), m.width-headerOffset)
}
func (m *OAuth) innerDialogContent() string {
@@ -357,67 +364,3 @@ func (m *OAuth) saveKeyAndContinue() Action {
ModelType: m.modelType,
}
}
-
-func (m *OAuth) initiateDeviceAuth() tea.Msg {
- minimumWait := 750 * time.Millisecond
- startTime := time.Now()
-
- ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
- defer cancel()
-
- authResp, err := hyper.InitiateDeviceAuth(ctx)
-
- ellapsed := time.Since(startTime)
- if ellapsed < minimumWait {
- time.Sleep(minimumWait - ellapsed)
- }
-
- if err != nil {
- return ActionOAuthErrored{fmt.Errorf("failed to initiate device auth: %w", err)}
- }
-
- return ActionInitiateOAuth{
- DeviceCode: authResp.DeviceCode,
- UserCode: authResp.UserCode,
- ExpiresIn: authResp.ExpiresIn,
- VerificationURL: authResp.VerificationURL,
- }
-}
-
-// startPolling starts polling for the device token.
-func (m *OAuth) startPolling(deviceCode string) tea.Cmd {
- return func() tea.Msg {
- ctx, cancel := context.WithCancel(context.Background())
- m.cancelFunc = cancel
-
- refreshToken, err := hyper.PollForToken(ctx, deviceCode, m.expiresIn)
- if err != nil {
- if ctx.Err() != nil {
- return nil
- }
- return ActionOAuthErrored{err}
- }
-
- token, err := hyper.ExchangeToken(ctx, refreshToken)
- if err != nil {
- return ActionOAuthErrored{fmt.Errorf("token exchange failed: %w", err)}
- }
-
- introspect, err := hyper.IntrospectToken(ctx, token.AccessToken)
- if err != nil {
- return ActionOAuthErrored{fmt.Errorf("token introspection failed: %w", err)}
- }
- if !introspect.Active {
- return ActionOAuthErrored{fmt.Errorf("access token is not active")}
- }
-
- return ActionCompleteOAuth{token}
- }
-}
-
-func (m *OAuth) stopPolling() tea.Msg {
- if m.cancelFunc != nil {
- m.cancelFunc()
- }
- return nil
-}
@@ -0,0 +1,90 @@
+package dialog
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ tea "charm.land/bubbletea/v2"
+ "github.com/charmbracelet/catwalk/pkg/catwalk"
+ "github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/oauth/hyper"
+ "github.com/charmbracelet/crush/internal/ui/common"
+)
+
+func NewOAuthHyper(com *common.Common, provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType) (*OAuth, error) {
+ return newOAuth(com, provider, model, modelType, &OAuthHyper{})
+}
+
+type OAuthHyper struct {
+ cancelFunc func()
+}
+
+var _ OAuthProvider = (*OAuthHyper)(nil)
+
+func (m *OAuthHyper) name() string {
+ return "Hyper"
+}
+
+func (m *OAuthHyper) initiateAuth() tea.Msg {
+ minimumWait := 750 * time.Millisecond
+ startTime := time.Now()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ authResp, err := hyper.InitiateDeviceAuth(ctx)
+
+ ellapsed := time.Since(startTime)
+ if ellapsed < minimumWait {
+ time.Sleep(minimumWait - ellapsed)
+ }
+
+ if err != nil {
+ return ActionOAuthErrored{fmt.Errorf("failed to initiate device auth: %w", err)}
+ }
+
+ return ActionInitiateOAuth{
+ DeviceCode: authResp.DeviceCode,
+ UserCode: authResp.UserCode,
+ ExpiresIn: authResp.ExpiresIn,
+ VerificationURL: authResp.VerificationURL,
+ }
+}
+
+func (m *OAuthHyper) startPolling(deviceCode string, expiresIn int) tea.Cmd {
+ return func() tea.Msg {
+ ctx, cancel := context.WithCancel(context.Background())
+ m.cancelFunc = cancel
+
+ refreshToken, err := hyper.PollForToken(ctx, deviceCode, expiresIn)
+ if err != nil {
+ if ctx.Err() != nil {
+ return nil
+ }
+ return ActionOAuthErrored{err}
+ }
+
+ token, err := hyper.ExchangeToken(ctx, refreshToken)
+ if err != nil {
+ return ActionOAuthErrored{fmt.Errorf("token exchange failed: %w", err)}
+ }
+
+ introspect, err := hyper.IntrospectToken(ctx, token.AccessToken)
+ if err != nil {
+ return ActionOAuthErrored{fmt.Errorf("token introspection failed: %w", err)}
+ }
+ if !introspect.Active {
+ return ActionOAuthErrored{fmt.Errorf("access token is not active")}
+ }
+
+ return ActionCompleteOAuth{token}
+ }
+}
+
+func (m *OAuthHyper) stopPolling() tea.Msg {
+ if m.cancelFunc != nil {
+ m.cancelFunc()
+ }
+ return nil
+}
@@ -1020,9 +1020,9 @@ func substituteArgs(content string, args map[string]string) string {
func (m *UI) openAuthenticationDialog(provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType) tea.Cmd {
switch provider.ID {
case "hyper":
- return m.openOAuthDialog(provider, model, modelType)
+ return m.openOAuthHyperDialog(provider, model, modelType)
case catwalk.InferenceProviderCopilot:
- return m.openOAuthDialog(provider, model, modelType)
+ return m.openOAuthCopilotDialog(provider, model, modelType)
default:
return m.openAPIKeyInputDialog(provider, model, modelType)
}
@@ -1042,13 +1042,13 @@ func (m *UI) openAPIKeyInputDialog(provider catwalk.Provider, model config.Selec
return nil
}
-func (m *UI) openOAuthDialog(provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType) tea.Cmd {
+func (m *UI) openOAuthHyperDialog(provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType) tea.Cmd {
if m.dialog.ContainsDialog(dialog.OAuthID) {
m.dialog.BringToFront(dialog.OAuthID)
return nil
}
- oAuthDialog, err := dialog.NewOAuth(m.com, provider, model, modelType)
+ oAuthDialog, err := dialog.NewOAuthHyper(m.com, provider, model, modelType)
if err != nil {
return uiutil.ReportError(err)
}
@@ -1057,6 +1057,10 @@ func (m *UI) openOAuthDialog(provider catwalk.Provider, model config.SelectedMod
return oAuthDialog.Init()
}
+func (m *UI) openOAuthCopilotDialog(provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType) tea.Cmd {
+ panic("TODO")
+}
+
func (m *UI) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd {
var cmds []tea.Cmd