models.go

  1package models
  2
  3import (
  4	"fmt"
  5	"strings"
  6	"time"
  7
  8	"github.com/charmbracelet/bubbles/v2/help"
  9	"github.com/charmbracelet/bubbles/v2/key"
 10	"github.com/charmbracelet/bubbles/v2/spinner"
 11	tea "github.com/charmbracelet/bubbletea/v2"
 12	"github.com/charmbracelet/catwalk/pkg/catwalk"
 13	"github.com/charmbracelet/crush/internal/config"
 14	"github.com/charmbracelet/crush/internal/tui/components/core"
 15	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 16	"github.com/charmbracelet/crush/internal/tui/exp/list"
 17	"github.com/charmbracelet/crush/internal/tui/styles"
 18	"github.com/charmbracelet/crush/internal/tui/util"
 19	"github.com/charmbracelet/lipgloss/v2"
 20)
 21
 22const (
 23	ModelsDialogID dialogs.DialogID = "models"
 24
 25	defaultWidth = 60
 26)
 27
 28const (
 29	LargeModelType int = iota
 30	SmallModelType
 31
 32	largeModelInputPlaceholder = "Choose a model for large, complex tasks"
 33	smallModelInputPlaceholder = "Choose a model for small, simple tasks"
 34)
 35
 36// ModelSelectedMsg is sent when a model is selected
 37type ModelSelectedMsg struct {
 38	Model     config.SelectedModel
 39	ModelType config.SelectedModelType
 40}
 41
 42// CloseModelDialogMsg is sent when a model is selected
 43type CloseModelDialogMsg struct{}
 44
 45// ModelDialog interface for the model selection dialog
 46type ModelDialog interface {
 47	dialogs.DialogModel
 48}
 49
 50type ModelOption struct {
 51	Provider catwalk.Provider
 52	Model    catwalk.Model
 53}
 54
 55type modelDialogCmp struct {
 56	width   int
 57	wWidth  int
 58	wHeight int
 59
 60	modelList *ModelListComponent
 61	keyMap    KeyMap
 62	help      help.Model
 63
 64	// API key state
 65	needsAPIKey       bool
 66	apiKeyInput       *APIKeyInput
 67	selectedModel     *ModelOption
 68	selectedModelType config.SelectedModelType
 69	isAPIKeyValid     bool
 70	apiKeyValue       string
 71}
 72
 73func NewModelDialogCmp() ModelDialog {
 74	keyMap := DefaultKeyMap()
 75
 76	listKeyMap := list.DefaultKeyMap()
 77	listKeyMap.Down.SetEnabled(false)
 78	listKeyMap.Up.SetEnabled(false)
 79	listKeyMap.DownOneItem = keyMap.Next
 80	listKeyMap.UpOneItem = keyMap.Previous
 81
 82	t := styles.CurrentTheme()
 83	modelList := NewModelListComponent(listKeyMap, "Choose a model for large, complex tasks", true)
 84	apiKeyInput := NewAPIKeyInput()
 85	apiKeyInput.SetShowTitle(false)
 86	help := help.New()
 87	help.Styles = t.S().Help
 88
 89	return &modelDialogCmp{
 90		modelList:   modelList,
 91		apiKeyInput: apiKeyInput,
 92		width:       defaultWidth,
 93		keyMap:      DefaultKeyMap(),
 94		help:        help,
 95	}
 96}
 97
 98func (m *modelDialogCmp) Init() tea.Cmd {
 99	providers, err := config.Providers()
100	if err == nil {
101		filteredProviders := []catwalk.Provider{}
102		for _, p := range providers {
103			if strings.HasPrefix(p.APIKey, "$") && p.ID != catwalk.InferenceProviderAzure {
104				filteredProviders = append(filteredProviders, p)
105			}
106		}
107		m.modelList.SetProviders(filteredProviders)
108	}
109	return tea.Batch(m.modelList.Init(), m.apiKeyInput.Init())
110}
111
112func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
113	switch msg := msg.(type) {
114	case tea.WindowSizeMsg:
115		m.wWidth = msg.Width
116		m.wHeight = msg.Height
117		m.apiKeyInput.SetWidth(m.width - 2)
118		m.help.Width = m.width - 2
119		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
120	case APIKeyStateChangeMsg:
121		u, cmd := m.apiKeyInput.Update(msg)
122		m.apiKeyInput = u.(*APIKeyInput)
123		return m, cmd
124	case tea.KeyPressMsg:
125		switch {
126		case key.Matches(msg, m.keyMap.Select):
127			if m.isAPIKeyValid {
128				return m, m.saveAPIKeyAndContinue(m.apiKeyValue)
129			}
130			if m.needsAPIKey {
131				// Handle API key submission
132				m.apiKeyValue = m.apiKeyInput.Value()
133				provider, err := m.getProvider(m.selectedModel.Provider.ID)
134				if err != nil || provider == nil {
135					return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
136				}
137				providerConfig := config.ProviderConfig{
138					ID:      string(m.selectedModel.Provider.ID),
139					Name:    m.selectedModel.Provider.Name,
140					APIKey:  m.apiKeyValue,
141					Type:    provider.Type,
142					BaseURL: provider.APIEndpoint,
143				}
144				return m, tea.Sequence(
145					util.CmdHandler(APIKeyStateChangeMsg{
146						State: APIKeyInputStateVerifying,
147					}),
148					func() tea.Msg {
149						start := time.Now()
150						err := providerConfig.TestConnection(config.Get().Resolver())
151						// intentionally wait for at least 750ms to make sure the user sees the spinner
152						elapsed := time.Since(start)
153						if elapsed < 750*time.Millisecond {
154							time.Sleep(750*time.Millisecond - elapsed)
155						}
156						if err == nil {
157							m.isAPIKeyValid = true
158							return APIKeyStateChangeMsg{
159								State: APIKeyInputStateVerified,
160							}
161						}
162						return APIKeyStateChangeMsg{
163							State: APIKeyInputStateError,
164						}
165					},
166				)
167			}
168			// Normal model selection
169			selectedItem := m.modelList.SelectedModel()
170
171			var modelType config.SelectedModelType
172			if m.modelList.GetModelType() == LargeModelType {
173				modelType = config.SelectedModelTypeLarge
174			} else {
175				modelType = config.SelectedModelTypeSmall
176			}
177
178			// Check if provider is configured
179			if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
180				return m, tea.Sequence(
181					util.CmdHandler(dialogs.CloseDialogMsg{}),
182					util.CmdHandler(ModelSelectedMsg{
183						Model: config.SelectedModel{
184							Model:    selectedItem.Model.ID,
185							Provider: string(selectedItem.Provider.ID),
186						},
187						ModelType: modelType,
188					}),
189				)
190			} else {
191				// Provider not configured, show API key input
192				m.needsAPIKey = true
193				m.selectedModel = selectedItem
194				m.selectedModelType = modelType
195				m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
196				return m, nil
197			}
198		case key.Matches(msg, m.keyMap.Tab):
199			if m.needsAPIKey {
200				u, cmd := m.apiKeyInput.Update(msg)
201				m.apiKeyInput = u.(*APIKeyInput)
202				return m, cmd
203			}
204			if m.modelList.GetModelType() == LargeModelType {
205				m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
206				return m, m.modelList.SetModelType(SmallModelType)
207			} else {
208				m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
209				return m, m.modelList.SetModelType(LargeModelType)
210			}
211		case key.Matches(msg, m.keyMap.Close):
212			if m.needsAPIKey {
213				if m.isAPIKeyValid {
214					return m, nil
215				}
216				// Go back to model selection
217				m.needsAPIKey = false
218				m.selectedModel = nil
219				m.isAPIKeyValid = false
220				m.apiKeyValue = ""
221				m.apiKeyInput.Reset()
222				return m, nil
223			}
224			return m, util.CmdHandler(dialogs.CloseDialogMsg{})
225		default:
226			if m.needsAPIKey {
227				u, cmd := m.apiKeyInput.Update(msg)
228				m.apiKeyInput = u.(*APIKeyInput)
229				return m, cmd
230			} else {
231				u, cmd := m.modelList.Update(msg)
232				m.modelList = u
233				return m, cmd
234			}
235		}
236	case tea.PasteMsg:
237		if m.needsAPIKey {
238			u, cmd := m.apiKeyInput.Update(msg)
239			m.apiKeyInput = u.(*APIKeyInput)
240			return m, cmd
241		} else {
242			var cmd tea.Cmd
243			m.modelList, cmd = m.modelList.Update(msg)
244			return m, cmd
245		}
246	case spinner.TickMsg:
247		u, cmd := m.apiKeyInput.Update(msg)
248		m.apiKeyInput = u.(*APIKeyInput)
249		return m, cmd
250	}
251	return m, nil
252}
253
254func (m *modelDialogCmp) View() string {
255	t := styles.CurrentTheme()
256
257	if m.needsAPIKey {
258		// Show API key input
259		m.keyMap.isAPIKeyHelp = true
260		m.keyMap.isAPIKeyValid = m.isAPIKeyValid
261		apiKeyView := m.apiKeyInput.View()
262		apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
263		content := lipgloss.JoinVertical(
264			lipgloss.Left,
265			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
266			apiKeyView,
267			"",
268			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
269		)
270		return m.style().Render(content)
271	}
272
273	// Show model selection
274	listView := m.modelList.View()
275	radio := m.modelTypeRadio()
276	content := lipgloss.JoinVertical(
277		lipgloss.Left,
278		t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
279		listView,
280		"",
281		t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
282	)
283	return m.style().Render(content)
284}
285
286func (m *modelDialogCmp) Cursor() *tea.Cursor {
287	if m.needsAPIKey {
288		cursor := m.apiKeyInput.Cursor()
289		if cursor != nil {
290			cursor = m.moveCursor(cursor)
291			return cursor
292		}
293	} else {
294		cursor := m.modelList.Cursor()
295		if cursor != nil {
296			cursor = m.moveCursor(cursor)
297			return cursor
298		}
299	}
300	return nil
301}
302
303func (m *modelDialogCmp) style() lipgloss.Style {
304	t := styles.CurrentTheme()
305	return t.S().Base.
306		Width(m.width).
307		Border(lipgloss.RoundedBorder()).
308		BorderForeground(t.BorderFocus)
309}
310
311func (m *modelDialogCmp) listWidth() int {
312	return m.width - 2
313}
314
315func (m *modelDialogCmp) listHeight() int {
316	return m.wHeight / 2
317}
318
319func (m *modelDialogCmp) Position() (int, int) {
320	row := m.wHeight/4 - 2 // just a bit above the center
321	col := m.wWidth / 2
322	col -= m.width / 2
323	return row, col
324}
325
326func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
327	row, col := m.Position()
328	if m.needsAPIKey {
329		offset := row + 3 // Border + title + API key input offset
330		cursor.Y += offset
331		cursor.X = cursor.X + col + 2
332	} else {
333		offset := row + 3 // Border + title
334		cursor.Y += offset
335		cursor.X = cursor.X + col + 2
336	}
337	return cursor
338}
339
340func (m *modelDialogCmp) ID() dialogs.DialogID {
341	return ModelsDialogID
342}
343
344func (m *modelDialogCmp) modelTypeRadio() string {
345	t := styles.CurrentTheme()
346	choices := []string{"Large Task", "Small Task"}
347	iconSelected := "◉"
348	iconUnselected := "○"
349	if m.modelList.GetModelType() == LargeModelType {
350		return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + "  " + iconUnselected + " " + choices[1])
351	}
352	return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + "  " + iconSelected + " " + choices[1])
353}
354
355func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
356	cfg := config.Get()
357	if _, ok := cfg.Providers.Get(providerID); ok {
358		return true
359	}
360	return false
361}
362
363func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
364	providers, err := config.Providers()
365	if err != nil {
366		return nil, err
367	}
368	for _, p := range providers {
369		if p.ID == providerID {
370			return &p, nil
371		}
372	}
373	return nil, nil
374}
375
376func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
377	if m.selectedModel == nil {
378		return util.ReportError(fmt.Errorf("no model selected"))
379	}
380
381	cfg := config.Get()
382	err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
383	if err != nil {
384		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
385	}
386
387	// Reset API key state and continue with model selection
388	selectedModel := *m.selectedModel
389	return tea.Sequence(
390		util.CmdHandler(dialogs.CloseDialogMsg{}),
391		util.CmdHandler(ModelSelectedMsg{
392			Model: config.SelectedModel{
393				Model:    selectedModel.Model.ID,
394				Provider: string(selectedModel.Provider.ID),
395			},
396			ModelType: m.selectedModelType,
397		}),
398	)
399}