models.go

  1package dialog
  2
  3import (
  4	"cmp"
  5	"fmt"
  6	"slices"
  7	"strings"
  8
  9	"charm.land/bubbles/v2/help"
 10	"charm.land/bubbles/v2/key"
 11	"charm.land/bubbles/v2/textinput"
 12	tea "charm.land/bubbletea/v2"
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/ui/common"
 16	"github.com/charmbracelet/crush/internal/uiutil"
 17	uv "github.com/charmbracelet/ultraviolet"
 18)
 19
 20// ModelType represents the type of model to select.
 21type ModelType int
 22
 23const (
 24	ModelTypeLarge ModelType = iota
 25	ModelTypeSmall
 26)
 27
 28// String returns the string representation of the [ModelType].
 29func (mt ModelType) String() string {
 30	switch mt {
 31	case ModelTypeLarge:
 32		return "Large Task"
 33	case ModelTypeSmall:
 34		return "Small Task"
 35	default:
 36		return "Unknown"
 37	}
 38}
 39
 40// Config returns the corresponding config model type.
 41func (mt ModelType) Config() config.SelectedModelType {
 42	switch mt {
 43	case ModelTypeLarge:
 44		return config.SelectedModelTypeLarge
 45	case ModelTypeSmall:
 46		return config.SelectedModelTypeSmall
 47	default:
 48		return ""
 49	}
 50}
 51
 52// Placeholder returns the input placeholder for the model type.
 53func (mt ModelType) Placeholder() string {
 54	switch mt {
 55	case ModelTypeLarge:
 56		return largeModelInputPlaceholder
 57	case ModelTypeSmall:
 58		return smallModelInputPlaceholder
 59	default:
 60		return ""
 61	}
 62}
 63
 64const (
 65	largeModelInputPlaceholder = "Choose a model for large, complex tasks"
 66	smallModelInputPlaceholder = "Choose a model for small, simple tasks"
 67)
 68
 69// ModelsID is the identifier for the model selection dialog.
 70const ModelsID = "models"
 71
 72// Models represents a model selection dialog.
 73type Models struct {
 74	com *common.Common
 75
 76	modelType ModelType
 77	providers []catwalk.Provider
 78
 79	keyMap struct {
 80		Tab      key.Binding
 81		UpDown   key.Binding
 82		Select   key.Binding
 83		Next     key.Binding
 84		Previous key.Binding
 85		Close    key.Binding
 86	}
 87	list  *ModelsList
 88	input textinput.Model
 89	help  help.Model
 90}
 91
 92var _ Dialog = (*Models)(nil)
 93
 94// NewModels creates a new Models dialog.
 95func NewModels(com *common.Common) (*Models, error) {
 96	t := com.Styles
 97	m := &Models{}
 98	m.com = com
 99	help := help.New()
100	help.Styles = t.DialogHelpStyles()
101
102	m.help = help
103	m.list = NewModelsList(t)
104	m.list.Focus()
105	m.list.SetSelected(0)
106
107	m.input = textinput.New()
108	m.input.SetVirtualCursor(false)
109	m.input.Placeholder = largeModelInputPlaceholder
110	m.input.SetStyles(com.Styles.TextInput)
111	m.input.Focus()
112
113	m.keyMap.Tab = key.NewBinding(
114		key.WithKeys("tab", "shift+tab"),
115		key.WithHelp("tab", "toggle type"),
116	)
117	m.keyMap.Select = key.NewBinding(
118		key.WithKeys("enter", "ctrl+y"),
119		key.WithHelp("enter", "confirm"),
120	)
121	m.keyMap.UpDown = key.NewBinding(
122		key.WithKeys("up", "down"),
123		key.WithHelp("↑/↓", "choose"),
124	)
125	m.keyMap.Next = key.NewBinding(
126		key.WithKeys("down", "ctrl+n"),
127		key.WithHelp("↓", "next item"),
128	)
129	m.keyMap.Previous = key.NewBinding(
130		key.WithKeys("up", "ctrl+p"),
131		key.WithHelp("↑", "previous item"),
132	)
133	m.keyMap.Close = CloseKey
134
135	providers, err := getFilteredProviders(com.Config())
136	if err != nil {
137		return nil, fmt.Errorf("failed to get providers: %w", err)
138	}
139
140	m.providers = providers
141	if err := m.setProviderItems(); err != nil {
142		return nil, fmt.Errorf("failed to set provider items: %w", err)
143	}
144
145	return m, nil
146}
147
148// ID implements Dialog.
149func (m *Models) ID() string {
150	return ModelsID
151}
152
153// HandleMsg implements Dialog.
154func (m *Models) HandleMsg(msg tea.Msg) Action {
155	switch msg := msg.(type) {
156	case tea.KeyPressMsg:
157		switch {
158		case key.Matches(msg, m.keyMap.Close):
159			return ActionClose{}
160		case key.Matches(msg, m.keyMap.Previous):
161			m.list.Focus()
162			if m.list.IsSelectedFirst() {
163				m.list.SelectLast()
164				m.list.ScrollToBottom()
165				break
166			}
167			m.list.SelectPrev()
168			m.list.ScrollToSelected()
169		case key.Matches(msg, m.keyMap.Next):
170			m.list.Focus()
171			if m.list.IsSelectedLast() {
172				m.list.SelectFirst()
173				m.list.ScrollToTop()
174				break
175			}
176			m.list.SelectNext()
177			m.list.ScrollToSelected()
178		case key.Matches(msg, m.keyMap.Select):
179			selectedItem := m.list.SelectedItem()
180			if selectedItem == nil {
181				break
182			}
183
184			modelItem, ok := selectedItem.(*ModelItem)
185			if !ok {
186				break
187			}
188
189			return ActionSelectModel{
190				Provider:  modelItem.prov,
191				Model:     modelItem.SelectedModel(),
192				ModelType: modelItem.SelectedModelType(),
193			}
194		case key.Matches(msg, m.keyMap.Tab):
195			if m.modelType == ModelTypeLarge {
196				m.modelType = ModelTypeSmall
197			} else {
198				m.modelType = ModelTypeLarge
199			}
200			if err := m.setProviderItems(); err != nil {
201				return uiutil.ReportError(err)
202			}
203		default:
204			var cmd tea.Cmd
205			m.input, cmd = m.input.Update(msg)
206			value := m.input.Value()
207			m.list.SetFilter(value)
208			m.list.ScrollToSelected()
209			return ActionCmd{cmd}
210		}
211	}
212	return nil
213}
214
215// Cursor returns the cursor for the dialog.
216func (m *Models) Cursor() *tea.Cursor {
217	return InputCursor(m.com.Styles, m.input.Cursor())
218}
219
220// modelTypeRadioView returns the radio view for model type selection.
221func (m *Models) modelTypeRadioView() string {
222	t := m.com.Styles
223	textStyle := t.HalfMuted
224	largeRadioStyle := t.RadioOff
225	smallRadioStyle := t.RadioOff
226	if m.modelType == ModelTypeLarge {
227		largeRadioStyle = t.RadioOn
228	} else {
229		smallRadioStyle = t.RadioOn
230	}
231
232	largeRadio := largeRadioStyle.Padding(0, 1).Render()
233	smallRadio := smallRadioStyle.Padding(0, 1).Render()
234
235	return fmt.Sprintf("%s%s  %s%s",
236		largeRadio, textStyle.Render(ModelTypeLarge.String()),
237		smallRadio, textStyle.Render(ModelTypeSmall.String()))
238}
239
240// Draw implements [Dialog].
241func (m *Models) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
242	t := m.com.Styles
243	width := max(0, min(defaultDialogMaxWidth, area.Dx()))
244	height := max(0, min(defaultDialogHeight, area.Dy()))
245	innerWidth := width - t.Dialog.View.GetHorizontalFrameSize()
246	heightOffset := t.Dialog.Title.GetVerticalFrameSize() + titleContentHeight +
247		t.Dialog.InputPrompt.GetVerticalFrameSize() + inputContentHeight +
248		t.Dialog.HelpView.GetVerticalFrameSize() +
249		t.Dialog.View.GetVerticalFrameSize()
250	m.input.SetWidth(innerWidth - t.Dialog.InputPrompt.GetHorizontalFrameSize() - 1) // (1) cursor padding
251	m.list.SetSize(innerWidth, height-heightOffset)
252	m.help.SetWidth(innerWidth)
253
254	rc := NewRenderContext(t, width)
255	rc.Title = "Switch Model"
256	rc.TitleInfo = m.modelTypeRadioView()
257	inputView := t.Dialog.InputPrompt.Render(m.input.View())
258	rc.AddPart(inputView)
259	listView := t.Dialog.List.Height(m.list.Height()).Render(m.list.Render())
260	rc.AddPart(listView)
261	rc.Help = m.help.View(m)
262
263	view := rc.Render()
264
265	cur := m.Cursor()
266	DrawCenterCursor(scr, area, view, cur)
267	return cur
268}
269
270// ShortHelp returns the short help view.
271func (m *Models) ShortHelp() []key.Binding {
272	return []key.Binding{
273		m.keyMap.UpDown,
274		m.keyMap.Tab,
275		m.keyMap.Select,
276		m.keyMap.Close,
277	}
278}
279
280// FullHelp returns the full help view.
281func (m *Models) FullHelp() [][]key.Binding {
282	return [][]key.Binding{
283		{
284			m.keyMap.Select,
285			m.keyMap.Next,
286			m.keyMap.Previous,
287			m.keyMap.Tab,
288		},
289		{
290			m.keyMap.Close,
291		},
292	}
293}
294
295// setProviderItems sets the provider items in the list.
296func (m *Models) setProviderItems() error {
297	t := m.com.Styles
298	cfg := m.com.Config()
299
300	var selectedItemID string
301	selectedType := m.modelType.Config()
302	currentModel := cfg.Models[selectedType]
303	recentItems := cfg.RecentModels[selectedType]
304
305	// Track providers already added to avoid duplicates
306	addedProviders := make(map[string]bool)
307
308	// Get a list of known providers to compare against
309	knownProviders, err := config.Providers(cfg)
310	if err != nil {
311		return fmt.Errorf("failed to get providers: %w", err)
312	}
313
314	containsProviderFunc := func(id string) func(p catwalk.Provider) bool {
315		return func(p catwalk.Provider) bool {
316			return p.ID == catwalk.InferenceProvider(id)
317		}
318	}
319
320	// itemsMap contains the keys of added model items.
321	itemsMap := make(map[string]*ModelItem)
322	groups := []ModelGroup{}
323	for id, p := range cfg.Providers.Seq2() {
324		if p.Disable {
325			continue
326		}
327
328		// Check if this provider is not in the known providers list
329		if !slices.ContainsFunc(knownProviders, containsProviderFunc(id)) ||
330			!slices.ContainsFunc(m.providers, containsProviderFunc(id)) {
331			provider := p.ToProvider()
332
333			// Add this unknown provider to the list
334			name := cmp.Or(p.Name, id)
335
336			addedProviders[id] = true
337
338			group := NewModelGroup(t, name, true)
339			for _, model := range p.Models {
340				item := NewModelItem(t, provider, model, m.modelType, false)
341				group.AppendItems(item)
342				itemsMap[item.ID()] = item
343				if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
344					selectedItemID = item.ID()
345				}
346			}
347			if len(group.Items) > 0 {
348				groups = append(groups, group)
349			}
350		}
351	}
352
353	// Now add known providers from the predefined list
354	for _, provider := range m.providers {
355		providerID := string(provider.ID)
356		if addedProviders[providerID] {
357			continue
358		}
359
360		providerConfig, providerConfigured := cfg.Providers.Get(providerID)
361		if providerConfigured && providerConfig.Disable {
362			continue
363		}
364
365		displayProvider := provider
366		if providerConfigured {
367			displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
368			modelIndex := make(map[string]int, len(displayProvider.Models))
369			for i, model := range displayProvider.Models {
370				modelIndex[model.ID] = i
371			}
372			for _, model := range providerConfig.Models {
373				if model.ID == "" {
374					continue
375				}
376				if idx, ok := modelIndex[model.ID]; ok {
377					if model.Name != "" {
378						displayProvider.Models[idx].Name = model.Name
379					}
380					continue
381				}
382				if model.Name == "" {
383					model.Name = model.ID
384				}
385				displayProvider.Models = append(displayProvider.Models, model)
386				modelIndex[model.ID] = len(displayProvider.Models) - 1
387			}
388		}
389
390		name := displayProvider.Name
391		if name == "" {
392			name = providerID
393		}
394
395		group := NewModelGroup(t, name, providerConfigured)
396		for _, model := range displayProvider.Models {
397			item := NewModelItem(t, provider, model, m.modelType, false)
398			group.AppendItems(item)
399			itemsMap[item.ID()] = item
400			if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
401				selectedItemID = item.ID()
402			}
403		}
404
405		groups = append(groups, group)
406	}
407
408	if len(recentItems) > 0 {
409		recentGroup := NewModelGroup(t, "Recently used", false)
410
411		var validRecentItems []config.SelectedModel
412		for _, recent := range recentItems {
413			key := modelKey(recent.Provider, recent.Model)
414			item, ok := itemsMap[key]
415			if !ok {
416				continue
417			}
418
419			// Show provider for recent items
420			item = NewModelItem(t, item.prov, item.model, m.modelType, true)
421			item.showProvider = true
422
423			validRecentItems = append(validRecentItems, recent)
424			recentGroup.AppendItems(item)
425			if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider {
426				selectedItemID = item.ID()
427			}
428		}
429
430		if len(validRecentItems) != len(recentItems) {
431			// FIXME: Does this need to be here? Is it mutating the config during a read?
432			if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
433				return fmt.Errorf("failed to update recent models: %w", err)
434			}
435		}
436
437		if len(recentGroup.Items) > 0 {
438			groups = append([]ModelGroup{recentGroup}, groups...)
439		}
440	}
441
442	// Set model groups in the list.
443	m.list.SetGroups(groups...)
444	m.list.SetSelectedItem(selectedItemID)
445
446	// Update placeholder based on model type
447	m.input.Placeholder = m.modelType.Placeholder()
448
449	return nil
450}
451
452func getFilteredProviders(cfg *config.Config) ([]catwalk.Provider, error) {
453	providers, err := config.Providers(cfg)
454	if err != nil {
455		return nil, fmt.Errorf("failed to get providers: %w", err)
456	}
457	filteredProviders := []catwalk.Provider{}
458	for _, p := range providers {
459		hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
460		if hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure {
461			filteredProviders = append(filteredProviders, p)
462		}
463	}
464	return filteredProviders, nil
465}
466
467func modelKey(providerID, modelID string) string {
468	if providerID == "" || modelID == "" {
469		return ""
470	}
471	return providerID + ":" + modelID
472}