models.go

  1package models
  2
  3import (
  4	"fmt"
  5	"time"
  6
  7	"github.com/charmbracelet/bubbles/v2/help"
  8	"github.com/charmbracelet/bubbles/v2/key"
  9	"github.com/charmbracelet/bubbles/v2/spinner"
 10	tea "github.com/charmbracelet/bubbletea/v2"
 11	"github.com/charmbracelet/catwalk/pkg/catwalk"
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/llm/agent"
 14	"github.com/charmbracelet/crush/internal/llm/provider"
 15	"github.com/charmbracelet/crush/internal/tui/components/core"
 16	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 17	"github.com/charmbracelet/crush/internal/tui/exp/list"
 18	"github.com/charmbracelet/crush/internal/tui/styles"
 19	"github.com/charmbracelet/crush/internal/tui/util"
 20	"github.com/charmbracelet/lipgloss/v2"
 21)
 22
 23const (
 24	ModelsDialogID dialogs.DialogID = "models"
 25
 26	defaultWidth = 60
 27)
 28
 29const (
 30	LargeModelType int = iota
 31	SmallModelType
 32
 33	largeModelInputPlaceholder = "Choose a model for large, complex tasks"
 34	smallModelInputPlaceholder = "Choose a model for small, simple tasks"
 35)
 36
 37// ModelSelectedMsg is sent when a model is selected
 38type ModelSelectedMsg struct {
 39	Model     agent.Model
 40	ModelType config.SelectedModelType
 41}
 42
 43// CloseModelDialogMsg is sent when a model is selected
 44type CloseModelDialogMsg struct{}
 45
 46// ModelDialog interface for the model selection dialog
 47type ModelDialog interface {
 48	dialogs.DialogModel
 49}
 50
 51type ModelOption struct {
 52	Provider catwalk.Provider
 53	Model    catwalk.Model
 54}
 55
 56type modelDialogCmp struct {
 57	width   int
 58	wWidth  int
 59	wHeight int
 60
 61	config    *config.Config
 62	modelList *ModelListComponent
 63	keyMap    KeyMap
 64	help      help.Model
 65
 66	// API key state
 67	needsAPIKey       bool
 68	apiKeyInput       *APIKeyInput
 69	selectedModel     *ModelOption
 70	selectedModelType config.SelectedModelType
 71	isAPIKeyValid     bool
 72	apiKeyValue       string
 73}
 74
 75func NewModelDialogCmp(cfg *config.Config) ModelDialog {
 76	keyMap := DefaultKeyMap()
 77
 78	listKeyMap := list.DefaultKeyMap()
 79	listKeyMap.Down.SetEnabled(false)
 80	listKeyMap.Up.SetEnabled(false)
 81	listKeyMap.DownOneItem = keyMap.Next
 82	listKeyMap.UpOneItem = keyMap.Previous
 83
 84	t := styles.CurrentTheme()
 85	modelList := NewModelListComponent(cfg, listKeyMap, "Choose a model for large, complex tasks", true)
 86	apiKeyInput := NewAPIKeyInput()
 87	apiKeyInput.SetShowTitle(false)
 88	help := help.New()
 89	help.Styles = t.S().Help
 90
 91	return &modelDialogCmp{
 92		modelList:   modelList,
 93		apiKeyInput: apiKeyInput,
 94		width:       defaultWidth,
 95		keyMap:      DefaultKeyMap(),
 96		help:        help,
 97		config:      cfg,
 98	}
 99}
100
101func (m *modelDialogCmp) Init() tea.Cmd {
102	return tea.Batch(m.modelList.Init(), m.apiKeyInput.Init())
103}
104
105func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
106	switch msg := msg.(type) {
107	case tea.WindowSizeMsg:
108		m.wWidth = msg.Width
109		m.wHeight = msg.Height
110		m.apiKeyInput.SetWidth(m.width - 2)
111		m.help.Width = m.width - 2
112		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
113	case APIKeyStateChangeMsg:
114		u, cmd := m.apiKeyInput.Update(msg)
115		m.apiKeyInput = u.(*APIKeyInput)
116		return m, cmd
117	case tea.KeyPressMsg:
118		switch {
119		case key.Matches(msg, m.keyMap.Select):
120			if m.isAPIKeyValid {
121				return m, m.saveAPIKeyAndContinue(m.apiKeyValue)
122			}
123			if m.needsAPIKey {
124				// Handle API key submission
125				m.apiKeyValue = m.apiKeyInput.Value()
126				selectedProvider, err := m.getProvider(m.selectedModel.Provider.ID)
127				if err != nil || selectedProvider == nil {
128					return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
129				}
130				providerConfig := provider.Config{
131					ID:      string(m.selectedModel.Provider.ID),
132					Name:    m.selectedModel.Provider.Name,
133					APIKey:  m.apiKeyValue,
134					Type:    selectedProvider.Type,
135					BaseURL: selectedProvider.APIEndpoint,
136				}
137				return m, tea.Sequence(
138					util.CmdHandler(APIKeyStateChangeMsg{
139						State: APIKeyInputStateVerifying,
140					}),
141					func() tea.Msg {
142						start := time.Now()
143						err := providerConfig.TestConnection(m.config.Resolver())
144						// intentionally wait for at least 750ms to make sure the user sees the spinner
145						elapsed := time.Since(start)
146						if elapsed < 750*time.Millisecond {
147							time.Sleep(750*time.Millisecond - elapsed)
148						}
149						if err == nil {
150							m.isAPIKeyValid = true
151							return APIKeyStateChangeMsg{
152								State: APIKeyInputStateVerified,
153							}
154						}
155						return APIKeyStateChangeMsg{
156							State: APIKeyInputStateError,
157						}
158					},
159				)
160			}
161			// Normal model selection
162			selectedItem := m.modelList.SelectedModel()
163
164			var modelType config.SelectedModelType
165			if m.modelList.GetModelType() == LargeModelType {
166				modelType = config.SelectedModelTypeLarge
167			} else {
168				modelType = config.SelectedModelTypeSmall
169			}
170
171			// Check if provider is configured
172			if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
173				return m, tea.Sequence(
174					util.CmdHandler(dialogs.CloseDialogMsg{}),
175					util.CmdHandler(ModelSelectedMsg{
176						Model: agent.Model{
177							Model:    selectedItem.Model.ID,
178							Provider: string(selectedItem.Provider.ID),
179						},
180						ModelType: modelType,
181					}),
182				)
183			} else {
184				// Provider not configured, show API key input
185				m.needsAPIKey = true
186				m.selectedModel = selectedItem
187				m.selectedModelType = modelType
188				m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
189				return m, nil
190			}
191		case key.Matches(msg, m.keyMap.Tab):
192			if m.needsAPIKey {
193				u, cmd := m.apiKeyInput.Update(msg)
194				m.apiKeyInput = u.(*APIKeyInput)
195				return m, cmd
196			}
197			if m.modelList.GetModelType() == LargeModelType {
198				m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
199				return m, m.modelList.SetModelType(SmallModelType)
200			} else {
201				m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
202				return m, m.modelList.SetModelType(LargeModelType)
203			}
204		case key.Matches(msg, m.keyMap.Close):
205			if m.needsAPIKey {
206				if m.isAPIKeyValid {
207					return m, nil
208				}
209				// Go back to model selection
210				m.needsAPIKey = false
211				m.selectedModel = nil
212				m.isAPIKeyValid = false
213				m.apiKeyValue = ""
214				m.apiKeyInput.Reset()
215				return m, nil
216			}
217			return m, util.CmdHandler(dialogs.CloseDialogMsg{})
218		default:
219			if m.needsAPIKey {
220				u, cmd := m.apiKeyInput.Update(msg)
221				m.apiKeyInput = u.(*APIKeyInput)
222				return m, cmd
223			} else {
224				u, cmd := m.modelList.Update(msg)
225				m.modelList = u
226				return m, cmd
227			}
228		}
229	case tea.PasteMsg:
230		if m.needsAPIKey {
231			u, cmd := m.apiKeyInput.Update(msg)
232			m.apiKeyInput = u.(*APIKeyInput)
233			return m, cmd
234		} else {
235			var cmd tea.Cmd
236			m.modelList, cmd = m.modelList.Update(msg)
237			return m, cmd
238		}
239	case spinner.TickMsg:
240		u, cmd := m.apiKeyInput.Update(msg)
241		m.apiKeyInput = u.(*APIKeyInput)
242		return m, cmd
243	}
244	return m, nil
245}
246
247func (m *modelDialogCmp) View() string {
248	t := styles.CurrentTheme()
249
250	if m.needsAPIKey {
251		// Show API key input
252		m.keyMap.isAPIKeyHelp = true
253		m.keyMap.isAPIKeyValid = m.isAPIKeyValid
254		apiKeyView := m.apiKeyInput.View()
255		apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
256		content := lipgloss.JoinVertical(
257			lipgloss.Left,
258			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
259			apiKeyView,
260			"",
261			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
262		)
263		return m.style().Render(content)
264	}
265
266	// Show model selection
267	listView := m.modelList.View()
268	radio := m.modelTypeRadio()
269	content := lipgloss.JoinVertical(
270		lipgloss.Left,
271		t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
272		listView,
273		"",
274		t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
275	)
276	return m.style().Render(content)
277}
278
279func (m *modelDialogCmp) Cursor() *tea.Cursor {
280	if m.needsAPIKey {
281		cursor := m.apiKeyInput.Cursor()
282		if cursor != nil {
283			cursor = m.moveCursor(cursor)
284			return cursor
285		}
286	} else {
287		cursor := m.modelList.Cursor()
288		if cursor != nil {
289			cursor = m.moveCursor(cursor)
290			return cursor
291		}
292	}
293	return nil
294}
295
296func (m *modelDialogCmp) style() lipgloss.Style {
297	t := styles.CurrentTheme()
298	return t.S().Base.
299		Width(m.width).
300		Border(lipgloss.RoundedBorder()).
301		BorderForeground(t.BorderFocus)
302}
303
304func (m *modelDialogCmp) listWidth() int {
305	return m.width - 2
306}
307
308func (m *modelDialogCmp) listHeight() int {
309	return m.wHeight / 2
310}
311
312func (m *modelDialogCmp) Position() (int, int) {
313	row := m.wHeight/4 - 2 // just a bit above the center
314	col := m.wWidth / 2
315	col -= m.width / 2
316	return row, col
317}
318
319func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
320	row, col := m.Position()
321	if m.needsAPIKey {
322		offset := row + 3 // Border + title + API key input offset
323		cursor.Y += offset
324		cursor.X = cursor.X + col + 2
325	} else {
326		offset := row + 3 // Border + title
327		cursor.Y += offset
328		cursor.X = cursor.X + col + 2
329	}
330	return cursor
331}
332
333func (m *modelDialogCmp) ID() dialogs.DialogID {
334	return ModelsDialogID
335}
336
337func (m *modelDialogCmp) modelTypeRadio() string {
338	t := styles.CurrentTheme()
339	choices := []string{"Large Task", "Small Task"}
340	iconSelected := "◉"
341	iconUnselected := "○"
342	if m.modelList.GetModelType() == LargeModelType {
343		return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + "  " + iconUnselected + " " + choices[1])
344	}
345	return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + "  " + iconSelected + " " + choices[1])
346}
347
348func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
349	if _, ok := m.config.Providers.Get(providerID); ok {
350		return true
351	}
352	return false
353}
354
355func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
356	providers, err := config.Providers()
357	if err != nil {
358		return nil, err
359	}
360	for _, p := range providers {
361		if p.ID == providerID {
362			return &p, nil
363		}
364	}
365	return nil, nil
366}
367
368func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
369	if m.selectedModel == nil {
370		return util.ReportError(fmt.Errorf("no model selected"))
371	}
372
373	err := m.config.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
374	if err != nil {
375		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
376	}
377
378	// Reset API key state and continue with model selection
379	selectedModel := *m.selectedModel
380	return tea.Sequence(
381		util.CmdHandler(dialogs.CloseDialogMsg{}),
382		util.CmdHandler(ModelSelectedMsg{
383			Model: agent.Model{
384				Model:    selectedModel.Model.ID,
385				Provider: string(selectedModel.Provider.ID),
386			},
387			ModelType: m.selectedModelType,
388		}),
389	)
390}