tree_map.rs

  1use std::{cmp::Ordering, fmt::Debug};
  2
  3use crate::{Bias, Dimension, Edit, Item, KeyedItem, SeekTarget, SumTree, Summary};
  4
  5#[derive(Clone, Debug, PartialEq, Eq)]
  6pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
  7where
  8    K: Clone + Debug + Default + Ord,
  9    V: Clone + Debug;
 10
 11#[derive(Clone, Debug, PartialEq, Eq)]
 12pub struct MapEntry<K, V> {
 13    key: K,
 14    value: V,
 15}
 16
 17#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 18pub struct MapKey<K>(K);
 19
 20#[derive(Clone, Debug, Default)]
 21pub struct MapKeyRef<'a, K>(Option<&'a K>);
 22
 23#[derive(Clone)]
 24pub struct TreeSet<K>(TreeMap<K, ()>)
 25where
 26    K: Clone + Debug + Default + Ord;
 27
 28impl<K: Clone + Debug + Default + Ord, V: Clone + Debug> TreeMap<K, V> {
 29    pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
 30        let tree = SumTree::from_iter(
 31            entries
 32                .into_iter()
 33                .map(|(key, value)| MapEntry { key, value }),
 34            &(),
 35        );
 36        Self(tree)
 37    }
 38
 39    pub fn is_empty(&self) -> bool {
 40        self.0.is_empty()
 41    }
 42
 43    pub fn get<'a>(&self, key: &'a K) -> Option<&V> {
 44        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 45        cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &());
 46        if let Some(item) = cursor.item() {
 47            if *key == item.key().0 {
 48                Some(&item.value)
 49            } else {
 50                None
 51            }
 52        } else {
 53            None
 54        }
 55    }
 56
 57    pub fn insert(&mut self, key: K, value: V) {
 58        self.0.insert_or_replace(MapEntry { key, value }, &());
 59    }
 60
 61    pub fn remove(&mut self, key: &K) -> Option<V> {
 62        let mut removed = None;
 63        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 64        let key = MapKeyRef(Some(key));
 65        let mut new_tree = cursor.slice(&key, Bias::Left, &());
 66        if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
 67            removed = Some(cursor.item().unwrap().value.clone());
 68            cursor.next(&());
 69        }
 70        new_tree.append(cursor.suffix(&()), &());
 71        drop(cursor);
 72        self.0 = new_tree;
 73        removed
 74    }
 75
 76    pub fn remove_range(&mut self, start: &impl MapSeekTarget<K>, end: &impl MapSeekTarget<K>) {
 77        let start = MapSeekTargetAdaptor(start);
 78        let end = MapSeekTargetAdaptor(end);
 79        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 80        let mut new_tree = cursor.slice(&start, Bias::Left, &());
 81        cursor.seek(&end, Bias::Left, &());
 82        new_tree.append(cursor.suffix(&()), &());
 83        drop(cursor);
 84        self.0 = new_tree;
 85    }
 86
 87    /// Returns the key-value pair with the greatest key less than or equal to the given key.
 88    pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
 89        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 90        let key = MapKeyRef(Some(key));
 91        cursor.seek(&key, Bias::Right, &());
 92        cursor.prev(&());
 93        cursor.item().map(|item| (&item.key, &item.value))
 94    }
 95
 96    pub fn iter_from<'a>(&'a self, from: &'a K) -> impl Iterator<Item = (&K, &V)> + '_ {
 97        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 98        let from_key = MapKeyRef(Some(from));
 99        cursor.seek(&from_key, Bias::Left, &());
100
101        cursor
102            .into_iter()
103            .map(|map_entry| (&map_entry.key, &map_entry.value))
104    }
105
106    pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
107    where
108        F: FnOnce(&mut V) -> T,
109    {
110        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
111        let key = MapKeyRef(Some(key));
112        let mut new_tree = cursor.slice(&key, Bias::Left, &());
113        let mut result = None;
114        if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
115            let mut updated = cursor.item().unwrap().clone();
116            result = Some(f(&mut updated.value));
117            new_tree.push(updated, &());
118            cursor.next(&());
119        }
120        new_tree.append(cursor.suffix(&()), &());
121        drop(cursor);
122        self.0 = new_tree;
123        result
124    }
125
126    pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
127        let mut new_map = SumTree::<MapEntry<K, V>>::default();
128
129        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
130        cursor.next(&());
131        while let Some(item) = cursor.item() {
132            if predicate(&item.key, &item.value) {
133                new_map.push(item.clone(), &());
134            }
135            cursor.next(&());
136        }
137        drop(cursor);
138
139        self.0 = new_map;
140    }
141
142    pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
143        self.0.iter().map(|entry| (&entry.key, &entry.value))
144    }
145
146    pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
147        self.0.iter().map(|entry| &entry.value)
148    }
149
150    pub fn insert_tree(&mut self, other: TreeMap<K, V>) {
151        let edits = other
152            .iter()
153            .map(|(key, value)| {
154                Edit::Insert(MapEntry {
155                    key: key.to_owned(),
156                    value: value.to_owned(),
157                })
158            })
159            .collect();
160
161        self.0.edit(edits, &());
162    }
163}
164
165#[derive(Debug)]
166struct MapSeekTargetAdaptor<'a, T>(&'a T);
167
168impl<'a, K: Debug + Clone + Default + Ord, T: MapSeekTarget<K>>
169    SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapSeekTargetAdaptor<'_, T>
170{
171    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
172        if let Some(key) = &cursor_location.0 {
173            MapSeekTarget::cmp_cursor(self.0, key)
174        } else {
175            Ordering::Greater
176        }
177    }
178}
179
180pub trait MapSeekTarget<K>: Debug {
181    fn cmp_cursor(&self, cursor_location: &K) -> Ordering;
182}
183
184impl<K: Debug + Ord> MapSeekTarget<K> for K {
185    fn cmp_cursor(&self, cursor_location: &K) -> Ordering {
186        self.cmp(cursor_location)
187    }
188}
189
190impl<K, V> Default for TreeMap<K, V>
191where
192    K: Clone + Debug + Default + Ord,
193    V: Clone + Debug,
194{
195    fn default() -> Self {
196        Self(Default::default())
197    }
198}
199
200impl<K, V> Item for MapEntry<K, V>
201where
202    K: Clone + Debug + Default + Ord,
203    V: Clone,
204{
205    type Summary = MapKey<K>;
206
207    fn summary(&self) -> Self::Summary {
208        self.key()
209    }
210}
211
212impl<K, V> KeyedItem for MapEntry<K, V>
213where
214    K: Clone + Debug + Default + Ord,
215    V: Clone,
216{
217    type Key = MapKey<K>;
218
219    fn key(&self) -> Self::Key {
220        MapKey(self.key.clone())
221    }
222}
223
224impl<K> Summary for MapKey<K>
225where
226    K: Clone + Debug + Default,
227{
228    type Context = ();
229
230    fn add_summary(&mut self, summary: &Self, _: &()) {
231        *self = summary.clone()
232    }
233}
234
235impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
236where
237    K: Clone + Debug + Default + Ord,
238{
239    fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
240        self.0 = Some(&summary.0)
241    }
242}
243
244impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
245where
246    K: Clone + Debug + Default + Ord,
247{
248    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
249        Ord::cmp(&self.0, &cursor_location.0)
250    }
251}
252
253impl<K> Default for TreeSet<K>
254where
255    K: Clone + Debug + Default + Ord,
256{
257    fn default() -> Self {
258        Self(Default::default())
259    }
260}
261
262impl<K> TreeSet<K>
263where
264    K: Clone + Debug + Default + Ord,
265{
266    pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
267        Self(TreeMap::from_ordered_entries(
268            entries.into_iter().map(|key| (key, ())),
269        ))
270    }
271
272    pub fn insert(&mut self, key: K) {
273        self.0.insert(key, ());
274    }
275
276    pub fn contains(&self, key: &K) -> bool {
277        self.0.get(key).is_some()
278    }
279
280    pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
281        self.0.iter().map(|(k, _)| k)
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    #[test]
290    fn test_basic() {
291        let mut map = TreeMap::default();
292        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
293
294        map.insert(3, "c");
295        assert_eq!(map.get(&3), Some(&"c"));
296        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
297
298        map.insert(1, "a");
299        assert_eq!(map.get(&1), Some(&"a"));
300        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
301
302        map.insert(2, "b");
303        assert_eq!(map.get(&2), Some(&"b"));
304        assert_eq!(map.get(&1), Some(&"a"));
305        assert_eq!(map.get(&3), Some(&"c"));
306        assert_eq!(
307            map.iter().collect::<Vec<_>>(),
308            vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
309        );
310
311        assert_eq!(map.closest(&0), None);
312        assert_eq!(map.closest(&1), Some((&1, &"a")));
313        assert_eq!(map.closest(&10), Some((&3, &"c")));
314
315        map.remove(&2);
316        assert_eq!(map.get(&2), None);
317        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
318
319        assert_eq!(map.closest(&2), Some((&1, &"a")));
320
321        map.remove(&3);
322        assert_eq!(map.get(&3), None);
323        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
324
325        map.remove(&1);
326        assert_eq!(map.get(&1), None);
327        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
328
329        map.insert(4, "d");
330        map.insert(5, "e");
331        map.insert(6, "f");
332        map.retain(|key, _| *key % 2 == 0);
333        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
334    }
335
336    #[test]
337    fn test_iter_from() {
338        let mut map = TreeMap::default();
339
340        map.insert("a", 1);
341        map.insert("b", 2);
342        map.insert("baa", 3);
343        map.insert("baaab", 4);
344        map.insert("c", 5);
345
346        let result = map
347            .iter_from(&"ba")
348            .take_while(|(key, _)| key.starts_with(&"ba"))
349            .collect::<Vec<_>>();
350
351        assert_eq!(result.len(), 2);
352        assert!(result.iter().find(|(k, _)| k == &&"baa").is_some());
353        assert!(result.iter().find(|(k, _)| k == &&"baaab").is_some());
354
355        let result = map
356            .iter_from(&"c")
357            .take_while(|(key, _)| key.starts_with(&"c"))
358            .collect::<Vec<_>>();
359
360        assert_eq!(result.len(), 1);
361        assert!(result.iter().find(|(k, _)| k == &&"c").is_some());
362    }
363
364    #[test]
365    fn test_insert_tree() {
366        let mut map = TreeMap::default();
367        map.insert("a", 1);
368        map.insert("b", 2);
369        map.insert("c", 3);
370
371        let mut other = TreeMap::default();
372        other.insert("a", 2);
373        other.insert("b", 2);
374        other.insert("d", 4);
375
376        map.insert_tree(other);
377
378        assert_eq!(map.iter().count(), 4);
379        assert_eq!(map.get(&"a"), Some(&2));
380        assert_eq!(map.get(&"b"), Some(&2));
381        assert_eq!(map.get(&"c"), Some(&3));
382        assert_eq!(map.get(&"d"), Some(&4));
383    }
384
385    #[test]
386    fn test_remove_between_and_path_successor() {
387        use std::path::{Path, PathBuf};
388
389        #[derive(Debug)]
390        pub struct PathDescendants<'a>(&'a Path);
391
392        impl MapSeekTarget<PathBuf> for PathDescendants<'_> {
393            fn cmp_cursor(&self, key: &PathBuf) -> Ordering {
394                if key.starts_with(&self.0) {
395                    Ordering::Greater
396                } else {
397                    self.0.cmp(key)
398                }
399            }
400        }
401
402        let mut map = TreeMap::default();
403
404        map.insert(PathBuf::from("a"), 1);
405        map.insert(PathBuf::from("a/a"), 1);
406        map.insert(PathBuf::from("b"), 2);
407        map.insert(PathBuf::from("b/a/a"), 3);
408        map.insert(PathBuf::from("b/a/a/a/b"), 4);
409        map.insert(PathBuf::from("c"), 5);
410        map.insert(PathBuf::from("c/a"), 6);
411
412        map.remove_range(
413            &PathBuf::from("b/a"),
414            &PathDescendants(&PathBuf::from("b/a")),
415        );
416
417        assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
418        assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
419        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
420        assert_eq!(map.get(&PathBuf::from("b/a/a")), None);
421        assert_eq!(map.get(&PathBuf::from("b/a/a/a/b")), None);
422        assert_eq!(map.get(&PathBuf::from("c")), Some(&5));
423        assert_eq!(map.get(&PathBuf::from("c/a")), Some(&6));
424
425        map.remove_range(&PathBuf::from("c"), &PathDescendants(&PathBuf::from("c")));
426
427        assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
428        assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
429        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
430        assert_eq!(map.get(&PathBuf::from("c")), None);
431        assert_eq!(map.get(&PathBuf::from("c/a")), None);
432
433        map.remove_range(&PathBuf::from("a"), &PathDescendants(&PathBuf::from("a")));
434
435        assert_eq!(map.get(&PathBuf::from("a")), None);
436        assert_eq!(map.get(&PathBuf::from("a/a")), None);
437        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
438
439        map.remove_range(&PathBuf::from("b"), &PathDescendants(&PathBuf::from("b")));
440
441        assert_eq!(map.get(&PathBuf::from("b")), None);
442    }
443}