tree_map.rs

  1use std::{cmp::Ordering, fmt::Debug, iter};
  2
  3use crate::{Bias, Dimension, Edit, Item, KeyedItem, SeekTarget, SumTree, Summary};
  4
  5#[derive(Clone, Debug, PartialEq, Eq)]
  6pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
  7where
  8    K: Clone + Debug + Default + Ord,
  9    V: Clone + Debug;
 10
 11#[derive(Clone, Debug, PartialEq, Eq)]
 12pub struct MapEntry<K, V> {
 13    key: K,
 14    value: V,
 15}
 16
 17#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 18pub struct MapKey<K>(K);
 19
 20#[derive(Clone, Debug, Default)]
 21pub struct MapKeyRef<'a, K>(Option<&'a K>);
 22
 23#[derive(Clone)]
 24pub struct TreeSet<K>(TreeMap<K, ()>)
 25where
 26    K: Clone + Debug + Default + Ord;
 27
 28impl<K: Clone + Debug + Default + Ord, V: Clone + Debug> TreeMap<K, V> {
 29    pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
 30        let tree = SumTree::from_iter(
 31            entries
 32                .into_iter()
 33                .map(|(key, value)| MapEntry { key, value }),
 34            &(),
 35        );
 36        Self(tree)
 37    }
 38
 39    pub fn is_empty(&self) -> bool {
 40        self.0.is_empty()
 41    }
 42
 43    pub fn get<'a>(&self, key: &'a K) -> Option<&V> {
 44        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 45        cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &());
 46        if let Some(item) = cursor.item() {
 47            if *key == item.key().0 {
 48                Some(&item.value)
 49            } else {
 50                None
 51            }
 52        } else {
 53            None
 54        }
 55    }
 56
 57    pub fn insert(&mut self, key: K, value: V) {
 58        self.0.insert_or_replace(MapEntry { key, value }, &());
 59    }
 60
 61    pub fn remove(&mut self, key: &K) -> Option<V> {
 62        let mut removed = None;
 63        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 64        let key = MapKeyRef(Some(key));
 65        let mut new_tree = cursor.slice(&key, Bias::Left, &());
 66        if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
 67            removed = Some(cursor.item().unwrap().value.clone());
 68            cursor.next(&());
 69        }
 70        new_tree.push_tree(cursor.suffix(&()), &());
 71        drop(cursor);
 72        self.0 = new_tree;
 73        removed
 74    }
 75
 76    /// Returns the key-value pair with the greatest key less than or equal to the given key.
 77    pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
 78        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 79        let key = MapKeyRef(Some(key));
 80        cursor.seek(&key, Bias::Right, &());
 81        cursor.prev(&());
 82        cursor.item().map(|item| (&item.key, &item.value))
 83    }
 84
 85    pub fn remove_between(&mut self, from: &K, until: &K) {
 86        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
 87        let from_key = MapKeyRef(Some(from));
 88        let mut new_tree = cursor.slice(&from_key, Bias::Left, &());
 89        let until_key = MapKeyRef(Some(until));
 90        cursor.seek_forward(&until_key, Bias::Left, &());
 91        new_tree.push_tree(cursor.suffix(&()), &());
 92        drop(cursor);
 93        self.0 = new_tree;
 94    }
 95
 96    pub fn remove_from_while<F>(&mut self, from: &K, mut f: F)
 97    where
 98        F: FnMut(&K, &V) -> bool,
 99    {
100        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
101        let from_key = MapKeyRef(Some(from));
102        let mut new_tree = cursor.slice(&from_key, Bias::Left, &());
103        while let Some(item) = cursor.item() {
104            if !f(&item.key, &item.value) {
105                break;
106            }
107            cursor.next(&());
108        }
109        new_tree.push_tree(cursor.suffix(&()), &());
110        drop(cursor);
111        self.0 = new_tree;
112    }
113
114    pub fn get_from_while<'tree, F>(
115        &'tree self,
116        from: &'tree K,
117        mut f: F,
118    ) -> impl Iterator<Item = (&K, &V)> + '_
119    where
120        F: FnMut(&K, &K, &V) -> bool + 'tree,
121    {
122        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
123        let from_key = MapKeyRef(Some(from));
124        cursor.seek(&from_key, Bias::Left, &());
125
126        iter::from_fn(move || {
127            let result = cursor.item().and_then(|item| {
128                (f(from, &item.key, &item.value)).then(|| (&item.key, &item.value))
129            });
130            cursor.next(&());
131            result
132        })
133    }
134
135    pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
136    where
137        F: FnOnce(&mut V) -> T,
138    {
139        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
140        let key = MapKeyRef(Some(key));
141        let mut new_tree = cursor.slice(&key, Bias::Left, &());
142        let mut result = None;
143        if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
144            let mut updated = cursor.item().unwrap().clone();
145            result = Some(f(&mut updated.value));
146            new_tree.push(updated, &());
147            cursor.next(&());
148        }
149        new_tree.push_tree(cursor.suffix(&()), &());
150        drop(cursor);
151        self.0 = new_tree;
152        result
153    }
154
155    pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
156        let mut new_map = SumTree::<MapEntry<K, V>>::default();
157
158        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
159        cursor.next(&());
160        while let Some(item) = cursor.item() {
161            if predicate(&item.key, &item.value) {
162                new_map.push(item.clone(), &());
163            }
164            cursor.next(&());
165        }
166        drop(cursor);
167
168        self.0 = new_map;
169    }
170
171    pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
172        self.0.iter().map(|entry| (&entry.key, &entry.value))
173    }
174
175    pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
176        self.0.iter().map(|entry| &entry.value)
177    }
178
179    pub fn insert_tree(&mut self, other: TreeMap<K, V>) {
180        let edits = other
181            .iter()
182            .map(|(key, value)| {
183                Edit::Insert(MapEntry {
184                    key: key.to_owned(),
185                    value: value.to_owned(),
186                })
187            })
188            .collect();
189
190        self.0.edit(edits, &());
191    }
192}
193
194impl<K, V> Default for TreeMap<K, V>
195where
196    K: Clone + Debug + Default + Ord,
197    V: Clone + Debug,
198{
199    fn default() -> Self {
200        Self(Default::default())
201    }
202}
203
204impl<K, V> Item for MapEntry<K, V>
205where
206    K: Clone + Debug + Default + Ord,
207    V: Clone,
208{
209    type Summary = MapKey<K>;
210
211    fn summary(&self) -> Self::Summary {
212        self.key()
213    }
214}
215
216impl<K, V> KeyedItem for MapEntry<K, V>
217where
218    K: Clone + Debug + Default + Ord,
219    V: Clone,
220{
221    type Key = MapKey<K>;
222
223    fn key(&self) -> Self::Key {
224        MapKey(self.key.clone())
225    }
226}
227
228impl<K> Summary for MapKey<K>
229where
230    K: Clone + Debug + Default,
231{
232    type Context = ();
233
234    fn add_summary(&mut self, summary: &Self, _: &()) {
235        *self = summary.clone()
236    }
237}
238
239impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
240where
241    K: Clone + Debug + Default + Ord,
242{
243    fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
244        self.0 = Some(&summary.0)
245    }
246}
247
248impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
249where
250    K: Clone + Debug + Default + Ord,
251{
252    fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
253        self.0.cmp(&cursor_location.0)
254    }
255}
256
257impl<K> Default for TreeSet<K>
258where
259    K: Clone + Debug + Default + Ord,
260{
261    fn default() -> Self {
262        Self(Default::default())
263    }
264}
265
266impl<K> TreeSet<K>
267where
268    K: Clone + Debug + Default + Ord,
269{
270    pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
271        Self(TreeMap::from_ordered_entries(
272            entries.into_iter().map(|key| (key, ())),
273        ))
274    }
275
276    pub fn insert(&mut self, key: K) {
277        self.0.insert(key, ());
278    }
279
280    pub fn contains(&self, key: &K) -> bool {
281        self.0.get(key).is_some()
282    }
283
284    pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
285        self.0.iter().map(|(k, _)| k)
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    #[test]
294    fn test_basic() {
295        let mut map = TreeMap::default();
296        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
297
298        map.insert(3, "c");
299        assert_eq!(map.get(&3), Some(&"c"));
300        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
301
302        map.insert(1, "a");
303        assert_eq!(map.get(&1), Some(&"a"));
304        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
305
306        map.insert(2, "b");
307        assert_eq!(map.get(&2), Some(&"b"));
308        assert_eq!(map.get(&1), Some(&"a"));
309        assert_eq!(map.get(&3), Some(&"c"));
310        assert_eq!(
311            map.iter().collect::<Vec<_>>(),
312            vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
313        );
314
315        assert_eq!(map.closest(&0), None);
316        assert_eq!(map.closest(&1), Some((&1, &"a")));
317        assert_eq!(map.closest(&10), Some((&3, &"c")));
318
319        map.remove(&2);
320        assert_eq!(map.get(&2), None);
321        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
322
323        assert_eq!(map.closest(&2), Some((&1, &"a")));
324
325        map.remove(&3);
326        assert_eq!(map.get(&3), None);
327        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
328
329        map.remove(&1);
330        assert_eq!(map.get(&1), None);
331        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
332
333        map.insert(4, "d");
334        map.insert(5, "e");
335        map.insert(6, "f");
336        map.retain(|key, _| *key % 2 == 0);
337        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
338    }
339
340    #[test]
341    fn test_remove_between() {
342        let mut map = TreeMap::default();
343
344        map.insert("a", 1);
345        map.insert("b", 2);
346        map.insert("baa", 3);
347        map.insert("baaab", 4);
348        map.insert("c", 5);
349
350        map.remove_between(&"ba", &"bb");
351
352        assert_eq!(map.get(&"a"), Some(&1));
353        assert_eq!(map.get(&"b"), Some(&2));
354        assert_eq!(map.get(&"baaa"), None);
355        assert_eq!(map.get(&"baaaab"), None);
356        assert_eq!(map.get(&"c"), Some(&5));
357    }
358
359    #[test]
360    fn test_remove_from_while() {
361        let mut map = TreeMap::default();
362
363        map.insert("a", 1);
364        map.insert("b", 2);
365        map.insert("baa", 3);
366        map.insert("baaab", 4);
367        map.insert("c", 5);
368
369        map.remove_from_while(&"ba", |key, _| key.starts_with(&"ba"));
370
371        assert_eq!(map.get(&"a"), Some(&1));
372        assert_eq!(map.get(&"b"), Some(&2));
373        assert_eq!(map.get(&"baaa"), None);
374        assert_eq!(map.get(&"baaaab"), None);
375        assert_eq!(map.get(&"c"), Some(&5));
376    }
377
378    #[test]
379    fn test_get_from_while() {
380        let mut map = TreeMap::default();
381
382        map.insert("a", 1);
383        map.insert("b", 2);
384        map.insert("baa", 3);
385        map.insert("baaab", 4);
386        map.insert("c", 5);
387
388        let result = map
389            .get_from_while(&"ba", |key, _| key.starts_with(&"ba"))
390            .collect::<Vec<_>>();
391
392        assert_eq!(result.len(), 2);
393        assert!(result.iter().find(|(k, _)| k == &&"baa").is_some());
394        assert!(result.iter().find(|(k, _)| k == &&"baaab").is_some());
395
396        let result = map
397            .get_from_while(&"c", |key, _| key.starts_with(&"c"))
398            .collect::<Vec<_>>();
399
400        assert_eq!(result.len(), 1);
401        assert!(result.iter().find(|(k, _)| k == &&"c").is_some());
402    }
403
404    #[test]
405    fn test_insert_tree() {
406        let mut map = TreeMap::default();
407        map.insert("a", 1);
408        map.insert("b", 2);
409        map.insert("c", 3);
410
411        let mut other = TreeMap::default();
412        other.insert("a", 2);
413        other.insert("b", 2);
414        other.insert("d", 4);
415
416        map.insert_tree(other);
417
418        assert_eq!(map.iter().count(), 4);
419        assert_eq!(map.get(&"a"), Some(&2));
420        assert_eq!(map.get(&"b"), Some(&2));
421        assert_eq!(map.get(&"c"), Some(&3));
422        assert_eq!(map.get(&"d"), Some(&4));
423    }
424}