1package models
2
3import (
4 "fmt"
5 "slices"
6 "time"
7
8 "github.com/charmbracelet/bubbles/v2/help"
9 "github.com/charmbracelet/bubbles/v2/key"
10 "github.com/charmbracelet/bubbles/v2/spinner"
11 tea "github.com/charmbracelet/bubbletea/v2"
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/tui/components/core"
15 "github.com/charmbracelet/crush/internal/tui/components/dialogs"
16 "github.com/charmbracelet/crush/internal/tui/exp/list"
17 "github.com/charmbracelet/crush/internal/tui/styles"
18 "github.com/charmbracelet/crush/internal/tui/util"
19 "github.com/charmbracelet/lipgloss/v2"
20)
21
22const (
23 ModelsDialogID dialogs.DialogID = "models"
24
25 defaultWidth = 60
26)
27
28const (
29 LargeModelType int = iota
30 SmallModelType
31
32 largeModelInputPlaceholder = "Choose a model for large, complex tasks"
33 smallModelInputPlaceholder = "Choose a model for small, simple tasks"
34)
35
36// ModelSelectedMsg is sent when a model is selected
37type ModelSelectedMsg struct {
38 Model config.SelectedModel
39 ModelType config.SelectedModelType
40}
41
42// CloseModelDialogMsg is sent when a model is selected
43type CloseModelDialogMsg struct{}
44
45// ModelDialog interface for the model selection dialog
46type ModelDialog interface {
47 dialogs.DialogModel
48}
49
50type ModelOption struct {
51 Provider catwalk.Provider
52 Model catwalk.Model
53}
54
55type modelDialogCmp struct {
56 width int
57 wWidth int
58 wHeight int
59
60 modelList *ModelListComponent
61 keyMap KeyMap
62 help help.Model
63
64 // API key state
65 needsAPIKey bool
66 apiKeyInput *APIKeyInput
67 selectedModel *ModelOption
68 selectedModelType config.SelectedModelType
69 isAPIKeyValid bool
70 apiKeyValue string
71}
72
73func NewModelDialogCmp() ModelDialog {
74 keyMap := DefaultKeyMap()
75
76 listKeyMap := list.DefaultKeyMap()
77 listKeyMap.Down.SetEnabled(false)
78 listKeyMap.Up.SetEnabled(false)
79 listKeyMap.DownOneItem = keyMap.Next
80 listKeyMap.UpOneItem = keyMap.Previous
81
82 t := styles.CurrentTheme()
83 modelList := NewModelListComponent(listKeyMap, "Choose a model for large, complex tasks", true)
84 apiKeyInput := NewAPIKeyInput()
85 apiKeyInput.SetShowTitle(false)
86 help := help.New()
87 help.Styles = t.S().Help
88
89 return &modelDialogCmp{
90 modelList: modelList,
91 apiKeyInput: apiKeyInput,
92 width: defaultWidth,
93 keyMap: DefaultKeyMap(),
94 help: help,
95 }
96}
97
98func (m *modelDialogCmp) Init() tea.Cmd {
99 providers, err := config.Providers()
100 if err == nil {
101 filteredProviders := []catwalk.Provider{}
102 simpleProviders := []string{
103 "anthropic",
104 "openai",
105 "gemini",
106 "xai",
107 "groq",
108 "openrouter",
109 }
110 for _, p := range providers {
111 if slices.Contains(simpleProviders, string(p.ID)) {
112 filteredProviders = append(filteredProviders, p)
113 }
114 }
115 m.modelList.SetProviders(filteredProviders)
116 }
117 return tea.Batch(m.modelList.Init(), m.apiKeyInput.Init())
118}
119
120func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
121 switch msg := msg.(type) {
122 case tea.WindowSizeMsg:
123 m.wWidth = msg.Width
124 m.wHeight = msg.Height
125 m.apiKeyInput.SetWidth(m.width - 2)
126 m.help.Width = m.width - 2
127 return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
128 case APIKeyStateChangeMsg:
129 u, cmd := m.apiKeyInput.Update(msg)
130 m.apiKeyInput = u.(*APIKeyInput)
131 return m, cmd
132 case tea.KeyPressMsg:
133 switch {
134 case key.Matches(msg, m.keyMap.Select):
135 if m.isAPIKeyValid {
136 return m, m.saveAPIKeyAndContinue(m.apiKeyValue)
137 }
138 if m.needsAPIKey {
139 // Handle API key submission
140 m.apiKeyValue = m.apiKeyInput.Value()
141 provider, err := m.getProvider(m.selectedModel.Provider.ID)
142 if err != nil || provider == nil {
143 return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
144 }
145 providerConfig := config.ProviderConfig{
146 ID: string(m.selectedModel.Provider.ID),
147 Name: m.selectedModel.Provider.Name,
148 APIKey: m.apiKeyValue,
149 Type: provider.Type,
150 BaseURL: provider.APIEndpoint,
151 }
152 return m, tea.Sequence(
153 util.CmdHandler(APIKeyStateChangeMsg{
154 State: APIKeyInputStateVerifying,
155 }),
156 func() tea.Msg {
157 start := time.Now()
158 err := providerConfig.TestConnection(config.Get().Resolver())
159 // intentionally wait for at least 750ms to make sure the user sees the spinner
160 elapsed := time.Since(start)
161 if elapsed < 750*time.Millisecond {
162 time.Sleep(750*time.Millisecond - elapsed)
163 }
164 if err == nil {
165 m.isAPIKeyValid = true
166 return APIKeyStateChangeMsg{
167 State: APIKeyInputStateVerified,
168 }
169 }
170 return APIKeyStateChangeMsg{
171 State: APIKeyInputStateError,
172 }
173 },
174 )
175 }
176 // Normal model selection
177 selectedItem := m.modelList.SelectedModel()
178
179 var modelType config.SelectedModelType
180 if m.modelList.GetModelType() == LargeModelType {
181 modelType = config.SelectedModelTypeLarge
182 } else {
183 modelType = config.SelectedModelTypeSmall
184 }
185
186 // Check if provider is configured
187 if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
188 return m, tea.Sequence(
189 util.CmdHandler(dialogs.CloseDialogMsg{}),
190 util.CmdHandler(ModelSelectedMsg{
191 Model: config.SelectedModel{
192 Model: selectedItem.Model.ID,
193 Provider: string(selectedItem.Provider.ID),
194 },
195 ModelType: modelType,
196 }),
197 )
198 } else {
199 // Provider not configured, show API key input
200 m.needsAPIKey = true
201 m.selectedModel = selectedItem
202 m.selectedModelType = modelType
203 m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
204 return m, nil
205 }
206 case key.Matches(msg, m.keyMap.Tab):
207 if m.needsAPIKey {
208 u, cmd := m.apiKeyInput.Update(msg)
209 m.apiKeyInput = u.(*APIKeyInput)
210 return m, cmd
211 }
212 if m.modelList.GetModelType() == LargeModelType {
213 m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
214 return m, m.modelList.SetModelType(SmallModelType)
215 } else {
216 m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
217 return m, m.modelList.SetModelType(LargeModelType)
218 }
219 case key.Matches(msg, m.keyMap.Close):
220 if m.needsAPIKey {
221 if m.isAPIKeyValid {
222 return m, nil
223 }
224 // Go back to model selection
225 m.needsAPIKey = false
226 m.selectedModel = nil
227 m.isAPIKeyValid = false
228 m.apiKeyValue = ""
229 m.apiKeyInput.Reset()
230 return m, nil
231 }
232 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
233 default:
234 if m.needsAPIKey {
235 u, cmd := m.apiKeyInput.Update(msg)
236 m.apiKeyInput = u.(*APIKeyInput)
237 return m, cmd
238 } else {
239 u, cmd := m.modelList.Update(msg)
240 m.modelList = u
241 return m, cmd
242 }
243 }
244 case tea.PasteMsg:
245 if m.needsAPIKey {
246 u, cmd := m.apiKeyInput.Update(msg)
247 m.apiKeyInput = u.(*APIKeyInput)
248 return m, cmd
249 } else {
250 var cmd tea.Cmd
251 m.modelList, cmd = m.modelList.Update(msg)
252 return m, cmd
253 }
254 case spinner.TickMsg:
255 u, cmd := m.apiKeyInput.Update(msg)
256 m.apiKeyInput = u.(*APIKeyInput)
257 return m, cmd
258 }
259 return m, nil
260}
261
262func (m *modelDialogCmp) View() string {
263 t := styles.CurrentTheme()
264
265 if m.needsAPIKey {
266 // Show API key input
267 m.keyMap.isAPIKeyHelp = true
268 m.keyMap.isAPIKeyValid = m.isAPIKeyValid
269 apiKeyView := m.apiKeyInput.View()
270 apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
271 content := lipgloss.JoinVertical(
272 lipgloss.Left,
273 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
274 apiKeyView,
275 "",
276 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
277 )
278 return m.style().Render(content)
279 }
280
281 // Show model selection
282 listView := m.modelList.View()
283 radio := m.modelTypeRadio()
284 content := lipgloss.JoinVertical(
285 lipgloss.Left,
286 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
287 listView,
288 "",
289 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
290 )
291 return m.style().Render(content)
292}
293
294func (m *modelDialogCmp) Cursor() *tea.Cursor {
295 if m.needsAPIKey {
296 cursor := m.apiKeyInput.Cursor()
297 if cursor != nil {
298 cursor = m.moveCursor(cursor)
299 return cursor
300 }
301 } else {
302 cursor := m.modelList.Cursor()
303 if cursor != nil {
304 cursor = m.moveCursor(cursor)
305 return cursor
306 }
307 }
308 return nil
309}
310
311func (m *modelDialogCmp) style() lipgloss.Style {
312 t := styles.CurrentTheme()
313 return t.S().Base.
314 Width(m.width).
315 Border(lipgloss.RoundedBorder()).
316 BorderForeground(t.BorderFocus)
317}
318
319func (m *modelDialogCmp) listWidth() int {
320 return m.width - 2
321}
322
323func (m *modelDialogCmp) listHeight() int {
324 return m.wHeight / 2
325}
326
327func (m *modelDialogCmp) Position() (int, int) {
328 row := m.wHeight/4 - 2 // just a bit above the center
329 col := m.wWidth / 2
330 col -= m.width / 2
331 return row, col
332}
333
334func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
335 row, col := m.Position()
336 if m.needsAPIKey {
337 offset := row + 3 // Border + title + API key input offset
338 cursor.Y += offset
339 cursor.X = cursor.X + col + 2
340 } else {
341 offset := row + 3 // Border + title
342 cursor.Y += offset
343 cursor.X = cursor.X + col + 2
344 }
345 return cursor
346}
347
348func (m *modelDialogCmp) ID() dialogs.DialogID {
349 return ModelsDialogID
350}
351
352func (m *modelDialogCmp) modelTypeRadio() string {
353 t := styles.CurrentTheme()
354 choices := []string{"Large Task", "Small Task"}
355 iconSelected := "◉"
356 iconUnselected := "○"
357 if m.modelList.GetModelType() == LargeModelType {
358 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1])
359 }
360 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1])
361}
362
363func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
364 cfg := config.Get()
365 if _, ok := cfg.Providers.Get(providerID); ok {
366 return true
367 }
368 return false
369}
370
371func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
372 providers, err := config.Providers()
373 if err != nil {
374 return nil, err
375 }
376 for _, p := range providers {
377 if p.ID == providerID {
378 return &p, nil
379 }
380 }
381 return nil, nil
382}
383
384func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
385 if m.selectedModel == nil {
386 return util.ReportError(fmt.Errorf("no model selected"))
387 }
388
389 cfg := config.Get()
390 err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
391 if err != nil {
392 return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
393 }
394
395 // Reset API key state and continue with model selection
396 selectedModel := *m.selectedModel
397 return tea.Sequence(
398 util.CmdHandler(dialogs.CloseDialogMsg{}),
399 util.CmdHandler(ModelSelectedMsg{
400 Model: config.SelectedModel{
401 Model: selectedModel.Model.ID,
402 Provider: string(selectedModel.Provider.ID),
403 },
404 ModelType: m.selectedModelType,
405 }),
406 )
407}