list.go

  1package models
  2
  3import (
  4	"cmp"
  5	"fmt"
  6	"slices"
  7	"strings"
  8
  9	tea "charm.land/bubbletea/v2"
 10	"github.com/charmbracelet/catwalk/pkg/catwalk"
 11	"github.com/charmbracelet/crush/internal/config"
 12	"github.com/charmbracelet/crush/internal/tui/exp/list"
 13	"github.com/charmbracelet/crush/internal/tui/styles"
 14	"github.com/charmbracelet/crush/internal/tui/util"
 15)
 16
 17type listModel = list.FilterableGroupList[list.CompletionItem[ModelOption]]
 18
 19type ModelListComponent struct {
 20	list      listModel
 21	modelType int
 22	providers []catwalk.Provider
 23}
 24
 25func modelKey(providerID, modelID string) string {
 26	if providerID == "" || modelID == "" {
 27		return ""
 28	}
 29	return providerID + ":" + modelID
 30}
 31
 32func NewModelListComponent(keyMap list.KeyMap, inputPlaceholder string, shouldResize bool) *ModelListComponent {
 33	t := styles.CurrentTheme()
 34	inputStyle := t.S().Base.PaddingLeft(1).PaddingBottom(1)
 35	options := []list.ListOption{
 36		list.WithKeyMap(keyMap),
 37		list.WithWrapNavigation(),
 38	}
 39	if shouldResize {
 40		options = append(options, list.WithResizeByList())
 41	}
 42	modelList := list.NewFilterableGroupedList(
 43		[]list.Group[list.CompletionItem[ModelOption]]{},
 44		list.WithFilterInputStyle(inputStyle),
 45		list.WithFilterPlaceholder(inputPlaceholder),
 46		list.WithFilterListOptions(
 47			options...,
 48		),
 49	)
 50
 51	return &ModelListComponent{
 52		list:      modelList,
 53		modelType: LargeModelType,
 54	}
 55}
 56
 57func (m *ModelListComponent) Init() tea.Cmd {
 58	var cmds []tea.Cmd
 59	if len(m.providers) == 0 {
 60		cfg := config.Get()
 61		providers, err := config.Providers(cfg)
 62		filteredProviders := []catwalk.Provider{}
 63		for _, p := range providers {
 64			hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
 65			isHyper := p.ID == "hyper"
 66			isCopilot := p.ID == catwalk.InferenceProviderCopilot
 67			if (hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure) || isHyper || isCopilot {
 68				filteredProviders = append(filteredProviders, p)
 69			}
 70		}
 71
 72		m.providers = filteredProviders
 73		if err != nil {
 74			cmds = append(cmds, util.ReportError(err))
 75		}
 76	}
 77	cmds = append(cmds, m.list.Init(), m.SetModelType(m.modelType))
 78	return tea.Batch(cmds...)
 79}
 80
 81func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
 82	u, cmd := m.list.Update(msg)
 83	m.list = u.(listModel)
 84	return m, cmd
 85}
 86
 87func (m *ModelListComponent) View() string {
 88	return m.list.View()
 89}
 90
 91func (m *ModelListComponent) Cursor() *tea.Cursor {
 92	return m.list.Cursor()
 93}
 94
 95func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
 96	return m.list.SetSize(width, height)
 97}
 98
 99func (m *ModelListComponent) SelectedModel() *ModelOption {
100	s := m.list.SelectedItem()
101	if s == nil {
102		return nil
103	}
104	sv := *s
105	model := sv.Value()
106	return &model
107}
108
109func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
110	t := styles.CurrentTheme()
111	m.modelType = modelType
112
113	var groups []list.Group[list.CompletionItem[ModelOption]]
114	// first none section
115	selectedItemID := ""
116	itemsByKey := make(map[string]list.CompletionItem[ModelOption])
117
118	cfg := config.Get()
119	var currentModel config.SelectedModel
120	selectedType := config.SelectedModelTypeLarge
121	if m.modelType == LargeModelType {
122		currentModel = cfg.Models[config.SelectedModelTypeLarge]
123		selectedType = config.SelectedModelTypeLarge
124	} else {
125		currentModel = cfg.Models[config.SelectedModelTypeSmall]
126		selectedType = config.SelectedModelTypeSmall
127	}
128	recentItems := cfg.RecentModels[selectedType]
129
130	configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
131	configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
132
133	// Create a map to track which providers we've already added
134	addedProviders := make(map[string]bool)
135
136	// First, add any configured providers that are not in the known providers list
137	// These should appear at the top of the list
138	knownProviders, err := config.Providers(cfg)
139	if err != nil {
140		return util.ReportError(err)
141	}
142	for providerID, providerConfig := range cfg.Providers.Seq2() {
143		if providerConfig.Disable {
144			continue
145		}
146
147		// Check if this provider is not in the known providers list
148		if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) ||
149			!slices.ContainsFunc(m.providers, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
150			// Convert config provider to provider.Provider format
151			configProvider := providerConfig.ToProvider()
152
153			// Add this unknown provider to the list
154			name := configProvider.Name
155			if name == "" {
156				name = string(configProvider.ID)
157			}
158			section := list.NewItemSection(name)
159			section.SetInfo(configured)
160			group := list.Group[list.CompletionItem[ModelOption]]{
161				Section: section,
162			}
163			for _, model := range configProvider.Models {
164				modelOption := ModelOption{
165					Provider: configProvider,
166					Model:    model,
167				}
168				key := modelKey(string(configProvider.ID), model.ID)
169				item := list.NewCompletionItem(
170					model.Name,
171					modelOption,
172					list.WithCompletionID(key),
173				)
174				itemsByKey[key] = item
175
176				group.Items = append(group.Items, item)
177				if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
178					selectedItemID = item.ID()
179				}
180			}
181			groups = append(groups, group)
182
183			addedProviders[providerID] = true
184		}
185	}
186
187	// Move "Charm Hyper" to first position
188	// (but still after recent models and custom providers).
189	slices.SortStableFunc(m.providers, func(a, b catwalk.Provider) int {
190		switch {
191		case a.ID == "hyper":
192			return -1
193		case b.ID == "hyper":
194			return 1
195		default:
196			return 0
197		}
198	})
199
200	// Then add the known providers from the predefined list
201	for _, provider := range m.providers {
202		// Skip if we already added this provider as an unknown provider
203		if addedProviders[string(provider.ID)] {
204			continue
205		}
206
207		providerConfig, providerConfigured := cfg.Providers.Get(string(provider.ID))
208		if providerConfigured && providerConfig.Disable {
209			continue
210		}
211
212		displayProvider := provider
213		if providerConfigured {
214			displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
215			modelIndex := make(map[string]int, len(displayProvider.Models))
216			for i, model := range displayProvider.Models {
217				modelIndex[model.ID] = i
218			}
219			for _, model := range providerConfig.Models {
220				if model.ID == "" {
221					continue
222				}
223				if idx, ok := modelIndex[model.ID]; ok {
224					if model.Name != "" {
225						displayProvider.Models[idx].Name = model.Name
226					}
227					continue
228				}
229				if model.Name == "" {
230					model.Name = model.ID
231				}
232				displayProvider.Models = append(displayProvider.Models, model)
233				modelIndex[model.ID] = len(displayProvider.Models) - 1
234			}
235		}
236
237		name := displayProvider.Name
238		if name == "" {
239			name = string(displayProvider.ID)
240		}
241
242		section := list.NewItemSection(name)
243		if providerConfigured {
244			section.SetInfo(configured)
245		}
246		group := list.Group[list.CompletionItem[ModelOption]]{
247			Section: section,
248		}
249		for _, model := range displayProvider.Models {
250			modelOption := ModelOption{
251				Provider: displayProvider,
252				Model:    model,
253			}
254			key := modelKey(string(displayProvider.ID), model.ID)
255			item := list.NewCompletionItem(
256				model.Name,
257				modelOption,
258				list.WithCompletionID(key),
259			)
260			itemsByKey[key] = item
261			group.Items = append(group.Items, item)
262			if model.ID == currentModel.Model && string(displayProvider.ID) == currentModel.Provider {
263				selectedItemID = item.ID()
264			}
265		}
266		groups = append(groups, group)
267	}
268
269	if len(recentItems) > 0 {
270		recentSection := list.NewItemSection("Recently used")
271		recentGroup := list.Group[list.CompletionItem[ModelOption]]{
272			Section: recentSection,
273		}
274		var validRecentItems []config.SelectedModel
275		for _, recent := range recentItems {
276			key := modelKey(recent.Provider, recent.Model)
277			option, ok := itemsByKey[key]
278			if !ok {
279				continue
280			}
281			validRecentItems = append(validRecentItems, recent)
282			recentID := fmt.Sprintf("recent::%s", key)
283			modelOption := option.Value()
284			providerName := modelOption.Provider.Name
285			if providerName == "" {
286				providerName = string(modelOption.Provider.ID)
287			}
288			item := list.NewCompletionItem(
289				modelOption.Model.Name,
290				option.Value(),
291				list.WithCompletionID(recentID),
292				list.WithCompletionShortcut(providerName),
293			)
294			recentGroup.Items = append(recentGroup.Items, item)
295			if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider {
296				selectedItemID = recentID
297			}
298		}
299
300		if len(validRecentItems) != len(recentItems) {
301			if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
302				return util.ReportError(err)
303			}
304		}
305
306		if len(recentGroup.Items) > 0 {
307			groups = append([]list.Group[list.CompletionItem[ModelOption]]{recentGroup}, groups...)
308		}
309	}
310
311	var cmds []tea.Cmd
312
313	cmd := m.list.SetGroups(groups)
314
315	if cmd != nil {
316		cmds = append(cmds, cmd)
317	}
318	cmd = m.list.SetSelected(selectedItemID)
319	if cmd != nil {
320		cmds = append(cmds, cmd)
321	}
322
323	return tea.Sequence(cmds...)
324}
325
326// GetModelType returns the current model type
327func (m *ModelListComponent) GetModelType() int {
328	return m.modelType
329}
330
331func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
332	m.list.SetInputPlaceholder(placeholder)
333}