matcher.rs

  1use std::{
  2    borrow::Cow,
  3    sync::atomic::{self, AtomicBool},
  4};
  5
  6use crate::CharBag;
  7
  8const BASE_DISTANCE_PENALTY: f64 = 0.6;
  9const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
 10const MIN_DISTANCE_PENALTY: f64 = 0.2;
 11
 12// TODO:
 13// Use `Path` instead of `&str` for paths.
 14pub struct Matcher<'a> {
 15    query: &'a [char],
 16    lowercase_query: &'a [char],
 17    query_char_bag: CharBag,
 18    smart_case: bool,
 19    min_score: f64,
 20    match_positions: Vec<usize>,
 21    last_positions: Vec<usize>,
 22    score_matrix: Vec<Option<f64>>,
 23    best_position_matrix: Vec<usize>,
 24}
 25
 26pub trait MatchCandidate {
 27    fn has_chars(&self, bag: CharBag) -> bool;
 28    fn to_string(&self) -> Cow<'_, str>;
 29}
 30
 31impl<'a> Matcher<'a> {
 32    pub fn new(
 33        query: &'a [char],
 34        lowercase_query: &'a [char],
 35        query_char_bag: CharBag,
 36        smart_case: bool,
 37    ) -> Self {
 38        Self {
 39            query,
 40            lowercase_query,
 41            query_char_bag,
 42            min_score: 0.0,
 43            last_positions: vec![0; lowercase_query.len()],
 44            match_positions: vec![0; query.len()],
 45            score_matrix: Vec::new(),
 46            best_position_matrix: Vec::new(),
 47            smart_case,
 48        }
 49    }
 50
 51    /// Filter and score fuzzy match candidates. Results are returned unsorted, in the same order as
 52    /// the input candidates.
 53    pub fn match_candidates<C: MatchCandidate, R, F>(
 54        &mut self,
 55        prefix: &[char],
 56        lowercase_prefix: &[char],
 57        candidates: impl Iterator<Item = C>,
 58        results: &mut Vec<R>,
 59        cancel_flag: &AtomicBool,
 60        build_match: F,
 61    ) where
 62        F: Fn(&C, f64, &Vec<usize>) -> R,
 63    {
 64        let mut candidate_chars = Vec::new();
 65        let mut lowercase_candidate_chars = Vec::new();
 66
 67        for candidate in candidates {
 68            if !candidate.has_chars(self.query_char_bag) {
 69                continue;
 70            }
 71
 72            if cancel_flag.load(atomic::Ordering::Relaxed) {
 73                break;
 74            }
 75
 76            candidate_chars.clear();
 77            lowercase_candidate_chars.clear();
 78            for c in candidate.to_string().chars() {
 79                candidate_chars.push(c);
 80                lowercase_candidate_chars.append(&mut c.to_lowercase().collect::<Vec<_>>());
 81            }
 82
 83            if !self.find_last_positions(lowercase_prefix, &lowercase_candidate_chars) {
 84                continue;
 85            }
 86
 87            let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
 88            self.score_matrix.clear();
 89            self.score_matrix.resize(matrix_len, None);
 90            self.best_position_matrix.clear();
 91            self.best_position_matrix.resize(matrix_len, 0);
 92
 93            let score = self.score_match(
 94                &candidate_chars,
 95                &lowercase_candidate_chars,
 96                prefix,
 97                lowercase_prefix,
 98            );
 99
100            if score > 0.0 {
101                results.push(build_match(&candidate, score, &self.match_positions));
102            }
103        }
104    }
105
106    fn find_last_positions(
107        &mut self,
108        lowercase_prefix: &[char],
109        lowercase_candidate: &[char],
110    ) -> bool {
111        let mut lowercase_prefix = lowercase_prefix.iter();
112        let mut lowercase_candidate = lowercase_candidate.iter();
113        for (i, char) in self.lowercase_query.iter().enumerate().rev() {
114            if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
115                self.last_positions[i] = j + lowercase_prefix.len();
116            } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
117                self.last_positions[i] = j;
118            } else {
119                return false;
120            }
121        }
122        true
123    }
124
125    fn score_match(
126        &mut self,
127        path: &[char],
128        path_cased: &[char],
129        prefix: &[char],
130        lowercase_prefix: &[char],
131    ) -> f64 {
132        let score = self.recursive_score_match(
133            path,
134            path_cased,
135            prefix,
136            lowercase_prefix,
137            0,
138            0,
139            self.query.len() as f64,
140        ) * self.query.len() as f64;
141
142        if score <= 0.0 {
143            return 0.0;
144        }
145
146        let path_len = prefix.len() + path.len();
147        let mut cur_start = 0;
148        let mut byte_ix = 0;
149        let mut char_ix = 0;
150        for i in 0..self.query.len() {
151            let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
152            while char_ix < match_char_ix {
153                let ch = prefix
154                    .get(char_ix)
155                    .or_else(|| path.get(char_ix - prefix.len()))
156                    .unwrap();
157                byte_ix += ch.len_utf8();
158                char_ix += 1;
159            }
160            cur_start = match_char_ix + 1;
161            self.match_positions[i] = byte_ix;
162        }
163
164        score
165    }
166
167    fn recursive_score_match(
168        &mut self,
169        path: &[char],
170        path_cased: &[char],
171        prefix: &[char],
172        lowercase_prefix: &[char],
173        query_idx: usize,
174        path_idx: usize,
175        cur_score: f64,
176    ) -> f64 {
177        use std::path::MAIN_SEPARATOR;
178
179        if query_idx == self.query.len() {
180            return 1.0;
181        }
182
183        let path_len = prefix.len() + path.len();
184
185        if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
186            return memoized;
187        }
188
189        let mut score = 0.0;
190        let mut best_position = 0;
191
192        let query_char = self.lowercase_query[query_idx];
193        let limit = self.last_positions[query_idx];
194
195        let mut last_slash = 0;
196        for j in path_idx..=limit {
197            let path_char = if j < prefix.len() {
198                lowercase_prefix[j]
199            } else {
200                path_cased[j - prefix.len()]
201            };
202            let is_path_sep = path_char == MAIN_SEPARATOR;
203
204            if query_idx == 0 && is_path_sep {
205                last_slash = j;
206            }
207
208            #[cfg(not(target_os = "windows"))]
209            let need_to_score =
210                query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\');
211            // `query_char == '\\'` breaks `test_match_path_entries` on Windows, `\` is only used as a path separator on Windows.
212            #[cfg(target_os = "windows")]
213            let need_to_score = query_char == path_char || (is_path_sep && query_char == '_');
214            if need_to_score {
215                let curr = if j < prefix.len() {
216                    prefix[j]
217                } else {
218                    path[j - prefix.len()]
219                };
220
221                let mut char_score = 1.0;
222                if j > path_idx {
223                    let last = if j - 1 < prefix.len() {
224                        prefix[j - 1]
225                    } else {
226                        path[j - 1 - prefix.len()]
227                    };
228
229                    if last == MAIN_SEPARATOR {
230                        char_score = 0.9;
231                    } else if (last == '-' || last == '_' || last == ' ' || last.is_numeric())
232                        || (last.is_lowercase() && curr.is_uppercase())
233                    {
234                        char_score = 0.8;
235                    } else if last == '.' {
236                        char_score = 0.7;
237                    } else if query_idx == 0 {
238                        char_score = BASE_DISTANCE_PENALTY;
239                    } else {
240                        char_score = MIN_DISTANCE_PENALTY.max(
241                            BASE_DISTANCE_PENALTY
242                                - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
243                        );
244                    }
245                }
246
247                // Apply a severe penalty if the case doesn't match.
248                // This will make the exact matches have higher score than the case-insensitive and the
249                // path insensitive matches.
250                if (self.smart_case || curr == MAIN_SEPARATOR) && self.query[query_idx] != curr {
251                    char_score *= 0.001;
252                }
253
254                let mut multiplier = char_score;
255
256                // Scale the score based on how deep within the path we found the match.
257                if query_idx == 0 {
258                    multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
259                }
260
261                let mut next_score = 1.0;
262                if self.min_score > 0.0 {
263                    next_score = cur_score * multiplier;
264                    // Scores only decrease. If we can't pass the previous best, bail
265                    if next_score < self.min_score {
266                        // Ensure that score is non-zero so we use it in the memo table.
267                        if score == 0.0 {
268                            score = 1e-18;
269                        }
270                        continue;
271                    }
272                }
273
274                let new_score = self.recursive_score_match(
275                    path,
276                    path_cased,
277                    prefix,
278                    lowercase_prefix,
279                    query_idx + 1,
280                    j + 1,
281                    next_score,
282                ) * multiplier;
283
284                if new_score > score {
285                    score = new_score;
286                    best_position = j;
287                    // Optimization: can't score better than 1.
288                    if new_score == 1.0 {
289                        break;
290                    }
291                }
292            }
293        }
294
295        if best_position != 0 {
296            self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
297        }
298
299        self.score_matrix[query_idx * path_len + path_idx] = Some(score);
300        score
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use crate::{PathMatch, PathMatchCandidate};
307
308    use super::*;
309    use std::{
310        path::{Path, PathBuf},
311        sync::Arc,
312    };
313
314    #[test]
315    fn test_get_last_positions() {
316        let mut query: &[char] = &['d', 'c'];
317        let mut matcher = Matcher::new(query, query, query.into(), false);
318        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
319        assert!(!result);
320
321        query = &['c', 'd'];
322        let mut matcher = Matcher::new(query, query, query.into(), false);
323        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
324        assert!(result);
325        assert_eq!(matcher.last_positions, vec![2, 4]);
326
327        query = &['z', '/', 'z', 'f'];
328        let mut matcher = Matcher::new(query, query, query.into(), false);
329        let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
330        assert!(result);
331        assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
332    }
333
334    #[cfg(not(target_os = "windows"))]
335    #[test]
336    fn test_match_path_entries() {
337        let paths = vec![
338            "",
339            "a",
340            "ab",
341            "abC",
342            "abcd",
343            "alphabravocharlie",
344            "AlphaBravoCharlie",
345            "thisisatestdir",
346            "/////ThisIsATestDir",
347            "/this/is/a/test/dir",
348            "/test/tiatd",
349        ];
350
351        assert_eq!(
352            match_single_path_query("abc", false, &paths),
353            vec![
354                ("abC", vec![0, 1, 2]),
355                ("abcd", vec![0, 1, 2]),
356                ("AlphaBravoCharlie", vec![0, 5, 10]),
357                ("alphabravocharlie", vec![4, 5, 10]),
358            ]
359        );
360        assert_eq!(
361            match_single_path_query("t/i/a/t/d", false, &paths),
362            vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
363        );
364
365        assert_eq!(
366            match_single_path_query("tiatd", false, &paths),
367            vec![
368                ("/test/tiatd", vec![6, 7, 8, 9, 10]),
369                ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
370                ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
371                ("thisisatestdir", vec![0, 2, 6, 7, 11]),
372            ]
373        );
374    }
375
376    /// todo(windows)
377    /// Now, on Windows, users can only use the backslash as a path separator.
378    /// I do want to support both the backslash and the forward slash as path separators on Windows.
379    #[cfg(target_os = "windows")]
380    #[test]
381    fn test_match_path_entries() {
382        let paths = vec![
383            "",
384            "a",
385            "ab",
386            "abC",
387            "abcd",
388            "alphabravocharlie",
389            "AlphaBravoCharlie",
390            "thisisatestdir",
391            "\\\\\\\\\\ThisIsATestDir",
392            "\\this\\is\\a\\test\\dir",
393            "\\test\\tiatd",
394        ];
395
396        assert_eq!(
397            match_single_path_query("abc", false, &paths),
398            vec![
399                ("abC", vec![0, 1, 2]),
400                ("abcd", vec![0, 1, 2]),
401                ("AlphaBravoCharlie", vec![0, 5, 10]),
402                ("alphabravocharlie", vec![4, 5, 10]),
403            ]
404        );
405        assert_eq!(
406            match_single_path_query("t\\i\\a\\t\\d", false, &paths),
407            vec![(
408                "\\this\\is\\a\\test\\dir",
409                vec![1, 5, 6, 8, 9, 10, 11, 15, 16]
410            ),]
411        );
412
413        assert_eq!(
414            match_single_path_query("tiatd", false, &paths),
415            vec![
416                ("\\test\\tiatd", vec![6, 7, 8, 9, 10]),
417                ("\\this\\is\\a\\test\\dir", vec![1, 6, 9, 11, 16]),
418                ("\\\\\\\\\\ThisIsATestDir", vec![5, 9, 11, 12, 16]),
419                ("thisisatestdir", vec![0, 2, 6, 7, 11]),
420            ]
421        );
422    }
423
424    #[test]
425    fn test_lowercase_longer_than_uppercase() {
426        // This character has more chars in lower-case than in upper-case.
427        let paths = vec!["\u{0130}"];
428        let query = "\u{0130}";
429        assert_eq!(
430            match_single_path_query(query, false, &paths),
431            vec![("\u{0130}", vec![0])]
432        );
433
434        // Path is the lower-case version of the query
435        let paths = vec!["i\u{307}"];
436        let query = "\u{0130}";
437        assert_eq!(
438            match_single_path_query(query, false, &paths),
439            vec![("i\u{307}", vec![0])]
440        );
441    }
442
443    #[test]
444    fn test_match_multibyte_path_entries() {
445        let paths = vec![
446            "aαbβ/cγdδ",
447            "αβγδ/bcde",
448            "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f",
449            "/d/🆒/h",
450        ];
451        assert_eq!("1️⃣".len(), 7);
452        assert_eq!(
453            match_single_path_query("bcd", false, &paths),
454            vec![
455                ("αβγδ/bcde", vec![9, 10, 11]),
456                ("aαbβ/cγdδ", vec![3, 7, 10]),
457            ]
458        );
459        assert_eq!(
460            match_single_path_query("cde", false, &paths),
461            vec![
462                ("αβγδ/bcde", vec![10, 11, 12]),
463                ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
464            ]
465        );
466    }
467
468    fn match_single_path_query<'a>(
469        query: &str,
470        smart_case: bool,
471        paths: &[&'a str],
472    ) -> Vec<(&'a str, Vec<usize>)> {
473        let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
474        let query = query.chars().collect::<Vec<_>>();
475        let query_chars = CharBag::from(&lowercase_query[..]);
476
477        let path_arcs: Vec<Arc<Path>> = paths
478            .iter()
479            .map(|path| Arc::from(PathBuf::from(path)))
480            .collect::<Vec<_>>();
481        let mut path_entries = Vec::new();
482        for (i, path) in paths.iter().enumerate() {
483            let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
484            let char_bag = CharBag::from(lowercase_path.as_slice());
485            path_entries.push(PathMatchCandidate {
486                is_dir: false,
487                char_bag,
488                path: &path_arcs[i],
489            });
490        }
491
492        let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case);
493
494        let cancel_flag = AtomicBool::new(false);
495        let mut results = Vec::new();
496
497        matcher.match_candidates(
498            &[],
499            &[],
500            path_entries.into_iter(),
501            &mut results,
502            &cancel_flag,
503            |candidate, score, positions| PathMatch {
504                score,
505                worktree_id: 0,
506                positions: positions.clone(),
507                path: Arc::from(candidate.path),
508                path_prefix: "".into(),
509                distance_to_relative_ancestor: usize::MAX,
510                is_dir: false,
511            },
512        );
513        results.sort_by(|a, b| b.cmp(a));
514
515        results
516            .into_iter()
517            .map(|result| {
518                (
519                    paths
520                        .iter()
521                        .copied()
522                        .find(|p| result.path.as_ref() == Path::new(p))
523                        .unwrap(),
524                    result.positions,
525                )
526            })
527            .collect()
528    }
529}