From ecc392804056a1daa200197f4decf2009fdab455 Mon Sep 17 00:00:00 2001 From: Richard Feldman Date: Sat, 10 Jan 2026 19:30:25 -0500 Subject: [PATCH] Fix cancellation regression: make edit_file_tool handle cancellation (#46527) PR #46306 changed cancellation to wait for tools to complete before returning. This was correct behavior - it allows tools like terminal to capture their output on cancellation. The real issue was that many tools didn't check for cancellation, so they would continue running until they finished. ## The Problem When the user pressed Escape to cancel during a tool operation, tools would continue running because they never checked for the cancellation signal. The thread correctly waited for tools to complete (so terminal could capture output), but tools like edit_file, grep, fetch, etc. would just keep going. ## The Fix Add cancellation handling to all tools using the same pattern as `terminal_tool`: use `select!` to race between the tool's main work and `event_stream.cancelled_by_user()`. When cancelled, tools break out of their loops or return early. ## All Tools Now Cancellation-Aware | Tool | Change | |------|--------| | `edit_file_tool` | Checks cancellation in edit event processing loop | | `terminal_tool` | Already handled cancellation | | `grep_tool` | Checks cancellation in search result iteration loop | | `fetch_tool` | Checks cancellation during HTTP fetch | | `web_search_tool` | Checks cancellation during web search | | `find_path_tool` | Checks cancellation during path search | | `read_file_tool` | Checks cancellation during buffer open | | `copy_path_tool` | Checks cancellation during file copy | | `move_path_tool` | Checks cancellation during file move/rename | | `delete_path_tool` | Checks cancellation during delete operation | | `create_directory_tool` | Checks cancellation during directory creation | | `save_file_tool` | Checks cancellation during buffer open and save | | `restore_file_from_disk_tool` | Checks cancellation during buffer open and reload | | `open_tool` | Checks cancellation during authorization | | `diagnostics_tool` | Checks cancellation during buffer open | | `ContextServerTool` (MCP) | Checks cancellation during external server calls | **Synchronous tools (no async work, return immediately):** - `list_directory_tool` - Reads worktree snapshot synchronously - `now_tool` - Returns current time immediately - `thinking_tool` - Returns immediately ## MCP Tools Automatically Handled MCP tools (user-defined tools via context servers) are now automatically cancellation-aware without any user action. The `ContextServerTool` wrapper races the external server request against `event_stream.cancelled_by_user()`. ## Testing - Added `CancellationAwareTool` test helper that mirrors the cancellation pattern - Updated `test_cancellation_aware_tool_responds_to_cancellation` to properly await the cancel task and verify the tool detected cancellation Release Notes: - Fixed a regression where pressing Escape wouldn't immediately cancel in-progress tool operations --- crates/agent/src/tests/mod.rs | 96 +++++++++++++++++++ crates/agent/src/tests/test_tools.rs | 57 +++++++++++ .../src/tools/context_server_registry.rs | 29 +++--- crates/agent/src/tools/copy_path_tool.rs | 11 ++- .../agent/src/tools/create_directory_tool.rs | 14 ++- crates/agent/src/tools/delete_path_tool.rs | 27 ++++-- crates/agent/src/tools/diagnostics_tool.rs | 12 ++- crates/agent/src/tools/edit_file_tool.rs | 13 ++- crates/agent/src/tools/fetch_tool.rs | 11 ++- crates/agent/src/tools/find_path_tool.rs | 8 +- crates/agent/src/tools/grep_tool.rs | 15 ++- crates/agent/src/tools/move_path_tool.rs | 11 ++- crates/agent/src/tools/open_tool.rs | 8 +- crates/agent/src/tools/read_file_tool.rs | 18 ++-- .../src/tools/restore_file_from_disk_tool.rs | 28 ++++-- crates/agent/src/tools/save_file_tool.rs | 28 ++++-- crates/agent/src/tools/web_search_tool.rs | 20 ++-- 17 files changed, 342 insertions(+), 64 deletions(-) diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index e65644deddceb12d1c954d20a0658fc9f4c264e5..31d6c194b4325c5d8e68bc55c17ac8b1b4c7f2b5 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -1789,6 +1789,101 @@ async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext verify_thread_recovery(&thread, &fake_model, cx).await; } +#[gpui::test] +async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) { + // This test verifies that tools which properly handle cancellation via + // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly + // to cancellation and report that they were cancelled. + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + always_allow_tools(cx); + let fake_model = model.as_fake(); + + let (tool, was_cancelled) = CancellationAwareTool::new(); + + let mut events = thread + .update(cx, |thread, cx| { + thread.add_tool(tool); + thread.send( + UserMessageId::new(), + ["call the cancellation aware tool"], + cx, + ) + }) + .unwrap(); + + cx.run_until_parked(); + + // Simulate the model calling the cancellation-aware tool + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "cancellation_aware_1".into(), + name: "cancellation_aware".into(), + raw_input: r#"{}"#.into(), + input: json!({}), + is_input_complete: true, + thought_signature: None, + }, + )); + fake_model.end_last_completion_stream(); + + cx.run_until_parked(); + + // Wait for the tool call to be reported + let mut tool_started = false; + let deadline = cx.executor().num_cpus() * 100; + for _ in 0..deadline { + cx.run_until_parked(); + + while let Some(Some(event)) = events.next().now_or_never() { + if let Ok(ThreadEvent::ToolCall(tool_call)) = &event { + if tool_call.title == "Cancellation Aware Tool" { + tool_started = true; + break; + } + } + } + + if tool_started { + break; + } + + cx.background_executor + .timer(Duration::from_millis(10)) + .await; + } + assert!(tool_started, "expected cancellation aware tool to start"); + + // Cancel the thread and wait for it to complete + let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx)); + + // The cancel task should complete promptly because the tool handles cancellation + let timeout = cx.background_executor.timer(Duration::from_secs(5)); + futures::select! { + _ = cancel_task.fuse() => {} + _ = timeout.fuse() => { + panic!("cancel task timed out - tool did not respond to cancellation"); + } + } + + // Verify the tool detected cancellation via its flag + assert!( + was_cancelled.load(std::sync::atomic::Ordering::SeqCst), + "tool should have detected cancellation via event_stream.cancelled_by_user()" + ); + + // Collect remaining events + let remaining_events = collect_events_until_stop(&mut events, cx).await; + + // Verify we got a cancellation stop event + assert_eq!( + stop_events(remaining_events), + vec![acp::StopReason::Cancelled], + ); + + // Verify we can send a new message after cancellation + verify_thread_recovery(&thread, &fake_model, cx).await; +} + /// Helper to verify thread can recover after cancellation by sending a simple message. async fn verify_thread_recovery( thread: &Entity, @@ -3236,6 +3331,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { WordListTool::name(): true, ToolRequiringPermission::name(): true, InfiniteTool::name(): true, + CancellationAwareTool::name(): true, ThinkingTool::name(): true, "terminal": true, } diff --git a/crates/agent/src/tests/test_tools.rs b/crates/agent/src/tests/test_tools.rs index 2275d23c2f8a924efce2d2d4d8bcf6a6f3a59def..58faedcabe935ec64728344d1b1256e53dc7a2f5 100644 --- a/crates/agent/src/tests/test_tools.rs +++ b/crates/agent/src/tests/test_tools.rs @@ -2,6 +2,7 @@ use super::*; use anyhow::Result; use gpui::{App, SharedString, Task}; use std::future; +use std::sync::atomic::{AtomicBool, Ordering}; /// A tool that echoes its input #[derive(JsonSchema, Serialize, Deserialize)] @@ -168,6 +169,62 @@ impl AgentTool for InfiniteTool { } } +/// A tool that loops forever but properly handles cancellation via `select!`, +/// similar to how edit_file_tool handles cancellation. +#[derive(JsonSchema, Serialize, Deserialize)] +pub struct CancellationAwareToolInput {} + +pub struct CancellationAwareTool { + pub was_cancelled: Arc, +} + +impl CancellationAwareTool { + pub fn new() -> (Self, Arc) { + let was_cancelled = Arc::new(AtomicBool::new(false)); + ( + Self { + was_cancelled: was_cancelled.clone(), + }, + was_cancelled, + ) + } +} + +impl AgentTool for CancellationAwareTool { + type Input = CancellationAwareToolInput; + type Output = String; + + fn name() -> &'static str { + "cancellation_aware" + } + + fn kind() -> acp::ToolKind { + acp::ToolKind::Other + } + + fn initial_title( + &self, + _input: Result, + _cx: &mut App, + ) -> SharedString { + "Cancellation Aware Tool".into() + } + + fn run( + self: Arc, + _input: Self::Input, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + cx.foreground_executor().spawn(async move { + // Wait for cancellation - this tool does nothing but wait to be cancelled + event_stream.cancelled_by_user().await; + self.was_cancelled.store(true, Ordering::SeqCst); + anyhow::bail!("Tool cancelled by user"); + }) + } +} + /// A tool that takes an object with map from letters to random words starting with that letter. /// All fiealds are required! Pass a word for every letter! #[derive(JsonSchema, Serialize, Deserialize)] diff --git a/crates/agent/src/tools/context_server_registry.rs b/crates/agent/src/tools/context_server_registry.rs index 30f2be95ef66e27a6802205a1bee5707dcee443d..d73e1d7ddae3573f7a980c0e8f66a39132d59bd5 100644 --- a/crates/agent/src/tools/context_server_registry.rs +++ b/crates/agent/src/tools/context_server_registry.rs @@ -1,8 +1,9 @@ use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream}; use agent_client_protocol::ToolKind; -use anyhow::{Result, anyhow, bail}; +use anyhow::{Result, anyhow}; use collections::{BTreeMap, HashMap}; use context_server::{ContextServerId, client::NotificationSubscription}; +use futures::FutureExt as _; use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task}; use project::context_server_store::{ContextServerStatus, ContextServerStore}; use std::sync::Arc; @@ -337,7 +338,7 @@ impl AnyAgentTool for ContextServerTool { authorize.await?; let Some(protocol) = server.client() else { - bail!("Context server not initialized"); + anyhow::bail!("Context server not initialized"); }; let arguments = if let serde_json::Value::Object(map) = input { @@ -351,15 +352,21 @@ impl AnyAgentTool for ContextServerTool { tool_name, arguments ); - let response = protocol - .request::( - context_server::types::CallToolParams { - name: tool_name, - arguments, - meta: None, - }, - ) - .await?; + + let request = protocol.request::( + context_server::types::CallToolParams { + name: tool_name, + arguments, + meta: None, + }, + ); + + let response = futures::select! { + response = request.fuse() => response?, + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("MCP tool cancelled by user"); + } + }; let mut result = String::new(); for content in response.content { diff --git a/crates/agent/src/tools/copy_path_tool.rs b/crates/agent/src/tools/copy_path_tool.rs index 236978c78f0c2fee7ecf611486349bab094b3cec..e1fdded7125ec847bbed77e898838a97b3dc2271 100644 --- a/crates/agent/src/tools/copy_path_tool.rs +++ b/crates/agent/src/tools/copy_path_tool.rs @@ -1,6 +1,7 @@ use crate::{AgentTool, ToolCallEventStream}; use agent_client_protocol::ToolKind; use anyhow::{Context as _, Result, anyhow}; +use futures::FutureExt as _; use gpui::{App, AppContext, Entity, Task}; use project::Project; use schemars::JsonSchema; @@ -75,7 +76,7 @@ impl AgentTool for CopyPathTool { fn run( self: Arc, input: Self::Input, - _event_stream: ToolCallEventStream, + event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { let copy_task = self.project.update(cx, |project, cx| { @@ -98,7 +99,13 @@ impl AgentTool for CopyPathTool { }); cx.background_spawn(async move { - let _ = copy_task.await.with_context(|| { + let result = futures::select! { + result = copy_task.fuse() => result, + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Copy cancelled by user"); + } + }; + let _ = result.with_context(|| { format!( "Copying {} to {}", input.source_path, input.destination_path diff --git a/crates/agent/src/tools/create_directory_tool.rs b/crates/agent/src/tools/create_directory_tool.rs index b6240e99cf4dd6698bf9f46edd8d4681247d8f64..6ef6566679f0e5dd2f2efb702054d03284115893 100644 --- a/crates/agent/src/tools/create_directory_tool.rs +++ b/crates/agent/src/tools/create_directory_tool.rs @@ -1,5 +1,6 @@ use agent_client_protocol::ToolKind; use anyhow::{Context as _, Result, anyhow}; +use futures::FutureExt as _; use gpui::{App, Entity, SharedString, Task}; use project::Project; use schemars::JsonSchema; @@ -64,7 +65,7 @@ impl AgentTool for CreateDirectoryTool { fn run( self: Arc, input: Self::Input, - _event_stream: ToolCallEventStream, + event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { let project_path = match self.project.read(cx).find_project_path(&input.path, cx) { @@ -80,9 +81,14 @@ impl AgentTool for CreateDirectoryTool { }); cx.spawn(async move |_cx| { - create_entry - .await - .with_context(|| format!("Creating directory {destination_path}"))?; + futures::select! { + result = create_entry.fuse() => { + result.with_context(|| format!("Creating directory {destination_path}"))?; + } + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Create directory cancelled by user"); + } + } Ok(format!("Created directory {destination_path}")) }) diff --git a/crates/agent/src/tools/delete_path_tool.rs b/crates/agent/src/tools/delete_path_tool.rs index 2b7482000856bf929b1f20ced02e39b2e55ec04c..3ec7ad7a7fc3604e6ef1bfc53d5f00f624dc98df 100644 --- a/crates/agent/src/tools/delete_path_tool.rs +++ b/crates/agent/src/tools/delete_path_tool.rs @@ -2,7 +2,7 @@ use crate::{AgentTool, ToolCallEventStream}; use action_log::ActionLog; use agent_client_protocol::ToolKind; use anyhow::{Context as _, Result, anyhow}; -use futures::{SinkExt, StreamExt, channel::mpsc}; +use futures::{FutureExt as _, SinkExt, StreamExt, channel::mpsc}; use gpui::{App, AppContext, Entity, SharedString, Task}; use project::{Project, ProjectPath}; use schemars::JsonSchema; @@ -67,7 +67,7 @@ impl AgentTool for DeletePathTool { fn run( self: Arc, input: Self::Input, - _event_stream: ToolCallEventStream, + event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { let path = input.path; @@ -113,7 +113,16 @@ impl AgentTool for DeletePathTool { let project = self.project.clone(); let action_log = self.action_log.clone(); cx.spawn(async move |cx| { - while let Some(path) = paths_rx.next().await { + loop { + let path_result = futures::select! { + path = paths_rx.next().fuse() => path, + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Delete cancelled by user"); + } + }; + let Some(path) = path_result else { + break; + }; if let Ok(buffer) = project .update(cx, |project, cx| project.open_buffer(path, cx)) .await @@ -131,9 +140,15 @@ impl AgentTool for DeletePathTool { .with_context(|| { format!("Couldn't delete {path} because that path isn't in this project.") })?; - deletion_task - .await - .with_context(|| format!("Deleting {path}"))?; + + futures::select! { + result = deletion_task.fuse() => { + result.with_context(|| format!("Deleting {path}"))?; + } + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Delete cancelled by user"); + } + } Ok(format!("Deleted {path}")) }) } diff --git a/crates/agent/src/tools/diagnostics_tool.rs b/crates/agent/src/tools/diagnostics_tool.rs index ea98f1830d07874d847d93fa532b0b6b3806a1d9..28d2d2c7e7e127f5afc106999e9d1d459f10dc2c 100644 --- a/crates/agent/src/tools/diagnostics_tool.rs +++ b/crates/agent/src/tools/diagnostics_tool.rs @@ -1,6 +1,7 @@ use crate::{AgentTool, ToolCallEventStream}; use agent_client_protocol as acp; use anyhow::{Result, anyhow}; +use futures::FutureExt as _; use gpui::{App, Entity, Task}; use language::{DiagnosticSeverity, OffsetRangeExt}; use project::Project; @@ -89,7 +90,7 @@ impl AgentTool for DiagnosticsTool { fn run( self: Arc, input: Self::Input, - _event_stream: ToolCallEventStream, + event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { match input.path { @@ -98,13 +99,18 @@ impl AgentTool for DiagnosticsTool { return Task::ready(Err(anyhow!("Could not find path {path} in project",))); }; - let buffer = self + let open_buffer_task = self .project .update(cx, |project, cx| project.open_buffer(project_path, cx)); cx.spawn(async move |cx| { + let buffer = futures::select! { + result = open_buffer_task.fuse() => result?, + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Diagnostics cancelled by user"); + } + }; let mut output = String::new(); - let buffer = buffer.await?; let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); for (_, group) in snapshot.diagnostic_groups(None) { diff --git a/crates/agent/src/tools/edit_file_tool.rs b/crates/agent/src/tools/edit_file_tool.rs index 72bb051e867d783ff8188210bb4ed5539be2bb8f..20e624efa1cf5f56365f36ef1d152773f7e9cbe7 100644 --- a/crates/agent/src/tools/edit_file_tool.rs +++ b/crates/agent/src/tools/edit_file_tool.rs @@ -7,6 +7,7 @@ use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields} use anyhow::{Context as _, Result, anyhow}; use cloud_llm_client::CompletionIntent; use collections::HashSet; +use futures::{FutureExt as _, StreamExt as _}; use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity}; use indoc::formatdoc; use language::language_settings::{self, FormatOnSave}; @@ -18,7 +19,6 @@ use project::{Project, ProjectPath}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::Settings; -use smol::stream::StreamExt as _; use std::ffi::OsStr; use std::path::{Path, PathBuf}; use std::sync::Arc; @@ -395,7 +395,16 @@ impl AgentTool for EditFileTool { let mut hallucinated_old_text = false; let mut ambiguous_ranges = Vec::new(); let mut emitted_location = false; - while let Some(event) = events.next().await { + loop { + let event = futures::select! { + event = events.next().fuse() => match event { + Some(event) => event, + None => break, + }, + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Edit cancelled by user"); + } + }; match event { EditAgentOutputEvent::Edited(range) => { if !emitted_location { diff --git a/crates/agent/src/tools/fetch_tool.rs b/crates/agent/src/tools/fetch_tool.rs index 60654ac863acdc559aeaad90f1c73727f33d1b59..a62d24863563d77f4a4140eceb19a31c9da7c8a2 100644 --- a/crates/agent/src/tools/fetch_tool.rs +++ b/crates/agent/src/tools/fetch_tool.rs @@ -4,7 +4,7 @@ use std::{borrow::Cow, cell::RefCell}; use agent_client_protocol as acp; use anyhow::{Context as _, Result, bail}; -use futures::AsyncReadExt as _; +use futures::{AsyncReadExt as _, FutureExt as _}; use gpui::{App, AppContext as _, Task}; use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown}; use http_client::{AsyncBody, HttpClientWithUrl}; @@ -145,7 +145,7 @@ impl AgentTool for FetchTool { ) -> Task> { let authorize = event_stream.authorize(input.url.clone(), cx); - let text = cx.background_spawn({ + let fetch_task = cx.background_spawn({ let http_client = self.http_client.clone(); async move { authorize.await?; @@ -154,7 +154,12 @@ impl AgentTool for FetchTool { }); cx.foreground_executor().spawn(async move { - let text = text.await?; + let text = futures::select! { + result = fetch_task.fuse() => result?, + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Fetch cancelled by user"); + } + }; if text.trim().is_empty() { bail!("no textual content found"); } diff --git a/crates/agent/src/tools/find_path_tool.rs b/crates/agent/src/tools/find_path_tool.rs index 2a33b14b4c87d87154e2aa1ee25363b397189f89..5f64a9cbbf569e4729777b3a4e266e788961dadd 100644 --- a/crates/agent/src/tools/find_path_tool.rs +++ b/crates/agent/src/tools/find_path_tool.rs @@ -1,6 +1,7 @@ use crate::{AgentTool, ToolCallEventStream}; use agent_client_protocol as acp; use anyhow::{Result, anyhow}; +use futures::FutureExt as _; use gpui::{App, AppContext, Entity, SharedString, Task}; use language_model::LanguageModelToolResultContent; use project::Project; @@ -114,7 +115,12 @@ impl AgentTool for FindPathTool { let search_paths_task = search_paths(&input.glob, self.project.clone(), cx); cx.background_spawn(async move { - let matches = search_paths_task.await?; + let matches = futures::select! { + result = search_paths_task.fuse() => result?, + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Path search cancelled by user"); + } + }; let paginated_matches: &[PathBuf] = &matches[cmp::min(input.offset, matches.len()) ..cmp::min(input.offset + RESULTS_PER_PAGE, matches.len())]; diff --git a/crates/agent/src/tools/grep_tool.rs b/crates/agent/src/tools/grep_tool.rs index 20c94ba6902af2a01d3a80062eb9fa803d8c25b0..886b52226f0b351d81771be475db6645873640ee 100644 --- a/crates/agent/src/tools/grep_tool.rs +++ b/crates/agent/src/tools/grep_tool.rs @@ -1,7 +1,7 @@ use crate::{AgentTool, ToolCallEventStream}; use agent_client_protocol as acp; use anyhow::{Result, anyhow}; -use futures::StreamExt; +use futures::{FutureExt as _, StreamExt}; use gpui::{App, Entity, SharedString, Task}; use language::{OffsetRangeExt, ParseStatus, Point}; use project::{ @@ -117,7 +117,7 @@ impl AgentTool for GrepTool { fn run( self: Arc, input: Self::Input, - _event_stream: ToolCallEventStream, + event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { const CONTEXT_LINES: u32 = 2; @@ -186,7 +186,16 @@ impl AgentTool for GrepTool { let mut matches_found = 0; let mut has_more_matches = false; - 'outer: while let Some(SearchResult::Buffer { buffer, ranges }) = rx.next().await { + 'outer: loop { + let search_result = futures::select! { + result = rx.next().fuse() => result, + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Search cancelled by user"); + } + }; + let Some(SearchResult::Buffer { buffer, ranges }) = search_result else { + break; + }; if ranges.is_empty() { continue; } diff --git a/crates/agent/src/tools/move_path_tool.rs b/crates/agent/src/tools/move_path_tool.rs index ae58145126f6356beaa1457d719812bb56d6e7db..8b340089e8b2941c45c7da376b181eaf7eaa2b66 100644 --- a/crates/agent/src/tools/move_path_tool.rs +++ b/crates/agent/src/tools/move_path_tool.rs @@ -1,6 +1,7 @@ use crate::{AgentTool, ToolCallEventStream}; use agent_client_protocol::ToolKind; use anyhow::{Context as _, Result, anyhow}; +use futures::FutureExt as _; use gpui::{App, AppContext, Entity, SharedString, Task}; use project::Project; use schemars::JsonSchema; @@ -89,7 +90,7 @@ impl AgentTool for MovePathTool { fn run( self: Arc, input: Self::Input, - _event_stream: ToolCallEventStream, + event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { let rename_task = self.project.update(cx, |project, cx| { @@ -112,7 +113,13 @@ impl AgentTool for MovePathTool { }); cx.background_spawn(async move { - let _ = rename_task.await.with_context(|| { + let result = futures::select! { + result = rename_task.fuse() => result, + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Move cancelled by user"); + } + }; + let _ = result.with_context(|| { format!("Moving {} to {}", input.source_path, input.destination_path) })?; Ok(format!( diff --git a/crates/agent/src/tools/open_tool.rs b/crates/agent/src/tools/open_tool.rs index 8826d1529ce43df0ea4a3e21795386874168de58..2a3d14fb3ad0129fad7032e798ac5ec58c59b7f3 100644 --- a/crates/agent/src/tools/open_tool.rs +++ b/crates/agent/src/tools/open_tool.rs @@ -1,6 +1,7 @@ use crate::AgentTool; use agent_client_protocol::ToolKind; use anyhow::{Context as _, Result}; +use futures::FutureExt as _; use gpui::{App, AppContext, Entity, SharedString, Task}; use project::Project; use schemars::JsonSchema; @@ -67,7 +68,12 @@ impl AgentTool for OpenTool { let abs_path = to_absolute_path(&input.path_or_url, self.project.clone(), cx); let authorize = event_stream.authorize(self.initial_title(Ok(input.clone()), cx), cx); cx.background_spawn(async move { - authorize.await?; + futures::select! { + result = authorize.fuse() => result?, + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Open cancelled by user"); + } + } match abs_path { Some(path) => open::that(path), diff --git a/crates/agent/src/tools/read_file_tool.rs b/crates/agent/src/tools/read_file_tool.rs index 2fa6efa9cdffa229fd1c2c447345220e460d286f..bc7647739035a41b91c481d2f25b5fbd0f7856c7 100644 --- a/crates/agent/src/tools/read_file_tool.rs +++ b/crates/agent/src/tools/read_file_tool.rs @@ -1,6 +1,7 @@ use action_log::ActionLog; use agent_client_protocol::{self as acp, ToolCallUpdateFields}; use anyhow::{Context as _, Result, anyhow}; +use futures::FutureExt as _; use gpui::{App, Entity, SharedString, Task, WeakEntity}; use indoc::formatdoc; use language::Point; @@ -192,13 +193,18 @@ impl AgentTool for ReadFileTool { let action_log = self.action_log.clone(); cx.spawn(async move |cx| { - let buffer = cx - .update(|cx| { - project.update(cx, |project, cx| { - project.open_buffer(project_path.clone(), cx) - }) + let open_buffer_task = cx.update(|cx| { + project.update(cx, |project, cx| { + project.open_buffer(project_path.clone(), cx) }) - .await?; + }); + + let buffer = futures::select! { + result = open_buffer_task.fuse() => result?, + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("File read cancelled by user"); + } + }; if buffer.read_with(cx, |buffer, _| { buffer .file() diff --git a/crates/agent/src/tools/restore_file_from_disk_tool.rs b/crates/agent/src/tools/restore_file_from_disk_tool.rs index eb2a027c723c80e4225380af7397e51b1af68d2b..0429706c5d9b449b7b54d19655fadeadd5adfbbc 100644 --- a/crates/agent/src/tools/restore_file_from_disk_tool.rs +++ b/crates/agent/src/tools/restore_file_from_disk_tool.rs @@ -1,6 +1,7 @@ use agent_client_protocol as acp; use anyhow::Result; use collections::FxHashSet; +use futures::FutureExt as _; use gpui::{App, Entity, SharedString, Task}; use language::Buffer; use project::Project; @@ -61,7 +62,7 @@ impl AgentTool for RestoreFileFromDiskTool { fn run( self: Arc, input: Self::Input, - _event_stream: ToolCallEventStream, + event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { let project = self.project.clone(); @@ -88,11 +89,18 @@ impl AgentTool for RestoreFileFromDiskTool { let open_buffer_task = project.update(cx, |project, cx| project.open_buffer(project_path, cx)); - let buffer = match open_buffer_task.await { - Ok(buffer) => buffer, - Err(error) => { - open_errors.push((path, error.to_string())); - continue; + let buffer = futures::select! { + result = open_buffer_task.fuse() => { + match result { + Ok(buffer) => buffer, + Err(error) => { + open_errors.push((path, error.to_string())); + continue; + } + } + } + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Restore cancelled by user"); } }; @@ -111,7 +119,13 @@ impl AgentTool for RestoreFileFromDiskTool { project.reload_buffers(buffers_to_reload, true, cx) }); - if let Err(error) = reload_task.await { + let result = futures::select! { + result = reload_task.fuse() => result, + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Restore cancelled by user"); + } + }; + if let Err(error) = result { reload_errors.push(error.to_string()); } } diff --git a/crates/agent/src/tools/save_file_tool.rs b/crates/agent/src/tools/save_file_tool.rs index dab69433bb515614bed2d3509f9d71b7c1c9ef0c..384675bd32728f22effcc7a97ae61ec19bb49d37 100644 --- a/crates/agent/src/tools/save_file_tool.rs +++ b/crates/agent/src/tools/save_file_tool.rs @@ -1,6 +1,7 @@ use agent_client_protocol as acp; use anyhow::Result; use collections::FxHashSet; +use futures::FutureExt as _; use gpui::{App, Entity, SharedString, Task}; use language::Buffer; use project::Project; @@ -58,7 +59,7 @@ impl AgentTool for SaveFileTool { fn run( self: Arc, input: Self::Input, - _event_stream: ToolCallEventStream, + event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { let project = self.project.clone(); @@ -85,11 +86,18 @@ impl AgentTool for SaveFileTool { let open_buffer_task = project.update(cx, |project, cx| project.open_buffer(project_path, cx)); - let buffer = match open_buffer_task.await { - Ok(buffer) => buffer, - Err(error) => { - open_errors.push((path, error.to_string())); - continue; + let buffer = futures::select! { + result = open_buffer_task.fuse() => { + match result { + Ok(buffer) => buffer, + Err(error) => { + open_errors.push((path, error.to_string())); + continue; + } + } + } + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Save cancelled by user"); } }; @@ -116,7 +124,13 @@ impl AgentTool for SaveFileTool { let save_task = project.update(cx, |project, cx| project.save_buffer(buffer, cx)); - if let Err(error) = save_task.await { + let save_result = futures::select! { + result = save_task.fuse() => result, + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Save cancelled by user"); + } + }; + if let Err(error) = save_result { save_errors.push((path_for_buffer, error.to_string())); } } diff --git a/crates/agent/src/tools/web_search_tool.rs b/crates/agent/src/tools/web_search_tool.rs index eb4ebacea2a8e48d6efa9032f46b336ca30c39b6..3d3998cf000783f5fe347432873cc9579dc14f2f 100644 --- a/crates/agent/src/tools/web_search_tool.rs +++ b/crates/agent/src/tools/web_search_tool.rs @@ -4,6 +4,7 @@ use crate::{AgentTool, ToolCallEventStream}; use agent_client_protocol as acp; use anyhow::{Result, anyhow}; use cloud_llm_client::WebSearchResponse; +use futures::FutureExt as _; use gpui::{App, AppContext, Task}; use language_model::{ LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID, @@ -73,12 +74,19 @@ impl AgentTool for WebSearchTool { let search_task = provider.search(input.query, cx); cx.background_spawn(async move { - let response = match search_task.await { - Ok(response) => response, - Err(err) => { - event_stream - .update_fields(acp::ToolCallUpdateFields::new().title("Web Search Failed")); - return Err(err); + let response = futures::select! { + result = search_task.fuse() => { + match result { + Ok(response) => response, + Err(err) => { + event_stream + .update_fields(acp::ToolCallUpdateFields::new().title("Web Search Failed")); + return Err(err); + } + } + } + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Web search cancelled by user"); } };