tree_map.rs

  1use std::{cmp::Ordering, fmt::Debug};
  2
  3use crate::{Bias, Dimension, Edit, Item, KeyedItem, SeekTarget, SumTree, Summary};
  4
  5#[derive(Clone, PartialEq, Eq)]
  6pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
  7where
  8    K: Clone + Ord,
  9    V: Clone;
 10
 11#[derive(Clone, Debug, PartialEq, Eq)]
 12pub struct MapEntry<K, V> {
 13    key: K,
 14    value: V,
 15}
 16
 17#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
 18pub struct MapKey<K>(Option<K>);
 19
 20impl<K> Default for MapKey<K> {
 21    fn default() -> Self {
 22        Self(None)
 23    }
 24}
 25
 26#[derive(Clone, Debug)]
 27pub struct MapKeyRef<'a, K>(Option<&'a K>);
 28
 29impl<K> Default for MapKeyRef<'_, K> {
 30    fn default() -> Self {
 31        Self(None)
 32    }
 33}
 34
 35#[derive(Clone, Debug, PartialEq, Eq)]
 36pub struct TreeSet<K>(TreeMap<K, ()>)
 37where
 38    K: Clone + Ord;
 39
 40impl<K: Clone + Ord, V: Clone> TreeMap<K, V> {
 41    pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
 42        let tree = SumTree::from_iter(
 43            entries
 44                .into_iter()
 45                .map(|(key, value)| MapEntry { key, value }),
 46            &(),
 47        );
 48        Self(tree)
 49    }
 50
 51    pub fn is_empty(&self) -> bool {
 52        self.0.is_empty()
 53    }
 54
 55    pub fn get(&self, key: &K) -> Option<&V> {
 56        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
 57        cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &());
 58        if let Some(item) = cursor.item() {
 59            if Some(key) == item.key().0.as_ref() {
 60                Some(&item.value)
 61            } else {
 62                None
 63            }
 64        } else {
 65            None
 66        }
 67    }
 68
 69    pub fn insert(&mut self, key: K, value: V) {
 70        self.0.insert_or_replace(MapEntry { key, value }, &());
 71    }
 72
 73    pub fn clear(&mut self) {
 74        self.0 = SumTree::default();
 75    }
 76
 77    pub fn remove(&mut self, key: &K) -> Option<V> {
 78        let mut removed = None;
 79        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
 80        let key = MapKeyRef(Some(key));
 81        let mut new_tree = cursor.slice(&key, Bias::Left, &());
 82        if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
 83            removed = Some(cursor.item().unwrap().value.clone());
 84            cursor.next(&());
 85        }
 86        new_tree.append(cursor.suffix(&()), &());
 87        drop(cursor);
 88        self.0 = new_tree;
 89        removed
 90    }
 91
 92    pub fn remove_range(&mut self, start: &impl MapSeekTarget<K>, end: &impl MapSeekTarget<K>) {
 93        let start = MapSeekTargetAdaptor(start);
 94        let end = MapSeekTargetAdaptor(end);
 95        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
 96        let mut new_tree = cursor.slice(&start, Bias::Left, &());
 97        cursor.seek(&end, Bias::Left, &());
 98        new_tree.append(cursor.suffix(&()), &());
 99        drop(cursor);
100        self.0 = new_tree;
101    }
102
103    /// Returns the key-value pair with the greatest key less than or equal to the given key.
104    pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
105        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
106        let key = MapKeyRef(Some(key));
107        cursor.seek(&key, Bias::Right, &());
108        cursor.prev(&());
109        cursor.item().map(|item| (&item.key, &item.value))
110    }
111
112    pub fn iter_from<'a>(&'a self, from: &'a K) -> impl Iterator<Item = (&'a K, &'a V)> + 'a {
113        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
114        let from_key = MapKeyRef(Some(from));
115        cursor.seek(&from_key, Bias::Left, &());
116
117        cursor.map(|map_entry| (&map_entry.key, &map_entry.value))
118    }
119
120    pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
121    where
122        F: FnOnce(&mut V) -> T,
123    {
124        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
125        let key = MapKeyRef(Some(key));
126        let mut new_tree = cursor.slice(&key, Bias::Left, &());
127        let mut result = None;
128        if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
129            let mut updated = cursor.item().unwrap().clone();
130            result = Some(f(&mut updated.value));
131            new_tree.push(updated, &());
132            cursor.next(&());
133        }
134        new_tree.append(cursor.suffix(&()), &());
135        drop(cursor);
136        self.0 = new_tree;
137        result
138    }
139
140    pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
141        let mut new_map = SumTree::<MapEntry<K, V>>::default();
142
143        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
144        cursor.next(&());
145        while let Some(item) = cursor.item() {
146            if predicate(&item.key, &item.value) {
147                new_map.push(item.clone(), &());
148            }
149            cursor.next(&());
150        }
151        drop(cursor);
152
153        self.0 = new_map;
154    }
155
156    pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
157        self.0.iter().map(|entry| (&entry.key, &entry.value))
158    }
159
160    pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
161        self.0.iter().map(|entry| &entry.value)
162    }
163
164    pub fn first(&self) -> Option<(&K, &V)> {
165        self.0.first().map(|entry| (&entry.key, &entry.value))
166    }
167
168    pub fn last(&self) -> Option<(&K, &V)> {
169        self.0.last().map(|entry| (&entry.key, &entry.value))
170    }
171
172    pub fn insert_tree(&mut self, other: TreeMap<K, V>) {
173        let edits = other
174            .iter()
175            .map(|(key, value)| {
176                Edit::Insert(MapEntry {
177                    key: key.to_owned(),
178                    value: value.to_owned(),
179                })
180            })
181            .collect();
182
183        self.0.edit(edits, &());
184    }
185}
186
187impl<K, V> Debug for TreeMap<K, V>
188where
189    K: Clone + Debug + Ord,
190    V: Clone + Debug,
191{
192    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193        f.debug_map().entries(self.iter()).finish()
194    }
195}
196
197#[derive(Debug)]
198struct MapSeekTargetAdaptor<'a, T>(&'a T);
199
200impl<'a, K: Clone + Ord, T: MapSeekTarget<K>> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>>
201    for MapSeekTargetAdaptor<'_, T>
202{
203    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
204        if let Some(key) = &cursor_location.0 {
205            MapSeekTarget::cmp_cursor(self.0, key)
206        } else {
207            Ordering::Greater
208        }
209    }
210}
211
212pub trait MapSeekTarget<K> {
213    fn cmp_cursor(&self, cursor_location: &K) -> Ordering;
214}
215
216impl<K: Ord> MapSeekTarget<K> for K {
217    fn cmp_cursor(&self, cursor_location: &K) -> Ordering {
218        self.cmp(cursor_location)
219    }
220}
221
222impl<K, V> Default for TreeMap<K, V>
223where
224    K: Clone + Ord,
225    V: Clone,
226{
227    fn default() -> Self {
228        Self(Default::default())
229    }
230}
231
232impl<K, V> Item for MapEntry<K, V>
233where
234    K: Clone + Ord,
235    V: Clone,
236{
237    type Summary = MapKey<K>;
238
239    fn summary(&self, _cx: &()) -> Self::Summary {
240        self.key()
241    }
242}
243
244impl<K, V> KeyedItem for MapEntry<K, V>
245where
246    K: Clone + Ord,
247    V: Clone,
248{
249    type Key = MapKey<K>;
250
251    fn key(&self) -> Self::Key {
252        MapKey(Some(self.key.clone()))
253    }
254}
255
256impl<K> Summary for MapKey<K>
257where
258    K: Clone,
259{
260    type Context = ();
261
262    fn zero(_cx: &()) -> Self {
263        Default::default()
264    }
265
266    fn add_summary(&mut self, summary: &Self, _: &()) {
267        *self = summary.clone()
268    }
269}
270
271impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
272where
273    K: Clone + Ord,
274{
275    fn zero(_cx: &()) -> Self {
276        Default::default()
277    }
278
279    fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
280        self.0 = summary.0.as_ref();
281    }
282}
283
284impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
285where
286    K: Clone + Ord,
287{
288    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
289        Ord::cmp(&self.0, &cursor_location.0)
290    }
291}
292
293impl<K> Default for TreeSet<K>
294where
295    K: Clone + Ord,
296{
297    fn default() -> Self {
298        Self(Default::default())
299    }
300}
301
302impl<K> TreeSet<K>
303where
304    K: Clone + Ord,
305{
306    pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
307        Self(TreeMap::from_ordered_entries(
308            entries.into_iter().map(|key| (key, ())),
309        ))
310    }
311
312    pub fn insert(&mut self, key: K) {
313        self.0.insert(key, ());
314    }
315
316    pub fn contains(&self, key: &K) -> bool {
317        self.0.get(key).is_some()
318    }
319
320    pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
321        self.0.iter().map(|(k, _)| k)
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328
329    #[test]
330    fn test_basic() {
331        let mut map = TreeMap::default();
332        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
333
334        map.insert(3, "c");
335        assert_eq!(map.get(&3), Some(&"c"));
336        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
337
338        map.insert(1, "a");
339        assert_eq!(map.get(&1), Some(&"a"));
340        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
341
342        map.insert(2, "b");
343        assert_eq!(map.get(&2), Some(&"b"));
344        assert_eq!(map.get(&1), Some(&"a"));
345        assert_eq!(map.get(&3), Some(&"c"));
346        assert_eq!(
347            map.iter().collect::<Vec<_>>(),
348            vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
349        );
350
351        assert_eq!(map.closest(&0), None);
352        assert_eq!(map.closest(&1), Some((&1, &"a")));
353        assert_eq!(map.closest(&10), Some((&3, &"c")));
354
355        map.remove(&2);
356        assert_eq!(map.get(&2), None);
357        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
358
359        assert_eq!(map.closest(&2), Some((&1, &"a")));
360
361        map.remove(&3);
362        assert_eq!(map.get(&3), None);
363        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
364
365        map.remove(&1);
366        assert_eq!(map.get(&1), None);
367        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
368
369        map.insert(4, "d");
370        map.insert(5, "e");
371        map.insert(6, "f");
372        map.retain(|key, _| *key % 2 == 0);
373        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
374    }
375
376    #[test]
377    fn test_iter_from() {
378        let mut map = TreeMap::default();
379
380        map.insert("a", 1);
381        map.insert("b", 2);
382        map.insert("baa", 3);
383        map.insert("baaab", 4);
384        map.insert("c", 5);
385
386        let result = map
387            .iter_from(&"ba")
388            .take_while(|(key, _)| key.starts_with("ba"))
389            .collect::<Vec<_>>();
390
391        assert_eq!(result.len(), 2);
392        assert!(result.iter().any(|(k, _)| k == &&"baa"));
393        assert!(result.iter().any(|(k, _)| k == &&"baaab"));
394
395        let result = map
396            .iter_from(&"c")
397            .take_while(|(key, _)| key.starts_with("c"))
398            .collect::<Vec<_>>();
399
400        assert_eq!(result.len(), 1);
401        assert!(result.iter().any(|(k, _)| k == &&"c"));
402    }
403
404    #[test]
405    fn test_insert_tree() {
406        let mut map = TreeMap::default();
407        map.insert("a", 1);
408        map.insert("b", 2);
409        map.insert("c", 3);
410
411        let mut other = TreeMap::default();
412        other.insert("a", 2);
413        other.insert("b", 2);
414        other.insert("d", 4);
415
416        map.insert_tree(other);
417
418        assert_eq!(map.iter().count(), 4);
419        assert_eq!(map.get(&"a"), Some(&2));
420        assert_eq!(map.get(&"b"), Some(&2));
421        assert_eq!(map.get(&"c"), Some(&3));
422        assert_eq!(map.get(&"d"), Some(&4));
423    }
424
425    #[test]
426    fn test_remove_between_and_path_successor() {
427        use std::path::{Path, PathBuf};
428
429        #[derive(Debug)]
430        pub struct PathDescendants<'a>(&'a Path);
431
432        impl MapSeekTarget<PathBuf> for PathDescendants<'_> {
433            fn cmp_cursor(&self, key: &PathBuf) -> Ordering {
434                if key.starts_with(self.0) {
435                    Ordering::Greater
436                } else {
437                    self.0.cmp(key)
438                }
439            }
440        }
441
442        let mut map = TreeMap::default();
443
444        map.insert(PathBuf::from("a"), 1);
445        map.insert(PathBuf::from("a/a"), 1);
446        map.insert(PathBuf::from("b"), 2);
447        map.insert(PathBuf::from("b/a/a"), 3);
448        map.insert(PathBuf::from("b/a/a/a/b"), 4);
449        map.insert(PathBuf::from("c"), 5);
450        map.insert(PathBuf::from("c/a"), 6);
451
452        map.remove_range(
453            &PathBuf::from("b/a"),
454            &PathDescendants(&PathBuf::from("b/a")),
455        );
456
457        assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
458        assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
459        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
460        assert_eq!(map.get(&PathBuf::from("b/a/a")), None);
461        assert_eq!(map.get(&PathBuf::from("b/a/a/a/b")), None);
462        assert_eq!(map.get(&PathBuf::from("c")), Some(&5));
463        assert_eq!(map.get(&PathBuf::from("c/a")), Some(&6));
464
465        map.remove_range(&PathBuf::from("c"), &PathDescendants(&PathBuf::from("c")));
466
467        assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
468        assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
469        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
470        assert_eq!(map.get(&PathBuf::from("c")), None);
471        assert_eq!(map.get(&PathBuf::from("c/a")), None);
472
473        map.remove_range(&PathBuf::from("a"), &PathDescendants(&PathBuf::from("a")));
474
475        assert_eq!(map.get(&PathBuf::from("a")), None);
476        assert_eq!(map.get(&PathBuf::from("a/a")), None);
477        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
478
479        map.remove_range(&PathBuf::from("b"), &PathDescendants(&PathBuf::from("b")));
480
481        assert_eq!(map.get(&PathBuf::from("b")), None);
482    }
483}