1package dialog
2
3import (
4 "fmt"
5 "strings"
6 "time"
7
8 "charm.land/bubbles/v2/help"
9 "charm.land/bubbles/v2/key"
10 "charm.land/bubbles/v2/spinner"
11 "charm.land/bubbles/v2/textinput"
12 tea "charm.land/bubbletea/v2"
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/ui/common"
16 "github.com/charmbracelet/crush/internal/ui/styles"
17 "github.com/charmbracelet/crush/internal/uiutil"
18 uv "github.com/charmbracelet/ultraviolet"
19 "github.com/charmbracelet/x/exp/charmtone"
20)
21
22type APIKeyInputState int
23
24const (
25 APIKeyInputStateInitial APIKeyInputState = iota
26 APIKeyInputStateVerifying
27 APIKeyInputStateVerified
28 APIKeyInputStateError
29)
30
31// APIKeyInputID is the identifier for the model selection dialog.
32const APIKeyInputID = "api_key_input"
33
34// APIKeyInput represents a model selection dialog.
35type APIKeyInput struct {
36 com *common.Common
37
38 provider catwalk.Provider
39 model config.SelectedModel
40 modelType config.SelectedModelType
41
42 width int
43 state APIKeyInputState
44
45 keyMap struct {
46 Submit key.Binding
47 Close key.Binding
48 }
49 input textinput.Model
50 spinner spinner.Model
51 help help.Model
52}
53
54var _ Dialog = (*APIKeyInput)(nil)
55
56// NewAPIKeyInput creates a new Models dialog.
57func NewAPIKeyInput(com *common.Common, provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType) (*APIKeyInput, error) {
58 t := com.Styles
59
60 m := APIKeyInput{}
61 m.com = com
62 m.provider = provider
63 m.model = model
64 m.modelType = modelType
65 m.width = 60
66
67 innerWidth := m.width - t.Dialog.View.GetHorizontalFrameSize() - 2
68
69 m.input = textinput.New()
70 m.input.SetVirtualCursor(false)
71 m.input.Placeholder = "Enter you API key..."
72 m.input.SetStyles(com.Styles.TextInput)
73 m.input.Focus()
74 m.input.SetWidth(innerWidth - t.Dialog.InputPrompt.GetHorizontalFrameSize() - 1) // (1) cursor padding
75
76 m.spinner = spinner.New(
77 spinner.WithSpinner(spinner.Dot),
78 spinner.WithStyle(t.Base.Foreground(t.Green)),
79 )
80
81 m.help = help.New()
82 m.help.Styles = t.DialogHelpStyles()
83
84 m.keyMap.Submit = key.NewBinding(
85 key.WithKeys("enter", "ctrl+y"),
86 key.WithHelp("enter", "submit"),
87 )
88 m.keyMap.Close = CloseKey
89
90 return &m, nil
91}
92
93// ID implements Dialog.
94func (m *APIKeyInput) ID() string {
95 return APIKeyInputID
96}
97
98// HandleMsg implements [Dialog].
99func (m *APIKeyInput) HandleMsg(msg tea.Msg) Action {
100 switch msg := msg.(type) {
101 case ActionChangeAPIKeyState:
102 m.state = msg.State
103 switch m.state {
104 case APIKeyInputStateVerifying:
105 cmd := tea.Batch(m.spinner.Tick, m.verifyAPIKey)
106 return ActionCmd{cmd}
107 }
108 case spinner.TickMsg:
109 switch m.state {
110 case APIKeyInputStateVerifying:
111 var cmd tea.Cmd
112 m.spinner, cmd = m.spinner.Update(msg)
113 if cmd != nil {
114 return ActionCmd{cmd}
115 }
116 }
117 case tea.KeyPressMsg:
118 switch {
119 case m.state == APIKeyInputStateVerifying:
120 // do nothing
121 case key.Matches(msg, m.keyMap.Close):
122 switch m.state {
123 case APIKeyInputStateVerified:
124 return m.saveKeyAndContinue()
125 default:
126 return ActionClose{}
127 }
128 case key.Matches(msg, m.keyMap.Submit):
129 switch m.state {
130 case APIKeyInputStateInitial, APIKeyInputStateError:
131 return ActionChangeAPIKeyState{State: APIKeyInputStateVerifying}
132 case APIKeyInputStateVerified:
133 return m.saveKeyAndContinue()
134 }
135 default:
136 var cmd tea.Cmd
137 m.input, cmd = m.input.Update(msg)
138 if cmd != nil {
139 return ActionCmd{cmd}
140 }
141 }
142 case tea.PasteMsg:
143 var cmd tea.Cmd
144 m.input, cmd = m.input.Update(msg)
145 if cmd != nil {
146 return ActionCmd{cmd}
147 }
148 }
149 return nil
150}
151
152// Draw implements [Dialog].
153func (m *APIKeyInput) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
154 t := m.com.Styles
155
156 textStyle := t.Dialog.SecondaryText
157 helpStyle := t.Dialog.HelpView
158 dialogStyle := t.Dialog.View.Width(m.width)
159 inputStyle := t.Dialog.InputPrompt
160 helpStyle = helpStyle.Width(m.width - dialogStyle.GetHorizontalFrameSize())
161
162 m.input.Prompt = m.spinner.View()
163
164 content := strings.Join([]string{
165 m.headerView(),
166 inputStyle.Render(m.inputView()),
167 textStyle.Render("This will be written in your global configuration:"),
168 textStyle.Render(config.GlobalConfigData()),
169 "",
170 helpStyle.Render(m.help.View(m)),
171 }, "\n")
172
173 view := dialogStyle.Render(content)
174
175 cur := m.Cursor()
176 DrawCenterCursor(scr, area, view, cur)
177 return cur
178}
179
180func (m *APIKeyInput) headerView() string {
181 t := m.com.Styles
182 titleStyle := t.Dialog.Title
183 dialogStyle := t.Dialog.View.Width(m.width)
184
185 headerOffset := titleStyle.GetHorizontalFrameSize() + dialogStyle.GetHorizontalFrameSize()
186 return common.DialogTitle(t, titleStyle.Render(m.dialogTitle()), m.width-headerOffset)
187}
188
189func (m *APIKeyInput) dialogTitle() string {
190 t := m.com.Styles
191 textStyle := t.Dialog.TitleText
192 errorStyle := t.Dialog.TitleError
193 accentStyle := t.Dialog.TitleAccent
194
195 switch m.state {
196 case APIKeyInputStateInitial:
197 return textStyle.Render("Enter your ") + accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render(".")
198 case APIKeyInputStateVerifying:
199 return textStyle.Render("Verifying your ") + accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render("...")
200 case APIKeyInputStateVerified:
201 return accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render(" validated.")
202 case APIKeyInputStateError:
203 return errorStyle.Render("Invalid ") + accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + errorStyle.Render(". Try again?")
204 }
205 return ""
206}
207
208func (m *APIKeyInput) inputView() string {
209 t := m.com.Styles
210
211 switch m.state {
212 case APIKeyInputStateInitial:
213 m.input.Prompt = "> "
214 m.input.SetStyles(t.TextInput)
215 m.input.Focus()
216 case APIKeyInputStateVerifying:
217 ts := t.TextInput
218 ts.Blurred.Prompt = ts.Focused.Prompt
219
220 m.input.Prompt = m.spinner.View()
221 m.input.SetStyles(ts)
222 m.input.Blur()
223 case APIKeyInputStateVerified:
224 ts := t.TextInput
225 ts.Blurred.Prompt = ts.Focused.Prompt
226
227 m.input.Prompt = styles.CheckIcon + " "
228 m.input.SetStyles(ts)
229 m.input.Blur()
230 case APIKeyInputStateError:
231 ts := t.TextInput
232 ts.Focused.Prompt = ts.Focused.Prompt.Foreground(charmtone.Cherry)
233
234 m.input.Prompt = styles.ErrorIcon + " "
235 m.input.SetStyles(ts)
236 m.input.Focus()
237 }
238 return m.input.View()
239}
240
241// Cursor returns the cursor position relative to the dialog.
242func (m *APIKeyInput) Cursor() *tea.Cursor {
243 return InputCursor(m.com.Styles, m.input.Cursor())
244}
245
246// FullHelp returns the full help view.
247func (m *APIKeyInput) FullHelp() [][]key.Binding {
248 return [][]key.Binding{
249 {
250 m.keyMap.Submit,
251 m.keyMap.Close,
252 },
253 }
254}
255
256// ShortHelp returns the full help view.
257func (m *APIKeyInput) ShortHelp() []key.Binding {
258 return []key.Binding{
259 m.keyMap.Submit,
260 m.keyMap.Close,
261 }
262}
263
264func (m *APIKeyInput) verifyAPIKey() tea.Msg {
265 start := time.Now()
266
267 providerConfig := config.ProviderConfig{
268 ID: string(m.provider.ID),
269 Name: m.provider.Name,
270 APIKey: m.input.Value(),
271 Type: m.provider.Type,
272 BaseURL: m.provider.APIEndpoint,
273 }
274 err := providerConfig.TestConnection(config.Get().Resolver())
275
276 // intentionally wait for at least 750ms to make sure the user sees the spinner
277 elapsed := time.Since(start)
278 minimum := 750 * time.Millisecond
279 if elapsed < minimum {
280 time.Sleep(minimum - elapsed)
281 }
282
283 if err == nil {
284 return ActionChangeAPIKeyState{APIKeyInputStateVerified}
285 }
286 return ActionChangeAPIKeyState{APIKeyInputStateError}
287}
288
289func (m *APIKeyInput) saveKeyAndContinue() Action {
290 cfg := m.com.Config()
291
292 err := cfg.SetProviderAPIKey(string(m.provider.ID), m.input.Value())
293 if err != nil {
294 return ActionCmd{uiutil.ReportError(fmt.Errorf("failed to save API key: %w", err))}
295 }
296
297 return ActionSelectModel{
298 Provider: m.provider,
299 Model: m.model,
300 ModelType: m.modelType,
301 }
302}