Write a simple unit test for TreeMap and fix bug in `remove`

Antonio Scandurra created

Change summary

crates/sum_tree/src/tree_map.rs | 42 ++++++++++++++++++++++++++++++----
1 file changed, 37 insertions(+), 5 deletions(-)

Detailed changes

crates/sum_tree/src/tree_map.rs 🔗

@@ -5,7 +5,7 @@ use crate::{Bias, Dimension, Item, KeyedItem, SeekTarget, SumTree, Summary};
 #[derive(Clone)]
 pub struct TreeMap<K, V>(SumTree<MapEntry<K, V>>)
 where
-    K: Clone + Debug + Default,
+    K: Clone + Debug + Default + Ord,
     V: Clone + Debug;
 
 #[derive(Clone)]
@@ -41,7 +41,7 @@ impl<K: Clone + Debug + Default + Ord, V: Clone + Debug> TreeMap<K, V> {
         let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
         let key = MapKeyRef(Some(key));
         let mut new_tree = cursor.slice(&key, Bias::Left, &());
-        if key.cmp(cursor.start(), &()) == Ordering::Equal {
+        if key.cmp(&cursor.end(&()), &()) == Ordering::Equal {
             removed = Some(cursor.item().unwrap().value.clone());
             cursor.next(&());
         }
@@ -58,7 +58,7 @@ impl<K: Clone + Debug + Default + Ord, V: Clone + Debug> TreeMap<K, V> {
 
 impl<K, V> Default for TreeMap<K, V>
 where
-    K: Clone + Debug + Default,
+    K: Clone + Debug + Default + Ord,
     V: Clone + Debug,
 {
     fn default() -> Self {
@@ -68,13 +68,13 @@ where
 
 impl<K, V> Item for MapEntry<K, V>
 where
-    K: Clone + Debug + Default + Clone,
+    K: Clone + Debug + Default + Ord,
     V: Clone,
 {
     type Summary = MapKey<K>;
 
     fn summary(&self) -> Self::Summary {
-        todo!()
+        self.key()
     }
 }
 
@@ -118,3 +118,35 @@ where
         self.0.cmp(&cursor_location.0)
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_basic() {
+        let mut map = TreeMap::default();
+        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
+
+        map.insert(3, "c");
+        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
+
+        map.insert(1, "a");
+        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
+
+        map.insert(2, "b");
+        assert_eq!(
+            map.iter().collect::<Vec<_>>(),
+            vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
+        );
+
+        map.remove(&2);
+        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
+
+        map.remove(&3);
+        assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
+
+        map.remove(&1);
+        assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
+    }
+}