cursor.rs

  1use super::*;
  2use arrayvec::ArrayVec;
  3use std::{cmp::Ordering, mem, sync::Arc};
  4
  5#[derive(Clone)]
  6struct StackEntry<'a, T: Item, D> {
  7    tree: &'a SumTree<T>,
  8    index: usize,
  9    position: D,
 10}
 11
 12impl<T: Item + fmt::Debug, D: fmt::Debug> fmt::Debug for StackEntry<'_, T, D> {
 13    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 14        f.debug_struct("StackEntry")
 15            .field("index", &self.index)
 16            .field("position", &self.position)
 17            .finish()
 18    }
 19}
 20
 21#[derive(Clone)]
 22pub struct Cursor<'a, T: Item, D> {
 23    tree: &'a SumTree<T>,
 24    stack: ArrayVec<StackEntry<'a, T, D>, 16>,
 25    position: D,
 26    did_seek: bool,
 27    at_end: bool,
 28    cx: &'a <T::Summary as Summary>::Context,
 29}
 30
 31impl<T: Item + fmt::Debug, D: fmt::Debug> fmt::Debug for Cursor<'_, T, D>
 32where
 33    T::Summary: fmt::Debug,
 34{
 35    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 36        f.debug_struct("Cursor")
 37            .field("tree", &self.tree)
 38            .field("stack", &self.stack)
 39            .field("position", &self.position)
 40            .field("did_seek", &self.did_seek)
 41            .field("at_end", &self.at_end)
 42            .finish()
 43    }
 44}
 45
 46pub struct Iter<'a, T: Item> {
 47    tree: &'a SumTree<T>,
 48    stack: ArrayVec<StackEntry<'a, T, ()>, 16>,
 49}
 50
 51impl<'a, T, D> Cursor<'a, T, D>
 52where
 53    T: Item,
 54    D: Dimension<'a, T::Summary>,
 55{
 56    pub fn new(tree: &'a SumTree<T>, cx: &'a <T::Summary as Summary>::Context) -> Self {
 57        Self {
 58            tree,
 59            stack: ArrayVec::new(),
 60            position: D::zero(cx),
 61            did_seek: false,
 62            at_end: tree.is_empty(),
 63            cx,
 64        }
 65    }
 66
 67    fn reset(&mut self) {
 68        self.did_seek = false;
 69        self.at_end = self.tree.is_empty();
 70        self.stack.truncate(0);
 71        self.position = D::zero(self.cx);
 72    }
 73
 74    pub fn start(&self) -> &D {
 75        &self.position
 76    }
 77
 78    #[track_caller]
 79    pub fn end(&self) -> D {
 80        if let Some(item_summary) = self.item_summary() {
 81            let mut end = self.start().clone();
 82            end.add_summary(item_summary, self.cx);
 83            end
 84        } else {
 85            self.start().clone()
 86        }
 87    }
 88
 89    /// Item is None, when the list is empty, or this cursor is at the end of the list.
 90    #[track_caller]
 91    pub fn item(&self) -> Option<&'a T> {
 92        self.assert_did_seek();
 93        if let Some(entry) = self.stack.last() {
 94            match *entry.tree.0 {
 95                Node::Leaf { ref items, .. } => {
 96                    if entry.index == items.len() {
 97                        None
 98                    } else {
 99                        Some(&items[entry.index])
100                    }
101                }
102                _ => unreachable!(),
103            }
104        } else {
105            None
106        }
107    }
108
109    #[track_caller]
110    pub fn item_summary(&self) -> Option<&'a T::Summary> {
111        self.assert_did_seek();
112        if let Some(entry) = self.stack.last() {
113            match *entry.tree.0 {
114                Node::Leaf {
115                    ref item_summaries, ..
116                } => {
117                    if entry.index == item_summaries.len() {
118                        None
119                    } else {
120                        Some(&item_summaries[entry.index])
121                    }
122                }
123                _ => unreachable!(),
124            }
125        } else {
126            None
127        }
128    }
129
130    #[track_caller]
131    pub fn next_item(&self) -> Option<&'a T> {
132        self.assert_did_seek();
133        if let Some(entry) = self.stack.last() {
134            if entry.index == entry.tree.0.items().len() - 1 {
135                if let Some(next_leaf) = self.next_leaf() {
136                    Some(next_leaf.0.items().first().unwrap())
137                } else {
138                    None
139                }
140            } else {
141                match *entry.tree.0 {
142                    Node::Leaf { ref items, .. } => Some(&items[entry.index + 1]),
143                    _ => unreachable!(),
144                }
145            }
146        } else if self.at_end {
147            None
148        } else {
149            self.tree.first()
150        }
151    }
152
153    #[track_caller]
154    fn next_leaf(&self) -> Option<&'a SumTree<T>> {
155        for entry in self.stack.iter().rev().skip(1) {
156            if entry.index < entry.tree.0.child_trees().len() - 1 {
157                match *entry.tree.0 {
158                    Node::Internal {
159                        ref child_trees, ..
160                    } => return Some(child_trees[entry.index + 1].leftmost_leaf()),
161                    Node::Leaf { .. } => unreachable!(),
162                };
163            }
164        }
165        None
166    }
167
168    #[track_caller]
169    pub fn prev_item(&self) -> Option<&'a T> {
170        self.assert_did_seek();
171        if let Some(entry) = self.stack.last() {
172            if entry.index == 0 {
173                if let Some(prev_leaf) = self.prev_leaf() {
174                    Some(prev_leaf.0.items().last().unwrap())
175                } else {
176                    None
177                }
178            } else {
179                match *entry.tree.0 {
180                    Node::Leaf { ref items, .. } => Some(&items[entry.index - 1]),
181                    _ => unreachable!(),
182                }
183            }
184        } else if self.at_end {
185            self.tree.last()
186        } else {
187            None
188        }
189    }
190
191    #[track_caller]
192    fn prev_leaf(&self) -> Option<&'a SumTree<T>> {
193        for entry in self.stack.iter().rev().skip(1) {
194            if entry.index != 0 {
195                match *entry.tree.0 {
196                    Node::Internal {
197                        ref child_trees, ..
198                    } => return Some(child_trees[entry.index - 1].rightmost_leaf()),
199                    Node::Leaf { .. } => unreachable!(),
200                };
201            }
202        }
203        None
204    }
205
206    #[track_caller]
207    pub fn prev(&mut self) {
208        self.search_backward(|_| Ordering::Greater)
209    }
210
211    #[track_caller]
212    pub fn search_backward<F>(&mut self, mut filter_node: F)
213    where
214        F: FnMut(&T::Summary) -> Ordering,
215    {
216        if !self.did_seek {
217            self.did_seek = true;
218            self.at_end = true;
219        }
220
221        if self.at_end {
222            self.position = D::zero(self.cx);
223            self.at_end = self.tree.is_empty();
224            if !self.tree.is_empty() {
225                let position = if let Some(summary) = self.tree.0.summary() {
226                    D::from_summary(summary, self.cx)
227                } else {
228                    D::zero(self.cx)
229                };
230                self.stack.push(StackEntry {
231                    tree: self.tree,
232                    index: self.tree.0.child_summaries().len(),
233                    position,
234                });
235            }
236        }
237
238        let mut descending = false;
239        while !self.stack.is_empty() {
240            if let Some(StackEntry { position, .. }) = self.stack.iter().rev().nth(1) {
241                self.position = position.clone();
242            } else {
243                self.position = D::zero(self.cx);
244            }
245
246            let entry = self.stack.last_mut().unwrap();
247            if !descending {
248                if entry.index == 0 {
249                    self.stack.pop();
250                    continue;
251                } else {
252                    entry.index -= 1;
253                }
254            }
255
256            if entry.index != 0 {
257                self.position
258                    .add_summary(&entry.tree.0.child_summaries()[entry.index - 1], self.cx);
259            }
260
261            entry.position = self.position.clone();
262
263            descending = filter_node(&entry.tree.0.child_summaries()[entry.index]).is_ge();
264            match entry.tree.0.as_ref() {
265                Node::Internal { child_trees, .. } => {
266                    if descending {
267                        let tree = &child_trees[entry.index];
268                        self.stack.push(StackEntry {
269                            position: D::zero(self.cx),
270                            tree,
271                            index: tree.0.child_summaries().len() - 1,
272                        })
273                    }
274                }
275                Node::Leaf { .. } => {
276                    if descending {
277                        break;
278                    }
279                }
280            }
281        }
282    }
283
284    #[track_caller]
285    pub fn next(&mut self) {
286        self.search_forward(|_| Ordering::Less)
287    }
288
289    #[track_caller]
290    pub fn search_forward<F>(&mut self, mut filter_node: F)
291    where
292        F: FnMut(&T::Summary) -> Ordering,
293    {
294        let mut descend = false;
295
296        if self.stack.is_empty() {
297            if !self.at_end {
298                self.stack.push(StackEntry {
299                    tree: self.tree,
300                    index: 0,
301                    position: D::zero(self.cx),
302                });
303                descend = true;
304            }
305            self.did_seek = true;
306        }
307
308        while !self.stack.is_empty() {
309            let new_subtree = {
310                let entry = self.stack.last_mut().unwrap();
311                match entry.tree.0.as_ref() {
312                    Node::Internal {
313                        child_trees,
314                        child_summaries,
315                        ..
316                    } => {
317                        if !descend {
318                            entry.index += 1;
319                            entry.position = self.position.clone();
320                        }
321
322                        if entry.index < child_summaries.len() {
323                            let index = child_summaries[entry.index..]
324                                .partition_point(|item| filter_node(item).is_lt());
325                            if index < child_summaries.len() - entry.index {
326                                entry.index += index;
327                            }
328
329                            let position = Some(entry.index)
330                                .filter(|index| *index < child_summaries.len())
331                                .unwrap_or(child_summaries.len());
332
333                            if let Some(summary) = child_summaries.get(position) {
334                                entry.position.add_summary(summary, self.cx);
335                                self.position.add_summary(summary, self.cx);
336                            }
337                        }
338                        dbg!((entry.index, child_trees.len()));
339
340                        child_trees.get(entry.index)
341                    }
342                    Node::Leaf { item_summaries, .. } => {
343                        dbg!("Ayo");
344                        if !descend {
345                            let item_summary = &item_summaries[entry.index];
346                            entry.index += 1;
347                            entry.position.add_summary(item_summary, self.cx);
348                            self.position.add_summary(item_summary, self.cx);
349                        }
350
351                        if entry.index < item_summaries.len() {
352                            let index = item_summaries[entry.index..]
353                                .partition_point(|item| filter_node(item).is_lt());
354                            if index < item_summaries.len() - entry.index {
355                                entry.index += index;
356                            }
357                            entry.index += index;
358                            let position = Some(entry.index)
359                                .filter(|index| *index < item_summaries.len())
360                                .unwrap_or(item_summaries.len());
361
362                            if let Some(summary) = item_summaries.get(position) {
363                                entry.position.add_summary(summary, self.cx);
364                                self.position.add_summary(summary, self.cx);
365                            }
366                            return;
367                        } else {
368                            None
369                        }
370                    }
371                }
372            };
373
374            if let Some(subtree) = new_subtree {
375                descend = true;
376                self.stack.push(StackEntry {
377                    tree: subtree,
378                    index: 0,
379                    position: self.position.clone(),
380                });
381            } else {
382                descend = false;
383                self.stack.pop();
384            }
385        }
386
387        self.at_end = self.stack.is_empty();
388        debug_assert!(self.stack.is_empty() || self.stack.last().unwrap().tree.0.is_leaf());
389    }
390
391    #[track_caller]
392    fn assert_did_seek(&self) {
393        assert!(
394            self.did_seek,
395            "Must call `seek`, `next` or `prev` before calling this method"
396        );
397    }
398}
399
400impl<'a, T, D> Cursor<'a, T, D>
401where
402    T: Item,
403    D: Dimension<'a, T::Summary>,
404{
405    #[track_caller]
406    pub fn seek<Target>(&mut self, pos: &Target, bias: Bias) -> bool
407    where
408        Target: SeekTarget<'a, T::Summary, D>,
409    {
410        self.reset();
411        self.seek_internal(pos, bias, &mut ())
412    }
413
414    #[track_caller]
415    pub fn seek_forward<Target>(&mut self, pos: &Target, bias: Bias) -> bool
416    where
417        Target: SeekTarget<'a, T::Summary, D>,
418    {
419        self.seek_internal(pos, bias, &mut ())
420    }
421
422    /// Advances the cursor and returns traversed items as a tree.
423    #[track_caller]
424    pub fn slice<Target>(&mut self, end: &Target, bias: Bias) -> SumTree<T>
425    where
426        Target: SeekTarget<'a, T::Summary, D>,
427    {
428        let mut slice = SliceSeekAggregate {
429            tree: SumTree::new(),
430            leaf_items: ArrayVec::new(),
431            leaf_item_summaries: ArrayVec::new(),
432            leaf_summary: <T::Summary as Summary>::zero(self.cx),
433        };
434        self.seek_internal(end, bias, &mut slice);
435        slice.tree
436    }
437
438    #[track_caller]
439    pub fn suffix(&mut self) -> SumTree<T> {
440        self.slice(&End::new(), Bias::Right)
441    }
442
443    #[track_caller]
444    pub fn summary<Target, Output>(&mut self, end: &Target, bias: Bias) -> Output
445    where
446        Target: SeekTarget<'a, T::Summary, D>,
447        Output: Dimension<'a, T::Summary>,
448    {
449        let mut summary = SummarySeekAggregate(Output::zero(self.cx));
450        self.seek_internal(end, bias, &mut summary);
451        summary.0
452    }
453
454    /// Returns whether we found the item you were seeking for
455    #[track_caller]
456    fn seek_internal(
457        &mut self,
458        target: &dyn SeekTarget<'a, T::Summary, D>,
459        bias: Bias,
460        aggregate: &mut dyn SeekAggregate<'a, T>,
461    ) -> bool {
462        assert!(
463            target.cmp(&self.position, self.cx) >= Ordering::Equal,
464            "cannot seek backward",
465        );
466
467        if !self.did_seek {
468            self.did_seek = true;
469            self.stack.push(StackEntry {
470                tree: self.tree,
471                index: 0,
472                position: D::zero(self.cx),
473            });
474        }
475
476        let mut ascending = false;
477        'outer: while let Some(entry) = self.stack.last_mut() {
478            match *entry.tree.0 {
479                Node::Internal {
480                    ref child_summaries,
481                    ref child_trees,
482                    ..
483                } => {
484                    if ascending {
485                        entry.index += 1;
486                        entry.position = self.position.clone();
487                    }
488
489                    for (child_tree, child_summary) in child_trees[entry.index..]
490                        .iter()
491                        .zip(&child_summaries[entry.index..])
492                    {
493                        let mut child_end = self.position.clone();
494                        child_end.add_summary(child_summary, self.cx);
495
496                        let comparison = target.cmp(&child_end, self.cx);
497                        if comparison == Ordering::Greater
498                            || (comparison == Ordering::Equal && bias == Bias::Right)
499                        {
500                            self.position = child_end;
501                            aggregate.push_tree(child_tree, child_summary, self.cx);
502                            entry.index += 1;
503                            entry.position = self.position.clone();
504                        } else {
505                            self.stack.push(StackEntry {
506                                tree: child_tree,
507                                index: 0,
508                                position: self.position.clone(),
509                            });
510                            ascending = false;
511                            continue 'outer;
512                        }
513                    }
514                }
515                Node::Leaf {
516                    ref items,
517                    ref item_summaries,
518                    ..
519                } => {
520                    aggregate.begin_leaf();
521
522                    for (item, item_summary) in items[entry.index..]
523                        .iter()
524                        .zip(&item_summaries[entry.index..])
525                    {
526                        let mut child_end = self.position.clone();
527                        child_end.add_summary(item_summary, self.cx);
528
529                        let comparison = target.cmp(&child_end, self.cx);
530                        if comparison == Ordering::Greater
531                            || (comparison == Ordering::Equal && bias == Bias::Right)
532                        {
533                            self.position = child_end;
534                            aggregate.push_item(item, item_summary, self.cx);
535                            entry.index += 1;
536                        } else {
537                            aggregate.end_leaf(self.cx);
538                            break 'outer;
539                        }
540                    }
541
542                    aggregate.end_leaf(self.cx);
543                }
544            }
545
546            self.stack.pop();
547            ascending = true;
548        }
549
550        self.at_end = self.stack.is_empty();
551        debug_assert!(self.stack.is_empty() || self.stack.last().unwrap().tree.0.is_leaf());
552
553        let mut end = self.position.clone();
554        if bias == Bias::Left {
555            if let Some(summary) = self.item_summary() {
556                end.add_summary(summary, self.cx);
557            }
558        }
559
560        target.cmp(&end, self.cx) == Ordering::Equal
561    }
562}
563
564impl<'a, T: Item> Iter<'a, T> {
565    pub(crate) fn new(tree: &'a SumTree<T>) -> Self {
566        Self {
567            tree,
568            stack: Default::default(),
569        }
570    }
571}
572
573impl<'a, T: Item> Iterator for Iter<'a, T> {
574    type Item = &'a T;
575
576    fn next(&mut self) -> Option<Self::Item> {
577        let mut descend = false;
578
579        if self.stack.is_empty() {
580            self.stack.push(StackEntry {
581                tree: self.tree,
582                index: 0,
583                position: (),
584            });
585            descend = true;
586        }
587
588        while !self.stack.is_empty() {
589            let new_subtree = {
590                let entry = self.stack.last_mut().unwrap();
591                match entry.tree.0.as_ref() {
592                    Node::Internal { child_trees, .. } => {
593                        if !descend {
594                            entry.index += 1;
595                        }
596                        child_trees.get(entry.index)
597                    }
598                    Node::Leaf { items, .. } => {
599                        if !descend {
600                            entry.index += 1;
601                        }
602
603                        if let Some(next_item) = items.get(entry.index) {
604                            return Some(next_item);
605                        } else {
606                            None
607                        }
608                    }
609                }
610            };
611
612            if let Some(subtree) = new_subtree {
613                descend = true;
614                self.stack.push(StackEntry {
615                    tree: subtree,
616                    index: 0,
617                    position: (),
618                });
619            } else {
620                descend = false;
621                self.stack.pop();
622            }
623        }
624
625        None
626    }
627}
628
629impl<'a, T: Item, D> Iterator for Cursor<'a, T, D>
630where
631    D: Dimension<'a, T::Summary>,
632{
633    type Item = &'a T;
634
635    fn next(&mut self) -> Option<Self::Item> {
636        if !self.did_seek {
637            self.next();
638        }
639
640        if let Some(item) = self.item() {
641            self.next();
642            Some(item)
643        } else {
644            None
645        }
646    }
647}
648
649pub struct FilterCursor<'a, F, T: Item, D> {
650    cursor: Cursor<'a, T, D>,
651    filter_node: F,
652}
653
654impl<'a, F, T: Item, D> FilterCursor<'a, F, T, D>
655where
656    F: FnMut(&T::Summary) -> Ordering,
657    T: Item,
658    D: Dimension<'a, T::Summary>,
659{
660    pub fn new(
661        tree: &'a SumTree<T>,
662        cx: &'a <T::Summary as Summary>::Context,
663        filter_node: F,
664    ) -> Self {
665        let cursor = tree.cursor::<D>(cx);
666        Self {
667            cursor,
668            filter_node,
669        }
670    }
671
672    pub fn start(&self) -> &D {
673        self.cursor.start()
674    }
675
676    pub fn end(&self) -> D {
677        self.cursor.end()
678    }
679
680    pub fn item(&self) -> Option<&'a T> {
681        self.cursor.item()
682    }
683
684    pub fn item_summary(&self) -> Option<&'a T::Summary> {
685        self.cursor.item_summary()
686    }
687
688    pub fn next(&mut self) {
689        self.cursor.search_forward(&mut self.filter_node);
690    }
691
692    pub fn prev(&mut self) {
693        self.cursor.search_backward(&mut self.filter_node);
694    }
695}
696
697impl<'a, F, T: Item, U> Iterator for FilterCursor<'a, F, T, U>
698where
699    F: FnMut(&T::Summary) -> Ordering,
700    U: Dimension<'a, T::Summary>,
701{
702    type Item = &'a T;
703
704    fn next(&mut self) -> Option<Self::Item> {
705        if !self.cursor.did_seek {
706            self.next();
707        }
708
709        if let Some(item) = self.item() {
710            self.cursor.search_forward(&mut self.filter_node);
711            Some(item)
712        } else {
713            None
714        }
715    }
716}
717
718trait SeekAggregate<'a, T: Item> {
719    fn begin_leaf(&mut self);
720    fn end_leaf(&mut self, cx: &<T::Summary as Summary>::Context);
721    fn push_item(
722        &mut self,
723        item: &'a T,
724        summary: &'a T::Summary,
725        cx: &<T::Summary as Summary>::Context,
726    );
727    fn push_tree(
728        &mut self,
729        tree: &'a SumTree<T>,
730        summary: &'a T::Summary,
731        cx: &<T::Summary as Summary>::Context,
732    );
733}
734
735struct SliceSeekAggregate<T: Item> {
736    tree: SumTree<T>,
737    leaf_items: ArrayVec<T, { 2 * TREE_BASE }>,
738    leaf_item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }>,
739    leaf_summary: T::Summary,
740}
741
742struct SummarySeekAggregate<D>(D);
743
744impl<T: Item> SeekAggregate<'_, T> for () {
745    fn begin_leaf(&mut self) {}
746    fn end_leaf(&mut self, _: &<T::Summary as Summary>::Context) {}
747    fn push_item(&mut self, _: &T, _: &T::Summary, _: &<T::Summary as Summary>::Context) {}
748    fn push_tree(&mut self, _: &SumTree<T>, _: &T::Summary, _: &<T::Summary as Summary>::Context) {}
749}
750
751impl<T: Item> SeekAggregate<'_, T> for SliceSeekAggregate<T> {
752    fn begin_leaf(&mut self) {}
753    fn end_leaf(&mut self, cx: &<T::Summary as Summary>::Context) {
754        self.tree.append(
755            SumTree(Arc::new(Node::Leaf {
756                items: mem::take(&mut self.leaf_items),
757                item_summaries: mem::take(&mut self.leaf_item_summaries),
758            })),
759            cx,
760        );
761    }
762    fn push_item(&mut self, item: &T, summary: &T::Summary, cx: &<T::Summary as Summary>::Context) {
763        self.leaf_items.push(item.clone());
764        self.leaf_item_summaries.push(summary.clone());
765        Summary::add_summary(&mut self.leaf_summary, summary, cx);
766    }
767    fn push_tree(
768        &mut self,
769        tree: &SumTree<T>,
770        _: &T::Summary,
771        cx: &<T::Summary as Summary>::Context,
772    ) {
773        self.tree.append(tree.clone(), cx);
774    }
775}
776
777impl<'a, T: Item, D> SeekAggregate<'a, T> for SummarySeekAggregate<D>
778where
779    D: Dimension<'a, T::Summary>,
780{
781    fn begin_leaf(&mut self) {}
782    fn end_leaf(&mut self, _: &<T::Summary as Summary>::Context) {}
783    fn push_item(&mut self, _: &T, summary: &'a T::Summary, cx: &<T::Summary as Summary>::Context) {
784        self.0.add_summary(summary, cx);
785    }
786    fn push_tree(
787        &mut self,
788        _: &SumTree<T>,
789        summary: &'a T::Summary,
790        cx: &<T::Summary as Summary>::Context,
791    ) {
792        self.0.add_summary(summary, cx);
793    }
794}
795
796struct End<D>(PhantomData<D>);
797
798impl<D> End<D> {
799    fn new() -> Self {
800        Self(PhantomData)
801    }
802}
803
804impl<'a, S: Summary, D: Dimension<'a, S>> SeekTarget<'a, S, D> for End<D> {
805    fn cmp(&self, _: &D, _: &S::Context) -> Ordering {
806        Ordering::Greater
807    }
808}
809
810impl<D> fmt::Debug for End<D> {
811    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
812        f.debug_tuple("End").finish()
813    }
814}