diff --git a/crates/sum_tree/src/cursor.rs b/crates/sum_tree/src/cursor.rs index 324276aeedd2307c1f8024d3243536462593fd3a..31f573a895696c3f6549787c35eb96c4e98fa9bd 100644 --- a/crates/sum_tree/src/cursor.rs +++ b/crates/sum_tree/src/cursor.rs @@ -205,13 +205,13 @@ where #[track_caller] pub fn prev(&mut self) { - self.search_backward(|_| true) + self.search_backward(|_| Ordering::Greater) } #[track_caller] pub fn search_backward(&mut self, mut filter_node: F) where - F: FnMut(&T::Summary) -> bool, + F: FnMut(&T::Summary) -> Ordering, { if !self.did_seek { self.did_seek = true; @@ -253,13 +253,14 @@ where } } - for summary in &entry.tree.0.child_summaries()[..entry.index] { - self.position.add_summary(summary, self.cx); + if entry.index != 0 { + self.position + .add_summary(&entry.tree.0.child_summaries()[entry.index - 1], self.cx); } entry.position = self.position.clone(); - descending = filter_node(&entry.tree.0.child_summaries()[entry.index]); + descending = filter_node(&entry.tree.0.child_summaries()[entry.index]).is_ge(); match entry.tree.0.as_ref() { Node::Internal { child_trees, .. } => { if descending { @@ -282,13 +283,13 @@ where #[track_caller] pub fn next(&mut self) { - self.search_forward(|_| true) + self.search_forward(|_| Ordering::Less) } #[track_caller] pub fn search_forward(&mut self, mut filter_node: F) where - F: FnMut(&T::Summary) -> bool, + F: FnMut(&T::Summary) -> Ordering, { let mut descend = false; @@ -318,14 +319,17 @@ where entry.position = self.position.clone(); } - while entry.index < child_summaries.len() { - let next_summary = &child_summaries[entry.index]; - if filter_node(next_summary) { - break; - } else { - entry.index += 1; - entry.position.add_summary(next_summary, self.cx); - self.position.add_summary(next_summary, self.cx); + if entry.index < child_summaries.len() { + let index = child_summaries[entry.index..] + .partition_point(|item| filter_node(item).is_lt()); + entry.index += index; + let position = Some(entry.index) + .filter(|index| *index < child_summaries.len()) + .unwrap_or(child_summaries.len()); + + if let Some(summary) = child_summaries.get(position) { + entry.position.add_summary(summary, self.cx); + self.position.add_summary(summary, self.cx); } } @@ -340,13 +344,17 @@ where } loop { - if let Some(next_item_summary) = item_summaries.get(entry.index) { - if filter_node(next_item_summary) { - return; - } else { - entry.index += 1; - entry.position.add_summary(next_item_summary, self.cx); - self.position.add_summary(next_item_summary, self.cx); + if entry.index < item_summaries.len() { + let index = item_summaries[entry.index..] + .partition_point(|item| filter_node(item).is_lt()); + entry.index += index; + let position = Some(entry.index) + .filter(|index| *index < item_summaries.len()) + .unwrap_or(item_summaries.len()); + + if let Some(summary) = item_summaries.get(position) { + entry.position.add_summary(summary, self.cx); + self.position.add_summary(summary, self.cx); } } else { break None; @@ -638,7 +646,7 @@ pub struct FilterCursor<'a, F, T: Item, D> { impl<'a, F, T: Item, D> FilterCursor<'a, F, T, D> where - F: FnMut(&T::Summary) -> bool, + F: FnMut(&T::Summary) -> Ordering, T: Item, D: Dimension<'a, T::Summary>, { @@ -681,7 +689,7 @@ where impl<'a, F, T: Item, U> Iterator for FilterCursor<'a, F, T, U> where - F: FnMut(&T::Summary) -> bool, + F: FnMut(&T::Summary) -> Ordering, U: Dimension<'a, T::Summary>, { type Item = &'a T; diff --git a/crates/sum_tree/src/sum_tree.rs b/crates/sum_tree/src/sum_tree.rs index 184148fa1b729102bd0b5b4e24885ee3ebb26103..565cfceaf26c553a0940e632d89014be55b05498 100644 --- a/crates/sum_tree/src/sum_tree.rs +++ b/crates/sum_tree/src/sum_tree.rs @@ -375,7 +375,7 @@ impl SumTree { filter_node: F, ) -> FilterCursor<'a, F, T, U> where - F: FnMut(&T::Summary) -> bool, + F: FnMut(&T::Summary) -> Ordering, U: Dimension<'a, T::Summary>, { FilterCursor::new(self, cx, filter_node) @@ -1026,7 +1026,7 @@ mod tests { log::info!("tree items: {:?}", tree.items(&())); let mut filter_cursor = - tree.filter::<_, Count>(&(), |summary| summary.contains_even); + tree.filter::<_, Count>(&(), |summary| summary.contains_even.cmp(&false)); let expected_filtered_items = tree .items(&()) .into_iter() diff --git a/crates/sum_tree/src/tree_map.rs b/crates/sum_tree/src/tree_map.rs index 54e8ae8343f4778e04a37a7ebd3dbe2b6da587cd..682f1e8f65b3f9cd6bec0c2d190264637d0e0234 100644 --- a/crates/sum_tree/src/tree_map.rs +++ b/crates/sum_tree/src/tree_map.rs @@ -412,7 +412,7 @@ mod tests { .take_while(|(key, _)| key.starts_with("ba")) .collect::>(); - assert_eq!(result.len(), 2); + assert_eq!(result.len(), 2, "{result:?}"); assert!(result.iter().any(|(k, _)| k == &&"baa")); assert!(result.iter().any(|(k, _)| k == &&"baaab"));