models.go

  1// Package models provides the model selection dialog for the TUI.
  2package models
  3
  4import (
  5	"fmt"
  6	"time"
  7
  8	"charm.land/bubbles/v2/help"
  9	"charm.land/bubbles/v2/key"
 10	"charm.land/bubbles/v2/spinner"
 11	tea "charm.land/bubbletea/v2"
 12	"charm.land/lipgloss/v2"
 13	"github.com/atotto/clipboard"
 14	"github.com/charmbracelet/catwalk/pkg/catwalk"
 15	hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/tui/components/core"
 18	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 19	"github.com/charmbracelet/crush/internal/tui/components/dialogs/claude"
 20	"github.com/charmbracelet/crush/internal/tui/components/dialogs/copilot"
 21	"github.com/charmbracelet/crush/internal/tui/components/dialogs/hyper"
 22	"github.com/charmbracelet/crush/internal/tui/exp/list"
 23	"github.com/charmbracelet/crush/internal/tui/styles"
 24	"github.com/charmbracelet/crush/internal/tui/util"
 25)
 26
 27const (
 28	ModelsDialogID dialogs.DialogID = "models"
 29
 30	defaultWidth = 60
 31)
 32
 33const (
 34	LargeModelType int = iota
 35	SmallModelType
 36
 37	largeModelInputPlaceholder = "Choose a model for large, complex tasks"
 38	smallModelInputPlaceholder = "Choose a model for small, simple tasks"
 39)
 40
 41// ModelSelectedMsg is sent when a model is selected
 42type ModelSelectedMsg struct {
 43	Model     config.SelectedModel
 44	ModelType config.SelectedModelType
 45}
 46
 47// CloseModelDialogMsg is sent when a model is selected
 48type CloseModelDialogMsg struct{}
 49
 50// ModelDialog interface for the model selection dialog
 51type ModelDialog interface {
 52	dialogs.DialogModel
 53}
 54
 55type ModelOption struct {
 56	Provider catwalk.Provider
 57	Model    catwalk.Model
 58}
 59
 60type modelDialogCmp struct {
 61	width   int
 62	wWidth  int
 63	wHeight int
 64
 65	modelList *ModelListComponent
 66	keyMap    KeyMap
 67	help      help.Model
 68
 69	// API key state
 70	needsAPIKey       bool
 71	apiKeyInput       *APIKeyInput
 72	selectedModel     *ModelOption
 73	selectedModelType config.SelectedModelType
 74	isAPIKeyValid     bool
 75	apiKeyValue       string
 76
 77	// Hyper device flow state
 78	hyperDeviceFlow     *hyper.DeviceFlow
 79	showHyperDeviceFlow bool
 80
 81	// Copilot device flow state
 82	copilotDeviceFlow     *copilot.DeviceFlow
 83	showCopilotDeviceFlow bool
 84
 85	// Claude state
 86	claudeAuthMethodChooser     *claude.AuthMethodChooser
 87	claudeOAuth2                *claude.OAuth2
 88	showClaudeAuthMethodChooser bool
 89	showClaudeOAuth2            bool
 90}
 91
 92func NewModelDialogCmp() ModelDialog {
 93	keyMap := DefaultKeyMap()
 94
 95	listKeyMap := list.DefaultKeyMap()
 96	listKeyMap.Down.SetEnabled(false)
 97	listKeyMap.Up.SetEnabled(false)
 98	listKeyMap.DownOneItem = keyMap.Next
 99	listKeyMap.UpOneItem = keyMap.Previous
100
101	t := styles.CurrentTheme()
102	modelList := NewModelListComponent(listKeyMap, largeModelInputPlaceholder, true)
103	apiKeyInput := NewAPIKeyInput()
104	apiKeyInput.SetShowTitle(false)
105	help := help.New()
106	help.Styles = t.S().Help
107
108	return &modelDialogCmp{
109		modelList:   modelList,
110		apiKeyInput: apiKeyInput,
111		width:       defaultWidth,
112		keyMap:      DefaultKeyMap(),
113		help:        help,
114
115		claudeAuthMethodChooser: claude.NewAuthMethodChooser(),
116		claudeOAuth2:            claude.NewOAuth2(),
117	}
118}
119
120func (m *modelDialogCmp) Init() tea.Cmd {
121	return tea.Batch(
122		m.modelList.Init(),
123		m.apiKeyInput.Init(),
124		m.claudeAuthMethodChooser.Init(),
125		m.claudeOAuth2.Init(),
126	)
127}
128
129func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
130	switch msg := msg.(type) {
131	case tea.WindowSizeMsg:
132		m.wWidth = msg.Width
133		m.wHeight = msg.Height
134		m.apiKeyInput.SetWidth(m.width - 2)
135		m.help.SetWidth(m.width - 2)
136		m.claudeAuthMethodChooser.SetWidth(m.width - 2)
137		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
138	case APIKeyStateChangeMsg:
139		u, cmd := m.apiKeyInput.Update(msg)
140		m.apiKeyInput = u.(*APIKeyInput)
141		return m, cmd
142	case hyper.DeviceFlowCompletedMsg:
143		return m, m.saveOauthTokenAndContinue(msg.Token, true)
144	case hyper.DeviceAuthInitiatedMsg, hyper.DeviceFlowErrorMsg:
145		if m.hyperDeviceFlow != nil {
146			u, cmd := m.hyperDeviceFlow.Update(msg)
147			m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
148			return m, cmd
149		}
150		return m, nil
151	case copilot.DeviceAuthInitiatedMsg, copilot.DeviceFlowErrorMsg:
152		if m.copilotDeviceFlow != nil {
153			u, cmd := m.copilotDeviceFlow.Update(msg)
154			m.copilotDeviceFlow = u.(*copilot.DeviceFlow)
155			return m, cmd
156		}
157		return m, nil
158	case copilot.DeviceFlowCompletedMsg:
159		return m, m.saveOauthTokenAndContinue(msg.Token, true)
160	case claude.ValidationCompletedMsg:
161		var cmds []tea.Cmd
162		u, cmd := m.claudeOAuth2.Update(msg)
163		m.claudeOAuth2 = u.(*claude.OAuth2)
164		cmds = append(cmds, cmd)
165
166		if msg.State == claude.OAuthValidationStateValid {
167			cmds = append(cmds, m.saveOauthTokenAndContinue(msg.Token, false))
168			m.keyMap.isClaudeOAuthHelpComplete = true
169		}
170
171		return m, tea.Batch(cmds...)
172	case claude.AuthenticationCompleteMsg:
173		return m, util.CmdHandler(dialogs.CloseDialogMsg{})
174	case tea.KeyPressMsg:
175		switch {
176		// Handle Hyper device flow keys
177		case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showHyperDeviceFlow:
178			return m, m.hyperDeviceFlow.CopyCode()
179		case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showCopilotDeviceFlow:
180			return m, m.copilotDeviceFlow.CopyCode()
181		case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showClaudeOAuth2 && m.claudeOAuth2.State == claude.OAuthStateURL:
182			return m, tea.Sequence(
183				tea.SetClipboard(m.claudeOAuth2.URL),
184				func() tea.Msg {
185					_ = clipboard.WriteAll(m.claudeOAuth2.URL)
186					return nil
187				},
188				util.ReportInfo("URL copied to clipboard"),
189			)
190		case key.Matches(msg, m.keyMap.Choose) && m.showClaudeAuthMethodChooser:
191			m.claudeAuthMethodChooser.ToggleChoice()
192			return m, nil
193		case key.Matches(msg, m.keyMap.Select):
194			// If showing device flow, enter copies code and opens URL
195			if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
196				return m, m.hyperDeviceFlow.CopyCodeAndOpenURL()
197			}
198			if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
199				return m, m.copilotDeviceFlow.CopyCodeAndOpenURL()
200			}
201			selectedItem := m.modelList.SelectedModel()
202			if selectedItem == nil {
203				return m, nil
204			}
205
206			modelType := config.SelectedModelTypeLarge
207			if m.modelList.GetModelType() == SmallModelType {
208				modelType = config.SelectedModelTypeSmall
209			}
210
211			askForApiKey := func() {
212				m.keyMap.isClaudeAuthChoiceHelp = false
213				m.keyMap.isClaudeOAuthHelp = false
214				m.keyMap.isAPIKeyHelp = true
215				m.showHyperDeviceFlow = false
216				m.showCopilotDeviceFlow = false
217				m.showClaudeAuthMethodChooser = false
218				m.needsAPIKey = true
219				m.selectedModel = selectedItem
220				m.selectedModelType = modelType
221				m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
222			}
223
224			if m.showClaudeAuthMethodChooser {
225				switch m.claudeAuthMethodChooser.State {
226				case claude.AuthMethodAPIKey:
227					askForApiKey()
228				case claude.AuthMethodOAuth2:
229					m.selectedModel = selectedItem
230					m.selectedModelType = modelType
231					m.showClaudeAuthMethodChooser = false
232					m.showClaudeOAuth2 = true
233					m.keyMap.isClaudeAuthChoiceHelp = false
234					m.keyMap.isClaudeOAuthHelp = true
235				}
236				return m, nil
237			}
238			if m.showClaudeOAuth2 {
239				m2, cmd2 := m.claudeOAuth2.ValidationConfirm()
240				m.claudeOAuth2 = m2.(*claude.OAuth2)
241				return m, cmd2
242			}
243			if m.isAPIKeyValid {
244				return m, m.saveOauthTokenAndContinue(m.apiKeyValue, true)
245			}
246			if m.needsAPIKey {
247				// Handle API key submission
248				m.apiKeyValue = m.apiKeyInput.Value()
249				provider, err := m.getProvider(m.selectedModel.Provider.ID)
250				if err != nil || provider == nil {
251					return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
252				}
253				providerConfig := config.ProviderConfig{
254					ID:      string(m.selectedModel.Provider.ID),
255					Name:    m.selectedModel.Provider.Name,
256					APIKey:  m.apiKeyValue,
257					Type:    provider.Type,
258					BaseURL: provider.APIEndpoint,
259				}
260				return m, tea.Sequence(
261					util.CmdHandler(APIKeyStateChangeMsg{
262						State: APIKeyInputStateVerifying,
263					}),
264					func() tea.Msg {
265						start := time.Now()
266						err := providerConfig.TestConnection(config.Get().Resolver())
267						// intentionally wait for at least 750ms to make sure the user sees the spinner
268						elapsed := time.Since(start)
269						if elapsed < 750*time.Millisecond {
270							time.Sleep(750*time.Millisecond - elapsed)
271						}
272						if err == nil {
273							m.isAPIKeyValid = true
274							return APIKeyStateChangeMsg{
275								State: APIKeyInputStateVerified,
276							}
277						}
278						return APIKeyStateChangeMsg{
279							State: APIKeyInputStateError,
280						}
281					},
282				)
283			}
284
285			// Check if provider is configured
286			if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
287				return m, tea.Sequence(
288					util.CmdHandler(dialogs.CloseDialogMsg{}),
289					util.CmdHandler(ModelSelectedMsg{
290						Model: config.SelectedModel{
291							Model:           selectedItem.Model.ID,
292							Provider:        string(selectedItem.Provider.ID),
293							ReasoningEffort: selectedItem.Model.DefaultReasoningEffort,
294							MaxTokens:       selectedItem.Model.DefaultMaxTokens,
295						},
296						ModelType: modelType,
297					}),
298				)
299			}
300			switch selectedItem.Provider.ID {
301			case catwalk.InferenceProviderAnthropic:
302				m.showClaudeAuthMethodChooser = true
303				m.keyMap.isClaudeAuthChoiceHelp = true
304				return m, nil
305			case hyperp.Name:
306				m.showHyperDeviceFlow = true
307				m.selectedModel = selectedItem
308				m.selectedModelType = modelType
309				m.hyperDeviceFlow = hyper.NewDeviceFlow()
310				m.hyperDeviceFlow.SetWidth(m.width - 2)
311				return m, m.hyperDeviceFlow.Init()
312			case catwalk.InferenceProviderCopilot:
313				if token, ok := config.Get().ImportCopilot(); ok {
314					m.selectedModel = selectedItem
315					m.selectedModelType = modelType
316					return m, m.saveOauthTokenAndContinue(token, true)
317				}
318				m.showCopilotDeviceFlow = true
319				m.selectedModel = selectedItem
320				m.selectedModelType = modelType
321				m.copilotDeviceFlow = copilot.NewDeviceFlow()
322				m.copilotDeviceFlow.SetWidth(m.width - 2)
323				return m, m.copilotDeviceFlow.Init()
324			}
325			// For other providers, show API key input
326			askForApiKey()
327			return m, nil
328		case key.Matches(msg, m.keyMap.Tab):
329			switch {
330			case m.showClaudeAuthMethodChooser:
331				m.claudeAuthMethodChooser.ToggleChoice()
332				return m, nil
333			case m.needsAPIKey:
334				u, cmd := m.apiKeyInput.Update(msg)
335				m.apiKeyInput = u.(*APIKeyInput)
336				return m, cmd
337			case m.modelList.GetModelType() == LargeModelType:
338				m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
339				return m, m.modelList.SetModelType(SmallModelType)
340			default:
341				m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
342				return m, m.modelList.SetModelType(LargeModelType)
343			}
344		case key.Matches(msg, m.keyMap.Close):
345			switch {
346			case m.showHyperDeviceFlow:
347				if m.hyperDeviceFlow != nil {
348					m.hyperDeviceFlow.Cancel()
349				}
350				m.showHyperDeviceFlow = false
351				m.selectedModel = nil
352			case m.showCopilotDeviceFlow:
353				if m.copilotDeviceFlow != nil {
354					m.copilotDeviceFlow.Cancel()
355				}
356				m.showCopilotDeviceFlow = false
357				m.selectedModel = nil
358			case m.showClaudeAuthMethodChooser:
359				m.claudeAuthMethodChooser.SetDefaults()
360				m.showClaudeAuthMethodChooser = false
361				m.keyMap.isClaudeAuthChoiceHelp = false
362				m.keyMap.isClaudeOAuthHelp = false
363				return m, nil
364			case m.needsAPIKey:
365				if m.isAPIKeyValid {
366					return m, nil
367				}
368				// Go back to model selection
369				m.needsAPIKey = false
370				m.selectedModel = nil
371				m.isAPIKeyValid = false
372				m.apiKeyValue = ""
373				m.apiKeyInput.Reset()
374				return m, nil
375			default:
376				return m, util.CmdHandler(dialogs.CloseDialogMsg{})
377			}
378		default:
379			switch {
380			case m.showClaudeAuthMethodChooser:
381				u, cmd := m.claudeAuthMethodChooser.Update(msg)
382				m.claudeAuthMethodChooser = u.(*claude.AuthMethodChooser)
383				return m, cmd
384			case m.showClaudeOAuth2:
385				u, cmd := m.claudeOAuth2.Update(msg)
386				m.claudeOAuth2 = u.(*claude.OAuth2)
387				return m, cmd
388			case m.needsAPIKey:
389				u, cmd := m.apiKeyInput.Update(msg)
390				m.apiKeyInput = u.(*APIKeyInput)
391				return m, cmd
392			default:
393				u, cmd := m.modelList.Update(msg)
394				m.modelList = u
395				return m, cmd
396			}
397		}
398	case tea.PasteMsg:
399		switch {
400		case m.showClaudeOAuth2:
401			u, cmd := m.claudeOAuth2.Update(msg)
402			m.claudeOAuth2 = u.(*claude.OAuth2)
403			return m, cmd
404		case m.needsAPIKey:
405			u, cmd := m.apiKeyInput.Update(msg)
406			m.apiKeyInput = u.(*APIKeyInput)
407			return m, cmd
408		default:
409			var cmd tea.Cmd
410			m.modelList, cmd = m.modelList.Update(msg)
411			return m, cmd
412		}
413	case spinner.TickMsg:
414		u, cmd := m.apiKeyInput.Update(msg)
415		m.apiKeyInput = u.(*APIKeyInput)
416		if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
417			u, cmd = m.hyperDeviceFlow.Update(msg)
418			m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
419		}
420		if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
421			u, cmd = m.copilotDeviceFlow.Update(msg)
422			m.copilotDeviceFlow = u.(*copilot.DeviceFlow)
423		}
424		return m, cmd
425	default:
426		// Pass all other messages to the device flow for spinner animation
427		switch {
428		case m.showHyperDeviceFlow && m.hyperDeviceFlow != nil:
429			u, cmd := m.hyperDeviceFlow.Update(msg)
430			m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
431			return m, cmd
432		case m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil:
433			u, cmd := m.copilotDeviceFlow.Update(msg)
434			m.copilotDeviceFlow = u.(*copilot.DeviceFlow)
435			return m, cmd
436		case m.showClaudeOAuth2:
437			u, cmd := m.claudeOAuth2.Update(msg)
438			m.claudeOAuth2 = u.(*claude.OAuth2)
439			return m, cmd
440		default:
441			u, cmd := m.apiKeyInput.Update(msg)
442			m.apiKeyInput = u.(*APIKeyInput)
443			return m, cmd
444		}
445	}
446	return m, nil
447}
448
449func (m *modelDialogCmp) View() string {
450	t := styles.CurrentTheme()
451
452	if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
453		// Show Hyper device flow
454		m.keyMap.isHyperDeviceFlow = true
455		deviceFlowView := m.hyperDeviceFlow.View()
456		content := lipgloss.JoinVertical(
457			lipgloss.Left,
458			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Authenticate with Hyper", m.width-4)),
459			deviceFlowView,
460			"",
461			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
462		)
463		return m.style().Render(content)
464	}
465	if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
466		// Show Hyper device flow
467		m.keyMap.isCopilotDeviceFlow = m.copilotDeviceFlow.State != copilot.DeviceFlowStateUnavailable
468		m.keyMap.isCopilotUnavailable = m.copilotDeviceFlow.State == copilot.DeviceFlowStateUnavailable
469		deviceFlowView := m.copilotDeviceFlow.View()
470		content := lipgloss.JoinVertical(
471			lipgloss.Left,
472			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Authenticate with GitHub Copilot", m.width-4)),
473			deviceFlowView,
474			"",
475			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
476		)
477		return m.style().Render(content)
478	}
479
480	// Reset the flags when not showing device flow
481	m.keyMap.isHyperDeviceFlow = false
482	m.keyMap.isCopilotDeviceFlow = false
483	m.keyMap.isCopilotUnavailable = false
484
485	switch {
486	case m.showClaudeAuthMethodChooser:
487		chooserView := m.claudeAuthMethodChooser.View()
488		content := lipgloss.JoinVertical(
489			lipgloss.Left,
490			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
491			chooserView,
492			"",
493			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
494		)
495		return m.style().Render(content)
496	case m.showClaudeOAuth2:
497		m.keyMap.isClaudeOAuthURLState = m.claudeOAuth2.State == claude.OAuthStateURL
498		oauth2View := m.claudeOAuth2.View()
499		content := lipgloss.JoinVertical(
500			lipgloss.Left,
501			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
502			oauth2View,
503			"",
504			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
505		)
506		return m.style().Render(content)
507	case m.needsAPIKey:
508		// Show API key input
509		m.keyMap.isAPIKeyHelp = true
510		m.keyMap.isAPIKeyValid = m.isAPIKeyValid
511		apiKeyView := m.apiKeyInput.View()
512		apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
513		content := lipgloss.JoinVertical(
514			lipgloss.Left,
515			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
516			apiKeyView,
517			"",
518			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
519		)
520		return m.style().Render(content)
521	}
522
523	// Show model selection
524	listView := m.modelList.View()
525	radio := m.modelTypeRadio()
526	content := lipgloss.JoinVertical(
527		lipgloss.Left,
528		t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
529		listView,
530		"",
531		t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
532	)
533	return m.style().Render(content)
534}
535
536func (m *modelDialogCmp) Cursor() *tea.Cursor {
537	if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
538		return m.hyperDeviceFlow.Cursor()
539	}
540	if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
541		return m.copilotDeviceFlow.Cursor()
542	}
543	if m.showClaudeAuthMethodChooser {
544		return nil
545	}
546	if m.showClaudeOAuth2 {
547		if cursor := m.claudeOAuth2.CodeInput.Cursor(); cursor != nil {
548			cursor.Y += 2 // FIXME(@andreynering): Why do we need this?
549			return m.moveCursor(cursor)
550		}
551		return nil
552	}
553	if m.needsAPIKey {
554		cursor := m.apiKeyInput.Cursor()
555		if cursor != nil {
556			cursor = m.moveCursor(cursor)
557			return cursor
558		}
559	} else {
560		cursor := m.modelList.Cursor()
561		if cursor != nil {
562			cursor = m.moveCursor(cursor)
563			return cursor
564		}
565	}
566	return nil
567}
568
569func (m *modelDialogCmp) style() lipgloss.Style {
570	t := styles.CurrentTheme()
571	return t.S().Base.
572		Width(m.width).
573		Border(lipgloss.RoundedBorder()).
574		BorderForeground(t.BorderFocus)
575}
576
577func (m *modelDialogCmp) listWidth() int {
578	return m.width - 2
579}
580
581func (m *modelDialogCmp) listHeight() int {
582	return m.wHeight / 2
583}
584
585func (m *modelDialogCmp) Position() (int, int) {
586	row := m.wHeight/4 - 2 // just a bit above the center
587	col := m.wWidth / 2
588	col -= m.width / 2
589	return row, col
590}
591
592func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
593	row, col := m.Position()
594	if m.needsAPIKey {
595		offset := row + 3 // Border + title + API key input offset
596		cursor.Y += offset
597		cursor.X = cursor.X + col + 2
598	} else {
599		offset := row + 3 // Border + title
600		cursor.Y += offset
601		cursor.X = cursor.X + col + 2
602	}
603	return cursor
604}
605
606func (m *modelDialogCmp) ID() dialogs.DialogID {
607	return ModelsDialogID
608}
609
610func (m *modelDialogCmp) modelTypeRadio() string {
611	t := styles.CurrentTheme()
612	choices := []string{"Large Task", "Small Task"}
613	iconSelected := "◉"
614	iconUnselected := "○"
615	if m.modelList.GetModelType() == LargeModelType {
616		return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + "  " + iconUnselected + " " + choices[1])
617	}
618	return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + "  " + iconSelected + " " + choices[1])
619}
620
621func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
622	cfg := config.Get()
623	_, ok := cfg.Providers.Get(providerID)
624	return ok
625}
626
627func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
628	cfg := config.Get()
629	providers, err := config.Providers(cfg)
630	if err != nil {
631		return nil, err
632	}
633	for _, p := range providers {
634		if p.ID == providerID {
635			return &p, nil
636		}
637	}
638	return nil, nil
639}
640
641func (m *modelDialogCmp) saveOauthTokenAndContinue(apiKey any, close bool) tea.Cmd {
642	if m.selectedModel == nil {
643		return util.ReportError(fmt.Errorf("no model selected"))
644	}
645
646	cfg := config.Get()
647	err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
648	if err != nil {
649		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
650	}
651
652	// Reset API key state and continue with model selection
653	selectedModel := *m.selectedModel
654	var cmds []tea.Cmd
655	if close {
656		cmds = append(cmds, util.CmdHandler(dialogs.CloseDialogMsg{}))
657	}
658	cmds = append(
659		cmds,
660		util.CmdHandler(ModelSelectedMsg{
661			Model: config.SelectedModel{
662				Model:           selectedModel.Model.ID,
663				Provider:        string(selectedModel.Provider.ID),
664				ReasoningEffort: selectedModel.Model.DefaultReasoningEffort,
665				MaxTokens:       selectedModel.Model.DefaultMaxTokens,
666			},
667			ModelType: m.selectedModelType,
668		}),
669	)
670	return tea.Sequence(cmds...)
671}