Checkpoint

Antonio Scandurra created

Change summary

crates/acp_thread/src/acp_thread.rs        | 12 ++
crates/acp_thread/src/diff.rs              | 13 +--
crates/agent2/src/agent.rs                 |  3 
crates/agent2/src/tools/edit_file_tool.rs  | 82 ++++++++++++++++-------
crates/agent2/src/tools/web_search_tool.rs | 67 ++++++++++++-------
5 files changed, 113 insertions(+), 64 deletions(-)

Detailed changes

crates/acp_thread/src/acp_thread.rs 🔗

@@ -539,9 +539,15 @@ impl ToolCallContent {
             acp::ToolCallContent::Content { content } => {
                 Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
             }
-            acp::ToolCallContent::Diff { diff } => {
-                Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx)))
-            }
+            acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
+                Diff::finalized(
+                    diff.path,
+                    diff.old_text,
+                    diff.new_text,
+                    language_registry,
+                    cx,
+                )
+            })),
         }
     }
 

crates/acp_thread/src/diff.rs 🔗

@@ -1,4 +1,3 @@
-use agent_client_protocol as acp;
 use anyhow::Result;
 use buffer_diff::{BufferDiff, BufferDiffSnapshot};
 use editor::{MultiBuffer, PathKey};
@@ -21,17 +20,13 @@ pub enum Diff {
 }
 
 impl Diff {
-    pub fn from_acp(
-        diff: acp::Diff,
+    pub fn finalized(
+        path: PathBuf,
+        old_text: Option<String>,
+        new_text: String,
         language_registry: Arc<LanguageRegistry>,
         cx: &mut Context<Self>,
     ) -> Self {
-        let acp::Diff {
-            path,
-            old_text,
-            new_text,
-        } = diff;
-
         let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
 
         let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));

crates/agent2/src/agent.rs 🔗

@@ -587,11 +587,12 @@ impl NativeAgentConnection {
         action_log: Entity<action_log::ActionLog>,
         cx: &mut Context<Thread>,
     ) {
+        let language_registry = project.read(cx).languages().clone();
         thread.add_tool(CopyPathTool::new(project.clone()));
         thread.add_tool(CreateDirectoryTool::new(project.clone()));
         thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
         thread.add_tool(DiagnosticsTool::new(project.clone()));
-        thread.add_tool(EditFileTool::new(cx.weak_entity()));
+        thread.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
         thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
         thread.add_tool(FindPathTool::new(project.clone()));
         thread.add_tool(GrepTool::new(project.clone()));

crates/agent2/src/tools/edit_file_tool.rs 🔗

@@ -7,8 +7,8 @@ use cloud_llm_client::CompletionIntent;
 use collections::HashSet;
 use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
 use indoc::formatdoc;
-use language::ToPoint;
 use language::language_settings::{self, FormatOnSave};
+use language::{LanguageRegistry, ToPoint};
 use language_model::LanguageModelToolResultContent;
 use paths;
 use project::lsp_store::{FormatTrigger, LspFormatTarget};
@@ -98,11 +98,13 @@ pub enum EditFileMode {
 
 #[derive(Debug, Serialize, Deserialize)]
 pub struct EditFileToolOutput {
+    #[serde(alias = "original_path")]
     input_path: PathBuf,
-    project_path: PathBuf,
     new_text: String,
     old_text: Arc<String>,
+    #[serde(default)]
     diff: String,
+    #[serde(alias = "raw_output")]
     edit_agent_output: EditAgentOutput,
 }
 
@@ -123,11 +125,15 @@ impl From<EditFileToolOutput> for LanguageModelToolResultContent {
 
 pub struct EditFileTool {
     thread: WeakEntity<Thread>,
+    language_registry: Arc<LanguageRegistry>,
 }
 
 impl EditFileTool {
-    pub fn new(thread: WeakEntity<Thread>) -> Self {
-        Self { thread }
+    pub fn new(thread: WeakEntity<Thread>, language_registry: Arc<LanguageRegistry>) -> Self {
+        Self {
+            thread,
+            language_registry,
+        }
     }
 
     fn authorize(
@@ -419,7 +425,6 @@ impl AgentTool for EditFileTool {
 
             Ok(EditFileToolOutput {
                 input_path: input.path,
-                project_path: project_path.path.to_path_buf(),
                 new_text: new_text.clone(),
                 old_text,
                 diff: unified_diff,
@@ -427,6 +432,26 @@ impl AgentTool for EditFileTool {
             })
         })
     }
+
+    fn replay(
+        &self,
+        _input: Self::Input,
+        output: Self::Output,
+        event_stream: ToolCallEventStream,
+        cx: &mut App,
+    ) -> Result<()> {
+        dbg!(&output);
+        event_stream.update_diff(cx.new(|cx| {
+            Diff::finalized(
+                output.input_path,
+                Some(output.old_text.to_string()),
+                output.new_text,
+                self.language_registry.clone(),
+                cx,
+            )
+        }));
+        Ok(())
+    }
 }
 
 /// Validate that the file path is valid, meaning:
@@ -515,6 +540,7 @@ mod tests {
         let fs = project::FakeFs::new(cx.executor());
         fs.insert_tree("/root", json!({})).await;
         let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
         let context_server_registry =
             cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@@ -537,7 +563,7 @@ mod tests {
                     path: "root/nonexistent_file.txt".into(),
                     mode: EditFileMode::Edit,
                 };
-                Arc::new(EditFileTool::new(thread.downgrade())).run(
+                Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
                     input,
                     ToolCallEventStream::test().0,
                     cx,
@@ -754,11 +780,11 @@ mod tests {
                     path: "root/src/main.rs".into(),
                     mode: EditFileMode::Overwrite,
                 };
-                Arc::new(EditFileTool::new(thread.downgrade())).run(
-                    input,
-                    ToolCallEventStream::test().0,
-                    cx,
-                )
+                Arc::new(EditFileTool::new(
+                    thread.downgrade(),
+                    language_registry.clone(),
+                ))
+                .run(input, ToolCallEventStream::test().0, cx)
             });
 
             // Stream the unformatted content
@@ -811,7 +837,7 @@ mod tests {
                     path: "root/src/main.rs".into(),
                     mode: EditFileMode::Overwrite,
                 };
-                Arc::new(EditFileTool::new(thread.downgrade())).run(
+                Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
                     input,
                     ToolCallEventStream::test().0,
                     cx,
@@ -857,6 +883,7 @@ mod tests {
         .unwrap();
 
         let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
         let context_server_registry =
             cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
@@ -896,11 +923,11 @@ mod tests {
                     path: "root/src/main.rs".into(),
                     mode: EditFileMode::Overwrite,
                 };
-                Arc::new(EditFileTool::new(thread.downgrade())).run(
-                    input,
-                    ToolCallEventStream::test().0,
-                    cx,
-                )
+                Arc::new(EditFileTool::new(
+                    thread.downgrade(),
+                    language_registry.clone(),
+                ))
+                .run(input, ToolCallEventStream::test().0, cx)
             });
 
             // Stream the content with trailing whitespace
@@ -948,7 +975,7 @@ mod tests {
                     path: "root/src/main.rs".into(),
                     mode: EditFileMode::Overwrite,
                 };
-                Arc::new(EditFileTool::new(thread.downgrade())).run(
+                Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
                     input,
                     ToolCallEventStream::test().0,
                     cx,
@@ -985,6 +1012,7 @@ mod tests {
         init_test(cx);
         let fs = project::FakeFs::new(cx.executor());
         let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
         let context_server_registry =
             cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
@@ -1000,7 +1028,7 @@ mod tests {
                 cx,
             )
         });
-        let tool = Arc::new(EditFileTool::new(thread.downgrade()));
+        let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
         fs.insert_tree("/root", json!({})).await;
 
         // Test 1: Path with .zed component should require confirmation
@@ -1122,6 +1150,7 @@ mod tests {
         let fs = project::FakeFs::new(cx.executor());
         fs.insert_tree("/project", json!({})).await;
         let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
         let context_server_registry =
             cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
@@ -1137,7 +1166,7 @@ mod tests {
                 cx,
             )
         });
-        let tool = Arc::new(EditFileTool::new(thread.downgrade()));
+        let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
 
         // Test global config paths - these should require confirmation if they exist and are outside the project
         let test_cases = vec![
@@ -1231,7 +1260,7 @@ mod tests {
             cx,
         )
         .await;
-
+        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
         let context_server_registry =
             cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@@ -1247,7 +1276,7 @@ mod tests {
                 cx,
             )
         });
-        let tool = Arc::new(EditFileTool::new(thread.downgrade()));
+        let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
 
         // Test files in different worktrees
         let test_cases = vec![
@@ -1313,6 +1342,7 @@ mod tests {
         )
         .await;
         let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
         let context_server_registry =
             cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@@ -1328,7 +1358,7 @@ mod tests {
                 cx,
             )
         });
-        let tool = Arc::new(EditFileTool::new(thread.downgrade()));
+        let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
 
         // Test edge cases
         let test_cases = vec![
@@ -1397,6 +1427,7 @@ mod tests {
         )
         .await;
         let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
         let context_server_registry =
             cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@@ -1412,7 +1443,7 @@ mod tests {
                 cx,
             )
         });
-        let tool = Arc::new(EditFileTool::new(thread.downgrade()));
+        let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
 
         // Test different EditFileMode values
         let modes = vec![
@@ -1478,6 +1509,7 @@ mod tests {
         init_test(cx);
         let fs = project::FakeFs::new(cx.executor());
         let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
         let context_server_registry =
             cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@@ -1493,7 +1525,7 @@ mod tests {
                 cx,
             )
         });
-        let tool = Arc::new(EditFileTool::new(thread.downgrade()));
+        let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
 
         assert_eq!(
             tool.initial_title(Err(json!({

crates/agent2/src/tools/web_search_tool.rs 🔗

@@ -80,33 +80,48 @@ impl AgentTool for WebSearchTool {
                 }
             };
 
-            let result_text = if response.results.len() == 1 {
-                "1 result".to_string()
-            } else {
-                format!("{} results", response.results.len())
-            };
-            event_stream.update_fields(acp::ToolCallUpdateFields {
-                title: Some(format!("Searched the web: {result_text}")),
-                content: Some(
-                    response
-                        .results
-                        .iter()
-                        .map(|result| acp::ToolCallContent::Content {
-                            content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
-                                name: result.title.clone(),
-                                uri: result.url.clone(),
-                                title: Some(result.title.clone()),
-                                description: Some(result.text.clone()),
-                                mime_type: None,
-                                annotations: None,
-                                size: None,
-                            }),
-                        })
-                        .collect(),
-                ),
-                ..Default::default()
-            });
+            emit_update(&response, &event_stream);
             Ok(WebSearchToolOutput(response))
         })
     }
+
+    fn replay(
+        &self,
+        _input: Self::Input,
+        output: Self::Output,
+        event_stream: ToolCallEventStream,
+        _cx: &mut App,
+    ) -> Result<()> {
+        emit_update(&output.0, &event_stream);
+        Ok(())
+    }
+}
+
+fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) {
+    let result_text = if response.results.len() == 1 {
+        "1 result".to_string()
+    } else {
+        format!("{} results", response.results.len())
+    };
+    event_stream.update_fields(acp::ToolCallUpdateFields {
+        title: Some(format!("Searched the web: {result_text}")),
+        content: Some(
+            response
+                .results
+                .iter()
+                .map(|result| acp::ToolCallContent::Content {
+                    content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
+                        name: result.title.clone(),
+                        uri: result.url.clone(),
+                        title: Some(result.title.clone()),
+                        description: Some(result.text.clone()),
+                        mime_type: None,
+                        annotations: None,
+                        size: None,
+                    }),
+                })
+                .collect(),
+        ),
+        ..Default::default()
+    });
 }