Randomize test `FilterCursor::prev`

Antonio Scandurra and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

Cargo.lock                      |  3 ++
crates/sum_tree/Cargo.toml      |  3 ++
crates/sum_tree/src/cursor.rs   | 21 ++++++++-------
crates/sum_tree/src/sum_tree.rs | 45 ++++++++++++++++++++++++++++++----
4 files changed, 56 insertions(+), 16 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -4929,6 +4929,9 @@ name = "sum_tree"
 version = "0.1.0"
 dependencies = [
  "arrayvec 0.7.1",
+ "ctor",
+ "env_logger",
+ "log",
  "rand 0.8.3",
 ]
 

crates/sum_tree/Cargo.toml 🔗

@@ -9,6 +9,9 @@ doctest = false
 
 [dependencies]
 arrayvec = "0.7.1"
+log = "0.4"
 
 [dev-dependencies]
+ctor = "0.1"
+env_logger = "0.8"
 rand = "0.8.3"

crates/sum_tree/src/cursor.rs 🔗

@@ -197,10 +197,6 @@ where
                 }
             }
         }
-
-        if self.stack.is_empty() {
-            self.position = D::default();
-        }
     }
 
     pub fn next(&mut self, cx: &<T::Summary as Summary>::Context) {
@@ -235,8 +231,8 @@ where
                         ..
                     } => {
                         if !descend {
-                            entry.position = self.position.clone();
                             entry.index += 1;
+                            entry.position = self.position.clone();
                         }
 
                         while entry.index < child_summaries.len() {
@@ -244,9 +240,10 @@ where
                             if filter_node(next_summary) {
                                 break;
                             } else {
+                                entry.index += 1;
+                                entry.position.add_summary(next_summary, cx);
                                 self.position.add_summary(next_summary, cx);
                             }
-                            entry.index += 1;
                         }
 
                         child_trees.get(entry.index)
@@ -254,9 +251,9 @@ where
                     Node::Leaf { item_summaries, .. } => {
                         if !descend {
                             let item_summary = &item_summaries[entry.index];
-                            self.position.add_summary(item_summary, cx);
-                            entry.position.add_summary(item_summary, cx);
                             entry.index += 1;
+                            entry.position.add_summary(item_summary, cx);
+                            self.position.add_summary(item_summary, cx);
                         }
 
                         loop {
@@ -264,9 +261,9 @@ where
                                 if filter_node(next_item_summary) {
                                     return;
                                 } else {
-                                    self.position.add_summary(next_item_summary, cx);
-                                    entry.position.add_summary(next_item_summary, cx);
                                     entry.index += 1;
+                                    entry.position.add_summary(next_item_summary, cx);
+                                    self.position.add_summary(next_item_summary, cx);
                                 }
                             } else {
                                 break None;
@@ -598,6 +595,10 @@ where
     pub fn next(&mut self, cx: &<T::Summary as Summary>::Context) {
         self.cursor.next_internal(&mut self.filter_node, cx);
     }
+
+    pub fn prev(&mut self, cx: &<T::Summary as Summary>::Context) {
+        self.cursor.prev_internal(&mut self.filter_node, cx);
+    }
 }
 
 impl<'a, F, T, S, U> Iterator for FilterCursor<'a, F, T, U>

crates/sum_tree/src/sum_tree.rs 🔗

@@ -678,6 +678,13 @@ mod tests {
     use rand::{distributions, prelude::*};
     use std::cmp;
 
+    #[ctor::ctor]
+    fn init_logger() {
+        if std::env::var("RUST_LOG").is_ok() {
+            env_logger::init();
+        }
+    }
+
     #[test]
     fn test_extend_and_push_tree() {
         let mut tree1 = SumTree::new();
@@ -703,8 +710,11 @@ mod tests {
         if let Ok(value) = std::env::var("ITERATIONS") {
             num_iterations = value.parse().expect("invalid ITERATIONS variable");
         }
+        let num_operations = std::env::var("OPERATIONS")
+            .map_or(5, |o| o.parse().expect("invalid OPERATIONS variable"));
 
         for seed in starting_seed..(starting_seed + num_iterations) {
+            dbg!(seed);
             let mut rng = StdRng::seed_from_u64(seed);
 
             let rng = &mut rng;
@@ -712,7 +722,7 @@ mod tests {
             let count = rng.gen_range(0..10);
             tree.extend(rng.sample_iter(distributions::Standard).take(count), &());
 
-            for _ in 0..5 {
+            for _ in 0..num_operations {
                 let splice_end = rng.gen_range(0..tree.extent::<Count>(&()).0 + 1);
                 let splice_start = rng.gen_range(0..splice_end + 1);
                 let count = rng.gen_range(0..3);
@@ -740,20 +750,43 @@ mod tests {
                     tree.cursor::<()>().collect::<Vec<_>>()
                 );
 
+                log::info!("tree items: {:?}", tree.items(&()));
+
                 let mut filter_cursor =
                     tree.filter::<_, Count>(|summary| summary.contains_even, &());
-                let mut reference_filter = tree
+                let expected_filtered_items = tree
                     .items(&())
                     .into_iter()
                     .enumerate()
-                    .filter(|(_, item)| (item & 1) == 0);
-                while let Some(actual_item) = filter_cursor.item() {
-                    let (reference_index, reference_item) = reference_filter.next().unwrap();
+                    .filter(|(_, item)| (item & 1) == 0)
+                    .collect::<Vec<_>>();
+
+                let mut item_ix = 0;
+                while item_ix < expected_filtered_items.len() {
+                    log::info!("filter_cursor, item_ix: {}", item_ix);
+                    let actual_item = filter_cursor.item().unwrap();
+                    let (reference_index, reference_item) =
+                        expected_filtered_items[item_ix].clone();
                     assert_eq!(actual_item, &reference_item);
                     assert_eq!(filter_cursor.start().0, reference_index);
+                    log::info!("next");
                     filter_cursor.next(&());
+                    item_ix += 1;
+
+                    while item_ix > 0 && rng.gen_bool(0.2) {
+                        log::info!("prev");
+                        filter_cursor.prev(&());
+                        item_ix -= 1;
+
+                        if item_ix == 0 && rng.gen_bool(0.2) {
+                            filter_cursor.prev(&());
+                            assert_eq!(filter_cursor.item(), None);
+                            assert_eq!(filter_cursor.start().0, 0);
+                            filter_cursor.next(&());
+                        }
+                    }
                 }
-                assert!(reference_filter.next().is_none());
+                assert_eq!(filter_cursor.item(), None);
 
                 let mut pos = rng.gen_range(0..tree.extent::<Count>(&()).0 + 1);
                 let mut before_start = false;