tree_map.rs

  1use std::{cmp::Ordering, fmt::Debug};
  2
  3use crate::{Bias, Dimension, Item, KeyedItem, SeekTarget, SumTree, Summary};
  4
  5#[derive(Clone)]
  6pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
  7where
  8    K: Clone + Debug + Default + Ord,
  9    V: Clone + Debug;
 10
 11#[derive(Clone)]
 12pub struct MapEntry<K, V> {
 13    key: K,
 14    value: V,
 15}
 16
 17#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 18pub struct MapKey<K>(K);
 19
 20#[derive(Clone, Debug, Default)]
 21pub struct MapKeyRef<'a, K>(Option<&'a K>);
 22
 23impl<K: Clone + Debug + Default + Ord, V: Clone + Debug> TreeMap<K, V> {
 24    pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
 25        let tree = SumTree::from_iter(
 26            entries
 27                .into_iter()
 28                .map(|(key, value)| MapEntry { key, value }),
 29            &(),
 30        );
 31        Self(tree)
 32    }
 33
 34    pub fn is_empty(&self) -> bool {
 35        self.0.is_empty()
 36    }
 37
 38    pub fn get<'a>(&self, key: &'a K) -> Option<&V> {
 39        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 40        cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &());
 41        if let Some(item) = cursor.item() {
 42            if *key == item.key().0 {
 43                Some(&item.value)
 44            } else {
 45                None
 46            }
 47        } else {
 48            None
 49        }
 50    }
 51
 52    pub fn insert(&mut self, key: K, value: V) {
 53        self.0.insert_or_replace(MapEntry { key, value }, &());
 54    }
 55
 56    pub fn remove<'a>(&mut self, key: &'a K) -> Option<V> {
 57        let mut removed = None;
 58        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 59        let key = MapKeyRef(Some(key));
 60        let mut new_tree = cursor.slice(&key, Bias::Left, &());
 61        if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
 62            removed = Some(cursor.item().unwrap().value.clone());
 63            cursor.next(&());
 64        }
 65        new_tree.push_tree(cursor.suffix(&()), &());
 66        drop(cursor);
 67        self.0 = new_tree;
 68        removed
 69    }
 70
 71    pub fn iter<'a>(&'a self) -> impl 'a + Iterator<Item = (&'a K, &'a V)> {
 72        self.0.iter().map(|entry| (&entry.key, &entry.value))
 73    }
 74}
 75
 76impl<K, V> Default for TreeMap<K, V>
 77where
 78    K: Clone + Debug + Default + Ord,
 79    V: Clone + Debug,
 80{
 81    fn default() -> Self {
 82        Self(Default::default())
 83    }
 84}
 85
 86impl<K, V> Item for MapEntry<K, V>
 87where
 88    K: Clone + Debug + Default + Ord,
 89    V: Clone,
 90{
 91    type Summary = MapKey<K>;
 92
 93    fn summary(&self) -> Self::Summary {
 94        self.key()
 95    }
 96}
 97
 98impl<K, V> KeyedItem for MapEntry<K, V>
 99where
100    K: Clone + Debug + Default + Ord,
101    V: Clone,
102{
103    type Key = MapKey<K>;
104
105    fn key(&self) -> Self::Key {
106        MapKey(self.key.clone())
107    }
108}
109
110impl<K> Summary for MapKey<K>
111where
112    K: Clone + Debug + Default,
113{
114    type Context = ();
115
116    fn add_summary(&mut self, summary: &Self, _: &()) {
117        *self = summary.clone()
118    }
119}
120
121impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
122where
123    K: Clone + Debug + Default + Ord,
124{
125    fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
126        self.0 = Some(&summary.0)
127    }
128}
129
130impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
131where
132    K: Clone + Debug + Default + Ord,
133{
134    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
135        self.0.cmp(&cursor_location.0)
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142
143    #[test]
144    fn test_basic() {
145        let mut map = TreeMap::default();
146        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
147
148        map.insert(3, "c");
149        assert_eq!(map.get(&3), Some(&"c"));
150        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
151
152        map.insert(1, "a");
153        assert_eq!(map.get(&1), Some(&"a"));
154        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
155
156        map.insert(2, "b");
157        assert_eq!(map.get(&2), Some(&"b"));
158        assert_eq!(map.get(&1), Some(&"a"));
159        assert_eq!(map.get(&3), Some(&"c"));
160        assert_eq!(
161            map.iter().collect::<Vec<_>>(),
162            vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
163        );
164
165        map.remove(&2);
166        assert_eq!(map.get(&2), None);
167        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
168
169        map.remove(&3);
170        assert_eq!(map.get(&3), None);
171        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
172
173        map.remove(&1);
174        assert_eq!(map.get(&1), None);
175        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
176    }
177}