1package models
2
3import (
4 "fmt"
5 "time"
6
7 "charm.land/bubbles/v2/help"
8 "charm.land/bubbles/v2/key"
9 "charm.land/bubbles/v2/spinner"
10 tea "charm.land/bubbletea/v2"
11 "charm.land/lipgloss/v2"
12 "github.com/atotto/clipboard"
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/tui/components/core"
16 "github.com/charmbracelet/crush/internal/tui/components/dialogs"
17 "github.com/charmbracelet/crush/internal/tui/components/dialogs/claude"
18 "github.com/charmbracelet/crush/internal/tui/exp/list"
19 "github.com/charmbracelet/crush/internal/tui/styles"
20 "github.com/charmbracelet/crush/internal/tui/util"
21)
22
23const (
24 ModelsDialogID dialogs.DialogID = "models"
25
26 defaultWidth = 60
27)
28
29const (
30 LargeModelType int = iota
31 SmallModelType
32
33 largeModelInputPlaceholder = "Choose a model for large, complex tasks"
34 smallModelInputPlaceholder = "Choose a model for small, simple tasks"
35)
36
37// ModelSelectedMsg is sent when a model is selected
38type ModelSelectedMsg struct {
39 Model config.SelectedModel
40 ModelType config.SelectedModelType
41}
42
43// CloseModelDialogMsg is sent when a model is selected
44type CloseModelDialogMsg struct{}
45
46// ModelDialog interface for the model selection dialog
47type ModelDialog interface {
48 dialogs.DialogModel
49}
50
51type ModelOption struct {
52 Provider catwalk.Provider
53 Model catwalk.Model
54}
55
56type modelDialogCmp struct {
57 width int
58 wWidth int
59 wHeight int
60
61 modelList *ModelListComponent
62 keyMap KeyMap
63 help help.Model
64
65 // API key state
66 needsAPIKey bool
67 apiKeyInput *APIKeyInput
68 selectedModel *ModelOption
69 selectedModelType config.SelectedModelType
70 isAPIKeyValid bool
71 apiKeyValue string
72
73 // Claude state
74 claudeAuthMethodChooser *claude.AuthMethodChooser
75 claudeOAuth2 *claude.OAuth2
76 showClaudeAuthMethodChooser bool
77 showClaudeOAuth2 bool
78}
79
80func NewModelDialogCmp() ModelDialog {
81 keyMap := DefaultKeyMap()
82
83 listKeyMap := list.DefaultKeyMap()
84 listKeyMap.Down.SetEnabled(false)
85 listKeyMap.Up.SetEnabled(false)
86 listKeyMap.DownOneItem = keyMap.Next
87 listKeyMap.UpOneItem = keyMap.Previous
88
89 t := styles.CurrentTheme()
90 modelList := NewModelListComponent(listKeyMap, largeModelInputPlaceholder, true)
91 apiKeyInput := NewAPIKeyInput()
92 apiKeyInput.SetShowTitle(false)
93 help := help.New()
94 help.Styles = t.S().Help
95
96 return &modelDialogCmp{
97 modelList: modelList,
98 apiKeyInput: apiKeyInput,
99 width: defaultWidth,
100 keyMap: DefaultKeyMap(),
101 help: help,
102
103 claudeAuthMethodChooser: claude.NewAuthMethodChooser(),
104 claudeOAuth2: claude.NewOAuth2(),
105 }
106}
107
108func (m *modelDialogCmp) Init() tea.Cmd {
109 return tea.Batch(
110 m.modelList.Init(),
111 m.apiKeyInput.Init(),
112 m.claudeAuthMethodChooser.Init(),
113 m.claudeOAuth2.Init(),
114 )
115}
116
117func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
118 switch msg := msg.(type) {
119 case tea.WindowSizeMsg:
120 m.wWidth = msg.Width
121 m.wHeight = msg.Height
122 m.apiKeyInput.SetWidth(m.width - 2)
123 m.help.SetWidth(m.width - 2)
124 m.claudeAuthMethodChooser.SetWidth(m.width - 2)
125 return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
126 case APIKeyStateChangeMsg:
127 u, cmd := m.apiKeyInput.Update(msg)
128 m.apiKeyInput = u.(*APIKeyInput)
129 return m, cmd
130 case claude.ValidationCompletedMsg:
131 var cmds []tea.Cmd
132 u, cmd := m.claudeOAuth2.Update(msg)
133 m.claudeOAuth2 = u.(*claude.OAuth2)
134 cmds = append(cmds, cmd)
135
136 if msg.State == claude.OAuthValidationStateValid {
137 cmds = append(cmds, m.saveAPIKeyAndContinue(msg.Token, false))
138 m.keyMap.isClaudeOAuthHelpComplete = true
139 }
140
141 return m, tea.Batch(cmds...)
142 case claude.AuthenticationCompleteMsg:
143 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
144 case tea.KeyPressMsg:
145 switch {
146 case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showClaudeOAuth2 && m.claudeOAuth2.State == claude.OAuthStateURL:
147 return m, tea.Sequence(
148 tea.SetClipboard(m.claudeOAuth2.URL),
149 func() tea.Msg {
150 _ = clipboard.WriteAll(m.claudeOAuth2.URL)
151 return nil
152 },
153 util.ReportInfo("URL copied to clipboard"),
154 )
155 case key.Matches(msg, m.keyMap.Choose) && m.showClaudeAuthMethodChooser:
156 m.claudeAuthMethodChooser.ToggleChoice()
157 return m, nil
158 case key.Matches(msg, m.keyMap.Select):
159 selectedItem := m.modelList.SelectedModel()
160
161 modelType := config.SelectedModelTypeLarge
162 if m.modelList.GetModelType() == SmallModelType {
163 modelType = config.SelectedModelTypeSmall
164 }
165
166 askForApiKey := func() {
167 m.keyMap.isClaudeAuthChoiseHelp = false
168 m.keyMap.isClaudeOAuthHelp = false
169 m.keyMap.isAPIKeyHelp = true
170 m.showClaudeAuthMethodChooser = false
171 m.needsAPIKey = true
172 m.selectedModel = selectedItem
173 m.selectedModelType = modelType
174 m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
175 }
176
177 if m.showClaudeAuthMethodChooser {
178 switch m.claudeAuthMethodChooser.State {
179 case claude.AuthMethodAPIKey:
180 askForApiKey()
181 case claude.AuthMethodOAuth2:
182 m.selectedModel = selectedItem
183 m.selectedModelType = modelType
184 m.showClaudeAuthMethodChooser = false
185 m.showClaudeOAuth2 = true
186 m.keyMap.isClaudeAuthChoiseHelp = false
187 m.keyMap.isClaudeOAuthHelp = true
188 }
189 return m, nil
190 }
191 if m.showClaudeOAuth2 {
192 m2, cmd2 := m.claudeOAuth2.ValidationConfirm()
193 m.claudeOAuth2 = m2.(*claude.OAuth2)
194 return m, cmd2
195 }
196 if m.isAPIKeyValid {
197 return m, m.saveAPIKeyAndContinue(m.apiKeyValue, true)
198 }
199 if m.needsAPIKey {
200 // Handle API key submission
201 m.apiKeyValue = m.apiKeyInput.Value()
202 provider, err := m.getProvider(m.selectedModel.Provider.ID)
203 if err != nil || provider == nil {
204 return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
205 }
206 providerConfig := config.ProviderConfig{
207 ID: string(m.selectedModel.Provider.ID),
208 Name: m.selectedModel.Provider.Name,
209 APIKey: m.apiKeyValue,
210 Type: provider.Type,
211 BaseURL: provider.APIEndpoint,
212 }
213 return m, tea.Sequence(
214 util.CmdHandler(APIKeyStateChangeMsg{
215 State: APIKeyInputStateVerifying,
216 }),
217 func() tea.Msg {
218 start := time.Now()
219 err := providerConfig.TestConnection(config.Get().Resolver())
220 // intentionally wait for at least 750ms to make sure the user sees the spinner
221 elapsed := time.Since(start)
222 if elapsed < 750*time.Millisecond {
223 time.Sleep(750*time.Millisecond - elapsed)
224 }
225 if err == nil {
226 m.isAPIKeyValid = true
227 return APIKeyStateChangeMsg{
228 State: APIKeyInputStateVerified,
229 }
230 }
231 return APIKeyStateChangeMsg{
232 State: APIKeyInputStateError,
233 }
234 },
235 )
236 }
237
238 // Check if provider is configured
239 if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
240 return m, tea.Sequence(
241 util.CmdHandler(dialogs.CloseDialogMsg{}),
242 util.CmdHandler(ModelSelectedMsg{
243 Model: config.SelectedModel{
244 Model: selectedItem.Model.ID,
245 Provider: string(selectedItem.Provider.ID),
246 ReasoningEffort: selectedItem.Model.DefaultReasoningEffort,
247 MaxTokens: selectedItem.Model.DefaultMaxTokens,
248 },
249 ModelType: modelType,
250 }),
251 )
252 } else {
253 if selectedItem.Provider.ID == catwalk.InferenceProviderAnthropic {
254 m.showClaudeAuthMethodChooser = true
255 m.keyMap.isClaudeAuthChoiseHelp = true
256 return m, nil
257 }
258 askForApiKey()
259 return m, nil
260 }
261 case key.Matches(msg, m.keyMap.Tab):
262 switch {
263 case m.showClaudeAuthMethodChooser:
264 m.claudeAuthMethodChooser.ToggleChoice()
265 return m, nil
266 case m.needsAPIKey:
267 u, cmd := m.apiKeyInput.Update(msg)
268 m.apiKeyInput = u.(*APIKeyInput)
269 return m, cmd
270 case m.modelList.GetModelType() == LargeModelType:
271 m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
272 return m, m.modelList.SetModelType(SmallModelType)
273 default:
274 m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
275 return m, m.modelList.SetModelType(LargeModelType)
276 }
277 case key.Matches(msg, m.keyMap.Close):
278 if m.showClaudeAuthMethodChooser {
279 m.claudeAuthMethodChooser.SetDefaults()
280 m.showClaudeAuthMethodChooser = false
281 m.keyMap.isClaudeAuthChoiseHelp = false
282 m.keyMap.isClaudeOAuthHelp = false
283 return m, nil
284 }
285 if m.needsAPIKey {
286 if m.isAPIKeyValid {
287 return m, nil
288 }
289 // Go back to model selection
290 m.needsAPIKey = false
291 m.selectedModel = nil
292 m.isAPIKeyValid = false
293 m.apiKeyValue = ""
294 m.apiKeyInput.Reset()
295 return m, nil
296 }
297 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
298 default:
299 if m.showClaudeAuthMethodChooser {
300 u, cmd := m.claudeAuthMethodChooser.Update(msg)
301 m.claudeAuthMethodChooser = u.(*claude.AuthMethodChooser)
302 return m, cmd
303 } else if m.showClaudeOAuth2 {
304 u, cmd := m.claudeOAuth2.Update(msg)
305 m.claudeOAuth2 = u.(*claude.OAuth2)
306 return m, cmd
307 } else if m.needsAPIKey {
308 u, cmd := m.apiKeyInput.Update(msg)
309 m.apiKeyInput = u.(*APIKeyInput)
310 return m, cmd
311 } else {
312 u, cmd := m.modelList.Update(msg)
313 m.modelList = u
314 return m, cmd
315 }
316 }
317 case tea.PasteMsg:
318 if m.showClaudeOAuth2 {
319 u, cmd := m.claudeOAuth2.Update(msg)
320 m.claudeOAuth2 = u.(*claude.OAuth2)
321 return m, cmd
322 } else if m.needsAPIKey {
323 u, cmd := m.apiKeyInput.Update(msg)
324 m.apiKeyInput = u.(*APIKeyInput)
325 return m, cmd
326 } else {
327 var cmd tea.Cmd
328 m.modelList, cmd = m.modelList.Update(msg)
329 return m, cmd
330 }
331 case spinner.TickMsg:
332 if m.showClaudeOAuth2 {
333 u, cmd := m.claudeOAuth2.Update(msg)
334 m.claudeOAuth2 = u.(*claude.OAuth2)
335 return m, cmd
336 } else {
337 u, cmd := m.apiKeyInput.Update(msg)
338 m.apiKeyInput = u.(*APIKeyInput)
339 return m, cmd
340 }
341 }
342 return m, nil
343}
344
345func (m *modelDialogCmp) View() string {
346 t := styles.CurrentTheme()
347
348 switch {
349 case m.showClaudeAuthMethodChooser:
350 chooserView := m.claudeAuthMethodChooser.View()
351 content := lipgloss.JoinVertical(
352 lipgloss.Left,
353 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
354 chooserView,
355 "",
356 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
357 )
358 return m.style().Render(content)
359 case m.showClaudeOAuth2:
360 m.keyMap.isClaudeOAuthURLState = m.claudeOAuth2.State == claude.OAuthStateURL
361 oauth2View := m.claudeOAuth2.View()
362 content := lipgloss.JoinVertical(
363 lipgloss.Left,
364 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
365 oauth2View,
366 "",
367 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
368 )
369 return m.style().Render(content)
370 case m.needsAPIKey:
371 // Show API key input
372 m.keyMap.isAPIKeyHelp = true
373 m.keyMap.isAPIKeyValid = m.isAPIKeyValid
374 apiKeyView := m.apiKeyInput.View()
375 apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
376 content := lipgloss.JoinVertical(
377 lipgloss.Left,
378 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
379 apiKeyView,
380 "",
381 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
382 )
383 return m.style().Render(content)
384 }
385
386 // Show model selection
387 listView := m.modelList.View()
388 radio := m.modelTypeRadio()
389 content := lipgloss.JoinVertical(
390 lipgloss.Left,
391 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
392 listView,
393 "",
394 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
395 )
396 return m.style().Render(content)
397}
398
399func (m *modelDialogCmp) Cursor() *tea.Cursor {
400 if m.showClaudeAuthMethodChooser {
401 return nil
402 }
403 if m.showClaudeOAuth2 {
404 if cursor := m.claudeOAuth2.CodeInput.Cursor(); cursor != nil {
405 cursor.Y += 2 // FIXME(@andreynering): Why do we need this?
406 return m.moveCursor(cursor)
407 }
408 return nil
409 }
410 if m.needsAPIKey {
411 cursor := m.apiKeyInput.Cursor()
412 if cursor != nil {
413 cursor = m.moveCursor(cursor)
414 return cursor
415 }
416 } else {
417 cursor := m.modelList.Cursor()
418 if cursor != nil {
419 cursor = m.moveCursor(cursor)
420 return cursor
421 }
422 }
423 return nil
424}
425
426func (m *modelDialogCmp) style() lipgloss.Style {
427 t := styles.CurrentTheme()
428 return t.S().Base.
429 Width(m.width).
430 Border(lipgloss.RoundedBorder()).
431 BorderForeground(t.BorderFocus)
432}
433
434func (m *modelDialogCmp) listWidth() int {
435 return m.width - 2
436}
437
438func (m *modelDialogCmp) listHeight() int {
439 return m.wHeight / 2
440}
441
442func (m *modelDialogCmp) Position() (int, int) {
443 row := m.wHeight/4 - 2 // just a bit above the center
444 col := m.wWidth / 2
445 col -= m.width / 2
446 return row, col
447}
448
449func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
450 row, col := m.Position()
451 if m.needsAPIKey {
452 offset := row + 3 // Border + title + API key input offset
453 cursor.Y += offset
454 cursor.X = cursor.X + col + 2
455 } else {
456 offset := row + 3 // Border + title
457 cursor.Y += offset
458 cursor.X = cursor.X + col + 2
459 }
460 return cursor
461}
462
463func (m *modelDialogCmp) ID() dialogs.DialogID {
464 return ModelsDialogID
465}
466
467func (m *modelDialogCmp) modelTypeRadio() string {
468 t := styles.CurrentTheme()
469 choices := []string{"Large Task", "Small Task"}
470 iconSelected := "◉"
471 iconUnselected := "○"
472 if m.modelList.GetModelType() == LargeModelType {
473 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1])
474 }
475 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1])
476}
477
478func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
479 cfg := config.Get()
480 if _, ok := cfg.Providers.Get(providerID); ok {
481 return true
482 }
483 return false
484}
485
486func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
487 cfg := config.Get()
488 providers, err := config.Providers(cfg)
489 if err != nil {
490 return nil, err
491 }
492 for _, p := range providers {
493 if p.ID == providerID {
494 return &p, nil
495 }
496 }
497 return nil, nil
498}
499
500func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey any, close bool) tea.Cmd {
501 if m.selectedModel == nil {
502 return util.ReportError(fmt.Errorf("no model selected"))
503 }
504
505 cfg := config.Get()
506 err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
507 if err != nil {
508 return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
509 }
510
511 // Reset API key state and continue with model selection
512 selectedModel := *m.selectedModel
513 var cmds []tea.Cmd
514 if close {
515 cmds = append(cmds, util.CmdHandler(dialogs.CloseDialogMsg{}))
516 }
517 cmds = append(
518 cmds,
519 util.CmdHandler(ModelSelectedMsg{
520 Model: config.SelectedModel{
521 Model: selectedModel.Model.ID,
522 Provider: string(selectedModel.Provider.ID),
523 ReasoningEffort: selectedModel.Model.DefaultReasoningEffort,
524 MaxTokens: selectedModel.Model.DefaultMaxTokens,
525 },
526 ModelType: m.selectedModelType,
527 }),
528 )
529 return tea.Sequence(cmds...)
530}