1package models
2
3import (
4 "cmp"
5 "fmt"
6 "slices"
7 "strings"
8
9 tea "charm.land/bubbletea/v2"
10 "github.com/charmbracelet/catwalk/pkg/catwalk"
11 "github.com/charmbracelet/crush/internal/config"
12 "github.com/charmbracelet/crush/internal/tui/exp/list"
13 "github.com/charmbracelet/crush/internal/tui/styles"
14 "github.com/charmbracelet/crush/internal/tui/util"
15)
16
17type listModel = list.FilterableGroupList[list.CompletionItem[ModelOption]]
18
19type ModelListComponent struct {
20 list listModel
21 modelType int
22 providers []catwalk.Provider
23}
24
25func modelKey(providerID, modelID string) string {
26 if providerID == "" || modelID == "" {
27 return ""
28 }
29 return providerID + ":" + modelID
30}
31
32func NewModelListComponent(keyMap list.KeyMap, inputPlaceholder string, shouldResize bool) *ModelListComponent {
33 t := styles.CurrentTheme()
34 inputStyle := t.S().Base.PaddingLeft(1).PaddingBottom(1)
35 options := []list.ListOption{
36 list.WithKeyMap(keyMap),
37 list.WithWrapNavigation(),
38 }
39 if shouldResize {
40 options = append(options, list.WithResizeByList())
41 }
42 modelList := list.NewFilterableGroupedList(
43 []list.Group[list.CompletionItem[ModelOption]]{},
44 list.WithFilterInputStyle(inputStyle),
45 list.WithFilterPlaceholder(inputPlaceholder),
46 list.WithFilterListOptions(
47 options...,
48 ),
49 )
50
51 return &ModelListComponent{
52 list: modelList,
53 modelType: LargeModelType,
54 }
55}
56
57func (m *ModelListComponent) Init() tea.Cmd {
58 var cmds []tea.Cmd
59 if len(m.providers) == 0 {
60 cfg := config.Get()
61 providers, err := config.Providers(cfg)
62 filteredProviders := []catwalk.Provider{}
63 for _, p := range providers {
64 hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
65 isHyper := p.ID == "hyper"
66 isCopilot := p.ID == catwalk.InferenceProviderCopilot
67 if (hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure) || isHyper || isCopilot {
68 filteredProviders = append(filteredProviders, p)
69 }
70 }
71
72 m.providers = filteredProviders
73 if err != nil {
74 cmds = append(cmds, util.ReportError(err))
75 }
76 }
77 cmds = append(cmds, m.list.Init(), m.SetModelType(m.modelType))
78 return tea.Batch(cmds...)
79}
80
81func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
82 u, cmd := m.list.Update(msg)
83 m.list = u.(listModel)
84 return m, cmd
85}
86
87func (m *ModelListComponent) View() string {
88 return m.list.View()
89}
90
91func (m *ModelListComponent) Cursor() *tea.Cursor {
92 return m.list.Cursor()
93}
94
95func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
96 return m.list.SetSize(width, height)
97}
98
99func (m *ModelListComponent) SelectedModel() *ModelOption {
100 s := m.list.SelectedItem()
101 if s == nil {
102 return nil
103 }
104 sv := *s
105 model := sv.Value()
106 return &model
107}
108
109func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
110 t := styles.CurrentTheme()
111 m.modelType = modelType
112
113 var groups []list.Group[list.CompletionItem[ModelOption]]
114 // first none section
115 selectedItemID := ""
116 itemsByKey := make(map[string]list.CompletionItem[ModelOption])
117
118 cfg := config.Get()
119 var currentModel config.SelectedModel
120 selectedType := config.SelectedModelTypeLarge
121 if m.modelType == LargeModelType {
122 currentModel = cfg.Models[config.SelectedModelTypeLarge]
123 selectedType = config.SelectedModelTypeLarge
124 } else {
125 currentModel = cfg.Models[config.SelectedModelTypeSmall]
126 selectedType = config.SelectedModelTypeSmall
127 }
128 recentItems := cfg.RecentModels[selectedType]
129
130 configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
131 configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
132
133 // Create a map to track which providers we've already added
134 addedProviders := make(map[string]bool)
135
136 // First, add any configured providers that are not in the known providers list
137 // These should appear at the top of the list
138 knownProviders, err := config.Providers(cfg)
139 if err != nil {
140 return util.ReportError(err)
141 }
142 for providerID, providerConfig := range cfg.Providers.Seq2() {
143 if providerConfig.Disable {
144 continue
145 }
146
147 // Check if this provider is not in the known providers list
148 if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) ||
149 !slices.ContainsFunc(m.providers, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
150 // Convert config provider to provider.Provider format
151 configProvider := catwalk.Provider{
152 Name: providerConfig.Name,
153 ID: catwalk.InferenceProvider(providerID),
154 Models: make([]catwalk.Model, len(providerConfig.Models)),
155 }
156
157 // Convert models
158 for i, model := range providerConfig.Models {
159 configProvider.Models[i] = catwalk.Model{
160 ID: model.ID,
161 Name: model.Name,
162 CostPer1MIn: model.CostPer1MIn,
163 CostPer1MOut: model.CostPer1MOut,
164 CostPer1MInCached: model.CostPer1MInCached,
165 CostPer1MOutCached: model.CostPer1MOutCached,
166 ContextWindow: model.ContextWindow,
167 DefaultMaxTokens: model.DefaultMaxTokens,
168 CanReason: model.CanReason,
169 ReasoningLevels: model.ReasoningLevels,
170 DefaultReasoningEffort: model.DefaultReasoningEffort,
171 SupportsImages: model.SupportsImages,
172 }
173 }
174
175 // Add this unknown provider to the list
176 name := configProvider.Name
177 if name == "" {
178 name = string(configProvider.ID)
179 }
180 section := list.NewItemSection(name)
181 section.SetInfo(configured)
182 group := list.Group[list.CompletionItem[ModelOption]]{
183 Section: section,
184 }
185 for _, model := range configProvider.Models {
186 modelOption := ModelOption{
187 Provider: configProvider,
188 Model: model,
189 }
190 key := modelKey(string(configProvider.ID), model.ID)
191 item := list.NewCompletionItem(
192 model.Name,
193 modelOption,
194 list.WithCompletionID(key),
195 )
196 itemsByKey[key] = item
197
198 group.Items = append(group.Items, item)
199 if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
200 selectedItemID = item.ID()
201 }
202 }
203 groups = append(groups, group)
204
205 addedProviders[providerID] = true
206 }
207 }
208
209 // Move "Charm Hyper" to first position
210 // (but still after recent models and custom providers).
211 sortedProviders := make([]catwalk.Provider, len(m.providers))
212 copy(sortedProviders, m.providers)
213 slices.SortStableFunc(sortedProviders, func(a, b catwalk.Provider) int {
214 switch {
215 case a.ID == "hyper":
216 return -1
217 case b.ID == "hyper":
218 return 1
219 default:
220 return 0
221 }
222 })
223
224 // Then add the known providers from the predefined list
225 for _, provider := range sortedProviders {
226 // Skip if we already added this provider as an unknown provider
227 if addedProviders[string(provider.ID)] {
228 continue
229 }
230
231 providerConfig, providerConfigured := cfg.Providers.Get(string(provider.ID))
232 if providerConfigured && providerConfig.Disable {
233 continue
234 }
235
236 displayProvider := provider
237 if providerConfigured {
238 displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
239 modelIndex := make(map[string]int, len(displayProvider.Models))
240 for i, model := range displayProvider.Models {
241 modelIndex[model.ID] = i
242 }
243 for _, model := range providerConfig.Models {
244 if model.ID == "" {
245 continue
246 }
247 if idx, ok := modelIndex[model.ID]; ok {
248 if model.Name != "" {
249 displayProvider.Models[idx].Name = model.Name
250 }
251 continue
252 }
253 if model.Name == "" {
254 model.Name = model.ID
255 }
256 displayProvider.Models = append(displayProvider.Models, model)
257 modelIndex[model.ID] = len(displayProvider.Models) - 1
258 }
259 }
260
261 name := displayProvider.Name
262 if name == "" {
263 name = string(displayProvider.ID)
264 }
265
266 section := list.NewItemSection(name)
267 if providerConfigured {
268 section.SetInfo(configured)
269 }
270 group := list.Group[list.CompletionItem[ModelOption]]{
271 Section: section,
272 }
273 for _, model := range displayProvider.Models {
274 modelOption := ModelOption{
275 Provider: displayProvider,
276 Model: model,
277 }
278 key := modelKey(string(displayProvider.ID), model.ID)
279 item := list.NewCompletionItem(
280 model.Name,
281 modelOption,
282 list.WithCompletionID(key),
283 )
284 itemsByKey[key] = item
285 group.Items = append(group.Items, item)
286 if model.ID == currentModel.Model && string(displayProvider.ID) == currentModel.Provider {
287 selectedItemID = item.ID()
288 }
289 }
290 groups = append(groups, group)
291 }
292
293 if len(recentItems) > 0 {
294 recentSection := list.NewItemSection("Recently used")
295 recentGroup := list.Group[list.CompletionItem[ModelOption]]{
296 Section: recentSection,
297 }
298 var validRecentItems []config.SelectedModel
299 for _, recent := range recentItems {
300 key := modelKey(recent.Provider, recent.Model)
301 option, ok := itemsByKey[key]
302 if !ok {
303 continue
304 }
305 validRecentItems = append(validRecentItems, recent)
306 recentID := fmt.Sprintf("recent::%s", key)
307 modelOption := option.Value()
308 providerName := modelOption.Provider.Name
309 if providerName == "" {
310 providerName = string(modelOption.Provider.ID)
311 }
312 item := list.NewCompletionItem(
313 modelOption.Model.Name,
314 option.Value(),
315 list.WithCompletionID(recentID),
316 list.WithCompletionShortcut(providerName),
317 )
318 recentGroup.Items = append(recentGroup.Items, item)
319 if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider {
320 selectedItemID = recentID
321 }
322 }
323
324 if len(validRecentItems) != len(recentItems) {
325 if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
326 return util.ReportError(err)
327 }
328 }
329
330 if len(recentGroup.Items) > 0 {
331 groups = append([]list.Group[list.CompletionItem[ModelOption]]{recentGroup}, groups...)
332 }
333 }
334
335 var cmds []tea.Cmd
336
337 cmd := m.list.SetGroups(groups)
338
339 if cmd != nil {
340 cmds = append(cmds, cmd)
341 }
342 cmd = m.list.SetSelected(selectedItemID)
343 if cmd != nil {
344 cmds = append(cmds, cmd)
345 }
346
347 return tea.Sequence(cmds...)
348}
349
350// GetModelType returns the current model type
351func (m *ModelListComponent) GetModelType() int {
352 return m.modelType
353}
354
355func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
356 m.list.SetInputPlaceholder(placeholder)
357}