1use std::{cmp::Ordering, fmt::Debug, iter};
2
3use crate::{Bias, Dimension, Edit, Item, KeyedItem, SeekTarget, SumTree, Summary};
4
5#[derive(Clone, Debug, PartialEq, Eq)]
6pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
7where
8 K: Clone + Debug + Default + Ord,
9 V: Clone + Debug;
10
11#[derive(Clone, Debug, PartialEq, Eq)]
12pub struct MapEntry<K, V> {
13 key: K,
14 value: V,
15}
16
17#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
18pub struct MapKey<K>(K);
19
20#[derive(Clone, Debug, Default)]
21pub struct MapKeyRef<'a, K>(Option<&'a K>);
22
23#[derive(Clone)]
24pub struct TreeSet<K>(TreeMap<K, ()>)
25where
26 K: Clone + Debug + Default + Ord;
27
28impl<K: Clone + Debug + Default + Ord, V: Clone + Debug> TreeMap<K, V> {
29 pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
30 let tree = SumTree::from_iter(
31 entries
32 .into_iter()
33 .map(|(key, value)| MapEntry { key, value }),
34 &(),
35 );
36 Self(tree)
37 }
38
39 pub fn is_empty(&self) -> bool {
40 self.0.is_empty()
41 }
42
43 pub fn get<'a>(&self, key: &'a K) -> Option<&V> {
44 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
45 cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &());
46 if let Some(item) = cursor.item() {
47 if *key == item.key().0 {
48 Some(&item.value)
49 } else {
50 None
51 }
52 } else {
53 None
54 }
55 }
56
57 pub fn insert(&mut self, key: K, value: V) {
58 self.0.insert_or_replace(MapEntry { key, value }, &());
59 }
60
61 pub fn remove(&mut self, key: &K) -> Option<V> {
62 let mut removed = None;
63 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
64 let key = MapKeyRef(Some(key));
65 let mut new_tree = cursor.slice(&key, Bias::Left, &());
66 if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
67 removed = Some(cursor.item().unwrap().value.clone());
68 cursor.next(&());
69 }
70 new_tree.push_tree(cursor.suffix(&()), &());
71 drop(cursor);
72 self.0 = new_tree;
73 removed
74 }
75
76 /// Returns the key-value pair with the greatest key less than or equal to the given key.
77 pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
78 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
79 let key = MapKeyRef(Some(key));
80 cursor.seek(&key, Bias::Right, &());
81 cursor.prev(&());
82 cursor.item().map(|item| (&item.key, &item.value))
83 }
84
85 pub fn remove_between(&mut self, from: &K, until: &K) {
86 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
87 let from_key = MapKeyRef(Some(from));
88 let mut new_tree = cursor.slice(&from_key, Bias::Left, &());
89 let until_key = MapKeyRef(Some(until));
90 cursor.seek_forward(&until_key, Bias::Left, &());
91 new_tree.push_tree(cursor.suffix(&()), &());
92 drop(cursor);
93 self.0 = new_tree;
94 }
95
96 pub fn remove_from_while<F>(&mut self, from: &K, mut f: F)
97 where
98 F: FnMut(&K, &V) -> bool,
99 {
100 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
101 let from_key = MapKeyRef(Some(from));
102 let mut new_tree = cursor.slice(&from_key, Bias::Left, &());
103 while let Some(item) = cursor.item() {
104 if !f(&item.key, &item.value) {
105 break;
106 }
107 cursor.next(&());
108 }
109 new_tree.push_tree(cursor.suffix(&()), &());
110 drop(cursor);
111 self.0 = new_tree;
112 }
113
114 pub fn get_from_while<'tree, F>(
115 &'tree self,
116 from: &'tree K,
117 mut f: F,
118 ) -> impl Iterator<Item = (&K, &V)> + '_
119 where
120 F: FnMut(&K, &K, &V) -> bool + 'tree,
121 {
122 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
123 let from_key = MapKeyRef(Some(from));
124 cursor.seek(&from_key, Bias::Left, &());
125
126 iter::from_fn(move || {
127 let result = cursor.item().and_then(|item| {
128 (f(from, &item.key, &item.value)).then(|| (&item.key, &item.value))
129 });
130 cursor.next(&());
131 result
132 })
133 }
134
135 pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
136 where
137 F: FnOnce(&mut V) -> T,
138 {
139 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
140 let key = MapKeyRef(Some(key));
141 let mut new_tree = cursor.slice(&key, Bias::Left, &());
142 let mut result = None;
143 if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
144 let mut updated = cursor.item().unwrap().clone();
145 result = Some(f(&mut updated.value));
146 new_tree.push(updated, &());
147 cursor.next(&());
148 }
149 new_tree.push_tree(cursor.suffix(&()), &());
150 drop(cursor);
151 self.0 = new_tree;
152 result
153 }
154
155 pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
156 let mut new_map = SumTree::<MapEntry<K, V>>::default();
157
158 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
159 cursor.next(&());
160 while let Some(item) = cursor.item() {
161 if predicate(&item.key, &item.value) {
162 new_map.push(item.clone(), &());
163 }
164 cursor.next(&());
165 }
166 drop(cursor);
167
168 self.0 = new_map;
169 }
170
171 pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
172 self.0.iter().map(|entry| (&entry.key, &entry.value))
173 }
174
175 pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
176 self.0.iter().map(|entry| &entry.value)
177 }
178
179 pub fn insert_tree(&mut self, other: TreeMap<K, V>) {
180 let edits = other
181 .iter()
182 .map(|(key, value)| {
183 Edit::Insert(MapEntry {
184 key: key.to_owned(),
185 value: value.to_owned(),
186 })
187 })
188 .collect();
189
190 self.0.edit(edits, &());
191 }
192}
193
194impl<K, V> Default for TreeMap<K, V>
195where
196 K: Clone + Debug + Default + Ord,
197 V: Clone + Debug,
198{
199 fn default() -> Self {
200 Self(Default::default())
201 }
202}
203
204impl<K, V> Item for MapEntry<K, V>
205where
206 K: Clone + Debug + Default + Ord,
207 V: Clone,
208{
209 type Summary = MapKey<K>;
210
211 fn summary(&self) -> Self::Summary {
212 self.key()
213 }
214}
215
216impl<K, V> KeyedItem for MapEntry<K, V>
217where
218 K: Clone + Debug + Default + Ord,
219 V: Clone,
220{
221 type Key = MapKey<K>;
222
223 fn key(&self) -> Self::Key {
224 MapKey(self.key.clone())
225 }
226}
227
228impl<K> Summary for MapKey<K>
229where
230 K: Clone + Debug + Default,
231{
232 type Context = ();
233
234 fn add_summary(&mut self, summary: &Self, _: &()) {
235 *self = summary.clone()
236 }
237}
238
239impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
240where
241 K: Clone + Debug + Default + Ord,
242{
243 fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
244 self.0 = Some(&summary.0)
245 }
246}
247
248impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
249where
250 K: Clone + Debug + Default + Ord,
251{
252 fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
253 self.0.cmp(&cursor_location.0)
254 }
255}
256
257impl<K> Default for TreeSet<K>
258where
259 K: Clone + Debug + Default + Ord,
260{
261 fn default() -> Self {
262 Self(Default::default())
263 }
264}
265
266impl<K> TreeSet<K>
267where
268 K: Clone + Debug + Default + Ord,
269{
270 pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
271 Self(TreeMap::from_ordered_entries(
272 entries.into_iter().map(|key| (key, ())),
273 ))
274 }
275
276 pub fn insert(&mut self, key: K) {
277 self.0.insert(key, ());
278 }
279
280 pub fn contains(&self, key: &K) -> bool {
281 self.0.get(key).is_some()
282 }
283
284 pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
285 self.0.iter().map(|(k, _)| k)
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292
293 #[test]
294 fn test_basic() {
295 let mut map = TreeMap::default();
296 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
297
298 map.insert(3, "c");
299 assert_eq!(map.get(&3), Some(&"c"));
300 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
301
302 map.insert(1, "a");
303 assert_eq!(map.get(&1), Some(&"a"));
304 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
305
306 map.insert(2, "b");
307 assert_eq!(map.get(&2), Some(&"b"));
308 assert_eq!(map.get(&1), Some(&"a"));
309 assert_eq!(map.get(&3), Some(&"c"));
310 assert_eq!(
311 map.iter().collect::<Vec<_>>(),
312 vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
313 );
314
315 assert_eq!(map.closest(&0), None);
316 assert_eq!(map.closest(&1), Some((&1, &"a")));
317 assert_eq!(map.closest(&10), Some((&3, &"c")));
318
319 map.remove(&2);
320 assert_eq!(map.get(&2), None);
321 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
322
323 assert_eq!(map.closest(&2), Some((&1, &"a")));
324
325 map.remove(&3);
326 assert_eq!(map.get(&3), None);
327 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
328
329 map.remove(&1);
330 assert_eq!(map.get(&1), None);
331 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
332
333 map.insert(4, "d");
334 map.insert(5, "e");
335 map.insert(6, "f");
336 map.retain(|key, _| *key % 2 == 0);
337 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
338 }
339
340 #[test]
341 fn test_remove_between() {
342 let mut map = TreeMap::default();
343
344 map.insert("a", 1);
345 map.insert("b", 2);
346 map.insert("baa", 3);
347 map.insert("baaab", 4);
348 map.insert("c", 5);
349
350 map.remove_between(&"ba", &"bb");
351
352 assert_eq!(map.get(&"a"), Some(&1));
353 assert_eq!(map.get(&"b"), Some(&2));
354 assert_eq!(map.get(&"baaa"), None);
355 assert_eq!(map.get(&"baaaab"), None);
356 assert_eq!(map.get(&"c"), Some(&5));
357 }
358
359 #[test]
360 fn test_remove_from_while() {
361 let mut map = TreeMap::default();
362
363 map.insert("a", 1);
364 map.insert("b", 2);
365 map.insert("baa", 3);
366 map.insert("baaab", 4);
367 map.insert("c", 5);
368
369 map.remove_from_while(&"ba", |key, _| key.starts_with(&"ba"));
370
371 assert_eq!(map.get(&"a"), Some(&1));
372 assert_eq!(map.get(&"b"), Some(&2));
373 assert_eq!(map.get(&"baaa"), None);
374 assert_eq!(map.get(&"baaaab"), None);
375 assert_eq!(map.get(&"c"), Some(&5));
376 }
377
378 #[test]
379 fn test_get_from_while() {
380 let mut map = TreeMap::default();
381
382 map.insert("a", 1);
383 map.insert("b", 2);
384 map.insert("baa", 3);
385 map.insert("baaab", 4);
386 map.insert("c", 5);
387
388 let result = map
389 .get_from_while(&"ba", |key, _| key.starts_with(&"ba"))
390 .collect::<Vec<_>>();
391
392 assert_eq!(result.len(), 2);
393 assert!(result.iter().find(|(k, _)| k == &&"baa").is_some());
394 assert!(result.iter().find(|(k, _)| k == &&"baaab").is_some());
395
396 let result = map
397 .get_from_while(&"c", |key, _| key.starts_with(&"c"))
398 .collect::<Vec<_>>();
399
400 assert_eq!(result.len(), 1);
401 assert!(result.iter().find(|(k, _)| k == &&"c").is_some());
402 }
403
404 #[test]
405 fn test_insert_tree() {
406 let mut map = TreeMap::default();
407 map.insert("a", 1);
408 map.insert("b", 2);
409 map.insert("c", 3);
410
411 let mut other = TreeMap::default();
412 other.insert("a", 2);
413 other.insert("b", 2);
414 other.insert("d", 4);
415
416 map.insert_tree(other);
417
418 assert_eq!(map.iter().count(), 4);
419 assert_eq!(map.get(&"a"), Some(&2));
420 assert_eq!(map.get(&"b"), Some(&2));
421 assert_eq!(map.get(&"c"), Some(&3));
422 assert_eq!(map.get(&"d"), Some(&4));
423 }
424}