tree_map.rs

  1use std::{cmp::Ordering, fmt::Debug};
  2
  3use crate::{Bias, Dimension, Item, KeyedItem, SeekTarget, SumTree, Summary};
  4
  5#[derive(Clone)]
  6pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
  7where
  8    K: Clone + Debug + Default + Ord,
  9    V: Clone + Debug;
 10
 11#[derive(Clone)]
 12pub struct MapEntry<K, V> {
 13    key: K,
 14    value: V,
 15}
 16
 17#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 18pub struct MapKey<K>(K);
 19
 20#[derive(Clone, Debug, Default)]
 21pub struct MapKeyRef<'a, K>(Option<&'a K>);
 22
 23impl<K: Clone + Debug + Default + Ord, V: Clone + Debug> TreeMap<K, V> {
 24    pub fn get<'a>(&self, key: &'a K) -> Option<&V> {
 25        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 26        let key = MapKeyRef(Some(key));
 27        cursor.seek(&key, Bias::Left, &());
 28        if key.cmp(cursor.start(), &()) == Ordering::Equal {
 29            Some(&cursor.item().unwrap().value)
 30        } else {
 31            None
 32        }
 33    }
 34
 35    pub fn insert(&mut self, key: K, value: V) {
 36        self.0.insert_or_replace(MapEntry { key, value }, &());
 37    }
 38
 39    pub fn remove<'a>(&mut self, key: &'a K) -> Option<V> {
 40        let mut removed = None;
 41        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 42        let key = MapKeyRef(Some(key));
 43        let mut new_tree = cursor.slice(&key, Bias::Left, &());
 44        if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
 45            removed = Some(cursor.item().unwrap().value.clone());
 46            cursor.next(&());
 47        }
 48        new_tree.push_tree(cursor.suffix(&()), &());
 49        drop(cursor);
 50        self.0 = new_tree;
 51        removed
 52    }
 53
 54    pub fn iter<'a>(&'a self) -> impl 'a + Iterator<Item = (&'a K, &'a V)> {
 55        self.0.iter().map(|entry| (&entry.key, &entry.value))
 56    }
 57}
 58
 59impl<K, V> Default for TreeMap<K, V>
 60where
 61    K: Clone + Debug + Default + Ord,
 62    V: Clone + Debug,
 63{
 64    fn default() -> Self {
 65        Self(Default::default())
 66    }
 67}
 68
 69impl<K, V> Item for MapEntry<K, V>
 70where
 71    K: Clone + Debug + Default + Ord,
 72    V: Clone,
 73{
 74    type Summary = MapKey<K>;
 75
 76    fn summary(&self) -> Self::Summary {
 77        self.key()
 78    }
 79}
 80
 81impl<K, V> KeyedItem for MapEntry<K, V>
 82where
 83    K: Clone + Debug + Default + Ord,
 84    V: Clone,
 85{
 86    type Key = MapKey<K>;
 87
 88    fn key(&self) -> Self::Key {
 89        MapKey(self.key.clone())
 90    }
 91}
 92
 93impl<K> Summary for MapKey<K>
 94where
 95    K: Clone + Debug + Default,
 96{
 97    type Context = ();
 98
 99    fn add_summary(&mut self, summary: &Self, _: &()) {
100        *self = summary.clone()
101    }
102}
103
104impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
105where
106    K: Clone + Debug + Default + Ord,
107{
108    fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
109        self.0 = Some(&summary.0)
110    }
111}
112
113impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
114where
115    K: Clone + Debug + Default + Ord,
116{
117    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
118        self.0.cmp(&cursor_location.0)
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    #[test]
127    fn test_basic() {
128        let mut map = TreeMap::default();
129        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
130
131        map.insert(3, "c");
132        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
133
134        map.insert(1, "a");
135        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
136
137        map.insert(2, "b");
138        assert_eq!(
139            map.iter().collect::<Vec<_>>(),
140            vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
141        );
142
143        map.remove(&2);
144        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
145
146        map.remove(&3);
147        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
148
149        map.remove(&1);
150        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
151    }
152}