@@ -941,7 +941,15 @@ async fn test_cancellation(cx: &mut TestAppContext) {
// Cancel the current send and ensure that the event stream is closed, even
// if one of the tools is still running.
thread.update(cx, |thread, _cx| thread.cancel());
- events.collect::<Vec<_>>().await;
+ let events = events.collect::<Vec<_>>().await;
+ let last_event = events.last();
+ assert!(
+ matches!(
+ last_event,
+ Some(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled)))
+ ),
+ "unexpected event {last_event:?}"
+ );
// Ensure we can still send a new message after cancellation.
let events = thread
@@ -965,6 +973,62 @@ async fn test_cancellation(cx: &mut TestAppContext) {
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
}
+#[gpui::test]
+async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
+ let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ let events_1 = thread.update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Hello 1"], cx)
+ });
+ cx.run_until_parked();
+ fake_model.send_last_completion_stream_text_chunk("Hey 1!");
+ cx.run_until_parked();
+
+ let events_2 = thread.update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Hello 2"], cx)
+ });
+ cx.run_until_parked();
+ fake_model.send_last_completion_stream_text_chunk("Hey 2!");
+ fake_model
+ .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
+ fake_model.end_last_completion_stream();
+
+ let events_1 = events_1.collect::<Vec<_>>().await;
+ assert_eq!(stop_events(events_1), vec![acp::StopReason::Canceled]);
+ let events_2 = events_2.collect::<Vec<_>>().await;
+ assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
+}
+
+#[gpui::test]
+async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
+ let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ let events_1 = thread.update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Hello 1"], cx)
+ });
+ cx.run_until_parked();
+ fake_model.send_last_completion_stream_text_chunk("Hey 1!");
+ fake_model
+ .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
+ fake_model.end_last_completion_stream();
+ let events_1 = events_1.collect::<Vec<_>>().await;
+
+ let events_2 = thread.update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Hello 2"], cx)
+ });
+ cx.run_until_parked();
+ fake_model.send_last_completion_stream_text_chunk("Hey 2!");
+ fake_model
+ .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
+ fake_model.end_last_completion_stream();
+ let events_2 = events_2.collect::<Vec<_>>().await;
+
+ assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
+ assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
+}
+
#[gpui::test]
async fn test_refusal(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
@@ -461,7 +461,7 @@ pub struct Thread {
/// Holds the task that handles agent interaction until the end of the turn.
/// Survives across multiple requests as the model performs tool calls and
/// we run tools, report their results.
- running_turn: Option<Task<()>>,
+ running_turn: Option<RunningTurn>,
pending_message: Option<AgentMessage>,
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
tool_use_limit_reached: bool,
@@ -554,8 +554,9 @@ impl Thread {
}
pub fn cancel(&mut self) {
- // TODO: do we need to emit a stop::cancel for ACP?
- self.running_turn.take();
+ if let Some(running_turn) = self.running_turn.take() {
+ running_turn.cancel();
+ }
self.flush_pending_message();
}
@@ -616,108 +617,118 @@ impl Thread {
&mut self,
cx: &mut Context<Self>,
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>> {
+ self.cancel();
+
let model = self.model.clone();
let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
let event_stream = AgentResponseEventStream(events_tx);
let message_ix = self.messages.len().saturating_sub(1);
self.tool_use_limit_reached = false;
- self.running_turn = Some(cx.spawn(async move |this, cx| {
- log::info!("Starting agent turn execution");
- let turn_result: Result<()> = async {
- let mut completion_intent = CompletionIntent::UserPrompt;
- loop {
- log::debug!(
- "Building completion request with intent: {:?}",
- completion_intent
- );
- let request = this.update(cx, |this, cx| {
- this.build_completion_request(completion_intent, cx)
- })?;
-
- log::info!("Calling model.stream_completion");
- let mut events = model.stream_completion(request, cx).await?;
- log::debug!("Stream completion started successfully");
-
- let mut tool_use_limit_reached = false;
- let mut tool_uses = FuturesUnordered::new();
- while let Some(event) = events.next().await {
- match event? {
- LanguageModelCompletionEvent::StatusUpdate(
- CompletionRequestStatus::ToolUseLimitReached,
- ) => {
- tool_use_limit_reached = true;
- }
- LanguageModelCompletionEvent::Stop(reason) => {
- event_stream.send_stop(reason);
- if reason == StopReason::Refusal {
- this.update(cx, |this, _cx| {
- this.flush_pending_message();
- this.messages.truncate(message_ix);
- })?;
- return Ok(());
+ self.running_turn = Some(RunningTurn {
+ event_stream: event_stream.clone(),
+ _task: cx.spawn(async move |this, cx| {
+ log::info!("Starting agent turn execution");
+ let turn_result: Result<()> = async {
+ let mut completion_intent = CompletionIntent::UserPrompt;
+ loop {
+ log::debug!(
+ "Building completion request with intent: {:?}",
+ completion_intent
+ );
+ let request = this.update(cx, |this, cx| {
+ this.build_completion_request(completion_intent, cx)
+ })?;
+
+ log::info!("Calling model.stream_completion");
+ let mut events = model.stream_completion(request, cx).await?;
+ log::debug!("Stream completion started successfully");
+
+ let mut tool_use_limit_reached = false;
+ let mut tool_uses = FuturesUnordered::new();
+ while let Some(event) = events.next().await {
+ match event? {
+ LanguageModelCompletionEvent::StatusUpdate(
+ CompletionRequestStatus::ToolUseLimitReached,
+ ) => {
+ tool_use_limit_reached = true;
+ }
+ LanguageModelCompletionEvent::Stop(reason) => {
+ event_stream.send_stop(reason);
+ if reason == StopReason::Refusal {
+ this.update(cx, |this, _cx| {
+ this.flush_pending_message();
+ this.messages.truncate(message_ix);
+ })?;
+ return Ok(());
+ }
+ }
+ event => {
+ log::trace!("Received completion event: {:?}", event);
+ this.update(cx, |this, cx| {
+ tool_uses.extend(this.handle_streamed_completion_event(
+ event,
+ &event_stream,
+ cx,
+ ));
+ })
+ .ok();
}
- }
- event => {
- log::trace!("Received completion event: {:?}", event);
- this.update(cx, |this, cx| {
- tool_uses.extend(this.handle_streamed_completion_event(
- event,
- &event_stream,
- cx,
- ));
- })
- .ok();
}
}
- }
- let used_tools = tool_uses.is_empty();
- while let Some(tool_result) = tool_uses.next().await {
- log::info!("Tool finished {:?}", tool_result);
-
- event_stream.update_tool_call_fields(
- &tool_result.tool_use_id,
- acp::ToolCallUpdateFields {
- status: Some(if tool_result.is_error {
- acp::ToolCallStatus::Failed
- } else {
- acp::ToolCallStatus::Completed
- }),
- raw_output: tool_result.output.clone(),
- ..Default::default()
- },
- );
- this.update(cx, |this, _cx| {
- this.pending_message()
- .tool_results
- .insert(tool_result.tool_use_id.clone(), tool_result);
- })
- .ok();
- }
+ let used_tools = tool_uses.is_empty();
+ while let Some(tool_result) = tool_uses.next().await {
+ log::info!("Tool finished {:?}", tool_result);
+
+ event_stream.update_tool_call_fields(
+ &tool_result.tool_use_id,
+ acp::ToolCallUpdateFields {
+ status: Some(if tool_result.is_error {
+ acp::ToolCallStatus::Failed
+ } else {
+ acp::ToolCallStatus::Completed
+ }),
+ raw_output: tool_result.output.clone(),
+ ..Default::default()
+ },
+ );
+ this.update(cx, |this, _cx| {
+ this.pending_message()
+ .tool_results
+ .insert(tool_result.tool_use_id.clone(), tool_result);
+ })
+ .ok();
+ }
- if tool_use_limit_reached {
- log::info!("Tool use limit reached, completing turn");
- this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
- return Err(language_model::ToolUseLimitReachedError.into());
- } else if used_tools {
- log::info!("No tool uses found, completing turn");
- return Ok(());
- } else {
- this.update(cx, |this, _| this.flush_pending_message())?;
- completion_intent = CompletionIntent::ToolResults;
+ if tool_use_limit_reached {
+ log::info!("Tool use limit reached, completing turn");
+ this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
+ return Err(language_model::ToolUseLimitReachedError.into());
+ } else if used_tools {
+ log::info!("No tool uses found, completing turn");
+ return Ok(());
+ } else {
+ this.update(cx, |this, _| this.flush_pending_message())?;
+ completion_intent = CompletionIntent::ToolResults;
+ }
}
}
- }
- .await;
+ .await;
- this.update(cx, |this, _| this.flush_pending_message()).ok();
- if let Err(error) = turn_result {
- log::error!("Turn execution failed: {:?}", error);
- event_stream.send_error(error);
- } else {
- log::info!("Turn execution completed successfully");
- }
- }));
+ if let Err(error) = turn_result {
+ log::error!("Turn execution failed: {:?}", error);
+ event_stream.send_error(error);
+ } else {
+ log::info!("Turn execution completed successfully");
+ }
+
+ this.update(cx, |this, _| {
+ this.flush_pending_message();
+ this.running_turn.take();
+ })
+ .ok();
+ }),
+ });
events_rx
}
@@ -1125,6 +1136,23 @@ impl Thread {
}
}
+struct RunningTurn {
+ /// Holds the task that handles agent interaction until the end of the turn.
+ /// Survives across multiple requests as the model performs tool calls and
+ /// we run tools, report their results.
+ _task: Task<()>,
+ /// The current event stream for the running turn. Used to report a final
+ /// cancellation event if we cancel the turn.
+ event_stream: AgentResponseEventStream,
+}
+
+impl RunningTurn {
+ fn cancel(self) {
+ log::debug!("Cancelling in progress turn");
+ self.event_stream.send_canceled();
+ }
+}
+
pub trait AgentTool
where
Self: 'static + Sized,
@@ -1336,6 +1364,12 @@ impl AgentResponseEventStream {
}
}
+ fn send_canceled(&self) {
+ self.0
+ .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled)))
+ .ok();
+ }
+
fn send_error(&self, error: impl Into<anyhow::Error>) {
self.0.unbounded_send(Err(error.into())).ok();
}