Remove type parameters from Cursor::seek_internal

Max Brunsfeld created

Instead, use trait objects for the target dimension and aggregation

Change summary

gpui/src/sum_tree/cursor.rs | 187 ++++++++++++++++++++------------------
1 file changed, 100 insertions(+), 87 deletions(-)

Detailed changes

gpui/src/sum_tree/cursor.rs 🔗

@@ -1,6 +1,6 @@
 use super::*;
 use arrayvec::ArrayVec;
-use std::{cmp::Ordering, sync::Arc};
+use std::{cmp::Ordering, mem, sync::Arc};
 
 #[derive(Clone)]
 struct StackEntry<'a, T: Item, D> {
@@ -324,7 +324,7 @@ where
         Target: SeekTarget<'a, T::Summary, D>,
     {
         self.reset();
-        self.seek_internal::<_, ()>(pos, bias, &mut SeekAggregate::None, cx)
+        self.seek_internal(pos, bias, &mut (), cx)
     }
 
     pub fn seek_forward<Target>(
@@ -336,7 +336,7 @@ where
     where
         Target: SeekTarget<'a, T::Summary, D>,
     {
-        self.seek_internal::<_, ()>(pos, bias, &mut SeekAggregate::None, cx)
+        self.seek_internal(pos, bias, &mut (), cx)
     }
 
     pub fn slice<Target>(
@@ -348,23 +348,18 @@ where
     where
         Target: SeekTarget<'a, T::Summary, D>,
     {
-        let mut slice = SeekAggregate::Slice(SumTree::new());
-        self.seek_internal::<_, ()>(end, bias, &mut slice, cx);
-        if let SeekAggregate::Slice(slice) = slice {
-            slice
-        } else {
-            unreachable!()
-        }
+        let mut slice = SliceSeekAggregate {
+            tree: SumTree::new(),
+            leaf_items: ArrayVec::new(),
+            leaf_item_summaries: ArrayVec::new(),
+            leaf_summary: T::Summary::default(),
+        };
+        self.seek_internal(end, bias, &mut slice, cx);
+        slice.tree
     }
 
     pub fn suffix(&mut self, cx: &<T::Summary as Summary>::Context) -> SumTree<T> {
-        let mut slice = SeekAggregate::Slice(SumTree::new());
-        self.seek_internal::<_, ()>(&End::new(), Bias::Right, &mut slice, cx);
-        if let SeekAggregate::Slice(slice) = slice {
-            slice
-        } else {
-            unreachable!()
-        }
+        self.slice(&End::new(), Bias::Right, cx)
     }
 
     pub fn summary<Target, Output>(
@@ -377,26 +372,18 @@ where
         Target: SeekTarget<'a, T::Summary, D>,
         Output: Dimension<'a, T::Summary>,
     {
-        let mut summary = SeekAggregate::Summary(Output::default());
+        let mut summary = SummarySeekAggregate(Output::default());
         self.seek_internal(end, bias, &mut summary, cx);
-        if let SeekAggregate::Summary(summary) = summary {
-            summary
-        } else {
-            unreachable!()
-        }
+        summary.0
     }
 
-    fn seek_internal<Target, Output>(
+    fn seek_internal(
         &mut self,
-        target: &Target,
+        target: &dyn SeekTarget<'a, T::Summary, D>,
         bias: Bias,
-        aggregate: &mut SeekAggregate<T, Output>,
+        aggregate: &mut dyn SeekAggregate<'a, T>,
         cx: &<T::Summary as Summary>::Context,
-    ) -> bool
-    where
-        Target: SeekTarget<'a, T::Summary, D>,
-        Output: Dimension<'a, T::Summary>,
-    {
+    ) -> bool {
         debug_assert!(
             target.cmp(&self.position, cx) >= Ordering::Equal,
             "cannot seek backward from {:?} to {:?}",
@@ -437,15 +424,7 @@ where
                             || (comparison == Ordering::Equal && bias == Bias::Right)
                         {
                             self.position = child_end;
-                            match aggregate {
-                                SeekAggregate::None => {}
-                                SeekAggregate::Slice(slice) => {
-                                    slice.push_tree(child_tree.clone(), cx);
-                                }
-                                SeekAggregate::Summary(summary) => {
-                                    summary.add_summary(child_summary, cx);
-                                }
-                            }
+                            aggregate.push_tree(child_tree, child_summary, cx);
                             entry.index += 1;
                             entry.position = self.position.clone();
                         } else {
@@ -464,12 +443,7 @@ where
                     ref item_summaries,
                     ..
                 } => {
-                    let mut slice_items = ArrayVec::<T, { 2 * TREE_BASE }>::new();
-                    let mut slice_item_summaries = ArrayVec::<T::Summary, { 2 * TREE_BASE }>::new();
-                    let mut slice_items_summary = match aggregate {
-                        SeekAggregate::Slice(_) => Some(T::Summary::default()),
-                        _ => None,
-                    };
+                    aggregate.begin_leaf();
 
                     for (item, item_summary) in items[entry.index..]
                         .iter()
@@ -483,49 +457,15 @@ where
                             || (comparison == Ordering::Equal && bias == Bias::Right)
                         {
                             self.position = child_end;
-                            match aggregate {
-                                SeekAggregate::None => {}
-                                SeekAggregate::Slice(_) => {
-                                    slice_items.push(item.clone());
-                                    slice_item_summaries.push(item_summary.clone());
-                                    <T::Summary as Summary>::add_summary(
-                                        slice_items_summary.as_mut().unwrap(),
-                                        item_summary,
-                                        cx,
-                                    );
-                                }
-                                SeekAggregate::Summary(summary) => {
-                                    summary.add_summary(item_summary, cx);
-                                }
-                            }
+                            aggregate.push_item(item, item_summary, cx);
                             entry.index += 1;
                         } else {
-                            if let SeekAggregate::Slice(slice) = aggregate {
-                                slice.push_tree(
-                                    SumTree(Arc::new(Node::Leaf {
-                                        summary: slice_items_summary.unwrap(),
-                                        items: slice_items,
-                                        item_summaries: slice_item_summaries,
-                                    })),
-                                    cx,
-                                );
-                            }
+                            aggregate.end_leaf(cx);
                             break 'outer;
                         }
                     }
 
-                    if let SeekAggregate::Slice(slice) = aggregate {
-                        if !slice_items.is_empty() {
-                            slice.push_tree(
-                                SumTree(Arc::new(Node::Leaf {
-                                    summary: slice_items_summary.unwrap(),
-                                    items: slice_items,
-                                    item_summaries: slice_item_summaries,
-                                })),
-                                cx,
-                            );
-                        }
-                    }
+                    aggregate.end_leaf(cx);
                 }
             }
 
@@ -625,8 +565,81 @@ where
     }
 }
 
-enum SeekAggregate<T: Item, D> {
-    None,
-    Slice(SumTree<T>),
-    Summary(D),
+trait SeekAggregate<'a, T: Item> {
+    fn begin_leaf(&mut self);
+    fn end_leaf(&mut self, cx: &<T::Summary as Summary>::Context);
+    fn push_item(
+        &mut self,
+        item: &'a T,
+        summary: &'a T::Summary,
+        cx: &<T::Summary as Summary>::Context,
+    );
+    fn push_tree(
+        &mut self,
+        tree: &'a SumTree<T>,
+        summary: &'a T::Summary,
+        cx: &<T::Summary as Summary>::Context,
+    );
+}
+
+struct SliceSeekAggregate<T: Item> {
+    tree: SumTree<T>,
+    leaf_items: ArrayVec<T, { 2 * TREE_BASE }>,
+    leaf_item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }>,
+    leaf_summary: T::Summary,
+}
+
+struct SummarySeekAggregate<D>(D);
+
+impl<'a, T: Item> SeekAggregate<'a, T> for () {
+    fn begin_leaf(&mut self) {}
+    fn end_leaf(&mut self, _: &<T::Summary as Summary>::Context) {}
+    fn push_item(&mut self, _: &T, _: &T::Summary, _: &<T::Summary as Summary>::Context) {}
+    fn push_tree(&mut self, _: &SumTree<T>, _: &T::Summary, _: &<T::Summary as Summary>::Context) {}
+}
+
+impl<'a, T: Item> SeekAggregate<'a, T> for SliceSeekAggregate<T> {
+    fn begin_leaf(&mut self) {}
+    fn end_leaf(&mut self, cx: &<T::Summary as Summary>::Context) {
+        self.tree.push_tree(
+            SumTree(Arc::new(Node::Leaf {
+                summary: mem::take(&mut self.leaf_summary),
+                items: mem::take(&mut self.leaf_items),
+                item_summaries: mem::take(&mut self.leaf_item_summaries),
+            })),
+            cx,
+        );
+    }
+    fn push_item(&mut self, item: &T, summary: &T::Summary, cx: &<T::Summary as Summary>::Context) {
+        self.leaf_items.push(item.clone());
+        self.leaf_item_summaries.push(summary.clone());
+        Summary::add_summary(&mut self.leaf_summary, summary, cx);
+    }
+    fn push_tree(
+        &mut self,
+        tree: &SumTree<T>,
+        _: &T::Summary,
+        cx: &<T::Summary as Summary>::Context,
+    ) {
+        self.tree.push_tree(tree.clone(), cx);
+    }
+}
+
+impl<'a, T: Item, D> SeekAggregate<'a, T> for SummarySeekAggregate<D>
+where
+    D: Dimension<'a, T::Summary>,
+{
+    fn begin_leaf(&mut self) {}
+    fn end_leaf(&mut self, _: &<T::Summary as Summary>::Context) {}
+    fn push_item(&mut self, _: &T, summary: &'a T::Summary, cx: &<T::Summary as Summary>::Context) {
+        self.0.add_summary(summary, cx);
+    }
+    fn push_tree(
+        &mut self,
+        _: &SumTree<T>,
+        summary: &'a T::Summary,
+        cx: &<T::Summary as Summary>::Context,
+    ) {
+        self.0.add_summary(summary, cx);
+    }
 }