tree_map.rs

  1use std::{
  2    cmp::Ordering,
  3    fmt::Debug,
  4    path::{Path, PathBuf},
  5};
  6
  7use crate::{Bias, Dimension, Edit, Item, KeyedItem, SeekTarget, SumTree, Summary};
  8
  9#[derive(Clone, Debug, PartialEq, Eq)]
 10pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
 11where
 12    K: Clone + Debug + Default + Ord,
 13    V: Clone + Debug;
 14
 15#[derive(Clone, Debug, PartialEq, Eq)]
 16pub struct MapEntry<K, V> {
 17    key: K,
 18    value: V,
 19}
 20
 21#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 22pub struct MapKey<K>(K);
 23
 24#[derive(Clone, Debug, Default)]
 25pub struct MapKeyRef<'a, K>(Option<&'a K>);
 26
 27#[derive(Clone)]
 28pub struct TreeSet<K>(TreeMap<K, ()>)
 29where
 30    K: Clone + Debug + Default + Ord;
 31
 32impl<K: Clone + Debug + Default + Ord, V: Clone + Debug> TreeMap<K, V> {
 33    pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
 34        let tree = SumTree::from_iter(
 35            entries
 36                .into_iter()
 37                .map(|(key, value)| MapEntry { key, value }),
 38            &(),
 39        );
 40        Self(tree)
 41    }
 42
 43    pub fn is_empty(&self) -> bool {
 44        self.0.is_empty()
 45    }
 46
 47    pub fn get<'a>(&self, key: &'a K) -> Option<&V> {
 48        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 49        cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &());
 50        if let Some(item) = cursor.item() {
 51            if *key == item.key().0 {
 52                Some(&item.value)
 53            } else {
 54                None
 55            }
 56        } else {
 57            None
 58        }
 59    }
 60
 61    pub fn insert(&mut self, key: K, value: V) {
 62        self.0.insert_or_replace(MapEntry { key, value }, &());
 63    }
 64
 65    pub fn remove(&mut self, key: &K) -> Option<V> {
 66        let mut removed = None;
 67        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 68        let key = MapKeyRef(Some(key));
 69        let mut new_tree = cursor.slice(&key, Bias::Left, &());
 70        if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
 71            removed = Some(cursor.item().unwrap().value.clone());
 72            cursor.next(&());
 73        }
 74        new_tree.push_tree(cursor.suffix(&()), &());
 75        drop(cursor);
 76        self.0 = new_tree;
 77        removed
 78    }
 79
 80    pub fn remove_range(&mut self, start: &impl MapSeekTarget<K>, end: &impl MapSeekTarget<K>) {
 81        let start = MapSeekTargetAdaptor(start);
 82        let end = MapSeekTargetAdaptor(end);
 83        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 84        let mut new_tree = cursor.slice(&start, Bias::Left, &());
 85        cursor.seek(&end, Bias::Left, &());
 86        new_tree.push_tree(cursor.suffix(&()), &());
 87        drop(cursor);
 88        self.0 = new_tree;
 89    }
 90
 91    /// Returns the key-value pair with the greatest key less than or equal to the given key.
 92    pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
 93        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 94        let key = MapKeyRef(Some(key));
 95        cursor.seek(&key, Bias::Right, &());
 96        cursor.prev(&());
 97        cursor.item().map(|item| (&item.key, &item.value))
 98    }
 99
100    pub fn iter_from<'a>(&'a self, from: &'a K) -> impl Iterator<Item = (&K, &V)> + '_ {
101        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
102        let from_key = MapKeyRef(Some(from));
103        cursor.seek(&from_key, Bias::Left, &());
104
105        cursor
106            .into_iter()
107            .map(|map_entry| (&map_entry.key, &map_entry.value))
108    }
109
110    pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
111    where
112        F: FnOnce(&mut V) -> T,
113    {
114        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
115        let key = MapKeyRef(Some(key));
116        let mut new_tree = cursor.slice(&key, Bias::Left, &());
117        let mut result = None;
118        if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
119            let mut updated = cursor.item().unwrap().clone();
120            result = Some(f(&mut updated.value));
121            new_tree.push(updated, &());
122            cursor.next(&());
123        }
124        new_tree.push_tree(cursor.suffix(&()), &());
125        drop(cursor);
126        self.0 = new_tree;
127        result
128    }
129
130    pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
131        let mut new_map = SumTree::<MapEntry<K, V>>::default();
132
133        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
134        cursor.next(&());
135        while let Some(item) = cursor.item() {
136            if predicate(&item.key, &item.value) {
137                new_map.push(item.clone(), &());
138            }
139            cursor.next(&());
140        }
141        drop(cursor);
142
143        self.0 = new_map;
144    }
145
146    pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
147        self.0.iter().map(|entry| (&entry.key, &entry.value))
148    }
149
150    pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
151        self.0.iter().map(|entry| &entry.value)
152    }
153
154    pub fn insert_tree(&mut self, other: TreeMap<K, V>) {
155        let edits = other
156            .iter()
157            .map(|(key, value)| {
158                Edit::Insert(MapEntry {
159                    key: key.to_owned(),
160                    value: value.to_owned(),
161                })
162            })
163            .collect();
164
165        self.0.edit(edits, &());
166    }
167}
168
169#[derive(Debug)]
170struct MapSeekTargetAdaptor<'a, T>(&'a T);
171
172impl<'a, K: Debug + Clone + Default + Ord, T: MapSeekTarget<K>>
173    SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapSeekTargetAdaptor<'_, T>
174{
175    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
176        MapSeekTarget::cmp(self.0, cursor_location)
177    }
178}
179
180pub trait MapSeekTarget<K>: Debug {
181    fn cmp(&self, cursor_location: &MapKeyRef<K>) -> Ordering;
182}
183
184impl<K: Debug + Ord> MapSeekTarget<K> for K {
185    fn cmp(&self, cursor_location: &MapKeyRef<K>) -> Ordering {
186        if let Some(key) = &cursor_location.0 {
187            self.cmp(key)
188        } else {
189            Ordering::Greater
190        }
191    }
192}
193
194#[derive(Debug)]
195pub struct PathDescendants<'a>(&'a Path);
196
197impl MapSeekTarget<PathBuf> for PathDescendants<'_> {
198    fn cmp(&self, cursor_location: &MapKeyRef<PathBuf>) -> Ordering {
199        if let Some(key) = &cursor_location.0 {
200            if key.starts_with(&self.0) {
201                Ordering::Greater
202            } else {
203                self.0.cmp(key)
204            }
205        } else {
206            Ordering::Greater
207        }
208    }
209}
210
211impl<K, V> Default for TreeMap<K, V>
212where
213    K: Clone + Debug + Default + Ord,
214    V: Clone + Debug,
215{
216    fn default() -> Self {
217        Self(Default::default())
218    }
219}
220
221impl<K, V> Item for MapEntry<K, V>
222where
223    K: Clone + Debug + Default + Ord,
224    V: Clone,
225{
226    type Summary = MapKey<K>;
227
228    fn summary(&self) -> Self::Summary {
229        self.key()
230    }
231}
232
233impl<K, V> KeyedItem for MapEntry<K, V>
234where
235    K: Clone + Debug + Default + Ord,
236    V: Clone,
237{
238    type Key = MapKey<K>;
239
240    fn key(&self) -> Self::Key {
241        MapKey(self.key.clone())
242    }
243}
244
245impl<K> Summary for MapKey<K>
246where
247    K: Clone + Debug + Default,
248{
249    type Context = ();
250
251    fn add_summary(&mut self, summary: &Self, _: &()) {
252        *self = summary.clone()
253    }
254}
255
256impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
257where
258    K: Clone + Debug + Default + Ord,
259{
260    fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
261        self.0 = Some(&summary.0)
262    }
263}
264
265impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
266where
267    K: Clone + Debug + Default + Ord,
268{
269    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
270        Ord::cmp(&self.0, &cursor_location.0)
271    }
272}
273
274impl<K> Default for TreeSet<K>
275where
276    K: Clone + Debug + Default + Ord,
277{
278    fn default() -> Self {
279        Self(Default::default())
280    }
281}
282
283impl<K> TreeSet<K>
284where
285    K: Clone + Debug + Default + Ord,
286{
287    pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
288        Self(TreeMap::from_ordered_entries(
289            entries.into_iter().map(|key| (key, ())),
290        ))
291    }
292
293    pub fn insert(&mut self, key: K) {
294        self.0.insert(key, ());
295    }
296
297    pub fn contains(&self, key: &K) -> bool {
298        self.0.get(key).is_some()
299    }
300
301    pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
302        self.0.iter().map(|(k, _)| k)
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    #[test]
311    fn test_basic() {
312        let mut map = TreeMap::default();
313        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
314
315        map.insert(3, "c");
316        assert_eq!(map.get(&3), Some(&"c"));
317        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
318
319        map.insert(1, "a");
320        assert_eq!(map.get(&1), Some(&"a"));
321        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
322
323        map.insert(2, "b");
324        assert_eq!(map.get(&2), Some(&"b"));
325        assert_eq!(map.get(&1), Some(&"a"));
326        assert_eq!(map.get(&3), Some(&"c"));
327        assert_eq!(
328            map.iter().collect::<Vec<_>>(),
329            vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
330        );
331
332        assert_eq!(map.closest(&0), None);
333        assert_eq!(map.closest(&1), Some((&1, &"a")));
334        assert_eq!(map.closest(&10), Some((&3, &"c")));
335
336        map.remove(&2);
337        assert_eq!(map.get(&2), None);
338        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
339
340        assert_eq!(map.closest(&2), Some((&1, &"a")));
341
342        map.remove(&3);
343        assert_eq!(map.get(&3), None);
344        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
345
346        map.remove(&1);
347        assert_eq!(map.get(&1), None);
348        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
349
350        map.insert(4, "d");
351        map.insert(5, "e");
352        map.insert(6, "f");
353        map.retain(|key, _| *key % 2 == 0);
354        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
355    }
356
357    #[test]
358    fn test_iter_from() {
359        let mut map = TreeMap::default();
360
361        map.insert("a", 1);
362        map.insert("b", 2);
363        map.insert("baa", 3);
364        map.insert("baaab", 4);
365        map.insert("c", 5);
366
367        let result = map
368            .iter_from(&"ba")
369            .take_while(|(key, _)| key.starts_with(&"ba"))
370            .collect::<Vec<_>>();
371
372        assert_eq!(result.len(), 2);
373        assert!(result.iter().find(|(k, _)| k == &&"baa").is_some());
374        assert!(result.iter().find(|(k, _)| k == &&"baaab").is_some());
375
376        let result = map
377            .iter_from(&"c")
378            .take_while(|(key, _)| key.starts_with(&"c"))
379            .collect::<Vec<_>>();
380
381        assert_eq!(result.len(), 1);
382        assert!(result.iter().find(|(k, _)| k == &&"c").is_some());
383    }
384
385    #[test]
386    fn test_insert_tree() {
387        let mut map = TreeMap::default();
388        map.insert("a", 1);
389        map.insert("b", 2);
390        map.insert("c", 3);
391
392        let mut other = TreeMap::default();
393        other.insert("a", 2);
394        other.insert("b", 2);
395        other.insert("d", 4);
396
397        map.insert_tree(other);
398
399        assert_eq!(map.iter().count(), 4);
400        assert_eq!(map.get(&"a"), Some(&2));
401        assert_eq!(map.get(&"b"), Some(&2));
402        assert_eq!(map.get(&"c"), Some(&3));
403        assert_eq!(map.get(&"d"), Some(&4));
404    }
405
406    #[test]
407    fn test_remove_between_and_path_successor() {
408        let mut map = TreeMap::default();
409
410        map.insert(PathBuf::from("a"), 1);
411        map.insert(PathBuf::from("a/a"), 1);
412        map.insert(PathBuf::from("b"), 2);
413        map.insert(PathBuf::from("b/a/a"), 3);
414        map.insert(PathBuf::from("b/a/a/a/b"), 4);
415        map.insert(PathBuf::from("c"), 5);
416        map.insert(PathBuf::from("c/a"), 6);
417
418        map.remove_range(&PathBuf::from("b/a"), &PathDescendants(&PathBuf::from("b/a")));
419
420        assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
421        assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
422        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
423        assert_eq!(map.get(&PathBuf::from("b/a/a")), None);
424        assert_eq!(map.get(&PathBuf::from("b/a/a/a/b")), None);
425        assert_eq!(map.get(&PathBuf::from("c")), Some(&5));
426        assert_eq!(map.get(&PathBuf::from("c/a")), Some(&6));
427
428        map.remove_range(&PathBuf::from("c"), &PathDescendants(&PathBuf::from("c")));
429
430        assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
431        assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
432        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
433        assert_eq!(map.get(&PathBuf::from("c")), None);
434        assert_eq!(map.get(&PathBuf::from("c/a")), None);
435
436        map.remove_range(&PathBuf::from("a"), &PathDescendants(&PathBuf::from("a")));
437
438        assert_eq!(map.get(&PathBuf::from("a")), None);
439        assert_eq!(map.get(&PathBuf::from("a/a")), None);
440        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
441
442        map.remove_range(&PathBuf::from("b"), &PathDescendants(&PathBuf::from("b")));
443
444        assert_eq!(map.get(&PathBuf::from("b")), None);
445    }
446}