1package dialog
2
3import (
4 "cmp"
5 "fmt"
6 "slices"
7 "strings"
8
9 "charm.land/bubbles/v2/help"
10 "charm.land/bubbles/v2/key"
11 "charm.land/bubbles/v2/textinput"
12 tea "charm.land/bubbletea/v2"
13 "charm.land/lipgloss/v2"
14 "github.com/charmbracelet/catwalk/pkg/catwalk"
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/ui/common"
17 "github.com/charmbracelet/crush/internal/uiutil"
18 uv "github.com/charmbracelet/ultraviolet"
19 "github.com/charmbracelet/x/ansi"
20)
21
22// ModelType represents the type of model to select.
23type ModelType int
24
25const (
26 ModelTypeLarge ModelType = iota
27 ModelTypeSmall
28)
29
30// String returns the string representation of the [ModelType].
31func (mt ModelType) String() string {
32 switch mt {
33 case ModelTypeLarge:
34 return "Large Task"
35 case ModelTypeSmall:
36 return "Small Task"
37 default:
38 return "Unknown"
39 }
40}
41
42// Config returns the corresponding config model type.
43func (mt ModelType) Config() config.SelectedModelType {
44 switch mt {
45 case ModelTypeLarge:
46 return config.SelectedModelTypeLarge
47 case ModelTypeSmall:
48 return config.SelectedModelTypeSmall
49 default:
50 return ""
51 }
52}
53
54// Placeholder returns the input placeholder for the model type.
55func (mt ModelType) Placeholder() string {
56 switch mt {
57 case ModelTypeLarge:
58 return largeModelInputPlaceholder
59 case ModelTypeSmall:
60 return smallModelInputPlaceholder
61 default:
62 return ""
63 }
64}
65
66const (
67 largeModelInputPlaceholder = "Choose a model for large, complex tasks"
68 smallModelInputPlaceholder = "Choose a model for small, simple tasks"
69)
70
71// ModelsID is the identifier for the model selection dialog.
72const ModelsID = "models"
73
74// Models represents a model selection dialog.
75type Models struct {
76 com *common.Common
77
78 modelType ModelType
79 providers []catwalk.Provider
80
81 keyMap struct {
82 Tab key.Binding
83 UpDown key.Binding
84 Select key.Binding
85 Next key.Binding
86 Previous key.Binding
87 Close key.Binding
88 }
89 list *ModelsList
90 input textinput.Model
91 help help.Model
92}
93
94var _ Dialog = (*Models)(nil)
95
96// NewModels creates a new Models dialog.
97func NewModels(com *common.Common) (*Models, error) {
98 t := com.Styles
99 m := &Models{}
100 m.com = com
101 help := help.New()
102 help.Styles = t.DialogHelpStyles()
103
104 m.help = help
105 m.list = NewModelsList(t)
106 m.list.Focus()
107 m.list.SetSelected(0)
108
109 m.input = textinput.New()
110 m.input.SetVirtualCursor(false)
111 m.input.Placeholder = largeModelInputPlaceholder
112 m.input.SetStyles(com.Styles.TextInput)
113 m.input.Focus()
114
115 m.keyMap.Tab = key.NewBinding(
116 key.WithKeys("tab", "shift+tab"),
117 key.WithHelp("tab", "toggle type"),
118 )
119 m.keyMap.Select = key.NewBinding(
120 key.WithKeys("enter", "ctrl+y"),
121 key.WithHelp("enter", "confirm"),
122 )
123 m.keyMap.UpDown = key.NewBinding(
124 key.WithKeys("up", "down"),
125 key.WithHelp("↑/↓", "choose"),
126 )
127 m.keyMap.Next = key.NewBinding(
128 key.WithKeys("down", "ctrl+n"),
129 key.WithHelp("↓", "next item"),
130 )
131 m.keyMap.Previous = key.NewBinding(
132 key.WithKeys("up", "ctrl+p"),
133 key.WithHelp("↑", "previous item"),
134 )
135 m.keyMap.Close = CloseKey
136
137 providers, err := getFilteredProviders(com.Config())
138 if err != nil {
139 return nil, fmt.Errorf("failed to get providers: %w", err)
140 }
141
142 m.providers = providers
143 if err := m.setProviderItems(); err != nil {
144 return nil, fmt.Errorf("failed to set provider items: %w", err)
145 }
146
147 return m, nil
148}
149
150// ID implements Dialog.
151func (m *Models) ID() string {
152 return ModelsID
153}
154
155// HandleMsg implements Dialog.
156func (m *Models) HandleMsg(msg tea.Msg) Action {
157 switch msg := msg.(type) {
158 case tea.KeyPressMsg:
159 switch {
160 case key.Matches(msg, m.keyMap.Close):
161 return ActionClose{}
162 case key.Matches(msg, m.keyMap.Previous):
163 m.list.Focus()
164 if m.list.IsSelectedFirst() {
165 m.list.SelectLast()
166 m.list.ScrollToBottom()
167 break
168 }
169 m.list.SelectPrev()
170 m.list.ScrollToSelected()
171 case key.Matches(msg, m.keyMap.Next):
172 m.list.Focus()
173 if m.list.IsSelectedLast() {
174 m.list.SelectFirst()
175 m.list.ScrollToTop()
176 break
177 }
178 m.list.SelectNext()
179 m.list.ScrollToSelected()
180 case key.Matches(msg, m.keyMap.Select):
181 selectedItem := m.list.SelectedItem()
182 if selectedItem == nil {
183 break
184 }
185
186 modelItem, ok := selectedItem.(*ModelItem)
187 if !ok {
188 break
189 }
190
191 return ActionSelectModel{
192 Model: modelItem.SelectedModel(),
193 ModelType: modelItem.SelectedModelType(),
194 }
195 case key.Matches(msg, m.keyMap.Tab):
196 if m.modelType == ModelTypeLarge {
197 m.modelType = ModelTypeSmall
198 } else {
199 m.modelType = ModelTypeLarge
200 }
201 if err := m.setProviderItems(); err != nil {
202 return uiutil.ReportError(err)
203 }
204 default:
205 var cmd tea.Cmd
206 m.input, cmd = m.input.Update(msg)
207 value := m.input.Value()
208 m.list.SetFilter(value)
209 m.list.ScrollToSelected()
210 return ActionCmd{cmd}
211 }
212 }
213 return nil
214}
215
216// Cursor returns the cursor for the dialog.
217func (m *Models) Cursor() *tea.Cursor {
218 return InputCursor(m.com.Styles, m.input.Cursor())
219}
220
221// modelTypeRadioView returns the radio view for model type selection.
222func (m *Models) modelTypeRadioView() string {
223 t := m.com.Styles
224 textStyle := t.HalfMuted
225 largeRadioStyle := t.RadioOff
226 smallRadioStyle := t.RadioOff
227 if m.modelType == ModelTypeLarge {
228 largeRadioStyle = t.RadioOn
229 } else {
230 smallRadioStyle = t.RadioOn
231 }
232
233 largeRadio := largeRadioStyle.Padding(0, 1).Render()
234 smallRadio := smallRadioStyle.Padding(0, 1).Render()
235
236 return fmt.Sprintf("%s%s %s%s",
237 largeRadio, textStyle.Render(ModelTypeLarge.String()),
238 smallRadio, textStyle.Render(ModelTypeSmall.String()))
239}
240
241// Draw implements [Dialog].
242func (m *Models) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
243 t := m.com.Styles
244 width := max(0, min(60, area.Dx()))
245 height := max(0, min(30, area.Dy()))
246 // TODO: Why do we need this 2?
247 innerWidth := width - t.Dialog.View.GetHorizontalFrameSize() - 2
248 heightOffset := t.Dialog.Title.GetVerticalFrameSize() + 1 + // (1) title content
249 t.Dialog.InputPrompt.GetVerticalFrameSize() + 1 + // (1) input content
250 t.Dialog.HelpView.GetVerticalFrameSize() +
251 // TODO: Why do we need this 2?
252 t.Dialog.View.GetVerticalFrameSize() + 2
253 m.input.SetWidth(innerWidth - t.Dialog.InputPrompt.GetHorizontalFrameSize() - 1) // (1) cursor padding
254 m.list.SetSize(innerWidth, height-heightOffset)
255 m.help.SetWidth(innerWidth)
256
257 titleStyle := t.Dialog.Title
258 dialogStyle := t.Dialog.View
259
260 radios := m.modelTypeRadioView()
261
262 headerOffset := lipgloss.Width(radios) + titleStyle.GetHorizontalFrameSize() +
263 dialogStyle.GetHorizontalFrameSize()
264
265 header := common.DialogTitle(t, "Switch Model", width-headerOffset) + radios
266
267 helpView := ansi.Truncate(m.help.View(m), innerWidth, "")
268 view := HeaderInputListHelpView(t, width, m.list.Height(), header,
269 m.input.View(), m.list.Render(), helpView)
270
271 cur := m.Cursor()
272 DrawCenterCursor(scr, area, view, cur)
273 return cur
274}
275
276// ShortHelp returns the short help view.
277func (m *Models) ShortHelp() []key.Binding {
278 return []key.Binding{
279 m.keyMap.UpDown,
280 m.keyMap.Tab,
281 m.keyMap.Select,
282 m.keyMap.Close,
283 }
284}
285
286// FullHelp returns the full help view.
287func (m *Models) FullHelp() [][]key.Binding {
288 return [][]key.Binding{
289 {
290 m.keyMap.Select,
291 m.keyMap.Next,
292 m.keyMap.Previous,
293 m.keyMap.Tab,
294 },
295 {
296 m.keyMap.Close,
297 },
298 }
299}
300
301// setProviderItems sets the provider items in the list.
302func (m *Models) setProviderItems() error {
303 t := m.com.Styles
304 cfg := m.com.Config()
305
306 var selectedItemID string
307 selectedType := m.modelType.Config()
308 currentModel := cfg.Models[selectedType]
309 recentItems := cfg.RecentModels[selectedType]
310
311 // Track providers already added to avoid duplicates
312 addedProviders := make(map[string]bool)
313
314 // Get a list of known providers to compare against
315 knownProviders, err := config.Providers(cfg)
316 if err != nil {
317 return fmt.Errorf("failed to get providers: %w", err)
318 }
319
320 containsProviderFunc := func(id string) func(p catwalk.Provider) bool {
321 return func(p catwalk.Provider) bool {
322 return p.ID == catwalk.InferenceProvider(id)
323 }
324 }
325
326 // itemsMap contains the keys of added model items.
327 itemsMap := make(map[string]*ModelItem)
328 groups := []ModelGroup{}
329 for id, p := range cfg.Providers.Seq2() {
330 if p.Disable {
331 continue
332 }
333
334 // Check if this provider is not in the known providers list
335 if !slices.ContainsFunc(knownProviders, containsProviderFunc(id)) ||
336 !slices.ContainsFunc(m.providers, containsProviderFunc(id)) {
337 provider := p.ToProvider()
338
339 // Add this unknown provider to the list
340 name := cmp.Or(p.Name, id)
341
342 addedProviders[id] = true
343
344 group := NewModelGroup(t, name, true)
345 for _, model := range p.Models {
346 item := NewModelItem(t, provider, model, m.modelType, false)
347 group.AppendItems(item)
348 itemsMap[item.ID()] = item
349 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
350 selectedItemID = item.ID()
351 }
352 }
353 if len(group.Items) > 0 {
354 groups = append(groups, group)
355 }
356 }
357 }
358
359 // Now add known providers from the predefined list
360 for _, provider := range m.providers {
361 providerID := string(provider.ID)
362 if addedProviders[providerID] {
363 continue
364 }
365
366 providerConfig, providerConfigured := cfg.Providers.Get(providerID)
367 if providerConfigured && providerConfig.Disable {
368 continue
369 }
370
371 displayProvider := provider
372 if providerConfigured {
373 displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
374 modelIndex := make(map[string]int, len(displayProvider.Models))
375 for i, model := range displayProvider.Models {
376 modelIndex[model.ID] = i
377 }
378 for _, model := range providerConfig.Models {
379 if model.ID == "" {
380 continue
381 }
382 if idx, ok := modelIndex[model.ID]; ok {
383 if model.Name != "" {
384 displayProvider.Models[idx].Name = model.Name
385 }
386 continue
387 }
388 if model.Name == "" {
389 model.Name = model.ID
390 }
391 displayProvider.Models = append(displayProvider.Models, model)
392 modelIndex[model.ID] = len(displayProvider.Models) - 1
393 }
394 }
395
396 name := displayProvider.Name
397 if name == "" {
398 name = providerID
399 }
400
401 group := NewModelGroup(t, name, providerConfigured)
402 for _, model := range displayProvider.Models {
403 item := NewModelItem(t, provider, model, m.modelType, false)
404 group.AppendItems(item)
405 itemsMap[item.ID()] = item
406 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
407 selectedItemID = item.ID()
408 }
409 }
410
411 groups = append(groups, group)
412 }
413
414 if len(recentItems) > 0 {
415 recentGroup := NewModelGroup(t, "Recently used", false)
416
417 var validRecentItems []config.SelectedModel
418 for _, recent := range recentItems {
419 key := modelKey(recent.Provider, recent.Model)
420 item, ok := itemsMap[key]
421 if !ok {
422 continue
423 }
424
425 // Show provider for recent items
426 item = NewModelItem(t, item.prov, item.model, m.modelType, true)
427 item.showProvider = true
428
429 validRecentItems = append(validRecentItems, recent)
430 recentGroup.AppendItems(item)
431 if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider {
432 selectedItemID = item.ID()
433 }
434 }
435
436 if len(validRecentItems) != len(recentItems) {
437 // FIXME: Does this need to be here? Is it mutating the config during a read?
438 if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
439 return fmt.Errorf("failed to update recent models: %w", err)
440 }
441 }
442
443 if len(recentGroup.Items) > 0 {
444 groups = append([]ModelGroup{recentGroup}, groups...)
445 }
446 }
447
448 // Set model groups in the list.
449 m.list.SetGroups(groups...)
450 m.list.SetSelectedItem(selectedItemID)
451
452 // Update placeholder based on model type
453 m.input.Placeholder = m.modelType.Placeholder()
454
455 return nil
456}
457
458func getFilteredProviders(cfg *config.Config) ([]catwalk.Provider, error) {
459 providers, err := config.Providers(cfg)
460 if err != nil {
461 return nil, fmt.Errorf("failed to get providers: %w", err)
462 }
463 filteredProviders := []catwalk.Provider{}
464 for _, p := range providers {
465 hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
466 if hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure {
467 filteredProviders = append(filteredProviders, p)
468 }
469 }
470 return filteredProviders, nil
471}
472
473func modelKey(providerID, modelID string) string {
474 if providerID == "" || modelID == "" {
475 return ""
476 }
477 return providerID + ":" + modelID
478}