models.go

  1// Package models provides the model selection dialog for the TUI.
  2package models
  3
  4import (
  5	"fmt"
  6	"time"
  7
  8	"charm.land/bubbles/v2/help"
  9	"charm.land/bubbles/v2/key"
 10	"charm.land/bubbles/v2/spinner"
 11	tea "charm.land/bubbletea/v2"
 12	"charm.land/lipgloss/v2"
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/tui/components/core"
 17	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 18	"github.com/charmbracelet/crush/internal/tui/components/dialogs/copilot"
 19	"github.com/charmbracelet/crush/internal/tui/components/dialogs/hyper"
 20	"github.com/charmbracelet/crush/internal/tui/exp/list"
 21	"github.com/charmbracelet/crush/internal/tui/styles"
 22	"github.com/charmbracelet/crush/internal/tui/util"
 23)
 24
 25const (
 26	ModelsDialogID dialogs.DialogID = "models"
 27
 28	defaultWidth = 60
 29)
 30
 31const (
 32	LargeModelType int = iota
 33	SmallModelType
 34
 35	largeModelInputPlaceholder = "Choose a model for large, complex tasks"
 36	smallModelInputPlaceholder = "Choose a model for small, simple tasks"
 37)
 38
 39// ModelSelectedMsg is sent when a model is selected
 40type ModelSelectedMsg struct {
 41	Model     config.SelectedModel
 42	ModelType config.SelectedModelType
 43}
 44
 45// CloseModelDialogMsg is sent when a model is selected
 46type CloseModelDialogMsg struct{}
 47
 48// ModelDialog interface for the model selection dialog
 49type ModelDialog interface {
 50	dialogs.DialogModel
 51}
 52
 53type ModelOption struct {
 54	Provider catwalk.Provider
 55	Model    catwalk.Model
 56}
 57
 58type modelDialogCmp struct {
 59	width   int
 60	wWidth  int
 61	wHeight int
 62
 63	modelList *ModelListComponent
 64	keyMap    KeyMap
 65	help      help.Model
 66
 67	// API key state
 68	needsAPIKey       bool
 69	apiKeyInput       *APIKeyInput
 70	selectedModel     *ModelOption
 71	selectedModelType config.SelectedModelType
 72	isAPIKeyValid     bool
 73	apiKeyValue       string
 74
 75	// Hyper device flow state
 76	hyperDeviceFlow     *hyper.DeviceFlow
 77	showHyperDeviceFlow bool
 78
 79	// Copilot device flow state
 80	copilotDeviceFlow     *copilot.DeviceFlow
 81	showCopilotDeviceFlow bool
 82}
 83
 84func NewModelDialogCmp() ModelDialog {
 85	keyMap := DefaultKeyMap()
 86
 87	listKeyMap := list.DefaultKeyMap()
 88	listKeyMap.Down.SetEnabled(false)
 89	listKeyMap.Up.SetEnabled(false)
 90	listKeyMap.DownOneItem = keyMap.Next
 91	listKeyMap.UpOneItem = keyMap.Previous
 92
 93	t := styles.CurrentTheme()
 94	modelList := NewModelListComponent(listKeyMap, largeModelInputPlaceholder, true)
 95	apiKeyInput := NewAPIKeyInput()
 96	apiKeyInput.SetShowTitle(false)
 97	help := help.New()
 98	help.Styles = t.S().Help
 99
100	return &modelDialogCmp{
101		modelList:   modelList,
102		apiKeyInput: apiKeyInput,
103		width:       defaultWidth,
104		keyMap:      DefaultKeyMap(),
105		help:        help,
106	}
107}
108
109func (m *modelDialogCmp) Init() tea.Cmd {
110	return tea.Batch(
111		m.modelList.Init(),
112		m.apiKeyInput.Init(),
113	)
114}
115
116func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
117	switch msg := msg.(type) {
118	case tea.WindowSizeMsg:
119		m.wWidth = msg.Width
120		m.wHeight = msg.Height
121		m.apiKeyInput.SetWidth(m.width - 2)
122		m.help.SetWidth(m.width - 2)
123		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
124	case APIKeyStateChangeMsg:
125		u, cmd := m.apiKeyInput.Update(msg)
126		m.apiKeyInput = u.(*APIKeyInput)
127		return m, cmd
128	case hyper.DeviceFlowCompletedMsg:
129		return m, m.saveOauthTokenAndContinue(msg.Token, true)
130	case hyper.DeviceAuthInitiatedMsg, hyper.DeviceFlowErrorMsg:
131		if m.hyperDeviceFlow != nil {
132			u, cmd := m.hyperDeviceFlow.Update(msg)
133			m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
134			return m, cmd
135		}
136		return m, nil
137	case copilot.DeviceAuthInitiatedMsg, copilot.DeviceFlowErrorMsg:
138		if m.copilotDeviceFlow != nil {
139			u, cmd := m.copilotDeviceFlow.Update(msg)
140			m.copilotDeviceFlow = u.(*copilot.DeviceFlow)
141			return m, cmd
142		}
143		return m, nil
144	case copilot.DeviceFlowCompletedMsg:
145		return m, m.saveOauthTokenAndContinue(msg.Token, true)
146	case tea.KeyPressMsg:
147		switch {
148		// Handle Hyper device flow keys
149		case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showHyperDeviceFlow:
150			return m, m.hyperDeviceFlow.CopyCode()
151		case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showCopilotDeviceFlow:
152			return m, m.copilotDeviceFlow.CopyCode()
153		case key.Matches(msg, m.keyMap.Select):
154			// If showing device flow, enter copies code and opens URL
155			if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
156				return m, m.hyperDeviceFlow.CopyCodeAndOpenURL()
157			}
158			if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
159				return m, m.copilotDeviceFlow.CopyCodeAndOpenURL()
160			}
161			selectedItem := m.modelList.SelectedModel()
162			if selectedItem == nil {
163				return m, nil
164			}
165
166			modelType := config.SelectedModelTypeLarge
167			if m.modelList.GetModelType() == SmallModelType {
168				modelType = config.SelectedModelTypeSmall
169			}
170
171			askForApiKey := func() {
172				m.keyMap.isAPIKeyHelp = true
173				m.showHyperDeviceFlow = false
174				m.showCopilotDeviceFlow = false
175				m.needsAPIKey = true
176				m.selectedModel = selectedItem
177				m.selectedModelType = modelType
178				m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
179			}
180
181			if m.isAPIKeyValid {
182				return m, m.saveOauthTokenAndContinue(m.apiKeyValue, true)
183			}
184			if m.needsAPIKey {
185				// Handle API key submission
186				m.apiKeyValue = m.apiKeyInput.Value()
187				provider, err := m.getProvider(m.selectedModel.Provider.ID)
188				if err != nil || provider == nil {
189					return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
190				}
191				providerConfig := config.ProviderConfig{
192					ID:      string(m.selectedModel.Provider.ID),
193					Name:    m.selectedModel.Provider.Name,
194					APIKey:  m.apiKeyValue,
195					Type:    provider.Type,
196					BaseURL: provider.APIEndpoint,
197				}
198				return m, tea.Sequence(
199					util.CmdHandler(APIKeyStateChangeMsg{
200						State: APIKeyInputStateVerifying,
201					}),
202					func() tea.Msg {
203						start := time.Now()
204						err := providerConfig.TestConnection(config.Get().Resolver())
205						// intentionally wait for at least 750ms to make sure the user sees the spinner
206						elapsed := time.Since(start)
207						if elapsed < 750*time.Millisecond {
208							time.Sleep(750*time.Millisecond - elapsed)
209						}
210						if err == nil {
211							m.isAPIKeyValid = true
212							return APIKeyStateChangeMsg{
213								State: APIKeyInputStateVerified,
214							}
215						}
216						return APIKeyStateChangeMsg{
217							State: APIKeyInputStateError,
218						}
219					},
220				)
221			}
222
223			// Check if provider is configured
224			if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
225				return m, tea.Sequence(
226					util.CmdHandler(dialogs.CloseDialogMsg{}),
227					util.CmdHandler(ModelSelectedMsg{
228						Model: config.SelectedModel{
229							Model:           selectedItem.Model.ID,
230							Provider:        string(selectedItem.Provider.ID),
231							ReasoningEffort: selectedItem.Model.DefaultReasoningEffort,
232							MaxTokens:       selectedItem.Model.DefaultMaxTokens,
233						},
234						ModelType: modelType,
235					}),
236				)
237			}
238			switch selectedItem.Provider.ID {
239			case hyperp.Name:
240				m.showHyperDeviceFlow = true
241				m.selectedModel = selectedItem
242				m.selectedModelType = modelType
243				m.hyperDeviceFlow = hyper.NewDeviceFlow()
244				m.hyperDeviceFlow.SetWidth(m.width - 2)
245				return m, m.hyperDeviceFlow.Init()
246			case catwalk.InferenceProviderCopilot:
247				if token, ok := config.Get().ImportCopilot(); ok {
248					m.selectedModel = selectedItem
249					m.selectedModelType = modelType
250					return m, m.saveOauthTokenAndContinue(token, true)
251				}
252				m.showCopilotDeviceFlow = true
253				m.selectedModel = selectedItem
254				m.selectedModelType = modelType
255				m.copilotDeviceFlow = copilot.NewDeviceFlow()
256				m.copilotDeviceFlow.SetWidth(m.width - 2)
257				return m, m.copilotDeviceFlow.Init()
258			}
259			// For other providers, show API key input
260			askForApiKey()
261			return m, nil
262		case key.Matches(msg, m.keyMap.Tab):
263			switch {
264			case m.needsAPIKey:
265				u, cmd := m.apiKeyInput.Update(msg)
266				m.apiKeyInput = u.(*APIKeyInput)
267				return m, cmd
268			case m.modelList.GetModelType() == LargeModelType:
269				m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
270				return m, m.modelList.SetModelType(SmallModelType)
271			default:
272				m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
273				return m, m.modelList.SetModelType(LargeModelType)
274			}
275		case key.Matches(msg, m.keyMap.Close):
276			switch {
277			case m.showHyperDeviceFlow:
278				if m.hyperDeviceFlow != nil {
279					m.hyperDeviceFlow.Cancel()
280				}
281				m.showHyperDeviceFlow = false
282				m.selectedModel = nil
283			case m.showCopilotDeviceFlow:
284				if m.copilotDeviceFlow != nil {
285					m.copilotDeviceFlow.Cancel()
286				}
287				m.showCopilotDeviceFlow = false
288				m.selectedModel = nil
289			case m.needsAPIKey:
290				if m.isAPIKeyValid {
291					return m, nil
292				}
293				// Go back to model selection
294				m.needsAPIKey = false
295				m.selectedModel = nil
296				m.isAPIKeyValid = false
297				m.apiKeyValue = ""
298				m.apiKeyInput.Reset()
299				return m, nil
300			default:
301				return m, util.CmdHandler(dialogs.CloseDialogMsg{})
302			}
303		default:
304			switch {
305			case m.needsAPIKey:
306				u, cmd := m.apiKeyInput.Update(msg)
307				m.apiKeyInput = u.(*APIKeyInput)
308				return m, cmd
309			default:
310				u, cmd := m.modelList.Update(msg)
311				m.modelList = u
312				return m, cmd
313			}
314		}
315	case tea.PasteMsg:
316		switch {
317		case m.needsAPIKey:
318			u, cmd := m.apiKeyInput.Update(msg)
319			m.apiKeyInput = u.(*APIKeyInput)
320			return m, cmd
321		default:
322			var cmd tea.Cmd
323			m.modelList, cmd = m.modelList.Update(msg)
324			return m, cmd
325		}
326	case spinner.TickMsg:
327		u, cmd := m.apiKeyInput.Update(msg)
328		m.apiKeyInput = u.(*APIKeyInput)
329		if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
330			u, cmd = m.hyperDeviceFlow.Update(msg)
331			m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
332		}
333		if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
334			u, cmd = m.copilotDeviceFlow.Update(msg)
335			m.copilotDeviceFlow = u.(*copilot.DeviceFlow)
336		}
337		return m, cmd
338	default:
339		// Pass all other messages to the device flow for spinner animation
340		switch {
341		case m.showHyperDeviceFlow && m.hyperDeviceFlow != nil:
342			u, cmd := m.hyperDeviceFlow.Update(msg)
343			m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
344			return m, cmd
345		case m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil:
346			u, cmd := m.copilotDeviceFlow.Update(msg)
347			m.copilotDeviceFlow = u.(*copilot.DeviceFlow)
348			return m, cmd
349		default:
350			u, cmd := m.apiKeyInput.Update(msg)
351			m.apiKeyInput = u.(*APIKeyInput)
352			return m, cmd
353		}
354	}
355	return m, nil
356}
357
358func (m *modelDialogCmp) View() string {
359	t := styles.CurrentTheme()
360
361	if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
362		// Show Hyper device flow
363		m.keyMap.isHyperDeviceFlow = true
364		deviceFlowView := m.hyperDeviceFlow.View()
365		content := lipgloss.JoinVertical(
366			lipgloss.Left,
367			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Authenticate with Hyper", m.width-4)),
368			deviceFlowView,
369			"",
370			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
371		)
372		return m.style().Render(content)
373	}
374	if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
375		// Show Hyper device flow
376		m.keyMap.isCopilotDeviceFlow = m.copilotDeviceFlow.State != copilot.DeviceFlowStateUnavailable
377		m.keyMap.isCopilotUnavailable = m.copilotDeviceFlow.State == copilot.DeviceFlowStateUnavailable
378		deviceFlowView := m.copilotDeviceFlow.View()
379		content := lipgloss.JoinVertical(
380			lipgloss.Left,
381			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Authenticate with GitHub Copilot", m.width-4)),
382			deviceFlowView,
383			"",
384			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
385		)
386		return m.style().Render(content)
387	}
388
389	// Reset the flags when not showing device flow
390	m.keyMap.isHyperDeviceFlow = false
391	m.keyMap.isCopilotDeviceFlow = false
392	m.keyMap.isCopilotUnavailable = false
393
394	switch {
395	case m.needsAPIKey:
396		// Show API key input
397		m.keyMap.isAPIKeyHelp = true
398		m.keyMap.isAPIKeyValid = m.isAPIKeyValid
399		apiKeyView := m.apiKeyInput.View()
400		apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
401		content := lipgloss.JoinVertical(
402			lipgloss.Left,
403			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
404			apiKeyView,
405			"",
406			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
407		)
408		return m.style().Render(content)
409	}
410
411	// Show model selection
412	listView := m.modelList.View()
413	radio := m.modelTypeRadio()
414	content := lipgloss.JoinVertical(
415		lipgloss.Left,
416		t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
417		listView,
418		"",
419		t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
420	)
421	return m.style().Render(content)
422}
423
424func (m *modelDialogCmp) Cursor() *tea.Cursor {
425	if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
426		return m.hyperDeviceFlow.Cursor()
427	}
428	if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
429		return m.copilotDeviceFlow.Cursor()
430	}
431	if m.needsAPIKey {
432		cursor := m.apiKeyInput.Cursor()
433		if cursor != nil {
434			cursor = m.moveCursor(cursor)
435			return cursor
436		}
437	} else {
438		cursor := m.modelList.Cursor()
439		if cursor != nil {
440			cursor = m.moveCursor(cursor)
441			return cursor
442		}
443	}
444	return nil
445}
446
447func (m *modelDialogCmp) style() lipgloss.Style {
448	t := styles.CurrentTheme()
449	return t.S().Base.
450		Width(m.width).
451		Border(lipgloss.RoundedBorder()).
452		BorderForeground(t.BorderFocus)
453}
454
455func (m *modelDialogCmp) listWidth() int {
456	return m.width - 2
457}
458
459func (m *modelDialogCmp) listHeight() int {
460	return m.wHeight / 2
461}
462
463func (m *modelDialogCmp) Position() (int, int) {
464	row := m.wHeight/4 - 2 // just a bit above the center
465	col := m.wWidth / 2
466	col -= m.width / 2
467	return row, col
468}
469
470func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
471	row, col := m.Position()
472	if m.needsAPIKey {
473		offset := row + 3 // Border + title + API key input offset
474		cursor.Y += offset
475		cursor.X = cursor.X + col + 2
476	} else {
477		offset := row + 3 // Border + title
478		cursor.Y += offset
479		cursor.X = cursor.X + col + 2
480	}
481	return cursor
482}
483
484func (m *modelDialogCmp) ID() dialogs.DialogID {
485	return ModelsDialogID
486}
487
488func (m *modelDialogCmp) modelTypeRadio() string {
489	t := styles.CurrentTheme()
490	choices := []string{"Large Task", "Small Task"}
491	iconSelected := "◉"
492	iconUnselected := "○"
493	if m.modelList.GetModelType() == LargeModelType {
494		return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + "  " + iconUnselected + " " + choices[1])
495	}
496	return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + "  " + iconSelected + " " + choices[1])
497}
498
499func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
500	cfg := config.Get()
501	_, ok := cfg.Providers.Get(providerID)
502	return ok
503}
504
505func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
506	cfg := config.Get()
507	providers, err := config.Providers(cfg)
508	if err != nil {
509		return nil, err
510	}
511	for _, p := range providers {
512		if p.ID == providerID {
513			return &p, nil
514		}
515	}
516	return nil, nil
517}
518
519func (m *modelDialogCmp) saveOauthTokenAndContinue(apiKey any, close bool) tea.Cmd {
520	if m.selectedModel == nil {
521		return util.ReportError(fmt.Errorf("no model selected"))
522	}
523
524	cfg := config.Get()
525	err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
526	if err != nil {
527		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
528	}
529
530	// Reset API key state and continue with model selection
531	selectedModel := *m.selectedModel
532	var cmds []tea.Cmd
533	if close {
534		cmds = append(cmds, util.CmdHandler(dialogs.CloseDialogMsg{}))
535	}
536	cmds = append(
537		cmds,
538		util.CmdHandler(ModelSelectedMsg{
539			Model: config.SelectedModel{
540				Model:           selectedModel.Model.ID,
541				Provider:        string(selectedModel.Provider.ID),
542				ReasoningEffort: selectedModel.Model.DefaultReasoningEffort,
543				MaxTokens:       selectedModel.Model.DefaultMaxTokens,
544			},
545			ModelType: m.selectedModelType,
546		}),
547	)
548	return tea.Sequence(cmds...)
549}