Don't seek `FilterCursor` upon creation

Antonio Scandurra and Nathan Sobo created

This lets us use `next` or `prev` to decide whether to park the cursor
at the first or last filtered item.

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

crates/editor/src/display_map/fold_map.rs | 23 ++++++++---------
crates/language/src/diagnostic_set.rs     | 24 ++++++++---------
crates/sum_tree/src/cursor.rs             | 33 +++++++++++++++---------
crates/sum_tree/src/sum_tree.rs           | 19 +++++++------
crates/text/src/text.rs                   |  9 +++---
5 files changed, 58 insertions(+), 50 deletions(-)

Detailed changes

crates/editor/src/display_map/fold_map.rs 🔗

@@ -819,19 +819,18 @@ where
 {
     let start = buffer.anchor_before(range.start.to_offset(buffer));
     let end = buffer.anchor_after(range.end.to_offset(buffer));
-    folds.filter::<_, usize>(
-        move |summary| {
-            let start_cmp = start.cmp(&summary.max_end, buffer).unwrap();
-            let end_cmp = end.cmp(&summary.min_start, buffer).unwrap();
+    let mut cursor = folds.filter::<_, usize>(move |summary| {
+        let start_cmp = start.cmp(&summary.max_end, buffer).unwrap();
+        let end_cmp = end.cmp(&summary.min_start, buffer).unwrap();
 
-            if inclusive {
-                start_cmp <= Ordering::Equal && end_cmp >= Ordering::Equal
-            } else {
-                start_cmp == Ordering::Less && end_cmp == Ordering::Greater
-            }
-        },
-        buffer,
-    )
+        if inclusive {
+            start_cmp <= Ordering::Equal && end_cmp >= Ordering::Equal
+        } else {
+            start_cmp == Ordering::Less && end_cmp == Ordering::Greater
+        }
+    });
+    cursor.next(buffer);
+    cursor
 }
 
 fn consolidate_buffer_edits(edits: &mut Vec<text::Edit<usize>>) {

crates/language/src/diagnostic_set.rs 🔗

@@ -78,21 +78,19 @@ impl DiagnosticSet {
     {
         let end_bias = if inclusive { Bias::Right } else { Bias::Left };
         let range = buffer.anchor_before(range.start)..buffer.anchor_at(range.end, end_bias);
-        let mut cursor = self.diagnostics.filter::<_, ()>(
-            {
-                move |summary: &Summary| {
-                    let start_cmp = range.start.cmp(&summary.max_end, buffer).unwrap();
-                    let end_cmp = range.end.cmp(&summary.min_start, buffer).unwrap();
-                    if inclusive {
-                        start_cmp <= Ordering::Equal && end_cmp >= Ordering::Equal
-                    } else {
-                        start_cmp == Ordering::Less && end_cmp == Ordering::Greater
-                    }
+        let mut cursor = self.diagnostics.filter::<_, ()>({
+            move |summary: &Summary| {
+                let start_cmp = range.start.cmp(&summary.max_end, buffer).unwrap();
+                let end_cmp = range.end.cmp(&summary.min_start, buffer).unwrap();
+                if inclusive {
+                    start_cmp <= Ordering::Equal && end_cmp >= Ordering::Equal
+                } else {
+                    start_cmp == Ordering::Less && end_cmp == Ordering::Greater
                 }
-            },
-            buffer,
-        );
+            }
+        });
 
+        cursor.next(buffer);
         iter::from_fn({
             move || {
                 if let Some(diagnostic) = cursor.item() {

crates/sum_tree/src/cursor.rs 🔗

@@ -60,7 +60,7 @@ where
     }
 
     pub fn item(&self) -> Option<&'a T> {
-        assert!(self.did_seek, "Must seek before calling this method");
+        self.assert_did_seek();
         if let Some(entry) = self.stack.last() {
             match *entry.tree.0 {
                 Node::Leaf { ref items, .. } => {
@@ -78,7 +78,7 @@ where
     }
 
     pub fn item_summary(&self) -> Option<&'a T::Summary> {
-        assert!(self.did_seek, "Must seek before calling this method");
+        self.assert_did_seek();
         if let Some(entry) = self.stack.last() {
             match *entry.tree.0 {
                 Node::Leaf {
@@ -98,7 +98,7 @@ where
     }
 
     pub fn prev_item(&self) -> Option<&'a T> {
-        assert!(self.did_seek, "Must seek before calling this method");
+        self.assert_did_seek();
         if let Some(entry) = self.stack.last() {
             if entry.index == 0 {
                 if let Some(prev_leaf) = self.prev_leaf() {
@@ -141,6 +141,11 @@ where
     where
         F: FnMut(&T::Summary) -> bool,
     {
+        if !self.did_seek {
+            self.did_seek = true;
+            self.at_end = true;
+        }
+
         if self.at_end {
             self.position = D::default();
             self.at_end = self.tree.is_empty();
@@ -151,8 +156,6 @@ where
                     position: D::from_summary(self.tree.summary(), cx),
                 });
             }
-        } else {
-            assert!(self.did_seek, "Must seek before calling this method");
         }
 
         let mut descending = false;
@@ -289,6 +292,13 @@ where
         self.at_end = self.stack.is_empty();
         debug_assert!(self.stack.is_empty() || self.stack.last().unwrap().tree.0.is_leaf());
     }
+
+    fn assert_did_seek(&self) {
+        assert!(
+            self.did_seek,
+            "Must call `seek`, `next` or `prev` before calling this method"
+        );
+    }
 }
 
 impl<'a, T, D> Cursor<'a, T, D>
@@ -567,13 +577,8 @@ where
     T: Item,
     D: Dimension<'a, T::Summary>,
 {
-    pub fn new(
-        tree: &'a SumTree<T>,
-        mut filter_node: F,
-        cx: &<T::Summary as Summary>::Context,
-    ) -> Self {
-        let mut cursor = tree.cursor::<D>();
-        cursor.next_internal(&mut filter_node, cx);
+    pub fn new(tree: &'a SumTree<T>, filter_node: F) -> Self {
+        let cursor = tree.cursor::<D>();
         Self {
             cursor,
             filter_node,
@@ -611,6 +616,10 @@ where
     type Item = &'a T;
 
     fn next(&mut self) -> Option<Self::Item> {
+        if !self.cursor.did_seek {
+            self.next(&());
+        }
+
         if let Some(item) = self.item() {
             self.cursor.next_internal(&self.filter_node, &());
             Some(item)

crates/sum_tree/src/sum_tree.rs 🔗

@@ -168,16 +168,12 @@ impl<T: Item> SumTree<T> {
         Cursor::new(self)
     }
 
-    pub fn filter<'a, F, U>(
-        &'a self,
-        filter_node: F,
-        cx: &<T::Summary as Summary>::Context,
-    ) -> FilterCursor<F, T, U>
+    pub fn filter<'a, F, U>(&'a self, filter_node: F) -> FilterCursor<F, T, U>
     where
         F: FnMut(&T::Summary) -> bool,
         U: Dimension<'a, T::Summary>,
     {
-        FilterCursor::new(self, filter_node, cx)
+        FilterCursor::new(self, filter_node)
     }
 
     #[allow(dead_code)]
@@ -752,8 +748,7 @@ mod tests {
 
                 log::info!("tree items: {:?}", tree.items(&()));
 
-                let mut filter_cursor =
-                    tree.filter::<_, Count>(|summary| summary.contains_even, &());
+                let mut filter_cursor = tree.filter::<_, Count>(|summary| summary.contains_even);
                 let expected_filtered_items = tree
                     .items(&())
                     .into_iter()
@@ -761,7 +756,13 @@ mod tests {
                     .filter(|(_, item)| (item & 1) == 0)
                     .collect::<Vec<_>>();
 
-                let mut item_ix = 0;
+                let mut item_ix = if rng.gen() {
+                    filter_cursor.next(&());
+                    0
+                } else {
+                    filter_cursor.prev(&());
+                    expected_filtered_items.len().saturating_sub(1)
+                };
                 while item_ix < expected_filtered_items.len() {
                     log::info!("filter_cursor, item_ix: {}", item_ix);
                     let actual_item = filter_cursor.item().unwrap();

crates/text/src/text.rs 🔗

@@ -1849,10 +1849,11 @@ impl BufferSnapshot {
         let fragments_cursor = if *since == self.version {
             None
         } else {
-            Some(self.fragments.filter(
-                move |summary| !since.observed_all(&summary.max_version),
-                &None,
-            ))
+            let mut cursor = self
+                .fragments
+                .filter(move |summary| !since.observed_all(&summary.max_version));
+            cursor.next(&None);
+            Some(cursor)
         };
         let mut cursor = self
             .fragments