matcher.rs

  1use std::{
  2    borrow::Borrow,
  3    sync::atomic::{self, AtomicBool},
  4};
  5
  6use crate::{CharBag, char_bag::simple_lowercase};
  7
  8const BASE_DISTANCE_PENALTY: f64 = 0.6;
  9const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
 10const MIN_DISTANCE_PENALTY: f64 = 0.2;
 11
 12// TODO:
 13// Use `Path` instead of `&str` for paths.
 14pub struct Matcher<'a> {
 15    query: &'a [char],
 16    lowercase_query: &'a [char],
 17    query_char_bag: CharBag,
 18    smart_case: bool,
 19    penalize_length: bool,
 20    min_score: f64,
 21    match_positions: Vec<usize>,
 22    last_positions: Vec<usize>,
 23    score_matrix: Vec<Option<f64>>,
 24    best_position_matrix: Vec<usize>,
 25}
 26
 27pub trait MatchCandidate {
 28    fn has_chars(&self, bag: CharBag) -> bool;
 29    fn candidate_chars(&self) -> impl Iterator<Item = char>;
 30}
 31
 32impl<'a> Matcher<'a> {
 33    pub fn new(
 34        query: &'a [char],
 35        lowercase_query: &'a [char],
 36        query_char_bag: CharBag,
 37        smart_case: bool,
 38        penalize_length: bool,
 39    ) -> Self {
 40        Self {
 41            query,
 42            lowercase_query,
 43            query_char_bag,
 44            min_score: 0.0,
 45            last_positions: vec![0; lowercase_query.len()],
 46            match_positions: vec![0; query.len()],
 47            score_matrix: Vec::new(),
 48            best_position_matrix: Vec::new(),
 49            smart_case,
 50            penalize_length,
 51        }
 52    }
 53
 54    /// Filter and score fuzzy match candidates. Results are returned unsorted, in the same order as
 55    /// the input candidates.
 56    pub(crate) fn match_candidates<C, R, F, T>(
 57        &mut self,
 58        prefix: &[char],
 59        lowercase_prefix: &[char],
 60        candidates: impl Iterator<Item = T>,
 61        results: &mut Vec<R>,
 62        cancel_flag: &AtomicBool,
 63        build_match: F,
 64    ) where
 65        C: MatchCandidate,
 66        T: Borrow<C>,
 67        F: Fn(&C, f64, &Vec<usize>) -> R,
 68    {
 69        let mut candidate_chars = Vec::new();
 70        let mut lowercase_candidate_chars = Vec::new();
 71
 72        for candidate in candidates {
 73            if !candidate.borrow().has_chars(self.query_char_bag) {
 74                continue;
 75            }
 76
 77            if cancel_flag.load(atomic::Ordering::Acquire) {
 78                break;
 79            }
 80
 81            candidate_chars.clear();
 82            lowercase_candidate_chars.clear();
 83            for c in candidate.borrow().candidate_chars() {
 84                candidate_chars.push(c);
 85                lowercase_candidate_chars.push(simple_lowercase(c));
 86            }
 87
 88            if !self.find_last_positions(lowercase_prefix, &lowercase_candidate_chars) {
 89                continue;
 90            }
 91
 92            let matrix_len =
 93                self.query.len() * (lowercase_prefix.len() + lowercase_candidate_chars.len());
 94            self.score_matrix.clear();
 95            self.score_matrix.resize(matrix_len, None);
 96            self.best_position_matrix.clear();
 97            self.best_position_matrix.resize(matrix_len, 0);
 98
 99            let score = self.score_match(
100                &candidate_chars,
101                &lowercase_candidate_chars,
102                prefix,
103                lowercase_prefix,
104            );
105
106            if score > 0.0 {
107                results.push(build_match(
108                    candidate.borrow(),
109                    score,
110                    &self.match_positions,
111                ));
112            }
113        }
114    }
115
116    fn find_last_positions(
117        &mut self,
118        lowercase_prefix: &[char],
119        lowercase_candidate: &[char],
120    ) -> bool {
121        let mut lowercase_prefix = lowercase_prefix.iter();
122        let mut lowercase_candidate = lowercase_candidate.iter();
123        for (i, char) in self.lowercase_query.iter().enumerate().rev() {
124            if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
125                self.last_positions[i] = j + lowercase_prefix.len();
126            } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
127                self.last_positions[i] = j;
128            } else {
129                return false;
130            }
131        }
132        true
133    }
134
135    fn score_match(
136        &mut self,
137        path: &[char],
138        path_lowercased: &[char],
139        prefix: &[char],
140        lowercase_prefix: &[char],
141    ) -> f64 {
142        let score = self.recursive_score_match(
143            path,
144            path_lowercased,
145            prefix,
146            lowercase_prefix,
147            0,
148            0,
149            self.query.len() as f64,
150        ) * self.query.len() as f64;
151
152        if score <= 0.0 {
153            return 0.0;
154        }
155        let path_len = prefix.len() + path.len();
156        let mut cur_start = 0;
157        let mut byte_ix = 0;
158        let mut char_ix = 0;
159        for i in 0..self.query.len() {
160            let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
161            while char_ix < match_char_ix {
162                let ch = prefix
163                    .get(char_ix)
164                    .or_else(|| path.get(char_ix - prefix.len()))
165                    .unwrap();
166                byte_ix += ch.len_utf8();
167                char_ix += 1;
168            }
169
170            self.match_positions[i] = byte_ix;
171
172            let matched_ch = prefix
173                .get(match_char_ix)
174                .or_else(|| path.get(match_char_ix - prefix.len()))
175                .unwrap();
176            byte_ix += matched_ch.len_utf8();
177
178            cur_start = match_char_ix + 1;
179            char_ix = match_char_ix + 1;
180        }
181
182        score
183    }
184
185    fn recursive_score_match(
186        &mut self,
187        path: &[char],
188        path_lowercased: &[char],
189        prefix: &[char],
190        lowercase_prefix: &[char],
191        query_idx: usize,
192        path_idx: usize,
193        cur_score: f64,
194    ) -> f64 {
195        if query_idx == self.query.len() {
196            return 1.0;
197        }
198
199        let limit = self.last_positions[query_idx];
200        let max_valid_index = (prefix.len() + path_lowercased.len()).saturating_sub(1);
201        let safe_limit = limit.min(max_valid_index);
202
203        if path_idx > safe_limit {
204            return 0.0;
205        }
206
207        let path_len = prefix.len() + path.len();
208        if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
209            return memoized;
210        }
211
212        let mut score = 0.0;
213        let mut best_position = 0;
214
215        let query_char = self.lowercase_query[query_idx];
216
217        let mut last_slash = 0;
218
219        for j in path_idx..=safe_limit {
220            let path_char = if j < prefix.len() {
221                lowercase_prefix[j]
222            } else {
223                let path_index = j - prefix.len();
224                match path_lowercased.get(path_index) {
225                    Some(&char) => char,
226                    None => continue,
227                }
228            };
229            let is_path_sep = path_char == '/';
230
231            if query_idx == 0 && is_path_sep {
232                last_slash = j;
233            }
234            let need_to_score = query_char == path_char || (is_path_sep && query_char == '_');
235            if need_to_score {
236                let curr = match prefix.get(j) {
237                    Some(&curr) => curr,
238                    None => path[j - prefix.len()],
239                };
240
241                let mut char_score = 1.0;
242                if j > path_idx {
243                    let last = match prefix.get(j - 1) {
244                        Some(&last) => last,
245                        None => path[j - 1 - prefix.len()],
246                    };
247
248                    if last == '/' {
249                        char_score = 0.9;
250                    } else if (last == '-' || last == '_' || last == ' ' || last.is_numeric())
251                        || (last.is_lowercase() && curr.is_uppercase())
252                    {
253                        char_score = 0.8;
254                    } else if last == '.' {
255                        char_score = 0.7;
256                    } else if query_idx == 0 {
257                        char_score = BASE_DISTANCE_PENALTY;
258                    } else {
259                        char_score = MIN_DISTANCE_PENALTY.max(
260                            BASE_DISTANCE_PENALTY
261                                - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
262                        );
263                    }
264                }
265
266                // Apply a severe penalty if the case doesn't match.
267                // This will make the exact matches have higher score than the case-insensitive and the
268                // path insensitive matches.
269                if (self.smart_case || curr == '/') && self.query[query_idx] != curr {
270                    char_score *= 0.001;
271                }
272
273                let mut multiplier = char_score;
274
275                // Scale the score based on how deep within the path we found the match.
276                if self.penalize_length && query_idx == 0 {
277                    multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
278                }
279
280                let mut next_score = 1.0;
281                if self.min_score > 0.0 {
282                    next_score = cur_score * multiplier;
283                    // Scores only decrease. If we can't pass the previous best, bail
284                    if next_score < self.min_score {
285                        // Ensure that score is non-zero so we use it in the memo table.
286                        if score == 0.0 {
287                            score = 1e-18;
288                        }
289                        continue;
290                    }
291                }
292
293                let new_score = self.recursive_score_match(
294                    path,
295                    path_lowercased,
296                    prefix,
297                    lowercase_prefix,
298                    query_idx + 1,
299                    j + 1,
300                    next_score,
301                ) * multiplier;
302
303                if new_score > score {
304                    score = new_score;
305                    best_position = j;
306                    // Optimization: can't score better than 1.
307                    if new_score == 1.0 {
308                        break;
309                    }
310                }
311            }
312        }
313
314        if best_position != 0 {
315            self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
316        }
317
318        self.score_matrix[query_idx * path_len + path_idx] = Some(score);
319        score
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use util::rel_path::{RelPath, rel_path};
326
327    use crate::{PathMatch, PathMatchCandidate};
328
329    use super::*;
330    use std::sync::Arc;
331
332    #[test]
333    fn test_get_last_positions() {
334        let mut query: &[char] = &['d', 'c'];
335        let mut matcher = Matcher::new(query, query, query.into(), false, true);
336        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
337        assert!(!result);
338
339        query = &['c', 'd'];
340        let mut matcher = Matcher::new(query, query, query.into(), false, true);
341        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
342        assert!(result);
343        assert_eq!(matcher.last_positions, vec![2, 4]);
344
345        query = &['z', '/', 'z', 'f'];
346        let mut matcher = Matcher::new(query, query, query.into(), false, true);
347        let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
348        assert!(result);
349        assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
350    }
351
352    #[test]
353    fn test_match_path_entries() {
354        let paths = vec![
355            "",
356            "a",
357            "ab",
358            "abC",
359            "abcd",
360            "alphabravocharlie",
361            "AlphaBravoCharlie",
362            "thisisatestdir",
363            "ThisIsATestDir",
364            "this/is/a/test/dir",
365            "test/tiatd",
366        ];
367
368        assert_eq!(
369            match_single_path_query("abc", false, &paths),
370            vec![
371                ("abC", vec![0, 1, 2]),
372                ("abcd", vec![0, 1, 2]),
373                ("AlphaBravoCharlie", vec![0, 5, 10]),
374                ("alphabravocharlie", vec![4, 5, 10]),
375            ]
376        );
377        assert_eq!(
378            match_single_path_query("t/i/a/t/d", false, &paths),
379            vec![("this/is/a/test/dir", vec![0, 4, 5, 7, 8, 9, 10, 14, 15]),]
380        );
381
382        assert_eq!(
383            match_single_path_query("tiatd", false, &paths),
384            vec![
385                ("test/tiatd", vec![5, 6, 7, 8, 9]),
386                ("ThisIsATestDir", vec![0, 4, 6, 7, 11]),
387                ("this/is/a/test/dir", vec![0, 5, 8, 10, 15]),
388                ("thisisatestdir", vec![0, 2, 6, 7, 11]),
389            ]
390        );
391    }
392
393    #[test]
394    fn test_lowercase_longer_than_uppercase() {
395        // This character has more chars in lower-case than in upper-case.
396        let paths = vec!["\u{0130}"];
397        let query = "\u{0130}";
398        assert_eq!(
399            match_single_path_query(query, false, &paths),
400            vec![("\u{0130}", vec![0])]
401        );
402
403        // Path is the lower-case version of the query
404        let paths = vec!["i\u{307}"];
405        let query = "\u{0130}";
406        assert_eq!(
407            match_single_path_query(query, false, &paths),
408            vec![("i\u{307}", vec![0])]
409        );
410    }
411
412    #[test]
413    fn test_match_multibyte_path_entries() {
414        let paths = vec![
415            "aαbβ/cγdδ",
416            "αβγδ/bcde",
417            "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f",
418            "d/🆒/h",
419        ];
420        assert_eq!("1️⃣".len(), 7);
421        assert_eq!(
422            match_single_path_query("bcd", false, &paths),
423            vec![
424                ("αβγδ/bcde", vec![9, 10, 11]),
425                ("aαbβ/cγdδ", vec![3, 7, 10]),
426            ]
427        );
428        assert_eq!(
429            match_single_path_query("cde", false, &paths),
430            vec![
431                ("αβγδ/bcde", vec![10, 11, 12]),
432                ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
433            ]
434        );
435    }
436
437    #[test]
438    fn match_unicode_path_entries() {
439        let mixed_unicode_paths = vec![
440            "İolu/oluş",
441            "İstanbul/code",
442            "Athens/Şanlıurfa",
443            "Çanakkale/scripts",
444            "paris/Düzce_İl",
445            "Berlin_Önemli_Ğündem",
446            "KİTAPLIK/london/dosya",
447            "tokyo/kyoto/fuji",
448            "new_york/san_francisco",
449        ];
450
451        assert_eq!(
452            match_single_path_query("İo/oluş", false, &mixed_unicode_paths),
453            vec![("İolu/oluş", vec![0, 2, 5, 6, 7, 8, 9])]
454        );
455
456        assert_eq!(
457            match_single_path_query("İst/code", false, &mixed_unicode_paths),
458            vec![("İstanbul/code", vec![0, 2, 3, 9, 10, 11, 12, 13])]
459        );
460
461        assert_eq!(
462            match_single_path_query("athens/şa", false, &mixed_unicode_paths),
463            vec![("Athens/Şanlıurfa", vec![0, 1, 2, 3, 4, 5, 6, 7, 9])]
464        );
465
466        assert_eq!(
467            match_single_path_query("BerlinÖĞ", false, &mixed_unicode_paths),
468            vec![("Berlin_Önemli_Ğündem", vec![0, 1, 2, 3, 4, 5, 7, 15])]
469        );
470
471        assert_eq!(
472            match_single_path_query("tokyo/fuji", false, &mixed_unicode_paths),
473            vec![("tokyo/kyoto/fuji", vec![0, 1, 2, 3, 4, 5, 12, 13, 14, 15])]
474        );
475
476        let mixed_script_paths = vec![
477            "résumé_Москва",
478            "naïve_київ_implementation",
479            "café_北京_app",
480            "東京_über_driver",
481            "déjà_vu_cairo",
482            "seoul_piñata_game",
483            "voilà_istanbul_result",
484        ];
485
486        assert_eq!(
487            match_single_path_query("résmé", false, &mixed_script_paths),
488            vec![("résumé_Москва", vec![0, 1, 3, 5, 6])]
489        );
490
491        assert_eq!(
492            match_single_path_query("café北京", false, &mixed_script_paths),
493            vec![("café_北京_app", vec![0, 1, 2, 3, 6, 9])]
494        );
495
496        assert_eq!(
497            match_single_path_query("ista", false, &mixed_script_paths),
498            vec![("voilà_istanbul_result", vec![7, 8, 9, 10])]
499        );
500
501        let complex_paths = vec![
502            "document_📚_library",
503            "project_👨‍👩‍👧‍👦_family",
504            "flags_🇯🇵🇺🇸🇪🇺_world",
505            "code_😀😃😄😁_happy",
506            "photo_👩‍👩‍👧‍👦_album",
507        ];
508
509        assert_eq!(
510            match_single_path_query("doc📚lib", false, &complex_paths),
511            vec![("document_📚_library", vec![0, 1, 2, 9, 14, 15, 16])]
512        );
513
514        assert_eq!(
515            match_single_path_query("codehappy", false, &complex_paths),
516            vec![("code_😀😃😄😁_happy", vec![0, 1, 2, 3, 22, 23, 24, 25, 26])]
517        );
518    }
519
520    #[test]
521    fn test_positions_are_valid_char_boundaries_with_expanding_lowercase() {
522        // İ (U+0130) lowercases to "i\u{307}" (2 chars) under full case folding.
523        // With simple case mapping (used by this matcher), İ → 'i' (1 char),
524        // so positions remain valid byte boundaries.
525        let paths = vec!["İstanbul/code.rs", "aİbİc/dİeİf.txt", "src/İmport/İndex.ts"];
526
527        for query in &["code", "İst", "dİe", "İndex", "İmport", "abcdef"] {
528            let results = match_single_path_query(query, false, &paths);
529            for (path, positions) in &results {
530                for &pos in positions {
531                    assert!(
532                        path.is_char_boundary(pos),
533                        "Position {pos} is not a valid char boundary in path {path:?} \
534                         (query: {query:?}, all positions: {positions:?})"
535                    );
536                }
537            }
538        }
539    }
540
541    #[test]
542    fn test_positions_valid_with_various_multibyte_chars() {
543        // German ß uppercases to SS but lowercases to itself — no expansion.
544        // Armenian ligatures and other characters that could expand under full
545        // case folding should still produce valid byte boundaries.
546        let paths = vec![
547            "straße/config.rs",
548            "Straße/München/file.txt",
549            "file/path.rs",     // fi (U+FB01, fi ligature)
550            "ffoo/bar.txt",     // ff (U+FB00, ff ligature)
551            "aÇbŞc/dÖeÜf.txt", // Turkish chars that don't expand
552        ];
553
554        for query in &["config", "Mün", "file", "bar", "abcdef", "straße", "ÇŞ"] {
555            let results = match_single_path_query(query, false, &paths);
556            for (path, positions) in &results {
557                for &pos in positions {
558                    assert!(
559                        path.is_char_boundary(pos),
560                        "Position {pos} is not a valid char boundary in path {path:?} \
561                         (query: {query:?}, all positions: {positions:?})"
562                    );
563                }
564            }
565        }
566    }
567
568    fn match_single_path_query<'a>(
569        query: &str,
570        smart_case: bool,
571        paths: &[&'a str],
572    ) -> Vec<(&'a str, Vec<usize>)> {
573        let lowercase_query = query.chars().map(simple_lowercase).collect::<Vec<_>>();
574        let query = query.chars().collect::<Vec<_>>();
575        let query_chars = CharBag::from(&lowercase_query[..]);
576
577        let path_arcs: Vec<Arc<RelPath>> = paths
578            .iter()
579            .map(|path| Arc::from(rel_path(path)))
580            .collect::<Vec<_>>();
581        let mut path_entries = Vec::new();
582        for (i, path) in paths.iter().enumerate() {
583            let lowercase_path: Vec<char> = path.chars().map(simple_lowercase).collect();
584            let char_bag = CharBag::from(lowercase_path.as_slice());
585            path_entries.push(PathMatchCandidate {
586                is_dir: false,
587                char_bag,
588                path: &path_arcs[i],
589            });
590        }
591
592        let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case, true);
593
594        let cancel_flag = AtomicBool::new(false);
595        let mut results = Vec::new();
596
597        matcher.match_candidates(
598            &[],
599            &[],
600            path_entries.into_iter(),
601            &mut results,
602            &cancel_flag,
603            |candidate, score, positions| PathMatch {
604                score,
605                worktree_id: 0,
606                positions: positions.clone(),
607                path: candidate.path.into(),
608                path_prefix: RelPath::empty().into(),
609                distance_to_relative_ancestor: usize::MAX,
610                is_dir: false,
611            },
612        );
613        results.sort_by(|a, b| b.cmp(a));
614
615        results
616            .into_iter()
617            .map(|result| {
618                (
619                    paths
620                        .iter()
621                        .copied()
622                        .find(|p| result.path.as_ref() == rel_path(p))
623                        .unwrap(),
624                    result.positions,
625                )
626            })
627            .collect()
628    }
629
630    /// Test for https://github.com/zed-industries/zed/issues/44324
631    #[test]
632    fn test_recursive_score_match_index_out_of_bounds() {
633        let paths = vec!["İ/İ/İ/İ"];
634        let query = "İ/İ";
635
636        // This panicked with "index out of bounds: the len is 21 but the index is 22"
637        let result = match_single_path_query(query, false, &paths);
638        let _ = result;
639    }
640}