models.go

  1package dialog
  2
  3import (
  4	"cmp"
  5	"fmt"
  6	"slices"
  7
  8	"charm.land/bubbles/v2/help"
  9	"charm.land/bubbles/v2/key"
 10	"charm.land/bubbles/v2/textinput"
 11	tea "charm.land/bubbletea/v2"
 12	"charm.land/catwalk/pkg/catwalk"
 13	"github.com/charmbracelet/crush/internal/config"
 14	"github.com/charmbracelet/crush/internal/ui/common"
 15	"github.com/charmbracelet/crush/internal/ui/util"
 16	uv "github.com/charmbracelet/ultraviolet"
 17	xslice "github.com/charmbracelet/x/exp/slice"
 18)
 19
 20// ModelType represents the type of model to select.
 21type ModelType int
 22
 23const (
 24	ModelTypeLarge ModelType = iota
 25	ModelTypeSmall
 26)
 27
 28// String returns the string representation of the [ModelType].
 29func (mt ModelType) String() string {
 30	switch mt {
 31	case ModelTypeLarge:
 32		return "Large Task"
 33	case ModelTypeSmall:
 34		return "Small Task"
 35	default:
 36		return "Unknown"
 37	}
 38}
 39
 40// Config returns the corresponding config model type.
 41func (mt ModelType) Config() config.SelectedModelType {
 42	switch mt {
 43	case ModelTypeLarge:
 44		return config.SelectedModelTypeLarge
 45	case ModelTypeSmall:
 46		return config.SelectedModelTypeSmall
 47	default:
 48		return ""
 49	}
 50}
 51
 52// Placeholder returns the input placeholder for the model type.
 53func (mt ModelType) Placeholder() string {
 54	switch mt {
 55	case ModelTypeLarge:
 56		return largeModelInputPlaceholder
 57	case ModelTypeSmall:
 58		return smallModelInputPlaceholder
 59	default:
 60		return ""
 61	}
 62}
 63
 64const (
 65	onboardingModelInputPlaceholder = "Find your fave"
 66	largeModelInputPlaceholder      = "Choose a model for large, complex tasks"
 67	smallModelInputPlaceholder      = "Choose a model for small, simple tasks"
 68)
 69
 70// ModelsID is the identifier for the model selection dialog.
 71const ModelsID = "models"
 72
 73const defaultModelsDialogMaxWidth = 73
 74
 75// Models represents a model selection dialog.
 76type Models struct {
 77	com          *common.Common
 78	isOnboarding bool
 79
 80	modelType ModelType
 81	providers []catwalk.Provider
 82
 83	keyMap struct {
 84		Tab      key.Binding
 85		UpDown   key.Binding
 86		Select   key.Binding
 87		Edit     key.Binding
 88		Next     key.Binding
 89		Previous key.Binding
 90		Close    key.Binding
 91	}
 92	list  *ModelsList
 93	input textinput.Model
 94	help  help.Model
 95}
 96
 97var _ Dialog = (*Models)(nil)
 98
 99// NewModels creates a new Models dialog.
100func NewModels(com *common.Common, isOnboarding bool) (*Models, error) {
101	t := com.Styles
102	m := &Models{}
103	m.com = com
104	m.isOnboarding = isOnboarding
105
106	help := help.New()
107	help.Styles = t.DialogHelpStyles()
108
109	m.help = help
110	m.list = NewModelsList(t)
111	m.list.Focus()
112	m.list.SetSelected(0)
113
114	m.input = textinput.New()
115	m.input.SetVirtualCursor(false)
116	m.input.Placeholder = onboardingModelInputPlaceholder
117	m.input.SetStyles(com.Styles.TextInput)
118	m.input.Focus()
119
120	m.keyMap.Tab = key.NewBinding(
121		key.WithKeys("tab", "shift+tab"),
122		key.WithHelp("tab", "toggle type"),
123	)
124	m.keyMap.Select = key.NewBinding(
125		key.WithKeys("enter", "ctrl+y"),
126		key.WithHelp("enter", "confirm"),
127	)
128	m.keyMap.Edit = key.NewBinding(
129		key.WithKeys("ctrl+e"),
130		key.WithHelp("ctrl+e", "edit"),
131	)
132	m.keyMap.UpDown = key.NewBinding(
133		key.WithKeys("up", "down"),
134		key.WithHelp("↑/↓", "choose"),
135	)
136	m.keyMap.Next = key.NewBinding(
137		key.WithKeys("down", "ctrl+n"),
138		key.WithHelp("↓", "next item"),
139	)
140	m.keyMap.Previous = key.NewBinding(
141		key.WithKeys("up", "ctrl+p"),
142		key.WithHelp("↑", "previous item"),
143	)
144	m.keyMap.Close = CloseKey
145
146	m.providers = slices.Collect(
147		xslice.Map(
148			com.Config().Providers.Seq(),
149			func(pc config.ProviderConfig) catwalk.Provider {
150				return pc.ToProvider()
151			},
152		),
153	)
154	if err := m.setProviderItems(); err != nil {
155		return nil, fmt.Errorf("failed to set provider items: %w", err)
156	}
157
158	return m, nil
159}
160
161// ID implements Dialog.
162func (m *Models) ID() string {
163	return ModelsID
164}
165
166// HandleMsg implements Dialog.
167func (m *Models) HandleMsg(msg tea.Msg) Action {
168	switch msg := msg.(type) {
169	case tea.KeyPressMsg:
170		switch {
171		case key.Matches(msg, m.keyMap.Close):
172			return ActionClose{}
173		case key.Matches(msg, m.keyMap.Previous):
174			m.list.Focus()
175			if m.list.IsSelectedFirst() {
176				m.list.SelectLast()
177				m.list.ScrollToBottom()
178				break
179			}
180			m.list.SelectPrev()
181			m.list.ScrollToSelected()
182		case key.Matches(msg, m.keyMap.Next):
183			m.list.Focus()
184			if m.list.IsSelectedLast() {
185				m.list.SelectFirst()
186				m.list.ScrollToTop()
187				break
188			}
189			m.list.SelectNext()
190			m.list.ScrollToSelected()
191		case key.Matches(msg, m.keyMap.Select, m.keyMap.Edit):
192			selectedItem := m.list.SelectedItem()
193			if selectedItem == nil {
194				break
195			}
196
197			modelItem, ok := selectedItem.(*ModelItem)
198			if !ok {
199				break
200			}
201
202			isEdit := key.Matches(msg, m.keyMap.Edit)
203
204			return ActionSelectModel{
205				Provider:       modelItem.prov,
206				Model:          modelItem.SelectedModel(),
207				ModelType:      modelItem.SelectedModelType(),
208				ReAuthenticate: isEdit,
209			}
210		case key.Matches(msg, m.keyMap.Tab):
211			if m.isOnboarding {
212				break
213			}
214			if m.modelType == ModelTypeLarge {
215				m.modelType = ModelTypeSmall
216			} else {
217				m.modelType = ModelTypeLarge
218			}
219			if err := m.setProviderItems(); err != nil {
220				return util.ReportError(err)
221			}
222		default:
223			var cmd tea.Cmd
224			m.input, cmd = m.input.Update(msg)
225			value := m.input.Value()
226			m.list.Focus()
227			m.list.SetFilter(value)
228			m.list.SelectFirst()
229			m.list.ScrollToTop()
230			return ActionCmd{cmd}
231		}
232	}
233	return nil
234}
235
236// Cursor returns the cursor for the dialog.
237func (m *Models) Cursor() *tea.Cursor {
238	return InputCursor(m.com.Styles, m.input.Cursor())
239}
240
241// modelTypeRadioView returns the radio view for model type selection.
242func (m *Models) modelTypeRadioView() string {
243	t := m.com.Styles
244	textStyle := t.HalfMuted
245	largeRadioStyle := t.RadioOff
246	smallRadioStyle := t.RadioOff
247	if m.modelType == ModelTypeLarge {
248		largeRadioStyle = t.RadioOn
249	} else {
250		smallRadioStyle = t.RadioOn
251	}
252
253	largeRadio := largeRadioStyle.Padding(0, 1).Render()
254	smallRadio := smallRadioStyle.Padding(0, 1).Render()
255
256	return fmt.Sprintf("%s%s  %s%s",
257		largeRadio, textStyle.Render(ModelTypeLarge.String()),
258		smallRadio, textStyle.Render(ModelTypeSmall.String()))
259}
260
261// Draw implements [Dialog].
262func (m *Models) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
263	t := m.com.Styles
264	width := max(0, min(defaultModelsDialogMaxWidth, area.Dx()-t.Dialog.View.GetHorizontalBorderSize()))
265	height := max(0, min(defaultDialogHeight, area.Dy()-t.Dialog.View.GetVerticalBorderSize()))
266	innerWidth := width - t.Dialog.View.GetHorizontalFrameSize()
267	heightOffset := t.Dialog.Title.GetVerticalFrameSize() + titleContentHeight +
268		t.Dialog.InputPrompt.GetVerticalFrameSize() + inputContentHeight +
269		t.Dialog.HelpView.GetVerticalFrameSize() +
270		t.Dialog.View.GetVerticalFrameSize()
271
272	m.input.SetWidth(max(0, innerWidth-t.Dialog.InputPrompt.GetHorizontalFrameSize()-1)) // (1) cursor padding
273	m.list.SetSize(innerWidth, height-heightOffset)
274	m.help.SetWidth(innerWidth)
275
276	rc := NewRenderContext(t, width)
277	rc.Title = "Switch Model"
278	rc.TitleInfo = m.modelTypeRadioView()
279
280	if m.isOnboarding {
281		titleText := t.Dialog.PrimaryText.Render("To start, let's choose a provider and model.")
282		rc.AddPart(titleText)
283	}
284
285	inputView := t.Dialog.InputPrompt.Render(m.input.View())
286	rc.AddPart(inputView)
287
288	listView := t.Dialog.List.Height(m.list.Height()).Render(m.list.Render())
289	rc.AddPart(listView)
290
291	rc.Help = m.help.View(m)
292
293	cur := m.Cursor()
294
295	if m.isOnboarding {
296		rc.Title = ""
297		rc.TitleInfo = ""
298		rc.IsOnboarding = true
299		view := rc.Render()
300		DrawOnboardingCursor(scr, area, view, cur)
301
302		// FIXME(@andreynering): Figure it out how to properly fix this
303		if cur != nil {
304			cur.Y -= 1
305			cur.X -= 1
306		}
307	} else {
308		view := rc.Render()
309		DrawCenterCursor(scr, area, view, cur)
310	}
311	return cur
312}
313
314// ShortHelp returns the short help view.
315func (m *Models) ShortHelp() []key.Binding {
316	if m.isOnboarding {
317		return []key.Binding{
318			m.keyMap.UpDown,
319			m.keyMap.Select,
320		}
321	}
322	h := []key.Binding{
323		m.keyMap.UpDown,
324		m.keyMap.Tab,
325		m.keyMap.Select,
326	}
327	if m.isSelectedConfigured() {
328		h = append(h, m.keyMap.Edit)
329	}
330	h = append(h, m.keyMap.Close)
331	return h
332}
333
334// FullHelp returns the full help view.
335func (m *Models) FullHelp() [][]key.Binding {
336	return [][]key.Binding{m.ShortHelp()}
337}
338
339func (m *Models) isSelectedConfigured() bool {
340	selectedItem := m.list.SelectedItem()
341	if selectedItem == nil {
342		return false
343	}
344	modelItem, ok := selectedItem.(*ModelItem)
345	if !ok {
346		return false
347	}
348	providerID := string(modelItem.prov.ID)
349	_, isConfigured := m.com.Config().Providers.Get(providerID)
350	return isConfigured
351}
352
353// setProviderItems sets the provider items in the list.
354func (m *Models) setProviderItems() error {
355	t := m.com.Styles
356	cfg := m.com.Config()
357
358	var selectedItemID string
359	selectedType := m.modelType.Config()
360	currentModel := cfg.Models[selectedType]
361	recentItems := cfg.RecentModels[selectedType]
362
363	// Track providers already added to avoid duplicates
364	addedProviders := make(map[string]bool)
365
366	// Get a list of known providers to compare against
367	knownProviders, err := config.Providers(cfg)
368	if err != nil {
369		return fmt.Errorf("failed to get providers: %w", err)
370	}
371
372	containsProviderFunc := func(id string) func(p catwalk.Provider) bool {
373		return func(p catwalk.Provider) bool {
374			return p.ID == catwalk.InferenceProvider(id)
375		}
376	}
377
378	// itemsMap contains the keys of added model items.
379	itemsMap := make(map[string]*ModelItem)
380	groups := []ModelGroup{}
381	for id, p := range cfg.Providers.Seq2() {
382		if p.Disable {
383			continue
384		}
385
386		// Check if this provider is not in the known providers list
387		if !slices.ContainsFunc(knownProviders, containsProviderFunc(id)) ||
388			!slices.ContainsFunc(m.providers, containsProviderFunc(id)) {
389			provider := p.ToProvider()
390
391			// Add this unknown provider to the list
392			name := cmp.Or(p.Name, id)
393
394			addedProviders[id] = true
395
396			group := NewModelGroup(t, name, true)
397			for _, model := range p.Models {
398				item := NewModelItem(t, provider, model, m.modelType, false)
399				group.AppendItems(item)
400				itemsMap[item.ID()] = item
401				if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
402					selectedItemID = item.ID()
403				}
404			}
405			if len(group.Items) > 0 {
406				groups = append(groups, group)
407			}
408		}
409	}
410
411	// Move "Charm Hyper" to first position.
412	// (But still after recent models and custom providers).
413	slices.SortStableFunc(m.providers, func(a, b catwalk.Provider) int {
414		switch {
415		case a.ID == "hyper":
416			return -1
417		case b.ID == "hyper":
418			return 1
419		default:
420			return 0
421		}
422	})
423
424	// Now add known providers from the predefined list
425	for _, provider := range m.providers {
426		providerID := string(provider.ID)
427		if addedProviders[providerID] {
428			continue
429		}
430
431		providerConfig, providerConfigured := cfg.Providers.Get(providerID)
432		if providerConfigured && providerConfig.Disable {
433			continue
434		}
435
436		displayProvider := provider
437		if providerConfigured {
438			displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
439			modelIndex := make(map[string]int, len(displayProvider.Models))
440			for i, model := range displayProvider.Models {
441				modelIndex[model.ID] = i
442			}
443			for _, model := range providerConfig.Models {
444				if model.ID == "" {
445					continue
446				}
447				if idx, ok := modelIndex[model.ID]; ok {
448					if model.Name != "" {
449						displayProvider.Models[idx].Name = model.Name
450					}
451					continue
452				}
453				if model.Name == "" {
454					model.Name = model.ID
455				}
456				displayProvider.Models = append(displayProvider.Models, model)
457				modelIndex[model.ID] = len(displayProvider.Models) - 1
458			}
459		}
460
461		name := displayProvider.Name
462		if name == "" {
463			name = providerID
464		}
465
466		group := NewModelGroup(t, name, providerConfigured)
467		for _, model := range displayProvider.Models {
468			item := NewModelItem(t, provider, model, m.modelType, false)
469			group.AppendItems(item)
470			itemsMap[item.ID()] = item
471			if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
472				selectedItemID = item.ID()
473			}
474		}
475
476		groups = append(groups, group)
477	}
478
479	if len(recentItems) > 0 {
480		recentGroup := NewModelGroup(t, "Recently used", false)
481
482		var validRecentItems []config.SelectedModel
483		for _, recent := range recentItems {
484			key := modelKey(recent.Provider, recent.Model)
485			item, ok := itemsMap[key]
486			if !ok {
487				continue
488			}
489
490			// Show provider for recent items
491			item = NewModelItem(t, item.prov, item.model, m.modelType, true)
492			item.showProvider = true
493
494			validRecentItems = append(validRecentItems, recent)
495			recentGroup.AppendItems(item)
496			if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider {
497				selectedItemID = item.ID()
498			}
499		}
500
501		if len(validRecentItems) != len(recentItems) {
502			// FIXME: Does this need to be here? Is it mutating the config during a read?
503			if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
504				return fmt.Errorf("failed to update recent models: %w", err)
505			}
506		}
507
508		if len(recentGroup.Items) > 0 {
509			groups = append([]ModelGroup{recentGroup}, groups...)
510		}
511	}
512
513	// Set model groups in the list.
514	m.list.SetGroups(groups...)
515	m.list.SetSelectedItem(selectedItemID)
516	m.list.ScrollToTop()
517
518	// Update placeholder based on model type
519	if !m.isOnboarding {
520		m.input.Placeholder = m.modelType.Placeholder()
521	}
522
523	return nil
524}
525
526func modelKey(providerID, modelID string) string {
527	if providerID == "" || modelID == "" {
528		return ""
529	}
530	return providerID + ":" + modelID
531}