tools: Send stale file notifications only once (#34026)

Oleksiy Syvokon created

Previously, we sent notifications repeatedly until the agent read a
file, which was often inefficient. With this change, we now send a
notification only once (unless the files are modified again, in which
case we'll send another notification).

Release Notes:

- N/A

Change summary

crates/agent/src/thread.rs                               | 55 +++++++--
crates/assistant_tool/src/action_log.rs                  | 40 +++++++
crates/assistant_tools/src/project_notifications_tool.rs | 41 ++++++
3 files changed, 117 insertions(+), 19 deletions(-)

Detailed changes

crates/agent/src/thread.rs 🔗

@@ -1516,7 +1516,7 @@ impl Thread {
     ) -> Option<PendingToolUse> {
         let action_log = self.action_log.read(cx);
 
-        action_log.stale_buffers(cx).next()?;
+        action_log.unnotified_stale_buffers(cx).next()?;
 
         // Represent notification as a simulated `project_notifications` tool call
         let tool_name = Arc::from("project_notifications");
@@ -3631,11 +3631,11 @@ fn main() {{
         });
 
         // We shouldn't have a stale buffer notification yet
-        let notification = thread.read_with(cx, |thread, _| {
-            find_tool_use(thread, "project_notifications")
+        let notifications = thread.read_with(cx, |thread, _| {
+            find_tool_uses(thread, "project_notifications")
         });
         assert!(
-            notification.is_none(),
+            notifications.is_empty(),
             "Should not have stale buffer notification before buffer is modified"
         );
 
@@ -3664,13 +3664,15 @@ fn main() {{
             thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
         });
 
-        let Some(notification_result) = thread.read_with(cx, |thread, _cx| {
-            find_tool_use(thread, "project_notifications")
-        }) else {
+        let notifications = thread.read_with(cx, |thread, _cx| {
+            find_tool_uses(thread, "project_notifications")
+        });
+
+        let [notification] = notifications.as_slice() else {
             panic!("Should have a `project_notifications` tool use");
         };
 
-        let Some(notification_content) = notification_result.content.to_str() else {
+        let Some(notification_content) = notification.content.to_str() else {
             panic!("`project_notifications` should return text");
         };
 
@@ -3680,19 +3682,46 @@ fn main() {{
         - code.rs
         "};
         assert_eq!(notification_content, expected_content);
+
+        // Insert another user message and flush notifications again
+        thread.update(cx, |thread, cx| {
+            thread.insert_user_message(
+                "Can you tell me more?",
+                ContextLoadResult::default(),
+                None,
+                Vec::new(),
+                cx,
+            )
+        });
+
+        thread.update(cx, |thread, cx| {
+            thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
+        });
+
+        // There should be no new notifications (we already flushed one)
+        let notifications = thread.read_with(cx, |thread, _cx| {
+            find_tool_uses(thread, "project_notifications")
+        });
+
+        assert_eq!(
+            notifications.len(),
+            1,
+            "Should still have only one notification after second flush - no duplicates"
+        );
     }
 
-    fn find_tool_use(thread: &Thread, tool_name: &str) -> Option<LanguageModelToolResult> {
+    fn find_tool_uses(thread: &Thread, tool_name: &str) -> Vec<LanguageModelToolResult> {
         thread
             .messages()
-            .filter_map(|message| {
+            .flat_map(|message| {
                 thread
                     .tool_results_for_message(message.id)
                     .into_iter()
-                    .find(|result| result.tool_name == tool_name.into())
+                    .filter(|result| result.tool_name == tool_name.into())
+                    .cloned()
+                    .collect::<Vec<_>>()
             })
-            .next()
-            .cloned()
+            .collect()
     }
 
     #[gpui::test]

crates/assistant_tool/src/action_log.rs 🔗

@@ -1,5 +1,6 @@
 use anyhow::{Context as _, Result};
 use buffer_diff::BufferDiff;
+use clock;
 use collections::BTreeMap;
 use futures::{FutureExt, StreamExt, channel::mpsc};
 use gpui::{App, AppContext, AsyncApp, Context, Entity, Subscription, Task, WeakEntity};
@@ -17,6 +18,8 @@ pub struct ActionLog {
     edited_since_project_diagnostics_check: bool,
     /// The project this action log is associated with
     project: Entity<Project>,
+    /// Tracks which buffer versions have already been notified as changed externally
+    notified_versions: BTreeMap<Entity<Buffer>, clock::Global>,
 }
 
 impl ActionLog {
@@ -26,6 +29,7 @@ impl ActionLog {
             tracked_buffers: BTreeMap::default(),
             edited_since_project_diagnostics_check: false,
             project,
+            notified_versions: BTreeMap::default(),
         }
     }
 
@@ -51,6 +55,7 @@ impl ActionLog {
     ) -> &mut TrackedBuffer {
         let status = if is_created {
             if let Some(tracked) = self.tracked_buffers.remove(&buffer) {
+                self.notified_versions.remove(&buffer);
                 match tracked.status {
                     TrackedBufferStatus::Created {
                         existing_file_content,
@@ -106,7 +111,7 @@ impl ActionLog {
                 TrackedBuffer {
                     buffer: buffer.clone(),
                     diff_base,
-                    unreviewed_edits: unreviewed_edits,
+                    unreviewed_edits,
                     snapshot: text_snapshot.clone(),
                     status,
                     version: buffer.read(cx).version(),
@@ -165,6 +170,7 @@ impl ActionLog {
                     // If the buffer had been edited by a tool, but it got
                     // deleted externally, we want to stop tracking it.
                     self.tracked_buffers.remove(&buffer);
+                    self.notified_versions.remove(&buffer);
                 }
                 cx.notify();
             }
@@ -178,6 +184,7 @@ impl ActionLog {
                     // resurrected externally, we want to clear the edits we
                     // were tracking and reset the buffer's state.
                     self.tracked_buffers.remove(&buffer);
+                    self.notified_versions.remove(&buffer);
                     self.track_buffer_internal(buffer, false, cx);
                 }
                 cx.notify();
@@ -483,6 +490,7 @@ impl ActionLog {
         match tracked_buffer.status {
             TrackedBufferStatus::Created { .. } => {
                 self.tracked_buffers.remove(&buffer);
+                self.notified_versions.remove(&buffer);
                 cx.notify();
             }
             TrackedBufferStatus::Modified => {
@@ -508,6 +516,7 @@ impl ActionLog {
         match tracked_buffer.status {
             TrackedBufferStatus::Deleted => {
                 self.tracked_buffers.remove(&buffer);
+                self.notified_versions.remove(&buffer);
                 cx.notify();
             }
             _ => {
@@ -616,6 +625,7 @@ impl ActionLog {
                 };
 
                 self.tracked_buffers.remove(&buffer);
+                self.notified_versions.remove(&buffer);
                 cx.notify();
                 task
             }
@@ -629,6 +639,7 @@ impl ActionLog {
 
                 // Clear all tracked edits for this buffer and start over as if we just read it.
                 self.tracked_buffers.remove(&buffer);
+                self.notified_versions.remove(&buffer);
                 self.buffer_read(buffer.clone(), cx);
                 cx.notify();
                 save
@@ -713,6 +724,33 @@ impl ActionLog {
             .collect()
     }
 
+    /// Returns stale buffers that haven't been notified yet
+    pub fn unnotified_stale_buffers<'a>(
+        &'a self,
+        cx: &'a App,
+    ) -> impl Iterator<Item = &'a Entity<Buffer>> {
+        self.stale_buffers(cx).filter(|buffer| {
+            let buffer_entity = buffer.read(cx);
+            self.notified_versions
+                .get(buffer)
+                .map_or(true, |notified_version| {
+                    *notified_version != buffer_entity.version
+                })
+        })
+    }
+
+    /// Marks the given buffers as notified at their current versions
+    pub fn mark_buffers_as_notified(
+        &mut self,
+        buffers: impl IntoIterator<Item = Entity<Buffer>>,
+        cx: &App,
+    ) {
+        for buffer in buffers {
+            let version = buffer.read(cx).version.clone();
+            self.notified_versions.insert(buffer, version);
+        }
+    }
+
     /// Iterate over buffers changed since last read or edited by the model
     pub fn stale_buffers<'a>(&'a self, cx: &'a App) -> impl Iterator<Item = &'a Entity<Buffer>> {
         self.tracked_buffers

crates/assistant_tools/src/project_notifications_tool.rs 🔗

@@ -53,15 +53,21 @@ impl Tool for ProjectNotificationsTool {
         cx: &mut App,
     ) -> ToolResult {
         let mut stale_files = String::new();
+        let mut notified_buffers = Vec::new();
 
-        let action_log = action_log.read(cx);
-
-        for stale_file in action_log.stale_buffers(cx) {
+        for stale_file in action_log.read(cx).unnotified_stale_buffers(cx) {
             if let Some(file) = stale_file.read(cx).file() {
                 writeln!(&mut stale_files, "- {}", file.path().display()).ok();
+                notified_buffers.push(stale_file.clone());
             }
         }
 
+        if !notified_buffers.is_empty() {
+            action_log.update(cx, |log, cx| {
+                log.mark_buffers_as_notified(notified_buffers, cx);
+            });
+        }
+
         let response = if stale_files.is_empty() {
             "No new notifications".to_string()
         } else {
@@ -155,11 +161,11 @@ mod tests {
 
         // Run the tool again
         let result = cx.update(|cx| {
-            tool.run(
+            tool.clone().run(
                 tool_input.clone(),
                 request.clone(),
                 project.clone(),
-                action_log,
+                action_log.clone(),
                 model.clone(),
                 None,
                 cx,
@@ -179,6 +185,31 @@ mod tests {
             expected_content,
             "Tool should return the stale buffer notification"
         );
+
+        // Run the tool once more without any changes - should get no new notifications
+        let result = cx.update(|cx| {
+            tool.run(
+                tool_input.clone(),
+                request.clone(),
+                project.clone(),
+                action_log,
+                model.clone(),
+                None,
+                cx,
+            )
+        });
+
+        let response = result.output.await.unwrap();
+        let response_text = match &response.content {
+            ToolResultContent::Text(text) => text.clone(),
+            _ => panic!("Expected text response"),
+        };
+
+        assert_eq!(
+            response_text.as_str(),
+            "No new notifications",
+            "Tool should return 'No new notifications' when running again without changes"
+        );
     }
 
     fn init_test(cx: &mut TestAppContext) {