tree_map.rs

  1use std::{cmp::Ordering, fmt::Debug};
  2
  3use crate::{Bias, Dimension, Edit, Item, KeyedItem, SeekTarget, SumTree, Summary};
  4
  5#[derive(Clone, PartialEq, Eq)]
  6pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
  7where
  8    K: Clone + Debug + Default + Ord,
  9    V: Clone + Debug;
 10
 11#[derive(Clone, Debug, PartialEq, Eq)]
 12pub struct MapEntry<K, V> {
 13    key: K,
 14    value: V,
 15}
 16
 17#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 18pub struct MapKey<K>(K);
 19
 20#[derive(Clone, Debug, Default)]
 21pub struct MapKeyRef<'a, K>(Option<&'a K>);
 22
 23#[derive(Clone)]
 24pub struct TreeSet<K>(TreeMap<K, ()>)
 25where
 26    K: Clone + Debug + Default + Ord;
 27
 28impl<K: Clone + Debug + Default + Ord, V: Clone + Debug> TreeMap<K, V> {
 29    pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
 30        let tree = SumTree::from_iter(
 31            entries
 32                .into_iter()
 33                .map(|(key, value)| MapEntry { key, value }),
 34            &(),
 35        );
 36        Self(tree)
 37    }
 38
 39    pub fn is_empty(&self) -> bool {
 40        self.0.is_empty()
 41    }
 42
 43    pub fn get(&self, key: &K) -> Option<&V> {
 44        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 45        cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &());
 46        if let Some(item) = cursor.item() {
 47            if *key == item.key().0 {
 48                Some(&item.value)
 49            } else {
 50                None
 51            }
 52        } else {
 53            None
 54        }
 55    }
 56
 57    pub fn insert(&mut self, key: K, value: V) {
 58        self.0.insert_or_replace(MapEntry { key, value }, &());
 59    }
 60
 61    pub fn remove(&mut self, key: &K) -> Option<V> {
 62        let mut removed = None;
 63        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 64        let key = MapKeyRef(Some(key));
 65        let mut new_tree = cursor.slice(&key, Bias::Left, &());
 66        if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
 67            removed = Some(cursor.item().unwrap().value.clone());
 68            cursor.next(&());
 69        }
 70        new_tree.append(cursor.suffix(&()), &());
 71        drop(cursor);
 72        self.0 = new_tree;
 73        removed
 74    }
 75
 76    pub fn remove_range(&mut self, start: &impl MapSeekTarget<K>, end: &impl MapSeekTarget<K>) {
 77        let start = MapSeekTargetAdaptor(start);
 78        let end = MapSeekTargetAdaptor(end);
 79        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 80        let mut new_tree = cursor.slice(&start, Bias::Left, &());
 81        cursor.seek(&end, Bias::Left, &());
 82        new_tree.append(cursor.suffix(&()), &());
 83        drop(cursor);
 84        self.0 = new_tree;
 85    }
 86
 87    /// Returns the key-value pair with the greatest key less than or equal to the given key.
 88    pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
 89        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 90        let key = MapKeyRef(Some(key));
 91        cursor.seek(&key, Bias::Right, &());
 92        cursor.prev(&());
 93        cursor.item().map(|item| (&item.key, &item.value))
 94    }
 95
 96    pub fn iter_from<'a>(&'a self, from: &'a K) -> impl Iterator<Item = (&K, &V)> + '_ {
 97        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 98        let from_key = MapKeyRef(Some(from));
 99        cursor.seek(&from_key, Bias::Left, &());
100
101        cursor.map(|map_entry| (&map_entry.key, &map_entry.value))
102    }
103
104    pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
105    where
106        F: FnOnce(&mut V) -> T,
107    {
108        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
109        let key = MapKeyRef(Some(key));
110        let mut new_tree = cursor.slice(&key, Bias::Left, &());
111        let mut result = None;
112        if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
113            let mut updated = cursor.item().unwrap().clone();
114            result = Some(f(&mut updated.value));
115            new_tree.push(updated, &());
116            cursor.next(&());
117        }
118        new_tree.append(cursor.suffix(&()), &());
119        drop(cursor);
120        self.0 = new_tree;
121        result
122    }
123
124    pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
125        let mut new_map = SumTree::<MapEntry<K, V>>::default();
126
127        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
128        cursor.next(&());
129        while let Some(item) = cursor.item() {
130            if predicate(&item.key, &item.value) {
131                new_map.push(item.clone(), &());
132            }
133            cursor.next(&());
134        }
135        drop(cursor);
136
137        self.0 = new_map;
138    }
139
140    pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
141        self.0.iter().map(|entry| (&entry.key, &entry.value))
142    }
143
144    pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
145        self.0.iter().map(|entry| &entry.value)
146    }
147
148    pub fn insert_tree(&mut self, other: TreeMap<K, V>) {
149        let edits = other
150            .iter()
151            .map(|(key, value)| {
152                Edit::Insert(MapEntry {
153                    key: key.to_owned(),
154                    value: value.to_owned(),
155                })
156            })
157            .collect();
158
159        self.0.edit(edits, &());
160    }
161}
162
163impl<K: Debug, V: Debug> Debug for TreeMap<K, V>
164where
165    K: Clone + Debug + Default + Ord,
166    V: Clone + Debug,
167{
168    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169        f.debug_map().entries(self.iter()).finish()
170    }
171}
172
173#[derive(Debug)]
174struct MapSeekTargetAdaptor<'a, T>(&'a T);
175
176impl<'a, K: Debug + Clone + Default + Ord, T: MapSeekTarget<K>>
177    SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapSeekTargetAdaptor<'_, T>
178{
179    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
180        if let Some(key) = &cursor_location.0 {
181            MapSeekTarget::cmp_cursor(self.0, key)
182        } else {
183            Ordering::Greater
184        }
185    }
186}
187
188pub trait MapSeekTarget<K>: Debug {
189    fn cmp_cursor(&self, cursor_location: &K) -> Ordering;
190}
191
192impl<K: Debug + Ord> MapSeekTarget<K> for K {
193    fn cmp_cursor(&self, cursor_location: &K) -> Ordering {
194        self.cmp(cursor_location)
195    }
196}
197
198impl<K, V> Default for TreeMap<K, V>
199where
200    K: Clone + Debug + Default + Ord,
201    V: Clone + Debug,
202{
203    fn default() -> Self {
204        Self(Default::default())
205    }
206}
207
208impl<K, V> Item for MapEntry<K, V>
209where
210    K: Clone + Debug + Default + Ord,
211    V: Clone,
212{
213    type Summary = MapKey<K>;
214
215    fn summary(&self) -> Self::Summary {
216        self.key()
217    }
218}
219
220impl<K, V> KeyedItem for MapEntry<K, V>
221where
222    K: Clone + Debug + Default + Ord,
223    V: Clone,
224{
225    type Key = MapKey<K>;
226
227    fn key(&self) -> Self::Key {
228        MapKey(self.key.clone())
229    }
230}
231
232impl<K> Summary for MapKey<K>
233where
234    K: Clone + Debug + Default,
235{
236    type Context = ();
237
238    fn add_summary(&mut self, summary: &Self, _: &()) {
239        *self = summary.clone()
240    }
241}
242
243impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
244where
245    K: Clone + Debug + Default + Ord,
246{
247    fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
248        self.0 = Some(&summary.0)
249    }
250}
251
252impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
253where
254    K: Clone + Debug + Default + Ord,
255{
256    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
257        Ord::cmp(&self.0, &cursor_location.0)
258    }
259}
260
261impl<K> Default for TreeSet<K>
262where
263    K: Clone + Debug + Default + Ord,
264{
265    fn default() -> Self {
266        Self(Default::default())
267    }
268}
269
270impl<K> TreeSet<K>
271where
272    K: Clone + Debug + Default + Ord,
273{
274    pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
275        Self(TreeMap::from_ordered_entries(
276            entries.into_iter().map(|key| (key, ())),
277        ))
278    }
279
280    pub fn insert(&mut self, key: K) {
281        self.0.insert(key, ());
282    }
283
284    pub fn contains(&self, key: &K) -> bool {
285        self.0.get(key).is_some()
286    }
287
288    pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
289        self.0.iter().map(|(k, _)| k)
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296
297    #[test]
298    fn test_basic() {
299        let mut map = TreeMap::default();
300        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
301
302        map.insert(3, "c");
303        assert_eq!(map.get(&3), Some(&"c"));
304        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
305
306        map.insert(1, "a");
307        assert_eq!(map.get(&1), Some(&"a"));
308        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
309
310        map.insert(2, "b");
311        assert_eq!(map.get(&2), Some(&"b"));
312        assert_eq!(map.get(&1), Some(&"a"));
313        assert_eq!(map.get(&3), Some(&"c"));
314        assert_eq!(
315            map.iter().collect::<Vec<_>>(),
316            vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
317        );
318
319        assert_eq!(map.closest(&0), None);
320        assert_eq!(map.closest(&1), Some((&1, &"a")));
321        assert_eq!(map.closest(&10), Some((&3, &"c")));
322
323        map.remove(&2);
324        assert_eq!(map.get(&2), None);
325        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
326
327        assert_eq!(map.closest(&2), Some((&1, &"a")));
328
329        map.remove(&3);
330        assert_eq!(map.get(&3), None);
331        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
332
333        map.remove(&1);
334        assert_eq!(map.get(&1), None);
335        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
336
337        map.insert(4, "d");
338        map.insert(5, "e");
339        map.insert(6, "f");
340        map.retain(|key, _| *key % 2 == 0);
341        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
342    }
343
344    #[test]
345    fn test_iter_from() {
346        let mut map = TreeMap::default();
347
348        map.insert("a", 1);
349        map.insert("b", 2);
350        map.insert("baa", 3);
351        map.insert("baaab", 4);
352        map.insert("c", 5);
353
354        let result = map
355            .iter_from(&"ba")
356            .take_while(|(key, _)| key.starts_with(&"ba"))
357            .collect::<Vec<_>>();
358
359        assert_eq!(result.len(), 2);
360        assert!(result.iter().any(|(k, _)| k == &&"baa"));
361        assert!(result.iter().any(|(k, _)| k == &&"baaab"));
362
363        let result = map
364            .iter_from(&"c")
365            .take_while(|(key, _)| key.starts_with(&"c"))
366            .collect::<Vec<_>>();
367
368        assert_eq!(result.len(), 1);
369        assert!(result.iter().any(|(k, _)| k == &&"c"));
370    }
371
372    #[test]
373    fn test_insert_tree() {
374        let mut map = TreeMap::default();
375        map.insert("a", 1);
376        map.insert("b", 2);
377        map.insert("c", 3);
378
379        let mut other = TreeMap::default();
380        other.insert("a", 2);
381        other.insert("b", 2);
382        other.insert("d", 4);
383
384        map.insert_tree(other);
385
386        assert_eq!(map.iter().count(), 4);
387        assert_eq!(map.get(&"a"), Some(&2));
388        assert_eq!(map.get(&"b"), Some(&2));
389        assert_eq!(map.get(&"c"), Some(&3));
390        assert_eq!(map.get(&"d"), Some(&4));
391    }
392
393    #[test]
394    fn test_remove_between_and_path_successor() {
395        use std::path::{Path, PathBuf};
396
397        #[derive(Debug)]
398        pub struct PathDescendants<'a>(&'a Path);
399
400        impl MapSeekTarget<PathBuf> for PathDescendants<'_> {
401            fn cmp_cursor(&self, key: &PathBuf) -> Ordering {
402                if key.starts_with(&self.0) {
403                    Ordering::Greater
404                } else {
405                    self.0.cmp(key)
406                }
407            }
408        }
409
410        let mut map = TreeMap::default();
411
412        map.insert(PathBuf::from("a"), 1);
413        map.insert(PathBuf::from("a/a"), 1);
414        map.insert(PathBuf::from("b"), 2);
415        map.insert(PathBuf::from("b/a/a"), 3);
416        map.insert(PathBuf::from("b/a/a/a/b"), 4);
417        map.insert(PathBuf::from("c"), 5);
418        map.insert(PathBuf::from("c/a"), 6);
419
420        map.remove_range(
421            &PathBuf::from("b/a"),
422            &PathDescendants(&PathBuf::from("b/a")),
423        );
424
425        assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
426        assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
427        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
428        assert_eq!(map.get(&PathBuf::from("b/a/a")), None);
429        assert_eq!(map.get(&PathBuf::from("b/a/a/a/b")), None);
430        assert_eq!(map.get(&PathBuf::from("c")), Some(&5));
431        assert_eq!(map.get(&PathBuf::from("c/a")), Some(&6));
432
433        map.remove_range(&PathBuf::from("c"), &PathDescendants(&PathBuf::from("c")));
434
435        assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
436        assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
437        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
438        assert_eq!(map.get(&PathBuf::from("c")), None);
439        assert_eq!(map.get(&PathBuf::from("c/a")), None);
440
441        map.remove_range(&PathBuf::from("a"), &PathDescendants(&PathBuf::from("a")));
442
443        assert_eq!(map.get(&PathBuf::from("a")), None);
444        assert_eq!(map.get(&PathBuf::from("a/a")), None);
445        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
446
447        map.remove_range(&PathBuf::from("b"), &PathDescendants(&PathBuf::from("b")));
448
449        assert_eq!(map.get(&PathBuf::from("b")), None);
450    }
451}