tree_map.rs

  1use std::{cmp::Ordering, fmt::Debug};
  2
  3use crate::{Bias, Dimension, Edit, Item, KeyedItem, SeekTarget, SumTree, Summary};
  4
  5#[derive(Clone, PartialEq, Eq)]
  6pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
  7where
  8    K: Clone + Debug + Ord,
  9    V: Clone + Debug;
 10
 11#[derive(Clone, Debug, PartialEq, Eq)]
 12pub struct MapEntry<K, V> {
 13    key: K,
 14    value: V,
 15}
 16
 17#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
 18pub struct MapKey<K>(Option<K>);
 19
 20impl<K> Default for MapKey<K> {
 21    fn default() -> Self {
 22        Self(None)
 23    }
 24}
 25
 26#[derive(Clone, Debug)]
 27pub struct MapKeyRef<'a, K>(Option<&'a K>);
 28
 29impl<'a, K> Default for MapKeyRef<'a, K> {
 30    fn default() -> Self {
 31        Self(None)
 32    }
 33}
 34
 35#[derive(Clone)]
 36pub struct TreeSet<K>(TreeMap<K, ()>)
 37where
 38    K: Clone + Debug + Ord;
 39
 40impl<K: Clone + Debug + Ord, V: Clone + Debug> TreeMap<K, V> {
 41    pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
 42        let tree = SumTree::from_iter(
 43            entries
 44                .into_iter()
 45                .map(|(key, value)| MapEntry { key, value }),
 46            &(),
 47        );
 48        Self(tree)
 49    }
 50
 51    pub fn is_empty(&self) -> bool {
 52        self.0.is_empty()
 53    }
 54
 55    pub fn get(&self, key: &K) -> Option<&V> {
 56        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 57        cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &());
 58        if let Some(item) = cursor.item() {
 59            if Some(key) == item.key().0.as_ref() {
 60                Some(&item.value)
 61            } else {
 62                None
 63            }
 64        } else {
 65            None
 66        }
 67    }
 68
 69    pub fn insert(&mut self, key: K, value: V) {
 70        self.0.insert_or_replace(MapEntry { key, value }, &());
 71    }
 72
 73    pub fn remove(&mut self, key: &K) -> Option<V> {
 74        let mut removed = None;
 75        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 76        let key = MapKeyRef(Some(key));
 77        let mut new_tree = cursor.slice(&key, Bias::Left, &());
 78        if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
 79            removed = Some(cursor.item().unwrap().value.clone());
 80            cursor.next(&());
 81        }
 82        new_tree.append(cursor.suffix(&()), &());
 83        drop(cursor);
 84        self.0 = new_tree;
 85        removed
 86    }
 87
 88    pub fn remove_range(&mut self, start: &impl MapSeekTarget<K>, end: &impl MapSeekTarget<K>) {
 89        let start = MapSeekTargetAdaptor(start);
 90        let end = MapSeekTargetAdaptor(end);
 91        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 92        let mut new_tree = cursor.slice(&start, Bias::Left, &());
 93        cursor.seek(&end, Bias::Left, &());
 94        new_tree.append(cursor.suffix(&()), &());
 95        drop(cursor);
 96        self.0 = new_tree;
 97    }
 98
 99    /// Returns the key-value pair with the greatest key less than or equal to the given key.
100    pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
101        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
102        let key = MapKeyRef(Some(key));
103        cursor.seek(&key, Bias::Right, &());
104        cursor.prev(&());
105        cursor.item().map(|item| (&item.key, &item.value))
106    }
107
108    pub fn iter_from<'a>(&'a self, from: &'a K) -> impl Iterator<Item = (&K, &V)> + '_ {
109        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
110        let from_key = MapKeyRef(Some(from));
111        cursor.seek(&from_key, Bias::Left, &());
112
113        cursor.map(|map_entry| (&map_entry.key, &map_entry.value))
114    }
115
116    pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
117    where
118        F: FnOnce(&mut V) -> T,
119    {
120        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
121        let key = MapKeyRef(Some(key));
122        let mut new_tree = cursor.slice(&key, Bias::Left, &());
123        let mut result = None;
124        if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
125            let mut updated = cursor.item().unwrap().clone();
126            result = Some(f(&mut updated.value));
127            new_tree.push(updated, &());
128            cursor.next(&());
129        }
130        new_tree.append(cursor.suffix(&()), &());
131        drop(cursor);
132        self.0 = new_tree;
133        result
134    }
135
136    pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
137        let mut new_map = SumTree::<MapEntry<K, V>>::default();
138
139        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
140        cursor.next(&());
141        while let Some(item) = cursor.item() {
142            if predicate(&item.key, &item.value) {
143                new_map.push(item.clone(), &());
144            }
145            cursor.next(&());
146        }
147        drop(cursor);
148
149        self.0 = new_map;
150    }
151
152    pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
153        self.0.iter().map(|entry| (&entry.key, &entry.value))
154    }
155
156    pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
157        self.0.iter().map(|entry| &entry.value)
158    }
159
160    pub fn insert_tree(&mut self, other: TreeMap<K, V>) {
161        let edits = other
162            .iter()
163            .map(|(key, value)| {
164                Edit::Insert(MapEntry {
165                    key: key.to_owned(),
166                    value: value.to_owned(),
167                })
168            })
169            .collect();
170
171        self.0.edit(edits, &());
172    }
173}
174
175impl<K: Debug, V: Debug> Debug for TreeMap<K, V>
176where
177    K: Clone + Debug + Ord,
178    V: Clone + Debug,
179{
180    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181        f.debug_map().entries(self.iter()).finish()
182    }
183}
184
185#[derive(Debug)]
186struct MapSeekTargetAdaptor<'a, T>(&'a T);
187
188impl<'a, K: Debug + Clone + Ord, T: MapSeekTarget<K>> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>>
189    for MapSeekTargetAdaptor<'_, T>
190{
191    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
192        if let Some(key) = &cursor_location.0 {
193            MapSeekTarget::cmp_cursor(self.0, key)
194        } else {
195            Ordering::Greater
196        }
197    }
198}
199
200pub trait MapSeekTarget<K>: Debug {
201    fn cmp_cursor(&self, cursor_location: &K) -> Ordering;
202}
203
204impl<K: Debug + Ord> MapSeekTarget<K> for K {
205    fn cmp_cursor(&self, cursor_location: &K) -> Ordering {
206        self.cmp(cursor_location)
207    }
208}
209
210impl<K, V> Default for TreeMap<K, V>
211where
212    K: Clone + Debug + Ord,
213    V: Clone + Debug,
214{
215    fn default() -> Self {
216        Self(Default::default())
217    }
218}
219
220impl<K, V> Item for MapEntry<K, V>
221where
222    K: Clone + Debug + Ord,
223    V: Clone,
224{
225    type Summary = MapKey<K>;
226
227    fn summary(&self) -> Self::Summary {
228        self.key()
229    }
230}
231
232impl<K, V> KeyedItem for MapEntry<K, V>
233where
234    K: Clone + Debug + Ord,
235    V: Clone,
236{
237    type Key = MapKey<K>;
238
239    fn key(&self) -> Self::Key {
240        MapKey(Some(self.key.clone()))
241    }
242}
243
244impl<K> Summary for MapKey<K>
245where
246    K: Clone + Debug,
247{
248    type Context = ();
249
250    fn add_summary(&mut self, summary: &Self, _: &()) {
251        *self = summary.clone()
252    }
253}
254
255impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
256where
257    K: Clone + Debug + Ord,
258{
259    fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
260        self.0 = summary.0.as_ref();
261    }
262}
263
264impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
265where
266    K: Clone + Debug + Ord,
267{
268    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
269        Ord::cmp(&self.0, &cursor_location.0)
270    }
271}
272
273impl<K> Default for TreeSet<K>
274where
275    K: Clone + Debug + Ord,
276{
277    fn default() -> Self {
278        Self(Default::default())
279    }
280}
281
282impl<K> TreeSet<K>
283where
284    K: Clone + Debug + Ord,
285{
286    pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
287        Self(TreeMap::from_ordered_entries(
288            entries.into_iter().map(|key| (key, ())),
289        ))
290    }
291
292    pub fn insert(&mut self, key: K) {
293        self.0.insert(key, ());
294    }
295
296    pub fn contains(&self, key: &K) -> bool {
297        self.0.get(key).is_some()
298    }
299
300    pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
301        self.0.iter().map(|(k, _)| k)
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    #[test]
310    fn test_basic() {
311        let mut map = TreeMap::default();
312        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
313
314        map.insert(3, "c");
315        assert_eq!(map.get(&3), Some(&"c"));
316        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
317
318        map.insert(1, "a");
319        assert_eq!(map.get(&1), Some(&"a"));
320        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
321
322        map.insert(2, "b");
323        assert_eq!(map.get(&2), Some(&"b"));
324        assert_eq!(map.get(&1), Some(&"a"));
325        assert_eq!(map.get(&3), Some(&"c"));
326        assert_eq!(
327            map.iter().collect::<Vec<_>>(),
328            vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
329        );
330
331        assert_eq!(map.closest(&0), None);
332        assert_eq!(map.closest(&1), Some((&1, &"a")));
333        assert_eq!(map.closest(&10), Some((&3, &"c")));
334
335        map.remove(&2);
336        assert_eq!(map.get(&2), None);
337        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
338
339        assert_eq!(map.closest(&2), Some((&1, &"a")));
340
341        map.remove(&3);
342        assert_eq!(map.get(&3), None);
343        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
344
345        map.remove(&1);
346        assert_eq!(map.get(&1), None);
347        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
348
349        map.insert(4, "d");
350        map.insert(5, "e");
351        map.insert(6, "f");
352        map.retain(|key, _| *key % 2 == 0);
353        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
354    }
355
356    #[test]
357    fn test_iter_from() {
358        let mut map = TreeMap::default();
359
360        map.insert("a", 1);
361        map.insert("b", 2);
362        map.insert("baa", 3);
363        map.insert("baaab", 4);
364        map.insert("c", 5);
365
366        let result = map
367            .iter_from(&"ba")
368            .take_while(|(key, _)| key.starts_with(&"ba"))
369            .collect::<Vec<_>>();
370
371        assert_eq!(result.len(), 2);
372        assert!(result.iter().any(|(k, _)| k == &&"baa"));
373        assert!(result.iter().any(|(k, _)| k == &&"baaab"));
374
375        let result = map
376            .iter_from(&"c")
377            .take_while(|(key, _)| key.starts_with(&"c"))
378            .collect::<Vec<_>>();
379
380        assert_eq!(result.len(), 1);
381        assert!(result.iter().any(|(k, _)| k == &&"c"));
382    }
383
384    #[test]
385    fn test_insert_tree() {
386        let mut map = TreeMap::default();
387        map.insert("a", 1);
388        map.insert("b", 2);
389        map.insert("c", 3);
390
391        let mut other = TreeMap::default();
392        other.insert("a", 2);
393        other.insert("b", 2);
394        other.insert("d", 4);
395
396        map.insert_tree(other);
397
398        assert_eq!(map.iter().count(), 4);
399        assert_eq!(map.get(&"a"), Some(&2));
400        assert_eq!(map.get(&"b"), Some(&2));
401        assert_eq!(map.get(&"c"), Some(&3));
402        assert_eq!(map.get(&"d"), Some(&4));
403    }
404
405    #[test]
406    fn test_remove_between_and_path_successor() {
407        use std::path::{Path, PathBuf};
408
409        #[derive(Debug)]
410        pub struct PathDescendants<'a>(&'a Path);
411
412        impl MapSeekTarget<PathBuf> for PathDescendants<'_> {
413            fn cmp_cursor(&self, key: &PathBuf) -> Ordering {
414                if key.starts_with(&self.0) {
415                    Ordering::Greater
416                } else {
417                    self.0.cmp(key)
418                }
419            }
420        }
421
422        let mut map = TreeMap::default();
423
424        map.insert(PathBuf::from("a"), 1);
425        map.insert(PathBuf::from("a/a"), 1);
426        map.insert(PathBuf::from("b"), 2);
427        map.insert(PathBuf::from("b/a/a"), 3);
428        map.insert(PathBuf::from("b/a/a/a/b"), 4);
429        map.insert(PathBuf::from("c"), 5);
430        map.insert(PathBuf::from("c/a"), 6);
431
432        map.remove_range(
433            &PathBuf::from("b/a"),
434            &PathDescendants(&PathBuf::from("b/a")),
435        );
436
437        assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
438        assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
439        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
440        assert_eq!(map.get(&PathBuf::from("b/a/a")), None);
441        assert_eq!(map.get(&PathBuf::from("b/a/a/a/b")), None);
442        assert_eq!(map.get(&PathBuf::from("c")), Some(&5));
443        assert_eq!(map.get(&PathBuf::from("c/a")), Some(&6));
444
445        map.remove_range(&PathBuf::from("c"), &PathDescendants(&PathBuf::from("c")));
446
447        assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
448        assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
449        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
450        assert_eq!(map.get(&PathBuf::from("c")), None);
451        assert_eq!(map.get(&PathBuf::from("c/a")), None);
452
453        map.remove_range(&PathBuf::from("a"), &PathDescendants(&PathBuf::from("a")));
454
455        assert_eq!(map.get(&PathBuf::from("a")), None);
456        assert_eq!(map.get(&PathBuf::from("a/a")), None);
457        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
458
459        map.remove_range(&PathBuf::from("b"), &PathDescendants(&PathBuf::from("b")));
460
461        assert_eq!(map.get(&PathBuf::from("b")), None);
462    }
463}