tree_map.rs

  1use std::{cmp::Ordering, fmt::Debug};
  2
  3use crate::{Bias, ContextLessSummary, Dimension, Edit, Item, KeyedItem, SeekTarget, SumTree};
  4
  5/// A cheaply-cloneable ordered map based on a [SumTree](crate::SumTree).
  6#[derive(Clone, PartialEq, Eq)]
  7pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
  8where
  9    K: Clone + Ord,
 10    V: Clone;
 11
 12#[derive(Clone, Debug, PartialEq, Eq)]
 13pub struct MapEntry<K, V> {
 14    key: K,
 15    value: V,
 16}
 17
 18#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
 19pub struct MapKey<K>(Option<K>);
 20
 21impl<K> Default for MapKey<K> {
 22    fn default() -> Self {
 23        Self(None)
 24    }
 25}
 26
 27#[derive(Clone, Debug)]
 28pub struct MapKeyRef<'a, K>(Option<&'a K>);
 29
 30impl<K> Default for MapKeyRef<'_, K> {
 31    fn default() -> Self {
 32        Self(None)
 33    }
 34}
 35
 36#[derive(Clone, Debug, PartialEq, Eq)]
 37pub struct TreeSet<K>(TreeMap<K, ()>)
 38where
 39    K: Clone + Ord;
 40
 41impl<K: Clone + Ord, V: Clone> TreeMap<K, V> {
 42    pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
 43        let tree = SumTree::from_iter(
 44            entries
 45                .into_iter()
 46                .map(|(key, value)| MapEntry { key, value }),
 47            (),
 48        );
 49        Self(tree)
 50    }
 51
 52    pub fn is_empty(&self) -> bool {
 53        self.0.is_empty()
 54    }
 55
 56    pub fn contains_key(&self, key: &K) -> bool {
 57        self.get(key).is_some()
 58    }
 59
 60    pub fn get(&self, key: &K) -> Option<&V> {
 61        let (.., item) = self
 62            .0
 63            .find::<MapKeyRef<'_, K>, _>((), &MapKeyRef(Some(key)), Bias::Left);
 64        if let Some(item) = item {
 65            if Some(key) == item.key().0.as_ref() {
 66                Some(&item.value)
 67            } else {
 68                None
 69            }
 70        } else {
 71            None
 72        }
 73    }
 74
 75    pub fn insert(&mut self, key: K, value: V) {
 76        self.0.insert_or_replace(MapEntry { key, value }, ());
 77    }
 78
 79    pub fn insert_or_replace(&mut self, key: K, value: V) -> Option<V> {
 80        self.0
 81            .insert_or_replace(MapEntry { key, value }, ())
 82            .map(|it| it.value)
 83    }
 84
 85    pub fn extend(&mut self, iter: impl IntoIterator<Item = (K, V)>) {
 86        let edits: Vec<_> = iter
 87            .into_iter()
 88            .map(|(key, value)| Edit::Insert(MapEntry { key, value }))
 89            .collect();
 90        self.0.edit(edits, ());
 91    }
 92
 93    pub fn clear(&mut self) {
 94        self.0 = SumTree::default();
 95    }
 96
 97    pub fn remove(&mut self, key: &K) -> Option<V> {
 98        let mut removed = None;
 99        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
100        let key = MapKeyRef(Some(key));
101        let mut new_tree = cursor.slice(&key, Bias::Left);
102        if key.cmp(&cursor.end(), ()) == Ordering::Equal {
103            removed = Some(cursor.item().unwrap().value.clone());
104            cursor.next();
105        }
106        new_tree.append(cursor.suffix(), ());
107        drop(cursor);
108        self.0 = new_tree;
109        removed
110    }
111
112    pub fn remove_range(&mut self, start: &impl MapSeekTarget<K>, end: &impl MapSeekTarget<K>) {
113        let start = MapSeekTargetAdaptor(start);
114        let end = MapSeekTargetAdaptor(end);
115        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
116        let mut new_tree = cursor.slice(&start, Bias::Left);
117        cursor.seek(&end, Bias::Left);
118        new_tree.append(cursor.suffix(), ());
119        drop(cursor);
120        self.0 = new_tree;
121    }
122
123    /// Returns the key-value pair with the greatest key less than or equal to the given key.
124    pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
125        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
126        let key = MapKeyRef(Some(key));
127        cursor.seek(&key, Bias::Right);
128        cursor.prev();
129        cursor.item().map(|item| (&item.key, &item.value))
130    }
131
132    pub fn iter_from<'a>(&'a self, from: &K) -> impl Iterator<Item = (&'a K, &'a V)> + 'a {
133        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
134        let from_key = MapKeyRef(Some(from));
135        cursor.seek(&from_key, Bias::Left);
136
137        cursor.map(|map_entry| (&map_entry.key, &map_entry.value))
138    }
139
140    pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
141    where
142        F: FnOnce(&mut V) -> T,
143    {
144        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
145        let key = MapKeyRef(Some(key));
146        let mut new_tree = cursor.slice(&key, Bias::Left);
147        let mut result = None;
148        if key.cmp(&cursor.end(), ()) == Ordering::Equal {
149            let mut updated = cursor.item().unwrap().clone();
150            result = Some(f(&mut updated.value));
151            new_tree.push(updated, ());
152            cursor.next();
153        }
154        new_tree.append(cursor.suffix(), ());
155        drop(cursor);
156        self.0 = new_tree;
157        result
158    }
159
160    pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
161        let mut new_map = SumTree::<MapEntry<K, V>>::default();
162
163        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
164        cursor.next();
165        while let Some(item) = cursor.item() {
166            if predicate(&item.key, &item.value) {
167                new_map.push(item.clone(), ());
168            }
169            cursor.next();
170        }
171        drop(cursor);
172
173        self.0 = new_map;
174    }
175
176    pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
177        self.0.iter().map(|entry| (&entry.key, &entry.value))
178    }
179
180    pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
181        self.0.iter().map(|entry| &entry.value)
182    }
183
184    pub fn first(&self) -> Option<(&K, &V)> {
185        self.0.first().map(|entry| (&entry.key, &entry.value))
186    }
187
188    pub fn last(&self) -> Option<(&K, &V)> {
189        self.0.last().map(|entry| (&entry.key, &entry.value))
190    }
191
192    pub fn insert_tree(&mut self, other: TreeMap<K, V>) {
193        let edits = other
194            .iter()
195            .map(|(key, value)| {
196                Edit::Insert(MapEntry {
197                    key: key.to_owned(),
198                    value: value.to_owned(),
199                })
200            })
201            .collect();
202
203        self.0.edit(edits, ());
204    }
205}
206
207impl<K, V> Debug for TreeMap<K, V>
208where
209    K: Clone + Debug + Ord,
210    V: Clone + Debug,
211{
212    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
213        f.debug_map().entries(self.iter()).finish()
214    }
215}
216
217#[derive(Debug)]
218struct MapSeekTargetAdaptor<'a, T>(&'a T);
219
220impl<'a, K: Clone + Ord, T: MapSeekTarget<K>> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>>
221    for MapSeekTargetAdaptor<'_, T>
222{
223    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: ()) -> Ordering {
224        if let Some(key) = &cursor_location.0 {
225            MapSeekTarget::cmp_cursor(self.0, key)
226        } else {
227            Ordering::Greater
228        }
229    }
230}
231
232pub trait MapSeekTarget<K> {
233    fn cmp_cursor(&self, cursor_location: &K) -> Ordering;
234}
235
236impl<K: Ord> MapSeekTarget<K> for K {
237    fn cmp_cursor(&self, cursor_location: &K) -> Ordering {
238        self.cmp(cursor_location)
239    }
240}
241
242impl<K, V> Default for TreeMap<K, V>
243where
244    K: Clone + Ord,
245    V: Clone,
246{
247    fn default() -> Self {
248        Self(Default::default())
249    }
250}
251
252impl<K, V> Item for MapEntry<K, V>
253where
254    K: Clone + Ord,
255    V: Clone,
256{
257    type Summary = MapKey<K>;
258
259    fn summary(&self, _cx: ()) -> Self::Summary {
260        self.key()
261    }
262}
263
264impl<K, V> KeyedItem for MapEntry<K, V>
265where
266    K: Clone + Ord,
267    V: Clone,
268{
269    type Key = MapKey<K>;
270
271    fn key(&self) -> Self::Key {
272        MapKey(Some(self.key.clone()))
273    }
274}
275
276impl<K> ContextLessSummary for MapKey<K>
277where
278    K: Clone,
279{
280    fn zero() -> Self {
281        Default::default()
282    }
283
284    fn add_summary(&mut self, summary: &Self) {
285        *self = summary.clone()
286    }
287}
288
289impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
290where
291    K: Clone + Ord,
292{
293    fn zero(_cx: ()) -> Self {
294        Default::default()
295    }
296
297    fn add_summary(&mut self, summary: &'a MapKey<K>, _: ()) {
298        self.0 = summary.0.as_ref();
299    }
300}
301
302impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
303where
304    K: Clone + Ord,
305{
306    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: ()) -> Ordering {
307        Ord::cmp(&self.0, &cursor_location.0)
308    }
309}
310
311impl<K> Default for TreeSet<K>
312where
313    K: Clone + Ord,
314{
315    fn default() -> Self {
316        Self(Default::default())
317    }
318}
319
320impl<K> TreeSet<K>
321where
322    K: Clone + Ord,
323{
324    pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
325        Self(TreeMap::from_ordered_entries(
326            entries.into_iter().map(|key| (key, ())),
327        ))
328    }
329
330    pub fn is_empty(&self) -> bool {
331        self.0.is_empty()
332    }
333
334    pub fn insert(&mut self, key: K) {
335        self.0.insert(key, ());
336    }
337
338    pub fn remove(&mut self, key: &K) -> bool {
339        self.0.remove(key).is_some()
340    }
341
342    pub fn extend(&mut self, iter: impl IntoIterator<Item = K>) {
343        self.0.extend(iter.into_iter().map(|key| (key, ())));
344    }
345
346    pub fn contains(&self, key: &K) -> bool {
347        self.0.get(key).is_some()
348    }
349
350    pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
351        self.0.iter().map(|(k, _)| k)
352    }
353
354    pub fn iter_from<'a>(&'a self, key: &K) -> impl Iterator<Item = &'a K> + 'a {
355        self.0.iter_from(key).map(move |(k, _)| k)
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362
363    #[test]
364    fn test_basic() {
365        let mut map = TreeMap::default();
366        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
367
368        map.insert(3, "c");
369        assert_eq!(map.get(&3), Some(&"c"));
370        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
371
372        map.insert(1, "a");
373        assert_eq!(map.get(&1), Some(&"a"));
374        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
375
376        map.insert(2, "b");
377        assert_eq!(map.get(&2), Some(&"b"));
378        assert_eq!(map.get(&1), Some(&"a"));
379        assert_eq!(map.get(&3), Some(&"c"));
380        assert_eq!(
381            map.iter().collect::<Vec<_>>(),
382            vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
383        );
384
385        assert_eq!(map.closest(&0), None);
386        assert_eq!(map.closest(&1), Some((&1, &"a")));
387        assert_eq!(map.closest(&10), Some((&3, &"c")));
388
389        map.remove(&2);
390        assert_eq!(map.get(&2), None);
391        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
392
393        assert_eq!(map.closest(&2), Some((&1, &"a")));
394
395        map.remove(&3);
396        assert_eq!(map.get(&3), None);
397        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
398
399        map.remove(&1);
400        assert_eq!(map.get(&1), None);
401        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
402
403        map.insert(4, "d");
404        map.insert(5, "e");
405        map.insert(6, "f");
406        map.retain(|key, _| *key % 2 == 0);
407        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
408    }
409
410    #[test]
411    fn test_iter_from() {
412        let mut map = TreeMap::default();
413
414        map.insert("a", 1);
415        map.insert("b", 2);
416        map.insert("baa", 3);
417        map.insert("baaab", 4);
418        map.insert("c", 5);
419
420        let result = map
421            .iter_from(&"ba")
422            .take_while(|(key, _)| key.starts_with("ba"))
423            .collect::<Vec<_>>();
424
425        assert_eq!(result.len(), 2);
426        assert!(result.iter().any(|(k, _)| k == &&"baa"));
427        assert!(result.iter().any(|(k, _)| k == &&"baaab"));
428
429        let result = map
430            .iter_from(&"c")
431            .take_while(|(key, _)| key.starts_with("c"))
432            .collect::<Vec<_>>();
433
434        assert_eq!(result.len(), 1);
435        assert!(result.iter().any(|(k, _)| k == &&"c"));
436    }
437
438    #[test]
439    fn test_insert_tree() {
440        let mut map = TreeMap::default();
441        map.insert("a", 1);
442        map.insert("b", 2);
443        map.insert("c", 3);
444
445        let mut other = TreeMap::default();
446        other.insert("a", 2);
447        other.insert("b", 2);
448        other.insert("d", 4);
449
450        map.insert_tree(other);
451
452        assert_eq!(map.iter().count(), 4);
453        assert_eq!(map.get(&"a"), Some(&2));
454        assert_eq!(map.get(&"b"), Some(&2));
455        assert_eq!(map.get(&"c"), Some(&3));
456        assert_eq!(map.get(&"d"), Some(&4));
457    }
458
459    #[test]
460    fn test_extend() {
461        let mut map = TreeMap::default();
462        map.insert("a", 1);
463        map.insert("b", 2);
464        map.insert("c", 3);
465        map.extend([("a", 2), ("b", 2), ("d", 4)]);
466        assert_eq!(map.iter().count(), 4);
467        assert_eq!(map.get(&"a"), Some(&2));
468        assert_eq!(map.get(&"b"), Some(&2));
469        assert_eq!(map.get(&"c"), Some(&3));
470        assert_eq!(map.get(&"d"), Some(&4));
471    }
472
473    #[test]
474    fn test_remove_between_and_path_successor() {
475        use std::path::{Path, PathBuf};
476
477        #[derive(Debug)]
478        pub struct PathDescendants<'a>(&'a Path);
479
480        impl MapSeekTarget<PathBuf> for PathDescendants<'_> {
481            fn cmp_cursor(&self, key: &PathBuf) -> Ordering {
482                if key.starts_with(self.0) {
483                    Ordering::Greater
484                } else {
485                    self.0.cmp(key)
486                }
487            }
488        }
489
490        let mut map = TreeMap::default();
491
492        map.insert(PathBuf::from("a"), 1);
493        map.insert(PathBuf::from("a/a"), 1);
494        map.insert(PathBuf::from("b"), 2);
495        map.insert(PathBuf::from("b/a/a"), 3);
496        map.insert(PathBuf::from("b/a/a/a/b"), 4);
497        map.insert(PathBuf::from("c"), 5);
498        map.insert(PathBuf::from("c/a"), 6);
499
500        map.remove_range(
501            &PathBuf::from("b/a"),
502            &PathDescendants(&PathBuf::from("b/a")),
503        );
504
505        assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
506        assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
507        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
508        assert_eq!(map.get(&PathBuf::from("b/a/a")), None);
509        assert_eq!(map.get(&PathBuf::from("b/a/a/a/b")), None);
510        assert_eq!(map.get(&PathBuf::from("c")), Some(&5));
511        assert_eq!(map.get(&PathBuf::from("c/a")), Some(&6));
512
513        map.remove_range(&PathBuf::from("c"), &PathDescendants(&PathBuf::from("c")));
514
515        assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
516        assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
517        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
518        assert_eq!(map.get(&PathBuf::from("c")), None);
519        assert_eq!(map.get(&PathBuf::from("c/a")), None);
520
521        map.remove_range(&PathBuf::from("a"), &PathDescendants(&PathBuf::from("a")));
522
523        assert_eq!(map.get(&PathBuf::from("a")), None);
524        assert_eq!(map.get(&PathBuf::from("a/a")), None);
525        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
526
527        map.remove_range(&PathBuf::from("b"), &PathDescendants(&PathBuf::from("b")));
528
529        assert_eq!(map.get(&PathBuf::from("b")), None);
530    }
531}