1package models
2
3import (
4 "cmp"
5 "fmt"
6 "slices"
7 "strings"
8
9 tea "charm.land/bubbletea/v2"
10 "github.com/charmbracelet/catwalk/pkg/catwalk"
11 "github.com/charmbracelet/crush/internal/config"
12 "github.com/charmbracelet/crush/internal/tui/exp/list"
13 "github.com/charmbracelet/crush/internal/tui/styles"
14 "github.com/charmbracelet/crush/internal/tui/util"
15)
16
17type listModel = list.FilterableGroupList[list.CompletionItem[ModelOption]]
18
19type ModelListComponent struct {
20 list listModel
21 modelType int
22 providers []catwalk.Provider
23}
24
25func modelKey(providerID, modelID string) string {
26 if providerID == "" || modelID == "" {
27 return ""
28 }
29 return providerID + ":" + modelID
30}
31
32func NewModelListComponent(keyMap list.KeyMap, inputPlaceholder string, shouldResize bool) *ModelListComponent {
33 t := styles.CurrentTheme()
34 inputStyle := t.S().Base.PaddingLeft(1).PaddingBottom(1)
35 options := []list.ListOption{
36 list.WithKeyMap(keyMap),
37 list.WithWrapNavigation(),
38 }
39 if shouldResize {
40 options = append(options, list.WithResizeByList())
41 }
42 modelList := list.NewFilterableGroupedList(
43 []list.Group[list.CompletionItem[ModelOption]]{},
44 list.WithFilterInputStyle(inputStyle),
45 list.WithFilterPlaceholder(inputPlaceholder),
46 list.WithFilterListOptions(
47 options...,
48 ),
49 )
50
51 return &ModelListComponent{
52 list: modelList,
53 modelType: LargeModelType,
54 }
55}
56
57func (m *ModelListComponent) Init() tea.Cmd {
58 var cmds []tea.Cmd
59 if len(m.providers) == 0 {
60 cfg := config.Get()
61 providers, err := config.Providers(cfg)
62 filteredProviders := []catwalk.Provider{}
63 for _, p := range providers {
64 hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
65 isHyper := p.ID == "hyper"
66 isCopilot := p.ID == catwalk.InferenceProviderCopilot
67 if (hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure) || isHyper || isCopilot {
68 filteredProviders = append(filteredProviders, p)
69 }
70 }
71
72 m.providers = filteredProviders
73 if err != nil {
74 cmds = append(cmds, util.ReportError(err))
75 }
76 }
77 cmds = append(cmds, m.list.Init(), m.SetModelType(m.modelType))
78 return tea.Batch(cmds...)
79}
80
81func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
82 u, cmd := m.list.Update(msg)
83 m.list = u.(listModel)
84 return m, cmd
85}
86
87func (m *ModelListComponent) View() string {
88 return m.list.View()
89}
90
91func (m *ModelListComponent) Cursor() *tea.Cursor {
92 return m.list.Cursor()
93}
94
95func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
96 return m.list.SetSize(width, height)
97}
98
99func (m *ModelListComponent) SelectedModel() *ModelOption {
100 s := m.list.SelectedItem()
101 if s == nil {
102 return nil
103 }
104 sv := *s
105 model := sv.Value()
106 return &model
107}
108
109func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
110 t := styles.CurrentTheme()
111 m.modelType = modelType
112
113 var groups []list.Group[list.CompletionItem[ModelOption]]
114 // first none section
115 selectedItemID := ""
116 itemsByKey := make(map[string]list.CompletionItem[ModelOption])
117
118 cfg := config.Get()
119 var currentModel config.SelectedModel
120 selectedType := config.SelectedModelTypeLarge
121 if m.modelType == LargeModelType {
122 currentModel = cfg.Models[config.SelectedModelTypeLarge]
123 selectedType = config.SelectedModelTypeLarge
124 } else {
125 currentModel = cfg.Models[config.SelectedModelTypeSmall]
126 selectedType = config.SelectedModelTypeSmall
127 }
128 recentItems := cfg.RecentModels[selectedType]
129
130 configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
131 configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
132
133 // Create a map to track which providers we've already added
134 addedProviders := make(map[string]bool)
135
136 // First, add any configured providers that are not in the known providers list
137 // These should appear at the top of the list
138 knownProviders, err := config.Providers(cfg)
139 if err != nil {
140 return util.ReportError(err)
141 }
142 for providerID, providerConfig := range cfg.Providers.Seq2() {
143 if providerConfig.Disable {
144 continue
145 }
146
147 // Check if this provider is not in the known providers list
148 if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) ||
149 !slices.ContainsFunc(m.providers, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
150 // Convert config provider to provider.Provider format
151 configProvider := providerConfig.ToProvider()
152
153 // Add this unknown provider to the list
154 name := configProvider.Name
155 if name == "" {
156 name = string(configProvider.ID)
157 }
158 section := list.NewItemSection(name)
159 section.SetInfo(configured)
160 group := list.Group[list.CompletionItem[ModelOption]]{
161 Section: section,
162 }
163 for _, model := range configProvider.Models {
164 modelOption := ModelOption{
165 Provider: configProvider,
166 Model: model,
167 }
168 key := modelKey(string(configProvider.ID), model.ID)
169 item := list.NewCompletionItem(
170 model.Name,
171 modelOption,
172 list.WithCompletionID(key),
173 )
174 itemsByKey[key] = item
175
176 group.Items = append(group.Items, item)
177 if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
178 selectedItemID = item.ID()
179 }
180 }
181 groups = append(groups, group)
182
183 addedProviders[providerID] = true
184 }
185 }
186
187 // Move "Charm Hyper" to first position
188 // (but still after recent models and custom providers).
189 slices.SortStableFunc(m.providers, func(a, b catwalk.Provider) int {
190 switch {
191 case a.ID == "hyper":
192 return -1
193 case b.ID == "hyper":
194 return 1
195 default:
196 return 0
197 }
198 })
199
200 // Then add the known providers from the predefined list
201 for _, provider := range m.providers {
202 // Skip if we already added this provider as an unknown provider
203 if addedProviders[string(provider.ID)] {
204 continue
205 }
206
207 providerConfig, providerConfigured := cfg.Providers.Get(string(provider.ID))
208 if providerConfigured && providerConfig.Disable {
209 continue
210 }
211
212 displayProvider := provider
213 if providerConfigured {
214 displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
215 modelIndex := make(map[string]int, len(displayProvider.Models))
216 for i, model := range displayProvider.Models {
217 modelIndex[model.ID] = i
218 }
219 for _, model := range providerConfig.Models {
220 if model.ID == "" {
221 continue
222 }
223 if idx, ok := modelIndex[model.ID]; ok {
224 if model.Name != "" {
225 displayProvider.Models[idx].Name = model.Name
226 }
227 continue
228 }
229 if model.Name == "" {
230 model.Name = model.ID
231 }
232 displayProvider.Models = append(displayProvider.Models, model)
233 modelIndex[model.ID] = len(displayProvider.Models) - 1
234 }
235 }
236
237 name := displayProvider.Name
238 if name == "" {
239 name = string(displayProvider.ID)
240 }
241
242 section := list.NewItemSection(name)
243 if providerConfigured {
244 section.SetInfo(configured)
245 }
246 group := list.Group[list.CompletionItem[ModelOption]]{
247 Section: section,
248 }
249 for _, model := range displayProvider.Models {
250 modelOption := ModelOption{
251 Provider: displayProvider,
252 Model: model,
253 }
254 key := modelKey(string(displayProvider.ID), model.ID)
255 item := list.NewCompletionItem(
256 model.Name,
257 modelOption,
258 list.WithCompletionID(key),
259 )
260 itemsByKey[key] = item
261 group.Items = append(group.Items, item)
262 if model.ID == currentModel.Model && string(displayProvider.ID) == currentModel.Provider {
263 selectedItemID = item.ID()
264 }
265 }
266 groups = append(groups, group)
267 }
268
269 if len(recentItems) > 0 {
270 recentSection := list.NewItemSection("Recently used")
271 recentGroup := list.Group[list.CompletionItem[ModelOption]]{
272 Section: recentSection,
273 }
274 var validRecentItems []config.SelectedModel
275 for _, recent := range recentItems {
276 key := modelKey(recent.Provider, recent.Model)
277 option, ok := itemsByKey[key]
278 if !ok {
279 continue
280 }
281 validRecentItems = append(validRecentItems, recent)
282 recentID := fmt.Sprintf("recent::%s", key)
283 modelOption := option.Value()
284 providerName := modelOption.Provider.Name
285 if providerName == "" {
286 providerName = string(modelOption.Provider.ID)
287 }
288 item := list.NewCompletionItem(
289 modelOption.Model.Name,
290 option.Value(),
291 list.WithCompletionID(recentID),
292 list.WithCompletionShortcut(providerName),
293 )
294 recentGroup.Items = append(recentGroup.Items, item)
295 if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider {
296 selectedItemID = recentID
297 }
298 }
299
300 if len(validRecentItems) != len(recentItems) {
301 if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
302 return util.ReportError(err)
303 }
304 }
305
306 if len(recentGroup.Items) > 0 {
307 groups = append([]list.Group[list.CompletionItem[ModelOption]]{recentGroup}, groups...)
308 }
309 }
310
311 var cmds []tea.Cmd
312
313 cmd := m.list.SetGroups(groups)
314
315 if cmd != nil {
316 cmds = append(cmds, cmd)
317 }
318 cmd = m.list.SetSelected(selectedItemID)
319 if cmd != nil {
320 cmds = append(cmds, cmd)
321 }
322
323 return tea.Sequence(cmds...)
324}
325
326// GetModelType returns the current model type
327func (m *ModelListComponent) GetModelType() int {
328 return m.modelType
329}
330
331func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
332 m.list.SetInputPlaceholder(placeholder)
333}