1use std::{cmp::Ordering, fmt::Debug};
2
3use crate::{Bias, Dimension, Item, KeyedItem, SeekTarget, SumTree, Summary};
4
5#[derive(Clone, Debug, PartialEq, Eq)]
6pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
7where
8 K: Clone + Debug + Default + Ord,
9 V: Clone + Debug;
10
11#[derive(Clone, Debug, PartialEq, Eq)]
12pub struct MapEntry<K, V> {
13 key: K,
14 value: V,
15}
16
17#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
18pub struct MapKey<K>(K);
19
20#[derive(Clone, Debug, Default)]
21pub struct MapKeyRef<'a, K>(Option<&'a K>);
22
23#[derive(Clone)]
24pub struct TreeSet<K>(TreeMap<K, ()>)
25where
26 K: Clone + Debug + Default + Ord;
27
28impl<K: Clone + Debug + Default + Ord, V: Clone + Debug> TreeMap<K, V> {
29 pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
30 let tree = SumTree::from_iter(
31 entries
32 .into_iter()
33 .map(|(key, value)| MapEntry { key, value }),
34 &(),
35 );
36 Self(tree)
37 }
38
39 pub fn is_empty(&self) -> bool {
40 self.0.is_empty()
41 }
42
43 pub fn get<'a>(&self, key: &'a K) -> Option<&V> {
44 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
45 cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &());
46 if let Some(item) = cursor.item() {
47 if *key == item.key().0 {
48 Some(&item.value)
49 } else {
50 None
51 }
52 } else {
53 None
54 }
55 }
56
57 pub fn insert(&mut self, key: K, value: V) {
58 self.0.insert_or_replace(MapEntry { key, value }, &());
59 }
60
61 pub fn remove(&mut self, key: &K) -> Option<V> {
62 let mut removed = None;
63 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
64 let key = MapKeyRef(Some(key));
65 let mut new_tree = cursor.slice(&key, Bias::Left, &());
66 if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
67 removed = Some(cursor.item().unwrap().value.clone());
68 cursor.next(&());
69 }
70 new_tree.push_tree(cursor.suffix(&()), &());
71 drop(cursor);
72 self.0 = new_tree;
73 removed
74 }
75
76 /// Returns the key-value pair with the greatest key less than or equal to the given key.
77 pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
78 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
79 let key = MapKeyRef(Some(key));
80 cursor.seek(&key, Bias::Right, &());
81 cursor.prev(&());
82 cursor.item().map(|item| (&item.key, &item.value))
83 }
84
85 pub fn remove_between(&mut self, from: &K, until: &K)
86 {
87 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
88 let from_key = MapKeyRef(Some(from));
89 let mut new_tree = cursor.slice(&from_key, Bias::Left, &());
90 let until_key = MapKeyRef(Some(until));
91 cursor.seek_forward(&until_key, Bias::Left, &());
92 new_tree.push_tree(cursor.suffix(&()), &());
93 drop(cursor);
94 self.0 = new_tree;
95 }
96
97 pub fn remove_from_while<F>(&mut self, from: &K, mut f: F)
98 where F: FnMut(&K, &V) -> bool
99 {
100 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
101 let from_key = MapKeyRef(Some(from));
102 let mut new_tree = cursor.slice(&from_key, Bias::Left, &());
103 while let Some(item) = cursor.item() {
104 if !f(&item.key, &item.value) {
105 break;
106 }
107 cursor.next(&());
108 }
109 new_tree.push_tree(cursor.suffix(&()), &());
110 drop(cursor);
111 self.0 = new_tree;
112 }
113
114
115 pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
116 where
117 F: FnOnce(&mut V) -> T,
118 {
119 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
120 let key = MapKeyRef(Some(key));
121 let mut new_tree = cursor.slice(&key, Bias::Left, &());
122 let mut result = None;
123 if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
124 let mut updated = cursor.item().unwrap().clone();
125 result = Some(f(&mut updated.value));
126 new_tree.push(updated, &());
127 cursor.next(&());
128 }
129 new_tree.push_tree(cursor.suffix(&()), &());
130 drop(cursor);
131 self.0 = new_tree;
132 result
133 }
134
135 pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
136 let mut new_map = SumTree::<MapEntry<K, V>>::default();
137
138 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
139 cursor.next(&());
140 while let Some(item) = cursor.item() {
141 if predicate(&item.key, &item.value) {
142 new_map.push(item.clone(), &());
143 }
144 cursor.next(&());
145 }
146 drop(cursor);
147
148 self.0 = new_map;
149 }
150
151 pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
152 self.0.iter().map(|entry| (&entry.key, &entry.value))
153 }
154
155 pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
156 self.0.iter().map(|entry| &entry.value)
157 }
158}
159
160impl<K, V> Default for TreeMap<K, V>
161where
162 K: Clone + Debug + Default + Ord,
163 V: Clone + Debug,
164{
165 fn default() -> Self {
166 Self(Default::default())
167 }
168}
169
170impl<K, V> Item for MapEntry<K, V>
171where
172 K: Clone + Debug + Default + Ord,
173 V: Clone,
174{
175 type Summary = MapKey<K>;
176
177 fn summary(&self) -> Self::Summary {
178 self.key()
179 }
180}
181
182impl<K, V> KeyedItem for MapEntry<K, V>
183where
184 K: Clone + Debug + Default + Ord,
185 V: Clone,
186{
187 type Key = MapKey<K>;
188
189 fn key(&self) -> Self::Key {
190 MapKey(self.key.clone())
191 }
192}
193
194impl<K> Summary for MapKey<K>
195where
196 K: Clone + Debug + Default,
197{
198 type Context = ();
199
200 fn add_summary(&mut self, summary: &Self, _: &()) {
201 *self = summary.clone()
202 }
203}
204
205impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
206where
207 K: Clone + Debug + Default + Ord,
208{
209 fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
210 self.0 = Some(&summary.0)
211 }
212}
213
214impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
215where
216 K: Clone + Debug + Default + Ord,
217{
218 fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
219 self.0.cmp(&cursor_location.0)
220 }
221}
222
223impl<K> Default for TreeSet<K>
224where
225 K: Clone + Debug + Default + Ord,
226{
227 fn default() -> Self {
228 Self(Default::default())
229 }
230}
231
232impl<K> TreeSet<K>
233where
234 K: Clone + Debug + Default + Ord,
235{
236 pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
237 Self(TreeMap::from_ordered_entries(
238 entries.into_iter().map(|key| (key, ())),
239 ))
240 }
241
242 pub fn insert(&mut self, key: K) {
243 self.0.insert(key, ());
244 }
245
246 pub fn contains(&self, key: &K) -> bool {
247 self.0.get(key).is_some()
248 }
249
250 pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
251 self.0.iter().map(|(k, _)| k)
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258
259 #[test]
260 fn test_basic() {
261 let mut map = TreeMap::default();
262 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
263
264 map.insert(3, "c");
265 assert_eq!(map.get(&3), Some(&"c"));
266 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
267
268 map.insert(1, "a");
269 assert_eq!(map.get(&1), Some(&"a"));
270 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
271
272 map.insert(2, "b");
273 assert_eq!(map.get(&2), Some(&"b"));
274 assert_eq!(map.get(&1), Some(&"a"));
275 assert_eq!(map.get(&3), Some(&"c"));
276 assert_eq!(
277 map.iter().collect::<Vec<_>>(),
278 vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
279 );
280
281 assert_eq!(map.closest(&0), None);
282 assert_eq!(map.closest(&1), Some((&1, &"a")));
283 assert_eq!(map.closest(&10), Some((&3, &"c")));
284
285 map.remove(&2);
286 assert_eq!(map.get(&2), None);
287 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
288
289 assert_eq!(map.closest(&2), Some((&1, &"a")));
290
291 map.remove(&3);
292 assert_eq!(map.get(&3), None);
293 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
294
295 map.remove(&1);
296 assert_eq!(map.get(&1), None);
297 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
298
299 map.insert(4, "d");
300 map.insert(5, "e");
301 map.insert(6, "f");
302 map.retain(|key, _| *key % 2 == 0);
303 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
304 }
305
306 #[test]
307 fn test_remove_between() {
308 let mut map = TreeMap::default();
309
310 map.insert("a", 1);
311 map.insert("b", 2);
312 map.insert("baa", 3);
313 map.insert("baaab", 4);
314 map.insert("c", 5);
315
316 map.remove_between(&"ba", &"bb");
317
318 assert_eq!(map.get(&"a"), Some(&1));
319 assert_eq!(map.get(&"b"), Some(&2));
320 assert_eq!(map.get(&"baaa"), None);
321 assert_eq!(map.get(&"baaaab"), None);
322 assert_eq!(map.get(&"c"), Some(&5));
323 }
324
325 #[test]
326 fn test_remove_from_while() {
327 let mut map = TreeMap::default();
328
329 map.insert("a", 1);
330 map.insert("b", 2);
331 map.insert("baa", 3);
332 map.insert("baaab", 4);
333 map.insert("c", 5);
334
335 map.remove_from_while(&"ba", |key, _| key.starts_with(&"ba"));
336
337 assert_eq!(map.get(&"a"), Some(&1));
338 assert_eq!(map.get(&"b"), Some(&2));
339 assert_eq!(map.get(&"baaa"), None);
340 assert_eq!(map.get(&"baaaab"), None);
341 assert_eq!(map.get(&"c"), Some(&5));
342 }
343}