fix: refresh oauth token in the background

Raphael Amorim and Kujtim Hoxha created

Co-authored-by: Kujtim Hoxha <kujtimii.h@gmail.com>

Change summary

internal/agent/coordinator.go | 16 ++++++++++++++
internal/config/config.go     | 38 +++++++++++++++++++++++++++++++++++++
2 files changed, 53 insertions(+), 1 deletion(-)

Detailed changes

internal/agent/coordinator.go 🔗

@@ -130,7 +130,20 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string,
 
 	mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
 
-	return c.currentAgent.Run(ctx, SessionAgentCall{
+	if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
+		slog.Info("Detected expired OAuth token, attempting refresh", "provider", providerCfg.ID)
+		if refreshErr := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); refreshErr != nil {
+			slog.Error("Failed to refresh OAuth token", "provider", providerCfg.ID, "error", refreshErr)
+			return nil, refreshErr
+		}
+
+		// Rebuild models with refreshed token
+		if updateErr := c.UpdateModels(ctx); updateErr != nil {
+			slog.Error("Failed to update models after token refresh", "error", updateErr)
+			return nil, updateErr
+		}
+	}
+	result, err := c.currentAgent.Run(ctx, SessionAgentCall{
 		SessionID:        sessionID,
 		Prompt:           prompt,
 		Attachments:      attachments,
@@ -142,6 +155,7 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string,
 		FrequencyPenalty: freqPenalty,
 		PresencePenalty:  presPenalty,
 	})
+	return result, err
 }
 
 func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {

internal/config/config.go 🔗

@@ -16,6 +16,7 @@ import (
 	"github.com/charmbracelet/crush/internal/csync"
 	"github.com/charmbracelet/crush/internal/env"
 	"github.com/charmbracelet/crush/internal/oauth"
+	"github.com/charmbracelet/crush/internal/oauth/claude"
 	"github.com/invopop/jsonschema"
 	"github.com/tidwall/sjson"
 )
@@ -468,6 +469,43 @@ func (c *Config) SetConfigField(key string, value any) error {
 	return nil
 }
 
+func (c *Config) RefreshOAuthToken(ctx context.Context, providerID string) error {
+	providerConfig, exists := c.Providers.Get(providerID)
+	if !exists {
+		return fmt.Errorf("provider %s not found", providerID)
+	}
+
+	if providerConfig.OAuthToken == nil {
+		return fmt.Errorf("provider %s does not have an OAuth token", providerID)
+	}
+
+	// Only Anthropic provider uses OAuth for now
+	if providerID != string(catwalk.InferenceProviderAnthropic) {
+		return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
+	}
+
+	newToken, err := claude.RefreshToken(ctx, providerConfig.OAuthToken.RefreshToken)
+	if err != nil {
+		return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, err)
+	}
+
+	slog.Info("Successfully refreshed OAuth token in background", "provider", providerID)
+	providerConfig.OAuthToken = newToken
+	providerConfig.APIKey = fmt.Sprintf("Bearer %s", newToken.AccessToken)
+	providerConfig.SetupClaudeCode()
+
+	c.Providers.Set(providerID, providerConfig)
+
+	if err := cmp.Or(
+		c.SetConfigField(fmt.Sprintf("providers.%s.api_key", providerID), newToken.AccessToken),
+		c.SetConfigField(fmt.Sprintf("providers.%s.oauth", providerID), newToken),
+	); err != nil {
+		return fmt.Errorf("failed to persist refreshed token: %w", err)
+	}
+
+	return nil
+}
+
 func (c *Config) SetProviderAPIKey(providerID string, apiKey any) error {
 	var providerConfig ProviderConfig
 	var exists bool