models.go

  1package models
  2
  3import (
  4	"fmt"
  5	"time"
  6
  7	"charm.land/bubbles/v2/help"
  8	"charm.land/bubbles/v2/key"
  9	"charm.land/bubbles/v2/spinner"
 10	tea "charm.land/bubbletea/v2"
 11	"charm.land/lipgloss/v2"
 12	"github.com/atotto/clipboard"
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/tui/components/core"
 16	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 17	"github.com/charmbracelet/crush/internal/tui/components/dialogs/claude"
 18	"github.com/charmbracelet/crush/internal/tui/exp/list"
 19	"github.com/charmbracelet/crush/internal/tui/styles"
 20	"github.com/charmbracelet/crush/internal/tui/util"
 21)
 22
 23const (
 24	ModelsDialogID dialogs.DialogID = "models"
 25
 26	defaultWidth = 60
 27)
 28
 29const (
 30	LargeModelType int = iota
 31	SmallModelType
 32
 33	largeModelInputPlaceholder = "Choose a model for large, complex tasks"
 34	smallModelInputPlaceholder = "Choose a model for small, simple tasks"
 35)
 36
 37// ModelSelectedMsg is sent when a model is selected
 38type ModelSelectedMsg struct {
 39	Model     config.SelectedModel
 40	ModelType config.SelectedModelType
 41}
 42
 43// CloseModelDialogMsg is sent when a model is selected
 44type CloseModelDialogMsg struct{}
 45
 46// ModelDialog interface for the model selection dialog
 47type ModelDialog interface {
 48	dialogs.DialogModel
 49}
 50
 51type ModelOption struct {
 52	Provider catwalk.Provider
 53	Model    catwalk.Model
 54}
 55
 56type modelDialogCmp struct {
 57	width   int
 58	wWidth  int
 59	wHeight int
 60
 61	modelList *ModelListComponent
 62	keyMap    KeyMap
 63	help      help.Model
 64
 65	// API key state
 66	needsAPIKey       bool
 67	apiKeyInput       *APIKeyInput
 68	selectedModel     *ModelOption
 69	selectedModelType config.SelectedModelType
 70	isAPIKeyValid     bool
 71	apiKeyValue       string
 72
 73	// Claude state
 74	claudeAuthMethodChooser     *claude.AuthMethodChooser
 75	claudeOAuth2                *claude.OAuth2
 76	showClaudeAuthMethodChooser bool
 77	showClaudeOAuth2            bool
 78}
 79
 80func NewModelDialogCmp() ModelDialog {
 81	keyMap := DefaultKeyMap()
 82
 83	listKeyMap := list.DefaultKeyMap()
 84	listKeyMap.Down.SetEnabled(false)
 85	listKeyMap.Up.SetEnabled(false)
 86	listKeyMap.DownOneItem = keyMap.Next
 87	listKeyMap.UpOneItem = keyMap.Previous
 88
 89	t := styles.CurrentTheme()
 90	modelList := NewModelListComponent(listKeyMap, largeModelInputPlaceholder, true)
 91	apiKeyInput := NewAPIKeyInput()
 92	apiKeyInput.SetShowTitle(false)
 93	help := help.New()
 94	help.Styles = t.S().Help
 95
 96	return &modelDialogCmp{
 97		modelList:   modelList,
 98		apiKeyInput: apiKeyInput,
 99		width:       defaultWidth,
100		keyMap:      DefaultKeyMap(),
101		help:        help,
102
103		claudeAuthMethodChooser: claude.NewAuthMethodChooser(),
104		claudeOAuth2:            claude.NewOAuth2(),
105	}
106}
107
108func (m *modelDialogCmp) Init() tea.Cmd {
109	return tea.Batch(
110		m.modelList.Init(),
111		m.apiKeyInput.Init(),
112		m.claudeAuthMethodChooser.Init(),
113		m.claudeOAuth2.Init(),
114	)
115}
116
117func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
118	switch msg := msg.(type) {
119	case tea.WindowSizeMsg:
120		m.wWidth = msg.Width
121		m.wHeight = msg.Height
122		m.apiKeyInput.SetWidth(m.width - 2)
123		m.help.SetWidth(m.width - 2)
124		m.claudeAuthMethodChooser.SetWidth(m.width - 2)
125		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
126	case APIKeyStateChangeMsg:
127		u, cmd := m.apiKeyInput.Update(msg)
128		m.apiKeyInput = u.(*APIKeyInput)
129		return m, cmd
130	case claude.ValidationCompletedMsg:
131		var cmds []tea.Cmd
132		u, cmd := m.claudeOAuth2.Update(msg)
133		m.claudeOAuth2 = u.(*claude.OAuth2)
134		cmds = append(cmds, cmd)
135
136		if msg.State == claude.OAuthValidationStateValid {
137			cmds = append(cmds, m.saveAPIKeyAndContinue(msg.Token, false))
138			m.keyMap.isClaudeOAuthHelpComplete = true
139		}
140
141		return m, tea.Batch(cmds...)
142	case claude.AuthenticationCompleteMsg:
143		return m, util.CmdHandler(dialogs.CloseDialogMsg{})
144	case tea.KeyPressMsg:
145		switch {
146		case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showClaudeOAuth2 && m.claudeOAuth2.State == claude.OAuthStateURL:
147			return m, tea.Sequence(
148				tea.SetClipboard(m.claudeOAuth2.URL),
149				func() tea.Msg {
150					_ = clipboard.WriteAll(m.claudeOAuth2.URL)
151					return nil
152				},
153				util.ReportInfo("URL copied to clipboard"),
154			)
155		case key.Matches(msg, m.keyMap.Choose) && m.showClaudeAuthMethodChooser:
156			m.claudeAuthMethodChooser.ToggleChoice()
157			return m, nil
158		case key.Matches(msg, m.keyMap.Select):
159			selectedItem := m.modelList.SelectedModel()
160
161			modelType := config.SelectedModelTypeLarge
162			if m.modelList.GetModelType() == SmallModelType {
163				modelType = config.SelectedModelTypeSmall
164			}
165
166			askForApiKey := func() {
167				m.keyMap.isClaudeAuthChoiseHelp = false
168				m.keyMap.isClaudeOAuthHelp = false
169				m.keyMap.isAPIKeyHelp = true
170				m.showClaudeAuthMethodChooser = false
171				m.needsAPIKey = true
172				m.selectedModel = selectedItem
173				m.selectedModelType = modelType
174				m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
175			}
176
177			if m.showClaudeAuthMethodChooser {
178				switch m.claudeAuthMethodChooser.State {
179				case claude.AuthMethodAPIKey:
180					askForApiKey()
181				case claude.AuthMethodOAuth2:
182					m.selectedModel = selectedItem
183					m.selectedModelType = modelType
184					m.showClaudeAuthMethodChooser = false
185					m.showClaudeOAuth2 = true
186					m.keyMap.isClaudeAuthChoiseHelp = false
187					m.keyMap.isClaudeOAuthHelp = true
188				}
189				return m, nil
190			}
191			if m.showClaudeOAuth2 {
192				m2, cmd2 := m.claudeOAuth2.ValidationConfirm()
193				m.claudeOAuth2 = m2.(*claude.OAuth2)
194				return m, cmd2
195			}
196			if m.isAPIKeyValid {
197				return m, m.saveAPIKeyAndContinue(m.apiKeyValue, true)
198			}
199			if m.needsAPIKey {
200				// Handle API key submission
201				m.apiKeyValue = m.apiKeyInput.Value()
202				provider, err := m.getProvider(m.selectedModel.Provider.ID)
203				if err != nil || provider == nil {
204					return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
205				}
206				providerConfig := config.ProviderConfig{
207					ID:      string(m.selectedModel.Provider.ID),
208					Name:    m.selectedModel.Provider.Name,
209					APIKey:  m.apiKeyValue,
210					Type:    provider.Type,
211					BaseURL: provider.APIEndpoint,
212				}
213				return m, tea.Sequence(
214					util.CmdHandler(APIKeyStateChangeMsg{
215						State: APIKeyInputStateVerifying,
216					}),
217					func() tea.Msg {
218						start := time.Now()
219						err := providerConfig.TestConnection(config.Get().Resolver())
220						// intentionally wait for at least 750ms to make sure the user sees the spinner
221						elapsed := time.Since(start)
222						if elapsed < 750*time.Millisecond {
223							time.Sleep(750*time.Millisecond - elapsed)
224						}
225						if err == nil {
226							m.isAPIKeyValid = true
227							return APIKeyStateChangeMsg{
228								State: APIKeyInputStateVerified,
229							}
230						}
231						return APIKeyStateChangeMsg{
232							State: APIKeyInputStateError,
233						}
234					},
235				)
236			}
237
238			// Check if provider is configured
239			if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
240				return m, tea.Sequence(
241					util.CmdHandler(dialogs.CloseDialogMsg{}),
242					util.CmdHandler(ModelSelectedMsg{
243						Model: config.SelectedModel{
244							Model:           selectedItem.Model.ID,
245							Provider:        string(selectedItem.Provider.ID),
246							ReasoningEffort: selectedItem.Model.DefaultReasoningEffort,
247							MaxTokens:       selectedItem.Model.DefaultMaxTokens,
248						},
249						ModelType: modelType,
250					}),
251				)
252			} else {
253				if selectedItem.Provider.ID == catwalk.InferenceProviderAnthropic {
254					m.showClaudeAuthMethodChooser = true
255					m.keyMap.isClaudeAuthChoiseHelp = true
256					return m, nil
257				}
258				askForApiKey()
259				return m, nil
260			}
261		case key.Matches(msg, m.keyMap.Tab):
262			switch {
263			case m.showClaudeAuthMethodChooser:
264				m.claudeAuthMethodChooser.ToggleChoice()
265				return m, nil
266			case m.needsAPIKey:
267				u, cmd := m.apiKeyInput.Update(msg)
268				m.apiKeyInput = u.(*APIKeyInput)
269				return m, cmd
270			case m.modelList.GetModelType() == LargeModelType:
271				m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
272				return m, m.modelList.SetModelType(SmallModelType)
273			default:
274				m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
275				return m, m.modelList.SetModelType(LargeModelType)
276			}
277		case key.Matches(msg, m.keyMap.Edit):
278			// Only handle edit key in the main model selection view
279			if !m.showClaudeAuthMethodChooser && !m.showClaudeOAuth2 && !m.needsAPIKey {
280				selectedItem := m.modelList.SelectedModel()
281				if selectedItem != nil {
282					// Check if provider is configured
283					if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
284						// Trigger API key editing flow
285						m.keyMap.isClaudeAuthChoiseHelp = false
286						m.keyMap.isClaudeOAuthHelp = false
287						m.keyMap.isAPIKeyHelp = true
288						m.showClaudeAuthMethodChooser = false
289						m.needsAPIKey = true
290						m.selectedModel = selectedItem
291						m.selectedModelType = config.SelectedModelTypeLarge
292						if m.modelList.GetModelType() == SmallModelType {
293							m.selectedModelType = config.SelectedModelTypeSmall
294						}
295						m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
296						// Pre-fill with existing API key if available
297						if providerConfig, ok := config.Get().Providers.Get(string(selectedItem.Provider.ID)); ok && providerConfig.APIKey != "" {
298							m.apiKeyInput.input.SetValue(providerConfig.APIKey)
299						}
300						return m, nil
301					}
302				}
303			}
304			return m, nil
305		case key.Matches(msg, m.keyMap.Close):
306			if m.showClaudeAuthMethodChooser {
307				m.claudeAuthMethodChooser.SetDefaults()
308				m.showClaudeAuthMethodChooser = false
309				m.keyMap.isClaudeAuthChoiseHelp = false
310				m.keyMap.isClaudeOAuthHelp = false
311				return m, nil
312			}
313			if m.needsAPIKey {
314				if m.isAPIKeyValid {
315					return m, nil
316				}
317				// Go back to model selection
318				m.needsAPIKey = false
319				m.selectedModel = nil
320				m.isAPIKeyValid = false
321				m.apiKeyValue = ""
322				m.apiKeyInput.Reset()
323				m.keyMap.isAPIKeyHelp = false
324				return m, nil
325			}
326			return m, util.CmdHandler(dialogs.CloseDialogMsg{})
327		default:
328			if m.showClaudeAuthMethodChooser {
329				u, cmd := m.claudeAuthMethodChooser.Update(msg)
330				m.claudeAuthMethodChooser = u.(*claude.AuthMethodChooser)
331				return m, cmd
332			} else if m.showClaudeOAuth2 {
333				u, cmd := m.claudeOAuth2.Update(msg)
334				m.claudeOAuth2 = u.(*claude.OAuth2)
335				return m, cmd
336			} else if m.needsAPIKey {
337				u, cmd := m.apiKeyInput.Update(msg)
338				m.apiKeyInput = u.(*APIKeyInput)
339				return m, cmd
340			} else {
341				u, cmd := m.modelList.Update(msg)
342				m.modelList = u
343				return m, cmd
344			}
345		}
346	case tea.PasteMsg:
347		if m.showClaudeOAuth2 {
348			u, cmd := m.claudeOAuth2.Update(msg)
349			m.claudeOAuth2 = u.(*claude.OAuth2)
350			return m, cmd
351		} else if m.needsAPIKey {
352			u, cmd := m.apiKeyInput.Update(msg)
353			m.apiKeyInput = u.(*APIKeyInput)
354			return m, cmd
355		} else {
356			var cmd tea.Cmd
357			m.modelList, cmd = m.modelList.Update(msg)
358			return m, cmd
359		}
360	case spinner.TickMsg:
361		if m.showClaudeOAuth2 {
362			u, cmd := m.claudeOAuth2.Update(msg)
363			m.claudeOAuth2 = u.(*claude.OAuth2)
364			return m, cmd
365		} else {
366			u, cmd := m.apiKeyInput.Update(msg)
367			m.apiKeyInput = u.(*APIKeyInput)
368			return m, cmd
369		}
370	}
371	return m, nil
372}
373
374func (m *modelDialogCmp) View() string {
375	t := styles.CurrentTheme()
376
377	switch {
378	case m.showClaudeAuthMethodChooser:
379		chooserView := m.claudeAuthMethodChooser.View()
380		content := lipgloss.JoinVertical(
381			lipgloss.Left,
382			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
383			chooserView,
384			"",
385			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
386		)
387		return m.style().Render(content)
388	case m.showClaudeOAuth2:
389		m.keyMap.isClaudeOAuthURLState = m.claudeOAuth2.State == claude.OAuthStateURL
390		oauth2View := m.claudeOAuth2.View()
391		content := lipgloss.JoinVertical(
392			lipgloss.Left,
393			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
394			oauth2View,
395			"",
396			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
397		)
398		return m.style().Render(content)
399	case m.needsAPIKey:
400		// Show API key input
401		m.keyMap.isAPIKeyHelp = true
402		m.keyMap.isAPIKeyValid = m.isAPIKeyValid
403		apiKeyView := m.apiKeyInput.View()
404		apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
405		content := lipgloss.JoinVertical(
406			lipgloss.Left,
407			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
408			apiKeyView,
409			"",
410			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
411		)
412		return m.style().Render(content)
413	}
414
415	// Show model selection
416	listView := m.modelList.View()
417	radio := m.modelTypeRadio()
418	content := lipgloss.JoinVertical(
419		lipgloss.Left,
420		t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
421		listView,
422		"",
423		t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
424	)
425	return m.style().Render(content)
426}
427
428func (m *modelDialogCmp) Cursor() *tea.Cursor {
429	if m.showClaudeAuthMethodChooser {
430		return nil
431	}
432	if m.showClaudeOAuth2 {
433		if cursor := m.claudeOAuth2.CodeInput.Cursor(); cursor != nil {
434			cursor.Y += 2 // FIXME(@andreynering): Why do we need this?
435			return m.moveCursor(cursor)
436		}
437		return nil
438	}
439	if m.needsAPIKey {
440		cursor := m.apiKeyInput.Cursor()
441		if cursor != nil {
442			cursor = m.moveCursor(cursor)
443			return cursor
444		}
445	} else {
446		cursor := m.modelList.Cursor()
447		if cursor != nil {
448			cursor = m.moveCursor(cursor)
449			return cursor
450		}
451	}
452	return nil
453}
454
455func (m *modelDialogCmp) style() lipgloss.Style {
456	t := styles.CurrentTheme()
457	return t.S().Base.
458		Width(m.width).
459		Border(lipgloss.RoundedBorder()).
460		BorderForeground(t.BorderFocus)
461}
462
463func (m *modelDialogCmp) listWidth() int {
464	return m.width - 2
465}
466
467func (m *modelDialogCmp) listHeight() int {
468	return m.wHeight / 2
469}
470
471func (m *modelDialogCmp) Position() (int, int) {
472	row := m.wHeight/4 - 2 // just a bit above the center
473	col := m.wWidth / 2
474	col -= m.width / 2
475	return row, col
476}
477
478func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
479	row, col := m.Position()
480	if m.needsAPIKey {
481		offset := row + 3 // Border + title + API key input offset
482		cursor.Y += offset
483		cursor.X = cursor.X + col + 2
484	} else {
485		offset := row + 3 // Border + title
486		cursor.Y += offset
487		cursor.X = cursor.X + col + 2
488	}
489	return cursor
490}
491
492func (m *modelDialogCmp) ID() dialogs.DialogID {
493	return ModelsDialogID
494}
495
496func (m *modelDialogCmp) modelTypeRadio() string {
497	t := styles.CurrentTheme()
498	choices := []string{"Large Task", "Small Task"}
499	iconSelected := "◉"
500	iconUnselected := "○"
501	if m.modelList.GetModelType() == LargeModelType {
502		return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + "  " + iconUnselected + " " + choices[1])
503	}
504	return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + "  " + iconSelected + " " + choices[1])
505}
506
507func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
508	cfg := config.Get()
509	if _, ok := cfg.Providers.Get(providerID); ok {
510		return true
511	}
512	return false
513}
514
515func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
516	cfg := config.Get()
517	providers, err := config.Providers(cfg)
518	if err != nil {
519		return nil, err
520	}
521	for _, p := range providers {
522		if p.ID == providerID {
523			return &p, nil
524		}
525	}
526	return nil, nil
527}
528
529func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey any, close bool) tea.Cmd {
530	if m.selectedModel == nil {
531		return util.ReportError(fmt.Errorf("no model selected"))
532	}
533
534	cfg := config.Get()
535	err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
536	if err != nil {
537		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
538	}
539
540	// Reset API key state and continue with model selection
541	selectedModel := *m.selectedModel
542	var cmds []tea.Cmd
543	if close {
544		cmds = append(cmds, util.CmdHandler(dialogs.CloseDialogMsg{}))
545	}
546	cmds = append(
547		cmds,
548		util.CmdHandler(ModelSelectedMsg{
549			Model: config.SelectedModel{
550				Model:           selectedModel.Model.ID,
551				Provider:        string(selectedModel.Provider.ID),
552				ReasoningEffort: selectedModel.Model.DefaultReasoningEffort,
553				MaxTokens:       selectedModel.Model.DefaultMaxTokens,
554			},
555			ModelType: m.selectedModelType,
556		}),
557	)
558	return tea.Sequence(cmds...)
559}