chore: add api key step

Kujtim Hoxha created

Change summary

internal/config/config.go                        |  52 +++++++
internal/config/load.go                          |   1 
internal/tui/components/chat/splash/keys.go      |   7 
internal/tui/components/chat/splash/splash.go    | 126 ++++++++++++++++-
internal/tui/components/dialogs/models/apikey.go |  20 +-
internal/tui/page/chat/chat.go                   |  26 +++
6 files changed, 204 insertions(+), 28 deletions(-)

Detailed changes

internal/config/config.go 🔗

@@ -211,8 +211,9 @@ type Config struct {
 	// TODO: most likely remove this concept when I come back to it
 	Agents map[string]Agent `json:"-"`
 	// TODO: find a better way to do this this should probably not be part of the config
-	resolver      VariableResolver
-	dataConfigDir string `json:"-"`
+	resolver       VariableResolver
+	dataConfigDir  string              `json:"-"`
+	knownProviders []provider.Provider `json:"-"`
 }
 
 func (c *Config) WorkingDir() string {
@@ -323,3 +324,50 @@ func (c *Config) SetConfigField(key string, value any) error {
 	}
 	return nil
 }
+
+func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
+	// First save to the config file
+	err := c.SetConfigField("providers."+providerID+".api_key", apiKey)
+	if err != nil {
+		return fmt.Errorf("failed to save API key to config file: %w", err)
+	}
+
+	if c.Providers == nil {
+		c.Providers = make(map[string]ProviderConfig)
+	}
+
+	providerConfig, exists := c.Providers[providerID]
+	if exists {
+		providerConfig.APIKey = apiKey
+		c.Providers[providerID] = providerConfig
+		return nil
+	}
+
+	var foundProvider *provider.Provider
+	for _, p := range c.knownProviders {
+		if string(p.ID) == providerID {
+			foundProvider = &p
+			break
+		}
+	}
+
+	if foundProvider != nil {
+		// Create new provider config based on known provider
+		providerConfig = ProviderConfig{
+			ID:           providerID,
+			Name:         foundProvider.Name,
+			BaseURL:      foundProvider.APIEndpoint,
+			Type:         foundProvider.Type,
+			APIKey:       apiKey,
+			Disable:      false,
+			ExtraHeaders: make(map[string]string),
+			ExtraParams:  make(map[string]string),
+			Models:       foundProvider.Models,
+		}
+	} else {
+		return fmt.Errorf("provider with ID %s not found in known providers", providerID)
+	}
+	// Store the updated provider config
+	c.Providers[providerID] = providerConfig
+	return nil
+}

internal/config/load.go 🔗

@@ -65,6 +65,7 @@ func Load(workingDir string, debug bool) (*Config, error) {
 	if err != nil || len(providers) == 0 {
 		return nil, fmt.Errorf("failed to load providers: %w", err)
 	}
+	cfg.knownProviders = providers
 
 	env := env.New()
 	// Configure providers

internal/tui/components/chat/splash/keys.go 🔗

@@ -11,7 +11,8 @@ type KeyMap struct {
 	Yes,
 	No,
 	Tab,
-	LeftRight key.Binding
+	LeftRight,
+	Back key.Binding
 }
 
 func DefaultKeyMap() KeyMap {
@@ -44,5 +45,9 @@ func DefaultKeyMap() KeyMap {
 			key.WithKeys("left", "right"),
 			key.WithHelp("←/→", "switch"),
 		),
+		Back: key.NewBinding(
+			key.WithKeys("esc"),
+			key.WithHelp("esc", "back"),
+		),
 	}
 }

internal/tui/components/chat/splash/splash.go 🔗

@@ -2,6 +2,7 @@ package splash
 
 import (
 	"fmt"
+	"log/slog"
 	"slices"
 
 	"github.com/charmbracelet/bubbles/v2/key"
@@ -48,10 +49,12 @@ type splashCmp struct {
 	// State
 	isOnboarding     bool
 	needsProjectInit bool
+	needsAPIKey      bool
 	selectedNo       bool
 
-	modelList            *models.ModelListComponent
-	cursorRow, cursorCol int
+	modelList     *models.ModelListComponent
+	apiKeyInput   *models.APIKeyInput
+	selectedModel *models.ModelOption
 }
 
 func New() Splash {
@@ -69,12 +72,15 @@ func New() Splash {
 	t := styles.CurrentTheme()
 	inputStyle := t.S().Base.Padding(0, 1, 0, 1)
 	modelList := models.NewModelListComponent(listKeyMap, inputStyle, "Find your fave")
+	apiKeyInput := models.NewAPIKeyInput()
+
 	return &splashCmp{
 		width:        0,
 		height:       0,
 		keyMap:       keyMap,
 		logoRendered: "",
 		modelList:    modelList,
+		apiKeyInput:  apiKeyInput,
 		selectedNo:   false,
 	}
 }
@@ -114,7 +120,7 @@ func (s *splashCmp) GetSize() (int, int) {
 
 // Init implements SplashPage.
 func (s *splashCmp) Init() tea.Cmd {
-	return s.modelList.Init()
+	return tea.Batch(s.modelList.Init(), s.apiKeyInput.Init())
 }
 
 // SetSize implements SplashPage.
@@ -125,8 +131,6 @@ func (s *splashCmp) SetSize(width int, height int) tea.Cmd {
 	listHeigh := min(40, height-(SplashScreenPaddingY*2)-lipgloss.Height(s.logoRendered)-2) // -1 for the title
 	listWidth := min(60, width-(SplashScreenPaddingX*2))
 
-	// Calculate the cursor position based on the height and logo size
-	s.cursorRow = height - listHeigh
 	return s.modelList.SetSize(listWidth, listHeigh)
 }
 
@@ -137,8 +141,16 @@ func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 		return s, s.SetSize(msg.Width, msg.Height)
 	case tea.KeyPressMsg:
 		switch {
+		case key.Matches(msg, s.keyMap.Back):
+			slog.Info("Back key pressed in splash screen")
+			if s.needsAPIKey {
+				// Go back to model selection
+				s.needsAPIKey = false
+				s.selectedModel = nil
+				return s, nil
+			}
 		case key.Matches(msg, s.keyMap.Select):
-			if s.isOnboarding {
+			if s.isOnboarding && !s.needsAPIKey {
 				modelInx := s.modelList.SelectedIndex()
 				items := s.modelList.Items()
 				selectedItem := items[modelInx].(completions.CompletionItem).Value().(models.ModelOption)
@@ -146,6 +158,18 @@ func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 					cmd := s.setPreferredModel(selectedItem)
 					s.isOnboarding = false
 					return s, tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
+				} else {
+					// Provider not configured, show API key input
+					s.needsAPIKey = true
+					s.selectedModel = &selectedItem
+					s.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
+					return s, nil
+				}
+			} else if s.needsAPIKey {
+				// Handle API key submission
+				apiKey := s.apiKeyInput.Value()
+				if apiKey != "" {
+					return s, s.saveAPIKeyAndContinue(apiKey)
 				}
 			} else if s.needsProjectInit {
 				return s, s.initializeProject()
@@ -165,16 +189,50 @@ func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 				return s, util.CmdHandler(OnboardingCompleteMsg{})
 			}
 		default:
-			if s.isOnboarding {
+			if s.needsAPIKey {
+				u, cmd := s.apiKeyInput.Update(msg)
+				s.apiKeyInput = u.(*models.APIKeyInput)
+				return s, cmd
+			} else if s.isOnboarding {
 				u, cmd := s.modelList.Update(msg)
 				s.modelList = u
 				return s, cmd
 			}
 		}
+	case tea.PasteMsg:
+		if s.needsAPIKey {
+			u, cmd := s.apiKeyInput.Update(msg)
+			s.apiKeyInput = u.(*models.APIKeyInput)
+			return s, cmd
+		} else if s.isOnboarding {
+			var cmd tea.Cmd
+			s.modelList, cmd = s.modelList.Update(msg)
+			return s, cmd
+		}
 	}
 	return s, nil
 }
 
+func (s *splashCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
+	if s.selectedModel == nil {
+		return util.ReportError(fmt.Errorf("no model selected"))
+	}
+
+	cfg := config.Get()
+	err := cfg.SetProviderAPIKey(string(s.selectedModel.Provider.ID), apiKey)
+	if err != nil {
+		return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
+	}
+
+	// Reset API key state and continue with model selection
+	s.needsAPIKey = false
+	cmd := s.setPreferredModel(*s.selectedModel)
+	s.isOnboarding = false
+	s.selectedModel = nil
+
+	return tea.Batch(cmd, util.CmdHandler(OnboardingCompleteMsg{}))
+}
+
 func (s *splashCmp) initializeProject() tea.Cmd {
 	s.needsProjectInit = false
 	prompt := `Please analyze this codebase and create a CRUSH.md file containing:
@@ -283,7 +341,21 @@ func (s *splashCmp) View() string {
 	t := styles.CurrentTheme()
 
 	var content string
-	if s.isOnboarding {
+	if s.needsAPIKey {
+		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
+		apiKeyView := s.apiKeyInput.View()
+		apiKeySelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
+			lipgloss.JoinVertical(
+				lipgloss.Left,
+				apiKeyView,
+			),
+		)
+		content = lipgloss.JoinVertical(
+			lipgloss.Left,
+			s.logoRendered,
+			apiKeySelector,
+		)
+	} else if s.isOnboarding {
 		remainingHeight := s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2)
 		modelListView := s.modelList.View()
 		modelSelector := t.S().Base.AlignVertical(lipgloss.Bottom).Height(remainingHeight).Render(
@@ -363,7 +435,12 @@ func (s *splashCmp) View() string {
 }
 
 func (s *splashCmp) Cursor() *tea.Cursor {
-	if s.isOnboarding {
+	if s.needsAPIKey {
+		cursor := s.apiKeyInput.Cursor()
+		if cursor != nil {
+			return s.moveCursor(cursor)
+		}
+	} else if s.isOnboarding {
 		cursor := s.modelList.Cursor()
 		if cursor != nil {
 			return s.moveCursor(cursor)
@@ -391,15 +468,38 @@ func (m *splashCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
 	if cursor == nil {
 		return nil
 	}
-	offset := m.cursorRow
-	cursor.Y += offset
-	cursor.X = cursor.X + 3 // 3 for padding
+
+	// Calculate the correct Y offset based on current state
+	logoHeight := lipgloss.Height(m.logoRendered)
+	baseOffset := logoHeight + SplashScreenPaddingY
+
+	if m.needsAPIKey {
+		// For API key input, position at the bottom of the remaining space
+		remainingHeight := m.height - logoHeight - (SplashScreenPaddingY * 2)
+		offset := baseOffset + remainingHeight - lipgloss.Height(m.apiKeyInput.View())
+		cursor.Y += offset
+		// API key input already includes prompt in its cursor positioning
+		cursor.X = cursor.X + SplashScreenPaddingX
+	} else if m.isOnboarding {
+		// For model list, use the original calculation
+		listHeight := min(40, m.height-(SplashScreenPaddingY*2)-logoHeight-2)
+		offset := m.height - listHeight
+		cursor.Y += offset
+		// Model list doesn't have a prompt, so add padding + space for list styling
+		cursor.X = cursor.X + SplashScreenPaddingX + 1
+	}
+
 	return cursor
 }
 
 // Bindings implements SplashPage.
 func (s *splashCmp) Bindings() []key.Binding {
-	if s.isOnboarding {
+	if s.needsAPIKey {
+		return []key.Binding{
+			s.keyMap.Select,
+			s.keyMap.Back,
+		}
+	} else if s.isOnboarding {
 		return []key.Binding{
 			s.keyMap.Select,
 			s.keyMap.Next,

internal/tui/components/dialogs/models/apikey.go 🔗

@@ -12,9 +12,10 @@ import (
 )
 
 type APIKeyInput struct {
-	input  textinput.Model
-	width  int
-	height int
+	input        textinput.Model
+	width        int
+	height       int
+	providerName string
 }
 
 func NewAPIKeyInput() *APIKeyInput {
@@ -29,11 +30,16 @@ func NewAPIKeyInput() *APIKeyInput {
 	ti.Focus()
 
 	return &APIKeyInput{
-		input: ti,
-		width: 60,
+		input:        ti,
+		width:        60,
+		providerName: "Provider",
 	}
 }
 
+func (a *APIKeyInput) SetProviderName(name string) {
+	a.providerName = name
+}
+
 func (a *APIKeyInput) Init() tea.Cmd {
 	return textinput.Blink
 }
@@ -54,9 +60,9 @@ func (a *APIKeyInput) View() string {
 	t := styles.CurrentTheme()
 
 	title := t.S().Base.
-		Foreground(t.Secondary).
+		Foreground(t.Primary).
 		Bold(true).
-		Render("Enter your Anthropic API Key")
+		Render(fmt.Sprintf("Enter your %s API Key", a.providerName))
 
 	inputView := a.input.View()
 

internal/tui/page/chat/chat.go 🔗

@@ -256,7 +256,9 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 			p.changeFocus()
 			return p, nil
 		case key.Matches(msg, p.keyMap.Cancel):
-			return p, p.cancel()
+			if p.session.ID != "" && p.app.CoderAgent.IsBusy() {
+				return p, p.cancel()
+			}
 		case key.Matches(msg, p.keyMap.Details):
 			p.showDetails()
 			return p, nil
@@ -276,6 +278,24 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 			p.splash = u.(splash.Splash)
 			cmds = append(cmds, cmd)
 		}
+	case tea.PasteMsg:
+		switch p.focusedPane {
+		case PanelTypeEditor:
+			u, cmd := p.editor.Update(msg)
+			p.editor = u.(editor.Editor)
+			cmds = append(cmds, cmd)
+			return p, tea.Batch(cmds...)
+		case PanelTypeChat:
+			u, cmd := p.chat.Update(msg)
+			p.chat = u.(chat.MessageListCmp)
+			cmds = append(cmds, cmd)
+			return p, tea.Batch(cmds...)
+		case PanelTypeSplash:
+			u, cmd := p.splash.Update(msg)
+			p.splash = u.(splash.Splash)
+			cmds = append(cmds, cmd)
+			return p, tea.Batch(cmds...)
+		}
 	}
 	return p, tea.Batch(cmds...)
 }
@@ -479,10 +499,6 @@ func (p *chatPage) changeFocus() {
 }
 
 func (p *chatPage) cancel() tea.Cmd {
-	if p.session.ID == "" || !p.app.CoderAgent.IsBusy() {
-		return nil
-	}
-
 	if p.isCanceling {
 		p.isCanceling = false
 		p.app.CoderAgent.Cancel(p.session.ID)