matcher.rs

  1use std::{
  2    borrow::{Borrow, Cow},
  3    sync::atomic::{self, AtomicBool},
  4};
  5
  6use crate::CharBag;
  7
  8const BASE_DISTANCE_PENALTY: f64 = 0.6;
  9const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
 10const MIN_DISTANCE_PENALTY: f64 = 0.2;
 11
 12// TODO:
 13// Use `Path` instead of `&str` for paths.
 14pub struct Matcher<'a> {
 15    query: &'a [char],
 16    lowercase_query: &'a [char],
 17    query_char_bag: CharBag,
 18    smart_case: bool,
 19    min_score: f64,
 20    match_positions: Vec<usize>,
 21    last_positions: Vec<usize>,
 22    score_matrix: Vec<Option<f64>>,
 23    best_position_matrix: Vec<usize>,
 24}
 25
 26pub trait MatchCandidate {
 27    fn has_chars(&self, bag: CharBag) -> bool;
 28    fn to_string(&self) -> Cow<'_, str>;
 29}
 30
 31impl<'a> Matcher<'a> {
 32    pub fn new(
 33        query: &'a [char],
 34        lowercase_query: &'a [char],
 35        query_char_bag: CharBag,
 36        smart_case: bool,
 37    ) -> Self {
 38        Self {
 39            query,
 40            lowercase_query,
 41            query_char_bag,
 42            min_score: 0.0,
 43            last_positions: vec![0; lowercase_query.len()],
 44            match_positions: vec![0; query.len()],
 45            score_matrix: Vec::new(),
 46            best_position_matrix: Vec::new(),
 47            smart_case,
 48        }
 49    }
 50
 51    /// Filter and score fuzzy match candidates. Results are returned unsorted, in the same order as
 52    /// the input candidates.
 53    pub fn match_candidates<C, R, F, T>(
 54        &mut self,
 55        prefix: &[char],
 56        lowercase_prefix: &[char],
 57        candidates: impl Iterator<Item = T>,
 58        results: &mut Vec<R>,
 59        cancel_flag: &AtomicBool,
 60        build_match: F,
 61    ) where
 62        C: MatchCandidate,
 63        T: Borrow<C>,
 64        F: Fn(&C, f64, &Vec<usize>) -> R,
 65    {
 66        let mut candidate_chars = Vec::new();
 67        let mut lowercase_candidate_chars = Vec::new();
 68
 69        for candidate in candidates {
 70            if !candidate.borrow().has_chars(self.query_char_bag) {
 71                continue;
 72            }
 73
 74            if cancel_flag.load(atomic::Ordering::Relaxed) {
 75                break;
 76            }
 77
 78            candidate_chars.clear();
 79            lowercase_candidate_chars.clear();
 80            for c in candidate.borrow().to_string().chars() {
 81                candidate_chars.push(c);
 82                lowercase_candidate_chars.append(&mut c.to_lowercase().collect::<Vec<_>>());
 83            }
 84
 85            if !self.find_last_positions(lowercase_prefix, &lowercase_candidate_chars) {
 86                continue;
 87            }
 88
 89            let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
 90            self.score_matrix.clear();
 91            self.score_matrix.resize(matrix_len, None);
 92            self.best_position_matrix.clear();
 93            self.best_position_matrix.resize(matrix_len, 0);
 94
 95            let score = self.score_match(
 96                &candidate_chars,
 97                &lowercase_candidate_chars,
 98                prefix,
 99                lowercase_prefix,
100            );
101
102            if score > 0.0 {
103                results.push(build_match(
104                    candidate.borrow(),
105                    score,
106                    &self.match_positions,
107                ));
108            }
109        }
110    }
111
112    fn find_last_positions(
113        &mut self,
114        lowercase_prefix: &[char],
115        lowercase_candidate: &[char],
116    ) -> bool {
117        let mut lowercase_prefix = lowercase_prefix.iter();
118        let mut lowercase_candidate = lowercase_candidate.iter();
119        for (i, char) in self.lowercase_query.iter().enumerate().rev() {
120            if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
121                self.last_positions[i] = j + lowercase_prefix.len();
122            } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
123                self.last_positions[i] = j;
124            } else {
125                return false;
126            }
127        }
128        true
129    }
130
131    fn score_match(
132        &mut self,
133        path: &[char],
134        path_cased: &[char],
135        prefix: &[char],
136        lowercase_prefix: &[char],
137    ) -> f64 {
138        let score = self.recursive_score_match(
139            path,
140            path_cased,
141            prefix,
142            lowercase_prefix,
143            0,
144            0,
145            self.query.len() as f64,
146        ) * self.query.len() as f64;
147
148        if score <= 0.0 {
149            return 0.0;
150        }
151
152        let path_len = prefix.len() + path.len();
153        let mut cur_start = 0;
154        let mut byte_ix = 0;
155        let mut char_ix = 0;
156        for i in 0..self.query.len() {
157            let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
158            while char_ix < match_char_ix {
159                let ch = prefix
160                    .get(char_ix)
161                    .or_else(|| path.get(char_ix - prefix.len()))
162                    .unwrap();
163                byte_ix += ch.len_utf8();
164                char_ix += 1;
165            }
166            cur_start = match_char_ix + 1;
167            self.match_positions[i] = byte_ix;
168        }
169
170        score
171    }
172
173    fn recursive_score_match(
174        &mut self,
175        path: &[char],
176        path_cased: &[char],
177        prefix: &[char],
178        lowercase_prefix: &[char],
179        query_idx: usize,
180        path_idx: usize,
181        cur_score: f64,
182    ) -> f64 {
183        use std::path::MAIN_SEPARATOR;
184
185        if query_idx == self.query.len() {
186            return 1.0;
187        }
188
189        let path_len = prefix.len() + path.len();
190
191        if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
192            return memoized;
193        }
194
195        let mut score = 0.0;
196        let mut best_position = 0;
197
198        let query_char = self.lowercase_query[query_idx];
199        let limit = self.last_positions[query_idx];
200
201        let mut last_slash = 0;
202        for j in path_idx..=limit {
203            let path_char = if j < prefix.len() {
204                lowercase_prefix[j]
205            } else {
206                path_cased[j - prefix.len()]
207            };
208            let is_path_sep = path_char == MAIN_SEPARATOR;
209
210            if query_idx == 0 && is_path_sep {
211                last_slash = j;
212            }
213
214            #[cfg(not(target_os = "windows"))]
215            let need_to_score =
216                query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\');
217            // `query_char == '\\'` breaks `test_match_path_entries` on Windows, `\` is only used as a path separator on Windows.
218            #[cfg(target_os = "windows")]
219            let need_to_score = query_char == path_char || (is_path_sep && query_char == '_');
220            if need_to_score {
221                let curr = if j < prefix.len() {
222                    prefix[j]
223                } else {
224                    path[j - prefix.len()]
225                };
226
227                let mut char_score = 1.0;
228                if j > path_idx {
229                    let last = if j - 1 < prefix.len() {
230                        prefix[j - 1]
231                    } else {
232                        path[j - 1 - prefix.len()]
233                    };
234
235                    if last == MAIN_SEPARATOR {
236                        char_score = 0.9;
237                    } else if (last == '-' || last == '_' || last == ' ' || last.is_numeric())
238                        || (last.is_lowercase() && curr.is_uppercase())
239                    {
240                        char_score = 0.8;
241                    } else if last == '.' {
242                        char_score = 0.7;
243                    } else if query_idx == 0 {
244                        char_score = BASE_DISTANCE_PENALTY;
245                    } else {
246                        char_score = MIN_DISTANCE_PENALTY.max(
247                            BASE_DISTANCE_PENALTY
248                                - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
249                        );
250                    }
251                }
252
253                // Apply a severe penalty if the case doesn't match.
254                // This will make the exact matches have higher score than the case-insensitive and the
255                // path insensitive matches.
256                if (self.smart_case || curr == MAIN_SEPARATOR) && self.query[query_idx] != curr {
257                    char_score *= 0.001;
258                }
259
260                let mut multiplier = char_score;
261
262                // Scale the score based on how deep within the path we found the match.
263                if query_idx == 0 {
264                    multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
265                }
266
267                let mut next_score = 1.0;
268                if self.min_score > 0.0 {
269                    next_score = cur_score * multiplier;
270                    // Scores only decrease. If we can't pass the previous best, bail
271                    if next_score < self.min_score {
272                        // Ensure that score is non-zero so we use it in the memo table.
273                        if score == 0.0 {
274                            score = 1e-18;
275                        }
276                        continue;
277                    }
278                }
279
280                let new_score = self.recursive_score_match(
281                    path,
282                    path_cased,
283                    prefix,
284                    lowercase_prefix,
285                    query_idx + 1,
286                    j + 1,
287                    next_score,
288                ) * multiplier;
289
290                if new_score > score {
291                    score = new_score;
292                    best_position = j;
293                    // Optimization: can't score better than 1.
294                    if new_score == 1.0 {
295                        break;
296                    }
297                }
298            }
299        }
300
301        if best_position != 0 {
302            self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
303        }
304
305        self.score_matrix[query_idx * path_len + path_idx] = Some(score);
306        score
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use crate::{PathMatch, PathMatchCandidate};
313
314    use super::*;
315    use std::{
316        path::{Path, PathBuf},
317        sync::Arc,
318    };
319
320    #[test]
321    fn test_get_last_positions() {
322        let mut query: &[char] = &['d', 'c'];
323        let mut matcher = Matcher::new(query, query, query.into(), false);
324        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
325        assert!(!result);
326
327        query = &['c', 'd'];
328        let mut matcher = Matcher::new(query, query, query.into(), false);
329        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
330        assert!(result);
331        assert_eq!(matcher.last_positions, vec![2, 4]);
332
333        query = &['z', '/', 'z', 'f'];
334        let mut matcher = Matcher::new(query, query, query.into(), false);
335        let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
336        assert!(result);
337        assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
338    }
339
340    #[cfg(not(target_os = "windows"))]
341    #[test]
342    fn test_match_path_entries() {
343        let paths = vec![
344            "",
345            "a",
346            "ab",
347            "abC",
348            "abcd",
349            "alphabravocharlie",
350            "AlphaBravoCharlie",
351            "thisisatestdir",
352            "/////ThisIsATestDir",
353            "/this/is/a/test/dir",
354            "/test/tiatd",
355        ];
356
357        assert_eq!(
358            match_single_path_query("abc", false, &paths),
359            vec![
360                ("abC", vec![0, 1, 2]),
361                ("abcd", vec![0, 1, 2]),
362                ("AlphaBravoCharlie", vec![0, 5, 10]),
363                ("alphabravocharlie", vec![4, 5, 10]),
364            ]
365        );
366        assert_eq!(
367            match_single_path_query("t/i/a/t/d", false, &paths),
368            vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
369        );
370
371        assert_eq!(
372            match_single_path_query("tiatd", false, &paths),
373            vec![
374                ("/test/tiatd", vec![6, 7, 8, 9, 10]),
375                ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
376                ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
377                ("thisisatestdir", vec![0, 2, 6, 7, 11]),
378            ]
379        );
380    }
381
382    /// todo(windows)
383    /// Now, on Windows, users can only use the backslash as a path separator.
384    /// I do want to support both the backslash and the forward slash as path separators on Windows.
385    #[cfg(target_os = "windows")]
386    #[test]
387    fn test_match_path_entries() {
388        let paths = vec![
389            "",
390            "a",
391            "ab",
392            "abC",
393            "abcd",
394            "alphabravocharlie",
395            "AlphaBravoCharlie",
396            "thisisatestdir",
397            "\\\\\\\\\\ThisIsATestDir",
398            "\\this\\is\\a\\test\\dir",
399            "\\test\\tiatd",
400        ];
401
402        assert_eq!(
403            match_single_path_query("abc", false, &paths),
404            vec![
405                ("abC", vec![0, 1, 2]),
406                ("abcd", vec![0, 1, 2]),
407                ("AlphaBravoCharlie", vec![0, 5, 10]),
408                ("alphabravocharlie", vec![4, 5, 10]),
409            ]
410        );
411        assert_eq!(
412            match_single_path_query("t\\i\\a\\t\\d", false, &paths),
413            vec![(
414                "\\this\\is\\a\\test\\dir",
415                vec![1, 5, 6, 8, 9, 10, 11, 15, 16]
416            ),]
417        );
418
419        assert_eq!(
420            match_single_path_query("tiatd", false, &paths),
421            vec![
422                ("\\test\\tiatd", vec![6, 7, 8, 9, 10]),
423                ("\\this\\is\\a\\test\\dir", vec![1, 6, 9, 11, 16]),
424                ("\\\\\\\\\\ThisIsATestDir", vec![5, 9, 11, 12, 16]),
425                ("thisisatestdir", vec![0, 2, 6, 7, 11]),
426            ]
427        );
428    }
429
430    #[test]
431    fn test_lowercase_longer_than_uppercase() {
432        // This character has more chars in lower-case than in upper-case.
433        let paths = vec!["\u{0130}"];
434        let query = "\u{0130}";
435        assert_eq!(
436            match_single_path_query(query, false, &paths),
437            vec![("\u{0130}", vec![0])]
438        );
439
440        // Path is the lower-case version of the query
441        let paths = vec!["i\u{307}"];
442        let query = "\u{0130}";
443        assert_eq!(
444            match_single_path_query(query, false, &paths),
445            vec![("i\u{307}", vec![0])]
446        );
447    }
448
449    #[test]
450    fn test_match_multibyte_path_entries() {
451        let paths = vec![
452            "aαbβ/cγdδ",
453            "αβγδ/bcde",
454            "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f",
455            "/d/🆒/h",
456        ];
457        assert_eq!("1️⃣".len(), 7);
458        assert_eq!(
459            match_single_path_query("bcd", false, &paths),
460            vec![
461                ("αβγδ/bcde", vec![9, 10, 11]),
462                ("aαbβ/cγdδ", vec![3, 7, 10]),
463            ]
464        );
465        assert_eq!(
466            match_single_path_query("cde", false, &paths),
467            vec![
468                ("αβγδ/bcde", vec![10, 11, 12]),
469                ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
470            ]
471        );
472    }
473
474    fn match_single_path_query<'a>(
475        query: &str,
476        smart_case: bool,
477        paths: &[&'a str],
478    ) -> Vec<(&'a str, Vec<usize>)> {
479        let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
480        let query = query.chars().collect::<Vec<_>>();
481        let query_chars = CharBag::from(&lowercase_query[..]);
482
483        let path_arcs: Vec<Arc<Path>> = paths
484            .iter()
485            .map(|path| Arc::from(PathBuf::from(path)))
486            .collect::<Vec<_>>();
487        let mut path_entries = Vec::new();
488        for (i, path) in paths.iter().enumerate() {
489            let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
490            let char_bag = CharBag::from(lowercase_path.as_slice());
491            path_entries.push(PathMatchCandidate {
492                is_dir: false,
493                char_bag,
494                path: &path_arcs[i],
495            });
496        }
497
498        let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case);
499
500        let cancel_flag = AtomicBool::new(false);
501        let mut results = Vec::new();
502
503        matcher.match_candidates(
504            &[],
505            &[],
506            path_entries.into_iter(),
507            &mut results,
508            &cancel_flag,
509            |candidate, score, positions| PathMatch {
510                score,
511                worktree_id: 0,
512                positions: positions.clone(),
513                path: Arc::from(candidate.path),
514                path_prefix: "".into(),
515                distance_to_relative_ancestor: usize::MAX,
516                is_dir: false,
517            },
518        );
519        results.sort_by(|a, b| b.cmp(a));
520
521        results
522            .into_iter()
523            .map(|result| {
524                (
525                    paths
526                        .iter()
527                        .copied()
528                        .find(|p| result.path.as_ref() == Path::new(p))
529                        .unwrap(),
530                    result.positions,
531                )
532            })
533            .collect()
534    }
535}