1use std::{cmp::Ordering, fmt::Debug};
2
3use crate::{Bias, ContextLessSummary, Dimension, Edit, Item, KeyedItem, SeekTarget, SumTree};
4
5/// A cheaply-cloneable ordered map based on a [SumTree](crate::SumTree).
6#[derive(Clone, PartialEq, Eq)]
7pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
8where
9 K: Clone + Ord,
10 V: Clone;
11
12#[derive(Clone, Debug, PartialEq, Eq)]
13pub struct MapEntry<K, V> {
14 key: K,
15 value: V,
16}
17
18#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
19pub struct MapKey<K>(Option<K>);
20
21impl<K> Default for MapKey<K> {
22 fn default() -> Self {
23 Self(None)
24 }
25}
26
27#[derive(Clone, Debug)]
28pub struct MapKeyRef<'a, K>(Option<&'a K>);
29
30impl<K> Default for MapKeyRef<'_, K> {
31 fn default() -> Self {
32 Self(None)
33 }
34}
35
36#[derive(Clone, Debug, PartialEq, Eq)]
37pub struct TreeSet<K>(TreeMap<K, ()>)
38where
39 K: Clone + Ord;
40
41impl<K: Clone + Ord, V: Clone> TreeMap<K, V> {
42 pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
43 let tree = SumTree::from_iter(
44 entries
45 .into_iter()
46 .map(|(key, value)| MapEntry { key, value }),
47 (),
48 );
49 Self(tree)
50 }
51
52 pub fn is_empty(&self) -> bool {
53 self.0.is_empty()
54 }
55
56 pub fn contains_key(&self, key: &K) -> bool {
57 self.get(key).is_some()
58 }
59
60 pub fn get(&self, key: &K) -> Option<&V> {
61 let (.., item) = self
62 .0
63 .find::<MapKeyRef<'_, K>, _>((), &MapKeyRef(Some(key)), Bias::Left);
64 if let Some(item) = item {
65 if Some(key) == item.key().0.as_ref() {
66 Some(&item.value)
67 } else {
68 None
69 }
70 } else {
71 None
72 }
73 }
74
75 pub fn insert(&mut self, key: K, value: V) {
76 self.0.insert_or_replace(MapEntry { key, value }, ());
77 }
78
79 pub fn insert_or_replace(&mut self, key: K, value: V) -> Option<V> {
80 self.0
81 .insert_or_replace(MapEntry { key, value }, ())
82 .map(|it| it.value)
83 }
84
85 pub fn extend(&mut self, iter: impl IntoIterator<Item = (K, V)>) {
86 let edits: Vec<_> = iter
87 .into_iter()
88 .map(|(key, value)| Edit::Insert(MapEntry { key, value }))
89 .collect();
90 self.0.edit(edits, ());
91 }
92
93 pub fn clear(&mut self) {
94 self.0 = SumTree::default();
95 }
96
97 pub fn remove(&mut self, key: &K) -> Option<V> {
98 let mut removed = None;
99 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
100 let key = MapKeyRef(Some(key));
101 let mut new_tree = cursor.slice(&key, Bias::Left);
102 if key.cmp(&cursor.end(), ()) == Ordering::Equal {
103 removed = Some(cursor.item().unwrap().value.clone());
104 cursor.next();
105 }
106 new_tree.append(cursor.suffix(), ());
107 drop(cursor);
108 self.0 = new_tree;
109 removed
110 }
111
112 pub fn remove_range(&mut self, start: &impl MapSeekTarget<K>, end: &impl MapSeekTarget<K>) {
113 let start = MapSeekTargetAdaptor(start);
114 let end = MapSeekTargetAdaptor(end);
115 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
116 let mut new_tree = cursor.slice(&start, Bias::Left);
117 cursor.seek(&end, Bias::Left);
118 new_tree.append(cursor.suffix(), ());
119 drop(cursor);
120 self.0 = new_tree;
121 }
122
123 /// Returns the key-value pair with the greatest key less than or equal to the given key.
124 pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
125 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
126 let key = MapKeyRef(Some(key));
127 cursor.seek(&key, Bias::Right);
128 cursor.prev();
129 cursor.item().map(|item| (&item.key, &item.value))
130 }
131
132 pub fn iter_from<'a>(&'a self, from: &K) -> impl Iterator<Item = (&'a K, &'a V)> + 'a {
133 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
134 let from_key = MapKeyRef(Some(from));
135 cursor.seek(&from_key, Bias::Left);
136
137 cursor.map(|map_entry| (&map_entry.key, &map_entry.value))
138 }
139
140 pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
141 where
142 F: FnOnce(&mut V) -> T,
143 {
144 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
145 let key = MapKeyRef(Some(key));
146 let mut new_tree = cursor.slice(&key, Bias::Left);
147 let mut result = None;
148 if key.cmp(&cursor.end(), ()) == Ordering::Equal {
149 let mut updated = cursor.item().unwrap().clone();
150 result = Some(f(&mut updated.value));
151 new_tree.push(updated, ());
152 cursor.next();
153 }
154 new_tree.append(cursor.suffix(), ());
155 drop(cursor);
156 self.0 = new_tree;
157 result
158 }
159
160 pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
161 let mut new_map = SumTree::<MapEntry<K, V>>::default();
162
163 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
164 cursor.next();
165 while let Some(item) = cursor.item() {
166 if predicate(&item.key, &item.value) {
167 new_map.push(item.clone(), ());
168 }
169 cursor.next();
170 }
171 drop(cursor);
172
173 self.0 = new_map;
174 }
175
176 pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
177 self.0.iter().map(|entry| (&entry.key, &entry.value))
178 }
179
180 pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
181 self.0.iter().map(|entry| &entry.value)
182 }
183
184 pub fn first(&self) -> Option<(&K, &V)> {
185 self.0.first().map(|entry| (&entry.key, &entry.value))
186 }
187
188 pub fn last(&self) -> Option<(&K, &V)> {
189 self.0.last().map(|entry| (&entry.key, &entry.value))
190 }
191
192 pub fn insert_tree(&mut self, other: TreeMap<K, V>) {
193 let edits = other
194 .iter()
195 .map(|(key, value)| {
196 Edit::Insert(MapEntry {
197 key: key.to_owned(),
198 value: value.to_owned(),
199 })
200 })
201 .collect();
202
203 self.0.edit(edits, ());
204 }
205}
206
207impl<K, V> Debug for TreeMap<K, V>
208where
209 K: Clone + Debug + Ord,
210 V: Clone + Debug,
211{
212 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
213 f.debug_map().entries(self.iter()).finish()
214 }
215}
216
217#[derive(Debug)]
218struct MapSeekTargetAdaptor<'a, T>(&'a T);
219
220impl<'a, K: Clone + Ord, T: MapSeekTarget<K>> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>>
221 for MapSeekTargetAdaptor<'_, T>
222{
223 fn cmp(&self, cursor_location: &MapKeyRef<K>, _: ()) -> Ordering {
224 if let Some(key) = &cursor_location.0 {
225 MapSeekTarget::cmp_cursor(self.0, key)
226 } else {
227 Ordering::Greater
228 }
229 }
230}
231
232pub trait MapSeekTarget<K> {
233 fn cmp_cursor(&self, cursor_location: &K) -> Ordering;
234}
235
236impl<K: Ord> MapSeekTarget<K> for K {
237 fn cmp_cursor(&self, cursor_location: &K) -> Ordering {
238 self.cmp(cursor_location)
239 }
240}
241
242impl<K, V> Default for TreeMap<K, V>
243where
244 K: Clone + Ord,
245 V: Clone,
246{
247 fn default() -> Self {
248 Self(Default::default())
249 }
250}
251
252impl<K, V> Item for MapEntry<K, V>
253where
254 K: Clone + Ord,
255 V: Clone,
256{
257 type Summary = MapKey<K>;
258
259 fn summary(&self, _cx: ()) -> Self::Summary {
260 self.key()
261 }
262}
263
264impl<K, V> KeyedItem for MapEntry<K, V>
265where
266 K: Clone + Ord,
267 V: Clone,
268{
269 type Key = MapKey<K>;
270
271 fn key(&self) -> Self::Key {
272 MapKey(Some(self.key.clone()))
273 }
274}
275
276impl<K> ContextLessSummary for MapKey<K>
277where
278 K: Clone,
279{
280 fn zero() -> Self {
281 Default::default()
282 }
283
284 fn add_summary(&mut self, summary: &Self) {
285 *self = summary.clone()
286 }
287}
288
289impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
290where
291 K: Clone + Ord,
292{
293 fn zero(_cx: ()) -> Self {
294 Default::default()
295 }
296
297 fn add_summary(&mut self, summary: &'a MapKey<K>, _: ()) {
298 self.0 = summary.0.as_ref();
299 }
300}
301
302impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
303where
304 K: Clone + Ord,
305{
306 fn cmp(&self, cursor_location: &MapKeyRef<K>, _: ()) -> Ordering {
307 Ord::cmp(&self.0, &cursor_location.0)
308 }
309}
310
311impl<K> Default for TreeSet<K>
312where
313 K: Clone + Ord,
314{
315 fn default() -> Self {
316 Self(Default::default())
317 }
318}
319
320impl<K> TreeSet<K>
321where
322 K: Clone + Ord,
323{
324 pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
325 Self(TreeMap::from_ordered_entries(
326 entries.into_iter().map(|key| (key, ())),
327 ))
328 }
329
330 pub fn is_empty(&self) -> bool {
331 self.0.is_empty()
332 }
333
334 pub fn insert(&mut self, key: K) {
335 self.0.insert(key, ());
336 }
337
338 pub fn remove(&mut self, key: &K) -> bool {
339 self.0.remove(key).is_some()
340 }
341
342 pub fn extend(&mut self, iter: impl IntoIterator<Item = K>) {
343 self.0.extend(iter.into_iter().map(|key| (key, ())));
344 }
345
346 pub fn contains(&self, key: &K) -> bool {
347 self.0.get(key).is_some()
348 }
349
350 pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
351 self.0.iter().map(|(k, _)| k)
352 }
353
354 pub fn iter_from<'a>(&'a self, key: &K) -> impl Iterator<Item = &'a K> + 'a {
355 self.0.iter_from(key).map(move |(k, _)| k)
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362
363 #[test]
364 fn test_basic() {
365 let mut map = TreeMap::default();
366 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
367
368 map.insert(3, "c");
369 assert_eq!(map.get(&3), Some(&"c"));
370 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
371
372 map.insert(1, "a");
373 assert_eq!(map.get(&1), Some(&"a"));
374 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
375
376 map.insert(2, "b");
377 assert_eq!(map.get(&2), Some(&"b"));
378 assert_eq!(map.get(&1), Some(&"a"));
379 assert_eq!(map.get(&3), Some(&"c"));
380 assert_eq!(
381 map.iter().collect::<Vec<_>>(),
382 vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
383 );
384
385 assert_eq!(map.closest(&0), None);
386 assert_eq!(map.closest(&1), Some((&1, &"a")));
387 assert_eq!(map.closest(&10), Some((&3, &"c")));
388
389 map.remove(&2);
390 assert_eq!(map.get(&2), None);
391 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
392
393 assert_eq!(map.closest(&2), Some((&1, &"a")));
394
395 map.remove(&3);
396 assert_eq!(map.get(&3), None);
397 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
398
399 map.remove(&1);
400 assert_eq!(map.get(&1), None);
401 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
402
403 map.insert(4, "d");
404 map.insert(5, "e");
405 map.insert(6, "f");
406 map.retain(|key, _| *key % 2 == 0);
407 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
408 }
409
410 #[test]
411 fn test_iter_from() {
412 let mut map = TreeMap::default();
413
414 map.insert("a", 1);
415 map.insert("b", 2);
416 map.insert("baa", 3);
417 map.insert("baaab", 4);
418 map.insert("c", 5);
419
420 let result = map
421 .iter_from(&"ba")
422 .take_while(|(key, _)| key.starts_with("ba"))
423 .collect::<Vec<_>>();
424
425 assert_eq!(result.len(), 2);
426 assert!(result.iter().any(|(k, _)| k == &&"baa"));
427 assert!(result.iter().any(|(k, _)| k == &&"baaab"));
428
429 let result = map
430 .iter_from(&"c")
431 .take_while(|(key, _)| key.starts_with("c"))
432 .collect::<Vec<_>>();
433
434 assert_eq!(result.len(), 1);
435 assert!(result.iter().any(|(k, _)| k == &&"c"));
436 }
437
438 #[test]
439 fn test_insert_tree() {
440 let mut map = TreeMap::default();
441 map.insert("a", 1);
442 map.insert("b", 2);
443 map.insert("c", 3);
444
445 let mut other = TreeMap::default();
446 other.insert("a", 2);
447 other.insert("b", 2);
448 other.insert("d", 4);
449
450 map.insert_tree(other);
451
452 assert_eq!(map.iter().count(), 4);
453 assert_eq!(map.get(&"a"), Some(&2));
454 assert_eq!(map.get(&"b"), Some(&2));
455 assert_eq!(map.get(&"c"), Some(&3));
456 assert_eq!(map.get(&"d"), Some(&4));
457 }
458
459 #[test]
460 fn test_extend() {
461 let mut map = TreeMap::default();
462 map.insert("a", 1);
463 map.insert("b", 2);
464 map.insert("c", 3);
465 map.extend([("a", 2), ("b", 2), ("d", 4)]);
466 assert_eq!(map.iter().count(), 4);
467 assert_eq!(map.get(&"a"), Some(&2));
468 assert_eq!(map.get(&"b"), Some(&2));
469 assert_eq!(map.get(&"c"), Some(&3));
470 assert_eq!(map.get(&"d"), Some(&4));
471 }
472
473 #[test]
474 fn test_remove_between_and_path_successor() {
475 use std::path::{Path, PathBuf};
476
477 #[derive(Debug)]
478 pub struct PathDescendants<'a>(&'a Path);
479
480 impl MapSeekTarget<PathBuf> for PathDescendants<'_> {
481 fn cmp_cursor(&self, key: &PathBuf) -> Ordering {
482 if key.starts_with(self.0) {
483 Ordering::Greater
484 } else {
485 self.0.cmp(key)
486 }
487 }
488 }
489
490 let mut map = TreeMap::default();
491
492 map.insert(PathBuf::from("a"), 1);
493 map.insert(PathBuf::from("a/a"), 1);
494 map.insert(PathBuf::from("b"), 2);
495 map.insert(PathBuf::from("b/a/a"), 3);
496 map.insert(PathBuf::from("b/a/a/a/b"), 4);
497 map.insert(PathBuf::from("c"), 5);
498 map.insert(PathBuf::from("c/a"), 6);
499
500 map.remove_range(
501 &PathBuf::from("b/a"),
502 &PathDescendants(&PathBuf::from("b/a")),
503 );
504
505 assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
506 assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
507 assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
508 assert_eq!(map.get(&PathBuf::from("b/a/a")), None);
509 assert_eq!(map.get(&PathBuf::from("b/a/a/a/b")), None);
510 assert_eq!(map.get(&PathBuf::from("c")), Some(&5));
511 assert_eq!(map.get(&PathBuf::from("c/a")), Some(&6));
512
513 map.remove_range(&PathBuf::from("c"), &PathDescendants(&PathBuf::from("c")));
514
515 assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
516 assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
517 assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
518 assert_eq!(map.get(&PathBuf::from("c")), None);
519 assert_eq!(map.get(&PathBuf::from("c/a")), None);
520
521 map.remove_range(&PathBuf::from("a"), &PathDescendants(&PathBuf::from("a")));
522
523 assert_eq!(map.get(&PathBuf::from("a")), None);
524 assert_eq!(map.get(&PathBuf::from("a/a")), None);
525 assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
526
527 map.remove_range(&PathBuf::from("b"), &PathDescendants(&PathBuf::from("b")));
528
529 assert_eq!(map.get(&PathBuf::from("b")), None);
530 }
531}