1package dialog
2
3import (
4 "fmt"
5 "strings"
6 "time"
7
8 "charm.land/bubbles/v2/help"
9 "charm.land/bubbles/v2/key"
10 "charm.land/bubbles/v2/spinner"
11 "charm.land/bubbles/v2/textinput"
12 tea "charm.land/bubbletea/v2"
13 "charm.land/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/ui/common"
16 "github.com/charmbracelet/crush/internal/ui/styles"
17 "github.com/charmbracelet/crush/internal/ui/util"
18 uv "github.com/charmbracelet/ultraviolet"
19 "github.com/charmbracelet/x/exp/charmtone"
20)
21
22type APIKeyInputState int
23
24const (
25 APIKeyInputStateInitial APIKeyInputState = iota
26 APIKeyInputStateVerifying
27 APIKeyInputStateVerified
28 APIKeyInputStateError
29)
30
31// APIKeyInputID is the identifier for the model selection dialog.
32const APIKeyInputID = "api_key_input"
33
34// APIKeyInput represents a model selection dialog.
35type APIKeyInput struct {
36 com *common.Common
37 isOnboarding bool
38
39 provider catwalk.Provider
40 model config.SelectedModel
41 modelType config.SelectedModelType
42
43 width int
44 state APIKeyInputState
45
46 keyMap struct {
47 Submit key.Binding
48 Close key.Binding
49 }
50 input textinput.Model
51 spinner spinner.Model
52 help help.Model
53}
54
55var _ Dialog = (*APIKeyInput)(nil)
56
57// NewAPIKeyInput creates a new Models dialog.
58func NewAPIKeyInput(
59 com *common.Common,
60 isOnboarding bool,
61 provider catwalk.Provider,
62 model config.SelectedModel,
63 modelType config.SelectedModelType,
64) (*APIKeyInput, tea.Cmd) {
65 t := com.Styles
66
67 m := APIKeyInput{}
68 m.com = com
69 m.isOnboarding = isOnboarding
70 m.provider = provider
71 m.model = model
72 m.modelType = modelType
73 m.width = 60
74
75 innerWidth := m.width - t.Dialog.View.GetHorizontalFrameSize() - 2
76
77 m.input = textinput.New()
78 m.input.SetVirtualCursor(false)
79 m.input.Placeholder = "Enter your API key..."
80 m.input.SetStyles(com.Styles.TextInput)
81 m.input.Focus()
82 m.input.SetWidth(max(0, innerWidth-t.Dialog.InputPrompt.GetHorizontalFrameSize()-1)) // (1) cursor padding
83
84 m.spinner = spinner.New(
85 spinner.WithSpinner(spinner.Dot),
86 spinner.WithStyle(t.Base.Foreground(t.Green)),
87 )
88
89 m.help = help.New()
90 m.help.Styles = t.DialogHelpStyles()
91
92 m.keyMap.Submit = key.NewBinding(
93 key.WithKeys("enter", "ctrl+y"),
94 key.WithHelp("enter", "submit"),
95 )
96 m.keyMap.Close = CloseKey
97
98 return &m, nil
99}
100
101// ID implements Dialog.
102func (m *APIKeyInput) ID() string {
103 return APIKeyInputID
104}
105
106// HandleMsg implements [Dialog].
107func (m *APIKeyInput) HandleMsg(msg tea.Msg) Action {
108 switch msg := msg.(type) {
109 case ActionChangeAPIKeyState:
110 m.state = msg.State
111 switch m.state {
112 case APIKeyInputStateVerifying:
113 cmd := tea.Batch(m.spinner.Tick, m.verifyAPIKey)
114 return ActionCmd{cmd}
115 }
116 case spinner.TickMsg:
117 switch m.state {
118 case APIKeyInputStateVerifying:
119 var cmd tea.Cmd
120 m.spinner, cmd = m.spinner.Update(msg)
121 if cmd != nil {
122 return ActionCmd{cmd}
123 }
124 }
125 case tea.KeyPressMsg:
126 switch {
127 case m.state == APIKeyInputStateVerifying:
128 // do nothing
129 case key.Matches(msg, m.keyMap.Close):
130 switch m.state {
131 case APIKeyInputStateVerified:
132 return m.saveKeyAndContinue()
133 default:
134 return ActionClose{}
135 }
136 case key.Matches(msg, m.keyMap.Submit):
137 switch m.state {
138 case APIKeyInputStateInitial, APIKeyInputStateError:
139 return ActionChangeAPIKeyState{State: APIKeyInputStateVerifying}
140 case APIKeyInputStateVerified:
141 return m.saveKeyAndContinue()
142 }
143 default:
144 var cmd tea.Cmd
145 m.input, cmd = m.input.Update(msg)
146 if cmd != nil {
147 return ActionCmd{cmd}
148 }
149 }
150 case tea.PasteMsg:
151 var cmd tea.Cmd
152 m.input, cmd = m.input.Update(msg)
153 if cmd != nil {
154 return ActionCmd{cmd}
155 }
156 }
157 return nil
158}
159
160// Draw implements [Dialog].
161func (m *APIKeyInput) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
162 t := m.com.Styles
163
164 textStyle := t.Dialog.SecondaryText
165 helpStyle := t.Dialog.HelpView
166 dialogStyle := t.Dialog.View.Width(m.width)
167 inputStyle := t.Dialog.InputPrompt
168 helpStyle = helpStyle.Width(m.width - dialogStyle.GetHorizontalFrameSize())
169
170 m.input.Prompt = m.spinner.View()
171
172 content := strings.Join([]string{
173 m.headerView(),
174 inputStyle.Render(m.inputView()),
175 textStyle.Render("This will be written in your global configuration:"),
176 textStyle.Render(config.GlobalConfigData()),
177 "",
178 helpStyle.Render(m.help.View(m)),
179 }, "\n")
180
181 cur := m.Cursor()
182
183 if m.isOnboarding {
184 view := content
185 cur = adjustOnboardingInputCursor(t, cur)
186 DrawOnboardingCursor(scr, area, view, cur)
187 } else {
188 view := dialogStyle.Render(content)
189 DrawCenterCursor(scr, area, view, cur)
190 }
191 return cur
192}
193
194func (m *APIKeyInput) headerView() string {
195 var (
196 t = m.com.Styles
197 titleStyle = t.Dialog.Title
198 textStyle = t.Dialog.PrimaryText
199 dialogStyle = t.Dialog.View.Width(m.width)
200 )
201 if m.isOnboarding {
202 return textStyle.Render(m.dialogTitle())
203 }
204 headerOffset := titleStyle.GetHorizontalFrameSize() + dialogStyle.GetHorizontalFrameSize()
205 return common.DialogTitle(t, titleStyle.Render(m.dialogTitle()), m.width-headerOffset, m.com.Styles.Primary, m.com.Styles.Secondary)
206}
207
208func (m *APIKeyInput) dialogTitle() string {
209 var (
210 t = m.com.Styles
211 textStyle = t.Dialog.TitleText
212 errorStyle = t.Dialog.TitleError
213 accentStyle = t.Dialog.TitleAccent
214 )
215 switch m.state {
216 case APIKeyInputStateInitial:
217 return textStyle.Render("Enter your ") + accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render(".")
218 case APIKeyInputStateVerifying:
219 return textStyle.Render("Verifying your ") + accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render("...")
220 case APIKeyInputStateVerified:
221 return accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render(" validated.")
222 case APIKeyInputStateError:
223 return errorStyle.Render("Invalid ") + accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + errorStyle.Render(". Try again?")
224 }
225 return ""
226}
227
228func (m *APIKeyInput) inputView() string {
229 t := m.com.Styles
230
231 switch m.state {
232 case APIKeyInputStateInitial:
233 m.input.Prompt = "> "
234 m.input.SetStyles(t.TextInput)
235 m.input.Focus()
236 case APIKeyInputStateVerifying:
237 ts := t.TextInput
238 ts.Blurred.Prompt = ts.Focused.Prompt
239
240 m.input.Prompt = m.spinner.View()
241 m.input.SetStyles(ts)
242 m.input.Blur()
243 case APIKeyInputStateVerified:
244 ts := t.TextInput
245 ts.Blurred.Prompt = ts.Focused.Prompt
246
247 m.input.Prompt = styles.CheckIcon + " "
248 m.input.SetStyles(ts)
249 m.input.Blur()
250 case APIKeyInputStateError:
251 ts := t.TextInput
252 ts.Focused.Prompt = ts.Focused.Prompt.Foreground(charmtone.Cherry)
253
254 m.input.Prompt = styles.LSPErrorIcon + " "
255 m.input.SetStyles(ts)
256 m.input.Focus()
257 }
258 return m.input.View()
259}
260
261// Cursor returns the cursor position relative to the dialog.
262func (m *APIKeyInput) Cursor() *tea.Cursor {
263 return InputCursor(m.com.Styles, m.input.Cursor())
264}
265
266// FullHelp returns the full help view.
267func (m *APIKeyInput) FullHelp() [][]key.Binding {
268 return [][]key.Binding{
269 {
270 m.keyMap.Submit,
271 m.keyMap.Close,
272 },
273 }
274}
275
276// ShortHelp returns the full help view.
277func (m *APIKeyInput) ShortHelp() []key.Binding {
278 return []key.Binding{
279 m.keyMap.Submit,
280 m.keyMap.Close,
281 }
282}
283
284func (m *APIKeyInput) verifyAPIKey() tea.Msg {
285 start := time.Now()
286
287 providerConfig := config.ProviderConfig{
288 ID: string(m.provider.ID),
289 Name: m.provider.Name,
290 APIKey: m.input.Value(),
291 Type: m.provider.Type,
292 BaseURL: m.provider.APIEndpoint,
293 }
294 err := providerConfig.TestConnection(m.com.Workspace.Resolver())
295
296 // intentionally wait for at least 750ms to make sure the user sees the spinner
297 elapsed := time.Since(start)
298 minimum := 750 * time.Millisecond
299 if elapsed < minimum {
300 time.Sleep(minimum - elapsed)
301 }
302
303 if err == nil {
304 return ActionChangeAPIKeyState{APIKeyInputStateVerified}
305 }
306 return ActionChangeAPIKeyState{APIKeyInputStateError}
307}
308
309func (m *APIKeyInput) saveKeyAndContinue() Action {
310 err := m.com.Workspace.SetProviderAPIKey(config.ScopeGlobal, string(m.provider.ID), m.input.Value())
311 if err != nil {
312 return ActionCmd{util.ReportError(fmt.Errorf("failed to save API key: %w", err))}
313 }
314
315 return ActionSelectModel{
316 Provider: m.provider,
317 Model: m.model,
318 ModelType: m.modelType,
319 }
320}