diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index e65644deddceb12d1c954d20a0658fc9f4c264e5..31d6c194b4325c5d8e68bc55c17ac8b1b4c7f2b5 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -1789,6 +1789,101 @@ async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext verify_thread_recovery(&thread, &fake_model, cx).await; } +#[gpui::test] +async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) { + // This test verifies that tools which properly handle cancellation via + // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly + // to cancellation and report that they were cancelled. + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + always_allow_tools(cx); + let fake_model = model.as_fake(); + + let (tool, was_cancelled) = CancellationAwareTool::new(); + + let mut events = thread + .update(cx, |thread, cx| { + thread.add_tool(tool); + thread.send( + UserMessageId::new(), + ["call the cancellation aware tool"], + cx, + ) + }) + .unwrap(); + + cx.run_until_parked(); + + // Simulate the model calling the cancellation-aware tool + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "cancellation_aware_1".into(), + name: "cancellation_aware".into(), + raw_input: r#"{}"#.into(), + input: json!({}), + is_input_complete: true, + thought_signature: None, + }, + )); + fake_model.end_last_completion_stream(); + + cx.run_until_parked(); + + // Wait for the tool call to be reported + let mut tool_started = false; + let deadline = cx.executor().num_cpus() * 100; + for _ in 0..deadline { + cx.run_until_parked(); + + while let Some(Some(event)) = events.next().now_or_never() { + if let Ok(ThreadEvent::ToolCall(tool_call)) = &event { + if tool_call.title == "Cancellation Aware Tool" { + tool_started = true; + break; + } + } + } + + if tool_started { + break; + } + + cx.background_executor + .timer(Duration::from_millis(10)) + .await; + } + assert!(tool_started, "expected cancellation aware tool to start"); + + // Cancel the thread and wait for it to complete + let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx)); + + // The cancel task should complete promptly because the tool handles cancellation + let timeout = cx.background_executor.timer(Duration::from_secs(5)); + futures::select! { + _ = cancel_task.fuse() => {} + _ = timeout.fuse() => { + panic!("cancel task timed out - tool did not respond to cancellation"); + } + } + + // Verify the tool detected cancellation via its flag + assert!( + was_cancelled.load(std::sync::atomic::Ordering::SeqCst), + "tool should have detected cancellation via event_stream.cancelled_by_user()" + ); + + // Collect remaining events + let remaining_events = collect_events_until_stop(&mut events, cx).await; + + // Verify we got a cancellation stop event + assert_eq!( + stop_events(remaining_events), + vec![acp::StopReason::Cancelled], + ); + + // Verify we can send a new message after cancellation + verify_thread_recovery(&thread, &fake_model, cx).await; +} + /// Helper to verify thread can recover after cancellation by sending a simple message. async fn verify_thread_recovery( thread: &Entity, @@ -3236,6 +3331,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { WordListTool::name(): true, ToolRequiringPermission::name(): true, InfiniteTool::name(): true, + CancellationAwareTool::name(): true, ThinkingTool::name(): true, "terminal": true, } diff --git a/crates/agent/src/tests/test_tools.rs b/crates/agent/src/tests/test_tools.rs index 2275d23c2f8a924efce2d2d4d8bcf6a6f3a59def..58faedcabe935ec64728344d1b1256e53dc7a2f5 100644 --- a/crates/agent/src/tests/test_tools.rs +++ b/crates/agent/src/tests/test_tools.rs @@ -2,6 +2,7 @@ use super::*; use anyhow::Result; use gpui::{App, SharedString, Task}; use std::future; +use std::sync::atomic::{AtomicBool, Ordering}; /// A tool that echoes its input #[derive(JsonSchema, Serialize, Deserialize)] @@ -168,6 +169,62 @@ impl AgentTool for InfiniteTool { } } +/// A tool that loops forever but properly handles cancellation via `select!`, +/// similar to how edit_file_tool handles cancellation. +#[derive(JsonSchema, Serialize, Deserialize)] +pub struct CancellationAwareToolInput {} + +pub struct CancellationAwareTool { + pub was_cancelled: Arc, +} + +impl CancellationAwareTool { + pub fn new() -> (Self, Arc) { + let was_cancelled = Arc::new(AtomicBool::new(false)); + ( + Self { + was_cancelled: was_cancelled.clone(), + }, + was_cancelled, + ) + } +} + +impl AgentTool for CancellationAwareTool { + type Input = CancellationAwareToolInput; + type Output = String; + + fn name() -> &'static str { + "cancellation_aware" + } + + fn kind() -> acp::ToolKind { + acp::ToolKind::Other + } + + fn initial_title( + &self, + _input: Result, + _cx: &mut App, + ) -> SharedString { + "Cancellation Aware Tool".into() + } + + fn run( + self: Arc, + _input: Self::Input, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + cx.foreground_executor().spawn(async move { + // Wait for cancellation - this tool does nothing but wait to be cancelled + event_stream.cancelled_by_user().await; + self.was_cancelled.store(true, Ordering::SeqCst); + anyhow::bail!("Tool cancelled by user"); + }) + } +} + /// A tool that takes an object with map from letters to random words starting with that letter. /// All fiealds are required! Pass a word for every letter! #[derive(JsonSchema, Serialize, Deserialize)] diff --git a/crates/agent/src/tools/context_server_registry.rs b/crates/agent/src/tools/context_server_registry.rs index 30f2be95ef66e27a6802205a1bee5707dcee443d..d73e1d7ddae3573f7a980c0e8f66a39132d59bd5 100644 --- a/crates/agent/src/tools/context_server_registry.rs +++ b/crates/agent/src/tools/context_server_registry.rs @@ -1,8 +1,9 @@ use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream}; use agent_client_protocol::ToolKind; -use anyhow::{Result, anyhow, bail}; +use anyhow::{Result, anyhow}; use collections::{BTreeMap, HashMap}; use context_server::{ContextServerId, client::NotificationSubscription}; +use futures::FutureExt as _; use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task}; use project::context_server_store::{ContextServerStatus, ContextServerStore}; use std::sync::Arc; @@ -337,7 +338,7 @@ impl AnyAgentTool for ContextServerTool { authorize.await?; let Some(protocol) = server.client() else { - bail!("Context server not initialized"); + anyhow::bail!("Context server not initialized"); }; let arguments = if let serde_json::Value::Object(map) = input { @@ -351,15 +352,21 @@ impl AnyAgentTool for ContextServerTool { tool_name, arguments ); - let response = protocol - .request::( - context_server::types::CallToolParams { - name: tool_name, - arguments, - meta: None, - }, - ) - .await?; + + let request = protocol.request::( + context_server::types::CallToolParams { + name: tool_name, + arguments, + meta: None, + }, + ); + + let response = futures::select! { + response = request.fuse() => response?, + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("MCP tool cancelled by user"); + } + }; let mut result = String::new(); for content in response.content { diff --git a/crates/agent/src/tools/copy_path_tool.rs b/crates/agent/src/tools/copy_path_tool.rs index 236978c78f0c2fee7ecf611486349bab094b3cec..e1fdded7125ec847bbed77e898838a97b3dc2271 100644 --- a/crates/agent/src/tools/copy_path_tool.rs +++ b/crates/agent/src/tools/copy_path_tool.rs @@ -1,6 +1,7 @@ use crate::{AgentTool, ToolCallEventStream}; use agent_client_protocol::ToolKind; use anyhow::{Context as _, Result, anyhow}; +use futures::FutureExt as _; use gpui::{App, AppContext, Entity, Task}; use project::Project; use schemars::JsonSchema; @@ -75,7 +76,7 @@ impl AgentTool for CopyPathTool { fn run( self: Arc, input: Self::Input, - _event_stream: ToolCallEventStream, + event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { let copy_task = self.project.update(cx, |project, cx| { @@ -98,7 +99,13 @@ impl AgentTool for CopyPathTool { }); cx.background_spawn(async move { - let _ = copy_task.await.with_context(|| { + let result = futures::select! { + result = copy_task.fuse() => result, + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Copy cancelled by user"); + } + }; + let _ = result.with_context(|| { format!( "Copying {} to {}", input.source_path, input.destination_path diff --git a/crates/agent/src/tools/create_directory_tool.rs b/crates/agent/src/tools/create_directory_tool.rs index b6240e99cf4dd6698bf9f46edd8d4681247d8f64..6ef6566679f0e5dd2f2efb702054d03284115893 100644 --- a/crates/agent/src/tools/create_directory_tool.rs +++ b/crates/agent/src/tools/create_directory_tool.rs @@ -1,5 +1,6 @@ use agent_client_protocol::ToolKind; use anyhow::{Context as _, Result, anyhow}; +use futures::FutureExt as _; use gpui::{App, Entity, SharedString, Task}; use project::Project; use schemars::JsonSchema; @@ -64,7 +65,7 @@ impl AgentTool for CreateDirectoryTool { fn run( self: Arc, input: Self::Input, - _event_stream: ToolCallEventStream, + event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { let project_path = match self.project.read(cx).find_project_path(&input.path, cx) { @@ -80,9 +81,14 @@ impl AgentTool for CreateDirectoryTool { }); cx.spawn(async move |_cx| { - create_entry - .await - .with_context(|| format!("Creating directory {destination_path}"))?; + futures::select! { + result = create_entry.fuse() => { + result.with_context(|| format!("Creating directory {destination_path}"))?; + } + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Create directory cancelled by user"); + } + } Ok(format!("Created directory {destination_path}")) }) diff --git a/crates/agent/src/tools/delete_path_tool.rs b/crates/agent/src/tools/delete_path_tool.rs index 2b7482000856bf929b1f20ced02e39b2e55ec04c..3ec7ad7a7fc3604e6ef1bfc53d5f00f624dc98df 100644 --- a/crates/agent/src/tools/delete_path_tool.rs +++ b/crates/agent/src/tools/delete_path_tool.rs @@ -2,7 +2,7 @@ use crate::{AgentTool, ToolCallEventStream}; use action_log::ActionLog; use agent_client_protocol::ToolKind; use anyhow::{Context as _, Result, anyhow}; -use futures::{SinkExt, StreamExt, channel::mpsc}; +use futures::{FutureExt as _, SinkExt, StreamExt, channel::mpsc}; use gpui::{App, AppContext, Entity, SharedString, Task}; use project::{Project, ProjectPath}; use schemars::JsonSchema; @@ -67,7 +67,7 @@ impl AgentTool for DeletePathTool { fn run( self: Arc, input: Self::Input, - _event_stream: ToolCallEventStream, + event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { let path = input.path; @@ -113,7 +113,16 @@ impl AgentTool for DeletePathTool { let project = self.project.clone(); let action_log = self.action_log.clone(); cx.spawn(async move |cx| { - while let Some(path) = paths_rx.next().await { + loop { + let path_result = futures::select! { + path = paths_rx.next().fuse() => path, + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Delete cancelled by user"); + } + }; + let Some(path) = path_result else { + break; + }; if let Ok(buffer) = project .update(cx, |project, cx| project.open_buffer(path, cx)) .await @@ -131,9 +140,15 @@ impl AgentTool for DeletePathTool { .with_context(|| { format!("Couldn't delete {path} because that path isn't in this project.") })?; - deletion_task - .await - .with_context(|| format!("Deleting {path}"))?; + + futures::select! { + result = deletion_task.fuse() => { + result.with_context(|| format!("Deleting {path}"))?; + } + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Delete cancelled by user"); + } + } Ok(format!("Deleted {path}")) }) } diff --git a/crates/agent/src/tools/diagnostics_tool.rs b/crates/agent/src/tools/diagnostics_tool.rs index ea98f1830d07874d847d93fa532b0b6b3806a1d9..28d2d2c7e7e127f5afc106999e9d1d459f10dc2c 100644 --- a/crates/agent/src/tools/diagnostics_tool.rs +++ b/crates/agent/src/tools/diagnostics_tool.rs @@ -1,6 +1,7 @@ use crate::{AgentTool, ToolCallEventStream}; use agent_client_protocol as acp; use anyhow::{Result, anyhow}; +use futures::FutureExt as _; use gpui::{App, Entity, Task}; use language::{DiagnosticSeverity, OffsetRangeExt}; use project::Project; @@ -89,7 +90,7 @@ impl AgentTool for DiagnosticsTool { fn run( self: Arc, input: Self::Input, - _event_stream: ToolCallEventStream, + event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { match input.path { @@ -98,13 +99,18 @@ impl AgentTool for DiagnosticsTool { return Task::ready(Err(anyhow!("Could not find path {path} in project",))); }; - let buffer = self + let open_buffer_task = self .project .update(cx, |project, cx| project.open_buffer(project_path, cx)); cx.spawn(async move |cx| { + let buffer = futures::select! { + result = open_buffer_task.fuse() => result?, + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Diagnostics cancelled by user"); + } + }; let mut output = String::new(); - let buffer = buffer.await?; let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); for (_, group) in snapshot.diagnostic_groups(None) { diff --git a/crates/agent/src/tools/edit_file_tool.rs b/crates/agent/src/tools/edit_file_tool.rs index 72bb051e867d783ff8188210bb4ed5539be2bb8f..20e624efa1cf5f56365f36ef1d152773f7e9cbe7 100644 --- a/crates/agent/src/tools/edit_file_tool.rs +++ b/crates/agent/src/tools/edit_file_tool.rs @@ -7,6 +7,7 @@ use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields} use anyhow::{Context as _, Result, anyhow}; use cloud_llm_client::CompletionIntent; use collections::HashSet; +use futures::{FutureExt as _, StreamExt as _}; use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity}; use indoc::formatdoc; use language::language_settings::{self, FormatOnSave}; @@ -18,7 +19,6 @@ use project::{Project, ProjectPath}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::Settings; -use smol::stream::StreamExt as _; use std::ffi::OsStr; use std::path::{Path, PathBuf}; use std::sync::Arc; @@ -395,7 +395,16 @@ impl AgentTool for EditFileTool { let mut hallucinated_old_text = false; let mut ambiguous_ranges = Vec::new(); let mut emitted_location = false; - while let Some(event) = events.next().await { + loop { + let event = futures::select! { + event = events.next().fuse() => match event { + Some(event) => event, + None => break, + }, + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Edit cancelled by user"); + } + }; match event { EditAgentOutputEvent::Edited(range) => { if !emitted_location { diff --git a/crates/agent/src/tools/fetch_tool.rs b/crates/agent/src/tools/fetch_tool.rs index 60654ac863acdc559aeaad90f1c73727f33d1b59..a62d24863563d77f4a4140eceb19a31c9da7c8a2 100644 --- a/crates/agent/src/tools/fetch_tool.rs +++ b/crates/agent/src/tools/fetch_tool.rs @@ -4,7 +4,7 @@ use std::{borrow::Cow, cell::RefCell}; use agent_client_protocol as acp; use anyhow::{Context as _, Result, bail}; -use futures::AsyncReadExt as _; +use futures::{AsyncReadExt as _, FutureExt as _}; use gpui::{App, AppContext as _, Task}; use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown}; use http_client::{AsyncBody, HttpClientWithUrl}; @@ -145,7 +145,7 @@ impl AgentTool for FetchTool { ) -> Task> { let authorize = event_stream.authorize(input.url.clone(), cx); - let text = cx.background_spawn({ + let fetch_task = cx.background_spawn({ let http_client = self.http_client.clone(); async move { authorize.await?; @@ -154,7 +154,12 @@ impl AgentTool for FetchTool { }); cx.foreground_executor().spawn(async move { - let text = text.await?; + let text = futures::select! { + result = fetch_task.fuse() => result?, + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Fetch cancelled by user"); + } + }; if text.trim().is_empty() { bail!("no textual content found"); } diff --git a/crates/agent/src/tools/find_path_tool.rs b/crates/agent/src/tools/find_path_tool.rs index 2a33b14b4c87d87154e2aa1ee25363b397189f89..5f64a9cbbf569e4729777b3a4e266e788961dadd 100644 --- a/crates/agent/src/tools/find_path_tool.rs +++ b/crates/agent/src/tools/find_path_tool.rs @@ -1,6 +1,7 @@ use crate::{AgentTool, ToolCallEventStream}; use agent_client_protocol as acp; use anyhow::{Result, anyhow}; +use futures::FutureExt as _; use gpui::{App, AppContext, Entity, SharedString, Task}; use language_model::LanguageModelToolResultContent; use project::Project; @@ -114,7 +115,12 @@ impl AgentTool for FindPathTool { let search_paths_task = search_paths(&input.glob, self.project.clone(), cx); cx.background_spawn(async move { - let matches = search_paths_task.await?; + let matches = futures::select! { + result = search_paths_task.fuse() => result?, + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Path search cancelled by user"); + } + }; let paginated_matches: &[PathBuf] = &matches[cmp::min(input.offset, matches.len()) ..cmp::min(input.offset + RESULTS_PER_PAGE, matches.len())]; diff --git a/crates/agent/src/tools/grep_tool.rs b/crates/agent/src/tools/grep_tool.rs index 20c94ba6902af2a01d3a80062eb9fa803d8c25b0..886b52226f0b351d81771be475db6645873640ee 100644 --- a/crates/agent/src/tools/grep_tool.rs +++ b/crates/agent/src/tools/grep_tool.rs @@ -1,7 +1,7 @@ use crate::{AgentTool, ToolCallEventStream}; use agent_client_protocol as acp; use anyhow::{Result, anyhow}; -use futures::StreamExt; +use futures::{FutureExt as _, StreamExt}; use gpui::{App, Entity, SharedString, Task}; use language::{OffsetRangeExt, ParseStatus, Point}; use project::{ @@ -117,7 +117,7 @@ impl AgentTool for GrepTool { fn run( self: Arc, input: Self::Input, - _event_stream: ToolCallEventStream, + event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { const CONTEXT_LINES: u32 = 2; @@ -186,7 +186,16 @@ impl AgentTool for GrepTool { let mut matches_found = 0; let mut has_more_matches = false; - 'outer: while let Some(SearchResult::Buffer { buffer, ranges }) = rx.next().await { + 'outer: loop { + let search_result = futures::select! { + result = rx.next().fuse() => result, + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Search cancelled by user"); + } + }; + let Some(SearchResult::Buffer { buffer, ranges }) = search_result else { + break; + }; if ranges.is_empty() { continue; } diff --git a/crates/agent/src/tools/move_path_tool.rs b/crates/agent/src/tools/move_path_tool.rs index ae58145126f6356beaa1457d719812bb56d6e7db..8b340089e8b2941c45c7da376b181eaf7eaa2b66 100644 --- a/crates/agent/src/tools/move_path_tool.rs +++ b/crates/agent/src/tools/move_path_tool.rs @@ -1,6 +1,7 @@ use crate::{AgentTool, ToolCallEventStream}; use agent_client_protocol::ToolKind; use anyhow::{Context as _, Result, anyhow}; +use futures::FutureExt as _; use gpui::{App, AppContext, Entity, SharedString, Task}; use project::Project; use schemars::JsonSchema; @@ -89,7 +90,7 @@ impl AgentTool for MovePathTool { fn run( self: Arc, input: Self::Input, - _event_stream: ToolCallEventStream, + event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { let rename_task = self.project.update(cx, |project, cx| { @@ -112,7 +113,13 @@ impl AgentTool for MovePathTool { }); cx.background_spawn(async move { - let _ = rename_task.await.with_context(|| { + let result = futures::select! { + result = rename_task.fuse() => result, + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Move cancelled by user"); + } + }; + let _ = result.with_context(|| { format!("Moving {} to {}", input.source_path, input.destination_path) })?; Ok(format!( diff --git a/crates/agent/src/tools/open_tool.rs b/crates/agent/src/tools/open_tool.rs index 8826d1529ce43df0ea4a3e21795386874168de58..2a3d14fb3ad0129fad7032e798ac5ec58c59b7f3 100644 --- a/crates/agent/src/tools/open_tool.rs +++ b/crates/agent/src/tools/open_tool.rs @@ -1,6 +1,7 @@ use crate::AgentTool; use agent_client_protocol::ToolKind; use anyhow::{Context as _, Result}; +use futures::FutureExt as _; use gpui::{App, AppContext, Entity, SharedString, Task}; use project::Project; use schemars::JsonSchema; @@ -67,7 +68,12 @@ impl AgentTool for OpenTool { let abs_path = to_absolute_path(&input.path_or_url, self.project.clone(), cx); let authorize = event_stream.authorize(self.initial_title(Ok(input.clone()), cx), cx); cx.background_spawn(async move { - authorize.await?; + futures::select! { + result = authorize.fuse() => result?, + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Open cancelled by user"); + } + } match abs_path { Some(path) => open::that(path), diff --git a/crates/agent/src/tools/read_file_tool.rs b/crates/agent/src/tools/read_file_tool.rs index 2fa6efa9cdffa229fd1c2c447345220e460d286f..bc7647739035a41b91c481d2f25b5fbd0f7856c7 100644 --- a/crates/agent/src/tools/read_file_tool.rs +++ b/crates/agent/src/tools/read_file_tool.rs @@ -1,6 +1,7 @@ use action_log::ActionLog; use agent_client_protocol::{self as acp, ToolCallUpdateFields}; use anyhow::{Context as _, Result, anyhow}; +use futures::FutureExt as _; use gpui::{App, Entity, SharedString, Task, WeakEntity}; use indoc::formatdoc; use language::Point; @@ -192,13 +193,18 @@ impl AgentTool for ReadFileTool { let action_log = self.action_log.clone(); cx.spawn(async move |cx| { - let buffer = cx - .update(|cx| { - project.update(cx, |project, cx| { - project.open_buffer(project_path.clone(), cx) - }) + let open_buffer_task = cx.update(|cx| { + project.update(cx, |project, cx| { + project.open_buffer(project_path.clone(), cx) }) - .await?; + }); + + let buffer = futures::select! { + result = open_buffer_task.fuse() => result?, + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("File read cancelled by user"); + } + }; if buffer.read_with(cx, |buffer, _| { buffer .file() diff --git a/crates/agent/src/tools/restore_file_from_disk_tool.rs b/crates/agent/src/tools/restore_file_from_disk_tool.rs index eb2a027c723c80e4225380af7397e51b1af68d2b..0429706c5d9b449b7b54d19655fadeadd5adfbbc 100644 --- a/crates/agent/src/tools/restore_file_from_disk_tool.rs +++ b/crates/agent/src/tools/restore_file_from_disk_tool.rs @@ -1,6 +1,7 @@ use agent_client_protocol as acp; use anyhow::Result; use collections::FxHashSet; +use futures::FutureExt as _; use gpui::{App, Entity, SharedString, Task}; use language::Buffer; use project::Project; @@ -61,7 +62,7 @@ impl AgentTool for RestoreFileFromDiskTool { fn run( self: Arc, input: Self::Input, - _event_stream: ToolCallEventStream, + event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { let project = self.project.clone(); @@ -88,11 +89,18 @@ impl AgentTool for RestoreFileFromDiskTool { let open_buffer_task = project.update(cx, |project, cx| project.open_buffer(project_path, cx)); - let buffer = match open_buffer_task.await { - Ok(buffer) => buffer, - Err(error) => { - open_errors.push((path, error.to_string())); - continue; + let buffer = futures::select! { + result = open_buffer_task.fuse() => { + match result { + Ok(buffer) => buffer, + Err(error) => { + open_errors.push((path, error.to_string())); + continue; + } + } + } + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Restore cancelled by user"); } }; @@ -111,7 +119,13 @@ impl AgentTool for RestoreFileFromDiskTool { project.reload_buffers(buffers_to_reload, true, cx) }); - if let Err(error) = reload_task.await { + let result = futures::select! { + result = reload_task.fuse() => result, + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Restore cancelled by user"); + } + }; + if let Err(error) = result { reload_errors.push(error.to_string()); } } diff --git a/crates/agent/src/tools/save_file_tool.rs b/crates/agent/src/tools/save_file_tool.rs index dab69433bb515614bed2d3509f9d71b7c1c9ef0c..384675bd32728f22effcc7a97ae61ec19bb49d37 100644 --- a/crates/agent/src/tools/save_file_tool.rs +++ b/crates/agent/src/tools/save_file_tool.rs @@ -1,6 +1,7 @@ use agent_client_protocol as acp; use anyhow::Result; use collections::FxHashSet; +use futures::FutureExt as _; use gpui::{App, Entity, SharedString, Task}; use language::Buffer; use project::Project; @@ -58,7 +59,7 @@ impl AgentTool for SaveFileTool { fn run( self: Arc, input: Self::Input, - _event_stream: ToolCallEventStream, + event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { let project = self.project.clone(); @@ -85,11 +86,18 @@ impl AgentTool for SaveFileTool { let open_buffer_task = project.update(cx, |project, cx| project.open_buffer(project_path, cx)); - let buffer = match open_buffer_task.await { - Ok(buffer) => buffer, - Err(error) => { - open_errors.push((path, error.to_string())); - continue; + let buffer = futures::select! { + result = open_buffer_task.fuse() => { + match result { + Ok(buffer) => buffer, + Err(error) => { + open_errors.push((path, error.to_string())); + continue; + } + } + } + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Save cancelled by user"); } }; @@ -116,7 +124,13 @@ impl AgentTool for SaveFileTool { let save_task = project.update(cx, |project, cx| project.save_buffer(buffer, cx)); - if let Err(error) = save_task.await { + let save_result = futures::select! { + result = save_task.fuse() => result, + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Save cancelled by user"); + } + }; + if let Err(error) = save_result { save_errors.push((path_for_buffer, error.to_string())); } } diff --git a/crates/agent/src/tools/web_search_tool.rs b/crates/agent/src/tools/web_search_tool.rs index eb4ebacea2a8e48d6efa9032f46b336ca30c39b6..3d3998cf000783f5fe347432873cc9579dc14f2f 100644 --- a/crates/agent/src/tools/web_search_tool.rs +++ b/crates/agent/src/tools/web_search_tool.rs @@ -4,6 +4,7 @@ use crate::{AgentTool, ToolCallEventStream}; use agent_client_protocol as acp; use anyhow::{Result, anyhow}; use cloud_llm_client::WebSearchResponse; +use futures::FutureExt as _; use gpui::{App, AppContext, Task}; use language_model::{ LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID, @@ -73,12 +74,19 @@ impl AgentTool for WebSearchTool { let search_task = provider.search(input.query, cx); cx.background_spawn(async move { - let response = match search_task.await { - Ok(response) => response, - Err(err) => { - event_stream - .update_fields(acp::ToolCallUpdateFields::new().title("Web Search Failed")); - return Err(err); + let response = futures::select! { + result = search_task.fuse() => { + match result { + Ok(response) => response, + Err(err) => { + event_stream + .update_fields(acp::ToolCallUpdateFields::new().title("Web Search Failed")); + return Err(err); + } + } + } + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Web search cancelled by user"); } };