From a05066cd834da5cebe021fa85f8bd2d23228a70e Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Mon, 17 Mar 2025 18:50:16 -0300 Subject: [PATCH] assistant edit tool: Track read buffers and notify model of user edits (#26952) When the model reads file, we'll track the version it read, and let it know if the user makes edits to the buffer. This helps prevent edit failures because it'll know to re-read the file before. Release Notes: - N/A --- Cargo.lock | 1 + crates/assistant2/src/active_thread.rs | 2 +- crates/assistant2/src/thread.rs | 31 +++++++++++ crates/assistant_tool/Cargo.toml | 1 + crates/assistant_tool/src/assistant_tool.rs | 52 +++++++++++++------ crates/assistant_tools/src/edit_files_tool.rs | 4 +- crates/assistant_tools/src/read_file_tool.rs | 15 ++++-- 7 files changed, 83 insertions(+), 23 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 36182efd3dfbafa305d85c7bf931b992cf1ea914..6b0b113fdd9a94f31c132d9824ef6b63c71359ae 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -692,6 +692,7 @@ name = "assistant_tool" version = "0.1.0" dependencies = [ "anyhow", + "clock", "collections", "derive_more", "gpui", diff --git a/crates/assistant2/src/active_thread.rs b/crates/assistant2/src/active_thread.rs index 3ab8e0fabf1e8224326dbf982c431cc26c6e01d2..a9797166f30ea84da5a6a4d8d86d6f3b7b92eb59 100644 --- a/crates/assistant2/src/active_thread.rs +++ b/crates/assistant2/src/active_thread.rs @@ -361,7 +361,7 @@ impl ActiveThread { if self.thread.read(cx).all_tools_finished() { let pending_refresh_buffers = self.thread.update(cx, |thread, cx| { thread.action_log().update(cx, |action_log, _cx| { - action_log.take_pending_refresh_buffers() + action_log.take_stale_buffers_in_context() }) }); diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index ddd6732446e86bb3436d0b0f47a5852d436165dd..9e97a9d7c41282cec68cfdd7f8f398af699cfa82 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -1,3 +1,4 @@ +use std::fmt::Write as _; use std::io::Write; use std::sync::Arc; @@ -560,9 +561,39 @@ impl Thread { request.messages.push(context_message); } + self.attach_stale_files(&mut request.messages, cx); + request } + fn attach_stale_files(&self, messages: &mut Vec, cx: &App) { + const STALE_FILES_HEADER: &str = "These files changed since last read:"; + + let mut stale_message = String::new(); + + for stale_file in self.action_log.read(cx).stale_buffers(cx) { + let Some(file) = stale_file.read(cx).file() else { + continue; + }; + + if stale_message.is_empty() { + write!(&mut stale_message, "{}", STALE_FILES_HEADER).ok(); + } + + writeln!(&mut stale_message, "- {}", file.path().display()).ok(); + } + + if !stale_message.is_empty() { + let context_message = LanguageModelRequestMessage { + role: Role::User, + content: vec![stale_message.into()], + cache: false, + }; + + messages.push(context_message); + } + } + pub fn stream_completion( &mut self, request: LanguageModelRequest, diff --git a/crates/assistant_tool/Cargo.toml b/crates/assistant_tool/Cargo.toml index 70022fc02ca618709c010000153d00c8b580bddc..040a906bf3cbcae70389c075a9048aa0273e319a 100644 --- a/crates/assistant_tool/Cargo.toml +++ b/crates/assistant_tool/Cargo.toml @@ -14,6 +14,7 @@ path = "src/assistant_tool.rs" [dependencies] anyhow.workspace = true collections.workspace = true +clock.workspace = true derive_more.workspace = true gpui.workspace = true language.workspace = true diff --git a/crates/assistant_tool/src/assistant_tool.rs b/crates/assistant_tool/src/assistant_tool.rs index b466931d89ab116bd2ca3163f8798f6617ee2b16..22564bc37f7bdc918d8947c69bacfc0645b9177e 100644 --- a/crates/assistant_tool/src/assistant_tool.rs +++ b/crates/assistant_tool/src/assistant_tool.rs @@ -4,7 +4,7 @@ mod tool_working_set; use std::sync::Arc; use anyhow::Result; -use collections::HashSet; +use collections::{HashMap, HashSet}; use gpui::Context; use gpui::{App, Entity, SharedString, Task}; use language::Buffer; @@ -58,31 +58,53 @@ pub trait Tool: 'static + Send + Sync { /// Tracks actions performed by tools in a thread #[derive(Debug)] pub struct ActionLog { - changed_buffers: HashSet>, - pending_refresh: HashSet>, + /// Buffers that user manually added to the context, and whose content has + /// changed since the model last saw them. + stale_buffers_in_context: HashSet>, + /// Buffers that we want to notify the model about when they change. + tracked_buffers: HashMap, TrackedBuffer>, +} + +#[derive(Debug, Default)] +struct TrackedBuffer { + version: clock::Global, } impl ActionLog { /// Creates a new, empty action log. pub fn new() -> Self { Self { - changed_buffers: HashSet::default(), - pending_refresh: HashSet::default(), + stale_buffers_in_context: HashSet::default(), + tracked_buffers: HashMap::default(), + } + } + + /// Track a buffer as read, so we can notify the model about user edits. + pub fn buffer_read(&mut self, buffer: Entity, cx: &mut Context) { + let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default(); + tracked_buffer.version = buffer.read(cx).version(); + } + + /// Mark a buffer as edited, so we can refresh it in the context + pub fn buffer_edited(&mut self, buffers: HashSet>, cx: &mut Context) { + for buffer in &buffers { + let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default(); + tracked_buffer.version = buffer.read(cx).version(); } + + self.stale_buffers_in_context.extend(buffers); } - /// Registers buffers that have changed and need refreshing. - pub fn notify_buffers_changed( - &mut self, - buffers: HashSet>, - _cx: &mut Context, - ) { - self.changed_buffers.extend(buffers.clone()); - self.pending_refresh.extend(buffers); + /// Iterate over buffers changed since last read or edited by the model + pub fn stale_buffers<'a>(&'a self, cx: &'a App) -> impl Iterator> { + self.tracked_buffers + .iter() + .filter(|(buffer, tracked)| tracked.version != buffer.read(cx).version) + .map(|(buffer, _)| buffer) } /// Takes and returns the set of buffers pending refresh, clearing internal state. - pub fn take_pending_refresh_buffers(&mut self) -> HashSet> { - std::mem::take(&mut self.pending_refresh) + pub fn take_stale_buffers_in_context(&mut self) -> HashSet> { + std::mem::take(&mut self.stale_buffers_in_context) } } diff --git a/crates/assistant_tools/src/edit_files_tool.rs b/crates/assistant_tools/src/edit_files_tool.rs index 5cc81eda37fc80da525f72b427d1a82130007bf9..10a2454c3d8d0475141013a1d0bcee9dcbe6b2de 100644 --- a/crates/assistant_tools/src/edit_files_tool.rs +++ b/crates/assistant_tools/src/edit_files_tool.rs @@ -309,9 +309,7 @@ impl EditToolRequest { } self.action_log - .update(cx, |log, cx| { - log.notify_buffers_changed(self.changed_buffers, cx) - }) + .update(cx, |log, cx| log.buffer_edited(self.changed_buffers, cx)) .log_err(); let errors = self.parser.errors(); diff --git a/crates/assistant_tools/src/read_file_tool.rs b/crates/assistant_tools/src/read_file_tool.rs index d4c69df8e552909c11f0b5fa7d14a094dce70218..e1f012be44fa23b56d57649661c611c4e84b983c 100644 --- a/crates/assistant_tools/src/read_file_tool.rs +++ b/crates/assistant_tools/src/read_file_tool.rs @@ -49,7 +49,7 @@ impl Tool for ReadFileTool { input: serde_json::Value, _messages: &[LanguageModelRequestMessage], project: Entity, - _action_log: Entity, + action_log: Entity, cx: &mut App, ) -> Task> { let input = match serde_json::from_value::(input) { @@ -60,14 +60,15 @@ impl Tool for ReadFileTool { let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else { return Task::ready(Err(anyhow!("Path not found in project"))); }; - cx.spawn(|cx| async move { + + cx.spawn(|mut cx| async move { let buffer = cx .update(|cx| { project.update(cx, |project, cx| project.open_buffer(project_path, cx)) })? .await?; - buffer.read_with(&cx, |buffer, _cx| { + let result = buffer.read_with(&cx, |buffer, _cx| { if buffer .file() .map_or(false, |file| file.disk_state().exists()) @@ -76,7 +77,13 @@ impl Tool for ReadFileTool { } else { Err(anyhow!("File does not exist")) } - })? + })??; + + action_log.update(&mut cx, |log, cx| { + log.buffer_read(buffer, cx); + })?; + + anyhow::Ok(result) }) } }