api_key_input.go

  1package dialog
  2
  3import (
  4	"fmt"
  5	"strings"
  6	"time"
  7
  8	"charm.land/bubbles/v2/help"
  9	"charm.land/bubbles/v2/key"
 10	"charm.land/bubbles/v2/spinner"
 11	"charm.land/bubbles/v2/textinput"
 12	tea "charm.land/bubbletea/v2"
 13	"charm.land/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/ui/common"
 16	"github.com/charmbracelet/crush/internal/ui/styles"
 17	"github.com/charmbracelet/crush/internal/ui/util"
 18	uv "github.com/charmbracelet/ultraviolet"
 19	"github.com/charmbracelet/x/exp/charmtone"
 20)
 21
 22type APIKeyInputState int
 23
 24const (
 25	APIKeyInputStateInitial APIKeyInputState = iota
 26	APIKeyInputStateVerifying
 27	APIKeyInputStateVerified
 28	APIKeyInputStateError
 29)
 30
 31// APIKeyInputID is the identifier for the model selection dialog.
 32const APIKeyInputID = "api_key_input"
 33
 34// APIKeyInput represents a model selection dialog.
 35type APIKeyInput struct {
 36	com          *common.Common
 37	isOnboarding bool
 38
 39	provider  catwalk.Provider
 40	model     config.SelectedModel
 41	modelType config.SelectedModelType
 42
 43	width int
 44	state APIKeyInputState
 45
 46	keyMap struct {
 47		Submit key.Binding
 48		Close  key.Binding
 49	}
 50	input   textinput.Model
 51	spinner spinner.Model
 52	help    help.Model
 53}
 54
 55var _ Dialog = (*APIKeyInput)(nil)
 56
 57// NewAPIKeyInput creates a new Models dialog.
 58func NewAPIKeyInput(
 59	com *common.Common,
 60	isOnboarding bool,
 61	provider catwalk.Provider,
 62	model config.SelectedModel,
 63	modelType config.SelectedModelType,
 64) (*APIKeyInput, tea.Cmd) {
 65	t := com.Styles
 66
 67	m := APIKeyInput{}
 68	m.com = com
 69	m.isOnboarding = isOnboarding
 70	m.provider = provider
 71	m.model = model
 72	m.modelType = modelType
 73	m.width = 60
 74
 75	innerWidth := m.width - t.Dialog.View.GetHorizontalFrameSize() - 2
 76
 77	m.input = textinput.New()
 78	m.input.SetVirtualCursor(false)
 79	m.input.Placeholder = "Enter your API key..."
 80	m.input.SetStyles(com.Styles.TextInput)
 81	m.input.Focus()
 82	m.input.SetWidth(max(0, innerWidth-t.Dialog.InputPrompt.GetHorizontalFrameSize()-1)) // (1) cursor padding
 83
 84	m.spinner = spinner.New(
 85		spinner.WithSpinner(spinner.Dot),
 86		spinner.WithStyle(t.Base.Foreground(t.Green)),
 87	)
 88
 89	m.help = help.New()
 90	m.help.Styles = t.DialogHelpStyles()
 91
 92	m.keyMap.Submit = key.NewBinding(
 93		key.WithKeys("enter", "ctrl+y"),
 94		key.WithHelp("enter", "submit"),
 95	)
 96	m.keyMap.Close = CloseKey
 97
 98	return &m, nil
 99}
100
101// ID implements Dialog.
102func (m *APIKeyInput) ID() string {
103	return APIKeyInputID
104}
105
106// HandleMsg implements [Dialog].
107func (m *APIKeyInput) HandleMsg(msg tea.Msg) Action {
108	switch msg := msg.(type) {
109	case ActionChangeAPIKeyState:
110		m.state = msg.State
111		switch m.state {
112		case APIKeyInputStateVerifying:
113			cmd := tea.Batch(m.spinner.Tick, m.verifyAPIKey)
114			return ActionCmd{cmd}
115		}
116	case spinner.TickMsg:
117		switch m.state {
118		case APIKeyInputStateVerifying:
119			var cmd tea.Cmd
120			m.spinner, cmd = m.spinner.Update(msg)
121			if cmd != nil {
122				return ActionCmd{cmd}
123			}
124		}
125	case tea.KeyPressMsg:
126		switch {
127		case m.state == APIKeyInputStateVerifying:
128			// do nothing
129		case key.Matches(msg, m.keyMap.Close):
130			switch m.state {
131			case APIKeyInputStateVerified:
132				return m.saveKeyAndContinue()
133			default:
134				return ActionClose{}
135			}
136		case key.Matches(msg, m.keyMap.Submit):
137			switch m.state {
138			case APIKeyInputStateInitial, APIKeyInputStateError:
139				return ActionChangeAPIKeyState{State: APIKeyInputStateVerifying}
140			case APIKeyInputStateVerified:
141				return m.saveKeyAndContinue()
142			}
143		default:
144			var cmd tea.Cmd
145			m.input, cmd = m.input.Update(msg)
146			if cmd != nil {
147				return ActionCmd{cmd}
148			}
149		}
150	case tea.PasteMsg:
151		var cmd tea.Cmd
152		m.input, cmd = m.input.Update(msg)
153		if cmd != nil {
154			return ActionCmd{cmd}
155		}
156	}
157	return nil
158}
159
160// Draw implements [Dialog].
161func (m *APIKeyInput) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
162	t := m.com.Styles
163
164	textStyle := t.Dialog.SecondaryText
165	helpStyle := t.Dialog.HelpView
166	dialogStyle := t.Dialog.View.Width(m.width)
167	inputStyle := t.Dialog.InputPrompt
168	helpStyle = helpStyle.Width(m.width - dialogStyle.GetHorizontalFrameSize())
169
170	m.input.Prompt = m.spinner.View()
171
172	content := strings.Join([]string{
173		m.headerView(),
174		inputStyle.Render(m.inputView()),
175		textStyle.Render("This will be written in your global configuration:"),
176		textStyle.Render(config.GlobalConfigData()),
177		"",
178		helpStyle.Render(m.help.View(m)),
179	}, "\n")
180
181	cur := m.Cursor()
182
183	if m.isOnboarding {
184		view := content
185		DrawOnboardingCursor(scr, area, view, cur)
186
187		// FIXME(@andreynering): Figure it out how to properly fix this
188		if cur != nil {
189			cur.Y -= 1
190			cur.X -= 1
191		}
192	} else {
193		view := dialogStyle.Render(content)
194		DrawCenterCursor(scr, area, view, cur)
195	}
196	return cur
197}
198
199func (m *APIKeyInput) headerView() string {
200	var (
201		t           = m.com.Styles
202		titleStyle  = t.Dialog.Title
203		textStyle   = t.Dialog.PrimaryText
204		dialogStyle = t.Dialog.View.Width(m.width)
205	)
206	if m.isOnboarding {
207		return textStyle.Render(m.dialogTitle())
208	}
209	headerOffset := titleStyle.GetHorizontalFrameSize() + dialogStyle.GetHorizontalFrameSize()
210	return common.DialogTitle(t, titleStyle.Render(m.dialogTitle()), m.width-headerOffset, m.com.Styles.Primary, m.com.Styles.Secondary)
211}
212
213func (m *APIKeyInput) dialogTitle() string {
214	var (
215		t           = m.com.Styles
216		textStyle   = t.Dialog.TitleText
217		errorStyle  = t.Dialog.TitleError
218		accentStyle = t.Dialog.TitleAccent
219	)
220	switch m.state {
221	case APIKeyInputStateInitial:
222		return textStyle.Render("Enter your ") + accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render(".")
223	case APIKeyInputStateVerifying:
224		return textStyle.Render("Verifying your ") + accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render("...")
225	case APIKeyInputStateVerified:
226		return accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render(" validated.")
227	case APIKeyInputStateError:
228		return errorStyle.Render("Invalid ") + accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + errorStyle.Render(". Try again?")
229	}
230	return ""
231}
232
233func (m *APIKeyInput) inputView() string {
234	t := m.com.Styles
235
236	switch m.state {
237	case APIKeyInputStateInitial:
238		m.input.Prompt = "> "
239		m.input.SetStyles(t.TextInput)
240		m.input.Focus()
241	case APIKeyInputStateVerifying:
242		ts := t.TextInput
243		ts.Blurred.Prompt = ts.Focused.Prompt
244
245		m.input.Prompt = m.spinner.View()
246		m.input.SetStyles(ts)
247		m.input.Blur()
248	case APIKeyInputStateVerified:
249		ts := t.TextInput
250		ts.Blurred.Prompt = ts.Focused.Prompt
251
252		m.input.Prompt = styles.CheckIcon + " "
253		m.input.SetStyles(ts)
254		m.input.Blur()
255	case APIKeyInputStateError:
256		ts := t.TextInput
257		ts.Focused.Prompt = ts.Focused.Prompt.Foreground(charmtone.Cherry)
258
259		m.input.Prompt = styles.LSPErrorIcon + " "
260		m.input.SetStyles(ts)
261		m.input.Focus()
262	}
263	return m.input.View()
264}
265
266// Cursor returns the cursor position relative to the dialog.
267func (m *APIKeyInput) Cursor() *tea.Cursor {
268	return InputCursor(m.com.Styles, m.input.Cursor())
269}
270
271// FullHelp returns the full help view.
272func (m *APIKeyInput) FullHelp() [][]key.Binding {
273	return [][]key.Binding{
274		{
275			m.keyMap.Submit,
276			m.keyMap.Close,
277		},
278	}
279}
280
281// ShortHelp returns the full help view.
282func (m *APIKeyInput) ShortHelp() []key.Binding {
283	return []key.Binding{
284		m.keyMap.Submit,
285		m.keyMap.Close,
286	}
287}
288
289func (m *APIKeyInput) verifyAPIKey() tea.Msg {
290	start := time.Now()
291
292	providerConfig := config.ProviderConfig{
293		ID:      string(m.provider.ID),
294		Name:    m.provider.Name,
295		APIKey:  m.input.Value(),
296		Type:    m.provider.Type,
297		BaseURL: m.provider.APIEndpoint,
298	}
299	err := providerConfig.TestConnection(m.com.Config().Resolver())
300
301	// intentionally wait for at least 750ms to make sure the user sees the spinner
302	elapsed := time.Since(start)
303	minimum := 750 * time.Millisecond
304	if elapsed < minimum {
305		time.Sleep(minimum - elapsed)
306	}
307
308	if err == nil {
309		return ActionChangeAPIKeyState{APIKeyInputStateVerified}
310	}
311	return ActionChangeAPIKeyState{APIKeyInputStateError}
312}
313
314func (m *APIKeyInput) saveKeyAndContinue() Action {
315	cfg := m.com.Config()
316
317	err := cfg.SetProviderAPIKey(string(m.provider.ID), m.input.Value())
318	if err != nil {
319		return ActionCmd{util.ReportError(fmt.Errorf("failed to save API key: %w", err))}
320	}
321
322	return ActionSelectModel{
323		Provider:  m.provider,
324		Model:     m.model,
325		ModelType: m.modelType,
326	}
327}