matcher.rs

  1use std::{
  2    borrow::{Borrow, Cow},
  3    collections::BTreeMap,
  4    sync::atomic::{self, AtomicBool},
  5};
  6
  7use crate::CharBag;
  8
  9const BASE_DISTANCE_PENALTY: f64 = 0.6;
 10const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
 11const MIN_DISTANCE_PENALTY: f64 = 0.2;
 12
 13// TODO:
 14// Use `Path` instead of `&str` for paths.
 15pub struct Matcher<'a> {
 16    query: &'a [char],
 17    lowercase_query: &'a [char],
 18    query_char_bag: CharBag,
 19    smart_case: bool,
 20    min_score: f64,
 21    match_positions: Vec<usize>,
 22    last_positions: Vec<usize>,
 23    score_matrix: Vec<Option<f64>>,
 24    best_position_matrix: Vec<usize>,
 25}
 26
 27pub trait MatchCandidate {
 28    fn has_chars(&self, bag: CharBag) -> bool;
 29    fn to_string(&self) -> Cow<'_, str>;
 30}
 31
 32impl<'a> Matcher<'a> {
 33    pub fn new(
 34        query: &'a [char],
 35        lowercase_query: &'a [char],
 36        query_char_bag: CharBag,
 37        smart_case: bool,
 38    ) -> Self {
 39        Self {
 40            query,
 41            lowercase_query,
 42            query_char_bag,
 43            min_score: 0.0,
 44            last_positions: vec![0; lowercase_query.len()],
 45            match_positions: vec![0; query.len()],
 46            score_matrix: Vec::new(),
 47            best_position_matrix: Vec::new(),
 48            smart_case,
 49        }
 50    }
 51
 52    /// Filter and score fuzzy match candidates. Results are returned unsorted, in the same order as
 53    /// the input candidates.
 54    pub(crate) fn match_candidates<C, R, F, T>(
 55        &mut self,
 56        prefix: &[char],
 57        lowercase_prefix: &[char],
 58        candidates: impl Iterator<Item = T>,
 59        results: &mut Vec<R>,
 60        cancel_flag: &AtomicBool,
 61        build_match: F,
 62    ) where
 63        C: MatchCandidate,
 64        T: Borrow<C>,
 65        F: Fn(&C, f64, &Vec<usize>) -> R,
 66    {
 67        let mut candidate_chars = Vec::new();
 68        let mut lowercase_candidate_chars = Vec::new();
 69        let mut extra_lowercase_chars = BTreeMap::new();
 70
 71        for candidate in candidates {
 72            if !candidate.borrow().has_chars(self.query_char_bag) {
 73                continue;
 74            }
 75
 76            if cancel_flag.load(atomic::Ordering::Relaxed) {
 77                break;
 78            }
 79
 80            candidate_chars.clear();
 81            lowercase_candidate_chars.clear();
 82            extra_lowercase_chars.clear();
 83            for (i, c) in candidate.borrow().to_string().chars().enumerate() {
 84                candidate_chars.push(c);
 85                let mut char_lowercased = c.to_lowercase().collect::<Vec<_>>();
 86                if char_lowercased.len() > 1 {
 87                    extra_lowercase_chars.insert(i, char_lowercased.len() - 1);
 88                }
 89                lowercase_candidate_chars.append(&mut char_lowercased);
 90            }
 91
 92            if !self.find_last_positions(lowercase_prefix, &lowercase_candidate_chars) {
 93                continue;
 94            }
 95
 96            let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
 97            self.score_matrix.clear();
 98            self.score_matrix.resize(matrix_len, None);
 99            self.best_position_matrix.clear();
100            self.best_position_matrix.resize(matrix_len, 0);
101
102            let score = self.score_match(
103                &candidate_chars,
104                &lowercase_candidate_chars,
105                prefix,
106                lowercase_prefix,
107                &extra_lowercase_chars,
108            );
109
110            if score > 0.0 {
111                results.push(build_match(
112                    candidate.borrow(),
113                    score,
114                    &self.match_positions,
115                ));
116            }
117        }
118    }
119
120    fn find_last_positions(
121        &mut self,
122        lowercase_prefix: &[char],
123        lowercase_candidate: &[char],
124    ) -> bool {
125        let mut lowercase_prefix = lowercase_prefix.iter();
126        let mut lowercase_candidate = lowercase_candidate.iter();
127        for (i, char) in self.lowercase_query.iter().enumerate().rev() {
128            if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
129                self.last_positions[i] = j + lowercase_prefix.len();
130            } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
131                self.last_positions[i] = j;
132            } else {
133                return false;
134            }
135        }
136        true
137    }
138
139    fn score_match(
140        &mut self,
141        path: &[char],
142        path_lowercased: &[char],
143        prefix: &[char],
144        lowercase_prefix: &[char],
145        extra_lowercase_chars: &BTreeMap<usize, usize>,
146    ) -> f64 {
147        let score = self.recursive_score_match(
148            path,
149            path_lowercased,
150            prefix,
151            lowercase_prefix,
152            0,
153            0,
154            self.query.len() as f64,
155            extra_lowercase_chars,
156        ) * self.query.len() as f64;
157
158        if score <= 0.0 {
159            return 0.0;
160        }
161        let path_len = prefix.len() + path.len();
162        let mut cur_start = 0;
163        let mut byte_ix = 0;
164        let mut char_ix = 0;
165        for i in 0..self.query.len() {
166            let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
167            while char_ix < match_char_ix {
168                let ch = prefix
169                    .get(char_ix)
170                    .or_else(|| path.get(char_ix - prefix.len()))
171                    .unwrap();
172                byte_ix += ch.len_utf8();
173                char_ix += 1;
174            }
175
176            self.match_positions[i] = byte_ix;
177
178            let matched_ch = prefix
179                .get(match_char_ix)
180                .or_else(|| path.get(match_char_ix - prefix.len()))
181                .unwrap();
182            byte_ix += matched_ch.len_utf8();
183
184            cur_start = match_char_ix + 1;
185            char_ix = match_char_ix + 1;
186        }
187
188        score
189    }
190
191    fn recursive_score_match(
192        &mut self,
193        path: &[char],
194        path_lowercased: &[char],
195        prefix: &[char],
196        lowercase_prefix: &[char],
197        query_idx: usize,
198        path_idx: usize,
199        cur_score: f64,
200        extra_lowercase_chars: &BTreeMap<usize, usize>,
201    ) -> f64 {
202        use std::path::MAIN_SEPARATOR;
203
204        if query_idx == self.query.len() {
205            return 1.0;
206        }
207
208        let path_len = prefix.len() + path.len();
209
210        if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
211            return memoized;
212        }
213
214        let mut score = 0.0;
215        let mut best_position = 0;
216
217        let query_char = self.lowercase_query[query_idx];
218        let limit = self.last_positions[query_idx];
219
220        let max_valid_index = (prefix.len() + path_lowercased.len()).saturating_sub(1);
221        let safe_limit = limit.min(max_valid_index);
222
223        let mut last_slash = 0;
224        for j in path_idx..=safe_limit {
225            let extra_lowercase_chars_count = extra_lowercase_chars
226                .iter()
227                .take_while(|(i, _)| i < &&j)
228                .map(|(_, increment)| increment)
229                .sum::<usize>();
230            let j_regular = j - extra_lowercase_chars_count;
231
232            let path_char = if j < prefix.len() {
233                lowercase_prefix[j]
234            } else {
235                let path_index = j - prefix.len();
236                if path_index < path_lowercased.len() {
237                    path_lowercased[path_index]
238                } else {
239                    continue;
240                }
241            };
242            let is_path_sep = path_char == MAIN_SEPARATOR;
243
244            if query_idx == 0 && is_path_sep {
245                last_slash = j_regular;
246            }
247
248            #[cfg(not(target_os = "windows"))]
249            let need_to_score =
250                query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\');
251            // `query_char == '\\'` breaks `test_match_path_entries` on Windows, `\` is only used as a path separator on Windows.
252            #[cfg(target_os = "windows")]
253            let need_to_score = query_char == path_char || (is_path_sep && query_char == '_');
254            if need_to_score {
255                let curr = if j_regular < prefix.len() {
256                    prefix[j_regular]
257                } else {
258                    path[j_regular - prefix.len()]
259                };
260
261                let mut char_score = 1.0;
262                if j > path_idx {
263                    let last = if j_regular - 1 < prefix.len() {
264                        prefix[j_regular - 1]
265                    } else {
266                        path[j_regular - 1 - prefix.len()]
267                    };
268
269                    if last == MAIN_SEPARATOR {
270                        char_score = 0.9;
271                    } else if (last == '-' || last == '_' || last == ' ' || last.is_numeric())
272                        || (last.is_lowercase() && curr.is_uppercase())
273                    {
274                        char_score = 0.8;
275                    } else if last == '.' {
276                        char_score = 0.7;
277                    } else if query_idx == 0 {
278                        char_score = BASE_DISTANCE_PENALTY;
279                    } else {
280                        char_score = MIN_DISTANCE_PENALTY.max(
281                            BASE_DISTANCE_PENALTY
282                                - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
283                        );
284                    }
285                }
286
287                // Apply a severe penalty if the case doesn't match.
288                // This will make the exact matches have higher score than the case-insensitive and the
289                // path insensitive matches.
290                if (self.smart_case || curr == MAIN_SEPARATOR) && self.query[query_idx] != curr {
291                    char_score *= 0.001;
292                }
293
294                let mut multiplier = char_score;
295
296                // Scale the score based on how deep within the path we found the match.
297                if query_idx == 0 {
298                    multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
299                }
300
301                let mut next_score = 1.0;
302                if self.min_score > 0.0 {
303                    next_score = cur_score * multiplier;
304                    // Scores only decrease. If we can't pass the previous best, bail
305                    if next_score < self.min_score {
306                        // Ensure that score is non-zero so we use it in the memo table.
307                        if score == 0.0 {
308                            score = 1e-18;
309                        }
310                        continue;
311                    }
312                }
313
314                let new_score = self.recursive_score_match(
315                    path,
316                    path_lowercased,
317                    prefix,
318                    lowercase_prefix,
319                    query_idx + 1,
320                    j + 1,
321                    next_score,
322                    extra_lowercase_chars,
323                ) * multiplier;
324
325                if new_score > score {
326                    score = new_score;
327                    best_position = j_regular;
328                    // Optimization: can't score better than 1.
329                    if new_score == 1.0 {
330                        break;
331                    }
332                }
333            }
334        }
335
336        if best_position != 0 {
337            self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
338        }
339
340        self.score_matrix[query_idx * path_len + path_idx] = Some(score);
341        score
342    }
343}
344
345#[cfg(test)]
346mod tests {
347    use crate::{PathMatch, PathMatchCandidate};
348
349    use super::*;
350    use std::{
351        path::{Path, PathBuf},
352        sync::Arc,
353    };
354
355    #[test]
356    fn test_get_last_positions() {
357        let mut query: &[char] = &['d', 'c'];
358        let mut matcher = Matcher::new(query, query, query.into(), false);
359        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
360        assert!(!result);
361
362        query = &['c', 'd'];
363        let mut matcher = Matcher::new(query, query, query.into(), false);
364        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
365        assert!(result);
366        assert_eq!(matcher.last_positions, vec![2, 4]);
367
368        query = &['z', '/', 'z', 'f'];
369        let mut matcher = Matcher::new(query, query, query.into(), false);
370        let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
371        assert!(result);
372        assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
373    }
374
375    #[cfg(not(target_os = "windows"))]
376    #[test]
377    fn test_match_path_entries() {
378        let paths = vec![
379            "",
380            "a",
381            "ab",
382            "abC",
383            "abcd",
384            "alphabravocharlie",
385            "AlphaBravoCharlie",
386            "thisisatestdir",
387            "/////ThisIsATestDir",
388            "/this/is/a/test/dir",
389            "/test/tiatd",
390        ];
391
392        assert_eq!(
393            match_single_path_query("abc", false, &paths),
394            vec![
395                ("abC", vec![0, 1, 2]),
396                ("abcd", vec![0, 1, 2]),
397                ("AlphaBravoCharlie", vec![0, 5, 10]),
398                ("alphabravocharlie", vec![4, 5, 10]),
399            ]
400        );
401        assert_eq!(
402            match_single_path_query("t/i/a/t/d", false, &paths),
403            vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
404        );
405
406        assert_eq!(
407            match_single_path_query("tiatd", false, &paths),
408            vec![
409                ("/test/tiatd", vec![6, 7, 8, 9, 10]),
410                ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
411                ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
412                ("thisisatestdir", vec![0, 2, 6, 7, 11]),
413            ]
414        );
415    }
416
417    /// todo(windows)
418    /// Now, on Windows, users can only use the backslash as a path separator.
419    /// I do want to support both the backslash and the forward slash as path separators on Windows.
420    #[cfg(target_os = "windows")]
421    #[test]
422    fn test_match_path_entries() {
423        let paths = vec![
424            "",
425            "a",
426            "ab",
427            "abC",
428            "abcd",
429            "alphabravocharlie",
430            "AlphaBravoCharlie",
431            "thisisatestdir",
432            "\\\\\\\\\\ThisIsATestDir",
433            "\\this\\is\\a\\test\\dir",
434            "\\test\\tiatd",
435        ];
436
437        assert_eq!(
438            match_single_path_query("abc", false, &paths),
439            vec![
440                ("abC", vec![0, 1, 2]),
441                ("abcd", vec![0, 1, 2]),
442                ("AlphaBravoCharlie", vec![0, 5, 10]),
443                ("alphabravocharlie", vec![4, 5, 10]),
444            ]
445        );
446        assert_eq!(
447            match_single_path_query("t\\i\\a\\t\\d", false, &paths),
448            vec![(
449                "\\this\\is\\a\\test\\dir",
450                vec![1, 5, 6, 8, 9, 10, 11, 15, 16]
451            ),]
452        );
453
454        assert_eq!(
455            match_single_path_query("tiatd", false, &paths),
456            vec![
457                ("\\test\\tiatd", vec![6, 7, 8, 9, 10]),
458                ("\\this\\is\\a\\test\\dir", vec![1, 6, 9, 11, 16]),
459                ("\\\\\\\\\\ThisIsATestDir", vec![5, 9, 11, 12, 16]),
460                ("thisisatestdir", vec![0, 2, 6, 7, 11]),
461            ]
462        );
463    }
464
465    #[test]
466    fn test_lowercase_longer_than_uppercase() {
467        // This character has more chars in lower-case than in upper-case.
468        let paths = vec!["\u{0130}"];
469        let query = "\u{0130}";
470        assert_eq!(
471            match_single_path_query(query, false, &paths),
472            vec![("\u{0130}", vec![0])]
473        );
474
475        // Path is the lower-case version of the query
476        let paths = vec!["i\u{307}"];
477        let query = "\u{0130}";
478        assert_eq!(
479            match_single_path_query(query, false, &paths),
480            vec![("i\u{307}", vec![0])]
481        );
482    }
483
484    #[test]
485    fn test_match_multibyte_path_entries() {
486        let paths = vec![
487            "aαbβ/cγdδ",
488            "αβγδ/bcde",
489            "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f",
490            "/d/🆒/h",
491        ];
492        assert_eq!("1️⃣".len(), 7);
493        assert_eq!(
494            match_single_path_query("bcd", false, &paths),
495            vec![
496                ("αβγδ/bcde", vec![9, 10, 11]),
497                ("aαbβ/cγdδ", vec![3, 7, 10]),
498            ]
499        );
500        assert_eq!(
501            match_single_path_query("cde", false, &paths),
502            vec![
503                ("αβγδ/bcde", vec![10, 11, 12]),
504                ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
505            ]
506        );
507    }
508
509    #[test]
510    fn match_unicode_path_entries() {
511        let mixed_unicode_paths = vec![
512            "İolu/oluş",
513            "İstanbul/code",
514            "Athens/Şanlıurfa",
515            "Çanakkale/scripts",
516            "paris/Düzce_İl",
517            "Berlin_Önemli_Ğündem",
518            "KİTAPLIK/london/dosya",
519            "tokyo/kyoto/fuji",
520            "new_york/san_francisco",
521        ];
522
523        assert_eq!(
524            match_single_path_query("İo/oluş", false, &mixed_unicode_paths),
525            vec![("İolu/oluş", vec![0, 2, 4, 6, 8, 10, 12])]
526        );
527
528        assert_eq!(
529            match_single_path_query("İst/code", false, &mixed_unicode_paths),
530            vec![("İstanbul/code", vec![0, 2, 4, 6, 8, 10, 12, 14])]
531        );
532
533        assert_eq!(
534            match_single_path_query("athens/şa", false, &mixed_unicode_paths),
535            vec![("Athens/Şanlıurfa", vec![0, 1, 2, 3, 4, 5, 6, 7, 9])]
536        );
537
538        assert_eq!(
539            match_single_path_query("BerlinÖĞ", false, &mixed_unicode_paths),
540            vec![("Berlin_Önemli_Ğündem", vec![0, 1, 2, 3, 4, 5, 7, 15])]
541        );
542
543        assert_eq!(
544            match_single_path_query("tokyo/fuji", false, &mixed_unicode_paths),
545            vec![("tokyo/kyoto/fuji", vec![0, 1, 2, 3, 4, 5, 12, 13, 14, 15])]
546        );
547
548        let mixed_script_paths = vec![
549            "résumé_Москва",
550            "naïve_київ_implementation",
551            "café_北京_app",
552            "東京_über_driver",
553            "déjà_vu_cairo",
554            "seoul_piñata_game",
555            "voilà_istanbul_result",
556        ];
557
558        assert_eq!(
559            match_single_path_query("résmé", false, &mixed_script_paths),
560            vec![("résumé_Москва", vec![0, 1, 3, 5, 6])]
561        );
562
563        assert_eq!(
564            match_single_path_query("café北京", false, &mixed_script_paths),
565            vec![("café_北京_app", vec![0, 1, 2, 3, 6, 9])]
566        );
567
568        assert_eq!(
569            match_single_path_query("ista", false, &mixed_script_paths),
570            vec![("voilà_istanbul_result", vec![7, 8, 9, 10])]
571        );
572
573        let complex_paths = vec![
574            "document_📚_library",
575            "project_👨‍👩‍👧‍👦_family",
576            "flags_🇯🇵🇺🇸🇪🇺_world",
577            "code_😀😃😄😁_happy",
578            "photo_👩‍👩‍👧‍👦_album",
579        ];
580
581        assert_eq!(
582            match_single_path_query("doc📚lib", false, &complex_paths),
583            vec![("document_📚_library", vec![0, 1, 2, 9, 14, 15, 16])]
584        );
585
586        assert_eq!(
587            match_single_path_query("codehappy", false, &complex_paths),
588            vec![("code_😀😃😄😁_happy", vec![0, 1, 2, 3, 22, 23, 24, 25, 26])]
589        );
590    }
591
592    fn match_single_path_query<'a>(
593        query: &str,
594        smart_case: bool,
595        paths: &[&'a str],
596    ) -> Vec<(&'a str, Vec<usize>)> {
597        let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
598        let query = query.chars().collect::<Vec<_>>();
599        let query_chars = CharBag::from(&lowercase_query[..]);
600
601        let path_arcs: Vec<Arc<Path>> = paths
602            .iter()
603            .map(|path| Arc::from(PathBuf::from(path)))
604            .collect::<Vec<_>>();
605        let mut path_entries = Vec::new();
606        for (i, path) in paths.iter().enumerate() {
607            let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
608            let char_bag = CharBag::from(lowercase_path.as_slice());
609            path_entries.push(PathMatchCandidate {
610                is_dir: false,
611                char_bag,
612                path: &path_arcs[i],
613            });
614        }
615
616        let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case);
617
618        let cancel_flag = AtomicBool::new(false);
619        let mut results = Vec::new();
620
621        matcher.match_candidates(
622            &[],
623            &[],
624            path_entries.into_iter(),
625            &mut results,
626            &cancel_flag,
627            |candidate, score, positions| PathMatch {
628                score,
629                worktree_id: 0,
630                positions: positions.clone(),
631                path: Arc::from(candidate.path),
632                path_prefix: "".into(),
633                distance_to_relative_ancestor: usize::MAX,
634                is_dir: false,
635            },
636        );
637        results.sort_by(|a, b| b.cmp(a));
638
639        results
640            .into_iter()
641            .map(|result| {
642                (
643                    paths
644                        .iter()
645                        .copied()
646                        .find(|p| result.path.as_ref() == Path::new(p))
647                        .unwrap(),
648                    result.positions,
649                )
650            })
651            .collect()
652    }
653}