api_key_input.go

  1package dialog
  2
  3import (
  4	"errors"
  5	"fmt"
  6	"maps"
  7	"strings"
  8	"time"
  9
 10	"charm.land/bubbles/v2/help"
 11	"charm.land/bubbles/v2/key"
 12	"charm.land/bubbles/v2/spinner"
 13	"charm.land/bubbles/v2/textinput"
 14	tea "charm.land/bubbletea/v2"
 15	"charm.land/catwalk/pkg/catwalk"
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/ui/common"
 18	"github.com/charmbracelet/crush/internal/ui/styles"
 19	"github.com/charmbracelet/crush/internal/ui/util"
 20	uv "github.com/charmbracelet/ultraviolet"
 21	"github.com/charmbracelet/x/exp/charmtone"
 22)
 23
 24type APIKeyInputState int
 25
 26const (
 27	APIKeyInputStateInitial APIKeyInputState = iota
 28	APIKeyInputStateVerifying
 29	APIKeyInputStateVerified
 30	// APIKeyInputStateUnverified indicates the key was saved but the
 31	// provider does not expose a deterministic validation probe, so
 32	// authentication could not be proven.
 33	APIKeyInputStateUnverified
 34	APIKeyInputStateError
 35)
 36
 37// APIKeyInputID is the identifier for the model selection dialog.
 38const APIKeyInputID = "api_key_input"
 39
 40// APIKeyInput represents a model selection dialog.
 41type APIKeyInput struct {
 42	com          *common.Common
 43	isOnboarding bool
 44
 45	provider  catwalk.Provider
 46	model     config.SelectedModel
 47	modelType config.SelectedModelType
 48
 49	width int
 50	state APIKeyInputState
 51
 52	keyMap struct {
 53		Submit key.Binding
 54		Close  key.Binding
 55	}
 56	input   textinput.Model
 57	spinner spinner.Model
 58	help    help.Model
 59}
 60
 61var _ Dialog = (*APIKeyInput)(nil)
 62
 63// NewAPIKeyInput creates a new Models dialog.
 64func NewAPIKeyInput(
 65	com *common.Common,
 66	isOnboarding bool,
 67	provider catwalk.Provider,
 68	model config.SelectedModel,
 69	modelType config.SelectedModelType,
 70) (*APIKeyInput, tea.Cmd) {
 71	t := com.Styles
 72
 73	m := APIKeyInput{}
 74	m.com = com
 75	m.isOnboarding = isOnboarding
 76	m.provider = provider
 77	m.model = model
 78	m.modelType = modelType
 79	m.width = 60
 80
 81	innerWidth := m.width - t.Dialog.View.GetHorizontalFrameSize() - 2
 82
 83	m.input = textinput.New()
 84	m.input.SetVirtualCursor(false)
 85	m.input.Placeholder = "Enter your API key..."
 86	m.input.SetStyles(com.Styles.TextInput)
 87	m.input.Focus()
 88	m.input.SetWidth(max(0, innerWidth-t.Dialog.InputPrompt.GetHorizontalFrameSize()-1)) // (1) cursor padding
 89
 90	m.spinner = spinner.New(
 91		spinner.WithSpinner(spinner.Dot),
 92		spinner.WithStyle(t.Dialog.APIKey.Spinner),
 93	)
 94
 95	m.help = help.New()
 96	m.help.Styles = t.DialogHelpStyles()
 97
 98	m.keyMap.Submit = key.NewBinding(
 99		key.WithKeys("enter", "ctrl+y"),
100		key.WithHelp("enter", "submit"),
101	)
102	m.keyMap.Close = CloseKey
103
104	return &m, nil
105}
106
107// ID implements Dialog.
108func (m *APIKeyInput) ID() string {
109	return APIKeyInputID
110}
111
112// HandleMsg implements [Dialog].
113func (m *APIKeyInput) HandleMsg(msg tea.Msg) Action {
114	switch msg := msg.(type) {
115	case ActionChangeAPIKeyState:
116		m.state = msg.State
117		switch m.state {
118		case APIKeyInputStateVerifying:
119			cmd := tea.Batch(m.spinner.Tick, m.verifyAPIKey)
120			return ActionCmd{cmd}
121		}
122	case spinner.TickMsg:
123		switch m.state {
124		case APIKeyInputStateVerifying:
125			var cmd tea.Cmd
126			m.spinner, cmd = m.spinner.Update(msg)
127			if cmd != nil {
128				return ActionCmd{cmd}
129			}
130		}
131	case tea.KeyPressMsg:
132		switch {
133		case m.state == APIKeyInputStateVerifying:
134			// do nothing
135		case key.Matches(msg, m.keyMap.Close):
136			switch m.state {
137			case APIKeyInputStateVerified, APIKeyInputStateUnverified:
138				return m.saveKeyAndContinue()
139			default:
140				return ActionClose{}
141			}
142		case key.Matches(msg, m.keyMap.Submit):
143			switch m.state {
144			case APIKeyInputStateInitial, APIKeyInputStateError:
145				return ActionChangeAPIKeyState{State: APIKeyInputStateVerifying}
146			case APIKeyInputStateVerified, APIKeyInputStateUnverified:
147				return m.saveKeyAndContinue()
148			}
149		default:
150			var cmd tea.Cmd
151			m.input, cmd = m.input.Update(msg)
152			if cmd != nil {
153				return ActionCmd{cmd}
154			}
155		}
156	case tea.PasteMsg:
157		var cmd tea.Cmd
158		m.input, cmd = m.input.Update(msg)
159		if cmd != nil {
160			return ActionCmd{cmd}
161		}
162	}
163	return nil
164}
165
166// Draw implements [Dialog].
167func (m *APIKeyInput) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
168	t := m.com.Styles
169
170	textStyle := t.Dialog.SecondaryText
171	helpStyle := t.Dialog.HelpView
172	dialogStyle := t.Dialog.View.Width(m.width)
173	inputStyle := t.Dialog.InputPrompt
174	helpStyle = helpStyle.Width(m.width - dialogStyle.GetHorizontalFrameSize())
175
176	m.input.Prompt = m.spinner.View()
177
178	content := strings.Join([]string{
179		m.headerView(),
180		inputStyle.Render(m.inputView()),
181		textStyle.Render("This will be written in your global configuration:"),
182		textStyle.Render(config.GlobalConfigData()),
183		"",
184		helpStyle.Render(m.help.View(m)),
185	}, "\n")
186
187	cur := m.Cursor()
188
189	if m.isOnboarding {
190		view := content
191		cur = adjustOnboardingInputCursor(t, cur)
192		DrawOnboardingCursor(scr, area, view, cur)
193	} else {
194		view := dialogStyle.Render(content)
195		DrawCenterCursor(scr, area, view, cur)
196	}
197	return cur
198}
199
200func (m *APIKeyInput) headerView() string {
201	var (
202		t           = m.com.Styles
203		titleStyle  = t.Dialog.Title
204		textStyle   = t.Dialog.PrimaryText
205		dialogStyle = t.Dialog.View.Width(m.width)
206	)
207	if m.isOnboarding {
208		return textStyle.Render(m.dialogTitle())
209	}
210	headerOffset := titleStyle.GetHorizontalFrameSize() + dialogStyle.GetHorizontalFrameSize()
211	return common.DialogTitle(t, titleStyle.Render(m.dialogTitle()), m.width-headerOffset, m.com.Styles.Dialog.TitleGradFromColor, m.com.Styles.Dialog.TitleGradToColor)
212}
213
214func (m *APIKeyInput) dialogTitle() string {
215	var (
216		t           = m.com.Styles
217		textStyle   = t.Dialog.TitleText
218		errorStyle  = t.Dialog.TitleError
219		accentStyle = t.Dialog.TitleAccent
220	)
221	switch m.state {
222	case APIKeyInputStateInitial:
223		return textStyle.Render("Enter your ") + accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render(".")
224	case APIKeyInputStateVerifying:
225		return textStyle.Render("Verifying your ") + accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render("...")
226	case APIKeyInputStateVerified:
227		return accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render(" validated.")
228	case APIKeyInputStateUnverified:
229		return accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render(" saved (not verified).")
230	case APIKeyInputStateError:
231		return errorStyle.Render("Invalid ") + accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + errorStyle.Render(". Try again?")
232	}
233	return ""
234}
235
236func (m *APIKeyInput) inputView() string {
237	t := m.com.Styles
238
239	switch m.state {
240	case APIKeyInputStateInitial:
241		m.input.Prompt = "> "
242		m.input.SetStyles(t.TextInput)
243		m.input.Focus()
244	case APIKeyInputStateVerifying:
245		ts := t.TextInput
246		ts.Blurred.Prompt = ts.Focused.Prompt
247
248		m.input.Prompt = m.spinner.View()
249		m.input.SetStyles(ts)
250		m.input.Blur()
251	case APIKeyInputStateVerified, APIKeyInputStateUnverified:
252		ts := t.TextInput
253		ts.Blurred.Prompt = ts.Focused.Prompt
254
255		m.input.Prompt = styles.CheckIcon + " "
256		m.input.SetStyles(ts)
257		m.input.Blur()
258	case APIKeyInputStateError:
259		ts := t.TextInput
260		ts.Focused.Prompt = ts.Focused.Prompt.Foreground(charmtone.Cherry)
261
262		m.input.Prompt = styles.LSPErrorIcon + " "
263		m.input.SetStyles(ts)
264		m.input.Focus()
265	}
266	return m.input.View()
267}
268
269// Cursor returns the cursor position relative to the dialog.
270func (m *APIKeyInput) Cursor() *tea.Cursor {
271	return InputCursor(m.com.Styles, m.input.Cursor())
272}
273
274// FullHelp returns the full help view.
275func (m *APIKeyInput) FullHelp() [][]key.Binding {
276	return [][]key.Binding{
277		{
278			m.keyMap.Submit,
279			m.keyMap.Close,
280		},
281	}
282}
283
284// ShortHelp returns the full help view.
285func (m *APIKeyInput) ShortHelp() []key.Binding {
286	return []key.Binding{
287		m.keyMap.Submit,
288		m.keyMap.Close,
289	}
290}
291
292func (m *APIKeyInput) verifyAPIKey() tea.Msg {
293	start := time.Now()
294
295	providerConfig := providerConfigForVerify(m.provider, m.input.Value())
296	err := providerConfig.TestConnection(m.com.Workspace.Resolver())
297
298	// intentionally wait for at least 750ms to make sure the user sees the spinner
299	elapsed := time.Since(start)
300	minimum := 750 * time.Millisecond
301	if elapsed < minimum {
302		time.Sleep(minimum - elapsed)
303	}
304
305	return ActionChangeAPIKeyState{State: apiKeyStateForVerifyErr(err)}
306}
307
308// providerConfigForVerify builds the [config.ProviderConfig] used to probe a
309// provider's API key from the dialog. In particular it propagates the
310// provider's [catwalk.Provider.DefaultHeaders] into [config.ProviderConfig.ExtraHeaders]
311// so validation probes carry any headers the provider expects (e.g. routing
312// or tenant hints) — matching the behaviour used for real inference traffic.
313func providerConfigForVerify(provider catwalk.Provider, apiKey string) config.ProviderConfig {
314	cfg := config.ProviderConfig{
315		ID:      string(provider.ID),
316		Name:    provider.Name,
317		APIKey:  apiKey,
318		Type:    provider.Type,
319		BaseURL: provider.APIEndpoint,
320	}
321	if len(provider.DefaultHeaders) > 0 {
322		cfg.ExtraHeaders = make(map[string]string, len(provider.DefaultHeaders))
323		maps.Copy(cfg.ExtraHeaders, provider.DefaultHeaders)
324	}
325	return cfg
326}
327
328// apiKeyStateForVerifyErr maps a [config.ProviderConfig.TestConnection] error
329// to the dialog state the UI should transition into. Extracted so the
330// mapping can be unit-tested without spinning up a full [common.Common].
331func apiKeyStateForVerifyErr(err error) APIKeyInputState {
332	switch {
333	case err == nil:
334		return APIKeyInputStateVerified
335	case errors.Is(err, config.ErrValidationUnsupported):
336		return APIKeyInputStateUnverified
337	default:
338		return APIKeyInputStateError
339	}
340}
341
342func (m *APIKeyInput) saveKeyAndContinue() Action {
343	err := m.com.Workspace.SetProviderAPIKey(config.ScopeGlobal, string(m.provider.ID), m.input.Value())
344	if err != nil {
345		return ActionCmd{util.ReportError(fmt.Errorf("failed to save API key: %w", err))}
346	}
347
348	return ActionSelectModel{
349		Provider:  m.provider,
350		Model:     m.model,
351		ModelType: m.modelType,
352	}
353}