feat: implement hyper oauth flow in the new ui codebase

Andrey Nering created

Change summary

internal/ui/dialog/actions.go |  23 ++
internal/ui/dialog/oauth.go   | 423 +++++++++++++++++++++++++++++++++++++
internal/ui/model/ui.go       |  30 ++
internal/ui/styles/styles.go  |  10 
4 files changed, 482 insertions(+), 4 deletions(-)

Detailed changes

internal/ui/dialog/actions.go 🔗

@@ -5,6 +5,7 @@ import (
 	"github.com/charmbracelet/catwalk/pkg/catwalk"
 	"github.com/charmbracelet/crush/internal/commands"
 	"github.com/charmbracelet/crush/internal/config"
+	"github.com/charmbracelet/crush/internal/oauth"
 	"github.com/charmbracelet/crush/internal/permission"
 	"github.com/charmbracelet/crush/internal/session"
 )
@@ -73,6 +74,28 @@ type (
 	}
 )
 
+// Messages for OAuth2 device flow dialog.
+type (
+	// ActionInitiateOAuth is sent when the device auth is initiated
+	// successfully.
+	ActionInitiateOAuth struct {
+		DeviceCode      string
+		UserCode        string
+		ExpiresIn       int
+		VerificationURL string
+	}
+
+	// ActionCompleteOAuth is sent when the device flow completes successfully.
+	ActionCompleteOAuth struct {
+		Token *oauth.Token
+	}
+
+	// ActionOAuthErrored is sent when the device flow encounters an error.
+	ActionOAuthErrored struct {
+		Error error
+	}
+)
+
 // ActionCmd represents an action that carries a [tea.Cmd] to be passed to the
 // Bubble Tea program loop.
 type ActionCmd struct {

internal/ui/dialog/oauth.go 🔗

@@ -0,0 +1,423 @@
+package dialog
+
+import (
+	"context"
+	"fmt"
+	"strings"
+	"time"
+
+	"charm.land/bubbles/v2/help"
+	"charm.land/bubbles/v2/key"
+	"charm.land/bubbles/v2/spinner"
+	tea "charm.land/bubbletea/v2"
+	"charm.land/lipgloss/v2"
+	"github.com/charmbracelet/catwalk/pkg/catwalk"
+	"github.com/charmbracelet/crush/internal/config"
+	"github.com/charmbracelet/crush/internal/oauth"
+	"github.com/charmbracelet/crush/internal/oauth/hyper"
+	"github.com/charmbracelet/crush/internal/ui/common"
+	"github.com/charmbracelet/crush/internal/uiutil"
+	uv "github.com/charmbracelet/ultraviolet"
+	"github.com/pkg/browser"
+)
+
+// OAuthState represents the current state of the device flow.
+type OAuthState int
+
+const (
+	OAuthStateInitializing OAuthState = iota
+	OAuthStateDisplay
+	OAuthStateSuccess
+	OAuthStateError
+)
+
+// OAuthID is the identifier for the model selection dialog.
+const OAuthID = "oauth"
+
+// OAuth handles the OAuth flow authentication.
+type OAuth struct {
+	com *common.Common
+
+	provider  catwalk.Provider
+	model     config.SelectedModel
+	modelType config.SelectedModelType
+
+	State OAuthState
+
+	spinner spinner.Model
+	help    help.Model
+	keyMap  struct {
+		Copy   key.Binding
+		Submit key.Binding
+		Close  key.Binding
+	}
+
+	width           int
+	deviceCode      string
+	userCode        string
+	verificationURL string
+	expiresIn       int
+	token           *oauth.Token
+	cancelFunc      context.CancelFunc
+}
+
+var _ Dialog = (*OAuth)(nil)
+
+// NewOAuth creates a new device flow component.
+func NewOAuth(com *common.Common, provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType) (*OAuth, error) {
+	t := com.Styles
+
+	m := OAuth{}
+	m.com = com
+	m.provider = provider
+	m.model = model
+	m.modelType = modelType
+	m.width = 60
+	m.State = OAuthStateInitializing
+
+	m.spinner = spinner.New(
+		spinner.WithSpinner(spinner.Dot),
+		spinner.WithStyle(t.Base.Foreground(t.GreenLight)),
+	)
+
+	m.help = help.New()
+	m.help.Styles = t.DialogHelpStyles()
+
+	m.keyMap.Copy = key.NewBinding(
+		key.WithKeys("c"),
+		key.WithHelp("c", "copy code"),
+	)
+	m.keyMap.Submit = key.NewBinding(
+		key.WithKeys("enter", "ctrl+y"),
+		key.WithHelp("enter", "copy & open"),
+	)
+	m.keyMap.Close = CloseKey
+
+	return &m, nil
+}
+
+// ID implements Dialog.
+func (m *OAuth) ID() string {
+	return OAuthID
+}
+
+// Init implements Dialog.
+func (m *OAuth) Init() tea.Cmd {
+	return tea.Batch(m.spinner.Tick, m.initiateDeviceAuth)
+}
+
+// HandleMsg handles messages and state transitions.
+func (m *OAuth) HandleMsg(msg tea.Msg) Action {
+	switch msg := msg.(type) {
+	case spinner.TickMsg:
+		switch m.State {
+		case OAuthStateInitializing, OAuthStateDisplay:
+			var cmd tea.Cmd
+			m.spinner, cmd = m.spinner.Update(msg)
+			if cmd != nil {
+				return ActionCmd{cmd}
+			}
+		}
+
+	case tea.KeyPressMsg:
+		switch {
+		case key.Matches(msg, m.keyMap.Copy):
+			cmd := m.copyCode()
+			return ActionCmd{cmd}
+
+		case key.Matches(msg, m.keyMap.Submit):
+			switch m.State {
+			case OAuthStateSuccess:
+				return m.saveKeyAndContinue()
+
+			default:
+				cmd := m.copyCodeAndOpenURL()
+				return ActionCmd{cmd}
+			}
+
+		case key.Matches(msg, m.keyMap.Close):
+			switch m.State {
+			case OAuthStateSuccess:
+				return m.saveKeyAndContinue()
+
+			default:
+				return ActionClose{}
+			}
+		}
+
+	case ActionInitiateOAuth:
+		m.deviceCode = msg.DeviceCode
+		m.userCode = msg.UserCode
+		m.expiresIn = msg.ExpiresIn
+		m.verificationURL = msg.VerificationURL
+		m.State = OAuthStateDisplay
+		return ActionCmd{m.startPolling(msg.DeviceCode)}
+
+	case ActionCompleteOAuth:
+		m.State = OAuthStateSuccess
+		m.token = msg.Token
+		return ActionCmd{m.stopPolling}
+
+	case ActionOAuthErrored:
+		m.State = OAuthStateError
+		cmd := tea.Batch(m.stopPolling, uiutil.ReportError(msg.Error))
+		return ActionCmd{cmd}
+	}
+	return nil
+}
+
+// View renders the device flow dialog.
+func (m *OAuth) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
+	var (
+		t           = m.com.Styles
+		dialogStyle = t.Dialog.View.Width(m.width)
+		view        = dialogStyle.Render(m.dialogContent())
+	)
+	DrawCenterCursor(scr, area, view, nil)
+	return nil
+}
+
+func (m *OAuth) dialogContent() string {
+	var (
+		t         = m.com.Styles
+		helpStyle = t.Dialog.HelpView
+	)
+
+	switch m.State {
+	case OAuthStateInitializing:
+		return m.innerDialogContent()
+
+	default:
+		elements := []string{
+			m.headerContent(),
+			m.innerDialogContent(),
+			helpStyle.Render(m.help.View(m)),
+		}
+		return strings.Join(elements, "\n")
+	}
+}
+
+func (m *OAuth) headerContent() string {
+	var (
+		t            = m.com.Styles
+		titleStyle   = t.Dialog.Title
+		dialogStyle  = t.Dialog.View.Width(m.width)
+		headerOffset = titleStyle.GetHorizontalFrameSize() + dialogStyle.GetHorizontalFrameSize()
+	)
+	return common.DialogTitle(t, titleStyle.Render("Authenticate with Hyper"), m.width-headerOffset)
+}
+
+func (m *OAuth) innerDialogContent() string {
+	var (
+		t            = m.com.Styles
+		whiteStyle   = lipgloss.NewStyle().Foreground(t.White)
+		primaryStyle = lipgloss.NewStyle().Foreground(t.Primary)
+		greenStyle   = lipgloss.NewStyle().Foreground(t.GreenLight)
+		linkStyle    = lipgloss.NewStyle().Foreground(t.GreenDark).Underline(true)
+		errorStyle   = lipgloss.NewStyle().Foreground(t.Error)
+		mutedStyle   = lipgloss.NewStyle().Foreground(t.FgMuted)
+	)
+
+	switch m.State {
+	case OAuthStateInitializing:
+		return lipgloss.NewStyle().
+			Margin(1, 1).
+			Width(m.width - 2).
+			Align(lipgloss.Center).
+			Render(
+				greenStyle.Render(m.spinner.View()) +
+					mutedStyle.Render("Initializing..."),
+			)
+
+	case OAuthStateDisplay:
+		instructions := lipgloss.NewStyle().
+			Margin(1).
+			Width(m.width - 2).
+			Render(
+				whiteStyle.Render("Press ") +
+					primaryStyle.Render("enter") +
+					whiteStyle.Render(" to copy the code below and open the browser."),
+			)
+
+		codeBox := lipgloss.NewStyle().
+			Width(m.width-2).
+			Height(7).
+			Align(lipgloss.Center, lipgloss.Center).
+			Background(t.BgBaseLighter).
+			Margin(1).
+			Render(
+				lipgloss.NewStyle().
+					Bold(true).
+					Foreground(t.White).
+					Render(m.userCode),
+			)
+
+		link := linkStyle.Hyperlink(m.verificationURL, "id=oauth-verify").Render(m.verificationURL)
+		url := mutedStyle.
+			Margin(0, 1).
+			Width(m.width - 2).
+			Render("Browser not opening? Refer to\n" + link)
+
+		waiting := greenStyle.
+			Width(m.width - 2).
+			Margin(1).
+			Render(m.spinner.View() + "Verifying...")
+
+		return lipgloss.JoinVertical(
+			lipgloss.Left,
+			instructions,
+			codeBox,
+			url,
+			waiting,
+		)
+
+	case OAuthStateSuccess:
+		return greenStyle.
+			Margin(1).
+			Width(m.width - 2).
+			Align(lipgloss.Center).
+			Render("Authentication successful!")
+
+	case OAuthStateError:
+		return lipgloss.NewStyle().
+			Margin(1).
+			Width(m.width - 2).
+			Render(errorStyle.Render("Authentication failed."))
+
+	default:
+		return ""
+	}
+}
+
+// FullHelp returns the full help view.
+func (m *OAuth) FullHelp() [][]key.Binding {
+	return [][]key.Binding{m.ShortHelp()}
+}
+
+// ShortHelp returns the full help view.
+func (m *OAuth) ShortHelp() []key.Binding {
+	switch m.State {
+	case OAuthStateError:
+		return []key.Binding{m.keyMap.Close}
+
+	case OAuthStateSuccess:
+		return []key.Binding{
+			key.NewBinding(
+				key.WithKeys("finish", "ctrl+y", "esc"),
+				key.WithHelp("enter", "finish"),
+			),
+		}
+
+	default:
+		return []key.Binding{
+			m.keyMap.Copy,
+			m.keyMap.Submit,
+			m.keyMap.Close,
+		}
+	}
+}
+
+func (d *OAuth) copyCode() tea.Cmd {
+	if d.State != OAuthStateDisplay {
+		return nil
+	}
+	return tea.Sequence(
+		tea.SetClipboard(d.userCode),
+		uiutil.ReportInfo("Code copied to clipboard"),
+	)
+}
+
+func (d *OAuth) copyCodeAndOpenURL() tea.Cmd {
+	if d.State != OAuthStateDisplay {
+		return nil
+	}
+	return tea.Sequence(
+		tea.SetClipboard(d.userCode),
+		func() tea.Msg {
+			if err := browser.OpenURL(d.verificationURL); err != nil {
+				return ActionOAuthErrored{fmt.Errorf("failed to open browser: %w", err)}
+			}
+			return nil
+		},
+		uiutil.ReportInfo("Code copied and URL opened"),
+	)
+}
+
+func (m *OAuth) saveKeyAndContinue() Action {
+	cfg := m.com.Config()
+
+	err := cfg.SetProviderAPIKey(string(m.provider.ID), m.token)
+	if err != nil {
+		return ActionCmd{uiutil.ReportError(fmt.Errorf("failed to save API key: %w", err))}
+	}
+
+	return ActionSelectModel{
+		Provider:  m.provider,
+		Model:     m.model,
+		ModelType: m.modelType,
+	}
+}
+
+func (m *OAuth) initiateDeviceAuth() tea.Msg {
+	minimumWait := 750 * time.Millisecond
+	startTime := time.Now()
+
+	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+	defer cancel()
+
+	authResp, err := hyper.InitiateDeviceAuth(ctx)
+
+	ellapsed := time.Since(startTime)
+	if ellapsed < minimumWait {
+		time.Sleep(minimumWait - ellapsed)
+	}
+
+	if err != nil {
+		return ActionOAuthErrored{fmt.Errorf("failed to initiate device auth: %w", err)}
+	}
+
+	return ActionInitiateOAuth{
+		DeviceCode:      authResp.DeviceCode,
+		UserCode:        authResp.UserCode,
+		ExpiresIn:       authResp.ExpiresIn,
+		VerificationURL: authResp.VerificationURL,
+	}
+}
+
+// startPolling starts polling for the device token.
+func (m *OAuth) startPolling(deviceCode string) tea.Cmd {
+	return func() tea.Msg {
+		ctx, cancel := context.WithCancel(context.Background())
+		m.cancelFunc = cancel
+
+		refreshToken, err := hyper.PollForToken(ctx, deviceCode, m.expiresIn)
+		if err != nil {
+			if ctx.Err() != nil {
+				return nil
+			}
+			return ActionOAuthErrored{err}
+		}
+
+		token, err := hyper.ExchangeToken(ctx, refreshToken)
+		if err != nil {
+			return ActionOAuthErrored{fmt.Errorf("token exchange failed: %w", err)}
+		}
+
+		introspect, err := hyper.IntrospectToken(ctx, token.AccessToken)
+		if err != nil {
+			return ActionOAuthErrored{fmt.Errorf("token introspection failed: %w", err)}
+		}
+		if !introspect.Active {
+			return ActionOAuthErrored{fmt.Errorf("access token is not active")}
+		}
+
+		return ActionCompleteOAuth{token}
+	}
+}
+
+func (m *OAuth) stopPolling() tea.Msg {
+	if m.cancelFunc != nil {
+		m.cancelFunc()
+	}
+	return nil
+}

internal/ui/model/ui.go 🔗

@@ -934,7 +934,7 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd {
 		_, isProviderConfigured := cfg.Providers.Get(msg.Model.Provider)
 		if !isProviderConfigured {
 			m.dialog.CloseDialog(dialog.ModelsID)
-			if cmd := m.openAPIKeyInputDialog(msg.Provider, msg.Model, msg.ModelType); cmd != nil {
+			if cmd := m.openAuthenticationDialog(msg.Provider, msg.Model, msg.ModelType); cmd != nil {
 				cmds = append(cmds, cmd)
 			}
 			break
@@ -950,6 +950,7 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd {
 		modelMsg := fmt.Sprintf("%s model changed to %s", msg.ModelType, msg.Model.Model)
 		cmds = append(cmds, uiutil.ReportInfo(modelMsg))
 		m.dialog.CloseDialog(dialog.APIKeyInputID)
+		m.dialog.CloseDialog(dialog.OAuthID)
 		m.dialog.CloseDialog(dialog.ModelsID)
 		// TODO CHANGE
 	case dialog.ActionPermissionResponse:
@@ -1016,7 +1017,17 @@ func substituteArgs(content string, args map[string]string) string {
 	return content
 }
 
-// openAPIKeyInputDialog opens the API key input dialog.
+func (m *UI) openAuthenticationDialog(provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType) tea.Cmd {
+	switch provider.ID {
+	case "hyper":
+		return m.openOAuthDialog(provider, model, modelType)
+	case catwalk.InferenceProviderCopilot:
+		return m.openOAuthDialog(provider, model, modelType)
+	default:
+		return m.openAPIKeyInputDialog(provider, model, modelType)
+	}
+}
+
 func (m *UI) openAPIKeyInputDialog(provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType) tea.Cmd {
 	if m.dialog.ContainsDialog(dialog.APIKeyInputID) {
 		m.dialog.BringToFront(dialog.APIKeyInputID)
@@ -1031,6 +1042,21 @@ func (m *UI) openAPIKeyInputDialog(provider catwalk.Provider, model config.Selec
 	return nil
 }
 
+func (m *UI) openOAuthDialog(provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType) tea.Cmd {
+	if m.dialog.ContainsDialog(dialog.OAuthID) {
+		m.dialog.BringToFront(dialog.OAuthID)
+		return nil
+	}
+
+	oAuthDialog, err := dialog.NewOAuth(m.com, provider, model, modelType)
+	if err != nil {
+		return uiutil.ReportError(err)
+	}
+	m.dialog.OpenDialog(oAuthDialog)
+
+	return oAuthDialog.Init()
+}
+
 func (m *UI) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd {
 	var cmds []tea.Cmd
 

internal/ui/styles/styles.go 🔗

@@ -173,12 +173,14 @@ type Styles struct {
 	FgSubtle      color.Color
 	Border        color.Color
 	BorderColor   color.Color // Border focus color
+	Error         color.Color
 	Warning       color.Color
 	Info          color.Color
 	White         color.Color
 	BlueLight     color.Color
 	Blue          color.Color
 	BlueDark      color.Color
+	GreenLight    color.Color
 	Green         color.Color
 	GreenDark     color.Color
 	Red           color.Color
@@ -459,6 +461,7 @@ func DefaultStyles() Styles {
 		borderFocus = charmtone.Charple
 
 		// Status
+		error   = charmtone.Sriracha
 		warning = charmtone.Zest
 		info    = charmtone.Malibu
 
@@ -473,8 +476,9 @@ func DefaultStyles() Styles {
 		yellow = charmtone.Mustard
 		// citron = charmtone.Citron
 
-		green     = charmtone.Julep
-		greenDark = charmtone.Guac
+		greenLight = charmtone.Bok
+		green      = charmtone.Julep
+		greenDark  = charmtone.Guac
 		// greenLight = charmtone.Bok
 
 		red     = charmtone.Coral
@@ -505,12 +509,14 @@ func DefaultStyles() Styles {
 	s.FgSubtle = fgSubtle
 	s.Border = border
 	s.BorderColor = borderFocus
+	s.Error = error
 	s.Warning = warning
 	s.Info = info
 	s.White = white
 	s.BlueLight = blueLight
 	s.Blue = blue
 	s.BlueDark = blueDark
+	s.GreenLight = greenLight
 	s.Green = green
 	s.GreenDark = greenDark
 	s.Red = red