sum_tree: Spawn less tasks in `SumTree::from_iter_async` (#41793)

Lukas Wirth created

Release Notes:

- N/A *or* Added/Fixed/Improved ...

Change summary

Cargo.lock                      |  2 
crates/rope/src/rope.rs         | 22 +++--------
crates/sum_tree/Cargo.toml      |  2 
crates/sum_tree/src/sum_tree.rs | 64 +++++++++++++++++++++-------------
4 files changed, 48 insertions(+), 42 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -16367,7 +16367,7 @@ dependencies = [
  "arrayvec",
  "ctor",
  "futures 0.3.31",
- "itertools 0.14.0",
+ "futures-lite 1.13.0",
  "log",
  "pollster 0.4.0",
  "rand 0.9.2",

crates/rope/src/rope.rs 🔗

@@ -323,21 +323,13 @@ impl Rope {
         const PARALLEL_THRESHOLD: usize = 4 * (2 * sum_tree::TREE_BASE);
 
         if new_chunks.len() >= PARALLEL_THRESHOLD {
-            let cx2 = executor.clone();
-            executor
-                .scoped(|scope| {
-                    // SAFETY: transmuting to 'static is safe because the future is scoped
-                    // and the underlying string data cannot go out of scope because dropping the scope
-                    // will wait for the task to finish
-                    let new_chunks =
-                        unsafe { std::mem::transmute::<Vec<&str>, Vec<&'static str>>(new_chunks) };
-
-                    let async_extend = self
-                        .chunks
-                        .async_extend(new_chunks.into_iter().map(Chunk::new), cx2);
-
-                    scope.spawn(async_extend);
-                })
+            // SAFETY: transmuting to 'static is sound here. We block on the future making use of this
+            // and we know that the result of this computation is not stashing the static reference
+            // away.
+            let new_chunks =
+                unsafe { std::mem::transmute::<Vec<&str>, Vec<&'static str>>(new_chunks) };
+            self.chunks
+                .async_extend(new_chunks.into_iter().map(Chunk::new), executor)
                 .await;
         } else {
             self.chunks

crates/sum_tree/Cargo.toml 🔗

@@ -17,7 +17,7 @@ doctest = false
 arrayvec = "0.7.1"
 log.workspace = true
 futures.workspace = true
-itertools.workspace = true
+futures-lite.workspace = true
 
 [dev-dependencies]
 ctor.workspace = true

crates/sum_tree/src/sum_tree.rs 🔗

@@ -4,15 +4,15 @@ mod tree_map;
 use arrayvec::ArrayVec;
 pub use cursor::{Cursor, FilterCursor, Iter};
 use futures::{StreamExt, stream};
-use itertools::Itertools as _;
+use futures_lite::future::yield_now;
 use std::marker::PhantomData;
 use std::mem;
 use std::{cmp::Ordering, fmt, iter::FromIterator, sync::Arc};
 pub use tree_map::{MapSeekTarget, TreeMap, TreeSet};
 
-#[cfg(test)]
+#[cfg(all(test, not(rust_analyzer)))]
 pub const TREE_BASE: usize = 2;
-#[cfg(not(test))]
+#[cfg(not(all(test, not(rust_analyzer))))]
 pub const TREE_BASE: usize = 6;
 
 pub trait BackgroundSpawn {
@@ -316,30 +316,44 @@ impl<T: Item> SumTree<T> {
         T: 'static + Send + Sync,
         for<'a> T::Summary: Summary<Context<'a> = ()> + Send + Sync,
         S: BackgroundSpawn,
-        I: IntoIterator<Item = T>,
+        I: IntoIterator<Item = T, IntoIter: ExactSizeIterator>,
     {
-        let mut futures = vec![];
-        let chunks = iter.into_iter().chunks(2 * TREE_BASE);
-        for chunk in chunks.into_iter() {
-            let items: ArrayVec<T, { 2 * TREE_BASE }> = chunk.into_iter().collect();
-            futures.push(async move {
-                let item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }> =
-                    items.iter().map(|item| item.summary(())).collect();
-                let mut summary = item_summaries[0].clone();
-                for item_summary in &item_summaries[1..] {
-                    <T::Summary as Summary>::add_summary(&mut summary, item_summary, ());
-                }
-                SumTree(Arc::new(Node::Leaf {
-                    summary,
-                    items,
-                    item_summaries,
-                }))
-            });
+        let iter = iter.into_iter();
+        let num_leaves = iter.len().div_ceil(2 * TREE_BASE);
+
+        if num_leaves == 0 {
+            return Self::new(());
         }
 
-        let mut nodes = futures::stream::iter(futures)
+        let mut nodes = stream::iter(iter)
+            .chunks(num_leaves.div_ceil(4))
+            .map(|chunk| async move {
+                let mut chunk = chunk.into_iter();
+                let mut leaves = vec![];
+                loop {
+                    let items: ArrayVec<T, { 2 * TREE_BASE }> =
+                        chunk.by_ref().take(2 * TREE_BASE).collect();
+                    if items.is_empty() {
+                        break;
+                    }
+                    let item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }> =
+                        items.iter().map(|item| item.summary(())).collect();
+                    let mut summary = item_summaries[0].clone();
+                    for item_summary in &item_summaries[1..] {
+                        <T::Summary as Summary>::add_summary(&mut summary, item_summary, ());
+                    }
+                    leaves.push(SumTree(Arc::new(Node::Leaf {
+                        summary,
+                        items,
+                        item_summaries,
+                    })));
+                    yield_now().await;
+                }
+                leaves
+            })
             .map(|future| spawn.background_spawn(future))
             .buffered(4)
+            .flat_map(|it| stream::iter(it.into_iter()))
             .collect::<Vec<_>>()
             .await;
 
@@ -622,7 +636,7 @@ impl<T: Item> SumTree<T> {
     pub async fn async_extend<S, I>(&mut self, iter: I, spawn: S)
     where
         S: BackgroundSpawn,
-        I: IntoIterator<Item = T> + 'static,
+        I: IntoIterator<Item = T, IntoIter: ExactSizeIterator>,
         T: 'static + Send + Sync,
         for<'b> T::Summary: Summary<Context<'b> = ()> + Send + Sync,
     {
@@ -1126,7 +1140,7 @@ mod tests {
 
             let rng = &mut rng;
             let mut tree = SumTree::<u8>::default();
-            let count = rng.random_range(0..10);
+            let count = rng.random_range(0..128);
             if rng.random() {
                 tree.extend(rng.sample_iter(StandardUniform).take(count), ());
             } else {
@@ -1140,7 +1154,7 @@ mod tests {
             for _ in 0..num_operations {
                 let splice_end = rng.random_range(0..tree.extent::<Count>(()).0 + 1);
                 let splice_start = rng.random_range(0..splice_end + 1);
-                let count = rng.random_range(0..10);
+                let count = rng.random_range(0..128);
                 let tree_end = tree.extent::<Count>(());
                 let new_items = rng
                     .sample_iter(StandardUniform)