sum_tree.rs

   1mod cursor;
   2mod tree_map;
   3
   4use arrayvec::ArrayVec;
   5pub use cursor::{Cursor, FilterCursor, Iter};
   6use futures::{StreamExt, stream};
   7use futures_lite::future::yield_now;
   8use std::marker::PhantomData;
   9use std::mem;
  10use std::{cmp::Ordering, fmt, iter::FromIterator, sync::Arc};
  11pub use tree_map::{MapSeekTarget, TreeMap, TreeSet};
  12
  13#[cfg(all(test, not(rust_analyzer)))]
  14pub const TREE_BASE: usize = 2;
  15#[cfg(not(all(test, not(rust_analyzer))))]
  16pub const TREE_BASE: usize = 6;
  17
  18pub trait BackgroundSpawn {
  19    type Task<R>: Future<Output = R> + Send + Sync
  20    where
  21        R: Send + Sync;
  22    fn background_spawn<R>(
  23        &self,
  24        future: impl Future<Output = R> + Send + Sync + 'static,
  25    ) -> Self::Task<R>
  26    where
  27        R: Send + Sync + 'static;
  28}
  29
  30/// An item that can be stored in a [`SumTree`]
  31///
  32/// Must be summarized by a type that implements [`Summary`]
  33pub trait Item: Clone {
  34    type Summary: Summary;
  35
  36    fn summary(&self, cx: <Self::Summary as Summary>::Context<'_>) -> Self::Summary;
  37}
  38
  39/// An [`Item`] whose summary has a specific key that can be used to identify it
  40pub trait KeyedItem: Item {
  41    type Key: for<'a> Dimension<'a, Self::Summary> + Ord;
  42
  43    fn key(&self) -> Self::Key;
  44}
  45
  46/// A type that describes the Sum of all [`Item`]s in a subtree of the [`SumTree`]
  47///
  48/// Each Summary type can have multiple [`Dimension`]s that it measures,
  49/// which can be used to navigate the tree
  50pub trait Summary: Clone {
  51    type Context<'a>: Copy;
  52    fn zero<'a>(cx: Self::Context<'a>) -> Self;
  53    fn add_summary<'a>(&mut self, summary: &Self, cx: Self::Context<'a>);
  54}
  55
  56pub trait ContextLessSummary: Clone {
  57    fn zero() -> Self;
  58    fn add_summary(&mut self, summary: &Self);
  59}
  60
  61impl<T: ContextLessSummary> Summary for T {
  62    type Context<'a> = ();
  63
  64    fn zero<'a>((): ()) -> Self {
  65        T::zero()
  66    }
  67
  68    fn add_summary<'a>(&mut self, summary: &Self, (): ()) {
  69        T::add_summary(self, summary)
  70    }
  71}
  72
  73#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
  74pub struct NoSummary;
  75
  76/// Catch-all implementation for when you need something that implements [`Summary`] without a specific type.
  77/// We implement it on a `NoSummary` instead of re-using `()`, as that avoids blanket impl collisions with `impl<T: Summary> Dimension for T`
  78/// (as we also need unit type to be a fill-in dimension)
  79impl ContextLessSummary for NoSummary {
  80    fn zero() -> Self {
  81        NoSummary
  82    }
  83
  84    fn add_summary(&mut self, _: &Self) {}
  85}
  86
  87/// Each [`Summary`] type can have more than one [`Dimension`] type that it measures.
  88///
  89/// You can use dimensions to seek to a specific location in the [`SumTree`]
  90///
  91/// # Example:
  92/// Zed's rope has a `TextSummary` type that summarizes lines, characters, and bytes.
  93/// Each of these are different dimensions we may want to seek to
  94pub trait Dimension<'a, S: Summary>: Clone {
  95    fn zero(cx: S::Context<'_>) -> Self;
  96
  97    fn add_summary(&mut self, summary: &'a S, cx: S::Context<'_>);
  98    #[must_use]
  99    fn with_added_summary(mut self, summary: &'a S, cx: S::Context<'_>) -> Self {
 100        self.add_summary(summary, cx);
 101        self
 102    }
 103
 104    fn from_summary(summary: &'a S, cx: S::Context<'_>) -> Self {
 105        let mut dimension = Self::zero(cx);
 106        dimension.add_summary(summary, cx);
 107        dimension
 108    }
 109}
 110
 111impl<'a, T: Summary> Dimension<'a, T> for T {
 112    fn zero(cx: T::Context<'_>) -> Self {
 113        Summary::zero(cx)
 114    }
 115
 116    fn add_summary(&mut self, summary: &'a T, cx: T::Context<'_>) {
 117        Summary::add_summary(self, summary, cx);
 118    }
 119}
 120
 121pub trait SeekTarget<'a, S: Summary, D: Dimension<'a, S>> {
 122    fn cmp(&self, cursor_location: &D, cx: S::Context<'_>) -> Ordering;
 123}
 124
 125impl<'a, S: Summary, D: Dimension<'a, S> + Ord> SeekTarget<'a, S, D> for D {
 126    fn cmp(&self, cursor_location: &Self, _: S::Context<'_>) -> Ordering {
 127        Ord::cmp(self, cursor_location)
 128    }
 129}
 130
 131impl<'a, T: Summary> Dimension<'a, T> for () {
 132    fn zero(_: T::Context<'_>) -> Self {}
 133
 134    fn add_summary(&mut self, _: &'a T, _: T::Context<'_>) {}
 135}
 136
 137#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 138pub struct Dimensions<D1, D2, D3 = ()>(pub D1, pub D2, pub D3);
 139
 140impl<'a, T: Summary, D1: Dimension<'a, T>, D2: Dimension<'a, T>, D3: Dimension<'a, T>>
 141    Dimension<'a, T> for Dimensions<D1, D2, D3>
 142{
 143    fn zero(cx: T::Context<'_>) -> Self {
 144        Dimensions(D1::zero(cx), D2::zero(cx), D3::zero(cx))
 145    }
 146
 147    fn add_summary(&mut self, summary: &'a T, cx: T::Context<'_>) {
 148        self.0.add_summary(summary, cx);
 149        self.1.add_summary(summary, cx);
 150        self.2.add_summary(summary, cx);
 151    }
 152}
 153
 154impl<'a, S, D1, D2, D3> SeekTarget<'a, S, Dimensions<D1, D2, D3>> for D1
 155where
 156    S: Summary,
 157    D1: SeekTarget<'a, S, D1> + Dimension<'a, S>,
 158    D2: Dimension<'a, S>,
 159    D3: Dimension<'a, S>,
 160{
 161    fn cmp(&self, cursor_location: &Dimensions<D1, D2, D3>, cx: S::Context<'_>) -> Ordering {
 162        self.cmp(&cursor_location.0, cx)
 163    }
 164}
 165
 166/// Bias is used to settle ambiguities when determining positions in an ordered sequence.
 167///
 168/// The primary use case is for text, where Bias influences
 169/// which character an offset or anchor is associated with.
 170///
 171/// # Examples
 172/// Given the buffer `AˇBCD`:
 173/// - The offset of the cursor is 1
 174/// - [Bias::Left] would attach the cursor to the character `A`
 175/// - [Bias::Right] would attach the cursor to the character `B`
 176///
 177/// Given the buffer `A«BCˇ»D`:
 178/// - The offset of the cursor is 3, and the selection is from 1 to 3
 179/// - The left anchor of the selection has [Bias::Right], attaching it to the character `B`
 180/// - The right anchor of the selection has [Bias::Left], attaching it to the character `C`
 181///
 182/// Given the buffer `{ˇ<...>`, where `<...>` is a folded region:
 183/// - The display offset of the cursor is 1, but the offset in the buffer is determined by the bias
 184/// - [Bias::Left] would attach the cursor to the character `{`, with a buffer offset of 1
 185/// - [Bias::Right] would attach the cursor to the first character of the folded region,
 186///   and the buffer offset would be the offset of the first character of the folded region
 187#[derive(Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Debug, Hash, Default)]
 188pub enum Bias {
 189    /// Attach to the character on the left
 190    #[default]
 191    Left,
 192    /// Attach to the character on the right
 193    Right,
 194}
 195
 196impl Bias {
 197    pub fn invert(self) -> Self {
 198        match self {
 199            Self::Left => Self::Right,
 200            Self::Right => Self::Left,
 201        }
 202    }
 203}
 204
 205/// A B+ tree in which each leaf node contains `Item`s of type `T` and a `Summary`s for each `Item`.
 206/// Each internal node contains a `Summary` of the items in its subtree.
 207///
 208/// The maximum number of items per node is `TREE_BASE * 2`.
 209///
 210/// Any [`Dimension`] supported by the [`Summary`] type can be used to seek to a specific location in the tree.
 211#[derive(Clone)]
 212pub struct SumTree<T: Item>(Arc<Node<T>>);
 213
 214impl<T> fmt::Debug for SumTree<T>
 215where
 216    T: fmt::Debug + Item,
 217    T::Summary: fmt::Debug,
 218{
 219    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
 220        f.debug_tuple("SumTree").field(&self.0).finish()
 221    }
 222}
 223
 224impl<T: Item> SumTree<T> {
 225    pub fn new(cx: <T::Summary as Summary>::Context<'_>) -> Self {
 226        SumTree(Arc::new(Node::Leaf {
 227            summary: <T::Summary as Summary>::zero(cx),
 228            items: ArrayVec::new(),
 229            item_summaries: ArrayVec::new(),
 230        }))
 231    }
 232
 233    /// Useful in cases where the item type has a non-trivial context type, but the zero value of the summary type doesn't depend on that context.
 234    pub fn from_summary(summary: T::Summary) -> Self {
 235        SumTree(Arc::new(Node::Leaf {
 236            summary,
 237            items: ArrayVec::new(),
 238            item_summaries: ArrayVec::new(),
 239        }))
 240    }
 241
 242    pub fn from_item(item: T, cx: <T::Summary as Summary>::Context<'_>) -> Self {
 243        let mut tree = Self::new(cx);
 244        tree.push(item, cx);
 245        tree
 246    }
 247
 248    pub fn from_iter<I: IntoIterator<Item = T>>(
 249        iter: I,
 250        cx: <T::Summary as Summary>::Context<'_>,
 251    ) -> Self {
 252        let mut nodes = Vec::new();
 253
 254        let mut iter = iter.into_iter().fuse().peekable();
 255        while iter.peek().is_some() {
 256            let items: ArrayVec<T, { 2 * TREE_BASE }> = iter.by_ref().take(2 * TREE_BASE).collect();
 257            let item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }> =
 258                items.iter().map(|item| item.summary(cx)).collect();
 259
 260            let mut summary = item_summaries[0].clone();
 261            for item_summary in &item_summaries[1..] {
 262                <T::Summary as Summary>::add_summary(&mut summary, item_summary, cx);
 263            }
 264
 265            nodes.push(Node::Leaf {
 266                summary,
 267                items,
 268                item_summaries,
 269            });
 270        }
 271
 272        let mut parent_nodes = Vec::new();
 273        let mut height = 0;
 274        while nodes.len() > 1 {
 275            height += 1;
 276            let mut current_parent_node = None;
 277            for child_node in nodes.drain(..) {
 278                let parent_node = current_parent_node.get_or_insert_with(|| Node::Internal {
 279                    summary: <T::Summary as Summary>::zero(cx),
 280                    height,
 281                    child_summaries: ArrayVec::new(),
 282                    child_trees: ArrayVec::new(),
 283                });
 284                let Node::Internal {
 285                    summary,
 286                    child_summaries,
 287                    child_trees,
 288                    ..
 289                } = parent_node
 290                else {
 291                    unreachable!()
 292                };
 293                let child_summary = child_node.summary();
 294                <T::Summary as Summary>::add_summary(summary, child_summary, cx);
 295                child_summaries.push(child_summary.clone());
 296                child_trees.push(Self(Arc::new(child_node)));
 297
 298                if child_trees.len() == 2 * TREE_BASE {
 299                    parent_nodes.extend(current_parent_node.take());
 300                }
 301            }
 302            parent_nodes.extend(current_parent_node.take());
 303            mem::swap(&mut nodes, &mut parent_nodes);
 304        }
 305
 306        if nodes.is_empty() {
 307            Self::new(cx)
 308        } else {
 309            debug_assert_eq!(nodes.len(), 1);
 310            Self(Arc::new(nodes.pop().unwrap()))
 311        }
 312    }
 313
 314    pub async fn from_iter_async<I, S>(iter: I, spawn: S) -> Self
 315    where
 316        T: 'static + Send + Sync,
 317        for<'a> T::Summary: Summary<Context<'a> = ()> + Send + Sync,
 318        S: BackgroundSpawn,
 319        I: IntoIterator<Item = T, IntoIter: ExactSizeIterator>,
 320    {
 321        let iter = iter.into_iter();
 322        let num_leaves = iter.len().div_ceil(2 * TREE_BASE);
 323
 324        if num_leaves == 0 {
 325            return Self::new(());
 326        }
 327
 328        let mut nodes = stream::iter(iter)
 329            .chunks(num_leaves.div_ceil(4))
 330            .map(|chunk| async move {
 331                let mut chunk = chunk.into_iter();
 332                let mut leaves = vec![];
 333                loop {
 334                    let items: ArrayVec<T, { 2 * TREE_BASE }> =
 335                        chunk.by_ref().take(2 * TREE_BASE).collect();
 336                    if items.is_empty() {
 337                        break;
 338                    }
 339                    let item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }> =
 340                        items.iter().map(|item| item.summary(())).collect();
 341                    let mut summary = item_summaries[0].clone();
 342                    for item_summary in &item_summaries[1..] {
 343                        <T::Summary as Summary>::add_summary(&mut summary, item_summary, ());
 344                    }
 345                    leaves.push(SumTree(Arc::new(Node::Leaf {
 346                        summary,
 347                        items,
 348                        item_summaries,
 349                    })));
 350                    yield_now().await;
 351                }
 352                leaves
 353            })
 354            .map(|future| spawn.background_spawn(future))
 355            .buffered(4)
 356            .flat_map(|it| stream::iter(it.into_iter()))
 357            .collect::<Vec<_>>()
 358            .await;
 359
 360        let mut height = 0;
 361        while nodes.len() > 1 {
 362            height += 1;
 363            let current_nodes = mem::take(&mut nodes);
 364            nodes = stream::iter(current_nodes)
 365                .chunks(2 * TREE_BASE)
 366                .map(|chunk| {
 367                    spawn.background_spawn(async move {
 368                        let child_trees: ArrayVec<SumTree<T>, { 2 * TREE_BASE }> =
 369                            chunk.into_iter().collect();
 370                        let child_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }> = child_trees
 371                            .iter()
 372                            .map(|child_tree| child_tree.summary().clone())
 373                            .collect();
 374                        let mut summary = child_summaries[0].clone();
 375                        for child_summary in &child_summaries[1..] {
 376                            <T::Summary as Summary>::add_summary(&mut summary, child_summary, ());
 377                        }
 378                        SumTree(Arc::new(Node::Internal {
 379                            height,
 380                            summary,
 381                            child_summaries,
 382                            child_trees,
 383                        }))
 384                    })
 385                })
 386                .buffered(4)
 387                .collect::<Vec<_>>()
 388                .await;
 389        }
 390
 391        if nodes.is_empty() {
 392            Self::new(())
 393        } else {
 394            debug_assert_eq!(nodes.len(), 1);
 395            nodes.pop().unwrap()
 396        }
 397    }
 398
 399    #[allow(unused)]
 400    pub fn items<'a>(&'a self, cx: <T::Summary as Summary>::Context<'a>) -> Vec<T> {
 401        let mut items = Vec::new();
 402        let mut cursor = self.cursor::<()>(cx);
 403        cursor.next();
 404        while let Some(item) = cursor.item() {
 405            items.push(item.clone());
 406            cursor.next();
 407        }
 408        items
 409    }
 410
 411    pub fn iter(&self) -> Iter<'_, T> {
 412        Iter::new(self)
 413    }
 414
 415    /// A more efficient version of `Cursor::new()` + `Cursor::seek()` + `Cursor::item()`.
 416    ///
 417    /// Only returns the item that exactly has the target match.
 418    pub fn find_exact<'a, 'slf, D, Target>(
 419        &'slf self,
 420        cx: <T::Summary as Summary>::Context<'a>,
 421        target: &Target,
 422        bias: Bias,
 423    ) -> (D, D, Option<&'slf T>)
 424    where
 425        D: Dimension<'slf, T::Summary>,
 426        Target: SeekTarget<'slf, T::Summary, D>,
 427    {
 428        let tree_end = D::zero(cx).with_added_summary(self.summary(), cx);
 429        let comparison = target.cmp(&tree_end, cx);
 430        if comparison == Ordering::Greater || (comparison == Ordering::Equal && bias == Bias::Right)
 431        {
 432            return (tree_end.clone(), tree_end, None);
 433        }
 434
 435        let mut pos = D::zero(cx);
 436        return match Self::find_recurse::<_, _, true>(cx, target, bias, &mut pos, self) {
 437            Some((item, end)) => (pos, end, Some(item)),
 438            None => (pos.clone(), pos, None),
 439        };
 440    }
 441
 442    /// A more efficient version of `Cursor::new()` + `Cursor::seek()` + `Cursor::item()`
 443    pub fn find<'a, 'slf, D, Target>(
 444        &'slf self,
 445        cx: <T::Summary as Summary>::Context<'a>,
 446        target: &Target,
 447        bias: Bias,
 448    ) -> (D, D, Option<&'slf T>)
 449    where
 450        D: Dimension<'slf, T::Summary>,
 451        Target: SeekTarget<'slf, T::Summary, D>,
 452    {
 453        let tree_end = D::zero(cx).with_added_summary(self.summary(), cx);
 454        let comparison = target.cmp(&tree_end, cx);
 455        if comparison == Ordering::Greater || (comparison == Ordering::Equal && bias == Bias::Right)
 456        {
 457            return (tree_end.clone(), tree_end, None);
 458        }
 459
 460        let mut pos = D::zero(cx);
 461        return match Self::find_recurse::<_, _, false>(cx, target, bias, &mut pos, self) {
 462            Some((item, end)) => (pos, end, Some(item)),
 463            None => (pos.clone(), pos, None),
 464        };
 465    }
 466
 467    fn find_recurse<'tree, 'a, D, Target, const EXACT: bool>(
 468        cx: <T::Summary as Summary>::Context<'a>,
 469        target: &Target,
 470        bias: Bias,
 471        position: &mut D,
 472        this: &'tree SumTree<T>,
 473    ) -> Option<(&'tree T, D)>
 474    where
 475        D: Dimension<'tree, T::Summary>,
 476        Target: SeekTarget<'tree, T::Summary, D>,
 477    {
 478        match &*this.0 {
 479            Node::Internal {
 480                child_summaries,
 481                child_trees,
 482                ..
 483            } => {
 484                for (child_tree, child_summary) in child_trees.iter().zip(child_summaries) {
 485                    let child_end = position.clone().with_added_summary(child_summary, cx);
 486
 487                    let comparison = target.cmp(&child_end, cx);
 488                    let target_in_child = comparison == Ordering::Less
 489                        || (comparison == Ordering::Equal && bias == Bias::Left);
 490                    if target_in_child {
 491                        return Self::find_recurse::<D, Target, EXACT>(
 492                            cx, target, bias, position, child_tree,
 493                        );
 494                    }
 495                    *position = child_end;
 496                }
 497            }
 498            Node::Leaf {
 499                items,
 500                item_summaries,
 501                ..
 502            } => {
 503                for (item, item_summary) in items.iter().zip(item_summaries) {
 504                    let mut child_end = position.clone();
 505                    child_end.add_summary(item_summary, cx);
 506
 507                    let comparison = target.cmp(&child_end, cx);
 508                    let entry_found = if EXACT {
 509                        comparison == Ordering::Equal
 510                    } else {
 511                        comparison == Ordering::Less
 512                            || (comparison == Ordering::Equal && bias == Bias::Left)
 513                    };
 514                    if entry_found {
 515                        return Some((item, child_end));
 516                    }
 517
 518                    *position = child_end;
 519                }
 520            }
 521        }
 522        None
 523    }
 524
 525    pub fn cursor<'a, 'b, D>(
 526        &'a self,
 527        cx: <T::Summary as Summary>::Context<'b>,
 528    ) -> Cursor<'a, 'b, T, D>
 529    where
 530        D: Dimension<'a, T::Summary>,
 531    {
 532        Cursor::new(self, cx)
 533    }
 534
 535    /// Note: If the summary type requires a non `()` context, then the filter cursor
 536    /// that is returned cannot be used with Rust's iterators.
 537    pub fn filter<'a, 'b, F, U>(
 538        &'a self,
 539        cx: <T::Summary as Summary>::Context<'b>,
 540        filter_node: F,
 541    ) -> FilterCursor<'a, 'b, F, T, U>
 542    where
 543        F: FnMut(&T::Summary) -> bool,
 544        U: Dimension<'a, T::Summary>,
 545    {
 546        FilterCursor::new(self, cx, filter_node)
 547    }
 548
 549    #[allow(dead_code)]
 550    pub fn first(&self) -> Option<&T> {
 551        self.leftmost_leaf().0.items().first()
 552    }
 553
 554    pub fn last(&self) -> Option<&T> {
 555        self.rightmost_leaf().0.items().last()
 556    }
 557
 558    pub fn update_last(
 559        &mut self,
 560        f: impl FnOnce(&mut T),
 561        cx: <T::Summary as Summary>::Context<'_>,
 562    ) {
 563        self.update_last_recursive(f, cx);
 564    }
 565
 566    fn update_last_recursive(
 567        &mut self,
 568        f: impl FnOnce(&mut T),
 569        cx: <T::Summary as Summary>::Context<'_>,
 570    ) -> Option<T::Summary> {
 571        match Arc::make_mut(&mut self.0) {
 572            Node::Internal {
 573                summary,
 574                child_summaries,
 575                child_trees,
 576                ..
 577            } => {
 578                let last_summary = child_summaries.last_mut().unwrap();
 579                let last_child = child_trees.last_mut().unwrap();
 580                *last_summary = last_child.update_last_recursive(f, cx).unwrap();
 581                *summary = sum(child_summaries.iter(), cx);
 582                Some(summary.clone())
 583            }
 584            Node::Leaf {
 585                summary,
 586                items,
 587                item_summaries,
 588            } => {
 589                if let Some((item, item_summary)) = items.last_mut().zip(item_summaries.last_mut())
 590                {
 591                    (f)(item);
 592                    *item_summary = item.summary(cx);
 593                    *summary = sum(item_summaries.iter(), cx);
 594                    Some(summary.clone())
 595                } else {
 596                    None
 597                }
 598            }
 599        }
 600    }
 601
 602    pub fn extent<'a, D: Dimension<'a, T::Summary>>(
 603        &'a self,
 604        cx: <T::Summary as Summary>::Context<'_>,
 605    ) -> D {
 606        let mut extent = D::zero(cx);
 607        match self.0.as_ref() {
 608            Node::Internal { summary, .. } | Node::Leaf { summary, .. } => {
 609                extent.add_summary(summary, cx);
 610            }
 611        }
 612        extent
 613    }
 614
 615    pub fn summary(&self) -> &T::Summary {
 616        match self.0.as_ref() {
 617            Node::Internal { summary, .. } => summary,
 618            Node::Leaf { summary, .. } => summary,
 619        }
 620    }
 621
 622    pub fn is_empty(&self) -> bool {
 623        match self.0.as_ref() {
 624            Node::Internal { .. } => false,
 625            Node::Leaf { items, .. } => items.is_empty(),
 626        }
 627    }
 628
 629    pub fn extend<I>(&mut self, iter: I, cx: <T::Summary as Summary>::Context<'_>)
 630    where
 631        I: IntoIterator<Item = T>,
 632    {
 633        self.append(Self::from_iter(iter, cx), cx);
 634    }
 635
 636    pub async fn async_extend<S, I>(&mut self, iter: I, spawn: S)
 637    where
 638        S: BackgroundSpawn,
 639        I: IntoIterator<Item = T, IntoIter: ExactSizeIterator>,
 640        T: 'static + Send + Sync,
 641        for<'b> T::Summary: Summary<Context<'b> = ()> + Send + Sync,
 642    {
 643        let other = Self::from_iter_async(iter, spawn);
 644        self.append(other.await, ());
 645    }
 646
 647    pub fn push(&mut self, item: T, cx: <T::Summary as Summary>::Context<'_>) {
 648        let summary = item.summary(cx);
 649        self.append(
 650            SumTree(Arc::new(Node::Leaf {
 651                summary: summary.clone(),
 652                items: ArrayVec::from_iter(Some(item)),
 653                item_summaries: ArrayVec::from_iter(Some(summary)),
 654            })),
 655            cx,
 656        );
 657    }
 658
 659    pub fn append(&mut self, other: Self, cx: <T::Summary as Summary>::Context<'_>) {
 660        if self.is_empty() {
 661            *self = other;
 662        } else if !other.0.is_leaf() || !other.0.items().is_empty() {
 663            if self.0.height() < other.0.height() {
 664                for tree in other.0.child_trees() {
 665                    self.append(tree.clone(), cx);
 666                }
 667            } else if let Some(split_tree) = self.push_tree_recursive(other, cx) {
 668                *self = Self::from_child_trees(self.clone(), split_tree, cx);
 669            }
 670        }
 671    }
 672
 673    fn push_tree_recursive(
 674        &mut self,
 675        other: SumTree<T>,
 676        cx: <T::Summary as Summary>::Context<'_>,
 677    ) -> Option<SumTree<T>> {
 678        match Arc::make_mut(&mut self.0) {
 679            Node::Internal {
 680                height,
 681                summary,
 682                child_summaries,
 683                child_trees,
 684                ..
 685            } => {
 686                let other_node = other.0.clone();
 687                <T::Summary as Summary>::add_summary(summary, other_node.summary(), cx);
 688
 689                let height_delta = *height - other_node.height();
 690                let mut summaries_to_append = ArrayVec::<T::Summary, { 2 * TREE_BASE }>::new();
 691                let mut trees_to_append = ArrayVec::<SumTree<T>, { 2 * TREE_BASE }>::new();
 692                if height_delta == 0 {
 693                    summaries_to_append.extend(other_node.child_summaries().iter().cloned());
 694                    trees_to_append.extend(other_node.child_trees().iter().cloned());
 695                } else if height_delta == 1 && !other_node.is_underflowing() {
 696                    summaries_to_append.push(other_node.summary().clone());
 697                    trees_to_append.push(other)
 698                } else {
 699                    let tree_to_append = child_trees
 700                        .last_mut()
 701                        .unwrap()
 702                        .push_tree_recursive(other, cx);
 703                    *child_summaries.last_mut().unwrap() =
 704                        child_trees.last().unwrap().0.summary().clone();
 705
 706                    if let Some(split_tree) = tree_to_append {
 707                        summaries_to_append.push(split_tree.0.summary().clone());
 708                        trees_to_append.push(split_tree);
 709                    }
 710                }
 711
 712                let child_count = child_trees.len() + trees_to_append.len();
 713                if child_count > 2 * TREE_BASE {
 714                    let left_summaries: ArrayVec<_, { 2 * TREE_BASE }>;
 715                    let right_summaries: ArrayVec<_, { 2 * TREE_BASE }>;
 716                    let left_trees;
 717                    let right_trees;
 718
 719                    let midpoint = (child_count + child_count % 2) / 2;
 720                    {
 721                        let mut all_summaries = child_summaries
 722                            .iter()
 723                            .chain(summaries_to_append.iter())
 724                            .cloned();
 725                        left_summaries = all_summaries.by_ref().take(midpoint).collect();
 726                        right_summaries = all_summaries.collect();
 727                        let mut all_trees =
 728                            child_trees.iter().chain(trees_to_append.iter()).cloned();
 729                        left_trees = all_trees.by_ref().take(midpoint).collect();
 730                        right_trees = all_trees.collect();
 731                    }
 732                    *summary = sum(left_summaries.iter(), cx);
 733                    *child_summaries = left_summaries;
 734                    *child_trees = left_trees;
 735
 736                    Some(SumTree(Arc::new(Node::Internal {
 737                        height: *height,
 738                        summary: sum(right_summaries.iter(), cx),
 739                        child_summaries: right_summaries,
 740                        child_trees: right_trees,
 741                    })))
 742                } else {
 743                    child_summaries.extend(summaries_to_append);
 744                    child_trees.extend(trees_to_append);
 745                    None
 746                }
 747            }
 748            Node::Leaf {
 749                summary,
 750                items,
 751                item_summaries,
 752            } => {
 753                let other_node = other.0;
 754
 755                let child_count = items.len() + other_node.items().len();
 756                if child_count > 2 * TREE_BASE {
 757                    let left_items;
 758                    let right_items;
 759                    let left_summaries;
 760                    let right_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }>;
 761
 762                    let midpoint = (child_count + child_count % 2) / 2;
 763                    {
 764                        let mut all_items = items.iter().chain(other_node.items().iter()).cloned();
 765                        left_items = all_items.by_ref().take(midpoint).collect();
 766                        right_items = all_items.collect();
 767
 768                        let mut all_summaries = item_summaries
 769                            .iter()
 770                            .chain(other_node.child_summaries())
 771                            .cloned();
 772                        left_summaries = all_summaries.by_ref().take(midpoint).collect();
 773                        right_summaries = all_summaries.collect();
 774                    }
 775                    *items = left_items;
 776                    *item_summaries = left_summaries;
 777                    *summary = sum(item_summaries.iter(), cx);
 778                    Some(SumTree(Arc::new(Node::Leaf {
 779                        items: right_items,
 780                        summary: sum(right_summaries.iter(), cx),
 781                        item_summaries: right_summaries,
 782                    })))
 783                } else {
 784                    <T::Summary as Summary>::add_summary(summary, other_node.summary(), cx);
 785                    items.extend(other_node.items().iter().cloned());
 786                    item_summaries.extend(other_node.child_summaries().iter().cloned());
 787                    None
 788                }
 789            }
 790        }
 791    }
 792
 793    fn from_child_trees(
 794        left: SumTree<T>,
 795        right: SumTree<T>,
 796        cx: <T::Summary as Summary>::Context<'_>,
 797    ) -> Self {
 798        let height = left.0.height() + 1;
 799        let mut child_summaries = ArrayVec::new();
 800        child_summaries.push(left.0.summary().clone());
 801        child_summaries.push(right.0.summary().clone());
 802        let mut child_trees = ArrayVec::new();
 803        child_trees.push(left);
 804        child_trees.push(right);
 805        SumTree(Arc::new(Node::Internal {
 806            height,
 807            summary: sum(child_summaries.iter(), cx),
 808            child_summaries,
 809            child_trees,
 810        }))
 811    }
 812
 813    fn leftmost_leaf(&self) -> &Self {
 814        match *self.0 {
 815            Node::Leaf { .. } => self,
 816            Node::Internal {
 817                ref child_trees, ..
 818            } => child_trees.first().unwrap().leftmost_leaf(),
 819        }
 820    }
 821
 822    fn rightmost_leaf(&self) -> &Self {
 823        match *self.0 {
 824            Node::Leaf { .. } => self,
 825            Node::Internal {
 826                ref child_trees, ..
 827            } => child_trees.last().unwrap().rightmost_leaf(),
 828        }
 829    }
 830}
 831
 832impl<T: Item + PartialEq> PartialEq for SumTree<T> {
 833    fn eq(&self, other: &Self) -> bool {
 834        self.iter().eq(other.iter())
 835    }
 836}
 837
 838impl<T: Item + Eq> Eq for SumTree<T> {}
 839
 840impl<T: KeyedItem> SumTree<T> {
 841    pub fn insert_or_replace<'a, 'b>(
 842        &'a mut self,
 843        item: T,
 844        cx: <T::Summary as Summary>::Context<'b>,
 845    ) -> Option<T> {
 846        let mut replaced = None;
 847        {
 848            let mut cursor = self.cursor::<T::Key>(cx);
 849            let mut new_tree = cursor.slice(&item.key(), Bias::Left);
 850            if let Some(cursor_item) = cursor.item()
 851                && cursor_item.key() == item.key()
 852            {
 853                replaced = Some(cursor_item.clone());
 854                cursor.next();
 855            }
 856            new_tree.push(item, cx);
 857            new_tree.append(cursor.suffix(), cx);
 858            drop(cursor);
 859            *self = new_tree
 860        };
 861        replaced
 862    }
 863
 864    pub fn remove(&mut self, key: &T::Key, cx: <T::Summary as Summary>::Context<'_>) -> Option<T> {
 865        let mut removed = None;
 866        *self = {
 867            let mut cursor = self.cursor::<T::Key>(cx);
 868            let mut new_tree = cursor.slice(key, Bias::Left);
 869            if let Some(item) = cursor.item()
 870                && item.key() == *key
 871            {
 872                removed = Some(item.clone());
 873                cursor.next();
 874            }
 875            new_tree.append(cursor.suffix(), cx);
 876            new_tree
 877        };
 878        removed
 879    }
 880
 881    pub fn edit(
 882        &mut self,
 883        mut edits: Vec<Edit<T>>,
 884        cx: <T::Summary as Summary>::Context<'_>,
 885    ) -> Vec<T> {
 886        if edits.is_empty() {
 887            return Vec::new();
 888        }
 889
 890        let mut removed = Vec::new();
 891        edits.sort_unstable_by_key(|item| item.key());
 892
 893        *self = {
 894            let mut cursor = self.cursor::<T::Key>(cx);
 895            let mut new_tree = SumTree::new(cx);
 896            let mut buffered_items = Vec::new();
 897
 898            cursor.seek(&T::Key::zero(cx), Bias::Left);
 899            for edit in edits {
 900                let new_key = edit.key();
 901                let mut old_item = cursor.item();
 902
 903                if old_item
 904                    .as_ref()
 905                    .is_some_and(|old_item| old_item.key() < new_key)
 906                {
 907                    new_tree.extend(buffered_items.drain(..), cx);
 908                    let slice = cursor.slice(&new_key, Bias::Left);
 909                    new_tree.append(slice, cx);
 910                    old_item = cursor.item();
 911                }
 912
 913                if let Some(old_item) = old_item
 914                    && old_item.key() == new_key
 915                {
 916                    removed.push(old_item.clone());
 917                    cursor.next();
 918                }
 919
 920                match edit {
 921                    Edit::Insert(item) => {
 922                        buffered_items.push(item);
 923                    }
 924                    Edit::Remove(_) => {}
 925                }
 926            }
 927
 928            new_tree.extend(buffered_items, cx);
 929            new_tree.append(cursor.suffix(), cx);
 930            new_tree
 931        };
 932
 933        removed
 934    }
 935
 936    pub fn get<'a>(
 937        &'a self,
 938        key: &T::Key,
 939        cx: <T::Summary as Summary>::Context<'a>,
 940    ) -> Option<&'a T> {
 941        if let (_, _, Some(item)) = self.find_exact::<T::Key, _>(cx, key, Bias::Left) {
 942            Some(item)
 943        } else {
 944            None
 945        }
 946    }
 947}
 948
 949impl<T, S> Default for SumTree<T>
 950where
 951    T: Item<Summary = S>,
 952    S: for<'a> Summary<Context<'a> = ()>,
 953{
 954    fn default() -> Self {
 955        Self::new(())
 956    }
 957}
 958
 959#[derive(Clone)]
 960pub enum Node<T: Item> {
 961    Internal {
 962        height: u8,
 963        summary: T::Summary,
 964        child_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }>,
 965        child_trees: ArrayVec<SumTree<T>, { 2 * TREE_BASE }>,
 966    },
 967    Leaf {
 968        summary: T::Summary,
 969        items: ArrayVec<T, { 2 * TREE_BASE }>,
 970        item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }>,
 971    },
 972}
 973
 974impl<T> fmt::Debug for Node<T>
 975where
 976    T: Item + fmt::Debug,
 977    T::Summary: fmt::Debug,
 978{
 979    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 980        match self {
 981            Node::Internal {
 982                height,
 983                summary,
 984                child_summaries,
 985                child_trees,
 986            } => f
 987                .debug_struct("Internal")
 988                .field("height", height)
 989                .field("summary", summary)
 990                .field("child_summaries", child_summaries)
 991                .field("child_trees", child_trees)
 992                .finish(),
 993            Node::Leaf {
 994                summary,
 995                items,
 996                item_summaries,
 997            } => f
 998                .debug_struct("Leaf")
 999                .field("summary", summary)
1000                .field("items", items)
1001                .field("item_summaries", item_summaries)
1002                .finish(),
1003        }
1004    }
1005}
1006
1007impl<T: Item> Node<T> {
1008    fn is_leaf(&self) -> bool {
1009        matches!(self, Node::Leaf { .. })
1010    }
1011
1012    fn height(&self) -> u8 {
1013        match self {
1014            Node::Internal { height, .. } => *height,
1015            Node::Leaf { .. } => 0,
1016        }
1017    }
1018
1019    fn summary(&self) -> &T::Summary {
1020        match self {
1021            Node::Internal { summary, .. } => summary,
1022            Node::Leaf { summary, .. } => summary,
1023        }
1024    }
1025
1026    fn child_summaries(&self) -> &[T::Summary] {
1027        match self {
1028            Node::Internal {
1029                child_summaries, ..
1030            } => child_summaries.as_slice(),
1031            Node::Leaf { item_summaries, .. } => item_summaries.as_slice(),
1032        }
1033    }
1034
1035    fn child_trees(&self) -> &ArrayVec<SumTree<T>, { 2 * TREE_BASE }> {
1036        match self {
1037            Node::Internal { child_trees, .. } => child_trees,
1038            Node::Leaf { .. } => panic!("Leaf nodes have no child trees"),
1039        }
1040    }
1041
1042    fn items(&self) -> &ArrayVec<T, { 2 * TREE_BASE }> {
1043        match self {
1044            Node::Leaf { items, .. } => items,
1045            Node::Internal { .. } => panic!("Internal nodes have no items"),
1046        }
1047    }
1048
1049    fn is_underflowing(&self) -> bool {
1050        match self {
1051            Node::Internal { child_trees, .. } => child_trees.len() < TREE_BASE,
1052            Node::Leaf { items, .. } => items.len() < TREE_BASE,
1053        }
1054    }
1055}
1056
1057#[derive(Debug)]
1058pub enum Edit<T: KeyedItem> {
1059    Insert(T),
1060    Remove(T::Key),
1061}
1062
1063impl<T: KeyedItem> Edit<T> {
1064    fn key(&self) -> T::Key {
1065        match self {
1066            Edit::Insert(item) => item.key(),
1067            Edit::Remove(key) => key.clone(),
1068        }
1069    }
1070}
1071
1072fn sum<'a, T, I>(iter: I, cx: T::Context<'_>) -> T
1073where
1074    T: 'a + Summary,
1075    I: Iterator<Item = &'a T>,
1076{
1077    let mut sum = T::zero(cx);
1078    for value in iter {
1079        sum.add_summary(value, cx);
1080    }
1081    sum
1082}
1083
1084#[cfg(test)]
1085mod tests {
1086    use super::*;
1087    use rand::{distr::StandardUniform, prelude::*};
1088    use std::cmp;
1089
1090    #[ctor::ctor]
1091    fn init_logger() {
1092        zlog::init_test();
1093    }
1094
1095    #[test]
1096    fn test_extend_and_push_tree() {
1097        let mut tree1 = SumTree::default();
1098        tree1.extend(0..20, ());
1099
1100        let mut tree2 = SumTree::default();
1101        tree2.extend(50..100, ());
1102
1103        tree1.append(tree2, ());
1104        assert_eq!(tree1.items(()), (0..20).chain(50..100).collect::<Vec<u8>>());
1105    }
1106
1107    #[test]
1108    fn test_random() {
1109        struct NoSpawn;
1110        impl BackgroundSpawn for NoSpawn {
1111            type Task<R>
1112                = std::pin::Pin<Box<dyn Future<Output = R> + Sync + Send>>
1113            where
1114                R: Send + Sync;
1115            fn background_spawn<R>(
1116                &self,
1117                future: impl Future<Output = R> + Send + Sync + 'static,
1118            ) -> Self::Task<R>
1119            where
1120                R: Send + Sync + 'static,
1121            {
1122                Box::pin(future)
1123            }
1124        }
1125
1126        let mut starting_seed = 0;
1127        if let Ok(value) = std::env::var("SEED") {
1128            starting_seed = value.parse().expect("invalid SEED variable");
1129        }
1130        let mut num_iterations = 100;
1131        if let Ok(value) = std::env::var("ITERATIONS") {
1132            num_iterations = value.parse().expect("invalid ITERATIONS variable");
1133        }
1134        let num_operations = std::env::var("OPERATIONS")
1135            .map_or(5, |o| o.parse().expect("invalid OPERATIONS variable"));
1136
1137        for seed in starting_seed..(starting_seed + num_iterations) {
1138            eprintln!("seed = {}", seed);
1139            let mut rng = StdRng::seed_from_u64(seed);
1140
1141            let rng = &mut rng;
1142            let mut tree = SumTree::<u8>::default();
1143            let count = rng.random_range(0..128);
1144            if rng.random() {
1145                tree.extend(rng.sample_iter(StandardUniform).take(count), ());
1146            } else {
1147                let items = rng
1148                    .sample_iter(StandardUniform)
1149                    .take(count)
1150                    .collect::<Vec<_>>();
1151                pollster::block_on(tree.async_extend(items, NoSpawn));
1152            }
1153
1154            for _ in 0..num_operations {
1155                let splice_end = rng.random_range(0..tree.extent::<Count>(()).0 + 1);
1156                let splice_start = rng.random_range(0..splice_end + 1);
1157                let count = rng.random_range(0..128);
1158                let tree_end = tree.extent::<Count>(());
1159                let new_items = rng
1160                    .sample_iter(StandardUniform)
1161                    .take(count)
1162                    .collect::<Vec<u8>>();
1163
1164                let mut reference_items = tree.items(());
1165                reference_items.splice(splice_start..splice_end, new_items.clone());
1166
1167                tree = {
1168                    let mut cursor = tree.cursor::<Count>(());
1169                    let mut new_tree = cursor.slice(&Count(splice_start), Bias::Right);
1170                    if rng.random() {
1171                        new_tree.extend(new_items, ());
1172                    } else {
1173                        pollster::block_on(new_tree.async_extend(new_items, NoSpawn));
1174                    }
1175                    cursor.seek(&Count(splice_end), Bias::Right);
1176                    new_tree.append(cursor.slice(&tree_end, Bias::Right), ());
1177                    new_tree
1178                };
1179
1180                assert_eq!(tree.items(()), reference_items);
1181                assert_eq!(
1182                    tree.iter().collect::<Vec<_>>(),
1183                    tree.cursor::<()>(()).collect::<Vec<_>>()
1184                );
1185
1186                log::info!("tree items: {:?}", tree.items(()));
1187
1188                let mut filter_cursor =
1189                    tree.filter::<_, Count>((), |summary| summary.contains_even);
1190                let expected_filtered_items = tree
1191                    .items(())
1192                    .into_iter()
1193                    .enumerate()
1194                    .filter(|(_, item)| (item & 1) == 0)
1195                    .collect::<Vec<_>>();
1196
1197                let mut item_ix = if rng.random() {
1198                    filter_cursor.next();
1199                    0
1200                } else {
1201                    filter_cursor.prev();
1202                    expected_filtered_items.len().saturating_sub(1)
1203                };
1204                while item_ix < expected_filtered_items.len() {
1205                    log::info!("filter_cursor, item_ix: {}", item_ix);
1206                    let actual_item = filter_cursor.item().unwrap();
1207                    let (reference_index, reference_item) = expected_filtered_items[item_ix];
1208                    assert_eq!(actual_item, &reference_item);
1209                    assert_eq!(filter_cursor.start().0, reference_index);
1210                    log::info!("next");
1211                    filter_cursor.next();
1212                    item_ix += 1;
1213
1214                    while item_ix > 0 && rng.random_bool(0.2) {
1215                        log::info!("prev");
1216                        filter_cursor.prev();
1217                        item_ix -= 1;
1218
1219                        if item_ix == 0 && rng.random_bool(0.2) {
1220                            filter_cursor.prev();
1221                            assert_eq!(filter_cursor.item(), None);
1222                            assert_eq!(filter_cursor.start().0, 0);
1223                            filter_cursor.next();
1224                        }
1225                    }
1226                }
1227                assert_eq!(filter_cursor.item(), None);
1228
1229                let mut before_start = false;
1230                let mut cursor = tree.cursor::<Count>(());
1231                let start_pos = rng.random_range(0..=reference_items.len());
1232                cursor.seek(&Count(start_pos), Bias::Right);
1233                let mut pos = rng.random_range(start_pos..=reference_items.len());
1234                cursor.seek_forward(&Count(pos), Bias::Right);
1235
1236                for i in 0..10 {
1237                    assert_eq!(cursor.start().0, pos);
1238
1239                    if pos > 0 {
1240                        assert_eq!(cursor.prev_item().unwrap(), &reference_items[pos - 1]);
1241                    } else {
1242                        assert_eq!(cursor.prev_item(), None);
1243                    }
1244
1245                    if pos < reference_items.len() && !before_start {
1246                        assert_eq!(cursor.item().unwrap(), &reference_items[pos]);
1247                    } else {
1248                        assert_eq!(cursor.item(), None);
1249                    }
1250
1251                    if before_start {
1252                        assert_eq!(cursor.next_item(), reference_items.first());
1253                    } else if pos + 1 < reference_items.len() {
1254                        assert_eq!(cursor.next_item().unwrap(), &reference_items[pos + 1]);
1255                    } else {
1256                        assert_eq!(cursor.next_item(), None);
1257                    }
1258
1259                    if i < 5 {
1260                        cursor.next();
1261                        if pos < reference_items.len() {
1262                            pos += 1;
1263                            before_start = false;
1264                        }
1265                    } else {
1266                        cursor.prev();
1267                        if pos == 0 {
1268                            before_start = true;
1269                        }
1270                        pos = pos.saturating_sub(1);
1271                    }
1272                }
1273            }
1274
1275            for _ in 0..10 {
1276                let end = rng.random_range(0..tree.extent::<Count>(()).0 + 1);
1277                let start = rng.random_range(0..end + 1);
1278                let start_bias = if rng.random() {
1279                    Bias::Left
1280                } else {
1281                    Bias::Right
1282                };
1283                let end_bias = if rng.random() {
1284                    Bias::Left
1285                } else {
1286                    Bias::Right
1287                };
1288
1289                let mut cursor = tree.cursor::<Count>(());
1290                cursor.seek(&Count(start), start_bias);
1291                let slice = cursor.slice(&Count(end), end_bias);
1292
1293                cursor.seek(&Count(start), start_bias);
1294                let summary = cursor.summary::<_, Sum>(&Count(end), end_bias);
1295
1296                assert_eq!(summary.0, slice.summary().sum);
1297            }
1298        }
1299    }
1300
1301    #[test]
1302    fn test_cursor() {
1303        // Empty tree
1304        let tree = SumTree::<u8>::default();
1305        let mut cursor = tree.cursor::<IntegersSummary>(());
1306        assert_eq!(
1307            cursor.slice(&Count(0), Bias::Right).items(()),
1308            Vec::<u8>::new()
1309        );
1310        assert_eq!(cursor.item(), None);
1311        assert_eq!(cursor.prev_item(), None);
1312        assert_eq!(cursor.next_item(), None);
1313        assert_eq!(cursor.start().sum, 0);
1314        cursor.prev();
1315        assert_eq!(cursor.item(), None);
1316        assert_eq!(cursor.prev_item(), None);
1317        assert_eq!(cursor.next_item(), None);
1318        assert_eq!(cursor.start().sum, 0);
1319        cursor.next();
1320        assert_eq!(cursor.item(), None);
1321        assert_eq!(cursor.prev_item(), None);
1322        assert_eq!(cursor.next_item(), None);
1323        assert_eq!(cursor.start().sum, 0);
1324
1325        // Single-element tree
1326        let mut tree = SumTree::<u8>::default();
1327        tree.extend(vec![1], ());
1328        let mut cursor = tree.cursor::<IntegersSummary>(());
1329        assert_eq!(
1330            cursor.slice(&Count(0), Bias::Right).items(()),
1331            Vec::<u8>::new()
1332        );
1333        assert_eq!(cursor.item(), Some(&1));
1334        assert_eq!(cursor.prev_item(), None);
1335        assert_eq!(cursor.next_item(), None);
1336        assert_eq!(cursor.start().sum, 0);
1337
1338        cursor.next();
1339        assert_eq!(cursor.item(), None);
1340        assert_eq!(cursor.prev_item(), Some(&1));
1341        assert_eq!(cursor.next_item(), None);
1342        assert_eq!(cursor.start().sum, 1);
1343
1344        cursor.prev();
1345        assert_eq!(cursor.item(), Some(&1));
1346        assert_eq!(cursor.prev_item(), None);
1347        assert_eq!(cursor.next_item(), None);
1348        assert_eq!(cursor.start().sum, 0);
1349
1350        let mut cursor = tree.cursor::<IntegersSummary>(());
1351        assert_eq!(cursor.slice(&Count(1), Bias::Right).items(()), [1]);
1352        assert_eq!(cursor.item(), None);
1353        assert_eq!(cursor.prev_item(), Some(&1));
1354        assert_eq!(cursor.next_item(), None);
1355        assert_eq!(cursor.start().sum, 1);
1356
1357        cursor.seek(&Count(0), Bias::Right);
1358        assert_eq!(
1359            cursor
1360                .slice(&tree.extent::<Count>(()), Bias::Right)
1361                .items(()),
1362            [1]
1363        );
1364        assert_eq!(cursor.item(), None);
1365        assert_eq!(cursor.prev_item(), Some(&1));
1366        assert_eq!(cursor.next_item(), None);
1367        assert_eq!(cursor.start().sum, 1);
1368
1369        // Multiple-element tree
1370        let mut tree = SumTree::default();
1371        tree.extend(vec![1, 2, 3, 4, 5, 6], ());
1372        let mut cursor = tree.cursor::<IntegersSummary>(());
1373
1374        assert_eq!(cursor.slice(&Count(2), Bias::Right).items(()), [1, 2]);
1375        assert_eq!(cursor.item(), Some(&3));
1376        assert_eq!(cursor.prev_item(), Some(&2));
1377        assert_eq!(cursor.next_item(), Some(&4));
1378        assert_eq!(cursor.start().sum, 3);
1379
1380        cursor.next();
1381        assert_eq!(cursor.item(), Some(&4));
1382        assert_eq!(cursor.prev_item(), Some(&3));
1383        assert_eq!(cursor.next_item(), Some(&5));
1384        assert_eq!(cursor.start().sum, 6);
1385
1386        cursor.next();
1387        assert_eq!(cursor.item(), Some(&5));
1388        assert_eq!(cursor.prev_item(), Some(&4));
1389        assert_eq!(cursor.next_item(), Some(&6));
1390        assert_eq!(cursor.start().sum, 10);
1391
1392        cursor.next();
1393        assert_eq!(cursor.item(), Some(&6));
1394        assert_eq!(cursor.prev_item(), Some(&5));
1395        assert_eq!(cursor.next_item(), None);
1396        assert_eq!(cursor.start().sum, 15);
1397
1398        cursor.next();
1399        cursor.next();
1400        assert_eq!(cursor.item(), None);
1401        assert_eq!(cursor.prev_item(), Some(&6));
1402        assert_eq!(cursor.next_item(), None);
1403        assert_eq!(cursor.start().sum, 21);
1404
1405        cursor.prev();
1406        assert_eq!(cursor.item(), Some(&6));
1407        assert_eq!(cursor.prev_item(), Some(&5));
1408        assert_eq!(cursor.next_item(), None);
1409        assert_eq!(cursor.start().sum, 15);
1410
1411        cursor.prev();
1412        assert_eq!(cursor.item(), Some(&5));
1413        assert_eq!(cursor.prev_item(), Some(&4));
1414        assert_eq!(cursor.next_item(), Some(&6));
1415        assert_eq!(cursor.start().sum, 10);
1416
1417        cursor.prev();
1418        assert_eq!(cursor.item(), Some(&4));
1419        assert_eq!(cursor.prev_item(), Some(&3));
1420        assert_eq!(cursor.next_item(), Some(&5));
1421        assert_eq!(cursor.start().sum, 6);
1422
1423        cursor.prev();
1424        assert_eq!(cursor.item(), Some(&3));
1425        assert_eq!(cursor.prev_item(), Some(&2));
1426        assert_eq!(cursor.next_item(), Some(&4));
1427        assert_eq!(cursor.start().sum, 3);
1428
1429        cursor.prev();
1430        assert_eq!(cursor.item(), Some(&2));
1431        assert_eq!(cursor.prev_item(), Some(&1));
1432        assert_eq!(cursor.next_item(), Some(&3));
1433        assert_eq!(cursor.start().sum, 1);
1434
1435        cursor.prev();
1436        assert_eq!(cursor.item(), Some(&1));
1437        assert_eq!(cursor.prev_item(), None);
1438        assert_eq!(cursor.next_item(), Some(&2));
1439        assert_eq!(cursor.start().sum, 0);
1440
1441        cursor.prev();
1442        assert_eq!(cursor.item(), None);
1443        assert_eq!(cursor.prev_item(), None);
1444        assert_eq!(cursor.next_item(), Some(&1));
1445        assert_eq!(cursor.start().sum, 0);
1446
1447        cursor.next();
1448        assert_eq!(cursor.item(), Some(&1));
1449        assert_eq!(cursor.prev_item(), None);
1450        assert_eq!(cursor.next_item(), Some(&2));
1451        assert_eq!(cursor.start().sum, 0);
1452
1453        let mut cursor = tree.cursor::<IntegersSummary>(());
1454        assert_eq!(
1455            cursor
1456                .slice(&tree.extent::<Count>(()), Bias::Right)
1457                .items(()),
1458            tree.items(())
1459        );
1460        assert_eq!(cursor.item(), None);
1461        assert_eq!(cursor.prev_item(), Some(&6));
1462        assert_eq!(cursor.next_item(), None);
1463        assert_eq!(cursor.start().sum, 21);
1464
1465        cursor.seek(&Count(3), Bias::Right);
1466        assert_eq!(
1467            cursor
1468                .slice(&tree.extent::<Count>(()), Bias::Right)
1469                .items(()),
1470            [4, 5, 6]
1471        );
1472        assert_eq!(cursor.item(), None);
1473        assert_eq!(cursor.prev_item(), Some(&6));
1474        assert_eq!(cursor.next_item(), None);
1475        assert_eq!(cursor.start().sum, 21);
1476
1477        // Seeking can bias left or right
1478        cursor.seek(&Count(1), Bias::Left);
1479        assert_eq!(cursor.item(), Some(&1));
1480        cursor.seek(&Count(1), Bias::Right);
1481        assert_eq!(cursor.item(), Some(&2));
1482
1483        // Slicing without resetting starts from where the cursor is parked at.
1484        cursor.seek(&Count(1), Bias::Right);
1485        assert_eq!(cursor.slice(&Count(3), Bias::Right).items(()), vec![2, 3]);
1486        assert_eq!(cursor.slice(&Count(6), Bias::Left).items(()), vec![4, 5]);
1487        assert_eq!(cursor.slice(&Count(6), Bias::Right).items(()), vec![6]);
1488    }
1489
1490    #[test]
1491    fn test_edit() {
1492        let mut tree = SumTree::<u8>::default();
1493
1494        let removed = tree.edit(vec![Edit::Insert(1), Edit::Insert(2), Edit::Insert(0)], ());
1495        assert_eq!(tree.items(()), vec![0, 1, 2]);
1496        assert_eq!(removed, Vec::<u8>::new());
1497        assert_eq!(tree.get(&0, ()), Some(&0));
1498        assert_eq!(tree.get(&1, ()), Some(&1));
1499        assert_eq!(tree.get(&2, ()), Some(&2));
1500        assert_eq!(tree.get(&4, ()), None);
1501
1502        let removed = tree.edit(vec![Edit::Insert(2), Edit::Insert(4), Edit::Remove(0)], ());
1503        assert_eq!(tree.items(()), vec![1, 2, 4]);
1504        assert_eq!(removed, vec![0, 2]);
1505        assert_eq!(tree.get(&0, ()), None);
1506        assert_eq!(tree.get(&1, ()), Some(&1));
1507        assert_eq!(tree.get(&2, ()), Some(&2));
1508        assert_eq!(tree.get(&4, ()), Some(&4));
1509    }
1510
1511    #[test]
1512    fn test_from_iter() {
1513        assert_eq!(
1514            SumTree::from_iter(0..100, ()).items(()),
1515            (0..100).collect::<Vec<_>>()
1516        );
1517
1518        // Ensure `from_iter` works correctly when the given iterator restarts
1519        // after calling `next` if `None` was already returned.
1520        let mut ix = 0;
1521        let iterator = std::iter::from_fn(|| {
1522            ix = (ix + 1) % 2;
1523            if ix == 1 { Some(1) } else { None }
1524        });
1525        assert_eq!(SumTree::from_iter(iterator, ()).items(()), vec![1]);
1526    }
1527
1528    #[derive(Clone, Default, Debug)]
1529    pub struct IntegersSummary {
1530        count: usize,
1531        sum: usize,
1532        contains_even: bool,
1533        max: u8,
1534    }
1535
1536    #[derive(Ord, PartialOrd, Default, Eq, PartialEq, Clone, Debug)]
1537    struct Count(usize);
1538
1539    #[derive(Ord, PartialOrd, Default, Eq, PartialEq, Clone, Debug)]
1540    struct Sum(usize);
1541
1542    impl Item for u8 {
1543        type Summary = IntegersSummary;
1544
1545        fn summary(&self, _cx: ()) -> Self::Summary {
1546            IntegersSummary {
1547                count: 1,
1548                sum: *self as usize,
1549                contains_even: (*self & 1) == 0,
1550                max: *self,
1551            }
1552        }
1553    }
1554
1555    impl KeyedItem for u8 {
1556        type Key = u8;
1557
1558        fn key(&self) -> Self::Key {
1559            *self
1560        }
1561    }
1562
1563    impl ContextLessSummary for IntegersSummary {
1564        fn zero() -> Self {
1565            Default::default()
1566        }
1567
1568        fn add_summary(&mut self, other: &Self) {
1569            self.count += other.count;
1570            self.sum += other.sum;
1571            self.contains_even |= other.contains_even;
1572            self.max = cmp::max(self.max, other.max);
1573        }
1574    }
1575
1576    impl Dimension<'_, IntegersSummary> for u8 {
1577        fn zero(_cx: ()) -> Self {
1578            Default::default()
1579        }
1580
1581        fn add_summary(&mut self, summary: &IntegersSummary, _: ()) {
1582            *self = summary.max;
1583        }
1584    }
1585
1586    impl Dimension<'_, IntegersSummary> for Count {
1587        fn zero(_cx: ()) -> Self {
1588            Default::default()
1589        }
1590
1591        fn add_summary(&mut self, summary: &IntegersSummary, _: ()) {
1592            self.0 += summary.count;
1593        }
1594    }
1595
1596    impl SeekTarget<'_, IntegersSummary, IntegersSummary> for Count {
1597        fn cmp(&self, cursor_location: &IntegersSummary, _: ()) -> Ordering {
1598            self.0.cmp(&cursor_location.count)
1599        }
1600    }
1601
1602    impl Dimension<'_, IntegersSummary> for Sum {
1603        fn zero(_cx: ()) -> Self {
1604            Default::default()
1605        }
1606
1607        fn add_summary(&mut self, summary: &IntegersSummary, _: ()) {
1608            self.0 += summary.sum;
1609        }
1610    }
1611}