diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index 91463fe4c24be90b743bcdb654f865ce60ecf2af..436aa27d95e4b86f83c20c3f46b2e1434986e89d 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -10,6 +10,7 @@ import ( "io" "log/slog" "maps" + "net/http" "os" "slices" "strings" @@ -130,32 +131,42 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg) - if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() { - slog.Info("Detected expired OAuth token, attempting refresh", "provider", providerCfg.ID) - if refreshErr := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); refreshErr != nil { - slog.Error("Failed to refresh OAuth token", "provider", providerCfg.ID, "error", refreshErr) - return nil, refreshErr - } - - // Rebuild models with refreshed token - if updateErr := c.UpdateModels(ctx); updateErr != nil { - slog.Error("Failed to update models after token refresh", "error", updateErr) - return nil, updateErr + run := func() (*fantasy.AgentResult, error) { + return c.currentAgent.Run(ctx, SessionAgentCall{ + SessionID: sessionID, + Prompt: prompt, + Attachments: attachments, + MaxOutputTokens: maxTokens, + ProviderOptions: mergedOptions, + Temperature: temp, + TopP: topP, + TopK: topK, + FrequencyPenalty: freqPenalty, + PresencePenalty: presPenalty, + }) + } + result, originalErr := run() + + if c.isUnauthorized(originalErr) { + switch { + case providerCfg.OAuthToken != nil: + slog.Info("Received 401. Refreshing token and retrying", "provider", providerCfg.ID) + if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil { + return nil, originalErr + } + slog.Info("Retrying request with refreshed OAuth token", "provider", providerCfg.ID) + return run() + case strings.Contains(providerCfg.APIKeyTemplate, "$"): + slog.Info("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID) + if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil { + return nil, originalErr + } + slog.Info("Retrying request with refreshed API key", "provider", providerCfg.ID) + return run() } } - result, err := c.currentAgent.Run(ctx, SessionAgentCall{ - SessionID: sessionID, - Prompt: prompt, - Attachments: attachments, - MaxOutputTokens: maxTokens, - ProviderOptions: mergedOptions, - Temperature: temp, - TopP: topP, - TopK: topK, - FrequencyPenalty: freqPenalty, - PresencePenalty: presPenalty, - }) - return result, err + + return result, originalErr } func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions { @@ -773,3 +784,35 @@ func (c *coordinator) Summarize(ctx context.Context, sessionID string) error { } return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg)) } + +func (c *coordinator) isUnauthorized(err error) bool { + var providerErr *fantasy.ProviderError + return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized +} + +func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error { + if err := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); err != nil { + slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err) + return err + } + if err := c.UpdateModels(ctx); err != nil { + return err + } + return nil +} + +func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error { + newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate) + if err != nil { + slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err) + return err + } + + providerCfg.APIKey = newAPIKey + c.cfg.Providers.Set(providerCfg.ID, providerCfg) + + if err := c.UpdateModels(ctx); err != nil { + return err + } + return nil +} diff --git a/internal/config/config.go b/internal/config/config.go index 464dc14bc8c6d12cdf1db17c681c4faa68a59339..4c9dc7bafe83ff0b75b0a0238fcd71ba9e63a3bf 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -95,6 +95,8 @@ type ProviderConfig struct { Type catwalk.Type `json:"type,omitempty" jsonschema:"description=Provider type that determines the API format,enum=openai,enum=openai-compat,enum=anthropic,enum=gemini,enum=azure,enum=vertexai,default=openai"` // The provider's API key. APIKey string `json:"api_key,omitempty" jsonschema:"description=API key for authentication with the provider,example=$OPENAI_API_KEY"` + // The original API key template before resolution (for re-resolution on auth errors). + APIKeyTemplate string `json:"-"` // OAuthToken for providers that use OAuth2 authentication. OAuthToken *oauth.Token `json:"oauth,omitempty" jsonschema:"description=OAuth2 token for authentication with the provider"` // Marks the provider as disabled. @@ -469,6 +471,7 @@ func (c *Config) SetConfigField(key string, value any) error { return nil } +// RefreshOAuthToken refreshes the OAuth token for the given provider. func (c *Config) RefreshOAuthToken(ctx context.Context, providerID string) error { providerConfig, exists := c.Providers.Get(providerID) if !exists { @@ -479,7 +482,7 @@ func (c *Config) RefreshOAuthToken(ctx context.Context, providerID string) error return fmt.Errorf("provider %s does not have an OAuth token", providerID) } - // Only Anthropic provider uses OAuth for now + // Only Anthropic provider uses OAuth for now. if providerID != string(catwalk.InferenceProviderAnthropic) { return fmt.Errorf("OAuth refresh not supported for provider %s", providerID) } @@ -489,7 +492,7 @@ func (c *Config) RefreshOAuthToken(ctx context.Context, providerID string) error return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, err) } - slog.Info("Successfully refreshed OAuth token in background", "provider", providerID) + slog.Info("Successfully refreshed OAuth token", "provider", providerID) providerConfig.OAuthToken = newToken providerConfig.APIKey = fmt.Sprintf("Bearer %s", newToken.AccessToken) providerConfig.SetupClaudeCode() diff --git a/internal/config/load.go b/internal/config/load.go index 7645861198eefbceb1e283ee7815d3f130b0b868..8f3ad171d2ae2e196584223e55d24d42b200e073 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -1,7 +1,6 @@ package config import ( - "cmp" "context" "encoding/json" "fmt" @@ -19,11 +18,9 @@ import ( "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" - "github.com/charmbracelet/crush/internal/event" "github.com/charmbracelet/crush/internal/fsext" "github.com/charmbracelet/crush/internal/home" "github.com/charmbracelet/crush/internal/log" - "github.com/charmbracelet/crush/internal/oauth/claude" powernapConfig "github.com/charmbracelet/x/powernap/pkg/config" ) @@ -189,6 +186,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know Name: p.Name, BaseURL: p.APIEndpoint, APIKey: p.APIKey, + APIKeyTemplate: p.APIKey, // Store original template for re-resolution OAuthToken: config.OAuthToken, Type: p.Type, Disable: config.Disable, @@ -200,25 +198,6 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know } if p.ID == catwalk.InferenceProviderAnthropic && config.OAuthToken != nil { - if config.OAuthToken.IsExpired() { - newToken, err := claude.RefreshToken(context.TODO(), config.OAuthToken.RefreshToken) - if err == nil { - slog.Info("Successfully refreshed Anthropic OAuth token") - config.OAuthToken = newToken - prepared.OAuthToken = newToken - if err := cmp.Or( - c.SetConfigField("providers.anthropic.api_key", newToken.AccessToken), - c.SetConfigField("providers.anthropic.oauth", newToken), - ); err != nil { - return err - } - } else { - slog.Error("Failed to refresh Anthropic OAuth token", "error", err) - event.Error(err) - } - } else { - slog.Info("Using existing non-expired Anthropic OAuth token") - } prepared.SetupClaudeCode() }