acp: Show retry button for errors (#36862)

Bennet Bo Fenner and Antonio Scandurra created

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>

Change summary

crates/acp_thread/src/acp_thread.rs    |   6 +
crates/acp_thread/src/connection.rs    |   8 
crates/agent2/src/agent.rs             |   8 
crates/agent2/src/tests/mod.rs         |  98 +++++++++++++++++++--
crates/agent2/src/thread.rs            | 124 +++++++++++++++------------
crates/agent_ui/src/acp/thread_view.rs |  46 ++++++++++
6 files changed, 212 insertions(+), 78 deletions(-)

Detailed changes

crates/acp_thread/src/acp_thread.rs 🔗

@@ -1373,6 +1373,10 @@ impl AcpThread {
         })
     }
 
+    pub fn can_resume(&self, cx: &App) -> bool {
+        self.connection.resume(&self.session_id, cx).is_some()
+    }
+
     pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
         self.run_turn(cx, async move |this, cx| {
             this.update(cx, |this, cx| {
@@ -2659,7 +2663,7 @@ mod tests {
         fn truncate(
             &self,
             session_id: &acp::SessionId,
-            _cx: &mut App,
+            _cx: &App,
         ) -> Option<Rc<dyn AgentSessionTruncate>> {
             Some(Rc::new(FakeAgentSessionEditor {
                 _session_id: session_id.clone(),

crates/acp_thread/src/connection.rs 🔗

@@ -43,7 +43,7 @@ pub trait AgentConnection {
     fn resume(
         &self,
         _session_id: &acp::SessionId,
-        _cx: &mut App,
+        _cx: &App,
     ) -> Option<Rc<dyn AgentSessionResume>> {
         None
     }
@@ -53,7 +53,7 @@ pub trait AgentConnection {
     fn truncate(
         &self,
         _session_id: &acp::SessionId,
-        _cx: &mut App,
+        _cx: &App,
     ) -> Option<Rc<dyn AgentSessionTruncate>> {
         None
     }
@@ -61,7 +61,7 @@ pub trait AgentConnection {
     fn set_title(
         &self,
         _session_id: &acp::SessionId,
-        _cx: &mut App,
+        _cx: &App,
     ) -> Option<Rc<dyn AgentSessionSetTitle>> {
         None
     }
@@ -439,7 +439,7 @@ mod test_support {
         fn truncate(
             &self,
             _session_id: &agent_client_protocol::SessionId,
-            _cx: &mut App,
+            _cx: &App,
         ) -> Option<Rc<dyn AgentSessionTruncate>> {
             Some(Rc::new(StubAgentSessionEditor))
         }

crates/agent2/src/agent.rs 🔗

@@ -936,7 +936,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
     fn resume(
         &self,
         session_id: &acp::SessionId,
-        _cx: &mut App,
+        _cx: &App,
     ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
         Some(Rc::new(NativeAgentSessionResume {
             connection: self.clone(),
@@ -956,9 +956,9 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
     fn truncate(
         &self,
         session_id: &agent_client_protocol::SessionId,
-        cx: &mut App,
+        cx: &App,
     ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
-        self.0.update(cx, |agent, _cx| {
+        self.0.read_with(cx, |agent, _cx| {
             agent.sessions.get(session_id).map(|session| {
                 Rc::new(NativeAgentSessionEditor {
                     thread: session.thread.clone(),
@@ -971,7 +971,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
     fn set_title(
         &self,
         session_id: &acp::SessionId,
-        _cx: &mut App,
+        _cx: &App,
     ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
         Some(Rc::new(NativeAgentSessionSetTitle {
             connection: self.clone(),

crates/agent2/src/tests/mod.rs 🔗

@@ -5,6 +5,7 @@ use agent_settings::AgentProfileId;
 use anyhow::Result;
 use client::{Client, UserStore};
 use cloud_llm_client::CompletionIntent;
+use collections::IndexMap;
 use context_server::{ContextServer, ContextServerCommand, ContextServerId};
 use fs::{FakeFs, Fs};
 use futures::{
@@ -673,15 +674,6 @@ async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
             "}
         )
     });
-
-    // Ensure we error if calling resume when tool use limit was *not* reached.
-    let error = thread
-        .update(cx, |thread, cx| thread.resume(cx))
-        .unwrap_err();
-    assert_eq!(
-        error.to_string(),
-        "can only resume after tool use limit is reached"
-    )
 }
 
 #[gpui::test]
@@ -2105,6 +2097,7 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
         .unwrap();
     cx.run_until_parked();
 
+    fake_model.send_last_completion_stream_text_chunk("Hey,");
     fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
         provider: LanguageModelProviderName::new("Anthropic"),
         retry_after: Some(Duration::from_secs(3)),
@@ -2114,8 +2107,9 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
     cx.executor().advance_clock(Duration::from_secs(3));
     cx.run_until_parked();
 
-    fake_model.send_last_completion_stream_text_chunk("Hey!");
+    fake_model.send_last_completion_stream_text_chunk("there!");
     fake_model.end_last_completion_stream();
+    cx.run_until_parked();
 
     let mut retry_events = Vec::new();
     while let Some(Ok(event)) = events.next().await {
@@ -2143,12 +2137,94 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
 
                 ## Assistant
 
-                Hey!
+                Hey,
+
+                [resume]
+
+                ## Assistant
+
+                there!
             "}
         )
     });
 }
 
+#[gpui::test]
+async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
+    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
+    let fake_model = model.as_fake();
+
+    let events = thread
+        .update(cx, |thread, cx| {
+            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
+            thread.add_tool(EchoTool);
+            thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
+        })
+        .unwrap();
+    cx.run_until_parked();
+
+    let tool_use_1 = LanguageModelToolUse {
+        id: "tool_1".into(),
+        name: EchoTool::name().into(),
+        raw_input: json!({"text": "test"}).to_string(),
+        input: json!({"text": "test"}),
+        is_input_complete: true,
+    };
+    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+        tool_use_1.clone(),
+    ));
+    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
+        provider: LanguageModelProviderName::new("Anthropic"),
+        retry_after: Some(Duration::from_secs(3)),
+    });
+    fake_model.end_last_completion_stream();
+
+    cx.executor().advance_clock(Duration::from_secs(3));
+    let completion = fake_model.pending_completions().pop().unwrap();
+    assert_eq!(
+        completion.messages[1..],
+        vec![
+            LanguageModelRequestMessage {
+                role: Role::User,
+                content: vec!["Call the echo tool!".into()],
+                cache: false
+            },
+            LanguageModelRequestMessage {
+                role: Role::Assistant,
+                content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
+                cache: false
+            },
+            LanguageModelRequestMessage {
+                role: Role::User,
+                content: vec![language_model::MessageContent::ToolResult(
+                    LanguageModelToolResult {
+                        tool_use_id: tool_use_1.id.clone(),
+                        tool_name: tool_use_1.name.clone(),
+                        is_error: false,
+                        content: "test".into(),
+                        output: Some("test".into())
+                    }
+                )],
+                cache: true
+            },
+        ]
+    );
+
+    fake_model.send_last_completion_stream_text_chunk("Done");
+    fake_model.end_last_completion_stream();
+    cx.run_until_parked();
+    events.collect::<Vec<_>>().await;
+    thread.read_with(cx, |thread, _cx| {
+        assert_eq!(
+            thread.last_message(),
+            Some(Message::Agent(AgentMessage {
+                content: vec![AgentMessageContent::Text("Done".into())],
+                tool_results: IndexMap::default()
+            }))
+        );
+    })
+}
+
 #[gpui::test]
 async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
     let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;

crates/agent2/src/thread.rs 🔗

@@ -123,7 +123,7 @@ impl Message {
         match self {
             Message::User(message) => message.to_markdown(),
             Message::Agent(message) => message.to_markdown(),
-            Message::Resume => "[resumed after tool use limit was reached]".into(),
+            Message::Resume => "[resume]\n".into(),
         }
     }
 
@@ -1085,11 +1085,6 @@ impl Thread {
         &mut self,
         cx: &mut Context<Self>,
     ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
-        anyhow::ensure!(
-            self.tool_use_limit_reached,
-            "can only resume after tool use limit is reached"
-        );
-
         self.messages.push(Message::Resume);
         cx.notify();
 
@@ -1216,12 +1211,13 @@ impl Thread {
         cx: &mut AsyncApp,
     ) -> Result<()> {
         log::debug!("Stream completion started successfully");
-        let request = this.update(cx, |this, cx| {
-            this.build_completion_request(completion_intent, cx)
-        })??;
 
         let mut attempt = None;
-        'retry: loop {
+        loop {
+            let request = this.update(cx, |this, cx| {
+                this.build_completion_request(completion_intent, cx)
+            })??;
+
             telemetry::event!(
                 "Agent Thread Completion",
                 thread_id = this.read_with(cx, |this, _| this.id.to_string())?,
@@ -1236,10 +1232,11 @@ impl Thread {
                 attempt.unwrap_or(0)
             );
             let mut events = model
-                .stream_completion(request.clone(), cx)
+                .stream_completion(request, cx)
                 .await
                 .map_err(|error| anyhow!(error))?;
             let mut tool_results = FuturesUnordered::new();
+            let mut error = None;
 
             while let Some(event) = events.next().await {
                 match event {
@@ -1249,51 +1246,9 @@ impl Thread {
                             this.handle_streamed_completion_event(event, event_stream, cx)
                         })??);
                     }
-                    Err(error) => {
-                        let completion_mode =
-                            this.read_with(cx, |thread, _cx| thread.completion_mode())?;
-                        if completion_mode == CompletionMode::Normal {
-                            return Err(anyhow!(error))?;
-                        }
-
-                        let Some(strategy) = Self::retry_strategy_for(&error) else {
-                            return Err(anyhow!(error))?;
-                        };
-
-                        let max_attempts = match &strategy {
-                            RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
-                            RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
-                        };
-
-                        let attempt = attempt.get_or_insert(0u8);
-
-                        *attempt += 1;
-
-                        let attempt = *attempt;
-                        if attempt > max_attempts {
-                            return Err(anyhow!(error))?;
-                        }
-
-                        let delay = match &strategy {
-                            RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
-                                let delay_secs =
-                                    initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
-                                Duration::from_secs(delay_secs)
-                            }
-                            RetryStrategy::Fixed { delay, .. } => *delay,
-                        };
-                        log::debug!("Retry attempt {attempt} with delay {delay:?}");
-
-                        event_stream.send_retry(acp_thread::RetryStatus {
-                            last_error: error.to_string().into(),
-                            attempt: attempt as usize,
-                            max_attempts: max_attempts as usize,
-                            started_at: Instant::now(),
-                            duration: delay,
-                        });
-
-                        cx.background_executor().timer(delay).await;
-                        continue 'retry;
+                    Err(err) => {
+                        error = Some(err);
+                        break;
                     }
                 }
             }
@@ -1320,7 +1275,58 @@ impl Thread {
                 })?;
             }
 
-            return Ok(());
+            if let Some(error) = error {
+                let completion_mode = this.read_with(cx, |thread, _cx| thread.completion_mode())?;
+                if completion_mode == CompletionMode::Normal {
+                    return Err(anyhow!(error))?;
+                }
+
+                let Some(strategy) = Self::retry_strategy_for(&error) else {
+                    return Err(anyhow!(error))?;
+                };
+
+                let max_attempts = match &strategy {
+                    RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
+                    RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
+                };
+
+                let attempt = attempt.get_or_insert(0u8);
+
+                *attempt += 1;
+
+                let attempt = *attempt;
+                if attempt > max_attempts {
+                    return Err(anyhow!(error))?;
+                }
+
+                let delay = match &strategy {
+                    RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
+                        let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
+                        Duration::from_secs(delay_secs)
+                    }
+                    RetryStrategy::Fixed { delay, .. } => *delay,
+                };
+                log::debug!("Retry attempt {attempt} with delay {delay:?}");
+
+                event_stream.send_retry(acp_thread::RetryStatus {
+                    last_error: error.to_string().into(),
+                    attempt: attempt as usize,
+                    max_attempts: max_attempts as usize,
+                    started_at: Instant::now(),
+                    duration: delay,
+                });
+                cx.background_executor().timer(delay).await;
+                this.update(cx, |this, cx| {
+                    this.flush_pending_message(cx);
+                    if let Some(Message::Agent(message)) = this.messages.last() {
+                        if message.tool_results.is_empty() {
+                            this.messages.push(Message::Resume);
+                        }
+                    }
+                })?;
+            } else {
+                return Ok(());
+            }
         }
     }
 
@@ -1737,6 +1743,10 @@ impl Thread {
             return;
         };
 
+        if message.content.is_empty() {
+            return;
+        }
+
         for content in &message.content {
             let AgentMessageContent::ToolUse(tool_use) = content else {
                 continue;

crates/agent_ui/src/acp/thread_view.rs 🔗

@@ -820,6 +820,9 @@ impl AcpThreadView {
         let Some(thread) = self.thread() else {
             return;
         };
+        if !thread.read(cx).can_resume(cx) {
+            return;
+        }
 
         let task = thread.update(cx, |thread, cx| thread.resume(cx));
         cx.spawn(async move |this, cx| {
@@ -4459,12 +4462,53 @@ impl AcpThreadView {
     }
 
     fn render_any_thread_error(&self, error: SharedString, cx: &mut Context<'_, Self>) -> Callout {
+        let can_resume = self
+            .thread()
+            .map_or(false, |thread| thread.read(cx).can_resume(cx));
+
+        let can_enable_burn_mode = self.as_native_thread(cx).map_or(false, |thread| {
+            let thread = thread.read(cx);
+            let supports_burn_mode = thread
+                .model()
+                .map_or(false, |model| model.supports_burn_mode());
+            supports_burn_mode && thread.completion_mode() == CompletionMode::Normal
+        });
+
         Callout::new()
             .severity(Severity::Error)
             .title("Error")
             .icon(IconName::XCircle)
             .description(error.clone())
-            .actions_slot(self.create_copy_button(error.to_string()))
+            .actions_slot(
+                h_flex()
+                    .gap_0p5()
+                    .when(can_resume && can_enable_burn_mode, |this| {
+                        this.child(
+                            Button::new("enable-burn-mode-and-retry", "Enable Burn Mode and Retry")
+                                .icon(IconName::ZedBurnMode)
+                                .icon_position(IconPosition::Start)
+                                .icon_size(IconSize::Small)
+                                .label_size(LabelSize::Small)
+                                .on_click(cx.listener(|this, _, window, cx| {
+                                    this.toggle_burn_mode(&ToggleBurnMode, window, cx);
+                                    this.resume_chat(cx);
+                                })),
+                        )
+                    })
+                    .when(can_resume, |this| {
+                        this.child(
+                            Button::new("retry", "Retry")
+                                .icon(IconName::RotateCw)
+                                .icon_position(IconPosition::Start)
+                                .icon_size(IconSize::Small)
+                                .label_size(LabelSize::Small)
+                                .on_click(cx.listener(|this, _, _window, cx| {
+                                    this.resume_chat(cx);
+                                })),
+                        )
+                    })
+                    .child(self.create_copy_button(error.to_string())),
+            )
             .dismiss_action(self.dismiss_error_button(cx))
     }