Fix `TreeMap::get` always returning `None`

Antonio Scandurra created

Change summary

crates/sum_tree/src/tree_map.rs | 19 +++++++++++++++----
1 file changed, 15 insertions(+), 4 deletions(-)

Detailed changes

crates/sum_tree/src/tree_map.rs 🔗

@@ -23,10 +23,13 @@ pub struct MapKeyRef<'a, K>(Option<&'a K>);
 impl<K: Clone + Debug + Default + Ord, V: Clone + Debug> TreeMap<K, V> {
     pub fn get<'a>(&self, key: &'a K) -> Option<&V> {
         let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
-        let key = MapKeyRef(Some(key));
-        cursor.seek(&key, Bias::Left, &());
-        if key.cmp(cursor.start(), &()) == Ordering::Equal {
-            Some(&cursor.item().unwrap().value)
+        cursor.seek(&MapKeyRef(Some(key)), Bias::Left, &());
+        if let Some(item) = cursor.item() {
+            if *key == item.key().0 {
+                Some(&item.value)
+            } else {
+                None
+            }
         } else {
             None
         }
@@ -129,24 +132,32 @@ mod tests {
         assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
 
         map.insert(3, "c");
+        assert_eq!(map.get(&3), Some(&"c"));
         assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&3, &"c")]);
 
         map.insert(1, "a");
+        assert_eq!(map.get(&1), Some(&"a"));
         assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
 
         map.insert(2, "b");
+        assert_eq!(map.get(&2), Some(&"b"));
+        assert_eq!(map.get(&1), Some(&"a"));
+        assert_eq!(map.get(&3), Some(&"c"));
         assert_eq!(
             map.iter().collect::<Vec<_>>(),
             vec![(&1, &"a"), (&2, &"b"), (&3, &"c")]
         );
 
         map.remove(&2);
+        assert_eq!(map.get(&2), None);
         assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a"), (&3, &"c")]);
 
         map.remove(&3);
+        assert_eq!(map.get(&3), None);
         assert_eq!(map.iter().collect::<Vec<_>>(), vec![(&1, &"a")]);
 
         map.remove(&1);
+        assert_eq!(map.get(&1), None);
         assert_eq!(map.iter().collect::<Vec<_>>(), vec![]);
     }
 }