1package dialog
2
3import (
4 "cmp"
5 "fmt"
6 "slices"
7 "strings"
8
9 "charm.land/bubbles/v2/help"
10 "charm.land/bubbles/v2/key"
11 "charm.land/bubbles/v2/textinput"
12 tea "charm.land/bubbletea/v2"
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/ui/common"
16 "github.com/charmbracelet/crush/internal/uiutil"
17 uv "github.com/charmbracelet/ultraviolet"
18)
19
20// ModelType represents the type of model to select.
21type ModelType int
22
23const (
24 ModelTypeLarge ModelType = iota
25 ModelTypeSmall
26)
27
28// String returns the string representation of the [ModelType].
29func (mt ModelType) String() string {
30 switch mt {
31 case ModelTypeLarge:
32 return "Large Task"
33 case ModelTypeSmall:
34 return "Small Task"
35 default:
36 return "Unknown"
37 }
38}
39
40// Config returns the corresponding config model type.
41func (mt ModelType) Config() config.SelectedModelType {
42 switch mt {
43 case ModelTypeLarge:
44 return config.SelectedModelTypeLarge
45 case ModelTypeSmall:
46 return config.SelectedModelTypeSmall
47 default:
48 return ""
49 }
50}
51
52// Placeholder returns the input placeholder for the model type.
53func (mt ModelType) Placeholder() string {
54 switch mt {
55 case ModelTypeLarge:
56 return largeModelInputPlaceholder
57 case ModelTypeSmall:
58 return smallModelInputPlaceholder
59 default:
60 return ""
61 }
62}
63
64const (
65 onboardingModelInputPlaceholder = "Find your fave"
66 largeModelInputPlaceholder = "Choose a model for large, complex tasks"
67 smallModelInputPlaceholder = "Choose a model for small, simple tasks"
68)
69
70// ModelsID is the identifier for the model selection dialog.
71const ModelsID = "models"
72
73const defaultModelsDialogMaxWidth = 70
74
75// Models represents a model selection dialog.
76type Models struct {
77 com *common.Common
78 isOnboarding bool
79
80 modelType ModelType
81 providers []catwalk.Provider
82
83 keyMap struct {
84 Tab key.Binding
85 UpDown key.Binding
86 Select key.Binding
87 Next key.Binding
88 Previous key.Binding
89 Close key.Binding
90 }
91 list *ModelsList
92 input textinput.Model
93 help help.Model
94}
95
96var _ Dialog = (*Models)(nil)
97
98// NewModels creates a new Models dialog.
99func NewModels(com *common.Common, isOnboarding bool) (*Models, error) {
100 t := com.Styles
101 m := &Models{}
102 m.com = com
103 m.isOnboarding = isOnboarding
104
105 help := help.New()
106 help.Styles = t.DialogHelpStyles()
107
108 m.help = help
109 m.list = NewModelsList(t)
110 m.list.Focus()
111 m.list.SetSelected(0)
112
113 m.input = textinput.New()
114 m.input.SetVirtualCursor(false)
115 m.input.Placeholder = onboardingModelInputPlaceholder
116 m.input.SetStyles(com.Styles.TextInput)
117 m.input.Focus()
118
119 m.keyMap.Tab = key.NewBinding(
120 key.WithKeys("tab", "shift+tab"),
121 key.WithHelp("tab", "toggle type"),
122 )
123 m.keyMap.Select = key.NewBinding(
124 key.WithKeys("enter", "ctrl+y"),
125 key.WithHelp("enter", "confirm"),
126 )
127 m.keyMap.UpDown = key.NewBinding(
128 key.WithKeys("up", "down"),
129 key.WithHelp("↑/↓", "choose"),
130 )
131 m.keyMap.Next = key.NewBinding(
132 key.WithKeys("down", "ctrl+n"),
133 key.WithHelp("↓", "next item"),
134 )
135 m.keyMap.Previous = key.NewBinding(
136 key.WithKeys("up", "ctrl+p"),
137 key.WithHelp("↑", "previous item"),
138 )
139 m.keyMap.Close = CloseKey
140
141 providers, err := getFilteredProviders(com.Config())
142 if err != nil {
143 return nil, fmt.Errorf("failed to get providers: %w", err)
144 }
145
146 m.providers = providers
147 if err := m.setProviderItems(); err != nil {
148 return nil, fmt.Errorf("failed to set provider items: %w", err)
149 }
150
151 return m, nil
152}
153
154// ID implements Dialog.
155func (m *Models) ID() string {
156 return ModelsID
157}
158
159// HandleMsg implements Dialog.
160func (m *Models) HandleMsg(msg tea.Msg) Action {
161 switch msg := msg.(type) {
162 case tea.KeyPressMsg:
163 switch {
164 case key.Matches(msg, m.keyMap.Close):
165 return ActionClose{}
166 case key.Matches(msg, m.keyMap.Previous):
167 m.list.Focus()
168 if m.list.IsSelectedFirst() {
169 m.list.SelectLast()
170 m.list.ScrollToBottom()
171 break
172 }
173 m.list.SelectPrev()
174 m.list.ScrollToSelected()
175 case key.Matches(msg, m.keyMap.Next):
176 m.list.Focus()
177 if m.list.IsSelectedLast() {
178 m.list.SelectFirst()
179 m.list.ScrollToTop()
180 break
181 }
182 m.list.SelectNext()
183 m.list.ScrollToSelected()
184 case key.Matches(msg, m.keyMap.Select):
185 selectedItem := m.list.SelectedItem()
186 if selectedItem == nil {
187 break
188 }
189
190 modelItem, ok := selectedItem.(*ModelItem)
191 if !ok {
192 break
193 }
194
195 return ActionSelectModel{
196 Provider: modelItem.prov,
197 Model: modelItem.SelectedModel(),
198 ModelType: modelItem.SelectedModelType(),
199 }
200 case key.Matches(msg, m.keyMap.Tab):
201 if m.isOnboarding {
202 break
203 }
204 if m.modelType == ModelTypeLarge {
205 m.modelType = ModelTypeSmall
206 } else {
207 m.modelType = ModelTypeLarge
208 }
209 if err := m.setProviderItems(); err != nil {
210 return uiutil.ReportError(err)
211 }
212 default:
213 var cmd tea.Cmd
214 m.input, cmd = m.input.Update(msg)
215 value := m.input.Value()
216 m.list.Focus()
217 m.list.SetFilter(value)
218 m.list.SelectFirst()
219 m.list.ScrollToTop()
220 return ActionCmd{cmd}
221 }
222 }
223 return nil
224}
225
226// Cursor returns the cursor for the dialog.
227func (m *Models) Cursor() *tea.Cursor {
228 return InputCursor(m.com.Styles, m.input.Cursor())
229}
230
231// modelTypeRadioView returns the radio view for model type selection.
232func (m *Models) modelTypeRadioView() string {
233 t := m.com.Styles
234 textStyle := t.HalfMuted
235 largeRadioStyle := t.RadioOff
236 smallRadioStyle := t.RadioOff
237 if m.modelType == ModelTypeLarge {
238 largeRadioStyle = t.RadioOn
239 } else {
240 smallRadioStyle = t.RadioOn
241 }
242
243 largeRadio := largeRadioStyle.Padding(0, 1).Render()
244 smallRadio := smallRadioStyle.Padding(0, 1).Render()
245
246 return fmt.Sprintf("%s%s %s%s",
247 largeRadio, textStyle.Render(ModelTypeLarge.String()),
248 smallRadio, textStyle.Render(ModelTypeSmall.String()))
249}
250
251// Draw implements [Dialog].
252func (m *Models) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
253 t := m.com.Styles
254 width := max(0, min(defaultModelsDialogMaxWidth, area.Dx()))
255 height := max(0, min(defaultDialogHeight, area.Dy()))
256 innerWidth := width - t.Dialog.View.GetHorizontalFrameSize()
257 heightOffset := t.Dialog.Title.GetVerticalFrameSize() + titleContentHeight +
258 t.Dialog.InputPrompt.GetVerticalFrameSize() + inputContentHeight +
259 t.Dialog.HelpView.GetVerticalFrameSize() +
260 t.Dialog.View.GetVerticalFrameSize()
261
262 m.input.SetWidth(max(0, innerWidth-t.Dialog.InputPrompt.GetHorizontalFrameSize()-1)) // (1) cursor padding
263 m.list.SetSize(innerWidth, height-heightOffset)
264 m.help.SetWidth(innerWidth)
265
266 rc := NewRenderContext(t, width)
267 rc.Title = "Switch Model"
268 rc.TitleInfo = m.modelTypeRadioView()
269
270 if m.isOnboarding {
271 titleText := t.Dialog.PrimaryText.Render("To start, let's choose a provider and model.")
272 rc.AddPart(titleText)
273 }
274
275 inputView := t.Dialog.InputPrompt.Render(m.input.View())
276 rc.AddPart(inputView)
277
278 listView := t.Dialog.List.Height(m.list.Height()).Render(m.list.Render())
279 rc.AddPart(listView)
280
281 rc.Help = m.help.View(m)
282
283 cur := m.Cursor()
284
285 if m.isOnboarding {
286 rc.Title = ""
287 rc.TitleInfo = ""
288 rc.IsOnboarding = true
289 view := rc.Render()
290 DrawOnboardingCursor(scr, area, view, cur)
291
292 // FIXME(@andreynering): Figure it out how to properly fix this
293 if cur != nil {
294 cur.Y -= 1
295 cur.X -= 1
296 }
297 } else {
298 view := rc.Render()
299 DrawCenterCursor(scr, area, view, cur)
300 }
301 return cur
302}
303
304// ShortHelp returns the short help view.
305func (m *Models) ShortHelp() []key.Binding {
306 if m.isOnboarding {
307 return []key.Binding{
308 m.keyMap.UpDown,
309 m.keyMap.Select,
310 }
311 }
312 return []key.Binding{
313 m.keyMap.UpDown,
314 m.keyMap.Tab,
315 m.keyMap.Select,
316 m.keyMap.Close,
317 }
318}
319
320// FullHelp returns the full help view.
321func (m *Models) FullHelp() [][]key.Binding {
322 return [][]key.Binding{
323 {
324 m.keyMap.Select,
325 m.keyMap.Next,
326 m.keyMap.Previous,
327 m.keyMap.Tab,
328 },
329 {
330 m.keyMap.Close,
331 },
332 }
333}
334
335// setProviderItems sets the provider items in the list.
336func (m *Models) setProviderItems() error {
337 t := m.com.Styles
338 cfg := m.com.Config()
339
340 var selectedItemID string
341 selectedType := m.modelType.Config()
342 currentModel := cfg.Models[selectedType]
343 recentItems := cfg.RecentModels[selectedType]
344
345 // Track providers already added to avoid duplicates
346 addedProviders := make(map[string]bool)
347
348 // Get a list of known providers to compare against
349 knownProviders, err := config.Providers(cfg)
350 if err != nil {
351 return fmt.Errorf("failed to get providers: %w", err)
352 }
353
354 containsProviderFunc := func(id string) func(p catwalk.Provider) bool {
355 return func(p catwalk.Provider) bool {
356 return p.ID == catwalk.InferenceProvider(id)
357 }
358 }
359
360 // itemsMap contains the keys of added model items.
361 itemsMap := make(map[string]*ModelItem)
362 groups := []ModelGroup{}
363 for id, p := range cfg.Providers.Seq2() {
364 if p.Disable {
365 continue
366 }
367
368 // Check if this provider is not in the known providers list
369 if !slices.ContainsFunc(knownProviders, containsProviderFunc(id)) ||
370 !slices.ContainsFunc(m.providers, containsProviderFunc(id)) {
371 provider := p.ToProvider()
372
373 // Add this unknown provider to the list
374 name := cmp.Or(p.Name, id)
375
376 addedProviders[id] = true
377
378 group := NewModelGroup(t, name, true)
379 for _, model := range p.Models {
380 item := NewModelItem(t, provider, model, m.modelType, false)
381 group.AppendItems(item)
382 itemsMap[item.ID()] = item
383 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
384 selectedItemID = item.ID()
385 }
386 }
387 if len(group.Items) > 0 {
388 groups = append(groups, group)
389 }
390 }
391 }
392
393 // Move "Charm Hyper" to first position.
394 // (But still after recent models and custom providers).
395 slices.SortStableFunc(m.providers, func(a, b catwalk.Provider) int {
396 switch {
397 case a.ID == "hyper":
398 return -1
399 case b.ID == "hyper":
400 return 1
401 default:
402 return 0
403 }
404 })
405
406 // Now add known providers from the predefined list
407 for _, provider := range m.providers {
408 providerID := string(provider.ID)
409 if addedProviders[providerID] {
410 continue
411 }
412
413 providerConfig, providerConfigured := cfg.Providers.Get(providerID)
414 if providerConfigured && providerConfig.Disable {
415 continue
416 }
417
418 displayProvider := provider
419 if providerConfigured {
420 displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
421 modelIndex := make(map[string]int, len(displayProvider.Models))
422 for i, model := range displayProvider.Models {
423 modelIndex[model.ID] = i
424 }
425 for _, model := range providerConfig.Models {
426 if model.ID == "" {
427 continue
428 }
429 if idx, ok := modelIndex[model.ID]; ok {
430 if model.Name != "" {
431 displayProvider.Models[idx].Name = model.Name
432 }
433 continue
434 }
435 if model.Name == "" {
436 model.Name = model.ID
437 }
438 displayProvider.Models = append(displayProvider.Models, model)
439 modelIndex[model.ID] = len(displayProvider.Models) - 1
440 }
441 }
442
443 name := displayProvider.Name
444 if name == "" {
445 name = providerID
446 }
447
448 group := NewModelGroup(t, name, providerConfigured)
449 for _, model := range displayProvider.Models {
450 item := NewModelItem(t, provider, model, m.modelType, false)
451 group.AppendItems(item)
452 itemsMap[item.ID()] = item
453 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
454 selectedItemID = item.ID()
455 }
456 }
457
458 groups = append(groups, group)
459 }
460
461 if len(recentItems) > 0 {
462 recentGroup := NewModelGroup(t, "Recently used", false)
463
464 var validRecentItems []config.SelectedModel
465 for _, recent := range recentItems {
466 key := modelKey(recent.Provider, recent.Model)
467 item, ok := itemsMap[key]
468 if !ok {
469 continue
470 }
471
472 // Show provider for recent items
473 item = NewModelItem(t, item.prov, item.model, m.modelType, true)
474 item.showProvider = true
475
476 validRecentItems = append(validRecentItems, recent)
477 recentGroup.AppendItems(item)
478 if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider {
479 selectedItemID = item.ID()
480 }
481 }
482
483 if len(validRecentItems) != len(recentItems) {
484 // FIXME: Does this need to be here? Is it mutating the config during a read?
485 if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
486 return fmt.Errorf("failed to update recent models: %w", err)
487 }
488 }
489
490 if len(recentGroup.Items) > 0 {
491 groups = append([]ModelGroup{recentGroup}, groups...)
492 }
493 }
494
495 // Set model groups in the list.
496 m.list.SetGroups(groups...)
497 m.list.SetSelectedItem(selectedItemID)
498 m.list.ScrollToTop()
499
500 // Update placeholder based on model type
501 if !m.isOnboarding {
502 m.input.Placeholder = m.modelType.Placeholder()
503 }
504
505 return nil
506}
507
508func getFilteredProviders(cfg *config.Config) ([]catwalk.Provider, error) {
509 providers, err := config.Providers(cfg)
510 if err != nil {
511 return nil, fmt.Errorf("failed to get providers: %w", err)
512 }
513 var filteredProviders []catwalk.Provider
514 for _, p := range providers {
515 var (
516 isAzure = p.ID == catwalk.InferenceProviderAzure
517 isCopilot = p.ID == catwalk.InferenceProviderCopilot
518 isHyper = string(p.ID) == "hyper"
519 hasAPIKeyEnv = strings.HasPrefix(p.APIKey, "$")
520 _, isConfigured = cfg.Providers.Get(string(p.ID))
521 )
522 if isAzure || isCopilot || isHyper || hasAPIKeyEnv || isConfigured {
523 filteredProviders = append(filteredProviders, p)
524 }
525 }
526 return filteredProviders, nil
527}
528
529func modelKey(providerID, modelID string) string {
530 if providerID == "" || modelID == "" {
531 return ""
532 }
533 return providerID + ":" + modelID
534}