list.go

  1package models
  2
  3import (
  4	"cmp"
  5	"fmt"
  6	"slices"
  7	"strings"
  8
  9	tea "charm.land/bubbletea/v2"
 10	"github.com/charmbracelet/catwalk/pkg/catwalk"
 11	"github.com/charmbracelet/crush/internal/config"
 12	"github.com/charmbracelet/crush/internal/tui/exp/list"
 13	"github.com/charmbracelet/crush/internal/tui/styles"
 14	"github.com/charmbracelet/crush/internal/tui/util"
 15)
 16
 17type listModel = list.FilterableGroupList[list.CompletionItem[ModelOption]]
 18
 19type ModelListComponent struct {
 20	list      listModel
 21	modelType int
 22	providers []catwalk.Provider
 23}
 24
 25func modelKey(providerID, modelID string) string {
 26	if providerID == "" || modelID == "" {
 27		return ""
 28	}
 29	return providerID + ":" + modelID
 30}
 31
 32func NewModelListComponent(keyMap list.KeyMap, inputPlaceholder string, shouldResize bool) *ModelListComponent {
 33	t := styles.CurrentTheme()
 34	inputStyle := t.S().Base.PaddingLeft(1).PaddingBottom(1)
 35	options := []list.ListOption{
 36		list.WithKeyMap(keyMap),
 37		list.WithWrapNavigation(),
 38	}
 39	if shouldResize {
 40		options = append(options, list.WithResizeByList())
 41	}
 42	modelList := list.NewFilterableGroupedList(
 43		[]list.Group[list.CompletionItem[ModelOption]]{},
 44		list.WithFilterInputStyle(inputStyle),
 45		list.WithFilterPlaceholder(inputPlaceholder),
 46		list.WithFilterListOptions(
 47			options...,
 48		),
 49	)
 50
 51	return &ModelListComponent{
 52		list:      modelList,
 53		modelType: LargeModelType,
 54	}
 55}
 56
 57func (m *ModelListComponent) Init() tea.Cmd {
 58	var cmds []tea.Cmd
 59	if len(m.providers) == 0 {
 60		cfg := config.Get()
 61		providers, err := config.Providers(cfg)
 62		filteredProviders := []catwalk.Provider{}
 63		for _, p := range providers {
 64			hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
 65			isHyper := p.ID == "hyper"
 66			isCopilot := p.ID == catwalk.InferenceProviderCopilot
 67			if (hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure) || isHyper || isCopilot {
 68				filteredProviders = append(filteredProviders, p)
 69			}
 70		}
 71
 72		m.providers = filteredProviders
 73		if err != nil {
 74			cmds = append(cmds, util.ReportError(err))
 75		}
 76	}
 77	cmds = append(cmds, m.list.Init(), m.SetModelType(m.modelType))
 78	return tea.Batch(cmds...)
 79}
 80
 81func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
 82	u, cmd := m.list.Update(msg)
 83	m.list = u.(listModel)
 84	return m, cmd
 85}
 86
 87func (m *ModelListComponent) View() string {
 88	return m.list.View()
 89}
 90
 91func (m *ModelListComponent) Cursor() *tea.Cursor {
 92	return m.list.Cursor()
 93}
 94
 95func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
 96	return m.list.SetSize(width, height)
 97}
 98
 99func (m *ModelListComponent) SelectedModel() *ModelOption {
100	s := m.list.SelectedItem()
101	if s == nil {
102		return nil
103	}
104	sv := *s
105	model := sv.Value()
106	return &model
107}
108
109func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
110	t := styles.CurrentTheme()
111	m.modelType = modelType
112
113	var groups []list.Group[list.CompletionItem[ModelOption]]
114	// first none section
115	selectedItemID := ""
116	itemsByKey := make(map[string]list.CompletionItem[ModelOption])
117
118	cfg := config.Get()
119	var currentModel config.SelectedModel
120	selectedType := config.SelectedModelTypeLarge
121	if m.modelType == LargeModelType {
122		currentModel = cfg.Models[config.SelectedModelTypeLarge]
123		selectedType = config.SelectedModelTypeLarge
124	} else {
125		currentModel = cfg.Models[config.SelectedModelTypeSmall]
126		selectedType = config.SelectedModelTypeSmall
127	}
128	recentItems := cfg.RecentModels[selectedType]
129
130	configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
131	configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
132
133	// Create a map to track which providers we've already added
134	addedProviders := make(map[string]bool)
135
136	// First, add any configured providers that are not in the known providers list
137	// These should appear at the top of the list
138	knownProviders, err := config.Providers(cfg)
139	if err != nil {
140		return util.ReportError(err)
141	}
142	for providerID, providerConfig := range cfg.Providers.Seq2() {
143		if providerConfig.Disable {
144			continue
145		}
146
147		// Check if this provider is not in the known providers list
148		if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) ||
149			!slices.ContainsFunc(m.providers, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
150			// Convert config provider to provider.Provider format
151			configProvider := providerConfig.ToProvider()
152
153			// Add this unknown provider to the list
154			name := configProvider.Name
155			if name == "" {
156				name = string(configProvider.ID)
157			}
158			section := list.NewItemSection(name)
159			section.SetInfo(configured)
160			group := list.Group[list.CompletionItem[ModelOption]]{
161				Section: section,
162			}
163			for _, model := range configProvider.Models {
164				modelOption := ModelOption{
165					Provider: configProvider,
166					Model:    model,
167				}
168				key := modelKey(string(configProvider.ID), model.ID)
169				item := list.NewCompletionItem(
170					model.Name,
171					modelOption,
172					list.WithCompletionID(key),
173				)
174				itemsByKey[key] = item
175
176				group.Items = append(group.Items, item)
177				if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
178					selectedItemID = item.ID()
179				}
180			}
181			groups = append(groups, group)
182
183			addedProviders[providerID] = true
184		}
185	}
186
187	// Move "Charm Hyper" to first position
188	// (but still after recent models and custom providers).
189	sortedProviders := make([]catwalk.Provider, len(m.providers))
190	copy(sortedProviders, m.providers)
191	slices.SortStableFunc(sortedProviders, func(a, b catwalk.Provider) int {
192		switch {
193		case a.ID == "hyper":
194			return -1
195		case b.ID == "hyper":
196			return 1
197		default:
198			return 0
199		}
200	})
201
202	// Then add the known providers from the predefined list
203	for _, provider := range sortedProviders {
204		// Skip if we already added this provider as an unknown provider
205		if addedProviders[string(provider.ID)] {
206			continue
207		}
208
209		providerConfig, providerConfigured := cfg.Providers.Get(string(provider.ID))
210		if providerConfigured && providerConfig.Disable {
211			continue
212		}
213
214		displayProvider := provider
215		if providerConfigured {
216			displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
217			modelIndex := make(map[string]int, len(displayProvider.Models))
218			for i, model := range displayProvider.Models {
219				modelIndex[model.ID] = i
220			}
221			for _, model := range providerConfig.Models {
222				if model.ID == "" {
223					continue
224				}
225				if idx, ok := modelIndex[model.ID]; ok {
226					if model.Name != "" {
227						displayProvider.Models[idx].Name = model.Name
228					}
229					continue
230				}
231				if model.Name == "" {
232					model.Name = model.ID
233				}
234				displayProvider.Models = append(displayProvider.Models, model)
235				modelIndex[model.ID] = len(displayProvider.Models) - 1
236			}
237		}
238
239		name := displayProvider.Name
240		if name == "" {
241			name = string(displayProvider.ID)
242		}
243
244		section := list.NewItemSection(name)
245		if providerConfigured {
246			section.SetInfo(configured)
247		}
248		group := list.Group[list.CompletionItem[ModelOption]]{
249			Section: section,
250		}
251		for _, model := range displayProvider.Models {
252			modelOption := ModelOption{
253				Provider: displayProvider,
254				Model:    model,
255			}
256			key := modelKey(string(displayProvider.ID), model.ID)
257			item := list.NewCompletionItem(
258				model.Name,
259				modelOption,
260				list.WithCompletionID(key),
261			)
262			itemsByKey[key] = item
263			group.Items = append(group.Items, item)
264			if model.ID == currentModel.Model && string(displayProvider.ID) == currentModel.Provider {
265				selectedItemID = item.ID()
266			}
267		}
268		groups = append(groups, group)
269	}
270
271	if len(recentItems) > 0 {
272		recentSection := list.NewItemSection("Recently used")
273		recentGroup := list.Group[list.CompletionItem[ModelOption]]{
274			Section: recentSection,
275		}
276		var validRecentItems []config.SelectedModel
277		for _, recent := range recentItems {
278			key := modelKey(recent.Provider, recent.Model)
279			option, ok := itemsByKey[key]
280			if !ok {
281				continue
282			}
283			validRecentItems = append(validRecentItems, recent)
284			recentID := fmt.Sprintf("recent::%s", key)
285			modelOption := option.Value()
286			providerName := modelOption.Provider.Name
287			if providerName == "" {
288				providerName = string(modelOption.Provider.ID)
289			}
290			item := list.NewCompletionItem(
291				modelOption.Model.Name,
292				option.Value(),
293				list.WithCompletionID(recentID),
294				list.WithCompletionShortcut(providerName),
295			)
296			recentGroup.Items = append(recentGroup.Items, item)
297			if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider {
298				selectedItemID = recentID
299			}
300		}
301
302		if len(validRecentItems) != len(recentItems) {
303			if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
304				return util.ReportError(err)
305			}
306		}
307
308		if len(recentGroup.Items) > 0 {
309			groups = append([]list.Group[list.CompletionItem[ModelOption]]{recentGroup}, groups...)
310		}
311	}
312
313	var cmds []tea.Cmd
314
315	cmd := m.list.SetGroups(groups)
316
317	if cmd != nil {
318		cmds = append(cmds, cmd)
319	}
320	cmd = m.list.SetSelected(selectedItemID)
321	if cmd != nil {
322		cmds = append(cmds, cmd)
323	}
324
325	return tea.Sequence(cmds...)
326}
327
328// GetModelType returns the current model type
329func (m *ModelListComponent) GetModelType() int {
330	return m.modelType
331}
332
333func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
334	m.list.SetInputPlaceholder(placeholder)
335}