1use std::{cmp::Ordering, fmt::Debug};
2
3use crate::{Bias, ContextLessSummary, Dimension, Edit, Item, KeyedItem, SeekTarget, SumTree};
4
5/// A cheaply-cloneable ordered map based on a [SumTree](crate::SumTree).
6#[derive(Clone, PartialEq, Eq)]
7pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
8where
9 K: Clone + Ord,
10 V: Clone;
11
12#[derive(Clone, Debug, PartialEq, Eq)]
13pub struct MapEntry<K, V> {
14 key: K,
15 value: V,
16}
17
18#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
19pub struct MapKey<K>(Option<K>);
20
21impl<K> Default for MapKey<K> {
22 fn default() -> Self {
23 Self(None)
24 }
25}
26
27#[derive(Clone, Debug)]
28pub struct MapKeyRef<'a, K>(Option<&'a K>);
29
30impl<K> Default for MapKeyRef<'_, K> {
31 fn default() -> Self {
32 Self(None)
33 }
34}
35
36#[derive(Clone, Debug, PartialEq, Eq)]
37pub struct TreeSet<K>(TreeMap<K, ()>)
38where
39 K: Clone + Ord;
40
41impl<K: Clone + Ord, V: Clone> TreeMap<K, V> {
42 pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
43 let tree = SumTree::from_iter(
44 entries
45 .into_iter()
46 .map(|(key, value)| MapEntry { key, value }),
47 (),
48 );
49 Self(tree)
50 }
51
52 pub fn is_empty(&self) -> bool {
53 self.0.is_empty()
54 }
55
56 pub fn get(&self, key: &K) -> Option<&V> {
57 let (.., item) = self
58 .0
59 .find::<MapKeyRef<'_, K>, _>((), &MapKeyRef(Some(key)), Bias::Left);
60 if let Some(item) = item {
61 if Some(key) == item.key().0.as_ref() {
62 Some(&item.value)
63 } else {
64 None
65 }
66 } else {
67 None
68 }
69 }
70
71 pub fn insert(&mut self, key: K, value: V) {
72 self.0.insert_or_replace(MapEntry { key, value }, ());
73 }
74
75 pub fn insert_or_replace(&mut self, key: K, value: V) -> Option<V> {
76 self.0
77 .insert_or_replace(MapEntry { key, value }, ())
78 .map(|it| it.value)
79 }
80
81 pub fn extend(&mut self, iter: impl IntoIterator<Item = (K, V)>) {
82 let edits: Vec<_> = iter
83 .into_iter()
84 .map(|(key, value)| Edit::Insert(MapEntry { key, value }))
85 .collect();
86 self.0.edit(edits, ());
87 }
88
89 pub fn clear(&mut self) {
90 self.0 = SumTree::default();
91 }
92
93 pub fn remove(&mut self, key: &K) -> Option<V> {
94 let mut removed = None;
95 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
96 let key = MapKeyRef(Some(key));
97 let mut new_tree = cursor.slice(&key, Bias::Left);
98 if key.cmp(&cursor.end(), ()) == Ordering::Equal {
99 removed = Some(cursor.item().unwrap().value.clone());
100 cursor.next();
101 }
102 new_tree.append(cursor.suffix(), ());
103 drop(cursor);
104 self.0 = new_tree;
105 removed
106 }
107
108 pub fn remove_range(&mut self, start: &impl MapSeekTarget<K>, end: &impl MapSeekTarget<K>) {
109 let start = MapSeekTargetAdaptor(start);
110 let end = MapSeekTargetAdaptor(end);
111 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
112 let mut new_tree = cursor.slice(&start, Bias::Left);
113 cursor.seek(&end, Bias::Left);
114 new_tree.append(cursor.suffix(), ());
115 drop(cursor);
116 self.0 = new_tree;
117 }
118
119 /// Returns the key-value pair with the greatest key less than or equal to the given key.
120 pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
121 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
122 let key = MapKeyRef(Some(key));
123 cursor.seek(&key, Bias::Right);
124 cursor.prev();
125 cursor.item().map(|item| (&item.key, &item.value))
126 }
127
128 pub fn iter_from<'a>(&'a self, from: &K) -> impl Iterator<Item = (&'a K, &'a V)> + 'a {
129 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
130 let from_key = MapKeyRef(Some(from));
131 cursor.seek(&from_key, Bias::Left);
132
133 cursor.map(|map_entry| (&map_entry.key, &map_entry.value))
134 }
135
136 pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
137 where
138 F: FnOnce(&mut V) -> T,
139 {
140 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
141 let key = MapKeyRef(Some(key));
142 let mut new_tree = cursor.slice(&key, Bias::Left);
143 let mut result = None;
144 if key.cmp(&cursor.end(), ()) == Ordering::Equal {
145 let mut updated = cursor.item().unwrap().clone();
146 result = Some(f(&mut updated.value));
147 new_tree.push(updated, ());
148 cursor.next();
149 }
150 new_tree.append(cursor.suffix(), ());
151 drop(cursor);
152 self.0 = new_tree;
153 result
154 }
155
156 pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
157 let mut new_map = SumTree::<MapEntry<K, V>>::default();
158
159 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
160 cursor.next();
161 while let Some(item) = cursor.item() {
162 if predicate(&item.key, &item.value) {
163 new_map.push(item.clone(), ());
164 }
165 cursor.next();
166 }
167 drop(cursor);
168
169 self.0 = new_map;
170 }
171
172 pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
173 self.0.iter().map(|entry| (&entry.key, &entry.value))
174 }
175
176 pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
177 self.0.iter().map(|entry| &entry.value)
178 }
179
180 pub fn first(&self) -> Option<(&K, &V)> {
181 self.0.first().map(|entry| (&entry.key, &entry.value))
182 }
183
184 pub fn last(&self) -> Option<(&K, &V)> {
185 self.0.last().map(|entry| (&entry.key, &entry.value))
186 }
187
188 pub fn insert_tree(&mut self, other: TreeMap<K, V>) {
189 let edits = other
190 .iter()
191 .map(|(key, value)| {
192 Edit::Insert(MapEntry {
193 key: key.to_owned(),
194 value: value.to_owned(),
195 })
196 })
197 .collect();
198
199 self.0.edit(edits, ());
200 }
201}
202
203impl<K, V> Debug for TreeMap<K, V>
204where
205 K: Clone + Debug + Ord,
206 V: Clone + Debug,
207{
208 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209 f.debug_map().entries(self.iter()).finish()
210 }
211}
212
213#[derive(Debug)]
214struct MapSeekTargetAdaptor<'a, T>(&'a T);
215
216impl<'a, K: Clone + Ord, T: MapSeekTarget<K>> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>>
217 for MapSeekTargetAdaptor<'_, T>
218{
219 fn cmp(&self, cursor_location: &MapKeyRef<K>, _: ()) -> Ordering {
220 if let Some(key) = &cursor_location.0 {
221 MapSeekTarget::cmp_cursor(self.0, key)
222 } else {
223 Ordering::Greater
224 }
225 }
226}
227
228pub trait MapSeekTarget<K> {
229 fn cmp_cursor(&self, cursor_location: &K) -> Ordering;
230}
231
232impl<K: Ord> MapSeekTarget<K> for K {
233 fn cmp_cursor(&self, cursor_location: &K) -> Ordering {
234 self.cmp(cursor_location)
235 }
236}
237
238impl<K, V> Default for TreeMap<K, V>
239where
240 K: Clone + Ord,
241 V: Clone,
242{
243 fn default() -> Self {
244 Self(Default::default())
245 }
246}
247
248impl<K, V> Item for MapEntry<K, V>
249where
250 K: Clone + Ord,
251 V: Clone,
252{
253 type Summary = MapKey<K>;
254
255 fn summary(&self, _cx: ()) -> Self::Summary {
256 self.key()
257 }
258}
259
260impl<K, V> KeyedItem for MapEntry<K, V>
261where
262 K: Clone + Ord,
263 V: Clone,
264{
265 type Key = MapKey<K>;
266
267 fn key(&self) -> Self::Key {
268 MapKey(Some(self.key.clone()))
269 }
270}
271
272impl<K> ContextLessSummary for MapKey<K>
273where
274 K: Clone,
275{
276 fn zero() -> Self {
277 Default::default()
278 }
279
280 fn add_summary(&mut self, summary: &Self) {
281 *self = summary.clone()
282 }
283}
284
285impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
286where
287 K: Clone + Ord,
288{
289 fn zero(_cx: ()) -> Self {
290 Default::default()
291 }
292
293 fn add_summary(&mut self, summary: &'a MapKey<K>, _: ()) {
294 self.0 = summary.0.as_ref();
295 }
296}
297
298impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
299where
300 K: Clone + Ord,
301{
302 fn cmp(&self, cursor_location: &MapKeyRef<K>, _: ()) -> Ordering {
303 Ord::cmp(&self.0, &cursor_location.0)
304 }
305}
306
307impl<K> Default for TreeSet<K>
308where
309 K: Clone + Ord,
310{
311 fn default() -> Self {
312 Self(Default::default())
313 }
314}
315
316impl<K> TreeSet<K>
317where
318 K: Clone + Ord,
319{
320 pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
321 Self(TreeMap::from_ordered_entries(
322 entries.into_iter().map(|key| (key, ())),
323 ))
324 }
325
326 pub fn is_empty(&self) -> bool {
327 self.0.is_empty()
328 }
329
330 pub fn insert(&mut self, key: K) {
331 self.0.insert(key, ());
332 }
333
334 pub fn remove(&mut self, key: &K) -> bool {
335 self.0.remove(key).is_some()
336 }
337
338 pub fn extend(&mut self, iter: impl IntoIterator<Item = K>) {
339 self.0.extend(iter.into_iter().map(|key| (key, ())));
340 }
341
342 pub fn contains(&self, key: &K) -> bool {
343 self.0.get(key).is_some()
344 }
345
346 pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
347 self.0.iter().map(|(k, _)| k)
348 }
349
350 pub fn iter_from<'a>(&'a self, key: &K) -> impl Iterator<Item = &'a K> + 'a {
351 self.0.iter_from(key).map(move |(k, _)| k)
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358
359 #[test]
360 fn test_basic() {
361 let mut map = TreeMap::default();
362 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
363
364 map.insert(3, "c");
365 assert_eq!(map.get(&3), Some(&"c"));
366 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
367
368 map.insert(1, "a");
369 assert_eq!(map.get(&1), Some(&"a"));
370 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
371
372 map.insert(2, "b");
373 assert_eq!(map.get(&2), Some(&"b"));
374 assert_eq!(map.get(&1), Some(&"a"));
375 assert_eq!(map.get(&3), Some(&"c"));
376 assert_eq!(
377 map.iter().collect::<Vec<_>>(),
378 vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
379 );
380
381 assert_eq!(map.closest(&0), None);
382 assert_eq!(map.closest(&1), Some((&1, &"a")));
383 assert_eq!(map.closest(&10), Some((&3, &"c")));
384
385 map.remove(&2);
386 assert_eq!(map.get(&2), None);
387 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
388
389 assert_eq!(map.closest(&2), Some((&1, &"a")));
390
391 map.remove(&3);
392 assert_eq!(map.get(&3), None);
393 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
394
395 map.remove(&1);
396 assert_eq!(map.get(&1), None);
397 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
398
399 map.insert(4, "d");
400 map.insert(5, "e");
401 map.insert(6, "f");
402 map.retain(|key, _| *key % 2 == 0);
403 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
404 }
405
406 #[test]
407 fn test_iter_from() {
408 let mut map = TreeMap::default();
409
410 map.insert("a", 1);
411 map.insert("b", 2);
412 map.insert("baa", 3);
413 map.insert("baaab", 4);
414 map.insert("c", 5);
415
416 let result = map
417 .iter_from(&"ba")
418 .take_while(|(key, _)| key.starts_with("ba"))
419 .collect::<Vec<_>>();
420
421 assert_eq!(result.len(), 2);
422 assert!(result.iter().any(|(k, _)| k == &&"baa"));
423 assert!(result.iter().any(|(k, _)| k == &&"baaab"));
424
425 let result = map
426 .iter_from(&"c")
427 .take_while(|(key, _)| key.starts_with("c"))
428 .collect::<Vec<_>>();
429
430 assert_eq!(result.len(), 1);
431 assert!(result.iter().any(|(k, _)| k == &&"c"));
432 }
433
434 #[test]
435 fn test_insert_tree() {
436 let mut map = TreeMap::default();
437 map.insert("a", 1);
438 map.insert("b", 2);
439 map.insert("c", 3);
440
441 let mut other = TreeMap::default();
442 other.insert("a", 2);
443 other.insert("b", 2);
444 other.insert("d", 4);
445
446 map.insert_tree(other);
447
448 assert_eq!(map.iter().count(), 4);
449 assert_eq!(map.get(&"a"), Some(&2));
450 assert_eq!(map.get(&"b"), Some(&2));
451 assert_eq!(map.get(&"c"), Some(&3));
452 assert_eq!(map.get(&"d"), Some(&4));
453 }
454
455 #[test]
456 fn test_extend() {
457 let mut map = TreeMap::default();
458 map.insert("a", 1);
459 map.insert("b", 2);
460 map.insert("c", 3);
461 map.extend([("a", 2), ("b", 2), ("d", 4)]);
462 assert_eq!(map.iter().count(), 4);
463 assert_eq!(map.get(&"a"), Some(&2));
464 assert_eq!(map.get(&"b"), Some(&2));
465 assert_eq!(map.get(&"c"), Some(&3));
466 assert_eq!(map.get(&"d"), Some(&4));
467 }
468
469 #[test]
470 fn test_remove_between_and_path_successor() {
471 use std::path::{Path, PathBuf};
472
473 #[derive(Debug)]
474 pub struct PathDescendants<'a>(&'a Path);
475
476 impl MapSeekTarget<PathBuf> for PathDescendants<'_> {
477 fn cmp_cursor(&self, key: &PathBuf) -> Ordering {
478 if key.starts_with(self.0) {
479 Ordering::Greater
480 } else {
481 self.0.cmp(key)
482 }
483 }
484 }
485
486 let mut map = TreeMap::default();
487
488 map.insert(PathBuf::from("a"), 1);
489 map.insert(PathBuf::from("a/a"), 1);
490 map.insert(PathBuf::from("b"), 2);
491 map.insert(PathBuf::from("b/a/a"), 3);
492 map.insert(PathBuf::from("b/a/a/a/b"), 4);
493 map.insert(PathBuf::from("c"), 5);
494 map.insert(PathBuf::from("c/a"), 6);
495
496 map.remove_range(
497 &PathBuf::from("b/a"),
498 &PathDescendants(&PathBuf::from("b/a")),
499 );
500
501 assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
502 assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
503 assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
504 assert_eq!(map.get(&PathBuf::from("b/a/a")), None);
505 assert_eq!(map.get(&PathBuf::from("b/a/a/a/b")), None);
506 assert_eq!(map.get(&PathBuf::from("c")), Some(&5));
507 assert_eq!(map.get(&PathBuf::from("c/a")), Some(&6));
508
509 map.remove_range(&PathBuf::from("c"), &PathDescendants(&PathBuf::from("c")));
510
511 assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
512 assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
513 assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
514 assert_eq!(map.get(&PathBuf::from("c")), None);
515 assert_eq!(map.get(&PathBuf::from("c/a")), None);
516
517 map.remove_range(&PathBuf::from("a"), &PathDescendants(&PathBuf::from("a")));
518
519 assert_eq!(map.get(&PathBuf::from("a")), None);
520 assert_eq!(map.get(&PathBuf::from("a/a")), None);
521 assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
522
523 map.remove_range(&PathBuf::from("b"), &PathDescendants(&PathBuf::from("b")));
524
525 assert_eq!(map.get(&PathBuf::from("b")), None);
526 }
527}