1use std::{cmp::Ordering, fmt::Debug};
2
3use crate::{Bias, Dimension, Edit, Item, KeyedItem, SeekTarget, SumTree, Summary};
4
5#[derive(Clone, PartialEq, Eq)]
6pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
7where
8 K: Clone + Ord,
9 V: Clone;
10
11#[derive(Clone, Debug, PartialEq, Eq)]
12pub struct MapEntry<K, V> {
13 key: K,
14 value: V,
15}
16
17#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
18pub struct MapKey<K>(Option<K>);
19
20impl<K> Default for MapKey<K> {
21 fn default() -> Self {
22 Self(None)
23 }
24}
25
26#[derive(Clone, Debug)]
27pub struct MapKeyRef<'a, K>(Option<&'a K>);
28
29impl<K> Default for MapKeyRef<'_, K> {
30 fn default() -> Self {
31 Self(None)
32 }
33}
34
35#[derive(Clone, Debug, PartialEq, Eq)]
36pub struct TreeSet<K>(TreeMap<K, ()>)
37where
38 K: Clone + Ord;
39
40impl<K: Clone + Ord, V: Clone> TreeMap<K, V> {
41 pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
42 let tree = SumTree::from_iter(
43 entries
44 .into_iter()
45 .map(|(key, value)| MapEntry { key, value }),
46 &(),
47 );
48 Self(tree)
49 }
50
51 pub fn is_empty(&self) -> bool {
52 self.0.is_empty()
53 }
54
55 pub fn get(&self, key: &K) -> Option<&V> {
56 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
57 cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &());
58 if let Some(item) = cursor.item() {
59 if Some(key) == item.key().0.as_ref() {
60 Some(&item.value)
61 } else {
62 None
63 }
64 } else {
65 None
66 }
67 }
68
69 pub fn insert(&mut self, key: K, value: V) {
70 self.0.insert_or_replace(MapEntry { key, value }, &());
71 }
72
73 pub fn clear(&mut self) {
74 self.0 = SumTree::default();
75 }
76
77 pub fn remove(&mut self, key: &K) -> Option<V> {
78 let mut removed = None;
79 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
80 let key = MapKeyRef(Some(key));
81 let mut new_tree = cursor.slice(&key, Bias::Left, &());
82 if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
83 removed = Some(cursor.item().unwrap().value.clone());
84 cursor.next(&());
85 }
86 new_tree.append(cursor.suffix(&()), &());
87 drop(cursor);
88 self.0 = new_tree;
89 removed
90 }
91
92 pub fn remove_range(&mut self, start: &impl MapSeekTarget<K>, end: &impl MapSeekTarget<K>) {
93 let start = MapSeekTargetAdaptor(start);
94 let end = MapSeekTargetAdaptor(end);
95 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
96 let mut new_tree = cursor.slice(&start, Bias::Left, &());
97 cursor.seek(&end, Bias::Left, &());
98 new_tree.append(cursor.suffix(&()), &());
99 drop(cursor);
100 self.0 = new_tree;
101 }
102
103 /// Returns the key-value pair with the greatest key less than or equal to the given key.
104 pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
105 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
106 let key = MapKeyRef(Some(key));
107 cursor.seek(&key, Bias::Right, &());
108 cursor.prev(&());
109 cursor.item().map(|item| (&item.key, &item.value))
110 }
111
112 pub fn iter_from<'a>(&'a self, from: &'a K) -> impl Iterator<Item = (&'a K, &'a V)> + 'a {
113 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
114 let from_key = MapKeyRef(Some(from));
115 cursor.seek(&from_key, Bias::Left, &());
116
117 cursor.map(|map_entry| (&map_entry.key, &map_entry.value))
118 }
119
120 pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
121 where
122 F: FnOnce(&mut V) -> T,
123 {
124 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
125 let key = MapKeyRef(Some(key));
126 let mut new_tree = cursor.slice(&key, Bias::Left, &());
127 let mut result = None;
128 if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
129 let mut updated = cursor.item().unwrap().clone();
130 result = Some(f(&mut updated.value));
131 new_tree.push(updated, &());
132 cursor.next(&());
133 }
134 new_tree.append(cursor.suffix(&()), &());
135 drop(cursor);
136 self.0 = new_tree;
137 result
138 }
139
140 pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
141 let mut new_map = SumTree::<MapEntry<K, V>>::default();
142
143 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>(&());
144 cursor.next(&());
145 while let Some(item) = cursor.item() {
146 if predicate(&item.key, &item.value) {
147 new_map.push(item.clone(), &());
148 }
149 cursor.next(&());
150 }
151 drop(cursor);
152
153 self.0 = new_map;
154 }
155
156 pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
157 self.0.iter().map(|entry| (&entry.key, &entry.value))
158 }
159
160 pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
161 self.0.iter().map(|entry| &entry.value)
162 }
163
164 pub fn first(&self) -> Option<(&K, &V)> {
165 self.0.first().map(|entry| (&entry.key, &entry.value))
166 }
167
168 pub fn last(&self) -> Option<(&K, &V)> {
169 self.0.last().map(|entry| (&entry.key, &entry.value))
170 }
171
172 pub fn insert_tree(&mut self, other: TreeMap<K, V>) {
173 let edits = other
174 .iter()
175 .map(|(key, value)| {
176 Edit::Insert(MapEntry {
177 key: key.to_owned(),
178 value: value.to_owned(),
179 })
180 })
181 .collect();
182
183 self.0.edit(edits, &());
184 }
185}
186
187impl<K, V> Debug for TreeMap<K, V>
188where
189 K: Clone + Debug + Ord,
190 V: Clone + Debug,
191{
192 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193 f.debug_map().entries(self.iter()).finish()
194 }
195}
196
197#[derive(Debug)]
198struct MapSeekTargetAdaptor<'a, T>(&'a T);
199
200impl<'a, K: Clone + Ord, T: MapSeekTarget<K>> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>>
201 for MapSeekTargetAdaptor<'_, T>
202{
203 fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
204 if let Some(key) = &cursor_location.0 {
205 MapSeekTarget::cmp_cursor(self.0, key)
206 } else {
207 Ordering::Greater
208 }
209 }
210}
211
212pub trait MapSeekTarget<K> {
213 fn cmp_cursor(&self, cursor_location: &K) -> Ordering;
214}
215
216impl<K: Ord> MapSeekTarget<K> for K {
217 fn cmp_cursor(&self, cursor_location: &K) -> Ordering {
218 self.cmp(cursor_location)
219 }
220}
221
222impl<K, V> Default for TreeMap<K, V>
223where
224 K: Clone + Ord,
225 V: Clone,
226{
227 fn default() -> Self {
228 Self(Default::default())
229 }
230}
231
232impl<K, V> Item for MapEntry<K, V>
233where
234 K: Clone + Ord,
235 V: Clone,
236{
237 type Summary = MapKey<K>;
238
239 fn summary(&self, _cx: &()) -> Self::Summary {
240 self.key()
241 }
242}
243
244impl<K, V> KeyedItem for MapEntry<K, V>
245where
246 K: Clone + Ord,
247 V: Clone,
248{
249 type Key = MapKey<K>;
250
251 fn key(&self) -> Self::Key {
252 MapKey(Some(self.key.clone()))
253 }
254}
255
256impl<K> Summary for MapKey<K>
257where
258 K: Clone,
259{
260 type Context = ();
261
262 fn zero(_cx: &()) -> Self {
263 Default::default()
264 }
265
266 fn add_summary(&mut self, summary: &Self, _: &()) {
267 *self = summary.clone()
268 }
269}
270
271impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
272where
273 K: Clone + Ord,
274{
275 fn zero(_cx: &()) -> Self {
276 Default::default()
277 }
278
279 fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
280 self.0 = summary.0.as_ref();
281 }
282}
283
284impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
285where
286 K: Clone + Ord,
287{
288 fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
289 Ord::cmp(&self.0, &cursor_location.0)
290 }
291}
292
293impl<K> Default for TreeSet<K>
294where
295 K: Clone + Ord,
296{
297 fn default() -> Self {
298 Self(Default::default())
299 }
300}
301
302impl<K> TreeSet<K>
303where
304 K: Clone + Ord,
305{
306 pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
307 Self(TreeMap::from_ordered_entries(
308 entries.into_iter().map(|key| (key, ())),
309 ))
310 }
311
312 pub fn insert(&mut self, key: K) {
313 self.0.insert(key, ());
314 }
315
316 pub fn contains(&self, key: &K) -> bool {
317 self.0.get(key).is_some()
318 }
319
320 pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
321 self.0.iter().map(|(k, _)| k)
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328
329 #[test]
330 fn test_basic() {
331 let mut map = TreeMap::default();
332 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
333
334 map.insert(3, "c");
335 assert_eq!(map.get(&3), Some(&"c"));
336 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
337
338 map.insert(1, "a");
339 assert_eq!(map.get(&1), Some(&"a"));
340 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
341
342 map.insert(2, "b");
343 assert_eq!(map.get(&2), Some(&"b"));
344 assert_eq!(map.get(&1), Some(&"a"));
345 assert_eq!(map.get(&3), Some(&"c"));
346 assert_eq!(
347 map.iter().collect::<Vec<_>>(),
348 vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
349 );
350
351 assert_eq!(map.closest(&0), None);
352 assert_eq!(map.closest(&1), Some((&1, &"a")));
353 assert_eq!(map.closest(&10), Some((&3, &"c")));
354
355 map.remove(&2);
356 assert_eq!(map.get(&2), None);
357 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
358
359 assert_eq!(map.closest(&2), Some((&1, &"a")));
360
361 map.remove(&3);
362 assert_eq!(map.get(&3), None);
363 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
364
365 map.remove(&1);
366 assert_eq!(map.get(&1), None);
367 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
368
369 map.insert(4, "d");
370 map.insert(5, "e");
371 map.insert(6, "f");
372 map.retain(|key, _| *key % 2 == 0);
373 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
374 }
375
376 #[test]
377 fn test_iter_from() {
378 let mut map = TreeMap::default();
379
380 map.insert("a", 1);
381 map.insert("b", 2);
382 map.insert("baa", 3);
383 map.insert("baaab", 4);
384 map.insert("c", 5);
385
386 let result = map
387 .iter_from(&"ba")
388 .take_while(|(key, _)| key.starts_with("ba"))
389 .collect::<Vec<_>>();
390
391 assert_eq!(result.len(), 2);
392 assert!(result.iter().any(|(k, _)| k == &&"baa"));
393 assert!(result.iter().any(|(k, _)| k == &&"baaab"));
394
395 let result = map
396 .iter_from(&"c")
397 .take_while(|(key, _)| key.starts_with("c"))
398 .collect::<Vec<_>>();
399
400 assert_eq!(result.len(), 1);
401 assert!(result.iter().any(|(k, _)| k == &&"c"));
402 }
403
404 #[test]
405 fn test_insert_tree() {
406 let mut map = TreeMap::default();
407 map.insert("a", 1);
408 map.insert("b", 2);
409 map.insert("c", 3);
410
411 let mut other = TreeMap::default();
412 other.insert("a", 2);
413 other.insert("b", 2);
414 other.insert("d", 4);
415
416 map.insert_tree(other);
417
418 assert_eq!(map.iter().count(), 4);
419 assert_eq!(map.get(&"a"), Some(&2));
420 assert_eq!(map.get(&"b"), Some(&2));
421 assert_eq!(map.get(&"c"), Some(&3));
422 assert_eq!(map.get(&"d"), Some(&4));
423 }
424
425 #[test]
426 fn test_remove_between_and_path_successor() {
427 use std::path::{Path, PathBuf};
428
429 #[derive(Debug)]
430 pub struct PathDescendants<'a>(&'a Path);
431
432 impl MapSeekTarget<PathBuf> for PathDescendants<'_> {
433 fn cmp_cursor(&self, key: &PathBuf) -> Ordering {
434 if key.starts_with(self.0) {
435 Ordering::Greater
436 } else {
437 self.0.cmp(key)
438 }
439 }
440 }
441
442 let mut map = TreeMap::default();
443
444 map.insert(PathBuf::from("a"), 1);
445 map.insert(PathBuf::from("a/a"), 1);
446 map.insert(PathBuf::from("b"), 2);
447 map.insert(PathBuf::from("b/a/a"), 3);
448 map.insert(PathBuf::from("b/a/a/a/b"), 4);
449 map.insert(PathBuf::from("c"), 5);
450 map.insert(PathBuf::from("c/a"), 6);
451
452 map.remove_range(
453 &PathBuf::from("b/a"),
454 &PathDescendants(&PathBuf::from("b/a")),
455 );
456
457 assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
458 assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
459 assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
460 assert_eq!(map.get(&PathBuf::from("b/a/a")), None);
461 assert_eq!(map.get(&PathBuf::from("b/a/a/a/b")), None);
462 assert_eq!(map.get(&PathBuf::from("c")), Some(&5));
463 assert_eq!(map.get(&PathBuf::from("c/a")), Some(&6));
464
465 map.remove_range(&PathBuf::from("c"), &PathDescendants(&PathBuf::from("c")));
466
467 assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
468 assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
469 assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
470 assert_eq!(map.get(&PathBuf::from("c")), None);
471 assert_eq!(map.get(&PathBuf::from("c/a")), None);
472
473 map.remove_range(&PathBuf::from("a"), &PathDescendants(&PathBuf::from("a")));
474
475 assert_eq!(map.get(&PathBuf::from("a")), None);
476 assert_eq!(map.get(&PathBuf::from("a/a")), None);
477 assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
478
479 map.remove_range(&PathBuf::from("b"), &PathDescendants(&PathBuf::from("b")));
480
481 assert_eq!(map.get(&PathBuf::from("b")), None);
482 }
483}