agent2: Expand auto-retries for completion errors (#39787)

Marshall Bowers and David Kleingeld created

This PR expands our automatic retry behavior for certain classes of
completion errors (e.g., rate limit errors).

Previously this was only available when using burn mode.

We now auto-retry when:

- Using the Zed provider while on a token-based plan
- Using the Zed provider while on a legacy plan with burn mode enabled
- Using a non-Zed provider

Release Notes:

- Expanded automatic retry behavior for errors in the Agent. Errors
classified as "retryable" (such as rate limit errors) will now
automatically be retried when:
  - Using the Zed provider while on a token-based plan
  - Using the Zed provider while on a legacy plan with burn mode enabled
  - Using a non-Zed provider

---------

Co-authored-by: David Kleingeld <davidsk@zed.dev>

Change summary

crates/agent2/src/thread.rs | 43 +++++++++++++++++++++++++++++---------
1 file changed, 32 insertions(+), 11 deletions(-)

Detailed changes

crates/agent2/src/thread.rs 🔗

@@ -15,10 +15,11 @@ use agent_settings::{
 use anyhow::{Context as _, Result, anyhow};
 use assistant_tool::adapt_schema_to_format;
 use chrono::{DateTime, Utc};
-use client::{ModelRequestUsage, RequestUsage};
-use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
+use client::{ModelRequestUsage, RequestUsage, UserStore};
+use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit};
 use collections::{HashMap, HashSet, IndexMap};
 use fs::Fs;
+use futures::stream;
 use futures::{
     FutureExt,
     channel::{mpsc, oneshot},
@@ -34,7 +35,7 @@ use language_model::{
     LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
     LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
     LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse,
-    LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage,
+    LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage, ZED_CLOUD_PROVIDER_ID,
 };
 use project::{
     Project,
@@ -585,6 +586,7 @@ pub struct Thread {
     pending_title_generation: Option<Task<()>>,
     summary: Option<SharedString>,
     messages: Vec<Message>,
+    user_store: Entity<UserStore>,
     completion_mode: CompletionMode,
     /// Holds the task that handles agent interaction until the end of the turn.
     /// Survives across multiple requests as the model performs tool calls and
@@ -641,6 +643,7 @@ impl Thread {
             pending_title_generation: None,
             summary: None,
             messages: Vec::new(),
+            user_store: project.read(cx).user_store(),
             completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
             running_turn: None,
             pending_message: None,
@@ -820,6 +823,7 @@ impl Thread {
             pending_title_generation: None,
             summary: db_thread.detailed_summary,
             messages: db_thread.messages,
+            user_store: project.read(cx).user_store(),
             completion_mode: db_thread.completion_mode.unwrap_or_default(),
             running_turn: None,
             pending_message: None,
@@ -1249,12 +1253,12 @@ impl Thread {
             );
 
             log::debug!("Calling model.stream_completion, attempt {}", attempt);
-            let mut events = model
-                .stream_completion(request, cx)
-                .await
-                .map_err(|error| anyhow!(error))?;
+
+            let (mut events, mut error) = match model.stream_completion(request, cx).await {
+                Ok(events) => (events, None),
+                Err(err) => (stream::empty().boxed(), Some(err)),
+            };
             let mut tool_results = FuturesUnordered::new();
-            let mut error = None;
             while let Some(event) = events.next().await {
                 log::trace!("Received completion event: {:?}", event);
                 match event {
@@ -1302,8 +1306,10 @@ impl Thread {
 
             if let Some(error) = error {
                 attempt += 1;
-                let retry =
-                    this.update(cx, |this, _| this.handle_completion_error(error, attempt))??;
+                let retry = this.update(cx, |this, cx| {
+                    let user_store = this.user_store.read(cx);
+                    this.handle_completion_error(error, attempt, user_store.plan())
+                })??;
                 let timer = cx.background_executor().timer(retry.duration);
                 event_stream.send_retry(retry);
                 timer.await;
@@ -1330,8 +1336,23 @@ impl Thread {
         &mut self,
         error: LanguageModelCompletionError,
         attempt: u8,
+        plan: Option<Plan>,
     ) -> Result<acp_thread::RetryStatus> {
-        if self.completion_mode == CompletionMode::Normal {
+        let Some(model) = self.model.as_ref() else {
+            return Err(anyhow!(error));
+        };
+
+        let auto_retry = if model.provider_id() == ZED_CLOUD_PROVIDER_ID {
+            match plan {
+                Some(Plan::V2(_)) => true,
+                Some(Plan::V1(_)) => self.completion_mode == CompletionMode::Burn,
+                None => false,
+            }
+        } else {
+            true
+        };
+
+        if !auto_retry {
             return Err(anyhow!(error));
         }