Detailed changes
@@ -1789,6 +1789,101 @@ async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext
verify_thread_recovery(&thread, &fake_model, cx).await;
}
+#[gpui::test]
+async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
+ // This test verifies that tools which properly handle cancellation via
+ // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
+ // to cancellation and report that they were cancelled.
+ let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+ always_allow_tools(cx);
+ let fake_model = model.as_fake();
+
+ let (tool, was_cancelled) = CancellationAwareTool::new();
+
+ let mut events = thread
+ .update(cx, |thread, cx| {
+ thread.add_tool(tool);
+ thread.send(
+ UserMessageId::new(),
+ ["call the cancellation aware tool"],
+ cx,
+ )
+ })
+ .unwrap();
+
+ cx.run_until_parked();
+
+ // Simulate the model calling the cancellation-aware tool
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: "cancellation_aware_1".into(),
+ name: "cancellation_aware".into(),
+ raw_input: r#"{}"#.into(),
+ input: json!({}),
+ is_input_complete: true,
+ thought_signature: None,
+ },
+ ));
+ fake_model.end_last_completion_stream();
+
+ cx.run_until_parked();
+
+ // Wait for the tool call to be reported
+ let mut tool_started = false;
+ let deadline = cx.executor().num_cpus() * 100;
+ for _ in 0..deadline {
+ cx.run_until_parked();
+
+ while let Some(Some(event)) = events.next().now_or_never() {
+ if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
+ if tool_call.title == "Cancellation Aware Tool" {
+ tool_started = true;
+ break;
+ }
+ }
+ }
+
+ if tool_started {
+ break;
+ }
+
+ cx.background_executor
+ .timer(Duration::from_millis(10))
+ .await;
+ }
+ assert!(tool_started, "expected cancellation aware tool to start");
+
+ // Cancel the thread and wait for it to complete
+ let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
+
+ // The cancel task should complete promptly because the tool handles cancellation
+ let timeout = cx.background_executor.timer(Duration::from_secs(5));
+ futures::select! {
+ _ = cancel_task.fuse() => {}
+ _ = timeout.fuse() => {
+ panic!("cancel task timed out - tool did not respond to cancellation");
+ }
+ }
+
+ // Verify the tool detected cancellation via its flag
+ assert!(
+ was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
+ "tool should have detected cancellation via event_stream.cancelled_by_user()"
+ );
+
+ // Collect remaining events
+ let remaining_events = collect_events_until_stop(&mut events, cx).await;
+
+ // Verify we got a cancellation stop event
+ assert_eq!(
+ stop_events(remaining_events),
+ vec![acp::StopReason::Cancelled],
+ );
+
+ // Verify we can send a new message after cancellation
+ verify_thread_recovery(&thread, &fake_model, cx).await;
+}
+
/// Helper to verify thread can recover after cancellation by sending a simple message.
async fn verify_thread_recovery(
thread: &Entity<Thread>,
@@ -3236,6 +3331,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
WordListTool::name(): true,
ToolRequiringPermission::name(): true,
InfiniteTool::name(): true,
+ CancellationAwareTool::name(): true,
ThinkingTool::name(): true,
"terminal": true,
}
@@ -2,6 +2,7 @@ use super::*;
use anyhow::Result;
use gpui::{App, SharedString, Task};
use std::future;
+use std::sync::atomic::{AtomicBool, Ordering};
/// A tool that echoes its input
#[derive(JsonSchema, Serialize, Deserialize)]
@@ -168,6 +169,62 @@ impl AgentTool for InfiniteTool {
}
}
+/// A tool that loops forever but properly handles cancellation via `select!`,
+/// similar to how edit_file_tool handles cancellation.
+#[derive(JsonSchema, Serialize, Deserialize)]
+pub struct CancellationAwareToolInput {}
+
+pub struct CancellationAwareTool {
+ pub was_cancelled: Arc<AtomicBool>,
+}
+
+impl CancellationAwareTool {
+ pub fn new() -> (Self, Arc<AtomicBool>) {
+ let was_cancelled = Arc::new(AtomicBool::new(false));
+ (
+ Self {
+ was_cancelled: was_cancelled.clone(),
+ },
+ was_cancelled,
+ )
+ }
+}
+
+impl AgentTool for CancellationAwareTool {
+ type Input = CancellationAwareToolInput;
+ type Output = String;
+
+ fn name() -> &'static str {
+ "cancellation_aware"
+ }
+
+ fn kind() -> acp::ToolKind {
+ acp::ToolKind::Other
+ }
+
+ fn initial_title(
+ &self,
+ _input: Result<Self::Input, serde_json::Value>,
+ _cx: &mut App,
+ ) -> SharedString {
+ "Cancellation Aware Tool".into()
+ }
+
+ fn run(
+ self: Arc<Self>,
+ _input: Self::Input,
+ event_stream: ToolCallEventStream,
+ cx: &mut App,
+ ) -> Task<Result<String>> {
+ cx.foreground_executor().spawn(async move {
+ // Wait for cancellation - this tool does nothing but wait to be cancelled
+ event_stream.cancelled_by_user().await;
+ self.was_cancelled.store(true, Ordering::SeqCst);
+ anyhow::bail!("Tool cancelled by user");
+ })
+ }
+}
+
/// A tool that takes an object with map from letters to random words starting with that letter.
/// All fiealds are required! Pass a word for every letter!
#[derive(JsonSchema, Serialize, Deserialize)]
@@ -1,8 +1,9 @@
use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream};
use agent_client_protocol::ToolKind;
-use anyhow::{Result, anyhow, bail};
+use anyhow::{Result, anyhow};
use collections::{BTreeMap, HashMap};
use context_server::{ContextServerId, client::NotificationSubscription};
+use futures::FutureExt as _;
use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task};
use project::context_server_store::{ContextServerStatus, ContextServerStore};
use std::sync::Arc;
@@ -337,7 +338,7 @@ impl AnyAgentTool for ContextServerTool {
authorize.await?;
let Some(protocol) = server.client() else {
- bail!("Context server not initialized");
+ anyhow::bail!("Context server not initialized");
};
let arguments = if let serde_json::Value::Object(map) = input {
@@ -351,15 +352,21 @@ impl AnyAgentTool for ContextServerTool {
tool_name,
arguments
);
- let response = protocol
- .request::<context_server::types::requests::CallTool>(
- context_server::types::CallToolParams {
- name: tool_name,
- arguments,
- meta: None,
- },
- )
- .await?;
+
+ let request = protocol.request::<context_server::types::requests::CallTool>(
+ context_server::types::CallToolParams {
+ name: tool_name,
+ arguments,
+ meta: None,
+ },
+ );
+
+ let response = futures::select! {
+ response = request.fuse() => response?,
+ _ = event_stream.cancelled_by_user().fuse() => {
+ anyhow::bail!("MCP tool cancelled by user");
+ }
+ };
let mut result = String::new();
for content in response.content {
@@ -1,6 +1,7 @@
use crate::{AgentTool, ToolCallEventStream};
use agent_client_protocol::ToolKind;
use anyhow::{Context as _, Result, anyhow};
+use futures::FutureExt as _;
use gpui::{App, AppContext, Entity, Task};
use project::Project;
use schemars::JsonSchema;
@@ -75,7 +76,7 @@ impl AgentTool for CopyPathTool {
fn run(
self: Arc<Self>,
input: Self::Input,
- _event_stream: ToolCallEventStream,
+ event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<Self::Output>> {
let copy_task = self.project.update(cx, |project, cx| {
@@ -98,7 +99,13 @@ impl AgentTool for CopyPathTool {
});
cx.background_spawn(async move {
- let _ = copy_task.await.with_context(|| {
+ let result = futures::select! {
+ result = copy_task.fuse() => result,
+ _ = event_stream.cancelled_by_user().fuse() => {
+ anyhow::bail!("Copy cancelled by user");
+ }
+ };
+ let _ = result.with_context(|| {
format!(
"Copying {} to {}",
input.source_path, input.destination_path
@@ -1,5 +1,6 @@
use agent_client_protocol::ToolKind;
use anyhow::{Context as _, Result, anyhow};
+use futures::FutureExt as _;
use gpui::{App, Entity, SharedString, Task};
use project::Project;
use schemars::JsonSchema;
@@ -64,7 +65,7 @@ impl AgentTool for CreateDirectoryTool {
fn run(
self: Arc<Self>,
input: Self::Input,
- _event_stream: ToolCallEventStream,
+ event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<Self::Output>> {
let project_path = match self.project.read(cx).find_project_path(&input.path, cx) {
@@ -80,9 +81,14 @@ impl AgentTool for CreateDirectoryTool {
});
cx.spawn(async move |_cx| {
- create_entry
- .await
- .with_context(|| format!("Creating directory {destination_path}"))?;
+ futures::select! {
+ result = create_entry.fuse() => {
+ result.with_context(|| format!("Creating directory {destination_path}"))?;
+ }
+ _ = event_stream.cancelled_by_user().fuse() => {
+ anyhow::bail!("Create directory cancelled by user");
+ }
+ }
Ok(format!("Created directory {destination_path}"))
})
@@ -2,7 +2,7 @@ use crate::{AgentTool, ToolCallEventStream};
use action_log::ActionLog;
use agent_client_protocol::ToolKind;
use anyhow::{Context as _, Result, anyhow};
-use futures::{SinkExt, StreamExt, channel::mpsc};
+use futures::{FutureExt as _, SinkExt, StreamExt, channel::mpsc};
use gpui::{App, AppContext, Entity, SharedString, Task};
use project::{Project, ProjectPath};
use schemars::JsonSchema;
@@ -67,7 +67,7 @@ impl AgentTool for DeletePathTool {
fn run(
self: Arc<Self>,
input: Self::Input,
- _event_stream: ToolCallEventStream,
+ event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<Self::Output>> {
let path = input.path;
@@ -113,7 +113,16 @@ impl AgentTool for DeletePathTool {
let project = self.project.clone();
let action_log = self.action_log.clone();
cx.spawn(async move |cx| {
- while let Some(path) = paths_rx.next().await {
+ loop {
+ let path_result = futures::select! {
+ path = paths_rx.next().fuse() => path,
+ _ = event_stream.cancelled_by_user().fuse() => {
+ anyhow::bail!("Delete cancelled by user");
+ }
+ };
+ let Some(path) = path_result else {
+ break;
+ };
if let Ok(buffer) = project
.update(cx, |project, cx| project.open_buffer(path, cx))
.await
@@ -131,9 +140,15 @@ impl AgentTool for DeletePathTool {
.with_context(|| {
format!("Couldn't delete {path} because that path isn't in this project.")
})?;
- deletion_task
- .await
- .with_context(|| format!("Deleting {path}"))?;
+
+ futures::select! {
+ result = deletion_task.fuse() => {
+ result.with_context(|| format!("Deleting {path}"))?;
+ }
+ _ = event_stream.cancelled_by_user().fuse() => {
+ anyhow::bail!("Delete cancelled by user");
+ }
+ }
Ok(format!("Deleted {path}"))
})
}
@@ -1,6 +1,7 @@
use crate::{AgentTool, ToolCallEventStream};
use agent_client_protocol as acp;
use anyhow::{Result, anyhow};
+use futures::FutureExt as _;
use gpui::{App, Entity, Task};
use language::{DiagnosticSeverity, OffsetRangeExt};
use project::Project;
@@ -89,7 +90,7 @@ impl AgentTool for DiagnosticsTool {
fn run(
self: Arc<Self>,
input: Self::Input,
- _event_stream: ToolCallEventStream,
+ event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<Self::Output>> {
match input.path {
@@ -98,13 +99,18 @@ impl AgentTool for DiagnosticsTool {
return Task::ready(Err(anyhow!("Could not find path {path} in project",)));
};
- let buffer = self
+ let open_buffer_task = self
.project
.update(cx, |project, cx| project.open_buffer(project_path, cx));
cx.spawn(async move |cx| {
+ let buffer = futures::select! {
+ result = open_buffer_task.fuse() => result?,
+ _ = event_stream.cancelled_by_user().fuse() => {
+ anyhow::bail!("Diagnostics cancelled by user");
+ }
+ };
let mut output = String::new();
- let buffer = buffer.await?;
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
for (_, group) in snapshot.diagnostic_groups(None) {
@@ -7,6 +7,7 @@ use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields}
use anyhow::{Context as _, Result, anyhow};
use cloud_llm_client::CompletionIntent;
use collections::HashSet;
+use futures::{FutureExt as _, StreamExt as _};
use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
use indoc::formatdoc;
use language::language_settings::{self, FormatOnSave};
@@ -18,7 +19,6 @@ use project::{Project, ProjectPath};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
-use smol::stream::StreamExt as _;
use std::ffi::OsStr;
use std::path::{Path, PathBuf};
use std::sync::Arc;
@@ -395,7 +395,16 @@ impl AgentTool for EditFileTool {
let mut hallucinated_old_text = false;
let mut ambiguous_ranges = Vec::new();
let mut emitted_location = false;
- while let Some(event) = events.next().await {
+ loop {
+ let event = futures::select! {
+ event = events.next().fuse() => match event {
+ Some(event) => event,
+ None => break,
+ },
+ _ = event_stream.cancelled_by_user().fuse() => {
+ anyhow::bail!("Edit cancelled by user");
+ }
+ };
match event {
EditAgentOutputEvent::Edited(range) => {
if !emitted_location {
@@ -4,7 +4,7 @@ use std::{borrow::Cow, cell::RefCell};
use agent_client_protocol as acp;
use anyhow::{Context as _, Result, bail};
-use futures::AsyncReadExt as _;
+use futures::{AsyncReadExt as _, FutureExt as _};
use gpui::{App, AppContext as _, Task};
use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown};
use http_client::{AsyncBody, HttpClientWithUrl};
@@ -145,7 +145,7 @@ impl AgentTool for FetchTool {
) -> Task<Result<Self::Output>> {
let authorize = event_stream.authorize(input.url.clone(), cx);
- let text = cx.background_spawn({
+ let fetch_task = cx.background_spawn({
let http_client = self.http_client.clone();
async move {
authorize.await?;
@@ -154,7 +154,12 @@ impl AgentTool for FetchTool {
});
cx.foreground_executor().spawn(async move {
- let text = text.await?;
+ let text = futures::select! {
+ result = fetch_task.fuse() => result?,
+ _ = event_stream.cancelled_by_user().fuse() => {
+ anyhow::bail!("Fetch cancelled by user");
+ }
+ };
if text.trim().is_empty() {
bail!("no textual content found");
}
@@ -1,6 +1,7 @@
use crate::{AgentTool, ToolCallEventStream};
use agent_client_protocol as acp;
use anyhow::{Result, anyhow};
+use futures::FutureExt as _;
use gpui::{App, AppContext, Entity, SharedString, Task};
use language_model::LanguageModelToolResultContent;
use project::Project;
@@ -114,7 +115,12 @@ impl AgentTool for FindPathTool {
let search_paths_task = search_paths(&input.glob, self.project.clone(), cx);
cx.background_spawn(async move {
- let matches = search_paths_task.await?;
+ let matches = futures::select! {
+ result = search_paths_task.fuse() => result?,
+ _ = event_stream.cancelled_by_user().fuse() => {
+ anyhow::bail!("Path search cancelled by user");
+ }
+ };
let paginated_matches: &[PathBuf] = &matches[cmp::min(input.offset, matches.len())
..cmp::min(input.offset + RESULTS_PER_PAGE, matches.len())];
@@ -1,7 +1,7 @@
use crate::{AgentTool, ToolCallEventStream};
use agent_client_protocol as acp;
use anyhow::{Result, anyhow};
-use futures::StreamExt;
+use futures::{FutureExt as _, StreamExt};
use gpui::{App, Entity, SharedString, Task};
use language::{OffsetRangeExt, ParseStatus, Point};
use project::{
@@ -117,7 +117,7 @@ impl AgentTool for GrepTool {
fn run(
self: Arc<Self>,
input: Self::Input,
- _event_stream: ToolCallEventStream,
+ event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<Self::Output>> {
const CONTEXT_LINES: u32 = 2;
@@ -186,7 +186,16 @@ impl AgentTool for GrepTool {
let mut matches_found = 0;
let mut has_more_matches = false;
- 'outer: while let Some(SearchResult::Buffer { buffer, ranges }) = rx.next().await {
+ 'outer: loop {
+ let search_result = futures::select! {
+ result = rx.next().fuse() => result,
+ _ = event_stream.cancelled_by_user().fuse() => {
+ anyhow::bail!("Search cancelled by user");
+ }
+ };
+ let Some(SearchResult::Buffer { buffer, ranges }) = search_result else {
+ break;
+ };
if ranges.is_empty() {
continue;
}
@@ -1,6 +1,7 @@
use crate::{AgentTool, ToolCallEventStream};
use agent_client_protocol::ToolKind;
use anyhow::{Context as _, Result, anyhow};
+use futures::FutureExt as _;
use gpui::{App, AppContext, Entity, SharedString, Task};
use project::Project;
use schemars::JsonSchema;
@@ -89,7 +90,7 @@ impl AgentTool for MovePathTool {
fn run(
self: Arc<Self>,
input: Self::Input,
- _event_stream: ToolCallEventStream,
+ event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<Self::Output>> {
let rename_task = self.project.update(cx, |project, cx| {
@@ -112,7 +113,13 @@ impl AgentTool for MovePathTool {
});
cx.background_spawn(async move {
- let _ = rename_task.await.with_context(|| {
+ let result = futures::select! {
+ result = rename_task.fuse() => result,
+ _ = event_stream.cancelled_by_user().fuse() => {
+ anyhow::bail!("Move cancelled by user");
+ }
+ };
+ let _ = result.with_context(|| {
format!("Moving {} to {}", input.source_path, input.destination_path)
})?;
Ok(format!(
@@ -1,6 +1,7 @@
use crate::AgentTool;
use agent_client_protocol::ToolKind;
use anyhow::{Context as _, Result};
+use futures::FutureExt as _;
use gpui::{App, AppContext, Entity, SharedString, Task};
use project::Project;
use schemars::JsonSchema;
@@ -67,7 +68,12 @@ impl AgentTool for OpenTool {
let abs_path = to_absolute_path(&input.path_or_url, self.project.clone(), cx);
let authorize = event_stream.authorize(self.initial_title(Ok(input.clone()), cx), cx);
cx.background_spawn(async move {
- authorize.await?;
+ futures::select! {
+ result = authorize.fuse() => result?,
+ _ = event_stream.cancelled_by_user().fuse() => {
+ anyhow::bail!("Open cancelled by user");
+ }
+ }
match abs_path {
Some(path) => open::that(path),
@@ -1,6 +1,7 @@
use action_log::ActionLog;
use agent_client_protocol::{self as acp, ToolCallUpdateFields};
use anyhow::{Context as _, Result, anyhow};
+use futures::FutureExt as _;
use gpui::{App, Entity, SharedString, Task, WeakEntity};
use indoc::formatdoc;
use language::Point;
@@ -192,13 +193,18 @@ impl AgentTool for ReadFileTool {
let action_log = self.action_log.clone();
cx.spawn(async move |cx| {
- let buffer = cx
- .update(|cx| {
- project.update(cx, |project, cx| {
- project.open_buffer(project_path.clone(), cx)
- })
+ let open_buffer_task = cx.update(|cx| {
+ project.update(cx, |project, cx| {
+ project.open_buffer(project_path.clone(), cx)
})
- .await?;
+ });
+
+ let buffer = futures::select! {
+ result = open_buffer_task.fuse() => result?,
+ _ = event_stream.cancelled_by_user().fuse() => {
+ anyhow::bail!("File read cancelled by user");
+ }
+ };
if buffer.read_with(cx, |buffer, _| {
buffer
.file()
@@ -1,6 +1,7 @@
use agent_client_protocol as acp;
use anyhow::Result;
use collections::FxHashSet;
+use futures::FutureExt as _;
use gpui::{App, Entity, SharedString, Task};
use language::Buffer;
use project::Project;
@@ -61,7 +62,7 @@ impl AgentTool for RestoreFileFromDiskTool {
fn run(
self: Arc<Self>,
input: Self::Input,
- _event_stream: ToolCallEventStream,
+ event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>> {
let project = self.project.clone();
@@ -88,11 +89,18 @@ impl AgentTool for RestoreFileFromDiskTool {
let open_buffer_task =
project.update(cx, |project, cx| project.open_buffer(project_path, cx));
- let buffer = match open_buffer_task.await {
- Ok(buffer) => buffer,
- Err(error) => {
- open_errors.push((path, error.to_string()));
- continue;
+ let buffer = futures::select! {
+ result = open_buffer_task.fuse() => {
+ match result {
+ Ok(buffer) => buffer,
+ Err(error) => {
+ open_errors.push((path, error.to_string()));
+ continue;
+ }
+ }
+ }
+ _ = event_stream.cancelled_by_user().fuse() => {
+ anyhow::bail!("Restore cancelled by user");
}
};
@@ -111,7 +119,13 @@ impl AgentTool for RestoreFileFromDiskTool {
project.reload_buffers(buffers_to_reload, true, cx)
});
- if let Err(error) = reload_task.await {
+ let result = futures::select! {
+ result = reload_task.fuse() => result,
+ _ = event_stream.cancelled_by_user().fuse() => {
+ anyhow::bail!("Restore cancelled by user");
+ }
+ };
+ if let Err(error) = result {
reload_errors.push(error.to_string());
}
}
@@ -1,6 +1,7 @@
use agent_client_protocol as acp;
use anyhow::Result;
use collections::FxHashSet;
+use futures::FutureExt as _;
use gpui::{App, Entity, SharedString, Task};
use language::Buffer;
use project::Project;
@@ -58,7 +59,7 @@ impl AgentTool for SaveFileTool {
fn run(
self: Arc<Self>,
input: Self::Input,
- _event_stream: ToolCallEventStream,
+ event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>> {
let project = self.project.clone();
@@ -85,11 +86,18 @@ impl AgentTool for SaveFileTool {
let open_buffer_task =
project.update(cx, |project, cx| project.open_buffer(project_path, cx));
- let buffer = match open_buffer_task.await {
- Ok(buffer) => buffer,
- Err(error) => {
- open_errors.push((path, error.to_string()));
- continue;
+ let buffer = futures::select! {
+ result = open_buffer_task.fuse() => {
+ match result {
+ Ok(buffer) => buffer,
+ Err(error) => {
+ open_errors.push((path, error.to_string()));
+ continue;
+ }
+ }
+ }
+ _ = event_stream.cancelled_by_user().fuse() => {
+ anyhow::bail!("Save cancelled by user");
}
};
@@ -116,7 +124,13 @@ impl AgentTool for SaveFileTool {
let save_task = project.update(cx, |project, cx| project.save_buffer(buffer, cx));
- if let Err(error) = save_task.await {
+ let save_result = futures::select! {
+ result = save_task.fuse() => result,
+ _ = event_stream.cancelled_by_user().fuse() => {
+ anyhow::bail!("Save cancelled by user");
+ }
+ };
+ if let Err(error) = save_result {
save_errors.push((path_for_buffer, error.to_string()));
}
}
@@ -4,6 +4,7 @@ use crate::{AgentTool, ToolCallEventStream};
use agent_client_protocol as acp;
use anyhow::{Result, anyhow};
use cloud_llm_client::WebSearchResponse;
+use futures::FutureExt as _;
use gpui::{App, AppContext, Task};
use language_model::{
LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID,
@@ -73,12 +74,19 @@ impl AgentTool for WebSearchTool {
let search_task = provider.search(input.query, cx);
cx.background_spawn(async move {
- let response = match search_task.await {
- Ok(response) => response,
- Err(err) => {
- event_stream
- .update_fields(acp::ToolCallUpdateFields::new().title("Web Search Failed"));
- return Err(err);
+ let response = futures::select! {
+ result = search_task.fuse() => {
+ match result {
+ Ok(response) => response,
+ Err(err) => {
+ event_stream
+ .update_fields(acp::ToolCallUpdateFields::new().title("Web Search Failed"));
+ return Err(err);
+ }
+ }
+ }
+ _ = event_stream.cancelled_by_user().fuse() => {
+ anyhow::bail!("Web search cancelled by user");
}
};