diff --git a/internal/ui/dialog/actions.go b/internal/ui/dialog/actions.go index 81911f9919be6c94ac158052b4a4e9b2236342a0..0048c4ac06540e22175d4663e1ee9123e1eba211 100644 --- a/internal/ui/dialog/actions.go +++ b/internal/ui/dialog/actions.go @@ -5,6 +5,7 @@ import ( "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/commands" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/oauth" "github.com/charmbracelet/crush/internal/permission" "github.com/charmbracelet/crush/internal/session" ) @@ -73,6 +74,28 @@ type ( } ) +// Messages for OAuth2 device flow dialog. +type ( + // ActionInitiateOAuth is sent when the device auth is initiated + // successfully. + ActionInitiateOAuth struct { + DeviceCode string + UserCode string + ExpiresIn int + VerificationURL string + } + + // ActionCompleteOAuth is sent when the device flow completes successfully. + ActionCompleteOAuth struct { + Token *oauth.Token + } + + // ActionOAuthErrored is sent when the device flow encounters an error. + ActionOAuthErrored struct { + Error error + } +) + // ActionCmd represents an action that carries a [tea.Cmd] to be passed to the // Bubble Tea program loop. type ActionCmd struct { diff --git a/internal/ui/dialog/oauth.go b/internal/ui/dialog/oauth.go new file mode 100644 index 0000000000000000000000000000000000000000..68f22a76e037278d5884f44c9cfbdeea2f3549f9 --- /dev/null +++ b/internal/ui/dialog/oauth.go @@ -0,0 +1,423 @@ +package dialog + +import ( + "context" + "fmt" + "strings" + "time" + + "charm.land/bubbles/v2/help" + "charm.land/bubbles/v2/key" + "charm.land/bubbles/v2/spinner" + tea "charm.land/bubbletea/v2" + "charm.land/lipgloss/v2" + "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/oauth" + "github.com/charmbracelet/crush/internal/oauth/hyper" + "github.com/charmbracelet/crush/internal/ui/common" + "github.com/charmbracelet/crush/internal/uiutil" + uv "github.com/charmbracelet/ultraviolet" + "github.com/pkg/browser" +) + +// OAuthState represents the current state of the device flow. +type OAuthState int + +const ( + OAuthStateInitializing OAuthState = iota + OAuthStateDisplay + OAuthStateSuccess + OAuthStateError +) + +// OAuthID is the identifier for the model selection dialog. +const OAuthID = "oauth" + +// OAuth handles the OAuth flow authentication. +type OAuth struct { + com *common.Common + + provider catwalk.Provider + model config.SelectedModel + modelType config.SelectedModelType + + State OAuthState + + spinner spinner.Model + help help.Model + keyMap struct { + Copy key.Binding + Submit key.Binding + Close key.Binding + } + + width int + deviceCode string + userCode string + verificationURL string + expiresIn int + token *oauth.Token + cancelFunc context.CancelFunc +} + +var _ Dialog = (*OAuth)(nil) + +// NewOAuth creates a new device flow component. +func NewOAuth(com *common.Common, provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType) (*OAuth, error) { + t := com.Styles + + m := OAuth{} + m.com = com + m.provider = provider + m.model = model + m.modelType = modelType + m.width = 60 + m.State = OAuthStateInitializing + + m.spinner = spinner.New( + spinner.WithSpinner(spinner.Dot), + spinner.WithStyle(t.Base.Foreground(t.GreenLight)), + ) + + m.help = help.New() + m.help.Styles = t.DialogHelpStyles() + + m.keyMap.Copy = key.NewBinding( + key.WithKeys("c"), + key.WithHelp("c", "copy code"), + ) + m.keyMap.Submit = key.NewBinding( + key.WithKeys("enter", "ctrl+y"), + key.WithHelp("enter", "copy & open"), + ) + m.keyMap.Close = CloseKey + + return &m, nil +} + +// ID implements Dialog. +func (m *OAuth) ID() string { + return OAuthID +} + +// Init implements Dialog. +func (m *OAuth) Init() tea.Cmd { + return tea.Batch(m.spinner.Tick, m.initiateDeviceAuth) +} + +// HandleMsg handles messages and state transitions. +func (m *OAuth) HandleMsg(msg tea.Msg) Action { + switch msg := msg.(type) { + case spinner.TickMsg: + switch m.State { + case OAuthStateInitializing, OAuthStateDisplay: + var cmd tea.Cmd + m.spinner, cmd = m.spinner.Update(msg) + if cmd != nil { + return ActionCmd{cmd} + } + } + + case tea.KeyPressMsg: + switch { + case key.Matches(msg, m.keyMap.Copy): + cmd := m.copyCode() + return ActionCmd{cmd} + + case key.Matches(msg, m.keyMap.Submit): + switch m.State { + case OAuthStateSuccess: + return m.saveKeyAndContinue() + + default: + cmd := m.copyCodeAndOpenURL() + return ActionCmd{cmd} + } + + case key.Matches(msg, m.keyMap.Close): + switch m.State { + case OAuthStateSuccess: + return m.saveKeyAndContinue() + + default: + return ActionClose{} + } + } + + case ActionInitiateOAuth: + m.deviceCode = msg.DeviceCode + m.userCode = msg.UserCode + m.expiresIn = msg.ExpiresIn + m.verificationURL = msg.VerificationURL + m.State = OAuthStateDisplay + return ActionCmd{m.startPolling(msg.DeviceCode)} + + case ActionCompleteOAuth: + m.State = OAuthStateSuccess + m.token = msg.Token + return ActionCmd{m.stopPolling} + + case ActionOAuthErrored: + m.State = OAuthStateError + cmd := tea.Batch(m.stopPolling, uiutil.ReportError(msg.Error)) + return ActionCmd{cmd} + } + return nil +} + +// View renders the device flow dialog. +func (m *OAuth) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor { + var ( + t = m.com.Styles + dialogStyle = t.Dialog.View.Width(m.width) + view = dialogStyle.Render(m.dialogContent()) + ) + DrawCenterCursor(scr, area, view, nil) + return nil +} + +func (m *OAuth) dialogContent() string { + var ( + t = m.com.Styles + helpStyle = t.Dialog.HelpView + ) + + switch m.State { + case OAuthStateInitializing: + return m.innerDialogContent() + + default: + elements := []string{ + m.headerContent(), + m.innerDialogContent(), + helpStyle.Render(m.help.View(m)), + } + return strings.Join(elements, "\n") + } +} + +func (m *OAuth) headerContent() string { + var ( + t = m.com.Styles + titleStyle = t.Dialog.Title + dialogStyle = t.Dialog.View.Width(m.width) + headerOffset = titleStyle.GetHorizontalFrameSize() + dialogStyle.GetHorizontalFrameSize() + ) + return common.DialogTitle(t, titleStyle.Render("Authenticate with Hyper"), m.width-headerOffset) +} + +func (m *OAuth) innerDialogContent() string { + var ( + t = m.com.Styles + whiteStyle = lipgloss.NewStyle().Foreground(t.White) + primaryStyle = lipgloss.NewStyle().Foreground(t.Primary) + greenStyle = lipgloss.NewStyle().Foreground(t.GreenLight) + linkStyle = lipgloss.NewStyle().Foreground(t.GreenDark).Underline(true) + errorStyle = lipgloss.NewStyle().Foreground(t.Error) + mutedStyle = lipgloss.NewStyle().Foreground(t.FgMuted) + ) + + switch m.State { + case OAuthStateInitializing: + return lipgloss.NewStyle(). + Margin(1, 1). + Width(m.width - 2). + Align(lipgloss.Center). + Render( + greenStyle.Render(m.spinner.View()) + + mutedStyle.Render("Initializing..."), + ) + + case OAuthStateDisplay: + instructions := lipgloss.NewStyle(). + Margin(1). + Width(m.width - 2). + Render( + whiteStyle.Render("Press ") + + primaryStyle.Render("enter") + + whiteStyle.Render(" to copy the code below and open the browser."), + ) + + codeBox := lipgloss.NewStyle(). + Width(m.width-2). + Height(7). + Align(lipgloss.Center, lipgloss.Center). + Background(t.BgBaseLighter). + Margin(1). + Render( + lipgloss.NewStyle(). + Bold(true). + Foreground(t.White). + Render(m.userCode), + ) + + link := linkStyle.Hyperlink(m.verificationURL, "id=oauth-verify").Render(m.verificationURL) + url := mutedStyle. + Margin(0, 1). + Width(m.width - 2). + Render("Browser not opening? Refer to\n" + link) + + waiting := greenStyle. + Width(m.width - 2). + Margin(1). + Render(m.spinner.View() + "Verifying...") + + return lipgloss.JoinVertical( + lipgloss.Left, + instructions, + codeBox, + url, + waiting, + ) + + case OAuthStateSuccess: + return greenStyle. + Margin(1). + Width(m.width - 2). + Align(lipgloss.Center). + Render("Authentication successful!") + + case OAuthStateError: + return lipgloss.NewStyle(). + Margin(1). + Width(m.width - 2). + Render(errorStyle.Render("Authentication failed.")) + + default: + return "" + } +} + +// FullHelp returns the full help view. +func (m *OAuth) FullHelp() [][]key.Binding { + return [][]key.Binding{m.ShortHelp()} +} + +// ShortHelp returns the full help view. +func (m *OAuth) ShortHelp() []key.Binding { + switch m.State { + case OAuthStateError: + return []key.Binding{m.keyMap.Close} + + case OAuthStateSuccess: + return []key.Binding{ + key.NewBinding( + key.WithKeys("finish", "ctrl+y", "esc"), + key.WithHelp("enter", "finish"), + ), + } + + default: + return []key.Binding{ + m.keyMap.Copy, + m.keyMap.Submit, + m.keyMap.Close, + } + } +} + +func (d *OAuth) copyCode() tea.Cmd { + if d.State != OAuthStateDisplay { + return nil + } + return tea.Sequence( + tea.SetClipboard(d.userCode), + uiutil.ReportInfo("Code copied to clipboard"), + ) +} + +func (d *OAuth) copyCodeAndOpenURL() tea.Cmd { + if d.State != OAuthStateDisplay { + return nil + } + return tea.Sequence( + tea.SetClipboard(d.userCode), + func() tea.Msg { + if err := browser.OpenURL(d.verificationURL); err != nil { + return ActionOAuthErrored{fmt.Errorf("failed to open browser: %w", err)} + } + return nil + }, + uiutil.ReportInfo("Code copied and URL opened"), + ) +} + +func (m *OAuth) saveKeyAndContinue() Action { + cfg := m.com.Config() + + err := cfg.SetProviderAPIKey(string(m.provider.ID), m.token) + if err != nil { + return ActionCmd{uiutil.ReportError(fmt.Errorf("failed to save API key: %w", err))} + } + + return ActionSelectModel{ + Provider: m.provider, + Model: m.model, + ModelType: m.modelType, + } +} + +func (m *OAuth) initiateDeviceAuth() tea.Msg { + minimumWait := 750 * time.Millisecond + startTime := time.Now() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + authResp, err := hyper.InitiateDeviceAuth(ctx) + + ellapsed := time.Since(startTime) + if ellapsed < minimumWait { + time.Sleep(minimumWait - ellapsed) + } + + if err != nil { + return ActionOAuthErrored{fmt.Errorf("failed to initiate device auth: %w", err)} + } + + return ActionInitiateOAuth{ + DeviceCode: authResp.DeviceCode, + UserCode: authResp.UserCode, + ExpiresIn: authResp.ExpiresIn, + VerificationURL: authResp.VerificationURL, + } +} + +// startPolling starts polling for the device token. +func (m *OAuth) startPolling(deviceCode string) tea.Cmd { + return func() tea.Msg { + ctx, cancel := context.WithCancel(context.Background()) + m.cancelFunc = cancel + + refreshToken, err := hyper.PollForToken(ctx, deviceCode, m.expiresIn) + if err != nil { + if ctx.Err() != nil { + return nil + } + return ActionOAuthErrored{err} + } + + token, err := hyper.ExchangeToken(ctx, refreshToken) + if err != nil { + return ActionOAuthErrored{fmt.Errorf("token exchange failed: %w", err)} + } + + introspect, err := hyper.IntrospectToken(ctx, token.AccessToken) + if err != nil { + return ActionOAuthErrored{fmt.Errorf("token introspection failed: %w", err)} + } + if !introspect.Active { + return ActionOAuthErrored{fmt.Errorf("access token is not active")} + } + + return ActionCompleteOAuth{token} + } +} + +func (m *OAuth) stopPolling() tea.Msg { + if m.cancelFunc != nil { + m.cancelFunc() + } + return nil +} diff --git a/internal/ui/model/ui.go b/internal/ui/model/ui.go index b580c4d792eb60a571af5919ee582426279e0c87..a61c3050ac7f982c0141b8525b304c92e379b5bc 100644 --- a/internal/ui/model/ui.go +++ b/internal/ui/model/ui.go @@ -934,7 +934,7 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { _, isProviderConfigured := cfg.Providers.Get(msg.Model.Provider) if !isProviderConfigured { m.dialog.CloseDialog(dialog.ModelsID) - if cmd := m.openAPIKeyInputDialog(msg.Provider, msg.Model, msg.ModelType); cmd != nil { + if cmd := m.openAuthenticationDialog(msg.Provider, msg.Model, msg.ModelType); cmd != nil { cmds = append(cmds, cmd) } break @@ -950,6 +950,7 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { modelMsg := fmt.Sprintf("%s model changed to %s", msg.ModelType, msg.Model.Model) cmds = append(cmds, uiutil.ReportInfo(modelMsg)) m.dialog.CloseDialog(dialog.APIKeyInputID) + m.dialog.CloseDialog(dialog.OAuthID) m.dialog.CloseDialog(dialog.ModelsID) // TODO CHANGE case dialog.ActionPermissionResponse: @@ -1016,7 +1017,17 @@ func substituteArgs(content string, args map[string]string) string { return content } -// openAPIKeyInputDialog opens the API key input dialog. +func (m *UI) openAuthenticationDialog(provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType) tea.Cmd { + switch provider.ID { + case "hyper": + return m.openOAuthDialog(provider, model, modelType) + case catwalk.InferenceProviderCopilot: + return m.openOAuthDialog(provider, model, modelType) + default: + return m.openAPIKeyInputDialog(provider, model, modelType) + } +} + func (m *UI) openAPIKeyInputDialog(provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType) tea.Cmd { if m.dialog.ContainsDialog(dialog.APIKeyInputID) { m.dialog.BringToFront(dialog.APIKeyInputID) @@ -1031,6 +1042,21 @@ func (m *UI) openAPIKeyInputDialog(provider catwalk.Provider, model config.Selec return nil } +func (m *UI) openOAuthDialog(provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType) tea.Cmd { + if m.dialog.ContainsDialog(dialog.OAuthID) { + m.dialog.BringToFront(dialog.OAuthID) + return nil + } + + oAuthDialog, err := dialog.NewOAuth(m.com, provider, model, modelType) + if err != nil { + return uiutil.ReportError(err) + } + m.dialog.OpenDialog(oAuthDialog) + + return oAuthDialog.Init() +} + func (m *UI) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd { var cmds []tea.Cmd diff --git a/internal/ui/styles/styles.go b/internal/ui/styles/styles.go index 878ed83eaf7c0eaaa490dc11546a72f0a9a8a539..bb39cc0a583cbaa834c4c59e139d97cc72a2de76 100644 --- a/internal/ui/styles/styles.go +++ b/internal/ui/styles/styles.go @@ -173,12 +173,14 @@ type Styles struct { FgSubtle color.Color Border color.Color BorderColor color.Color // Border focus color + Error color.Color Warning color.Color Info color.Color White color.Color BlueLight color.Color Blue color.Color BlueDark color.Color + GreenLight color.Color Green color.Color GreenDark color.Color Red color.Color @@ -459,6 +461,7 @@ func DefaultStyles() Styles { borderFocus = charmtone.Charple // Status + error = charmtone.Sriracha warning = charmtone.Zest info = charmtone.Malibu @@ -473,8 +476,9 @@ func DefaultStyles() Styles { yellow = charmtone.Mustard // citron = charmtone.Citron - green = charmtone.Julep - greenDark = charmtone.Guac + greenLight = charmtone.Bok + green = charmtone.Julep + greenDark = charmtone.Guac // greenLight = charmtone.Bok red = charmtone.Coral @@ -505,12 +509,14 @@ func DefaultStyles() Styles { s.FgSubtle = fgSubtle s.Border = border s.BorderColor = borderFocus + s.Error = error s.Warning = warning s.Info = info s.White = white s.BlueLight = blueLight s.Blue = blue s.BlueDark = blueDark + s.GreenLight = greenLight s.Green = green s.GreenDark = greenDark s.Red = red