1package models
2
3import (
4 "fmt"
5 "time"
6
7 "github.com/charmbracelet/bubbles/v2/help"
8 "github.com/charmbracelet/bubbles/v2/key"
9 "github.com/charmbracelet/bubbles/v2/spinner"
10 tea "github.com/charmbracelet/bubbletea/v2"
11 "github.com/charmbracelet/catwalk/pkg/catwalk"
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/llm/agent"
14 "github.com/charmbracelet/crush/internal/llm/provider"
15 "github.com/charmbracelet/crush/internal/tui/components/core"
16 "github.com/charmbracelet/crush/internal/tui/components/dialogs"
17 "github.com/charmbracelet/crush/internal/tui/exp/list"
18 "github.com/charmbracelet/crush/internal/tui/styles"
19 "github.com/charmbracelet/crush/internal/tui/util"
20 "github.com/charmbracelet/lipgloss/v2"
21)
22
23const (
24 ModelsDialogID dialogs.DialogID = "models"
25
26 defaultWidth = 60
27)
28
29const (
30 LargeModelType int = iota
31 SmallModelType
32
33 largeModelInputPlaceholder = "Choose a model for large, complex tasks"
34 smallModelInputPlaceholder = "Choose a model for small, simple tasks"
35)
36
37// ModelSelectedMsg is sent when a model is selected
38type ModelSelectedMsg struct {
39 Model agent.Model
40 ModelType config.SelectedModelType
41}
42
43// CloseModelDialogMsg is sent when a model is selected
44type CloseModelDialogMsg struct{}
45
46// ModelDialog interface for the model selection dialog
47type ModelDialog interface {
48 dialogs.DialogModel
49}
50
51type ModelOption struct {
52 Provider catwalk.Provider
53 Model catwalk.Model
54}
55
56type modelDialogCmp struct {
57 width int
58 wWidth int
59 wHeight int
60
61 config *config.Config
62 modelList *ModelListComponent
63 keyMap KeyMap
64 help help.Model
65
66 // API key state
67 needsAPIKey bool
68 apiKeyInput *APIKeyInput
69 selectedModel *ModelOption
70 selectedModelType config.SelectedModelType
71 isAPIKeyValid bool
72 apiKeyValue string
73}
74
75func NewModelDialogCmp(cfg *config.Config) ModelDialog {
76 keyMap := DefaultKeyMap()
77
78 listKeyMap := list.DefaultKeyMap()
79 listKeyMap.Down.SetEnabled(false)
80 listKeyMap.Up.SetEnabled(false)
81 listKeyMap.DownOneItem = keyMap.Next
82 listKeyMap.UpOneItem = keyMap.Previous
83
84 t := styles.CurrentTheme()
85 modelList := NewModelListComponent(cfg, listKeyMap, "Choose a model for large, complex tasks", true)
86 apiKeyInput := NewAPIKeyInput()
87 apiKeyInput.SetShowTitle(false)
88 help := help.New()
89 help.Styles = t.S().Help
90
91 return &modelDialogCmp{
92 modelList: modelList,
93 apiKeyInput: apiKeyInput,
94 width: defaultWidth,
95 keyMap: DefaultKeyMap(),
96 help: help,
97 config: cfg,
98 }
99}
100
101func (m *modelDialogCmp) Init() tea.Cmd {
102 return tea.Batch(m.modelList.Init(), m.apiKeyInput.Init())
103}
104
105func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
106 switch msg := msg.(type) {
107 case tea.WindowSizeMsg:
108 m.wWidth = msg.Width
109 m.wHeight = msg.Height
110 m.apiKeyInput.SetWidth(m.width - 2)
111 m.help.Width = m.width - 2
112 return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
113 case APIKeyStateChangeMsg:
114 u, cmd := m.apiKeyInput.Update(msg)
115 m.apiKeyInput = u.(*APIKeyInput)
116 return m, cmd
117 case tea.KeyPressMsg:
118 switch {
119 case key.Matches(msg, m.keyMap.Select):
120 if m.isAPIKeyValid {
121 return m, m.saveAPIKeyAndContinue(m.apiKeyValue)
122 }
123 if m.needsAPIKey {
124 // Handle API key submission
125 m.apiKeyValue = m.apiKeyInput.Value()
126 selectedProvider, err := m.getProvider(m.selectedModel.Provider.ID)
127 if err != nil || selectedProvider == nil {
128 return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
129 }
130 providerConfig := provider.Config{
131 ID: string(m.selectedModel.Provider.ID),
132 Name: m.selectedModel.Provider.Name,
133 APIKey: m.apiKeyValue,
134 Type: selectedProvider.Type,
135 BaseURL: selectedProvider.APIEndpoint,
136 }
137 return m, tea.Sequence(
138 util.CmdHandler(APIKeyStateChangeMsg{
139 State: APIKeyInputStateVerifying,
140 }),
141 func() tea.Msg {
142 start := time.Now()
143 err := providerConfig.TestConnection(m.config.Resolver())
144 // intentionally wait for at least 750ms to make sure the user sees the spinner
145 elapsed := time.Since(start)
146 if elapsed < 750*time.Millisecond {
147 time.Sleep(750*time.Millisecond - elapsed)
148 }
149 if err == nil {
150 m.isAPIKeyValid = true
151 return APIKeyStateChangeMsg{
152 State: APIKeyInputStateVerified,
153 }
154 }
155 return APIKeyStateChangeMsg{
156 State: APIKeyInputStateError,
157 }
158 },
159 )
160 }
161 // Normal model selection
162 selectedItem := m.modelList.SelectedModel()
163
164 var modelType config.SelectedModelType
165 if m.modelList.GetModelType() == LargeModelType {
166 modelType = config.SelectedModelTypeLarge
167 } else {
168 modelType = config.SelectedModelTypeSmall
169 }
170
171 // Check if provider is configured
172 if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
173 return m, tea.Sequence(
174 util.CmdHandler(dialogs.CloseDialogMsg{}),
175 util.CmdHandler(ModelSelectedMsg{
176 Model: agent.Model{
177 Model: selectedItem.Model.ID,
178 Provider: string(selectedItem.Provider.ID),
179 },
180 ModelType: modelType,
181 }),
182 )
183 } else {
184 // Provider not configured, show API key input
185 m.needsAPIKey = true
186 m.selectedModel = selectedItem
187 m.selectedModelType = modelType
188 m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
189 return m, nil
190 }
191 case key.Matches(msg, m.keyMap.Tab):
192 if m.needsAPIKey {
193 u, cmd := m.apiKeyInput.Update(msg)
194 m.apiKeyInput = u.(*APIKeyInput)
195 return m, cmd
196 }
197 if m.modelList.GetModelType() == LargeModelType {
198 m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
199 return m, m.modelList.SetModelType(SmallModelType)
200 } else {
201 m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
202 return m, m.modelList.SetModelType(LargeModelType)
203 }
204 case key.Matches(msg, m.keyMap.Close):
205 if m.needsAPIKey {
206 if m.isAPIKeyValid {
207 return m, nil
208 }
209 // Go back to model selection
210 m.needsAPIKey = false
211 m.selectedModel = nil
212 m.isAPIKeyValid = false
213 m.apiKeyValue = ""
214 m.apiKeyInput.Reset()
215 return m, nil
216 }
217 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
218 default:
219 if m.needsAPIKey {
220 u, cmd := m.apiKeyInput.Update(msg)
221 m.apiKeyInput = u.(*APIKeyInput)
222 return m, cmd
223 } else {
224 u, cmd := m.modelList.Update(msg)
225 m.modelList = u
226 return m, cmd
227 }
228 }
229 case tea.PasteMsg:
230 if m.needsAPIKey {
231 u, cmd := m.apiKeyInput.Update(msg)
232 m.apiKeyInput = u.(*APIKeyInput)
233 return m, cmd
234 } else {
235 var cmd tea.Cmd
236 m.modelList, cmd = m.modelList.Update(msg)
237 return m, cmd
238 }
239 case spinner.TickMsg:
240 u, cmd := m.apiKeyInput.Update(msg)
241 m.apiKeyInput = u.(*APIKeyInput)
242 return m, cmd
243 }
244 return m, nil
245}
246
247func (m *modelDialogCmp) View() string {
248 t := styles.CurrentTheme()
249
250 if m.needsAPIKey {
251 // Show API key input
252 m.keyMap.isAPIKeyHelp = true
253 m.keyMap.isAPIKeyValid = m.isAPIKeyValid
254 apiKeyView := m.apiKeyInput.View()
255 apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
256 content := lipgloss.JoinVertical(
257 lipgloss.Left,
258 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
259 apiKeyView,
260 "",
261 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
262 )
263 return m.style().Render(content)
264 }
265
266 // Show model selection
267 listView := m.modelList.View()
268 radio := m.modelTypeRadio()
269 content := lipgloss.JoinVertical(
270 lipgloss.Left,
271 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
272 listView,
273 "",
274 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
275 )
276 return m.style().Render(content)
277}
278
279func (m *modelDialogCmp) Cursor() *tea.Cursor {
280 if m.needsAPIKey {
281 cursor := m.apiKeyInput.Cursor()
282 if cursor != nil {
283 cursor = m.moveCursor(cursor)
284 return cursor
285 }
286 } else {
287 cursor := m.modelList.Cursor()
288 if cursor != nil {
289 cursor = m.moveCursor(cursor)
290 return cursor
291 }
292 }
293 return nil
294}
295
296func (m *modelDialogCmp) style() lipgloss.Style {
297 t := styles.CurrentTheme()
298 return t.S().Base.
299 Width(m.width).
300 Border(lipgloss.RoundedBorder()).
301 BorderForeground(t.BorderFocus)
302}
303
304func (m *modelDialogCmp) listWidth() int {
305 return m.width - 2
306}
307
308func (m *modelDialogCmp) listHeight() int {
309 return m.wHeight / 2
310}
311
312func (m *modelDialogCmp) Position() (int, int) {
313 row := m.wHeight/4 - 2 // just a bit above the center
314 col := m.wWidth / 2
315 col -= m.width / 2
316 return row, col
317}
318
319func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
320 row, col := m.Position()
321 if m.needsAPIKey {
322 offset := row + 3 // Border + title + API key input offset
323 cursor.Y += offset
324 cursor.X = cursor.X + col + 2
325 } else {
326 offset := row + 3 // Border + title
327 cursor.Y += offset
328 cursor.X = cursor.X + col + 2
329 }
330 return cursor
331}
332
333func (m *modelDialogCmp) ID() dialogs.DialogID {
334 return ModelsDialogID
335}
336
337func (m *modelDialogCmp) modelTypeRadio() string {
338 t := styles.CurrentTheme()
339 choices := []string{"Large Task", "Small Task"}
340 iconSelected := "◉"
341 iconUnselected := "○"
342 if m.modelList.GetModelType() == LargeModelType {
343 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1])
344 }
345 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1])
346}
347
348func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
349 if _, ok := m.config.Providers.Get(providerID); ok {
350 return true
351 }
352 return false
353}
354
355func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
356 providers, err := config.Providers()
357 if err != nil {
358 return nil, err
359 }
360 for _, p := range providers {
361 if p.ID == providerID {
362 return &p, nil
363 }
364 }
365 return nil, nil
366}
367
368func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
369 if m.selectedModel == nil {
370 return util.ReportError(fmt.Errorf("no model selected"))
371 }
372
373 err := m.config.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
374 if err != nil {
375 return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
376 }
377
378 // Reset API key state and continue with model selection
379 selectedModel := *m.selectedModel
380 return tea.Sequence(
381 util.CmdHandler(dialogs.CloseDialogMsg{}),
382 util.CmdHandler(ModelSelectedMsg{
383 Model: agent.Model{
384 Model: selectedModel.Model.ID,
385 Provider: string(selectedModel.Provider.ID),
386 },
387 ModelType: m.selectedModelType,
388 }),
389 )
390}