tree_map.rs

  1use std::{cmp::Ordering, fmt::Debug};
  2
  3use crate::{Bias, ContextLessSummary, Dimension, Edit, Item, KeyedItem, SeekTarget, SumTree};
  4
  5/// A cheaply-cloneable ordered map based on a [SumTree](crate::SumTree).
  6#[derive(Clone, PartialEq, Eq)]
  7pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
  8where
  9    K: Clone + Ord,
 10    V: Clone;
 11
 12#[derive(Clone, Debug, PartialEq, Eq)]
 13pub struct MapEntry<K, V> {
 14    key: K,
 15    value: V,
 16}
 17
 18#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
 19pub struct MapKey<K>(Option<K>);
 20
 21impl<K> Default for MapKey<K> {
 22    fn default() -> Self {
 23        Self(None)
 24    }
 25}
 26
 27#[derive(Clone, Debug)]
 28pub struct MapKeyRef<'a, K>(Option<&'a K>);
 29
 30impl<K> Default for MapKeyRef<'_, K> {
 31    fn default() -> Self {
 32        Self(None)
 33    }
 34}
 35
 36#[derive(Clone, Debug, PartialEq, Eq)]
 37pub struct TreeSet<K>(TreeMap<K, ()>)
 38where
 39    K: Clone + Ord;
 40
 41impl<K: Clone + Ord, V: Clone> TreeMap<K, V> {
 42    pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
 43        let tree = SumTree::from_iter(
 44            entries
 45                .into_iter()
 46                .map(|(key, value)| MapEntry { key, value }),
 47            (),
 48        );
 49        Self(tree)
 50    }
 51
 52    pub fn is_empty(&self) -> bool {
 53        self.0.is_empty()
 54    }
 55
 56    pub fn get(&self, key: &K) -> Option<&V> {
 57        let (.., item) = self
 58            .0
 59            .find::<MapKeyRef<'_, K>, _>((), &MapKeyRef(Some(key)), Bias::Left);
 60        if let Some(item) = item {
 61            if Some(key) == item.key().0.as_ref() {
 62                Some(&item.value)
 63            } else {
 64                None
 65            }
 66        } else {
 67            None
 68        }
 69    }
 70
 71    pub fn insert(&mut self, key: K, value: V) {
 72        self.0.insert_or_replace(MapEntry { key, value }, ());
 73    }
 74
 75    pub fn extend(&mut self, iter: impl IntoIterator<Item = (K, V)>) {
 76        let edits: Vec<_> = iter
 77            .into_iter()
 78            .map(|(key, value)| Edit::Insert(MapEntry { key, value }))
 79            .collect();
 80        self.0.edit(edits, ());
 81    }
 82
 83    pub fn clear(&mut self) {
 84        self.0 = SumTree::default();
 85    }
 86
 87    pub fn remove(&mut self, key: &K) -> Option<V> {
 88        let mut removed = None;
 89        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
 90        let key = MapKeyRef(Some(key));
 91        let mut new_tree = cursor.slice(&key, Bias::Left);
 92        if key.cmp(&cursor.end(), ()) == Ordering::Equal {
 93            removed = Some(cursor.item().unwrap().value.clone());
 94            cursor.next();
 95        }
 96        new_tree.append(cursor.suffix(), ());
 97        drop(cursor);
 98        self.0 = new_tree;
 99        removed
100    }
101
102    pub fn remove_range(&mut self, start: &impl MapSeekTarget<K>, end: &impl MapSeekTarget<K>) {
103        let start = MapSeekTargetAdaptor(start);
104        let end = MapSeekTargetAdaptor(end);
105        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
106        let mut new_tree = cursor.slice(&start, Bias::Left);
107        cursor.seek(&end, Bias::Left);
108        new_tree.append(cursor.suffix(), ());
109        drop(cursor);
110        self.0 = new_tree;
111    }
112
113    /// Returns the key-value pair with the greatest key less than or equal to the given key.
114    pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
115        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
116        let key = MapKeyRef(Some(key));
117        cursor.seek(&key, Bias::Right);
118        cursor.prev();
119        cursor.item().map(|item| (&item.key, &item.value))
120    }
121
122    pub fn iter_from<'a>(&'a self, from: &K) -> impl Iterator<Item = (&'a K, &'a V)> + 'a {
123        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
124        let from_key = MapKeyRef(Some(from));
125        cursor.seek(&from_key, Bias::Left);
126
127        cursor.map(|map_entry| (&map_entry.key, &map_entry.value))
128    }
129
130    pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
131    where
132        F: FnOnce(&mut V) -> T,
133    {
134        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
135        let key = MapKeyRef(Some(key));
136        let mut new_tree = cursor.slice(&key, Bias::Left);
137        let mut result = None;
138        if key.cmp(&cursor.end(), ()) == Ordering::Equal {
139            let mut updated = cursor.item().unwrap().clone();
140            result = Some(f(&mut updated.value));
141            new_tree.push(updated, ());
142            cursor.next();
143        }
144        new_tree.append(cursor.suffix(), ());
145        drop(cursor);
146        self.0 = new_tree;
147        result
148    }
149
150    pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
151        let mut new_map = SumTree::<MapEntry<K, V>>::default();
152
153        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
154        cursor.next();
155        while let Some(item) = cursor.item() {
156            if predicate(&item.key, &item.value) {
157                new_map.push(item.clone(), ());
158            }
159            cursor.next();
160        }
161        drop(cursor);
162
163        self.0 = new_map;
164    }
165
166    pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
167        self.0.iter().map(|entry| (&entry.key, &entry.value))
168    }
169
170    pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
171        self.0.iter().map(|entry| &entry.value)
172    }
173
174    pub fn first(&self) -> Option<(&K, &V)> {
175        self.0.first().map(|entry| (&entry.key, &entry.value))
176    }
177
178    pub fn last(&self) -> Option<(&K, &V)> {
179        self.0.last().map(|entry| (&entry.key, &entry.value))
180    }
181
182    pub fn insert_tree(&mut self, other: TreeMap<K, V>) {
183        let edits = other
184            .iter()
185            .map(|(key, value)| {
186                Edit::Insert(MapEntry {
187                    key: key.to_owned(),
188                    value: value.to_owned(),
189                })
190            })
191            .collect();
192
193        self.0.edit(edits, ());
194    }
195}
196
197impl<K, V> Debug for TreeMap<K, V>
198where
199    K: Clone + Debug + Ord,
200    V: Clone + Debug,
201{
202    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
203        f.debug_map().entries(self.iter()).finish()
204    }
205}
206
207#[derive(Debug)]
208struct MapSeekTargetAdaptor<'a, T>(&'a T);
209
210impl<'a, K: Clone + Ord, T: MapSeekTarget<K>> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>>
211    for MapSeekTargetAdaptor<'_, T>
212{
213    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: ()) -> Ordering {
214        if let Some(key) = &cursor_location.0 {
215            MapSeekTarget::cmp_cursor(self.0, key)
216        } else {
217            Ordering::Greater
218        }
219    }
220}
221
222pub trait MapSeekTarget<K> {
223    fn cmp_cursor(&self, cursor_location: &K) -> Ordering;
224}
225
226impl<K: Ord> MapSeekTarget<K> for K {
227    fn cmp_cursor(&self, cursor_location: &K) -> Ordering {
228        self.cmp(cursor_location)
229    }
230}
231
232impl<K, V> Default for TreeMap<K, V>
233where
234    K: Clone + Ord,
235    V: Clone,
236{
237    fn default() -> Self {
238        Self(Default::default())
239    }
240}
241
242impl<K, V> Item for MapEntry<K, V>
243where
244    K: Clone + Ord,
245    V: Clone,
246{
247    type Summary = MapKey<K>;
248
249    fn summary(&self, _cx: ()) -> Self::Summary {
250        self.key()
251    }
252}
253
254impl<K, V> KeyedItem for MapEntry<K, V>
255where
256    K: Clone + Ord,
257    V: Clone,
258{
259    type Key = MapKey<K>;
260
261    fn key(&self) -> Self::Key {
262        MapKey(Some(self.key.clone()))
263    }
264}
265
266impl<K> ContextLessSummary for MapKey<K>
267where
268    K: Clone,
269{
270    fn zero() -> Self {
271        Default::default()
272    }
273
274    fn add_summary(&mut self, summary: &Self) {
275        *self = summary.clone()
276    }
277}
278
279impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
280where
281    K: Clone + Ord,
282{
283    fn zero(_cx: ()) -> Self {
284        Default::default()
285    }
286
287    fn add_summary(&mut self, summary: &'a MapKey<K>, _: ()) {
288        self.0 = summary.0.as_ref();
289    }
290}
291
292impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
293where
294    K: Clone + Ord,
295{
296    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: ()) -> Ordering {
297        Ord::cmp(&self.0, &cursor_location.0)
298    }
299}
300
301impl<K> Default for TreeSet<K>
302where
303    K: Clone + Ord,
304{
305    fn default() -> Self {
306        Self(Default::default())
307    }
308}
309
310impl<K> TreeSet<K>
311where
312    K: Clone + Ord,
313{
314    pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
315        Self(TreeMap::from_ordered_entries(
316            entries.into_iter().map(|key| (key, ())),
317        ))
318    }
319
320    pub fn is_empty(&self) -> bool {
321        self.0.is_empty()
322    }
323
324    pub fn insert(&mut self, key: K) {
325        self.0.insert(key, ());
326    }
327
328    pub fn remove(&mut self, key: &K) -> bool {
329        self.0.remove(key).is_some()
330    }
331
332    pub fn extend(&mut self, iter: impl IntoIterator<Item = K>) {
333        self.0.extend(iter.into_iter().map(|key| (key, ())));
334    }
335
336    pub fn contains(&self, key: &K) -> bool {
337        self.0.get(key).is_some()
338    }
339
340    pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
341        self.0.iter().map(|(k, _)| k)
342    }
343
344    pub fn iter_from<'a>(&'a self, key: &K) -> impl Iterator<Item = &'a K> + 'a {
345        self.0.iter_from(key).map(move |(k, _)| k)
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn test_basic() {
355        let mut map = TreeMap::default();
356        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
357
358        map.insert(3, "c");
359        assert_eq!(map.get(&3), Some(&"c"));
360        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
361
362        map.insert(1, "a");
363        assert_eq!(map.get(&1), Some(&"a"));
364        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
365
366        map.insert(2, "b");
367        assert_eq!(map.get(&2), Some(&"b"));
368        assert_eq!(map.get(&1), Some(&"a"));
369        assert_eq!(map.get(&3), Some(&"c"));
370        assert_eq!(
371            map.iter().collect::<Vec<_>>(),
372            vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
373        );
374
375        assert_eq!(map.closest(&0), None);
376        assert_eq!(map.closest(&1), Some((&1, &"a")));
377        assert_eq!(map.closest(&10), Some((&3, &"c")));
378
379        map.remove(&2);
380        assert_eq!(map.get(&2), None);
381        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
382
383        assert_eq!(map.closest(&2), Some((&1, &"a")));
384
385        map.remove(&3);
386        assert_eq!(map.get(&3), None);
387        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
388
389        map.remove(&1);
390        assert_eq!(map.get(&1), None);
391        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
392
393        map.insert(4, "d");
394        map.insert(5, "e");
395        map.insert(6, "f");
396        map.retain(|key, _| *key % 2 == 0);
397        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
398    }
399
400    #[test]
401    fn test_iter_from() {
402        let mut map = TreeMap::default();
403
404        map.insert("a", 1);
405        map.insert("b", 2);
406        map.insert("baa", 3);
407        map.insert("baaab", 4);
408        map.insert("c", 5);
409
410        let result = map
411            .iter_from(&"ba")
412            .take_while(|(key, _)| key.starts_with("ba"))
413            .collect::<Vec<_>>();
414
415        assert_eq!(result.len(), 2);
416        assert!(result.iter().any(|(k, _)| k == &&"baa"));
417        assert!(result.iter().any(|(k, _)| k == &&"baaab"));
418
419        let result = map
420            .iter_from(&"c")
421            .take_while(|(key, _)| key.starts_with("c"))
422            .collect::<Vec<_>>();
423
424        assert_eq!(result.len(), 1);
425        assert!(result.iter().any(|(k, _)| k == &&"c"));
426    }
427
428    #[test]
429    fn test_insert_tree() {
430        let mut map = TreeMap::default();
431        map.insert("a", 1);
432        map.insert("b", 2);
433        map.insert("c", 3);
434
435        let mut other = TreeMap::default();
436        other.insert("a", 2);
437        other.insert("b", 2);
438        other.insert("d", 4);
439
440        map.insert_tree(other);
441
442        assert_eq!(map.iter().count(), 4);
443        assert_eq!(map.get(&"a"), Some(&2));
444        assert_eq!(map.get(&"b"), Some(&2));
445        assert_eq!(map.get(&"c"), Some(&3));
446        assert_eq!(map.get(&"d"), Some(&4));
447    }
448
449    #[test]
450    fn test_extend() {
451        let mut map = TreeMap::default();
452        map.insert("a", 1);
453        map.insert("b", 2);
454        map.insert("c", 3);
455        map.extend([("a", 2), ("b", 2), ("d", 4)]);
456        assert_eq!(map.iter().count(), 4);
457        assert_eq!(map.get(&"a"), Some(&2));
458        assert_eq!(map.get(&"b"), Some(&2));
459        assert_eq!(map.get(&"c"), Some(&3));
460        assert_eq!(map.get(&"d"), Some(&4));
461    }
462
463    #[test]
464    fn test_remove_between_and_path_successor() {
465        use std::path::{Path, PathBuf};
466
467        #[derive(Debug)]
468        pub struct PathDescendants<'a>(&'a Path);
469
470        impl MapSeekTarget<PathBuf> for PathDescendants<'_> {
471            fn cmp_cursor(&self, key: &PathBuf) -> Ordering {
472                if key.starts_with(self.0) {
473                    Ordering::Greater
474                } else {
475                    self.0.cmp(key)
476                }
477            }
478        }
479
480        let mut map = TreeMap::default();
481
482        map.insert(PathBuf::from("a"), 1);
483        map.insert(PathBuf::from("a/a"), 1);
484        map.insert(PathBuf::from("b"), 2);
485        map.insert(PathBuf::from("b/a/a"), 3);
486        map.insert(PathBuf::from("b/a/a/a/b"), 4);
487        map.insert(PathBuf::from("c"), 5);
488        map.insert(PathBuf::from("c/a"), 6);
489
490        map.remove_range(
491            &PathBuf::from("b/a"),
492            &PathDescendants(&PathBuf::from("b/a")),
493        );
494
495        assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
496        assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
497        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
498        assert_eq!(map.get(&PathBuf::from("b/a/a")), None);
499        assert_eq!(map.get(&PathBuf::from("b/a/a/a/b")), None);
500        assert_eq!(map.get(&PathBuf::from("c")), Some(&5));
501        assert_eq!(map.get(&PathBuf::from("c/a")), Some(&6));
502
503        map.remove_range(&PathBuf::from("c"), &PathDescendants(&PathBuf::from("c")));
504
505        assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
506        assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
507        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
508        assert_eq!(map.get(&PathBuf::from("c")), None);
509        assert_eq!(map.get(&PathBuf::from("c/a")), None);
510
511        map.remove_range(&PathBuf::from("a"), &PathDescendants(&PathBuf::from("a")));
512
513        assert_eq!(map.get(&PathBuf::from("a")), None);
514        assert_eq!(map.get(&PathBuf::from("a/a")), None);
515        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
516
517        map.remove_range(&PathBuf::from("b"), &PathDescendants(&PathBuf::from("b")));
518
519        assert_eq!(map.get(&PathBuf::from("b")), None);
520    }
521}