cursor.rs

  1use super::*;
  2use arrayvec::ArrayVec;
  3use std::{cmp::Ordering, sync::Arc};
  4
  5#[derive(Clone)]
  6struct StackEntry<'a, T: Item, S, U> {
  7    tree: &'a SumTree<T>,
  8    index: usize,
  9    seek_dimension: S,
 10    sum_dimension: U,
 11}
 12
 13#[derive(Clone)]
 14pub struct Cursor<'a, T: Item, S, U> {
 15    tree: &'a SumTree<T>,
 16    stack: ArrayVec<[StackEntry<'a, T, S, U>; 16]>,
 17    seek_dimension: S,
 18    sum_dimension: U,
 19    did_seek: bool,
 20    at_end: bool,
 21}
 22
 23impl<'a, T, S, U> Cursor<'a, T, S, U>
 24where
 25    T: Item,
 26    S: Dimension<'a, T::Summary>,
 27    U: Dimension<'a, T::Summary>,
 28{
 29    pub fn new(tree: &'a SumTree<T>) -> Self {
 30        Self {
 31            tree,
 32            stack: ArrayVec::new(),
 33            seek_dimension: S::default(),
 34            sum_dimension: U::default(),
 35            did_seek: false,
 36            at_end: false,
 37        }
 38    }
 39
 40    fn reset(&mut self) {
 41        self.did_seek = false;
 42        self.at_end = false;
 43        self.stack.truncate(0);
 44        self.seek_dimension = S::default();
 45        self.sum_dimension = U::default();
 46    }
 47
 48    pub fn start(&self) -> &U {
 49        &self.sum_dimension
 50    }
 51
 52    pub fn end(&self) -> U {
 53        if let Some(item_summary) = self.item_summary() {
 54            let mut end = self.start().clone();
 55            end.add_summary(item_summary);
 56            end
 57        } else {
 58            self.start().clone()
 59        }
 60    }
 61
 62    pub fn item(&self) -> Option<&'a T> {
 63        assert!(self.did_seek, "Must seek before calling this method");
 64        if let Some(entry) = self.stack.last() {
 65            match *entry.tree.0 {
 66                Node::Leaf { ref items, .. } => {
 67                    if entry.index == items.len() {
 68                        None
 69                    } else {
 70                        Some(&items[entry.index])
 71                    }
 72                }
 73                _ => unreachable!(),
 74            }
 75        } else {
 76            None
 77        }
 78    }
 79
 80    pub fn item_summary(&self) -> Option<&'a T::Summary> {
 81        assert!(self.did_seek, "Must seek before calling this method");
 82        if let Some(entry) = self.stack.last() {
 83            match *entry.tree.0 {
 84                Node::Leaf {
 85                    ref item_summaries, ..
 86                } => {
 87                    if entry.index == item_summaries.len() {
 88                        None
 89                    } else {
 90                        Some(&item_summaries[entry.index])
 91                    }
 92                }
 93                _ => unreachable!(),
 94            }
 95        } else {
 96            None
 97        }
 98    }
 99
100    pub fn prev_item(&self) -> Option<&'a T> {
101        assert!(self.did_seek, "Must seek before calling this method");
102        if let Some(entry) = self.stack.last() {
103            if entry.index == 0 {
104                if let Some(prev_leaf) = self.prev_leaf() {
105                    Some(prev_leaf.0.items().last().unwrap())
106                } else {
107                    None
108                }
109            } else {
110                match *entry.tree.0 {
111                    Node::Leaf { ref items, .. } => Some(&items[entry.index - 1]),
112                    _ => unreachable!(),
113                }
114            }
115        } else if self.at_end {
116            self.tree.last()
117        } else {
118            None
119        }
120    }
121
122    fn prev_leaf(&self) -> Option<&'a SumTree<T>> {
123        for entry in self.stack.iter().rev().skip(1) {
124            if entry.index != 0 {
125                match *entry.tree.0 {
126                    Node::Internal {
127                        ref child_trees, ..
128                    } => return Some(child_trees[entry.index - 1].rightmost_leaf()),
129                    Node::Leaf { .. } => unreachable!(),
130                };
131            }
132        }
133        None
134    }
135
136    #[allow(unused)]
137    pub fn prev(&mut self) {
138        assert!(self.did_seek, "Must seek before calling this method");
139
140        if self.at_end {
141            self.seek_dimension = S::default();
142            self.sum_dimension = U::default();
143            self.descend_to_last_item(self.tree);
144            self.at_end = false;
145        } else {
146            while let Some(entry) = self.stack.pop() {
147                if entry.index > 0 {
148                    let new_index = entry.index - 1;
149
150                    if let Some(StackEntry {
151                        seek_dimension,
152                        sum_dimension,
153                        ..
154                    }) = self.stack.last()
155                    {
156                        self.seek_dimension = seek_dimension.clone();
157                        self.sum_dimension = sum_dimension.clone();
158                    } else {
159                        self.seek_dimension = S::default();
160                        self.sum_dimension = U::default();
161                    }
162
163                    match entry.tree.0.as_ref() {
164                        Node::Internal {
165                            child_trees,
166                            child_summaries,
167                            ..
168                        } => {
169                            for summary in &child_summaries[0..new_index] {
170                                self.seek_dimension.add_summary(summary);
171                                self.sum_dimension.add_summary(summary);
172                            }
173                            self.stack.push(StackEntry {
174                                tree: entry.tree,
175                                index: new_index,
176                                seek_dimension: self.seek_dimension.clone(),
177                                sum_dimension: self.sum_dimension.clone(),
178                            });
179                            self.descend_to_last_item(&child_trees[new_index]);
180                        }
181                        Node::Leaf { item_summaries, .. } => {
182                            for item_summary in &item_summaries[0..new_index] {
183                                self.seek_dimension.add_summary(item_summary);
184                                self.sum_dimension.add_summary(item_summary);
185                            }
186                            self.stack.push(StackEntry {
187                                tree: entry.tree,
188                                index: new_index,
189                                seek_dimension: self.seek_dimension.clone(),
190                                sum_dimension: self.sum_dimension.clone(),
191                            });
192                        }
193                    }
194
195                    break;
196                }
197            }
198        }
199    }
200
201    pub fn next(&mut self) {
202        self.next_internal(|_| true)
203    }
204
205    fn next_internal<F>(&mut self, filter_node: F)
206    where
207        F: Fn(&T::Summary) -> bool,
208    {
209        let mut descend = false;
210
211        if self.stack.is_empty() && !self.at_end {
212            self.stack.push(StackEntry {
213                tree: self.tree,
214                index: 0,
215                seek_dimension: S::default(),
216                sum_dimension: U::default(),
217            });
218            descend = true;
219            self.did_seek = true;
220        }
221
222        while self.stack.len() > 0 {
223            let new_subtree = {
224                let entry = self.stack.last_mut().unwrap();
225                match entry.tree.0.as_ref() {
226                    Node::Internal {
227                        child_trees,
228                        child_summaries,
229                        ..
230                    } => {
231                        if !descend {
232                            let summary = &child_summaries[entry.index];
233                            entry.seek_dimension.add_summary(summary);
234                            entry.sum_dimension.add_summary(summary);
235                            entry.index += 1;
236                        }
237
238                        while entry.index < child_summaries.len() {
239                            let next_summary = &child_summaries[entry.index];
240                            if filter_node(next_summary) {
241                                break;
242                            } else {
243                                self.seek_dimension.add_summary(next_summary);
244                                self.sum_dimension.add_summary(next_summary);
245                            }
246                            entry.index += 1;
247                        }
248
249                        child_trees.get(entry.index)
250                    }
251                    Node::Leaf { item_summaries, .. } => {
252                        if !descend {
253                            let item_summary = &item_summaries[entry.index];
254                            self.seek_dimension.add_summary(item_summary);
255                            entry.seek_dimension.add_summary(item_summary);
256                            self.sum_dimension.add_summary(item_summary);
257                            entry.sum_dimension.add_summary(item_summary);
258                            entry.index += 1;
259                        }
260
261                        loop {
262                            if let Some(next_item_summary) = item_summaries.get(entry.index) {
263                                if filter_node(next_item_summary) {
264                                    return;
265                                } else {
266                                    self.seek_dimension.add_summary(next_item_summary);
267                                    entry.seek_dimension.add_summary(next_item_summary);
268                                    self.sum_dimension.add_summary(next_item_summary);
269                                    entry.sum_dimension.add_summary(next_item_summary);
270                                    entry.index += 1;
271                                }
272                            } else {
273                                break None;
274                            }
275                        }
276                    }
277                }
278            };
279
280            if let Some(subtree) = new_subtree {
281                descend = true;
282                self.stack.push(StackEntry {
283                    tree: subtree,
284                    index: 0,
285                    seek_dimension: self.seek_dimension.clone(),
286                    sum_dimension: self.sum_dimension.clone(),
287                });
288            } else {
289                descend = false;
290                self.stack.pop();
291            }
292        }
293
294        self.at_end = self.stack.is_empty();
295        debug_assert!(self.stack.is_empty() || self.stack.last().unwrap().tree.0.is_leaf());
296    }
297
298    fn descend_to_last_item(&mut self, mut subtree: &'a SumTree<T>) {
299        self.did_seek = true;
300        loop {
301            match subtree.0.as_ref() {
302                Node::Internal {
303                    child_trees,
304                    child_summaries,
305                    ..
306                } => {
307                    for summary in &child_summaries[0..child_summaries.len() - 1] {
308                        self.seek_dimension.add_summary(summary);
309                        self.sum_dimension.add_summary(summary);
310                    }
311
312                    self.stack.push(StackEntry {
313                        tree: subtree,
314                        index: child_trees.len() - 1,
315                        seek_dimension: self.seek_dimension.clone(),
316                        sum_dimension: self.sum_dimension.clone(),
317                    });
318                    subtree = child_trees.last().unwrap();
319                }
320                Node::Leaf { item_summaries, .. } => {
321                    let last_index = item_summaries.len().saturating_sub(1);
322                    for item_summary in &item_summaries[0..last_index] {
323                        self.seek_dimension.add_summary(item_summary);
324                        self.sum_dimension.add_summary(item_summary);
325                    }
326                    self.stack.push(StackEntry {
327                        tree: subtree,
328                        index: last_index,
329                        seek_dimension: self.seek_dimension.clone(),
330                        sum_dimension: self.sum_dimension.clone(),
331                    });
332                    break;
333                }
334            }
335        }
336    }
337}
338
339impl<'a, T, S, U> Cursor<'a, T, S, U>
340where
341    T: Item,
342    S: SeekDimension<'a, T::Summary>,
343    U: Dimension<'a, T::Summary>,
344{
345    pub fn seek(
346        &mut self,
347        pos: &S,
348        bias: SeekBias,
349        ctx: &<T::Summary as Summary>::Context,
350    ) -> bool {
351        self.reset();
352        self.seek_internal::<()>(pos, bias, &mut SeekAggregate::None, ctx)
353    }
354
355    pub fn seek_forward(
356        &mut self,
357        pos: &S,
358        bias: SeekBias,
359        ctx: &<T::Summary as Summary>::Context,
360    ) -> bool {
361        self.seek_internal::<()>(pos, bias, &mut SeekAggregate::None, ctx)
362    }
363
364    pub fn slice(
365        &mut self,
366        end: &S,
367        bias: SeekBias,
368        ctx: &<T::Summary as Summary>::Context,
369    ) -> SumTree<T> {
370        let mut slice = SeekAggregate::Slice(SumTree::new());
371        self.seek_internal::<()>(end, bias, &mut slice, ctx);
372        if let SeekAggregate::Slice(slice) = slice {
373            slice
374        } else {
375            unreachable!()
376        }
377    }
378
379    pub fn suffix(&mut self, ctx: &<T::Summary as Summary>::Context) -> SumTree<T> {
380        let extent = self.tree.extent::<S>();
381        let mut slice = SeekAggregate::Slice(SumTree::new());
382        self.seek_internal::<()>(&extent, SeekBias::Right, &mut slice, ctx);
383        if let SeekAggregate::Slice(slice) = slice {
384            slice
385        } else {
386            unreachable!()
387        }
388    }
389
390    pub fn summary<D>(
391        &mut self,
392        end: &S,
393        bias: SeekBias,
394        ctx: &<T::Summary as Summary>::Context,
395    ) -> D
396    where
397        D: Dimension<'a, T::Summary>,
398    {
399        let mut summary = SeekAggregate::Summary(D::default());
400        self.seek_internal(end, bias, &mut summary, ctx);
401        if let SeekAggregate::Summary(summary) = summary {
402            summary
403        } else {
404            unreachable!()
405        }
406    }
407
408    fn seek_internal<D>(
409        &mut self,
410        target: &S,
411        bias: SeekBias,
412        aggregate: &mut SeekAggregate<T, D>,
413        ctx: &<T::Summary as Summary>::Context,
414    ) -> bool
415    where
416        D: Dimension<'a, T::Summary>,
417    {
418        debug_assert!(target.cmp(&self.seek_dimension, ctx) >= Ordering::Equal);
419        let mut containing_subtree = None;
420
421        if self.did_seek {
422            'outer: while let Some(entry) = self.stack.last_mut() {
423                {
424                    match *entry.tree.0 {
425                        Node::Internal {
426                            ref child_summaries,
427                            ref child_trees,
428                            ..
429                        } => {
430                            entry.index += 1;
431                            for (child_tree, child_summary) in child_trees[entry.index..]
432                                .iter()
433                                .zip(&child_summaries[entry.index..])
434                            {
435                                let mut child_end = self.seek_dimension.clone();
436                                child_end.add_summary(&child_summary);
437
438                                let comparison = target.cmp(&child_end, ctx);
439                                if comparison == Ordering::Greater
440                                    || (comparison == Ordering::Equal && bias == SeekBias::Right)
441                                {
442                                    self.seek_dimension.add_summary(child_summary);
443                                    self.sum_dimension.add_summary(child_summary);
444                                    match aggregate {
445                                        SeekAggregate::None => {}
446                                        SeekAggregate::Slice(slice) => {
447                                            slice.push_tree(child_tree.clone(), ctx);
448                                        }
449                                        SeekAggregate::Summary(summary) => {
450                                            summary.add_summary(child_summary);
451                                        }
452                                    }
453                                    entry.index += 1;
454                                } else {
455                                    containing_subtree = Some(child_tree);
456                                    break 'outer;
457                                }
458                            }
459                        }
460                        Node::Leaf {
461                            ref items,
462                            ref item_summaries,
463                            ..
464                        } => {
465                            let mut slice_items = ArrayVec::<[T; 2 * TREE_BASE]>::new();
466                            let mut slice_item_summaries =
467                                ArrayVec::<[T::Summary; 2 * TREE_BASE]>::new();
468                            let mut slice_items_summary = match aggregate {
469                                SeekAggregate::Slice(_) => Some(T::Summary::default()),
470                                _ => None,
471                            };
472
473                            for (item, item_summary) in items[entry.index..]
474                                .iter()
475                                .zip(&item_summaries[entry.index..])
476                            {
477                                let mut item_end = self.seek_dimension.clone();
478                                item_end.add_summary(item_summary);
479
480                                let comparison = target.cmp(&item_end, ctx);
481                                if comparison == Ordering::Greater
482                                    || (comparison == Ordering::Equal && bias == SeekBias::Right)
483                                {
484                                    self.seek_dimension.add_summary(item_summary);
485                                    self.sum_dimension.add_summary(item_summary);
486                                    match aggregate {
487                                        SeekAggregate::None => {}
488                                        SeekAggregate::Slice(_) => {
489                                            slice_items.push(item.clone());
490                                            slice_item_summaries.push(item_summary.clone());
491                                            slice_items_summary
492                                                .as_mut()
493                                                .unwrap()
494                                                .add_summary(item_summary, ctx);
495                                        }
496                                        SeekAggregate::Summary(summary) => {
497                                            summary.add_summary(item_summary);
498                                        }
499                                    }
500                                    entry.index += 1;
501                                } else {
502                                    if let SeekAggregate::Slice(slice) = aggregate {
503                                        slice.push_tree(
504                                            SumTree(Arc::new(Node::Leaf {
505                                                summary: slice_items_summary.unwrap(),
506                                                items: slice_items,
507                                                item_summaries: slice_item_summaries,
508                                            })),
509                                            ctx,
510                                        );
511                                    }
512                                    break 'outer;
513                                }
514                            }
515
516                            if let SeekAggregate::Slice(slice) = aggregate {
517                                if !slice_items.is_empty() {
518                                    slice.push_tree(
519                                        SumTree(Arc::new(Node::Leaf {
520                                            summary: slice_items_summary.unwrap(),
521                                            items: slice_items,
522                                            item_summaries: slice_item_summaries,
523                                        })),
524                                        ctx,
525                                    );
526                                }
527                            }
528                        }
529                    }
530                }
531
532                self.stack.pop();
533            }
534        } else {
535            self.did_seek = true;
536            containing_subtree = Some(self.tree);
537        }
538
539        if let Some(mut subtree) = containing_subtree {
540            loop {
541                let mut next_subtree = None;
542                match *subtree.0 {
543                    Node::Internal {
544                        ref child_summaries,
545                        ref child_trees,
546                        ..
547                    } => {
548                        for (index, (child_tree, child_summary)) in
549                            child_trees.iter().zip(child_summaries).enumerate()
550                        {
551                            let mut child_end = self.seek_dimension.clone();
552                            child_end.add_summary(child_summary);
553
554                            let comparison = target.cmp(&child_end, ctx);
555                            if comparison == Ordering::Greater
556                                || (comparison == Ordering::Equal && bias == SeekBias::Right)
557                            {
558                                self.seek_dimension.add_summary(child_summary);
559                                self.sum_dimension.add_summary(child_summary);
560                                match aggregate {
561                                    SeekAggregate::None => {}
562                                    SeekAggregate::Slice(slice) => {
563                                        slice.push_tree(child_trees[index].clone(), ctx);
564                                    }
565                                    SeekAggregate::Summary(summary) => {
566                                        summary.add_summary(child_summary);
567                                    }
568                                }
569                            } else {
570                                self.stack.push(StackEntry {
571                                    tree: subtree,
572                                    index,
573                                    seek_dimension: self.seek_dimension.clone(),
574                                    sum_dimension: self.sum_dimension.clone(),
575                                });
576                                next_subtree = Some(child_tree);
577                                break;
578                            }
579                        }
580                    }
581                    Node::Leaf {
582                        ref items,
583                        ref item_summaries,
584                        ..
585                    } => {
586                        let mut slice_items = ArrayVec::<[T; 2 * TREE_BASE]>::new();
587                        let mut slice_item_summaries =
588                            ArrayVec::<[T::Summary; 2 * TREE_BASE]>::new();
589                        let mut slice_items_summary = match aggregate {
590                            SeekAggregate::Slice(_) => Some(T::Summary::default()),
591                            _ => None,
592                        };
593
594                        for (index, (item, item_summary)) in
595                            items.iter().zip(item_summaries).enumerate()
596                        {
597                            let mut child_end = self.seek_dimension.clone();
598                            child_end.add_summary(item_summary);
599
600                            let comparison = target.cmp(&child_end, ctx);
601                            if comparison == Ordering::Greater
602                                || (comparison == Ordering::Equal && bias == SeekBias::Right)
603                            {
604                                self.seek_dimension.add_summary(item_summary);
605                                self.sum_dimension.add_summary(item_summary);
606                                match aggregate {
607                                    SeekAggregate::None => {}
608                                    SeekAggregate::Slice(_) => {
609                                        slice_items.push(item.clone());
610                                        slice_items_summary
611                                            .as_mut()
612                                            .unwrap()
613                                            .add_summary(item_summary, ctx);
614                                        slice_item_summaries.push(item_summary.clone());
615                                    }
616                                    SeekAggregate::Summary(summary) => {
617                                        summary.add_summary(item_summary);
618                                    }
619                                }
620                            } else {
621                                self.stack.push(StackEntry {
622                                    tree: subtree,
623                                    index,
624                                    seek_dimension: self.seek_dimension.clone(),
625                                    sum_dimension: self.sum_dimension.clone(),
626                                });
627                                break;
628                            }
629                        }
630
631                        if let SeekAggregate::Slice(slice) = aggregate {
632                            if !slice_items.is_empty() {
633                                slice.push_tree(
634                                    SumTree(Arc::new(Node::Leaf {
635                                        summary: slice_items_summary.unwrap(),
636                                        items: slice_items,
637                                        item_summaries: slice_item_summaries,
638                                    })),
639                                    ctx,
640                                );
641                            }
642                        }
643                    }
644                };
645
646                if let Some(next_subtree) = next_subtree {
647                    subtree = next_subtree;
648                } else {
649                    break;
650                }
651            }
652        }
653
654        self.at_end = self.stack.is_empty();
655        debug_assert!(self.stack.is_empty() || self.stack.last().unwrap().tree.0.is_leaf());
656        if bias == SeekBias::Left {
657            let mut end = self.seek_dimension.clone();
658            if let Some(summary) = self.item_summary() {
659                end.add_summary(summary);
660            }
661            target.cmp(&end, ctx) == Ordering::Equal
662        } else {
663            target.cmp(&self.seek_dimension, ctx) == Ordering::Equal
664        }
665    }
666}
667
668impl<'a, T, S, U> Iterator for Cursor<'a, T, S, U>
669where
670    T: Item,
671    S: Dimension<'a, T::Summary>,
672    U: Dimension<'a, T::Summary>,
673{
674    type Item = &'a T;
675
676    fn next(&mut self) -> Option<Self::Item> {
677        if !self.did_seek {
678            self.next();
679        }
680
681        if let Some(item) = self.item() {
682            self.next();
683            Some(item)
684        } else {
685            None
686        }
687    }
688}
689
690pub struct FilterCursor<'a, F: Fn(&T::Summary) -> bool, T: Item, U> {
691    cursor: Cursor<'a, T, (), U>,
692    filter_node: F,
693}
694
695impl<'a, F, T, U> FilterCursor<'a, F, T, U>
696where
697    F: Fn(&T::Summary) -> bool,
698    T: Item,
699    U: Dimension<'a, T::Summary>,
700{
701    pub fn new(tree: &'a SumTree<T>, filter_node: F) -> Self {
702        let mut cursor = tree.cursor::<(), U>();
703        cursor.next_internal(&filter_node);
704        Self {
705            cursor,
706            filter_node,
707        }
708    }
709
710    pub fn start(&self) -> &U {
711        self.cursor.start()
712    }
713
714    pub fn item(&self) -> Option<&'a T> {
715        self.cursor.item()
716    }
717
718    pub fn next(&mut self) {
719        self.cursor.next_internal(&self.filter_node);
720    }
721}
722
723impl<'a, F, T, U> Iterator for FilterCursor<'a, F, T, U>
724where
725    F: Fn(&T::Summary) -> bool,
726    T: Item,
727    U: Dimension<'a, T::Summary>,
728{
729    type Item = &'a T;
730
731    fn next(&mut self) -> Option<Self::Item> {
732        if let Some(item) = self.item() {
733            self.cursor.next_internal(&self.filter_node);
734            Some(item)
735        } else {
736            None
737        }
738    }
739}
740
741enum SeekAggregate<T: Item, D> {
742    None,
743    Slice(SumTree<T>),
744    Summary(D),
745}