1package models
2
3import (
4 "fmt"
5 "slices"
6 "strings"
7
8 tea "github.com/charmbracelet/bubbletea/v2"
9 "github.com/charmbracelet/catwalk/pkg/catwalk"
10 "github.com/charmbracelet/crush/internal/config"
11 "github.com/charmbracelet/crush/internal/env"
12 "github.com/charmbracelet/crush/internal/tui/exp/list"
13 "github.com/charmbracelet/crush/internal/tui/styles"
14 "github.com/charmbracelet/crush/internal/tui/util"
15)
16
17type listModel = list.FilterableGroupList[list.CompletionItem[ModelOption]]
18
19type ModelListComponent struct {
20 list listModel
21 modelType int
22 providers []catwalk.Provider
23}
24
25func NewModelListComponent(keyMap list.KeyMap, inputPlaceholder string, shouldResize bool) *ModelListComponent {
26 t := styles.CurrentTheme()
27 inputStyle := t.S().Base.PaddingLeft(1).PaddingBottom(1)
28 options := []list.ListOption{
29 list.WithKeyMap(keyMap),
30 list.WithWrapNavigation(),
31 }
32 if shouldResize {
33 options = append(options, list.WithResizeByList())
34 }
35 modelList := list.NewFilterableGroupedList(
36 []list.Group[list.CompletionItem[ModelOption]]{},
37 list.WithFilterInputStyle(inputStyle),
38 list.WithFilterPlaceholder(inputPlaceholder),
39 list.WithFilterListOptions(
40 options...,
41 ),
42 )
43
44 return &ModelListComponent{
45 list: modelList,
46 modelType: LargeModelType,
47 }
48}
49
50func (m *ModelListComponent) Init() tea.Cmd {
51 var cmds []tea.Cmd
52 if len(m.providers) == 0 {
53 providers, err := config.Providers()
54 filteredProviders := []catwalk.Provider{}
55 for _, p := range providers {
56 hasApiKeyEnv := strings.HasPrefix(p.APIKey, "$")
57 resolver := config.NewEnvironmentVariableResolver(env.New())
58 endpoint, _ := resolver.ResolveValue(p.APIEndpoint)
59 if endpoint != "" && hasApiKeyEnv {
60 filteredProviders = append(filteredProviders, p)
61 }
62 }
63
64 m.providers = filteredProviders
65 if err != nil {
66 cmds = append(cmds, util.ReportError(err))
67 }
68 }
69 cmds = append(cmds, m.list.Init(), m.SetModelType(m.modelType))
70 return tea.Batch(cmds...)
71}
72
73func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
74 u, cmd := m.list.Update(msg)
75 m.list = u.(listModel)
76 return m, cmd
77}
78
79func (m *ModelListComponent) View() string {
80 return m.list.View()
81}
82
83func (m *ModelListComponent) Cursor() *tea.Cursor {
84 return m.list.Cursor()
85}
86
87func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
88 return m.list.SetSize(width, height)
89}
90
91func (m *ModelListComponent) SelectedModel() *ModelOption {
92 s := m.list.SelectedItem()
93 if s == nil {
94 return nil
95 }
96 sv := *s
97 model := sv.Value()
98 return &model
99}
100
101func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
102 t := styles.CurrentTheme()
103 m.modelType = modelType
104
105 var groups []list.Group[list.CompletionItem[ModelOption]]
106 // first none section
107 selectedItemID := ""
108
109 cfg := config.Get()
110 var currentModel config.SelectedModel
111 if m.modelType == LargeModelType {
112 currentModel = cfg.Models[config.SelectedModelTypeLarge]
113 } else {
114 currentModel = cfg.Models[config.SelectedModelTypeSmall]
115 }
116
117 configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
118 configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
119
120 // Create a map to track which providers we've already added
121 addedProviders := make(map[string]bool)
122
123 // First, add any configured providers that are not in the known providers list
124 // These should appear at the top of the list
125 knownProviders, err := config.Providers()
126 if err != nil {
127 return util.ReportError(err)
128 }
129 for providerID, providerConfig := range cfg.Providers.Seq2() {
130 if providerConfig.Disable {
131 continue
132 }
133
134 // Check if this provider is not in the known providers list
135 if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) ||
136 !slices.ContainsFunc(m.providers, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
137 // Convert config provider to provider.Provider format
138 configProvider := catwalk.Provider{
139 Name: providerConfig.Name,
140 ID: catwalk.InferenceProvider(providerID),
141 Models: make([]catwalk.Model, len(providerConfig.Models)),
142 }
143
144 // Convert models
145 for i, model := range providerConfig.Models {
146 configProvider.Models[i] = catwalk.Model{
147 ID: model.ID,
148 Name: model.Name,
149 CostPer1MIn: model.CostPer1MIn,
150 CostPer1MOut: model.CostPer1MOut,
151 CostPer1MInCached: model.CostPer1MInCached,
152 CostPer1MOutCached: model.CostPer1MOutCached,
153 ContextWindow: model.ContextWindow,
154 DefaultMaxTokens: model.DefaultMaxTokens,
155 CanReason: model.CanReason,
156 HasReasoningEffort: model.HasReasoningEffort,
157 DefaultReasoningEffort: model.DefaultReasoningEffort,
158 SupportsImages: model.SupportsImages,
159 }
160 }
161
162 // Add this unknown provider to the list
163 name := configProvider.Name
164 if name == "" {
165 name = string(configProvider.ID)
166 }
167 section := list.NewItemSection(name)
168 section.SetInfo(configured)
169 group := list.Group[list.CompletionItem[ModelOption]]{
170 Section: section,
171 }
172 for _, model := range configProvider.Models {
173 item := list.NewCompletionItem(model.Name, ModelOption{
174 Provider: configProvider,
175 Model: model,
176 },
177 list.WithCompletionID(
178 fmt.Sprintf("%s:%s", providerConfig.ID, model.ID),
179 ),
180 )
181
182 group.Items = append(group.Items, item)
183 if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
184 selectedItemID = item.ID()
185 }
186 }
187 groups = append(groups, group)
188
189 addedProviders[providerID] = true
190 }
191 }
192
193 // Then add the known providers from the predefined list
194 for _, provider := range m.providers {
195 // Skip if we already added this provider as an unknown provider
196 if addedProviders[string(provider.ID)] {
197 continue
198 }
199
200 // Check if this provider is configured and not disabled
201 if providerConfig, exists := cfg.Providers.Get(string(provider.ID)); exists && providerConfig.Disable {
202 continue
203 }
204
205 name := provider.Name
206 if name == "" {
207 name = string(provider.ID)
208 }
209
210 section := list.NewItemSection(name)
211 if _, ok := cfg.Providers.Get(string(provider.ID)); ok {
212 section.SetInfo(configured)
213 }
214 group := list.Group[list.CompletionItem[ModelOption]]{
215 Section: section,
216 }
217 for _, model := range provider.Models {
218 item := list.NewCompletionItem(model.Name, ModelOption{
219 Provider: provider,
220 Model: model,
221 },
222 list.WithCompletionID(
223 fmt.Sprintf("%s:%s", provider.ID, model.ID),
224 ),
225 )
226 group.Items = append(group.Items, item)
227 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
228 selectedItemID = item.ID()
229 }
230 }
231 groups = append(groups, group)
232 }
233
234 var cmds []tea.Cmd
235
236 cmd := m.list.SetGroups(groups)
237
238 if cmd != nil {
239 cmds = append(cmds, cmd)
240 }
241 cmd = m.list.SetSelected(selectedItemID)
242 if cmd != nil {
243 cmds = append(cmds, cmd)
244 }
245
246 return tea.Sequence(cmds...)
247}
248
249// GetModelType returns the current model type
250func (m *ModelListComponent) GetModelType() int {
251 return m.modelType
252}
253
254func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
255 m.list.SetInputPlaceholder(placeholder)
256}