1package models
2
3import (
4 "fmt"
5 "slices"
6
7 tea "github.com/charmbracelet/bubbletea/v2"
8 "github.com/charmbracelet/catwalk/pkg/catwalk"
9 "github.com/charmbracelet/crush/internal/config"
10 "github.com/charmbracelet/crush/internal/tui/exp/list"
11 "github.com/charmbracelet/crush/internal/tui/styles"
12 "github.com/charmbracelet/crush/internal/tui/util"
13)
14
15type listModel = list.FilterableGroupList[list.CompletionItem[ModelOption]]
16
17type ModelListComponent struct {
18 list listModel
19 modelType int
20 providers []catwalk.Provider
21}
22
23func NewModelListComponent(keyMap list.KeyMap, inputPlaceholder string, shouldResize bool) *ModelListComponent {
24 t := styles.CurrentTheme()
25 inputStyle := t.S().Base.PaddingLeft(1).PaddingBottom(1)
26 options := []list.ListOption{
27 list.WithKeyMap(keyMap),
28 list.WithWrapNavigation(),
29 }
30 if shouldResize {
31 options = append(options, list.WithResizeByList())
32 }
33 modelList := list.NewFilterableGroupedList(
34 []list.Group[list.CompletionItem[ModelOption]]{},
35 list.WithFilterInputStyle(inputStyle),
36 list.WithFilterPlaceholder(inputPlaceholder),
37 list.WithFilterListOptions(
38 options...,
39 ),
40 )
41
42 return &ModelListComponent{
43 list: modelList,
44 modelType: LargeModelType,
45 }
46}
47
48func (m *ModelListComponent) Init() tea.Cmd {
49 var cmds []tea.Cmd
50 if len(m.providers) == 0 {
51 providers, err := config.Providers()
52 m.providers = providers
53 if err != nil {
54 cmds = append(cmds, util.ReportError(err))
55 }
56 }
57 cmds = append(cmds, m.list.Init(), m.SetModelType(m.modelType))
58 return tea.Batch(cmds...)
59}
60
61func (m *ModelListComponent) Update(msg tea.Msg) (*ModelListComponent, tea.Cmd) {
62 u, cmd := m.list.Update(msg)
63 m.list = u.(listModel)
64 return m, cmd
65}
66
67func (m *ModelListComponent) View() string {
68 return m.list.View()
69}
70
71func (m *ModelListComponent) Cursor() *tea.Cursor {
72 return m.list.Cursor()
73}
74
75func (m *ModelListComponent) SetSize(width, height int) tea.Cmd {
76 return m.list.SetSize(width, height)
77}
78
79func (m *ModelListComponent) SelectedModel() *ModelOption {
80 s := m.list.SelectedItem()
81 if s == nil {
82 return nil
83 }
84 sv := *s
85 model := sv.Value()
86 return &model
87}
88
89func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
90 t := styles.CurrentTheme()
91 m.modelType = modelType
92
93 var groups []list.Group[list.CompletionItem[ModelOption]]
94 // first none section
95 selectedItemID := ""
96
97 cfg := config.Get()
98 var currentModel config.SelectedModel
99 if m.modelType == LargeModelType {
100 currentModel = cfg.Models[config.SelectedModelTypeLarge]
101 } else {
102 currentModel = cfg.Models[config.SelectedModelTypeSmall]
103 }
104
105 configuredIcon := t.S().Base.Foreground(t.Success).Render(styles.CheckIcon)
106 configured := fmt.Sprintf("%s %s", configuredIcon, t.S().Subtle.Render("Configured"))
107
108 // Create a map to track which providers we've already added
109 addedProviders := make(map[string]bool)
110
111 // First, add any configured providers that are not in the known providers list
112 // These should appear at the top of the list
113 knownProviders, err := config.Providers()
114 if err != nil {
115 return util.ReportError(err)
116 }
117 for providerID, providerConfig := range cfg.Providers.Seq2() {
118 if providerConfig.Disable {
119 continue
120 }
121
122 // Check if this provider is not in the known providers list
123 if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
124 // Convert config provider to provider.Provider format
125 configProvider := catwalk.Provider{
126 Name: providerConfig.Name,
127 ID: catwalk.InferenceProvider(providerID),
128 Models: make([]catwalk.Model, len(providerConfig.Models)),
129 }
130
131 // Convert models
132 for i, model := range providerConfig.Models {
133 configProvider.Models[i] = catwalk.Model{
134 ID: model.ID,
135 Name: model.Name,
136 CostPer1MIn: model.CostPer1MIn,
137 CostPer1MOut: model.CostPer1MOut,
138 CostPer1MInCached: model.CostPer1MInCached,
139 CostPer1MOutCached: model.CostPer1MOutCached,
140 ContextWindow: model.ContextWindow,
141 DefaultMaxTokens: model.DefaultMaxTokens,
142 CanReason: model.CanReason,
143 HasReasoningEffort: model.HasReasoningEffort,
144 DefaultReasoningEffort: model.DefaultReasoningEffort,
145 SupportsImages: model.SupportsImages,
146 }
147 }
148
149 // Add this unknown provider to the list
150 name := configProvider.Name
151 if name == "" {
152 name = string(configProvider.ID)
153 }
154 section := list.NewItemSection(name)
155 section.SetInfo(configured)
156 group := list.Group[list.CompletionItem[ModelOption]]{
157 Section: section,
158 }
159 for _, model := range configProvider.Models {
160 item := list.NewCompletionItem(model.Model, ModelOption{
161 Provider: configProvider,
162 Model: model,
163 },
164 list.WithCompletionID(
165 fmt.Sprintf("%s:%s", providerConfig.ID, model.ID),
166 ),
167 )
168
169 group.Items = append(group.Items, item)
170 if model.ID == currentModel.Model && string(configProvider.ID) == currentModel.Provider {
171 selectedItemID = item.ID()
172 }
173 }
174 groups = append(groups, group)
175
176 addedProviders[providerID] = true
177 }
178 }
179
180 // Then add the known providers from the predefined list
181 for _, provider := range m.providers {
182 // Skip if we already added this provider as an unknown provider
183 if addedProviders[string(provider.ID)] {
184 continue
185 }
186
187 // Check if this provider is configured and not disabled
188 if providerConfig, exists := cfg.Providers.Get(string(provider.ID)); exists && providerConfig.Disable {
189 continue
190 }
191
192 name := provider.Name
193 if name == "" {
194 name = string(provider.ID)
195 }
196
197 section := list.NewItemSection(name)
198 if _, ok := cfg.Providers[string(provider.ID)]; ok {
199 section.SetInfo(configured)
200 }
201 group := list.Group[list.CompletionItem[ModelOption]]{
202 Section: section,
203 }
204 for _, model := range provider.Models {
205 item := list.NewCompletionItem(model.Model, ModelOption{
206 Provider: provider,
207 Model: model,
208 },
209 list.WithCompletionID(
210 fmt.Sprintf("%s:%s", provider.ID, model.ID),
211 ),
212 )
213 group.Items = append(group.Items, item)
214 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
215 selectedItemID = item.ID()
216 }
217 }
218 groups = append(groups, group)
219 }
220
221 var cmds []tea.Cmd
222
223 cmd := m.list.SetGroups(groups)
224
225 if cmd != nil {
226 cmds = append(cmds, cmd)
227 }
228 cmd = m.list.SetSelected(selectedItemID)
229 if cmd != nil {
230 cmds = append(cmds, cmd)
231 }
232
233 return tea.Sequence(cmds...)
234}
235
236// GetModelType returns the current model type
237func (m *ModelListComponent) GetModelType() int {
238 return m.modelType
239}
240
241func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
242 m.list.SetInputPlaceholder(placeholder)
243}
244
245func (m *ModelListComponent) SetProviders(providers []catwalk.Provider) {
246 m.providers = providers
247}