diff --git a/internal/config/config.go b/internal/config/config.go index d63a34f73d5210c2542be8a598ef38cb06339bd9..f4467f71a86d027298c8adc5dabc872a02710d0c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,10 +1,13 @@ package config import ( + "context" "fmt" + "net/http" "os" "slices" "strings" + "time" "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/fur/provider" @@ -435,3 +438,52 @@ func (c *Config) SetupAgents() { } c.Agents = agents } + +func (c *Config) Resolver() VariableResolver { + return c.resolver +} + +func (c *ProviderConfig) TestConnection(resolver VariableResolver) error { + testURL := "" + headers := make(map[string]string) + apiKey, _ := resolver.ResolveValue(c.APIKey) + switch c.Type { + case provider.TypeOpenAI: + baseURL, _ := resolver.ResolveValue(c.BaseURL) + if baseURL == "" { + baseURL = "https://api.openai.com/v1" + } + testURL = baseURL + "/models" + headers["Authorization"] = "Bearer " + apiKey + case provider.TypeAnthropic: + baseURL, _ := resolver.ResolveValue(c.BaseURL) + if baseURL == "" { + baseURL = "https://api.anthropic.com/v1" + } + testURL = baseURL + "/models" + headers["x-api-key"] = apiKey + headers["anthropic-version"] = "2023-06-01" + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + client := &http.Client{} + req, err := http.NewRequestWithContext(ctx, "GET", testURL, nil) + if err != nil { + return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err) + } + for k, v := range headers { + req.Header.Set(k, v) + } + for k, v := range c.ExtraHeaders { + req.Header.Set(k, v) + } + b, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err) + } + if b.StatusCode != http.StatusOK { + return fmt.Errorf("failed to connect to provider %s: %s", c.ID, b.Status) + } + _ = b.Body.Close() + return nil +} diff --git a/internal/config/init.go b/internal/config/init.go index 12b30efd75f88d438e0734571cbb5c634ba231bc..827a287718e40e1fc5b9b761293c00799ec5ef3d 100644 --- a/internal/config/init.go +++ b/internal/config/init.go @@ -109,5 +109,5 @@ func HasInitialDataConfig() bool { if _, err := os.Stat(cfgPath); err != nil { return false } - return true + return Get().IsConfigured() } diff --git a/internal/config/load.go b/internal/config/load.go index e056847aeeda476e819439384a16e0e237b067e1..f481be240e9d82520cef6c9f75210d8cbd1a0776 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -9,6 +9,7 @@ import ( "runtime" "slices" "strings" + "sync" "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/fur/client" @@ -75,6 +76,35 @@ func Load(workingDir string, debug bool) (*Config, error) { return nil, fmt.Errorf("failed to configure providers: %w", err) } + // Test provider connections in parallel + var testResults sync.Map + var wg sync.WaitGroup + + for _, p := range cfg.Providers { + if p.Type == provider.TypeOpenAI || p.Type == provider.TypeAnthropic { + wg.Add(1) + go func(provider ProviderConfig) { + defer wg.Done() + err := provider.TestConnection(cfg.resolver) + testResults.Store(provider.ID, err == nil) + if err != nil { + slog.Error("Provider connection test failed", "provider", provider.ID, "error", err) + } + }(p) + } + } + wg.Wait() + + // Remove failed providers + testResults.Range(func(key, value any) bool { + providerID := key.(string) + passed := value.(bool) + if !passed { + delete(cfg.Providers, providerID) + } + return true + }) + if !cfg.IsConfigured() { slog.Warn("No providers configured") return cfg, nil diff --git a/internal/tui/components/chat/splash/splash.go b/internal/tui/components/chat/splash/splash.go index 0ba04f6d16f2b93ac5556cd204fe63bcca5594e2..ed5f071a2d68843f89e9f5c92255ce705984e3f1 100644 --- a/internal/tui/components/chat/splash/splash.go +++ b/internal/tui/components/chat/splash/splash.go @@ -5,8 +5,10 @@ import ( "os" "slices" "strings" + "time" "github.com/charmbracelet/bubbles/v2/key" + "github.com/charmbracelet/bubbles/v2/spinner" tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/fur/provider" @@ -36,6 +38,9 @@ type Splash interface { // Showing API key input IsShowingAPIKey() bool + + // IsAPIKeyValid returns whether the API key is valid + IsAPIKeyValid() bool } const ( @@ -45,7 +50,10 @@ const ( ) // OnboardingCompleteMsg is sent when onboarding is complete -type OnboardingCompleteMsg struct{} +type ( + OnboardingCompleteMsg struct{} + SubmitAPIKeyMsg struct{} +) type splashCmp struct { width, height int @@ -62,6 +70,8 @@ type splashCmp struct { modelList *models.ModelListComponent apiKeyInput *models.APIKeyInput selectedModel *models.ModelOption + isAPIKeyValid bool + apiKeyValue string } func New() Splash { @@ -141,6 +151,7 @@ func (s *splashCmp) SetSize(width int, height int) tea.Cmd { // remove padding, logo height, gap, title space s.listHeight = s.height - lipgloss.Height(s.logoRendered) - (SplashScreenPaddingY * 2) - s.logoGap() - 2 listWidth := min(60, width) + s.apiKeyInput.SetWidth(width - 2) return s.modelList.SetSize(listWidth, s.listHeight) } @@ -149,16 +160,38 @@ func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { case tea.WindowSizeMsg: return s, s.SetSize(msg.Width, msg.Height) + case models.APIKeyStateChangeMsg: + u, cmd := s.apiKeyInput.Update(msg) + s.apiKeyInput = u.(*models.APIKeyInput) + if msg.State == models.APIKeyInputStateVerified { + return s, tea.Tick(5*time.Second, func(t time.Time) tea.Msg { + return SubmitAPIKeyMsg{} + }) + } + return s, cmd + case SubmitAPIKeyMsg: + if s.isAPIKeyValid { + return s, s.saveAPIKeyAndContinue(s.apiKeyValue) + } case tea.KeyPressMsg: switch { case key.Matches(msg, s.keyMap.Back): + if s.isAPIKeyValid { + return s, nil + } if s.needsAPIKey { // Go back to model selection s.needsAPIKey = false s.selectedModel = nil + s.isAPIKeyValid = false + s.apiKeyValue = "" + s.apiKeyInput.Reset() return s, nil } case key.Matches(msg, s.keyMap.Select): + if s.isAPIKeyValid { + return s, s.saveAPIKeyAndContinue(s.apiKeyValue) + } if s.isOnboarding && !s.needsAPIKey { modelInx := s.modelList.SelectedIndex() items := s.modelList.Items() @@ -176,23 +209,75 @@ func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } } else if s.needsAPIKey { // Handle API key submission - apiKey := s.apiKeyInput.Value() - if apiKey != "" { - return s, s.saveAPIKeyAndContinue(apiKey) + s.apiKeyValue = strings.TrimSpace(s.apiKeyInput.Value()) + if s.apiKeyValue == "" { + return s, nil } + + provider, err := s.getProvider(s.selectedModel.Provider.ID) + if err != nil || provider == nil { + return s, util.ReportError(fmt.Errorf("provider %s not found", s.selectedModel.Provider.ID)) + } + providerConfig := config.ProviderConfig{ + ID: string(s.selectedModel.Provider.ID), + Name: s.selectedModel.Provider.Name, + APIKey: s.apiKeyValue, + Type: provider.Type, + BaseURL: provider.APIEndpoint, + } + return s, tea.Sequence( + util.CmdHandler(models.APIKeyStateChangeMsg{ + State: models.APIKeyInputStateVerifying, + }), + func() tea.Msg { + start := time.Now() + err := providerConfig.TestConnection(config.Get().Resolver()) + // intentionally wait for at least 750ms to make sure the user sees the spinner + elapsed := time.Since(start) + if elapsed < 750*time.Millisecond { + time.Sleep(750*time.Millisecond - elapsed) + } + if err == nil { + s.isAPIKeyValid = true + return models.APIKeyStateChangeMsg{ + State: models.APIKeyInputStateVerified, + } + } + return models.APIKeyStateChangeMsg{ + State: models.APIKeyInputStateError, + } + }, + ) } else if s.needsProjectInit { return s, s.initializeProject() } case key.Matches(msg, s.keyMap.Tab, s.keyMap.LeftRight): + if s.needsAPIKey { + u, cmd := s.apiKeyInput.Update(msg) + s.apiKeyInput = u.(*models.APIKeyInput) + return s, cmd + } if s.needsProjectInit { s.selectedNo = !s.selectedNo return s, nil } case key.Matches(msg, s.keyMap.Yes): + if s.needsAPIKey { + u, cmd := s.apiKeyInput.Update(msg) + s.apiKeyInput = u.(*models.APIKeyInput) + return s, cmd + } + if s.needsProjectInit { return s, s.initializeProject() } case key.Matches(msg, s.keyMap.No): + if s.needsAPIKey { + u, cmd := s.apiKeyInput.Update(msg) + s.apiKeyInput = u.(*models.APIKeyInput) + return s, cmd + } + s.selectedNo = true return s, s.initializeProject() default: @@ -216,6 +301,10 @@ func (s *splashCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { s.modelList, cmd = s.modelList.Update(msg) return s, cmd } + case spinner.TickMsg: + u, cmd := s.apiKeyInput.Update(msg) + s.apiKeyInput = u.(*models.APIKeyInput) + return s, cmd } return s, nil } @@ -629,3 +718,7 @@ func (s *splashCmp) mcpBlock() string { func (s *splashCmp) IsShowingAPIKey() bool { return s.needsAPIKey } + +func (s *splashCmp) IsAPIKeyValid() bool { + return s.isAPIKeyValid +} diff --git a/internal/tui/components/dialogs/models/apikey.go b/internal/tui/components/dialogs/models/apikey.go index d5aa034d133d2e4d5cbe676aed0fb7e1edde487c..10378ead072f01ed064fe1c48c97abd3c6feb175 100644 --- a/internal/tui/components/dialogs/models/apikey.go +++ b/internal/tui/components/dialogs/models/apikey.go @@ -4,6 +4,7 @@ import ( "fmt" "strings" + "github.com/charmbracelet/bubbles/v2/spinner" "github.com/charmbracelet/bubbles/v2/textinput" tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/crush/internal/config" @@ -11,11 +12,27 @@ import ( "github.com/charmbracelet/lipgloss/v2" ) +type APIKeyInputState int + +const ( + APIKeyInputStateInitial APIKeyInputState = iota + APIKeyInputStateVerifying + APIKeyInputStateVerified + APIKeyInputStateError +) + +type APIKeyStateChangeMsg struct { + State APIKeyInputState +} + type APIKeyInput struct { input textinput.Model width int - height int + spinner spinner.Model providerName string + state APIKeyInputState + title string + showTitle bool } func NewAPIKeyInput() *APIKeyInput { @@ -23,32 +40,59 @@ func NewAPIKeyInput() *APIKeyInput { ti := textinput.New() ti.Placeholder = "Enter your API key..." - ti.SetWidth(50) ti.SetVirtualCursor(false) ti.Prompt = "> " ti.SetStyles(t.S().TextInput) ti.Focus() return &APIKeyInput{ - input: ti, - width: 60, + input: ti, + state: APIKeyInputStateInitial, + spinner: spinner.New( + spinner.WithSpinner(spinner.Dot), + spinner.WithStyle(t.S().Base.Foreground(t.Green)), + ), providerName: "Provider", + showTitle: true, } } func (a *APIKeyInput) SetProviderName(name string) { a.providerName = name + a.updateStatePresentation() +} + +func (a *APIKeyInput) SetShowTitle(show bool) { + a.showTitle = show +} + +func (a *APIKeyInput) GetTitle() string { + return a.title } func (a *APIKeyInput) Init() tea.Cmd { - return textinput.Blink + a.updateStatePresentation() + return a.spinner.Tick } func (a *APIKeyInput) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { - case tea.WindowSizeMsg: - a.width = msg.Width - a.height = msg.Height + case spinner.TickMsg: + if a.state == APIKeyInputStateVerifying { + var cmd tea.Cmd + a.spinner, cmd = a.spinner.Update(msg) + a.updateStatePresentation() + return a, cmd + } + return a, nil + case APIKeyStateChangeMsg: + a.state = msg.State + var cmd tea.Cmd + if msg.State == APIKeyInputStateVerifying { + cmd = a.spinner.Tick + } + a.updateStatePresentation() + return a, cmd } var cmd tea.Cmd @@ -56,36 +100,79 @@ func (a *APIKeyInput) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return a, cmd } -func (a *APIKeyInput) View() string { +func (a *APIKeyInput) updateStatePresentation() { t := styles.CurrentTheme() - title := t.S().Base. - Foreground(t.Primary). - Bold(true). - Render(fmt.Sprintf("Enter your %s API Key", a.providerName)) + prefixStyle := t.S().Base. + Foreground(t.Primary) + accentStyle := t.S().Base.Foreground(t.Green).Bold(true) + errorStyle := t.S().Base.Foreground(t.Cherry) + + switch a.state { + case APIKeyInputStateInitial: + titlePrefix := prefixStyle.Render("Enter your ") + a.title = titlePrefix + accentStyle.Render(a.providerName+" API Key") + prefixStyle.Render(".") + a.input.SetStyles(t.S().TextInput) + a.input.Prompt = "> " + case APIKeyInputStateVerifying: + titlePrefix := prefixStyle.Render("Verifying your ") + a.title = titlePrefix + accentStyle.Render(a.providerName+" API Key") + prefixStyle.Render("...") + ts := t.S().TextInput + // make the blurred state be the same + ts.Blurred.Prompt = ts.Focused.Prompt + a.input.Prompt = a.spinner.View() + a.input.Blur() + case APIKeyInputStateVerified: + a.title = accentStyle.Render(a.providerName+" API Key") + prefixStyle.Render(" validated.") + ts := t.S().TextInput + // make the blurred state be the same + ts.Blurred.Prompt = ts.Focused.Prompt + a.input.SetStyles(ts) + a.input.Prompt = styles.CheckIcon + " " + a.input.Blur() + case APIKeyInputStateError: + a.title = errorStyle.Render("Invalid ") + accentStyle.Render(a.providerName+" API Key") + errorStyle.Render(". Try again?") + ts := t.S().TextInput + ts.Focused.Prompt = ts.Focused.Prompt.Foreground(t.Cherry) + a.input.Focus() + a.input.SetStyles(ts) + a.input.Prompt = styles.ErrorIcon + " " + } +} +func (a *APIKeyInput) View() string { inputView := a.input.View() dataPath := config.GlobalConfigData() dataPath = strings.Replace(dataPath, config.HomeDir(), "~", 1) - helpText := t.S().Muted. + helpText := styles.CurrentTheme().S().Muted. Render(fmt.Sprintf("This will be written to the global configuration: %s", dataPath)) - content := lipgloss.JoinVertical( - lipgloss.Left, - title, - "", - inputView, - "", - helpText, - ) + var content string + if a.showTitle && a.title != "" { + content = lipgloss.JoinVertical( + lipgloss.Left, + a.title, + "", + inputView, + "", + helpText, + ) + } else { + content = lipgloss.JoinVertical( + lipgloss.Left, + inputView, + "", + helpText, + ) + } return content } func (a *APIKeyInput) Cursor() *tea.Cursor { cursor := a.input.Cursor() - if cursor != nil { + if cursor != nil && a.showTitle { cursor.Y += 2 // Adjust for title and spacing } return cursor @@ -94,3 +181,22 @@ func (a *APIKeyInput) Cursor() *tea.Cursor { func (a *APIKeyInput) Value() string { return a.input.Value() } + +func (a *APIKeyInput) Tick() tea.Cmd { + if a.state == APIKeyInputStateVerifying { + return a.spinner.Tick + } + return nil +} + +func (a *APIKeyInput) SetWidth(width int) { + a.width = width + a.input.SetWidth(width - 4) +} + +func (a *APIKeyInput) Reset() { + a.state = APIKeyInputStateInitial + a.input.SetValue("") + a.input.Focus() + a.updateStatePresentation() +} diff --git a/internal/tui/components/dialogs/models/keys.go b/internal/tui/components/dialogs/models/keys.go index bb70785172bac66d9fda905172572c881b2ecd35..df546863d87d3a68777e51938f58eee28a5c6473 100644 --- a/internal/tui/components/dialogs/models/keys.go +++ b/internal/tui/components/dialogs/models/keys.go @@ -10,6 +10,9 @@ type KeyMap struct { Previous, Tab, Close key.Binding + + isAPIKeyHelp bool + isAPIKeyValid bool } func DefaultKeyMap() KeyMap { @@ -61,6 +64,15 @@ func (k KeyMap) FullHelp() [][]key.Binding { // ShortHelp implements help.KeyMap. func (k KeyMap) ShortHelp() []key.Binding { + if k.isAPIKeyHelp && !k.isAPIKeyValid { + return []key.Binding{ + k.Close, + } + } else if k.isAPIKeyValid { + return []key.Binding{ + k.Select, + } + } return []key.Binding{ key.NewBinding( key.WithKeys("down", "up"), diff --git a/internal/tui/components/dialogs/models/models.go b/internal/tui/components/dialogs/models/models.go index a4cb9bd47e81229b343d65660174f843a98503a8..b28efc6010582a503c34e87ad101832925d8acca 100644 --- a/internal/tui/components/dialogs/models/models.go +++ b/internal/tui/components/dialogs/models/models.go @@ -1,8 +1,12 @@ package models import ( + "fmt" + "time" + "github.com/charmbracelet/bubbles/v2/help" "github.com/charmbracelet/bubbles/v2/key" + "github.com/charmbracelet/bubbles/v2/spinner" tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/fur/provider" @@ -56,6 +60,14 @@ type modelDialogCmp struct { modelList *ModelListComponent keyMap KeyMap help help.Model + + // API key state + needsAPIKey bool + apiKeyInput *APIKeyInput + selectedModel *ModelOption + selectedModelType config.SelectedModelType + isAPIKeyValid bool + apiKeyValue string } func NewModelDialogCmp() ModelDialog { @@ -75,19 +87,22 @@ func NewModelDialogCmp() ModelDialog { t := styles.CurrentTheme() inputStyle := t.S().Base.Padding(0, 1, 0, 1) modelList := NewModelListComponent(listKeyMap, inputStyle, "Choose a model for large, complex tasks") + apiKeyInput := NewAPIKeyInput() + apiKeyInput.SetShowTitle(false) help := help.New() help.Styles = t.S().Help return &modelDialogCmp{ - modelList: modelList, - width: defaultWidth, - keyMap: DefaultKeyMap(), - help: help, + modelList: modelList, + apiKeyInput: apiKeyInput, + width: defaultWidth, + keyMap: DefaultKeyMap(), + help: help, } } func (m *modelDialogCmp) Init() tea.Cmd { - return m.modelList.Init() + return tea.Batch(m.modelList.Init(), m.apiKeyInput.Init()) } func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { @@ -95,10 +110,58 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case tea.WindowSizeMsg: m.wWidth = msg.Width m.wHeight = msg.Height + m.apiKeyInput.SetWidth(m.width - 2) + m.help.Width = m.width - 2 return m, m.modelList.SetSize(m.listWidth(), m.listHeight()) + case APIKeyStateChangeMsg: + u, cmd := m.apiKeyInput.Update(msg) + m.apiKeyInput = u.(*APIKeyInput) + return m, cmd case tea.KeyPressMsg: switch { case key.Matches(msg, m.keyMap.Select): + if m.isAPIKeyValid { + return m, m.saveAPIKeyAndContinue(m.apiKeyValue) + } + if m.needsAPIKey { + // Handle API key submission + m.apiKeyValue = m.apiKeyInput.Value() + provider, err := m.getProvider(m.selectedModel.Provider.ID) + if err != nil || provider == nil { + return m, util.ReportError(fmt.Errorf("provider %s not found", m.selectedModel.Provider.ID)) + } + providerConfig := config.ProviderConfig{ + ID: string(m.selectedModel.Provider.ID), + Name: m.selectedModel.Provider.Name, + APIKey: m.apiKeyValue, + Type: provider.Type, + BaseURL: provider.APIEndpoint, + } + return m, tea.Sequence( + util.CmdHandler(APIKeyStateChangeMsg{ + State: APIKeyInputStateVerifying, + }), + func() tea.Msg { + start := time.Now() + err := providerConfig.TestConnection(config.Get().Resolver()) + // intentionally wait for at least 750ms to make sure the user sees the spinner + elapsed := time.Since(start) + if elapsed < 750*time.Millisecond { + time.Sleep(750*time.Millisecond - elapsed) + } + if err == nil { + m.isAPIKeyValid = true + return APIKeyStateChangeMsg{ + State: APIKeyInputStateVerified, + } + } + return APIKeyStateChangeMsg{ + State: APIKeyInputStateError, + } + }, + ) + } + // Normal model selection selectedItemInx := m.modelList.SelectedIndex() if selectedItemInx == list.NoSelection { return m, nil @@ -113,17 +176,32 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { modelType = config.SelectedModelTypeSmall } - return m, tea.Sequence( - util.CmdHandler(dialogs.CloseDialogMsg{}), - util.CmdHandler(ModelSelectedMsg{ - Model: config.SelectedModel{ - Model: selectedItem.Model.ID, - Provider: string(selectedItem.Provider.ID), - }, - ModelType: modelType, - }), - ) + // Check if provider is configured + if m.isProviderConfigured(string(selectedItem.Provider.ID)) { + return m, tea.Sequence( + util.CmdHandler(dialogs.CloseDialogMsg{}), + util.CmdHandler(ModelSelectedMsg{ + Model: config.SelectedModel{ + Model: selectedItem.Model.ID, + Provider: string(selectedItem.Provider.ID), + }, + ModelType: modelType, + }), + ) + } else { + // Provider not configured, show API key input + m.needsAPIKey = true + m.selectedModel = &selectedItem + m.selectedModelType = modelType + m.apiKeyInput.SetProviderName(selectedItem.Provider.Name) + return m, nil + } case key.Matches(msg, m.keyMap.Tab): + if m.needsAPIKey { + u, cmd := m.apiKeyInput.Update(msg) + m.apiKeyInput = u.(*APIKeyInput) + return m, cmd + } if m.modelList.GetModelType() == LargeModelType { m.modelList.SetInputPlaceholder(smallModelInputPlaceholder) return m, m.modelList.SetModelType(SmallModelType) @@ -132,18 +210,68 @@ func (m *modelDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, m.modelList.SetModelType(LargeModelType) } case key.Matches(msg, m.keyMap.Close): + if m.needsAPIKey { + if m.isAPIKeyValid { + return m, nil + } + // Go back to model selection + m.needsAPIKey = false + m.selectedModel = nil + m.isAPIKeyValid = false + m.apiKeyValue = "" + m.apiKeyInput.Reset() + return m, nil + } return m, util.CmdHandler(dialogs.CloseDialogMsg{}) default: - u, cmd := m.modelList.Update(msg) - m.modelList = u + if m.needsAPIKey { + u, cmd := m.apiKeyInput.Update(msg) + m.apiKeyInput = u.(*APIKeyInput) + return m, cmd + } else { + u, cmd := m.modelList.Update(msg) + m.modelList = u + return m, cmd + } + } + case tea.PasteMsg: + if m.needsAPIKey { + u, cmd := m.apiKeyInput.Update(msg) + m.apiKeyInput = u.(*APIKeyInput) + return m, cmd + } else { + var cmd tea.Cmd + m.modelList, cmd = m.modelList.Update(msg) return m, cmd } + case spinner.TickMsg: + u, cmd := m.apiKeyInput.Update(msg) + m.apiKeyInput = u.(*APIKeyInput) + return m, cmd } return m, nil } func (m *modelDialogCmp) View() string { t := styles.CurrentTheme() + + if m.needsAPIKey { + // Show API key input + m.keyMap.isAPIKeyHelp = true + m.keyMap.isAPIKeyValid = m.isAPIKeyValid + apiKeyView := m.apiKeyInput.View() + apiKeyView = t.S().Base.Width(m.width - 3).Height(lipgloss.Height(apiKeyView)).PaddingLeft(1).Render(apiKeyView) + content := lipgloss.JoinVertical( + lipgloss.Left, + t.S().Base.Padding(0, 1, 1, 1).Render(core.Title(m.apiKeyInput.GetTitle(), m.width-4)), + apiKeyView, + "", + t.S().Base.Width(m.width-2).PaddingLeft(1).AlignHorizontal(lipgloss.Left).Render(m.help.View(m.keyMap)), + ) + return m.style().Render(content) + } + + // Show model selection listView := m.modelList.View() radio := m.modelTypeRadio() content := lipgloss.JoinVertical( @@ -157,10 +285,18 @@ func (m *modelDialogCmp) View() string { } func (m *modelDialogCmp) Cursor() *tea.Cursor { - cursor := m.modelList.Cursor() - if cursor != nil { - cursor = m.moveCursor(cursor) - return cursor + if m.needsAPIKey { + cursor := m.apiKeyInput.Cursor() + if cursor != nil { + cursor = m.moveCursor(cursor) + return cursor + } + } else { + cursor := m.modelList.Cursor() + if cursor != nil { + cursor = m.moveCursor(cursor) + return cursor + } } return nil } @@ -192,9 +328,15 @@ func (m *modelDialogCmp) Position() (int, int) { func (m *modelDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor { row, col := m.Position() - offset := row + 3 // Border + title - cursor.Y += offset - cursor.X = cursor.X + col + 2 + if m.needsAPIKey { + offset := row + 3 // Border + title + API key input offset + cursor.Y += offset + cursor.X = cursor.X + col + 2 + } else { + offset := row + 3 // Border + title + cursor.Y += offset + cursor.X = cursor.X + col + 2 + } return cursor } @@ -212,3 +354,49 @@ func (m *modelDialogCmp) modelTypeRadio() string { } return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1]) } + +func (m *modelDialogCmp) isProviderConfigured(providerID string) bool { + cfg := config.Get() + if _, ok := cfg.Providers[providerID]; ok { + return true + } + return false +} + +func (m *modelDialogCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) { + providers, err := config.Providers() + if err != nil { + return nil, err + } + for _, p := range providers { + if p.ID == providerID { + return &p, nil + } + } + return nil, nil +} + +func (m *modelDialogCmp) saveAPIKeyAndContinue(apiKey string) tea.Cmd { + if m.selectedModel == nil { + return util.ReportError(fmt.Errorf("no model selected")) + } + + cfg := config.Get() + err := cfg.SetProviderAPIKey(string(m.selectedModel.Provider.ID), apiKey) + if err != nil { + return util.ReportError(fmt.Errorf("failed to save API key: %w", err)) + } + + // Reset API key state and continue with model selection + selectedModel := *m.selectedModel + return tea.Sequence( + util.CmdHandler(dialogs.CloseDialogMsg{}), + util.CmdHandler(ModelSelectedMsg{ + Model: config.SelectedModel{ + Model: selectedModel.Model.ID, + Provider: string(selectedModel.Provider.ID), + }, + ModelType: m.selectedModelType, + }), + ) +} diff --git a/internal/tui/page/chat/chat.go b/internal/tui/page/chat/chat.go index e55e965df2cbe375143a2e9cdf0c2c3252338f95..0d28f13f3ca0a42c9ae15612f21678cdeb8f4bf2 100644 --- a/internal/tui/page/chat/chat.go +++ b/internal/tui/page/chat/chat.go @@ -25,6 +25,7 @@ import ( "github.com/charmbracelet/crush/internal/tui/components/core/layout" "github.com/charmbracelet/crush/internal/tui/components/dialogs/commands" "github.com/charmbracelet/crush/internal/tui/components/dialogs/filepicker" + "github.com/charmbracelet/crush/internal/tui/components/dialogs/models" "github.com/charmbracelet/crush/internal/tui/page" "github.com/charmbracelet/crush/internal/tui/styles" "github.com/charmbracelet/crush/internal/tui/util" @@ -172,6 +173,11 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return p, p.sendMessage(msg.Text, msg.Attachments) case chat.SessionSelectedMsg: return p, p.setSession(msg) + case splash.SubmitAPIKeyMsg: + u, cmd := p.splash.Update(msg) + p.splash = u.(splash.Splash) + cmds = append(cmds, cmd) + return p, tea.Batch(cmds...) case commands.ToggleCompactModeMsg: p.forceCompact = !p.forceCompact var cmd tea.Cmd @@ -212,12 +218,25 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { cmds = append(cmds, cmd) return p, tea.Batch(cmds...) + case models.APIKeyStateChangeMsg: + if p.focusedPane == PanelTypeSplash { + u, cmd := p.splash.Update(msg) + p.splash = u.(splash.Splash) + cmds = append(cmds, cmd) + } + return p, tea.Batch(cmds...) case pubsub.Event[message.Message], anim.StepMsg, spinner.TickMsg: - u, cmd := p.chat.Update(msg) - p.chat = u.(chat.MessageListCmp) - cmds = append(cmds, cmd) + if p.focusedPane == PanelTypeSplash { + u, cmd := p.splash.Update(msg) + p.splash = u.(splash.Splash) + cmds = append(cmds, cmd) + } else { + u, cmd := p.chat.Update(msg) + p.chat = u.(chat.MessageListCmp) + cmds = append(cmds, cmd) + } return p, tea.Batch(cmds...) case pubsub.Event[history.File], sidebar.SessionFilesMsg: @@ -655,12 +674,23 @@ func (p *chatPage) Help() help.KeyMap { fullList = append(fullList, []key.Binding{v}) } case p.isOnboarding && p.splash.IsShowingAPIKey(): + if p.splash.IsAPIKeyValid() { + shortList = append(shortList, + key.NewBinding( + key.WithKeys("enter"), + key.WithHelp("enter", "continue"), + ), + ) + } else { + shortList = append(shortList, + // Go back + key.NewBinding( + key.WithKeys("esc"), + key.WithHelp("esc", "back"), + ), + ) + } shortList = append(shortList, - // Go back - key.NewBinding( - key.WithKeys("esc"), - key.WithHelp("esc", "back"), - ), // Quit key.NewBinding( key.WithKeys("ctrl+c"), diff --git a/internal/tui/styles/crush.go b/internal/tui/styles/crush.go index 48911d7096f7b0d104b0361ae5f6632c9658e536..2c54d5e41c91521b9418cdcdd4bcbc5dc7231eee 100644 --- a/internal/tui/styles/crush.go +++ b/internal/tui/styles/crush.go @@ -52,5 +52,6 @@ func NewCrushTheme() *Theme { Red: charmtone.Coral, RedDark: charmtone.Sriracha, RedLight: charmtone.Salmon, + Cherry: charmtone.Cherry, } } diff --git a/internal/tui/styles/theme.go b/internal/tui/styles/theme.go index b91b7b32bcc599a64a07802f6641ddbaeff6d4e3..1d6967684c6ccb5c8f9db2dd23300600b2b5af15 100644 --- a/internal/tui/styles/theme.go +++ b/internal/tui/styles/theme.go @@ -72,6 +72,7 @@ type Theme struct { Red color.Color RedDark color.Color RedLight color.Color + Cherry color.Color styles *Styles } @@ -150,15 +151,15 @@ func (t *Theme) buildStyles() *Styles { TextInput: textinput.Styles{ Focused: textinput.StyleState{ Text: base, - Placeholder: base.Foreground(t.FgMuted), + Placeholder: base.Foreground(t.FgSubtle), Prompt: base.Foreground(t.Tertiary), - Suggestion: base.Foreground(t.FgMuted), + Suggestion: base.Foreground(t.FgSubtle), }, Blurred: textinput.StyleState{ Text: base.Foreground(t.FgMuted), - Placeholder: base.Foreground(t.FgMuted), + Placeholder: base.Foreground(t.FgSubtle), Prompt: base.Foreground(t.FgMuted), - Suggestion: base.Foreground(t.FgMuted), + Suggestion: base.Foreground(t.FgSubtle), }, Cursor: textinput.CursorStyle{ Color: t.Secondary, @@ -173,7 +174,7 @@ func (t *Theme) buildStyles() *Styles { LineNumber: base.Foreground(t.FgSubtle), CursorLine: base, CursorLineNumber: base.Foreground(t.FgSubtle), - Placeholder: base.Foreground(t.FgMuted), + Placeholder: base.Foreground(t.FgSubtle), Prompt: base.Foreground(t.Tertiary), }, Blurred: textarea.StyleState{ @@ -182,7 +183,7 @@ func (t *Theme) buildStyles() *Styles { LineNumber: base.Foreground(t.FgMuted), CursorLine: base, CursorLineNumber: base.Foreground(t.FgMuted), - Placeholder: base.Foreground(t.FgMuted), + Placeholder: base.Foreground(t.FgSubtle), Prompt: base.Foreground(t.FgMuted), }, Cursor: textarea.CursorStyle{ diff --git a/internal/tui/tui.go b/internal/tui/tui.go index b7bfc068abe8bade0248f5d23105a52cf315b98d..dab220adf021e5b74a07508f58e951b65f9e9cdb 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -12,6 +12,7 @@ import ( "github.com/charmbracelet/crush/internal/permission" "github.com/charmbracelet/crush/internal/pubsub" cmpChat "github.com/charmbracelet/crush/internal/tui/components/chat" + "github.com/charmbracelet/crush/internal/tui/components/chat/splash" "github.com/charmbracelet/crush/internal/tui/components/completions" "github.com/charmbracelet/crush/internal/tui/components/core" "github.com/charmbracelet/crush/internal/tui/components/core/layout" @@ -48,8 +49,9 @@ type appModel struct { app *app.App - dialog dialogs.DialogCmp - completions completions.Completions + dialog dialogs.DialogCmp + completions completions.Completions + isConfigured bool // Chat Page Specific selectedSessionID string // The ID of the currently selected session @@ -72,6 +74,7 @@ func (a appModel) Init() tea.Cmd { func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmds []tea.Cmd var cmd tea.Cmd + a.isConfigured = config.HasInitialDataConfig() switch msg := msg.(type) { case tea.KeyboardEnhancementsMsg: @@ -223,10 +226,28 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } } + return a, tea.Batch(cmds...) + case splash.OnboardingCompleteMsg: + a.isConfigured = config.HasInitialDataConfig() + updated, cmd := a.pages[a.currentPage].Update(msg) + a.pages[a.currentPage] = updated.(util.Model) + cmds = append(cmds, cmd) return a, tea.Batch(cmds...) // Key Press Messages case tea.KeyPressMsg: return a, a.handleKeyPressMsg(msg) + + case tea.PasteMsg: + if a.dialog.HasDialogs() { + u, dialogCmd := a.dialog.Update(msg) + a.dialog = u.(dialogs.DialogCmp) + cmds = append(cmds, dialogCmd) + } else { + updated, cmd := a.pages[a.currentPage].Update(msg) + a.pages[a.currentPage] = updated.(util.Model) + cmds = append(cmds, cmd) + } + return a, tea.Batch(cmds...) } s, _ := a.status.Update(msg) a.status = s.(status.StatusCmp) @@ -307,6 +328,10 @@ func (a *appModel) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd { }) case key.Matches(msg, a.keyMap.Commands): + // if the app is not configured show no commands + if !a.isConfigured { + return nil + } if a.dialog.ActiveDialogID() == commands.CommandsDialogID { return util.CmdHandler(dialogs.CloseDialogMsg{}) } @@ -317,6 +342,10 @@ func (a *appModel) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd { Model: commands.NewCommandDialog(a.selectedSessionID), }) case key.Matches(msg, a.keyMap.Sessions): + // if the app is not configured show no sessions + if !a.isConfigured { + return nil + } if a.dialog.ActiveDialogID() == sessions.SessionsDialogID { return util.CmdHandler(dialogs.CloseDialogMsg{}) }