1package models
  2
  3import (
  4	"fmt"
  5	"slices"
  6	"time"
  7
  8	"github.com/charmbracelet/bubbles/v2/help"
  9	"github.com/charmbracelet/bubbles/v2/key"
 10	"github.com/charmbracelet/bubbles/v2/spinner"
 11	tea "github.com/charmbracelet/bubbletea/v2"
 12	"github.com/charmbracelet/catwalk/pkg/catwalk"
 13	"github.com/charmbracelet/crush/internal/config"
 14	"github.com/charmbracelet/crush/internal/tui/components/core"
 15	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 16	"github.com/charmbracelet/crush/internal/tui/exp/list"
 17	"github.com/charmbracelet/crush/internal/tui/styles"
 18	"github.com/charmbracelet/crush/internal/tui/util"
 19	"github.com/charmbracelet/lipgloss/v2"
 20)
 21
 22const (
 23	ModelsDialogID dialogs.DialogID = "models"
 24
 25	defaultWidth = 60
 26)
 27
 28const (
 29	LargeModelType int = iota
 30	SmallModelType
 31
 32	largeModelInputPlaceholder = "Choose a model for large, complex tasks"
 33	smallModelInputPlaceholder = "Choose a model for small, simple tasks"
 34)
 35
 36// ModelSelectedMsg is sent when a model is selected
 37type ModelSelectedMsg struct {
 38	Model     config.SelectedModel
 39	ModelType config.SelectedModelType
 40}
 41
 42// CloseModelDialogMsg is sent when a model is selected
 43type CloseModelDialogMsg struct{}
 44
 45// ModelDialog interface for the model selection dialog
 46type ModelDialog interface {
 47	dialogs.DialogModel
 48}
 49
 50type ModelOption struct {
 51	Provider catwalk.Provider
 52	Model    catwalk.Model
 53}
 54
 55type modelDialogCmp struct {
 56	width   int
 57	wWidth  int
 58	wHeight int
 59
 60	modelList *ModelListComponent
 61	keyMap    KeyMap
 62	help      help.Model
 63
 64	// API key state
 65	needsAPIKey       bool
 66	apiKeyInput       *APIKeyInput
 67	selectedModel     *ModelOption
 68	selectedModelType config.SelectedModelType
 69	isAPIKeyValid     bool
 70	apiKeyValue       string
 71}
 72
 73func NewModelDialogCmp() ModelDialog {
 74	keyMap := DefaultKeyMap()
 75
 76	listKeyMap := list.DefaultKeyMap()
 77	listKeyMap.Down.SetEnabled(false)
 78	listKeyMap.Up.SetEnabled(false)
 79	listKeyMap.DownOneItem = keyMap.Next
 80	listKeyMap.UpOneItem = keyMap.Previous
 81
 82	t := styles.CurrentTheme()
 83	modelList := NewModelListComponent(listKeyMap, "Choose a model for large, complex tasks", true)
 84	apiKeyInput := NewAPIKeyInput()
 85	apiKeyInput.SetShowTitle(false)
 86	help := help.New()
 87	help.Styles = t.S().Help
 88
 89	return &modelDialogCmp{
 90		modelList:   modelList,
 91		apiKeyInput: apiKeyInput,
 92		width:       defaultWidth,
 93		keyMap:      DefaultKeyMap(),
 94		help:        help,
 95	}
 96}
 97
 98func (m *modelDialogCmp) Init() tea.Cmd {
 99	providers, err := config.Providers()
100	if err == nil {
101		filteredProviders := []catwalk.Provider{}
102		simpleProviders := []string{
103			"anthropic",
104			"openai",
105			"gemini",
106			"xai",
107			"groq",
108			"openrouter",
109		}
110		for _, p := range providers {
111			if slices.Contains(simpleProviders, string(p.ID)) {
112				filteredProviders = append(filteredProviders, p)
113			}
114		}
115		m.modelList.SetProviders(filteredProviders)
116	}
117	return tea.Batch(m.modelList.Init(), m.apiKeyInput.Init())
118}
119
120func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
121	switch msg := msg.(type) {
122	case tea.WindowSizeMsg:
123		m.wWidth = msg.Width
124		m.wHeight = msg.Height
125		m.apiKeyInput.SetWidth(m.width - 2)
126		m.help.Width = m.width - 2
127		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
128	case APIKeyStateChangeMsg:
129		u, cmd := m.apiKeyInput.Update(msg)
130		m.apiKeyInput = u.(*APIKeyInput)
131		return m, cmd
132	case tea.KeyPressMsg:
133		switch {
134		case key.Matches(msg, m.keyMap.Select):
135			if m.isAPIKeyValid {
136				return m, m.saveAPIKeyAndContinue(m.apiKeyValue)
137			}
138			if m.needsAPIKey {
139				// Handle API key submission
140				m.apiKeyValue = m.apiKeyInput.Value()
141				provider, err := m.getProvider(m.selectedModel.Provider.ID)
142				if err != nil || provider == nil {
143					return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
144				}
145				providerConfig := config.ProviderConfig{
146					ID:      string(m.selectedModel.Provider.ID),
147					Name:    m.selectedModel.Provider.Name,
148					APIKey:  m.apiKeyValue,
149					Type:    provider.Type,
150					BaseURL: provider.APIEndpoint,
151				}
152				return m, tea.Sequence(
153					util.CmdHandler(APIKeyStateChangeMsg{
154						State: APIKeyInputStateVerifying,
155					}),
156					func() tea.Msg {
157						start := time.Now()
158						err := providerConfig.TestConnection(config.Get().Resolver())
159						// intentionally wait for at least 750ms to make sure the user sees the spinner
160						elapsed := time.Since(start)
161						if elapsed < 750*time.Millisecond {
162							time.Sleep(750*time.Millisecond - elapsed)
163						}
164						if err == nil {
165							m.isAPIKeyValid = true
166							return APIKeyStateChangeMsg{
167								State: APIKeyInputStateVerified,
168							}
169						}
170						return APIKeyStateChangeMsg{
171							State: APIKeyInputStateError,
172						}
173					},
174				)
175			}
176			// Normal model selection
177			selectedItem := m.modelList.SelectedModel()
178
179			var modelType config.SelectedModelType
180			if m.modelList.GetModelType() == LargeModelType {
181				modelType = config.SelectedModelTypeLarge
182			} else {
183				modelType = config.SelectedModelTypeSmall
184			}
185
186			// Check if provider is configured
187			if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
188				return m, tea.Sequence(
189					util.CmdHandler(dialogs.CloseDialogMsg{}),
190					util.CmdHandler(ModelSelectedMsg{
191						Model: config.SelectedModel{
192							Model:    selectedItem.Model.ID,
193							Provider: string(selectedItem.Provider.ID),
194						},
195						ModelType: modelType,
196					}),
197				)
198			} else {
199				// Provider not configured, show API key input
200				m.needsAPIKey = true
201				m.selectedModel = selectedItem
202				m.selectedModelType = modelType
203				m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
204				return m, nil
205			}
206		case key.Matches(msg, m.keyMap.Tab):
207			if m.needsAPIKey {
208				u, cmd := m.apiKeyInput.Update(msg)
209				m.apiKeyInput = u.(*APIKeyInput)
210				return m, cmd
211			}
212			if m.modelList.GetModelType() == LargeModelType {
213				m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
214				return m, m.modelList.SetModelType(SmallModelType)
215			} else {
216				m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
217				return m, m.modelList.SetModelType(LargeModelType)
218			}
219		case key.Matches(msg, m.keyMap.Close):
220			if m.needsAPIKey {
221				if m.isAPIKeyValid {
222					return m, nil
223				}
224				// Go back to model selection
225				m.needsAPIKey = false
226				m.selectedModel = nil
227				m.isAPIKeyValid = false
228				m.apiKeyValue = ""
229				m.apiKeyInput.Reset()
230				return m, nil
231			}
232			return m, util.CmdHandler(dialogs.CloseDialogMsg{})
233		default:
234			if m.needsAPIKey {
235				u, cmd := m.apiKeyInput.Update(msg)
236				m.apiKeyInput = u.(*APIKeyInput)
237				return m, cmd
238			} else {
239				u, cmd := m.modelList.Update(msg)
240				m.modelList = u
241				return m, cmd
242			}
243		}
244	case tea.PasteMsg:
245		if m.needsAPIKey {
246			u, cmd := m.apiKeyInput.Update(msg)
247			m.apiKeyInput = u.(*APIKeyInput)
248			return m, cmd
249		} else {
250			var cmd tea.Cmd
251			m.modelList, cmd = m.modelList.Update(msg)
252			return m, cmd
253		}
254	case spinner.TickMsg:
255		u, cmd := m.apiKeyInput.Update(msg)
256		m.apiKeyInput = u.(*APIKeyInput)
257		return m, cmd
258	}
259	return m, nil
260}
261
262func (m *modelDialogCmp) View() string {
263	t := styles.CurrentTheme()
264
265	if m.needsAPIKey {
266		// Show API key input
267		m.keyMap.isAPIKeyHelp = true
268		m.keyMap.isAPIKeyValid = m.isAPIKeyValid
269		apiKeyView := m.apiKeyInput.View()
270		apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
271		content := lipgloss.JoinVertical(
272			lipgloss.Left,
273			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
274			apiKeyView,
275			"",
276			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
277		)
278		return m.style().Render(content)
279	}
280
281	// Show model selection
282	listView := m.modelList.View()
283	radio := m.modelTypeRadio()
284	content := lipgloss.JoinVertical(
285		lipgloss.Left,
286		t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
287		listView,
288		"",
289		t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
290	)
291	return m.style().Render(content)
292}
293
294func (m *modelDialogCmp) Cursor() *tea.Cursor {
295	if m.needsAPIKey {
296		cursor := m.apiKeyInput.Cursor()
297		if cursor != nil {
298			cursor = m.moveCursor(cursor)
299			return cursor
300		}
301	} else {
302		cursor := m.modelList.Cursor()
303		if cursor != nil {
304			cursor = m.moveCursor(cursor)
305			return cursor
306		}
307	}
308	return nil
309}
310
311func (m *modelDialogCmp) style() lipgloss.Style {
312	t := styles.CurrentTheme()
313	return t.S().Base.
314		Width(m.width).
315		Border(lipgloss.RoundedBorder()).
316		BorderForeground(t.BorderFocus)
317}
318
319func (m *modelDialogCmp) listWidth() int {
320	return m.width - 2
321}
322
323func (m *modelDialogCmp) listHeight() int {
324	return m.wHeight / 2
325}
326
327func (m *modelDialogCmp) Position() (int, int) {
328	row := m.wHeight/4 - 2 // just a bit above the center
329	col := m.wWidth / 2
330	col -= m.width / 2
331	return row, col
332}
333
334func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
335	row, col := m.Position()
336	if m.needsAPIKey {
337		offset := row + 3 // Border + title + API key input offset
338		cursor.Y += offset
339		cursor.X = cursor.X + col + 2
340	} else {
341		offset := row + 3 // Border + title
342		cursor.Y += offset
343		cursor.X = cursor.X + col + 2
344	}
345	return cursor
346}
347
348func (m *modelDialogCmp) ID() dialogs.DialogID {
349	return ModelsDialogID
350}
351
352func (m *modelDialogCmp) modelTypeRadio() string {
353	t := styles.CurrentTheme()
354	choices := []string{"Large Task", "Small Task"}
355	iconSelected := "◉"
356	iconUnselected := "○"
357	if m.modelList.GetModelType() == LargeModelType {
358		return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + "  " + iconUnselected + " " + choices[1])
359	}
360	return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + "  " + iconSelected + " " + choices[1])
361}
362
363func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
364	cfg := config.Get()
365	if _, ok := cfg.Providers.Get(providerID); ok {
366		return true
367	}
368	return false
369}
370
371func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
372	providers, err := config.Providers()
373	if err != nil {
374		return nil, err
375	}
376	for _, p := range providers {
377		if p.ID == providerID {
378			return &p, nil
379		}
380	}
381	return nil, nil
382}
383
384func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
385	if m.selectedModel == nil {
386		return util.ReportError(fmt.Errorf("no model selected"))
387	}
388
389	cfg := config.Get()
390	err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
391	if err != nil {
392		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
393	}
394
395	// Reset API key state and continue with model selection
396	selectedModel := *m.selectedModel
397	return tea.Sequence(
398		util.CmdHandler(dialogs.CloseDialogMsg{}),
399		util.CmdHandler(ModelSelectedMsg{
400			Model: config.SelectedModel{
401				Model:    selectedModel.Model.ID,
402				Provider: string(selectedModel.Provider.ID),
403			},
404			ModelType: m.selectedModelType,
405		}),
406	)
407}