Implement `FoldMap`'s folds using a `SumTree`

Antonio Scandurra created

This required passing a `Context` object to `Summary` and introducing a
new `SeekDimension` trait that allows comparing two dimensions and pass
an additional context object.

Change summary

zed/src/editor/buffer/mod.rs           |   4 
zed/src/editor/buffer/text.rs          |   2 
zed/src/editor/display_map/fold_map.rs | 229 +++++++++++++++++++++------
zed/src/editor/display_map/mod.rs      |   2 
zed/src/operation_queue.rs             |   2 
zed/src/sum_tree/cursor.rs             | 121 ++++++++++----
zed/src/sum_tree/mod.rs                | 104 +++++++++---
zed/src/util.rs                        |  39 ----
zed/src/worktree.rs                    |   2 
9 files changed, 354 insertions(+), 151 deletions(-)

Detailed changes

zed/src/editor/buffer/mod.rs 🔗

@@ -2103,6 +2103,8 @@ impl sum_tree::Item for Fragment {
 }
 
 impl sum_tree::Summary for FragmentSummary {
+    type Context = ();
+
     fn add_summary(&mut self, other: &Self) {
         self.text_summary += &other.text_summary;
         debug_assert!(self.max_fragment_id <= other.max_fragment_id);
@@ -2167,6 +2169,8 @@ impl sum_tree::Item for InsertionSplit {
 }
 
 impl sum_tree::Summary for InsertionSplitSummary {
+    type Context = ();
+
     fn add_summary(&mut self, other: &Self) {
         self.extent += other.extent;
     }

zed/src/editor/buffer/text.rs 🔗

@@ -59,6 +59,8 @@ pub struct TextSummary {
 }
 
 impl sum_tree::Summary for TextSummary {
+    type Context = ();
+
     fn add_summary(&mut self, other: &Self) {
         *self += other;
     }

zed/src/editor/display_map/fold_map.rs 🔗

@@ -1,10 +1,10 @@
 use super::{
-    buffer, Anchor, AnchorRangeExt, Buffer, DisplayPoint, Edit, Point, TextSummary, ToOffset,
+    buffer::{self, AnchorRangeExt},
+    Anchor, Buffer, DisplayPoint, Edit, Point, TextSummary, ToOffset,
 };
 use crate::{
-    sum_tree::{self, Cursor, SumTree},
+    sum_tree::{self, Cursor, SeekBias, SumTree},
     time,
-    util::find_insertion_index,
 };
 use anyhow::{anyhow, Result};
 use gpui::{AppContext, ModelHandle};
@@ -14,12 +14,11 @@ use std::{
     iter::Take,
     ops::Range,
 };
-use sum_tree::{Dimension, SeekBias};
 
 pub struct FoldMap {
     buffer: ModelHandle<Buffer>,
     transforms: Mutex<SumTree<Transform>>,
-    folds: Vec<Range<Anchor>>,
+    folds: SumTree<Fold>,
     last_sync: Mutex<time::Global>,
 }
 
@@ -29,7 +28,7 @@ impl FoldMap {
         let text_summary = buffer.text_summary();
         Self {
             buffer: buffer_handle,
-            folds: Vec::new(),
+            folds: Default::default(),
             transforms: Mutex::new(SumTree::from_item(Transform {
                 summary: TransformSummary {
                     buffer: text_summary.clone(),
@@ -76,17 +75,20 @@ impl FoldMap {
     pub fn folds_in_range<'a, T>(
         &'a self,
         range: Range<T>,
-        app: &'a AppContext,
+        ctx: &'a AppContext,
     ) -> Result<impl Iterator<Item = &'a Range<Anchor>>>
     where
         T: ToOffset,
     {
-        let buffer = self.buffer.read(app);
+        let buffer = self.buffer.read(ctx);
         let range = buffer.anchor_before(range.start)?..buffer.anchor_before(range.end)?;
-        Ok(self.folds.iter().filter(move |fold| {
-            range.start.cmp(&fold.end, buffer).unwrap() == Ordering::Less
-                && range.end.cmp(&fold.start, buffer).unwrap() == Ordering::Greater
-        }))
+        Ok(self
+            .folds
+            .filter::<_, usize>(move |summary| {
+                range.start.cmp(&summary.max_end, buffer).unwrap() < Ordering::Equal
+                    && range.end.cmp(&summary.min_start, buffer).unwrap() >= Ordering::Equal
+            })
+            .map(|f| &f.0))
     }
 
     pub fn fold<T: ToOffset>(
@@ -99,16 +101,31 @@ impl FoldMap {
         let mut edits = Vec::new();
         let buffer = self.buffer.read(ctx);
         for range in ranges.into_iter() {
-            let start = range.start.to_offset(buffer)?;
-            let end = range.end.to_offset(buffer)?;
+            let range = range.start.to_offset(buffer)?..range.end.to_offset(buffer)?;
+            let fold = if range.start == range.end {
+                Fold(buffer.anchor_after(range.start)?..buffer.anchor_after(range.end)?)
+            } else {
+                Fold(buffer.anchor_after(range.start)?..buffer.anchor_before(range.end)?)
+            };
             edits.push(Edit {
-                old_range: start..end,
-                new_range: start..end,
+                old_range: range.clone(),
+                new_range: range.clone(),
             });
-
-            let fold = buffer.anchor_after(start)?..buffer.anchor_before(end)?;
-            let ix = find_insertion_index(&self.folds, |probe| probe.cmp(&fold, buffer))?;
-            self.folds.insert(ix, fold);
+            self.folds = {
+                let mut new_tree = SumTree::new();
+                let mut cursor = self.folds.cursor::<_, ()>();
+                new_tree.push_tree_with_ctx(
+                    cursor.slice_with_ctx(
+                        &FoldRange(fold.0.clone()),
+                        SeekBias::Right,
+                        Some(buffer),
+                    ),
+                    Some(buffer),
+                );
+                new_tree.push_with_ctx(fold, Some(buffer));
+                new_tree.push_tree_with_ctx(cursor.suffix_with_ctx(Some(buffer)), Some(buffer));
+                new_tree
+            };
         }
         edits.sort_unstable_by(|a, b| {
             a.old_range
@@ -131,27 +148,45 @@ impl FoldMap {
         let buffer = self.buffer.read(ctx);
 
         let mut edits = Vec::new();
+        let mut fold_ixs_to_delete = Vec::new();
         for range in ranges.into_iter() {
             let start = buffer.anchor_before(range.start.to_offset(buffer)?)?;
             let end = buffer.anchor_after(range.end.to_offset(buffer)?)?;
+            let range = start..end;
 
             // Remove intersecting folds and add their ranges to edits that are passed to apply_edits
-            self.folds.retain(|fold| {
-                if fold.start.cmp(&end, buffer).unwrap() > Ordering::Equal
-                    || fold.end.cmp(&start, buffer).unwrap() < Ordering::Equal
-                {
-                    true
-                } else {
-                    let offset_range =
-                        fold.start.to_offset(buffer).unwrap()..fold.end.to_offset(buffer).unwrap();
-                    edits.push(Edit {
-                        old_range: offset_range.clone(),
-                        new_range: offset_range,
-                    });
-                    false
-                }
+            let mut cursor = self.folds.filter::<_, usize>(|summary| {
+                range.start.cmp(&summary.max_end, buffer).unwrap() < Ordering::Equal
+                    && range.end.cmp(&summary.min_start, buffer).unwrap() >= Ordering::Equal
             });
+
+            while let Some(fold) = cursor.item() {
+                let offset_range =
+                    fold.0.start.to_offset(buffer).unwrap()..fold.0.end.to_offset(buffer).unwrap();
+                edits.push(Edit {
+                    old_range: offset_range.clone(),
+                    new_range: offset_range,
+                });
+                fold_ixs_to_delete.push(*cursor.start());
+                cursor.next();
+            }
         }
+        fold_ixs_to_delete.sort_unstable();
+        fold_ixs_to_delete.dedup();
+
+        self.folds = {
+            let mut cursor = self.folds.cursor::<_, ()>();
+            let mut folds = SumTree::new();
+            for fold_ix in fold_ixs_to_delete {
+                folds.push_tree_with_ctx(
+                    cursor.slice_with_ctx(&fold_ix, SeekBias::Right, Some(buffer)),
+                    Some(buffer),
+                );
+                cursor.next();
+            }
+            folds.push_tree_with_ctx(cursor.suffix_with_ctx(Some(buffer)), Some(buffer));
+            folds
+        };
 
         self.apply_edits(edits, ctx);
         Ok(())
@@ -262,14 +297,14 @@ impl FoldMap {
                 ((edit.new_range.start + edit.old_extent()) as isize + delta) as usize;
 
             let anchor = buffer.anchor_before(edit.new_range.start).unwrap();
-            let folds_start =
-                find_insertion_index(&self.folds, |probe| probe.start.cmp(&anchor, buffer))
-                    .unwrap();
-            let mut folds = self.folds[folds_start..]
-                .iter()
-                .map(|fold| {
-                    fold.start.to_offset(buffer).unwrap()..fold.end.to_offset(buffer).unwrap()
-                })
+            let mut folds_cursor = self.folds.cursor::<_, ()>();
+            folds_cursor.seek_with_ctx(
+                &FoldRange(anchor..Anchor::End),
+                SeekBias::Left,
+                Some(buffer),
+            );
+            let mut folds = folds_cursor
+                .map(|f| f.0.start.to_offset(buffer).unwrap()..f.0.end.to_offset(buffer).unwrap())
                 .peekable();
 
             while folds
@@ -422,18 +457,112 @@ impl sum_tree::Item for Transform {
 }
 
 impl sum_tree::Summary for TransformSummary {
+    type Context = ();
+
     fn add_summary(&mut self, other: &Self) {
         self.buffer += &other.buffer;
         self.display += &other.display;
     }
 }
 
-impl<'a> Dimension<'a, TransformSummary> for TransformSummary {
+impl<'a> sum_tree::Dimension<'a, TransformSummary> for TransformSummary {
     fn add_summary(&mut self, summary: &'a TransformSummary) {
         sum_tree::Summary::add_summary(self, summary);
     }
 }
 
+#[derive(Clone, Debug)]
+struct Fold(Range<Anchor>);
+
+impl sum_tree::Item for Fold {
+    type Summary = FoldSummary;
+
+    fn summary(&self) -> Self::Summary {
+        FoldSummary {
+            start: self.0.start.clone(),
+            end: self.0.end.clone(),
+            min_start: self.0.start.clone(),
+            max_end: self.0.end.clone(),
+            count: 1,
+        }
+    }
+}
+
+#[derive(Clone, Debug)]
+struct FoldSummary {
+    start: Anchor,
+    end: Anchor,
+    min_start: Anchor,
+    max_end: Anchor,
+    count: usize,
+}
+
+impl Default for FoldSummary {
+    fn default() -> Self {
+        Self {
+            start: Anchor::Start,
+            end: Anchor::End,
+            min_start: Anchor::Start,
+            max_end: Anchor::Start,
+            count: 0,
+        }
+    }
+}
+
+impl sum_tree::Summary for FoldSummary {
+    type Context = Buffer;
+
+    fn add_summary_with_ctx(&mut self, other: &Self, buffer: Option<&Self::Context>) {
+        let buffer = buffer.unwrap();
+        if other.min_start.cmp(&self.min_start, buffer).unwrap() == Ordering::Less {
+            self.min_start = other.min_start.clone();
+        }
+        if other.max_end.cmp(&self.max_end, buffer).unwrap() == Ordering::Greater {
+            self.max_end = other.max_end.clone();
+        }
+
+        if cfg!(debug_assertions) {
+            let start_comparison = self.start.cmp(&other.start, buffer).unwrap();
+            let end_comparison = self.end.cmp(&other.end, buffer).unwrap();
+            assert!(start_comparison <= Ordering::Equal);
+            if start_comparison == Ordering::Equal {
+                assert!(end_comparison >= Ordering::Equal);
+            }
+        }
+        self.start = other.start.clone();
+        self.end = other.end.clone();
+        self.count += other.count;
+    }
+}
+
+#[derive(Clone, Debug)]
+struct FoldRange(Range<Anchor>);
+
+impl Default for FoldRange {
+    fn default() -> Self {
+        Self(Anchor::Start..Anchor::End)
+    }
+}
+
+impl<'a> sum_tree::Dimension<'a, FoldSummary> for FoldRange {
+    fn add_summary(&mut self, summary: &'a FoldSummary) {
+        self.0.start = summary.start.clone();
+        self.0.end = summary.end.clone();
+    }
+}
+
+impl<'a> sum_tree::SeekDimension<'a, FoldSummary> for FoldRange {
+    fn cmp(&self, other: &Self, buffer: Option<&Buffer>) -> Ordering {
+        self.0.cmp(&other.0, buffer.unwrap()).unwrap()
+    }
+}
+
+impl<'a> sum_tree::Dimension<'a, FoldSummary> for usize {
+    fn add_summary(&mut self, summary: &'a FoldSummary) {
+        *self += summary.count;
+    }
+}
+
 pub struct BufferRows<'a> {
     cursor: Cursor<'a, Transform, DisplayPoint, TransformSummary>,
     display_point: Point,
@@ -498,7 +627,7 @@ impl<'a> Iterator for Chars<'a> {
     }
 }
 
-impl<'a> Dimension<'a, TransformSummary> for DisplayPoint {
+impl<'a> sum_tree::Dimension<'a, TransformSummary> for DisplayPoint {
     fn add_summary(&mut self, summary: &'a TransformSummary) {
         self.0 += &summary.display.lines;
     }
@@ -507,19 +636,19 @@ impl<'a> Dimension<'a, TransformSummary> for DisplayPoint {
 #[derive(Copy, Clone, Debug, Default, Eq, Ord, PartialOrd, PartialEq)]
 pub struct DisplayOffset(usize);
 
-impl<'a> Dimension<'a, TransformSummary> for DisplayOffset {
+impl<'a> sum_tree::Dimension<'a, TransformSummary> for DisplayOffset {
     fn add_summary(&mut self, summary: &'a TransformSummary) {
         self.0 += &summary.display.chars;
     }
 }
 
-impl<'a> Dimension<'a, TransformSummary> for Point {
+impl<'a> sum_tree::Dimension<'a, TransformSummary> for Point {
     fn add_summary(&mut self, summary: &'a TransformSummary) {
         *self += &summary.buffer.lines;
     }
 }
 
-impl<'a> Dimension<'a, TransformSummary> for usize {
+impl<'a> sum_tree::Dimension<'a, TransformSummary> for usize {
     fn add_summary(&mut self, summary: &'a TransformSummary) {
         *self += &summary.buffer.chars;
     }
@@ -894,13 +1023,13 @@ mod tests {
 
         fn merged_fold_ranges(&self, app: &AppContext) -> Vec<Range<usize>> {
             let buffer = self.buffer.read(app);
-            let mut folds = self.folds.clone();
+            let mut folds = self.folds.items();
             // Ensure sorting doesn't change how folds get merged and displayed.
-            folds.sort_by(|a, b| a.cmp(b, buffer).unwrap());
+            folds.sort_by(|a, b| a.0.cmp(&b.0, buffer).unwrap());
             let mut fold_ranges = folds
                 .iter()
                 .map(|fold| {
-                    fold.start.to_offset(buffer).unwrap()..fold.end.to_offset(buffer).unwrap()
+                    fold.0.start.to_offset(buffer).unwrap()..fold.0.end.to_offset(buffer).unwrap()
                 })
                 .peekable();
 

zed/src/editor/display_map/mod.rs 🔗

@@ -1,6 +1,6 @@
 mod fold_map;
 
-use super::{buffer, Anchor, AnchorRangeExt, Buffer, Edit, Point, TextSummary, ToOffset, ToPoint};
+use super::{buffer, Anchor, Buffer, Edit, Point, TextSummary, ToOffset, ToPoint};
 use anyhow::Result;
 pub use fold_map::BufferRows;
 use fold_map::{FoldMap, FoldMapSnapshot};

zed/src/operation_queue.rs 🔗

@@ -66,6 +66,8 @@ impl<T: Operation> KeyedItem for T {
 }
 
 impl Summary for OperationSummary {
+    type Context = ();
+
     fn add_summary(&mut self, other: &Self) {
         assert!(self.key < other.key);
         self.key = other.key;

zed/src/sum_tree/cursor.rs 🔗

@@ -382,22 +382,48 @@ where
 impl<'a, T, S, U> Cursor<'a, T, S, U>
 where
     T: Item,
-    S: Dimension<'a, T::Summary> + Ord,
+    S: SeekDimension<'a, T::Summary>,
     U: Dimension<'a, T::Summary>,
 {
     pub fn seek(&mut self, pos: &S, bias: SeekBias) -> bool {
+        self.seek_with_ctx(pos, bias, None)
+    }
+
+    pub fn seek_with_ctx(
+        &mut self,
+        pos: &S,
+        bias: SeekBias,
+        ctx: Option<&'a <T::Summary as Summary>::Context>,
+    ) -> bool {
         self.reset();
-        self.seek_internal::<()>(pos, bias, &mut SeekAggregate::None)
+        self.seek_internal::<()>(pos, bias, &mut SeekAggregate::None, ctx)
     }
 
-    #[allow(unused)]
     pub fn seek_forward(&mut self, pos: &S, bias: SeekBias) -> bool {
-        self.seek_internal::<()>(pos, bias, &mut SeekAggregate::None)
+        self.seek_forward_with_ctx(pos, bias, None)
+    }
+
+    pub fn seek_forward_with_ctx(
+        &mut self,
+        pos: &S,
+        bias: SeekBias,
+        ctx: Option<&'a <T::Summary as Summary>::Context>,
+    ) -> bool {
+        self.seek_internal::<()>(pos, bias, &mut SeekAggregate::None, ctx)
     }
 
     pub fn slice(&mut self, end: &S, bias: SeekBias) -> SumTree<T> {
+        self.slice_with_ctx(end, bias, None)
+    }
+
+    pub fn slice_with_ctx(
+        &mut self,
+        end: &S,
+        bias: SeekBias,
+        ctx: Option<&'a <T::Summary as Summary>::Context>,
+    ) -> SumTree<T> {
         let mut slice = SeekAggregate::Slice(SumTree::new());
-        self.seek_internal::<()>(end, bias, &mut slice);
+        self.seek_internal::<()>(end, bias, &mut slice, ctx);
         if let SeekAggregate::Slice(slice) = slice {
             slice
         } else {
@@ -406,9 +432,16 @@ where
     }
 
     pub fn suffix(&mut self) -> SumTree<T> {
+        self.suffix_with_ctx(None)
+    }
+
+    pub fn suffix_with_ctx(
+        &mut self,
+        ctx: Option<&'a <T::Summary as Summary>::Context>,
+    ) -> SumTree<T> {
         let extent = self.tree.extent::<S>();
         let mut slice = SeekAggregate::Slice(SumTree::new());
-        self.seek_internal::<()>(&extent, SeekBias::Right, &mut slice);
+        self.seek_internal::<()>(&extent, SeekBias::Right, &mut slice, ctx);
         if let SeekAggregate::Slice(slice) = slice {
             slice
         } else {
@@ -417,11 +450,23 @@ where
     }
 
     pub fn summary<D>(&mut self, end: &S, bias: SeekBias) -> D
+    where
+        D: Dimension<'a, T::Summary>,
+    {
+        self.summary_with_ctx(end, bias, None)
+    }
+
+    pub fn summary_with_ctx<D>(
+        &mut self,
+        end: &S,
+        bias: SeekBias,
+        ctx: Option<&'a <T::Summary as Summary>::Context>,
+    ) -> D
     where
         D: Dimension<'a, T::Summary>,
     {
         let mut summary = SeekAggregate::Summary(D::default());
-        self.seek_internal(end, bias, &mut summary);
+        self.seek_internal(end, bias, &mut summary, ctx);
         if let SeekAggregate::Summary(summary) = summary {
             summary
         } else {
@@ -434,11 +479,12 @@ where
         target: &S,
         bias: SeekBias,
         aggregate: &mut SeekAggregate<T, D>,
+        ctx: Option<&'a <T::Summary as Summary>::Context>,
     ) -> bool
     where
         D: Dimension<'a, T::Summary>,
     {
-        debug_assert!(target >= &self.seek_dimension);
+        debug_assert!(target.cmp(&self.seek_dimension, ctx) >= Ordering::Equal);
         let mut containing_subtree = None;
 
         if self.did_seek {
@@ -458,7 +504,7 @@ where
                                 let mut child_end = self.seek_dimension.clone();
                                 child_end.add_summary(&child_summary);
 
-                                let comparison = target.cmp(&child_end);
+                                let comparison = target.cmp(&child_end, ctx);
                                 if comparison == Ordering::Greater
                                     || (comparison == Ordering::Equal && bias == SeekBias::Right)
                                 {
@@ -467,7 +513,7 @@ where
                                     match aggregate {
                                         SeekAggregate::None => {}
                                         SeekAggregate::Slice(slice) => {
-                                            slice.push_tree(child_tree.clone());
+                                            slice.push_tree_with_ctx(child_tree.clone(), ctx);
                                         }
                                         SeekAggregate::Summary(summary) => {
                                             summary.add_summary(child_summary);
@@ -500,7 +546,7 @@ where
                                 let mut item_end = self.seek_dimension.clone();
                                 item_end.add_summary(item_summary);
 
-                                let comparison = target.cmp(&item_end);
+                                let comparison = target.cmp(&item_end, ctx);
                                 if comparison == Ordering::Greater
                                     || (comparison == Ordering::Equal && bias == SeekBias::Right)
                                 {
@@ -514,7 +560,7 @@ where
                                             slice_items_summary
                                                 .as_mut()
                                                 .unwrap()
-                                                .add_summary(item_summary);
+                                                .add_summary_with_ctx(item_summary, ctx);
                                         }
                                         SeekAggregate::Summary(summary) => {
                                             summary.add_summary(item_summary);
@@ -523,11 +569,14 @@ where
                                     entry.index += 1;
                                 } else {
                                     if let SeekAggregate::Slice(slice) = aggregate {
-                                        slice.push_tree(SumTree(Arc::new(Node::Leaf {
-                                            summary: slice_items_summary.unwrap(),
-                                            items: slice_items,
-                                            item_summaries: slice_item_summaries,
-                                        })));
+                                        slice.push_tree_with_ctx(
+                                            SumTree(Arc::new(Node::Leaf {
+                                                summary: slice_items_summary.unwrap(),
+                                                items: slice_items,
+                                                item_summaries: slice_item_summaries,
+                                            })),
+                                            ctx,
+                                        );
                                     }
                                     break 'outer;
                                 }
@@ -535,11 +584,14 @@ where
 
                             if let SeekAggregate::Slice(slice) = aggregate {
                                 if !slice_items.is_empty() {
-                                    slice.push_tree(SumTree(Arc::new(Node::Leaf {
-                                        summary: slice_items_summary.unwrap(),
-                                        items: slice_items,
-                                        item_summaries: slice_item_summaries,
-                                    })));
+                                    slice.push_tree_with_ctx(
+                                        SumTree(Arc::new(Node::Leaf {
+                                            summary: slice_items_summary.unwrap(),
+                                            items: slice_items,
+                                            item_summaries: slice_item_summaries,
+                                        })),
+                                        ctx,
+                                    );
                                 }
                             }
                         }
@@ -568,7 +620,7 @@ where
                             let mut child_end = self.seek_dimension.clone();
                             child_end.add_summary(child_summary);
 
-                            let comparison = target.cmp(&child_end);
+                            let comparison = target.cmp(&child_end, ctx);
                             if comparison == Ordering::Greater
                                 || (comparison == Ordering::Equal && bias == SeekBias::Right)
                             {
@@ -577,7 +629,7 @@ where
                                 match aggregate {
                                     SeekAggregate::None => {}
                                     SeekAggregate::Slice(slice) => {
-                                        slice.push_tree(child_trees[index].clone());
+                                        slice.push_tree_with_ctx(child_trees[index].clone(), ctx);
                                     }
                                     SeekAggregate::Summary(summary) => {
                                         summary.add_summary(child_summary);
@@ -614,7 +666,7 @@ where
                             let mut child_end = self.seek_dimension.clone();
                             child_end.add_summary(item_summary);
 
-                            let comparison = target.cmp(&child_end);
+                            let comparison = target.cmp(&child_end, ctx);
                             if comparison == Ordering::Greater
                                 || (comparison == Ordering::Equal && bias == SeekBias::Right)
                             {
@@ -627,7 +679,7 @@ where
                                         slice_items_summary
                                             .as_mut()
                                             .unwrap()
-                                            .add_summary(item_summary);
+                                            .add_summary_with_ctx(item_summary, ctx);
                                         slice_item_summaries.push(item_summary.clone());
                                     }
                                     SeekAggregate::Summary(summary) => {
@@ -647,11 +699,14 @@ where
 
                         if let SeekAggregate::Slice(slice) = aggregate {
                             if !slice_items.is_empty() {
-                                slice.push_tree(SumTree(Arc::new(Node::Leaf {
-                                    summary: slice_items_summary.unwrap(),
-                                    items: slice_items,
-                                    item_summaries: slice_item_summaries,
-                                })));
+                                slice.push_tree_with_ctx(
+                                    SumTree(Arc::new(Node::Leaf {
+                                        summary: slice_items_summary.unwrap(),
+                                        items: slice_items,
+                                        item_summaries: slice_item_summaries,
+                                    })),
+                                    ctx,
+                                );
                             }
                         }
                     }
@@ -672,9 +727,9 @@ where
             if let Some(summary) = self.item_summary() {
                 end.add_summary(summary);
             }
-            *target == end
+            target.cmp(&end, ctx) == Ordering::Equal
         } else {
-            *target == self.seek_dimension
+            target.cmp(&self.seek_dimension, ctx) == Ordering::Equal
         }
     }
 }

zed/src/sum_tree/mod.rs 🔗

@@ -3,7 +3,7 @@ mod cursor;
 use arrayvec::ArrayVec;
 pub use cursor::Cursor;
 pub use cursor::FilterCursor;
-use std::{fmt, iter::FromIterator, sync::Arc};
+use std::{cmp::Ordering, fmt, iter::FromIterator, sync::Arc};
 
 #[cfg(test)]
 const TREE_BASE: usize = 2;
@@ -23,17 +23,36 @@ pub trait KeyedItem: Item {
 }
 
 pub trait Summary: Default + Clone + fmt::Debug {
-    fn add_summary(&mut self, summary: &Self);
+    type Context;
+
+    fn add_summary(&mut self, _summary: &Self) {
+        unimplemented!();
+    }
+
+    fn add_summary_with_ctx(&mut self, summary: &Self, ctx: Option<&Self::Context>) {
+        assert!(ctx.is_none());
+        self.add_summary(summary);
+    }
 }
 
-pub trait Dimension<'a, Summary: Default>: Clone + fmt::Debug + Default {
-    fn add_summary(&mut self, summary: &'a Summary);
+pub trait Dimension<'a, S: Summary>: Clone + fmt::Debug + Default {
+    fn add_summary(&mut self, _summary: &'a S);
 }
 
-impl<'a, T: Default> Dimension<'a, T> for () {
+impl<'a, T: Summary> Dimension<'a, T> for () {
     fn add_summary(&mut self, _: &'a T) {}
 }
 
+pub trait SeekDimension<'a, T: Summary>: Dimension<'a, T> {
+    fn cmp(&self, other: &Self, ctx: Option<&T::Context>) -> Ordering;
+}
+
+impl<'a, S: Summary, T: Dimension<'a, S> + Ord> SeekDimension<'a, S> for T {
+    fn cmp(&self, other: &Self, _ctx: Option<&S::Context>) -> Ordering {
+        Ord::cmp(self, other)
+    }
+}
+
 #[derive(Copy, Clone, Eq, PartialEq)]
 pub enum SeekBias {
     Left,
@@ -94,7 +113,7 @@ impl<T: Item> SumTree<T> {
         let mut extent = D::default();
         match self.0.as_ref() {
             Node::Internal { summary, .. } | Node::Leaf { summary, .. } => {
-                extent.add_summary(summary)
+                extent.add_summary(summary);
             }
         }
         extent
@@ -154,30 +173,50 @@ impl<T: Item> SumTree<T> {
     }
 
     pub fn push(&mut self, item: T) {
+        self.push_with_ctx(item, None);
+    }
+
+    pub fn push_with_ctx(&mut self, item: T, ctx: Option<&<T::Summary as Summary>::Context>) {
         let summary = item.summary();
-        self.push_tree(SumTree::from_child_trees(vec![SumTree(Arc::new(
-            Node::Leaf {
-                summary: summary.clone(),
-                items: ArrayVec::from_iter(Some(item)),
-                item_summaries: ArrayVec::from_iter(Some(summary)),
-            },
-        ))]))
+        self.push_tree_with_ctx(
+            SumTree::from_child_trees(
+                vec![SumTree(Arc::new(Node::Leaf {
+                    summary: summary.clone(),
+                    items: ArrayVec::from_iter(Some(item)),
+                    item_summaries: ArrayVec::from_iter(Some(summary)),
+                }))],
+                ctx,
+            ),
+            ctx,
+        )
     }
 
     pub fn push_tree(&mut self, other: Self) {
+        self.push_tree_with_ctx(other, None);
+    }
+
+    pub fn push_tree_with_ctx(
+        &mut self,
+        other: Self,
+        ctx: Option<&<T::Summary as Summary>::Context>,
+    ) {
         let other_node = other.0.clone();
         if !other_node.is_leaf() || other_node.items().len() > 0 {
             if self.0.height() < other_node.height() {
                 for tree in other_node.child_trees() {
-                    self.push_tree(tree.clone());
+                    self.push_tree_with_ctx(tree.clone(), ctx);
                 }
-            } else if let Some(split_tree) = self.push_tree_recursive(other) {
-                *self = Self::from_child_trees(vec![self.clone(), split_tree]);
+            } else if let Some(split_tree) = self.push_tree_recursive(other, ctx) {
+                *self = Self::from_child_trees(vec![self.clone(), split_tree], ctx);
             }
         }
     }
 
-    fn push_tree_recursive(&mut self, other: SumTree<T>) -> Option<SumTree<T>> {
+    fn push_tree_recursive(
+        &mut self,
+        other: SumTree<T>,
+        ctx: Option<&<T::Summary as Summary>::Context>,
+    ) -> Option<SumTree<T>> {
         match Arc::make_mut(&mut self.0) {
             Node::Internal {
                 height,
@@ -187,7 +226,7 @@ impl<T: Item> SumTree<T> {
                 ..
             } => {
                 let other_node = other.0.clone();
-                summary.add_summary(other_node.summary());
+                summary.add_summary_with_ctx(other_node.summary(), ctx);
 
                 let height_delta = *height - other_node.height();
                 let mut summaries_to_append = ArrayVec::<[T::Summary; 2 * TREE_BASE]>::new();
@@ -199,7 +238,10 @@ impl<T: Item> SumTree<T> {
                     summaries_to_append.push(other_node.summary().clone());
                     trees_to_append.push(other)
                 } else {
-                    let tree_to_append = child_trees.last_mut().unwrap().push_tree_recursive(other);
+                    let tree_to_append = child_trees
+                        .last_mut()
+                        .unwrap()
+                        .push_tree_recursive(other, ctx);
                     *child_summaries.last_mut().unwrap() =
                         child_trees.last().unwrap().0.summary().clone();
 
@@ -229,13 +271,13 @@ impl<T: Item> SumTree<T> {
                         left_trees = all_trees.by_ref().take(midpoint).collect();
                         right_trees = all_trees.collect();
                     }
-                    *summary = sum(left_summaries.iter());
+                    *summary = sum(left_summaries.iter(), ctx);
                     *child_summaries = left_summaries;
                     *child_trees = left_trees;
 
                     Some(SumTree(Arc::new(Node::Internal {
                         height: *height,
-                        summary: sum(right_summaries.iter()),
+                        summary: sum(right_summaries.iter(), ctx),
                         child_summaries: right_summaries,
                         child_trees: right_trees,
                     })))
@@ -274,14 +316,14 @@ impl<T: Item> SumTree<T> {
                     }
                     *items = left_items;
                     *item_summaries = left_summaries;
-                    *summary = sum(item_summaries.iter());
+                    *summary = sum(item_summaries.iter(), ctx);
                     Some(SumTree(Arc::new(Node::Leaf {
                         items: right_items,
-                        summary: sum(right_summaries.iter()),
+                        summary: sum(right_summaries.iter(), ctx),
                         item_summaries: right_summaries,
                     })))
                 } else {
-                    summary.add_summary(other_node.summary());
+                    summary.add_summary_with_ctx(other_node.summary(), ctx);
                     items.extend(other_node.items().iter().cloned());
                     item_summaries.extend(other_node.child_summaries().iter().cloned());
                     None
@@ -290,13 +332,16 @@ impl<T: Item> SumTree<T> {
         }
     }
 
-    fn from_child_trees(child_trees: Vec<SumTree<T>>) -> Self {
+    fn from_child_trees(
+        child_trees: Vec<SumTree<T>>,
+        ctx: Option<&<T::Summary as Summary>::Context>,
+    ) -> Self {
         let height = child_trees[0].0.height() + 1;
         let mut child_summaries = ArrayVec::new();
         for child in &child_trees {
             child_summaries.push(child.0.summary().clone());
         }
-        let summary = sum(child_summaries.iter());
+        let summary = sum(child_summaries.iter(), ctx);
         SumTree(Arc::new(Node::Internal {
             height,
             summary,
@@ -486,14 +531,14 @@ impl<T: KeyedItem> Edit<T> {
     }
 }
 
-fn sum<'a, T, I>(iter: I) -> T
+fn sum<'a, T, I>(iter: I, ctx: Option<&T::Context>) -> T
 where
     T: 'a + Summary,
     I: Iterator<Item = &'a T>,
 {
     let mut sum = T::default();
     for value in iter {
-        sum.add_summary(value);
+        sum.add_summary_with_ctx(value, ctx);
     }
     sum
 }
@@ -840,7 +885,10 @@ mod tests {
             *self
         }
     }
+
     impl Summary for IntegersSummary {
+        type Context = ();
+
         fn add_summary(&mut self, other: &Self) {
             self.count.0 += &other.count.0;
             self.sum.0 += &other.sum.0;

zed/src/util.rs 🔗

@@ -30,37 +30,6 @@ where
     }
 }
 
-pub fn find_insertion_index<'a, F, T, E>(slice: &'a [T], mut f: F) -> Result<usize, E>
-where
-    F: FnMut(&'a T) -> Result<Ordering, E>,
-{
-    use Ordering::*;
-
-    let s = slice;
-    let mut size = s.len();
-    if size == 0 {
-        return Ok(0);
-    }
-    let mut base = 0usize;
-    while size > 1 {
-        let half = size / 2;
-        let mid = base + half;
-        // mid is always in [0, size), that means mid is >= 0 and < size.
-        // mid >= 0: by definition
-        // mid < size: mid = size / 2 + size / 4 + size / 8 ...
-        let cmp = f(unsafe { s.get_unchecked(mid) })?;
-        base = if cmp == Greater { base } else { mid };
-        size -= half;
-    }
-    // base is always in [0, size) because base <= mid.
-    let cmp = f(unsafe { s.get_unchecked(base) })?;
-    if cmp == Equal {
-        Ok(base)
-    } else {
-        Ok(base + (cmp == Less) as usize)
-    }
-}
-
 pub struct RandomCharIter<T: Rng>(T);
 
 impl<T: Rng> RandomCharIter<T> {
@@ -85,14 +54,6 @@ impl<T: Rng> Iterator for RandomCharIter<T> {
 mod tests {
     use super::*;
 
-    #[test]
-    fn test_find_insertion_index() {
-        assert_eq!(
-            find_insertion_index(&[0, 4, 8], |probe| Ok::<Ordering, ()>(probe.cmp(&2))),
-            Ok(1)
-        );
-    }
-
     #[test]
     fn test_extend_sorted() {
         let mut vec = vec![];

zed/src/worktree.rs 🔗

@@ -536,6 +536,8 @@ impl Default for EntrySummary {
 }
 
 impl sum_tree::Summary for EntrySummary {
+    type Context = ();
+
     fn add_summary(&mut self, rhs: &Self) {
         self.max_path = rhs.max_path.clone();
         self.file_count += rhs.file_count;