1package dialog
2
3import (
4 "cmp"
5 "fmt"
6 "slices"
7
8 "charm.land/bubbles/v2/help"
9 "charm.land/bubbles/v2/key"
10 "charm.land/bubbles/v2/textinput"
11 tea "charm.land/bubbletea/v2"
12 "charm.land/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/ui/common"
15 "github.com/charmbracelet/crush/internal/ui/util"
16 uv "github.com/charmbracelet/ultraviolet"
17)
18
19// ModelType represents the type of model to select.
20type ModelType int
21
22const (
23 ModelTypeLarge ModelType = iota
24 ModelTypeSmall
25)
26
27// String returns the string representation of the [ModelType].
28func (mt ModelType) String() string {
29 switch mt {
30 case ModelTypeLarge:
31 return "Large Task"
32 case ModelTypeSmall:
33 return "Small Task"
34 default:
35 return "Unknown"
36 }
37}
38
39// Config returns the corresponding config model type.
40func (mt ModelType) Config() config.SelectedModelType {
41 switch mt {
42 case ModelTypeLarge:
43 return config.SelectedModelTypeLarge
44 case ModelTypeSmall:
45 return config.SelectedModelTypeSmall
46 default:
47 return ""
48 }
49}
50
51// Placeholder returns the input placeholder for the model type.
52func (mt ModelType) Placeholder() string {
53 switch mt {
54 case ModelTypeLarge:
55 return largeModelInputPlaceholder
56 case ModelTypeSmall:
57 return smallModelInputPlaceholder
58 default:
59 return ""
60 }
61}
62
63const (
64 onboardingModelInputPlaceholder = "Find your fave"
65 largeModelInputPlaceholder = "Choose a model for large, complex tasks"
66 smallModelInputPlaceholder = "Choose a model for small, simple tasks"
67)
68
69// ModelsID is the identifier for the model selection dialog.
70const ModelsID = "models"
71
72const defaultModelsDialogMaxWidth = 73
73
74// Models represents a model selection dialog.
75type Models struct {
76 com *common.Common
77 isOnboarding bool
78
79 modelType ModelType
80 providers []catwalk.Provider
81
82 keyMap struct {
83 Tab key.Binding
84 UpDown key.Binding
85 Select key.Binding
86 Edit key.Binding
87 Next key.Binding
88 Previous key.Binding
89 Close key.Binding
90 }
91 list *ModelsList
92 input textinput.Model
93 help help.Model
94}
95
96var _ Dialog = (*Models)(nil)
97
98// NewModels creates a new Models dialog.
99func NewModels(com *common.Common, isOnboarding bool) (*Models, error) {
100 t := com.Styles
101 m := &Models{}
102 m.com = com
103 m.isOnboarding = isOnboarding
104
105 help := help.New()
106 help.Styles = t.DialogHelpStyles()
107
108 m.help = help
109 m.list = NewModelsList(t)
110 m.list.Focus()
111 m.list.SetSelected(0)
112
113 m.input = textinput.New()
114 m.input.SetVirtualCursor(false)
115 m.input.Placeholder = onboardingModelInputPlaceholder
116 m.input.SetStyles(com.Styles.TextInput)
117 m.input.Focus()
118
119 m.keyMap.Tab = key.NewBinding(
120 key.WithKeys("tab", "shift+tab"),
121 key.WithHelp("tab", "toggle type"),
122 )
123 m.keyMap.Select = key.NewBinding(
124 key.WithKeys("enter", "ctrl+y"),
125 key.WithHelp("enter", "confirm"),
126 )
127 m.keyMap.Edit = key.NewBinding(
128 key.WithKeys("ctrl+e"),
129 key.WithHelp("ctrl+e", "edit"),
130 )
131 m.keyMap.UpDown = key.NewBinding(
132 key.WithKeys("up", "down"),
133 key.WithHelp("↑/↓", "choose"),
134 )
135 m.keyMap.Next = key.NewBinding(
136 key.WithKeys("down", "ctrl+n"),
137 key.WithHelp("↓", "next item"),
138 )
139 m.keyMap.Previous = key.NewBinding(
140 key.WithKeys("up", "ctrl+p"),
141 key.WithHelp("↑", "previous item"),
142 )
143 m.keyMap.Close = CloseKey
144
145 var err error
146 m.providers, err = config.Providers(m.com.Config())
147 if err != nil {
148 return nil, fmt.Errorf("failed to get providers: %w", err)
149 }
150
151 if err := m.setProviderItems(); err != nil {
152 return nil, fmt.Errorf("failed to set provider items: %w", err)
153 }
154
155 return m, nil
156}
157
158// ID implements Dialog.
159func (m *Models) ID() string {
160 return ModelsID
161}
162
163// HandleMsg implements Dialog.
164func (m *Models) HandleMsg(msg tea.Msg) Action {
165 switch msg := msg.(type) {
166 case tea.KeyPressMsg:
167 switch {
168 case key.Matches(msg, m.keyMap.Close):
169 return ActionClose{}
170 case key.Matches(msg, m.keyMap.Previous):
171 m.list.Focus()
172 if m.list.IsSelectedFirst() {
173 m.list.SelectLast()
174 } else {
175 m.list.SelectPrev()
176 }
177 m.list.ScrollToSelected()
178 case key.Matches(msg, m.keyMap.Next):
179 m.list.Focus()
180 if m.list.IsSelectedLast() {
181 m.list.SelectFirst()
182 } else {
183 m.list.SelectNext()
184 }
185 m.list.ScrollToSelected()
186 case key.Matches(msg, m.keyMap.Select, m.keyMap.Edit):
187 selectedItem := m.list.SelectedItem()
188 if selectedItem == nil {
189 break
190 }
191
192 modelItem, ok := selectedItem.(*ModelItem)
193 if !ok {
194 break
195 }
196
197 isEdit := key.Matches(msg, m.keyMap.Edit)
198
199 return ActionSelectModel{
200 Provider: modelItem.prov,
201 Model: modelItem.SelectedModel(),
202 ModelType: modelItem.SelectedModelType(),
203 ReAuthenticate: isEdit,
204 }
205 case key.Matches(msg, m.keyMap.Tab):
206 if m.isOnboarding {
207 break
208 }
209 if m.modelType == ModelTypeLarge {
210 m.modelType = ModelTypeSmall
211 } else {
212 m.modelType = ModelTypeLarge
213 }
214 if err := m.setProviderItems(); err != nil {
215 return util.ReportError(err)
216 }
217 default:
218 var cmd tea.Cmd
219 m.input, cmd = m.input.Update(msg)
220 value := m.input.Value()
221 m.list.Focus()
222 m.list.SetFilter(value)
223 m.list.SelectFirst()
224 m.list.ScrollToTop()
225 return ActionCmd{cmd}
226 }
227 }
228 return nil
229}
230
231// Cursor returns the cursor for the dialog.
232func (m *Models) Cursor() *tea.Cursor {
233 return InputCursor(m.com.Styles, m.input.Cursor())
234}
235
236// modelTypeRadioView returns the radio view for model type selection.
237func (m *Models) modelTypeRadioView() string {
238 t := m.com.Styles
239 textStyle := t.HalfMuted
240 largeRadioStyle := t.RadioOff
241 smallRadioStyle := t.RadioOff
242 if m.modelType == ModelTypeLarge {
243 largeRadioStyle = t.RadioOn
244 } else {
245 smallRadioStyle = t.RadioOn
246 }
247
248 largeRadio := largeRadioStyle.Padding(0, 1).Render()
249 smallRadio := smallRadioStyle.Padding(0, 1).Render()
250
251 return fmt.Sprintf("%s%s %s%s",
252 largeRadio, textStyle.Render(ModelTypeLarge.String()),
253 smallRadio, textStyle.Render(ModelTypeSmall.String()))
254}
255
256// Draw implements [Dialog].
257func (m *Models) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
258 t := m.com.Styles
259 width := max(0, min(defaultModelsDialogMaxWidth, area.Dx()-t.Dialog.View.GetHorizontalBorderSize()))
260 height := max(0, min(defaultDialogHeight, area.Dy()-t.Dialog.View.GetVerticalBorderSize()))
261 innerWidth := width - t.Dialog.View.GetHorizontalFrameSize()
262 heightOffset := t.Dialog.Title.GetVerticalFrameSize() + titleContentHeight +
263 t.Dialog.InputPrompt.GetVerticalFrameSize() + inputContentHeight +
264 t.Dialog.HelpView.GetVerticalFrameSize() +
265 t.Dialog.View.GetVerticalFrameSize()
266
267 m.input.SetWidth(max(0, innerWidth-t.Dialog.InputPrompt.GetHorizontalFrameSize()-1)) // (1) cursor padding
268 m.list.SetSize(innerWidth, height-heightOffset)
269 m.help.SetWidth(innerWidth)
270
271 rc := NewRenderContext(t, width)
272 rc.Title = "Switch Model"
273 rc.TitleInfo = m.modelTypeRadioView()
274
275 if m.isOnboarding {
276 titleText := t.Dialog.PrimaryText.Render("To start, let's choose a provider and model.")
277 rc.AddPart(titleText)
278 }
279
280 inputView := t.Dialog.InputPrompt.Render(m.input.View())
281 rc.AddPart(inputView)
282
283 listView := t.Dialog.List.Height(m.list.Height()).Render(m.list.Render())
284 rc.AddPart(listView)
285
286 rc.Help = m.help.View(m)
287
288 cur := m.Cursor()
289
290 if m.isOnboarding {
291 rc.Title = ""
292 rc.TitleInfo = ""
293 rc.IsOnboarding = true
294 view := rc.Render()
295 cur = adjustOnboardingInputCursor(t, cur)
296 DrawOnboardingCursor(scr, area, view, cur)
297 } else {
298 view := rc.Render()
299 DrawCenterCursor(scr, area, view, cur)
300 }
301 return cur
302}
303
304// ShortHelp returns the short help view.
305func (m *Models) ShortHelp() []key.Binding {
306 if m.isOnboarding {
307 return []key.Binding{
308 m.keyMap.UpDown,
309 m.keyMap.Select,
310 }
311 }
312 h := []key.Binding{
313 m.keyMap.UpDown,
314 m.keyMap.Tab,
315 m.keyMap.Select,
316 }
317 if m.isSelectedConfigured() {
318 h = append(h, m.keyMap.Edit)
319 }
320 h = append(h, m.keyMap.Close)
321 return h
322}
323
324// FullHelp returns the full help view.
325func (m *Models) FullHelp() [][]key.Binding {
326 return [][]key.Binding{m.ShortHelp()}
327}
328
329func (m *Models) isSelectedConfigured() bool {
330 selectedItem := m.list.SelectedItem()
331 if selectedItem == nil {
332 return false
333 }
334 modelItem, ok := selectedItem.(*ModelItem)
335 if !ok {
336 return false
337 }
338 providerID := string(modelItem.prov.ID)
339 _, isConfigured := m.com.Config().Providers.Get(providerID)
340 return isConfigured
341}
342
343// setProviderItems sets the provider items in the list.
344func (m *Models) setProviderItems() error {
345 t := m.com.Styles
346 cfg := m.com.Config()
347
348 var selectedItemID string
349 selectedType := m.modelType.Config()
350 currentModel := cfg.Models[selectedType]
351 recentItems := cfg.RecentModels[selectedType]
352
353 // Track providers already added to avoid duplicates
354 addedProviders := make(map[string]bool)
355
356 // Get a list of known providers to compare against
357 knownProviders, err := config.Providers(cfg)
358 if err != nil {
359 return fmt.Errorf("failed to get providers: %w", err)
360 }
361
362 containsProviderFunc := func(id string) func(p catwalk.Provider) bool {
363 return func(p catwalk.Provider) bool {
364 return p.ID == catwalk.InferenceProvider(id)
365 }
366 }
367
368 // itemsMap contains the keys of added model items.
369 itemsMap := make(map[string]*ModelItem)
370 groups := []ModelGroup{}
371 for id, p := range cfg.Providers.Seq2() {
372 if p.Disable {
373 continue
374 }
375
376 // Check if this provider is not in the known providers list
377 if !slices.ContainsFunc(knownProviders, containsProviderFunc(id)) ||
378 !slices.ContainsFunc(m.providers, containsProviderFunc(id)) {
379 provider := p.ToProvider()
380
381 // Add this unknown provider to the list
382 name := cmp.Or(p.Name, id)
383
384 addedProviders[id] = true
385
386 group := NewModelGroup(t, name, true)
387 for _, model := range p.Models {
388 item := NewModelItem(t, provider, model, m.modelType, false)
389 group.AppendItems(item)
390 itemsMap[item.ID()] = item
391 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
392 selectedItemID = item.ID()
393 }
394 }
395 if len(group.Items) > 0 {
396 groups = append(groups, group)
397 }
398 }
399 }
400
401 // Move "Charm Hyper" to first position.
402 // (But still after recent models and custom providers).
403 slices.SortStableFunc(m.providers, func(a, b catwalk.Provider) int {
404 switch {
405 case a.ID == "hyper":
406 return -1
407 case b.ID == "hyper":
408 return 1
409 default:
410 return 0
411 }
412 })
413
414 // Now add known providers from the predefined list
415 for _, provider := range m.providers {
416 providerID := string(provider.ID)
417 if addedProviders[providerID] {
418 continue
419 }
420
421 providerConfig, providerConfigured := cfg.Providers.Get(providerID)
422 if providerConfigured && providerConfig.Disable {
423 continue
424 }
425
426 displayProvider := provider
427 if providerConfigured {
428 displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
429 modelIndex := make(map[string]int, len(displayProvider.Models))
430 for i, model := range displayProvider.Models {
431 modelIndex[model.ID] = i
432 }
433 for _, model := range providerConfig.Models {
434 if model.ID == "" {
435 continue
436 }
437 if idx, ok := modelIndex[model.ID]; ok {
438 if model.Name != "" {
439 displayProvider.Models[idx].Name = model.Name
440 }
441 continue
442 }
443 model.Name = cmp.Or(model.Name, model.ID)
444 displayProvider.Models = append(displayProvider.Models, model)
445 modelIndex[model.ID] = len(displayProvider.Models) - 1
446 }
447 }
448
449 name := cmp.Or(displayProvider.Name, providerID)
450
451 group := NewModelGroup(t, name, providerConfigured)
452 for _, model := range displayProvider.Models {
453 item := NewModelItem(t, provider, model, m.modelType, false)
454 group.AppendItems(item)
455 itemsMap[item.ID()] = item
456 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
457 selectedItemID = item.ID()
458 }
459 }
460
461 groups = append(groups, group)
462 }
463
464 if len(recentItems) > 0 {
465 recentGroup := NewModelGroup(t, "Recently used", false)
466
467 var validRecentItems []config.SelectedModel
468 for _, recent := range recentItems {
469 key := modelKey(recent.Provider, recent.Model)
470 item, ok := itemsMap[key]
471 if !ok {
472 continue
473 }
474
475 // Show provider for recent items
476 item = NewModelItem(t, item.prov, item.model, m.modelType, true)
477 item.showProvider = true
478
479 validRecentItems = append(validRecentItems, recent)
480 recentGroup.AppendItems(item)
481 if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider {
482 selectedItemID = item.ID()
483 }
484 }
485
486 if len(validRecentItems) != len(recentItems) {
487 // FIXME: Does this need to be here? Is it mutating the config during a read?
488 if err := m.com.Workspace.SetConfigField(config.ScopeGlobal, fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
489 return fmt.Errorf("failed to update recent models: %w", err)
490 }
491 }
492
493 if len(recentGroup.Items) > 0 {
494 groups = append([]ModelGroup{recentGroup}, groups...)
495 }
496 }
497
498 // Set model groups in the list.
499 m.list.SetGroups(groups...)
500 m.list.SetSelectedItem(selectedItemID)
501 m.list.ScrollToTop()
502
503 // Update placeholder based on model type
504 if !m.isOnboarding {
505 m.input.Placeholder = m.modelType.Placeholder()
506 }
507
508 return nil
509}
510
511func modelKey(providerID, modelID string) string {
512 if providerID == "" || modelID == "" {
513 return ""
514 }
515 return providerID + ":" + modelID
516}