1use std::{cmp::Ordering, fmt::Debug};
2
3use crate::{Bias, Dimension, Edit, Item, KeyedItem, SeekTarget, SumTree, Summary};
4
5#[derive(Clone, PartialEq, Eq)]
6pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
7where
8 K: Clone + Ord,
9 V: Clone;
10
11#[derive(Clone, Debug, PartialEq, Eq)]
12pub struct MapEntry<K, V> {
13 key: K,
14 value: V,
15}
16
17#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
18pub struct MapKey<K>(Option<K>);
19
20impl<K> Default for MapKey<K> {
21 fn default() -> Self {
22 Self(None)
23 }
24}
25
26#[derive(Clone, Debug)]
27pub struct MapKeyRef<'a, K>(Option<&'a K>);
28
29impl<K> Default for MapKeyRef<'_, K> {
30 fn default() -> Self {
31 Self(None)
32 }
33}
34
35#[derive(Clone, Debug, PartialEq, Eq)]
36pub struct TreeSet<K>(TreeMap<K, ()>)
37where
38 K: Clone + Ord;
39
40impl<K: Clone + Ord, V: Clone> TreeMap<K, V> {
41 pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
42 let tree = SumTree::from_iter(
43 entries
44 .into_iter()
45 .map(|(key, value)| MapEntry { key, value }),
46 &(),
47 );
48 Self(tree)
49 }
50
51 pub fn is_empty(&self) -> bool {
52 self.0.is_empty()
53 }
54
55 pub fn get(&self, key: &K) -> Option<&V> {
56 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
57 cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &());
58 if let Some(item) = cursor.item() {
59 if Some(key) == item.key().0.as_ref() {
60 Some(&item.value)
61 } else {
62 None
63 }
64 } else {
65 None
66 }
67 }
68
69 pub fn insert(&mut self, key: K, value: V) {
70 self.0.insert_or_replace(MapEntry { key, value }, &());
71 }
72
73 pub fn remove(&mut self, key: &K) -> Option<V> {
74 let mut removed = None;
75 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
76 let key = MapKeyRef(Some(key));
77 let mut new_tree = cursor.slice(&key, Bias::Left, &());
78 if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
79 removed = Some(cursor.item().unwrap().value.clone());
80 cursor.next(&());
81 }
82 new_tree.append(cursor.suffix(&()), &());
83 drop(cursor);
84 self.0 = new_tree;
85 removed
86 }
87
88 pub fn remove_range(&mut self, start: &impl MapSeekTarget<K>, end: &impl MapSeekTarget<K>) {
89 let start = MapSeekTargetAdaptor(start);
90 let end = MapSeekTargetAdaptor(end);
91 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
92 let mut new_tree = cursor.slice(&start, Bias::Left, &());
93 cursor.seek(&end, Bias::Left, &());
94 new_tree.append(cursor.suffix(&()), &());
95 drop(cursor);
96 self.0 = new_tree;
97 }
98
99 /// Returns the key-value pair with the greatest key less than or equal to the given key.
100 pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
101 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
102 let key = MapKeyRef(Some(key));
103 cursor.seek(&key, Bias::Right, &());
104 cursor.prev(&());
105 cursor.item().map(|item| (&item.key, &item.value))
106 }
107
108 pub fn iter_from<'a>(&'a self, from: &'a K) -> impl Iterator<Item = (&'a K, &'a V)> + 'a {
109 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
110 let from_key = MapKeyRef(Some(from));
111 cursor.seek(&from_key, Bias::Left, &());
112
113 cursor.map(|map_entry| (&map_entry.key, &map_entry.value))
114 }
115
116 pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
117 where
118 F: FnOnce(&mut V) -> T,
119 {
120 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
121 let key = MapKeyRef(Some(key));
122 let mut new_tree = cursor.slice(&key, Bias::Left, &());
123 let mut result = None;
124 if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
125 let mut updated = cursor.item().unwrap().clone();
126 result = Some(f(&mut updated.value));
127 new_tree.push(updated, &());
128 cursor.next(&());
129 }
130 new_tree.append(cursor.suffix(&()), &());
131 drop(cursor);
132 self.0 = new_tree;
133 result
134 }
135
136 pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
137 let mut new_map = SumTree::<MapEntry<K, V>>::default();
138
139 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
140 cursor.next(&());
141 while let Some(item) = cursor.item() {
142 if predicate(&item.key, &item.value) {
143 new_map.push(item.clone(), &());
144 }
145 cursor.next(&());
146 }
147 drop(cursor);
148
149 self.0 = new_map;
150 }
151
152 pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
153 self.0.iter().map(|entry| (&entry.key, &entry.value))
154 }
155
156 pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
157 self.0.iter().map(|entry| &entry.value)
158 }
159
160 pub fn insert_tree(&mut self, other: TreeMap<K, V>) {
161 let edits = other
162 .iter()
163 .map(|(key, value)| {
164 Edit::Insert(MapEntry {
165 key: key.to_owned(),
166 value: value.to_owned(),
167 })
168 })
169 .collect();
170
171 self.0.edit(edits, &());
172 }
173}
174
175impl<K, V> Debug for TreeMap<K, V>
176where
177 K: Clone + Debug + Ord,
178 V: Clone + Debug,
179{
180 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181 f.debug_map().entries(self.iter()).finish()
182 }
183}
184
185#[derive(Debug)]
186struct MapSeekTargetAdaptor<'a, T>(&'a T);
187
188impl<'a, K: Clone + Ord, T: MapSeekTarget<K>> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>>
189 for MapSeekTargetAdaptor<'_, T>
190{
191 fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
192 if let Some(key) = &cursor_location.0 {
193 MapSeekTarget::cmp_cursor(self.0, key)
194 } else {
195 Ordering::Greater
196 }
197 }
198}
199
200pub trait MapSeekTarget<K> {
201 fn cmp_cursor(&self, cursor_location: &K) -> Ordering;
202}
203
204impl<K: Ord> MapSeekTarget<K> for K {
205 fn cmp_cursor(&self, cursor_location: &K) -> Ordering {
206 self.cmp(cursor_location)
207 }
208}
209
210impl<K, V> Default for TreeMap<K, V>
211where
212 K: Clone + Ord,
213 V: Clone,
214{
215 fn default() -> Self {
216 Self(Default::default())
217 }
218}
219
220impl<K, V> Item for MapEntry<K, V>
221where
222 K: Clone + Ord,
223 V: Clone,
224{
225 type Summary = MapKey<K>;
226
227 fn summary(&self, _cx: &()) -> Self::Summary {
228 self.key()
229 }
230}
231
232impl<K, V> KeyedItem for MapEntry<K, V>
233where
234 K: Clone + Ord,
235 V: Clone,
236{
237 type Key = MapKey<K>;
238
239 fn key(&self) -> Self::Key {
240 MapKey(Some(self.key.clone()))
241 }
242}
243
244impl<K> Summary for MapKey<K>
245where
246 K: Clone,
247{
248 type Context = ();
249
250 fn zero(_cx: &()) -> Self {
251 Default::default()
252 }
253
254 fn add_summary(&mut self, summary: &Self, _: &()) {
255 *self = summary.clone()
256 }
257}
258
259impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
260where
261 K: Clone + Ord,
262{
263 fn zero(_cx: &()) -> Self {
264 Default::default()
265 }
266
267 fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
268 self.0 = summary.0.as_ref();
269 }
270}
271
272impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
273where
274 K: Clone + Ord,
275{
276 fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
277 Ord::cmp(&self.0, &cursor_location.0)
278 }
279}
280
281impl<K> Default for TreeSet<K>
282where
283 K: Clone + Ord,
284{
285 fn default() -> Self {
286 Self(Default::default())
287 }
288}
289
290impl<K> TreeSet<K>
291where
292 K: Clone + Ord,
293{
294 pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
295 Self(TreeMap::from_ordered_entries(
296 entries.into_iter().map(|key| (key, ())),
297 ))
298 }
299
300 pub fn insert(&mut self, key: K) {
301 self.0.insert(key, ());
302 }
303
304 pub fn contains(&self, key: &K) -> bool {
305 self.0.get(key).is_some()
306 }
307
308 pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
309 self.0.iter().map(|(k, _)| k)
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
318 fn test_basic() {
319 let mut map = TreeMap::default();
320 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
321
322 map.insert(3, "c");
323 assert_eq!(map.get(&3), Some(&"c"));
324 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
325
326 map.insert(1, "a");
327 assert_eq!(map.get(&1), Some(&"a"));
328 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
329
330 map.insert(2, "b");
331 assert_eq!(map.get(&2), Some(&"b"));
332 assert_eq!(map.get(&1), Some(&"a"));
333 assert_eq!(map.get(&3), Some(&"c"));
334 assert_eq!(
335 map.iter().collect::<Vec<_>>(),
336 vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
337 );
338
339 assert_eq!(map.closest(&0), None);
340 assert_eq!(map.closest(&1), Some((&1, &"a")));
341 assert_eq!(map.closest(&10), Some((&3, &"c")));
342
343 map.remove(&2);
344 assert_eq!(map.get(&2), None);
345 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
346
347 assert_eq!(map.closest(&2), Some((&1, &"a")));
348
349 map.remove(&3);
350 assert_eq!(map.get(&3), None);
351 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
352
353 map.remove(&1);
354 assert_eq!(map.get(&1), None);
355 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
356
357 map.insert(4, "d");
358 map.insert(5, "e");
359 map.insert(6, "f");
360 map.retain(|key, _| *key % 2 == 0);
361 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
362 }
363
364 #[test]
365 fn test_iter_from() {
366 let mut map = TreeMap::default();
367
368 map.insert("a", 1);
369 map.insert("b", 2);
370 map.insert("baa", 3);
371 map.insert("baaab", 4);
372 map.insert("c", 5);
373
374 let result = map
375 .iter_from(&"ba")
376 .take_while(|(key, _)| key.starts_with("ba"))
377 .collect::<Vec<_>>();
378
379 assert_eq!(result.len(), 2);
380 assert!(result.iter().any(|(k, _)| k == &&"baa"));
381 assert!(result.iter().any(|(k, _)| k == &&"baaab"));
382
383 let result = map
384 .iter_from(&"c")
385 .take_while(|(key, _)| key.starts_with("c"))
386 .collect::<Vec<_>>();
387
388 assert_eq!(result.len(), 1);
389 assert!(result.iter().any(|(k, _)| k == &&"c"));
390 }
391
392 #[test]
393 fn test_insert_tree() {
394 let mut map = TreeMap::default();
395 map.insert("a", 1);
396 map.insert("b", 2);
397 map.insert("c", 3);
398
399 let mut other = TreeMap::default();
400 other.insert("a", 2);
401 other.insert("b", 2);
402 other.insert("d", 4);
403
404 map.insert_tree(other);
405
406 assert_eq!(map.iter().count(), 4);
407 assert_eq!(map.get(&"a"), Some(&2));
408 assert_eq!(map.get(&"b"), Some(&2));
409 assert_eq!(map.get(&"c"), Some(&3));
410 assert_eq!(map.get(&"d"), Some(&4));
411 }
412
413 #[test]
414 fn test_remove_between_and_path_successor() {
415 use std::path::{Path, PathBuf};
416
417 #[derive(Debug)]
418 pub struct PathDescendants<'a>(&'a Path);
419
420 impl MapSeekTarget<PathBuf> for PathDescendants<'_> {
421 fn cmp_cursor(&self, key: &PathBuf) -> Ordering {
422 if key.starts_with(self.0) {
423 Ordering::Greater
424 } else {
425 self.0.cmp(key)
426 }
427 }
428 }
429
430 let mut map = TreeMap::default();
431
432 map.insert(PathBuf::from("a"), 1);
433 map.insert(PathBuf::from("a/a"), 1);
434 map.insert(PathBuf::from("b"), 2);
435 map.insert(PathBuf::from("b/a/a"), 3);
436 map.insert(PathBuf::from("b/a/a/a/b"), 4);
437 map.insert(PathBuf::from("c"), 5);
438 map.insert(PathBuf::from("c/a"), 6);
439
440 map.remove_range(
441 &PathBuf::from("b/a"),
442 &PathDescendants(&PathBuf::from("b/a")),
443 );
444
445 assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
446 assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
447 assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
448 assert_eq!(map.get(&PathBuf::from("b/a/a")), None);
449 assert_eq!(map.get(&PathBuf::from("b/a/a/a/b")), None);
450 assert_eq!(map.get(&PathBuf::from("c")), Some(&5));
451 assert_eq!(map.get(&PathBuf::from("c/a")), Some(&6));
452
453 map.remove_range(&PathBuf::from("c"), &PathDescendants(&PathBuf::from("c")));
454
455 assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
456 assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
457 assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
458 assert_eq!(map.get(&PathBuf::from("c")), None);
459 assert_eq!(map.get(&PathBuf::from("c/a")), None);
460
461 map.remove_range(&PathBuf::from("a"), &PathDescendants(&PathBuf::from("a")));
462
463 assert_eq!(map.get(&PathBuf::from("a")), None);
464 assert_eq!(map.get(&PathBuf::from("a/a")), None);
465 assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
466
467 map.remove_range(&PathBuf::from("b"), &PathDescendants(&PathBuf::from("b")));
468
469 assert_eq!(map.get(&PathBuf::from("b")), None);
470 }
471}