Make summaries a running sum

Piotr Osiewicz created

Change summary

crates/sum_tree/src/cursor.rs   | 56 ++++++++++++++++++++---------------
crates/sum_tree/src/sum_tree.rs |  4 +-
crates/sum_tree/src/tree_map.rs |  2 
3 files changed, 35 insertions(+), 27 deletions(-)

Detailed changes

crates/sum_tree/src/cursor.rs 🔗

@@ -205,13 +205,13 @@ where
 
     #[track_caller]
     pub fn prev(&mut self) {
-        self.search_backward(|_| true)
+        self.search_backward(|_| Ordering::Greater)
     }
 
     #[track_caller]
     pub fn search_backward<F>(&mut self, mut filter_node: F)
     where
-        F: FnMut(&T::Summary) -> bool,
+        F: FnMut(&T::Summary) -> Ordering,
     {
         if !self.did_seek {
             self.did_seek = true;
@@ -253,13 +253,14 @@ where
                 }
             }
 
-            for summary in &entry.tree.0.child_summaries()[..entry.index] {
-                self.position.add_summary(summary, self.cx);
+            if entry.index != 0 {
+                self.position
+                    .add_summary(&entry.tree.0.child_summaries()[entry.index - 1], self.cx);
             }
 
             entry.position = self.position.clone();
 
-            descending = filter_node(&entry.tree.0.child_summaries()[entry.index]);
+            descending = filter_node(&entry.tree.0.child_summaries()[entry.index]).is_ge();
             match entry.tree.0.as_ref() {
                 Node::Internal { child_trees, .. } => {
                     if descending {
@@ -282,13 +283,13 @@ where
 
     #[track_caller]
     pub fn next(&mut self) {
-        self.search_forward(|_| true)
+        self.search_forward(|_| Ordering::Less)
     }
 
     #[track_caller]
     pub fn search_forward<F>(&mut self, mut filter_node: F)
     where
-        F: FnMut(&T::Summary) -> bool,
+        F: FnMut(&T::Summary) -> Ordering,
     {
         let mut descend = false;
 
@@ -318,14 +319,17 @@ where
                             entry.position = self.position.clone();
                         }
 
-                        while entry.index < child_summaries.len() {
-                            let next_summary = &child_summaries[entry.index];
-                            if filter_node(next_summary) {
-                                break;
-                            } else {
-                                entry.index += 1;
-                                entry.position.add_summary(next_summary, self.cx);
-                                self.position.add_summary(next_summary, self.cx);
+                        if entry.index < child_summaries.len() {
+                            let index = child_summaries[entry.index..]
+                                .partition_point(|item| filter_node(item).is_lt());
+                            entry.index += index;
+                            let position = Some(entry.index)
+                                .filter(|index| *index < child_summaries.len())
+                                .unwrap_or(child_summaries.len());
+
+                            if let Some(summary) = child_summaries.get(position) {
+                                entry.position.add_summary(summary, self.cx);
+                                self.position.add_summary(summary, self.cx);
                             }
                         }
 
@@ -340,13 +344,17 @@ where
                         }
 
                         loop {
-                            if let Some(next_item_summary) = item_summaries.get(entry.index) {
-                                if filter_node(next_item_summary) {
-                                    return;
-                                } else {
-                                    entry.index += 1;
-                                    entry.position.add_summary(next_item_summary, self.cx);
-                                    self.position.add_summary(next_item_summary, self.cx);
+                            if entry.index < item_summaries.len() {
+                                let index = item_summaries[entry.index..]
+                                    .partition_point(|item| filter_node(item).is_lt());
+                                entry.index += index;
+                                let position = Some(entry.index)
+                                    .filter(|index| *index < item_summaries.len())
+                                    .unwrap_or(item_summaries.len());
+
+                                if let Some(summary) = item_summaries.get(position) {
+                                    entry.position.add_summary(summary, self.cx);
+                                    self.position.add_summary(summary, self.cx);
                                 }
                             } else {
                                 break None;
@@ -638,7 +646,7 @@ pub struct FilterCursor<'a, F, T: Item, D> {
 
 impl<'a, F, T: Item, D> FilterCursor<'a, F, T, D>
 where
-    F: FnMut(&T::Summary) -> bool,
+    F: FnMut(&T::Summary) -> Ordering,
     T: Item,
     D: Dimension<'a, T::Summary>,
 {
@@ -681,7 +689,7 @@ where
 
 impl<'a, F, T: Item, U> Iterator for FilterCursor<'a, F, T, U>
 where
-    F: FnMut(&T::Summary) -> bool,
+    F: FnMut(&T::Summary) -> Ordering,
     U: Dimension<'a, T::Summary>,
 {
     type Item = &'a T;

crates/sum_tree/src/sum_tree.rs 🔗

@@ -375,7 +375,7 @@ impl<T: Item> SumTree<T> {
         filter_node: F,
     ) -> FilterCursor<'a, F, T, U>
     where
-        F: FnMut(&T::Summary) -> bool,
+        F: FnMut(&T::Summary) -> Ordering,
         U: Dimension<'a, T::Summary>,
     {
         FilterCursor::new(self, cx, filter_node)
@@ -1026,7 +1026,7 @@ mod tests {
                 log::info!("tree items: {:?}", tree.items(&()));
 
                 let mut filter_cursor =
-                    tree.filter::<_, Count>(&(), |summary| summary.contains_even);
+                    tree.filter::<_, Count>(&(), |summary| summary.contains_even.cmp(&false));
                 let expected_filtered_items = tree
                     .items(&())
                     .into_iter()

crates/sum_tree/src/tree_map.rs 🔗

@@ -412,7 +412,7 @@ mod tests {
             .take_while(|(key, _)| key.starts_with("ba"))
             .collect::<Vec<_>>();
 
-        assert_eq!(result.len(), 2);
+        assert_eq!(result.len(), 2, "{result:?}");
         assert!(result.iter().any(|(k, _)| k == &&"baa"));
         assert!(result.iter().any(|(k, _)| k == &&"baaab"));