1package dialog
2
3import (
4 "cmp"
5 "fmt"
6 "slices"
7
8 "charm.land/bubbles/v2/help"
9 "charm.land/bubbles/v2/key"
10 "charm.land/bubbles/v2/textinput"
11 tea "charm.land/bubbletea/v2"
12 "charm.land/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/ui/common"
15 "github.com/charmbracelet/crush/internal/ui/util"
16 uv "github.com/charmbracelet/ultraviolet"
17 xslice "github.com/charmbracelet/x/exp/slice"
18)
19
20// ModelType represents the type of model to select.
21type ModelType int
22
23const (
24 ModelTypeLarge ModelType = iota
25 ModelTypeSmall
26)
27
28// String returns the string representation of the [ModelType].
29func (mt ModelType) String() string {
30 switch mt {
31 case ModelTypeLarge:
32 return "Large Task"
33 case ModelTypeSmall:
34 return "Small Task"
35 default:
36 return "Unknown"
37 }
38}
39
40// Config returns the corresponding config model type.
41func (mt ModelType) Config() config.SelectedModelType {
42 switch mt {
43 case ModelTypeLarge:
44 return config.SelectedModelTypeLarge
45 case ModelTypeSmall:
46 return config.SelectedModelTypeSmall
47 default:
48 return ""
49 }
50}
51
52// Placeholder returns the input placeholder for the model type.
53func (mt ModelType) Placeholder() string {
54 switch mt {
55 case ModelTypeLarge:
56 return largeModelInputPlaceholder
57 case ModelTypeSmall:
58 return smallModelInputPlaceholder
59 default:
60 return ""
61 }
62}
63
64const (
65 onboardingModelInputPlaceholder = "Find your fave"
66 largeModelInputPlaceholder = "Choose a model for large, complex tasks"
67 smallModelInputPlaceholder = "Choose a model for small, simple tasks"
68)
69
70// ModelsID is the identifier for the model selection dialog.
71const ModelsID = "models"
72
73const defaultModelsDialogMaxWidth = 73
74
75// Models represents a model selection dialog.
76type Models struct {
77 com *common.Common
78 isOnboarding bool
79
80 modelType ModelType
81 providers []catwalk.Provider
82
83 keyMap struct {
84 Tab key.Binding
85 UpDown key.Binding
86 Select key.Binding
87 Edit key.Binding
88 Next key.Binding
89 Previous key.Binding
90 Close key.Binding
91 }
92 list *ModelsList
93 input textinput.Model
94 help help.Model
95}
96
97var _ Dialog = (*Models)(nil)
98
99// NewModels creates a new Models dialog.
100func NewModels(com *common.Common, isOnboarding bool) (*Models, error) {
101 t := com.Styles
102 m := &Models{}
103 m.com = com
104 m.isOnboarding = isOnboarding
105
106 help := help.New()
107 help.Styles = t.DialogHelpStyles()
108
109 m.help = help
110 m.list = NewModelsList(t)
111 m.list.Focus()
112 m.list.SetSelected(0)
113
114 m.input = textinput.New()
115 m.input.SetVirtualCursor(false)
116 m.input.Placeholder = onboardingModelInputPlaceholder
117 m.input.SetStyles(com.Styles.TextInput)
118 m.input.Focus()
119
120 m.keyMap.Tab = key.NewBinding(
121 key.WithKeys("tab", "shift+tab"),
122 key.WithHelp("tab", "toggle type"),
123 )
124 m.keyMap.Select = key.NewBinding(
125 key.WithKeys("enter", "ctrl+y"),
126 key.WithHelp("enter", "confirm"),
127 )
128 m.keyMap.Edit = key.NewBinding(
129 key.WithKeys("ctrl+e"),
130 key.WithHelp("ctrl+e", "edit"),
131 )
132 m.keyMap.UpDown = key.NewBinding(
133 key.WithKeys("up", "down"),
134 key.WithHelp("↑/↓", "choose"),
135 )
136 m.keyMap.Next = key.NewBinding(
137 key.WithKeys("down", "ctrl+n"),
138 key.WithHelp("↓", "next item"),
139 )
140 m.keyMap.Previous = key.NewBinding(
141 key.WithKeys("up", "ctrl+p"),
142 key.WithHelp("↑", "previous item"),
143 )
144 m.keyMap.Close = CloseKey
145
146 m.providers = slices.Collect(
147 xslice.Map(
148 com.Config().Providers.Seq(),
149 func(pc config.ProviderConfig) catwalk.Provider {
150 return pc.ToProvider()
151 },
152 ),
153 )
154 if err := m.setProviderItems(); err != nil {
155 return nil, fmt.Errorf("failed to set provider items: %w", err)
156 }
157
158 return m, nil
159}
160
161// ID implements Dialog.
162func (m *Models) ID() string {
163 return ModelsID
164}
165
166// HandleMsg implements Dialog.
167func (m *Models) HandleMsg(msg tea.Msg) Action {
168 switch msg := msg.(type) {
169 case tea.KeyPressMsg:
170 switch {
171 case key.Matches(msg, m.keyMap.Close):
172 return ActionClose{}
173 case key.Matches(msg, m.keyMap.Previous):
174 m.list.Focus()
175 if m.list.IsSelectedFirst() {
176 m.list.SelectLast()
177 m.list.ScrollToBottom()
178 break
179 }
180 m.list.SelectPrev()
181 m.list.ScrollToSelected()
182 case key.Matches(msg, m.keyMap.Next):
183 m.list.Focus()
184 if m.list.IsSelectedLast() {
185 m.list.SelectFirst()
186 m.list.ScrollToTop()
187 break
188 }
189 m.list.SelectNext()
190 m.list.ScrollToSelected()
191 case key.Matches(msg, m.keyMap.Select, m.keyMap.Edit):
192 selectedItem := m.list.SelectedItem()
193 if selectedItem == nil {
194 break
195 }
196
197 modelItem, ok := selectedItem.(*ModelItem)
198 if !ok {
199 break
200 }
201
202 isEdit := key.Matches(msg, m.keyMap.Edit)
203
204 return ActionSelectModel{
205 Provider: modelItem.prov,
206 Model: modelItem.SelectedModel(),
207 ModelType: modelItem.SelectedModelType(),
208 ReAuthenticate: isEdit,
209 }
210 case key.Matches(msg, m.keyMap.Tab):
211 if m.isOnboarding {
212 break
213 }
214 if m.modelType == ModelTypeLarge {
215 m.modelType = ModelTypeSmall
216 } else {
217 m.modelType = ModelTypeLarge
218 }
219 if err := m.setProviderItems(); err != nil {
220 return util.ReportError(err)
221 }
222 default:
223 var cmd tea.Cmd
224 m.input, cmd = m.input.Update(msg)
225 value := m.input.Value()
226 m.list.Focus()
227 m.list.SetFilter(value)
228 m.list.SelectFirst()
229 m.list.ScrollToTop()
230 return ActionCmd{cmd}
231 }
232 }
233 return nil
234}
235
236// Cursor returns the cursor for the dialog.
237func (m *Models) Cursor() *tea.Cursor {
238 return InputCursor(m.com.Styles, m.input.Cursor())
239}
240
241// modelTypeRadioView returns the radio view for model type selection.
242func (m *Models) modelTypeRadioView() string {
243 t := m.com.Styles
244 textStyle := t.HalfMuted
245 largeRadioStyle := t.RadioOff
246 smallRadioStyle := t.RadioOff
247 if m.modelType == ModelTypeLarge {
248 largeRadioStyle = t.RadioOn
249 } else {
250 smallRadioStyle = t.RadioOn
251 }
252
253 largeRadio := largeRadioStyle.Padding(0, 1).Render()
254 smallRadio := smallRadioStyle.Padding(0, 1).Render()
255
256 return fmt.Sprintf("%s%s %s%s",
257 largeRadio, textStyle.Render(ModelTypeLarge.String()),
258 smallRadio, textStyle.Render(ModelTypeSmall.String()))
259}
260
261// Draw implements [Dialog].
262func (m *Models) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
263 t := m.com.Styles
264 width := max(0, min(defaultModelsDialogMaxWidth, area.Dx()-t.Dialog.View.GetHorizontalBorderSize()))
265 height := max(0, min(defaultDialogHeight, area.Dy()-t.Dialog.View.GetVerticalBorderSize()))
266 innerWidth := width - t.Dialog.View.GetHorizontalFrameSize()
267 heightOffset := t.Dialog.Title.GetVerticalFrameSize() + titleContentHeight +
268 t.Dialog.InputPrompt.GetVerticalFrameSize() + inputContentHeight +
269 t.Dialog.HelpView.GetVerticalFrameSize() +
270 t.Dialog.View.GetVerticalFrameSize()
271
272 m.input.SetWidth(max(0, innerWidth-t.Dialog.InputPrompt.GetHorizontalFrameSize()-1)) // (1) cursor padding
273 m.list.SetSize(innerWidth, height-heightOffset)
274 m.help.SetWidth(innerWidth)
275
276 rc := NewRenderContext(t, width)
277 rc.Title = "Switch Model"
278 rc.TitleInfo = m.modelTypeRadioView()
279
280 if m.isOnboarding {
281 titleText := t.Dialog.PrimaryText.Render("To start, let's choose a provider and model.")
282 rc.AddPart(titleText)
283 }
284
285 inputView := t.Dialog.InputPrompt.Render(m.input.View())
286 rc.AddPart(inputView)
287
288 listView := t.Dialog.List.Height(m.list.Height()).Render(m.list.Render())
289 rc.AddPart(listView)
290
291 rc.Help = m.help.View(m)
292
293 cur := m.Cursor()
294
295 if m.isOnboarding {
296 rc.Title = ""
297 rc.TitleInfo = ""
298 rc.IsOnboarding = true
299 view := rc.Render()
300 DrawOnboardingCursor(scr, area, view, cur)
301
302 // FIXME(@andreynering): Figure it out how to properly fix this
303 if cur != nil {
304 cur.Y -= 1
305 cur.X -= 1
306 }
307 } else {
308 view := rc.Render()
309 DrawCenterCursor(scr, area, view, cur)
310 }
311 return cur
312}
313
314// ShortHelp returns the short help view.
315func (m *Models) ShortHelp() []key.Binding {
316 if m.isOnboarding {
317 return []key.Binding{
318 m.keyMap.UpDown,
319 m.keyMap.Select,
320 }
321 }
322 h := []key.Binding{
323 m.keyMap.UpDown,
324 m.keyMap.Tab,
325 m.keyMap.Select,
326 }
327 if m.isSelectedConfigured() {
328 h = append(h, m.keyMap.Edit)
329 }
330 h = append(h, m.keyMap.Close)
331 return h
332}
333
334// FullHelp returns the full help view.
335func (m *Models) FullHelp() [][]key.Binding {
336 return [][]key.Binding{m.ShortHelp()}
337}
338
339func (m *Models) isSelectedConfigured() bool {
340 selectedItem := m.list.SelectedItem()
341 if selectedItem == nil {
342 return false
343 }
344 modelItem, ok := selectedItem.(*ModelItem)
345 if !ok {
346 return false
347 }
348 providerID := string(modelItem.prov.ID)
349 _, isConfigured := m.com.Config().Providers.Get(providerID)
350 return isConfigured
351}
352
353// setProviderItems sets the provider items in the list.
354func (m *Models) setProviderItems() error {
355 t := m.com.Styles
356 cfg := m.com.Config()
357
358 var selectedItemID string
359 selectedType := m.modelType.Config()
360 currentModel := cfg.Models[selectedType]
361 recentItems := cfg.RecentModels[selectedType]
362
363 // Track providers already added to avoid duplicates
364 addedProviders := make(map[string]bool)
365
366 // Get a list of known providers to compare against
367 knownProviders, err := config.Providers(cfg)
368 if err != nil {
369 return fmt.Errorf("failed to get providers: %w", err)
370 }
371
372 containsProviderFunc := func(id string) func(p catwalk.Provider) bool {
373 return func(p catwalk.Provider) bool {
374 return p.ID == catwalk.InferenceProvider(id)
375 }
376 }
377
378 // itemsMap contains the keys of added model items.
379 itemsMap := make(map[string]*ModelItem)
380 groups := []ModelGroup{}
381 for id, p := range cfg.Providers.Seq2() {
382 if p.Disable {
383 continue
384 }
385
386 // Check if this provider is not in the known providers list
387 if !slices.ContainsFunc(knownProviders, containsProviderFunc(id)) ||
388 !slices.ContainsFunc(m.providers, containsProviderFunc(id)) {
389 provider := p.ToProvider()
390
391 // Add this unknown provider to the list
392 name := cmp.Or(p.Name, id)
393
394 addedProviders[id] = true
395
396 group := NewModelGroup(t, name, true)
397 for _, model := range p.Models {
398 item := NewModelItem(t, provider, model, m.modelType, false)
399 group.AppendItems(item)
400 itemsMap[item.ID()] = item
401 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
402 selectedItemID = item.ID()
403 }
404 }
405 if len(group.Items) > 0 {
406 groups = append(groups, group)
407 }
408 }
409 }
410
411 // Move "Charm Hyper" to first position.
412 // (But still after recent models and custom providers).
413 slices.SortStableFunc(m.providers, func(a, b catwalk.Provider) int {
414 switch {
415 case a.ID == "hyper":
416 return -1
417 case b.ID == "hyper":
418 return 1
419 default:
420 return 0
421 }
422 })
423
424 // Now add known providers from the predefined list
425 for _, provider := range m.providers {
426 providerID := string(provider.ID)
427 if addedProviders[providerID] {
428 continue
429 }
430
431 providerConfig, providerConfigured := cfg.Providers.Get(providerID)
432 if providerConfigured && providerConfig.Disable {
433 continue
434 }
435
436 displayProvider := provider
437 if providerConfigured {
438 displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
439 modelIndex := make(map[string]int, len(displayProvider.Models))
440 for i, model := range displayProvider.Models {
441 modelIndex[model.ID] = i
442 }
443 for _, model := range providerConfig.Models {
444 if model.ID == "" {
445 continue
446 }
447 if idx, ok := modelIndex[model.ID]; ok {
448 if model.Name != "" {
449 displayProvider.Models[idx].Name = model.Name
450 }
451 continue
452 }
453 if model.Name == "" {
454 model.Name = model.ID
455 }
456 displayProvider.Models = append(displayProvider.Models, model)
457 modelIndex[model.ID] = len(displayProvider.Models) - 1
458 }
459 }
460
461 name := displayProvider.Name
462 if name == "" {
463 name = providerID
464 }
465
466 group := NewModelGroup(t, name, providerConfigured)
467 for _, model := range displayProvider.Models {
468 item := NewModelItem(t, provider, model, m.modelType, false)
469 group.AppendItems(item)
470 itemsMap[item.ID()] = item
471 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
472 selectedItemID = item.ID()
473 }
474 }
475
476 groups = append(groups, group)
477 }
478
479 if len(recentItems) > 0 {
480 recentGroup := NewModelGroup(t, "Recently used", false)
481
482 var validRecentItems []config.SelectedModel
483 for _, recent := range recentItems {
484 key := modelKey(recent.Provider, recent.Model)
485 item, ok := itemsMap[key]
486 if !ok {
487 continue
488 }
489
490 // Show provider for recent items
491 item = NewModelItem(t, item.prov, item.model, m.modelType, true)
492 item.showProvider = true
493
494 validRecentItems = append(validRecentItems, recent)
495 recentGroup.AppendItems(item)
496 if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider {
497 selectedItemID = item.ID()
498 }
499 }
500
501 if len(validRecentItems) != len(recentItems) {
502 // FIXME: Does this need to be here? Is it mutating the config during a read?
503 if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
504 return fmt.Errorf("failed to update recent models: %w", err)
505 }
506 }
507
508 if len(recentGroup.Items) > 0 {
509 groups = append([]ModelGroup{recentGroup}, groups...)
510 }
511 }
512
513 // Set model groups in the list.
514 m.list.SetGroups(groups...)
515 m.list.SetSelectedItem(selectedItemID)
516 m.list.ScrollToTop()
517
518 // Update placeholder based on model type
519 if !m.isOnboarding {
520 m.input.Placeholder = m.modelType.Placeholder()
521 }
522
523 return nil
524}
525
526func modelKey(providerID, modelID string) string {
527 if providerID == "" || modelID == "" {
528 return ""
529 }
530 return providerID + ":" + modelID
531}