list.go

  1package models
  2
  3import (
  4	"cmp"
  5	"fmt"
  6	"slices"
  7	"strings"
  8
  9	tea "charm.land/bubbletea/v2"
 10	"git.secluded.site/crush/internal/config"
 11	"git.secluded.site/crush/internal/tui/exp/list"
 12	"git.secluded.site/crush/internal/tui/styles"
 13	"git.secluded.site/crush/internal/tui/util"
 14	"github.com/charmbracelet/catwalk/pkg/catwalk"
 15)
 16
 17type listModel = list.FilterableGroupList[list.CompletionItem[ModelOption]]
 18
 19type ModelListComponent struct {
 20	list      listModel
 21	modelType int
 22	providers []catwalk.Provider
 23}
 24
 25func modelKey(providerID, modelID string) string {
 26	if providerID == "" || modelID == "" {
 27		return ""
 28	}
 29	return providerID + ":" + modelID
 30}
 31
 32func NewModelListComponent(keyMap list.KeyMap, inputPlaceholder string, shouldResize bool) *ModelListComponent {
 33	t := styles.CurrentTheme()
 34	inputStyle := t.S().Base.PaddingLeft(1).PaddingBottom(1)
 35	options := []list.ListOption{
 36		list.WithKeyMap(keyMap),
 37		list.WithWrapNavigation(),
 38	}
 39	if shouldResize {
 40		options = append(options, list.WithResizeByList())
 41	}
 42	modelList := list.NewFilterableGroupedList(
 43		[]list.Group[list.CompletionItem[ModelOption]]{},
 44		list.WithFilterInputStyle(inputStyle),
 45		list.WithFilterPlaceholder(inputPlaceholder),
 46		list.WithFilterListOptions(
 47			options...,
 48		),
 49	)
 50
 51	return &ModelListComponent{
 52		list:      modelList,
 53		modelType: LargeModelType,
 54	}
 55}
 56
 57func (m *ModelListComponent) Init() tea.Cmd {
 58	var cmds []tea.Cmd
 59	if len(m.providers) == 0 {
 60		cfg := config.Get()
 61		providers, err := config.Providers(cfg)
 62		filteredProviders := []catwalk.Provider{}
 63		for _, p := range providers {
 64			hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
 65			if hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure {
 66				filteredProviders = append(filteredProviders, p)
 67			}
 68		}
 69
 70		m.providers = filteredProviders
 71		if err != nil {
 72			cmds = append(cmds, util.ReportError(err))
 73		}
 74	}
 75	cmds = append(cmds, m.list.Init(), m.SetModelType(m.modelType))
 76	return tea.Batch(cmds...)
 77}
 78
 79func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
 80	u, cmd := m.list.Update(msg)
 81	m.list = u.(listModel)
 82	return m, cmd
 83}
 84
 85func (m *ModelListComponent) View() string {
 86	return m.list.View()
 87}
 88
 89func (m *ModelListComponent) Cursor() *tea.Cursor {
 90	return m.list.Cursor()
 91}
 92
 93func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
 94	return m.list.SetSize(width, height)
 95}
 96
 97func (m *ModelListComponent) SelectedModel() *ModelOption {
 98	s := m.list.SelectedItem()
 99	if s == nil {
100		return nil
101	}
102	sv := *s
103	model := sv.Value()
104	return &model
105}
106
107func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
108	t := styles.CurrentTheme()
109	m.modelType = modelType
110
111	var groups []list.Group[list.CompletionItem[ModelOption]]
112	// first none section
113	selectedItemID := ""
114	itemsByKey := make(map[string]list.CompletionItem[ModelOption])
115
116	cfg := config.Get()
117	var currentModel config.SelectedModel
118	selectedType := config.SelectedModelTypeLarge
119	if m.modelType == LargeModelType {
120		currentModel = cfg.Models[config.SelectedModelTypeLarge]
121		selectedType = config.SelectedModelTypeLarge
122	} else {
123		currentModel = cfg.Models[config.SelectedModelTypeSmall]
124		selectedType = config.SelectedModelTypeSmall
125	}
126	recentItems := cfg.RecentModels[selectedType]
127
128	configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
129	configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
130
131	// Create a map to track which providers we've already added
132	addedProviders := make(map[string]bool)
133
134	// First, add any configured providers that are not in the known providers list
135	// These should appear at the top of the list
136	knownProviders, err := config.Providers(cfg)
137	if err != nil {
138		return util.ReportError(err)
139	}
140	for providerID, providerConfig := range cfg.Providers.Seq2() {
141		if providerConfig.Disable {
142			continue
143		}
144
145		// Check if this provider is not in the known providers list
146		if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) ||
147			!slices.ContainsFunc(m.providers, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
148			// Convert config provider to provider.Provider format
149			configProvider := catwalk.Provider{
150				Name:   providerConfig.Name,
151				ID:     catwalk.InferenceProvider(providerID),
152				Models: make([]catwalk.Model, len(providerConfig.Models)),
153			}
154
155			// Convert models
156			for i, model := range providerConfig.Models {
157				configProvider.Models[i] = catwalk.Model{
158					ID:                     model.ID,
159					Name:                   model.Name,
160					CostPer1MIn:            model.CostPer1MIn,
161					CostPer1MOut:           model.CostPer1MOut,
162					CostPer1MInCached:      model.CostPer1MInCached,
163					CostPer1MOutCached:     model.CostPer1MOutCached,
164					ContextWindow:          model.ContextWindow,
165					DefaultMaxTokens:       model.DefaultMaxTokens,
166					CanReason:              model.CanReason,
167					ReasoningLevels:        model.ReasoningLevels,
168					DefaultReasoningEffort: model.DefaultReasoningEffort,
169					SupportsImages:         model.SupportsImages,
170				}
171			}
172
173			// Add this unknown provider to the list
174			name := configProvider.Name
175			if name == "" {
176				name = string(configProvider.ID)
177			}
178			section := list.NewItemSection(name)
179			section.SetInfo(configured)
180			group := list.Group[list.CompletionItem[ModelOption]]{
181				Section: section,
182			}
183			for _, model := range configProvider.Models {
184				modelOption := ModelOption{
185					Provider: configProvider,
186					Model:    model,
187				}
188				key := modelKey(string(configProvider.ID), model.ID)
189				item := list.NewCompletionItem(
190					model.Name,
191					modelOption,
192					list.WithCompletionID(key),
193				)
194				itemsByKey[key] = item
195
196				group.Items = append(group.Items, item)
197				if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
198					selectedItemID = item.ID()
199				}
200			}
201			groups = append(groups, group)
202
203			addedProviders[providerID] = true
204		}
205	}
206
207	// Then add the known providers from the predefined list
208	for _, provider := range m.providers {
209		// Skip if we already added this provider as an unknown provider
210		if addedProviders[string(provider.ID)] {
211			continue
212		}
213
214		providerConfig, providerConfigured := cfg.Providers.Get(string(provider.ID))
215		if providerConfigured && providerConfig.Disable {
216			continue
217		}
218
219		displayProvider := provider
220		if providerConfigured {
221			displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
222			modelIndex := make(map[string]int, len(displayProvider.Models))
223			for i, model := range displayProvider.Models {
224				modelIndex[model.ID] = i
225			}
226			for _, model := range providerConfig.Models {
227				if model.ID == "" {
228					continue
229				}
230				if idx, ok := modelIndex[model.ID]; ok {
231					if model.Name != "" {
232						displayProvider.Models[idx].Name = model.Name
233					}
234					continue
235				}
236				if model.Name == "" {
237					model.Name = model.ID
238				}
239				displayProvider.Models = append(displayProvider.Models, model)
240				modelIndex[model.ID] = len(displayProvider.Models) - 1
241			}
242		}
243
244		name := displayProvider.Name
245		if name == "" {
246			name = string(displayProvider.ID)
247		}
248
249		section := list.NewItemSection(name)
250		if providerConfigured {
251			section.SetInfo(configured)
252		}
253		group := list.Group[list.CompletionItem[ModelOption]]{
254			Section: section,
255		}
256		for _, model := range displayProvider.Models {
257			modelOption := ModelOption{
258				Provider: displayProvider,
259				Model:    model,
260			}
261			key := modelKey(string(displayProvider.ID), model.ID)
262			item := list.NewCompletionItem(
263				model.Name,
264				modelOption,
265				list.WithCompletionID(key),
266			)
267			itemsByKey[key] = item
268			group.Items = append(group.Items, item)
269			if model.ID == currentModel.Model && string(displayProvider.ID) == currentModel.Provider {
270				selectedItemID = item.ID()
271			}
272		}
273		groups = append(groups, group)
274	}
275
276	if len(recentItems) > 0 {
277		recentSection := list.NewItemSection("Recently used")
278		recentGroup := list.Group[list.CompletionItem[ModelOption]]{
279			Section: recentSection,
280		}
281		var validRecentItems []config.SelectedModel
282		for _, recent := range recentItems {
283			key := modelKey(recent.Provider, recent.Model)
284			option, ok := itemsByKey[key]
285			if !ok {
286				continue
287			}
288			validRecentItems = append(validRecentItems, recent)
289			recentID := fmt.Sprintf("recent::%s", key)
290			modelOption := option.Value()
291			providerName := modelOption.Provider.Name
292			if providerName == "" {
293				providerName = string(modelOption.Provider.ID)
294			}
295			item := list.NewCompletionItem(
296				modelOption.Model.Name,
297				option.Value(),
298				list.WithCompletionID(recentID),
299				list.WithCompletionShortcut(providerName),
300			)
301			recentGroup.Items = append(recentGroup.Items, item)
302			if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider {
303				selectedItemID = recentID
304			}
305		}
306
307		if len(validRecentItems) != len(recentItems) {
308			if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
309				return util.ReportError(err)
310			}
311		}
312
313		if len(recentGroup.Items) > 0 {
314			groups = append([]list.Group[list.CompletionItem[ModelOption]]{recentGroup}, groups...)
315		}
316	}
317
318	var cmds []tea.Cmd
319
320	cmd := m.list.SetGroups(groups)
321
322	if cmd != nil {
323		cmds = append(cmds, cmd)
324	}
325	cmd = m.list.SetSelected(selectedItemID)
326	if cmd != nil {
327		cmds = append(cmds, cmd)
328	}
329
330	return tea.Sequence(cmds...)
331}
332
333// GetModelType returns the current model type
334func (m *ModelListComponent) GetModelType() int {
335	return m.modelType
336}
337
338func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
339	m.list.SetInputPlaceholder(placeholder)
340}