models.go

  1package models
  2
  3import (
  4	"fmt"
  5	"time"
  6
  7	"charm.land/bubbles/v2/help"
  8	"charm.land/bubbles/v2/key"
  9	"charm.land/bubbles/v2/spinner"
 10	tea "charm.land/bubbletea/v2"
 11	"charm.land/lipgloss/v2"
 12	"github.com/atotto/clipboard"
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/tui/components/core"
 16	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 17	"github.com/charmbracelet/crush/internal/tui/components/dialogs/claude"
 18	"github.com/charmbracelet/crush/internal/tui/exp/list"
 19	"github.com/charmbracelet/crush/internal/tui/styles"
 20	"github.com/charmbracelet/crush/internal/tui/util"
 21)
 22
 23const (
 24	ModelsDialogID dialogs.DialogID = "models"
 25
 26	defaultWidth = 60
 27)
 28
 29const (
 30	LargeModelType int = iota
 31	SmallModelType
 32
 33	largeModelInputPlaceholder = "Choose a model for large, complex tasks"
 34	smallModelInputPlaceholder = "Choose a model for small, simple tasks"
 35)
 36
 37// ModelSelectedMsg is sent when a model is selected
 38type ModelSelectedMsg struct {
 39	Model     config.SelectedModel
 40	ModelType config.SelectedModelType
 41}
 42
 43// CloseModelDialogMsg is sent when a model is selected
 44type CloseModelDialogMsg struct{}
 45
 46// ModelDialog interface for the model selection dialog
 47type ModelDialog interface {
 48	dialogs.DialogModel
 49}
 50
 51type ModelOption struct {
 52	Provider catwalk.Provider
 53	Model    catwalk.Model
 54}
 55
 56type modelDialogCmp struct {
 57	width   int
 58	wWidth  int
 59	wHeight int
 60
 61	modelList *ModelListComponent
 62	keyMap    KeyMap
 63	help      help.Model
 64
 65	// API key state
 66	needsAPIKey       bool
 67	apiKeyInput       *APIKeyInput
 68	selectedModel     *ModelOption
 69	selectedModelType config.SelectedModelType
 70	isAPIKeyValid     bool
 71	apiKeyValue       string
 72
 73	// Claude state
 74	claudeAuthMethodChooser     *claude.AuthMethodChooser
 75	claudeOAuth2                *claude.OAuth2
 76	showClaudeAuthMethodChooser bool
 77	showClaudeOAuth2            bool
 78}
 79
 80func NewModelDialogCmp() ModelDialog {
 81	keyMap := DefaultKeyMap()
 82
 83	listKeyMap := list.DefaultKeyMap()
 84	listKeyMap.Down.SetEnabled(false)
 85	listKeyMap.Up.SetEnabled(false)
 86	listKeyMap.DownOneItem = keyMap.Next
 87	listKeyMap.UpOneItem = keyMap.Previous
 88
 89	t := styles.CurrentTheme()
 90	modelList := NewModelListComponent(listKeyMap, largeModelInputPlaceholder, true)
 91	apiKeyInput := NewAPIKeyInput()
 92	apiKeyInput.SetShowTitle(false)
 93	help := help.New()
 94	help.Styles = t.S().Help
 95
 96	return &modelDialogCmp{
 97		modelList:   modelList,
 98		apiKeyInput: apiKeyInput,
 99		width:       defaultWidth,
100		keyMap:      DefaultKeyMap(),
101		help:        help,
102
103		claudeAuthMethodChooser: claude.NewAuthMethodChooser(),
104		claudeOAuth2:            claude.NewOAuth2(),
105	}
106}
107
108func (m *modelDialogCmp) Init() tea.Cmd {
109	return tea.Batch(
110		m.modelList.Init(),
111		m.apiKeyInput.Init(),
112		m.claudeAuthMethodChooser.Init(),
113		m.claudeOAuth2.Init(),
114	)
115}
116
117func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
118	switch msg := msg.(type) {
119	case tea.WindowSizeMsg:
120		m.wWidth = msg.Width
121		m.wHeight = msg.Height
122		m.apiKeyInput.SetWidth(m.width - 2)
123		m.help.SetWidth(m.width - 2)
124		m.claudeAuthMethodChooser.SetWidth(m.width - 2)
125		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
126	case APIKeyStateChangeMsg:
127		u, cmd := m.apiKeyInput.Update(msg)
128		m.apiKeyInput = u.(*APIKeyInput)
129		return m, cmd
130	case claude.ValidationCompletedMsg:
131		var cmds []tea.Cmd
132		u, cmd := m.claudeOAuth2.Update(msg)
133		m.claudeOAuth2 = u.(*claude.OAuth2)
134		cmds = append(cmds, cmd)
135
136		if msg.State == claude.OAuthValidationStateValid {
137			cmds = append(cmds, m.saveAPIKeyAndContinue(msg.Token, false))
138			m.keyMap.isClaudeOAuthHelpComplete = true
139		}
140
141		return m, tea.Batch(cmds...)
142	case claude.AuthenticationCompleteMsg:
143		return m, util.CmdHandler(dialogs.CloseDialogMsg{})
144	case tea.KeyPressMsg:
145		switch {
146		case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))):
147			if m.showClaudeOAuth2 && m.claudeOAuth2.State == claude.OAuthStateURL {
148				return m, tea.Sequence(
149					tea.SetClipboard(m.claudeOAuth2.URL),
150					func() tea.Msg {
151						_ = clipboard.WriteAll(m.claudeOAuth2.URL)
152						return nil
153					},
154					util.ReportInfo("URL copied to clipboard"),
155				)
156			}
157		case key.Matches(msg, m.keyMap.Choose):
158			if m.showClaudeAuthMethodChooser {
159				m.claudeAuthMethodChooser.ToggleChoice()
160				return m, nil
161			}
162		case key.Matches(msg, m.keyMap.Select):
163			selectedItem := m.modelList.SelectedModel()
164
165			modelType := config.SelectedModelTypeLarge
166			if m.modelList.GetModelType() == SmallModelType {
167				modelType = config.SelectedModelTypeSmall
168			}
169
170			askForApiKey := func() {
171				m.keyMap.isClaudeAuthChoiseHelp = false
172				m.keyMap.isClaudeOAuthHelp = false
173				m.keyMap.isAPIKeyHelp = true
174				m.showClaudeAuthMethodChooser = false
175				m.needsAPIKey = true
176				m.selectedModel = selectedItem
177				m.selectedModelType = modelType
178				m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
179			}
180
181			if m.showClaudeAuthMethodChooser {
182				switch m.claudeAuthMethodChooser.State {
183				case claude.AuthMethodAPIKey:
184					askForApiKey()
185				case claude.AuthMethodOAuth2:
186					m.selectedModel = selectedItem
187					m.selectedModelType = modelType
188					m.showClaudeAuthMethodChooser = false
189					m.showClaudeOAuth2 = true
190					m.keyMap.isClaudeAuthChoiseHelp = false
191					m.keyMap.isClaudeOAuthHelp = true
192				}
193				return m, nil
194			}
195			if m.showClaudeOAuth2 {
196				m2, cmd2 := m.claudeOAuth2.ValidationConfirm()
197				m.claudeOAuth2 = m2.(*claude.OAuth2)
198				return m, cmd2
199			}
200			if m.isAPIKeyValid {
201				return m, m.saveAPIKeyAndContinue(m.apiKeyValue, true)
202			}
203			if m.needsAPIKey {
204				// Handle API key submission
205				m.apiKeyValue = m.apiKeyInput.Value()
206				provider, err := m.getProvider(m.selectedModel.Provider.ID)
207				if err != nil || provider == nil {
208					return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
209				}
210				providerConfig := config.ProviderConfig{
211					ID:      string(m.selectedModel.Provider.ID),
212					Name:    m.selectedModel.Provider.Name,
213					APIKey:  m.apiKeyValue,
214					Type:    provider.Type,
215					BaseURL: provider.APIEndpoint,
216				}
217				return m, tea.Sequence(
218					util.CmdHandler(APIKeyStateChangeMsg{
219						State: APIKeyInputStateVerifying,
220					}),
221					func() tea.Msg {
222						start := time.Now()
223						err := providerConfig.TestConnection(config.Get().Resolver())
224						// intentionally wait for at least 750ms to make sure the user sees the spinner
225						elapsed := time.Since(start)
226						if elapsed < 750*time.Millisecond {
227							time.Sleep(750*time.Millisecond - elapsed)
228						}
229						if err == nil {
230							m.isAPIKeyValid = true
231							return APIKeyStateChangeMsg{
232								State: APIKeyInputStateVerified,
233							}
234						}
235						return APIKeyStateChangeMsg{
236							State: APIKeyInputStateError,
237						}
238					},
239				)
240			}
241
242			// Check if provider is configured
243			if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
244				return m, tea.Sequence(
245					util.CmdHandler(dialogs.CloseDialogMsg{}),
246					util.CmdHandler(ModelSelectedMsg{
247						Model: config.SelectedModel{
248							Model:           selectedItem.Model.ID,
249							Provider:        string(selectedItem.Provider.ID),
250							ReasoningEffort: selectedItem.Model.DefaultReasoningEffort,
251							MaxTokens:       selectedItem.Model.DefaultMaxTokens,
252						},
253						ModelType: modelType,
254					}),
255				)
256			} else {
257				if selectedItem.Provider.ID == catwalk.InferenceProviderAnthropic {
258					m.showClaudeAuthMethodChooser = true
259					m.keyMap.isClaudeAuthChoiseHelp = true
260					return m, nil
261				}
262				askForApiKey()
263				return m, nil
264			}
265		case key.Matches(msg, m.keyMap.Tab):
266			switch {
267			case m.showClaudeAuthMethodChooser:
268				m.claudeAuthMethodChooser.ToggleChoice()
269				return m, nil
270			case m.needsAPIKey:
271				u, cmd := m.apiKeyInput.Update(msg)
272				m.apiKeyInput = u.(*APIKeyInput)
273				return m, cmd
274			case m.modelList.GetModelType() == LargeModelType:
275				m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
276				return m, m.modelList.SetModelType(SmallModelType)
277			default:
278				m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
279				return m, m.modelList.SetModelType(LargeModelType)
280			}
281		case key.Matches(msg, m.keyMap.Close):
282			if m.showClaudeAuthMethodChooser {
283				m.claudeAuthMethodChooser.SetDefaults()
284				m.showClaudeAuthMethodChooser = false
285				m.keyMap.isClaudeAuthChoiseHelp = false
286				m.keyMap.isClaudeOAuthHelp = false
287				return m, nil
288			}
289			if m.needsAPIKey {
290				if m.isAPIKeyValid {
291					return m, nil
292				}
293				// Go back to model selection
294				m.needsAPIKey = false
295				m.selectedModel = nil
296				m.isAPIKeyValid = false
297				m.apiKeyValue = ""
298				m.apiKeyInput.Reset()
299				return m, nil
300			}
301			return m, util.CmdHandler(dialogs.CloseDialogMsg{})
302		default:
303			if m.showClaudeAuthMethodChooser {
304				u, cmd := m.claudeAuthMethodChooser.Update(msg)
305				m.claudeAuthMethodChooser = u.(*claude.AuthMethodChooser)
306				return m, cmd
307			} else if m.showClaudeOAuth2 {
308				u, cmd := m.claudeOAuth2.Update(msg)
309				m.claudeOAuth2 = u.(*claude.OAuth2)
310				return m, cmd
311			} else if m.needsAPIKey {
312				u, cmd := m.apiKeyInput.Update(msg)
313				m.apiKeyInput = u.(*APIKeyInput)
314				return m, cmd
315			} else {
316				u, cmd := m.modelList.Update(msg)
317				m.modelList = u
318				return m, cmd
319			}
320		}
321	case tea.PasteMsg:
322		if m.showClaudeOAuth2 {
323			u, cmd := m.claudeOAuth2.Update(msg)
324			m.claudeOAuth2 = u.(*claude.OAuth2)
325			return m, cmd
326		} else if m.needsAPIKey {
327			u, cmd := m.apiKeyInput.Update(msg)
328			m.apiKeyInput = u.(*APIKeyInput)
329			return m, cmd
330		} else {
331			var cmd tea.Cmd
332			m.modelList, cmd = m.modelList.Update(msg)
333			return m, cmd
334		}
335	case spinner.TickMsg:
336		if m.showClaudeOAuth2 {
337			u, cmd := m.claudeOAuth2.Update(msg)
338			m.claudeOAuth2 = u.(*claude.OAuth2)
339			return m, cmd
340		} else {
341			u, cmd := m.apiKeyInput.Update(msg)
342			m.apiKeyInput = u.(*APIKeyInput)
343			return m, cmd
344		}
345	}
346	return m, nil
347}
348
349func (m *modelDialogCmp) View() string {
350	t := styles.CurrentTheme()
351
352	switch {
353	case m.showClaudeAuthMethodChooser:
354		chooserView := m.claudeAuthMethodChooser.View()
355		content := lipgloss.JoinVertical(
356			lipgloss.Left,
357			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
358			chooserView,
359			"",
360			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
361		)
362		return m.style().Render(content)
363	case m.showClaudeOAuth2:
364		m.keyMap.isClaudeOAuthURLState = m.claudeOAuth2.State == claude.OAuthStateURL
365		oauth2View := m.claudeOAuth2.View()
366		content := lipgloss.JoinVertical(
367			lipgloss.Left,
368			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
369			oauth2View,
370			"",
371			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
372		)
373		return m.style().Render(content)
374	case m.needsAPIKey:
375		// Show API key input
376		m.keyMap.isAPIKeyHelp = true
377		m.keyMap.isAPIKeyValid = m.isAPIKeyValid
378		apiKeyView := m.apiKeyInput.View()
379		apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
380		content := lipgloss.JoinVertical(
381			lipgloss.Left,
382			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
383			apiKeyView,
384			"",
385			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
386		)
387		return m.style().Render(content)
388	}
389
390	// Show model selection
391	listView := m.modelList.View()
392	radio := m.modelTypeRadio()
393	content := lipgloss.JoinVertical(
394		lipgloss.Left,
395		t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
396		listView,
397		"",
398		t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
399	)
400	return m.style().Render(content)
401}
402
403func (m *modelDialogCmp) Cursor() *tea.Cursor {
404	if m.showClaudeAuthMethodChooser {
405		return nil
406	}
407	if m.showClaudeOAuth2 {
408		if cursor := m.claudeOAuth2.CodeInput.Cursor(); cursor != nil {
409			cursor.Y += 2 // FIXME(@andreynering): Why do we need this?
410			return m.moveCursor(cursor)
411		}
412		return nil
413	}
414	if m.needsAPIKey {
415		cursor := m.apiKeyInput.Cursor()
416		if cursor != nil {
417			cursor = m.moveCursor(cursor)
418			return cursor
419		}
420	} else {
421		cursor := m.modelList.Cursor()
422		if cursor != nil {
423			cursor = m.moveCursor(cursor)
424			return cursor
425		}
426	}
427	return nil
428}
429
430func (m *modelDialogCmp) style() lipgloss.Style {
431	t := styles.CurrentTheme()
432	return t.S().Base.
433		Width(m.width).
434		Border(lipgloss.RoundedBorder()).
435		BorderForeground(t.BorderFocus)
436}
437
438func (m *modelDialogCmp) listWidth() int {
439	return m.width - 2
440}
441
442func (m *modelDialogCmp) listHeight() int {
443	return m.wHeight / 2
444}
445
446func (m *modelDialogCmp) Position() (int, int) {
447	row := m.wHeight/4 - 2 // just a bit above the center
448	col := m.wWidth / 2
449	col -= m.width / 2
450	return row, col
451}
452
453func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
454	row, col := m.Position()
455	if m.needsAPIKey {
456		offset := row + 3 // Border + title + API key input offset
457		cursor.Y += offset
458		cursor.X = cursor.X + col + 2
459	} else {
460		offset := row + 3 // Border + title
461		cursor.Y += offset
462		cursor.X = cursor.X + col + 2
463	}
464	return cursor
465}
466
467func (m *modelDialogCmp) ID() dialogs.DialogID {
468	return ModelsDialogID
469}
470
471func (m *modelDialogCmp) modelTypeRadio() string {
472	t := styles.CurrentTheme()
473	choices := []string{"Large Task", "Small Task"}
474	iconSelected := "◉"
475	iconUnselected := "○"
476	if m.modelList.GetModelType() == LargeModelType {
477		return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + "  " + iconUnselected + " " + choices[1])
478	}
479	return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + "  " + iconSelected + " " + choices[1])
480}
481
482func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
483	cfg := config.Get()
484	if _, ok := cfg.Providers.Get(providerID); ok {
485		return true
486	}
487	return false
488}
489
490func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
491	cfg := config.Get()
492	providers, err := config.Providers(cfg)
493	if err != nil {
494		return nil, err
495	}
496	for _, p := range providers {
497		if p.ID == providerID {
498			return &p, nil
499		}
500	}
501	return nil, nil
502}
503
504func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey any, close bool) tea.Cmd {
505	if m.selectedModel == nil {
506		return util.ReportError(fmt.Errorf("no model selected"))
507	}
508
509	cfg := config.Get()
510	err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
511	if err != nil {
512		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
513	}
514
515	// Reset API key state and continue with model selection
516	selectedModel := *m.selectedModel
517	var cmds []tea.Cmd
518	if close {
519		cmds = append(cmds, util.CmdHandler(dialogs.CloseDialogMsg{}))
520	}
521	cmds = append(
522		cmds,
523		util.CmdHandler(ModelSelectedMsg{
524			Model: config.SelectedModel{
525				Model:           selectedModel.Model.ID,
526				Provider:        string(selectedModel.Provider.ID),
527				ReasoningEffort: selectedModel.Model.DefaultReasoningEffort,
528				MaxTokens:       selectedModel.Model.DefaultMaxTokens,
529			},
530			ModelType: m.selectedModelType,
531		}),
532	)
533	return tea.Sequence(cmds...)
534}