1use super::*;
2use arrayvec::ArrayVec;
3use std::{cmp::Ordering, sync::Arc};
4
5#[derive(Clone)]
6struct StackEntry<'a, T: Item, S, U> {
7 tree: &'a SumTree<T>,
8 index: usize,
9 seek_dimension: S,
10 sum_dimension: U,
11}
12
13impl<'a, T, S, U> StackEntry<'a, T, S, U>
14where
15 T: Item,
16 S: SeekDimension<'a, T::Summary>,
17 U: SeekDimension<'a, T::Summary>,
18{
19 fn swap_dimensions(self) -> StackEntry<'a, T, U, S> {
20 StackEntry {
21 tree: self.tree,
22 index: self.index,
23 seek_dimension: self.sum_dimension,
24 sum_dimension: self.seek_dimension,
25 }
26 }
27}
28
29#[derive(Clone)]
30pub struct Cursor<'a, T: Item, S, U> {
31 tree: &'a SumTree<T>,
32 stack: ArrayVec<StackEntry<'a, T, S, U>, 16>,
33 seek_dimension: S,
34 sum_dimension: U,
35 did_seek: bool,
36 at_end: bool,
37}
38
39impl<'a, T, S, U> Cursor<'a, T, S, U>
40where
41 T: Item,
42 S: Dimension<'a, T::Summary>,
43 U: Dimension<'a, T::Summary>,
44{
45 pub fn new(tree: &'a SumTree<T>) -> Self {
46 Self {
47 tree,
48 stack: ArrayVec::new(),
49 seek_dimension: S::default(),
50 sum_dimension: U::default(),
51 did_seek: false,
52 at_end: false,
53 }
54 }
55
56 fn reset(&mut self) {
57 self.did_seek = false;
58 self.at_end = false;
59 self.stack.truncate(0);
60 self.seek_dimension = S::default();
61 self.sum_dimension = U::default();
62 }
63
64 pub fn seek_start(&self) -> &S {
65 &self.seek_dimension
66 }
67
68 pub fn seek_end(&self, cx: &<T::Summary as Summary>::Context) -> S {
69 if let Some(item_summary) = self.item_summary() {
70 let mut end = self.seek_start().clone();
71 end.add_summary(item_summary, cx);
72 end
73 } else {
74 self.seek_start().clone()
75 }
76 }
77
78 pub fn sum_start(&self) -> &U {
79 &self.sum_dimension
80 }
81
82 pub fn sum_end(&self, cx: &<T::Summary as Summary>::Context) -> U {
83 if let Some(item_summary) = self.item_summary() {
84 let mut end = self.sum_start().clone();
85 end.add_summary(item_summary, cx);
86 end
87 } else {
88 self.sum_start().clone()
89 }
90 }
91
92 pub fn item(&self) -> Option<&'a T> {
93 assert!(self.did_seek, "Must seek before calling this method");
94 if let Some(entry) = self.stack.last() {
95 match *entry.tree.0 {
96 Node::Leaf { ref items, .. } => {
97 if entry.index == items.len() {
98 None
99 } else {
100 Some(&items[entry.index])
101 }
102 }
103 _ => unreachable!(),
104 }
105 } else {
106 None
107 }
108 }
109
110 pub fn item_summary(&self) -> Option<&'a T::Summary> {
111 assert!(self.did_seek, "Must seek before calling this method");
112 if let Some(entry) = self.stack.last() {
113 match *entry.tree.0 {
114 Node::Leaf {
115 ref item_summaries, ..
116 } => {
117 if entry.index == item_summaries.len() {
118 None
119 } else {
120 Some(&item_summaries[entry.index])
121 }
122 }
123 _ => unreachable!(),
124 }
125 } else {
126 None
127 }
128 }
129
130 pub fn prev_item(&self) -> Option<&'a T> {
131 assert!(self.did_seek, "Must seek before calling this method");
132 if let Some(entry) = self.stack.last() {
133 if entry.index == 0 {
134 if let Some(prev_leaf) = self.prev_leaf() {
135 Some(prev_leaf.0.items().last().unwrap())
136 } else {
137 None
138 }
139 } else {
140 match *entry.tree.0 {
141 Node::Leaf { ref items, .. } => Some(&items[entry.index - 1]),
142 _ => unreachable!(),
143 }
144 }
145 } else if self.at_end {
146 self.tree.last()
147 } else {
148 None
149 }
150 }
151
152 fn prev_leaf(&self) -> Option<&'a SumTree<T>> {
153 for entry in self.stack.iter().rev().skip(1) {
154 if entry.index != 0 {
155 match *entry.tree.0 {
156 Node::Internal {
157 ref child_trees, ..
158 } => return Some(child_trees[entry.index - 1].rightmost_leaf()),
159 Node::Leaf { .. } => unreachable!(),
160 };
161 }
162 }
163 None
164 }
165
166 pub fn prev(&mut self, cx: &<T::Summary as Summary>::Context) {
167 assert!(self.did_seek, "Must seek before calling this method");
168
169 if self.at_end {
170 self.seek_dimension = S::default();
171 self.sum_dimension = U::default();
172 self.descend_to_last_item(self.tree, cx);
173 self.at_end = false;
174 } else {
175 while let Some(entry) = self.stack.pop() {
176 if entry.index > 0 {
177 let new_index = entry.index - 1;
178
179 if let Some(StackEntry {
180 seek_dimension,
181 sum_dimension,
182 ..
183 }) = self.stack.last()
184 {
185 self.seek_dimension = seek_dimension.clone();
186 self.sum_dimension = sum_dimension.clone();
187 } else {
188 self.seek_dimension = S::default();
189 self.sum_dimension = U::default();
190 }
191
192 match entry.tree.0.as_ref() {
193 Node::Internal {
194 child_trees,
195 child_summaries,
196 ..
197 } => {
198 for summary in &child_summaries[0..new_index] {
199 self.seek_dimension.add_summary(summary, cx);
200 self.sum_dimension.add_summary(summary, cx);
201 }
202 self.stack.push(StackEntry {
203 tree: entry.tree,
204 index: new_index,
205 seek_dimension: self.seek_dimension.clone(),
206 sum_dimension: self.sum_dimension.clone(),
207 });
208 self.descend_to_last_item(&child_trees[new_index], cx);
209 }
210 Node::Leaf { item_summaries, .. } => {
211 for item_summary in &item_summaries[0..new_index] {
212 self.seek_dimension.add_summary(item_summary, cx);
213 self.sum_dimension.add_summary(item_summary, cx);
214 }
215 self.stack.push(StackEntry {
216 tree: entry.tree,
217 index: new_index,
218 seek_dimension: self.seek_dimension.clone(),
219 sum_dimension: self.sum_dimension.clone(),
220 });
221 }
222 }
223
224 break;
225 }
226 }
227 }
228 }
229
230 pub fn next(&mut self, cx: &<T::Summary as Summary>::Context) {
231 self.next_internal(|_| true, cx)
232 }
233
234 fn next_internal<F>(&mut self, filter_node: F, cx: &<T::Summary as Summary>::Context)
235 where
236 F: Fn(&T::Summary) -> bool,
237 {
238 let mut descend = false;
239
240 if self.stack.is_empty() && !self.at_end {
241 self.stack.push(StackEntry {
242 tree: self.tree,
243 index: 0,
244 seek_dimension: S::default(),
245 sum_dimension: U::default(),
246 });
247 descend = true;
248 self.did_seek = true;
249 }
250
251 while self.stack.len() > 0 {
252 let new_subtree = {
253 let entry = self.stack.last_mut().unwrap();
254 match entry.tree.0.as_ref() {
255 Node::Internal {
256 child_trees,
257 child_summaries,
258 ..
259 } => {
260 if !descend {
261 entry.seek_dimension = self.seek_dimension.clone();
262 entry.sum_dimension = self.sum_dimension.clone();
263 entry.index += 1;
264 }
265
266 while entry.index < child_summaries.len() {
267 let next_summary = &child_summaries[entry.index];
268 if filter_node(next_summary) {
269 break;
270 } else {
271 self.seek_dimension.add_summary(next_summary, cx);
272 self.sum_dimension.add_summary(next_summary, cx);
273 }
274 entry.index += 1;
275 }
276
277 child_trees.get(entry.index)
278 }
279 Node::Leaf { item_summaries, .. } => {
280 if !descend {
281 let item_summary = &item_summaries[entry.index];
282 self.seek_dimension.add_summary(item_summary, cx);
283 entry.seek_dimension.add_summary(item_summary, cx);
284 self.sum_dimension.add_summary(item_summary, cx);
285 entry.sum_dimension.add_summary(item_summary, cx);
286 entry.index += 1;
287 }
288
289 loop {
290 if let Some(next_item_summary) = item_summaries.get(entry.index) {
291 if filter_node(next_item_summary) {
292 return;
293 } else {
294 self.seek_dimension.add_summary(next_item_summary, cx);
295 entry.seek_dimension.add_summary(next_item_summary, cx);
296 self.sum_dimension.add_summary(next_item_summary, cx);
297 entry.sum_dimension.add_summary(next_item_summary, cx);
298 entry.index += 1;
299 }
300 } else {
301 break None;
302 }
303 }
304 }
305 }
306 };
307
308 if let Some(subtree) = new_subtree {
309 descend = true;
310 self.stack.push(StackEntry {
311 tree: subtree,
312 index: 0,
313 seek_dimension: self.seek_dimension.clone(),
314 sum_dimension: self.sum_dimension.clone(),
315 });
316 } else {
317 descend = false;
318 self.stack.pop();
319 }
320 }
321
322 self.at_end = self.stack.is_empty();
323 debug_assert!(self.stack.is_empty() || self.stack.last().unwrap().tree.0.is_leaf());
324 }
325
326 fn descend_to_last_item(
327 &mut self,
328 mut subtree: &'a SumTree<T>,
329 cx: &<T::Summary as Summary>::Context,
330 ) {
331 self.did_seek = true;
332 loop {
333 match subtree.0.as_ref() {
334 Node::Internal {
335 child_trees,
336 child_summaries,
337 ..
338 } => {
339 for summary in &child_summaries[0..child_summaries.len() - 1] {
340 self.seek_dimension.add_summary(summary, cx);
341 self.sum_dimension.add_summary(summary, cx);
342 }
343
344 self.stack.push(StackEntry {
345 tree: subtree,
346 index: child_trees.len() - 1,
347 seek_dimension: self.seek_dimension.clone(),
348 sum_dimension: self.sum_dimension.clone(),
349 });
350 subtree = child_trees.last().unwrap();
351 }
352 Node::Leaf { item_summaries, .. } => {
353 let last_index = item_summaries.len().saturating_sub(1);
354 for item_summary in &item_summaries[0..last_index] {
355 self.seek_dimension.add_summary(item_summary, cx);
356 self.sum_dimension.add_summary(item_summary, cx);
357 }
358 self.stack.push(StackEntry {
359 tree: subtree,
360 index: last_index,
361 seek_dimension: self.seek_dimension.clone(),
362 sum_dimension: self.sum_dimension.clone(),
363 });
364 break;
365 }
366 }
367 }
368 }
369}
370
371impl<'a, T, S, U> Cursor<'a, T, S, U>
372where
373 T: Item,
374 S: SeekDimension<'a, T::Summary>,
375 U: Dimension<'a, T::Summary>,
376{
377 pub fn seek(&mut self, pos: &S, bias: Bias, cx: &<T::Summary as Summary>::Context) -> bool {
378 self.reset();
379 self.seek_internal::<()>(Some(pos), bias, &mut SeekAggregate::None, cx)
380 }
381
382 pub fn seek_forward(
383 &mut self,
384 pos: &S,
385 bias: Bias,
386 cx: &<T::Summary as Summary>::Context,
387 ) -> bool {
388 self.seek_internal::<()>(Some(pos), bias, &mut SeekAggregate::None, cx)
389 }
390
391 pub fn slice(
392 &mut self,
393 end: &S,
394 bias: Bias,
395 cx: &<T::Summary as Summary>::Context,
396 ) -> SumTree<T> {
397 let mut slice = SeekAggregate::Slice(SumTree::new());
398 self.seek_internal::<()>(Some(end), bias, &mut slice, cx);
399 if let SeekAggregate::Slice(slice) = slice {
400 slice
401 } else {
402 unreachable!()
403 }
404 }
405
406 pub fn suffix(&mut self, cx: &<T::Summary as Summary>::Context) -> SumTree<T> {
407 let mut slice = SeekAggregate::Slice(SumTree::new());
408 self.seek_internal::<()>(None, Bias::Right, &mut slice, cx);
409 if let SeekAggregate::Slice(slice) = slice {
410 slice
411 } else {
412 unreachable!()
413 }
414 }
415
416 pub fn summary<D>(&mut self, end: &S, bias: Bias, cx: &<T::Summary as Summary>::Context) -> D
417 where
418 D: Dimension<'a, T::Summary>,
419 {
420 let mut summary = SeekAggregate::Summary(D::default());
421 self.seek_internal(Some(end), bias, &mut summary, cx);
422 if let SeekAggregate::Summary(summary) = summary {
423 summary
424 } else {
425 unreachable!()
426 }
427 }
428
429 fn seek_internal<D>(
430 &mut self,
431 target: Option<&S>,
432 bias: Bias,
433 aggregate: &mut SeekAggregate<T, D>,
434 cx: &<T::Summary as Summary>::Context,
435 ) -> bool
436 where
437 D: Dimension<'a, T::Summary>,
438 {
439 if let Some(target) = target {
440 debug_assert!(
441 target.cmp(&self.seek_dimension, cx) >= Ordering::Equal,
442 "cannot seek backward from {:?} to {:?}",
443 self.seek_dimension,
444 target
445 );
446 }
447
448 if !self.did_seek {
449 self.did_seek = true;
450 self.stack.push(StackEntry {
451 tree: self.tree,
452 index: 0,
453 seek_dimension: Default::default(),
454 sum_dimension: Default::default(),
455 });
456 }
457
458 let mut ascending = false;
459 'outer: while let Some(entry) = self.stack.last_mut() {
460 match *entry.tree.0 {
461 Node::Internal {
462 ref child_summaries,
463 ref child_trees,
464 ..
465 } => {
466 if ascending {
467 entry.index += 1;
468 }
469
470 for (child_tree, child_summary) in child_trees[entry.index..]
471 .iter()
472 .zip(&child_summaries[entry.index..])
473 {
474 let mut child_end = self.seek_dimension.clone();
475 child_end.add_summary(&child_summary, cx);
476
477 let comparison =
478 target.map_or(Ordering::Greater, |t| t.cmp(&child_end, cx));
479 if comparison == Ordering::Greater
480 || (comparison == Ordering::Equal && bias == Bias::Right)
481 {
482 self.seek_dimension = child_end;
483 self.sum_dimension.add_summary(child_summary, cx);
484 match aggregate {
485 SeekAggregate::None => {}
486 SeekAggregate::Slice(slice) => {
487 slice.push_tree(child_tree.clone(), cx);
488 }
489 SeekAggregate::Summary(summary) => {
490 summary.add_summary(child_summary, cx);
491 }
492 }
493 entry.index += 1;
494 entry.seek_dimension = self.seek_dimension.clone();
495 entry.sum_dimension = self.sum_dimension.clone();
496 } else {
497 self.stack.push(StackEntry {
498 tree: child_tree,
499 index: 0,
500 seek_dimension: self.seek_dimension.clone(),
501 sum_dimension: self.sum_dimension.clone(),
502 });
503 ascending = false;
504 continue 'outer;
505 }
506 }
507 }
508 Node::Leaf {
509 ref items,
510 ref item_summaries,
511 ..
512 } => {
513 let mut slice_items = ArrayVec::<T, { 2 * TREE_BASE }>::new();
514 let mut slice_item_summaries = ArrayVec::<T::Summary, { 2 * TREE_BASE }>::new();
515 let mut slice_items_summary = match aggregate {
516 SeekAggregate::Slice(_) => Some(T::Summary::default()),
517 _ => None,
518 };
519
520 for (item, item_summary) in items[entry.index..]
521 .iter()
522 .zip(&item_summaries[entry.index..])
523 {
524 let mut child_end = self.seek_dimension.clone();
525 child_end.add_summary(item_summary, cx);
526
527 let comparison =
528 target.map_or(Ordering::Greater, |t| t.cmp(&child_end, cx));
529 if comparison == Ordering::Greater
530 || (comparison == Ordering::Equal && bias == Bias::Right)
531 {
532 self.seek_dimension = child_end;
533 self.sum_dimension.add_summary(item_summary, cx);
534 match aggregate {
535 SeekAggregate::None => {}
536 SeekAggregate::Slice(_) => {
537 slice_items.push(item.clone());
538 slice_item_summaries.push(item_summary.clone());
539 slice_items_summary
540 .as_mut()
541 .unwrap()
542 .add_summary(item_summary, cx);
543 }
544 SeekAggregate::Summary(summary) => {
545 summary.add_summary(item_summary, cx);
546 }
547 }
548 entry.index += 1;
549 } else {
550 if let SeekAggregate::Slice(slice) = aggregate {
551 slice.push_tree(
552 SumTree(Arc::new(Node::Leaf {
553 summary: slice_items_summary.unwrap(),
554 items: slice_items,
555 item_summaries: slice_item_summaries,
556 })),
557 cx,
558 );
559 }
560 break 'outer;
561 }
562 }
563
564 if let SeekAggregate::Slice(slice) = aggregate {
565 if !slice_items.is_empty() {
566 slice.push_tree(
567 SumTree(Arc::new(Node::Leaf {
568 summary: slice_items_summary.unwrap(),
569 items: slice_items,
570 item_summaries: slice_item_summaries,
571 })),
572 cx,
573 );
574 }
575 }
576 }
577 }
578
579 self.stack.pop();
580 ascending = true;
581 }
582
583 self.at_end = self.stack.is_empty();
584 debug_assert!(self.stack.is_empty() || self.stack.last().unwrap().tree.0.is_leaf());
585
586 let mut end = self.seek_dimension.clone();
587 if bias == Bias::Left {
588 if let Some(summary) = self.item_summary() {
589 end.add_summary(summary, cx);
590 }
591 }
592
593 target.map_or(false, |t| t.cmp(&end, cx) == Ordering::Equal)
594 }
595}
596
597impl<'a, T, S, Seek, Sum> Iterator for Cursor<'a, T, Seek, Sum>
598where
599 T: Item<Summary = S>,
600 S: Summary<Context = ()>,
601 Seek: Dimension<'a, T::Summary>,
602 Sum: Dimension<'a, T::Summary>,
603{
604 type Item = &'a T;
605
606 fn next(&mut self) -> Option<Self::Item> {
607 if !self.did_seek {
608 self.next(&());
609 }
610
611 if let Some(item) = self.item() {
612 self.next(&());
613 Some(item)
614 } else {
615 None
616 }
617 }
618}
619
620impl<'a, T, S, U> Cursor<'a, T, S, U>
621where
622 T: Item,
623 S: SeekDimension<'a, T::Summary>,
624 U: SeekDimension<'a, T::Summary>,
625{
626 pub fn swap_dimensions(self) -> Cursor<'a, T, U, S> {
627 Cursor {
628 tree: self.tree,
629 stack: self
630 .stack
631 .into_iter()
632 .map(StackEntry::swap_dimensions)
633 .collect(),
634 seek_dimension: self.sum_dimension,
635 sum_dimension: self.seek_dimension,
636 did_seek: self.did_seek,
637 at_end: self.at_end,
638 }
639 }
640}
641
642pub struct FilterCursor<'a, F: Fn(&T::Summary) -> bool, T: Item, U> {
643 cursor: Cursor<'a, T, (), U>,
644 filter_node: F,
645}
646
647impl<'a, F, T, U> FilterCursor<'a, F, T, U>
648where
649 F: Fn(&T::Summary) -> bool,
650 T: Item,
651 U: Dimension<'a, T::Summary>,
652{
653 pub fn new(
654 tree: &'a SumTree<T>,
655 filter_node: F,
656 cx: &<T::Summary as Summary>::Context,
657 ) -> Self {
658 let mut cursor = tree.cursor::<(), U>();
659 cursor.next_internal(&filter_node, cx);
660 Self {
661 cursor,
662 filter_node,
663 }
664 }
665
666 pub fn start(&self) -> &U {
667 self.cursor.sum_start()
668 }
669
670 pub fn item(&self) -> Option<&'a T> {
671 self.cursor.item()
672 }
673
674 pub fn next(&mut self, cx: &<T::Summary as Summary>::Context) {
675 self.cursor.next_internal(&self.filter_node, cx);
676 }
677}
678
679impl<'a, F, T, S, U> Iterator for FilterCursor<'a, F, T, U>
680where
681 F: Fn(&T::Summary) -> bool,
682 T: Item<Summary = S>,
683 S: Summary<Context = ()>,
684 U: Dimension<'a, T::Summary>,
685{
686 type Item = &'a T;
687
688 fn next(&mut self) -> Option<Self::Item> {
689 if let Some(item) = self.item() {
690 self.cursor.next_internal(&self.filter_node, &());
691 Some(item)
692 } else {
693 None
694 }
695 }
696}
697
698enum SeekAggregate<T: Item, D> {
699 None,
700 Slice(SumTree<T>),
701 Summary(D),
702}