matcher.rs

  1use std::{
  2    borrow::Borrow,
  3    collections::BTreeMap,
  4    sync::atomic::{self, AtomicBool},
  5};
  6
  7use crate::CharBag;
  8
  9const BASE_DISTANCE_PENALTY: f64 = 0.6;
 10const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
 11const MIN_DISTANCE_PENALTY: f64 = 0.2;
 12
 13// TODO:
 14// Use `Path` instead of `&str` for paths.
 15pub struct Matcher<'a> {
 16    query: &'a [char],
 17    lowercase_query: &'a [char],
 18    query_char_bag: CharBag,
 19    smart_case: bool,
 20    penalize_length: bool,
 21    min_score: f64,
 22    match_positions: Vec<usize>,
 23    last_positions: Vec<usize>,
 24    score_matrix: Vec<Option<f64>>,
 25    best_position_matrix: Vec<usize>,
 26}
 27
 28pub trait MatchCandidate {
 29    fn has_chars(&self, bag: CharBag) -> bool;
 30    fn candidate_chars(&self) -> impl Iterator<Item = char>;
 31}
 32
 33impl<'a> Matcher<'a> {
 34    pub fn new(
 35        query: &'a [char],
 36        lowercase_query: &'a [char],
 37        query_char_bag: CharBag,
 38        smart_case: bool,
 39        penalize_length: bool,
 40    ) -> Self {
 41        Self {
 42            query,
 43            lowercase_query,
 44            query_char_bag,
 45            min_score: 0.0,
 46            last_positions: vec![0; lowercase_query.len()],
 47            match_positions: vec![0; query.len()],
 48            score_matrix: Vec::new(),
 49            best_position_matrix: Vec::new(),
 50            smart_case,
 51            penalize_length,
 52        }
 53    }
 54
 55    /// Filter and score fuzzy match candidates. Results are returned unsorted, in the same order as
 56    /// the input candidates.
 57    pub(crate) fn match_candidates<C, R, F, T>(
 58        &mut self,
 59        prefix: &[char],
 60        lowercase_prefix: &[char],
 61        candidates: impl Iterator<Item = T>,
 62        results: &mut Vec<R>,
 63        cancel_flag: &AtomicBool,
 64        build_match: F,
 65    ) where
 66        C: MatchCandidate,
 67        T: Borrow<C>,
 68        F: Fn(&C, f64, &Vec<usize>) -> R,
 69    {
 70        let mut candidate_chars = Vec::new();
 71        let mut lowercase_candidate_chars = Vec::new();
 72        let mut extra_lowercase_chars = BTreeMap::new();
 73
 74        for candidate in candidates {
 75            if !candidate.borrow().has_chars(self.query_char_bag) {
 76                continue;
 77            }
 78
 79            if cancel_flag.load(atomic::Ordering::Acquire) {
 80                break;
 81            }
 82
 83            candidate_chars.clear();
 84            lowercase_candidate_chars.clear();
 85            extra_lowercase_chars.clear();
 86            for (i, c) in candidate.borrow().candidate_chars().enumerate() {
 87                candidate_chars.push(c);
 88                let mut char_lowercased = c.to_lowercase().collect::<Vec<_>>();
 89                if char_lowercased.len() > 1 {
 90                    extra_lowercase_chars.insert(i, char_lowercased.len() - 1);
 91                }
 92                lowercase_candidate_chars.append(&mut char_lowercased);
 93            }
 94
 95            if !self.find_last_positions(lowercase_prefix, &lowercase_candidate_chars) {
 96                continue;
 97            }
 98
 99            let matrix_len =
100                self.query.len() * (lowercase_prefix.len() + lowercase_candidate_chars.len());
101            self.score_matrix.clear();
102            self.score_matrix.resize(matrix_len, None);
103            self.best_position_matrix.clear();
104            self.best_position_matrix.resize(matrix_len, 0);
105
106            let score = self.score_match(
107                &candidate_chars,
108                &lowercase_candidate_chars,
109                prefix,
110                lowercase_prefix,
111                &extra_lowercase_chars,
112            );
113
114            if score > 0.0 {
115                results.push(build_match(
116                    candidate.borrow(),
117                    score,
118                    &self.match_positions,
119                ));
120            }
121        }
122    }
123
124    fn find_last_positions(
125        &mut self,
126        lowercase_prefix: &[char],
127        lowercase_candidate: &[char],
128    ) -> bool {
129        let mut lowercase_prefix = lowercase_prefix.iter();
130        let mut lowercase_candidate = lowercase_candidate.iter();
131        for (i, char) in self.lowercase_query.iter().enumerate().rev() {
132            if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
133                self.last_positions[i] = j + lowercase_prefix.len();
134            } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
135                self.last_positions[i] = j;
136            } else {
137                return false;
138            }
139        }
140        true
141    }
142
143    fn score_match(
144        &mut self,
145        path: &[char],
146        path_lowercased: &[char],
147        prefix: &[char],
148        lowercase_prefix: &[char],
149        extra_lowercase_chars: &BTreeMap<usize, usize>,
150    ) -> f64 {
151        let score = self.recursive_score_match(
152            path,
153            path_lowercased,
154            prefix,
155            lowercase_prefix,
156            0,
157            0,
158            self.query.len() as f64,
159            extra_lowercase_chars,
160        ) * self.query.len() as f64;
161
162        if score <= 0.0 {
163            return 0.0;
164        }
165        let path_len = prefix.len() + path.len();
166        let mut cur_start = 0;
167        let mut byte_ix = 0;
168        let mut char_ix = 0;
169        for i in 0..self.query.len() {
170            let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
171            while char_ix < match_char_ix {
172                let ch = prefix
173                    .get(char_ix)
174                    .or_else(|| path.get(char_ix - prefix.len()))
175                    .unwrap();
176                byte_ix += ch.len_utf8();
177                char_ix += 1;
178            }
179
180            self.match_positions[i] = byte_ix;
181
182            let matched_ch = prefix
183                .get(match_char_ix)
184                .or_else(|| path.get(match_char_ix - prefix.len()))
185                .unwrap();
186            byte_ix += matched_ch.len_utf8();
187
188            cur_start = match_char_ix + 1;
189            char_ix = match_char_ix + 1;
190        }
191
192        score
193    }
194
195    fn recursive_score_match(
196        &mut self,
197        path: &[char],
198        path_lowercased: &[char],
199        prefix: &[char],
200        lowercase_prefix: &[char],
201        query_idx: usize,
202        path_idx: usize,
203        cur_score: f64,
204        extra_lowercase_chars: &BTreeMap<usize, usize>,
205    ) -> f64 {
206        if query_idx == self.query.len() {
207            return 1.0;
208        }
209
210        let limit = self.last_positions[query_idx];
211        let max_valid_index = (prefix.len() + path_lowercased.len()).saturating_sub(1);
212        let safe_limit = limit.min(max_valid_index);
213
214        if path_idx > safe_limit {
215            return 0.0;
216        }
217
218        let path_len = prefix.len() + path.len();
219        if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
220            return memoized;
221        }
222
223        let mut score = 0.0;
224        let mut best_position = 0;
225
226        let query_char = self.lowercase_query[query_idx];
227
228        let mut last_slash = 0;
229
230        for j in path_idx..=safe_limit {
231            let extra_lowercase_chars_count = extra_lowercase_chars
232                .iter()
233                .take_while(|&(&i, _)| i < j)
234                .map(|(_, increment)| increment)
235                .sum::<usize>();
236            let j_regular = j - extra_lowercase_chars_count;
237
238            let path_char = if j < prefix.len() {
239                lowercase_prefix[j]
240            } else {
241                let path_index = j - prefix.len();
242                match path_lowercased.get(path_index) {
243                    Some(&char) => char,
244                    None => continue,
245                }
246            };
247            let is_path_sep = path_char == '/';
248
249            if query_idx == 0 && is_path_sep {
250                last_slash = j_regular;
251            }
252            let need_to_score = query_char == path_char || (is_path_sep && query_char == '_');
253            if need_to_score {
254                let curr = match prefix.get(j_regular) {
255                    Some(&curr) => curr,
256                    None => path[j_regular - prefix.len()],
257                };
258
259                let mut char_score = 1.0;
260                if j > path_idx {
261                    let last = match prefix.get(j_regular - 1) {
262                        Some(&last) => last,
263                        None => path[j_regular - 1 - prefix.len()],
264                    };
265
266                    if last == '/' {
267                        char_score = 0.9;
268                    } else if (last == '-' || last == '_' || last == ' ' || last.is_numeric())
269                        || (last.is_lowercase() && curr.is_uppercase())
270                    {
271                        char_score = 0.8;
272                    } else if last == '.' {
273                        char_score = 0.7;
274                    } else if query_idx == 0 {
275                        char_score = BASE_DISTANCE_PENALTY;
276                    } else {
277                        char_score = MIN_DISTANCE_PENALTY.max(
278                            BASE_DISTANCE_PENALTY
279                                - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
280                        );
281                    }
282                }
283
284                // Apply a severe penalty if the case doesn't match.
285                // This will make the exact matches have higher score than the case-insensitive and the
286                // path insensitive matches.
287                if (self.smart_case || curr == '/') && self.query[query_idx] != curr {
288                    char_score *= 0.001;
289                }
290
291                let mut multiplier = char_score;
292
293                // Scale the score based on how deep within the path we found the match.
294                if self.penalize_length && query_idx == 0 {
295                    multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
296                }
297
298                let mut next_score = 1.0;
299                if self.min_score > 0.0 {
300                    next_score = cur_score * multiplier;
301                    // Scores only decrease. If we can't pass the previous best, bail
302                    if next_score < self.min_score {
303                        // Ensure that score is non-zero so we use it in the memo table.
304                        if score == 0.0 {
305                            score = 1e-18;
306                        }
307                        continue;
308                    }
309                }
310
311                let new_score = self.recursive_score_match(
312                    path,
313                    path_lowercased,
314                    prefix,
315                    lowercase_prefix,
316                    query_idx + 1,
317                    j + 1,
318                    next_score,
319                    extra_lowercase_chars,
320                ) * multiplier;
321
322                if new_score > score {
323                    score = new_score;
324                    best_position = j_regular;
325                    // Optimization: can't score better than 1.
326                    if new_score == 1.0 {
327                        break;
328                    }
329                }
330            }
331        }
332
333        if best_position != 0 {
334            self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
335        }
336
337        self.score_matrix[query_idx * path_len + path_idx] = Some(score);
338        score
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use util::rel_path::{RelPath, rel_path};
345
346    use crate::{PathMatch, PathMatchCandidate};
347
348    use super::*;
349    use std::sync::Arc;
350
351    #[test]
352    fn test_get_last_positions() {
353        let mut query: &[char] = &['d', 'c'];
354        let mut matcher = Matcher::new(query, query, query.into(), false, true);
355        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
356        assert!(!result);
357
358        query = &['c', 'd'];
359        let mut matcher = Matcher::new(query, query, query.into(), false, true);
360        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
361        assert!(result);
362        assert_eq!(matcher.last_positions, vec![2, 4]);
363
364        query = &['z', '/', 'z', 'f'];
365        let mut matcher = Matcher::new(query, query, query.into(), false, true);
366        let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
367        assert!(result);
368        assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
369    }
370
371    #[test]
372    fn test_match_path_entries() {
373        let paths = vec![
374            "",
375            "a",
376            "ab",
377            "abC",
378            "abcd",
379            "alphabravocharlie",
380            "AlphaBravoCharlie",
381            "thisisatestdir",
382            "ThisIsATestDir",
383            "this/is/a/test/dir",
384            "test/tiatd",
385        ];
386
387        assert_eq!(
388            match_single_path_query("abc", false, &paths),
389            vec![
390                ("abC", vec![0, 1, 2]),
391                ("abcd", vec![0, 1, 2]),
392                ("AlphaBravoCharlie", vec![0, 5, 10]),
393                ("alphabravocharlie", vec![4, 5, 10]),
394            ]
395        );
396        assert_eq!(
397            match_single_path_query("t/i/a/t/d", false, &paths),
398            vec![("this/is/a/test/dir", vec![0, 4, 5, 7, 8, 9, 10, 14, 15]),]
399        );
400
401        assert_eq!(
402            match_single_path_query("tiatd", false, &paths),
403            vec![
404                ("test/tiatd", vec![5, 6, 7, 8, 9]),
405                ("ThisIsATestDir", vec![0, 4, 6, 7, 11]),
406                ("this/is/a/test/dir", vec![0, 5, 8, 10, 15]),
407                ("thisisatestdir", vec![0, 2, 6, 7, 11]),
408            ]
409        );
410    }
411
412    #[test]
413    fn test_lowercase_longer_than_uppercase() {
414        // This character has more chars in lower-case than in upper-case.
415        let paths = vec!["\u{0130}"];
416        let query = "\u{0130}";
417        assert_eq!(
418            match_single_path_query(query, false, &paths),
419            vec![("\u{0130}", vec![0])]
420        );
421
422        // Path is the lower-case version of the query
423        let paths = vec!["i\u{307}"];
424        let query = "\u{0130}";
425        assert_eq!(
426            match_single_path_query(query, false, &paths),
427            vec![("i\u{307}", vec![0])]
428        );
429    }
430
431    #[test]
432    fn test_match_multibyte_path_entries() {
433        let paths = vec![
434            "aαbβ/cγdδ",
435            "αβγδ/bcde",
436            "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f",
437            "d/🆒/h",
438        ];
439        assert_eq!("1️⃣".len(), 7);
440        assert_eq!(
441            match_single_path_query("bcd", false, &paths),
442            vec![
443                ("αβγδ/bcde", vec![9, 10, 11]),
444                ("aαbβ/cγdδ", vec![3, 7, 10]),
445            ]
446        );
447        assert_eq!(
448            match_single_path_query("cde", false, &paths),
449            vec![
450                ("αβγδ/bcde", vec![10, 11, 12]),
451                ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
452            ]
453        );
454    }
455
456    #[test]
457    fn match_unicode_path_entries() {
458        let mixed_unicode_paths = vec![
459            "İolu/oluş",
460            "İstanbul/code",
461            "Athens/Şanlıurfa",
462            "Çanakkale/scripts",
463            "paris/Düzce_İl",
464            "Berlin_Önemli_Ğündem",
465            "KİTAPLIK/london/dosya",
466            "tokyo/kyoto/fuji",
467            "new_york/san_francisco",
468        ];
469
470        assert_eq!(
471            match_single_path_query("İo/oluş", false, &mixed_unicode_paths),
472            vec![("İolu/oluş", vec![0, 2, 4, 6, 8, 10, 12])]
473        );
474
475        assert_eq!(
476            match_single_path_query("İst/code", false, &mixed_unicode_paths),
477            vec![("İstanbul/code", vec![0, 2, 4, 6, 8, 10, 12, 14])]
478        );
479
480        assert_eq!(
481            match_single_path_query("athens/şa", false, &mixed_unicode_paths),
482            vec![("Athens/Şanlıurfa", vec![0, 1, 2, 3, 4, 5, 6, 7, 9])]
483        );
484
485        assert_eq!(
486            match_single_path_query("BerlinÖĞ", false, &mixed_unicode_paths),
487            vec![("Berlin_Önemli_Ğündem", vec![0, 1, 2, 3, 4, 5, 7, 15])]
488        );
489
490        assert_eq!(
491            match_single_path_query("tokyo/fuji", false, &mixed_unicode_paths),
492            vec![("tokyo/kyoto/fuji", vec![0, 1, 2, 3, 4, 5, 12, 13, 14, 15])]
493        );
494
495        let mixed_script_paths = vec![
496            "résumé_Москва",
497            "naïve_київ_implementation",
498            "café_北京_app",
499            "東京_über_driver",
500            "déjà_vu_cairo",
501            "seoul_piñata_game",
502            "voilà_istanbul_result",
503        ];
504
505        assert_eq!(
506            match_single_path_query("résmé", false, &mixed_script_paths),
507            vec![("résumé_Москва", vec![0, 1, 3, 5, 6])]
508        );
509
510        assert_eq!(
511            match_single_path_query("café北京", false, &mixed_script_paths),
512            vec![("café_北京_app", vec![0, 1, 2, 3, 6, 9])]
513        );
514
515        assert_eq!(
516            match_single_path_query("ista", false, &mixed_script_paths),
517            vec![("voilà_istanbul_result", vec![7, 8, 9, 10])]
518        );
519
520        let complex_paths = vec![
521            "document_📚_library",
522            "project_👨‍👩‍👧‍👦_family",
523            "flags_🇯🇵🇺🇸🇪🇺_world",
524            "code_😀😃😄😁_happy",
525            "photo_👩‍👩‍👧‍👦_album",
526        ];
527
528        assert_eq!(
529            match_single_path_query("doc📚lib", false, &complex_paths),
530            vec![("document_📚_library", vec![0, 1, 2, 9, 14, 15, 16])]
531        );
532
533        assert_eq!(
534            match_single_path_query("codehappy", false, &complex_paths),
535            vec![("code_😀😃😄😁_happy", vec![0, 1, 2, 3, 22, 23, 24, 25, 26])]
536        );
537    }
538
539    fn match_single_path_query<'a>(
540        query: &str,
541        smart_case: bool,
542        paths: &[&'a str],
543    ) -> Vec<(&'a str, Vec<usize>)> {
544        let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
545        let query = query.chars().collect::<Vec<_>>();
546        let query_chars = CharBag::from(&lowercase_query[..]);
547
548        let path_arcs: Vec<Arc<RelPath>> = paths
549            .iter()
550            .map(|path| Arc::from(rel_path(path)))
551            .collect::<Vec<_>>();
552        let mut path_entries = Vec::new();
553        for (i, path) in paths.iter().enumerate() {
554            let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
555            let char_bag = CharBag::from(lowercase_path.as_slice());
556            path_entries.push(PathMatchCandidate {
557                is_dir: false,
558                char_bag,
559                path: &path_arcs[i],
560            });
561        }
562
563        let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case, true);
564
565        let cancel_flag = AtomicBool::new(false);
566        let mut results = Vec::new();
567
568        matcher.match_candidates(
569            &[],
570            &[],
571            path_entries.into_iter(),
572            &mut results,
573            &cancel_flag,
574            |candidate, score, positions| PathMatch {
575                score,
576                worktree_id: 0,
577                positions: positions.clone(),
578                path: candidate.path.into(),
579                path_prefix: RelPath::empty().into(),
580                distance_to_relative_ancestor: usize::MAX,
581                is_dir: false,
582            },
583        );
584        results.sort_by(|a, b| b.cmp(a));
585
586        results
587            .into_iter()
588            .map(|result| {
589                (
590                    paths
591                        .iter()
592                        .copied()
593                        .find(|p| result.path.as_ref() == rel_path(p))
594                        .unwrap(),
595                    result.positions,
596                )
597            })
598            .collect()
599    }
600
601    /// Test for https://github.com/zed-industries/zed/issues/44324
602    #[test]
603    fn test_recursive_score_match_index_out_of_bounds() {
604        let paths = vec!["İ/İ/İ/İ"];
605        let query = "İ/İ";
606
607        // This panicked with "index out of bounds: the len is 21 but the index is 22"
608        let result = match_single_path_query(query, false, &paths);
609        let _ = result;
610    }
611}