models.go

  1package dialog
  2
  3import (
  4	"cmp"
  5	"fmt"
  6	"slices"
  7
  8	"charm.land/bubbles/v2/help"
  9	"charm.land/bubbles/v2/key"
 10	"charm.land/bubbles/v2/textinput"
 11	tea "charm.land/bubbletea/v2"
 12	"charm.land/catwalk/pkg/catwalk"
 13	"github.com/charmbracelet/crush/internal/config"
 14	"github.com/charmbracelet/crush/internal/ui/common"
 15	"github.com/charmbracelet/crush/internal/ui/util"
 16	uv "github.com/charmbracelet/ultraviolet"
 17)
 18
 19// ModelType represents the type of model to select.
 20type ModelType int
 21
 22const (
 23	ModelTypeLarge ModelType = iota
 24	ModelTypeSmall
 25)
 26
 27// String returns the string representation of the [ModelType].
 28func (mt ModelType) String() string {
 29	switch mt {
 30	case ModelTypeLarge:
 31		return "Large Task"
 32	case ModelTypeSmall:
 33		return "Small Task"
 34	default:
 35		return "Unknown"
 36	}
 37}
 38
 39// Config returns the corresponding config model type.
 40func (mt ModelType) Config() config.SelectedModelType {
 41	switch mt {
 42	case ModelTypeLarge:
 43		return config.SelectedModelTypeLarge
 44	case ModelTypeSmall:
 45		return config.SelectedModelTypeSmall
 46	default:
 47		return ""
 48	}
 49}
 50
 51// Placeholder returns the input placeholder for the model type.
 52func (mt ModelType) Placeholder() string {
 53	switch mt {
 54	case ModelTypeLarge:
 55		return largeModelInputPlaceholder
 56	case ModelTypeSmall:
 57		return smallModelInputPlaceholder
 58	default:
 59		return ""
 60	}
 61}
 62
 63const (
 64	onboardingModelInputPlaceholder = "Find your fave"
 65	largeModelInputPlaceholder      = "Choose a model for large, complex tasks"
 66	smallModelInputPlaceholder      = "Choose a model for small, simple tasks"
 67)
 68
 69// ModelsID is the identifier for the model selection dialog.
 70const ModelsID = "models"
 71
 72const defaultModelsDialogMaxWidth = 73
 73
 74// Models represents a model selection dialog.
 75type Models struct {
 76	com          *common.Common
 77	isOnboarding bool
 78
 79	modelType ModelType
 80	providers []catwalk.Provider
 81
 82	keyMap struct {
 83		Tab      key.Binding
 84		UpDown   key.Binding
 85		Select   key.Binding
 86		Edit     key.Binding
 87		Next     key.Binding
 88		Previous key.Binding
 89		Close    key.Binding
 90	}
 91	list  *ModelsList
 92	input textinput.Model
 93	help  help.Model
 94}
 95
 96var _ Dialog = (*Models)(nil)
 97
 98// NewModels creates a new Models dialog.
 99func NewModels(com *common.Common, isOnboarding bool) (*Models, error) {
100	t := com.Styles
101	m := &Models{}
102	m.com = com
103	m.isOnboarding = isOnboarding
104
105	help := help.New()
106	help.Styles = t.DialogHelpStyles()
107
108	m.help = help
109	m.list = NewModelsList(t)
110	m.list.Focus()
111	m.list.SetSelected(0)
112
113	m.input = textinput.New()
114	m.input.SetVirtualCursor(false)
115	m.input.Placeholder = onboardingModelInputPlaceholder
116	m.input.SetStyles(com.Styles.TextInput)
117	m.input.Focus()
118
119	m.keyMap.Tab = key.NewBinding(
120		key.WithKeys("tab", "shift+tab"),
121		key.WithHelp("tab", "toggle type"),
122	)
123	m.keyMap.Select = key.NewBinding(
124		key.WithKeys("enter", "ctrl+y"),
125		key.WithHelp("enter", "confirm"),
126	)
127	m.keyMap.Edit = key.NewBinding(
128		key.WithKeys("ctrl+e"),
129		key.WithHelp("ctrl+e", "edit"),
130	)
131	m.keyMap.UpDown = key.NewBinding(
132		key.WithKeys("up", "down"),
133		key.WithHelp("↑/↓", "choose"),
134	)
135	m.keyMap.Next = key.NewBinding(
136		key.WithKeys("down", "ctrl+n"),
137		key.WithHelp("↓", "next item"),
138	)
139	m.keyMap.Previous = key.NewBinding(
140		key.WithKeys("up", "ctrl+p"),
141		key.WithHelp("↑", "previous item"),
142	)
143	m.keyMap.Close = CloseKey
144
145	var err error
146	m.providers, err = config.Providers(m.com.Config())
147	if err != nil {
148		return nil, fmt.Errorf("failed to get providers: %w", err)
149	}
150
151	if err := m.setProviderItems(); err != nil {
152		return nil, fmt.Errorf("failed to set provider items: %w", err)
153	}
154
155	return m, nil
156}
157
158// ID implements Dialog.
159func (m *Models) ID() string {
160	return ModelsID
161}
162
163// HandleMsg implements Dialog.
164func (m *Models) HandleMsg(msg tea.Msg) Action {
165	switch msg := msg.(type) {
166	case tea.KeyPressMsg:
167		switch {
168		case key.Matches(msg, m.keyMap.Close):
169			return ActionClose{}
170		case key.Matches(msg, m.keyMap.Previous):
171			m.list.Focus()
172			if m.list.IsSelectedFirst() {
173				m.list.SelectLast()
174			} else {
175				m.list.SelectPrev()
176			}
177			m.list.ScrollToSelected()
178		case key.Matches(msg, m.keyMap.Next):
179			m.list.Focus()
180			if m.list.IsSelectedLast() {
181				m.list.SelectFirst()
182			} else {
183				m.list.SelectNext()
184			}
185			m.list.ScrollToSelected()
186		case key.Matches(msg, m.keyMap.Select, m.keyMap.Edit):
187			selectedItem := m.list.SelectedItem()
188			if selectedItem == nil {
189				break
190			}
191
192			modelItem, ok := selectedItem.(*ModelItem)
193			if !ok {
194				break
195			}
196
197			isEdit := key.Matches(msg, m.keyMap.Edit)
198
199			return ActionSelectModel{
200				Provider:       modelItem.prov,
201				Model:          modelItem.SelectedModel(),
202				ModelType:      modelItem.SelectedModelType(),
203				ReAuthenticate: isEdit,
204			}
205		case key.Matches(msg, m.keyMap.Tab):
206			if m.isOnboarding {
207				break
208			}
209			if m.modelType == ModelTypeLarge {
210				m.modelType = ModelTypeSmall
211			} else {
212				m.modelType = ModelTypeLarge
213			}
214			if err := m.setProviderItems(); err != nil {
215				return util.ReportError(err)
216			}
217		default:
218			var cmd tea.Cmd
219			m.input, cmd = m.input.Update(msg)
220			value := m.input.Value()
221			m.list.Focus()
222			m.list.SetFilter(value)
223			m.list.SelectFirst()
224			m.list.ScrollToTop()
225			return ActionCmd{cmd}
226		}
227	}
228	return nil
229}
230
231// Cursor returns the cursor for the dialog.
232func (m *Models) Cursor() *tea.Cursor {
233	return InputCursor(m.com.Styles, m.input.Cursor())
234}
235
236// modelTypeRadioView returns the radio view for model type selection.
237func (m *Models) modelTypeRadioView() string {
238	t := m.com.Styles
239	textStyle := t.HalfMuted
240	largeRadioStyle := t.RadioOff
241	smallRadioStyle := t.RadioOff
242	if m.modelType == ModelTypeLarge {
243		largeRadioStyle = t.RadioOn
244	} else {
245		smallRadioStyle = t.RadioOn
246	}
247
248	largeRadio := largeRadioStyle.Padding(0, 1).Render()
249	smallRadio := smallRadioStyle.Padding(0, 1).Render()
250
251	return fmt.Sprintf("%s%s  %s%s",
252		largeRadio, textStyle.Render(ModelTypeLarge.String()),
253		smallRadio, textStyle.Render(ModelTypeSmall.String()))
254}
255
256// Draw implements [Dialog].
257func (m *Models) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
258	t := m.com.Styles
259	width := max(0, min(defaultModelsDialogMaxWidth, area.Dx()-t.Dialog.View.GetHorizontalBorderSize()))
260	height := max(0, min(defaultDialogHeight, area.Dy()-t.Dialog.View.GetVerticalBorderSize()))
261	innerWidth := width - t.Dialog.View.GetHorizontalFrameSize()
262	heightOffset := t.Dialog.Title.GetVerticalFrameSize() + titleContentHeight +
263		t.Dialog.InputPrompt.GetVerticalFrameSize() + inputContentHeight +
264		t.Dialog.HelpView.GetVerticalFrameSize() +
265		t.Dialog.View.GetVerticalFrameSize()
266
267	m.input.SetWidth(max(0, innerWidth-t.Dialog.InputPrompt.GetHorizontalFrameSize()-1)) // (1) cursor padding
268	m.list.SetSize(innerWidth, height-heightOffset)
269	m.help.SetWidth(innerWidth)
270
271	rc := NewRenderContext(t, width)
272	rc.Title = "Switch Model"
273	rc.TitleInfo = m.modelTypeRadioView()
274
275	if m.isOnboarding {
276		titleText := t.Dialog.PrimaryText.Render("To start, let's choose a provider and model.")
277		rc.AddPart(titleText)
278	}
279
280	inputView := t.Dialog.InputPrompt.Render(m.input.View())
281	rc.AddPart(inputView)
282
283	listView := t.Dialog.List.Height(m.list.Height()).Render(m.list.Render())
284	rc.AddPart(listView)
285
286	rc.Help = m.help.View(m)
287
288	cur := m.Cursor()
289
290	if m.isOnboarding {
291		rc.Title = ""
292		rc.TitleInfo = ""
293		rc.IsOnboarding = true
294		view := rc.Render()
295		DrawOnboardingCursor(scr, area, view, cur)
296
297		// FIXME(@andreynering): Figure it out how to properly fix this
298		if cur != nil {
299			cur.Y -= 1
300			cur.X -= 1
301		}
302	} else {
303		view := rc.Render()
304		DrawCenterCursor(scr, area, view, cur)
305	}
306	return cur
307}
308
309// ShortHelp returns the short help view.
310func (m *Models) ShortHelp() []key.Binding {
311	if m.isOnboarding {
312		return []key.Binding{
313			m.keyMap.UpDown,
314			m.keyMap.Select,
315		}
316	}
317	h := []key.Binding{
318		m.keyMap.UpDown,
319		m.keyMap.Tab,
320		m.keyMap.Select,
321	}
322	if m.isSelectedConfigured() {
323		h = append(h, m.keyMap.Edit)
324	}
325	h = append(h, m.keyMap.Close)
326	return h
327}
328
329// FullHelp returns the full help view.
330func (m *Models) FullHelp() [][]key.Binding {
331	return [][]key.Binding{m.ShortHelp()}
332}
333
334func (m *Models) isSelectedConfigured() bool {
335	selectedItem := m.list.SelectedItem()
336	if selectedItem == nil {
337		return false
338	}
339	modelItem, ok := selectedItem.(*ModelItem)
340	if !ok {
341		return false
342	}
343	providerID := string(modelItem.prov.ID)
344	_, isConfigured := m.com.Config().Providers.Get(providerID)
345	return isConfigured
346}
347
348// setProviderItems sets the provider items in the list.
349func (m *Models) setProviderItems() error {
350	t := m.com.Styles
351	cfg := m.com.Config()
352
353	var selectedItemID string
354	selectedType := m.modelType.Config()
355	currentModel := cfg.Models[selectedType]
356	recentItems := cfg.RecentModels[selectedType]
357
358	// Track providers already added to avoid duplicates
359	addedProviders := make(map[string]bool)
360
361	// Get a list of known providers to compare against
362	knownProviders, err := config.Providers(cfg)
363	if err != nil {
364		return fmt.Errorf("failed to get providers: %w", err)
365	}
366
367	containsProviderFunc := func(id string) func(p catwalk.Provider) bool {
368		return func(p catwalk.Provider) bool {
369			return p.ID == catwalk.InferenceProvider(id)
370		}
371	}
372
373	// itemsMap contains the keys of added model items.
374	itemsMap := make(map[string]*ModelItem)
375	groups := []ModelGroup{}
376	for id, p := range cfg.Providers.Seq2() {
377		if p.Disable {
378			continue
379		}
380
381		// Check if this provider is not in the known providers list
382		if !slices.ContainsFunc(knownProviders, containsProviderFunc(id)) ||
383			!slices.ContainsFunc(m.providers, containsProviderFunc(id)) {
384			provider := p.ToProvider()
385
386			// Add this unknown provider to the list
387			name := cmp.Or(p.Name, id)
388
389			addedProviders[id] = true
390
391			group := NewModelGroup(t, name, true)
392			for _, model := range p.Models {
393				item := NewModelItem(t, provider, model, m.modelType, false)
394				group.AppendItems(item)
395				itemsMap[item.ID()] = item
396				if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
397					selectedItemID = item.ID()
398				}
399			}
400			if len(group.Items) > 0 {
401				groups = append(groups, group)
402			}
403		}
404	}
405
406	// Move "Charm Hyper" to first position.
407	// (But still after recent models and custom providers).
408	slices.SortStableFunc(m.providers, func(a, b catwalk.Provider) int {
409		switch {
410		case a.ID == "hyper":
411			return -1
412		case b.ID == "hyper":
413			return 1
414		default:
415			return 0
416		}
417	})
418
419	// Now add known providers from the predefined list
420	for _, provider := range m.providers {
421		providerID := string(provider.ID)
422		if addedProviders[providerID] {
423			continue
424		}
425
426		providerConfig, providerConfigured := cfg.Providers.Get(providerID)
427		if providerConfigured && providerConfig.Disable {
428			continue
429		}
430
431		displayProvider := provider
432		if providerConfigured {
433			displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
434			modelIndex := make(map[string]int, len(displayProvider.Models))
435			for i, model := range displayProvider.Models {
436				modelIndex[model.ID] = i
437			}
438			for _, model := range providerConfig.Models {
439				if model.ID == "" {
440					continue
441				}
442				if idx, ok := modelIndex[model.ID]; ok {
443					if model.Name != "" {
444						displayProvider.Models[idx].Name = model.Name
445					}
446					continue
447				}
448				model.Name = cmp.Or(model.Name, model.ID)
449				displayProvider.Models = append(displayProvider.Models, model)
450				modelIndex[model.ID] = len(displayProvider.Models) - 1
451			}
452		}
453
454		name := cmp.Or(displayProvider.Name, providerID)
455
456		group := NewModelGroup(t, name, providerConfigured)
457		for _, model := range displayProvider.Models {
458			item := NewModelItem(t, provider, model, m.modelType, false)
459			group.AppendItems(item)
460			itemsMap[item.ID()] = item
461			if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
462				selectedItemID = item.ID()
463			}
464		}
465
466		groups = append(groups, group)
467	}
468
469	if len(recentItems) > 0 {
470		recentGroup := NewModelGroup(t, "Recently used", false)
471
472		var validRecentItems []config.SelectedModel
473		for _, recent := range recentItems {
474			key := modelKey(recent.Provider, recent.Model)
475			item, ok := itemsMap[key]
476			if !ok {
477				continue
478			}
479
480			// Show provider for recent items
481			item = NewModelItem(t, item.prov, item.model, m.modelType, true)
482			item.showProvider = true
483
484			validRecentItems = append(validRecentItems, recent)
485			recentGroup.AppendItems(item)
486			if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider {
487				selectedItemID = item.ID()
488			}
489		}
490
491		if len(validRecentItems) != len(recentItems) {
492			// FIXME: Does this need to be here? Is it mutating the config during a read?
493			if err := m.com.Store().SetConfigField(config.ScopeGlobal, fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
494				return fmt.Errorf("failed to update recent models: %w", err)
495			}
496		}
497
498		if len(recentGroup.Items) > 0 {
499			groups = append([]ModelGroup{recentGroup}, groups...)
500		}
501	}
502
503	// Set model groups in the list.
504	m.list.SetGroups(groups...)
505	m.list.SetSelectedItem(selectedItemID)
506	m.list.ScrollToTop()
507
508	// Update placeholder based on model type
509	if !m.isOnboarding {
510		m.input.Placeholder = m.modelType.Placeholder()
511	}
512
513	return nil
514}
515
516func modelKey(providerID, modelID string) string {
517	if providerID == "" || modelID == "" {
518		return ""
519	}
520	return providerID + ":" + modelID
521}