1// Package models provides the model selection dialog for the TUI.
2package models
3
4import (
5 "fmt"
6 "time"
7
8 "charm.land/bubbles/v2/help"
9 "charm.land/bubbles/v2/key"
10 "charm.land/bubbles/v2/spinner"
11 tea "charm.land/bubbletea/v2"
12 "charm.land/lipgloss/v2"
13 "github.com/atotto/clipboard"
14 "github.com/charmbracelet/catwalk/pkg/catwalk"
15 hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
16 "github.com/charmbracelet/crush/internal/config"
17 "github.com/charmbracelet/crush/internal/tui/components/core"
18 "github.com/charmbracelet/crush/internal/tui/components/dialogs"
19 "github.com/charmbracelet/crush/internal/tui/components/dialogs/claude"
20 "github.com/charmbracelet/crush/internal/tui/components/dialogs/hyper"
21 "github.com/charmbracelet/crush/internal/tui/exp/list"
22 "github.com/charmbracelet/crush/internal/tui/styles"
23 "github.com/charmbracelet/crush/internal/tui/util"
24)
25
26const (
27 ModelsDialogID dialogs.DialogID = "models"
28
29 defaultWidth = 60
30)
31
32const (
33 LargeModelType int = iota
34 SmallModelType
35
36 largeModelInputPlaceholder = "Choose a model for large, complex tasks"
37 smallModelInputPlaceholder = "Choose a model for small, simple tasks"
38)
39
40// ModelSelectedMsg is sent when a model is selected
41type ModelSelectedMsg struct {
42 Model config.SelectedModel
43 ModelType config.SelectedModelType
44}
45
46// CloseModelDialogMsg is sent when a model is selected
47type CloseModelDialogMsg struct{}
48
49// ModelDialog interface for the model selection dialog
50type ModelDialog interface {
51 dialogs.DialogModel
52}
53
54type ModelOption struct {
55 Provider catwalk.Provider
56 Model catwalk.Model
57}
58
59type modelDialogCmp struct {
60 width int
61 wWidth int
62 wHeight int
63
64 modelList *ModelListComponent
65 keyMap KeyMap
66 help help.Model
67
68 // API key state
69 needsAPIKey bool
70 apiKeyInput *APIKeyInput
71 selectedModel *ModelOption
72 selectedModelType config.SelectedModelType
73 isAPIKeyValid bool
74 apiKeyValue string
75
76 // Hyper device flow state
77 hyperDeviceFlow *hyper.DeviceFlow
78 showHyperDeviceFlow bool
79
80 // Claude state
81 claudeAuthMethodChooser *claude.AuthMethodChooser
82 claudeOAuth2 *claude.OAuth2
83 showClaudeAuthMethodChooser bool
84 showClaudeOAuth2 bool
85}
86
87func NewModelDialogCmp() ModelDialog {
88 keyMap := DefaultKeyMap()
89
90 listKeyMap := list.DefaultKeyMap()
91 listKeyMap.Down.SetEnabled(false)
92 listKeyMap.Up.SetEnabled(false)
93 listKeyMap.DownOneItem = keyMap.Next
94 listKeyMap.UpOneItem = keyMap.Previous
95
96 t := styles.CurrentTheme()
97 modelList := NewModelListComponent(listKeyMap, largeModelInputPlaceholder, true)
98 apiKeyInput := NewAPIKeyInput()
99 apiKeyInput.SetShowTitle(false)
100 help := help.New()
101 help.Styles = t.S().Help
102
103 return &modelDialogCmp{
104 modelList: modelList,
105 apiKeyInput: apiKeyInput,
106 width: defaultWidth,
107 keyMap: DefaultKeyMap(),
108 help: help,
109
110 claudeAuthMethodChooser: claude.NewAuthMethodChooser(),
111 claudeOAuth2: claude.NewOAuth2(),
112 }
113}
114
115func (m *modelDialogCmp) Init() tea.Cmd {
116 return tea.Batch(
117 m.modelList.Init(),
118 m.apiKeyInput.Init(),
119 m.claudeAuthMethodChooser.Init(),
120 m.claudeOAuth2.Init(),
121 )
122}
123
124func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
125 switch msg := msg.(type) {
126 case tea.WindowSizeMsg:
127 m.wWidth = msg.Width
128 m.wHeight = msg.Height
129 m.apiKeyInput.SetWidth(m.width - 2)
130 m.help.SetWidth(m.width - 2)
131 m.claudeAuthMethodChooser.SetWidth(m.width - 2)
132 return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
133 case APIKeyStateChangeMsg:
134 u, cmd := m.apiKeyInput.Update(msg)
135 m.apiKeyInput = u.(*APIKeyInput)
136 return m, cmd
137 case hyper.DeviceFlowCompletedMsg:
138 return m, m.saveOauthTokenAndContinue(msg.Token, true)
139 case hyper.DeviceAuthInitiatedMsg, hyper.DeviceFlowErrorMsg:
140 if m.hyperDeviceFlow != nil {
141 u, cmd := m.hyperDeviceFlow.Update(msg)
142 m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
143 return m, cmd
144 }
145 return m, nil
146 case claude.ValidationCompletedMsg:
147 var cmds []tea.Cmd
148 u, cmd := m.claudeOAuth2.Update(msg)
149 m.claudeOAuth2 = u.(*claude.OAuth2)
150 cmds = append(cmds, cmd)
151
152 if msg.State == claude.OAuthValidationStateValid {
153 cmds = append(cmds, m.saveOauthTokenAndContinue(msg.Token, false))
154 m.keyMap.isClaudeOAuthHelpComplete = true
155 }
156
157 return m, tea.Batch(cmds...)
158 case claude.AuthenticationCompleteMsg:
159 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
160 case tea.KeyPressMsg:
161 switch {
162 // Handle Hyper device flow keys
163 case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showHyperDeviceFlow:
164 if m.hyperDeviceFlow != nil {
165 return m, m.hyperDeviceFlow.CopyCode()
166 }
167 case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showClaudeOAuth2 && m.claudeOAuth2.State == claude.OAuthStateURL:
168 return m, tea.Sequence(
169 tea.SetClipboard(m.claudeOAuth2.URL),
170 func() tea.Msg {
171 _ = clipboard.WriteAll(m.claudeOAuth2.URL)
172 return nil
173 },
174 util.ReportInfo("URL copied to clipboard"),
175 )
176 case key.Matches(msg, m.keyMap.Choose) && m.showClaudeAuthMethodChooser:
177 m.claudeAuthMethodChooser.ToggleChoice()
178 return m, nil
179 case key.Matches(msg, m.keyMap.Select):
180 // If showing device flow, enter copies code and opens URL
181 if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
182 return m, m.hyperDeviceFlow.CopyCodeAndOpenURL()
183 }
184 selectedItem := m.modelList.SelectedModel()
185
186 modelType := config.SelectedModelTypeLarge
187 if m.modelList.GetModelType() == SmallModelType {
188 modelType = config.SelectedModelTypeSmall
189 }
190
191 askForApiKey := func() {
192 m.keyMap.isClaudeAuthChoiceHelp = false
193 m.keyMap.isClaudeOAuthHelp = false
194 m.keyMap.isAPIKeyHelp = true
195 m.showHyperDeviceFlow = false
196 m.showClaudeAuthMethodChooser = false
197 m.needsAPIKey = true
198 m.selectedModel = selectedItem
199 m.selectedModelType = modelType
200 m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
201 }
202
203 if m.showClaudeAuthMethodChooser {
204 switch m.claudeAuthMethodChooser.State {
205 case claude.AuthMethodAPIKey:
206 askForApiKey()
207 case claude.AuthMethodOAuth2:
208 m.selectedModel = selectedItem
209 m.selectedModelType = modelType
210 m.showClaudeAuthMethodChooser = false
211 m.showClaudeOAuth2 = true
212 m.keyMap.isClaudeAuthChoiceHelp = false
213 m.keyMap.isClaudeOAuthHelp = true
214 }
215 return m, nil
216 }
217 if m.showClaudeOAuth2 {
218 m2, cmd2 := m.claudeOAuth2.ValidationConfirm()
219 m.claudeOAuth2 = m2.(*claude.OAuth2)
220 return m, cmd2
221 }
222 if m.isAPIKeyValid {
223 return m, m.saveOauthTokenAndContinue(m.apiKeyValue, true)
224 }
225 if m.needsAPIKey {
226 // Handle API key submission
227 m.apiKeyValue = m.apiKeyInput.Value()
228 provider, err := m.getProvider(m.selectedModel.Provider.ID)
229 if err != nil || provider == nil {
230 return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
231 }
232 providerConfig := config.ProviderConfig{
233 ID: string(m.selectedModel.Provider.ID),
234 Name: m.selectedModel.Provider.Name,
235 APIKey: m.apiKeyValue,
236 Type: provider.Type,
237 BaseURL: provider.APIEndpoint,
238 }
239 return m, tea.Sequence(
240 util.CmdHandler(APIKeyStateChangeMsg{
241 State: APIKeyInputStateVerifying,
242 }),
243 func() tea.Msg {
244 start := time.Now()
245 err := providerConfig.TestConnection(config.Get().Resolver())
246 // intentionally wait for at least 750ms to make sure the user sees the spinner
247 elapsed := time.Since(start)
248 if elapsed < 750*time.Millisecond {
249 time.Sleep(750*time.Millisecond - elapsed)
250 }
251 if err == nil {
252 m.isAPIKeyValid = true
253 return APIKeyStateChangeMsg{
254 State: APIKeyInputStateVerified,
255 }
256 }
257 return APIKeyStateChangeMsg{
258 State: APIKeyInputStateError,
259 }
260 },
261 )
262 }
263
264 // Check if provider is configured
265 if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
266 return m, tea.Sequence(
267 util.CmdHandler(dialogs.CloseDialogMsg{}),
268 util.CmdHandler(ModelSelectedMsg{
269 Model: config.SelectedModel{
270 Model: selectedItem.Model.ID,
271 Provider: string(selectedItem.Provider.ID),
272 ReasoningEffort: selectedItem.Model.DefaultReasoningEffort,
273 MaxTokens: selectedItem.Model.DefaultMaxTokens,
274 },
275 ModelType: modelType,
276 }),
277 )
278 }
279 switch selectedItem.Provider.ID {
280 case catwalk.InferenceProviderAnthropic:
281 m.showClaudeAuthMethodChooser = true
282 m.keyMap.isClaudeAuthChoiceHelp = true
283 return m, nil
284 case hyperp.Name:
285 m.showHyperDeviceFlow = true
286 m.selectedModel = selectedItem
287 m.selectedModelType = modelType
288 m.hyperDeviceFlow = hyper.NewDeviceFlow()
289 m.hyperDeviceFlow.SetWidth(m.width - 2)
290 return m, m.hyperDeviceFlow.Init()
291 }
292 // For other providers, show API key input
293 askForApiKey()
294 return m, nil
295 case key.Matches(msg, m.keyMap.Tab):
296 switch {
297 case m.showClaudeAuthMethodChooser:
298 m.claudeAuthMethodChooser.ToggleChoice()
299 return m, nil
300 case m.needsAPIKey:
301 u, cmd := m.apiKeyInput.Update(msg)
302 m.apiKeyInput = u.(*APIKeyInput)
303 return m, cmd
304 case m.modelList.GetModelType() == LargeModelType:
305 m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
306 return m, m.modelList.SetModelType(SmallModelType)
307 default:
308 m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
309 return m, m.modelList.SetModelType(LargeModelType)
310 }
311 case key.Matches(msg, m.keyMap.Close):
312 if m.showHyperDeviceFlow {
313 // Cancel device flow and go back to model selection
314 if m.hyperDeviceFlow != nil {
315 m.hyperDeviceFlow.Cancel()
316 }
317 m.showHyperDeviceFlow = false
318 m.selectedModel = nil
319 }
320 if m.showClaudeAuthMethodChooser {
321 m.claudeAuthMethodChooser.SetDefaults()
322 m.showClaudeAuthMethodChooser = false
323 m.keyMap.isClaudeAuthChoiceHelp = false
324 m.keyMap.isClaudeOAuthHelp = false
325 return m, nil
326 }
327 if m.needsAPIKey {
328 if m.isAPIKeyValid {
329 return m, nil
330 }
331 // Go back to model selection
332 m.needsAPIKey = false
333 m.selectedModel = nil
334 m.isAPIKeyValid = false
335 m.apiKeyValue = ""
336 m.apiKeyInput.Reset()
337 return m, nil
338 }
339 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
340 default:
341 if m.showClaudeAuthMethodChooser {
342 u, cmd := m.claudeAuthMethodChooser.Update(msg)
343 m.claudeAuthMethodChooser = u.(*claude.AuthMethodChooser)
344 return m, cmd
345 } else if m.showClaudeOAuth2 {
346 u, cmd := m.claudeOAuth2.Update(msg)
347 m.claudeOAuth2 = u.(*claude.OAuth2)
348 return m, cmd
349 } else if m.needsAPIKey {
350 u, cmd := m.apiKeyInput.Update(msg)
351 m.apiKeyInput = u.(*APIKeyInput)
352 return m, cmd
353 } else {
354 u, cmd := m.modelList.Update(msg)
355 m.modelList = u
356 return m, cmd
357 }
358 }
359 case tea.PasteMsg:
360 if m.showClaudeOAuth2 {
361 u, cmd := m.claudeOAuth2.Update(msg)
362 m.claudeOAuth2 = u.(*claude.OAuth2)
363 return m, cmd
364 } else if m.needsAPIKey {
365 u, cmd := m.apiKeyInput.Update(msg)
366 m.apiKeyInput = u.(*APIKeyInput)
367 return m, cmd
368 } else {
369 var cmd tea.Cmd
370 m.modelList, cmd = m.modelList.Update(msg)
371 return m, cmd
372 }
373 case spinner.TickMsg:
374 u, cmd := m.apiKeyInput.Update(msg)
375 m.apiKeyInput = u.(*APIKeyInput)
376 if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
377 u, cmd = m.hyperDeviceFlow.Update(msg)
378 m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
379 }
380 return m, cmd
381 default:
382 // Pass all other messages to the device flow for spinner animation
383 if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
384 u, cmd := m.hyperDeviceFlow.Update(msg)
385 m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
386 return m, cmd
387 } else if m.showClaudeOAuth2 {
388 u, cmd := m.claudeOAuth2.Update(msg)
389 m.claudeOAuth2 = u.(*claude.OAuth2)
390 return m, cmd
391 } else {
392 u, cmd := m.apiKeyInput.Update(msg)
393 m.apiKeyInput = u.(*APIKeyInput)
394 return m, cmd
395 }
396 }
397 return m, nil
398}
399
400func (m *modelDialogCmp) View() string {
401 t := styles.CurrentTheme()
402
403 if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
404 // Show Hyper device flow
405 m.keyMap.isHyperDeviceFlow = true
406 deviceFlowView := m.hyperDeviceFlow.View()
407 content := lipgloss.JoinVertical(
408 lipgloss.Left,
409 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Authenticate with Hyper", m.width-4)),
410 deviceFlowView,
411 "",
412 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
413 )
414 return m.style().Render(content)
415 }
416
417 // Reset the flags when not showing device flow
418 m.keyMap.isHyperDeviceFlow = false
419
420 switch {
421 case m.showClaudeAuthMethodChooser:
422 chooserView := m.claudeAuthMethodChooser.View()
423 content := lipgloss.JoinVertical(
424 lipgloss.Left,
425 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
426 chooserView,
427 "",
428 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
429 )
430 return m.style().Render(content)
431 case m.showClaudeOAuth2:
432 m.keyMap.isClaudeOAuthURLState = m.claudeOAuth2.State == claude.OAuthStateURL
433 oauth2View := m.claudeOAuth2.View()
434 content := lipgloss.JoinVertical(
435 lipgloss.Left,
436 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
437 oauth2View,
438 "",
439 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
440 )
441 return m.style().Render(content)
442 case m.needsAPIKey:
443 // Show API key input
444 m.keyMap.isAPIKeyHelp = true
445 m.keyMap.isAPIKeyValid = m.isAPIKeyValid
446 apiKeyView := m.apiKeyInput.View()
447 apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
448 content := lipgloss.JoinVertical(
449 lipgloss.Left,
450 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
451 apiKeyView,
452 "",
453 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
454 )
455 return m.style().Render(content)
456 }
457
458 // Show model selection
459 listView := m.modelList.View()
460 radio := m.modelTypeRadio()
461 content := lipgloss.JoinVertical(
462 lipgloss.Left,
463 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
464 listView,
465 "",
466 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
467 )
468 return m.style().Render(content)
469}
470
471func (m *modelDialogCmp) Cursor() *tea.Cursor {
472 if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
473 return m.hyperDeviceFlow.Cursor()
474 }
475 if m.showClaudeAuthMethodChooser {
476 return nil
477 }
478 if m.showClaudeOAuth2 {
479 if cursor := m.claudeOAuth2.CodeInput.Cursor(); cursor != nil {
480 cursor.Y += 2 // FIXME(@andreynering): Why do we need this?
481 return m.moveCursor(cursor)
482 }
483 return nil
484 }
485 if m.needsAPIKey {
486 cursor := m.apiKeyInput.Cursor()
487 if cursor != nil {
488 cursor = m.moveCursor(cursor)
489 return cursor
490 }
491 } else {
492 cursor := m.modelList.Cursor()
493 if cursor != nil {
494 cursor = m.moveCursor(cursor)
495 return cursor
496 }
497 }
498 return nil
499}
500
501func (m *modelDialogCmp) style() lipgloss.Style {
502 t := styles.CurrentTheme()
503 return t.S().Base.
504 Width(m.width).
505 Border(lipgloss.RoundedBorder()).
506 BorderForeground(t.BorderFocus)
507}
508
509func (m *modelDialogCmp) listWidth() int {
510 return m.width - 2
511}
512
513func (m *modelDialogCmp) listHeight() int {
514 return m.wHeight / 2
515}
516
517func (m *modelDialogCmp) Position() (int, int) {
518 row := m.wHeight/4 - 2 // just a bit above the center
519 col := m.wWidth / 2
520 col -= m.width / 2
521 return row, col
522}
523
524func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
525 row, col := m.Position()
526 if m.needsAPIKey {
527 offset := row + 3 // Border + title + API key input offset
528 cursor.Y += offset
529 cursor.X = cursor.X + col + 2
530 } else {
531 offset := row + 3 // Border + title
532 cursor.Y += offset
533 cursor.X = cursor.X + col + 2
534 }
535 return cursor
536}
537
538func (m *modelDialogCmp) ID() dialogs.DialogID {
539 return ModelsDialogID
540}
541
542func (m *modelDialogCmp) modelTypeRadio() string {
543 t := styles.CurrentTheme()
544 choices := []string{"Large Task", "Small Task"}
545 iconSelected := "◉"
546 iconUnselected := "○"
547 if m.modelList.GetModelType() == LargeModelType {
548 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1])
549 }
550 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1])
551}
552
553func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
554 cfg := config.Get()
555 _, ok := cfg.Providers.Get(providerID)
556 return ok
557}
558
559func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
560 cfg := config.Get()
561 providers, err := config.Providers(cfg)
562 if err != nil {
563 return nil, err
564 }
565 for _, p := range providers {
566 if p.ID == providerID {
567 return &p, nil
568 }
569 }
570 return nil, nil
571}
572
573func (m *modelDialogCmp) saveOauthTokenAndContinue(apiKey any, close bool) tea.Cmd {
574 if m.selectedModel == nil {
575 return util.ReportError(fmt.Errorf("no model selected"))
576 }
577
578 cfg := config.Get()
579 err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
580 if err != nil {
581 return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
582 }
583
584 // Reset API key state and continue with model selection
585 selectedModel := *m.selectedModel
586 var cmds []tea.Cmd
587 if close {
588 cmds = append(cmds, util.CmdHandler(dialogs.CloseDialogMsg{}))
589 }
590 cmds = append(
591 cmds,
592 util.CmdHandler(ModelSelectedMsg{
593 Model: config.SelectedModel{
594 Model: selectedModel.Model.ID,
595 Provider: string(selectedModel.Provider.ID),
596 ReasoningEffort: selectedModel.Model.DefaultReasoningEffort,
597 MaxTokens: selectedModel.Model.DefaultMaxTokens,
598 },
599 ModelType: m.selectedModelType,
600 }),
601 )
602 return tea.Sequence(cmds...)
603}