1package models
2
3import (
4 "fmt"
5 "time"
6
7 "github.com/charmbracelet/bubbles/v2/help"
8 "github.com/charmbracelet/bubbles/v2/key"
9 "github.com/charmbracelet/bubbles/v2/spinner"
10 tea "github.com/charmbracelet/bubbletea/v2"
11 "github.com/charmbracelet/catwalk/pkg/catwalk"
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/tui/components/core"
14 "github.com/charmbracelet/crush/internal/tui/components/dialogs"
15 "github.com/charmbracelet/crush/internal/tui/exp/list"
16 "github.com/charmbracelet/crush/internal/tui/styles"
17 "github.com/charmbracelet/crush/internal/tui/util"
18 "github.com/charmbracelet/lipgloss/v2"
19)
20
21const (
22 ModelsDialogID dialogs.DialogID = "models"
23
24 defaultWidth = 60
25)
26
27const (
28 LargeModelType int = iota
29 SmallModelType
30
31 largeModelInputPlaceholder = "Choose a model for large, complex tasks"
32 smallModelInputPlaceholder = "Choose a model for small, simple tasks"
33)
34
35// ModelSelectedMsg is sent when a model is selected
36type ModelSelectedMsg struct {
37 Model config.SelectedModel
38 ModelType config.SelectedModelType
39}
40
41// CloseModelDialogMsg is sent when a model is selected
42type CloseModelDialogMsg struct{}
43
44// ModelDialog interface for the model selection dialog
45type ModelDialog interface {
46 dialogs.DialogModel
47}
48
49type ModelOption struct {
50 Provider catwalk.Provider
51 Model catwalk.Model
52}
53
54type modelDialogCmp struct {
55 width int
56 wWidth int
57 wHeight int
58
59 modelList *ModelListComponent
60 keyMap KeyMap
61 help help.Model
62
63 // API key state
64 needsAPIKey bool
65 apiKeyInput *APIKeyInput
66 selectedModel *ModelOption
67 selectedModelType config.SelectedModelType
68 isAPIKeyValid bool
69 apiKeyValue string
70}
71
72func NewModelDialogCmp() ModelDialog {
73 keyMap := DefaultKeyMap()
74
75 listKeyMap := list.DefaultKeyMap()
76 listKeyMap.Down.SetEnabled(false)
77 listKeyMap.Up.SetEnabled(false)
78 listKeyMap.DownOneItem = keyMap.Next
79 listKeyMap.UpOneItem = keyMap.Previous
80
81 t := styles.CurrentTheme()
82 modelList := NewModelListComponent(listKeyMap, largeModelInputPlaceholder, true)
83 apiKeyInput := NewAPIKeyInput()
84 apiKeyInput.SetShowTitle(false)
85 help := help.New()
86 help.Styles = t.S().Help
87
88 return &modelDialogCmp{
89 modelList: modelList,
90 apiKeyInput: apiKeyInput,
91 width: defaultWidth,
92 keyMap: DefaultKeyMap(),
93 help: help,
94 }
95}
96
97func (m *modelDialogCmp) Init() tea.Cmd {
98 return tea.Batch(m.modelList.Init(), m.apiKeyInput.Init())
99}
100
101func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
102 switch msg := msg.(type) {
103 case tea.WindowSizeMsg:
104 m.wWidth = msg.Width
105 m.wHeight = msg.Height
106 m.apiKeyInput.SetWidth(m.width - 2)
107 m.help.Width = m.width - 2
108 return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
109 case APIKeyStateChangeMsg:
110 u, cmd := m.apiKeyInput.Update(msg)
111 m.apiKeyInput = u.(*APIKeyInput)
112 return m, cmd
113 case tea.KeyPressMsg:
114 switch {
115 case key.Matches(msg, m.keyMap.Select):
116 if m.isAPIKeyValid {
117 return m, m.saveAPIKeyAndContinue(m.apiKeyValue)
118 }
119 if m.needsAPIKey {
120 // Handle API key submission
121 m.apiKeyValue = m.apiKeyInput.Value()
122 provider, err := m.getProvider(m.selectedModel.Provider.ID)
123 if err != nil || provider == nil {
124 return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
125 }
126 providerConfig := config.ProviderConfig{
127 ID: string(m.selectedModel.Provider.ID),
128 Name: m.selectedModel.Provider.Name,
129 APIKey: m.apiKeyValue,
130 Type: provider.Type,
131 BaseURL: provider.APIEndpoint,
132 }
133 return m, tea.Sequence(
134 util.CmdHandler(APIKeyStateChangeMsg{
135 State: APIKeyInputStateVerifying,
136 }),
137 func() tea.Msg {
138 start := time.Now()
139 err := providerConfig.TestConnection(config.Get().Resolver())
140 // intentionally wait for at least 750ms to make sure the user sees the spinner
141 elapsed := time.Since(start)
142 if elapsed < 750*time.Millisecond {
143 time.Sleep(750*time.Millisecond - elapsed)
144 }
145 if err == nil {
146 m.isAPIKeyValid = true
147 return APIKeyStateChangeMsg{
148 State: APIKeyInputStateVerified,
149 }
150 }
151 return APIKeyStateChangeMsg{
152 State: APIKeyInputStateError,
153 }
154 },
155 )
156 }
157 // Normal model selection
158 selectedItem := m.modelList.SelectedModel()
159
160 var modelType config.SelectedModelType
161 if m.modelList.GetModelType() == LargeModelType {
162 modelType = config.SelectedModelTypeLarge
163 } else {
164 modelType = config.SelectedModelTypeSmall
165 }
166
167 // Check if provider is configured
168 if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
169 return m, tea.Sequence(
170 util.CmdHandler(dialogs.CloseDialogMsg{}),
171 util.CmdHandler(ModelSelectedMsg{
172 Model: config.SelectedModel{
173 Model: selectedItem.Model.ID,
174 Provider: string(selectedItem.Provider.ID),
175 ReasoningEffort: selectedItem.Model.DefaultReasoningEffort,
176 MaxTokens: selectedItem.Model.DefaultMaxTokens,
177 },
178 ModelType: modelType,
179 }),
180 )
181 } else {
182 // Provider not configured, show API key input
183 m.needsAPIKey = true
184 m.selectedModel = selectedItem
185 m.selectedModelType = modelType
186 m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
187 return m, nil
188 }
189 case key.Matches(msg, m.keyMap.Tab):
190 if m.needsAPIKey {
191 u, cmd := m.apiKeyInput.Update(msg)
192 m.apiKeyInput = u.(*APIKeyInput)
193 return m, cmd
194 }
195 if m.modelList.GetModelType() == LargeModelType {
196 m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
197 return m, m.modelList.SetModelType(SmallModelType)
198 } else {
199 m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
200 return m, m.modelList.SetModelType(LargeModelType)
201 }
202 case key.Matches(msg, m.keyMap.Close):
203 if m.needsAPIKey {
204 if m.isAPIKeyValid {
205 return m, nil
206 }
207 // Go back to model selection
208 m.needsAPIKey = false
209 m.selectedModel = nil
210 m.isAPIKeyValid = false
211 m.apiKeyValue = ""
212 m.apiKeyInput.Reset()
213 return m, nil
214 }
215 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
216 default:
217 if m.needsAPIKey {
218 u, cmd := m.apiKeyInput.Update(msg)
219 m.apiKeyInput = u.(*APIKeyInput)
220 return m, cmd
221 } else {
222 u, cmd := m.modelList.Update(msg)
223 m.modelList = u
224 return m, cmd
225 }
226 }
227 case tea.PasteMsg:
228 if m.needsAPIKey {
229 u, cmd := m.apiKeyInput.Update(msg)
230 m.apiKeyInput = u.(*APIKeyInput)
231 return m, cmd
232 } else {
233 var cmd tea.Cmd
234 m.modelList, cmd = m.modelList.Update(msg)
235 return m, cmd
236 }
237 case spinner.TickMsg:
238 u, cmd := m.apiKeyInput.Update(msg)
239 m.apiKeyInput = u.(*APIKeyInput)
240 return m, cmd
241 }
242 return m, nil
243}
244
245func (m *modelDialogCmp) View() string {
246 t := styles.CurrentTheme()
247
248 if m.needsAPIKey {
249 // Show API key input
250 m.keyMap.isAPIKeyHelp = true
251 m.keyMap.isAPIKeyValid = m.isAPIKeyValid
252 apiKeyView := m.apiKeyInput.View()
253 apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
254 content := lipgloss.JoinVertical(
255 lipgloss.Left,
256 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
257 apiKeyView,
258 "",
259 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
260 )
261 return m.style().Render(content)
262 }
263
264 // Show model selection
265 listView := m.modelList.View()
266 radio := m.modelTypeRadio()
267 content := lipgloss.JoinVertical(
268 lipgloss.Left,
269 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
270 listView,
271 "",
272 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
273 )
274 return m.style().Render(content)
275}
276
277func (m *modelDialogCmp) Cursor() *tea.Cursor {
278 if m.needsAPIKey {
279 cursor := m.apiKeyInput.Cursor()
280 if cursor != nil {
281 cursor = m.moveCursor(cursor)
282 return cursor
283 }
284 } else {
285 cursor := m.modelList.Cursor()
286 if cursor != nil {
287 cursor = m.moveCursor(cursor)
288 return cursor
289 }
290 }
291 return nil
292}
293
294func (m *modelDialogCmp) style() lipgloss.Style {
295 t := styles.CurrentTheme()
296 return t.S().Base.
297 Width(m.width).
298 Border(lipgloss.RoundedBorder()).
299 BorderForeground(t.BorderFocus)
300}
301
302func (m *modelDialogCmp) listWidth() int {
303 return m.width - 2
304}
305
306func (m *modelDialogCmp) listHeight() int {
307 return m.wHeight / 2
308}
309
310func (m *modelDialogCmp) Position() (int, int) {
311 row := m.wHeight/4 - 2 // just a bit above the center
312 col := m.wWidth / 2
313 col -= m.width / 2
314 return row, col
315}
316
317func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
318 row, col := m.Position()
319 if m.needsAPIKey {
320 offset := row + 3 // Border + title + API key input offset
321 cursor.Y += offset
322 cursor.X = cursor.X + col + 2
323 } else {
324 offset := row + 3 // Border + title
325 cursor.Y += offset
326 cursor.X = cursor.X + col + 2
327 }
328 return cursor
329}
330
331func (m *modelDialogCmp) ID() dialogs.DialogID {
332 return ModelsDialogID
333}
334
335func (m *modelDialogCmp) modelTypeRadio() string {
336 t := styles.CurrentTheme()
337 choices := []string{"Large Task", "Small Task"}
338 iconSelected := "◉"
339 iconUnselected := "○"
340 if m.modelList.GetModelType() == LargeModelType {
341 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1])
342 }
343 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1])
344}
345
346func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
347 cfg := config.Get()
348 if _, ok := cfg.Providers.Get(providerID); ok {
349 return true
350 }
351 return false
352}
353
354func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
355 providers, err := config.Providers()
356 if err != nil {
357 return nil, err
358 }
359 for _, p := range providers {
360 if p.ID == providerID {
361 return &p, nil
362 }
363 }
364 return nil, nil
365}
366
367func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
368 if m.selectedModel == nil {
369 return util.ReportError(fmt.Errorf("no model selected"))
370 }
371
372 cfg := config.Get()
373 err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
374 if err != nil {
375 return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
376 }
377
378 // Reset API key state and continue with model selection
379 selectedModel := *m.selectedModel
380 return tea.Sequence(
381 util.CmdHandler(dialogs.CloseDialogMsg{}),
382 util.CmdHandler(ModelSelectedMsg{
383 Model: config.SelectedModel{
384 Model: selectedModel.Model.ID,
385 Provider: string(selectedModel.Provider.ID),
386 ReasoningEffort: selectedModel.Model.DefaultReasoningEffort,
387 MaxTokens: selectedModel.Model.DefaultMaxTokens,
388 },
389 ModelType: m.selectedModelType,
390 }),
391 )
392}