1use super::*;
2use heapless::Vec as ArrayVec;
3use std::{cmp::Ordering, mem, sync::Arc};
4use ztracing::instrument;
5
6#[derive(Clone)]
7struct StackEntry<'a, T: Item, D> {
8 tree: &'a SumTree<T>,
9 index: u32,
10 position: D,
11}
12
13impl<'a, T: Item, D> StackEntry<'a, T, D> {
14 #[inline]
15 fn index(&self) -> usize {
16 self.index as usize
17 }
18}
19
20impl<T: Item + fmt::Debug, D: fmt::Debug> fmt::Debug for StackEntry<'_, T, D> {
21 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22 f.debug_struct("StackEntry")
23 .field("index", &self.index)
24 .field("position", &self.position)
25 .finish()
26 }
27}
28
29#[derive(Clone)]
30pub struct Cursor<'a, 'b, T: Item, D> {
31 tree: &'a SumTree<T>,
32 stack: ArrayVec<StackEntry<'a, T, D>, 16, u8>,
33 pub position: D,
34 did_seek: bool,
35 at_end: bool,
36 cx: <T::Summary as Summary>::Context<'b>,
37}
38
39impl<T: Item + fmt::Debug, D: fmt::Debug> fmt::Debug for Cursor<'_, '_, T, D>
40where
41 T::Summary: fmt::Debug,
42{
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 f.debug_struct("Cursor")
45 .field("tree", &self.tree)
46 .field("stack", &self.stack)
47 .field("position", &self.position)
48 .field("did_seek", &self.did_seek)
49 .field("at_end", &self.at_end)
50 .finish()
51 }
52}
53
54pub struct Iter<'a, T: Item> {
55 tree: &'a SumTree<T>,
56 stack: ArrayVec<StackEntry<'a, T, ()>, 16, u8>,
57}
58
59impl<'a, 'b, T, D> Cursor<'a, 'b, T, D>
60where
61 T: Item,
62 D: Dimension<'a, T::Summary>,
63{
64 pub fn new(tree: &'a SumTree<T>, cx: <T::Summary as Summary>::Context<'b>) -> Self {
65 Self {
66 tree,
67 stack: ArrayVec::new(),
68 position: D::zero(cx),
69 did_seek: false,
70 at_end: tree.is_empty(),
71 cx,
72 }
73 }
74
75 pub fn reset(&mut self) {
76 self.did_seek = false;
77 self.at_end = self.tree.is_empty();
78 self.stack.truncate(0);
79 self.position = D::zero(self.cx);
80 }
81
82 pub fn start(&self) -> &D {
83 &self.position
84 }
85
86 #[track_caller]
87 pub fn end(&self) -> D {
88 if let Some(item_summary) = self.item_summary() {
89 let mut end = self.start().clone();
90 end.add_summary(item_summary, self.cx);
91 end
92 } else {
93 self.start().clone()
94 }
95 }
96
97 /// Item is None, when the list is empty, or this cursor is at the end of the list.
98 #[track_caller]
99 pub fn item(&self) -> Option<&'a T> {
100 self.assert_did_seek();
101 if let Some(entry) = self.stack.last() {
102 match *entry.tree.0 {
103 Node::Leaf { ref items, .. } => {
104 if entry.index() == items.len() {
105 None
106 } else {
107 Some(&items[entry.index()])
108 }
109 }
110 _ => unreachable!(),
111 }
112 } else {
113 None
114 }
115 }
116
117 #[track_caller]
118 pub fn item_summary(&self) -> Option<&'a T::Summary> {
119 self.assert_did_seek();
120 if let Some(entry) = self.stack.last() {
121 match *entry.tree.0 {
122 Node::Leaf {
123 ref item_summaries, ..
124 } => {
125 if entry.index() == item_summaries.len() {
126 None
127 } else {
128 Some(&item_summaries[entry.index()])
129 }
130 }
131 _ => unreachable!(),
132 }
133 } else {
134 None
135 }
136 }
137
138 #[track_caller]
139 pub fn next_item(&self) -> Option<&'a T> {
140 self.assert_did_seek();
141 if let Some(entry) = self.stack.last() {
142 if entry.index() == entry.tree.0.items().len() - 1 {
143 if let Some(next_leaf) = self.next_leaf() {
144 Some(next_leaf.0.items().first().unwrap())
145 } else {
146 None
147 }
148 } else {
149 match *entry.tree.0 {
150 Node::Leaf { ref items, .. } => Some(&items[entry.index() + 1]),
151 _ => unreachable!(),
152 }
153 }
154 } else if self.at_end {
155 None
156 } else {
157 self.tree.first()
158 }
159 }
160
161 #[track_caller]
162 fn next_leaf(&self) -> Option<&'a SumTree<T>> {
163 for entry in self.stack.iter().rev().skip(1) {
164 if entry.index() < entry.tree.0.child_trees().len() - 1 {
165 match *entry.tree.0 {
166 Node::Internal {
167 ref child_trees, ..
168 } => return Some(child_trees[entry.index() + 1].leftmost_leaf()),
169 Node::Leaf { .. } => unreachable!(),
170 };
171 }
172 }
173 None
174 }
175
176 #[track_caller]
177 pub fn prev_item(&self) -> Option<&'a T> {
178 self.assert_did_seek();
179 if let Some(entry) = self.stack.last() {
180 if entry.index() == 0 {
181 if let Some(prev_leaf) = self.prev_leaf() {
182 Some(prev_leaf.0.items().last().unwrap())
183 } else {
184 None
185 }
186 } else {
187 match *entry.tree.0 {
188 Node::Leaf { ref items, .. } => Some(&items[entry.index() - 1]),
189 _ => unreachable!(),
190 }
191 }
192 } else if self.at_end {
193 self.tree.last()
194 } else {
195 None
196 }
197 }
198
199 #[track_caller]
200 fn prev_leaf(&self) -> Option<&'a SumTree<T>> {
201 for entry in self.stack.iter().rev().skip(1) {
202 if entry.index() != 0 {
203 match *entry.tree.0 {
204 Node::Internal {
205 ref child_trees, ..
206 } => return Some(child_trees[entry.index() - 1].rightmost_leaf()),
207 Node::Leaf { .. } => unreachable!(),
208 };
209 }
210 }
211 None
212 }
213
214 #[track_caller]
215 #[instrument(skip_all)]
216 pub fn prev(&mut self) {
217 self.search_backward(|_| true)
218 }
219
220 #[track_caller]
221 pub fn search_backward<F>(&mut self, mut filter_node: F)
222 where
223 F: FnMut(&T::Summary) -> bool,
224 {
225 if !self.did_seek {
226 self.did_seek = true;
227 self.at_end = true;
228 }
229
230 if self.at_end {
231 self.position = D::zero(self.cx);
232 self.at_end = self.tree.is_empty();
233 if !self.tree.is_empty() {
234 self.stack
235 .push(StackEntry {
236 tree: self.tree,
237 index: self.tree.0.child_summaries().len() as u32,
238 position: D::from_summary(self.tree.summary(), self.cx),
239 })
240 .unwrap_oob();
241 }
242 }
243
244 let mut descending = false;
245 while !self.stack.is_empty() {
246 if let Some(StackEntry { position, .. }) = self.stack.iter().rev().nth(1) {
247 self.position = position.clone();
248 } else {
249 self.position = D::zero(self.cx);
250 }
251
252 let entry = self.stack.last_mut().unwrap();
253 if !descending {
254 if entry.index() == 0 {
255 self.stack.pop();
256 continue;
257 } else {
258 entry.index -= 1;
259 }
260 }
261
262 for summary in &entry.tree.0.child_summaries()[..entry.index()] {
263 self.position.add_summary(summary, self.cx);
264 }
265 entry.position = self.position.clone();
266
267 descending = filter_node(&entry.tree.0.child_summaries()[entry.index()]);
268 match entry.tree.0.as_ref() {
269 Node::Internal { child_trees, .. } => {
270 if descending {
271 let tree = &child_trees[entry.index()];
272 self.stack
273 .push(StackEntry {
274 position: D::zero(self.cx),
275 tree,
276 index: tree.0.child_summaries().len() as u32 - 1,
277 })
278 .unwrap_oob();
279 }
280 }
281 Node::Leaf { .. } => {
282 if descending {
283 break;
284 }
285 }
286 }
287 }
288 }
289
290 #[track_caller]
291 pub fn next(&mut self) {
292 self.search_forward(|_| true)
293 }
294
295 #[track_caller]
296 pub fn search_forward<F>(&mut self, mut filter_node: F)
297 where
298 F: FnMut(&T::Summary) -> bool,
299 {
300 let mut descend = false;
301
302 if self.stack.is_empty() {
303 if !self.at_end {
304 self.stack
305 .push(StackEntry {
306 tree: self.tree,
307 index: 0,
308 position: D::zero(self.cx),
309 })
310 .unwrap_oob();
311 descend = true;
312 }
313 self.did_seek = true;
314 }
315
316 while !self.stack.is_empty() {
317 let new_subtree = {
318 let entry = self.stack.last_mut().unwrap();
319 match entry.tree.0.as_ref() {
320 Node::Internal {
321 child_trees,
322 child_summaries,
323 ..
324 } => {
325 if !descend {
326 entry.index += 1;
327 entry.position = self.position.clone();
328 }
329
330 while entry.index() < child_summaries.len() {
331 let next_summary = &child_summaries[entry.index()];
332 if filter_node(next_summary) {
333 break;
334 } else {
335 entry.index += 1;
336 entry.position.add_summary(next_summary, self.cx);
337 self.position.add_summary(next_summary, self.cx);
338 }
339 }
340
341 child_trees.get(entry.index())
342 }
343 Node::Leaf { item_summaries, .. } => {
344 if !descend {
345 let item_summary = &item_summaries[entry.index()];
346 entry.index += 1;
347 entry.position.add_summary(item_summary, self.cx);
348 self.position.add_summary(item_summary, self.cx);
349 }
350
351 loop {
352 if let Some(next_item_summary) = item_summaries.get(entry.index()) {
353 if filter_node(next_item_summary) {
354 return;
355 } else {
356 entry.index += 1;
357 entry.position.add_summary(next_item_summary, self.cx);
358 self.position.add_summary(next_item_summary, self.cx);
359 }
360 } else {
361 break None;
362 }
363 }
364 }
365 }
366 };
367
368 if let Some(subtree) = new_subtree {
369 descend = true;
370 self.stack
371 .push(StackEntry {
372 tree: subtree,
373 index: 0,
374 position: self.position.clone(),
375 })
376 .unwrap_oob();
377 } else {
378 descend = false;
379 self.stack.pop();
380 }
381 }
382
383 self.at_end = self.stack.is_empty();
384 debug_assert!(self.stack.is_empty() || self.stack.last().unwrap().tree.0.is_leaf());
385 }
386
387 #[track_caller]
388 fn assert_did_seek(&self) {
389 assert!(
390 self.did_seek,
391 "Must call `seek`, `next` or `prev` before calling this method"
392 );
393 }
394
395 pub fn did_seek(&self) -> bool {
396 self.did_seek
397 }
398}
399
400impl<'a, 'b, T, D> Cursor<'a, 'b, T, D>
401where
402 T: Item,
403 D: Dimension<'a, T::Summary>,
404{
405 /// Returns whether we found the item you were seeking for.
406 #[track_caller]
407 #[instrument(skip_all)]
408 pub fn seek<Target>(&mut self, pos: &Target, bias: Bias) -> bool
409 where
410 Target: SeekTarget<'a, T::Summary, D>,
411 {
412 self.reset();
413 self.seek_internal(pos, bias, &mut ())
414 }
415
416 /// Returns whether we found the item you were seeking for.
417 ///
418 /// # Panics
419 ///
420 /// If we did not seek before, use seek instead in that case.
421 #[track_caller]
422 #[instrument(skip_all)]
423 pub fn seek_forward<Target>(&mut self, pos: &Target, bias: Bias) -> bool
424 where
425 Target: SeekTarget<'a, T::Summary, D>,
426 {
427 self.seek_internal(pos, bias, &mut ())
428 }
429
430 /// Advances the cursor and returns traversed items as a tree.
431 #[track_caller]
432 pub fn slice<Target>(&mut self, end: &Target, bias: Bias) -> SumTree<T>
433 where
434 Target: SeekTarget<'a, T::Summary, D>,
435 {
436 let mut slice = SliceSeekAggregate {
437 tree: SumTree::new(self.cx),
438 leaf_items: ArrayVec::new(),
439 leaf_item_summaries: ArrayVec::new(),
440 leaf_summary: <T::Summary as Summary>::zero(self.cx),
441 };
442 self.seek_internal(end, bias, &mut slice);
443 slice.tree
444 }
445
446 #[track_caller]
447 pub fn suffix(&mut self) -> SumTree<T> {
448 self.slice(&End::new(), Bias::Right)
449 }
450
451 #[track_caller]
452 pub fn summary<Target, Output>(&mut self, end: &Target, bias: Bias) -> Output
453 where
454 Target: SeekTarget<'a, T::Summary, D>,
455 Output: Dimension<'a, T::Summary>,
456 {
457 let mut summary = SummarySeekAggregate(Output::zero(self.cx));
458 self.seek_internal(end, bias, &mut summary);
459 summary.0
460 }
461
462 /// Returns whether we found the item you were seeking for.
463 #[track_caller]
464 #[instrument(skip_all)]
465 fn seek_internal(
466 &mut self,
467 target: &dyn SeekTarget<'a, T::Summary, D>,
468 bias: Bias,
469 aggregate: &mut dyn SeekAggregate<'a, T>,
470 ) -> bool {
471 assert!(
472 target.cmp(&self.position, self.cx).is_ge(),
473 "cannot seek backward",
474 );
475
476 if !self.did_seek {
477 self.did_seek = true;
478 self.stack
479 .push(StackEntry {
480 tree: self.tree,
481 index: 0,
482 position: D::zero(self.cx),
483 })
484 .unwrap_oob();
485 }
486
487 let mut ascending = false;
488 'outer: while let Some(entry) = self.stack.last_mut() {
489 match *entry.tree.0 {
490 Node::Internal {
491 ref child_summaries,
492 ref child_trees,
493 ..
494 } => {
495 if ascending {
496 entry.index += 1;
497 entry.position = self.position.clone();
498 }
499
500 for (child_tree, child_summary) in child_trees[entry.index()..]
501 .iter()
502 .zip(&child_summaries[entry.index()..])
503 {
504 let mut child_end = self.position.clone();
505 child_end.add_summary(child_summary, self.cx);
506
507 let comparison = target.cmp(&child_end, self.cx);
508 if comparison == Ordering::Greater
509 || (comparison == Ordering::Equal && bias == Bias::Right)
510 {
511 self.position = child_end;
512 aggregate.push_tree(child_tree, child_summary, self.cx);
513 entry.index += 1;
514 entry.position = self.position.clone();
515 } else {
516 self.stack
517 .push(StackEntry {
518 tree: child_tree,
519 index: 0,
520 position: self.position.clone(),
521 })
522 .unwrap_oob();
523 ascending = false;
524 continue 'outer;
525 }
526 }
527 }
528 Node::Leaf {
529 ref items,
530 ref item_summaries,
531 ..
532 } => {
533 aggregate.begin_leaf();
534
535 for (item, item_summary) in items[entry.index()..]
536 .iter()
537 .zip(&item_summaries[entry.index()..])
538 {
539 let mut child_end = self.position.clone();
540 child_end.add_summary(item_summary, self.cx);
541
542 let comparison = target.cmp(&child_end, self.cx);
543 if comparison == Ordering::Greater
544 || (comparison == Ordering::Equal && bias == Bias::Right)
545 {
546 self.position = child_end;
547 aggregate.push_item(item, item_summary, self.cx);
548 entry.index += 1;
549 } else {
550 aggregate.end_leaf(self.cx);
551 break 'outer;
552 }
553 }
554
555 aggregate.end_leaf(self.cx);
556 }
557 }
558
559 self.stack.pop();
560 ascending = true;
561 }
562
563 self.at_end = self.stack.is_empty();
564 debug_assert!(self.stack.is_empty() || self.stack.last().unwrap().tree.0.is_leaf());
565
566 let mut end = self.position.clone();
567 if bias == Bias::Left
568 && let Some(summary) = self.item_summary()
569 {
570 end.add_summary(summary, self.cx);
571 }
572
573 target.cmp(&end, self.cx) == Ordering::Equal
574 }
575}
576
577impl<'a, T: Item> Iter<'a, T> {
578 pub(crate) fn new(tree: &'a SumTree<T>) -> Self {
579 Self {
580 tree,
581 stack: Default::default(),
582 }
583 }
584}
585
586impl<'a, T: Item> Iterator for Iter<'a, T> {
587 type Item = &'a T;
588
589 fn next(&mut self) -> Option<Self::Item> {
590 let mut descend = false;
591
592 if self.stack.is_empty() {
593 self.stack
594 .push(StackEntry {
595 tree: self.tree,
596 index: 0,
597 position: (),
598 })
599 .unwrap_oob();
600 descend = true;
601 }
602
603 while let Some(entry) = self.stack.last_mut() {
604 let new_subtree = {
605 match entry.tree.0.as_ref() {
606 Node::Internal { child_trees, .. } => {
607 if !descend {
608 entry.index += 1;
609 }
610 child_trees.get(entry.index())
611 }
612 Node::Leaf { items, .. } => {
613 if !descend {
614 entry.index += 1;
615 }
616
617 if let Some(next_item) = items.get(entry.index()) {
618 return Some(next_item);
619 } else {
620 None
621 }
622 }
623 }
624 };
625
626 if let Some(subtree) = new_subtree {
627 descend = true;
628 self.stack
629 .push(StackEntry {
630 tree: subtree,
631 index: 0,
632 position: (),
633 })
634 .unwrap_oob();
635 } else {
636 descend = false;
637 self.stack.pop();
638 }
639 }
640
641 None
642 }
643
644 fn last(mut self) -> Option<Self::Item> {
645 self.stack.clear();
646 self.tree.rightmost_leaf().last()
647 }
648
649 fn size_hint(&self) -> (usize, Option<usize>) {
650 let lower_bound = match self.stack.last() {
651 Some(top) => top.tree.0.child_summaries().len() - top.index as usize,
652 None => self.tree.0.child_summaries().len(),
653 };
654
655 (lower_bound, None)
656 }
657}
658
659impl<'a, 'b, T: Item, D> Iterator for Cursor<'a, 'b, T, D>
660where
661 D: Dimension<'a, T::Summary>,
662{
663 type Item = &'a T;
664
665 fn next(&mut self) -> Option<Self::Item> {
666 if !self.did_seek {
667 self.next();
668 }
669
670 if let Some(item) = self.item() {
671 self.next();
672 Some(item)
673 } else {
674 None
675 }
676 }
677}
678
679pub struct FilterCursor<'a, 'b, F, T: Item, D> {
680 cursor: Cursor<'a, 'b, T, D>,
681 filter_node: F,
682}
683
684impl<'a, 'b, F, T: Item, D> FilterCursor<'a, 'b, F, T, D>
685where
686 F: FnMut(&T::Summary) -> bool,
687 T: Item,
688 D: Dimension<'a, T::Summary>,
689{
690 pub fn new(
691 tree: &'a SumTree<T>,
692 cx: <T::Summary as Summary>::Context<'b>,
693 filter_node: F,
694 ) -> Self {
695 let cursor = tree.cursor::<D>(cx);
696 Self {
697 cursor,
698 filter_node,
699 }
700 }
701
702 pub fn start(&self) -> &D {
703 self.cursor.start()
704 }
705
706 pub fn end(&self) -> D {
707 self.cursor.end()
708 }
709
710 pub fn item(&self) -> Option<&'a T> {
711 self.cursor.item()
712 }
713
714 pub fn item_summary(&self) -> Option<&'a T::Summary> {
715 self.cursor.item_summary()
716 }
717
718 pub fn next(&mut self) {
719 self.cursor.search_forward(&mut self.filter_node);
720 }
721
722 pub fn prev(&mut self) {
723 self.cursor.search_backward(&mut self.filter_node);
724 }
725}
726
727impl<'a, 'b, F, T: Item, U> Iterator for FilterCursor<'a, 'b, F, T, U>
728where
729 F: FnMut(&T::Summary) -> bool,
730 U: Dimension<'a, T::Summary>,
731{
732 type Item = &'a T;
733
734 fn next(&mut self) -> Option<Self::Item> {
735 if !self.cursor.did_seek {
736 self.next();
737 }
738
739 if let Some(item) = self.item() {
740 self.cursor.search_forward(&mut self.filter_node);
741 Some(item)
742 } else {
743 None
744 }
745 }
746}
747
748trait SeekAggregate<'a, T: Item> {
749 fn begin_leaf(&mut self);
750 fn end_leaf(&mut self, cx: <T::Summary as Summary>::Context<'_>);
751 fn push_item(
752 &mut self,
753 item: &'a T,
754 summary: &'a T::Summary,
755 cx: <T::Summary as Summary>::Context<'_>,
756 );
757 fn push_tree(
758 &mut self,
759 tree: &'a SumTree<T>,
760 summary: &'a T::Summary,
761 cx: <T::Summary as Summary>::Context<'_>,
762 );
763}
764
765struct SliceSeekAggregate<T: Item> {
766 tree: SumTree<T>,
767 leaf_items: ArrayVec<T, { 2 * TREE_BASE }, u8>,
768 leaf_item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }, u8>,
769 leaf_summary: T::Summary,
770}
771
772struct SummarySeekAggregate<D>(D);
773
774impl<T: Item> SeekAggregate<'_, T> for () {
775 fn begin_leaf(&mut self) {}
776 fn end_leaf(&mut self, _: <T::Summary as Summary>::Context<'_>) {}
777 fn push_item(&mut self, _: &T, _: &T::Summary, _: <T::Summary as Summary>::Context<'_>) {}
778 fn push_tree(
779 &mut self,
780 _: &SumTree<T>,
781 _: &T::Summary,
782 _: <T::Summary as Summary>::Context<'_>,
783 ) {
784 }
785}
786
787impl<T: Item> SeekAggregate<'_, T> for SliceSeekAggregate<T> {
788 fn begin_leaf(&mut self) {}
789 fn end_leaf(&mut self, cx: <T::Summary as Summary>::Context<'_>) {
790 self.tree.append(
791 SumTree(Arc::new(Node::Leaf {
792 summary: mem::replace(&mut self.leaf_summary, <T::Summary as Summary>::zero(cx)),
793 items: mem::take(&mut self.leaf_items),
794 item_summaries: mem::take(&mut self.leaf_item_summaries),
795 })),
796 cx,
797 );
798 }
799 fn push_item(
800 &mut self,
801 item: &T,
802 summary: &T::Summary,
803 cx: <T::Summary as Summary>::Context<'_>,
804 ) {
805 self.leaf_items.push(item.clone()).unwrap_oob();
806 self.leaf_item_summaries.push(summary.clone()).unwrap_oob();
807 Summary::add_summary(&mut self.leaf_summary, summary, cx);
808 }
809 fn push_tree(
810 &mut self,
811 tree: &SumTree<T>,
812 _: &T::Summary,
813 cx: <T::Summary as Summary>::Context<'_>,
814 ) {
815 self.tree.append(tree.clone(), cx);
816 }
817}
818
819impl<'a, T: Item, D> SeekAggregate<'a, T> for SummarySeekAggregate<D>
820where
821 D: Dimension<'a, T::Summary>,
822{
823 fn begin_leaf(&mut self) {}
824 fn end_leaf(&mut self, _: <T::Summary as Summary>::Context<'_>) {}
825 fn push_item(
826 &mut self,
827 _: &T,
828 summary: &'a T::Summary,
829 cx: <T::Summary as Summary>::Context<'_>,
830 ) {
831 self.0.add_summary(summary, cx);
832 }
833 fn push_tree(
834 &mut self,
835 _: &SumTree<T>,
836 summary: &'a T::Summary,
837 cx: <T::Summary as Summary>::Context<'_>,
838 ) {
839 self.0.add_summary(summary, cx);
840 }
841}
842
843struct End<D>(PhantomData<D>);
844
845impl<D> End<D> {
846 fn new() -> Self {
847 Self(PhantomData)
848 }
849}
850
851impl<'a, S: Summary, D: Dimension<'a, S>> SeekTarget<'a, S, D> for End<D> {
852 fn cmp(&self, _: &D, _: S::Context<'_>) -> Ordering {
853 Ordering::Greater
854 }
855}
856
857impl<D> fmt::Debug for End<D> {
858 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
859 f.debug_tuple("End").finish()
860 }
861}