1package dialog
2
3import (
4 "cmp"
5 "fmt"
6 "slices"
7 "strings"
8
9 "charm.land/bubbles/v2/help"
10 "charm.land/bubbles/v2/key"
11 "charm.land/bubbles/v2/textinput"
12 tea "charm.land/bubbletea/v2"
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/ui/common"
16 "github.com/charmbracelet/crush/internal/uiutil"
17 uv "github.com/charmbracelet/ultraviolet"
18)
19
20// ModelType represents the type of model to select.
21type ModelType int
22
23const (
24 ModelTypeLarge ModelType = iota
25 ModelTypeSmall
26)
27
28// String returns the string representation of the [ModelType].
29func (mt ModelType) String() string {
30 switch mt {
31 case ModelTypeLarge:
32 return "Large Task"
33 case ModelTypeSmall:
34 return "Small Task"
35 default:
36 return "Unknown"
37 }
38}
39
40// Config returns the corresponding config model type.
41func (mt ModelType) Config() config.SelectedModelType {
42 switch mt {
43 case ModelTypeLarge:
44 return config.SelectedModelTypeLarge
45 case ModelTypeSmall:
46 return config.SelectedModelTypeSmall
47 default:
48 return ""
49 }
50}
51
52// Placeholder returns the input placeholder for the model type.
53func (mt ModelType) Placeholder() string {
54 switch mt {
55 case ModelTypeLarge:
56 return largeModelInputPlaceholder
57 case ModelTypeSmall:
58 return smallModelInputPlaceholder
59 default:
60 return ""
61 }
62}
63
64const (
65 largeModelInputPlaceholder = "Choose a model for large, complex tasks"
66 smallModelInputPlaceholder = "Choose a model for small, simple tasks"
67)
68
69// ModelsID is the identifier for the model selection dialog.
70const ModelsID = "models"
71
72// Models represents a model selection dialog.
73type Models struct {
74 com *common.Common
75
76 modelType ModelType
77 providers []catwalk.Provider
78
79 keyMap struct {
80 Tab key.Binding
81 UpDown key.Binding
82 Select key.Binding
83 Next key.Binding
84 Previous key.Binding
85 Close key.Binding
86 }
87 list *ModelsList
88 input textinput.Model
89 help help.Model
90}
91
92var _ Dialog = (*Models)(nil)
93
94// NewModels creates a new Models dialog.
95func NewModels(com *common.Common) (*Models, error) {
96 t := com.Styles
97 m := &Models{}
98 m.com = com
99 help := help.New()
100 help.Styles = t.DialogHelpStyles()
101
102 m.help = help
103 m.list = NewModelsList(t)
104 m.list.Focus()
105 m.list.SetSelected(0)
106
107 m.input = textinput.New()
108 m.input.SetVirtualCursor(false)
109 m.input.Placeholder = largeModelInputPlaceholder
110 m.input.SetStyles(com.Styles.TextInput)
111 m.input.Focus()
112
113 m.keyMap.Tab = key.NewBinding(
114 key.WithKeys("tab", "shift+tab"),
115 key.WithHelp("tab", "toggle type"),
116 )
117 m.keyMap.Select = key.NewBinding(
118 key.WithKeys("enter", "ctrl+y"),
119 key.WithHelp("enter", "confirm"),
120 )
121 m.keyMap.UpDown = key.NewBinding(
122 key.WithKeys("up", "down"),
123 key.WithHelp("↑/↓", "choose"),
124 )
125 m.keyMap.Next = key.NewBinding(
126 key.WithKeys("down", "ctrl+n"),
127 key.WithHelp("↓", "next item"),
128 )
129 m.keyMap.Previous = key.NewBinding(
130 key.WithKeys("up", "ctrl+p"),
131 key.WithHelp("↑", "previous item"),
132 )
133 m.keyMap.Close = CloseKey
134
135 providers, err := getFilteredProviders(com.Config())
136 if err != nil {
137 return nil, fmt.Errorf("failed to get providers: %w", err)
138 }
139
140 m.providers = providers
141 if err := m.setProviderItems(); err != nil {
142 return nil, fmt.Errorf("failed to set provider items: %w", err)
143 }
144
145 return m, nil
146}
147
148// ID implements Dialog.
149func (m *Models) ID() string {
150 return ModelsID
151}
152
153// HandleMsg implements Dialog.
154func (m *Models) HandleMsg(msg tea.Msg) Action {
155 switch msg := msg.(type) {
156 case tea.KeyPressMsg:
157 switch {
158 case key.Matches(msg, m.keyMap.Close):
159 return ActionClose{}
160 case key.Matches(msg, m.keyMap.Previous):
161 m.list.Focus()
162 if m.list.IsSelectedFirst() {
163 m.list.SelectLast()
164 m.list.ScrollToBottom()
165 break
166 }
167 m.list.SelectPrev()
168 m.list.ScrollToSelected()
169 case key.Matches(msg, m.keyMap.Next):
170 m.list.Focus()
171 if m.list.IsSelectedLast() {
172 m.list.SelectFirst()
173 m.list.ScrollToTop()
174 break
175 }
176 m.list.SelectNext()
177 m.list.ScrollToSelected()
178 case key.Matches(msg, m.keyMap.Select):
179 selectedItem := m.list.SelectedItem()
180 if selectedItem == nil {
181 break
182 }
183
184 modelItem, ok := selectedItem.(*ModelItem)
185 if !ok {
186 break
187 }
188
189 return ActionSelectModel{
190 Provider: modelItem.prov,
191 Model: modelItem.SelectedModel(),
192 ModelType: modelItem.SelectedModelType(),
193 }
194 case key.Matches(msg, m.keyMap.Tab):
195 if m.modelType == ModelTypeLarge {
196 m.modelType = ModelTypeSmall
197 } else {
198 m.modelType = ModelTypeLarge
199 }
200 if err := m.setProviderItems(); err != nil {
201 return uiutil.ReportError(err)
202 }
203 default:
204 var cmd tea.Cmd
205 m.input, cmd = m.input.Update(msg)
206 value := m.input.Value()
207 m.list.SetFilter(value)
208 m.list.ScrollToSelected()
209 return ActionCmd{cmd}
210 }
211 }
212 return nil
213}
214
215// Cursor returns the cursor for the dialog.
216func (m *Models) Cursor() *tea.Cursor {
217 return InputCursor(m.com.Styles, m.input.Cursor())
218}
219
220// modelTypeRadioView returns the radio view for model type selection.
221func (m *Models) modelTypeRadioView() string {
222 t := m.com.Styles
223 textStyle := t.HalfMuted
224 largeRadioStyle := t.RadioOff
225 smallRadioStyle := t.RadioOff
226 if m.modelType == ModelTypeLarge {
227 largeRadioStyle = t.RadioOn
228 } else {
229 smallRadioStyle = t.RadioOn
230 }
231
232 largeRadio := largeRadioStyle.Padding(0, 1).Render()
233 smallRadio := smallRadioStyle.Padding(0, 1).Render()
234
235 return fmt.Sprintf("%s%s %s%s",
236 largeRadio, textStyle.Render(ModelTypeLarge.String()),
237 smallRadio, textStyle.Render(ModelTypeSmall.String()))
238}
239
240// Draw implements [Dialog].
241func (m *Models) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
242 t := m.com.Styles
243 width := max(0, min(defaultDialogMaxWidth, area.Dx()))
244 height := max(0, min(defaultDialogHeight, area.Dy()))
245 innerWidth := width - t.Dialog.View.GetHorizontalFrameSize()
246 heightOffset := t.Dialog.Title.GetVerticalFrameSize() + titleContentHeight +
247 t.Dialog.InputPrompt.GetVerticalFrameSize() + inputContentHeight +
248 t.Dialog.HelpView.GetVerticalFrameSize() +
249 t.Dialog.View.GetVerticalFrameSize()
250 m.input.SetWidth(innerWidth - t.Dialog.InputPrompt.GetHorizontalFrameSize() - 1) // (1) cursor padding
251 m.list.SetSize(innerWidth, height-heightOffset)
252 m.help.SetWidth(innerWidth)
253
254 rc := NewRenderContext(t, width)
255 rc.Title = "Switch Model"
256 rc.TitleInfo = m.modelTypeRadioView()
257 inputView := t.Dialog.InputPrompt.Render(m.input.View())
258 rc.AddPart(inputView)
259 listView := t.Dialog.List.Height(m.list.Height()).Render(m.list.Render())
260 rc.AddPart(listView)
261 rc.Help = m.help.View(m)
262
263 view := rc.Render()
264
265 cur := m.Cursor()
266 DrawCenterCursor(scr, area, view, cur)
267 return cur
268}
269
270// ShortHelp returns the short help view.
271func (m *Models) ShortHelp() []key.Binding {
272 return []key.Binding{
273 m.keyMap.UpDown,
274 m.keyMap.Tab,
275 m.keyMap.Select,
276 m.keyMap.Close,
277 }
278}
279
280// FullHelp returns the full help view.
281func (m *Models) FullHelp() [][]key.Binding {
282 return [][]key.Binding{
283 {
284 m.keyMap.Select,
285 m.keyMap.Next,
286 m.keyMap.Previous,
287 m.keyMap.Tab,
288 },
289 {
290 m.keyMap.Close,
291 },
292 }
293}
294
295// setProviderItems sets the provider items in the list.
296func (m *Models) setProviderItems() error {
297 t := m.com.Styles
298 cfg := m.com.Config()
299
300 var selectedItemID string
301 selectedType := m.modelType.Config()
302 currentModel := cfg.Models[selectedType]
303 recentItems := cfg.RecentModels[selectedType]
304
305 // Track providers already added to avoid duplicates
306 addedProviders := make(map[string]bool)
307
308 // Get a list of known providers to compare against
309 knownProviders, err := config.Providers(cfg)
310 if err != nil {
311 return fmt.Errorf("failed to get providers: %w", err)
312 }
313
314 containsProviderFunc := func(id string) func(p catwalk.Provider) bool {
315 return func(p catwalk.Provider) bool {
316 return p.ID == catwalk.InferenceProvider(id)
317 }
318 }
319
320 // itemsMap contains the keys of added model items.
321 itemsMap := make(map[string]*ModelItem)
322 groups := []ModelGroup{}
323 for id, p := range cfg.Providers.Seq2() {
324 if p.Disable {
325 continue
326 }
327
328 // Check if this provider is not in the known providers list
329 if !slices.ContainsFunc(knownProviders, containsProviderFunc(id)) ||
330 !slices.ContainsFunc(m.providers, containsProviderFunc(id)) {
331 provider := p.ToProvider()
332
333 // Add this unknown provider to the list
334 name := cmp.Or(p.Name, id)
335
336 addedProviders[id] = true
337
338 group := NewModelGroup(t, name, true)
339 for _, model := range p.Models {
340 item := NewModelItem(t, provider, model, m.modelType, false)
341 group.AppendItems(item)
342 itemsMap[item.ID()] = item
343 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
344 selectedItemID = item.ID()
345 }
346 }
347 if len(group.Items) > 0 {
348 groups = append(groups, group)
349 }
350 }
351 }
352
353 // Now add known providers from the predefined list
354 for _, provider := range m.providers {
355 providerID := string(provider.ID)
356 if addedProviders[providerID] {
357 continue
358 }
359
360 providerConfig, providerConfigured := cfg.Providers.Get(providerID)
361 if providerConfigured && providerConfig.Disable {
362 continue
363 }
364
365 displayProvider := provider
366 if providerConfigured {
367 displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
368 modelIndex := make(map[string]int, len(displayProvider.Models))
369 for i, model := range displayProvider.Models {
370 modelIndex[model.ID] = i
371 }
372 for _, model := range providerConfig.Models {
373 if model.ID == "" {
374 continue
375 }
376 if idx, ok := modelIndex[model.ID]; ok {
377 if model.Name != "" {
378 displayProvider.Models[idx].Name = model.Name
379 }
380 continue
381 }
382 if model.Name == "" {
383 model.Name = model.ID
384 }
385 displayProvider.Models = append(displayProvider.Models, model)
386 modelIndex[model.ID] = len(displayProvider.Models) - 1
387 }
388 }
389
390 name := displayProvider.Name
391 if name == "" {
392 name = providerID
393 }
394
395 group := NewModelGroup(t, name, providerConfigured)
396 for _, model := range displayProvider.Models {
397 item := NewModelItem(t, provider, model, m.modelType, false)
398 group.AppendItems(item)
399 itemsMap[item.ID()] = item
400 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
401 selectedItemID = item.ID()
402 }
403 }
404
405 groups = append(groups, group)
406 }
407
408 if len(recentItems) > 0 {
409 recentGroup := NewModelGroup(t, "Recently used", false)
410
411 var validRecentItems []config.SelectedModel
412 for _, recent := range recentItems {
413 key := modelKey(recent.Provider, recent.Model)
414 item, ok := itemsMap[key]
415 if !ok {
416 continue
417 }
418
419 // Show provider for recent items
420 item = NewModelItem(t, item.prov, item.model, m.modelType, true)
421 item.showProvider = true
422
423 validRecentItems = append(validRecentItems, recent)
424 recentGroup.AppendItems(item)
425 if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider {
426 selectedItemID = item.ID()
427 }
428 }
429
430 if len(validRecentItems) != len(recentItems) {
431 // FIXME: Does this need to be here? Is it mutating the config during a read?
432 if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
433 return fmt.Errorf("failed to update recent models: %w", err)
434 }
435 }
436
437 if len(recentGroup.Items) > 0 {
438 groups = append([]ModelGroup{recentGroup}, groups...)
439 }
440 }
441
442 // Set model groups in the list.
443 m.list.SetGroups(groups...)
444 m.list.SetSelectedItem(selectedItemID)
445
446 // Update placeholder based on model type
447 m.input.Placeholder = m.modelType.Placeholder()
448
449 return nil
450}
451
452func getFilteredProviders(cfg *config.Config) ([]catwalk.Provider, error) {
453 providers, err := config.Providers(cfg)
454 if err != nil {
455 return nil, fmt.Errorf("failed to get providers: %w", err)
456 }
457 filteredProviders := []catwalk.Provider{}
458 for _, p := range providers {
459 hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
460 if hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure {
461 filteredProviders = append(filteredProviders, p)
462 }
463 }
464 return filteredProviders, nil
465}
466
467func modelKey(providerID, modelID string) string {
468 if providerID == "" || modelID == "" {
469 return ""
470 }
471 return providerID + ":" + modelID
472}