cursor.rs

  1use super::*;
  2use arrayvec::ArrayVec;
  3use std::{cmp::Ordering, sync::Arc};
  4
  5#[derive(Clone)]
  6struct StackEntry<'a, T: Item, S, U> {
  7    tree: &'a SumTree<T>,
  8    index: usize,
  9    seek_dimension: S,
 10    sum_dimension: U,
 11}
 12
 13impl<'a, T, S, U> StackEntry<'a, T, S, U>
 14where
 15    T: Item,
 16    S: SeekDimension<'a, T::Summary>,
 17    U: SeekDimension<'a, T::Summary>,
 18{
 19    fn swap_dimensions(self) -> StackEntry<'a, T, U, S> {
 20        StackEntry {
 21            tree: self.tree,
 22            index: self.index,
 23            seek_dimension: self.sum_dimension,
 24            sum_dimension: self.seek_dimension,
 25        }
 26    }
 27}
 28
 29#[derive(Clone)]
 30pub struct Cursor<'a, T: Item, S, U> {
 31    tree: &'a SumTree<T>,
 32    stack: ArrayVec<StackEntry<'a, T, S, U>, 16>,
 33    seek_dimension: S,
 34    sum_dimension: U,
 35    did_seek: bool,
 36    at_end: bool,
 37}
 38
 39impl<'a, T, S, U> Cursor<'a, T, S, U>
 40where
 41    T: Item,
 42    S: Dimension<'a, T::Summary>,
 43    U: Dimension<'a, T::Summary>,
 44{
 45    pub fn new(tree: &'a SumTree<T>) -> Self {
 46        Self {
 47            tree,
 48            stack: ArrayVec::new(),
 49            seek_dimension: S::default(),
 50            sum_dimension: U::default(),
 51            did_seek: false,
 52            at_end: false,
 53        }
 54    }
 55
 56    fn reset(&mut self) {
 57        self.did_seek = false;
 58        self.at_end = false;
 59        self.stack.truncate(0);
 60        self.seek_dimension = S::default();
 61        self.sum_dimension = U::default();
 62    }
 63
 64    pub fn seek_start(&self) -> &S {
 65        &self.seek_dimension
 66    }
 67
 68    pub fn seek_end(&self, cx: &<T::Summary as Summary>::Context) -> S {
 69        if let Some(item_summary) = self.item_summary() {
 70            let mut end = self.seek_start().clone();
 71            end.add_summary(item_summary, cx);
 72            end
 73        } else {
 74            self.seek_start().clone()
 75        }
 76    }
 77
 78    pub fn sum_start(&self) -> &U {
 79        &self.sum_dimension
 80    }
 81
 82    pub fn sum_end(&self, cx: &<T::Summary as Summary>::Context) -> U {
 83        if let Some(item_summary) = self.item_summary() {
 84            let mut end = self.sum_start().clone();
 85            end.add_summary(item_summary, cx);
 86            end
 87        } else {
 88            self.sum_start().clone()
 89        }
 90    }
 91
 92    pub fn item(&self) -> Option<&'a T> {
 93        assert!(self.did_seek, "Must seek before calling this method");
 94        if let Some(entry) = self.stack.last() {
 95            match *entry.tree.0 {
 96                Node::Leaf { ref items, .. } => {
 97                    if entry.index == items.len() {
 98                        None
 99                    } else {
100                        Some(&items[entry.index])
101                    }
102                }
103                _ => unreachable!(),
104            }
105        } else {
106            None
107        }
108    }
109
110    pub fn item_summary(&self) -> Option<&'a T::Summary> {
111        assert!(self.did_seek, "Must seek before calling this method");
112        if let Some(entry) = self.stack.last() {
113            match *entry.tree.0 {
114                Node::Leaf {
115                    ref item_summaries, ..
116                } => {
117                    if entry.index == item_summaries.len() {
118                        None
119                    } else {
120                        Some(&item_summaries[entry.index])
121                    }
122                }
123                _ => unreachable!(),
124            }
125        } else {
126            None
127        }
128    }
129
130    pub fn prev_item(&self) -> Option<&'a T> {
131        assert!(self.did_seek, "Must seek before calling this method");
132        if let Some(entry) = self.stack.last() {
133            if entry.index == 0 {
134                if let Some(prev_leaf) = self.prev_leaf() {
135                    Some(prev_leaf.0.items().last().unwrap())
136                } else {
137                    None
138                }
139            } else {
140                match *entry.tree.0 {
141                    Node::Leaf { ref items, .. } => Some(&items[entry.index - 1]),
142                    _ => unreachable!(),
143                }
144            }
145        } else if self.at_end {
146            self.tree.last()
147        } else {
148            None
149        }
150    }
151
152    fn prev_leaf(&self) -> Option<&'a SumTree<T>> {
153        for entry in self.stack.iter().rev().skip(1) {
154            if entry.index != 0 {
155                match *entry.tree.0 {
156                    Node::Internal {
157                        ref child_trees, ..
158                    } => return Some(child_trees[entry.index - 1].rightmost_leaf()),
159                    Node::Leaf { .. } => unreachable!(),
160                };
161            }
162        }
163        None
164    }
165
166    pub fn prev(&mut self, cx: &<T::Summary as Summary>::Context) {
167        assert!(self.did_seek, "Must seek before calling this method");
168
169        if self.at_end {
170            self.seek_dimension = S::default();
171            self.sum_dimension = U::default();
172            self.descend_to_last_item(self.tree, cx);
173            self.at_end = false;
174        } else {
175            while let Some(entry) = self.stack.pop() {
176                if entry.index > 0 {
177                    let new_index = entry.index - 1;
178
179                    if let Some(StackEntry {
180                        seek_dimension,
181                        sum_dimension,
182                        ..
183                    }) = self.stack.last()
184                    {
185                        self.seek_dimension = seek_dimension.clone();
186                        self.sum_dimension = sum_dimension.clone();
187                    } else {
188                        self.seek_dimension = S::default();
189                        self.sum_dimension = U::default();
190                    }
191
192                    match entry.tree.0.as_ref() {
193                        Node::Internal {
194                            child_trees,
195                            child_summaries,
196                            ..
197                        } => {
198                            for summary in &child_summaries[0..new_index] {
199                                self.seek_dimension.add_summary(summary, cx);
200                                self.sum_dimension.add_summary(summary, cx);
201                            }
202                            self.stack.push(StackEntry {
203                                tree: entry.tree,
204                                index: new_index,
205                                seek_dimension: self.seek_dimension.clone(),
206                                sum_dimension: self.sum_dimension.clone(),
207                            });
208                            self.descend_to_last_item(&child_trees[new_index], cx);
209                        }
210                        Node::Leaf { item_summaries, .. } => {
211                            for item_summary in &item_summaries[0..new_index] {
212                                self.seek_dimension.add_summary(item_summary, cx);
213                                self.sum_dimension.add_summary(item_summary, cx);
214                            }
215                            self.stack.push(StackEntry {
216                                tree: entry.tree,
217                                index: new_index,
218                                seek_dimension: self.seek_dimension.clone(),
219                                sum_dimension: self.sum_dimension.clone(),
220                            });
221                        }
222                    }
223
224                    break;
225                }
226            }
227        }
228    }
229
230    pub fn next(&mut self, cx: &<T::Summary as Summary>::Context) {
231        self.next_internal(|_| true, cx)
232    }
233
234    fn next_internal<F>(&mut self, filter_node: F, cx: &<T::Summary as Summary>::Context)
235    where
236        F: Fn(&T::Summary) -> bool,
237    {
238        let mut descend = false;
239
240        if self.stack.is_empty() && !self.at_end {
241            self.stack.push(StackEntry {
242                tree: self.tree,
243                index: 0,
244                seek_dimension: S::default(),
245                sum_dimension: U::default(),
246            });
247            descend = true;
248            self.did_seek = true;
249        }
250
251        while self.stack.len() > 0 {
252            let new_subtree = {
253                let entry = self.stack.last_mut().unwrap();
254                match entry.tree.0.as_ref() {
255                    Node::Internal {
256                        child_trees,
257                        child_summaries,
258                        ..
259                    } => {
260                        if !descend {
261                            entry.seek_dimension = self.seek_dimension.clone();
262                            entry.sum_dimension = self.sum_dimension.clone();
263                            entry.index += 1;
264                        }
265
266                        while entry.index < child_summaries.len() {
267                            let next_summary = &child_summaries[entry.index];
268                            if filter_node(next_summary) {
269                                break;
270                            } else {
271                                self.seek_dimension.add_summary(next_summary, cx);
272                                self.sum_dimension.add_summary(next_summary, cx);
273                            }
274                            entry.index += 1;
275                        }
276
277                        child_trees.get(entry.index)
278                    }
279                    Node::Leaf { item_summaries, .. } => {
280                        if !descend {
281                            let item_summary = &item_summaries[entry.index];
282                            self.seek_dimension.add_summary(item_summary, cx);
283                            entry.seek_dimension.add_summary(item_summary, cx);
284                            self.sum_dimension.add_summary(item_summary, cx);
285                            entry.sum_dimension.add_summary(item_summary, cx);
286                            entry.index += 1;
287                        }
288
289                        loop {
290                            if let Some(next_item_summary) = item_summaries.get(entry.index) {
291                                if filter_node(next_item_summary) {
292                                    return;
293                                } else {
294                                    self.seek_dimension.add_summary(next_item_summary, cx);
295                                    entry.seek_dimension.add_summary(next_item_summary, cx);
296                                    self.sum_dimension.add_summary(next_item_summary, cx);
297                                    entry.sum_dimension.add_summary(next_item_summary, cx);
298                                    entry.index += 1;
299                                }
300                            } else {
301                                break None;
302                            }
303                        }
304                    }
305                }
306            };
307
308            if let Some(subtree) = new_subtree {
309                descend = true;
310                self.stack.push(StackEntry {
311                    tree: subtree,
312                    index: 0,
313                    seek_dimension: self.seek_dimension.clone(),
314                    sum_dimension: self.sum_dimension.clone(),
315                });
316            } else {
317                descend = false;
318                self.stack.pop();
319            }
320        }
321
322        self.at_end = self.stack.is_empty();
323        debug_assert!(self.stack.is_empty() || self.stack.last().unwrap().tree.0.is_leaf());
324    }
325
326    fn descend_to_last_item(
327        &mut self,
328        mut subtree: &'a SumTree<T>,
329        cx: &<T::Summary as Summary>::Context,
330    ) {
331        self.did_seek = true;
332        loop {
333            match subtree.0.as_ref() {
334                Node::Internal {
335                    child_trees,
336                    child_summaries,
337                    ..
338                } => {
339                    for summary in &child_summaries[0..child_summaries.len() - 1] {
340                        self.seek_dimension.add_summary(summary, cx);
341                        self.sum_dimension.add_summary(summary, cx);
342                    }
343
344                    self.stack.push(StackEntry {
345                        tree: subtree,
346                        index: child_trees.len() - 1,
347                        seek_dimension: self.seek_dimension.clone(),
348                        sum_dimension: self.sum_dimension.clone(),
349                    });
350                    subtree = child_trees.last().unwrap();
351                }
352                Node::Leaf { item_summaries, .. } => {
353                    let last_index = item_summaries.len().saturating_sub(1);
354                    for item_summary in &item_summaries[0..last_index] {
355                        self.seek_dimension.add_summary(item_summary, cx);
356                        self.sum_dimension.add_summary(item_summary, cx);
357                    }
358                    self.stack.push(StackEntry {
359                        tree: subtree,
360                        index: last_index,
361                        seek_dimension: self.seek_dimension.clone(),
362                        sum_dimension: self.sum_dimension.clone(),
363                    });
364                    break;
365                }
366            }
367        }
368    }
369}
370
371impl<'a, T, S, U> Cursor<'a, T, S, U>
372where
373    T: Item,
374    S: SeekDimension<'a, T::Summary>,
375    U: Dimension<'a, T::Summary>,
376{
377    pub fn seek(&mut self, pos: &S, bias: Bias, cx: &<T::Summary as Summary>::Context) -> bool {
378        self.reset();
379        self.seek_internal::<()>(Some(pos), bias, &mut SeekAggregate::None, cx)
380    }
381
382    pub fn seek_forward(
383        &mut self,
384        pos: &S,
385        bias: Bias,
386        cx: &<T::Summary as Summary>::Context,
387    ) -> bool {
388        self.seek_internal::<()>(Some(pos), bias, &mut SeekAggregate::None, cx)
389    }
390
391    pub fn slice(
392        &mut self,
393        end: &S,
394        bias: Bias,
395        cx: &<T::Summary as Summary>::Context,
396    ) -> SumTree<T> {
397        let mut slice = SeekAggregate::Slice(SumTree::new());
398        self.seek_internal::<()>(Some(end), bias, &mut slice, cx);
399        if let SeekAggregate::Slice(slice) = slice {
400            slice
401        } else {
402            unreachable!()
403        }
404    }
405
406    pub fn suffix(&mut self, cx: &<T::Summary as Summary>::Context) -> SumTree<T> {
407        let mut slice = SeekAggregate::Slice(SumTree::new());
408        self.seek_internal::<()>(None, Bias::Right, &mut slice, cx);
409        if let SeekAggregate::Slice(slice) = slice {
410            slice
411        } else {
412            unreachable!()
413        }
414    }
415
416    pub fn summary<D>(&mut self, end: &S, bias: Bias, cx: &<T::Summary as Summary>::Context) -> D
417    where
418        D: Dimension<'a, T::Summary>,
419    {
420        let mut summary = SeekAggregate::Summary(D::default());
421        self.seek_internal(Some(end), bias, &mut summary, cx);
422        if let SeekAggregate::Summary(summary) = summary {
423            summary
424        } else {
425            unreachable!()
426        }
427    }
428
429    fn seek_internal<D>(
430        &mut self,
431        target: Option<&S>,
432        bias: Bias,
433        aggregate: &mut SeekAggregate<T, D>,
434        cx: &<T::Summary as Summary>::Context,
435    ) -> bool
436    where
437        D: Dimension<'a, T::Summary>,
438    {
439        if let Some(target) = target {
440            debug_assert!(
441                target.cmp(&self.seek_dimension, cx) >= Ordering::Equal,
442                "cannot seek backward from {:?} to {:?}",
443                self.seek_dimension,
444                target
445            );
446        }
447
448        if !self.did_seek {
449            self.did_seek = true;
450            self.stack.push(StackEntry {
451                tree: self.tree,
452                index: 0,
453                seek_dimension: Default::default(),
454                sum_dimension: Default::default(),
455            });
456        }
457
458        let mut ascending = false;
459        'outer: while let Some(entry) = self.stack.last_mut() {
460            match *entry.tree.0 {
461                Node::Internal {
462                    ref child_summaries,
463                    ref child_trees,
464                    ..
465                } => {
466                    if ascending {
467                        entry.index += 1;
468                    }
469
470                    for (child_tree, child_summary) in child_trees[entry.index..]
471                        .iter()
472                        .zip(&child_summaries[entry.index..])
473                    {
474                        let mut child_end = self.seek_dimension.clone();
475                        child_end.add_summary(&child_summary, cx);
476
477                        let comparison =
478                            target.map_or(Ordering::Greater, |t| t.cmp(&child_end, cx));
479                        if comparison == Ordering::Greater
480                            || (comparison == Ordering::Equal && bias == Bias::Right)
481                        {
482                            self.seek_dimension = child_end;
483                            self.sum_dimension.add_summary(child_summary, cx);
484                            match aggregate {
485                                SeekAggregate::None => {}
486                                SeekAggregate::Slice(slice) => {
487                                    slice.push_tree(child_tree.clone(), cx);
488                                }
489                                SeekAggregate::Summary(summary) => {
490                                    summary.add_summary(child_summary, cx);
491                                }
492                            }
493                            entry.index += 1;
494                            entry.seek_dimension = self.seek_dimension.clone();
495                            entry.sum_dimension = self.sum_dimension.clone();
496                        } else {
497                            self.stack.push(StackEntry {
498                                tree: child_tree,
499                                index: 0,
500                                seek_dimension: self.seek_dimension.clone(),
501                                sum_dimension: self.sum_dimension.clone(),
502                            });
503                            ascending = false;
504                            continue 'outer;
505                        }
506                    }
507                }
508                Node::Leaf {
509                    ref items,
510                    ref item_summaries,
511                    ..
512                } => {
513                    let mut slice_items = ArrayVec::<T, { 2 * TREE_BASE }>::new();
514                    let mut slice_item_summaries = ArrayVec::<T::Summary, { 2 * TREE_BASE }>::new();
515                    let mut slice_items_summary = match aggregate {
516                        SeekAggregate::Slice(_) => Some(T::Summary::default()),
517                        _ => None,
518                    };
519
520                    for (item, item_summary) in items[entry.index..]
521                        .iter()
522                        .zip(&item_summaries[entry.index..])
523                    {
524                        let mut child_end = self.seek_dimension.clone();
525                        child_end.add_summary(item_summary, cx);
526
527                        let comparison =
528                            target.map_or(Ordering::Greater, |t| t.cmp(&child_end, cx));
529                        if comparison == Ordering::Greater
530                            || (comparison == Ordering::Equal && bias == Bias::Right)
531                        {
532                            self.seek_dimension = child_end;
533                            self.sum_dimension.add_summary(item_summary, cx);
534                            match aggregate {
535                                SeekAggregate::None => {}
536                                SeekAggregate::Slice(_) => {
537                                    slice_items.push(item.clone());
538                                    slice_item_summaries.push(item_summary.clone());
539                                    slice_items_summary
540                                        .as_mut()
541                                        .unwrap()
542                                        .add_summary(item_summary, cx);
543                                }
544                                SeekAggregate::Summary(summary) => {
545                                    summary.add_summary(item_summary, cx);
546                                }
547                            }
548                            entry.index += 1;
549                        } else {
550                            if let SeekAggregate::Slice(slice) = aggregate {
551                                slice.push_tree(
552                                    SumTree(Arc::new(Node::Leaf {
553                                        summary: slice_items_summary.unwrap(),
554                                        items: slice_items,
555                                        item_summaries: slice_item_summaries,
556                                    })),
557                                    cx,
558                                );
559                            }
560                            break 'outer;
561                        }
562                    }
563
564                    if let SeekAggregate::Slice(slice) = aggregate {
565                        if !slice_items.is_empty() {
566                            slice.push_tree(
567                                SumTree(Arc::new(Node::Leaf {
568                                    summary: slice_items_summary.unwrap(),
569                                    items: slice_items,
570                                    item_summaries: slice_item_summaries,
571                                })),
572                                cx,
573                            );
574                        }
575                    }
576                }
577            }
578
579            self.stack.pop();
580            ascending = true;
581        }
582
583        self.at_end = self.stack.is_empty();
584        debug_assert!(self.stack.is_empty() || self.stack.last().unwrap().tree.0.is_leaf());
585
586        let mut end = self.seek_dimension.clone();
587        if bias == Bias::Left {
588            if let Some(summary) = self.item_summary() {
589                end.add_summary(summary, cx);
590            }
591        }
592
593        target.map_or(false, |t| t.cmp(&end, cx) == Ordering::Equal)
594    }
595}
596
597impl<'a, T, S, Seek, Sum> Iterator for Cursor<'a, T, Seek, Sum>
598where
599    T: Item<Summary = S>,
600    S: Summary<Context = ()>,
601    Seek: Dimension<'a, T::Summary>,
602    Sum: Dimension<'a, T::Summary>,
603{
604    type Item = &'a T;
605
606    fn next(&mut self) -> Option<Self::Item> {
607        if !self.did_seek {
608            self.next(&());
609        }
610
611        if let Some(item) = self.item() {
612            self.next(&());
613            Some(item)
614        } else {
615            None
616        }
617    }
618}
619
620impl<'a, T, S, U> Cursor<'a, T, S, U>
621where
622    T: Item,
623    S: SeekDimension<'a, T::Summary>,
624    U: SeekDimension<'a, T::Summary>,
625{
626    pub fn swap_dimensions(self) -> Cursor<'a, T, U, S> {
627        Cursor {
628            tree: self.tree,
629            stack: self
630                .stack
631                .into_iter()
632                .map(StackEntry::swap_dimensions)
633                .collect(),
634            seek_dimension: self.sum_dimension,
635            sum_dimension: self.seek_dimension,
636            did_seek: self.did_seek,
637            at_end: self.at_end,
638        }
639    }
640}
641
642pub struct FilterCursor<'a, F: Fn(&T::Summary) -> bool, T: Item, U> {
643    cursor: Cursor<'a, T, (), U>,
644    filter_node: F,
645}
646
647impl<'a, F, T, U> FilterCursor<'a, F, T, U>
648where
649    F: Fn(&T::Summary) -> bool,
650    T: Item,
651    U: Dimension<'a, T::Summary>,
652{
653    pub fn new(
654        tree: &'a SumTree<T>,
655        filter_node: F,
656        cx: &<T::Summary as Summary>::Context,
657    ) -> Self {
658        let mut cursor = tree.cursor::<(), U>();
659        cursor.next_internal(&filter_node, cx);
660        Self {
661            cursor,
662            filter_node,
663        }
664    }
665
666    pub fn start(&self) -> &U {
667        self.cursor.sum_start()
668    }
669
670    pub fn item(&self) -> Option<&'a T> {
671        self.cursor.item()
672    }
673
674    pub fn next(&mut self, cx: &<T::Summary as Summary>::Context) {
675        self.cursor.next_internal(&self.filter_node, cx);
676    }
677}
678
679impl<'a, F, T, S, U> Iterator for FilterCursor<'a, F, T, U>
680where
681    F: Fn(&T::Summary) -> bool,
682    T: Item<Summary = S>,
683    S: Summary<Context = ()>,
684    U: Dimension<'a, T::Summary>,
685{
686    type Item = &'a T;
687
688    fn next(&mut self) -> Option<Self::Item> {
689        if let Some(item) = self.item() {
690            self.cursor.next_internal(&self.filter_node, &());
691            Some(item)
692        } else {
693            None
694        }
695    }
696}
697
698enum SeekAggregate<T: Item, D> {
699    None,
700    Slice(SumTree<T>),
701    Summary(D),
702}