From 8cd4786cc66b29da2412e6cc9f40cfa364cd1a17 Mon Sep 17 00:00:00 2001 From: Andrey Nering Date: Wed, 6 May 2026 17:23:32 -0300 Subject: [PATCH] refactor(coordinator): extract token refresh helpers to reduce duplication MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extract `refreshTokenIfExpired` and `retryAfterUnauthorized` functions to eliminate duplicated OAuth/API key refresh logic in both `Run` and `Summarize`. 💘 Generated with Crush Assisted-by: Kimi K2.6 via Crush --- internal/agent/coordinator.go | 71 ++++++++++++++++------------------- 1 file changed, 32 insertions(+), 39 deletions(-) diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index c4c1b6e6204dca3fd1e13c5ef304daf8c889cdbb..f9209ee6bf853b698211c56353c315b3379459f3 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -184,13 +184,10 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg) - if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() { - slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID) - if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil { - // NOTE(@andreynering): We don't return here because the event handling to ask the user to reauthenticate - // depends on the flow below. If refresh fails, proceed with the token we have. - slog.Error("Failed to refresh OAuth2 token. Proceeding with existing token.", "error", err) - } + if err := c.refreshTokenIfExpired(ctx, providerCfg); err != nil { + // NOTE(@andreynering): We don't return here because the event handling to ask the user to reauthenticate + // depends on the flow below. If refresh fails, proceed with the token we have. + slog.Error("Failed to refresh OAuth2 token. Proceeding with existing token.", "error", err) } run := func() (*fantasy.AgentResult, error) { @@ -212,20 +209,7 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, logTurnSkillUsage(sessionID, prompt, c.activeSkills, c.skillTracker, beforeLoaded) if c.isUnauthorized(originalErr) { - switch { - case providerCfg.OAuthToken != nil: - slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID) - if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil { - return nil, originalErr - } - slog.Debug("Retrying request with refreshed OAuth token", "provider", providerCfg.ID) - return run() - case strings.Contains(providerCfg.APIKeyTemplate, "$"): - slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID) - if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil { - return nil, originalErr - } - slog.Debug("Retrying request with refreshed API key", "provider", providerCfg.ID) + if err := c.retryAfterUnauthorized(ctx, providerCfg); err == nil { return run() } } @@ -962,12 +946,8 @@ func (c *coordinator) Summarize(ctx context.Context, sessionID string) error { return errModelProviderNotConfigured } - // Proactively refresh OAuth token if expired, same as Run(). - if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() { - slog.Debug("Token needs to be refreshed before summarize", "provider", providerCfg.ID) - if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil { - slog.Error("Failed to refresh OAuth2 token before summarize. Proceeding with existing token.", "error", err) - } + if err := c.refreshTokenIfExpired(ctx, providerCfg); err != nil { + slog.Error("Failed to refresh OAuth2 token before summarize. Proceeding with existing token.", "error", err) } summarize := func() error { @@ -976,18 +956,7 @@ func (c *coordinator) Summarize(ctx context.Context, sessionID string) error { err := summarize() if err != nil && c.isUnauthorized(err) { - switch { - case providerCfg.OAuthToken != nil: - slog.Debug("Received 401 during summarize. Refreshing token and retrying", "provider", providerCfg.ID) - if refreshErr := c.refreshOAuth2Token(ctx, providerCfg); refreshErr != nil { - return err - } - return summarize() - case strings.Contains(providerCfg.APIKeyTemplate, "$"): - slog.Debug("Received 401 during summarize. Refreshing API Key template and retrying", "provider", providerCfg.ID) - if refreshErr := c.refreshApiKeyTemplate(ctx, providerCfg); refreshErr != nil { - return err - } + if retryErr := c.retryAfterUnauthorized(ctx, providerCfg); retryErr == nil { return summarize() } } @@ -995,6 +964,30 @@ func (c *coordinator) Summarize(ctx context.Context, sessionID string) error { return err } +// refreshTokenIfExpired proactively refreshes the OAuth token if it has expired. +func (c *coordinator) refreshTokenIfExpired(ctx context.Context, providerCfg config.ProviderConfig) error { + if providerCfg.OAuthToken == nil || !providerCfg.OAuthToken.IsExpired() { + return nil + } + slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID) + return c.refreshOAuth2Token(ctx, providerCfg) +} + +// retryAfterUnauthorized attempts to refresh credentials after receiving a 401 +// and returns nil if retry should be attempted. +func (c *coordinator) retryAfterUnauthorized(ctx context.Context, providerCfg config.ProviderConfig) error { + switch { + case providerCfg.OAuthToken != nil: + slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID) + return c.refreshOAuth2Token(ctx, providerCfg) + case strings.Contains(providerCfg.APIKeyTemplate, "$"): + slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID) + return c.refreshApiKeyTemplate(ctx, providerCfg) + default: + return nil + } +} + func (c *coordinator) isUnauthorized(err error) bool { var providerErr *fantasy.ProviderError return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized