1use std::{cmp::Ordering, fmt::Debug};
2
3use crate::{Bias, Dimension, Item, KeyedItem, SeekTarget, SumTree, Summary};
4
5#[derive(Clone, Debug)]
6pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
7where
8 K: Clone + Debug + Default + Ord,
9 V: Clone + Debug;
10
11#[derive(Clone, Debug)]
12pub struct MapEntry<K, V> {
13 key: K,
14 value: V,
15}
16
17#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
18pub struct MapKey<K>(K);
19
20#[derive(Clone, Debug, Default)]
21pub struct MapKeyRef<'a, K>(Option<&'a K>);
22
23#[derive(Clone)]
24pub struct TreeSet<K>(TreeMap<K, ()>)
25where
26 K: Clone + Debug + Default + Ord;
27
28impl<K: Clone + Debug + Default + Ord, V: Clone + Debug> TreeMap<K, V> {
29 pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
30 let tree = SumTree::from_iter(
31 entries
32 .into_iter()
33 .map(|(key, value)| MapEntry { key, value }),
34 &(),
35 );
36 Self(tree)
37 }
38
39 pub fn is_empty(&self) -> bool {
40 self.0.is_empty()
41 }
42
43 pub fn get<'a>(&self, key: &'a K) -> Option<&V> {
44 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
45 cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &());
46 if let Some(item) = cursor.item() {
47 if *key == item.key().0 {
48 Some(&item.value)
49 } else {
50 None
51 }
52 } else {
53 None
54 }
55 }
56
57 pub fn insert(&mut self, key: K, value: V) {
58 self.0.insert_or_replace(MapEntry { key, value }, &());
59 }
60
61 pub fn remove(&mut self, key: &K) -> Option<V> {
62 let mut removed = None;
63 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
64 let key = MapKeyRef(Some(key));
65 let mut new_tree = cursor.slice(&key, Bias::Left, &());
66 if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
67 removed = Some(cursor.item().unwrap().value.clone());
68 cursor.next(&());
69 }
70 new_tree.push_tree(cursor.suffix(&()), &());
71 drop(cursor);
72 self.0 = new_tree;
73 removed
74 }
75
76 pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
77 where
78 F: FnOnce(&mut V) -> T,
79 {
80 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
81 let key = MapKeyRef(Some(key));
82 let mut new_tree = cursor.slice(&key, Bias::Left, &());
83 let mut result = None;
84 if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
85 let mut updated = cursor.item().unwrap().clone();
86 result = Some(f(&mut updated.value));
87 new_tree.push(updated, &());
88 cursor.next(&());
89 }
90 new_tree.push_tree(cursor.suffix(&()), &());
91 drop(cursor);
92 self.0 = new_tree;
93 result
94 }
95
96 pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut predicate: F) {
97 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
98 cursor.seek(&MapKeyRef(None), Bias::Left, &());
99
100 let mut new_map = SumTree::<MapEntry<K, V>>::default();
101 if let Some(item) = cursor.item() {
102 if predicate(&item.key, &item.value) {
103 new_map.push(item.clone(), &());
104 }
105 cursor.next(&());
106 }
107 drop(cursor);
108
109 self.0 = new_map;
110 }
111
112 pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
113 self.0.iter().map(|entry| (&entry.key, &entry.value))
114 }
115}
116
117impl<K, V> Default for TreeMap<K, V>
118where
119 K: Clone + Debug + Default + Ord,
120 V: Clone + Debug,
121{
122 fn default() -> Self {
123 Self(Default::default())
124 }
125}
126
127impl<K, V> Item for MapEntry<K, V>
128where
129 K: Clone + Debug + Default + Ord,
130 V: Clone,
131{
132 type Summary = MapKey<K>;
133
134 fn summary(&self) -> Self::Summary {
135 self.key()
136 }
137}
138
139impl<K, V> KeyedItem for MapEntry<K, V>
140where
141 K: Clone + Debug + Default + Ord,
142 V: Clone,
143{
144 type Key = MapKey<K>;
145
146 fn key(&self) -> Self::Key {
147 MapKey(self.key.clone())
148 }
149}
150
151impl<K> Summary for MapKey<K>
152where
153 K: Clone + Debug + Default,
154{
155 type Context = ();
156
157 fn add_summary(&mut self, summary: &Self, _: &()) {
158 *self = summary.clone()
159 }
160}
161
162impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
163where
164 K: Clone + Debug + Default + Ord,
165{
166 fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
167 self.0 = Some(&summary.0)
168 }
169}
170
171impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
172where
173 K: Clone + Debug + Default + Ord,
174{
175 fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
176 self.0.cmp(&cursor_location.0)
177 }
178}
179
180impl<K> Default for TreeSet<K>
181where
182 K: Clone + Debug + Default + Ord,
183{
184 fn default() -> Self {
185 Self(Default::default())
186 }
187}
188
189impl<K> TreeSet<K>
190where
191 K: Clone + Debug + Default + Ord,
192{
193 pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
194 Self(TreeMap::from_ordered_entries(
195 entries.into_iter().map(|key| (key, ())),
196 ))
197 }
198
199 pub fn insert(&mut self, key: K) {
200 self.0.insert(key, ());
201 }
202
203 pub fn contains(&self, key: &K) -> bool {
204 self.0.get(key).is_some()
205 }
206
207 pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
208 self.0.iter().map(|(k, _)| k)
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215
216 #[test]
217 fn test_basic() {
218 let mut map = TreeMap::default();
219 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
220
221 map.insert(3, "c");
222 assert_eq!(map.get(&3), Some(&"c"));
223 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
224
225 map.insert(1, "a");
226 assert_eq!(map.get(&1), Some(&"a"));
227 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
228
229 map.insert(2, "b");
230 assert_eq!(map.get(&2), Some(&"b"));
231 assert_eq!(map.get(&1), Some(&"a"));
232 assert_eq!(map.get(&3), Some(&"c"));
233 assert_eq!(
234 map.iter().collect::<Vec<_>>(),
235 vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
236 );
237
238 map.remove(&2);
239 assert_eq!(map.get(&2), None);
240 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
241
242 map.remove(&3);
243 assert_eq!(map.get(&3), None);
244 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
245
246 map.remove(&1);
247 assert_eq!(map.get(&1), None);
248 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
249 }
250}