fix(claude): add authentication refresh on 401 errors (#1581)

Kieran Klukas and Andrey Nering created

Co-authored-by: Andrey Nering <andreynering@users.noreply.github.com>

Change summary

internal/agent/coordinator.go | 91 +++++++++++++++++++++++++++---------
internal/config/config.go     |  7 ++
internal/config/load.go       | 23 --------
3 files changed, 73 insertions(+), 48 deletions(-)

Detailed changes

internal/agent/coordinator.go 🔗

@@ -10,6 +10,7 @@ import (
 	"io"
 	"log/slog"
 	"maps"
+	"net/http"
 	"os"
 	"slices"
 	"strings"
@@ -130,32 +131,42 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string,
 
 	mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
 
-	if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
-		slog.Info("Detected expired OAuth token, attempting refresh", "provider", providerCfg.ID)
-		if refreshErr := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); refreshErr != nil {
-			slog.Error("Failed to refresh OAuth token", "provider", providerCfg.ID, "error", refreshErr)
-			return nil, refreshErr
-		}
-
-		// Rebuild models with refreshed token
-		if updateErr := c.UpdateModels(ctx); updateErr != nil {
-			slog.Error("Failed to update models after token refresh", "error", updateErr)
-			return nil, updateErr
+	run := func() (*fantasy.AgentResult, error) {
+		return c.currentAgent.Run(ctx, SessionAgentCall{
+			SessionID:        sessionID,
+			Prompt:           prompt,
+			Attachments:      attachments,
+			MaxOutputTokens:  maxTokens,
+			ProviderOptions:  mergedOptions,
+			Temperature:      temp,
+			TopP:             topP,
+			TopK:             topK,
+			FrequencyPenalty: freqPenalty,
+			PresencePenalty:  presPenalty,
+		})
+	}
+	result, originalErr := run()
+
+	if c.isUnauthorized(originalErr) {
+		switch {
+		case providerCfg.OAuthToken != nil:
+			slog.Info("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
+			if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
+				return nil, originalErr
+			}
+			slog.Info("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
+			return run()
+		case strings.Contains(providerCfg.APIKeyTemplate, "$"):
+			slog.Info("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
+			if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
+				return nil, originalErr
+			}
+			slog.Info("Retrying request with refreshed API key", "provider", providerCfg.ID)
+			return run()
 		}
 	}
-	result, err := c.currentAgent.Run(ctx, SessionAgentCall{
-		SessionID:        sessionID,
-		Prompt:           prompt,
-		Attachments:      attachments,
-		MaxOutputTokens:  maxTokens,
-		ProviderOptions:  mergedOptions,
-		Temperature:      temp,
-		TopP:             topP,
-		TopK:             topK,
-		FrequencyPenalty: freqPenalty,
-		PresencePenalty:  presPenalty,
-	})
-	return result, err
+
+	return result, originalErr
 }
 
 func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
@@ -773,3 +784,35 @@ func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
 	}
 	return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
 }
+
+func (c *coordinator) isUnauthorized(err error) bool {
+	var providerErr *fantasy.ProviderError
+	return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
+}
+
+func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
+	if err := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); err != nil {
+		slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
+		return err
+	}
+	if err := c.UpdateModels(ctx); err != nil {
+		return err
+	}
+	return nil
+}
+
+func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
+	newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
+	if err != nil {
+		slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
+		return err
+	}
+
+	providerCfg.APIKey = newAPIKey
+	c.cfg.Providers.Set(providerCfg.ID, providerCfg)
+
+	if err := c.UpdateModels(ctx); err != nil {
+		return err
+	}
+	return nil
+}

internal/config/config.go 🔗

@@ -95,6 +95,8 @@ type ProviderConfig struct {
 	Type catwalk.Type `json:"type,omitempty" jsonschema:"description=Provider type that determines the API format,enum=openai,enum=openai-compat,enum=anthropic,enum=gemini,enum=azure,enum=vertexai,default=openai"`
 	// The provider's API key.
 	APIKey string `json:"api_key,omitempty" jsonschema:"description=API key for authentication with the provider,example=$OPENAI_API_KEY"`
+	// The original API key template before resolution (for re-resolution on auth errors).
+	APIKeyTemplate string `json:"-"`
 	// OAuthToken for providers that use OAuth2 authentication.
 	OAuthToken *oauth.Token `json:"oauth,omitempty" jsonschema:"description=OAuth2 token for authentication with the provider"`
 	// Marks the provider as disabled.
@@ -469,6 +471,7 @@ func (c *Config) SetConfigField(key string, value any) error {
 	return nil
 }
 
+// RefreshOAuthToken refreshes the OAuth token for the given provider.
 func (c *Config) RefreshOAuthToken(ctx context.Context, providerID string) error {
 	providerConfig, exists := c.Providers.Get(providerID)
 	if !exists {
@@ -479,7 +482,7 @@ func (c *Config) RefreshOAuthToken(ctx context.Context, providerID string) error
 		return fmt.Errorf("provider %s does not have an OAuth token", providerID)
 	}
 
-	// Only Anthropic provider uses OAuth for now
+	// Only Anthropic provider uses OAuth for now.
 	if providerID != string(catwalk.InferenceProviderAnthropic) {
 		return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
 	}
@@ -489,7 +492,7 @@ func (c *Config) RefreshOAuthToken(ctx context.Context, providerID string) error
 		return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, err)
 	}
 
-	slog.Info("Successfully refreshed OAuth token in background", "provider", providerID)
+	slog.Info("Successfully refreshed OAuth token", "provider", providerID)
 	providerConfig.OAuthToken = newToken
 	providerConfig.APIKey = fmt.Sprintf("Bearer %s", newToken.AccessToken)
 	providerConfig.SetupClaudeCode()

internal/config/load.go 🔗

@@ -1,7 +1,6 @@
 package config
 
 import (
-	"cmp"
 	"context"
 	"encoding/json"
 	"fmt"
@@ -19,11 +18,9 @@ import (
 	"github.com/charmbracelet/catwalk/pkg/catwalk"
 	"github.com/charmbracelet/crush/internal/csync"
 	"github.com/charmbracelet/crush/internal/env"
-	"github.com/charmbracelet/crush/internal/event"
 	"github.com/charmbracelet/crush/internal/fsext"
 	"github.com/charmbracelet/crush/internal/home"
 	"github.com/charmbracelet/crush/internal/log"
-	"github.com/charmbracelet/crush/internal/oauth/claude"
 	powernapConfig "github.com/charmbracelet/x/powernap/pkg/config"
 )
 
@@ -189,6 +186,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
 			Name:               p.Name,
 			BaseURL:            p.APIEndpoint,
 			APIKey:             p.APIKey,
+			APIKeyTemplate:     p.APIKey, // Store original template for re-resolution
 			OAuthToken:         config.OAuthToken,
 			Type:               p.Type,
 			Disable:            config.Disable,
@@ -200,25 +198,6 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
 		}
 
 		if p.ID == catwalk.InferenceProviderAnthropic && config.OAuthToken != nil {
-			if config.OAuthToken.IsExpired() {
-				newToken, err := claude.RefreshToken(context.TODO(), config.OAuthToken.RefreshToken)
-				if err == nil {
-					slog.Info("Successfully refreshed Anthropic OAuth token")
-					config.OAuthToken = newToken
-					prepared.OAuthToken = newToken
-					if err := cmp.Or(
-						c.SetConfigField("providers.anthropic.api_key", newToken.AccessToken),
-						c.SetConfigField("providers.anthropic.oauth", newToken),
-					); err != nil {
-						return err
-					}
-				} else {
-					slog.Error("Failed to refresh Anthropic OAuth token", "error", err)
-					event.Error(err)
-				}
-			} else {
-				slog.Info("Using existing non-expired Anthropic OAuth token")
-			}
 			prepared.SetupClaudeCode()
 		}