matcher.rs

  1use std::{
  2    borrow::Cow,
  3    sync::atomic::{self, AtomicBool},
  4};
  5
  6use crate::CharBag;
  7
  8const BASE_DISTANCE_PENALTY: f64 = 0.6;
  9const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
 10const MIN_DISTANCE_PENALTY: f64 = 0.2;
 11
 12// TODO:
 13// Use `Path` instead of `&str` for paths.
 14pub struct Matcher<'a> {
 15    query: &'a [char],
 16    lowercase_query: &'a [char],
 17    query_char_bag: CharBag,
 18    smart_case: bool,
 19    min_score: f64,
 20    match_positions: Vec<usize>,
 21    last_positions: Vec<usize>,
 22    score_matrix: Vec<Option<f64>>,
 23    best_position_matrix: Vec<usize>,
 24}
 25
 26pub trait MatchCandidate {
 27    fn has_chars(&self, bag: CharBag) -> bool;
 28    fn to_string(&self) -> Cow<'_, str>;
 29}
 30
 31impl<'a> Matcher<'a> {
 32    pub fn new(
 33        query: &'a [char],
 34        lowercase_query: &'a [char],
 35        query_char_bag: CharBag,
 36        smart_case: bool,
 37    ) -> Self {
 38        Self {
 39            query,
 40            lowercase_query,
 41            query_char_bag,
 42            min_score: 0.0,
 43            last_positions: vec![0; lowercase_query.len()],
 44            match_positions: vec![0; query.len()],
 45            score_matrix: Vec::new(),
 46            best_position_matrix: Vec::new(),
 47            smart_case,
 48        }
 49    }
 50
 51    /// Filter and score fuzzy match candidates. Results are returned unsorted, in the same order as
 52    /// the input candidates.
 53    pub fn match_candidates<C: MatchCandidate, R, F>(
 54        &mut self,
 55        prefix: &[char],
 56        lowercase_prefix: &[char],
 57        candidates: impl Iterator<Item = C>,
 58        results: &mut Vec<R>,
 59        cancel_flag: &AtomicBool,
 60        build_match: F,
 61    ) where
 62        F: Fn(&C, f64, &Vec<usize>) -> R,
 63    {
 64        let mut candidate_chars = Vec::new();
 65        let mut lowercase_candidate_chars = Vec::new();
 66
 67        for candidate in candidates {
 68            if !candidate.has_chars(self.query_char_bag) {
 69                continue;
 70            }
 71
 72            if cancel_flag.load(atomic::Ordering::Relaxed) {
 73                break;
 74            }
 75
 76            candidate_chars.clear();
 77            lowercase_candidate_chars.clear();
 78            for c in candidate.to_string().chars() {
 79                candidate_chars.push(c);
 80                lowercase_candidate_chars.append(&mut c.to_lowercase().collect::<Vec<_>>());
 81            }
 82
 83            if !self.find_last_positions(lowercase_prefix, &lowercase_candidate_chars) {
 84                continue;
 85            }
 86
 87            let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
 88            self.score_matrix.clear();
 89            self.score_matrix.resize(matrix_len, None);
 90            self.best_position_matrix.clear();
 91            self.best_position_matrix.resize(matrix_len, 0);
 92
 93            let score = self.score_match(
 94                &candidate_chars,
 95                &lowercase_candidate_chars,
 96                prefix,
 97                lowercase_prefix,
 98            );
 99
100            if score > 0.0 {
101                results.push(build_match(&candidate, score, &self.match_positions));
102            }
103        }
104    }
105
106    fn find_last_positions(
107        &mut self,
108        lowercase_prefix: &[char],
109        lowercase_candidate: &[char],
110    ) -> bool {
111        let mut lowercase_prefix = lowercase_prefix.iter();
112        let mut lowercase_candidate = lowercase_candidate.iter();
113        for (i, char) in self.lowercase_query.iter().enumerate().rev() {
114            if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
115                self.last_positions[i] = j + lowercase_prefix.len();
116            } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
117                self.last_positions[i] = j;
118            } else {
119                return false;
120            }
121        }
122        true
123    }
124
125    fn score_match(
126        &mut self,
127        path: &[char],
128        path_cased: &[char],
129        prefix: &[char],
130        lowercase_prefix: &[char],
131    ) -> f64 {
132        let score = self.recursive_score_match(
133            path,
134            path_cased,
135            prefix,
136            lowercase_prefix,
137            0,
138            0,
139            self.query.len() as f64,
140        ) * self.query.len() as f64;
141
142        if score <= 0.0 {
143            return 0.0;
144        }
145
146        let path_len = prefix.len() + path.len();
147        let mut cur_start = 0;
148        let mut byte_ix = 0;
149        let mut char_ix = 0;
150        for i in 0..self.query.len() {
151            let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
152            while char_ix < match_char_ix {
153                let ch = prefix
154                    .get(char_ix)
155                    .or_else(|| path.get(char_ix - prefix.len()))
156                    .unwrap();
157                byte_ix += ch.len_utf8();
158                char_ix += 1;
159            }
160            cur_start = match_char_ix + 1;
161            self.match_positions[i] = byte_ix;
162        }
163
164        score
165    }
166
167    #[allow(clippy::too_many_arguments)]
168    fn recursive_score_match(
169        &mut self,
170        path: &[char],
171        path_cased: &[char],
172        prefix: &[char],
173        lowercase_prefix: &[char],
174        query_idx: usize,
175        path_idx: usize,
176        cur_score: f64,
177    ) -> f64 {
178        use std::path::MAIN_SEPARATOR;
179
180        if query_idx == self.query.len() {
181            return 1.0;
182        }
183
184        let path_len = prefix.len() + path.len();
185
186        if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
187            return memoized;
188        }
189
190        let mut score = 0.0;
191        let mut best_position = 0;
192
193        let query_char = self.lowercase_query[query_idx];
194        let limit = self.last_positions[query_idx];
195
196        let mut last_slash = 0;
197        for j in path_idx..=limit {
198            let path_char = if j < prefix.len() {
199                lowercase_prefix[j]
200            } else {
201                path_cased[j - prefix.len()]
202            };
203            let is_path_sep = path_char == MAIN_SEPARATOR;
204
205            if query_idx == 0 && is_path_sep {
206                last_slash = j;
207            }
208
209            #[cfg(not(target_os = "windows"))]
210            let need_to_score =
211                query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\');
212            // `query_char == '\\'` breaks `test_match_path_entries` on Windows, `\` is only used as a path separator on Windows.
213            #[cfg(target_os = "windows")]
214            let need_to_score = query_char == path_char || (is_path_sep && query_char == '_');
215            if need_to_score {
216                let curr = if j < prefix.len() {
217                    prefix[j]
218                } else {
219                    path[j - prefix.len()]
220                };
221
222                let mut char_score = 1.0;
223                if j > path_idx {
224                    let last = if j - 1 < prefix.len() {
225                        prefix[j - 1]
226                    } else {
227                        path[j - 1 - prefix.len()]
228                    };
229
230                    if last == MAIN_SEPARATOR {
231                        char_score = 0.9;
232                    } else if (last == '-' || last == '_' || last == ' ' || last.is_numeric())
233                        || (last.is_lowercase() && curr.is_uppercase())
234                    {
235                        char_score = 0.8;
236                    } else if last == '.' {
237                        char_score = 0.7;
238                    } else if query_idx == 0 {
239                        char_score = BASE_DISTANCE_PENALTY;
240                    } else {
241                        char_score = MIN_DISTANCE_PENALTY.max(
242                            BASE_DISTANCE_PENALTY
243                                - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
244                        );
245                    }
246                }
247
248                // Apply a severe penalty if the case doesn't match.
249                // This will make the exact matches have higher score than the case-insensitive and the
250                // path insensitive matches.
251                if (self.smart_case || curr == MAIN_SEPARATOR) && self.query[query_idx] != curr {
252                    char_score *= 0.001;
253                }
254
255                let mut multiplier = char_score;
256
257                // Scale the score based on how deep within the path we found the match.
258                if query_idx == 0 {
259                    multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
260                }
261
262                let mut next_score = 1.0;
263                if self.min_score > 0.0 {
264                    next_score = cur_score * multiplier;
265                    // Scores only decrease. If we can't pass the previous best, bail
266                    if next_score < self.min_score {
267                        // Ensure that score is non-zero so we use it in the memo table.
268                        if score == 0.0 {
269                            score = 1e-18;
270                        }
271                        continue;
272                    }
273                }
274
275                let new_score = self.recursive_score_match(
276                    path,
277                    path_cased,
278                    prefix,
279                    lowercase_prefix,
280                    query_idx + 1,
281                    j + 1,
282                    next_score,
283                ) * multiplier;
284
285                if new_score > score {
286                    score = new_score;
287                    best_position = j;
288                    // Optimization: can't score better than 1.
289                    if new_score == 1.0 {
290                        break;
291                    }
292                }
293            }
294        }
295
296        if best_position != 0 {
297            self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
298        }
299
300        self.score_matrix[query_idx * path_len + path_idx] = Some(score);
301        score
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use crate::{PathMatch, PathMatchCandidate};
308
309    use super::*;
310    use std::{
311        path::{Path, PathBuf},
312        sync::Arc,
313    };
314
315    #[test]
316    fn test_get_last_positions() {
317        let mut query: &[char] = &['d', 'c'];
318        let mut matcher = Matcher::new(query, query, query.into(), false);
319        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
320        assert!(!result);
321
322        query = &['c', 'd'];
323        let mut matcher = Matcher::new(query, query, query.into(), false);
324        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
325        assert!(result);
326        assert_eq!(matcher.last_positions, vec![2, 4]);
327
328        query = &['z', '/', 'z', 'f'];
329        let mut matcher = Matcher::new(query, query, query.into(), false);
330        let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
331        assert!(result);
332        assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
333    }
334
335    #[cfg(not(target_os = "windows"))]
336    #[test]
337    fn test_match_path_entries() {
338        let paths = vec![
339            "",
340            "a",
341            "ab",
342            "abC",
343            "abcd",
344            "alphabravocharlie",
345            "AlphaBravoCharlie",
346            "thisisatestdir",
347            "/////ThisIsATestDir",
348            "/this/is/a/test/dir",
349            "/test/tiatd",
350        ];
351
352        assert_eq!(
353            match_single_path_query("abc", false, &paths),
354            vec![
355                ("abC", vec![0, 1, 2]),
356                ("abcd", vec![0, 1, 2]),
357                ("AlphaBravoCharlie", vec![0, 5, 10]),
358                ("alphabravocharlie", vec![4, 5, 10]),
359            ]
360        );
361        assert_eq!(
362            match_single_path_query("t/i/a/t/d", false, &paths),
363            vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
364        );
365
366        assert_eq!(
367            match_single_path_query("tiatd", false, &paths),
368            vec![
369                ("/test/tiatd", vec![6, 7, 8, 9, 10]),
370                ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
371                ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
372                ("thisisatestdir", vec![0, 2, 6, 7, 11]),
373            ]
374        );
375    }
376
377    /// todo(windows)
378    /// Now, on Windows, users can only use the backslash as a path separator.
379    /// I do want to support both the backslash and the forward slash as path separators on Windows.
380    #[cfg(target_os = "windows")]
381    #[test]
382    fn test_match_path_entries() {
383        let paths = vec![
384            "",
385            "a",
386            "ab",
387            "abC",
388            "abcd",
389            "alphabravocharlie",
390            "AlphaBravoCharlie",
391            "thisisatestdir",
392            "\\\\\\\\\\ThisIsATestDir",
393            "\\this\\is\\a\\test\\dir",
394            "\\test\\tiatd",
395        ];
396
397        assert_eq!(
398            match_single_path_query("abc", false, &paths),
399            vec![
400                ("abC", vec![0, 1, 2]),
401                ("abcd", vec![0, 1, 2]),
402                ("AlphaBravoCharlie", vec![0, 5, 10]),
403                ("alphabravocharlie", vec![4, 5, 10]),
404            ]
405        );
406        assert_eq!(
407            match_single_path_query("t\\i\\a\\t\\d", false, &paths),
408            vec![(
409                "\\this\\is\\a\\test\\dir",
410                vec![1, 5, 6, 8, 9, 10, 11, 15, 16]
411            ),]
412        );
413
414        assert_eq!(
415            match_single_path_query("tiatd", false, &paths),
416            vec![
417                ("\\test\\tiatd", vec![6, 7, 8, 9, 10]),
418                ("\\this\\is\\a\\test\\dir", vec![1, 6, 9, 11, 16]),
419                ("\\\\\\\\\\ThisIsATestDir", vec![5, 9, 11, 12, 16]),
420                ("thisisatestdir", vec![0, 2, 6, 7, 11]),
421            ]
422        );
423    }
424
425    #[test]
426    fn test_lowercase_longer_than_uppercase() {
427        // This character has more chars in lower-case than in upper-case.
428        let paths = vec!["\u{0130}"];
429        let query = "\u{0130}";
430        assert_eq!(
431            match_single_path_query(query, false, &paths),
432            vec![("\u{0130}", vec![0])]
433        );
434
435        // Path is the lower-case version of the query
436        let paths = vec!["i\u{307}"];
437        let query = "\u{0130}";
438        assert_eq!(
439            match_single_path_query(query, false, &paths),
440            vec![("i\u{307}", vec![0])]
441        );
442    }
443
444    #[test]
445    fn test_match_multibyte_path_entries() {
446        let paths = vec![
447            "aαbβ/cγdδ",
448            "αβγδ/bcde",
449            "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f",
450            "/d/🆒/h",
451        ];
452        assert_eq!("1️⃣".len(), 7);
453        assert_eq!(
454            match_single_path_query("bcd", false, &paths),
455            vec![
456                ("αβγδ/bcde", vec![9, 10, 11]),
457                ("aαbβ/cγdδ", vec![3, 7, 10]),
458            ]
459        );
460        assert_eq!(
461            match_single_path_query("cde", false, &paths),
462            vec![
463                ("αβγδ/bcde", vec![10, 11, 12]),
464                ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
465            ]
466        );
467    }
468
469    fn match_single_path_query<'a>(
470        query: &str,
471        smart_case: bool,
472        paths: &[&'a str],
473    ) -> Vec<(&'a str, Vec<usize>)> {
474        let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
475        let query = query.chars().collect::<Vec<_>>();
476        let query_chars = CharBag::from(&lowercase_query[..]);
477
478        let path_arcs: Vec<Arc<Path>> = paths
479            .iter()
480            .map(|path| Arc::from(PathBuf::from(path)))
481            .collect::<Vec<_>>();
482        let mut path_entries = Vec::new();
483        for (i, path) in paths.iter().enumerate() {
484            let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
485            let char_bag = CharBag::from(lowercase_path.as_slice());
486            path_entries.push(PathMatchCandidate {
487                is_dir: false,
488                char_bag,
489                path: &path_arcs[i],
490            });
491        }
492
493        let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case);
494
495        let cancel_flag = AtomicBool::new(false);
496        let mut results = Vec::new();
497
498        matcher.match_candidates(
499            &[],
500            &[],
501            path_entries.into_iter(),
502            &mut results,
503            &cancel_flag,
504            |candidate, score, positions| PathMatch {
505                score,
506                worktree_id: 0,
507                positions: positions.clone(),
508                path: Arc::from(candidate.path),
509                path_prefix: "".into(),
510                distance_to_relative_ancestor: usize::MAX,
511                is_dir: false,
512            },
513        );
514        results.sort_by(|a, b| b.cmp(a));
515
516        results
517            .into_iter()
518            .map(|result| {
519                (
520                    paths
521                        .iter()
522                        .copied()
523                        .find(|p| result.path.as_ref() == Path::new(p))
524                        .unwrap(),
525                    result.positions,
526                )
527            })
528            .collect()
529    }
530}