Properly cancel requests

Agus Zubiaga and Mikayla Maki created

Co-authored-by: Mikayla Maki <mikayla.c.maki@gmail.com>

Change summary

crates/acp/src/acp.rs              | 170 ++++++++++++++++++++++++++++++-
crates/acp/src/server.rs           |  41 +++---
crates/acp/src/thread_view.rs      | 143 ++++++++++++++------------
crates/agent_ui/src/agent_panel.rs |   2 
4 files changed, 264 insertions(+), 92 deletions(-)

Detailed changes

crates/acp/src/acp.rs 🔗

@@ -141,6 +141,7 @@ pub enum ToolCallStatus {
         status: acp::ToolCallStatus,
     },
     Rejected,
+    Canceled,
 }
 
 #[derive(Debug)]
@@ -359,6 +360,7 @@ pub struct AcpThread {
     server: Arc<AcpServer>,
     title: SharedString,
     project: Entity<Project>,
+    send_task: Option<Task<()>>,
 }
 
 enum AcpThreadEvent {
@@ -366,6 +368,13 @@ enum AcpThreadEvent {
     EntryUpdated(usize),
 }
 
+#[derive(PartialEq, Eq)]
+pub enum ThreadStatus {
+    Idle,
+    WaitingForToolConfirmation,
+    Generating,
+}
+
 impl EventEmitter<AcpThreadEvent> for AcpThread {}
 
 impl AcpThread {
@@ -378,7 +387,7 @@ impl AcpThread {
     ) -> Self {
         let mut next_entry_id = ThreadEntryId(0);
         Self {
-            title: "A new agent2 thread".into(),
+            title: "ACP Thread".into(),
             entries: entries
                 .into_iter()
                 .map(|entry| ThreadEntry {
@@ -390,6 +399,7 @@ impl AcpThread {
             id: thread_id,
             next_entry_id,
             project,
+            send_task: None,
         }
     }
 
@@ -401,6 +411,18 @@ impl AcpThread {
         &self.entries
     }
 
+    pub fn status(&self) -> ThreadStatus {
+        if self.send_task.is_some() {
+            if self.waiting_for_tool_confirmation() {
+                ThreadStatus::WaitingForToolConfirmation
+            } else {
+                ThreadStatus::Generating
+            }
+        } else {
+            ThreadStatus::Idle
+        }
+    }
+
     pub fn push_entry(
         &mut self,
         entry: AgentThreadEntryContent,
@@ -577,6 +599,10 @@ impl AcpThread {
                     ToolCallStatus::Rejected => {
                         anyhow::bail!("Tool call was rejected and therefore can't be updated")
                     }
+                    ToolCallStatus::Canceled => {
+                        // todo! test this case with fake server
+                        call.status = ToolCallStatus::Allowed { status: new_status };
+                    }
                 }
             }
             _ => anyhow::bail!("Entry is not a tool call"),
@@ -597,11 +623,14 @@ impl AcpThread {
 
     /// Returns true if the last turn is awaiting tool authorization
     pub fn waiting_for_tool_confirmation(&self) -> bool {
+        // todo!("should we use a hashmap?")
         for entry in self.entries.iter().rev() {
             match &entry.content {
                 AgentThreadEntryContent::ToolCall(call) => match call.status {
                     ToolCallStatus::WaitingForConfirmation { .. } => return true,
-                    ToolCallStatus::Allowed { .. } | ToolCallStatus::Rejected => continue,
+                    ToolCallStatus::Allowed { .. }
+                    | ToolCallStatus::Rejected
+                    | ToolCallStatus::Canceled => continue,
                 },
                 AgentThreadEntryContent::Message(_) => {
                     // Reached the beginning of the turn
@@ -612,9 +641,14 @@ impl AcpThread {
         false
     }
 
-    pub fn send(&mut self, message: &str, cx: &mut Context<Self>) -> Task<Result<()>> {
+    pub fn send(
+        &mut self,
+        message: &str,
+        cx: &mut Context<Self>,
+    ) -> impl use<> + Future<Output = Result<()>> {
         let agent = self.server.clone();
         let id = self.id.clone();
+
         let chunk = MessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx);
         let message = Message {
             role: Role::User,
@@ -622,10 +656,65 @@ impl AcpThread {
         };
         self.push_entry(AgentThreadEntryContent::Message(message.clone()), cx);
         let acp_message = message.into_acp(cx);
-        cx.spawn(async move |_, cx| {
-            agent.send_message(id, acp_message, cx).await?;
-            Ok(())
-        })
+
+        let (tx, rx) = oneshot::channel();
+        let cancel = self.cancel(cx);
+
+        self.send_task = Some(cx.spawn(async move |this, cx| {
+            cancel.await.log_err();
+
+            let result = agent.send_message(id, acp_message, cx).await;
+            tx.send(result).log_err();
+            this.update(cx, |this, _cx| this.send_task.take()).log_err();
+        }));
+
+        async move {
+            match rx.await {
+                Ok(result) => result,
+                Err(_) => Ok(()),
+            }
+        }
+    }
+
+    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
+        let agent = self.server.clone();
+        let id = self.id.clone();
+
+        if self.send_task.take().is_some() {
+            cx.spawn(async move |this, cx| {
+                agent.cancel_send_message(id, cx).await?;
+
+                this.update(cx, |this, _cx| {
+                    for entry in this.entries.iter_mut() {
+                        if let AgentThreadEntryContent::ToolCall(call) = &mut entry.content {
+                            let cancel = matches!(
+                                call.status,
+                                ToolCallStatus::WaitingForConfirmation { .. }
+                                    | ToolCallStatus::Allowed {
+                                        status: acp::ToolCallStatus::Running
+                                    }
+                            );
+
+                            if cancel {
+                                let curr_status =
+                                    mem::replace(&mut call.status, ToolCallStatus::Canceled);
+
+                                if let ToolCallStatus::WaitingForConfirmation {
+                                    respond_tx, ..
+                                } = curr_status
+                                {
+                                    respond_tx
+                                        .send(acp::ToolCallConfirmationOutcome::Cancel)
+                                        .ok();
+                                }
+                            }
+                        }
+                    }
+                })
+            })
+        } else {
+            Task::ready(Ok(()))
+        }
     }
 }
 
@@ -815,6 +904,73 @@ mod tests {
         });
     }
 
+    #[gpui::test]
+    async fn test_gemini_cancel(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        cx.executor().allow_parking();
+
+        let fs = FakeFs::new(cx.executor());
+        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
+        let server = gemini_acp_server(project.clone(), cx).await;
+        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
+        let full_turn = thread.update(cx, |thread, cx| {
+            thread.send(r#"Run `echo "Hello, world!"`"#, cx)
+        });
+
+        run_until_tool_call(&thread, cx).await;
+
+        thread.read_with(cx, |thread, _cx| {
+            let AgentThreadEntryContent::ToolCall(ToolCall {
+                id,
+                status:
+                    ToolCallStatus::WaitingForConfirmation {
+                        confirmation: ToolCallConfirmation::Execute { root_command, .. },
+                        ..
+                    },
+                ..
+            }) = &thread.entries()[1].content
+            else {
+                panic!();
+            };
+
+            assert_eq!(root_command, "echo");
+
+            *id
+        });
+
+        thread
+            .update(cx, |thread, cx| thread.cancel(cx))
+            .await
+            .unwrap();
+        full_turn.await.unwrap();
+        thread.read_with(cx, |thread, _| {
+            let AgentThreadEntryContent::ToolCall(ToolCall {
+                status: ToolCallStatus::Canceled,
+                ..
+            }) = &thread.entries()[1].content
+            else {
+                panic!();
+            };
+        });
+
+        thread
+            .update(cx, |thread, cx| {
+                thread.send(r#"Stop running and say goodbye to me."#, cx)
+            })
+            .await
+            .unwrap();
+        thread.read_with(cx, |thread, _| {
+            let AgentThreadEntryContent::Message(Message {
+                role: Role::Assistant,
+                ..
+            }) = &thread.entries()[3].content
+            else {
+                panic!();
+            };
+        });
+    }
+
     async fn run_until_tool_call(thread: &Entity<AcpThread>, cx: &mut TestAppContext) {
         let (mut tx, mut rx) = mpsc::channel::<()>(1);
 

crates/acp/src/server.rs 🔗

@@ -20,23 +20,14 @@ pub struct AcpServer {
 }
 
 struct AcpClientDelegate {
-    project: Entity<Project>,
     threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
     cx: AsyncApp,
     // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
 }
 
 impl AcpClientDelegate {
-    fn new(
-        project: Entity<Project>,
-        threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
-        cx: AsyncApp,
-    ) -> Self {
-        Self {
-            project,
-            threads,
-            cx: cx,
-        }
+    fn new(threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>, cx: AsyncApp) -> Self {
+        Self { threads, cx: cx }
     }
 
     fn update_thread<R>(
@@ -143,7 +134,7 @@ impl AcpServer {
 
         let threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>> = Default::default();
         let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
-            AcpClientDelegate::new(project.clone(), threads.clone(), cx.to_async()),
+            AcpClientDelegate::new(threads.clone(), cx.to_async()),
             stdin,
             stdout,
         );
@@ -193,14 +184,14 @@ impl AcpServer {
 
         let thread_id: ThreadId = response.thread_id.into();
         let server = self.clone();
-        let thread = cx.new(|_| AcpThread {
-            // todo!
-            title: "ACP Thread".into(),
-            id: thread_id.clone(), // Either<ErrorState, Id>
-            next_entry_id: ThreadEntryId(0),
-            entries: Vec::default(),
-            project: self.project.clone(),
-            server,
+        let thread = cx.new(|cx| {
+            AcpThread::new(
+                server,
+                thread_id.clone(),
+                Vec::default(),
+                self.project.clone(),
+                cx,
+            )
         })?;
         self.threads.lock().insert(thread_id, thread.downgrade());
         Ok(thread)
@@ -222,6 +213,16 @@ impl AcpServer {
         Ok(())
     }
 
+    pub async fn cancel_send_message(&self, thread_id: ThreadId, _cx: &mut AsyncApp) -> Result<()> {
+        self.connection
+            .request(acp::CancelSendMessageParams {
+                thread_id: thread_id.clone().into(),
+            })
+            .await
+            .map_err(to_anyhow)?;
+        Ok(())
+    }
+
     pub fn exit_status(&self) -> Option<ExitStatus> {
         *self.exit_status.lock()
     }

crates/acp/src/thread_view.rs 🔗

@@ -4,7 +4,6 @@ use std::sync::Arc;
 use std::time::Duration;
 
 use agentic_coding_protocol::{self as acp};
-use anyhow::Result;
 use editor::{Editor, EditorMode, MinimapVisibility, MultiBuffer};
 use gpui::{
     Animation, AnimationExt, App, EdgesRefinement, Empty, Entity, Focusable, ListState,
@@ -25,7 +24,8 @@ use zed_actions::agent::Chat;
 
 use crate::{
     AcpServer, AcpThread, AcpThreadEvent, AgentThreadEntryContent, Diff, MessageChunk, Role,
-    ThreadEntry, ToolCall, ToolCallConfirmation, ToolCallContent, ToolCallId, ToolCallStatus,
+    ThreadEntry, ThreadStatus, ToolCall, ToolCallConfirmation, ToolCallContent, ToolCallId,
+    ToolCallStatus,
 };
 
 pub struct AcpThreadView {
@@ -36,7 +36,6 @@ pub struct AcpThreadView {
     message_editor: Entity<Editor>,
     last_error: Option<Entity<Markdown>>,
     list_state: ListState,
-    send_task: Option<Task<Result<()>>>,
     auth_task: Option<Task<()>>,
 }
 
@@ -123,7 +122,6 @@ impl AcpThreadView {
             agent,
             message_editor,
             thread_entry_views: Vec::new(),
-            send_task: None,
             list_state: list_state,
             last_error: None,
             auth_task: None,
@@ -203,8 +201,12 @@ impl AcpThreadView {
         }
     }
 
-    pub fn cancel(&mut self) {
-        self.send_task.take();
+    pub fn cancel(&mut self, cx: &mut Context<Self>) {
+        self.last_error.take();
+
+        if let Some(thread) = self.thread() {
+            thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
+        }
     }
 
     fn chat(&mut self, _: &Chat, window: &mut Window, cx: &mut Context<Self>) {
@@ -217,7 +219,7 @@ impl AcpThreadView {
 
         let task = thread.update(cx, |thread, cx| thread.send(&text, cx));
 
-        self.send_task = Some(cx.spawn(async move |this, cx| {
+        cx.spawn(async move |this, cx| {
             let result = task.await;
 
             this.update(cx, |this, cx| {
@@ -227,9 +229,9 @@ impl AcpThreadView {
                             Markdown::new(format!("Error: {err}").into(), None, None, cx)
                         }))
                 }
-                this.send_task.take();
             })
-        }));
+        })
+        .detach();
 
         self.message_editor.update(cx, |editor, cx| {
             editor.clear(window, cx);
@@ -467,6 +469,7 @@ impl AcpThreadView {
                 .size(IconSize::Small)
                 .into_any_element(),
             ToolCallStatus::Rejected
+            | ToolCallStatus::Canceled
             | ToolCallStatus::Allowed {
                 status: acp::ToolCallStatus::Error,
                 ..
@@ -487,15 +490,17 @@ impl AcpThreadView {
                     cx,
                 ))
             }
-            ToolCallStatus::Allowed { .. } => tool_call.content.as_ref().map(|content| {
-                div()
-                    .border_color(cx.theme().colors().border)
-                    .border_t_1()
-                    .px_2()
-                    .py_1p5()
-                    .child(self.render_tool_call_content(entry_ix, content, window, cx))
-                    .into_any_element()
-            }),
+            ToolCallStatus::Allowed { .. } | ToolCallStatus::Canceled => {
+                tool_call.content.as_ref().map(|content| {
+                    div()
+                        .border_color(cx.theme().colors().border)
+                        .border_t_1()
+                        .px_2()
+                        .py_1p5()
+                        .child(self.render_tool_call_content(entry_ix, content, window, cx))
+                        .into_any_element()
+                })
+            }
             ToolCallStatus::Rejected => None,
         };
 
@@ -1016,18 +1021,21 @@ impl Render for AcpThreadView {
                             .with_sizing_behavior(gpui::ListSizingBehavior::Auto)
                             .flex_grow(),
                     )
-                    .child(div().px_3().children(if self.send_task.is_none() {
-                        None
-                    } else {
-                        Label::new(if thread.read(cx).waiting_for_tool_confirmation() {
-                            "Waiting for tool confirmation"
-                        } else {
-                            "Generating..."
-                        })
-                        .color(Color::Muted)
-                        .size(LabelSize::Small)
-                        .into()
-                    })),
+                    .child(
+                        div().px_3().children(match thread.read(cx).status() {
+                            ThreadStatus::Idle => None,
+                            ThreadStatus::WaitingForToolConfirmation => {
+                                Label::new("Waiting for tool confirmation")
+                                    .color(Color::Muted)
+                                    .size(LabelSize::Small)
+                                    .into()
+                            }
+                            ThreadStatus::Generating => Label::new("Generating...")
+                                .color(Color::Muted)
+                                .size(LabelSize::Small)
+                                .into(),
+                        }),
+                    ),
             })
             .when_some(self.last_error.clone(), |el, error| {
                 el.child(
@@ -1052,40 +1060,47 @@ impl Render for AcpThreadView {
                     .p_2()
                     .gap_2()
                     .child(self.message_editor.clone())
-                    .child(h_flex().justify_end().child(if self.send_task.is_some() {
-                        IconButton::new("stop-generation", IconName::StopFilled)
-                            .icon_color(Color::Error)
-                            .style(ButtonStyle::Tinted(ui::TintColor::Error))
-                            .tooltip(move |window, cx| {
-                                Tooltip::for_action(
-                                    "Stop Generation",
-                                    &editor::actions::Cancel,
-                                    window,
-                                    cx,
-                                )
-                            })
-                            .disabled(is_editor_empty)
-                            .on_click(cx.listener(|this, _event, _, _| this.cancel()))
-                    } else {
-                        IconButton::new("send-message", IconName::Send)
-                            .icon_color(Color::Accent)
-                            .style(ButtonStyle::Filled)
-                            .disabled(is_editor_empty)
-                            .on_click({
-                                let focus_handle = focus_handle.clone();
-                                move |_event, window, cx| {
-                                    focus_handle.dispatch_action(&Chat, window, cx);
-                                }
-                            })
-                            .when(!is_editor_empty, |button| {
-                                button.tooltip(move |window, cx| {
-                                    Tooltip::for_action("Send", &Chat, window, cx)
-                                })
-                            })
-                            .when(is_editor_empty, |button| {
-                                button.tooltip(Tooltip::text("Type a message to submit"))
-                            })
-                    })),
+                    .child({
+                        let thread = self.thread();
+
+                        h_flex().justify_end().child(
+                            if thread.map_or(true, |thread| {
+                                thread.read(cx).status() == ThreadStatus::Idle
+                            }) {
+                                IconButton::new("send-message", IconName::Send)
+                                    .icon_color(Color::Accent)
+                                    .style(ButtonStyle::Filled)
+                                    .disabled(thread.is_none() || is_editor_empty)
+                                    .on_click({
+                                        let focus_handle = focus_handle.clone();
+                                        move |_event, window, cx| {
+                                            focus_handle.dispatch_action(&Chat, window, cx);
+                                        }
+                                    })
+                                    .when(!is_editor_empty, |button| {
+                                        button.tooltip(move |window, cx| {
+                                            Tooltip::for_action("Send", &Chat, window, cx)
+                                        })
+                                    })
+                                    .when(is_editor_empty, |button| {
+                                        button.tooltip(Tooltip::text("Type a message to submit"))
+                                    })
+                            } else {
+                                IconButton::new("stop-generation", IconName::StopFilled)
+                                    .icon_color(Color::Error)
+                                    .style(ButtonStyle::Tinted(ui::TintColor::Error))
+                                    .tooltip(move |window, cx| {
+                                        Tooltip::for_action(
+                                            "Stop Generation",
+                                            &editor::actions::Cancel,
+                                            window,
+                                            cx,
+                                        )
+                                    })
+                                    .on_click(cx.listener(|this, _event, _, cx| this.cancel(cx)))
+                            },
+                        )
+                    }),
             )
     }
 }

crates/agent_ui/src/agent_panel.rs 🔗

@@ -753,7 +753,7 @@ impl AgentPanel {
                 thread.update(cx, |thread, cx| thread.cancel_last_completion(window, cx));
             }
             ActiveView::AcpThread { thread_view, .. } => {
-                thread_view.update(cx, |thread_element, _cx| thread_element.cancel());
+                thread_view.update(cx, |thread_element, cx| thread_element.cancel(cx));
             }
             ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {}
         }