models.go

  1// Package models provides the model selection dialog for the TUI.
  2package models
  3
  4import (
  5	"fmt"
  6	"time"
  7
  8	"charm.land/bubbles/v2/help"
  9	"charm.land/bubbles/v2/key"
 10	"charm.land/bubbles/v2/spinner"
 11	tea "charm.land/bubbletea/v2"
 12	"charm.land/lipgloss/v2"
 13	"github.com/atotto/clipboard"
 14	"github.com/charmbracelet/catwalk/pkg/catwalk"
 15	hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/tui/components/core"
 18	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 19	"github.com/charmbracelet/crush/internal/tui/components/dialogs/claude"
 20	"github.com/charmbracelet/crush/internal/tui/components/dialogs/copilot"
 21	"github.com/charmbracelet/crush/internal/tui/components/dialogs/hyper"
 22	"github.com/charmbracelet/crush/internal/tui/exp/list"
 23	"github.com/charmbracelet/crush/internal/tui/styles"
 24	"github.com/charmbracelet/crush/internal/tui/util"
 25)
 26
 27const (
 28	ModelsDialogID dialogs.DialogID = "models"
 29
 30	defaultWidth = 60
 31)
 32
 33const (
 34	LargeModelType int = iota
 35	SmallModelType
 36
 37	largeModelInputPlaceholder = "Choose a model for large, complex tasks"
 38	smallModelInputPlaceholder = "Choose a model for small, simple tasks"
 39)
 40
 41// ModelSelectedMsg is sent when a model is selected
 42type ModelSelectedMsg struct {
 43	Model     config.SelectedModel
 44	ModelType config.SelectedModelType
 45}
 46
 47// CloseModelDialogMsg is sent when a model is selected
 48type CloseModelDialogMsg struct{}
 49
 50// ModelDialog interface for the model selection dialog
 51type ModelDialog interface {
 52	dialogs.DialogModel
 53}
 54
 55type ModelOption struct {
 56	Provider catwalk.Provider
 57	Model    catwalk.Model
 58}
 59
 60type modelDialogCmp struct {
 61	width   int
 62	wWidth  int
 63	wHeight int
 64
 65	modelList *ModelListComponent
 66	keyMap    KeyMap
 67	help      help.Model
 68
 69	// API key state
 70	needsAPIKey       bool
 71	apiKeyInput       *APIKeyInput
 72	selectedModel     *ModelOption
 73	selectedModelType config.SelectedModelType
 74	isAPIKeyValid     bool
 75	apiKeyValue       string
 76
 77	// Hyper device flow state
 78	hyperDeviceFlow     *hyper.DeviceFlow
 79	showHyperDeviceFlow bool
 80
 81	// Copilot device flow state
 82	copilotDeviceFlow     *copilot.DeviceFlow
 83	showCopilotDeviceFlow bool
 84
 85	// Claude state
 86	claudeAuthMethodChooser     *claude.AuthMethodChooser
 87	claudeOAuth2                *claude.OAuth2
 88	showClaudeAuthMethodChooser bool
 89	showClaudeOAuth2            bool
 90}
 91
 92func NewModelDialogCmp() ModelDialog {
 93	keyMap := DefaultKeyMap()
 94
 95	listKeyMap := list.DefaultKeyMap()
 96	listKeyMap.Down.SetEnabled(false)
 97	listKeyMap.Up.SetEnabled(false)
 98	listKeyMap.DownOneItem = keyMap.Next
 99	listKeyMap.UpOneItem = keyMap.Previous
100
101	t := styles.CurrentTheme()
102	modelList := NewModelListComponent(listKeyMap, largeModelInputPlaceholder, true)
103	apiKeyInput := NewAPIKeyInput()
104	apiKeyInput.SetShowTitle(false)
105	help := help.New()
106	help.Styles = t.S().Help
107
108	return &modelDialogCmp{
109		modelList:   modelList,
110		apiKeyInput: apiKeyInput,
111		width:       defaultWidth,
112		keyMap:      DefaultKeyMap(),
113		help:        help,
114
115		claudeAuthMethodChooser: claude.NewAuthMethodChooser(),
116		claudeOAuth2:            claude.NewOAuth2(),
117	}
118}
119
120func (m *modelDialogCmp) Init() tea.Cmd {
121	return tea.Batch(
122		m.modelList.Init(),
123		m.apiKeyInput.Init(),
124		m.claudeAuthMethodChooser.Init(),
125		m.claudeOAuth2.Init(),
126	)
127}
128
129func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
130	switch msg := msg.(type) {
131	case tea.WindowSizeMsg:
132		m.wWidth = msg.Width
133		m.wHeight = msg.Height
134		m.apiKeyInput.SetWidth(m.width - 2)
135		m.help.SetWidth(m.width - 2)
136		m.claudeAuthMethodChooser.SetWidth(m.width - 2)
137		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
138	case APIKeyStateChangeMsg:
139		u, cmd := m.apiKeyInput.Update(msg)
140		m.apiKeyInput = u.(*APIKeyInput)
141		return m, cmd
142	case hyper.DeviceFlowCompletedMsg:
143		return m, m.saveOauthTokenAndContinue(msg.Token, true)
144	case hyper.DeviceAuthInitiatedMsg, hyper.DeviceFlowErrorMsg:
145		if m.hyperDeviceFlow != nil {
146			u, cmd := m.hyperDeviceFlow.Update(msg)
147			m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
148			return m, cmd
149		}
150		return m, nil
151	case copilot.DeviceAuthInitiatedMsg, copilot.DeviceFlowErrorMsg:
152		if m.copilotDeviceFlow != nil {
153			u, cmd := m.copilotDeviceFlow.Update(msg)
154			m.copilotDeviceFlow = u.(*copilot.DeviceFlow)
155			return m, cmd
156		}
157		return m, nil
158	case copilot.DeviceFlowCompletedMsg:
159		return m, m.saveOauthTokenAndContinue(msg.Token, true)
160	case claude.ValidationCompletedMsg:
161		var cmds []tea.Cmd
162		u, cmd := m.claudeOAuth2.Update(msg)
163		m.claudeOAuth2 = u.(*claude.OAuth2)
164		cmds = append(cmds, cmd)
165
166		if msg.State == claude.OAuthValidationStateValid {
167			cmds = append(cmds, m.saveOauthTokenAndContinue(msg.Token, false))
168			m.keyMap.isClaudeOAuthHelpComplete = true
169		}
170
171		return m, tea.Batch(cmds...)
172	case claude.AuthenticationCompleteMsg:
173		return m, util.CmdHandler(dialogs.CloseDialogMsg{})
174	case tea.KeyPressMsg:
175		switch {
176		// Handle Hyper device flow keys
177		case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showHyperDeviceFlow:
178			return m, m.hyperDeviceFlow.CopyCode()
179		case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showCopilotDeviceFlow:
180			return m, m.copilotDeviceFlow.CopyCode()
181		case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showClaudeOAuth2 && m.claudeOAuth2.State == claude.OAuthStateURL:
182			return m, tea.Sequence(
183				tea.SetClipboard(m.claudeOAuth2.URL),
184				func() tea.Msg {
185					_ = clipboard.WriteAll(m.claudeOAuth2.URL)
186					return nil
187				},
188				util.ReportInfo("URL copied to clipboard"),
189			)
190		case key.Matches(msg, m.keyMap.Choose) && m.showClaudeAuthMethodChooser:
191			m.claudeAuthMethodChooser.ToggleChoice()
192			return m, nil
193		case key.Matches(msg, m.keyMap.Select):
194			// If showing device flow, enter copies code and opens URL
195			if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
196				return m, m.hyperDeviceFlow.CopyCodeAndOpenURL()
197			}
198			if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
199				return m, m.copilotDeviceFlow.CopyCodeAndOpenURL()
200			}
201			selectedItem := m.modelList.SelectedModel()
202
203			modelType := config.SelectedModelTypeLarge
204			if m.modelList.GetModelType() == SmallModelType {
205				modelType = config.SelectedModelTypeSmall
206			}
207
208			askForApiKey := func() {
209				m.keyMap.isClaudeAuthChoiceHelp = false
210				m.keyMap.isClaudeOAuthHelp = false
211				m.keyMap.isAPIKeyHelp = true
212				m.showHyperDeviceFlow = false
213				m.showCopilotDeviceFlow = false
214				m.showClaudeAuthMethodChooser = false
215				m.needsAPIKey = true
216				m.selectedModel = selectedItem
217				m.selectedModelType = modelType
218				m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
219			}
220
221			if m.showClaudeAuthMethodChooser {
222				switch m.claudeAuthMethodChooser.State {
223				case claude.AuthMethodAPIKey:
224					askForApiKey()
225				case claude.AuthMethodOAuth2:
226					m.selectedModel = selectedItem
227					m.selectedModelType = modelType
228					m.showClaudeAuthMethodChooser = false
229					m.showClaudeOAuth2 = true
230					m.keyMap.isClaudeAuthChoiceHelp = false
231					m.keyMap.isClaudeOAuthHelp = true
232				}
233				return m, nil
234			}
235			if m.showClaudeOAuth2 {
236				m2, cmd2 := m.claudeOAuth2.ValidationConfirm()
237				m.claudeOAuth2 = m2.(*claude.OAuth2)
238				return m, cmd2
239			}
240			if m.isAPIKeyValid {
241				return m, m.saveOauthTokenAndContinue(m.apiKeyValue, true)
242			}
243			if m.needsAPIKey {
244				// Handle API key submission
245				m.apiKeyValue = m.apiKeyInput.Value()
246				provider, err := m.getProvider(m.selectedModel.Provider.ID)
247				if err != nil || provider == nil {
248					return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
249				}
250				providerConfig := config.ProviderConfig{
251					ID:      string(m.selectedModel.Provider.ID),
252					Name:    m.selectedModel.Provider.Name,
253					APIKey:  m.apiKeyValue,
254					Type:    provider.Type,
255					BaseURL: provider.APIEndpoint,
256				}
257				return m, tea.Sequence(
258					util.CmdHandler(APIKeyStateChangeMsg{
259						State: APIKeyInputStateVerifying,
260					}),
261					func() tea.Msg {
262						start := time.Now()
263						err := providerConfig.TestConnection(config.Get().Resolver())
264						// intentionally wait for at least 750ms to make sure the user sees the spinner
265						elapsed := time.Since(start)
266						if elapsed < 750*time.Millisecond {
267							time.Sleep(750*time.Millisecond - elapsed)
268						}
269						if err == nil {
270							m.isAPIKeyValid = true
271							return APIKeyStateChangeMsg{
272								State: APIKeyInputStateVerified,
273							}
274						}
275						return APIKeyStateChangeMsg{
276							State: APIKeyInputStateError,
277						}
278					},
279				)
280			}
281
282			// Check if provider is configured
283			if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
284				return m, tea.Sequence(
285					util.CmdHandler(dialogs.CloseDialogMsg{}),
286					util.CmdHandler(ModelSelectedMsg{
287						Model: config.SelectedModel{
288							Model:           selectedItem.Model.ID,
289							Provider:        string(selectedItem.Provider.ID),
290							ReasoningEffort: selectedItem.Model.DefaultReasoningEffort,
291							MaxTokens:       selectedItem.Model.DefaultMaxTokens,
292						},
293						ModelType: modelType,
294					}),
295				)
296			}
297			switch selectedItem.Provider.ID {
298			case catwalk.InferenceProviderAnthropic:
299				m.showClaudeAuthMethodChooser = true
300				m.keyMap.isClaudeAuthChoiceHelp = true
301				return m, nil
302			case hyperp.Name:
303				m.showHyperDeviceFlow = true
304				m.selectedModel = selectedItem
305				m.selectedModelType = modelType
306				m.hyperDeviceFlow = hyper.NewDeviceFlow()
307				m.hyperDeviceFlow.SetWidth(m.width - 2)
308				return m, m.hyperDeviceFlow.Init()
309			case catwalk.InferenceProviderCopilot:
310				if token, ok := config.Get().ImportCopilot(); ok {
311					m.selectedModel = selectedItem
312					m.selectedModelType = modelType
313					return m, m.saveOauthTokenAndContinue(token, true)
314				}
315				m.showCopilotDeviceFlow = true
316				m.selectedModel = selectedItem
317				m.selectedModelType = modelType
318				m.copilotDeviceFlow = copilot.NewDeviceFlow()
319				m.copilotDeviceFlow.SetWidth(m.width - 2)
320				return m, m.copilotDeviceFlow.Init()
321			}
322			// For other providers, show API key input
323			askForApiKey()
324			return m, nil
325		case key.Matches(msg, m.keyMap.Tab):
326			switch {
327			case m.showClaudeAuthMethodChooser:
328				m.claudeAuthMethodChooser.ToggleChoice()
329				return m, nil
330			case m.needsAPIKey:
331				u, cmd := m.apiKeyInput.Update(msg)
332				m.apiKeyInput = u.(*APIKeyInput)
333				return m, cmd
334			case m.modelList.GetModelType() == LargeModelType:
335				m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
336				return m, m.modelList.SetModelType(SmallModelType)
337			default:
338				m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
339				return m, m.modelList.SetModelType(LargeModelType)
340			}
341		case key.Matches(msg, m.keyMap.Close):
342			switch {
343			case m.showHyperDeviceFlow:
344				if m.hyperDeviceFlow != nil {
345					m.hyperDeviceFlow.Cancel()
346				}
347				m.showHyperDeviceFlow = false
348				m.selectedModel = nil
349			case m.showCopilotDeviceFlow:
350				if m.copilotDeviceFlow != nil {
351					m.copilotDeviceFlow.Cancel()
352				}
353				m.showCopilotDeviceFlow = false
354				m.selectedModel = nil
355			case m.showClaudeAuthMethodChooser:
356				m.claudeAuthMethodChooser.SetDefaults()
357				m.showClaudeAuthMethodChooser = false
358				m.keyMap.isClaudeAuthChoiceHelp = false
359				m.keyMap.isClaudeOAuthHelp = false
360				return m, nil
361			case m.needsAPIKey:
362				if m.isAPIKeyValid {
363					return m, nil
364				}
365				// Go back to model selection
366				m.needsAPIKey = false
367				m.selectedModel = nil
368				m.isAPIKeyValid = false
369				m.apiKeyValue = ""
370				m.apiKeyInput.Reset()
371				return m, nil
372			default:
373				return m, util.CmdHandler(dialogs.CloseDialogMsg{})
374			}
375		default:
376			switch {
377			case m.showClaudeAuthMethodChooser:
378				u, cmd := m.claudeAuthMethodChooser.Update(msg)
379				m.claudeAuthMethodChooser = u.(*claude.AuthMethodChooser)
380				return m, cmd
381			case m.showClaudeOAuth2:
382				u, cmd := m.claudeOAuth2.Update(msg)
383				m.claudeOAuth2 = u.(*claude.OAuth2)
384				return m, cmd
385			case m.needsAPIKey:
386				u, cmd := m.apiKeyInput.Update(msg)
387				m.apiKeyInput = u.(*APIKeyInput)
388				return m, cmd
389			default:
390				u, cmd := m.modelList.Update(msg)
391				m.modelList = u
392				return m, cmd
393			}
394		}
395	case tea.PasteMsg:
396		switch {
397		case m.showClaudeOAuth2:
398			u, cmd := m.claudeOAuth2.Update(msg)
399			m.claudeOAuth2 = u.(*claude.OAuth2)
400			return m, cmd
401		case m.needsAPIKey:
402			u, cmd := m.apiKeyInput.Update(msg)
403			m.apiKeyInput = u.(*APIKeyInput)
404			return m, cmd
405		default:
406			var cmd tea.Cmd
407			m.modelList, cmd = m.modelList.Update(msg)
408			return m, cmd
409		}
410	case spinner.TickMsg:
411		u, cmd := m.apiKeyInput.Update(msg)
412		m.apiKeyInput = u.(*APIKeyInput)
413		if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
414			u, cmd = m.hyperDeviceFlow.Update(msg)
415			m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
416		}
417		if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
418			u, cmd = m.copilotDeviceFlow.Update(msg)
419			m.copilotDeviceFlow = u.(*copilot.DeviceFlow)
420		}
421		return m, cmd
422	default:
423		// Pass all other messages to the device flow for spinner animation
424		switch {
425		case m.showHyperDeviceFlow && m.hyperDeviceFlow != nil:
426			u, cmd := m.hyperDeviceFlow.Update(msg)
427			m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
428			return m, cmd
429		case m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil:
430			u, cmd := m.copilotDeviceFlow.Update(msg)
431			m.copilotDeviceFlow = u.(*copilot.DeviceFlow)
432			return m, cmd
433		case m.showClaudeOAuth2:
434			u, cmd := m.claudeOAuth2.Update(msg)
435			m.claudeOAuth2 = u.(*claude.OAuth2)
436			return m, cmd
437		default:
438			u, cmd := m.apiKeyInput.Update(msg)
439			m.apiKeyInput = u.(*APIKeyInput)
440			return m, cmd
441		}
442	}
443	return m, nil
444}
445
446func (m *modelDialogCmp) View() string {
447	t := styles.CurrentTheme()
448
449	if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
450		// Show Hyper device flow
451		m.keyMap.isHyperDeviceFlow = true
452		deviceFlowView := m.hyperDeviceFlow.View()
453		content := lipgloss.JoinVertical(
454			lipgloss.Left,
455			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Authenticate with Hyper", m.width-4)),
456			deviceFlowView,
457			"",
458			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
459		)
460		return m.style().Render(content)
461	}
462	if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
463		// Show Hyper device flow
464		m.keyMap.isCopilotDeviceFlow = m.copilotDeviceFlow.State != copilot.DeviceFlowStateUnavailable
465		m.keyMap.isCopilotUnavailable = m.copilotDeviceFlow.State == copilot.DeviceFlowStateUnavailable
466		deviceFlowView := m.copilotDeviceFlow.View()
467		content := lipgloss.JoinVertical(
468			lipgloss.Left,
469			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Authenticate with GitHub Copilot", m.width-4)),
470			deviceFlowView,
471			"",
472			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
473		)
474		return m.style().Render(content)
475	}
476
477	// Reset the flags when not showing device flow
478	m.keyMap.isHyperDeviceFlow = false
479	m.keyMap.isCopilotDeviceFlow = false
480	m.keyMap.isCopilotUnavailable = false
481
482	switch {
483	case m.showClaudeAuthMethodChooser:
484		chooserView := m.claudeAuthMethodChooser.View()
485		content := lipgloss.JoinVertical(
486			lipgloss.Left,
487			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
488			chooserView,
489			"",
490			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
491		)
492		return m.style().Render(content)
493	case m.showClaudeOAuth2:
494		m.keyMap.isClaudeOAuthURLState = m.claudeOAuth2.State == claude.OAuthStateURL
495		oauth2View := m.claudeOAuth2.View()
496		content := lipgloss.JoinVertical(
497			lipgloss.Left,
498			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
499			oauth2View,
500			"",
501			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
502		)
503		return m.style().Render(content)
504	case m.needsAPIKey:
505		// Show API key input
506		m.keyMap.isAPIKeyHelp = true
507		m.keyMap.isAPIKeyValid = m.isAPIKeyValid
508		apiKeyView := m.apiKeyInput.View()
509		apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
510		content := lipgloss.JoinVertical(
511			lipgloss.Left,
512			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
513			apiKeyView,
514			"",
515			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
516		)
517		return m.style().Render(content)
518	}
519
520	// Show model selection
521	listView := m.modelList.View()
522	radio := m.modelTypeRadio()
523	content := lipgloss.JoinVertical(
524		lipgloss.Left,
525		t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
526		listView,
527		"",
528		t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
529	)
530	return m.style().Render(content)
531}
532
533func (m *modelDialogCmp) Cursor() *tea.Cursor {
534	if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
535		return m.hyperDeviceFlow.Cursor()
536	}
537	if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
538		return m.copilotDeviceFlow.Cursor()
539	}
540	if m.showClaudeAuthMethodChooser {
541		return nil
542	}
543	if m.showClaudeOAuth2 {
544		if cursor := m.claudeOAuth2.CodeInput.Cursor(); cursor != nil {
545			cursor.Y += 2 // FIXME(@andreynering): Why do we need this?
546			return m.moveCursor(cursor)
547		}
548		return nil
549	}
550	if m.needsAPIKey {
551		cursor := m.apiKeyInput.Cursor()
552		if cursor != nil {
553			cursor = m.moveCursor(cursor)
554			return cursor
555		}
556	} else {
557		cursor := m.modelList.Cursor()
558		if cursor != nil {
559			cursor = m.moveCursor(cursor)
560			return cursor
561		}
562	}
563	return nil
564}
565
566func (m *modelDialogCmp) style() lipgloss.Style {
567	t := styles.CurrentTheme()
568	return t.S().Base.
569		Width(m.width).
570		Border(lipgloss.RoundedBorder()).
571		BorderForeground(t.BorderFocus)
572}
573
574func (m *modelDialogCmp) listWidth() int {
575	return m.width - 2
576}
577
578func (m *modelDialogCmp) listHeight() int {
579	return m.wHeight / 2
580}
581
582func (m *modelDialogCmp) Position() (int, int) {
583	row := m.wHeight/4 - 2 // just a bit above the center
584	col := m.wWidth / 2
585	col -= m.width / 2
586	return row, col
587}
588
589func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
590	row, col := m.Position()
591	if m.needsAPIKey {
592		offset := row + 3 // Border + title + API key input offset
593		cursor.Y += offset
594		cursor.X = cursor.X + col + 2
595	} else {
596		offset := row + 3 // Border + title
597		cursor.Y += offset
598		cursor.X = cursor.X + col + 2
599	}
600	return cursor
601}
602
603func (m *modelDialogCmp) ID() dialogs.DialogID {
604	return ModelsDialogID
605}
606
607func (m *modelDialogCmp) modelTypeRadio() string {
608	t := styles.CurrentTheme()
609	choices := []string{"Large Task", "Small Task"}
610	iconSelected := "◉"
611	iconUnselected := "○"
612	if m.modelList.GetModelType() == LargeModelType {
613		return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + "  " + iconUnselected + " " + choices[1])
614	}
615	return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + "  " + iconSelected + " " + choices[1])
616}
617
618func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
619	cfg := config.Get()
620	_, ok := cfg.Providers.Get(providerID)
621	return ok
622}
623
624func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
625	cfg := config.Get()
626	providers, err := config.Providers(cfg)
627	if err != nil {
628		return nil, err
629	}
630	for _, p := range providers {
631		if p.ID == providerID {
632			return &p, nil
633		}
634	}
635	return nil, nil
636}
637
638func (m *modelDialogCmp) saveOauthTokenAndContinue(apiKey any, close bool) tea.Cmd {
639	if m.selectedModel == nil {
640		return util.ReportError(fmt.Errorf("no model selected"))
641	}
642
643	cfg := config.Get()
644	err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
645	if err != nil {
646		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
647	}
648
649	// Reset API key state and continue with model selection
650	selectedModel := *m.selectedModel
651	var cmds []tea.Cmd
652	if close {
653		cmds = append(cmds, util.CmdHandler(dialogs.CloseDialogMsg{}))
654	}
655	cmds = append(
656		cmds,
657		util.CmdHandler(ModelSelectedMsg{
658			Model: config.SelectedModel{
659				Model:           selectedModel.Model.ID,
660				Provider:        string(selectedModel.Provider.ID),
661				ReasoningEffort: selectedModel.Model.DefaultReasoningEffort,
662				MaxTokens:       selectedModel.Model.DefaultMaxTokens,
663			},
664			ModelType: m.selectedModelType,
665		}),
666	)
667	return tea.Sequence(cmds...)
668}