agent2: Port retry logic (#36421)

Bennet Bo Fenner created

Release Notes:

- N/A

Change summary

crates/acp_thread/src/acp_thread.rs        |  15 +
crates/agent2/src/agent.rs                 |   5 
crates/agent2/src/tests/mod.rs             | 166 +++++++++++++
crates/agent2/src/thread.rs                | 286 ++++++++++++++++++++---
crates/agent_ui/src/acp/thread_view.rs     |  61 ++++
crates/agent_ui/src/agent_diff.rs          |   1 
crates/language_model/src/fake_provider.rs |  32 ++
7 files changed, 514 insertions(+), 52 deletions(-)

Detailed changes

crates/acp_thread/src/acp_thread.rs 🔗

@@ -24,6 +24,7 @@ use std::fmt::{Formatter, Write};
 use std::ops::Range;
 use std::process::ExitStatus;
 use std::rc::Rc;
+use std::time::{Duration, Instant};
 use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
 use ui::App;
 use util::ResultExt;
@@ -658,6 +659,15 @@ impl PlanEntry {
     }
 }
 
+#[derive(Debug, Clone)]
+pub struct RetryStatus {
+    pub last_error: SharedString,
+    pub attempt: usize,
+    pub max_attempts: usize,
+    pub started_at: Instant,
+    pub duration: Duration,
+}
+
 pub struct AcpThread {
     title: SharedString,
     entries: Vec<AgentThreadEntry>,
@@ -676,6 +686,7 @@ pub enum AcpThreadEvent {
     EntryUpdated(usize),
     EntriesRemoved(Range<usize>),
     ToolAuthorizationRequired,
+    Retry(RetryStatus),
     Stopped,
     Error,
     ServerExited(ExitStatus),
@@ -916,6 +927,10 @@ impl AcpThread {
         cx.emit(AcpThreadEvent::NewEntry);
     }
 
+    pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
+        cx.emit(AcpThreadEvent::Retry(status));
+    }
+
     pub fn update_tool_call(
         &mut self,
         update: impl Into<ToolCallUpdate>,

crates/agent2/src/agent.rs 🔗

@@ -546,6 +546,11 @@ impl NativeAgentConnection {
                                     thread.update_tool_call(update, cx)
                                 })??;
                             }
+                            AgentResponseEvent::Retry(status) => {
+                                acp_thread.update(cx, |thread, cx| {
+                                    thread.update_retry_status(status, cx)
+                                })?;
+                            }
                             AgentResponseEvent::Stop(stop_reason) => {
                                 log::debug!("Assistant message complete: {:?}", stop_reason);
                                 return Ok(acp::PromptResponse { stop_reason });

crates/agent2/src/tests/mod.rs 🔗

@@ -6,15 +6,16 @@ use agent_settings::AgentProfileId;
 use anyhow::Result;
 use client::{Client, UserStore};
 use fs::{FakeFs, Fs};
-use futures::channel::mpsc::UnboundedReceiver;
+use futures::{StreamExt, channel::mpsc::UnboundedReceiver};
 use gpui::{
     App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
 };
 use indoc::indoc;
 use language_model::{
-    LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry,
-    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, MessageContent,
-    Role, StopReason, fake_provider::FakeLanguageModel,
+    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
+    LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequestMessage,
+    LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
+    fake_provider::FakeLanguageModel,
 };
 use pretty_assertions::assert_eq;
 use project::Project;
@@ -24,7 +25,6 @@ use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use serde_json::json;
 use settings::SettingsStore;
-use smol::stream::StreamExt;
 use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
 use util::path;
 
@@ -1435,6 +1435,162 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
     );
 }
 
+#[gpui::test]
+async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
+    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
+    let fake_model = model.as_fake();
+
+    let mut events = thread
+        .update(cx, |thread, cx| {
+            thread.set_completion_mode(agent_settings::CompletionMode::Burn);
+            thread.send(UserMessageId::new(), ["Hello!"], cx)
+        })
+        .unwrap();
+    cx.run_until_parked();
+
+    fake_model.send_last_completion_stream_text_chunk("Hey!");
+    fake_model.end_last_completion_stream();
+
+    let mut retry_events = Vec::new();
+    while let Some(Ok(event)) = events.next().await {
+        match event {
+            AgentResponseEvent::Retry(retry_status) => {
+                retry_events.push(retry_status);
+            }
+            AgentResponseEvent::Stop(..) => break,
+            _ => {}
+        }
+    }
+
+    assert_eq!(retry_events.len(), 0);
+    thread.read_with(cx, |thread, _cx| {
+        assert_eq!(
+            thread.to_markdown(),
+            indoc! {"
+                ## User
+
+                Hello!
+
+                ## Assistant
+
+                Hey!
+            "}
+        )
+    });
+}
+
+#[gpui::test]
+async fn test_send_retry_on_error(cx: &mut TestAppContext) {
+    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
+    let fake_model = model.as_fake();
+
+    let mut events = thread
+        .update(cx, |thread, cx| {
+            thread.set_completion_mode(agent_settings::CompletionMode::Burn);
+            thread.send(UserMessageId::new(), ["Hello!"], cx)
+        })
+        .unwrap();
+    cx.run_until_parked();
+
+    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
+        provider: LanguageModelProviderName::new("Anthropic"),
+        retry_after: Some(Duration::from_secs(3)),
+    });
+    fake_model.end_last_completion_stream();
+
+    cx.executor().advance_clock(Duration::from_secs(3));
+    cx.run_until_parked();
+
+    fake_model.send_last_completion_stream_text_chunk("Hey!");
+    fake_model.end_last_completion_stream();
+
+    let mut retry_events = Vec::new();
+    while let Some(Ok(event)) = events.next().await {
+        match event {
+            AgentResponseEvent::Retry(retry_status) => {
+                retry_events.push(retry_status);
+            }
+            AgentResponseEvent::Stop(..) => break,
+            _ => {}
+        }
+    }
+
+    assert_eq!(retry_events.len(), 1);
+    assert!(matches!(
+        retry_events[0],
+        acp_thread::RetryStatus { attempt: 1, .. }
+    ));
+    thread.read_with(cx, |thread, _cx| {
+        assert_eq!(
+            thread.to_markdown(),
+            indoc! {"
+                ## User
+
+                Hello!
+
+                ## Assistant
+
+                Hey!
+            "}
+        )
+    });
+}
+
+#[gpui::test]
+async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
+    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
+    let fake_model = model.as_fake();
+
+    let mut events = thread
+        .update(cx, |thread, cx| {
+            thread.set_completion_mode(agent_settings::CompletionMode::Burn);
+            thread.send(UserMessageId::new(), ["Hello!"], cx)
+        })
+        .unwrap();
+    cx.run_until_parked();
+
+    for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
+        fake_model.send_last_completion_stream_error(
+            LanguageModelCompletionError::ServerOverloaded {
+                provider: LanguageModelProviderName::new("Anthropic"),
+                retry_after: Some(Duration::from_secs(3)),
+            },
+        );
+        fake_model.end_last_completion_stream();
+        cx.executor().advance_clock(Duration::from_secs(3));
+        cx.run_until_parked();
+    }
+
+    let mut errors = Vec::new();
+    let mut retry_events = Vec::new();
+    while let Some(event) = events.next().await {
+        match event {
+            Ok(AgentResponseEvent::Retry(retry_status)) => {
+                retry_events.push(retry_status);
+            }
+            Ok(AgentResponseEvent::Stop(..)) => break,
+            Err(error) => errors.push(error),
+            _ => {}
+        }
+    }
+
+    assert_eq!(
+        retry_events.len(),
+        crate::thread::MAX_RETRY_ATTEMPTS as usize
+    );
+    for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
+        assert_eq!(retry_events[i].attempt, i + 1);
+    }
+    assert_eq!(errors.len(), 1);
+    let error = errors[0]
+        .downcast_ref::<LanguageModelCompletionError>()
+        .unwrap();
+    assert!(matches!(
+        error,
+        LanguageModelCompletionError::ServerOverloaded { .. }
+    ));
+}
+
 /// Filters out the stop events for asserting against in tests
 fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopReason> {
     result_events

crates/agent2/src/thread.rs 🔗

@@ -12,12 +12,12 @@ use futures::{
     channel::{mpsc, oneshot},
     stream::FuturesUnordered,
 };
-use gpui::{App, Context, Entity, SharedString, Task};
+use gpui::{App, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
 use language_model::{
-    LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
-    LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
-    LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
-    LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
+    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
+    LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
+    LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
+    LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
 };
 use project::Project;
 use prompt_store::ProjectContext;
@@ -25,7 +25,12 @@ use schemars::{JsonSchema, Schema};
 use serde::{Deserialize, Serialize};
 use settings::{Settings, update_settings_file};
 use smol::stream::StreamExt;
-use std::{collections::BTreeMap, path::Path, sync::Arc};
+use std::{
+    collections::BTreeMap,
+    path::Path,
+    sync::Arc,
+    time::{Duration, Instant},
+};
 use std::{fmt::Write, ops::Range};
 use util::{ResultExt, markdown::MarkdownCodeBlock};
 use uuid::Uuid;
@@ -71,6 +76,21 @@ impl std::fmt::Display for PromptId {
     }
 }
 
+pub(crate) const MAX_RETRY_ATTEMPTS: u8 = 4;
+pub(crate) const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
+
+#[derive(Debug, Clone)]
+enum RetryStrategy {
+    ExponentialBackoff {
+        initial_delay: Duration,
+        max_attempts: u8,
+    },
+    Fixed {
+        delay: Duration,
+        max_attempts: u8,
+    },
+}
+
 #[derive(Debug, Clone, PartialEq, Eq)]
 pub enum Message {
     User(UserMessage),
@@ -455,6 +475,7 @@ pub enum AgentResponseEvent {
     ToolCall(acp::ToolCall),
     ToolCallUpdate(acp_thread::ToolCallUpdate),
     ToolCallAuthorization(ToolCallAuthorization),
+    Retry(acp_thread::RetryStatus),
     Stop(acp::StopReason),
 }
 
@@ -662,41 +683,18 @@ impl Thread {
                         })??;
 
                         log::info!("Calling model.stream_completion");
-                        let mut events = model.stream_completion(request, cx).await?;
-                        log::debug!("Stream completion started successfully");
 
                         let mut tool_use_limit_reached = false;
-                        let mut tool_uses = FuturesUnordered::new();
-                        while let Some(event) = events.next().await {
-                            match event? {
-                                LanguageModelCompletionEvent::StatusUpdate(
-                                    CompletionRequestStatus::ToolUseLimitReached,
-                                ) => {
-                                    tool_use_limit_reached = true;
-                                }
-                                LanguageModelCompletionEvent::Stop(reason) => {
-                                    event_stream.send_stop(reason);
-                                    if reason == StopReason::Refusal {
-                                        this.update(cx, |this, _cx| {
-                                            this.flush_pending_message();
-                                            this.messages.truncate(message_ix);
-                                        })?;
-                                        return Ok(());
-                                    }
-                                }
-                                event => {
-                                    log::trace!("Received completion event: {:?}", event);
-                                    this.update(cx, |this, cx| {
-                                        tool_uses.extend(this.handle_streamed_completion_event(
-                                            event,
-                                            &event_stream,
-                                            cx,
-                                        ));
-                                    })
-                                    .ok();
-                                }
-                            }
-                        }
+                        let mut tool_uses = Self::stream_completion_with_retries(
+                            this.clone(),
+                            model.clone(),
+                            request,
+                            message_ix,
+                            &event_stream,
+                            &mut tool_use_limit_reached,
+                            cx,
+                        )
+                        .await?;
 
                         let used_tools = tool_uses.is_empty();
                         while let Some(tool_result) = tool_uses.next().await {
@@ -754,10 +752,105 @@ impl Thread {
         Ok(events_rx)
     }
 
+    async fn stream_completion_with_retries(
+        this: WeakEntity<Self>,
+        model: Arc<dyn LanguageModel>,
+        request: LanguageModelRequest,
+        message_ix: usize,
+        event_stream: &AgentResponseEventStream,
+        tool_use_limit_reached: &mut bool,
+        cx: &mut AsyncApp,
+    ) -> Result<FuturesUnordered<Task<LanguageModelToolResult>>> {
+        log::debug!("Stream completion started successfully");
+
+        let mut attempt = None;
+        'retry: loop {
+            let mut events = model.stream_completion(request.clone(), cx).await?;
+            let mut tool_uses = FuturesUnordered::new();
+            while let Some(event) = events.next().await {
+                match event {
+                    Ok(LanguageModelCompletionEvent::StatusUpdate(
+                        CompletionRequestStatus::ToolUseLimitReached,
+                    )) => {
+                        *tool_use_limit_reached = true;
+                    }
+                    Ok(LanguageModelCompletionEvent::Stop(reason)) => {
+                        event_stream.send_stop(reason);
+                        if reason == StopReason::Refusal {
+                            this.update(cx, |this, _cx| {
+                                this.flush_pending_message();
+                                this.messages.truncate(message_ix);
+                            })?;
+                            return Ok(tool_uses);
+                        }
+                    }
+                    Ok(event) => {
+                        log::trace!("Received completion event: {:?}", event);
+                        this.update(cx, |this, cx| {
+                            tool_uses.extend(this.handle_streamed_completion_event(
+                                event,
+                                event_stream,
+                                cx,
+                            ));
+                        })
+                        .ok();
+                    }
+                    Err(error) => {
+                        let completion_mode =
+                            this.read_with(cx, |thread, _cx| thread.completion_mode())?;
+                        if completion_mode == CompletionMode::Normal {
+                            return Err(error.into());
+                        }
+
+                        let Some(strategy) = Self::retry_strategy_for(&error) else {
+                            return Err(error.into());
+                        };
+
+                        let max_attempts = match &strategy {
+                            RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
+                            RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
+                        };
+
+                        let attempt = attempt.get_or_insert(0u8);
+
+                        *attempt += 1;
+
+                        let attempt = *attempt;
+                        if attempt > max_attempts {
+                            return Err(error.into());
+                        }
+
+                        let delay = match &strategy {
+                            RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
+                                let delay_secs =
+                                    initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
+                                Duration::from_secs(delay_secs)
+                            }
+                            RetryStrategy::Fixed { delay, .. } => *delay,
+                        };
+                        log::debug!("Retry attempt {attempt} with delay {delay:?}");
+
+                        event_stream.send_retry(acp_thread::RetryStatus {
+                            last_error: error.to_string().into(),
+                            attempt: attempt as usize,
+                            max_attempts: max_attempts as usize,
+                            started_at: Instant::now(),
+                            duration: delay,
+                        });
+
+                        cx.background_executor().timer(delay).await;
+                        continue 'retry;
+                    }
+                }
+            }
+            return Ok(tool_uses);
+        }
+    }
+
     pub fn build_system_message(&self, cx: &App) -> LanguageModelRequestMessage {
         log::debug!("Building system message");
         let prompt = SystemPromptTemplate {
-            project: &self.project_context.read(cx),
+            project: self.project_context.read(cx),
             available_tools: self.tools.keys().cloned().collect(),
         }
         .render(&self.templates)
@@ -1158,6 +1251,113 @@ impl Thread {
     fn advance_prompt_id(&mut self) {
         self.prompt_id = PromptId::new();
     }
+
+    fn retry_strategy_for(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
+        use LanguageModelCompletionError::*;
+        use http_client::StatusCode;
+
+        // General strategy here:
+        // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
+        // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff.
+        // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times.
+        match error {
+            HttpResponseError {
+                status_code: StatusCode::TOO_MANY_REQUESTS,
+                ..
+            } => Some(RetryStrategy::ExponentialBackoff {
+                initial_delay: BASE_RETRY_DELAY,
+                max_attempts: MAX_RETRY_ATTEMPTS,
+            }),
+            ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
+                Some(RetryStrategy::Fixed {
+                    delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
+                    max_attempts: MAX_RETRY_ATTEMPTS,
+                })
+            }
+            UpstreamProviderError {
+                status,
+                retry_after,
+                ..
+            } => match *status {
+                StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => {
+                    Some(RetryStrategy::Fixed {
+                        delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
+                        max_attempts: MAX_RETRY_ATTEMPTS,
+                    })
+                }
+                StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed {
+                    delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
+                    // Internal Server Error could be anything, retry up to 3 times.
+                    max_attempts: 3,
+                }),
+                status => {
+                    // There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"),
+                    // but we frequently get them in practice. See https://http.dev/529
+                    if status.as_u16() == 529 {
+                        Some(RetryStrategy::Fixed {
+                            delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
+                            max_attempts: MAX_RETRY_ATTEMPTS,
+                        })
+                    } else {
+                        Some(RetryStrategy::Fixed {
+                            delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
+                            max_attempts: 2,
+                        })
+                    }
+                }
+            },
+            ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
+                delay: BASE_RETRY_DELAY,
+                max_attempts: 3,
+            }),
+            ApiReadResponseError { .. }
+            | HttpSend { .. }
+            | DeserializeResponse { .. }
+            | BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
+                delay: BASE_RETRY_DELAY,
+                max_attempts: 3,
+            }),
+            // Retrying these errors definitely shouldn't help.
+            HttpResponseError {
+                status_code:
+                    StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
+                ..
+            }
+            | AuthenticationError { .. }
+            | PermissionError { .. }
+            | NoApiKey { .. }
+            | ApiEndpointNotFound { .. }
+            | PromptTooLarge { .. } => None,
+            // These errors might be transient, so retry them
+            SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed {
+                delay: BASE_RETRY_DELAY,
+                max_attempts: 1,
+            }),
+            // Retry all other 4xx and 5xx errors once.
+            HttpResponseError { status_code, .. }
+                if status_code.is_client_error() || status_code.is_server_error() =>
+            {
+                Some(RetryStrategy::Fixed {
+                    delay: BASE_RETRY_DELAY,
+                    max_attempts: 3,
+                })
+            }
+            Other(err)
+                if err.is::<language_model::PaymentRequiredError>()
+                    || err.is::<language_model::ModelRequestLimitReachedError>() =>
+            {
+                // Retrying won't help for Payment Required or Model Request Limit errors (where
+                // the user must upgrade to usage-based billing to get more requests, or else wait
+                // for a significant amount of time for the request limit to reset).
+                None
+            }
+            // Conservatively assume that any other errors are non-retryable
+            HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
+                delay: BASE_RETRY_DELAY,
+                max_attempts: 2,
+            }),
+        }
+    }
 }
 
 struct RunningTurn {
@@ -1367,6 +1567,12 @@ impl AgentResponseEventStream {
             .ok();
     }
 
+    fn send_retry(&self, status: acp_thread::RetryStatus) {
+        self.0
+            .unbounded_send(Ok(AgentResponseEvent::Retry(status)))
+            .ok();
+    }
+
     fn send_stop(&self, reason: StopReason) {
         match reason {
             StopReason::EndTurn => {

crates/agent_ui/src/acp/thread_view.rs 🔗

@@ -1,7 +1,7 @@
 use acp_thread::{
     AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk,
-    AuthRequired, LoadError, MentionUri, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus,
-    UserMessageId,
+    AuthRequired, LoadError, MentionUri, RetryStatus, ThreadStatus, ToolCall, ToolCallContent,
+    ToolCallStatus, UserMessageId,
 };
 use acp_thread::{AgentConnection, Plan};
 use action_log::ActionLog;
@@ -35,6 +35,7 @@ use prompt_store::PromptId;
 use rope::Point;
 use settings::{Settings as _, SettingsStore};
 use std::sync::Arc;
+use std::time::Instant;
 use std::{collections::BTreeMap, process::ExitStatus, rc::Rc, time::Duration};
 use text::Anchor;
 use theme::ThemeSettings;
@@ -115,6 +116,7 @@ pub struct AcpThreadView {
     profile_selector: Option<Entity<ProfileSelector>>,
     notifications: Vec<WindowHandle<AgentNotification>>,
     notification_subscriptions: HashMap<WindowHandle<AgentNotification>, Vec<Subscription>>,
+    thread_retry_status: Option<RetryStatus>,
     thread_error: Option<ThreadError>,
     list_state: ListState,
     scrollbar_state: ScrollbarState,
@@ -209,6 +211,7 @@ impl AcpThreadView {
             notification_subscriptions: HashMap::default(),
             list_state: list_state.clone(),
             scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()),
+            thread_retry_status: None,
             thread_error: None,
             auth_task: None,
             expanded_tool_calls: HashSet::default(),
@@ -445,6 +448,7 @@ impl AcpThreadView {
 
     pub fn cancel_generation(&mut self, cx: &mut Context<Self>) {
         self.thread_error.take();
+        self.thread_retry_status.take();
 
         if let Some(thread) = self.thread() {
             self._cancel_task = Some(thread.update(cx, |thread, cx| thread.cancel(cx)));
@@ -775,7 +779,11 @@ impl AcpThreadView {
             AcpThreadEvent::ToolAuthorizationRequired => {
                 self.notify_with_sound("Waiting for tool confirmation", IconName::Info, window, cx);
             }
+            AcpThreadEvent::Retry(retry) => {
+                self.thread_retry_status = Some(retry.clone());
+            }
             AcpThreadEvent::Stopped => {
+                self.thread_retry_status.take();
                 let used_tools = thread.read(cx).used_tools_since_last_user_message();
                 self.notify_with_sound(
                     if used_tools {
@@ -789,6 +797,7 @@ impl AcpThreadView {
                 );
             }
             AcpThreadEvent::Error => {
+                self.thread_retry_status.take();
                 self.notify_with_sound(
                     "Agent stopped due to an error",
                     IconName::Warning,
@@ -797,6 +806,7 @@ impl AcpThreadView {
                 );
             }
             AcpThreadEvent::ServerExited(status) => {
+                self.thread_retry_status.take();
                 self.thread_state = ThreadState::ServerExited { status: *status };
             }
         }
@@ -3413,7 +3423,51 @@ impl AcpThreadView {
         })
     }
 
-    fn render_thread_error(&self, window: &mut Window, cx: &mut Context<'_, Self>) -> Option<Div> {
+    fn render_thread_retry_status_callout(
+        &self,
+        _window: &mut Window,
+        _cx: &mut Context<Self>,
+    ) -> Option<Callout> {
+        let state = self.thread_retry_status.as_ref()?;
+
+        let next_attempt_in = state
+            .duration
+            .saturating_sub(Instant::now().saturating_duration_since(state.started_at));
+        if next_attempt_in.is_zero() {
+            return None;
+        }
+
+        let next_attempt_in_secs = next_attempt_in.as_secs() + 1;
+
+        let retry_message = if state.max_attempts == 1 {
+            if next_attempt_in_secs == 1 {
+                "Retrying. Next attempt in 1 second.".to_string()
+            } else {
+                format!("Retrying. Next attempt in {next_attempt_in_secs} seconds.")
+            }
+        } else {
+            if next_attempt_in_secs == 1 {
+                format!(
+                    "Retrying. Next attempt in 1 second (Attempt {} of {}).",
+                    state.attempt, state.max_attempts,
+                )
+            } else {
+                format!(
+                    "Retrying. Next attempt in {next_attempt_in_secs} seconds (Attempt {} of {}).",
+                    state.attempt, state.max_attempts,
+                )
+            }
+        };
+
+        Some(
+            Callout::new()
+                .severity(Severity::Warning)
+                .title(state.last_error.clone())
+                .description(retry_message),
+        )
+    }
+
+    fn render_thread_error(&self, window: &mut Window, cx: &mut Context<Self>) -> Option<Div> {
         let content = match self.thread_error.as_ref()? {
             ThreadError::Other(error) => self.render_any_thread_error(error.clone(), cx),
             ThreadError::PaymentRequired => self.render_payment_required_error(cx),
@@ -3678,6 +3732,7 @@ impl Render for AcpThreadView {
                 }
                 _ => this,
             })
+            .children(self.render_thread_retry_status_callout(window, cx))
             .children(self.render_thread_error(window, cx))
             .child(self.render_message_editor(window, cx))
     }

crates/agent_ui/src/agent_diff.rs 🔗

@@ -1523,6 +1523,7 @@ impl AgentDiff {
             AcpThreadEvent::EntriesRemoved(_)
             | AcpThreadEvent::Stopped
             | AcpThreadEvent::ToolAuthorizationRequired
+            | AcpThreadEvent::Retry(_)
             | AcpThreadEvent::Error
             | AcpThreadEvent::ServerExited(_) => {}
         }

crates/language_model/src/fake_provider.rs 🔗

@@ -4,10 +4,11 @@ use crate::{
     LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
     LanguageModelRequest, LanguageModelToolChoice,
 };
-use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
+use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
 use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
 use http_client::Result;
 use parking_lot::Mutex;
+use smol::stream::StreamExt;
 use std::sync::Arc;
 
 #[derive(Clone)]
@@ -100,7 +101,9 @@ pub struct FakeLanguageModel {
     current_completion_txs: Mutex<
         Vec<(
             LanguageModelRequest,
-            mpsc::UnboundedSender<LanguageModelCompletionEvent>,
+            mpsc::UnboundedSender<
+                Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
+            >,
         )>,
     >,
 }
@@ -150,7 +153,21 @@ impl FakeLanguageModel {
             .find(|(req, _)| req == request)
             .map(|(_, tx)| tx)
             .unwrap();
-        tx.unbounded_send(event.into()).unwrap();
+        tx.unbounded_send(Ok(event.into())).unwrap();
+    }
+
+    pub fn send_completion_stream_error(
+        &self,
+        request: &LanguageModelRequest,
+        error: impl Into<LanguageModelCompletionError>,
+    ) {
+        let current_completion_txs = self.current_completion_txs.lock();
+        let tx = current_completion_txs
+            .iter()
+            .find(|(req, _)| req == request)
+            .map(|(_, tx)| tx)
+            .unwrap();
+        tx.unbounded_send(Err(error.into())).unwrap();
     }
 
     pub fn end_completion_stream(&self, request: &LanguageModelRequest) {
@@ -170,6 +187,13 @@ impl FakeLanguageModel {
         self.send_completion_stream_event(self.pending_completions().last().unwrap(), event);
     }
 
+    pub fn send_last_completion_stream_error(
+        &self,
+        error: impl Into<LanguageModelCompletionError>,
+    ) {
+        self.send_completion_stream_error(self.pending_completions().last().unwrap(), error);
+    }
+
     pub fn end_last_completion_stream(&self) {
         self.end_completion_stream(self.pending_completions().last().unwrap());
     }
@@ -229,7 +253,7 @@ impl LanguageModel for FakeLanguageModel {
     > {
         let (tx, rx) = mpsc::unbounded();
         self.current_completion_txs.lock().push((request, tx));
-        async move { Ok(rx.map(Ok).boxed()) }.boxed()
+        async move { Ok(rx.boxed()) }.boxed()
     }
 
     fn as_fake(&self) -> &Self {