diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index 139242fdee9da968986b3fc9537bf9e5292b7dc5..e8c95c630b65870bfc8a78b9e965373a2604879d 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -310,11 +310,11 @@ async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) { let task = cx.update(|cx| { tool.run( - crate::TerminalToolInput { + ToolInput::resolved(crate::TerminalToolInput { command: "sleep 1000".to_string(), cd: ".".to_string(), timeout_ms: Some(5), - }, + }), event_stream, cx, ) @@ -377,11 +377,11 @@ async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAp let _task = cx.update(|cx| { tool.run( - crate::TerminalToolInput { + ToolInput::resolved(crate::TerminalToolInput { command: "sleep 1000".to_string(), cd: ".".to_string(), timeout_ms: None, - }, + }), event_stream, cx, ) @@ -3991,11 +3991,11 @@ async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) { let task = cx.update(|cx| { tool.run( - crate::TerminalToolInput { + ToolInput::resolved(crate::TerminalToolInput { command: "rm -rf /".to_string(), cd: ".".to_string(), timeout_ms: None, - }, + }), event_stream, cx, ) @@ -4043,11 +4043,11 @@ async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) { let task = cx.update(|cx| { tool.run( - crate::TerminalToolInput { + ToolInput::resolved(crate::TerminalToolInput { command: "echo hello".to_string(), cd: ".".to_string(), timeout_ms: None, - }, + }), event_stream, cx, ) @@ -4101,11 +4101,11 @@ async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) { let _task = cx.update(|cx| { tool.run( - crate::TerminalToolInput { + ToolInput::resolved(crate::TerminalToolInput { command: "sudo rm file".to_string(), cd: ".".to_string(), timeout_ms: None, - }, + }), event_stream, cx, ) @@ -4148,11 +4148,11 @@ async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) { let task = cx.update(|cx| { tool.run( - crate::TerminalToolInput { + ToolInput::resolved(crate::TerminalToolInput { command: "echo hello".to_string(), cd: ".".to_string(), timeout_ms: None, - }, + }), event_stream, cx, ) @@ -5309,11 +5309,11 @@ async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) { let task = cx.update(|cx| { tool.run( - crate::EditFileToolInput { + ToolInput::resolved(crate::EditFileToolInput { display_description: "Edit sensitive file".to_string(), path: "root/sensitive_config.txt".into(), mode: crate::EditFileMode::Edit, - }, + }), event_stream, cx, ) @@ -5359,9 +5359,9 @@ async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext let task = cx.update(|cx| { tool.run( - crate::DeletePathToolInput { + ToolInput::resolved(crate::DeletePathToolInput { path: "root/important_data.txt".to_string(), - }, + }), event_stream, cx, ) @@ -5411,10 +5411,10 @@ async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContex let task = cx.update(|cx| { tool.run( - crate::MovePathToolInput { + ToolInput::resolved(crate::MovePathToolInput { source_path: "root/safe.txt".to_string(), destination_path: "root/protected/safe.txt".to_string(), - }, + }), event_stream, cx, ) @@ -5467,10 +5467,10 @@ async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) { let task = cx.update(|cx| { tool.run( - crate::MovePathToolInput { + ToolInput::resolved(crate::MovePathToolInput { source_path: "root/secret.txt".to_string(), destination_path: "root/public/not_secret.txt".to_string(), - }, + }), event_stream, cx, ) @@ -5525,10 +5525,10 @@ async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) { let task = cx.update(|cx| { tool.run( - crate::CopyPathToolInput { + ToolInput::resolved(crate::CopyPathToolInput { source_path: "root/confidential.txt".to_string(), destination_path: "root/dest/copy.txt".to_string(), - }, + }), event_stream, cx, ) @@ -5580,12 +5580,12 @@ async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) let task = cx.update(|cx| { tool.run( - crate::SaveFileToolInput { + ToolInput::resolved(crate::SaveFileToolInput { paths: vec![ std::path::PathBuf::from("root/normal.txt"), std::path::PathBuf::from("root/readonly/config.txt"), ], - }, + }), event_stream, cx, ) @@ -5632,9 +5632,9 @@ async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) { let task = cx.update(|cx| { tool.run( - crate::SaveFileToolInput { + ToolInput::resolved(crate::SaveFileToolInput { paths: vec![std::path::PathBuf::from("root/config.secret")], - }, + }), event_stream, cx, ) @@ -5676,7 +5676,7 @@ async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) { let input: crate::WebSearchToolInput = serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap(); - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx)); let result = task.await; assert!(result.is_err(), "expected search to be blocked"); @@ -5741,11 +5741,11 @@ async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppConte let _task = cx.update(|cx| { tool.run( - crate::EditFileToolInput { + ToolInput::resolved(crate::EditFileToolInput { display_description: "Edit README".to_string(), path: "root/README.md".into(), mode: crate::EditFileMode::Edit, - }, + }), event_stream, cx, ) @@ -5811,11 +5811,11 @@ async fn test_edit_file_tool_allow_still_prompts_for_local_settings(cx: &mut Tes let (event_stream, mut rx) = crate::ToolCallEventStream::test(); let _task = cx.update(|cx| { tool.run( - crate::EditFileToolInput { + ToolInput::resolved(crate::EditFileToolInput { display_description: "Edit local settings".to_string(), path: "root/.zed/settings.json".into(), mode: crate::EditFileMode::Edit, - }, + }), event_stream, cx, ) @@ -5855,7 +5855,7 @@ async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) { let input: crate::FetchToolInput = serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap(); - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx)); let result = task.await; assert!(result.is_err(), "expected fetch to be blocked"); @@ -5893,7 +5893,7 @@ async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) let input: crate::FetchToolInput = serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap(); - let _task = cx.update(|cx| tool.run(input, event_stream, cx)); + let _task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx)); cx.run_until_parked(); diff --git a/crates/agent/src/tests/test_tools.rs b/crates/agent/src/tests/test_tools.rs index 0ed2eef90271538c575cc84b56a28df106e4bd41..e0794ee322cdf2c77c37d1d22f30ec77c5642d24 100644 --- a/crates/agent/src/tests/test_tools.rs +++ b/crates/agent/src/tests/test_tools.rs @@ -3,6 +3,7 @@ use agent_settings::AgentSettings; use gpui::{App, SharedString, Task}; use std::future; use std::sync::atomic::{AtomicBool, Ordering}; +use std::time::Duration; /// A tool that echoes its input #[derive(JsonSchema, Serialize, Deserialize)] @@ -33,11 +34,17 @@ impl AgentTool for EchoTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, _event_stream: ToolCallEventStream, - _cx: &mut App, + cx: &mut App, ) -> Task> { - Task::ready(Ok(input.text)) + cx.spawn(async move |_cx| { + let input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; + Ok(input.text) + }) } } @@ -74,7 +81,7 @@ impl AgentTool for DelayTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, _event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> @@ -83,6 +90,10 @@ impl AgentTool for DelayTool { { let executor = cx.background_executor().clone(); cx.foreground_executor().spawn(async move { + let input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; executor.timer(Duration::from_millis(input.ms)).await; Ok("Ding".to_string()) }) @@ -114,28 +125,38 @@ impl AgentTool for ToolRequiringPermission { fn run( self: Arc, - _input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let settings = AgentSettings::get_global(cx); - let decision = decide_permission_from_settings(Self::NAME, &[String::new()], settings); - - let authorize = match decision { - ToolPermissionDecision::Allow => None, - ToolPermissionDecision::Deny(reason) => { - return Task::ready(Err(reason)); - } - ToolPermissionDecision::Confirm => { - let context = crate::ToolPermissionContext::new( - "tool_requiring_permission", - vec![String::new()], - ); - Some(event_stream.authorize("Authorize?", context, cx)) - } - }; + cx.spawn(async move |cx| { + let _input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; + + let decision = cx.update(|cx| { + decide_permission_from_settings( + Self::NAME, + &[String::new()], + AgentSettings::get_global(cx), + ) + }); + + let authorize = match decision { + ToolPermissionDecision::Allow => None, + ToolPermissionDecision::Deny(reason) => { + return Err(reason); + } + ToolPermissionDecision::Confirm => Some(cx.update(|cx| { + let context = crate::ToolPermissionContext::new( + "tool_requiring_permission", + vec![String::new()], + ); + event_stream.authorize("Authorize?", context, cx) + })), + }; - cx.foreground_executor().spawn(async move { if let Some(authorize) = authorize { authorize.await.map_err(|e| e.to_string())?; } @@ -169,11 +190,15 @@ impl AgentTool for InfiniteTool { fn run( self: Arc, - _input: Self::Input, + input: ToolInput, _event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { cx.foreground_executor().spawn(async move { + let _input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; future::pending::<()>().await; unreachable!() }) @@ -221,11 +246,15 @@ impl AgentTool for CancellationAwareTool { fn run( self: Arc, - _input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { cx.foreground_executor().spawn(async move { + let _input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; // Wait for cancellation - this tool does nothing but wait to be cancelled event_stream.cancelled_by_user().await; self.was_cancelled.store(true, Ordering::SeqCst); @@ -276,10 +305,16 @@ impl AgentTool for WordListTool { fn run( self: Arc, - _input: Self::Input, + input: ToolInput, _event_stream: ToolCallEventStream, - _cx: &mut App, + cx: &mut App, ) -> Task> { - Task::ready(Ok("ok".to_string())) + cx.spawn(async move |_cx| { + let _input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; + Ok("ok".to_string()) + }) } } diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 5d4de36cb69335de7a77eb7ad7a15f75b8e2b0b7..f9be3bfbeacfd137b06da7dc99eef7ae34422325 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -45,11 +45,13 @@ use language_model::{ use project::Project; use prompt_store::ProjectContext; use schemars::{JsonSchema, Schema}; +use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use settings::{LanguageModelSelection, Settings, ToolPermissionMode, update_settings_file}; use smol::stream::StreamExt; use std::{ collections::BTreeMap, + marker::PhantomData, ops::RangeInclusive, path::Path, rc::Rc, @@ -1360,7 +1362,6 @@ impl Thread { self.project.clone(), cx.weak_entity(), language_registry, - Templates::new(), )); self.add_tool(FetchTool::new(self.project.read(cx).client().http_client())); self.add_tool(FindPathTool::new(self.project.clone())); @@ -1664,6 +1665,7 @@ impl Thread { event_stream: event_stream.clone(), tools: self.enabled_tools(profile, &model, cx), cancellation_tx, + streaming_tool_inputs: HashMap::default(), _task: cx.spawn(async move |this, cx| { log::debug!("Starting agent turn execution"); @@ -2068,10 +2070,6 @@ impl Thread { self.send_or_update_tool_use(&tool_use, title, kind, event_stream); - if !tool_use.is_input_complete { - return None; - } - let Some(tool) = tool else { let content = format!("No tool named {} exists", tool_use.name); return Some(Task::ready(LanguageModelToolResult { @@ -2083,9 +2081,72 @@ impl Thread { })); }; + if !tool_use.is_input_complete { + if tool.supports_input_streaming() { + let running_turn = self.running_turn.as_mut()?; + if let Some(sender) = running_turn.streaming_tool_inputs.get(&tool_use.id) { + sender.send_partial(tool_use.input); + return None; + } + + let (sender, tool_input) = ToolInputSender::channel(); + sender.send_partial(tool_use.input); + running_turn + .streaming_tool_inputs + .insert(tool_use.id.clone(), sender); + + let tool = tool.clone(); + log::debug!("Running streaming tool {}", tool_use.name); + return Some(self.run_tool( + tool, + tool_input, + tool_use.id, + tool_use.name, + event_stream, + cancellation_rx, + cx, + )); + } else { + return None; + } + } + + if let Some(sender) = self + .running_turn + .as_mut()? + .streaming_tool_inputs + .remove(&tool_use.id) + { + sender.send_final(tool_use.input); + return None; + } + + log::debug!("Running tool {}", tool_use.name); + let tool_input = ToolInput::ready(tool_use.input); + Some(self.run_tool( + tool, + tool_input, + tool_use.id, + tool_use.name, + event_stream, + cancellation_rx, + cx, + )) + } + + fn run_tool( + &self, + tool: Arc, + tool_input: ToolInput, + tool_use_id: LanguageModelToolUseId, + tool_name: Arc, + event_stream: &ThreadEventStream, + cancellation_rx: watch::Receiver, + cx: &mut Context, + ) -> Task { let fs = self.project.read(cx).fs().clone(); let tool_event_stream = ToolCallEventStream::new( - tool_use.id.clone(), + tool_use_id.clone(), event_stream.clone(), Some(fs), cancellation_rx, @@ -2094,9 +2155,8 @@ impl Thread { acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress), ); let supports_images = self.model().is_some_and(|model| model.supports_images()); - let tool_result = tool.run(tool_use.input, tool_event_stream, cx); - log::debug!("Running tool {}", tool_use.name); - Some(cx.foreground_executor().spawn(async move { + let tool_result = tool.run(tool_input, tool_event_stream, cx); + cx.foreground_executor().spawn(async move { let (is_error, output) = match tool_result.await { Ok(mut output) => { if let LanguageModelToolResultContent::Image(_) = &output.llm_output @@ -2114,13 +2174,13 @@ impl Thread { }; LanguageModelToolResult { - tool_use_id: tool_use.id, - tool_name: tool_use.name, + tool_use_id, + tool_name, is_error, content: output.llm_output, output: Some(output.raw_output), } - })) + }) } fn handle_tool_use_json_parse_error_event( @@ -2776,6 +2836,9 @@ struct RunningTurn { /// Sender to signal tool cancellation. When cancel is called, this is /// set to true so all tools can detect user-initiated cancellation. cancellation_tx: watch::Sender, + /// Senders for tools that support input streaming and have already been + /// started but are still receiving input from the LLM. + streaming_tool_inputs: HashMap, } impl RunningTurn { @@ -2795,6 +2858,103 @@ pub struct TitleUpdated; impl EventEmitter for Thread {} +/// A channel-based wrapper that delivers tool input to a running tool. +/// +/// For non-streaming tools, created via `ToolInput::ready()` so `.recv()` resolves immediately. +/// For streaming tools, partial JSON snapshots arrive via `.recv_partial()` as the LLM streams +/// them, followed by the final complete input available through `.recv()`. +pub struct ToolInput { + partial_rx: mpsc::UnboundedReceiver, + final_rx: oneshot::Receiver, + _phantom: PhantomData, +} + +impl ToolInput { + #[cfg(any(test, feature = "test-support"))] + pub fn resolved(input: impl Serialize) -> Self { + let value = serde_json::to_value(input).expect("failed to serialize tool input"); + Self::ready(value) + } + + pub fn ready(value: serde_json::Value) -> Self { + let (partial_tx, partial_rx) = mpsc::unbounded(); + drop(partial_tx); + let (final_tx, final_rx) = oneshot::channel(); + final_tx.send(value).ok(); + Self { + partial_rx, + final_rx, + _phantom: PhantomData, + } + } + + #[cfg(any(test, feature = "test-support"))] + pub fn test() -> (ToolInputSender, Self) { + let (sender, input) = ToolInputSender::channel(); + (sender, input.cast()) + } + + /// Wait for the final deserialized input, ignoring all partial updates. + /// Non-streaming tools can use this to wait until the whole input is available. + pub async fn recv(mut self) -> Result { + // Drain any remaining partials + while self.partial_rx.next().await.is_some() {} + let value = self + .final_rx + .await + .map_err(|_| anyhow!("tool input sender was dropped before sending final input"))?; + serde_json::from_value(value).map_err(Into::into) + } + + /// Returns the next partial JSON snapshot, or `None` when input is complete. + /// Once this returns `None`, call `recv()` to get the final input. + pub async fn recv_partial(&mut self) -> Option { + self.partial_rx.next().await + } + + fn cast(self) -> ToolInput { + ToolInput { + partial_rx: self.partial_rx, + final_rx: self.final_rx, + _phantom: PhantomData, + } + } +} + +pub struct ToolInputSender { + partial_tx: mpsc::UnboundedSender, + final_tx: Option>, +} + +impl ToolInputSender { + pub(crate) fn channel() -> (Self, ToolInput) { + let (partial_tx, partial_rx) = mpsc::unbounded(); + let (final_tx, final_rx) = oneshot::channel(); + let sender = Self { + partial_tx, + final_tx: Some(final_tx), + }; + let input = ToolInput { + partial_rx, + final_rx, + _phantom: PhantomData, + }; + (sender, input) + } + + pub(crate) fn send_partial(&self, value: serde_json::Value) { + self.partial_tx.unbounded_send(value).ok(); + } + + pub(crate) fn send_final(mut self, value: serde_json::Value) { + // Close the partial channel so recv_partial() returns None + self.partial_tx.close_channel(); + if let Some(final_tx) = self.final_tx.take() { + final_tx.send(value).ok(); + } + } +} + pub trait AgentTool where Self: 'static + Sized, @@ -2828,6 +2988,11 @@ where language_model::tool_schema::root_schema_for::(format) } + /// Returns whether the tool supports streaming of tool use parameters. + fn supports_input_streaming() -> bool { + false + } + /// Some tools rely on a provider for the underlying billing or other reasons. /// Allow the tool to check if they are compatible, or should be filtered out. fn supports_provider(_provider: &LanguageModelProviderId) -> bool { @@ -2843,7 +3008,7 @@ where /// still signaling whether the invocation succeeded or failed. fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task>; @@ -2888,13 +3053,16 @@ pub trait AnyAgentTool { fn kind(&self) -> acp::ToolKind; fn initial_title(&self, input: serde_json::Value, _cx: &mut App) -> SharedString; fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result; + fn supports_input_streaming(&self) -> bool { + false + } fn supports_provider(&self, _provider: &LanguageModelProviderId) -> bool { true } /// See [`AgentTool::run`] for why this returns `Result`. fn run( self: Arc, - input: serde_json::Value, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task>; @@ -2923,6 +3091,10 @@ where T::kind() } + fn supports_input_streaming(&self) -> bool { + T::supports_input_streaming() + } + fn initial_title(&self, input: serde_json::Value, _cx: &mut App) -> SharedString { let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input); self.0.initial_title(parsed_input, _cx) @@ -2940,35 +3112,31 @@ where fn run( self: Arc, - input: serde_json::Value, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - cx.spawn(async move |cx| { - let input: T::Input = serde_json::from_value(input).map_err(|e| { - AgentToolOutput::from_error(format!("Failed to parse tool input: {e}")) - })?; - let task = cx.update(|cx| self.0.clone().run(input, event_stream, cx)); - match task.await { - Ok(output) => { - let raw_output = serde_json::to_value(&output).map_err(|e| { - AgentToolOutput::from_error(format!("Failed to serialize tool output: {e}")) - })?; - Ok(AgentToolOutput { - llm_output: output.into(), - raw_output, - }) - } - Err(error_output) => { - let raw_output = serde_json::to_value(&error_output).unwrap_or_else(|e| { - log::error!("Failed to serialize tool error output: {e}"); - serde_json::Value::Null - }); - Err(AgentToolOutput { - llm_output: error_output.into(), - raw_output, - }) - } + let tool_input: ToolInput = input.cast(); + let task = self.0.clone().run(tool_input, event_stream, cx); + cx.spawn(async move |_cx| match task.await { + Ok(output) => { + let raw_output = serde_json::to_value(&output).map_err(|e| { + AgentToolOutput::from_error(format!("Failed to serialize tool output: {e}")) + })?; + Ok(AgentToolOutput { + llm_output: output.into(), + raw_output, + }) + } + Err(error_output) => { + let raw_output = serde_json::to_value(&error_output).unwrap_or_else(|e| { + log::error!("Failed to serialize tool error output: {e}"); + serde_json::Value::Null + }); + Err(AgentToolOutput { + llm_output: error_output.into(), + raw_output, + }) } }) } diff --git a/crates/agent/src/tools/context_server_registry.rs b/crates/agent/src/tools/context_server_registry.rs index 694e28750cd69facc49b7a0bf862203a00043b4c..1c7590d8097a5de50b879d5b253c5dbabd3dcbab 100644 --- a/crates/agent/src/tools/context_server_registry.rs +++ b/crates/agent/src/tools/context_server_registry.rs @@ -1,4 +1,4 @@ -use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream}; +use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream, ToolInput}; use agent_client_protocol::ToolKind; use anyhow::Result; use collections::{BTreeMap, HashMap}; @@ -329,7 +329,7 @@ impl AnyAgentTool for ContextServerTool { fn run( self: Arc, - input: serde_json::Value, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { @@ -339,14 +339,15 @@ impl AnyAgentTool for ContextServerTool { let tool_name = self.tool.name.clone(); let tool_id = mcp_tool_id(&self.server_id.0, &self.tool.name); let display_name = self.tool.name.clone(); - let authorize = event_stream.authorize_third_party_tool( - self.initial_title(input.clone(), cx), - tool_id, - display_name, - cx, - ); + let initial_title = self.initial_title(serde_json::Value::Null, cx); + let authorize = + event_stream.authorize_third_party_tool(initial_title, tool_id, display_name, cx); cx.spawn(async move |_cx| { + let input = input.recv().await.map_err(|e| { + AgentToolOutput::from_error(format!("Failed to receive tool input: {e}")) + })?; + authorize.await.map_err(|e| AgentToolOutput::from_error(e.to_string()))?; let Some(protocol) = server.client() else { diff --git a/crates/agent/src/tools/copy_path_tool.rs b/crates/agent/src/tools/copy_path_tool.rs index c82d9e930e1987d389ece84347c1a0f43c601182..7f53a5c36a7979a01de96535f19e421fa3119e16 100644 --- a/crates/agent/src/tools/copy_path_tool.rs +++ b/crates/agent/src/tools/copy_path_tool.rs @@ -2,7 +2,9 @@ use super::tool_permissions::{ SensitiveSettingsKind, authorize_symlink_escapes, canonicalize_worktree_roots, collect_symlink_escapes, sensitive_settings_kind, }; -use crate::{AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_for_paths}; +use crate::{ + AgentTool, ToolCallEventStream, ToolInput, ToolPermissionDecision, decide_permission_for_paths, +}; use agent_client_protocol::ToolKind; use agent_settings::AgentSettings; use futures::FutureExt as _; @@ -79,19 +81,24 @@ impl AgentTool for CopyPathTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let settings = AgentSettings::get_global(cx); - let paths = vec![input.source_path.clone(), input.destination_path.clone()]; - let decision = decide_permission_for_paths(Self::NAME, &paths, settings); - if let ToolPermissionDecision::Deny(reason) = decision { - return Task::ready(Err(reason)); - } - let project = self.project.clone(); cx.spawn(async move |cx| { + let input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; + let paths = vec![input.source_path.clone(), input.destination_path.clone()]; + let decision = cx.update(|cx| { + decide_permission_for_paths(Self::NAME, &paths, &AgentSettings::get_global(cx)) + }); + if let ToolPermissionDecision::Deny(reason) = decision { + return Err(reason); + } + let fs = project.read_with(cx, |project, _cx| project.fs().clone()); let canonical_roots = canonicalize_worktree_roots(&project, &fs, cx).await; @@ -248,7 +255,7 @@ mod tests { }; let (event_stream, mut event_rx) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx)); let auth = event_rx.expect_authorization().await; let title = auth.tool_call.fields.title.as_deref().unwrap_or(""); @@ -302,7 +309,7 @@ mod tests { }; let (event_stream, mut event_rx) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx)); let auth = event_rx.expect_authorization().await; drop(auth); @@ -354,7 +361,7 @@ mod tests { }; let (event_stream, mut event_rx) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx)); let auth = event_rx.expect_authorization().await; let title = auth.tool_call.fields.title.as_deref().unwrap_or(""); @@ -430,7 +437,9 @@ mod tests { }; let (event_stream, mut event_rx) = ToolCallEventStream::test(); - let result = cx.update(|cx| tool.run(input, event_stream, cx)).await; + let result = cx + .update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx)) + .await; assert!(result.is_err(), "Tool should fail when policy denies"); assert!( diff --git a/crates/agent/src/tools/create_directory_tool.rs b/crates/agent/src/tools/create_directory_tool.rs index 500b5f00289db245898d5918a79dc684a6f0f110..5d8930f3c7400428d55cfe7d14bafc16d94be43a 100644 --- a/crates/agent/src/tools/create_directory_tool.rs +++ b/crates/agent/src/tools/create_directory_tool.rs @@ -13,7 +13,9 @@ use settings::Settings; use std::sync::Arc; use util::markdown::MarkdownInlineCode; -use crate::{AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_for_path}; +use crate::{ + AgentTool, ToolCallEventStream, ToolInput, ToolPermissionDecision, decide_permission_for_path, +}; use std::path::Path; /// Creates a new directory at the specified path within the project. Returns confirmation that the directory was created. @@ -68,21 +70,26 @@ impl AgentTool for CreateDirectoryTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let settings = AgentSettings::get_global(cx); - let decision = decide_permission_for_path(Self::NAME, &input.path, settings); + let project = self.project.clone(); + cx.spawn(async move |cx| { + let input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; + let decision = cx.update(|cx| { + decide_permission_for_path(Self::NAME, &input.path, AgentSettings::get_global(cx)) + }); - if let ToolPermissionDecision::Deny(reason) = decision { - return Task::ready(Err(reason)); - } + if let ToolPermissionDecision::Deny(reason) = decision { + return Err(reason); + } - let destination_path: Arc = input.path.as_str().into(); + let destination_path: Arc = input.path.as_str().into(); - let project = self.project.clone(); - cx.spawn(async move |cx| { let fs = project.read_with(cx, |project, _cx| project.fs().clone()); let canonical_roots = canonicalize_worktree_roots(&project, &fs, cx).await; @@ -218,9 +225,9 @@ mod tests { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { tool.run( - CreateDirectoryToolInput { + ToolInput::resolved(CreateDirectoryToolInput { path: "project/link_to_external".into(), - }, + }), event_stream, cx, ) @@ -277,9 +284,9 @@ mod tests { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { tool.run( - CreateDirectoryToolInput { + ToolInput::resolved(CreateDirectoryToolInput { path: "project/link_to_external".into(), - }, + }), event_stream, cx, ) @@ -336,9 +343,9 @@ mod tests { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { tool.run( - CreateDirectoryToolInput { + ToolInput::resolved(CreateDirectoryToolInput { path: "project/link_to_external".into(), - }, + }), event_stream, cx, ) @@ -415,9 +422,9 @@ mod tests { let result = cx .update(|cx| { tool.run( - CreateDirectoryToolInput { + ToolInput::resolved(CreateDirectoryToolInput { path: "project/link_to_external".into(), - }, + }), event_stream, cx, ) diff --git a/crates/agent/src/tools/delete_path_tool.rs b/crates/agent/src/tools/delete_path_tool.rs index 048f4bd8292077874b49bd74b09cbea38b4fafc5..27ab68db667a4cf3223e6521682814dc1c245bb7 100644 --- a/crates/agent/src/tools/delete_path_tool.rs +++ b/crates/agent/src/tools/delete_path_tool.rs @@ -2,7 +2,9 @@ use super::tool_permissions::{ SensitiveSettingsKind, authorize_symlink_access, canonicalize_worktree_roots, detect_symlink_escape, sensitive_settings_kind, }; -use crate::{AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_for_path}; +use crate::{ + AgentTool, ToolCallEventStream, ToolInput, ToolPermissionDecision, decide_permission_for_path, +}; use action_log::ActionLog; use agent_client_protocol::ToolKind; use agent_settings::AgentSettings; @@ -71,22 +73,27 @@ impl AgentTool for DeletePathTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let path = input.path; - - let settings = AgentSettings::get_global(cx); - let decision = decide_permission_for_path(Self::NAME, &path, settings); - - if let ToolPermissionDecision::Deny(reason) = decision { - return Task::ready(Err(reason)); - } - let project = self.project.clone(); let action_log = self.action_log.clone(); cx.spawn(async move |cx| { + let input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; + let path = input.path; + + let decision = cx.update(|cx| { + decide_permission_for_path(Self::NAME, &path, AgentSettings::get_global(cx)) + }); + + if let ToolPermissionDecision::Deny(reason) = decision { + return Err(reason); + } + let fs = project.read_with(cx, |project, _cx| project.fs().clone()); let canonical_roots = canonicalize_worktree_roots(&project, &fs, cx).await; @@ -278,9 +285,9 @@ mod tests { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { tool.run( - DeletePathToolInput { + ToolInput::resolved(DeletePathToolInput { path: "project/link_to_external".into(), - }, + }), event_stream, cx, ) @@ -345,9 +352,9 @@ mod tests { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { tool.run( - DeletePathToolInput { + ToolInput::resolved(DeletePathToolInput { path: "project/link_to_external".into(), - }, + }), event_stream, cx, ) @@ -405,9 +412,9 @@ mod tests { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { tool.run( - DeletePathToolInput { + ToolInput::resolved(DeletePathToolInput { path: "project/link_to_external".into(), - }, + }), event_stream, cx, ) @@ -488,9 +495,9 @@ mod tests { let result = cx .update(|cx| { tool.run( - DeletePathToolInput { + ToolInput::resolved(DeletePathToolInput { path: "project/link_to_external".into(), - }, + }), event_stream, cx, ) diff --git a/crates/agent/src/tools/diagnostics_tool.rs b/crates/agent/src/tools/diagnostics_tool.rs index fea16d531ed5f4212e6b1374aee04cee67b2fc0b..5889f66c2edbe06055678b19474447e0f23e2b0f 100644 --- a/crates/agent/src/tools/diagnostics_tool.rs +++ b/crates/agent/src/tools/diagnostics_tool.rs @@ -1,4 +1,4 @@ -use crate::{AgentTool, ToolCallEventStream}; +use crate::{AgentTool, ToolCallEventStream, ToolInput}; use agent_client_protocol as acp; use anyhow::Result; use futures::FutureExt as _; @@ -87,21 +87,27 @@ impl AgentTool for DiagnosticsTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - match input.path { - Some(path) if !path.is_empty() => { - let Some(project_path) = self.project.read(cx).find_project_path(&path, cx) else { - return Task::ready(Err(format!("Could not find path {path} in project"))); - }; - - let open_buffer_task = self - .project - .update(cx, |project, cx| project.open_buffer(project_path, cx)); + let project = self.project.clone(); + cx.spawn(async move |cx| { + let input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; + + match input.path { + Some(path) if !path.is_empty() => { + let (_project_path, open_buffer_task) = project.update(cx, |project, cx| { + let Some(project_path) = project.find_project_path(&path, cx) else { + return Err(format!("Could not find path {path} in project")); + }; + let task = project.open_buffer(project_path.clone(), cx); + Ok((project_path, task)) + })?; - cx.spawn(async move |cx| { let buffer = futures::select! { result = open_buffer_task.fuse() => result.map_err(|e| e.to_string())?, _ = event_stream.cancelled_by_user().fuse() => { @@ -135,36 +141,40 @@ impl AgentTool for DiagnosticsTool { } else { Ok(output) } - }) - } - _ => { - let project = self.project.read(cx); - let mut output = String::new(); - let mut has_diagnostics = false; - - for (project_path, _, summary) in project.diagnostic_summaries(true, cx) { - if summary.error_count > 0 || summary.warning_count > 0 { - let Some(worktree) = project.worktree_for_id(project_path.worktree_id, cx) - else { - continue; - }; - - has_diagnostics = true; - output.push_str(&format!( - "{}: {} error(s), {} warning(s)\n", - worktree.read(cx).absolutize(&project_path.path).display(), - summary.error_count, - summary.warning_count - )); - } } + _ => { + let (output, has_diagnostics) = project.read_with(cx, |project, cx| { + let mut output = String::new(); + let mut has_diagnostics = false; + + for (project_path, _, summary) in project.diagnostic_summaries(true, cx) { + if summary.error_count > 0 || summary.warning_count > 0 { + let Some(worktree) = + project.worktree_for_id(project_path.worktree_id, cx) + else { + continue; + }; + + has_diagnostics = true; + output.push_str(&format!( + "{}: {} error(s), {} warning(s)\n", + worktree.read(cx).absolutize(&project_path.path).display(), + summary.error_count, + summary.warning_count + )); + } + } + + (output, has_diagnostics) + }); - if has_diagnostics { - Task::ready(Ok(output)) - } else { - Task::ready(Ok("No errors or warnings found in the project.".into())) + if has_diagnostics { + Ok(output) + } else { + Ok("No errors or warnings found in the project.".into()) + } } } - } + }) } } diff --git a/crates/agent/src/tools/edit_file_tool.rs b/crates/agent/src/tools/edit_file_tool.rs index 788bf06529a6f0b87242379ffcdb83f38e4c7126..3e1e0661f126d464c8d4611e2b3d85a9f668a5ca 100644 --- a/crates/agent/src/tools/edit_file_tool.rs +++ b/crates/agent/src/tools/edit_file_tool.rs @@ -2,7 +2,7 @@ use super::restore_file_from_disk_tool::RestoreFileFromDiskTool; use super::save_file_tool::SaveFileTool; use super::tool_permissions::authorize_file_edit; use crate::{ - AgentTool, Templates, Thread, ToolCallEventStream, + AgentTool, Templates, Thread, ToolCallEventStream, ToolInput, edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat}, }; use acp_thread::Diff; @@ -237,39 +237,44 @@ impl AgentTool for EditFileTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let Ok(project) = self - .thread - .read_with(cx, |thread, _cx| thread.project().clone()) - else { - return Task::ready(Err(EditFileToolOutput::Error { - error: "thread was dropped".to_string(), - })); - }; - let project_path = match resolve_path(&input, project.clone(), cx) { - Ok(path) => path, - Err(err) => { - return Task::ready(Err(EditFileToolOutput::Error { - error: err.to_string(), - })); - } - }; - let abs_path = project.read(cx).absolute_path(&project_path, cx); - if let Some(abs_path) = abs_path.clone() { - event_stream.update_fields( - ToolCallUpdateFields::new().locations(vec![acp::ToolCallLocation::new(abs_path)]), - ); - } - let allow_thinking = self - .thread - .read_with(cx, |thread, _cx| thread.thinking_enabled()) - .unwrap_or(true); - - let authorize = self.authorize(&input, &event_stream, cx); cx.spawn(async move |cx: &mut AsyncApp| { + let input = input.recv().await.map_err(|e| EditFileToolOutput::Error { + error: format!("Failed to receive tool input: {e}"), + })?; + + let project = self + .thread + .read_with(cx, |thread, _cx| thread.project().clone()) + .map_err(|_| EditFileToolOutput::Error { + error: "thread was dropped".to_string(), + })?; + + let (project_path, abs_path, allow_thinking, authorize) = + cx.update(|cx| { + let project_path = resolve_path(&input, project.clone(), cx).map_err(|err| { + EditFileToolOutput::Error { + error: err.to_string(), + } + })?; + let abs_path = project.read(cx).absolute_path(&project_path, cx); + if let Some(abs_path) = abs_path.clone() { + event_stream.update_fields( + ToolCallUpdateFields::new() + .locations(vec![acp::ToolCallLocation::new(abs_path)]), + ); + } + let allow_thinking = self + .thread + .read_with(cx, |thread, _cx| thread.thinking_enabled()) + .unwrap_or(true); + let authorize = self.authorize(&input, &event_stream, cx); + Ok::<_, EditFileToolOutput>((project_path, abs_path, allow_thinking, authorize)) + })?; + let result: anyhow::Result = async { authorize.await?; @@ -672,7 +677,11 @@ mod tests { language_registry, Templates::new(), )) - .run(input, ToolCallEventStream::test().0, cx) + .run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; assert_eq!( @@ -881,7 +890,11 @@ mod tests { language_registry.clone(), Templates::new(), )) - .run(input, ToolCallEventStream::test().0, cx) + .run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }); // Stream the unformatted content @@ -940,7 +953,11 @@ mod tests { language_registry, Templates::new(), )) - .run(input, ToolCallEventStream::test().0, cx) + .run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }); // Stream the unformatted content @@ -1027,7 +1044,11 @@ mod tests { language_registry.clone(), Templates::new(), )) - .run(input, ToolCallEventStream::test().0, cx) + .run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }); // Stream the content with trailing whitespace @@ -1082,7 +1103,11 @@ mod tests { language_registry, Templates::new(), )) - .run(input, ToolCallEventStream::test().0, cx) + .run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }); // Stream the content with trailing whitespace @@ -2081,11 +2106,11 @@ mod tests { let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); let edit = cx.update(|cx| { tool.run( - EditFileToolInput { + ToolInput::resolved(EditFileToolInput { display_description: "Edit file".into(), path: path!("/main.rs").into(), mode: EditFileMode::Edit, - }, + }), stream_tx, cx, ) @@ -2111,11 +2136,11 @@ mod tests { let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); let edit = cx.update(|cx| { tool.run( - EditFileToolInput { + ToolInput::resolved(EditFileToolInput { display_description: "Edit file".into(), path: path!("/main.rs").into(), mode: EditFileMode::Edit, - }, + }), stream_tx, cx, ) @@ -2139,11 +2164,11 @@ mod tests { let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); let edit = cx.update(|cx| { tool.run( - EditFileToolInput { + ToolInput::resolved(EditFileToolInput { display_description: "Edit file".into(), path: path!("/main.rs").into(), mode: EditFileMode::Edit, - }, + }), stream_tx, cx, ) @@ -2199,11 +2224,11 @@ mod tests { // Read the file to record the read time cx.update(|cx| { read_tool.clone().run( - crate::ReadFileToolInput { + ToolInput::resolved(crate::ReadFileToolInput { path: "root/test.txt".to_string(), start_line: None, end_line: None, - }, + }), ToolCallEventStream::test().0, cx, ) @@ -2227,11 +2252,11 @@ mod tests { // Read the file again - should update the entry cx.update(|cx| { read_tool.clone().run( - crate::ReadFileToolInput { + ToolInput::resolved(crate::ReadFileToolInput { path: "root/test.txt".to_string(), start_line: None, end_line: None, - }, + }), ToolCallEventStream::test().0, cx, ) @@ -2298,11 +2323,11 @@ mod tests { // Read the file first cx.update(|cx| { read_tool.clone().run( - crate::ReadFileToolInput { + ToolInput::resolved(crate::ReadFileToolInput { path: "root/test.txt".to_string(), start_line: None, end_line: None, - }, + }), ToolCallEventStream::test().0, cx, ) @@ -2314,11 +2339,11 @@ mod tests { let edit_result = { let edit_task = cx.update(|cx| { edit_tool.clone().run( - EditFileToolInput { + ToolInput::resolved(EditFileToolInput { display_description: "First edit".into(), path: "root/test.txt".into(), mode: EditFileMode::Edit, - }, + }), ToolCallEventStream::test().0, cx, ) @@ -2343,11 +2368,11 @@ mod tests { let edit_result = { let edit_task = cx.update(|cx| { edit_tool.clone().run( - EditFileToolInput { + ToolInput::resolved(EditFileToolInput { display_description: "Second edit".into(), path: "root/test.txt".into(), mode: EditFileMode::Edit, - }, + }), ToolCallEventStream::test().0, cx, ) @@ -2412,11 +2437,11 @@ mod tests { // Read the file first cx.update(|cx| { read_tool.clone().run( - crate::ReadFileToolInput { + ToolInput::resolved(crate::ReadFileToolInput { path: "root/test.txt".to_string(), start_line: None, end_line: None, - }, + }), ToolCallEventStream::test().0, cx, ) @@ -2456,11 +2481,11 @@ mod tests { let result = cx .update(|cx| { edit_tool.clone().run( - EditFileToolInput { + ToolInput::resolved(EditFileToolInput { display_description: "Edit after external change".into(), path: "root/test.txt".into(), mode: EditFileMode::Edit, - }, + }), ToolCallEventStream::test().0, cx, ) @@ -2523,11 +2548,11 @@ mod tests { // Read the file first cx.update(|cx| { read_tool.clone().run( - crate::ReadFileToolInput { + ToolInput::resolved(crate::ReadFileToolInput { path: "root/test.txt".to_string(), start_line: None, end_line: None, - }, + }), ToolCallEventStream::test().0, cx, ) @@ -2560,11 +2585,11 @@ mod tests { let result = cx .update(|cx| { edit_tool.clone().run( - EditFileToolInput { + ToolInput::resolved(EditFileToolInput { display_description: "Edit with dirty buffer".into(), path: "root/test.txt".into(), mode: EditFileMode::Edit, - }, + }), ToolCallEventStream::test().0, cx, ) diff --git a/crates/agent/src/tools/fetch_tool.rs b/crates/agent/src/tools/fetch_tool.rs index e573c2202b09d1283d75c3eda20b65be1bcd82a7..75880801595ad0604c9f3a1fac58bd916809a8ba 100644 --- a/crates/agent/src/tools/fetch_tool.rs +++ b/crates/agent/src/tools/fetch_tool.rs @@ -16,7 +16,8 @@ use ui::SharedString; use util::markdown::{MarkdownEscaped, MarkdownInlineCode}; use crate::{ - AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_from_settings, + AgentTool, ToolCallEventStream, ToolInput, ToolPermissionDecision, + decide_permission_from_settings, }; #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)] @@ -141,41 +142,52 @@ impl AgentTool for FetchTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let settings = AgentSettings::get_global(cx); - let decision = - decide_permission_from_settings(Self::NAME, std::slice::from_ref(&input.url), settings); - - let authorize = match decision { - ToolPermissionDecision::Allow => None, - ToolPermissionDecision::Deny(reason) => { - return Task::ready(Err(reason)); - } - ToolPermissionDecision::Confirm => { - let context = - crate::ToolPermissionContext::new(Self::NAME, vec![input.url.clone()]); - Some(event_stream.authorize( - format!("Fetch {}", MarkdownInlineCode(&input.url)), - context, - cx, - )) - } - }; + let http_client = self.http_client.clone(); + cx.spawn(async move |cx| { + let input: FetchToolInput = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; + + let decision = cx.update(|cx| { + decide_permission_from_settings( + Self::NAME, + std::slice::from_ref(&input.url), + AgentSettings::get_global(cx), + ) + }); + + let authorize = match decision { + ToolPermissionDecision::Allow => None, + ToolPermissionDecision::Deny(reason) => { + return Err(reason); + } + ToolPermissionDecision::Confirm => Some(cx.update(|cx| { + let context = + crate::ToolPermissionContext::new(Self::NAME, vec![input.url.clone()]); + event_stream.authorize( + format!("Fetch {}", MarkdownInlineCode(&input.url)), + context, + cx, + ) + })), + }; - let fetch_task = cx.background_spawn({ - let http_client = self.http_client.clone(); - async move { - if let Some(authorize) = authorize { - authorize.await?; + let fetch_task = cx.background_spawn({ + let http_client = http_client.clone(); + let url = input.url.clone(); + async move { + if let Some(authorize) = authorize { + authorize.await?; + } + Self::build_message(http_client, &url).await } - Self::build_message(http_client, &input.url).await - } - }); + }); - cx.foreground_executor().spawn(async move { let text = futures::select! { result = fetch_task.fuse() => result.map_err(|e| e.to_string())?, _ = event_stream.cancelled_by_user().fuse() => { diff --git a/crates/agent/src/tools/find_path_tool.rs b/crates/agent/src/tools/find_path_tool.rs index 4ba60c61063c08ac002dc7dc16fa11b987cbab74..9c65461503225171bcda482d58871a94743481e3 100644 --- a/crates/agent/src/tools/find_path_tool.rs +++ b/crates/agent/src/tools/find_path_tool.rs @@ -1,4 +1,4 @@ -use crate::{AgentTool, ToolCallEventStream}; +use crate::{AgentTool, ToolCallEventStream, ToolInput}; use agent_client_protocol as acp; use anyhow::{Result, anyhow}; use futures::FutureExt as _; @@ -121,13 +121,18 @@ impl AgentTool for FindPathTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let search_paths_task = search_paths(&input.glob, self.project.clone(), cx); + let project = self.project.clone(); + cx.spawn(async move |cx| { + let input = input.recv().await.map_err(|e| FindPathToolOutput::Error { + error: format!("Failed to receive tool input: {e}"), + })?; + + let search_paths_task = cx.update(|cx| search_paths(&input.glob, project, cx)); - cx.background_spawn(async move { let matches = futures::select! { result = search_paths_task.fuse() => result.map_err(|e| FindPathToolOutput::Error { error: e.to_string() })?, _ = event_stream.cancelled_by_user().fuse() => { diff --git a/crates/agent/src/tools/grep_tool.rs b/crates/agent/src/tools/grep_tool.rs index 16162107dff84ab40117b7783e04b633d144a214..fbfdc18585b822361effb6fd770e678b3e434a17 100644 --- a/crates/agent/src/tools/grep_tool.rs +++ b/crates/agent/src/tools/grep_tool.rs @@ -1,4 +1,4 @@ -use crate::{AgentTool, ToolCallEventStream}; +use crate::{AgentTool, ToolCallEventStream, ToolInput}; use agent_client_protocol as acp; use anyhow::Result; use futures::{FutureExt as _, StreamExt}; @@ -114,66 +114,64 @@ impl AgentTool for GrepTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { const CONTEXT_LINES: u32 = 2; const MAX_ANCESTOR_LINES: u32 = 10; - let path_style = self.project.read(cx).path_style(cx); - - let include_matcher = match PathMatcher::new( - input - .include_pattern - .as_ref() - .into_iter() - .collect::>(), - path_style, - ) { - Ok(matcher) => matcher, - Err(error) => { - return Task::ready(Err(format!("invalid include glob pattern: {error}"))); - } - }; - - // Exclude global file_scan_exclusions and private_files settings - let exclude_matcher = { - let global_settings = WorktreeSettings::get_global(cx); - let exclude_patterns = global_settings - .file_scan_exclusions - .sources() - .chain(global_settings.private_files.sources()); - - match PathMatcher::new(exclude_patterns, path_style) { - Ok(matcher) => matcher, - Err(error) => { - return Task::ready(Err(format!("invalid exclude pattern: {error}"))); - } - } - }; - - let query = match SearchQuery::regex( - &input.regex, - false, - input.case_sensitive, - false, - false, - include_matcher, - exclude_matcher, - true, // Always match file include pattern against *full project paths* that start with a project root. - None, - ) { - Ok(query) => query, - Err(error) => return Task::ready(Err(error.to_string())), - }; - - let results = self - .project - .update(cx, |project, cx| project.search(query, cx)); - - let project = self.project.downgrade(); + let project = self.project.clone(); cx.spawn(async move |cx| { + let input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; + + let results = cx.update(|cx| { + let path_style = project.read(cx).path_style(cx); + + let include_matcher = PathMatcher::new( + input + .include_pattern + .as_ref() + .into_iter() + .collect::>(), + path_style, + ) + .map_err(|error| format!("invalid include glob pattern: {error}"))?; + + // Exclude global file_scan_exclusions and private_files settings + let exclude_matcher = { + let global_settings = WorktreeSettings::get_global(cx); + let exclude_patterns = global_settings + .file_scan_exclusions + .sources() + .chain(global_settings.private_files.sources()); + + PathMatcher::new(exclude_patterns, path_style) + .map_err(|error| format!("invalid exclude pattern: {error}"))? + }; + + let query = SearchQuery::regex( + &input.regex, + false, + input.case_sensitive, + false, + false, + include_matcher, + exclude_matcher, + true, // Always match file include pattern against *full project paths* that start with a project root. + None, + ) + .map_err(|error| error.to_string())?; + + Ok::<_, String>( + project.update(cx, |project, cx| project.search(query, cx)), + ) + })?; + + let project = project.downgrade(); // Keep the search alive for the duration of result iteration. Dropping this task is the // cancellation mechanism; we intentionally do not detach it. let SearchResults {rx, _task_handle} = results; @@ -787,7 +785,13 @@ mod tests { cx: &mut TestAppContext, ) -> String { let tool = Arc::new(GrepTool { project }); - let task = cx.update(|cx| tool.run(input, ToolCallEventStream::test().0, cx)); + let task = cx.update(|cx| { + tool.run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }); match task.await { Ok(result) => { diff --git a/crates/agent/src/tools/list_directory_tool.rs b/crates/agent/src/tools/list_directory_tool.rs index 5dddee94904283ccb9198ce56aa4005250b5908a..1a674aaa71fef5bf9c11688e82982a5dbcfee331 100644 --- a/crates/agent/src/tools/list_directory_tool.rs +++ b/crates/agent/src/tools/list_directory_tool.rs @@ -2,7 +2,7 @@ use super::tool_permissions::{ ResolvedProjectPath, authorize_symlink_access, canonicalize_worktree_roots, resolve_project_path, }; -use crate::{AgentTool, ToolCallEventStream}; +use crate::{AgentTool, ToolCallEventStream, ToolInput}; use agent_client_protocol::ToolKind; use anyhow::{Context as _, Result, anyhow}; use gpui::{App, Entity, SharedString, Task}; @@ -146,34 +146,39 @@ impl AgentTool for ListDirectoryTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - // Sometimes models will return these even though we tell it to give a path and not a glob. - // When this happens, just list the root worktree directories. - if matches!(input.path.as_str(), "." | "" | "./" | "*") { - let output = self - .project - .read(cx) - .worktrees(cx) - .filter_map(|worktree| { - let worktree = worktree.read(cx); - let root_entry = worktree.root_entry()?; - if root_entry.is_dir() { - Some(root_entry.path.display(worktree.path_style())) - } else { - None - } - }) - .collect::>() - .join("\n"); - - return Task::ready(Ok(output)); - } - let project = self.project.clone(); cx.spawn(async move |cx| { + let input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; + + // Sometimes models will return these even though we tell it to give a path and not a glob. + // When this happens, just list the root worktree directories. + if matches!(input.path.as_str(), "." | "" | "./" | "*") { + let output = project.read_with(cx, |project, cx| { + project + .worktrees(cx) + .filter_map(|worktree| { + let worktree = worktree.read(cx); + let root_entry = worktree.root_entry()?; + if root_entry.is_dir() { + Some(root_entry.path.display(worktree.path_style())) + } else { + None + } + }) + .collect::>() + .join("\n") + }); + + return Ok(output); + } + let fs = project.read_with(cx, |project, _cx| project.fs().clone()); let canonical_roots = canonicalize_worktree_roots(&project, &fs, cx).await; @@ -323,7 +328,13 @@ mod tests { path: "project".into(), }; let output = cx - .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .update(|cx| { + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }) .await .unwrap(); assert_eq!( @@ -344,7 +355,13 @@ mod tests { path: "project/src".into(), }; let output = cx - .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .update(|cx| { + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }) .await .unwrap(); assert_eq!( @@ -365,7 +382,13 @@ mod tests { path: "project/tests".into(), }; let output = cx - .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .update(|cx| { + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }) .await .unwrap(); assert!(!output.contains("# Folders:")); @@ -393,7 +416,13 @@ mod tests { path: "project/empty_dir".into(), }; let output = cx - .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .update(|cx| { + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }) .await .unwrap(); assert_eq!(output, "project/empty_dir is empty.\n"); @@ -420,7 +449,13 @@ mod tests { path: "project/nonexistent".into(), }; let output = cx - .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .update(|cx| { + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }) .await; assert!(output.unwrap_err().contains("Path not found")); @@ -429,7 +464,13 @@ mod tests { path: "project/file.txt".into(), }; let output = cx - .update(|cx| tool.run(input, ToolCallEventStream::test().0, cx)) + .update(|cx| { + tool.run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }) .await; assert!(output.unwrap_err().contains("is not a directory")); } @@ -493,7 +534,13 @@ mod tests { path: "project".into(), }; let output = cx - .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .update(|cx| { + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }) .await .unwrap(); @@ -520,7 +567,13 @@ mod tests { path: "project/.secretdir".into(), }; let output = cx - .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .update(|cx| { + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }) .await; assert!( output.unwrap_err().contains("file_scan_exclusions"), @@ -532,7 +585,13 @@ mod tests { path: "project/visible_dir".into(), }; let output = cx - .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .update(|cx| { + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }) .await .unwrap(); @@ -637,7 +696,13 @@ mod tests { path: "worktree1/src".into(), }; let output = cx - .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .update(|cx| { + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }) .await .unwrap(); assert!(output.contains("main.rs"), "Should list main.rs"); @@ -655,7 +720,13 @@ mod tests { path: "worktree1/tests".into(), }; let output = cx - .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .update(|cx| { + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }) .await .unwrap(); assert!(output.contains("test.rs"), "Should list test.rs"); @@ -669,7 +740,13 @@ mod tests { path: "worktree2/lib".into(), }; let output = cx - .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .update(|cx| { + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }) .await .unwrap(); assert!(output.contains("public.js"), "Should list public.js"); @@ -687,7 +764,13 @@ mod tests { path: "worktree2/docs".into(), }; let output = cx - .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .update(|cx| { + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }) .await .unwrap(); assert!(output.contains("README.md"), "Should list README.md"); @@ -701,7 +784,13 @@ mod tests { path: "worktree1/src/secret.rs".into(), }; let output = cx - .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx)) + .update(|cx| { + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }) .await; assert!(output.unwrap_err().contains("Cannot list directory"),); } @@ -743,9 +832,9 @@ mod tests { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { tool.clone().run( - ListDirectoryToolInput { + ToolInput::resolved(ListDirectoryToolInput { path: "project/link_to_external".into(), - }, + }), event_stream, cx, ) @@ -804,9 +893,9 @@ mod tests { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { tool.clone().run( - ListDirectoryToolInput { + ToolInput::resolved(ListDirectoryToolInput { path: "project/link_to_external".into(), - }, + }), event_stream, cx, ) @@ -871,9 +960,9 @@ mod tests { let result = cx .update(|cx| { tool.clone().run( - ListDirectoryToolInput { + ToolInput::resolved(ListDirectoryToolInput { path: "project/link_to_external".into(), - }, + }), event_stream, cx, ) @@ -924,9 +1013,9 @@ mod tests { let result = cx .update(|cx| { tool.clone().run( - ListDirectoryToolInput { + ToolInput::resolved(ListDirectoryToolInput { path: "project/src".into(), - }, + }), event_stream, cx, ) @@ -981,9 +1070,9 @@ mod tests { let result = cx .update(|cx| { tool.clone().run( - ListDirectoryToolInput { + ToolInput::resolved(ListDirectoryToolInput { path: "project/link_dir".into(), - }, + }), event_stream, cx, ) diff --git a/crates/agent/src/tools/move_path_tool.rs b/crates/agent/src/tools/move_path_tool.rs index 4c337d0ec2827ad7c63c87ef206f0e74dc63091f..c246b3c5b0661546f4617bb5521766f9da3839fb 100644 --- a/crates/agent/src/tools/move_path_tool.rs +++ b/crates/agent/src/tools/move_path_tool.rs @@ -2,7 +2,9 @@ use super::tool_permissions::{ SensitiveSettingsKind, authorize_symlink_escapes, canonicalize_worktree_roots, collect_symlink_escapes, sensitive_settings_kind, }; -use crate::{AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_for_paths}; +use crate::{ + AgentTool, ToolCallEventStream, ToolInput, ToolPermissionDecision, decide_permission_for_paths, +}; use agent_client_protocol::ToolKind; use agent_settings::AgentSettings; use futures::FutureExt as _; @@ -92,19 +94,24 @@ impl AgentTool for MovePathTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let settings = AgentSettings::get_global(cx); - let paths = vec![input.source_path.clone(), input.destination_path.clone()]; - let decision = decide_permission_for_paths(Self::NAME, &paths, settings); - if let ToolPermissionDecision::Deny(reason) = decision { - return Task::ready(Err(reason)); - } - let project = self.project.clone(); cx.spawn(async move |cx| { + let input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; + let paths = vec![input.source_path.clone(), input.destination_path.clone()]; + let decision = cx.update(|cx| { + decide_permission_for_paths(Self::NAME, &paths, AgentSettings::get_global(cx)) + }); + if let ToolPermissionDecision::Deny(reason) = decision { + return Err(reason); + } + let fs = project.read_with(cx, |project, _cx| project.fs().clone()); let canonical_roots = canonicalize_worktree_roots(&project, &fs, cx).await; @@ -255,7 +262,7 @@ mod tests { }; let (event_stream, mut event_rx) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx)); let auth = event_rx.expect_authorization().await; let title = auth.tool_call.fields.title.as_deref().unwrap_or(""); @@ -309,7 +316,7 @@ mod tests { }; let (event_stream, mut event_rx) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx)); let auth = event_rx.expect_authorization().await; drop(auth); @@ -361,7 +368,7 @@ mod tests { }; let (event_stream, mut event_rx) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx)); let auth = event_rx.expect_authorization().await; let title = auth.tool_call.fields.title.as_deref().unwrap_or(""); @@ -437,7 +444,9 @@ mod tests { }; let (event_stream, mut event_rx) = ToolCallEventStream::test(); - let result = cx.update(|cx| tool.run(input, event_stream, cx)).await; + let result = cx + .update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx)) + .await; assert!(result.is_err(), "Tool should fail when policy denies"); assert!( diff --git a/crates/agent/src/tools/now_tool.rs b/crates/agent/src/tools/now_tool.rs index 689d70ff20d15cbc56fcc0e14a3b46408647f737..fe1cafe5881d14c9700813f742e1f2df0aa1203e 100644 --- a/crates/agent/src/tools/now_tool.rs +++ b/crates/agent/src/tools/now_tool.rs @@ -6,7 +6,7 @@ use gpui::{App, SharedString, Task}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use crate::{AgentTool, ToolCallEventStream}; +use crate::{AgentTool, ToolCallEventStream, ToolInput}; #[derive(Debug, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "snake_case")] @@ -48,14 +48,20 @@ impl AgentTool for NowTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, _event_stream: ToolCallEventStream, - _cx: &mut App, + cx: &mut App, ) -> Task> { - let now = match input.timezone { - Timezone::Utc => Utc::now().to_rfc3339(), - Timezone::Local => Local::now().to_rfc3339(), - }; - Task::ready(Ok(format!("The current datetime is {now}."))) + cx.spawn(async move |_cx| { + let input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; + let now = match input.timezone { + Timezone::Utc => Utc::now().to_rfc3339(), + Timezone::Local => Local::now().to_rfc3339(), + }; + Ok(format!("The current datetime is {now}.")) + }) } } diff --git a/crates/agent/src/tools/open_tool.rs b/crates/agent/src/tools/open_tool.rs index c0b24efbec6418c437e9e3d14ffb5d40b45c91b0..344a513d10c2d62e4247dd3e47bcdf428586d6f0 100644 --- a/crates/agent/src/tools/open_tool.rs +++ b/crates/agent/src/tools/open_tool.rs @@ -2,7 +2,7 @@ use super::tool_permissions::{ ResolvedProjectPath, authorize_symlink_access, canonicalize_worktree_roots, resolve_project_path, }; -use crate::AgentTool; +use crate::{AgentTool, ToolInput}; use agent_client_protocol::ToolKind; use futures::FutureExt as _; use gpui::{App, AppContext as _, Entity, SharedString, Task}; @@ -61,16 +61,24 @@ impl AgentTool for OpenTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: crate::ToolCallEventStream, cx: &mut App, ) -> Task> { - // If path_or_url turns out to be a path in the project, make it absolute. - let abs_path = to_absolute_path(&input.path_or_url, self.project.clone(), cx); - let initial_title = self.initial_title(Ok(input.clone()), cx); - let project = self.project.clone(); cx.spawn(async move |cx| { + let input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; + + // If path_or_url turns out to be a path in the project, make it absolute. + let (abs_path, initial_title) = cx.update(|cx| { + let abs_path = to_absolute_path(&input.path_or_url, project.clone(), cx); + let initial_title = self.initial_title(Ok(input.clone()), cx); + (abs_path, initial_title) + }); + let fs = project.read_with(cx, |project, _cx| project.fs().clone()); let canonical_roots = canonicalize_worktree_roots(&project, &fs, cx).await; diff --git a/crates/agent/src/tools/read_file_tool.rs b/crates/agent/src/tools/read_file_tool.rs index efd33fe5caece4cee4fc02aab8c1a0ebee92f94e..bbc67cf68c7d104772c18ad222478621ce4d7a54 100644 --- a/crates/agent/src/tools/read_file_tool.rs +++ b/crates/agent/src/tools/read_file_tool.rs @@ -21,7 +21,7 @@ use super::tool_permissions::{ ResolvedProjectPath, authorize_symlink_access, canonicalize_worktree_roots, resolve_project_path, }; -use crate::{AgentTool, Thread, ToolCallEventStream, outline}; +use crate::{AgentTool, Thread, ToolCallEventStream, ToolInput, outline}; /// Reads the content of the given file in the project. /// @@ -114,7 +114,7 @@ impl AgentTool for ReadFileTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { @@ -122,6 +122,10 @@ impl AgentTool for ReadFileTool { let thread = self.thread.clone(); let action_log = self.action_log.clone(); cx.spawn(async move |cx| { + let input = input + .recv() + .await + .map_err(tool_content_err)?; let fs = project.read_with(cx, |project, _cx| project.fs().clone()); let canonical_roots = canonicalize_worktree_roots(&project, &fs, cx).await; @@ -398,7 +402,7 @@ mod test { start_line: None, end_line: None, }; - tool.run(input, event_stream, cx) + tool.run(ToolInput::resolved(input), event_stream, cx) }) .await; assert_eq!( @@ -442,7 +446,11 @@ mod test { start_line: None, end_line: None, }; - tool.run(input, ToolCallEventStream::test().0, cx) + tool.run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; assert_eq!(result.unwrap(), "This is a small file content".into()); @@ -485,7 +493,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await .unwrap(); @@ -510,7 +522,11 @@ mod test { start_line: None, end_line: None, }; - tool.run(input, ToolCallEventStream::test().0, cx) + tool.run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await .unwrap(); @@ -570,7 +586,11 @@ mod test { start_line: Some(2), end_line: Some(4), }; - tool.run(input, ToolCallEventStream::test().0, cx) + tool.run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4\n".into()); @@ -613,7 +633,11 @@ mod test { start_line: Some(0), end_line: Some(2), }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; assert_eq!(result.unwrap(), "Line 1\nLine 2\n".into()); @@ -626,7 +650,11 @@ mod test { start_line: Some(1), end_line: Some(0), }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; assert_eq!(result.unwrap(), "Line 1\n".into()); @@ -639,7 +667,11 @@ mod test { start_line: Some(3), end_line: Some(2), }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; assert_eq!(result.unwrap(), "Line 3\n".into()); @@ -744,7 +776,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; assert!( @@ -760,7 +796,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; assert!( @@ -776,7 +816,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; assert!( @@ -791,7 +835,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; assert!( @@ -807,7 +855,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; assert!( @@ -822,7 +874,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; assert!( @@ -837,7 +893,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; assert!( @@ -853,7 +913,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; assert!(result.is_ok(), "Should be able to read normal files"); @@ -867,7 +931,11 @@ mod test { start_line: None, end_line: None, }; - tool.run(input, ToolCallEventStream::test().0, cx) + tool.run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; assert!( @@ -911,11 +979,11 @@ mod test { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let read_task = cx.update(|cx| { tool.run( - ReadFileToolInput { + ToolInput::resolved(ReadFileToolInput { path: "root/secret.png".to_string(), start_line: None, end_line: None, - }, + }), event_stream, cx, ) @@ -1039,7 +1107,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await .unwrap(); @@ -1057,7 +1129,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; @@ -1075,7 +1151,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; @@ -1093,7 +1173,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await .unwrap(); @@ -1111,7 +1195,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; @@ -1129,7 +1217,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; @@ -1148,7 +1240,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, ToolCallEventStream::test().0, cx) + tool.clone().run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; @@ -1210,11 +1306,11 @@ mod test { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { tool.clone().run( - ReadFileToolInput { + ToolInput::resolved(ReadFileToolInput { path: "project/secret_link.txt".to_string(), start_line: None, end_line: None, - }, + }), event_stream, cx, ) @@ -1286,11 +1382,11 @@ mod test { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { tool.clone().run( - ReadFileToolInput { + ToolInput::resolved(ReadFileToolInput { path: "project/secret_link.txt".to_string(), start_line: None, end_line: None, - }, + }), event_stream, cx, ) @@ -1367,11 +1463,11 @@ mod test { let result = cx .update(|cx| { tool.clone().run( - ReadFileToolInput { + ToolInput::resolved(ReadFileToolInput { path: "project/secret_link.txt".to_string(), start_line: None, end_line: None, - }, + }), event_stream, cx, ) diff --git a/crates/agent/src/tools/restore_file_from_disk_tool.rs b/crates/agent/src/tools/restore_file_from_disk_tool.rs index 304e0d1180fe626482206bfdc2dfa6d53f529816..c1aa8690a840ea6911dcb94c26c8cef3cb5f313d 100644 --- a/crates/agent/src/tools/restore_file_from_disk_tool.rs +++ b/crates/agent/src/tools/restore_file_from_disk_tool.rs @@ -17,7 +17,9 @@ use std::path::{Path, PathBuf}; use std::sync::Arc; use util::markdown::MarkdownInlineCode; -use crate::{AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_for_path}; +use crate::{ + AgentTool, ToolCallEventStream, ToolInput, ToolPermissionDecision, decide_permission_for_path, +}; /// Discards unsaved changes in open buffers by reloading file contents from disk. /// @@ -66,25 +68,31 @@ impl AgentTool for RestoreFileFromDiskTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let settings = AgentSettings::get_global(cx).clone(); - - // Check for any immediate deny before spawning async work. - for path in &input.paths { - let path_str = path.to_string_lossy(); - let decision = decide_permission_for_path(Self::NAME, &path_str, &settings); - if let ToolPermissionDecision::Deny(reason) = decision { - return Task::ready(Err(reason)); - } - } - let project = self.project.clone(); - let input_paths = input.paths; cx.spawn(async move |cx| { + let input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; + + // Check for any immediate deny before doing async work. + for path in &input.paths { + let path_str = path.to_string_lossy(); + let decision = cx.update(|cx| { + decide_permission_for_path(Self::NAME, &path_str, AgentSettings::get_global(cx)) + }); + if let ToolPermissionDecision::Deny(reason) = decision { + return Err(reason); + } + } + + let input_paths = input.paths; + let fs = project.read_with(cx, |project, _cx| project.fs().clone()); let canonical_roots = canonicalize_worktree_roots(&project, &fs, cx).await; @@ -92,7 +100,9 @@ impl AgentTool for RestoreFileFromDiskTool { for path in &input_paths { let path_str = path.to_string_lossy(); - let decision = decide_permission_for_path(Self::NAME, &path_str, &settings); + let decision = cx.update(|cx| { + decide_permission_for_path(Self::NAME, &path_str, AgentSettings::get_global(cx)) + }); let symlink_escape = project.read_with(cx, |project, cx| { path_has_symlink_escape(project, path, &canonical_roots, cx) }); @@ -378,12 +388,12 @@ mod tests { let output = cx .update(|cx| { tool.clone().run( - RestoreFileFromDiskToolInput { + ToolInput::resolved(RestoreFileFromDiskToolInput { paths: vec![ PathBuf::from("root/dirty.txt"), PathBuf::from("root/clean.txt"), ], - }, + }), ToolCallEventStream::test().0, cx, ) @@ -428,7 +438,7 @@ mod tests { let output = cx .update(|cx| { tool.clone().run( - RestoreFileFromDiskToolInput { paths: vec![] }, + ToolInput::resolved(RestoreFileFromDiskToolInput { paths: vec![] }), ToolCallEventStream::test().0, cx, ) @@ -441,9 +451,9 @@ mod tests { let output = cx .update(|cx| { tool.clone().run( - RestoreFileFromDiskToolInput { + ToolInput::resolved(RestoreFileFromDiskToolInput { paths: vec![PathBuf::from("nonexistent/path.txt")], - }, + }), ToolCallEventStream::test().0, cx, ) @@ -495,9 +505,9 @@ mod tests { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { tool.clone().run( - RestoreFileFromDiskToolInput { + ToolInput::resolved(RestoreFileFromDiskToolInput { paths: vec![PathBuf::from("project/link.txt")], - }, + }), event_stream, cx, ) @@ -564,9 +574,9 @@ mod tests { let result = cx .update(|cx| { tool.clone().run( - RestoreFileFromDiskToolInput { + ToolInput::resolved(RestoreFileFromDiskToolInput { paths: vec![PathBuf::from("project/link.txt")], - }, + }), event_stream, cx, ) @@ -623,9 +633,9 @@ mod tests { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { tool.clone().run( - RestoreFileFromDiskToolInput { + ToolInput::resolved(RestoreFileFromDiskToolInput { paths: vec![PathBuf::from("project/link.txt")], - }, + }), event_stream, cx, ) diff --git a/crates/agent/src/tools/save_file_tool.rs b/crates/agent/src/tools/save_file_tool.rs index 20140c77d113d96c741d5afbe672882f708870d6..99e937b9dff2a1b4781dde16bd2bf6d64edd25ad 100644 --- a/crates/agent/src/tools/save_file_tool.rs +++ b/crates/agent/src/tools/save_file_tool.rs @@ -17,7 +17,9 @@ use super::tool_permissions::{ canonicalize_worktree_roots, path_has_symlink_escape, resolve_project_path, sensitive_settings_kind, }; -use crate::{AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_for_path}; +use crate::{ + AgentTool, ToolCallEventStream, ToolInput, ToolPermissionDecision, decide_permission_for_path, +}; /// Saves files that have unsaved changes. /// @@ -63,25 +65,31 @@ impl AgentTool for SaveFileTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let settings = AgentSettings::get_global(cx).clone(); - - // Check for any immediate deny before spawning async work. - for path in &input.paths { - let path_str = path.to_string_lossy(); - let decision = decide_permission_for_path(Self::NAME, &path_str, &settings); - if let ToolPermissionDecision::Deny(reason) = decision { - return Task::ready(Err(reason)); - } - } - let project = self.project.clone(); - let input_paths = input.paths; cx.spawn(async move |cx| { + let input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; + + // Check for any immediate deny before doing async work. + for path in &input.paths { + let path_str = path.to_string_lossy(); + let decision = cx.update(|cx| { + decide_permission_for_path(Self::NAME, &path_str, AgentSettings::get_global(cx)) + }); + if let ToolPermissionDecision::Deny(reason) = decision { + return Err(reason); + } + } + + let input_paths = input.paths; + let fs = project.read_with(cx, |project, _cx| project.fs().clone()); let canonical_roots = canonicalize_worktree_roots(&project, &fs, cx).await; @@ -89,7 +97,9 @@ impl AgentTool for SaveFileTool { for path in &input_paths { let path_str = path.to_string_lossy(); - let decision = decide_permission_for_path(Self::NAME, &path_str, &settings); + let decision = cx.update(|cx| { + decide_permission_for_path(Self::NAME, &path_str, AgentSettings::get_global(cx)) + }); let symlink_escape = project.read_with(cx, |project, cx| { path_has_symlink_escape(project, path, &canonical_roots, cx) }); @@ -382,12 +392,12 @@ mod tests { let output = cx .update(|cx| { tool.clone().run( - SaveFileToolInput { + ToolInput::resolved(SaveFileToolInput { paths: vec![ PathBuf::from("root/dirty.txt"), PathBuf::from("root/clean.txt"), ], - }, + }), ToolCallEventStream::test().0, cx, ) @@ -425,7 +435,7 @@ mod tests { let output = cx .update(|cx| { tool.clone().run( - SaveFileToolInput { paths: vec![] }, + ToolInput::resolved(SaveFileToolInput { paths: vec![] }), ToolCallEventStream::test().0, cx, ) @@ -438,9 +448,9 @@ mod tests { let output = cx .update(|cx| { tool.clone().run( - SaveFileToolInput { + ToolInput::resolved(SaveFileToolInput { paths: vec![PathBuf::from("nonexistent/path.txt")], - }, + }), ToolCallEventStream::test().0, cx, ) @@ -490,9 +500,9 @@ mod tests { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { tool.clone().run( - SaveFileToolInput { + ToolInput::resolved(SaveFileToolInput { paths: vec![PathBuf::from("project/link.txt")], - }, + }), event_stream, cx, ) @@ -559,9 +569,9 @@ mod tests { let result = cx .update(|cx| { tool.clone().run( - SaveFileToolInput { + ToolInput::resolved(SaveFileToolInput { paths: vec![PathBuf::from("project/link.txt")], - }, + }), event_stream, cx, ) @@ -618,9 +628,9 @@ mod tests { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { tool.clone().run( - SaveFileToolInput { + ToolInput::resolved(SaveFileToolInput { paths: vec![PathBuf::from("project/link.txt")], - }, + }), event_stream, cx, ) @@ -702,12 +712,12 @@ mod tests { let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { tool.clone().run( - SaveFileToolInput { + ToolInput::resolved(SaveFileToolInput { paths: vec![ PathBuf::from("project/dirty.txt"), PathBuf::from("project/link.txt"), ], - }, + }), event_stream, cx, ) diff --git a/crates/agent/src/tools/spawn_agent_tool.rs b/crates/agent/src/tools/spawn_agent_tool.rs index e2dd78d4476de48465cb5c48e225e2ae5a0a7767..69529282544cc35a01f792dcb45df6eb8bdf67d5 100644 --- a/crates/agent/src/tools/spawn_agent_tool.rs +++ b/crates/agent/src/tools/spawn_agent_tool.rs @@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize}; use std::rc::Rc; use std::sync::Arc; -use crate::{AgentTool, Thread, ThreadEnvironment, ToolCallEventStream}; +use crate::{AgentTool, Thread, ThreadEnvironment, ToolCallEventStream, ToolInput}; /// Spawns an agent to perform a delegated task. /// @@ -97,61 +97,78 @@ impl AgentTool for SpawnAgentTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let Some(parent_thread_entity) = self.parent_thread.upgrade() else { - return Task::ready(Err(SpawnAgentToolOutput::Error { - session_id: None, - error: "Parent thread no longer exists".to_string(), - })); - }; - - let subagent = if let Some(session_id) = input.session_id { - self.environment - .resume_subagent(parent_thread_entity, session_id, input.message, cx) - } else { - self.environment - .create_subagent(parent_thread_entity, input.label, input.message, cx) - }; - let subagent = match subagent { - Ok(subagent) => subagent, - Err(err) => { - return Task::ready(Err(SpawnAgentToolOutput::Error { + cx.spawn(async move |cx| { + let input = input + .recv() + .await + .map_err(|e| SpawnAgentToolOutput::Error { + session_id: None, + error: format!("Failed to receive tool input: {e}"), + })?; + + let (subagent, subagent_session_id) = cx.update(|cx| { + let Some(parent_thread_entity) = self.parent_thread.upgrade() else { + return Err(SpawnAgentToolOutput::Error { + session_id: None, + error: "Parent thread no longer exists".to_string(), + }); + }; + + let subagent = if let Some(session_id) = input.session_id { + self.environment.resume_subagent( + parent_thread_entity, + session_id, + input.message, + cx, + ) + } else { + self.environment.create_subagent( + parent_thread_entity, + input.label, + input.message, + cx, + ) + }; + let subagent = subagent.map_err(|err| SpawnAgentToolOutput::Error { session_id: None, error: err.to_string(), - })); - } - }; - let subagent_session_id = subagent.id(); - - event_stream.subagent_spawned(subagent_session_id.clone()); - let meta = acp::Meta::from_iter([( - SUBAGENT_SESSION_ID_META_KEY.into(), - subagent_session_id.to_string().into(), - )]); - event_stream.update_fields_with_meta(acp::ToolCallUpdateFields::new(), Some(meta)); - - cx.spawn(async move |cx| match subagent.wait_for_output(cx).await { - Ok(output) => { - event_stream.update_fields( - acp::ToolCallUpdateFields::new().content(vec![output.clone().into()]), - ); - Ok(SpawnAgentToolOutput::Success { - session_id: subagent_session_id, - output, - }) - } - Err(e) => { - let error = e.to_string(); - event_stream.update_fields( - acp::ToolCallUpdateFields::new().content(vec![error.clone().into()]), - ); - Err(SpawnAgentToolOutput::Error { - session_id: Some(subagent_session_id), - error, - }) + })?; + let subagent_session_id = subagent.id(); + + event_stream.subagent_spawned(subagent_session_id.clone()); + let meta = acp::Meta::from_iter([( + SUBAGENT_SESSION_ID_META_KEY.into(), + subagent_session_id.to_string().into(), + )]); + event_stream.update_fields_with_meta(acp::ToolCallUpdateFields::new(), Some(meta)); + + Ok((subagent, subagent_session_id)) + })?; + + match subagent.wait_for_output(cx).await { + Ok(output) => { + event_stream.update_fields( + acp::ToolCallUpdateFields::new().content(vec![output.clone().into()]), + ); + Ok(SpawnAgentToolOutput::Success { + session_id: subagent_session_id, + output, + }) + } + Err(e) => { + let error = e.to_string(); + event_stream.update_fields( + acp::ToolCallUpdateFields::new().content(vec![error.clone().into()]), + ); + Err(SpawnAgentToolOutput::Error { + session_id: Some(subagent_session_id), + error, + }) + } } }) } diff --git a/crates/agent/src/tools/streaming_edit_file_tool.rs b/crates/agent/src/tools/streaming_edit_file_tool.rs index dd5445142a001fbd9106af548444165bc8331581..95651b44bac44ad3cc67c25c0ef13fc885342ce3 100644 --- a/crates/agent/src/tools/streaming_edit_file_tool.rs +++ b/crates/agent/src/tools/streaming_edit_file_tool.rs @@ -2,7 +2,7 @@ use super::edit_file_tool::EditFileTool; use super::restore_file_from_disk_tool::RestoreFileFromDiskTool; use super::save_file_tool::SaveFileTool; use crate::{ - AgentTool, Templates, Thread, ToolCallEventStream, + AgentTool, Thread, ToolCallEventStream, ToolInput, edit_agent::streaming_fuzzy_matcher::StreamingFuzzyMatcher, }; use acp_thread::Diff; @@ -164,8 +164,6 @@ pub struct StreamingEditFileTool { thread: WeakEntity, language_registry: Arc, project: Entity, - #[allow(dead_code)] - templates: Arc, } impl StreamingEditFileTool { @@ -173,13 +171,11 @@ impl StreamingEditFileTool { project: Entity, thread: WeakEntity, language_registry: Arc, - templates: Arc, ) -> Self { Self { project, thread, language_registry, - templates, } } @@ -188,7 +184,6 @@ impl StreamingEditFileTool { project: self.project.clone(), thread: new_thread, language_registry: self.language_registry.clone(), - templates: self.templates.clone(), } } @@ -268,38 +263,41 @@ impl AgentTool for StreamingEditFileTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let Ok(project) = self - .thread - .read_with(cx, |thread, _cx| thread.project().clone()) - else { - return Task::ready(Err(StreamingEditFileToolOutput::Error { - error: "thread was dropped".to_string(), - })); - }; - - let project_path = match resolve_path(&input, project.clone(), cx) { - Ok(path) => path, - Err(err) => { - return Task::ready(Err(StreamingEditFileToolOutput::Error { - error: err.to_string(), - })); - } - }; - - let abs_path = project.read(cx).absolute_path(&project_path, cx); - if let Some(abs_path) = abs_path.clone() { - event_stream.update_fields( - ToolCallUpdateFields::new().locations(vec![acp::ToolCallLocation::new(abs_path)]), - ); - } - - let authorize = self.authorize(&input, &event_stream, cx); - cx.spawn(async move |cx: &mut AsyncApp| { + let input = input.recv().await.map_err(|e| { + StreamingEditFileToolOutput::Error { + error: format!("Failed to receive tool input: {e}"), + } + })?; + + let project = self + .thread + .read_with(cx, |thread, _cx| thread.project().clone()) + .map_err(|_| StreamingEditFileToolOutput::Error { + error: "thread was dropped".to_string(), + })?; + + let (project_path, abs_path, authorize) = cx.update(|cx| { + let project_path = + resolve_path(&input, project.clone(), cx).map_err(|err| { + StreamingEditFileToolOutput::Error { + error: err.to_string(), + } + })?; + let abs_path = project.read(cx).absolute_path(&project_path, cx); + if let Some(abs_path) = abs_path.clone() { + event_stream.update_fields( + ToolCallUpdateFields::new() + .locations(vec![acp::ToolCallLocation::new(abs_path)]), + ); + } + let authorize = self.authorize(&input, &event_stream, cx); + Ok::<_, StreamingEditFileToolOutput>((project_path, abs_path, authorize)) + })?; let result: anyhow::Result = async { authorize.await?; @@ -787,9 +785,12 @@ mod tests { project.clone(), thread.downgrade(), language_registry, - Templates::new(), )) - .run(input, ToolCallEventStream::test().0, cx) + .run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; @@ -836,9 +837,12 @@ mod tests { project.clone(), thread.downgrade(), language_registry, - Templates::new(), )) - .run(input, ToolCallEventStream::test().0, cx) + .run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; @@ -896,9 +900,12 @@ mod tests { project.clone(), thread.downgrade(), language_registry, - Templates::new(), )) - .run(input, ToolCallEventStream::test().0, cx) + .run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; @@ -958,9 +965,12 @@ mod tests { project.clone(), thread.downgrade(), language_registry, - Templates::new(), )) - .run(input, ToolCallEventStream::test().0, cx) + .run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; @@ -1023,9 +1033,12 @@ mod tests { project.clone(), thread.downgrade(), language_registry, - Templates::new(), )) - .run(input, ToolCallEventStream::test().0, cx) + .run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; @@ -1088,9 +1101,12 @@ mod tests { project.clone(), thread.downgrade(), language_registry, - Templates::new(), )) - .run(input, ToolCallEventStream::test().0, cx) + .run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; @@ -1141,9 +1157,12 @@ mod tests { project, thread.downgrade(), language_registry, - Templates::new(), )) - .run(input, ToolCallEventStream::test().0, cx) + .run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; @@ -1192,9 +1211,12 @@ mod tests { project, thread.downgrade(), language_registry, - Templates::new(), )) - .run(input, ToolCallEventStream::test().0, cx) + .run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; @@ -1262,9 +1284,12 @@ mod tests { project, thread.downgrade(), language_registry, - Templates::new(), )) - .run(input, ToolCallEventStream::test().0, cx) + .run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) }) .await; diff --git a/crates/agent/src/tools/terminal_tool.rs b/crates/agent/src/tools/terminal_tool.rs index 57b3278da256c01408f704a8e2f6f7e075057597..6396bd1b0e63b46a0207dd7df9b9f2fcd00176b7 100644 --- a/crates/agent/src/tools/terminal_tool.rs +++ b/crates/agent/src/tools/terminal_tool.rs @@ -15,7 +15,7 @@ use std::{ }; use crate::{ - AgentTool, ThreadEnvironment, ToolCallEventStream, ToolPermissionDecision, + AgentTool, ThreadEnvironment, ToolCallEventStream, ToolInput, ToolPermissionDecision, decide_permission_from_settings, }; @@ -85,34 +85,45 @@ impl AgentTool for TerminalTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let working_dir = match working_dir(&input, &self.project, cx) { - Ok(dir) => dir, - Err(err) => return Task::ready(Err(err.to_string())), - }; + cx.spawn(async move |cx| { + let input = input + .recv() + .await + .map_err(|e| format!("Failed to receive tool input: {e}"))?; - let settings = AgentSettings::get_global(cx); - let decision = decide_permission_from_settings( - Self::NAME, - std::slice::from_ref(&input.command), - settings, - ); + let (working_dir, authorize) = cx.update(|cx| { + let working_dir = + working_dir(&input, &self.project, cx).map_err(|err| err.to_string())?; - let authorize = match decision { - ToolPermissionDecision::Allow => None, - ToolPermissionDecision::Deny(reason) => { - return Task::ready(Err(reason)); - } - ToolPermissionDecision::Confirm => { - let context = - crate::ToolPermissionContext::new(Self::NAME, vec![input.command.clone()]); - Some(event_stream.authorize(self.initial_title(Ok(input.clone()), cx), context, cx)) - } - }; - cx.spawn(async move |cx| { + let decision = decide_permission_from_settings( + Self::NAME, + std::slice::from_ref(&input.command), + AgentSettings::get_global(cx), + ); + + let authorize = match decision { + ToolPermissionDecision::Allow => None, + ToolPermissionDecision::Deny(reason) => { + return Err(reason); + } + ToolPermissionDecision::Confirm => { + let context = crate::ToolPermissionContext::new( + Self::NAME, + vec![input.command.clone()], + ); + Some(event_stream.authorize( + self.initial_title(Ok(input.clone()), cx), + context, + cx, + )) + } + }; + Ok((working_dir, authorize)) + })?; if let Some(authorize) = authorize { authorize.await.map_err(|e| e.to_string())?; } diff --git a/crates/agent/src/tools/web_search_tool.rs b/crates/agent/src/tools/web_search_tool.rs index c536f45ba65c109d3068b0722db1ffb1cad8b87c..c697a5b78f1fe8c84d6ed58db13f651a493ae8c3 100644 --- a/crates/agent/src/tools/web_search_tool.rs +++ b/crates/agent/src/tools/web_search_tool.rs @@ -1,14 +1,15 @@ use std::sync::Arc; use crate::{ - AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_from_settings, + AgentTool, ToolCallEventStream, ToolInput, ToolPermissionDecision, + decide_permission_from_settings, }; use agent_client_protocol as acp; use agent_settings::AgentSettings; use anyhow::Result; use cloud_llm_client::WebSearchResponse; use futures::FutureExt as _; -use gpui::{App, AppContext, Task}; +use gpui::{App, Task}; use language_model::{ LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID, }; @@ -73,41 +74,51 @@ impl AgentTool for WebSearchTool { fn run( self: Arc, - input: Self::Input, + input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let settings = AgentSettings::get_global(cx); - let decision = decide_permission_from_settings( - Self::NAME, - std::slice::from_ref(&input.query), - settings, - ); - - let authorize = match decision { - ToolPermissionDecision::Allow => None, - ToolPermissionDecision::Deny(reason) => { - return Task::ready(Err(WebSearchToolOutput::Error { error: reason })); - } - ToolPermissionDecision::Confirm => { - let context = - crate::ToolPermissionContext::new(Self::NAME, vec![input.query.clone()]); - Some(event_stream.authorize( - format!("Search the web for {}", MarkdownInlineCode(&input.query)), - context, - cx, - )) - } - }; + cx.spawn(async move |cx| { + let input = input + .recv() + .await + .map_err(|e| WebSearchToolOutput::Error { + error: format!("Failed to receive tool input: {e}"), + })?; + + let (authorize, search_task) = cx.update(|cx| { + let decision = decide_permission_from_settings( + Self::NAME, + std::slice::from_ref(&input.query), + AgentSettings::get_global(cx), + ); + + let authorize = match decision { + ToolPermissionDecision::Allow => None, + ToolPermissionDecision::Deny(reason) => { + return Err(WebSearchToolOutput::Error { error: reason }); + } + ToolPermissionDecision::Confirm => { + let context = + crate::ToolPermissionContext::new(Self::NAME, vec![input.query.clone()]); + Some(event_stream.authorize( + format!("Search the web for {}", MarkdownInlineCode(&input.query)), + context, + cx, + )) + } + }; + + let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else { + return Err(WebSearchToolOutput::Error { + error: "Web search is not available.".to_string(), + }); + }; - let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else { - return Task::ready(Err(WebSearchToolOutput::Error { - error: "Web search is not available.".to_string(), - })); - }; + let search_task = provider.search(input.query, cx); + Ok((authorize, search_task)) + })?; - let search_task = provider.search(input.query, cx); - cx.background_spawn(async move { if let Some(authorize) = authorize { authorize.await.map_err(|e| WebSearchToolOutput::Error { error: e.to_string() })?; } diff --git a/crates/remote_server/src/remote_editing_tests.rs b/crates/remote_server/src/remote_editing_tests.rs index 9d673182bc64e192e6db13a927392d611c53407d..f15382b67557fa9a9b0eda2a9d4438aa33c7cff3 100644 --- a/crates/remote_server/src/remote_editing_tests.rs +++ b/crates/remote_server/src/remote_editing_tests.rs @@ -2,7 +2,9 @@ /// The tests in this file assume that server_cx is running on Windows too. /// We neead to find a way to test Windows-Non-Windows interactions. use crate::headless_project::HeadlessProject; -use agent::{AgentTool, ReadFileTool, ReadFileToolInput, Templates, Thread, ToolCallEventStream}; +use agent::{ + AgentTool, ReadFileTool, ReadFileToolInput, Templates, Thread, ToolCallEventStream, ToolInput, +}; use client::{Client, UserStore}; use clock::FakeSystemClock; use collections::{HashMap, HashSet}; @@ -1962,7 +1964,11 @@ async fn test_remote_agent_fs_tool_calls(cx: &mut TestAppContext, server_cx: &mu let read_tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log)); let (event_stream, _) = ToolCallEventStream::test(); - let exists_result = cx.update(|cx| read_tool.clone().run(input, event_stream.clone(), cx)); + let exists_result = cx.update(|cx| { + read_tool + .clone() + .run(ToolInput::resolved(input), event_stream.clone(), cx) + }); let output = exists_result.await.unwrap(); assert_eq!(output, LanguageModelToolResultContent::Text("B".into())); @@ -1971,7 +1977,8 @@ async fn test_remote_agent_fs_tool_calls(cx: &mut TestAppContext, server_cx: &mu start_line: None, end_line: None, }; - let does_not_exist_result = cx.update(|cx| read_tool.run(input, event_stream, cx)); + let does_not_exist_result = + cx.update(|cx| read_tool.run(ToolInput::resolved(input), event_stream, cx)); does_not_exist_result.await.unwrap_err(); } diff --git a/crates/zed/src/visual_test_runner.rs b/crates/zed/src/visual_test_runner.rs index 09340dcec641ae2a6c1ea871e770886d14276529..b7471321db203075ac6c71eee0b3ef29c5edaefc 100644 --- a/crates/zed/src/visual_test_runner.rs +++ b/crates/zed/src/visual_test_runner.rs @@ -1962,7 +1962,7 @@ fn run_agent_thread_view_test( cx: &mut VisualTestAppContext, update_baseline: bool, ) -> Result { - use agent::AgentTool; + use agent::{AgentTool, ToolInput}; use agent_ui::AgentPanel; // Create a temporary directory with the test image @@ -2047,7 +2047,10 @@ fn run_agent_thread_view_test( start_line: None, end_line: None, }; - let run_task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); + let run_task = cx.update(|cx| { + tool.clone() + .run(ToolInput::resolved(input), event_stream, cx) + }); cx.background_executor.allow_parking(); let run_result = cx.foreground_executor.block_test(run_task);