1package models
2
3import (
4 "fmt"
5 "slices"
6
7 tea "github.com/charmbracelet/bubbletea/v2"
8 "github.com/charmbracelet/catwalk/pkg/catwalk"
9 "github.com/charmbracelet/crush/internal/config"
10 "github.com/charmbracelet/crush/internal/tui/exp/list"
11 "github.com/charmbracelet/crush/internal/tui/styles"
12 "github.com/charmbracelet/crush/internal/tui/util"
13)
14
15type listModel = list.FilterableGroupList[list.CompletionItem[ModelOption]]
16
17type ModelListComponent struct {
18 list listModel
19 modelType int
20 providers []catwalk.Provider
21}
22
23func NewModelListComponent(keyMap list.KeyMap, inputPlaceholder string, shouldResize bool) *ModelListComponent {
24 t := styles.CurrentTheme()
25 inputStyle := t.S().Base.PaddingLeft(1).PaddingBottom(1)
26 options := []list.ListOption{
27 list.WithKeyMap(keyMap),
28 list.WithWrapNavigation(),
29 }
30 if shouldResize {
31 options = append(options, list.WithResizeByList())
32 }
33 modelList := list.NewFilterableGroupedList(
34 []list.Group[list.CompletionItem[ModelOption]]{},
35 list.WithFilterInputStyle(inputStyle),
36 list.WithFilterPlaceholder(inputPlaceholder),
37 list.WithFilterListOptions(
38 options...,
39 ),
40 )
41
42 return &ModelListComponent{
43 list: modelList,
44 modelType: LargeModelType,
45 }
46}
47
48func (m *ModelListComponent) Init() tea.Cmd {
49 var cmds []tea.Cmd
50 if len(m.providers) == 0 {
51 providers, err := config.Providers()
52 m.providers = providers
53 if err != nil {
54 cmds = append(cmds, util.ReportError(err))
55 }
56 }
57 cmds = append(cmds, m.list.Init(), m.SetModelType(m.modelType))
58 return tea.Batch(cmds...)
59}
60
61func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
62 u, cmd := m.list.Update(msg)
63 m.list = u.(listModel)
64 return m, cmd
65}
66
67func (m *ModelListComponent) View() string {
68 return m.list.View()
69}
70
71func (m *ModelListComponent) Cursor() *tea.Cursor {
72 return m.list.Cursor()
73}
74
75func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
76 return m.list.SetSize(width, height)
77}
78
79func (m *ModelListComponent) SelectedModel() *ModelOption {
80 s := m.list.SelectedItem()
81 if s == nil {
82 return nil
83 }
84 sv := *s
85 model := sv.Value()
86 return &model
87}
88
89func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
90 t := styles.CurrentTheme()
91 m.modelType = modelType
92
93 var groups []list.Group[list.CompletionItem[ModelOption]]
94 // first none section
95 selectedItemID := ""
96
97 cfg := config.Get()
98 var currentModel config.SelectedModel
99 if m.modelType == LargeModelType {
100 currentModel = cfg.Models[config.SelectedModelTypeLarge]
101 } else {
102 currentModel = cfg.Models[config.SelectedModelTypeSmall]
103 }
104
105 configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
106 configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
107
108 // Create a map to track which providers we've already added
109 addedProviders := make(map[string]bool)
110
111 // First, add any configured providers that are not in the known providers list
112 // These should appear at the top of the list
113 knownProviders, err := config.Providers()
114 if err != nil {
115 return util.ReportError(err)
116 }
117 for providerID, providerConfig := range cfg.Providers.Seq2() {
118 if providerConfig.Disable {
119 continue
120 }
121
122 // Check if this provider is not in the known providers list
123 if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) ||
124 !slices.ContainsFunc(m.providers, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
125 // Convert config provider to provider.Provider format
126 configProvider := catwalk.Provider{
127 Name: providerConfig.Name,
128 ID: catwalk.InferenceProvider(providerID),
129 Models: make([]catwalk.Model, len(providerConfig.Models)),
130 }
131
132 // Convert models
133 for i, model := range providerConfig.Models {
134 configProvider.Models[i] = catwalk.Model{
135 ID: model.ID,
136 Name: model.Name,
137 CostPer1MIn: model.CostPer1MIn,
138 CostPer1MOut: model.CostPer1MOut,
139 CostPer1MInCached: model.CostPer1MInCached,
140 CostPer1MOutCached: model.CostPer1MOutCached,
141 ContextWindow: model.ContextWindow,
142 DefaultMaxTokens: model.DefaultMaxTokens,
143 CanReason: model.CanReason,
144 HasReasoningEffort: model.HasReasoningEffort,
145 DefaultReasoningEffort: model.DefaultReasoningEffort,
146 SupportsImages: model.SupportsImages,
147 }
148 }
149
150 // Add this unknown provider to the list
151 name := configProvider.Name
152 if name == "" {
153 name = string(configProvider.ID)
154 }
155 section := list.NewItemSection(name)
156 section.SetInfo(configured)
157 group := list.Group[list.CompletionItem[ModelOption]]{
158 Section: section,
159 }
160 for _, model := range configProvider.Models {
161 item := list.NewCompletionItem(model.Name, ModelOption{
162 Provider: configProvider,
163 Model: model,
164 },
165 list.WithCompletionID(
166 fmt.Sprintf("%s:%s", providerConfig.ID, model.ID),
167 ),
168 )
169
170 group.Items = append(group.Items, item)
171 if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
172 selectedItemID = item.ID()
173 }
174 }
175 groups = append(groups, group)
176
177 addedProviders[providerID] = true
178 }
179 }
180
181 // Then add the known providers from the predefined list
182 for _, provider := range m.providers {
183 // Skip if we already added this provider as an unknown provider
184 if addedProviders[string(provider.ID)] {
185 continue
186 }
187
188 // Check if this provider is configured and not disabled
189 if providerConfig, exists := cfg.Providers.Get(string(provider.ID)); exists && providerConfig.Disable {
190 continue
191 }
192
193 name := provider.Name
194 if name == "" {
195 name = string(provider.ID)
196 }
197
198 section := list.NewItemSection(name)
199 if _, ok := cfg.Providers.Get(string(provider.ID)); ok {
200 section.SetInfo(configured)
201 }
202 group := list.Group[list.CompletionItem[ModelOption]]{
203 Section: section,
204 }
205 for _, model := range provider.Models {
206 item := list.NewCompletionItem(model.Name, ModelOption{
207 Provider: provider,
208 Model: model,
209 },
210 list.WithCompletionID(
211 fmt.Sprintf("%s:%s", provider.ID, model.ID),
212 ),
213 )
214 group.Items = append(group.Items, item)
215 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
216 selectedItemID = item.ID()
217 }
218 }
219 groups = append(groups, group)
220 }
221
222 var cmds []tea.Cmd
223
224 cmd := m.list.SetGroups(groups)
225
226 if cmd != nil {
227 cmds = append(cmds, cmd)
228 }
229 cmd = m.list.SetSelected(selectedItemID)
230 if cmd != nil {
231 cmds = append(cmds, cmd)
232 }
233
234 return tea.Sequence(cmds...)
235}
236
237// GetModelType returns the current model type
238func (m *ModelListComponent) GetModelType() int {
239 return m.modelType
240}
241
242func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
243 m.list.SetInputPlaceholder(placeholder)
244}
245
246func (m *ModelListComponent) SetProviders(providers []catwalk.Provider) {
247 m.providers = providers
248}