1package models
2
3import (
4 "fmt"
5 "time"
6
7 "charm.land/bubbles/v2/help"
8 "charm.land/bubbles/v2/key"
9 "charm.land/bubbles/v2/spinner"
10 tea "charm.land/bubbletea/v2"
11 "charm.land/lipgloss/v2"
12 "github.com/atotto/clipboard"
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/tui/components/core"
16 "github.com/charmbracelet/crush/internal/tui/components/dialogs"
17 "github.com/charmbracelet/crush/internal/tui/components/dialogs/claude"
18 "github.com/charmbracelet/crush/internal/tui/exp/list"
19 "github.com/charmbracelet/crush/internal/tui/styles"
20 "github.com/charmbracelet/crush/internal/tui/util"
21)
22
23const (
24 ModelsDialogID dialogs.DialogID = "models"
25
26 defaultWidth = 60
27)
28
29const (
30 LargeModelType int = iota
31 SmallModelType
32
33 largeModelInputPlaceholder = "Choose a model for large, complex tasks"
34 smallModelInputPlaceholder = "Choose a model for small, simple tasks"
35)
36
37// ModelSelectedMsg is sent when a model is selected
38type ModelSelectedMsg struct {
39 Model config.SelectedModel
40 ModelType config.SelectedModelType
41}
42
43// CloseModelDialogMsg is sent when a model is selected
44type CloseModelDialogMsg struct{}
45
46// ModelDialog interface for the model selection dialog
47type ModelDialog interface {
48 dialogs.DialogModel
49}
50
51type ModelOption struct {
52 Provider catwalk.Provider
53 Model catwalk.Model
54}
55
56type modelDialogCmp struct {
57 width int
58 wWidth int
59 wHeight int
60
61 modelList *ModelListComponent
62 keyMap KeyMap
63 help help.Model
64
65 // API key state
66 needsAPIKey bool
67 apiKeyInput *APIKeyInput
68 selectedModel *ModelOption
69 selectedModelType config.SelectedModelType
70 isAPIKeyValid bool
71 apiKeyValue string
72
73 // Claude state
74 claudeAuthMethodChooser *claude.AuthMethodChooser
75 claudeOAuth2 *claude.OAuth2
76 showClaudeAuthMethodChooser bool
77 showClaudeOAuth2 bool
78}
79
80func NewModelDialogCmp() ModelDialog {
81 keyMap := DefaultKeyMap()
82
83 listKeyMap := list.DefaultKeyMap()
84 listKeyMap.Down.SetEnabled(false)
85 listKeyMap.Up.SetEnabled(false)
86 listKeyMap.DownOneItem = keyMap.Next
87 listKeyMap.UpOneItem = keyMap.Previous
88
89 t := styles.CurrentTheme()
90 modelList := NewModelListComponent(listKeyMap, largeModelInputPlaceholder, true)
91 apiKeyInput := NewAPIKeyInput()
92 apiKeyInput.SetShowTitle(false)
93 help := help.New()
94 help.Styles = t.S().Help
95
96 return &modelDialogCmp{
97 modelList: modelList,
98 apiKeyInput: apiKeyInput,
99 width: defaultWidth,
100 keyMap: DefaultKeyMap(),
101 help: help,
102
103 claudeAuthMethodChooser: claude.NewAuthMethodChooser(),
104 claudeOAuth2: claude.NewOAuth2(),
105 }
106}
107
108func (m *modelDialogCmp) Init() tea.Cmd {
109 return tea.Batch(
110 m.modelList.Init(),
111 m.apiKeyInput.Init(),
112 m.claudeAuthMethodChooser.Init(),
113 m.claudeOAuth2.Init(),
114 )
115}
116
117func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
118 switch msg := msg.(type) {
119 case tea.WindowSizeMsg:
120 m.wWidth = msg.Width
121 m.wHeight = msg.Height
122 m.apiKeyInput.SetWidth(m.width - 2)
123 m.help.SetWidth(m.width - 2)
124 m.claudeAuthMethodChooser.SetWidth(m.width - 2)
125 return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
126 case APIKeyStateChangeMsg:
127 u, cmd := m.apiKeyInput.Update(msg)
128 m.apiKeyInput = u.(*APIKeyInput)
129 return m, cmd
130 case claude.ValidationCompletedMsg:
131 var cmds []tea.Cmd
132 u, cmd := m.claudeOAuth2.Update(msg)
133 m.claudeOAuth2 = u.(*claude.OAuth2)
134 cmds = append(cmds, cmd)
135
136 if msg.State == claude.OAuthValidationStateValid {
137 cmds = append(cmds, m.saveAPIKeyAndContinue(msg.Token, false))
138 m.keyMap.isClaudeOAuthHelpComplete = true
139 }
140
141 return m, tea.Batch(cmds...)
142 case claude.AuthenticationCompleteMsg:
143 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
144 case tea.KeyPressMsg:
145 switch {
146 case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))):
147 if m.showClaudeOAuth2 && m.claudeOAuth2.State == claude.OAuthStateURL {
148 return m, tea.Sequence(
149 tea.SetClipboard(m.claudeOAuth2.URL),
150 func() tea.Msg {
151 _ = clipboard.WriteAll(m.claudeOAuth2.URL)
152 return nil
153 },
154 util.ReportInfo("URL copied to clipboard"),
155 )
156 }
157 case key.Matches(msg, m.keyMap.Choose) && m.showClaudeAuthMethodChooser:
158 m.claudeAuthMethodChooser.ToggleChoice()
159 return m, nil
160 case key.Matches(msg, m.keyMap.Select):
161 selectedItem := m.modelList.SelectedModel()
162
163 modelType := config.SelectedModelTypeLarge
164 if m.modelList.GetModelType() == SmallModelType {
165 modelType = config.SelectedModelTypeSmall
166 }
167
168 askForApiKey := func() {
169 m.keyMap.isClaudeAuthChoiseHelp = false
170 m.keyMap.isClaudeOAuthHelp = false
171 m.keyMap.isAPIKeyHelp = true
172 m.showClaudeAuthMethodChooser = false
173 m.needsAPIKey = true
174 m.selectedModel = selectedItem
175 m.selectedModelType = modelType
176 m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
177 }
178
179 if m.showClaudeAuthMethodChooser {
180 switch m.claudeAuthMethodChooser.State {
181 case claude.AuthMethodAPIKey:
182 askForApiKey()
183 case claude.AuthMethodOAuth2:
184 m.selectedModel = selectedItem
185 m.selectedModelType = modelType
186 m.showClaudeAuthMethodChooser = false
187 m.showClaudeOAuth2 = true
188 m.keyMap.isClaudeAuthChoiseHelp = false
189 m.keyMap.isClaudeOAuthHelp = true
190 }
191 return m, nil
192 }
193 if m.showClaudeOAuth2 {
194 m2, cmd2 := m.claudeOAuth2.ValidationConfirm()
195 m.claudeOAuth2 = m2.(*claude.OAuth2)
196 return m, cmd2
197 }
198 if m.isAPIKeyValid {
199 return m, m.saveAPIKeyAndContinue(m.apiKeyValue, true)
200 }
201 if m.needsAPIKey {
202 // Handle API key submission
203 m.apiKeyValue = m.apiKeyInput.Value()
204 provider, err := m.getProvider(m.selectedModel.Provider.ID)
205 if err != nil || provider == nil {
206 return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
207 }
208 providerConfig := config.ProviderConfig{
209 ID: string(m.selectedModel.Provider.ID),
210 Name: m.selectedModel.Provider.Name,
211 APIKey: m.apiKeyValue,
212 Type: provider.Type,
213 BaseURL: provider.APIEndpoint,
214 }
215 return m, tea.Sequence(
216 util.CmdHandler(APIKeyStateChangeMsg{
217 State: APIKeyInputStateVerifying,
218 }),
219 func() tea.Msg {
220 start := time.Now()
221 err := providerConfig.TestConnection(config.Get().Resolver())
222 // intentionally wait for at least 750ms to make sure the user sees the spinner
223 elapsed := time.Since(start)
224 if elapsed < 750*time.Millisecond {
225 time.Sleep(750*time.Millisecond - elapsed)
226 }
227 if err == nil {
228 m.isAPIKeyValid = true
229 return APIKeyStateChangeMsg{
230 State: APIKeyInputStateVerified,
231 }
232 }
233 return APIKeyStateChangeMsg{
234 State: APIKeyInputStateError,
235 }
236 },
237 )
238 }
239
240 // Check if provider is configured
241 if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
242 return m, tea.Sequence(
243 util.CmdHandler(dialogs.CloseDialogMsg{}),
244 util.CmdHandler(ModelSelectedMsg{
245 Model: config.SelectedModel{
246 Model: selectedItem.Model.ID,
247 Provider: string(selectedItem.Provider.ID),
248 ReasoningEffort: selectedItem.Model.DefaultReasoningEffort,
249 MaxTokens: selectedItem.Model.DefaultMaxTokens,
250 },
251 ModelType: modelType,
252 }),
253 )
254 } else {
255 if selectedItem.Provider.ID == catwalk.InferenceProviderAnthropic {
256 m.showClaudeAuthMethodChooser = true
257 m.keyMap.isClaudeAuthChoiseHelp = true
258 return m, nil
259 }
260 askForApiKey()
261 return m, nil
262 }
263 case key.Matches(msg, m.keyMap.Tab):
264 switch {
265 case m.showClaudeAuthMethodChooser:
266 m.claudeAuthMethodChooser.ToggleChoice()
267 return m, nil
268 case m.needsAPIKey:
269 u, cmd := m.apiKeyInput.Update(msg)
270 m.apiKeyInput = u.(*APIKeyInput)
271 return m, cmd
272 case m.modelList.GetModelType() == LargeModelType:
273 m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
274 return m, m.modelList.SetModelType(SmallModelType)
275 default:
276 m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
277 return m, m.modelList.SetModelType(LargeModelType)
278 }
279 case key.Matches(msg, m.keyMap.Close):
280 if m.showClaudeAuthMethodChooser {
281 m.claudeAuthMethodChooser.SetDefaults()
282 m.showClaudeAuthMethodChooser = false
283 m.keyMap.isClaudeAuthChoiseHelp = false
284 m.keyMap.isClaudeOAuthHelp = false
285 return m, nil
286 }
287 if m.needsAPIKey {
288 if m.isAPIKeyValid {
289 return m, nil
290 }
291 // Go back to model selection
292 m.needsAPIKey = false
293 m.selectedModel = nil
294 m.isAPIKeyValid = false
295 m.apiKeyValue = ""
296 m.apiKeyInput.Reset()
297 return m, nil
298 }
299 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
300 default:
301 if m.showClaudeAuthMethodChooser {
302 u, cmd := m.claudeAuthMethodChooser.Update(msg)
303 m.claudeAuthMethodChooser = u.(*claude.AuthMethodChooser)
304 return m, cmd
305 } else if m.showClaudeOAuth2 {
306 u, cmd := m.claudeOAuth2.Update(msg)
307 m.claudeOAuth2 = u.(*claude.OAuth2)
308 return m, cmd
309 } else if m.needsAPIKey {
310 u, cmd := m.apiKeyInput.Update(msg)
311 m.apiKeyInput = u.(*APIKeyInput)
312 return m, cmd
313 } else {
314 u, cmd := m.modelList.Update(msg)
315 m.modelList = u
316 return m, cmd
317 }
318 }
319 case tea.PasteMsg:
320 if m.showClaudeOAuth2 {
321 u, cmd := m.claudeOAuth2.Update(msg)
322 m.claudeOAuth2 = u.(*claude.OAuth2)
323 return m, cmd
324 } else if m.needsAPIKey {
325 u, cmd := m.apiKeyInput.Update(msg)
326 m.apiKeyInput = u.(*APIKeyInput)
327 return m, cmd
328 } else {
329 var cmd tea.Cmd
330 m.modelList, cmd = m.modelList.Update(msg)
331 return m, cmd
332 }
333 case spinner.TickMsg:
334 if m.showClaudeOAuth2 {
335 u, cmd := m.claudeOAuth2.Update(msg)
336 m.claudeOAuth2 = u.(*claude.OAuth2)
337 return m, cmd
338 } else {
339 u, cmd := m.apiKeyInput.Update(msg)
340 m.apiKeyInput = u.(*APIKeyInput)
341 return m, cmd
342 }
343 }
344 return m, nil
345}
346
347func (m *modelDialogCmp) View() string {
348 t := styles.CurrentTheme()
349
350 switch {
351 case m.showClaudeAuthMethodChooser:
352 chooserView := m.claudeAuthMethodChooser.View()
353 content := lipgloss.JoinVertical(
354 lipgloss.Left,
355 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
356 chooserView,
357 "",
358 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
359 )
360 return m.style().Render(content)
361 case m.showClaudeOAuth2:
362 m.keyMap.isClaudeOAuthURLState = m.claudeOAuth2.State == claude.OAuthStateURL
363 oauth2View := m.claudeOAuth2.View()
364 content := lipgloss.JoinVertical(
365 lipgloss.Left,
366 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
367 oauth2View,
368 "",
369 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
370 )
371 return m.style().Render(content)
372 case m.needsAPIKey:
373 // Show API key input
374 m.keyMap.isAPIKeyHelp = true
375 m.keyMap.isAPIKeyValid = m.isAPIKeyValid
376 apiKeyView := m.apiKeyInput.View()
377 apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
378 content := lipgloss.JoinVertical(
379 lipgloss.Left,
380 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
381 apiKeyView,
382 "",
383 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
384 )
385 return m.style().Render(content)
386 }
387
388 // Show model selection
389 listView := m.modelList.View()
390 radio := m.modelTypeRadio()
391 content := lipgloss.JoinVertical(
392 lipgloss.Left,
393 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
394 listView,
395 "",
396 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
397 )
398 return m.style().Render(content)
399}
400
401func (m *modelDialogCmp) Cursor() *tea.Cursor {
402 if m.showClaudeAuthMethodChooser {
403 return nil
404 }
405 if m.showClaudeOAuth2 {
406 if cursor := m.claudeOAuth2.CodeInput.Cursor(); cursor != nil {
407 cursor.Y += 2 // FIXME(@andreynering): Why do we need this?
408 return m.moveCursor(cursor)
409 }
410 return nil
411 }
412 if m.needsAPIKey {
413 cursor := m.apiKeyInput.Cursor()
414 if cursor != nil {
415 cursor = m.moveCursor(cursor)
416 return cursor
417 }
418 } else {
419 cursor := m.modelList.Cursor()
420 if cursor != nil {
421 cursor = m.moveCursor(cursor)
422 return cursor
423 }
424 }
425 return nil
426}
427
428func (m *modelDialogCmp) style() lipgloss.Style {
429 t := styles.CurrentTheme()
430 return t.S().Base.
431 Width(m.width).
432 Border(lipgloss.RoundedBorder()).
433 BorderForeground(t.BorderFocus)
434}
435
436func (m *modelDialogCmp) listWidth() int {
437 return m.width - 2
438}
439
440func (m *modelDialogCmp) listHeight() int {
441 return m.wHeight / 2
442}
443
444func (m *modelDialogCmp) Position() (int, int) {
445 row := m.wHeight/4 - 2 // just a bit above the center
446 col := m.wWidth / 2
447 col -= m.width / 2
448 return row, col
449}
450
451func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
452 row, col := m.Position()
453 if m.needsAPIKey {
454 offset := row + 3 // Border + title + API key input offset
455 cursor.Y += offset
456 cursor.X = cursor.X + col + 2
457 } else {
458 offset := row + 3 // Border + title
459 cursor.Y += offset
460 cursor.X = cursor.X + col + 2
461 }
462 return cursor
463}
464
465func (m *modelDialogCmp) ID() dialogs.DialogID {
466 return ModelsDialogID
467}
468
469func (m *modelDialogCmp) modelTypeRadio() string {
470 t := styles.CurrentTheme()
471 choices := []string{"Large Task", "Small Task"}
472 iconSelected := "◉"
473 iconUnselected := "○"
474 if m.modelList.GetModelType() == LargeModelType {
475 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1])
476 }
477 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1])
478}
479
480func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
481 cfg := config.Get()
482 if _, ok := cfg.Providers.Get(providerID); ok {
483 return true
484 }
485 return false
486}
487
488func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
489 cfg := config.Get()
490 providers, err := config.Providers(cfg)
491 if err != nil {
492 return nil, err
493 }
494 for _, p := range providers {
495 if p.ID == providerID {
496 return &p, nil
497 }
498 }
499 return nil, nil
500}
501
502func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey any, close bool) tea.Cmd {
503 if m.selectedModel == nil {
504 return util.ReportError(fmt.Errorf("no model selected"))
505 }
506
507 cfg := config.Get()
508 err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
509 if err != nil {
510 return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
511 }
512
513 // Reset API key state and continue with model selection
514 selectedModel := *m.selectedModel
515 var cmds []tea.Cmd
516 if close {
517 cmds = append(cmds, util.CmdHandler(dialogs.CloseDialogMsg{}))
518 }
519 cmds = append(
520 cmds,
521 util.CmdHandler(ModelSelectedMsg{
522 Model: config.SelectedModel{
523 Model: selectedModel.Model.ID,
524 Provider: string(selectedModel.Provider.ID),
525 ReasoningEffort: selectedModel.Model.DefaultReasoningEffort,
526 MaxTokens: selectedModel.Model.DefaultMaxTokens,
527 },
528 ModelType: m.selectedModelType,
529 }),
530 )
531 return tea.Sequence(cmds...)
532}