From e6e23d04f8e40afc0b1dd122713b0c1674ae4d70 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 18 Aug 2025 14:56:10 +0200 Subject: [PATCH] Checkpoint --- crates/acp_thread/src/acp_thread.rs | 12 +++- crates/acp_thread/src/diff.rs | 13 ++-- crates/agent2/src/agent.rs | 3 +- crates/agent2/src/tools/edit_file_tool.rs | 82 +++++++++++++++------- crates/agent2/src/tools/web_search_tool.rs | 67 +++++++++++------- 5 files changed, 113 insertions(+), 64 deletions(-) diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 0fb0d9e779a323c3d22146a8e8ef25947d2f4b4a..a0e62c29e3abec25206cb273b4d206b29ae8922c 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -539,9 +539,15 @@ impl ToolCallContent { acp::ToolCallContent::Content { content } => { Self::ContentBlock(ContentBlock::new(content, &language_registry, cx)) } - acp::ToolCallContent::Diff { diff } => { - Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx))) - } + acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| { + Diff::finalized( + diff.path, + diff.old_text, + diff.new_text, + language_registry, + cx, + ) + })), } } diff --git a/crates/acp_thread/src/diff.rs b/crates/acp_thread/src/diff.rs index a2c2d6c3229ae96bf45dfc870e8600a5f778a6f0..a67e37bcb84a6e2c8ba9f9cbaeedd5fc120444b3 100644 --- a/crates/acp_thread/src/diff.rs +++ b/crates/acp_thread/src/diff.rs @@ -1,4 +1,3 @@ -use agent_client_protocol as acp; use anyhow::Result; use buffer_diff::{BufferDiff, BufferDiffSnapshot}; use editor::{MultiBuffer, PathKey}; @@ -21,17 +20,13 @@ pub enum Diff { } impl Diff { - pub fn from_acp( - diff: acp::Diff, + pub fn finalized( + path: PathBuf, + old_text: Option, + new_text: String, language_registry: Arc, cx: &mut Context, ) -> Self { - let acp::Diff { - path, - old_text, - new_text, - } = diff; - let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly)); let new_buffer = cx.new(|cx| Buffer::local(new_text, cx)); diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 5b27f0a048f7d96bd038cb0322b7fa65005ab33b..398054739e48c6ef45a1fa6ac21b461e1758458a 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -587,11 +587,12 @@ impl NativeAgentConnection { action_log: Entity, cx: &mut Context, ) { + let language_registry = project.read(cx).languages().clone(); thread.add_tool(CopyPathTool::new(project.clone())); thread.add_tool(CreateDirectoryTool::new(project.clone())); thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone())); thread.add_tool(DiagnosticsTool::new(project.clone())); - thread.add_tool(EditFileTool::new(cx.weak_entity())); + thread.add_tool(EditFileTool::new(cx.weak_entity(), language_registry)); thread.add_tool(FetchTool::new(project.read(cx).client().http_client())); thread.add_tool(FindPathTool::new(project.clone())); thread.add_tool(GrepTool::new(project.clone())); diff --git a/crates/agent2/src/tools/edit_file_tool.rs b/crates/agent2/src/tools/edit_file_tool.rs index 62774ac2b16c0919dc60a385dee2b65a679d54b2..01fa77e22ddc58b27d67e8fab6d1cf0bd64ae84e 100644 --- a/crates/agent2/src/tools/edit_file_tool.rs +++ b/crates/agent2/src/tools/edit_file_tool.rs @@ -7,8 +7,8 @@ use cloud_llm_client::CompletionIntent; use collections::HashSet; use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity}; use indoc::formatdoc; -use language::ToPoint; use language::language_settings::{self, FormatOnSave}; +use language::{LanguageRegistry, ToPoint}; use language_model::LanguageModelToolResultContent; use paths; use project::lsp_store::{FormatTrigger, LspFormatTarget}; @@ -98,11 +98,13 @@ pub enum EditFileMode { #[derive(Debug, Serialize, Deserialize)] pub struct EditFileToolOutput { + #[serde(alias = "original_path")] input_path: PathBuf, - project_path: PathBuf, new_text: String, old_text: Arc, + #[serde(default)] diff: String, + #[serde(alias = "raw_output")] edit_agent_output: EditAgentOutput, } @@ -123,11 +125,15 @@ impl From for LanguageModelToolResultContent { pub struct EditFileTool { thread: WeakEntity, + language_registry: Arc, } impl EditFileTool { - pub fn new(thread: WeakEntity) -> Self { - Self { thread } + pub fn new(thread: WeakEntity, language_registry: Arc) -> Self { + Self { + thread, + language_registry, + } } fn authorize( @@ -419,7 +425,6 @@ impl AgentTool for EditFileTool { Ok(EditFileToolOutput { input_path: input.path, - project_path: project_path.path.to_path_buf(), new_text: new_text.clone(), old_text, diff: unified_diff, @@ -427,6 +432,26 @@ impl AgentTool for EditFileTool { }) }) } + + fn replay( + &self, + _input: Self::Input, + output: Self::Output, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Result<()> { + dbg!(&output); + event_stream.update_diff(cx.new(|cx| { + Diff::finalized( + output.input_path, + Some(output.old_text.to_string()), + output.new_text, + self.language_registry.clone(), + cx, + ) + })); + Ok(()) + } } /// Validate that the file path is valid, meaning: @@ -515,6 +540,7 @@ mod tests { let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/root", json!({})).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); @@ -537,7 +563,7 @@ mod tests { path: "root/nonexistent_file.txt".into(), mode: EditFileMode::Edit, }; - Arc::new(EditFileTool::new(thread.downgrade())).run( + Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run( input, ToolCallEventStream::test().0, cx, @@ -754,11 +780,11 @@ mod tests { path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, }; - Arc::new(EditFileTool::new(thread.downgrade())).run( - input, - ToolCallEventStream::test().0, - cx, - ) + Arc::new(EditFileTool::new( + thread.downgrade(), + language_registry.clone(), + )) + .run(input, ToolCallEventStream::test().0, cx) }); // Stream the unformatted content @@ -811,7 +837,7 @@ mod tests { path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, }; - Arc::new(EditFileTool::new(thread.downgrade())).run( + Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run( input, ToolCallEventStream::test().0, cx, @@ -857,6 +883,7 @@ mod tests { .unwrap(); let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let action_log = cx.new(|_| ActionLog::new(project.clone())); @@ -896,11 +923,11 @@ mod tests { path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, }; - Arc::new(EditFileTool::new(thread.downgrade())).run( - input, - ToolCallEventStream::test().0, - cx, - ) + Arc::new(EditFileTool::new( + thread.downgrade(), + language_registry.clone(), + )) + .run(input, ToolCallEventStream::test().0, cx) }); // Stream the content with trailing whitespace @@ -948,7 +975,7 @@ mod tests { path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, }; - Arc::new(EditFileTool::new(thread.downgrade())).run( + Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run( input, ToolCallEventStream::test().0, cx, @@ -985,6 +1012,7 @@ mod tests { init_test(cx); let fs = project::FakeFs::new(cx.executor()); let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let action_log = cx.new(|_| ActionLog::new(project.clone())); @@ -1000,7 +1028,7 @@ mod tests { cx, ) }); - let tool = Arc::new(EditFileTool::new(thread.downgrade())); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); fs.insert_tree("/root", json!({})).await; // Test 1: Path with .zed component should require confirmation @@ -1122,6 +1150,7 @@ mod tests { let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/project", json!({})).await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let action_log = cx.new(|_| ActionLog::new(project.clone())); @@ -1137,7 +1166,7 @@ mod tests { cx, ) }); - let tool = Arc::new(EditFileTool::new(thread.downgrade())); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); // Test global config paths - these should require confirmation if they exist and are outside the project let test_cases = vec![ @@ -1231,7 +1260,7 @@ mod tests { cx, ) .await; - + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); @@ -1247,7 +1276,7 @@ mod tests { cx, ) }); - let tool = Arc::new(EditFileTool::new(thread.downgrade())); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); // Test files in different worktrees let test_cases = vec![ @@ -1313,6 +1342,7 @@ mod tests { ) .await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); @@ -1328,7 +1358,7 @@ mod tests { cx, ) }); - let tool = Arc::new(EditFileTool::new(thread.downgrade())); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); // Test edge cases let test_cases = vec![ @@ -1397,6 +1427,7 @@ mod tests { ) .await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); @@ -1412,7 +1443,7 @@ mod tests { cx, ) }); - let tool = Arc::new(EditFileTool::new(thread.downgrade())); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); // Test different EditFileMode values let modes = vec![ @@ -1478,6 +1509,7 @@ mod tests { init_test(cx); let fs = project::FakeFs::new(cx.executor()); let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); @@ -1493,7 +1525,7 @@ mod tests { cx, ) }); - let tool = Arc::new(EditFileTool::new(thread.downgrade())); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); assert_eq!( tool.initial_title(Err(json!({ diff --git a/crates/agent2/src/tools/web_search_tool.rs b/crates/agent2/src/tools/web_search_tool.rs index c1c09707426431bf8a3ad4c59a012a567366d392..d71a128bfe4f70a95aa71d776b76bd4f5426800a 100644 --- a/crates/agent2/src/tools/web_search_tool.rs +++ b/crates/agent2/src/tools/web_search_tool.rs @@ -80,33 +80,48 @@ impl AgentTool for WebSearchTool { } }; - let result_text = if response.results.len() == 1 { - "1 result".to_string() - } else { - format!("{} results", response.results.len()) - }; - event_stream.update_fields(acp::ToolCallUpdateFields { - title: Some(format!("Searched the web: {result_text}")), - content: Some( - response - .results - .iter() - .map(|result| acp::ToolCallContent::Content { - content: acp::ContentBlock::ResourceLink(acp::ResourceLink { - name: result.title.clone(), - uri: result.url.clone(), - title: Some(result.title.clone()), - description: Some(result.text.clone()), - mime_type: None, - annotations: None, - size: None, - }), - }) - .collect(), - ), - ..Default::default() - }); + emit_update(&response, &event_stream); Ok(WebSearchToolOutput(response)) }) } + + fn replay( + &self, + _input: Self::Input, + output: Self::Output, + event_stream: ToolCallEventStream, + _cx: &mut App, + ) -> Result<()> { + emit_update(&output.0, &event_stream); + Ok(()) + } +} + +fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) { + let result_text = if response.results.len() == 1 { + "1 result".to_string() + } else { + format!("{} results", response.results.len()) + }; + event_stream.update_fields(acp::ToolCallUpdateFields { + title: Some(format!("Searched the web: {result_text}")), + content: Some( + response + .results + .iter() + .map(|result| acp::ToolCallContent::Content { + content: acp::ContentBlock::ResourceLink(acp::ResourceLink { + name: result.title.clone(), + uri: result.url.clone(), + title: Some(result.title.clone()), + description: Some(result.text.clone()), + mime_type: None, + annotations: None, + size: None, + }), + }) + .collect(), + ), + ..Default::default() + }); }