refactor(coordinator): extract token refresh helpers to reduce duplication

Andrey Nering created

Extract `refreshTokenIfExpired` and `retryAfterUnauthorized` functions
to eliminate duplicated OAuth/API key refresh logic in both `Run` and
`Summarize`.

💘 Generated with Crush

Assisted-by: Kimi K2.6 via Crush <crush@charm.land>

Change summary

internal/agent/coordinator.go | 71 ++++++++++++++++--------------------
1 file changed, 32 insertions(+), 39 deletions(-)

Detailed changes

internal/agent/coordinator.go 🔗

@@ -184,13 +184,10 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string,
 
 	mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
 
-	if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
-		slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID)
-		if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
-			// NOTE(@andreynering): We don't return here because the event handling to ask the user to reauthenticate
-			// depends on the flow below. If refresh fails, proceed with the token we have.
-			slog.Error("Failed to refresh OAuth2 token. Proceeding with existing token.", "error", err)
-		}
+	if err := c.refreshTokenIfExpired(ctx, providerCfg); err != nil {
+		// NOTE(@andreynering): We don't return here because the event handling to ask the user to reauthenticate
+		// depends on the flow below. If refresh fails, proceed with the token we have.
+		slog.Error("Failed to refresh OAuth2 token. Proceeding with existing token.", "error", err)
 	}
 
 	run := func() (*fantasy.AgentResult, error) {
@@ -212,20 +209,7 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string,
 	logTurnSkillUsage(sessionID, prompt, c.activeSkills, c.skillTracker, beforeLoaded)
 
 	if c.isUnauthorized(originalErr) {
-		switch {
-		case providerCfg.OAuthToken != nil:
-			slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
-			if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
-				return nil, originalErr
-			}
-			slog.Debug("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
-			return run()
-		case strings.Contains(providerCfg.APIKeyTemplate, "$"):
-			slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
-			if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
-				return nil, originalErr
-			}
-			slog.Debug("Retrying request with refreshed API key", "provider", providerCfg.ID)
+		if err := c.retryAfterUnauthorized(ctx, providerCfg); err == nil {
 			return run()
 		}
 	}
@@ -962,12 +946,8 @@ func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
 		return errModelProviderNotConfigured
 	}
 
-	// Proactively refresh OAuth token if expired, same as Run().
-	if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
-		slog.Debug("Token needs to be refreshed before summarize", "provider", providerCfg.ID)
-		if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
-			slog.Error("Failed to refresh OAuth2 token before summarize. Proceeding with existing token.", "error", err)
-		}
+	if err := c.refreshTokenIfExpired(ctx, providerCfg); err != nil {
+		slog.Error("Failed to refresh OAuth2 token before summarize. Proceeding with existing token.", "error", err)
 	}
 
 	summarize := func() error {
@@ -976,18 +956,7 @@ func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
 
 	err := summarize()
 	if err != nil && c.isUnauthorized(err) {
-		switch {
-		case providerCfg.OAuthToken != nil:
-			slog.Debug("Received 401 during summarize. Refreshing token and retrying", "provider", providerCfg.ID)
-			if refreshErr := c.refreshOAuth2Token(ctx, providerCfg); refreshErr != nil {
-				return err
-			}
-			return summarize()
-		case strings.Contains(providerCfg.APIKeyTemplate, "$"):
-			slog.Debug("Received 401 during summarize. Refreshing API Key template and retrying", "provider", providerCfg.ID)
-			if refreshErr := c.refreshApiKeyTemplate(ctx, providerCfg); refreshErr != nil {
-				return err
-			}
+		if retryErr := c.retryAfterUnauthorized(ctx, providerCfg); retryErr == nil {
 			return summarize()
 		}
 	}
@@ -995,6 +964,30 @@ func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
 	return err
 }
 
+// refreshTokenIfExpired proactively refreshes the OAuth token if it has expired.
+func (c *coordinator) refreshTokenIfExpired(ctx context.Context, providerCfg config.ProviderConfig) error {
+	if providerCfg.OAuthToken == nil || !providerCfg.OAuthToken.IsExpired() {
+		return nil
+	}
+	slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID)
+	return c.refreshOAuth2Token(ctx, providerCfg)
+}
+
+// retryAfterUnauthorized attempts to refresh credentials after receiving a 401
+// and returns nil if retry should be attempted.
+func (c *coordinator) retryAfterUnauthorized(ctx context.Context, providerCfg config.ProviderConfig) error {
+	switch {
+	case providerCfg.OAuthToken != nil:
+		slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
+		return c.refreshOAuth2Token(ctx, providerCfg)
+	case strings.Contains(providerCfg.APIKeyTemplate, "$"):
+		slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
+		return c.refreshApiKeyTemplate(ctx, providerCfg)
+	default:
+		return nil
+	}
+}
+
 func (c *coordinator) isUnauthorized(err error) bool {
 	var providerErr *fantasy.ProviderError
 	return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized