tree_map.rs

  1use std::{cmp::Ordering, fmt::Debug};
  2
  3use crate::{Bias, Dimension, Edit, Item, KeyedItem, SeekTarget, SumTree, Summary};
  4
  5#[derive(Clone, PartialEq, Eq)]
  6pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
  7where
  8    K: Clone + Ord,
  9    V: Clone;
 10
 11#[derive(Clone, Debug, PartialEq, Eq)]
 12pub struct MapEntry<K, V> {
 13    key: K,
 14    value: V,
 15}
 16
 17#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
 18pub struct MapKey<K>(Option<K>);
 19
 20impl<K> Default for MapKey<K> {
 21    fn default() -> Self {
 22        Self(None)
 23    }
 24}
 25
 26#[derive(Clone, Debug)]
 27pub struct MapKeyRef<'a, K>(Option<&'a K>);
 28
 29impl<K> Default for MapKeyRef<'_, K> {
 30    fn default() -> Self {
 31        Self(None)
 32    }
 33}
 34
 35#[derive(Clone, Debug, PartialEq, Eq)]
 36pub struct TreeSet<K>(TreeMap<K, ()>)
 37where
 38    K: Clone + Ord;
 39
 40impl<K: Clone + Ord, V: Clone> TreeMap<K, V> {
 41    pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
 42        let tree = SumTree::from_iter(
 43            entries
 44                .into_iter()
 45                .map(|(key, value)| MapEntry { key, value }),
 46            &(),
 47        );
 48        Self(tree)
 49    }
 50
 51    pub fn is_empty(&self) -> bool {
 52        self.0.is_empty()
 53    }
 54
 55    pub fn get(&self, key: &K) -> Option<&V> {
 56        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
 57        cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &());
 58        if let Some(item) = cursor.item() {
 59            if Some(key) == item.key().0.as_ref() {
 60                Some(&item.value)
 61            } else {
 62                None
 63            }
 64        } else {
 65            None
 66        }
 67    }
 68
 69    pub fn insert(&mut self, key: K, value: V) {
 70        self.0.insert_or_replace(MapEntry { key, value }, &());
 71    }
 72
 73    pub fn remove(&mut self, key: &K) -> Option<V> {
 74        let mut removed = None;
 75        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
 76        let key = MapKeyRef(Some(key));
 77        let mut new_tree = cursor.slice(&key, Bias::Left, &());
 78        if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
 79            removed = Some(cursor.item().unwrap().value.clone());
 80            cursor.next(&());
 81        }
 82        new_tree.append(cursor.suffix(&()), &());
 83        drop(cursor);
 84        self.0 = new_tree;
 85        removed
 86    }
 87
 88    pub fn remove_range(&mut self, start: &impl MapSeekTarget<K>, end: &impl MapSeekTarget<K>) {
 89        let start = MapSeekTargetAdaptor(start);
 90        let end = MapSeekTargetAdaptor(end);
 91        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
 92        let mut new_tree = cursor.slice(&start, Bias::Left, &());
 93        cursor.seek(&end, Bias::Left, &());
 94        new_tree.append(cursor.suffix(&()), &());
 95        drop(cursor);
 96        self.0 = new_tree;
 97    }
 98
 99    /// Returns the key-value pair with the greatest key less than or equal to the given key.
100    pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
101        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
102        let key = MapKeyRef(Some(key));
103        cursor.seek(&key, Bias::Right, &());
104        cursor.prev(&());
105        cursor.item().map(|item| (&item.key, &item.value))
106    }
107
108    pub fn iter_from<'a>(&'a self, from: &'a K) -> impl Iterator<Item = (&'a K, &'a V)> + 'a {
109        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
110        let from_key = MapKeyRef(Some(from));
111        cursor.seek(&from_key, Bias::Left, &());
112
113        cursor.map(|map_entry| (&map_entry.key, &map_entry.value))
114    }
115
116    pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
117    where
118        F: FnOnce(&mut V) -> T,
119    {
120        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
121        let key = MapKeyRef(Some(key));
122        let mut new_tree = cursor.slice(&key, Bias::Left, &());
123        let mut result = None;
124        if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
125            let mut updated = cursor.item().unwrap().clone();
126            result = Some(f(&mut updated.value));
127            new_tree.push(updated, &());
128            cursor.next(&());
129        }
130        new_tree.append(cursor.suffix(&()), &());
131        drop(cursor);
132        self.0 = new_tree;
133        result
134    }
135
136    pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
137        let mut new_map = SumTree::<MapEntry<K, V>>::default();
138
139        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
140        cursor.next(&());
141        while let Some(item) = cursor.item() {
142            if predicate(&item.key, &item.value) {
143                new_map.push(item.clone(), &());
144            }
145            cursor.next(&());
146        }
147        drop(cursor);
148
149        self.0 = new_map;
150    }
151
152    pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
153        self.0.iter().map(|entry| (&entry.key, &entry.value))
154    }
155
156    pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
157        self.0.iter().map(|entry| &entry.value)
158    }
159
160    pub fn insert_tree(&mut self, other: TreeMap<K, V>) {
161        let edits = other
162            .iter()
163            .map(|(key, value)| {
164                Edit::Insert(MapEntry {
165                    key: key.to_owned(),
166                    value: value.to_owned(),
167                })
168            })
169            .collect();
170
171        self.0.edit(edits, &());
172    }
173}
174
175impl<K, V> Debug for TreeMap<K, V>
176where
177    K: Clone + Debug + Ord,
178    V: Clone + Debug,
179{
180    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181        f.debug_map().entries(self.iter()).finish()
182    }
183}
184
185#[derive(Debug)]
186struct MapSeekTargetAdaptor<'a, T>(&'a T);
187
188impl<'a, K: Clone + Ord, T: MapSeekTarget<K>> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>>
189    for MapSeekTargetAdaptor<'_, T>
190{
191    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
192        if let Some(key) = &cursor_location.0 {
193            MapSeekTarget::cmp_cursor(self.0, key)
194        } else {
195            Ordering::Greater
196        }
197    }
198}
199
200pub trait MapSeekTarget<K> {
201    fn cmp_cursor(&self, cursor_location: &K) -> Ordering;
202}
203
204impl<K: Ord> MapSeekTarget<K> for K {
205    fn cmp_cursor(&self, cursor_location: &K) -> Ordering {
206        self.cmp(cursor_location)
207    }
208}
209
210impl<K, V> Default for TreeMap<K, V>
211where
212    K: Clone + Ord,
213    V: Clone,
214{
215    fn default() -> Self {
216        Self(Default::default())
217    }
218}
219
220impl<K, V> Item for MapEntry<K, V>
221where
222    K: Clone + Ord,
223    V: Clone,
224{
225    type Summary = MapKey<K>;
226
227    fn summary(&self, _cx: &()) -> Self::Summary {
228        self.key()
229    }
230}
231
232impl<K, V> KeyedItem for MapEntry<K, V>
233where
234    K: Clone + Ord,
235    V: Clone,
236{
237    type Key = MapKey<K>;
238
239    fn key(&self) -> Self::Key {
240        MapKey(Some(self.key.clone()))
241    }
242}
243
244impl<K> Summary for MapKey<K>
245where
246    K: Clone,
247{
248    type Context = ();
249
250    fn zero(_cx: &()) -> Self {
251        Default::default()
252    }
253
254    fn add_summary(&mut self, summary: &Self, _: &()) {
255        *self = summary.clone()
256    }
257}
258
259impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
260where
261    K: Clone + Ord,
262{
263    fn zero(_cx: &()) -> Self {
264        Default::default()
265    }
266
267    fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
268        self.0 = summary.0.as_ref();
269    }
270}
271
272impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
273where
274    K: Clone + Ord,
275{
276    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
277        Ord::cmp(&self.0, &cursor_location.0)
278    }
279}
280
281impl<K> Default for TreeSet<K>
282where
283    K: Clone + Ord,
284{
285    fn default() -> Self {
286        Self(Default::default())
287    }
288}
289
290impl<K> TreeSet<K>
291where
292    K: Clone + Ord,
293{
294    pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
295        Self(TreeMap::from_ordered_entries(
296            entries.into_iter().map(|key| (key, ())),
297        ))
298    }
299
300    pub fn insert(&mut self, key: K) {
301        self.0.insert(key, ());
302    }
303
304    pub fn contains(&self, key: &K) -> bool {
305        self.0.get(key).is_some()
306    }
307
308    pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
309        self.0.iter().map(|(k, _)| k)
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[test]
318    fn test_basic() {
319        let mut map = TreeMap::default();
320        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
321
322        map.insert(3, "c");
323        assert_eq!(map.get(&3), Some(&"c"));
324        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
325
326        map.insert(1, "a");
327        assert_eq!(map.get(&1), Some(&"a"));
328        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
329
330        map.insert(2, "b");
331        assert_eq!(map.get(&2), Some(&"b"));
332        assert_eq!(map.get(&1), Some(&"a"));
333        assert_eq!(map.get(&3), Some(&"c"));
334        assert_eq!(
335            map.iter().collect::<Vec<_>>(),
336            vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
337        );
338
339        assert_eq!(map.closest(&0), None);
340        assert_eq!(map.closest(&1), Some((&1, &"a")));
341        assert_eq!(map.closest(&10), Some((&3, &"c")));
342
343        map.remove(&2);
344        assert_eq!(map.get(&2), None);
345        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
346
347        assert_eq!(map.closest(&2), Some((&1, &"a")));
348
349        map.remove(&3);
350        assert_eq!(map.get(&3), None);
351        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
352
353        map.remove(&1);
354        assert_eq!(map.get(&1), None);
355        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
356
357        map.insert(4, "d");
358        map.insert(5, "e");
359        map.insert(6, "f");
360        map.retain(|key, _| *key % 2 == 0);
361        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
362    }
363
364    #[test]
365    fn test_iter_from() {
366        let mut map = TreeMap::default();
367
368        map.insert("a", 1);
369        map.insert("b", 2);
370        map.insert("baa", 3);
371        map.insert("baaab", 4);
372        map.insert("c", 5);
373
374        let result = map
375            .iter_from(&"ba")
376            .take_while(|(key, _)| key.starts_with("ba"))
377            .collect::<Vec<_>>();
378
379        assert_eq!(result.len(), 2);
380        assert!(result.iter().any(|(k, _)| k == &&"baa"));
381        assert!(result.iter().any(|(k, _)| k == &&"baaab"));
382
383        let result = map
384            .iter_from(&"c")
385            .take_while(|(key, _)| key.starts_with("c"))
386            .collect::<Vec<_>>();
387
388        assert_eq!(result.len(), 1);
389        assert!(result.iter().any(|(k, _)| k == &&"c"));
390    }
391
392    #[test]
393    fn test_insert_tree() {
394        let mut map = TreeMap::default();
395        map.insert("a", 1);
396        map.insert("b", 2);
397        map.insert("c", 3);
398
399        let mut other = TreeMap::default();
400        other.insert("a", 2);
401        other.insert("b", 2);
402        other.insert("d", 4);
403
404        map.insert_tree(other);
405
406        assert_eq!(map.iter().count(), 4);
407        assert_eq!(map.get(&"a"), Some(&2));
408        assert_eq!(map.get(&"b"), Some(&2));
409        assert_eq!(map.get(&"c"), Some(&3));
410        assert_eq!(map.get(&"d"), Some(&4));
411    }
412
413    #[test]
414    fn test_remove_between_and_path_successor() {
415        use std::path::{Path, PathBuf};
416
417        #[derive(Debug)]
418        pub struct PathDescendants<'a>(&'a Path);
419
420        impl MapSeekTarget<PathBuf> for PathDescendants<'_> {
421            fn cmp_cursor(&self, key: &PathBuf) -> Ordering {
422                if key.starts_with(self.0) {
423                    Ordering::Greater
424                } else {
425                    self.0.cmp(key)
426                }
427            }
428        }
429
430        let mut map = TreeMap::default();
431
432        map.insert(PathBuf::from("a"), 1);
433        map.insert(PathBuf::from("a/a"), 1);
434        map.insert(PathBuf::from("b"), 2);
435        map.insert(PathBuf::from("b/a/a"), 3);
436        map.insert(PathBuf::from("b/a/a/a/b"), 4);
437        map.insert(PathBuf::from("c"), 5);
438        map.insert(PathBuf::from("c/a"), 6);
439
440        map.remove_range(
441            &PathBuf::from("b/a"),
442            &PathDescendants(&PathBuf::from("b/a")),
443        );
444
445        assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
446        assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
447        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
448        assert_eq!(map.get(&PathBuf::from("b/a/a")), None);
449        assert_eq!(map.get(&PathBuf::from("b/a/a/a/b")), None);
450        assert_eq!(map.get(&PathBuf::from("c")), Some(&5));
451        assert_eq!(map.get(&PathBuf::from("c/a")), Some(&6));
452
453        map.remove_range(&PathBuf::from("c"), &PathDescendants(&PathBuf::from("c")));
454
455        assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
456        assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
457        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
458        assert_eq!(map.get(&PathBuf::from("c")), None);
459        assert_eq!(map.get(&PathBuf::from("c/a")), None);
460
461        map.remove_range(&PathBuf::from("a"), &PathDescendants(&PathBuf::from("a")));
462
463        assert_eq!(map.get(&PathBuf::from("a")), None);
464        assert_eq!(map.get(&PathBuf::from("a/a")), None);
465        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
466
467        map.remove_range(&PathBuf::from("b"), &PathDescendants(&PathBuf::from("b")));
468
469        assert_eq!(map.get(&PathBuf::from("b")), None);
470    }
471}