cursor.rs

  1use super::*;
  2use arrayvec::ArrayVec;
  3use std::{cmp::Ordering, sync::Arc};
  4
  5#[derive(Clone)]
  6struct StackEntry<'a, T: Item, S, U> {
  7    tree: &'a SumTree<T>,
  8    index: usize,
  9    seek_dimension: S,
 10    sum_dimension: U,
 11}
 12
 13impl<'a, T, S, U> StackEntry<'a, T, S, U>
 14where
 15    T: Item,
 16    S: SeekDimension<'a, T::Summary>,
 17    U: SeekDimension<'a, T::Summary>,
 18{
 19    fn swap_dimensions(self) -> StackEntry<'a, T, U, S> {
 20        StackEntry {
 21            tree: self.tree,
 22            index: self.index,
 23            seek_dimension: self.sum_dimension,
 24            sum_dimension: self.seek_dimension,
 25        }
 26    }
 27}
 28
 29#[derive(Clone)]
 30pub struct Cursor<'a, T: Item, S, U> {
 31    tree: &'a SumTree<T>,
 32    stack: ArrayVec<StackEntry<'a, T, S, U>, 16>,
 33    seek_dimension: S,
 34    sum_dimension: U,
 35    did_seek: bool,
 36    at_end: bool,
 37}
 38
 39impl<'a, T, S, U> Cursor<'a, T, S, U>
 40where
 41    T: Item,
 42    S: Dimension<'a, T::Summary>,
 43    U: Dimension<'a, T::Summary>,
 44{
 45    pub fn new(tree: &'a SumTree<T>) -> Self {
 46        Self {
 47            tree,
 48            stack: ArrayVec::new(),
 49            seek_dimension: S::default(),
 50            sum_dimension: U::default(),
 51            did_seek: false,
 52            at_end: false,
 53        }
 54    }
 55
 56    fn reset(&mut self) {
 57        self.did_seek = false;
 58        self.at_end = false;
 59        self.stack.truncate(0);
 60        self.seek_dimension = S::default();
 61        self.sum_dimension = U::default();
 62    }
 63
 64    pub fn seek_start(&self) -> &S {
 65        &self.seek_dimension
 66    }
 67
 68    pub fn seek_end(&self, cx: &<T::Summary as Summary>::Context) -> S {
 69        if let Some(item_summary) = self.item_summary() {
 70            let mut end = self.seek_start().clone();
 71            end.add_summary(item_summary, cx);
 72            end
 73        } else {
 74            self.seek_start().clone()
 75        }
 76    }
 77
 78    pub fn sum_start(&self) -> &U {
 79        &self.sum_dimension
 80    }
 81
 82    pub fn sum_end(&self, cx: &<T::Summary as Summary>::Context) -> U {
 83        if let Some(item_summary) = self.item_summary() {
 84            let mut end = self.sum_start().clone();
 85            end.add_summary(item_summary, cx);
 86            end
 87        } else {
 88            self.sum_start().clone()
 89        }
 90    }
 91
 92    pub fn item(&self) -> Option<&'a T> {
 93        assert!(self.did_seek, "Must seek before calling this method");
 94        if let Some(entry) = self.stack.last() {
 95            match *entry.tree.0 {
 96                Node::Leaf { ref items, .. } => {
 97                    if entry.index == items.len() {
 98                        None
 99                    } else {
100                        Some(&items[entry.index])
101                    }
102                }
103                _ => unreachable!(),
104            }
105        } else {
106            None
107        }
108    }
109
110    pub fn item_summary(&self) -> Option<&'a T::Summary> {
111        assert!(self.did_seek, "Must seek before calling this method");
112        if let Some(entry) = self.stack.last() {
113            match *entry.tree.0 {
114                Node::Leaf {
115                    ref item_summaries, ..
116                } => {
117                    if entry.index == item_summaries.len() {
118                        None
119                    } else {
120                        Some(&item_summaries[entry.index])
121                    }
122                }
123                _ => unreachable!(),
124            }
125        } else {
126            None
127        }
128    }
129
130    pub fn prev_item(&self) -> Option<&'a T> {
131        assert!(self.did_seek, "Must seek before calling this method");
132        if let Some(entry) = self.stack.last() {
133            if entry.index == 0 {
134                if let Some(prev_leaf) = self.prev_leaf() {
135                    Some(prev_leaf.0.items().last().unwrap())
136                } else {
137                    None
138                }
139            } else {
140                match *entry.tree.0 {
141                    Node::Leaf { ref items, .. } => Some(&items[entry.index - 1]),
142                    _ => unreachable!(),
143                }
144            }
145        } else if self.at_end {
146            self.tree.last()
147        } else {
148            None
149        }
150    }
151
152    fn prev_leaf(&self) -> Option<&'a SumTree<T>> {
153        for entry in self.stack.iter().rev().skip(1) {
154            if entry.index != 0 {
155                match *entry.tree.0 {
156                    Node::Internal {
157                        ref child_trees, ..
158                    } => return Some(child_trees[entry.index - 1].rightmost_leaf()),
159                    Node::Leaf { .. } => unreachable!(),
160                };
161            }
162        }
163        None
164    }
165
166    #[allow(unused)]
167    pub fn prev(&mut self, cx: &<T::Summary as Summary>::Context) {
168        assert!(self.did_seek, "Must seek before calling this method");
169
170        if self.at_end {
171            self.seek_dimension = S::default();
172            self.sum_dimension = U::default();
173            self.descend_to_last_item(self.tree, cx);
174            self.at_end = false;
175        } else {
176            while let Some(entry) = self.stack.pop() {
177                if entry.index > 0 {
178                    let new_index = entry.index - 1;
179
180                    if let Some(StackEntry {
181                        seek_dimension,
182                        sum_dimension,
183                        ..
184                    }) = self.stack.last()
185                    {
186                        self.seek_dimension = seek_dimension.clone();
187                        self.sum_dimension = sum_dimension.clone();
188                    } else {
189                        self.seek_dimension = S::default();
190                        self.sum_dimension = U::default();
191                    }
192
193                    match entry.tree.0.as_ref() {
194                        Node::Internal {
195                            child_trees,
196                            child_summaries,
197                            ..
198                        } => {
199                            for summary in &child_summaries[0..new_index] {
200                                self.seek_dimension.add_summary(summary, cx);
201                                self.sum_dimension.add_summary(summary, cx);
202                            }
203                            self.stack.push(StackEntry {
204                                tree: entry.tree,
205                                index: new_index,
206                                seek_dimension: self.seek_dimension.clone(),
207                                sum_dimension: self.sum_dimension.clone(),
208                            });
209                            self.descend_to_last_item(&child_trees[new_index], cx);
210                        }
211                        Node::Leaf { item_summaries, .. } => {
212                            for item_summary in &item_summaries[0..new_index] {
213                                self.seek_dimension.add_summary(item_summary, cx);
214                                self.sum_dimension.add_summary(item_summary, cx);
215                            }
216                            self.stack.push(StackEntry {
217                                tree: entry.tree,
218                                index: new_index,
219                                seek_dimension: self.seek_dimension.clone(),
220                                sum_dimension: self.sum_dimension.clone(),
221                            });
222                        }
223                    }
224
225                    break;
226                }
227            }
228        }
229    }
230
231    pub fn next(&mut self, cx: &<T::Summary as Summary>::Context) {
232        self.next_internal(|_| true, cx)
233    }
234
235    fn next_internal<F>(&mut self, filter_node: F, cx: &<T::Summary as Summary>::Context)
236    where
237        F: Fn(&T::Summary) -> bool,
238    {
239        let mut descend = false;
240
241        if self.stack.is_empty() && !self.at_end {
242            self.stack.push(StackEntry {
243                tree: self.tree,
244                index: 0,
245                seek_dimension: S::default(),
246                sum_dimension: U::default(),
247            });
248            descend = true;
249            self.did_seek = true;
250        }
251
252        while self.stack.len() > 0 {
253            let new_subtree = {
254                let entry = self.stack.last_mut().unwrap();
255                match entry.tree.0.as_ref() {
256                    Node::Internal {
257                        child_trees,
258                        child_summaries,
259                        ..
260                    } => {
261                        if !descend {
262                            entry.seek_dimension = self.seek_dimension.clone();
263                            entry.sum_dimension = self.sum_dimension.clone();
264                            entry.index += 1;
265                        }
266
267                        while entry.index < child_summaries.len() {
268                            let next_summary = &child_summaries[entry.index];
269                            if filter_node(next_summary) {
270                                break;
271                            } else {
272                                self.seek_dimension.add_summary(next_summary, cx);
273                                self.sum_dimension.add_summary(next_summary, cx);
274                            }
275                            entry.index += 1;
276                        }
277
278                        child_trees.get(entry.index)
279                    }
280                    Node::Leaf { item_summaries, .. } => {
281                        if !descend {
282                            let item_summary = &item_summaries[entry.index];
283                            self.seek_dimension.add_summary(item_summary, cx);
284                            entry.seek_dimension.add_summary(item_summary, cx);
285                            self.sum_dimension.add_summary(item_summary, cx);
286                            entry.sum_dimension.add_summary(item_summary, cx);
287                            entry.index += 1;
288                        }
289
290                        loop {
291                            if let Some(next_item_summary) = item_summaries.get(entry.index) {
292                                if filter_node(next_item_summary) {
293                                    return;
294                                } else {
295                                    self.seek_dimension.add_summary(next_item_summary, cx);
296                                    entry.seek_dimension.add_summary(next_item_summary, cx);
297                                    self.sum_dimension.add_summary(next_item_summary, cx);
298                                    entry.sum_dimension.add_summary(next_item_summary, cx);
299                                    entry.index += 1;
300                                }
301                            } else {
302                                break None;
303                            }
304                        }
305                    }
306                }
307            };
308
309            if let Some(subtree) = new_subtree {
310                descend = true;
311                self.stack.push(StackEntry {
312                    tree: subtree,
313                    index: 0,
314                    seek_dimension: self.seek_dimension.clone(),
315                    sum_dimension: self.sum_dimension.clone(),
316                });
317            } else {
318                descend = false;
319                self.stack.pop();
320            }
321        }
322
323        self.at_end = self.stack.is_empty();
324        debug_assert!(self.stack.is_empty() || self.stack.last().unwrap().tree.0.is_leaf());
325    }
326
327    fn descend_to_last_item(
328        &mut self,
329        mut subtree: &'a SumTree<T>,
330        cx: &<T::Summary as Summary>::Context,
331    ) {
332        self.did_seek = true;
333        loop {
334            match subtree.0.as_ref() {
335                Node::Internal {
336                    child_trees,
337                    child_summaries,
338                    ..
339                } => {
340                    for summary in &child_summaries[0..child_summaries.len() - 1] {
341                        self.seek_dimension.add_summary(summary, cx);
342                        self.sum_dimension.add_summary(summary, cx);
343                    }
344
345                    self.stack.push(StackEntry {
346                        tree: subtree,
347                        index: child_trees.len() - 1,
348                        seek_dimension: self.seek_dimension.clone(),
349                        sum_dimension: self.sum_dimension.clone(),
350                    });
351                    subtree = child_trees.last().unwrap();
352                }
353                Node::Leaf { item_summaries, .. } => {
354                    let last_index = item_summaries.len().saturating_sub(1);
355                    for item_summary in &item_summaries[0..last_index] {
356                        self.seek_dimension.add_summary(item_summary, cx);
357                        self.sum_dimension.add_summary(item_summary, cx);
358                    }
359                    self.stack.push(StackEntry {
360                        tree: subtree,
361                        index: last_index,
362                        seek_dimension: self.seek_dimension.clone(),
363                        sum_dimension: self.sum_dimension.clone(),
364                    });
365                    break;
366                }
367            }
368        }
369    }
370}
371
372impl<'a, T, S, U> Cursor<'a, T, S, U>
373where
374    T: Item,
375    S: SeekDimension<'a, T::Summary>,
376    U: Dimension<'a, T::Summary>,
377{
378    pub fn seek(&mut self, pos: &S, bias: Bias, cx: &<T::Summary as Summary>::Context) -> bool {
379        self.reset();
380        self.seek_internal::<()>(Some(pos), bias, &mut SeekAggregate::None, cx)
381    }
382
383    pub fn seek_forward(
384        &mut self,
385        pos: &S,
386        bias: Bias,
387        cx: &<T::Summary as Summary>::Context,
388    ) -> bool {
389        self.seek_internal::<()>(Some(pos), bias, &mut SeekAggregate::None, cx)
390    }
391
392    pub fn slice(
393        &mut self,
394        end: &S,
395        bias: Bias,
396        cx: &<T::Summary as Summary>::Context,
397    ) -> SumTree<T> {
398        let mut slice = SeekAggregate::Slice(SumTree::new());
399        self.seek_internal::<()>(Some(end), bias, &mut slice, cx);
400        if let SeekAggregate::Slice(slice) = slice {
401            slice
402        } else {
403            unreachable!()
404        }
405    }
406
407    pub fn suffix(&mut self, cx: &<T::Summary as Summary>::Context) -> SumTree<T> {
408        let mut slice = SeekAggregate::Slice(SumTree::new());
409        self.seek_internal::<()>(None, Bias::Right, &mut slice, cx);
410        if let SeekAggregate::Slice(slice) = slice {
411            slice
412        } else {
413            unreachable!()
414        }
415    }
416
417    pub fn summary<D>(&mut self, end: &S, bias: Bias, cx: &<T::Summary as Summary>::Context) -> D
418    where
419        D: Dimension<'a, T::Summary>,
420    {
421        let mut summary = SeekAggregate::Summary(D::default());
422        self.seek_internal(Some(end), bias, &mut summary, cx);
423        if let SeekAggregate::Summary(summary) = summary {
424            summary
425        } else {
426            unreachable!()
427        }
428    }
429
430    fn seek_internal<D>(
431        &mut self,
432        target: Option<&S>,
433        bias: Bias,
434        aggregate: &mut SeekAggregate<T, D>,
435        cx: &<T::Summary as Summary>::Context,
436    ) -> bool
437    where
438        D: Dimension<'a, T::Summary>,
439    {
440        if let Some(target) = target {
441            debug_assert!(
442                target.cmp(&self.seek_dimension, cx) >= Ordering::Equal,
443                "cannot seek backward from {:?} to {:?}",
444                self.seek_dimension,
445                target
446            );
447        }
448
449        if !self.did_seek {
450            self.did_seek = true;
451            self.stack.push(StackEntry {
452                tree: self.tree,
453                index: 0,
454                seek_dimension: Default::default(),
455                sum_dimension: Default::default(),
456            });
457        }
458
459        let mut ascending = false;
460        'outer: while let Some(entry) = self.stack.last_mut() {
461            match *entry.tree.0 {
462                Node::Internal {
463                    ref child_summaries,
464                    ref child_trees,
465                    ..
466                } => {
467                    if ascending {
468                        entry.index += 1;
469                    }
470
471                    for (child_tree, child_summary) in child_trees[entry.index..]
472                        .iter()
473                        .zip(&child_summaries[entry.index..])
474                    {
475                        let mut child_end = self.seek_dimension.clone();
476                        child_end.add_summary(&child_summary, cx);
477
478                        let comparison =
479                            target.map_or(Ordering::Greater, |t| t.cmp(&child_end, cx));
480                        if comparison == Ordering::Greater
481                            || (comparison == Ordering::Equal && bias == Bias::Right)
482                        {
483                            self.seek_dimension = child_end;
484                            self.sum_dimension.add_summary(child_summary, cx);
485                            match aggregate {
486                                SeekAggregate::None => {}
487                                SeekAggregate::Slice(slice) => {
488                                    slice.push_tree(child_tree.clone(), cx);
489                                }
490                                SeekAggregate::Summary(summary) => {
491                                    summary.add_summary(child_summary, cx);
492                                }
493                            }
494                            entry.index += 1;
495                            entry.seek_dimension = self.seek_dimension.clone();
496                            entry.sum_dimension = self.sum_dimension.clone();
497                        } else {
498                            self.stack.push(StackEntry {
499                                tree: child_tree,
500                                index: 0,
501                                seek_dimension: self.seek_dimension.clone(),
502                                sum_dimension: self.sum_dimension.clone(),
503                            });
504                            ascending = false;
505                            continue 'outer;
506                        }
507                    }
508                }
509                Node::Leaf {
510                    ref items,
511                    ref item_summaries,
512                    ..
513                } => {
514                    let mut slice_items = ArrayVec::<T, { 2 * TREE_BASE }>::new();
515                    let mut slice_item_summaries = ArrayVec::<T::Summary, { 2 * TREE_BASE }>::new();
516                    let mut slice_items_summary = match aggregate {
517                        SeekAggregate::Slice(_) => Some(T::Summary::default()),
518                        _ => None,
519                    };
520
521                    for (item, item_summary) in items[entry.index..]
522                        .iter()
523                        .zip(&item_summaries[entry.index..])
524                    {
525                        let mut child_end = self.seek_dimension.clone();
526                        child_end.add_summary(item_summary, cx);
527
528                        let comparison =
529                            target.map_or(Ordering::Greater, |t| t.cmp(&child_end, cx));
530                        if comparison == Ordering::Greater
531                            || (comparison == Ordering::Equal && bias == Bias::Right)
532                        {
533                            self.seek_dimension = child_end;
534                            self.sum_dimension.add_summary(item_summary, cx);
535                            match aggregate {
536                                SeekAggregate::None => {}
537                                SeekAggregate::Slice(_) => {
538                                    slice_items.push(item.clone());
539                                    slice_item_summaries.push(item_summary.clone());
540                                    slice_items_summary
541                                        .as_mut()
542                                        .unwrap()
543                                        .add_summary(item_summary, cx);
544                                }
545                                SeekAggregate::Summary(summary) => {
546                                    summary.add_summary(item_summary, cx);
547                                }
548                            }
549                            entry.index += 1;
550                        } else {
551                            if let SeekAggregate::Slice(slice) = aggregate {
552                                slice.push_tree(
553                                    SumTree(Arc::new(Node::Leaf {
554                                        summary: slice_items_summary.unwrap(),
555                                        items: slice_items,
556                                        item_summaries: slice_item_summaries,
557                                    })),
558                                    cx,
559                                );
560                            }
561                            break 'outer;
562                        }
563                    }
564
565                    if let SeekAggregate::Slice(slice) = aggregate {
566                        if !slice_items.is_empty() {
567                            slice.push_tree(
568                                SumTree(Arc::new(Node::Leaf {
569                                    summary: slice_items_summary.unwrap(),
570                                    items: slice_items,
571                                    item_summaries: slice_item_summaries,
572                                })),
573                                cx,
574                            );
575                        }
576                    }
577                }
578            }
579
580            self.stack.pop();
581            ascending = true;
582        }
583
584        self.at_end = self.stack.is_empty();
585        debug_assert!(self.stack.is_empty() || self.stack.last().unwrap().tree.0.is_leaf());
586
587        let mut end = self.seek_dimension.clone();
588        if bias == Bias::Left {
589            if let Some(summary) = self.item_summary() {
590                end.add_summary(summary, cx);
591            }
592        }
593
594        target.map_or(false, |t| t.cmp(&end, cx) == Ordering::Equal)
595    }
596}
597
598impl<'a, T, S, Seek, Sum> Iterator for Cursor<'a, T, Seek, Sum>
599where
600    T: Item<Summary = S>,
601    S: Summary<Context = ()>,
602    Seek: Dimension<'a, T::Summary>,
603    Sum: Dimension<'a, T::Summary>,
604{
605    type Item = &'a T;
606
607    fn next(&mut self) -> Option<Self::Item> {
608        if !self.did_seek {
609            self.next(&());
610        }
611
612        if let Some(item) = self.item() {
613            self.next(&());
614            Some(item)
615        } else {
616            None
617        }
618    }
619}
620
621impl<'a, T, S, U> Cursor<'a, T, S, U>
622where
623    T: Item,
624    S: SeekDimension<'a, T::Summary>,
625    U: SeekDimension<'a, T::Summary>,
626{
627    pub fn swap_dimensions(self) -> Cursor<'a, T, U, S> {
628        Cursor {
629            tree: self.tree,
630            stack: self
631                .stack
632                .into_iter()
633                .map(StackEntry::swap_dimensions)
634                .collect(),
635            seek_dimension: self.sum_dimension,
636            sum_dimension: self.seek_dimension,
637            did_seek: self.did_seek,
638            at_end: self.at_end,
639        }
640    }
641}
642
643pub struct FilterCursor<'a, F: Fn(&T::Summary) -> bool, T: Item, U> {
644    cursor: Cursor<'a, T, (), U>,
645    filter_node: F,
646}
647
648impl<'a, F, T, U> FilterCursor<'a, F, T, U>
649where
650    F: Fn(&T::Summary) -> bool,
651    T: Item,
652    U: Dimension<'a, T::Summary>,
653{
654    pub fn new(
655        tree: &'a SumTree<T>,
656        filter_node: F,
657        cx: &<T::Summary as Summary>::Context,
658    ) -> Self {
659        let mut cursor = tree.cursor::<(), U>();
660        cursor.next_internal(&filter_node, cx);
661        Self {
662            cursor,
663            filter_node,
664        }
665    }
666
667    pub fn start(&self) -> &U {
668        self.cursor.sum_start()
669    }
670
671    pub fn item(&self) -> Option<&'a T> {
672        self.cursor.item()
673    }
674
675    pub fn next(&mut self, cx: &<T::Summary as Summary>::Context) {
676        self.cursor.next_internal(&self.filter_node, cx);
677    }
678}
679
680impl<'a, F, T, S, U> Iterator for FilterCursor<'a, F, T, U>
681where
682    F: Fn(&T::Summary) -> bool,
683    T: Item<Summary = S>,
684    S: Summary<Context = ()>,
685    U: Dimension<'a, T::Summary>,
686{
687    type Item = &'a T;
688
689    fn next(&mut self) -> Option<Self::Item> {
690        if let Some(item) = self.item() {
691            self.cursor.next_internal(&self.filter_node, &());
692            Some(item)
693        } else {
694            None
695        }
696    }
697}
698
699enum SeekAggregate<T: Item, D> {
700    None,
701    Slice(SumTree<T>),
702    Summary(D),
703}