tree_map.rs

  1use std::{cmp::Ordering, fmt::Debug};
  2
  3use crate::{Bias, Dimension, Edit, Item, KeyedItem, SeekTarget, SumTree, Summary};
  4
  5/// A cheaply-cloneable ordered map based on a [SumTree](crate::SumTree).
  6#[derive(Clone, PartialEq, Eq)]
  7pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
  8where
  9    K: Clone + Ord,
 10    V: Clone;
 11
 12#[derive(Clone, Debug, PartialEq, Eq)]
 13pub struct MapEntry<K, V> {
 14    key: K,
 15    value: V,
 16}
 17
 18#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
 19pub struct MapKey<K>(Option<K>);
 20
 21impl<K> Default for MapKey<K> {
 22    fn default() -> Self {
 23        Self(None)
 24    }
 25}
 26
 27#[derive(Clone, Debug)]
 28pub struct MapKeyRef<'a, K>(Option<&'a K>);
 29
 30impl<K> Default for MapKeyRef<'_, K> {
 31    fn default() -> Self {
 32        Self(None)
 33    }
 34}
 35
 36#[derive(Clone, Debug, PartialEq, Eq)]
 37pub struct TreeSet<K>(TreeMap<K, ()>)
 38where
 39    K: Clone + Ord;
 40
 41impl<K: Clone + Ord, V: Clone> TreeMap<K, V> {
 42    pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
 43        let tree = SumTree::from_iter(
 44            entries
 45                .into_iter()
 46                .map(|(key, value)| MapEntry { key, value }),
 47            &(),
 48        );
 49        Self(tree)
 50    }
 51
 52    pub fn is_empty(&self) -> bool {
 53        self.0.is_empty()
 54    }
 55
 56    pub fn get(&self, key: &K) -> Option<&V> {
 57        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
 58        cursor.seek(&MapKeyRef(Some(key)), Bias::Left);
 59        if let Some(item) = cursor.item() {
 60            if Some(key) == item.key().0.as_ref() {
 61                Some(&item.value)
 62            } else {
 63                None
 64            }
 65        } else {
 66            None
 67        }
 68    }
 69
 70    pub fn insert(&mut self, key: K, value: V) {
 71        self.0.insert_or_replace(MapEntry { key, value }, &());
 72    }
 73
 74    pub fn extend(&mut self, iter: impl IntoIterator<Item = (K, V)>) {
 75        let edits: Vec<_> = iter
 76            .into_iter()
 77            .map(|(key, value)| Edit::Insert(MapEntry { key, value }))
 78            .collect();
 79        self.0.edit(edits, &());
 80    }
 81
 82    pub fn clear(&mut self) {
 83        self.0 = SumTree::default();
 84    }
 85
 86    pub fn remove(&mut self, key: &K) -> Option<V> {
 87        let mut removed = None;
 88        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
 89        let key = MapKeyRef(Some(key));
 90        let mut new_tree = cursor.slice(&key, Bias::Left);
 91        if key.cmp(&cursor.end(), &()) == Ordering::Equal {
 92            removed = Some(cursor.item().unwrap().value.clone());
 93            cursor.next();
 94        }
 95        new_tree.append(cursor.suffix(), &());
 96        drop(cursor);
 97        self.0 = new_tree;
 98        removed
 99    }
100
101    pub fn remove_range(&mut self, start: &impl MapSeekTarget<K>, end: &impl MapSeekTarget<K>) {
102        let start = MapSeekTargetAdaptor(start);
103        let end = MapSeekTargetAdaptor(end);
104        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
105        let mut new_tree = cursor.slice(&start, Bias::Left);
106        cursor.seek(&end, Bias::Left);
107        new_tree.append(cursor.suffix(), &());
108        drop(cursor);
109        self.0 = new_tree;
110    }
111
112    /// Returns the key-value pair with the greatest key less than or equal to the given key.
113    pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
114        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
115        let key = MapKeyRef(Some(key));
116        cursor.seek(&key, Bias::Right);
117        cursor.prev();
118        cursor.item().map(|item| (&item.key, &item.value))
119    }
120
121    pub fn iter_from<'a>(&'a self, from: &K) -> impl Iterator<Item = (&'a K, &'a V)> + 'a {
122        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
123        let from_key = MapKeyRef(Some(from));
124        cursor.seek(&from_key, Bias::Left);
125
126        cursor.map(|map_entry| (&map_entry.key, &map_entry.value))
127    }
128
129    pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
130    where
131        F: FnOnce(&mut V) -> T,
132    {
133        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
134        let key = MapKeyRef(Some(key));
135        let mut new_tree = cursor.slice(&key, Bias::Left);
136        let mut result = None;
137        if key.cmp(&cursor.end(), &()) == Ordering::Equal {
138            let mut updated = cursor.item().unwrap().clone();
139            result = Some(f(&mut updated.value));
140            new_tree.push(updated, &());
141            cursor.next();
142        }
143        new_tree.append(cursor.suffix(), &());
144        drop(cursor);
145        self.0 = new_tree;
146        result
147    }
148
149    pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
150        let mut new_map = SumTree::<MapEntry<K, V>>::default();
151
152        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
153        cursor.next();
154        while let Some(item) = cursor.item() {
155            if predicate(&item.key, &item.value) {
156                new_map.push(item.clone(), &());
157            }
158            cursor.next();
159        }
160        drop(cursor);
161
162        self.0 = new_map;
163    }
164
165    pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
166        self.0.iter().map(|entry| (&entry.key, &entry.value))
167    }
168
169    pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
170        self.0.iter().map(|entry| &entry.value)
171    }
172
173    pub fn first(&self) -> Option<(&K, &V)> {
174        self.0.first().map(|entry| (&entry.key, &entry.value))
175    }
176
177    pub fn last(&self) -> Option<(&K, &V)> {
178        self.0.last().map(|entry| (&entry.key, &entry.value))
179    }
180
181    pub fn insert_tree(&mut self, other: TreeMap<K, V>) {
182        let edits = other
183            .iter()
184            .map(|(key, value)| {
185                Edit::Insert(MapEntry {
186                    key: key.to_owned(),
187                    value: value.to_owned(),
188                })
189            })
190            .collect();
191
192        self.0.edit(edits, &());
193    }
194}
195
196impl<K, V> Debug for TreeMap<K, V>
197where
198    K: Clone + Debug + Ord,
199    V: Clone + Debug,
200{
201    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202        f.debug_map().entries(self.iter()).finish()
203    }
204}
205
206#[derive(Debug)]
207struct MapSeekTargetAdaptor<'a, T>(&'a T);
208
209impl<'a, K: Clone + Ord, T: MapSeekTarget<K>> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>>
210    for MapSeekTargetAdaptor<'_, T>
211{
212    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
213        if let Some(key) = &cursor_location.0 {
214            MapSeekTarget::cmp_cursor(self.0, key)
215        } else {
216            Ordering::Greater
217        }
218    }
219}
220
221pub trait MapSeekTarget<K> {
222    fn cmp_cursor(&self, cursor_location: &K) -> Ordering;
223}
224
225impl<K: Ord> MapSeekTarget<K> for K {
226    fn cmp_cursor(&self, cursor_location: &K) -> Ordering {
227        self.cmp(cursor_location)
228    }
229}
230
231impl<K, V> Default for TreeMap<K, V>
232where
233    K: Clone + Ord,
234    V: Clone,
235{
236    fn default() -> Self {
237        Self(Default::default())
238    }
239}
240
241impl<K, V> Item for MapEntry<K, V>
242where
243    K: Clone + Ord,
244    V: Clone,
245{
246    type Summary = MapKey<K>;
247
248    fn summary(&self, _cx: &()) -> Self::Summary {
249        self.key()
250    }
251}
252
253impl<K, V> KeyedItem for MapEntry<K, V>
254where
255    K: Clone + Ord,
256    V: Clone,
257{
258    type Key = MapKey<K>;
259
260    fn key(&self) -> Self::Key {
261        MapKey(Some(self.key.clone()))
262    }
263}
264
265impl<K> Summary for MapKey<K>
266where
267    K: Clone,
268{
269    type Context = ();
270
271    fn zero(_cx: &()) -> Self {
272        Default::default()
273    }
274
275    fn add_summary(&mut self, summary: &Self, _: &()) {
276        *self = summary.clone()
277    }
278}
279
280impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
281where
282    K: Clone + Ord,
283{
284    fn zero(_cx: &()) -> Self {
285        Default::default()
286    }
287
288    fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
289        self.0 = summary.0.as_ref();
290    }
291}
292
293impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
294where
295    K: Clone + Ord,
296{
297    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
298        Ord::cmp(&self.0, &cursor_location.0)
299    }
300}
301
302impl<K> Default for TreeSet<K>
303where
304    K: Clone + Ord,
305{
306    fn default() -> Self {
307        Self(Default::default())
308    }
309}
310
311impl<K> TreeSet<K>
312where
313    K: Clone + Ord,
314{
315    pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
316        Self(TreeMap::from_ordered_entries(
317            entries.into_iter().map(|key| (key, ())),
318        ))
319    }
320
321    pub fn is_empty(&self) -> bool {
322        self.0.is_empty()
323    }
324
325    pub fn insert(&mut self, key: K) {
326        self.0.insert(key, ());
327    }
328
329    pub fn remove(&mut self, key: &K) -> bool {
330        self.0.remove(key).is_some()
331    }
332
333    pub fn extend(&mut self, iter: impl IntoIterator<Item = K>) {
334        self.0.extend(iter.into_iter().map(|key| (key, ())));
335    }
336
337    pub fn contains(&self, key: &K) -> bool {
338        self.0.get(key).is_some()
339    }
340
341    pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
342        self.0.iter().map(|(k, _)| k)
343    }
344
345    pub fn iter_from<'a>(&'a self, key: &K) -> impl Iterator<Item = &'a K> + 'a {
346        self.0.iter_from(key).map(move |(k, _)| k)
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    #[test]
355    fn test_basic() {
356        let mut map = TreeMap::default();
357        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
358
359        map.insert(3, "c");
360        assert_eq!(map.get(&3), Some(&"c"));
361        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
362
363        map.insert(1, "a");
364        assert_eq!(map.get(&1), Some(&"a"));
365        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
366
367        map.insert(2, "b");
368        assert_eq!(map.get(&2), Some(&"b"));
369        assert_eq!(map.get(&1), Some(&"a"));
370        assert_eq!(map.get(&3), Some(&"c"));
371        assert_eq!(
372            map.iter().collect::<Vec<_>>(),
373            vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
374        );
375
376        assert_eq!(map.closest(&0), None);
377        assert_eq!(map.closest(&1), Some((&1, &"a")));
378        assert_eq!(map.closest(&10), Some((&3, &"c")));
379
380        map.remove(&2);
381        assert_eq!(map.get(&2), None);
382        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
383
384        assert_eq!(map.closest(&2), Some((&1, &"a")));
385
386        map.remove(&3);
387        assert_eq!(map.get(&3), None);
388        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
389
390        map.remove(&1);
391        assert_eq!(map.get(&1), None);
392        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
393
394        map.insert(4, "d");
395        map.insert(5, "e");
396        map.insert(6, "f");
397        map.retain(|key, _| *key % 2 == 0);
398        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
399    }
400
401    #[test]
402    fn test_iter_from() {
403        let mut map = TreeMap::default();
404
405        map.insert("a", 1);
406        map.insert("b", 2);
407        map.insert("baa", 3);
408        map.insert("baaab", 4);
409        map.insert("c", 5);
410
411        let result = map
412            .iter_from(&"ba")
413            .take_while(|(key, _)| key.starts_with("ba"))
414            .collect::<Vec<_>>();
415
416        assert_eq!(result.len(), 2);
417        assert!(result.iter().any(|(k, _)| k == &&"baa"));
418        assert!(result.iter().any(|(k, _)| k == &&"baaab"));
419
420        let result = map
421            .iter_from(&"c")
422            .take_while(|(key, _)| key.starts_with("c"))
423            .collect::<Vec<_>>();
424
425        assert_eq!(result.len(), 1);
426        assert!(result.iter().any(|(k, _)| k == &&"c"));
427    }
428
429    #[test]
430    fn test_insert_tree() {
431        let mut map = TreeMap::default();
432        map.insert("a", 1);
433        map.insert("b", 2);
434        map.insert("c", 3);
435
436        let mut other = TreeMap::default();
437        other.insert("a", 2);
438        other.insert("b", 2);
439        other.insert("d", 4);
440
441        map.insert_tree(other);
442
443        assert_eq!(map.iter().count(), 4);
444        assert_eq!(map.get(&"a"), Some(&2));
445        assert_eq!(map.get(&"b"), Some(&2));
446        assert_eq!(map.get(&"c"), Some(&3));
447        assert_eq!(map.get(&"d"), Some(&4));
448    }
449
450    #[test]
451    fn test_extend() {
452        let mut map = TreeMap::default();
453        map.insert("a", 1);
454        map.insert("b", 2);
455        map.insert("c", 3);
456        map.extend([("a", 2), ("b", 2), ("d", 4)]);
457        assert_eq!(map.iter().count(), 4);
458        assert_eq!(map.get(&"a"), Some(&2));
459        assert_eq!(map.get(&"b"), Some(&2));
460        assert_eq!(map.get(&"c"), Some(&3));
461        assert_eq!(map.get(&"d"), Some(&4));
462    }
463
464    #[test]
465    fn test_remove_between_and_path_successor() {
466        use std::path::{Path, PathBuf};
467
468        #[derive(Debug)]
469        pub struct PathDescendants<'a>(&'a Path);
470
471        impl MapSeekTarget<PathBuf> for PathDescendants<'_> {
472            fn cmp_cursor(&self, key: &PathBuf) -> Ordering {
473                if key.starts_with(self.0) {
474                    Ordering::Greater
475                } else {
476                    self.0.cmp(key)
477                }
478            }
479        }
480
481        let mut map = TreeMap::default();
482
483        map.insert(PathBuf::from("a"), 1);
484        map.insert(PathBuf::from("a/a"), 1);
485        map.insert(PathBuf::from("b"), 2);
486        map.insert(PathBuf::from("b/a/a"), 3);
487        map.insert(PathBuf::from("b/a/a/a/b"), 4);
488        map.insert(PathBuf::from("c"), 5);
489        map.insert(PathBuf::from("c/a"), 6);
490
491        map.remove_range(
492            &PathBuf::from("b/a"),
493            &PathDescendants(&PathBuf::from("b/a")),
494        );
495
496        assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
497        assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
498        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
499        assert_eq!(map.get(&PathBuf::from("b/a/a")), None);
500        assert_eq!(map.get(&PathBuf::from("b/a/a/a/b")), None);
501        assert_eq!(map.get(&PathBuf::from("c")), Some(&5));
502        assert_eq!(map.get(&PathBuf::from("c/a")), Some(&6));
503
504        map.remove_range(&PathBuf::from("c"), &PathDescendants(&PathBuf::from("c")));
505
506        assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
507        assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
508        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
509        assert_eq!(map.get(&PathBuf::from("c")), None);
510        assert_eq!(map.get(&PathBuf::from("c/a")), None);
511
512        map.remove_range(&PathBuf::from("a"), &PathDescendants(&PathBuf::from("a")));
513
514        assert_eq!(map.get(&PathBuf::from("a")), None);
515        assert_eq!(map.get(&PathBuf::from("a/a")), None);
516        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
517
518        map.remove_range(&PathBuf::from("b"), &PathDescendants(&PathBuf::from("b")));
519
520        assert_eq!(map.get(&PathBuf::from("b")), None);
521    }
522}