matcher.rs

  1use std::{
  2    borrow::{Borrow, Cow},
  3    collections::BTreeMap,
  4    sync::atomic::{self, AtomicBool},
  5};
  6
  7use crate::CharBag;
  8
  9const BASE_DISTANCE_PENALTY: f64 = 0.6;
 10const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
 11const MIN_DISTANCE_PENALTY: f64 = 0.2;
 12
 13// TODO:
 14// Use `Path` instead of `&str` for paths.
 15pub struct Matcher<'a> {
 16    query: &'a [char],
 17    lowercase_query: &'a [char],
 18    query_char_bag: CharBag,
 19    smart_case: bool,
 20    penalize_length: bool,
 21    min_score: f64,
 22    match_positions: Vec<usize>,
 23    last_positions: Vec<usize>,
 24    score_matrix: Vec<Option<f64>>,
 25    best_position_matrix: Vec<usize>,
 26}
 27
 28pub trait MatchCandidate {
 29    fn has_chars(&self, bag: CharBag) -> bool;
 30    fn to_string(&self) -> Cow<'_, str>;
 31}
 32
 33impl<'a> Matcher<'a> {
 34    pub fn new(
 35        query: &'a [char],
 36        lowercase_query: &'a [char],
 37        query_char_bag: CharBag,
 38        smart_case: bool,
 39        penalize_length: bool,
 40    ) -> Self {
 41        Self {
 42            query,
 43            lowercase_query,
 44            query_char_bag,
 45            min_score: 0.0,
 46            last_positions: vec![0; lowercase_query.len()],
 47            match_positions: vec![0; query.len()],
 48            score_matrix: Vec::new(),
 49            best_position_matrix: Vec::new(),
 50            smart_case,
 51            penalize_length,
 52        }
 53    }
 54
 55    /// Filter and score fuzzy match candidates. Results are returned unsorted, in the same order as
 56    /// the input candidates.
 57    pub(crate) fn match_candidates<C, R, F, T>(
 58        &mut self,
 59        prefix: &[char],
 60        lowercase_prefix: &[char],
 61        candidates: impl Iterator<Item = T>,
 62        results: &mut Vec<R>,
 63        cancel_flag: &AtomicBool,
 64        build_match: F,
 65    ) where
 66        C: MatchCandidate,
 67        T: Borrow<C>,
 68        F: Fn(&C, f64, &Vec<usize>) -> R,
 69    {
 70        let mut candidate_chars = Vec::new();
 71        let mut lowercase_candidate_chars = Vec::new();
 72        let mut extra_lowercase_chars = BTreeMap::new();
 73
 74        for candidate in candidates {
 75            if !candidate.borrow().has_chars(self.query_char_bag) {
 76                continue;
 77            }
 78
 79            if cancel_flag.load(atomic::Ordering::Relaxed) {
 80                break;
 81            }
 82
 83            candidate_chars.clear();
 84            lowercase_candidate_chars.clear();
 85            extra_lowercase_chars.clear();
 86            for (i, c) in candidate.borrow().to_string().chars().enumerate() {
 87                candidate_chars.push(c);
 88                let mut char_lowercased = c.to_lowercase().collect::<Vec<_>>();
 89                if char_lowercased.len() > 1 {
 90                    extra_lowercase_chars.insert(i, char_lowercased.len() - 1);
 91                }
 92                lowercase_candidate_chars.append(&mut char_lowercased);
 93            }
 94
 95            if !self.find_last_positions(lowercase_prefix, &lowercase_candidate_chars) {
 96                continue;
 97            }
 98
 99            let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
100            self.score_matrix.clear();
101            self.score_matrix.resize(matrix_len, None);
102            self.best_position_matrix.clear();
103            self.best_position_matrix.resize(matrix_len, 0);
104
105            let score = self.score_match(
106                &candidate_chars,
107                &lowercase_candidate_chars,
108                prefix,
109                lowercase_prefix,
110                &extra_lowercase_chars,
111            );
112
113            if score > 0.0 {
114                results.push(build_match(
115                    candidate.borrow(),
116                    score,
117                    &self.match_positions,
118                ));
119            }
120        }
121    }
122
123    fn find_last_positions(
124        &mut self,
125        lowercase_prefix: &[char],
126        lowercase_candidate: &[char],
127    ) -> bool {
128        let mut lowercase_prefix = lowercase_prefix.iter();
129        let mut lowercase_candidate = lowercase_candidate.iter();
130        for (i, char) in self.lowercase_query.iter().enumerate().rev() {
131            if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
132                self.last_positions[i] = j + lowercase_prefix.len();
133            } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
134                self.last_positions[i] = j;
135            } else {
136                return false;
137            }
138        }
139        true
140    }
141
142    fn score_match(
143        &mut self,
144        path: &[char],
145        path_lowercased: &[char],
146        prefix: &[char],
147        lowercase_prefix: &[char],
148        extra_lowercase_chars: &BTreeMap<usize, usize>,
149    ) -> f64 {
150        let score = self.recursive_score_match(
151            path,
152            path_lowercased,
153            prefix,
154            lowercase_prefix,
155            0,
156            0,
157            self.query.len() as f64,
158            extra_lowercase_chars,
159        ) * self.query.len() as f64;
160
161        if score <= 0.0 {
162            return 0.0;
163        }
164        let path_len = prefix.len() + path.len();
165        let mut cur_start = 0;
166        let mut byte_ix = 0;
167        let mut char_ix = 0;
168        for i in 0..self.query.len() {
169            let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
170            while char_ix < match_char_ix {
171                let ch = prefix
172                    .get(char_ix)
173                    .or_else(|| path.get(char_ix - prefix.len()))
174                    .unwrap();
175                byte_ix += ch.len_utf8();
176                char_ix += 1;
177            }
178
179            self.match_positions[i] = byte_ix;
180
181            let matched_ch = prefix
182                .get(match_char_ix)
183                .or_else(|| path.get(match_char_ix - prefix.len()))
184                .unwrap();
185            byte_ix += matched_ch.len_utf8();
186
187            cur_start = match_char_ix + 1;
188            char_ix = match_char_ix + 1;
189        }
190
191        score
192    }
193
194    fn recursive_score_match(
195        &mut self,
196        path: &[char],
197        path_lowercased: &[char],
198        prefix: &[char],
199        lowercase_prefix: &[char],
200        query_idx: usize,
201        path_idx: usize,
202        cur_score: f64,
203        extra_lowercase_chars: &BTreeMap<usize, usize>,
204    ) -> f64 {
205        use std::path::MAIN_SEPARATOR;
206
207        if query_idx == self.query.len() {
208            return 1.0;
209        }
210
211        let limit = self.last_positions[query_idx];
212        let max_valid_index = (prefix.len() + path_lowercased.len()).saturating_sub(1);
213        let safe_limit = limit.min(max_valid_index);
214
215        if path_idx > safe_limit {
216            return 0.0;
217        }
218
219        let path_len = prefix.len() + path.len();
220        if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
221            return memoized;
222        }
223
224        let mut score = 0.0;
225        let mut best_position = 0;
226
227        let query_char = self.lowercase_query[query_idx];
228
229        let mut last_slash = 0;
230
231        for j in path_idx..=safe_limit {
232            let extra_lowercase_chars_count = extra_lowercase_chars
233                .iter()
234                .take_while(|&(&i, _)| i < j)
235                .map(|(_, increment)| increment)
236                .sum::<usize>();
237            let j_regular = j - extra_lowercase_chars_count;
238
239            let path_char = if j < prefix.len() {
240                lowercase_prefix[j]
241            } else {
242                let path_index = j - prefix.len();
243                match path_lowercased.get(path_index) {
244                    Some(&char) => char,
245                    None => continue,
246                }
247            };
248            let is_path_sep = path_char == MAIN_SEPARATOR;
249
250            if query_idx == 0 && is_path_sep {
251                last_slash = j_regular;
252            }
253
254            #[cfg(not(target_os = "windows"))]
255            let need_to_score =
256                query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\');
257            // `query_char == '\\'` breaks `test_match_path_entries` on Windows, `\` is only used as a path separator on Windows.
258            #[cfg(target_os = "windows")]
259            let need_to_score = query_char == path_char || (is_path_sep && query_char == '_');
260            if need_to_score {
261                let curr = match prefix.get(j_regular) {
262                    Some(&curr) => curr,
263                    None => path[j_regular - prefix.len()],
264                };
265
266                let mut char_score = 1.0;
267                if j > path_idx {
268                    let last = match prefix.get(j_regular - 1) {
269                        Some(&last) => last,
270                        None => path[j_regular - 1 - prefix.len()],
271                    };
272
273                    if last == MAIN_SEPARATOR {
274                        char_score = 0.9;
275                    } else if (last == '-' || last == '_' || last == ' ' || last.is_numeric())
276                        || (last.is_lowercase() && curr.is_uppercase())
277                    {
278                        char_score = 0.8;
279                    } else if last == '.' {
280                        char_score = 0.7;
281                    } else if query_idx == 0 {
282                        char_score = BASE_DISTANCE_PENALTY;
283                    } else {
284                        char_score = MIN_DISTANCE_PENALTY.max(
285                            BASE_DISTANCE_PENALTY
286                                - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
287                        );
288                    }
289                }
290
291                // Apply a severe penalty if the case doesn't match.
292                // This will make the exact matches have higher score than the case-insensitive and the
293                // path insensitive matches.
294                if (self.smart_case || curr == MAIN_SEPARATOR) && self.query[query_idx] != curr {
295                    char_score *= 0.001;
296                }
297
298                let mut multiplier = char_score;
299
300                // Scale the score based on how deep within the path we found the match.
301                if self.penalize_length && query_idx == 0 {
302                    multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
303                }
304
305                let mut next_score = 1.0;
306                if self.min_score > 0.0 {
307                    next_score = cur_score * multiplier;
308                    // Scores only decrease. If we can't pass the previous best, bail
309                    if next_score < self.min_score {
310                        // Ensure that score is non-zero so we use it in the memo table.
311                        if score == 0.0 {
312                            score = 1e-18;
313                        }
314                        continue;
315                    }
316                }
317
318                let new_score = self.recursive_score_match(
319                    path,
320                    path_lowercased,
321                    prefix,
322                    lowercase_prefix,
323                    query_idx + 1,
324                    j + 1,
325                    next_score,
326                    extra_lowercase_chars,
327                ) * multiplier;
328
329                if new_score > score {
330                    score = new_score;
331                    best_position = j_regular;
332                    // Optimization: can't score better than 1.
333                    if new_score == 1.0 {
334                        break;
335                    }
336                }
337            }
338        }
339
340        if best_position != 0 {
341            self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
342        }
343
344        self.score_matrix[query_idx * path_len + path_idx] = Some(score);
345        score
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use crate::{PathMatch, PathMatchCandidate};
352
353    use super::*;
354    use std::{
355        path::{Path, PathBuf},
356        sync::Arc,
357    };
358
359    #[test]
360    fn test_get_last_positions() {
361        let mut query: &[char] = &['d', 'c'];
362        let mut matcher = Matcher::new(query, query, query.into(), false, true);
363        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
364        assert!(!result);
365
366        query = &['c', 'd'];
367        let mut matcher = Matcher::new(query, query, query.into(), false, true);
368        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
369        assert!(result);
370        assert_eq!(matcher.last_positions, vec![2, 4]);
371
372        query = &['z', '/', 'z', 'f'];
373        let mut matcher = Matcher::new(query, query, query.into(), false, true);
374        let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
375        assert!(result);
376        assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
377    }
378
379    #[cfg(not(target_os = "windows"))]
380    #[test]
381    fn test_match_path_entries() {
382        let paths = vec![
383            "",
384            "a",
385            "ab",
386            "abC",
387            "abcd",
388            "alphabravocharlie",
389            "AlphaBravoCharlie",
390            "thisisatestdir",
391            "/////ThisIsATestDir",
392            "/this/is/a/test/dir",
393            "/test/tiatd",
394        ];
395
396        assert_eq!(
397            match_single_path_query("abc", false, &paths),
398            vec![
399                ("abC", vec![0, 1, 2]),
400                ("abcd", vec![0, 1, 2]),
401                ("AlphaBravoCharlie", vec![0, 5, 10]),
402                ("alphabravocharlie", vec![4, 5, 10]),
403            ]
404        );
405        assert_eq!(
406            match_single_path_query("t/i/a/t/d", false, &paths),
407            vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
408        );
409
410        assert_eq!(
411            match_single_path_query("tiatd", false, &paths),
412            vec![
413                ("/test/tiatd", vec![6, 7, 8, 9, 10]),
414                ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
415                ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
416                ("thisisatestdir", vec![0, 2, 6, 7, 11]),
417            ]
418        );
419    }
420
421    /// todo(windows)
422    /// Now, on Windows, users can only use the backslash as a path separator.
423    /// I do want to support both the backslash and the forward slash as path separators on Windows.
424    #[cfg(target_os = "windows")]
425    #[test]
426    fn test_match_path_entries() {
427        let paths = vec![
428            "",
429            "a",
430            "ab",
431            "abC",
432            "abcd",
433            "alphabravocharlie",
434            "AlphaBravoCharlie",
435            "thisisatestdir",
436            "\\\\\\\\\\ThisIsATestDir",
437            "\\this\\is\\a\\test\\dir",
438            "\\test\\tiatd",
439        ];
440
441        assert_eq!(
442            match_single_path_query("abc", false, &paths),
443            vec![
444                ("abC", vec![0, 1, 2]),
445                ("abcd", vec![0, 1, 2]),
446                ("AlphaBravoCharlie", vec![0, 5, 10]),
447                ("alphabravocharlie", vec![4, 5, 10]),
448            ]
449        );
450        assert_eq!(
451            match_single_path_query("t\\i\\a\\t\\d", false, &paths),
452            vec![(
453                "\\this\\is\\a\\test\\dir",
454                vec![1, 5, 6, 8, 9, 10, 11, 15, 16]
455            ),]
456        );
457
458        assert_eq!(
459            match_single_path_query("tiatd", false, &paths),
460            vec![
461                ("\\test\\tiatd", vec![6, 7, 8, 9, 10]),
462                ("\\this\\is\\a\\test\\dir", vec![1, 6, 9, 11, 16]),
463                ("\\\\\\\\\\ThisIsATestDir", vec![5, 9, 11, 12, 16]),
464                ("thisisatestdir", vec![0, 2, 6, 7, 11]),
465            ]
466        );
467    }
468
469    #[test]
470    fn test_lowercase_longer_than_uppercase() {
471        // This character has more chars in lower-case than in upper-case.
472        let paths = vec!["\u{0130}"];
473        let query = "\u{0130}";
474        assert_eq!(
475            match_single_path_query(query, false, &paths),
476            vec![("\u{0130}", vec![0])]
477        );
478
479        // Path is the lower-case version of the query
480        let paths = vec!["i\u{307}"];
481        let query = "\u{0130}";
482        assert_eq!(
483            match_single_path_query(query, false, &paths),
484            vec![("i\u{307}", vec![0])]
485        );
486    }
487
488    #[test]
489    fn test_match_multibyte_path_entries() {
490        let paths = vec![
491            "aαbβ/cγdδ",
492            "αβγδ/bcde",
493            "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f",
494            "/d/🆒/h",
495        ];
496        assert_eq!("1️⃣".len(), 7);
497        assert_eq!(
498            match_single_path_query("bcd", false, &paths),
499            vec![
500                ("αβγδ/bcde", vec![9, 10, 11]),
501                ("aαbβ/cγdδ", vec![3, 7, 10]),
502            ]
503        );
504        assert_eq!(
505            match_single_path_query("cde", false, &paths),
506            vec![
507                ("αβγδ/bcde", vec![10, 11, 12]),
508                ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
509            ]
510        );
511    }
512
513    #[test]
514    fn match_unicode_path_entries() {
515        let mixed_unicode_paths = vec![
516            "İolu/oluş",
517            "İstanbul/code",
518            "Athens/Şanlıurfa",
519            "Çanakkale/scripts",
520            "paris/Düzce_İl",
521            "Berlin_Önemli_Ğündem",
522            "KİTAPLIK/london/dosya",
523            "tokyo/kyoto/fuji",
524            "new_york/san_francisco",
525        ];
526
527        assert_eq!(
528            match_single_path_query("İo/oluş", false, &mixed_unicode_paths),
529            vec![("İolu/oluş", vec![0, 2, 4, 6, 8, 10, 12])]
530        );
531
532        assert_eq!(
533            match_single_path_query("İst/code", false, &mixed_unicode_paths),
534            vec![("İstanbul/code", vec![0, 2, 4, 6, 8, 10, 12, 14])]
535        );
536
537        assert_eq!(
538            match_single_path_query("athens/şa", false, &mixed_unicode_paths),
539            vec![("Athens/Şanlıurfa", vec![0, 1, 2, 3, 4, 5, 6, 7, 9])]
540        );
541
542        assert_eq!(
543            match_single_path_query("BerlinÖĞ", false, &mixed_unicode_paths),
544            vec![("Berlin_Önemli_Ğündem", vec![0, 1, 2, 3, 4, 5, 7, 15])]
545        );
546
547        assert_eq!(
548            match_single_path_query("tokyo/fuji", false, &mixed_unicode_paths),
549            vec![("tokyo/kyoto/fuji", vec![0, 1, 2, 3, 4, 5, 12, 13, 14, 15])]
550        );
551
552        let mixed_script_paths = vec![
553            "résumé_Москва",
554            "naïve_київ_implementation",
555            "café_北京_app",
556            "東京_über_driver",
557            "déjà_vu_cairo",
558            "seoul_piñata_game",
559            "voilà_istanbul_result",
560        ];
561
562        assert_eq!(
563            match_single_path_query("résmé", false, &mixed_script_paths),
564            vec![("résumé_Москва", vec![0, 1, 3, 5, 6])]
565        );
566
567        assert_eq!(
568            match_single_path_query("café北京", false, &mixed_script_paths),
569            vec![("café_北京_app", vec![0, 1, 2, 3, 6, 9])]
570        );
571
572        assert_eq!(
573            match_single_path_query("ista", false, &mixed_script_paths),
574            vec![("voilà_istanbul_result", vec![7, 8, 9, 10])]
575        );
576
577        let complex_paths = vec![
578            "document_📚_library",
579            "project_👨‍👩‍👧‍👦_family",
580            "flags_🇯🇵🇺🇸🇪🇺_world",
581            "code_😀😃😄😁_happy",
582            "photo_👩‍👩‍👧‍👦_album",
583        ];
584
585        assert_eq!(
586            match_single_path_query("doc📚lib", false, &complex_paths),
587            vec![("document_📚_library", vec![0, 1, 2, 9, 14, 15, 16])]
588        );
589
590        assert_eq!(
591            match_single_path_query("codehappy", false, &complex_paths),
592            vec![("code_😀😃😄😁_happy", vec![0, 1, 2, 3, 22, 23, 24, 25, 26])]
593        );
594    }
595
596    fn match_single_path_query<'a>(
597        query: &str,
598        smart_case: bool,
599        paths: &[&'a str],
600    ) -> Vec<(&'a str, Vec<usize>)> {
601        let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
602        let query = query.chars().collect::<Vec<_>>();
603        let query_chars = CharBag::from(&lowercase_query[..]);
604
605        let path_arcs: Vec<Arc<Path>> = paths
606            .iter()
607            .map(|path| Arc::from(PathBuf::from(path)))
608            .collect::<Vec<_>>();
609        let mut path_entries = Vec::new();
610        for (i, path) in paths.iter().enumerate() {
611            let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
612            let char_bag = CharBag::from(lowercase_path.as_slice());
613            path_entries.push(PathMatchCandidate {
614                is_dir: false,
615                char_bag,
616                path: &path_arcs[i],
617            });
618        }
619
620        let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case, true);
621
622        let cancel_flag = AtomicBool::new(false);
623        let mut results = Vec::new();
624
625        matcher.match_candidates(
626            &[],
627            &[],
628            path_entries.into_iter(),
629            &mut results,
630            &cancel_flag,
631            |candidate, score, positions| PathMatch {
632                score,
633                worktree_id: 0,
634                positions: positions.clone(),
635                path: Arc::from(candidate.path),
636                path_prefix: "".into(),
637                distance_to_relative_ancestor: usize::MAX,
638                is_dir: false,
639            },
640        );
641        results.sort_by(|a, b| b.cmp(a));
642
643        results
644            .into_iter()
645            .map(|result| {
646                (
647                    paths
648                        .iter()
649                        .copied()
650                        .find(|p| result.path.as_ref() == Path::new(p))
651                        .unwrap(),
652                    result.positions,
653                )
654            })
655            .collect()
656    }
657}