Cargo.lock 🔗
@@ -76,6 +76,7 @@ dependencies = [
"clock",
"collections",
"ctor",
+ "fs",
"futures 0.3.31",
"gpui",
"indoc",
Ben Brandt , Bennet Bo Fenner , and MrSubidubi created
Since the read times always correspond to an action log call anyway, we
can let the action log track this internally, and we don't have to
provide a reference to the Thread in as many tools.
Release Notes:
- N/A
---------
Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
Co-authored-by: MrSubidubi <dev@bahn.sh>
Cargo.lock | 1
crates/action_log/Cargo.toml | 1
crates/action_log/src/action_log.rs | 288 +++++++++++++++
crates/agent/src/tests/edit_file_thread_test.rs | 2
crates/agent/src/thread.rs | 9
crates/agent/src/tools/edit_file_tool.rs | 65 +--
crates/agent/src/tools/read_file_tool.rs | 211 ----------
crates/agent/src/tools/streaming_edit_file_tool.rs | 59 +--
crates/remote_server/src/remote_editing_tests.rs | 24 -
crates/zed/src/visual_test_runner.rs | 25 -
10 files changed, 349 insertions(+), 336 deletions(-)
@@ -76,6 +76,7 @@ dependencies = [
"clock",
"collections",
"ctor",
+ "fs",
"futures 0.3.31",
"gpui",
"indoc",
@@ -20,6 +20,7 @@ buffer_diff.workspace = true
log.workspace = true
clock.workspace = true
collections.workspace = true
+fs.workspace = true
futures.workspace = true
gpui.workspace = true
language.workspace = true
@@ -1,14 +1,20 @@
use anyhow::{Context as _, Result};
use buffer_diff::BufferDiff;
use clock;
-use collections::BTreeMap;
+use collections::{BTreeMap, HashMap};
+use fs::MTime;
use futures::{FutureExt, StreamExt, channel::mpsc};
use gpui::{
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
};
use language::{Anchor, Buffer, BufferEvent, Point, ToOffset, ToPoint};
use project::{Project, ProjectItem, lsp_store::OpenLspBufferHandle};
-use std::{cmp, ops::Range, sync::Arc};
+use std::{
+ cmp,
+ ops::Range,
+ path::{Path, PathBuf},
+ sync::Arc,
+};
use text::{Edit, Patch, Rope};
use util::{RangeExt, ResultExt as _};
@@ -54,6 +60,8 @@ pub struct ActionLog {
linked_action_log: Option<Entity<ActionLog>>,
/// Stores undo information for the most recent reject operation
last_reject_undo: Option<LastRejectUndo>,
+ /// Tracks the last time files were read by the agent, to detect external modifications
+ file_read_times: HashMap<PathBuf, MTime>,
}
impl ActionLog {
@@ -64,6 +72,7 @@ impl ActionLog {
project,
linked_action_log: None,
last_reject_undo: None,
+ file_read_times: HashMap::default(),
}
}
@@ -76,6 +85,32 @@ impl ActionLog {
&self.project
}
+ pub fn file_read_time(&self, path: &Path) -> Option<MTime> {
+ self.file_read_times.get(path).copied()
+ }
+
+ fn update_file_read_time(&mut self, buffer: &Entity<Buffer>, cx: &App) {
+ let buffer = buffer.read(cx);
+ if let Some(file) = buffer.file() {
+ if let Some(local_file) = file.as_local() {
+ if let Some(mtime) = file.disk_state().mtime() {
+ let abs_path = local_file.abs_path(cx);
+ self.file_read_times.insert(abs_path, mtime);
+ }
+ }
+ }
+ }
+
+ fn remove_file_read_time(&mut self, buffer: &Entity<Buffer>, cx: &App) {
+ let buffer = buffer.read(cx);
+ if let Some(file) = buffer.file() {
+ if let Some(local_file) = file.as_local() {
+ let abs_path = local_file.abs_path(cx);
+ self.file_read_times.remove(&abs_path);
+ }
+ }
+ }
+
fn track_buffer_internal(
&mut self,
buffer: Entity<Buffer>,
@@ -506,24 +541,69 @@ impl ActionLog {
/// Track a buffer as read by agent, so we can notify the model about user edits.
pub fn buffer_read(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
- if let Some(linked_action_log) = &mut self.linked_action_log {
- linked_action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
+ self.buffer_read_impl(buffer, true, cx);
+ }
+
+ fn buffer_read_impl(
+ &mut self,
+ buffer: Entity<Buffer>,
+ record_file_read_time: bool,
+ cx: &mut Context<Self>,
+ ) {
+ if let Some(linked_action_log) = &self.linked_action_log {
+ // We don't want to share read times since the other agent hasn't read it necessarily
+ linked_action_log.update(cx, |log, cx| {
+ log.buffer_read_impl(buffer.clone(), false, cx);
+ });
+ }
+ if record_file_read_time {
+ self.update_file_read_time(&buffer, cx);
}
self.track_buffer_internal(buffer, false, cx);
}
/// Mark a buffer as created by agent, so we can refresh it in the context
pub fn buffer_created(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
- if let Some(linked_action_log) = &mut self.linked_action_log {
- linked_action_log.update(cx, |log, cx| log.buffer_created(buffer.clone(), cx));
+ self.buffer_created_impl(buffer, true, cx);
+ }
+
+ fn buffer_created_impl(
+ &mut self,
+ buffer: Entity<Buffer>,
+ record_file_read_time: bool,
+ cx: &mut Context<Self>,
+ ) {
+ if let Some(linked_action_log) = &self.linked_action_log {
+ // We don't want to share read times since the other agent hasn't read it necessarily
+ linked_action_log.update(cx, |log, cx| {
+ log.buffer_created_impl(buffer.clone(), false, cx);
+ });
+ }
+ if record_file_read_time {
+ self.update_file_read_time(&buffer, cx);
}
self.track_buffer_internal(buffer, true, cx);
}
/// Mark a buffer as edited by agent, so we can refresh it in the context
pub fn buffer_edited(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
- if let Some(linked_action_log) = &mut self.linked_action_log {
- linked_action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
+ self.buffer_edited_impl(buffer, true, cx);
+ }
+
+ fn buffer_edited_impl(
+ &mut self,
+ buffer: Entity<Buffer>,
+ record_file_read_time: bool,
+ cx: &mut Context<Self>,
+ ) {
+ if let Some(linked_action_log) = &self.linked_action_log {
+ // We don't want to share read times since the other agent hasn't read it necessarily
+ linked_action_log.update(cx, |log, cx| {
+ log.buffer_edited_impl(buffer.clone(), false, cx);
+ });
+ }
+ if record_file_read_time {
+ self.update_file_read_time(&buffer, cx);
}
let new_version = buffer.read(cx).version();
let tracked_buffer = self.track_buffer_internal(buffer, false, cx);
@@ -536,6 +616,8 @@ impl ActionLog {
}
pub fn will_delete_buffer(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
+ // Ok to propagate file read time removal to linked action log
+ self.remove_file_read_time(&buffer, cx);
let has_linked_action_log = self.linked_action_log.is_some();
let tracked_buffer = self.track_buffer_internal(buffer.clone(), false, cx);
match tracked_buffer.status {
@@ -2976,6 +3058,196 @@ mod tests {
);
}
+ #[gpui::test]
+ async fn test_file_read_time_recorded_on_buffer_read(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/dir"), json!({"file": "hello world"}))
+ .await;
+ let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+
+ let file_path = project
+ .read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
+ .unwrap();
+ let buffer = project
+ .update(cx, |project, cx| project.open_buffer(file_path, cx))
+ .await
+ .unwrap();
+
+ let abs_path = PathBuf::from(path!("/dir/file"));
+ assert!(
+ action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_none()),
+ "file_read_time should be None before buffer_read"
+ );
+
+ cx.update(|cx| {
+ action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
+ });
+
+ assert!(
+ action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_some()),
+ "file_read_time should be recorded after buffer_read"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_file_read_time_recorded_on_buffer_edited(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/dir"), json!({"file": "hello world"}))
+ .await;
+ let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+
+ let file_path = project
+ .read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
+ .unwrap();
+ let buffer = project
+ .update(cx, |project, cx| project.open_buffer(file_path, cx))
+ .await
+ .unwrap();
+
+ let abs_path = PathBuf::from(path!("/dir/file"));
+ assert!(
+ action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_none()),
+ "file_read_time should be None before buffer_edited"
+ );
+
+ cx.update(|cx| {
+ action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
+ });
+
+ assert!(
+ action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_some()),
+ "file_read_time should be recorded after buffer_edited"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_file_read_time_recorded_on_buffer_created(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/dir"), json!({"file": "existing content"}))
+ .await;
+ let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+
+ let file_path = project
+ .read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
+ .unwrap();
+ let buffer = project
+ .update(cx, |project, cx| project.open_buffer(file_path, cx))
+ .await
+ .unwrap();
+
+ let abs_path = PathBuf::from(path!("/dir/file"));
+ assert!(
+ action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_none()),
+ "file_read_time should be None before buffer_created"
+ );
+
+ cx.update(|cx| {
+ action_log.update(cx, |log, cx| log.buffer_created(buffer.clone(), cx));
+ });
+
+ assert!(
+ action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_some()),
+ "file_read_time should be recorded after buffer_created"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_file_read_time_removed_on_delete(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/dir"), json!({"file": "hello world"}))
+ .await;
+ let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+
+ let file_path = project
+ .read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
+ .unwrap();
+ let buffer = project
+ .update(cx, |project, cx| project.open_buffer(file_path, cx))
+ .await
+ .unwrap();
+
+ let abs_path = PathBuf::from(path!("/dir/file"));
+
+ cx.update(|cx| {
+ action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
+ });
+ assert!(
+ action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_some()),
+ "file_read_time should exist after buffer_read"
+ );
+
+ cx.update(|cx| {
+ action_log.update(cx, |log, cx| log.will_delete_buffer(buffer.clone(), cx));
+ });
+ assert!(
+ action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_none()),
+ "file_read_time should be removed after will_delete_buffer"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_file_read_time_not_forwarded_to_linked_action_log(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/dir"), json!({"file": "hello world"}))
+ .await;
+ let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
+ let parent_log = cx.new(|_| ActionLog::new(project.clone()));
+ let child_log =
+ cx.new(|_| ActionLog::new(project.clone()).with_linked_action_log(parent_log.clone()));
+
+ let file_path = project
+ .read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
+ .unwrap();
+ let buffer = project
+ .update(cx, |project, cx| project.open_buffer(file_path, cx))
+ .await
+ .unwrap();
+
+ let abs_path = PathBuf::from(path!("/dir/file"));
+
+ cx.update(|cx| {
+ child_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
+ });
+ assert!(
+ child_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_some()),
+ "child should record file_read_time on buffer_read"
+ );
+ assert!(
+ parent_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_none()),
+ "parent should NOT get file_read_time from child's buffer_read"
+ );
+
+ cx.update(|cx| {
+ child_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
+ });
+ assert!(
+ parent_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_none()),
+ "parent should NOT get file_read_time from child's buffer_edited"
+ );
+
+ cx.update(|cx| {
+ child_log.update(cx, |log, cx| log.buffer_created(buffer.clone(), cx));
+ });
+ assert!(
+ parent_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_none()),
+ "parent should NOT get file_read_time from child's buffer_created"
+ );
+ }
+
#[derive(Debug, PartialEq)]
struct HunkStatus {
range: Range<Point>,
@@ -50,9 +50,9 @@ async fn test_edit_file_tool_in_thread_context(cx: &mut TestAppContext) {
// Add just the tools we need for this test
let language_registry = project.read(cx).languages().clone();
thread.add_tool(crate::ReadFileTool::new(
- cx.weak_entity(),
project.clone(),
thread.action_log().clone(),
+ true,
));
thread.add_tool(crate::EditFileTool::new(
project.clone(),
@@ -893,8 +893,6 @@ pub struct Thread {
pub(crate) prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
pub(crate) project: Entity<Project>,
pub(crate) action_log: Entity<ActionLog>,
- /// Tracks the last time files were read by the agent, to detect external modifications
- pub(crate) file_read_times: HashMap<PathBuf, fs::MTime>,
/// True if this thread was imported from a shared thread and can be synced.
imported: bool,
/// If this is a subagent thread, contains context about the parent
@@ -1014,7 +1012,6 @@ impl Thread {
prompt_capabilities_rx,
project,
action_log,
- file_read_times: HashMap::default(),
imported: false,
subagent_context: None,
draft_prompt: None,
@@ -1231,7 +1228,6 @@ impl Thread {
updated_at: db_thread.updated_at,
prompt_capabilities_tx,
prompt_capabilities_rx,
- file_read_times: HashMap::default(),
imported: db_thread.imported,
subagent_context: db_thread.subagent_context,
draft_prompt: db_thread.draft_prompt,
@@ -1436,6 +1432,9 @@ impl Thread {
environment: Rc<dyn ThreadEnvironment>,
cx: &mut Context<Self>,
) {
+ // Only update the agent location for the root thread, not for subagents.
+ let update_agent_location = self.parent_thread_id().is_none();
+
let language_registry = self.project.read(cx).languages().clone();
self.add_tool(CopyPathTool::new(self.project.clone()));
self.add_tool(CreateDirectoryTool::new(self.project.clone()));
@@ -1463,9 +1462,9 @@ impl Thread {
self.add_tool(NowTool);
self.add_tool(OpenTool::new(self.project.clone()));
self.add_tool(ReadFileTool::new(
- cx.weak_entity(),
self.project.clone(),
self.action_log.clone(),
+ update_agent_location,
));
self.add_tool(SaveFileTool::new(self.project.clone()));
self.add_tool(RestoreFileFromDiskTool::new(self.project.clone()));
@@ -305,13 +305,13 @@ impl AgentTool for EditFileTool {
// Check if the file has been modified since the agent last read it
if let Some(abs_path) = abs_path.as_ref() {
- let (last_read_mtime, current_mtime, is_dirty, has_save_tool, has_restore_tool) = self.thread.update(cx, |thread, cx| {
- let last_read = thread.file_read_times.get(abs_path).copied();
+ let last_read_mtime = action_log.read_with(cx, |log, _| log.file_read_time(abs_path));
+ let (current_mtime, is_dirty, has_save_tool, has_restore_tool) = self.thread.read_with(cx, |thread, cx| {
let current = buffer.read(cx).file().and_then(|file| file.disk_state().mtime());
let dirty = buffer.read(cx).is_dirty();
let has_save = thread.has_tool(SaveFileTool::NAME);
let has_restore = thread.has_tool(RestoreFileFromDiskTool::NAME);
- (last_read, current, dirty, has_save, has_restore)
+ (current, dirty, has_save, has_restore)
})?;
// Check for unsaved changes first - these indicate modifications we don't know about
@@ -470,17 +470,6 @@ impl AgentTool for EditFileTool {
log.buffer_edited(buffer.clone(), cx);
});
- // Update the recorded read time after a successful edit so consecutive edits work
- if let Some(abs_path) = abs_path.as_ref() {
- if let Some(new_mtime) = buffer.read_with(cx, |buffer, _| {
- buffer.file().and_then(|file| file.disk_state().mtime())
- }) {
- self.thread.update(cx, |thread, _| {
- thread.file_read_times.insert(abs_path.to_path_buf(), new_mtime);
- })?;
- }
- }
-
let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let (new_text, unified_diff) = cx
.background_spawn({
@@ -2212,14 +2201,18 @@ mod tests {
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
// Initially, file_read_times should be empty
- let is_empty = thread.read_with(cx, |thread, _| thread.file_read_times.is_empty());
+ let is_empty = action_log.read_with(cx, |action_log, _| {
+ action_log
+ .file_read_time(path!("/root/test.txt").as_ref())
+ .is_none()
+ });
assert!(is_empty, "file_read_times should start empty");
// Create read tool
let read_tool = Arc::new(crate::ReadFileTool::new(
- thread.downgrade(),
project.clone(),
- action_log,
+ action_log.clone(),
+ true,
));
// Read the file to record the read time
@@ -2238,12 +2231,9 @@ mod tests {
.unwrap();
// Verify that file_read_times now contains an entry for the file
- let has_entry = thread.read_with(cx, |thread, _| {
- thread.file_read_times.len() == 1
- && thread
- .file_read_times
- .keys()
- .any(|path| path.ends_with("test.txt"))
+ let has_entry = action_log.read_with(cx, |log, _| {
+ log.file_read_time(path!("/root/test.txt").as_ref())
+ .is_some()
});
assert!(
has_entry,
@@ -2265,11 +2255,14 @@ mod tests {
.await
.unwrap();
- // Should still have exactly one entry
- let has_one_entry = thread.read_with(cx, |thread, _| thread.file_read_times.len() == 1);
+ // Should still have an entry after re-reading
+ let has_entry = action_log.read_with(cx, |log, _| {
+ log.file_read_time(path!("/root/test.txt").as_ref())
+ .is_some()
+ });
assert!(
- has_one_entry,
- "file_read_times should still have one entry after re-reading"
+ has_entry,
+ "file_read_times should still have an entry after re-reading"
);
}
@@ -2309,11 +2302,7 @@ mod tests {
let languages = project.read_with(cx, |project, _| project.languages().clone());
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
- let read_tool = Arc::new(crate::ReadFileTool::new(
- thread.downgrade(),
- project.clone(),
- action_log,
- ));
+ let read_tool = Arc::new(crate::ReadFileTool::new(project.clone(), action_log, true));
let edit_tool = Arc::new(EditFileTool::new(
project.clone(),
thread.downgrade(),
@@ -2423,11 +2412,7 @@ mod tests {
let languages = project.read_with(cx, |project, _| project.languages().clone());
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
- let read_tool = Arc::new(crate::ReadFileTool::new(
- thread.downgrade(),
- project.clone(),
- action_log,
- ));
+ let read_tool = Arc::new(crate::ReadFileTool::new(project.clone(), action_log, true));
let edit_tool = Arc::new(EditFileTool::new(
project.clone(),
thread.downgrade(),
@@ -2534,11 +2519,7 @@ mod tests {
let languages = project.read_with(cx, |project, _| project.languages().clone());
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
- let read_tool = Arc::new(crate::ReadFileTool::new(
- thread.downgrade(),
- project.clone(),
- action_log,
- ));
+ let read_tool = Arc::new(crate::ReadFileTool::new(project.clone(), action_log, true));
let edit_tool = Arc::new(EditFileTool::new(
project.clone(),
thread.downgrade(),
@@ -2,7 +2,7 @@ use action_log::ActionLog;
use agent_client_protocol::{self as acp, ToolCallUpdateFields};
use anyhow::{Context as _, Result, anyhow};
use futures::FutureExt as _;
-use gpui::{App, Entity, SharedString, Task, WeakEntity};
+use gpui::{App, Entity, SharedString, Task};
use indoc::formatdoc;
use language::Point;
use language_model::{LanguageModelImage, LanguageModelToolResultContent};
@@ -21,7 +21,7 @@ use super::tool_permissions::{
ResolvedProjectPath, authorize_symlink_access, canonicalize_worktree_roots,
resolve_project_path,
};
-use crate::{AgentTool, Thread, ToolCallEventStream, ToolInput, outline};
+use crate::{AgentTool, ToolCallEventStream, ToolInput, outline};
/// Reads the content of the given file in the project.
///
@@ -56,21 +56,21 @@ pub struct ReadFileToolInput {
}
pub struct ReadFileTool {
- thread: WeakEntity<Thread>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
+ update_agent_location: bool,
}
impl ReadFileTool {
pub fn new(
- thread: WeakEntity<Thread>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
+ update_agent_location: bool,
) -> Self {
Self {
- thread,
project,
action_log,
+ update_agent_location,
}
}
}
@@ -119,7 +119,6 @@ impl AgentTool for ReadFileTool {
cx: &mut App,
) -> Task<Result<LanguageModelToolResultContent, LanguageModelToolResultContent>> {
let project = self.project.clone();
- let thread = self.thread.clone();
let action_log = self.action_log.clone();
cx.spawn(async move |cx| {
let input = input
@@ -257,20 +256,6 @@ impl AgentTool for ReadFileTool {
return Err(tool_content_err(format!("{file_path} not found")));
}
- // Record the file read time and mtime
- if let Some(mtime) = buffer.read_with(cx, |buffer, _| {
- buffer.file().and_then(|file| file.disk_state().mtime())
- }) {
- thread
- .update(cx, |thread, _| {
- thread.file_read_times.insert(abs_path.to_path_buf(), mtime);
- })
- .ok();
- }
-
-
- let update_agent_location = self.thread.read_with(cx, |thread, _cx| !thread.is_subagent()).unwrap_or_default();
-
let mut anchor = None;
// Check if specific line ranges are provided
@@ -330,7 +315,7 @@ impl AgentTool for ReadFileTool {
};
project.update(cx, |project, cx| {
- if update_agent_location {
+ if self.update_agent_location {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
@@ -362,13 +347,10 @@ impl AgentTool for ReadFileTool {
#[cfg(test)]
mod test {
use super::*;
- use crate::{ContextServerRegistry, Templates, Thread};
use agent_client_protocol as acp;
use fs::Fs as _;
use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
- use language_model::fake_provider::FakeLanguageModel;
use project::{FakeFs, Project};
- use prompt_store::ProjectContext;
use serde_json::json;
use settings::SettingsStore;
use std::path::PathBuf;
@@ -383,20 +365,7 @@ mod test {
fs.insert_tree(path!("/root"), json!({})).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
- let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|cx| {
- Thread::new(
- project.clone(),
- cx.new(|_cx| ProjectContext::default()),
- context_server_registry,
- Templates::new(),
- Some(model),
- cx,
- )
- });
- let tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log));
+ let tool = Arc::new(ReadFileTool::new(project, action_log, true));
let (event_stream, _) = ToolCallEventStream::test();
let result = cx
@@ -429,20 +398,7 @@ mod test {
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
- let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|cx| {
- Thread::new(
- project.clone(),
- cx.new(|_cx| ProjectContext::default()),
- context_server_registry,
- Templates::new(),
- Some(model),
- cx,
- )
- });
- let tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log));
+ let tool = Arc::new(ReadFileTool::new(project, action_log, true));
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
@@ -476,20 +432,7 @@ mod test {
let language_registry = project.read_with(cx, |project, _| project.languages().clone());
language_registry.add(language::rust_lang());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
- let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|cx| {
- Thread::new(
- project.clone(),
- cx.new(|_cx| ProjectContext::default()),
- context_server_registry,
- Templates::new(),
- Some(model),
- cx,
- )
- });
- let tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log));
+ let tool = Arc::new(ReadFileTool::new(project, action_log, true));
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
@@ -569,20 +512,7 @@ mod test {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
- let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|cx| {
- Thread::new(
- project.clone(),
- cx.new(|_cx| ProjectContext::default()),
- context_server_registry,
- Templates::new(),
- Some(model),
- cx,
- )
- });
- let tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log));
+ let tool = Arc::new(ReadFileTool::new(project, action_log, true));
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
@@ -614,20 +544,7 @@ mod test {
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
- let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|cx| {
- Thread::new(
- project.clone(),
- cx.new(|_cx| ProjectContext::default()),
- context_server_registry,
- Templates::new(),
- Some(model),
- cx,
- )
- });
- let tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log));
+ let tool = Arc::new(ReadFileTool::new(project, action_log, true));
// start_line of 0 should be treated as 1
let result = cx
@@ -757,20 +674,7 @@ mod test {
let project = Project::test(fs.clone(), [path!("/project_root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
- let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|cx| {
- Thread::new(
- project.clone(),
- cx.new(|_cx| ProjectContext::default()),
- context_server_registry,
- Templates::new(),
- Some(model),
- cx,
- )
- });
- let tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log));
+ let tool = Arc::new(ReadFileTool::new(project, action_log, true));
// Reading a file outside the project worktree should fail
let result = cx
@@ -965,20 +869,7 @@ mod test {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
- let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|cx| {
- Thread::new(
- project.clone(),
- cx.new(|_cx| ProjectContext::default()),
- context_server_registry,
- Templates::new(),
- Some(model),
- cx,
- )
- });
- let tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log));
+ let tool = Arc::new(ReadFileTool::new(project, action_log, true));
let (event_stream, mut event_rx) = ToolCallEventStream::test();
let read_task = cx.update(|cx| {
@@ -1084,24 +975,7 @@ mod test {
.await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
- let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|cx| {
- Thread::new(
- project.clone(),
- cx.new(|_cx| ProjectContext::default()),
- context_server_registry,
- Templates::new(),
- Some(model),
- cx,
- )
- });
- let tool = Arc::new(ReadFileTool::new(
- thread.downgrade(),
- project.clone(),
- action_log.clone(),
- ));
+ let tool = Arc::new(ReadFileTool::new(project.clone(), action_log.clone(), true));
// Test reading allowed files in worktree1
let result = cx
@@ -1288,24 +1162,7 @@ mod test {
cx.executor().run_until_parked();
let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
- let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|cx| {
- Thread::new(
- project.clone(),
- cx.new(|_cx| ProjectContext::default()),
- context_server_registry,
- Templates::new(),
- Some(model),
- cx,
- )
- });
- let tool = Arc::new(ReadFileTool::new(
- thread.downgrade(),
- project.clone(),
- action_log,
- ));
+ let tool = Arc::new(ReadFileTool::new(project.clone(), action_log, true));
let (event_stream, mut event_rx) = ToolCallEventStream::test();
let task = cx.update(|cx| {
@@ -1364,24 +1221,7 @@ mod test {
cx.executor().run_until_parked();
let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
- let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|cx| {
- Thread::new(
- project.clone(),
- cx.new(|_cx| ProjectContext::default()),
- context_server_registry,
- Templates::new(),
- Some(model),
- cx,
- )
- });
- let tool = Arc::new(ReadFileTool::new(
- thread.downgrade(),
- project.clone(),
- action_log,
- ));
+ let tool = Arc::new(ReadFileTool::new(project.clone(), action_log, true));
let (event_stream, mut event_rx) = ToolCallEventStream::test();
let task = cx.update(|cx| {
@@ -1444,24 +1284,7 @@ mod test {
cx.executor().run_until_parked();
let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
- let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|cx| {
- Thread::new(
- project.clone(),
- cx.new(|_cx| ProjectContext::default()),
- context_server_registry,
- Templates::new(),
- Some(model),
- cx,
- )
- });
- let tool = Arc::new(ReadFileTool::new(
- thread.downgrade(),
- project.clone(),
- action_log,
- ));
+ let tool = Arc::new(ReadFileTool::new(project.clone(), action_log, true));
let (event_stream, mut event_rx) = ToolCallEventStream::test();
let result = cx
@@ -483,7 +483,12 @@ impl EditSession {
.await
.map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?;
- ensure_buffer_saved(&buffer, &abs_path, tool, cx)?;
+ let action_log = tool
+ .thread
+ .read_with(cx, |thread, _cx| thread.action_log().clone())
+ .ok();
+
+ ensure_buffer_saved(&buffer, &abs_path, tool, action_log.as_ref(), cx)?;
let diff = cx.new(|cx| Diff::new(buffer.clone(), cx));
event_stream.update_diff(diff.clone());
@@ -495,13 +500,9 @@ impl EditSession {
}
}) as Box<dyn FnOnce()>);
- tool.thread
- .update(cx, |thread, cx| {
- thread
- .action_log()
- .update(cx, |log, cx| log.buffer_read(buffer.clone(), cx))
- })
- .ok();
+ if let Some(action_log) = &action_log {
+ action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
+ }
let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let old_text = cx
@@ -637,18 +638,6 @@ impl EditSession {
log.buffer_edited(buffer.clone(), cx);
});
- if let Some(new_mtime) = buffer.read_with(cx, |buffer, _| {
- buffer.file().and_then(|file| file.disk_state().mtime())
- }) {
- tool.thread
- .update(cx, |thread, _| {
- thread
- .file_read_times
- .insert(abs_path.to_path_buf(), new_mtime);
- })
- .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?;
- }
-
let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let (new_text, unified_diff) = cx
.background_spawn({
@@ -1018,10 +1007,12 @@ fn ensure_buffer_saved(
buffer: &Entity<Buffer>,
abs_path: &PathBuf,
tool: &StreamingEditFileTool,
+ action_log: Option<&Entity<ActionLog>>,
cx: &mut AsyncApp,
) -> Result<(), StreamingEditFileToolOutput> {
- let check_result = tool.thread.update(cx, |thread, cx| {
- let last_read = thread.file_read_times.get(abs_path).copied();
+ let last_read_mtime =
+ action_log.and_then(|log| log.read_with(cx, |log, _| log.file_read_time(abs_path)));
+ let check_result = tool.thread.read_with(cx, |thread, cx| {
let current = buffer
.read(cx)
.file()
@@ -1029,12 +1020,10 @@ fn ensure_buffer_saved(
let dirty = buffer.read(cx).is_dirty();
let has_save = thread.has_tool(SaveFileTool::NAME);
let has_restore = thread.has_tool(RestoreFileFromDiskTool::NAME);
- (last_read, current, dirty, has_save, has_restore)
+ (current, dirty, has_save, has_restore)
});
- let Ok((last_read_mtime, current_mtime, is_dirty, has_save_tool, has_restore_tool)) =
- check_result
- else {
+ let Ok((current_mtime, is_dirty, has_save_tool, has_restore_tool)) = check_result else {
return Ok(());
};
@@ -4006,11 +3995,7 @@ mod tests {
let languages = project.read_with(cx, |project, _| project.languages().clone());
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
- let read_tool = Arc::new(crate::ReadFileTool::new(
- thread.downgrade(),
- project.clone(),
- action_log,
- ));
+ let read_tool = Arc::new(crate::ReadFileTool::new(project.clone(), action_log, true));
let edit_tool = Arc::new(StreamingEditFileTool::new(
project.clone(),
thread.downgrade(),
@@ -4112,11 +4097,7 @@ mod tests {
let languages = project.read_with(cx, |project, _| project.languages().clone());
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
- let read_tool = Arc::new(crate::ReadFileTool::new(
- thread.downgrade(),
- project.clone(),
- action_log,
- ));
+ let read_tool = Arc::new(crate::ReadFileTool::new(project.clone(), action_log, true));
let edit_tool = Arc::new(StreamingEditFileTool::new(
project.clone(),
thread.downgrade(),
@@ -4225,11 +4206,7 @@ mod tests {
let languages = project.read_with(cx, |project, _| project.languages().clone());
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
- let read_tool = Arc::new(crate::ReadFileTool::new(
- thread.downgrade(),
- project.clone(),
- action_log,
- ));
+ let read_tool = Arc::new(crate::ReadFileTool::new(project.clone(), action_log, true));
let edit_tool = Arc::new(StreamingEditFileTool::new(
project.clone(),
thread.downgrade(),
@@ -2,15 +2,12 @@
/// The tests in this file assume that server_cx is running on Windows too.
/// We neead to find a way to test Windows-Non-Windows interactions.
use crate::headless_project::HeadlessProject;
-use agent::{
- AgentTool, ReadFileTool, ReadFileToolInput, Templates, Thread, ToolCallEventStream, ToolInput,
-};
+use agent::{AgentTool, ReadFileTool, ReadFileToolInput, ToolCallEventStream, ToolInput};
use client::{Client, UserStore};
use clock::FakeSystemClock;
use collections::{HashMap, HashSet};
use git::repository::DiffType;
-use language_model::{LanguageModelToolResultContent, fake_provider::FakeLanguageModel};
-use prompt_store::ProjectContext;
+use language_model::LanguageModelToolResultContent;
use extension::ExtensionHostProxy;
use fs::{FakeFs, Fs};
@@ -2065,27 +2062,12 @@ async fn test_remote_agent_fs_tool_calls(cx: &mut TestAppContext, server_cx: &mu
let action_log = cx.new(|_| action_log::ActionLog::new(project.clone()));
- // Create a minimal thread for the ReadFileTool
- let context_server_registry =
- cx.new(|cx| agent::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
- let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|cx| {
- Thread::new(
- project.clone(),
- cx.new(|_cx| ProjectContext::default()),
- context_server_registry,
- Templates::new(),
- Some(model),
- cx,
- )
- });
-
let input = ReadFileToolInput {
path: "project/b.txt".into(),
start_line: None,
end_line: None,
};
- let read_tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log));
+ let read_tool = Arc::new(ReadFileTool::new(project, action_log, true));
let (event_stream, _) = ToolCallEventStream::test();
let exists_result = cx.update(|cx| {
@@ -2032,32 +2032,9 @@ fn run_agent_thread_view_test(
// Create the necessary entities for the ReadFileTool
let action_log = cx.update(|cx| cx.new(|_| action_log::ActionLog::new(project.clone())));
- let context_server_registry = cx.update(|cx| {
- cx.new(|cx| agent::ContextServerRegistry::new(project.read(cx).context_server_store(), cx))
- });
- let fake_model = Arc::new(language_model::fake_provider::FakeLanguageModel::default());
- let project_context = cx.update(|cx| cx.new(|_| prompt_store::ProjectContext::default()));
-
- // Create the agent Thread
- let thread = cx.update(|cx| {
- cx.new(|cx| {
- agent::Thread::new(
- project.clone(),
- project_context,
- context_server_registry,
- agent::Templates::new(),
- Some(fake_model),
- cx,
- )
- })
- });
// Create the ReadFileTool
- let tool = Arc::new(agent::ReadFileTool::new(
- thread.downgrade(),
- project.clone(),
- action_log,
- ));
+ let tool = Arc::new(agent::ReadFileTool::new(project.clone(), action_log, true));
// Create a test event stream to capture tool output
let (event_stream, mut event_receiver) = agent::ToolCallEventStream::test();