1package models
2
3import (
4 "fmt"
5 "time"
6
7 "github.com/charmbracelet/bubbles/v2/help"
8 "github.com/charmbracelet/bubbles/v2/key"
9 "github.com/charmbracelet/bubbles/v2/spinner"
10 tea "github.com/charmbracelet/bubbletea/v2"
11 "github.com/charmbracelet/catwalk/pkg/catwalk"
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/tui/components/core"
14 "github.com/charmbracelet/crush/internal/tui/components/dialogs"
15 "github.com/charmbracelet/crush/internal/tui/exp/list"
16 "github.com/charmbracelet/crush/internal/tui/styles"
17 "github.com/charmbracelet/crush/internal/tui/util"
18 "github.com/charmbracelet/lipgloss/v2"
19)
20
21const (
22 ModelsDialogID dialogs.DialogID = "models"
23
24 defaultWidth = 60
25)
26
27const (
28 LargeModelType int = iota
29 SmallModelType
30
31 largeModelInputPlaceholder = "Choose a model for large, complex tasks"
32 smallModelInputPlaceholder = "Choose a model for small, simple tasks"
33)
34
35// ModelSelectedMsg is sent when a model is selected
36type ModelSelectedMsg struct {
37 Model config.SelectedModel
38 ModelType config.SelectedModelType
39}
40
41// CloseModelDialogMsg is sent when a model is selected
42type CloseModelDialogMsg struct{}
43
44// ModelDialog interface for the model selection dialog
45type ModelDialog interface {
46 dialogs.DialogModel
47}
48
49type ModelOption struct {
50 Provider catwalk.Provider
51 Model catwalk.Model
52}
53
54type modelDialogCmp struct {
55 width int
56 wWidth int
57 wHeight int
58
59 modelList *ModelListComponent
60 keyMap KeyMap
61 help help.Model
62
63 // API key state
64 needsAPIKey bool
65 apiKeyInput *APIKeyInput
66 selectedModel *ModelOption
67 selectedModelType config.SelectedModelType
68 isAPIKeyValid bool
69 apiKeyValue string
70}
71
72func NewModelDialogCmp() ModelDialog {
73 keyMap := DefaultKeyMap()
74
75 listKeyMap := list.DefaultKeyMap()
76 listKeyMap.Down.SetEnabled(false)
77 listKeyMap.Up.SetEnabled(false)
78 listKeyMap.DownOneItem = keyMap.Next
79 listKeyMap.UpOneItem = keyMap.Previous
80
81 t := styles.CurrentTheme()
82 modelList := NewModelListComponent(listKeyMap, "Choose a model for large, complex tasks", true)
83 apiKeyInput := NewAPIKeyInput()
84 apiKeyInput.SetShowTitle(false)
85 help := help.New()
86 help.Styles = t.S().Help
87
88 return &modelDialogCmp{
89 modelList: modelList,
90 apiKeyInput: apiKeyInput,
91 width: defaultWidth,
92 keyMap: DefaultKeyMap(),
93 help: help,
94 }
95}
96
97func (m *modelDialogCmp) Init() tea.Cmd {
98 return tea.Batch(m.modelList.Init(), m.apiKeyInput.Init())
99}
100
101func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
102 switch msg := msg.(type) {
103 case tea.WindowSizeMsg:
104 m.wWidth = msg.Width
105 m.wHeight = msg.Height
106 m.apiKeyInput.SetWidth(m.width - 2)
107 m.help.Width = m.width - 2
108 return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
109 case APIKeyStateChangeMsg:
110 u, cmd := m.apiKeyInput.Update(msg)
111 m.apiKeyInput = u.(*APIKeyInput)
112 return m, cmd
113 case tea.KeyPressMsg:
114 switch {
115 case key.Matches(msg, m.keyMap.Select):
116 if m.isAPIKeyValid {
117 return m, m.saveAPIKeyAndContinue(m.apiKeyValue)
118 }
119 if m.needsAPIKey {
120 // Handle API key submission
121 m.apiKeyValue = m.apiKeyInput.Value()
122 provider, err := m.getProvider(m.selectedModel.Provider.ID)
123 if err != nil || provider == nil {
124 return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
125 }
126 providerConfig := config.ProviderConfig{
127 ID: string(m.selectedModel.Provider.ID),
128 Name: m.selectedModel.Provider.Name,
129 APIKey: m.apiKeyValue,
130 Type: provider.Type,
131 BaseURL: provider.APIEndpoint,
132 }
133 return m, tea.Sequence(
134 util.CmdHandler(APIKeyStateChangeMsg{
135 State: APIKeyInputStateVerifying,
136 }),
137 func() tea.Msg {
138 start := time.Now()
139 err := providerConfig.TestConnection(config.Get().Resolver())
140 // intentionally wait for at least 750ms to make sure the user sees the spinner
141 elapsed := time.Since(start)
142 if elapsed < 750*time.Millisecond {
143 time.Sleep(750*time.Millisecond - elapsed)
144 }
145 if err == nil {
146 m.isAPIKeyValid = true
147 return APIKeyStateChangeMsg{
148 State: APIKeyInputStateVerified,
149 }
150 }
151 return APIKeyStateChangeMsg{
152 State: APIKeyInputStateError,
153 }
154 },
155 )
156 }
157 // Normal model selection
158 selectedItem := m.modelList.SelectedModel()
159
160 var modelType config.SelectedModelType
161 if m.modelList.GetModelType() == LargeModelType {
162 modelType = config.SelectedModelTypeLarge
163 } else {
164 modelType = config.SelectedModelTypeSmall
165 }
166
167 // Check if provider is configured
168 if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
169 return m, tea.Sequence(
170 util.CmdHandler(dialogs.CloseDialogMsg{}),
171 util.CmdHandler(ModelSelectedMsg{
172 Model: config.SelectedModel{
173 Model: selectedItem.Model.ID,
174 Provider: string(selectedItem.Provider.ID),
175 },
176 ModelType: modelType,
177 }),
178 )
179 } else {
180 // Provider not configured, show API key input
181 m.needsAPIKey = true
182 m.selectedModel = selectedItem
183 m.selectedModelType = modelType
184 m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
185 return m, nil
186 }
187 case key.Matches(msg, m.keyMap.Tab):
188 if m.needsAPIKey {
189 u, cmd := m.apiKeyInput.Update(msg)
190 m.apiKeyInput = u.(*APIKeyInput)
191 return m, cmd
192 }
193 if m.modelList.GetModelType() == LargeModelType {
194 m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
195 return m, m.modelList.SetModelType(SmallModelType)
196 } else {
197 m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
198 return m, m.modelList.SetModelType(LargeModelType)
199 }
200 case key.Matches(msg, m.keyMap.Close):
201 if m.needsAPIKey {
202 if m.isAPIKeyValid {
203 return m, nil
204 }
205 // Go back to model selection
206 m.needsAPIKey = false
207 m.selectedModel = nil
208 m.isAPIKeyValid = false
209 m.apiKeyValue = ""
210 m.apiKeyInput.Reset()
211 return m, nil
212 }
213 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
214 default:
215 if m.needsAPIKey {
216 u, cmd := m.apiKeyInput.Update(msg)
217 m.apiKeyInput = u.(*APIKeyInput)
218 return m, cmd
219 } else {
220 u, cmd := m.modelList.Update(msg)
221 m.modelList = u
222 return m, cmd
223 }
224 }
225 case tea.PasteMsg:
226 if m.needsAPIKey {
227 u, cmd := m.apiKeyInput.Update(msg)
228 m.apiKeyInput = u.(*APIKeyInput)
229 return m, cmd
230 } else {
231 var cmd tea.Cmd
232 m.modelList, cmd = m.modelList.Update(msg)
233 return m, cmd
234 }
235 case spinner.TickMsg:
236 u, cmd := m.apiKeyInput.Update(msg)
237 m.apiKeyInput = u.(*APIKeyInput)
238 return m, cmd
239 }
240 return m, nil
241}
242
243func (m *modelDialogCmp) View() string {
244 t := styles.CurrentTheme()
245
246 if m.needsAPIKey {
247 // Show API key input
248 m.keyMap.isAPIKeyHelp = true
249 m.keyMap.isAPIKeyValid = m.isAPIKeyValid
250 apiKeyView := m.apiKeyInput.View()
251 apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
252 content := lipgloss.JoinVertical(
253 lipgloss.Left,
254 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
255 apiKeyView,
256 "",
257 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
258 )
259 return m.style().Render(content)
260 }
261
262 // Show model selection
263 listView := m.modelList.View()
264 radio := m.modelTypeRadio()
265 content := lipgloss.JoinVertical(
266 lipgloss.Left,
267 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
268 listView,
269 "",
270 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
271 )
272 return m.style().Render(content)
273}
274
275func (m *modelDialogCmp) Cursor() *tea.Cursor {
276 if m.needsAPIKey {
277 cursor := m.apiKeyInput.Cursor()
278 if cursor != nil {
279 cursor = m.moveCursor(cursor)
280 return cursor
281 }
282 } else {
283 cursor := m.modelList.Cursor()
284 if cursor != nil {
285 cursor = m.moveCursor(cursor)
286 return cursor
287 }
288 }
289 return nil
290}
291
292func (m *modelDialogCmp) style() lipgloss.Style {
293 t := styles.CurrentTheme()
294 return t.S().Base.
295 Width(m.width).
296 Border(lipgloss.RoundedBorder()).
297 BorderForeground(t.BorderFocus)
298}
299
300func (m *modelDialogCmp) listWidth() int {
301 return m.width - 2
302}
303
304func (m *modelDialogCmp) listHeight() int {
305 return m.wHeight / 2
306}
307
308func (m *modelDialogCmp) Position() (int, int) {
309 row := m.wHeight/4 - 2 // just a bit above the center
310 col := m.wWidth / 2
311 col -= m.width / 2
312 return row, col
313}
314
315func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
316 row, col := m.Position()
317 if m.needsAPIKey {
318 offset := row + 3 // Border + title + API key input offset
319 cursor.Y += offset
320 cursor.X = cursor.X + col + 2
321 } else {
322 offset := row + 3 // Border + title
323 cursor.Y += offset
324 cursor.X = cursor.X + col + 2
325 }
326 return cursor
327}
328
329func (m *modelDialogCmp) ID() dialogs.DialogID {
330 return ModelsDialogID
331}
332
333func (m *modelDialogCmp) modelTypeRadio() string {
334 t := styles.CurrentTheme()
335 choices := []string{"Large Task", "Small Task"}
336 iconSelected := "◉"
337 iconUnselected := "○"
338 if m.modelList.GetModelType() == LargeModelType {
339 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1])
340 }
341 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1])
342}
343
344func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
345 cfg := config.Get()
346 if _, ok := cfg.Providers.Get(providerID); ok {
347 return true
348 }
349 return false
350}
351
352func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
353 providers, err := config.Providers()
354 if err != nil {
355 return nil, err
356 }
357 for _, p := range providers {
358 if p.ID == providerID {
359 return &p, nil
360 }
361 }
362 return nil, nil
363}
364
365func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
366 if m.selectedModel == nil {
367 return util.ReportError(fmt.Errorf("no model selected"))
368 }
369
370 cfg := config.Get()
371 err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
372 if err != nil {
373 return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
374 }
375
376 // Reset API key state and continue with model selection
377 selectedModel := *m.selectedModel
378 return tea.Sequence(
379 util.CmdHandler(dialogs.CloseDialogMsg{}),
380 util.CmdHandler(ModelSelectedMsg{
381 Model: config.SelectedModel{
382 Model: selectedModel.Model.ID,
383 Provider: string(selectedModel.Provider.ID),
384 },
385 ModelType: m.selectedModelType,
386 }),
387 )
388}