Start tracking diffs in `ScriptingSession` (#26463)

Antonio Scandurra created

The diff is not exposed yet, but we'll take care of that next.

Release Notes:

- N/A

Change summary

Cargo.lock                                     |   2 
crates/scripting_tool/Cargo.toml               |   4 
crates/scripting_tool/src/scripting_session.rs | 160 +++++++++++++++----
crates/sum_tree/src/tree_map.rs                |  32 +++
crates/text/src/text.rs                        |  73 ++++++---
5 files changed, 207 insertions(+), 64 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -11915,6 +11915,8 @@ name = "scripting_tool"
 version = "0.1.0"
 dependencies = [
  "anyhow",
+ "buffer_diff",
+ "clock",
  "collections",
  "futures 0.3.31",
  "gpui",

crates/scripting_tool/Cargo.toml 🔗

@@ -14,6 +14,8 @@ doctest = false
 
 [dependencies]
 anyhow.workspace = true
+buffer_diff.workspace = true
+clock.workspace = true
 collections.workspace = true
 futures.workspace = true
 gpui.workspace = true
@@ -30,6 +32,8 @@ settings.workspace = true
 util.workspace = true
 
 [dev-dependencies]
+buffer_diff = { workspace = true, features = ["test-support"] }
+clock = { workspace = true, features = ["test-support"] }
 collections = { workspace = true, features = ["test-support"] }
 gpui = { workspace = true, features = ["test-support"] }
 language = { workspace = true, features = ["test-support"] }

crates/scripting_tool/src/scripting_session.rs 🔗

@@ -1,5 +1,6 @@
 use anyhow::anyhow;
-use collections::HashSet;
+use buffer_diff::BufferDiff;
+use collections::{HashMap, HashSet};
 use futures::{
     channel::{mpsc, oneshot},
     pin_mut, SinkExt, StreamExt,
@@ -18,10 +19,15 @@ use util::{paths::PathMatcher, ResultExt};
 
 struct ForegroundFn(Box<dyn FnOnce(WeakEntity<ScriptingSession>, AsyncApp) + Send>);
 
+struct BufferChanges {
+    diff: Entity<BufferDiff>,
+    edit_ids: Vec<clock::Lamport>,
+}
+
 pub struct ScriptingSession {
     project: Entity<Project>,
     scripts: Vec<Script>,
-    changed_buffers: HashSet<Entity<Buffer>>,
+    changes_by_buffer: HashMap<Entity<Buffer>, BufferChanges>,
     foreground_fns_tx: mpsc::Sender<ForegroundFn>,
     _invoke_foreground_fns: Task<()>,
 }
@@ -32,7 +38,7 @@ impl ScriptingSession {
         ScriptingSession {
             project,
             scripts: Vec::new(),
-            changed_buffers: HashSet::default(),
+            changes_by_buffer: HashMap::default(),
             foreground_fns_tx,
             _invoke_foreground_fns: cx.spawn(|this, cx| async move {
                 while let Some(foreground_fn) = foreground_fns_rx.next().await {
@@ -43,7 +49,7 @@ impl ScriptingSession {
     }
 
     pub fn changed_buffers(&self) -> impl ExactSizeIterator<Item = &Entity<Buffer>> {
-        self.changed_buffers.iter()
+        self.changes_by_buffer.keys()
     }
 
     pub fn run_script(
@@ -188,9 +194,6 @@ impl ScriptingSession {
 
                 lua.load(SANDBOX_PREAMBLE).exec_async().await?;
 
-                // Drop Lua instance to decrement reference count.
-                drop(lua);
-
                 anyhow::Ok(())
             }
         });
@@ -384,8 +387,12 @@ impl ScriptingSession {
                             .update(&mut cx, |buffer, cx| buffer.diff(text, cx))?
                             .await;
 
-                        buffer.update(&mut cx, |buffer, cx| {
+                        let edit_ids = buffer.update(&mut cx, |buffer, cx| {
+                            buffer.finalize_last_transaction();
                             buffer.apply_diff(diff, cx);
+                            let transaction = buffer.finalize_last_transaction();
+                            transaction
+                                .map_or(Vec::new(), |transaction| transaction.edit_ids.clone())
                         })?;
 
                         session
@@ -400,10 +407,36 @@ impl ScriptingSession {
                             })?
                             .await?;
 
+                        let snapshot = buffer.read_with(&cx, |buffer, _| buffer.snapshot())?;
+
                         // If we saved successfully, mark buffer as changed
-                        session.update(&mut cx, |session, _cx| {
-                            session.changed_buffers.insert(buffer);
-                        })
+                        let buffer_without_changes =
+                            buffer.update(&mut cx, |buffer, cx| buffer.branch(cx))?;
+                        session
+                            .update(&mut cx, |session, cx| {
+                                let changed_buffer = session
+                                    .changes_by_buffer
+                                    .entry(buffer)
+                                    .or_insert_with(|| BufferChanges {
+                                        diff: cx.new(|cx| BufferDiff::new(&snapshot, cx)),
+                                        edit_ids: Vec::new(),
+                                    });
+                                changed_buffer.edit_ids.extend(edit_ids);
+                                let operations_to_undo = changed_buffer
+                                    .edit_ids
+                                    .iter()
+                                    .map(|edit_id| (*edit_id, u32::MAX))
+                                    .collect::<HashMap<_, _>>();
+                                buffer_without_changes.update(cx, |buffer, cx| {
+                                    buffer.undo_operations(operations_to_undo, cx);
+                                });
+                                changed_buffer.diff.update(cx, |diff, cx| {
+                                    diff.set_base_text(buffer_without_changes, snapshot.text, cx)
+                                })
+                            })?
+                            .await?;
+
+                        Ok(())
                     })
                 })
             }),
@@ -895,6 +928,7 @@ impl Script {
         }
     }
 }
+
 #[cfg(test)]
 mod tests {
     use gpui::TestAppContext;
@@ -954,9 +988,7 @@ mod tests {
         let test_session = TestSession::init(cx).await;
         let output = test_session.test_success(script, cx).await;
         assert_eq!(output, "Content:\tHello world!\n");
-
-        // Only read, should not be marked as changed
-        assert!(!test_session.was_marked_changed("file1.txt", cx));
+        assert_eq!(test_session.diff(cx), Vec::new());
     }
 
     #[gpui::test]
@@ -976,7 +1008,16 @@ mod tests {
         let test_session = TestSession::init(cx).await;
         let output = test_session.test_success(script, cx).await;
         assert_eq!(output, "Written content:\tThis is new content\n");
-        assert!(test_session.was_marked_changed("file1.txt", cx));
+        assert_eq!(
+            test_session.diff(cx),
+            vec![(
+                PathBuf::from("file1.txt"),
+                vec![(
+                    "Hello world!\n".to_string(),
+                    "This is new content".to_string()
+                )]
+            )]
+        );
     }
 
     #[gpui::test]
@@ -1004,7 +1045,16 @@ mod tests {
             output,
             "Full content:\tFirst line\nSecond line\nThird line\n"
         );
-        assert!(test_session.was_marked_changed("multiwrite.txt", cx));
+        assert_eq!(
+            test_session.diff(cx),
+            vec![(
+                PathBuf::from("multiwrite.txt"),
+                vec![(
+                    "".to_string(),
+                    "First line\nSecond line\nThird line".to_string()
+                )]
+            )]
+        );
     }
 
     #[gpui::test]
@@ -1033,29 +1083,33 @@ mod tests {
             output,
             "Final content:\tContent written by second handle\n\n"
         );
-        assert!(test_session.was_marked_changed("multi_open.txt", cx));
+        assert_eq!(
+            test_session.diff(cx),
+            vec![(
+                PathBuf::from("multi_open.txt"),
+                vec![(
+                    "".to_string(),
+                    "Content written by second handle\n".to_string()
+                )]
+            )]
+        );
     }
 
     #[gpui::test]
     async fn test_append_mode(cx: &mut TestAppContext) {
         let script = r#"
-            -- Test append mode
-            local file = io.open("append.txt", "w")
-            file:write("Initial content\n")
-            file:close()
-
             -- Append more content
-            file = io.open("append.txt", "a")
+            file = io.open("file1.txt", "a")
             file:write("Appended content\n")
             file:close()
 
             -- Add even more
-            file = io.open("append.txt", "a")
+            file = io.open("file1.txt", "a")
             file:write("More appended content")
             file:close()
 
             -- Read back to verify
-            local read_file = io.open("append.txt", "r")
+            local read_file = io.open("file1.txt", "r")
             local content = read_file:read("*a")
             print("Content after appends:", content)
             read_file:close()
@@ -1065,9 +1119,18 @@ mod tests {
         let output = test_session.test_success(script, cx).await;
         assert_eq!(
             output,
-            "Content after appends:\tInitial content\nAppended content\nMore appended content\n"
+            "Content after appends:\tHello world!\nAppended content\nMore appended content\n"
+        );
+        assert_eq!(
+            test_session.diff(cx),
+            vec![(
+                PathBuf::from("file1.txt"),
+                vec![(
+                    "".to_string(),
+                    "Appended content\nMore appended content".to_string()
+                )]
+            )]
         );
-        assert!(test_session.was_marked_changed("append.txt", cx));
     }
 
     #[gpui::test]
@@ -1117,7 +1180,13 @@ mod tests {
         assert!(output.contains("Line with newline length:\t7"));
         assert!(output.contains("Last char:\t10")); // LF
         assert!(output.contains("5 bytes:\tLine "));
-        assert!(test_session.was_marked_changed("multiline.txt", cx));
+        assert_eq!(
+            test_session.diff(cx),
+            vec![(
+                PathBuf::from("multiline.txt"),
+                vec![("".to_string(), "Line 1\nLine 2\nLine 3".to_string())]
+            )]
+        );
     }
 
     // helpers
@@ -1137,8 +1206,8 @@ mod tests {
             fs.insert_tree(
                 path!("/"),
                 json!({
-                    "file1.txt": "Hello world!",
-                    "file2.txt": "Goodbye moon!"
+                    "file1.txt": "Hello world!\n",
+                    "file2.txt": "Goodbye moon!\n"
                 }),
             )
             .await;
@@ -1164,17 +1233,30 @@ mod tests {
             })
         }
 
-        fn was_marked_changed(&self, path_str: &str, cx: &mut TestAppContext) -> bool {
+        fn diff(&self, cx: &mut TestAppContext) -> Vec<(PathBuf, Vec<(String, String)>)> {
             self.session.read_with(cx, |session, cx| {
-                let count_changed = session
-                    .changed_buffers
+                session
+                    .changes_by_buffer
                     .iter()
-                    .filter(|buffer| buffer.read(cx).file().unwrap().path().ends_with(path_str))
-                    .count();
-
-                assert!(count_changed < 2, "Multiple buffers matched for same path");
-
-                count_changed > 0
+                    .map(|(buffer, changes)| {
+                        let snapshot = buffer.read(cx).snapshot();
+                        let diff = changes.diff.read(cx);
+                        let hunks = diff.hunks(&snapshot, cx);
+                        let path = buffer.read(cx).file().unwrap().path().clone();
+                        let diffs = hunks
+                            .map(|hunk| {
+                                let old_text = diff
+                                    .base_text()
+                                    .text_for_range(hunk.diff_base_byte_range)
+                                    .collect::<String>();
+                                let new_text =
+                                    snapshot.text_for_range(hunk.range).collect::<String>();
+                                (old_text, new_text)
+                            })
+                            .collect();
+                        (path.to_path_buf(), diffs)
+                    })
+                    .collect()
             })
         }
 

crates/sum_tree/src/tree_map.rs 🔗

@@ -70,6 +70,14 @@ impl<K: Clone + Ord, V: Clone> TreeMap<K, V> {
         self.0.insert_or_replace(MapEntry { key, value }, &());
     }
 
+    pub fn extend(&mut self, iter: impl IntoIterator<Item = (K, V)>) {
+        let mut edits = Vec::new();
+        for (key, value) in iter {
+            edits.push(Edit::Insert(MapEntry { key, value }));
+        }
+        self.0.edit(edits, &());
+    }
+
     pub fn clear(&mut self) {
         self.0 = SumTree::default();
     }
@@ -109,7 +117,7 @@ impl<K: Clone + Ord, V: Clone> TreeMap<K, V> {
         cursor.item().map(|item| (&item.key, &item.value))
     }
 
-    pub fn iter_from<'a>(&'a self, from: &'a K) -> impl Iterator<Item = (&'a K, &'a V)> + 'a {
+    pub fn iter_from<'a>(&'a self, from: &K) -> impl Iterator<Item = (&'a K, &'a V)> + 'a {
         let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
         let from_key = MapKeyRef(Some(from));
         cursor.seek(&from_key, Bias::Left, &());
@@ -313,6 +321,10 @@ where
         self.0.insert(key, ());
     }
 
+    pub fn extend(&mut self, iter: impl IntoIterator<Item = K>) {
+        self.0.extend(iter.into_iter().map(|key| (key, ())));
+    }
+
     pub fn contains(&self, key: &K) -> bool {
         self.0.get(key).is_some()
     }
@@ -320,6 +332,10 @@ where
     pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
         self.0.iter().map(|(k, _)| k)
     }
+
+    pub fn iter_from<'a>(&'a self, key: &K) -> impl Iterator<Item = &'a K> + 'a {
+        self.0.iter_from(key).map(move |(k, _)| k)
+    }
 }
 
 #[cfg(test)]
@@ -422,6 +438,20 @@ mod tests {
         assert_eq!(map.get(&"d"), Some(&4));
     }
 
+    #[test]
+    fn test_extend() {
+        let mut map = TreeMap::default();
+        map.insert("a", 1);
+        map.insert("b", 2);
+        map.insert("c", 3);
+        map.extend([("a", 2), ("b", 2), ("d", 4)]);
+        assert_eq!(map.iter().count(), 4);
+        assert_eq!(map.get(&"a"), Some(&2));
+        assert_eq!(map.get(&"b"), Some(&2));
+        assert_eq!(map.get(&"c"), Some(&3));
+        assert_eq!(map.get(&"d"), Some(&4));
+    }
+
     #[test]
     fn test_remove_between_and_path_successor() {
         use std::path::{Path, PathBuf};

crates/text/src/text.rs 🔗

@@ -37,7 +37,7 @@ use std::{
 };
 pub use subscription::*;
 pub use sum_tree::Bias;
-use sum_tree::{FilterCursor, SumTree, TreeMap};
+use sum_tree::{FilterCursor, SumTree, TreeMap, TreeSet};
 use undo_map::UndoMap;
 
 #[cfg(any(test, feature = "test-support"))]
@@ -110,6 +110,7 @@ pub struct BufferSnapshot {
     undo_map: UndoMap,
     fragments: SumTree<Fragment>,
     insertions: SumTree<InsertionFragment>,
+    insertion_slices: TreeSet<InsertionSlice>,
     pub version: clock::Global,
 }
 
@@ -137,25 +138,50 @@ impl HistoryEntry {
 struct History {
     base_text: Rope,
     operations: TreeMap<clock::Lamport, Operation>,
-    insertion_slices: HashMap<clock::Lamport, Vec<InsertionSlice>>,
     undo_stack: Vec<HistoryEntry>,
     redo_stack: Vec<HistoryEntry>,
     transaction_depth: usize,
     group_interval: Duration,
 }
 
-#[derive(Clone, Debug)]
+#[derive(Clone, Debug, Eq, PartialEq)]
 struct InsertionSlice {
+    edit_id: clock::Lamport,
     insertion_id: clock::Lamport,
     range: Range<usize>,
 }
 
+impl Ord for InsertionSlice {
+    fn cmp(&self, other: &Self) -> Ordering {
+        self.edit_id
+            .cmp(&other.edit_id)
+            .then_with(|| self.insertion_id.cmp(&other.insertion_id))
+            .then_with(|| self.range.start.cmp(&other.range.start))
+            .then_with(|| self.range.end.cmp(&other.range.end))
+    }
+}
+
+impl PartialOrd for InsertionSlice {
+    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+        Some(self.cmp(other))
+    }
+}
+
+impl InsertionSlice {
+    fn from_fragment(edit_id: clock::Lamport, fragment: &Fragment) -> Self {
+        Self {
+            edit_id,
+            insertion_id: fragment.timestamp,
+            range: fragment.insertion_offset..fragment.insertion_offset + fragment.len,
+        }
+    }
+}
+
 impl History {
     pub fn new(base_text: Rope) -> Self {
         Self {
             base_text,
             operations: Default::default(),
-            insertion_slices: Default::default(),
             undo_stack: Vec::new(),
             redo_stack: Vec::new(),
             transaction_depth: 0,
@@ -696,6 +722,7 @@ impl Buffer {
                 insertions,
                 version,
                 undo_map: Default::default(),
+                insertion_slices: Default::default(),
             },
             history,
             deferred_ops: OperationQueue::new(),
@@ -857,7 +884,7 @@ impl Buffer {
                     old: fragment_start..fragment_start,
                     new: new_start..new_start + new_text.len(),
                 });
-                insertion_slices.push(fragment.insertion_slice());
+                insertion_slices.push(InsertionSlice::from_fragment(timestamp, &fragment));
                 new_insertions.push(InsertionFragment::insert_new(&fragment));
                 new_ropes.push_str(new_text.as_ref());
                 new_fragments.push(fragment, &None);
@@ -886,7 +913,8 @@ impl Buffer {
                             old: fragment_start..intersection_end,
                             new: new_start..new_start,
                         });
-                        insertion_slices.push(intersection.insertion_slice());
+                        insertion_slices
+                            .push(InsertionSlice::from_fragment(timestamp, &intersection));
                     }
                     new_insertions.push(InsertionFragment::insert_new(&intersection));
                     new_ropes.push_fragment(&intersection, fragment.visible);
@@ -929,9 +957,7 @@ impl Buffer {
         self.snapshot.visible_text = visible_text;
         self.snapshot.deleted_text = deleted_text;
         self.subscriptions.publish_mut(&edits_patch);
-        self.history
-            .insertion_slices
-            .insert(timestamp, insertion_slices);
+        self.snapshot.insertion_slices.extend(insertion_slices);
         edit_op
     }
 
@@ -1107,7 +1133,7 @@ impl Buffer {
                     old: old_start..old_start,
                     new: new_start..new_start + new_text.len(),
                 });
-                insertion_slices.push(fragment.insertion_slice());
+                insertion_slices.push(InsertionSlice::from_fragment(timestamp, &fragment));
                 new_insertions.push(InsertionFragment::insert_new(&fragment));
                 new_ropes.push_str(new_text);
                 new_fragments.push(fragment, &None);
@@ -1129,7 +1155,7 @@ impl Buffer {
                         Locator::between(&new_fragments.summary().max_id, &intersection.id);
                     intersection.deletions.insert(timestamp);
                     intersection.visible = false;
-                    insertion_slices.push(intersection.insertion_slice());
+                    insertion_slices.push(InsertionSlice::from_fragment(timestamp, &intersection));
                 }
                 if intersection.len > 0 {
                     if fragment.visible && !intersection.visible {
@@ -1177,9 +1203,7 @@ impl Buffer {
         self.snapshot.visible_text = visible_text;
         self.snapshot.deleted_text = deleted_text;
         self.snapshot.insertions.edit(new_insertions, &());
-        self.history
-            .insertion_slices
-            .insert(timestamp, insertion_slices);
+        self.snapshot.insertion_slices.extend(insertion_slices);
         self.subscriptions.publish_mut(&edits_patch)
     }
 
@@ -1190,9 +1214,17 @@ impl Buffer {
         // Get all of the insertion slices changed by the given edits.
         let mut insertion_slices = Vec::new();
         for edit_id in edit_ids {
-            if let Some(slices) = self.history.insertion_slices.get(edit_id) {
-                insertion_slices.extend_from_slice(slices)
-            }
+            let insertion_slice = InsertionSlice {
+                edit_id: *edit_id,
+                insertion_id: clock::Lamport::default(),
+                range: 0..0,
+            };
+            let slices = self
+                .snapshot
+                .insertion_slices
+                .iter_from(&insertion_slice)
+                .take_while(|slice| slice.edit_id == *edit_id);
+            insertion_slices.extend(slices)
         }
         insertion_slices
             .sort_unstable_by_key(|s| (s.insertion_id, s.range.start, Reverse(s.range.end)));
@@ -2639,13 +2671,6 @@ impl<D: TextDimension + Ord, F: FnMut(&FragmentSummary) -> bool> Iterator for Ed
 }
 
 impl Fragment {
-    fn insertion_slice(&self) -> InsertionSlice {
-        InsertionSlice {
-            insertion_id: self.timestamp,
-            range: self.insertion_offset..self.insertion_offset + self.len,
-        }
-    }
-
     fn is_visible(&self, undos: &UndoMap) -> bool {
         !undos.is_undone(self.timestamp) && self.deletions.iter().all(|d| undos.is_undone(*d))
     }