@@ -7,8 +7,8 @@ use cloud_llm_client::CompletionIntent;
use collections::HashSet;
use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
use indoc::formatdoc;
-use language::ToPoint;
use language::language_settings::{self, FormatOnSave};
+use language::{LanguageRegistry, ToPoint};
use language_model::LanguageModelToolResultContent;
use paths;
use project::lsp_store::{FormatTrigger, LspFormatTarget};
@@ -98,11 +98,13 @@ pub enum EditFileMode {
#[derive(Debug, Serialize, Deserialize)]
pub struct EditFileToolOutput {
+ #[serde(alias = "original_path")]
input_path: PathBuf,
- project_path: PathBuf,
new_text: String,
old_text: Arc<String>,
+ #[serde(default)]
diff: String,
+ #[serde(alias = "raw_output")]
edit_agent_output: EditAgentOutput,
}
@@ -123,11 +125,15 @@ impl From<EditFileToolOutput> for LanguageModelToolResultContent {
pub struct EditFileTool {
thread: WeakEntity<Thread>,
+ language_registry: Arc<LanguageRegistry>,
}
impl EditFileTool {
- pub fn new(thread: WeakEntity<Thread>) -> Self {
- Self { thread }
+ pub fn new(thread: WeakEntity<Thread>, language_registry: Arc<LanguageRegistry>) -> Self {
+ Self {
+ thread,
+ language_registry,
+ }
}
fn authorize(
@@ -419,7 +425,6 @@ impl AgentTool for EditFileTool {
Ok(EditFileToolOutput {
input_path: input.path,
- project_path: project_path.path.to_path_buf(),
new_text: new_text.clone(),
old_text,
diff: unified_diff,
@@ -427,6 +432,26 @@ impl AgentTool for EditFileTool {
})
})
}
+
+ fn replay(
+ &self,
+ _input: Self::Input,
+ output: Self::Output,
+ event_stream: ToolCallEventStream,
+ cx: &mut App,
+ ) -> Result<()> {
+ dbg!(&output);
+ event_stream.update_diff(cx.new(|cx| {
+ Diff::finalized(
+ output.input_path,
+ Some(output.old_text.to_string()),
+ output.new_text,
+ self.language_registry.clone(),
+ cx,
+ )
+ }));
+ Ok(())
+ }
}
/// Validate that the file path is valid, meaning:
@@ -515,6 +540,7 @@ mod tests {
let fs = project::FakeFs::new(cx.executor());
fs.insert_tree("/root", json!({})).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+ let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@@ -537,7 +563,7 @@ mod tests {
path: "root/nonexistent_file.txt".into(),
mode: EditFileMode::Edit,
};
- Arc::new(EditFileTool::new(thread.downgrade())).run(
+ Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
input,
ToolCallEventStream::test().0,
cx,
@@ -754,11 +780,11 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
- Arc::new(EditFileTool::new(thread.downgrade())).run(
- input,
- ToolCallEventStream::test().0,
- cx,
- )
+ Arc::new(EditFileTool::new(
+ thread.downgrade(),
+ language_registry.clone(),
+ ))
+ .run(input, ToolCallEventStream::test().0, cx)
});
// Stream the unformatted content
@@ -811,7 +837,7 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
- Arc::new(EditFileTool::new(thread.downgrade())).run(
+ Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
input,
ToolCallEventStream::test().0,
cx,
@@ -857,6 +883,7 @@ mod tests {
.unwrap();
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+ let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
@@ -896,11 +923,11 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
- Arc::new(EditFileTool::new(thread.downgrade())).run(
- input,
- ToolCallEventStream::test().0,
- cx,
- )
+ Arc::new(EditFileTool::new(
+ thread.downgrade(),
+ language_registry.clone(),
+ ))
+ .run(input, ToolCallEventStream::test().0, cx)
});
// Stream the content with trailing whitespace
@@ -948,7 +975,7 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
- Arc::new(EditFileTool::new(thread.downgrade())).run(
+ Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
input,
ToolCallEventStream::test().0,
cx,
@@ -985,6 +1012,7 @@ mod tests {
init_test(cx);
let fs = project::FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+ let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
@@ -1000,7 +1028,7 @@ mod tests {
cx,
)
});
- let tool = Arc::new(EditFileTool::new(thread.downgrade()));
+ let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
fs.insert_tree("/root", json!({})).await;
// Test 1: Path with .zed component should require confirmation
@@ -1122,6 +1150,7 @@ mod tests {
let fs = project::FakeFs::new(cx.executor());
fs.insert_tree("/project", json!({})).await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+ let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
@@ -1137,7 +1166,7 @@ mod tests {
cx,
)
});
- let tool = Arc::new(EditFileTool::new(thread.downgrade()));
+ let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test global config paths - these should require confirmation if they exist and are outside the project
let test_cases = vec![
@@ -1231,7 +1260,7 @@ mod tests {
cx,
)
.await;
-
+ let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@@ -1247,7 +1276,7 @@ mod tests {
cx,
)
});
- let tool = Arc::new(EditFileTool::new(thread.downgrade()));
+ let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test files in different worktrees
let test_cases = vec![
@@ -1313,6 +1342,7 @@ mod tests {
)
.await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+ let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@@ -1328,7 +1358,7 @@ mod tests {
cx,
)
});
- let tool = Arc::new(EditFileTool::new(thread.downgrade()));
+ let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test edge cases
let test_cases = vec![
@@ -1397,6 +1427,7 @@ mod tests {
)
.await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+ let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@@ -1412,7 +1443,7 @@ mod tests {
cx,
)
});
- let tool = Arc::new(EditFileTool::new(thread.downgrade()));
+ let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test different EditFileMode values
let modes = vec![
@@ -1478,6 +1509,7 @@ mod tests {
init_test(cx);
let fs = project::FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+ let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@@ -1493,7 +1525,7 @@ mod tests {
cx,
)
});
- let tool = Arc::new(EditFileTool::new(thread.downgrade()));
+ let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
assert_eq!(
tool.initial_title(Err(json!({
@@ -80,33 +80,48 @@ impl AgentTool for WebSearchTool {
}
};
- let result_text = if response.results.len() == 1 {
- "1 result".to_string()
- } else {
- format!("{} results", response.results.len())
- };
- event_stream.update_fields(acp::ToolCallUpdateFields {
- title: Some(format!("Searched the web: {result_text}")),
- content: Some(
- response
- .results
- .iter()
- .map(|result| acp::ToolCallContent::Content {
- content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
- name: result.title.clone(),
- uri: result.url.clone(),
- title: Some(result.title.clone()),
- description: Some(result.text.clone()),
- mime_type: None,
- annotations: None,
- size: None,
- }),
- })
- .collect(),
- ),
- ..Default::default()
- });
+ emit_update(&response, &event_stream);
Ok(WebSearchToolOutput(response))
})
}
+
+ fn replay(
+ &self,
+ _input: Self::Input,
+ output: Self::Output,
+ event_stream: ToolCallEventStream,
+ _cx: &mut App,
+ ) -> Result<()> {
+ emit_update(&output.0, &event_stream);
+ Ok(())
+ }
+}
+
+fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) {
+ let result_text = if response.results.len() == 1 {
+ "1 result".to_string()
+ } else {
+ format!("{} results", response.results.len())
+ };
+ event_stream.update_fields(acp::ToolCallUpdateFields {
+ title: Some(format!("Searched the web: {result_text}")),
+ content: Some(
+ response
+ .results
+ .iter()
+ .map(|result| acp::ToolCallContent::Content {
+ content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
+ name: result.title.clone(),
+ uri: result.url.clone(),
+ title: Some(result.title.clone()),
+ description: Some(result.text.clone()),
+ mime_type: None,
+ annotations: None,
+ size: None,
+ }),
+ })
+ .collect(),
+ ),
+ ..Default::default()
+ });
}