Optimize construction and insertion of large `SumTree`s (#7731)

Thorsten Ball , Antonio Scandurra , and Julia created

This does two things:

1. It optimizes the constructions of `SumTree`s to not insert nodes
one-by-one, but instead inserts them level-by-level. That makes it more
efficient to construct large `SumTree`s.
2. It adds a `from_par_iter` constructor that parallelizes the
construction of `SumTree`s.

In combination, **loading a 500MB plain text file went from from
~18seconds down to ~2seconds**.

Disclaimer: I didn't write any of this code, lol! It's all @as-cii and
@nathansobo.

Release Notes:

- Improved performance when opening very large files.

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Julia <julia@zed.dev>

Change summary

Cargo.lock                      |   1 
crates/rope/src/rope.rs         |  87 ++++++---------
crates/sum_tree/Cargo.toml      |   1 
crates/sum_tree/src/sum_tree.rs | 194 +++++++++++++++++++++++++++-------
4 files changed, 192 insertions(+), 91 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -8257,6 +8257,7 @@ dependencies = [
  "env_logger",
  "log",
  "rand 0.8.5",
+ "rayon",
 ]
 
 [[package]]

crates/rope/src/rope.rs 🔗

@@ -84,45 +84,49 @@ impl Rope {
         self.slice(start..end)
     }
 
-    pub fn push(&mut self, text: &str) {
-        let mut new_chunks = SmallVec::<[_; 16]>::new();
-        let mut new_chunk = ArrayString::new();
-        for ch in text.chars() {
-            if new_chunk.len() + ch.len_utf8() > 2 * CHUNK_BASE {
-                new_chunks.push(Chunk(new_chunk));
-                new_chunk = ArrayString::new();
-            }
-
-            new_chunk.push(ch);
-        }
-        if !new_chunk.is_empty() {
-            new_chunks.push(Chunk(new_chunk));
-        }
-
-        let mut new_chunks = new_chunks.into_iter();
-        let mut first_new_chunk = new_chunks.next();
+    pub fn push(&mut self, mut text: &str) {
         self.chunks.update_last(
             |last_chunk| {
-                if let Some(first_new_chunk_ref) = first_new_chunk.as_mut() {
-                    if last_chunk.0.len() + first_new_chunk_ref.0.len() <= 2 * CHUNK_BASE {
-                        last_chunk.0.push_str(&first_new_chunk.take().unwrap().0);
-                    } else {
-                        let mut text = ArrayString::<{ 4 * CHUNK_BASE }>::new();
-                        text.push_str(&last_chunk.0);
-                        text.push_str(&first_new_chunk_ref.0);
-                        let (left, right) = text.split_at(find_split_ix(&text));
-                        last_chunk.0.clear();
-                        last_chunk.0.push_str(left);
-                        first_new_chunk_ref.0.clear();
-                        first_new_chunk_ref.0.push_str(right);
+                let split_ix = if last_chunk.0.len() + text.len() <= 2 * CHUNK_BASE {
+                    text.len()
+                } else {
+                    let mut split_ix =
+                        cmp::min(CHUNK_BASE.saturating_sub(last_chunk.0.len()), text.len());
+                    while !text.is_char_boundary(split_ix) {
+                        split_ix += 1;
                     }
-                }
+                    split_ix
+                };
+
+                let (suffix, remainder) = text.split_at(split_ix);
+                last_chunk.0.push_str(suffix);
+                text = remainder;
             },
             &(),
         );
 
-        self.chunks
-            .extend(first_new_chunk.into_iter().chain(new_chunks), &());
+        let mut new_chunks = SmallVec::<[_; 16]>::new();
+        while !text.is_empty() {
+            let mut split_ix = cmp::min(2 * CHUNK_BASE, text.len());
+            while !text.is_char_boundary(split_ix) {
+                split_ix -= 1;
+            }
+            let (chunk, remainder) = text.split_at(split_ix);
+            new_chunks.push(Chunk(ArrayString::from(chunk).unwrap()));
+            text = remainder;
+        }
+
+        #[cfg(test)]
+        const PARALLEL_THRESHOLD: usize = 4;
+        #[cfg(not(test))]
+        const PARALLEL_THRESHOLD: usize = 4 * (2 * sum_tree::TREE_BASE);
+
+        if new_chunks.len() >= PARALLEL_THRESHOLD {
+            self.chunks.par_extend(new_chunks.into_vec(), &());
+        } else {
+            self.chunks.extend(new_chunks, &());
+        }
+
         self.check_invariants();
     }
 
@@ -1167,25 +1171,6 @@ impl TextDimension for PointUtf16 {
     }
 }
 
-fn find_split_ix(text: &str) -> usize {
-    let mut ix = text.len() / 2;
-    while !text.is_char_boundary(ix) {
-        if ix < 2 * CHUNK_BASE {
-            ix += 1;
-        } else {
-            ix = (text.len() / 2) - 1;
-            break;
-        }
-    }
-    while !text.is_char_boundary(ix) {
-        ix -= 1;
-    }
-
-    debug_assert!(ix <= 2 * CHUNK_BASE);
-    debug_assert!(text.len() - ix <= 2 * CHUNK_BASE);
-    ix
-}
-
 #[cfg(test)]
 mod tests {
     use super::*;

crates/sum_tree/Cargo.toml 🔗

@@ -11,6 +11,7 @@ doctest = false
 
 [dependencies]
 arrayvec = "0.7.1"
+rayon = "1.8"
 log.workspace = true
 
 [dev-dependencies]

crates/sum_tree/src/sum_tree.rs 🔗

@@ -3,14 +3,16 @@ mod tree_map;
 
 use arrayvec::ArrayVec;
 pub use cursor::{Cursor, FilterCursor, Iter};
+use rayon::prelude::*;
 use std::marker::PhantomData;
+use std::mem;
 use std::{cmp::Ordering, fmt, iter::FromIterator, sync::Arc};
 pub use tree_map::{MapSeekTarget, TreeMap, TreeSet};
 
 #[cfg(test)]
-const TREE_BASE: usize = 2;
+pub const TREE_BASE: usize = 2;
 #[cfg(not(test))]
-const TREE_BASE: usize = 6;
+pub const TREE_BASE: usize = 6;
 
 pub trait Item: Clone {
     type Summary: Summary;
@@ -133,9 +135,128 @@ impl<T: Item> SumTree<T> {
         iter: I,
         cx: &<T::Summary as Summary>::Context,
     ) -> Self {
-        let mut tree = Self::new();
-        tree.extend(iter, cx);
-        tree
+        let mut nodes = Vec::new();
+
+        let mut iter = iter.into_iter().peekable();
+        while iter.peek().is_some() {
+            let items: ArrayVec<T, { 2 * TREE_BASE }> = iter.by_ref().take(2 * TREE_BASE).collect();
+            let item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }> =
+                items.iter().map(|item| item.summary()).collect();
+
+            let mut summary = item_summaries[0].clone();
+            for item_summary in &item_summaries[1..] {
+                <T::Summary as Summary>::add_summary(&mut summary, item_summary, cx);
+            }
+
+            nodes.push(Node::Leaf {
+                summary,
+                items,
+                item_summaries,
+            });
+        }
+
+        let mut parent_nodes = Vec::new();
+        let mut height = 0;
+        while nodes.len() > 1 {
+            height += 1;
+            let mut current_parent_node = None;
+            for child_node in nodes.drain(..) {
+                let parent_node = current_parent_node.get_or_insert_with(|| Node::Internal {
+                    summary: T::Summary::default(),
+                    height,
+                    child_summaries: ArrayVec::new(),
+                    child_trees: ArrayVec::new(),
+                });
+                let Node::Internal {
+                    summary,
+                    child_summaries,
+                    child_trees,
+                    ..
+                } = parent_node
+                else {
+                    unreachable!()
+                };
+                let child_summary = child_node.summary();
+                <T::Summary as Summary>::add_summary(summary, child_summary, cx);
+                child_summaries.push(child_summary.clone());
+                child_trees.push(Self(Arc::new(child_node)));
+
+                if child_trees.len() == 2 * TREE_BASE {
+                    parent_nodes.extend(current_parent_node.take());
+                }
+            }
+            parent_nodes.extend(current_parent_node.take());
+            mem::swap(&mut nodes, &mut parent_nodes);
+        }
+
+        if nodes.is_empty() {
+            Self::new()
+        } else {
+            debug_assert_eq!(nodes.len(), 1);
+            Self(Arc::new(nodes.pop().unwrap()))
+        }
+    }
+
+    pub fn from_par_iter<I, Iter>(iter: I, cx: &<T::Summary as Summary>::Context) -> Self
+    where
+        I: IntoParallelIterator<Iter = Iter>,
+        Iter: IndexedParallelIterator<Item = T>,
+        T: Send + Sync,
+        T::Summary: Send + Sync,
+        <T::Summary as Summary>::Context: Sync,
+    {
+        let mut nodes = iter
+            .into_par_iter()
+            .chunks(2 * TREE_BASE)
+            .map(|items| {
+                let items: ArrayVec<T, { 2 * TREE_BASE }> = items.into_iter().collect();
+                let item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }> =
+                    items.iter().map(|item| item.summary()).collect();
+                let mut summary = item_summaries[0].clone();
+                for item_summary in &item_summaries[1..] {
+                    <T::Summary as Summary>::add_summary(&mut summary, item_summary, cx);
+                }
+                SumTree(Arc::new(Node::Leaf {
+                    summary,
+                    items,
+                    item_summaries,
+                }))
+            })
+            .collect::<Vec<_>>();
+
+        let mut height = 0;
+        while nodes.len() > 1 {
+            height += 1;
+            nodes = nodes
+                .into_par_iter()
+                .chunks(2 * TREE_BASE)
+                .map(|child_nodes| {
+                    let child_trees: ArrayVec<SumTree<T>, { 2 * TREE_BASE }> =
+                        child_nodes.into_iter().collect();
+                    let child_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }> = child_trees
+                        .iter()
+                        .map(|child_tree| child_tree.summary().clone())
+                        .collect();
+                    let mut summary = child_summaries[0].clone();
+                    for child_summary in &child_summaries[1..] {
+                        <T::Summary as Summary>::add_summary(&mut summary, child_summary, cx);
+                    }
+                    SumTree(Arc::new(Node::Internal {
+                        height,
+                        summary,
+                        child_summaries,
+                        child_trees,
+                    }))
+                })
+                .collect::<Vec<_>>();
+        }
+
+        if nodes.is_empty() {
+            Self::new()
+        } else {
+            debug_assert_eq!(nodes.len(), 1);
+            nodes.pop().unwrap()
+        }
     }
 
     #[allow(unused)]
@@ -251,39 +372,18 @@ impl<T: Item> SumTree<T> {
     where
         I: IntoIterator<Item = T>,
     {
-        let mut leaf: Option<Node<T>> = None;
-
-        for item in iter {
-            if leaf.is_some() && leaf.as_ref().unwrap().items().len() == 2 * TREE_BASE {
-                self.append(SumTree(Arc::new(leaf.take().unwrap())), cx);
-            }
-
-            if leaf.is_none() {
-                leaf = Some(Node::Leaf::<T> {
-                    summary: T::Summary::default(),
-                    items: ArrayVec::new(),
-                    item_summaries: ArrayVec::new(),
-                });
-            }
-
-            if let Some(Node::Leaf {
-                summary,
-                items,
-                item_summaries,
-            }) = leaf.as_mut()
-            {
-                let item_summary = item.summary();
-                <T::Summary as Summary>::add_summary(summary, &item_summary, cx);
-                items.push(item);
-                item_summaries.push(item_summary);
-            } else {
-                unreachable!()
-            }
-        }
+        self.append(Self::from_iter(iter, cx), cx);
+    }
 
-        if leaf.is_some() {
-            self.append(SumTree(Arc::new(leaf.take().unwrap())), cx);
-        }
+    pub fn par_extend<I, Iter>(&mut self, iter: I, cx: &<T::Summary as Summary>::Context)
+    where
+        I: IntoParallelIterator<Iter = Iter>,
+        Iter: IndexedParallelIterator<Item = T>,
+        T: Send + Sync,
+        T::Summary: Send + Sync,
+        <T::Summary as Summary>::Context: Sync,
+    {
+        self.append(Self::from_par_iter(iter, cx), cx);
     }
 
     pub fn push(&mut self, item: T, cx: &<T::Summary as Summary>::Context) {
@@ -299,7 +399,9 @@ impl<T: Item> SumTree<T> {
     }
 
     pub fn append(&mut self, other: Self, cx: &<T::Summary as Summary>::Context) {
-        if !other.0.is_leaf() || !other.0.items().is_empty() {
+        if self.is_empty() {
+            *self = other;
+        } else if !other.0.is_leaf() || !other.0.items().is_empty() {
             if self.0.height() < other.0.height() {
                 for tree in other.0.child_trees() {
                     self.append(tree.clone(), cx);
@@ -733,7 +835,15 @@ mod tests {
             let rng = &mut rng;
             let mut tree = SumTree::<u8>::new();
             let count = rng.gen_range(0..10);
-            tree.extend(rng.sample_iter(distributions::Standard).take(count), &());
+            if rng.gen() {
+                tree.extend(rng.sample_iter(distributions::Standard).take(count), &());
+            } else {
+                let items = rng
+                    .sample_iter(distributions::Standard)
+                    .take(count)
+                    .collect::<Vec<_>>();
+                tree.par_extend(items, &());
+            }
 
             for _ in 0..num_operations {
                 let splice_end = rng.gen_range(0..tree.extent::<Count>(&()).0 + 1);
@@ -751,7 +861,11 @@ mod tests {
                 tree = {
                     let mut cursor = tree.cursor::<Count>();
                     let mut new_tree = cursor.slice(&Count(splice_start), Bias::Right, &());
-                    new_tree.extend(new_items, &());
+                    if rng.gen() {
+                        new_tree.extend(new_items, &());
+                    } else {
+                        new_tree.par_extend(new_items, &());
+                    }
                     cursor.seek(&Count(splice_end), Bias::Right, &());
                     new_tree.append(cursor.slice(&tree_end, Bias::Right, &()), &());
                     new_tree