acp: Have `AcpThread` handle all interrupting (#36417)

Agus Zubiaga and Danilo created

The view was cancelling the generation, but `AcpThread` already handles
that, so we removed that extra code and fixed a bug where an update from
the first user message would appear after the second one.

Release Notes:

- N/A

Co-authored-by: Danilo <danilo@zed.dev>

Change summary

crates/acp_thread/src/acp_thread.rs    |  22 ++-
crates/acp_thread/src/connection.rs    |  27 +++--
crates/agent_ui/src/acp/thread_view.rs | 135 +++++++++++++++++++++++++++
3 files changed, 164 insertions(+), 20 deletions(-)

Detailed changes

crates/acp_thread/src/acp_thread.rs 🔗

@@ -1200,17 +1200,21 @@ impl AcpThread {
         } else {
             None
         };
-        self.push_entry(
-            AgentThreadEntry::UserMessage(UserMessage {
-                id: message_id.clone(),
-                content: block,
-                chunks: message,
-                checkpoint: None,
-            }),
-            cx,
-        );
 
         self.run_turn(cx, async move |this, cx| {
+            this.update(cx, |this, cx| {
+                this.push_entry(
+                    AgentThreadEntry::UserMessage(UserMessage {
+                        id: message_id.clone(),
+                        content: block,
+                        chunks: message,
+                        checkpoint: None,
+                    }),
+                    cx,
+                );
+            })
+            .ok();
+
             let old_checkpoint = git_store
                 .update(cx, |git, cx| git.checkpoint(cx))?
                 .await

crates/acp_thread/src/connection.rs 🔗

@@ -201,7 +201,7 @@ mod test_support {
 
     struct Session {
         thread: WeakEntity<AcpThread>,
-        response_tx: Option<oneshot::Sender<()>>,
+        response_tx: Option<oneshot::Sender<acp::StopReason>>,
     }
 
     impl StubAgentConnection {
@@ -242,12 +242,12 @@ mod test_support {
                 .unwrap()
                 .thread
                 .update(cx, |thread, cx| {
-                    thread.handle_session_update(update.clone(), cx).unwrap();
+                    thread.handle_session_update(update, cx).unwrap();
                 })
                 .unwrap();
         }
 
-        pub fn end_turn(&self, session_id: acp::SessionId) {
+        pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
             self.sessions
                 .lock()
                 .get_mut(&session_id)
@@ -255,7 +255,7 @@ mod test_support {
                 .response_tx
                 .take()
                 .expect("No pending turn")
-                .send(())
+                .send(stop_reason)
                 .unwrap();
         }
     }
@@ -308,10 +308,8 @@ mod test_support {
                 let (tx, rx) = oneshot::channel();
                 response_tx.replace(tx);
                 cx.spawn(async move |_| {
-                    rx.await?;
-                    Ok(acp::PromptResponse {
-                        stop_reason: acp::StopReason::EndTurn,
-                    })
+                    let stop_reason = rx.await?;
+                    Ok(acp::PromptResponse { stop_reason })
                 })
             } else {
                 for update in self.next_prompt_updates.lock().drain(..) {
@@ -353,8 +351,17 @@ mod test_support {
             }
         }
 
-        fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
-            unimplemented!()
+        fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
+            if let Some(end_turn_tx) = self
+                .sessions
+                .lock()
+                .get_mut(session_id)
+                .unwrap()
+                .response_tx
+                .take()
+            {
+                end_turn_tx.send(acp::StopReason::Canceled).unwrap();
+            }
         }
 
         fn session_editor(

crates/agent_ui/src/acp/thread_view.rs 🔗

@@ -4283,7 +4283,7 @@ pub(crate) mod tests {
                 },
                 cx,
             );
-            connection.end_turn(session_id);
+            connection.end_turn(session_id, acp::StopReason::EndTurn);
         });
 
         thread_view.read_with(cx, |view, _cx| {
@@ -4302,4 +4302,137 @@ pub(crate) mod tests {
             );
         });
     }
+
+    #[gpui::test]
+    async fn test_interrupt(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let connection = StubAgentConnection::new();
+
+        let (thread_view, cx) =
+            setup_thread_view(StubAgentServer::new(connection.clone()), cx).await;
+        add_to_workspace(thread_view.clone(), cx);
+
+        let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone());
+        message_editor.update_in(cx, |editor, window, cx| {
+            editor.set_text("Message 1", window, cx);
+        });
+        thread_view.update_in(cx, |thread_view, window, cx| {
+            thread_view.send(window, cx);
+        });
+
+        let (thread, session_id) = thread_view.read_with(cx, |view, cx| {
+            let thread = view.thread().unwrap();
+
+            (thread.clone(), thread.read(cx).session_id().clone())
+        });
+
+        cx.run_until_parked();
+
+        cx.update(|_, cx| {
+            connection.send_update(
+                session_id.clone(),
+                acp::SessionUpdate::AgentMessageChunk {
+                    content: "Message 1 resp".into(),
+                },
+                cx,
+            );
+        });
+
+        cx.run_until_parked();
+
+        thread.read_with(cx, |thread, cx| {
+            assert_eq!(
+                thread.to_markdown(cx),
+                indoc::indoc! {"
+                    ## User
+
+                    Message 1
+
+                    ## Assistant
+
+                    Message 1 resp
+
+                "}
+            )
+        });
+
+        message_editor.update_in(cx, |editor, window, cx| {
+            editor.set_text("Message 2", window, cx);
+        });
+        thread_view.update_in(cx, |thread_view, window, cx| {
+            thread_view.send(window, cx);
+        });
+
+        cx.update(|_, cx| {
+            // Simulate a response sent after beginning to cancel
+            connection.send_update(
+                session_id.clone(),
+                acp::SessionUpdate::AgentMessageChunk {
+                    content: "onse".into(),
+                },
+                cx,
+            );
+        });
+
+        cx.run_until_parked();
+
+        // Last Message 1 response should appear before Message 2
+        thread.read_with(cx, |thread, cx| {
+            assert_eq!(
+                thread.to_markdown(cx),
+                indoc::indoc! {"
+                    ## User
+
+                    Message 1
+
+                    ## Assistant
+
+                    Message 1 response
+
+                    ## User
+
+                    Message 2
+
+                "}
+            )
+        });
+
+        cx.update(|_, cx| {
+            connection.send_update(
+                session_id.clone(),
+                acp::SessionUpdate::AgentMessageChunk {
+                    content: "Message 2 response".into(),
+                },
+                cx,
+            );
+            connection.end_turn(session_id.clone(), acp::StopReason::EndTurn);
+        });
+
+        cx.run_until_parked();
+
+        thread.read_with(cx, |thread, cx| {
+            assert_eq!(
+                thread.to_markdown(cx),
+                indoc::indoc! {"
+                    ## User
+
+                    Message 1
+
+                    ## Assistant
+
+                    Message 1 response
+
+                    ## User
+
+                    Message 2
+
+                    ## Assistant
+
+                    Message 2 response
+
+                "}
+            )
+        });
+    }
 }