1use std::{cmp::Ordering, fmt::Debug};
2
3use crate::{Bias, Dimension, Edit, Item, KeyedItem, SeekTarget, SumTree, Summary};
4
5#[derive(Clone, PartialEq, Eq)]
6pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
7where
8 K: Clone + Debug + Default + Ord,
9 V: Clone + Debug;
10
11#[derive(Clone, Debug, PartialEq, Eq)]
12pub struct MapEntry<K, V> {
13 key: K,
14 value: V,
15}
16
17#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
18pub struct MapKey<K>(K);
19
20#[derive(Clone, Debug, Default)]
21pub struct MapKeyRef<'a, K>(Option<&'a K>);
22
23#[derive(Clone)]
24pub struct TreeSet<K>(TreeMap<K, ()>)
25where
26 K: Clone + Debug + Default + Ord;
27
28impl<K: Clone + Debug + Default + Ord, V: Clone + Debug> TreeMap<K, V> {
29 pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
30 let tree = SumTree::from_iter(
31 entries
32 .into_iter()
33 .map(|(key, value)| MapEntry { key, value }),
34 &(),
35 );
36 Self(tree)
37 }
38
39 pub fn is_empty(&self) -> bool {
40 self.0.is_empty()
41 }
42
43 pub fn get<'a>(&self, key: &'a K) -> Option<&V> {
44 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
45 cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &());
46 if let Some(item) = cursor.item() {
47 if *key == item.key().0 {
48 Some(&item.value)
49 } else {
50 None
51 }
52 } else {
53 None
54 }
55 }
56
57 pub fn insert(&mut self, key: K, value: V) {
58 self.0.insert_or_replace(MapEntry { key, value }, &());
59 }
60
61 pub fn remove(&mut self, key: &K) -> Option<V> {
62 let mut removed = None;
63 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
64 let key = MapKeyRef(Some(key));
65 let mut new_tree = cursor.slice(&key, Bias::Left, &());
66 if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
67 removed = Some(cursor.item().unwrap().value.clone());
68 cursor.next(&());
69 }
70 new_tree.append(cursor.suffix(&()), &());
71 drop(cursor);
72 self.0 = new_tree;
73 removed
74 }
75
76 pub fn remove_range(&mut self, start: &impl MapSeekTarget<K>, end: &impl MapSeekTarget<K>) {
77 let start = MapSeekTargetAdaptor(start);
78 let end = MapSeekTargetAdaptor(end);
79 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
80 let mut new_tree = cursor.slice(&start, Bias::Left, &());
81 cursor.seek(&end, Bias::Left, &());
82 new_tree.append(cursor.suffix(&()), &());
83 drop(cursor);
84 self.0 = new_tree;
85 }
86
87 /// Returns the key-value pair with the greatest key less than or equal to the given key.
88 pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
89 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
90 let key = MapKeyRef(Some(key));
91 cursor.seek(&key, Bias::Right, &());
92 cursor.prev(&());
93 cursor.item().map(|item| (&item.key, &item.value))
94 }
95
96 pub fn iter_from<'a>(&'a self, from: &'a K) -> impl Iterator<Item = (&K, &V)> + '_ {
97 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
98 let from_key = MapKeyRef(Some(from));
99 cursor.seek(&from_key, Bias::Left, &());
100
101 cursor
102 .into_iter()
103 .map(|map_entry| (&map_entry.key, &map_entry.value))
104 }
105
106 pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
107 where
108 F: FnOnce(&mut V) -> T,
109 {
110 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
111 let key = MapKeyRef(Some(key));
112 let mut new_tree = cursor.slice(&key, Bias::Left, &());
113 let mut result = None;
114 if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
115 let mut updated = cursor.item().unwrap().clone();
116 result = Some(f(&mut updated.value));
117 new_tree.push(updated, &());
118 cursor.next(&());
119 }
120 new_tree.append(cursor.suffix(&()), &());
121 drop(cursor);
122 self.0 = new_tree;
123 result
124 }
125
126 pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
127 let mut new_map = SumTree::<MapEntry<K, V>>::default();
128
129 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
130 cursor.next(&());
131 while let Some(item) = cursor.item() {
132 if predicate(&item.key, &item.value) {
133 new_map.push(item.clone(), &());
134 }
135 cursor.next(&());
136 }
137 drop(cursor);
138
139 self.0 = new_map;
140 }
141
142 pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
143 self.0.iter().map(|entry| (&entry.key, &entry.value))
144 }
145
146 pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
147 self.0.iter().map(|entry| &entry.value)
148 }
149
150 pub fn insert_tree(&mut self, other: TreeMap<K, V>) {
151 let edits = other
152 .iter()
153 .map(|(key, value)| {
154 Edit::Insert(MapEntry {
155 key: key.to_owned(),
156 value: value.to_owned(),
157 })
158 })
159 .collect();
160
161 self.0.edit(edits, &());
162 }
163}
164
165impl<K: Debug, V: Debug> Debug for TreeMap<K, V>
166where
167 K: Clone + Debug + Default + Ord,
168 V: Clone + Debug,
169{
170 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171 f.debug_map().entries(self.iter()).finish()
172 }
173}
174
175#[derive(Debug)]
176struct MapSeekTargetAdaptor<'a, T>(&'a T);
177
178impl<'a, K: Debug + Clone + Default + Ord, T: MapSeekTarget<K>>
179 SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapSeekTargetAdaptor<'_, T>
180{
181 fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
182 if let Some(key) = &cursor_location.0 {
183 MapSeekTarget::cmp_cursor(self.0, key)
184 } else {
185 Ordering::Greater
186 }
187 }
188}
189
190pub trait MapSeekTarget<K>: Debug {
191 fn cmp_cursor(&self, cursor_location: &K) -> Ordering;
192}
193
194impl<K: Debug + Ord> MapSeekTarget<K> for K {
195 fn cmp_cursor(&self, cursor_location: &K) -> Ordering {
196 self.cmp(cursor_location)
197 }
198}
199
200impl<K, V> Default for TreeMap<K, V>
201where
202 K: Clone + Debug + Default + Ord,
203 V: Clone + Debug,
204{
205 fn default() -> Self {
206 Self(Default::default())
207 }
208}
209
210impl<K, V> Item for MapEntry<K, V>
211where
212 K: Clone + Debug + Default + Ord,
213 V: Clone,
214{
215 type Summary = MapKey<K>;
216
217 fn summary(&self) -> Self::Summary {
218 self.key()
219 }
220}
221
222impl<K, V> KeyedItem for MapEntry<K, V>
223where
224 K: Clone + Debug + Default + Ord,
225 V: Clone,
226{
227 type Key = MapKey<K>;
228
229 fn key(&self) -> Self::Key {
230 MapKey(self.key.clone())
231 }
232}
233
234impl<K> Summary for MapKey<K>
235where
236 K: Clone + Debug + Default,
237{
238 type Context = ();
239
240 fn add_summary(&mut self, summary: &Self, _: &()) {
241 *self = summary.clone()
242 }
243}
244
245impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
246where
247 K: Clone + Debug + Default + Ord,
248{
249 fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
250 self.0 = Some(&summary.0)
251 }
252}
253
254impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
255where
256 K: Clone + Debug + Default + Ord,
257{
258 fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
259 Ord::cmp(&self.0, &cursor_location.0)
260 }
261}
262
263impl<K> Default for TreeSet<K>
264where
265 K: Clone + Debug + Default + Ord,
266{
267 fn default() -> Self {
268 Self(Default::default())
269 }
270}
271
272impl<K> TreeSet<K>
273where
274 K: Clone + Debug + Default + Ord,
275{
276 pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
277 Self(TreeMap::from_ordered_entries(
278 entries.into_iter().map(|key| (key, ())),
279 ))
280 }
281
282 pub fn insert(&mut self, key: K) {
283 self.0.insert(key, ());
284 }
285
286 pub fn contains(&self, key: &K) -> bool {
287 self.0.get(key).is_some()
288 }
289
290 pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
291 self.0.iter().map(|(k, _)| k)
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 #[test]
300 fn test_basic() {
301 let mut map = TreeMap::default();
302 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
303
304 map.insert(3, "c");
305 assert_eq!(map.get(&3), Some(&"c"));
306 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
307
308 map.insert(1, "a");
309 assert_eq!(map.get(&1), Some(&"a"));
310 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
311
312 map.insert(2, "b");
313 assert_eq!(map.get(&2), Some(&"b"));
314 assert_eq!(map.get(&1), Some(&"a"));
315 assert_eq!(map.get(&3), Some(&"c"));
316 assert_eq!(
317 map.iter().collect::<Vec<_>>(),
318 vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
319 );
320
321 assert_eq!(map.closest(&0), None);
322 assert_eq!(map.closest(&1), Some((&1, &"a")));
323 assert_eq!(map.closest(&10), Some((&3, &"c")));
324
325 map.remove(&2);
326 assert_eq!(map.get(&2), None);
327 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
328
329 assert_eq!(map.closest(&2), Some((&1, &"a")));
330
331 map.remove(&3);
332 assert_eq!(map.get(&3), None);
333 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
334
335 map.remove(&1);
336 assert_eq!(map.get(&1), None);
337 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
338
339 map.insert(4, "d");
340 map.insert(5, "e");
341 map.insert(6, "f");
342 map.retain(|key, _| *key % 2 == 0);
343 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
344 }
345
346 #[test]
347 fn test_iter_from() {
348 let mut map = TreeMap::default();
349
350 map.insert("a", 1);
351 map.insert("b", 2);
352 map.insert("baa", 3);
353 map.insert("baaab", 4);
354 map.insert("c", 5);
355
356 let result = map
357 .iter_from(&"ba")
358 .take_while(|(key, _)| key.starts_with(&"ba"))
359 .collect::<Vec<_>>();
360
361 assert_eq!(result.len(), 2);
362 assert!(result.iter().find(|(k, _)| k == &&"baa").is_some());
363 assert!(result.iter().find(|(k, _)| k == &&"baaab").is_some());
364
365 let result = map
366 .iter_from(&"c")
367 .take_while(|(key, _)| key.starts_with(&"c"))
368 .collect::<Vec<_>>();
369
370 assert_eq!(result.len(), 1);
371 assert!(result.iter().find(|(k, _)| k == &&"c").is_some());
372 }
373
374 #[test]
375 fn test_insert_tree() {
376 let mut map = TreeMap::default();
377 map.insert("a", 1);
378 map.insert("b", 2);
379 map.insert("c", 3);
380
381 let mut other = TreeMap::default();
382 other.insert("a", 2);
383 other.insert("b", 2);
384 other.insert("d", 4);
385
386 map.insert_tree(other);
387
388 assert_eq!(map.iter().count(), 4);
389 assert_eq!(map.get(&"a"), Some(&2));
390 assert_eq!(map.get(&"b"), Some(&2));
391 assert_eq!(map.get(&"c"), Some(&3));
392 assert_eq!(map.get(&"d"), Some(&4));
393 }
394
395 #[test]
396 fn test_remove_between_and_path_successor() {
397 use std::path::{Path, PathBuf};
398
399 #[derive(Debug)]
400 pub struct PathDescendants<'a>(&'a Path);
401
402 impl MapSeekTarget<PathBuf> for PathDescendants<'_> {
403 fn cmp_cursor(&self, key: &PathBuf) -> Ordering {
404 if key.starts_with(&self.0) {
405 Ordering::Greater
406 } else {
407 self.0.cmp(key)
408 }
409 }
410 }
411
412 let mut map = TreeMap::default();
413
414 map.insert(PathBuf::from("a"), 1);
415 map.insert(PathBuf::from("a/a"), 1);
416 map.insert(PathBuf::from("b"), 2);
417 map.insert(PathBuf::from("b/a/a"), 3);
418 map.insert(PathBuf::from("b/a/a/a/b"), 4);
419 map.insert(PathBuf::from("c"), 5);
420 map.insert(PathBuf::from("c/a"), 6);
421
422 map.remove_range(
423 &PathBuf::from("b/a"),
424 &PathDescendants(&PathBuf::from("b/a")),
425 );
426
427 assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
428 assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
429 assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
430 assert_eq!(map.get(&PathBuf::from("b/a/a")), None);
431 assert_eq!(map.get(&PathBuf::from("b/a/a/a/b")), None);
432 assert_eq!(map.get(&PathBuf::from("c")), Some(&5));
433 assert_eq!(map.get(&PathBuf::from("c/a")), Some(&6));
434
435 map.remove_range(&PathBuf::from("c"), &PathDescendants(&PathBuf::from("c")));
436
437 assert_eq!(map.get(&PathBuf::from("a")), Some(&1));
438 assert_eq!(map.get(&PathBuf::from("a/a")), Some(&1));
439 assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
440 assert_eq!(map.get(&PathBuf::from("c")), None);
441 assert_eq!(map.get(&PathBuf::from("c/a")), None);
442
443 map.remove_range(&PathBuf::from("a"), &PathDescendants(&PathBuf::from("a")));
444
445 assert_eq!(map.get(&PathBuf::from("a")), None);
446 assert_eq!(map.get(&PathBuf::from("a/a")), None);
447 assert_eq!(map.get(&PathBuf::from("b")), Some(&2));
448
449 map.remove_range(&PathBuf::from("b"), &PathDescendants(&PathBuf::from("b")));
450
451 assert_eq!(map.get(&PathBuf::from("b")), None);
452 }
453}