Detailed changes
@@ -1,10 +1,13 @@
package config
import (
+ "context"
"fmt"
+ "net/http"
"os"
"slices"
"strings"
+ "time"
"github.com/charmbracelet/crush/internal/env"
"github.com/charmbracelet/crush/internal/fur/provider"
@@ -435,3 +438,52 @@ func (c *Config) SetupAgents() {
}
c.Agents = agents
}
+
+func (c *Config) Resolver() VariableResolver {
+ return c.resolver
+}
+
+func (c *ProviderConfig) TestConnection(resolver VariableResolver) error {
+ testURL := ""
+ headers := make(map[string]string)
+ apiKey, _ := resolver.ResolveValue(c.APIKey)
+ switch c.Type {
+ case provider.TypeOpenAI:
+ baseURL, _ := resolver.ResolveValue(c.BaseURL)
+ if baseURL == "" {
+ baseURL = "https://api.openai.com/v1"
+ }
+ testURL = baseURL + "/models"
+ headers["Authorization"] = "Bearer " + apiKey
+ case provider.TypeAnthropic:
+ baseURL, _ := resolver.ResolveValue(c.BaseURL)
+ if baseURL == "" {
+ baseURL = "https://api.anthropic.com/v1"
+ }
+ testURL = baseURL + "/models"
+ headers["x-api-key"] = apiKey
+ headers["anthropic-version"] = "2023-06-01"
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ client := &http.Client{}
+ req, err := http.NewRequestWithContext(ctx, "GET", testURL, nil)
+ if err != nil {
+ return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
+ }
+ for k, v := range headers {
+ req.Header.Set(k, v)
+ }
+ for k, v := range c.ExtraHeaders {
+ req.Header.Set(k, v)
+ }
+ b, err := client.Do(req)
+ if err != nil {
+ return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
+ }
+ if b.StatusCode != http.StatusOK {
+ return fmt.Errorf("failed to connect to provider %s: %s", c.ID, b.Status)
+ }
+ _ = b.Body.Close()
+ return nil
+}
@@ -109,5 +109,5 @@ func HasInitialDataConfig() bool {
if _, err := os.Stat(cfgPath); err != nil {
return false
}
- return true
+ return Get().IsConfigured()
}
@@ -9,6 +9,7 @@ import (
"runtime"
"slices"
"strings"
+ "sync"
"github.com/charmbracelet/crush/internal/env"
"github.com/charmbracelet/crush/internal/fur/client"
@@ -75,6 +76,35 @@ func Load(workingDir string, debug bool) (*Config, error) {
return nil, fmt.Errorf("failed to configure providers: %w", err)
}
+ // Test provider connections in parallel
+ var testResults sync.Map
+ var wg sync.WaitGroup
+
+ for _, p := range cfg.Providers {
+ if p.Type == provider.TypeOpenAI || p.Type == provider.TypeAnthropic {
+ wg.Add(1)
+ go func(provider ProviderConfig) {
+ defer wg.Done()
+ err := provider.TestConnection(cfg.resolver)
+ testResults.Store(provider.ID, err == nil)
+ if err != nil {
+ slog.Error("Provider connection test failed", "provider", provider.ID, "error", err)
+ }
+ }(p)
+ }
+ }
+ wg.Wait()
+
+ // Remove failed providers
+ testResults.Range(func(key, value any) bool {
+ providerID := key.(string)
+ passed := value.(bool)
+ if !passed {
+ delete(cfg.Providers, providerID)
+ }
+ return true
+ })
+
if !cfg.IsConfigured() {
slog.Warn("No providers configured")
return cfg, nil
@@ -5,8 +5,10 @@ import (
"os"
"slices"
"strings"
+ "time"
"github.com/charmbracelet/bubbles/v2/key"
+ "github.com/charmbracelet/bubbles/v2/spinner"
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/fur/provider"
@@ -36,6 +38,9 @@ type Splash interface {
// Showing API key input
IsShowingAPIKey() bool
+
+ // IsAPIKeyValid returns whether the API key is valid
+ IsAPIKeyValid() bool
}
const (
@@ -45,7 +50,10 @@ const (
)
// OnboardingCompleteMsg is sent when onboarding is complete
-type OnboardingCompleteMsg struct{}
+type (
+ OnboardingCompleteMsg struct{}
+ SubmitAPIKeyMsg struct{}
+)
type splashCmp struct {
width, height int
@@ -62,6 +70,8 @@ type splashCmp struct {
modelList *models.ModelListComponent
apiKeyInput *models.APIKeyInput
selectedModel *models.ModelOption
+ isAPIKeyValid bool
+ apiKeyValue string
}
func New() Splash {
@@ -141,6 +151,7 @@ func (s *splashCmp) SetSize(width int, height int) tea.Cmd {
// remove padding, logo height, gap, title space
s.listHeight = s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - s.logoGap() - 2
listWidth := min(60, width)
+ s.apiKeyInput.SetWidth(width - 2)
return s.modelList.SetSize(listWidth, s.listHeight)
}
@@ -149,16 +160,38 @@ func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.WindowSizeMsg:
return s, s.SetSize(msg.Width, msg.Height)
+ case models.APIKeyStateChangeMsg:
+ u, cmd := s.apiKeyInput.Update(msg)
+ s.apiKeyInput = u.(*models.APIKeyInput)
+ if msg.State == models.APIKeyInputStateVerified {
+ return s, tea.Tick(5*time.Second, func(t time.Time) tea.Msg {
+ return SubmitAPIKeyMsg{}
+ })
+ }
+ return s, cmd
+ case SubmitAPIKeyMsg:
+ if s.isAPIKeyValid {
+ return s, s.saveAPIKeyAndContinue(s.apiKeyValue)
+ }
case tea.KeyPressMsg:
switch {
case key.Matches(msg, s.keyMap.Back):
+ if s.isAPIKeyValid {
+ return s, nil
+ }
if s.needsAPIKey {
// Go back to model selection
s.needsAPIKey = false
s.selectedModel = nil
+ s.isAPIKeyValid = false
+ s.apiKeyValue = ""
+ s.apiKeyInput.Reset()
return s, nil
}
case key.Matches(msg, s.keyMap.Select):
+ if s.isAPIKeyValid {
+ return s, s.saveAPIKeyAndContinue(s.apiKeyValue)
+ }
if s.isOnboarding && !s.needsAPIKey {
modelInx := s.modelList.SelectedIndex()
items := s.modelList.Items()
@@ -176,23 +209,75 @@ func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
} else if s.needsAPIKey {
// Handle API key submission
- apiKey := s.apiKeyInput.Value()
- if apiKey != "" {
- return s, s.saveAPIKeyAndContinue(apiKey)
+ s.apiKeyValue = strings.TrimSpace(s.apiKeyInput.Value())
+ if s.apiKeyValue == "" {
+ return s, nil
}
+
+ provider, err := s.getProvider(s.selectedModel.Provider.ID)
+ if err != nil || provider == nil {
+ return s, util.ReportError(fmt.Errorf("provider %s not found", s.selectedModel.Provider.ID))
+ }
+ providerConfig := config.ProviderConfig{
+ ID: string(s.selectedModel.Provider.ID),
+ Name: s.selectedModel.Provider.Name,
+ APIKey: s.apiKeyValue,
+ Type: provider.Type,
+ BaseURL: provider.APIEndpoint,
+ }
+ return s, tea.Sequence(
+ util.CmdHandler(models.APIKeyStateChangeMsg{
+ State: models.APIKeyInputStateVerifying,
+ }),
+ func() tea.Msg {
+ start := time.Now()
+ err := providerConfig.TestConnection(config.Get().Resolver())
+ // intentionally wait for at least 750ms to make sure the user sees the spinner
+ elapsed := time.Since(start)
+ if elapsed < 750*time.Millisecond {
+ time.Sleep(750*time.Millisecond - elapsed)
+ }
+ if err == nil {
+ s.isAPIKeyValid = true
+ return models.APIKeyStateChangeMsg{
+ State: models.APIKeyInputStateVerified,
+ }
+ }
+ return models.APIKeyStateChangeMsg{
+ State: models.APIKeyInputStateError,
+ }
+ },
+ )
} else if s.needsProjectInit {
return s, s.initializeProject()
}
case key.Matches(msg, s.keyMap.Tab, s.keyMap.LeftRight):
+ if s.needsAPIKey {
+ u, cmd := s.apiKeyInput.Update(msg)
+ s.apiKeyInput = u.(*models.APIKeyInput)
+ return s, cmd
+ }
if s.needsProjectInit {
s.selectedNo = !s.selectedNo
return s, nil
}
case key.Matches(msg, s.keyMap.Yes):
+ if s.needsAPIKey {
+ u, cmd := s.apiKeyInput.Update(msg)
+ s.apiKeyInput = u.(*models.APIKeyInput)
+ return s, cmd
+ }
+
if s.needsProjectInit {
return s, s.initializeProject()
}
case key.Matches(msg, s.keyMap.No):
+ if s.needsAPIKey {
+ u, cmd := s.apiKeyInput.Update(msg)
+ s.apiKeyInput = u.(*models.APIKeyInput)
+ return s, cmd
+ }
+
s.selectedNo = true
return s, s.initializeProject()
default:
@@ -216,6 +301,10 @@ func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
s.modelList, cmd = s.modelList.Update(msg)
return s, cmd
}
+ case spinner.TickMsg:
+ u, cmd := s.apiKeyInput.Update(msg)
+ s.apiKeyInput = u.(*models.APIKeyInput)
+ return s, cmd
}
return s, nil
}
@@ -629,3 +718,7 @@ func (s *splashCmp) mcpBlock() string {
func (s *splashCmp) IsShowingAPIKey() bool {
return s.needsAPIKey
}
+
+func (s *splashCmp) IsAPIKeyValid() bool {
+ return s.isAPIKeyValid
+}
@@ -4,6 +4,7 @@ import (
"fmt"
"strings"
+ "github.com/charmbracelet/bubbles/v2/spinner"
"github.com/charmbracelet/bubbles/v2/textinput"
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/crush/internal/config"
@@ -11,11 +12,27 @@ import (
"github.com/charmbracelet/lipgloss/v2"
)
+type APIKeyInputState int
+
+const (
+ APIKeyInputStateInitial APIKeyInputState = iota
+ APIKeyInputStateVerifying
+ APIKeyInputStateVerified
+ APIKeyInputStateError
+)
+
+type APIKeyStateChangeMsg struct {
+ State APIKeyInputState
+}
+
type APIKeyInput struct {
input textinput.Model
width int
- height int
+ spinner spinner.Model
providerName string
+ state APIKeyInputState
+ title string
+ showTitle bool
}
func NewAPIKeyInput() *APIKeyInput {
@@ -23,32 +40,59 @@ func NewAPIKeyInput() *APIKeyInput {
ti := textinput.New()
ti.Placeholder = "Enter your API key..."
- ti.SetWidth(50)
ti.SetVirtualCursor(false)
ti.Prompt = "> "
ti.SetStyles(t.S().TextInput)
ti.Focus()
return &APIKeyInput{
- input: ti,
- width: 60,
+ input: ti,
+ state: APIKeyInputStateInitial,
+ spinner: spinner.New(
+ spinner.WithSpinner(spinner.Dot),
+ spinner.WithStyle(t.S().Base.Foreground(t.Green)),
+ ),
providerName: "Provider",
+ showTitle: true,
}
}
func (a *APIKeyInput) SetProviderName(name string) {
a.providerName = name
+ a.updateStatePresentation()
+}
+
+func (a *APIKeyInput) SetShowTitle(show bool) {
+ a.showTitle = show
+}
+
+func (a *APIKeyInput) GetTitle() string {
+ return a.title
}
func (a *APIKeyInput) Init() tea.Cmd {
- return textinput.Blink
+ a.updateStatePresentation()
+ return a.spinner.Tick
}
func (a *APIKeyInput) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
- case tea.WindowSizeMsg:
- a.width = msg.Width
- a.height = msg.Height
+ case spinner.TickMsg:
+ if a.state == APIKeyInputStateVerifying {
+ var cmd tea.Cmd
+ a.spinner, cmd = a.spinner.Update(msg)
+ a.updateStatePresentation()
+ return a, cmd
+ }
+ return a, nil
+ case APIKeyStateChangeMsg:
+ a.state = msg.State
+ var cmd tea.Cmd
+ if msg.State == APIKeyInputStateVerifying {
+ cmd = a.spinner.Tick
+ }
+ a.updateStatePresentation()
+ return a, cmd
}
var cmd tea.Cmd
@@ -56,36 +100,79 @@ func (a *APIKeyInput) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return a, cmd
}
-func (a *APIKeyInput) View() string {
+func (a *APIKeyInput) updateStatePresentation() {
t := styles.CurrentTheme()
- title := t.S().Base.
- Foreground(t.Primary).
- Bold(true).
- Render(fmt.Sprintf("Enter your %s API Key", a.providerName))
+ prefixStyle := t.S().Base.
+ Foreground(t.Primary)
+ accentStyle := t.S().Base.Foreground(t.Green).Bold(true)
+ errorStyle := t.S().Base.Foreground(t.Cherry)
+
+ switch a.state {
+ case APIKeyInputStateInitial:
+ titlePrefix := prefixStyle.Render("Enter your ")
+ a.title = titlePrefix + accentStyle.Render(a.providerName+" API Key") + prefixStyle.Render(".")
+ a.input.SetStyles(t.S().TextInput)
+ a.input.Prompt = "> "
+ case APIKeyInputStateVerifying:
+ titlePrefix := prefixStyle.Render("Verifying your ")
+ a.title = titlePrefix + accentStyle.Render(a.providerName+" API Key") + prefixStyle.Render("...")
+ ts := t.S().TextInput
+ // make the blurred state be the same
+ ts.Blurred.Prompt = ts.Focused.Prompt
+ a.input.Prompt = a.spinner.View()
+ a.input.Blur()
+ case APIKeyInputStateVerified:
+ a.title = accentStyle.Render(a.providerName+" API Key") + prefixStyle.Render(" validated.")
+ ts := t.S().TextInput
+ // make the blurred state be the same
+ ts.Blurred.Prompt = ts.Focused.Prompt
+ a.input.SetStyles(ts)
+ a.input.Prompt = styles.CheckIcon + " "
+ a.input.Blur()
+ case APIKeyInputStateError:
+ a.title = errorStyle.Render("Invalid ") + accentStyle.Render(a.providerName+" API Key") + errorStyle.Render(". Try again?")
+ ts := t.S().TextInput
+ ts.Focused.Prompt = ts.Focused.Prompt.Foreground(t.Cherry)
+ a.input.Focus()
+ a.input.SetStyles(ts)
+ a.input.Prompt = styles.ErrorIcon + " "
+ }
+}
+func (a *APIKeyInput) View() string {
inputView := a.input.View()
dataPath := config.GlobalConfigData()
dataPath = strings.Replace(dataPath, config.HomeDir(), "~", 1)
- helpText := t.S().Muted.
+ helpText := styles.CurrentTheme().S().Muted.
Render(fmt.Sprintf("This will be written to the global configuration: %s", dataPath))
- content := lipgloss.JoinVertical(
- lipgloss.Left,
- title,
- "",
- inputView,
- "",
- helpText,
- )
+ var content string
+ if a.showTitle && a.title != "" {
+ content = lipgloss.JoinVertical(
+ lipgloss.Left,
+ a.title,
+ "",
+ inputView,
+ "",
+ helpText,
+ )
+ } else {
+ content = lipgloss.JoinVertical(
+ lipgloss.Left,
+ inputView,
+ "",
+ helpText,
+ )
+ }
return content
}
func (a *APIKeyInput) Cursor() *tea.Cursor {
cursor := a.input.Cursor()
- if cursor != nil {
+ if cursor != nil && a.showTitle {
cursor.Y += 2 // Adjust for title and spacing
}
return cursor
@@ -94,3 +181,22 @@ func (a *APIKeyInput) Cursor() *tea.Cursor {
func (a *APIKeyInput) Value() string {
return a.input.Value()
}
+
+func (a *APIKeyInput) Tick() tea.Cmd {
+ if a.state == APIKeyInputStateVerifying {
+ return a.spinner.Tick
+ }
+ return nil
+}
+
+func (a *APIKeyInput) SetWidth(width int) {
+ a.width = width
+ a.input.SetWidth(width - 4)
+}
+
+func (a *APIKeyInput) Reset() {
+ a.state = APIKeyInputStateInitial
+ a.input.SetValue("")
+ a.input.Focus()
+ a.updateStatePresentation()
+}
@@ -10,6 +10,9 @@ type KeyMap struct {
Previous,
Tab,
Close key.Binding
+
+ isAPIKeyHelp bool
+ isAPIKeyValid bool
}
func DefaultKeyMap() KeyMap {
@@ -61,6 +64,15 @@ func (k KeyMap) FullHelp() [][]key.Binding {
// ShortHelp implements help.KeyMap.
func (k KeyMap) ShortHelp() []key.Binding {
+ if k.isAPIKeyHelp && !k.isAPIKeyValid {
+ return []key.Binding{
+ k.Close,
+ }
+ } else if k.isAPIKeyValid {
+ return []key.Binding{
+ k.Select,
+ }
+ }
return []key.Binding{
key.NewBinding(
key.WithKeys("down", "up"),
@@ -1,8 +1,12 @@
package models
import (
+ "fmt"
+ "time"
+
"github.com/charmbracelet/bubbles/v2/help"
"github.com/charmbracelet/bubbles/v2/key"
+ "github.com/charmbracelet/bubbles/v2/spinner"
tea "github.com/charmbracelet/bubbletea/v2"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/fur/provider"
@@ -56,6 +60,14 @@ type modelDialogCmp struct {
modelList *ModelListComponent
keyMap KeyMap
help help.Model
+
+ // API key state
+ needsAPIKey bool
+ apiKeyInput *APIKeyInput
+ selectedModel *ModelOption
+ selectedModelType config.SelectedModelType
+ isAPIKeyValid bool
+ apiKeyValue string
}
func NewModelDialogCmp() ModelDialog {
@@ -75,19 +87,22 @@ func NewModelDialogCmp() ModelDialog {
t := styles.CurrentTheme()
inputStyle := t.S().Base.Padding(0, 1, 0, 1)
modelList := NewModelListComponent(listKeyMap, inputStyle, "Choose a model for large, complex tasks")
+ apiKeyInput := NewAPIKeyInput()
+ apiKeyInput.SetShowTitle(false)
help := help.New()
help.Styles = t.S().Help
return &modelDialogCmp{
- modelList: modelList,
- width: defaultWidth,
- keyMap: DefaultKeyMap(),
- help: help,
+ modelList: modelList,
+ apiKeyInput: apiKeyInput,
+ width: defaultWidth,
+ keyMap: DefaultKeyMap(),
+ help: help,
}
}
func (m *modelDialogCmp) Init() tea.Cmd {
- return m.modelList.Init()
+ return tea.Batch(m.modelList.Init(), m.apiKeyInput.Init())
}
func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
@@ -95,10 +110,58 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case tea.WindowSizeMsg:
m.wWidth = msg.Width
m.wHeight = msg.Height
+ m.apiKeyInput.SetWidth(m.width - 2)
+ m.help.Width = m.width - 2
return m, m.modelList.SetSize(m.listWidth(), m.listHeight())
+ case APIKeyStateChangeMsg:
+ u, cmd := m.apiKeyInput.Update(msg)
+ m.apiKeyInput = u.(*APIKeyInput)
+ return m, cmd
case tea.KeyPressMsg:
switch {
case key.Matches(msg, m.keyMap.Select):
+ if m.isAPIKeyValid {
+ return m, m.saveAPIKeyAndContinue(m.apiKeyValue)
+ }
+ if m.needsAPIKey {
+ // Handle API key submission
+ m.apiKeyValue = m.apiKeyInput.Value()
+ provider, err := m.getProvider(m.selectedModel.Provider.ID)
+ if err != nil || provider == nil {
+ return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID))
+ }
+ providerConfig := config.ProviderConfig{
+ ID: string(m.selectedModel.Provider.ID),
+ Name: m.selectedModel.Provider.Name,
+ APIKey: m.apiKeyValue,
+ Type: provider.Type,
+ BaseURL: provider.APIEndpoint,
+ }
+ return m, tea.Sequence(
+ util.CmdHandler(APIKeyStateChangeMsg{
+ State: APIKeyInputStateVerifying,
+ }),
+ func() tea.Msg {
+ start := time.Now()
+ err := providerConfig.TestConnection(config.Get().Resolver())
+ // intentionally wait for at least 750ms to make sure the user sees the spinner
+ elapsed := time.Since(start)
+ if elapsed < 750*time.Millisecond {
+ time.Sleep(750*time.Millisecond - elapsed)
+ }
+ if err == nil {
+ m.isAPIKeyValid = true
+ return APIKeyStateChangeMsg{
+ State: APIKeyInputStateVerified,
+ }
+ }
+ return APIKeyStateChangeMsg{
+ State: APIKeyInputStateError,
+ }
+ },
+ )
+ }
+ // Normal model selection
selectedItemInx := m.modelList.SelectedIndex()
if selectedItemInx == list.NoSelection {
return m, nil
@@ -113,17 +176,32 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
modelType = config.SelectedModelTypeSmall
}
- return m, tea.Sequence(
- util.CmdHandler(dialogs.CloseDialogMsg{}),
- util.CmdHandler(ModelSelectedMsg{
- Model: config.SelectedModel{
- Model: selectedItem.Model.ID,
- Provider: string(selectedItem.Provider.ID),
- },
- ModelType: modelType,
- }),
- )
+ // Check if provider is configured
+ if m.isProviderConfigured(string(selectedItem.Provider.ID)) {
+ return m, tea.Sequence(
+ util.CmdHandler(dialogs.CloseDialogMsg{}),
+ util.CmdHandler(ModelSelectedMsg{
+ Model: config.SelectedModel{
+ Model: selectedItem.Model.ID,
+ Provider: string(selectedItem.Provider.ID),
+ },
+ ModelType: modelType,
+ }),
+ )
+ } else {
+ // Provider not configured, show API key input
+ m.needsAPIKey = true
+ m.selectedModel = &selectedItem
+ m.selectedModelType = modelType
+ m.apiKeyInput.SetProviderName(selectedItem.Provider.Name)
+ return m, nil
+ }
case key.Matches(msg, m.keyMap.Tab):
+ if m.needsAPIKey {
+ u, cmd := m.apiKeyInput.Update(msg)
+ m.apiKeyInput = u.(*APIKeyInput)
+ return m, cmd
+ }
if m.modelList.GetModelType() == LargeModelType {
m.modelList.SetInputPlaceholder(smallModelInputPlaceholder)
return m, m.modelList.SetModelType(SmallModelType)
@@ -132,18 +210,68 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, m.modelList.SetModelType(LargeModelType)
}
case key.Matches(msg, m.keyMap.Close):
+ if m.needsAPIKey {
+ if m.isAPIKeyValid {
+ return m, nil
+ }
+ // Go back to model selection
+ m.needsAPIKey = false
+ m.selectedModel = nil
+ m.isAPIKeyValid = false
+ m.apiKeyValue = ""
+ m.apiKeyInput.Reset()
+ return m, nil
+ }
return m, util.CmdHandler(dialogs.CloseDialogMsg{})
default:
- u, cmd := m.modelList.Update(msg)
- m.modelList = u
+ if m.needsAPIKey {
+ u, cmd := m.apiKeyInput.Update(msg)
+ m.apiKeyInput = u.(*APIKeyInput)
+ return m, cmd
+ } else {
+ u, cmd := m.modelList.Update(msg)
+ m.modelList = u
+ return m, cmd
+ }
+ }
+ case tea.PasteMsg:
+ if m.needsAPIKey {
+ u, cmd := m.apiKeyInput.Update(msg)
+ m.apiKeyInput = u.(*APIKeyInput)
+ return m, cmd
+ } else {
+ var cmd tea.Cmd
+ m.modelList, cmd = m.modelList.Update(msg)
return m, cmd
}
+ case spinner.TickMsg:
+ u, cmd := m.apiKeyInput.Update(msg)
+ m.apiKeyInput = u.(*APIKeyInput)
+ return m, cmd
}
return m, nil
}
func (m *modelDialogCmp) View() string {
t := styles.CurrentTheme()
+
+ if m.needsAPIKey {
+ // Show API key input
+ m.keyMap.isAPIKeyHelp = true
+ m.keyMap.isAPIKeyValid = m.isAPIKeyValid
+ apiKeyView := m.apiKeyInput.View()
+ apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView)
+ content := lipgloss.JoinVertical(
+ lipgloss.Left,
+ t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)),
+ apiKeyView,
+ "",
+ t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)),
+ )
+ return m.style().Render(content)
+ }
+
+ // Show model selection
listView := m.modelList.View()
radio := m.modelTypeRadio()
content := lipgloss.JoinVertical(
@@ -157,10 +285,18 @@ func (m *modelDialogCmp) View() string {
}
func (m *modelDialogCmp) Cursor() *tea.Cursor {
- cursor := m.modelList.Cursor()
- if cursor != nil {
- cursor = m.moveCursor(cursor)
- return cursor
+ if m.needsAPIKey {
+ cursor := m.apiKeyInput.Cursor()
+ if cursor != nil {
+ cursor = m.moveCursor(cursor)
+ return cursor
+ }
+ } else {
+ cursor := m.modelList.Cursor()
+ if cursor != nil {
+ cursor = m.moveCursor(cursor)
+ return cursor
+ }
}
return nil
}
@@ -192,9 +328,15 @@ func (m *modelDialogCmp) Position() (int, int) {
func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor {
row, col := m.Position()
- offset := row + 3 // Border + title
- cursor.Y += offset
- cursor.X = cursor.X + col + 2
+ if m.needsAPIKey {
+ offset := row + 3 // Border + title + API key input offset
+ cursor.Y += offset
+ cursor.X = cursor.X + col + 2
+ } else {
+ offset := row + 3 // Border + title
+ cursor.Y += offset
+ cursor.X = cursor.X + col + 2
+ }
return cursor
}
@@ -212,3 +354,49 @@ func (m *modelDialogCmp) modelTypeRadio() string {
}
return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1])
}
+
+func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
+ cfg := config.Get()
+ if _, ok := cfg.Providers[providerID]; ok {
+ return true
+ }
+ return false
+}
+
+func (m *modelDialogCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
+ providers, err := config.Providers()
+ if err != nil {
+ return nil, err
+ }
+ for _, p := range providers {
+ if p.ID == providerID {
+ return &p, nil
+ }
+ }
+ return nil, nil
+}
+
+func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd {
+ if m.selectedModel == nil {
+ return util.ReportError(fmt.Errorf("no model selected"))
+ }
+
+ cfg := config.Get()
+ err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey)
+ if err != nil {
+ return util.ReportError(fmt.Errorf("failed to save API key: %w", err))
+ }
+
+ // Reset API key state and continue with model selection
+ selectedModel := *m.selectedModel
+ return tea.Sequence(
+ util.CmdHandler(dialogs.CloseDialogMsg{}),
+ util.CmdHandler(ModelSelectedMsg{
+ Model: config.SelectedModel{
+ Model: selectedModel.Model.ID,
+ Provider: string(selectedModel.Provider.ID),
+ },
+ ModelType: m.selectedModelType,
+ }),
+ )
+}
@@ -25,6 +25,7 @@ import (
"github.com/charmbracelet/crush/internal/tui/components/core/layout"
"github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
"github.com/charmbracelet/crush/internal/tui/components/dialogs/filepicker"
+ "github.com/charmbracelet/crush/internal/tui/components/dialogs/models"
"github.com/charmbracelet/crush/internal/tui/page"
"github.com/charmbracelet/crush/internal/tui/styles"
"github.com/charmbracelet/crush/internal/tui/util"
@@ -172,6 +173,11 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return p, p.sendMessage(msg.Text, msg.Attachments)
case chat.SessionSelectedMsg:
return p, p.setSession(msg)
+ case splash.SubmitAPIKeyMsg:
+ u, cmd := p.splash.Update(msg)
+ p.splash = u.(splash.Splash)
+ cmds = append(cmds, cmd)
+ return p, tea.Batch(cmds...)
case commands.ToggleCompactModeMsg:
p.forceCompact = !p.forceCompact
var cmd tea.Cmd
@@ -212,12 +218,25 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
cmds = append(cmds, cmd)
return p, tea.Batch(cmds...)
+ case models.APIKeyStateChangeMsg:
+ if p.focusedPane == PanelTypeSplash {
+ u, cmd := p.splash.Update(msg)
+ p.splash = u.(splash.Splash)
+ cmds = append(cmds, cmd)
+ }
+ return p, tea.Batch(cmds...)
case pubsub.Event[message.Message],
anim.StepMsg,
spinner.TickMsg:
- u, cmd := p.chat.Update(msg)
- p.chat = u.(chat.MessageListCmp)
- cmds = append(cmds, cmd)
+ if p.focusedPane == PanelTypeSplash {
+ u, cmd := p.splash.Update(msg)
+ p.splash = u.(splash.Splash)
+ cmds = append(cmds, cmd)
+ } else {
+ u, cmd := p.chat.Update(msg)
+ p.chat = u.(chat.MessageListCmp)
+ cmds = append(cmds, cmd)
+ }
return p, tea.Batch(cmds...)
case pubsub.Event[history.File], sidebar.SessionFilesMsg:
@@ -655,12 +674,23 @@ func (p *chatPage) Help() help.KeyMap {
fullList = append(fullList, []key.Binding{v})
}
case p.isOnboarding && p.splash.IsShowingAPIKey():
+ if p.splash.IsAPIKeyValid() {
+ shortList = append(shortList,
+ key.NewBinding(
+ key.WithKeys("enter"),
+ key.WithHelp("enter", "continue"),
+ ),
+ )
+ } else {
+ shortList = append(shortList,
+ // Go back
+ key.NewBinding(
+ key.WithKeys("esc"),
+ key.WithHelp("esc", "back"),
+ ),
+ )
+ }
shortList = append(shortList,
- // Go back
- key.NewBinding(
- key.WithKeys("esc"),
- key.WithHelp("esc", "back"),
- ),
// Quit
key.NewBinding(
key.WithKeys("ctrl+c"),
@@ -52,5 +52,6 @@ func NewCrushTheme() *Theme {
Red: charmtone.Coral,
RedDark: charmtone.Sriracha,
RedLight: charmtone.Salmon,
+ Cherry: charmtone.Cherry,
}
}
@@ -72,6 +72,7 @@ type Theme struct {
Red color.Color
RedDark color.Color
RedLight color.Color
+ Cherry color.Color
styles *Styles
}
@@ -150,15 +151,15 @@ func (t *Theme) buildStyles() *Styles {
TextInput: textinput.Styles{
Focused: textinput.StyleState{
Text: base,
- Placeholder: base.Foreground(t.FgMuted),
+ Placeholder: base.Foreground(t.FgSubtle),
Prompt: base.Foreground(t.Tertiary),
- Suggestion: base.Foreground(t.FgMuted),
+ Suggestion: base.Foreground(t.FgSubtle),
},
Blurred: textinput.StyleState{
Text: base.Foreground(t.FgMuted),
- Placeholder: base.Foreground(t.FgMuted),
+ Placeholder: base.Foreground(t.FgSubtle),
Prompt: base.Foreground(t.FgMuted),
- Suggestion: base.Foreground(t.FgMuted),
+ Suggestion: base.Foreground(t.FgSubtle),
},
Cursor: textinput.CursorStyle{
Color: t.Secondary,
@@ -173,7 +174,7 @@ func (t *Theme) buildStyles() *Styles {
LineNumber: base.Foreground(t.FgSubtle),
CursorLine: base,
CursorLineNumber: base.Foreground(t.FgSubtle),
- Placeholder: base.Foreground(t.FgMuted),
+ Placeholder: base.Foreground(t.FgSubtle),
Prompt: base.Foreground(t.Tertiary),
},
Blurred: textarea.StyleState{
@@ -182,7 +183,7 @@ func (t *Theme) buildStyles() *Styles {
LineNumber: base.Foreground(t.FgMuted),
CursorLine: base,
CursorLineNumber: base.Foreground(t.FgMuted),
- Placeholder: base.Foreground(t.FgMuted),
+ Placeholder: base.Foreground(t.FgSubtle),
Prompt: base.Foreground(t.FgMuted),
},
Cursor: textarea.CursorStyle{
@@ -12,6 +12,7 @@ import (
"github.com/charmbracelet/crush/internal/permission"
"github.com/charmbracelet/crush/internal/pubsub"
cmpChat "github.com/charmbracelet/crush/internal/tui/components/chat"
+ "github.com/charmbracelet/crush/internal/tui/components/chat/splash"
"github.com/charmbracelet/crush/internal/tui/components/completions"
"github.com/charmbracelet/crush/internal/tui/components/core"
"github.com/charmbracelet/crush/internal/tui/components/core/layout"
@@ -48,8 +49,9 @@ type appModel struct {
app *app.App
- dialog dialogs.DialogCmp
- completions completions.Completions
+ dialog dialogs.DialogCmp
+ completions completions.Completions
+ isConfigured bool
// Chat Page Specific
selectedSessionID string // The ID of the currently selected session
@@ -72,6 +74,7 @@ func (a appModel) Init() tea.Cmd {
func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmds []tea.Cmd
var cmd tea.Cmd
+ a.isConfigured = config.HasInitialDataConfig()
switch msg := msg.(type) {
case tea.KeyboardEnhancementsMsg:
@@ -223,10 +226,28 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
}
+ return a, tea.Batch(cmds...)
+ case splash.OnboardingCompleteMsg:
+ a.isConfigured = config.HasInitialDataConfig()
+ updated, cmd := a.pages[a.currentPage].Update(msg)
+ a.pages[a.currentPage] = updated.(util.Model)
+ cmds = append(cmds, cmd)
return a, tea.Batch(cmds...)
// Key Press Messages
case tea.KeyPressMsg:
return a, a.handleKeyPressMsg(msg)
+
+ case tea.PasteMsg:
+ if a.dialog.HasDialogs() {
+ u, dialogCmd := a.dialog.Update(msg)
+ a.dialog = u.(dialogs.DialogCmp)
+ cmds = append(cmds, dialogCmd)
+ } else {
+ updated, cmd := a.pages[a.currentPage].Update(msg)
+ a.pages[a.currentPage] = updated.(util.Model)
+ cmds = append(cmds, cmd)
+ }
+ return a, tea.Batch(cmds...)
}
s, _ := a.status.Update(msg)
a.status = s.(status.StatusCmp)
@@ -307,6 +328,10 @@ func (a *appModel) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd {
})
case key.Matches(msg, a.keyMap.Commands):
+ // if the app is not configured show no commands
+ if !a.isConfigured {
+ return nil
+ }
if a.dialog.ActiveDialogID() == commands.CommandsDialogID {
return util.CmdHandler(dialogs.CloseDialogMsg{})
}
@@ -317,6 +342,10 @@ func (a *appModel) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd {
Model: commands.NewCommandDialog(a.selectedSessionID),
})
case key.Matches(msg, a.keyMap.Sessions):
+ // if the app is not configured show no sessions
+ if !a.isConfigured {
+ return nil
+ }
if a.dialog.ActiveDialogID() == sessions.SessionsDialogID {
return util.CmdHandler(dialogs.CloseDialogMsg{})
}