1package models
2
3import (
4 "cmp"
5 "fmt"
6 "slices"
7 "strings"
8
9 tea "charm.land/bubbletea/v2"
10 "github.com/charmbracelet/catwalk/pkg/catwalk"
11 "github.com/charmbracelet/crush/internal/config"
12 "github.com/charmbracelet/crush/internal/tui/exp/list"
13 "github.com/charmbracelet/crush/internal/tui/styles"
14 "github.com/charmbracelet/crush/internal/tui/util"
15)
16
17type listModel = list.FilterableGroupList[list.CompletionItem[ModelOption]]
18
19type ModelListComponent struct {
20 list listModel
21 modelType int
22 providers []catwalk.Provider
23}
24
25func modelKey(providerID, modelID string) string {
26 if providerID == "" || modelID == "" {
27 return ""
28 }
29 return providerID + ":" + modelID
30}
31
32func NewModelListComponent(keyMap list.KeyMap, inputPlaceholder string, shouldResize bool) *ModelListComponent {
33 t := styles.CurrentTheme()
34 inputStyle := t.S().Base.PaddingLeft(1).PaddingBottom(1)
35 options := []list.ListOption{
36 list.WithKeyMap(keyMap),
37 list.WithWrapNavigation(),
38 }
39 if shouldResize {
40 options = append(options, list.WithResizeByList())
41 }
42 modelList := list.NewFilterableGroupedList(
43 []list.Group[list.CompletionItem[ModelOption]]{},
44 list.WithFilterInputStyle(inputStyle),
45 list.WithFilterPlaceholder(inputPlaceholder),
46 list.WithFilterListOptions(
47 options...,
48 ),
49 )
50
51 return &ModelListComponent{
52 list: modelList,
53 modelType: LargeModelType,
54 }
55}
56
57func (m *ModelListComponent) Init() tea.Cmd {
58 var cmds []tea.Cmd
59 if len(m.providers) == 0 {
60 cfg := config.Get()
61 providers, err := config.Providers(cfg)
62 filteredProviders := []catwalk.Provider{}
63 for _, p := range providers {
64 hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
65 isHyper := p.ID == "hyper"
66 if (hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure) || isHyper {
67 filteredProviders = append(filteredProviders, p)
68 }
69 }
70
71 m.providers = filteredProviders
72 if err != nil {
73 cmds = append(cmds, util.ReportError(err))
74 }
75 }
76 cmds = append(cmds, m.list.Init(), m.SetModelType(m.modelType))
77 return tea.Batch(cmds...)
78}
79
80func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
81 u, cmd := m.list.Update(msg)
82 m.list = u.(listModel)
83 return m, cmd
84}
85
86func (m *ModelListComponent) View() string {
87 return m.list.View()
88}
89
90func (m *ModelListComponent) Cursor() *tea.Cursor {
91 return m.list.Cursor()
92}
93
94func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
95 return m.list.SetSize(width, height)
96}
97
98func (m *ModelListComponent) SelectedModel() *ModelOption {
99 s := m.list.SelectedItem()
100 if s == nil {
101 return nil
102 }
103 sv := *s
104 model := sv.Value()
105 return &model
106}
107
108func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
109 t := styles.CurrentTheme()
110 m.modelType = modelType
111
112 var groups []list.Group[list.CompletionItem[ModelOption]]
113 // first none section
114 selectedItemID := ""
115 itemsByKey := make(map[string]list.CompletionItem[ModelOption])
116
117 cfg := config.Get()
118 var currentModel config.SelectedModel
119 selectedType := config.SelectedModelTypeLarge
120 if m.modelType == LargeModelType {
121 currentModel = cfg.Models[config.SelectedModelTypeLarge]
122 selectedType = config.SelectedModelTypeLarge
123 } else {
124 currentModel = cfg.Models[config.SelectedModelTypeSmall]
125 selectedType = config.SelectedModelTypeSmall
126 }
127 recentItems := cfg.RecentModels[selectedType]
128
129 configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
130 configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
131
132 // Create a map to track which providers we've already added
133 addedProviders := make(map[string]bool)
134
135 // First, add any configured providers that are not in the known providers list
136 // These should appear at the top of the list
137 knownProviders, err := config.Providers(cfg)
138 if err != nil {
139 return util.ReportError(err)
140 }
141 for providerID, providerConfig := range cfg.Providers.Seq2() {
142 if providerConfig.Disable {
143 continue
144 }
145
146 // Check if this provider is not in the known providers list
147 if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) ||
148 !slices.ContainsFunc(m.providers, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
149 // Convert config provider to provider.Provider format
150 configProvider := catwalk.Provider{
151 Name: providerConfig.Name,
152 ID: catwalk.InferenceProvider(providerID),
153 Models: make([]catwalk.Model, len(providerConfig.Models)),
154 }
155
156 // Convert models
157 for i, model := range providerConfig.Models {
158 configProvider.Models[i] = catwalk.Model{
159 ID: model.ID,
160 Name: model.Name,
161 CostPer1MIn: model.CostPer1MIn,
162 CostPer1MOut: model.CostPer1MOut,
163 CostPer1MInCached: model.CostPer1MInCached,
164 CostPer1MOutCached: model.CostPer1MOutCached,
165 ContextWindow: model.ContextWindow,
166 DefaultMaxTokens: model.DefaultMaxTokens,
167 CanReason: model.CanReason,
168 ReasoningLevels: model.ReasoningLevels,
169 DefaultReasoningEffort: model.DefaultReasoningEffort,
170 SupportsImages: model.SupportsImages,
171 }
172 }
173
174 // Add this unknown provider to the list
175 name := configProvider.Name
176 if name == "" {
177 name = string(configProvider.ID)
178 }
179 section := list.NewItemSection(name)
180 section.SetInfo(configured)
181 group := list.Group[list.CompletionItem[ModelOption]]{
182 Section: section,
183 }
184 for _, model := range configProvider.Models {
185 modelOption := ModelOption{
186 Provider: configProvider,
187 Model: model,
188 }
189 key := modelKey(string(configProvider.ID), model.ID)
190 item := list.NewCompletionItem(
191 model.Name,
192 modelOption,
193 list.WithCompletionID(key),
194 )
195 itemsByKey[key] = item
196
197 group.Items = append(group.Items, item)
198 if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
199 selectedItemID = item.ID()
200 }
201 }
202 groups = append(groups, group)
203
204 addedProviders[providerID] = true
205 }
206 }
207
208 // Move "Charm Hyper" to first position
209 // (but still after recent models and custom providers).
210 sortedProviders := make([]catwalk.Provider, len(m.providers))
211 copy(sortedProviders, m.providers)
212 slices.SortStableFunc(sortedProviders, func(a, b catwalk.Provider) int {
213 switch {
214 case a.ID == "hyper":
215 return -1
216 case b.ID == "hyper":
217 return 1
218 default:
219 return 0
220 }
221 })
222
223 // Then add the known providers from the predefined list
224 for _, provider := range sortedProviders {
225 // Skip if we already added this provider as an unknown provider
226 if addedProviders[string(provider.ID)] {
227 continue
228 }
229
230 providerConfig, providerConfigured := cfg.Providers.Get(string(provider.ID))
231 if providerConfigured && providerConfig.Disable {
232 continue
233 }
234
235 displayProvider := provider
236 if providerConfigured {
237 displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
238 modelIndex := make(map[string]int, len(displayProvider.Models))
239 for i, model := range displayProvider.Models {
240 modelIndex[model.ID] = i
241 }
242 for _, model := range providerConfig.Models {
243 if model.ID == "" {
244 continue
245 }
246 if idx, ok := modelIndex[model.ID]; ok {
247 if model.Name != "" {
248 displayProvider.Models[idx].Name = model.Name
249 }
250 continue
251 }
252 if model.Name == "" {
253 model.Name = model.ID
254 }
255 displayProvider.Models = append(displayProvider.Models, model)
256 modelIndex[model.ID] = len(displayProvider.Models) - 1
257 }
258 }
259
260 name := displayProvider.Name
261 if name == "" {
262 name = string(displayProvider.ID)
263 }
264
265 section := list.NewItemSection(name)
266 if providerConfigured {
267 section.SetInfo(configured)
268 }
269 group := list.Group[list.CompletionItem[ModelOption]]{
270 Section: section,
271 }
272 for _, model := range displayProvider.Models {
273 modelOption := ModelOption{
274 Provider: displayProvider,
275 Model: model,
276 }
277 key := modelKey(string(displayProvider.ID), model.ID)
278 item := list.NewCompletionItem(
279 model.Name,
280 modelOption,
281 list.WithCompletionID(key),
282 )
283 itemsByKey[key] = item
284 group.Items = append(group.Items, item)
285 if model.ID == currentModel.Model && string(displayProvider.ID) == currentModel.Provider {
286 selectedItemID = item.ID()
287 }
288 }
289 groups = append(groups, group)
290 }
291
292 if len(recentItems) > 0 {
293 recentSection := list.NewItemSection("Recently used")
294 recentGroup := list.Group[list.CompletionItem[ModelOption]]{
295 Section: recentSection,
296 }
297 var validRecentItems []config.SelectedModel
298 for _, recent := range recentItems {
299 key := modelKey(recent.Provider, recent.Model)
300 option, ok := itemsByKey[key]
301 if !ok {
302 continue
303 }
304 validRecentItems = append(validRecentItems, recent)
305 recentID := fmt.Sprintf("recent::%s", key)
306 modelOption := option.Value()
307 providerName := modelOption.Provider.Name
308 if providerName == "" {
309 providerName = string(modelOption.Provider.ID)
310 }
311 item := list.NewCompletionItem(
312 modelOption.Model.Name,
313 option.Value(),
314 list.WithCompletionID(recentID),
315 list.WithCompletionShortcut(providerName),
316 )
317 recentGroup.Items = append(recentGroup.Items, item)
318 if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider {
319 selectedItemID = recentID
320 }
321 }
322
323 if len(validRecentItems) != len(recentItems) {
324 if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
325 return util.ReportError(err)
326 }
327 }
328
329 if len(recentGroup.Items) > 0 {
330 groups = append([]list.Group[list.CompletionItem[ModelOption]]{recentGroup}, groups...)
331 }
332 }
333
334 var cmds []tea.Cmd
335
336 cmd := m.list.SetGroups(groups)
337
338 if cmd != nil {
339 cmds = append(cmds, cmd)
340 }
341 cmd = m.list.SetSelected(selectedItemID)
342 if cmd != nil {
343 cmds = append(cmds, cmd)
344 }
345
346 return tea.Sequence(cmds...)
347}
348
349// GetModelType returns the current model type
350func (m *ModelListComponent) GetModelType() int {
351 return m.modelType
352}
353
354func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
355 m.list.SetInputPlaceholder(placeholder)
356}