1package dialog
2
3import (
4 "cmp"
5 "fmt"
6 "slices"
7
8 "charm.land/bubbles/v2/help"
9 "charm.land/bubbles/v2/key"
10 "charm.land/bubbles/v2/textinput"
11 tea "charm.land/bubbletea/v2"
12 "charm.land/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/ui/common"
15 "github.com/charmbracelet/crush/internal/ui/util"
16 uv "github.com/charmbracelet/ultraviolet"
17 xslice "github.com/charmbracelet/x/exp/slice"
18)
19
20// ModelType represents the type of model to select.
21type ModelType int
22
23const (
24 ModelTypeLarge ModelType = iota
25 ModelTypeSmall
26)
27
28// String returns the string representation of the [ModelType].
29func (mt ModelType) String() string {
30 switch mt {
31 case ModelTypeLarge:
32 return "Large Task"
33 case ModelTypeSmall:
34 return "Small Task"
35 default:
36 return "Unknown"
37 }
38}
39
40// Config returns the corresponding config model type.
41func (mt ModelType) Config() config.SelectedModelType {
42 switch mt {
43 case ModelTypeLarge:
44 return config.SelectedModelTypeLarge
45 case ModelTypeSmall:
46 return config.SelectedModelTypeSmall
47 default:
48 return ""
49 }
50}
51
52// Placeholder returns the input placeholder for the model type.
53func (mt ModelType) Placeholder() string {
54 switch mt {
55 case ModelTypeLarge:
56 return largeModelInputPlaceholder
57 case ModelTypeSmall:
58 return smallModelInputPlaceholder
59 default:
60 return ""
61 }
62}
63
64const (
65 onboardingModelInputPlaceholder = "Find your fave"
66 largeModelInputPlaceholder = "Choose a model for large, complex tasks"
67 smallModelInputPlaceholder = "Choose a model for small, simple tasks"
68)
69
70// ModelsID is the identifier for the model selection dialog.
71const ModelsID = "models"
72
73const defaultModelsDialogMaxWidth = 73
74
75// Models represents a model selection dialog.
76type Models struct {
77 com *common.Common
78 isOnboarding bool
79
80 modelType ModelType
81 providers []catwalk.Provider
82
83 keyMap struct {
84 Tab key.Binding
85 UpDown key.Binding
86 Select key.Binding
87 Edit key.Binding
88 Next key.Binding
89 Previous key.Binding
90 Close key.Binding
91 }
92 list *ModelsList
93 input textinput.Model
94 help help.Model
95}
96
97var _ Dialog = (*Models)(nil)
98
99// NewModels creates a new Models dialog.
100func NewModels(com *common.Common, isOnboarding bool) (*Models, error) {
101 t := com.Styles
102 m := &Models{}
103 m.com = com
104 m.isOnboarding = isOnboarding
105
106 help := help.New()
107 help.Styles = t.DialogHelpStyles()
108
109 m.help = help
110 m.list = NewModelsList(t)
111 m.list.Focus()
112 m.list.SetSelected(0)
113
114 m.input = textinput.New()
115 m.input.SetVirtualCursor(false)
116 m.input.Placeholder = onboardingModelInputPlaceholder
117 m.input.SetStyles(com.Styles.TextInput)
118 m.input.Focus()
119
120 m.keyMap.Tab = key.NewBinding(
121 key.WithKeys("tab", "shift+tab"),
122 key.WithHelp("tab", "toggle type"),
123 )
124 m.keyMap.Select = key.NewBinding(
125 key.WithKeys("enter", "ctrl+y"),
126 key.WithHelp("enter", "confirm"),
127 )
128 m.keyMap.Edit = key.NewBinding(
129 key.WithKeys("ctrl+e"),
130 key.WithHelp("ctrl+e", "edit"),
131 )
132 m.keyMap.UpDown = key.NewBinding(
133 key.WithKeys("up", "down"),
134 key.WithHelp("↑/↓", "choose"),
135 )
136 m.keyMap.Next = key.NewBinding(
137 key.WithKeys("down", "ctrl+n"),
138 key.WithHelp("↓", "next item"),
139 )
140 m.keyMap.Previous = key.NewBinding(
141 key.WithKeys("up", "ctrl+p"),
142 key.WithHelp("↑", "previous item"),
143 )
144 m.keyMap.Close = CloseKey
145
146 m.providers = slices.Collect(
147 xslice.Map(
148 com.Config().Providers.Seq(),
149 func(pc config.ProviderConfig) catwalk.Provider {
150 return pc.ToProvider()
151 },
152 ),
153 )
154 if err := m.setProviderItems(); err != nil {
155 return nil, fmt.Errorf("failed to set provider items: %w", err)
156 }
157
158 return m, nil
159}
160
161// ID implements Dialog.
162func (m *Models) ID() string {
163 return ModelsID
164}
165
166// HandleMsg implements Dialog.
167func (m *Models) HandleMsg(msg tea.Msg) Action {
168 switch msg := msg.(type) {
169 case tea.KeyPressMsg:
170 switch {
171 case key.Matches(msg, m.keyMap.Close):
172 return ActionClose{}
173 case key.Matches(msg, m.keyMap.Previous):
174 m.list.Focus()
175 if m.list.IsSelectedFirst() {
176 m.list.SelectLast()
177 } else {
178 m.list.SelectPrev()
179 }
180 m.list.ScrollToSelected()
181 case key.Matches(msg, m.keyMap.Next):
182 m.list.Focus()
183 if m.list.IsSelectedLast() {
184 m.list.SelectFirst()
185 } else {
186 m.list.SelectNext()
187 }
188 m.list.ScrollToSelected()
189 case key.Matches(msg, m.keyMap.Select, m.keyMap.Edit):
190 selectedItem := m.list.SelectedItem()
191 if selectedItem == nil {
192 break
193 }
194
195 modelItem, ok := selectedItem.(*ModelItem)
196 if !ok {
197 break
198 }
199
200 isEdit := key.Matches(msg, m.keyMap.Edit)
201
202 return ActionSelectModel{
203 Provider: modelItem.prov,
204 Model: modelItem.SelectedModel(),
205 ModelType: modelItem.SelectedModelType(),
206 ReAuthenticate: isEdit,
207 }
208 case key.Matches(msg, m.keyMap.Tab):
209 if m.isOnboarding {
210 break
211 }
212 if m.modelType == ModelTypeLarge {
213 m.modelType = ModelTypeSmall
214 } else {
215 m.modelType = ModelTypeLarge
216 }
217 if err := m.setProviderItems(); err != nil {
218 return util.ReportError(err)
219 }
220 default:
221 var cmd tea.Cmd
222 m.input, cmd = m.input.Update(msg)
223 value := m.input.Value()
224 m.list.Focus()
225 m.list.SetFilter(value)
226 m.list.SelectFirst()
227 m.list.ScrollToTop()
228 return ActionCmd{cmd}
229 }
230 }
231 return nil
232}
233
234// Cursor returns the cursor for the dialog.
235func (m *Models) Cursor() *tea.Cursor {
236 return InputCursor(m.com.Styles, m.input.Cursor())
237}
238
239// modelTypeRadioView returns the radio view for model type selection.
240func (m *Models) modelTypeRadioView() string {
241 t := m.com.Styles
242 textStyle := t.HalfMuted
243 largeRadioStyle := t.RadioOff
244 smallRadioStyle := t.RadioOff
245 if m.modelType == ModelTypeLarge {
246 largeRadioStyle = t.RadioOn
247 } else {
248 smallRadioStyle = t.RadioOn
249 }
250
251 largeRadio := largeRadioStyle.Padding(0, 1).Render()
252 smallRadio := smallRadioStyle.Padding(0, 1).Render()
253
254 return fmt.Sprintf("%s%s %s%s",
255 largeRadio, textStyle.Render(ModelTypeLarge.String()),
256 smallRadio, textStyle.Render(ModelTypeSmall.String()))
257}
258
259// Draw implements [Dialog].
260func (m *Models) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
261 t := m.com.Styles
262 width := max(0, min(defaultModelsDialogMaxWidth, area.Dx()-t.Dialog.View.GetHorizontalBorderSize()))
263 height := max(0, min(defaultDialogHeight, area.Dy()-t.Dialog.View.GetVerticalBorderSize()))
264 innerWidth := width - t.Dialog.View.GetHorizontalFrameSize()
265 heightOffset := t.Dialog.Title.GetVerticalFrameSize() + titleContentHeight +
266 t.Dialog.InputPrompt.GetVerticalFrameSize() + inputContentHeight +
267 t.Dialog.HelpView.GetVerticalFrameSize() +
268 t.Dialog.View.GetVerticalFrameSize()
269
270 m.input.SetWidth(max(0, innerWidth-t.Dialog.InputPrompt.GetHorizontalFrameSize()-1)) // (1) cursor padding
271 m.list.SetSize(innerWidth, height-heightOffset)
272 m.help.SetWidth(innerWidth)
273
274 rc := NewRenderContext(t, width)
275 rc.Title = "Switch Model"
276 rc.TitleInfo = m.modelTypeRadioView()
277
278 if m.isOnboarding {
279 titleText := t.Dialog.PrimaryText.Render("To start, let's choose a provider and model.")
280 rc.AddPart(titleText)
281 }
282
283 inputView := t.Dialog.InputPrompt.Render(m.input.View())
284 rc.AddPart(inputView)
285
286 listView := t.Dialog.List.Height(m.list.Height()).Render(m.list.Render())
287 rc.AddPart(listView)
288
289 rc.Help = m.help.View(m)
290
291 cur := m.Cursor()
292
293 if m.isOnboarding {
294 rc.Title = ""
295 rc.TitleInfo = ""
296 rc.IsOnboarding = true
297 view := rc.Render()
298 DrawOnboardingCursor(scr, area, view, cur)
299
300 // FIXME(@andreynering): Figure it out how to properly fix this
301 if cur != nil {
302 cur.Y -= 1
303 cur.X -= 1
304 }
305 } else {
306 view := rc.Render()
307 DrawCenterCursor(scr, area, view, cur)
308 }
309 return cur
310}
311
312// ShortHelp returns the short help view.
313func (m *Models) ShortHelp() []key.Binding {
314 if m.isOnboarding {
315 return []key.Binding{
316 m.keyMap.UpDown,
317 m.keyMap.Select,
318 }
319 }
320 h := []key.Binding{
321 m.keyMap.UpDown,
322 m.keyMap.Tab,
323 m.keyMap.Select,
324 }
325 if m.isSelectedConfigured() {
326 h = append(h, m.keyMap.Edit)
327 }
328 h = append(h, m.keyMap.Close)
329 return h
330}
331
332// FullHelp returns the full help view.
333func (m *Models) FullHelp() [][]key.Binding {
334 return [][]key.Binding{m.ShortHelp()}
335}
336
337func (m *Models) isSelectedConfigured() bool {
338 selectedItem := m.list.SelectedItem()
339 if selectedItem == nil {
340 return false
341 }
342 modelItem, ok := selectedItem.(*ModelItem)
343 if !ok {
344 return false
345 }
346 providerID := string(modelItem.prov.ID)
347 _, isConfigured := m.com.Config().Providers.Get(providerID)
348 return isConfigured
349}
350
351// setProviderItems sets the provider items in the list.
352func (m *Models) setProviderItems() error {
353 t := m.com.Styles
354 cfg := m.com.Config()
355
356 var selectedItemID string
357 selectedType := m.modelType.Config()
358 currentModel := cfg.Models[selectedType]
359 recentItems := cfg.RecentModels[selectedType]
360
361 // Track providers already added to avoid duplicates
362 addedProviders := make(map[string]bool)
363
364 // Get a list of known providers to compare against
365 knownProviders, err := config.Providers(cfg)
366 if err != nil {
367 return fmt.Errorf("failed to get providers: %w", err)
368 }
369
370 containsProviderFunc := func(id string) func(p catwalk.Provider) bool {
371 return func(p catwalk.Provider) bool {
372 return p.ID == catwalk.InferenceProvider(id)
373 }
374 }
375
376 // itemsMap contains the keys of added model items.
377 itemsMap := make(map[string]*ModelItem)
378 groups := []ModelGroup{}
379 for id, p := range cfg.Providers.Seq2() {
380 if p.Disable {
381 continue
382 }
383
384 // Check if this provider is not in the known providers list
385 if !slices.ContainsFunc(knownProviders, containsProviderFunc(id)) ||
386 !slices.ContainsFunc(m.providers, containsProviderFunc(id)) {
387 provider := p.ToProvider()
388
389 // Add this unknown provider to the list
390 name := cmp.Or(p.Name, id)
391
392 addedProviders[id] = true
393
394 group := NewModelGroup(t, name, true)
395 for _, model := range p.Models {
396 item := NewModelItem(t, provider, model, m.modelType, false)
397 group.AppendItems(item)
398 itemsMap[item.ID()] = item
399 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
400 selectedItemID = item.ID()
401 }
402 }
403 if len(group.Items) > 0 {
404 groups = append(groups, group)
405 }
406 }
407 }
408
409 // Move "Charm Hyper" to first position.
410 // (But still after recent models and custom providers).
411 slices.SortStableFunc(m.providers, func(a, b catwalk.Provider) int {
412 switch {
413 case a.ID == "hyper":
414 return -1
415 case b.ID == "hyper":
416 return 1
417 default:
418 return 0
419 }
420 })
421
422 // Now add known providers from the predefined list
423 for _, provider := range m.providers {
424 providerID := string(provider.ID)
425 if addedProviders[providerID] {
426 continue
427 }
428
429 providerConfig, providerConfigured := cfg.Providers.Get(providerID)
430 if providerConfigured && providerConfig.Disable {
431 continue
432 }
433
434 displayProvider := provider
435 if providerConfigured {
436 displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
437 modelIndex := make(map[string]int, len(displayProvider.Models))
438 for i, model := range displayProvider.Models {
439 modelIndex[model.ID] = i
440 }
441 for _, model := range providerConfig.Models {
442 if model.ID == "" {
443 continue
444 }
445 if idx, ok := modelIndex[model.ID]; ok {
446 if model.Name != "" {
447 displayProvider.Models[idx].Name = model.Name
448 }
449 continue
450 }
451 if model.Name == "" {
452 model.Name = model.ID
453 }
454 displayProvider.Models = append(displayProvider.Models, model)
455 modelIndex[model.ID] = len(displayProvider.Models) - 1
456 }
457 }
458
459 name := displayProvider.Name
460 if name == "" {
461 name = providerID
462 }
463
464 group := NewModelGroup(t, name, providerConfigured)
465 for _, model := range displayProvider.Models {
466 item := NewModelItem(t, provider, model, m.modelType, false)
467 group.AppendItems(item)
468 itemsMap[item.ID()] = item
469 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
470 selectedItemID = item.ID()
471 }
472 }
473
474 groups = append(groups, group)
475 }
476
477 if len(recentItems) > 0 {
478 recentGroup := NewModelGroup(t, "Recently used", false)
479
480 var validRecentItems []config.SelectedModel
481 for _, recent := range recentItems {
482 key := modelKey(recent.Provider, recent.Model)
483 item, ok := itemsMap[key]
484 if !ok {
485 continue
486 }
487
488 // Show provider for recent items
489 item = NewModelItem(t, item.prov, item.model, m.modelType, true)
490 item.showProvider = true
491
492 validRecentItems = append(validRecentItems, recent)
493 recentGroup.AppendItems(item)
494 if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider {
495 selectedItemID = item.ID()
496 }
497 }
498
499 if len(validRecentItems) != len(recentItems) {
500 // FIXME: Does this need to be here? Is it mutating the config during a read?
501 if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
502 return fmt.Errorf("failed to update recent models: %w", err)
503 }
504 }
505
506 if len(recentGroup.Items) > 0 {
507 groups = append([]ModelGroup{recentGroup}, groups...)
508 }
509 }
510
511 // Set model groups in the list.
512 m.list.SetGroups(groups...)
513 m.list.SetSelectedItem(selectedItemID)
514 m.list.ScrollToTop()
515
516 // Update placeholder based on model type
517 if !m.isOnboarding {
518 m.input.Placeholder = m.modelType.Placeholder()
519 }
520
521 return nil
522}
523
524func modelKey(providerID, modelID string) string {
525 if providerID == "" || modelID == "" {
526 return ""
527 }
528 return providerID + ":" + modelID
529}