1use std::{cmp::Ordering, fmt::Debug};
2
3use crate::{Bias, Dimension, Edit, Item, KeyedItem, SeekTarget, SumTree, Summary};
4
5#[derive(Clone, PartialEq, Eq)]
6pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
7where
8 K: Clone + Ord,
9 V: Clone;
10
11#[derive(Clone, Debug, PartialEq, Eq)]
12pub struct MapEntry<K, V> {
13 key: K,
14 value: V,
15}
16
17#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
18pub struct MapKey<K>(Option<K>);
19
20impl<K> Default for MapKey<K> {
21 fn default() -> Self {
22 Self(None)
23 }
24}
25
26#[derive(Clone, Debug)]
27pub struct MapKeyRef<'a, K>(Option<&'a K>);
28
29impl<K> Default for MapKeyRef<'_, K> {
30 fn default() -> Self {
31 Self(None)
32 }
33}
34
35#[derive(Clone, Debug, PartialEq, Eq)]
36pub struct TreeSet<K>(TreeMap<K, ()>)
37where
38 K: Clone + Ord;
39
40impl<K: Clone + Ord, V: Clone> TreeMap<K, V> {
41 pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
42 let tree = SumTree::from_iter(
43 entries
44 .into_iter()
45 .map(|(key, value)| MapEntry { key, value }),
46 &(),
47 );
48 Self(tree)
49 }
50
51 pub fn is_empty(&self) -> bool {
52 self.0.is_empty()
53 }
54
55 pub fn get(&self, key: &K) -> Option<&V> {
56 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
57 cursor.seek(&MapKeyRef(Some(key)), Bias::Left);
58 if let Some(item) = cursor.item() {
59 if Some(key) == item.key().0.as_ref() {
60 Some(&item.value)
61 } else {
62 None
63 }
64 } else {
65 None
66 }
67 }
68
69 pub fn insert(&mut self, key: K, value: V) {
70 self.0.insert_or_replace(MapEntry { key, value }, &());
71 }
72
73 pub fn extend(&mut self, iter: impl IntoIterator<Item = (K, V)>) {
74 let edits: Vec<_> = iter
75 .into_iter()
76 .map(|(key, value)| Edit::Insert(MapEntry { key, value }))
77 .collect();
78 self.0.edit(edits, &());
79 }
80
81 pub fn clear(&mut self) {
82 self.0 = SumTree::default();
83 }
84
85 pub fn remove(&mut self, key: &K) -> Option<V> {
86 let mut removed = None;
87 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
88 let key = MapKeyRef(Some(key));
89 let mut new_tree = cursor.slice(&key, Bias::Left);
90 if key.cmp(&cursor.end(), &()) == Ordering::Equal {
91 removed = Some(cursor.item().unwrap().value.clone());
92 cursor.next();
93 }
94 new_tree.append(cursor.suffix(), &());
95 drop(cursor);
96 self.0 = new_tree;
97 removed
98 }
99
100 pub fn remove_range(&mut self, start: &impl MapSeekTarget<K>, end: &impl MapSeekTarget<K>) {
101 let start = MapSeekTargetAdaptor(start);
102 let end = MapSeekTargetAdaptor(end);
103 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
104 let mut new_tree = cursor.slice(&start, Bias::Left);
105 cursor.seek(&end, Bias::Left);
106 new_tree.append(cursor.suffix(), &());
107 drop(cursor);
108 self.0 = new_tree;
109 }
110
111 /// Returns the key-value pair with the greatest key less than or equal to the given key.
112 pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
113 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
114 let key = MapKeyRef(Some(key));
115 cursor.seek(&key, Bias::Right);
116 cursor.prev();
117 cursor.item().map(|item| (&item.key, &item.value))
118 }
119
120 pub fn iter_from<'a>(&'a self, from: &K) -> impl Iterator<Item = (&'a K, &'a V)> + 'a {
121 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
122 let from_key = MapKeyRef(Some(from));
123 cursor.seek(&from_key, Bias::Left);
124
125 cursor.map(|map_entry| (&map_entry.key, &map_entry.value))
126 }
127
128 pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
129 where
130 F: FnOnce(&mut V) -> T,
131 {
132 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
133 let key = MapKeyRef(Some(key));
134 let mut new_tree = cursor.slice(&key, Bias::Left);
135 let mut result = None;
136 if key.cmp(&cursor.end(), &()) == Ordering::Equal {
137 let mut updated = cursor.item().unwrap().clone();
138 result = Some(f(&mut updated.value));
139 new_tree.push(updated, &());
140 cursor.next();
141 }
142 new_tree.append(cursor.suffix(), &());
143 drop(cursor);
144 self.0 = new_tree;
145 result
146 }
147
148 pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
149 let mut new_map = SumTree::<MapEntry<K, V>>::default();
150
151 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
152 cursor.next();
153 while let Some(item) = cursor.item() {
154 if predicate(&item.key, &item.value) {
155 new_map.push(item.clone(), &());
156 }
157 cursor.next();
158 }
159 drop(cursor);
160
161 self.0 = new_map;
162 }
163
164 pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
165 self.0.iter().map(|entry| (&entry.key, &entry.value))
166 }
167
168 pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
169 self.0.iter().map(|entry| &entry.value)
170 }
171
172 pub fn first(&self) -> Option<(&K, &V)> {
173 self.0.first().map(|entry| (&entry.key, &entry.value))
174 }
175
176 pub fn last(&self) -> Option<(&K, &V)> {
177 self.0.last().map(|entry| (&entry.key, &entry.value))
178 }
179
180 pub fn insert_tree(&mut self, other: TreeMap<K, V>) {
181 let edits = other
182 .iter()
183 .map(|(key, value)| {
184 Edit::Insert(MapEntry {
185 key: key.to_owned(),
186 value: value.to_owned(),
187 })
188 })
189 .collect();
190
191 self.0.edit(edits, &());
192 }
193}
194
195impl<K, V> Debug for TreeMap<K, V>
196where
197 K: Clone + Debug + Ord,
198 V: Clone + Debug,
199{
200 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201 f.debug_map().entries(self.iter()).finish()
202 }
203}
204
205#[derive(Debug)]
206struct MapSeekTargetAdaptor<'a, T>(&'a T);
207
208impl<'a, K: Clone + Ord, T: MapSeekTarget<K>> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>>
209 for MapSeekTargetAdaptor<'_, T>
210{
211 fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
212 if let Some(key) = &cursor_location.0 {
213 MapSeekTarget::cmp_cursor(self.0, key)
214 } else {
215 Ordering::Greater
216 }
217 }
218}
219
220pub trait MapSeekTarget<K> {
221 fn cmp_cursor(&self, cursor_location: &K) -> Ordering;
222}
223
224impl<K: Ord> MapSeekTarget<K> for K {
225 fn cmp_cursor(&self, cursor_location: &K) -> Ordering {
226 self.cmp(cursor_location)
227 }
228}
229
230impl<K, V> Default for TreeMap<K, V>
231where
232 K: Clone + Ord,
233 V: Clone,
234{
235 fn default() -> Self {
236 Self(Default::default())
237 }
238}
239
240impl<K, V> Item for MapEntry<K, V>
241where
242 K: Clone + Ord,
243 V: Clone,
244{
245 type Summary = MapKey<K>;
246
247 fn summary(&self, _cx: &()) -> Self::Summary {
248 self.key()
249 }
250}
251
252impl<K, V> KeyedItem for MapEntry<K, V>
253where
254 K: Clone + Ord,
255 V: Clone,
256{
257 type Key = MapKey<K>;
258
259 fn key(&self) -> Self::Key {
260 MapKey(Some(self.key.clone()))
261 }
262}
263
264impl<K> Summary for MapKey<K>
265where
266 K: Clone,
267{
268 type Context = ();
269
270 fn zero(_cx: &()) -> Self {
271 Default::default()
272 }
273
274 fn add_summary(&mut self, summary: &Self, _: &()) {
275 *self = summary.clone()
276 }
277}
278
279impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
280where
281 K: Clone + Ord,
282{
283 fn zero(_cx: &()) -> Self {
284 Default::default()
285 }
286
287 fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
288 self.0 = summary.0.as_ref();
289 }
290}
291
292impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
293where
294 K: Clone + Ord,
295{
296 fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
297 Ord::cmp(&self.0, &cursor_location.0)
298 }
299}
300
301impl<K> Default for TreeSet<K>
302where
303 K: Clone + Ord,
304{
305 fn default() -> Self {
306 Self(Default::default())
307 }
308}
309
310impl<K> TreeSet<K>
311where
312 K: Clone + Ord,
313{
314 pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
315 Self(TreeMap::from_ordered_entries(
316 entries.into_iter().map(|key| (key, ())),
317 ))
318 }
319
320 pub fn is_empty(&self) -> bool {
321 self.0.is_empty()
322 }
323
324 pub fn insert(&mut self, key: K) {
325 self.0.insert(key, ());
326 }
327
328 pub fn remove(&mut self, key: &K) -> bool {
329 self.0.remove(key).is_some()
330 }
331
332 pub fn extend(&mut self, iter: impl IntoIterator<Item = K>) {
333 self.0.extend(iter.into_iter().map(|key| (key, ())));
334 }
335
336 pub fn contains(&self, key: &K) -> bool {
337 self.0.get(key).is_some()
338 }
339
340 pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
341 self.0.iter().map(|(k, _)| k)
342 }
343
344 pub fn iter_from<'a>(&'a self, key: &K) -> impl Iterator<Item = &'a K> + 'a {
345 self.0.iter_from(key).map(move |(k, _)| k)
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn test_basic() {
355 let mut map = TreeMap::default();
356 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
357
358 map.insert(3, "c");
359 assert_eq!(map.get(&3), Some(&"c"));
360 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
361
362 map.insert(1, "a");
363 assert_eq!(map.get(&1), Some(&"a"));
364 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
365
366 map.insert(2, "b");
367 assert_eq!(map.get(&2), Some(&"b"));
368 assert_eq!(map.get(&1), Some(&"a"));
369 assert_eq!(map.get(&3), Some(&"c"));
370 assert_eq!(
371 map.iter().collect::<Vec<_>>(),
372 vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
373 );
374
375 assert_eq!(map.closest(&0), None);
376 assert_eq!(map.closest(&1), Some((&1, &"a")));
377 assert_eq!(map.closest(&10), Some((&3, &"c")));
378
379 map.remove(&2);
380 assert_eq!(map.get(&2), None);
381 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
382
383 assert_eq!(map.closest(&2), Some((&1, &"a")));
384
385 map.remove(&3);
386 assert_eq!(map.get(&3), None);
387 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
388
389 map.remove(&1);
390 assert_eq!(map.get(&1), None);
391 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
392
393 map.insert(4, "d");
394 map.insert(5, "e");
395 map.insert(6, "f");
396 map.retain(|key, _| *key % 2 == 0);
397 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
398 }
399
400 #[test]
401 fn test_iter_from() {
402 let mut map = TreeMap::default();
403
404 map.insert("a", 1);
405 map.insert("b", 2);
406 map.insert("baa", 3);
407 map.insert("baaab", 4);
408 map.insert("c", 5);
409
410 let result = map
411 .iter_from(&"ba")
412 .take_while(|(key, _)| key.starts_with("ba"))
413 .collect::<Vec<_>>();
414
415 assert_eq!(result.len(), 2);
416 assert!(result.iter().any(|(k, _)| k == &&"baa"));
417 assert!(result.iter().any(|(k, _)| k == &&"baaab"));
418
419 let result = map
420 .iter_from(&"c")
421 .take_while(|(key, _)| key.starts_with("c"))
422 .collect::<Vec<_>>();
423
424 assert_eq!(result.len(), 1);
425 assert!(result.iter().any(|(k, _)| k == &&"c"));
426 }
427
428 #[test]
429 fn test_insert_tree() {
430 let mut map = TreeMap::default();
431 map.insert("a", 1);
432 map.insert("b", 2);
433 map.insert("c", 3);
434
435 let mut other = TreeMap::default();
436 other.insert("a", 2);
437 other.insert("b", 2);
438 other.insert("d", 4);
439
440 map.insert_tree(other);
441
442 assert_eq!(map.iter().count(), 4);
443 assert_eq!(map.get(&"a"), Some(&2));
444 assert_eq!(map.get(&"b"), Some(&2));
445 assert_eq!(map.get(&"c"), Some(&3));
446 assert_eq!(map.get(&"d"), Some(&4));
447 }
448
449 #[test]
450 fn test_extend() {
451 let mut map = TreeMap::default();
452 map.insert("a", 1);
453 map.insert("b", 2);
454 map.insert("c", 3);
455 map.extend([("a", 2), ("b", 2), ("d", 4)]);
456 assert_eq!(map.iter().count(), 4);
457 assert_eq!(map.get(&"a"), Some(&2));
458 assert_eq!(map.get(&"b"), Some(&2));
459 assert_eq!(map.get(&"c"), Some(&3));
460 assert_eq!(map.get(&"d"), Some(&4));
461 }
462
463 #[test]
464 fn test_remove_between_and_path_successor() {
465 use std::path::{Path, PathBuf};
466
467 #[derive(Debug)]
468 pub struct PathDescendants<'a>(&'a Path);
469
470 impl MapSeekTarget<PathBuf> for PathDescendants<'_> {
471 fn cmp_cursor(&self, key: &PathBuf) -> Ordering {
472 if key.starts_with(self.0) {
473 Ordering::Greater
474 } else {
475 self.0.cmp(key)
476 }
477 }
478 }
479
480 let mut map = TreeMap::default();
481
482 map.insert(PathBuf::from("a"), 1);
483 map.insert(PathBuf::from("a/a"), 1);
484 map.insert(PathBuf::from("b"), 2);
485 map.insert(PathBuf::from("b/a/a"), 3);
486 map.insert(PathBuf::from("b/a/a/a/b"), 4);
487 map.insert(PathBuf::from("c"), 5);
488 map.insert(PathBuf::from("c/a"), 6);
489
490 map.remove_range(
491 &PathBuf::from("b/a"),
492 &PathDescendants(&PathBuf::from("b/a")),
493 );
494
495 assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
496 assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
497 assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
498 assert_eq!(map.get(&PathBuf::from("b/a/a")), None);
499 assert_eq!(map.get(&PathBuf::from("b/a/a/a/b")), None);
500 assert_eq!(map.get(&PathBuf::from("c")), Some(&5));
501 assert_eq!(map.get(&PathBuf::from("c/a")), Some(&6));
502
503 map.remove_range(&PathBuf::from("c"), &PathDescendants(&PathBuf::from("c")));
504
505 assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
506 assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
507 assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
508 assert_eq!(map.get(&PathBuf::from("c")), None);
509 assert_eq!(map.get(&PathBuf::from("c/a")), None);
510
511 map.remove_range(&PathBuf::from("a"), &PathDescendants(&PathBuf::from("a")));
512
513 assert_eq!(map.get(&PathBuf::from("a")), None);
514 assert_eq!(map.get(&PathBuf::from("a/a")), None);
515 assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
516
517 map.remove_range(&PathBuf::from("b"), &PathDescendants(&PathBuf::from("b")));
518
519 assert_eq!(map.get(&PathBuf::from("b")), None);
520 }
521}