Fix acp generating status after stop (#35852)

Agus Zubiaga created

Release Notes:

- N/A

Change summary

crates/acp_thread/src/acp_thread.rs | 73 +++++++++++++++++-------------
1 file changed, 41 insertions(+), 32 deletions(-)

Detailed changes

crates/acp_thread/src/acp_thread.rs 🔗

@@ -6,7 +6,6 @@ use anyhow::{Context as _, Result};
 use assistant_tool::ActionLog;
 use buffer_diff::BufferDiff;
 use editor::{Bias, MultiBuffer, PathKey};
-use futures::future::{Fuse, FusedFuture};
 use futures::{FutureExt, channel::oneshot, future::BoxFuture};
 use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
 use itertools::Itertools;
@@ -580,7 +579,7 @@ pub struct AcpThread {
     project: Entity<Project>,
     action_log: Entity<ActionLog>,
     shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
-    send_task: Option<Fuse<Task<()>>>,
+    send_task: Option<Task<()>>,
     connection: Rc<dyn AgentConnection>,
     session_id: acp::SessionId,
 }
@@ -670,11 +669,7 @@ impl AcpThread {
     }
 
     pub fn status(&self) -> ThreadStatus {
-        if self
-            .send_task
-            .as_ref()
-            .map_or(false, |t| !t.is_terminated())
-        {
+        if self.send_task.is_some() {
             if self.waiting_for_tool_confirmation() {
                 ThreadStatus::WaitingForToolConfirmation
             } else {
@@ -1049,31 +1044,29 @@ impl AcpThread {
         let (tx, rx) = oneshot::channel();
         let cancel_task = self.cancel(cx);
 
-        self.send_task = Some(
-            cx.spawn(async move |this, cx| {
-                async {
-                    cancel_task.await;
-
-                    let result = this
-                        .update(cx, |this, cx| {
-                            this.connection.prompt(
-                                acp::PromptRequest {
-                                    prompt: message,
-                                    session_id: this.session_id.clone(),
-                                },
-                                cx,
-                            )
-                        })?
-                        .await;
+        self.send_task = Some(cx.spawn(async move |this, cx| {
+            async {
+                cancel_task.await;
 
-                    tx.send(result).log_err();
-                    anyhow::Ok(())
-                }
-                .await
-                .log_err();
-            })
-            .fuse(),
-        );
+                let result = this
+                    .update(cx, |this, cx| {
+                        this.connection.prompt(
+                            acp::PromptRequest {
+                                prompt: message,
+                                session_id: this.session_id.clone(),
+                            },
+                            cx,
+                        )
+                    })?
+                    .await;
+
+                tx.send(result).log_err();
+
+                anyhow::Ok(())
+            }
+            .await
+            .log_err();
+        }));
 
         cx.spawn(async move |this, cx| match rx.await {
             Ok(Err(e)) => {
@@ -1081,7 +1074,23 @@ impl AcpThread {
                     .log_err();
                 Err(e)?
             }
-            _ => {
+            result => {
+                let cancelled = matches!(
+                    result,
+                    Ok(Ok(acp::PromptResponse {
+                        stop_reason: acp::StopReason::Cancelled
+                    }))
+                );
+
+                // We only take the task if the current prompt wasn't cancelled.
+                //
+                // This prompt may have been cancelled because another one was sent
+                // while it was still generating. In these cases, dropping `send_task`
+                // would cause the next generation to be cancelled.
+                if !cancelled {
+                    this.update(cx, |this, _cx| this.send_task.take()).ok();
+                }
+
                 this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
                     .log_err();
                 Ok(())