diff --git a/Cargo.lock b/Cargo.lock index d79134c6145d3a6644f780097f7dd8f69eeae295..4e4d86b947be1f68d03b225d4a62747659c99bf8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -76,6 +76,7 @@ dependencies = [ "clock", "collections", "ctor", + "fs", "futures 0.3.31", "gpui", "indoc", diff --git a/crates/action_log/Cargo.toml b/crates/action_log/Cargo.toml index 8488df691e40ea3bcfc04f4f6f74964fba7863dd..b1a1bf824fb770b8378e596fd0c799a7cf98b13d 100644 --- a/crates/action_log/Cargo.toml +++ b/crates/action_log/Cargo.toml @@ -20,6 +20,7 @@ buffer_diff.workspace = true log.workspace = true clock.workspace = true collections.workspace = true +fs.workspace = true futures.workspace = true gpui.workspace = true language.workspace = true diff --git a/crates/action_log/src/action_log.rs b/crates/action_log/src/action_log.rs index 5f8a639c0559c10546fc5640dc240aeba9dde487..5679f3c58fe52057f7a4a0faa24d5b5db2b5e497 100644 --- a/crates/action_log/src/action_log.rs +++ b/crates/action_log/src/action_log.rs @@ -1,14 +1,20 @@ use anyhow::{Context as _, Result}; use buffer_diff::BufferDiff; use clock; -use collections::BTreeMap; +use collections::{BTreeMap, HashMap}; +use fs::MTime; use futures::{FutureExt, StreamExt, channel::mpsc}; use gpui::{ App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity, }; use language::{Anchor, Buffer, BufferEvent, Point, ToOffset, ToPoint}; use project::{Project, ProjectItem, lsp_store::OpenLspBufferHandle}; -use std::{cmp, ops::Range, sync::Arc}; +use std::{ + cmp, + ops::Range, + path::{Path, PathBuf}, + sync::Arc, +}; use text::{Edit, Patch, Rope}; use util::{RangeExt, ResultExt as _}; @@ -54,6 +60,8 @@ pub struct ActionLog { linked_action_log: Option>, /// Stores undo information for the most recent reject operation last_reject_undo: Option, + /// Tracks the last time files were read by the agent, to detect external modifications + file_read_times: HashMap, } impl ActionLog { @@ -64,6 +72,7 @@ impl ActionLog { project, linked_action_log: None, last_reject_undo: None, + file_read_times: HashMap::default(), } } @@ -76,6 +85,32 @@ impl ActionLog { &self.project } + pub fn file_read_time(&self, path: &Path) -> Option { + self.file_read_times.get(path).copied() + } + + fn update_file_read_time(&mut self, buffer: &Entity, cx: &App) { + let buffer = buffer.read(cx); + if let Some(file) = buffer.file() { + if let Some(local_file) = file.as_local() { + if let Some(mtime) = file.disk_state().mtime() { + let abs_path = local_file.abs_path(cx); + self.file_read_times.insert(abs_path, mtime); + } + } + } + } + + fn remove_file_read_time(&mut self, buffer: &Entity, cx: &App) { + let buffer = buffer.read(cx); + if let Some(file) = buffer.file() { + if let Some(local_file) = file.as_local() { + let abs_path = local_file.abs_path(cx); + self.file_read_times.remove(&abs_path); + } + } + } + fn track_buffer_internal( &mut self, buffer: Entity, @@ -506,24 +541,69 @@ impl ActionLog { /// Track a buffer as read by agent, so we can notify the model about user edits. pub fn buffer_read(&mut self, buffer: Entity, cx: &mut Context) { - if let Some(linked_action_log) = &mut self.linked_action_log { - linked_action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx)); + self.buffer_read_impl(buffer, true, cx); + } + + fn buffer_read_impl( + &mut self, + buffer: Entity, + record_file_read_time: bool, + cx: &mut Context, + ) { + if let Some(linked_action_log) = &self.linked_action_log { + // We don't want to share read times since the other agent hasn't read it necessarily + linked_action_log.update(cx, |log, cx| { + log.buffer_read_impl(buffer.clone(), false, cx); + }); + } + if record_file_read_time { + self.update_file_read_time(&buffer, cx); } self.track_buffer_internal(buffer, false, cx); } /// Mark a buffer as created by agent, so we can refresh it in the context pub fn buffer_created(&mut self, buffer: Entity, cx: &mut Context) { - if let Some(linked_action_log) = &mut self.linked_action_log { - linked_action_log.update(cx, |log, cx| log.buffer_created(buffer.clone(), cx)); + self.buffer_created_impl(buffer, true, cx); + } + + fn buffer_created_impl( + &mut self, + buffer: Entity, + record_file_read_time: bool, + cx: &mut Context, + ) { + if let Some(linked_action_log) = &self.linked_action_log { + // We don't want to share read times since the other agent hasn't read it necessarily + linked_action_log.update(cx, |log, cx| { + log.buffer_created_impl(buffer.clone(), false, cx); + }); + } + if record_file_read_time { + self.update_file_read_time(&buffer, cx); } self.track_buffer_internal(buffer, true, cx); } /// Mark a buffer as edited by agent, so we can refresh it in the context pub fn buffer_edited(&mut self, buffer: Entity, cx: &mut Context) { - if let Some(linked_action_log) = &mut self.linked_action_log { - linked_action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); + self.buffer_edited_impl(buffer, true, cx); + } + + fn buffer_edited_impl( + &mut self, + buffer: Entity, + record_file_read_time: bool, + cx: &mut Context, + ) { + if let Some(linked_action_log) = &self.linked_action_log { + // We don't want to share read times since the other agent hasn't read it necessarily + linked_action_log.update(cx, |log, cx| { + log.buffer_edited_impl(buffer.clone(), false, cx); + }); + } + if record_file_read_time { + self.update_file_read_time(&buffer, cx); } let new_version = buffer.read(cx).version(); let tracked_buffer = self.track_buffer_internal(buffer, false, cx); @@ -536,6 +616,8 @@ impl ActionLog { } pub fn will_delete_buffer(&mut self, buffer: Entity, cx: &mut Context) { + // Ok to propagate file read time removal to linked action log + self.remove_file_read_time(&buffer, cx); let has_linked_action_log = self.linked_action_log.is_some(); let tracked_buffer = self.track_buffer_internal(buffer.clone(), false, cx); match tracked_buffer.status { @@ -2976,6 +3058,196 @@ mod tests { ); } + #[gpui::test] + async fn test_file_read_time_recorded_on_buffer_read(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/dir"), json!({"file": "hello world"})) + .await; + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + + let file_path = project + .read_with(cx, |project, cx| project.find_project_path("dir/file", cx)) + .unwrap(); + let buffer = project + .update(cx, |project, cx| project.open_buffer(file_path, cx)) + .await + .unwrap(); + + let abs_path = PathBuf::from(path!("/dir/file")); + assert!( + action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_none()), + "file_read_time should be None before buffer_read" + ); + + cx.update(|cx| { + action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx)); + }); + + assert!( + action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_some()), + "file_read_time should be recorded after buffer_read" + ); + } + + #[gpui::test] + async fn test_file_read_time_recorded_on_buffer_edited(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/dir"), json!({"file": "hello world"})) + .await; + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + + let file_path = project + .read_with(cx, |project, cx| project.find_project_path("dir/file", cx)) + .unwrap(); + let buffer = project + .update(cx, |project, cx| project.open_buffer(file_path, cx)) + .await + .unwrap(); + + let abs_path = PathBuf::from(path!("/dir/file")); + assert!( + action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_none()), + "file_read_time should be None before buffer_edited" + ); + + cx.update(|cx| { + action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); + }); + + assert!( + action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_some()), + "file_read_time should be recorded after buffer_edited" + ); + } + + #[gpui::test] + async fn test_file_read_time_recorded_on_buffer_created(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/dir"), json!({"file": "existing content"})) + .await; + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + + let file_path = project + .read_with(cx, |project, cx| project.find_project_path("dir/file", cx)) + .unwrap(); + let buffer = project + .update(cx, |project, cx| project.open_buffer(file_path, cx)) + .await + .unwrap(); + + let abs_path = PathBuf::from(path!("/dir/file")); + assert!( + action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_none()), + "file_read_time should be None before buffer_created" + ); + + cx.update(|cx| { + action_log.update(cx, |log, cx| log.buffer_created(buffer.clone(), cx)); + }); + + assert!( + action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_some()), + "file_read_time should be recorded after buffer_created" + ); + } + + #[gpui::test] + async fn test_file_read_time_removed_on_delete(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/dir"), json!({"file": "hello world"})) + .await; + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + + let file_path = project + .read_with(cx, |project, cx| project.find_project_path("dir/file", cx)) + .unwrap(); + let buffer = project + .update(cx, |project, cx| project.open_buffer(file_path, cx)) + .await + .unwrap(); + + let abs_path = PathBuf::from(path!("/dir/file")); + + cx.update(|cx| { + action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx)); + }); + assert!( + action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_some()), + "file_read_time should exist after buffer_read" + ); + + cx.update(|cx| { + action_log.update(cx, |log, cx| log.will_delete_buffer(buffer.clone(), cx)); + }); + assert!( + action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_none()), + "file_read_time should be removed after will_delete_buffer" + ); + } + + #[gpui::test] + async fn test_file_read_time_not_forwarded_to_linked_action_log(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/dir"), json!({"file": "hello world"})) + .await; + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let parent_log = cx.new(|_| ActionLog::new(project.clone())); + let child_log = + cx.new(|_| ActionLog::new(project.clone()).with_linked_action_log(parent_log.clone())); + + let file_path = project + .read_with(cx, |project, cx| project.find_project_path("dir/file", cx)) + .unwrap(); + let buffer = project + .update(cx, |project, cx| project.open_buffer(file_path, cx)) + .await + .unwrap(); + + let abs_path = PathBuf::from(path!("/dir/file")); + + cx.update(|cx| { + child_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx)); + }); + assert!( + child_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_some()), + "child should record file_read_time on buffer_read" + ); + assert!( + parent_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_none()), + "parent should NOT get file_read_time from child's buffer_read" + ); + + cx.update(|cx| { + child_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); + }); + assert!( + parent_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_none()), + "parent should NOT get file_read_time from child's buffer_edited" + ); + + cx.update(|cx| { + child_log.update(cx, |log, cx| log.buffer_created(buffer.clone(), cx)); + }); + assert!( + parent_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_none()), + "parent should NOT get file_read_time from child's buffer_created" + ); + } + #[derive(Debug, PartialEq)] struct HunkStatus { range: Range, diff --git a/crates/agent/src/tests/edit_file_thread_test.rs b/crates/agent/src/tests/edit_file_thread_test.rs index 069bf0349299e6f4952f673cbf7607e52d48d9c5..3beb5cb0d51abc55fbf3cf0849ced248a9d1fa5c 100644 --- a/crates/agent/src/tests/edit_file_thread_test.rs +++ b/crates/agent/src/tests/edit_file_thread_test.rs @@ -50,9 +50,9 @@ async fn test_edit_file_tool_in_thread_context(cx: &mut TestAppContext) { // Add just the tools we need for this test let language_registry = project.read(cx).languages().clone(); thread.add_tool(crate::ReadFileTool::new( - cx.weak_entity(), project.clone(), thread.action_log().clone(), + true, )); thread.add_tool(crate::EditFileTool::new( project.clone(), diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 99d77456e3822ae12c65c0a419ceea18f13f41e8..616ae414d4d51a384a18460e8339fd07770fa6b9 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -893,8 +893,6 @@ pub struct Thread { pub(crate) prompt_capabilities_rx: watch::Receiver, pub(crate) project: Entity, pub(crate) action_log: Entity, - /// Tracks the last time files were read by the agent, to detect external modifications - pub(crate) file_read_times: HashMap, /// True if this thread was imported from a shared thread and can be synced. imported: bool, /// If this is a subagent thread, contains context about the parent @@ -1014,7 +1012,6 @@ impl Thread { prompt_capabilities_rx, project, action_log, - file_read_times: HashMap::default(), imported: false, subagent_context: None, draft_prompt: None, @@ -1231,7 +1228,6 @@ impl Thread { updated_at: db_thread.updated_at, prompt_capabilities_tx, prompt_capabilities_rx, - file_read_times: HashMap::default(), imported: db_thread.imported, subagent_context: db_thread.subagent_context, draft_prompt: db_thread.draft_prompt, @@ -1436,6 +1432,9 @@ impl Thread { environment: Rc, cx: &mut Context, ) { + // Only update the agent location for the root thread, not for subagents. + let update_agent_location = self.parent_thread_id().is_none(); + let language_registry = self.project.read(cx).languages().clone(); self.add_tool(CopyPathTool::new(self.project.clone())); self.add_tool(CreateDirectoryTool::new(self.project.clone())); @@ -1463,9 +1462,9 @@ impl Thread { self.add_tool(NowTool); self.add_tool(OpenTool::new(self.project.clone())); self.add_tool(ReadFileTool::new( - cx.weak_entity(), self.project.clone(), self.action_log.clone(), + update_agent_location, )); self.add_tool(SaveFileTool::new(self.project.clone())); self.add_tool(RestoreFileFromDiskTool::new(self.project.clone())); diff --git a/crates/agent/src/tools/edit_file_tool.rs b/crates/agent/src/tools/edit_file_tool.rs index d8c380eba326d089b848563cca04557e903ba0f4..29b08ac09db4417123403fd3915b8575791b2a4e 100644 --- a/crates/agent/src/tools/edit_file_tool.rs +++ b/crates/agent/src/tools/edit_file_tool.rs @@ -305,13 +305,13 @@ impl AgentTool for EditFileTool { // Check if the file has been modified since the agent last read it if let Some(abs_path) = abs_path.as_ref() { - let (last_read_mtime, current_mtime, is_dirty, has_save_tool, has_restore_tool) = self.thread.update(cx, |thread, cx| { - let last_read = thread.file_read_times.get(abs_path).copied(); + let last_read_mtime = action_log.read_with(cx, |log, _| log.file_read_time(abs_path)); + let (current_mtime, is_dirty, has_save_tool, has_restore_tool) = self.thread.read_with(cx, |thread, cx| { let current = buffer.read(cx).file().and_then(|file| file.disk_state().mtime()); let dirty = buffer.read(cx).is_dirty(); let has_save = thread.has_tool(SaveFileTool::NAME); let has_restore = thread.has_tool(RestoreFileFromDiskTool::NAME); - (last_read, current, dirty, has_save, has_restore) + (current, dirty, has_save, has_restore) })?; // Check for unsaved changes first - these indicate modifications we don't know about @@ -470,17 +470,6 @@ impl AgentTool for EditFileTool { log.buffer_edited(buffer.clone(), cx); }); - // Update the recorded read time after a successful edit so consecutive edits work - if let Some(abs_path) = abs_path.as_ref() { - if let Some(new_mtime) = buffer.read_with(cx, |buffer, _| { - buffer.file().and_then(|file| file.disk_state().mtime()) - }) { - self.thread.update(cx, |thread, _| { - thread.file_read_times.insert(abs_path.to_path_buf(), new_mtime); - })?; - } - } - let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); let (new_text, unified_diff) = cx .background_spawn({ @@ -2212,14 +2201,18 @@ mod tests { let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); // Initially, file_read_times should be empty - let is_empty = thread.read_with(cx, |thread, _| thread.file_read_times.is_empty()); + let is_empty = action_log.read_with(cx, |action_log, _| { + action_log + .file_read_time(path!("/root/test.txt").as_ref()) + .is_none() + }); assert!(is_empty, "file_read_times should start empty"); // Create read tool let read_tool = Arc::new(crate::ReadFileTool::new( - thread.downgrade(), project.clone(), - action_log, + action_log.clone(), + true, )); // Read the file to record the read time @@ -2238,12 +2231,9 @@ mod tests { .unwrap(); // Verify that file_read_times now contains an entry for the file - let has_entry = thread.read_with(cx, |thread, _| { - thread.file_read_times.len() == 1 - && thread - .file_read_times - .keys() - .any(|path| path.ends_with("test.txt")) + let has_entry = action_log.read_with(cx, |log, _| { + log.file_read_time(path!("/root/test.txt").as_ref()) + .is_some() }); assert!( has_entry, @@ -2265,11 +2255,14 @@ mod tests { .await .unwrap(); - // Should still have exactly one entry - let has_one_entry = thread.read_with(cx, |thread, _| thread.file_read_times.len() == 1); + // Should still have an entry after re-reading + let has_entry = action_log.read_with(cx, |log, _| { + log.file_read_time(path!("/root/test.txt").as_ref()) + .is_some() + }); assert!( - has_one_entry, - "file_read_times should still have one entry after re-reading" + has_entry, + "file_read_times should still have an entry after re-reading" ); } @@ -2309,11 +2302,7 @@ mod tests { let languages = project.read_with(cx, |project, _| project.languages().clone()); let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); - let read_tool = Arc::new(crate::ReadFileTool::new( - thread.downgrade(), - project.clone(), - action_log, - )); + let read_tool = Arc::new(crate::ReadFileTool::new(project.clone(), action_log, true)); let edit_tool = Arc::new(EditFileTool::new( project.clone(), thread.downgrade(), @@ -2423,11 +2412,7 @@ mod tests { let languages = project.read_with(cx, |project, _| project.languages().clone()); let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); - let read_tool = Arc::new(crate::ReadFileTool::new( - thread.downgrade(), - project.clone(), - action_log, - )); + let read_tool = Arc::new(crate::ReadFileTool::new(project.clone(), action_log, true)); let edit_tool = Arc::new(EditFileTool::new( project.clone(), thread.downgrade(), @@ -2534,11 +2519,7 @@ mod tests { let languages = project.read_with(cx, |project, _| project.languages().clone()); let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); - let read_tool = Arc::new(crate::ReadFileTool::new( - thread.downgrade(), - project.clone(), - action_log, - )); + let read_tool = Arc::new(crate::ReadFileTool::new(project.clone(), action_log, true)); let edit_tool = Arc::new(EditFileTool::new( project.clone(), thread.downgrade(), diff --git a/crates/agent/src/tools/read_file_tool.rs b/crates/agent/src/tools/read_file_tool.rs index 8cfc16ddf6174a190ffe7cc11921dc204b05b79d..f7a75bc63a1c461b65c3a2e6f74f2c70e0ca15f6 100644 --- a/crates/agent/src/tools/read_file_tool.rs +++ b/crates/agent/src/tools/read_file_tool.rs @@ -2,7 +2,7 @@ use action_log::ActionLog; use agent_client_protocol::{self as acp, ToolCallUpdateFields}; use anyhow::{Context as _, Result, anyhow}; use futures::FutureExt as _; -use gpui::{App, Entity, SharedString, Task, WeakEntity}; +use gpui::{App, Entity, SharedString, Task}; use indoc::formatdoc; use language::Point; use language_model::{LanguageModelImage, LanguageModelToolResultContent}; @@ -21,7 +21,7 @@ use super::tool_permissions::{ ResolvedProjectPath, authorize_symlink_access, canonicalize_worktree_roots, resolve_project_path, }; -use crate::{AgentTool, Thread, ToolCallEventStream, ToolInput, outline}; +use crate::{AgentTool, ToolCallEventStream, ToolInput, outline}; /// Reads the content of the given file in the project. /// @@ -56,21 +56,21 @@ pub struct ReadFileToolInput { } pub struct ReadFileTool { - thread: WeakEntity, project: Entity, action_log: Entity, + update_agent_location: bool, } impl ReadFileTool { pub fn new( - thread: WeakEntity, project: Entity, action_log: Entity, + update_agent_location: bool, ) -> Self { Self { - thread, project, action_log, + update_agent_location, } } } @@ -119,7 +119,6 @@ impl AgentTool for ReadFileTool { cx: &mut App, ) -> Task> { let project = self.project.clone(); - let thread = self.thread.clone(); let action_log = self.action_log.clone(); cx.spawn(async move |cx| { let input = input @@ -257,20 +256,6 @@ impl AgentTool for ReadFileTool { return Err(tool_content_err(format!("{file_path} not found"))); } - // Record the file read time and mtime - if let Some(mtime) = buffer.read_with(cx, |buffer, _| { - buffer.file().and_then(|file| file.disk_state().mtime()) - }) { - thread - .update(cx, |thread, _| { - thread.file_read_times.insert(abs_path.to_path_buf(), mtime); - }) - .ok(); - } - - - let update_agent_location = self.thread.read_with(cx, |thread, _cx| !thread.is_subagent()).unwrap_or_default(); - let mut anchor = None; // Check if specific line ranges are provided @@ -330,7 +315,7 @@ impl AgentTool for ReadFileTool { }; project.update(cx, |project, cx| { - if update_agent_location { + if self.update_agent_location { project.set_agent_location( Some(AgentLocation { buffer: buffer.downgrade(), @@ -362,13 +347,10 @@ impl AgentTool for ReadFileTool { #[cfg(test)] mod test { use super::*; - use crate::{ContextServerRegistry, Templates, Thread}; use agent_client_protocol as acp; use fs::Fs as _; use gpui::{AppContext, TestAppContext, UpdateGlobal as _}; - use language_model::fake_provider::FakeLanguageModel; use project::{FakeFs, Project}; - use prompt_store::ProjectContext; use serde_json::json; use settings::SettingsStore; use std::path::PathBuf; @@ -383,20 +365,7 @@ mod test { fs.insert_tree(path!("/root"), json!({})).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log)); + let tool = Arc::new(ReadFileTool::new(project, action_log, true)); let (event_stream, _) = ToolCallEventStream::test(); let result = cx @@ -429,20 +398,7 @@ mod test { .await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log)); + let tool = Arc::new(ReadFileTool::new(project, action_log, true)); let result = cx .update(|cx| { let input = ReadFileToolInput { @@ -476,20 +432,7 @@ mod test { let language_registry = project.read_with(cx, |project, _| project.languages().clone()); language_registry.add(language::rust_lang()); let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log)); + let tool = Arc::new(ReadFileTool::new(project, action_log, true)); let result = cx .update(|cx| { let input = ReadFileToolInput { @@ -569,20 +512,7 @@ mod test { let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log)); + let tool = Arc::new(ReadFileTool::new(project, action_log, true)); let result = cx .update(|cx| { let input = ReadFileToolInput { @@ -614,20 +544,7 @@ mod test { .await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log)); + let tool = Arc::new(ReadFileTool::new(project, action_log, true)); // start_line of 0 should be treated as 1 let result = cx @@ -757,20 +674,7 @@ mod test { let project = Project::test(fs.clone(), [path!("/project_root").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log)); + let tool = Arc::new(ReadFileTool::new(project, action_log, true)); // Reading a file outside the project worktree should fail let result = cx @@ -965,20 +869,7 @@ mod test { let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log)); + let tool = Arc::new(ReadFileTool::new(project, action_log, true)); let (event_stream, mut event_rx) = ToolCallEventStream::test(); let read_task = cx.update(|cx| { @@ -1084,24 +975,7 @@ mod test { .await; let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let tool = Arc::new(ReadFileTool::new( - thread.downgrade(), - project.clone(), - action_log.clone(), - )); + let tool = Arc::new(ReadFileTool::new(project.clone(), action_log.clone(), true)); // Test reading allowed files in worktree1 let result = cx @@ -1288,24 +1162,7 @@ mod test { cx.executor().run_until_parked(); let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let tool = Arc::new(ReadFileTool::new( - thread.downgrade(), - project.clone(), - action_log, - )); + let tool = Arc::new(ReadFileTool::new(project.clone(), action_log, true)); let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { @@ -1364,24 +1221,7 @@ mod test { cx.executor().run_until_parked(); let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let tool = Arc::new(ReadFileTool::new( - thread.downgrade(), - project.clone(), - action_log, - )); + let tool = Arc::new(ReadFileTool::new(project.clone(), action_log, true)); let (event_stream, mut event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { @@ -1444,24 +1284,7 @@ mod test { cx.executor().run_until_parked(); let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let tool = Arc::new(ReadFileTool::new( - thread.downgrade(), - project.clone(), - action_log, - )); + let tool = Arc::new(ReadFileTool::new(project.clone(), action_log, true)); let (event_stream, mut event_rx) = ToolCallEventStream::test(); let result = cx diff --git a/crates/agent/src/tools/streaming_edit_file_tool.rs b/crates/agent/src/tools/streaming_edit_file_tool.rs index 7e023d7d7e00c2eb13ea78467776816b13151796..62b96d569f34d65889abee6be803674dfa42e709 100644 --- a/crates/agent/src/tools/streaming_edit_file_tool.rs +++ b/crates/agent/src/tools/streaming_edit_file_tool.rs @@ -483,7 +483,12 @@ impl EditSession { .await .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?; - ensure_buffer_saved(&buffer, &abs_path, tool, cx)?; + let action_log = tool + .thread + .read_with(cx, |thread, _cx| thread.action_log().clone()) + .ok(); + + ensure_buffer_saved(&buffer, &abs_path, tool, action_log.as_ref(), cx)?; let diff = cx.new(|cx| Diff::new(buffer.clone(), cx)); event_stream.update_diff(diff.clone()); @@ -495,13 +500,9 @@ impl EditSession { } }) as Box); - tool.thread - .update(cx, |thread, cx| { - thread - .action_log() - .update(cx, |log, cx| log.buffer_read(buffer.clone(), cx)) - }) - .ok(); + if let Some(action_log) = &action_log { + action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx)); + } let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); let old_text = cx @@ -637,18 +638,6 @@ impl EditSession { log.buffer_edited(buffer.clone(), cx); }); - if let Some(new_mtime) = buffer.read_with(cx, |buffer, _| { - buffer.file().and_then(|file| file.disk_state().mtime()) - }) { - tool.thread - .update(cx, |thread, _| { - thread - .file_read_times - .insert(abs_path.to_path_buf(), new_mtime); - }) - .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?; - } - let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); let (new_text, unified_diff) = cx .background_spawn({ @@ -1018,10 +1007,12 @@ fn ensure_buffer_saved( buffer: &Entity, abs_path: &PathBuf, tool: &StreamingEditFileTool, + action_log: Option<&Entity>, cx: &mut AsyncApp, ) -> Result<(), StreamingEditFileToolOutput> { - let check_result = tool.thread.update(cx, |thread, cx| { - let last_read = thread.file_read_times.get(abs_path).copied(); + let last_read_mtime = + action_log.and_then(|log| log.read_with(cx, |log, _| log.file_read_time(abs_path))); + let check_result = tool.thread.read_with(cx, |thread, cx| { let current = buffer .read(cx) .file() @@ -1029,12 +1020,10 @@ fn ensure_buffer_saved( let dirty = buffer.read(cx).is_dirty(); let has_save = thread.has_tool(SaveFileTool::NAME); let has_restore = thread.has_tool(RestoreFileFromDiskTool::NAME); - (last_read, current, dirty, has_save, has_restore) + (current, dirty, has_save, has_restore) }); - let Ok((last_read_mtime, current_mtime, is_dirty, has_save_tool, has_restore_tool)) = - check_result - else { + let Ok((current_mtime, is_dirty, has_save_tool, has_restore_tool)) = check_result else { return Ok(()); }; @@ -4006,11 +3995,7 @@ mod tests { let languages = project.read_with(cx, |project, _| project.languages().clone()); let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); - let read_tool = Arc::new(crate::ReadFileTool::new( - thread.downgrade(), - project.clone(), - action_log, - )); + let read_tool = Arc::new(crate::ReadFileTool::new(project.clone(), action_log, true)); let edit_tool = Arc::new(StreamingEditFileTool::new( project.clone(), thread.downgrade(), @@ -4112,11 +4097,7 @@ mod tests { let languages = project.read_with(cx, |project, _| project.languages().clone()); let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); - let read_tool = Arc::new(crate::ReadFileTool::new( - thread.downgrade(), - project.clone(), - action_log, - )); + let read_tool = Arc::new(crate::ReadFileTool::new(project.clone(), action_log, true)); let edit_tool = Arc::new(StreamingEditFileTool::new( project.clone(), thread.downgrade(), @@ -4225,11 +4206,7 @@ mod tests { let languages = project.read_with(cx, |project, _| project.languages().clone()); let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); - let read_tool = Arc::new(crate::ReadFileTool::new( - thread.downgrade(), - project.clone(), - action_log, - )); + let read_tool = Arc::new(crate::ReadFileTool::new(project.clone(), action_log, true)); let edit_tool = Arc::new(StreamingEditFileTool::new( project.clone(), thread.downgrade(), diff --git a/crates/remote_server/src/remote_editing_tests.rs b/crates/remote_server/src/remote_editing_tests.rs index 778f7292d2a032df6995169852deeecee6fa76a7..9b9fe9948ace530d7e55d2843952ca5c9efb3749 100644 --- a/crates/remote_server/src/remote_editing_tests.rs +++ b/crates/remote_server/src/remote_editing_tests.rs @@ -2,15 +2,12 @@ /// The tests in this file assume that server_cx is running on Windows too. /// We neead to find a way to test Windows-Non-Windows interactions. use crate::headless_project::HeadlessProject; -use agent::{ - AgentTool, ReadFileTool, ReadFileToolInput, Templates, Thread, ToolCallEventStream, ToolInput, -}; +use agent::{AgentTool, ReadFileTool, ReadFileToolInput, ToolCallEventStream, ToolInput}; use client::{Client, UserStore}; use clock::FakeSystemClock; use collections::{HashMap, HashSet}; use git::repository::DiffType; -use language_model::{LanguageModelToolResultContent, fake_provider::FakeLanguageModel}; -use prompt_store::ProjectContext; +use language_model::LanguageModelToolResultContent; use extension::ExtensionHostProxy; use fs::{FakeFs, Fs}; @@ -2065,27 +2062,12 @@ async fn test_remote_agent_fs_tool_calls(cx: &mut TestAppContext, server_cx: &mu let action_log = cx.new(|_| action_log::ActionLog::new(project.clone())); - // Create a minimal thread for the ReadFileTool - let context_server_registry = - cx.new(|cx| agent::ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let input = ReadFileToolInput { path: "project/b.txt".into(), start_line: None, end_line: None, }; - let read_tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log)); + let read_tool = Arc::new(ReadFileTool::new(project, action_log, true)); let (event_stream, _) = ToolCallEventStream::test(); let exists_result = cx.update(|cx| { diff --git a/crates/zed/src/visual_test_runner.rs b/crates/zed/src/visual_test_runner.rs index df673f0b4869af8fa55b0e83af10553df8afb4d8..8f005fa68b6accb5cf5686157bbb065e33bb1b0c 100644 --- a/crates/zed/src/visual_test_runner.rs +++ b/crates/zed/src/visual_test_runner.rs @@ -2032,32 +2032,9 @@ fn run_agent_thread_view_test( // Create the necessary entities for the ReadFileTool let action_log = cx.update(|cx| cx.new(|_| action_log::ActionLog::new(project.clone()))); - let context_server_registry = cx.update(|cx| { - cx.new(|cx| agent::ContextServerRegistry::new(project.read(cx).context_server_store(), cx)) - }); - let fake_model = Arc::new(language_model::fake_provider::FakeLanguageModel::default()); - let project_context = cx.update(|cx| cx.new(|_| prompt_store::ProjectContext::default())); - - // Create the agent Thread - let thread = cx.update(|cx| { - cx.new(|cx| { - agent::Thread::new( - project.clone(), - project_context, - context_server_registry, - agent::Templates::new(), - Some(fake_model), - cx, - ) - }) - }); // Create the ReadFileTool - let tool = Arc::new(agent::ReadFileTool::new( - thread.downgrade(), - project.clone(), - action_log, - )); + let tool = Arc::new(agent::ReadFileTool::new(project.clone(), action_log, true)); // Create a test event stream to capture tool output let (event_stream, mut event_receiver) = agent::ToolCallEventStream::test();