Compute minimal `version_in_range` on edit and account for undo

Antonio Scandurra created

Change summary

zed/src/editor/buffer/mod.rs | 49 ++++++++++++++++++++-----------------
zed/src/sum_tree/cursor.rs   |  2 
2 files changed, 27 insertions(+), 24 deletions(-)

Detailed changes

zed/src/editor/buffer/mod.rs 🔗

@@ -56,7 +56,7 @@ type HashMap<K, V> = std::collections::HashMap<K, V>;
 #[cfg(not(test))]
 type HashSet<T> = std::collections::HashSet<T>;
 
-#[derive(Clone, Default)]
+#[derive(Clone, Default, Debug)]
 struct UndoMap(HashMap<time::Local, Vec<UndoOperation>>);
 
 impl UndoMap {
@@ -177,6 +177,7 @@ struct Fragment {
     insertion: Insertion,
     text: Text,
     deletions: HashSet<time::Local>,
+    undos: HashSet<time::Local>,
     visible: bool,
 }
 
@@ -272,6 +273,7 @@ impl Buffer {
             insertion: base_insertion.clone(),
             text: base_insertion.text.slice(0..0),
             deletions: HashSet::default(),
+            undos: HashSet::default(),
             visible: true,
         });
 
@@ -291,6 +293,7 @@ impl Buffer {
                 text: base_insertion.text.clone(),
                 insertion: base_insertion,
                 deletions: HashSet::default(),
+                undos: HashSet::default(),
                 visible: true,
             });
         }
@@ -787,9 +790,11 @@ impl Buffer {
                 undo,
                 lamport_timestamp,
             } => {
-                self.apply_undo(undo)?;
-                self.version.observe(undo.id);
-                self.lamport_clock.observe(lamport_timestamp);
+                if !self.version.observed(undo.id) {
+                    self.apply_undo(undo)?;
+                    self.version.observe(undo.id);
+                    self.lamport_clock.observe(lamport_timestamp);
+                }
             }
             Operation::UpdateSelections {
                 set_id,
@@ -877,7 +882,7 @@ impl Buffer {
                     new_fragments.push(fragment);
                 }
                 if let Some(mut fragment) = within_range {
-                    if version_in_range.observed(fragment.insertion.id) {
+                    if fragment.was_visible(&version_in_range, &self.undo_map) {
                         fragment.deletions.insert(local_timestamp);
                         fragment.visible = false;
                     }
@@ -897,7 +902,8 @@ impl Buffer {
                     ));
                 }
 
-                if fragment.id < end_fragment_id && version_in_range.observed(fragment.insertion.id)
+                if fragment.id < end_fragment_id
+                    && fragment.was_visible(&version_in_range, &self.undo_map)
                 {
                     fragment.deletions.insert(local_timestamp);
                     fragment.visible = false;
@@ -955,17 +961,6 @@ impl Buffer {
     }
 
     fn apply_undo(&mut self, undo: UndoOperation) -> Result<()> {
-        // let mut new_fragments = SumTree::new();
-
-        // self.undos.insert(undo);
-        // let edit = &self.edit_ops[&undo.edit_id];
-        // let start_fragment_id = self.resolve_fragment_id(edit.start_id, edit.start_offset)?;
-        // let end_fragment_id = self.resolve_fragment_id(edit.end_id, edit.end_offset)?;
-        // let mut cursor = self.fragments.cursor::<FragmentIdRef, ()>();
-
-        // for fragment in cursor {}
-        // self.fragments = new_fragments;
-
         let mut new_fragments;
 
         self.undo_map.insert(undo);
@@ -984,6 +979,7 @@ impl Buffer {
             loop {
                 let mut fragment = cursor.item().unwrap().clone();
                 fragment.visible = fragment.is_visible(&self.undo_map);
+                fragment.undos.insert(undo.id);
                 new_fragments.push(fragment);
                 cursor.next();
                 if let Some(split_id) = insertion_splits.next() {
@@ -1004,6 +1000,7 @@ impl Buffer {
                         || fragment.insertion.id == undo.edit_id
                     {
                         fragment.visible = fragment.is_visible(&self.undo_map);
+                        fragment.undos.insert(undo.id);
                     }
                     new_fragments.push(fragment);
                     cursor.next();
@@ -1110,6 +1107,7 @@ impl Buffer {
 
         while cur_range.is_some() && cursor.item().is_some() {
             let mut fragment = cursor.item().unwrap().clone();
+            let fragment_summary = cursor.item_summary().unwrap();
             let mut fragment_start = *cursor.start();
             let mut fragment_end = fragment_start + fragment.visible_len();
 
@@ -1169,6 +1167,7 @@ impl Buffer {
                         prefix.set_end_offset(prefix.start_offset() + (range.end - fragment_start));
                         prefix.id =
                             FragmentId::between(&new_fragments.last().unwrap().id, &fragment.id);
+                        version_in_range.observe_all(&fragment_summary.max_version);
                         if fragment.visible {
                             prefix.deletions.insert(local_timestamp);
                             prefix.visible = false;
@@ -1182,10 +1181,9 @@ impl Buffer {
                         fragment_start = range.end;
                         end_id = Some(fragment.insertion.id);
                         end_offset = Some(fragment.start_offset());
-                        version_in_range.observe(fragment.insertion.id);
                     }
                 } else {
-                    version_in_range.observe(fragment.insertion.id);
+                    version_in_range.observe_all(&fragment_summary.max_version);
                     if fragment.visible {
                         fragment.deletions.insert(local_timestamp);
                         fragment.visible = false;
@@ -1238,15 +1236,16 @@ impl Buffer {
             cursor.next();
             if let Some(range) = cur_range.clone() {
                 while let Some(fragment) = cursor.item() {
+                    let fragment_summary = cursor.item_summary().unwrap();
                     fragment_start = *cursor.start();
                     fragment_end = fragment_start + fragment.visible_len();
                     if range.start < fragment_start && range.end >= fragment_end {
                         let mut new_fragment = fragment.clone();
+                        version_in_range.observe_all(&fragment_summary.max_version);
                         if new_fragment.visible {
                             new_fragment.deletions.insert(local_timestamp);
                             new_fragment.visible = false;
                         }
-                        version_in_range.observe(new_fragment.insertion.id);
                         new_fragments.push(new_fragment);
                         cursor.next();
 
@@ -1927,6 +1926,7 @@ impl Fragment {
             text: insertion.text.clone(),
             insertion,
             deletions: HashSet::default(),
+            undos: HashSet::default(),
             visible: true,
         }
     }
@@ -1989,6 +1989,9 @@ impl sum_tree::Item for Fragment {
         for deletion in &self.deletions {
             max_version.observe(*deletion);
         }
+        for undo in &self.undos {
+            max_version.observe(*undo);
+        }
 
         if self.visible {
             FragmentSummary {
@@ -2845,9 +2848,9 @@ mod tests {
     fn test_random_concurrent_edits() {
         use crate::test::Network;
 
-        const PEERS: usize = 2;
+        const PEERS: usize = 5;
 
-        for seed in 0..1000 {
+        for seed in 0..100 {
             println!("{:?}", seed);
             let mut rng = &mut StdRng::seed_from_u64(seed);
 
@@ -2865,7 +2868,7 @@ mod tests {
                 network.add_peer(i as u16);
             }
 
-            let mut mutation_count = 3;
+            let mut mutation_count = 10;
             loop {
                 let replica_index = rng.gen_range(0..PEERS);
                 let replica_id = replica_ids[replica_index];

zed/src/sum_tree/cursor.rs 🔗

@@ -77,7 +77,7 @@ where
         }
     }
 
-    fn item_summary(&self) -> Option<&'a T::Summary> {
+    pub fn item_summary(&self) -> Option<&'a T::Summary> {
         assert!(self.did_seek, "Must seek before calling this method");
         if let Some(entry) = self.stack.last() {
             match *entry.tree.0 {