agent2: Emit cancellation stop reason on cancel (#36381)

Ben Brandt and Antonio Scandurra created

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>

Change summary

crates/agent2/src/tests/mod.rs |  66 ++++++++++
crates/agent2/src/thread.rs    | 218 ++++++++++++++++++++---------------
2 files changed, 191 insertions(+), 93 deletions(-)

Detailed changes

crates/agent2/src/tests/mod.rs 🔗

@@ -941,7 +941,15 @@ async fn test_cancellation(cx: &mut TestAppContext) {
     // Cancel the current send and ensure that the event stream is closed, even
     // if one of the tools is still running.
     thread.update(cx, |thread, _cx| thread.cancel());
-    events.collect::<Vec<_>>().await;
+    let events = events.collect::<Vec<_>>().await;
+    let last_event = events.last();
+    assert!(
+        matches!(
+            last_event,
+            Some(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled)))
+        ),
+        "unexpected event {last_event:?}"
+    );
 
     // Ensure we can still send a new message after cancellation.
     let events = thread
@@ -965,6 +973,62 @@ async fn test_cancellation(cx: &mut TestAppContext) {
     assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 }
 
+#[gpui::test]
+async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
+    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+    let fake_model = model.as_fake();
+
+    let events_1 = thread.update(cx, |thread, cx| {
+        thread.send(UserMessageId::new(), ["Hello 1"], cx)
+    });
+    cx.run_until_parked();
+    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
+    cx.run_until_parked();
+
+    let events_2 = thread.update(cx, |thread, cx| {
+        thread.send(UserMessageId::new(), ["Hello 2"], cx)
+    });
+    cx.run_until_parked();
+    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
+    fake_model
+        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
+    fake_model.end_last_completion_stream();
+
+    let events_1 = events_1.collect::<Vec<_>>().await;
+    assert_eq!(stop_events(events_1), vec![acp::StopReason::Canceled]);
+    let events_2 = events_2.collect::<Vec<_>>().await;
+    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
+}
+
+#[gpui::test]
+async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
+    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+    let fake_model = model.as_fake();
+
+    let events_1 = thread.update(cx, |thread, cx| {
+        thread.send(UserMessageId::new(), ["Hello 1"], cx)
+    });
+    cx.run_until_parked();
+    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
+    fake_model
+        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
+    fake_model.end_last_completion_stream();
+    let events_1 = events_1.collect::<Vec<_>>().await;
+
+    let events_2 = thread.update(cx, |thread, cx| {
+        thread.send(UserMessageId::new(), ["Hello 2"], cx)
+    });
+    cx.run_until_parked();
+    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
+    fake_model
+        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
+    fake_model.end_last_completion_stream();
+    let events_2 = events_2.collect::<Vec<_>>().await;
+
+    assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
+    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
+}
+
 #[gpui::test]
 async fn test_refusal(cx: &mut TestAppContext) {
     let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;

crates/agent2/src/thread.rs 🔗

@@ -461,7 +461,7 @@ pub struct Thread {
     /// Holds the task that handles agent interaction until the end of the turn.
     /// Survives across multiple requests as the model performs tool calls and
     /// we run tools, report their results.
-    running_turn: Option<Task<()>>,
+    running_turn: Option<RunningTurn>,
     pending_message: Option<AgentMessage>,
     tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
     tool_use_limit_reached: bool,
@@ -554,8 +554,9 @@ impl Thread {
     }
 
     pub fn cancel(&mut self) {
-        // TODO: do we need to emit a stop::cancel for ACP?
-        self.running_turn.take();
+        if let Some(running_turn) = self.running_turn.take() {
+            running_turn.cancel();
+        }
         self.flush_pending_message();
     }
 
@@ -616,108 +617,118 @@ impl Thread {
         &mut self,
         cx: &mut Context<Self>,
     ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>> {
+        self.cancel();
+
         let model = self.model.clone();
         let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
         let event_stream = AgentResponseEventStream(events_tx);
         let message_ix = self.messages.len().saturating_sub(1);
         self.tool_use_limit_reached = false;
-        self.running_turn = Some(cx.spawn(async move |this, cx| {
-            log::info!("Starting agent turn execution");
-            let turn_result: Result<()> = async {
-                let mut completion_intent = CompletionIntent::UserPrompt;
-                loop {
-                    log::debug!(
-                        "Building completion request with intent: {:?}",
-                        completion_intent
-                    );
-                    let request = this.update(cx, |this, cx| {
-                        this.build_completion_request(completion_intent, cx)
-                    })?;
-
-                    log::info!("Calling model.stream_completion");
-                    let mut events = model.stream_completion(request, cx).await?;
-                    log::debug!("Stream completion started successfully");
-
-                    let mut tool_use_limit_reached = false;
-                    let mut tool_uses = FuturesUnordered::new();
-                    while let Some(event) = events.next().await {
-                        match event? {
-                            LanguageModelCompletionEvent::StatusUpdate(
-                                CompletionRequestStatus::ToolUseLimitReached,
-                            ) => {
-                                tool_use_limit_reached = true;
-                            }
-                            LanguageModelCompletionEvent::Stop(reason) => {
-                                event_stream.send_stop(reason);
-                                if reason == StopReason::Refusal {
-                                    this.update(cx, |this, _cx| {
-                                        this.flush_pending_message();
-                                        this.messages.truncate(message_ix);
-                                    })?;
-                                    return Ok(());
+        self.running_turn = Some(RunningTurn {
+            event_stream: event_stream.clone(),
+            _task: cx.spawn(async move |this, cx| {
+                log::info!("Starting agent turn execution");
+                let turn_result: Result<()> = async {
+                    let mut completion_intent = CompletionIntent::UserPrompt;
+                    loop {
+                        log::debug!(
+                            "Building completion request with intent: {:?}",
+                            completion_intent
+                        );
+                        let request = this.update(cx, |this, cx| {
+                            this.build_completion_request(completion_intent, cx)
+                        })?;
+
+                        log::info!("Calling model.stream_completion");
+                        let mut events = model.stream_completion(request, cx).await?;
+                        log::debug!("Stream completion started successfully");
+
+                        let mut tool_use_limit_reached = false;
+                        let mut tool_uses = FuturesUnordered::new();
+                        while let Some(event) = events.next().await {
+                            match event? {
+                                LanguageModelCompletionEvent::StatusUpdate(
+                                    CompletionRequestStatus::ToolUseLimitReached,
+                                ) => {
+                                    tool_use_limit_reached = true;
+                                }
+                                LanguageModelCompletionEvent::Stop(reason) => {
+                                    event_stream.send_stop(reason);
+                                    if reason == StopReason::Refusal {
+                                        this.update(cx, |this, _cx| {
+                                            this.flush_pending_message();
+                                            this.messages.truncate(message_ix);
+                                        })?;
+                                        return Ok(());
+                                    }
+                                }
+                                event => {
+                                    log::trace!("Received completion event: {:?}", event);
+                                    this.update(cx, |this, cx| {
+                                        tool_uses.extend(this.handle_streamed_completion_event(
+                                            event,
+                                            &event_stream,
+                                            cx,
+                                        ));
+                                    })
+                                    .ok();
                                 }
-                            }
-                            event => {
-                                log::trace!("Received completion event: {:?}", event);
-                                this.update(cx, |this, cx| {
-                                    tool_uses.extend(this.handle_streamed_completion_event(
-                                        event,
-                                        &event_stream,
-                                        cx,
-                                    ));
-                                })
-                                .ok();
                             }
                         }
-                    }
 
-                    let used_tools = tool_uses.is_empty();
-                    while let Some(tool_result) = tool_uses.next().await {
-                        log::info!("Tool finished {:?}", tool_result);
-
-                        event_stream.update_tool_call_fields(
-                            &tool_result.tool_use_id,
-                            acp::ToolCallUpdateFields {
-                                status: Some(if tool_result.is_error {
-                                    acp::ToolCallStatus::Failed
-                                } else {
-                                    acp::ToolCallStatus::Completed
-                                }),
-                                raw_output: tool_result.output.clone(),
-                                ..Default::default()
-                            },
-                        );
-                        this.update(cx, |this, _cx| {
-                            this.pending_message()
-                                .tool_results
-                                .insert(tool_result.tool_use_id.clone(), tool_result);
-                        })
-                        .ok();
-                    }
+                        let used_tools = tool_uses.is_empty();
+                        while let Some(tool_result) = tool_uses.next().await {
+                            log::info!("Tool finished {:?}", tool_result);
+
+                            event_stream.update_tool_call_fields(
+                                &tool_result.tool_use_id,
+                                acp::ToolCallUpdateFields {
+                                    status: Some(if tool_result.is_error {
+                                        acp::ToolCallStatus::Failed
+                                    } else {
+                                        acp::ToolCallStatus::Completed
+                                    }),
+                                    raw_output: tool_result.output.clone(),
+                                    ..Default::default()
+                                },
+                            );
+                            this.update(cx, |this, _cx| {
+                                this.pending_message()
+                                    .tool_results
+                                    .insert(tool_result.tool_use_id.clone(), tool_result);
+                            })
+                            .ok();
+                        }
 
-                    if tool_use_limit_reached {
-                        log::info!("Tool use limit reached, completing turn");
-                        this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
-                        return Err(language_model::ToolUseLimitReachedError.into());
-                    } else if used_tools {
-                        log::info!("No tool uses found, completing turn");
-                        return Ok(());
-                    } else {
-                        this.update(cx, |this, _| this.flush_pending_message())?;
-                        completion_intent = CompletionIntent::ToolResults;
+                        if tool_use_limit_reached {
+                            log::info!("Tool use limit reached, completing turn");
+                            this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
+                            return Err(language_model::ToolUseLimitReachedError.into());
+                        } else if used_tools {
+                            log::info!("No tool uses found, completing turn");
+                            return Ok(());
+                        } else {
+                            this.update(cx, |this, _| this.flush_pending_message())?;
+                            completion_intent = CompletionIntent::ToolResults;
+                        }
                     }
                 }
-            }
-            .await;
+                .await;
 
-            this.update(cx, |this, _| this.flush_pending_message()).ok();
-            if let Err(error) = turn_result {
-                log::error!("Turn execution failed: {:?}", error);
-                event_stream.send_error(error);
-            } else {
-                log::info!("Turn execution completed successfully");
-            }
-        }));
+                if let Err(error) = turn_result {
+                    log::error!("Turn execution failed: {:?}", error);
+                    event_stream.send_error(error);
+                } else {
+                    log::info!("Turn execution completed successfully");
+                }
+
+                this.update(cx, |this, _| {
+                    this.flush_pending_message();
+                    this.running_turn.take();
+                })
+                .ok();
+            }),
+        });
         events_rx
     }
 
@@ -1125,6 +1136,23 @@ impl Thread {
     }
 }
 
+struct RunningTurn {
+    /// Holds the task that handles agent interaction until the end of the turn.
+    /// Survives across multiple requests as the model performs tool calls and
+    /// we run tools, report their results.
+    _task: Task<()>,
+    /// The current event stream for the running turn. Used to report a final
+    /// cancellation event if we cancel the turn.
+    event_stream: AgentResponseEventStream,
+}
+
+impl RunningTurn {
+    fn cancel(self) {
+        log::debug!("Cancelling in progress turn");
+        self.event_stream.send_canceled();
+    }
+}
+
 pub trait AgentTool
 where
     Self: 'static + Sized,
@@ -1336,6 +1364,12 @@ impl AgentResponseEventStream {
         }
     }
 
+    fn send_canceled(&self) {
+        self.0
+            .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled)))
+            .ok();
+    }
+
     fn send_error(&self, error: impl Into<anyhow::Error>) {
         self.0.unbounded_send(Err(error.into())).ok();
     }