diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index 90aecf47fad2d8c3562eff63e6dbf852c306a61c..2975ffa0d26dbaa956c1668aa4087a753491729d 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -1309,7 +1309,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection { log::info!("Cancelling on session: {}", session_id); self.0.update(cx, |agent, cx| { if let Some(agent) = agent.sessions.get(session_id) { - agent.thread.update(cx, |thread, cx| thread.cancel(cx)); + agent + .thread + .update(cx, |thread, cx| thread.cancel(cx)) + .detach(); } }); } diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index 4a12b47a0e1df015be29fd8ca97e9dd6ba3bff15..908c5edf47e6f4e7873e0f37d0a70c69d4f1717b 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -61,6 +61,8 @@ fn init_test(cx: &mut TestAppContext) { struct FakeTerminalHandle { killed: Arc, + stopped_by_user: Arc, + exit_sender: std::cell::RefCell>>, wait_for_exit: Shared>, output: acp::TerminalOutputResponse, id: acp::TerminalId, @@ -69,23 +71,22 @@ struct FakeTerminalHandle { impl FakeTerminalHandle { fn new_never_exits(cx: &mut App) -> Self { let killed = Arc::new(AtomicBool::new(false)); + let stopped_by_user = Arc::new(AtomicBool::new(false)); + + let (exit_sender, exit_receiver) = futures::channel::oneshot::channel(); - let killed_for_task = killed.clone(); let wait_for_exit = cx - .spawn(async move |cx| { - loop { - if killed_for_task.load(Ordering::SeqCst) { - return acp::TerminalExitStatus::new(); - } - cx.background_executor() - .timer(Duration::from_millis(1)) - .await; - } + .spawn(async move |_cx| { + // Wait for the exit signal (sent when kill() is called) + let _ = exit_receiver.await; + acp::TerminalExitStatus::new() }) .shared(); Self { killed, + stopped_by_user, + exit_sender: std::cell::RefCell::new(Some(exit_sender)), wait_for_exit, output: acp::TerminalOutputResponse::new("partial output".to_string(), false), id: acp::TerminalId::new("fake_terminal".to_string()), @@ -95,6 +96,16 @@ impl FakeTerminalHandle { fn was_killed(&self) -> bool { self.killed.load(Ordering::SeqCst) } + + fn set_stopped_by_user(&self, stopped: bool) { + self.stopped_by_user.store(stopped, Ordering::SeqCst); + } + + fn signal_exit(&self) { + if let Some(sender) = self.exit_sender.borrow_mut().take() { + let _ = sender.send(()); + } + } } impl crate::TerminalHandle for FakeTerminalHandle { @@ -112,11 +123,12 @@ impl crate::TerminalHandle for FakeTerminalHandle { fn kill(&self, _cx: &AsyncApp) -> Result<()> { self.killed.store(true, Ordering::SeqCst); + self.signal_exit(); Ok(()) } fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result { - Ok(false) + Ok(self.stopped_by_user.load(Ordering::SeqCst)) } } @@ -136,6 +148,37 @@ impl crate::ThreadEnvironment for FakeThreadEnvironment { } } +/// Environment that creates multiple independent terminal handles for testing concurrent terminals. +struct MultiTerminalEnvironment { + handles: std::cell::RefCell>>, +} + +impl MultiTerminalEnvironment { + fn new() -> Self { + Self { + handles: std::cell::RefCell::new(Vec::new()), + } + } + + fn handles(&self) -> Vec> { + self.handles.borrow().clone() + } +} + +impl crate::ThreadEnvironment for MultiTerminalEnvironment { + fn create_terminal( + &self, + _command: String, + _cwd: Option, + _output_byte_limit: Option, + cx: &mut AsyncApp, + ) -> Task>> { + let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx))); + self.handles.borrow_mut().push(handle.clone()); + Task::ready(Ok(handle as Rc)) + } +} + fn always_allow_tools(cx: &mut TestAppContext) { cx.update(|cx| { let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); @@ -1596,7 +1639,7 @@ async fn test_cancellation(cx: &mut TestAppContext) { // Cancel the current send and ensure that the event stream is closed, even // if one of the tools is still running. - thread.update(cx, |thread, cx| thread.cancel(cx)); + thread.update(cx, |thread, cx| thread.cancel(cx)).await; let events = events.collect::>().await; let last_event = events.last(); assert!( @@ -1630,6 +1673,563 @@ async fn test_cancellation(cx: &mut TestAppContext) { assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]); } +#[gpui::test] +async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + always_allow_tools(cx); + let fake_model = model.as_fake(); + + let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx))); + let environment = Rc::new(FakeThreadEnvironment { + handle: handle.clone(), + }); + + let mut events = thread + .update(cx, |thread, cx| { + thread.add_tool(crate::TerminalTool::new( + thread.project().clone(), + environment, + )); + thread.send(UserMessageId::new(), ["run a command"], cx) + }) + .unwrap(); + + cx.run_until_parked(); + + // Simulate the model calling the terminal tool + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "terminal_tool_1".into(), + name: "terminal".into(), + raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(), + input: json!({"command": "sleep 1000", "cd": "."}), + is_input_complete: true, + thought_signature: None, + }, + )); + fake_model.end_last_completion_stream(); + + // Wait for the terminal tool to start running + wait_for_terminal_tool_started(&mut events, cx).await; + + // Cancel the thread while the terminal is running + thread.update(cx, |thread, cx| thread.cancel(cx)).detach(); + + // Collect remaining events, driving the executor to let cancellation complete + let remaining_events = collect_events_until_stop(&mut events, cx).await; + + // Verify the terminal was killed + assert!( + handle.was_killed(), + "expected terminal handle to be killed on cancellation" + ); + + // Verify we got a cancellation stop event + assert_eq!( + stop_events(remaining_events), + vec![acp::StopReason::Cancelled], + ); + + // Verify the tool result contains the terminal output, not just "Tool canceled by user" + thread.update(cx, |thread, _cx| { + let message = thread.last_message().unwrap(); + let agent_message = message.as_agent_message().unwrap(); + + let tool_use = agent_message + .content + .iter() + .find_map(|content| match content { + AgentMessageContent::ToolUse(tool_use) => Some(tool_use), + _ => None, + }) + .expect("expected tool use in agent message"); + + let tool_result = agent_message + .tool_results + .get(&tool_use.id) + .expect("expected tool result"); + + let result_text = match &tool_result.content { + language_model::LanguageModelToolResultContent::Text(text) => text.to_string(), + _ => panic!("expected text content in tool result"), + }; + + // "partial output" comes from FakeTerminalHandle's output field + assert!( + result_text.contains("partial output"), + "expected tool result to contain terminal output, got: {result_text}" + ); + // Match the actual format from process_content in terminal_tool.rs + assert!( + result_text.contains("The user stopped this command"), + "expected tool result to indicate user stopped, got: {result_text}" + ); + }); + + // Verify we can send a new message after cancellation + verify_thread_recovery(&thread, &fake_model, cx).await; +} + +/// Helper to verify thread can recover after cancellation by sending a simple message. +async fn verify_thread_recovery( + thread: &Entity, + fake_model: &FakeLanguageModel, + cx: &mut TestAppContext, +) { + let events = thread + .update(cx, |thread, cx| { + thread.send( + UserMessageId::new(), + ["Testing: reply with 'Hello' then stop."], + cx, + ) + }) + .unwrap(); + cx.run_until_parked(); + fake_model.send_last_completion_stream_text_chunk("Hello"); + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)); + fake_model.end_last_completion_stream(); + + let events = events.collect::>().await; + thread.update(cx, |thread, _cx| { + let message = thread.last_message().unwrap(); + let agent_message = message.as_agent_message().unwrap(); + assert_eq!( + agent_message.content, + vec![AgentMessageContent::Text("Hello".to_string())] + ); + }); + assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]); +} + +/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content. +async fn wait_for_terminal_tool_started( + events: &mut mpsc::UnboundedReceiver>, + cx: &mut TestAppContext, +) { + let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism + for _ in 0..deadline { + cx.run_until_parked(); + + while let Some(Some(event)) = events.next().now_or_never() { + if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields( + update, + ))) = &event + { + if update.fields.content.as_ref().is_some_and(|content| { + content + .iter() + .any(|c| matches!(c, acp::ToolCallContent::Terminal(_))) + }) { + return; + } + } + } + + cx.background_executor + .timer(Duration::from_millis(10)) + .await; + } + panic!("terminal tool did not start within the expected time"); +} + +/// Collects events until a Stop event is received, driving the executor to completion. +async fn collect_events_until_stop( + events: &mut mpsc::UnboundedReceiver>, + cx: &mut TestAppContext, +) -> Vec> { + let mut collected = Vec::new(); + let deadline = cx.executor().num_cpus() * 200; + + for _ in 0..deadline { + cx.executor().advance_clock(Duration::from_millis(10)); + cx.run_until_parked(); + + while let Some(Some(event)) = events.next().now_or_never() { + let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_))); + collected.push(event); + if is_stop { + return collected; + } + } + } + panic!( + "did not receive Stop event within the expected time; collected {} events", + collected.len() + ); +} + +#[gpui::test] +async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + always_allow_tools(cx); + let fake_model = model.as_fake(); + + let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx))); + let environment = Rc::new(FakeThreadEnvironment { + handle: handle.clone(), + }); + + let message_id = UserMessageId::new(); + let mut events = thread + .update(cx, |thread, cx| { + thread.add_tool(crate::TerminalTool::new( + thread.project().clone(), + environment, + )); + thread.send(message_id.clone(), ["run a command"], cx) + }) + .unwrap(); + + cx.run_until_parked(); + + // Simulate the model calling the terminal tool + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "terminal_tool_1".into(), + name: "terminal".into(), + raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(), + input: json!({"command": "sleep 1000", "cd": "."}), + is_input_complete: true, + thought_signature: None, + }, + )); + fake_model.end_last_completion_stream(); + + // Wait for the terminal tool to start running + wait_for_terminal_tool_started(&mut events, cx).await; + + // Truncate the thread while the terminal is running + thread + .update(cx, |thread, cx| thread.truncate(message_id, cx)) + .unwrap(); + + // Drive the executor to let cancellation complete + let _ = collect_events_until_stop(&mut events, cx).await; + + // Verify the terminal was killed + assert!( + handle.was_killed(), + "expected terminal handle to be killed on truncate" + ); + + // Verify the thread is empty after truncation + thread.update(cx, |thread, _cx| { + assert_eq!( + thread.to_markdown(), + "", + "expected thread to be empty after truncating the only message" + ); + }); + + // Verify we can send a new message after truncation + verify_thread_recovery(&thread, &fake_model, cx).await; +} + +#[gpui::test] +async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) { + // Tests that cancellation properly kills all running terminal tools when multiple are active. + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + always_allow_tools(cx); + let fake_model = model.as_fake(); + + let environment = Rc::new(MultiTerminalEnvironment::new()); + + let mut events = thread + .update(cx, |thread, cx| { + thread.add_tool(crate::TerminalTool::new( + thread.project().clone(), + environment.clone(), + )); + thread.send(UserMessageId::new(), ["run multiple commands"], cx) + }) + .unwrap(); + + cx.run_until_parked(); + + // Simulate the model calling two terminal tools + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "terminal_tool_1".into(), + name: "terminal".into(), + raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(), + input: json!({"command": "sleep 1000", "cd": "."}), + is_input_complete: true, + thought_signature: None, + }, + )); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "terminal_tool_2".into(), + name: "terminal".into(), + raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(), + input: json!({"command": "sleep 2000", "cd": "."}), + is_input_complete: true, + thought_signature: None, + }, + )); + fake_model.end_last_completion_stream(); + + // Wait for both terminal tools to start by counting terminal content updates + let mut terminals_started = 0; + let deadline = cx.executor().num_cpus() * 100; + for _ in 0..deadline { + cx.run_until_parked(); + + while let Some(Some(event)) = events.next().now_or_never() { + if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields( + update, + ))) = &event + { + if update.fields.content.as_ref().is_some_and(|content| { + content + .iter() + .any(|c| matches!(c, acp::ToolCallContent::Terminal(_))) + }) { + terminals_started += 1; + if terminals_started >= 2 { + break; + } + } + } + } + if terminals_started >= 2 { + break; + } + + cx.background_executor + .timer(Duration::from_millis(10)) + .await; + } + assert!( + terminals_started >= 2, + "expected 2 terminal tools to start, got {terminals_started}" + ); + + // Cancel the thread while both terminals are running + thread.update(cx, |thread, cx| thread.cancel(cx)).detach(); + + // Collect remaining events + let remaining_events = collect_events_until_stop(&mut events, cx).await; + + // Verify both terminal handles were killed + let handles = environment.handles(); + assert_eq!( + handles.len(), + 2, + "expected 2 terminal handles to be created" + ); + assert!( + handles[0].was_killed(), + "expected first terminal handle to be killed on cancellation" + ); + assert!( + handles[1].was_killed(), + "expected second terminal handle to be killed on cancellation" + ); + + // Verify we got a cancellation stop event + assert_eq!( + stop_events(remaining_events), + vec![acp::StopReason::Cancelled], + ); +} + +#[gpui::test] +async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) { + // Tests that clicking the stop button on the terminal card (as opposed to the main + // cancel button) properly reports user stopped via the was_stopped_by_user path. + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + always_allow_tools(cx); + let fake_model = model.as_fake(); + + let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx))); + let environment = Rc::new(FakeThreadEnvironment { + handle: handle.clone(), + }); + + let mut events = thread + .update(cx, |thread, cx| { + thread.add_tool(crate::TerminalTool::new( + thread.project().clone(), + environment, + )); + thread.send(UserMessageId::new(), ["run a command"], cx) + }) + .unwrap(); + + cx.run_until_parked(); + + // Simulate the model calling the terminal tool + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "terminal_tool_1".into(), + name: "terminal".into(), + raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(), + input: json!({"command": "sleep 1000", "cd": "."}), + is_input_complete: true, + thought_signature: None, + }, + )); + fake_model.end_last_completion_stream(); + + // Wait for the terminal tool to start running + wait_for_terminal_tool_started(&mut events, cx).await; + + // Simulate user clicking stop on the terminal card itself. + // This sets the flag and signals exit (simulating what the real UI would do). + handle.set_stopped_by_user(true); + handle.killed.store(true, Ordering::SeqCst); + handle.signal_exit(); + + // Wait for the tool to complete + cx.run_until_parked(); + + // The thread continues after tool completion - simulate the model ending its turn + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)); + fake_model.end_last_completion_stream(); + + // Collect remaining events + let remaining_events = collect_events_until_stop(&mut events, cx).await; + + // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread) + assert_eq!( + stop_events(remaining_events), + vec![acp::StopReason::EndTurn], + ); + + // Verify the tool result indicates user stopped + thread.update(cx, |thread, _cx| { + let message = thread.last_message().unwrap(); + let agent_message = message.as_agent_message().unwrap(); + + let tool_use = agent_message + .content + .iter() + .find_map(|content| match content { + AgentMessageContent::ToolUse(tool_use) => Some(tool_use), + _ => None, + }) + .expect("expected tool use in agent message"); + + let tool_result = agent_message + .tool_results + .get(&tool_use.id) + .expect("expected tool result"); + + let result_text = match &tool_result.content { + language_model::LanguageModelToolResultContent::Text(text) => text.to_string(), + _ => panic!("expected text content in tool result"), + }; + + assert!( + result_text.contains("The user stopped this command"), + "expected tool result to indicate user stopped, got: {result_text}" + ); + }); +} + +#[gpui::test] +async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) { + // Tests that when a timeout is configured and expires, the tool result indicates timeout. + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + always_allow_tools(cx); + let fake_model = model.as_fake(); + + let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx))); + let environment = Rc::new(FakeThreadEnvironment { + handle: handle.clone(), + }); + + let mut events = thread + .update(cx, |thread, cx| { + thread.add_tool(crate::TerminalTool::new( + thread.project().clone(), + environment, + )); + thread.send(UserMessageId::new(), ["run a command with timeout"], cx) + }) + .unwrap(); + + cx.run_until_parked(); + + // Simulate the model calling the terminal tool with a short timeout + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "terminal_tool_1".into(), + name: "terminal".into(), + raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(), + input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}), + is_input_complete: true, + thought_signature: None, + }, + )); + fake_model.end_last_completion_stream(); + + // Wait for the terminal tool to start running + wait_for_terminal_tool_started(&mut events, cx).await; + + // Advance clock past the timeout + cx.executor().advance_clock(Duration::from_millis(200)); + cx.run_until_parked(); + + // The thread continues after tool completion - simulate the model ending its turn + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)); + fake_model.end_last_completion_stream(); + + // Collect remaining events + let remaining_events = collect_events_until_stop(&mut events, cx).await; + + // Verify the terminal was killed due to timeout + assert!( + handle.was_killed(), + "expected terminal handle to be killed on timeout" + ); + + // Verify we got an EndTurn (the tool completed, just with timeout) + assert_eq!( + stop_events(remaining_events), + vec![acp::StopReason::EndTurn], + ); + + // Verify the tool result indicates timeout, not user stopped + thread.update(cx, |thread, _cx| { + let message = thread.last_message().unwrap(); + let agent_message = message.as_agent_message().unwrap(); + + let tool_use = agent_message + .content + .iter() + .find_map(|content| match content { + AgentMessageContent::ToolUse(tool_use) => Some(tool_use), + _ => None, + }) + .expect("expected tool use in agent message"); + + let tool_result = agent_message + .tool_results + .get(&tool_use.id) + .expect("expected tool result"); + + let result_text = match &tool_result.content { + language_model::LanguageModelToolResultContent::Text(text) => text.to_string(), + _ => panic!("expected text content in tool result"), + }; + + assert!( + result_text.contains("timed out"), + "expected tool result to indicate timeout, got: {result_text}" + ); + assert!( + !result_text.contains("The user stopped"), + "tool result should not mention user stopped when it timed out, got: {result_text}" + ); + }); +} + #[gpui::test] async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) { let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; @@ -2616,6 +3216,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { ToolRequiringPermission::name(): true, InfiniteTool::name(): true, ThinkingTool::name(): true, + "terminal": true, } } } diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index d54130240c89f76a44ed71e7e4ebc5c65ac4aa2d..0095c572e1f7539b5ad3f00f0fba684565198541 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -1055,11 +1055,21 @@ impl Thread { } } - pub fn cancel(&mut self, cx: &mut Context) { - if let Some(running_turn) = self.running_turn.take() { - running_turn.cancel(); - } - self.flush_pending_message(cx); + pub fn cancel(&mut self, cx: &mut Context) -> Task<()> { + let Some(running_turn) = self.running_turn.take() else { + self.flush_pending_message(cx); + return Task::ready(()); + }; + + let turn_task = running_turn.cancel(); + + cx.spawn(async move |this, cx| { + turn_task.await; + this.update(cx, |this, cx| { + this.flush_pending_message(cx); + }) + .ok(); + }) } fn update_token_usage(&mut self, update: language_model::TokenUsage, cx: &mut Context) { @@ -1074,7 +1084,10 @@ impl Thread { } pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context) -> Result<()> { - self.cancel(cx); + self.cancel(cx).detach(); + // Clear pending message since cancel will try to flush it asynchronously, + // and we don't want that content to be added after we truncate + self.pending_message.take(); let Some(position) = self.messages.iter().position( |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id), ) else { @@ -1255,7 +1268,11 @@ impl Thread { &mut self, cx: &mut Context, ) -> Result>> { - self.cancel(cx); + // Flush the old pending message synchronously before cancelling, + // to avoid a race where the detached cancel task might flush the NEW + // turn's pending message instead of the old one. + self.flush_pending_message(cx); + self.cancel(cx).detach(); let model = self.model.clone().context("No language model configured")?; let profile = AgentSettings::get_global(cx) @@ -1267,7 +1284,7 @@ impl Thread { let message_ix = self.messages.len().saturating_sub(1); self.tool_use_limit_reached = false; self.clear_summary(); - let (cancellation_tx, cancellation_rx) = watch::channel(false); + let (cancellation_tx, mut cancellation_rx) = watch::channel(false); self.running_turn = Some(RunningTurn { event_stream: event_stream.clone(), tools: self.enabled_tools(profile, &model, cx), @@ -1275,8 +1292,23 @@ impl Thread { _task: cx.spawn(async move |this, cx| { log::debug!("Starting agent turn execution"); - let turn_result = - Self::run_turn_internal(&this, model, &event_stream, cancellation_rx, cx).await; + let turn_result = Self::run_turn_internal( + &this, + model, + &event_stream, + cancellation_rx.clone(), + cx, + ) + .await; + + // Check if we were cancelled - if so, cancel() already took running_turn + // and we shouldn't touch it (it might be a NEW turn now) + let was_cancelled = *cancellation_rx.borrow(); + if was_cancelled { + log::debug!("Turn was cancelled, skipping cleanup"); + return; + } + _ = this.update(cx, |this, cx| this.flush_pending_message(cx)); match turn_result { @@ -1311,7 +1343,7 @@ impl Thread { this: &WeakEntity, model: Arc, event_stream: &ThreadEventStream, - cancellation_rx: watch::Receiver, + mut cancellation_rx: watch::Receiver, cx: &mut AsyncApp, ) -> Result<()> { let mut attempt = 0; @@ -1336,7 +1368,22 @@ impl Thread { Err(err) => (stream::empty().boxed(), Some(err)), }; let mut tool_results = FuturesUnordered::new(); - while let Some(event) = events.next().await { + let mut cancelled = false; + loop { + // Race between getting the next event and cancellation + let event = futures::select! { + event = events.next().fuse() => event, + _ = cancellation_rx.changed().fuse() => { + if *cancellation_rx.borrow() { + cancelled = true; + break; + } + continue; + } + }; + let Some(event) = event else { + break; + }; log::trace!("Received completion event: {:?}", event); match event { Ok(event) => { @@ -1384,6 +1431,11 @@ impl Thread { } })?; + if cancelled { + log::debug!("Turn cancelled by user, exiting"); + return Ok(()); + } + if let Some(error) = error { attempt += 1; let retry = this.update(cx, |this, cx| { @@ -2264,10 +2316,11 @@ struct RunningTurn { } impl RunningTurn { - fn cancel(mut self) { + fn cancel(mut self) -> Task<()> { log::debug!("Cancelling in progress turn"); self.cancellation_tx.send(true).ok(); self.event_stream.send_canceled(); + self._task } } diff --git a/crates/agent/src/tools/terminal_tool.rs b/crates/agent/src/tools/terminal_tool.rs index d153c047f171c2dcbb08dd0adc8ea555d872112d..914a0d1f3262334a69ce5cf6b0c8633149dcb61c 100644 --- a/crates/agent/src/tools/terminal_tool.rs +++ b/crates/agent/src/tools/terminal_tool.rs @@ -1,7 +1,7 @@ use agent_client_protocol as acp; use anyhow::Result; use futures::FutureExt as _; -use gpui::{App, AppContext, Entity, SharedString, Task}; +use gpui::{App, Entity, SharedString, Task}; use project::Project; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -125,13 +125,12 @@ impl AgentTool for TerminalTool { let timeout = input.timeout_ms.map(Duration::from_millis); let mut timed_out = false; + let mut user_stopped_via_signal = false; let wait_for_exit = terminal.wait_for_exit(cx)?; match timeout { Some(timeout) => { - let timeout_task = cx.background_spawn(async move { - smol::Timer::after(timeout).await; - }); + let timeout_task = cx.background_executor().timer(timeout); futures::select! { _ = wait_for_exit.clone().fuse() => {}, @@ -140,17 +139,32 @@ impl AgentTool for TerminalTool { terminal.kill(cx)?; wait_for_exit.await; } + _ = event_stream.cancelled_by_user().fuse() => { + user_stopped_via_signal = true; + terminal.kill(cx)?; + wait_for_exit.await; + } } } None => { - wait_for_exit.await; + futures::select! { + _ = wait_for_exit.clone().fuse() => {}, + _ = event_stream.cancelled_by_user().fuse() => { + user_stopped_via_signal = true; + terminal.kill(cx)?; + wait_for_exit.await; + } + } } }; // Check if user stopped - we check both: // 1. The cancellation signal from RunningTurn::cancel (e.g. user pressed main Stop button) // 2. The terminal's user_stopped flag (e.g. user clicked Stop on the terminal card) - let user_stopped_via_signal = event_stream.was_cancelled_by_user(); + // Note: user_stopped_via_signal is already set above if we detected cancellation in the select! + // but we also check was_cancelled_by_user() for cases where cancellation happened after wait_for_exit completed + let user_stopped_via_signal = + user_stopped_via_signal || event_stream.was_cancelled_by_user(); let user_stopped_via_terminal = terminal.was_stopped_by_user(cx).unwrap_or(false); let user_stopped = user_stopped_via_signal || user_stopped_via_terminal;