Use more efficient sum tree traversals for removal and improve ergonomics with iter_from

Mikayla Maki and Nathan created

co-authored-by: Nathan <nathan@zed.dev>

Change summary

crates/project/src/worktree.rs  |   5 
crates/sum_tree/src/tree_map.rs | 120 +++++++++++++++++++++++-----------
2 files changed, 84 insertions(+), 41 deletions(-)

Detailed changes

crates/project/src/worktree.rs 🔗

@@ -185,7 +185,8 @@ impl RepositoryEntry {
             .relativize(snapshot, path)
             .and_then(|repo_path| {
                 self.worktree_statuses
-                    .get_from_while(&repo_path, |repo_path, key, _| key.starts_with(repo_path))
+                    .iter_from(&repo_path)
+                    .take_while(|(key, _)| key.starts_with(&repo_path))
                     .map(|(_, status)| status)
                     // Short circut once we've found the highest level
                     .take_until(|status| status == &&GitFileStatus::Conflict)
@@ -3022,7 +3023,7 @@ impl BackgroundScanner {
             snapshot.repository_entries.update(&work_dir, |entry| {
                 entry
                     .worktree_statuses
-                    .remove_from_while(&repo_path, |stored_path, _| {
+                    .remove_by(&repo_path, |stored_path, _| {
                         stored_path.starts_with(&repo_path)
                     })
             });

crates/sum_tree/src/tree_map.rs 🔗

@@ -1,4 +1,4 @@
-use std::{cmp::Ordering, fmt::Debug, iter};
+use std::{cmp::Ordering, fmt::Debug};
 
 use crate::{Bias, Dimension, Edit, Item, KeyedItem, SeekTarget, SumTree, Summary};
 
@@ -93,43 +93,14 @@ impl<K: Clone + Debug + Default + Ord, V: Clone + Debug> TreeMap<K, V> {
         self.0 = new_tree;
     }
 
-    pub fn remove_from_while<F>(&mut self, from: &K, mut f: F)
-    where
-        F: FnMut(&K, &V) -> bool,
-    {
-        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
-        let from_key = MapKeyRef(Some(from));
-        let mut new_tree = cursor.slice(&from_key, Bias::Left, &());
-        while let Some(item) = cursor.item() {
-            if !f(&item.key, &item.value) {
-                break;
-            }
-            cursor.next(&());
-        }
-        new_tree.push_tree(cursor.suffix(&()), &());
-        drop(cursor);
-        self.0 = new_tree;
-    }
-
-    pub fn get_from_while<'tree, F>(
-        &'tree self,
-        from: &'tree K,
-        mut f: F,
-    ) -> impl Iterator<Item = (&K, &V)> + '_
-    where
-        F: FnMut(&K, &K, &V) -> bool + 'tree,
-    {
+    pub fn iter_from<'a>(&'a self, from: &'a K) -> impl Iterator<Item = (&K, &V)> + '_ {
         let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
         let from_key = MapKeyRef(Some(from));
         cursor.seek(&from_key, Bias::Left, &());
 
-        iter::from_fn(move || {
-            let result = cursor.item().and_then(|item| {
-                (f(from, &item.key, &item.value)).then(|| (&item.key, &item.value))
-            });
-            cursor.next(&());
-            result
-        })
+        cursor
+            .into_iter()
+            .map(|map_entry| (&map_entry.key, &map_entry.value))
     }
 
     pub fn update<F, T>(&mut self, key: &K, f: F) -> Option<T>
@@ -189,6 +160,51 @@ impl<K: Clone + Debug + Default + Ord, V: Clone + Debug> TreeMap<K, V> {
 
         self.0.edit(edits, &());
     }
+
+    pub fn remove_by<F>(&mut self, key: &K, f: F)
+    where
+        F: Fn(&K) -> bool,
+    {
+        let mut cursor = self.0.cursor::<MapKeyRef<'_, K>>();
+        let key = MapKeyRef(Some(key));
+        let mut new_tree = cursor.slice(&key, Bias::Left, &());
+        let until = RemoveByTarget(key, &f);
+        cursor.seek_forward(&until, Bias::Right, &());
+        new_tree.push_tree(cursor.suffix(&()), &());
+        drop(cursor);
+        self.0 = new_tree;
+    }
+}
+
+struct RemoveByTarget<'a, K>(MapKeyRef<'a, K>, &'a dyn Fn(&K) -> bool);
+
+impl<'a, K: Debug> Debug for RemoveByTarget<'a, K> {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.debug_struct("RemoveByTarget")
+            .field("key", &self.0)
+            .field("F", &"<...>")
+            .finish()
+    }
+}
+
+impl<'a, K: Debug + Clone + Default + Ord> SeekTarget<'a, MapKey<K>, MapKeyRef<'a, K>>
+    for RemoveByTarget<'_, K>
+{
+    fn cmp(
+        &self,
+        cursor_location: &MapKeyRef<'a, K>,
+        _cx: &<MapKey<K> as Summary>::Context,
+    ) -> Ordering {
+        if let Some(cursor_location) = cursor_location.0 {
+            if (self.1)(cursor_location) {
+                Ordering::Equal
+            } else {
+                self.0 .0.unwrap().cmp(cursor_location)
+            }
+        } else {
+            Ordering::Greater
+        }
+    }
 }
 
 impl<K, V> Default for TreeMap<K, V>
@@ -357,26 +373,50 @@ mod tests {
     }
 
     #[test]
-    fn test_remove_from_while() {
+    fn test_remove_by() {
         let mut map = TreeMap::default();
 
         map.insert("a", 1);
+        map.insert("aa", 1);
         map.insert("b", 2);
         map.insert("baa", 3);
         map.insert("baaab", 4);
         map.insert("c", 5);
+        map.insert("ca", 6);
 
-        map.remove_from_while(&"ba", |key, _| key.starts_with(&"ba"));
+        map.remove_by(&"ba", |key| key.starts_with("ba"));
 
         assert_eq!(map.get(&"a"), Some(&1));
+        assert_eq!(map.get(&"aa"), Some(&1));
         assert_eq!(map.get(&"b"), Some(&2));
         assert_eq!(map.get(&"baaa"), None);
         assert_eq!(map.get(&"baaaab"), None);
         assert_eq!(map.get(&"c"), Some(&5));
+        assert_eq!(map.get(&"ca"), Some(&6));
+
+
+        map.remove_by(&"c", |key| key.starts_with("c"));
+
+        assert_eq!(map.get(&"a"), Some(&1));
+        assert_eq!(map.get(&"aa"), Some(&1));
+        assert_eq!(map.get(&"b"), Some(&2));
+        assert_eq!(map.get(&"c"), None);
+        assert_eq!(map.get(&"ca"), None);
+
+        map.remove_by(&"a", |key| key.starts_with("a"));
+
+        assert_eq!(map.get(&"a"), None);
+        assert_eq!(map.get(&"aa"), None);
+        assert_eq!(map.get(&"b"), Some(&2));
+
+        map.remove_by(&"b", |key| key.starts_with("b"));
+
+        assert_eq!(map.get(&"b"), None);
+
     }
 
     #[test]
-    fn test_get_from_while() {
+    fn test_iter_from() {
         let mut map = TreeMap::default();
 
         map.insert("a", 1);
@@ -386,7 +426,8 @@ mod tests {
         map.insert("c", 5);
 
         let result = map
-            .get_from_while(&"ba", |_, key, _| key.starts_with(&"ba"))
+            .iter_from(&"ba")
+            .take_while(|(key, _)| key.starts_with(&"ba"))
             .collect::<Vec<_>>();
 
         assert_eq!(result.len(), 2);
@@ -394,7 +435,8 @@ mod tests {
         assert!(result.iter().find(|(k, _)| k == &&"baaab").is_some());
 
         let result = map
-            .get_from_while(&"c", |_, key, _| key.starts_with(&"c"))
+            .iter_from(&"c")
+            .take_while(|(key, _)| key.starts_with(&"c"))
             .collect::<Vec<_>>();
 
         assert_eq!(result.len(), 1);