tree_map.rs

  1use std::{cmp::Ordering, fmt::Debug};
  2
  3use crate::{Bias, ContextLessSummary, Dimension, Edit, Item, KeyedItem, SeekTarget, SumTree};
  4
  5/// A cheaply-cloneable ordered map based on a [SumTree](crate::SumTree).
  6#[derive(Clone, PartialEq, Eq)]
  7pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
  8where
  9    K: Clone + Ord,
 10    V: Clone;
 11
 12#[derive(Clone, Debug, PartialEq, Eq)]
 13pub struct MapEntry<K, V> {
 14    key: K,
 15    value: V,
 16}
 17
 18#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
 19pub struct MapKey<K>(Option<K>);
 20
 21impl<K> Default for MapKey<K> {
 22    fn default() -> Self {
 23        Self(None)
 24    }
 25}
 26
 27#[derive(Clone, Debug)]
 28pub struct MapKeyRef<'a, K>(Option<&'a K>);
 29
 30impl<K> Default for MapKeyRef<'_, K> {
 31    fn default() -> Self {
 32        Self(None)
 33    }
 34}
 35
 36#[derive(Clone, Debug, PartialEq, Eq)]
 37pub struct TreeSet<K>(TreeMap<K, ()>)
 38where
 39    K: Clone + Ord;
 40
 41impl<K: Clone + Ord, V: Clone> TreeMap<K, V> {
 42    pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
 43        let tree = SumTree::from_iter(
 44            entries
 45                .into_iter()
 46                .map(|(key, value)| MapEntry { key, value }),
 47            (),
 48        );
 49        Self(tree)
 50    }
 51
 52    pub fn is_empty(&self) -> bool {
 53        self.0.is_empty()
 54    }
 55
 56    pub fn get(&self, key: &K) -> Option<&V> {
 57        let (.., item) = self
 58            .0
 59            .find::<MapKeyRef<'_, K>, _>((), &MapKeyRef(Some(key)), Bias::Left);
 60        if let Some(item) = item {
 61            if Some(key) == item.key().0.as_ref() {
 62                Some(&item.value)
 63            } else {
 64                None
 65            }
 66        } else {
 67            None
 68        }
 69    }
 70
 71    pub fn insert(&mut self, key: K, value: V) {
 72        self.0.insert_or_replace(MapEntry { key, value }, ());
 73    }
 74
 75    pub fn insert_or_replace(&mut self, key: K, value: V) -> Option<V> {
 76        self.0
 77            .insert_or_replace(MapEntry { key, value }, ())
 78            .map(|it| it.value)
 79    }
 80
 81    pub fn extend(&mut self, iter: impl IntoIterator<Item = (K, V)>) {
 82        let edits: Vec<_> = iter
 83            .into_iter()
 84            .map(|(key, value)| Edit::Insert(MapEntry { key, value }))
 85            .collect();
 86        self.0.edit(edits, ());
 87    }
 88
 89    pub fn clear(&mut self) {
 90        self.0 = SumTree::default();
 91    }
 92
 93    pub fn remove(&mut self, key: &K) -> Option<V> {
 94        let mut removed = None;
 95        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
 96        let key = MapKeyRef(Some(key));
 97        let mut new_tree = cursor.slice(&key, Bias::Left);
 98        if key.cmp(&cursor.end(), ()) == Ordering::Equal {
 99            removed = Some(cursor.item().unwrap().value.clone());
100            cursor.next();
101        }
102        new_tree.append(cursor.suffix(), ());
103        drop(cursor);
104        self.0 = new_tree;
105        removed
106    }
107
108    pub fn remove_range(&mut self, start: &impl MapSeekTarget<K>, end: &impl MapSeekTarget<K>) {
109        let start = MapSeekTargetAdaptor(start);
110        let end = MapSeekTargetAdaptor(end);
111        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
112        let mut new_tree = cursor.slice(&start, Bias::Left);
113        cursor.seek(&end, Bias::Left);
114        new_tree.append(cursor.suffix(), ());
115        drop(cursor);
116        self.0 = new_tree;
117    }
118
119    /// Returns the key-value pair with the greatest key less than or equal to the given key.
120    pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
121        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
122        let key = MapKeyRef(Some(key));
123        cursor.seek(&key, Bias::Right);
124        cursor.prev();
125        cursor.item().map(|item| (&item.key, &item.value))
126    }
127
128    pub fn iter_from<'a>(&'a self, from: &K) -> impl Iterator<Item = (&'a K, &'a V)> + 'a {
129        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
130        let from_key = MapKeyRef(Some(from));
131        cursor.seek(&from_key, Bias::Left);
132
133        cursor.map(|map_entry| (&map_entry.key, &map_entry.value))
134    }
135
136    pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
137    where
138        F: FnOnce(&mut V) -> T,
139    {
140        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
141        let key = MapKeyRef(Some(key));
142        let mut new_tree = cursor.slice(&key, Bias::Left);
143        let mut result = None;
144        if key.cmp(&cursor.end(), ()) == Ordering::Equal {
145            let mut updated = cursor.item().unwrap().clone();
146            result = Some(f(&mut updated.value));
147            new_tree.push(updated, ());
148            cursor.next();
149        }
150        new_tree.append(cursor.suffix(), ());
151        drop(cursor);
152        self.0 = new_tree;
153        result
154    }
155
156    pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
157        let mut new_map = SumTree::<MapEntry<K, V>>::default();
158
159        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
160        cursor.next();
161        while let Some(item) = cursor.item() {
162            if predicate(&item.key, &item.value) {
163                new_map.push(item.clone(), ());
164            }
165            cursor.next();
166        }
167        drop(cursor);
168
169        self.0 = new_map;
170    }
171
172    pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
173        self.0.iter().map(|entry| (&entry.key, &entry.value))
174    }
175
176    pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
177        self.0.iter().map(|entry| &entry.value)
178    }
179
180    pub fn first(&self) -> Option<(&K, &V)> {
181        self.0.first().map(|entry| (&entry.key, &entry.value))
182    }
183
184    pub fn last(&self) -> Option<(&K, &V)> {
185        self.0.last().map(|entry| (&entry.key, &entry.value))
186    }
187
188    pub fn insert_tree(&mut self, other: TreeMap<K, V>) {
189        let edits = other
190            .iter()
191            .map(|(key, value)| {
192                Edit::Insert(MapEntry {
193                    key: key.to_owned(),
194                    value: value.to_owned(),
195                })
196            })
197            .collect();
198
199        self.0.edit(edits, ());
200    }
201}
202
203impl<K, V> Debug for TreeMap<K, V>
204where
205    K: Clone + Debug + Ord,
206    V: Clone + Debug,
207{
208    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209        f.debug_map().entries(self.iter()).finish()
210    }
211}
212
213#[derive(Debug)]
214struct MapSeekTargetAdaptor<'a, T>(&'a T);
215
216impl<'a, K: Clone + Ord, T: MapSeekTarget<K>> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>>
217    for MapSeekTargetAdaptor<'_, T>
218{
219    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: ()) -> Ordering {
220        if let Some(key) = &cursor_location.0 {
221            MapSeekTarget::cmp_cursor(self.0, key)
222        } else {
223            Ordering::Greater
224        }
225    }
226}
227
228pub trait MapSeekTarget<K> {
229    fn cmp_cursor(&self, cursor_location: &K) -> Ordering;
230}
231
232impl<K: Ord> MapSeekTarget<K> for K {
233    fn cmp_cursor(&self, cursor_location: &K) -> Ordering {
234        self.cmp(cursor_location)
235    }
236}
237
238impl<K, V> Default for TreeMap<K, V>
239where
240    K: Clone + Ord,
241    V: Clone,
242{
243    fn default() -> Self {
244        Self(Default::default())
245    }
246}
247
248impl<K, V> Item for MapEntry<K, V>
249where
250    K: Clone + Ord,
251    V: Clone,
252{
253    type Summary = MapKey<K>;
254
255    fn summary(&self, _cx: ()) -> Self::Summary {
256        self.key()
257    }
258}
259
260impl<K, V> KeyedItem for MapEntry<K, V>
261where
262    K: Clone + Ord,
263    V: Clone,
264{
265    type Key = MapKey<K>;
266
267    fn key(&self) -> Self::Key {
268        MapKey(Some(self.key.clone()))
269    }
270}
271
272impl<K> ContextLessSummary for MapKey<K>
273where
274    K: Clone,
275{
276    fn zero() -> Self {
277        Default::default()
278    }
279
280    fn add_summary(&mut self, summary: &Self) {
281        *self = summary.clone()
282    }
283}
284
285impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
286where
287    K: Clone + Ord,
288{
289    fn zero(_cx: ()) -> Self {
290        Default::default()
291    }
292
293    fn add_summary(&mut self, summary: &'a MapKey<K>, _: ()) {
294        self.0 = summary.0.as_ref();
295    }
296}
297
298impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
299where
300    K: Clone + Ord,
301{
302    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: ()) -> Ordering {
303        Ord::cmp(&self.0, &cursor_location.0)
304    }
305}
306
307impl<K> Default for TreeSet<K>
308where
309    K: Clone + Ord,
310{
311    fn default() -> Self {
312        Self(Default::default())
313    }
314}
315
316impl<K> TreeSet<K>
317where
318    K: Clone + Ord,
319{
320    pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
321        Self(TreeMap::from_ordered_entries(
322            entries.into_iter().map(|key| (key, ())),
323        ))
324    }
325
326    pub fn is_empty(&self) -> bool {
327        self.0.is_empty()
328    }
329
330    pub fn insert(&mut self, key: K) {
331        self.0.insert(key, ());
332    }
333
334    pub fn remove(&mut self, key: &K) -> bool {
335        self.0.remove(key).is_some()
336    }
337
338    pub fn extend(&mut self, iter: impl IntoIterator<Item = K>) {
339        self.0.extend(iter.into_iter().map(|key| (key, ())));
340    }
341
342    pub fn contains(&self, key: &K) -> bool {
343        self.0.get(key).is_some()
344    }
345
346    pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
347        self.0.iter().map(|(k, _)| k)
348    }
349
350    pub fn iter_from<'a>(&'a self, key: &K) -> impl Iterator<Item = &'a K> + 'a {
351        self.0.iter_from(key).map(move |(k, _)| k)
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358
359    #[test]
360    fn test_basic() {
361        let mut map = TreeMap::default();
362        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
363
364        map.insert(3, "c");
365        assert_eq!(map.get(&3), Some(&"c"));
366        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
367
368        map.insert(1, "a");
369        assert_eq!(map.get(&1), Some(&"a"));
370        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
371
372        map.insert(2, "b");
373        assert_eq!(map.get(&2), Some(&"b"));
374        assert_eq!(map.get(&1), Some(&"a"));
375        assert_eq!(map.get(&3), Some(&"c"));
376        assert_eq!(
377            map.iter().collect::<Vec<_>>(),
378            vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
379        );
380
381        assert_eq!(map.closest(&0), None);
382        assert_eq!(map.closest(&1), Some((&1, &"a")));
383        assert_eq!(map.closest(&10), Some((&3, &"c")));
384
385        map.remove(&2);
386        assert_eq!(map.get(&2), None);
387        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
388
389        assert_eq!(map.closest(&2), Some((&1, &"a")));
390
391        map.remove(&3);
392        assert_eq!(map.get(&3), None);
393        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
394
395        map.remove(&1);
396        assert_eq!(map.get(&1), None);
397        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
398
399        map.insert(4, "d");
400        map.insert(5, "e");
401        map.insert(6, "f");
402        map.retain(|key, _| *key % 2 == 0);
403        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
404    }
405
406    #[test]
407    fn test_iter_from() {
408        let mut map = TreeMap::default();
409
410        map.insert("a", 1);
411        map.insert("b", 2);
412        map.insert("baa", 3);
413        map.insert("baaab", 4);
414        map.insert("c", 5);
415
416        let result = map
417            .iter_from(&"ba")
418            .take_while(|(key, _)| key.starts_with("ba"))
419            .collect::<Vec<_>>();
420
421        assert_eq!(result.len(), 2);
422        assert!(result.iter().any(|(k, _)| k == &&"baa"));
423        assert!(result.iter().any(|(k, _)| k == &&"baaab"));
424
425        let result = map
426            .iter_from(&"c")
427            .take_while(|(key, _)| key.starts_with("c"))
428            .collect::<Vec<_>>();
429
430        assert_eq!(result.len(), 1);
431        assert!(result.iter().any(|(k, _)| k == &&"c"));
432    }
433
434    #[test]
435    fn test_insert_tree() {
436        let mut map = TreeMap::default();
437        map.insert("a", 1);
438        map.insert("b", 2);
439        map.insert("c", 3);
440
441        let mut other = TreeMap::default();
442        other.insert("a", 2);
443        other.insert("b", 2);
444        other.insert("d", 4);
445
446        map.insert_tree(other);
447
448        assert_eq!(map.iter().count(), 4);
449        assert_eq!(map.get(&"a"), Some(&2));
450        assert_eq!(map.get(&"b"), Some(&2));
451        assert_eq!(map.get(&"c"), Some(&3));
452        assert_eq!(map.get(&"d"), Some(&4));
453    }
454
455    #[test]
456    fn test_extend() {
457        let mut map = TreeMap::default();
458        map.insert("a", 1);
459        map.insert("b", 2);
460        map.insert("c", 3);
461        map.extend([("a", 2), ("b", 2), ("d", 4)]);
462        assert_eq!(map.iter().count(), 4);
463        assert_eq!(map.get(&"a"), Some(&2));
464        assert_eq!(map.get(&"b"), Some(&2));
465        assert_eq!(map.get(&"c"), Some(&3));
466        assert_eq!(map.get(&"d"), Some(&4));
467    }
468
469    #[test]
470    fn test_remove_between_and_path_successor() {
471        use std::path::{Path, PathBuf};
472
473        #[derive(Debug)]
474        pub struct PathDescendants<'a>(&'a Path);
475
476        impl MapSeekTarget<PathBuf> for PathDescendants<'_> {
477            fn cmp_cursor(&self, key: &PathBuf) -> Ordering {
478                if key.starts_with(self.0) {
479                    Ordering::Greater
480                } else {
481                    self.0.cmp(key)
482                }
483            }
484        }
485
486        let mut map = TreeMap::default();
487
488        map.insert(PathBuf::from("a"), 1);
489        map.insert(PathBuf::from("a/a"), 1);
490        map.insert(PathBuf::from("b"), 2);
491        map.insert(PathBuf::from("b/a/a"), 3);
492        map.insert(PathBuf::from("b/a/a/a/b"), 4);
493        map.insert(PathBuf::from("c"), 5);
494        map.insert(PathBuf::from("c/a"), 6);
495
496        map.remove_range(
497            &PathBuf::from("b/a"),
498            &PathDescendants(&PathBuf::from("b/a")),
499        );
500
501        assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
502        assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
503        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
504        assert_eq!(map.get(&PathBuf::from("b/a/a")), None);
505        assert_eq!(map.get(&PathBuf::from("b/a/a/a/b")), None);
506        assert_eq!(map.get(&PathBuf::from("c")), Some(&5));
507        assert_eq!(map.get(&PathBuf::from("c/a")), Some(&6));
508
509        map.remove_range(&PathBuf::from("c"), &PathDescendants(&PathBuf::from("c")));
510
511        assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
512        assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
513        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
514        assert_eq!(map.get(&PathBuf::from("c")), None);
515        assert_eq!(map.get(&PathBuf::from("c/a")), None);
516
517        map.remove_range(&PathBuf::from("a"), &PathDescendants(&PathBuf::from("a")));
518
519        assert_eq!(map.get(&PathBuf::from("a")), None);
520        assert_eq!(map.get(&PathBuf::from("a/a")), None);
521        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
522
523        map.remove_range(&PathBuf::from("b"), &PathDescendants(&PathBuf::from("b")));
524
525        assert_eq!(map.get(&PathBuf::from("b")), None);
526    }
527}