oauth.go

  1package dialog
  2
  3import (
  4	"context"
  5	"fmt"
  6	"strings"
  7
  8	"charm.land/bubbles/v2/help"
  9	"charm.land/bubbles/v2/key"
 10	"charm.land/bubbles/v2/spinner"
 11	tea "charm.land/bubbletea/v2"
 12	"charm.land/lipgloss/v2"
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/oauth"
 16	"github.com/charmbracelet/crush/internal/ui/common"
 17	"github.com/charmbracelet/crush/internal/uiutil"
 18	uv "github.com/charmbracelet/ultraviolet"
 19	"github.com/pkg/browser"
 20)
 21
 22type OAuthProvider interface {
 23	name() string
 24	initiateAuth() tea.Msg
 25	startPolling(deviceCode string, expiresIn int) tea.Cmd
 26	stopPolling() tea.Msg
 27}
 28
 29// OAuthState represents the current state of the device flow.
 30type OAuthState int
 31
 32const (
 33	OAuthStateInitializing OAuthState = iota
 34	OAuthStateDisplay
 35	OAuthStateSuccess
 36	OAuthStateError
 37)
 38
 39// OAuthID is the identifier for the model selection dialog.
 40const OAuthID = "oauth"
 41
 42// OAuth handles the OAuth flow authentication.
 43type OAuth struct {
 44	com *common.Common
 45
 46	provider      catwalk.Provider
 47	model         config.SelectedModel
 48	modelType     config.SelectedModelType
 49	oAuthProvider OAuthProvider
 50
 51	State OAuthState
 52
 53	spinner spinner.Model
 54	help    help.Model
 55	keyMap  struct {
 56		Copy   key.Binding
 57		Submit key.Binding
 58		Close  key.Binding
 59	}
 60
 61	width           int
 62	deviceCode      string
 63	userCode        string
 64	verificationURL string
 65	expiresIn       int
 66	interval        int
 67	token           *oauth.Token
 68	cancelFunc      context.CancelFunc
 69}
 70
 71var _ Dialog = (*OAuth)(nil)
 72
 73// newOAuth creates a new device flow component.
 74func newOAuth(com *common.Common, provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType, oAuthProvider OAuthProvider) (*OAuth, tea.Cmd) {
 75	t := com.Styles
 76
 77	m := OAuth{}
 78	m.com = com
 79	m.provider = provider
 80	m.model = model
 81	m.modelType = modelType
 82	m.oAuthProvider = oAuthProvider
 83	m.width = 60
 84	m.State = OAuthStateInitializing
 85
 86	m.spinner = spinner.New(
 87		spinner.WithSpinner(spinner.Dot),
 88		spinner.WithStyle(t.Base.Foreground(t.GreenLight)),
 89	)
 90
 91	m.help = help.New()
 92	m.help.Styles = t.DialogHelpStyles()
 93
 94	m.keyMap.Copy = key.NewBinding(
 95		key.WithKeys("c"),
 96		key.WithHelp("c", "copy code"),
 97	)
 98	m.keyMap.Submit = key.NewBinding(
 99		key.WithKeys("enter", "ctrl+y"),
100		key.WithHelp("enter", "copy & open"),
101	)
102	m.keyMap.Close = CloseKey
103
104	return &m, tea.Batch(m.spinner.Tick, m.oAuthProvider.initiateAuth)
105}
106
107// ID implements Dialog.
108func (m *OAuth) ID() string {
109	return OAuthID
110}
111
112// HandleMsg handles messages and state transitions.
113func (m *OAuth) HandleMsg(msg tea.Msg) Action {
114	switch msg := msg.(type) {
115	case spinner.TickMsg:
116		switch m.State {
117		case OAuthStateInitializing, OAuthStateDisplay:
118			var cmd tea.Cmd
119			m.spinner, cmd = m.spinner.Update(msg)
120			if cmd != nil {
121				return ActionCmd{cmd}
122			}
123		}
124
125	case tea.KeyPressMsg:
126		switch {
127		case key.Matches(msg, m.keyMap.Copy):
128			cmd := m.copyCode()
129			return ActionCmd{cmd}
130
131		case key.Matches(msg, m.keyMap.Submit):
132			switch m.State {
133			case OAuthStateSuccess:
134				return m.saveKeyAndContinue()
135
136			default:
137				cmd := m.copyCodeAndOpenURL()
138				return ActionCmd{cmd}
139			}
140
141		case key.Matches(msg, m.keyMap.Close):
142			switch m.State {
143			case OAuthStateSuccess:
144				return m.saveKeyAndContinue()
145
146			default:
147				return ActionClose{}
148			}
149		}
150
151	case ActionInitiateOAuth:
152		m.deviceCode = msg.DeviceCode
153		m.userCode = msg.UserCode
154		m.expiresIn = msg.ExpiresIn
155		m.verificationURL = msg.VerificationURL
156		m.interval = msg.Interval
157		m.State = OAuthStateDisplay
158		return ActionCmd{m.oAuthProvider.startPolling(msg.DeviceCode, msg.ExpiresIn)}
159
160	case ActionCompleteOAuth:
161		m.State = OAuthStateSuccess
162		m.token = msg.Token
163		return ActionCmd{m.oAuthProvider.stopPolling}
164
165	case ActionOAuthErrored:
166		m.State = OAuthStateError
167		cmd := tea.Batch(m.oAuthProvider.stopPolling, uiutil.ReportError(msg.Error))
168		return ActionCmd{cmd}
169	}
170	return nil
171}
172
173// View renders the device flow dialog.
174func (m *OAuth) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
175	var (
176		t           = m.com.Styles
177		dialogStyle = t.Dialog.View.Width(m.width)
178		view        = dialogStyle.Render(m.dialogContent())
179	)
180	DrawCenterCursor(scr, area, view, nil)
181	return nil
182}
183
184func (m *OAuth) dialogContent() string {
185	var (
186		t         = m.com.Styles
187		helpStyle = t.Dialog.HelpView
188	)
189
190	switch m.State {
191	case OAuthStateInitializing:
192		return m.innerDialogContent()
193
194	default:
195		elements := []string{
196			m.headerContent(),
197			m.innerDialogContent(),
198			helpStyle.Render(m.help.View(m)),
199		}
200		return strings.Join(elements, "\n")
201	}
202}
203
204func (m *OAuth) headerContent() string {
205	var (
206		t            = m.com.Styles
207		titleStyle   = t.Dialog.Title
208		dialogStyle  = t.Dialog.View.Width(m.width)
209		headerOffset = titleStyle.GetHorizontalFrameSize() + dialogStyle.GetHorizontalFrameSize()
210	)
211	return common.DialogTitle(t, titleStyle.Render("Authenticate with "+m.oAuthProvider.name()), m.width-headerOffset)
212}
213
214func (m *OAuth) innerDialogContent() string {
215	var (
216		t            = m.com.Styles
217		whiteStyle   = lipgloss.NewStyle().Foreground(t.White)
218		primaryStyle = lipgloss.NewStyle().Foreground(t.Primary)
219		greenStyle   = lipgloss.NewStyle().Foreground(t.GreenLight)
220		linkStyle    = lipgloss.NewStyle().Foreground(t.GreenDark).Underline(true)
221		errorStyle   = lipgloss.NewStyle().Foreground(t.Error)
222		mutedStyle   = lipgloss.NewStyle().Foreground(t.FgMuted)
223	)
224
225	switch m.State {
226	case OAuthStateInitializing:
227		return lipgloss.NewStyle().
228			Margin(1, 1).
229			Width(m.width - 2).
230			Align(lipgloss.Center).
231			Render(
232				greenStyle.Render(m.spinner.View()) +
233					mutedStyle.Render("Initializing..."),
234			)
235
236	case OAuthStateDisplay:
237		instructions := lipgloss.NewStyle().
238			Margin(0, 1).
239			Width(m.width - 2).
240			Render(
241				whiteStyle.Render("Press ") +
242					primaryStyle.Render("enter") +
243					whiteStyle.Render(" to copy the code below and open the browser."),
244			)
245
246		codeBox := lipgloss.NewStyle().
247			Width(m.width-2).
248			Height(7).
249			Align(lipgloss.Center, lipgloss.Center).
250			Background(t.BgBaseLighter).
251			Margin(0, 1).
252			Render(
253				lipgloss.NewStyle().
254					Bold(true).
255					Foreground(t.White).
256					Render(m.userCode),
257			)
258
259		link := linkStyle.Hyperlink(m.verificationURL, "id=oauth-verify").Render(m.verificationURL)
260		url := mutedStyle.
261			Margin(0, 1).
262			Width(m.width - 2).
263			Render("Browser not opening? Refer to\n" + link)
264
265		waiting := lipgloss.NewStyle().
266			Margin(0, 1).
267			Width(m.width - 2).
268			Render(
269				greenStyle.Render(m.spinner.View()) + mutedStyle.Render("Verifying..."),
270			)
271
272		return lipgloss.JoinVertical(
273			lipgloss.Left,
274			"",
275			instructions,
276			"",
277			codeBox,
278			"",
279			url,
280			"",
281			waiting,
282			"",
283		)
284
285	case OAuthStateSuccess:
286		return greenStyle.
287			Margin(1).
288			Width(m.width - 2).
289			Render("Authentication successful!")
290
291	case OAuthStateError:
292		return lipgloss.NewStyle().
293			Margin(1).
294			Width(m.width - 2).
295			Render(errorStyle.Render("Authentication failed."))
296
297	default:
298		return ""
299	}
300}
301
302// FullHelp returns the full help view.
303func (m *OAuth) FullHelp() [][]key.Binding {
304	return [][]key.Binding{m.ShortHelp()}
305}
306
307// ShortHelp returns the full help view.
308func (m *OAuth) ShortHelp() []key.Binding {
309	switch m.State {
310	case OAuthStateError:
311		return []key.Binding{m.keyMap.Close}
312
313	case OAuthStateSuccess:
314		return []key.Binding{
315			key.NewBinding(
316				key.WithKeys("finish", "ctrl+y", "esc"),
317				key.WithHelp("enter", "finish"),
318			),
319		}
320
321	default:
322		return []key.Binding{
323			m.keyMap.Copy,
324			m.keyMap.Submit,
325			m.keyMap.Close,
326		}
327	}
328}
329
330func (d *OAuth) copyCode() tea.Cmd {
331	if d.State != OAuthStateDisplay {
332		return nil
333	}
334	return tea.Sequence(
335		tea.SetClipboard(d.userCode),
336		uiutil.ReportInfo("Code copied to clipboard"),
337	)
338}
339
340func (d *OAuth) copyCodeAndOpenURL() tea.Cmd {
341	if d.State != OAuthStateDisplay {
342		return nil
343	}
344	return tea.Sequence(
345		tea.SetClipboard(d.userCode),
346		func() tea.Msg {
347			if err := browser.OpenURL(d.verificationURL); err != nil {
348				return ActionOAuthErrored{fmt.Errorf("failed to open browser: %w", err)}
349			}
350			return nil
351		},
352		uiutil.ReportInfo("Code copied and URL opened"),
353	)
354}
355
356func (m *OAuth) saveKeyAndContinue() Action {
357	cfg := m.com.Config()
358
359	err := cfg.SetProviderAPIKey(string(m.provider.ID), m.token)
360	if err != nil {
361		return ActionCmd{uiutil.ReportError(fmt.Errorf("failed to save API key: %w", err))}
362	}
363
364	return ActionSelectModel{
365		Provider:  m.provider,
366		Model:     m.model,
367		ModelType: m.modelType,
368	}
369}