models.go

  1// Package models provides the model selection dialog for the TUI.
  2package models
  3
  4import (
  5	"fmt"
  6	"time"
  7
  8	"charm.land/bubbles/v2/help"
  9	"charm.land/bubbles/v2/key"
 10	"charm.land/bubbles/v2/spinner"
 11	tea "charm.land/bubbletea/v2"
 12	"charm.land/lipgloss/v2"
 13	"github.com/atotto/clipboard"
 14	"github.com/charmbracelet/catwalk/pkg/catwalk"
 15	hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/tui/components/core"
 18	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 19	"github.com/charmbracelet/crush/internal/tui/components/dialogs/claude"
 20	"github.com/charmbracelet/crush/internal/tui/components/dialogs/hyper"
 21	"github.com/charmbracelet/crush/internal/tui/exp/list"
 22	"github.com/charmbracelet/crush/internal/tui/styles"
 23	"github.com/charmbracelet/crush/internal/tui/util"
 24)
 25
 26const (
 27	ModelsDialogID dialogs.DialogID = "models"
 28
 29	defaultWidth = 60
 30)
 31
 32const (
 33	LargeModelType int = iota
 34	SmallModelType
 35
 36	largeModelInputPlaceholder = "Choose a model for large, complex tasks"
 37	smallModelInputPlaceholder = "Choose a model for small, simple tasks"
 38)
 39
 40// ModelSelectedMsg is sent when a model is selected
 41type ModelSelectedMsg struct {
 42	Model     config.SelectedModel
 43	ModelType config.SelectedModelType
 44}
 45
 46// CloseModelDialogMsg is sent when a model is selected
 47type CloseModelDialogMsg struct{}
 48
 49// ModelDialog interface for the model selection dialog
 50type ModelDialog interface {
 51	dialogs.DialogModel
 52}
 53
 54type ModelOption struct {
 55	Provider catwalk.Provider
 56	Model    catwalk.Model
 57}
 58
 59type modelDialogCmp struct {
 60	width   int
 61	wWidth  int
 62	wHeight int
 63
 64	modelList *ModelListComponent
 65	keyMap    KeyMap
 66	help      help.Model
 67
 68	// API key state
 69	needsAPIKey       bool
 70	apiKeyInput       *APIKeyInput
 71	selectedModel     *ModelOption
 72	selectedModelType config.SelectedModelType
 73	isAPIKeyValid     bool
 74	apiKeyValue       string
 75
 76	// Hyper device flow state
 77	hyperDeviceFlow     *hyper.DeviceFlow
 78	showHyperDeviceFlow bool
 79
 80	// Claude state
 81	claudeAuthMethodChooser     *claude.AuthMethodChooser
 82	claudeOAuth2                *claude.OAuth2
 83	showClaudeAuthMethodChooser bool
 84	showClaudeOAuth2            bool
 85}
 86
 87func NewModelDialogCmp() ModelDialog {
 88	keyMap := DefaultKeyMap()
 89
 90	listKeyMap := list.DefaultKeyMap()
 91	listKeyMap.Down.SetEnabled(false)
 92	listKeyMap.Up.SetEnabled(false)
 93	listKeyMap.DownOneItem = keyMap.Next
 94	listKeyMap.UpOneItem = keyMap.Previous
 95
 96	t := styles.CurrentTheme()
 97	modelList := NewModelListComponent(listKeyMap, largeModelInputPlaceholder, true)
 98	apiKeyInput := NewAPIKeyInput()
 99	apiKeyInput.SetShowTitle(false)
100	help := help.New()
101	help.Styles = t.S().Help
102
103	return &modelDialogCmp{
104		modelList:   modelList,
105		apiKeyInput: apiKeyInput,
106		width:       defaultWidth,
107		keyMap:      DefaultKeyMap(),
108		help:        help,
109
110		claudeAuthMethodChooser: claude.NewAuthMethodChooser(),
111		claudeOAuth2:            claude.NewOAuth2(),
112	}
113}
114
115func (m *modelDialogCmp) Init() tea.Cmd {
116	return tea.Batch(
117		m.modelList.Init(),
118		m.apiKeyInput.Init(),
119		m.claudeAuthMethodChooser.Init(),
120		m.claudeOAuth2.Init(),
121	)
122}
123
124func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
125	switch msg := msg.(type) {
126	case tea.WindowSizeMsg:
127		m.wWidth = msg.Width
128		m.wHeight = msg.Height
129		m.apiKeyInput.SetWidth(m.width - 2)
130		m.help.SetWidth(m.width - 2)
131		m.claudeAuthMethodChooser.SetWidth(m.width - 2)
132		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
133	case APIKeyStateChangeMsg:
134		u, cmd := m.apiKeyInput.Update(msg)
135		m.apiKeyInput = u.(*APIKeyInput)
136		return m, cmd
137	case hyper.DeviceFlowCompletedMsg:
138		return m, m.saveOauthTokenAndContinue(msg.Token, true)
139	case hyper.DeviceAuthInitiatedMsg, hyper.DeviceFlowErrorMsg:
140		if m.hyperDeviceFlow != nil {
141			u, cmd := m.hyperDeviceFlow.Update(msg)
142			m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
143			return m, cmd
144		}
145		return m, nil
146	case claude.ValidationCompletedMsg:
147		var cmds []tea.Cmd
148		u, cmd := m.claudeOAuth2.Update(msg)
149		m.claudeOAuth2 = u.(*claude.OAuth2)
150		cmds = append(cmds, cmd)
151
152		if msg.State == claude.OAuthValidationStateValid {
153			cmds = append(cmds, m.saveOauthTokenAndContinue(msg.Token, false))
154			m.keyMap.isClaudeOAuthHelpComplete = true
155		}
156
157		return m, tea.Batch(cmds...)
158	case claude.AuthenticationCompleteMsg:
159		return m, util.CmdHandler(dialogs.CloseDialogMsg{})
160	case tea.KeyPressMsg:
161		switch {
162		// Handle Hyper device flow keys
163		case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showHyperDeviceFlow:
164			if m.hyperDeviceFlow != nil {
165				return m, m.hyperDeviceFlow.CopyCode()
166			}
167		case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showClaudeOAuth2 && m.claudeOAuth2.State == claude.OAuthStateURL:
168			return m, tea.Sequence(
169				tea.SetClipboard(m.claudeOAuth2.URL),
170				func() tea.Msg {
171					_ = clipboard.WriteAll(m.claudeOAuth2.URL)
172					return nil
173				},
174				util.ReportInfo("URL copied to clipboard"),
175			)
176		case key.Matches(msg, m.keyMap.Choose) && m.showClaudeAuthMethodChooser:
177			m.claudeAuthMethodChooser.ToggleChoice()
178			return m, nil
179		case key.Matches(msg, m.keyMap.Select):
180			// If showing device flow, enter copies code and opens URL
181			if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
182				return m, m.hyperDeviceFlow.CopyCodeAndOpenURL()
183			}
184			selectedItem := m.modelList.SelectedModel()
185
186			modelType := config.SelectedModelTypeLarge
187			if m.modelList.GetModelType() == SmallModelType {
188				modelType = config.SelectedModelTypeSmall
189			}
190
191			askForApiKey := func() {
192				m.keyMap.isClaudeAuthChoiceHelp = false
193				m.keyMap.isClaudeOAuthHelp = false
194				m.keyMap.isAPIKeyHelp = true
195				m.showHyperDeviceFlow = false
196				m.showClaudeAuthMethodChooser = false
197				m.needsAPIKey = true
198				m.selectedModel = selectedItem
199				m.selectedModelType = modelType
200				m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
201			}
202
203			if m.showClaudeAuthMethodChooser {
204				switch m.claudeAuthMethodChooser.State {
205				case claude.AuthMethodAPIKey:
206					askForApiKey()
207				case claude.AuthMethodOAuth2:
208					m.selectedModel = selectedItem
209					m.selectedModelType = modelType
210					m.showClaudeAuthMethodChooser = false
211					m.showClaudeOAuth2 = true
212					m.keyMap.isClaudeAuthChoiceHelp = false
213					m.keyMap.isClaudeOAuthHelp = true
214				}
215				return m, nil
216			}
217			if m.showClaudeOAuth2 {
218				m2, cmd2 := m.claudeOAuth2.ValidationConfirm()
219				m.claudeOAuth2 = m2.(*claude.OAuth2)
220				return m, cmd2
221			}
222			if m.isAPIKeyValid {
223				return m, m.saveOauthTokenAndContinue(m.apiKeyValue, true)
224			}
225			if m.needsAPIKey {
226				// Handle API key submission
227				m.apiKeyValue = m.apiKeyInput.Value()
228				provider, err := m.getProvider(m.selectedModel.Provider.ID)
229				if err != nil || provider == nil {
230					return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
231				}
232				providerConfig := config.ProviderConfig{
233					ID:      string(m.selectedModel.Provider.ID),
234					Name:    m.selectedModel.Provider.Name,
235					APIKey:  m.apiKeyValue,
236					Type:    provider.Type,
237					BaseURL: provider.APIEndpoint,
238				}
239				return m, tea.Sequence(
240					util.CmdHandler(APIKeyStateChangeMsg{
241						State: APIKeyInputStateVerifying,
242					}),
243					func() tea.Msg {
244						start := time.Now()
245						err := providerConfig.TestConnection(config.Get().Resolver())
246						// intentionally wait for at least 750ms to make sure the user sees the spinner
247						elapsed := time.Since(start)
248						if elapsed < 750*time.Millisecond {
249							time.Sleep(750*time.Millisecond - elapsed)
250						}
251						if err == nil {
252							m.isAPIKeyValid = true
253							return APIKeyStateChangeMsg{
254								State: APIKeyInputStateVerified,
255							}
256						}
257						return APIKeyStateChangeMsg{
258							State: APIKeyInputStateError,
259						}
260					},
261				)
262			}
263
264			// Check if provider is configured
265			if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
266				return m, tea.Sequence(
267					util.CmdHandler(dialogs.CloseDialogMsg{}),
268					util.CmdHandler(ModelSelectedMsg{
269						Model: config.SelectedModel{
270							Model:           selectedItem.Model.ID,
271							Provider:        string(selectedItem.Provider.ID),
272							ReasoningEffort: selectedItem.Model.DefaultReasoningEffort,
273							MaxTokens:       selectedItem.Model.DefaultMaxTokens,
274						},
275						ModelType: modelType,
276					}),
277				)
278			}
279			switch selectedItem.Provider.ID {
280			case catwalk.InferenceProviderAnthropic:
281				m.showClaudeAuthMethodChooser = true
282				m.keyMap.isClaudeAuthChoiceHelp = true
283				return m, nil
284			case hyperp.Name:
285				m.showHyperDeviceFlow = true
286				m.selectedModel = selectedItem
287				m.selectedModelType = modelType
288				m.hyperDeviceFlow = hyper.NewDeviceFlow()
289				m.hyperDeviceFlow.SetWidth(m.width - 2)
290				return m, m.hyperDeviceFlow.Init()
291			}
292			// For other providers, show API key input
293			askForApiKey()
294			return m, nil
295		case key.Matches(msg, m.keyMap.Tab):
296			switch {
297			case m.showClaudeAuthMethodChooser:
298				m.claudeAuthMethodChooser.ToggleChoice()
299				return m, nil
300			case m.needsAPIKey:
301				u, cmd := m.apiKeyInput.Update(msg)
302				m.apiKeyInput = u.(*APIKeyInput)
303				return m, cmd
304			case m.modelList.GetModelType() == LargeModelType:
305				m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
306				return m, m.modelList.SetModelType(SmallModelType)
307			default:
308				m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
309				return m, m.modelList.SetModelType(LargeModelType)
310			}
311		case key.Matches(msg, m.keyMap.Close):
312			if m.showHyperDeviceFlow {
313				// Cancel device flow and go back to model selection
314				if m.hyperDeviceFlow != nil {
315					m.hyperDeviceFlow.Cancel()
316				}
317				m.showHyperDeviceFlow = false
318				m.selectedModel = nil
319			}
320			if m.showClaudeAuthMethodChooser {
321				m.claudeAuthMethodChooser.SetDefaults()
322				m.showClaudeAuthMethodChooser = false
323				m.keyMap.isClaudeAuthChoiceHelp = false
324				m.keyMap.isClaudeOAuthHelp = false
325				return m, nil
326			}
327			if m.needsAPIKey {
328				if m.isAPIKeyValid {
329					return m, nil
330				}
331				// Go back to model selection
332				m.needsAPIKey = false
333				m.selectedModel = nil
334				m.isAPIKeyValid = false
335				m.apiKeyValue = ""
336				m.apiKeyInput.Reset()
337				return m, nil
338			}
339			return m, util.CmdHandler(dialogs.CloseDialogMsg{})
340		default:
341			if m.showClaudeAuthMethodChooser {
342				u, cmd := m.claudeAuthMethodChooser.Update(msg)
343				m.claudeAuthMethodChooser = u.(*claude.AuthMethodChooser)
344				return m, cmd
345			} else if m.showClaudeOAuth2 {
346				u, cmd := m.claudeOAuth2.Update(msg)
347				m.claudeOAuth2 = u.(*claude.OAuth2)
348				return m, cmd
349			} else if m.needsAPIKey {
350				u, cmd := m.apiKeyInput.Update(msg)
351				m.apiKeyInput = u.(*APIKeyInput)
352				return m, cmd
353			} else {
354				u, cmd := m.modelList.Update(msg)
355				m.modelList = u
356				return m, cmd
357			}
358		}
359	case tea.PasteMsg:
360		if m.showClaudeOAuth2 {
361			u, cmd := m.claudeOAuth2.Update(msg)
362			m.claudeOAuth2 = u.(*claude.OAuth2)
363			return m, cmd
364		} else if m.needsAPIKey {
365			u, cmd := m.apiKeyInput.Update(msg)
366			m.apiKeyInput = u.(*APIKeyInput)
367			return m, cmd
368		} else {
369			var cmd tea.Cmd
370			m.modelList, cmd = m.modelList.Update(msg)
371			return m, cmd
372		}
373	case spinner.TickMsg:
374		u, cmd := m.apiKeyInput.Update(msg)
375		m.apiKeyInput = u.(*APIKeyInput)
376		if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
377			u, cmd = m.hyperDeviceFlow.Update(msg)
378			m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
379		}
380		return m, cmd
381	default:
382		// Pass all other messages to the device flow for spinner animation
383		if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
384			u, cmd := m.hyperDeviceFlow.Update(msg)
385			m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
386			return m, cmd
387		} else if m.showClaudeOAuth2 {
388			u, cmd := m.claudeOAuth2.Update(msg)
389			m.claudeOAuth2 = u.(*claude.OAuth2)
390			return m, cmd
391		} else {
392			u, cmd := m.apiKeyInput.Update(msg)
393			m.apiKeyInput = u.(*APIKeyInput)
394			return m, cmd
395		}
396	}
397	return m, nil
398}
399
400func (m *modelDialogCmp) View() string {
401	t := styles.CurrentTheme()
402
403	if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
404		// Show Hyper device flow
405		m.keyMap.isHyperDeviceFlow = true
406		deviceFlowView := m.hyperDeviceFlow.View()
407		content := lipgloss.JoinVertical(
408			lipgloss.Left,
409			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Authenticate with Hyper", m.width-4)),
410			deviceFlowView,
411			"",
412			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
413		)
414		return m.style().Render(content)
415	}
416
417	// Reset the flags when not showing device flow
418	m.keyMap.isHyperDeviceFlow = false
419
420	switch {
421	case m.showClaudeAuthMethodChooser:
422		chooserView := m.claudeAuthMethodChooser.View()
423		content := lipgloss.JoinVertical(
424			lipgloss.Left,
425			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
426			chooserView,
427			"",
428			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
429		)
430		return m.style().Render(content)
431	case m.showClaudeOAuth2:
432		m.keyMap.isClaudeOAuthURLState = m.claudeOAuth2.State == claude.OAuthStateURL
433		oauth2View := m.claudeOAuth2.View()
434		content := lipgloss.JoinVertical(
435			lipgloss.Left,
436			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
437			oauth2View,
438			"",
439			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
440		)
441		return m.style().Render(content)
442	case m.needsAPIKey:
443		// Show API key input
444		m.keyMap.isAPIKeyHelp = true
445		m.keyMap.isAPIKeyValid = m.isAPIKeyValid
446		apiKeyView := m.apiKeyInput.View()
447		apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
448		content := lipgloss.JoinVertical(
449			lipgloss.Left,
450			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
451			apiKeyView,
452			"",
453			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
454		)
455		return m.style().Render(content)
456	}
457
458	// Show model selection
459	listView := m.modelList.View()
460	radio := m.modelTypeRadio()
461	content := lipgloss.JoinVertical(
462		lipgloss.Left,
463		t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
464		listView,
465		"",
466		t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
467	)
468	return m.style().Render(content)
469}
470
471func (m *modelDialogCmp) Cursor() *tea.Cursor {
472	if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
473		return m.hyperDeviceFlow.Cursor()
474	}
475	if m.showClaudeAuthMethodChooser {
476		return nil
477	}
478	if m.showClaudeOAuth2 {
479		if cursor := m.claudeOAuth2.CodeInput.Cursor(); cursor != nil {
480			cursor.Y += 2 // FIXME(@andreynering): Why do we need this?
481			return m.moveCursor(cursor)
482		}
483		return nil
484	}
485	if m.needsAPIKey {
486		cursor := m.apiKeyInput.Cursor()
487		if cursor != nil {
488			cursor = m.moveCursor(cursor)
489			return cursor
490		}
491	} else {
492		cursor := m.modelList.Cursor()
493		if cursor != nil {
494			cursor = m.moveCursor(cursor)
495			return cursor
496		}
497	}
498	return nil
499}
500
501func (m *modelDialogCmp) style() lipgloss.Style {
502	t := styles.CurrentTheme()
503	return t.S().Base.
504		Width(m.width).
505		Border(lipgloss.RoundedBorder()).
506		BorderForeground(t.BorderFocus)
507}
508
509func (m *modelDialogCmp) listWidth() int {
510	return m.width - 2
511}
512
513func (m *modelDialogCmp) listHeight() int {
514	return m.wHeight / 2
515}
516
517func (m *modelDialogCmp) Position() (int, int) {
518	row := m.wHeight/4 - 2 // just a bit above the center
519	col := m.wWidth / 2
520	col -= m.width / 2
521	return row, col
522}
523
524func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
525	row, col := m.Position()
526	if m.needsAPIKey {
527		offset := row + 3 // Border + title + API key input offset
528		cursor.Y += offset
529		cursor.X = cursor.X + col + 2
530	} else {
531		offset := row + 3 // Border + title
532		cursor.Y += offset
533		cursor.X = cursor.X + col + 2
534	}
535	return cursor
536}
537
538func (m *modelDialogCmp) ID() dialogs.DialogID {
539	return ModelsDialogID
540}
541
542func (m *modelDialogCmp) modelTypeRadio() string {
543	t := styles.CurrentTheme()
544	choices := []string{"Large Task", "Small Task"}
545	iconSelected := "◉"
546	iconUnselected := "○"
547	if m.modelList.GetModelType() == LargeModelType {
548		return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + "  " + iconUnselected + " " + choices[1])
549	}
550	return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + "  " + iconSelected + " " + choices[1])
551}
552
553func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
554	cfg := config.Get()
555	_, ok := cfg.Providers.Get(providerID)
556	return ok
557}
558
559func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
560	cfg := config.Get()
561	providers, err := config.Providers(cfg)
562	if err != nil {
563		return nil, err
564	}
565	for _, p := range providers {
566		if p.ID == providerID {
567			return &p, nil
568		}
569	}
570	return nil, nil
571}
572
573func (m *modelDialogCmp) saveOauthTokenAndContinue(apiKey any, close bool) tea.Cmd {
574	if m.selectedModel == nil {
575		return util.ReportError(fmt.Errorf("no model selected"))
576	}
577
578	cfg := config.Get()
579	err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
580	if err != nil {
581		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
582	}
583
584	// Reset API key state and continue with model selection
585	selectedModel := *m.selectedModel
586	var cmds []tea.Cmd
587	if close {
588		cmds = append(cmds, util.CmdHandler(dialogs.CloseDialogMsg{}))
589	}
590	cmds = append(
591		cmds,
592		util.CmdHandler(ModelSelectedMsg{
593			Model: config.SelectedModel{
594				Model:           selectedModel.Model.ID,
595				Provider:        string(selectedModel.Provider.ID),
596				ReasoningEffort: selectedModel.Model.DefaultReasoningEffort,
597				MaxTokens:       selectedModel.Model.DefaultMaxTokens,
598			},
599			ModelType: m.selectedModelType,
600		}),
601	)
602	return tea.Sequence(cmds...)
603}