1package models
2
3import (
4 "fmt"
5 "time"
6
7 "charm.land/bubbles/v2/help"
8 "charm.land/bubbles/v2/key"
9 "charm.land/bubbles/v2/spinner"
10 tea "charm.land/bubbletea/v2"
11 "charm.land/lipgloss/v2"
12 "github.com/atotto/clipboard"
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/tui/components/core"
16 "github.com/charmbracelet/crush/internal/tui/components/dialogs"
17 "github.com/charmbracelet/crush/internal/tui/components/dialogs/claude"
18 "github.com/charmbracelet/crush/internal/tui/exp/list"
19 "github.com/charmbracelet/crush/internal/tui/styles"
20 "github.com/charmbracelet/crush/internal/tui/util"
21)
22
23const (
24 ModelsDialogID dialogs.DialogID = "models"
25
26 defaultWidth = 60
27)
28
29const (
30 LargeModelType int = iota
31 SmallModelType
32
33 largeModelInputPlaceholder = "Choose a model for large, complex tasks"
34 smallModelInputPlaceholder = "Choose a model for small, simple tasks"
35)
36
37// ModelSelectedMsg is sent when a model is selected
38type ModelSelectedMsg struct {
39 Model config.SelectedModel
40 ModelType config.SelectedModelType
41}
42
43// CloseModelDialogMsg is sent when a model is selected
44type CloseModelDialogMsg struct{}
45
46// ModelDialog interface for the model selection dialog
47type ModelDialog interface {
48 dialogs.DialogModel
49}
50
51type ModelOption struct {
52 Provider catwalk.Provider
53 Model catwalk.Model
54}
55
56type modelDialogCmp struct {
57 width int
58 wWidth int
59 wHeight int
60
61 modelList *ModelListComponent
62 keyMap KeyMap
63 help help.Model
64
65 // API key state
66 needsAPIKey bool
67 apiKeyInput *APIKeyInput
68 selectedModel *ModelOption
69 selectedModelType config.SelectedModelType
70 isAPIKeyValid bool
71 apiKeyValue string
72
73 // Claude state
74 claudeAuthMethodChooser *claude.AuthMethodChooser
75 claudeOAuth2 *claude.OAuth2
76 showClaudeAuthMethodChooser bool
77 showClaudeOAuth2 bool
78}
79
80func NewModelDialogCmp() ModelDialog {
81 keyMap := DefaultKeyMap()
82
83 listKeyMap := list.DefaultKeyMap()
84 listKeyMap.Down.SetEnabled(false)
85 listKeyMap.Up.SetEnabled(false)
86 listKeyMap.DownOneItem = keyMap.Next
87 listKeyMap.UpOneItem = keyMap.Previous
88
89 t := styles.CurrentTheme()
90 modelList := NewModelListComponent(listKeyMap, largeModelInputPlaceholder, true)
91 apiKeyInput := NewAPIKeyInput()
92 apiKeyInput.SetShowTitle(false)
93 help := help.New()
94 help.Styles = t.S().Help
95
96 return &modelDialogCmp{
97 modelList: modelList,
98 apiKeyInput: apiKeyInput,
99 width: defaultWidth,
100 keyMap: DefaultKeyMap(),
101 help: help,
102
103 claudeAuthMethodChooser: claude.NewAuthMethodChooser(),
104 claudeOAuth2: claude.NewOAuth2(),
105 }
106}
107
108func (m *modelDialogCmp) Init() tea.Cmd {
109 return tea.Batch(
110 m.modelList.Init(),
111 m.apiKeyInput.Init(),
112 m.claudeAuthMethodChooser.Init(),
113 m.claudeOAuth2.Init(),
114 )
115}
116
117func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
118 switch msg := msg.(type) {
119 case tea.WindowSizeMsg:
120 m.wWidth = msg.Width
121 m.wHeight = msg.Height
122 m.apiKeyInput.SetWidth(m.width - 2)
123 m.help.SetWidth(m.width - 2)
124 m.claudeAuthMethodChooser.SetWidth(m.width - 2)
125 return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
126 case APIKeyStateChangeMsg:
127 u, cmd := m.apiKeyInput.Update(msg)
128 m.apiKeyInput = u.(*APIKeyInput)
129 return m, cmd
130 case claude.ValidationCompletedMsg:
131 var cmds []tea.Cmd
132 u, cmd := m.claudeOAuth2.Update(msg)
133 m.claudeOAuth2 = u.(*claude.OAuth2)
134 cmds = append(cmds, cmd)
135
136 if msg.State == claude.OAuthValidationStateValid {
137 cmds = append(cmds, m.saveAPIKeyAndContinue(msg.Token, false))
138 m.keyMap.isClaudeOAuthHelpComplete = true
139 }
140
141 return m, tea.Batch(cmds...)
142 case claude.AuthenticationCompleteMsg:
143 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
144 case tea.KeyPressMsg:
145 switch {
146 case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))):
147 if m.showClaudeOAuth2 && m.claudeOAuth2.State == claude.OAuthStateURL {
148 return m, tea.Sequence(
149 tea.SetClipboard(m.claudeOAuth2.URL),
150 func() tea.Msg {
151 _ = clipboard.WriteAll(m.claudeOAuth2.URL)
152 return nil
153 },
154 util.ReportInfo("URL copied to clipboard"),
155 )
156 }
157 case key.Matches(msg, m.keyMap.Choose):
158 if m.showClaudeAuthMethodChooser {
159 m.claudeAuthMethodChooser.ToggleChoice()
160 return m, nil
161 }
162 case key.Matches(msg, m.keyMap.Select):
163 selectedItem := m.modelList.SelectedModel()
164
165 modelType := config.SelectedModelTypeLarge
166 if m.modelList.GetModelType() == SmallModelType {
167 modelType = config.SelectedModelTypeSmall
168 }
169
170 askForApiKey := func() {
171 m.keyMap.isClaudeAuthChoiseHelp = false
172 m.keyMap.isClaudeOAuthHelp = false
173 m.keyMap.isAPIKeyHelp = true
174 m.showClaudeAuthMethodChooser = false
175 m.needsAPIKey = true
176 m.selectedModel = selectedItem
177 m.selectedModelType = modelType
178 m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
179 }
180
181 if m.showClaudeAuthMethodChooser {
182 switch m.claudeAuthMethodChooser.State {
183 case claude.AuthMethodAPIKey:
184 askForApiKey()
185 case claude.AuthMethodOAuth2:
186 m.selectedModel = selectedItem
187 m.selectedModelType = modelType
188 m.showClaudeAuthMethodChooser = false
189 m.showClaudeOAuth2 = true
190 m.keyMap.isClaudeAuthChoiseHelp = false
191 m.keyMap.isClaudeOAuthHelp = true
192 }
193 return m, nil
194 }
195 if m.showClaudeOAuth2 {
196 m2, cmd2 := m.claudeOAuth2.ValidationConfirm()
197 m.claudeOAuth2 = m2.(*claude.OAuth2)
198 return m, cmd2
199 }
200 if m.isAPIKeyValid {
201 return m, m.saveAPIKeyAndContinue(m.apiKeyValue, true)
202 }
203 if m.needsAPIKey {
204 // Handle API key submission
205 m.apiKeyValue = m.apiKeyInput.Value()
206 provider, err := m.getProvider(m.selectedModel.Provider.ID)
207 if err != nil || provider == nil {
208 return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
209 }
210 providerConfig := config.ProviderConfig{
211 ID: string(m.selectedModel.Provider.ID),
212 Name: m.selectedModel.Provider.Name,
213 APIKey: m.apiKeyValue,
214 Type: provider.Type,
215 BaseURL: provider.APIEndpoint,
216 }
217 return m, tea.Sequence(
218 util.CmdHandler(APIKeyStateChangeMsg{
219 State: APIKeyInputStateVerifying,
220 }),
221 func() tea.Msg {
222 start := time.Now()
223 err := providerConfig.TestConnection(config.Get().Resolver())
224 // intentionally wait for at least 750ms to make sure the user sees the spinner
225 elapsed := time.Since(start)
226 if elapsed < 750*time.Millisecond {
227 time.Sleep(750*time.Millisecond - elapsed)
228 }
229 if err == nil {
230 m.isAPIKeyValid = true
231 return APIKeyStateChangeMsg{
232 State: APIKeyInputStateVerified,
233 }
234 }
235 return APIKeyStateChangeMsg{
236 State: APIKeyInputStateError,
237 }
238 },
239 )
240 }
241
242 // Check if provider is configured
243 if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
244 return m, tea.Sequence(
245 util.CmdHandler(dialogs.CloseDialogMsg{}),
246 util.CmdHandler(ModelSelectedMsg{
247 Model: config.SelectedModel{
248 Model: selectedItem.Model.ID,
249 Provider: string(selectedItem.Provider.ID),
250 ReasoningEffort: selectedItem.Model.DefaultReasoningEffort,
251 MaxTokens: selectedItem.Model.DefaultMaxTokens,
252 },
253 ModelType: modelType,
254 }),
255 )
256 } else {
257 if selectedItem.Provider.ID == catwalk.InferenceProviderAnthropic {
258 m.showClaudeAuthMethodChooser = true
259 m.keyMap.isClaudeAuthChoiseHelp = true
260 return m, nil
261 }
262 askForApiKey()
263 return m, nil
264 }
265 case key.Matches(msg, m.keyMap.Tab):
266 switch {
267 case m.showClaudeAuthMethodChooser:
268 m.claudeAuthMethodChooser.ToggleChoice()
269 return m, nil
270 case m.needsAPIKey:
271 u, cmd := m.apiKeyInput.Update(msg)
272 m.apiKeyInput = u.(*APIKeyInput)
273 return m, cmd
274 case m.modelList.GetModelType() == LargeModelType:
275 m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
276 return m, m.modelList.SetModelType(SmallModelType)
277 default:
278 m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
279 return m, m.modelList.SetModelType(LargeModelType)
280 }
281 case key.Matches(msg, m.keyMap.Close):
282 if m.showClaudeAuthMethodChooser {
283 m.claudeAuthMethodChooser.SetDefaults()
284 m.showClaudeAuthMethodChooser = false
285 m.keyMap.isClaudeAuthChoiseHelp = false
286 m.keyMap.isClaudeOAuthHelp = false
287 return m, nil
288 }
289 if m.needsAPIKey {
290 if m.isAPIKeyValid {
291 return m, nil
292 }
293 // Go back to model selection
294 m.needsAPIKey = false
295 m.selectedModel = nil
296 m.isAPIKeyValid = false
297 m.apiKeyValue = ""
298 m.apiKeyInput.Reset()
299 return m, nil
300 }
301 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
302 default:
303 if m.showClaudeAuthMethodChooser {
304 u, cmd := m.claudeAuthMethodChooser.Update(msg)
305 m.claudeAuthMethodChooser = u.(*claude.AuthMethodChooser)
306 return m, cmd
307 } else if m.showClaudeOAuth2 {
308 u, cmd := m.claudeOAuth2.Update(msg)
309 m.claudeOAuth2 = u.(*claude.OAuth2)
310 return m, cmd
311 } else if m.needsAPIKey {
312 u, cmd := m.apiKeyInput.Update(msg)
313 m.apiKeyInput = u.(*APIKeyInput)
314 return m, cmd
315 } else {
316 u, cmd := m.modelList.Update(msg)
317 m.modelList = u
318 return m, cmd
319 }
320 }
321 case tea.PasteMsg:
322 if m.showClaudeOAuth2 {
323 u, cmd := m.claudeOAuth2.Update(msg)
324 m.claudeOAuth2 = u.(*claude.OAuth2)
325 return m, cmd
326 } else if m.needsAPIKey {
327 u, cmd := m.apiKeyInput.Update(msg)
328 m.apiKeyInput = u.(*APIKeyInput)
329 return m, cmd
330 } else {
331 var cmd tea.Cmd
332 m.modelList, cmd = m.modelList.Update(msg)
333 return m, cmd
334 }
335 case spinner.TickMsg:
336 if m.showClaudeOAuth2 {
337 u, cmd := m.claudeOAuth2.Update(msg)
338 m.claudeOAuth2 = u.(*claude.OAuth2)
339 return m, cmd
340 } else {
341 u, cmd := m.apiKeyInput.Update(msg)
342 m.apiKeyInput = u.(*APIKeyInput)
343 return m, cmd
344 }
345 }
346 return m, nil
347}
348
349func (m *modelDialogCmp) View() string {
350 t := styles.CurrentTheme()
351
352 switch {
353 case m.showClaudeAuthMethodChooser:
354 chooserView := m.claudeAuthMethodChooser.View()
355 content := lipgloss.JoinVertical(
356 lipgloss.Left,
357 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
358 chooserView,
359 "",
360 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
361 )
362 return m.style().Render(content)
363 case m.showClaudeOAuth2:
364 m.keyMap.isClaudeOAuthURLState = m.claudeOAuth2.State == claude.OAuthStateURL
365 oauth2View := m.claudeOAuth2.View()
366 content := lipgloss.JoinVertical(
367 lipgloss.Left,
368 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
369 oauth2View,
370 "",
371 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
372 )
373 return m.style().Render(content)
374 case m.needsAPIKey:
375 // Show API key input
376 m.keyMap.isAPIKeyHelp = true
377 m.keyMap.isAPIKeyValid = m.isAPIKeyValid
378 apiKeyView := m.apiKeyInput.View()
379 apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
380 content := lipgloss.JoinVertical(
381 lipgloss.Left,
382 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
383 apiKeyView,
384 "",
385 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
386 )
387 return m.style().Render(content)
388 }
389
390 // Show model selection
391 listView := m.modelList.View()
392 radio := m.modelTypeRadio()
393 content := lipgloss.JoinVertical(
394 lipgloss.Left,
395 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
396 listView,
397 "",
398 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
399 )
400 return m.style().Render(content)
401}
402
403func (m *modelDialogCmp) Cursor() *tea.Cursor {
404 if m.showClaudeAuthMethodChooser {
405 return nil
406 }
407 if m.showClaudeOAuth2 {
408 if cursor := m.claudeOAuth2.CodeInput.Cursor(); cursor != nil {
409 cursor.Y += 2 // FIXME(@andreynering): Why do we need this?
410 return m.moveCursor(cursor)
411 }
412 return nil
413 }
414 if m.needsAPIKey {
415 cursor := m.apiKeyInput.Cursor()
416 if cursor != nil {
417 cursor = m.moveCursor(cursor)
418 return cursor
419 }
420 } else {
421 cursor := m.modelList.Cursor()
422 if cursor != nil {
423 cursor = m.moveCursor(cursor)
424 return cursor
425 }
426 }
427 return nil
428}
429
430func (m *modelDialogCmp) style() lipgloss.Style {
431 t := styles.CurrentTheme()
432 return t.S().Base.
433 Width(m.width).
434 Border(lipgloss.RoundedBorder()).
435 BorderForeground(t.BorderFocus)
436}
437
438func (m *modelDialogCmp) listWidth() int {
439 return m.width - 2
440}
441
442func (m *modelDialogCmp) listHeight() int {
443 return m.wHeight / 2
444}
445
446func (m *modelDialogCmp) Position() (int, int) {
447 row := m.wHeight/4 - 2 // just a bit above the center
448 col := m.wWidth / 2
449 col -= m.width / 2
450 return row, col
451}
452
453func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
454 row, col := m.Position()
455 if m.needsAPIKey {
456 offset := row + 3 // Border + title + API key input offset
457 cursor.Y += offset
458 cursor.X = cursor.X + col + 2
459 } else {
460 offset := row + 3 // Border + title
461 cursor.Y += offset
462 cursor.X = cursor.X + col + 2
463 }
464 return cursor
465}
466
467func (m *modelDialogCmp) ID() dialogs.DialogID {
468 return ModelsDialogID
469}
470
471func (m *modelDialogCmp) modelTypeRadio() string {
472 t := styles.CurrentTheme()
473 choices := []string{"Large Task", "Small Task"}
474 iconSelected := "◉"
475 iconUnselected := "○"
476 if m.modelList.GetModelType() == LargeModelType {
477 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1])
478 }
479 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1])
480}
481
482func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
483 cfg := config.Get()
484 if _, ok := cfg.Providers.Get(providerID); ok {
485 return true
486 }
487 return false
488}
489
490func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
491 cfg := config.Get()
492 providers, err := config.Providers(cfg)
493 if err != nil {
494 return nil, err
495 }
496 for _, p := range providers {
497 if p.ID == providerID {
498 return &p, nil
499 }
500 }
501 return nil, nil
502}
503
504func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey any, close bool) tea.Cmd {
505 if m.selectedModel == nil {
506 return util.ReportError(fmt.Errorf("no model selected"))
507 }
508
509 cfg := config.Get()
510 err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
511 if err != nil {
512 return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
513 }
514
515 // Reset API key state and continue with model selection
516 selectedModel := *m.selectedModel
517 var cmds []tea.Cmd
518 if close {
519 cmds = append(cmds, util.CmdHandler(dialogs.CloseDialogMsg{}))
520 }
521 cmds = append(
522 cmds,
523 util.CmdHandler(ModelSelectedMsg{
524 Model: config.SelectedModel{
525 Model: selectedModel.Model.ID,
526 Provider: string(selectedModel.Provider.ID),
527 ReasoningEffort: selectedModel.Model.DefaultReasoningEffort,
528 MaxTokens: selectedModel.Model.DefaultMaxTokens,
529 },
530 ModelType: m.selectedModelType,
531 }),
532 )
533 return tea.Sequence(cmds...)
534}