1use std::{cmp::Ordering, fmt::Debug, iter};
2
3use crate::{Bias, Dimension, Edit, Item, KeyedItem, SeekTarget, SumTree, Summary};
4
5#[derive(Clone, Debug, PartialEq, Eq)]
6pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
7where
8 K: Clone + Debug + Default + Ord,
9 V: Clone + Debug;
10
11#[derive(Clone, Debug, PartialEq, Eq)]
12pub struct MapEntry<K, V> {
13 key: K,
14 value: V,
15}
16
17#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
18pub struct MapKey<K>(K);
19
20#[derive(Clone, Debug, Default)]
21pub struct MapKeyRef<'a, K>(Option<&'a K>);
22
23#[derive(Clone)]
24pub struct TreeSet<K>(TreeMap<K, ()>)
25where
26 K: Clone + Debug + Default + Ord;
27
28impl<K: Clone + Debug + Default + Ord, V: Clone + Debug> TreeMap<K, V> {
29 pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
30 let tree = SumTree::from_iter(
31 entries
32 .into_iter()
33 .map(|(key, value)| MapEntry { key, value }),
34 &(),
35 );
36 Self(tree)
37 }
38
39 pub fn is_empty(&self) -> bool {
40 self.0.is_empty()
41 }
42
43 pub fn get<'a>(&self, key: &'a K) -> Option<&V> {
44 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
45 cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &());
46 if let Some(item) = cursor.item() {
47 if *key == item.key().0 {
48 Some(&item.value)
49 } else {
50 None
51 }
52 } else {
53 None
54 }
55 }
56
57 pub fn insert(&mut self, key: K, value: V) {
58 self.0.insert_or_replace(MapEntry { key, value }, &());
59 }
60
61 pub fn remove(&mut self, key: &K) -> Option<V> {
62 let mut removed = None;
63 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
64 let key = MapKeyRef(Some(key));
65 let mut new_tree = cursor.slice(&key, Bias::Left, &());
66 if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
67 removed = Some(cursor.item().unwrap().value.clone());
68 cursor.next(&());
69 }
70 new_tree.push_tree(cursor.suffix(&()), &());
71 drop(cursor);
72 self.0 = new_tree;
73 removed
74 }
75
76 /// Returns the key-value pair with the greatest key less than or equal to the given key.
77 pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
78 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
79 let key = MapKeyRef(Some(key));
80 cursor.seek(&key, Bias::Right, &());
81 cursor.prev(&());
82 cursor.item().map(|item| (&item.key, &item.value))
83 }
84
85 pub fn remove_between(&mut self, from: &K, until: &K) {
86 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
87 let from_key = MapKeyRef(Some(from));
88 let mut new_tree = cursor.slice(&from_key, Bias::Left, &());
89 let until_key = MapKeyRef(Some(until));
90 cursor.seek_forward(&until_key, Bias::Left, &());
91 new_tree.push_tree(cursor.suffix(&()), &());
92 drop(cursor);
93 self.0 = new_tree;
94 }
95
96 pub fn remove_from_while<F>(&mut self, from: &K, mut f: F)
97 where
98 F: FnMut(&K, &V) -> bool,
99 {
100 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
101 let from_key = MapKeyRef(Some(from));
102 let mut new_tree = cursor.slice(&from_key, Bias::Left, &());
103 while let Some(item) = cursor.item() {
104 if !f(&item.key, &item.value) {
105 break;
106 }
107 cursor.next(&());
108 }
109 new_tree.push_tree(cursor.suffix(&()), &());
110 drop(cursor);
111 self.0 = new_tree;
112 }
113
114
115 pub fn get_from_while<'tree, F>(&'tree self, from: &'tree K, mut f: F) -> impl Iterator<Item = (&K, &V)> + '_
116 where
117 F: FnMut(&K, &K, &V) -> bool + 'tree,
118 {
119 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
120 let from_key = MapKeyRef(Some(from));
121 cursor.seek(&from_key, Bias::Left, &());
122
123 iter::from_fn(move || {
124 let result = cursor.item().and_then(|item| {
125 (f(from, &item.key, &item.value))
126 .then(|| (&item.key, &item.value))
127 });
128 cursor.next(&());
129 result
130 })
131 }
132
133
134 pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
135 where
136 F: FnOnce(&mut V) -> T,
137 {
138 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
139 let key = MapKeyRef(Some(key));
140 let mut new_tree = cursor.slice(&key, Bias::Left, &());
141 let mut result = None;
142 if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
143 let mut updated = cursor.item().unwrap().clone();
144 result = Some(f(&mut updated.value));
145 new_tree.push(updated, &());
146 cursor.next(&());
147 }
148 new_tree.push_tree(cursor.suffix(&()), &());
149 drop(cursor);
150 self.0 = new_tree;
151 result
152 }
153
154 pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
155 let mut new_map = SumTree::<MapEntry<K, V>>::default();
156
157 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
158 cursor.next(&());
159 while let Some(item) = cursor.item() {
160 if predicate(&item.key, &item.value) {
161 new_map.push(item.clone(), &());
162 }
163 cursor.next(&());
164 }
165 drop(cursor);
166
167 self.0 = new_map;
168 }
169
170 pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
171 self.0.iter().map(|entry| (&entry.key, &entry.value))
172 }
173
174 pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
175 self.0.iter().map(|entry| &entry.value)
176 }
177
178 pub fn insert_tree(&mut self, other: TreeMap<K, V>) {
179 let edits = other
180 .iter()
181 .map(|(key, value)| {
182 Edit::Insert(MapEntry {
183 key: key.to_owned(),
184 value: value.to_owned(),
185 })
186 })
187 .collect();
188
189 self.0.edit(edits, &());
190 }
191}
192
193impl<K, V> Default for TreeMap<K, V>
194where
195 K: Clone + Debug + Default + Ord,
196 V: Clone + Debug,
197{
198 fn default() -> Self {
199 Self(Default::default())
200 }
201}
202
203impl<K, V> Item for MapEntry<K, V>
204where
205 K: Clone + Debug + Default + Ord,
206 V: Clone,
207{
208 type Summary = MapKey<K>;
209
210 fn summary(&self) -> Self::Summary {
211 self.key()
212 }
213}
214
215impl<K, V> KeyedItem for MapEntry<K, V>
216where
217 K: Clone + Debug + Default + Ord,
218 V: Clone,
219{
220 type Key = MapKey<K>;
221
222 fn key(&self) -> Self::Key {
223 MapKey(self.key.clone())
224 }
225}
226
227impl<K> Summary for MapKey<K>
228where
229 K: Clone + Debug + Default,
230{
231 type Context = ();
232
233 fn add_summary(&mut self, summary: &Self, _: &()) {
234 *self = summary.clone()
235 }
236}
237
238impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
239where
240 K: Clone + Debug + Default + Ord,
241{
242 fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
243 self.0 = Some(&summary.0)
244 }
245}
246
247impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
248where
249 K: Clone + Debug + Default + Ord,
250{
251 fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
252 self.0.cmp(&cursor_location.0)
253 }
254}
255
256impl<K> Default for TreeSet<K>
257where
258 K: Clone + Debug + Default + Ord,
259{
260 fn default() -> Self {
261 Self(Default::default())
262 }
263}
264
265impl<K> TreeSet<K>
266where
267 K: Clone + Debug + Default + Ord,
268{
269 pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
270 Self(TreeMap::from_ordered_entries(
271 entries.into_iter().map(|key| (key, ())),
272 ))
273 }
274
275 pub fn insert(&mut self, key: K) {
276 self.0.insert(key, ());
277 }
278
279 pub fn contains(&self, key: &K) -> bool {
280 self.0.get(key).is_some()
281 }
282
283 pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
284 self.0.iter().map(|(k, _)| k)
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291
292 #[test]
293 fn test_basic() {
294 let mut map = TreeMap::default();
295 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
296
297 map.insert(3, "c");
298 assert_eq!(map.get(&3), Some(&"c"));
299 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
300
301 map.insert(1, "a");
302 assert_eq!(map.get(&1), Some(&"a"));
303 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
304
305 map.insert(2, "b");
306 assert_eq!(map.get(&2), Some(&"b"));
307 assert_eq!(map.get(&1), Some(&"a"));
308 assert_eq!(map.get(&3), Some(&"c"));
309 assert_eq!(
310 map.iter().collect::<Vec<_>>(),
311 vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
312 );
313
314 assert_eq!(map.closest(&0), None);
315 assert_eq!(map.closest(&1), Some((&1, &"a")));
316 assert_eq!(map.closest(&10), Some((&3, &"c")));
317
318 map.remove(&2);
319 assert_eq!(map.get(&2), None);
320 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
321
322 assert_eq!(map.closest(&2), Some((&1, &"a")));
323
324 map.remove(&3);
325 assert_eq!(map.get(&3), None);
326 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
327
328 map.remove(&1);
329 assert_eq!(map.get(&1), None);
330 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
331
332 map.insert(4, "d");
333 map.insert(5, "e");
334 map.insert(6, "f");
335 map.retain(|key, _| *key % 2 == 0);
336 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
337 }
338
339 #[test]
340 fn test_remove_between() {
341 let mut map = TreeMap::default();
342
343 map.insert("a", 1);
344 map.insert("b", 2);
345 map.insert("baa", 3);
346 map.insert("baaab", 4);
347 map.insert("c", 5);
348
349 map.remove_between(&"ba", &"bb");
350
351 assert_eq!(map.get(&"a"), Some(&1));
352 assert_eq!(map.get(&"b"), Some(&2));
353 assert_eq!(map.get(&"baaa"), None);
354 assert_eq!(map.get(&"baaaab"), None);
355 assert_eq!(map.get(&"c"), Some(&5));
356 }
357
358 #[test]
359 fn test_remove_from_while() {
360 let mut map = TreeMap::default();
361
362 map.insert("a", 1);
363 map.insert("b", 2);
364 map.insert("baa", 3);
365 map.insert("baaab", 4);
366 map.insert("c", 5);
367
368 map.remove_from_while(&"ba", |key, _| key.starts_with(&"ba"));
369
370 assert_eq!(map.get(&"a"), Some(&1));
371 assert_eq!(map.get(&"b"), Some(&2));
372 assert_eq!(map.get(&"baaa"), None);
373 assert_eq!(map.get(&"baaaab"), None);
374 assert_eq!(map.get(&"c"), Some(&5));
375 }
376
377 #[test]
378 fn test_get_from_while() {
379 let mut map = TreeMap::default();
380
381 map.insert("a", 1);
382 map.insert("b", 2);
383 map.insert("baa", 3);
384 map.insert("baaab", 4);
385 map.insert("c", 5);
386
387 let result = map.get_from_while(&"ba", |key, _| key.starts_with(&"ba")).collect::<Vec<_>>();
388
389 assert_eq!(result.len(), 2);
390 assert!(result.iter().find(|(k, _)| k == &&"baa").is_some());
391 assert!(result.iter().find(|(k, _)| k == &&"baaab").is_some());
392
393 let result = map.get_from_while(&"c", |key, _| key.starts_with(&"c")).collect::<Vec<_>>();
394
395 assert_eq!(result.len(), 1);
396 assert!(result.iter().find(|(k, _)| k == &&"c").is_some());
397 }
398
399 #[test]
400 fn test_insert_tree() {
401 let mut map = TreeMap::default();
402 map.insert("a", 1);
403 map.insert("b", 2);
404 map.insert("c", 3);
405
406 let mut other = TreeMap::default();
407 other.insert("a", 2);
408 other.insert("b", 2);
409 other.insert("d", 4);
410
411 map.insert_tree(other);
412
413 assert_eq!(map.iter().count(), 4);
414 assert_eq!(map.get(&"a"), Some(&2));
415 assert_eq!(map.get(&"b"), Some(&2));
416 assert_eq!(map.get(&"c"), Some(&3));
417 assert_eq!(map.get(&"d"), Some(&4));
418 }
419}