models.go

  1// Package models provides the model selection dialog for the TUI.
  2package models
  3
  4import (
  5	"fmt"
  6	"time"
  7
  8	"charm.land/bubbles/v2/help"
  9	"charm.land/bubbles/v2/key"
 10	"charm.land/bubbles/v2/spinner"
 11	tea "charm.land/bubbletea/v2"
 12	"charm.land/lipgloss/v2"
 13	"github.com/atotto/clipboard"
 14	"github.com/charmbracelet/catwalk/pkg/catwalk"
 15	hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/tui/components/core"
 18	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 19	"github.com/charmbracelet/crush/internal/tui/components/dialogs/claude"
 20	"github.com/charmbracelet/crush/internal/tui/components/dialogs/copilot"
 21	"github.com/charmbracelet/crush/internal/tui/components/dialogs/hyper"
 22	"github.com/charmbracelet/crush/internal/tui/exp/list"
 23	"github.com/charmbracelet/crush/internal/tui/styles"
 24	"github.com/charmbracelet/crush/internal/tui/util"
 25)
 26
 27const (
 28	ModelsDialogID dialogs.DialogID = "models"
 29
 30	defaultWidth = 60
 31)
 32
 33const (
 34	LargeModelType int = iota
 35	SmallModelType
 36
 37	largeModelInputPlaceholder = "Choose a model for large, complex tasks"
 38	smallModelInputPlaceholder = "Choose a model for small, simple tasks"
 39)
 40
 41// ModelSelectedMsg is sent when a model is selected
 42type ModelSelectedMsg struct {
 43	Model     config.SelectedModel
 44	ModelType config.SelectedModelType
 45}
 46
 47// CloseModelDialogMsg is sent when a model is selected
 48type CloseModelDialogMsg struct{}
 49
 50// ModelDialog interface for the model selection dialog
 51type ModelDialog interface {
 52	dialogs.DialogModel
 53}
 54
 55type ModelOption struct {
 56	Provider catwalk.Provider
 57	Model    catwalk.Model
 58}
 59
 60type modelDialogCmp struct {
 61	width   int
 62	wWidth  int
 63	wHeight int
 64
 65	modelList *ModelListComponent
 66	keyMap    KeyMap
 67	help      help.Model
 68
 69	// API key state
 70	needsAPIKey       bool
 71	apiKeyInput       *APIKeyInput
 72	selectedModel     *ModelOption
 73	selectedModelType config.SelectedModelType
 74	isAPIKeyValid     bool
 75	apiKeyValue       string
 76
 77	// Hyper device flow state
 78	hyperDeviceFlow     *hyper.DeviceFlow
 79	showHyperDeviceFlow bool
 80
 81	// Copilot device flow state
 82	copilotDeviceFlow     *copilot.DeviceFlow
 83	showCopilotDeviceFlow bool
 84
 85	// Claude state
 86	claudeAuthMethodChooser     *claude.AuthMethodChooser
 87	claudeOAuth2                *claude.OAuth2
 88	showClaudeAuthMethodChooser bool
 89	showClaudeOAuth2            bool
 90}
 91
 92func NewModelDialogCmp() ModelDialog {
 93	keyMap := DefaultKeyMap()
 94
 95	listKeyMap := list.DefaultKeyMap()
 96	listKeyMap.Down.SetEnabled(false)
 97	listKeyMap.Up.SetEnabled(false)
 98	listKeyMap.DownOneItem = keyMap.Next
 99	listKeyMap.UpOneItem = keyMap.Previous
100
101	t := styles.CurrentTheme()
102	modelList := NewModelListComponent(listKeyMap, largeModelInputPlaceholder, true)
103	apiKeyInput := NewAPIKeyInput()
104	apiKeyInput.SetShowTitle(false)
105	help := help.New()
106	help.Styles = t.S().Help
107
108	return &modelDialogCmp{
109		modelList:   modelList,
110		apiKeyInput: apiKeyInput,
111		width:       defaultWidth,
112		keyMap:      DefaultKeyMap(),
113		help:        help,
114
115		claudeAuthMethodChooser: claude.NewAuthMethodChooser(),
116		claudeOAuth2:            claude.NewOAuth2(),
117	}
118}
119
120func (m *modelDialogCmp) Init() tea.Cmd {
121	return tea.Batch(
122		m.modelList.Init(),
123		m.apiKeyInput.Init(),
124		m.claudeAuthMethodChooser.Init(),
125		m.claudeOAuth2.Init(),
126	)
127}
128
129func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
130	switch msg := msg.(type) {
131	case tea.WindowSizeMsg:
132		m.wWidth = msg.Width
133		m.wHeight = msg.Height
134		m.apiKeyInput.SetWidth(m.width - 2)
135		m.help.SetWidth(m.width - 2)
136		m.claudeAuthMethodChooser.SetWidth(m.width - 2)
137		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
138	case APIKeyStateChangeMsg:
139		u, cmd := m.apiKeyInput.Update(msg)
140		m.apiKeyInput = u.(*APIKeyInput)
141		return m, cmd
142	case hyper.DeviceFlowCompletedMsg:
143		return m, m.saveOauthTokenAndContinue(msg.Token, true)
144	case hyper.DeviceAuthInitiatedMsg, hyper.DeviceFlowErrorMsg:
145		if m.hyperDeviceFlow != nil {
146			u, cmd := m.hyperDeviceFlow.Update(msg)
147			m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
148			return m, cmd
149		}
150		return m, nil
151	case copilot.DeviceAuthInitiatedMsg, copilot.DeviceFlowErrorMsg:
152		if m.copilotDeviceFlow != nil {
153			u, cmd := m.copilotDeviceFlow.Update(msg)
154			m.copilotDeviceFlow = u.(*copilot.DeviceFlow)
155			return m, cmd
156		}
157		return m, nil
158	case copilot.DeviceFlowCompletedMsg:
159		return m, m.saveOauthTokenAndContinue(msg.Token, true)
160	case claude.ValidationCompletedMsg:
161		var cmds []tea.Cmd
162		u, cmd := m.claudeOAuth2.Update(msg)
163		m.claudeOAuth2 = u.(*claude.OAuth2)
164		cmds = append(cmds, cmd)
165
166		if msg.State == claude.OAuthValidationStateValid {
167			cmds = append(cmds, m.saveOauthTokenAndContinue(msg.Token, false))
168			m.keyMap.isClaudeOAuthHelpComplete = true
169		}
170
171		return m, tea.Batch(cmds...)
172	case claude.AuthenticationCompleteMsg:
173		return m, util.CmdHandler(dialogs.CloseDialogMsg{})
174	case tea.KeyPressMsg:
175		switch {
176		// Handle Hyper device flow keys
177		case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && (m.showHyperDeviceFlow || m.showCopilotDeviceFlow):
178			if m.hyperDeviceFlow != nil {
179				return m, m.hyperDeviceFlow.CopyCode()
180			}
181			if m.copilotDeviceFlow != nil {
182				return m, m.copilotDeviceFlow.CopyCode()
183			}
184		case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showClaudeOAuth2 && m.claudeOAuth2.State == claude.OAuthStateURL:
185			return m, tea.Sequence(
186				tea.SetClipboard(m.claudeOAuth2.URL),
187				func() tea.Msg {
188					_ = clipboard.WriteAll(m.claudeOAuth2.URL)
189					return nil
190				},
191				util.ReportInfo("URL copied to clipboard"),
192			)
193		case key.Matches(msg, m.keyMap.Choose) && m.showClaudeAuthMethodChooser:
194			m.claudeAuthMethodChooser.ToggleChoice()
195			return m, nil
196		case key.Matches(msg, m.keyMap.Select):
197			// If showing device flow, enter copies code and opens URL
198			if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
199				return m, m.hyperDeviceFlow.CopyCodeAndOpenURL()
200			}
201			if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
202				return m, m.copilotDeviceFlow.CopyCodeAndOpenURL()
203			}
204			selectedItem := m.modelList.SelectedModel()
205
206			modelType := config.SelectedModelTypeLarge
207			if m.modelList.GetModelType() == SmallModelType {
208				modelType = config.SelectedModelTypeSmall
209			}
210
211			askForApiKey := func() {
212				m.keyMap.isClaudeAuthChoiceHelp = false
213				m.keyMap.isClaudeOAuthHelp = false
214				m.keyMap.isAPIKeyHelp = true
215				m.showHyperDeviceFlow = false
216				m.showCopilotDeviceFlow = false
217				m.showClaudeAuthMethodChooser = false
218				m.needsAPIKey = true
219				m.selectedModel = selectedItem
220				m.selectedModelType = modelType
221				m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
222			}
223
224			if m.showClaudeAuthMethodChooser {
225				switch m.claudeAuthMethodChooser.State {
226				case claude.AuthMethodAPIKey:
227					askForApiKey()
228				case claude.AuthMethodOAuth2:
229					m.selectedModel = selectedItem
230					m.selectedModelType = modelType
231					m.showClaudeAuthMethodChooser = false
232					m.showClaudeOAuth2 = true
233					m.keyMap.isClaudeAuthChoiceHelp = false
234					m.keyMap.isClaudeOAuthHelp = true
235				}
236				return m, nil
237			}
238			if m.showClaudeOAuth2 {
239				m2, cmd2 := m.claudeOAuth2.ValidationConfirm()
240				m.claudeOAuth2 = m2.(*claude.OAuth2)
241				return m, cmd2
242			}
243			if m.isAPIKeyValid {
244				return m, m.saveOauthTokenAndContinue(m.apiKeyValue, true)
245			}
246			if m.needsAPIKey {
247				// Handle API key submission
248				m.apiKeyValue = m.apiKeyInput.Value()
249				provider, err := m.getProvider(m.selectedModel.Provider.ID)
250				if err != nil || provider == nil {
251					return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
252				}
253				providerConfig := config.ProviderConfig{
254					ID:      string(m.selectedModel.Provider.ID),
255					Name:    m.selectedModel.Provider.Name,
256					APIKey:  m.apiKeyValue,
257					Type:    provider.Type,
258					BaseURL: provider.APIEndpoint,
259				}
260				return m, tea.Sequence(
261					util.CmdHandler(APIKeyStateChangeMsg{
262						State: APIKeyInputStateVerifying,
263					}),
264					func() tea.Msg {
265						start := time.Now()
266						err := providerConfig.TestConnection(config.Get().Resolver())
267						// intentionally wait for at least 750ms to make sure the user sees the spinner
268						elapsed := time.Since(start)
269						if elapsed < 750*time.Millisecond {
270							time.Sleep(750*time.Millisecond - elapsed)
271						}
272						if err == nil {
273							m.isAPIKeyValid = true
274							return APIKeyStateChangeMsg{
275								State: APIKeyInputStateVerified,
276							}
277						}
278						return APIKeyStateChangeMsg{
279							State: APIKeyInputStateError,
280						}
281					},
282				)
283			}
284
285			// Check if provider is configured
286			if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
287				return m, tea.Sequence(
288					util.CmdHandler(dialogs.CloseDialogMsg{}),
289					util.CmdHandler(ModelSelectedMsg{
290						Model: config.SelectedModel{
291							Model:           selectedItem.Model.ID,
292							Provider:        string(selectedItem.Provider.ID),
293							ReasoningEffort: selectedItem.Model.DefaultReasoningEffort,
294							MaxTokens:       selectedItem.Model.DefaultMaxTokens,
295						},
296						ModelType: modelType,
297					}),
298				)
299			}
300			switch selectedItem.Provider.ID {
301			case catwalk.InferenceProviderAnthropic:
302				m.showClaudeAuthMethodChooser = true
303				m.keyMap.isClaudeAuthChoiceHelp = true
304				return m, nil
305			case hyperp.Name:
306				m.showHyperDeviceFlow = true
307				m.selectedModel = selectedItem
308				m.selectedModelType = modelType
309				m.hyperDeviceFlow = hyper.NewDeviceFlow()
310				m.hyperDeviceFlow.SetWidth(m.width - 2)
311				return m, m.hyperDeviceFlow.Init()
312			case catwalk.InferenceProviderCopilot:
313				m.showCopilotDeviceFlow = true
314				m.selectedModel = selectedItem
315				m.selectedModelType = modelType
316				m.copilotDeviceFlow = copilot.NewDeviceFlow()
317				m.copilotDeviceFlow.SetWidth(m.width - 2)
318				return m, m.copilotDeviceFlow.Init()
319			}
320			// For other providers, show API key input
321			askForApiKey()
322			return m, nil
323		case key.Matches(msg, m.keyMap.Tab):
324			switch {
325			case m.showClaudeAuthMethodChooser:
326				m.claudeAuthMethodChooser.ToggleChoice()
327				return m, nil
328			case m.needsAPIKey:
329				u, cmd := m.apiKeyInput.Update(msg)
330				m.apiKeyInput = u.(*APIKeyInput)
331				return m, cmd
332			case m.modelList.GetModelType() == LargeModelType:
333				m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
334				return m, m.modelList.SetModelType(SmallModelType)
335			default:
336				m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
337				return m, m.modelList.SetModelType(LargeModelType)
338			}
339		case key.Matches(msg, m.keyMap.Close):
340			if m.showHyperDeviceFlow {
341				if m.hyperDeviceFlow != nil {
342					m.hyperDeviceFlow.Cancel()
343				}
344				m.showHyperDeviceFlow = false
345				m.selectedModel = nil
346			}
347			if m.showCopilotDeviceFlow {
348				if m.copilotDeviceFlow != nil {
349					m.copilotDeviceFlow.Cancel()
350				}
351				m.showCopilotDeviceFlow = false
352				m.selectedModel = nil
353			}
354			if m.showClaudeAuthMethodChooser {
355				m.claudeAuthMethodChooser.SetDefaults()
356				m.showClaudeAuthMethodChooser = false
357				m.keyMap.isClaudeAuthChoiceHelp = false
358				m.keyMap.isClaudeOAuthHelp = false
359				return m, nil
360			}
361			if m.needsAPIKey {
362				if m.isAPIKeyValid {
363					return m, nil
364				}
365				// Go back to model selection
366				m.needsAPIKey = false
367				m.selectedModel = nil
368				m.isAPIKeyValid = false
369				m.apiKeyValue = ""
370				m.apiKeyInput.Reset()
371				return m, nil
372			}
373			return m, util.CmdHandler(dialogs.CloseDialogMsg{})
374		default:
375			if m.showClaudeAuthMethodChooser {
376				u, cmd := m.claudeAuthMethodChooser.Update(msg)
377				m.claudeAuthMethodChooser = u.(*claude.AuthMethodChooser)
378				return m, cmd
379			} else if m.showClaudeOAuth2 {
380				u, cmd := m.claudeOAuth2.Update(msg)
381				m.claudeOAuth2 = u.(*claude.OAuth2)
382				return m, cmd
383			} else if m.needsAPIKey {
384				u, cmd := m.apiKeyInput.Update(msg)
385				m.apiKeyInput = u.(*APIKeyInput)
386				return m, cmd
387			} else {
388				u, cmd := m.modelList.Update(msg)
389				m.modelList = u
390				return m, cmd
391			}
392		}
393	case tea.PasteMsg:
394		if m.showClaudeOAuth2 {
395			u, cmd := m.claudeOAuth2.Update(msg)
396			m.claudeOAuth2 = u.(*claude.OAuth2)
397			return m, cmd
398		} else if m.needsAPIKey {
399			u, cmd := m.apiKeyInput.Update(msg)
400			m.apiKeyInput = u.(*APIKeyInput)
401			return m, cmd
402		} else {
403			var cmd tea.Cmd
404			m.modelList, cmd = m.modelList.Update(msg)
405			return m, cmd
406		}
407	case spinner.TickMsg:
408		u, cmd := m.apiKeyInput.Update(msg)
409		m.apiKeyInput = u.(*APIKeyInput)
410		if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
411			u, cmd = m.hyperDeviceFlow.Update(msg)
412			m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
413		}
414		if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
415			u, cmd = m.copilotDeviceFlow.Update(msg)
416			m.copilotDeviceFlow = u.(*copilot.DeviceFlow)
417		}
418		return m, cmd
419	default:
420		// Pass all other messages to the device flow for spinner animation
421		switch {
422		case m.showHyperDeviceFlow && m.hyperDeviceFlow != nil:
423			u, cmd := m.hyperDeviceFlow.Update(msg)
424			m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
425			return m, cmd
426		case m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil:
427			u, cmd := m.copilotDeviceFlow.Update(msg)
428			m.copilotDeviceFlow = u.(*copilot.DeviceFlow)
429			return m, cmd
430		case m.showClaudeOAuth2:
431			u, cmd := m.claudeOAuth2.Update(msg)
432			m.claudeOAuth2 = u.(*claude.OAuth2)
433			return m, cmd
434		default:
435			u, cmd := m.apiKeyInput.Update(msg)
436			m.apiKeyInput = u.(*APIKeyInput)
437			return m, cmd
438		}
439	}
440	return m, nil
441}
442
443func (m *modelDialogCmp) View() string {
444	t := styles.CurrentTheme()
445
446	if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
447		// Show Hyper device flow
448		m.keyMap.isHyperDeviceFlow = true
449		deviceFlowView := m.hyperDeviceFlow.View()
450		content := lipgloss.JoinVertical(
451			lipgloss.Left,
452			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Authenticate with Hyper", m.width-4)),
453			deviceFlowView,
454			"",
455			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
456		)
457		return m.style().Render(content)
458	}
459	if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
460		// Show Hyper device flow
461		m.keyMap.isCopilotDeviceFlow = m.copilotDeviceFlow.State != copilot.DeviceFlowStateUnavailable
462		m.keyMap.isCopilotUnavailable = m.copilotDeviceFlow.State == copilot.DeviceFlowStateUnavailable
463		deviceFlowView := m.copilotDeviceFlow.View()
464		content := lipgloss.JoinVertical(
465			lipgloss.Left,
466			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Authenticate with GitHub Copilot", m.width-4)),
467			deviceFlowView,
468			"",
469			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
470		)
471		return m.style().Render(content)
472	}
473
474	// Reset the flags when not showing device flow
475	m.keyMap.isHyperDeviceFlow = false
476	m.keyMap.isCopilotDeviceFlow = false
477	m.keyMap.isCopilotUnavailable = false
478
479	switch {
480	case m.showClaudeAuthMethodChooser:
481		chooserView := m.claudeAuthMethodChooser.View()
482		content := lipgloss.JoinVertical(
483			lipgloss.Left,
484			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
485			chooserView,
486			"",
487			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
488		)
489		return m.style().Render(content)
490	case m.showClaudeOAuth2:
491		m.keyMap.isClaudeOAuthURLState = m.claudeOAuth2.State == claude.OAuthStateURL
492		oauth2View := m.claudeOAuth2.View()
493		content := lipgloss.JoinVertical(
494			lipgloss.Left,
495			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
496			oauth2View,
497			"",
498			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
499		)
500		return m.style().Render(content)
501	case m.needsAPIKey:
502		// Show API key input
503		m.keyMap.isAPIKeyHelp = true
504		m.keyMap.isAPIKeyValid = m.isAPIKeyValid
505		apiKeyView := m.apiKeyInput.View()
506		apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
507		content := lipgloss.JoinVertical(
508			lipgloss.Left,
509			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
510			apiKeyView,
511			"",
512			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
513		)
514		return m.style().Render(content)
515	}
516
517	// Show model selection
518	listView := m.modelList.View()
519	radio := m.modelTypeRadio()
520	content := lipgloss.JoinVertical(
521		lipgloss.Left,
522		t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
523		listView,
524		"",
525		t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
526	)
527	return m.style().Render(content)
528}
529
530func (m *modelDialogCmp) Cursor() *tea.Cursor {
531	if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
532		return m.hyperDeviceFlow.Cursor()
533	}
534	if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
535		return m.copilotDeviceFlow.Cursor()
536	}
537	if m.showClaudeAuthMethodChooser {
538		return nil
539	}
540	if m.showClaudeOAuth2 {
541		if cursor := m.claudeOAuth2.CodeInput.Cursor(); cursor != nil {
542			cursor.Y += 2 // FIXME(@andreynering): Why do we need this?
543			return m.moveCursor(cursor)
544		}
545		return nil
546	}
547	if m.needsAPIKey {
548		cursor := m.apiKeyInput.Cursor()
549		if cursor != nil {
550			cursor = m.moveCursor(cursor)
551			return cursor
552		}
553	} else {
554		cursor := m.modelList.Cursor()
555		if cursor != nil {
556			cursor = m.moveCursor(cursor)
557			return cursor
558		}
559	}
560	return nil
561}
562
563func (m *modelDialogCmp) style() lipgloss.Style {
564	t := styles.CurrentTheme()
565	return t.S().Base.
566		Width(m.width).
567		Border(lipgloss.RoundedBorder()).
568		BorderForeground(t.BorderFocus)
569}
570
571func (m *modelDialogCmp) listWidth() int {
572	return m.width - 2
573}
574
575func (m *modelDialogCmp) listHeight() int {
576	return m.wHeight / 2
577}
578
579func (m *modelDialogCmp) Position() (int, int) {
580	row := m.wHeight/4 - 2 // just a bit above the center
581	col := m.wWidth / 2
582	col -= m.width / 2
583	return row, col
584}
585
586func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
587	row, col := m.Position()
588	if m.needsAPIKey {
589		offset := row + 3 // Border + title + API key input offset
590		cursor.Y += offset
591		cursor.X = cursor.X + col + 2
592	} else {
593		offset := row + 3 // Border + title
594		cursor.Y += offset
595		cursor.X = cursor.X + col + 2
596	}
597	return cursor
598}
599
600func (m *modelDialogCmp) ID() dialogs.DialogID {
601	return ModelsDialogID
602}
603
604func (m *modelDialogCmp) modelTypeRadio() string {
605	t := styles.CurrentTheme()
606	choices := []string{"Large Task", "Small Task"}
607	iconSelected := "◉"
608	iconUnselected := "○"
609	if m.modelList.GetModelType() == LargeModelType {
610		return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + "  " + iconUnselected + " " + choices[1])
611	}
612	return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + "  " + iconSelected + " " + choices[1])
613}
614
615func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
616	cfg := config.Get()
617	_, ok := cfg.Providers.Get(providerID)
618	return ok
619}
620
621func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
622	cfg := config.Get()
623	providers, err := config.Providers(cfg)
624	if err != nil {
625		return nil, err
626	}
627	for _, p := range providers {
628		if p.ID == providerID {
629			return &p, nil
630		}
631	}
632	return nil, nil
633}
634
635func (m *modelDialogCmp) saveOauthTokenAndContinue(apiKey any, close bool) tea.Cmd {
636	if m.selectedModel == nil {
637		return util.ReportError(fmt.Errorf("no model selected"))
638	}
639
640	cfg := config.Get()
641	err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
642	if err != nil {
643		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
644	}
645
646	// Reset API key state and continue with model selection
647	selectedModel := *m.selectedModel
648	var cmds []tea.Cmd
649	if close {
650		cmds = append(cmds, util.CmdHandler(dialogs.CloseDialogMsg{}))
651	}
652	cmds = append(
653		cmds,
654		util.CmdHandler(ModelSelectedMsg{
655			Model: config.SelectedModel{
656				Model:           selectedModel.Model.ID,
657				Provider:        string(selectedModel.Provider.ID),
658				ReasoningEffort: selectedModel.Model.DefaultReasoningEffort,
659				MaxTokens:       selectedModel.Model.DefaultMaxTokens,
660			},
661			ModelType: m.selectedModelType,
662		}),
663	)
664	return tea.Sequence(cmds...)
665}