tree_map.rs

  1use std::{cmp::Ordering, fmt::Debug};
  2
  3use crate::{Bias, Dimension, Item, KeyedItem, SeekTarget, SumTree, Summary};
  4
  5#[derive(Clone, Debug, PartialEq, Eq)]
  6pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
  7where
  8    K: Clone + Debug + Default + Ord,
  9    V: Clone + Debug;
 10
 11#[derive(Clone, Debug, PartialEq, Eq)]
 12pub struct MapEntry<K, V> {
 13    key: K,
 14    value: V,
 15}
 16
 17#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 18pub struct MapKey<K>(K);
 19
 20#[derive(Clone, Debug, Default)]
 21pub struct MapKeyRef<'a, K>(Option<&'a K>);
 22
 23#[derive(Clone)]
 24pub struct TreeSet<K>(TreeMap<K, ()>)
 25where
 26    K: Clone + Debug + Default + Ord;
 27
 28impl<K: Clone + Debug + Default + Ord, V: Clone + Debug> TreeMap<K, V> {
 29    pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
 30        let tree = SumTree::from_iter(
 31            entries
 32                .into_iter()
 33                .map(|(key, value)| MapEntry { key, value }),
 34            &(),
 35        );
 36        Self(tree)
 37    }
 38
 39    pub fn is_empty(&self) -> bool {
 40        self.0.is_empty()
 41    }
 42
 43    pub fn get<'a>(&self, key: &'a K) -> Option<&V> {
 44        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 45        cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &());
 46        if let Some(item) = cursor.item() {
 47            if *key == item.key().0 {
 48                Some(&item.value)
 49            } else {
 50                None
 51            }
 52        } else {
 53            None
 54        }
 55    }
 56
 57    pub fn insert(&mut self, key: K, value: V) {
 58        self.0.insert_or_replace(MapEntry { key, value }, &());
 59    }
 60
 61    pub fn remove(&mut self, key: &K) -> Option<V> {
 62        let mut removed = None;
 63        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 64        let key = MapKeyRef(Some(key));
 65        let mut new_tree = cursor.slice(&key, Bias::Left, &());
 66        if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
 67            removed = Some(cursor.item().unwrap().value.clone());
 68            cursor.next(&());
 69        }
 70        new_tree.push_tree(cursor.suffix(&()), &());
 71        drop(cursor);
 72        self.0 = new_tree;
 73        removed
 74    }
 75
 76    /// Returns the key-value pair with the greatest key less than or equal to the given key.
 77    pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
 78        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 79        let key = MapKeyRef(Some(key));
 80        cursor.seek(&key, Bias::Right, &());
 81        cursor.prev(&());
 82        cursor.item().map(|item| (&item.key, &item.value))
 83    }
 84
 85    pub fn remove_between(&mut self, from: &K, until: &K)
 86    {
 87        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 88        let from_key = MapKeyRef(Some(from));
 89        let mut new_tree = cursor.slice(&from_key, Bias::Left, &());
 90        let until_key = MapKeyRef(Some(until));
 91        cursor.seek_forward(&until_key, Bias::Left, &());
 92        new_tree.push_tree(cursor.suffix(&()), &());
 93        drop(cursor);
 94        self.0 = new_tree;
 95    }
 96
 97    pub fn remove_from_while<F>(&mut self, from: &K, mut f: F)
 98    where F: FnMut(&K, &V) -> bool
 99    {
100        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
101        let from_key = MapKeyRef(Some(from));
102        let mut new_tree = cursor.slice(&from_key, Bias::Left, &());
103        while let Some(item) = cursor.item() {
104            if !f(&item.key, &item.value) {
105                break;
106            }
107            cursor.next(&());
108        }
109        new_tree.push_tree(cursor.suffix(&()), &());
110        drop(cursor);
111        self.0 = new_tree;
112    }
113
114
115    pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
116    where
117        F: FnOnce(&mut V) -> T,
118    {
119        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
120        let key = MapKeyRef(Some(key));
121        let mut new_tree = cursor.slice(&key, Bias::Left, &());
122        let mut result = None;
123        if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
124            let mut updated = cursor.item().unwrap().clone();
125            result = Some(f(&mut updated.value));
126            new_tree.push(updated, &());
127            cursor.next(&());
128        }
129        new_tree.push_tree(cursor.suffix(&()), &());
130        drop(cursor);
131        self.0 = new_tree;
132        result
133    }
134
135    pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
136        let mut new_map = SumTree::<MapEntry<K, V>>::default();
137
138        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
139        cursor.next(&());
140        while let Some(item) = cursor.item() {
141            if predicate(&item.key, &item.value) {
142                new_map.push(item.clone(), &());
143            }
144            cursor.next(&());
145        }
146        drop(cursor);
147
148        self.0 = new_map;
149    }
150
151    pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
152        self.0.iter().map(|entry| (&entry.key, &entry.value))
153    }
154
155    pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
156        self.0.iter().map(|entry| &entry.value)
157    }
158}
159
160impl<K, V> Default for TreeMap<K, V>
161where
162    K: Clone + Debug + Default + Ord,
163    V: Clone + Debug,
164{
165    fn default() -> Self {
166        Self(Default::default())
167    }
168}
169
170impl<K, V> Item for MapEntry<K, V>
171where
172    K: Clone + Debug + Default + Ord,
173    V: Clone,
174{
175    type Summary = MapKey<K>;
176
177    fn summary(&self) -> Self::Summary {
178        self.key()
179    }
180}
181
182impl<K, V> KeyedItem for MapEntry<K, V>
183where
184    K: Clone + Debug + Default + Ord,
185    V: Clone,
186{
187    type Key = MapKey<K>;
188
189    fn key(&self) -> Self::Key {
190        MapKey(self.key.clone())
191    }
192}
193
194impl<K> Summary for MapKey<K>
195where
196    K: Clone + Debug + Default,
197{
198    type Context = ();
199
200    fn add_summary(&mut self, summary: &Self, _: &()) {
201        *self = summary.clone()
202    }
203}
204
205impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
206where
207    K: Clone + Debug + Default + Ord,
208{
209    fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
210        self.0 = Some(&summary.0)
211    }
212}
213
214impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
215where
216    K: Clone + Debug + Default + Ord,
217{
218    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
219        self.0.cmp(&cursor_location.0)
220    }
221}
222
223impl<K> Default for TreeSet<K>
224where
225    K: Clone + Debug + Default + Ord,
226{
227    fn default() -> Self {
228        Self(Default::default())
229    }
230}
231
232impl<K> TreeSet<K>
233where
234    K: Clone + Debug + Default + Ord,
235{
236    pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
237        Self(TreeMap::from_ordered_entries(
238            entries.into_iter().map(|key| (key, ())),
239        ))
240    }
241
242    pub fn insert(&mut self, key: K) {
243        self.0.insert(key, ());
244    }
245
246    pub fn contains(&self, key: &K) -> bool {
247        self.0.get(key).is_some()
248    }
249
250    pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
251        self.0.iter().map(|(k, _)| k)
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258
259    #[test]
260    fn test_basic() {
261        let mut map = TreeMap::default();
262        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
263
264        map.insert(3, "c");
265        assert_eq!(map.get(&3), Some(&"c"));
266        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
267
268        map.insert(1, "a");
269        assert_eq!(map.get(&1), Some(&"a"));
270        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
271
272        map.insert(2, "b");
273        assert_eq!(map.get(&2), Some(&"b"));
274        assert_eq!(map.get(&1), Some(&"a"));
275        assert_eq!(map.get(&3), Some(&"c"));
276        assert_eq!(
277            map.iter().collect::<Vec<_>>(),
278            vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
279        );
280
281        assert_eq!(map.closest(&0), None);
282        assert_eq!(map.closest(&1), Some((&1, &"a")));
283        assert_eq!(map.closest(&10), Some((&3, &"c")));
284
285        map.remove(&2);
286        assert_eq!(map.get(&2), None);
287        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
288
289        assert_eq!(map.closest(&2), Some((&1, &"a")));
290
291        map.remove(&3);
292        assert_eq!(map.get(&3), None);
293        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
294
295        map.remove(&1);
296        assert_eq!(map.get(&1), None);
297        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
298
299        map.insert(4, "d");
300        map.insert(5, "e");
301        map.insert(6, "f");
302        map.retain(|key, _| *key % 2 == 0);
303        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
304    }
305
306    #[test]
307    fn test_remove_between() {
308        let mut map = TreeMap::default();
309
310        map.insert("a", 1);
311        map.insert("b", 2);
312        map.insert("baa", 3);
313        map.insert("baaab", 4);
314        map.insert("c", 5);
315
316        map.remove_between(&"ba", &"bb");
317
318        assert_eq!(map.get(&"a"), Some(&1));
319        assert_eq!(map.get(&"b"), Some(&2));
320        assert_eq!(map.get(&"baaa"), None);
321        assert_eq!(map.get(&"baaaab"), None);
322        assert_eq!(map.get(&"c"), Some(&5));
323    }
324
325    #[test]
326    fn test_remove_from_while() {
327        let mut map = TreeMap::default();
328
329        map.insert("a", 1);
330        map.insert("b", 2);
331        map.insert("baa", 3);
332        map.insert("baaab", 4);
333        map.insert("c", 5);
334
335        map.remove_from_while(&"ba", |key, _| key.starts_with(&"ba"));
336
337        assert_eq!(map.get(&"a"), Some(&1));
338        assert_eq!(map.get(&"b"), Some(&2));
339        assert_eq!(map.get(&"baaa"), None);
340        assert_eq!(map.get(&"baaaab"), None);
341        assert_eq!(map.get(&"c"), Some(&5));
342    }
343}