1package models
2
3import (
4 "fmt"
5 "time"
6
7 "charm.land/bubbles/v2/help"
8 "charm.land/bubbles/v2/key"
9 "charm.land/bubbles/v2/spinner"
10 tea "charm.land/bubbletea/v2"
11 "charm.land/lipgloss/v2"
12 "github.com/atotto/clipboard"
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/tui/components/core"
16 "github.com/charmbracelet/crush/internal/tui/components/dialogs"
17 "github.com/charmbracelet/crush/internal/tui/components/dialogs/claude"
18 "github.com/charmbracelet/crush/internal/tui/exp/list"
19 "github.com/charmbracelet/crush/internal/tui/styles"
20 "github.com/charmbracelet/crush/internal/tui/util"
21)
22
23const (
24 ModelsDialogID dialogs.DialogID = "models"
25
26 defaultWidth = 60
27)
28
29const (
30 LargeModelType int = iota
31 SmallModelType
32
33 largeModelInputPlaceholder = "Choose a model for large, complex tasks"
34 smallModelInputPlaceholder = "Choose a model for small, simple tasks"
35)
36
37// ModelSelectedMsg is sent when a model is selected
38type ModelSelectedMsg struct {
39 Model config.SelectedModel
40 ModelType config.SelectedModelType
41}
42
43// CloseModelDialogMsg is sent when a model is selected
44type CloseModelDialogMsg struct{}
45
46// ModelDialog interface for the model selection dialog
47type ModelDialog interface {
48 dialogs.DialogModel
49}
50
51type ModelOption struct {
52 Provider catwalk.Provider
53 Model catwalk.Model
54}
55
56type modelDialogCmp struct {
57 width int
58 wWidth int
59 wHeight int
60
61 modelList *ModelListComponent
62 keyMap KeyMap
63 help help.Model
64
65 // API key state
66 needsAPIKey bool
67 apiKeyInput *APIKeyInput
68 selectedModel *ModelOption
69 selectedModelType config.SelectedModelType
70 isAPIKeyValid bool
71 apiKeyValue string
72
73 // Claude state
74 claudeAuthMethodChooser *claude.AuthMethodChooser
75 claudeOAuth2 *claude.OAuth2
76 showClaudeAuthMethodChooser bool
77 showClaudeOAuth2 bool
78}
79
80func NewModelDialogCmp() ModelDialog {
81 keyMap := DefaultKeyMap()
82
83 listKeyMap := list.DefaultKeyMap()
84 listKeyMap.Down.SetEnabled(false)
85 listKeyMap.Up.SetEnabled(false)
86 listKeyMap.DownOneItem = keyMap.Next
87 listKeyMap.UpOneItem = keyMap.Previous
88
89 t := styles.CurrentTheme()
90 modelList := NewModelListComponent(listKeyMap, largeModelInputPlaceholder, true)
91 apiKeyInput := NewAPIKeyInput()
92 apiKeyInput.SetShowTitle(false)
93 help := help.New()
94 help.Styles = t.S().Help
95
96 return &modelDialogCmp{
97 modelList: modelList,
98 apiKeyInput: apiKeyInput,
99 width: defaultWidth,
100 keyMap: DefaultKeyMap(),
101 help: help,
102
103 claudeAuthMethodChooser: claude.NewAuthMethodChooser(),
104 claudeOAuth2: claude.NewOAuth2(),
105 }
106}
107
108func (m *modelDialogCmp) Init() tea.Cmd {
109 return tea.Batch(
110 m.modelList.Init(),
111 m.apiKeyInput.Init(),
112 m.claudeAuthMethodChooser.Init(),
113 m.claudeOAuth2.Init(),
114 )
115}
116
117func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
118 switch msg := msg.(type) {
119 case tea.WindowSizeMsg:
120 m.wWidth = msg.Width
121 m.wHeight = msg.Height
122 m.apiKeyInput.SetWidth(m.width - 2)
123 m.help.SetWidth(m.width - 2)
124 m.claudeAuthMethodChooser.SetWidth(m.width - 2)
125 return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
126 case APIKeyStateChangeMsg:
127 u, cmd := m.apiKeyInput.Update(msg)
128 m.apiKeyInput = u.(*APIKeyInput)
129 return m, cmd
130 case claude.ValidationCompletedMsg:
131 var cmds []tea.Cmd
132 u, cmd := m.claudeOAuth2.Update(msg)
133 m.claudeOAuth2 = u.(*claude.OAuth2)
134 cmds = append(cmds, cmd)
135
136 if msg.State == claude.OAuthValidationStateValid {
137 cmds = append(cmds, m.saveAPIKeyAndContinue(msg.Token, false))
138 m.keyMap.isClaudeOAuthHelpComplete = true
139 }
140
141 return m, tea.Batch(cmds...)
142 case claude.AuthenticationCompleteMsg:
143 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
144 case tea.KeyPressMsg:
145 switch {
146 case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showClaudeOAuth2 && m.claudeOAuth2.State == claude.OAuthStateURL:
147 return m, tea.Sequence(
148 tea.SetClipboard(m.claudeOAuth2.URL),
149 func() tea.Msg {
150 _ = clipboard.WriteAll(m.claudeOAuth2.URL)
151 return nil
152 },
153 util.ReportInfo("URL copied to clipboard"),
154 )
155 case key.Matches(msg, m.keyMap.Choose) && m.showClaudeAuthMethodChooser:
156 m.claudeAuthMethodChooser.ToggleChoice()
157 return m, nil
158 case key.Matches(msg, m.keyMap.Select):
159 selectedItem := m.modelList.SelectedModel()
160
161 modelType := config.SelectedModelTypeLarge
162 if m.modelList.GetModelType() == SmallModelType {
163 modelType = config.SelectedModelTypeSmall
164 }
165
166 askForApiKey := func() {
167 m.keyMap.isClaudeAuthChoiseHelp = false
168 m.keyMap.isClaudeOAuthHelp = false
169 m.keyMap.isAPIKeyHelp = true
170 m.showClaudeAuthMethodChooser = false
171 m.needsAPIKey = true
172 m.selectedModel = selectedItem
173 m.selectedModelType = modelType
174 m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
175 }
176
177 if m.showClaudeAuthMethodChooser {
178 switch m.claudeAuthMethodChooser.State {
179 case claude.AuthMethodAPIKey:
180 askForApiKey()
181 case claude.AuthMethodOAuth2:
182 m.selectedModel = selectedItem
183 m.selectedModelType = modelType
184 m.showClaudeAuthMethodChooser = false
185 m.showClaudeOAuth2 = true
186 m.keyMap.isClaudeAuthChoiseHelp = false
187 m.keyMap.isClaudeOAuthHelp = true
188 }
189 return m, nil
190 }
191 if m.showClaudeOAuth2 {
192 m2, cmd2 := m.claudeOAuth2.ValidationConfirm()
193 m.claudeOAuth2 = m2.(*claude.OAuth2)
194 return m, cmd2
195 }
196 if m.isAPIKeyValid {
197 return m, m.saveAPIKeyAndContinue(m.apiKeyValue, true)
198 }
199 if m.needsAPIKey {
200 // Handle API key submission
201 m.apiKeyValue = m.apiKeyInput.Value()
202 provider, err := m.getProvider(m.selectedModel.Provider.ID)
203 if err != nil || provider == nil {
204 return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
205 }
206 providerConfig := config.ProviderConfig{
207 ID: string(m.selectedModel.Provider.ID),
208 Name: m.selectedModel.Provider.Name,
209 APIKey: m.apiKeyValue,
210 Type: provider.Type,
211 BaseURL: provider.APIEndpoint,
212 }
213 return m, tea.Sequence(
214 util.CmdHandler(APIKeyStateChangeMsg{
215 State: APIKeyInputStateVerifying,
216 }),
217 func() tea.Msg {
218 start := time.Now()
219 err := providerConfig.TestConnection(config.Get().Resolver())
220 // intentionally wait for at least 750ms to make sure the user sees the spinner
221 elapsed := time.Since(start)
222 if elapsed < 750*time.Millisecond {
223 time.Sleep(750*time.Millisecond - elapsed)
224 }
225 if err == nil {
226 m.isAPIKeyValid = true
227 return APIKeyStateChangeMsg{
228 State: APIKeyInputStateVerified,
229 }
230 }
231 return APIKeyStateChangeMsg{
232 State: APIKeyInputStateError,
233 }
234 },
235 )
236 }
237
238 // Check if provider is configured
239 if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
240 return m, tea.Sequence(
241 util.CmdHandler(dialogs.CloseDialogMsg{}),
242 util.CmdHandler(ModelSelectedMsg{
243 Model: config.SelectedModel{
244 Model: selectedItem.Model.ID,
245 Provider: string(selectedItem.Provider.ID),
246 ReasoningEffort: selectedItem.Model.DefaultReasoningEffort,
247 MaxTokens: selectedItem.Model.DefaultMaxTokens,
248 },
249 ModelType: modelType,
250 }),
251 )
252 } else {
253 if selectedItem.Provider.ID == catwalk.InferenceProviderAnthropic {
254 m.showClaudeAuthMethodChooser = true
255 m.keyMap.isClaudeAuthChoiseHelp = true
256 return m, nil
257 }
258 askForApiKey()
259 return m, nil
260 }
261 case key.Matches(msg, m.keyMap.Tab):
262 switch {
263 case m.showClaudeAuthMethodChooser:
264 m.claudeAuthMethodChooser.ToggleChoice()
265 return m, nil
266 case m.needsAPIKey:
267 u, cmd := m.apiKeyInput.Update(msg)
268 m.apiKeyInput = u.(*APIKeyInput)
269 return m, cmd
270 case m.modelList.GetModelType() == LargeModelType:
271 m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
272 return m, m.modelList.SetModelType(SmallModelType)
273 default:
274 m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
275 return m, m.modelList.SetModelType(LargeModelType)
276 }
277 case key.Matches(msg, m.keyMap.Edit):
278 // Only handle edit key in the main model selection view
279 if !m.showClaudeAuthMethodChooser && !m.showClaudeOAuth2 && !m.needsAPIKey {
280 selectedItem := m.modelList.SelectedModel()
281 if selectedItem != nil {
282 // Check if provider is configured
283 if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
284 // Trigger API key editing flow
285 m.keyMap.isClaudeAuthChoiseHelp = false
286 m.keyMap.isClaudeOAuthHelp = false
287 m.keyMap.isAPIKeyHelp = true
288 m.showClaudeAuthMethodChooser = false
289 m.needsAPIKey = true
290 m.selectedModel = selectedItem
291 m.selectedModelType = config.SelectedModelTypeLarge
292 if m.modelList.GetModelType() == SmallModelType {
293 m.selectedModelType = config.SelectedModelTypeSmall
294 }
295 m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
296 // Pre-fill with existing API key if available
297 if providerConfig, ok := config.Get().Providers.Get(string(selectedItem.Provider.ID)); ok && providerConfig.APIKey != "" {
298 m.apiKeyInput.input.SetValue(providerConfig.APIKey)
299 }
300 return m, nil
301 }
302 }
303 }
304 return m, nil
305 case key.Matches(msg, m.keyMap.Close):
306 if m.showClaudeAuthMethodChooser {
307 m.claudeAuthMethodChooser.SetDefaults()
308 m.showClaudeAuthMethodChooser = false
309 m.keyMap.isClaudeAuthChoiseHelp = false
310 m.keyMap.isClaudeOAuthHelp = false
311 return m, nil
312 }
313 if m.needsAPIKey {
314 if m.isAPIKeyValid {
315 return m, nil
316 }
317 // Go back to model selection
318 m.needsAPIKey = false
319 m.selectedModel = nil
320 m.isAPIKeyValid = false
321 m.apiKeyValue = ""
322 m.apiKeyInput.Reset()
323 m.keyMap.isAPIKeyHelp = false
324 return m, nil
325 }
326 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
327 default:
328 if m.showClaudeAuthMethodChooser {
329 u, cmd := m.claudeAuthMethodChooser.Update(msg)
330 m.claudeAuthMethodChooser = u.(*claude.AuthMethodChooser)
331 return m, cmd
332 } else if m.showClaudeOAuth2 {
333 u, cmd := m.claudeOAuth2.Update(msg)
334 m.claudeOAuth2 = u.(*claude.OAuth2)
335 return m, cmd
336 } else if m.needsAPIKey {
337 u, cmd := m.apiKeyInput.Update(msg)
338 m.apiKeyInput = u.(*APIKeyInput)
339 return m, cmd
340 } else {
341 u, cmd := m.modelList.Update(msg)
342 m.modelList = u
343 return m, cmd
344 }
345 }
346 case tea.PasteMsg:
347 if m.showClaudeOAuth2 {
348 u, cmd := m.claudeOAuth2.Update(msg)
349 m.claudeOAuth2 = u.(*claude.OAuth2)
350 return m, cmd
351 } else if m.needsAPIKey {
352 u, cmd := m.apiKeyInput.Update(msg)
353 m.apiKeyInput = u.(*APIKeyInput)
354 return m, cmd
355 } else {
356 var cmd tea.Cmd
357 m.modelList, cmd = m.modelList.Update(msg)
358 return m, cmd
359 }
360 case spinner.TickMsg:
361 if m.showClaudeOAuth2 {
362 u, cmd := m.claudeOAuth2.Update(msg)
363 m.claudeOAuth2 = u.(*claude.OAuth2)
364 return m, cmd
365 } else {
366 u, cmd := m.apiKeyInput.Update(msg)
367 m.apiKeyInput = u.(*APIKeyInput)
368 return m, cmd
369 }
370 }
371 return m, nil
372}
373
374func (m *modelDialogCmp) View() string {
375 t := styles.CurrentTheme()
376
377 switch {
378 case m.showClaudeAuthMethodChooser:
379 chooserView := m.claudeAuthMethodChooser.View()
380 content := lipgloss.JoinVertical(
381 lipgloss.Left,
382 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
383 chooserView,
384 "",
385 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
386 )
387 return m.style().Render(content)
388 case m.showClaudeOAuth2:
389 m.keyMap.isClaudeOAuthURLState = m.claudeOAuth2.State == claude.OAuthStateURL
390 oauth2View := m.claudeOAuth2.View()
391 content := lipgloss.JoinVertical(
392 lipgloss.Left,
393 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
394 oauth2View,
395 "",
396 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
397 )
398 return m.style().Render(content)
399 case m.needsAPIKey:
400 // Show API key input
401 m.keyMap.isAPIKeyHelp = true
402 m.keyMap.isAPIKeyValid = m.isAPIKeyValid
403 apiKeyView := m.apiKeyInput.View()
404 apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
405 content := lipgloss.JoinVertical(
406 lipgloss.Left,
407 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
408 apiKeyView,
409 "",
410 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
411 )
412 return m.style().Render(content)
413 }
414
415 // Show model selection
416 listView := m.modelList.View()
417 radio := m.modelTypeRadio()
418 content := lipgloss.JoinVertical(
419 lipgloss.Left,
420 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
421 listView,
422 "",
423 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
424 )
425 return m.style().Render(content)
426}
427
428func (m *modelDialogCmp) Cursor() *tea.Cursor {
429 if m.showClaudeAuthMethodChooser {
430 return nil
431 }
432 if m.showClaudeOAuth2 {
433 if cursor := m.claudeOAuth2.CodeInput.Cursor(); cursor != nil {
434 cursor.Y += 2 // FIXME(@andreynering): Why do we need this?
435 return m.moveCursor(cursor)
436 }
437 return nil
438 }
439 if m.needsAPIKey {
440 cursor := m.apiKeyInput.Cursor()
441 if cursor != nil {
442 cursor = m.moveCursor(cursor)
443 return cursor
444 }
445 } else {
446 cursor := m.modelList.Cursor()
447 if cursor != nil {
448 cursor = m.moveCursor(cursor)
449 return cursor
450 }
451 }
452 return nil
453}
454
455func (m *modelDialogCmp) style() lipgloss.Style {
456 t := styles.CurrentTheme()
457 return t.S().Base.
458 Width(m.width).
459 Border(lipgloss.RoundedBorder()).
460 BorderForeground(t.BorderFocus)
461}
462
463func (m *modelDialogCmp) listWidth() int {
464 return m.width - 2
465}
466
467func (m *modelDialogCmp) listHeight() int {
468 return m.wHeight / 2
469}
470
471func (m *modelDialogCmp) Position() (int, int) {
472 row := m.wHeight/4 - 2 // just a bit above the center
473 col := m.wWidth / 2
474 col -= m.width / 2
475 return row, col
476}
477
478func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
479 row, col := m.Position()
480 if m.needsAPIKey {
481 offset := row + 3 // Border + title + API key input offset
482 cursor.Y += offset
483 cursor.X = cursor.X + col + 2
484 } else {
485 offset := row + 3 // Border + title
486 cursor.Y += offset
487 cursor.X = cursor.X + col + 2
488 }
489 return cursor
490}
491
492func (m *modelDialogCmp) ID() dialogs.DialogID {
493 return ModelsDialogID
494}
495
496func (m *modelDialogCmp) modelTypeRadio() string {
497 t := styles.CurrentTheme()
498 choices := []string{"Large Task", "Small Task"}
499 iconSelected := "◉"
500 iconUnselected := "○"
501 if m.modelList.GetModelType() == LargeModelType {
502 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1])
503 }
504 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1])
505}
506
507func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
508 cfg := config.Get()
509 if _, ok := cfg.Providers.Get(providerID); ok {
510 return true
511 }
512 return false
513}
514
515func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
516 cfg := config.Get()
517 providers, err := config.Providers(cfg)
518 if err != nil {
519 return nil, err
520 }
521 for _, p := range providers {
522 if p.ID == providerID {
523 return &p, nil
524 }
525 }
526 return nil, nil
527}
528
529func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey any, close bool) tea.Cmd {
530 if m.selectedModel == nil {
531 return util.ReportError(fmt.Errorf("no model selected"))
532 }
533
534 cfg := config.Get()
535 err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
536 if err != nil {
537 return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
538 }
539
540 // Reset API key state and continue with model selection
541 selectedModel := *m.selectedModel
542 var cmds []tea.Cmd
543 if close {
544 cmds = append(cmds, util.CmdHandler(dialogs.CloseDialogMsg{}))
545 }
546 cmds = append(
547 cmds,
548 util.CmdHandler(ModelSelectedMsg{
549 Model: config.SelectedModel{
550 Model: selectedModel.Model.ID,
551 Provider: string(selectedModel.Provider.ID),
552 ReasoningEffort: selectedModel.Model.DefaultReasoningEffort,
553 MaxTokens: selectedModel.Model.DefaultMaxTokens,
554 },
555 ModelType: m.selectedModelType,
556 }),
557 )
558 return tea.Sequence(cmds...)
559}