1package models
2
3import (
4 "fmt"
5 "strings"
6 "time"
7
8 "github.com/charmbracelet/bubbles/v2/help"
9 "github.com/charmbracelet/bubbles/v2/key"
10 "github.com/charmbracelet/bubbles/v2/spinner"
11 tea "github.com/charmbracelet/bubbletea/v2"
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/tui/components/core"
15 "github.com/charmbracelet/crush/internal/tui/components/dialogs"
16 "github.com/charmbracelet/crush/internal/tui/exp/list"
17 "github.com/charmbracelet/crush/internal/tui/styles"
18 "github.com/charmbracelet/crush/internal/tui/util"
19 "github.com/charmbracelet/lipgloss/v2"
20)
21
22const (
23 ModelsDialogID dialogs.DialogID = "models"
24
25 defaultWidth = 60
26)
27
28const (
29 LargeModelType int = iota
30 SmallModelType
31
32 largeModelInputPlaceholder = "Choose a model for large, complex tasks"
33 smallModelInputPlaceholder = "Choose a model for small, simple tasks"
34)
35
36// ModelSelectedMsg is sent when a model is selected
37type ModelSelectedMsg struct {
38 Model config.SelectedModel
39 ModelType config.SelectedModelType
40}
41
42// CloseModelDialogMsg is sent when a model is selected
43type CloseModelDialogMsg struct{}
44
45// ModelDialog interface for the model selection dialog
46type ModelDialog interface {
47 dialogs.DialogModel
48}
49
50type ModelOption struct {
51 Provider catwalk.Provider
52 Model catwalk.Model
53}
54
55type modelDialogCmp struct {
56 width int
57 wWidth int
58 wHeight int
59
60 modelList *ModelListComponent
61 keyMap KeyMap
62 help help.Model
63
64 // API key state
65 needsAPIKey bool
66 apiKeyInput *APIKeyInput
67 selectedModel *ModelOption
68 selectedModelType config.SelectedModelType
69 isAPIKeyValid bool
70 apiKeyValue string
71}
72
73func NewModelDialogCmp() ModelDialog {
74 keyMap := DefaultKeyMap()
75
76 listKeyMap := list.DefaultKeyMap()
77 listKeyMap.Down.SetEnabled(false)
78 listKeyMap.Up.SetEnabled(false)
79 listKeyMap.DownOneItem = keyMap.Next
80 listKeyMap.UpOneItem = keyMap.Previous
81
82 t := styles.CurrentTheme()
83 modelList := NewModelListComponent(listKeyMap, "Choose a model for large, complex tasks", true)
84 apiKeyInput := NewAPIKeyInput()
85 apiKeyInput.SetShowTitle(false)
86 help := help.New()
87 help.Styles = t.S().Help
88
89 return &modelDialogCmp{
90 modelList: modelList,
91 apiKeyInput: apiKeyInput,
92 width: defaultWidth,
93 keyMap: DefaultKeyMap(),
94 help: help,
95 }
96}
97
98func (m *modelDialogCmp) Init() tea.Cmd {
99 providers, err := config.Providers()
100 if err == nil {
101 filteredProviders := []catwalk.Provider{}
102 for _, p := range providers {
103 if strings.HasPrefix(p.APIKey, "$") && p.ID != catwalk.InferenceProviderAzure {
104 filteredProviders = append(filteredProviders, p)
105 }
106 }
107 m.modelList.SetProviders(filteredProviders)
108 }
109 return tea.Batch(m.modelList.Init(), m.apiKeyInput.Init())
110}
111
112func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
113 switch msg := msg.(type) {
114 case tea.WindowSizeMsg:
115 m.wWidth = msg.Width
116 m.wHeight = msg.Height
117 m.apiKeyInput.SetWidth(m.width - 2)
118 m.help.Width = m.width - 2
119 return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
120 case APIKeyStateChangeMsg:
121 u, cmd := m.apiKeyInput.Update(msg)
122 m.apiKeyInput = u.(*APIKeyInput)
123 return m, cmd
124 case tea.KeyPressMsg:
125 switch {
126 case key.Matches(msg, m.keyMap.Select):
127 if m.isAPIKeyValid {
128 return m, m.saveAPIKeyAndContinue(m.apiKeyValue)
129 }
130 if m.needsAPIKey {
131 // Handle API key submission
132 m.apiKeyValue = m.apiKeyInput.Value()
133 provider, err := m.getProvider(m.selectedModel.Provider.ID)
134 if err != nil || provider == nil {
135 return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
136 }
137 providerConfig := config.ProviderConfig{
138 ID: string(m.selectedModel.Provider.ID),
139 Name: m.selectedModel.Provider.Name,
140 APIKey: m.apiKeyValue,
141 Type: provider.Type,
142 BaseURL: provider.APIEndpoint,
143 }
144 return m, tea.Sequence(
145 util.CmdHandler(APIKeyStateChangeMsg{
146 State: APIKeyInputStateVerifying,
147 }),
148 func() tea.Msg {
149 start := time.Now()
150 err := providerConfig.TestConnection(config.Get().Resolver())
151 // intentionally wait for at least 750ms to make sure the user sees the spinner
152 elapsed := time.Since(start)
153 if elapsed < 750*time.Millisecond {
154 time.Sleep(750*time.Millisecond - elapsed)
155 }
156 if err == nil {
157 m.isAPIKeyValid = true
158 return APIKeyStateChangeMsg{
159 State: APIKeyInputStateVerified,
160 }
161 }
162 return APIKeyStateChangeMsg{
163 State: APIKeyInputStateError,
164 }
165 },
166 )
167 }
168 // Normal model selection
169 selectedItem := m.modelList.SelectedModel()
170
171 var modelType config.SelectedModelType
172 if m.modelList.GetModelType() == LargeModelType {
173 modelType = config.SelectedModelTypeLarge
174 } else {
175 modelType = config.SelectedModelTypeSmall
176 }
177
178 // Check if provider is configured
179 if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
180 return m, tea.Sequence(
181 util.CmdHandler(dialogs.CloseDialogMsg{}),
182 util.CmdHandler(ModelSelectedMsg{
183 Model: config.SelectedModel{
184 Model: selectedItem.Model.ID,
185 Provider: string(selectedItem.Provider.ID),
186 },
187 ModelType: modelType,
188 }),
189 )
190 } else {
191 // Provider not configured, show API key input
192 m.needsAPIKey = true
193 m.selectedModel = selectedItem
194 m.selectedModelType = modelType
195 m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
196 return m, nil
197 }
198 case key.Matches(msg, m.keyMap.Tab):
199 if m.needsAPIKey {
200 u, cmd := m.apiKeyInput.Update(msg)
201 m.apiKeyInput = u.(*APIKeyInput)
202 return m, cmd
203 }
204 if m.modelList.GetModelType() == LargeModelType {
205 m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
206 return m, m.modelList.SetModelType(SmallModelType)
207 } else {
208 m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
209 return m, m.modelList.SetModelType(LargeModelType)
210 }
211 case key.Matches(msg, m.keyMap.Close):
212 if m.needsAPIKey {
213 if m.isAPIKeyValid {
214 return m, nil
215 }
216 // Go back to model selection
217 m.needsAPIKey = false
218 m.selectedModel = nil
219 m.isAPIKeyValid = false
220 m.apiKeyValue = ""
221 m.apiKeyInput.Reset()
222 return m, nil
223 }
224 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
225 default:
226 if m.needsAPIKey {
227 u, cmd := m.apiKeyInput.Update(msg)
228 m.apiKeyInput = u.(*APIKeyInput)
229 return m, cmd
230 } else {
231 u, cmd := m.modelList.Update(msg)
232 m.modelList = u
233 return m, cmd
234 }
235 }
236 case tea.PasteMsg:
237 if m.needsAPIKey {
238 u, cmd := m.apiKeyInput.Update(msg)
239 m.apiKeyInput = u.(*APIKeyInput)
240 return m, cmd
241 } else {
242 var cmd tea.Cmd
243 m.modelList, cmd = m.modelList.Update(msg)
244 return m, cmd
245 }
246 case spinner.TickMsg:
247 u, cmd := m.apiKeyInput.Update(msg)
248 m.apiKeyInput = u.(*APIKeyInput)
249 return m, cmd
250 }
251 return m, nil
252}
253
254func (m *modelDialogCmp) View() string {
255 t := styles.CurrentTheme()
256
257 if m.needsAPIKey {
258 // Show API key input
259 m.keyMap.isAPIKeyHelp = true
260 m.keyMap.isAPIKeyValid = m.isAPIKeyValid
261 apiKeyView := m.apiKeyInput.View()
262 apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
263 content := lipgloss.JoinVertical(
264 lipgloss.Left,
265 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
266 apiKeyView,
267 "",
268 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
269 )
270 return m.style().Render(content)
271 }
272
273 // Show model selection
274 listView := m.modelList.View()
275 radio := m.modelTypeRadio()
276 content := lipgloss.JoinVertical(
277 lipgloss.Left,
278 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
279 listView,
280 "",
281 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
282 )
283 return m.style().Render(content)
284}
285
286func (m *modelDialogCmp) Cursor() *tea.Cursor {
287 if m.needsAPIKey {
288 cursor := m.apiKeyInput.Cursor()
289 if cursor != nil {
290 cursor = m.moveCursor(cursor)
291 return cursor
292 }
293 } else {
294 cursor := m.modelList.Cursor()
295 if cursor != nil {
296 cursor = m.moveCursor(cursor)
297 return cursor
298 }
299 }
300 return nil
301}
302
303func (m *modelDialogCmp) style() lipgloss.Style {
304 t := styles.CurrentTheme()
305 return t.S().Base.
306 Width(m.width).
307 Border(lipgloss.RoundedBorder()).
308 BorderForeground(t.BorderFocus)
309}
310
311func (m *modelDialogCmp) listWidth() int {
312 return m.width - 2
313}
314
315func (m *modelDialogCmp) listHeight() int {
316 return m.wHeight / 2
317}
318
319func (m *modelDialogCmp) Position() (int, int) {
320 row := m.wHeight/4 - 2 // just a bit above the center
321 col := m.wWidth / 2
322 col -= m.width / 2
323 return row, col
324}
325
326func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
327 row, col := m.Position()
328 if m.needsAPIKey {
329 offset := row + 3 // Border + title + API key input offset
330 cursor.Y += offset
331 cursor.X = cursor.X + col + 2
332 } else {
333 offset := row + 3 // Border + title
334 cursor.Y += offset
335 cursor.X = cursor.X + col + 2
336 }
337 return cursor
338}
339
340func (m *modelDialogCmp) ID() dialogs.DialogID {
341 return ModelsDialogID
342}
343
344func (m *modelDialogCmp) modelTypeRadio() string {
345 t := styles.CurrentTheme()
346 choices := []string{"Large Task", "Small Task"}
347 iconSelected := "◉"
348 iconUnselected := "○"
349 if m.modelList.GetModelType() == LargeModelType {
350 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1])
351 }
352 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1])
353}
354
355func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
356 cfg := config.Get()
357 if _, ok := cfg.Providers.Get(providerID); ok {
358 return true
359 }
360 return false
361}
362
363func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
364 providers, err := config.Providers()
365 if err != nil {
366 return nil, err
367 }
368 for _, p := range providers {
369 if p.ID == providerID {
370 return &p, nil
371 }
372 }
373 return nil, nil
374}
375
376func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
377 if m.selectedModel == nil {
378 return util.ReportError(fmt.Errorf("no model selected"))
379 }
380
381 cfg := config.Get()
382 err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
383 if err != nil {
384 return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
385 }
386
387 // Reset API key state and continue with model selection
388 selectedModel := *m.selectedModel
389 return tea.Sequence(
390 util.CmdHandler(dialogs.CloseDialogMsg{}),
391 util.CmdHandler(ModelSelectedMsg{
392 Model: config.SelectedModel{
393 Model: selectedModel.Model.ID,
394 Provider: string(selectedModel.Provider.ID),
395 },
396 ModelType: m.selectedModelType,
397 }),
398 )
399}