1// Package models provides the model selection dialog for the TUI.
2package models
3
4import (
5 "fmt"
6 "time"
7
8 "charm.land/bubbles/v2/help"
9 "charm.land/bubbles/v2/key"
10 "charm.land/bubbles/v2/spinner"
11 tea "charm.land/bubbletea/v2"
12 "charm.land/lipgloss/v2"
13 "github.com/atotto/clipboard"
14 "github.com/charmbracelet/catwalk/pkg/catwalk"
15 hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
16 "github.com/charmbracelet/crush/internal/config"
17 "github.com/charmbracelet/crush/internal/tui/components/core"
18 "github.com/charmbracelet/crush/internal/tui/components/dialogs"
19 "github.com/charmbracelet/crush/internal/tui/components/dialogs/claude"
20 "github.com/charmbracelet/crush/internal/tui/components/dialogs/copilot"
21 "github.com/charmbracelet/crush/internal/tui/components/dialogs/hyper"
22 "github.com/charmbracelet/crush/internal/tui/exp/list"
23 "github.com/charmbracelet/crush/internal/tui/styles"
24 "github.com/charmbracelet/crush/internal/tui/util"
25)
26
27const (
28 ModelsDialogID dialogs.DialogID = "models"
29
30 defaultWidth = 60
31)
32
33const (
34 LargeModelType int = iota
35 SmallModelType
36
37 largeModelInputPlaceholder = "Choose a model for large, complex tasks"
38 smallModelInputPlaceholder = "Choose a model for small, simple tasks"
39)
40
41// ModelSelectedMsg is sent when a model is selected
42type ModelSelectedMsg struct {
43 Model config.SelectedModel
44 ModelType config.SelectedModelType
45}
46
47// CloseModelDialogMsg is sent when a model is selected
48type CloseModelDialogMsg struct{}
49
50// ModelDialog interface for the model selection dialog
51type ModelDialog interface {
52 dialogs.DialogModel
53}
54
55type ModelOption struct {
56 Provider catwalk.Provider
57 Model catwalk.Model
58}
59
60type modelDialogCmp struct {
61 width int
62 wWidth int
63 wHeight int
64
65 modelList *ModelListComponent
66 keyMap KeyMap
67 help help.Model
68
69 // API key state
70 needsAPIKey bool
71 apiKeyInput *APIKeyInput
72 selectedModel *ModelOption
73 selectedModelType config.SelectedModelType
74 isAPIKeyValid bool
75 apiKeyValue string
76
77 // Hyper device flow state
78 hyperDeviceFlow *hyper.DeviceFlow
79 showHyperDeviceFlow bool
80
81 // Copilot device flow state
82 copilotDeviceFlow *copilot.DeviceFlow
83 showCopilotDeviceFlow bool
84
85 // Claude state
86 claudeAuthMethodChooser *claude.AuthMethodChooser
87 claudeOAuth2 *claude.OAuth2
88 showClaudeAuthMethodChooser bool
89 showClaudeOAuth2 bool
90}
91
92func NewModelDialogCmp() ModelDialog {
93 keyMap := DefaultKeyMap()
94
95 listKeyMap := list.DefaultKeyMap()
96 listKeyMap.Down.SetEnabled(false)
97 listKeyMap.Up.SetEnabled(false)
98 listKeyMap.DownOneItem = keyMap.Next
99 listKeyMap.UpOneItem = keyMap.Previous
100
101 t := styles.CurrentTheme()
102 modelList := NewModelListComponent(listKeyMap, largeModelInputPlaceholder, true)
103 apiKeyInput := NewAPIKeyInput()
104 apiKeyInput.SetShowTitle(false)
105 help := help.New()
106 help.Styles = t.S().Help
107
108 return &modelDialogCmp{
109 modelList: modelList,
110 apiKeyInput: apiKeyInput,
111 width: defaultWidth,
112 keyMap: DefaultKeyMap(),
113 help: help,
114
115 claudeAuthMethodChooser: claude.NewAuthMethodChooser(),
116 claudeOAuth2: claude.NewOAuth2(),
117 }
118}
119
120func (m *modelDialogCmp) Init() tea.Cmd {
121 return tea.Batch(
122 m.modelList.Init(),
123 m.apiKeyInput.Init(),
124 m.claudeAuthMethodChooser.Init(),
125 m.claudeOAuth2.Init(),
126 )
127}
128
129func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
130 switch msg := msg.(type) {
131 case tea.WindowSizeMsg:
132 m.wWidth = msg.Width
133 m.wHeight = msg.Height
134 m.apiKeyInput.SetWidth(m.width - 2)
135 m.help.SetWidth(m.width - 2)
136 m.claudeAuthMethodChooser.SetWidth(m.width - 2)
137 return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
138 case APIKeyStateChangeMsg:
139 u, cmd := m.apiKeyInput.Update(msg)
140 m.apiKeyInput = u.(*APIKeyInput)
141 return m, cmd
142 case hyper.DeviceFlowCompletedMsg:
143 return m, m.saveOauthTokenAndContinue(msg.Token, true)
144 case hyper.DeviceAuthInitiatedMsg, hyper.DeviceFlowErrorMsg:
145 if m.hyperDeviceFlow != nil {
146 u, cmd := m.hyperDeviceFlow.Update(msg)
147 m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
148 return m, cmd
149 }
150 return m, nil
151 case copilot.DeviceAuthInitiatedMsg, copilot.DeviceFlowErrorMsg:
152 if m.copilotDeviceFlow != nil {
153 u, cmd := m.copilotDeviceFlow.Update(msg)
154 m.copilotDeviceFlow = u.(*copilot.DeviceFlow)
155 return m, cmd
156 }
157 return m, nil
158 case copilot.DeviceFlowCompletedMsg:
159 return m, m.saveOauthTokenAndContinue(msg.Token, true)
160 case claude.ValidationCompletedMsg:
161 var cmds []tea.Cmd
162 u, cmd := m.claudeOAuth2.Update(msg)
163 m.claudeOAuth2 = u.(*claude.OAuth2)
164 cmds = append(cmds, cmd)
165
166 if msg.State == claude.OAuthValidationStateValid {
167 cmds = append(cmds, m.saveOauthTokenAndContinue(msg.Token, false))
168 m.keyMap.isClaudeOAuthHelpComplete = true
169 }
170
171 return m, tea.Batch(cmds...)
172 case claude.AuthenticationCompleteMsg:
173 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
174 case tea.KeyPressMsg:
175 switch {
176 // Handle Hyper device flow keys
177 case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showHyperDeviceFlow:
178 return m, m.hyperDeviceFlow.CopyCode()
179 case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showCopilotDeviceFlow:
180 return m, m.copilotDeviceFlow.CopyCode()
181 case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showClaudeOAuth2 && m.claudeOAuth2.State == claude.OAuthStateURL:
182 return m, tea.Sequence(
183 tea.SetClipboard(m.claudeOAuth2.URL),
184 func() tea.Msg {
185 _ = clipboard.WriteAll(m.claudeOAuth2.URL)
186 return nil
187 },
188 util.ReportInfo("URL copied to clipboard"),
189 )
190 case key.Matches(msg, m.keyMap.Choose) && m.showClaudeAuthMethodChooser:
191 m.claudeAuthMethodChooser.ToggleChoice()
192 return m, nil
193 case key.Matches(msg, m.keyMap.Select):
194 // If showing device flow, enter copies code and opens URL
195 if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
196 return m, m.hyperDeviceFlow.CopyCodeAndOpenURL()
197 }
198 if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
199 return m, m.copilotDeviceFlow.CopyCodeAndOpenURL()
200 }
201 selectedItem := m.modelList.SelectedModel()
202
203 modelType := config.SelectedModelTypeLarge
204 if m.modelList.GetModelType() == SmallModelType {
205 modelType = config.SelectedModelTypeSmall
206 }
207
208 askForApiKey := func() {
209 m.keyMap.isClaudeAuthChoiceHelp = false
210 m.keyMap.isClaudeOAuthHelp = false
211 m.keyMap.isAPIKeyHelp = true
212 m.showHyperDeviceFlow = false
213 m.showCopilotDeviceFlow = false
214 m.showClaudeAuthMethodChooser = false
215 m.needsAPIKey = true
216 m.selectedModel = selectedItem
217 m.selectedModelType = modelType
218 m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
219 }
220
221 if m.showClaudeAuthMethodChooser {
222 switch m.claudeAuthMethodChooser.State {
223 case claude.AuthMethodAPIKey:
224 askForApiKey()
225 case claude.AuthMethodOAuth2:
226 m.selectedModel = selectedItem
227 m.selectedModelType = modelType
228 m.showClaudeAuthMethodChooser = false
229 m.showClaudeOAuth2 = true
230 m.keyMap.isClaudeAuthChoiceHelp = false
231 m.keyMap.isClaudeOAuthHelp = true
232 }
233 return m, nil
234 }
235 if m.showClaudeOAuth2 {
236 m2, cmd2 := m.claudeOAuth2.ValidationConfirm()
237 m.claudeOAuth2 = m2.(*claude.OAuth2)
238 return m, cmd2
239 }
240 if m.isAPIKeyValid {
241 return m, m.saveOauthTokenAndContinue(m.apiKeyValue, true)
242 }
243 if m.needsAPIKey {
244 // Handle API key submission
245 m.apiKeyValue = m.apiKeyInput.Value()
246 provider, err := m.getProvider(m.selectedModel.Provider.ID)
247 if err != nil || provider == nil {
248 return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
249 }
250 providerConfig := config.ProviderConfig{
251 ID: string(m.selectedModel.Provider.ID),
252 Name: m.selectedModel.Provider.Name,
253 APIKey: m.apiKeyValue,
254 Type: provider.Type,
255 BaseURL: provider.APIEndpoint,
256 }
257 return m, tea.Sequence(
258 util.CmdHandler(APIKeyStateChangeMsg{
259 State: APIKeyInputStateVerifying,
260 }),
261 func() tea.Msg {
262 start := time.Now()
263 err := providerConfig.TestConnection(config.Get().Resolver())
264 // intentionally wait for at least 750ms to make sure the user sees the spinner
265 elapsed := time.Since(start)
266 if elapsed < 750*time.Millisecond {
267 time.Sleep(750*time.Millisecond - elapsed)
268 }
269 if err == nil {
270 m.isAPIKeyValid = true
271 return APIKeyStateChangeMsg{
272 State: APIKeyInputStateVerified,
273 }
274 }
275 return APIKeyStateChangeMsg{
276 State: APIKeyInputStateError,
277 }
278 },
279 )
280 }
281
282 // Check if provider is configured
283 if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
284 return m, tea.Sequence(
285 util.CmdHandler(dialogs.CloseDialogMsg{}),
286 util.CmdHandler(ModelSelectedMsg{
287 Model: config.SelectedModel{
288 Model: selectedItem.Model.ID,
289 Provider: string(selectedItem.Provider.ID),
290 ReasoningEffort: selectedItem.Model.DefaultReasoningEffort,
291 MaxTokens: selectedItem.Model.DefaultMaxTokens,
292 },
293 ModelType: modelType,
294 }),
295 )
296 }
297 switch selectedItem.Provider.ID {
298 case catwalk.InferenceProviderAnthropic:
299 m.showClaudeAuthMethodChooser = true
300 m.keyMap.isClaudeAuthChoiceHelp = true
301 return m, nil
302 case hyperp.Name:
303 m.showHyperDeviceFlow = true
304 m.selectedModel = selectedItem
305 m.selectedModelType = modelType
306 m.hyperDeviceFlow = hyper.NewDeviceFlow()
307 m.hyperDeviceFlow.SetWidth(m.width - 2)
308 return m, m.hyperDeviceFlow.Init()
309 case catwalk.InferenceProviderCopilot:
310 if token, ok := config.Get().ImportCopilot(); ok {
311 m.selectedModel = selectedItem
312 m.selectedModelType = modelType
313 return m, m.saveOauthTokenAndContinue(token, true)
314 }
315 m.showCopilotDeviceFlow = true
316 m.selectedModel = selectedItem
317 m.selectedModelType = modelType
318 m.copilotDeviceFlow = copilot.NewDeviceFlow()
319 m.copilotDeviceFlow.SetWidth(m.width - 2)
320 return m, m.copilotDeviceFlow.Init()
321 }
322 // For other providers, show API key input
323 askForApiKey()
324 return m, nil
325 case key.Matches(msg, m.keyMap.Tab):
326 switch {
327 case m.showClaudeAuthMethodChooser:
328 m.claudeAuthMethodChooser.ToggleChoice()
329 return m, nil
330 case m.needsAPIKey:
331 u, cmd := m.apiKeyInput.Update(msg)
332 m.apiKeyInput = u.(*APIKeyInput)
333 return m, cmd
334 case m.modelList.GetModelType() == LargeModelType:
335 m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
336 return m, m.modelList.SetModelType(SmallModelType)
337 default:
338 m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
339 return m, m.modelList.SetModelType(LargeModelType)
340 }
341 case key.Matches(msg, m.keyMap.Close):
342 switch {
343 case m.showHyperDeviceFlow:
344 if m.hyperDeviceFlow != nil {
345 m.hyperDeviceFlow.Cancel()
346 }
347 m.showHyperDeviceFlow = false
348 m.selectedModel = nil
349 case m.showCopilotDeviceFlow:
350 if m.copilotDeviceFlow != nil {
351 m.copilotDeviceFlow.Cancel()
352 }
353 m.showCopilotDeviceFlow = false
354 m.selectedModel = nil
355 case m.showClaudeAuthMethodChooser:
356 m.claudeAuthMethodChooser.SetDefaults()
357 m.showClaudeAuthMethodChooser = false
358 m.keyMap.isClaudeAuthChoiceHelp = false
359 m.keyMap.isClaudeOAuthHelp = false
360 return m, nil
361 case m.needsAPIKey:
362 if m.isAPIKeyValid {
363 return m, nil
364 }
365 // Go back to model selection
366 m.needsAPIKey = false
367 m.selectedModel = nil
368 m.isAPIKeyValid = false
369 m.apiKeyValue = ""
370 m.apiKeyInput.Reset()
371 return m, nil
372 default:
373 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
374 }
375 default:
376 switch {
377 case m.showClaudeAuthMethodChooser:
378 u, cmd := m.claudeAuthMethodChooser.Update(msg)
379 m.claudeAuthMethodChooser = u.(*claude.AuthMethodChooser)
380 return m, cmd
381 case m.showClaudeOAuth2:
382 u, cmd := m.claudeOAuth2.Update(msg)
383 m.claudeOAuth2 = u.(*claude.OAuth2)
384 return m, cmd
385 case m.needsAPIKey:
386 u, cmd := m.apiKeyInput.Update(msg)
387 m.apiKeyInput = u.(*APIKeyInput)
388 return m, cmd
389 default:
390 u, cmd := m.modelList.Update(msg)
391 m.modelList = u
392 return m, cmd
393 }
394 }
395 case tea.PasteMsg:
396 switch {
397 case m.showClaudeOAuth2:
398 u, cmd := m.claudeOAuth2.Update(msg)
399 m.claudeOAuth2 = u.(*claude.OAuth2)
400 return m, cmd
401 case m.needsAPIKey:
402 u, cmd := m.apiKeyInput.Update(msg)
403 m.apiKeyInput = u.(*APIKeyInput)
404 return m, cmd
405 default:
406 var cmd tea.Cmd
407 m.modelList, cmd = m.modelList.Update(msg)
408 return m, cmd
409 }
410 case spinner.TickMsg:
411 u, cmd := m.apiKeyInput.Update(msg)
412 m.apiKeyInput = u.(*APIKeyInput)
413 if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
414 u, cmd = m.hyperDeviceFlow.Update(msg)
415 m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
416 }
417 if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
418 u, cmd = m.copilotDeviceFlow.Update(msg)
419 m.copilotDeviceFlow = u.(*copilot.DeviceFlow)
420 }
421 return m, cmd
422 default:
423 // Pass all other messages to the device flow for spinner animation
424 switch {
425 case m.showHyperDeviceFlow && m.hyperDeviceFlow != nil:
426 u, cmd := m.hyperDeviceFlow.Update(msg)
427 m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
428 return m, cmd
429 case m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil:
430 u, cmd := m.copilotDeviceFlow.Update(msg)
431 m.copilotDeviceFlow = u.(*copilot.DeviceFlow)
432 return m, cmd
433 case m.showClaudeOAuth2:
434 u, cmd := m.claudeOAuth2.Update(msg)
435 m.claudeOAuth2 = u.(*claude.OAuth2)
436 return m, cmd
437 default:
438 u, cmd := m.apiKeyInput.Update(msg)
439 m.apiKeyInput = u.(*APIKeyInput)
440 return m, cmd
441 }
442 }
443 return m, nil
444}
445
446func (m *modelDialogCmp) View() string {
447 t := styles.CurrentTheme()
448
449 if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
450 // Show Hyper device flow
451 m.keyMap.isHyperDeviceFlow = true
452 deviceFlowView := m.hyperDeviceFlow.View()
453 content := lipgloss.JoinVertical(
454 lipgloss.Left,
455 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Authenticate with Hyper", m.width-4)),
456 deviceFlowView,
457 "",
458 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
459 )
460 return m.style().Render(content)
461 }
462 if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
463 // Show Hyper device flow
464 m.keyMap.isCopilotDeviceFlow = m.copilotDeviceFlow.State != copilot.DeviceFlowStateUnavailable
465 m.keyMap.isCopilotUnavailable = m.copilotDeviceFlow.State == copilot.DeviceFlowStateUnavailable
466 deviceFlowView := m.copilotDeviceFlow.View()
467 content := lipgloss.JoinVertical(
468 lipgloss.Left,
469 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Authenticate with GitHub Copilot", m.width-4)),
470 deviceFlowView,
471 "",
472 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
473 )
474 return m.style().Render(content)
475 }
476
477 // Reset the flags when not showing device flow
478 m.keyMap.isHyperDeviceFlow = false
479 m.keyMap.isCopilotDeviceFlow = false
480 m.keyMap.isCopilotUnavailable = false
481
482 switch {
483 case m.showClaudeAuthMethodChooser:
484 chooserView := m.claudeAuthMethodChooser.View()
485 content := lipgloss.JoinVertical(
486 lipgloss.Left,
487 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
488 chooserView,
489 "",
490 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
491 )
492 return m.style().Render(content)
493 case m.showClaudeOAuth2:
494 m.keyMap.isClaudeOAuthURLState = m.claudeOAuth2.State == claude.OAuthStateURL
495 oauth2View := m.claudeOAuth2.View()
496 content := lipgloss.JoinVertical(
497 lipgloss.Left,
498 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
499 oauth2View,
500 "",
501 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
502 )
503 return m.style().Render(content)
504 case m.needsAPIKey:
505 // Show API key input
506 m.keyMap.isAPIKeyHelp = true
507 m.keyMap.isAPIKeyValid = m.isAPIKeyValid
508 apiKeyView := m.apiKeyInput.View()
509 apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
510 content := lipgloss.JoinVertical(
511 lipgloss.Left,
512 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
513 apiKeyView,
514 "",
515 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
516 )
517 return m.style().Render(content)
518 }
519
520 // Show model selection
521 listView := m.modelList.View()
522 radio := m.modelTypeRadio()
523 content := lipgloss.JoinVertical(
524 lipgloss.Left,
525 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
526 listView,
527 "",
528 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
529 )
530 return m.style().Render(content)
531}
532
533func (m *modelDialogCmp) Cursor() *tea.Cursor {
534 if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
535 return m.hyperDeviceFlow.Cursor()
536 }
537 if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
538 return m.copilotDeviceFlow.Cursor()
539 }
540 if m.showClaudeAuthMethodChooser {
541 return nil
542 }
543 if m.showClaudeOAuth2 {
544 if cursor := m.claudeOAuth2.CodeInput.Cursor(); cursor != nil {
545 cursor.Y += 2 // FIXME(@andreynering): Why do we need this?
546 return m.moveCursor(cursor)
547 }
548 return nil
549 }
550 if m.needsAPIKey {
551 cursor := m.apiKeyInput.Cursor()
552 if cursor != nil {
553 cursor = m.moveCursor(cursor)
554 return cursor
555 }
556 } else {
557 cursor := m.modelList.Cursor()
558 if cursor != nil {
559 cursor = m.moveCursor(cursor)
560 return cursor
561 }
562 }
563 return nil
564}
565
566func (m *modelDialogCmp) style() lipgloss.Style {
567 t := styles.CurrentTheme()
568 return t.S().Base.
569 Width(m.width).
570 Border(lipgloss.RoundedBorder()).
571 BorderForeground(t.BorderFocus)
572}
573
574func (m *modelDialogCmp) listWidth() int {
575 return m.width - 2
576}
577
578func (m *modelDialogCmp) listHeight() int {
579 return m.wHeight / 2
580}
581
582func (m *modelDialogCmp) Position() (int, int) {
583 row := m.wHeight/4 - 2 // just a bit above the center
584 col := m.wWidth / 2
585 col -= m.width / 2
586 return row, col
587}
588
589func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
590 row, col := m.Position()
591 if m.needsAPIKey {
592 offset := row + 3 // Border + title + API key input offset
593 cursor.Y += offset
594 cursor.X = cursor.X + col + 2
595 } else {
596 offset := row + 3 // Border + title
597 cursor.Y += offset
598 cursor.X = cursor.X + col + 2
599 }
600 return cursor
601}
602
603func (m *modelDialogCmp) ID() dialogs.DialogID {
604 return ModelsDialogID
605}
606
607func (m *modelDialogCmp) modelTypeRadio() string {
608 t := styles.CurrentTheme()
609 choices := []string{"Large Task", "Small Task"}
610 iconSelected := "◉"
611 iconUnselected := "○"
612 if m.modelList.GetModelType() == LargeModelType {
613 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1])
614 }
615 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1])
616}
617
618func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
619 cfg := config.Get()
620 _, ok := cfg.Providers.Get(providerID)
621 return ok
622}
623
624func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
625 cfg := config.Get()
626 providers, err := config.Providers(cfg)
627 if err != nil {
628 return nil, err
629 }
630 for _, p := range providers {
631 if p.ID == providerID {
632 return &p, nil
633 }
634 }
635 return nil, nil
636}
637
638func (m *modelDialogCmp) saveOauthTokenAndContinue(apiKey any, close bool) tea.Cmd {
639 if m.selectedModel == nil {
640 return util.ReportError(fmt.Errorf("no model selected"))
641 }
642
643 cfg := config.Get()
644 err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
645 if err != nil {
646 return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
647 }
648
649 // Reset API key state and continue with model selection
650 selectedModel := *m.selectedModel
651 var cmds []tea.Cmd
652 if close {
653 cmds = append(cmds, util.CmdHandler(dialogs.CloseDialogMsg{}))
654 }
655 cmds = append(
656 cmds,
657 util.CmdHandler(ModelSelectedMsg{
658 Model: config.SelectedModel{
659 Model: selectedModel.Model.ID,
660 Provider: string(selectedModel.Provider.ID),
661 ReasoningEffort: selectedModel.Model.DefaultReasoningEffort,
662 MaxTokens: selectedModel.Model.DefaultMaxTokens,
663 },
664 ModelType: m.selectedModelType,
665 }),
666 )
667 return tea.Sequence(cmds...)
668}