From c235d539dd720a1e224c4e5cbf2e430da2353e38 Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Wed, 25 Feb 2026 17:46:27 +0100 Subject: [PATCH] agent: Support streaming tool input (#50099) This PR introduces a `ToolInput` struct which allows tools to receive their inputs incrementally as they stream in. Right now no tool makes use of the streaming APIs, will be used for the streaming edit file tool in #50004 Release Notes: - N/A --- crates/agent/src/tests/mod.rs | 66 ++--- crates/agent/src/tests/test_tools.rs | 89 +++++-- crates/agent/src/thread.rs | 248 +++++++++++++++--- .../src/tools/context_server_registry.rs | 17 +- crates/agent/src/tools/copy_path_tool.rs | 35 ++- .../agent/src/tools/create_directory_tool.rs | 43 +-- crates/agent/src/tools/delete_path_tool.rs | 45 ++-- crates/agent/src/tools/diagnostics_tool.rs | 88 ++++--- crates/agent/src/tools/edit_file_tool.rs | 143 +++++----- crates/agent/src/tools/fetch_tool.rs | 72 ++--- crates/agent/src/tools/find_path_tool.rs | 13 +- crates/agent/src/tools/grep_tool.rs | 114 ++++---- crates/agent/src/tools/list_directory_tool.rs | 185 +++++++++---- crates/agent/src/tools/move_path_tool.rs | 35 ++- crates/agent/src/tools/now_tool.rs | 22 +- crates/agent/src/tools/open_tool.rs | 20 +- crates/agent/src/tools/read_file_tool.rs | 164 +++++++++--- .../src/tools/restore_file_from_disk_tool.rs | 62 +++-- crates/agent/src/tools/save_file_tool.rs | 66 +++-- crates/agent/src/tools/spawn_agent_tool.rs | 119 +++++---- .../src/tools/streaming_edit_file_tool.rs | 129 +++++---- crates/agent/src/tools/terminal_tool.rs | 59 +++-- crates/agent/src/tools/web_search_tool.rs | 75 +++--- .../remote_server/src/remote_editing_tests.rs | 13 +- crates/zed/src/visual_test_runner.rs | 7 +- 25 files changed, 1257 insertions(+), 672 deletions(-) diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index 139242fdee9da968986b3fc9537bf9e5292b7dc5..e8c95c630b65870bfc8a78b9e965373a2604879d 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -310,11 +310,11 @@ async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) { let task = cx.update(|cx| { tool.run( - crate::TerminalToolInput { + ToolInput::resolved(crate::TerminalToolInput { command: "sleep 1000".to_string(), cd: ".".to_string(), timeout_ms: Some(5), - }, + }), event_stream, cx, ) @@ -377,11 +377,11 @@ async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAp let _task = cx.update(|cx| { tool.run( - crate::TerminalToolInput { + ToolInput::resolved(crate::TerminalToolInput { command: "sleep 1000".to_string(), cd: ".".to_string(), timeout_ms: None, - }, + }), event_stream, cx, ) @@ -3991,11 +3991,11 @@ async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) { let task = cx.update(|cx| { tool.run( - crate::TerminalToolInput { + ToolInput::resolved(crate::TerminalToolInput { command: "rm -rf /".to_string(), cd: ".".to_string(), timeout_ms: None, - }, + }), event_stream, cx, ) @@ -4043,11 +4043,11 @@ async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) { let task = cx.update(|cx| { tool.run( - crate::TerminalToolInput { + ToolInput::resolved(crate::TerminalToolInput { command: "echo hello".to_string(), cd: ".".to_string(), timeout_ms: None, - }, + }), event_stream, cx, ) @@ -4101,11 +4101,11 @@ async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) { let _task = cx.update(|cx| { tool.run( - crate::TerminalToolInput { + ToolInput::resolved(crate::TerminalToolInput { command: "sudo rm file".to_string(), cd: ".".to_string(), timeout_ms: None, - }, + }), event_stream, cx, ) @@ -4148,11 +4148,11 @@ async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) { let task = cx.update(|cx| { tool.run( - crate::TerminalToolInput { + ToolInput::resolved(crate::TerminalToolInput { command: "echo hello".to_string(), cd: ".".to_string(), timeout_ms: None, - }, + }), event_stream, cx, ) @@ -5309,11 +5309,11 @@ async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) { let task = cx.update(|cx| { tool.run( - crate::EditFileToolInput { + ToolInput::resolved(crate::EditFileToolInput { display_description: "Edit sensitive file".to_string(), path: "root/sensitive_config.txt".into(), mode: crate::EditFileMode::Edit, - }, + }), event_stream, cx, ) @@ -5359,9 +5359,9 @@ async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext let task = cx.update(|cx| { tool.run( - crate::DeletePathToolInput { + ToolInput::resolved(crate::DeletePathToolInput { path: "root/important_data.txt".to_string(), - }, + }), event_stream, cx, ) @@ -5411,10 +5411,10 @@ async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContex let task = cx.update(|cx| { tool.run( - crate::MovePathToolInput { + ToolInput::resolved(crate::MovePathToolInput { source_path: "root/safe.txt".to_string(), destination_path: "root/protected/safe.txt".to_string(), - }, + }), event_stream, cx, ) @@ -5467,10 +5467,10 @@ async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) { let task = cx.update(|cx| { tool.run( - crate::MovePathToolInput { + ToolInput::resolved(crate::MovePathToolInput { source_path: "root/secret.txt".to_string(), destination_path: "root/public/not_secret.txt".to_string(), - }, + }), event_stream, cx, ) @@ -5525,10 +5525,10 @@ async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) { let task = cx.update(|cx| { tool.run( - crate::CopyPathToolInput { + ToolInput::resolved(crate::CopyPathToolInput { source_path: "root/confidential.txt".to_string(), destination_path: "root/dest/copy.txt".to_string(), - }, + }), event_stream, cx, ) @@ -5580,12 +5580,12 @@ async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) let task = cx.update(|cx| { tool.run( - crate::SaveFileToolInput { + ToolInput::resolved(crate::SaveFileToolInput { paths: vec![ std::path::PathBuf::from("root/normal.txt"), std::path::PathBuf::from("root/readonly/config.txt"), ], - }, + }), event_stream, cx, ) @@ -5632,9 +5632,9 @@ async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) { let task = cx.update(|cx| { tool.run( - crate::SaveFileToolInput { + ToolInput::resolved(crate::SaveFileToolInput { paths: vec![std::path::PathBuf::from("root/config.secret")], - }, + }), event_stream, cx, ) @@ -5676,7 +5676,7 @@ async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) { let input: crate::WebSearchToolInput = serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap(); - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx)); let result = task.await; assert!(result.is_err(), "expected search to be blocked"); @@ -5741,11 +5741,11 @@ async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppConte let _task = cx.update(|cx| { tool.run( - crate::EditFileToolInput { + ToolInput::resolved(crate::EditFileToolInput { display_description: "Edit README".to_string(), path: "root/README.md".into(), mode: crate::EditFileMode::Edit, - }, + }), event_stream, cx, ) @@ -5811,11 +5811,11 @@ async fn test_edit_file_tool_allow_still_prompts_for_local_settings(cx: &mut Tes let (event_stream, mut rx) = crate::ToolCallEventStream::test(); let _task = cx.update(|cx| { tool.run( - crate::EditFileToolInput { + ToolInput::resolved(crate::EditFileToolInput { display_description: "Edit local settings".to_string(), path: "root/.zed/settings.json".into(), mode: crate::EditFileMode::Edit, - }, + }), event_stream, cx, ) @@ -5855,7 +5855,7 @@ async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) { let input: crate::FetchToolInput = serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap(); - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx)); let result = task.await; assert!(result.is_err(), "expected fetch to be blocked"); @@ -5893,7 +5893,7 @@ async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) let input: crate::FetchToolInput = serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap(); - let _task = cx.update(|cx| tool.run(input, event_stream, cx)); + let _task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx)); cx.run_until_parked(); diff --git a/crates/agent/src/tests/test_tools.rs b/crates/agent/src/tests/test_tools.rs index 0ed2eef90271538c575cc84b56a28df106e4bd41..e0794ee322cdf2c77c37d1d22f30ec77c5642d24 100644 --- a/crates/agent/src/tests/test_tools.rs +++ b/crates/agent/src/tests/test_tools.rs @@ -3,6 +3,7 @@ use agent_settings::AgentSettings; use gpui::{App, SharedString, Task}; use std::future; use std::sync::atomic::{AtomicBool, Ordering}; +use std::time::Duration; /// A tool that echoes its input #[derive(JsonSchema, Serialize, Deserialize)] @@ -33,11 +34,17 @@ impl AgentTool for EchoTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, _event_stream: ToolCallEventStream, - _cx: &mut App, + cx: &mut App, ) -> Task> { - Task::ready(Ok(input.text)) + cx.spawn(async move |_cx| { + let input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; + Ok(input.text) + }) } } @@ -74,7 +81,7 @@ impl AgentTool for DelayTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, _event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> @@ -83,6 +90,10 @@ impl AgentTool for DelayTool { { let executor = cx.background_executor().clone(); cx.foreground_executor().spawn(async move { + let input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; executor.timer(Duration::from_millis(input.ms)).await; Ok("Ding".to_string()) }) @@ -114,28 +125,38 @@ impl AgentTool for ToolRequiringPermission { fn run( self: Arc, - _input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let settings = AgentSettings::get_global(cx); - let decision = decide_permission_from_settings(Self::NAME, &[String::new()], settings); - - let authorize = match decision { - ToolPermissionDecision::Allow => None, - ToolPermissionDecision::Deny(reason) => { - return Task::ready(Err(reason)); - } - ToolPermissionDecision::Confirm => { - let context = crate::ToolPermissionContext::new( - "tool_requiring_permission", - vec![String::new()], - ); - Some(event_stream.authorize("Authorize?", context, cx)) - } - }; + cx.spawn(async move |cx| { + let _input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; + + let decision = cx.update(|cx| { + decide_permission_from_settings( + Self::NAME, + &[String::new()], + AgentSettings::get_global(cx), + ) + }); + + let authorize = match decision { + ToolPermissionDecision::Allow => None, + ToolPermissionDecision::Deny(reason) => { + return Err(reason); + } + ToolPermissionDecision::Confirm => Some(cx.update(|cx| { + let context = crate::ToolPermissionContext::new( + "tool_requiring_permission", + vec![String::new()], + ); + event_stream.authorize("Authorize?", context, cx) + })), + }; - cx.foreground_executor().spawn(async move { if let Some(authorize) = authorize { authorize.await.map_err(|e| e.to_string())?; } @@ -169,11 +190,15 @@ impl AgentTool for InfiniteTool { fn run( self: Arc, - _input: Self::Input, + input: ToolInput, _event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { cx.foreground_executor().spawn(async move { + let _input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; future::pending::<()>().await; unreachable!() }) @@ -221,11 +246,15 @@ impl AgentTool for CancellationAwareTool { fn run( self: Arc, - _input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { cx.foreground_executor().spawn(async move { + let _input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; // Wait for cancellation - this tool does nothing but wait to be cancelled event_stream.cancelled_by_user().await; self.was_cancelled.store(true, Ordering::SeqCst); @@ -276,10 +305,16 @@ impl AgentTool for WordListTool { fn run( self: Arc, - _input: Self::Input, + input: ToolInput, _event_stream: ToolCallEventStream, - _cx: &mut App, + cx: &mut App, ) -> Task> { - Task::ready(Ok("ok".to_string())) + cx.spawn(async move |_cx| { + let _input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; + Ok("ok".to_string()) + }) } } diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 5d4de36cb69335de7a77eb7ad7a15f75b8e2b0b7..f9be3bfbeacfd137b06da7dc99eef7ae34422325 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -45,11 +45,13 @@ use language_model::{ use project::Project; use prompt_store::ProjectContext; use schemars::{JsonSchema, Schema}; +use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use settings::{LanguageModelSelection, Settings, ToolPermissionMode, update_settings_file}; use smol::stream::StreamExt; use std::{ collections::BTreeMap, + marker::PhantomData, ops::RangeInclusive, path::Path, rc::Rc, @@ -1360,7 +1362,6 @@ impl Thread { self.project.clone(), cx.weak_entity(), language_registry, - Templates::new(), )); self.add_tool(FetchTool::new(self.project.read(cx).client().http_client())); self.add_tool(FindPathTool::new(self.project.clone())); @@ -1664,6 +1665,7 @@ impl Thread { event_stream: event_stream.clone(), tools: self.enabled_tools(profile, &model, cx), cancellation_tx, + streaming_tool_inputs: HashMap::default(), _task: cx.spawn(async move |this, cx| { log::debug!("Starting agent turn execution"); @@ -2068,10 +2070,6 @@ impl Thread { self.send_or_update_tool_use(&tool_use, title, kind, event_stream); - if !tool_use.is_input_complete { - return None; - } - let Some(tool) = tool else { let content = format!("No tool named {} exists", tool_use.name); return Some(Task::ready(LanguageModelToolResult { @@ -2083,9 +2081,72 @@ impl Thread { })); }; + if !tool_use.is_input_complete { + if tool.supports_input_streaming() { + let running_turn = self.running_turn.as_mut()?; + if let Some(sender) = running_turn.streaming_tool_inputs.get(&tool_use.id) { + sender.send_partial(tool_use.input); + return None; + } + + let (sender, tool_input) = ToolInputSender::channel(); + sender.send_partial(tool_use.input); + running_turn + .streaming_tool_inputs + .insert(tool_use.id.clone(), sender); + + let tool = tool.clone(); + log::debug!("Running streaming tool {}", tool_use.name); + return Some(self.run_tool( + tool, + tool_input, + tool_use.id, + tool_use.name, + event_stream, + cancellation_rx, + cx, + )); + } else { + return None; + } + } + + if let Some(sender) = self + .running_turn + .as_mut()? + .streaming_tool_inputs + .remove(&tool_use.id) + { + sender.send_final(tool_use.input); + return None; + } + + log::debug!("Running tool {}", tool_use.name); + let tool_input = ToolInput::ready(tool_use.input); + Some(self.run_tool( + tool, + tool_input, + tool_use.id, + tool_use.name, + event_stream, + cancellation_rx, + cx, + )) + } + + fn run_tool( + &self, + tool: Arc, + tool_input: ToolInput, + tool_use_id: LanguageModelToolUseId, + tool_name: Arc, + event_stream: &ThreadEventStream, + cancellation_rx: watch::Receiver, + cx: &mut Context, + ) -> Task { let fs = self.project.read(cx).fs().clone(); let tool_event_stream = ToolCallEventStream::new( - tool_use.id.clone(), + tool_use_id.clone(), event_stream.clone(), Some(fs), cancellation_rx, @@ -2094,9 +2155,8 @@ impl Thread { acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress), ); let supports_images = self.model().is_some_and(|model| model.supports_images()); - let tool_result = tool.run(tool_use.input, tool_event_stream, cx); - log::debug!("Running tool {}", tool_use.name); - Some(cx.foreground_executor().spawn(async move { + let tool_result = tool.run(tool_input, tool_event_stream, cx); + cx.foreground_executor().spawn(async move { let (is_error, output) = match tool_result.await { Ok(mut output) => { if let LanguageModelToolResultContent::Image(_) = &output.llm_output @@ -2114,13 +2174,13 @@ impl Thread { }; LanguageModelToolResult { - tool_use_id: tool_use.id, - tool_name: tool_use.name, + tool_use_id, + tool_name, is_error, content: output.llm_output, output: Some(output.raw_output), } - })) + }) } fn handle_tool_use_json_parse_error_event( @@ -2776,6 +2836,9 @@ struct RunningTurn { /// Sender to signal tool cancellation. When cancel is called, this is /// set to true so all tools can detect user-initiated cancellation. cancellation_tx: watch::Sender, + /// Senders for tools that support input streaming and have already been + /// started but are still receiving input from the LLM. + streaming_tool_inputs: HashMap, } impl RunningTurn { @@ -2795,6 +2858,103 @@ pub struct TitleUpdated; impl EventEmitter for Thread {} +/// A channel-based wrapper that delivers tool input to a running tool. +/// +/// For non-streaming tools, created via `ToolInput::ready()` so `.recv()` resolves immediately. +/// For streaming tools, partial JSON snapshots arrive via `.recv_partial()` as the LLM streams +/// them, followed by the final complete input available through `.recv()`. +pub struct ToolInput { + partial_rx: mpsc::UnboundedReceiver, + final_rx: oneshot::Receiver, + _phantom: PhantomData, +} + +impl ToolInput { + #[cfg(any(test, feature = "test-support"))] + pub fn resolved(input: impl Serialize) -> Self { + let value = serde_json::to_value(input).expect("failed to serialize tool input"); + Self::ready(value) + } + + pub fn ready(value: serde_json::Value) -> Self { + let (partial_tx, partial_rx) = mpsc::unbounded(); + drop(partial_tx); + let (final_tx, final_rx) = oneshot::channel(); + final_tx.send(value).ok(); + Self { + partial_rx, + final_rx, + _phantom: PhantomData, + } + } + + #[cfg(any(test, feature = "test-support"))] + pub fn test() -> (ToolInputSender, Self) { + let (sender, input) = ToolInputSender::channel(); + (sender, input.cast()) + } + + /// Wait for the final deserialized input, ignoring all partial updates. + /// Non-streaming tools can use this to wait until the whole input is available. + pub async fn recv(mut self) -> Result { + // Drain any remaining partials + while self.partial_rx.next().await.is_some() {} + let value = self + .final_rx + .await + .map_err(|_| anyhow!("tool input sender was dropped before sending final input"))?; + serde_json::from_value(value).map_err(Into::into) + } + + /// Returns the next partial JSON snapshot, or `None` when input is complete. + /// Once this returns `None`, call `recv()` to get the final input. + pub async fn recv_partial(&mut self) -> Option { + self.partial_rx.next().await + } + + fn cast(self) -> ToolInput { + ToolInput { + partial_rx: self.partial_rx, + final_rx: self.final_rx, + _phantom: PhantomData, + } + } +} + +pub struct ToolInputSender { + partial_tx: mpsc::UnboundedSender, + final_tx: Option>, +} + +impl ToolInputSender { + pub(crate) fn channel() -> (Self, ToolInput) { + let (partial_tx, partial_rx) = mpsc::unbounded(); + let (final_tx, final_rx) = oneshot::channel(); + let sender = Self { + partial_tx, + final_tx: Some(final_tx), + }; + let input = ToolInput { + partial_rx, + final_rx, + _phantom: PhantomData, + }; + (sender, input) + } + + pub(crate) fn send_partial(&self, value: serde_json::Value) { + self.partial_tx.unbounded_send(value).ok(); + } + + pub(crate) fn send_final(mut self, value: serde_json::Value) { + // Close the partial channel so recv_partial() returns None + self.partial_tx.close_channel(); + if let Some(final_tx) = self.final_tx.take() { + final_tx.send(value).ok(); + } + } +} + pub trait AgentTool where Self: 'static + Sized, @@ -2828,6 +2988,11 @@ where language_model::tool_schema::root_schema_for::(format) } + /// Returns whether the tool supports streaming of tool use parameters. + fn supports_input_streaming() -> bool { + false + } + /// Some tools rely on a provider for the underlying billing or other reasons. /// Allow the tool to check if they are compatible, or should be filtered out. fn supports_provider(_provider: &LanguageModelProviderId) -> bool { @@ -2843,7 +3008,7 @@ where /// still signaling whether the invocation succeeded or failed. fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task>; @@ -2888,13 +3053,16 @@ pub trait AnyAgentTool { fn kind(&self) -> acp::ToolKind; fn initial_title(&self, input: serde_json::Value, _cx: &mut App) -> SharedString; fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result; + fn supports_input_streaming(&self) -> bool { + false + } fn supports_provider(&self, _provider: &LanguageModelProviderId) -> bool { true } /// See [`AgentTool::run`] for why this returns `Result`. fn run( self: Arc, - input: serde_json::Value, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task>; @@ -2923,6 +3091,10 @@ where T::kind() } + fn supports_input_streaming(&self) -> bool { + T::supports_input_streaming() + } + fn initial_title(&self, input: serde_json::Value, _cx: &mut App) -> SharedString { let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input); self.0.initial_title(parsed_input, _cx) @@ -2940,35 +3112,31 @@ where fn run( self: Arc, - input: serde_json::Value, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - cx.spawn(async move |cx| { - let input: T::Input = serde_json::from_value(input).map_err(|e| { - AgentToolOutput::from_error(format!("Failed to parse tool input: {e}")) - })?; - let task = cx.update(|cx| self.0.clone().run(input, event_stream, cx)); - match task.await { - Ok(output) => { - let raw_output = serde_json::to_value(&output).map_err(|e| { - AgentToolOutput::from_error(format!("Failed to serialize tool output: {e}")) - })?; - Ok(AgentToolOutput { - llm_output: output.into(), - raw_output, - }) - } - Err(error_output) => { - let raw_output = serde_json::to_value(&error_output).unwrap_or_else(|e| { - log::error!("Failed to serialize tool error output: {e}"); - serde_json::Value::Null - }); - Err(AgentToolOutput { - llm_output: error_output.into(), - raw_output, - }) - } + let tool_input: ToolInput = input.cast(); + let task = self.0.clone().run(tool_input, event_stream, cx); + cx.spawn(async move |_cx| match task.await { + Ok(output) => { + let raw_output = serde_json::to_value(&output).map_err(|e| { + AgentToolOutput::from_error(format!("Failed to serialize tool output: {e}")) + })?; + Ok(AgentToolOutput { + llm_output: output.into(), + raw_output, + }) + } + Err(error_output) => { + let raw_output = serde_json::to_value(&error_output).unwrap_or_else(|e| { + log::error!("Failed to serialize tool error output: {e}"); + serde_json::Value::Null + }); + Err(AgentToolOutput { + llm_output: error_output.into(), + raw_output, + }) } }) } diff --git a/crates/agent/src/tools/context_server_registry.rs b/crates/agent/src/tools/context_server_registry.rs index 694e28750cd69facc49b7a0bf862203a00043b4c..1c7590d8097a5de50b879d5b253c5dbabd3dcbab 100644 --- a/crates/agent/src/tools/context_server_registry.rs +++ b/crates/agent/src/tools/context_server_registry.rs @@ -1,4 +1,4 @@ -use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream}; +use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream, ToolInput}; use agent_client_protocol::ToolKind; use anyhow::Result; use collections::{BTreeMap, HashMap}; @@ -329,7 +329,7 @@ impl AnyAgentTool for ContextServerTool { fn run( self: Arc, - input: serde_json::Value, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { @@ -339,14 +339,15 @@ impl AnyAgentTool for ContextServerTool { let tool_name = self.tool.name.clone(); let tool_id = mcp_tool_id(&self.server_id.0, &self.tool.name); let display_name = self.tool.name.clone(); - let authorize = event_stream.authorize_third_party_tool( - self.initial_title(input.clone(), cx), - tool_id, - display_name, - cx, - ); + let initial_title = self.initial_title(serde_json::Value::Null, cx); + let authorize = + event_stream.authorize_third_party_tool(initial_title, tool_id, display_name, cx); cx.spawn(async move |_cx| { + let input = input.recv().await.map_err(|e| { + AgentToolOutput::from_error(format!("Failed to receive tool input: {e}")) + })?; + authorize.await.map_err(|e| AgentToolOutput::from_error(e.to_string()))?; let Some(protocol) = server.client() else { diff --git a/crates/agent/src/tools/copy_path_tool.rs b/crates/agent/src/tools/copy_path_tool.rs index c82d9e930e1987d389ece84347c1a0f43c601182..7f53a5c36a7979a01de96535f19e421fa3119e16 100644 --- a/crates/agent/src/tools/copy_path_tool.rs +++ b/crates/agent/src/tools/copy_path_tool.rs @@ -2,7 +2,9 @@ use super::tool_permissions::{ SensitiveSettingsKind, authorize_symlink_escapes, canonicalize_worktree_roots, collect_symlink_escapes, sensitive_settings_kind, }; -use crate::{AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_for_paths}; +use crate::{ + AgentTool, ToolCallEventStream, ToolInput, ToolPermissionDecision, decide_permission_for_paths, +}; use agent_client_protocol::ToolKind; use agent_settings::AgentSettings; use futures::FutureExt as _; @@ -79,19 +81,24 @@ impl AgentTool for CopyPathTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let settings = AgentSettings::get_global(cx); - let paths = vec![input.source_path.clone(), input.destination_path.clone()]; - let decision = decide_permission_for_paths(Self::NAME, &paths, settings); - if let ToolPermissionDecision::Deny(reason) = decision { - return Task::ready(Err(reason)); - } - let project = self.project.clone(); cx.spawn(async move |cx| { + let input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; + let paths = vec![input.source_path.clone(), input.destination_path.clone()]; + let decision = cx.update(|cx| { + decide_permission_for_paths(Self::NAME, &paths, &AgentSettings::get_global(cx)) + }); + if let ToolPermissionDecision::Deny(reason) = decision { + return Err(reason); + } + let fs = project.read_with(cx, |project, _cx| project.fs().clone()); let canonical_roots = canonicalize_worktree_roots(&project, &fs, cx).await; @@ -248,7 +255,7 @@ mod tests { }; let (event_stream, mut event_rx) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx)); let auth = event_rx.expect_authorization().await; let title = auth.tool_call.fields.title.as_deref().unwrap_or(""); @@ -302,7 +309,7 @@ mod tests { }; let (event_stream, mut event_rx) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx)); let auth = event_rx.expect_authorization().await; drop(auth); @@ -354,7 +361,7 @@ mod tests { }; let (event_stream, mut event_rx) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx)); let auth = event_rx.expect_authorization().await; let title = auth.tool_call.fields.title.as_deref().unwrap_or(""); @@ -430,7 +437,9 @@ mod tests { }; let (event_stream, mut event_rx) = ToolCallEventStream::test(); - let result = cx.update(|cx| tool.run(input, event_stream, cx)).await; + let result = cx + .update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx)) + .await; assert!(result.is_err(), "Tool should fail when policy denies"); assert!( diff --git a/crates/agent/src/tools/create_directory_tool.rs b/crates/agent/src/tools/create_directory_tool.rs index 500b5f00289db245898d5918a79dc684a6f0f110..5d8930f3c7400428d55cfe7d14bafc16d94be43a 100644 --- a/crates/agent/src/tools/create_directory_tool.rs +++ b/crates/agent/src/tools/create_directory_tool.rs @@ -13,7 +13,9 @@ use settings::Settings; use std::sync::Arc; use util::markdown::MarkdownInlineCode; -use crate::{AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_for_path}; +use crate::{ + AgentTool, ToolCallEventStream, ToolInput, ToolPermissionDecision, decide_permission_for_path, +}; use std::path::Path; /// Creates a new directory at the specified path within the project. Returns confirmation that the directory was created. @@ -68,21 +70,26 @@ impl AgentTool for CreateDirectoryTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let settings = AgentSettings::get_global(cx); - let decision = decide_permission_for_path(Self::NAME, &input.path, settings); + let project = self.project.clone(); + cx.spawn(async move |cx| { + let input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; + let decision = cx.update(|cx| { + decide_permission_for_path(Self::NAME, &input.path, AgentSettings::get_global(cx)) + }); - if let ToolPermissionDecision::Deny(reason) = decision { - return Task::ready(Err(reason)); - } + if let ToolPermissionDecision::Deny(reason) = decision { + return Err(reason); + } - let destination_path: Arc = input.path.as_str().into(); + let destination_path: Arc = input.path.as_str().into(); - let project = self.project.clone(); - cx.spawn(async move |cx| { let fs = project.read_with(cx, |project, _cx| project.fs().clone()); let canonical_roots = canonicalize_worktree_roots(&project, &fs, cx).await; @@ -218,9 +225,9 @@ mod tests { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { tool.run( - CreateDirectoryToolInput { + ToolInput::resolved(CreateDirectoryToolInput { path: "project/link_to_external".into(), - }, + }), event_stream, cx, ) @@ -277,9 +284,9 @@ mod tests { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { tool.run( - CreateDirectoryToolInput { + ToolInput::resolved(CreateDirectoryToolInput { path: "project/link_to_external".into(), - }, + }), event_stream, cx, ) @@ -336,9 +343,9 @@ mod tests { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { tool.run( - CreateDirectoryToolInput { + ToolInput::resolved(CreateDirectoryToolInput { path: "project/link_to_external".into(), - }, + }), event_stream, cx, ) @@ -415,9 +422,9 @@ mod tests { let result = cx .update(|cx| { tool.run( - CreateDirectoryToolInput { + ToolInput::resolved(CreateDirectoryToolInput { path: "project/link_to_external".into(), - }, + }), event_stream, cx, ) diff --git a/crates/agent/src/tools/delete_path_tool.rs b/crates/agent/src/tools/delete_path_tool.rs index 048f4bd8292077874b49bd74b09cbea38b4fafc5..27ab68db667a4cf3223e6521682814dc1c245bb7 100644 --- a/crates/agent/src/tools/delete_path_tool.rs +++ b/crates/agent/src/tools/delete_path_tool.rs @@ -2,7 +2,9 @@ use super::tool_permissions::{ SensitiveSettingsKind, authorize_symlink_access, canonicalize_worktree_roots, detect_symlink_escape, sensitive_settings_kind, }; -use crate::{AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_for_path}; +use crate::{ + AgentTool, ToolCallEventStream, ToolInput, ToolPermissionDecision, decide_permission_for_path, +}; use action_log::ActionLog; use agent_client_protocol::ToolKind; use agent_settings::AgentSettings; @@ -71,22 +73,27 @@ impl AgentTool for DeletePathTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let path = input.path; - - let settings = AgentSettings::get_global(cx); - let decision = decide_permission_for_path(Self::NAME, &path, settings); - - if let ToolPermissionDecision::Deny(reason) = decision { - return Task::ready(Err(reason)); - } - let project = self.project.clone(); let action_log = self.action_log.clone(); cx.spawn(async move |cx| { + let input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; + let path = input.path; + + let decision = cx.update(|cx| { + decide_permission_for_path(Self::NAME, &path, AgentSettings::get_global(cx)) + }); + + if let ToolPermissionDecision::Deny(reason) = decision { + return Err(reason); + } + let fs = project.read_with(cx, |project, _cx| project.fs().clone()); let canonical_roots = canonicalize_worktree_roots(&project, &fs, cx).await; @@ -278,9 +285,9 @@ mod tests { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { tool.run( - DeletePathToolInput { + ToolInput::resolved(DeletePathToolInput { path: "project/link_to_external".into(), - }, + }), event_stream, cx, ) @@ -345,9 +352,9 @@ mod tests { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { tool.run( - DeletePathToolInput { + ToolInput::resolved(DeletePathToolInput { path: "project/link_to_external".into(), - }, + }), event_stream, cx, ) @@ -405,9 +412,9 @@ mod tests { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { tool.run( - DeletePathToolInput { + ToolInput::resolved(DeletePathToolInput { path: "project/link_to_external".into(), - }, + }), event_stream, cx, ) @@ -488,9 +495,9 @@ mod tests { let result = cx .update(|cx| { tool.run( - DeletePathToolInput { + ToolInput::resolved(DeletePathToolInput { path: "project/link_to_external".into(), - }, + }), event_stream, cx, ) diff --git a/crates/agent/src/tools/diagnostics_tool.rs b/crates/agent/src/tools/diagnostics_tool.rs index fea16d531ed5f4212e6b1374aee04cee67b2fc0b..5889f66c2edbe06055678b19474447e0f23e2b0f 100644 --- a/crates/agent/src/tools/diagnostics_tool.rs +++ b/crates/agent/src/tools/diagnostics_tool.rs @@ -1,4 +1,4 @@ -use crate::{AgentTool, ToolCallEventStream}; +use crate::{AgentTool, ToolCallEventStream, ToolInput}; use agent_client_protocol as acp; use anyhow::Result; use futures::FutureExt as _; @@ -87,21 +87,27 @@ impl AgentTool for DiagnosticsTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - match input.path { - Some(path) if !path.is_empty() => { - let Some(project_path) = self.project.read(cx).find_project_path(&path, cx) else { - return Task::ready(Err(format!("Could not find path {path} in project"))); - }; - - let open_buffer_task = self - .project - .update(cx, |project, cx| project.open_buffer(project_path, cx)); + let project = self.project.clone(); + cx.spawn(async move |cx| { + let input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; + + match input.path { + Some(path) if !path.is_empty() => { + let (_project_path, open_buffer_task) = project.update(cx, |project, cx| { + let Some(project_path) = project.find_project_path(&path, cx) else { + return Err(format!("Could not find path {path} in project")); + }; + let task = project.open_buffer(project_path.clone(), cx); + Ok((project_path, task)) + })?; - cx.spawn(async move |cx| { let buffer = futures::select! { result = open_buffer_task.fuse() => result.map_err(|e| e.to_string())?, _ = event_stream.cancelled_by_user().fuse() => { @@ -135,36 +141,40 @@ impl AgentTool for DiagnosticsTool { } else { Ok(output) } - }) - } - _ => { - let project = self.project.read(cx); - let mut output = String::new(); - let mut has_diagnostics = false; - - for (project_path, _, summary) in project.diagnostic_summaries(true, cx) { - if summary.error_count > 0 || summary.warning_count > 0 { - let Some(worktree) = project.worktree_for_id(project_path.worktree_id, cx) - else { - continue; - }; - - has_diagnostics = true; - output.push_str(&format!( - "{}: {} error(s), {} warning(s)\n", - worktree.read(cx).absolutize(&project_path.path).display(), - summary.error_count, - summary.warning_count - )); - } } + _ => { + let (output, has_diagnostics) = project.read_with(cx, |project, cx| { + let mut output = String::new(); + let mut has_diagnostics = false; + + for (project_path, _, summary) in project.diagnostic_summaries(true, cx) { + if summary.error_count > 0 || summary.warning_count > 0 { + let Some(worktree) = + project.worktree_for_id(project_path.worktree_id, cx) + else { + continue; + }; + + has_diagnostics = true; + output.push_str(&format!( + "{}: {} error(s), {} warning(s)\n", + worktree.read(cx).absolutize(&project_path.path).display(), + summary.error_count, + summary.warning_count + )); + } + } + + (output, has_diagnostics) + }); - if has_diagnostics { - Task::ready(Ok(output)) - } else { - Task::ready(Ok("No errors or warnings found in the project.".into())) + if has_diagnostics { + Ok(output) + } else { + Ok("No errors or warnings found in the project.".into()) + } } } - } + }) } } diff --git a/crates/agent/src/tools/edit_file_tool.rs b/crates/agent/src/tools/edit_file_tool.rs index 788bf06529a6f0b87242379ffcdb83f38e4c7126..3e1e0661f126d464c8d4611e2b3d85a9f668a5ca 100644 --- a/crates/agent/src/tools/edit_file_tool.rs +++ b/crates/agent/src/tools/edit_file_tool.rs @@ -2,7 +2,7 @@ use super::restore_file_from_disk_tool::RestoreFileFromDiskTool; use super::save_file_tool::SaveFileTool; use super::tool_permissions::authorize_file_edit; use crate::{ - AgentTool, Templates, Thread, ToolCallEventStream, + AgentTool, Templates, Thread, ToolCallEventStream, ToolInput, edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat}, }; use acp_thread::Diff; @@ -237,39 +237,44 @@ impl AgentTool for EditFileTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let Ok(project) = self - .thread - .read_with(cx, |thread, _cx| thread.project().clone()) - else { - return Task::ready(Err(EditFileToolOutput::Error { - error: "thread was dropped".to_string(), - })); - }; - let project_path = match resolve_path(&input, project.clone(), cx) { - Ok(path) => path, - Err(err) => { - return Task::ready(Err(EditFileToolOutput::Error { - error: err.to_string(), - })); - } - }; - let abs_path = project.read(cx).absolute_path(&project_path, cx); - if let Some(abs_path) = abs_path.clone() { - event_stream.update_fields( - ToolCallUpdateFields::new().locations(vec![acp::ToolCallLocation::new(abs_path)]), - ); - } - let allow_thinking = self - .thread - .read_with(cx, |thread, _cx| thread.thinking_enabled()) - .unwrap_or(true); - - let authorize = self.authorize(&input, &event_stream, cx); cx.spawn(async move |cx: &mut AsyncApp| { + let input = input.recv().await.map_err(|e| EditFileToolOutput::Error { + error: format!("Failed to receive tool input: {e}"), + })?; + + let project = self + .thread + .read_with(cx, |thread, _cx| thread.project().clone()) + .map_err(|_| EditFileToolOutput::Error { + error: "thread was dropped".to_string(), + })?; + + let (project_path, abs_path, allow_thinking, authorize) = + cx.update(|cx| { + let project_path = resolve_path(&input, project.clone(), cx).map_err(|err| { + EditFileToolOutput::Error { + error: err.to_string(), + } + })?; + let abs_path = project.read(cx).absolute_path(&project_path, cx); + if let Some(abs_path) = abs_path.clone() { + event_stream.update_fields( + ToolCallUpdateFields::new() + .locations(vec![acp::ToolCallLocation::new(abs_path)]), + ); + } + let allow_thinking = self + .thread + .read_with(cx, |thread, _cx| thread.thinking_enabled()) + .unwrap_or(true); + let authorize = self.authorize(&input, &event_stream, cx); + Ok::<_, EditFileToolOutput>((project_path, abs_path, allow_thinking, authorize)) + })?; + let result: anyhow::Result = async { authorize.await?; @@ -672,7 +677,11 @@ mod tests { language_registry, Templates::new(), )) - .run(input, ToolCallEventStream::test().0, cx) + .run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; assert_eq!( @@ -881,7 +890,11 @@ mod tests { language_registry.clone(), Templates::new(), )) - .run(input, ToolCallEventStream::test().0, cx) + .run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }); // Stream the unformatted content @@ -940,7 +953,11 @@ mod tests { language_registry, Templates::new(), )) - .run(input, ToolCallEventStream::test().0, cx) + .run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }); // Stream the unformatted content @@ -1027,7 +1044,11 @@ mod tests { language_registry.clone(), Templates::new(), )) - .run(input, ToolCallEventStream::test().0, cx) + .run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }); // Stream the content with trailing whitespace @@ -1082,7 +1103,11 @@ mod tests { language_registry, Templates::new(), )) - .run(input, ToolCallEventStream::test().0, cx) + .run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }); // Stream the content with trailing whitespace @@ -2081,11 +2106,11 @@ mod tests { let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); let edit = cx.update(|cx| { tool.run( - EditFileToolInput { + ToolInput::resolved(EditFileToolInput { display_description: "Edit file".into(), path: path!("/main.rs").into(), mode: EditFileMode::Edit, - }, + }), stream_tx, cx, ) @@ -2111,11 +2136,11 @@ mod tests { let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); let edit = cx.update(|cx| { tool.run( - EditFileToolInput { + ToolInput::resolved(EditFileToolInput { display_description: "Edit file".into(), path: path!("/main.rs").into(), mode: EditFileMode::Edit, - }, + }), stream_tx, cx, ) @@ -2139,11 +2164,11 @@ mod tests { let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); let edit = cx.update(|cx| { tool.run( - EditFileToolInput { + ToolInput::resolved(EditFileToolInput { display_description: "Edit file".into(), path: path!("/main.rs").into(), mode: EditFileMode::Edit, - }, + }), stream_tx, cx, ) @@ -2199,11 +2224,11 @@ mod tests { // Read the file to record the read time cx.update(|cx| { read_tool.clone().run( - crate::ReadFileToolInput { + ToolInput::resolved(crate::ReadFileToolInput { path: "root/test.txt".to_string(), start_line: None, end_line: None, - }, + }), ToolCallEventStream::test().0, cx, ) @@ -2227,11 +2252,11 @@ mod tests { // Read the file again - should update the entry cx.update(|cx| { read_tool.clone().run( - crate::ReadFileToolInput { + ToolInput::resolved(crate::ReadFileToolInput { path: "root/test.txt".to_string(), start_line: None, end_line: None, - }, + }), ToolCallEventStream::test().0, cx, ) @@ -2298,11 +2323,11 @@ mod tests { // Read the file first cx.update(|cx| { read_tool.clone().run( - crate::ReadFileToolInput { + ToolInput::resolved(crate::ReadFileToolInput { path: "root/test.txt".to_string(), start_line: None, end_line: None, - }, + }), ToolCallEventStream::test().0, cx, ) @@ -2314,11 +2339,11 @@ mod tests { let edit_result = { let edit_task = cx.update(|cx| { edit_tool.clone().run( - EditFileToolInput { + ToolInput::resolved(EditFileToolInput { display_description: "First edit".into(), path: "root/test.txt".into(), mode: EditFileMode::Edit, - }, + }), ToolCallEventStream::test().0, cx, ) @@ -2343,11 +2368,11 @@ mod tests { let edit_result = { let edit_task = cx.update(|cx| { edit_tool.clone().run( - EditFileToolInput { + ToolInput::resolved(EditFileToolInput { display_description: "Second edit".into(), path: "root/test.txt".into(), mode: EditFileMode::Edit, - }, + }), ToolCallEventStream::test().0, cx, ) @@ -2412,11 +2437,11 @@ mod tests { // Read the file first cx.update(|cx| { read_tool.clone().run( - crate::ReadFileToolInput { + ToolInput::resolved(crate::ReadFileToolInput { path: "root/test.txt".to_string(), start_line: None, end_line: None, - }, + }), ToolCallEventStream::test().0, cx, ) @@ -2456,11 +2481,11 @@ mod tests { let result = cx .update(|cx| { edit_tool.clone().run( - EditFileToolInput { + ToolInput::resolved(EditFileToolInput { display_description: "Edit after external change".into(), path: "root/test.txt".into(), mode: EditFileMode::Edit, - }, + }), ToolCallEventStream::test().0, cx, ) @@ -2523,11 +2548,11 @@ mod tests { // Read the file first cx.update(|cx| { read_tool.clone().run( - crate::ReadFileToolInput { + ToolInput::resolved(crate::ReadFileToolInput { path: "root/test.txt".to_string(), start_line: None, end_line: None, - }, + }), ToolCallEventStream::test().0, cx, ) @@ -2560,11 +2585,11 @@ mod tests { let result = cx .update(|cx| { edit_tool.clone().run( - EditFileToolInput { + ToolInput::resolved(EditFileToolInput { display_description: "Edit with dirty buffer".into(), path: "root/test.txt".into(), mode: EditFileMode::Edit, - }, + }), ToolCallEventStream::test().0, cx, ) diff --git a/crates/agent/src/tools/fetch_tool.rs b/crates/agent/src/tools/fetch_tool.rs index e573c2202b09d1283d75c3eda20b65be1bcd82a7..75880801595ad0604c9f3a1fac58bd916809a8ba 100644 --- a/crates/agent/src/tools/fetch_tool.rs +++ b/crates/agent/src/tools/fetch_tool.rs @@ -16,7 +16,8 @@ use ui::SharedString; use util::markdown::{MarkdownEscaped, MarkdownInlineCode}; use crate::{ - AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_from_settings, + AgentTool, ToolCallEventStream, ToolInput, ToolPermissionDecision, + decide_permission_from_settings, }; #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)] @@ -141,41 +142,52 @@ impl AgentTool for FetchTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let settings = AgentSettings::get_global(cx); - let decision = - decide_permission_from_settings(Self::NAME, std::slice::from_ref(&input.url), settings); - - let authorize = match decision { - ToolPermissionDecision::Allow => None, - ToolPermissionDecision::Deny(reason) => { - return Task::ready(Err(reason)); - } - ToolPermissionDecision::Confirm => { - let context = - crate::ToolPermissionContext::new(Self::NAME, vec![input.url.clone()]); - Some(event_stream.authorize( - format!("Fetch {}", MarkdownInlineCode(&input.url)), - context, - cx, - )) - } - }; + let http_client = self.http_client.clone(); + cx.spawn(async move |cx| { + let input: FetchToolInput = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; + + let decision = cx.update(|cx| { + decide_permission_from_settings( + Self::NAME, + std::slice::from_ref(&input.url), + AgentSettings::get_global(cx), + ) + }); + + let authorize = match decision { + ToolPermissionDecision::Allow => None, + ToolPermissionDecision::Deny(reason) => { + return Err(reason); + } + ToolPermissionDecision::Confirm => Some(cx.update(|cx| { + let context = + crate::ToolPermissionContext::new(Self::NAME, vec![input.url.clone()]); + event_stream.authorize( + format!("Fetch {}", MarkdownInlineCode(&input.url)), + context, + cx, + ) + })), + }; - let fetch_task = cx.background_spawn({ - let http_client = self.http_client.clone(); - async move { - if let Some(authorize) = authorize { - authorize.await?; + let fetch_task = cx.background_spawn({ + let http_client = http_client.clone(); + let url = input.url.clone(); + async move { + if let Some(authorize) = authorize { + authorize.await?; + } + Self::build_message(http_client, &url).await } - Self::build_message(http_client, &input.url).await - } - }); + }); - cx.foreground_executor().spawn(async move { let text = futures::select! { result = fetch_task.fuse() => result.map_err(|e| e.to_string())?, _ = event_stream.cancelled_by_user().fuse() => { diff --git a/crates/agent/src/tools/find_path_tool.rs b/crates/agent/src/tools/find_path_tool.rs index 4ba60c61063c08ac002dc7dc16fa11b987cbab74..9c65461503225171bcda482d58871a94743481e3 100644 --- a/crates/agent/src/tools/find_path_tool.rs +++ b/crates/agent/src/tools/find_path_tool.rs @@ -1,4 +1,4 @@ -use crate::{AgentTool, ToolCallEventStream}; +use crate::{AgentTool, ToolCallEventStream, ToolInput}; use agent_client_protocol as acp; use anyhow::{Result, anyhow}; use futures::FutureExt as _; @@ -121,13 +121,18 @@ impl AgentTool for FindPathTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let search_paths_task = search_paths(&input.glob, self.project.clone(), cx); + let project = self.project.clone(); + cx.spawn(async move |cx| { + let input = input.recv().await.map_err(|e| FindPathToolOutput::Error { + error: format!("Failed to receive tool input: {e}"), + })?; + + let search_paths_task = cx.update(|cx| search_paths(&input.glob, project, cx)); - cx.background_spawn(async move { let matches = futures::select! { result = search_paths_task.fuse() => result.map_err(|e| FindPathToolOutput::Error { error: e.to_string() })?, _ = event_stream.cancelled_by_user().fuse() => { diff --git a/crates/agent/src/tools/grep_tool.rs b/crates/agent/src/tools/grep_tool.rs index 16162107dff84ab40117b7783e04b633d144a214..fbfdc18585b822361effb6fd770e678b3e434a17 100644 --- a/crates/agent/src/tools/grep_tool.rs +++ b/crates/agent/src/tools/grep_tool.rs @@ -1,4 +1,4 @@ -use crate::{AgentTool, ToolCallEventStream}; +use crate::{AgentTool, ToolCallEventStream, ToolInput}; use agent_client_protocol as acp; use anyhow::Result; use futures::{FutureExt as _, StreamExt}; @@ -114,66 +114,64 @@ impl AgentTool for GrepTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { const CONTEXT_LINES: u32 = 2; const MAX_ANCESTOR_LINES: u32 = 10; - let path_style = self.project.read(cx).path_style(cx); - - let include_matcher = match PathMatcher::new( - input - .include_pattern - .as_ref() - .into_iter() - .collect::>(), - path_style, - ) { - Ok(matcher) => matcher, - Err(error) => { - return Task::ready(Err(format!("invalid include glob pattern: {error}"))); - } - }; - - // Exclude global file_scan_exclusions and private_files settings - let exclude_matcher = { - let global_settings = WorktreeSettings::get_global(cx); - let exclude_patterns = global_settings - .file_scan_exclusions - .sources() - .chain(global_settings.private_files.sources()); - - match PathMatcher::new(exclude_patterns, path_style) { - Ok(matcher) => matcher, - Err(error) => { - return Task::ready(Err(format!("invalid exclude pattern: {error}"))); - } - } - }; - - let query = match SearchQuery::regex( - &input.regex, - false, - input.case_sensitive, - false, - false, - include_matcher, - exclude_matcher, - true, // Always match file include pattern against *full project paths* that start with a project root. - None, - ) { - Ok(query) => query, - Err(error) => return Task::ready(Err(error.to_string())), - }; - - let results = self - .project - .update(cx, |project, cx| project.search(query, cx)); - - let project = self.project.downgrade(); + let project = self.project.clone(); cx.spawn(async move |cx| { + let input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; + + let results = cx.update(|cx| { + let path_style = project.read(cx).path_style(cx); + + let include_matcher = PathMatcher::new( + input + .include_pattern + .as_ref() + .into_iter() + .collect::>(), + path_style, + ) + .map_err(|error| format!("invalid include glob pattern: {error}"))?; + + // Exclude global file_scan_exclusions and private_files settings + let exclude_matcher = { + let global_settings = WorktreeSettings::get_global(cx); + let exclude_patterns = global_settings + .file_scan_exclusions + .sources() + .chain(global_settings.private_files.sources()); + + PathMatcher::new(exclude_patterns, path_style) + .map_err(|error| format!("invalid exclude pattern: {error}"))? + }; + + let query = SearchQuery::regex( + &input.regex, + false, + input.case_sensitive, + false, + false, + include_matcher, + exclude_matcher, + true, // Always match file include pattern against *full project paths* that start with a project root. + None, + ) + .map_err(|error| error.to_string())?; + + Ok::<_, String>( + project.update(cx, |project, cx| project.search(query, cx)), + ) + })?; + + let project = project.downgrade(); // Keep the search alive for the duration of result iteration. Dropping this task is the // cancellation mechanism; we intentionally do not detach it. let SearchResults {rx, _task_handle} = results; @@ -787,7 +785,13 @@ mod tests { cx: &mut TestAppContext, ) -> String { let tool = Arc::new(GrepTool { project }); - let task = cx.update(|cx| tool.run(input, ToolCallEventStream::test().0, cx)); + let task = cx.update(|cx| { + tool.run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }); match task.await { Ok(result) => { diff --git a/crates/agent/src/tools/list_directory_tool.rs b/crates/agent/src/tools/list_directory_tool.rs index 5dddee94904283ccb9198ce56aa4005250b5908a..1a674aaa71fef5bf9c11688e82982a5dbcfee331 100644 --- a/crates/agent/src/tools/list_directory_tool.rs +++ b/crates/agent/src/tools/list_directory_tool.rs @@ -2,7 +2,7 @@ use super::tool_permissions::{ ResolvedProjectPath, authorize_symlink_access, canonicalize_worktree_roots, resolve_project_path, }; -use crate::{AgentTool, ToolCallEventStream}; +use crate::{AgentTool, ToolCallEventStream, ToolInput}; use agent_client_protocol::ToolKind; use anyhow::{Context as _, Result, anyhow}; use gpui::{App, Entity, SharedString, Task}; @@ -146,34 +146,39 @@ impl AgentTool for ListDirectoryTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - // Sometimes models will return these even though we tell it to give a path and not a glob. - // When this happens, just list the root worktree directories. - if matches!(input.path.as_str(), "." | "" | "./" | "*") { - let output = self - .project - .read(cx) - .worktrees(cx) - .filter_map(|worktree| { - let worktree = worktree.read(cx); - let root_entry = worktree.root_entry()?; - if root_entry.is_dir() { - Some(root_entry.path.display(worktree.path_style())) - } else { - None - } - }) - .collect::>() - .join("\n"); - - return Task::ready(Ok(output)); - } - let project = self.project.clone(); cx.spawn(async move |cx| { + let input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; + + // Sometimes models will return these even though we tell it to give a path and not a glob. + // When this happens, just list the root worktree directories. + if matches!(input.path.as_str(), "." | "" | "./" | "*") { + let output = project.read_with(cx, |project, cx| { + project + .worktrees(cx) + .filter_map(|worktree| { + let worktree = worktree.read(cx); + let root_entry = worktree.root_entry()?; + if root_entry.is_dir() { + Some(root_entry.path.display(worktree.path_style())) + } else { + None + } + }) + .collect::>() + .join("\n") + }); + + return Ok(output); + } + let fs = project.read_with(cx, |project, _cx| project.fs().clone()); let canonical_roots = canonicalize_worktree_roots(&project, &fs, cx).await; @@ -323,7 +328,13 @@ mod tests { path: "project".into(), }; let output = cx - .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .update(|cx| { + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }) .await .unwrap(); assert_eq!( @@ -344,7 +355,13 @@ mod tests { path: "project/src".into(), }; let output = cx - .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .update(|cx| { + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }) .await .unwrap(); assert_eq!( @@ -365,7 +382,13 @@ mod tests { path: "project/tests".into(), }; let output = cx - .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .update(|cx| { + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }) .await .unwrap(); assert!(!output.contains("# Folders:")); @@ -393,7 +416,13 @@ mod tests { path: "project/empty_dir".into(), }; let output = cx - .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .update(|cx| { + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }) .await .unwrap(); assert_eq!(output, "project/empty_dir is empty.\n"); @@ -420,7 +449,13 @@ mod tests { path: "project/nonexistent".into(), }; let output = cx - .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .update(|cx| { + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }) .await; assert!(output.unwrap_err().contains("Path not found")); @@ -429,7 +464,13 @@ mod tests { path: "project/file.txt".into(), }; let output = cx - .update(|cx| tool.run(input, ToolCallEventStream::test().0, cx)) + .update(|cx| { + tool.run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }) .await; assert!(output.unwrap_err().contains("is not a directory")); } @@ -493,7 +534,13 @@ mod tests { path: "project".into(), }; let output = cx - .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .update(|cx| { + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }) .await .unwrap(); @@ -520,7 +567,13 @@ mod tests { path: "project/.secretdir".into(), }; let output = cx - .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .update(|cx| { + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }) .await; assert!( output.unwrap_err().contains("file_scan_exclusions"), @@ -532,7 +585,13 @@ mod tests { path: "project/visible_dir".into(), }; let output = cx - .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .update(|cx| { + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }) .await .unwrap(); @@ -637,7 +696,13 @@ mod tests { path: "worktree1/src".into(), }; let output = cx - .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .update(|cx| { + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }) .await .unwrap(); assert!(output.contains("main.rs"), "Should list main.rs"); @@ -655,7 +720,13 @@ mod tests { path: "worktree1/tests".into(), }; let output = cx - .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .update(|cx| { + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }) .await .unwrap(); assert!(output.contains("test.rs"), "Should list test.rs"); @@ -669,7 +740,13 @@ mod tests { path: "worktree2/lib".into(), }; let output = cx - .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .update(|cx| { + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }) .await .unwrap(); assert!(output.contains("public.js"), "Should list public.js"); @@ -687,7 +764,13 @@ mod tests { path: "worktree2/docs".into(), }; let output = cx - .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .update(|cx| { + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }) .await .unwrap(); assert!(output.contains("README.md"), "Should list README.md"); @@ -701,7 +784,13 @@ mod tests { path: "worktree1/src/secret.rs".into(), }; let output = cx - .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .update(|cx| { + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }) .await; assert!(output.unwrap_err().contains("Cannot list directory"),); } @@ -743,9 +832,9 @@ mod tests { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { tool.clone().run( - ListDirectoryToolInput { + ToolInput::resolved(ListDirectoryToolInput { path: "project/link_to_external".into(), - }, + }), event_stream, cx, ) @@ -804,9 +893,9 @@ mod tests { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { tool.clone().run( - ListDirectoryToolInput { + ToolInput::resolved(ListDirectoryToolInput { path: "project/link_to_external".into(), - }, + }), event_stream, cx, ) @@ -871,9 +960,9 @@ mod tests { let result = cx .update(|cx| { tool.clone().run( - ListDirectoryToolInput { + ToolInput::resolved(ListDirectoryToolInput { path: "project/link_to_external".into(), - }, + }), event_stream, cx, ) @@ -924,9 +1013,9 @@ mod tests { let result = cx .update(|cx| { tool.clone().run( - ListDirectoryToolInput { + ToolInput::resolved(ListDirectoryToolInput { path: "project/src".into(), - }, + }), event_stream, cx, ) @@ -981,9 +1070,9 @@ mod tests { let result = cx .update(|cx| { tool.clone().run( - ListDirectoryToolInput { + ToolInput::resolved(ListDirectoryToolInput { path: "project/link_dir".into(), - }, + }), event_stream, cx, ) diff --git a/crates/agent/src/tools/move_path_tool.rs b/crates/agent/src/tools/move_path_tool.rs index 4c337d0ec2827ad7c63c87ef206f0e74dc63091f..c246b3c5b0661546f4617bb5521766f9da3839fb 100644 --- a/crates/agent/src/tools/move_path_tool.rs +++ b/crates/agent/src/tools/move_path_tool.rs @@ -2,7 +2,9 @@ use super::tool_permissions::{ SensitiveSettingsKind, authorize_symlink_escapes, canonicalize_worktree_roots, collect_symlink_escapes, sensitive_settings_kind, }; -use crate::{AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_for_paths}; +use crate::{ + AgentTool, ToolCallEventStream, ToolInput, ToolPermissionDecision, decide_permission_for_paths, +}; use agent_client_protocol::ToolKind; use agent_settings::AgentSettings; use futures::FutureExt as _; @@ -92,19 +94,24 @@ impl AgentTool for MovePathTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let settings = AgentSettings::get_global(cx); - let paths = vec![input.source_path.clone(), input.destination_path.clone()]; - let decision = decide_permission_for_paths(Self::NAME, &paths, settings); - if let ToolPermissionDecision::Deny(reason) = decision { - return Task::ready(Err(reason)); - } - let project = self.project.clone(); cx.spawn(async move |cx| { + let input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; + let paths = vec![input.source_path.clone(), input.destination_path.clone()]; + let decision = cx.update(|cx| { + decide_permission_for_paths(Self::NAME, &paths, AgentSettings::get_global(cx)) + }); + if let ToolPermissionDecision::Deny(reason) = decision { + return Err(reason); + } + let fs = project.read_with(cx, |project, _cx| project.fs().clone()); let canonical_roots = canonicalize_worktree_roots(&project, &fs, cx).await; @@ -255,7 +262,7 @@ mod tests { }; let (event_stream, mut event_rx) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx)); let auth = event_rx.expect_authorization().await; let title = auth.tool_call.fields.title.as_deref().unwrap_or(""); @@ -309,7 +316,7 @@ mod tests { }; let (event_stream, mut event_rx) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx)); let auth = event_rx.expect_authorization().await; drop(auth); @@ -361,7 +368,7 @@ mod tests { }; let (event_stream, mut event_rx) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx)); let auth = event_rx.expect_authorization().await; let title = auth.tool_call.fields.title.as_deref().unwrap_or(""); @@ -437,7 +444,9 @@ mod tests { }; let (event_stream, mut event_rx) = ToolCallEventStream::test(); - let result = cx.update(|cx| tool.run(input, event_stream, cx)).await; + let result = cx + .update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx)) + .await; assert!(result.is_err(), "Tool should fail when policy denies"); assert!( diff --git a/crates/agent/src/tools/now_tool.rs b/crates/agent/src/tools/now_tool.rs index 689d70ff20d15cbc56fcc0e14a3b46408647f737..fe1cafe5881d14c9700813f742e1f2df0aa1203e 100644 --- a/crates/agent/src/tools/now_tool.rs +++ b/crates/agent/src/tools/now_tool.rs @@ -6,7 +6,7 @@ use gpui::{App, SharedString, Task}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use crate::{AgentTool, ToolCallEventStream}; +use crate::{AgentTool, ToolCallEventStream, ToolInput}; #[derive(Debug, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "snake_case")] @@ -48,14 +48,20 @@ impl AgentTool for NowTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, _event_stream: ToolCallEventStream, - _cx: &mut App, + cx: &mut App, ) -> Task> { - let now = match input.timezone { - Timezone::Utc => Utc::now().to_rfc3339(), - Timezone::Local => Local::now().to_rfc3339(), - }; - Task::ready(Ok(format!("The current datetime is {now}."))) + cx.spawn(async move |_cx| { + let input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; + let now = match input.timezone { + Timezone::Utc => Utc::now().to_rfc3339(), + Timezone::Local => Local::now().to_rfc3339(), + }; + Ok(format!("The current datetime is {now}.")) + }) } } diff --git a/crates/agent/src/tools/open_tool.rs b/crates/agent/src/tools/open_tool.rs index c0b24efbec6418c437e9e3d14ffb5d40b45c91b0..344a513d10c2d62e4247dd3e47bcdf428586d6f0 100644 --- a/crates/agent/src/tools/open_tool.rs +++ b/crates/agent/src/tools/open_tool.rs @@ -2,7 +2,7 @@ use super::tool_permissions::{ ResolvedProjectPath, authorize_symlink_access, canonicalize_worktree_roots, resolve_project_path, }; -use crate::AgentTool; +use crate::{AgentTool, ToolInput}; use agent_client_protocol::ToolKind; use futures::FutureExt as _; use gpui::{App, AppContext as _, Entity, SharedString, Task}; @@ -61,16 +61,24 @@ impl AgentTool for OpenTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: crate::ToolCallEventStream, cx: &mut App, ) -> Task> { - // If path_or_url turns out to be a path in the project, make it absolute. - let abs_path = to_absolute_path(&input.path_or_url, self.project.clone(), cx); - let initial_title = self.initial_title(Ok(input.clone()), cx); - let project = self.project.clone(); cx.spawn(async move |cx| { + let input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; + + // If path_or_url turns out to be a path in the project, make it absolute. + let (abs_path, initial_title) = cx.update(|cx| { + let abs_path = to_absolute_path(&input.path_or_url, project.clone(), cx); + let initial_title = self.initial_title(Ok(input.clone()), cx); + (abs_path, initial_title) + }); + let fs = project.read_with(cx, |project, _cx| project.fs().clone()); let canonical_roots = canonicalize_worktree_roots(&project, &fs, cx).await; diff --git a/crates/agent/src/tools/read_file_tool.rs b/crates/agent/src/tools/read_file_tool.rs index efd33fe5caece4cee4fc02aab8c1a0ebee92f94e..bbc67cf68c7d104772c18ad222478621ce4d7a54 100644 --- a/crates/agent/src/tools/read_file_tool.rs +++ b/crates/agent/src/tools/read_file_tool.rs @@ -21,7 +21,7 @@ use super::tool_permissions::{ ResolvedProjectPath, authorize_symlink_access, canonicalize_worktree_roots, resolve_project_path, }; -use crate::{AgentTool, Thread, ToolCallEventStream, outline}; +use crate::{AgentTool, Thread, ToolCallEventStream, ToolInput, outline}; /// Reads the content of the given file in the project. /// @@ -114,7 +114,7 @@ impl AgentTool for ReadFileTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { @@ -122,6 +122,10 @@ impl AgentTool for ReadFileTool { let thread = self.thread.clone(); let action_log = self.action_log.clone(); cx.spawn(async move |cx| { + let input = input + .recv() + .await + .map_err(tool_content_err)?; let fs = project.read_with(cx, |project, _cx| project.fs().clone()); let canonical_roots = canonicalize_worktree_roots(&project, &fs, cx).await; @@ -398,7 +402,7 @@ mod test { start_line: None, end_line: None, }; - tool.run(input, event_stream, cx) + tool.run(ToolInput::resolved(input), event_stream, cx) }) .await; assert_eq!( @@ -442,7 +446,11 @@ mod test { start_line: None, end_line: None, }; - tool.run(input, ToolCallEventStream::test().0, cx) + tool.run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; assert_eq!(result.unwrap(), "This is a small file content".into()); @@ -485,7 +493,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await .unwrap(); @@ -510,7 +522,11 @@ mod test { start_line: None, end_line: None, }; - tool.run(input, ToolCallEventStream::test().0, cx) + tool.run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await .unwrap(); @@ -570,7 +586,11 @@ mod test { start_line: Some(2), end_line: Some(4), }; - tool.run(input, ToolCallEventStream::test().0, cx) + tool.run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4\n".into()); @@ -613,7 +633,11 @@ mod test { start_line: Some(0), end_line: Some(2), }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; assert_eq!(result.unwrap(), "Line 1\nLine 2\n".into()); @@ -626,7 +650,11 @@ mod test { start_line: Some(1), end_line: Some(0), }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; assert_eq!(result.unwrap(), "Line 1\n".into()); @@ -639,7 +667,11 @@ mod test { start_line: Some(3), end_line: Some(2), }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; assert_eq!(result.unwrap(), "Line 3\n".into()); @@ -744,7 +776,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; assert!( @@ -760,7 +796,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; assert!( @@ -776,7 +816,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; assert!( @@ -791,7 +835,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; assert!( @@ -807,7 +855,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; assert!( @@ -822,7 +874,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; assert!( @@ -837,7 +893,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; assert!( @@ -853,7 +913,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; assert!(result.is_ok(), "Should be able to read normal files"); @@ -867,7 +931,11 @@ mod test { start_line: None, end_line: None, }; - tool.run(input, ToolCallEventStream::test().0, cx) + tool.run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; assert!( @@ -911,11 +979,11 @@ mod test { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let read_task = cx.update(|cx| { tool.run( - ReadFileToolInput { + ToolInput::resolved(ReadFileToolInput { path: "root/secret.png".to_string(), start_line: None, end_line: None, - }, + }), event_stream, cx, ) @@ -1039,7 +1107,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await .unwrap(); @@ -1057,7 +1129,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; @@ -1075,7 +1151,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; @@ -1093,7 +1173,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await .unwrap(); @@ -1111,7 +1195,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; @@ -1129,7 +1217,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; @@ -1148,7 +1240,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; @@ -1210,11 +1306,11 @@ mod test { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { tool.clone().run( - ReadFileToolInput { + ToolInput::resolved(ReadFileToolInput { path: "project/secret_link.txt".to_string(), start_line: None, end_line: None, - }, + }), event_stream, cx, ) @@ -1286,11 +1382,11 @@ mod test { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { tool.clone().run( - ReadFileToolInput { + ToolInput::resolved(ReadFileToolInput { path: "project/secret_link.txt".to_string(), start_line: None, end_line: None, - }, + }), event_stream, cx, ) @@ -1367,11 +1463,11 @@ mod test { let result = cx .update(|cx| { tool.clone().run( - ReadFileToolInput { + ToolInput::resolved(ReadFileToolInput { path: "project/secret_link.txt".to_string(), start_line: None, end_line: None, - }, + }), event_stream, cx, ) diff --git a/crates/agent/src/tools/restore_file_from_disk_tool.rs b/crates/agent/src/tools/restore_file_from_disk_tool.rs index 304e0d1180fe626482206bfdc2dfa6d53f529816..c1aa8690a840ea6911dcb94c26c8cef3cb5f313d 100644 --- a/crates/agent/src/tools/restore_file_from_disk_tool.rs +++ b/crates/agent/src/tools/restore_file_from_disk_tool.rs @@ -17,7 +17,9 @@ use std::path::{Path, PathBuf}; use std::sync::Arc; use util::markdown::MarkdownInlineCode; -use crate::{AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_for_path}; +use crate::{ + AgentTool, ToolCallEventStream, ToolInput, ToolPermissionDecision, decide_permission_for_path, +}; /// Discards unsaved changes in open buffers by reloading file contents from disk. /// @@ -66,25 +68,31 @@ impl AgentTool for RestoreFileFromDiskTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let settings = AgentSettings::get_global(cx).clone(); - - // Check for any immediate deny before spawning async work. - for path in &input.paths { - let path_str = path.to_string_lossy(); - let decision = decide_permission_for_path(Self::NAME, &path_str, &settings); - if let ToolPermissionDecision::Deny(reason) = decision { - return Task::ready(Err(reason)); - } - } - let project = self.project.clone(); - let input_paths = input.paths; cx.spawn(async move |cx| { + let input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; + + // Check for any immediate deny before doing async work. + for path in &input.paths { + let path_str = path.to_string_lossy(); + let decision = cx.update(|cx| { + decide_permission_for_path(Self::NAME, &path_str, AgentSettings::get_global(cx)) + }); + if let ToolPermissionDecision::Deny(reason) = decision { + return Err(reason); + } + } + + let input_paths = input.paths; + let fs = project.read_with(cx, |project, _cx| project.fs().clone()); let canonical_roots = canonicalize_worktree_roots(&project, &fs, cx).await; @@ -92,7 +100,9 @@ impl AgentTool for RestoreFileFromDiskTool { for path in &input_paths { let path_str = path.to_string_lossy(); - let decision = decide_permission_for_path(Self::NAME, &path_str, &settings); + let decision = cx.update(|cx| { + decide_permission_for_path(Self::NAME, &path_str, AgentSettings::get_global(cx)) + }); let symlink_escape = project.read_with(cx, |project, cx| { path_has_symlink_escape(project, path, &canonical_roots, cx) }); @@ -378,12 +388,12 @@ mod tests { let output = cx .update(|cx| { tool.clone().run( - RestoreFileFromDiskToolInput { + ToolInput::resolved(RestoreFileFromDiskToolInput { paths: vec![ PathBuf::from("root/dirty.txt"), PathBuf::from("root/clean.txt"), ], - }, + }), ToolCallEventStream::test().0, cx, ) @@ -428,7 +438,7 @@ mod tests { let output = cx .update(|cx| { tool.clone().run( - RestoreFileFromDiskToolInput { paths: vec![] }, + ToolInput::resolved(RestoreFileFromDiskToolInput { paths: vec![] }), ToolCallEventStream::test().0, cx, ) @@ -441,9 +451,9 @@ mod tests { let output = cx .update(|cx| { tool.clone().run( - RestoreFileFromDiskToolInput { + ToolInput::resolved(RestoreFileFromDiskToolInput { paths: vec![PathBuf::from("nonexistent/path.txt")], - }, + }), ToolCallEventStream::test().0, cx, ) @@ -495,9 +505,9 @@ mod tests { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { tool.clone().run( - RestoreFileFromDiskToolInput { + ToolInput::resolved(RestoreFileFromDiskToolInput { paths: vec![PathBuf::from("project/link.txt")], - }, + }), event_stream, cx, ) @@ -564,9 +574,9 @@ mod tests { let result = cx .update(|cx| { tool.clone().run( - RestoreFileFromDiskToolInput { + ToolInput::resolved(RestoreFileFromDiskToolInput { paths: vec![PathBuf::from("project/link.txt")], - }, + }), event_stream, cx, ) @@ -623,9 +633,9 @@ mod tests { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { tool.clone().run( - RestoreFileFromDiskToolInput { + ToolInput::resolved(RestoreFileFromDiskToolInput { paths: vec![PathBuf::from("project/link.txt")], - }, + }), event_stream, cx, ) diff --git a/crates/agent/src/tools/save_file_tool.rs b/crates/agent/src/tools/save_file_tool.rs index 20140c77d113d96c741d5afbe672882f708870d6..99e937b9dff2a1b4781dde16bd2bf6d64edd25ad 100644 --- a/crates/agent/src/tools/save_file_tool.rs +++ b/crates/agent/src/tools/save_file_tool.rs @@ -17,7 +17,9 @@ use super::tool_permissions::{ canonicalize_worktree_roots, path_has_symlink_escape, resolve_project_path, sensitive_settings_kind, }; -use crate::{AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_for_path}; +use crate::{ + AgentTool, ToolCallEventStream, ToolInput, ToolPermissionDecision, decide_permission_for_path, +}; /// Saves files that have unsaved changes. /// @@ -63,25 +65,31 @@ impl AgentTool for SaveFileTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let settings = AgentSettings::get_global(cx).clone(); - - // Check for any immediate deny before spawning async work. - for path in &input.paths { - let path_str = path.to_string_lossy(); - let decision = decide_permission_for_path(Self::NAME, &path_str, &settings); - if let ToolPermissionDecision::Deny(reason) = decision { - return Task::ready(Err(reason)); - } - } - let project = self.project.clone(); - let input_paths = input.paths; cx.spawn(async move |cx| { + let input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; + + // Check for any immediate deny before doing async work. + for path in &input.paths { + let path_str = path.to_string_lossy(); + let decision = cx.update(|cx| { + decide_permission_for_path(Self::NAME, &path_str, AgentSettings::get_global(cx)) + }); + if let ToolPermissionDecision::Deny(reason) = decision { + return Err(reason); + } + } + + let input_paths = input.paths; + let fs = project.read_with(cx, |project, _cx| project.fs().clone()); let canonical_roots = canonicalize_worktree_roots(&project, &fs, cx).await; @@ -89,7 +97,9 @@ impl AgentTool for SaveFileTool { for path in &input_paths { let path_str = path.to_string_lossy(); - let decision = decide_permission_for_path(Self::NAME, &path_str, &settings); + let decision = cx.update(|cx| { + decide_permission_for_path(Self::NAME, &path_str, AgentSettings::get_global(cx)) + }); let symlink_escape = project.read_with(cx, |project, cx| { path_has_symlink_escape(project, path, &canonical_roots, cx) }); @@ -382,12 +392,12 @@ mod tests { let output = cx .update(|cx| { tool.clone().run( - SaveFileToolInput { + ToolInput::resolved(SaveFileToolInput { paths: vec![ PathBuf::from("root/dirty.txt"), PathBuf::from("root/clean.txt"), ], - }, + }), ToolCallEventStream::test().0, cx, ) @@ -425,7 +435,7 @@ mod tests { let output = cx .update(|cx| { tool.clone().run( - SaveFileToolInput { paths: vec![] }, + ToolInput::resolved(SaveFileToolInput { paths: vec![] }), ToolCallEventStream::test().0, cx, ) @@ -438,9 +448,9 @@ mod tests { let output = cx .update(|cx| { tool.clone().run( - SaveFileToolInput { + ToolInput::resolved(SaveFileToolInput { paths: vec![PathBuf::from("nonexistent/path.txt")], - }, + }), ToolCallEventStream::test().0, cx, ) @@ -490,9 +500,9 @@ mod tests { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { tool.clone().run( - SaveFileToolInput { + ToolInput::resolved(SaveFileToolInput { paths: vec![PathBuf::from("project/link.txt")], - }, + }), event_stream, cx, ) @@ -559,9 +569,9 @@ mod tests { let result = cx .update(|cx| { tool.clone().run( - SaveFileToolInput { + ToolInput::resolved(SaveFileToolInput { paths: vec![PathBuf::from("project/link.txt")], - }, + }), event_stream, cx, ) @@ -618,9 +628,9 @@ mod tests { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { tool.clone().run( - SaveFileToolInput { + ToolInput::resolved(SaveFileToolInput { paths: vec![PathBuf::from("project/link.txt")], - }, + }), event_stream, cx, ) @@ -702,12 +712,12 @@ mod tests { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { tool.clone().run( - SaveFileToolInput { + ToolInput::resolved(SaveFileToolInput { paths: vec![ PathBuf::from("project/dirty.txt"), PathBuf::from("project/link.txt"), ], - }, + }), event_stream, cx, ) diff --git a/crates/agent/src/tools/spawn_agent_tool.rs b/crates/agent/src/tools/spawn_agent_tool.rs index e2dd78d4476de48465cb5c48e225e2ae5a0a7767..69529282544cc35a01f792dcb45df6eb8bdf67d5 100644 --- a/crates/agent/src/tools/spawn_agent_tool.rs +++ b/crates/agent/src/tools/spawn_agent_tool.rs @@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize}; use std::rc::Rc; use std::sync::Arc; -use crate::{AgentTool, Thread, ThreadEnvironment, ToolCallEventStream}; +use crate::{AgentTool, Thread, ThreadEnvironment, ToolCallEventStream, ToolInput}; /// Spawns an agent to perform a delegated task. /// @@ -97,61 +97,78 @@ impl AgentTool for SpawnAgentTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let Some(parent_thread_entity) = self.parent_thread.upgrade() else { - return Task::ready(Err(SpawnAgentToolOutput::Error { - session_id: None, - error: "Parent thread no longer exists".to_string(), - })); - }; - - let subagent = if let Some(session_id) = input.session_id { - self.environment - .resume_subagent(parent_thread_entity, session_id, input.message, cx) - } else { - self.environment - .create_subagent(parent_thread_entity, input.label, input.message, cx) - }; - let subagent = match subagent { - Ok(subagent) => subagent, - Err(err) => { - return Task::ready(Err(SpawnAgentToolOutput::Error { + cx.spawn(async move |cx| { + let input = input + .recv() + .await + .map_err(|e| SpawnAgentToolOutput::Error { + session_id: None, + error: format!("Failed to receive tool input: {e}"), + })?; + + let (subagent, subagent_session_id) = cx.update(|cx| { + let Some(parent_thread_entity) = self.parent_thread.upgrade() else { + return Err(SpawnAgentToolOutput::Error { + session_id: None, + error: "Parent thread no longer exists".to_string(), + }); + }; + + let subagent = if let Some(session_id) = input.session_id { + self.environment.resume_subagent( + parent_thread_entity, + session_id, + input.message, + cx, + ) + } else { + self.environment.create_subagent( + parent_thread_entity, + input.label, + input.message, + cx, + ) + }; + let subagent = subagent.map_err(|err| SpawnAgentToolOutput::Error { session_id: None, error: err.to_string(), - })); - } - }; - let subagent_session_id = subagent.id(); - - event_stream.subagent_spawned(subagent_session_id.clone()); - let meta = acp::Meta::from_iter([( - SUBAGENT_SESSION_ID_META_KEY.into(), - subagent_session_id.to_string().into(), - )]); - event_stream.update_fields_with_meta(acp::ToolCallUpdateFields::new(), Some(meta)); - - cx.spawn(async move |cx| match subagent.wait_for_output(cx).await { - Ok(output) => { - event_stream.update_fields( - acp::ToolCallUpdateFields::new().content(vec![output.clone().into()]), - ); - Ok(SpawnAgentToolOutput::Success { - session_id: subagent_session_id, - output, - }) - } - Err(e) => { - let error = e.to_string(); - event_stream.update_fields( - acp::ToolCallUpdateFields::new().content(vec![error.clone().into()]), - ); - Err(SpawnAgentToolOutput::Error { - session_id: Some(subagent_session_id), - error, - }) + })?; + let subagent_session_id = subagent.id(); + + event_stream.subagent_spawned(subagent_session_id.clone()); + let meta = acp::Meta::from_iter([( + SUBAGENT_SESSION_ID_META_KEY.into(), + subagent_session_id.to_string().into(), + )]); + event_stream.update_fields_with_meta(acp::ToolCallUpdateFields::new(), Some(meta)); + + Ok((subagent, subagent_session_id)) + })?; + + match subagent.wait_for_output(cx).await { + Ok(output) => { + event_stream.update_fields( + acp::ToolCallUpdateFields::new().content(vec![output.clone().into()]), + ); + Ok(SpawnAgentToolOutput::Success { + session_id: subagent_session_id, + output, + }) + } + Err(e) => { + let error = e.to_string(); + event_stream.update_fields( + acp::ToolCallUpdateFields::new().content(vec![error.clone().into()]), + ); + Err(SpawnAgentToolOutput::Error { + session_id: Some(subagent_session_id), + error, + }) + } } }) } diff --git a/crates/agent/src/tools/streaming_edit_file_tool.rs b/crates/agent/src/tools/streaming_edit_file_tool.rs index dd5445142a001fbd9106af548444165bc8331581..95651b44bac44ad3cc67c25c0ef13fc885342ce3 100644 --- a/crates/agent/src/tools/streaming_edit_file_tool.rs +++ b/crates/agent/src/tools/streaming_edit_file_tool.rs @@ -2,7 +2,7 @@ use super::edit_file_tool::EditFileTool; use super::restore_file_from_disk_tool::RestoreFileFromDiskTool; use super::save_file_tool::SaveFileTool; use crate::{ - AgentTool, Templates, Thread, ToolCallEventStream, + AgentTool, Thread, ToolCallEventStream, ToolInput, edit_agent::streaming_fuzzy_matcher::StreamingFuzzyMatcher, }; use acp_thread::Diff; @@ -164,8 +164,6 @@ pub struct StreamingEditFileTool { thread: WeakEntity, language_registry: Arc, project: Entity, - #[allow(dead_code)] - templates: Arc, } impl StreamingEditFileTool { @@ -173,13 +171,11 @@ impl StreamingEditFileTool { project: Entity, thread: WeakEntity, language_registry: Arc, - templates: Arc, ) -> Self { Self { project, thread, language_registry, - templates, } } @@ -188,7 +184,6 @@ impl StreamingEditFileTool { project: self.project.clone(), thread: new_thread, language_registry: self.language_registry.clone(), - templates: self.templates.clone(), } } @@ -268,38 +263,41 @@ impl AgentTool for StreamingEditFileTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let Ok(project) = self - .thread - .read_with(cx, |thread, _cx| thread.project().clone()) - else { - return Task::ready(Err(StreamingEditFileToolOutput::Error { - error: "thread was dropped".to_string(), - })); - }; - - let project_path = match resolve_path(&input, project.clone(), cx) { - Ok(path) => path, - Err(err) => { - return Task::ready(Err(StreamingEditFileToolOutput::Error { - error: err.to_string(), - })); - } - }; - - let abs_path = project.read(cx).absolute_path(&project_path, cx); - if let Some(abs_path) = abs_path.clone() { - event_stream.update_fields( - ToolCallUpdateFields::new().locations(vec![acp::ToolCallLocation::new(abs_path)]), - ); - } - - let authorize = self.authorize(&input, &event_stream, cx); - cx.spawn(async move |cx: &mut AsyncApp| { + let input = input.recv().await.map_err(|e| { + StreamingEditFileToolOutput::Error { + error: format!("Failed to receive tool input: {e}"), + } + })?; + + let project = self + .thread + .read_with(cx, |thread, _cx| thread.project().clone()) + .map_err(|_| StreamingEditFileToolOutput::Error { + error: "thread was dropped".to_string(), + })?; + + let (project_path, abs_path, authorize) = cx.update(|cx| { + let project_path = + resolve_path(&input, project.clone(), cx).map_err(|err| { + StreamingEditFileToolOutput::Error { + error: err.to_string(), + } + })?; + let abs_path = project.read(cx).absolute_path(&project_path, cx); + if let Some(abs_path) = abs_path.clone() { + event_stream.update_fields( + ToolCallUpdateFields::new() + .locations(vec![acp::ToolCallLocation::new(abs_path)]), + ); + } + let authorize = self.authorize(&input, &event_stream, cx); + Ok::<_, StreamingEditFileToolOutput>((project_path, abs_path, authorize)) + })?; let result: anyhow::Result = async { authorize.await?; @@ -787,9 +785,12 @@ mod tests { project.clone(), thread.downgrade(), language_registry, - Templates::new(), )) - .run(input, ToolCallEventStream::test().0, cx) + .run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; @@ -836,9 +837,12 @@ mod tests { project.clone(), thread.downgrade(), language_registry, - Templates::new(), )) - .run(input, ToolCallEventStream::test().0, cx) + .run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; @@ -896,9 +900,12 @@ mod tests { project.clone(), thread.downgrade(), language_registry, - Templates::new(), )) - .run(input, ToolCallEventStream::test().0, cx) + .run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; @@ -958,9 +965,12 @@ mod tests { project.clone(), thread.downgrade(), language_registry, - Templates::new(), )) - .run(input, ToolCallEventStream::test().0, cx) + .run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; @@ -1023,9 +1033,12 @@ mod tests { project.clone(), thread.downgrade(), language_registry, - Templates::new(), )) - .run(input, ToolCallEventStream::test().0, cx) + .run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; @@ -1088,9 +1101,12 @@ mod tests { project.clone(), thread.downgrade(), language_registry, - Templates::new(), )) - .run(input, ToolCallEventStream::test().0, cx) + .run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; @@ -1141,9 +1157,12 @@ mod tests { project, thread.downgrade(), language_registry, - Templates::new(), )) - .run(input, ToolCallEventStream::test().0, cx) + .run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; @@ -1192,9 +1211,12 @@ mod tests { project, thread.downgrade(), language_registry, - Templates::new(), )) - .run(input, ToolCallEventStream::test().0, cx) + .run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; @@ -1262,9 +1284,12 @@ mod tests { project, thread.downgrade(), language_registry, - Templates::new(), )) - .run(input, ToolCallEventStream::test().0, cx) + .run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; diff --git a/crates/agent/src/tools/terminal_tool.rs b/crates/agent/src/tools/terminal_tool.rs index 57b3278da256c01408f704a8e2f6f7e075057597..6396bd1b0e63b46a0207dd7df9b9f2fcd00176b7 100644 --- a/crates/agent/src/tools/terminal_tool.rs +++ b/crates/agent/src/tools/terminal_tool.rs @@ -15,7 +15,7 @@ use std::{ }; use crate::{ - AgentTool, ThreadEnvironment, ToolCallEventStream, ToolPermissionDecision, + AgentTool, ThreadEnvironment, ToolCallEventStream, ToolInput, ToolPermissionDecision, decide_permission_from_settings, }; @@ -85,34 +85,45 @@ impl AgentTool for TerminalTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let working_dir = match working_dir(&input, &self.project, cx) { - Ok(dir) => dir, - Err(err) => return Task::ready(Err(err.to_string())), - }; + cx.spawn(async move |cx| { + let input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; - let settings = AgentSettings::get_global(cx); - let decision = decide_permission_from_settings( - Self::NAME, - std::slice::from_ref(&input.command), - settings, - ); + let (working_dir, authorize) = cx.update(|cx| { + let working_dir = + working_dir(&input, &self.project, cx).map_err(|err| err.to_string())?; - let authorize = match decision { - ToolPermissionDecision::Allow => None, - ToolPermissionDecision::Deny(reason) => { - return Task::ready(Err(reason)); - } - ToolPermissionDecision::Confirm => { - let context = - crate::ToolPermissionContext::new(Self::NAME, vec![input.command.clone()]); - Some(event_stream.authorize(self.initial_title(Ok(input.clone()), cx), context, cx)) - } - }; - cx.spawn(async move |cx| { + let decision = decide_permission_from_settings( + Self::NAME, + std::slice::from_ref(&input.command), + AgentSettings::get_global(cx), + ); + + let authorize = match decision { + ToolPermissionDecision::Allow => None, + ToolPermissionDecision::Deny(reason) => { + return Err(reason); + } + ToolPermissionDecision::Confirm => { + let context = crate::ToolPermissionContext::new( + Self::NAME, + vec![input.command.clone()], + ); + Some(event_stream.authorize( + self.initial_title(Ok(input.clone()), cx), + context, + cx, + )) + } + }; + Ok((working_dir, authorize)) + })?; if let Some(authorize) = authorize { authorize.await.map_err(|e| e.to_string())?; } diff --git a/crates/agent/src/tools/web_search_tool.rs b/crates/agent/src/tools/web_search_tool.rs index c536f45ba65c109d3068b0722db1ffb1cad8b87c..c697a5b78f1fe8c84d6ed58db13f651a493ae8c3 100644 --- a/crates/agent/src/tools/web_search_tool.rs +++ b/crates/agent/src/tools/web_search_tool.rs @@ -1,14 +1,15 @@ use std::sync::Arc; use crate::{ - AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_from_settings, + AgentTool, ToolCallEventStream, ToolInput, ToolPermissionDecision, + decide_permission_from_settings, }; use agent_client_protocol as acp; use agent_settings::AgentSettings; use anyhow::Result; use cloud_llm_client::WebSearchResponse; use futures::FutureExt as _; -use gpui::{App, AppContext, Task}; +use gpui::{App, Task}; use language_model::{ LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID, }; @@ -73,41 +74,51 @@ impl AgentTool for WebSearchTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let settings = AgentSettings::get_global(cx); - let decision = decide_permission_from_settings( - Self::NAME, - std::slice::from_ref(&input.query), - settings, - ); - - let authorize = match decision { - ToolPermissionDecision::Allow => None, - ToolPermissionDecision::Deny(reason) => { - return Task::ready(Err(WebSearchToolOutput::Error { error: reason })); - } - ToolPermissionDecision::Confirm => { - let context = - crate::ToolPermissionContext::new(Self::NAME, vec![input.query.clone()]); - Some(event_stream.authorize( - format!("Search the web for {}", MarkdownInlineCode(&input.query)), - context, - cx, - )) - } - }; + cx.spawn(async move |cx| { + let input = input + .recv() + .await + .map_err(|e| WebSearchToolOutput::Error { + error: format!("Failed to receive tool input: {e}"), + })?; + + let (authorize, search_task) = cx.update(|cx| { + let decision = decide_permission_from_settings( + Self::NAME, + std::slice::from_ref(&input.query), + AgentSettings::get_global(cx), + ); + + let authorize = match decision { + ToolPermissionDecision::Allow => None, + ToolPermissionDecision::Deny(reason) => { + return Err(WebSearchToolOutput::Error { error: reason }); + } + ToolPermissionDecision::Confirm => { + let context = + crate::ToolPermissionContext::new(Self::NAME, vec![input.query.clone()]); + Some(event_stream.authorize( + format!("Search the web for {}", MarkdownInlineCode(&input.query)), + context, + cx, + )) + } + }; + + let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else { + return Err(WebSearchToolOutput::Error { + error: "Web search is not available.".to_string(), + }); + }; - let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else { - return Task::ready(Err(WebSearchToolOutput::Error { - error: "Web search is not available.".to_string(), - })); - }; + let search_task = provider.search(input.query, cx); + Ok((authorize, search_task)) + })?; - let search_task = provider.search(input.query, cx); - cx.background_spawn(async move { if let Some(authorize) = authorize { authorize.await.map_err(|e| WebSearchToolOutput::Error { error: e.to_string() })?; } diff --git a/crates/remote_server/src/remote_editing_tests.rs b/crates/remote_server/src/remote_editing_tests.rs index 9d673182bc64e192e6db13a927392d611c53407d..f15382b67557fa9a9b0eda2a9d4438aa33c7cff3 100644 --- a/crates/remote_server/src/remote_editing_tests.rs +++ b/crates/remote_server/src/remote_editing_tests.rs @@ -2,7 +2,9 @@ /// The tests in this file assume that server_cx is running on Windows too. /// We neead to find a way to test Windows-Non-Windows interactions. use crate::headless_project::HeadlessProject; -use agent::{AgentTool, ReadFileTool, ReadFileToolInput, Templates, Thread, ToolCallEventStream}; +use agent::{ + AgentTool, ReadFileTool, ReadFileToolInput, Templates, Thread, ToolCallEventStream, ToolInput, +}; use client::{Client, UserStore}; use clock::FakeSystemClock; use collections::{HashMap, HashSet}; @@ -1962,7 +1964,11 @@ async fn test_remote_agent_fs_tool_calls(cx: &mut TestAppContext, server_cx: &mu let read_tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log)); let (event_stream, _) = ToolCallEventStream::test(); - let exists_result = cx.update(|cx| read_tool.clone().run(input, event_stream.clone(), cx)); + let exists_result = cx.update(|cx| { + read_tool + .clone() + .run(ToolInput::resolved(input), event_stream.clone(), cx) + }); let output = exists_result.await.unwrap(); assert_eq!(output, LanguageModelToolResultContent::Text("B".into())); @@ -1971,7 +1977,8 @@ async fn test_remote_agent_fs_tool_calls(cx: &mut TestAppContext, server_cx: &mu start_line: None, end_line: None, }; - let does_not_exist_result = cx.update(|cx| read_tool.run(input, event_stream, cx)); + let does_not_exist_result = + cx.update(|cx| read_tool.run(ToolInput::resolved(input), event_stream, cx)); does_not_exist_result.await.unwrap_err(); } diff --git a/crates/zed/src/visual_test_runner.rs b/crates/zed/src/visual_test_runner.rs index 09340dcec641ae2a6c1ea871e770886d14276529..b7471321db203075ac6c71eee0b3ef29c5edaefc 100644 --- a/crates/zed/src/visual_test_runner.rs +++ b/crates/zed/src/visual_test_runner.rs @@ -1962,7 +1962,7 @@ fn run_agent_thread_view_test( cx: &mut VisualTestAppContext, update_baseline: bool, ) -> Result { - use agent::AgentTool; + use agent::{AgentTool, ToolInput}; use agent_ui::AgentPanel; // Create a temporary directory with the test image @@ -2047,7 +2047,10 @@ fn run_agent_thread_view_test( start_line: None, end_line: None, }; - let run_task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); + let run_task = cx.update(|cx| { + tool.clone() + .run(ToolInput::resolved(input), event_stream, cx) + }); cx.background_executor.allow_parking(); let run_result = cx.foreground_executor.block_test(run_task);