api_key_input.go

  1package dialog
  2
  3import (
  4	"fmt"
  5	"strings"
  6	"time"
  7
  8	"charm.land/bubbles/v2/help"
  9	"charm.land/bubbles/v2/key"
 10	"charm.land/bubbles/v2/spinner"
 11	"charm.land/bubbles/v2/textinput"
 12	tea "charm.land/bubbletea/v2"
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/ui/common"
 16	"github.com/charmbracelet/crush/internal/ui/styles"
 17	"github.com/charmbracelet/crush/internal/uiutil"
 18	uv "github.com/charmbracelet/ultraviolet"
 19	"github.com/charmbracelet/x/exp/charmtone"
 20)
 21
 22type APIKeyInputState int
 23
 24const (
 25	APIKeyInputStateInitial APIKeyInputState = iota
 26	APIKeyInputStateVerifying
 27	APIKeyInputStateVerified
 28	APIKeyInputStateError
 29)
 30
 31// APIKeyInputID is the identifier for the model selection dialog.
 32const APIKeyInputID = "api_key_input"
 33
 34// APIKeyInput represents a model selection dialog.
 35type APIKeyInput struct {
 36	com *common.Common
 37
 38	provider  catwalk.Provider
 39	model     config.SelectedModel
 40	modelType config.SelectedModelType
 41
 42	width int
 43	state APIKeyInputState
 44
 45	keyMap struct {
 46		Submit key.Binding
 47		Close  key.Binding
 48	}
 49	input   textinput.Model
 50	spinner spinner.Model
 51	help    help.Model
 52}
 53
 54var _ Dialog = (*APIKeyInput)(nil)
 55
 56// NewAPIKeyInput creates a new Models dialog.
 57func NewAPIKeyInput(com *common.Common, provider catwalk.Provider, model config.SelectedModel, modelType config.SelectedModelType) (*APIKeyInput, error) {
 58	t := com.Styles
 59
 60	m := APIKeyInput{}
 61	m.com = com
 62	m.provider = provider
 63	m.model = model
 64	m.modelType = modelType
 65	m.width = 60
 66
 67	innerWidth := m.width - t.Dialog.View.GetHorizontalFrameSize() - 2
 68
 69	m.input = textinput.New()
 70	m.input.SetVirtualCursor(false)
 71	m.input.Placeholder = "Enter you API key..."
 72	m.input.SetStyles(com.Styles.TextInput)
 73	m.input.Focus()
 74	m.input.SetWidth(innerWidth - t.Dialog.InputPrompt.GetHorizontalFrameSize() - 1) // (1) cursor padding
 75
 76	m.spinner = spinner.New(
 77		spinner.WithSpinner(spinner.Dot),
 78		spinner.WithStyle(t.Base.Foreground(t.Green)),
 79	)
 80
 81	m.help = help.New()
 82	m.help.Styles = t.DialogHelpStyles()
 83
 84	m.keyMap.Submit = key.NewBinding(
 85		key.WithKeys("enter", "ctrl+y"),
 86		key.WithHelp("enter", "submit"),
 87	)
 88	m.keyMap.Close = CloseKey
 89
 90	return &m, nil
 91}
 92
 93// ID implements Dialog.
 94func (m *APIKeyInput) ID() string {
 95	return APIKeyInputID
 96}
 97
 98// HandleMsg implements [Dialog].
 99func (m *APIKeyInput) HandleMsg(msg tea.Msg) Action {
100	switch msg := msg.(type) {
101	case ActionChangeAPIKeyState:
102		m.state = msg.State
103		switch m.state {
104		case APIKeyInputStateVerifying:
105			cmd := tea.Batch(m.spinner.Tick, m.verifyAPIKey)
106			return ActionCmd{cmd}
107		}
108	case spinner.TickMsg:
109		switch m.state {
110		case APIKeyInputStateVerifying:
111			var cmd tea.Cmd
112			m.spinner, cmd = m.spinner.Update(msg)
113			if cmd != nil {
114				return ActionCmd{cmd}
115			}
116		}
117	case tea.KeyPressMsg:
118		switch {
119		case m.state == APIKeyInputStateVerifying:
120			// do nothing
121		case key.Matches(msg, m.keyMap.Close):
122			switch m.state {
123			case APIKeyInputStateVerified:
124				return m.saveKeyAndContinue()
125			default:
126				return ActionClose{}
127			}
128		case key.Matches(msg, m.keyMap.Submit):
129			switch m.state {
130			case APIKeyInputStateInitial, APIKeyInputStateError:
131				return ActionChangeAPIKeyState{State: APIKeyInputStateVerifying}
132			case APIKeyInputStateVerified:
133				return m.saveKeyAndContinue()
134			}
135		default:
136			var cmd tea.Cmd
137			m.input, cmd = m.input.Update(msg)
138			if cmd != nil {
139				return ActionCmd{cmd}
140			}
141		}
142	case tea.PasteMsg:
143		var cmd tea.Cmd
144		m.input, cmd = m.input.Update(msg)
145		if cmd != nil {
146			return ActionCmd{cmd}
147		}
148	}
149	return nil
150}
151
152// Draw implements [Dialog].
153func (m *APIKeyInput) Draw(scr uv.Screen, area uv.Rectangle) *tea.Cursor {
154	t := m.com.Styles
155
156	textStyle := t.Dialog.SecondaryText
157	helpStyle := t.Dialog.HelpView
158	dialogStyle := t.Dialog.View.Width(m.width)
159	inputStyle := t.Dialog.InputPrompt
160	helpStyle = helpStyle.Width(m.width - dialogStyle.GetHorizontalFrameSize())
161
162	m.input.Prompt = m.spinner.View()
163
164	content := strings.Join([]string{
165		m.headerView(),
166		inputStyle.Render(m.inputView()),
167		textStyle.Render("This will be written in your global configuration:"),
168		textStyle.Render(config.GlobalConfigData()),
169		"",
170		helpStyle.Render(m.help.View(m)),
171	}, "\n")
172
173	view := dialogStyle.Render(content)
174
175	cur := m.Cursor()
176	DrawCenterCursor(scr, area, view, cur)
177	return cur
178}
179
180func (m *APIKeyInput) headerView() string {
181	t := m.com.Styles
182	titleStyle := t.Dialog.Title
183	dialogStyle := t.Dialog.View.Width(m.width)
184
185	headerOffset := titleStyle.GetHorizontalFrameSize() + dialogStyle.GetHorizontalFrameSize()
186	return common.DialogTitle(t, titleStyle.Render(m.dialogTitle()), m.width-headerOffset)
187}
188
189func (m *APIKeyInput) dialogTitle() string {
190	t := m.com.Styles
191	textStyle := t.Dialog.TitleText
192	errorStyle := t.Dialog.TitleError
193	accentStyle := t.Dialog.TitleAccent
194
195	switch m.state {
196	case APIKeyInputStateInitial:
197		return textStyle.Render("Enter your ") + accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render(".")
198	case APIKeyInputStateVerifying:
199		return textStyle.Render("Verifying your ") + accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render("...")
200	case APIKeyInputStateVerified:
201		return accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + textStyle.Render(" validated.")
202	case APIKeyInputStateError:
203		return errorStyle.Render("Invalid ") + accentStyle.Render(fmt.Sprintf("%s Key", m.provider.Name)) + errorStyle.Render(". Try again?")
204	}
205	return ""
206}
207
208func (m *APIKeyInput) inputView() string {
209	t := m.com.Styles
210
211	switch m.state {
212	case APIKeyInputStateInitial:
213		m.input.Prompt = "> "
214		m.input.SetStyles(t.TextInput)
215		m.input.Focus()
216	case APIKeyInputStateVerifying:
217		ts := t.TextInput
218		ts.Blurred.Prompt = ts.Focused.Prompt
219
220		m.input.Prompt = m.spinner.View()
221		m.input.SetStyles(ts)
222		m.input.Blur()
223	case APIKeyInputStateVerified:
224		ts := t.TextInput
225		ts.Blurred.Prompt = ts.Focused.Prompt
226
227		m.input.Prompt = styles.CheckIcon + " "
228		m.input.SetStyles(ts)
229		m.input.Blur()
230	case APIKeyInputStateError:
231		ts := t.TextInput
232		ts.Focused.Prompt = ts.Focused.Prompt.Foreground(charmtone.Cherry)
233
234		m.input.Prompt = styles.ErrorIcon + " "
235		m.input.SetStyles(ts)
236		m.input.Focus()
237	}
238	return m.input.View()
239}
240
241// Cursor returns the cursor position relative to the dialog.
242func (m *APIKeyInput) Cursor() *tea.Cursor {
243	return InputCursor(m.com.Styles, m.input.Cursor())
244}
245
246// FullHelp returns the full help view.
247func (m *APIKeyInput) FullHelp() [][]key.Binding {
248	return [][]key.Binding{
249		{
250			m.keyMap.Submit,
251			m.keyMap.Close,
252		},
253	}
254}
255
256// ShortHelp returns the full help view.
257func (m *APIKeyInput) ShortHelp() []key.Binding {
258	return []key.Binding{
259		m.keyMap.Submit,
260		m.keyMap.Close,
261	}
262}
263
264func (m *APIKeyInput) verifyAPIKey() tea.Msg {
265	start := time.Now()
266
267	providerConfig := config.ProviderConfig{
268		ID:      string(m.provider.ID),
269		Name:    m.provider.Name,
270		APIKey:  m.input.Value(),
271		Type:    m.provider.Type,
272		BaseURL: m.provider.APIEndpoint,
273	}
274	err := providerConfig.TestConnection(config.Get().Resolver())
275
276	// intentionally wait for at least 750ms to make sure the user sees the spinner
277	elapsed := time.Since(start)
278	minimum := 750 * time.Millisecond
279	if elapsed < minimum {
280		time.Sleep(minimum - elapsed)
281	}
282
283	if err == nil {
284		return ActionChangeAPIKeyState{APIKeyInputStateVerified}
285	}
286	return ActionChangeAPIKeyState{APIKeyInputStateError}
287}
288
289func (m *APIKeyInput) saveKeyAndContinue() Action {
290	cfg := m.com.Config()
291
292	err := cfg.SetProviderAPIKey(string(m.provider.ID), m.input.Value())
293	if err != nil {
294		return ActionCmd{uiutil.ReportError(fmt.Errorf("failed to save API key: %w", err))}
295	}
296
297	return ActionSelectModel{
298		Provider:  m.provider,
299		Model:     m.model,
300		ModelType: m.modelType,
301	}
302}