Cargo.lock 🔗
@@ -745,6 +745,7 @@ dependencies = [
"futures 0.3.31",
"gpui",
"icons",
+ "indoc",
"language",
"language_model",
"log",
Oleksiy Syvokon created
This change improves user/agent collaborative editing.
When the user edits files that are used by the agent, the
`project_notification` tool now pushes *diffs* of the changes, not just
file names. This helps the agent to stay up to date without needing to
re-read files.
Release Notes:
- Improved user/agent collaborative editing: agent now receives diffs of
user edits
Cargo.lock | 1
crates/agent/src/thread.rs | 17
crates/assistant_tool/Cargo.toml | 1
crates/assistant_tool/src/action_log.rs | 282 ++++++++-
crates/assistant_tools/src/project_notifications_tool.rs | 50
5 files changed, 274 insertions(+), 77 deletions(-)
@@ -745,6 +745,7 @@ dependencies = [
"futures 0.3.31",
"gpui",
"icons",
+ "indoc",
"language",
"language_model",
"log",
@@ -1532,7 +1532,9 @@ impl Thread {
) -> Option<PendingToolUse> {
let action_log = self.action_log.read(cx);
- action_log.unnotified_stale_buffers(cx).next()?;
+ if !action_log.has_unnotified_user_edits() {
+ return None;
+ }
// Represent notification as a simulated `project_notifications` tool call
let tool_name = Arc::from("project_notifications");
@@ -3253,7 +3255,6 @@ mod tests {
use futures::stream::BoxStream;
use gpui::TestAppContext;
use http_client;
- use indoc::indoc;
use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
use language_model::{
LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
@@ -3614,6 +3615,7 @@ fn main() {{
cx,
);
});
+ cx.run_until_parked();
// We shouldn't have a stale buffer notification yet
let notifications = thread.read_with(cx, |thread, _| {
@@ -3643,11 +3645,13 @@ fn main() {{
cx,
)
});
+ cx.run_until_parked();
// Check for the stale buffer warning
thread.update(cx, |thread, cx| {
thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
});
+ cx.run_until_parked();
let notifications = thread.read_with(cx, |thread, _cx| {
find_tool_uses(thread, "project_notifications")
@@ -3661,12 +3665,8 @@ fn main() {{
panic!("`project_notifications` should return text");
};
- let expected_content = indoc! {"[The following is an auto-generated notification; do not reply]
-
- These files have changed since the last read:
- - code.rs
- "};
- assert_eq!(notification_content, expected_content);
+ assert!(notification_content.contains("These files have changed since the last read:"));
+ assert!(notification_content.contains("code.rs"));
// Insert another user message and flush notifications again
thread.update(cx, |thread, cx| {
@@ -3682,6 +3682,7 @@ fn main() {{
thread.update(cx, |thread, cx| {
thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
});
+ cx.run_until_parked();
// There should be no new notifications (we already flushed one)
let notifications = thread.read_with(cx, |thread, _cx| {
@@ -40,6 +40,7 @@ collections = { workspace = true, features = ["test-support"] }
clock = { workspace = true, features = ["test-support"] }
ctor.workspace = true
gpui = { workspace = true, features = ["test-support"] }
+indoc.workspace = true
language = { workspace = true, features = ["test-support"] }
language_model = { workspace = true, features = ["test-support"] }
log.workspace = true
@@ -8,7 +8,10 @@ use language::{Anchor, Buffer, BufferEvent, DiskState, Point, ToPoint};
use project::{Project, ProjectItem, lsp_store::OpenLspBufferHandle};
use std::{cmp, ops::Range, sync::Arc};
use text::{Edit, Patch, Rope};
-use util::{RangeExt, ResultExt as _};
+use util::{
+ RangeExt, ResultExt as _,
+ paths::{PathStyle, RemotePathBuf},
+};
/// Tracks actions performed by tools in a thread
pub struct ActionLog {
@@ -18,8 +21,6 @@ pub struct ActionLog {
edited_since_project_diagnostics_check: bool,
/// The project this action log is associated with
project: Entity<Project>,
- /// Tracks which buffer versions have already been notified as changed externally
- notified_versions: BTreeMap<Entity<Buffer>, clock::Global>,
}
impl ActionLog {
@@ -29,7 +30,6 @@ impl ActionLog {
tracked_buffers: BTreeMap::default(),
edited_since_project_diagnostics_check: false,
project,
- notified_versions: BTreeMap::default(),
}
}
@@ -51,6 +51,67 @@ impl ActionLog {
Some(self.tracked_buffers.get(buffer)?.snapshot.clone())
}
+ pub fn has_unnotified_user_edits(&self) -> bool {
+ self.tracked_buffers
+ .values()
+ .any(|tracked| tracked.has_unnotified_user_edits)
+ }
+
+ /// Return a unified diff patch with user edits made since last read or notification
+ pub fn unnotified_user_edits(&self, cx: &Context<Self>) -> Option<String> {
+ if !self.has_unnotified_user_edits() {
+ return None;
+ }
+
+ let unified_diff = self
+ .tracked_buffers
+ .values()
+ .filter_map(|tracked| {
+ if !tracked.has_unnotified_user_edits {
+ return None;
+ }
+
+ let text_with_latest_user_edits = tracked.diff_base.to_string();
+ let text_with_last_seen_user_edits = tracked.last_seen_base.to_string();
+ if text_with_latest_user_edits == text_with_last_seen_user_edits {
+ return None;
+ }
+ let patch = language::unified_diff(
+ &text_with_last_seen_user_edits,
+ &text_with_latest_user_edits,
+ );
+
+ let buffer = tracked.buffer.clone();
+ let file_path = buffer
+ .read(cx)
+ .file()
+ .map(|file| RemotePathBuf::new(file.full_path(cx), PathStyle::Posix).to_proto())
+ .unwrap_or_else(|| format!("buffer_{}", buffer.entity_id()));
+
+ let mut result = String::new();
+ result.push_str(&format!("--- a/{}\n", file_path));
+ result.push_str(&format!("+++ b/{}\n", file_path));
+ result.push_str(&patch);
+
+ Some(result)
+ })
+ .collect::<Vec<_>>()
+ .join("\n\n");
+
+ Some(unified_diff)
+ }
+
+ /// Return a unified diff patch with user edits made since last read/notification
+ /// and mark them as notified
+ pub fn flush_unnotified_user_edits(&mut self, cx: &Context<Self>) -> Option<String> {
+ let patch = self.unnotified_user_edits(cx);
+ self.tracked_buffers.values_mut().for_each(|tracked| {
+ tracked.has_unnotified_user_edits = false;
+ tracked.last_seen_base = tracked.diff_base.clone();
+ });
+ patch
+ }
+
fn track_buffer_internal(
&mut self,
buffer: Entity<Buffer>,
@@ -59,7 +120,6 @@ impl ActionLog {
) -> &mut TrackedBuffer {
let status = if is_created {
if let Some(tracked) = self.tracked_buffers.remove(&buffer) {
- self.notified_versions.remove(&buffer);
match tracked.status {
TrackedBufferStatus::Created {
existing_file_content,
@@ -101,26 +161,31 @@ impl ActionLog {
let diff = cx.new(|cx| BufferDiff::new(&text_snapshot, cx));
let (diff_update_tx, diff_update_rx) = mpsc::unbounded();
let diff_base;
+ let last_seen_base;
let unreviewed_edits;
if is_created {
diff_base = Rope::default();
+ last_seen_base = Rope::default();
unreviewed_edits = Patch::new(vec![Edit {
old: 0..1,
new: 0..text_snapshot.max_point().row + 1,
}])
} else {
diff_base = buffer.read(cx).as_rope().clone();
+ last_seen_base = diff_base.clone();
unreviewed_edits = Patch::default();
}
TrackedBuffer {
buffer: buffer.clone(),
diff_base,
+ last_seen_base,
unreviewed_edits,
snapshot: text_snapshot.clone(),
status,
version: buffer.read(cx).version(),
diff,
diff_update: diff_update_tx,
+ has_unnotified_user_edits: false,
_open_lsp_handle: open_lsp_handle,
_maintain_diff: cx.spawn({
let buffer = buffer.clone();
@@ -174,7 +239,6 @@ impl ActionLog {
// If the buffer had been edited by a tool, but it got
// deleted externally, we want to stop tracking it.
self.tracked_buffers.remove(&buffer);
- self.notified_versions.remove(&buffer);
}
cx.notify();
}
@@ -188,7 +252,6 @@ impl ActionLog {
// resurrected externally, we want to clear the edits we
// were tracking and reset the buffer's state.
self.tracked_buffers.remove(&buffer);
- self.notified_versions.remove(&buffer);
self.track_buffer_internal(buffer, false, cx);
}
cx.notify();
@@ -262,19 +325,23 @@ impl ActionLog {
buffer_snapshot: text::BufferSnapshot,
cx: &mut AsyncApp,
) -> Result<()> {
- let rebase = this.read_with(cx, |this, cx| {
+ let rebase = this.update(cx, |this, cx| {
let tracked_buffer = this
.tracked_buffers
- .get(buffer)
+ .get_mut(buffer)
.context("buffer not tracked")?;
+ if let ChangeAuthor::User = author {
+ tracked_buffer.has_unnotified_user_edits = true;
+ }
+
let rebase = cx.background_spawn({
let mut base_text = tracked_buffer.diff_base.clone();
let old_snapshot = tracked_buffer.snapshot.clone();
let new_snapshot = buffer_snapshot.clone();
let unreviewed_edits = tracked_buffer.unreviewed_edits.clone();
+ let edits = diff_snapshots(&old_snapshot, &new_snapshot);
async move {
- let edits = diff_snapshots(&old_snapshot, &new_snapshot);
if let ChangeAuthor::User = author {
apply_non_conflicting_edits(
&unreviewed_edits,
@@ -494,7 +561,6 @@ impl ActionLog {
match tracked_buffer.status {
TrackedBufferStatus::Created { .. } => {
self.tracked_buffers.remove(&buffer);
- self.notified_versions.remove(&buffer);
cx.notify();
}
TrackedBufferStatus::Modified => {
@@ -520,7 +586,6 @@ impl ActionLog {
match tracked_buffer.status {
TrackedBufferStatus::Deleted => {
self.tracked_buffers.remove(&buffer);
- self.notified_versions.remove(&buffer);
cx.notify();
}
_ => {
@@ -629,7 +694,6 @@ impl ActionLog {
};
self.tracked_buffers.remove(&buffer);
- self.notified_versions.remove(&buffer);
cx.notify();
task
}
@@ -643,7 +707,6 @@ impl ActionLog {
// Clear all tracked edits for this buffer and start over as if we just read it.
self.tracked_buffers.remove(&buffer);
- self.notified_versions.remove(&buffer);
self.buffer_read(buffer.clone(), cx);
cx.notify();
save
@@ -744,33 +807,6 @@ impl ActionLog {
.collect()
}
- /// Returns stale buffers that haven't been notified yet
- pub fn unnotified_stale_buffers<'a>(
- &'a self,
- cx: &'a App,
- ) -> impl Iterator<Item = &'a Entity<Buffer>> {
- self.stale_buffers(cx).filter(|buffer| {
- let buffer_entity = buffer.read(cx);
- self.notified_versions
- .get(buffer)
- .map_or(true, |notified_version| {
- *notified_version != buffer_entity.version
- })
- })
- }
-
- /// Marks the given buffers as notified at their current versions
- pub fn mark_buffers_as_notified(
- &mut self,
- buffers: impl IntoIterator<Item = Entity<Buffer>>,
- cx: &App,
- ) {
- for buffer in buffers {
- let version = buffer.read(cx).version.clone();
- self.notified_versions.insert(buffer, version);
- }
- }
-
/// Iterate over buffers changed since last read or edited by the model
pub fn stale_buffers<'a>(&'a self, cx: &'a App) -> impl Iterator<Item = &'a Entity<Buffer>> {
self.tracked_buffers
@@ -914,12 +950,14 @@ enum TrackedBufferStatus {
struct TrackedBuffer {
buffer: Entity<Buffer>,
diff_base: Rope,
+ last_seen_base: Rope,
unreviewed_edits: Patch<u32>,
status: TrackedBufferStatus,
version: clock::Global,
diff: Entity<BufferDiff>,
snapshot: text::BufferSnapshot,
diff_update: mpsc::UnboundedSender<(ChangeAuthor, text::BufferSnapshot)>,
+ has_unnotified_user_edits: bool,
_open_lsp_handle: OpenLspBufferHandle,
_maintain_diff: Task<()>,
_subscription: Subscription,
@@ -950,6 +988,7 @@ mod tests {
use super::*;
use buffer_diff::DiffHunkStatusKind;
use gpui::TestAppContext;
+ use indoc::indoc;
use language::Point;
use project::{FakeFs, Fs, Project, RemoveOptions};
use rand::prelude::*;
@@ -1232,6 +1271,110 @@ mod tests {
assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
}
+ #[gpui::test(iterations = 10)]
+ async fn test_user_edits_notifications(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/dir"),
+ json!({"file": indoc! {"
+ abc
+ def
+ ghi
+ jkl
+ mno"}}),
+ )
+ .await;
+ let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+ let file_path = project
+ .read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
+ .unwrap();
+ let buffer = project
+ .update(cx, |project, cx| project.open_buffer(file_path, cx))
+ .await
+ .unwrap();
+
+ // Agent edits
+ cx.update(|cx| {
+ action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
+ buffer.update(cx, |buffer, cx| {
+ buffer
+ .edit([(Point::new(1, 2)..Point::new(2, 3), "F\nGHI")], None, cx)
+ .unwrap()
+ });
+ action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
+ });
+ cx.run_until_parked();
+ assert_eq!(
+ buffer.read_with(cx, |buffer, _| buffer.text()),
+ indoc! {"
+ abc
+ deF
+ GHI
+ jkl
+ mno"}
+ );
+ assert_eq!(
+ unreviewed_hunks(&action_log, cx),
+ vec![(
+ buffer.clone(),
+ vec![HunkStatus {
+ range: Point::new(1, 0)..Point::new(3, 0),
+ diff_status: DiffHunkStatusKind::Modified,
+ old_text: "def\nghi\n".into(),
+ }],
+ )]
+ );
+
+ // User edits
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit(
+ [
+ (Point::new(0, 2)..Point::new(0, 2), "X"),
+ (Point::new(3, 0)..Point::new(3, 0), "Y"),
+ ],
+ None,
+ cx,
+ )
+ });
+ cx.run_until_parked();
+ assert_eq!(
+ buffer.read_with(cx, |buffer, _| buffer.text()),
+ indoc! {"
+ abXc
+ deF
+ GHI
+ Yjkl
+ mno"}
+ );
+
+ // User edits should be stored separately from agent's
+ let user_edits = action_log.update(cx, |log, cx| log.unnotified_user_edits(cx));
+ assert_eq!(
+ user_edits.expect("should have some user edits"),
+ indoc! {"
+ --- a/dir/file
+ +++ b/dir/file
+ @@ -1,5 +1,5 @@
+ -abc
+ +abXc
+ def
+ ghi
+ -jkl
+ +Yjkl
+ mno
+ "}
+ );
+
+ action_log.update(cx, |log, cx| {
+ log.keep_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(1, 0), cx)
+ });
+ cx.run_until_parked();
+ assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
+ }
+
#[gpui::test(iterations = 10)]
async fn test_creating_files(cx: &mut TestAppContext) {
init_test(cx);
@@ -2221,4 +2364,61 @@ mod tests {
.collect()
})
}
+
+ #[gpui::test]
+ async fn test_format_patch(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/dir"),
+ json!({"test.txt": "line 1\nline 2\nline 3\n"}),
+ )
+ .await;
+ let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+
+ let file_path = project
+ .read_with(cx, |project, cx| {
+ project.find_project_path("dir/test.txt", cx)
+ })
+ .unwrap();
+ let buffer = project
+ .update(cx, |project, cx| project.open_buffer(file_path, cx))
+ .await
+ .unwrap();
+
+ cx.update(|cx| {
+ // Track the buffer and mark it as read first
+ action_log.update(cx, |log, cx| {
+ log.buffer_read(buffer.clone(), cx);
+ });
+
+ // Make some edits to create a patch
+ buffer.update(cx, |buffer, cx| {
+ buffer
+ .edit([(Point::new(1, 0)..Point::new(1, 6), "CHANGED")], None, cx)
+ .unwrap(); // Replace "line2" with "CHANGED"
+ });
+ });
+
+ cx.run_until_parked();
+
+ // Get the patch
+ let patch = action_log.update(cx, |log, cx| log.unnotified_user_edits(cx));
+
+ // Verify the patch format contains expected unified diff elements
+ assert_eq!(
+ patch.unwrap(),
+ indoc! {"
+ --- a/dir/test.txt
+ +++ b/dir/test.txt
+ @@ -1,3 +1,3 @@
+ line 1
+ -line 2
+ +CHANGED
+ line 3
+ "}
+ );
+ }
}
@@ -6,7 +6,6 @@ use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchem
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
-use std::fmt::Write as _;
use std::sync::Arc;
use ui::IconName;
@@ -52,34 +51,22 @@ impl Tool for ProjectNotificationsTool {
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
- let mut stale_files = String::new();
- let mut notified_buffers = Vec::new();
-
- for stale_file in action_log.read(cx).unnotified_stale_buffers(cx) {
- if let Some(file) = stale_file.read(cx).file() {
- writeln!(&mut stale_files, "- {}", file.path().display()).ok();
- notified_buffers.push(stale_file.clone());
- }
- }
-
- if !notified_buffers.is_empty() {
- action_log.update(cx, |log, cx| {
- log.mark_buffers_as_notified(notified_buffers, cx);
- });
- }
-
- let response = if stale_files.is_empty() {
- "No new notifications".to_string()
- } else {
- // NOTE: Changes to this prompt require a symmetric update in the LLM Worker
- const HEADER: &str = include_str!("./project_notifications_tool/prompt_header.txt");
- format!("{HEADER}{stale_files}").replace("\r\n", "\n")
+ let Some(user_edits_diff) =
+ action_log.update(cx, |log, cx| log.flush_unnotified_user_edits(cx))
+ else {
+ return result("No new notifications");
};
- Task::ready(Ok(response.into())).into()
+ // NOTE: Changes to this prompt require a symmetric update in the LLM Worker
+ const HEADER: &str = include_str!("./project_notifications_tool/prompt_header.txt");
+ result(&format!("{HEADER}\n\n```diff\n{user_edits_diff}\n```\n").replace("\r\n", "\n"))
}
}
+fn result(response: &str) -> ToolResult {
+ Task::ready(Ok(response.to_string().into())).into()
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -123,6 +110,7 @@ mod tests {
action_log.update(cx, |log, cx| {
log.buffer_read(buffer.clone(), cx);
});
+ cx.run_until_parked();
// Run the tool before any changes
let tool = Arc::new(ProjectNotificationsTool);
@@ -142,6 +130,7 @@ mod tests {
cx,
)
});
+ cx.run_until_parked();
let response = result.output.await.unwrap();
let response_text = match &response.content {
@@ -158,6 +147,7 @@ mod tests {
buffer.update(cx, |buffer, cx| {
buffer.edit([(1..1, "\nChange!\n")], None, cx);
});
+ cx.run_until_parked();
// Run the tool again
let result = cx.update(|cx| {
@@ -171,6 +161,7 @@ mod tests {
cx,
)
});
+ cx.run_until_parked();
// This time the buffer is stale, so the tool should return a notification
let response = result.output.await.unwrap();
@@ -179,10 +170,12 @@ mod tests {
_ => panic!("Expected text response"),
};
- let expected_content = "[The following is an auto-generated notification; do not reply]\n\nThese files have changed since the last read:\n- code.rs\n";
- assert_eq!(
- response_text.as_str(),
- expected_content,
+ assert!(
+ response_text.contains("These files have changed"),
+ "Tool should return the stale buffer notification"
+ );
+ assert!(
+ response_text.contains("test/code.rs"),
"Tool should return the stale buffer notification"
);
@@ -198,6 +191,7 @@ mod tests {
cx,
)
});
+ cx.run_until_parked();
let response = result.output.await.unwrap();
let response_text = match &response.content {