1package models
2
3import (
4 "fmt"
5 "slices"
6 "strings"
7
8 tea "github.com/charmbracelet/bubbletea/v2"
9 "github.com/charmbracelet/catwalk/pkg/catwalk"
10 "github.com/charmbracelet/crush/internal/config"
11 "github.com/charmbracelet/crush/internal/tui/exp/list"
12 "github.com/charmbracelet/crush/internal/tui/styles"
13 "github.com/charmbracelet/crush/internal/tui/util"
14)
15
16type listModel = list.FilterableGroupList[list.CompletionItem[ModelOption]]
17
18type ModelListComponent struct {
19 list listModel
20 modelType int
21 providers []catwalk.Provider
22 cfg *config.Config
23}
24
25func NewModelListComponent(cfg *config.Config, keyMap list.KeyMap, inputPlaceholder string, shouldResize bool) *ModelListComponent {
26 t := styles.CurrentTheme()
27 inputStyle := t.S().Base.PaddingLeft(1).PaddingBottom(1)
28 options := []list.ListOption{
29 list.WithKeyMap(keyMap),
30 list.WithWrapNavigation(),
31 }
32 if shouldResize {
33 options = append(options, list.WithResizeByList())
34 }
35 modelList := list.NewFilterableGroupedList(
36 []list.Group[list.CompletionItem[ModelOption]]{},
37 list.WithFilterInputStyle(inputStyle),
38 list.WithFilterPlaceholder(inputPlaceholder),
39 list.WithFilterListOptions(
40 options...,
41 ),
42 )
43
44 return &ModelListComponent{
45 list: modelList,
46 modelType: LargeModelType,
47 cfg: cfg,
48 }
49}
50
51func (m *ModelListComponent) Init() tea.Cmd {
52 var cmds []tea.Cmd
53 if len(m.providers) == 0 {
54 providers, err := config.Providers(m.cfg)
55 filteredProviders := []catwalk.Provider{}
56 for _, p := range providers {
57 hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
58 if hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure {
59 filteredProviders = append(filteredProviders, p)
60 }
61 }
62
63 m.providers = filteredProviders
64 if err != nil {
65 cmds = append(cmds, util.ReportError(err))
66 }
67 }
68 cmds = append(cmds, m.list.Init(), m.SetModelType(m.modelType))
69 return tea.Batch(cmds...)
70}
71
72func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
73 u, cmd := m.list.Update(msg)
74 m.list = u.(listModel)
75 return m, cmd
76}
77
78func (m *ModelListComponent) View() string {
79 return m.list.View()
80}
81
82func (m *ModelListComponent) Cursor() *tea.Cursor {
83 return m.list.Cursor()
84}
85
86func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
87 return m.list.SetSize(width, height)
88}
89
90func (m *ModelListComponent) SelectedModel() *ModelOption {
91 s := m.list.SelectedItem()
92 if s == nil {
93 return nil
94 }
95 sv := *s
96 model := sv.Value()
97 return &model
98}
99
100func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
101 t := styles.CurrentTheme()
102 m.modelType = modelType
103
104 var groups []list.Group[list.CompletionItem[ModelOption]]
105 // first none section
106 selectedItemID := ""
107
108 var currentModel config.SelectedModel
109 if m.modelType == LargeModelType {
110 currentModel = m.cfg.Models[config.SelectedModelTypeLarge]
111 } else {
112 currentModel = m.cfg.Models[config.SelectedModelTypeSmall]
113 }
114
115 configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
116 configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
117
118 // Create a map to track which providers we've already added
119 addedProviders := make(map[string]bool)
120
121 // First, add any configured providers that are not in the known providers list
122 // These should appear at the top of the list
123 knownProviders, err := config.Providers(m.cfg)
124 if err != nil {
125 return util.ReportError(err)
126 }
127 for providerID, providerConfig := range m.cfg.Providers.Seq2() {
128 if providerConfig.Disable {
129 continue
130 }
131
132 // Check if this provider is not in the known providers list
133 if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) ||
134 !slices.ContainsFunc(m.providers, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
135 // Convert config provider to provider.Provider format
136 configProvider := catwalk.Provider{
137 Name: providerConfig.Name,
138 ID: catwalk.InferenceProvider(providerID),
139 Models: make([]catwalk.Model, len(providerConfig.Models)),
140 }
141
142 // Convert models
143 for i, model := range providerConfig.Models {
144 configProvider.Models[i] = catwalk.Model{
145 ID: model.ID,
146 Name: model.Name,
147 CostPer1MIn: model.CostPer1MIn,
148 CostPer1MOut: model.CostPer1MOut,
149 CostPer1MInCached: model.CostPer1MInCached,
150 CostPer1MOutCached: model.CostPer1MOutCached,
151 ContextWindow: model.ContextWindow,
152 DefaultMaxTokens: model.DefaultMaxTokens,
153 CanReason: model.CanReason,
154 HasReasoningEffort: model.HasReasoningEffort,
155 DefaultReasoningEffort: model.DefaultReasoningEffort,
156 SupportsImages: model.SupportsImages,
157 }
158 }
159
160 // Add this unknown provider to the list
161 name := configProvider.Name
162 if name == "" {
163 name = string(configProvider.ID)
164 }
165 section := list.NewItemSection(name)
166 section.SetInfo(configured)
167 group := list.Group[list.CompletionItem[ModelOption]]{
168 Section: section,
169 }
170 for _, model := range configProvider.Models {
171 item := list.NewCompletionItem(model.Name, ModelOption{
172 Provider: configProvider,
173 Model: model,
174 },
175 list.WithCompletionID(
176 fmt.Sprintf("%s:%s", providerConfig.ID, model.ID),
177 ),
178 )
179
180 group.Items = append(group.Items, item)
181 if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
182 selectedItemID = item.ID()
183 }
184 }
185 groups = append(groups, group)
186
187 addedProviders[providerID] = true
188 }
189 }
190
191 // Then add the known providers from the predefined list
192 for _, provider := range m.providers {
193 // Skip if we already added this provider as an unknown provider
194 if addedProviders[string(provider.ID)] {
195 continue
196 }
197
198 // Check if this provider is configured and not disabled
199 if providerConfig, exists := m.cfg.Providers.Get(string(provider.ID)); exists && providerConfig.Disable {
200 continue
201 }
202
203 name := provider.Name
204 if name == "" {
205 name = string(provider.ID)
206 }
207
208 section := list.NewItemSection(name)
209 if _, ok := m.cfg.Providers.Get(string(provider.ID)); ok {
210 section.SetInfo(configured)
211 }
212 group := list.Group[list.CompletionItem[ModelOption]]{
213 Section: section,
214 }
215 for _, model := range provider.Models {
216 item := list.NewCompletionItem(model.Name, ModelOption{
217 Provider: provider,
218 Model: model,
219 },
220 list.WithCompletionID(
221 fmt.Sprintf("%s:%s", provider.ID, model.ID),
222 ),
223 )
224 group.Items = append(group.Items, item)
225 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
226 selectedItemID = item.ID()
227 }
228 }
229 groups = append(groups, group)
230 }
231
232 var cmds []tea.Cmd
233
234 cmd := m.list.SetGroups(groups)
235
236 if cmd != nil {
237 cmds = append(cmds, cmd)
238 }
239 cmd = m.list.SetSelected(selectedItemID)
240 if cmd != nil {
241 cmds = append(cmds, cmd)
242 }
243
244 return tea.Sequence(cmds...)
245}
246
247// GetModelType returns the current model type
248func (m *ModelListComponent) GetModelType() int {
249 return m.modelType
250}
251
252func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
253 m.list.SetInputPlaceholder(placeholder)
254}