tree_map.rs

  1use std::{cmp::Ordering, fmt::Debug};
  2
  3use crate::{Bias, Dimension, Item, KeyedItem, SeekTarget, SumTree, Summary};
  4
  5#[derive(Clone)]
  6pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
  7where
  8    K: Clone + Debug + Default + Ord,
  9    V: Clone + Debug;
 10
 11#[derive(Clone)]
 12pub struct MapEntry<K, V> {
 13    key: K,
 14    value: V,
 15}
 16
 17#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 18pub struct MapKey<K>(K);
 19
 20#[derive(Clone, Debug, Default)]
 21pub struct MapKeyRef<'a, K>(Option<&'a K>);
 22
 23impl<K: Clone + Debug + Default + Ord, V: Clone + Debug> TreeMap<K, V> {
 24    pub fn get<'a>(&self, key: &'a K) -> Option<&V> {
 25        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 26        cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &());
 27        if let Some(item) = cursor.item() {
 28            if *key == item.key().0 {
 29                Some(&item.value)
 30            } else {
 31                None
 32            }
 33        } else {
 34            None
 35        }
 36    }
 37
 38    pub fn insert(&mut self, key: K, value: V) {
 39        self.0.insert_or_replace(MapEntry { key, value }, &());
 40    }
 41
 42    pub fn remove<'a>(&mut self, key: &'a K) -> Option<V> {
 43        let mut removed = None;
 44        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 45        let key = MapKeyRef(Some(key));
 46        let mut new_tree = cursor.slice(&key, Bias::Left, &());
 47        if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
 48            removed = Some(cursor.item().unwrap().value.clone());
 49            cursor.next(&());
 50        }
 51        new_tree.push_tree(cursor.suffix(&()), &());
 52        drop(cursor);
 53        self.0 = new_tree;
 54        removed
 55    }
 56
 57    pub fn iter<'a>(&'a self) -> impl 'a + Iterator<Item = (&'a K, &'a V)> {
 58        self.0.iter().map(|entry| (&entry.key, &entry.value))
 59    }
 60}
 61
 62impl<K, V> Default for TreeMap<K, V>
 63where
 64    K: Clone + Debug + Default + Ord,
 65    V: Clone + Debug,
 66{
 67    fn default() -> Self {
 68        Self(Default::default())
 69    }
 70}
 71
 72impl<K, V> Item for MapEntry<K, V>
 73where
 74    K: Clone + Debug + Default + Ord,
 75    V: Clone,
 76{
 77    type Summary = MapKey<K>;
 78
 79    fn summary(&self) -> Self::Summary {
 80        self.key()
 81    }
 82}
 83
 84impl<K, V> KeyedItem for MapEntry<K, V>
 85where
 86    K: Clone + Debug + Default + Ord,
 87    V: Clone,
 88{
 89    type Key = MapKey<K>;
 90
 91    fn key(&self) -> Self::Key {
 92        MapKey(self.key.clone())
 93    }
 94}
 95
 96impl<K> Summary for MapKey<K>
 97where
 98    K: Clone + Debug + Default,
 99{
100    type Context = ();
101
102    fn add_summary(&mut self, summary: &Self, _: &()) {
103        *self = summary.clone()
104    }
105}
106
107impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
108where
109    K: Clone + Debug + Default + Ord,
110{
111    fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
112        self.0 = Some(&summary.0)
113    }
114}
115
116impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
117where
118    K: Clone + Debug + Default + Ord,
119{
120    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
121        self.0.cmp(&cursor_location.0)
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128
129    #[test]
130    fn test_basic() {
131        let mut map = TreeMap::default();
132        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
133
134        map.insert(3, "c");
135        assert_eq!(map.get(&3), Some(&"c"));
136        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
137
138        map.insert(1, "a");
139        assert_eq!(map.get(&1), Some(&"a"));
140        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
141
142        map.insert(2, "b");
143        assert_eq!(map.get(&2), Some(&"b"));
144        assert_eq!(map.get(&1), Some(&"a"));
145        assert_eq!(map.get(&3), Some(&"c"));
146        assert_eq!(
147            map.iter().collect::<Vec<_>>(),
148            vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
149        );
150
151        map.remove(&2);
152        assert_eq!(map.get(&2), None);
153        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
154
155        map.remove(&3);
156        assert_eq!(map.get(&3), None);
157        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
158
159        map.remove(&1);
160        assert_eq!(map.get(&1), None);
161        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
162    }
163}