1package models
2
3import (
4 "fmt"
5 "time"
6
7 "github.com/charmbracelet/bubbles/v2/help"
8 "github.com/charmbracelet/bubbles/v2/key"
9 "github.com/charmbracelet/bubbles/v2/spinner"
10 tea "github.com/charmbracelet/bubbletea/v2"
11 "github.com/charmbracelet/catwalk/pkg/catwalk"
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/tui/components/core"
14 "github.com/charmbracelet/crush/internal/tui/components/dialogs"
15 "github.com/charmbracelet/crush/internal/tui/exp/list"
16 "github.com/charmbracelet/crush/internal/tui/styles"
17 "github.com/charmbracelet/crush/internal/tui/util"
18 "github.com/charmbracelet/lipgloss/v2"
19)
20
21const (
22 ModelsDialogID dialogs.DialogID = "models"
23
24 defaultWidth = 60
25)
26
27const (
28 LargeModelType int = iota
29 SmallModelType
30
31 largeModelInputPlaceholder = "Choose a model for large, complex tasks"
32 smallModelInputPlaceholder = "Choose a model for small, simple tasks"
33)
34
35// ModelSelectedMsg is sent when a model is selected
36type ModelSelectedMsg struct {
37 Model config.SelectedModel
38 ModelType config.SelectedModelType
39}
40
41// CloseModelDialogMsg is sent when a model is selected
42type CloseModelDialogMsg struct{}
43
44// ModelDialog interface for the model selection dialog
45type ModelDialog interface {
46 dialogs.DialogModel
47}
48
49type ModelOption struct {
50 Provider catwalk.Provider
51 Model catwalk.Model
52}
53
54type modelDialogCmp struct {
55 width int
56 wWidth int
57 wHeight int
58
59 modelList *ModelListComponent
60 keyMap KeyMap
61 help help.Model
62
63 // API key state
64 needsAPIKey bool
65 apiKeyInput *APIKeyInput
66 selectedModel *ModelOption
67 selectedModelType config.SelectedModelType
68 isAPIKeyValid bool
69 apiKeyValue string
70
71 cfg *config.Config
72}
73
74func NewModelDialogCmp(cfg *config.Config) ModelDialog {
75 keyMap := DefaultKeyMap()
76
77 listKeyMap := list.DefaultKeyMap()
78 listKeyMap.Down.SetEnabled(false)
79 listKeyMap.Up.SetEnabled(false)
80 listKeyMap.DownOneItem = keyMap.Next
81 listKeyMap.UpOneItem = keyMap.Previous
82
83 t := styles.CurrentTheme()
84 modelList := NewModelListComponent(cfg, listKeyMap, largeModelInputPlaceholder, true)
85 apiKeyInput := NewAPIKeyInput()
86 apiKeyInput.SetShowTitle(false)
87 help := help.New()
88 help.Styles = t.S().Help
89
90 return &modelDialogCmp{
91 modelList: modelList,
92 apiKeyInput: apiKeyInput,
93 width: defaultWidth,
94 keyMap: DefaultKeyMap(),
95 help: help,
96 cfg: cfg,
97 }
98}
99
100func (m *modelDialogCmp) Init() tea.Cmd {
101 return tea.Batch(m.modelList.Init(), m.apiKeyInput.Init())
102}
103
104func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
105 switch msg := msg.(type) {
106 case tea.WindowSizeMsg:
107 m.wWidth = msg.Width
108 m.wHeight = msg.Height
109 m.apiKeyInput.SetWidth(m.width - 2)
110 m.help.Width = m.width - 2
111 return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
112 case APIKeyStateChangeMsg:
113 u, cmd := m.apiKeyInput.Update(msg)
114 m.apiKeyInput = u.(*APIKeyInput)
115 return m, cmd
116 case tea.KeyPressMsg:
117 switch {
118 case key.Matches(msg, m.keyMap.Select):
119 if m.isAPIKeyValid {
120 return m, m.saveAPIKeyAndContinue(m.apiKeyValue)
121 }
122 if m.needsAPIKey {
123 // Handle API key submission
124 m.apiKeyValue = m.apiKeyInput.Value()
125 provider, err := m.getProvider(m.selectedModel.Provider.ID)
126 if err != nil || provider == nil {
127 return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
128 }
129 providerConfig := config.ProviderConfig{
130 ID: string(m.selectedModel.Provider.ID),
131 Name: m.selectedModel.Provider.Name,
132 APIKey: m.apiKeyValue,
133 Type: provider.Type,
134 BaseURL: provider.APIEndpoint,
135 }
136 return m, tea.Sequence(
137 util.CmdHandler(APIKeyStateChangeMsg{
138 State: APIKeyInputStateVerifying,
139 }),
140 func() tea.Msg {
141 start := time.Now()
142 err := providerConfig.TestConnection(m.cfg.Resolver())
143 // intentionally wait for at least 750ms to make sure the user sees the spinner
144 elapsed := time.Since(start)
145 if elapsed < 750*time.Millisecond {
146 time.Sleep(750*time.Millisecond - elapsed)
147 }
148 if err == nil {
149 m.isAPIKeyValid = true
150 return APIKeyStateChangeMsg{
151 State: APIKeyInputStateVerified,
152 }
153 }
154 return APIKeyStateChangeMsg{
155 State: APIKeyInputStateError,
156 }
157 },
158 )
159 }
160 // Normal model selection
161 selectedItem := m.modelList.SelectedModel()
162
163 var modelType config.SelectedModelType
164 if m.modelList.GetModelType() == LargeModelType {
165 modelType = config.SelectedModelTypeLarge
166 } else {
167 modelType = config.SelectedModelTypeSmall
168 }
169
170 // Check if provider is configured
171 if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
172 return m, tea.Sequence(
173 util.CmdHandler(dialogs.CloseDialogMsg{}),
174 util.CmdHandler(ModelSelectedMsg{
175 Model: config.SelectedModel{
176 Model: selectedItem.Model.ID,
177 Provider: string(selectedItem.Provider.ID),
178 ReasoningEffort: selectedItem.Model.DefaultReasoningEffort,
179 MaxTokens: selectedItem.Model.DefaultMaxTokens,
180 },
181 ModelType: modelType,
182 }),
183 )
184 } else {
185 // Provider not configured, show API key input
186 m.needsAPIKey = true
187 m.selectedModel = selectedItem
188 m.selectedModelType = modelType
189 m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
190 return m, nil
191 }
192 case key.Matches(msg, m.keyMap.Tab):
193 if m.needsAPIKey {
194 u, cmd := m.apiKeyInput.Update(msg)
195 m.apiKeyInput = u.(*APIKeyInput)
196 return m, cmd
197 }
198 if m.modelList.GetModelType() == LargeModelType {
199 m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
200 return m, m.modelList.SetModelType(SmallModelType)
201 } else {
202 m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
203 return m, m.modelList.SetModelType(LargeModelType)
204 }
205 case key.Matches(msg, m.keyMap.Close):
206 if m.needsAPIKey {
207 if m.isAPIKeyValid {
208 return m, nil
209 }
210 // Go back to model selection
211 m.needsAPIKey = false
212 m.selectedModel = nil
213 m.isAPIKeyValid = false
214 m.apiKeyValue = ""
215 m.apiKeyInput.Reset()
216 return m, nil
217 }
218 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
219 default:
220 if m.needsAPIKey {
221 u, cmd := m.apiKeyInput.Update(msg)
222 m.apiKeyInput = u.(*APIKeyInput)
223 return m, cmd
224 } else {
225 u, cmd := m.modelList.Update(msg)
226 m.modelList = u
227 return m, cmd
228 }
229 }
230 case tea.PasteMsg:
231 if m.needsAPIKey {
232 u, cmd := m.apiKeyInput.Update(msg)
233 m.apiKeyInput = u.(*APIKeyInput)
234 return m, cmd
235 } else {
236 var cmd tea.Cmd
237 m.modelList, cmd = m.modelList.Update(msg)
238 return m, cmd
239 }
240 case spinner.TickMsg:
241 u, cmd := m.apiKeyInput.Update(msg)
242 m.apiKeyInput = u.(*APIKeyInput)
243 return m, cmd
244 }
245 return m, nil
246}
247
248func (m *modelDialogCmp) View() string {
249 t := styles.CurrentTheme()
250
251 if m.needsAPIKey {
252 // Show API key input
253 m.keyMap.isAPIKeyHelp = true
254 m.keyMap.isAPIKeyValid = m.isAPIKeyValid
255 apiKeyView := m.apiKeyInput.View()
256 apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
257 content := lipgloss.JoinVertical(
258 lipgloss.Left,
259 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
260 apiKeyView,
261 "",
262 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
263 )
264 return m.style().Render(content)
265 }
266
267 // Show model selection
268 listView := m.modelList.View()
269 radio := m.modelTypeRadio()
270 content := lipgloss.JoinVertical(
271 lipgloss.Left,
272 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
273 listView,
274 "",
275 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
276 )
277 return m.style().Render(content)
278}
279
280func (m *modelDialogCmp) Cursor() *tea.Cursor {
281 if m.needsAPIKey {
282 cursor := m.apiKeyInput.Cursor()
283 if cursor != nil {
284 cursor = m.moveCursor(cursor)
285 return cursor
286 }
287 } else {
288 cursor := m.modelList.Cursor()
289 if cursor != nil {
290 cursor = m.moveCursor(cursor)
291 return cursor
292 }
293 }
294 return nil
295}
296
297func (m *modelDialogCmp) style() lipgloss.Style {
298 t := styles.CurrentTheme()
299 return t.S().Base.
300 Width(m.width).
301 Border(lipgloss.RoundedBorder()).
302 BorderForeground(t.BorderFocus)
303}
304
305func (m *modelDialogCmp) listWidth() int {
306 return m.width - 2
307}
308
309func (m *modelDialogCmp) listHeight() int {
310 return m.wHeight / 2
311}
312
313func (m *modelDialogCmp) Position() (int, int) {
314 row := m.wHeight/4 - 2 // just a bit above the center
315 col := m.wWidth / 2
316 col -= m.width / 2
317 return row, col
318}
319
320func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
321 row, col := m.Position()
322 if m.needsAPIKey {
323 offset := row + 3 // Border + title + API key input offset
324 cursor.Y += offset
325 cursor.X = cursor.X + col + 2
326 } else {
327 offset := row + 3 // Border + title
328 cursor.Y += offset
329 cursor.X = cursor.X + col + 2
330 }
331 return cursor
332}
333
334func (m *modelDialogCmp) ID() dialogs.DialogID {
335 return ModelsDialogID
336}
337
338func (m *modelDialogCmp) modelTypeRadio() string {
339 t := styles.CurrentTheme()
340 choices := []string{"Large Task", "Small Task"}
341 iconSelected := "◉"
342 iconUnselected := "○"
343 if m.modelList.GetModelType() == LargeModelType {
344 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1])
345 }
346 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1])
347}
348
349func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
350 if _, ok := m.cfg.Providers.Get(providerID); ok {
351 return true
352 }
353 return false
354}
355
356func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
357 providers, err := config.Providers(m.cfg)
358 if err != nil {
359 return nil, err
360 }
361 for _, p := range providers {
362 if p.ID == providerID {
363 return &p, nil
364 }
365 }
366 return nil, nil
367}
368
369func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
370 if m.selectedModel == nil {
371 return util.ReportError(fmt.Errorf("no model selected"))
372 }
373
374 err := m.cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
375 if err != nil {
376 return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
377 }
378
379 // Reset API key state and continue with model selection
380 selectedModel := *m.selectedModel
381 return tea.Sequence(
382 util.CmdHandler(dialogs.CloseDialogMsg{}),
383 util.CmdHandler(ModelSelectedMsg{
384 Model: config.SelectedModel{
385 Model: selectedModel.Model.ID,
386 Provider: string(selectedModel.Provider.ID),
387 ReasoningEffort: selectedModel.Model.DefaultReasoningEffort,
388 MaxTokens: selectedModel.Model.DefaultMaxTokens,
389 },
390 ModelType: m.selectedModelType,
391 }),
392 )
393}