tree_map.rs

  1use std::{cmp::Ordering, fmt::Debug};
  2
  3use crate::{Bias, Dimension, Item, KeyedItem, SeekTarget, SumTree, Summary};
  4
  5#[derive(Clone)]
  6pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
  7where
  8    K: Clone + Debug + Default + Ord,
  9    V: Clone + Debug;
 10
 11#[derive(Clone)]
 12pub struct MapEntry<K, V> {
 13    key: K,
 14    value: V,
 15}
 16
 17#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 18pub struct MapKey<K>(K);
 19
 20#[derive(Clone, Debug, Default)]
 21pub struct MapKeyRef<'a, K>(Option<&'a K>);
 22
 23impl<K: Clone + Debug + Default + Ord, V: Clone + Debug> TreeMap<K, V> {
 24    pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
 25        let tree = SumTree::from_iter(
 26            entries
 27                .into_iter()
 28                .map(|(key, value)| MapEntry { key, value }),
 29            &(),
 30        );
 31        Self(tree)
 32    }
 33
 34    pub fn get<'a>(&self, key: &'a K) -> Option<&V> {
 35        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 36        cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &());
 37        if let Some(item) = cursor.item() {
 38            if *key == item.key().0 {
 39                Some(&item.value)
 40            } else {
 41                None
 42            }
 43        } else {
 44            None
 45        }
 46    }
 47
 48    pub fn insert(&mut self, key: K, value: V) {
 49        self.0.insert_or_replace(MapEntry { key, value }, &());
 50    }
 51
 52    pub fn remove<'a>(&mut self, key: &'a K) -> Option<V> {
 53        let mut removed = None;
 54        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 55        let key = MapKeyRef(Some(key));
 56        let mut new_tree = cursor.slice(&key, Bias::Left, &());
 57        if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
 58            removed = Some(cursor.item().unwrap().value.clone());
 59            cursor.next(&());
 60        }
 61        new_tree.push_tree(cursor.suffix(&()), &());
 62        drop(cursor);
 63        self.0 = new_tree;
 64        removed
 65    }
 66
 67    pub fn iter<'a>(&'a self) -> impl 'a + Iterator<Item = (&'a K, &'a V)> {
 68        self.0.iter().map(|entry| (&entry.key, &entry.value))
 69    }
 70}
 71
 72impl<K, V> Default for TreeMap<K, V>
 73where
 74    K: Clone + Debug + Default + Ord,
 75    V: Clone + Debug,
 76{
 77    fn default() -> Self {
 78        Self(Default::default())
 79    }
 80}
 81
 82impl<K, V> Item for MapEntry<K, V>
 83where
 84    K: Clone + Debug + Default + Ord,
 85    V: Clone,
 86{
 87    type Summary = MapKey<K>;
 88
 89    fn summary(&self) -> Self::Summary {
 90        self.key()
 91    }
 92}
 93
 94impl<K, V> KeyedItem for MapEntry<K, V>
 95where
 96    K: Clone + Debug + Default + Ord,
 97    V: Clone,
 98{
 99    type Key = MapKey<K>;
100
101    fn key(&self) -> Self::Key {
102        MapKey(self.key.clone())
103    }
104}
105
106impl<K> Summary for MapKey<K>
107where
108    K: Clone + Debug + Default,
109{
110    type Context = ();
111
112    fn add_summary(&mut self, summary: &Self, _: &()) {
113        *self = summary.clone()
114    }
115}
116
117impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
118where
119    K: Clone + Debug + Default + Ord,
120{
121    fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
122        self.0 = Some(&summary.0)
123    }
124}
125
126impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
127where
128    K: Clone + Debug + Default + Ord,
129{
130    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
131        self.0.cmp(&cursor_location.0)
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138
139    #[test]
140    fn test_basic() {
141        let mut map = TreeMap::default();
142        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
143
144        map.insert(3, "c");
145        assert_eq!(map.get(&3), Some(&"c"));
146        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
147
148        map.insert(1, "a");
149        assert_eq!(map.get(&1), Some(&"a"));
150        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
151
152        map.insert(2, "b");
153        assert_eq!(map.get(&2), Some(&"b"));
154        assert_eq!(map.get(&1), Some(&"a"));
155        assert_eq!(map.get(&3), Some(&"c"));
156        assert_eq!(
157            map.iter().collect::<Vec<_>>(),
158            vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
159        );
160
161        map.remove(&2);
162        assert_eq!(map.get(&2), None);
163        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
164
165        map.remove(&3);
166        assert_eq!(map.get(&3), None);
167        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
168
169        map.remove(&1);
170        assert_eq!(map.get(&1), None);
171        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
172    }
173}