1package dialog
2
3import (
4 "errors"
5 "fmt"
6 "maps"
7 "strings"
8 "time"
9
10 "charm.land/bubbles/v2/help"
11 "charm.land/bubbles/v2/key"
12 "charm.land/bubbles/v2/spinner"
13 "charm.land/bubbles/v2/textinput"
14 tea "charm.land/bubbletea/v2"
15 "charm.land/catwalk/pkg/catwalk"
16 "github.com/charmbracelet/crush/internal/config"
17 "github.com/charmbracelet/crush/internal/ui/common"
18 "github.com/charmbracelet/crush/internal/ui/styles"
19 "github.com/charmbracelet/crush/internal/ui/util"
20 uv "github.com/charmbracelet/ultraviolet"
21 "github.com/charmbracelet/x/exp/charmtone"
22)
23
24type APIKeyInputState int
25
26const (
27 APIKeyInputStateInitial APIKeyInputState = iota
28 APIKeyInputStateVerifying
29 APIKeyInputStateVerified
30 // APIKeyInputStateUnverified indicates the key was saved but the
31 // provider does not expose a deterministic validation probe, so
32 // authentication could not be proven.
33 APIKeyInputStateUnverified
34 APIKeyInputStateError
35)
36
37// APIKeyInputID is the identifier for the model selection dialog.
38const APIKeyInputID = "api_key_input"
39
40// APIKeyInput represents a model selection dialog.
41type APIKeyInput struct {
42 com *common.Common
43 isOnboarding bool
44
45 provider catwalk.Provider
46 model config.SelectedModel
47 modelType config.SelectedModelType
48
49 width int
50 state APIKeyInputState
51
52 keyMap struct {
53 Submit key.Binding
54 Close key.Binding
55 }
56 input textinput.Model
57 spinner spinner.Model
58 help help.Model
59}
60
61var _ Dialog = (*APIKeyInput)(nil)
62
63// NewAPIKeyInput creates a new Models dialog.
64func NewAPIKeyInput(
65 com *common.Common,
66 isOnboarding bool,
67 provider catwalk.Provider,
68 model config.SelectedModel,
69 modelType config.SelectedModelType,
70) (*APIKeyInput, tea.Cmd) {
71 t := com.Styles
72
73 m := APIKeyInput{}
74 m.com = com
75 m.isOnboarding = isOnboarding
76 m.provider = provider
77 m.model = model
78 m.modelType = modelType
79 m.width = 60
80
81 innerWidth := m.width - t.Dialog.View.GetHorizontalFrameSize() - 2
82
83 m.input = textinput.New()
84 m.input.SetVirtualCursor(false)
85 m.input.Placeholder = "Enter your API key..."
86 m.input.SetStyles(com.Styles.TextInput)
87 m.input.Focus()
88 m.input.SetWidth(max(0, innerWidth-t.Dialog.InputPrompt.GetHorizontalFrameSize()-1)) // (1) cursor padding
89
90 m.spinner = spinner.New(
91 spinner.WithSpinner(spinner.Dot),
92 spinner.WithStyle(t.Dialog.APIKey.Spinner),
93 )
94
95 m.help = help.New()
96 m.help.Styles = t.DialogHelpStyles()
97
98 m.keyMap.Submit = key.NewBinding(
99 key.WithKeys("enter", "ctrl+y"),
100 key.WithHelp("enter", "submit"),
101 )
102 m.keyMap.Close = CloseKey
103
104 return &m, nil
105}
106
107// ID implements Dialog.
108func (m *APIKeyInput) ID() string {
109 return APIKeyInputID
110}
111
112// HandleMsg implements [Dialog].
113func (m *APIKeyInput) HandleMsg(msg tea.Msg) Action {
114 switch msg := msg.(type) {
115 case ActionChangeAPIKeyState:
116 m.state = msg.State
117 switch m.state {
118 case APIKeyInputStateVerifying:
119 cmd := tea.Batch(m.spinner.Tick, m.verifyAPIKey)
120 return ActionCmd{cmd}
121 }
122 case spinner.TickMsg:
123 switch m.state {
124 case APIKeyInputStateVerifying:
125 var cmd tea.Cmd
126 m.spinner, cmd = m.spinner.Update(msg)
127 if cmd != nil {
128 return ActionCmd{cmd}
129 }
130 }
131 case tea.KeyPressMsg:
132 switch {
133 case m.state == APIKeyInputStateVerifying:
134 // do nothing
135 case key.Matches(msg, m.keyMap.Close):
136 switch m.state {
137 case APIKeyInputStateVerified, APIKeyInputStateUnverified:
138 return m.saveKeyAndContinue()
139 default:
140 return ActionClose{}
141 }
142 case key.Matches(msg, m.keyMap.Submit):
143 switch m.state {
144 case APIKeyInputStateInitial, APIKeyInputStateError:
145 return ActionChangeAPIKeyState{State: APIKeyInputStateVerifying}
146 case APIKeyInputStateVerified, APIKeyInputStateUnverified:
147 return m.saveKeyAndContinue()
148 }
149 default:
150 var cmd tea.Cmd
151 m.input, cmd = m.input.Update(msg)
152 if cmd != nil {
153 return ActionCmd{cmd}
154 }
155 }
156 case tea.PasteMsg:
157 var cmd tea.Cmd
158 m.input, cmd = m.input.Update(msg)
159 if cmd != nil {
160 return ActionCmd{cmd}
161 }
162 }
163 return nil
164}
165
166// Draw implements [Dialog].
167func (m *APIKeyInput) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
168 t := m.com.Styles
169
170 textStyle := t.Dialog.SecondaryText
171 helpStyle := t.Dialog.HelpView
172 dialogStyle := t.Dialog.View.Width(m.width)
173 inputStyle := t.Dialog.InputPrompt
174 helpStyle = helpStyle.Width(m.width - dialogStyle.GetHorizontalFrameSize())
175
176 m.input.Prompt = m.spinner.View()
177
178 content := strings.Join([]string{
179 m.headerView(),
180 inputStyle.Render(m.inputView()),
181 textStyle.Render("This will be written in your global configuration:"),
182 textStyle.Render(config.GlobalConfigData()),
183 "",
184 helpStyle.Render(m.help.View(m)),
185 }, "\n")
186
187 cur := m.Cursor()
188
189 if m.isOnboarding {
190 view := content
191 cur = adjustOnboardingInputCursor(t, cur)
192 DrawOnboardingCursor(scr, area, view, cur)
193 } else {
194 view := dialogStyle.Render(content)
195 DrawCenterCursor(scr, area, view, cur)
196 }
197 return cur
198}
199
200func (m *APIKeyInput) headerView() string {
201 var (
202 t = m.com.Styles
203 titleStyle = t.Dialog.Title
204 textStyle = t.Dialog.PrimaryText
205 dialogStyle = t.Dialog.View.Width(m.width)
206 )
207 if m.isOnboarding {
208 return textStyle.Render(m.dialogTitle())
209 }
210 headerOffset := titleStyle.GetHorizontalFrameSize() + dialogStyle.GetHorizontalFrameSize()
211 return common.DialogTitle(t, titleStyle.Render(m.dialogTitle()), m.width-headerOffset, m.com.Styles.Dialog.TitleGradFromColor, m.com.Styles.Dialog.TitleGradToColor)
212}
213
214func (m *APIKeyInput) dialogTitle() string {
215 var (
216 t = m.com.Styles
217 textStyle = t.Dialog.TitleText
218 errorStyle = t.Dialog.TitleError
219 accentStyle = t.Dialog.TitleAccent
220 )
221 switch m.state {
222 case APIKeyInputStateInitial:
223 return textStyle.Render("Enter your ") + accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render(".")
224 case APIKeyInputStateVerifying:
225 return textStyle.Render("Verifying your ") + accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render("...")
226 case APIKeyInputStateVerified:
227 return accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render(" validated.")
228 case APIKeyInputStateUnverified:
229 return accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render(" saved (not verified).")
230 case APIKeyInputStateError:
231 return errorStyle.Render("Invalid ") + accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + errorStyle.Render(". Try again?")
232 }
233 return ""
234}
235
236func (m *APIKeyInput) inputView() string {
237 t := m.com.Styles
238
239 switch m.state {
240 case APIKeyInputStateInitial:
241 m.input.Prompt = "> "
242 m.input.SetStyles(t.TextInput)
243 m.input.Focus()
244 case APIKeyInputStateVerifying:
245 ts := t.TextInput
246 ts.Blurred.Prompt = ts.Focused.Prompt
247
248 m.input.Prompt = m.spinner.View()
249 m.input.SetStyles(ts)
250 m.input.Blur()
251 case APIKeyInputStateVerified, APIKeyInputStateUnverified:
252 ts := t.TextInput
253 ts.Blurred.Prompt = ts.Focused.Prompt
254
255 m.input.Prompt = styles.CheckIcon + " "
256 m.input.SetStyles(ts)
257 m.input.Blur()
258 case APIKeyInputStateError:
259 ts := t.TextInput
260 ts.Focused.Prompt = ts.Focused.Prompt.Foreground(charmtone.Cherry)
261
262 m.input.Prompt = styles.LSPErrorIcon + " "
263 m.input.SetStyles(ts)
264 m.input.Focus()
265 }
266 return m.input.View()
267}
268
269// Cursor returns the cursor position relative to the dialog.
270func (m *APIKeyInput) Cursor() *tea.Cursor {
271 return InputCursor(m.com.Styles, m.input.Cursor())
272}
273
274// FullHelp returns the full help view.
275func (m *APIKeyInput) FullHelp() [][]key.Binding {
276 return [][]key.Binding{
277 {
278 m.keyMap.Submit,
279 m.keyMap.Close,
280 },
281 }
282}
283
284// ShortHelp returns the full help view.
285func (m *APIKeyInput) ShortHelp() []key.Binding {
286 return []key.Binding{
287 m.keyMap.Submit,
288 m.keyMap.Close,
289 }
290}
291
292func (m *APIKeyInput) verifyAPIKey() tea.Msg {
293 start := time.Now()
294
295 providerConfig := providerConfigForVerify(m.provider, m.input.Value())
296 err := providerConfig.TestConnection(m.com.Workspace.Resolver())
297
298 // intentionally wait for at least 750ms to make sure the user sees the spinner
299 elapsed := time.Since(start)
300 minimum := 750 * time.Millisecond
301 if elapsed < minimum {
302 time.Sleep(minimum - elapsed)
303 }
304
305 return ActionChangeAPIKeyState{State: apiKeyStateForVerifyErr(err)}
306}
307
308// providerConfigForVerify builds the [config.ProviderConfig] used to probe a
309// provider's API key from the dialog. In particular it propagates the
310// provider's [catwalk.Provider.DefaultHeaders] into [config.ProviderConfig.ExtraHeaders]
311// so validation probes carry any headers the provider expects (e.g. routing
312// or tenant hints) — matching the behaviour used for real inference traffic.
313func providerConfigForVerify(provider catwalk.Provider, apiKey string) config.ProviderConfig {
314 cfg := config.ProviderConfig{
315 ID: string(provider.ID),
316 Name: provider.Name,
317 APIKey: apiKey,
318 Type: provider.Type,
319 BaseURL: provider.APIEndpoint,
320 }
321 if len(provider.DefaultHeaders) > 0 {
322 cfg.ExtraHeaders = make(map[string]string, len(provider.DefaultHeaders))
323 maps.Copy(cfg.ExtraHeaders, provider.DefaultHeaders)
324 }
325 return cfg
326}
327
328// apiKeyStateForVerifyErr maps a [config.ProviderConfig.TestConnection] error
329// to the dialog state the UI should transition into. Extracted so the
330// mapping can be unit-tested without spinning up a full [common.Common].
331func apiKeyStateForVerifyErr(err error) APIKeyInputState {
332 switch {
333 case err == nil:
334 return APIKeyInputStateVerified
335 case errors.Is(err, config.ErrValidationUnsupported):
336 return APIKeyInputStateUnverified
337 default:
338 return APIKeyInputStateError
339 }
340}
341
342func (m *APIKeyInput) saveKeyAndContinue() Action {
343 err := m.com.Workspace.SetProviderAPIKey(config.ScopeGlobal, string(m.provider.ID), m.input.Value())
344 if err != nil {
345 return ActionCmd{util.ReportError(fmt.Errorf("failed to save API key: %w", err))}
346 }
347
348 return ActionSelectModel{
349 Provider: m.provider,
350 Model: m.model,
351 ModelType: m.modelType,
352 }
353}