1mod cursor;
2mod tree_map;
3
4use arrayvec::ArrayVec;
5pub use cursor::{Cursor, FilterCursor, Iter};
6use rayon::prelude::*;
7use std::marker::PhantomData;
8use std::mem;
9use std::{cmp::Ordering, fmt, iter::FromIterator, sync::Arc};
10pub use tree_map::{MapSeekTarget, TreeMap, TreeSet};
11
12#[cfg(test)]
13pub const TREE_BASE: usize = 2;
14#[cfg(not(test))]
15pub const TREE_BASE: usize = 6;
16
17/// An item that can be stored in a [`SumTree`]
18///
19/// Must be summarized by a type that implements [`Summary`]
20pub trait Item: Clone {
21 type Summary: Summary;
22
23 fn summary(&self) -> Self::Summary;
24}
25
26/// An [`Item`] whose summary has a specific key that can be used to identify it
27pub trait KeyedItem: Item {
28 type Key: for<'a> Dimension<'a, Self::Summary> + Ord;
29
30 fn key(&self) -> Self::Key;
31}
32
33/// A type that describes the Sum of all [`Item`]s in a subtree of the [`SumTree`]
34///
35/// Each Summary type can have multiple [`Dimensions`] that it measures,
36/// which can be used to navigate the tree
37pub trait Summary: Clone + fmt::Debug {
38 type Context;
39
40 fn zero(cx: &Self::Context) -> Self;
41
42 fn add_summary(&mut self, summary: &Self, cx: &Self::Context);
43}
44
45/// Each [`Summary`] type can have more than one [`Dimension`] type that it measures.
46///
47/// You can use dimensions to seek to a specific location in the [`SumTree`]
48///
49/// # Example:
50/// Zed's rope has a `TextSummary` type that summarizes lines, characters, and bytes.
51/// Each of these are different dimensions we may want to seek to
52pub trait Dimension<'a, S: Summary>: Clone + fmt::Debug {
53 fn zero(cx: &S::Context) -> Self;
54
55 fn add_summary(&mut self, summary: &'a S, cx: &S::Context);
56
57 fn from_summary(summary: &'a S, cx: &S::Context) -> Self {
58 let mut dimension = Self::zero(cx);
59 dimension.add_summary(summary, cx);
60 dimension
61 }
62}
63
64impl<'a, T: Summary> Dimension<'a, T> for T {
65 fn zero(cx: &T::Context) -> Self {
66 Summary::zero(cx)
67 }
68
69 fn add_summary(&mut self, summary: &'a T, cx: &T::Context) {
70 Summary::add_summary(self, summary, cx);
71 }
72}
73
74pub trait SeekTarget<'a, S: Summary, D: Dimension<'a, S>>: fmt::Debug {
75 fn cmp(&self, cursor_location: &D, cx: &S::Context) -> Ordering;
76}
77
78impl<'a, S: Summary, D: Dimension<'a, S> + Ord> SeekTarget<'a, S, D> for D {
79 fn cmp(&self, cursor_location: &Self, _: &S::Context) -> Ordering {
80 Ord::cmp(self, cursor_location)
81 }
82}
83
84impl<'a, T: Summary> Dimension<'a, T> for () {
85 fn zero(_: &T::Context) -> Self {
86 ()
87 }
88
89 fn add_summary(&mut self, _: &'a T, _: &T::Context) {}
90}
91
92impl<'a, T: Summary, D1: Dimension<'a, T>, D2: Dimension<'a, T>> Dimension<'a, T> for (D1, D2) {
93 fn zero(cx: &T::Context) -> Self {
94 (D1::zero(cx), D2::zero(cx))
95 }
96
97 fn add_summary(&mut self, summary: &'a T, cx: &T::Context) {
98 self.0.add_summary(summary, cx);
99 self.1.add_summary(summary, cx);
100 }
101}
102
103impl<'a, S: Summary, D1: SeekTarget<'a, S, D1> + Dimension<'a, S>, D2: Dimension<'a, S>>
104 SeekTarget<'a, S, (D1, D2)> for D1
105{
106 fn cmp(&self, cursor_location: &(D1, D2), cx: &S::Context) -> Ordering {
107 self.cmp(&cursor_location.0, cx)
108 }
109}
110
111struct End<D>(PhantomData<D>);
112
113impl<D> End<D> {
114 fn new() -> Self {
115 Self(PhantomData)
116 }
117}
118
119impl<'a, S: Summary, D: Dimension<'a, S>> SeekTarget<'a, S, D> for End<D> {
120 fn cmp(&self, _: &D, _: &S::Context) -> Ordering {
121 Ordering::Greater
122 }
123}
124
125impl<D> fmt::Debug for End<D> {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 f.debug_tuple("End").finish()
128 }
129}
130
131/// Bias is used to settle ambiguities when determining positions in an ordered sequence.
132///
133/// The primary use case is for text, where Bias influences
134/// which character an offset or anchor is associated with.
135///
136/// # Examples
137/// Given the buffer `AˇBCD`:
138/// - The offset of the cursor is 1
139/// - [Bias::Left] would attach the cursor to the character `A`
140/// - [Bias::Right] would attach the cursor to the character `B`
141///
142/// Given the buffer `A«BCˇ»D`:
143/// - The offset of the cursor is 3, and the selection is from 1 to 3
144/// - The left anchor of the selection has [Bias::Right], attaching it to the character `B`
145/// - The right anchor of the selection has [Bias::Left], attaching it to the character `C`
146///
147/// Given the buffer `{ˇ<...>`, where `<...>` is a folded region:
148/// - The display offset of the cursor is 1, but the offset in the buffer is determined by the bias
149/// - [Bias::Left] would attach the cursor to the character `{`, with a buffer offset of 1
150/// - [Bias::Right] would attach the cursor to the first character of the folded region,
151/// and the buffer offset would be the offset of the first character of the folded region
152#[derive(Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Debug, Hash, Default)]
153pub enum Bias {
154 /// Attach to the character on the left
155 #[default]
156 Left,
157 /// Attach to the character on the right
158 Right,
159}
160
161impl Bias {
162 pub fn invert(self) -> Self {
163 match self {
164 Self::Left => Self::Right,
165 Self::Right => Self::Left,
166 }
167 }
168}
169
170/// A B+ tree in which each leaf node contains `Item`s of type `T` and a `Summary`s for each `Item`.
171/// Each internal node contains a `Summary` of the items in its subtree.
172///
173/// The maximum number of items per node is `TREE_BASE * 2`.
174///
175/// Any [`Dimension`] supported by the [`Summary`] type can be used to seek to a specific location in the tree.
176#[derive(Debug, Clone)]
177pub struct SumTree<T: Item>(Arc<Node<T>>);
178
179impl<T: Item> SumTree<T> {
180 pub fn new(cx: &<T::Summary as Summary>::Context) -> Self {
181 SumTree(Arc::new(Node::Leaf {
182 summary: <T::Summary as Summary>::zero(cx),
183 items: ArrayVec::new(),
184 item_summaries: ArrayVec::new(),
185 }))
186 }
187
188 pub fn from_item(item: T, cx: &<T::Summary as Summary>::Context) -> Self {
189 let mut tree = Self::new(cx);
190 tree.push(item, cx);
191 tree
192 }
193
194 pub fn from_iter<I: IntoIterator<Item = T>>(
195 iter: I,
196 cx: &<T::Summary as Summary>::Context,
197 ) -> Self {
198 let mut nodes = Vec::new();
199
200 let mut iter = iter.into_iter().fuse().peekable();
201 while iter.peek().is_some() {
202 let items: ArrayVec<T, { 2 * TREE_BASE }> = iter.by_ref().take(2 * TREE_BASE).collect();
203 let item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }> =
204 items.iter().map(|item| item.summary()).collect();
205
206 let mut summary = item_summaries[0].clone();
207 for item_summary in &item_summaries[1..] {
208 <T::Summary as Summary>::add_summary(&mut summary, item_summary, cx);
209 }
210
211 nodes.push(Node::Leaf {
212 summary,
213 items,
214 item_summaries,
215 });
216 }
217
218 let mut parent_nodes = Vec::new();
219 let mut height = 0;
220 while nodes.len() > 1 {
221 height += 1;
222 let mut current_parent_node = None;
223 for child_node in nodes.drain(..) {
224 let parent_node = current_parent_node.get_or_insert_with(|| Node::Internal {
225 summary: <T::Summary as Summary>::zero(cx),
226 height,
227 child_summaries: ArrayVec::new(),
228 child_trees: ArrayVec::new(),
229 });
230 let Node::Internal {
231 summary,
232 child_summaries,
233 child_trees,
234 ..
235 } = parent_node
236 else {
237 unreachable!()
238 };
239 let child_summary = child_node.summary();
240 <T::Summary as Summary>::add_summary(summary, child_summary, cx);
241 child_summaries.push(child_summary.clone());
242 child_trees.push(Self(Arc::new(child_node)));
243
244 if child_trees.len() == 2 * TREE_BASE {
245 parent_nodes.extend(current_parent_node.take());
246 }
247 }
248 parent_nodes.extend(current_parent_node.take());
249 mem::swap(&mut nodes, &mut parent_nodes);
250 }
251
252 if nodes.is_empty() {
253 Self::new(cx)
254 } else {
255 debug_assert_eq!(nodes.len(), 1);
256 Self(Arc::new(nodes.pop().unwrap()))
257 }
258 }
259
260 pub fn from_par_iter<I, Iter>(iter: I, cx: &<T::Summary as Summary>::Context) -> Self
261 where
262 I: IntoParallelIterator<Iter = Iter>,
263 Iter: IndexedParallelIterator<Item = T>,
264 T: Send + Sync,
265 T::Summary: Send + Sync,
266 <T::Summary as Summary>::Context: Sync,
267 {
268 let mut nodes = iter
269 .into_par_iter()
270 .chunks(2 * TREE_BASE)
271 .map(|items| {
272 let items: ArrayVec<T, { 2 * TREE_BASE }> = items.into_iter().collect();
273 let item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }> =
274 items.iter().map(|item| item.summary()).collect();
275 let mut summary = item_summaries[0].clone();
276 for item_summary in &item_summaries[1..] {
277 <T::Summary as Summary>::add_summary(&mut summary, item_summary, cx);
278 }
279 SumTree(Arc::new(Node::Leaf {
280 summary,
281 items,
282 item_summaries,
283 }))
284 })
285 .collect::<Vec<_>>();
286
287 let mut height = 0;
288 while nodes.len() > 1 {
289 height += 1;
290 nodes = nodes
291 .into_par_iter()
292 .chunks(2 * TREE_BASE)
293 .map(|child_nodes| {
294 let child_trees: ArrayVec<SumTree<T>, { 2 * TREE_BASE }> =
295 child_nodes.into_iter().collect();
296 let child_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }> = child_trees
297 .iter()
298 .map(|child_tree| child_tree.summary().clone())
299 .collect();
300 let mut summary = child_summaries[0].clone();
301 for child_summary in &child_summaries[1..] {
302 <T::Summary as Summary>::add_summary(&mut summary, child_summary, cx);
303 }
304 SumTree(Arc::new(Node::Internal {
305 height,
306 summary,
307 child_summaries,
308 child_trees,
309 }))
310 })
311 .collect::<Vec<_>>();
312 }
313
314 if nodes.is_empty() {
315 Self::new(cx)
316 } else {
317 debug_assert_eq!(nodes.len(), 1);
318 nodes.pop().unwrap()
319 }
320 }
321
322 #[allow(unused)]
323 pub fn items(&self, cx: &<T::Summary as Summary>::Context) -> Vec<T> {
324 let mut items = Vec::new();
325 let mut cursor = self.cursor::<()>(cx);
326 cursor.next(cx);
327 while let Some(item) = cursor.item() {
328 items.push(item.clone());
329 cursor.next(cx);
330 }
331 items
332 }
333
334 pub fn iter(&self) -> Iter<T> {
335 Iter::new(self)
336 }
337
338 pub fn cursor<'a, S>(&'a self, cx: &<T::Summary as Summary>::Context) -> Cursor<T, S>
339 where
340 S: Dimension<'a, T::Summary>,
341 {
342 Cursor::new(self, cx)
343 }
344
345 /// Note: If the summary type requires a non `()` context, then the filter cursor
346 /// that is returned cannot be used with Rust's iterators.
347 pub fn filter<'a, F, U>(
348 &'a self,
349 cx: &<T::Summary as Summary>::Context,
350 filter_node: F,
351 ) -> FilterCursor<F, T, U>
352 where
353 F: FnMut(&T::Summary) -> bool,
354 U: Dimension<'a, T::Summary>,
355 {
356 FilterCursor::new(self, cx, filter_node)
357 }
358
359 #[allow(dead_code)]
360 pub fn first(&self) -> Option<&T> {
361 self.leftmost_leaf().0.items().first()
362 }
363
364 pub fn last(&self) -> Option<&T> {
365 self.rightmost_leaf().0.items().last()
366 }
367
368 pub fn update_last(&mut self, f: impl FnOnce(&mut T), cx: &<T::Summary as Summary>::Context) {
369 self.update_last_recursive(f, cx);
370 }
371
372 fn update_last_recursive(
373 &mut self,
374 f: impl FnOnce(&mut T),
375 cx: &<T::Summary as Summary>::Context,
376 ) -> Option<T::Summary> {
377 match Arc::make_mut(&mut self.0) {
378 Node::Internal {
379 summary,
380 child_summaries,
381 child_trees,
382 ..
383 } => {
384 let last_summary = child_summaries.last_mut().unwrap();
385 let last_child = child_trees.last_mut().unwrap();
386 *last_summary = last_child.update_last_recursive(f, cx).unwrap();
387 *summary = sum(child_summaries.iter(), cx);
388 Some(summary.clone())
389 }
390 Node::Leaf {
391 summary,
392 items,
393 item_summaries,
394 } => {
395 if let Some((item, item_summary)) = items.last_mut().zip(item_summaries.last_mut())
396 {
397 (f)(item);
398 *item_summary = item.summary();
399 *summary = sum(item_summaries.iter(), cx);
400 Some(summary.clone())
401 } else {
402 None
403 }
404 }
405 }
406 }
407
408 pub fn extent<'a, D: Dimension<'a, T::Summary>>(
409 &'a self,
410 cx: &<T::Summary as Summary>::Context,
411 ) -> D {
412 let mut extent = D::zero(cx);
413 match self.0.as_ref() {
414 Node::Internal { summary, .. } | Node::Leaf { summary, .. } => {
415 extent.add_summary(summary, cx);
416 }
417 }
418 extent
419 }
420
421 pub fn summary(&self) -> &T::Summary {
422 match self.0.as_ref() {
423 Node::Internal { summary, .. } => summary,
424 Node::Leaf { summary, .. } => summary,
425 }
426 }
427
428 pub fn is_empty(&self) -> bool {
429 match self.0.as_ref() {
430 Node::Internal { .. } => false,
431 Node::Leaf { items, .. } => items.is_empty(),
432 }
433 }
434
435 pub fn extend<I>(&mut self, iter: I, cx: &<T::Summary as Summary>::Context)
436 where
437 I: IntoIterator<Item = T>,
438 {
439 self.append(Self::from_iter(iter, cx), cx);
440 }
441
442 pub fn par_extend<I, Iter>(&mut self, iter: I, cx: &<T::Summary as Summary>::Context)
443 where
444 I: IntoParallelIterator<Iter = Iter>,
445 Iter: IndexedParallelIterator<Item = T>,
446 T: Send + Sync,
447 T::Summary: Send + Sync,
448 <T::Summary as Summary>::Context: Sync,
449 {
450 self.append(Self::from_par_iter(iter, cx), cx);
451 }
452
453 pub fn push(&mut self, item: T, cx: &<T::Summary as Summary>::Context) {
454 let summary = item.summary();
455 self.append(
456 SumTree(Arc::new(Node::Leaf {
457 summary: summary.clone(),
458 items: ArrayVec::from_iter(Some(item)),
459 item_summaries: ArrayVec::from_iter(Some(summary)),
460 })),
461 cx,
462 );
463 }
464
465 pub fn append(&mut self, other: Self, cx: &<T::Summary as Summary>::Context) {
466 if self.is_empty() {
467 *self = other;
468 } else if !other.0.is_leaf() || !other.0.items().is_empty() {
469 if self.0.height() < other.0.height() {
470 for tree in other.0.child_trees() {
471 self.append(tree.clone(), cx);
472 }
473 } else if let Some(split_tree) = self.push_tree_recursive(other, cx) {
474 *self = Self::from_child_trees(self.clone(), split_tree, cx);
475 }
476 }
477 }
478
479 fn push_tree_recursive(
480 &mut self,
481 other: SumTree<T>,
482 cx: &<T::Summary as Summary>::Context,
483 ) -> Option<SumTree<T>> {
484 match Arc::make_mut(&mut self.0) {
485 Node::Internal {
486 height,
487 summary,
488 child_summaries,
489 child_trees,
490 ..
491 } => {
492 let other_node = other.0.clone();
493 <T::Summary as Summary>::add_summary(summary, other_node.summary(), cx);
494
495 let height_delta = *height - other_node.height();
496 let mut summaries_to_append = ArrayVec::<T::Summary, { 2 * TREE_BASE }>::new();
497 let mut trees_to_append = ArrayVec::<SumTree<T>, { 2 * TREE_BASE }>::new();
498 if height_delta == 0 {
499 summaries_to_append.extend(other_node.child_summaries().iter().cloned());
500 trees_to_append.extend(other_node.child_trees().iter().cloned());
501 } else if height_delta == 1 && !other_node.is_underflowing() {
502 summaries_to_append.push(other_node.summary().clone());
503 trees_to_append.push(other)
504 } else {
505 let tree_to_append = child_trees
506 .last_mut()
507 .unwrap()
508 .push_tree_recursive(other, cx);
509 *child_summaries.last_mut().unwrap() =
510 child_trees.last().unwrap().0.summary().clone();
511
512 if let Some(split_tree) = tree_to_append {
513 summaries_to_append.push(split_tree.0.summary().clone());
514 trees_to_append.push(split_tree);
515 }
516 }
517
518 let child_count = child_trees.len() + trees_to_append.len();
519 if child_count > 2 * TREE_BASE {
520 let left_summaries: ArrayVec<_, { 2 * TREE_BASE }>;
521 let right_summaries: ArrayVec<_, { 2 * TREE_BASE }>;
522 let left_trees;
523 let right_trees;
524
525 let midpoint = (child_count + child_count % 2) / 2;
526 {
527 let mut all_summaries = child_summaries
528 .iter()
529 .chain(summaries_to_append.iter())
530 .cloned();
531 left_summaries = all_summaries.by_ref().take(midpoint).collect();
532 right_summaries = all_summaries.collect();
533 let mut all_trees =
534 child_trees.iter().chain(trees_to_append.iter()).cloned();
535 left_trees = all_trees.by_ref().take(midpoint).collect();
536 right_trees = all_trees.collect();
537 }
538 *summary = sum(left_summaries.iter(), cx);
539 *child_summaries = left_summaries;
540 *child_trees = left_trees;
541
542 Some(SumTree(Arc::new(Node::Internal {
543 height: *height,
544 summary: sum(right_summaries.iter(), cx),
545 child_summaries: right_summaries,
546 child_trees: right_trees,
547 })))
548 } else {
549 child_summaries.extend(summaries_to_append);
550 child_trees.extend(trees_to_append);
551 None
552 }
553 }
554 Node::Leaf {
555 summary,
556 items,
557 item_summaries,
558 } => {
559 let other_node = other.0;
560
561 let child_count = items.len() + other_node.items().len();
562 if child_count > 2 * TREE_BASE {
563 let left_items;
564 let right_items;
565 let left_summaries;
566 let right_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }>;
567
568 let midpoint = (child_count + child_count % 2) / 2;
569 {
570 let mut all_items = items.iter().chain(other_node.items().iter()).cloned();
571 left_items = all_items.by_ref().take(midpoint).collect();
572 right_items = all_items.collect();
573
574 let mut all_summaries = item_summaries
575 .iter()
576 .chain(other_node.child_summaries())
577 .cloned();
578 left_summaries = all_summaries.by_ref().take(midpoint).collect();
579 right_summaries = all_summaries.collect();
580 }
581 *items = left_items;
582 *item_summaries = left_summaries;
583 *summary = sum(item_summaries.iter(), cx);
584 Some(SumTree(Arc::new(Node::Leaf {
585 items: right_items,
586 summary: sum(right_summaries.iter(), cx),
587 item_summaries: right_summaries,
588 })))
589 } else {
590 <T::Summary as Summary>::add_summary(summary, other_node.summary(), cx);
591 items.extend(other_node.items().iter().cloned());
592 item_summaries.extend(other_node.child_summaries().iter().cloned());
593 None
594 }
595 }
596 }
597 }
598
599 fn from_child_trees(
600 left: SumTree<T>,
601 right: SumTree<T>,
602 cx: &<T::Summary as Summary>::Context,
603 ) -> Self {
604 let height = left.0.height() + 1;
605 let mut child_summaries = ArrayVec::new();
606 child_summaries.push(left.0.summary().clone());
607 child_summaries.push(right.0.summary().clone());
608 let mut child_trees = ArrayVec::new();
609 child_trees.push(left);
610 child_trees.push(right);
611 SumTree(Arc::new(Node::Internal {
612 height,
613 summary: sum(child_summaries.iter(), cx),
614 child_summaries,
615 child_trees,
616 }))
617 }
618
619 fn leftmost_leaf(&self) -> &Self {
620 match *self.0 {
621 Node::Leaf { .. } => self,
622 Node::Internal {
623 ref child_trees, ..
624 } => child_trees.first().unwrap().leftmost_leaf(),
625 }
626 }
627
628 fn rightmost_leaf(&self) -> &Self {
629 match *self.0 {
630 Node::Leaf { .. } => self,
631 Node::Internal {
632 ref child_trees, ..
633 } => child_trees.last().unwrap().rightmost_leaf(),
634 }
635 }
636
637 #[cfg(debug_assertions)]
638 pub fn _debug_entries(&self) -> Vec<&T> {
639 self.iter().collect::<Vec<_>>()
640 }
641}
642
643impl<T: Item + PartialEq> PartialEq for SumTree<T> {
644 fn eq(&self, other: &Self) -> bool {
645 self.iter().eq(other.iter())
646 }
647}
648
649impl<T: Item + Eq> Eq for SumTree<T> {}
650
651impl<T: KeyedItem> SumTree<T> {
652 pub fn insert_or_replace(
653 &mut self,
654 item: T,
655 cx: &<T::Summary as Summary>::Context,
656 ) -> Option<T> {
657 let mut replaced = None;
658 *self = {
659 let mut cursor = self.cursor::<T::Key>(cx);
660 let mut new_tree = cursor.slice(&item.key(), Bias::Left, cx);
661 if let Some(cursor_item) = cursor.item() {
662 if cursor_item.key() == item.key() {
663 replaced = Some(cursor_item.clone());
664 cursor.next(cx);
665 }
666 }
667 new_tree.push(item, cx);
668 new_tree.append(cursor.suffix(cx), cx);
669 new_tree
670 };
671 replaced
672 }
673
674 pub fn remove(&mut self, key: &T::Key, cx: &<T::Summary as Summary>::Context) -> Option<T> {
675 let mut removed = None;
676 *self = {
677 let mut cursor = self.cursor::<T::Key>(cx);
678 let mut new_tree = cursor.slice(key, Bias::Left, cx);
679 if let Some(item) = cursor.item() {
680 if item.key() == *key {
681 removed = Some(item.clone());
682 cursor.next(cx);
683 }
684 }
685 new_tree.append(cursor.suffix(cx), cx);
686 new_tree
687 };
688 removed
689 }
690
691 pub fn edit(
692 &mut self,
693 mut edits: Vec<Edit<T>>,
694 cx: &<T::Summary as Summary>::Context,
695 ) -> Vec<T> {
696 if edits.is_empty() {
697 return Vec::new();
698 }
699
700 let mut removed = Vec::new();
701 edits.sort_unstable_by_key(|item| item.key());
702
703 *self = {
704 let mut cursor = self.cursor::<T::Key>(cx);
705 let mut new_tree = SumTree::new(cx);
706 let mut buffered_items = Vec::new();
707
708 cursor.seek(&T::Key::zero(cx), Bias::Left, cx);
709 for edit in edits {
710 let new_key = edit.key();
711 let mut old_item = cursor.item();
712
713 if old_item
714 .as_ref()
715 .map_or(false, |old_item| old_item.key() < new_key)
716 {
717 new_tree.extend(buffered_items.drain(..), cx);
718 let slice = cursor.slice(&new_key, Bias::Left, cx);
719 new_tree.append(slice, cx);
720 old_item = cursor.item();
721 }
722
723 if let Some(old_item) = old_item {
724 if old_item.key() == new_key {
725 removed.push(old_item.clone());
726 cursor.next(cx);
727 }
728 }
729
730 match edit {
731 Edit::Insert(item) => {
732 buffered_items.push(item);
733 }
734 Edit::Remove(_) => {}
735 }
736 }
737
738 new_tree.extend(buffered_items, cx);
739 new_tree.append(cursor.suffix(cx), cx);
740 new_tree
741 };
742
743 removed
744 }
745
746 pub fn get(&self, key: &T::Key, cx: &<T::Summary as Summary>::Context) -> Option<&T> {
747 let mut cursor = self.cursor::<T::Key>(cx);
748 if cursor.seek(key, Bias::Left, cx) {
749 cursor.item()
750 } else {
751 None
752 }
753 }
754}
755
756impl<T, S> Default for SumTree<T>
757where
758 T: Item<Summary = S>,
759 S: Summary<Context = ()>,
760{
761 fn default() -> Self {
762 Self::new(&())
763 }
764}
765
766#[derive(Clone, Debug)]
767pub enum Node<T: Item> {
768 Internal {
769 height: u8,
770 summary: T::Summary,
771 child_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }>,
772 child_trees: ArrayVec<SumTree<T>, { 2 * TREE_BASE }>,
773 },
774 Leaf {
775 summary: T::Summary,
776 items: ArrayVec<T, { 2 * TREE_BASE }>,
777 item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }>,
778 },
779}
780
781impl<T: Item> Node<T> {
782 fn is_leaf(&self) -> bool {
783 matches!(self, Node::Leaf { .. })
784 }
785
786 fn height(&self) -> u8 {
787 match self {
788 Node::Internal { height, .. } => *height,
789 Node::Leaf { .. } => 0,
790 }
791 }
792
793 fn summary(&self) -> &T::Summary {
794 match self {
795 Node::Internal { summary, .. } => summary,
796 Node::Leaf { summary, .. } => summary,
797 }
798 }
799
800 fn child_summaries(&self) -> &[T::Summary] {
801 match self {
802 Node::Internal {
803 child_summaries, ..
804 } => child_summaries.as_slice(),
805 Node::Leaf { item_summaries, .. } => item_summaries.as_slice(),
806 }
807 }
808
809 fn child_trees(&self) -> &ArrayVec<SumTree<T>, { 2 * TREE_BASE }> {
810 match self {
811 Node::Internal { child_trees, .. } => child_trees,
812 Node::Leaf { .. } => panic!("Leaf nodes have no child trees"),
813 }
814 }
815
816 fn items(&self) -> &ArrayVec<T, { 2 * TREE_BASE }> {
817 match self {
818 Node::Leaf { items, .. } => items,
819 Node::Internal { .. } => panic!("Internal nodes have no items"),
820 }
821 }
822
823 fn is_underflowing(&self) -> bool {
824 match self {
825 Node::Internal { child_trees, .. } => child_trees.len() < TREE_BASE,
826 Node::Leaf { items, .. } => items.len() < TREE_BASE,
827 }
828 }
829}
830
831#[derive(Debug)]
832pub enum Edit<T: KeyedItem> {
833 Insert(T),
834 Remove(T::Key),
835}
836
837impl<T: KeyedItem> Edit<T> {
838 fn key(&self) -> T::Key {
839 match self {
840 Edit::Insert(item) => item.key(),
841 Edit::Remove(key) => key.clone(),
842 }
843 }
844}
845
846fn sum<'a, T, I>(iter: I, cx: &T::Context) -> T
847where
848 T: 'a + Summary,
849 I: Iterator<Item = &'a T>,
850{
851 let mut sum = T::zero(cx);
852 for value in iter {
853 sum.add_summary(value, cx);
854 }
855 sum
856}
857
858#[cfg(test)]
859mod tests {
860 use super::*;
861 use rand::{distributions, prelude::*};
862 use std::cmp;
863
864 #[ctor::ctor]
865 fn init_logger() {
866 if std::env::var("RUST_LOG").is_ok() {
867 env_logger::init();
868 }
869 }
870
871 #[test]
872 fn test_extend_and_push_tree() {
873 let mut tree1 = SumTree::default();
874 tree1.extend(0..20, &());
875
876 let mut tree2 = SumTree::default();
877 tree2.extend(50..100, &());
878
879 tree1.append(tree2, &());
880 assert_eq!(
881 tree1.items(&()),
882 (0..20).chain(50..100).collect::<Vec<u8>>()
883 );
884 }
885
886 #[test]
887 fn test_random() {
888 let mut starting_seed = 0;
889 if let Ok(value) = std::env::var("SEED") {
890 starting_seed = value.parse().expect("invalid SEED variable");
891 }
892 let mut num_iterations = 100;
893 if let Ok(value) = std::env::var("ITERATIONS") {
894 num_iterations = value.parse().expect("invalid ITERATIONS variable");
895 }
896 let num_operations = std::env::var("OPERATIONS")
897 .map_or(5, |o| o.parse().expect("invalid OPERATIONS variable"));
898
899 for seed in starting_seed..(starting_seed + num_iterations) {
900 eprintln!("seed = {}", seed);
901 let mut rng = StdRng::seed_from_u64(seed);
902
903 let rng = &mut rng;
904 let mut tree = SumTree::<u8>::default();
905 let count = rng.gen_range(0..10);
906 if rng.gen() {
907 tree.extend(rng.sample_iter(distributions::Standard).take(count), &());
908 } else {
909 let items = rng
910 .sample_iter(distributions::Standard)
911 .take(count)
912 .collect::<Vec<_>>();
913 tree.par_extend(items, &());
914 }
915
916 for _ in 0..num_operations {
917 let splice_end = rng.gen_range(0..tree.extent::<Count>(&()).0 + 1);
918 let splice_start = rng.gen_range(0..splice_end + 1);
919 let count = rng.gen_range(0..10);
920 let tree_end = tree.extent::<Count>(&());
921 let new_items = rng
922 .sample_iter(distributions::Standard)
923 .take(count)
924 .collect::<Vec<u8>>();
925
926 let mut reference_items = tree.items(&());
927 reference_items.splice(splice_start..splice_end, new_items.clone());
928
929 tree = {
930 let mut cursor = tree.cursor::<Count>(&());
931 let mut new_tree = cursor.slice(&Count(splice_start), Bias::Right, &());
932 if rng.gen() {
933 new_tree.extend(new_items, &());
934 } else {
935 new_tree.par_extend(new_items, &());
936 }
937 cursor.seek(&Count(splice_end), Bias::Right, &());
938 new_tree.append(cursor.slice(&tree_end, Bias::Right, &()), &());
939 new_tree
940 };
941
942 assert_eq!(tree.items(&()), reference_items);
943 assert_eq!(
944 tree.iter().collect::<Vec<_>>(),
945 tree.cursor::<()>(&()).collect::<Vec<_>>()
946 );
947
948 log::info!("tree items: {:?}", tree.items(&()));
949
950 let mut filter_cursor =
951 tree.filter::<_, Count>(&(), |summary| summary.contains_even);
952 let expected_filtered_items = tree
953 .items(&())
954 .into_iter()
955 .enumerate()
956 .filter(|(_, item)| (item & 1) == 0)
957 .collect::<Vec<_>>();
958
959 let mut item_ix = if rng.gen() {
960 filter_cursor.next(&());
961 0
962 } else {
963 filter_cursor.prev(&());
964 expected_filtered_items.len().saturating_sub(1)
965 };
966 while item_ix < expected_filtered_items.len() {
967 log::info!("filter_cursor, item_ix: {}", item_ix);
968 let actual_item = filter_cursor.item().unwrap();
969 let (reference_index, reference_item) = expected_filtered_items[item_ix];
970 assert_eq!(actual_item, &reference_item);
971 assert_eq!(filter_cursor.start().0, reference_index);
972 log::info!("next");
973 filter_cursor.next(&());
974 item_ix += 1;
975
976 while item_ix > 0 && rng.gen_bool(0.2) {
977 log::info!("prev");
978 filter_cursor.prev(&());
979 item_ix -= 1;
980
981 if item_ix == 0 && rng.gen_bool(0.2) {
982 filter_cursor.prev(&());
983 assert_eq!(filter_cursor.item(), None);
984 assert_eq!(filter_cursor.start().0, 0);
985 filter_cursor.next(&());
986 }
987 }
988 }
989 assert_eq!(filter_cursor.item(), None);
990
991 let mut before_start = false;
992 let mut cursor = tree.cursor::<Count>(&());
993 let start_pos = rng.gen_range(0..=reference_items.len());
994 cursor.seek(&Count(start_pos), Bias::Right, &());
995 let mut pos = rng.gen_range(start_pos..=reference_items.len());
996 cursor.seek_forward(&Count(pos), Bias::Right, &());
997
998 for i in 0..10 {
999 assert_eq!(cursor.start().0, pos);
1000
1001 if pos > 0 {
1002 assert_eq!(cursor.prev_item().unwrap(), &reference_items[pos - 1]);
1003 } else {
1004 assert_eq!(cursor.prev_item(), None);
1005 }
1006
1007 if pos < reference_items.len() && !before_start {
1008 assert_eq!(cursor.item().unwrap(), &reference_items[pos]);
1009 } else {
1010 assert_eq!(cursor.item(), None);
1011 }
1012
1013 if before_start {
1014 assert_eq!(cursor.next_item(), reference_items.first());
1015 } else if pos + 1 < reference_items.len() {
1016 assert_eq!(cursor.next_item().unwrap(), &reference_items[pos + 1]);
1017 } else {
1018 assert_eq!(cursor.next_item(), None);
1019 }
1020
1021 if i < 5 {
1022 cursor.next(&());
1023 if pos < reference_items.len() {
1024 pos += 1;
1025 before_start = false;
1026 }
1027 } else {
1028 cursor.prev(&());
1029 if pos == 0 {
1030 before_start = true;
1031 }
1032 pos = pos.saturating_sub(1);
1033 }
1034 }
1035 }
1036
1037 for _ in 0..10 {
1038 let end = rng.gen_range(0..tree.extent::<Count>(&()).0 + 1);
1039 let start = rng.gen_range(0..end + 1);
1040 let start_bias = if rng.gen() { Bias::Left } else { Bias::Right };
1041 let end_bias = if rng.gen() { Bias::Left } else { Bias::Right };
1042
1043 let mut cursor = tree.cursor::<Count>(&());
1044 cursor.seek(&Count(start), start_bias, &());
1045 let slice = cursor.slice(&Count(end), end_bias, &());
1046
1047 cursor.seek(&Count(start), start_bias, &());
1048 let summary = cursor.summary::<_, Sum>(&Count(end), end_bias, &());
1049
1050 assert_eq!(summary.0, slice.summary().sum);
1051 }
1052 }
1053 }
1054
1055 #[test]
1056 fn test_cursor() {
1057 // Empty tree
1058 let tree = SumTree::<u8>::default();
1059 let mut cursor = tree.cursor::<IntegersSummary>(&());
1060 assert_eq!(
1061 cursor.slice(&Count(0), Bias::Right, &()).items(&()),
1062 Vec::<u8>::new()
1063 );
1064 assert_eq!(cursor.item(), None);
1065 assert_eq!(cursor.prev_item(), None);
1066 assert_eq!(cursor.next_item(), None);
1067 assert_eq!(cursor.start().sum, 0);
1068 cursor.prev(&());
1069 assert_eq!(cursor.item(), None);
1070 assert_eq!(cursor.prev_item(), None);
1071 assert_eq!(cursor.next_item(), None);
1072 assert_eq!(cursor.start().sum, 0);
1073 cursor.next(&());
1074 assert_eq!(cursor.item(), None);
1075 assert_eq!(cursor.prev_item(), None);
1076 assert_eq!(cursor.next_item(), None);
1077 assert_eq!(cursor.start().sum, 0);
1078
1079 // Single-element tree
1080 let mut tree = SumTree::<u8>::default();
1081 tree.extend(vec![1], &());
1082 let mut cursor = tree.cursor::<IntegersSummary>(&());
1083 assert_eq!(
1084 cursor.slice(&Count(0), Bias::Right, &()).items(&()),
1085 Vec::<u8>::new()
1086 );
1087 assert_eq!(cursor.item(), Some(&1));
1088 assert_eq!(cursor.prev_item(), None);
1089 assert_eq!(cursor.next_item(), None);
1090 assert_eq!(cursor.start().sum, 0);
1091
1092 cursor.next(&());
1093 assert_eq!(cursor.item(), None);
1094 assert_eq!(cursor.prev_item(), Some(&1));
1095 assert_eq!(cursor.next_item(), None);
1096 assert_eq!(cursor.start().sum, 1);
1097
1098 cursor.prev(&());
1099 assert_eq!(cursor.item(), Some(&1));
1100 assert_eq!(cursor.prev_item(), None);
1101 assert_eq!(cursor.next_item(), None);
1102 assert_eq!(cursor.start().sum, 0);
1103
1104 let mut cursor = tree.cursor::<IntegersSummary>(&());
1105 assert_eq!(cursor.slice(&Count(1), Bias::Right, &()).items(&()), [1]);
1106 assert_eq!(cursor.item(), None);
1107 assert_eq!(cursor.prev_item(), Some(&1));
1108 assert_eq!(cursor.next_item(), None);
1109 assert_eq!(cursor.start().sum, 1);
1110
1111 cursor.seek(&Count(0), Bias::Right, &());
1112 assert_eq!(
1113 cursor
1114 .slice(&tree.extent::<Count>(&()), Bias::Right, &())
1115 .items(&()),
1116 [1]
1117 );
1118 assert_eq!(cursor.item(), None);
1119 assert_eq!(cursor.prev_item(), Some(&1));
1120 assert_eq!(cursor.next_item(), None);
1121 assert_eq!(cursor.start().sum, 1);
1122
1123 // Multiple-element tree
1124 let mut tree = SumTree::default();
1125 tree.extend(vec![1, 2, 3, 4, 5, 6], &());
1126 let mut cursor = tree.cursor::<IntegersSummary>(&());
1127
1128 assert_eq!(cursor.slice(&Count(2), Bias::Right, &()).items(&()), [1, 2]);
1129 assert_eq!(cursor.item(), Some(&3));
1130 assert_eq!(cursor.prev_item(), Some(&2));
1131 assert_eq!(cursor.next_item(), Some(&4));
1132 assert_eq!(cursor.start().sum, 3);
1133
1134 cursor.next(&());
1135 assert_eq!(cursor.item(), Some(&4));
1136 assert_eq!(cursor.prev_item(), Some(&3));
1137 assert_eq!(cursor.next_item(), Some(&5));
1138 assert_eq!(cursor.start().sum, 6);
1139
1140 cursor.next(&());
1141 assert_eq!(cursor.item(), Some(&5));
1142 assert_eq!(cursor.prev_item(), Some(&4));
1143 assert_eq!(cursor.next_item(), Some(&6));
1144 assert_eq!(cursor.start().sum, 10);
1145
1146 cursor.next(&());
1147 assert_eq!(cursor.item(), Some(&6));
1148 assert_eq!(cursor.prev_item(), Some(&5));
1149 assert_eq!(cursor.next_item(), None);
1150 assert_eq!(cursor.start().sum, 15);
1151
1152 cursor.next(&());
1153 cursor.next(&());
1154 assert_eq!(cursor.item(), None);
1155 assert_eq!(cursor.prev_item(), Some(&6));
1156 assert_eq!(cursor.next_item(), None);
1157 assert_eq!(cursor.start().sum, 21);
1158
1159 cursor.prev(&());
1160 assert_eq!(cursor.item(), Some(&6));
1161 assert_eq!(cursor.prev_item(), Some(&5));
1162 assert_eq!(cursor.next_item(), None);
1163 assert_eq!(cursor.start().sum, 15);
1164
1165 cursor.prev(&());
1166 assert_eq!(cursor.item(), Some(&5));
1167 assert_eq!(cursor.prev_item(), Some(&4));
1168 assert_eq!(cursor.next_item(), Some(&6));
1169 assert_eq!(cursor.start().sum, 10);
1170
1171 cursor.prev(&());
1172 assert_eq!(cursor.item(), Some(&4));
1173 assert_eq!(cursor.prev_item(), Some(&3));
1174 assert_eq!(cursor.next_item(), Some(&5));
1175 assert_eq!(cursor.start().sum, 6);
1176
1177 cursor.prev(&());
1178 assert_eq!(cursor.item(), Some(&3));
1179 assert_eq!(cursor.prev_item(), Some(&2));
1180 assert_eq!(cursor.next_item(), Some(&4));
1181 assert_eq!(cursor.start().sum, 3);
1182
1183 cursor.prev(&());
1184 assert_eq!(cursor.item(), Some(&2));
1185 assert_eq!(cursor.prev_item(), Some(&1));
1186 assert_eq!(cursor.next_item(), Some(&3));
1187 assert_eq!(cursor.start().sum, 1);
1188
1189 cursor.prev(&());
1190 assert_eq!(cursor.item(), Some(&1));
1191 assert_eq!(cursor.prev_item(), None);
1192 assert_eq!(cursor.next_item(), Some(&2));
1193 assert_eq!(cursor.start().sum, 0);
1194
1195 cursor.prev(&());
1196 assert_eq!(cursor.item(), None);
1197 assert_eq!(cursor.prev_item(), None);
1198 assert_eq!(cursor.next_item(), Some(&1));
1199 assert_eq!(cursor.start().sum, 0);
1200
1201 cursor.next(&());
1202 assert_eq!(cursor.item(), Some(&1));
1203 assert_eq!(cursor.prev_item(), None);
1204 assert_eq!(cursor.next_item(), Some(&2));
1205 assert_eq!(cursor.start().sum, 0);
1206
1207 let mut cursor = tree.cursor::<IntegersSummary>(&());
1208 assert_eq!(
1209 cursor
1210 .slice(&tree.extent::<Count>(&()), Bias::Right, &())
1211 .items(&()),
1212 tree.items(&())
1213 );
1214 assert_eq!(cursor.item(), None);
1215 assert_eq!(cursor.prev_item(), Some(&6));
1216 assert_eq!(cursor.next_item(), None);
1217 assert_eq!(cursor.start().sum, 21);
1218
1219 cursor.seek(&Count(3), Bias::Right, &());
1220 assert_eq!(
1221 cursor
1222 .slice(&tree.extent::<Count>(&()), Bias::Right, &())
1223 .items(&()),
1224 [4, 5, 6]
1225 );
1226 assert_eq!(cursor.item(), None);
1227 assert_eq!(cursor.prev_item(), Some(&6));
1228 assert_eq!(cursor.next_item(), None);
1229 assert_eq!(cursor.start().sum, 21);
1230
1231 // Seeking can bias left or right
1232 cursor.seek(&Count(1), Bias::Left, &());
1233 assert_eq!(cursor.item(), Some(&1));
1234 cursor.seek(&Count(1), Bias::Right, &());
1235 assert_eq!(cursor.item(), Some(&2));
1236
1237 // Slicing without resetting starts from where the cursor is parked at.
1238 cursor.seek(&Count(1), Bias::Right, &());
1239 assert_eq!(
1240 cursor.slice(&Count(3), Bias::Right, &()).items(&()),
1241 vec![2, 3]
1242 );
1243 assert_eq!(
1244 cursor.slice(&Count(6), Bias::Left, &()).items(&()),
1245 vec![4, 5]
1246 );
1247 assert_eq!(
1248 cursor.slice(&Count(6), Bias::Right, &()).items(&()),
1249 vec![6]
1250 );
1251 }
1252
1253 #[test]
1254 fn test_edit() {
1255 let mut tree = SumTree::<u8>::default();
1256
1257 let removed = tree.edit(vec![Edit::Insert(1), Edit::Insert(2), Edit::Insert(0)], &());
1258 assert_eq!(tree.items(&()), vec![0, 1, 2]);
1259 assert_eq!(removed, Vec::<u8>::new());
1260 assert_eq!(tree.get(&0, &()), Some(&0));
1261 assert_eq!(tree.get(&1, &()), Some(&1));
1262 assert_eq!(tree.get(&2, &()), Some(&2));
1263 assert_eq!(tree.get(&4, &()), None);
1264
1265 let removed = tree.edit(vec![Edit::Insert(2), Edit::Insert(4), Edit::Remove(0)], &());
1266 assert_eq!(tree.items(&()), vec![1, 2, 4]);
1267 assert_eq!(removed, vec![0, 2]);
1268 assert_eq!(tree.get(&0, &()), None);
1269 assert_eq!(tree.get(&1, &()), Some(&1));
1270 assert_eq!(tree.get(&2, &()), Some(&2));
1271 assert_eq!(tree.get(&4, &()), Some(&4));
1272 }
1273
1274 #[test]
1275 fn test_from_iter() {
1276 assert_eq!(
1277 SumTree::from_iter(0..100, &()).items(&()),
1278 (0..100).collect::<Vec<_>>()
1279 );
1280
1281 // Ensure `from_iter` works correctly when the given iterator restarts
1282 // after calling `next` if `None` was already returned.
1283 let mut ix = 0;
1284 let iterator = std::iter::from_fn(|| {
1285 ix = (ix + 1) % 2;
1286 if ix == 1 {
1287 Some(1)
1288 } else {
1289 None
1290 }
1291 });
1292 assert_eq!(SumTree::from_iter(iterator, &()).items(&()), vec![1]);
1293 }
1294
1295 #[derive(Clone, Default, Debug)]
1296 pub struct IntegersSummary {
1297 count: usize,
1298 sum: usize,
1299 contains_even: bool,
1300 max: u8,
1301 }
1302
1303 #[derive(Ord, PartialOrd, Default, Eq, PartialEq, Clone, Debug)]
1304 struct Count(usize);
1305
1306 #[derive(Ord, PartialOrd, Default, Eq, PartialEq, Clone, Debug)]
1307 struct Sum(usize);
1308
1309 impl Item for u8 {
1310 type Summary = IntegersSummary;
1311
1312 fn summary(&self) -> Self::Summary {
1313 IntegersSummary {
1314 count: 1,
1315 sum: *self as usize,
1316 contains_even: (*self & 1) == 0,
1317 max: *self,
1318 }
1319 }
1320 }
1321
1322 impl KeyedItem for u8 {
1323 type Key = u8;
1324
1325 fn key(&self) -> Self::Key {
1326 *self
1327 }
1328 }
1329
1330 impl Summary for IntegersSummary {
1331 type Context = ();
1332
1333 fn zero(_cx: &()) -> Self {
1334 Default::default()
1335 }
1336
1337 fn add_summary(&mut self, other: &Self, _: &()) {
1338 self.count += other.count;
1339 self.sum += other.sum;
1340 self.contains_even |= other.contains_even;
1341 self.max = cmp::max(self.max, other.max);
1342 }
1343 }
1344
1345 impl<'a> Dimension<'a, IntegersSummary> for u8 {
1346 fn zero(_cx: &()) -> Self {
1347 Default::default()
1348 }
1349
1350 fn add_summary(&mut self, summary: &IntegersSummary, _: &()) {
1351 *self = summary.max;
1352 }
1353 }
1354
1355 impl<'a> Dimension<'a, IntegersSummary> for Count {
1356 fn zero(_cx: &()) -> Self {
1357 Default::default()
1358 }
1359
1360 fn add_summary(&mut self, summary: &IntegersSummary, _: &()) {
1361 self.0 += summary.count;
1362 }
1363 }
1364
1365 impl<'a> SeekTarget<'a, IntegersSummary, IntegersSummary> for Count {
1366 fn cmp(&self, cursor_location: &IntegersSummary, _: &()) -> Ordering {
1367 self.0.cmp(&cursor_location.count)
1368 }
1369 }
1370
1371 impl<'a> Dimension<'a, IntegersSummary> for Sum {
1372 fn zero(_cx: &()) -> Self {
1373 Default::default()
1374 }
1375
1376 fn add_summary(&mut self, summary: &IntegersSummary, _: &()) {
1377 self.0 += summary.sum;
1378 }
1379 }
1380}