1package models
2
3import (
4 "fmt"
5 "slices"
6
7 tea "github.com/charmbracelet/bubbletea/v2"
8 "github.com/charmbracelet/catwalk/pkg/catwalk"
9 "github.com/charmbracelet/crush/internal/config"
10 "github.com/charmbracelet/crush/internal/llm/agent"
11 "github.com/charmbracelet/crush/internal/tui/exp/list"
12 "github.com/charmbracelet/crush/internal/tui/styles"
13 "github.com/charmbracelet/crush/internal/tui/util"
14)
15
16type listModel = list.FilterableGroupList[list.CompletionItem[ModelOption]]
17
18type ModelListComponent struct {
19 list listModel
20 modelType int
21 providers []catwalk.Provider
22 config *config.Config
23}
24
25func NewModelListComponent(cfg *config.Config, keyMap list.KeyMap, inputPlaceholder string, shouldResize bool) *ModelListComponent {
26 t := styles.CurrentTheme()
27 inputStyle := t.S().Base.PaddingLeft(1).PaddingBottom(1)
28 options := []list.ListOption{
29 list.WithKeyMap(keyMap),
30 list.WithWrapNavigation(),
31 }
32 if shouldResize {
33 options = append(options, list.WithResizeByList())
34 }
35 modelList := list.NewFilterableGroupedList(
36 []list.Group[list.CompletionItem[ModelOption]]{},
37 list.WithFilterInputStyle(inputStyle),
38 list.WithFilterPlaceholder(inputPlaceholder),
39 list.WithFilterListOptions(
40 options...,
41 ),
42 )
43
44 return &ModelListComponent{
45 list: modelList,
46 modelType: LargeModelType,
47 config: cfg,
48 }
49}
50
51func (m *ModelListComponent) Init() tea.Cmd {
52 var cmds []tea.Cmd
53 if len(m.providers) == 0 {
54 providers, err := config.Providers()
55 m.providers = providers
56 if err != nil {
57 cmds = append(cmds, util.ReportError(err))
58 }
59 }
60 cmds = append(cmds, m.list.Init(), m.SetModelType(m.modelType))
61 return tea.Batch(cmds...)
62}
63
64func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
65 u, cmd := m.list.Update(msg)
66 m.list = u.(listModel)
67 return m, cmd
68}
69
70func (m *ModelListComponent) View() string {
71 return m.list.View()
72}
73
74func (m *ModelListComponent) Cursor() *tea.Cursor {
75 return m.list.Cursor()
76}
77
78func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
79 return m.list.SetSize(width, height)
80}
81
82func (m *ModelListComponent) SelectedModel() *ModelOption {
83 s := m.list.SelectedItem()
84 if s == nil {
85 return nil
86 }
87 sv := *s
88 model := sv.Value()
89 return &model
90}
91
92func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
93 t := styles.CurrentTheme()
94 m.modelType = modelType
95
96 var groups []list.Group[list.CompletionItem[ModelOption]]
97 // first none section
98 selectedItemID := ""
99
100 var currentModel agent.Model
101 if m.modelType == LargeModelType {
102 currentModel = m.config.Models[config.SelectedModelTypeLarge]
103 } else {
104 currentModel = m.config.Models[config.SelectedModelTypeSmall]
105 }
106
107 configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
108 configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
109
110 // Create a map to track which providers we've already added
111 addedProviders := make(map[string]bool)
112
113 // First, add any configured providers that are not in the known providers list
114 // These should appear at the top of the list
115 knownProviders, err := config.Providers()
116 if err != nil {
117 return util.ReportError(err)
118 }
119 for providerID, providerConfig := range m.config.Providers.Seq2() {
120 if providerConfig.Disable {
121 continue
122 }
123
124 // Check if this provider is not in the known providers list
125 if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
126 // Convert config provider to provider.Provider format
127 configProvider := catwalk.Provider{
128 Name: providerConfig.Name,
129 ID: catwalk.InferenceProvider(providerID),
130 Models: make([]catwalk.Model, len(providerConfig.Models)),
131 }
132
133 // Convert models
134 for i, model := range providerConfig.Models {
135 configProvider.Models[i] = catwalk.Model{
136 ID: model.ID,
137 Name: model.Name,
138 CostPer1MIn: model.CostPer1MIn,
139 CostPer1MOut: model.CostPer1MOut,
140 CostPer1MInCached: model.CostPer1MInCached,
141 CostPer1MOutCached: model.CostPer1MOutCached,
142 ContextWindow: model.ContextWindow,
143 DefaultMaxTokens: model.DefaultMaxTokens,
144 CanReason: model.CanReason,
145 HasReasoningEffort: model.HasReasoningEffort,
146 DefaultReasoningEffort: model.DefaultReasoningEffort,
147 SupportsImages: model.SupportsImages,
148 }
149 }
150
151 // Add this unknown provider to the list
152 name := configProvider.Name
153 if name == "" {
154 name = string(configProvider.ID)
155 }
156 section := list.NewItemSection(name)
157 section.SetInfo(configured)
158 group := list.Group[list.CompletionItem[ModelOption]]{
159 Section: section,
160 }
161 for _, model := range configProvider.Models {
162 item := list.NewCompletionItem(model.Name, ModelOption{
163 Provider: configProvider,
164 Model: model,
165 },
166 list.WithCompletionID(
167 fmt.Sprintf("%s:%s", providerConfig.ID, model.ID),
168 ),
169 )
170
171 group.Items = append(group.Items, item)
172 if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
173 selectedItemID = item.ID()
174 }
175 }
176 groups = append(groups, group)
177
178 addedProviders[providerID] = true
179 }
180 }
181
182 // Then add the known providers from the predefined list
183 for _, provider := range m.providers {
184 // Skip if we already added this provider as an unknown provider
185 if addedProviders[string(provider.ID)] {
186 continue
187 }
188
189 // Check if this provider is configured and not disabled
190 if providerConfig, exists := m.config.Providers.Get(string(provider.ID)); exists && providerConfig.Disable {
191 continue
192 }
193
194 name := provider.Name
195 if name == "" {
196 name = string(provider.ID)
197 }
198
199 section := list.NewItemSection(name)
200 if _, ok := m.config.Providers.Get(string(provider.ID)); ok {
201 section.SetInfo(configured)
202 }
203 group := list.Group[list.CompletionItem[ModelOption]]{
204 Section: section,
205 }
206 for _, model := range provider.Models {
207 item := list.NewCompletionItem(model.Name, ModelOption{
208 Provider: provider,
209 Model: model,
210 },
211 list.WithCompletionID(
212 fmt.Sprintf("%s:%s", provider.ID, model.ID),
213 ),
214 )
215 group.Items = append(group.Items, item)
216 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
217 selectedItemID = item.ID()
218 }
219 }
220 groups = append(groups, group)
221 }
222
223 var cmds []tea.Cmd
224
225 cmd := m.list.SetGroups(groups)
226
227 if cmd != nil {
228 cmds = append(cmds, cmd)
229 }
230 cmd = m.list.SetSelected(selectedItemID)
231 if cmd != nil {
232 cmds = append(cmds, cmd)
233 }
234
235 return tea.Sequence(cmds...)
236}
237
238// GetModelType returns the current model type
239func (m *ModelListComponent) GetModelType() int {
240 return m.modelType
241}
242
243func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
244 m.list.SetInputPlaceholder(placeholder)
245}
246
247func (m *ModelListComponent) SetProviders(providers []catwalk.Provider) {
248 m.providers = providers
249}