cursor.rs

  1use super::*;
  2use arrayvec::ArrayVec;
  3use std::{cmp::Ordering, mem, sync::Arc};
  4use ztracing::instrument;
  5
  6#[derive(Clone)]
  7struct StackEntry<'a, T: Item, D> {
  8    tree: &'a SumTree<T>,
  9    index: u32,
 10    position: D,
 11}
 12
 13impl<'a, T: Item, D> StackEntry<'a, T, D> {
 14    #[inline]
 15    fn index(&self) -> usize {
 16        self.index as usize
 17    }
 18}
 19
 20impl<T: Item + fmt::Debug, D: fmt::Debug> fmt::Debug for StackEntry<'_, T, D> {
 21    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 22        f.debug_struct("StackEntry")
 23            .field("index", &self.index)
 24            .field("position", &self.position)
 25            .finish()
 26    }
 27}
 28
 29#[derive(Clone)]
 30pub struct Cursor<'a, 'b, T: Item, D> {
 31    tree: &'a SumTree<T>,
 32    stack: ArrayVec<StackEntry<'a, T, D>, 16>,
 33    position: D,
 34    did_seek: bool,
 35    at_end: bool,
 36    cx: <T::Summary as Summary>::Context<'b>,
 37}
 38
 39impl<T: Item + fmt::Debug, D: fmt::Debug> fmt::Debug for Cursor<'_, '_, T, D>
 40where
 41    T::Summary: fmt::Debug,
 42{
 43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 44        f.debug_struct("Cursor")
 45            .field("tree", &self.tree)
 46            .field("stack", &self.stack)
 47            .field("position", &self.position)
 48            .field("did_seek", &self.did_seek)
 49            .field("at_end", &self.at_end)
 50            .finish()
 51    }
 52}
 53
 54pub struct Iter<'a, T: Item> {
 55    tree: &'a SumTree<T>,
 56    stack: ArrayVec<StackEntry<'a, T, ()>, 16>,
 57}
 58
 59impl<'a, 'b, T, D> Cursor<'a, 'b, T, D>
 60where
 61    T: Item,
 62    D: Dimension<'a, T::Summary>,
 63{
 64    pub fn new(tree: &'a SumTree<T>, cx: <T::Summary as Summary>::Context<'b>) -> Self {
 65        Self {
 66            tree,
 67            stack: ArrayVec::new(),
 68            position: D::zero(cx),
 69            did_seek: false,
 70            at_end: tree.is_empty(),
 71            cx,
 72        }
 73    }
 74
 75    fn reset(&mut self) {
 76        self.did_seek = false;
 77        self.at_end = self.tree.is_empty();
 78        self.stack.truncate(0);
 79        self.position = D::zero(self.cx);
 80    }
 81
 82    pub fn start(&self) -> &D {
 83        &self.position
 84    }
 85
 86    #[track_caller]
 87    pub fn end(&self) -> D {
 88        if let Some(item_summary) = self.item_summary() {
 89            let mut end = self.start().clone();
 90            end.add_summary(item_summary, self.cx);
 91            end
 92        } else {
 93            self.start().clone()
 94        }
 95    }
 96
 97    /// Item is None, when the list is empty, or this cursor is at the end of the list.
 98    #[track_caller]
 99    pub fn item(&self) -> Option<&'a T> {
100        self.assert_did_seek();
101        if let Some(entry) = self.stack.last() {
102            match *entry.tree.0 {
103                Node::Leaf { ref items, .. } => {
104                    if entry.index() == items.len() {
105                        None
106                    } else {
107                        Some(&items[entry.index()])
108                    }
109                }
110                _ => unreachable!(),
111            }
112        } else {
113            None
114        }
115    }
116
117    #[track_caller]
118    pub fn item_summary(&self) -> Option<&'a T::Summary> {
119        self.assert_did_seek();
120        if let Some(entry) = self.stack.last() {
121            match *entry.tree.0 {
122                Node::Leaf {
123                    ref item_summaries, ..
124                } => {
125                    if entry.index() == item_summaries.len() {
126                        None
127                    } else {
128                        Some(&item_summaries[entry.index()])
129                    }
130                }
131                _ => unreachable!(),
132            }
133        } else {
134            None
135        }
136    }
137
138    #[track_caller]
139    pub fn next_item(&self) -> Option<&'a T> {
140        self.assert_did_seek();
141        if let Some(entry) = self.stack.last() {
142            if entry.index() == entry.tree.0.items().len() - 1 {
143                if let Some(next_leaf) = self.next_leaf() {
144                    Some(next_leaf.0.items().first().unwrap())
145                } else {
146                    None
147                }
148            } else {
149                match *entry.tree.0 {
150                    Node::Leaf { ref items, .. } => Some(&items[entry.index() + 1]),
151                    _ => unreachable!(),
152                }
153            }
154        } else if self.at_end {
155            None
156        } else {
157            self.tree.first()
158        }
159    }
160
161    #[track_caller]
162    fn next_leaf(&self) -> Option<&'a SumTree<T>> {
163        for entry in self.stack.iter().rev().skip(1) {
164            if entry.index() < entry.tree.0.child_trees().len() - 1 {
165                match *entry.tree.0 {
166                    Node::Internal {
167                        ref child_trees, ..
168                    } => return Some(child_trees[entry.index() + 1].leftmost_leaf()),
169                    Node::Leaf { .. } => unreachable!(),
170                };
171            }
172        }
173        None
174    }
175
176    #[track_caller]
177    pub fn prev_item(&self) -> Option<&'a T> {
178        self.assert_did_seek();
179        if let Some(entry) = self.stack.last() {
180            if entry.index() == 0 {
181                if let Some(prev_leaf) = self.prev_leaf() {
182                    Some(prev_leaf.0.items().last().unwrap())
183                } else {
184                    None
185                }
186            } else {
187                match *entry.tree.0 {
188                    Node::Leaf { ref items, .. } => Some(&items[entry.index() - 1]),
189                    _ => unreachable!(),
190                }
191            }
192        } else if self.at_end {
193            self.tree.last()
194        } else {
195            None
196        }
197    }
198
199    #[track_caller]
200    fn prev_leaf(&self) -> Option<&'a SumTree<T>> {
201        for entry in self.stack.iter().rev().skip(1) {
202            if entry.index() != 0 {
203                match *entry.tree.0 {
204                    Node::Internal {
205                        ref child_trees, ..
206                    } => return Some(child_trees[entry.index() - 1].rightmost_leaf()),
207                    Node::Leaf { .. } => unreachable!(),
208                };
209            }
210        }
211        None
212    }
213
214    #[track_caller]
215    #[instrument(skip_all)]
216    pub fn prev(&mut self) {
217        self.search_backward(|_| true)
218    }
219
220    #[track_caller]
221    pub fn search_backward<F>(&mut self, mut filter_node: F)
222    where
223        F: FnMut(&T::Summary) -> bool,
224    {
225        if !self.did_seek {
226            self.did_seek = true;
227            self.at_end = true;
228        }
229
230        if self.at_end {
231            self.position = D::zero(self.cx);
232            self.at_end = self.tree.is_empty();
233            if !self.tree.is_empty() {
234                self.stack.push(StackEntry {
235                    tree: self.tree,
236                    index: self.tree.0.child_summaries().len() as u32,
237                    position: D::from_summary(self.tree.summary(), self.cx),
238                });
239            }
240        }
241
242        let mut descending = false;
243        while !self.stack.is_empty() {
244            if let Some(StackEntry { position, .. }) = self.stack.iter().rev().nth(1) {
245                self.position = position.clone();
246            } else {
247                self.position = D::zero(self.cx);
248            }
249
250            let entry = self.stack.last_mut().unwrap();
251            if !descending {
252                if entry.index() == 0 {
253                    self.stack.pop();
254                    continue;
255                } else {
256                    entry.index -= 1;
257                }
258            }
259
260            for summary in &entry.tree.0.child_summaries()[..entry.index()] {
261                self.position.add_summary(summary, self.cx);
262            }
263            entry.position = self.position.clone();
264
265            descending = filter_node(&entry.tree.0.child_summaries()[entry.index()]);
266            match entry.tree.0.as_ref() {
267                Node::Internal { child_trees, .. } => {
268                    if descending {
269                        let tree = &child_trees[entry.index()];
270                        self.stack.push(StackEntry {
271                            position: D::zero(self.cx),
272                            tree,
273                            index: tree.0.child_summaries().len() as u32 - 1,
274                        })
275                    }
276                }
277                Node::Leaf { .. } => {
278                    if descending {
279                        break;
280                    }
281                }
282            }
283        }
284    }
285
286    #[track_caller]
287    pub fn next(&mut self) {
288        self.search_forward(|_| true)
289    }
290
291    #[track_caller]
292    pub fn search_forward<F>(&mut self, mut filter_node: F)
293    where
294        F: FnMut(&T::Summary) -> bool,
295    {
296        let mut descend = false;
297
298        if self.stack.is_empty() {
299            if !self.at_end {
300                self.stack.push(StackEntry {
301                    tree: self.tree,
302                    index: 0,
303                    position: D::zero(self.cx),
304                });
305                descend = true;
306            }
307            self.did_seek = true;
308        }
309
310        while !self.stack.is_empty() {
311            let new_subtree = {
312                let entry = self.stack.last_mut().unwrap();
313                match entry.tree.0.as_ref() {
314                    Node::Internal {
315                        child_trees,
316                        child_summaries,
317                        ..
318                    } => {
319                        if !descend {
320                            entry.index += 1;
321                            entry.position = self.position.clone();
322                        }
323
324                        while entry.index() < child_summaries.len() {
325                            let next_summary = &child_summaries[entry.index()];
326                            if filter_node(next_summary) {
327                                break;
328                            } else {
329                                entry.index += 1;
330                                entry.position.add_summary(next_summary, self.cx);
331                                self.position.add_summary(next_summary, self.cx);
332                            }
333                        }
334
335                        child_trees.get(entry.index())
336                    }
337                    Node::Leaf { item_summaries, .. } => {
338                        if !descend {
339                            let item_summary = &item_summaries[entry.index()];
340                            entry.index += 1;
341                            entry.position.add_summary(item_summary, self.cx);
342                            self.position.add_summary(item_summary, self.cx);
343                        }
344
345                        loop {
346                            if let Some(next_item_summary) = item_summaries.get(entry.index()) {
347                                if filter_node(next_item_summary) {
348                                    return;
349                                } else {
350                                    entry.index += 1;
351                                    entry.position.add_summary(next_item_summary, self.cx);
352                                    self.position.add_summary(next_item_summary, self.cx);
353                                }
354                            } else {
355                                break None;
356                            }
357                        }
358                    }
359                }
360            };
361
362            if let Some(subtree) = new_subtree {
363                descend = true;
364                self.stack.push(StackEntry {
365                    tree: subtree,
366                    index: 0,
367                    position: self.position.clone(),
368                });
369            } else {
370                descend = false;
371                self.stack.pop();
372            }
373        }
374
375        self.at_end = self.stack.is_empty();
376        debug_assert!(self.stack.is_empty() || self.stack.last().unwrap().tree.0.is_leaf());
377    }
378
379    #[track_caller]
380    fn assert_did_seek(&self) {
381        assert!(
382            self.did_seek,
383            "Must call `seek`, `next` or `prev` before calling this method"
384        );
385    }
386
387    pub fn did_seek(&self) -> bool {
388        self.did_seek
389    }
390}
391
392impl<'a, 'b, T, D> Cursor<'a, 'b, T, D>
393where
394    T: Item,
395    D: Dimension<'a, T::Summary>,
396{
397    /// Returns whether we found the item you were seeking for.
398    #[track_caller]
399    #[instrument(skip_all)]
400    pub fn seek<Target>(&mut self, pos: &Target, bias: Bias) -> bool
401    where
402        Target: SeekTarget<'a, T::Summary, D>,
403    {
404        self.reset();
405        self.seek_internal(pos, bias, &mut ())
406    }
407
408    /// Returns whether we found the item you were seeking for.
409    ///
410    /// # Panics
411    ///
412    /// If we did not seek before, use seek instead in that case.
413    #[track_caller]
414    #[instrument(skip_all)]
415    pub fn seek_forward<Target>(&mut self, pos: &Target, bias: Bias) -> bool
416    where
417        Target: SeekTarget<'a, T::Summary, D>,
418    {
419        self.seek_internal(pos, bias, &mut ())
420    }
421
422    /// Advances the cursor and returns traversed items as a tree.
423    #[track_caller]
424    pub fn slice<Target>(&mut self, end: &Target, bias: Bias) -> SumTree<T>
425    where
426        Target: SeekTarget<'a, T::Summary, D>,
427    {
428        let mut slice = SliceSeekAggregate {
429            tree: SumTree::new(self.cx),
430            leaf_items: ArrayVec::new(),
431            leaf_item_summaries: ArrayVec::new(),
432            leaf_summary: <T::Summary as Summary>::zero(self.cx),
433        };
434        self.seek_internal(end, bias, &mut slice);
435        slice.tree
436    }
437
438    #[track_caller]
439    pub fn suffix(&mut self) -> SumTree<T> {
440        self.slice(&End::new(), Bias::Right)
441    }
442
443    #[track_caller]
444    pub fn summary<Target, Output>(&mut self, end: &Target, bias: Bias) -> Output
445    where
446        Target: SeekTarget<'a, T::Summary, D>,
447        Output: Dimension<'a, T::Summary>,
448    {
449        let mut summary = SummarySeekAggregate(Output::zero(self.cx));
450        self.seek_internal(end, bias, &mut summary);
451        summary.0
452    }
453
454    /// Returns whether we found the item you were seeking for.
455    #[track_caller]
456    #[instrument(skip_all)]
457    fn seek_internal(
458        &mut self,
459        target: &dyn SeekTarget<'a, T::Summary, D>,
460        bias: Bias,
461        aggregate: &mut dyn SeekAggregate<'a, T>,
462    ) -> bool {
463        assert!(
464            target.cmp(&self.position, self.cx).is_ge(),
465            "cannot seek backward",
466        );
467
468        if !self.did_seek {
469            self.did_seek = true;
470            self.stack.push(StackEntry {
471                tree: self.tree,
472                index: 0,
473                position: D::zero(self.cx),
474            });
475        }
476
477        let mut ascending = false;
478        'outer: while let Some(entry) = self.stack.last_mut() {
479            match *entry.tree.0 {
480                Node::Internal {
481                    ref child_summaries,
482                    ref child_trees,
483                    ..
484                } => {
485                    if ascending {
486                        entry.index += 1;
487                        entry.position = self.position.clone();
488                    }
489
490                    for (child_tree, child_summary) in child_trees[entry.index()..]
491                        .iter()
492                        .zip(&child_summaries[entry.index()..])
493                    {
494                        let mut child_end = self.position.clone();
495                        child_end.add_summary(child_summary, self.cx);
496
497                        let comparison = target.cmp(&child_end, self.cx);
498                        if comparison == Ordering::Greater
499                            || (comparison == Ordering::Equal && bias == Bias::Right)
500                        {
501                            self.position = child_end;
502                            aggregate.push_tree(child_tree, child_summary, self.cx);
503                            entry.index += 1;
504                            entry.position = self.position.clone();
505                        } else {
506                            self.stack.push(StackEntry {
507                                tree: child_tree,
508                                index: 0,
509                                position: self.position.clone(),
510                            });
511                            ascending = false;
512                            continue 'outer;
513                        }
514                    }
515                }
516                Node::Leaf {
517                    ref items,
518                    ref item_summaries,
519                    ..
520                } => {
521                    aggregate.begin_leaf();
522
523                    for (item, item_summary) in items[entry.index()..]
524                        .iter()
525                        .zip(&item_summaries[entry.index()..])
526                    {
527                        let mut child_end = self.position.clone();
528                        child_end.add_summary(item_summary, self.cx);
529
530                        let comparison = target.cmp(&child_end, self.cx);
531                        if comparison == Ordering::Greater
532                            || (comparison == Ordering::Equal && bias == Bias::Right)
533                        {
534                            self.position = child_end;
535                            aggregate.push_item(item, item_summary, self.cx);
536                            entry.index += 1;
537                        } else {
538                            aggregate.end_leaf(self.cx);
539                            break 'outer;
540                        }
541                    }
542
543                    aggregate.end_leaf(self.cx);
544                }
545            }
546
547            self.stack.pop();
548            ascending = true;
549        }
550
551        self.at_end = self.stack.is_empty();
552        debug_assert!(self.stack.is_empty() || self.stack.last().unwrap().tree.0.is_leaf());
553
554        let mut end = self.position.clone();
555        if bias == Bias::Left
556            && let Some(summary) = self.item_summary()
557        {
558            end.add_summary(summary, self.cx);
559        }
560
561        target.cmp(&end, self.cx) == Ordering::Equal
562    }
563}
564
565impl<'a, T: Item> Iter<'a, T> {
566    pub(crate) fn new(tree: &'a SumTree<T>) -> Self {
567        Self {
568            tree,
569            stack: Default::default(),
570        }
571    }
572}
573
574impl<'a, T: Item> Iterator for Iter<'a, T> {
575    type Item = &'a T;
576
577    fn next(&mut self) -> Option<Self::Item> {
578        let mut descend = false;
579
580        if self.stack.is_empty() {
581            self.stack.push(StackEntry {
582                tree: self.tree,
583                index: 0,
584                position: (),
585            });
586            descend = true;
587        }
588
589        while !self.stack.is_empty() {
590            let new_subtree = {
591                let entry = self.stack.last_mut().unwrap();
592                match entry.tree.0.as_ref() {
593                    Node::Internal { child_trees, .. } => {
594                        if !descend {
595                            entry.index += 1;
596                        }
597                        child_trees.get(entry.index())
598                    }
599                    Node::Leaf { items, .. } => {
600                        if !descend {
601                            entry.index += 1;
602                        }
603
604                        if let Some(next_item) = items.get(entry.index()) {
605                            return Some(next_item);
606                        } else {
607                            None
608                        }
609                    }
610                }
611            };
612
613            if let Some(subtree) = new_subtree {
614                descend = true;
615                self.stack.push(StackEntry {
616                    tree: subtree,
617                    index: 0,
618                    position: (),
619                });
620            } else {
621                descend = false;
622                self.stack.pop();
623            }
624        }
625
626        None
627    }
628}
629
630impl<'a, 'b, T: Item, D> Iterator for Cursor<'a, 'b, T, D>
631where
632    D: Dimension<'a, T::Summary>,
633{
634    type Item = &'a T;
635
636    fn next(&mut self) -> Option<Self::Item> {
637        if !self.did_seek {
638            self.next();
639        }
640
641        if let Some(item) = self.item() {
642            self.next();
643            Some(item)
644        } else {
645            None
646        }
647    }
648}
649
650pub struct FilterCursor<'a, 'b, F, T: Item, D> {
651    cursor: Cursor<'a, 'b, T, D>,
652    filter_node: F,
653}
654
655impl<'a, 'b, F, T: Item, D> FilterCursor<'a, 'b, F, T, D>
656where
657    F: FnMut(&T::Summary) -> bool,
658    T: Item,
659    D: Dimension<'a, T::Summary>,
660{
661    pub fn new(
662        tree: &'a SumTree<T>,
663        cx: <T::Summary as Summary>::Context<'b>,
664        filter_node: F,
665    ) -> Self {
666        let cursor = tree.cursor::<D>(cx);
667        Self {
668            cursor,
669            filter_node,
670        }
671    }
672
673    pub fn start(&self) -> &D {
674        self.cursor.start()
675    }
676
677    pub fn end(&self) -> D {
678        self.cursor.end()
679    }
680
681    pub fn item(&self) -> Option<&'a T> {
682        self.cursor.item()
683    }
684
685    pub fn item_summary(&self) -> Option<&'a T::Summary> {
686        self.cursor.item_summary()
687    }
688
689    pub fn next(&mut self) {
690        self.cursor.search_forward(&mut self.filter_node);
691    }
692
693    pub fn prev(&mut self) {
694        self.cursor.search_backward(&mut self.filter_node);
695    }
696}
697
698impl<'a, 'b, F, T: Item, U> Iterator for FilterCursor<'a, 'b, F, T, U>
699where
700    F: FnMut(&T::Summary) -> bool,
701    U: Dimension<'a, T::Summary>,
702{
703    type Item = &'a T;
704
705    fn next(&mut self) -> Option<Self::Item> {
706        if !self.cursor.did_seek {
707            self.next();
708        }
709
710        if let Some(item) = self.item() {
711            self.cursor.search_forward(&mut self.filter_node);
712            Some(item)
713        } else {
714            None
715        }
716    }
717}
718
719trait SeekAggregate<'a, T: Item> {
720    fn begin_leaf(&mut self);
721    fn end_leaf(&mut self, cx: <T::Summary as Summary>::Context<'_>);
722    fn push_item(
723        &mut self,
724        item: &'a T,
725        summary: &'a T::Summary,
726        cx: <T::Summary as Summary>::Context<'_>,
727    );
728    fn push_tree(
729        &mut self,
730        tree: &'a SumTree<T>,
731        summary: &'a T::Summary,
732        cx: <T::Summary as Summary>::Context<'_>,
733    );
734}
735
736struct SliceSeekAggregate<T: Item> {
737    tree: SumTree<T>,
738    leaf_items: ArrayVec<T, { 2 * TREE_BASE }>,
739    leaf_item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }>,
740    leaf_summary: T::Summary,
741}
742
743struct SummarySeekAggregate<D>(D);
744
745impl<T: Item> SeekAggregate<'_, T> for () {
746    fn begin_leaf(&mut self) {}
747    fn end_leaf(&mut self, _: <T::Summary as Summary>::Context<'_>) {}
748    fn push_item(&mut self, _: &T, _: &T::Summary, _: <T::Summary as Summary>::Context<'_>) {}
749    fn push_tree(
750        &mut self,
751        _: &SumTree<T>,
752        _: &T::Summary,
753        _: <T::Summary as Summary>::Context<'_>,
754    ) {
755    }
756}
757
758impl<T: Item> SeekAggregate<'_, T> for SliceSeekAggregate<T> {
759    fn begin_leaf(&mut self) {}
760    fn end_leaf(&mut self, cx: <T::Summary as Summary>::Context<'_>) {
761        self.tree.append(
762            SumTree(Arc::new(Node::Leaf {
763                summary: mem::replace(&mut self.leaf_summary, <T::Summary as Summary>::zero(cx)),
764                items: mem::take(&mut self.leaf_items),
765                item_summaries: mem::take(&mut self.leaf_item_summaries),
766            })),
767            cx,
768        );
769    }
770    fn push_item(
771        &mut self,
772        item: &T,
773        summary: &T::Summary,
774        cx: <T::Summary as Summary>::Context<'_>,
775    ) {
776        self.leaf_items.push(item.clone());
777        self.leaf_item_summaries.push(summary.clone());
778        Summary::add_summary(&mut self.leaf_summary, summary, cx);
779    }
780    fn push_tree(
781        &mut self,
782        tree: &SumTree<T>,
783        _: &T::Summary,
784        cx: <T::Summary as Summary>::Context<'_>,
785    ) {
786        self.tree.append(tree.clone(), cx);
787    }
788}
789
790impl<'a, T: Item, D> SeekAggregate<'a, T> for SummarySeekAggregate<D>
791where
792    D: Dimension<'a, T::Summary>,
793{
794    fn begin_leaf(&mut self) {}
795    fn end_leaf(&mut self, _: <T::Summary as Summary>::Context<'_>) {}
796    fn push_item(
797        &mut self,
798        _: &T,
799        summary: &'a T::Summary,
800        cx: <T::Summary as Summary>::Context<'_>,
801    ) {
802        self.0.add_summary(summary, cx);
803    }
804    fn push_tree(
805        &mut self,
806        _: &SumTree<T>,
807        summary: &'a T::Summary,
808        cx: <T::Summary as Summary>::Context<'_>,
809    ) {
810        self.0.add_summary(summary, cx);
811    }
812}
813
814struct End<D>(PhantomData<D>);
815
816impl<D> End<D> {
817    fn new() -> Self {
818        Self(PhantomData)
819    }
820}
821
822impl<'a, S: Summary, D: Dimension<'a, S>> SeekTarget<'a, S, D> for End<D> {
823    fn cmp(&self, _: &D, _: S::Context<'_>) -> Ordering {
824        Ordering::Greater
825    }
826}
827
828impl<D> fmt::Debug for End<D> {
829    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
830        f.debug_tuple("End").finish()
831    }
832}