tree_map.rs

  1use std::{cmp::Ordering, fmt::Debug};
  2
  3use crate::{Bias, Dimension, Item, KeyedItem, SeekTarget, SumTree, Summary};
  4
  5#[derive(Clone, Debug, PartialEq, Eq)]
  6pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
  7where
  8    K: Clone + Debug + Default + Ord,
  9    V: Clone + Debug;
 10
 11#[derive(Clone, Debug, PartialEq, Eq)]
 12pub struct MapEntry<K, V> {
 13    key: K,
 14    value: V,
 15}
 16
 17#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 18pub struct MapKey<K>(K);
 19
 20#[derive(Clone, Debug, Default)]
 21pub struct MapKeyRef<'a, K>(Option<&'a K>);
 22
 23#[derive(Clone)]
 24pub struct TreeSet<K>(TreeMap<K, ()>)
 25where
 26    K: Clone + Debug + Default + Ord;
 27
 28impl<K: Clone + Debug + Default + Ord, V: Clone + Debug> TreeMap<K, V> {
 29    pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
 30        let tree = SumTree::from_iter(
 31            entries
 32                .into_iter()
 33                .map(|(key, value)| MapEntry { key, value }),
 34            &(),
 35        );
 36        Self(tree)
 37    }
 38
 39    pub fn is_empty(&self) -> bool {
 40        self.0.is_empty()
 41    }
 42
 43    pub fn get<'a>(&self, key: &'a K) -> Option<&V> {
 44        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 45        cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &());
 46        if let Some(item) = cursor.item() {
 47            if *key == item.key().0 {
 48                Some(&item.value)
 49            } else {
 50                None
 51            }
 52        } else {
 53            None
 54        }
 55    }
 56
 57    pub fn insert(&mut self, key: K, value: V) {
 58        self.0.insert_or_replace(MapEntry { key, value }, &());
 59    }
 60
 61    pub fn remove(&mut self, key: &K) -> Option<V> {
 62        let mut removed = None;
 63        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 64        let key = MapKeyRef(Some(key));
 65        let mut new_tree = cursor.slice(&key, Bias::Left, &());
 66        if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
 67            removed = Some(cursor.item().unwrap().value.clone());
 68            cursor.next(&());
 69        }
 70        new_tree.push_tree(cursor.suffix(&()), &());
 71        drop(cursor);
 72        self.0 = new_tree;
 73        removed
 74    }
 75
 76    /// Returns the key-value pair with the greatest key less than or equal to the given key.
 77    pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
 78        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 79        let key = MapKeyRef(Some(key));
 80        cursor.seek(&key, Bias::Right, &());
 81        cursor.prev(&());
 82        cursor.item().map(|item| (&item.key, &item.value))
 83    }
 84
 85    pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
 86    where
 87        F: FnOnce(&mut V) -> T,
 88    {
 89        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 90        let key = MapKeyRef(Some(key));
 91        let mut new_tree = cursor.slice(&key, Bias::Left, &());
 92        let mut result = None;
 93        if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
 94            let mut updated = cursor.item().unwrap().clone();
 95            result = Some(f(&mut updated.value));
 96            new_tree.push(updated, &());
 97            cursor.next(&());
 98        }
 99        new_tree.push_tree(cursor.suffix(&()), &());
100        drop(cursor);
101        self.0 = new_tree;
102        result
103    }
104
105    pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
106        let mut new_map = SumTree::<MapEntry<K, V>>::default();
107
108        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
109        cursor.next(&());
110        while let Some(item) = cursor.item() {
111            if predicate(&item.key, &item.value) {
112                new_map.push(item.clone(), &());
113            }
114            cursor.next(&());
115        }
116        drop(cursor);
117
118        self.0 = new_map;
119    }
120
121    pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
122        self.0.iter().map(|entry| (&entry.key, &entry.value))
123    }
124
125    pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
126        self.0.iter().map(|entry| &entry.value)
127    }
128}
129
130impl<K, V> Default for TreeMap<K, V>
131where
132    K: Clone + Debug + Default + Ord,
133    V: Clone + Debug,
134{
135    fn default() -> Self {
136        Self(Default::default())
137    }
138}
139
140impl<K, V> Item for MapEntry<K, V>
141where
142    K: Clone + Debug + Default + Ord,
143    V: Clone,
144{
145    type Summary = MapKey<K>;
146
147    fn summary(&self) -> Self::Summary {
148        self.key()
149    }
150}
151
152impl<K, V> KeyedItem for MapEntry<K, V>
153where
154    K: Clone + Debug + Default + Ord,
155    V: Clone,
156{
157    type Key = MapKey<K>;
158
159    fn key(&self) -> Self::Key {
160        MapKey(self.key.clone())
161    }
162}
163
164impl<K> Summary for MapKey<K>
165where
166    K: Clone + Debug + Default,
167{
168    type Context = ();
169
170    fn add_summary(&mut self, summary: &Self, _: &()) {
171        *self = summary.clone()
172    }
173}
174
175impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
176where
177    K: Clone + Debug + Default + Ord,
178{
179    fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
180        self.0 = Some(&summary.0)
181    }
182}
183
184impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
185where
186    K: Clone + Debug + Default + Ord,
187{
188    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
189        self.0.cmp(&cursor_location.0)
190    }
191}
192
193impl<K> Default for TreeSet<K>
194where
195    K: Clone + Debug + Default + Ord,
196{
197    fn default() -> Self {
198        Self(Default::default())
199    }
200}
201
202impl<K> TreeSet<K>
203where
204    K: Clone + Debug + Default + Ord,
205{
206    pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
207        Self(TreeMap::from_ordered_entries(
208            entries.into_iter().map(|key| (key, ())),
209        ))
210    }
211
212    pub fn insert(&mut self, key: K) {
213        self.0.insert(key, ());
214    }
215
216    pub fn contains(&self, key: &K) -> bool {
217        self.0.get(key).is_some()
218    }
219
220    pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
221        self.0.iter().map(|(k, _)| k)
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    #[test]
230    fn test_basic() {
231        let mut map = TreeMap::default();
232        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
233
234        map.insert(3, "c");
235        assert_eq!(map.get(&3), Some(&"c"));
236        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
237
238        map.insert(1, "a");
239        assert_eq!(map.get(&1), Some(&"a"));
240        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
241
242        map.insert(2, "b");
243        assert_eq!(map.get(&2), Some(&"b"));
244        assert_eq!(map.get(&1), Some(&"a"));
245        assert_eq!(map.get(&3), Some(&"c"));
246        assert_eq!(
247            map.iter().collect::<Vec<_>>(),
248            vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
249        );
250
251        assert_eq!(map.closest(&0), None);
252        assert_eq!(map.closest(&1), Some((&1, &"a")));
253        assert_eq!(map.closest(&10), Some((&3, &"c")));
254
255        map.remove(&2);
256        assert_eq!(map.get(&2), None);
257        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
258
259        assert_eq!(map.closest(&2), Some((&1, &"a")));
260
261        map.remove(&3);
262        assert_eq!(map.get(&3), None);
263        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
264
265        map.remove(&1);
266        assert_eq!(map.get(&1), None);
267        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
268
269        map.insert(4, "d");
270        map.insert(5, "e");
271        map.insert(6, "f");
272        map.retain(|key, _| *key % 2 == 0);
273        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
274    }
275}