models.go

  1package models
  2
  3import (
  4	"fmt"
  5	"time"
  6
  7	"charm.land/bubbles/v2/help"
  8	"charm.land/bubbles/v2/key"
  9	"charm.land/bubbles/v2/spinner"
 10	tea "charm.land/bubbletea/v2"
 11	"charm.land/lipgloss/v2"
 12	"github.com/atotto/clipboard"
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/tui/components/core"
 16	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 17	"github.com/charmbracelet/crush/internal/tui/components/dialogs/claude"
 18	"github.com/charmbracelet/crush/internal/tui/exp/list"
 19	"github.com/charmbracelet/crush/internal/tui/styles"
 20	"github.com/charmbracelet/crush/internal/tui/util"
 21)
 22
 23const (
 24	ModelsDialogID dialogs.DialogID = "models"
 25
 26	defaultWidth = 60
 27)
 28
 29const (
 30	LargeModelType int = iota
 31	SmallModelType
 32
 33	largeModelInputPlaceholder = "Choose a model for large, complex tasks"
 34	smallModelInputPlaceholder = "Choose a model for small, simple tasks"
 35)
 36
 37// ModelSelectedMsg is sent when a model is selected
 38type ModelSelectedMsg struct {
 39	Model     config.SelectedModel
 40	ModelType config.SelectedModelType
 41}
 42
 43// CloseModelDialogMsg is sent when a model is selected
 44type CloseModelDialogMsg struct{}
 45
 46// ModelDialog interface for the model selection dialog
 47type ModelDialog interface {
 48	dialogs.DialogModel
 49}
 50
 51type ModelOption struct {
 52	Provider catwalk.Provider
 53	Model    catwalk.Model
 54}
 55
 56type modelDialogCmp struct {
 57	width   int
 58	wWidth  int
 59	wHeight int
 60
 61	modelList *ModelListComponent
 62	keyMap    KeyMap
 63	help      help.Model
 64
 65	// API key state
 66	needsAPIKey       bool
 67	apiKeyInput       *APIKeyInput
 68	selectedModel     *ModelOption
 69	selectedModelType config.SelectedModelType
 70	isAPIKeyValid     bool
 71	apiKeyValue       string
 72
 73	// Claude state
 74	claudeAuthMethodChooser     *claude.AuthMethodChooser
 75	claudeOAuth2                *claude.OAuth2
 76	showClaudeAuthMethodChooser bool
 77	showClaudeOAuth2            bool
 78}
 79
 80func NewModelDialogCmp() ModelDialog {
 81	keyMap := DefaultKeyMap()
 82
 83	listKeyMap := list.DefaultKeyMap()
 84	listKeyMap.Down.SetEnabled(false)
 85	listKeyMap.Up.SetEnabled(false)
 86	listKeyMap.DownOneItem = keyMap.Next
 87	listKeyMap.UpOneItem = keyMap.Previous
 88
 89	t := styles.CurrentTheme()
 90	modelList := NewModelListComponent(listKeyMap, largeModelInputPlaceholder, true)
 91	apiKeyInput := NewAPIKeyInput()
 92	apiKeyInput.SetShowTitle(false)
 93	help := help.New()
 94	help.Styles = t.S().Help
 95
 96	return &modelDialogCmp{
 97		modelList:   modelList,
 98		apiKeyInput: apiKeyInput,
 99		width:       defaultWidth,
100		keyMap:      DefaultKeyMap(),
101		help:        help,
102
103		claudeAuthMethodChooser: claude.NewAuthMethodChooser(),
104		claudeOAuth2:            claude.NewOAuth2(),
105	}
106}
107
108func (m *modelDialogCmp) Init() tea.Cmd {
109	return tea.Batch(
110		m.modelList.Init(),
111		m.apiKeyInput.Init(),
112		m.claudeAuthMethodChooser.Init(),
113		m.claudeOAuth2.Init(),
114	)
115}
116
117func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
118	switch msg := msg.(type) {
119	case tea.WindowSizeMsg:
120		m.wWidth = msg.Width
121		m.wHeight = msg.Height
122		m.apiKeyInput.SetWidth(m.width - 2)
123		m.help.SetWidth(m.width - 2)
124		m.claudeAuthMethodChooser.SetWidth(m.width - 2)
125		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
126	case APIKeyStateChangeMsg:
127		u, cmd := m.apiKeyInput.Update(msg)
128		m.apiKeyInput = u.(*APIKeyInput)
129		return m, cmd
130	case claude.ValidationCompletedMsg:
131		var cmds []tea.Cmd
132		u, cmd := m.claudeOAuth2.Update(msg)
133		m.claudeOAuth2 = u.(*claude.OAuth2)
134		cmds = append(cmds, cmd)
135
136		if msg.State == claude.OAuthValidationStateValid {
137			cmds = append(cmds, m.saveAPIKeyAndContinue(msg.Token, false))
138			m.keyMap.isClaudeOAuthHelpComplete = true
139		}
140
141		return m, tea.Batch(cmds...)
142	case claude.AuthenticationCompleteMsg:
143		return m, util.CmdHandler(dialogs.CloseDialogMsg{})
144	case tea.KeyPressMsg:
145		switch {
146		case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showClaudeOAuth2 && m.claudeOAuth2.State == claude.OAuthStateURL:
147			return m, tea.Sequence(
148				tea.SetClipboard(m.claudeOAuth2.URL),
149				func() tea.Msg {
150					_ = clipboard.WriteAll(m.claudeOAuth2.URL)
151					return nil
152				},
153				util.ReportInfo("URL copied to clipboard"),
154			)
155		case key.Matches(msg, m.keyMap.Choose) && m.showClaudeAuthMethodChooser:
156			m.claudeAuthMethodChooser.ToggleChoice()
157			return m, nil
158		case key.Matches(msg, m.keyMap.Select):
159			selectedItem := m.modelList.SelectedModel()
160
161			modelType := config.SelectedModelTypeLarge
162			if m.modelList.GetModelType() == SmallModelType {
163				modelType = config.SelectedModelTypeSmall
164			}
165
166			askForApiKey := func() {
167				m.keyMap.isClaudeAuthChoiseHelp = false
168				m.keyMap.isClaudeOAuthHelp = false
169				m.keyMap.isAPIKeyHelp = true
170				m.showClaudeAuthMethodChooser = false
171				m.needsAPIKey = true
172				m.selectedModel = selectedItem
173				m.selectedModelType = modelType
174				m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
175			}
176
177			if m.showClaudeAuthMethodChooser {
178				switch m.claudeAuthMethodChooser.State {
179				case claude.AuthMethodAPIKey:
180					askForApiKey()
181				case claude.AuthMethodOAuth2:
182					m.selectedModel = selectedItem
183					m.selectedModelType = modelType
184					m.showClaudeAuthMethodChooser = false
185					m.showClaudeOAuth2 = true
186					m.keyMap.isClaudeAuthChoiseHelp = false
187					m.keyMap.isClaudeOAuthHelp = true
188				}
189				return m, nil
190			}
191			if m.showClaudeOAuth2 {
192				m2, cmd2 := m.claudeOAuth2.ValidationConfirm()
193				m.claudeOAuth2 = m2.(*claude.OAuth2)
194				return m, cmd2
195			}
196			if m.isAPIKeyValid {
197				return m, m.saveAPIKeyAndContinue(m.apiKeyValue, true)
198			}
199			if m.needsAPIKey {
200				// Handle API key submission
201				m.apiKeyValue = m.apiKeyInput.Value()
202				provider, err := m.getProvider(m.selectedModel.Provider.ID)
203				if err != nil || provider == nil {
204					return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
205				}
206				providerConfig := config.ProviderConfig{
207					ID:      string(m.selectedModel.Provider.ID),
208					Name:    m.selectedModel.Provider.Name,
209					APIKey:  m.apiKeyValue,
210					Type:    provider.Type,
211					BaseURL: provider.APIEndpoint,
212				}
213				return m, tea.Sequence(
214					util.CmdHandler(APIKeyStateChangeMsg{
215						State: APIKeyInputStateVerifying,
216					}),
217					func() tea.Msg {
218						start := time.Now()
219						err := providerConfig.TestConnection(config.Get().Resolver())
220						// intentionally wait for at least 750ms to make sure the user sees the spinner
221						elapsed := time.Since(start)
222						if elapsed < 750*time.Millisecond {
223							time.Sleep(750*time.Millisecond - elapsed)
224						}
225						if err == nil {
226							m.isAPIKeyValid = true
227							return APIKeyStateChangeMsg{
228								State: APIKeyInputStateVerified,
229							}
230						}
231						return APIKeyStateChangeMsg{
232							State: APIKeyInputStateError,
233						}
234					},
235				)
236			}
237
238			// Check if provider is configured
239			if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
240				return m, tea.Sequence(
241					util.CmdHandler(dialogs.CloseDialogMsg{}),
242					util.CmdHandler(ModelSelectedMsg{
243						Model: config.SelectedModel{
244							Model:           selectedItem.Model.ID,
245							Provider:        string(selectedItem.Provider.ID),
246							ReasoningEffort: selectedItem.Model.DefaultReasoningEffort,
247							MaxTokens:       selectedItem.Model.DefaultMaxTokens,
248						},
249						ModelType: modelType,
250					}),
251				)
252			} else {
253				if selectedItem.Provider.ID == catwalk.InferenceProviderAnthropic {
254					m.showClaudeAuthMethodChooser = true
255					m.keyMap.isClaudeAuthChoiseHelp = true
256					return m, nil
257				}
258				askForApiKey()
259				return m, nil
260			}
261		case key.Matches(msg, m.keyMap.Tab):
262			switch {
263			case m.showClaudeAuthMethodChooser:
264				m.claudeAuthMethodChooser.ToggleChoice()
265				return m, nil
266			case m.needsAPIKey:
267				u, cmd := m.apiKeyInput.Update(msg)
268				m.apiKeyInput = u.(*APIKeyInput)
269				return m, cmd
270			case m.modelList.GetModelType() == LargeModelType:
271				m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
272				return m, m.modelList.SetModelType(SmallModelType)
273			default:
274				m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
275				return m, m.modelList.SetModelType(LargeModelType)
276			}
277		case key.Matches(msg, m.keyMap.Edit):
278			// Only handle edit key in the main model selection view
279			if !m.showClaudeAuthMethodChooser && !m.showClaudeOAuth2 && !m.needsAPIKey {
280				selectedItem := m.modelList.SelectedModel()
281				if selectedItem != nil {
282					// Check if provider is configured
283					if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
284						// Trigger API key editing flow
285						m.keyMap.isClaudeAuthChoiseHelp = false
286						m.keyMap.isClaudeOAuthHelp = false
287						m.keyMap.isAPIKeyHelp = true
288						m.showClaudeAuthMethodChooser = false
289						m.needsAPIKey = true
290						m.selectedModel = selectedItem
291						m.selectedModelType = config.SelectedModelTypeLarge
292						if m.modelList.GetModelType() == SmallModelType {
293							m.selectedModelType = config.SelectedModelTypeSmall
294						}
295						m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
296						// Pre-fill with existing API key if available
297						if providerConfig, ok := config.Get().Providers.Get(string(selectedItem.Provider.ID)); ok && providerConfig.APIKey != "" {
298							m.apiKeyInput.input.SetValue(providerConfig.APIKey)
299						}
300						return m, nil
301					}
302				}
303			}
304			return m, nil
305		case key.Matches(msg, m.keyMap.Close):
306			if m.showClaudeAuthMethodChooser {
307				m.claudeAuthMethodChooser.SetDefaults()
308				m.showClaudeAuthMethodChooser = false
309				m.keyMap.isClaudeAuthChoiseHelp = false
310				m.keyMap.isClaudeOAuthHelp = false
311				return m, nil
312			}
313			if m.needsAPIKey {
314				if m.isAPIKeyValid {
315					return m, nil
316				}
317				// Go back to model selection
318				m.needsAPIKey = false
319				m.selectedModel = nil
320				m.isAPIKeyValid = false
321				m.apiKeyValue = ""
322				m.apiKeyInput.Reset()
323				return m, nil
324			}
325			return m, util.CmdHandler(dialogs.CloseDialogMsg{})
326		default:
327			if m.showClaudeAuthMethodChooser {
328				u, cmd := m.claudeAuthMethodChooser.Update(msg)
329				m.claudeAuthMethodChooser = u.(*claude.AuthMethodChooser)
330				return m, cmd
331			} else if m.showClaudeOAuth2 {
332				u, cmd := m.claudeOAuth2.Update(msg)
333				m.claudeOAuth2 = u.(*claude.OAuth2)
334				return m, cmd
335			} else if m.needsAPIKey {
336				u, cmd := m.apiKeyInput.Update(msg)
337				m.apiKeyInput = u.(*APIKeyInput)
338				return m, cmd
339			} else {
340				u, cmd := m.modelList.Update(msg)
341				m.modelList = u
342				return m, cmd
343			}
344		}
345	case tea.PasteMsg:
346		if m.showClaudeOAuth2 {
347			u, cmd := m.claudeOAuth2.Update(msg)
348			m.claudeOAuth2 = u.(*claude.OAuth2)
349			return m, cmd
350		} else if m.needsAPIKey {
351			u, cmd := m.apiKeyInput.Update(msg)
352			m.apiKeyInput = u.(*APIKeyInput)
353			return m, cmd
354		} else {
355			var cmd tea.Cmd
356			m.modelList, cmd = m.modelList.Update(msg)
357			return m, cmd
358		}
359	case spinner.TickMsg:
360		if m.showClaudeOAuth2 {
361			u, cmd := m.claudeOAuth2.Update(msg)
362			m.claudeOAuth2 = u.(*claude.OAuth2)
363			return m, cmd
364		} else {
365			u, cmd := m.apiKeyInput.Update(msg)
366			m.apiKeyInput = u.(*APIKeyInput)
367			return m, cmd
368		}
369	}
370	return m, nil
371}
372
373func (m *modelDialogCmp) View() string {
374	t := styles.CurrentTheme()
375
376	switch {
377	case m.showClaudeAuthMethodChooser:
378		chooserView := m.claudeAuthMethodChooser.View()
379		content := lipgloss.JoinVertical(
380			lipgloss.Left,
381			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
382			chooserView,
383			"",
384			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
385		)
386		return m.style().Render(content)
387	case m.showClaudeOAuth2:
388		m.keyMap.isClaudeOAuthURLState = m.claudeOAuth2.State == claude.OAuthStateURL
389		oauth2View := m.claudeOAuth2.View()
390		content := lipgloss.JoinVertical(
391			lipgloss.Left,
392			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
393			oauth2View,
394			"",
395			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
396		)
397		return m.style().Render(content)
398	case m.needsAPIKey:
399		// Show API key input
400		m.keyMap.isAPIKeyHelp = true
401		m.keyMap.isAPIKeyValid = m.isAPIKeyValid
402		apiKeyView := m.apiKeyInput.View()
403		apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
404		content := lipgloss.JoinVertical(
405			lipgloss.Left,
406			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
407			apiKeyView,
408			"",
409			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
410		)
411		return m.style().Render(content)
412	}
413
414	// Show model selection
415	listView := m.modelList.View()
416	radio := m.modelTypeRadio()
417	content := lipgloss.JoinVertical(
418		lipgloss.Left,
419		t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
420		listView,
421		"",
422		t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
423	)
424	return m.style().Render(content)
425}
426
427func (m *modelDialogCmp) Cursor() *tea.Cursor {
428	if m.showClaudeAuthMethodChooser {
429		return nil
430	}
431	if m.showClaudeOAuth2 {
432		if cursor := m.claudeOAuth2.CodeInput.Cursor(); cursor != nil {
433			cursor.Y += 2 // FIXME(@andreynering): Why do we need this?
434			return m.moveCursor(cursor)
435		}
436		return nil
437	}
438	if m.needsAPIKey {
439		cursor := m.apiKeyInput.Cursor()
440		if cursor != nil {
441			cursor = m.moveCursor(cursor)
442			return cursor
443		}
444	} else {
445		cursor := m.modelList.Cursor()
446		if cursor != nil {
447			cursor = m.moveCursor(cursor)
448			return cursor
449		}
450	}
451	return nil
452}
453
454func (m *modelDialogCmp) style() lipgloss.Style {
455	t := styles.CurrentTheme()
456	return t.S().Base.
457		Width(m.width).
458		Border(lipgloss.RoundedBorder()).
459		BorderForeground(t.BorderFocus)
460}
461
462func (m *modelDialogCmp) listWidth() int {
463	return m.width - 2
464}
465
466func (m *modelDialogCmp) listHeight() int {
467	return m.wHeight / 2
468}
469
470func (m *modelDialogCmp) Position() (int, int) {
471	row := m.wHeight/4 - 2 // just a bit above the center
472	col := m.wWidth / 2
473	col -= m.width / 2
474	return row, col
475}
476
477func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
478	row, col := m.Position()
479	if m.needsAPIKey {
480		offset := row + 3 // Border + title + API key input offset
481		cursor.Y += offset
482		cursor.X = cursor.X + col + 2
483	} else {
484		offset := row + 3 // Border + title
485		cursor.Y += offset
486		cursor.X = cursor.X + col + 2
487	}
488	return cursor
489}
490
491func (m *modelDialogCmp) ID() dialogs.DialogID {
492	return ModelsDialogID
493}
494
495func (m *modelDialogCmp) modelTypeRadio() string {
496	t := styles.CurrentTheme()
497	choices := []string{"Large Task", "Small Task"}
498	iconSelected := "◉"
499	iconUnselected := "○"
500	if m.modelList.GetModelType() == LargeModelType {
501		return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + "  " + iconUnselected + " " + choices[1])
502	}
503	return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + "  " + iconSelected + " " + choices[1])
504}
505
506func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
507	cfg := config.Get()
508	if _, ok := cfg.Providers.Get(providerID); ok {
509		return true
510	}
511	return false
512}
513
514func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
515	cfg := config.Get()
516	providers, err := config.Providers(cfg)
517	if err != nil {
518		return nil, err
519	}
520	for _, p := range providers {
521		if p.ID == providerID {
522			return &p, nil
523		}
524	}
525	return nil, nil
526}
527
528func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey any, close bool) tea.Cmd {
529	if m.selectedModel == nil {
530		return util.ReportError(fmt.Errorf("no model selected"))
531	}
532
533	cfg := config.Get()
534	err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
535	if err != nil {
536		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
537	}
538
539	// Reset API key state and continue with model selection
540	selectedModel := *m.selectedModel
541	var cmds []tea.Cmd
542	if close {
543		cmds = append(cmds, util.CmdHandler(dialogs.CloseDialogMsg{}))
544	}
545	cmds = append(
546		cmds,
547		util.CmdHandler(ModelSelectedMsg{
548			Model: config.SelectedModel{
549				Model:           selectedModel.Model.ID,
550				Provider:        string(selectedModel.Provider.ID),
551				ReasoningEffort: selectedModel.Model.DefaultReasoningEffort,
552				MaxTokens:       selectedModel.Model.DefaultMaxTokens,
553			},
554			ModelType: m.selectedModelType,
555		}),
556	)
557	return tea.Sequence(cmds...)
558}