1package models
2
3import (
4 "fmt"
5 "slices"
6 "strings"
7
8 tea "github.com/charmbracelet/bubbletea/v2"
9 "github.com/charmbracelet/catwalk/pkg/catwalk"
10 "github.com/charmbracelet/crush/internal/config"
11 "github.com/charmbracelet/crush/internal/tui/exp/list"
12 "github.com/charmbracelet/crush/internal/tui/styles"
13 "github.com/charmbracelet/crush/internal/tui/util"
14)
15
16type listModel = list.FilterableGroupList[list.CompletionItem[ModelOption]]
17
18type ModelListComponent struct {
19 list listModel
20 modelType int
21 providers []catwalk.Provider
22}
23
24func NewModelListComponent(keyMap list.KeyMap, inputPlaceholder string, shouldResize bool) *ModelListComponent {
25 t := styles.CurrentTheme()
26 inputStyle := t.S().Base.PaddingLeft(1).PaddingBottom(1)
27 options := []list.ListOption{
28 list.WithKeyMap(keyMap),
29 list.WithWrapNavigation(),
30 }
31 if shouldResize {
32 options = append(options, list.WithResizeByList())
33 }
34 modelList := list.NewFilterableGroupedList(
35 []list.Group[list.CompletionItem[ModelOption]]{},
36 list.WithFilterInputStyle(inputStyle),
37 list.WithFilterPlaceholder(inputPlaceholder),
38 list.WithFilterListOptions(
39 options...,
40 ),
41 )
42
43 return &ModelListComponent{
44 list: modelList,
45 modelType: LargeModelType,
46 }
47}
48
49func (m *ModelListComponent) Init() tea.Cmd {
50 var cmds []tea.Cmd
51 if len(m.providers) == 0 {
52 providers, err := config.Providers()
53 filteredProviders := []catwalk.Provider{}
54 for _, p := range providers {
55 hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
56 if hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure {
57 filteredProviders = append(filteredProviders, p)
58 }
59 }
60
61 m.providers = filteredProviders
62 if err != nil {
63 cmds = append(cmds, util.ReportError(err))
64 }
65 }
66 cmds = append(cmds, m.list.Init(), m.SetModelType(m.modelType))
67 return tea.Batch(cmds...)
68}
69
70func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
71 u, cmd := m.list.Update(msg)
72 m.list = u.(listModel)
73 return m, cmd
74}
75
76func (m *ModelListComponent) View() string {
77 return m.list.View()
78}
79
80func (m *ModelListComponent) Cursor() *tea.Cursor {
81 return m.list.Cursor()
82}
83
84func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
85 return m.list.SetSize(width, height)
86}
87
88func (m *ModelListComponent) SelectedModel() *ModelOption {
89 s := m.list.SelectedItem()
90 if s == nil {
91 return nil
92 }
93 sv := *s
94 model := sv.Value()
95 return &model
96}
97
98func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
99 t := styles.CurrentTheme()
100 m.modelType = modelType
101
102 var groups []list.Group[list.CompletionItem[ModelOption]]
103 // first none section
104 selectedItemID := ""
105
106 cfg := config.Get()
107 var currentModel config.SelectedModel
108 if m.modelType == LargeModelType {
109 currentModel = cfg.Models[config.SelectedModelTypeLarge]
110 } else {
111 currentModel = cfg.Models[config.SelectedModelTypeSmall]
112 }
113
114 configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
115 configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
116
117 // Create a map to track which providers we've already added
118 addedProviders := make(map[string]bool)
119
120 // First, add any configured providers that are not in the known providers list
121 // These should appear at the top of the list
122 knownProviders, err := config.Providers()
123 if err != nil {
124 return util.ReportError(err)
125 }
126 for providerID, providerConfig := range cfg.Providers.Seq2() {
127 if providerConfig.Disable {
128 continue
129 }
130
131 // Check if this provider is not in the known providers list
132 if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) ||
133 !slices.ContainsFunc(m.providers, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
134 // Convert config provider to provider.Provider format
135 configProvider := catwalk.Provider{
136 Name: providerConfig.Name,
137 ID: catwalk.InferenceProvider(providerID),
138 Models: make([]catwalk.Model, len(providerConfig.Models)),
139 }
140
141 // Convert models
142 for i, model := range providerConfig.Models {
143 configProvider.Models[i] = catwalk.Model{
144 ID: model.ID,
145 Name: model.Name,
146 CostPer1MIn: model.CostPer1MIn,
147 CostPer1MOut: model.CostPer1MOut,
148 CostPer1MInCached: model.CostPer1MInCached,
149 CostPer1MOutCached: model.CostPer1MOutCached,
150 ContextWindow: model.ContextWindow,
151 DefaultMaxTokens: model.DefaultMaxTokens,
152 CanReason: model.CanReason,
153 HasReasoningEffort: model.HasReasoningEffort,
154 DefaultReasoningEffort: model.DefaultReasoningEffort,
155 SupportsImages: model.SupportsImages,
156 }
157 }
158
159 // Add this unknown provider to the list
160 name := configProvider.Name
161 if name == "" {
162 name = string(configProvider.ID)
163 }
164 section := list.NewItemSection(name)
165 section.SetInfo(configured)
166 group := list.Group[list.CompletionItem[ModelOption]]{
167 Section: section,
168 }
169 for _, model := range configProvider.Models {
170 item := list.NewCompletionItem(model.Name, ModelOption{
171 Provider: configProvider,
172 Model: model,
173 },
174 list.WithCompletionID(
175 fmt.Sprintf("%s:%s", providerConfig.ID, model.ID),
176 ),
177 )
178
179 group.Items = append(group.Items, item)
180 if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
181 selectedItemID = item.ID()
182 }
183 }
184 groups = append(groups, group)
185
186 addedProviders[providerID] = true
187 }
188 }
189
190 // Then add the known providers from the predefined list
191 for _, provider := range m.providers {
192 // Skip if we already added this provider as an unknown provider
193 if addedProviders[string(provider.ID)] {
194 continue
195 }
196
197 // Check if this provider is configured and not disabled
198 if providerConfig, exists := cfg.Providers.Get(string(provider.ID)); exists && providerConfig.Disable {
199 continue
200 }
201
202 name := provider.Name
203 if name == "" {
204 name = string(provider.ID)
205 }
206
207 section := list.NewItemSection(name)
208 if _, ok := cfg.Providers.Get(string(provider.ID)); ok {
209 section.SetInfo(configured)
210 }
211 group := list.Group[list.CompletionItem[ModelOption]]{
212 Section: section,
213 }
214 for _, model := range provider.Models {
215 item := list.NewCompletionItem(model.Name, ModelOption{
216 Provider: provider,
217 Model: model,
218 },
219 list.WithCompletionID(
220 fmt.Sprintf("%s:%s", provider.ID, model.ID),
221 ),
222 )
223 group.Items = append(group.Items, item)
224 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
225 selectedItemID = item.ID()
226 }
227 }
228 groups = append(groups, group)
229 }
230
231 var cmds []tea.Cmd
232
233 cmd := m.list.SetGroups(groups)
234
235 if cmd != nil {
236 cmds = append(cmds, cmd)
237 }
238 cmd = m.list.SetSelected(selectedItemID)
239 if cmd != nil {
240 cmds = append(cmds, cmd)
241 }
242
243 return tea.Sequence(cmds...)
244}
245
246// GetModelType returns the current model type
247func (m *ModelListComponent) GetModelType() int {
248 return m.modelType
249}
250
251func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
252 m.list.SetInputPlaceholder(placeholder)
253}