1package dialog
2
3import (
4 "cmp"
5 "fmt"
6 "slices"
7
8 "charm.land/bubbles/v2/help"
9 "charm.land/bubbles/v2/key"
10 "charm.land/bubbles/v2/textinput"
11 tea "charm.land/bubbletea/v2"
12 "charm.land/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/ui/common"
15 "github.com/charmbracelet/crush/internal/ui/util"
16 uv "github.com/charmbracelet/ultraviolet"
17)
18
19// ModelType represents the type of model to select.
20type ModelType int
21
22const (
23 ModelTypeLarge ModelType = iota
24 ModelTypeSmall
25)
26
27// String returns the string representation of the [ModelType].
28func (mt ModelType) String() string {
29 switch mt {
30 case ModelTypeLarge:
31 return "Large Task"
32 case ModelTypeSmall:
33 return "Small Task"
34 default:
35 return "Unknown"
36 }
37}
38
39// Config returns the corresponding config model type.
40func (mt ModelType) Config() config.SelectedModelType {
41 switch mt {
42 case ModelTypeLarge:
43 return config.SelectedModelTypeLarge
44 case ModelTypeSmall:
45 return config.SelectedModelTypeSmall
46 default:
47 return ""
48 }
49}
50
51// Placeholder returns the input placeholder for the model type.
52func (mt ModelType) Placeholder() string {
53 switch mt {
54 case ModelTypeLarge:
55 return largeModelInputPlaceholder
56 case ModelTypeSmall:
57 return smallModelInputPlaceholder
58 default:
59 return ""
60 }
61}
62
63const (
64 onboardingModelInputPlaceholder = "Find your fave"
65 largeModelInputPlaceholder = "Choose a model for large, complex tasks"
66 smallModelInputPlaceholder = "Choose a model for small, simple tasks"
67)
68
69// ModelsID is the identifier for the model selection dialog.
70const ModelsID = "models"
71
72const defaultModelsDialogMaxWidth = 73
73
74// Models represents a model selection dialog.
75type Models struct {
76 com *common.Common
77 isOnboarding bool
78
79 modelType ModelType
80 providers []catwalk.Provider
81
82 keyMap struct {
83 Tab key.Binding
84 UpDown key.Binding
85 Select key.Binding
86 Edit key.Binding
87 Next key.Binding
88 Previous key.Binding
89 Close key.Binding
90 }
91 list *ModelsList
92 input textinput.Model
93 help help.Model
94}
95
96var _ Dialog = (*Models)(nil)
97
98// NewModels creates a new Models dialog.
99func NewModels(com *common.Common, isOnboarding bool) (*Models, error) {
100 t := com.Styles
101 m := &Models{}
102 m.com = com
103 m.isOnboarding = isOnboarding
104
105 help := help.New()
106 help.Styles = t.DialogHelpStyles()
107
108 m.help = help
109 m.list = NewModelsList(t)
110 m.list.Focus()
111 m.list.SetSelected(0)
112
113 m.input = textinput.New()
114 m.input.SetVirtualCursor(false)
115 m.input.Placeholder = onboardingModelInputPlaceholder
116 m.input.SetStyles(com.Styles.TextInput)
117 m.input.Focus()
118
119 m.keyMap.Tab = key.NewBinding(
120 key.WithKeys("tab", "shift+tab"),
121 key.WithHelp("tab", "toggle type"),
122 )
123 m.keyMap.Select = key.NewBinding(
124 key.WithKeys("enter", "ctrl+y"),
125 key.WithHelp("enter", "confirm"),
126 )
127 m.keyMap.Edit = key.NewBinding(
128 key.WithKeys("ctrl+e"),
129 key.WithHelp("ctrl+e", "edit"),
130 )
131 m.keyMap.UpDown = key.NewBinding(
132 key.WithKeys("up", "down"),
133 key.WithHelp("↑/↓", "choose"),
134 )
135 m.keyMap.Next = key.NewBinding(
136 key.WithKeys("down", "ctrl+n"),
137 key.WithHelp("↓", "next item"),
138 )
139 m.keyMap.Previous = key.NewBinding(
140 key.WithKeys("up", "ctrl+p"),
141 key.WithHelp("↑", "previous item"),
142 )
143 m.keyMap.Close = CloseKey
144
145 var err error
146 m.providers, err = config.Providers(m.com.Config())
147 if err != nil {
148 return nil, fmt.Errorf("failed to get providers: %w", err)
149 }
150
151 if err := m.setProviderItems(); err != nil {
152 return nil, fmt.Errorf("failed to set provider items: %w", err)
153 }
154
155 return m, nil
156}
157
158// ID implements Dialog.
159func (m *Models) ID() string {
160 return ModelsID
161}
162
163// HandleMsg implements Dialog.
164func (m *Models) HandleMsg(msg tea.Msg) Action {
165 switch msg := msg.(type) {
166 case tea.KeyPressMsg:
167 switch {
168 case key.Matches(msg, m.keyMap.Close):
169 return ActionClose{}
170 case key.Matches(msg, m.keyMap.Previous):
171 m.list.Focus()
172 if m.list.IsSelectedFirst() {
173 m.list.SelectLast()
174 } else {
175 m.list.SelectPrev()
176 }
177 m.list.ScrollToSelected()
178 case key.Matches(msg, m.keyMap.Next):
179 m.list.Focus()
180 if m.list.IsSelectedLast() {
181 m.list.SelectFirst()
182 } else {
183 m.list.SelectNext()
184 }
185 m.list.ScrollToSelected()
186 case key.Matches(msg, m.keyMap.Select, m.keyMap.Edit):
187 selectedItem := m.list.SelectedItem()
188 if selectedItem == nil {
189 break
190 }
191
192 modelItem, ok := selectedItem.(*ModelItem)
193 if !ok {
194 break
195 }
196
197 isEdit := key.Matches(msg, m.keyMap.Edit)
198
199 return ActionSelectModel{
200 Provider: modelItem.prov,
201 Model: modelItem.SelectedModel(),
202 ModelType: modelItem.SelectedModelType(),
203 ReAuthenticate: isEdit,
204 }
205 case key.Matches(msg, m.keyMap.Tab):
206 if m.isOnboarding {
207 break
208 }
209 if m.modelType == ModelTypeLarge {
210 m.modelType = ModelTypeSmall
211 } else {
212 m.modelType = ModelTypeLarge
213 }
214 if err := m.setProviderItems(); err != nil {
215 return util.ReportError(err)
216 }
217 default:
218 var cmd tea.Cmd
219 m.input, cmd = m.input.Update(msg)
220 value := m.input.Value()
221 m.list.Focus()
222 m.list.SetFilter(value)
223 m.list.SelectFirst()
224 m.list.ScrollToTop()
225 return ActionCmd{cmd}
226 }
227 }
228 return nil
229}
230
231// Cursor returns the cursor for the dialog.
232func (m *Models) Cursor() *tea.Cursor {
233 return InputCursor(m.com.Styles, m.input.Cursor())
234}
235
236// modelTypeRadioView returns the radio view for model type selection.
237func (m *Models) modelTypeRadioView() string {
238 t := m.com.Styles
239 textStyle := t.HalfMuted
240 largeRadioStyle := t.RadioOff
241 smallRadioStyle := t.RadioOff
242 if m.modelType == ModelTypeLarge {
243 largeRadioStyle = t.RadioOn
244 } else {
245 smallRadioStyle = t.RadioOn
246 }
247
248 largeRadio := largeRadioStyle.Padding(0, 1).Render()
249 smallRadio := smallRadioStyle.Padding(0, 1).Render()
250
251 return fmt.Sprintf("%s%s %s%s",
252 largeRadio, textStyle.Render(ModelTypeLarge.String()),
253 smallRadio, textStyle.Render(ModelTypeSmall.String()))
254}
255
256// Draw implements [Dialog].
257func (m *Models) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
258 t := m.com.Styles
259 width := max(0, min(defaultModelsDialogMaxWidth, area.Dx()-t.Dialog.View.GetHorizontalBorderSize()))
260 height := max(0, min(defaultDialogHeight, area.Dy()-t.Dialog.View.GetVerticalBorderSize()))
261 innerWidth := width - t.Dialog.View.GetHorizontalFrameSize()
262 heightOffset := t.Dialog.Title.GetVerticalFrameSize() + titleContentHeight +
263 t.Dialog.InputPrompt.GetVerticalFrameSize() + inputContentHeight +
264 t.Dialog.HelpView.GetVerticalFrameSize() +
265 t.Dialog.View.GetVerticalFrameSize()
266
267 m.input.SetWidth(max(0, innerWidth-t.Dialog.InputPrompt.GetHorizontalFrameSize()-1)) // (1) cursor padding
268 m.list.SetSize(innerWidth, height-heightOffset)
269 m.help.SetWidth(innerWidth)
270
271 rc := NewRenderContext(t, width)
272 rc.Title = "Switch Model"
273 rc.TitleInfo = m.modelTypeRadioView()
274
275 if m.isOnboarding {
276 titleText := t.Dialog.PrimaryText.Render("To start, let's choose a provider and model.")
277 rc.AddPart(titleText)
278 }
279
280 inputView := t.Dialog.InputPrompt.Render(m.input.View())
281 rc.AddPart(inputView)
282
283 listView := t.Dialog.List.Height(m.list.Height()).Render(m.list.Render())
284 rc.AddPart(listView)
285
286 rc.Help = m.help.View(m)
287
288 cur := m.Cursor()
289
290 if m.isOnboarding {
291 rc.Title = ""
292 rc.TitleInfo = ""
293 rc.IsOnboarding = true
294 view := rc.Render()
295 DrawOnboardingCursor(scr, area, view, cur)
296
297 // FIXME(@andreynering): Figure it out how to properly fix this
298 if cur != nil {
299 cur.Y -= 1
300 cur.X -= 1
301 }
302 } else {
303 view := rc.Render()
304 DrawCenterCursor(scr, area, view, cur)
305 }
306 return cur
307}
308
309// ShortHelp returns the short help view.
310func (m *Models) ShortHelp() []key.Binding {
311 if m.isOnboarding {
312 return []key.Binding{
313 m.keyMap.UpDown,
314 m.keyMap.Select,
315 }
316 }
317 h := []key.Binding{
318 m.keyMap.UpDown,
319 m.keyMap.Tab,
320 m.keyMap.Select,
321 }
322 if m.isSelectedConfigured() {
323 h = append(h, m.keyMap.Edit)
324 }
325 h = append(h, m.keyMap.Close)
326 return h
327}
328
329// FullHelp returns the full help view.
330func (m *Models) FullHelp() [][]key.Binding {
331 return [][]key.Binding{m.ShortHelp()}
332}
333
334func (m *Models) isSelectedConfigured() bool {
335 selectedItem := m.list.SelectedItem()
336 if selectedItem == nil {
337 return false
338 }
339 modelItem, ok := selectedItem.(*ModelItem)
340 if !ok {
341 return false
342 }
343 providerID := string(modelItem.prov.ID)
344 _, isConfigured := m.com.Config().Providers.Get(providerID)
345 return isConfigured
346}
347
348// setProviderItems sets the provider items in the list.
349func (m *Models) setProviderItems() error {
350 t := m.com.Styles
351 cfg := m.com.Config()
352
353 var selectedItemID string
354 selectedType := m.modelType.Config()
355 currentModel := cfg.Models[selectedType]
356 recentItems := cfg.RecentModels[selectedType]
357
358 // Track providers already added to avoid duplicates
359 addedProviders := make(map[string]bool)
360
361 // Get a list of known providers to compare against
362 knownProviders, err := config.Providers(cfg)
363 if err != nil {
364 return fmt.Errorf("failed to get providers: %w", err)
365 }
366
367 containsProviderFunc := func(id string) func(p catwalk.Provider) bool {
368 return func(p catwalk.Provider) bool {
369 return p.ID == catwalk.InferenceProvider(id)
370 }
371 }
372
373 // itemsMap contains the keys of added model items.
374 itemsMap := make(map[string]*ModelItem)
375 groups := []ModelGroup{}
376 for id, p := range cfg.Providers.Seq2() {
377 if p.Disable {
378 continue
379 }
380
381 // Check if this provider is not in the known providers list
382 if !slices.ContainsFunc(knownProviders, containsProviderFunc(id)) ||
383 !slices.ContainsFunc(m.providers, containsProviderFunc(id)) {
384 provider := p.ToProvider()
385
386 // Add this unknown provider to the list
387 name := cmp.Or(p.Name, id)
388
389 addedProviders[id] = true
390
391 group := NewModelGroup(t, name, true)
392 for _, model := range p.Models {
393 item := NewModelItem(t, provider, model, m.modelType, false)
394 group.AppendItems(item)
395 itemsMap[item.ID()] = item
396 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
397 selectedItemID = item.ID()
398 }
399 }
400 if len(group.Items) > 0 {
401 groups = append(groups, group)
402 }
403 }
404 }
405
406 // Move "Charm Hyper" to first position.
407 // (But still after recent models and custom providers).
408 slices.SortStableFunc(m.providers, func(a, b catwalk.Provider) int {
409 switch {
410 case a.ID == "hyper":
411 return -1
412 case b.ID == "hyper":
413 return 1
414 default:
415 return 0
416 }
417 })
418
419 // Now add known providers from the predefined list
420 for _, provider := range m.providers {
421 providerID := string(provider.ID)
422 if addedProviders[providerID] {
423 continue
424 }
425
426 providerConfig, providerConfigured := cfg.Providers.Get(providerID)
427 if providerConfigured && providerConfig.Disable {
428 continue
429 }
430
431 displayProvider := provider
432 if providerConfigured {
433 displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
434 modelIndex := make(map[string]int, len(displayProvider.Models))
435 for i, model := range displayProvider.Models {
436 modelIndex[model.ID] = i
437 }
438 for _, model := range providerConfig.Models {
439 if model.ID == "" {
440 continue
441 }
442 if idx, ok := modelIndex[model.ID]; ok {
443 if model.Name != "" {
444 displayProvider.Models[idx].Name = model.Name
445 }
446 continue
447 }
448 model.Name = cmp.Or(model.Name, model.ID)
449 displayProvider.Models = append(displayProvider.Models, model)
450 modelIndex[model.ID] = len(displayProvider.Models) - 1
451 }
452 }
453
454 name := cmp.Or(displayProvider.Name, providerID)
455
456 group := NewModelGroup(t, name, providerConfigured)
457 for _, model := range displayProvider.Models {
458 item := NewModelItem(t, provider, model, m.modelType, false)
459 group.AppendItems(item)
460 itemsMap[item.ID()] = item
461 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
462 selectedItemID = item.ID()
463 }
464 }
465
466 groups = append(groups, group)
467 }
468
469 if len(recentItems) > 0 {
470 recentGroup := NewModelGroup(t, "Recently used", false)
471
472 var validRecentItems []config.SelectedModel
473 for _, recent := range recentItems {
474 key := modelKey(recent.Provider, recent.Model)
475 item, ok := itemsMap[key]
476 if !ok {
477 continue
478 }
479
480 // Show provider for recent items
481 item = NewModelItem(t, item.prov, item.model, m.modelType, true)
482 item.showProvider = true
483
484 validRecentItems = append(validRecentItems, recent)
485 recentGroup.AppendItems(item)
486 if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider {
487 selectedItemID = item.ID()
488 }
489 }
490
491 if len(validRecentItems) != len(recentItems) {
492 // FIXME: Does this need to be here? Is it mutating the config during a read?
493 if err := m.com.Workspace.SetConfigField(config.ScopeGlobal, fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
494 return fmt.Errorf("failed to update recent models: %w", err)
495 }
496 }
497
498 if len(recentGroup.Items) > 0 {
499 groups = append([]ModelGroup{recentGroup}, groups...)
500 }
501 }
502
503 // Set model groups in the list.
504 m.list.SetGroups(groups...)
505 m.list.SetSelectedItem(selectedItemID)
506 m.list.ScrollToTop()
507
508 // Update placeholder based on model type
509 if !m.isOnboarding {
510 m.input.Placeholder = m.modelType.Placeholder()
511 }
512
513 return nil
514}
515
516func modelKey(providerID, modelID string) string {
517 if providerID == "" || modelID == "" {
518 return ""
519 }
520 return providerID + ":" + modelID
521}