models.go

  1package dialog
  2
  3import (
  4	"cmp"
  5	"fmt"
  6	"slices"
  7	"strings"
  8
  9	"charm.land/bubbles/v2/help"
 10	"charm.land/bubbles/v2/key"
 11	"charm.land/bubbles/v2/textinput"
 12	tea "charm.land/bubbletea/v2"
 13	"charm.land/lipgloss/v2"
 14	"github.com/charmbracelet/catwalk/pkg/catwalk"
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/ui/common"
 17	"github.com/charmbracelet/crush/internal/uiutil"
 18)
 19
 20// ModelType represents the type of model to select.
 21type ModelType int
 22
 23const (
 24	ModelTypeLarge ModelType = iota
 25	ModelTypeSmall
 26)
 27
 28// String returns the string representation of the [ModelType].
 29func (mt ModelType) String() string {
 30	switch mt {
 31	case ModelTypeLarge:
 32		return "Large Task"
 33	case ModelTypeSmall:
 34		return "Small Task"
 35	default:
 36		return "Unknown"
 37	}
 38}
 39
 40const (
 41	largeModelInputPlaceholder = "Choose a model for large, complex tasks"
 42	smallModelInputPlaceholder = "Choose a model for small, simple tasks"
 43)
 44
 45// ModelsID is the identifier for the model selection dialog.
 46const ModelsID = "models"
 47
 48// Models represents a model selection dialog.
 49type Models struct {
 50	com *common.Common
 51
 52	modelType ModelType
 53	providers []catwalk.Provider
 54
 55	width, height int
 56
 57	keyMap struct {
 58		Tab      key.Binding
 59		UpDown   key.Binding
 60		Select   key.Binding
 61		Next     key.Binding
 62		Previous key.Binding
 63		Close    key.Binding
 64	}
 65	list  *ModelsList
 66	input textinput.Model
 67	help  help.Model
 68}
 69
 70var _ Dialog = (*Models)(nil)
 71
 72// NewModels creates a new Models dialog.
 73func NewModels(com *common.Common) (*Models, error) {
 74	t := com.Styles
 75	m := &Models{}
 76	m.com = com
 77	help := help.New()
 78	help.Styles = t.DialogHelpStyles()
 79
 80	m.help = help
 81	m.list = NewModelsList(t)
 82	m.list.Focus()
 83	m.list.SetSelected(0)
 84
 85	m.input = textinput.New()
 86	m.input.SetVirtualCursor(false)
 87	m.input.Placeholder = largeModelInputPlaceholder
 88	m.input.SetStyles(com.Styles.TextInput)
 89	m.input.Focus()
 90
 91	m.keyMap.Tab = key.NewBinding(
 92		key.WithKeys("tab", "shift+tab"),
 93		key.WithHelp("tab", "toggle type"),
 94	)
 95	m.keyMap.Select = key.NewBinding(
 96		key.WithKeys("enter", "ctrl+y"),
 97		key.WithHelp("enter", "confirm"),
 98	)
 99	m.keyMap.UpDown = key.NewBinding(
100		key.WithKeys("up", "down"),
101		key.WithHelp("↑/↓", "choose"),
102	)
103	m.keyMap.Next = key.NewBinding(
104		key.WithKeys("down", "ctrl+n"),
105		key.WithHelp("↓", "next item"),
106	)
107	m.keyMap.Previous = key.NewBinding(
108		key.WithKeys("up", "ctrl+p"),
109		key.WithHelp("↑", "previous item"),
110	)
111	m.keyMap.Close = CloseKey
112
113	providers, err := getFilteredProviders(com.Config())
114	if err != nil {
115		return nil, fmt.Errorf("failed to get providers: %w", err)
116	}
117
118	m.providers = providers
119	if err := m.setProviderItems(); err != nil {
120		return nil, fmt.Errorf("failed to set provider items: %w", err)
121	}
122
123	return m, nil
124}
125
126// SetSize sets the size of the dialog.
127func (m *Models) SetSize(width, height int) {
128	t := m.com.Styles
129	m.width = width
130	m.height = height
131	innerWidth := width - t.Dialog.View.GetHorizontalFrameSize()
132	heightOffset := t.Dialog.Title.GetVerticalFrameSize() + 1 + // (1) title content
133		t.Dialog.InputPrompt.GetVerticalFrameSize() + 1 + // (1) input content
134		t.Dialog.HelpView.GetVerticalFrameSize() +
135		t.Dialog.View.GetVerticalFrameSize()
136	m.input.SetWidth(innerWidth - t.Dialog.InputPrompt.GetHorizontalFrameSize() - 1) // (1) cursor padding
137	m.list.SetSize(innerWidth, height-heightOffset)
138	m.help.SetWidth(width)
139}
140
141// ID implements Dialog.
142func (m *Models) ID() string {
143	return ModelsID
144}
145
146// Update implements Dialog.
147func (m *Models) Update(msg tea.Msg) tea.Msg {
148	switch msg := msg.(type) {
149	case tea.KeyPressMsg:
150		switch {
151		case key.Matches(msg, m.keyMap.Close):
152			return CloseMsg{}
153		case key.Matches(msg, m.keyMap.Previous):
154			m.list.Focus()
155			if m.list.IsSelectedFirst() {
156				m.list.SelectLast()
157				m.list.ScrollToBottom()
158				break
159			}
160			m.list.SelectPrev()
161			m.list.ScrollToSelected()
162		case key.Matches(msg, m.keyMap.Next):
163			m.list.Focus()
164			if m.list.IsSelectedLast() {
165				m.list.SelectFirst()
166				m.list.ScrollToTop()
167				break
168			}
169			m.list.SelectNext()
170			m.list.ScrollToSelected()
171		case key.Matches(msg, m.keyMap.Select):
172			selectedItem := m.list.SelectedItem()
173			if selectedItem == nil {
174				break
175			}
176
177			modelItem, ok := selectedItem.(*ModelItem)
178			if !ok {
179				break
180			}
181
182			return ModelSelectedMsg{
183				Provider: modelItem.prov,
184				Model:    modelItem.model,
185			}
186		case key.Matches(msg, m.keyMap.Tab):
187			if m.modelType == ModelTypeLarge {
188				m.modelType = ModelTypeSmall
189			} else {
190				m.modelType = ModelTypeLarge
191			}
192			if err := m.setProviderItems(); err != nil {
193				return uiutil.ReportError(err)
194			}
195		default:
196			var cmd tea.Cmd
197			m.input, cmd = m.input.Update(msg)
198			value := m.input.Value()
199			m.list.SetFilter(value)
200			m.list.ScrollToSelected()
201			return cmd
202		}
203	}
204	return nil
205}
206
207// Cursor returns the cursor for the dialog.
208func (m *Models) Cursor() *tea.Cursor {
209	return InputCursor(m.com.Styles, m.input.Cursor())
210}
211
212// modelTypeRadioView returns the radio view for model type selection.
213func (m *Models) modelTypeRadioView() string {
214	t := m.com.Styles
215	textStyle := t.HalfMuted
216	largeRadioStyle := t.RadioOff
217	smallRadioStyle := t.RadioOff
218	if m.modelType == ModelTypeLarge {
219		largeRadioStyle = t.RadioOn
220	} else {
221		smallRadioStyle = t.RadioOn
222	}
223
224	largeRadio := largeRadioStyle.Padding(0, 1).Render()
225	smallRadio := smallRadioStyle.Padding(0, 1).Render()
226
227	return fmt.Sprintf("%s%s  %s%s",
228		largeRadio, textStyle.Render(ModelTypeLarge.String()),
229		smallRadio, textStyle.Render(ModelTypeSmall.String()))
230}
231
232// View implements Dialog.
233func (m *Models) View() string {
234	t := m.com.Styles
235	titleStyle := t.Dialog.Title
236	dialogStyle := t.Dialog.View
237
238	radios := m.modelTypeRadioView()
239
240	headerOffset := lipgloss.Width(radios) + titleStyle.GetHorizontalFrameSize() +
241		dialogStyle.GetHorizontalFrameSize()
242
243	header := common.DialogTitle(t, "Switch Model", m.width-headerOffset) + radios
244
245	return HeaderInputListHelpView(t, m.width, m.list.Height(), header,
246		m.input.View(), m.list.Render(), m.help.View(m))
247}
248
249// ShortHelp returns the short help view.
250func (m *Models) ShortHelp() []key.Binding {
251	return []key.Binding{
252		m.keyMap.UpDown,
253		m.keyMap.Tab,
254		m.keyMap.Select,
255		m.keyMap.Close,
256	}
257}
258
259// FullHelp returns the full help view.
260func (m *Models) FullHelp() [][]key.Binding {
261	return [][]key.Binding{
262		{
263			m.keyMap.Select,
264			m.keyMap.Next,
265			m.keyMap.Previous,
266			m.keyMap.Tab,
267		},
268		{
269			m.keyMap.Close,
270		},
271	}
272}
273
274// setProviderItems sets the provider items in the list.
275func (m *Models) setProviderItems() error {
276	t := m.com.Styles
277	cfg := m.com.Config()
278
279	selectedType := config.SelectedModelTypeLarge
280	if m.modelType == ModelTypeLarge {
281		selectedType = config.SelectedModelTypeLarge
282	} else {
283		selectedType = config.SelectedModelTypeSmall
284	}
285
286	var selectedItemID string
287	currentModel := cfg.Models[selectedType]
288	recentItems := cfg.RecentModels[selectedType]
289
290	// Track providers already added to avoid duplicates
291	addedProviders := make(map[string]bool)
292
293	// Get a list of known providers to compare against
294	knownProviders, err := config.Providers(cfg)
295	if err != nil {
296		return fmt.Errorf("failed to get providers: %w", err)
297	}
298
299	containsProviderFunc := func(id string) func(p catwalk.Provider) bool {
300		return func(p catwalk.Provider) bool {
301			return p.ID == catwalk.InferenceProvider(id)
302		}
303	}
304
305	// itemsMap contains the keys of added model items.
306	itemsMap := make(map[string]*ModelItem)
307	groups := []ModelGroup{}
308	for id, p := range cfg.Providers.Seq2() {
309		if p.Disable {
310			continue
311		}
312
313		// Check if this provider is not in the known providers list
314		if !slices.ContainsFunc(knownProviders, containsProviderFunc(id)) ||
315			!slices.ContainsFunc(m.providers, containsProviderFunc(id)) {
316			provider := p.ToProvider()
317
318			// Add this unknown provider to the list
319			name := p.Name
320			if name == "" {
321				name = id
322			}
323
324			addedProviders[id] = true
325
326			group := NewModelGroup(t, name, true)
327			for _, model := range p.Models {
328				item := NewModelItem(t, provider, model, false)
329				group.AppendItems(item)
330				itemsMap[item.ID()] = item
331				if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
332					selectedItemID = item.ID()
333				}
334			}
335		}
336	}
337
338	// Now add known providers from the predefined list
339	for _, provider := range m.providers {
340		providerID := string(provider.ID)
341		if addedProviders[providerID] {
342			continue
343		}
344
345		providerConfig, providerConfigured := cfg.Providers.Get(providerID)
346		if providerConfigured && providerConfig.Disable {
347			continue
348		}
349
350		displayProvider := provider
351		if providerConfigured {
352			displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
353			modelIndex := make(map[string]int, len(displayProvider.Models))
354			for i, model := range displayProvider.Models {
355				modelIndex[model.ID] = i
356			}
357			for _, model := range providerConfig.Models {
358				if model.ID == "" {
359					continue
360				}
361				if idx, ok := modelIndex[model.ID]; ok {
362					if model.Name != "" {
363						displayProvider.Models[idx].Name = model.Name
364					}
365					continue
366				}
367				if model.Name == "" {
368					model.Name = model.ID
369				}
370				displayProvider.Models = append(displayProvider.Models, model)
371				modelIndex[model.ID] = len(displayProvider.Models) - 1
372			}
373		}
374
375		name := displayProvider.Name
376		if name == "" {
377			name = providerID
378		}
379
380		group := NewModelGroup(t, name, providerConfigured)
381		for _, model := range displayProvider.Models {
382			item := NewModelItem(t, provider, model, false)
383			group.AppendItems(item)
384			itemsMap[item.ID()] = item
385			if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
386				selectedItemID = item.ID()
387			}
388		}
389
390		groups = append(groups, group)
391	}
392
393	if len(recentItems) > 0 {
394		recentGroup := NewModelGroup(t, "Recently used", false)
395
396		var validRecentItems []config.SelectedModel
397		for _, recent := range recentItems {
398			key := modelKey(recent.Provider, recent.Model)
399			item, ok := itemsMap[key]
400			if !ok {
401				continue
402			}
403
404			// Show provider for recent items
405			item = NewModelItem(t, item.prov, item.model, true)
406			item.showProvider = true
407
408			validRecentItems = append(validRecentItems, recent)
409			recentGroup.AppendItems(item)
410			if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider {
411				selectedItemID = item.ID()
412			}
413		}
414
415		if len(validRecentItems) != len(recentItems) {
416			// FIXME: Does this need to be here? Is it mutating the config during a read?
417			if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
418				return fmt.Errorf("failed to update recent models: %w", err)
419			}
420		}
421
422		if len(recentGroup.Items) > 0 {
423			groups = append([]ModelGroup{recentGroup}, groups...)
424		}
425	}
426
427	// Set model groups in the list.
428	m.list.SetGroups(groups...)
429	m.list.SetSelectedItem(selectedItemID)
430
431	// Update placeholder based on model type
432	if m.modelType == ModelTypeLarge {
433		m.input.Placeholder = largeModelInputPlaceholder
434	} else {
435		m.input.Placeholder = smallModelInputPlaceholder
436	}
437
438	return nil
439}
440
441func getFilteredProviders(cfg *config.Config) ([]catwalk.Provider, error) {
442	providers, err := config.Providers(cfg)
443	if err != nil {
444		return nil, fmt.Errorf("failed to get providers: %w", err)
445	}
446	filteredProviders := []catwalk.Provider{}
447	for _, p := range providers {
448		hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
449		if hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure {
450			filteredProviders = append(filteredProviders, p)
451		}
452	}
453	return filteredProviders, nil
454}
455
456func modelKey(providerID, modelID string) string {
457	if providerID == "" || modelID == "" {
458		return ""
459	}
460	return providerID + ":" + modelID
461}