diff --git a/cspell.json b/cspell.json index 2595963182b8e2aa6fe575bbe9ca6a5da0f70a9b..d62c817e8c8699e6172e576eb0f91602dd8417a3 100644 --- a/cspell.json +++ b/cspell.json @@ -1,39 +1 @@ -{ - "flagWords": [], - "words": [ - "afero", - "alecthomas", - "bubbletea", - "charmbracelet", - "charmtone", - "Charple", - "crush", - "diffview", - "Emph", - "filepicker", - "Focusable", - "fsext", - "GROQ", - "Guac", - "imageorient", - "Lanczos", - "lipgloss", - "lsps", - "lucasb", - "nfnt", - "oksvg", - "Preproc", - "rasterx", - "rivo", - "Sourcegraph", - "srwiley", - "Strikethrough", - "termenv", - "textinput", - "trashhalo", - "uniseg", - "Unticked" - ], - "version": "0.2", - "language": "en" -} +{"language":"en","flagWords":[],"version":"0.2","words":["afero","alecthomas","bubbletea","charmbracelet","charmtone","Charple","crush","diffview","Emph","filepicker","Focusable","fsext","GROQ","Guac","imageorient","Lanczos","lipgloss","lsps","lucasb","nfnt","oksvg","Preproc","rasterx","rivo","Sourcegraph","srwiley","Strikethrough","termenv","textinput","trashhalo","uniseg","Unticked","genai"]} \ No newline at end of file diff --git a/internal/config/config.go b/internal/config/config.go index 32ca8729295bb3994af27ec4359a1b4960527671..74c0c63ecdd2843da0daf4875295d9d1f8ad20d7 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -592,11 +592,11 @@ func mergeProviderConfigs(base, global, local *Config) { if cfg == nil { continue } - for providerName, globalProvider := range cfg.Providers { + for providerName, p := range cfg.Providers { if _, ok := base.Providers[providerName]; !ok { - base.Providers[providerName] = globalProvider + base.Providers[providerName] = p } else { - base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], globalProvider) + base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], p) } } } diff --git a/internal/config/shell.go b/internal/config/shell.go new file mode 100644 index 0000000000000000000000000000000000000000..a12ecd1da3b82c113175a1f4825877a7fb94a95c --- /dev/null +++ b/internal/config/shell.go @@ -0,0 +1,74 @@ +package config + +import ( + "context" + "fmt" + "os" + "strings" + "time" + + "github.com/charmbracelet/crush/internal/llm/tools/shell" + "github.com/charmbracelet/crush/internal/logging" +) + +// ExecuteCommand executes a shell command and returns the output +// This is a shared utility that can be used by both provider config and tools +func ExecuteCommand(ctx context.Context, command string, workingDir string) (string, error) { + if workingDir == "" { + workingDir = WorkingDirectory() + } + + persistentShell := shell.GetPersistentShell(workingDir) + + stdout, stderr, err := persistentShell.Exec(ctx, command) + if err != nil { + logging.Debug("Command execution failed", "command", command, "error", err, "stderr", stderr) + return "", fmt.Errorf("command execution failed: %w", err) + } + + return strings.TrimSpace(stdout), nil +} + +// ResolveAPIKey resolves an API key that can be either: +// - A direct string value +// - An environment variable (prefixed with $) +// - A shell command (wrapped in $(...)) +func ResolveAPIKey(apiKey string) (string, error) { + if !strings.HasPrefix(apiKey, "$") { + return apiKey, nil + } + + if strings.HasPrefix(apiKey, "$(") && strings.HasSuffix(apiKey, ")") { + command := strings.TrimSuffix(strings.TrimPrefix(apiKey, "$("), ")") + logging.Debug("Resolving API key from command", "command", command) + return resolveCommandAPIKey(command) + } + + envVar := strings.TrimPrefix(apiKey, "$") + if value := os.Getenv(envVar); value != "" { + logging.Debug("Resolved environment variable", "envVar", envVar, "value", value) + return value, nil + } + + logging.Debug("Environment variable not found", "envVar", envVar) + + return "", fmt.Errorf("environment variable %s not found", envVar) +} + +// resolveCommandAPIKey executes a command to get an API key, with caching support +func resolveCommandAPIKey(command string) (string, error) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + logging.Debug("Executing command for API key", "command", command) + + workingDir := WorkingDirectory() + + result, err := ExecuteCommand(ctx, command, workingDir) + if err != nil { + return "", fmt.Errorf("failed to execute API key command: %w", err) + } + logging.Debug("Command executed successfully", "command", command, "result", result) + return result, nil +} + diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 57771a7dc98efd2fa897d655aa04b7fef628dab5..d165921f639ffee7127e4044c42d154f091a0dca 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -253,10 +253,10 @@ func (a *agent) IsBusy() bool { if cancelFunc, ok := value.(context.CancelFunc); ok { if cancelFunc != nil { busy = true - return false // Stop iterating + return false } } - return true // Continue iterating + return true }) return busy } diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index df6b8490ebc48abc7c01a2a938c6f7d395526654..c86f4372acc4fafd2a829f42489f545c4d589861 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -21,12 +21,20 @@ import ( type anthropicClient struct { providerOptions providerClientOptions + useBedrock bool client anthropic.Client } type AnthropicClient ProviderClient func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicClient { + return &anthropicClient{ + providerOptions: opts, + client: createAnthropicClient(opts, useBedrock), + } +} + +func createAnthropicClient(opts providerClientOptions, useBedrock bool) anthropic.Client { anthropicClientOptions := []option.RequestOption{} if opts.apiKey != "" { anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey)) @@ -34,12 +42,7 @@ func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicCl if useBedrock { anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background())) } - - client := anthropic.NewClient(anthropicClientOptions...) - return &anthropicClient{ - providerOptions: opts, - client: client, - } + return anthropic.NewClient(anthropicClientOptions...) } func (a *anthropicClient) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) { @@ -385,12 +388,21 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message } func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, error) { - var apierr *anthropic.Error - if !errors.As(err, &apierr) { + var apiErr *anthropic.Error + if !errors.As(err, &apiErr) { return false, 0, err } - if apierr.StatusCode != 429 && apierr.StatusCode != 529 { + if apiErr.StatusCode == 401 { + a.providerOptions.apiKey, err = config.ResolveAPIKey(a.providerOptions.config.APIKey) + if err != nil { + return false, 0, fmt.Errorf("failed to resolve API key: %w", err) + } + a.client = createAnthropicClient(a.providerOptions, a.useBedrock) + return true, 0, nil + } + + if apiErr.StatusCode != 429 && apiErr.StatusCode != 529 { return false, 0, err } @@ -399,7 +411,7 @@ func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, err } retryMs := 0 - retryAfterValues := apierr.Response.Header.Values("Retry-After") + retryAfterValues := apiErr.Response.Header.Values("Retry-After") backoffMs := 2000 * (1 << (attempts - 1)) jitterMs := int(float64(backoffMs) * 0.2) diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index f644d118b4ef642c5f9e835ecfaa450d9f835f4d..e80af34d0815695ea6ed76d01c25262381a836ec 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -25,7 +25,7 @@ type geminiClient struct { type GeminiClient ProviderClient func newGeminiClient(opts providerClientOptions) GeminiClient { - client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI}) + client, err := createGeminiClient(opts) if err != nil { logging.Error("Failed to create Gemini client", "error", err) return nil @@ -37,6 +37,14 @@ func newGeminiClient(opts providerClientOptions) GeminiClient { } } +func createGeminiClient(opts providerClientOptions) (*genai.Client, error) { + client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI}) + if err != nil { + return nil, err + } + return client, nil +} + func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content { var history []*genai.Content for _, msg := range messages { @@ -414,6 +422,19 @@ func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) errMsg := err.Error() isRateLimit := contains(errMsg, "rate limit", "quota exceeded", "too many requests") + // Check for token expiration (401 Unauthorized) + if contains(errMsg, "unauthorized", "invalid api key", "api key expired") { + g.providerOptions.apiKey, err = config.ResolveAPIKey(g.providerOptions.config.APIKey) + if err != nil { + return false, 0, fmt.Errorf("failed to resolve API key: %w", err) + } + g.client, err = createGeminiClient(g.providerOptions) + if err != nil { + return false, 0, fmt.Errorf("failed to create Gemini client after API key refresh: %w", err) + } + return true, 0, nil + } + // Check for common rate limit error messages if !isRateLimit { diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 1ae8847db441181a1a65bcacc8b4bd039b45a0fc..e045029651f3e9fc158c9f38cf810584e2c06724 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -26,6 +26,13 @@ type openaiClient struct { type OpenAIClient ProviderClient func newOpenAIClient(opts providerClientOptions) OpenAIClient { + return &openaiClient{ + providerOptions: opts, + client: createOpenAIClient(opts), + } +} + +func createOpenAIClient(opts providerClientOptions) openai.Client { openaiClientOptions := []option.RequestOption{} if opts.apiKey != "" { openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey)) @@ -40,11 +47,7 @@ func newOpenAIClient(opts providerClientOptions) OpenAIClient { } } - client := openai.NewClient(openaiClientOptions...) - return &openaiClient{ - providerOptions: opts, - client: client, - } + return openai.NewClient(openaiClientOptions...) } func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) { @@ -339,12 +342,22 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t } func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) { - var apierr *openai.Error - if !errors.As(err, &apierr) { + var apiErr *openai.Error + if !errors.As(err, &apiErr) { return false, 0, err } - if apierr.StatusCode != 429 && apierr.StatusCode != 500 { + // Check for token expiration (401 Unauthorized) + if apiErr.StatusCode == 401 { + o.providerOptions.apiKey, err = config.ResolveAPIKey(o.providerOptions.config.APIKey) + if err != nil { + return false, 0, fmt.Errorf("failed to resolve API key: %w", err) + } + o.client = createOpenAIClient(o.providerOptions) + return true, 0, nil + } + + if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 { return false, 0, err } @@ -353,7 +366,7 @@ func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) } retryMs := 0 - retryAfterValues := apierr.Response.Header.Values("Retry-After") + retryAfterValues := apiErr.Response.Header.Values("Retry-After") backoffMs := 2000 * (1 << (attempts - 1)) jitterMs := int(float64(backoffMs) * 0.2) diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 2133e23309b4d92d8c8b2efbf1bb386a2e7753cd..3ffbf86c00c5e3ca27f1b68965f4ff950f1f7454 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -60,6 +60,7 @@ type Provider interface { type providerClientOptions struct { baseURL string + config config.ProviderConfig apiKey string modelType config.ModelType model func(config.ModelType) config.Model @@ -134,9 +135,15 @@ func WithMaxTokens(maxTokens int64) ProviderClientOption { } func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) { + resolvedAPIKey, err := config.ResolveAPIKey(cfg.APIKey) + if err != nil { + return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err) + } + clientOptions := providerClientOptions{ baseURL: cfg.BaseURL, - apiKey: cfg.APIKey, + config: cfg, + apiKey: resolvedAPIKey, extraHeaders: cfg.ExtraHeaders, model: func(tp config.ModelType) config.Model { return config.GetModel(tp)