1use std::{cmp::Ordering, fmt::Debug};
2
3use crate::{Bias, Dimension, Item, KeyedItem, SeekTarget, SumTree, Summary};
4
5#[derive(Clone)]
6pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
7where
8 K: Clone + Debug + Default + Ord,
9 V: Clone + Debug;
10
11#[derive(Clone)]
12pub struct MapEntry<K, V> {
13 key: K,
14 value: V,
15}
16
17#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
18pub struct MapKey<K>(K);
19
20#[derive(Clone, Debug, Default)]
21pub struct MapKeyRef<'a, K>(Option<&'a K>);
22
23#[derive(Clone)]
24pub struct TreeSet<K>(TreeMap<K, ()>)
25where
26 K: Clone + Debug + Default + Ord;
27
28impl<K: Clone + Debug + Default + Ord, V: Clone + Debug> TreeMap<K, V> {
29 pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
30 let tree = SumTree::from_iter(
31 entries
32 .into_iter()
33 .map(|(key, value)| MapEntry { key, value }),
34 &(),
35 );
36 Self(tree)
37 }
38
39 pub fn is_empty(&self) -> bool {
40 self.0.is_empty()
41 }
42
43 pub fn get<'a>(&self, key: &'a K) -> Option<&V> {
44 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
45 cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &());
46 if let Some(item) = cursor.item() {
47 if *key == item.key().0 {
48 Some(&item.value)
49 } else {
50 None
51 }
52 } else {
53 None
54 }
55 }
56
57 pub fn insert(&mut self, key: K, value: V) {
58 self.0.insert_or_replace(MapEntry { key, value }, &());
59 }
60
61 pub fn remove(&mut self, key: &K) -> Option<V> {
62 let mut removed = None;
63 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
64 let key = MapKeyRef(Some(key));
65 let mut new_tree = cursor.slice(&key, Bias::Left, &());
66 if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
67 removed = Some(cursor.item().unwrap().value.clone());
68 cursor.next(&());
69 }
70 new_tree.push_tree(cursor.suffix(&()), &());
71 drop(cursor);
72 self.0 = new_tree;
73 removed
74 }
75
76 pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> + '_ {
77 self.0.iter().map(|entry| (&entry.key, &entry.value))
78 }
79}
80
81impl<K, V> Default for TreeMap<K, V>
82where
83 K: Clone + Debug + Default + Ord,
84 V: Clone + Debug,
85{
86 fn default() -> Self {
87 Self(Default::default())
88 }
89}
90
91impl<K, V> Item for MapEntry<K, V>
92where
93 K: Clone + Debug + Default + Ord,
94 V: Clone,
95{
96 type Summary = MapKey<K>;
97
98 fn summary(&self) -> Self::Summary {
99 self.key()
100 }
101}
102
103impl<K, V> KeyedItem for MapEntry<K, V>
104where
105 K: Clone + Debug + Default + Ord,
106 V: Clone,
107{
108 type Key = MapKey<K>;
109
110 fn key(&self) -> Self::Key {
111 MapKey(self.key.clone())
112 }
113}
114
115impl<K> Summary for MapKey<K>
116where
117 K: Clone + Debug + Default,
118{
119 type Context = ();
120
121 fn add_summary(&mut self, summary: &Self, _: &()) {
122 *self = summary.clone()
123 }
124}
125
126impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
127where
128 K: Clone + Debug + Default + Ord,
129{
130 fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
131 self.0 = Some(&summary.0)
132 }
133}
134
135impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
136where
137 K: Clone + Debug + Default + Ord,
138{
139 fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
140 self.0.cmp(&cursor_location.0)
141 }
142}
143
144impl<K> Default for TreeSet<K>
145where
146 K: Clone + Debug + Default + Ord,
147{
148 fn default() -> Self {
149 Self(Default::default())
150 }
151}
152
153impl<K> TreeSet<K>
154where
155 K: Clone + Debug + Default + Ord,
156{
157 pub fn from_ordered_entries(entries: impl IntoIterator<Item = K>) -> Self {
158 Self(TreeMap::from_ordered_entries(
159 entries.into_iter().map(|key| (key, ())),
160 ))
161 }
162
163 pub fn insert(&mut self, key: K) {
164 self.0.insert(key, ());
165 }
166
167 pub fn contains(&self, key: &K) -> bool {
168 self.0.get(key).is_some()
169 }
170
171 pub fn iter(&self) -> impl Iterator<Item = &K> + '_ {
172 self.0.iter().map(|(k, _)| k)
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 #[test]
181 fn test_basic() {
182 let mut map = TreeMap::default();
183 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
184
185 map.insert(3, "c");
186 assert_eq!(map.get(&3), Some(&"c"));
187 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
188
189 map.insert(1, "a");
190 assert_eq!(map.get(&1), Some(&"a"));
191 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
192
193 map.insert(2, "b");
194 assert_eq!(map.get(&2), Some(&"b"));
195 assert_eq!(map.get(&1), Some(&"a"));
196 assert_eq!(map.get(&3), Some(&"c"));
197 assert_eq!(
198 map.iter().collect::<Vec<_>>(),
199 vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
200 );
201
202 map.remove(&2);
203 assert_eq!(map.get(&2), None);
204 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
205
206 map.remove(&3);
207 assert_eq!(map.get(&3), None);
208 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
209
210 map.remove(&1);
211 assert_eq!(map.get(&1), None);
212 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
213 }
214}