cursor.rs

  1use super::*;
  2use arrayvec::ArrayVec;
  3use std::{cmp::Ordering, sync::Arc};
  4
  5#[derive(Clone)]
  6struct StackEntry<'a, T: Item, S, U> {
  7    tree: &'a SumTree<T>,
  8    index: usize,
  9    seek_dimension: S,
 10    sum_dimension: U,
 11}
 12
 13#[derive(Clone)]
 14pub struct Cursor<'a, T: Item, S, U> {
 15    tree: &'a SumTree<T>,
 16    stack: ArrayVec<StackEntry<'a, T, S, U>, 16>,
 17    seek_dimension: S,
 18    sum_dimension: U,
 19    did_seek: bool,
 20    at_end: bool,
 21}
 22
 23impl<'a, T, S, U> Cursor<'a, T, S, U>
 24where
 25    T: Item,
 26    S: Dimension<'a, T::Summary>,
 27    U: Dimension<'a, T::Summary>,
 28{
 29    pub fn new(tree: &'a SumTree<T>) -> Self {
 30        Self {
 31            tree,
 32            stack: ArrayVec::new(),
 33            seek_dimension: S::default(),
 34            sum_dimension: U::default(),
 35            did_seek: false,
 36            at_end: false,
 37        }
 38    }
 39
 40    fn reset(&mut self) {
 41        self.did_seek = false;
 42        self.at_end = false;
 43        self.stack.truncate(0);
 44        self.seek_dimension = S::default();
 45        self.sum_dimension = U::default();
 46    }
 47
 48    pub fn seek_start(&self) -> &S {
 49        &self.seek_dimension
 50    }
 51
 52    pub fn seek_end(&self, cx: &<T::Summary as Summary>::Context) -> S {
 53        if let Some(item_summary) = self.item_summary() {
 54            let mut end = self.seek_start().clone();
 55            end.add_summary(item_summary, cx);
 56            end
 57        } else {
 58            self.seek_start().clone()
 59        }
 60    }
 61
 62    pub fn sum_start(&self) -> &U {
 63        &self.sum_dimension
 64    }
 65
 66    pub fn sum_end(&self, cx: &<T::Summary as Summary>::Context) -> U {
 67        if let Some(item_summary) = self.item_summary() {
 68            let mut end = self.sum_start().clone();
 69            end.add_summary(item_summary, cx);
 70            end
 71        } else {
 72            self.sum_start().clone()
 73        }
 74    }
 75
 76    pub fn item(&self) -> Option<&'a T> {
 77        assert!(self.did_seek, "Must seek before calling this method");
 78        if let Some(entry) = self.stack.last() {
 79            match *entry.tree.0 {
 80                Node::Leaf { ref items, .. } => {
 81                    if entry.index == items.len() {
 82                        None
 83                    } else {
 84                        Some(&items[entry.index])
 85                    }
 86                }
 87                _ => unreachable!(),
 88            }
 89        } else {
 90            None
 91        }
 92    }
 93
 94    pub fn item_summary(&self) -> Option<&'a T::Summary> {
 95        assert!(self.did_seek, "Must seek before calling this method");
 96        if let Some(entry) = self.stack.last() {
 97            match *entry.tree.0 {
 98                Node::Leaf {
 99                    ref item_summaries, ..
100                } => {
101                    if entry.index == item_summaries.len() {
102                        None
103                    } else {
104                        Some(&item_summaries[entry.index])
105                    }
106                }
107                _ => unreachable!(),
108            }
109        } else {
110            None
111        }
112    }
113
114    pub fn prev_item(&self) -> Option<&'a T> {
115        assert!(self.did_seek, "Must seek before calling this method");
116        if let Some(entry) = self.stack.last() {
117            if entry.index == 0 {
118                if let Some(prev_leaf) = self.prev_leaf() {
119                    Some(prev_leaf.0.items().last().unwrap())
120                } else {
121                    None
122                }
123            } else {
124                match *entry.tree.0 {
125                    Node::Leaf { ref items, .. } => Some(&items[entry.index - 1]),
126                    _ => unreachable!(),
127                }
128            }
129        } else if self.at_end {
130            self.tree.last()
131        } else {
132            None
133        }
134    }
135
136    fn prev_leaf(&self) -> Option<&'a SumTree<T>> {
137        for entry in self.stack.iter().rev().skip(1) {
138            if entry.index != 0 {
139                match *entry.tree.0 {
140                    Node::Internal {
141                        ref child_trees, ..
142                    } => return Some(child_trees[entry.index - 1].rightmost_leaf()),
143                    Node::Leaf { .. } => unreachable!(),
144                };
145            }
146        }
147        None
148    }
149
150    #[allow(unused)]
151    pub fn prev(&mut self, cx: &<T::Summary as Summary>::Context) {
152        assert!(self.did_seek, "Must seek before calling this method");
153
154        if self.at_end {
155            self.seek_dimension = S::default();
156            self.sum_dimension = U::default();
157            self.descend_to_last_item(self.tree, cx);
158            self.at_end = false;
159        } else {
160            while let Some(entry) = self.stack.pop() {
161                if entry.index > 0 {
162                    let new_index = entry.index - 1;
163
164                    if let Some(StackEntry {
165                        seek_dimension,
166                        sum_dimension,
167                        ..
168                    }) = self.stack.last()
169                    {
170                        self.seek_dimension = seek_dimension.clone();
171                        self.sum_dimension = sum_dimension.clone();
172                    } else {
173                        self.seek_dimension = S::default();
174                        self.sum_dimension = U::default();
175                    }
176
177                    match entry.tree.0.as_ref() {
178                        Node::Internal {
179                            child_trees,
180                            child_summaries,
181                            ..
182                        } => {
183                            for summary in &child_summaries[0..new_index] {
184                                self.seek_dimension.add_summary(summary, cx);
185                                self.sum_dimension.add_summary(summary, cx);
186                            }
187                            self.stack.push(StackEntry {
188                                tree: entry.tree,
189                                index: new_index,
190                                seek_dimension: self.seek_dimension.clone(),
191                                sum_dimension: self.sum_dimension.clone(),
192                            });
193                            self.descend_to_last_item(&child_trees[new_index], cx);
194                        }
195                        Node::Leaf { item_summaries, .. } => {
196                            for item_summary in &item_summaries[0..new_index] {
197                                self.seek_dimension.add_summary(item_summary, cx);
198                                self.sum_dimension.add_summary(item_summary, cx);
199                            }
200                            self.stack.push(StackEntry {
201                                tree: entry.tree,
202                                index: new_index,
203                                seek_dimension: self.seek_dimension.clone(),
204                                sum_dimension: self.sum_dimension.clone(),
205                            });
206                        }
207                    }
208
209                    break;
210                }
211            }
212        }
213    }
214
215    pub fn next(&mut self, cx: &<T::Summary as Summary>::Context) {
216        self.next_internal(|_| true, cx)
217    }
218
219    fn next_internal<F>(&mut self, filter_node: F, cx: &<T::Summary as Summary>::Context)
220    where
221        F: Fn(&T::Summary) -> bool,
222    {
223        let mut descend = false;
224
225        if self.stack.is_empty() && !self.at_end {
226            self.stack.push(StackEntry {
227                tree: self.tree,
228                index: 0,
229                seek_dimension: S::default(),
230                sum_dimension: U::default(),
231            });
232            descend = true;
233            self.did_seek = true;
234        }
235
236        while self.stack.len() > 0 {
237            let new_subtree = {
238                let entry = self.stack.last_mut().unwrap();
239                match entry.tree.0.as_ref() {
240                    Node::Internal {
241                        child_trees,
242                        child_summaries,
243                        ..
244                    } => {
245                        if !descend {
246                            entry.seek_dimension = self.seek_dimension.clone();
247                            entry.sum_dimension = self.sum_dimension.clone();
248                            entry.index += 1;
249                        }
250
251                        while entry.index < child_summaries.len() {
252                            let next_summary = &child_summaries[entry.index];
253                            if filter_node(next_summary) {
254                                break;
255                            } else {
256                                self.seek_dimension.add_summary(next_summary, cx);
257                                self.sum_dimension.add_summary(next_summary, cx);
258                            }
259                            entry.index += 1;
260                        }
261
262                        child_trees.get(entry.index)
263                    }
264                    Node::Leaf { item_summaries, .. } => {
265                        if !descend {
266                            let item_summary = &item_summaries[entry.index];
267                            self.seek_dimension.add_summary(item_summary, cx);
268                            entry.seek_dimension.add_summary(item_summary, cx);
269                            self.sum_dimension.add_summary(item_summary, cx);
270                            entry.sum_dimension.add_summary(item_summary, cx);
271                            entry.index += 1;
272                        }
273
274                        loop {
275                            if let Some(next_item_summary) = item_summaries.get(entry.index) {
276                                if filter_node(next_item_summary) {
277                                    return;
278                                } else {
279                                    self.seek_dimension.add_summary(next_item_summary, cx);
280                                    entry.seek_dimension.add_summary(next_item_summary, cx);
281                                    self.sum_dimension.add_summary(next_item_summary, cx);
282                                    entry.sum_dimension.add_summary(next_item_summary, cx);
283                                    entry.index += 1;
284                                }
285                            } else {
286                                break None;
287                            }
288                        }
289                    }
290                }
291            };
292
293            if let Some(subtree) = new_subtree {
294                descend = true;
295                self.stack.push(StackEntry {
296                    tree: subtree,
297                    index: 0,
298                    seek_dimension: self.seek_dimension.clone(),
299                    sum_dimension: self.sum_dimension.clone(),
300                });
301            } else {
302                descend = false;
303                self.stack.pop();
304            }
305        }
306
307        self.at_end = self.stack.is_empty();
308        debug_assert!(self.stack.is_empty() || self.stack.last().unwrap().tree.0.is_leaf());
309    }
310
311    fn descend_to_last_item(
312        &mut self,
313        mut subtree: &'a SumTree<T>,
314        cx: &<T::Summary as Summary>::Context,
315    ) {
316        self.did_seek = true;
317        loop {
318            match subtree.0.as_ref() {
319                Node::Internal {
320                    child_trees,
321                    child_summaries,
322                    ..
323                } => {
324                    for summary in &child_summaries[0..child_summaries.len() - 1] {
325                        self.seek_dimension.add_summary(summary, cx);
326                        self.sum_dimension.add_summary(summary, cx);
327                    }
328
329                    self.stack.push(StackEntry {
330                        tree: subtree,
331                        index: child_trees.len() - 1,
332                        seek_dimension: self.seek_dimension.clone(),
333                        sum_dimension: self.sum_dimension.clone(),
334                    });
335                    subtree = child_trees.last().unwrap();
336                }
337                Node::Leaf { item_summaries, .. } => {
338                    let last_index = item_summaries.len().saturating_sub(1);
339                    for item_summary in &item_summaries[0..last_index] {
340                        self.seek_dimension.add_summary(item_summary, cx);
341                        self.sum_dimension.add_summary(item_summary, cx);
342                    }
343                    self.stack.push(StackEntry {
344                        tree: subtree,
345                        index: last_index,
346                        seek_dimension: self.seek_dimension.clone(),
347                        sum_dimension: self.sum_dimension.clone(),
348                    });
349                    break;
350                }
351            }
352        }
353    }
354}
355
356impl<'a, T, S, U> Cursor<'a, T, S, U>
357where
358    T: Item,
359    S: SeekDimension<'a, T::Summary>,
360    U: Dimension<'a, T::Summary>,
361{
362    pub fn seek(&mut self, pos: &S, bias: Bias, cx: &<T::Summary as Summary>::Context) -> bool {
363        self.reset();
364        self.seek_internal::<()>(Some(pos), bias, &mut SeekAggregate::None, cx)
365    }
366
367    pub fn seek_forward(
368        &mut self,
369        pos: &S,
370        bias: Bias,
371        cx: &<T::Summary as Summary>::Context,
372    ) -> bool {
373        self.seek_internal::<()>(Some(pos), bias, &mut SeekAggregate::None, cx)
374    }
375
376    pub fn slice(
377        &mut self,
378        end: &S,
379        bias: Bias,
380        cx: &<T::Summary as Summary>::Context,
381    ) -> SumTree<T> {
382        let mut slice = SeekAggregate::Slice(SumTree::new());
383        self.seek_internal::<()>(Some(end), bias, &mut slice, cx);
384        if let SeekAggregate::Slice(slice) = slice {
385            slice
386        } else {
387            unreachable!()
388        }
389    }
390
391    pub fn suffix(&mut self, cx: &<T::Summary as Summary>::Context) -> SumTree<T> {
392        let mut slice = SeekAggregate::Slice(SumTree::new());
393        self.seek_internal::<()>(None, Bias::Right, &mut slice, cx);
394        if let SeekAggregate::Slice(slice) = slice {
395            slice
396        } else {
397            unreachable!()
398        }
399    }
400
401    pub fn summary<D>(&mut self, end: &S, bias: Bias, cx: &<T::Summary as Summary>::Context) -> D
402    where
403        D: Dimension<'a, T::Summary>,
404    {
405        let mut summary = SeekAggregate::Summary(D::default());
406        self.seek_internal(Some(end), bias, &mut summary, cx);
407        if let SeekAggregate::Summary(summary) = summary {
408            summary
409        } else {
410            unreachable!()
411        }
412    }
413
414    fn seek_internal<D>(
415        &mut self,
416        target: Option<&S>,
417        bias: Bias,
418        aggregate: &mut SeekAggregate<T, D>,
419        cx: &<T::Summary as Summary>::Context,
420    ) -> bool
421    where
422        D: Dimension<'a, T::Summary>,
423    {
424        if let Some(target) = target {
425            debug_assert!(
426                target.cmp(&self.seek_dimension, cx) >= Ordering::Equal,
427                "cannot seek backward from {:?} to {:?}",
428                self.seek_dimension,
429                target
430            );
431        }
432
433        if !self.did_seek {
434            self.did_seek = true;
435            self.stack.push(StackEntry {
436                tree: self.tree,
437                index: 0,
438                seek_dimension: Default::default(),
439                sum_dimension: Default::default(),
440            });
441        }
442
443        let mut ascending = false;
444        'outer: while let Some(entry) = self.stack.last_mut() {
445            match *entry.tree.0 {
446                Node::Internal {
447                    ref child_summaries,
448                    ref child_trees,
449                    ..
450                } => {
451                    if ascending {
452                        entry.index += 1;
453                    }
454
455                    for (child_tree, child_summary) in child_trees[entry.index..]
456                        .iter()
457                        .zip(&child_summaries[entry.index..])
458                    {
459                        let mut child_end = self.seek_dimension.clone();
460                        child_end.add_summary(&child_summary, cx);
461
462                        let comparison =
463                            target.map_or(Ordering::Greater, |t| t.cmp(&child_end, cx));
464                        if comparison == Ordering::Greater
465                            || (comparison == Ordering::Equal && bias == Bias::Right)
466                        {
467                            self.seek_dimension = child_end;
468                            self.sum_dimension.add_summary(child_summary, cx);
469                            match aggregate {
470                                SeekAggregate::None => {}
471                                SeekAggregate::Slice(slice) => {
472                                    slice.push_tree(child_tree.clone(), cx);
473                                }
474                                SeekAggregate::Summary(summary) => {
475                                    summary.add_summary(child_summary, cx);
476                                }
477                            }
478                            entry.index += 1;
479                            entry.seek_dimension = self.seek_dimension.clone();
480                            entry.sum_dimension = self.sum_dimension.clone();
481                        } else {
482                            self.stack.push(StackEntry {
483                                tree: child_tree,
484                                index: 0,
485                                seek_dimension: self.seek_dimension.clone(),
486                                sum_dimension: self.sum_dimension.clone(),
487                            });
488                            ascending = false;
489                            continue 'outer;
490                        }
491                    }
492                }
493                Node::Leaf {
494                    ref items,
495                    ref item_summaries,
496                    ..
497                } => {
498                    let mut slice_items = ArrayVec::<T, { 2 * TREE_BASE }>::new();
499                    let mut slice_item_summaries = ArrayVec::<T::Summary, { 2 * TREE_BASE }>::new();
500                    let mut slice_items_summary = match aggregate {
501                        SeekAggregate::Slice(_) => Some(T::Summary::default()),
502                        _ => None,
503                    };
504
505                    for (item, item_summary) in items[entry.index..]
506                        .iter()
507                        .zip(&item_summaries[entry.index..])
508                    {
509                        let mut child_end = self.seek_dimension.clone();
510                        child_end.add_summary(item_summary, cx);
511
512                        let comparison =
513                            target.map_or(Ordering::Greater, |t| t.cmp(&child_end, cx));
514                        if comparison == Ordering::Greater
515                            || (comparison == Ordering::Equal && bias == Bias::Right)
516                        {
517                            self.seek_dimension = child_end;
518                            self.sum_dimension.add_summary(item_summary, cx);
519                            match aggregate {
520                                SeekAggregate::None => {}
521                                SeekAggregate::Slice(_) => {
522                                    slice_items.push(item.clone());
523                                    slice_item_summaries.push(item_summary.clone());
524                                    slice_items_summary
525                                        .as_mut()
526                                        .unwrap()
527                                        .add_summary(item_summary, cx);
528                                }
529                                SeekAggregate::Summary(summary) => {
530                                    summary.add_summary(item_summary, cx);
531                                }
532                            }
533                            entry.index += 1;
534                        } else {
535                            if let SeekAggregate::Slice(slice) = aggregate {
536                                slice.push_tree(
537                                    SumTree(Arc::new(Node::Leaf {
538                                        summary: slice_items_summary.unwrap(),
539                                        items: slice_items,
540                                        item_summaries: slice_item_summaries,
541                                    })),
542                                    cx,
543                                );
544                            }
545                            break 'outer;
546                        }
547                    }
548
549                    if let SeekAggregate::Slice(slice) = aggregate {
550                        if !slice_items.is_empty() {
551                            slice.push_tree(
552                                SumTree(Arc::new(Node::Leaf {
553                                    summary: slice_items_summary.unwrap(),
554                                    items: slice_items,
555                                    item_summaries: slice_item_summaries,
556                                })),
557                                cx,
558                            );
559                        }
560                    }
561                }
562            }
563
564            self.stack.pop();
565            ascending = true;
566        }
567
568        self.at_end = self.stack.is_empty();
569        debug_assert!(self.stack.is_empty() || self.stack.last().unwrap().tree.0.is_leaf());
570
571        let mut end = self.seek_dimension.clone();
572        if bias == Bias::Left {
573            if let Some(summary) = self.item_summary() {
574                end.add_summary(summary, cx);
575            }
576        }
577
578        target.map_or(false, |t| t.cmp(&end, cx) == Ordering::Equal)
579    }
580}
581
582impl<'a, T, S, Seek, Sum> Iterator for Cursor<'a, T, Seek, Sum>
583where
584    T: Item<Summary = S>,
585    S: Summary<Context = ()>,
586    Seek: Dimension<'a, T::Summary>,
587    Sum: Dimension<'a, T::Summary>,
588{
589    type Item = &'a T;
590
591    fn next(&mut self) -> Option<Self::Item> {
592        if !self.did_seek {
593            self.next(&());
594        }
595
596        if let Some(item) = self.item() {
597            self.next(&());
598            Some(item)
599        } else {
600            None
601        }
602    }
603}
604
605pub struct FilterCursor<'a, F: Fn(&T::Summary) -> bool, T: Item, U> {
606    cursor: Cursor<'a, T, (), U>,
607    filter_node: F,
608}
609
610impl<'a, F, T, U> FilterCursor<'a, F, T, U>
611where
612    F: Fn(&T::Summary) -> bool,
613    T: Item,
614    U: Dimension<'a, T::Summary>,
615{
616    pub fn new(
617        tree: &'a SumTree<T>,
618        filter_node: F,
619        cx: &<T::Summary as Summary>::Context,
620    ) -> Self {
621        let mut cursor = tree.cursor::<(), U>();
622        cursor.next_internal(&filter_node, cx);
623        Self {
624            cursor,
625            filter_node,
626        }
627    }
628
629    pub fn start(&self) -> &U {
630        self.cursor.sum_start()
631    }
632
633    pub fn item(&self) -> Option<&'a T> {
634        self.cursor.item()
635    }
636
637    pub fn next(&mut self, cx: &<T::Summary as Summary>::Context) {
638        self.cursor.next_internal(&self.filter_node, cx);
639    }
640}
641
642impl<'a, F, T, S, U> Iterator for FilterCursor<'a, F, T, U>
643where
644    F: Fn(&T::Summary) -> bool,
645    T: Item<Summary = S>,
646    S: Summary<Context = ()>,
647    U: Dimension<'a, T::Summary>,
648{
649    type Item = &'a T;
650
651    fn next(&mut self) -> Option<Self::Item> {
652        if let Some(item) = self.item() {
653            self.cursor.next_internal(&self.filter_node, &());
654            Some(item)
655        } else {
656            None
657        }
658    }
659}
660
661enum SeekAggregate<T: Item, D> {
662    None,
663    Slice(SumTree<T>),
664    Summary(D),
665}