1// Package models provides the model selection dialog for the TUI.
2package models
3
4import (
5 "fmt"
6 "time"
7
8 "charm.land/bubbles/v2/help"
9 "charm.land/bubbles/v2/key"
10 "charm.land/bubbles/v2/spinner"
11 tea "charm.land/bubbletea/v2"
12 "charm.land/lipgloss/v2"
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/tui/components/core"
17 "github.com/charmbracelet/crush/internal/tui/components/dialogs"
18 "github.com/charmbracelet/crush/internal/tui/components/dialogs/copilot"
19 "github.com/charmbracelet/crush/internal/tui/components/dialogs/hyper"
20 "github.com/charmbracelet/crush/internal/tui/exp/list"
21 "github.com/charmbracelet/crush/internal/tui/styles"
22 "github.com/charmbracelet/crush/internal/tui/util"
23)
24
25const (
26 ModelsDialogID dialogs.DialogID = "models"
27
28 defaultWidth = 60
29)
30
31const (
32 LargeModelType int = iota
33 SmallModelType
34
35 largeModelInputPlaceholder = "Choose a model for large, complex tasks"
36 smallModelInputPlaceholder = "Choose a model for small, simple tasks"
37)
38
39// ModelSelectedMsg is sent when a model is selected
40type ModelSelectedMsg struct {
41 Model config.SelectedModel
42 ModelType config.SelectedModelType
43}
44
45// CloseModelDialogMsg is sent when a model is selected
46type CloseModelDialogMsg struct{}
47
48// ModelDialog interface for the model selection dialog
49type ModelDialog interface {
50 dialogs.DialogModel
51}
52
53type ModelOption struct {
54 Provider catwalk.Provider
55 Model catwalk.Model
56}
57
58type modelDialogCmp struct {
59 width int
60 wWidth int
61 wHeight int
62
63 modelList *ModelListComponent
64 keyMap KeyMap
65 help help.Model
66
67 // API key state
68 needsAPIKey bool
69 apiKeyInput *APIKeyInput
70 selectedModel *ModelOption
71 selectedModelType config.SelectedModelType
72 isAPIKeyValid bool
73 apiKeyValue string
74
75 // Hyper device flow state
76 hyperDeviceFlow *hyper.DeviceFlow
77 showHyperDeviceFlow bool
78
79 // Copilot device flow state
80 copilotDeviceFlow *copilot.DeviceFlow
81 showCopilotDeviceFlow bool
82}
83
84func NewModelDialogCmp() ModelDialog {
85 keyMap := DefaultKeyMap()
86
87 listKeyMap := list.DefaultKeyMap()
88 listKeyMap.Down.SetEnabled(false)
89 listKeyMap.Up.SetEnabled(false)
90 listKeyMap.DownOneItem = keyMap.Next
91 listKeyMap.UpOneItem = keyMap.Previous
92
93 t := styles.CurrentTheme()
94 modelList := NewModelListComponent(listKeyMap, largeModelInputPlaceholder, true)
95 apiKeyInput := NewAPIKeyInput()
96 apiKeyInput.SetShowTitle(false)
97 help := help.New()
98 help.Styles = t.S().Help
99
100 return &modelDialogCmp{
101 modelList: modelList,
102 apiKeyInput: apiKeyInput,
103 width: defaultWidth,
104 keyMap: DefaultKeyMap(),
105 help: help,
106 }
107}
108
109func (m *modelDialogCmp) Init() tea.Cmd {
110 return tea.Batch(
111 m.modelList.Init(),
112 m.apiKeyInput.Init(),
113 )
114}
115
116func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
117 switch msg := msg.(type) {
118 case tea.WindowSizeMsg:
119 m.wWidth = msg.Width
120 m.wHeight = msg.Height
121 m.apiKeyInput.SetWidth(m.width - 2)
122 m.help.SetWidth(m.width - 2)
123 return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
124 case APIKeyStateChangeMsg:
125 u, cmd := m.apiKeyInput.Update(msg)
126 m.apiKeyInput = u.(*APIKeyInput)
127 return m, cmd
128 case hyper.DeviceFlowCompletedMsg:
129 return m, m.saveOauthTokenAndContinue(msg.Token, true)
130 case hyper.DeviceAuthInitiatedMsg, hyper.DeviceFlowErrorMsg:
131 if m.hyperDeviceFlow != nil {
132 u, cmd := m.hyperDeviceFlow.Update(msg)
133 m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
134 return m, cmd
135 }
136 return m, nil
137 case copilot.DeviceAuthInitiatedMsg, copilot.DeviceFlowErrorMsg:
138 if m.copilotDeviceFlow != nil {
139 u, cmd := m.copilotDeviceFlow.Update(msg)
140 m.copilotDeviceFlow = u.(*copilot.DeviceFlow)
141 return m, cmd
142 }
143 return m, nil
144 case copilot.DeviceFlowCompletedMsg:
145 return m, m.saveOauthTokenAndContinue(msg.Token, true)
146 case tea.KeyPressMsg:
147 switch {
148 // Handle Hyper device flow keys
149 case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showHyperDeviceFlow:
150 return m, m.hyperDeviceFlow.CopyCode()
151 case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showCopilotDeviceFlow:
152 return m, m.copilotDeviceFlow.CopyCode()
153 case key.Matches(msg, m.keyMap.Select):
154 // If showing device flow, enter copies code and opens URL
155 if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
156 return m, m.hyperDeviceFlow.CopyCodeAndOpenURL()
157 }
158 if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
159 return m, m.copilotDeviceFlow.CopyCodeAndOpenURL()
160 }
161 selectedItem := m.modelList.SelectedModel()
162 if selectedItem == nil {
163 return m, nil
164 }
165
166 modelType := config.SelectedModelTypeLarge
167 if m.modelList.GetModelType() == SmallModelType {
168 modelType = config.SelectedModelTypeSmall
169 }
170
171 askForApiKey := func() {
172 m.keyMap.isAPIKeyHelp = true
173 m.showHyperDeviceFlow = false
174 m.showCopilotDeviceFlow = false
175 m.needsAPIKey = true
176 m.selectedModel = selectedItem
177 m.selectedModelType = modelType
178 m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
179 }
180
181 if m.isAPIKeyValid {
182 return m, m.saveOauthTokenAndContinue(m.apiKeyValue, true)
183 }
184 if m.needsAPIKey {
185 // Handle API key submission
186 m.apiKeyValue = m.apiKeyInput.Value()
187 provider, err := m.getProvider(m.selectedModel.Provider.ID)
188 if err != nil || provider == nil {
189 return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
190 }
191 providerConfig := config.ProviderConfig{
192 ID: string(m.selectedModel.Provider.ID),
193 Name: m.selectedModel.Provider.Name,
194 APIKey: m.apiKeyValue,
195 Type: provider.Type,
196 BaseURL: provider.APIEndpoint,
197 }
198 return m, tea.Sequence(
199 util.CmdHandler(APIKeyStateChangeMsg{
200 State: APIKeyInputStateVerifying,
201 }),
202 func() tea.Msg {
203 start := time.Now()
204 err := providerConfig.TestConnection(config.Get().Resolver())
205 // intentionally wait for at least 750ms to make sure the user sees the spinner
206 elapsed := time.Since(start)
207 if elapsed < 750*time.Millisecond {
208 time.Sleep(750*time.Millisecond - elapsed)
209 }
210 if err == nil {
211 m.isAPIKeyValid = true
212 return APIKeyStateChangeMsg{
213 State: APIKeyInputStateVerified,
214 }
215 }
216 return APIKeyStateChangeMsg{
217 State: APIKeyInputStateError,
218 }
219 },
220 )
221 }
222
223 // Check if provider is configured
224 if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
225 return m, tea.Sequence(
226 util.CmdHandler(dialogs.CloseDialogMsg{}),
227 util.CmdHandler(ModelSelectedMsg{
228 Model: config.SelectedModel{
229 Model: selectedItem.Model.ID,
230 Provider: string(selectedItem.Provider.ID),
231 ReasoningEffort: selectedItem.Model.DefaultReasoningEffort,
232 MaxTokens: selectedItem.Model.DefaultMaxTokens,
233 },
234 ModelType: modelType,
235 }),
236 )
237 }
238 switch selectedItem.Provider.ID {
239 case hyperp.Name:
240 m.showHyperDeviceFlow = true
241 m.selectedModel = selectedItem
242 m.selectedModelType = modelType
243 m.hyperDeviceFlow = hyper.NewDeviceFlow()
244 m.hyperDeviceFlow.SetWidth(m.width - 2)
245 return m, m.hyperDeviceFlow.Init()
246 case catwalk.InferenceProviderCopilot:
247 if token, ok := config.Get().ImportCopilot(); ok {
248 m.selectedModel = selectedItem
249 m.selectedModelType = modelType
250 return m, m.saveOauthTokenAndContinue(token, true)
251 }
252 m.showCopilotDeviceFlow = true
253 m.selectedModel = selectedItem
254 m.selectedModelType = modelType
255 m.copilotDeviceFlow = copilot.NewDeviceFlow()
256 m.copilotDeviceFlow.SetWidth(m.width - 2)
257 return m, m.copilotDeviceFlow.Init()
258 }
259 // For other providers, show API key input
260 askForApiKey()
261 return m, nil
262 case key.Matches(msg, m.keyMap.Tab):
263 switch {
264 case m.needsAPIKey:
265 u, cmd := m.apiKeyInput.Update(msg)
266 m.apiKeyInput = u.(*APIKeyInput)
267 return m, cmd
268 case m.modelList.GetModelType() == LargeModelType:
269 m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
270 return m, m.modelList.SetModelType(SmallModelType)
271 default:
272 m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
273 return m, m.modelList.SetModelType(LargeModelType)
274 }
275 case key.Matches(msg, m.keyMap.Close):
276 switch {
277 case m.showHyperDeviceFlow:
278 if m.hyperDeviceFlow != nil {
279 m.hyperDeviceFlow.Cancel()
280 }
281 m.showHyperDeviceFlow = false
282 m.selectedModel = nil
283 case m.showCopilotDeviceFlow:
284 if m.copilotDeviceFlow != nil {
285 m.copilotDeviceFlow.Cancel()
286 }
287 m.showCopilotDeviceFlow = false
288 m.selectedModel = nil
289 case m.needsAPIKey:
290 if m.isAPIKeyValid {
291 return m, nil
292 }
293 // Go back to model selection
294 m.needsAPIKey = false
295 m.selectedModel = nil
296 m.isAPIKeyValid = false
297 m.apiKeyValue = ""
298 m.apiKeyInput.Reset()
299 return m, nil
300 default:
301 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
302 }
303 default:
304 switch {
305 case m.needsAPIKey:
306 u, cmd := m.apiKeyInput.Update(msg)
307 m.apiKeyInput = u.(*APIKeyInput)
308 return m, cmd
309 default:
310 u, cmd := m.modelList.Update(msg)
311 m.modelList = u
312 return m, cmd
313 }
314 }
315 case tea.PasteMsg:
316 switch {
317 case m.needsAPIKey:
318 u, cmd := m.apiKeyInput.Update(msg)
319 m.apiKeyInput = u.(*APIKeyInput)
320 return m, cmd
321 default:
322 var cmd tea.Cmd
323 m.modelList, cmd = m.modelList.Update(msg)
324 return m, cmd
325 }
326 case spinner.TickMsg:
327 u, cmd := m.apiKeyInput.Update(msg)
328 m.apiKeyInput = u.(*APIKeyInput)
329 if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
330 u, cmd = m.hyperDeviceFlow.Update(msg)
331 m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
332 }
333 if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
334 u, cmd = m.copilotDeviceFlow.Update(msg)
335 m.copilotDeviceFlow = u.(*copilot.DeviceFlow)
336 }
337 return m, cmd
338 default:
339 // Pass all other messages to the device flow for spinner animation
340 switch {
341 case m.showHyperDeviceFlow && m.hyperDeviceFlow != nil:
342 u, cmd := m.hyperDeviceFlow.Update(msg)
343 m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
344 return m, cmd
345 case m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil:
346 u, cmd := m.copilotDeviceFlow.Update(msg)
347 m.copilotDeviceFlow = u.(*copilot.DeviceFlow)
348 return m, cmd
349 default:
350 u, cmd := m.apiKeyInput.Update(msg)
351 m.apiKeyInput = u.(*APIKeyInput)
352 return m, cmd
353 }
354 }
355 return m, nil
356}
357
358func (m *modelDialogCmp) View() string {
359 t := styles.CurrentTheme()
360
361 if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
362 // Show Hyper device flow
363 m.keyMap.isHyperDeviceFlow = true
364 deviceFlowView := m.hyperDeviceFlow.View()
365 content := lipgloss.JoinVertical(
366 lipgloss.Left,
367 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Authenticate with Hyper", m.width-4)),
368 deviceFlowView,
369 "",
370 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
371 )
372 return m.style().Render(content)
373 }
374 if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
375 // Show Hyper device flow
376 m.keyMap.isCopilotDeviceFlow = m.copilotDeviceFlow.State != copilot.DeviceFlowStateUnavailable
377 m.keyMap.isCopilotUnavailable = m.copilotDeviceFlow.State == copilot.DeviceFlowStateUnavailable
378 deviceFlowView := m.copilotDeviceFlow.View()
379 content := lipgloss.JoinVertical(
380 lipgloss.Left,
381 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Authenticate with GitHub Copilot", m.width-4)),
382 deviceFlowView,
383 "",
384 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
385 )
386 return m.style().Render(content)
387 }
388
389 // Reset the flags when not showing device flow
390 m.keyMap.isHyperDeviceFlow = false
391 m.keyMap.isCopilotDeviceFlow = false
392 m.keyMap.isCopilotUnavailable = false
393
394 switch {
395 case m.needsAPIKey:
396 // Show API key input
397 m.keyMap.isAPIKeyHelp = true
398 m.keyMap.isAPIKeyValid = m.isAPIKeyValid
399 apiKeyView := m.apiKeyInput.View()
400 apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
401 content := lipgloss.JoinVertical(
402 lipgloss.Left,
403 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
404 apiKeyView,
405 "",
406 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
407 )
408 return m.style().Render(content)
409 }
410
411 // Show model selection
412 listView := m.modelList.View()
413 radio := m.modelTypeRadio()
414 content := lipgloss.JoinVertical(
415 lipgloss.Left,
416 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
417 listView,
418 "",
419 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
420 )
421 return m.style().Render(content)
422}
423
424func (m *modelDialogCmp) Cursor() *tea.Cursor {
425 if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
426 return m.hyperDeviceFlow.Cursor()
427 }
428 if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
429 return m.copilotDeviceFlow.Cursor()
430 }
431 if m.needsAPIKey {
432 cursor := m.apiKeyInput.Cursor()
433 if cursor != nil {
434 cursor = m.moveCursor(cursor)
435 return cursor
436 }
437 } else {
438 cursor := m.modelList.Cursor()
439 if cursor != nil {
440 cursor = m.moveCursor(cursor)
441 return cursor
442 }
443 }
444 return nil
445}
446
447func (m *modelDialogCmp) style() lipgloss.Style {
448 t := styles.CurrentTheme()
449 return t.S().Base.
450 Width(m.width).
451 Border(lipgloss.RoundedBorder()).
452 BorderForeground(t.BorderFocus)
453}
454
455func (m *modelDialogCmp) listWidth() int {
456 return m.width - 2
457}
458
459func (m *modelDialogCmp) listHeight() int {
460 return m.wHeight / 2
461}
462
463func (m *modelDialogCmp) Position() (int, int) {
464 row := m.wHeight/4 - 2 // just a bit above the center
465 col := m.wWidth / 2
466 col -= m.width / 2
467 return row, col
468}
469
470func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
471 row, col := m.Position()
472 if m.needsAPIKey {
473 offset := row + 3 // Border + title + API key input offset
474 cursor.Y += offset
475 cursor.X = cursor.X + col + 2
476 } else {
477 offset := row + 3 // Border + title
478 cursor.Y += offset
479 cursor.X = cursor.X + col + 2
480 }
481 return cursor
482}
483
484func (m *modelDialogCmp) ID() dialogs.DialogID {
485 return ModelsDialogID
486}
487
488func (m *modelDialogCmp) modelTypeRadio() string {
489 t := styles.CurrentTheme()
490 choices := []string{"Large Task", "Small Task"}
491 iconSelected := "◉"
492 iconUnselected := "○"
493 if m.modelList.GetModelType() == LargeModelType {
494 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1])
495 }
496 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1])
497}
498
499func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
500 cfg := config.Get()
501 _, ok := cfg.Providers.Get(providerID)
502 return ok
503}
504
505func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
506 cfg := config.Get()
507 providers, err := config.Providers(cfg)
508 if err != nil {
509 return nil, err
510 }
511 for _, p := range providers {
512 if p.ID == providerID {
513 return &p, nil
514 }
515 }
516 return nil, nil
517}
518
519func (m *modelDialogCmp) saveOauthTokenAndContinue(apiKey any, close bool) tea.Cmd {
520 if m.selectedModel == nil {
521 return util.ReportError(fmt.Errorf("no model selected"))
522 }
523
524 cfg := config.Get()
525 err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
526 if err != nil {
527 return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
528 }
529
530 // Reset API key state and continue with model selection
531 selectedModel := *m.selectedModel
532 var cmds []tea.Cmd
533 if close {
534 cmds = append(cmds, util.CmdHandler(dialogs.CloseDialogMsg{}))
535 }
536 cmds = append(
537 cmds,
538 util.CmdHandler(ModelSelectedMsg{
539 Model: config.SelectedModel{
540 Model: selectedModel.Model.ID,
541 Provider: string(selectedModel.Provider.ID),
542 ReasoningEffort: selectedModel.Model.DefaultReasoningEffort,
543 MaxTokens: selectedModel.Model.DefaultMaxTokens,
544 },
545 ModelType: m.selectedModelType,
546 }),
547 )
548 return tea.Sequence(cmds...)
549}