1use std::{cmp::Ordering, fmt::Debug};
2
3use crate::{Bias, Dimension, Edit, Item, KeyedItem, SeekTarget, SumTree, Summary};
4
5/// A cheaply-cloneable ordered map based on a [SumTree](crate::SumTree).
6#[derive(Clone, PartialEq, Eq)]
7pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
8where
9 K: Clone + Ord,
10 V: Clone;
11
12#[derive(Clone, Debug, PartialEq, Eq)]
13pub struct MapEntry<K, V> {
14 key: K,
15 value: V,
16}
17
18#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
19pub struct MapKey<K>(Option<K>);
20
21impl<K> Default for MapKey<K> {
22 fn default() -> Self {
23 Self(None)
24 }
25}
26
27#[derive(Clone, Debug)]
28pub struct MapKeyRef<'a, K>(Option<&'a K>);
29
30impl<K> Default for MapKeyRef<'_, K> {
31 fn default() -> Self {
32 Self(None)
33 }
34}
35
36#[derive(Clone, Debug, PartialEq, Eq)]
37pub struct TreeSet<K>(TreeMap<K, ()>)
38where
39 K: Clone + Ord;
40
41impl<K: Clone + Ord, V: Clone> TreeMap<K, V> {
42 pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
43 let tree = SumTree::from_iter(
44 entries
45 .into_iter()
46 .map(|(key, value)| MapEntry { key, value }),
47 &(),
48 );
49 Self(tree)
50 }
51
52 pub fn is_empty(&self) -> bool {
53 self.0.is_empty()
54 }
55
56 pub fn get(&self, key: &K) -> Option<&V> {
57 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
58 cursor.seek(&MapKeyRef(Some(key)), Bias::Left);
59 if let Some(item) = cursor.item() {
60 if Some(key) == item.key().0.as_ref() {
61 Some(&item.value)
62 } else {
63 None
64 }
65 } else {
66 None
67 }
68 }
69
70 pub fn insert(&mut self, key: K, value: V) {
71 self.0.insert_or_replace(MapEntry { key, value }, &());
72 }
73
74 pub fn extend(&mut self, iter: impl IntoIterator<Item = (K, V)>) {
75 let edits: Vec<_> = iter
76 .into_iter()
77 .map(|(key, value)| Edit::Insert(MapEntry { key, value }))
78 .collect();
79 self.0.edit(edits, &());
80 }
81
82 pub fn clear(&mut self) {
83 self.0 = SumTree::default();
84 }
85
86 pub fn remove(&mut self, key: &K) -> Option<V> {
87 let mut removed = None;
88 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
89 let key = MapKeyRef(Some(key));
90 let mut new_tree = cursor.slice(&key, Bias::Left);
91 if key.cmp(&cursor.end(), &()) == Ordering::Equal {
92 removed = Some(cursor.item().unwrap().value.clone());
93 cursor.next();
94 }
95 new_tree.append(cursor.suffix(), &());
96 drop(cursor);
97 self.0 = new_tree;
98 removed
99 }
100
101 pub fn remove_range(&mut self, start: &impl MapSeekTarget<K>, end: &impl MapSeekTarget<K>) {
102 let start = MapSeekTargetAdaptor(start);
103 let end = MapSeekTargetAdaptor(end);
104 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
105 let mut new_tree = cursor.slice(&start, Bias::Left);
106 cursor.seek(&end, Bias::Left);
107 new_tree.append(cursor.suffix(), &());
108 drop(cursor);
109 self.0 = new_tree;
110 }
111
112 /// Returns the key-value pair with the greatest key less than or equal to the given key.
113 pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
114 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
115 let key = MapKeyRef(Some(key));
116 cursor.seek(&key, Bias::Right);
117 cursor.prev();
118 cursor.item().map(|item| (&item.key, &item.value))
119 }
120
121 pub fn iter_from<'a>(&'a self, from: &K) -> impl Iterator<Item = (&'a K, &'a V)> + 'a {
122 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
123 let from_key = MapKeyRef(Some(from));
124 cursor.seek(&from_key, Bias::Left);
125
126 cursor.map(|map_entry| (&map_entry.key, &map_entry.value))
127 }
128
129 pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
130 where
131 F: FnOnce(&mut V) -> T,
132 {
133 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
134 let key = MapKeyRef(Some(key));
135 let mut new_tree = cursor.slice(&key, Bias::Left);
136 let mut result = None;
137 if key.cmp(&cursor.end(), &()) == Ordering::Equal {
138 let mut updated = cursor.item().unwrap().clone();
139 result = Some(f(&mut updated.value));
140 new_tree.push(updated, &());
141 cursor.next();
142 }
143 new_tree.append(cursor.suffix(), &());
144 drop(cursor);
145 self.0 = new_tree;
146 result
147 }
148
149 pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
150 let mut new_map = SumTree::<MapEntry<K, V>>::default();
151
152 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
153 cursor.next();
154 while let Some(item) = cursor.item() {
155 if predicate(&item.key, &item.value) {
156 new_map.push(item.clone(), &());
157 }
158 cursor.next();
159 }
160 drop(cursor);
161
162 self.0 = new_map;
163 }
164
165 pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
166 self.0.iter().map(|entry| (&entry.key, &entry.value))
167 }
168
169 pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
170 self.0.iter().map(|entry| &entry.value)
171 }
172
173 pub fn first(&self) -> Option<(&K, &V)> {
174 self.0.first().map(|entry| (&entry.key, &entry.value))
175 }
176
177 pub fn last(&self) -> Option<(&K, &V)> {
178 self.0.last().map(|entry| (&entry.key, &entry.value))
179 }
180
181 pub fn insert_tree(&mut self, other: TreeMap<K, V>) {
182 let edits = other
183 .iter()
184 .map(|(key, value)| {
185 Edit::Insert(MapEntry {
186 key: key.to_owned(),
187 value: value.to_owned(),
188 })
189 })
190 .collect();
191
192 self.0.edit(edits, &());
193 }
194}
195
196impl<K, V> Debug for TreeMap<K, V>
197where
198 K: Clone + Debug + Ord,
199 V: Clone + Debug,
200{
201 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202 f.debug_map().entries(self.iter()).finish()
203 }
204}
205
206#[derive(Debug)]
207struct MapSeekTargetAdaptor<'a, T>(&'a T);
208
209impl<'a, K: Clone + Ord, T: MapSeekTarget<K>> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>>
210 for MapSeekTargetAdaptor<'_, T>
211{
212 fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
213 if let Some(key) = &cursor_location.0 {
214 MapSeekTarget::cmp_cursor(self.0, key)
215 } else {
216 Ordering::Greater
217 }
218 }
219}
220
221pub trait MapSeekTarget<K> {
222 fn cmp_cursor(&self, cursor_location: &K) -> Ordering;
223}
224
225impl<K: Ord> MapSeekTarget<K> for K {
226 fn cmp_cursor(&self, cursor_location: &K) -> Ordering {
227 self.cmp(cursor_location)
228 }
229}
230
231impl<K, V> Default for TreeMap<K, V>
232where
233 K: Clone + Ord,
234 V: Clone,
235{
236 fn default() -> Self {
237 Self(Default::default())
238 }
239}
240
241impl<K, V> Item for MapEntry<K, V>
242where
243 K: Clone + Ord,
244 V: Clone,
245{
246 type Summary = MapKey<K>;
247
248 fn summary(&self, _cx: &()) -> Self::Summary {
249 self.key()
250 }
251}
252
253impl<K, V> KeyedItem for MapEntry<K, V>
254where
255 K: Clone + Ord,
256 V: Clone,
257{
258 type Key = MapKey<K>;
259
260 fn key(&self) -> Self::Key {
261 MapKey(Some(self.key.clone()))
262 }
263}
264
265impl<K> Summary for MapKey<K>
266where
267 K: Clone,
268{
269 type Context = ();
270
271 fn zero(_cx: &()) -> Self {
272 Default::default()
273 }
274
275 fn add_summary(&mut self, summary: &Self, _: &()) {
276 *self = summary.clone()
277 }
278}
279
280impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
281where
282 K: Clone + Ord,
283{
284 fn zero(_cx: &()) -> Self {
285 Default::default()
286 }
287
288 fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
289 self.0 = summary.0.as_ref();
290 }
291}
292
293impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
294where
295 K: Clone + Ord,
296{
297 fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
298 Ord::cmp(&self.0, &cursor_location.0)
299 }
300}
301
302impl<K> Default for TreeSet<K>
303where
304 K: Clone + Ord,
305{
306 fn default() -> Self {
307 Self(Default::default())
308 }
309}
310
311impl<K> TreeSet<K>
312where
313 K: Clone + Ord,
314{
315 pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
316 Self(TreeMap::from_ordered_entries(
317 entries.into_iter().map(|key| (key, ())),
318 ))
319 }
320
321 pub fn is_empty(&self) -> bool {
322 self.0.is_empty()
323 }
324
325 pub fn insert(&mut self, key: K) {
326 self.0.insert(key, ());
327 }
328
329 pub fn remove(&mut self, key: &K) -> bool {
330 self.0.remove(key).is_some()
331 }
332
333 pub fn extend(&mut self, iter: impl IntoIterator<Item = K>) {
334 self.0.extend(iter.into_iter().map(|key| (key, ())));
335 }
336
337 pub fn contains(&self, key: &K) -> bool {
338 self.0.get(key).is_some()
339 }
340
341 pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
342 self.0.iter().map(|(k, _)| k)
343 }
344
345 pub fn iter_from<'a>(&'a self, key: &K) -> impl Iterator<Item = &'a K> + 'a {
346 self.0.iter_from(key).map(move |(k, _)| k)
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[test]
355 fn test_basic() {
356 let mut map = TreeMap::default();
357 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
358
359 map.insert(3, "c");
360 assert_eq!(map.get(&3), Some(&"c"));
361 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
362
363 map.insert(1, "a");
364 assert_eq!(map.get(&1), Some(&"a"));
365 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
366
367 map.insert(2, "b");
368 assert_eq!(map.get(&2), Some(&"b"));
369 assert_eq!(map.get(&1), Some(&"a"));
370 assert_eq!(map.get(&3), Some(&"c"));
371 assert_eq!(
372 map.iter().collect::<Vec<_>>(),
373 vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
374 );
375
376 assert_eq!(map.closest(&0), None);
377 assert_eq!(map.closest(&1), Some((&1, &"a")));
378 assert_eq!(map.closest(&10), Some((&3, &"c")));
379
380 map.remove(&2);
381 assert_eq!(map.get(&2), None);
382 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
383
384 assert_eq!(map.closest(&2), Some((&1, &"a")));
385
386 map.remove(&3);
387 assert_eq!(map.get(&3), None);
388 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
389
390 map.remove(&1);
391 assert_eq!(map.get(&1), None);
392 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
393
394 map.insert(4, "d");
395 map.insert(5, "e");
396 map.insert(6, "f");
397 map.retain(|key, _| *key % 2 == 0);
398 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
399 }
400
401 #[test]
402 fn test_iter_from() {
403 let mut map = TreeMap::default();
404
405 map.insert("a", 1);
406 map.insert("b", 2);
407 map.insert("baa", 3);
408 map.insert("baaab", 4);
409 map.insert("c", 5);
410
411 let result = map
412 .iter_from(&"ba")
413 .take_while(|(key, _)| key.starts_with("ba"))
414 .collect::<Vec<_>>();
415
416 assert_eq!(result.len(), 2);
417 assert!(result.iter().any(|(k, _)| k == &&"baa"));
418 assert!(result.iter().any(|(k, _)| k == &&"baaab"));
419
420 let result = map
421 .iter_from(&"c")
422 .take_while(|(key, _)| key.starts_with("c"))
423 .collect::<Vec<_>>();
424
425 assert_eq!(result.len(), 1);
426 assert!(result.iter().any(|(k, _)| k == &&"c"));
427 }
428
429 #[test]
430 fn test_insert_tree() {
431 let mut map = TreeMap::default();
432 map.insert("a", 1);
433 map.insert("b", 2);
434 map.insert("c", 3);
435
436 let mut other = TreeMap::default();
437 other.insert("a", 2);
438 other.insert("b", 2);
439 other.insert("d", 4);
440
441 map.insert_tree(other);
442
443 assert_eq!(map.iter().count(), 4);
444 assert_eq!(map.get(&"a"), Some(&2));
445 assert_eq!(map.get(&"b"), Some(&2));
446 assert_eq!(map.get(&"c"), Some(&3));
447 assert_eq!(map.get(&"d"), Some(&4));
448 }
449
450 #[test]
451 fn test_extend() {
452 let mut map = TreeMap::default();
453 map.insert("a", 1);
454 map.insert("b", 2);
455 map.insert("c", 3);
456 map.extend([("a", 2), ("b", 2), ("d", 4)]);
457 assert_eq!(map.iter().count(), 4);
458 assert_eq!(map.get(&"a"), Some(&2));
459 assert_eq!(map.get(&"b"), Some(&2));
460 assert_eq!(map.get(&"c"), Some(&3));
461 assert_eq!(map.get(&"d"), Some(&4));
462 }
463
464 #[test]
465 fn test_remove_between_and_path_successor() {
466 use std::path::{Path, PathBuf};
467
468 #[derive(Debug)]
469 pub struct PathDescendants<'a>(&'a Path);
470
471 impl MapSeekTarget<PathBuf> for PathDescendants<'_> {
472 fn cmp_cursor(&self, key: &PathBuf) -> Ordering {
473 if key.starts_with(self.0) {
474 Ordering::Greater
475 } else {
476 self.0.cmp(key)
477 }
478 }
479 }
480
481 let mut map = TreeMap::default();
482
483 map.insert(PathBuf::from("a"), 1);
484 map.insert(PathBuf::from("a/a"), 1);
485 map.insert(PathBuf::from("b"), 2);
486 map.insert(PathBuf::from("b/a/a"), 3);
487 map.insert(PathBuf::from("b/a/a/a/b"), 4);
488 map.insert(PathBuf::from("c"), 5);
489 map.insert(PathBuf::from("c/a"), 6);
490
491 map.remove_range(
492 &PathBuf::from("b/a"),
493 &PathDescendants(&PathBuf::from("b/a")),
494 );
495
496 assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
497 assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
498 assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
499 assert_eq!(map.get(&PathBuf::from("b/a/a")), None);
500 assert_eq!(map.get(&PathBuf::from("b/a/a/a/b")), None);
501 assert_eq!(map.get(&PathBuf::from("c")), Some(&5));
502 assert_eq!(map.get(&PathBuf::from("c/a")), Some(&6));
503
504 map.remove_range(&PathBuf::from("c"), &PathDescendants(&PathBuf::from("c")));
505
506 assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
507 assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
508 assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
509 assert_eq!(map.get(&PathBuf::from("c")), None);
510 assert_eq!(map.get(&PathBuf::from("c/a")), None);
511
512 map.remove_range(&PathBuf::from("a"), &PathDescendants(&PathBuf::from("a")));
513
514 assert_eq!(map.get(&PathBuf::from("a")), None);
515 assert_eq!(map.get(&PathBuf::from("a/a")), None);
516 assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
517
518 map.remove_range(&PathBuf::from("b"), &PathDescendants(&PathBuf::from("b")));
519
520 assert_eq!(map.get(&PathBuf::from("b")), None);
521 }
522}