1package models
2
3import (
4 "fmt"
5 "time"
6
7 "github.com/charmbracelet/bubbles/v2/help"
8 "github.com/charmbracelet/bubbles/v2/key"
9 "github.com/charmbracelet/bubbles/v2/spinner"
10 tea "github.com/charmbracelet/bubbletea/v2"
11 "github.com/charmbracelet/catwalk/pkg/catwalk"
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/tui/components/completions"
14 "github.com/charmbracelet/crush/internal/tui/components/core"
15 "github.com/charmbracelet/crush/internal/tui/components/core/list"
16 "github.com/charmbracelet/crush/internal/tui/components/dialogs"
17 "github.com/charmbracelet/crush/internal/tui/styles"
18 "github.com/charmbracelet/crush/internal/tui/util"
19 "github.com/charmbracelet/lipgloss/v2"
20)
21
22const (
23 ModelsDialogID dialogs.DialogID = "models"
24
25 defaultWidth = 60
26)
27
28const (
29 LargeModelType int = iota
30 SmallModelType
31
32 largeModelInputPlaceholder = "Choose a model for large, complex tasks"
33 smallModelInputPlaceholder = "Choose a model for small, simple tasks"
34)
35
36// ModelSelectedMsg is sent when a model is selected
37type ModelSelectedMsg struct {
38 Model config.SelectedModel
39 ModelType config.SelectedModelType
40}
41
42// CloseModelDialogMsg is sent when a model is selected
43type CloseModelDialogMsg struct{}
44
45// ModelDialog interface for the model selection dialog
46type ModelDialog interface {
47 dialogs.DialogModel
48}
49
50type ModelOption struct {
51 Provider catwalk.Provider
52 Model catwalk.Model
53}
54
55type modelDialogCmp struct {
56 width int
57 wWidth int
58 wHeight int
59
60 modelList *ModelListComponent
61 keyMap KeyMap
62 help help.Model
63
64 // API key state
65 needsAPIKey bool
66 apiKeyInput *APIKeyInput
67 selectedModel *ModelOption
68 selectedModelType config.SelectedModelType
69 isAPIKeyValid bool
70 apiKeyValue string
71}
72
73func NewModelDialogCmp() ModelDialog {
74 listKeyMap := list.DefaultKeyMap()
75 keyMap := DefaultKeyMap()
76
77 listKeyMap.Down.SetEnabled(false)
78 listKeyMap.Up.SetEnabled(false)
79 listKeyMap.HalfPageDown.SetEnabled(false)
80 listKeyMap.HalfPageUp.SetEnabled(false)
81 listKeyMap.Home.SetEnabled(false)
82 listKeyMap.End.SetEnabled(false)
83
84 listKeyMap.DownOneItem = keyMap.Next
85 listKeyMap.UpOneItem = keyMap.Previous
86
87 t := styles.CurrentTheme()
88 inputStyle := t.S().Base.Padding(0, 1, 0, 1)
89 modelList := NewModelListComponent(listKeyMap, inputStyle, "Choose a model for large, complex tasks")
90 apiKeyInput := NewAPIKeyInput()
91 apiKeyInput.SetShowTitle(false)
92 help := help.New()
93 help.Styles = t.S().Help
94
95 return &modelDialogCmp{
96 modelList: modelList,
97 apiKeyInput: apiKeyInput,
98 width: defaultWidth,
99 keyMap: DefaultKeyMap(),
100 help: help,
101 }
102}
103
104func (m *modelDialogCmp) Init() tea.Cmd {
105 return tea.Batch(m.modelList.Init(), m.apiKeyInput.Init())
106}
107
108func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
109 switch msg := msg.(type) {
110 case tea.WindowSizeMsg:
111 m.wWidth = msg.Width
112 m.wHeight = msg.Height
113 m.apiKeyInput.SetWidth(m.width - 2)
114 m.help.Width = m.width - 2
115 return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
116 case APIKeyStateChangeMsg:
117 u, cmd := m.apiKeyInput.Update(msg)
118 m.apiKeyInput = u.(*APIKeyInput)
119 return m, cmd
120 case tea.KeyPressMsg:
121 switch {
122 case key.Matches(msg, m.keyMap.Select):
123 if m.isAPIKeyValid {
124 return m, m.saveAPIKeyAndContinue(m.apiKeyValue)
125 }
126 if m.needsAPIKey {
127 // Handle API key submission
128 m.apiKeyValue = m.apiKeyInput.Value()
129 provider, err := m.getProvider(m.selectedModel.Provider.ID)
130 if err != nil || provider == nil {
131 return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
132 }
133 providerConfig := config.ProviderConfig{
134 ID: string(m.selectedModel.Provider.ID),
135 Name: m.selectedModel.Provider.Name,
136 APIKey: m.apiKeyValue,
137 Type: provider.Type,
138 BaseURL: provider.APIEndpoint,
139 }
140 return m, tea.Sequence(
141 util.CmdHandler(APIKeyStateChangeMsg{
142 State: APIKeyInputStateVerifying,
143 }),
144 func() tea.Msg {
145 start := time.Now()
146 err := providerConfig.TestConnection(config.Get().Resolver())
147 // intentionally wait for at least 750ms to make sure the user sees the spinner
148 elapsed := time.Since(start)
149 if elapsed < 750*time.Millisecond {
150 time.Sleep(750*time.Millisecond - elapsed)
151 }
152 if err == nil {
153 m.isAPIKeyValid = true
154 return APIKeyStateChangeMsg{
155 State: APIKeyInputStateVerified,
156 }
157 }
158 return APIKeyStateChangeMsg{
159 State: APIKeyInputStateError,
160 }
161 },
162 )
163 }
164 // Normal model selection
165 selectedItemInx := m.modelList.SelectedIndex()
166 if selectedItemInx == list.NoSelection {
167 return m, nil
168 }
169 items := m.modelList.Items()
170 selectedItem := items[selectedItemInx].(completions.CompletionItem).Value().(ModelOption)
171
172 var modelType config.SelectedModelType
173 if m.modelList.GetModelType() == LargeModelType {
174 modelType = config.SelectedModelTypeLarge
175 } else {
176 modelType = config.SelectedModelTypeSmall
177 }
178
179 // Check if provider is configured
180 if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
181 return m, tea.Sequence(
182 util.CmdHandler(dialogs.CloseDialogMsg{}),
183 util.CmdHandler(ModelSelectedMsg{
184 Model: config.SelectedModel{
185 Model: selectedItem.Model.ID,
186 Provider: string(selectedItem.Provider.ID),
187 },
188 ModelType: modelType,
189 }),
190 )
191 } else {
192 // Provider not configured, show API key input
193 m.needsAPIKey = true
194 m.selectedModel = &selectedItem
195 m.selectedModelType = modelType
196 m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
197 return m, nil
198 }
199 case key.Matches(msg, m.keyMap.Tab):
200 if m.needsAPIKey {
201 u, cmd := m.apiKeyInput.Update(msg)
202 m.apiKeyInput = u.(*APIKeyInput)
203 return m, cmd
204 }
205 if m.modelList.GetModelType() == LargeModelType {
206 m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
207 return m, m.modelList.SetModelType(SmallModelType)
208 } else {
209 m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
210 return m, m.modelList.SetModelType(LargeModelType)
211 }
212 case key.Matches(msg, m.keyMap.Close):
213 if m.needsAPIKey {
214 if m.isAPIKeyValid {
215 return m, nil
216 }
217 // Go back to model selection
218 m.needsAPIKey = false
219 m.selectedModel = nil
220 m.isAPIKeyValid = false
221 m.apiKeyValue = ""
222 m.apiKeyInput.Reset()
223 return m, nil
224 }
225 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
226 default:
227 if m.needsAPIKey {
228 u, cmd := m.apiKeyInput.Update(msg)
229 m.apiKeyInput = u.(*APIKeyInput)
230 return m, cmd
231 } else {
232 u, cmd := m.modelList.Update(msg)
233 m.modelList = u
234 return m, cmd
235 }
236 }
237 case tea.PasteMsg:
238 if m.needsAPIKey {
239 u, cmd := m.apiKeyInput.Update(msg)
240 m.apiKeyInput = u.(*APIKeyInput)
241 return m, cmd
242 } else {
243 var cmd tea.Cmd
244 m.modelList, cmd = m.modelList.Update(msg)
245 return m, cmd
246 }
247 case spinner.TickMsg:
248 u, cmd := m.apiKeyInput.Update(msg)
249 m.apiKeyInput = u.(*APIKeyInput)
250 return m, cmd
251 }
252 return m, nil
253}
254
255func (m *modelDialogCmp) View() string {
256 t := styles.CurrentTheme()
257
258 if m.needsAPIKey {
259 // Show API key input
260 m.keyMap.isAPIKeyHelp = true
261 m.keyMap.isAPIKeyValid = m.isAPIKeyValid
262 apiKeyView := m.apiKeyInput.View()
263 apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
264 content := lipgloss.JoinVertical(
265 lipgloss.Left,
266 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
267 apiKeyView,
268 "",
269 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
270 )
271 return m.style().Render(content)
272 }
273
274 // Show model selection
275 listView := m.modelList.View()
276 radio := m.modelTypeRadio()
277 content := lipgloss.JoinVertical(
278 lipgloss.Left,
279 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
280 listView,
281 "",
282 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
283 )
284 return m.style().Render(content)
285}
286
287func (m *modelDialogCmp) Cursor() *tea.Cursor {
288 if m.needsAPIKey {
289 cursor := m.apiKeyInput.Cursor()
290 if cursor != nil {
291 cursor = m.moveCursor(cursor)
292 return cursor
293 }
294 } else {
295 cursor := m.modelList.Cursor()
296 if cursor != nil {
297 cursor = m.moveCursor(cursor)
298 return cursor
299 }
300 }
301 return nil
302}
303
304func (m *modelDialogCmp) style() lipgloss.Style {
305 t := styles.CurrentTheme()
306 return t.S().Base.
307 Width(m.width).
308 Border(lipgloss.RoundedBorder()).
309 BorderForeground(t.BorderFocus)
310}
311
312func (m *modelDialogCmp) listWidth() int {
313 return defaultWidth - 2 // 4 for padding
314}
315
316func (m *modelDialogCmp) listHeight() int {
317 items := m.modelList.Items()
318 listHeigh := len(items) + 2 + 4
319 return min(listHeigh, m.wHeight/2)
320}
321
322func (m *modelDialogCmp) Position() (int, int) {
323 row := m.wHeight/4 - 2 // just a bit above the center
324 col := m.wWidth / 2
325 col -= m.width / 2
326 return row, col
327}
328
329func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
330 row, col := m.Position()
331 if m.needsAPIKey {
332 offset := row + 3 // Border + title + API key input offset
333 cursor.Y += offset
334 cursor.X = cursor.X + col + 2
335 } else {
336 offset := row + 3 // Border + title
337 cursor.Y += offset
338 cursor.X = cursor.X + col + 2
339 }
340 return cursor
341}
342
343func (m *modelDialogCmp) ID() dialogs.DialogID {
344 return ModelsDialogID
345}
346
347func (m *modelDialogCmp) modelTypeRadio() string {
348 t := styles.CurrentTheme()
349 choices := []string{"Large Task", "Small Task"}
350 iconSelected := "◉"
351 iconUnselected := "○"
352 if m.modelList.GetModelType() == LargeModelType {
353 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1])
354 }
355 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1])
356}
357
358func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
359 cfg := config.Get()
360 if _, ok := cfg.Providers[providerID]; ok {
361 return true
362 }
363 return false
364}
365
366func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
367 providers, err := config.Providers()
368 if err != nil {
369 return nil, err
370 }
371 for _, p := range providers {
372 if p.ID == providerID {
373 return &p, nil
374 }
375 }
376 return nil, nil
377}
378
379func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
380 if m.selectedModel == nil {
381 return util.ReportError(fmt.Errorf("no model selected"))
382 }
383
384 cfg := config.Get()
385 err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
386 if err != nil {
387 return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
388 }
389
390 // Reset API key state and continue with model selection
391 selectedModel := *m.selectedModel
392 return tea.Sequence(
393 util.CmdHandler(dialogs.CloseDialogMsg{}),
394 util.CmdHandler(ModelSelectedMsg{
395 Model: config.SelectedModel{
396 Model: selectedModel.Model.ID,
397 Provider: string(selectedModel.Provider.ID),
398 },
399 ModelType: m.selectedModelType,
400 }),
401 )
402}