models.go

  1package models
  2
  3import (
  4	"fmt"
  5	"time"
  6
  7	"charm.land/bubbles/v2/help"
  8	"charm.land/bubbles/v2/key"
  9	"charm.land/bubbles/v2/spinner"
 10	tea "charm.land/bubbletea/v2"
 11	"charm.land/lipgloss/v2"
 12	"github.com/atotto/clipboard"
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/tui/components/core"
 16	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 17	"github.com/charmbracelet/crush/internal/tui/components/dialogs/claude"
 18	"github.com/charmbracelet/crush/internal/tui/exp/list"
 19	"github.com/charmbracelet/crush/internal/tui/styles"
 20	"github.com/charmbracelet/crush/internal/tui/util"
 21)
 22
 23const (
 24	ModelsDialogID dialogs.DialogID = "models"
 25
 26	defaultWidth = 60
 27)
 28
 29const (
 30	LargeModelType int = iota
 31	SmallModelType
 32
 33	largeModelInputPlaceholder = "Choose a model for large, complex tasks"
 34	smallModelInputPlaceholder = "Choose a model for small, simple tasks"
 35)
 36
 37// ModelSelectedMsg is sent when a model is selected
 38type ModelSelectedMsg struct {
 39	Model     config.SelectedModel
 40	ModelType config.SelectedModelType
 41}
 42
 43// CloseModelDialogMsg is sent when a model is selected
 44type CloseModelDialogMsg struct{}
 45
 46// ModelDialog interface for the model selection dialog
 47type ModelDialog interface {
 48	dialogs.DialogModel
 49}
 50
 51type ModelOption struct {
 52	Provider catwalk.Provider
 53	Model    catwalk.Model
 54}
 55
 56type modelDialogCmp struct {
 57	width   int
 58	wWidth  int
 59	wHeight int
 60
 61	modelList *ModelListComponent
 62	keyMap    KeyMap
 63	help      help.Model
 64
 65	// API key state
 66	needsAPIKey       bool
 67	apiKeyInput       *APIKeyInput
 68	selectedModel     *ModelOption
 69	selectedModelType config.SelectedModelType
 70	isAPIKeyValid     bool
 71	apiKeyValue       string
 72
 73	// Claude state
 74	claudeAuthMethodChooser     *claude.AuthMethodChooser
 75	claudeOAuth2                *claude.OAuth2
 76	showClaudeAuthMethodChooser bool
 77	showClaudeOAuth2            bool
 78}
 79
 80func NewModelDialogCmp() ModelDialog {
 81	keyMap := DefaultKeyMap()
 82
 83	listKeyMap := list.DefaultKeyMap()
 84	listKeyMap.Down.SetEnabled(false)
 85	listKeyMap.Up.SetEnabled(false)
 86	listKeyMap.DownOneItem = keyMap.Next
 87	listKeyMap.UpOneItem = keyMap.Previous
 88
 89	t := styles.CurrentTheme()
 90	modelList := NewModelListComponent(listKeyMap, largeModelInputPlaceholder, true)
 91	apiKeyInput := NewAPIKeyInput()
 92	apiKeyInput.SetShowTitle(false)
 93	help := help.New()
 94	help.Styles = t.S().Help
 95
 96	return &modelDialogCmp{
 97		modelList:   modelList,
 98		apiKeyInput: apiKeyInput,
 99		width:       defaultWidth,
100		keyMap:      DefaultKeyMap(),
101		help:        help,
102
103		claudeAuthMethodChooser: claude.NewAuthMethodChooser(),
104		claudeOAuth2:            claude.NewOAuth2(),
105	}
106}
107
108func (m *modelDialogCmp) Init() tea.Cmd {
109	return tea.Batch(
110		m.modelList.Init(),
111		m.apiKeyInput.Init(),
112		m.claudeAuthMethodChooser.Init(),
113		m.claudeOAuth2.Init(),
114	)
115}
116
117func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
118	switch msg := msg.(type) {
119	case tea.WindowSizeMsg:
120		m.wWidth = msg.Width
121		m.wHeight = msg.Height
122		m.apiKeyInput.SetWidth(m.width - 2)
123		m.help.SetWidth(m.width - 2)
124		m.claudeAuthMethodChooser.SetWidth(m.width - 2)
125		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
126	case APIKeyStateChangeMsg:
127		u, cmd := m.apiKeyInput.Update(msg)
128		m.apiKeyInput = u.(*APIKeyInput)
129		return m, cmd
130	case claude.ValidationCompletedMsg:
131		var cmds []tea.Cmd
132		u, cmd := m.claudeOAuth2.Update(msg)
133		m.claudeOAuth2 = u.(*claude.OAuth2)
134		cmds = append(cmds, cmd)
135
136		if msg.State == claude.OAuthValidationStateValid {
137			cmds = append(cmds, m.saveAPIKeyAndContinue(msg.Token, false))
138			m.keyMap.isClaudeOAuthHelpComplete = true
139		}
140
141		return m, tea.Batch(cmds...)
142	case claude.AuthenticationCompleteMsg:
143		return m, util.CmdHandler(dialogs.CloseDialogMsg{})
144	case tea.KeyPressMsg:
145		switch {
146		case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))):
147			if m.showClaudeOAuth2 && m.claudeOAuth2.State == claude.OAuthStateURL {
148				return m, tea.Sequence(
149					tea.SetClipboard(m.claudeOAuth2.URL),
150					func() tea.Msg {
151						_ = clipboard.WriteAll(m.claudeOAuth2.URL)
152						return nil
153					},
154					util.ReportInfo("URL copied to clipboard"),
155				)
156			}
157		case key.Matches(msg, m.keyMap.Choose) && m.showClaudeAuthMethodChooser:
158			m.claudeAuthMethodChooser.ToggleChoice()
159			return m, nil
160		case key.Matches(msg, m.keyMap.Select):
161			selectedItem := m.modelList.SelectedModel()
162
163			modelType := config.SelectedModelTypeLarge
164			if m.modelList.GetModelType() == SmallModelType {
165				modelType = config.SelectedModelTypeSmall
166			}
167
168			askForApiKey := func() {
169				m.keyMap.isClaudeAuthChoiseHelp = false
170				m.keyMap.isClaudeOAuthHelp = false
171				m.keyMap.isAPIKeyHelp = true
172				m.showClaudeAuthMethodChooser = false
173				m.needsAPIKey = true
174				m.selectedModel = selectedItem
175				m.selectedModelType = modelType
176				m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
177			}
178
179			if m.showClaudeAuthMethodChooser {
180				switch m.claudeAuthMethodChooser.State {
181				case claude.AuthMethodAPIKey:
182					askForApiKey()
183				case claude.AuthMethodOAuth2:
184					m.selectedModel = selectedItem
185					m.selectedModelType = modelType
186					m.showClaudeAuthMethodChooser = false
187					m.showClaudeOAuth2 = true
188					m.keyMap.isClaudeAuthChoiseHelp = false
189					m.keyMap.isClaudeOAuthHelp = true
190				}
191				return m, nil
192			}
193			if m.showClaudeOAuth2 {
194				m2, cmd2 := m.claudeOAuth2.ValidationConfirm()
195				m.claudeOAuth2 = m2.(*claude.OAuth2)
196				return m, cmd2
197			}
198			if m.isAPIKeyValid {
199				return m, m.saveAPIKeyAndContinue(m.apiKeyValue, true)
200			}
201			if m.needsAPIKey {
202				// Handle API key submission
203				m.apiKeyValue = m.apiKeyInput.Value()
204				provider, err := m.getProvider(m.selectedModel.Provider.ID)
205				if err != nil || provider == nil {
206					return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
207				}
208				providerConfig := config.ProviderConfig{
209					ID:      string(m.selectedModel.Provider.ID),
210					Name:    m.selectedModel.Provider.Name,
211					APIKey:  m.apiKeyValue,
212					Type:    provider.Type,
213					BaseURL: provider.APIEndpoint,
214				}
215				return m, tea.Sequence(
216					util.CmdHandler(APIKeyStateChangeMsg{
217						State: APIKeyInputStateVerifying,
218					}),
219					func() tea.Msg {
220						start := time.Now()
221						err := providerConfig.TestConnection(config.Get().Resolver())
222						// intentionally wait for at least 750ms to make sure the user sees the spinner
223						elapsed := time.Since(start)
224						if elapsed < 750*time.Millisecond {
225							time.Sleep(750*time.Millisecond - elapsed)
226						}
227						if err == nil {
228							m.isAPIKeyValid = true
229							return APIKeyStateChangeMsg{
230								State: APIKeyInputStateVerified,
231							}
232						}
233						return APIKeyStateChangeMsg{
234							State: APIKeyInputStateError,
235						}
236					},
237				)
238			}
239
240			// Check if provider is configured
241			if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
242				return m, tea.Sequence(
243					util.CmdHandler(dialogs.CloseDialogMsg{}),
244					util.CmdHandler(ModelSelectedMsg{
245						Model: config.SelectedModel{
246							Model:           selectedItem.Model.ID,
247							Provider:        string(selectedItem.Provider.ID),
248							ReasoningEffort: selectedItem.Model.DefaultReasoningEffort,
249							MaxTokens:       selectedItem.Model.DefaultMaxTokens,
250						},
251						ModelType: modelType,
252					}),
253				)
254			} else {
255				if selectedItem.Provider.ID == catwalk.InferenceProviderAnthropic {
256					m.showClaudeAuthMethodChooser = true
257					m.keyMap.isClaudeAuthChoiseHelp = true
258					return m, nil
259				}
260				askForApiKey()
261				return m, nil
262			}
263		case key.Matches(msg, m.keyMap.Tab):
264			switch {
265			case m.showClaudeAuthMethodChooser:
266				m.claudeAuthMethodChooser.ToggleChoice()
267				return m, nil
268			case m.needsAPIKey:
269				u, cmd := m.apiKeyInput.Update(msg)
270				m.apiKeyInput = u.(*APIKeyInput)
271				return m, cmd
272			case m.modelList.GetModelType() == LargeModelType:
273				m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
274				return m, m.modelList.SetModelType(SmallModelType)
275			default:
276				m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
277				return m, m.modelList.SetModelType(LargeModelType)
278			}
279		case key.Matches(msg, m.keyMap.Close):
280			if m.showClaudeAuthMethodChooser {
281				m.claudeAuthMethodChooser.SetDefaults()
282				m.showClaudeAuthMethodChooser = false
283				m.keyMap.isClaudeAuthChoiseHelp = false
284				m.keyMap.isClaudeOAuthHelp = false
285				return m, nil
286			}
287			if m.needsAPIKey {
288				if m.isAPIKeyValid {
289					return m, nil
290				}
291				// Go back to model selection
292				m.needsAPIKey = false
293				m.selectedModel = nil
294				m.isAPIKeyValid = false
295				m.apiKeyValue = ""
296				m.apiKeyInput.Reset()
297				return m, nil
298			}
299			return m, util.CmdHandler(dialogs.CloseDialogMsg{})
300		default:
301			if m.showClaudeAuthMethodChooser {
302				u, cmd := m.claudeAuthMethodChooser.Update(msg)
303				m.claudeAuthMethodChooser = u.(*claude.AuthMethodChooser)
304				return m, cmd
305			} else if m.showClaudeOAuth2 {
306				u, cmd := m.claudeOAuth2.Update(msg)
307				m.claudeOAuth2 = u.(*claude.OAuth2)
308				return m, cmd
309			} else if m.needsAPIKey {
310				u, cmd := m.apiKeyInput.Update(msg)
311				m.apiKeyInput = u.(*APIKeyInput)
312				return m, cmd
313			} else {
314				u, cmd := m.modelList.Update(msg)
315				m.modelList = u
316				return m, cmd
317			}
318		}
319	case tea.PasteMsg:
320		if m.showClaudeOAuth2 {
321			u, cmd := m.claudeOAuth2.Update(msg)
322			m.claudeOAuth2 = u.(*claude.OAuth2)
323			return m, cmd
324		} else if m.needsAPIKey {
325			u, cmd := m.apiKeyInput.Update(msg)
326			m.apiKeyInput = u.(*APIKeyInput)
327			return m, cmd
328		} else {
329			var cmd tea.Cmd
330			m.modelList, cmd = m.modelList.Update(msg)
331			return m, cmd
332		}
333	case spinner.TickMsg:
334		if m.showClaudeOAuth2 {
335			u, cmd := m.claudeOAuth2.Update(msg)
336			m.claudeOAuth2 = u.(*claude.OAuth2)
337			return m, cmd
338		} else {
339			u, cmd := m.apiKeyInput.Update(msg)
340			m.apiKeyInput = u.(*APIKeyInput)
341			return m, cmd
342		}
343	}
344	return m, nil
345}
346
347func (m *modelDialogCmp) View() string {
348	t := styles.CurrentTheme()
349
350	switch {
351	case m.showClaudeAuthMethodChooser:
352		chooserView := m.claudeAuthMethodChooser.View()
353		content := lipgloss.JoinVertical(
354			lipgloss.Left,
355			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
356			chooserView,
357			"",
358			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
359		)
360		return m.style().Render(content)
361	case m.showClaudeOAuth2:
362		m.keyMap.isClaudeOAuthURLState = m.claudeOAuth2.State == claude.OAuthStateURL
363		oauth2View := m.claudeOAuth2.View()
364		content := lipgloss.JoinVertical(
365			lipgloss.Left,
366			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
367			oauth2View,
368			"",
369			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
370		)
371		return m.style().Render(content)
372	case m.needsAPIKey:
373		// Show API key input
374		m.keyMap.isAPIKeyHelp = true
375		m.keyMap.isAPIKeyValid = m.isAPIKeyValid
376		apiKeyView := m.apiKeyInput.View()
377		apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
378		content := lipgloss.JoinVertical(
379			lipgloss.Left,
380			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
381			apiKeyView,
382			"",
383			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
384		)
385		return m.style().Render(content)
386	}
387
388	// Show model selection
389	listView := m.modelList.View()
390	radio := m.modelTypeRadio()
391	content := lipgloss.JoinVertical(
392		lipgloss.Left,
393		t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
394		listView,
395		"",
396		t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
397	)
398	return m.style().Render(content)
399}
400
401func (m *modelDialogCmp) Cursor() *tea.Cursor {
402	if m.showClaudeAuthMethodChooser {
403		return nil
404	}
405	if m.showClaudeOAuth2 {
406		if cursor := m.claudeOAuth2.CodeInput.Cursor(); cursor != nil {
407			cursor.Y += 2 // FIXME(@andreynering): Why do we need this?
408			return m.moveCursor(cursor)
409		}
410		return nil
411	}
412	if m.needsAPIKey {
413		cursor := m.apiKeyInput.Cursor()
414		if cursor != nil {
415			cursor = m.moveCursor(cursor)
416			return cursor
417		}
418	} else {
419		cursor := m.modelList.Cursor()
420		if cursor != nil {
421			cursor = m.moveCursor(cursor)
422			return cursor
423		}
424	}
425	return nil
426}
427
428func (m *modelDialogCmp) style() lipgloss.Style {
429	t := styles.CurrentTheme()
430	return t.S().Base.
431		Width(m.width).
432		Border(lipgloss.RoundedBorder()).
433		BorderForeground(t.BorderFocus)
434}
435
436func (m *modelDialogCmp) listWidth() int {
437	return m.width - 2
438}
439
440func (m *modelDialogCmp) listHeight() int {
441	return m.wHeight / 2
442}
443
444func (m *modelDialogCmp) Position() (int, int) {
445	row := m.wHeight/4 - 2 // just a bit above the center
446	col := m.wWidth / 2
447	col -= m.width / 2
448	return row, col
449}
450
451func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
452	row, col := m.Position()
453	if m.needsAPIKey {
454		offset := row + 3 // Border + title + API key input offset
455		cursor.Y += offset
456		cursor.X = cursor.X + col + 2
457	} else {
458		offset := row + 3 // Border + title
459		cursor.Y += offset
460		cursor.X = cursor.X + col + 2
461	}
462	return cursor
463}
464
465func (m *modelDialogCmp) ID() dialogs.DialogID {
466	return ModelsDialogID
467}
468
469func (m *modelDialogCmp) modelTypeRadio() string {
470	t := styles.CurrentTheme()
471	choices := []string{"Large Task", "Small Task"}
472	iconSelected := "◉"
473	iconUnselected := "○"
474	if m.modelList.GetModelType() == LargeModelType {
475		return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + "  " + iconUnselected + " " + choices[1])
476	}
477	return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + "  " + iconSelected + " " + choices[1])
478}
479
480func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
481	cfg := config.Get()
482	if _, ok := cfg.Providers.Get(providerID); ok {
483		return true
484	}
485	return false
486}
487
488func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
489	cfg := config.Get()
490	providers, err := config.Providers(cfg)
491	if err != nil {
492		return nil, err
493	}
494	for _, p := range providers {
495		if p.ID == providerID {
496			return &p, nil
497		}
498	}
499	return nil, nil
500}
501
502func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey any, close bool) tea.Cmd {
503	if m.selectedModel == nil {
504		return util.ReportError(fmt.Errorf("no model selected"))
505	}
506
507	cfg := config.Get()
508	err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
509	if err != nil {
510		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
511	}
512
513	// Reset API key state and continue with model selection
514	selectedModel := *m.selectedModel
515	var cmds []tea.Cmd
516	if close {
517		cmds = append(cmds, util.CmdHandler(dialogs.CloseDialogMsg{}))
518	}
519	cmds = append(
520		cmds,
521		util.CmdHandler(ModelSelectedMsg{
522			Model: config.SelectedModel{
523				Model:           selectedModel.Model.ID,
524				Provider:        string(selectedModel.Provider.ID),
525				ReasoningEffort: selectedModel.Model.DefaultReasoningEffort,
526				MaxTokens:       selectedModel.Model.DefaultMaxTokens,
527			},
528			ModelType: m.selectedModelType,
529		}),
530	)
531	return tea.Sequence(cmds...)
532}