1package models
  2
  3import (
  4	"fmt"
  5	"time"
  6
  7	"github.com/charmbracelet/bubbles/v2/help"
  8	"github.com/charmbracelet/bubbles/v2/key"
  9	"github.com/charmbracelet/bubbles/v2/spinner"
 10	tea "github.com/charmbracelet/bubbletea/v2"
 11	"github.com/charmbracelet/catwalk/pkg/catwalk"
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/tui/components/core"
 14	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 15	"github.com/charmbracelet/crush/internal/tui/exp/list"
 16	"github.com/charmbracelet/crush/internal/tui/styles"
 17	"github.com/charmbracelet/crush/internal/tui/util"
 18	"github.com/charmbracelet/lipgloss/v2"
 19)
 20
 21const (
 22	ModelsDialogID dialogs.DialogID = "models"
 23
 24	defaultWidth = 60
 25)
 26
 27const (
 28	LargeModelType int = iota
 29	SmallModelType
 30
 31	largeModelInputPlaceholder = "Choose a model for large, complex tasks"
 32	smallModelInputPlaceholder = "Choose a model for small, simple tasks"
 33)
 34
 35// ModelSelectedMsg is sent when a model is selected
 36type ModelSelectedMsg struct {
 37	Model     config.SelectedModel
 38	ModelType config.SelectedModelType
 39}
 40
 41// CloseModelDialogMsg is sent when a model is selected
 42type CloseModelDialogMsg struct{}
 43
 44// ModelDialog interface for the model selection dialog
 45type ModelDialog interface {
 46	dialogs.DialogModel
 47}
 48
 49type ModelOption struct {
 50	Provider catwalk.Provider
 51	Model    catwalk.Model
 52}
 53
 54type modelDialogCmp struct {
 55	width   int
 56	wWidth  int
 57	wHeight int
 58
 59	modelList *ModelListComponent
 60	keyMap    KeyMap
 61	help      help.Model
 62
 63	// API key state
 64	needsAPIKey       bool
 65	apiKeyInput       *APIKeyInput
 66	selectedModel     *ModelOption
 67	selectedModelType config.SelectedModelType
 68	isAPIKeyValid     bool
 69	apiKeyValue       string
 70}
 71
 72func NewModelDialogCmp() ModelDialog {
 73	keyMap := DefaultKeyMap()
 74
 75	listKeyMap := list.DefaultKeyMap()
 76	listKeyMap.Down.SetEnabled(false)
 77	listKeyMap.Up.SetEnabled(false)
 78	listKeyMap.DownOneItem = keyMap.Next
 79	listKeyMap.UpOneItem = keyMap.Previous
 80
 81	t := styles.CurrentTheme()
 82	modelList := NewModelListComponent(listKeyMap, largeModelInputPlaceholder, true)
 83	apiKeyInput := NewAPIKeyInput()
 84	apiKeyInput.SetShowTitle(false)
 85	help := help.New()
 86	help.Styles = t.S().Help
 87
 88	return &modelDialogCmp{
 89		modelList:   modelList,
 90		apiKeyInput: apiKeyInput,
 91		width:       defaultWidth,
 92		keyMap:      DefaultKeyMap(),
 93		help:        help,
 94	}
 95}
 96
 97func (m *modelDialogCmp) Init() tea.Cmd {
 98	return tea.Batch(m.modelList.Init(), m.apiKeyInput.Init())
 99}
100
101func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
102	switch msg := msg.(type) {
103	case tea.WindowSizeMsg:
104		m.wWidth = msg.Width
105		m.wHeight = msg.Height
106		m.apiKeyInput.SetWidth(m.width - 2)
107		m.help.Width = m.width - 2
108		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
109	case APIKeyStateChangeMsg:
110		u, cmd := m.apiKeyInput.Update(msg)
111		m.apiKeyInput = u.(*APIKeyInput)
112		return m, cmd
113	case tea.KeyPressMsg:
114		switch {
115		case key.Matches(msg, m.keyMap.Select):
116			if m.isAPIKeyValid {
117				return m, m.saveAPIKeyAndContinue(m.apiKeyValue)
118			}
119			if m.needsAPIKey {
120				// Handle API key submission
121				m.apiKeyValue = m.apiKeyInput.Value()
122				provider, err := m.getProvider(m.selectedModel.Provider.ID)
123				if err != nil || provider == nil {
124					return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
125				}
126				providerConfig := config.ProviderConfig{
127					ID:      string(m.selectedModel.Provider.ID),
128					Name:    m.selectedModel.Provider.Name,
129					APIKey:  m.apiKeyValue,
130					Type:    provider.Type,
131					BaseURL: provider.APIEndpoint,
132				}
133				return m, tea.Sequence(
134					util.CmdHandler(APIKeyStateChangeMsg{
135						State: APIKeyInputStateVerifying,
136					}),
137					func() tea.Msg {
138						start := time.Now()
139						err := providerConfig.TestConnection(config.Get().Resolver())
140						// intentionally wait for at least 750ms to make sure the user sees the spinner
141						elapsed := time.Since(start)
142						if elapsed < 750*time.Millisecond {
143							time.Sleep(750*time.Millisecond - elapsed)
144						}
145						if err == nil {
146							m.isAPIKeyValid = true
147							return APIKeyStateChangeMsg{
148								State: APIKeyInputStateVerified,
149							}
150						}
151						return APIKeyStateChangeMsg{
152							State: APIKeyInputStateError,
153						}
154					},
155				)
156			}
157			// Normal model selection
158			selectedItem := m.modelList.SelectedModel()
159
160			var modelType config.SelectedModelType
161			if m.modelList.GetModelType() == LargeModelType {
162				modelType = config.SelectedModelTypeLarge
163			} else {
164				modelType = config.SelectedModelTypeSmall
165			}
166
167			// Check if provider is configured
168			if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
169				return m, tea.Sequence(
170					util.CmdHandler(dialogs.CloseDialogMsg{}),
171					util.CmdHandler(ModelSelectedMsg{
172						Model: config.SelectedModel{
173							Model:           selectedItem.Model.ID,
174							Provider:        string(selectedItem.Provider.ID),
175							ReasoningEffort: selectedItem.Model.DefaultReasoningEffort,
176							MaxTokens:       selectedItem.Model.DefaultMaxTokens,
177						},
178						ModelType: modelType,
179					}),
180				)
181			} else {
182				// Provider not configured, show API key input
183				m.needsAPIKey = true
184				m.selectedModel = selectedItem
185				m.selectedModelType = modelType
186				m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
187				return m, nil
188			}
189		case key.Matches(msg, m.keyMap.Tab):
190			if m.needsAPIKey {
191				u, cmd := m.apiKeyInput.Update(msg)
192				m.apiKeyInput = u.(*APIKeyInput)
193				return m, cmd
194			}
195			if m.modelList.GetModelType() == LargeModelType {
196				m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
197				return m, m.modelList.SetModelType(SmallModelType)
198			} else {
199				m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
200				return m, m.modelList.SetModelType(LargeModelType)
201			}
202		case key.Matches(msg, m.keyMap.Close):
203			if m.needsAPIKey {
204				if m.isAPIKeyValid {
205					return m, nil
206				}
207				// Go back to model selection
208				m.needsAPIKey = false
209				m.selectedModel = nil
210				m.isAPIKeyValid = false
211				m.apiKeyValue = ""
212				m.apiKeyInput.Reset()
213				return m, nil
214			}
215			return m, util.CmdHandler(dialogs.CloseDialogMsg{})
216		default:
217			if m.needsAPIKey {
218				u, cmd := m.apiKeyInput.Update(msg)
219				m.apiKeyInput = u.(*APIKeyInput)
220				return m, cmd
221			} else {
222				u, cmd := m.modelList.Update(msg)
223				m.modelList = u
224				return m, cmd
225			}
226		}
227	case tea.PasteMsg:
228		if m.needsAPIKey {
229			u, cmd := m.apiKeyInput.Update(msg)
230			m.apiKeyInput = u.(*APIKeyInput)
231			return m, cmd
232		} else {
233			var cmd tea.Cmd
234			m.modelList, cmd = m.modelList.Update(msg)
235			return m, cmd
236		}
237	case spinner.TickMsg:
238		u, cmd := m.apiKeyInput.Update(msg)
239		m.apiKeyInput = u.(*APIKeyInput)
240		return m, cmd
241	}
242	return m, nil
243}
244
245func (m *modelDialogCmp) View() string {
246	t := styles.CurrentTheme()
247
248	if m.needsAPIKey {
249		// Show API key input
250		m.keyMap.isAPIKeyHelp = true
251		m.keyMap.isAPIKeyValid = m.isAPIKeyValid
252		apiKeyView := m.apiKeyInput.View()
253		apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
254		content := lipgloss.JoinVertical(
255			lipgloss.Left,
256			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
257			apiKeyView,
258			"",
259			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
260		)
261		return m.style().Render(content)
262	}
263
264	// Show model selection
265	listView := m.modelList.View()
266	radio := m.modelTypeRadio()
267	content := lipgloss.JoinVertical(
268		lipgloss.Left,
269		t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
270		listView,
271		"",
272		t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
273	)
274	return m.style().Render(content)
275}
276
277func (m *modelDialogCmp) Cursor() *tea.Cursor {
278	if m.needsAPIKey {
279		cursor := m.apiKeyInput.Cursor()
280		if cursor != nil {
281			cursor = m.moveCursor(cursor)
282			return cursor
283		}
284	} else {
285		cursor := m.modelList.Cursor()
286		if cursor != nil {
287			cursor = m.moveCursor(cursor)
288			return cursor
289		}
290	}
291	return nil
292}
293
294func (m *modelDialogCmp) style() lipgloss.Style {
295	t := styles.CurrentTheme()
296	return t.S().Base.
297		Width(m.width).
298		Border(lipgloss.RoundedBorder()).
299		BorderForeground(t.BorderFocus)
300}
301
302func (m *modelDialogCmp) listWidth() int {
303	return m.width - 2
304}
305
306func (m *modelDialogCmp) listHeight() int {
307	return m.wHeight / 2
308}
309
310func (m *modelDialogCmp) Position() (int, int) {
311	row := m.wHeight/4 - 2 // just a bit above the center
312	col := m.wWidth / 2
313	col -= m.width / 2
314	return row, col
315}
316
317func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
318	row, col := m.Position()
319	if m.needsAPIKey {
320		offset := row + 3 // Border + title + API key input offset
321		cursor.Y += offset
322		cursor.X = cursor.X + col + 2
323	} else {
324		offset := row + 3 // Border + title
325		cursor.Y += offset
326		cursor.X = cursor.X + col + 2
327	}
328	return cursor
329}
330
331func (m *modelDialogCmp) ID() dialogs.DialogID {
332	return ModelsDialogID
333}
334
335func (m *modelDialogCmp) modelTypeRadio() string {
336	t := styles.CurrentTheme()
337	choices := []string{"Large Task", "Small Task"}
338	iconSelected := "◉"
339	iconUnselected := "○"
340	if m.modelList.GetModelType() == LargeModelType {
341		return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + "  " + iconUnselected + " " + choices[1])
342	}
343	return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + "  " + iconSelected + " " + choices[1])
344}
345
346func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
347	cfg := config.Get()
348	if _, ok := cfg.Providers.Get(providerID); ok {
349		return true
350	}
351	return false
352}
353
354func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
355	providers, err := config.Providers()
356	if err != nil {
357		return nil, err
358	}
359	for _, p := range providers {
360		if p.ID == providerID {
361			return &p, nil
362		}
363	}
364	return nil, nil
365}
366
367func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
368	if m.selectedModel == nil {
369		return util.ReportError(fmt.Errorf("no model selected"))
370	}
371
372	cfg := config.Get()
373	err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
374	if err != nil {
375		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
376	}
377
378	// Reset API key state and continue with model selection
379	selectedModel := *m.selectedModel
380	return tea.Sequence(
381		util.CmdHandler(dialogs.CloseDialogMsg{}),
382		util.CmdHandler(ModelSelectedMsg{
383			Model: config.SelectedModel{
384				Model:           selectedModel.Model.ID,
385				Provider:        string(selectedModel.Provider.ID),
386				ReasoningEffort: selectedModel.Model.DefaultReasoningEffort,
387				MaxTokens:       selectedModel.Model.DefaultMaxTokens,
388			},
389			ModelType: m.selectedModelType,
390		}),
391	)
392}