1use std::{cmp::Ordering, fmt::Debug};
2
3use crate::{Bias, Dimension, Item, KeyedItem, SeekTarget, SumTree, Summary};
4
5#[derive(Clone, Debug, PartialEq, Eq)]
6pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
7where
8 K: Clone + Debug + Default + Ord,
9 V: Clone + Debug;
10
11#[derive(Clone, Debug, PartialEq, Eq)]
12pub struct MapEntry<K, V> {
13 key: K,
14 value: V,
15}
16
17#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
18pub struct MapKey<K>(K);
19
20#[derive(Clone, Debug, Default)]
21pub struct MapKeyRef<'a, K>(Option<&'a K>);
22
23#[derive(Clone)]
24pub struct TreeSet<K>(TreeMap<K, ()>)
25where
26 K: Clone + Debug + Default + Ord;
27
28impl<K: Clone + Debug + Default + Ord, V: Clone + Debug> TreeMap<K, V> {
29 pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
30 let tree = SumTree::from_iter(
31 entries
32 .into_iter()
33 .map(|(key, value)| MapEntry { key, value }),
34 &(),
35 );
36 Self(tree)
37 }
38
39 pub fn is_empty(&self) -> bool {
40 self.0.is_empty()
41 }
42
43 pub fn get<'a>(&self, key: &'a K) -> Option<&V> {
44 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
45 cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &());
46 if let Some(item) = cursor.item() {
47 if *key == item.key().0 {
48 Some(&item.value)
49 } else {
50 None
51 }
52 } else {
53 None
54 }
55 }
56
57 pub fn insert(&mut self, key: K, value: V) {
58 self.0.insert_or_replace(MapEntry { key, value }, &());
59 }
60
61 pub fn remove(&mut self, key: &K) -> Option<V> {
62 let mut removed = None;
63 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
64 let key = MapKeyRef(Some(key));
65 let mut new_tree = cursor.slice(&key, Bias::Left, &());
66 if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
67 removed = Some(cursor.item().unwrap().value.clone());
68 cursor.next(&());
69 }
70 new_tree.push_tree(cursor.suffix(&()), &());
71 drop(cursor);
72 self.0 = new_tree;
73 removed
74 }
75
76 /// Returns the key-value pair with the greatest key less than or equal to the given key.
77 pub fn closest(&self, key: &K) -> Option<(&K, &V)> {
78 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
79 let key = MapKeyRef(Some(key));
80 cursor.seek(&key, Bias::Right, &());
81 cursor.prev(&());
82 cursor.item().map(|item| (&item.key, &item.value))
83 }
84
85 pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
86 where
87 F: FnOnce(&mut V) -> T,
88 {
89 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
90 let key = MapKeyRef(Some(key));
91 let mut new_tree = cursor.slice(&key, Bias::Left, &());
92 let mut result = None;
93 if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
94 let mut updated = cursor.item().unwrap().clone();
95 result = Some(f(&mut updated.value));
96 new_tree.push(updated, &());
97 cursor.next(&());
98 }
99 new_tree.push_tree(cursor.suffix(&()), &());
100 drop(cursor);
101 self.0 = new_tree;
102 result
103 }
104
105 pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
106 let mut new_map = SumTree::<MapEntry<K, V>>::default();
107
108 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
109 cursor.next(&());
110 while let Some(item) = cursor.item() {
111 if predicate(&item.key, &item.value) {
112 new_map.push(item.clone(), &());
113 }
114 cursor.next(&());
115 }
116 drop(cursor);
117
118 self.0 = new_map;
119 }
120
121 pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
122 self.0.iter().map(|entry| (&entry.key, &entry.value))
123 }
124
125 pub fn values(&self) -> impl Iterator<Item = &V> + '_ {
126 self.0.iter().map(|entry| &entry.value)
127 }
128}
129
130impl<K, V> Default for TreeMap<K, V>
131where
132 K: Clone + Debug + Default + Ord,
133 V: Clone + Debug,
134{
135 fn default() -> Self {
136 Self(Default::default())
137 }
138}
139
140impl<K, V> Item for MapEntry<K, V>
141where
142 K: Clone + Debug + Default + Ord,
143 V: Clone,
144{
145 type Summary = MapKey<K>;
146
147 fn summary(&self) -> Self::Summary {
148 self.key()
149 }
150}
151
152impl<K, V> KeyedItem for MapEntry<K, V>
153where
154 K: Clone + Debug + Default + Ord,
155 V: Clone,
156{
157 type Key = MapKey<K>;
158
159 fn key(&self) -> Self::Key {
160 MapKey(self.key.clone())
161 }
162}
163
164impl<K> Summary for MapKey<K>
165where
166 K: Clone + Debug + Default,
167{
168 type Context = ();
169
170 fn add_summary(&mut self, summary: &Self, _: &()) {
171 *self = summary.clone()
172 }
173}
174
175impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
176where
177 K: Clone + Debug + Default + Ord,
178{
179 fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
180 self.0 = Some(&summary.0)
181 }
182}
183
184impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
185where
186 K: Clone + Debug + Default + Ord,
187{
188 fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
189 self.0.cmp(&cursor_location.0)
190 }
191}
192
193impl<K> Default for TreeSet<K>
194where
195 K: Clone + Debug + Default + Ord,
196{
197 fn default() -> Self {
198 Self(Default::default())
199 }
200}
201
202impl<K> TreeSet<K>
203where
204 K: Clone + Debug + Default + Ord,
205{
206 pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
207 Self(TreeMap::from_ordered_entries(
208 entries.into_iter().map(|key| (key, ())),
209 ))
210 }
211
212 pub fn insert(&mut self, key: K) {
213 self.0.insert(key, ());
214 }
215
216 pub fn contains(&self, key: &K) -> bool {
217 self.0.get(key).is_some()
218 }
219
220 pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
221 self.0.iter().map(|(k, _)| k)
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 #[test]
230 fn test_basic() {
231 let mut map = TreeMap::default();
232 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
233
234 map.insert(3, "c");
235 assert_eq!(map.get(&3), Some(&"c"));
236 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
237
238 map.insert(1, "a");
239 assert_eq!(map.get(&1), Some(&"a"));
240 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
241
242 map.insert(2, "b");
243 assert_eq!(map.get(&2), Some(&"b"));
244 assert_eq!(map.get(&1), Some(&"a"));
245 assert_eq!(map.get(&3), Some(&"c"));
246 assert_eq!(
247 map.iter().collect::<Vec<_>>(),
248 vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
249 );
250
251 assert_eq!(map.closest(&0), None);
252 assert_eq!(map.closest(&1), Some((&1, &"a")));
253 assert_eq!(map.closest(&10), Some((&3, &"c")));
254
255 map.remove(&2);
256 assert_eq!(map.get(&2), None);
257 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
258
259 assert_eq!(map.closest(&2), Some((&1, &"a")));
260
261 map.remove(&3);
262 assert_eq!(map.get(&3), None);
263 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
264
265 map.remove(&1);
266 assert_eq!(map.get(&1), None);
267 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
268
269 map.insert(4, "d");
270 map.insert(5, "e");
271 map.insert(6, "f");
272 map.retain(|key, _| *key % 2 == 0);
273 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&4, &"d"), (&6, &"f")]);
274 }
275}