Detailed changes
@@ -1,39 +1 @@
-{
- "flagWords": [],
- "words": [
- "afero",
- "alecthomas",
- "bubbletea",
- "charmbracelet",
- "charmtone",
- "Charple",
- "crush",
- "diffview",
- "Emph",
- "filepicker",
- "Focusable",
- "fsext",
- "GROQ",
- "Guac",
- "imageorient",
- "Lanczos",
- "lipgloss",
- "lsps",
- "lucasb",
- "nfnt",
- "oksvg",
- "Preproc",
- "rasterx",
- "rivo",
- "Sourcegraph",
- "srwiley",
- "Strikethrough",
- "termenv",
- "textinput",
- "trashhalo",
- "uniseg",
- "Unticked"
- ],
- "version": "0.2",
- "language": "en"
-}
+{"language":"en","flagWords":[],"version":"0.2","words":["afero","alecthomas","bubbletea","charmbracelet","charmtone","Charple","crush","diffview","Emph","filepicker","Focusable","fsext","GROQ","Guac","imageorient","Lanczos","lipgloss","lsps","lucasb","nfnt","oksvg","Preproc","rasterx","rivo","Sourcegraph","srwiley","Strikethrough","termenv","textinput","trashhalo","uniseg","Unticked","genai"]}
@@ -592,11 +592,11 @@ func mergeProviderConfigs(base, global, local *Config) {
if cfg == nil {
continue
}
- for providerName, globalProvider := range cfg.Providers {
+ for providerName, p := range cfg.Providers {
if _, ok := base.Providers[providerName]; !ok {
- base.Providers[providerName] = globalProvider
+ base.Providers[providerName] = p
} else {
- base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], globalProvider)
+ base.Providers[providerName] = mergeProviderConfig(providerName, base.Providers[providerName], p)
}
}
}
@@ -0,0 +1,74 @@
+package config
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "strings"
+ "time"
+
+ "github.com/charmbracelet/crush/internal/llm/tools/shell"
+ "github.com/charmbracelet/crush/internal/logging"
+)
+
+// ExecuteCommand executes a shell command and returns the output
+// This is a shared utility that can be used by both provider config and tools
+func ExecuteCommand(ctx context.Context, command string, workingDir string) (string, error) {
+ if workingDir == "" {
+ workingDir = WorkingDirectory()
+ }
+
+ persistentShell := shell.GetPersistentShell(workingDir)
+
+ stdout, stderr, err := persistentShell.Exec(ctx, command)
+ if err != nil {
+ logging.Debug("Command execution failed", "command", command, "error", err, "stderr", stderr)
+ return "", fmt.Errorf("command execution failed: %w", err)
+ }
+
+ return strings.TrimSpace(stdout), nil
+}
+
+// ResolveAPIKey resolves an API key that can be either:
+// - A direct string value
+// - An environment variable (prefixed with $)
+// - A shell command (wrapped in $(...))
+func ResolveAPIKey(apiKey string) (string, error) {
+ if !strings.HasPrefix(apiKey, "$") {
+ return apiKey, nil
+ }
+
+ if strings.HasPrefix(apiKey, "$(") && strings.HasSuffix(apiKey, ")") {
+ command := strings.TrimSuffix(strings.TrimPrefix(apiKey, "$("), ")")
+ logging.Debug("Resolving API key from command", "command", command)
+ return resolveCommandAPIKey(command)
+ }
+
+ envVar := strings.TrimPrefix(apiKey, "$")
+ if value := os.Getenv(envVar); value != "" {
+ logging.Debug("Resolved environment variable", "envVar", envVar, "value", value)
+ return value, nil
+ }
+
+ logging.Debug("Environment variable not found", "envVar", envVar)
+
+ return "", fmt.Errorf("environment variable %s not found", envVar)
+}
+
+// resolveCommandAPIKey executes a command to get an API key, with caching support
+func resolveCommandAPIKey(command string) (string, error) {
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ logging.Debug("Executing command for API key", "command", command)
+
+ workingDir := WorkingDirectory()
+
+ result, err := ExecuteCommand(ctx, command, workingDir)
+ if err != nil {
+ return "", fmt.Errorf("failed to execute API key command: %w", err)
+ }
+ logging.Debug("Command executed successfully", "command", command, "result", result)
+ return result, nil
+}
+
@@ -253,10 +253,10 @@ func (a *agent) IsBusy() bool {
if cancelFunc, ok := value.(context.CancelFunc); ok {
if cancelFunc != nil {
busy = true
- return false // Stop iterating
+ return false
}
}
- return true // Continue iterating
+ return true
})
return busy
}
@@ -21,12 +21,20 @@ import (
type anthropicClient struct {
providerOptions providerClientOptions
+ useBedrock bool
client anthropic.Client
}
type AnthropicClient ProviderClient
func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicClient {
+ return &anthropicClient{
+ providerOptions: opts,
+ client: createAnthropicClient(opts, useBedrock),
+ }
+}
+
+func createAnthropicClient(opts providerClientOptions, useBedrock bool) anthropic.Client {
anthropicClientOptions := []option.RequestOption{}
if opts.apiKey != "" {
anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey))
@@ -34,12 +42,7 @@ func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicCl
if useBedrock {
anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background()))
}
-
- client := anthropic.NewClient(anthropicClientOptions...)
- return &anthropicClient{
- providerOptions: opts,
- client: client,
- }
+ return anthropic.NewClient(anthropicClientOptions...)
}
func (a *anthropicClient) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) {
@@ -385,12 +388,21 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message
}
func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, error) {
- var apierr *anthropic.Error
- if !errors.As(err, &apierr) {
+ var apiErr *anthropic.Error
+ if !errors.As(err, &apiErr) {
return false, 0, err
}
- if apierr.StatusCode != 429 && apierr.StatusCode != 529 {
+ if apiErr.StatusCode == 401 {
+ a.providerOptions.apiKey, err = config.ResolveAPIKey(a.providerOptions.config.APIKey)
+ if err != nil {
+ return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
+ }
+ a.client = createAnthropicClient(a.providerOptions, a.useBedrock)
+ return true, 0, nil
+ }
+
+ if apiErr.StatusCode != 429 && apiErr.StatusCode != 529 {
return false, 0, err
}
@@ -399,7 +411,7 @@ func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, err
}
retryMs := 0
- retryAfterValues := apierr.Response.Header.Values("Retry-After")
+ retryAfterValues := apiErr.Response.Header.Values("Retry-After")
backoffMs := 2000 * (1 << (attempts - 1))
jitterMs := int(float64(backoffMs) * 0.2)
@@ -25,7 +25,7 @@ type geminiClient struct {
type GeminiClient ProviderClient
func newGeminiClient(opts providerClientOptions) GeminiClient {
- client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
+ client, err := createGeminiClient(opts)
if err != nil {
logging.Error("Failed to create Gemini client", "error", err)
return nil
@@ -37,6 +37,14 @@ func newGeminiClient(opts providerClientOptions) GeminiClient {
}
}
+func createGeminiClient(opts providerClientOptions) (*genai.Client, error) {
+ client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI})
+ if err != nil {
+ return nil, err
+ }
+ return client, nil
+}
+
func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
var history []*genai.Content
for _, msg := range messages {
@@ -414,6 +422,19 @@ func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error)
errMsg := err.Error()
isRateLimit := contains(errMsg, "rate limit", "quota exceeded", "too many requests")
+ // Check for token expiration (401 Unauthorized)
+ if contains(errMsg, "unauthorized", "invalid api key", "api key expired") {
+ g.providerOptions.apiKey, err = config.ResolveAPIKey(g.providerOptions.config.APIKey)
+ if err != nil {
+ return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
+ }
+ g.client, err = createGeminiClient(g.providerOptions)
+ if err != nil {
+ return false, 0, fmt.Errorf("failed to create Gemini client after API key refresh: %w", err)
+ }
+ return true, 0, nil
+ }
+
// Check for common rate limit error messages
if !isRateLimit {
@@ -26,6 +26,13 @@ type openaiClient struct {
type OpenAIClient ProviderClient
func newOpenAIClient(opts providerClientOptions) OpenAIClient {
+ return &openaiClient{
+ providerOptions: opts,
+ client: createOpenAIClient(opts),
+ }
+}
+
+func createOpenAIClient(opts providerClientOptions) openai.Client {
openaiClientOptions := []option.RequestOption{}
if opts.apiKey != "" {
openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
@@ -40,11 +47,7 @@ func newOpenAIClient(opts providerClientOptions) OpenAIClient {
}
}
- client := openai.NewClient(openaiClientOptions...)
- return &openaiClient{
- providerOptions: opts,
- client: client,
- }
+ return openai.NewClient(openaiClientOptions...)
}
func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
@@ -339,12 +342,22 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
}
func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
- var apierr *openai.Error
- if !errors.As(err, &apierr) {
+ var apiErr *openai.Error
+ if !errors.As(err, &apiErr) {
return false, 0, err
}
- if apierr.StatusCode != 429 && apierr.StatusCode != 500 {
+ // Check for token expiration (401 Unauthorized)
+ if apiErr.StatusCode == 401 {
+ o.providerOptions.apiKey, err = config.ResolveAPIKey(o.providerOptions.config.APIKey)
+ if err != nil {
+ return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
+ }
+ o.client = createOpenAIClient(o.providerOptions)
+ return true, 0, nil
+ }
+
+ if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 {
return false, 0, err
}
@@ -353,7 +366,7 @@ func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error)
}
retryMs := 0
- retryAfterValues := apierr.Response.Header.Values("Retry-After")
+ retryAfterValues := apiErr.Response.Header.Values("Retry-After")
backoffMs := 2000 * (1 << (attempts - 1))
jitterMs := int(float64(backoffMs) * 0.2)
@@ -60,6 +60,7 @@ type Provider interface {
type providerClientOptions struct {
baseURL string
+ config config.ProviderConfig
apiKey string
modelType config.ModelType
model func(config.ModelType) config.Model
@@ -134,9 +135,15 @@ func WithMaxTokens(maxTokens int64) ProviderClientOption {
}
func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
+ resolvedAPIKey, err := config.ResolveAPIKey(cfg.APIKey)
+ if err != nil {
+ return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err)
+ }
+
clientOptions := providerClientOptions{
baseURL: cfg.BaseURL,
- apiKey: cfg.APIKey,
+ config: cfg,
+ apiKey: resolvedAPIKey,
extraHeaders: cfg.ExtraHeaders,
model: func(tp config.ModelType) config.Model {
return config.GetModel(tp)