1package dialog
2
3import (
4 "cmp"
5 "fmt"
6 "slices"
7 "strings"
8
9 "charm.land/bubbles/v2/help"
10 "charm.land/bubbles/v2/key"
11 "charm.land/bubbles/v2/textinput"
12 tea "charm.land/bubbletea/v2"
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/ui/common"
16 "github.com/charmbracelet/crush/internal/uiutil"
17 uv "github.com/charmbracelet/ultraviolet"
18)
19
20// ModelType represents the type of model to select.
21type ModelType int
22
23const (
24 ModelTypeLarge ModelType = iota
25 ModelTypeSmall
26)
27
28// String returns the string representation of the [ModelType].
29func (mt ModelType) String() string {
30 switch mt {
31 case ModelTypeLarge:
32 return "Large Task"
33 case ModelTypeSmall:
34 return "Small Task"
35 default:
36 return "Unknown"
37 }
38}
39
40// Config returns the corresponding config model type.
41func (mt ModelType) Config() config.SelectedModelType {
42 switch mt {
43 case ModelTypeLarge:
44 return config.SelectedModelTypeLarge
45 case ModelTypeSmall:
46 return config.SelectedModelTypeSmall
47 default:
48 return ""
49 }
50}
51
52// Placeholder returns the input placeholder for the model type.
53func (mt ModelType) Placeholder() string {
54 switch mt {
55 case ModelTypeLarge:
56 return largeModelInputPlaceholder
57 case ModelTypeSmall:
58 return smallModelInputPlaceholder
59 default:
60 return ""
61 }
62}
63
64const (
65 largeModelInputPlaceholder = "Choose a model for large, complex tasks"
66 smallModelInputPlaceholder = "Choose a model for small, simple tasks"
67)
68
69// ModelsID is the identifier for the model selection dialog.
70const ModelsID = "models"
71
72const defaultModelsDialogMaxWidth = 70
73
74// Models represents a model selection dialog.
75type Models struct {
76 com *common.Common
77
78 modelType ModelType
79 providers []catwalk.Provider
80
81 keyMap struct {
82 Tab key.Binding
83 UpDown key.Binding
84 Select key.Binding
85 Next key.Binding
86 Previous key.Binding
87 Close key.Binding
88 }
89 list *ModelsList
90 input textinput.Model
91 help help.Model
92}
93
94var _ Dialog = (*Models)(nil)
95
96// NewModels creates a new Models dialog.
97func NewModels(com *common.Common) (*Models, error) {
98 t := com.Styles
99 m := &Models{}
100 m.com = com
101 help := help.New()
102 help.Styles = t.DialogHelpStyles()
103
104 m.help = help
105 m.list = NewModelsList(t)
106 m.list.Focus()
107 m.list.SetSelected(0)
108
109 m.input = textinput.New()
110 m.input.SetVirtualCursor(false)
111 m.input.Placeholder = largeModelInputPlaceholder
112 m.input.SetStyles(com.Styles.TextInput)
113 m.input.Focus()
114
115 m.keyMap.Tab = key.NewBinding(
116 key.WithKeys("tab", "shift+tab"),
117 key.WithHelp("tab", "toggle type"),
118 )
119 m.keyMap.Select = key.NewBinding(
120 key.WithKeys("enter", "ctrl+y"),
121 key.WithHelp("enter", "confirm"),
122 )
123 m.keyMap.UpDown = key.NewBinding(
124 key.WithKeys("up", "down"),
125 key.WithHelp("↑/↓", "choose"),
126 )
127 m.keyMap.Next = key.NewBinding(
128 key.WithKeys("down", "ctrl+n"),
129 key.WithHelp("↓", "next item"),
130 )
131 m.keyMap.Previous = key.NewBinding(
132 key.WithKeys("up", "ctrl+p"),
133 key.WithHelp("↑", "previous item"),
134 )
135 m.keyMap.Close = CloseKey
136
137 providers, err := getFilteredProviders(com.Config())
138 if err != nil {
139 return nil, fmt.Errorf("failed to get providers: %w", err)
140 }
141
142 m.providers = providers
143 if err := m.setProviderItems(); err != nil {
144 return nil, fmt.Errorf("failed to set provider items: %w", err)
145 }
146
147 return m, nil
148}
149
150// ID implements Dialog.
151func (m *Models) ID() string {
152 return ModelsID
153}
154
155// HandleMsg implements Dialog.
156func (m *Models) HandleMsg(msg tea.Msg) Action {
157 switch msg := msg.(type) {
158 case tea.KeyPressMsg:
159 switch {
160 case key.Matches(msg, m.keyMap.Close):
161 return ActionClose{}
162 case key.Matches(msg, m.keyMap.Previous):
163 m.list.Focus()
164 if m.list.IsSelectedFirst() {
165 m.list.SelectLast()
166 m.list.ScrollToBottom()
167 break
168 }
169 m.list.SelectPrev()
170 m.list.ScrollToSelected()
171 case key.Matches(msg, m.keyMap.Next):
172 m.list.Focus()
173 if m.list.IsSelectedLast() {
174 m.list.SelectFirst()
175 m.list.ScrollToTop()
176 break
177 }
178 m.list.SelectNext()
179 m.list.ScrollToSelected()
180 case key.Matches(msg, m.keyMap.Select):
181 selectedItem := m.list.SelectedItem()
182 if selectedItem == nil {
183 break
184 }
185
186 modelItem, ok := selectedItem.(*ModelItem)
187 if !ok {
188 break
189 }
190
191 return ActionSelectModel{
192 Provider: modelItem.prov,
193 Model: modelItem.SelectedModel(),
194 ModelType: modelItem.SelectedModelType(),
195 }
196 case key.Matches(msg, m.keyMap.Tab):
197 if m.modelType == ModelTypeLarge {
198 m.modelType = ModelTypeSmall
199 } else {
200 m.modelType = ModelTypeLarge
201 }
202 if err := m.setProviderItems(); err != nil {
203 return uiutil.ReportError(err)
204 }
205 default:
206 var cmd tea.Cmd
207 m.input, cmd = m.input.Update(msg)
208 value := m.input.Value()
209 m.list.SetFilter(value)
210 m.list.ScrollToSelected()
211 return ActionCmd{cmd}
212 }
213 }
214 return nil
215}
216
217// Cursor returns the cursor for the dialog.
218func (m *Models) Cursor() *tea.Cursor {
219 return InputCursor(m.com.Styles, m.input.Cursor())
220}
221
222// modelTypeRadioView returns the radio view for model type selection.
223func (m *Models) modelTypeRadioView() string {
224 t := m.com.Styles
225 textStyle := t.HalfMuted
226 largeRadioStyle := t.RadioOff
227 smallRadioStyle := t.RadioOff
228 if m.modelType == ModelTypeLarge {
229 largeRadioStyle = t.RadioOn
230 } else {
231 smallRadioStyle = t.RadioOn
232 }
233
234 largeRadio := largeRadioStyle.Padding(0, 1).Render()
235 smallRadio := smallRadioStyle.Padding(0, 1).Render()
236
237 return fmt.Sprintf("%s%s %s%s",
238 largeRadio, textStyle.Render(ModelTypeLarge.String()),
239 smallRadio, textStyle.Render(ModelTypeSmall.String()))
240}
241
242// Draw implements [Dialog].
243func (m *Models) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
244 t := m.com.Styles
245 width := max(0, min(defaultModelsDialogMaxWidth, area.Dx()))
246 height := max(0, min(defaultDialogHeight, area.Dy()))
247 innerWidth := width - t.Dialog.View.GetHorizontalFrameSize()
248 heightOffset := t.Dialog.Title.GetVerticalFrameSize() + titleContentHeight +
249 t.Dialog.InputPrompt.GetVerticalFrameSize() + inputContentHeight +
250 t.Dialog.HelpView.GetVerticalFrameSize() +
251 t.Dialog.View.GetVerticalFrameSize()
252 m.input.SetWidth(innerWidth - t.Dialog.InputPrompt.GetHorizontalFrameSize() - 1) // (1) cursor padding
253 m.list.SetSize(innerWidth, height-heightOffset)
254 m.help.SetWidth(innerWidth)
255
256 rc := NewRenderContext(t, width)
257 rc.Title = "Switch Model"
258 rc.TitleInfo = m.modelTypeRadioView()
259 inputView := t.Dialog.InputPrompt.Render(m.input.View())
260 rc.AddPart(inputView)
261 listView := t.Dialog.List.Height(m.list.Height()).Render(m.list.Render())
262 rc.AddPart(listView)
263 rc.Help = m.help.View(m)
264
265 view := rc.Render()
266
267 cur := m.Cursor()
268 DrawCenterCursor(scr, area, view, cur)
269 return cur
270}
271
272// ShortHelp returns the short help view.
273func (m *Models) ShortHelp() []key.Binding {
274 return []key.Binding{
275 m.keyMap.UpDown,
276 m.keyMap.Tab,
277 m.keyMap.Select,
278 m.keyMap.Close,
279 }
280}
281
282// FullHelp returns the full help view.
283func (m *Models) FullHelp() [][]key.Binding {
284 return [][]key.Binding{
285 {
286 m.keyMap.Select,
287 m.keyMap.Next,
288 m.keyMap.Previous,
289 m.keyMap.Tab,
290 },
291 {
292 m.keyMap.Close,
293 },
294 }
295}
296
297// setProviderItems sets the provider items in the list.
298func (m *Models) setProviderItems() error {
299 t := m.com.Styles
300 cfg := m.com.Config()
301
302 var selectedItemID string
303 selectedType := m.modelType.Config()
304 currentModel := cfg.Models[selectedType]
305 recentItems := cfg.RecentModels[selectedType]
306
307 // Track providers already added to avoid duplicates
308 addedProviders := make(map[string]bool)
309
310 // Get a list of known providers to compare against
311 knownProviders, err := config.Providers(cfg)
312 if err != nil {
313 return fmt.Errorf("failed to get providers: %w", err)
314 }
315
316 containsProviderFunc := func(id string) func(p catwalk.Provider) bool {
317 return func(p catwalk.Provider) bool {
318 return p.ID == catwalk.InferenceProvider(id)
319 }
320 }
321
322 // itemsMap contains the keys of added model items.
323 itemsMap := make(map[string]*ModelItem)
324 groups := []ModelGroup{}
325 for id, p := range cfg.Providers.Seq2() {
326 if p.Disable {
327 continue
328 }
329
330 // Check if this provider is not in the known providers list
331 if !slices.ContainsFunc(knownProviders, containsProviderFunc(id)) ||
332 !slices.ContainsFunc(m.providers, containsProviderFunc(id)) {
333 provider := p.ToProvider()
334
335 // Add this unknown provider to the list
336 name := cmp.Or(p.Name, id)
337
338 addedProviders[id] = true
339
340 group := NewModelGroup(t, name, true)
341 for _, model := range p.Models {
342 item := NewModelItem(t, provider, model, m.modelType, false)
343 group.AppendItems(item)
344 itemsMap[item.ID()] = item
345 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
346 selectedItemID = item.ID()
347 }
348 }
349 if len(group.Items) > 0 {
350 groups = append(groups, group)
351 }
352 }
353 }
354
355 // Now add known providers from the predefined list
356 for _, provider := range m.providers {
357 providerID := string(provider.ID)
358 if addedProviders[providerID] {
359 continue
360 }
361
362 providerConfig, providerConfigured := cfg.Providers.Get(providerID)
363 if providerConfigured && providerConfig.Disable {
364 continue
365 }
366
367 displayProvider := provider
368 if providerConfigured {
369 displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
370 modelIndex := make(map[string]int, len(displayProvider.Models))
371 for i, model := range displayProvider.Models {
372 modelIndex[model.ID] = i
373 }
374 for _, model := range providerConfig.Models {
375 if model.ID == "" {
376 continue
377 }
378 if idx, ok := modelIndex[model.ID]; ok {
379 if model.Name != "" {
380 displayProvider.Models[idx].Name = model.Name
381 }
382 continue
383 }
384 if model.Name == "" {
385 model.Name = model.ID
386 }
387 displayProvider.Models = append(displayProvider.Models, model)
388 modelIndex[model.ID] = len(displayProvider.Models) - 1
389 }
390 }
391
392 name := displayProvider.Name
393 if name == "" {
394 name = providerID
395 }
396
397 group := NewModelGroup(t, name, providerConfigured)
398 for _, model := range displayProvider.Models {
399 item := NewModelItem(t, provider, model, m.modelType, false)
400 group.AppendItems(item)
401 itemsMap[item.ID()] = item
402 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
403 selectedItemID = item.ID()
404 }
405 }
406
407 groups = append(groups, group)
408 }
409
410 if len(recentItems) > 0 {
411 recentGroup := NewModelGroup(t, "Recently used", false)
412
413 var validRecentItems []config.SelectedModel
414 for _, recent := range recentItems {
415 key := modelKey(recent.Provider, recent.Model)
416 item, ok := itemsMap[key]
417 if !ok {
418 continue
419 }
420
421 // Show provider for recent items
422 item = NewModelItem(t, item.prov, item.model, m.modelType, true)
423 item.showProvider = true
424
425 validRecentItems = append(validRecentItems, recent)
426 recentGroup.AppendItems(item)
427 if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider {
428 selectedItemID = item.ID()
429 }
430 }
431
432 if len(validRecentItems) != len(recentItems) {
433 // FIXME: Does this need to be here? Is it mutating the config during a read?
434 if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
435 return fmt.Errorf("failed to update recent models: %w", err)
436 }
437 }
438
439 if len(recentGroup.Items) > 0 {
440 groups = append([]ModelGroup{recentGroup}, groups...)
441 }
442 }
443
444 // Set model groups in the list.
445 m.list.SetGroups(groups...)
446 m.list.SetSelectedItem(selectedItemID)
447 m.list.ScrollToSelected()
448
449 // Update placeholder based on model type
450 m.input.Placeholder = m.modelType.Placeholder()
451
452 return nil
453}
454
455func getFilteredProviders(cfg *config.Config) ([]catwalk.Provider, error) {
456 providers, err := config.Providers(cfg)
457 if err != nil {
458 return nil, fmt.Errorf("failed to get providers: %w", err)
459 }
460 var filteredProviders []catwalk.Provider
461 for _, p := range providers {
462 var (
463 isAzure = p.ID == catwalk.InferenceProviderAzure
464 isCopilot = p.ID == catwalk.InferenceProviderCopilot
465 isHyper = string(p.ID) == "hyper"
466 hasAPIKeyEnv = strings.HasPrefix(p.APIKey, "$")
467 _, isConfigured = cfg.Providers.Get(string(p.ID))
468 )
469 if isAzure || isCopilot || isHyper || hasAPIKeyEnv || isConfigured {
470 filteredProviders = append(filteredProviders, p)
471 }
472 }
473 return filteredProviders, nil
474}
475
476func modelKey(providerID, modelID string) string {
477 if providerID == "" || modelID == "" {
478 return ""
479 }
480 return providerID + ":" + modelID
481}