tree_map.rs

  1use std::{cmp::Ordering, fmt::Debug};
  2
  3use crate::{Bias, Dimension, Edit, Item, KeyedItem, SeekTarget, SumTree, Summary};
  4
  5#[derive(Clone, Debug, PartialEq, Eq)]
  6pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
  7where
  8    K: Clone + Debug + Default + Ord,
  9    V: Clone + Debug;
 10
 11#[derive(Clone, Debug, PartialEq, Eq)]
 12pub struct MapEntry<K, V> {
 13    key: K,
 14    value: V,
 15}
 16
 17#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 18pub struct MapKey<K>(K);
 19
 20#[derive(Clone, Debug, Default)]
 21pub struct MapKeyRef<'a, K>(Option<&'a K>);
 22
 23#[derive(Clone)]
 24pub struct TreeSet<K>(TreeMap<K, ()>)
 25where
 26    K: Clone + Debug + Default + Ord;
 27
 28impl<K: Clone + Debug + Default + Ord, V: Clone + Debug> TreeMap<K, V> {
 29    pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
 30        let tree = SumTree::from_iter(
 31            entries
 32                .into_iter()
 33                .map(|(key, value)| MapEntry { key, value }),
 34            &(),
 35        );
 36        Self(tree)
 37    }
 38
 39    pub fn is_empty(&self) -> bool {
 40        self.0.is_empty()
 41    }
 42
 43    pub fn get<'a>(&self, key: &'a K) -> Option<&V> {
 44        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 45        cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &());
 46        if let Some(item) = cursor.item() {
 47            if *key == item.key().0 {
 48                Some(&item.value)
 49            } else {
 50                None
 51            }
 52        } else {
 53            None
 54        }
 55    }
 56
 57    pub fn insert(&mut self, key: K, value: V) {
 58        self.0.insert_or_replace(MapEntry { key, value }, &());
 59    }
 60
 61    pub fn remove(&mut self, key: &K) -> Option<V> {
 62        let mut removed = None;
 63        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 64        let key = MapKeyRef(Some(key));
 65        let mut new_tree = cursor.slice(&key, Bias::Left, &());
 66        if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
 67            removed = Some(cursor.item().unwrap().value.clone());
 68            cursor.next(&());
 69        }
 70        new_tree.push_tree(cursor.suffix(&()), &());
 71        drop(cursor);
 72        self.0 = new_tree;
 73        removed
 74    }
 75
 76    /// Returns the key-value pair with the greatest key less than or equal to the given key.
 77    pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
 78        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 79        let key = MapKeyRef(Some(key));
 80        cursor.seek(&key, Bias::Right, &());
 81        cursor.prev(&());
 82        cursor.item().map(|item| (&item.key, &item.value))
 83    }
 84
 85    pub fn remove_between(&mut self, from: &K, until: &K) {
 86        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 87        let from_key = MapKeyRef(Some(from));
 88        let mut new_tree = cursor.slice(&from_key, Bias::Left, &());
 89        let until_key = MapKeyRef(Some(until));
 90        cursor.seek_forward(&until_key, Bias::Left, &());
 91        new_tree.push_tree(cursor.suffix(&()), &());
 92        drop(cursor);
 93        self.0 = new_tree;
 94    }
 95
 96    pub fn iter_from<'a>(&'a self, from: &'a K) -> impl Iterator<Item = (&K, &V)> + '_ {
 97        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 98        let from_key = MapKeyRef(Some(from));
 99        cursor.seek(&from_key, Bias::Left, &());
100
101        cursor
102            .into_iter()
103            .map(|map_entry| (&map_entry.key, &map_entry.value))
104    }
105
106    pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
107    where
108        F: FnOnce(&mut V) -> T,
109    {
110        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
111        let key = MapKeyRef(Some(key));
112        let mut new_tree = cursor.slice(&key, Bias::Left, &());
113        let mut result = None;
114        if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
115            let mut updated = cursor.item().unwrap().clone();
116            result = Some(f(&mut updated.value));
117            new_tree.push(updated, &());
118            cursor.next(&());
119        }
120        new_tree.push_tree(cursor.suffix(&()), &());
121        drop(cursor);
122        self.0 = new_tree;
123        result
124    }
125
126    pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
127        let mut new_map = SumTree::<MapEntry<K, V>>::default();
128
129        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
130        cursor.next(&());
131        while let Some(item) = cursor.item() {
132            if predicate(&item.key, &item.value) {
133                new_map.push(item.clone(), &());
134            }
135            cursor.next(&());
136        }
137        drop(cursor);
138
139        self.0 = new_map;
140    }
141
142    pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
143        self.0.iter().map(|entry| (&entry.key, &entry.value))
144    }
145
146    pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
147        self.0.iter().map(|entry| &entry.value)
148    }
149
150    pub fn insert_tree(&mut self, other: TreeMap<K, V>) {
151        let edits = other
152            .iter()
153            .map(|(key, value)| {
154                Edit::Insert(MapEntry {
155                    key: key.to_owned(),
156                    value: value.to_owned(),
157                })
158            })
159            .collect();
160
161        self.0.edit(edits, &());
162    }
163
164    pub fn remove_by<F>(&mut self, key: &K, f: F)
165    where
166        F: Fn(&K) -> bool,
167    {
168        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
169        let key = MapKeyRef(Some(key));
170        let mut new_tree = cursor.slice(&key, Bias::Left, &());
171        let until = RemoveByTarget(key, &f);
172        cursor.seek_forward(&until, Bias::Right, &());
173        new_tree.push_tree(cursor.suffix(&()), &());
174        drop(cursor);
175        self.0 = new_tree;
176    }
177}
178
179struct RemoveByTarget<'a, K>(MapKeyRef<'a, K>, &'a dyn Fn(&K) -> bool);
180
181impl<'a, K: Debug> Debug for RemoveByTarget<'a, K> {
182    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183        f.debug_struct("RemoveByTarget")
184            .field("key", &self.0)
185            .field("F", &"<...>")
186            .finish()
187    }
188}
189
190impl<'a, K: Debug + Clone + Default + Ord> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>>
191    for RemoveByTarget<'_, K>
192{
193    fn cmp(
194        &self,
195        cursor_location: &MapKeyRef<'a, K>,
196        _cx: &<MapKey<K> as Summary>::Context,
197    ) -> Ordering {
198        if let Some(cursor_location) = cursor_location.0 {
199            if (self.1)(cursor_location) {
200                Ordering::Equal
201            } else {
202                self.0 .0.unwrap().cmp(cursor_location)
203            }
204        } else {
205            Ordering::Greater
206        }
207    }
208}
209
210impl<K, V> Default for TreeMap<K, V>
211where
212    K: Clone + Debug + Default + Ord,
213    V: Clone + Debug,
214{
215    fn default() -> Self {
216        Self(Default::default())
217    }
218}
219
220impl<K, V> Item for MapEntry<K, V>
221where
222    K: Clone + Debug + Default + Ord,
223    V: Clone,
224{
225    type Summary = MapKey<K>;
226
227    fn summary(&self) -> Self::Summary {
228        self.key()
229    }
230}
231
232impl<K, V> KeyedItem for MapEntry<K, V>
233where
234    K: Clone + Debug + Default + Ord,
235    V: Clone,
236{
237    type Key = MapKey<K>;
238
239    fn key(&self) -> Self::Key {
240        MapKey(self.key.clone())
241    }
242}
243
244impl<K> Summary for MapKey<K>
245where
246    K: Clone + Debug + Default,
247{
248    type Context = ();
249
250    fn add_summary(&mut self, summary: &Self, _: &()) {
251        *self = summary.clone()
252    }
253}
254
255impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
256where
257    K: Clone + Debug + Default + Ord,
258{
259    fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
260        self.0 = Some(&summary.0)
261    }
262}
263
264impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
265where
266    K: Clone + Debug + Default + Ord,
267{
268    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
269        self.0.cmp(&cursor_location.0)
270    }
271}
272
273impl<K> Default for TreeSet<K>
274where
275    K: Clone + Debug + Default + Ord,
276{
277    fn default() -> Self {
278        Self(Default::default())
279    }
280}
281
282impl<K> TreeSet<K>
283where
284    K: Clone + Debug + Default + Ord,
285{
286    pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
287        Self(TreeMap::from_ordered_entries(
288            entries.into_iter().map(|key| (key, ())),
289        ))
290    }
291
292    pub fn insert(&mut self, key: K) {
293        self.0.insert(key, ());
294    }
295
296    pub fn contains(&self, key: &K) -> bool {
297        self.0.get(key).is_some()
298    }
299
300    pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
301        self.0.iter().map(|(k, _)| k)
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    #[test]
310    fn test_basic() {
311        let mut map = TreeMap::default();
312        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
313
314        map.insert(3, "c");
315        assert_eq!(map.get(&3), Some(&"c"));
316        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
317
318        map.insert(1, "a");
319        assert_eq!(map.get(&1), Some(&"a"));
320        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
321
322        map.insert(2, "b");
323        assert_eq!(map.get(&2), Some(&"b"));
324        assert_eq!(map.get(&1), Some(&"a"));
325        assert_eq!(map.get(&3), Some(&"c"));
326        assert_eq!(
327            map.iter().collect::<Vec<_>>(),
328            vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
329        );
330
331        assert_eq!(map.closest(&0), None);
332        assert_eq!(map.closest(&1), Some((&1, &"a")));
333        assert_eq!(map.closest(&10), Some((&3, &"c")));
334
335        map.remove(&2);
336        assert_eq!(map.get(&2), None);
337        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
338
339        assert_eq!(map.closest(&2), Some((&1, &"a")));
340
341        map.remove(&3);
342        assert_eq!(map.get(&3), None);
343        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
344
345        map.remove(&1);
346        assert_eq!(map.get(&1), None);
347        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
348
349        map.insert(4, "d");
350        map.insert(5, "e");
351        map.insert(6, "f");
352        map.retain(|key, _| *key % 2 == 0);
353        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
354    }
355
356    #[test]
357    fn test_remove_between() {
358        let mut map = TreeMap::default();
359
360        map.insert("a", 1);
361        map.insert("b", 2);
362        map.insert("baa", 3);
363        map.insert("baaab", 4);
364        map.insert("c", 5);
365
366        map.remove_between(&"ba", &"bb");
367
368        assert_eq!(map.get(&"a"), Some(&1));
369        assert_eq!(map.get(&"b"), Some(&2));
370        assert_eq!(map.get(&"baaa"), None);
371        assert_eq!(map.get(&"baaaab"), None);
372        assert_eq!(map.get(&"c"), Some(&5));
373    }
374
375    #[test]
376    fn test_remove_by() {
377        let mut map = TreeMap::default();
378
379        map.insert("a", 1);
380        map.insert("aa", 1);
381        map.insert("b", 2);
382        map.insert("baa", 3);
383        map.insert("baaab", 4);
384        map.insert("c", 5);
385        map.insert("ca", 6);
386
387        map.remove_by(&"ba", |key| key.starts_with("ba"));
388
389        assert_eq!(map.get(&"a"), Some(&1));
390        assert_eq!(map.get(&"aa"), Some(&1));
391        assert_eq!(map.get(&"b"), Some(&2));
392        assert_eq!(map.get(&"baaa"), None);
393        assert_eq!(map.get(&"baaaab"), None);
394        assert_eq!(map.get(&"c"), Some(&5));
395        assert_eq!(map.get(&"ca"), Some(&6));
396
397        map.remove_by(&"c", |key| key.starts_with("c"));
398
399        assert_eq!(map.get(&"a"), Some(&1));
400        assert_eq!(map.get(&"aa"), Some(&1));
401        assert_eq!(map.get(&"b"), Some(&2));
402        assert_eq!(map.get(&"c"), None);
403        assert_eq!(map.get(&"ca"), None);
404
405        map.remove_by(&"a", |key| key.starts_with("a"));
406
407        assert_eq!(map.get(&"a"), None);
408        assert_eq!(map.get(&"aa"), None);
409        assert_eq!(map.get(&"b"), Some(&2));
410
411        map.remove_by(&"b", |key| key.starts_with("b"));
412
413        assert_eq!(map.get(&"b"), None);
414    }
415
416    #[test]
417    fn test_iter_from() {
418        let mut map = TreeMap::default();
419
420        map.insert("a", 1);
421        map.insert("b", 2);
422        map.insert("baa", 3);
423        map.insert("baaab", 4);
424        map.insert("c", 5);
425
426        let result = map
427            .iter_from(&"ba")
428            .take_while(|(key, _)| key.starts_with(&"ba"))
429            .collect::<Vec<_>>();
430
431        assert_eq!(result.len(), 2);
432        assert!(result.iter().find(|(k, _)| k == &&"baa").is_some());
433        assert!(result.iter().find(|(k, _)| k == &&"baaab").is_some());
434
435        let result = map
436            .iter_from(&"c")
437            .take_while(|(key, _)| key.starts_with(&"c"))
438            .collect::<Vec<_>>();
439
440        assert_eq!(result.len(), 1);
441        assert!(result.iter().find(|(k, _)| k == &&"c").is_some());
442    }
443
444    #[test]
445    fn test_insert_tree() {
446        let mut map = TreeMap::default();
447        map.insert("a", 1);
448        map.insert("b", 2);
449        map.insert("c", 3);
450
451        let mut other = TreeMap::default();
452        other.insert("a", 2);
453        other.insert("b", 2);
454        other.insert("d", 4);
455
456        map.insert_tree(other);
457
458        assert_eq!(map.iter().count(), 4);
459        assert_eq!(map.get(&"a"), Some(&2));
460        assert_eq!(map.get(&"b"), Some(&2));
461        assert_eq!(map.get(&"c"), Some(&3));
462        assert_eq!(map.get(&"d"), Some(&4));
463    }
464}