bounds_tree.rs

  1use crate::{Bounds, Half};
  2use std::{
  3    cmp,
  4    fmt::Debug,
  5    ops::{Add, Sub},
  6};
  7
  8#[derive(Debug)]
  9pub(crate) struct BoundsTree<U>
 10where
 11    U: Clone + Debug + Default + PartialEq,
 12{
 13    root: Option<usize>,
 14    nodes: Vec<Node<U>>,
 15    stack: Vec<usize>,
 16}
 17
 18impl<U> BoundsTree<U>
 19where
 20    U: Clone
 21        + Debug
 22        + PartialEq
 23        + PartialOrd
 24        + Add<U, Output = U>
 25        + Sub<Output = U>
 26        + Half
 27        + Default,
 28{
 29    pub fn clear(&mut self) {
 30        self.root = None;
 31        self.nodes.clear();
 32        self.stack.clear();
 33    }
 34
 35    pub fn insert(&mut self, new_bounds: Bounds<U>) -> u32 {
 36        // If the tree is empty, make the root the new leaf.
 37        if self.root.is_none() {
 38            let new_node = self.push_leaf(new_bounds, 1);
 39            self.root = Some(new_node);
 40            return 1;
 41        }
 42
 43        // Search for the best place to add the new leaf based on heuristics.
 44        let mut index = self.root.unwrap();
 45        let mut max_intersecting_ordering = self.find_max_ordering(index, &new_bounds, 0);
 46
 47        while let Node::Internal {
 48            left,
 49            right,
 50            bounds: node_bounds,
 51            ..
 52        } = &mut self.nodes[index]
 53        {
 54            let left = *left;
 55            let right = *right;
 56            *node_bounds = node_bounds.union(&new_bounds);
 57            self.stack.push(index);
 58
 59            // Descend to the best-fit child, based on which one would increase
 60            // the surface area the least. This attempts to keep the tree balanced
 61            // in terms of surface area. If there is an intersection with the other child,
 62            // add its keys to the intersections vector.
 63            let left_cost = new_bounds.union(self.nodes[left].bounds()).half_perimeter();
 64            let right_cost = new_bounds
 65                .union(self.nodes[right].bounds())
 66                .half_perimeter();
 67            index = if left_cost < right_cost { left } else { right };
 68        }
 69
 70        // We've found a leaf ('index' now refers to a leaf node).
 71        // We'll insert a new parent node above the leaf and attach our new leaf to it.
 72        let sibling = index;
 73
 74        // Check for collision with the located leaf node
 75        let Node::Leaf {
 76            bounds: sibling_bounds,
 77            order: sibling_ordering,
 78            ..
 79        } = &self.nodes[index]
 80        else {
 81            unreachable!();
 82        };
 83        if sibling_bounds.intersects(&new_bounds) {
 84            max_intersecting_ordering = cmp::max(max_intersecting_ordering, *sibling_ordering);
 85        }
 86
 87        let ordering = max_intersecting_ordering + 1;
 88        let new_node = self.push_leaf(new_bounds, ordering);
 89        let new_parent = self.push_internal(sibling, new_node);
 90
 91        // If there was an old parent, we need to update its children indices.
 92        if let Some(old_parent) = self.stack.last().copied() {
 93            let Node::Internal { left, right, .. } = &mut self.nodes[old_parent] else {
 94                unreachable!();
 95            };
 96
 97            if *left == sibling {
 98                *left = new_parent;
 99            } else {
100                *right = new_parent;
101            }
102        } else {
103            // If the old parent was the root, the new parent is the new root.
104            self.root = Some(new_parent);
105        }
106
107        for node_index in self.stack.drain(..).rev() {
108            let Node::Internal {
109                max_order: max_ordering,
110                ..
111            } = &mut self.nodes[node_index]
112            else {
113                unreachable!()
114            };
115            if *max_ordering >= ordering {
116                break;
117            }
118            *max_ordering = ordering;
119        }
120
121        ordering
122    }
123
124    fn find_max_ordering(&self, index: usize, bounds: &Bounds<U>, mut max_ordering: u32) -> u32 {
125        match &self.nodes[index] {
126            Node::Leaf {
127                bounds: node_bounds,
128                order: ordering,
129                ..
130            } => {
131                if bounds.intersects(node_bounds) {
132                    max_ordering = cmp::max(*ordering, max_ordering);
133                }
134            }
135            Node::Internal {
136                left,
137                right,
138                bounds: node_bounds,
139                max_order: node_max_ordering,
140                ..
141            } => {
142                if bounds.intersects(node_bounds) && max_ordering < *node_max_ordering {
143                    let left_max_ordering = self.nodes[*left].max_ordering();
144                    let right_max_ordering = self.nodes[*right].max_ordering();
145                    if left_max_ordering > right_max_ordering {
146                        max_ordering = self.find_max_ordering(*left, bounds, max_ordering);
147                        max_ordering = self.find_max_ordering(*right, bounds, max_ordering);
148                    } else {
149                        max_ordering = self.find_max_ordering(*right, bounds, max_ordering);
150                        max_ordering = self.find_max_ordering(*left, bounds, max_ordering);
151                    }
152                }
153            }
154        }
155        max_ordering
156    }
157
158    fn push_leaf(&mut self, bounds: Bounds<U>, order: u32) -> usize {
159        self.nodes.push(Node::Leaf { bounds, order });
160        self.nodes.len() - 1
161    }
162
163    fn push_internal(&mut self, left: usize, right: usize) -> usize {
164        let left_node = &self.nodes[left];
165        let right_node = &self.nodes[right];
166        let new_bounds = left_node.bounds().union(right_node.bounds());
167        let max_ordering = cmp::max(left_node.max_ordering(), right_node.max_ordering());
168        self.nodes.push(Node::Internal {
169            bounds: new_bounds,
170            left,
171            right,
172            max_order: max_ordering,
173        });
174        self.nodes.len() - 1
175    }
176}
177
178impl<U> Default for BoundsTree<U>
179where
180    U: Clone + Debug + Default + PartialEq,
181{
182    fn default() -> Self {
183        BoundsTree {
184            root: None,
185            nodes: Vec::new(),
186            stack: Vec::new(),
187        }
188    }
189}
190
191#[derive(Debug, Clone)]
192enum Node<U>
193where
194    U: Clone + Debug + Default + PartialEq,
195{
196    Leaf {
197        order: u32,
198        bounds: Bounds<U>,
199    },
200    Internal {
201        max_order: u32,
202        bounds: Bounds<U>,
203        left: usize,
204        right: usize,
205    },
206}
207
208impl<U> Node<U>
209where
210    U: Clone + Debug + Default + PartialEq,
211{
212    fn bounds(&self) -> &Bounds<U> {
213        match self {
214            Node::Leaf { bounds, .. } => bounds,
215            Node::Internal { bounds, .. } => bounds,
216        }
217    }
218
219    fn max_ordering(&self) -> u32 {
220        match self {
221            &Node::Leaf {
222                order: ordering, ..
223            }
224            | &Node::Internal {
225                max_order: ordering,
226                ..
227            } => ordering,
228        }
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use crate::{Bounds, Point, Size};
236    use rand::{Rng, SeedableRng};
237
238    #[test]
239    fn test_insert() {
240        let mut tree = BoundsTree::<f32>::default();
241        let bounds1 = Bounds {
242            origin: Point { x: 0.0, y: 0.0 },
243            size: Size {
244                width: 10.0,
245                height: 10.0,
246            },
247        };
248        let bounds2 = Bounds {
249            origin: Point { x: 5.0, y: 5.0 },
250            size: Size {
251                width: 10.0,
252                height: 10.0,
253            },
254        };
255        let bounds3 = Bounds {
256            origin: Point { x: 10.0, y: 10.0 },
257            size: Size {
258                width: 10.0,
259                height: 10.0,
260            },
261        };
262
263        // Insert the bounds into the tree and verify the order is correct
264        assert_eq!(tree.insert(bounds1), 1);
265        assert_eq!(tree.insert(bounds2), 2);
266        assert_eq!(tree.insert(bounds3), 3);
267
268        // Insert non-overlapping bounds and verify they can reuse orders
269        let bounds4 = Bounds {
270            origin: Point { x: 20.0, y: 20.0 },
271            size: Size {
272                width: 10.0,
273                height: 10.0,
274            },
275        };
276        let bounds5 = Bounds {
277            origin: Point { x: 40.0, y: 40.0 },
278            size: Size {
279                width: 10.0,
280                height: 10.0,
281            },
282        };
283        let bounds6 = Bounds {
284            origin: Point { x: 25.0, y: 25.0 },
285            size: Size {
286                width: 10.0,
287                height: 10.0,
288            },
289        };
290        assert_eq!(tree.insert(bounds4), 1); // bounds4 does not overlap with bounds1, bounds2, or bounds3
291        assert_eq!(tree.insert(bounds5), 1); // bounds5 does not overlap with any other bounds
292        assert_eq!(tree.insert(bounds6), 2); // bounds6 overlaps with bounds4, so it should have a different order
293    }
294
295    #[test]
296    fn test_random_iterations() {
297        let max_bounds = 100;
298        for seed in 1..=1000 {
299            // let seed = 44;
300            let mut tree = BoundsTree::default();
301            let mut rng = rand::rngs::StdRng::seed_from_u64(seed as u64);
302            let mut expected_quads: Vec<(Bounds<f32>, u32)> = Vec::new();
303
304            // Insert a random number of random AABBs into the tree.
305            let num_bounds = rng.random_range(1..=max_bounds);
306            for _ in 0..num_bounds {
307                let min_x: f32 = rng.random_range(-100.0..100.0);
308                let min_y: f32 = rng.random_range(-100.0..100.0);
309                let width: f32 = rng.random_range(0.0..50.0);
310                let height: f32 = rng.random_range(0.0..50.0);
311                let bounds = Bounds {
312                    origin: Point { x: min_x, y: min_y },
313                    size: Size { width, height },
314                };
315
316                let expected_ordering = expected_quads
317                    .iter()
318                    .filter_map(|quad| quad.0.intersects(&bounds).then_some(quad.1))
319                    .max()
320                    .unwrap_or(0)
321                    + 1;
322                expected_quads.push((bounds, expected_ordering));
323
324                // Insert the AABB into the tree and collect intersections.
325                let actual_ordering = tree.insert(bounds);
326                assert_eq!(actual_ordering, expected_ordering);
327            }
328        }
329    }
330}