1use std::{cmp::Ordering, fmt::Debug};
2
3use crate::{Bias, Dimension, Item, KeyedItem, SeekTarget, SumTree, Summary};
4
5#[derive(Clone)]
6pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
7where
8 K: Clone + Debug + Default + Ord,
9 V: Clone + Debug;
10
11#[derive(Clone)]
12pub struct MapEntry<K, V> {
13 key: K,
14 value: V,
15}
16
17#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
18pub struct MapKey<K>(K);
19
20#[derive(Clone, Debug, Default)]
21pub struct MapKeyRef<'a, K>(Option<&'a K>);
22
23impl<K: Clone + Debug + Default + Ord, V: Clone + Debug> TreeMap<K, V> {
24 pub fn from_ordered_entries(entries: impl IntoIterator<Item = (K, V)>) -> Self {
25 let tree = SumTree::from_iter(
26 entries
27 .into_iter()
28 .map(|(key, value)| MapEntry { key, value }),
29 &(),
30 );
31 Self(tree)
32 }
33
34 pub fn get<'a>(&self, key: &'a K) -> Option<&V> {
35 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
36 cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &());
37 if let Some(item) = cursor.item() {
38 if *key == item.key().0 {
39 Some(&item.value)
40 } else {
41 None
42 }
43 } else {
44 None
45 }
46 }
47
48 pub fn insert(&mut self, key: K, value: V) {
49 self.0.insert_or_replace(MapEntry { key, value }, &());
50 }
51
52 pub fn remove<'a>(&mut self, key: &'a K) -> Option<V> {
53 let mut removed = None;
54 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
55 let key = MapKeyRef(Some(key));
56 let mut new_tree = cursor.slice(&key, Bias::Left, &());
57 if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
58 removed = Some(cursor.item().unwrap().value.clone());
59 cursor.next(&());
60 }
61 new_tree.push_tree(cursor.suffix(&()), &());
62 drop(cursor);
63 self.0 = new_tree;
64 removed
65 }
66
67 pub fn iter<'a>(&'a self) -> impl 'a + Iterator<Item = (&'a K, &'a V)> {
68 self.0.iter().map(|entry| (&entry.key, &entry.value))
69 }
70}
71
72impl<K, V> Default for TreeMap<K, V>
73where
74 K: Clone + Debug + Default + Ord,
75 V: Clone + Debug,
76{
77 fn default() -> Self {
78 Self(Default::default())
79 }
80}
81
82impl<K, V> Item for MapEntry<K, V>
83where
84 K: Clone + Debug + Default + Ord,
85 V: Clone,
86{
87 type Summary = MapKey<K>;
88
89 fn summary(&self) -> Self::Summary {
90 self.key()
91 }
92}
93
94impl<K, V> KeyedItem for MapEntry<K, V>
95where
96 K: Clone + Debug + Default + Ord,
97 V: Clone,
98{
99 type Key = MapKey<K>;
100
101 fn key(&self) -> Self::Key {
102 MapKey(self.key.clone())
103 }
104}
105
106impl<K> Summary for MapKey<K>
107where
108 K: Clone + Debug + Default,
109{
110 type Context = ();
111
112 fn add_summary(&mut self, summary: &Self, _: &()) {
113 *self = summary.clone()
114 }
115}
116
117impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
118where
119 K: Clone + Debug + Default + Ord,
120{
121 fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
122 self.0 = Some(&summary.0)
123 }
124}
125
126impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
127where
128 K: Clone + Debug + Default + Ord,
129{
130 fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
131 self.0.cmp(&cursor_location.0)
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138
139 #[test]
140 fn test_basic() {
141 let mut map = TreeMap::default();
142 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
143
144 map.insert(3, "c");
145 assert_eq!(map.get(&3), Some(&"c"));
146 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
147
148 map.insert(1, "a");
149 assert_eq!(map.get(&1), Some(&"a"));
150 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
151
152 map.insert(2, "b");
153 assert_eq!(map.get(&2), Some(&"b"));
154 assert_eq!(map.get(&1), Some(&"a"));
155 assert_eq!(map.get(&3), Some(&"c"));
156 assert_eq!(
157 map.iter().collect::<Vec<_>>(),
158 vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
159 );
160
161 map.remove(&2);
162 assert_eq!(map.get(&2), None);
163 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
164
165 map.remove(&3);
166 assert_eq!(map.get(&3), None);
167 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
168
169 map.remove(&1);
170 assert_eq!(map.get(&1), None);
171 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
172 }
173}