1package dialog
2
3import (
4 "fmt"
5 "strings"
6 "time"
7
8 "charm.land/bubbles/v2/help"
9 "charm.land/bubbles/v2/key"
10 "charm.land/bubbles/v2/spinner"
11 "charm.land/bubbles/v2/textinput"
12 tea "charm.land/bubbletea/v2"
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/ui/common"
16 "github.com/charmbracelet/crush/internal/ui/styles"
17 "github.com/charmbracelet/crush/internal/uiutil"
18 uv "github.com/charmbracelet/ultraviolet"
19 "github.com/charmbracelet/x/exp/charmtone"
20)
21
22type APIKeyInputState int
23
24const (
25 APIKeyInputStateInitial APIKeyInputState = iota
26 APIKeyInputStateVerifying
27 APIKeyInputStateVerified
28 APIKeyInputStateError
29)
30
31// APIKeyInputID is the identifier for the model selection dialog.
32const APIKeyInputID = "api_key_input"
33
34// APIKeyInput represents a model selection dialog.
35type APIKeyInput struct {
36 com *common.Common
37 isOnboarding bool
38
39 provider catwalk.Provider
40 model config.SelectedModel
41 modelType config.SelectedModelType
42
43 width int
44 state APIKeyInputState
45
46 keyMap struct {
47 Submit key.Binding
48 Close key.Binding
49 }
50 input textinput.Model
51 spinner spinner.Model
52 help help.Model
53}
54
55var _ Dialog = (*APIKeyInput)(nil)
56
57// NewAPIKeyInput creates a new Models dialog.
58func NewAPIKeyInput(
59 com *common.Common,
60 isOnboarding bool,
61 provider catwalk.Provider,
62 model config.SelectedModel,
63 modelType config.SelectedModelType,
64) (*APIKeyInput, tea.Cmd) {
65 t := com.Styles
66
67 m := APIKeyInput{}
68 m.com = com
69 m.isOnboarding = isOnboarding
70 m.provider = provider
71 m.model = model
72 m.modelType = modelType
73 m.width = 60
74
75 innerWidth := m.width - t.Dialog.View.GetHorizontalFrameSize() - 2
76
77 m.input = textinput.New()
78 m.input.SetVirtualCursor(false)
79 m.input.Placeholder = "Enter you API key..."
80 m.input.SetStyles(com.Styles.TextInput)
81 m.input.Focus()
82 m.input.SetWidth(max(0, innerWidth-t.Dialog.InputPrompt.GetHorizontalFrameSize()-1)) // (1) cursor padding
83
84 m.spinner = spinner.New(
85 spinner.WithSpinner(spinner.Dot),
86 spinner.WithStyle(t.Base.Foreground(t.Green)),
87 )
88
89 m.help = help.New()
90 m.help.Styles = t.DialogHelpStyles()
91
92 m.keyMap.Submit = key.NewBinding(
93 key.WithKeys("enter", "ctrl+y"),
94 key.WithHelp("enter", "submit"),
95 )
96 m.keyMap.Close = CloseKey
97
98 return &m, nil
99}
100
101// ID implements Dialog.
102func (m *APIKeyInput) ID() string {
103 return APIKeyInputID
104}
105
106// HandleMsg implements [Dialog].
107func (m *APIKeyInput) HandleMsg(msg tea.Msg) Action {
108 switch msg := msg.(type) {
109 case ActionChangeAPIKeyState:
110 m.state = msg.State
111 switch m.state {
112 case APIKeyInputStateVerifying:
113 cmd := tea.Batch(m.spinner.Tick, m.verifyAPIKey)
114 return ActionCmd{cmd}
115 }
116 case spinner.TickMsg:
117 switch m.state {
118 case APIKeyInputStateVerifying:
119 var cmd tea.Cmd
120 m.spinner, cmd = m.spinner.Update(msg)
121 if cmd != nil {
122 return ActionCmd{cmd}
123 }
124 }
125 case tea.KeyPressMsg:
126 switch {
127 case m.state == APIKeyInputStateVerifying:
128 // do nothing
129 case key.Matches(msg, m.keyMap.Close):
130 switch m.state {
131 case APIKeyInputStateVerified:
132 return m.saveKeyAndContinue()
133 default:
134 return ActionClose{}
135 }
136 case key.Matches(msg, m.keyMap.Submit):
137 switch m.state {
138 case APIKeyInputStateInitial, APIKeyInputStateError:
139 return ActionChangeAPIKeyState{State: APIKeyInputStateVerifying}
140 case APIKeyInputStateVerified:
141 return m.saveKeyAndContinue()
142 }
143 default:
144 var cmd tea.Cmd
145 m.input, cmd = m.input.Update(msg)
146 if cmd != nil {
147 return ActionCmd{cmd}
148 }
149 }
150 case tea.PasteMsg:
151 var cmd tea.Cmd
152 m.input, cmd = m.input.Update(msg)
153 if cmd != nil {
154 return ActionCmd{cmd}
155 }
156 }
157 return nil
158}
159
160// Draw implements [Dialog].
161func (m *APIKeyInput) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
162 t := m.com.Styles
163
164 textStyle := t.Dialog.SecondaryText
165 helpStyle := t.Dialog.HelpView
166 dialogStyle := t.Dialog.View.Width(m.width)
167 inputStyle := t.Dialog.InputPrompt
168 helpStyle = helpStyle.Width(m.width - dialogStyle.GetHorizontalFrameSize())
169
170 m.input.Prompt = m.spinner.View()
171
172 content := strings.Join([]string{
173 m.headerView(),
174 inputStyle.Render(m.inputView()),
175 textStyle.Render("This will be written in your global configuration:"),
176 textStyle.Render(config.GlobalConfigData()),
177 "",
178 helpStyle.Render(m.help.View(m)),
179 }, "\n")
180
181 cur := m.Cursor()
182
183 if m.isOnboarding {
184 view := content
185 DrawOnboardingCursor(scr, area, view, cur)
186
187 // FIXME(@andreynering): Figure it out how to properly fix this
188 if cur != nil {
189 cur.Y -= 1
190 cur.X -= 1
191 }
192 } else {
193 view := dialogStyle.Render(content)
194 DrawCenterCursor(scr, area, view, cur)
195 }
196 return cur
197}
198
199func (m *APIKeyInput) headerView() string {
200 var (
201 t = m.com.Styles
202 titleStyle = t.Dialog.Title
203 textStyle = t.Dialog.PrimaryText
204 dialogStyle = t.Dialog.View.Width(m.width)
205 )
206 if m.isOnboarding {
207 return textStyle.Render(m.dialogTitle())
208 }
209 headerOffset := titleStyle.GetHorizontalFrameSize() + dialogStyle.GetHorizontalFrameSize()
210 return common.DialogTitle(t, titleStyle.Render(m.dialogTitle()), m.width-headerOffset, m.com.Styles.Primary, m.com.Styles.Secondary)
211}
212
213func (m *APIKeyInput) dialogTitle() string {
214 var (
215 t = m.com.Styles
216 textStyle = t.Dialog.TitleText
217 errorStyle = t.Dialog.TitleError
218 accentStyle = t.Dialog.TitleAccent
219 )
220 switch m.state {
221 case APIKeyInputStateInitial:
222 return textStyle.Render("Enter your ") + accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render(".")
223 case APIKeyInputStateVerifying:
224 return textStyle.Render("Verifying your ") + accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render("...")
225 case APIKeyInputStateVerified:
226 return accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render(" validated.")
227 case APIKeyInputStateError:
228 return errorStyle.Render("Invalid ") + accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + errorStyle.Render(". Try again?")
229 }
230 return ""
231}
232
233func (m *APIKeyInput) inputView() string {
234 t := m.com.Styles
235
236 switch m.state {
237 case APIKeyInputStateInitial:
238 m.input.Prompt = "> "
239 m.input.SetStyles(t.TextInput)
240 m.input.Focus()
241 case APIKeyInputStateVerifying:
242 ts := t.TextInput
243 ts.Blurred.Prompt = ts.Focused.Prompt
244
245 m.input.Prompt = m.spinner.View()
246 m.input.SetStyles(ts)
247 m.input.Blur()
248 case APIKeyInputStateVerified:
249 ts := t.TextInput
250 ts.Blurred.Prompt = ts.Focused.Prompt
251
252 m.input.Prompt = styles.CheckIcon + " "
253 m.input.SetStyles(ts)
254 m.input.Blur()
255 case APIKeyInputStateError:
256 ts := t.TextInput
257 ts.Focused.Prompt = ts.Focused.Prompt.Foreground(charmtone.Cherry)
258
259 m.input.Prompt = styles.ErrorIcon + " "
260 m.input.SetStyles(ts)
261 m.input.Focus()
262 }
263 return m.input.View()
264}
265
266// Cursor returns the cursor position relative to the dialog.
267func (m *APIKeyInput) Cursor() *tea.Cursor {
268 return InputCursor(m.com.Styles, m.input.Cursor())
269}
270
271// FullHelp returns the full help view.
272func (m *APIKeyInput) FullHelp() [][]key.Binding {
273 return [][]key.Binding{
274 {
275 m.keyMap.Submit,
276 m.keyMap.Close,
277 },
278 }
279}
280
281// ShortHelp returns the full help view.
282func (m *APIKeyInput) ShortHelp() []key.Binding {
283 return []key.Binding{
284 m.keyMap.Submit,
285 m.keyMap.Close,
286 }
287}
288
289func (m *APIKeyInput) verifyAPIKey() tea.Msg {
290 start := time.Now()
291
292 providerConfig := config.ProviderConfig{
293 ID: string(m.provider.ID),
294 Name: m.provider.Name,
295 APIKey: m.input.Value(),
296 Type: m.provider.Type,
297 BaseURL: m.provider.APIEndpoint,
298 }
299 err := providerConfig.TestConnection(config.Get().Resolver())
300
301 // intentionally wait for at least 750ms to make sure the user sees the spinner
302 elapsed := time.Since(start)
303 minimum := 750 * time.Millisecond
304 if elapsed < minimum {
305 time.Sleep(minimum - elapsed)
306 }
307
308 if err == nil {
309 return ActionChangeAPIKeyState{APIKeyInputStateVerified}
310 }
311 return ActionChangeAPIKeyState{APIKeyInputStateError}
312}
313
314func (m *APIKeyInput) saveKeyAndContinue() Action {
315 cfg := m.com.Config()
316
317 err := cfg.SetProviderAPIKey(string(m.provider.ID), m.input.Value())
318 if err != nil {
319 return ActionCmd{uiutil.ReportError(fmt.Errorf("failed to save API key: %w", err))}
320 }
321
322 return ActionSelectModel{
323 Provider: m.provider,
324 Model: m.model,
325 ModelType: m.modelType,
326 }
327}