models.go

  1package dialog
  2
  3import (
  4	"cmp"
  5	"fmt"
  6	"slices"
  7	"strings"
  8
  9	"charm.land/bubbles/v2/help"
 10	"charm.land/bubbles/v2/key"
 11	"charm.land/bubbles/v2/textinput"
 12	tea "charm.land/bubbletea/v2"
 13	"charm.land/lipgloss/v2"
 14	"github.com/charmbracelet/catwalk/pkg/catwalk"
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/ui/common"
 17	"github.com/charmbracelet/crush/internal/uiutil"
 18	uv "github.com/charmbracelet/ultraviolet"
 19	"github.com/charmbracelet/x/ansi"
 20)
 21
 22// ModelType represents the type of model to select.
 23type ModelType int
 24
 25const (
 26	ModelTypeLarge ModelType = iota
 27	ModelTypeSmall
 28)
 29
 30// String returns the string representation of the [ModelType].
 31func (mt ModelType) String() string {
 32	switch mt {
 33	case ModelTypeLarge:
 34		return "Large Task"
 35	case ModelTypeSmall:
 36		return "Small Task"
 37	default:
 38		return "Unknown"
 39	}
 40}
 41
 42// Config returns the corresponding config model type.
 43func (mt ModelType) Config() config.SelectedModelType {
 44	switch mt {
 45	case ModelTypeLarge:
 46		return config.SelectedModelTypeLarge
 47	case ModelTypeSmall:
 48		return config.SelectedModelTypeSmall
 49	default:
 50		return ""
 51	}
 52}
 53
 54// Placeholder returns the input placeholder for the model type.
 55func (mt ModelType) Placeholder() string {
 56	switch mt {
 57	case ModelTypeLarge:
 58		return largeModelInputPlaceholder
 59	case ModelTypeSmall:
 60		return smallModelInputPlaceholder
 61	default:
 62		return ""
 63	}
 64}
 65
 66const (
 67	largeModelInputPlaceholder = "Choose a model for large, complex tasks"
 68	smallModelInputPlaceholder = "Choose a model for small, simple tasks"
 69)
 70
 71// ModelsID is the identifier for the model selection dialog.
 72const ModelsID = "models"
 73
 74// Models represents a model selection dialog.
 75type Models struct {
 76	com *common.Common
 77
 78	modelType ModelType
 79	providers []catwalk.Provider
 80
 81	keyMap struct {
 82		Tab      key.Binding
 83		UpDown   key.Binding
 84		Select   key.Binding
 85		Next     key.Binding
 86		Previous key.Binding
 87		Close    key.Binding
 88	}
 89	list  *ModelsList
 90	input textinput.Model
 91	help  help.Model
 92}
 93
 94var _ Dialog = (*Models)(nil)
 95
 96// NewModels creates a new Models dialog.
 97func NewModels(com *common.Common) (*Models, error) {
 98	t := com.Styles
 99	m := &Models{}
100	m.com = com
101	help := help.New()
102	help.Styles = t.DialogHelpStyles()
103
104	m.help = help
105	m.list = NewModelsList(t)
106	m.list.Focus()
107	m.list.SetSelected(0)
108
109	m.input = textinput.New()
110	m.input.SetVirtualCursor(false)
111	m.input.Placeholder = largeModelInputPlaceholder
112	m.input.SetStyles(com.Styles.TextInput)
113	m.input.Focus()
114
115	m.keyMap.Tab = key.NewBinding(
116		key.WithKeys("tab", "shift+tab"),
117		key.WithHelp("tab", "toggle type"),
118	)
119	m.keyMap.Select = key.NewBinding(
120		key.WithKeys("enter", "ctrl+y"),
121		key.WithHelp("enter", "confirm"),
122	)
123	m.keyMap.UpDown = key.NewBinding(
124		key.WithKeys("up", "down"),
125		key.WithHelp("↑/↓", "choose"),
126	)
127	m.keyMap.Next = key.NewBinding(
128		key.WithKeys("down", "ctrl+n"),
129		key.WithHelp("↓", "next item"),
130	)
131	m.keyMap.Previous = key.NewBinding(
132		key.WithKeys("up", "ctrl+p"),
133		key.WithHelp("↑", "previous item"),
134	)
135	m.keyMap.Close = CloseKey
136
137	providers, err := getFilteredProviders(com.Config())
138	if err != nil {
139		return nil, fmt.Errorf("failed to get providers: %w", err)
140	}
141
142	m.providers = providers
143	if err := m.setProviderItems(); err != nil {
144		return nil, fmt.Errorf("failed to set provider items: %w", err)
145	}
146
147	return m, nil
148}
149
150// ID implements Dialog.
151func (m *Models) ID() string {
152	return ModelsID
153}
154
155// HandleMsg implements Dialog.
156func (m *Models) HandleMsg(msg tea.Msg) Action {
157	switch msg := msg.(type) {
158	case tea.KeyPressMsg:
159		switch {
160		case key.Matches(msg, m.keyMap.Close):
161			return ActionClose{}
162		case key.Matches(msg, m.keyMap.Previous):
163			m.list.Focus()
164			if m.list.IsSelectedFirst() {
165				m.list.SelectLast()
166				m.list.ScrollToBottom()
167				break
168			}
169			m.list.SelectPrev()
170			m.list.ScrollToSelected()
171		case key.Matches(msg, m.keyMap.Next):
172			m.list.Focus()
173			if m.list.IsSelectedLast() {
174				m.list.SelectFirst()
175				m.list.ScrollToTop()
176				break
177			}
178			m.list.SelectNext()
179			m.list.ScrollToSelected()
180		case key.Matches(msg, m.keyMap.Select):
181			selectedItem := m.list.SelectedItem()
182			if selectedItem == nil {
183				break
184			}
185
186			modelItem, ok := selectedItem.(*ModelItem)
187			if !ok {
188				break
189			}
190
191			return ActionSelectModel{
192				Model:     modelItem.SelectedModel(),
193				ModelType: modelItem.SelectedModelType(),
194			}
195		case key.Matches(msg, m.keyMap.Tab):
196			if m.modelType == ModelTypeLarge {
197				m.modelType = ModelTypeSmall
198			} else {
199				m.modelType = ModelTypeLarge
200			}
201			if err := m.setProviderItems(); err != nil {
202				return uiutil.ReportError(err)
203			}
204		default:
205			var cmd tea.Cmd
206			m.input, cmd = m.input.Update(msg)
207			value := m.input.Value()
208			m.list.SetFilter(value)
209			m.list.ScrollToSelected()
210			return ActionCmd{cmd}
211		}
212	}
213	return nil
214}
215
216// Cursor returns the cursor for the dialog.
217func (m *Models) Cursor() *tea.Cursor {
218	return InputCursor(m.com.Styles, m.input.Cursor())
219}
220
221// modelTypeRadioView returns the radio view for model type selection.
222func (m *Models) modelTypeRadioView() string {
223	t := m.com.Styles
224	textStyle := t.HalfMuted
225	largeRadioStyle := t.RadioOff
226	smallRadioStyle := t.RadioOff
227	if m.modelType == ModelTypeLarge {
228		largeRadioStyle = t.RadioOn
229	} else {
230		smallRadioStyle = t.RadioOn
231	}
232
233	largeRadio := largeRadioStyle.Padding(0, 1).Render()
234	smallRadio := smallRadioStyle.Padding(0, 1).Render()
235
236	return fmt.Sprintf("%s%s  %s%s",
237		largeRadio, textStyle.Render(ModelTypeLarge.String()),
238		smallRadio, textStyle.Render(ModelTypeSmall.String()))
239}
240
241// Draw implements [Dialog].
242func (m *Models) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
243	t := m.com.Styles
244	width := max(0, min(60, area.Dx()))
245	height := max(0, min(30, area.Dy()))
246	// TODO: Why do we need this 2?
247	innerWidth := width - t.Dialog.View.GetHorizontalFrameSize() - 2
248	heightOffset := t.Dialog.Title.GetVerticalFrameSize() + 1 + // (1) title content
249		t.Dialog.InputPrompt.GetVerticalFrameSize() + 1 + // (1) input content
250		t.Dialog.HelpView.GetVerticalFrameSize() +
251		// TODO: Why do we need this 2?
252		t.Dialog.View.GetVerticalFrameSize() + 2
253	m.input.SetWidth(innerWidth - t.Dialog.InputPrompt.GetHorizontalFrameSize() - 1) // (1) cursor padding
254	m.list.SetSize(innerWidth, height-heightOffset)
255	m.help.SetWidth(innerWidth)
256
257	titleStyle := t.Dialog.Title
258	dialogStyle := t.Dialog.View
259
260	radios := m.modelTypeRadioView()
261
262	headerOffset := lipgloss.Width(radios) + titleStyle.GetHorizontalFrameSize() +
263		dialogStyle.GetHorizontalFrameSize()
264
265	header := common.DialogTitle(t, "Switch Model", width-headerOffset) + radios
266
267	helpView := ansi.Truncate(m.help.View(m), innerWidth, "")
268	view := HeaderInputListHelpView(t, width, m.list.Height(), header,
269		m.input.View(), m.list.Render(), helpView)
270
271	cur := m.Cursor()
272	DrawCenterCursor(scr, area, view, cur)
273	return cur
274}
275
276// ShortHelp returns the short help view.
277func (m *Models) ShortHelp() []key.Binding {
278	return []key.Binding{
279		m.keyMap.UpDown,
280		m.keyMap.Tab,
281		m.keyMap.Select,
282		m.keyMap.Close,
283	}
284}
285
286// FullHelp returns the full help view.
287func (m *Models) FullHelp() [][]key.Binding {
288	return [][]key.Binding{
289		{
290			m.keyMap.Select,
291			m.keyMap.Next,
292			m.keyMap.Previous,
293			m.keyMap.Tab,
294		},
295		{
296			m.keyMap.Close,
297		},
298	}
299}
300
301// setProviderItems sets the provider items in the list.
302func (m *Models) setProviderItems() error {
303	t := m.com.Styles
304	cfg := m.com.Config()
305
306	var selectedItemID string
307	selectedType := m.modelType.Config()
308	currentModel := cfg.Models[selectedType]
309	recentItems := cfg.RecentModels[selectedType]
310
311	// Track providers already added to avoid duplicates
312	addedProviders := make(map[string]bool)
313
314	// Get a list of known providers to compare against
315	knownProviders, err := config.Providers(cfg)
316	if err != nil {
317		return fmt.Errorf("failed to get providers: %w", err)
318	}
319
320	containsProviderFunc := func(id string) func(p catwalk.Provider) bool {
321		return func(p catwalk.Provider) bool {
322			return p.ID == catwalk.InferenceProvider(id)
323		}
324	}
325
326	// itemsMap contains the keys of added model items.
327	itemsMap := make(map[string]*ModelItem)
328	groups := []ModelGroup{}
329	for id, p := range cfg.Providers.Seq2() {
330		if p.Disable {
331			continue
332		}
333
334		// Check if this provider is not in the known providers list
335		if !slices.ContainsFunc(knownProviders, containsProviderFunc(id)) ||
336			!slices.ContainsFunc(m.providers, containsProviderFunc(id)) {
337			provider := p.ToProvider()
338
339			// Add this unknown provider to the list
340			name := cmp.Or(p.Name, id)
341
342			addedProviders[id] = true
343
344			group := NewModelGroup(t, name, true)
345			for _, model := range p.Models {
346				item := NewModelItem(t, provider, model, m.modelType, false)
347				group.AppendItems(item)
348				itemsMap[item.ID()] = item
349				if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
350					selectedItemID = item.ID()
351				}
352			}
353			if len(group.Items) > 0 {
354				groups = append(groups, group)
355			}
356		}
357	}
358
359	// Now add known providers from the predefined list
360	for _, provider := range m.providers {
361		providerID := string(provider.ID)
362		if addedProviders[providerID] {
363			continue
364		}
365
366		providerConfig, providerConfigured := cfg.Providers.Get(providerID)
367		if providerConfigured && providerConfig.Disable {
368			continue
369		}
370
371		displayProvider := provider
372		if providerConfigured {
373			displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
374			modelIndex := make(map[string]int, len(displayProvider.Models))
375			for i, model := range displayProvider.Models {
376				modelIndex[model.ID] = i
377			}
378			for _, model := range providerConfig.Models {
379				if model.ID == "" {
380					continue
381				}
382				if idx, ok := modelIndex[model.ID]; ok {
383					if model.Name != "" {
384						displayProvider.Models[idx].Name = model.Name
385					}
386					continue
387				}
388				if model.Name == "" {
389					model.Name = model.ID
390				}
391				displayProvider.Models = append(displayProvider.Models, model)
392				modelIndex[model.ID] = len(displayProvider.Models) - 1
393			}
394		}
395
396		name := displayProvider.Name
397		if name == "" {
398			name = providerID
399		}
400
401		group := NewModelGroup(t, name, providerConfigured)
402		for _, model := range displayProvider.Models {
403			item := NewModelItem(t, provider, model, m.modelType, false)
404			group.AppendItems(item)
405			itemsMap[item.ID()] = item
406			if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
407				selectedItemID = item.ID()
408			}
409		}
410
411		groups = append(groups, group)
412	}
413
414	if len(recentItems) > 0 {
415		recentGroup := NewModelGroup(t, "Recently used", false)
416
417		var validRecentItems []config.SelectedModel
418		for _, recent := range recentItems {
419			key := modelKey(recent.Provider, recent.Model)
420			item, ok := itemsMap[key]
421			if !ok {
422				continue
423			}
424
425			// Show provider for recent items
426			item = NewModelItem(t, item.prov, item.model, m.modelType, true)
427			item.showProvider = true
428
429			validRecentItems = append(validRecentItems, recent)
430			recentGroup.AppendItems(item)
431			if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider {
432				selectedItemID = item.ID()
433			}
434		}
435
436		if len(validRecentItems) != len(recentItems) {
437			// FIXME: Does this need to be here? Is it mutating the config during a read?
438			if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
439				return fmt.Errorf("failed to update recent models: %w", err)
440			}
441		}
442
443		if len(recentGroup.Items) > 0 {
444			groups = append([]ModelGroup{recentGroup}, groups...)
445		}
446	}
447
448	// Set model groups in the list.
449	m.list.SetGroups(groups...)
450	m.list.SetSelectedItem(selectedItemID)
451
452	// Update placeholder based on model type
453	m.input.Placeholder = m.modelType.Placeholder()
454
455	return nil
456}
457
458func getFilteredProviders(cfg *config.Config) ([]catwalk.Provider, error) {
459	providers, err := config.Providers(cfg)
460	if err != nil {
461		return nil, fmt.Errorf("failed to get providers: %w", err)
462	}
463	filteredProviders := []catwalk.Provider{}
464	for _, p := range providers {
465		hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
466		if hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure {
467			filteredProviders = append(filteredProviders, p)
468		}
469	}
470	return filteredProviders, nil
471}
472
473func modelKey(providerID, modelID string) string {
474	if providerID == "" || modelID == "" {
475		return ""
476	}
477	return providerID + ":" + modelID
478}