1package models
2
3import (
4 "cmp"
5 "fmt"
6 "slices"
7 "strings"
8
9 tea "charm.land/bubbletea/v2"
10 "github.com/charmbracelet/catwalk/pkg/catwalk"
11 "github.com/charmbracelet/crush/internal/config"
12 "github.com/charmbracelet/crush/internal/tui/exp/list"
13 "github.com/charmbracelet/crush/internal/tui/styles"
14 "github.com/charmbracelet/crush/internal/tui/util"
15)
16
17type listModel = list.FilterableGroupList[list.CompletionItem[ModelOption]]
18
19type ModelListComponent struct {
20 list listModel
21 modelType int
22 providers []catwalk.Provider
23}
24
25func modelKey(providerID, modelID string) string {
26 if providerID == "" || modelID == "" {
27 return ""
28 }
29 return providerID + ":" + modelID
30}
31
32func NewModelListComponent(keyMap list.KeyMap, inputPlaceholder string, shouldResize bool) *ModelListComponent {
33 t := styles.CurrentTheme()
34 inputStyle := t.S().Base.PaddingLeft(1).PaddingBottom(1)
35 options := []list.ListOption{
36 list.WithKeyMap(keyMap),
37 list.WithWrapNavigation(),
38 }
39 if shouldResize {
40 options = append(options, list.WithResizeByList())
41 }
42 modelList := list.NewFilterableGroupedList(
43 []list.Group[list.CompletionItem[ModelOption]]{},
44 list.WithFilterInputStyle(inputStyle),
45 list.WithFilterPlaceholder(inputPlaceholder),
46 list.WithFilterListOptions(
47 options...,
48 ),
49 )
50
51 return &ModelListComponent{
52 list: modelList,
53 modelType: LargeModelType,
54 }
55}
56
57func (m *ModelListComponent) Init() tea.Cmd {
58 var cmds []tea.Cmd
59 if len(m.providers) == 0 {
60 cfg := config.Get()
61 providers, err := config.Providers(cfg)
62 filteredProviders := []catwalk.Provider{}
63 for _, p := range providers {
64 hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
65 isHyper := p.ID == "hyper"
66 isCopilot := p.ID == catwalk.InferenceProviderCopilot
67 if (hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure) || isHyper || isCopilot {
68 filteredProviders = append(filteredProviders, p)
69 }
70 }
71
72 m.providers = filteredProviders
73 if err != nil {
74 cmds = append(cmds, util.ReportError(err))
75 }
76 }
77 cmds = append(cmds, m.list.Init(), m.SetModelType(m.modelType))
78 return tea.Batch(cmds...)
79}
80
81func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
82 u, cmd := m.list.Update(msg)
83 m.list = u.(listModel)
84 return m, cmd
85}
86
87func (m *ModelListComponent) View() string {
88 return m.list.View()
89}
90
91func (m *ModelListComponent) Cursor() *tea.Cursor {
92 return m.list.Cursor()
93}
94
95func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
96 return m.list.SetSize(width, height)
97}
98
99func (m *ModelListComponent) SelectedModel() *ModelOption {
100 s := m.list.SelectedItem()
101 if s == nil {
102 return nil
103 }
104 sv := *s
105 model := sv.Value()
106 return &model
107}
108
109func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
110 t := styles.CurrentTheme()
111 m.modelType = modelType
112
113 var groups []list.Group[list.CompletionItem[ModelOption]]
114 // first none section
115 selectedItemID := ""
116 itemsByKey := make(map[string]list.CompletionItem[ModelOption])
117
118 cfg := config.Get()
119 var currentModel config.SelectedModel
120 selectedType := config.SelectedModelTypeLarge
121 if m.modelType == LargeModelType {
122 currentModel = cfg.Models[config.SelectedModelTypeLarge]
123 selectedType = config.SelectedModelTypeLarge
124 } else {
125 currentModel = cfg.Models[config.SelectedModelTypeSmall]
126 selectedType = config.SelectedModelTypeSmall
127 }
128 recentItems := cfg.RecentModels[selectedType]
129
130 configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
131 configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
132
133 // Create a map to track which providers we've already added
134 addedProviders := make(map[string]bool)
135
136 // First, add any configured providers that are not in the known providers list
137 // These should appear at the top of the list
138 knownProviders, err := config.Providers(cfg)
139 if err != nil {
140 return util.ReportError(err)
141 }
142 for providerID, providerConfig := range cfg.Providers.Seq2() {
143 if providerConfig.Disable {
144 continue
145 }
146
147 // Check if this provider is not in the known providers list
148 if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) ||
149 !slices.ContainsFunc(m.providers, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
150 // Convert config provider to provider.Provider format
151 configProvider := providerConfig.ToProvider()
152
153 // Add this unknown provider to the list
154 name := configProvider.Name
155 if name == "" {
156 name = string(configProvider.ID)
157 }
158 section := list.NewItemSection(name)
159 section.SetInfo(configured)
160 group := list.Group[list.CompletionItem[ModelOption]]{
161 Section: section,
162 }
163 for _, model := range configProvider.Models {
164 modelOption := ModelOption{
165 Provider: configProvider,
166 Model: model,
167 }
168 key := modelKey(string(configProvider.ID), model.ID)
169 item := list.NewCompletionItem(
170 model.Name,
171 modelOption,
172 list.WithCompletionID(key),
173 )
174 itemsByKey[key] = item
175
176 group.Items = append(group.Items, item)
177 if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
178 selectedItemID = item.ID()
179 }
180 }
181 groups = append(groups, group)
182
183 addedProviders[providerID] = true
184 }
185 }
186
187 // Move "Charm Hyper" to first position
188 // (but still after recent models and custom providers).
189 sortedProviders := make([]catwalk.Provider, len(m.providers))
190 copy(sortedProviders, m.providers)
191 slices.SortStableFunc(sortedProviders, func(a, b catwalk.Provider) int {
192 switch {
193 case a.ID == "hyper":
194 return -1
195 case b.ID == "hyper":
196 return 1
197 default:
198 return 0
199 }
200 })
201
202 // Then add the known providers from the predefined list
203 for _, provider := range sortedProviders {
204 // Skip if we already added this provider as an unknown provider
205 if addedProviders[string(provider.ID)] {
206 continue
207 }
208
209 providerConfig, providerConfigured := cfg.Providers.Get(string(provider.ID))
210 if providerConfigured && providerConfig.Disable {
211 continue
212 }
213
214 displayProvider := provider
215 if providerConfigured {
216 displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
217 modelIndex := make(map[string]int, len(displayProvider.Models))
218 for i, model := range displayProvider.Models {
219 modelIndex[model.ID] = i
220 }
221 for _, model := range providerConfig.Models {
222 if model.ID == "" {
223 continue
224 }
225 if idx, ok := modelIndex[model.ID]; ok {
226 if model.Name != "" {
227 displayProvider.Models[idx].Name = model.Name
228 }
229 continue
230 }
231 if model.Name == "" {
232 model.Name = model.ID
233 }
234 displayProvider.Models = append(displayProvider.Models, model)
235 modelIndex[model.ID] = len(displayProvider.Models) - 1
236 }
237 }
238
239 name := displayProvider.Name
240 if name == "" {
241 name = string(displayProvider.ID)
242 }
243
244 section := list.NewItemSection(name)
245 if providerConfigured {
246 section.SetInfo(configured)
247 }
248 group := list.Group[list.CompletionItem[ModelOption]]{
249 Section: section,
250 }
251 for _, model := range displayProvider.Models {
252 modelOption := ModelOption{
253 Provider: displayProvider,
254 Model: model,
255 }
256 key := modelKey(string(displayProvider.ID), model.ID)
257 item := list.NewCompletionItem(
258 model.Name,
259 modelOption,
260 list.WithCompletionID(key),
261 )
262 itemsByKey[key] = item
263 group.Items = append(group.Items, item)
264 if model.ID == currentModel.Model && string(displayProvider.ID) == currentModel.Provider {
265 selectedItemID = item.ID()
266 }
267 }
268 groups = append(groups, group)
269 }
270
271 if len(recentItems) > 0 {
272 recentSection := list.NewItemSection("Recently used")
273 recentGroup := list.Group[list.CompletionItem[ModelOption]]{
274 Section: recentSection,
275 }
276 var validRecentItems []config.SelectedModel
277 for _, recent := range recentItems {
278 key := modelKey(recent.Provider, recent.Model)
279 option, ok := itemsByKey[key]
280 if !ok {
281 continue
282 }
283 validRecentItems = append(validRecentItems, recent)
284 recentID := fmt.Sprintf("recent::%s", key)
285 modelOption := option.Value()
286 providerName := modelOption.Provider.Name
287 if providerName == "" {
288 providerName = string(modelOption.Provider.ID)
289 }
290 item := list.NewCompletionItem(
291 modelOption.Model.Name,
292 option.Value(),
293 list.WithCompletionID(recentID),
294 list.WithCompletionShortcut(providerName),
295 )
296 recentGroup.Items = append(recentGroup.Items, item)
297 if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider {
298 selectedItemID = recentID
299 }
300 }
301
302 if len(validRecentItems) != len(recentItems) {
303 if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
304 return util.ReportError(err)
305 }
306 }
307
308 if len(recentGroup.Items) > 0 {
309 groups = append([]list.Group[list.CompletionItem[ModelOption]]{recentGroup}, groups...)
310 }
311 }
312
313 var cmds []tea.Cmd
314
315 cmd := m.list.SetGroups(groups)
316
317 if cmd != nil {
318 cmds = append(cmds, cmd)
319 }
320 cmd = m.list.SetSelected(selectedItemID)
321 if cmd != nil {
322 cmds = append(cmds, cmd)
323 }
324
325 return tea.Sequence(cmds...)
326}
327
328// GetModelType returns the current model type
329func (m *ModelListComponent) GetModelType() int {
330 return m.modelType
331}
332
333func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
334 m.list.SetInputPlaceholder(placeholder)
335}