agent: Push diffs of user edits to the agent (#34487)

Oleksiy Syvokon created

This change improves user/agent collaborative editing.

When the user edits files that are used by the agent, the
`project_notification` tool now pushes *diffs* of the changes, not just
file names. This helps the agent to stay up to date without needing to
re-read files.

Release Notes:

- Improved user/agent collaborative editing: agent now receives diffs of
user edits

Change summary

Cargo.lock                                               |   1 
crates/agent/src/thread.rs                               |  17 
crates/assistant_tool/Cargo.toml                         |   1 
crates/assistant_tool/src/action_log.rs                  | 282 ++++++++-
crates/assistant_tools/src/project_notifications_tool.rs |  50 
5 files changed, 274 insertions(+), 77 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -745,6 +745,7 @@ dependencies = [
  "futures 0.3.31",
  "gpui",
  "icons",
+ "indoc",
  "language",
  "language_model",
  "log",

crates/agent/src/thread.rs 🔗

@@ -1532,7 +1532,9 @@ impl Thread {
     ) -> Option<PendingToolUse> {
         let action_log = self.action_log.read(cx);
 
-        action_log.unnotified_stale_buffers(cx).next()?;
+        if !action_log.has_unnotified_user_edits() {
+            return None;
+        }
 
         // Represent notification as a simulated `project_notifications` tool call
         let tool_name = Arc::from("project_notifications");
@@ -3253,7 +3255,6 @@ mod tests {
     use futures::stream::BoxStream;
     use gpui::TestAppContext;
     use http_client;
-    use indoc::indoc;
     use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
     use language_model::{
         LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
@@ -3614,6 +3615,7 @@ fn main() {{
                 cx,
             );
         });
+        cx.run_until_parked();
 
         // We shouldn't have a stale buffer notification yet
         let notifications = thread.read_with(cx, |thread, _| {
@@ -3643,11 +3645,13 @@ fn main() {{
                 cx,
             )
         });
+        cx.run_until_parked();
 
         // Check for the stale buffer warning
         thread.update(cx, |thread, cx| {
             thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
         });
+        cx.run_until_parked();
 
         let notifications = thread.read_with(cx, |thread, _cx| {
             find_tool_uses(thread, "project_notifications")
@@ -3661,12 +3665,8 @@ fn main() {{
             panic!("`project_notifications` should return text");
         };
 
-        let expected_content = indoc! {"[The following is an auto-generated notification; do not reply]
-
-        These files have changed since the last read:
-        - code.rs
-        "};
-        assert_eq!(notification_content, expected_content);
+        assert!(notification_content.contains("These files have changed since the last read:"));
+        assert!(notification_content.contains("code.rs"));
 
         // Insert another user message and flush notifications again
         thread.update(cx, |thread, cx| {
@@ -3682,6 +3682,7 @@ fn main() {{
         thread.update(cx, |thread, cx| {
             thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
         });
+        cx.run_until_parked();
 
         // There should be no new notifications (we already flushed one)
         let notifications = thread.read_with(cx, |thread, _cx| {

crates/assistant_tool/Cargo.toml 🔗

@@ -40,6 +40,7 @@ collections = { workspace = true, features = ["test-support"] }
 clock = { workspace = true, features = ["test-support"] }
 ctor.workspace = true
 gpui = { workspace = true, features = ["test-support"] }
+indoc.workspace = true
 language = { workspace = true, features = ["test-support"] }
 language_model = { workspace = true, features = ["test-support"] }
 log.workspace = true

crates/assistant_tool/src/action_log.rs 🔗

@@ -8,7 +8,10 @@ use language::{Anchor, Buffer, BufferEvent, DiskState, Point, ToPoint};
 use project::{Project, ProjectItem, lsp_store::OpenLspBufferHandle};
 use std::{cmp, ops::Range, sync::Arc};
 use text::{Edit, Patch, Rope};
-use util::{RangeExt, ResultExt as _};
+use util::{
+    RangeExt, ResultExt as _,
+    paths::{PathStyle, RemotePathBuf},
+};
 
 /// Tracks actions performed by tools in a thread
 pub struct ActionLog {
@@ -18,8 +21,6 @@ pub struct ActionLog {
     edited_since_project_diagnostics_check: bool,
     /// The project this action log is associated with
     project: Entity<Project>,
-    /// Tracks which buffer versions have already been notified as changed externally
-    notified_versions: BTreeMap<Entity<Buffer>, clock::Global>,
 }
 
 impl ActionLog {
@@ -29,7 +30,6 @@ impl ActionLog {
             tracked_buffers: BTreeMap::default(),
             edited_since_project_diagnostics_check: false,
             project,
-            notified_versions: BTreeMap::default(),
         }
     }
 
@@ -51,6 +51,67 @@ impl ActionLog {
         Some(self.tracked_buffers.get(buffer)?.snapshot.clone())
     }
 
+    pub fn has_unnotified_user_edits(&self) -> bool {
+        self.tracked_buffers
+            .values()
+            .any(|tracked| tracked.has_unnotified_user_edits)
+    }
+
+    /// Return a unified diff patch with user edits made since last read or notification
+    pub fn unnotified_user_edits(&self, cx: &Context<Self>) -> Option<String> {
+        if !self.has_unnotified_user_edits() {
+            return None;
+        }
+
+        let unified_diff = self
+            .tracked_buffers
+            .values()
+            .filter_map(|tracked| {
+                if !tracked.has_unnotified_user_edits {
+                    return None;
+                }
+
+                let text_with_latest_user_edits = tracked.diff_base.to_string();
+                let text_with_last_seen_user_edits = tracked.last_seen_base.to_string();
+                if text_with_latest_user_edits == text_with_last_seen_user_edits {
+                    return None;
+                }
+                let patch = language::unified_diff(
+                    &text_with_last_seen_user_edits,
+                    &text_with_latest_user_edits,
+                );
+
+                let buffer = tracked.buffer.clone();
+                let file_path = buffer
+                    .read(cx)
+                    .file()
+                    .map(|file| RemotePathBuf::new(file.full_path(cx), PathStyle::Posix).to_proto())
+                    .unwrap_or_else(|| format!("buffer_{}", buffer.entity_id()));
+
+                let mut result = String::new();
+                result.push_str(&format!("--- a/{}\n", file_path));
+                result.push_str(&format!("+++ b/{}\n", file_path));
+                result.push_str(&patch);
+
+                Some(result)
+            })
+            .collect::<Vec<_>>()
+            .join("\n\n");
+
+        Some(unified_diff)
+    }
+
+    /// Return a unified diff patch with user edits made since last read/notification
+    /// and mark them as notified
+    pub fn flush_unnotified_user_edits(&mut self, cx: &Context<Self>) -> Option<String> {
+        let patch = self.unnotified_user_edits(cx);
+        self.tracked_buffers.values_mut().for_each(|tracked| {
+            tracked.has_unnotified_user_edits = false;
+            tracked.last_seen_base = tracked.diff_base.clone();
+        });
+        patch
+    }
+
     fn track_buffer_internal(
         &mut self,
         buffer: Entity<Buffer>,
@@ -59,7 +120,6 @@ impl ActionLog {
     ) -> &mut TrackedBuffer {
         let status = if is_created {
             if let Some(tracked) = self.tracked_buffers.remove(&buffer) {
-                self.notified_versions.remove(&buffer);
                 match tracked.status {
                     TrackedBufferStatus::Created {
                         existing_file_content,
@@ -101,26 +161,31 @@ impl ActionLog {
                 let diff = cx.new(|cx| BufferDiff::new(&text_snapshot, cx));
                 let (diff_update_tx, diff_update_rx) = mpsc::unbounded();
                 let diff_base;
+                let last_seen_base;
                 let unreviewed_edits;
                 if is_created {
                     diff_base = Rope::default();
+                    last_seen_base = Rope::default();
                     unreviewed_edits = Patch::new(vec![Edit {
                         old: 0..1,
                         new: 0..text_snapshot.max_point().row + 1,
                     }])
                 } else {
                     diff_base = buffer.read(cx).as_rope().clone();
+                    last_seen_base = diff_base.clone();
                     unreviewed_edits = Patch::default();
                 }
                 TrackedBuffer {
                     buffer: buffer.clone(),
                     diff_base,
+                    last_seen_base,
                     unreviewed_edits,
                     snapshot: text_snapshot.clone(),
                     status,
                     version: buffer.read(cx).version(),
                     diff,
                     diff_update: diff_update_tx,
+                    has_unnotified_user_edits: false,
                     _open_lsp_handle: open_lsp_handle,
                     _maintain_diff: cx.spawn({
                         let buffer = buffer.clone();
@@ -174,7 +239,6 @@ impl ActionLog {
                     // If the buffer had been edited by a tool, but it got
                     // deleted externally, we want to stop tracking it.
                     self.tracked_buffers.remove(&buffer);
-                    self.notified_versions.remove(&buffer);
                 }
                 cx.notify();
             }
@@ -188,7 +252,6 @@ impl ActionLog {
                     // resurrected externally, we want to clear the edits we
                     // were tracking and reset the buffer's state.
                     self.tracked_buffers.remove(&buffer);
-                    self.notified_versions.remove(&buffer);
                     self.track_buffer_internal(buffer, false, cx);
                 }
                 cx.notify();
@@ -262,19 +325,23 @@ impl ActionLog {
         buffer_snapshot: text::BufferSnapshot,
         cx: &mut AsyncApp,
     ) -> Result<()> {
-        let rebase = this.read_with(cx, |this, cx| {
+        let rebase = this.update(cx, |this, cx| {
             let tracked_buffer = this
                 .tracked_buffers
-                .get(buffer)
+                .get_mut(buffer)
                 .context("buffer not tracked")?;
 
+            if let ChangeAuthor::User = author {
+                tracked_buffer.has_unnotified_user_edits = true;
+            }
+
             let rebase = cx.background_spawn({
                 let mut base_text = tracked_buffer.diff_base.clone();
                 let old_snapshot = tracked_buffer.snapshot.clone();
                 let new_snapshot = buffer_snapshot.clone();
                 let unreviewed_edits = tracked_buffer.unreviewed_edits.clone();
+                let edits = diff_snapshots(&old_snapshot, &new_snapshot);
                 async move {
-                    let edits = diff_snapshots(&old_snapshot, &new_snapshot);
                     if let ChangeAuthor::User = author {
                         apply_non_conflicting_edits(
                             &unreviewed_edits,
@@ -494,7 +561,6 @@ impl ActionLog {
         match tracked_buffer.status {
             TrackedBufferStatus::Created { .. } => {
                 self.tracked_buffers.remove(&buffer);
-                self.notified_versions.remove(&buffer);
                 cx.notify();
             }
             TrackedBufferStatus::Modified => {
@@ -520,7 +586,6 @@ impl ActionLog {
         match tracked_buffer.status {
             TrackedBufferStatus::Deleted => {
                 self.tracked_buffers.remove(&buffer);
-                self.notified_versions.remove(&buffer);
                 cx.notify();
             }
             _ => {
@@ -629,7 +694,6 @@ impl ActionLog {
                 };
 
                 self.tracked_buffers.remove(&buffer);
-                self.notified_versions.remove(&buffer);
                 cx.notify();
                 task
             }
@@ -643,7 +707,6 @@ impl ActionLog {
 
                 // Clear all tracked edits for this buffer and start over as if we just read it.
                 self.tracked_buffers.remove(&buffer);
-                self.notified_versions.remove(&buffer);
                 self.buffer_read(buffer.clone(), cx);
                 cx.notify();
                 save
@@ -744,33 +807,6 @@ impl ActionLog {
             .collect()
     }
 
-    /// Returns stale buffers that haven't been notified yet
-    pub fn unnotified_stale_buffers<'a>(
-        &'a self,
-        cx: &'a App,
-    ) -> impl Iterator<Item = &'a Entity<Buffer>> {
-        self.stale_buffers(cx).filter(|buffer| {
-            let buffer_entity = buffer.read(cx);
-            self.notified_versions
-                .get(buffer)
-                .map_or(true, |notified_version| {
-                    *notified_version != buffer_entity.version
-                })
-        })
-    }
-
-    /// Marks the given buffers as notified at their current versions
-    pub fn mark_buffers_as_notified(
-        &mut self,
-        buffers: impl IntoIterator<Item = Entity<Buffer>>,
-        cx: &App,
-    ) {
-        for buffer in buffers {
-            let version = buffer.read(cx).version.clone();
-            self.notified_versions.insert(buffer, version);
-        }
-    }
-
     /// Iterate over buffers changed since last read or edited by the model
     pub fn stale_buffers<'a>(&'a self, cx: &'a App) -> impl Iterator<Item = &'a Entity<Buffer>> {
         self.tracked_buffers
@@ -914,12 +950,14 @@ enum TrackedBufferStatus {
 struct TrackedBuffer {
     buffer: Entity<Buffer>,
     diff_base: Rope,
+    last_seen_base: Rope,
     unreviewed_edits: Patch<u32>,
     status: TrackedBufferStatus,
     version: clock::Global,
     diff: Entity<BufferDiff>,
     snapshot: text::BufferSnapshot,
     diff_update: mpsc::UnboundedSender<(ChangeAuthor, text::BufferSnapshot)>,
+    has_unnotified_user_edits: bool,
     _open_lsp_handle: OpenLspBufferHandle,
     _maintain_diff: Task<()>,
     _subscription: Subscription,
@@ -950,6 +988,7 @@ mod tests {
     use super::*;
     use buffer_diff::DiffHunkStatusKind;
     use gpui::TestAppContext;
+    use indoc::indoc;
     use language::Point;
     use project::{FakeFs, Fs, Project, RemoveOptions};
     use rand::prelude::*;
@@ -1232,6 +1271,110 @@ mod tests {
         assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
     }
 
+    #[gpui::test(iterations = 10)]
+    async fn test_user_edits_notifications(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(
+            path!("/dir"),
+            json!({"file": indoc! {"
+            abc
+            def
+            ghi
+            jkl
+            mno"}}),
+        )
+        .await;
+        let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
+        let action_log = cx.new(|_| ActionLog::new(project.clone()));
+        let file_path = project
+            .read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
+            .unwrap();
+        let buffer = project
+            .update(cx, |project, cx| project.open_buffer(file_path, cx))
+            .await
+            .unwrap();
+
+        // Agent edits
+        cx.update(|cx| {
+            action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
+            buffer.update(cx, |buffer, cx| {
+                buffer
+                    .edit([(Point::new(1, 2)..Point::new(2, 3), "F\nGHI")], None, cx)
+                    .unwrap()
+            });
+            action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
+        });
+        cx.run_until_parked();
+        assert_eq!(
+            buffer.read_with(cx, |buffer, _| buffer.text()),
+            indoc! {"
+                abc
+                deF
+                GHI
+                jkl
+                mno"}
+        );
+        assert_eq!(
+            unreviewed_hunks(&action_log, cx),
+            vec![(
+                buffer.clone(),
+                vec![HunkStatus {
+                    range: Point::new(1, 0)..Point::new(3, 0),
+                    diff_status: DiffHunkStatusKind::Modified,
+                    old_text: "def\nghi\n".into(),
+                }],
+            )]
+        );
+
+        // User edits
+        buffer.update(cx, |buffer, cx| {
+            buffer.edit(
+                [
+                    (Point::new(0, 2)..Point::new(0, 2), "X"),
+                    (Point::new(3, 0)..Point::new(3, 0), "Y"),
+                ],
+                None,
+                cx,
+            )
+        });
+        cx.run_until_parked();
+        assert_eq!(
+            buffer.read_with(cx, |buffer, _| buffer.text()),
+            indoc! {"
+                abXc
+                deF
+                GHI
+                Yjkl
+                mno"}
+        );
+
+        // User edits should be stored separately from agent's
+        let user_edits = action_log.update(cx, |log, cx| log.unnotified_user_edits(cx));
+        assert_eq!(
+            user_edits.expect("should have some user edits"),
+            indoc! {"
+                --- a/dir/file
+                +++ b/dir/file
+                @@ -1,5 +1,5 @@
+                -abc
+                +abXc
+                 def
+                 ghi
+                -jkl
+                +Yjkl
+                 mno
+            "}
+        );
+
+        action_log.update(cx, |log, cx| {
+            log.keep_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(1, 0), cx)
+        });
+        cx.run_until_parked();
+        assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
+    }
+
     #[gpui::test(iterations = 10)]
     async fn test_creating_files(cx: &mut TestAppContext) {
         init_test(cx);
@@ -2221,4 +2364,61 @@ mod tests {
                 .collect()
         })
     }
+
+    #[gpui::test]
+    async fn test_format_patch(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(
+            path!("/dir"),
+            json!({"test.txt": "line 1\nline 2\nline 3\n"}),
+        )
+        .await;
+        let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
+        let action_log = cx.new(|_| ActionLog::new(project.clone()));
+
+        let file_path = project
+            .read_with(cx, |project, cx| {
+                project.find_project_path("dir/test.txt", cx)
+            })
+            .unwrap();
+        let buffer = project
+            .update(cx, |project, cx| project.open_buffer(file_path, cx))
+            .await
+            .unwrap();
+
+        cx.update(|cx| {
+            // Track the buffer and mark it as read first
+            action_log.update(cx, |log, cx| {
+                log.buffer_read(buffer.clone(), cx);
+            });
+
+            // Make some edits to create a patch
+            buffer.update(cx, |buffer, cx| {
+                buffer
+                    .edit([(Point::new(1, 0)..Point::new(1, 6), "CHANGED")], None, cx)
+                    .unwrap(); // Replace "line2" with "CHANGED"
+            });
+        });
+
+        cx.run_until_parked();
+
+        // Get the patch
+        let patch = action_log.update(cx, |log, cx| log.unnotified_user_edits(cx));
+
+        // Verify the patch format contains expected unified diff elements
+        assert_eq!(
+            patch.unwrap(),
+            indoc! {"
+            --- a/dir/test.txt
+            +++ b/dir/test.txt
+            @@ -1,3 +1,3 @@
+             line 1
+            -line 2
+            +CHANGED
+             line 3
+            "}
+        );
+    }
 }

crates/assistant_tools/src/project_notifications_tool.rs 🔗

@@ -6,7 +6,6 @@ use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchem
 use project::Project;
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
-use std::fmt::Write as _;
 use std::sync::Arc;
 use ui::IconName;
 
@@ -52,34 +51,22 @@ impl Tool for ProjectNotificationsTool {
         _window: Option<AnyWindowHandle>,
         cx: &mut App,
     ) -> ToolResult {
-        let mut stale_files = String::new();
-        let mut notified_buffers = Vec::new();
-
-        for stale_file in action_log.read(cx).unnotified_stale_buffers(cx) {
-            if let Some(file) = stale_file.read(cx).file() {
-                writeln!(&mut stale_files, "- {}", file.path().display()).ok();
-                notified_buffers.push(stale_file.clone());
-            }
-        }
-
-        if !notified_buffers.is_empty() {
-            action_log.update(cx, |log, cx| {
-                log.mark_buffers_as_notified(notified_buffers, cx);
-            });
-        }
-
-        let response = if stale_files.is_empty() {
-            "No new notifications".to_string()
-        } else {
-            // NOTE: Changes to this prompt require a symmetric update in the LLM Worker
-            const HEADER: &str = include_str!("./project_notifications_tool/prompt_header.txt");
-            format!("{HEADER}{stale_files}").replace("\r\n", "\n")
+        let Some(user_edits_diff) =
+            action_log.update(cx, |log, cx| log.flush_unnotified_user_edits(cx))
+        else {
+            return result("No new notifications");
         };
 
-        Task::ready(Ok(response.into())).into()
+        // NOTE: Changes to this prompt require a symmetric update in the LLM Worker
+        const HEADER: &str = include_str!("./project_notifications_tool/prompt_header.txt");
+        result(&format!("{HEADER}\n\n```diff\n{user_edits_diff}\n```\n").replace("\r\n", "\n"))
     }
 }
 
+fn result(response: &str) -> ToolResult {
+    Task::ready(Ok(response.to_string().into())).into()
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -123,6 +110,7 @@ mod tests {
         action_log.update(cx, |log, cx| {
             log.buffer_read(buffer.clone(), cx);
         });
+        cx.run_until_parked();
 
         // Run the tool before any changes
         let tool = Arc::new(ProjectNotificationsTool);
@@ -142,6 +130,7 @@ mod tests {
                 cx,
             )
         });
+        cx.run_until_parked();
 
         let response = result.output.await.unwrap();
         let response_text = match &response.content {
@@ -158,6 +147,7 @@ mod tests {
         buffer.update(cx, |buffer, cx| {
             buffer.edit([(1..1, "\nChange!\n")], None, cx);
         });
+        cx.run_until_parked();
 
         // Run the tool again
         let result = cx.update(|cx| {
@@ -171,6 +161,7 @@ mod tests {
                 cx,
             )
         });
+        cx.run_until_parked();
 
         // This time the buffer is stale, so the tool should return a notification
         let response = result.output.await.unwrap();
@@ -179,10 +170,12 @@ mod tests {
             _ => panic!("Expected text response"),
         };
 
-        let expected_content = "[The following is an auto-generated notification; do not reply]\n\nThese files have changed since the last read:\n- code.rs\n";
-        assert_eq!(
-            response_text.as_str(),
-            expected_content,
+        assert!(
+            response_text.contains("These files have changed"),
+            "Tool should return the stale buffer notification"
+        );
+        assert!(
+            response_text.contains("test/code.rs"),
             "Tool should return the stale buffer notification"
         );
 
@@ -198,6 +191,7 @@ mod tests {
                 cx,
             )
         });
+        cx.run_until_parked();
 
         let response = result.output.await.unwrap();
         let response_text = match &response.content {