1use std::{cmp::Ordering, fmt::Debug};
2
3use crate::{Bias, ContextLessSummary, Dimension, Edit, Item, KeyedItem, SeekTarget, SumTree};
4
5/// A cheaply-cloneable ordered map based on a [SumTree](crate::SumTree).
6#[derive(Clone, PartialEq, Eq)]
7pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
8where
9 K: Clone + Ord,
10 V: Clone;
11
12#[derive(Clone, Debug, PartialEq, Eq)]
13pub struct MapEntry<K, V> {
14 key: K,
15 value: V,
16}
17
18#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
19pub struct MapKey<K>(Option<K>);
20
21impl<K> Default for MapKey<K> {
22 fn default() -> Self {
23 Self(None)
24 }
25}
26
27#[derive(Clone, Debug)]
28pub struct MapKeyRef<'a, K>(Option<&'a K>);
29
30impl<K> Default for MapKeyRef<'_, K> {
31 fn default() -> Self {
32 Self(None)
33 }
34}
35
36#[derive(Clone, Debug, PartialEq, Eq)]
37pub struct TreeSet<K>(TreeMap<K, ()>)
38where
39 K: Clone + Ord;
40
41impl<K: Clone + Ord, V: Clone> TreeMap<K, V> {
42 pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
43 let tree = SumTree::from_iter(
44 entries
45 .into_iter()
46 .map(|(key, value)| MapEntry { key, value }),
47 (),
48 );
49 Self(tree)
50 }
51
52 pub fn is_empty(&self) -> bool {
53 self.0.is_empty()
54 }
55
56 pub fn get(&self, key: &K) -> Option<&V> {
57 let (.., item) = self
58 .0
59 .find::<MapKeyRef<'_, K>, _>((), &MapKeyRef(Some(key)), Bias::Left);
60 if let Some(item) = item {
61 if Some(key) == item.key().0.as_ref() {
62 Some(&item.value)
63 } else {
64 None
65 }
66 } else {
67 None
68 }
69 }
70
71 pub fn insert(&mut self, key: K, value: V) {
72 self.0.insert_or_replace(MapEntry { key, value }, ());
73 }
74
75 pub fn insert_or_replace(&mut self, key: K, value: V) -> Option<V> {
76 self.0
77 .insert_or_replace(MapEntry { key, value }, ())
78 .map(|it| it.value)
79 }
80
81 pub fn extend(&mut self, iter: impl IntoIterator<Item = (K, V)>) {
82 let edits: Vec<_> = iter
83 .into_iter()
84 .map(|(key, value)| Edit::Insert(MapEntry { key, value }))
85 .collect();
86 if edits.is_empty() {
87 return;
88 }
89 self.0.edit(edits, ());
90 }
91
92 pub fn clear(&mut self) {
93 self.0 = SumTree::default();
94 }
95
96 pub fn remove(&mut self, key: &K) -> Option<V> {
97 let mut removed = None;
98 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
99 let key = MapKeyRef(Some(key));
100 let mut new_tree = cursor.slice(&key, Bias::Left);
101 if key.cmp(&cursor.end(), ()) == Ordering::Equal {
102 removed = Some(cursor.item().unwrap().value.clone());
103 cursor.next();
104 }
105 new_tree.append(cursor.suffix(), ());
106 drop(cursor);
107 self.0 = new_tree;
108 removed
109 }
110
111 pub fn remove_range(&mut self, start: &impl MapSeekTarget<K>, end: &impl MapSeekTarget<K>) {
112 let start = MapSeekTargetAdaptor(start);
113 let end = MapSeekTargetAdaptor(end);
114 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
115 let mut new_tree = cursor.slice(&start, Bias::Left);
116 cursor.seek(&end, Bias::Left);
117 new_tree.append(cursor.suffix(), ());
118 drop(cursor);
119 self.0 = new_tree;
120 }
121
122 /// Returns the key-value pair with the greatest key less than or equal to the given key.
123 pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
124 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
125 let key = MapKeyRef(Some(key));
126 cursor.seek(&key, Bias::Right);
127 cursor.prev();
128 cursor.item().map(|item| (&item.key, &item.value))
129 }
130
131 pub fn iter_from<'a>(&'a self, from: &K) -> impl Iterator<Item = (&'a K, &'a V)> + 'a {
132 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
133 let from_key = MapKeyRef(Some(from));
134 cursor.seek(&from_key, Bias::Left);
135
136 cursor.map(|map_entry| (&map_entry.key, &map_entry.value))
137 }
138
139 pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
140 where
141 F: FnOnce(&mut V) -> T,
142 {
143 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
144 let key = MapKeyRef(Some(key));
145 let mut new_tree = cursor.slice(&key, Bias::Left);
146 let mut result = None;
147 if key.cmp(&cursor.end(), ()) == Ordering::Equal {
148 let mut updated = cursor.item().unwrap().clone();
149 result = Some(f(&mut updated.value));
150 new_tree.push(updated, ());
151 cursor.next();
152 }
153 new_tree.append(cursor.suffix(), ());
154 drop(cursor);
155 self.0 = new_tree;
156 result
157 }
158
159 pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
160 let mut new_map = SumTree::<MapEntry<K, V>>::default();
161
162 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
163 cursor.next();
164 while let Some(item) = cursor.item() {
165 if predicate(&item.key, &item.value) {
166 new_map.push(item.clone(), ());
167 }
168 cursor.next();
169 }
170 drop(cursor);
171
172 self.0 = new_map;
173 }
174
175 pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
176 self.0.iter().map(|entry| (&entry.key, &entry.value))
177 }
178
179 pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
180 self.0.iter().map(|entry| &entry.value)
181 }
182
183 pub fn first(&self) -> Option<(&K, &V)> {
184 self.0.first().map(|entry| (&entry.key, &entry.value))
185 }
186
187 pub fn last(&self) -> Option<(&K, &V)> {
188 self.0.last().map(|entry| (&entry.key, &entry.value))
189 }
190
191 pub fn insert_tree(&mut self, other: TreeMap<K, V>) {
192 let edits = other
193 .iter()
194 .map(|(key, value)| {
195 Edit::Insert(MapEntry {
196 key: key.to_owned(),
197 value: value.to_owned(),
198 })
199 })
200 .collect();
201
202 self.0.edit(edits, ());
203 }
204}
205
206impl<K, V> Debug for TreeMap<K, V>
207where
208 K: Clone + Debug + Ord,
209 V: Clone + Debug,
210{
211 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
212 f.debug_map().entries(self.iter()).finish()
213 }
214}
215
216#[derive(Debug)]
217struct MapSeekTargetAdaptor<'a, T>(&'a T);
218
219impl<'a, K: Clone + Ord, T: MapSeekTarget<K>> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>>
220 for MapSeekTargetAdaptor<'_, T>
221{
222 fn cmp(&self, cursor_location: &MapKeyRef<K>, _: ()) -> Ordering {
223 if let Some(key) = &cursor_location.0 {
224 MapSeekTarget::cmp_cursor(self.0, key)
225 } else {
226 Ordering::Greater
227 }
228 }
229}
230
231pub trait MapSeekTarget<K> {
232 fn cmp_cursor(&self, cursor_location: &K) -> Ordering;
233}
234
235impl<K: Ord> MapSeekTarget<K> for K {
236 fn cmp_cursor(&self, cursor_location: &K) -> Ordering {
237 self.cmp(cursor_location)
238 }
239}
240
241impl<K, V> Default for TreeMap<K, V>
242where
243 K: Clone + Ord,
244 V: Clone,
245{
246 fn default() -> Self {
247 Self(Default::default())
248 }
249}
250
251impl<K, V> Item for MapEntry<K, V>
252where
253 K: Clone + Ord,
254 V: Clone,
255{
256 type Summary = MapKey<K>;
257
258 fn summary(&self, _cx: ()) -> Self::Summary {
259 self.key()
260 }
261}
262
263impl<K, V> KeyedItem for MapEntry<K, V>
264where
265 K: Clone + Ord,
266 V: Clone,
267{
268 type Key = MapKey<K>;
269
270 fn key(&self) -> Self::Key {
271 MapKey(Some(self.key.clone()))
272 }
273}
274
275impl<K> ContextLessSummary for MapKey<K>
276where
277 K: Clone,
278{
279 fn zero() -> Self {
280 Default::default()
281 }
282
283 fn add_summary(&mut self, summary: &Self) {
284 *self = summary.clone()
285 }
286}
287
288impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
289where
290 K: Clone + Ord,
291{
292 fn zero(_cx: ()) -> Self {
293 Default::default()
294 }
295
296 fn add_summary(&mut self, summary: &'a MapKey<K>, _: ()) {
297 self.0 = summary.0.as_ref();
298 }
299}
300
301impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
302where
303 K: Clone + Ord,
304{
305 fn cmp(&self, cursor_location: &MapKeyRef<K>, _: ()) -> Ordering {
306 Ord::cmp(&self.0, &cursor_location.0)
307 }
308}
309
310impl<K> Default for TreeSet<K>
311where
312 K: Clone + Ord,
313{
314 fn default() -> Self {
315 Self(Default::default())
316 }
317}
318
319impl<K> TreeSet<K>
320where
321 K: Clone + Ord,
322{
323 pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
324 Self(TreeMap::from_ordered_entries(
325 entries.into_iter().map(|key| (key, ())),
326 ))
327 }
328
329 pub fn is_empty(&self) -> bool {
330 self.0.is_empty()
331 }
332
333 pub fn insert(&mut self, key: K) {
334 self.0.insert(key, ());
335 }
336
337 pub fn remove(&mut self, key: &K) -> bool {
338 self.0.remove(key).is_some()
339 }
340
341 pub fn extend(&mut self, iter: impl IntoIterator<Item = K>) {
342 self.0.extend(iter.into_iter().map(|key| (key, ())));
343 }
344
345 pub fn contains(&self, key: &K) -> bool {
346 self.0.get(key).is_some()
347 }
348
349 pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
350 self.0.iter().map(|(k, _)| k)
351 }
352
353 pub fn iter_from<'a>(&'a self, key: &K) -> impl Iterator<Item = &'a K> + 'a {
354 self.0.iter_from(key).map(move |(k, _)| k)
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361
362 #[test]
363 fn test_basic() {
364 let mut map = TreeMap::default();
365 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
366
367 map.insert(3, "c");
368 assert_eq!(map.get(&3), Some(&"c"));
369 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
370
371 map.insert(1, "a");
372 assert_eq!(map.get(&1), Some(&"a"));
373 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
374
375 map.insert(2, "b");
376 assert_eq!(map.get(&2), Some(&"b"));
377 assert_eq!(map.get(&1), Some(&"a"));
378 assert_eq!(map.get(&3), Some(&"c"));
379 assert_eq!(
380 map.iter().collect::<Vec<_>>(),
381 vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
382 );
383
384 assert_eq!(map.closest(&0), None);
385 assert_eq!(map.closest(&1), Some((&1, &"a")));
386 assert_eq!(map.closest(&10), Some((&3, &"c")));
387
388 map.remove(&2);
389 assert_eq!(map.get(&2), None);
390 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
391
392 assert_eq!(map.closest(&2), Some((&1, &"a")));
393
394 map.remove(&3);
395 assert_eq!(map.get(&3), None);
396 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
397
398 map.remove(&1);
399 assert_eq!(map.get(&1), None);
400 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
401
402 map.insert(4, "d");
403 map.insert(5, "e");
404 map.insert(6, "f");
405 map.retain(|key, _| *key % 2 == 0);
406 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
407 }
408
409 #[test]
410 fn test_iter_from() {
411 let mut map = TreeMap::default();
412
413 map.insert("a", 1);
414 map.insert("b", 2);
415 map.insert("baa", 3);
416 map.insert("baaab", 4);
417 map.insert("c", 5);
418
419 let result = map
420 .iter_from(&"ba")
421 .take_while(|(key, _)| key.starts_with("ba"))
422 .collect::<Vec<_>>();
423
424 assert_eq!(result.len(), 2);
425 assert!(result.iter().any(|(k, _)| k == &&"baa"));
426 assert!(result.iter().any(|(k, _)| k == &&"baaab"));
427
428 let result = map
429 .iter_from(&"c")
430 .take_while(|(key, _)| key.starts_with("c"))
431 .collect::<Vec<_>>();
432
433 assert_eq!(result.len(), 1);
434 assert!(result.iter().any(|(k, _)| k == &&"c"));
435 }
436
437 #[test]
438 fn test_insert_tree() {
439 let mut map = TreeMap::default();
440 map.insert("a", 1);
441 map.insert("b", 2);
442 map.insert("c", 3);
443
444 let mut other = TreeMap::default();
445 other.insert("a", 2);
446 other.insert("b", 2);
447 other.insert("d", 4);
448
449 map.insert_tree(other);
450
451 assert_eq!(map.iter().count(), 4);
452 assert_eq!(map.get(&"a"), Some(&2));
453 assert_eq!(map.get(&"b"), Some(&2));
454 assert_eq!(map.get(&"c"), Some(&3));
455 assert_eq!(map.get(&"d"), Some(&4));
456 }
457
458 #[test]
459 fn test_extend() {
460 let mut map = TreeMap::default();
461 map.insert("a", 1);
462 map.insert("b", 2);
463 map.insert("c", 3);
464 map.extend([("a", 2), ("b", 2), ("d", 4)]);
465 assert_eq!(map.iter().count(), 4);
466 assert_eq!(map.get(&"a"), Some(&2));
467 assert_eq!(map.get(&"b"), Some(&2));
468 assert_eq!(map.get(&"c"), Some(&3));
469 assert_eq!(map.get(&"d"), Some(&4));
470 }
471
472 #[test]
473 fn test_remove_between_and_path_successor() {
474 use std::path::{Path, PathBuf};
475
476 #[derive(Debug)]
477 pub struct PathDescendants<'a>(&'a Path);
478
479 impl MapSeekTarget<PathBuf> for PathDescendants<'_> {
480 fn cmp_cursor(&self, key: &PathBuf) -> Ordering {
481 if key.starts_with(self.0) {
482 Ordering::Greater
483 } else {
484 self.0.cmp(key)
485 }
486 }
487 }
488
489 let mut map = TreeMap::default();
490
491 map.insert(PathBuf::from("a"), 1);
492 map.insert(PathBuf::from("a/a"), 1);
493 map.insert(PathBuf::from("b"), 2);
494 map.insert(PathBuf::from("b/a/a"), 3);
495 map.insert(PathBuf::from("b/a/a/a/b"), 4);
496 map.insert(PathBuf::from("c"), 5);
497 map.insert(PathBuf::from("c/a"), 6);
498
499 map.remove_range(
500 &PathBuf::from("b/a"),
501 &PathDescendants(&PathBuf::from("b/a")),
502 );
503
504 assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
505 assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
506 assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
507 assert_eq!(map.get(&PathBuf::from("b/a/a")), None);
508 assert_eq!(map.get(&PathBuf::from("b/a/a/a/b")), None);
509 assert_eq!(map.get(&PathBuf::from("c")), Some(&5));
510 assert_eq!(map.get(&PathBuf::from("c/a")), Some(&6));
511
512 map.remove_range(&PathBuf::from("c"), &PathDescendants(&PathBuf::from("c")));
513
514 assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
515 assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
516 assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
517 assert_eq!(map.get(&PathBuf::from("c")), None);
518 assert_eq!(map.get(&PathBuf::from("c/a")), None);
519
520 map.remove_range(&PathBuf::from("a"), &PathDescendants(&PathBuf::from("a")));
521
522 assert_eq!(map.get(&PathBuf::from("a")), None);
523 assert_eq!(map.get(&PathBuf::from("a/a")), None);
524 assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
525
526 map.remove_range(&PathBuf::from("b"), &PathDescendants(&PathBuf::from("b")));
527
528 assert_eq!(map.get(&PathBuf::from("b")), None);
529 }
530}