Fix fuzzy matching after removing root dirname from stored paths

Max Brunsfeld and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

zed/src/file_finder.rs       |  25 +++++-
zed/src/worktree.rs          |  38 +++++++---
zed/src/worktree/char_bag.rs |  10 ++
zed/src/worktree/fuzzy.rs    | 134 ++++++++++++++++++++++++-------------
4 files changed, 139 insertions(+), 68 deletions(-)

Detailed changes

zed/src/file_finder.rs 🔗

@@ -24,6 +24,7 @@ pub struct FileFinder {
     search_count: usize,
     latest_search_id: usize,
     matches: Vec<PathMatch>,
+    include_root_name: bool,
     selected: usize,
     list_state: UniformListState,
 }
@@ -138,7 +139,12 @@ impl FileFinder {
     ) -> Option<ElementBox> {
         let tree_id = path_match.tree_id;
 
-        self.worktree(tree_id, app).map(|_| {
+        self.worktree(tree_id, app).map(|tree| {
+            let prefix = if self.include_root_name {
+                tree.root_name_chars()
+            } else {
+                &[]
+            };
             let path = path_match.path.clone();
             let path_string = &path_match.path_string;
             let file_name = Path::new(&path_string)
@@ -148,7 +154,8 @@ impl FileFinder {
                 .to_string();
 
             let path_positions = path_match.positions.clone();
-            let file_name_start = path_string.chars().count() - file_name.chars().count();
+            let file_name_start =
+                prefix.len() + path_string.chars().count() - file_name.chars().count();
             let mut file_name_positions = Vec::new();
             file_name_positions.extend(path_positions.iter().filter_map(|pos| {
                 if pos >= &file_name_start {
@@ -162,6 +169,9 @@ impl FileFinder {
             let highlight_color = ColorU::from_u32(0x304ee2ff);
             let bold = *Properties::new().weight(Weight::BOLD);
 
+            let mut full_path = prefix.iter().collect::<String>();
+            full_path.push_str(&path_string);
+
             let mut container = Container::new(
                 Flex::row()
                     .with_child(
@@ -191,7 +201,7 @@ impl FileFinder {
                                 )
                                 .with_child(
                                     Label::new(
-                                        path_string.into(),
+                                        full_path,
                                         settings.ui_font_family,
                                         settings.ui_font_size,
                                     )
@@ -275,6 +285,7 @@ impl FileFinder {
             search_count: 0,
             latest_search_id: 0,
             matches: Vec::new(),
+            include_root_name: false,
             selected: 0,
             list_state: UniformListState::new(),
         }
@@ -348,16 +359,17 @@ impl FileFinder {
         let search_id = util::post_inc(&mut self.search_count);
         let pool = ctx.as_ref().thread_pool().clone();
         let task = ctx.background_executor().spawn(async move {
+            let include_root_name = snapshots.len() > 1;
             let matches = match_paths(
                 snapshots.iter(),
                 &query,
-                snapshots.len() > 1,
+                include_root_name,
                 false,
                 false,
                 100,
                 pool,
             );
-            (search_id, matches)
+            (search_id, include_root_name, matches)
         });
 
         ctx.spawn(task, Self::update_matches).detach();
@@ -365,12 +377,13 @@ impl FileFinder {
 
     fn update_matches(
         &mut self,
-        (search_id, matches): (usize, Vec<PathMatch>),
+        (search_id, include_root_name, matches): (usize, bool, Vec<PathMatch>),
         ctx: &mut ViewContext<Self>,
     ) {
         if search_id >= self.latest_search_id {
             self.latest_search_id = search_id;
             self.matches = matches;
+            self.include_root_name = include_root_name;
             self.selected = 0;
             self.list_state.scroll_to(0);
             ctx.notify();

zed/src/worktree.rs 🔗

@@ -32,6 +32,8 @@ use std::{
     time::Duration,
 };
 
+use self::char_bag::CharBag;
+
 lazy_static! {
     static ref GITIGNORE: &'static OsStr = OsStr::new(".gitignore");
 }
@@ -59,12 +61,17 @@ pub struct FileHandle {
 
 impl Worktree {
     pub fn new(path: impl Into<Arc<Path>>, ctx: &mut ModelContext<Self>) -> Self {
+        let abs_path = path.into();
+        let root_name_chars = abs_path.file_name().map_or(Vec::new(), |n| {
+            n.to_string_lossy().chars().chain(Some('/')).collect()
+        });
         let (scan_state_tx, scan_state_rx) = smol::channel::unbounded();
         let id = ctx.model_id();
         let snapshot = Snapshot {
             id,
             scan_id: 0,
-            abs_path: path.into(),
+            abs_path,
+            root_name_chars,
             ignores: Default::default(),
             entries: Default::default(),
         };
@@ -209,6 +216,7 @@ pub struct Snapshot {
     id: usize,
     scan_id: usize,
     abs_path: Arc<Path>,
+    root_name_chars: Vec<char>,
     ignores: BTreeMap<Arc<Path>, (Arc<Gitignore>, usize)>,
     entries: SumTree<Entry>,
 }
@@ -238,13 +246,17 @@ impl Snapshot {
     }
 
     pub fn root_entry(&self) -> Entry {
-        self.entry_for_path(&self.abs_path).unwrap()
+        self.entry_for_path("").unwrap()
     }
 
     pub fn root_name(&self) -> Option<&OsStr> {
         self.abs_path.file_name()
     }
 
+    pub fn root_name_chars(&self) -> &[char] {
+        &self.root_name_chars
+    }
+
     fn entry_for_path(&self, path: impl AsRef<Path>) -> Option<Entry> {
         let mut cursor = self.entries.cursor::<_, ()>();
         if cursor.seek(&PathSearch::Exact(path.as_ref()), SeekBias::Left) {
@@ -259,8 +271,6 @@ impl Snapshot {
     }
 
     fn is_path_ignored(&self, path: &Path) -> Result<bool> {
-        dbg!(path);
-
         let mut entry = self
             .entry_for_path(path)
             .ok_or_else(|| anyhow!("entry does not exist in worktree"))?;
@@ -272,7 +282,6 @@ impl Snapshot {
                 entry.path().parent().and_then(|p| self.entry_for_path(p))
             {
                 let parent_path = parent_entry.path();
-                dbg!(parent_path);
                 if let Some((ignore, _)) = self.ignores.get(parent_path) {
                     let relative_path = path.strip_prefix(parent_path).unwrap();
                     match ignore.matched_path_or_any_parents(relative_path, entry.is_dir()) {
@@ -567,11 +576,14 @@ struct BackgroundScanner {
     notify: Sender<ScanState>,
     other_mount_paths: HashSet<PathBuf>,
     thread_pool: scoped_pool::Pool,
+    root_char_bag: CharBag,
 }
 
 impl BackgroundScanner {
     fn new(snapshot: Arc<Mutex<Snapshot>>, notify: Sender<ScanState>, worktree_id: usize) -> Self {
+        let root_char_bag = CharBag::from(snapshot.lock().root_name_chars.as_slice());
         let mut scanner = Self {
+            root_char_bag,
             snapshot,
             notify,
             other_mount_paths: Default::default(),
@@ -673,7 +685,7 @@ impl BackgroundScanner {
             });
         } else {
             self.snapshot.lock().insert_entry(Entry::File {
-                path_entry: PathEntry::new(inode, path.clone()),
+                path_entry: PathEntry::new(inode, self.root_char_bag, path.clone()),
                 path,
                 inode,
                 is_symlink,
@@ -719,7 +731,7 @@ impl BackgroundScanner {
                 });
             } else {
                 new_entries.push(Entry::File {
-                    path_entry: PathEntry::new(child_inode, child_path.clone()),
+                    path_entry: PathEntry::new(child_inode, self.root_char_bag, child_path.clone()),
                     path: child_path,
                     inode: child_inode,
                     is_symlink: child_is_symlink,
@@ -958,7 +970,7 @@ impl BackgroundScanner {
             }
         } else {
             Entry::File {
-                path_entry: PathEntry::new(inode, path.clone()),
+                path_entry: PathEntry::new(inode, self.root_char_bag, path.clone()),
                 path,
                 inode,
                 is_symlink,
@@ -1113,14 +1125,14 @@ mod tests {
                     10,
                     ctx.thread_pool().clone(),
                 )
-                .iter()
-                .map(|result| result.path.clone())
+                .into_iter()
+                .map(|result| result.path)
                 .collect::<Vec<Arc<Path>>>();
                 assert_eq!(
                     results,
                     vec![
-                        PathBuf::from("root_link/banana/carrot/date").into(),
-                        PathBuf::from("root_link/banana/carrot/endive").into(),
+                        PathBuf::from("banana/carrot/date").into(),
+                        PathBuf::from("banana/carrot/endive").into(),
                     ]
                 );
             })
@@ -1288,6 +1300,7 @@ mod tests {
                     abs_path: root_dir.path().into(),
                     entries: Default::default(),
                     ignores: Default::default(),
+                    root_name_chars: Default::default(),
                 })),
                 notify_tx,
                 0,
@@ -1321,6 +1334,7 @@ mod tests {
                     abs_path: root_dir.path().into(),
                     entries: Default::default(),
                     ignores: Default::default(),
+                    root_name_chars: Default::default(),
                 })),
                 notify_tx,
                 1,

zed/src/worktree/char_bag.rs 🔗

@@ -1,4 +1,4 @@
-#[derive(Copy, Clone, Debug)]
+#[derive(Copy, Clone, Debug, Default)]
 pub struct CharBag(u64);
 
 impl CharBag {
@@ -23,6 +23,14 @@ impl CharBag {
     }
 }
 
+impl Extend<char> for CharBag {
+    fn extend<T: IntoIterator<Item = char>>(&mut self, iter: T) {
+        for c in iter {
+            self.insert(c);
+        }
+    }
+}
+
 impl From<&str> for CharBag {
     fn from(s: &str) -> Self {
         let mut bag = Self(0);

zed/src/worktree/fuzzy.rs 🔗

@@ -21,11 +21,12 @@ pub struct PathEntry {
 }
 
 impl PathEntry {
-    pub fn new(ino: u64, path: Arc<Path>) -> Self {
+    pub fn new(ino: u64, root_char_bag: CharBag, path: Arc<Path>) -> Self {
         let path_str = path.to_string_lossy();
         let lowercase_path = path_str.to_lowercase().chars().collect::<Vec<_>>().into();
         let path_chars: Arc<[char]> = path_str.chars().collect::<Vec<_>>().into();
-        let char_bag = CharBag::from(path_chars.as_ref());
+        let mut char_bag = root_char_bag;
+        char_bag.extend(path_chars.iter().copied());
 
         Self {
             ino,
@@ -136,21 +137,9 @@ where
                             }
                         });
 
-                        let skipped_prefix_len = if include_root_name {
-                            0
-                        } else if let Entry::Dir { .. } = snapshot.root_entry() {
-                            if let Some(name) = snapshot.root_name() {
-                                name.to_string_lossy().chars().count() + 1
-                            } else {
-                                1
-                            }
-                        } else {
-                            0
-                        };
-
                         match_single_tree_paths(
                             snapshot,
-                            skipped_prefix_len,
+                            include_root_name,
                             path_entries,
                             query,
                             lowercase_query,
@@ -186,7 +175,7 @@ where
 
 fn match_single_tree_paths<'a>(
     snapshot: &Snapshot,
-    skipped_prefix_len: usize,
+    include_root_name: bool,
     path_entries: impl Iterator<Item = &'a PathEntry>,
     query: &[char],
     lowercase_query: &[char],
@@ -200,6 +189,12 @@ fn match_single_tree_paths<'a>(
     score_matrix: &mut Vec<Option<f64>>,
     best_position_matrix: &mut Vec<usize>,
 ) {
+    let prefix = if include_root_name {
+        snapshot.root_name_chars.as_slice()
+    } else {
+        &[]
+    };
+
     for path_entry in path_entries {
         if !path_entry.char_bag.is_superset(query_chars) {
             continue;
@@ -207,25 +202,25 @@ fn match_single_tree_paths<'a>(
 
         if !find_last_positions(
             last_positions,
-            skipped_prefix_len,
+            prefix,
             &path_entry.lowercase_path,
             &lowercase_query[..],
         ) {
             continue;
         }
 
-        let matrix_len = query.len() * (path_entry.path_chars.len() - skipped_prefix_len);
+        let matrix_len = query.len() * (path_entry.path_chars.len() + prefix.len());
         score_matrix.clear();
         score_matrix.resize(matrix_len, None);
         best_position_matrix.clear();
-        best_position_matrix.resize(matrix_len, skipped_prefix_len);
+        best_position_matrix.resize(matrix_len, 0);
 
         let score = score_match(
             &query[..],
             &lowercase_query[..],
             &path_entry.path_chars,
             &path_entry.lowercase_path,
-            skipped_prefix_len,
+            prefix,
             smart_case,
             &last_positions,
             score_matrix,
@@ -237,11 +232,7 @@ fn match_single_tree_paths<'a>(
         if score > 0.0 {
             results.push(Reverse(PathMatch {
                 tree_id: snapshot.id,
-                path_string: path_entry
-                    .path_chars
-                    .iter()
-                    .skip(skipped_prefix_len)
-                    .collect(),
+                path_string: path_entry.path_chars.iter().collect(),
                 path: path_entry.path.clone(),
                 score,
                 positions: match_positions.clone(),
@@ -255,18 +246,17 @@ fn match_single_tree_paths<'a>(
 
 fn find_last_positions(
     last_positions: &mut Vec<usize>,
-    skipped_prefix_len: usize,
+    prefix: &[char],
     path: &[char],
     query: &[char],
 ) -> bool {
     let mut path = path.iter();
+    let mut prefix_iter = prefix.iter();
     for (i, char) in query.iter().enumerate().rev() {
         if let Some(j) = path.rposition(|c| c == char) {
-            if j >= skipped_prefix_len {
-                last_positions[i] = j;
-            } else {
-                return false;
-            }
+            last_positions[i] = j + prefix.len();
+        } else if let Some(j) = prefix_iter.rposition(|c| c == char) {
+            last_positions[i] = j;
         } else {
             return false;
         }
@@ -279,7 +269,7 @@ fn score_match(
     query_cased: &[char],
     path: &[char],
     path_cased: &[char],
-    skipped_prefix_len: usize,
+    prefix: &[char],
     smart_case: bool,
     last_positions: &[usize],
     score_matrix: &mut [Option<f64>],
@@ -292,14 +282,14 @@ fn score_match(
         query_cased,
         path,
         path_cased,
-        skipped_prefix_len,
+        prefix,
         smart_case,
         last_positions,
         score_matrix,
         best_position_matrix,
         min_score,
         0,
-        skipped_prefix_len,
+        0,
         query.len() as f64,
     ) * query.len() as f64;
 
@@ -307,10 +297,10 @@ fn score_match(
         return 0.0;
     }
 
-    let path_len = path.len() - skipped_prefix_len;
+    let path_len = path.len() + prefix.len();
     let mut cur_start = 0;
     for i in 0..query.len() {
-        match_positions[i] = best_position_matrix[i * path_len + cur_start] - skipped_prefix_len;
+        match_positions[i] = best_position_matrix[i * path_len + cur_start];
         cur_start = match_positions[i] + 1;
     }
 
@@ -322,7 +312,7 @@ fn recursive_score_match(
     query_cased: &[char],
     path: &[char],
     path_cased: &[char],
-    skipped_prefix_len: usize,
+    prefix: &[char],
     smart_case: bool,
     last_positions: &[usize],
     score_matrix: &mut [Option<f64>],
@@ -336,9 +326,9 @@ fn recursive_score_match(
         return 1.0;
     }
 
-    let path_len = path.len() - skipped_prefix_len;
+    let path_len = prefix.len() + path.len();
 
-    if let Some(memoized) = score_matrix[query_idx * path_len + path_idx - skipped_prefix_len] {
+    if let Some(memoized) = score_matrix[query_idx * path_len + path_idx] {
         return memoized;
     }
 
@@ -350,7 +340,11 @@ fn recursive_score_match(
 
     let mut last_slash = 0;
     for j in path_idx..=limit {
-        let path_char = path_cased[j];
+        let path_char = if j < prefix.len() {
+            prefix[j]
+        } else {
+            path_cased[j - prefix.len()]
+        };
         let is_path_sep = path_char == '/' || path_char == '\\';
 
         if query_idx == 0 && is_path_sep {
@@ -358,10 +352,19 @@ fn recursive_score_match(
         }
 
         if query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\') {
+            let curr = if j < prefix.len() {
+                prefix[j]
+            } else {
+                path[j - prefix.len()]
+            };
+
             let mut char_score = 1.0;
             if j > path_idx {
-                let last = path[j - 1];
-                let curr = path[j];
+                let last = if j - 1 < prefix.len() {
+                    prefix[j - 1]
+                } else {
+                    path[j - 1 - prefix.len()]
+                };
 
                 if last == '/' {
                     char_score = 0.9;
@@ -384,15 +387,15 @@ fn recursive_score_match(
             // Apply a severe penalty if the case doesn't match.
             // This will make the exact matches have higher score than the case-insensitive and the
             // path insensitive matches.
-            if (smart_case || path[j] == '/') && query[query_idx] != path[j] {
+            if (smart_case || curr == '/') && query[query_idx] != curr {
                 char_score *= 0.001;
             }
 
             let mut multiplier = char_score;
 
-            // Scale the score based on how deep within the patch we found the match.
+            // Scale the score based on how deep within the path we found the match.
             if query_idx == 0 {
-                multiplier /= (path.len() - last_slash) as f64;
+                multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
             }
 
             let mut next_score = 1.0;
@@ -413,7 +416,7 @@ fn recursive_score_match(
                 query_cased,
                 path,
                 path_cased,
-                skipped_prefix_len,
+                prefix,
                 smart_case,
                 last_positions,
                 score_matrix,
@@ -436,10 +439,10 @@ fn recursive_score_match(
     }
 
     if best_position != 0 {
-        best_position_matrix[query_idx * path_len + path_idx - skipped_prefix_len] = best_position;
+        best_position_matrix[query_idx * path_len + path_idx] = best_position;
     }
 
-    score_matrix[query_idx * path_len + path_idx - skipped_prefix_len] = Some(score);
+    score_matrix[query_idx * path_len + path_idx] = Some(score);
     score
 }
 
@@ -448,6 +451,38 @@ mod tests {
     use super::*;
     use std::path::PathBuf;
 
+    #[test]
+    fn test_get_last_positions() {
+        let mut last_positions = vec![0; 2];
+        let result = find_last_positions(
+            &mut last_positions,
+            &['a', 'b', 'c'],
+            &['b', 'd', 'e', 'f'],
+            &['d', 'c'],
+        );
+        assert_eq!(result, false);
+
+        last_positions.resize(2, 0);
+        let result = find_last_positions(
+            &mut last_positions,
+            &['a', 'b', 'c'],
+            &['b', 'd', 'e', 'f'],
+            &['c', 'd'],
+        );
+        assert_eq!(result, true);
+        assert_eq!(last_positions, vec![2, 4]);
+
+        last_positions.resize(4, 0);
+        let result = find_last_positions(
+            &mut last_positions,
+            &['z', 'e', 'd', '/'],
+            &['z', 'e', 'd', '/', 'f'],
+            &['z', '/', 'z', 'f'],
+        );
+        assert_eq!(result, true);
+        assert_eq!(last_positions, vec![0, 3, 4, 8]);
+    }
+
     #[test]
     fn test_match_path_entries() {
         let paths = vec![
@@ -526,8 +561,9 @@ mod tests {
                 abs_path: PathBuf::new().into(),
                 ignores: Default::default(),
                 entries: Default::default(),
+                root_name_chars: Vec::new(),
             },
-            0,
+            false,
             path_entries.iter(),
             &query[..],
             &lowercase_query[..],