1package dialog
2
3import (
4 "cmp"
5 "fmt"
6 "slices"
7 "strings"
8
9 "charm.land/bubbles/v2/help"
10 "charm.land/bubbles/v2/key"
11 "charm.land/bubbles/v2/textinput"
12 tea "charm.land/bubbletea/v2"
13 "charm.land/lipgloss/v2"
14 "github.com/charmbracelet/catwalk/pkg/catwalk"
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/ui/common"
17 "github.com/charmbracelet/crush/internal/uiutil"
18)
19
20// ModelType represents the type of model to select.
21type ModelType int
22
23const (
24 ModelTypeLarge ModelType = iota
25 ModelTypeSmall
26)
27
28// String returns the string representation of the [ModelType].
29func (mt ModelType) String() string {
30 switch mt {
31 case ModelTypeLarge:
32 return "Large Task"
33 case ModelTypeSmall:
34 return "Small Task"
35 default:
36 return "Unknown"
37 }
38}
39
40const (
41 largeModelInputPlaceholder = "Choose a model for large, complex tasks"
42 smallModelInputPlaceholder = "Choose a model for small, simple tasks"
43)
44
45// ModelsID is the identifier for the model selection dialog.
46const ModelsID = "models"
47
48// Models represents a model selection dialog.
49type Models struct {
50 com *common.Common
51
52 modelType ModelType
53 providers []catwalk.Provider
54
55 width, height int
56
57 keyMap struct {
58 Tab key.Binding
59 UpDown key.Binding
60 Select key.Binding
61 Next key.Binding
62 Previous key.Binding
63 Close key.Binding
64 }
65 list *ModelsList
66 input textinput.Model
67 help help.Model
68}
69
70var _ Dialog = (*Models)(nil)
71
72// NewModels creates a new Models dialog.
73func NewModels(com *common.Common) (*Models, error) {
74 t := com.Styles
75 m := &Models{}
76 m.com = com
77 help := help.New()
78 help.Styles = t.DialogHelpStyles()
79
80 m.help = help
81 m.list = NewModelsList(t)
82 m.list.Focus()
83 m.list.SetSelected(0)
84
85 m.input = textinput.New()
86 m.input.SetVirtualCursor(false)
87 m.input.Placeholder = largeModelInputPlaceholder
88 m.input.SetStyles(com.Styles.TextInput)
89 m.input.Focus()
90
91 m.keyMap.Tab = key.NewBinding(
92 key.WithKeys("tab", "shift+tab"),
93 key.WithHelp("tab", "toggle type"),
94 )
95 m.keyMap.Select = key.NewBinding(
96 key.WithKeys("enter", "ctrl+y"),
97 key.WithHelp("enter", "confirm"),
98 )
99 m.keyMap.UpDown = key.NewBinding(
100 key.WithKeys("up", "down"),
101 key.WithHelp("↑/↓", "choose"),
102 )
103 m.keyMap.Next = key.NewBinding(
104 key.WithKeys("down", "ctrl+n"),
105 key.WithHelp("↓", "next item"),
106 )
107 m.keyMap.Previous = key.NewBinding(
108 key.WithKeys("up", "ctrl+p"),
109 key.WithHelp("↑", "previous item"),
110 )
111 m.keyMap.Close = CloseKey
112
113 providers, err := getFilteredProviders(com.Config())
114 if err != nil {
115 return nil, fmt.Errorf("failed to get providers: %w", err)
116 }
117
118 m.providers = providers
119 if err := m.setProviderItems(); err != nil {
120 return nil, fmt.Errorf("failed to set provider items: %w", err)
121 }
122
123 return m, nil
124}
125
126// SetSize sets the size of the dialog.
127func (m *Models) SetSize(width, height int) {
128 t := m.com.Styles
129 m.width = width
130 m.height = height
131 innerWidth := width - t.Dialog.View.GetHorizontalFrameSize()
132 heightOffset := t.Dialog.Title.GetVerticalFrameSize() + 1 + // (1) title content
133 t.Dialog.InputPrompt.GetVerticalFrameSize() + 1 + // (1) input content
134 t.Dialog.HelpView.GetVerticalFrameSize() +
135 t.Dialog.View.GetVerticalFrameSize()
136 m.input.SetWidth(innerWidth - t.Dialog.InputPrompt.GetHorizontalFrameSize() - 1) // (1) cursor padding
137 m.list.SetSize(innerWidth, height-heightOffset)
138 m.help.SetWidth(width)
139}
140
141// ID implements Dialog.
142func (m *Models) ID() string {
143 return ModelsID
144}
145
146// Update implements Dialog.
147func (m *Models) Update(msg tea.Msg) tea.Msg {
148 switch msg := msg.(type) {
149 case tea.KeyPressMsg:
150 switch {
151 case key.Matches(msg, m.keyMap.Close):
152 return CloseMsg{}
153 case key.Matches(msg, m.keyMap.Previous):
154 m.list.Focus()
155 if m.list.IsSelectedFirst() {
156 m.list.SelectLast()
157 m.list.ScrollToBottom()
158 break
159 }
160 m.list.SelectPrev()
161 m.list.ScrollToSelected()
162 case key.Matches(msg, m.keyMap.Next):
163 m.list.Focus()
164 if m.list.IsSelectedLast() {
165 m.list.SelectFirst()
166 m.list.ScrollToTop()
167 break
168 }
169 m.list.SelectNext()
170 m.list.ScrollToSelected()
171 case key.Matches(msg, m.keyMap.Select):
172 if selectedItem := m.list.SelectedItem(); selectedItem != nil {
173 // TODO: Handle model selection confirmation.
174 }
175 case key.Matches(msg, m.keyMap.Tab):
176 if m.modelType == ModelTypeLarge {
177 m.modelType = ModelTypeSmall
178 } else {
179 m.modelType = ModelTypeLarge
180 }
181 if err := m.setProviderItems(); err != nil {
182 return uiutil.ReportError(err)
183 }
184 default:
185 var cmd tea.Cmd
186 m.input, cmd = m.input.Update(msg)
187 value := m.input.Value()
188 m.list.SetFilter(value)
189 m.list.ScrollToSelected()
190 return cmd
191 }
192 }
193 return nil
194}
195
196// Cursor returns the cursor for the dialog.
197func (m *Models) Cursor() *tea.Cursor {
198 return InputCursor(m.com.Styles, m.input.Cursor())
199}
200
201// modelTypeRadioView returns the radio view for model type selection.
202func (m *Models) modelTypeRadioView() string {
203 t := m.com.Styles
204 textStyle := t.HalfMuted
205 largeRadioStyle := t.RadioOff
206 smallRadioStyle := t.RadioOff
207 if m.modelType == ModelTypeLarge {
208 largeRadioStyle = t.RadioOn
209 } else {
210 smallRadioStyle = t.RadioOn
211 }
212
213 largeRadio := largeRadioStyle.Padding(0, 1).Render()
214 smallRadio := smallRadioStyle.Padding(0, 1).Render()
215
216 return fmt.Sprintf("%s%s %s%s",
217 largeRadio, textStyle.Render(ModelTypeLarge.String()),
218 smallRadio, textStyle.Render(ModelTypeSmall.String()))
219}
220
221// View implements Dialog.
222func (m *Models) View() string {
223 t := m.com.Styles
224 titleStyle := t.Dialog.Title
225 dialogStyle := t.Dialog.View
226
227 radios := m.modelTypeRadioView()
228
229 headerOffset := lipgloss.Width(radios) + titleStyle.GetHorizontalFrameSize() +
230 dialogStyle.GetHorizontalFrameSize()
231
232 header := common.DialogTitle(t, "Switch Model", m.width-headerOffset) + radios
233
234 return HeaderInputListHelpView(t, m.width, m.list.Height(), header,
235 m.input.View(), m.list.Render(), m.help.View(m))
236}
237
238// ShortHelp returns the short help view.
239func (m *Models) ShortHelp() []key.Binding {
240 return []key.Binding{
241 m.keyMap.UpDown,
242 m.keyMap.Tab,
243 m.keyMap.Select,
244 m.keyMap.Close,
245 }
246}
247
248// FullHelp returns the full help view.
249func (m *Models) FullHelp() [][]key.Binding {
250 return [][]key.Binding{
251 {
252 m.keyMap.Select,
253 m.keyMap.Next,
254 m.keyMap.Previous,
255 m.keyMap.Tab,
256 },
257 {
258 m.keyMap.Close,
259 },
260 }
261}
262
263// setProviderItems sets the provider items in the list.
264func (m *Models) setProviderItems() error {
265 t := m.com.Styles
266 cfg := m.com.Config()
267
268 selectedType := config.SelectedModelTypeLarge
269 if m.modelType == ModelTypeLarge {
270 selectedType = config.SelectedModelTypeLarge
271 } else {
272 selectedType = config.SelectedModelTypeSmall
273 }
274
275 var selectedItemID string
276 currentModel := cfg.Models[selectedType]
277 recentItems := cfg.RecentModels[selectedType]
278
279 // Track providers already added to avoid duplicates
280 addedProviders := make(map[string]bool)
281
282 // Get a list of known providers to compare against
283 knownProviders, err := config.Providers(cfg)
284 if err != nil {
285 return fmt.Errorf("failed to get providers: %w", err)
286 }
287
288 containsProviderFunc := func(id string) func(p catwalk.Provider) bool {
289 return func(p catwalk.Provider) bool {
290 return p.ID == catwalk.InferenceProvider(id)
291 }
292 }
293
294 // itemsMap contains the keys of added model items.
295 itemsMap := make(map[string]*ModelItem)
296 groups := []ModelGroup{}
297 for id, p := range cfg.Providers.Seq2() {
298 if p.Disable {
299 continue
300 }
301
302 // Check if this provider is not in the known providers list
303 if !slices.ContainsFunc(knownProviders, containsProviderFunc(id)) ||
304 !slices.ContainsFunc(m.providers, containsProviderFunc(id)) {
305 provider := p.ToProvider()
306
307 // Add this unknown provider to the list
308 name := p.Name
309 if name == "" {
310 name = id
311 }
312
313 addedProviders[id] = true
314
315 group := NewModelGroup(t, name, true)
316 for _, model := range p.Models {
317 item := NewModelItem(t, provider, model)
318 group.AppendItems(item)
319 itemsMap[item.ID()] = item
320 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
321 selectedItemID = item.ID()
322 }
323 }
324 }
325 }
326
327 // Now add known providers from the predefined list
328 for _, provider := range m.providers {
329 providerID := string(provider.ID)
330 if addedProviders[providerID] {
331 continue
332 }
333
334 providerConfig, providerConfigured := cfg.Providers.Get(providerID)
335 if providerConfigured && providerConfig.Disable {
336 continue
337 }
338
339 displayProvider := provider
340 if providerConfigured {
341 displayProvider.Name = cmp.Or(providerConfig.Name, displayProvider.Name)
342 modelIndex := make(map[string]int, len(displayProvider.Models))
343 for i, model := range displayProvider.Models {
344 modelIndex[model.ID] = i
345 }
346 for _, model := range providerConfig.Models {
347 if model.ID == "" {
348 continue
349 }
350 if idx, ok := modelIndex[model.ID]; ok {
351 if model.Name != "" {
352 displayProvider.Models[idx].Name = model.Name
353 }
354 continue
355 }
356 if model.Name == "" {
357 model.Name = model.ID
358 }
359 displayProvider.Models = append(displayProvider.Models, model)
360 modelIndex[model.ID] = len(displayProvider.Models) - 1
361 }
362 }
363
364 name := displayProvider.Name
365 if name == "" {
366 name = providerID
367 }
368
369 group := NewModelGroup(t, name, providerConfigured)
370 for _, model := range displayProvider.Models {
371 item := NewModelItem(t, provider, model)
372 group.AppendItems(item)
373 itemsMap[item.ID()] = item
374 if model.ID == currentModel.Model && string(provider.ID) == currentModel.Provider {
375 selectedItemID = item.ID()
376 }
377 }
378
379 groups = append(groups, group)
380 }
381
382 if len(recentItems) > 0 {
383 recentGroup := NewModelGroup(t, "Recently used", false)
384
385 var validRecentItems []config.SelectedModel
386 for _, recent := range recentItems {
387 key := modelKey(recent.Provider, recent.Model)
388 item, ok := itemsMap[key]
389 if !ok {
390 continue
391 }
392
393 validRecentItems = append(validRecentItems, recent)
394 recentGroup.AppendItems(item)
395 if recent.Model == currentModel.Model && recent.Provider == currentModel.Provider {
396 selectedItemID = item.ID()
397 }
398 }
399
400 if len(validRecentItems) != len(recentItems) {
401 // FIXME: Does this need to be here? Is it mutating the config during a read?
402 if err := cfg.SetConfigField(fmt.Sprintf("recent_models.%s", selectedType), validRecentItems); err != nil {
403 return fmt.Errorf("failed to update recent models: %w", err)
404 }
405 }
406
407 if len(recentGroup.Items) > 0 {
408 groups = append([]ModelGroup{recentGroup}, groups...)
409 }
410 }
411
412 // Set model groups in the list.
413 m.list.SetGroups(groups...)
414 m.list.SetSelectedItem(selectedItemID)
415 // Update placeholder based on model type
416 if m.modelType == ModelTypeLarge {
417 m.input.Placeholder = largeModelInputPlaceholder
418 } else {
419 m.input.Placeholder = smallModelInputPlaceholder
420 }
421
422 return nil
423}
424
425func getFilteredProviders(cfg *config.Config) ([]catwalk.Provider, error) {
426 providers, err := config.Providers(cfg)
427 if err != nil {
428 return nil, fmt.Errorf("failed to get providers: %w", err)
429 }
430 filteredProviders := []catwalk.Provider{}
431 for _, p := range providers {
432 hasAPIKeyEnv := strings.HasPrefix(p.APIKey, "$")
433 if hasAPIKeyEnv && p.ID != catwalk.InferenceProviderAzure {
434 filteredProviders = append(filteredProviders, p)
435 }
436 }
437 return filteredProviders, nil
438}
439
440func modelKey(providerID, modelID string) string {
441 if providerID == "" || modelID == "" {
442 return ""
443 }
444 return providerID + ":" + modelID
445}