models.go

  1package models
  2
  3import (
  4	"fmt"
  5	"time"
  6
  7	"github.com/charmbracelet/bubbles/v2/help"
  8	"github.com/charmbracelet/bubbles/v2/key"
  9	"github.com/charmbracelet/bubbles/v2/spinner"
 10	tea "github.com/charmbracelet/bubbletea/v2"
 11	"github.com/charmbracelet/catwalk/pkg/catwalk"
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/tui/components/core"
 14	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 15	"github.com/charmbracelet/crush/internal/tui/exp/list"
 16	"github.com/charmbracelet/crush/internal/tui/styles"
 17	"github.com/charmbracelet/crush/internal/tui/util"
 18	"github.com/charmbracelet/lipgloss/v2"
 19)
 20
 21const (
 22	ModelsDialogID dialogs.DialogID = "models"
 23
 24	defaultWidth = 60
 25)
 26
 27const (
 28	LargeModelType int = iota
 29	SmallModelType
 30
 31	largeModelInputPlaceholder = "Choose a model for large, complex tasks"
 32	smallModelInputPlaceholder = "Choose a model for small, simple tasks"
 33)
 34
 35// ModelSelectedMsg is sent when a model is selected
 36type ModelSelectedMsg struct {
 37	Model     config.SelectedModel
 38	ModelType config.SelectedModelType
 39}
 40
 41// CloseModelDialogMsg is sent when a model is selected
 42type CloseModelDialogMsg struct{}
 43
 44// ModelDialog interface for the model selection dialog
 45type ModelDialog interface {
 46	dialogs.DialogModel
 47}
 48
 49type ModelOption struct {
 50	Provider catwalk.Provider
 51	Model    catwalk.Model
 52}
 53
 54type modelDialogCmp struct {
 55	width   int
 56	wWidth  int
 57	wHeight int
 58
 59	modelList *ModelListComponent
 60	keyMap    KeyMap
 61	help      help.Model
 62
 63	// API key state
 64	needsAPIKey       bool
 65	apiKeyInput       *APIKeyInput
 66	selectedModel     *ModelOption
 67	selectedModelType config.SelectedModelType
 68	isAPIKeyValid     bool
 69	apiKeyValue       string
 70
 71	cfg *config.Config
 72}
 73
 74func NewModelDialogCmp(cfg *config.Config) ModelDialog {
 75	keyMap := DefaultKeyMap()
 76
 77	listKeyMap := list.DefaultKeyMap()
 78	listKeyMap.Down.SetEnabled(false)
 79	listKeyMap.Up.SetEnabled(false)
 80	listKeyMap.DownOneItem = keyMap.Next
 81	listKeyMap.UpOneItem = keyMap.Previous
 82
 83	t := styles.CurrentTheme()
 84	modelList := NewModelListComponent(cfg, listKeyMap, largeModelInputPlaceholder, true)
 85	apiKeyInput := NewAPIKeyInput()
 86	apiKeyInput.SetShowTitle(false)
 87	help := help.New()
 88	help.Styles = t.S().Help
 89
 90	return &modelDialogCmp{
 91		modelList:   modelList,
 92		apiKeyInput: apiKeyInput,
 93		width:       defaultWidth,
 94		keyMap:      DefaultKeyMap(),
 95		help:        help,
 96		cfg:         cfg,
 97	}
 98}
 99
100func (m *modelDialogCmp) Init() tea.Cmd {
101	return tea.Batch(m.modelList.Init(), m.apiKeyInput.Init())
102}
103
104func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
105	switch msg := msg.(type) {
106	case tea.WindowSizeMsg:
107		m.wWidth = msg.Width
108		m.wHeight = msg.Height
109		m.apiKeyInput.SetWidth(m.width - 2)
110		m.help.Width = m.width - 2
111		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
112	case APIKeyStateChangeMsg:
113		u, cmd := m.apiKeyInput.Update(msg)
114		m.apiKeyInput = u.(*APIKeyInput)
115		return m, cmd
116	case tea.KeyPressMsg:
117		switch {
118		case key.Matches(msg, m.keyMap.Select):
119			if m.isAPIKeyValid {
120				return m, m.saveAPIKeyAndContinue(m.apiKeyValue)
121			}
122			if m.needsAPIKey {
123				// Handle API key submission
124				m.apiKeyValue = m.apiKeyInput.Value()
125				provider, err := m.getProvider(m.selectedModel.Provider.ID)
126				if err != nil || provider == nil {
127					return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
128				}
129				providerConfig := config.ProviderConfig{
130					ID:      string(m.selectedModel.Provider.ID),
131					Name:    m.selectedModel.Provider.Name,
132					APIKey:  m.apiKeyValue,
133					Type:    provider.Type,
134					BaseURL: provider.APIEndpoint,
135				}
136				return m, tea.Sequence(
137					util.CmdHandler(APIKeyStateChangeMsg{
138						State: APIKeyInputStateVerifying,
139					}),
140					func() tea.Msg {
141						start := time.Now()
142						err := providerConfig.TestConnection(m.cfg.Resolver())
143						// intentionally wait for at least 750ms to make sure the user sees the spinner
144						elapsed := time.Since(start)
145						if elapsed < 750*time.Millisecond {
146							time.Sleep(750*time.Millisecond - elapsed)
147						}
148						if err == nil {
149							m.isAPIKeyValid = true
150							return APIKeyStateChangeMsg{
151								State: APIKeyInputStateVerified,
152							}
153						}
154						return APIKeyStateChangeMsg{
155							State: APIKeyInputStateError,
156						}
157					},
158				)
159			}
160			// Normal model selection
161			selectedItem := m.modelList.SelectedModel()
162
163			var modelType config.SelectedModelType
164			if m.modelList.GetModelType() == LargeModelType {
165				modelType = config.SelectedModelTypeLarge
166			} else {
167				modelType = config.SelectedModelTypeSmall
168			}
169
170			// Check if provider is configured
171			if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
172				return m, tea.Sequence(
173					util.CmdHandler(dialogs.CloseDialogMsg{}),
174					util.CmdHandler(ModelSelectedMsg{
175						Model: config.SelectedModel{
176							Model:           selectedItem.Model.ID,
177							Provider:        string(selectedItem.Provider.ID),
178							ReasoningEffort: selectedItem.Model.DefaultReasoningEffort,
179							MaxTokens:       selectedItem.Model.DefaultMaxTokens,
180						},
181						ModelType: modelType,
182					}),
183				)
184			} else {
185				// Provider not configured, show API key input
186				m.needsAPIKey = true
187				m.selectedModel = selectedItem
188				m.selectedModelType = modelType
189				m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
190				return m, nil
191			}
192		case key.Matches(msg, m.keyMap.Tab):
193			if m.needsAPIKey {
194				u, cmd := m.apiKeyInput.Update(msg)
195				m.apiKeyInput = u.(*APIKeyInput)
196				return m, cmd
197			}
198			if m.modelList.GetModelType() == LargeModelType {
199				m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
200				return m, m.modelList.SetModelType(SmallModelType)
201			} else {
202				m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
203				return m, m.modelList.SetModelType(LargeModelType)
204			}
205		case key.Matches(msg, m.keyMap.Close):
206			if m.needsAPIKey {
207				if m.isAPIKeyValid {
208					return m, nil
209				}
210				// Go back to model selection
211				m.needsAPIKey = false
212				m.selectedModel = nil
213				m.isAPIKeyValid = false
214				m.apiKeyValue = ""
215				m.apiKeyInput.Reset()
216				return m, nil
217			}
218			return m, util.CmdHandler(dialogs.CloseDialogMsg{})
219		default:
220			if m.needsAPIKey {
221				u, cmd := m.apiKeyInput.Update(msg)
222				m.apiKeyInput = u.(*APIKeyInput)
223				return m, cmd
224			} else {
225				u, cmd := m.modelList.Update(msg)
226				m.modelList = u
227				return m, cmd
228			}
229		}
230	case tea.PasteMsg:
231		if m.needsAPIKey {
232			u, cmd := m.apiKeyInput.Update(msg)
233			m.apiKeyInput = u.(*APIKeyInput)
234			return m, cmd
235		} else {
236			var cmd tea.Cmd
237			m.modelList, cmd = m.modelList.Update(msg)
238			return m, cmd
239		}
240	case spinner.TickMsg:
241		u, cmd := m.apiKeyInput.Update(msg)
242		m.apiKeyInput = u.(*APIKeyInput)
243		return m, cmd
244	}
245	return m, nil
246}
247
248func (m *modelDialogCmp) View() string {
249	t := styles.CurrentTheme()
250
251	if m.needsAPIKey {
252		// Show API key input
253		m.keyMap.isAPIKeyHelp = true
254		m.keyMap.isAPIKeyValid = m.isAPIKeyValid
255		apiKeyView := m.apiKeyInput.View()
256		apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
257		content := lipgloss.JoinVertical(
258			lipgloss.Left,
259			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
260			apiKeyView,
261			"",
262			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
263		)
264		return m.style().Render(content)
265	}
266
267	// Show model selection
268	listView := m.modelList.View()
269	radio := m.modelTypeRadio()
270	content := lipgloss.JoinVertical(
271		lipgloss.Left,
272		t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
273		listView,
274		"",
275		t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
276	)
277	return m.style().Render(content)
278}
279
280func (m *modelDialogCmp) Cursor() *tea.Cursor {
281	if m.needsAPIKey {
282		cursor := m.apiKeyInput.Cursor()
283		if cursor != nil {
284			cursor = m.moveCursor(cursor)
285			return cursor
286		}
287	} else {
288		cursor := m.modelList.Cursor()
289		if cursor != nil {
290			cursor = m.moveCursor(cursor)
291			return cursor
292		}
293	}
294	return nil
295}
296
297func (m *modelDialogCmp) style() lipgloss.Style {
298	t := styles.CurrentTheme()
299	return t.S().Base.
300		Width(m.width).
301		Border(lipgloss.RoundedBorder()).
302		BorderForeground(t.BorderFocus)
303}
304
305func (m *modelDialogCmp) listWidth() int {
306	return m.width - 2
307}
308
309func (m *modelDialogCmp) listHeight() int {
310	return m.wHeight / 2
311}
312
313func (m *modelDialogCmp) Position() (int, int) {
314	row := m.wHeight/4 - 2 // just a bit above the center
315	col := m.wWidth / 2
316	col -= m.width / 2
317	return row, col
318}
319
320func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
321	row, col := m.Position()
322	if m.needsAPIKey {
323		offset := row + 3 // Border + title + API key input offset
324		cursor.Y += offset
325		cursor.X = cursor.X + col + 2
326	} else {
327		offset := row + 3 // Border + title
328		cursor.Y += offset
329		cursor.X = cursor.X + col + 2
330	}
331	return cursor
332}
333
334func (m *modelDialogCmp) ID() dialogs.DialogID {
335	return ModelsDialogID
336}
337
338func (m *modelDialogCmp) modelTypeRadio() string {
339	t := styles.CurrentTheme()
340	choices := []string{"Large Task", "Small Task"}
341	iconSelected := "◉"
342	iconUnselected := "○"
343	if m.modelList.GetModelType() == LargeModelType {
344		return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + "  " + iconUnselected + " " + choices[1])
345	}
346	return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + "  " + iconSelected + " " + choices[1])
347}
348
349func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
350	if _, ok := m.cfg.Providers.Get(providerID); ok {
351		return true
352	}
353	return false
354}
355
356func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
357	providers, err := config.Providers(m.cfg)
358	if err != nil {
359		return nil, err
360	}
361	for _, p := range providers {
362		if p.ID == providerID {
363			return &p, nil
364		}
365	}
366	return nil, nil
367}
368
369func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
370	if m.selectedModel == nil {
371		return util.ReportError(fmt.Errorf("no model selected"))
372	}
373
374	err := m.cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
375	if err != nil {
376		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
377	}
378
379	// Reset API key state and continue with model selection
380	selectedModel := *m.selectedModel
381	return tea.Sequence(
382		util.CmdHandler(dialogs.CloseDialogMsg{}),
383		util.CmdHandler(ModelSelectedMsg{
384			Model: config.SelectedModel{
385				Model:           selectedModel.Model.ID,
386				Provider:        string(selectedModel.Provider.ID),
387				ReasoningEffort: selectedModel.Model.DefaultReasoningEffort,
388				MaxTokens:       selectedModel.Model.DefaultMaxTokens,
389			},
390			ModelType: m.selectedModelType,
391		}),
392	)
393}