models.go

  1package models
  2
  3import (
  4	"fmt"
  5	"time"
  6
  7	"charm.land/bubbles/v2/help"
  8	"charm.land/bubbles/v2/key"
  9	"charm.land/bubbles/v2/spinner"
 10	tea "charm.land/bubbletea/v2"
 11	"charm.land/lipgloss/v2"
 12	"github.com/atotto/clipboard"
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/tui/components/core"
 16	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 17	"github.com/charmbracelet/crush/internal/tui/components/dialogs/claude"
 18	"github.com/charmbracelet/crush/internal/tui/exp/list"
 19	"github.com/charmbracelet/crush/internal/tui/styles"
 20	"github.com/charmbracelet/crush/internal/tui/util"
 21)
 22
 23const (
 24	ModelsDialogID dialogs.DialogID = "models"
 25
 26	defaultWidth = 60
 27)
 28
 29const (
 30	LargeModelType int = iota
 31	SmallModelType
 32
 33	largeModelInputPlaceholder = "Choose a model for large, complex tasks"
 34	smallModelInputPlaceholder = "Choose a model for small, simple tasks"
 35)
 36
 37// ModelSelectedMsg is sent when a model is selected
 38type ModelSelectedMsg struct {
 39	Model     config.SelectedModel
 40	ModelType config.SelectedModelType
 41}
 42
 43// CloseModelDialogMsg is sent when a model is selected
 44type CloseModelDialogMsg struct{}
 45
 46// ModelDialog interface for the model selection dialog
 47type ModelDialog interface {
 48	dialogs.DialogModel
 49}
 50
 51type ModelOption struct {
 52	Provider catwalk.Provider
 53	Model    catwalk.Model
 54}
 55
 56type modelDialogCmp struct {
 57	width   int
 58	wWidth  int
 59	wHeight int
 60
 61	modelList *ModelListComponent
 62	keyMap    KeyMap
 63	help      help.Model
 64
 65	// API key state
 66	needsAPIKey       bool
 67	apiKeyInput       *APIKeyInput
 68	selectedModel     *ModelOption
 69	selectedModelType config.SelectedModelType
 70	isAPIKeyValid     bool
 71	apiKeyValue       string
 72
 73	// Claude state
 74	claudeAuthMethodChooser     *claude.AuthMethodChooser
 75	claudeOAuth2                *claude.OAuth2
 76	showClaudeAuthMethodChooser bool
 77	showClaudeOAuth2            bool
 78}
 79
 80func NewModelDialogCmp() ModelDialog {
 81	keyMap := DefaultKeyMap()
 82
 83	listKeyMap := list.DefaultKeyMap()
 84	listKeyMap.Down.SetEnabled(false)
 85	listKeyMap.Up.SetEnabled(false)
 86	listKeyMap.DownOneItem = keyMap.Next
 87	listKeyMap.UpOneItem = keyMap.Previous
 88
 89	t := styles.CurrentTheme()
 90	modelList := NewModelListComponent(listKeyMap, largeModelInputPlaceholder, true)
 91	apiKeyInput := NewAPIKeyInput()
 92	apiKeyInput.SetShowTitle(false)
 93	help := help.New()
 94	help.Styles = t.S().Help
 95
 96	return &modelDialogCmp{
 97		modelList:   modelList,
 98		apiKeyInput: apiKeyInput,
 99		width:       defaultWidth,
100		keyMap:      DefaultKeyMap(),
101		help:        help,
102
103		claudeAuthMethodChooser: claude.NewAuthMethodChooser(),
104		claudeOAuth2:            claude.NewOAuth2(),
105	}
106}
107
108func (m *modelDialogCmp) Init() tea.Cmd {
109	return tea.Batch(
110		m.modelList.Init(),
111		m.apiKeyInput.Init(),
112		m.claudeAuthMethodChooser.Init(),
113		m.claudeOAuth2.Init(),
114	)
115}
116
117func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
118	switch msg := msg.(type) {
119	case tea.WindowSizeMsg:
120		m.wWidth = msg.Width
121		m.wHeight = msg.Height
122		m.apiKeyInput.SetWidth(m.width - 2)
123		m.help.SetWidth(m.width - 2)
124		m.claudeAuthMethodChooser.SetWidth(m.width - 2)
125		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
126	case APIKeyStateChangeMsg:
127		u, cmd := m.apiKeyInput.Update(msg)
128		m.apiKeyInput = u.(*APIKeyInput)
129		return m, cmd
130	case claude.ValidationCompletedMsg:
131		var cmds []tea.Cmd
132		u, cmd := m.claudeOAuth2.Update(msg)
133		m.claudeOAuth2 = u.(*claude.OAuth2)
134		cmds = append(cmds, cmd)
135
136		if msg.State == claude.OAuthValidationStateValid {
137			cmds = append(cmds, m.saveAPIKeyAndContinue(msg.Token, false))
138			m.keyMap.isClaudeOAuthHelpComplete = true
139		}
140
141		return m, tea.Batch(cmds...)
142	case claude.AuthenticationCompleteMsg:
143		return m, util.CmdHandler(dialogs.CloseDialogMsg{})
144	case tea.KeyPressMsg:
145		switch {
146		case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showClaudeOAuth2 && m.claudeOAuth2.State == claude.OAuthStateURL:
147			return m, tea.Sequence(
148				tea.SetClipboard(m.claudeOAuth2.URL),
149				func() tea.Msg {
150					_ = clipboard.WriteAll(m.claudeOAuth2.URL)
151					return nil
152				},
153				util.ReportInfo("URL copied to clipboard"),
154			)
155		case key.Matches(msg, m.keyMap.Choose) && m.showClaudeAuthMethodChooser:
156			m.claudeAuthMethodChooser.ToggleChoice()
157			return m, nil
158		case key.Matches(msg, m.keyMap.Select):
159			selectedItem := m.modelList.SelectedModel()
160
161			modelType := config.SelectedModelTypeLarge
162			if m.modelList.GetModelType() == SmallModelType {
163				modelType = config.SelectedModelTypeSmall
164			}
165
166			askForApiKey := func() {
167				m.keyMap.isClaudeAuthChoiseHelp = false
168				m.keyMap.isClaudeOAuthHelp = false
169				m.keyMap.isAPIKeyHelp = true
170				m.showClaudeAuthMethodChooser = false
171				m.needsAPIKey = true
172				m.selectedModel = selectedItem
173				m.selectedModelType = modelType
174				m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
175			}
176
177			if m.showClaudeAuthMethodChooser {
178				switch m.claudeAuthMethodChooser.State {
179				case claude.AuthMethodAPIKey:
180					askForApiKey()
181				case claude.AuthMethodOAuth2:
182					m.selectedModel = selectedItem
183					m.selectedModelType = modelType
184					m.showClaudeAuthMethodChooser = false
185					m.showClaudeOAuth2 = true
186					m.keyMap.isClaudeAuthChoiseHelp = false
187					m.keyMap.isClaudeOAuthHelp = true
188				}
189				return m, nil
190			}
191			if m.showClaudeOAuth2 {
192				m2, cmd2 := m.claudeOAuth2.ValidationConfirm()
193				m.claudeOAuth2 = m2.(*claude.OAuth2)
194				return m, cmd2
195			}
196			if m.isAPIKeyValid {
197				return m, m.saveAPIKeyAndContinue(m.apiKeyValue, true)
198			}
199			if m.needsAPIKey {
200				// Handle API key submission
201				m.apiKeyValue = m.apiKeyInput.Value()
202				provider, err := m.getProvider(m.selectedModel.Provider.ID)
203				if err != nil || provider == nil {
204					return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
205				}
206				providerConfig := config.ProviderConfig{
207					ID:      string(m.selectedModel.Provider.ID),
208					Name:    m.selectedModel.Provider.Name,
209					APIKey:  m.apiKeyValue,
210					Type:    provider.Type,
211					BaseURL: provider.APIEndpoint,
212				}
213				return m, tea.Sequence(
214					util.CmdHandler(APIKeyStateChangeMsg{
215						State: APIKeyInputStateVerifying,
216					}),
217					func() tea.Msg {
218						start := time.Now()
219						err := providerConfig.TestConnection(config.Get().Resolver())
220						// intentionally wait for at least 750ms to make sure the user sees the spinner
221						elapsed := time.Since(start)
222						if elapsed < 750*time.Millisecond {
223							time.Sleep(750*time.Millisecond - elapsed)
224						}
225						if err == nil {
226							m.isAPIKeyValid = true
227							return APIKeyStateChangeMsg{
228								State: APIKeyInputStateVerified,
229							}
230						}
231						return APIKeyStateChangeMsg{
232							State: APIKeyInputStateError,
233						}
234					},
235				)
236			}
237
238			// Check if provider is configured
239			if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
240				return m, tea.Sequence(
241					util.CmdHandler(dialogs.CloseDialogMsg{}),
242					util.CmdHandler(ModelSelectedMsg{
243						Model: config.SelectedModel{
244							Model:           selectedItem.Model.ID,
245							Provider:        string(selectedItem.Provider.ID),
246							ReasoningEffort: selectedItem.Model.DefaultReasoningEffort,
247							MaxTokens:       selectedItem.Model.DefaultMaxTokens,
248						},
249						ModelType: modelType,
250					}),
251				)
252			} else {
253				if selectedItem.Provider.ID == catwalk.InferenceProviderAnthropic {
254					m.showClaudeAuthMethodChooser = true
255					m.keyMap.isClaudeAuthChoiseHelp = true
256					return m, nil
257				}
258				askForApiKey()
259				return m, nil
260			}
261		case key.Matches(msg, m.keyMap.Tab):
262			switch {
263			case m.showClaudeAuthMethodChooser:
264				m.claudeAuthMethodChooser.ToggleChoice()
265				return m, nil
266			case m.needsAPIKey:
267				u, cmd := m.apiKeyInput.Update(msg)
268				m.apiKeyInput = u.(*APIKeyInput)
269				return m, cmd
270			case m.modelList.GetModelType() == LargeModelType:
271				m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
272				return m, m.modelList.SetModelType(SmallModelType)
273			default:
274				m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
275				return m, m.modelList.SetModelType(LargeModelType)
276			}
277		case key.Matches(msg, m.keyMap.Close):
278			if m.showClaudeAuthMethodChooser {
279				m.claudeAuthMethodChooser.SetDefaults()
280				m.showClaudeAuthMethodChooser = false
281				m.keyMap.isClaudeAuthChoiseHelp = false
282				m.keyMap.isClaudeOAuthHelp = false
283				return m, nil
284			}
285			if m.needsAPIKey {
286				if m.isAPIKeyValid {
287					return m, nil
288				}
289				// Go back to model selection
290				m.needsAPIKey = false
291				m.selectedModel = nil
292				m.isAPIKeyValid = false
293				m.apiKeyValue = ""
294				m.apiKeyInput.Reset()
295				return m, nil
296			}
297			return m, util.CmdHandler(dialogs.CloseDialogMsg{})
298		default:
299			if m.showClaudeAuthMethodChooser {
300				u, cmd := m.claudeAuthMethodChooser.Update(msg)
301				m.claudeAuthMethodChooser = u.(*claude.AuthMethodChooser)
302				return m, cmd
303			} else if m.showClaudeOAuth2 {
304				u, cmd := m.claudeOAuth2.Update(msg)
305				m.claudeOAuth2 = u.(*claude.OAuth2)
306				return m, cmd
307			} else if m.needsAPIKey {
308				u, cmd := m.apiKeyInput.Update(msg)
309				m.apiKeyInput = u.(*APIKeyInput)
310				return m, cmd
311			} else {
312				u, cmd := m.modelList.Update(msg)
313				m.modelList = u
314				return m, cmd
315			}
316		}
317	case tea.PasteMsg:
318		if m.showClaudeOAuth2 {
319			u, cmd := m.claudeOAuth2.Update(msg)
320			m.claudeOAuth2 = u.(*claude.OAuth2)
321			return m, cmd
322		} else if m.needsAPIKey {
323			u, cmd := m.apiKeyInput.Update(msg)
324			m.apiKeyInput = u.(*APIKeyInput)
325			return m, cmd
326		} else {
327			var cmd tea.Cmd
328			m.modelList, cmd = m.modelList.Update(msg)
329			return m, cmd
330		}
331	case spinner.TickMsg:
332		if m.showClaudeOAuth2 {
333			u, cmd := m.claudeOAuth2.Update(msg)
334			m.claudeOAuth2 = u.(*claude.OAuth2)
335			return m, cmd
336		} else {
337			u, cmd := m.apiKeyInput.Update(msg)
338			m.apiKeyInput = u.(*APIKeyInput)
339			return m, cmd
340		}
341	}
342	return m, nil
343}
344
345func (m *modelDialogCmp) View() string {
346	t := styles.CurrentTheme()
347
348	switch {
349	case m.showClaudeAuthMethodChooser:
350		chooserView := m.claudeAuthMethodChooser.View()
351		content := lipgloss.JoinVertical(
352			lipgloss.Left,
353			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
354			chooserView,
355			"",
356			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
357		)
358		return m.style().Render(content)
359	case m.showClaudeOAuth2:
360		m.keyMap.isClaudeOAuthURLState = m.claudeOAuth2.State == claude.OAuthStateURL
361		oauth2View := m.claudeOAuth2.View()
362		content := lipgloss.JoinVertical(
363			lipgloss.Left,
364			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
365			oauth2View,
366			"",
367			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
368		)
369		return m.style().Render(content)
370	case m.needsAPIKey:
371		// Show API key input
372		m.keyMap.isAPIKeyHelp = true
373		m.keyMap.isAPIKeyValid = m.isAPIKeyValid
374		apiKeyView := m.apiKeyInput.View()
375		apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
376		content := lipgloss.JoinVertical(
377			lipgloss.Left,
378			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
379			apiKeyView,
380			"",
381			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
382		)
383		return m.style().Render(content)
384	}
385
386	// Show model selection
387	listView := m.modelList.View()
388	radio := m.modelTypeRadio()
389	content := lipgloss.JoinVertical(
390		lipgloss.Left,
391		t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
392		listView,
393		"",
394		t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
395	)
396	return m.style().Render(content)
397}
398
399func (m *modelDialogCmp) Cursor() *tea.Cursor {
400	if m.showClaudeAuthMethodChooser {
401		return nil
402	}
403	if m.showClaudeOAuth2 {
404		if cursor := m.claudeOAuth2.CodeInput.Cursor(); cursor != nil {
405			cursor.Y += 2 // FIXME(@andreynering): Why do we need this?
406			return m.moveCursor(cursor)
407		}
408		return nil
409	}
410	if m.needsAPIKey {
411		cursor := m.apiKeyInput.Cursor()
412		if cursor != nil {
413			cursor = m.moveCursor(cursor)
414			return cursor
415		}
416	} else {
417		cursor := m.modelList.Cursor()
418		if cursor != nil {
419			cursor = m.moveCursor(cursor)
420			return cursor
421		}
422	}
423	return nil
424}
425
426func (m *modelDialogCmp) style() lipgloss.Style {
427	t := styles.CurrentTheme()
428	return t.S().Base.
429		Width(m.width).
430		Border(lipgloss.RoundedBorder()).
431		BorderForeground(t.BorderFocus)
432}
433
434func (m *modelDialogCmp) listWidth() int {
435	return m.width - 2
436}
437
438func (m *modelDialogCmp) listHeight() int {
439	return m.wHeight / 2
440}
441
442func (m *modelDialogCmp) Position() (int, int) {
443	row := m.wHeight/4 - 2 // just a bit above the center
444	col := m.wWidth / 2
445	col -= m.width / 2
446	return row, col
447}
448
449func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
450	row, col := m.Position()
451	if m.needsAPIKey {
452		offset := row + 3 // Border + title + API key input offset
453		cursor.Y += offset
454		cursor.X = cursor.X + col + 2
455	} else {
456		offset := row + 3 // Border + title
457		cursor.Y += offset
458		cursor.X = cursor.X + col + 2
459	}
460	return cursor
461}
462
463func (m *modelDialogCmp) ID() dialogs.DialogID {
464	return ModelsDialogID
465}
466
467func (m *modelDialogCmp) modelTypeRadio() string {
468	t := styles.CurrentTheme()
469	choices := []string{"Large Task", "Small Task"}
470	iconSelected := "◉"
471	iconUnselected := "○"
472	if m.modelList.GetModelType() == LargeModelType {
473		return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + "  " + iconUnselected + " " + choices[1])
474	}
475	return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + "  " + iconSelected + " " + choices[1])
476}
477
478func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
479	cfg := config.Get()
480	if _, ok := cfg.Providers.Get(providerID); ok {
481		return true
482	}
483	return false
484}
485
486func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
487	cfg := config.Get()
488	providers, err := config.Providers(cfg)
489	if err != nil {
490		return nil, err
491	}
492	for _, p := range providers {
493		if p.ID == providerID {
494			return &p, nil
495		}
496	}
497	return nil, nil
498}
499
500func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey any, close bool) tea.Cmd {
501	if m.selectedModel == nil {
502		return util.ReportError(fmt.Errorf("no model selected"))
503	}
504
505	cfg := config.Get()
506	err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
507	if err != nil {
508		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
509	}
510
511	// Reset API key state and continue with model selection
512	selectedModel := *m.selectedModel
513	var cmds []tea.Cmd
514	if close {
515		cmds = append(cmds, util.CmdHandler(dialogs.CloseDialogMsg{}))
516	}
517	cmds = append(
518		cmds,
519		util.CmdHandler(ModelSelectedMsg{
520			Model: config.SelectedModel{
521				Model:           selectedModel.Model.ID,
522				Provider:        string(selectedModel.Provider.ID),
523				ReasoningEffort: selectedModel.Model.DefaultReasoningEffort,
524				MaxTokens:       selectedModel.Model.DefaultMaxTokens,
525			},
526			ModelType: m.selectedModelType,
527		}),
528	)
529	return tea.Sequence(cmds...)
530}