1package models
2
3import (
4 "cmp"
5 "fmt"
6 "slices"
7 "strings"
8
9 tea "github.com/charmbracelet/bubbletea/v2"
10 "github.com/charmbracelet/catwalk/pkg/catwalk"
11 "github.com/charmbracelet/crush/internal/config"
12 "github.com/charmbracelet/crush/internal/tui/exp/list"
13 "github.com/charmbracelet/crush/internal/tui/styles"
14 "github.com/charmbracelet/crush/internal/tui/util"
15)
16
17type listModel = list.FilterableGroupList[list.CompletionItem[ModelOption]]
18
19type ModelListComponent struct {
20 list listModel
21 modelType int
22 providers []catwalk.Provider
23}
24
25func NewModelListComponent(keyMap list.KeyMap, inputPlaceholder string, shouldResize bool) *ModelListComponent {
26 t := styles.CurrentTheme()
27 inputStyle := t.S().Base.PaddingLeft(1).PaddingBottom(1)
28 options := []list.ListOption{
29 list.WithKeyMap(keyMap),
30 list.WithWrapNavigation(),
31 }
32 if shouldResize {
33 options = append(options, list.WithResizeByList())
34 }
35 modelList := list.NewFilterableGroupedList(
36 []list.Group[list.CompletionItem[ModelOption]]{},
37 list.WithFilterInputStyle(inputStyle),
38 list.WithFilterPlaceholder(inputPlaceholder),
39 list.WithFilterListOptions(
40 options...,
41 ),
42 )
43
44 return &ModelListComponent{
45 list: modelList,
46 modelType: LargeModelType,
47 }
48}
49
50func (m *ModelListComponent) Init() tea.Cmd {
51 var cmds []tea.Cmd
52 if len(m.providers) == 0 {
53 cfg := config.Get()
54 providers, err := config.Providers(cfg)
55 filteredProviders := []catwalk.Provider{}
56 for _, p := range providers {
57 hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
58 if hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure {
59 filteredProviders = append(filteredProviders, p)
60 }
61 }
62
63 m.providers = filteredProviders
64 if err != nil {
65 cmds = append(cmds, util.ReportError(err))
66 }
67 }
68 cmds = append(cmds, m.list.Init(), m.SetModelType(m.modelType))
69 return tea.Batch(cmds...)
70}
71
72func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
73 u, cmd := m.list.Update(msg)
74 m.list = u.(listModel)
75 return m, cmd
76}
77
78func (m *ModelListComponent) View() string {
79 return m.list.View()
80}
81
82func (m *ModelListComponent) Cursor() *tea.Cursor {
83 return m.list.Cursor()
84}
85
86func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
87 return m.list.SetSize(width, height)
88}
89
90func (m *ModelListComponent) SelectedModel() *ModelOption {
91 s := m.list.SelectedItem()
92 if s == nil {
93 return nil
94 }
95 sv := *s
96 model := sv.Value()
97 return &model
98}
99
100func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
101 t := styles.CurrentTheme()
102 m.modelType = modelType
103
104 var groups []list.Group[list.CompletionItem[ModelOption]]
105 // first none section
106 selectedItemID := ""
107
108 cfg := config.Get()
109 var currentModel config.SelectedModel
110 if m.modelType == LargeModelType {
111 currentModel = cfg.Models[config.SelectedModelTypeLarge]
112 } else {
113 currentModel = cfg.Models[config.SelectedModelTypeSmall]
114 }
115
116 configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
117 configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
118
119 // Create a map to track which providers we've already added
120 addedProviders := make(map[string]bool)
121
122 // First, add any configured providers that are not in the known providers list
123 // These should appear at the top of the list
124 knownProviders, err := config.Providers(cfg)
125 if err != nil {
126 return util.ReportError(err)
127 }
128 for providerID, providerConfig := range cfg.Providers.Seq2() {
129 if providerConfig.Disable {
130 continue
131 }
132
133 // Check if this provider is not in the known providers list
134 if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) ||
135 !slices.ContainsFunc(m.providers, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
136 // Convert config provider to provider.Provider format
137 configProvider := catwalk.Provider{
138 Name: providerConfig.Name,
139 ID: catwalk.InferenceProvider(providerID),
140 Models: make([]catwalk.Model, len(providerConfig.Models)),
141 }
142
143 // Convert models
144 for i, model := range providerConfig.Models {
145 configProvider.Models[i] = catwalk.Model{
146 ID: model.ID,
147 Name: model.Name,
148 CostPer1MIn: model.CostPer1MIn,
149 CostPer1MOut: model.CostPer1MOut,
150 CostPer1MInCached: model.CostPer1MInCached,
151 CostPer1MOutCached: model.CostPer1MOutCached,
152 ContextWindow: model.ContextWindow,
153 DefaultMaxTokens: model.DefaultMaxTokens,
154 CanReason: model.CanReason,
155 ReasoningLevels: model.ReasoningLevels,
156 DefaultReasoningEffort: model.DefaultReasoningEffort,
157 SupportsImages: model.SupportsImages,
158 }
159 }
160
161 // Add this unknown provider to the list
162 name := configProvider.Name
163 if name == "" {
164 name = string(configProvider.ID)
165 }
166 section := list.NewItemSection(name)
167 section.SetInfo(configured)
168 group := list.Group[list.CompletionItem[ModelOption]]{
169 Section: section,
170 }
171 for _, model := range configProvider.Models {
172 item := list.NewCompletionItem(model.Name, ModelOption{
173 Provider: configProvider,
174 Model: model,
175 },
176 list.WithCompletionID(
177 fmt.Sprintf("%s:%s", providerConfig.ID, model.ID),
178 ),
179 )
180
181 group.Items = append(group.Items, item)
182 if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
183 selectedItemID = item.ID()
184 }
185 }
186 groups = append(groups, group)
187
188 addedProviders[providerID] = true
189 }
190 }
191
192 // Then add the known providers from the predefined list
193 for _, provider := range m.providers {
194 // Skip if we already added this provider as an unknown provider
195 if addedProviders[string(provider.ID)] {
196 continue
197 }
198
199 providerConfig, providerConfigured := cfg.Providers.Get(string(provider.ID))
200 if providerConfigured && providerConfig.Disable {
201 continue
202 }
203
204 displayProvider := provider
205 if providerConfigured {
206 displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
207 modelIndex := make(map[string]int, len(displayProvider.Models))
208 for i, model := range displayProvider.Models {
209 modelIndex[model.ID] = i
210 }
211 for _, model := range providerConfig.Models {
212 if model.ID == "" {
213 continue
214 }
215 if idx, ok := modelIndex[model.ID]; ok {
216 if model.Name != "" {
217 displayProvider.Models[idx].Name = model.Name
218 }
219 continue
220 }
221 if model.Name == "" {
222 model.Name = model.ID
223 }
224 displayProvider.Models = append(displayProvider.Models, model)
225 modelIndex[model.ID] = len(displayProvider.Models) - 1
226 }
227 }
228
229 name := displayProvider.Name
230 if name == "" {
231 name = string(displayProvider.ID)
232 }
233
234 section := list.NewItemSection(name)
235 if providerConfigured {
236 section.SetInfo(configured)
237 }
238 group := list.Group[list.CompletionItem[ModelOption]]{
239 Section: section,
240 }
241 for _, model := range displayProvider.Models {
242 item := list.NewCompletionItem(model.Name, ModelOption{
243 Provider: displayProvider,
244 Model: model,
245 },
246 list.WithCompletionID(
247 fmt.Sprintf("%s:%s", displayProvider.ID, model.ID),
248 ),
249 )
250 group.Items = append(group.Items, item)
251 if model.ID == currentModel.Model && string(displayProvider.ID) == currentModel.Provider {
252 selectedItemID = item.ID()
253 }
254 }
255 groups = append(groups, group)
256 }
257
258 var cmds []tea.Cmd
259
260 cmd := m.list.SetGroups(groups)
261
262 if cmd != nil {
263 cmds = append(cmds, cmd)
264 }
265 cmd = m.list.SetSelected(selectedItemID)
266 if cmd != nil {
267 cmds = append(cmds, cmd)
268 }
269
270 return tea.Sequence(cmds...)
271}
272
273// GetModelType returns the current model type
274func (m *ModelListComponent) GetModelType() int {
275 return m.modelType
276}
277
278func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
279 m.list.SetInputPlaceholder(placeholder)
280}