@@ -184,13 +184,10 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string,
mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
- if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
- slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID)
- if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
- // NOTE(@andreynering): We don't return here because the event handling to ask the user to reauthenticate
- // depends on the flow below. If refresh fails, proceed with the token we have.
- slog.Error("Failed to refresh OAuth2 token. Proceeding with existing token.", "error", err)
- }
+ if err := c.refreshTokenIfExpired(ctx, providerCfg); err != nil {
+ // NOTE(@andreynering): We don't return here because the event handling to ask the user to reauthenticate
+ // depends on the flow below. If refresh fails, proceed with the token we have.
+ slog.Error("Failed to refresh OAuth2 token. Proceeding with existing token.", "error", err)
}
run := func() (*fantasy.AgentResult, error) {
@@ -212,20 +209,7 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string,
logTurnSkillUsage(sessionID, prompt, c.activeSkills, c.skillTracker, beforeLoaded)
if c.isUnauthorized(originalErr) {
- switch {
- case providerCfg.OAuthToken != nil:
- slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
- if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
- return nil, originalErr
- }
- slog.Debug("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
- return run()
- case strings.Contains(providerCfg.APIKeyTemplate, "$"):
- slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
- if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
- return nil, originalErr
- }
- slog.Debug("Retrying request with refreshed API key", "provider", providerCfg.ID)
+ if err := c.retryAfterUnauthorized(ctx, providerCfg); err == nil {
return run()
}
}
@@ -962,12 +946,8 @@ func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
return errModelProviderNotConfigured
}
- // Proactively refresh OAuth token if expired, same as Run().
- if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
- slog.Debug("Token needs to be refreshed before summarize", "provider", providerCfg.ID)
- if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
- slog.Error("Failed to refresh OAuth2 token before summarize. Proceeding with existing token.", "error", err)
- }
+ if err := c.refreshTokenIfExpired(ctx, providerCfg); err != nil {
+ slog.Error("Failed to refresh OAuth2 token before summarize. Proceeding with existing token.", "error", err)
}
summarize := func() error {
@@ -976,18 +956,7 @@ func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
err := summarize()
if err != nil && c.isUnauthorized(err) {
- switch {
- case providerCfg.OAuthToken != nil:
- slog.Debug("Received 401 during summarize. Refreshing token and retrying", "provider", providerCfg.ID)
- if refreshErr := c.refreshOAuth2Token(ctx, providerCfg); refreshErr != nil {
- return err
- }
- return summarize()
- case strings.Contains(providerCfg.APIKeyTemplate, "$"):
- slog.Debug("Received 401 during summarize. Refreshing API Key template and retrying", "provider", providerCfg.ID)
- if refreshErr := c.refreshApiKeyTemplate(ctx, providerCfg); refreshErr != nil {
- return err
- }
+ if retryErr := c.retryAfterUnauthorized(ctx, providerCfg); retryErr == nil {
return summarize()
}
}
@@ -995,6 +964,30 @@ func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
return err
}
+// refreshTokenIfExpired proactively refreshes the OAuth token if it has expired.
+func (c *coordinator) refreshTokenIfExpired(ctx context.Context, providerCfg config.ProviderConfig) error {
+ if providerCfg.OAuthToken == nil || !providerCfg.OAuthToken.IsExpired() {
+ return nil
+ }
+ slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID)
+ return c.refreshOAuth2Token(ctx, providerCfg)
+}
+
+// retryAfterUnauthorized attempts to refresh credentials after receiving a 401
+// and returns nil if retry should be attempted.
+func (c *coordinator) retryAfterUnauthorized(ctx context.Context, providerCfg config.ProviderConfig) error {
+ switch {
+ case providerCfg.OAuthToken != nil:
+ slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
+ return c.refreshOAuth2Token(ctx, providerCfg)
+ case strings.Contains(providerCfg.APIKeyTemplate, "$"):
+ slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
+ return c.refreshApiKeyTemplate(ctx, providerCfg)
+ default:
+ return nil
+ }
+}
+
func (c *coordinator) isUnauthorized(err error) bool {
var providerErr *fantasy.ProviderError
return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized