1use std::{
2 cmp::Ordering,
3 fmt::Debug,
4 path::{Path, PathBuf},
5};
6
7use crate::{Bias, Dimension, Edit, Item, KeyedItem, SeekTarget, SumTree, Summary};
8
9#[derive(Clone, Debug, PartialEq, Eq)]
10pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
11where
12 K: Clone + Debug + Default + Ord,
13 V: Clone + Debug;
14
15#[derive(Clone, Debug, PartialEq, Eq)]
16pub struct MapEntry<K, V> {
17 key: K,
18 value: V,
19}
20
21#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
22pub struct MapKey<K>(K);
23
24#[derive(Clone, Debug, Default)]
25pub struct MapKeyRef<'a, K>(Option<&'a K>);
26
27#[derive(Clone)]
28pub struct TreeSet<K>(TreeMap<K, ()>)
29where
30 K: Clone + Debug + Default + Ord;
31
32impl<K: Clone + Debug + Default + Ord, V: Clone + Debug> TreeMap<K, V> {
33 pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
34 let tree = SumTree::from_iter(
35 entries
36 .into_iter()
37 .map(|(key, value)| MapEntry { key, value }),
38 &(),
39 );
40 Self(tree)
41 }
42
43 pub fn is_empty(&self) -> bool {
44 self.0.is_empty()
45 }
46
47 pub fn get<'a>(&self, key: &'a K) -> Option<&V> {
48 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
49 cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &());
50 if let Some(item) = cursor.item() {
51 if *key == item.key().0 {
52 Some(&item.value)
53 } else {
54 None
55 }
56 } else {
57 None
58 }
59 }
60
61 pub fn insert(&mut self, key: K, value: V) {
62 self.0.insert_or_replace(MapEntry { key, value }, &());
63 }
64
65 pub fn remove(&mut self, key: &K) -> Option<V> {
66 let mut removed = None;
67 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
68 let key = MapKeyRef(Some(key));
69 let mut new_tree = cursor.slice(&key, Bias::Left, &());
70 if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
71 removed = Some(cursor.item().unwrap().value.clone());
72 cursor.next(&());
73 }
74 new_tree.push_tree(cursor.suffix(&()), &());
75 drop(cursor);
76 self.0 = new_tree;
77 removed
78 }
79
80 pub fn remove_range(&mut self, start: &impl MapSeekTarget<K>, end: &impl MapSeekTarget<K>) {
81 let start = MapSeekTargetAdaptor(start);
82 let end = MapSeekTargetAdaptor(end);
83 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
84 let mut new_tree = cursor.slice(&start, Bias::Left, &());
85 cursor.seek(&end, Bias::Left, &());
86 new_tree.push_tree(cursor.suffix(&()), &());
87 drop(cursor);
88 self.0 = new_tree;
89 }
90
91 /// Returns the key-value pair with the greatest key less than or equal to the given key.
92 pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
93 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
94 let key = MapKeyRef(Some(key));
95 cursor.seek(&key, Bias::Right, &());
96 cursor.prev(&());
97 cursor.item().map(|item| (&item.key, &item.value))
98 }
99
100 pub fn iter_from<'a>(&'a self, from: &'a K) -> impl Iterator<Item = (&K, &V)> + '_ {
101 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
102 let from_key = MapKeyRef(Some(from));
103 cursor.seek(&from_key, Bias::Left, &());
104
105 cursor
106 .into_iter()
107 .map(|map_entry| (&map_entry.key, &map_entry.value))
108 }
109
110 pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
111 where
112 F: FnOnce(&mut V) -> T,
113 {
114 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
115 let key = MapKeyRef(Some(key));
116 let mut new_tree = cursor.slice(&key, Bias::Left, &());
117 let mut result = None;
118 if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
119 let mut updated = cursor.item().unwrap().clone();
120 result = Some(f(&mut updated.value));
121 new_tree.push(updated, &());
122 cursor.next(&());
123 }
124 new_tree.push_tree(cursor.suffix(&()), &());
125 drop(cursor);
126 self.0 = new_tree;
127 result
128 }
129
130 pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
131 let mut new_map = SumTree::<MapEntry<K, V>>::default();
132
133 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
134 cursor.next(&());
135 while let Some(item) = cursor.item() {
136 if predicate(&item.key, &item.value) {
137 new_map.push(item.clone(), &());
138 }
139 cursor.next(&());
140 }
141 drop(cursor);
142
143 self.0 = new_map;
144 }
145
146 pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
147 self.0.iter().map(|entry| (&entry.key, &entry.value))
148 }
149
150 pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
151 self.0.iter().map(|entry| &entry.value)
152 }
153
154 pub fn insert_tree(&mut self, other: TreeMap<K, V>) {
155 let edits = other
156 .iter()
157 .map(|(key, value)| {
158 Edit::Insert(MapEntry {
159 key: key.to_owned(),
160 value: value.to_owned(),
161 })
162 })
163 .collect();
164
165 self.0.edit(edits, &());
166 }
167}
168
169#[derive(Debug)]
170struct MapSeekTargetAdaptor<'a, T>(&'a T);
171
172impl<'a, K: Debug + Clone + Default + Ord, T: MapSeekTarget<K>>
173 SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapSeekTargetAdaptor<'_, T>
174{
175 fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
176 MapSeekTarget::cmp(self.0, cursor_location)
177 }
178}
179
180pub trait MapSeekTarget<K>: Debug {
181 fn cmp(&self, cursor_location: &MapKeyRef<K>) -> Ordering;
182}
183
184impl<K: Debug + Ord> MapSeekTarget<K> for K {
185 fn cmp(&self, cursor_location: &MapKeyRef<K>) -> Ordering {
186 if let Some(key) = &cursor_location.0 {
187 self.cmp(key)
188 } else {
189 Ordering::Greater
190 }
191 }
192}
193
194#[derive(Debug)]
195pub struct PathDescendants<'a>(&'a Path);
196
197impl MapSeekTarget<PathBuf> for PathDescendants<'_> {
198 fn cmp(&self, cursor_location: &MapKeyRef<PathBuf>) -> Ordering {
199 if let Some(key) = &cursor_location.0 {
200 if key.starts_with(&self.0) {
201 Ordering::Greater
202 } else {
203 self.0.cmp(key)
204 }
205 } else {
206 Ordering::Greater
207 }
208 }
209}
210
211impl<K, V> Default for TreeMap<K, V>
212where
213 K: Clone + Debug + Default + Ord,
214 V: Clone + Debug,
215{
216 fn default() -> Self {
217 Self(Default::default())
218 }
219}
220
221impl<K, V> Item for MapEntry<K, V>
222where
223 K: Clone + Debug + Default + Ord,
224 V: Clone,
225{
226 type Summary = MapKey<K>;
227
228 fn summary(&self) -> Self::Summary {
229 self.key()
230 }
231}
232
233impl<K, V> KeyedItem for MapEntry<K, V>
234where
235 K: Clone + Debug + Default + Ord,
236 V: Clone,
237{
238 type Key = MapKey<K>;
239
240 fn key(&self) -> Self::Key {
241 MapKey(self.key.clone())
242 }
243}
244
245impl<K> Summary for MapKey<K>
246where
247 K: Clone + Debug + Default,
248{
249 type Context = ();
250
251 fn add_summary(&mut self, summary: &Self, _: &()) {
252 *self = summary.clone()
253 }
254}
255
256impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
257where
258 K: Clone + Debug + Default + Ord,
259{
260 fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
261 self.0 = Some(&summary.0)
262 }
263}
264
265impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
266where
267 K: Clone + Debug + Default + Ord,
268{
269 fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
270 Ord::cmp(&self.0, &cursor_location.0)
271 }
272}
273
274impl<K> Default for TreeSet<K>
275where
276 K: Clone + Debug + Default + Ord,
277{
278 fn default() -> Self {
279 Self(Default::default())
280 }
281}
282
283impl<K> TreeSet<K>
284where
285 K: Clone + Debug + Default + Ord,
286{
287 pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
288 Self(TreeMap::from_ordered_entries(
289 entries.into_iter().map(|key| (key, ())),
290 ))
291 }
292
293 pub fn insert(&mut self, key: K) {
294 self.0.insert(key, ());
295 }
296
297 pub fn contains(&self, key: &K) -> bool {
298 self.0.get(key).is_some()
299 }
300
301 pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
302 self.0.iter().map(|(k, _)| k)
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309
310 #[test]
311 fn test_basic() {
312 let mut map = TreeMap::default();
313 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
314
315 map.insert(3, "c");
316 assert_eq!(map.get(&3), Some(&"c"));
317 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
318
319 map.insert(1, "a");
320 assert_eq!(map.get(&1), Some(&"a"));
321 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
322
323 map.insert(2, "b");
324 assert_eq!(map.get(&2), Some(&"b"));
325 assert_eq!(map.get(&1), Some(&"a"));
326 assert_eq!(map.get(&3), Some(&"c"));
327 assert_eq!(
328 map.iter().collect::<Vec<_>>(),
329 vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
330 );
331
332 assert_eq!(map.closest(&0), None);
333 assert_eq!(map.closest(&1), Some((&1, &"a")));
334 assert_eq!(map.closest(&10), Some((&3, &"c")));
335
336 map.remove(&2);
337 assert_eq!(map.get(&2), None);
338 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
339
340 assert_eq!(map.closest(&2), Some((&1, &"a")));
341
342 map.remove(&3);
343 assert_eq!(map.get(&3), None);
344 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
345
346 map.remove(&1);
347 assert_eq!(map.get(&1), None);
348 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
349
350 map.insert(4, "d");
351 map.insert(5, "e");
352 map.insert(6, "f");
353 map.retain(|key, _| *key % 2 == 0);
354 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
355 }
356
357 #[test]
358 fn test_iter_from() {
359 let mut map = TreeMap::default();
360
361 map.insert("a", 1);
362 map.insert("b", 2);
363 map.insert("baa", 3);
364 map.insert("baaab", 4);
365 map.insert("c", 5);
366
367 let result = map
368 .iter_from(&"ba")
369 .take_while(|(key, _)| key.starts_with(&"ba"))
370 .collect::<Vec<_>>();
371
372 assert_eq!(result.len(), 2);
373 assert!(result.iter().find(|(k, _)| k == &&"baa").is_some());
374 assert!(result.iter().find(|(k, _)| k == &&"baaab").is_some());
375
376 let result = map
377 .iter_from(&"c")
378 .take_while(|(key, _)| key.starts_with(&"c"))
379 .collect::<Vec<_>>();
380
381 assert_eq!(result.len(), 1);
382 assert!(result.iter().find(|(k, _)| k == &&"c").is_some());
383 }
384
385 #[test]
386 fn test_insert_tree() {
387 let mut map = TreeMap::default();
388 map.insert("a", 1);
389 map.insert("b", 2);
390 map.insert("c", 3);
391
392 let mut other = TreeMap::default();
393 other.insert("a", 2);
394 other.insert("b", 2);
395 other.insert("d", 4);
396
397 map.insert_tree(other);
398
399 assert_eq!(map.iter().count(), 4);
400 assert_eq!(map.get(&"a"), Some(&2));
401 assert_eq!(map.get(&"b"), Some(&2));
402 assert_eq!(map.get(&"c"), Some(&3));
403 assert_eq!(map.get(&"d"), Some(&4));
404 }
405
406 #[test]
407 fn test_remove_between_and_path_successor() {
408 let mut map = TreeMap::default();
409
410 map.insert(PathBuf::from("a"), 1);
411 map.insert(PathBuf::from("a/a"), 1);
412 map.insert(PathBuf::from("b"), 2);
413 map.insert(PathBuf::from("b/a/a"), 3);
414 map.insert(PathBuf::from("b/a/a/a/b"), 4);
415 map.insert(PathBuf::from("c"), 5);
416 map.insert(PathBuf::from("c/a"), 6);
417
418 map.remove_range(&PathBuf::from("b/a"), &PathDescendants(&PathBuf::from("b/a")));
419
420 assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
421 assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
422 assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
423 assert_eq!(map.get(&PathBuf::from("b/a/a")), None);
424 assert_eq!(map.get(&PathBuf::from("b/a/a/a/b")), None);
425 assert_eq!(map.get(&PathBuf::from("c")), Some(&5));
426 assert_eq!(map.get(&PathBuf::from("c/a")), Some(&6));
427
428 map.remove_range(&PathBuf::from("c"), &PathDescendants(&PathBuf::from("c")));
429
430 assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
431 assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
432 assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
433 assert_eq!(map.get(&PathBuf::from("c")), None);
434 assert_eq!(map.get(&PathBuf::from("c/a")), None);
435
436 map.remove_range(&PathBuf::from("a"), &PathDescendants(&PathBuf::from("a")));
437
438 assert_eq!(map.get(&PathBuf::from("a")), None);
439 assert_eq!(map.get(&PathBuf::from("a/a")), None);
440 assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
441
442 map.remove_range(&PathBuf::from("b"), &PathDescendants(&PathBuf::from("b")));
443
444 assert_eq!(map.get(&PathBuf::from("b")), None);
445 }
446}