list.go

  1package models
  2
  3import (
  4	"cmp"
  5	"fmt"
  6	"slices"
  7	"strings"
  8
  9	tea "charm.land/bubbletea/v2"
 10	"github.com/charmbracelet/catwalk/pkg/catwalk"
 11	"github.com/charmbracelet/crush/internal/config"
 12	"github.com/charmbracelet/crush/internal/tui/exp/list"
 13	"github.com/charmbracelet/crush/internal/tui/styles"
 14	"github.com/charmbracelet/crush/internal/tui/util"
 15)
 16
 17type listModel = list.FilterableGroupList[list.CompletionItem[ModelOption]]
 18
 19type ModelListComponent struct {
 20	list      listModel
 21	modelType int
 22	providers []catwalk.Provider
 23}
 24
 25func modelKey(providerID, modelID string) string {
 26	if providerID == "" || modelID == "" {
 27		return ""
 28	}
 29	return providerID + ":" + modelID
 30}
 31
 32func NewModelListComponent(keyMap list.KeyMap, inputPlaceholder string, shouldResize bool) *ModelListComponent {
 33	t := styles.CurrentTheme()
 34	inputStyle := t.S().Base.PaddingLeft(1).PaddingBottom(1)
 35	options := []list.ListOption{
 36		list.WithKeyMap(keyMap),
 37		list.WithWrapNavigation(),
 38	}
 39	if shouldResize {
 40		options = append(options, list.WithResizeByList())
 41	}
 42	modelList := list.NewFilterableGroupedList(
 43		[]list.Group[list.CompletionItem[ModelOption]]{},
 44		list.WithFilterInputStyle(inputStyle),
 45		list.WithFilterPlaceholder(inputPlaceholder),
 46		list.WithFilterListOptions(
 47			options...,
 48		),
 49	)
 50
 51	return &ModelListComponent{
 52		list:      modelList,
 53		modelType: LargeModelType,
 54	}
 55}
 56
 57func (m *ModelListComponent) Init() tea.Cmd {
 58	var cmds []tea.Cmd
 59	if len(m.providers) == 0 {
 60		cfg := config.Get()
 61		providers, err := config.Providers(cfg)
 62		filteredProviders := []catwalk.Provider{}
 63		for _, p := range providers {
 64			hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
 65			isHyper := p.ID == "hyper"
 66			isCopilot := p.ID == catwalk.InferenceProviderCopilot
 67			if (hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure) || isHyper || isCopilot {
 68				filteredProviders = append(filteredProviders, p)
 69			}
 70		}
 71
 72		m.providers = filteredProviders
 73		if err != nil {
 74			cmds = append(cmds, util.ReportError(err))
 75		}
 76	}
 77	cmds = append(cmds, m.list.Init(), m.SetModelType(m.modelType))
 78	return tea.Batch(cmds...)
 79}
 80
 81func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
 82	u, cmd := m.list.Update(msg)
 83	m.list = u.(listModel)
 84	return m, cmd
 85}
 86
 87func (m *ModelListComponent) View() string {
 88	return m.list.View()
 89}
 90
 91func (m *ModelListComponent) Cursor() *tea.Cursor {
 92	return m.list.Cursor()
 93}
 94
 95func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
 96	return m.list.SetSize(width, height)
 97}
 98
 99func (m *ModelListComponent) SelectedModel() *ModelOption {
100	s := m.list.SelectedItem()
101	if s == nil {
102		return nil
103	}
104	sv := *s
105	model := sv.Value()
106	return &model
107}
108
109func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
110	t := styles.CurrentTheme()
111	m.modelType = modelType
112
113	var groups []list.Group[list.CompletionItem[ModelOption]]
114	// first none section
115	selectedItemID := ""
116	itemsByKey := make(map[string]list.CompletionItem[ModelOption])
117
118	cfg := config.Get()
119	var currentModel config.SelectedModel
120	selectedType := config.SelectedModelTypeLarge
121	if m.modelType == LargeModelType {
122		currentModel = cfg.Models[config.SelectedModelTypeLarge]
123		selectedType = config.SelectedModelTypeLarge
124	} else {
125		currentModel = cfg.Models[config.SelectedModelTypeSmall]
126		selectedType = config.SelectedModelTypeSmall
127	}
128	recentItems := cfg.RecentModels[selectedType]
129
130	configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
131	configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
132
133	// Create a map to track which providers we've already added
134	addedProviders := make(map[string]bool)
135
136	// First, add any configured providers that are not in the known providers list
137	// These should appear at the top of the list
138	knownProviders, err := config.Providers(cfg)
139	if err != nil {
140		return util.ReportError(err)
141	}
142	for providerID, providerConfig := range cfg.Providers.Seq2() {
143		if providerConfig.Disable {
144			continue
145		}
146
147		// Check if this provider is not in the known providers list
148		if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) ||
149			!slices.ContainsFunc(m.providers, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
150			// Convert config provider to provider.Provider format
151			configProvider := catwalk.Provider{
152				Name:   providerConfig.Name,
153				ID:     catwalk.InferenceProvider(providerID),
154				Models: make([]catwalk.Model, len(providerConfig.Models)),
155			}
156
157			// Convert models
158			for i, model := range providerConfig.Models {
159				configProvider.Models[i] = catwalk.Model{
160					ID:                     model.ID,
161					Name:                   model.Name,
162					CostPer1MIn:            model.CostPer1MIn,
163					CostPer1MOut:           model.CostPer1MOut,
164					CostPer1MInCached:      model.CostPer1MInCached,
165					CostPer1MOutCached:     model.CostPer1MOutCached,
166					ContextWindow:          model.ContextWindow,
167					DefaultMaxTokens:       model.DefaultMaxTokens,
168					CanReason:              model.CanReason,
169					ReasoningLevels:        model.ReasoningLevels,
170					DefaultReasoningEffort: model.DefaultReasoningEffort,
171					SupportsImages:         model.SupportsImages,
172				}
173			}
174
175			// Add this unknown provider to the list
176			name := configProvider.Name
177			if name == "" {
178				name = string(configProvider.ID)
179			}
180			section := list.NewItemSection(name)
181			section.SetInfo(configured)
182			group := list.Group[list.CompletionItem[ModelOption]]{
183				Section: section,
184			}
185			for _, model := range configProvider.Models {
186				modelOption := ModelOption{
187					Provider: configProvider,
188					Model:    model,
189				}
190				key := modelKey(string(configProvider.ID), model.ID)
191				item := list.NewCompletionItem(
192					model.Name,
193					modelOption,
194					list.WithCompletionID(key),
195				)
196				itemsByKey[key] = item
197
198				group.Items = append(group.Items, item)
199				if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
200					selectedItemID = item.ID()
201				}
202			}
203			groups = append(groups, group)
204
205			addedProviders[providerID] = true
206		}
207	}
208
209	// Move "Charm Hyper" to first position
210	// (but still after recent models and custom providers).
211	sortedProviders := make([]catwalk.Provider, len(m.providers))
212	copy(sortedProviders, m.providers)
213	slices.SortStableFunc(sortedProviders, func(a, b catwalk.Provider) int {
214		switch {
215		case a.ID == "hyper":
216			return -1
217		case b.ID == "hyper":
218			return 1
219		default:
220			return 0
221		}
222	})
223
224	// Then add the known providers from the predefined list
225	for _, provider := range sortedProviders {
226		// Skip if we already added this provider as an unknown provider
227		if addedProviders[string(provider.ID)] {
228			continue
229		}
230
231		providerConfig, providerConfigured := cfg.Providers.Get(string(provider.ID))
232		if providerConfigured && providerConfig.Disable {
233			continue
234		}
235
236		displayProvider := provider
237		if providerConfigured {
238			displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
239			modelIndex := make(map[string]int, len(displayProvider.Models))
240			for i, model := range displayProvider.Models {
241				modelIndex[model.ID] = i
242			}
243			for _, model := range providerConfig.Models {
244				if model.ID == "" {
245					continue
246				}
247				if idx, ok := modelIndex[model.ID]; ok {
248					if model.Name != "" {
249						displayProvider.Models[idx].Name = model.Name
250					}
251					continue
252				}
253				if model.Name == "" {
254					model.Name = model.ID
255				}
256				displayProvider.Models = append(displayProvider.Models, model)
257				modelIndex[model.ID] = len(displayProvider.Models) - 1
258			}
259		}
260
261		name := displayProvider.Name
262		if name == "" {
263			name = string(displayProvider.ID)
264		}
265
266		section := list.NewItemSection(name)
267		if providerConfigured {
268			section.SetInfo(configured)
269		}
270		group := list.Group[list.CompletionItem[ModelOption]]{
271			Section: section,
272		}
273		for _, model := range displayProvider.Models {
274			modelOption := ModelOption{
275				Provider: displayProvider,
276				Model:    model,
277			}
278			key := modelKey(string(displayProvider.ID), model.ID)
279			item := list.NewCompletionItem(
280				model.Name,
281				modelOption,
282				list.WithCompletionID(key),
283			)
284			itemsByKey[key] = item
285			group.Items = append(group.Items, item)
286			if model.ID == currentModel.Model && string(displayProvider.ID) == currentModel.Provider {
287				selectedItemID = item.ID()
288			}
289		}
290		groups = append(groups, group)
291	}
292
293	if len(recentItems) > 0 {
294		recentSection := list.NewItemSection("Recently used")
295		recentGroup := list.Group[list.CompletionItem[ModelOption]]{
296			Section: recentSection,
297		}
298		var validRecentItems []config.SelectedModel
299		for _, recent := range recentItems {
300			key := modelKey(recent.Provider, recent.Model)
301			option, ok := itemsByKey[key]
302			if !ok {
303				continue
304			}
305			validRecentItems = append(validRecentItems, recent)
306			recentID := fmt.Sprintf("recent::%s", key)
307			modelOption := option.Value()
308			providerName := modelOption.Provider.Name
309			if providerName == "" {
310				providerName = string(modelOption.Provider.ID)
311			}
312			item := list.NewCompletionItem(
313				modelOption.Model.Name,
314				option.Value(),
315				list.WithCompletionID(recentID),
316				list.WithCompletionShortcut(providerName),
317			)
318			recentGroup.Items = append(recentGroup.Items, item)
319			if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider {
320				selectedItemID = recentID
321			}
322		}
323
324		if len(validRecentItems) != len(recentItems) {
325			if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
326				return util.ReportError(err)
327			}
328		}
329
330		if len(recentGroup.Items) > 0 {
331			groups = append([]list.Group[list.CompletionItem[ModelOption]]{recentGroup}, groups...)
332		}
333	}
334
335	var cmds []tea.Cmd
336
337	cmd := m.list.SetGroups(groups)
338
339	if cmd != nil {
340		cmds = append(cmds, cmd)
341	}
342	cmd = m.list.SetSelected(selectedItemID)
343	if cmd != nil {
344		cmds = append(cmds, cmd)
345	}
346
347	return tea.Sequence(cmds...)
348}
349
350// GetModelType returns the current model type
351func (m *ModelListComponent) GetModelType() int {
352	return m.modelType
353}
354
355func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
356	m.list.SetInputPlaceholder(placeholder)
357}