1package dialog
2
3import (
4 "cmp"
5 "fmt"
6 "slices"
7
8 "charm.land/bubbles/v2/help"
9 "charm.land/bubbles/v2/key"
10 "charm.land/bubbles/v2/textinput"
11 tea "charm.land/bubbletea/v2"
12 "charm.land/catwalk/pkg/catwalk"
13 "charm.land/lipgloss/v2"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/ui/common"
16 "github.com/charmbracelet/crush/internal/ui/util"
17 uv "github.com/charmbracelet/ultraviolet"
18)
19
20// ModelType represents the type of model to select.
21type ModelType int
22
23const (
24 ModelTypeLarge ModelType = iota
25 ModelTypeSmall
26)
27
28// String returns the string representation of the [ModelType].
29func (mt ModelType) String() string {
30 switch mt {
31 case ModelTypeLarge:
32 return "Large Task"
33 case ModelTypeSmall:
34 return "Small Task"
35 default:
36 return "Unknown"
37 }
38}
39
40// Config returns the corresponding config model type.
41func (mt ModelType) Config() config.SelectedModelType {
42 switch mt {
43 case ModelTypeLarge:
44 return config.SelectedModelTypeLarge
45 case ModelTypeSmall:
46 return config.SelectedModelTypeSmall
47 default:
48 return ""
49 }
50}
51
52// Placeholder returns the input placeholder for the model type.
53func (mt ModelType) Placeholder() string {
54 switch mt {
55 case ModelTypeLarge:
56 return largeModelInputPlaceholder
57 case ModelTypeSmall:
58 return smallModelInputPlaceholder
59 default:
60 return ""
61 }
62}
63
64const (
65 onboardingModelInputPlaceholder = "Find your fave"
66 largeModelInputPlaceholder = "Choose a model for large, complex tasks"
67 smallModelInputPlaceholder = "Choose a model for small, simple tasks"
68)
69
70// ModelsID is the identifier for the model selection dialog.
71const ModelsID = "models"
72
73const defaultModelsDialogMaxWidth = 73
74
75// Models represents a model selection dialog.
76type Models struct {
77 com *common.Common
78 isOnboarding bool
79
80 modelType ModelType
81 providers []catwalk.Provider
82
83 keyMap struct {
84 Tab key.Binding
85 UpDown key.Binding
86 Select key.Binding
87 Edit key.Binding
88 Next key.Binding
89 Previous key.Binding
90 Close key.Binding
91 }
92 list *ModelsList
93 input textinput.Model
94 help help.Model
95}
96
97var _ Dialog = (*Models)(nil)
98
99// NewModels creates a new Models dialog.
100func NewModels(com *common.Common, isOnboarding bool) (*Models, error) {
101 t := com.Styles
102 m := &Models{}
103 m.com = com
104 m.isOnboarding = isOnboarding
105
106 help := help.New()
107 help.Styles = t.DialogHelpStyles()
108
109 m.help = help
110 m.list = NewModelsList(t)
111 m.list.Focus()
112 m.list.SetSelected(0)
113
114 m.input = textinput.New()
115 m.input.SetVirtualCursor(false)
116 m.input.Placeholder = onboardingModelInputPlaceholder
117 m.input.SetStyles(com.Styles.TextInput)
118 m.input.Focus()
119
120 m.keyMap.Tab = key.NewBinding(
121 key.WithKeys("tab", "shift+tab"),
122 key.WithHelp("tab", "toggle type"),
123 )
124 m.keyMap.Select = key.NewBinding(
125 key.WithKeys("enter", "ctrl+y"),
126 key.WithHelp("enter", "confirm"),
127 )
128 m.keyMap.Edit = key.NewBinding(
129 key.WithKeys("ctrl+e"),
130 key.WithHelp("ctrl+e", "edit"),
131 )
132 m.keyMap.UpDown = key.NewBinding(
133 key.WithKeys("up", "down"),
134 key.WithHelp("↑/↓", "choose"),
135 )
136 m.keyMap.Next = key.NewBinding(
137 key.WithKeys("down", "ctrl+n"),
138 key.WithHelp("↓", "next item"),
139 )
140 m.keyMap.Previous = key.NewBinding(
141 key.WithKeys("up", "ctrl+p"),
142 key.WithHelp("↑", "previous item"),
143 )
144 m.keyMap.Close = CloseKey
145
146 var err error
147 m.providers, err = config.Providers(m.com.Config())
148 if err != nil {
149 return nil, fmt.Errorf("failed to get providers: %w", err)
150 }
151
152 if err := m.setProviderItems(); err != nil {
153 return nil, fmt.Errorf("failed to set provider items: %w", err)
154 }
155
156 return m, nil
157}
158
159// ID implements Dialog.
160func (m *Models) ID() string {
161 return ModelsID
162}
163
164// HandleMsg implements Dialog.
165func (m *Models) HandleMsg(msg tea.Msg) Action {
166 switch msg := msg.(type) {
167 case tea.KeyPressMsg:
168 switch {
169 case key.Matches(msg, m.keyMap.Close):
170 return ActionClose{}
171 case key.Matches(msg, m.keyMap.Previous):
172 m.list.Focus()
173 if m.list.IsSelectedFirst() {
174 m.list.SelectLast()
175 } else {
176 m.list.SelectPrev()
177 }
178 m.list.ScrollToSelected()
179 case key.Matches(msg, m.keyMap.Next):
180 m.list.Focus()
181 if m.list.IsSelectedLast() {
182 m.list.SelectFirst()
183 } else {
184 m.list.SelectNext()
185 }
186 m.list.ScrollToSelected()
187 case key.Matches(msg, m.keyMap.Select, m.keyMap.Edit):
188 selectedItem := m.list.SelectedItem()
189 if selectedItem == nil {
190 break
191 }
192
193 modelItem, ok := selectedItem.(*ModelItem)
194 if !ok {
195 break
196 }
197
198 isEdit := key.Matches(msg, m.keyMap.Edit)
199
200 return ActionSelectModel{
201 Provider: modelItem.prov,
202 Model: modelItem.SelectedModel(),
203 ModelType: modelItem.SelectedModelType(),
204 ReAuthenticate: isEdit,
205 }
206 case key.Matches(msg, m.keyMap.Tab):
207 if m.isOnboarding {
208 break
209 }
210 if m.modelType == ModelTypeLarge {
211 m.modelType = ModelTypeSmall
212 } else {
213 m.modelType = ModelTypeLarge
214 }
215 if err := m.setProviderItems(); err != nil {
216 return util.ReportError(err)
217 }
218 default:
219 var cmd tea.Cmd
220 m.input, cmd = m.input.Update(msg)
221 value := m.input.Value()
222 m.list.Focus()
223 m.list.SetFilter(value)
224 m.list.SelectFirst()
225 m.list.ScrollToTop()
226 return ActionCmd{cmd}
227 }
228 }
229 return nil
230}
231
232// Cursor returns the cursor for the dialog.
233func (m *Models) Cursor() *tea.Cursor {
234 return InputCursor(m.com.Styles, m.input.Cursor())
235}
236
237// modelTypeRadioView returns the radio view for model type selection.
238func (m *Models) modelTypeRadioView() string {
239 t := m.com.Styles
240 textStyle := t.Radio.Label
241 largeRadioStyle := t.Radio.Off
242 smallRadioStyle := t.Radio.Off
243 if m.modelType == ModelTypeLarge {
244 largeRadioStyle = t.Radio.On
245 } else {
246 smallRadioStyle = t.Radio.On
247 }
248
249 largeRadio := largeRadioStyle.Padding(0, 1).Render()
250 smallRadio := smallRadioStyle.Padding(0, 1).Render()
251
252 return fmt.Sprintf("%s%s %s%s",
253 largeRadio, textStyle.Render(ModelTypeLarge.String()),
254 smallRadio, textStyle.Render(ModelTypeSmall.String()))
255}
256
257// Draw implements [Dialog].
258func (m *Models) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
259 t := m.com.Styles
260 width := max(0, min(defaultModelsDialogMaxWidth, area.Dx()-t.Dialog.View.GetHorizontalBorderSize()))
261 height := max(0, min(defaultDialogHeight, area.Dy()-t.Dialog.View.GetVerticalBorderSize()))
262 innerWidth := width - t.Dialog.View.GetHorizontalFrameSize()
263 heightOffset := t.Dialog.Title.GetVerticalFrameSize() + titleContentHeight +
264 t.Dialog.InputPrompt.GetVerticalFrameSize() + inputContentHeight +
265 t.Dialog.HelpView.GetVerticalFrameSize() +
266 t.Dialog.View.GetVerticalFrameSize()
267
268 m.input.SetWidth(max(0, innerWidth-t.Dialog.InputPrompt.GetHorizontalFrameSize()-1)) // (1) cursor padding
269 m.help.SetWidth(innerWidth)
270
271 listHeight := height - heightOffset
272 m.list.SetSize(innerWidth, listHeight)
273 listTotalHeight := m.list.TotalHeight()
274 listWidth := max(0, innerWidth-3) // Reserve space for scrollbar.
275 m.list.SetSize(listWidth, listHeight)
276
277 rc := NewRenderContext(t, width)
278 rc.Title = "Switch Model"
279 rc.TitleInfo = m.modelTypeRadioView()
280
281 if m.isOnboarding {
282 titleText := t.Dialog.PrimaryText.Render("To start, let's choose a provider and model.")
283 rc.AddPart(titleText)
284 }
285
286 inputView := t.Dialog.InputPrompt.Render(m.input.View())
287 rc.AddPart(inputView)
288
289 listView := t.Dialog.List.Height(m.list.Height()).Render(m.list.Render())
290 scrollbar := common.Scrollbar(t, listHeight, listTotalHeight, listHeight, m.list.Offset())
291 if scrollbar != "" {
292 listView = lipgloss.JoinHorizontal(lipgloss.Top, listView, scrollbar)
293 }
294 rc.AddPart(listView)
295
296 rc.Help = m.help.View(m)
297
298 cur := m.Cursor()
299
300 if m.isOnboarding {
301 rc.Title = ""
302 rc.TitleInfo = ""
303 rc.IsOnboarding = true
304 view := rc.Render()
305 cur = adjustOnboardingInputCursor(t, cur)
306 DrawOnboardingCursor(scr, area, view, cur)
307 } else {
308 view := rc.Render()
309 DrawCenterCursor(scr, area, view, cur)
310 }
311 return cur
312}
313
314// ShortHelp returns the short help view.
315func (m *Models) ShortHelp() []key.Binding {
316 if m.isOnboarding {
317 return []key.Binding{
318 m.keyMap.UpDown,
319 m.keyMap.Select,
320 }
321 }
322 h := []key.Binding{
323 m.keyMap.UpDown,
324 m.keyMap.Tab,
325 m.keyMap.Select,
326 }
327 if m.isSelectedConfigured() {
328 h = append(h, m.keyMap.Edit)
329 }
330 h = append(h, m.keyMap.Close)
331 return h
332}
333
334// FullHelp returns the full help view.
335func (m *Models) FullHelp() [][]key.Binding {
336 return [][]key.Binding{m.ShortHelp()}
337}
338
339func (m *Models) isSelectedConfigured() bool {
340 selectedItem := m.list.SelectedItem()
341 if selectedItem == nil {
342 return false
343 }
344 modelItem, ok := selectedItem.(*ModelItem)
345 if !ok {
346 return false
347 }
348 providerID := string(modelItem.prov.ID)
349 _, isConfigured := m.com.Config().Providers.Get(providerID)
350 return isConfigured
351}
352
353// setProviderItems sets the provider items in the list.
354func (m *Models) setProviderItems() error {
355 t := m.com.Styles
356 cfg := m.com.Config()
357
358 var selectedItemID string
359 selectedType := m.modelType.Config()
360 currentModel := cfg.Models[selectedType]
361 recentItems := cfg.RecentModels[selectedType]
362
363 // Track providers already added to avoid duplicates
364 addedProviders := make(map[string]bool)
365
366 // Get a list of known providers to compare against
367 knownProviders, err := config.Providers(cfg)
368 if err != nil {
369 return fmt.Errorf("failed to get providers: %w", err)
370 }
371
372 containsProviderFunc := func(id string) func(p catwalk.Provider) bool {
373 return func(p catwalk.Provider) bool {
374 return p.ID == catwalk.InferenceProvider(id)
375 }
376 }
377
378 // itemsMap contains the keys of added model items.
379 itemsMap := make(map[string]*ModelItem)
380 groups := []ModelGroup{}
381 for id, p := range cfg.Providers.Seq2() {
382 if p.Disable {
383 continue
384 }
385
386 // Check if this provider is not in the known providers list
387 if !slices.ContainsFunc(knownProviders, containsProviderFunc(id)) ||
388 !slices.ContainsFunc(m.providers, containsProviderFunc(id)) {
389 provider := p.ToProvider()
390
391 // Add this unknown provider to the list
392 name := cmp.Or(p.Name, id)
393
394 addedProviders[id] = true
395
396 group := NewModelGroup(t, name, true)
397 for _, model := range p.Models {
398 item := NewModelItem(t, provider, model, m.modelType, false)
399 group.AppendItems(item)
400 itemsMap[item.ID()] = item
401 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
402 selectedItemID = item.ID()
403 }
404 }
405 if len(group.Items) > 0 {
406 groups = append(groups, group)
407 }
408 }
409 }
410
411 // Now add known providers from the predefined list.
412 // Providers already has Hyper at the front of the list.
413 for _, provider := range m.providers {
414 providerID := string(provider.ID)
415 if addedProviders[providerID] {
416 continue
417 }
418
419 providerConfig, providerConfigured := cfg.Providers.Get(providerID)
420 if providerConfigured && providerConfig.Disable {
421 continue
422 }
423
424 displayProvider := provider
425 if providerConfigured {
426 displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
427 modelIndex := make(map[string]int, len(displayProvider.Models))
428 for i, model := range displayProvider.Models {
429 modelIndex[model.ID] = i
430 }
431 for _, model := range providerConfig.Models {
432 if model.ID == "" {
433 continue
434 }
435 if idx, ok := modelIndex[model.ID]; ok {
436 if model.Name != "" {
437 displayProvider.Models[idx].Name = model.Name
438 }
439 continue
440 }
441 model.Name = cmp.Or(model.Name, model.ID)
442 displayProvider.Models = append(displayProvider.Models, model)
443 modelIndex[model.ID] = len(displayProvider.Models) - 1
444 }
445 }
446
447 name := cmp.Or(displayProvider.Name, providerID)
448
449 group := NewModelGroup(t, name, providerConfigured)
450 for _, model := range displayProvider.Models {
451 item := NewModelItem(t, provider, model, m.modelType, false)
452 group.AppendItems(item)
453 itemsMap[item.ID()] = item
454 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
455 selectedItemID = item.ID()
456 }
457 }
458
459 groups = append(groups, group)
460 }
461
462 if len(recentItems) > 0 {
463 recentGroup := NewModelGroup(t, "Recently used", false)
464
465 var validRecentItems []config.SelectedModel
466 for _, recent := range recentItems {
467 key := modelKey(recent.Provider, recent.Model)
468 item, ok := itemsMap[key]
469 if !ok {
470 continue
471 }
472
473 // Show provider for recent items
474 item = NewModelItem(t, item.prov, item.model, m.modelType, true)
475 item.showProvider = true
476
477 validRecentItems = append(validRecentItems, recent)
478 recentGroup.AppendItems(item)
479 if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider {
480 selectedItemID = item.ID()
481 }
482 }
483
484 if len(validRecentItems) != len(recentItems) {
485 // FIXME: Does this need to be here? Is it mutating the config during a read?
486 if err := m.com.Workspace.SetConfigField(config.ScopeGlobal, fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
487 return fmt.Errorf("failed to update recent models: %w", err)
488 }
489 }
490
491 if len(recentGroup.Items) > 0 {
492 groups = append([]ModelGroup{recentGroup}, groups...)
493 }
494 }
495
496 // Set model groups in the list.
497 m.list.SetGroups(groups...)
498 m.list.SetSelectedItem(selectedItemID)
499 if selectedItemID != "" {
500 m.list.ScrollToSelected()
501 } else {
502 m.list.ScrollToTop()
503 }
504
505 // Update placeholder based on model type
506 if !m.isOnboarding {
507 m.input.Placeholder = m.modelType.Placeholder()
508 }
509
510 return nil
511}
512
513func modelKey(providerID, modelID string) string {
514 if providerID == "" || modelID == "" {
515 return ""
516 }
517 return providerID + ":" + modelID
518}