1package models
2
3import (
4 "fmt"
5 "time"
6
7 "github.com/charmbracelet/bubbles/v2/help"
8 "github.com/charmbracelet/bubbles/v2/key"
9 "github.com/charmbracelet/bubbles/v2/spinner"
10 tea "github.com/charmbracelet/bubbletea/v2"
11 "github.com/charmbracelet/catwalk/pkg/catwalk"
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/proto"
14 "github.com/charmbracelet/crush/internal/tui/components/core"
15 "github.com/charmbracelet/crush/internal/tui/components/dialogs"
16 "github.com/charmbracelet/crush/internal/tui/exp/list"
17 "github.com/charmbracelet/crush/internal/tui/styles"
18 "github.com/charmbracelet/crush/internal/tui/util"
19 "github.com/charmbracelet/lipgloss/v2"
20)
21
22const (
23 ModelsDialogID dialogs.DialogID = "models"
24
25 defaultWidth = 60
26)
27
28const (
29 LargeModelType int = iota
30 SmallModelType
31
32 largeModelInputPlaceholder = "Choose a model for large, complex tasks"
33 smallModelInputPlaceholder = "Choose a model for small, simple tasks"
34)
35
36// ModelSelectedMsg is sent when a model is selected
37type ModelSelectedMsg struct {
38 Model config.SelectedModel
39 ModelType config.SelectedModelType
40}
41
42// CloseModelDialogMsg is sent when a model is selected
43type CloseModelDialogMsg struct{}
44
45// ModelDialog interface for the model selection dialog
46type ModelDialog interface {
47 dialogs.DialogModel
48}
49
50type ModelOption struct {
51 Provider catwalk.Provider
52 Model catwalk.Model
53}
54
55type modelDialogCmp struct {
56 width int
57 wWidth int
58 wHeight int
59
60 modelList *ModelListComponent
61 keyMap KeyMap
62 help help.Model
63
64 // API key state
65 needsAPIKey bool
66 apiKeyInput *APIKeyInput
67 selectedModel *ModelOption
68 selectedModelType config.SelectedModelType
69 isAPIKeyValid bool
70 apiKeyValue string
71
72 ins *proto.Instance
73}
74
75func NewModelDialogCmp(ins *proto.Instance) ModelDialog {
76 keyMap := DefaultKeyMap()
77
78 listKeyMap := list.DefaultKeyMap()
79 listKeyMap.Down.SetEnabled(false)
80 listKeyMap.Up.SetEnabled(false)
81 listKeyMap.DownOneItem = keyMap.Next
82 listKeyMap.UpOneItem = keyMap.Previous
83
84 t := styles.CurrentTheme()
85 modelList := NewModelListComponent(ins.Config, listKeyMap, largeModelInputPlaceholder, true)
86 apiKeyInput := NewAPIKeyInput()
87 apiKeyInput.SetShowTitle(false)
88 help := help.New()
89 help.Styles = t.S().Help
90
91 return &modelDialogCmp{
92 modelList: modelList,
93 apiKeyInput: apiKeyInput,
94 width: defaultWidth,
95 keyMap: DefaultKeyMap(),
96 help: help,
97 ins: ins,
98 }
99}
100
101func (m *modelDialogCmp) Init() tea.Cmd {
102 return tea.Batch(m.modelList.Init(), m.apiKeyInput.Init())
103}
104
105func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
106 switch msg := msg.(type) {
107 case tea.WindowSizeMsg:
108 m.wWidth = msg.Width
109 m.wHeight = msg.Height
110 m.apiKeyInput.SetWidth(m.width - 2)
111 m.help.Width = m.width - 2
112 return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
113 case APIKeyStateChangeMsg:
114 u, cmd := m.apiKeyInput.Update(msg)
115 m.apiKeyInput = u.(*APIKeyInput)
116 return m, cmd
117 case tea.KeyPressMsg:
118 switch {
119 case key.Matches(msg, m.keyMap.Select):
120 if m.isAPIKeyValid {
121 return m, m.saveAPIKeyAndContinue(m.apiKeyValue)
122 }
123 if m.needsAPIKey {
124 // Handle API key submission
125 m.apiKeyValue = m.apiKeyInput.Value()
126 provider, err := m.getProvider(m.selectedModel.Provider.ID)
127 if err != nil || provider == nil {
128 return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
129 }
130 providerConfig := config.ProviderConfig{
131 ID: string(m.selectedModel.Provider.ID),
132 Name: m.selectedModel.Provider.Name,
133 APIKey: m.apiKeyValue,
134 Type: provider.Type,
135 BaseURL: provider.APIEndpoint,
136 }
137 return m, tea.Sequence(
138 util.CmdHandler(APIKeyStateChangeMsg{
139 State: APIKeyInputStateVerifying,
140 }),
141 func() tea.Msg {
142 start := time.Now()
143 err := providerConfig.TestConnection(m.ins.ShellResolver())
144 // intentionally wait for at least 750ms to make sure the user sees the spinner
145 elapsed := time.Since(start)
146 if elapsed < 750*time.Millisecond {
147 time.Sleep(750*time.Millisecond - elapsed)
148 }
149 if err == nil {
150 m.isAPIKeyValid = true
151 return APIKeyStateChangeMsg{
152 State: APIKeyInputStateVerified,
153 }
154 }
155 return APIKeyStateChangeMsg{
156 State: APIKeyInputStateError,
157 }
158 },
159 )
160 }
161 // Normal model selection
162 selectedItem := m.modelList.SelectedModel()
163
164 var modelType config.SelectedModelType
165 if m.modelList.GetModelType() == LargeModelType {
166 modelType = config.SelectedModelTypeLarge
167 } else {
168 modelType = config.SelectedModelTypeSmall
169 }
170
171 // Check if provider is configured
172 if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
173 return m, tea.Sequence(
174 util.CmdHandler(dialogs.CloseDialogMsg{}),
175 util.CmdHandler(ModelSelectedMsg{
176 Model: config.SelectedModel{
177 Model: selectedItem.Model.ID,
178 Provider: string(selectedItem.Provider.ID),
179 ReasoningEffort: selectedItem.Model.DefaultReasoningEffort,
180 MaxTokens: selectedItem.Model.DefaultMaxTokens,
181 },
182 ModelType: modelType,
183 }),
184 )
185 } else {
186 // Provider not configured, show API key input
187 m.needsAPIKey = true
188 m.selectedModel = selectedItem
189 m.selectedModelType = modelType
190 m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
191 return m, nil
192 }
193 case key.Matches(msg, m.keyMap.Tab):
194 if m.needsAPIKey {
195 u, cmd := m.apiKeyInput.Update(msg)
196 m.apiKeyInput = u.(*APIKeyInput)
197 return m, cmd
198 }
199 if m.modelList.GetModelType() == LargeModelType {
200 m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
201 return m, m.modelList.SetModelType(SmallModelType)
202 } else {
203 m.modelList.SetInputPlaceholder(largeModelInputPlaceholder)
204 return m, m.modelList.SetModelType(LargeModelType)
205 }
206 case key.Matches(msg, m.keyMap.Close):
207 if m.needsAPIKey {
208 if m.isAPIKeyValid {
209 return m, nil
210 }
211 // Go back to model selection
212 m.needsAPIKey = false
213 m.selectedModel = nil
214 m.isAPIKeyValid = false
215 m.apiKeyValue = ""
216 m.apiKeyInput.Reset()
217 return m, nil
218 }
219 return m, util.CmdHandler(dialogs.CloseDialogMsg{})
220 default:
221 if m.needsAPIKey {
222 u, cmd := m.apiKeyInput.Update(msg)
223 m.apiKeyInput = u.(*APIKeyInput)
224 return m, cmd
225 } else {
226 u, cmd := m.modelList.Update(msg)
227 m.modelList = u
228 return m, cmd
229 }
230 }
231 case tea.PasteMsg:
232 if m.needsAPIKey {
233 u, cmd := m.apiKeyInput.Update(msg)
234 m.apiKeyInput = u.(*APIKeyInput)
235 return m, cmd
236 } else {
237 var cmd tea.Cmd
238 m.modelList, cmd = m.modelList.Update(msg)
239 return m, cmd
240 }
241 case spinner.TickMsg:
242 u, cmd := m.apiKeyInput.Update(msg)
243 m.apiKeyInput = u.(*APIKeyInput)
244 return m, cmd
245 }
246 return m, nil
247}
248
249func (m *modelDialogCmp) View() string {
250 t := styles.CurrentTheme()
251
252 if m.needsAPIKey {
253 // Show API key input
254 m.keyMap.isAPIKeyHelp = true
255 m.keyMap.isAPIKeyValid = m.isAPIKeyValid
256 apiKeyView := m.apiKeyInput.View()
257 apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
258 content := lipgloss.JoinVertical(
259 lipgloss.Left,
260 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
261 apiKeyView,
262 "",
263 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
264 )
265 return m.style().Render(content)
266 }
267
268 // Show model selection
269 listView := m.modelList.View()
270 radio := m.modelTypeRadio()
271 content := lipgloss.JoinVertical(
272 lipgloss.Left,
273 t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Switch Model", m.width-lipgloss.Width(radio)-5)+" "+radio),
274 listView,
275 "",
276 t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
277 )
278 return m.style().Render(content)
279}
280
281func (m *modelDialogCmp) Cursor() *tea.Cursor {
282 if m.needsAPIKey {
283 cursor := m.apiKeyInput.Cursor()
284 if cursor != nil {
285 cursor = m.moveCursor(cursor)
286 return cursor
287 }
288 } else {
289 cursor := m.modelList.Cursor()
290 if cursor != nil {
291 cursor = m.moveCursor(cursor)
292 return cursor
293 }
294 }
295 return nil
296}
297
298func (m *modelDialogCmp) style() lipgloss.Style {
299 t := styles.CurrentTheme()
300 return t.S().Base.
301 Width(m.width).
302 Border(lipgloss.RoundedBorder()).
303 BorderForeground(t.BorderFocus)
304}
305
306func (m *modelDialogCmp) listWidth() int {
307 return m.width - 2
308}
309
310func (m *modelDialogCmp) listHeight() int {
311 return m.wHeight / 2
312}
313
314func (m *modelDialogCmp) Position() (int, int) {
315 row := m.wHeight/4 - 2 // just a bit above the center
316 col := m.wWidth / 2
317 col -= m.width / 2
318 return row, col
319}
320
321func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
322 row, col := m.Position()
323 if m.needsAPIKey {
324 offset := row + 3 // Border + title + API key input offset
325 cursor.Y += offset
326 cursor.X = cursor.X + col + 2
327 } else {
328 offset := row + 3 // Border + title
329 cursor.Y += offset
330 cursor.X = cursor.X + col + 2
331 }
332 return cursor
333}
334
335func (m *modelDialogCmp) ID() dialogs.DialogID {
336 return ModelsDialogID
337}
338
339func (m *modelDialogCmp) modelTypeRadio() string {
340 t := styles.CurrentTheme()
341 choices := []string{"Large Task", "Small Task"}
342 iconSelected := "◉"
343 iconUnselected := "○"
344 if m.modelList.GetModelType() == LargeModelType {
345 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1])
346 }
347 return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1])
348}
349
350func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
351 if _, ok := m.ins.Config.Providers.Get(providerID); ok {
352 return true
353 }
354 return false
355}
356
357func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
358 providers, err := config.Providers(m.ins.Config)
359 if err != nil {
360 return nil, err
361 }
362 for _, p := range providers {
363 if p.ID == providerID {
364 return &p, nil
365 }
366 }
367 return nil, nil
368}
369
370func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
371 if m.selectedModel == nil {
372 return util.ReportError(fmt.Errorf("no model selected"))
373 }
374
375 err := m.ins.Config.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
376 if err != nil {
377 return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
378 }
379
380 // Reset API key state and continue with model selection
381 selectedModel := *m.selectedModel
382 return tea.Sequence(
383 util.CmdHandler(dialogs.CloseDialogMsg{}),
384 util.CmdHandler(ModelSelectedMsg{
385 Model: config.SelectedModel{
386 Model: selectedModel.Model.ID,
387 Provider: string(selectedModel.Provider.ID),
388 ReasoningEffort: selectedModel.Model.DefaultReasoningEffort,
389 MaxTokens: selectedModel.Model.DefaultMaxTokens,
390 },
391 ModelType: m.selectedModelType,
392 }),
393 )
394}