models.go

  1package dialog
  2
  3import (
  4	"cmp"
  5	"fmt"
  6	"slices"
  7	"strings"
  8
  9	"charm.land/bubbles/v2/help"
 10	"charm.land/bubbles/v2/key"
 11	"charm.land/bubbles/v2/textinput"
 12	tea "charm.land/bubbletea/v2"
 13	"charm.land/lipgloss/v2"
 14	"github.com/charmbracelet/catwalk/pkg/catwalk"
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/ui/common"
 17	"github.com/charmbracelet/crush/internal/uiutil"
 18)
 19
 20// ModelType represents the type of model to select.
 21type ModelType int
 22
 23const (
 24	ModelTypeLarge ModelType = iota
 25	ModelTypeSmall
 26)
 27
 28// String returns the string representation of the [ModelType].
 29func (mt ModelType) String() string {
 30	switch mt {
 31	case ModelTypeLarge:
 32		return "Large Task"
 33	case ModelTypeSmall:
 34		return "Small Task"
 35	default:
 36		return "Unknown"
 37	}
 38}
 39
 40// Config returns the corresponding config model type.
 41func (mt ModelType) Config() config.SelectedModelType {
 42	switch mt {
 43	case ModelTypeLarge:
 44		return config.SelectedModelTypeLarge
 45	case ModelTypeSmall:
 46		return config.SelectedModelTypeSmall
 47	default:
 48		return ""
 49	}
 50}
 51
 52// Placeholder returns the input placeholder for the model type.
 53func (mt ModelType) Placeholder() string {
 54	switch mt {
 55	case ModelTypeLarge:
 56		return largeModelInputPlaceholder
 57	case ModelTypeSmall:
 58		return smallModelInputPlaceholder
 59	default:
 60		return ""
 61	}
 62}
 63
 64const (
 65	largeModelInputPlaceholder = "Choose a model for large, complex tasks"
 66	smallModelInputPlaceholder = "Choose a model for small, simple tasks"
 67)
 68
 69// ModelsID is the identifier for the model selection dialog.
 70const ModelsID = "models"
 71
 72// Models represents a model selection dialog.
 73type Models struct {
 74	com *common.Common
 75
 76	modelType ModelType
 77	providers []catwalk.Provider
 78
 79	width, height int
 80
 81	keyMap struct {
 82		Tab      key.Binding
 83		UpDown   key.Binding
 84		Select   key.Binding
 85		Next     key.Binding
 86		Previous key.Binding
 87		Close    key.Binding
 88	}
 89	list  *ModelsList
 90	input textinput.Model
 91	help  help.Model
 92}
 93
 94var _ Dialog = (*Models)(nil)
 95
 96// NewModels creates a new Models dialog.
 97func NewModels(com *common.Common) (*Models, error) {
 98	t := com.Styles
 99	m := &Models{}
100	m.com = com
101	help := help.New()
102	help.Styles = t.DialogHelpStyles()
103
104	m.help = help
105	m.list = NewModelsList(t)
106	m.list.Focus()
107	m.list.SetSelected(0)
108
109	m.input = textinput.New()
110	m.input.SetVirtualCursor(false)
111	m.input.Placeholder = largeModelInputPlaceholder
112	m.input.SetStyles(com.Styles.TextInput)
113	m.input.Focus()
114
115	m.keyMap.Tab = key.NewBinding(
116		key.WithKeys("tab", "shift+tab"),
117		key.WithHelp("tab", "toggle type"),
118	)
119	m.keyMap.Select = key.NewBinding(
120		key.WithKeys("enter", "ctrl+y"),
121		key.WithHelp("enter", "confirm"),
122	)
123	m.keyMap.UpDown = key.NewBinding(
124		key.WithKeys("up", "down"),
125		key.WithHelp("↑/↓", "choose"),
126	)
127	m.keyMap.Next = key.NewBinding(
128		key.WithKeys("down", "ctrl+n"),
129		key.WithHelp("↓", "next item"),
130	)
131	m.keyMap.Previous = key.NewBinding(
132		key.WithKeys("up", "ctrl+p"),
133		key.WithHelp("↑", "previous item"),
134	)
135	m.keyMap.Close = CloseKey
136
137	providers, err := getFilteredProviders(com.Config())
138	if err != nil {
139		return nil, fmt.Errorf("failed to get providers: %w", err)
140	}
141
142	m.providers = providers
143	if err := m.setProviderItems(); err != nil {
144		return nil, fmt.Errorf("failed to set provider items: %w", err)
145	}
146
147	return m, nil
148}
149
150// SetSize sets the size of the dialog.
151func (m *Models) SetSize(width, height int) {
152	t := m.com.Styles
153	m.width = width
154	m.height = height
155	innerWidth := width - t.Dialog.View.GetHorizontalFrameSize()
156	heightOffset := t.Dialog.Title.GetVerticalFrameSize() + 1 + // (1) title content
157		t.Dialog.InputPrompt.GetVerticalFrameSize() + 1 + // (1) input content
158		t.Dialog.HelpView.GetVerticalFrameSize() +
159		t.Dialog.View.GetVerticalFrameSize()
160	m.input.SetWidth(innerWidth - t.Dialog.InputPrompt.GetHorizontalFrameSize() - 1) // (1) cursor padding
161	m.list.SetSize(innerWidth, height-heightOffset)
162	m.help.SetWidth(width)
163}
164
165// ID implements Dialog.
166func (m *Models) ID() string {
167	return ModelsID
168}
169
170// Update implements Dialog.
171func (m *Models) Update(msg tea.Msg) tea.Msg {
172	switch msg := msg.(type) {
173	case tea.KeyPressMsg:
174		switch {
175		case key.Matches(msg, m.keyMap.Close):
176			return CloseMsg{}
177		case key.Matches(msg, m.keyMap.Previous):
178			m.list.Focus()
179			if m.list.IsSelectedFirst() {
180				m.list.SelectLast()
181				m.list.ScrollToBottom()
182				break
183			}
184			m.list.SelectPrev()
185			m.list.ScrollToSelected()
186		case key.Matches(msg, m.keyMap.Next):
187			m.list.Focus()
188			if m.list.IsSelectedLast() {
189				m.list.SelectFirst()
190				m.list.ScrollToTop()
191				break
192			}
193			m.list.SelectNext()
194			m.list.ScrollToSelected()
195		case key.Matches(msg, m.keyMap.Select):
196			selectedItem := m.list.SelectedItem()
197			if selectedItem == nil {
198				break
199			}
200
201			modelItem, ok := selectedItem.(*ModelItem)
202			if !ok {
203				break
204			}
205
206			return ModelSelectedMsg{
207				Model:     modelItem.SelectedModel(),
208				ModelType: modelItem.SelectedModelType(),
209			}
210		case key.Matches(msg, m.keyMap.Tab):
211			if m.modelType == ModelTypeLarge {
212				m.modelType = ModelTypeSmall
213			} else {
214				m.modelType = ModelTypeLarge
215			}
216			if err := m.setProviderItems(); err != nil {
217				return uiutil.ReportError(err)
218			}
219		default:
220			var cmd tea.Cmd
221			m.input, cmd = m.input.Update(msg)
222			value := m.input.Value()
223			m.list.SetFilter(value)
224			m.list.ScrollToSelected()
225			return cmd
226		}
227	}
228	return nil
229}
230
231// Cursor returns the cursor for the dialog.
232func (m *Models) Cursor() *tea.Cursor {
233	return InputCursor(m.com.Styles, m.input.Cursor())
234}
235
236// modelTypeRadioView returns the radio view for model type selection.
237func (m *Models) modelTypeRadioView() string {
238	t := m.com.Styles
239	textStyle := t.HalfMuted
240	largeRadioStyle := t.RadioOff
241	smallRadioStyle := t.RadioOff
242	if m.modelType == ModelTypeLarge {
243		largeRadioStyle = t.RadioOn
244	} else {
245		smallRadioStyle = t.RadioOn
246	}
247
248	largeRadio := largeRadioStyle.Padding(0, 1).Render()
249	smallRadio := smallRadioStyle.Padding(0, 1).Render()
250
251	return fmt.Sprintf("%s%s  %s%s",
252		largeRadio, textStyle.Render(ModelTypeLarge.String()),
253		smallRadio, textStyle.Render(ModelTypeSmall.String()))
254}
255
256// View implements Dialog.
257func (m *Models) View() string {
258	t := m.com.Styles
259	titleStyle := t.Dialog.Title
260	dialogStyle := t.Dialog.View
261
262	radios := m.modelTypeRadioView()
263
264	headerOffset := lipgloss.Width(radios) + titleStyle.GetHorizontalFrameSize() +
265		dialogStyle.GetHorizontalFrameSize()
266
267	header := common.DialogTitle(t, "Switch Model", m.width-headerOffset) + radios
268
269	return HeaderInputListHelpView(t, m.width, m.list.Height(), header,
270		m.input.View(), m.list.Render(), m.help.View(m))
271}
272
273// ShortHelp returns the short help view.
274func (m *Models) ShortHelp() []key.Binding {
275	return []key.Binding{
276		m.keyMap.UpDown,
277		m.keyMap.Tab,
278		m.keyMap.Select,
279		m.keyMap.Close,
280	}
281}
282
283// FullHelp returns the full help view.
284func (m *Models) FullHelp() [][]key.Binding {
285	return [][]key.Binding{
286		{
287			m.keyMap.Select,
288			m.keyMap.Next,
289			m.keyMap.Previous,
290			m.keyMap.Tab,
291		},
292		{
293			m.keyMap.Close,
294		},
295	}
296}
297
298// setProviderItems sets the provider items in the list.
299func (m *Models) setProviderItems() error {
300	t := m.com.Styles
301	cfg := m.com.Config()
302
303	var selectedItemID string
304	selectedType := m.modelType.Config()
305	currentModel := cfg.Models[selectedType]
306	recentItems := cfg.RecentModels[selectedType]
307
308	// Track providers already added to avoid duplicates
309	addedProviders := make(map[string]bool)
310
311	// Get a list of known providers to compare against
312	knownProviders, err := config.Providers(cfg)
313	if err != nil {
314		return fmt.Errorf("failed to get providers: %w", err)
315	}
316
317	containsProviderFunc := func(id string) func(p catwalk.Provider) bool {
318		return func(p catwalk.Provider) bool {
319			return p.ID == catwalk.InferenceProvider(id)
320		}
321	}
322
323	// itemsMap contains the keys of added model items.
324	itemsMap := make(map[string]*ModelItem)
325	groups := []ModelGroup{}
326	for id, p := range cfg.Providers.Seq2() {
327		if p.Disable {
328			continue
329		}
330
331		// Check if this provider is not in the known providers list
332		if !slices.ContainsFunc(knownProviders, containsProviderFunc(id)) ||
333			!slices.ContainsFunc(m.providers, containsProviderFunc(id)) {
334			provider := p.ToProvider()
335
336			// Add this unknown provider to the list
337			name := p.Name
338			if name == "" {
339				name = id
340			}
341
342			addedProviders[id] = true
343
344			group := NewModelGroup(t, name, true)
345			for _, model := range p.Models {
346				item := NewModelItem(t, provider, model, m.modelType, false)
347				group.AppendItems(item)
348				itemsMap[item.ID()] = item
349				if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
350					selectedItemID = item.ID()
351				}
352			}
353		}
354	}
355
356	// Now add known providers from the predefined list
357	for _, provider := range m.providers {
358		providerID := string(provider.ID)
359		if addedProviders[providerID] {
360			continue
361		}
362
363		providerConfig, providerConfigured := cfg.Providers.Get(providerID)
364		if providerConfigured && providerConfig.Disable {
365			continue
366		}
367
368		displayProvider := provider
369		if providerConfigured {
370			displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
371			modelIndex := make(map[string]int, len(displayProvider.Models))
372			for i, model := range displayProvider.Models {
373				modelIndex[model.ID] = i
374			}
375			for _, model := range providerConfig.Models {
376				if model.ID == "" {
377					continue
378				}
379				if idx, ok := modelIndex[model.ID]; ok {
380					if model.Name != "" {
381						displayProvider.Models[idx].Name = model.Name
382					}
383					continue
384				}
385				if model.Name == "" {
386					model.Name = model.ID
387				}
388				displayProvider.Models = append(displayProvider.Models, model)
389				modelIndex[model.ID] = len(displayProvider.Models) - 1
390			}
391		}
392
393		name := displayProvider.Name
394		if name == "" {
395			name = providerID
396		}
397
398		group := NewModelGroup(t, name, providerConfigured)
399		for _, model := range displayProvider.Models {
400			item := NewModelItem(t, provider, model, m.modelType, false)
401			group.AppendItems(item)
402			itemsMap[item.ID()] = item
403			if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
404				selectedItemID = item.ID()
405			}
406		}
407
408		groups = append(groups, group)
409	}
410
411	if len(recentItems) > 0 {
412		recentGroup := NewModelGroup(t, "Recently used", false)
413
414		var validRecentItems []config.SelectedModel
415		for _, recent := range recentItems {
416			key := modelKey(recent.Provider, recent.Model)
417			item, ok := itemsMap[key]
418			if !ok {
419				continue
420			}
421
422			// Show provider for recent items
423			item = NewModelItem(t, item.prov, item.model, m.modelType, true)
424			item.showProvider = true
425
426			validRecentItems = append(validRecentItems, recent)
427			recentGroup.AppendItems(item)
428			if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider {
429				selectedItemID = item.ID()
430			}
431		}
432
433		if len(validRecentItems) != len(recentItems) {
434			// FIXME: Does this need to be here? Is it mutating the config during a read?
435			if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
436				return fmt.Errorf("failed to update recent models: %w", err)
437			}
438		}
439
440		if len(recentGroup.Items) > 0 {
441			groups = append([]ModelGroup{recentGroup}, groups...)
442		}
443	}
444
445	// Set model groups in the list.
446	m.list.SetGroups(groups...)
447	m.list.SetSelectedItem(selectedItemID)
448
449	// Update placeholder based on model type
450	m.input.Placeholder = m.modelType.Placeholder()
451
452	return nil
453}
454
455func getFilteredProviders(cfg *config.Config) ([]catwalk.Provider, error) {
456	providers, err := config.Providers(cfg)
457	if err != nil {
458		return nil, fmt.Errorf("failed to get providers: %w", err)
459	}
460	filteredProviders := []catwalk.Provider{}
461	for _, p := range providers {
462		hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
463		if hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure {
464			filteredProviders = append(filteredProviders, p)
465		}
466	}
467	return filteredProviders, nil
468}
469
470func modelKey(providerID, modelID string) string {
471	if providerID == "" || modelID == "" {
472		return ""
473	}
474	return providerID + ":" + modelID
475}