1package models
  2
  3import (
  4	"fmt"
  5	"time"
  6
  7	"github.com/charmbracelet/bubbles/v2/help"
  8	"github.com/charmbracelet/bubbles/v2/key"
  9	"github.com/charmbracelet/bubbles/v2/spinner"
 10	tea "github.com/charmbracelet/bubbletea/v2"
 11	"github.com/charmbracelet/catwalk/pkg/catwalk"
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/tui/components/core"
 14	"github.com/charmbracelet/crush/internal/tui/components/dialogs"
 15	"github.com/charmbracelet/crush/internal/tui/exp/list"
 16	"github.com/charmbracelet/crush/internal/tui/styles"
 17	"github.com/charmbracelet/crush/internal/tui/util"
 18	"github.com/charmbracelet/lipgloss/v2"
 19)
 20
 21const (
 22	ModelsDialogID dialogs.DialogID = "models"
 23
 24	defaultWidth = 60
 25)
 26
 27const (
 28	LargeModelType int = iota
 29	SmallModelType
 30
 31	largeModelInputPlaceholder = "Choose a model for large, complex tasks"
 32	smallModelInputPlaceholder = "Choose a model for small, simple tasks"
 33)
 34
 35// ModelSelectedMsg is sent when a model is selected
 36type ModelSelectedMsg struct {
 37	Model     config.SelectedModel
 38	ModelType config.SelectedModelType
 39}
 40
 41// CloseModelDialogMsg is sent when a model is selected
 42type CloseModelDialogMsg struct{}
 43
 44// ModelDialog interface for the model selection dialog
 45type ModelDialog interface {
 46	dialogs.DialogModel
 47}
 48
 49type ModelOption struct {
 50	Provider catwalk.Provider
 51	Model    catwalk.Model
 52}
 53
 54type modelDialogCmp struct {
 55	width   int
 56	wWidth  int
 57	wHeight int
 58
 59	modelList *ModelListComponent
 60	keyMap    KeyMap
 61	help      help.Model
 62
 63	// API key state
 64	needsAPIKey       bool
 65	apiKeyInput       *APIKeyInput
 66	selectedModel     *ModelOption
 67	selectedModelType config.SelectedModelType
 68	isAPIKeyValid     bool
 69	apiKeyValue       string
 70}
 71
 72func NewModelDialogCmp() ModelDialog {
 73	keyMap := DefaultKeyMap()
 74
 75	listKeyMap := list.DefaultKeyMap()
 76	listKeyMap.Down.SetEnabled(false)
 77	listKeyMap.Up.SetEnabled(false)
 78	listKeyMap.DownOneItem = keyMap.Next
 79	listKeyMap.UpOneItem = keyMap.Previous
 80
 81	t := styles.CurrentTheme()
 82	modelList := NewModelListComponent(listKeyMap, "Choose a model for large, complex tasks", true)
 83	apiKeyInput := NewAPIKeyInput()
 84	apiKeyInput.SetShowTitle(false)
 85	help := help.New()
 86	help.Styles = t.S().Help
 87
 88	return &modelDialogCmp{
 89		modelList:   modelList,
 90		apiKeyInput: apiKeyInput,
 91		width:       defaultWidth,
 92		keyMap:      DefaultKeyMap(),
 93		help:        help,
 94	}
 95}
 96
 97func (m *modelDialogCmp) Init() tea.Cmd {
 98	return tea.Batch(m.modelList.Init(), m.apiKeyInput.Init())
 99}
100
101func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
102	switch msg := msg.(type) {
103	case tea.WindowSizeMsg:
104		m.wWidth = msg.Width
105		m.wHeight = msg.Height
106		m.apiKeyInput.SetWidth(m.width - 2)
107		m.help.Width = m.width - 2
108		return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
109	case APIKeyStateChangeMsg:
110		u, cmd := m.apiKeyInput.Update(msg)
111		m.apiKeyInput = u.(*APIKeyInput)
112		return m, cmd
113	case tea.KeyPressMsg:
114		switch {
115		case key.Matches(msg, m.keyMap.Select):
116			if m.isAPIKeyValid {
117				return m, m.saveAPIKeyAndContinue(m.apiKeyValue)
118			}
119			if m.needsAPIKey {
120				// Handle API key submission
121				m.apiKeyValue = m.apiKeyInput.Value()
122				provider, err := m.getProvider(m.selectedModel.Provider.ID)
123				if err != nil || provider == nil {
124					return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
125				}
126				providerConfig := config.ProviderConfig{
127					ID:      string(m.selectedModel.Provider.ID),
128					Name:    m.selectedModel.Provider.Name,
129					APIKey:  m.apiKeyValue,
130					Type:    provider.Type,
131					BaseURL: provider.APIEndpoint,
132				}
133				return m, tea.Sequence(
134					util.CmdHandler(APIKeyStateChangeMsg{
135						State: APIKeyInputStateVerifying,
136					}),
137					func() tea.Msg {
138						start := time.Now()
139						err := providerConfig.TestConnection(config.Get().Resolver())
140						// intentionally wait for at least 750ms to make sure the user sees the spinner
141						elapsed := time.Since(start)
142						if elapsed < 750*time.Millisecond {
143							time.Sleep(750*time.Millisecond - elapsed)
144						}
145						if err == nil {
146							m.isAPIKeyValid = true
147							return APIKeyStateChangeMsg{
148								State: APIKeyInputStateVerified,
149							}
150						}
151						return APIKeyStateChangeMsg{
152							State: APIKeyInputStateError,
153						}
154					},
155				)
156			}
157			// Normal model selection
158			selectedItem := m.modelList.SelectedModel()
159
160			var modelType config.SelectedModelType
161			if m.modelList.GetModelType() == LargeModelType {
162				modelType = config.SelectedModelTypeLarge
163			} else {
164				modelType = config.SelectedModelTypeSmall
165			}
166
167			// Check if provider is configured
168			if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
169				return m, tea.Sequence(
170					util.CmdHandler(dialogs.CloseDialogMsg{}),
171					util.CmdHandler(ModelSelectedMsg{
172						Model: config.SelectedModel{
173							Model:    selectedItem.Model.ID,
174							Provider: string(selectedItem.Provider.ID),
175						},
176						ModelType: modelType,
177					}),
178				)
179			} else {
180				// Provider not configured, show API key input
181				m.needsAPIKey = true
182				m.selectedModel = selectedItem
183				m.selectedModelType = modelType
184				m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
185				return m, nil
186			}
187		case key.Matches(msg, m.keyMap.Tab):
188			if m.needsAPIKey {
189				u, cmd := m.apiKeyInput.Update(msg)
190				m.apiKeyInput = u.(*APIKeyInput)
191				return m, cmd
192			}
193			if m.modelList.GetModelType() == LargeModelType {
194				m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
195				return m, m.modelList.SetModelType(SmallModelType)
196			} else {
197				m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
198				return m, m.modelList.SetModelType(LargeModelType)
199			}
200		case key.Matches(msg, m.keyMap.Close):
201			if m.needsAPIKey {
202				if m.isAPIKeyValid {
203					return m, nil
204				}
205				// Go back to model selection
206				m.needsAPIKey = false
207				m.selectedModel = nil
208				m.isAPIKeyValid = false
209				m.apiKeyValue = ""
210				m.apiKeyInput.Reset()
211				return m, nil
212			}
213			return m, util.CmdHandler(dialogs.CloseDialogMsg{})
214		default:
215			if m.needsAPIKey {
216				u, cmd := m.apiKeyInput.Update(msg)
217				m.apiKeyInput = u.(*APIKeyInput)
218				return m, cmd
219			} else {
220				u, cmd := m.modelList.Update(msg)
221				m.modelList = u
222				return m, cmd
223			}
224		}
225	case tea.PasteMsg:
226		if m.needsAPIKey {
227			u, cmd := m.apiKeyInput.Update(msg)
228			m.apiKeyInput = u.(*APIKeyInput)
229			return m, cmd
230		} else {
231			var cmd tea.Cmd
232			m.modelList, cmd = m.modelList.Update(msg)
233			return m, cmd
234		}
235	case spinner.TickMsg:
236		u, cmd := m.apiKeyInput.Update(msg)
237		m.apiKeyInput = u.(*APIKeyInput)
238		return m, cmd
239	}
240	return m, nil
241}
242
243func (m *modelDialogCmp) View() string {
244	t := styles.CurrentTheme()
245
246	if m.needsAPIKey {
247		// Show API key input
248		m.keyMap.isAPIKeyHelp = true
249		m.keyMap.isAPIKeyValid = m.isAPIKeyValid
250		apiKeyView := m.apiKeyInput.View()
251		apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
252		content := lipgloss.JoinVertical(
253			lipgloss.Left,
254			t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
255			apiKeyView,
256			"",
257			t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
258		)
259		return m.style().Render(content)
260	}
261
262	// Show model selection
263	listView := m.modelList.View()
264	radio := m.modelTypeRadio()
265	content := lipgloss.JoinVertical(
266		lipgloss.Left,
267		t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
268		listView,
269		"",
270		t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
271	)
272	return m.style().Render(content)
273}
274
275func (m *modelDialogCmp) Cursor() *tea.Cursor {
276	if m.needsAPIKey {
277		cursor := m.apiKeyInput.Cursor()
278		if cursor != nil {
279			cursor = m.moveCursor(cursor)
280			return cursor
281		}
282	} else {
283		cursor := m.modelList.Cursor()
284		if cursor != nil {
285			cursor = m.moveCursor(cursor)
286			return cursor
287		}
288	}
289	return nil
290}
291
292func (m *modelDialogCmp) style() lipgloss.Style {
293	t := styles.CurrentTheme()
294	return t.S().Base.
295		Width(m.width).
296		Border(lipgloss.RoundedBorder()).
297		BorderForeground(t.BorderFocus)
298}
299
300func (m *modelDialogCmp) listWidth() int {
301	return m.width - 2
302}
303
304func (m *modelDialogCmp) listHeight() int {
305	return m.wHeight / 2
306}
307
308func (m *modelDialogCmp) Position() (int, int) {
309	row := m.wHeight/4 - 2 // just a bit above the center
310	col := m.wWidth / 2
311	col -= m.width / 2
312	return row, col
313}
314
315func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
316	row, col := m.Position()
317	if m.needsAPIKey {
318		offset := row + 3 // Border + title + API key input offset
319		cursor.Y += offset
320		cursor.X = cursor.X + col + 2
321	} else {
322		offset := row + 3 // Border + title
323		cursor.Y += offset
324		cursor.X = cursor.X + col + 2
325	}
326	return cursor
327}
328
329func (m *modelDialogCmp) ID() dialogs.DialogID {
330	return ModelsDialogID
331}
332
333func (m *modelDialogCmp) modelTypeRadio() string {
334	t := styles.CurrentTheme()
335	choices := []string{"Large Task", "Small Task"}
336	iconSelected := "◉"
337	iconUnselected := "○"
338	if m.modelList.GetModelType() == LargeModelType {
339		return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + "  " + iconUnselected + " " + choices[1])
340	}
341	return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + "  " + iconSelected + " " + choices[1])
342}
343
344func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
345	cfg := config.Get()
346	if _, ok := cfg.Providers.Get(providerID); ok {
347		return true
348	}
349	return false
350}
351
352func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
353	providers, err := config.Providers()
354	if err != nil {
355		return nil, err
356	}
357	for _, p := range providers {
358		if p.ID == providerID {
359			return &p, nil
360		}
361	}
362	return nil, nil
363}
364
365func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
366	if m.selectedModel == nil {
367		return util.ReportError(fmt.Errorf("no model selected"))
368	}
369
370	cfg := config.Get()
371	err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
372	if err != nil {
373		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
374	}
375
376	// Reset API key state and continue with model selection
377	selectedModel := *m.selectedModel
378	return tea.Sequence(
379		util.CmdHandler(dialogs.CloseDialogMsg{}),
380		util.CmdHandler(ModelSelectedMsg{
381			Model: config.SelectedModel{
382				Model:    selectedModel.Model.ID,
383				Provider: string(selectedModel.Provider.ID),
384			},
385			ModelType: m.selectedModelType,
386		}),
387	)
388}