tree_map.rs

  1use std::{cmp::Ordering, fmt::Debug};
  2
  3use crate::{Bias, ContextLessSummary, Dimension, Edit, Item, KeyedItem, SeekTarget, SumTree};
  4
  5/// A cheaply-cloneable ordered map based on a [SumTree](crate::SumTree).
  6#[derive(Clone, PartialEq, Eq)]
  7pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
  8where
  9    K: Clone + Ord,
 10    V: Clone;
 11
 12#[derive(Clone, Debug, PartialEq, Eq)]
 13pub struct MapEntry<K, V> {
 14    key: K,
 15    value: V,
 16}
 17
 18#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
 19pub struct MapKey<K>(Option<K>);
 20
 21impl<K> Default for MapKey<K> {
 22    fn default() -> Self {
 23        Self(None)
 24    }
 25}
 26
 27#[derive(Clone, Debug)]
 28pub struct MapKeyRef<'a, K>(Option<&'a K>);
 29
 30impl<K> Default for MapKeyRef<'_, K> {
 31    fn default() -> Self {
 32        Self(None)
 33    }
 34}
 35
 36#[derive(Clone, Debug, PartialEq, Eq)]
 37pub struct TreeSet<K>(TreeMap<K, ()>)
 38where
 39    K: Clone + Ord;
 40
 41impl<K: Clone + Ord, V: Clone> TreeMap<K, V> {
 42    pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
 43        let tree = SumTree::from_iter(
 44            entries
 45                .into_iter()
 46                .map(|(key, value)| MapEntry { key, value }),
 47            (),
 48        );
 49        Self(tree)
 50    }
 51
 52    pub fn is_empty(&self) -> bool {
 53        self.0.is_empty()
 54    }
 55
 56    pub fn get(&self, key: &K) -> Option<&V> {
 57        let (.., item) = self
 58            .0
 59            .find::<MapKeyRef<'_, K>, _>((), &MapKeyRef(Some(key)), Bias::Left);
 60        if let Some(item) = item {
 61            if Some(key) == item.key().0.as_ref() {
 62                Some(&item.value)
 63            } else {
 64                None
 65            }
 66        } else {
 67            None
 68        }
 69    }
 70
 71    pub fn insert(&mut self, key: K, value: V) {
 72        self.0.insert_or_replace(MapEntry { key, value }, ());
 73    }
 74
 75    pub fn insert_or_replace(&mut self, key: K, value: V) -> Option<V> {
 76        self.0
 77            .insert_or_replace(MapEntry { key, value }, ())
 78            .map(|it| it.value)
 79    }
 80
 81    pub fn extend(&mut self, iter: impl IntoIterator<Item = (K, V)>) {
 82        let edits: Vec<_> = iter
 83            .into_iter()
 84            .map(|(key, value)| Edit::Insert(MapEntry { key, value }))
 85            .collect();
 86        if edits.is_empty() {
 87            return;
 88        }
 89        self.0.edit(edits, ());
 90    }
 91
 92    pub fn clear(&mut self) {
 93        self.0 = SumTree::default();
 94    }
 95
 96    pub fn remove(&mut self, key: &K) -> Option<V> {
 97        let mut removed = None;
 98        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
 99        let key = MapKeyRef(Some(key));
100        let mut new_tree = cursor.slice(&key, Bias::Left);
101        if key.cmp(&cursor.end(), ()) == Ordering::Equal {
102            removed = Some(cursor.item().unwrap().value.clone());
103            cursor.next();
104        }
105        new_tree.append(cursor.suffix(), ());
106        drop(cursor);
107        self.0 = new_tree;
108        removed
109    }
110
111    pub fn remove_range(&mut self, start: &impl MapSeekTarget<K>, end: &impl MapSeekTarget<K>) {
112        let start = MapSeekTargetAdaptor(start);
113        let end = MapSeekTargetAdaptor(end);
114        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
115        let mut new_tree = cursor.slice(&start, Bias::Left);
116        cursor.seek(&end, Bias::Left);
117        new_tree.append(cursor.suffix(), ());
118        drop(cursor);
119        self.0 = new_tree;
120    }
121
122    /// Returns the key-value pair with the greatest key less than or equal to the given key.
123    pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
124        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
125        let key = MapKeyRef(Some(key));
126        cursor.seek(&key, Bias::Right);
127        cursor.prev();
128        cursor.item().map(|item| (&item.key, &item.value))
129    }
130
131    pub fn iter_from<'a>(&'a self, from: &K) -> impl Iterator<Item = (&'a K, &'a V)> + 'a {
132        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
133        let from_key = MapKeyRef(Some(from));
134        cursor.seek(&from_key, Bias::Left);
135
136        cursor.map(|map_entry| (&map_entry.key, &map_entry.value))
137    }
138
139    pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
140    where
141        F: FnOnce(&mut V) -> T,
142    {
143        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
144        let key = MapKeyRef(Some(key));
145        let mut new_tree = cursor.slice(&key, Bias::Left);
146        let mut result = None;
147        if key.cmp(&cursor.end(), ()) == Ordering::Equal {
148            let mut updated = cursor.item().unwrap().clone();
149            result = Some(f(&mut updated.value));
150            new_tree.push(updated, ());
151            cursor.next();
152        }
153        new_tree.append(cursor.suffix(), ());
154        drop(cursor);
155        self.0 = new_tree;
156        result
157    }
158
159    pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
160        let mut new_map = SumTree::<MapEntry<K, V>>::default();
161
162        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
163        cursor.next();
164        while let Some(item) = cursor.item() {
165            if predicate(&item.key, &item.value) {
166                new_map.push(item.clone(), ());
167            }
168            cursor.next();
169        }
170        drop(cursor);
171
172        self.0 = new_map;
173    }
174
175    pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
176        self.0.iter().map(|entry| (&entry.key, &entry.value))
177    }
178
179    pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
180        self.0.iter().map(|entry| &entry.value)
181    }
182
183    pub fn first(&self) -> Option<(&K, &V)> {
184        self.0.first().map(|entry| (&entry.key, &entry.value))
185    }
186
187    pub fn last(&self) -> Option<(&K, &V)> {
188        self.0.last().map(|entry| (&entry.key, &entry.value))
189    }
190
191    pub fn insert_tree(&mut self, other: TreeMap<K, V>) {
192        let edits = other
193            .iter()
194            .map(|(key, value)| {
195                Edit::Insert(MapEntry {
196                    key: key.to_owned(),
197                    value: value.to_owned(),
198                })
199            })
200            .collect();
201
202        self.0.edit(edits, ());
203    }
204}
205
206impl<K, V> Debug for TreeMap<K, V>
207where
208    K: Clone + Debug + Ord,
209    V: Clone + Debug,
210{
211    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
212        f.debug_map().entries(self.iter()).finish()
213    }
214}
215
216#[derive(Debug)]
217struct MapSeekTargetAdaptor<'a, T>(&'a T);
218
219impl<'a, K: Clone + Ord, T: MapSeekTarget<K>> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>>
220    for MapSeekTargetAdaptor<'_, T>
221{
222    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: ()) -> Ordering {
223        if let Some(key) = &cursor_location.0 {
224            MapSeekTarget::cmp_cursor(self.0, key)
225        } else {
226            Ordering::Greater
227        }
228    }
229}
230
231pub trait MapSeekTarget<K> {
232    fn cmp_cursor(&self, cursor_location: &K) -> Ordering;
233}
234
235impl<K: Ord> MapSeekTarget<K> for K {
236    fn cmp_cursor(&self, cursor_location: &K) -> Ordering {
237        self.cmp(cursor_location)
238    }
239}
240
241impl<K, V> Default for TreeMap<K, V>
242where
243    K: Clone + Ord,
244    V: Clone,
245{
246    fn default() -> Self {
247        Self(Default::default())
248    }
249}
250
251impl<K, V> Item for MapEntry<K, V>
252where
253    K: Clone + Ord,
254    V: Clone,
255{
256    type Summary = MapKey<K>;
257
258    fn summary(&self, _cx: ()) -> Self::Summary {
259        self.key()
260    }
261}
262
263impl<K, V> KeyedItem for MapEntry<K, V>
264where
265    K: Clone + Ord,
266    V: Clone,
267{
268    type Key = MapKey<K>;
269
270    fn key(&self) -> Self::Key {
271        MapKey(Some(self.key.clone()))
272    }
273}
274
275impl<K> ContextLessSummary for MapKey<K>
276where
277    K: Clone,
278{
279    fn zero() -> Self {
280        Default::default()
281    }
282
283    fn add_summary(&mut self, summary: &Self) {
284        *self = summary.clone()
285    }
286}
287
288impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
289where
290    K: Clone + Ord,
291{
292    fn zero(_cx: ()) -> Self {
293        Default::default()
294    }
295
296    fn add_summary(&mut self, summary: &'a MapKey<K>, _: ()) {
297        self.0 = summary.0.as_ref();
298    }
299}
300
301impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
302where
303    K: Clone + Ord,
304{
305    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: ()) -> Ordering {
306        Ord::cmp(&self.0, &cursor_location.0)
307    }
308}
309
310impl<K> Default for TreeSet<K>
311where
312    K: Clone + Ord,
313{
314    fn default() -> Self {
315        Self(Default::default())
316    }
317}
318
319impl<K> TreeSet<K>
320where
321    K: Clone + Ord,
322{
323    pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
324        Self(TreeMap::from_ordered_entries(
325            entries.into_iter().map(|key| (key, ())),
326        ))
327    }
328
329    pub fn is_empty(&self) -> bool {
330        self.0.is_empty()
331    }
332
333    pub fn insert(&mut self, key: K) {
334        self.0.insert(key, ());
335    }
336
337    pub fn remove(&mut self, key: &K) -> bool {
338        self.0.remove(key).is_some()
339    }
340
341    pub fn extend(&mut self, iter: impl IntoIterator<Item = K>) {
342        self.0.extend(iter.into_iter().map(|key| (key, ())));
343    }
344
345    pub fn contains(&self, key: &K) -> bool {
346        self.0.get(key).is_some()
347    }
348
349    pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
350        self.0.iter().map(|(k, _)| k)
351    }
352
353    pub fn iter_from<'a>(&'a self, key: &K) -> impl Iterator<Item = &'a K> + 'a {
354        self.0.iter_from(key).map(move |(k, _)| k)
355    }
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361
362    #[test]
363    fn test_basic() {
364        let mut map = TreeMap::default();
365        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
366
367        map.insert(3, "c");
368        assert_eq!(map.get(&3), Some(&"c"));
369        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
370
371        map.insert(1, "a");
372        assert_eq!(map.get(&1), Some(&"a"));
373        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
374
375        map.insert(2, "b");
376        assert_eq!(map.get(&2), Some(&"b"));
377        assert_eq!(map.get(&1), Some(&"a"));
378        assert_eq!(map.get(&3), Some(&"c"));
379        assert_eq!(
380            map.iter().collect::<Vec<_>>(),
381            vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
382        );
383
384        assert_eq!(map.closest(&0), None);
385        assert_eq!(map.closest(&1), Some((&1, &"a")));
386        assert_eq!(map.closest(&10), Some((&3, &"c")));
387
388        map.remove(&2);
389        assert_eq!(map.get(&2), None);
390        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
391
392        assert_eq!(map.closest(&2), Some((&1, &"a")));
393
394        map.remove(&3);
395        assert_eq!(map.get(&3), None);
396        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
397
398        map.remove(&1);
399        assert_eq!(map.get(&1), None);
400        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
401
402        map.insert(4, "d");
403        map.insert(5, "e");
404        map.insert(6, "f");
405        map.retain(|key, _| *key % 2 == 0);
406        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
407    }
408
409    #[test]
410    fn test_iter_from() {
411        let mut map = TreeMap::default();
412
413        map.insert("a", 1);
414        map.insert("b", 2);
415        map.insert("baa", 3);
416        map.insert("baaab", 4);
417        map.insert("c", 5);
418
419        let result = map
420            .iter_from(&"ba")
421            .take_while(|(key, _)| key.starts_with("ba"))
422            .collect::<Vec<_>>();
423
424        assert_eq!(result.len(), 2);
425        assert!(result.iter().any(|(k, _)| k == &&"baa"));
426        assert!(result.iter().any(|(k, _)| k == &&"baaab"));
427
428        let result = map
429            .iter_from(&"c")
430            .take_while(|(key, _)| key.starts_with("c"))
431            .collect::<Vec<_>>();
432
433        assert_eq!(result.len(), 1);
434        assert!(result.iter().any(|(k, _)| k == &&"c"));
435    }
436
437    #[test]
438    fn test_insert_tree() {
439        let mut map = TreeMap::default();
440        map.insert("a", 1);
441        map.insert("b", 2);
442        map.insert("c", 3);
443
444        let mut other = TreeMap::default();
445        other.insert("a", 2);
446        other.insert("b", 2);
447        other.insert("d", 4);
448
449        map.insert_tree(other);
450
451        assert_eq!(map.iter().count(), 4);
452        assert_eq!(map.get(&"a"), Some(&2));
453        assert_eq!(map.get(&"b"), Some(&2));
454        assert_eq!(map.get(&"c"), Some(&3));
455        assert_eq!(map.get(&"d"), Some(&4));
456    }
457
458    #[test]
459    fn test_extend() {
460        let mut map = TreeMap::default();
461        map.insert("a", 1);
462        map.insert("b", 2);
463        map.insert("c", 3);
464        map.extend([("a", 2), ("b", 2), ("d", 4)]);
465        assert_eq!(map.iter().count(), 4);
466        assert_eq!(map.get(&"a"), Some(&2));
467        assert_eq!(map.get(&"b"), Some(&2));
468        assert_eq!(map.get(&"c"), Some(&3));
469        assert_eq!(map.get(&"d"), Some(&4));
470    }
471
472    #[test]
473    fn test_remove_between_and_path_successor() {
474        use std::path::{Path, PathBuf};
475
476        #[derive(Debug)]
477        pub struct PathDescendants<'a>(&'a Path);
478
479        impl MapSeekTarget<PathBuf> for PathDescendants<'_> {
480            fn cmp_cursor(&self, key: &PathBuf) -> Ordering {
481                if key.starts_with(self.0) {
482                    Ordering::Greater
483                } else {
484                    self.0.cmp(key)
485                }
486            }
487        }
488
489        let mut map = TreeMap::default();
490
491        map.insert(PathBuf::from("a"), 1);
492        map.insert(PathBuf::from("a/a"), 1);
493        map.insert(PathBuf::from("b"), 2);
494        map.insert(PathBuf::from("b/a/a"), 3);
495        map.insert(PathBuf::from("b/a/a/a/b"), 4);
496        map.insert(PathBuf::from("c"), 5);
497        map.insert(PathBuf::from("c/a"), 6);
498
499        map.remove_range(
500            &PathBuf::from("b/a"),
501            &PathDescendants(&PathBuf::from("b/a")),
502        );
503
504        assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
505        assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
506        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
507        assert_eq!(map.get(&PathBuf::from("b/a/a")), None);
508        assert_eq!(map.get(&PathBuf::from("b/a/a/a/b")), None);
509        assert_eq!(map.get(&PathBuf::from("c")), Some(&5));
510        assert_eq!(map.get(&PathBuf::from("c/a")), Some(&6));
511
512        map.remove_range(&PathBuf::from("c"), &PathDescendants(&PathBuf::from("c")));
513
514        assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
515        assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
516        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
517        assert_eq!(map.get(&PathBuf::from("c")), None);
518        assert_eq!(map.get(&PathBuf::from("c/a")), None);
519
520        map.remove_range(&PathBuf::from("a"), &PathDescendants(&PathBuf::from("a")));
521
522        assert_eq!(map.get(&PathBuf::from("a")), None);
523        assert_eq!(map.get(&PathBuf::from("a/a")), None);
524        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
525
526        map.remove_range(&PathBuf::from("b"), &PathDescendants(&PathBuf::from("b")));
527
528        assert_eq!(map.get(&PathBuf::from("b")), None);
529    }
530}