text: Speed up `offset_for_anchor` and `fragment_id_for_anchor` conversions (#46989)

Lukas Wirth created

And most importantly, speed `Anchor::cmp` by doing so.

Release Notes:

- N/A *or* Added/Fixed/Improved ...

Change summary

crates/editor/src/selections_collection.rs | 27 +++++-
crates/sum_tree/src/cursor.rs              |  2 
crates/sum_tree/src/sum_tree.rs            | 91 ++++++++++++++++++++++++
crates/text/src/text.rs                    | 84 ++++++---------------
crates/vim/src/vim.rs                      | 18 ++--
5 files changed, 150 insertions(+), 72 deletions(-)

Detailed changes

crates/editor/src/selections_collection.rs 🔗

@@ -109,10 +109,6 @@ impl SelectionsCollection {
         self.pending.as_ref().map(|pending| &pending.selection)
     }
 
-    pub fn pending_anchor_mut(&mut self) -> Option<&mut Selection<Anchor>> {
-        self.pending.as_mut().map(|pending| &mut pending.selection)
-    }
-
     pub fn pending<D>(&self, snapshot: &DisplaySnapshot) -> Option<Selection<D>>
     where
         D: MultiBufferDimension + Sub + AddAssign<<D as Sub>::Output> + Ord,
@@ -545,6 +541,11 @@ impl SelectionsCollection {
         );
         if cfg!(debug_assertions) {
             mutable_collection.disjoint.iter().for_each(|selection| {
+                assert!(
+                     selection.start.cmp(&selection.end, &snapshot).is_le(),
+                    "disjoint selection has start > end: {:?}",
+                    mutable_collection.disjoint
+                );
                 assert!(
                     snapshot.can_resolve(&selection.start),
                     "disjoint selection start is not resolvable for the given snapshot:\n{selection:?}, {excerpt:?}",
@@ -556,8 +557,20 @@ impl SelectionsCollection {
                     excerpt = snapshot.buffer_for_excerpt(selection.end.excerpt_id).map(|snapshot| snapshot.remote_id()),
                 );
             });
+            assert!(
+                mutable_collection
+                    .disjoint
+                    .is_sorted_by(|first, second| first.end.cmp(&second.start, &snapshot).is_le()),
+                "disjoint selections are not sorted: {:?}",
+                mutable_collection.disjoint
+            );
             if let Some(pending) = &mutable_collection.pending {
                 let selection = &pending.selection;
+                assert!(
+                    selection.start.cmp(&selection.end, &snapshot).is_le(),
+                    "pending selection has start > end: {:?}",
+                    selection
+                );
                 assert!(
                     snapshot.can_resolve(&selection.start),
                     "pending selection start is not resolvable for the given snapshot: {pending:?}, {excerpt:?}",
@@ -933,7 +946,6 @@ impl<'snap, 'a> MutableSelectionsCollection<'snap, 'a> {
         for selection in disjoint
             .iter()
             .sorted_by(|first, second| Ord::cmp(&second.id, &first.id))
-            .collect::<Vec<&Selection<Anchor>>>()
         {
             new_selections.push(Selection {
                 id: self.new_selection_id(),
@@ -1062,6 +1074,11 @@ impl<'snap, 'a> MutableSelectionsCollection<'snap, 'a> {
         self.select(new_selections);
     }
 
+    pub fn pending_anchor_mut(&mut self) -> Option<&mut Selection<Anchor>> {
+        self.selections_changed = true;
+        self.pending.as_mut().map(|pending| &mut pending.selection)
+    }
+
     /// Compute new ranges for any selections that were located in excerpts that have
     /// since been removed.
     ///

crates/sum_tree/src/cursor.rs 🔗

@@ -30,7 +30,7 @@ impl<T: Item + fmt::Debug, D: fmt::Debug> fmt::Debug for StackEntry<'_, T, D> {
 pub struct Cursor<'a, 'b, T: Item, D> {
     tree: &'a SumTree<T>,
     stack: ArrayVec<StackEntry<'a, T, D>, 16>,
-    position: D,
+    pub position: D,
     did_seek: bool,
     at_end: bool,
     cx: <T::Summary as Summary>::Context<'b>,

crates/sum_tree/src/sum_tree.rs 🔗

@@ -491,6 +491,97 @@ impl<T: Item> SumTree<T> {
         None
     }
 
+    /// A more efficient version of `Cursor::new()` + `Cursor::seek()` + `Cursor::item()`
+    #[instrument(skip_all)]
+    pub fn find_with_prev<'a, 'slf, D, Target>(
+        &'slf self,
+        cx: <T::Summary as Summary>::Context<'a>,
+        target: &Target,
+        bias: Bias,
+    ) -> (D, D, Option<(Option<&'slf T>, &'slf T)>)
+    where
+        D: Dimension<'slf, T::Summary>,
+        Target: SeekTarget<'slf, T::Summary, D>,
+    {
+        let tree_end = D::zero(cx).with_added_summary(self.summary(), cx);
+        let comparison = target.cmp(&tree_end, cx);
+        if comparison == Ordering::Greater || (comparison == Ordering::Equal && bias == Bias::Right)
+        {
+            return (tree_end.clone(), tree_end, None);
+        }
+
+        let mut pos = D::zero(cx);
+        return match Self::find_recurse_with_prev::<_, _, false>(
+            cx, target, bias, &mut pos, self, None,
+        ) {
+            Some((prev, item, end)) => (pos, end, Some((prev, item))),
+            None => (pos.clone(), pos, None),
+        };
+    }
+
+    fn find_recurse_with_prev<'tree, 'a, D, Target, const EXACT: bool>(
+        cx: <T::Summary as Summary>::Context<'a>,
+        target: &Target,
+        bias: Bias,
+        position: &mut D,
+        this: &'tree SumTree<T>,
+        prev: Option<&'tree T>,
+    ) -> Option<(Option<&'tree T>, &'tree T, D)>
+    where
+        D: Dimension<'tree, T::Summary>,
+        Target: SeekTarget<'tree, T::Summary, D>,
+    {
+        match &*this.0 {
+            Node::Internal {
+                child_summaries,
+                child_trees,
+                ..
+            } => {
+                let mut prev = prev;
+                for (child_tree, child_summary) in child_trees.iter().zip(child_summaries) {
+                    let child_end = position.clone().with_added_summary(child_summary, cx);
+
+                    let comparison = target.cmp(&child_end, cx);
+                    let target_in_child = comparison == Ordering::Less
+                        || (comparison == Ordering::Equal && bias == Bias::Left);
+                    if target_in_child {
+                        return Self::find_recurse_with_prev::<D, Target, EXACT>(
+                            cx, target, bias, position, child_tree, prev,
+                        );
+                    }
+                    prev = child_tree.last();
+                    *position = child_end;
+                }
+            }
+            Node::Leaf {
+                items,
+                item_summaries,
+                ..
+            } => {
+                let mut prev = prev;
+                for (item, item_summary) in items.iter().zip(item_summaries) {
+                    let mut child_end = position.clone();
+                    child_end.add_summary(item_summary, cx);
+
+                    let comparison = target.cmp(&child_end, cx);
+                    let entry_found = if EXACT {
+                        comparison == Ordering::Equal
+                    } else {
+                        comparison == Ordering::Less
+                            || (comparison == Ordering::Equal && bias == Bias::Left)
+                    };
+                    if entry_found {
+                        return Some((prev, item, child_end));
+                    }
+
+                    prev = Some(item);
+                    *position = child_end;
+                }
+            }
+        }
+        None
+    }
+
     pub fn cursor<'a, 'b, D>(
         &'a self,
         cx: <T::Summary as Summary>::Context<'b>,

crates/text/src/text.rs 🔗

@@ -2248,7 +2248,6 @@ impl BufferSnapshot {
         A: 'a + IntoIterator<Item = (&'a Anchor, T)>,
     {
         let anchors = anchors.into_iter();
-        let mut insertion_cursor = self.insertions.cursor::<InsertionFragmentKey>(());
         let mut fragment_cursor = self
             .fragments
             .cursor::<Dimensions<Option<&Locator>, usize>>(&None);
@@ -2262,24 +2261,7 @@ impl BufferSnapshot {
                 return (D::from_text_summary(&self.visible_text.summary()), payload);
             }
 
-            let anchor_key = InsertionFragmentKey {
-                timestamp: anchor.timestamp,
-                split_offset: anchor.offset,
-            };
-            insertion_cursor.seek(&anchor_key, anchor.bias);
-            if let Some(insertion) = insertion_cursor.item() {
-                let comparison = sum_tree::KeyedItem::key(insertion).cmp(&anchor_key);
-                if comparison == Ordering::Greater
-                    || (anchor.bias == Bias::Left
-                        && comparison == Ordering::Equal
-                        && anchor.offset > 0)
-                {
-                    insertion_cursor.prev();
-                }
-            } else {
-                insertion_cursor.prev();
-            }
-            let Some(insertion) = insertion_cursor.item() else {
+            let Some(insertion) = self.try_find_fragment(anchor) else {
                 panic!(
                     "invalid insertion for buffer {}@{:?} with anchor {:?}",
                     self.remote_id(),
@@ -2328,28 +2310,8 @@ impl BufferSnapshot {
                 anchor.timestamp,
                 self.version
             );
-            let anchor_key = InsertionFragmentKey {
-                timestamp: anchor.timestamp,
-                split_offset: anchor.offset,
-            };
-            let mut insertion_cursor = self.insertions.cursor::<InsertionFragmentKey>(());
-            insertion_cursor.seek(&anchor_key, anchor.bias);
-            if let Some(insertion) = insertion_cursor.item() {
-                let comparison = sum_tree::KeyedItem::key(insertion).cmp(&anchor_key);
-                if comparison == Ordering::Greater
-                    || (anchor.bias == Bias::Left
-                        && comparison == Ordering::Equal
-                        && anchor.offset > 0)
-                {
-                    insertion_cursor.prev();
-                }
-            } else {
-                insertion_cursor.prev();
-            }
-
-            let Some(insertion) = insertion_cursor
-                .item()
-                .filter(|insertion| insertion.timestamp == anchor.timestamp)
+            let item = self.try_find_fragment(anchor);
+            let Some(insertion) = item.filter(|insertion| insertion.timestamp == anchor.timestamp)
             else {
                 self.panic_bad_anchor(anchor);
             };
@@ -2401,31 +2363,37 @@ impl BufferSnapshot {
         } else if anchor.is_max() {
             Some(Locator::max_ref())
         } else {
-            let anchor_key = InsertionFragmentKey {
-                timestamp: anchor.timestamp,
-                split_offset: anchor.offset,
-            };
-            let mut insertion_cursor = self.insertions.cursor::<InsertionFragmentKey>(());
-            insertion_cursor.seek(&anchor_key, anchor.bias);
-            if let Some(insertion) = insertion_cursor.item() {
+            let item = self.try_find_fragment(anchor);
+            item.filter(|insertion| {
+                !cfg!(debug_assertions) || insertion.timestamp == anchor.timestamp
+            })
+            .map(|insertion| &insertion.fragment_id)
+        }
+    }
+
+    fn try_find_fragment(&self, anchor: &Anchor) -> Option<&InsertionFragment> {
+        let anchor_key = InsertionFragmentKey {
+            timestamp: anchor.timestamp,
+            split_offset: anchor.offset,
+        };
+        match self.insertions.find_with_prev::<InsertionFragmentKey, _>(
+            (),
+            &anchor_key,
+            anchor.bias,
+        ) {
+            (_, _, Some((prev, insertion))) => {
                 let comparison = sum_tree::KeyedItem::key(insertion).cmp(&anchor_key);
                 if comparison == Ordering::Greater
                     || (anchor.bias == Bias::Left
                         && comparison == Ordering::Equal
                         && anchor.offset > 0)
                 {
-                    insertion_cursor.prev();
+                    prev
+                } else {
+                    Some(insertion)
                 }
-            } else {
-                insertion_cursor.prev();
             }
-
-            insertion_cursor
-                .item()
-                .filter(|insertion| {
-                    !cfg!(debug_assertions) || insertion.timestamp == anchor.timestamp
-                })
-                .map(|insertion| &insertion.fragment_id)
+            _ => self.insertions.last(),
         }
     }
 

crates/vim/src/vim.rs 🔗

@@ -1245,14 +1245,16 @@ impl Vim {
 
                 if should_extend_pending {
                     let snapshot = s.display_snapshot();
-                    if let Some(pending) = s.pending_anchor_mut() {
-                        let end = pending.end.to_point(&snapshot.buffer_snapshot());
-                        let end = end.to_display_point(&snapshot);
-                        let new_end = movement::right(&snapshot, end);
-                        pending.end = snapshot
-                            .buffer_snapshot()
-                            .anchor_before(new_end.to_point(&snapshot));
-                    }
+                    s.change_with(&snapshot, |map| {
+                        if let Some(pending) = map.pending_anchor_mut() {
+                            let end = pending.end.to_point(&snapshot.buffer_snapshot());
+                            let end = end.to_display_point(&snapshot);
+                            let new_end = movement::right(&snapshot, end);
+                            pending.end = snapshot
+                                .buffer_snapshot()
+                                .anchor_before(new_end.to_point(&snapshot));
+                        }
+                    });
                     vim.extended_pending_selection_id = s.pending_anchor().map(|p| p.id)
                 }