tree_map.rs

  1use std::{cmp::Ordering, fmt::Debug};
  2
  3use crate::{Bias, Dimension, Edit, Item, KeyedItem, SeekTarget, SumTree, Summary};
  4
  5#[derive(Clone, PartialEq, Eq)]
  6pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
  7where
  8    K: Clone + Ord,
  9    V: Clone;
 10
 11#[derive(Clone, Debug, PartialEq, Eq)]
 12pub struct MapEntry<K, V> {
 13    key: K,
 14    value: V,
 15}
 16
 17#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
 18pub struct MapKey<K>(Option<K>);
 19
 20impl<K> Default for MapKey<K> {
 21    fn default() -> Self {
 22        Self(None)
 23    }
 24}
 25
 26#[derive(Clone, Debug)]
 27pub struct MapKeyRef<'a, K>(Option<&'a K>);
 28
 29impl<K> Default for MapKeyRef<'_, K> {
 30    fn default() -> Self {
 31        Self(None)
 32    }
 33}
 34
 35#[derive(Clone, Debug, PartialEq, Eq)]
 36pub struct TreeSet<K>(TreeMap<K, ()>)
 37where
 38    K: Clone + Ord;
 39
 40impl<K: Clone + Ord, V: Clone> TreeMap<K, V> {
 41    pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
 42        let tree = SumTree::from_iter(
 43            entries
 44                .into_iter()
 45                .map(|(key, value)| MapEntry { key, value }),
 46            &(),
 47        );
 48        Self(tree)
 49    }
 50
 51    pub fn is_empty(&self) -> bool {
 52        self.0.is_empty()
 53    }
 54
 55    pub fn get(&self, key: &K) -> Option<&V> {
 56        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
 57        cursor.seek(&MapKeyRef(Some(key)), Bias::Left);
 58        if let Some(item) = cursor.item() {
 59            if Some(key) == item.key().0.as_ref() {
 60                Some(&item.value)
 61            } else {
 62                None
 63            }
 64        } else {
 65            None
 66        }
 67    }
 68
 69    pub fn insert(&mut self, key: K, value: V) {
 70        self.0.insert_or_replace(MapEntry { key, value }, &());
 71    }
 72
 73    pub fn extend(&mut self, iter: impl IntoIterator<Item = (K, V)>) {
 74        let edits: Vec<_> = iter
 75            .into_iter()
 76            .map(|(key, value)| Edit::Insert(MapEntry { key, value }))
 77            .collect();
 78        self.0.edit(edits, &());
 79    }
 80
 81    pub fn clear(&mut self) {
 82        self.0 = SumTree::default();
 83    }
 84
 85    pub fn remove(&mut self, key: &K) -> Option<V> {
 86        let mut removed = None;
 87        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
 88        let key = MapKeyRef(Some(key));
 89        let mut new_tree = cursor.slice(&key, Bias::Left);
 90        if key.cmp(&cursor.end(), &()) == Ordering::Equal {
 91            removed = Some(cursor.item().unwrap().value.clone());
 92            cursor.next();
 93        }
 94        new_tree.append(cursor.suffix(), &());
 95        drop(cursor);
 96        self.0 = new_tree;
 97        removed
 98    }
 99
100    pub fn remove_range(&mut self, start: &impl MapSeekTarget<K>, end: &impl MapSeekTarget<K>) {
101        let start = MapSeekTargetAdaptor(start);
102        let end = MapSeekTargetAdaptor(end);
103        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
104        let mut new_tree = cursor.slice(&start, Bias::Left);
105        cursor.seek(&end, Bias::Left);
106        new_tree.append(cursor.suffix(), &());
107        drop(cursor);
108        self.0 = new_tree;
109    }
110
111    /// Returns the key-value pair with the greatest key less than or equal to the given key.
112    pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
113        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
114        let key = MapKeyRef(Some(key));
115        cursor.seek(&key, Bias::Right);
116        cursor.prev();
117        cursor.item().map(|item| (&item.key, &item.value))
118    }
119
120    pub fn iter_from<'a>(&'a self, from: &K) -> impl Iterator<Item = (&'a K, &'a V)> + 'a {
121        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
122        let from_key = MapKeyRef(Some(from));
123        cursor.seek(&from_key, Bias::Left);
124
125        cursor.map(|map_entry| (&map_entry.key, &map_entry.value))
126    }
127
128    pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
129    where
130        F: FnOnce(&mut V) -> T,
131    {
132        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
133        let key = MapKeyRef(Some(key));
134        let mut new_tree = cursor.slice(&key, Bias::Left);
135        let mut result = None;
136        if key.cmp(&cursor.end(), &()) == Ordering::Equal {
137            let mut updated = cursor.item().unwrap().clone();
138            result = Some(f(&mut updated.value));
139            new_tree.push(updated, &());
140            cursor.next();
141        }
142        new_tree.append(cursor.suffix(), &());
143        drop(cursor);
144        self.0 = new_tree;
145        result
146    }
147
148    pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
149        let mut new_map = SumTree::<MapEntry<K, V>>::default();
150
151        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
152        cursor.next();
153        while let Some(item) = cursor.item() {
154            if predicate(&item.key, &item.value) {
155                new_map.push(item.clone(), &());
156            }
157            cursor.next();
158        }
159        drop(cursor);
160
161        self.0 = new_map;
162    }
163
164    pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
165        self.0.iter().map(|entry| (&entry.key, &entry.value))
166    }
167
168    pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
169        self.0.iter().map(|entry| &entry.value)
170    }
171
172    pub fn first(&self) -> Option<(&K, &V)> {
173        self.0.first().map(|entry| (&entry.key, &entry.value))
174    }
175
176    pub fn last(&self) -> Option<(&K, &V)> {
177        self.0.last().map(|entry| (&entry.key, &entry.value))
178    }
179
180    pub fn insert_tree(&mut self, other: TreeMap<K, V>) {
181        let edits = other
182            .iter()
183            .map(|(key, value)| {
184                Edit::Insert(MapEntry {
185                    key: key.to_owned(),
186                    value: value.to_owned(),
187                })
188            })
189            .collect();
190
191        self.0.edit(edits, &());
192    }
193}
194
195impl<K, V> Debug for TreeMap<K, V>
196where
197    K: Clone + Debug + Ord,
198    V: Clone + Debug,
199{
200    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201        f.debug_map().entries(self.iter()).finish()
202    }
203}
204
205#[derive(Debug)]
206struct MapSeekTargetAdaptor<'a, T>(&'a T);
207
208impl<'a, K: Clone + Ord, T: MapSeekTarget<K>> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>>
209    for MapSeekTargetAdaptor<'_, T>
210{
211    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
212        if let Some(key) = &cursor_location.0 {
213            MapSeekTarget::cmp_cursor(self.0, key)
214        } else {
215            Ordering::Greater
216        }
217    }
218}
219
220pub trait MapSeekTarget<K> {
221    fn cmp_cursor(&self, cursor_location: &K) -> Ordering;
222}
223
224impl<K: Ord> MapSeekTarget<K> for K {
225    fn cmp_cursor(&self, cursor_location: &K) -> Ordering {
226        self.cmp(cursor_location)
227    }
228}
229
230impl<K, V> Default for TreeMap<K, V>
231where
232    K: Clone + Ord,
233    V: Clone,
234{
235    fn default() -> Self {
236        Self(Default::default())
237    }
238}
239
240impl<K, V> Item for MapEntry<K, V>
241where
242    K: Clone + Ord,
243    V: Clone,
244{
245    type Summary = MapKey<K>;
246
247    fn summary(&self, _cx: &()) -> Self::Summary {
248        self.key()
249    }
250}
251
252impl<K, V> KeyedItem for MapEntry<K, V>
253where
254    K: Clone + Ord,
255    V: Clone,
256{
257    type Key = MapKey<K>;
258
259    fn key(&self) -> Self::Key {
260        MapKey(Some(self.key.clone()))
261    }
262}
263
264impl<K> Summary for MapKey<K>
265where
266    K: Clone,
267{
268    type Context = ();
269
270    fn zero(_cx: &()) -> Self {
271        Default::default()
272    }
273
274    fn add_summary(&mut self, summary: &Self, _: &()) {
275        *self = summary.clone()
276    }
277}
278
279impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
280where
281    K: Clone + Ord,
282{
283    fn zero(_cx: &()) -> Self {
284        Default::default()
285    }
286
287    fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
288        self.0 = summary.0.as_ref();
289    }
290}
291
292impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
293where
294    K: Clone + Ord,
295{
296    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
297        Ord::cmp(&self.0, &cursor_location.0)
298    }
299}
300
301impl<K> Default for TreeSet<K>
302where
303    K: Clone + Ord,
304{
305    fn default() -> Self {
306        Self(Default::default())
307    }
308}
309
310impl<K> TreeSet<K>
311where
312    K: Clone + Ord,
313{
314    pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
315        Self(TreeMap::from_ordered_entries(
316            entries.into_iter().map(|key| (key, ())),
317        ))
318    }
319
320    pub fn is_empty(&self) -> bool {
321        self.0.is_empty()
322    }
323
324    pub fn insert(&mut self, key: K) {
325        self.0.insert(key, ());
326    }
327
328    pub fn remove(&mut self, key: &K) -> bool {
329        self.0.remove(key).is_some()
330    }
331
332    pub fn extend(&mut self, iter: impl IntoIterator<Item = K>) {
333        self.0.extend(iter.into_iter().map(|key| (key, ())));
334    }
335
336    pub fn contains(&self, key: &K) -> bool {
337        self.0.get(key).is_some()
338    }
339
340    pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
341        self.0.iter().map(|(k, _)| k)
342    }
343
344    pub fn iter_from<'a>(&'a self, key: &K) -> impl Iterator<Item = &'a K> + 'a {
345        self.0.iter_from(key).map(move |(k, _)| k)
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn test_basic() {
355        let mut map = TreeMap::default();
356        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
357
358        map.insert(3, "c");
359        assert_eq!(map.get(&3), Some(&"c"));
360        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
361
362        map.insert(1, "a");
363        assert_eq!(map.get(&1), Some(&"a"));
364        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
365
366        map.insert(2, "b");
367        assert_eq!(map.get(&2), Some(&"b"));
368        assert_eq!(map.get(&1), Some(&"a"));
369        assert_eq!(map.get(&3), Some(&"c"));
370        assert_eq!(
371            map.iter().collect::<Vec<_>>(),
372            vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
373        );
374
375        assert_eq!(map.closest(&0), None);
376        assert_eq!(map.closest(&1), Some((&1, &"a")));
377        assert_eq!(map.closest(&10), Some((&3, &"c")));
378
379        map.remove(&2);
380        assert_eq!(map.get(&2), None);
381        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
382
383        assert_eq!(map.closest(&2), Some((&1, &"a")));
384
385        map.remove(&3);
386        assert_eq!(map.get(&3), None);
387        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
388
389        map.remove(&1);
390        assert_eq!(map.get(&1), None);
391        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
392
393        map.insert(4, "d");
394        map.insert(5, "e");
395        map.insert(6, "f");
396        map.retain(|key, _| *key % 2 == 0);
397        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
398    }
399
400    #[test]
401    fn test_iter_from() {
402        let mut map = TreeMap::default();
403
404        map.insert("a", 1);
405        map.insert("b", 2);
406        map.insert("baa", 3);
407        map.insert("baaab", 4);
408        map.insert("c", 5);
409
410        let result = map
411            .iter_from(&"ba")
412            .take_while(|(key, _)| key.starts_with("ba"))
413            .collect::<Vec<_>>();
414
415        assert_eq!(result.len(), 2);
416        assert!(result.iter().any(|(k, _)| k == &&"baa"));
417        assert!(result.iter().any(|(k, _)| k == &&"baaab"));
418
419        let result = map
420            .iter_from(&"c")
421            .take_while(|(key, _)| key.starts_with("c"))
422            .collect::<Vec<_>>();
423
424        assert_eq!(result.len(), 1);
425        assert!(result.iter().any(|(k, _)| k == &&"c"));
426    }
427
428    #[test]
429    fn test_insert_tree() {
430        let mut map = TreeMap::default();
431        map.insert("a", 1);
432        map.insert("b", 2);
433        map.insert("c", 3);
434
435        let mut other = TreeMap::default();
436        other.insert("a", 2);
437        other.insert("b", 2);
438        other.insert("d", 4);
439
440        map.insert_tree(other);
441
442        assert_eq!(map.iter().count(), 4);
443        assert_eq!(map.get(&"a"), Some(&2));
444        assert_eq!(map.get(&"b"), Some(&2));
445        assert_eq!(map.get(&"c"), Some(&3));
446        assert_eq!(map.get(&"d"), Some(&4));
447    }
448
449    #[test]
450    fn test_extend() {
451        let mut map = TreeMap::default();
452        map.insert("a", 1);
453        map.insert("b", 2);
454        map.insert("c", 3);
455        map.extend([("a", 2), ("b", 2), ("d", 4)]);
456        assert_eq!(map.iter().count(), 4);
457        assert_eq!(map.get(&"a"), Some(&2));
458        assert_eq!(map.get(&"b"), Some(&2));
459        assert_eq!(map.get(&"c"), Some(&3));
460        assert_eq!(map.get(&"d"), Some(&4));
461    }
462
463    #[test]
464    fn test_remove_between_and_path_successor() {
465        use std::path::{Path, PathBuf};
466
467        #[derive(Debug)]
468        pub struct PathDescendants<'a>(&'a Path);
469
470        impl MapSeekTarget<PathBuf> for PathDescendants<'_> {
471            fn cmp_cursor(&self, key: &PathBuf) -> Ordering {
472                if key.starts_with(self.0) {
473                    Ordering::Greater
474                } else {
475                    self.0.cmp(key)
476                }
477            }
478        }
479
480        let mut map = TreeMap::default();
481
482        map.insert(PathBuf::from("a"), 1);
483        map.insert(PathBuf::from("a/a"), 1);
484        map.insert(PathBuf::from("b"), 2);
485        map.insert(PathBuf::from("b/a/a"), 3);
486        map.insert(PathBuf::from("b/a/a/a/b"), 4);
487        map.insert(PathBuf::from("c"), 5);
488        map.insert(PathBuf::from("c/a"), 6);
489
490        map.remove_range(
491            &PathBuf::from("b/a"),
492            &PathDescendants(&PathBuf::from("b/a")),
493        );
494
495        assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
496        assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
497        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
498        assert_eq!(map.get(&PathBuf::from("b/a/a")), None);
499        assert_eq!(map.get(&PathBuf::from("b/a/a/a/b")), None);
500        assert_eq!(map.get(&PathBuf::from("c")), Some(&5));
501        assert_eq!(map.get(&PathBuf::from("c/a")), Some(&6));
502
503        map.remove_range(&PathBuf::from("c"), &PathDescendants(&PathBuf::from("c")));
504
505        assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
506        assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
507        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
508        assert_eq!(map.get(&PathBuf::from("c")), None);
509        assert_eq!(map.get(&PathBuf::from("c/a")), None);
510
511        map.remove_range(&PathBuf::from("a"), &PathDescendants(&PathBuf::from("a")));
512
513        assert_eq!(map.get(&PathBuf::from("a")), None);
514        assert_eq!(map.get(&PathBuf::from("a/a")), None);
515        assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
516
517        map.remove_range(&PathBuf::from("b"), &PathDescendants(&PathBuf::from("b")));
518
519        assert_eq!(map.get(&PathBuf::from("b")), None);
520    }
521}