api_key_input.go

  1package dialog
  2
  3import (
  4	"fmt"
  5	"strings"
  6	"time"
  7
  8	"charm.land/bubbles/v2/help"
  9	"charm.land/bubbles/v2/key"
 10	"charm.land/bubbles/v2/spinner"
 11	"charm.land/bubbles/v2/textinput"
 12	tea "charm.land/bubbletea/v2"
 13	"charm.land/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/ui/common"
 16	"github.com/charmbracelet/crush/internal/ui/styles"
 17	"github.com/charmbracelet/crush/internal/ui/util"
 18	uv "github.com/charmbracelet/ultraviolet"
 19	"github.com/charmbracelet/x/exp/charmtone"
 20)
 21
 22type APIKeyInputState int
 23
 24const (
 25	APIKeyInputStateInitial APIKeyInputState = iota
 26	APIKeyInputStateVerifying
 27	APIKeyInputStateVerified
 28	APIKeyInputStateError
 29)
 30
 31// APIKeyInputID is the identifier for the model selection dialog.
 32const APIKeyInputID = "api_key_input"
 33
 34// APIKeyInput represents a model selection dialog.
 35type APIKeyInput struct {
 36	com          *common.Common
 37	isOnboarding bool
 38
 39	provider  catwalk.Provider
 40	model     config.SelectedModel
 41	modelType config.SelectedModelType
 42
 43	width int
 44	state APIKeyInputState
 45
 46	keyMap struct {
 47		Submit key.Binding
 48		Close  key.Binding
 49	}
 50	input   textinput.Model
 51	spinner spinner.Model
 52	help    help.Model
 53}
 54
 55var _ Dialog = (*APIKeyInput)(nil)
 56
 57// NewAPIKeyInput creates a new Models dialog.
 58func NewAPIKeyInput(
 59	com *common.Common,
 60	isOnboarding bool,
 61	provider catwalk.Provider,
 62	model config.SelectedModel,
 63	modelType config.SelectedModelType,
 64) (*APIKeyInput, tea.Cmd) {
 65	t := com.Styles
 66
 67	m := APIKeyInput{}
 68	m.com = com
 69	m.isOnboarding = isOnboarding
 70	m.provider = provider
 71	m.model = model
 72	m.modelType = modelType
 73	m.width = 60
 74
 75	innerWidth := m.width - t.Dialog.View.GetHorizontalFrameSize() - 2
 76
 77	m.input = textinput.New()
 78	m.input.SetVirtualCursor(false)
 79	m.input.Placeholder = "Enter your API key..."
 80	m.input.SetStyles(com.Styles.TextInput)
 81	m.input.Focus()
 82	m.input.SetWidth(max(0, innerWidth-t.Dialog.InputPrompt.GetHorizontalFrameSize()-1)) // (1) cursor padding
 83
 84	m.spinner = spinner.New(
 85		spinner.WithSpinner(spinner.Dot),
 86		spinner.WithStyle(t.Base.Foreground(t.Green)),
 87	)
 88
 89	m.help = help.New()
 90	m.help.Styles = t.DialogHelpStyles()
 91
 92	m.keyMap.Submit = key.NewBinding(
 93		key.WithKeys("enter", "ctrl+y"),
 94		key.WithHelp("enter", "submit"),
 95	)
 96	m.keyMap.Close = CloseKey
 97
 98	return &m, nil
 99}
100
101// ID implements Dialog.
102func (m *APIKeyInput) ID() string {
103	return APIKeyInputID
104}
105
106// HandleMsg implements [Dialog].
107func (m *APIKeyInput) HandleMsg(msg tea.Msg) Action {
108	switch msg := msg.(type) {
109	case ActionChangeAPIKeyState:
110		m.state = msg.State
111		switch m.state {
112		case APIKeyInputStateVerifying:
113			cmd := tea.Batch(m.spinner.Tick, m.verifyAPIKey)
114			return ActionCmd{cmd}
115		}
116	case spinner.TickMsg:
117		switch m.state {
118		case APIKeyInputStateVerifying:
119			var cmd tea.Cmd
120			m.spinner, cmd = m.spinner.Update(msg)
121			if cmd != nil {
122				return ActionCmd{cmd}
123			}
124		}
125	case tea.KeyPressMsg:
126		switch {
127		case m.state == APIKeyInputStateVerifying:
128			// do nothing
129		case key.Matches(msg, m.keyMap.Close):
130			switch m.state {
131			case APIKeyInputStateVerified:
132				return m.saveKeyAndContinue()
133			default:
134				return ActionClose{}
135			}
136		case key.Matches(msg, m.keyMap.Submit):
137			switch m.state {
138			case APIKeyInputStateInitial, APIKeyInputStateError:
139				return ActionChangeAPIKeyState{State: APIKeyInputStateVerifying}
140			case APIKeyInputStateVerified:
141				return m.saveKeyAndContinue()
142			}
143		default:
144			var cmd tea.Cmd
145			m.input, cmd = m.input.Update(msg)
146			if cmd != nil {
147				return ActionCmd{cmd}
148			}
149		}
150	case tea.PasteMsg:
151		var cmd tea.Cmd
152		m.input, cmd = m.input.Update(msg)
153		if cmd != nil {
154			return ActionCmd{cmd}
155		}
156	}
157	return nil
158}
159
160// Draw implements [Dialog].
161func (m *APIKeyInput) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
162	t := m.com.Styles
163
164	textStyle := t.Dialog.SecondaryText
165	helpStyle := t.Dialog.HelpView
166	dialogStyle := t.Dialog.View.Width(m.width)
167	inputStyle := t.Dialog.InputPrompt
168	helpStyle = helpStyle.Width(m.width - dialogStyle.GetHorizontalFrameSize())
169
170	m.input.Prompt = m.spinner.View()
171
172	content := strings.Join([]string{
173		m.headerView(),
174		inputStyle.Render(m.inputView()),
175		textStyle.Render("This will be written in your global configuration:"),
176		textStyle.Render(config.GlobalConfigData()),
177		"",
178		helpStyle.Render(m.help.View(m)),
179	}, "\n")
180
181	cur := m.Cursor()
182
183	if m.isOnboarding {
184		view := content
185		cur = adjustOnboardingInputCursor(t, cur)
186		DrawOnboardingCursor(scr, area, view, cur)
187	} else {
188		view := dialogStyle.Render(content)
189		DrawCenterCursor(scr, area, view, cur)
190	}
191	return cur
192}
193
194func (m *APIKeyInput) headerView() string {
195	var (
196		t           = m.com.Styles
197		titleStyle  = t.Dialog.Title
198		textStyle   = t.Dialog.PrimaryText
199		dialogStyle = t.Dialog.View.Width(m.width)
200	)
201	if m.isOnboarding {
202		return textStyle.Render(m.dialogTitle())
203	}
204	headerOffset := titleStyle.GetHorizontalFrameSize() + dialogStyle.GetHorizontalFrameSize()
205	return common.DialogTitle(t, titleStyle.Render(m.dialogTitle()), m.width-headerOffset, m.com.Styles.Primary, m.com.Styles.Secondary)
206}
207
208func (m *APIKeyInput) dialogTitle() string {
209	var (
210		t           = m.com.Styles
211		textStyle   = t.Dialog.TitleText
212		errorStyle  = t.Dialog.TitleError
213		accentStyle = t.Dialog.TitleAccent
214	)
215	switch m.state {
216	case APIKeyInputStateInitial:
217		return textStyle.Render("Enter your ") + accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render(".")
218	case APIKeyInputStateVerifying:
219		return textStyle.Render("Verifying your ") + accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render("...")
220	case APIKeyInputStateVerified:
221		return accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render(" validated.")
222	case APIKeyInputStateError:
223		return errorStyle.Render("Invalid ") + accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + errorStyle.Render(". Try again?")
224	}
225	return ""
226}
227
228func (m *APIKeyInput) inputView() string {
229	t := m.com.Styles
230
231	switch m.state {
232	case APIKeyInputStateInitial:
233		m.input.Prompt = "> "
234		m.input.SetStyles(t.TextInput)
235		m.input.Focus()
236	case APIKeyInputStateVerifying:
237		ts := t.TextInput
238		ts.Blurred.Prompt = ts.Focused.Prompt
239
240		m.input.Prompt = m.spinner.View()
241		m.input.SetStyles(ts)
242		m.input.Blur()
243	case APIKeyInputStateVerified:
244		ts := t.TextInput
245		ts.Blurred.Prompt = ts.Focused.Prompt
246
247		m.input.Prompt = styles.CheckIcon + " "
248		m.input.SetStyles(ts)
249		m.input.Blur()
250	case APIKeyInputStateError:
251		ts := t.TextInput
252		ts.Focused.Prompt = ts.Focused.Prompt.Foreground(charmtone.Cherry)
253
254		m.input.Prompt = styles.LSPErrorIcon + " "
255		m.input.SetStyles(ts)
256		m.input.Focus()
257	}
258	return m.input.View()
259}
260
261// Cursor returns the cursor position relative to the dialog.
262func (m *APIKeyInput) Cursor() *tea.Cursor {
263	return InputCursor(m.com.Styles, m.input.Cursor())
264}
265
266// FullHelp returns the full help view.
267func (m *APIKeyInput) FullHelp() [][]key.Binding {
268	return [][]key.Binding{
269		{
270			m.keyMap.Submit,
271			m.keyMap.Close,
272		},
273	}
274}
275
276// ShortHelp returns the full help view.
277func (m *APIKeyInput) ShortHelp() []key.Binding {
278	return []key.Binding{
279		m.keyMap.Submit,
280		m.keyMap.Close,
281	}
282}
283
284func (m *APIKeyInput) verifyAPIKey() tea.Msg {
285	start := time.Now()
286
287	providerConfig := config.ProviderConfig{
288		ID:      string(m.provider.ID),
289		Name:    m.provider.Name,
290		APIKey:  m.input.Value(),
291		Type:    m.provider.Type,
292		BaseURL: m.provider.APIEndpoint,
293	}
294	err := providerConfig.TestConnection(m.com.Workspace.Resolver())
295
296	// intentionally wait for at least 750ms to make sure the user sees the spinner
297	elapsed := time.Since(start)
298	minimum := 750 * time.Millisecond
299	if elapsed < minimum {
300		time.Sleep(minimum - elapsed)
301	}
302
303	if err == nil {
304		return ActionChangeAPIKeyState{APIKeyInputStateVerified}
305	}
306	return ActionChangeAPIKeyState{APIKeyInputStateError}
307}
308
309func (m *APIKeyInput) saveKeyAndContinue() Action {
310	err := m.com.Workspace.SetProviderAPIKey(config.ScopeGlobal, string(m.provider.ID), m.input.Value())
311	if err != nil {
312		return ActionCmd{util.ReportError(fmt.Errorf("failed to save API key: %w", err))}
313	}
314
315	return ActionSelectModel{
316		Provider:  m.provider,
317		Model:     m.model,
318		ModelType: m.modelType,
319	}
320}