models.go

  1package models
  2
  3import (
  4	"fmt"
  5	"time"
  6
  7	"github.com/charmbracelet/bubbles/v2/help"
  8	"github.com/charmbracelet/bubbles/v2/key"
  9	"github.com/charmbracelet/bubbles/v2/spinner"
 10	tea "github.com/charmbracelet/bubbletea/v2"
 11	"github.com/charmbracelet/catwalk/pkg/catwalk"
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/proto"
 14	"github.com/charmbracelet/crush/internal/tui/components/core"
 15	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 16	"github.com/charmbracelet/crush/internal/tui/exp/list"
 17	"github.com/charmbracelet/crush/internal/tui/styles"
 18	"github.com/charmbracelet/crush/internal/tui/util"
 19	"github.com/charmbracelet/lipgloss/v2"
 20)
 21
 22const (
 23	ModelsDialogID dialogs.DialogID = "models"
 24
 25	defaultWidth = 60
 26)
 27
 28const (
 29	LargeModelType int = iota
 30	SmallModelType
 31
 32	largeModelInputPlaceholder = "Choose a model for large, complex tasks"
 33	smallModelInputPlaceholder = "Choose a model for small, simple tasks"
 34)
 35
 36// ModelSelectedMsg is sent when a model is selected
 37type ModelSelectedMsg struct {
 38	Model     config.SelectedModel
 39	ModelType config.SelectedModelType
 40}
 41
 42// CloseModelDialogMsg is sent when a model is selected
 43type CloseModelDialogMsg struct{}
 44
 45// ModelDialog interface for the model selection dialog
 46type ModelDialog interface {
 47	dialogs.DialogModel
 48}
 49
 50type ModelOption struct {
 51	Provider catwalk.Provider
 52	Model    catwalk.Model
 53}
 54
 55type modelDialogCmp struct {
 56	width   int
 57	wWidth  int
 58	wHeight int
 59
 60	modelList *ModelListComponent
 61	keyMap    KeyMap
 62	help      help.Model
 63
 64	// API key state
 65	needsAPIKey       bool
 66	apiKeyInput       *APIKeyInput
 67	selectedModel     *ModelOption
 68	selectedModelType config.SelectedModelType
 69	isAPIKeyValid     bool
 70	apiKeyValue       string
 71
 72	ins *proto.Instance
 73}
 74
 75func NewModelDialogCmp(ins *proto.Instance) ModelDialog {
 76	keyMap := DefaultKeyMap()
 77
 78	listKeyMap := list.DefaultKeyMap()
 79	listKeyMap.Down.SetEnabled(false)
 80	listKeyMap.Up.SetEnabled(false)
 81	listKeyMap.DownOneItem = keyMap.Next
 82	listKeyMap.UpOneItem = keyMap.Previous
 83
 84	t := styles.CurrentTheme()
 85	modelList := NewModelListComponent(ins.Config, listKeyMap, largeModelInputPlaceholder, true)
 86	apiKeyInput := NewAPIKeyInput()
 87	apiKeyInput.SetShowTitle(false)
 88	help := help.New()
 89	help.Styles = t.S().Help
 90
 91	return &modelDialogCmp{
 92		modelList:   modelList,
 93		apiKeyInput: apiKeyInput,
 94		width:       defaultWidth,
 95		keyMap:      DefaultKeyMap(),
 96		help:        help,
 97		ins:         ins,
 98	}
 99}
100
101func (m *modelDialogCmp) Init() tea.Cmd {
102	return tea.Batch(m.modelList.Init(), m.apiKeyInput.Init())
103}
104
105func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
106	switch msg := msg.(type) {
107	case tea.WindowSizeMsg:
108		m.wWidth = msg.Width
109		m.wHeight = msg.Height
110		m.apiKeyInput.SetWidth(m.width - 2)
111		m.help.Width = m.width - 2
112		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
113	case APIKeyStateChangeMsg:
114		u, cmd := m.apiKeyInput.Update(msg)
115		m.apiKeyInput = u.(*APIKeyInput)
116		return m, cmd
117	case tea.KeyPressMsg:
118		switch {
119		case key.Matches(msg, m.keyMap.Select):
120			if m.isAPIKeyValid {
121				return m, m.saveAPIKeyAndContinue(m.apiKeyValue)
122			}
123			if m.needsAPIKey {
124				// Handle API key submission
125				m.apiKeyValue = m.apiKeyInput.Value()
126				provider, err := m.getProvider(m.selectedModel.Provider.ID)
127				if err != nil || provider == nil {
128					return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
129				}
130				providerConfig := config.ProviderConfig{
131					ID:      string(m.selectedModel.Provider.ID),
132					Name:    m.selectedModel.Provider.Name,
133					APIKey:  m.apiKeyValue,
134					Type:    provider.Type,
135					BaseURL: provider.APIEndpoint,
136				}
137				return m, tea.Sequence(
138					util.CmdHandler(APIKeyStateChangeMsg{
139						State: APIKeyInputStateVerifying,
140					}),
141					func() tea.Msg {
142						start := time.Now()
143						err := providerConfig.TestConnection(m.ins.ShellResolver())
144						// intentionally wait for at least 750ms to make sure the user sees the spinner
145						elapsed := time.Since(start)
146						if elapsed < 750*time.Millisecond {
147							time.Sleep(750*time.Millisecond - elapsed)
148						}
149						if err == nil {
150							m.isAPIKeyValid = true
151							return APIKeyStateChangeMsg{
152								State: APIKeyInputStateVerified,
153							}
154						}
155						return APIKeyStateChangeMsg{
156							State: APIKeyInputStateError,
157						}
158					},
159				)
160			}
161			// Normal model selection
162			selectedItem := m.modelList.SelectedModel()
163
164			var modelType config.SelectedModelType
165			if m.modelList.GetModelType() == LargeModelType {
166				modelType = config.SelectedModelTypeLarge
167			} else {
168				modelType = config.SelectedModelTypeSmall
169			}
170
171			// Check if provider is configured
172			if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
173				return m, tea.Sequence(
174					util.CmdHandler(dialogs.CloseDialogMsg{}),
175					util.CmdHandler(ModelSelectedMsg{
176						Model: config.SelectedModel{
177							Model:           selectedItem.Model.ID,
178							Provider:        string(selectedItem.Provider.ID),
179							ReasoningEffort: selectedItem.Model.DefaultReasoningEffort,
180							MaxTokens:       selectedItem.Model.DefaultMaxTokens,
181						},
182						ModelType: modelType,
183					}),
184				)
185			} else {
186				// Provider not configured, show API key input
187				m.needsAPIKey = true
188				m.selectedModel = selectedItem
189				m.selectedModelType = modelType
190				m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
191				return m, nil
192			}
193		case key.Matches(msg, m.keyMap.Tab):
194			if m.needsAPIKey {
195				u, cmd := m.apiKeyInput.Update(msg)
196				m.apiKeyInput = u.(*APIKeyInput)
197				return m, cmd
198			}
199			if m.modelList.GetModelType() == LargeModelType {
200				m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
201				return m, m.modelList.SetModelType(SmallModelType)
202			} else {
203				m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
204				return m, m.modelList.SetModelType(LargeModelType)
205			}
206		case key.Matches(msg, m.keyMap.Close):
207			if m.needsAPIKey {
208				if m.isAPIKeyValid {
209					return m, nil
210				}
211				// Go back to model selection
212				m.needsAPIKey = false
213				m.selectedModel = nil
214				m.isAPIKeyValid = false
215				m.apiKeyValue = ""
216				m.apiKeyInput.Reset()
217				return m, nil
218			}
219			return m, util.CmdHandler(dialogs.CloseDialogMsg{})
220		default:
221			if m.needsAPIKey {
222				u, cmd := m.apiKeyInput.Update(msg)
223				m.apiKeyInput = u.(*APIKeyInput)
224				return m, cmd
225			} else {
226				u, cmd := m.modelList.Update(msg)
227				m.modelList = u
228				return m, cmd
229			}
230		}
231	case tea.PasteMsg:
232		if m.needsAPIKey {
233			u, cmd := m.apiKeyInput.Update(msg)
234			m.apiKeyInput = u.(*APIKeyInput)
235			return m, cmd
236		} else {
237			var cmd tea.Cmd
238			m.modelList, cmd = m.modelList.Update(msg)
239			return m, cmd
240		}
241	case spinner.TickMsg:
242		u, cmd := m.apiKeyInput.Update(msg)
243		m.apiKeyInput = u.(*APIKeyInput)
244		return m, cmd
245	}
246	return m, nil
247}
248
249func (m *modelDialogCmp) View() string {
250	t := styles.CurrentTheme()
251
252	if m.needsAPIKey {
253		// Show API key input
254		m.keyMap.isAPIKeyHelp = true
255		m.keyMap.isAPIKeyValid = m.isAPIKeyValid
256		apiKeyView := m.apiKeyInput.View()
257		apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
258		content := lipgloss.JoinVertical(
259			lipgloss.Left,
260			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
261			apiKeyView,
262			"",
263			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
264		)
265		return m.style().Render(content)
266	}
267
268	// Show model selection
269	listView := m.modelList.View()
270	radio := m.modelTypeRadio()
271	content := lipgloss.JoinVertical(
272		lipgloss.Left,
273		t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
274		listView,
275		"",
276		t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
277	)
278	return m.style().Render(content)
279}
280
281func (m *modelDialogCmp) Cursor() *tea.Cursor {
282	if m.needsAPIKey {
283		cursor := m.apiKeyInput.Cursor()
284		if cursor != nil {
285			cursor = m.moveCursor(cursor)
286			return cursor
287		}
288	} else {
289		cursor := m.modelList.Cursor()
290		if cursor != nil {
291			cursor = m.moveCursor(cursor)
292			return cursor
293		}
294	}
295	return nil
296}
297
298func (m *modelDialogCmp) style() lipgloss.Style {
299	t := styles.CurrentTheme()
300	return t.S().Base.
301		Width(m.width).
302		Border(lipgloss.RoundedBorder()).
303		BorderForeground(t.BorderFocus)
304}
305
306func (m *modelDialogCmp) listWidth() int {
307	return m.width - 2
308}
309
310func (m *modelDialogCmp) listHeight() int {
311	return m.wHeight / 2
312}
313
314func (m *modelDialogCmp) Position() (int, int) {
315	row := m.wHeight/4 - 2 // just a bit above the center
316	col := m.wWidth / 2
317	col -= m.width / 2
318	return row, col
319}
320
321func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
322	row, col := m.Position()
323	if m.needsAPIKey {
324		offset := row + 3 // Border + title + API key input offset
325		cursor.Y += offset
326		cursor.X = cursor.X + col + 2
327	} else {
328		offset := row + 3 // Border + title
329		cursor.Y += offset
330		cursor.X = cursor.X + col + 2
331	}
332	return cursor
333}
334
335func (m *modelDialogCmp) ID() dialogs.DialogID {
336	return ModelsDialogID
337}
338
339func (m *modelDialogCmp) modelTypeRadio() string {
340	t := styles.CurrentTheme()
341	choices := []string{"Large Task", "Small Task"}
342	iconSelected := "◉"
343	iconUnselected := "○"
344	if m.modelList.GetModelType() == LargeModelType {
345		return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + "  " + iconUnselected + " " + choices[1])
346	}
347	return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + "  " + iconSelected + " " + choices[1])
348}
349
350func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
351	if _, ok := m.ins.Config.Providers.Get(providerID); ok {
352		return true
353	}
354	return false
355}
356
357func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
358	providers, err := config.Providers(m.ins.Config)
359	if err != nil {
360		return nil, err
361	}
362	for _, p := range providers {
363		if p.ID == providerID {
364			return &p, nil
365		}
366	}
367	return nil, nil
368}
369
370func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
371	if m.selectedModel == nil {
372		return util.ReportError(fmt.Errorf("no model selected"))
373	}
374
375	err := m.ins.Config.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
376	if err != nil {
377		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
378	}
379
380	// Reset API key state and continue with model selection
381	selectedModel := *m.selectedModel
382	return tea.Sequence(
383		util.CmdHandler(dialogs.CloseDialogMsg{}),
384		util.CmdHandler(ModelSelectedMsg{
385			Model: config.SelectedModel{
386				Model:           selectedModel.Model.ID,
387				Provider:        string(selectedModel.Provider.ID),
388				ReasoningEffort: selectedModel.Model.DefaultReasoningEffort,
389				MaxTokens:       selectedModel.Model.DefaultMaxTokens,
390			},
391			ModelType: m.selectedModelType,
392		}),
393	)
394}