@@ -61,6 +61,8 @@ fn init_test(cx: &mut TestAppContext) {
struct FakeTerminalHandle {
killed: Arc<AtomicBool>,
+ stopped_by_user: Arc<AtomicBool>,
+ exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
output: acp::TerminalOutputResponse,
id: acp::TerminalId,
@@ -69,23 +71,22 @@ struct FakeTerminalHandle {
impl FakeTerminalHandle {
fn new_never_exits(cx: &mut App) -> Self {
let killed = Arc::new(AtomicBool::new(false));
+ let stopped_by_user = Arc::new(AtomicBool::new(false));
+
+ let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
- let killed_for_task = killed.clone();
let wait_for_exit = cx
- .spawn(async move |cx| {
- loop {
- if killed_for_task.load(Ordering::SeqCst) {
- return acp::TerminalExitStatus::new();
- }
- cx.background_executor()
- .timer(Duration::from_millis(1))
- .await;
- }
+ .spawn(async move |_cx| {
+ // Wait for the exit signal (sent when kill() is called)
+ let _ = exit_receiver.await;
+ acp::TerminalExitStatus::new()
})
.shared();
Self {
killed,
+ stopped_by_user,
+ exit_sender: std::cell::RefCell::new(Some(exit_sender)),
wait_for_exit,
output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
id: acp::TerminalId::new("fake_terminal".to_string()),
@@ -95,6 +96,16 @@ impl FakeTerminalHandle {
fn was_killed(&self) -> bool {
self.killed.load(Ordering::SeqCst)
}
+
+ fn set_stopped_by_user(&self, stopped: bool) {
+ self.stopped_by_user.store(stopped, Ordering::SeqCst);
+ }
+
+ fn signal_exit(&self) {
+ if let Some(sender) = self.exit_sender.borrow_mut().take() {
+ let _ = sender.send(());
+ }
+ }
}
impl crate::TerminalHandle for FakeTerminalHandle {
@@ -112,11 +123,12 @@ impl crate::TerminalHandle for FakeTerminalHandle {
fn kill(&self, _cx: &AsyncApp) -> Result<()> {
self.killed.store(true, Ordering::SeqCst);
+ self.signal_exit();
Ok(())
}
fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
- Ok(false)
+ Ok(self.stopped_by_user.load(Ordering::SeqCst))
}
}
@@ -136,6 +148,37 @@ impl crate::ThreadEnvironment for FakeThreadEnvironment {
}
}
+/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
+struct MultiTerminalEnvironment {
+ handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
+}
+
+impl MultiTerminalEnvironment {
+ fn new() -> Self {
+ Self {
+ handles: std::cell::RefCell::new(Vec::new()),
+ }
+ }
+
+ fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
+ self.handles.borrow().clone()
+ }
+}
+
+impl crate::ThreadEnvironment for MultiTerminalEnvironment {
+ fn create_terminal(
+ &self,
+ _command: String,
+ _cwd: Option<std::path::PathBuf>,
+ _output_byte_limit: Option<u64>,
+ cx: &mut AsyncApp,
+ ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
+ let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
+ self.handles.borrow_mut().push(handle.clone());
+ Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
+ }
+}
+
fn always_allow_tools(cx: &mut TestAppContext) {
cx.update(|cx| {
let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
@@ -1596,7 +1639,7 @@ async fn test_cancellation(cx: &mut TestAppContext) {
// Cancel the current send and ensure that the event stream is closed, even
// if one of the tools is still running.
- thread.update(cx, |thread, cx| thread.cancel(cx));
+ thread.update(cx, |thread, cx| thread.cancel(cx)).await;
let events = events.collect::<Vec<_>>().await;
let last_event = events.last();
assert!(
@@ -1630,6 +1673,563 @@ async fn test_cancellation(cx: &mut TestAppContext) {
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
}
+#[gpui::test]
+async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
+ let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+ always_allow_tools(cx);
+ let fake_model = model.as_fake();
+
+ let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
+ let environment = Rc::new(FakeThreadEnvironment {
+ handle: handle.clone(),
+ });
+
+ let mut events = thread
+ .update(cx, |thread, cx| {
+ thread.add_tool(crate::TerminalTool::new(
+ thread.project().clone(),
+ environment,
+ ));
+ thread.send(UserMessageId::new(), ["run a command"], cx)
+ })
+ .unwrap();
+
+ cx.run_until_parked();
+
+ // Simulate the model calling the terminal tool
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: "terminal_tool_1".into(),
+ name: "terminal".into(),
+ raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
+ input: json!({"command": "sleep 1000", "cd": "."}),
+ is_input_complete: true,
+ thought_signature: None,
+ },
+ ));
+ fake_model.end_last_completion_stream();
+
+ // Wait for the terminal tool to start running
+ wait_for_terminal_tool_started(&mut events, cx).await;
+
+ // Cancel the thread while the terminal is running
+ thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
+
+ // Collect remaining events, driving the executor to let cancellation complete
+ let remaining_events = collect_events_until_stop(&mut events, cx).await;
+
+ // Verify the terminal was killed
+ assert!(
+ handle.was_killed(),
+ "expected terminal handle to be killed on cancellation"
+ );
+
+ // Verify we got a cancellation stop event
+ assert_eq!(
+ stop_events(remaining_events),
+ vec![acp::StopReason::Cancelled],
+ );
+
+ // Verify the tool result contains the terminal output, not just "Tool canceled by user"
+ thread.update(cx, |thread, _cx| {
+ let message = thread.last_message().unwrap();
+ let agent_message = message.as_agent_message().unwrap();
+
+ let tool_use = agent_message
+ .content
+ .iter()
+ .find_map(|content| match content {
+ AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
+ _ => None,
+ })
+ .expect("expected tool use in agent message");
+
+ let tool_result = agent_message
+ .tool_results
+ .get(&tool_use.id)
+ .expect("expected tool result");
+
+ let result_text = match &tool_result.content {
+ language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
+ _ => panic!("expected text content in tool result"),
+ };
+
+ // "partial output" comes from FakeTerminalHandle's output field
+ assert!(
+ result_text.contains("partial output"),
+ "expected tool result to contain terminal output, got: {result_text}"
+ );
+ // Match the actual format from process_content in terminal_tool.rs
+ assert!(
+ result_text.contains("The user stopped this command"),
+ "expected tool result to indicate user stopped, got: {result_text}"
+ );
+ });
+
+ // Verify we can send a new message after cancellation
+ verify_thread_recovery(&thread, &fake_model, cx).await;
+}
+
+/// Helper to verify thread can recover after cancellation by sending a simple message.
+async fn verify_thread_recovery(
+ thread: &Entity<Thread>,
+ fake_model: &FakeLanguageModel,
+ cx: &mut TestAppContext,
+) {
+ let events = thread
+ .update(cx, |thread, cx| {
+ thread.send(
+ UserMessageId::new(),
+ ["Testing: reply with 'Hello' then stop."],
+ cx,
+ )
+ })
+ .unwrap();
+ cx.run_until_parked();
+ fake_model.send_last_completion_stream_text_chunk("Hello");
+ fake_model
+ .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
+ fake_model.end_last_completion_stream();
+
+ let events = events.collect::<Vec<_>>().await;
+ thread.update(cx, |thread, _cx| {
+ let message = thread.last_message().unwrap();
+ let agent_message = message.as_agent_message().unwrap();
+ assert_eq!(
+ agent_message.content,
+ vec![AgentMessageContent::Text("Hello".to_string())]
+ );
+ });
+ assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
+}
+
+/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
+async fn wait_for_terminal_tool_started(
+ events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
+ cx: &mut TestAppContext,
+) {
+ let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
+ for _ in 0..deadline {
+ cx.run_until_parked();
+
+ while let Some(Some(event)) = events.next().now_or_never() {
+ if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
+ update,
+ ))) = &event
+ {
+ if update.fields.content.as_ref().is_some_and(|content| {
+ content
+ .iter()
+ .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
+ }) {
+ return;
+ }
+ }
+ }
+
+ cx.background_executor
+ .timer(Duration::from_millis(10))
+ .await;
+ }
+ panic!("terminal tool did not start within the expected time");
+}
+
+/// Collects events until a Stop event is received, driving the executor to completion.
+async fn collect_events_until_stop(
+ events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
+ cx: &mut TestAppContext,
+) -> Vec<Result<ThreadEvent>> {
+ let mut collected = Vec::new();
+ let deadline = cx.executor().num_cpus() * 200;
+
+ for _ in 0..deadline {
+ cx.executor().advance_clock(Duration::from_millis(10));
+ cx.run_until_parked();
+
+ while let Some(Some(event)) = events.next().now_or_never() {
+ let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
+ collected.push(event);
+ if is_stop {
+ return collected;
+ }
+ }
+ }
+ panic!(
+ "did not receive Stop event within the expected time; collected {} events",
+ collected.len()
+ );
+}
+
+#[gpui::test]
+async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
+ let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+ always_allow_tools(cx);
+ let fake_model = model.as_fake();
+
+ let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
+ let environment = Rc::new(FakeThreadEnvironment {
+ handle: handle.clone(),
+ });
+
+ let message_id = UserMessageId::new();
+ let mut events = thread
+ .update(cx, |thread, cx| {
+ thread.add_tool(crate::TerminalTool::new(
+ thread.project().clone(),
+ environment,
+ ));
+ thread.send(message_id.clone(), ["run a command"], cx)
+ })
+ .unwrap();
+
+ cx.run_until_parked();
+
+ // Simulate the model calling the terminal tool
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: "terminal_tool_1".into(),
+ name: "terminal".into(),
+ raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
+ input: json!({"command": "sleep 1000", "cd": "."}),
+ is_input_complete: true,
+ thought_signature: None,
+ },
+ ));
+ fake_model.end_last_completion_stream();
+
+ // Wait for the terminal tool to start running
+ wait_for_terminal_tool_started(&mut events, cx).await;
+
+ // Truncate the thread while the terminal is running
+ thread
+ .update(cx, |thread, cx| thread.truncate(message_id, cx))
+ .unwrap();
+
+ // Drive the executor to let cancellation complete
+ let _ = collect_events_until_stop(&mut events, cx).await;
+
+ // Verify the terminal was killed
+ assert!(
+ handle.was_killed(),
+ "expected terminal handle to be killed on truncate"
+ );
+
+ // Verify the thread is empty after truncation
+ thread.update(cx, |thread, _cx| {
+ assert_eq!(
+ thread.to_markdown(),
+ "",
+ "expected thread to be empty after truncating the only message"
+ );
+ });
+
+ // Verify we can send a new message after truncation
+ verify_thread_recovery(&thread, &fake_model, cx).await;
+}
+
+#[gpui::test]
+async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
+ // Tests that cancellation properly kills all running terminal tools when multiple are active.
+ let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+ always_allow_tools(cx);
+ let fake_model = model.as_fake();
+
+ let environment = Rc::new(MultiTerminalEnvironment::new());
+
+ let mut events = thread
+ .update(cx, |thread, cx| {
+ thread.add_tool(crate::TerminalTool::new(
+ thread.project().clone(),
+ environment.clone(),
+ ));
+ thread.send(UserMessageId::new(), ["run multiple commands"], cx)
+ })
+ .unwrap();
+
+ cx.run_until_parked();
+
+ // Simulate the model calling two terminal tools
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: "terminal_tool_1".into(),
+ name: "terminal".into(),
+ raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
+ input: json!({"command": "sleep 1000", "cd": "."}),
+ is_input_complete: true,
+ thought_signature: None,
+ },
+ ));
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: "terminal_tool_2".into(),
+ name: "terminal".into(),
+ raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
+ input: json!({"command": "sleep 2000", "cd": "."}),
+ is_input_complete: true,
+ thought_signature: None,
+ },
+ ));
+ fake_model.end_last_completion_stream();
+
+ // Wait for both terminal tools to start by counting terminal content updates
+ let mut terminals_started = 0;
+ let deadline = cx.executor().num_cpus() * 100;
+ for _ in 0..deadline {
+ cx.run_until_parked();
+
+ while let Some(Some(event)) = events.next().now_or_never() {
+ if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
+ update,
+ ))) = &event
+ {
+ if update.fields.content.as_ref().is_some_and(|content| {
+ content
+ .iter()
+ .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
+ }) {
+ terminals_started += 1;
+ if terminals_started >= 2 {
+ break;
+ }
+ }
+ }
+ }
+ if terminals_started >= 2 {
+ break;
+ }
+
+ cx.background_executor
+ .timer(Duration::from_millis(10))
+ .await;
+ }
+ assert!(
+ terminals_started >= 2,
+ "expected 2 terminal tools to start, got {terminals_started}"
+ );
+
+ // Cancel the thread while both terminals are running
+ thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
+
+ // Collect remaining events
+ let remaining_events = collect_events_until_stop(&mut events, cx).await;
+
+ // Verify both terminal handles were killed
+ let handles = environment.handles();
+ assert_eq!(
+ handles.len(),
+ 2,
+ "expected 2 terminal handles to be created"
+ );
+ assert!(
+ handles[0].was_killed(),
+ "expected first terminal handle to be killed on cancellation"
+ );
+ assert!(
+ handles[1].was_killed(),
+ "expected second terminal handle to be killed on cancellation"
+ );
+
+ // Verify we got a cancellation stop event
+ assert_eq!(
+ stop_events(remaining_events),
+ vec![acp::StopReason::Cancelled],
+ );
+}
+
+#[gpui::test]
+async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
+ // Tests that clicking the stop button on the terminal card (as opposed to the main
+ // cancel button) properly reports user stopped via the was_stopped_by_user path.
+ let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+ always_allow_tools(cx);
+ let fake_model = model.as_fake();
+
+ let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
+ let environment = Rc::new(FakeThreadEnvironment {
+ handle: handle.clone(),
+ });
+
+ let mut events = thread
+ .update(cx, |thread, cx| {
+ thread.add_tool(crate::TerminalTool::new(
+ thread.project().clone(),
+ environment,
+ ));
+ thread.send(UserMessageId::new(), ["run a command"], cx)
+ })
+ .unwrap();
+
+ cx.run_until_parked();
+
+ // Simulate the model calling the terminal tool
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: "terminal_tool_1".into(),
+ name: "terminal".into(),
+ raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
+ input: json!({"command": "sleep 1000", "cd": "."}),
+ is_input_complete: true,
+ thought_signature: None,
+ },
+ ));
+ fake_model.end_last_completion_stream();
+
+ // Wait for the terminal tool to start running
+ wait_for_terminal_tool_started(&mut events, cx).await;
+
+ // Simulate user clicking stop on the terminal card itself.
+ // This sets the flag and signals exit (simulating what the real UI would do).
+ handle.set_stopped_by_user(true);
+ handle.killed.store(true, Ordering::SeqCst);
+ handle.signal_exit();
+
+ // Wait for the tool to complete
+ cx.run_until_parked();
+
+ // The thread continues after tool completion - simulate the model ending its turn
+ fake_model
+ .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
+ fake_model.end_last_completion_stream();
+
+ // Collect remaining events
+ let remaining_events = collect_events_until_stop(&mut events, cx).await;
+
+ // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
+ assert_eq!(
+ stop_events(remaining_events),
+ vec![acp::StopReason::EndTurn],
+ );
+
+ // Verify the tool result indicates user stopped
+ thread.update(cx, |thread, _cx| {
+ let message = thread.last_message().unwrap();
+ let agent_message = message.as_agent_message().unwrap();
+
+ let tool_use = agent_message
+ .content
+ .iter()
+ .find_map(|content| match content {
+ AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
+ _ => None,
+ })
+ .expect("expected tool use in agent message");
+
+ let tool_result = agent_message
+ .tool_results
+ .get(&tool_use.id)
+ .expect("expected tool result");
+
+ let result_text = match &tool_result.content {
+ language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
+ _ => panic!("expected text content in tool result"),
+ };
+
+ assert!(
+ result_text.contains("The user stopped this command"),
+ "expected tool result to indicate user stopped, got: {result_text}"
+ );
+ });
+}
+
+#[gpui::test]
+async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
+ // Tests that when a timeout is configured and expires, the tool result indicates timeout.
+ let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+ always_allow_tools(cx);
+ let fake_model = model.as_fake();
+
+ let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
+ let environment = Rc::new(FakeThreadEnvironment {
+ handle: handle.clone(),
+ });
+
+ let mut events = thread
+ .update(cx, |thread, cx| {
+ thread.add_tool(crate::TerminalTool::new(
+ thread.project().clone(),
+ environment,
+ ));
+ thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
+ })
+ .unwrap();
+
+ cx.run_until_parked();
+
+ // Simulate the model calling the terminal tool with a short timeout
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: "terminal_tool_1".into(),
+ name: "terminal".into(),
+ raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
+ input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
+ is_input_complete: true,
+ thought_signature: None,
+ },
+ ));
+ fake_model.end_last_completion_stream();
+
+ // Wait for the terminal tool to start running
+ wait_for_terminal_tool_started(&mut events, cx).await;
+
+ // Advance clock past the timeout
+ cx.executor().advance_clock(Duration::from_millis(200));
+ cx.run_until_parked();
+
+ // The thread continues after tool completion - simulate the model ending its turn
+ fake_model
+ .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
+ fake_model.end_last_completion_stream();
+
+ // Collect remaining events
+ let remaining_events = collect_events_until_stop(&mut events, cx).await;
+
+ // Verify the terminal was killed due to timeout
+ assert!(
+ handle.was_killed(),
+ "expected terminal handle to be killed on timeout"
+ );
+
+ // Verify we got an EndTurn (the tool completed, just with timeout)
+ assert_eq!(
+ stop_events(remaining_events),
+ vec![acp::StopReason::EndTurn],
+ );
+
+ // Verify the tool result indicates timeout, not user stopped
+ thread.update(cx, |thread, _cx| {
+ let message = thread.last_message().unwrap();
+ let agent_message = message.as_agent_message().unwrap();
+
+ let tool_use = agent_message
+ .content
+ .iter()
+ .find_map(|content| match content {
+ AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
+ _ => None,
+ })
+ .expect("expected tool use in agent message");
+
+ let tool_result = agent_message
+ .tool_results
+ .get(&tool_use.id)
+ .expect("expected tool result");
+
+ let result_text = match &tool_result.content {
+ language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
+ _ => panic!("expected text content in tool result"),
+ };
+
+ assert!(
+ result_text.contains("timed out"),
+ "expected tool result to indicate timeout, got: {result_text}"
+ );
+ assert!(
+ !result_text.contains("The user stopped"),
+ "tool result should not mention user stopped when it timed out, got: {result_text}"
+ );
+ });
+}
+
#[gpui::test]
async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
@@ -2616,6 +3216,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
ToolRequiringPermission::name(): true,
InfiniteTool::name(): true,
ThinkingTool::name(): true,
+ "terminal": true,
}
}
}
@@ -1055,11 +1055,21 @@ impl Thread {
}
}
- pub fn cancel(&mut self, cx: &mut Context<Self>) {
- if let Some(running_turn) = self.running_turn.take() {
- running_turn.cancel();
- }
- self.flush_pending_message(cx);
+ pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
+ let Some(running_turn) = self.running_turn.take() else {
+ self.flush_pending_message(cx);
+ return Task::ready(());
+ };
+
+ let turn_task = running_turn.cancel();
+
+ cx.spawn(async move |this, cx| {
+ turn_task.await;
+ this.update(cx, |this, cx| {
+ this.flush_pending_message(cx);
+ })
+ .ok();
+ })
}
fn update_token_usage(&mut self, update: language_model::TokenUsage, cx: &mut Context<Self>) {
@@ -1074,7 +1084,10 @@ impl Thread {
}
pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
- self.cancel(cx);
+ self.cancel(cx).detach();
+ // Clear pending message since cancel will try to flush it asynchronously,
+ // and we don't want that content to be added after we truncate
+ self.pending_message.take();
let Some(position) = self.messages.iter().position(
|msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
) else {
@@ -1255,7 +1268,11 @@ impl Thread {
&mut self,
cx: &mut Context<Self>,
) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
- self.cancel(cx);
+ // Flush the old pending message synchronously before cancelling,
+ // to avoid a race where the detached cancel task might flush the NEW
+ // turn's pending message instead of the old one.
+ self.flush_pending_message(cx);
+ self.cancel(cx).detach();
let model = self.model.clone().context("No language model configured")?;
let profile = AgentSettings::get_global(cx)
@@ -1267,7 +1284,7 @@ impl Thread {
let message_ix = self.messages.len().saturating_sub(1);
self.tool_use_limit_reached = false;
self.clear_summary();
- let (cancellation_tx, cancellation_rx) = watch::channel(false);
+ let (cancellation_tx, mut cancellation_rx) = watch::channel(false);
self.running_turn = Some(RunningTurn {
event_stream: event_stream.clone(),
tools: self.enabled_tools(profile, &model, cx),
@@ -1275,8 +1292,23 @@ impl Thread {
_task: cx.spawn(async move |this, cx| {
log::debug!("Starting agent turn execution");
- let turn_result =
- Self::run_turn_internal(&this, model, &event_stream, cancellation_rx, cx).await;
+ let turn_result = Self::run_turn_internal(
+ &this,
+ model,
+ &event_stream,
+ cancellation_rx.clone(),
+ cx,
+ )
+ .await;
+
+ // Check if we were cancelled - if so, cancel() already took running_turn
+ // and we shouldn't touch it (it might be a NEW turn now)
+ let was_cancelled = *cancellation_rx.borrow();
+ if was_cancelled {
+ log::debug!("Turn was cancelled, skipping cleanup");
+ return;
+ }
+
_ = this.update(cx, |this, cx| this.flush_pending_message(cx));
match turn_result {
@@ -1311,7 +1343,7 @@ impl Thread {
this: &WeakEntity<Self>,
model: Arc<dyn LanguageModel>,
event_stream: &ThreadEventStream,
- cancellation_rx: watch::Receiver<bool>,
+ mut cancellation_rx: watch::Receiver<bool>,
cx: &mut AsyncApp,
) -> Result<()> {
let mut attempt = 0;
@@ -1336,7 +1368,22 @@ impl Thread {
Err(err) => (stream::empty().boxed(), Some(err)),
};
let mut tool_results = FuturesUnordered::new();
- while let Some(event) = events.next().await {
+ let mut cancelled = false;
+ loop {
+ // Race between getting the next event and cancellation
+ let event = futures::select! {
+ event = events.next().fuse() => event,
+ _ = cancellation_rx.changed().fuse() => {
+ if *cancellation_rx.borrow() {
+ cancelled = true;
+ break;
+ }
+ continue;
+ }
+ };
+ let Some(event) = event else {
+ break;
+ };
log::trace!("Received completion event: {:?}", event);
match event {
Ok(event) => {
@@ -1384,6 +1431,11 @@ impl Thread {
}
})?;
+ if cancelled {
+ log::debug!("Turn cancelled by user, exiting");
+ return Ok(());
+ }
+
if let Some(error) = error {
attempt += 1;
let retry = this.update(cx, |this, cx| {
@@ -2264,10 +2316,11 @@ struct RunningTurn {
}
impl RunningTurn {
- fn cancel(mut self) {
+ fn cancel(mut self) -> Task<()> {
log::debug!("Cancelling in progress turn");
self.cancellation_tx.send(true).ok();
self.event_stream.send_canceled();
+ self._task
}
}