1package dialog
2
3import (
4 "cmp"
5 "fmt"
6 "slices"
7 "strings"
8
9 "charm.land/bubbles/v2/help"
10 "charm.land/bubbles/v2/key"
11 "charm.land/bubbles/v2/textinput"
12 tea "charm.land/bubbletea/v2"
13 "charm.land/lipgloss/v2"
14 "github.com/charmbracelet/catwalk/pkg/catwalk"
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/ui/common"
17 "github.com/charmbracelet/crush/internal/uiutil"
18)
19
20// ModelType represents the type of model to select.
21type ModelType int
22
23const (
24 ModelTypeLarge ModelType = iota
25 ModelTypeSmall
26)
27
28// String returns the string representation of the [ModelType].
29func (mt ModelType) String() string {
30 switch mt {
31 case ModelTypeLarge:
32 return "Large Task"
33 case ModelTypeSmall:
34 return "Small Task"
35 default:
36 return "Unknown"
37 }
38}
39
40const (
41 largeModelInputPlaceholder = "Choose a model for large, complex tasks"
42 smallModelInputPlaceholder = "Choose a model for small, simple tasks"
43)
44
45// ModelsID is the identifier for the model selection dialog.
46const ModelsID = "models"
47
48// Models represents a model selection dialog.
49type Models struct {
50 com *common.Common
51
52 modelType ModelType
53 providers []catwalk.Provider
54
55 width, height int
56
57 keyMap struct {
58 Tab key.Binding
59 UpDown key.Binding
60 Select key.Binding
61 Next key.Binding
62 Previous key.Binding
63 Close key.Binding
64 }
65 list *ModelsList
66 input textinput.Model
67 help help.Model
68}
69
70var _ Dialog = (*Models)(nil)
71
72// NewModels creates a new Models dialog.
73func NewModels(com *common.Common) (*Models, error) {
74 t := com.Styles
75 m := &Models{}
76 m.com = com
77 help := help.New()
78 help.Styles = t.DialogHelpStyles()
79
80 m.help = help
81 m.list = NewModelsList(t)
82 m.list.Focus()
83 m.list.SetSelected(0)
84
85 m.input = textinput.New()
86 m.input.SetVirtualCursor(false)
87 m.input.Placeholder = largeModelInputPlaceholder
88 m.input.SetStyles(com.Styles.TextInput)
89 m.input.Focus()
90
91 m.keyMap.Tab = key.NewBinding(
92 key.WithKeys("tab", "shift+tab"),
93 key.WithHelp("tab", "toggle type"),
94 )
95 m.keyMap.Select = key.NewBinding(
96 key.WithKeys("enter", "ctrl+y"),
97 key.WithHelp("enter", "confirm"),
98 )
99 m.keyMap.UpDown = key.NewBinding(
100 key.WithKeys("up", "down"),
101 key.WithHelp("↑/↓", "choose"),
102 )
103 m.keyMap.Next = key.NewBinding(
104 key.WithKeys("down", "ctrl+n"),
105 key.WithHelp("↓", "next item"),
106 )
107 m.keyMap.Previous = key.NewBinding(
108 key.WithKeys("up", "ctrl+p"),
109 key.WithHelp("↑", "previous item"),
110 )
111 m.keyMap.Close = CloseKey
112
113 providers, err := getFilteredProviders(com.Config())
114 if err != nil {
115 return nil, fmt.Errorf("failed to get providers: %w", err)
116 }
117
118 m.providers = providers
119 if err := m.setProviderItems(); err != nil {
120 return nil, fmt.Errorf("failed to set provider items: %w", err)
121 }
122
123 return m, nil
124}
125
126// SetSize sets the size of the dialog.
127func (m *Models) SetSize(width, height int) {
128 t := m.com.Styles
129 m.width = width
130 m.height = height
131 innerWidth := width - t.Dialog.View.GetHorizontalFrameSize()
132 heightOffset := t.Dialog.Title.GetVerticalFrameSize() + 1 + // (1) title content
133 t.Dialog.InputPrompt.GetVerticalFrameSize() + 1 + // (1) input content
134 t.Dialog.HelpView.GetVerticalFrameSize() +
135 t.Dialog.View.GetVerticalFrameSize()
136 m.input.SetWidth(innerWidth - t.Dialog.InputPrompt.GetHorizontalFrameSize() - 1) // (1) cursor padding
137 m.list.SetSize(innerWidth, height-heightOffset)
138 m.help.SetWidth(width)
139}
140
141// ID implements Dialog.
142func (m *Models) ID() string {
143 return ModelsID
144}
145
146// Update implements Dialog.
147func (m *Models) Update(msg tea.Msg) tea.Msg {
148 switch msg := msg.(type) {
149 case tea.KeyPressMsg:
150 switch {
151 case key.Matches(msg, m.keyMap.Close):
152 return CloseMsg{}
153 case key.Matches(msg, m.keyMap.Previous):
154 m.list.Focus()
155 if m.list.IsSelectedFirst() {
156 m.list.SelectLast()
157 m.list.ScrollToBottom()
158 break
159 }
160 m.list.SelectPrev()
161 m.list.ScrollToSelected()
162 case key.Matches(msg, m.keyMap.Next):
163 m.list.Focus()
164 if m.list.IsSelectedLast() {
165 m.list.SelectFirst()
166 m.list.ScrollToTop()
167 break
168 }
169 m.list.SelectNext()
170 m.list.ScrollToSelected()
171 case key.Matches(msg, m.keyMap.Select):
172 selectedItem := m.list.SelectedItem()
173 if selectedItem == nil {
174 break
175 }
176
177 modelItem, ok := selectedItem.(*ModelItem)
178 if !ok {
179 break
180 }
181
182 return ModelSelectedMsg{
183 Provider: modelItem.prov,
184 Model: modelItem.model,
185 }
186 case key.Matches(msg, m.keyMap.Tab):
187 if m.modelType == ModelTypeLarge {
188 m.modelType = ModelTypeSmall
189 } else {
190 m.modelType = ModelTypeLarge
191 }
192 if err := m.setProviderItems(); err != nil {
193 return uiutil.ReportError(err)
194 }
195 default:
196 var cmd tea.Cmd
197 m.input, cmd = m.input.Update(msg)
198 value := m.input.Value()
199 m.list.SetFilter(value)
200 m.list.ScrollToSelected()
201 return cmd
202 }
203 }
204 return nil
205}
206
207// Cursor returns the cursor for the dialog.
208func (m *Models) Cursor() *tea.Cursor {
209 return InputCursor(m.com.Styles, m.input.Cursor())
210}
211
212// modelTypeRadioView returns the radio view for model type selection.
213func (m *Models) modelTypeRadioView() string {
214 t := m.com.Styles
215 textStyle := t.HalfMuted
216 largeRadioStyle := t.RadioOff
217 smallRadioStyle := t.RadioOff
218 if m.modelType == ModelTypeLarge {
219 largeRadioStyle = t.RadioOn
220 } else {
221 smallRadioStyle = t.RadioOn
222 }
223
224 largeRadio := largeRadioStyle.Padding(0, 1).Render()
225 smallRadio := smallRadioStyle.Padding(0, 1).Render()
226
227 return fmt.Sprintf("%s%s %s%s",
228 largeRadio, textStyle.Render(ModelTypeLarge.String()),
229 smallRadio, textStyle.Render(ModelTypeSmall.String()))
230}
231
232// View implements Dialog.
233func (m *Models) View() string {
234 t := m.com.Styles
235 titleStyle := t.Dialog.Title
236 dialogStyle := t.Dialog.View
237
238 radios := m.modelTypeRadioView()
239
240 headerOffset := lipgloss.Width(radios) + titleStyle.GetHorizontalFrameSize() +
241 dialogStyle.GetHorizontalFrameSize()
242
243 header := common.DialogTitle(t, "Switch Model", m.width-headerOffset) + radios
244
245 return HeaderInputListHelpView(t, m.width, m.list.Height(), header,
246 m.input.View(), m.list.Render(), m.help.View(m))
247}
248
249// ShortHelp returns the short help view.
250func (m *Models) ShortHelp() []key.Binding {
251 return []key.Binding{
252 m.keyMap.UpDown,
253 m.keyMap.Tab,
254 m.keyMap.Select,
255 m.keyMap.Close,
256 }
257}
258
259// FullHelp returns the full help view.
260func (m *Models) FullHelp() [][]key.Binding {
261 return [][]key.Binding{
262 {
263 m.keyMap.Select,
264 m.keyMap.Next,
265 m.keyMap.Previous,
266 m.keyMap.Tab,
267 },
268 {
269 m.keyMap.Close,
270 },
271 }
272}
273
274// setProviderItems sets the provider items in the list.
275func (m *Models) setProviderItems() error {
276 t := m.com.Styles
277 cfg := m.com.Config()
278
279 selectedType := config.SelectedModelTypeLarge
280 if m.modelType == ModelTypeLarge {
281 selectedType = config.SelectedModelTypeLarge
282 } else {
283 selectedType = config.SelectedModelTypeSmall
284 }
285
286 var selectedItemID string
287 currentModel := cfg.Models[selectedType]
288 recentItems := cfg.RecentModels[selectedType]
289
290 // Track providers already added to avoid duplicates
291 addedProviders := make(map[string]bool)
292
293 // Get a list of known providers to compare against
294 knownProviders, err := config.Providers(cfg)
295 if err != nil {
296 return fmt.Errorf("failed to get providers: %w", err)
297 }
298
299 containsProviderFunc := func(id string) func(p catwalk.Provider) bool {
300 return func(p catwalk.Provider) bool {
301 return p.ID == catwalk.InferenceProvider(id)
302 }
303 }
304
305 // itemsMap contains the keys of added model items.
306 itemsMap := make(map[string]*ModelItem)
307 groups := []ModelGroup{}
308 for id, p := range cfg.Providers.Seq2() {
309 if p.Disable {
310 continue
311 }
312
313 // Check if this provider is not in the known providers list
314 if !slices.ContainsFunc(knownProviders, containsProviderFunc(id)) ||
315 !slices.ContainsFunc(m.providers, containsProviderFunc(id)) {
316 provider := p.ToProvider()
317
318 // Add this unknown provider to the list
319 name := p.Name
320 if name == "" {
321 name = id
322 }
323
324 addedProviders[id] = true
325
326 group := NewModelGroup(t, name, true)
327 for _, model := range p.Models {
328 item := NewModelItem(t, provider, model, false)
329 group.AppendItems(item)
330 itemsMap[item.ID()] = item
331 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
332 selectedItemID = item.ID()
333 }
334 }
335 }
336 }
337
338 // Now add known providers from the predefined list
339 for _, provider := range m.providers {
340 providerID := string(provider.ID)
341 if addedProviders[providerID] {
342 continue
343 }
344
345 providerConfig, providerConfigured := cfg.Providers.Get(providerID)
346 if providerConfigured && providerConfig.Disable {
347 continue
348 }
349
350 displayProvider := provider
351 if providerConfigured {
352 displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
353 modelIndex := make(map[string]int, len(displayProvider.Models))
354 for i, model := range displayProvider.Models {
355 modelIndex[model.ID] = i
356 }
357 for _, model := range providerConfig.Models {
358 if model.ID == "" {
359 continue
360 }
361 if idx, ok := modelIndex[model.ID]; ok {
362 if model.Name != "" {
363 displayProvider.Models[idx].Name = model.Name
364 }
365 continue
366 }
367 if model.Name == "" {
368 model.Name = model.ID
369 }
370 displayProvider.Models = append(displayProvider.Models, model)
371 modelIndex[model.ID] = len(displayProvider.Models) - 1
372 }
373 }
374
375 name := displayProvider.Name
376 if name == "" {
377 name = providerID
378 }
379
380 group := NewModelGroup(t, name, providerConfigured)
381 for _, model := range displayProvider.Models {
382 item := NewModelItem(t, provider, model, false)
383 group.AppendItems(item)
384 itemsMap[item.ID()] = item
385 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
386 selectedItemID = item.ID()
387 }
388 }
389
390 groups = append(groups, group)
391 }
392
393 if len(recentItems) > 0 {
394 recentGroup := NewModelGroup(t, "Recently used", false)
395
396 var validRecentItems []config.SelectedModel
397 for _, recent := range recentItems {
398 key := modelKey(recent.Provider, recent.Model)
399 item, ok := itemsMap[key]
400 if !ok {
401 continue
402 }
403
404 // Show provider for recent items
405 item = NewModelItem(t, item.prov, item.model, true)
406 item.showProvider = true
407
408 validRecentItems = append(validRecentItems, recent)
409 recentGroup.AppendItems(item)
410 if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider {
411 selectedItemID = item.ID()
412 }
413 }
414
415 if len(validRecentItems) != len(recentItems) {
416 // FIXME: Does this need to be here? Is it mutating the config during a read?
417 if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
418 return fmt.Errorf("failed to update recent models: %w", err)
419 }
420 }
421
422 if len(recentGroup.Items) > 0 {
423 groups = append([]ModelGroup{recentGroup}, groups...)
424 }
425 }
426
427 // Set model groups in the list.
428 m.list.SetGroups(groups...)
429 m.list.SetSelectedItem(selectedItemID)
430
431 // Update placeholder based on model type
432 if m.modelType == ModelTypeLarge {
433 m.input.Placeholder = largeModelInputPlaceholder
434 } else {
435 m.input.Placeholder = smallModelInputPlaceholder
436 }
437
438 return nil
439}
440
441func getFilteredProviders(cfg *config.Config) ([]catwalk.Provider, error) {
442 providers, err := config.Providers(cfg)
443 if err != nil {
444 return nil, fmt.Errorf("failed to get providers: %w", err)
445 }
446 filteredProviders := []catwalk.Provider{}
447 for _, p := range providers {
448 hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
449 if hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure {
450 filteredProviders = append(filteredProviders, p)
451 }
452 }
453 return filteredProviders, nil
454}
455
456func modelKey(providerID, modelID string) string {
457 if providerID == "" || modelID == "" {
458 return ""
459 }
460 return providerID + ":" + modelID
461}