1package models
2
3import (
4 "cmp"
5 "fmt"
6 "slices"
7 "strings"
8
9 tea "charm.land/bubbletea/v2"
10 "git.secluded.site/crush/internal/config"
11 "git.secluded.site/crush/internal/tui/exp/list"
12 "git.secluded.site/crush/internal/tui/styles"
13 "git.secluded.site/crush/internal/tui/util"
14 "github.com/charmbracelet/catwalk/pkg/catwalk"
15)
16
17type listModel = list.FilterableGroupList[list.CompletionItem[ModelOption]]
18
19type ModelListComponent struct {
20 list listModel
21 modelType int
22 providers []catwalk.Provider
23}
24
25func modelKey(providerID, modelID string) string {
26 if providerID == "" || modelID == "" {
27 return ""
28 }
29 return providerID + ":" + modelID
30}
31
32func NewModelListComponent(keyMap list.KeyMap, inputPlaceholder string, shouldResize bool) *ModelListComponent {
33 t := styles.CurrentTheme()
34 inputStyle := t.S().Base.PaddingLeft(1).PaddingBottom(1)
35 options := []list.ListOption{
36 list.WithKeyMap(keyMap),
37 list.WithWrapNavigation(),
38 }
39 if shouldResize {
40 options = append(options, list.WithResizeByList())
41 }
42 modelList := list.NewFilterableGroupedList(
43 []list.Group[list.CompletionItem[ModelOption]]{},
44 list.WithFilterInputStyle(inputStyle),
45 list.WithFilterPlaceholder(inputPlaceholder),
46 list.WithFilterListOptions(
47 options...,
48 ),
49 )
50
51 return &ModelListComponent{
52 list: modelList,
53 modelType: LargeModelType,
54 }
55}
56
57func (m *ModelListComponent) Init() tea.Cmd {
58 var cmds []tea.Cmd
59 if len(m.providers) == 0 {
60 cfg := config.Get()
61 providers, err := config.Providers(cfg)
62 filteredProviders := []catwalk.Provider{}
63 for _, p := range providers {
64 hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
65 if hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure {
66 filteredProviders = append(filteredProviders, p)
67 }
68 }
69
70 m.providers = filteredProviders
71 if err != nil {
72 cmds = append(cmds, util.ReportError(err))
73 }
74 }
75 cmds = append(cmds, m.list.Init(), m.SetModelType(m.modelType))
76 return tea.Batch(cmds...)
77}
78
79func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
80 u, cmd := m.list.Update(msg)
81 m.list = u.(listModel)
82 return m, cmd
83}
84
85func (m *ModelListComponent) View() string {
86 return m.list.View()
87}
88
89func (m *ModelListComponent) Cursor() *tea.Cursor {
90 return m.list.Cursor()
91}
92
93func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
94 return m.list.SetSize(width, height)
95}
96
97func (m *ModelListComponent) SelectedModel() *ModelOption {
98 s := m.list.SelectedItem()
99 if s == nil {
100 return nil
101 }
102 sv := *s
103 model := sv.Value()
104 return &model
105}
106
107func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
108 t := styles.CurrentTheme()
109 m.modelType = modelType
110
111 var groups []list.Group[list.CompletionItem[ModelOption]]
112 // first none section
113 selectedItemID := ""
114 itemsByKey := make(map[string]list.CompletionItem[ModelOption])
115
116 cfg := config.Get()
117 var currentModel config.SelectedModel
118 selectedType := config.SelectedModelTypeLarge
119 if m.modelType == LargeModelType {
120 currentModel = cfg.Models[config.SelectedModelTypeLarge]
121 selectedType = config.SelectedModelTypeLarge
122 } else {
123 currentModel = cfg.Models[config.SelectedModelTypeSmall]
124 selectedType = config.SelectedModelTypeSmall
125 }
126 recentItems := cfg.RecentModels[selectedType]
127
128 configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
129 configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
130
131 // Create a map to track which providers we've already added
132 addedProviders := make(map[string]bool)
133
134 // First, add any configured providers that are not in the known providers list
135 // These should appear at the top of the list
136 knownProviders, err := config.Providers(cfg)
137 if err != nil {
138 return util.ReportError(err)
139 }
140 for providerID, providerConfig := range cfg.Providers.Seq2() {
141 if providerConfig.Disable {
142 continue
143 }
144
145 // Check if this provider is not in the known providers list
146 if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) ||
147 !slices.ContainsFunc(m.providers, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
148 // Convert config provider to provider.Provider format
149 configProvider := catwalk.Provider{
150 Name: providerConfig.Name,
151 ID: catwalk.InferenceProvider(providerID),
152 Models: make([]catwalk.Model, len(providerConfig.Models)),
153 }
154
155 // Convert models
156 for i, model := range providerConfig.Models {
157 configProvider.Models[i] = catwalk.Model{
158 ID: model.ID,
159 Name: model.Name,
160 CostPer1MIn: model.CostPer1MIn,
161 CostPer1MOut: model.CostPer1MOut,
162 CostPer1MInCached: model.CostPer1MInCached,
163 CostPer1MOutCached: model.CostPer1MOutCached,
164 ContextWindow: model.ContextWindow,
165 DefaultMaxTokens: model.DefaultMaxTokens,
166 CanReason: model.CanReason,
167 ReasoningLevels: model.ReasoningLevels,
168 DefaultReasoningEffort: model.DefaultReasoningEffort,
169 SupportsImages: model.SupportsImages,
170 }
171 }
172
173 // Add this unknown provider to the list
174 name := configProvider.Name
175 if name == "" {
176 name = string(configProvider.ID)
177 }
178 section := list.NewItemSection(name)
179 section.SetInfo(configured)
180 group := list.Group[list.CompletionItem[ModelOption]]{
181 Section: section,
182 }
183 for _, model := range configProvider.Models {
184 modelOption := ModelOption{
185 Provider: configProvider,
186 Model: model,
187 }
188 key := modelKey(string(configProvider.ID), model.ID)
189 item := list.NewCompletionItem(
190 model.Name,
191 modelOption,
192 list.WithCompletionID(key),
193 )
194 itemsByKey[key] = item
195
196 group.Items = append(group.Items, item)
197 if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
198 selectedItemID = item.ID()
199 }
200 }
201 groups = append(groups, group)
202
203 addedProviders[providerID] = true
204 }
205 }
206
207 // Then add the known providers from the predefined list
208 for _, provider := range m.providers {
209 // Skip if we already added this provider as an unknown provider
210 if addedProviders[string(provider.ID)] {
211 continue
212 }
213
214 providerConfig, providerConfigured := cfg.Providers.Get(string(provider.ID))
215 if providerConfigured && providerConfig.Disable {
216 continue
217 }
218
219 displayProvider := provider
220 if providerConfigured {
221 displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
222 modelIndex := make(map[string]int, len(displayProvider.Models))
223 for i, model := range displayProvider.Models {
224 modelIndex[model.ID] = i
225 }
226 for _, model := range providerConfig.Models {
227 if model.ID == "" {
228 continue
229 }
230 if idx, ok := modelIndex[model.ID]; ok {
231 if model.Name != "" {
232 displayProvider.Models[idx].Name = model.Name
233 }
234 continue
235 }
236 if model.Name == "" {
237 model.Name = model.ID
238 }
239 displayProvider.Models = append(displayProvider.Models, model)
240 modelIndex[model.ID] = len(displayProvider.Models) - 1
241 }
242 }
243
244 name := displayProvider.Name
245 if name == "" {
246 name = string(displayProvider.ID)
247 }
248
249 section := list.NewItemSection(name)
250 if providerConfigured {
251 section.SetInfo(configured)
252 }
253 group := list.Group[list.CompletionItem[ModelOption]]{
254 Section: section,
255 }
256 for _, model := range displayProvider.Models {
257 modelOption := ModelOption{
258 Provider: displayProvider,
259 Model: model,
260 }
261 key := modelKey(string(displayProvider.ID), model.ID)
262 item := list.NewCompletionItem(
263 model.Name,
264 modelOption,
265 list.WithCompletionID(key),
266 )
267 itemsByKey[key] = item
268 group.Items = append(group.Items, item)
269 if model.ID == currentModel.Model && string(displayProvider.ID) == currentModel.Provider {
270 selectedItemID = item.ID()
271 }
272 }
273 groups = append(groups, group)
274 }
275
276 if len(recentItems) > 0 {
277 recentSection := list.NewItemSection("Recently used")
278 recentGroup := list.Group[list.CompletionItem[ModelOption]]{
279 Section: recentSection,
280 }
281 var validRecentItems []config.SelectedModel
282 for _, recent := range recentItems {
283 key := modelKey(recent.Provider, recent.Model)
284 option, ok := itemsByKey[key]
285 if !ok {
286 continue
287 }
288 validRecentItems = append(validRecentItems, recent)
289 recentID := fmt.Sprintf("recent::%s", key)
290 modelOption := option.Value()
291 providerName := modelOption.Provider.Name
292 if providerName == "" {
293 providerName = string(modelOption.Provider.ID)
294 }
295 item := list.NewCompletionItem(
296 modelOption.Model.Name,
297 option.Value(),
298 list.WithCompletionID(recentID),
299 list.WithCompletionShortcut(providerName),
300 )
301 recentGroup.Items = append(recentGroup.Items, item)
302 if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider {
303 selectedItemID = recentID
304 }
305 }
306
307 if len(validRecentItems) != len(recentItems) {
308 if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
309 return util.ReportError(err)
310 }
311 }
312
313 if len(recentGroup.Items) > 0 {
314 groups = append([]list.Group[list.CompletionItem[ModelOption]]{recentGroup}, groups...)
315 }
316 }
317
318 var cmds []tea.Cmd
319
320 cmd := m.list.SetGroups(groups)
321
322 if cmd != nil {
323 cmds = append(cmds, cmd)
324 }
325 cmd = m.list.SetSelected(selectedItemID)
326 if cmd != nil {
327 cmds = append(cmds, cmd)
328 }
329
330 return tea.Sequence(cmds...)
331}
332
333// GetModelType returns the current model type
334func (m *ModelListComponent) GetModelType() int {
335 return m.modelType
336}
337
338func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
339 m.list.SetInputPlaceholder(placeholder)
340}