1package dialog
2
3import (
4 "cmp"
5 "fmt"
6 "slices"
7 "strings"
8
9 "charm.land/bubbles/v2/help"
10 "charm.land/bubbles/v2/key"
11 "charm.land/bubbles/v2/textinput"
12 tea "charm.land/bubbletea/v2"
13 "charm.land/lipgloss/v2"
14 "github.com/charmbracelet/catwalk/pkg/catwalk"
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/ui/common"
17 "github.com/charmbracelet/crush/internal/uiutil"
18)
19
20// ModelType represents the type of model to select.
21type ModelType int
22
23const (
24 ModelTypeLarge ModelType = iota
25 ModelTypeSmall
26)
27
28// String returns the string representation of the [ModelType].
29func (mt ModelType) String() string {
30 switch mt {
31 case ModelTypeLarge:
32 return "Large Task"
33 case ModelTypeSmall:
34 return "Small Task"
35 default:
36 return "Unknown"
37 }
38}
39
40// Config returns the corresponding config model type.
41func (mt ModelType) Config() config.SelectedModelType {
42 switch mt {
43 case ModelTypeLarge:
44 return config.SelectedModelTypeLarge
45 case ModelTypeSmall:
46 return config.SelectedModelTypeSmall
47 default:
48 return ""
49 }
50}
51
52// Placeholder returns the input placeholder for the model type.
53func (mt ModelType) Placeholder() string {
54 switch mt {
55 case ModelTypeLarge:
56 return largeModelInputPlaceholder
57 case ModelTypeSmall:
58 return smallModelInputPlaceholder
59 default:
60 return ""
61 }
62}
63
64const (
65 largeModelInputPlaceholder = "Choose a model for large, complex tasks"
66 smallModelInputPlaceholder = "Choose a model for small, simple tasks"
67)
68
69// ModelsID is the identifier for the model selection dialog.
70const ModelsID = "models"
71
72// Models represents a model selection dialog.
73type Models struct {
74 com *common.Common
75
76 modelType ModelType
77 providers []catwalk.Provider
78
79 width, height int
80
81 keyMap struct {
82 Tab key.Binding
83 UpDown key.Binding
84 Select key.Binding
85 Next key.Binding
86 Previous key.Binding
87 Close key.Binding
88 }
89 list *ModelsList
90 input textinput.Model
91 help help.Model
92}
93
94var _ Dialog = (*Models)(nil)
95
96// NewModels creates a new Models dialog.
97func NewModels(com *common.Common) (*Models, error) {
98 t := com.Styles
99 m := &Models{}
100 m.com = com
101 help := help.New()
102 help.Styles = t.DialogHelpStyles()
103
104 m.help = help
105 m.list = NewModelsList(t)
106 m.list.Focus()
107 m.list.SetSelected(0)
108
109 m.input = textinput.New()
110 m.input.SetVirtualCursor(false)
111 m.input.Placeholder = largeModelInputPlaceholder
112 m.input.SetStyles(com.Styles.TextInput)
113 m.input.Focus()
114
115 m.keyMap.Tab = key.NewBinding(
116 key.WithKeys("tab", "shift+tab"),
117 key.WithHelp("tab", "toggle type"),
118 )
119 m.keyMap.Select = key.NewBinding(
120 key.WithKeys("enter", "ctrl+y"),
121 key.WithHelp("enter", "confirm"),
122 )
123 m.keyMap.UpDown = key.NewBinding(
124 key.WithKeys("up", "down"),
125 key.WithHelp("↑/↓", "choose"),
126 )
127 m.keyMap.Next = key.NewBinding(
128 key.WithKeys("down", "ctrl+n"),
129 key.WithHelp("↓", "next item"),
130 )
131 m.keyMap.Previous = key.NewBinding(
132 key.WithKeys("up", "ctrl+p"),
133 key.WithHelp("↑", "previous item"),
134 )
135 m.keyMap.Close = CloseKey
136
137 providers, err := getFilteredProviders(com.Config())
138 if err != nil {
139 return nil, fmt.Errorf("failed to get providers: %w", err)
140 }
141
142 m.providers = providers
143 if err := m.setProviderItems(); err != nil {
144 return nil, fmt.Errorf("failed to set provider items: %w", err)
145 }
146
147 return m, nil
148}
149
150// SetSize sets the size of the dialog.
151func (m *Models) SetSize(width, height int) {
152 t := m.com.Styles
153 m.width = width
154 m.height = height
155 innerWidth := width - t.Dialog.View.GetHorizontalFrameSize()
156 heightOffset := t.Dialog.Title.GetVerticalFrameSize() + 1 + // (1) title content
157 t.Dialog.InputPrompt.GetVerticalFrameSize() + 1 + // (1) input content
158 t.Dialog.HelpView.GetVerticalFrameSize() +
159 t.Dialog.View.GetVerticalFrameSize()
160 m.input.SetWidth(innerWidth - t.Dialog.InputPrompt.GetHorizontalFrameSize() - 1) // (1) cursor padding
161 m.list.SetSize(innerWidth, height-heightOffset)
162 m.help.SetWidth(width)
163}
164
165// ID implements Dialog.
166func (m *Models) ID() string {
167 return ModelsID
168}
169
170// Update implements Dialog.
171func (m *Models) Update(msg tea.Msg) tea.Msg {
172 switch msg := msg.(type) {
173 case tea.KeyPressMsg:
174 switch {
175 case key.Matches(msg, m.keyMap.Close):
176 return CloseMsg{}
177 case key.Matches(msg, m.keyMap.Previous):
178 m.list.Focus()
179 if m.list.IsSelectedFirst() {
180 m.list.SelectLast()
181 m.list.ScrollToBottom()
182 break
183 }
184 m.list.SelectPrev()
185 m.list.ScrollToSelected()
186 case key.Matches(msg, m.keyMap.Next):
187 m.list.Focus()
188 if m.list.IsSelectedLast() {
189 m.list.SelectFirst()
190 m.list.ScrollToTop()
191 break
192 }
193 m.list.SelectNext()
194 m.list.ScrollToSelected()
195 case key.Matches(msg, m.keyMap.Select):
196 selectedItem := m.list.SelectedItem()
197 if selectedItem == nil {
198 break
199 }
200
201 modelItem, ok := selectedItem.(*ModelItem)
202 if !ok {
203 break
204 }
205
206 return ModelSelectedMsg{
207 Model: modelItem.SelectedModel(),
208 ModelType: modelItem.SelectedModelType(),
209 }
210 case key.Matches(msg, m.keyMap.Tab):
211 if m.modelType == ModelTypeLarge {
212 m.modelType = ModelTypeSmall
213 } else {
214 m.modelType = ModelTypeLarge
215 }
216 if err := m.setProviderItems(); err != nil {
217 return uiutil.ReportError(err)
218 }
219 default:
220 var cmd tea.Cmd
221 m.input, cmd = m.input.Update(msg)
222 value := m.input.Value()
223 m.list.SetFilter(value)
224 m.list.ScrollToSelected()
225 return cmd
226 }
227 }
228 return nil
229}
230
231// Cursor returns the cursor for the dialog.
232func (m *Models) Cursor() *tea.Cursor {
233 return InputCursor(m.com.Styles, m.input.Cursor())
234}
235
236// modelTypeRadioView returns the radio view for model type selection.
237func (m *Models) modelTypeRadioView() string {
238 t := m.com.Styles
239 textStyle := t.HalfMuted
240 largeRadioStyle := t.RadioOff
241 smallRadioStyle := t.RadioOff
242 if m.modelType == ModelTypeLarge {
243 largeRadioStyle = t.RadioOn
244 } else {
245 smallRadioStyle = t.RadioOn
246 }
247
248 largeRadio := largeRadioStyle.Padding(0, 1).Render()
249 smallRadio := smallRadioStyle.Padding(0, 1).Render()
250
251 return fmt.Sprintf("%s%s %s%s",
252 largeRadio, textStyle.Render(ModelTypeLarge.String()),
253 smallRadio, textStyle.Render(ModelTypeSmall.String()))
254}
255
256// View implements Dialog.
257func (m *Models) View() string {
258 t := m.com.Styles
259 titleStyle := t.Dialog.Title
260 dialogStyle := t.Dialog.View
261
262 radios := m.modelTypeRadioView()
263
264 headerOffset := lipgloss.Width(radios) + titleStyle.GetHorizontalFrameSize() +
265 dialogStyle.GetHorizontalFrameSize()
266
267 header := common.DialogTitle(t, "Switch Model", m.width-headerOffset) + radios
268
269 return HeaderInputListHelpView(t, m.width, m.list.Height(), header,
270 m.input.View(), m.list.Render(), m.help.View(m))
271}
272
273// ShortHelp returns the short help view.
274func (m *Models) ShortHelp() []key.Binding {
275 return []key.Binding{
276 m.keyMap.UpDown,
277 m.keyMap.Tab,
278 m.keyMap.Select,
279 m.keyMap.Close,
280 }
281}
282
283// FullHelp returns the full help view.
284func (m *Models) FullHelp() [][]key.Binding {
285 return [][]key.Binding{
286 {
287 m.keyMap.Select,
288 m.keyMap.Next,
289 m.keyMap.Previous,
290 m.keyMap.Tab,
291 },
292 {
293 m.keyMap.Close,
294 },
295 }
296}
297
298// setProviderItems sets the provider items in the list.
299func (m *Models) setProviderItems() error {
300 t := m.com.Styles
301 cfg := m.com.Config()
302
303 var selectedItemID string
304 selectedType := m.modelType.Config()
305 currentModel := cfg.Models[selectedType]
306 recentItems := cfg.RecentModels[selectedType]
307
308 // Track providers already added to avoid duplicates
309 addedProviders := make(map[string]bool)
310
311 // Get a list of known providers to compare against
312 knownProviders, err := config.Providers(cfg)
313 if err != nil {
314 return fmt.Errorf("failed to get providers: %w", err)
315 }
316
317 containsProviderFunc := func(id string) func(p catwalk.Provider) bool {
318 return func(p catwalk.Provider) bool {
319 return p.ID == catwalk.InferenceProvider(id)
320 }
321 }
322
323 // itemsMap contains the keys of added model items.
324 itemsMap := make(map[string]*ModelItem)
325 groups := []ModelGroup{}
326 for id, p := range cfg.Providers.Seq2() {
327 if p.Disable {
328 continue
329 }
330
331 // Check if this provider is not in the known providers list
332 if !slices.ContainsFunc(knownProviders, containsProviderFunc(id)) ||
333 !slices.ContainsFunc(m.providers, containsProviderFunc(id)) {
334 provider := p.ToProvider()
335
336 // Add this unknown provider to the list
337 name := p.Name
338 if name == "" {
339 name = id
340 }
341
342 addedProviders[id] = true
343
344 group := NewModelGroup(t, name, true)
345 for _, model := range p.Models {
346 item := NewModelItem(t, provider, model, m.modelType, false)
347 group.AppendItems(item)
348 itemsMap[item.ID()] = item
349 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
350 selectedItemID = item.ID()
351 }
352 }
353 }
354 }
355
356 // Now add known providers from the predefined list
357 for _, provider := range m.providers {
358 providerID := string(provider.ID)
359 if addedProviders[providerID] {
360 continue
361 }
362
363 providerConfig, providerConfigured := cfg.Providers.Get(providerID)
364 if providerConfigured && providerConfig.Disable {
365 continue
366 }
367
368 displayProvider := provider
369 if providerConfigured {
370 displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
371 modelIndex := make(map[string]int, len(displayProvider.Models))
372 for i, model := range displayProvider.Models {
373 modelIndex[model.ID] = i
374 }
375 for _, model := range providerConfig.Models {
376 if model.ID == "" {
377 continue
378 }
379 if idx, ok := modelIndex[model.ID]; ok {
380 if model.Name != "" {
381 displayProvider.Models[idx].Name = model.Name
382 }
383 continue
384 }
385 if model.Name == "" {
386 model.Name = model.ID
387 }
388 displayProvider.Models = append(displayProvider.Models, model)
389 modelIndex[model.ID] = len(displayProvider.Models) - 1
390 }
391 }
392
393 name := displayProvider.Name
394 if name == "" {
395 name = providerID
396 }
397
398 group := NewModelGroup(t, name, providerConfigured)
399 for _, model := range displayProvider.Models {
400 item := NewModelItem(t, provider, model, m.modelType, false)
401 group.AppendItems(item)
402 itemsMap[item.ID()] = item
403 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
404 selectedItemID = item.ID()
405 }
406 }
407
408 groups = append(groups, group)
409 }
410
411 if len(recentItems) > 0 {
412 recentGroup := NewModelGroup(t, "Recently used", false)
413
414 var validRecentItems []config.SelectedModel
415 for _, recent := range recentItems {
416 key := modelKey(recent.Provider, recent.Model)
417 item, ok := itemsMap[key]
418 if !ok {
419 continue
420 }
421
422 // Show provider for recent items
423 item = NewModelItem(t, item.prov, item.model, m.modelType, true)
424 item.showProvider = true
425
426 validRecentItems = append(validRecentItems, recent)
427 recentGroup.AppendItems(item)
428 if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider {
429 selectedItemID = item.ID()
430 }
431 }
432
433 if len(validRecentItems) != len(recentItems) {
434 // FIXME: Does this need to be here? Is it mutating the config during a read?
435 if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
436 return fmt.Errorf("failed to update recent models: %w", err)
437 }
438 }
439
440 if len(recentGroup.Items) > 0 {
441 groups = append([]ModelGroup{recentGroup}, groups...)
442 }
443 }
444
445 // Set model groups in the list.
446 m.list.SetGroups(groups...)
447 m.list.SetSelectedItem(selectedItemID)
448
449 // Update placeholder based on model type
450 m.input.Placeholder = m.modelType.Placeholder()
451
452 return nil
453}
454
455func getFilteredProviders(cfg *config.Config) ([]catwalk.Provider, error) {
456 providers, err := config.Providers(cfg)
457 if err != nil {
458 return nil, fmt.Errorf("failed to get providers: %w", err)
459 }
460 filteredProviders := []catwalk.Provider{}
461 for _, p := range providers {
462 hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
463 if hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure {
464 filteredProviders = append(filteredProviders, p)
465 }
466 }
467 return filteredProviders, nil
468}
469
470func modelKey(providerID, modelID string) string {
471 if providerID == "" || modelID == "" {
472 return ""
473 }
474 return providerID + ":" + modelID
475}