list.go

  1package models
  2
  3import (
  4	"cmp"
  5	"fmt"
  6	"slices"
  7	"strings"
  8
  9	tea "charm.land/bubbletea/v2"
 10	"github.com/charmbracelet/catwalk/pkg/catwalk"
 11	"github.com/charmbracelet/crush/internal/config"
 12	"github.com/charmbracelet/crush/internal/tui/exp/list"
 13	"github.com/charmbracelet/crush/internal/tui/styles"
 14	"github.com/charmbracelet/crush/internal/tui/util"
 15)
 16
 17type listModel = list.FilterableGroupList[list.CompletionItem[ModelOption]]
 18
 19type ModelListComponent struct {
 20	list      listModel
 21	modelType int
 22	providers []catwalk.Provider
 23}
 24
 25func modelKey(providerID, modelID string) string {
 26	if providerID == "" || modelID == "" {
 27		return ""
 28	}
 29	return providerID + ":" + modelID
 30}
 31
 32func NewModelListComponent(keyMap list.KeyMap, inputPlaceholder string, shouldResize bool) *ModelListComponent {
 33	t := styles.CurrentTheme()
 34	inputStyle := t.S().Base.PaddingLeft(1).PaddingBottom(1)
 35	options := []list.ListOption{
 36		list.WithKeyMap(keyMap),
 37		list.WithWrapNavigation(),
 38	}
 39	if shouldResize {
 40		options = append(options, list.WithResizeByList())
 41	}
 42	modelList := list.NewFilterableGroupedList(
 43		[]list.Group[list.CompletionItem[ModelOption]]{},
 44		list.WithFilterInputStyle(inputStyle),
 45		list.WithFilterPlaceholder(inputPlaceholder),
 46		list.WithFilterListOptions(
 47			options...,
 48		),
 49	)
 50
 51	return &ModelListComponent{
 52		list:      modelList,
 53		modelType: LargeModelType,
 54	}
 55}
 56
 57func (m *ModelListComponent) Init() tea.Cmd {
 58	var cmds []tea.Cmd
 59	if len(m.providers) == 0 {
 60		cfg := config.Get()
 61		providers, err := config.Providers(cfg)
 62		filteredProviders := []catwalk.Provider{}
 63		for _, p := range providers {
 64			hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
 65			isHyper := p.ID == "hyper"
 66			if (hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure) || isHyper {
 67				filteredProviders = append(filteredProviders, p)
 68			}
 69		}
 70
 71		m.providers = filteredProviders
 72		if err != nil {
 73			cmds = append(cmds, util.ReportError(err))
 74		}
 75	}
 76	cmds = append(cmds, m.list.Init(), m.SetModelType(m.modelType))
 77	return tea.Batch(cmds...)
 78}
 79
 80func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
 81	u, cmd := m.list.Update(msg)
 82	m.list = u.(listModel)
 83	return m, cmd
 84}
 85
 86func (m *ModelListComponent) View() string {
 87	return m.list.View()
 88}
 89
 90func (m *ModelListComponent) Cursor() *tea.Cursor {
 91	return m.list.Cursor()
 92}
 93
 94func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
 95	return m.list.SetSize(width, height)
 96}
 97
 98func (m *ModelListComponent) SelectedModel() *ModelOption {
 99	s := m.list.SelectedItem()
100	if s == nil {
101		return nil
102	}
103	sv := *s
104	model := sv.Value()
105	return &model
106}
107
108func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
109	t := styles.CurrentTheme()
110	m.modelType = modelType
111
112	var groups []list.Group[list.CompletionItem[ModelOption]]
113	// first none section
114	selectedItemID := ""
115	itemsByKey := make(map[string]list.CompletionItem[ModelOption])
116
117	cfg := config.Get()
118	var currentModel config.SelectedModel
119	selectedType := config.SelectedModelTypeLarge
120	if m.modelType == LargeModelType {
121		currentModel = cfg.Models[config.SelectedModelTypeLarge]
122		selectedType = config.SelectedModelTypeLarge
123	} else {
124		currentModel = cfg.Models[config.SelectedModelTypeSmall]
125		selectedType = config.SelectedModelTypeSmall
126	}
127	recentItems := cfg.RecentModels[selectedType]
128
129	configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
130	configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
131
132	// Create a map to track which providers we've already added
133	addedProviders := make(map[string]bool)
134
135	// First, add any configured providers that are not in the known providers list
136	// These should appear at the top of the list
137	knownProviders, err := config.Providers(cfg)
138	if err != nil {
139		return util.ReportError(err)
140	}
141	for providerID, providerConfig := range cfg.Providers.Seq2() {
142		if providerConfig.Disable {
143			continue
144		}
145
146		// Check if this provider is not in the known providers list
147		if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) ||
148			!slices.ContainsFunc(m.providers, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
149			// Convert config provider to provider.Provider format
150			configProvider := catwalk.Provider{
151				Name:   providerConfig.Name,
152				ID:     catwalk.InferenceProvider(providerID),
153				Models: make([]catwalk.Model, len(providerConfig.Models)),
154			}
155
156			// Convert models
157			for i, model := range providerConfig.Models {
158				configProvider.Models[i] = catwalk.Model{
159					ID:                     model.ID,
160					Name:                   model.Name,
161					CostPer1MIn:            model.CostPer1MIn,
162					CostPer1MOut:           model.CostPer1MOut,
163					CostPer1MInCached:      model.CostPer1MInCached,
164					CostPer1MOutCached:     model.CostPer1MOutCached,
165					ContextWindow:          model.ContextWindow,
166					DefaultMaxTokens:       model.DefaultMaxTokens,
167					CanReason:              model.CanReason,
168					ReasoningLevels:        model.ReasoningLevels,
169					DefaultReasoningEffort: model.DefaultReasoningEffort,
170					SupportsImages:         model.SupportsImages,
171				}
172			}
173
174			// Add this unknown provider to the list
175			name := configProvider.Name
176			if name == "" {
177				name = string(configProvider.ID)
178			}
179			section := list.NewItemSection(name)
180			section.SetInfo(configured)
181			group := list.Group[list.CompletionItem[ModelOption]]{
182				Section: section,
183			}
184			for _, model := range configProvider.Models {
185				modelOption := ModelOption{
186					Provider: configProvider,
187					Model:    model,
188				}
189				key := modelKey(string(configProvider.ID), model.ID)
190				item := list.NewCompletionItem(
191					model.Name,
192					modelOption,
193					list.WithCompletionID(key),
194				)
195				itemsByKey[key] = item
196
197				group.Items = append(group.Items, item)
198				if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
199					selectedItemID = item.ID()
200				}
201			}
202			groups = append(groups, group)
203
204			addedProviders[providerID] = true
205		}
206	}
207
208	// Move "Charm Hyper" to first position
209	// (but still after recent models and custom providers).
210	sortedProviders := make([]catwalk.Provider, len(m.providers))
211	copy(sortedProviders, m.providers)
212	slices.SortStableFunc(sortedProviders, func(a, b catwalk.Provider) int {
213		switch {
214		case a.ID == "hyper":
215			return -1
216		case b.ID == "hyper":
217			return 1
218		default:
219			return 0
220		}
221	})
222
223	// Then add the known providers from the predefined list
224	for _, provider := range sortedProviders {
225		// Skip if we already added this provider as an unknown provider
226		if addedProviders[string(provider.ID)] {
227			continue
228		}
229
230		providerConfig, providerConfigured := cfg.Providers.Get(string(provider.ID))
231		if providerConfigured && providerConfig.Disable {
232			continue
233		}
234
235		displayProvider := provider
236		if providerConfigured {
237			displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
238			modelIndex := make(map[string]int, len(displayProvider.Models))
239			for i, model := range displayProvider.Models {
240				modelIndex[model.ID] = i
241			}
242			for _, model := range providerConfig.Models {
243				if model.ID == "" {
244					continue
245				}
246				if idx, ok := modelIndex[model.ID]; ok {
247					if model.Name != "" {
248						displayProvider.Models[idx].Name = model.Name
249					}
250					continue
251				}
252				if model.Name == "" {
253					model.Name = model.ID
254				}
255				displayProvider.Models = append(displayProvider.Models, model)
256				modelIndex[model.ID] = len(displayProvider.Models) - 1
257			}
258		}
259
260		name := displayProvider.Name
261		if name == "" {
262			name = string(displayProvider.ID)
263		}
264
265		section := list.NewItemSection(name)
266		if providerConfigured {
267			section.SetInfo(configured)
268		}
269		group := list.Group[list.CompletionItem[ModelOption]]{
270			Section: section,
271		}
272		for _, model := range displayProvider.Models {
273			modelOption := ModelOption{
274				Provider: displayProvider,
275				Model:    model,
276			}
277			key := modelKey(string(displayProvider.ID), model.ID)
278			item := list.NewCompletionItem(
279				model.Name,
280				modelOption,
281				list.WithCompletionID(key),
282			)
283			itemsByKey[key] = item
284			group.Items = append(group.Items, item)
285			if model.ID == currentModel.Model && string(displayProvider.ID) == currentModel.Provider {
286				selectedItemID = item.ID()
287			}
288		}
289		groups = append(groups, group)
290	}
291
292	if len(recentItems) > 0 {
293		recentSection := list.NewItemSection("Recently used")
294		recentGroup := list.Group[list.CompletionItem[ModelOption]]{
295			Section: recentSection,
296		}
297		var validRecentItems []config.SelectedModel
298		for _, recent := range recentItems {
299			key := modelKey(recent.Provider, recent.Model)
300			option, ok := itemsByKey[key]
301			if !ok {
302				continue
303			}
304			validRecentItems = append(validRecentItems, recent)
305			recentID := fmt.Sprintf("recent::%s", key)
306			modelOption := option.Value()
307			providerName := modelOption.Provider.Name
308			if providerName == "" {
309				providerName = string(modelOption.Provider.ID)
310			}
311			item := list.NewCompletionItem(
312				modelOption.Model.Name,
313				option.Value(),
314				list.WithCompletionID(recentID),
315				list.WithCompletionShortcut(providerName),
316			)
317			recentGroup.Items = append(recentGroup.Items, item)
318			if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider {
319				selectedItemID = recentID
320			}
321		}
322
323		if len(validRecentItems) != len(recentItems) {
324			if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
325				return util.ReportError(err)
326			}
327		}
328
329		if len(recentGroup.Items) > 0 {
330			groups = append([]list.Group[list.CompletionItem[ModelOption]]{recentGroup}, groups...)
331		}
332	}
333
334	var cmds []tea.Cmd
335
336	cmd := m.list.SetGroups(groups)
337
338	if cmd != nil {
339		cmds = append(cmds, cmd)
340	}
341	cmd = m.list.SetSelected(selectedItemID)
342	if cmd != nil {
343		cmds = append(cmds, cmd)
344	}
345
346	return tea.Sequence(cmds...)
347}
348
349// GetModelType returns the current model type
350func (m *ModelListComponent) GetModelType() int {
351	return m.modelType
352}
353
354func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
355	m.list.SetInputPlaceholder(placeholder)
356}