1// Package models provides the model selection dialog for the TUI.
2package models
3
4import (
5 "fmt"
6 "time"
7
8 "charm.land/bubbles/v2/help"
9 "charm.land/bubbles/v2/key"
10 "charm.land/bubbles/v2/spinner"
11 tea "charm.land/bubbletea/v2"
12 "charm.land/lipgloss/v2"
13 "github.com/atotto/clipboard"
14 "github.com/charmbracelet/catwalk/pkg/catwalk"
15 hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
16 "github.com/charmbracelet/crush/internal/config"
17 "github.com/charmbracelet/crush/internal/tui/components/core"
18 "github.com/charmbracelet/crush/internal/tui/components/dialogs"
19 "github.com/charmbracelet/crush/internal/tui/components/dialogs/claude"
20 "github.com/charmbracelet/crush/internal/tui/components/dialogs/copilot"
21 "github.com/charmbracelet/crush/internal/tui/components/dialogs/hyper"
22 "github.com/charmbracelet/crush/internal/tui/exp/list"
23 "github.com/charmbracelet/crush/internal/tui/styles"
24 "github.com/charmbracelet/crush/internal/tui/util"
25)
26
27const (
28 ModelsDialogID dialogs.DialogID = "models"
29
30 defaultWidth = 60
31)
32
33const (
34 LargeModelType int = iota
35 SmallModelType
36
37 largeModelInputPlaceholder = "Choose a model for large, complex tasks"
38 smallModelInputPlaceholder = "Choose a model for small, simple tasks"
39)
40
41// ModelSelectedMsg is sent when a model is selected
42type ModelSelectedMsg struct {
43 Model config.SelectedModel
44 ModelType config.SelectedModelType
45}
46
47// CloseModelDialogMsg is sent when a model is selected
48type CloseModelDialogMsg struct{}
49
50// ModelDialog interface for the model selection dialog
51type ModelDialog interface {
52 dialogs.DialogModel
53}
54
55type ModelOption struct {
56 Provider catwalk.Provider
57 Model catwalk.Model
58}
59
60type modelDialogCmp struct {
61 width int
62 wWidth int
63 wHeight int
64
65 modelList *ModelListComponent
66 keyMap KeyMap
67 help help.Model
68
69 // API key state
70 needsAPIKey bool
71 apiKeyInput *APIKeyInput
72 selectedModel *ModelOption
73 selectedModelType config.SelectedModelType
74 isAPIKeyValid bool
75 apiKeyValue string
76
77 // Hyper device flow state
78 hyperDeviceFlow *hyper.DeviceFlow
79 showHyperDeviceFlow bool
80
81 // Copilot device flow state
82 copilotDeviceFlow *copilot.DeviceFlow
83 showCopilotDeviceFlow bool
84
85 // Claude state
86 claudeAuthMethodChooser *claude.AuthMethodChooser
87 claudeOAuth2 *claude.OAuth2
88 showClaudeAuthMethodChooser bool
89 showClaudeOAuth2 bool
90}
91
92func NewModelDialogCmp() ModelDialog {
93 keyMap := DefaultKeyMap()
94
95 listKeyMap := list.DefaultKeyMap()
96 listKeyMap.Down.SetEnabled(false)
97 listKeyMap.Up.SetEnabled(false)
98 listKeyMap.DownOneItem = keyMap.Next
99 listKeyMap.UpOneItem = keyMap.Previous
100
101 t := styles.CurrentTheme()
102 modelList := NewModelListComponent(listKeyMap, largeModelInputPlaceholder, true)
103 apiKeyInput := NewAPIKeyInput()
104 apiKeyInput.SetShowTitle(false)
105 help := help.New()
106 help.Styles = t.S().Help
107
108 return &modelDialogCmp{
109 modelList: modelList,
110 apiKeyInput: apiKeyInput,
111 width: defaultWidth,
112 keyMap: DefaultKeyMap(),
113 help: help,
114
115 claudeAuthMethodChooser: claude.NewAuthMethodChooser(),
116 claudeOAuth2: claude.NewOAuth2(),
117 }
118}
119
120func (m *modelDialogCmp) Init() tea.Cmd {
121 return tea.Batch(
122 m.modelList.Init(),
123 m.apiKeyInput.Init(),
124 m.claudeAuthMethodChooser.Init(),
125 m.claudeOAuth2.Init(),
126 )
127}
128
129func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
130 switch msg := msg.(type) {
131 case tea.WindowSizeMsg:
132 m.wWidth = msg.Width
133 m.wHeight = msg.Height
134 m.apiKeyInput.SetWidth(m.width - 2)
135 m.help.SetWidth(m.width - 2)
136 m.claudeAuthMethodChooser.SetWidth(m.width - 2)
137 return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
138 case APIKeyStateChangeMsg:
139 u, cmd := m.apiKeyInput.Update(msg)
140 m.apiKeyInput = u.(*APIKeyInput)
141 return m, cmd
142 case hyper.DeviceFlowCompletedMsg:
143 return m, m.saveOauthTokenAndContinue(msg.Token, true)
144 case hyper.DeviceAuthInitiatedMsg, hyper.DeviceFlowErrorMsg:
145 if m.hyperDeviceFlow != nil {
146 u, cmd := m.hyperDeviceFlow.Update(msg)
147 m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
148 return m, cmd
149 }
150 return m, nil
151 case copilot.DeviceAuthInitiatedMsg, copilot.DeviceFlowErrorMsg:
152 if m.copilotDeviceFlow != nil {
153 u, cmd := m.copilotDeviceFlow.Update(msg)
154 m.copilotDeviceFlow = u.(*copilot.DeviceFlow)
155 return m, cmd
156 }
157 return m, nil
158 case copilot.DeviceFlowCompletedMsg:
159 return m, m.saveOauthTokenAndContinue(msg.Token, true)
160 case claude.ValidationCompletedMsg:
161 var cmds []tea.Cmd
162 u, cmd := m.claudeOAuth2.Update(msg)
163 m.claudeOAuth2 = u.(*claude.OAuth2)
164 cmds = append(cmds, cmd)
165
166 if msg.State == claude.OAuthValidationStateValid {
167 cmds = append(cmds, m.saveOauthTokenAndContinue(msg.Token, false))
168 m.keyMap.isClaudeOAuthHelpComplete = true
169 }
170
171 return m, tea.Batch(cmds...)
172 case claude.AuthenticationCompleteMsg:
173 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
174 case tea.KeyPressMsg:
175 switch {
176 // Handle Hyper device flow keys
177 case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showHyperDeviceFlow:
178 return m, m.hyperDeviceFlow.CopyCode()
179 case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showCopilotDeviceFlow:
180 return m, m.copilotDeviceFlow.CopyCode()
181 case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showClaudeOAuth2 && m.claudeOAuth2.State == claude.OAuthStateURL:
182 return m, tea.Sequence(
183 tea.SetClipboard(m.claudeOAuth2.URL),
184 func() tea.Msg {
185 _ = clipboard.WriteAll(m.claudeOAuth2.URL)
186 return nil
187 },
188 util.ReportInfo("URL copied to clipboard"),
189 )
190 case key.Matches(msg, m.keyMap.Choose) && m.showClaudeAuthMethodChooser:
191 m.claudeAuthMethodChooser.ToggleChoice()
192 return m, nil
193 case key.Matches(msg, m.keyMap.Select):
194 // If showing device flow, enter copies code and opens URL
195 if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
196 return m, m.hyperDeviceFlow.CopyCodeAndOpenURL()
197 }
198 if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
199 return m, m.copilotDeviceFlow.CopyCodeAndOpenURL()
200 }
201 selectedItem := m.modelList.SelectedModel()
202 if selectedItem == nil {
203 return m, nil
204 }
205
206 modelType := config.SelectedModelTypeLarge
207 if m.modelList.GetModelType() == SmallModelType {
208 modelType = config.SelectedModelTypeSmall
209 }
210
211 askForApiKey := func() {
212 m.keyMap.isClaudeAuthChoiceHelp = false
213 m.keyMap.isClaudeOAuthHelp = false
214 m.keyMap.isAPIKeyHelp = true
215 m.showHyperDeviceFlow = false
216 m.showCopilotDeviceFlow = false
217 m.showClaudeAuthMethodChooser = false
218 m.needsAPIKey = true
219 m.selectedModel = selectedItem
220 m.selectedModelType = modelType
221 m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
222 }
223
224 if m.showClaudeAuthMethodChooser {
225 switch m.claudeAuthMethodChooser.State {
226 case claude.AuthMethodAPIKey:
227 askForApiKey()
228 case claude.AuthMethodOAuth2:
229 m.selectedModel = selectedItem
230 m.selectedModelType = modelType
231 m.showClaudeAuthMethodChooser = false
232 m.showClaudeOAuth2 = true
233 m.keyMap.isClaudeAuthChoiceHelp = false
234 m.keyMap.isClaudeOAuthHelp = true
235 }
236 return m, nil
237 }
238 if m.showClaudeOAuth2 {
239 m2, cmd2 := m.claudeOAuth2.ValidationConfirm()
240 m.claudeOAuth2 = m2.(*claude.OAuth2)
241 return m, cmd2
242 }
243 if m.isAPIKeyValid {
244 return m, m.saveOauthTokenAndContinue(m.apiKeyValue, true)
245 }
246 if m.needsAPIKey {
247 // Handle API key submission
248 m.apiKeyValue = m.apiKeyInput.Value()
249 provider, err := m.getProvider(m.selectedModel.Provider.ID)
250 if err != nil || provider == nil {
251 return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
252 }
253 providerConfig := config.ProviderConfig{
254 ID: string(m.selectedModel.Provider.ID),
255 Name: m.selectedModel.Provider.Name,
256 APIKey: m.apiKeyValue,
257 Type: provider.Type,
258 BaseURL: provider.APIEndpoint,
259 }
260 return m, tea.Sequence(
261 util.CmdHandler(APIKeyStateChangeMsg{
262 State: APIKeyInputStateVerifying,
263 }),
264 func() tea.Msg {
265 start := time.Now()
266 err := providerConfig.TestConnection(config.Get().Resolver())
267 // intentionally wait for at least 750ms to make sure the user sees the spinner
268 elapsed := time.Since(start)
269 if elapsed < 750*time.Millisecond {
270 time.Sleep(750*time.Millisecond - elapsed)
271 }
272 if err == nil {
273 m.isAPIKeyValid = true
274 return APIKeyStateChangeMsg{
275 State: APIKeyInputStateVerified,
276 }
277 }
278 return APIKeyStateChangeMsg{
279 State: APIKeyInputStateError,
280 }
281 },
282 )
283 }
284
285 // Check if provider is configured
286 if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
287 return m, tea.Sequence(
288 util.CmdHandler(dialogs.CloseDialogMsg{}),
289 util.CmdHandler(ModelSelectedMsg{
290 Model: config.SelectedModel{
291 Model: selectedItem.Model.ID,
292 Provider: string(selectedItem.Provider.ID),
293 ReasoningEffort: selectedItem.Model.DefaultReasoningEffort,
294 MaxTokens: selectedItem.Model.DefaultMaxTokens,
295 },
296 ModelType: modelType,
297 }),
298 )
299 }
300 switch selectedItem.Provider.ID {
301 case catwalk.InferenceProviderAnthropic:
302 m.showClaudeAuthMethodChooser = true
303 m.keyMap.isClaudeAuthChoiceHelp = true
304 return m, nil
305 case hyperp.Name:
306 m.showHyperDeviceFlow = true
307 m.selectedModel = selectedItem
308 m.selectedModelType = modelType
309 m.hyperDeviceFlow = hyper.NewDeviceFlow()
310 m.hyperDeviceFlow.SetWidth(m.width - 2)
311 return m, m.hyperDeviceFlow.Init()
312 case catwalk.InferenceProviderCopilot:
313 if token, ok := config.Get().ImportCopilot(); ok {
314 m.selectedModel = selectedItem
315 m.selectedModelType = modelType
316 return m, m.saveOauthTokenAndContinue(token, true)
317 }
318 m.showCopilotDeviceFlow = true
319 m.selectedModel = selectedItem
320 m.selectedModelType = modelType
321 m.copilotDeviceFlow = copilot.NewDeviceFlow()
322 m.copilotDeviceFlow.SetWidth(m.width - 2)
323 return m, m.copilotDeviceFlow.Init()
324 }
325 // For other providers, show API key input
326 askForApiKey()
327 return m, nil
328 case key.Matches(msg, m.keyMap.Tab):
329 switch {
330 case m.showClaudeAuthMethodChooser:
331 m.claudeAuthMethodChooser.ToggleChoice()
332 return m, nil
333 case m.needsAPIKey:
334 u, cmd := m.apiKeyInput.Update(msg)
335 m.apiKeyInput = u.(*APIKeyInput)
336 return m, cmd
337 case m.modelList.GetModelType() == LargeModelType:
338 m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
339 return m, m.modelList.SetModelType(SmallModelType)
340 default:
341 m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
342 return m, m.modelList.SetModelType(LargeModelType)
343 }
344 case key.Matches(msg, m.keyMap.Close):
345 switch {
346 case m.showHyperDeviceFlow:
347 if m.hyperDeviceFlow != nil {
348 m.hyperDeviceFlow.Cancel()
349 }
350 m.showHyperDeviceFlow = false
351 m.selectedModel = nil
352 case m.showCopilotDeviceFlow:
353 if m.copilotDeviceFlow != nil {
354 m.copilotDeviceFlow.Cancel()
355 }
356 m.showCopilotDeviceFlow = false
357 m.selectedModel = nil
358 case m.showClaudeAuthMethodChooser:
359 m.claudeAuthMethodChooser.SetDefaults()
360 m.showClaudeAuthMethodChooser = false
361 m.keyMap.isClaudeAuthChoiceHelp = false
362 m.keyMap.isClaudeOAuthHelp = false
363 return m, nil
364 case m.needsAPIKey:
365 if m.isAPIKeyValid {
366 return m, nil
367 }
368 // Go back to model selection
369 m.needsAPIKey = false
370 m.selectedModel = nil
371 m.isAPIKeyValid = false
372 m.apiKeyValue = ""
373 m.apiKeyInput.Reset()
374 return m, nil
375 default:
376 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
377 }
378 default:
379 switch {
380 case m.showClaudeAuthMethodChooser:
381 u, cmd := m.claudeAuthMethodChooser.Update(msg)
382 m.claudeAuthMethodChooser = u.(*claude.AuthMethodChooser)
383 return m, cmd
384 case m.showClaudeOAuth2:
385 u, cmd := m.claudeOAuth2.Update(msg)
386 m.claudeOAuth2 = u.(*claude.OAuth2)
387 return m, cmd
388 case m.needsAPIKey:
389 u, cmd := m.apiKeyInput.Update(msg)
390 m.apiKeyInput = u.(*APIKeyInput)
391 return m, cmd
392 default:
393 u, cmd := m.modelList.Update(msg)
394 m.modelList = u
395 return m, cmd
396 }
397 }
398 case tea.PasteMsg:
399 switch {
400 case m.showClaudeOAuth2:
401 u, cmd := m.claudeOAuth2.Update(msg)
402 m.claudeOAuth2 = u.(*claude.OAuth2)
403 return m, cmd
404 case m.needsAPIKey:
405 u, cmd := m.apiKeyInput.Update(msg)
406 m.apiKeyInput = u.(*APIKeyInput)
407 return m, cmd
408 default:
409 var cmd tea.Cmd
410 m.modelList, cmd = m.modelList.Update(msg)
411 return m, cmd
412 }
413 case spinner.TickMsg:
414 u, cmd := m.apiKeyInput.Update(msg)
415 m.apiKeyInput = u.(*APIKeyInput)
416 if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
417 u, cmd = m.hyperDeviceFlow.Update(msg)
418 m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
419 }
420 if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
421 u, cmd = m.copilotDeviceFlow.Update(msg)
422 m.copilotDeviceFlow = u.(*copilot.DeviceFlow)
423 }
424 return m, cmd
425 default:
426 // Pass all other messages to the device flow for spinner animation
427 switch {
428 case m.showHyperDeviceFlow && m.hyperDeviceFlow != nil:
429 u, cmd := m.hyperDeviceFlow.Update(msg)
430 m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
431 return m, cmd
432 case m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil:
433 u, cmd := m.copilotDeviceFlow.Update(msg)
434 m.copilotDeviceFlow = u.(*copilot.DeviceFlow)
435 return m, cmd
436 case m.showClaudeOAuth2:
437 u, cmd := m.claudeOAuth2.Update(msg)
438 m.claudeOAuth2 = u.(*claude.OAuth2)
439 return m, cmd
440 default:
441 u, cmd := m.apiKeyInput.Update(msg)
442 m.apiKeyInput = u.(*APIKeyInput)
443 return m, cmd
444 }
445 }
446 return m, nil
447}
448
449func (m *modelDialogCmp) View() string {
450 t := styles.CurrentTheme()
451
452 if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
453 // Show Hyper device flow
454 m.keyMap.isHyperDeviceFlow = true
455 deviceFlowView := m.hyperDeviceFlow.View()
456 content := lipgloss.JoinVertical(
457 lipgloss.Left,
458 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Authenticate with Hyper", m.width-4)),
459 deviceFlowView,
460 "",
461 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
462 )
463 return m.style().Render(content)
464 }
465 if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
466 // Show Hyper device flow
467 m.keyMap.isCopilotDeviceFlow = m.copilotDeviceFlow.State != copilot.DeviceFlowStateUnavailable
468 m.keyMap.isCopilotUnavailable = m.copilotDeviceFlow.State == copilot.DeviceFlowStateUnavailable
469 deviceFlowView := m.copilotDeviceFlow.View()
470 content := lipgloss.JoinVertical(
471 lipgloss.Left,
472 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Authenticate with GitHub Copilot", m.width-4)),
473 deviceFlowView,
474 "",
475 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
476 )
477 return m.style().Render(content)
478 }
479
480 // Reset the flags when not showing device flow
481 m.keyMap.isHyperDeviceFlow = false
482 m.keyMap.isCopilotDeviceFlow = false
483 m.keyMap.isCopilotUnavailable = false
484
485 switch {
486 case m.showClaudeAuthMethodChooser:
487 chooserView := m.claudeAuthMethodChooser.View()
488 content := lipgloss.JoinVertical(
489 lipgloss.Left,
490 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
491 chooserView,
492 "",
493 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
494 )
495 return m.style().Render(content)
496 case m.showClaudeOAuth2:
497 m.keyMap.isClaudeOAuthURLState = m.claudeOAuth2.State == claude.OAuthStateURL
498 oauth2View := m.claudeOAuth2.View()
499 content := lipgloss.JoinVertical(
500 lipgloss.Left,
501 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
502 oauth2View,
503 "",
504 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
505 )
506 return m.style().Render(content)
507 case m.needsAPIKey:
508 // Show API key input
509 m.keyMap.isAPIKeyHelp = true
510 m.keyMap.isAPIKeyValid = m.isAPIKeyValid
511 apiKeyView := m.apiKeyInput.View()
512 apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
513 content := lipgloss.JoinVertical(
514 lipgloss.Left,
515 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
516 apiKeyView,
517 "",
518 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
519 )
520 return m.style().Render(content)
521 }
522
523 // Show model selection
524 listView := m.modelList.View()
525 radio := m.modelTypeRadio()
526 content := lipgloss.JoinVertical(
527 lipgloss.Left,
528 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
529 listView,
530 "",
531 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
532 )
533 return m.style().Render(content)
534}
535
536func (m *modelDialogCmp) Cursor() *tea.Cursor {
537 if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
538 return m.hyperDeviceFlow.Cursor()
539 }
540 if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
541 return m.copilotDeviceFlow.Cursor()
542 }
543 if m.showClaudeAuthMethodChooser {
544 return nil
545 }
546 if m.showClaudeOAuth2 {
547 if cursor := m.claudeOAuth2.CodeInput.Cursor(); cursor != nil {
548 cursor.Y += 2 // FIXME(@andreynering): Why do we need this?
549 return m.moveCursor(cursor)
550 }
551 return nil
552 }
553 if m.needsAPIKey {
554 cursor := m.apiKeyInput.Cursor()
555 if cursor != nil {
556 cursor = m.moveCursor(cursor)
557 return cursor
558 }
559 } else {
560 cursor := m.modelList.Cursor()
561 if cursor != nil {
562 cursor = m.moveCursor(cursor)
563 return cursor
564 }
565 }
566 return nil
567}
568
569func (m *modelDialogCmp) style() lipgloss.Style {
570 t := styles.CurrentTheme()
571 return t.S().Base.
572 Width(m.width).
573 Border(lipgloss.RoundedBorder()).
574 BorderForeground(t.BorderFocus)
575}
576
577func (m *modelDialogCmp) listWidth() int {
578 return m.width - 2
579}
580
581func (m *modelDialogCmp) listHeight() int {
582 return m.wHeight / 2
583}
584
585func (m *modelDialogCmp) Position() (int, int) {
586 row := m.wHeight/4 - 2 // just a bit above the center
587 col := m.wWidth / 2
588 col -= m.width / 2
589 return row, col
590}
591
592func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
593 row, col := m.Position()
594 if m.needsAPIKey {
595 offset := row + 3 // Border + title + API key input offset
596 cursor.Y += offset
597 cursor.X = cursor.X + col + 2
598 } else {
599 offset := row + 3 // Border + title
600 cursor.Y += offset
601 cursor.X = cursor.X + col + 2
602 }
603 return cursor
604}
605
606func (m *modelDialogCmp) ID() dialogs.DialogID {
607 return ModelsDialogID
608}
609
610func (m *modelDialogCmp) modelTypeRadio() string {
611 t := styles.CurrentTheme()
612 choices := []string{"Large Task", "Small Task"}
613 iconSelected := "◉"
614 iconUnselected := "○"
615 if m.modelList.GetModelType() == LargeModelType {
616 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1])
617 }
618 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1])
619}
620
621func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
622 cfg := config.Get()
623 _, ok := cfg.Providers.Get(providerID)
624 return ok
625}
626
627func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
628 cfg := config.Get()
629 providers, err := config.Providers(cfg)
630 if err != nil {
631 return nil, err
632 }
633 for _, p := range providers {
634 if p.ID == providerID {
635 return &p, nil
636 }
637 }
638 return nil, nil
639}
640
641func (m *modelDialogCmp) saveOauthTokenAndContinue(apiKey any, close bool) tea.Cmd {
642 if m.selectedModel == nil {
643 return util.ReportError(fmt.Errorf("no model selected"))
644 }
645
646 cfg := config.Get()
647 err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
648 if err != nil {
649 return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
650 }
651
652 // Reset API key state and continue with model selection
653 selectedModel := *m.selectedModel
654 var cmds []tea.Cmd
655 if close {
656 cmds = append(cmds, util.CmdHandler(dialogs.CloseDialogMsg{}))
657 }
658 cmds = append(
659 cmds,
660 util.CmdHandler(ModelSelectedMsg{
661 Model: config.SelectedModel{
662 Model: selectedModel.Model.ID,
663 Provider: string(selectedModel.Provider.ID),
664 ReasoningEffort: selectedModel.Model.DefaultReasoningEffort,
665 MaxTokens: selectedModel.Model.DefaultMaxTokens,
666 },
667 ModelType: m.selectedModelType,
668 }),
669 )
670 return tea.Sequence(cmds...)
671}