1use std::{cmp::Ordering, fmt::Debug};
2
3use crate::{Bias, ContextLessSummary, Dimension, Edit, Item, KeyedItem, SeekTarget, SumTree};
4
5/// A cheaply-cloneable ordered map based on a [SumTree](crate::SumTree).
6#[derive(Clone, PartialEq, Eq)]
7pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
8where
9 K: Clone + Ord,
10 V: Clone;
11
12#[derive(Clone, Debug, PartialEq, Eq)]
13pub struct MapEntry<K, V> {
14 key: K,
15 value: V,
16}
17
18#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
19pub struct MapKey<K>(Option<K>);
20
21impl<K> Default for MapKey<K> {
22 fn default() -> Self {
23 Self(None)
24 }
25}
26
27#[derive(Clone, Debug)]
28pub struct MapKeyRef<'a, K>(Option<&'a K>);
29
30impl<K> Default for MapKeyRef<'_, K> {
31 fn default() -> Self {
32 Self(None)
33 }
34}
35
36#[derive(Clone, Debug, PartialEq, Eq)]
37pub struct TreeSet<K>(TreeMap<K, ()>)
38where
39 K: Clone + Ord;
40
41impl<K: Clone + Ord, V: Clone> TreeMap<K, V> {
42 pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
43 let tree = SumTree::from_iter(
44 entries
45 .into_iter()
46 .map(|(key, value)| MapEntry { key, value }),
47 (),
48 );
49 Self(tree)
50 }
51
52 pub fn is_empty(&self) -> bool {
53 self.0.is_empty()
54 }
55
56 pub fn get(&self, key: &K) -> Option<&V> {
57 let (.., item) = self
58 .0
59 .find::<MapKeyRef<'_, K>, _>((), &MapKeyRef(Some(key)), Bias::Left);
60 if let Some(item) = item {
61 if Some(key) == item.key().0.as_ref() {
62 Some(&item.value)
63 } else {
64 None
65 }
66 } else {
67 None
68 }
69 }
70
71 pub fn insert(&mut self, key: K, value: V) {
72 self.0.insert_or_replace(MapEntry { key, value }, ());
73 }
74
75 pub fn extend(&mut self, iter: impl IntoIterator<Item = (K, V)>) {
76 let edits: Vec<_> = iter
77 .into_iter()
78 .map(|(key, value)| Edit::Insert(MapEntry { key, value }))
79 .collect();
80 self.0.edit(edits, ());
81 }
82
83 pub fn clear(&mut self) {
84 self.0 = SumTree::default();
85 }
86
87 pub fn remove(&mut self, key: &K) -> Option<V> {
88 let mut removed = None;
89 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
90 let key = MapKeyRef(Some(key));
91 let mut new_tree = cursor.slice(&key, Bias::Left);
92 if key.cmp(&cursor.end(), ()) == Ordering::Equal {
93 removed = Some(cursor.item().unwrap().value.clone());
94 cursor.next();
95 }
96 new_tree.append(cursor.suffix(), ());
97 drop(cursor);
98 self.0 = new_tree;
99 removed
100 }
101
102 pub fn remove_range(&mut self, start: &impl MapSeekTarget<K>, end: &impl MapSeekTarget<K>) {
103 let start = MapSeekTargetAdaptor(start);
104 let end = MapSeekTargetAdaptor(end);
105 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
106 let mut new_tree = cursor.slice(&start, Bias::Left);
107 cursor.seek(&end, Bias::Left);
108 new_tree.append(cursor.suffix(), ());
109 drop(cursor);
110 self.0 = new_tree;
111 }
112
113 /// Returns the key-value pair with the greatest key less than or equal to the given key.
114 pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
115 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
116 let key = MapKeyRef(Some(key));
117 cursor.seek(&key, Bias::Right);
118 cursor.prev();
119 cursor.item().map(|item| (&item.key, &item.value))
120 }
121
122 pub fn iter_from<'a>(&'a self, from: &K) -> impl Iterator<Item = (&'a K, &'a V)> + 'a {
123 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
124 let from_key = MapKeyRef(Some(from));
125 cursor.seek(&from_key, Bias::Left);
126
127 cursor.map(|map_entry| (&map_entry.key, &map_entry.value))
128 }
129
130 pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
131 where
132 F: FnOnce(&mut V) -> T,
133 {
134 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
135 let key = MapKeyRef(Some(key));
136 let mut new_tree = cursor.slice(&key, Bias::Left);
137 let mut result = None;
138 if key.cmp(&cursor.end(), ()) == Ordering::Equal {
139 let mut updated = cursor.item().unwrap().clone();
140 result = Some(f(&mut updated.value));
141 new_tree.push(updated, ());
142 cursor.next();
143 }
144 new_tree.append(cursor.suffix(), ());
145 drop(cursor);
146 self.0 = new_tree;
147 result
148 }
149
150 pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
151 let mut new_map = SumTree::<MapEntry<K, V>>::default();
152
153 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(());
154 cursor.next();
155 while let Some(item) = cursor.item() {
156 if predicate(&item.key, &item.value) {
157 new_map.push(item.clone(), ());
158 }
159 cursor.next();
160 }
161 drop(cursor);
162
163 self.0 = new_map;
164 }
165
166 pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
167 self.0.iter().map(|entry| (&entry.key, &entry.value))
168 }
169
170 pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
171 self.0.iter().map(|entry| &entry.value)
172 }
173
174 pub fn first(&self) -> Option<(&K, &V)> {
175 self.0.first().map(|entry| (&entry.key, &entry.value))
176 }
177
178 pub fn last(&self) -> Option<(&K, &V)> {
179 self.0.last().map(|entry| (&entry.key, &entry.value))
180 }
181
182 pub fn insert_tree(&mut self, other: TreeMap<K, V>) {
183 let edits = other
184 .iter()
185 .map(|(key, value)| {
186 Edit::Insert(MapEntry {
187 key: key.to_owned(),
188 value: value.to_owned(),
189 })
190 })
191 .collect();
192
193 self.0.edit(edits, ());
194 }
195}
196
197impl<K, V> Debug for TreeMap<K, V>
198where
199 K: Clone + Debug + Ord,
200 V: Clone + Debug,
201{
202 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
203 f.debug_map().entries(self.iter()).finish()
204 }
205}
206
207#[derive(Debug)]
208struct MapSeekTargetAdaptor<'a, T>(&'a T);
209
210impl<'a, K: Clone + Ord, T: MapSeekTarget<K>> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>>
211 for MapSeekTargetAdaptor<'_, T>
212{
213 fn cmp(&self, cursor_location: &MapKeyRef<K>, _: ()) -> Ordering {
214 if let Some(key) = &cursor_location.0 {
215 MapSeekTarget::cmp_cursor(self.0, key)
216 } else {
217 Ordering::Greater
218 }
219 }
220}
221
222pub trait MapSeekTarget<K> {
223 fn cmp_cursor(&self, cursor_location: &K) -> Ordering;
224}
225
226impl<K: Ord> MapSeekTarget<K> for K {
227 fn cmp_cursor(&self, cursor_location: &K) -> Ordering {
228 self.cmp(cursor_location)
229 }
230}
231
232impl<K, V> Default for TreeMap<K, V>
233where
234 K: Clone + Ord,
235 V: Clone,
236{
237 fn default() -> Self {
238 Self(Default::default())
239 }
240}
241
242impl<K, V> Item for MapEntry<K, V>
243where
244 K: Clone + Ord,
245 V: Clone,
246{
247 type Summary = MapKey<K>;
248
249 fn summary(&self, _cx: ()) -> Self::Summary {
250 self.key()
251 }
252}
253
254impl<K, V> KeyedItem for MapEntry<K, V>
255where
256 K: Clone + Ord,
257 V: Clone,
258{
259 type Key = MapKey<K>;
260
261 fn key(&self) -> Self::Key {
262 MapKey(Some(self.key.clone()))
263 }
264}
265
266impl<K> ContextLessSummary for MapKey<K>
267where
268 K: Clone,
269{
270 fn zero() -> Self {
271 Default::default()
272 }
273
274 fn add_summary(&mut self, summary: &Self) {
275 *self = summary.clone()
276 }
277}
278
279impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
280where
281 K: Clone + Ord,
282{
283 fn zero(_cx: ()) -> Self {
284 Default::default()
285 }
286
287 fn add_summary(&mut self, summary: &'a MapKey<K>, _: ()) {
288 self.0 = summary.0.as_ref();
289 }
290}
291
292impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
293where
294 K: Clone + Ord,
295{
296 fn cmp(&self, cursor_location: &MapKeyRef<K>, _: ()) -> Ordering {
297 Ord::cmp(&self.0, &cursor_location.0)
298 }
299}
300
301impl<K> Default for TreeSet<K>
302where
303 K: Clone + Ord,
304{
305 fn default() -> Self {
306 Self(Default::default())
307 }
308}
309
310impl<K> TreeSet<K>
311where
312 K: Clone + Ord,
313{
314 pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
315 Self(TreeMap::from_ordered_entries(
316 entries.into_iter().map(|key| (key, ())),
317 ))
318 }
319
320 pub fn is_empty(&self) -> bool {
321 self.0.is_empty()
322 }
323
324 pub fn insert(&mut self, key: K) {
325 self.0.insert(key, ());
326 }
327
328 pub fn remove(&mut self, key: &K) -> bool {
329 self.0.remove(key).is_some()
330 }
331
332 pub fn extend(&mut self, iter: impl IntoIterator<Item = K>) {
333 self.0.extend(iter.into_iter().map(|key| (key, ())));
334 }
335
336 pub fn contains(&self, key: &K) -> bool {
337 self.0.get(key).is_some()
338 }
339
340 pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
341 self.0.iter().map(|(k, _)| k)
342 }
343
344 pub fn iter_from<'a>(&'a self, key: &K) -> impl Iterator<Item = &'a K> + 'a {
345 self.0.iter_from(key).map(move |(k, _)| k)
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn test_basic() {
355 let mut map = TreeMap::default();
356 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
357
358 map.insert(3, "c");
359 assert_eq!(map.get(&3), Some(&"c"));
360 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
361
362 map.insert(1, "a");
363 assert_eq!(map.get(&1), Some(&"a"));
364 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
365
366 map.insert(2, "b");
367 assert_eq!(map.get(&2), Some(&"b"));
368 assert_eq!(map.get(&1), Some(&"a"));
369 assert_eq!(map.get(&3), Some(&"c"));
370 assert_eq!(
371 map.iter().collect::<Vec<_>>(),
372 vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
373 );
374
375 assert_eq!(map.closest(&0), None);
376 assert_eq!(map.closest(&1), Some((&1, &"a")));
377 assert_eq!(map.closest(&10), Some((&3, &"c")));
378
379 map.remove(&2);
380 assert_eq!(map.get(&2), None);
381 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
382
383 assert_eq!(map.closest(&2), Some((&1, &"a")));
384
385 map.remove(&3);
386 assert_eq!(map.get(&3), None);
387 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
388
389 map.remove(&1);
390 assert_eq!(map.get(&1), None);
391 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
392
393 map.insert(4, "d");
394 map.insert(5, "e");
395 map.insert(6, "f");
396 map.retain(|key, _| *key % 2 == 0);
397 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
398 }
399
400 #[test]
401 fn test_iter_from() {
402 let mut map = TreeMap::default();
403
404 map.insert("a", 1);
405 map.insert("b", 2);
406 map.insert("baa", 3);
407 map.insert("baaab", 4);
408 map.insert("c", 5);
409
410 let result = map
411 .iter_from(&"ba")
412 .take_while(|(key, _)| key.starts_with("ba"))
413 .collect::<Vec<_>>();
414
415 assert_eq!(result.len(), 2);
416 assert!(result.iter().any(|(k, _)| k == &&"baa"));
417 assert!(result.iter().any(|(k, _)| k == &&"baaab"));
418
419 let result = map
420 .iter_from(&"c")
421 .take_while(|(key, _)| key.starts_with("c"))
422 .collect::<Vec<_>>();
423
424 assert_eq!(result.len(), 1);
425 assert!(result.iter().any(|(k, _)| k == &&"c"));
426 }
427
428 #[test]
429 fn test_insert_tree() {
430 let mut map = TreeMap::default();
431 map.insert("a", 1);
432 map.insert("b", 2);
433 map.insert("c", 3);
434
435 let mut other = TreeMap::default();
436 other.insert("a", 2);
437 other.insert("b", 2);
438 other.insert("d", 4);
439
440 map.insert_tree(other);
441
442 assert_eq!(map.iter().count(), 4);
443 assert_eq!(map.get(&"a"), Some(&2));
444 assert_eq!(map.get(&"b"), Some(&2));
445 assert_eq!(map.get(&"c"), Some(&3));
446 assert_eq!(map.get(&"d"), Some(&4));
447 }
448
449 #[test]
450 fn test_extend() {
451 let mut map = TreeMap::default();
452 map.insert("a", 1);
453 map.insert("b", 2);
454 map.insert("c", 3);
455 map.extend([("a", 2), ("b", 2), ("d", 4)]);
456 assert_eq!(map.iter().count(), 4);
457 assert_eq!(map.get(&"a"), Some(&2));
458 assert_eq!(map.get(&"b"), Some(&2));
459 assert_eq!(map.get(&"c"), Some(&3));
460 assert_eq!(map.get(&"d"), Some(&4));
461 }
462
463 #[test]
464 fn test_remove_between_and_path_successor() {
465 use std::path::{Path, PathBuf};
466
467 #[derive(Debug)]
468 pub struct PathDescendants<'a>(&'a Path);
469
470 impl MapSeekTarget<PathBuf> for PathDescendants<'_> {
471 fn cmp_cursor(&self, key: &PathBuf) -> Ordering {
472 if key.starts_with(self.0) {
473 Ordering::Greater
474 } else {
475 self.0.cmp(key)
476 }
477 }
478 }
479
480 let mut map = TreeMap::default();
481
482 map.insert(PathBuf::from("a"), 1);
483 map.insert(PathBuf::from("a/a"), 1);
484 map.insert(PathBuf::from("b"), 2);
485 map.insert(PathBuf::from("b/a/a"), 3);
486 map.insert(PathBuf::from("b/a/a/a/b"), 4);
487 map.insert(PathBuf::from("c"), 5);
488 map.insert(PathBuf::from("c/a"), 6);
489
490 map.remove_range(
491 &PathBuf::from("b/a"),
492 &PathDescendants(&PathBuf::from("b/a")),
493 );
494
495 assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
496 assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
497 assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
498 assert_eq!(map.get(&PathBuf::from("b/a/a")), None);
499 assert_eq!(map.get(&PathBuf::from("b/a/a/a/b")), None);
500 assert_eq!(map.get(&PathBuf::from("c")), Some(&5));
501 assert_eq!(map.get(&PathBuf::from("c/a")), Some(&6));
502
503 map.remove_range(&PathBuf::from("c"), &PathDescendants(&PathBuf::from("c")));
504
505 assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
506 assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
507 assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
508 assert_eq!(map.get(&PathBuf::from("c")), None);
509 assert_eq!(map.get(&PathBuf::from("c/a")), None);
510
511 map.remove_range(&PathBuf::from("a"), &PathDescendants(&PathBuf::from("a")));
512
513 assert_eq!(map.get(&PathBuf::from("a")), None);
514 assert_eq!(map.get(&PathBuf::from("a/a")), None);
515 assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
516
517 map.remove_range(&PathBuf::from("b"), &PathDescendants(&PathBuf::from("b")));
518
519 assert_eq!(map.get(&PathBuf::from("b")), None);
520 }
521}