models.go

  1// Package models provides the model selection dialog for the TUI.
  2package models
  3
  4import (
  5	"fmt"
  6	"time"
  7
  8	"charm.land/bubbles/v2/help"
  9	"charm.land/bubbles/v2/key"
 10	"charm.land/bubbles/v2/spinner"
 11	tea "charm.land/bubbletea/v2"
 12	"charm.land/lipgloss/v2"
 13	"github.com/atotto/clipboard"
 14	"github.com/charmbracelet/catwalk/pkg/catwalk"
 15	hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/tui/components/core"
 18	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 19	"github.com/charmbracelet/crush/internal/tui/components/dialogs/claude"
 20	"github.com/charmbracelet/crush/internal/tui/components/dialogs/copilot"
 21	"github.com/charmbracelet/crush/internal/tui/components/dialogs/hyper"
 22	"github.com/charmbracelet/crush/internal/tui/exp/list"
 23	"github.com/charmbracelet/crush/internal/tui/styles"
 24	"github.com/charmbracelet/crush/internal/tui/util"
 25)
 26
 27const (
 28	ModelsDialogID dialogs.DialogID = "models"
 29
 30	defaultWidth = 60
 31)
 32
 33const (
 34	LargeModelType int = iota
 35	SmallModelType
 36
 37	largeModelInputPlaceholder = "Choose a model for large, complex tasks"
 38	smallModelInputPlaceholder = "Choose a model for small, simple tasks"
 39)
 40
 41// ModelSelectedMsg is sent when a model is selected
 42type ModelSelectedMsg struct {
 43	Model     config.SelectedModel
 44	ModelType config.SelectedModelType
 45}
 46
 47// CloseModelDialogMsg is sent when a model is selected
 48type CloseModelDialogMsg struct{}
 49
 50// ModelDialog interface for the model selection dialog
 51type ModelDialog interface {
 52	dialogs.DialogModel
 53}
 54
 55type ModelOption struct {
 56	Provider catwalk.Provider
 57	Model    catwalk.Model
 58}
 59
 60type modelDialogCmp struct {
 61	width   int
 62	wWidth  int
 63	wHeight int
 64
 65	modelList *ModelListComponent
 66	keyMap    KeyMap
 67	help      help.Model
 68
 69	// API key state
 70	needsAPIKey       bool
 71	apiKeyInput       *APIKeyInput
 72	selectedModel     *ModelOption
 73	selectedModelType config.SelectedModelType
 74	isAPIKeyValid     bool
 75	apiKeyValue       string
 76
 77	// Hyper device flow state
 78	hyperDeviceFlow     *hyper.DeviceFlow
 79	showHyperDeviceFlow bool
 80
 81	// Copilot device flow state
 82	copilotDeviceFlow     *copilot.DeviceFlow
 83	showCopilotDeviceFlow bool
 84
 85	// Claude state
 86	claudeAuthMethodChooser     *claude.AuthMethodChooser
 87	claudeOAuth2                *claude.OAuth2
 88	showClaudeAuthMethodChooser bool
 89	showClaudeOAuth2            bool
 90}
 91
 92func NewModelDialogCmp() ModelDialog {
 93	keyMap := DefaultKeyMap()
 94
 95	listKeyMap := list.DefaultKeyMap()
 96	listKeyMap.Down.SetEnabled(false)
 97	listKeyMap.Up.SetEnabled(false)
 98	listKeyMap.DownOneItem = keyMap.Next
 99	listKeyMap.UpOneItem = keyMap.Previous
100
101	t := styles.CurrentTheme()
102	modelList := NewModelListComponent(listKeyMap, largeModelInputPlaceholder, true)
103	apiKeyInput := NewAPIKeyInput()
104	apiKeyInput.SetShowTitle(false)
105	help := help.New()
106	help.Styles = t.S().Help
107
108	return &modelDialogCmp{
109		modelList:   modelList,
110		apiKeyInput: apiKeyInput,
111		width:       defaultWidth,
112		keyMap:      DefaultKeyMap(),
113		help:        help,
114
115		claudeAuthMethodChooser: claude.NewAuthMethodChooser(),
116		claudeOAuth2:            claude.NewOAuth2(),
117	}
118}
119
120func (m *modelDialogCmp) Init() tea.Cmd {
121	return tea.Batch(
122		m.modelList.Init(),
123		m.apiKeyInput.Init(),
124		m.claudeAuthMethodChooser.Init(),
125		m.claudeOAuth2.Init(),
126	)
127}
128
129func (m *modelDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
130	switch msg := msg.(type) {
131	case tea.WindowSizeMsg:
132		m.wWidth = msg.Width
133		m.wHeight = msg.Height
134		m.apiKeyInput.SetWidth(m.width - 2)
135		m.help.SetWidth(m.width - 2)
136		m.claudeAuthMethodChooser.SetWidth(m.width - 2)
137		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
138	case APIKeyStateChangeMsg:
139		u, cmd := m.apiKeyInput.Update(msg)
140		m.apiKeyInput = u.(*APIKeyInput)
141		return m, cmd
142	case hyper.DeviceFlowCompletedMsg:
143		return m, m.saveOauthTokenAndContinue(msg.Token, true)
144	case hyper.DeviceAuthInitiatedMsg, hyper.DeviceFlowErrorMsg:
145		if m.hyperDeviceFlow != nil {
146			u, cmd := m.hyperDeviceFlow.Update(msg)
147			m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
148			return m, cmd
149		}
150		return m, nil
151	case copilot.DeviceAuthInitiatedMsg, copilot.DeviceFlowErrorMsg:
152		if m.copilotDeviceFlow != nil {
153			u, cmd := m.copilotDeviceFlow.Update(msg)
154			m.copilotDeviceFlow = u.(*copilot.DeviceFlow)
155			return m, cmd
156		}
157		return m, nil
158	case copilot.DeviceFlowCompletedMsg:
159		return m, m.saveOauthTokenAndContinue(msg.Token, true)
160	case claude.ValidationCompletedMsg:
161		var cmds []tea.Cmd
162		u, cmd := m.claudeOAuth2.Update(msg)
163		m.claudeOAuth2 = u.(*claude.OAuth2)
164		cmds = append(cmds, cmd)
165
166		if msg.State == claude.OAuthValidationStateValid {
167			cmds = append(cmds, m.saveOauthTokenAndContinue(msg.Token, false))
168			m.keyMap.isClaudeOAuthHelpComplete = true
169		}
170
171		return m, tea.Batch(cmds...)
172	case claude.AuthenticationCompleteMsg:
173		return m, util.CmdHandler(dialogs.CloseDialogMsg{})
174	case tea.KeyPressMsg:
175		switch {
176		// Handle Hyper device flow keys
177		case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showHyperDeviceFlow:
178			return m, m.hyperDeviceFlow.CopyCode()
179		case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showCopilotDeviceFlow:
180			return m, m.copilotDeviceFlow.CopyCode()
181		case key.Matches(msg, key.NewBinding(key.WithKeys("c", "C"))) && m.showClaudeOAuth2 && m.claudeOAuth2.State == claude.OAuthStateURL:
182			return m, tea.Sequence(
183				tea.SetClipboard(m.claudeOAuth2.URL),
184				func() tea.Msg {
185					_ = clipboard.WriteAll(m.claudeOAuth2.URL)
186					return nil
187				},
188				util.ReportInfo("URL copied to clipboard"),
189			)
190		case key.Matches(msg, m.keyMap.Choose) && m.showClaudeAuthMethodChooser:
191			m.claudeAuthMethodChooser.ToggleChoice()
192			return m, nil
193		case key.Matches(msg, m.keyMap.Select):
194			// If showing device flow, enter copies code and opens URL
195			if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
196				return m, m.hyperDeviceFlow.CopyCodeAndOpenURL()
197			}
198			if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
199				return m, m.copilotDeviceFlow.CopyCodeAndOpenURL()
200			}
201			selectedItem := m.modelList.SelectedModel()
202
203			modelType := config.SelectedModelTypeLarge
204			if m.modelList.GetModelType() == SmallModelType {
205				modelType = config.SelectedModelTypeSmall
206			}
207
208			askForApiKey := func() {
209				m.keyMap.isClaudeAuthChoiceHelp = false
210				m.keyMap.isClaudeOAuthHelp = false
211				m.keyMap.isAPIKeyHelp = true
212				m.showHyperDeviceFlow = false
213				m.showCopilotDeviceFlow = false
214				m.showClaudeAuthMethodChooser = false
215				m.needsAPIKey = true
216				m.selectedModel = selectedItem
217				m.selectedModelType = modelType
218				m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
219			}
220
221			if m.showClaudeAuthMethodChooser {
222				switch m.claudeAuthMethodChooser.State {
223				case claude.AuthMethodAPIKey:
224					askForApiKey()
225				case claude.AuthMethodOAuth2:
226					m.selectedModel = selectedItem
227					m.selectedModelType = modelType
228					m.showClaudeAuthMethodChooser = false
229					m.showClaudeOAuth2 = true
230					m.keyMap.isClaudeAuthChoiceHelp = false
231					m.keyMap.isClaudeOAuthHelp = true
232				}
233				return m, nil
234			}
235			if m.showClaudeOAuth2 {
236				m2, cmd2 := m.claudeOAuth2.ValidationConfirm()
237				m.claudeOAuth2 = m2.(*claude.OAuth2)
238				return m, cmd2
239			}
240			if m.isAPIKeyValid {
241				return m, m.saveOauthTokenAndContinue(m.apiKeyValue, true)
242			}
243			if m.needsAPIKey {
244				// Handle API key submission
245				m.apiKeyValue = m.apiKeyInput.Value()
246				provider, err := m.getProvider(m.selectedModel.Provider.ID)
247				if err != nil || provider == nil {
248					return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
249				}
250				providerConfig := config.ProviderConfig{
251					ID:      string(m.selectedModel.Provider.ID),
252					Name:    m.selectedModel.Provider.Name,
253					APIKey:  m.apiKeyValue,
254					Type:    provider.Type,
255					BaseURL: provider.APIEndpoint,
256				}
257				return m, tea.Sequence(
258					util.CmdHandler(APIKeyStateChangeMsg{
259						State: APIKeyInputStateVerifying,
260					}),
261					func() tea.Msg {
262						start := time.Now()
263						err := providerConfig.TestConnection(config.Get().Resolver())
264						// intentionally wait for at least 750ms to make sure the user sees the spinner
265						elapsed := time.Since(start)
266						if elapsed < 750*time.Millisecond {
267							time.Sleep(750*time.Millisecond - elapsed)
268						}
269						if err == nil {
270							m.isAPIKeyValid = true
271							return APIKeyStateChangeMsg{
272								State: APIKeyInputStateVerified,
273							}
274						}
275						return APIKeyStateChangeMsg{
276							State: APIKeyInputStateError,
277						}
278					},
279				)
280			}
281
282			// Check if provider is configured
283			if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
284				return m, tea.Sequence(
285					util.CmdHandler(dialogs.CloseDialogMsg{}),
286					util.CmdHandler(ModelSelectedMsg{
287						Model: config.SelectedModel{
288							Model:           selectedItem.Model.ID,
289							Provider:        string(selectedItem.Provider.ID),
290							ReasoningEffort: selectedItem.Model.DefaultReasoningEffort,
291							MaxTokens:       selectedItem.Model.DefaultMaxTokens,
292						},
293						ModelType: modelType,
294					}),
295				)
296			}
297			switch selectedItem.Provider.ID {
298			case catwalk.InferenceProviderAnthropic:
299				m.showClaudeAuthMethodChooser = true
300				m.keyMap.isClaudeAuthChoiceHelp = true
301				return m, nil
302			case hyperp.Name:
303				m.showHyperDeviceFlow = true
304				m.selectedModel = selectedItem
305				m.selectedModelType = modelType
306				m.hyperDeviceFlow = hyper.NewDeviceFlow()
307				m.hyperDeviceFlow.SetWidth(m.width - 2)
308				return m, m.hyperDeviceFlow.Init()
309			case catwalk.InferenceProviderCopilot:
310				m.showCopilotDeviceFlow = true
311				m.selectedModel = selectedItem
312				m.selectedModelType = modelType
313				m.copilotDeviceFlow = copilot.NewDeviceFlow()
314				m.copilotDeviceFlow.SetWidth(m.width - 2)
315				return m, m.copilotDeviceFlow.Init()
316			}
317			// For other providers, show API key input
318			askForApiKey()
319			return m, nil
320		case key.Matches(msg, m.keyMap.Tab):
321			switch {
322			case m.showClaudeAuthMethodChooser:
323				m.claudeAuthMethodChooser.ToggleChoice()
324				return m, nil
325			case m.needsAPIKey:
326				u, cmd := m.apiKeyInput.Update(msg)
327				m.apiKeyInput = u.(*APIKeyInput)
328				return m, cmd
329			case m.modelList.GetModelType() == LargeModelType:
330				m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
331				return m, m.modelList.SetModelType(SmallModelType)
332			default:
333				m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
334				return m, m.modelList.SetModelType(LargeModelType)
335			}
336		case key.Matches(msg, m.keyMap.Close):
337			switch {
338			case m.showHyperDeviceFlow:
339				if m.hyperDeviceFlow != nil {
340					m.hyperDeviceFlow.Cancel()
341				}
342				m.showHyperDeviceFlow = false
343				m.selectedModel = nil
344			case m.showCopilotDeviceFlow:
345				if m.copilotDeviceFlow != nil {
346					m.copilotDeviceFlow.Cancel()
347				}
348				m.showCopilotDeviceFlow = false
349				m.selectedModel = nil
350			case m.showClaudeAuthMethodChooser:
351				m.claudeAuthMethodChooser.SetDefaults()
352				m.showClaudeAuthMethodChooser = false
353				m.keyMap.isClaudeAuthChoiceHelp = false
354				m.keyMap.isClaudeOAuthHelp = false
355				return m, nil
356			case m.needsAPIKey:
357				if m.isAPIKeyValid {
358					return m, nil
359				}
360				// Go back to model selection
361				m.needsAPIKey = false
362				m.selectedModel = nil
363				m.isAPIKeyValid = false
364				m.apiKeyValue = ""
365				m.apiKeyInput.Reset()
366				return m, nil
367			default:
368				return m, util.CmdHandler(dialogs.CloseDialogMsg{})
369			}
370		default:
371			switch {
372			case m.showClaudeAuthMethodChooser:
373				u, cmd := m.claudeAuthMethodChooser.Update(msg)
374				m.claudeAuthMethodChooser = u.(*claude.AuthMethodChooser)
375				return m, cmd
376			case m.showClaudeOAuth2:
377				u, cmd := m.claudeOAuth2.Update(msg)
378				m.claudeOAuth2 = u.(*claude.OAuth2)
379				return m, cmd
380			case m.needsAPIKey:
381				u, cmd := m.apiKeyInput.Update(msg)
382				m.apiKeyInput = u.(*APIKeyInput)
383				return m, cmd
384			default:
385				u, cmd := m.modelList.Update(msg)
386				m.modelList = u
387				return m, cmd
388			}
389		}
390	case tea.PasteMsg:
391		switch {
392		case m.showClaudeOAuth2:
393			u, cmd := m.claudeOAuth2.Update(msg)
394			m.claudeOAuth2 = u.(*claude.OAuth2)
395			return m, cmd
396		case m.needsAPIKey:
397			u, cmd := m.apiKeyInput.Update(msg)
398			m.apiKeyInput = u.(*APIKeyInput)
399			return m, cmd
400		default:
401			var cmd tea.Cmd
402			m.modelList, cmd = m.modelList.Update(msg)
403			return m, cmd
404		}
405	case spinner.TickMsg:
406		u, cmd := m.apiKeyInput.Update(msg)
407		m.apiKeyInput = u.(*APIKeyInput)
408		if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
409			u, cmd = m.hyperDeviceFlow.Update(msg)
410			m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
411		}
412		if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
413			u, cmd = m.copilotDeviceFlow.Update(msg)
414			m.copilotDeviceFlow = u.(*copilot.DeviceFlow)
415		}
416		return m, cmd
417	default:
418		// Pass all other messages to the device flow for spinner animation
419		switch {
420		case m.showHyperDeviceFlow && m.hyperDeviceFlow != nil:
421			u, cmd := m.hyperDeviceFlow.Update(msg)
422			m.hyperDeviceFlow = u.(*hyper.DeviceFlow)
423			return m, cmd
424		case m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil:
425			u, cmd := m.copilotDeviceFlow.Update(msg)
426			m.copilotDeviceFlow = u.(*copilot.DeviceFlow)
427			return m, cmd
428		case m.showClaudeOAuth2:
429			u, cmd := m.claudeOAuth2.Update(msg)
430			m.claudeOAuth2 = u.(*claude.OAuth2)
431			return m, cmd
432		default:
433			u, cmd := m.apiKeyInput.Update(msg)
434			m.apiKeyInput = u.(*APIKeyInput)
435			return m, cmd
436		}
437	}
438	return m, nil
439}
440
441func (m *modelDialogCmp) View() string {
442	t := styles.CurrentTheme()
443
444	if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
445		// Show Hyper device flow
446		m.keyMap.isHyperDeviceFlow = true
447		deviceFlowView := m.hyperDeviceFlow.View()
448		content := lipgloss.JoinVertical(
449			lipgloss.Left,
450			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Authenticate with Hyper", m.width-4)),
451			deviceFlowView,
452			"",
453			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
454		)
455		return m.style().Render(content)
456	}
457	if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
458		// Show Hyper device flow
459		m.keyMap.isCopilotDeviceFlow = m.copilotDeviceFlow.State != copilot.DeviceFlowStateUnavailable
460		m.keyMap.isCopilotUnavailable = m.copilotDeviceFlow.State == copilot.DeviceFlowStateUnavailable
461		deviceFlowView := m.copilotDeviceFlow.View()
462		content := lipgloss.JoinVertical(
463			lipgloss.Left,
464			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Authenticate with GitHub Copilot", m.width-4)),
465			deviceFlowView,
466			"",
467			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
468		)
469		return m.style().Render(content)
470	}
471
472	// Reset the flags when not showing device flow
473	m.keyMap.isHyperDeviceFlow = false
474	m.keyMap.isCopilotDeviceFlow = false
475	m.keyMap.isCopilotUnavailable = false
476
477	switch {
478	case m.showClaudeAuthMethodChooser:
479		chooserView := m.claudeAuthMethodChooser.View()
480		content := lipgloss.JoinVertical(
481			lipgloss.Left,
482			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
483			chooserView,
484			"",
485			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
486		)
487		return m.style().Render(content)
488	case m.showClaudeOAuth2:
489		m.keyMap.isClaudeOAuthURLState = m.claudeOAuth2.State == claude.OAuthStateURL
490		oauth2View := m.claudeOAuth2.View()
491		content := lipgloss.JoinVertical(
492			lipgloss.Left,
493			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Let's Auth Anthropic", m.width-4)),
494			oauth2View,
495			"",
496			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
497		)
498		return m.style().Render(content)
499	case m.needsAPIKey:
500		// Show API key input
501		m.keyMap.isAPIKeyHelp = true
502		m.keyMap.isAPIKeyValid = m.isAPIKeyValid
503		apiKeyView := m.apiKeyInput.View()
504		apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
505		content := lipgloss.JoinVertical(
506			lipgloss.Left,
507			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
508			apiKeyView,
509			"",
510			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
511		)
512		return m.style().Render(content)
513	}
514
515	// Show model selection
516	listView := m.modelList.View()
517	radio := m.modelTypeRadio()
518	content := lipgloss.JoinVertical(
519		lipgloss.Left,
520		t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
521		listView,
522		"",
523		t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
524	)
525	return m.style().Render(content)
526}
527
528func (m *modelDialogCmp) Cursor() *tea.Cursor {
529	if m.showHyperDeviceFlow && m.hyperDeviceFlow != nil {
530		return m.hyperDeviceFlow.Cursor()
531	}
532	if m.showCopilotDeviceFlow && m.copilotDeviceFlow != nil {
533		return m.copilotDeviceFlow.Cursor()
534	}
535	if m.showClaudeAuthMethodChooser {
536		return nil
537	}
538	if m.showClaudeOAuth2 {
539		if cursor := m.claudeOAuth2.CodeInput.Cursor(); cursor != nil {
540			cursor.Y += 2 // FIXME(@andreynering): Why do we need this?
541			return m.moveCursor(cursor)
542		}
543		return nil
544	}
545	if m.needsAPIKey {
546		cursor := m.apiKeyInput.Cursor()
547		if cursor != nil {
548			cursor = m.moveCursor(cursor)
549			return cursor
550		}
551	} else {
552		cursor := m.modelList.Cursor()
553		if cursor != nil {
554			cursor = m.moveCursor(cursor)
555			return cursor
556		}
557	}
558	return nil
559}
560
561func (m *modelDialogCmp) style() lipgloss.Style {
562	t := styles.CurrentTheme()
563	return t.S().Base.
564		Width(m.width).
565		Border(lipgloss.RoundedBorder()).
566		BorderForeground(t.BorderFocus)
567}
568
569func (m *modelDialogCmp) listWidth() int {
570	return m.width - 2
571}
572
573func (m *modelDialogCmp) listHeight() int {
574	return m.wHeight / 2
575}
576
577func (m *modelDialogCmp) Position() (int, int) {
578	row := m.wHeight/4 - 2 // just a bit above the center
579	col := m.wWidth / 2
580	col -= m.width / 2
581	return row, col
582}
583
584func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
585	row, col := m.Position()
586	if m.needsAPIKey {
587		offset := row + 3 // Border + title + API key input offset
588		cursor.Y += offset
589		cursor.X = cursor.X + col + 2
590	} else {
591		offset := row + 3 // Border + title
592		cursor.Y += offset
593		cursor.X = cursor.X + col + 2
594	}
595	return cursor
596}
597
598func (m *modelDialogCmp) ID() dialogs.DialogID {
599	return ModelsDialogID
600}
601
602func (m *modelDialogCmp) modelTypeRadio() string {
603	t := styles.CurrentTheme()
604	choices := []string{"Large Task", "Small Task"}
605	iconSelected := "◉"
606	iconUnselected := "○"
607	if m.modelList.GetModelType() == LargeModelType {
608		return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + "  " + iconUnselected + " " + choices[1])
609	}
610	return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + "  " + iconSelected + " " + choices[1])
611}
612
613func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
614	cfg := config.Get()
615	_, ok := cfg.Providers.Get(providerID)
616	return ok
617}
618
619func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
620	cfg := config.Get()
621	providers, err := config.Providers(cfg)
622	if err != nil {
623		return nil, err
624	}
625	for _, p := range providers {
626		if p.ID == providerID {
627			return &p, nil
628		}
629	}
630	return nil, nil
631}
632
633func (m *modelDialogCmp) saveOauthTokenAndContinue(apiKey any, close bool) tea.Cmd {
634	if m.selectedModel == nil {
635		return util.ReportError(fmt.Errorf("no model selected"))
636	}
637
638	cfg := config.Get()
639	err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
640	if err != nil {
641		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
642	}
643
644	// Reset API key state and continue with model selection
645	selectedModel := *m.selectedModel
646	var cmds []tea.Cmd
647	if close {
648		cmds = append(cmds, util.CmdHandler(dialogs.CloseDialogMsg{}))
649	}
650	cmds = append(
651		cmds,
652		util.CmdHandler(ModelSelectedMsg{
653			Model: config.SelectedModel{
654				Model:           selectedModel.Model.ID,
655				Provider:        string(selectedModel.Provider.ID),
656				ReasoningEffort: selectedModel.Model.DefaultReasoningEffort,
657				MaxTokens:       selectedModel.Model.DefaultMaxTokens,
658			},
659			ModelType: m.selectedModelType,
660		}),
661	)
662	return tea.Sequence(cmds...)
663}