1use std::{cmp::Ordering, fmt::Debug};
2
3use crate::{Bias, Dimension, Edit, Item, KeyedItem, SeekTarget, SumTree, Summary};
4
5#[derive(Clone, Debug, PartialEq, Eq)]
6pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
7where
8 K: Clone + Debug + Default + Ord,
9 V: Clone + Debug;
10
11#[derive(Clone, Debug, PartialEq, Eq)]
12pub struct MapEntry<K, V> {
13 key: K,
14 value: V,
15}
16
17#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
18pub struct MapKey<K>(K);
19
20#[derive(Clone, Debug, Default)]
21pub struct MapKeyRef<'a, K>(Option<&'a K>);
22
23#[derive(Clone)]
24pub struct TreeSet<K>(TreeMap<K, ()>)
25where
26 K: Clone + Debug + Default + Ord;
27
28impl<K: Clone + Debug + Default + Ord, V: Clone + Debug> TreeMap<K, V> {
29 pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
30 let tree = SumTree::from_iter(
31 entries
32 .into_iter()
33 .map(|(key, value)| MapEntry { key, value }),
34 &(),
35 );
36 Self(tree)
37 }
38
39 pub fn is_empty(&self) -> bool {
40 self.0.is_empty()
41 }
42
43 pub fn get<'a>(&self, key: &'a K) -> Option<&V> {
44 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
45 cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &());
46 if let Some(item) = cursor.item() {
47 if *key == item.key().0 {
48 Some(&item.value)
49 } else {
50 None
51 }
52 } else {
53 None
54 }
55 }
56
57 pub fn insert(&mut self, key: K, value: V) {
58 self.0.insert_or_replace(MapEntry { key, value }, &());
59 }
60
61 pub fn remove(&mut self, key: &K) -> Option<V> {
62 let mut removed = None;
63 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
64 let key = MapKeyRef(Some(key));
65 let mut new_tree = cursor.slice(&key, Bias::Left, &());
66 if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
67 removed = Some(cursor.item().unwrap().value.clone());
68 cursor.next(&());
69 }
70 new_tree.push_tree(cursor.suffix(&()), &());
71 drop(cursor);
72 self.0 = new_tree;
73 removed
74 }
75
76 /// Returns the key-value pair with the greatest key less than or equal to the given key.
77 pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
78 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
79 let key = MapKeyRef(Some(key));
80 cursor.seek(&key, Bias::Right, &());
81 cursor.prev(&());
82 cursor.item().map(|item| (&item.key, &item.value))
83 }
84
85 pub fn remove_between(&mut self, from: &K, until: &K) {
86 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
87 let from_key = MapKeyRef(Some(from));
88 let mut new_tree = cursor.slice(&from_key, Bias::Left, &());
89 let until_key = MapKeyRef(Some(until));
90 cursor.seek_forward(&until_key, Bias::Left, &());
91 new_tree.push_tree(cursor.suffix(&()), &());
92 drop(cursor);
93 self.0 = new_tree;
94 }
95
96 pub fn iter_from<'a>(&'a self, from: &'a K) -> impl Iterator<Item = (&K, &V)> + '_ {
97 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
98 let from_key = MapKeyRef(Some(from));
99 cursor.seek(&from_key, Bias::Left, &());
100
101 cursor
102 .into_iter()
103 .map(|map_entry| (&map_entry.key, &map_entry.value))
104 }
105
106 pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
107 where
108 F: FnOnce(&mut V) -> T,
109 {
110 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
111 let key = MapKeyRef(Some(key));
112 let mut new_tree = cursor.slice(&key, Bias::Left, &());
113 let mut result = None;
114 if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
115 let mut updated = cursor.item().unwrap().clone();
116 result = Some(f(&mut updated.value));
117 new_tree.push(updated, &());
118 cursor.next(&());
119 }
120 new_tree.push_tree(cursor.suffix(&()), &());
121 drop(cursor);
122 self.0 = new_tree;
123 result
124 }
125
126 pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
127 let mut new_map = SumTree::<MapEntry<K, V>>::default();
128
129 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
130 cursor.next(&());
131 while let Some(item) = cursor.item() {
132 if predicate(&item.key, &item.value) {
133 new_map.push(item.clone(), &());
134 }
135 cursor.next(&());
136 }
137 drop(cursor);
138
139 self.0 = new_map;
140 }
141
142 pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
143 self.0.iter().map(|entry| (&entry.key, &entry.value))
144 }
145
146 pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
147 self.0.iter().map(|entry| &entry.value)
148 }
149
150 pub fn insert_tree(&mut self, other: TreeMap<K, V>) {
151 let edits = other
152 .iter()
153 .map(|(key, value)| {
154 Edit::Insert(MapEntry {
155 key: key.to_owned(),
156 value: value.to_owned(),
157 })
158 })
159 .collect();
160
161 self.0.edit(edits, &());
162 }
163
164 pub fn remove_by<F>(&mut self, key: &K, f: F)
165 where
166 F: Fn(&K) -> bool,
167 {
168 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
169 let key = MapKeyRef(Some(key));
170 let mut new_tree = cursor.slice(&key, Bias::Left, &());
171 let until = RemoveByTarget(key, &f);
172 cursor.seek_forward(&until, Bias::Right, &());
173 new_tree.push_tree(cursor.suffix(&()), &());
174 drop(cursor);
175 self.0 = new_tree;
176 }
177}
178
179struct RemoveByTarget<'a, K>(MapKeyRef<'a, K>, &'a dyn Fn(&K) -> bool);
180
181impl<'a, K: Debug> Debug for RemoveByTarget<'a, K> {
182 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183 f.debug_struct("RemoveByTarget")
184 .field("key", &self.0)
185 .field("F", &"<...>")
186 .finish()
187 }
188}
189
190impl<'a, K: Debug + Clone + Default + Ord> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>>
191 for RemoveByTarget<'_, K>
192{
193 fn cmp(
194 &self,
195 cursor_location: &MapKeyRef<'a, K>,
196 _cx: &<MapKey<K> as Summary>::Context,
197 ) -> Ordering {
198 if let Some(cursor_location) = cursor_location.0 {
199 if (self.1)(cursor_location) {
200 Ordering::Equal
201 } else {
202 self.0 .0.unwrap().cmp(cursor_location)
203 }
204 } else {
205 Ordering::Greater
206 }
207 }
208}
209
210impl<K, V> Default for TreeMap<K, V>
211where
212 K: Clone + Debug + Default + Ord,
213 V: Clone + Debug,
214{
215 fn default() -> Self {
216 Self(Default::default())
217 }
218}
219
220impl<K, V> Item for MapEntry<K, V>
221where
222 K: Clone + Debug + Default + Ord,
223 V: Clone,
224{
225 type Summary = MapKey<K>;
226
227 fn summary(&self) -> Self::Summary {
228 self.key()
229 }
230}
231
232impl<K, V> KeyedItem for MapEntry<K, V>
233where
234 K: Clone + Debug + Default + Ord,
235 V: Clone,
236{
237 type Key = MapKey<K>;
238
239 fn key(&self) -> Self::Key {
240 MapKey(self.key.clone())
241 }
242}
243
244impl<K> Summary for MapKey<K>
245where
246 K: Clone + Debug + Default,
247{
248 type Context = ();
249
250 fn add_summary(&mut self, summary: &Self, _: &()) {
251 *self = summary.clone()
252 }
253}
254
255impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
256where
257 K: Clone + Debug + Default + Ord,
258{
259 fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
260 self.0 = Some(&summary.0)
261 }
262}
263
264impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
265where
266 K: Clone + Debug + Default + Ord,
267{
268 fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
269 self.0.cmp(&cursor_location.0)
270 }
271}
272
273impl<K> Default for TreeSet<K>
274where
275 K: Clone + Debug + Default + Ord,
276{
277 fn default() -> Self {
278 Self(Default::default())
279 }
280}
281
282impl<K> TreeSet<K>
283where
284 K: Clone + Debug + Default + Ord,
285{
286 pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
287 Self(TreeMap::from_ordered_entries(
288 entries.into_iter().map(|key| (key, ())),
289 ))
290 }
291
292 pub fn insert(&mut self, key: K) {
293 self.0.insert(key, ());
294 }
295
296 pub fn contains(&self, key: &K) -> bool {
297 self.0.get(key).is_some()
298 }
299
300 pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
301 self.0.iter().map(|(k, _)| k)
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308
309 #[test]
310 fn test_basic() {
311 let mut map = TreeMap::default();
312 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
313
314 map.insert(3, "c");
315 assert_eq!(map.get(&3), Some(&"c"));
316 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
317
318 map.insert(1, "a");
319 assert_eq!(map.get(&1), Some(&"a"));
320 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
321
322 map.insert(2, "b");
323 assert_eq!(map.get(&2), Some(&"b"));
324 assert_eq!(map.get(&1), Some(&"a"));
325 assert_eq!(map.get(&3), Some(&"c"));
326 assert_eq!(
327 map.iter().collect::<Vec<_>>(),
328 vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
329 );
330
331 assert_eq!(map.closest(&0), None);
332 assert_eq!(map.closest(&1), Some((&1, &"a")));
333 assert_eq!(map.closest(&10), Some((&3, &"c")));
334
335 map.remove(&2);
336 assert_eq!(map.get(&2), None);
337 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
338
339 assert_eq!(map.closest(&2), Some((&1, &"a")));
340
341 map.remove(&3);
342 assert_eq!(map.get(&3), None);
343 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
344
345 map.remove(&1);
346 assert_eq!(map.get(&1), None);
347 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
348
349 map.insert(4, "d");
350 map.insert(5, "e");
351 map.insert(6, "f");
352 map.retain(|key, _| *key % 2 == 0);
353 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
354 }
355
356 #[test]
357 fn test_remove_between() {
358 let mut map = TreeMap::default();
359
360 map.insert("a", 1);
361 map.insert("b", 2);
362 map.insert("baa", 3);
363 map.insert("baaab", 4);
364 map.insert("c", 5);
365
366 map.remove_between(&"ba", &"bb");
367
368 assert_eq!(map.get(&"a"), Some(&1));
369 assert_eq!(map.get(&"b"), Some(&2));
370 assert_eq!(map.get(&"baaa"), None);
371 assert_eq!(map.get(&"baaaab"), None);
372 assert_eq!(map.get(&"c"), Some(&5));
373 }
374
375 #[test]
376 fn test_remove_by() {
377 let mut map = TreeMap::default();
378
379 map.insert("a", 1);
380 map.insert("aa", 1);
381 map.insert("b", 2);
382 map.insert("baa", 3);
383 map.insert("baaab", 4);
384 map.insert("c", 5);
385 map.insert("ca", 6);
386
387 map.remove_by(&"ba", |key| key.starts_with("ba"));
388
389 assert_eq!(map.get(&"a"), Some(&1));
390 assert_eq!(map.get(&"aa"), Some(&1));
391 assert_eq!(map.get(&"b"), Some(&2));
392 assert_eq!(map.get(&"baaa"), None);
393 assert_eq!(map.get(&"baaaab"), None);
394 assert_eq!(map.get(&"c"), Some(&5));
395 assert_eq!(map.get(&"ca"), Some(&6));
396
397 map.remove_by(&"c", |key| key.starts_with("c"));
398
399 assert_eq!(map.get(&"a"), Some(&1));
400 assert_eq!(map.get(&"aa"), Some(&1));
401 assert_eq!(map.get(&"b"), Some(&2));
402 assert_eq!(map.get(&"c"), None);
403 assert_eq!(map.get(&"ca"), None);
404
405 map.remove_by(&"a", |key| key.starts_with("a"));
406
407 assert_eq!(map.get(&"a"), None);
408 assert_eq!(map.get(&"aa"), None);
409 assert_eq!(map.get(&"b"), Some(&2));
410
411 map.remove_by(&"b", |key| key.starts_with("b"));
412
413 assert_eq!(map.get(&"b"), None);
414 }
415
416 #[test]
417 fn test_iter_from() {
418 let mut map = TreeMap::default();
419
420 map.insert("a", 1);
421 map.insert("b", 2);
422 map.insert("baa", 3);
423 map.insert("baaab", 4);
424 map.insert("c", 5);
425
426 let result = map
427 .iter_from(&"ba")
428 .take_while(|(key, _)| key.starts_with(&"ba"))
429 .collect::<Vec<_>>();
430
431 assert_eq!(result.len(), 2);
432 assert!(result.iter().find(|(k, _)| k == &&"baa").is_some());
433 assert!(result.iter().find(|(k, _)| k == &&"baaab").is_some());
434
435 let result = map
436 .iter_from(&"c")
437 .take_while(|(key, _)| key.starts_with(&"c"))
438 .collect::<Vec<_>>();
439
440 assert_eq!(result.len(), 1);
441 assert!(result.iter().find(|(k, _)| k == &&"c").is_some());
442 }
443
444 #[test]
445 fn test_insert_tree() {
446 let mut map = TreeMap::default();
447 map.insert("a", 1);
448 map.insert("b", 2);
449 map.insert("c", 3);
450
451 let mut other = TreeMap::default();
452 other.insert("a", 2);
453 other.insert("b", 2);
454 other.insert("d", 4);
455
456 map.insert_tree(other);
457
458 assert_eq!(map.iter().count(), 4);
459 assert_eq!(map.get(&"a"), Some(&2));
460 assert_eq!(map.get(&"b"), Some(&2));
461 assert_eq!(map.get(&"c"), Some(&3));
462 assert_eq!(map.get(&"d"), Some(&4));
463 }
464}