@@ -10,6 +10,7 @@ import (
"io"
"log/slog"
"maps"
+ "net/http"
"os"
"slices"
"strings"
@@ -130,32 +131,42 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string,
mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
- if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
- slog.Info("Detected expired OAuth token, attempting refresh", "provider", providerCfg.ID)
- if refreshErr := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); refreshErr != nil {
- slog.Error("Failed to refresh OAuth token", "provider", providerCfg.ID, "error", refreshErr)
- return nil, refreshErr
- }
-
- // Rebuild models with refreshed token
- if updateErr := c.UpdateModels(ctx); updateErr != nil {
- slog.Error("Failed to update models after token refresh", "error", updateErr)
- return nil, updateErr
+ run := func() (*fantasy.AgentResult, error) {
+ return c.currentAgent.Run(ctx, SessionAgentCall{
+ SessionID: sessionID,
+ Prompt: prompt,
+ Attachments: attachments,
+ MaxOutputTokens: maxTokens,
+ ProviderOptions: mergedOptions,
+ Temperature: temp,
+ TopP: topP,
+ TopK: topK,
+ FrequencyPenalty: freqPenalty,
+ PresencePenalty: presPenalty,
+ })
+ }
+ result, originalErr := run()
+
+ if c.isUnauthorized(originalErr) {
+ switch {
+ case providerCfg.OAuthToken != nil:
+ slog.Info("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
+ if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
+ return nil, originalErr
+ }
+ slog.Info("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
+ return run()
+ case strings.Contains(providerCfg.APIKeyTemplate, "$"):
+ slog.Info("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
+ if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
+ return nil, originalErr
+ }
+ slog.Info("Retrying request with refreshed API key", "provider", providerCfg.ID)
+ return run()
}
}
- result, err := c.currentAgent.Run(ctx, SessionAgentCall{
- SessionID: sessionID,
- Prompt: prompt,
- Attachments: attachments,
- MaxOutputTokens: maxTokens,
- ProviderOptions: mergedOptions,
- Temperature: temp,
- TopP: topP,
- TopK: topK,
- FrequencyPenalty: freqPenalty,
- PresencePenalty: presPenalty,
- })
- return result, err
+
+ return result, originalErr
}
func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
@@ -773,3 +784,35 @@ func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
}
return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
}
+
+func (c *coordinator) isUnauthorized(err error) bool {
+ var providerErr *fantasy.ProviderError
+ return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
+}
+
+func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
+ if err := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); err != nil {
+ slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
+ return err
+ }
+ if err := c.UpdateModels(ctx); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
+ newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
+ if err != nil {
+ slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
+ return err
+ }
+
+ providerCfg.APIKey = newAPIKey
+ c.cfg.Providers.Set(providerCfg.ID, providerCfg)
+
+ if err := c.UpdateModels(ctx); err != nil {
+ return err
+ }
+ return nil
+}
@@ -95,6 +95,8 @@ type ProviderConfig struct {
Type catwalk.Type `json:"type,omitempty" jsonschema:"description=Provider type that determines the API format,enum=openai,enum=openai-compat,enum=anthropic,enum=gemini,enum=azure,enum=vertexai,default=openai"`
// The provider's API key.
APIKey string `json:"api_key,omitempty" jsonschema:"description=API key for authentication with the provider,example=$OPENAI_API_KEY"`
+ // The original API key template before resolution (for re-resolution on auth errors).
+ APIKeyTemplate string `json:"-"`
// OAuthToken for providers that use OAuth2 authentication.
OAuthToken *oauth.Token `json:"oauth,omitempty" jsonschema:"description=OAuth2 token for authentication with the provider"`
// Marks the provider as disabled.
@@ -469,6 +471,7 @@ func (c *Config) SetConfigField(key string, value any) error {
return nil
}
+// RefreshOAuthToken refreshes the OAuth token for the given provider.
func (c *Config) RefreshOAuthToken(ctx context.Context, providerID string) error {
providerConfig, exists := c.Providers.Get(providerID)
if !exists {
@@ -479,7 +482,7 @@ func (c *Config) RefreshOAuthToken(ctx context.Context, providerID string) error
return fmt.Errorf("provider %s does not have an OAuth token", providerID)
}
- // Only Anthropic provider uses OAuth for now
+ // Only Anthropic provider uses OAuth for now.
if providerID != string(catwalk.InferenceProviderAnthropic) {
return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
}
@@ -489,7 +492,7 @@ func (c *Config) RefreshOAuthToken(ctx context.Context, providerID string) error
return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, err)
}
- slog.Info("Successfully refreshed OAuth token in background", "provider", providerID)
+ slog.Info("Successfully refreshed OAuth token", "provider", providerID)
providerConfig.OAuthToken = newToken
providerConfig.APIKey = fmt.Sprintf("Bearer %s", newToken.AccessToken)
providerConfig.SetupClaudeCode()
@@ -1,7 +1,6 @@
package config
import (
- "cmp"
"context"
"encoding/json"
"fmt"
@@ -19,11 +18,9 @@ import (
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/env"
- "github.com/charmbracelet/crush/internal/event"
"github.com/charmbracelet/crush/internal/fsext"
"github.com/charmbracelet/crush/internal/home"
"github.com/charmbracelet/crush/internal/log"
- "github.com/charmbracelet/crush/internal/oauth/claude"
powernapConfig "github.com/charmbracelet/x/powernap/pkg/config"
)
@@ -189,6 +186,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
Name: p.Name,
BaseURL: p.APIEndpoint,
APIKey: p.APIKey,
+ APIKeyTemplate: p.APIKey, // Store original template for re-resolution
OAuthToken: config.OAuthToken,
Type: p.Type,
Disable: config.Disable,
@@ -200,25 +198,6 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
}
if p.ID == catwalk.InferenceProviderAnthropic && config.OAuthToken != nil {
- if config.OAuthToken.IsExpired() {
- newToken, err := claude.RefreshToken(context.TODO(), config.OAuthToken.RefreshToken)
- if err == nil {
- slog.Info("Successfully refreshed Anthropic OAuth token")
- config.OAuthToken = newToken
- prepared.OAuthToken = newToken
- if err := cmp.Or(
- c.SetConfigField("providers.anthropic.api_key", newToken.AccessToken),
- c.SetConfigField("providers.anthropic.oauth", newToken),
- ); err != nil {
- return err
- }
- } else {
- slog.Error("Failed to refresh Anthropic OAuth token", "error", err)
- event.Error(err)
- }
- } else {
- slog.Info("Using existing non-expired Anthropic OAuth token")
- }
prepared.SetupClaudeCode()
}