1use std::{cmp::Ordering, fmt::Debug};
2
3use crate::{Bias, Dimension, Item, KeyedItem, SeekTarget, SumTree, Summary};
4
5#[derive(Clone)]
6pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
7where
8 K: Clone + Debug + Default + Ord,
9 V: Clone + Debug;
10
11#[derive(Clone)]
12pub struct MapEntry<K, V> {
13 key: K,
14 value: V,
15}
16
17#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
18pub struct MapKey<K>(K);
19
20#[derive(Clone, Debug, Default)]
21pub struct MapKeyRef<'a, K>(Option<&'a K>);
22
23impl<K: Clone + Debug + Default + Ord, V: Clone + Debug> TreeMap<K, V> {
24 pub fn get<'a>(&self, key: &'a K) -> Option<&V> {
25 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
26 cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &());
27 if let Some(item) = cursor.item() {
28 if *key == item.key().0 {
29 Some(&item.value)
30 } else {
31 None
32 }
33 } else {
34 None
35 }
36 }
37
38 pub fn insert(&mut self, key: K, value: V) {
39 self.0.insert_or_replace(MapEntry { key, value }, &());
40 }
41
42 pub fn remove<'a>(&mut self, key: &'a K) -> Option<V> {
43 let mut removed = None;
44 let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
45 let key = MapKeyRef(Some(key));
46 let mut new_tree = cursor.slice(&key, Bias::Left, &());
47 if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
48 removed = Some(cursor.item().unwrap().value.clone());
49 cursor.next(&());
50 }
51 new_tree.push_tree(cursor.suffix(&()), &());
52 drop(cursor);
53 self.0 = new_tree;
54 removed
55 }
56
57 pub fn iter<'a>(&'a self) -> impl 'a + Iterator<Item = (&'a K, &'a V)> {
58 self.0.iter().map(|entry| (&entry.key, &entry.value))
59 }
60}
61
62impl<K, V> Default for TreeMap<K, V>
63where
64 K: Clone + Debug + Default + Ord,
65 V: Clone + Debug,
66{
67 fn default() -> Self {
68 Self(Default::default())
69 }
70}
71
72impl<K, V> Item for MapEntry<K, V>
73where
74 K: Clone + Debug + Default + Ord,
75 V: Clone,
76{
77 type Summary = MapKey<K>;
78
79 fn summary(&self) -> Self::Summary {
80 self.key()
81 }
82}
83
84impl<K, V> KeyedItem for MapEntry<K, V>
85where
86 K: Clone + Debug + Default + Ord,
87 V: Clone,
88{
89 type Key = MapKey<K>;
90
91 fn key(&self) -> Self::Key {
92 MapKey(self.key.clone())
93 }
94}
95
96impl<K> Summary for MapKey<K>
97where
98 K: Clone + Debug + Default,
99{
100 type Context = ();
101
102 fn add_summary(&mut self, summary: &Self, _: &()) {
103 *self = summary.clone()
104 }
105}
106
107impl<'a, K> Dimension<'a, MapKey<K>> for MapKeyRef<'a, K>
108where
109 K: Clone + Debug + Default + Ord,
110{
111 fn add_summary(&mut self, summary: &'a MapKey<K>, _: &()) {
112 self.0 = Some(&summary.0)
113 }
114}
115
116impl<'a, K> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>> for MapKeyRef<'_, K>
117where
118 K: Clone + Debug + Default + Ord,
119{
120 fn cmp(&self, cursor_location: &MapKeyRef<K>, _: &()) -> Ordering {
121 self.0.cmp(&cursor_location.0)
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128
129 #[test]
130 fn test_basic() {
131 let mut map = TreeMap::default();
132 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
133
134 map.insert(3, "c");
135 assert_eq!(map.get(&3), Some(&"c"));
136 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
137
138 map.insert(1, "a");
139 assert_eq!(map.get(&1), Some(&"a"));
140 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
141
142 map.insert(2, "b");
143 assert_eq!(map.get(&2), Some(&"b"));
144 assert_eq!(map.get(&1), Some(&"a"));
145 assert_eq!(map.get(&3), Some(&"c"));
146 assert_eq!(
147 map.iter().collect::<Vec<_>>(),
148 vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
149 );
150
151 map.remove(&2);
152 assert_eq!(map.get(&2), None);
153 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
154
155 map.remove(&3);
156 assert_eq!(map.get(&3), None);
157 assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
158
159 map.remove(&1);
160 assert_eq!(map.get(&1), None);
161 assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
162 }
163}