tree_map.rs

  1use std::{cmp::Ordering, fmt::Debug, iter};
  2
  3use crate::{Bias, Dimension, Edit, Item, KeyedItem, SeekTarget, SumTree, Summary};
  4
  5#[derive(Clone, Debug, PartialEq, Eq)]
  6pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
  7where
  8    K: Clone + Debug + Default + Ord,
  9    V: Clone + Debug;
 10
 11#[derive(Clone, Debug, PartialEq, Eq)]
 12pub struct MapEntry<K, V> {
 13    key: K,
 14    value: V,
 15}
 16
 17#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 18pub struct MapKey<K>(K);
 19
 20#[derive(Clone, Debug, Default)]
 21pub struct MapKeyRef<'a, K>(Option<&'a K>);
 22
 23#[derive(Clone)]
 24pub struct TreeSet<K>(TreeMap<K, ()>)
 25where
 26    K: Clone + Debug + Default + Ord;
 27
 28impl<K: Clone + Debug + Default + Ord, V: Clone + Debug> TreeMap<K, V> {
 29    pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
 30        let tree = SumTree::from_iter(
 31            entries
 32                .into_iter()
 33                .map(|(key, value)| MapEntry { key, value }),
 34            &(),
 35        );
 36        Self(tree)
 37    }
 38
 39    pub fn is_empty(&self) -> bool {
 40        self.0.is_empty()
 41    }
 42
 43    pub fn get<'a>(&self, key: &'a K) -> Option<&V> {
 44        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 45        cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &());
 46        if let Some(item) = cursor.item() {
 47            if *key == item.key().0 {
 48                Some(&item.value)
 49            } else {
 50                None
 51            }
 52        } else {
 53            None
 54        }
 55    }
 56
 57    pub fn insert(&mut self, key: K, value: V) {
 58        self.0.insert_or_replace(MapEntry { key, value }, &());
 59    }
 60
 61    pub fn remove(&mut self, key: &K) -> Option<V> {
 62        let mut removed = None;
 63        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 64        let key = MapKeyRef(Some(key));
 65        let mut new_tree = cursor.slice(&key, Bias::Left, &());
 66        if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
 67            removed = Some(cursor.item().unwrap().value.clone());
 68            cursor.next(&());
 69        }
 70        new_tree.push_tree(cursor.suffix(&()), &());
 71        drop(cursor);
 72        self.0 = new_tree;
 73        removed
 74    }
 75
 76    /// Returns the key-value pair with the greatest key less than or equal to the given key.
 77    pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
 78        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 79        let key = MapKeyRef(Some(key));
 80        cursor.seek(&key, Bias::Right, &());
 81        cursor.prev(&());
 82        cursor.item().map(|item| (&item.key, &item.value))
 83    }
 84
 85    pub fn remove_between(&mut self, from: &K, until: &K) {
 86        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 87        let from_key = MapKeyRef(Some(from));
 88        let mut new_tree = cursor.slice(&from_key, Bias::Left, &());
 89        let until_key = MapKeyRef(Some(until));
 90        cursor.seek_forward(&until_key, Bias::Left, &());
 91        new_tree.push_tree(cursor.suffix(&()), &());
 92        drop(cursor);
 93        self.0 = new_tree;
 94    }
 95
 96    pub fn remove_from_while<F>(&mut self, from: &K, mut f: F)
 97    where
 98        F: FnMut(&K, &V) -> bool,
 99    {
100        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
101        let from_key = MapKeyRef(Some(from));
102        let mut new_tree = cursor.slice(&from_key, Bias::Left, &());
103        while let Some(item) = cursor.item() {
104            if !f(&item.key, &item.value) {
105                break;
106            }
107            cursor.next(&());
108        }
109        new_tree.push_tree(cursor.suffix(&()), &());
110        drop(cursor);
111        self.0 = new_tree;
112    }
113
114
115    pub fn get_from_while<'tree, F>(&'tree self, from: &'tree K, mut f: F) -> impl Iterator<Item = (&K, &V)> + '_
116        where
117            F: FnMut(&K, &K, &V) -> bool + 'tree,
118        {
119            let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
120            let from_key = MapKeyRef(Some(from));
121            cursor.seek(&from_key, Bias::Left, &());
122
123            iter::from_fn(move || {
124                let result = cursor.item().and_then(|item| {
125                    (f(from, &item.key, &item.value))
126                        .then(|| (&item.key, &item.value))
127                });
128                cursor.next(&());
129                result
130            })
131        }
132
133
134    pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
135    where
136        F: FnOnce(&mut V) -> T,
137    {
138        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
139        let key = MapKeyRef(Some(key));
140        let mut new_tree = cursor.slice(&key, Bias::Left, &());
141        let mut result = None;
142        if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
143            let mut updated = cursor.item().unwrap().clone();
144            result = Some(f(&mut updated.value));
145            new_tree.push(updated, &());
146            cursor.next(&());
147        }
148        new_tree.push_tree(cursor.suffix(&()), &());
149        drop(cursor);
150        self.0 = new_tree;
151        result
152    }
153
154    pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
155        let mut new_map = SumTree::<MapEntry<K, V>>::default();
156
157        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
158        cursor.next(&());
159        while let Some(item) = cursor.item() {
160            if predicate(&item.key, &item.value) {
161                new_map.push(item.clone(), &());
162            }
163            cursor.next(&());
164        }
165        drop(cursor);
166
167        self.0 = new_map;
168    }
169
170    pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
171        self.0.iter().map(|entry| (&entry.key, &entry.value))
172    }
173
174    pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
175        self.0.iter().map(|entry| &entry.value)
176    }
177
178    pub fn insert_tree(&mut self, other: TreeMap<K, V>) {
179        let edits = other
180            .iter()
181            .map(|(key, value)| {
182                Edit::Insert(MapEntry {
183                    key: key.to_owned(),
184                    value: value.to_owned(),
185                })
186            })
187            .collect();
188
189        self.0.edit(edits, &());
190    }
191}
192
193impl<K, V> Default for TreeMap<K, V>
194where
195    K: Clone + Debug + Default + Ord,
196    V: Clone + Debug,
197{
198    fn default() -> Self {
199        Self(Default::default())
200    }
201}
202
203impl<K, V> Item for MapEntry<K, V>
204where
205    K: Clone + Debug + Default + Ord,
206    V: Clone,
207{
208    type Summary = MapKey<K>;
209
210    fn summary(&self) -> Self::Summary {
211        self.key()
212    }
213}
214
215impl<K, V> KeyedItem for MapEntry<K, V>
216where
217    K: Clone + Debug + Default + Ord,
218    V: Clone,
219{
220    type Key = MapKey<K>;
221
222    fn key(&self) -> Self::Key {
223        MapKey(self.key.clone())
224    }
225}
226
227impl<K> Summary for MapKey<K>
228where
229    K: Clone + Debug + Default,
230{
231    type Context = ();
232
233    fn add_summary(&mut self, summary: &Self, _: &()) {
234        *self = summary.clone()
235    }
236}
237
238impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
239where
240    K: Clone + Debug + Default + Ord,
241{
242    fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
243        self.0 = Some(&summary.0)
244    }
245}
246
247impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
248where
249    K: Clone + Debug + Default + Ord,
250{
251    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
252        self.0.cmp(&cursor_location.0)
253    }
254}
255
256impl<K> Default for TreeSet<K>
257where
258    K: Clone + Debug + Default + Ord,
259{
260    fn default() -> Self {
261        Self(Default::default())
262    }
263}
264
265impl<K> TreeSet<K>
266where
267    K: Clone + Debug + Default + Ord,
268{
269    pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
270        Self(TreeMap::from_ordered_entries(
271            entries.into_iter().map(|key| (key, ())),
272        ))
273    }
274
275    pub fn insert(&mut self, key: K) {
276        self.0.insert(key, ());
277    }
278
279    pub fn contains(&self, key: &K) -> bool {
280        self.0.get(key).is_some()
281    }
282
283    pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
284        self.0.iter().map(|(k, _)| k)
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291
292    #[test]
293    fn test_basic() {
294        let mut map = TreeMap::default();
295        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
296
297        map.insert(3, "c");
298        assert_eq!(map.get(&3), Some(&"c"));
299        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
300
301        map.insert(1, "a");
302        assert_eq!(map.get(&1), Some(&"a"));
303        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
304
305        map.insert(2, "b");
306        assert_eq!(map.get(&2), Some(&"b"));
307        assert_eq!(map.get(&1), Some(&"a"));
308        assert_eq!(map.get(&3), Some(&"c"));
309        assert_eq!(
310            map.iter().collect::<Vec<_>>(),
311            vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
312        );
313
314        assert_eq!(map.closest(&0), None);
315        assert_eq!(map.closest(&1), Some((&1, &"a")));
316        assert_eq!(map.closest(&10), Some((&3, &"c")));
317
318        map.remove(&2);
319        assert_eq!(map.get(&2), None);
320        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
321
322        assert_eq!(map.closest(&2), Some((&1, &"a")));
323
324        map.remove(&3);
325        assert_eq!(map.get(&3), None);
326        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
327
328        map.remove(&1);
329        assert_eq!(map.get(&1), None);
330        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
331
332        map.insert(4, "d");
333        map.insert(5, "e");
334        map.insert(6, "f");
335        map.retain(|key, _| *key % 2 == 0);
336        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
337    }
338
339    #[test]
340    fn test_remove_between() {
341        let mut map = TreeMap::default();
342
343        map.insert("a", 1);
344        map.insert("b", 2);
345        map.insert("baa", 3);
346        map.insert("baaab", 4);
347        map.insert("c", 5);
348
349        map.remove_between(&"ba", &"bb");
350
351        assert_eq!(map.get(&"a"), Some(&1));
352        assert_eq!(map.get(&"b"), Some(&2));
353        assert_eq!(map.get(&"baaa"), None);
354        assert_eq!(map.get(&"baaaab"), None);
355        assert_eq!(map.get(&"c"), Some(&5));
356    }
357
358    #[test]
359    fn test_remove_from_while() {
360        let mut map = TreeMap::default();
361
362        map.insert("a", 1);
363        map.insert("b", 2);
364        map.insert("baa", 3);
365        map.insert("baaab", 4);
366        map.insert("c", 5);
367
368        map.remove_from_while(&"ba", |key, _| key.starts_with(&"ba"));
369
370        assert_eq!(map.get(&"a"), Some(&1));
371        assert_eq!(map.get(&"b"), Some(&2));
372        assert_eq!(map.get(&"baaa"), None);
373        assert_eq!(map.get(&"baaaab"), None);
374        assert_eq!(map.get(&"c"), Some(&5));
375    }
376
377    #[test]
378    fn test_get_from_while() {
379        let mut map = TreeMap::default();
380
381        map.insert("a", 1);
382        map.insert("b", 2);
383        map.insert("baa", 3);
384        map.insert("baaab", 4);
385        map.insert("c", 5);
386
387        let result = map.get_from_while(&"ba", |key, _| key.starts_with(&"ba")).collect::<Vec<_>>();
388
389        assert_eq!(result.len(), 2);
390        assert!(result.iter().find(|(k, _)| k == &&"baa").is_some());
391        assert!(result.iter().find(|(k, _)| k == &&"baaab").is_some());
392
393        let result = map.get_from_while(&"c", |key, _| key.starts_with(&"c")).collect::<Vec<_>>();
394
395        assert_eq!(result.len(), 1);
396        assert!(result.iter().find(|(k, _)| k == &&"c").is_some());
397    }
398
399    #[test]
400    fn test_insert_tree() {
401        let mut map = TreeMap::default();
402        map.insert("a", 1);
403        map.insert("b", 2);
404        map.insert("c", 3);
405
406        let mut other = TreeMap::default();
407        other.insert("a", 2);
408        other.insert("b", 2);
409        other.insert("d", 4);
410
411        map.insert_tree(other);
412
413        assert_eq!(map.iter().count(), 4);
414        assert_eq!(map.get(&"a"), Some(&2));
415        assert_eq!(map.get(&"b"), Some(&2));
416        assert_eq!(map.get(&"c"), Some(&3));
417        assert_eq!(map.get(&"d"), Some(&4));
418    }
419}