models.go

  1package dialog
  2
  3import (
  4	"cmp"
  5	"fmt"
  6	"slices"
  7	"strings"
  8
  9	"charm.land/bubbles/v2/help"
 10	"charm.land/bubbles/v2/key"
 11	"charm.land/bubbles/v2/textinput"
 12	tea "charm.land/bubbletea/v2"
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/ui/common"
 16	"github.com/charmbracelet/crush/internal/uiutil"
 17	uv "github.com/charmbracelet/ultraviolet"
 18)
 19
 20// ModelType represents the type of model to select.
 21type ModelType int
 22
 23const (
 24	ModelTypeLarge ModelType = iota
 25	ModelTypeSmall
 26)
 27
 28// String returns the string representation of the [ModelType].
 29func (mt ModelType) String() string {
 30	switch mt {
 31	case ModelTypeLarge:
 32		return "Large Task"
 33	case ModelTypeSmall:
 34		return "Small Task"
 35	default:
 36		return "Unknown"
 37	}
 38}
 39
 40// Config returns the corresponding config model type.
 41func (mt ModelType) Config() config.SelectedModelType {
 42	switch mt {
 43	case ModelTypeLarge:
 44		return config.SelectedModelTypeLarge
 45	case ModelTypeSmall:
 46		return config.SelectedModelTypeSmall
 47	default:
 48		return ""
 49	}
 50}
 51
 52// Placeholder returns the input placeholder for the model type.
 53func (mt ModelType) Placeholder() string {
 54	switch mt {
 55	case ModelTypeLarge:
 56		return largeModelInputPlaceholder
 57	case ModelTypeSmall:
 58		return smallModelInputPlaceholder
 59	default:
 60		return ""
 61	}
 62}
 63
 64const (
 65	onboardingModelInputPlaceholder = "Find your fave"
 66	largeModelInputPlaceholder      = "Choose a model for large, complex tasks"
 67	smallModelInputPlaceholder      = "Choose a model for small, simple tasks"
 68)
 69
 70// ModelsID is the identifier for the model selection dialog.
 71const ModelsID = "models"
 72
 73const defaultModelsDialogMaxWidth = 70
 74
 75// Models represents a model selection dialog.
 76type Models struct {
 77	com          *common.Common
 78	isOnboarding bool
 79
 80	modelType ModelType
 81	providers []catwalk.Provider
 82
 83	keyMap struct {
 84		Tab      key.Binding
 85		UpDown   key.Binding
 86		Select   key.Binding
 87		Next     key.Binding
 88		Previous key.Binding
 89		Close    key.Binding
 90	}
 91	list  *ModelsList
 92	input textinput.Model
 93	help  help.Model
 94}
 95
 96var _ Dialog = (*Models)(nil)
 97
 98// NewModels creates a new Models dialog.
 99func NewModels(com *common.Common, isOnboarding bool) (*Models, error) {
100	t := com.Styles
101	m := &Models{}
102	m.com = com
103	m.isOnboarding = isOnboarding
104
105	help := help.New()
106	help.Styles = t.DialogHelpStyles()
107
108	m.help = help
109	m.list = NewModelsList(t)
110	m.list.Focus()
111	m.list.SetSelected(0)
112
113	m.input = textinput.New()
114	m.input.SetVirtualCursor(false)
115	m.input.Placeholder = onboardingModelInputPlaceholder
116	m.input.SetStyles(com.Styles.TextInput)
117	m.input.Focus()
118
119	m.keyMap.Tab = key.NewBinding(
120		key.WithKeys("tab", "shift+tab"),
121		key.WithHelp("tab", "toggle type"),
122	)
123	m.keyMap.Select = key.NewBinding(
124		key.WithKeys("enter", "ctrl+y"),
125		key.WithHelp("enter", "confirm"),
126	)
127	m.keyMap.UpDown = key.NewBinding(
128		key.WithKeys("up", "down"),
129		key.WithHelp("↑/↓", "choose"),
130	)
131	m.keyMap.Next = key.NewBinding(
132		key.WithKeys("down", "ctrl+n"),
133		key.WithHelp("↓", "next item"),
134	)
135	m.keyMap.Previous = key.NewBinding(
136		key.WithKeys("up", "ctrl+p"),
137		key.WithHelp("↑", "previous item"),
138	)
139	m.keyMap.Close = CloseKey
140
141	providers, err := getFilteredProviders(com.Config())
142	if err != nil {
143		return nil, fmt.Errorf("failed to get providers: %w", err)
144	}
145
146	m.providers = providers
147	if err := m.setProviderItems(); err != nil {
148		return nil, fmt.Errorf("failed to set provider items: %w", err)
149	}
150
151	return m, nil
152}
153
154// ID implements Dialog.
155func (m *Models) ID() string {
156	return ModelsID
157}
158
159// HandleMsg implements Dialog.
160func (m *Models) HandleMsg(msg tea.Msg) Action {
161	switch msg := msg.(type) {
162	case tea.KeyPressMsg:
163		switch {
164		case key.Matches(msg, m.keyMap.Close):
165			return ActionClose{}
166		case key.Matches(msg, m.keyMap.Previous):
167			m.list.Focus()
168			if m.list.IsSelectedFirst() {
169				m.list.SelectLast()
170				m.list.ScrollToBottom()
171				break
172			}
173			m.list.SelectPrev()
174			m.list.ScrollToSelected()
175		case key.Matches(msg, m.keyMap.Next):
176			m.list.Focus()
177			if m.list.IsSelectedLast() {
178				m.list.SelectFirst()
179				m.list.ScrollToTop()
180				break
181			}
182			m.list.SelectNext()
183			m.list.ScrollToSelected()
184		case key.Matches(msg, m.keyMap.Select):
185			selectedItem := m.list.SelectedItem()
186			if selectedItem == nil {
187				break
188			}
189
190			modelItem, ok := selectedItem.(*ModelItem)
191			if !ok {
192				break
193			}
194
195			return ActionSelectModel{
196				Provider:  modelItem.prov,
197				Model:     modelItem.SelectedModel(),
198				ModelType: modelItem.SelectedModelType(),
199			}
200		case key.Matches(msg, m.keyMap.Tab):
201			if m.isOnboarding {
202				break
203			}
204			if m.modelType == ModelTypeLarge {
205				m.modelType = ModelTypeSmall
206			} else {
207				m.modelType = ModelTypeLarge
208			}
209			if err := m.setProviderItems(); err != nil {
210				return uiutil.ReportError(err)
211			}
212		default:
213			var cmd tea.Cmd
214			m.input, cmd = m.input.Update(msg)
215			value := m.input.Value()
216			m.list.Focus()
217			m.list.SetFilter(value)
218			m.list.SelectFirst()
219			m.list.ScrollToTop()
220			return ActionCmd{cmd}
221		}
222	}
223	return nil
224}
225
226// Cursor returns the cursor for the dialog.
227func (m *Models) Cursor() *tea.Cursor {
228	return InputCursor(m.com.Styles, m.input.Cursor())
229}
230
231// modelTypeRadioView returns the radio view for model type selection.
232func (m *Models) modelTypeRadioView() string {
233	t := m.com.Styles
234	textStyle := t.HalfMuted
235	largeRadioStyle := t.RadioOff
236	smallRadioStyle := t.RadioOff
237	if m.modelType == ModelTypeLarge {
238		largeRadioStyle = t.RadioOn
239	} else {
240		smallRadioStyle = t.RadioOn
241	}
242
243	largeRadio := largeRadioStyle.Padding(0, 1).Render()
244	smallRadio := smallRadioStyle.Padding(0, 1).Render()
245
246	return fmt.Sprintf("%s%s  %s%s",
247		largeRadio, textStyle.Render(ModelTypeLarge.String()),
248		smallRadio, textStyle.Render(ModelTypeSmall.String()))
249}
250
251// Draw implements [Dialog].
252func (m *Models) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
253	t := m.com.Styles
254	width := max(0, min(defaultModelsDialogMaxWidth, area.Dx()))
255	height := max(0, min(defaultDialogHeight, area.Dy()))
256	innerWidth := width - t.Dialog.View.GetHorizontalFrameSize()
257	heightOffset := t.Dialog.Title.GetVerticalFrameSize() + titleContentHeight +
258		t.Dialog.InputPrompt.GetVerticalFrameSize() + inputContentHeight +
259		t.Dialog.HelpView.GetVerticalFrameSize() +
260		t.Dialog.View.GetVerticalFrameSize()
261
262	m.input.SetWidth(max(0, innerWidth-t.Dialog.InputPrompt.GetHorizontalFrameSize()-1)) // (1) cursor padding
263	m.list.SetSize(innerWidth, height-heightOffset)
264	m.help.SetWidth(innerWidth)
265
266	rc := NewRenderContext(t, width)
267	rc.Title = "Switch Model"
268	rc.TitleInfo = m.modelTypeRadioView()
269
270	if m.isOnboarding {
271		titleText := t.Dialog.PrimaryText.Render("To start, let's choose a provider and model.")
272		rc.AddPart(titleText)
273	}
274
275	inputView := t.Dialog.InputPrompt.Render(m.input.View())
276	rc.AddPart(inputView)
277
278	listView := t.Dialog.List.Height(m.list.Height()).Render(m.list.Render())
279	rc.AddPart(listView)
280
281	rc.Help = m.help.View(m)
282
283	cur := m.Cursor()
284
285	if m.isOnboarding {
286		rc.Title = ""
287		rc.TitleInfo = ""
288		rc.IsOnboarding = true
289		view := rc.Render()
290		DrawOnboardingCursor(scr, area, view, cur)
291
292		// FIXME(@andreynering): Figure it out how to properly fix this
293		if cur != nil {
294			cur.Y -= 1
295			cur.X -= 1
296		}
297	} else {
298		view := rc.Render()
299		DrawCenterCursor(scr, area, view, cur)
300	}
301	return cur
302}
303
304// ShortHelp returns the short help view.
305func (m *Models) ShortHelp() []key.Binding {
306	if m.isOnboarding {
307		return []key.Binding{
308			m.keyMap.UpDown,
309			m.keyMap.Select,
310		}
311	}
312	return []key.Binding{
313		m.keyMap.UpDown,
314		m.keyMap.Tab,
315		m.keyMap.Select,
316		m.keyMap.Close,
317	}
318}
319
320// FullHelp returns the full help view.
321func (m *Models) FullHelp() [][]key.Binding {
322	return [][]key.Binding{
323		{
324			m.keyMap.Select,
325			m.keyMap.Next,
326			m.keyMap.Previous,
327			m.keyMap.Tab,
328		},
329		{
330			m.keyMap.Close,
331		},
332	}
333}
334
335// setProviderItems sets the provider items in the list.
336func (m *Models) setProviderItems() error {
337	t := m.com.Styles
338	cfg := m.com.Config()
339
340	var selectedItemID string
341	selectedType := m.modelType.Config()
342	currentModel := cfg.Models[selectedType]
343	recentItems := cfg.RecentModels[selectedType]
344
345	// Track providers already added to avoid duplicates
346	addedProviders := make(map[string]bool)
347
348	// Get a list of known providers to compare against
349	knownProviders, err := config.Providers(cfg)
350	if err != nil {
351		return fmt.Errorf("failed to get providers: %w", err)
352	}
353
354	containsProviderFunc := func(id string) func(p catwalk.Provider) bool {
355		return func(p catwalk.Provider) bool {
356			return p.ID == catwalk.InferenceProvider(id)
357		}
358	}
359
360	// itemsMap contains the keys of added model items.
361	itemsMap := make(map[string]*ModelItem)
362	groups := []ModelGroup{}
363	for id, p := range cfg.Providers.Seq2() {
364		if p.Disable {
365			continue
366		}
367
368		// Check if this provider is not in the known providers list
369		if !slices.ContainsFunc(knownProviders, containsProviderFunc(id)) ||
370			!slices.ContainsFunc(m.providers, containsProviderFunc(id)) {
371			provider := p.ToProvider()
372
373			// Add this unknown provider to the list
374			name := cmp.Or(p.Name, id)
375
376			addedProviders[id] = true
377
378			group := NewModelGroup(t, name, true)
379			for _, model := range p.Models {
380				item := NewModelItem(t, provider, model, m.modelType, false)
381				group.AppendItems(item)
382				itemsMap[item.ID()] = item
383				if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
384					selectedItemID = item.ID()
385				}
386			}
387			if len(group.Items) > 0 {
388				groups = append(groups, group)
389			}
390		}
391	}
392
393	// Move "Charm Hyper" to first position.
394	// (But still after recent models and custom providers).
395	slices.SortStableFunc(m.providers, func(a, b catwalk.Provider) int {
396		switch {
397		case a.ID == "hyper":
398			return -1
399		case b.ID == "hyper":
400			return 1
401		default:
402			return 0
403		}
404	})
405
406	// Now add known providers from the predefined list
407	for _, provider := range m.providers {
408		providerID := string(provider.ID)
409		if addedProviders[providerID] {
410			continue
411		}
412
413		providerConfig, providerConfigured := cfg.Providers.Get(providerID)
414		if providerConfigured && providerConfig.Disable {
415			continue
416		}
417
418		displayProvider := provider
419		if providerConfigured {
420			displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
421			modelIndex := make(map[string]int, len(displayProvider.Models))
422			for i, model := range displayProvider.Models {
423				modelIndex[model.ID] = i
424			}
425			for _, model := range providerConfig.Models {
426				if model.ID == "" {
427					continue
428				}
429				if idx, ok := modelIndex[model.ID]; ok {
430					if model.Name != "" {
431						displayProvider.Models[idx].Name = model.Name
432					}
433					continue
434				}
435				if model.Name == "" {
436					model.Name = model.ID
437				}
438				displayProvider.Models = append(displayProvider.Models, model)
439				modelIndex[model.ID] = len(displayProvider.Models) - 1
440			}
441		}
442
443		name := displayProvider.Name
444		if name == "" {
445			name = providerID
446		}
447
448		group := NewModelGroup(t, name, providerConfigured)
449		for _, model := range displayProvider.Models {
450			item := NewModelItem(t, provider, model, m.modelType, false)
451			group.AppendItems(item)
452			itemsMap[item.ID()] = item
453			if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
454				selectedItemID = item.ID()
455			}
456		}
457
458		groups = append(groups, group)
459	}
460
461	if len(recentItems) > 0 {
462		recentGroup := NewModelGroup(t, "Recently used", false)
463
464		var validRecentItems []config.SelectedModel
465		for _, recent := range recentItems {
466			key := modelKey(recent.Provider, recent.Model)
467			item, ok := itemsMap[key]
468			if !ok {
469				continue
470			}
471
472			// Show provider for recent items
473			item = NewModelItem(t, item.prov, item.model, m.modelType, true)
474			item.showProvider = true
475
476			validRecentItems = append(validRecentItems, recent)
477			recentGroup.AppendItems(item)
478			if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider {
479				selectedItemID = item.ID()
480			}
481		}
482
483		if len(validRecentItems) != len(recentItems) {
484			// FIXME: Does this need to be here? Is it mutating the config during a read?
485			if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
486				return fmt.Errorf("failed to update recent models: %w", err)
487			}
488		}
489
490		if len(recentGroup.Items) > 0 {
491			groups = append([]ModelGroup{recentGroup}, groups...)
492		}
493	}
494
495	// Set model groups in the list.
496	m.list.SetGroups(groups...)
497	m.list.SetSelectedItem(selectedItemID)
498	m.list.ScrollToTop()
499
500	// Update placeholder based on model type
501	if !m.isOnboarding {
502		m.input.Placeholder = m.modelType.Placeholder()
503	}
504
505	return nil
506}
507
508func getFilteredProviders(cfg *config.Config) ([]catwalk.Provider, error) {
509	providers, err := config.Providers(cfg)
510	if err != nil {
511		return nil, fmt.Errorf("failed to get providers: %w", err)
512	}
513	var filteredProviders []catwalk.Provider
514	for _, p := range providers {
515		var (
516			isAzure         = p.ID == catwalk.InferenceProviderAzure
517			isCopilot       = p.ID == catwalk.InferenceProviderCopilot
518			isHyper         = string(p.ID) == "hyper"
519			hasAPIKeyEnv    = strings.HasPrefix(p.APIKey, "$")
520			_, isConfigured = cfg.Providers.Get(string(p.ID))
521		)
522		if isAzure || isCopilot || isHyper || hasAPIKeyEnv || isConfigured {
523			filteredProviders = append(filteredProviders, p)
524		}
525	}
526	return filteredProviders, nil
527}
528
529func modelKey(providerID, modelID string) string {
530	if providerID == "" || modelID == "" {
531		return ""
532	}
533	return providerID + ":" + modelID
534}