models.go

  1package models
  2
  3import (
  4	"fmt"
  5	"time"
  6
  7	"github.com/charmbracelet/bubbles/v2/help"
  8	"github.com/charmbracelet/bubbles/v2/key"
  9	"github.com/charmbracelet/bubbles/v2/spinner"
 10	tea "github.com/charmbracelet/bubbletea/v2"
 11	"github.com/charmbracelet/crush/internal/config"
 12	"github.com/charmbracelet/crush/internal/fur/provider"
 13	"github.com/charmbracelet/crush/internal/tui/components/completions"
 14	"github.com/charmbracelet/crush/internal/tui/components/core"
 15	"github.com/charmbracelet/crush/internal/tui/components/core/list"
 16	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 17	"github.com/charmbracelet/crush/internal/tui/styles"
 18	"github.com/charmbracelet/crush/internal/tui/util"
 19	"github.com/charmbracelet/lipgloss/v2"
 20)
 21
 22const (
 23	ModelsDialogID dialogs.DialogID = "models"
 24
 25	defaultWidth = 60
 26)
 27
 28const (
 29	LargeModelType int = iota
 30	SmallModelType
 31
 32	largeModelInputPlaceholder = "Choose a model for large, complex tasks"
 33	smallModelInputPlaceholder = "Choose a model for small, simple tasks"
 34)
 35
 36// ModelSelectedMsg is sent when a model is selected
 37type ModelSelectedMsg struct {
 38	Model     config.SelectedModel
 39	ModelType config.SelectedModelType
 40}
 41
 42// CloseModelDialogMsg is sent when a model is selected
 43type CloseModelDialogMsg struct{}
 44
 45// ModelDialog interface for the model selection dialog
 46type ModelDialog interface {
 47	dialogs.DialogModel
 48}
 49
 50type ModelOption struct {
 51	Provider provider.Provider
 52	Model    provider.Model
 53}
 54
 55type modelDialogCmp struct {
 56	width   int
 57	wWidth  int
 58	wHeight int
 59
 60	modelList *ModelListComponent
 61	keyMap    KeyMap
 62	help      help.Model
 63
 64	// API key state
 65	needsAPIKey       bool
 66	apiKeyInput       *APIKeyInput
 67	selectedModel     *ModelOption
 68	selectedModelType config.SelectedModelType
 69	isAPIKeyValid     bool
 70	apiKeyValue       string
 71}
 72
 73func NewModelDialogCmp() ModelDialog {
 74	listKeyMap := list.DefaultKeyMap()
 75	keyMap := DefaultKeyMap()
 76
 77	listKeyMap.Down.SetEnabled(false)
 78	listKeyMap.Up.SetEnabled(false)
 79	listKeyMap.HalfPageDown.SetEnabled(false)
 80	listKeyMap.HalfPageUp.SetEnabled(false)
 81	listKeyMap.Home.SetEnabled(false)
 82	listKeyMap.End.SetEnabled(false)
 83
 84	listKeyMap.DownOneItem = keyMap.Next
 85	listKeyMap.UpOneItem = keyMap.Previous
 86
 87	t := styles.CurrentTheme()
 88	inputStyle := t.S().Base.Padding(0, 1, 0, 1)
 89	modelList := NewModelListComponent(listKeyMap, inputStyle, "Choose a model for large, complex tasks")
 90	apiKeyInput := NewAPIKeyInput()
 91	apiKeyInput.SetShowTitle(false)
 92	help := help.New()
 93	help.Styles = t.S().Help
 94
 95	return &modelDialogCmp{
 96		modelList:   modelList,
 97		apiKeyInput: apiKeyInput,
 98		width:       defaultWidth,
 99		keyMap:      DefaultKeyMap(),
100		help:        help,
101	}
102}
103
104func (m *modelDialogCmp) Init() tea.Cmd {
105	return tea.Batch(m.modelList.Init(), m.apiKeyInput.Init())
106}
107
108func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
109	switch msg := msg.(type) {
110	case tea.WindowSizeMsg:
111		m.wWidth = msg.Width
112		m.wHeight = msg.Height
113		m.apiKeyInput.SetWidth(m.width - 2)
114		m.help.Width = m.width - 2
115		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
116	case APIKeyStateChangeMsg:
117		u, cmd := m.apiKeyInput.Update(msg)
118		m.apiKeyInput = u.(*APIKeyInput)
119		return m, cmd
120	case tea.KeyPressMsg:
121		switch {
122		case key.Matches(msg, m.keyMap.Select):
123			if m.isAPIKeyValid {
124				return m, m.saveAPIKeyAndContinue(m.apiKeyValue)
125			}
126			if m.needsAPIKey {
127				// Handle API key submission
128				m.apiKeyValue = m.apiKeyInput.Value()
129				provider, err := m.getProvider(m.selectedModel.Provider.ID)
130				if err != nil || provider == nil {
131					return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
132				}
133				providerConfig := config.ProviderConfig{
134					ID:      string(m.selectedModel.Provider.ID),
135					Name:    m.selectedModel.Provider.Name,
136					APIKey:  m.apiKeyValue,
137					Type:    provider.Type,
138					BaseURL: provider.APIEndpoint,
139				}
140				return m, tea.Sequence(
141					util.CmdHandler(APIKeyStateChangeMsg{
142						State: APIKeyInputStateVerifying,
143					}),
144					func() tea.Msg {
145						start := time.Now()
146						err := providerConfig.TestConnection(config.Get().Resolver())
147						// intentionally wait for at least 750ms to make sure the user sees the spinner
148						elapsed := time.Since(start)
149						if elapsed < 750*time.Millisecond {
150							time.Sleep(750*time.Millisecond - elapsed)
151						}
152						if err == nil {
153							m.isAPIKeyValid = true
154							return APIKeyStateChangeMsg{
155								State: APIKeyInputStateVerified,
156							}
157						}
158						return APIKeyStateChangeMsg{
159							State: APIKeyInputStateError,
160						}
161					},
162				)
163			}
164			// Normal model selection
165			selectedItemInx := m.modelList.SelectedIndex()
166			if selectedItemInx == list.NoSelection {
167				return m, nil
168			}
169			items := m.modelList.Items()
170			selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(ModelOption)
171
172			var modelType config.SelectedModelType
173			if m.modelList.GetModelType() == LargeModelType {
174				modelType = config.SelectedModelTypeLarge
175			} else {
176				modelType = config.SelectedModelTypeSmall
177			}
178
179			// Check if provider is configured
180			if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
181				return m, tea.Sequence(
182					util.CmdHandler(dialogs.CloseDialogMsg{}),
183					util.CmdHandler(ModelSelectedMsg{
184						Model: config.SelectedModel{
185							Model:    selectedItem.Model.ID,
186							Provider: string(selectedItem.Provider.ID),
187						},
188						ModelType: modelType,
189					}),
190				)
191			} else {
192				// Provider not configured, show API key input
193				m.needsAPIKey = true
194				m.selectedModel = &selectedItem
195				m.selectedModelType = modelType
196				m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
197				return m, nil
198			}
199		case key.Matches(msg, m.keyMap.Tab):
200			if m.needsAPIKey {
201				u, cmd := m.apiKeyInput.Update(msg)
202				m.apiKeyInput = u.(*APIKeyInput)
203				return m, cmd
204			}
205			if m.modelList.GetModelType() == LargeModelType {
206				m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
207				return m, m.modelList.SetModelType(SmallModelType)
208			} else {
209				m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
210				return m, m.modelList.SetModelType(LargeModelType)
211			}
212		case key.Matches(msg, m.keyMap.Close):
213			if m.needsAPIKey {
214				if m.isAPIKeyValid {
215					return m, nil
216				}
217				// Go back to model selection
218				m.needsAPIKey = false
219				m.selectedModel = nil
220				m.isAPIKeyValid = false
221				m.apiKeyValue = ""
222				m.apiKeyInput.Reset()
223				return m, nil
224			}
225			return m, util.CmdHandler(dialogs.CloseDialogMsg{})
226		default:
227			if m.needsAPIKey {
228				u, cmd := m.apiKeyInput.Update(msg)
229				m.apiKeyInput = u.(*APIKeyInput)
230				return m, cmd
231			} else {
232				u, cmd := m.modelList.Update(msg)
233				m.modelList = u
234				return m, cmd
235			}
236		}
237	case tea.PasteMsg:
238		if m.needsAPIKey {
239			u, cmd := m.apiKeyInput.Update(msg)
240			m.apiKeyInput = u.(*APIKeyInput)
241			return m, cmd
242		} else {
243			var cmd tea.Cmd
244			m.modelList, cmd = m.modelList.Update(msg)
245			return m, cmd
246		}
247	case spinner.TickMsg:
248		u, cmd := m.apiKeyInput.Update(msg)
249		m.apiKeyInput = u.(*APIKeyInput)
250		return m, cmd
251	}
252	return m, nil
253}
254
255func (m *modelDialogCmp) View() string {
256	t := styles.CurrentTheme()
257
258	if m.needsAPIKey {
259		// Show API key input
260		m.keyMap.isAPIKeyHelp = true
261		m.keyMap.isAPIKeyValid = m.isAPIKeyValid
262		apiKeyView := m.apiKeyInput.View()
263		apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
264		content := lipgloss.JoinVertical(
265			lipgloss.Left,
266			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
267			apiKeyView,
268			"",
269			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
270		)
271		return m.style().Render(content)
272	}
273
274	// Show model selection
275	listView := m.modelList.View()
276	radio := m.modelTypeRadio()
277	content := lipgloss.JoinVertical(
278		lipgloss.Left,
279		t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
280		listView,
281		"",
282		t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
283	)
284	return m.style().Render(content)
285}
286
287func (m *modelDialogCmp) Cursor() *tea.Cursor {
288	if m.needsAPIKey {
289		cursor := m.apiKeyInput.Cursor()
290		if cursor != nil {
291			cursor = m.moveCursor(cursor)
292			return cursor
293		}
294	} else {
295		cursor := m.modelList.Cursor()
296		if cursor != nil {
297			cursor = m.moveCursor(cursor)
298			return cursor
299		}
300	}
301	return nil
302}
303
304func (m *modelDialogCmp) style() lipgloss.Style {
305	t := styles.CurrentTheme()
306	return t.S().Base.
307		Width(m.width).
308		Border(lipgloss.RoundedBorder()).
309		BorderForeground(t.BorderFocus)
310}
311
312func (m *modelDialogCmp) listWidth() int {
313	return defaultWidth - 2 // 4 for padding
314}
315
316func (m *modelDialogCmp) listHeight() int {
317	items := m.modelList.Items()
318	listHeigh := len(items) + 2 + 4
319	return min(listHeigh, m.wHeight/2)
320}
321
322func (m *modelDialogCmp) Position() (int, int) {
323	row := m.wHeight/4 - 2 // just a bit above the center
324	col := m.wWidth / 2
325	col -= m.width / 2
326	return row, col
327}
328
329func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
330	row, col := m.Position()
331	if m.needsAPIKey {
332		offset := row + 3 // Border + title + API key input offset
333		cursor.Y += offset
334		cursor.X = cursor.X + col + 2
335	} else {
336		offset := row + 3 // Border + title
337		cursor.Y += offset
338		cursor.X = cursor.X + col + 2
339	}
340	return cursor
341}
342
343func (m *modelDialogCmp) ID() dialogs.DialogID {
344	return ModelsDialogID
345}
346
347func (m *modelDialogCmp) modelTypeRadio() string {
348	t := styles.CurrentTheme()
349	choices := []string{"Large Task", "Small Task"}
350	iconSelected := "◉"
351	iconUnselected := "○"
352	if m.modelList.GetModelType() == LargeModelType {
353		return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + "  " + iconUnselected + " " + choices[1])
354	}
355	return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + "  " + iconSelected + " " + choices[1])
356}
357
358func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
359	cfg := config.Get()
360	if _, ok := cfg.Providers.Get(providerID); ok {
361		return true
362	}
363	return false
364}
365
366func (m *modelDialogCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
367	providers, err := config.Providers()
368	if err != nil {
369		return nil, err
370	}
371	for _, p := range providers {
372		if p.ID == providerID {
373			return &p, nil
374		}
375	}
376	return nil, nil
377}
378
379func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
380	if m.selectedModel == nil {
381		return util.ReportError(fmt.Errorf("no model selected"))
382	}
383
384	cfg := config.Get()
385	err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
386	if err != nil {
387		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
388	}
389
390	// Reset API key state and continue with model selection
391	selectedModel := *m.selectedModel
392	return tea.Sequence(
393		util.CmdHandler(dialogs.CloseDialogMsg{}),
394		util.CmdHandler(ModelSelectedMsg{
395			Model: config.SelectedModel{
396				Model:    selectedModel.Model.ID,
397				Provider: string(selectedModel.Provider.ID),
398			},
399			ModelType: m.selectedModelType,
400		}),
401	)
402}