Improve distinguishing user from agent edits (#34716)

Oleksiy Syvokon created

We no longer rely on the `author` field to tell if a change was made by
the user or the agent. The `author` can be set to `User` in many
situations that are not really user-made edits, such as saving a file,
accepting a change, auto-formatting, and more. I started tracking and
fixing some of these cases, but found that inspecting changes in
`diff_base` is a more reliable method.

Also, we no longer show empty diffs. For example, if the user adds a
line and then removes the same line, the final diff is empty, even
though the buffer is marked as user-changed. Now we won't show such
edit.

There are still some issues to address:

- When a user edits within an unaccepted agent-written block, this
change becomes a part of the agent's edit. Rejecting this block will
lose user edits. It won't be displayed in project notifications, either.

- Accepting an agent block counts as a user-made edit.

- Agent start to call `project_notifications` tool after seeing enough
auto-calls.

Release Notes:

- N/A

Change summary

crates/agent/src/thread.rs                   | 20 +++----
crates/assistant_tool/src/action_log.rs      | 56 +++++++++++----------
crates/assistant_tools/src/edit_file_tool.rs |  3 +
3 files changed, 42 insertions(+), 37 deletions(-)

Detailed changes

crates/agent/src/thread.rs 🔗

@@ -47,7 +47,7 @@ use std::{
     time::{Duration, Instant},
 };
 use thiserror::Error;
-use util::{ResultExt as _, debug_panic, post_inc};
+use util::{ResultExt as _, post_inc};
 use uuid::Uuid;
 use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
 
@@ -1582,23 +1582,21 @@ impl Thread {
         model: Arc<dyn LanguageModel>,
         cx: &mut App,
     ) -> Option<PendingToolUse> {
-        let action_log = self.action_log.read(cx);
-
-        if !action_log.has_unnotified_user_edits() {
-            return None;
-        }
-
         // Represent notification as a simulated `project_notifications` tool call
         let tool_name = Arc::from("project_notifications");
-        let Some(tool) = self.tools.read(cx).tool(&tool_name, cx) else {
-            debug_panic!("`project_notifications` tool not found");
-            return None;
-        };
+        let tool = self.tools.read(cx).tool(&tool_name, cx)?;
 
         if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
             return None;
         }
 
+        if self
+            .action_log
+            .update(cx, |log, cx| log.unnotified_user_edits(cx).is_none())
+        {
+            return None;
+        }
+
         let input = serde_json::json!({});
         let request = Arc::new(LanguageModelRequest::default()); // unused
         let window = None;

crates/assistant_tool/src/action_log.rs 🔗

@@ -51,23 +51,13 @@ impl ActionLog {
         Some(self.tracked_buffers.get(buffer)?.snapshot.clone())
     }
 
-    pub fn has_unnotified_user_edits(&self) -> bool {
-        self.tracked_buffers
-            .values()
-            .any(|tracked| tracked.has_unnotified_user_edits)
-    }
-
     /// Return a unified diff patch with user edits made since last read or notification
     pub fn unnotified_user_edits(&self, cx: &Context<Self>) -> Option<String> {
-        if !self.has_unnotified_user_edits() {
-            return None;
-        }
-
-        let unified_diff = self
+        let diffs = self
             .tracked_buffers
             .values()
             .filter_map(|tracked| {
-                if !tracked.has_unnotified_user_edits {
+                if !tracked.may_have_unnotified_user_edits {
                     return None;
                 }
 
@@ -95,9 +85,13 @@ impl ActionLog {
 
                 Some(result)
             })
-            .collect::<Vec<_>>()
-            .join("\n\n");
+            .collect::<Vec<_>>();
+
+        if diffs.is_empty() {
+            return None;
+        }
 
+        let unified_diff = diffs.join("\n\n");
         Some(unified_diff)
     }
 
@@ -106,7 +100,7 @@ impl ActionLog {
     pub fn flush_unnotified_user_edits(&mut self, cx: &Context<Self>) -> Option<String> {
         let patch = self.unnotified_user_edits(cx);
         self.tracked_buffers.values_mut().for_each(|tracked| {
-            tracked.has_unnotified_user_edits = false;
+            tracked.may_have_unnotified_user_edits = false;
             tracked.last_seen_base = tracked.diff_base.clone();
         });
         patch
@@ -185,7 +179,7 @@ impl ActionLog {
                     version: buffer.read(cx).version(),
                     diff,
                     diff_update: diff_update_tx,
-                    has_unnotified_user_edits: false,
+                    may_have_unnotified_user_edits: false,
                     _open_lsp_handle: open_lsp_handle,
                     _maintain_diff: cx.spawn({
                         let buffer = buffer.clone();
@@ -337,27 +331,34 @@ impl ActionLog {
                 let new_snapshot = buffer_snapshot.clone();
                 let unreviewed_edits = tracked_buffer.unreviewed_edits.clone();
                 let edits = diff_snapshots(&old_snapshot, &new_snapshot);
-                if let ChangeAuthor::User = author
-                    && !edits.is_empty()
-                {
-                    tracked_buffer.has_unnotified_user_edits = true;
-                }
+                let mut has_user_changes = false;
                 async move {
                     if let ChangeAuthor::User = author {
-                        apply_non_conflicting_edits(
+                        has_user_changes = apply_non_conflicting_edits(
                             &unreviewed_edits,
                             edits,
                             &mut base_text,
                             new_snapshot.as_rope(),
                         );
                     }
-                    (Arc::new(base_text.to_string()), base_text)
+
+                    (Arc::new(base_text.to_string()), base_text, has_user_changes)
                 }
             });
 
             anyhow::Ok(rebase)
         })??;
-        let (new_base_text, new_diff_base) = rebase.await;
+        let (new_base_text, new_diff_base, has_user_changes) = rebase.await;
+
+        this.update(cx, |this, _| {
+            let tracked_buffer = this
+                .tracked_buffers
+                .get_mut(buffer)
+                .context("buffer not tracked")
+                .unwrap();
+            tracked_buffer.may_have_unnotified_user_edits |= has_user_changes;
+        })?;
+
         Self::update_diff(
             this,
             buffer,
@@ -829,11 +830,12 @@ fn apply_non_conflicting_edits(
     edits: Vec<Edit<u32>>,
     old_text: &mut Rope,
     new_text: &Rope,
-) {
+) -> bool {
     let mut old_edits = patch.edits().iter().cloned().peekable();
     let mut new_edits = edits.into_iter().peekable();
     let mut applied_delta = 0i32;
     let mut rebased_delta = 0i32;
+    let mut has_made_changes = false;
 
     while let Some(mut new_edit) = new_edits.next() {
         let mut conflict = false;
@@ -883,8 +885,10 @@ fn apply_non_conflicting_edits(
                 &new_text.chunks_in_range(new_bytes).collect::<String>(),
             );
             applied_delta += new_edit.new_len() as i32 - new_edit.old_len() as i32;
+            has_made_changes = true;
         }
     }
+    has_made_changes
 }
 
 fn diff_snapshots(
@@ -958,7 +962,7 @@ struct TrackedBuffer {
     diff: Entity<BufferDiff>,
     snapshot: text::BufferSnapshot,
     diff_update: mpsc::UnboundedSender<(ChangeAuthor, text::BufferSnapshot)>,
-    has_unnotified_user_edits: bool,
+    may_have_unnotified_user_edits: bool,
     _open_lsp_handle: OpenLspBufferHandle,
     _maintain_diff: Task<()>,
     _subscription: Subscription,

crates/assistant_tools/src/edit_file_tool.rs 🔗

@@ -278,6 +278,9 @@ impl Tool for EditFileTool {
                 .unwrap_or(false);
 
             if format_on_save_enabled {
+                action_log.update(cx, |log, cx| {
+                    log.buffer_edited(buffer.clone(), cx);
+                })?;
                 let format_task = project.update(cx, |project, cx| {
                     project.format(
                         HashSet::from_iter([buffer.clone()]),