1package models
2
3import (
4 "fmt"
5 "time"
6
7 "charm.land/bubbles/v2/help"
8 "charm.land/bubbles/v2/key"
9 "charm.land/bubbles/v2/spinner"
10 tea "charm.land/bubbletea/v2"
11 "charm.land/lipgloss/v2"
12 "github.com/atotto/clipboard"
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/tui/components/core"
16 "github.com/charmbracelet/crush/internal/tui/components/dialogs"
17 "github.com/charmbracelet/crush/internal/tui/components/dialogs/claude"
18 "github.com/charmbracelet/crush/internal/tui/exp/list"
19 "github.com/charmbracelet/crush/internal/tui/styles"
20 "github.com/charmbracelet/crush/internal/tui/util"
21)
22
23const (
24 ModelsDialogID dialogs.DialogID = "models"
25
26 defaultWidth = 60
27)
28
29const (
30 LargeModelType int = iota
31 SmallModelType
32
33 largeModelInputPlaceholder = "Choose a model for large, complex tasks"
34 smallModelInputPlaceholder = "Choose a model for small, simple tasks"
35)
36
37// ModelSelectedMsg is sent when a model is selected
38type ModelSelectedMsg struct {
39 Model config.SelectedModel
40 ModelType config.SelectedModelType
41}
42
43// CloseModelDialogMsg is sent when a model is selected
44type CloseModelDialogMsg struct{}
45
46// ModelDialog interface for the model selection dialog
47type ModelDialog interface {
48 dialogs.DialogModel
49}
50
51type ModelOption struct {
52 Provider catwalk.Provider
53 Model catwalk.Model
54}
55
56type modelDialogCmp struct {
57 width int
58 wWidth int
59 wHeight int
60
61 modelList *ModelListComponent
62 keyMap KeyMap
63 help help.Model
64
65 // API key state
66 needsAPIKey bool
67 apiKeyInput *APIKeyInput
68 selectedModel *ModelOption
69 selectedModelType config.SelectedModelType
70 isAPIKeyValid bool
71 apiKeyValue string
72
73 // Claude state
74 claudeAuthMethodChooser *claude.AuthMethodChooser
75 claudeOAuth2 *claude.OAuth2
76 showClaudeAuthMethodChooser bool
77 showClaudeOAuth2 bool
78}
79
80func NewModelDialogCmp() ModelDialog {
81 keyMap := DefaultKeyMap()
82
83 listKeyMap := list.DefaultKeyMap()
84 listKeyMap.Down.SetEnabled(false)
85 listKeyMap.Up.SetEnabled(false)
86 listKeyMap.DownOneItem = keyMap.Next
87 listKeyMap.UpOneItem = keyMap.Previous
88
89 t := styles.CurrentTheme()
90 modelList := NewModelListComponent(listKeyMap, largeModelInputPlaceholder, true)
91 apiKeyInput := NewAPIKeyInput()
92 apiKeyInput.SetShowTitle(false)
93 help := help.New()
94 help.Styles = t.S().Help
95
96 return &modelDialogCmp{
97 modelList: modelList,
98 apiKeyInput: apiKeyInput,
99 width: defaultWidth,
100 keyMap: DefaultKeyMap(),
101 help: help,
102
103 claudeAuthMethodChooser: claude.NewAuthMethodChooser(),
104 claudeOAuth2: claude.NewOAuth2(),
105 }
106}
107
108func (m *modelDialogCmp) Init() tea.Cmd {
109 return tea.Batch(
110 m.modelList.Init(),
111 m.apiKeyInput.Init(),
112 m.claudeAuthMethodChooser.Init(),
113 m.claudeOAuth2.Init(),
114 )
115}
116
117func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
118 switch msg := msg.(type) {
119 case tea.WindowSizeMsg:
120 m.wWidth = msg.Width
121 m.wHeight = msg.Height
122 m.apiKeyInput.SetWidth(m.width - 2)
123 m.help.SetWidth(m.width - 2)
124 m.claudeAuthMethodChooser.SetWidth(m.width - 2)
125 return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
126 case APIKeyStateChangeMsg:
127 u, cmd := m.apiKeyInput.Update(msg)
128 m.apiKeyInput = u.(*APIKeyInput)
129 return m, cmd
130 case claude.ValidationCompletedMsg:
131 var cmds []tea.Cmd
132 u, cmd := m.claudeOAuth2.Update(msg)
133 m.claudeOAuth2 = u.(*claude.OAuth2)
134 cmds = append(cmds, cmd)
135
136 if msg.State == claude.OAuthValidationStateValid {
137 cmds = append(cmds, m.saveAPIKeyAndContinue(msg.Token, false))
138 m.keyMap.isClaudeOAuthHelpComplete = true
139 }
140
141 return m, tea.Batch(cmds...)
142 case claude.AuthenticationCompleteMsg:
143 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
144 case tea.KeyPressMsg:
145 switch {
146 case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showClaudeOAuth2 && m.claudeOAuth2.State == claude.OAuthStateURL:
147 return m, tea.Sequence(
148 tea.SetClipboard(m.claudeOAuth2.URL),
149 func() tea.Msg {
150 _ = clipboard.WriteAll(m.claudeOAuth2.URL)
151 return nil
152 },
153 util.ReportInfo("URL copied to clipboard"),
154 )
155 case key.Matches(msg, m.keyMap.Choose) && m.showClaudeAuthMethodChooser:
156 m.claudeAuthMethodChooser.ToggleChoice()
157 return m, nil
158 case key.Matches(msg, m.keyMap.Select):
159 selectedItem := m.modelList.SelectedModel()
160
161 modelType := config.SelectedModelTypeLarge
162 if m.modelList.GetModelType() == SmallModelType {
163 modelType = config.SelectedModelTypeSmall
164 }
165
166 askForApiKey := func() {
167 m.keyMap.isClaudeAuthChoiseHelp = false
168 m.keyMap.isClaudeOAuthHelp = false
169 m.keyMap.isAPIKeyHelp = true
170 m.showClaudeAuthMethodChooser = false
171 m.needsAPIKey = true
172 m.selectedModel = selectedItem
173 m.selectedModelType = modelType
174 m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
175 }
176
177 if m.showClaudeAuthMethodChooser {
178 switch m.claudeAuthMethodChooser.State {
179 case claude.AuthMethodAPIKey:
180 askForApiKey()
181 case claude.AuthMethodOAuth2:
182 m.selectedModel = selectedItem
183 m.selectedModelType = modelType
184 m.showClaudeAuthMethodChooser = false
185 m.showClaudeOAuth2 = true
186 m.keyMap.isClaudeAuthChoiseHelp = false
187 m.keyMap.isClaudeOAuthHelp = true
188 }
189 return m, nil
190 }
191 if m.showClaudeOAuth2 {
192 m2, cmd2 := m.claudeOAuth2.ValidationConfirm()
193 m.claudeOAuth2 = m2.(*claude.OAuth2)
194 return m, cmd2
195 }
196 if m.isAPIKeyValid {
197 return m, m.saveAPIKeyAndContinue(m.apiKeyValue, true)
198 }
199 if m.needsAPIKey {
200 // Handle API key submission
201 m.apiKeyValue = m.apiKeyInput.Value()
202 provider, err := m.getProvider(m.selectedModel.Provider.ID)
203 if err != nil || provider == nil {
204 return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
205 }
206 providerConfig := config.ProviderConfig{
207 ID: string(m.selectedModel.Provider.ID),
208 Name: m.selectedModel.Provider.Name,
209 APIKey: m.apiKeyValue,
210 Type: provider.Type,
211 BaseURL: provider.APIEndpoint,
212 }
213 return m, tea.Sequence(
214 util.CmdHandler(APIKeyStateChangeMsg{
215 State: APIKeyInputStateVerifying,
216 }),
217 func() tea.Msg {
218 start := time.Now()
219 err := providerConfig.TestConnection(config.Get().Resolver())
220 // intentionally wait for at least 750ms to make sure the user sees the spinner
221 elapsed := time.Since(start)
222 if elapsed < 750*time.Millisecond {
223 time.Sleep(750*time.Millisecond - elapsed)
224 }
225 if err == nil {
226 m.isAPIKeyValid = true
227 return APIKeyStateChangeMsg{
228 State: APIKeyInputStateVerified,
229 }
230 }
231 return APIKeyStateChangeMsg{
232 State: APIKeyInputStateError,
233 }
234 },
235 )
236 }
237
238 // Check if provider is configured
239 if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
240 return m, tea.Sequence(
241 util.CmdHandler(dialogs.CloseDialogMsg{}),
242 util.CmdHandler(ModelSelectedMsg{
243 Model: config.SelectedModel{
244 Model: selectedItem.Model.ID,
245 Provider: string(selectedItem.Provider.ID),
246 ReasoningEffort: selectedItem.Model.DefaultReasoningEffort,
247 MaxTokens: selectedItem.Model.DefaultMaxTokens,
248 },
249 ModelType: modelType,
250 }),
251 )
252 } else {
253 if selectedItem.Provider.ID == catwalk.InferenceProviderAnthropic {
254 m.showClaudeAuthMethodChooser = true
255 m.keyMap.isClaudeAuthChoiseHelp = true
256 return m, nil
257 }
258 askForApiKey()
259 return m, nil
260 }
261 case key.Matches(msg, m.keyMap.Tab):
262 switch {
263 case m.showClaudeAuthMethodChooser:
264 m.claudeAuthMethodChooser.ToggleChoice()
265 return m, nil
266 case m.needsAPIKey:
267 u, cmd := m.apiKeyInput.Update(msg)
268 m.apiKeyInput = u.(*APIKeyInput)
269 return m, cmd
270 case m.modelList.GetModelType() == LargeModelType:
271 m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
272 return m, m.modelList.SetModelType(SmallModelType)
273 default:
274 m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
275 return m, m.modelList.SetModelType(LargeModelType)
276 }
277 case key.Matches(msg, m.keyMap.Edit):
278 // Only handle edit key in the main model selection view
279 if !m.showClaudeAuthMethodChooser && !m.showClaudeOAuth2 && !m.needsAPIKey {
280 selectedItem := m.modelList.SelectedModel()
281 if selectedItem != nil {
282 // Check if provider is configured
283 if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
284 // Trigger API key editing flow
285 m.keyMap.isClaudeAuthChoiseHelp = false
286 m.keyMap.isClaudeOAuthHelp = false
287 m.keyMap.isAPIKeyHelp = true
288 m.showClaudeAuthMethodChooser = false
289 m.needsAPIKey = true
290 m.selectedModel = selectedItem
291 m.selectedModelType = config.SelectedModelTypeLarge
292 if m.modelList.GetModelType() == SmallModelType {
293 m.selectedModelType = config.SelectedModelTypeSmall
294 }
295 m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
296 // Pre-fill with existing API key if available
297 if providerConfig, ok := config.Get().Providers.Get(string(selectedItem.Provider.ID)); ok && providerConfig.APIKey != "" {
298 m.apiKeyInput.input.SetValue(providerConfig.APIKey)
299 }
300 return m, nil
301 }
302 }
303 }
304 return m, nil
305 case key.Matches(msg, m.keyMap.Close):
306 if m.showClaudeAuthMethodChooser {
307 m.claudeAuthMethodChooser.SetDefaults()
308 m.showClaudeAuthMethodChooser = false
309 m.keyMap.isClaudeAuthChoiseHelp = false
310 m.keyMap.isClaudeOAuthHelp = false
311 return m, nil
312 }
313 if m.needsAPIKey {
314 if m.isAPIKeyValid {
315 return m, nil
316 }
317 // Go back to model selection
318 m.needsAPIKey = false
319 m.selectedModel = nil
320 m.isAPIKeyValid = false
321 m.apiKeyValue = ""
322 m.apiKeyInput.Reset()
323 return m, nil
324 }
325 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
326 default:
327 if m.showClaudeAuthMethodChooser {
328 u, cmd := m.claudeAuthMethodChooser.Update(msg)
329 m.claudeAuthMethodChooser = u.(*claude.AuthMethodChooser)
330 return m, cmd
331 } else if m.showClaudeOAuth2 {
332 u, cmd := m.claudeOAuth2.Update(msg)
333 m.claudeOAuth2 = u.(*claude.OAuth2)
334 return m, cmd
335 } else if m.needsAPIKey {
336 u, cmd := m.apiKeyInput.Update(msg)
337 m.apiKeyInput = u.(*APIKeyInput)
338 return m, cmd
339 } else {
340 u, cmd := m.modelList.Update(msg)
341 m.modelList = u
342 return m, cmd
343 }
344 }
345 case tea.PasteMsg:
346 if m.showClaudeOAuth2 {
347 u, cmd := m.claudeOAuth2.Update(msg)
348 m.claudeOAuth2 = u.(*claude.OAuth2)
349 return m, cmd
350 } else if m.needsAPIKey {
351 u, cmd := m.apiKeyInput.Update(msg)
352 m.apiKeyInput = u.(*APIKeyInput)
353 return m, cmd
354 } else {
355 var cmd tea.Cmd
356 m.modelList, cmd = m.modelList.Update(msg)
357 return m, cmd
358 }
359 case spinner.TickMsg:
360 if m.showClaudeOAuth2 {
361 u, cmd := m.claudeOAuth2.Update(msg)
362 m.claudeOAuth2 = u.(*claude.OAuth2)
363 return m, cmd
364 } else {
365 u, cmd := m.apiKeyInput.Update(msg)
366 m.apiKeyInput = u.(*APIKeyInput)
367 return m, cmd
368 }
369 }
370 return m, nil
371}
372
373func (m *modelDialogCmp) View() string {
374 t := styles.CurrentTheme()
375
376 switch {
377 case m.showClaudeAuthMethodChooser:
378 chooserView := m.claudeAuthMethodChooser.View()
379 content := lipgloss.JoinVertical(
380 lipgloss.Left,
381 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
382 chooserView,
383 "",
384 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
385 )
386 return m.style().Render(content)
387 case m.showClaudeOAuth2:
388 m.keyMap.isClaudeOAuthURLState = m.claudeOAuth2.State == claude.OAuthStateURL
389 oauth2View := m.claudeOAuth2.View()
390 content := lipgloss.JoinVertical(
391 lipgloss.Left,
392 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
393 oauth2View,
394 "",
395 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
396 )
397 return m.style().Render(content)
398 case m.needsAPIKey:
399 // Show API key input
400 m.keyMap.isAPIKeyHelp = true
401 m.keyMap.isAPIKeyValid = m.isAPIKeyValid
402 apiKeyView := m.apiKeyInput.View()
403 apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
404 content := lipgloss.JoinVertical(
405 lipgloss.Left,
406 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
407 apiKeyView,
408 "",
409 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
410 )
411 return m.style().Render(content)
412 }
413
414 // Show model selection
415 listView := m.modelList.View()
416 radio := m.modelTypeRadio()
417 content := lipgloss.JoinVertical(
418 lipgloss.Left,
419 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
420 listView,
421 "",
422 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
423 )
424 return m.style().Render(content)
425}
426
427func (m *modelDialogCmp) Cursor() *tea.Cursor {
428 if m.showClaudeAuthMethodChooser {
429 return nil
430 }
431 if m.showClaudeOAuth2 {
432 if cursor := m.claudeOAuth2.CodeInput.Cursor(); cursor != nil {
433 cursor.Y += 2 // FIXME(@andreynering): Why do we need this?
434 return m.moveCursor(cursor)
435 }
436 return nil
437 }
438 if m.needsAPIKey {
439 cursor := m.apiKeyInput.Cursor()
440 if cursor != nil {
441 cursor = m.moveCursor(cursor)
442 return cursor
443 }
444 } else {
445 cursor := m.modelList.Cursor()
446 if cursor != nil {
447 cursor = m.moveCursor(cursor)
448 return cursor
449 }
450 }
451 return nil
452}
453
454func (m *modelDialogCmp) style() lipgloss.Style {
455 t := styles.CurrentTheme()
456 return t.S().Base.
457 Width(m.width).
458 Border(lipgloss.RoundedBorder()).
459 BorderForeground(t.BorderFocus)
460}
461
462func (m *modelDialogCmp) listWidth() int {
463 return m.width - 2
464}
465
466func (m *modelDialogCmp) listHeight() int {
467 return m.wHeight / 2
468}
469
470func (m *modelDialogCmp) Position() (int, int) {
471 row := m.wHeight/4 - 2 // just a bit above the center
472 col := m.wWidth / 2
473 col -= m.width / 2
474 return row, col
475}
476
477func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
478 row, col := m.Position()
479 if m.needsAPIKey {
480 offset := row + 3 // Border + title + API key input offset
481 cursor.Y += offset
482 cursor.X = cursor.X + col + 2
483 } else {
484 offset := row + 3 // Border + title
485 cursor.Y += offset
486 cursor.X = cursor.X + col + 2
487 }
488 return cursor
489}
490
491func (m *modelDialogCmp) ID() dialogs.DialogID {
492 return ModelsDialogID
493}
494
495func (m *modelDialogCmp) modelTypeRadio() string {
496 t := styles.CurrentTheme()
497 choices := []string{"Large Task", "Small Task"}
498 iconSelected := "◉"
499 iconUnselected := "○"
500 if m.modelList.GetModelType() == LargeModelType {
501 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1])
502 }
503 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1])
504}
505
506func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
507 cfg := config.Get()
508 if _, ok := cfg.Providers.Get(providerID); ok {
509 return true
510 }
511 return false
512}
513
514func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
515 cfg := config.Get()
516 providers, err := config.Providers(cfg)
517 if err != nil {
518 return nil, err
519 }
520 for _, p := range providers {
521 if p.ID == providerID {
522 return &p, nil
523 }
524 }
525 return nil, nil
526}
527
528func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey any, close bool) tea.Cmd {
529 if m.selectedModel == nil {
530 return util.ReportError(fmt.Errorf("no model selected"))
531 }
532
533 cfg := config.Get()
534 err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
535 if err != nil {
536 return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
537 }
538
539 // Reset API key state and continue with model selection
540 selectedModel := *m.selectedModel
541 var cmds []tea.Cmd
542 if close {
543 cmds = append(cmds, util.CmdHandler(dialogs.CloseDialogMsg{}))
544 }
545 cmds = append(
546 cmds,
547 util.CmdHandler(ModelSelectedMsg{
548 Model: config.SelectedModel{
549 Model: selectedModel.Model.ID,
550 Provider: string(selectedModel.Provider.ID),
551 ReasoningEffort: selectedModel.Model.DefaultReasoningEffort,
552 MaxTokens: selectedModel.Model.DefaultMaxTokens,
553 },
554 ModelType: m.selectedModelType,
555 }),
556 )
557 return tea.Sequence(cmds...)
558}