1use std::{cmp::Ordering, fmt::Debug};
2
3use crate::{Bias, Dimension, Item, KeyedItem, SeekTarget, SumTree, Summary};
4
5#[derive(Clone)]
6pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
7where
8 K: Clone + Debug + Default + Ord,
9 V: Clone + Debug;
10
11#[derive(Clone)]
12pub struct MapEntry<K, V> {
13 key: K,
14 value: V,
15}
16
17#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
18pub struct MapKey<K>(K);
19
20#[derive(Clone, Debug, Default)]
21pub struct MapKeyRef<'a, K>(Option<&'a K>);
22
23impl<K: Clone + Debug + Default + Ord, V: Clone + Debug> TreeMap<K, V> {
24 pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
25 let tree = SumTree::from_iter(
26 entries
27 .into_iter()
28 .map(|(key, value)| MapEntry { key, value }),
29 &(),
30 );
31 Self(tree)
32 }
33
34 pub fn is_empty(&self) -> bool {
35 self.0.is_empty()
36 }
37
38 pub fn get<'a>(&self, key: &'a K) -> Option<&V> {
39 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
40 cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &());
41 if let Some(item) = cursor.item() {
42 if *key == item.key().0 {
43 Some(&item.value)
44 } else {
45 None
46 }
47 } else {
48 None
49 }
50 }
51
52 pub fn insert(&mut self, key: K, value: V) {
53 self.0.insert_or_replace(MapEntry { key, value }, &());
54 }
55
56 pub fn remove<'a>(&mut self, key: &'a K) -> Option<V> {
57 let mut removed = None;
58 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
59 let key = MapKeyRef(Some(key));
60 let mut new_tree = cursor.slice(&key, Bias::Left, &());
61 if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
62 removed = Some(cursor.item().unwrap().value.clone());
63 cursor.next(&());
64 }
65 new_tree.push_tree(cursor.suffix(&()), &());
66 drop(cursor);
67 self.0 = new_tree;
68 removed
69 }
70
71 pub fn iter<'a>(&'a self) -> impl 'a + Iterator<Item = (&'a K, &'a V)> {
72 self.0.iter().map(|entry| (&entry.key, &entry.value))
73 }
74}
75
76impl<K, V> Default for TreeMap<K, V>
77where
78 K: Clone + Debug + Default + Ord,
79 V: Clone + Debug,
80{
81 fn default() -> Self {
82 Self(Default::default())
83 }
84}
85
86impl<K, V> Item for MapEntry<K, V>
87where
88 K: Clone + Debug + Default + Ord,
89 V: Clone,
90{
91 type Summary = MapKey<K>;
92
93 fn summary(&self) -> Self::Summary {
94 self.key()
95 }
96}
97
98impl<K, V> KeyedItem for MapEntry<K, V>
99where
100 K: Clone + Debug + Default + Ord,
101 V: Clone,
102{
103 type Key = MapKey<K>;
104
105 fn key(&self) -> Self::Key {
106 MapKey(self.key.clone())
107 }
108}
109
110impl<K> Summary for MapKey<K>
111where
112 K: Clone + Debug + Default,
113{
114 type Context = ();
115
116 fn add_summary(&mut self, summary: &Self, _: &()) {
117 *self = summary.clone()
118 }
119}
120
121impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
122where
123 K: Clone + Debug + Default + Ord,
124{
125 fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
126 self.0 = Some(&summary.0)
127 }
128}
129
130impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
131where
132 K: Clone + Debug + Default + Ord,
133{
134 fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
135 self.0.cmp(&cursor_location.0)
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142
143 #[test]
144 fn test_basic() {
145 let mut map = TreeMap::default();
146 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
147
148 map.insert(3, "c");
149 assert_eq!(map.get(&3), Some(&"c"));
150 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
151
152 map.insert(1, "a");
153 assert_eq!(map.get(&1), Some(&"a"));
154 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
155
156 map.insert(2, "b");
157 assert_eq!(map.get(&2), Some(&"b"));
158 assert_eq!(map.get(&1), Some(&"a"));
159 assert_eq!(map.get(&3), Some(&"c"));
160 assert_eq!(
161 map.iter().collect::<Vec<_>>(),
162 vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
163 );
164
165 map.remove(&2);
166 assert_eq!(map.get(&2), None);
167 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
168
169 map.remove(&3);
170 assert_eq!(map.get(&3), None);
171 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
172
173 map.remove(&1);
174 assert_eq!(map.get(&1), None);
175 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
176 }
177}