matcher.rs

  1use std::{
  2    borrow::{Borrow, Cow},
  3    sync::atomic::{self, AtomicBool},
  4};
  5
  6use crate::CharBag;
  7
  8const BASE_DISTANCE_PENALTY: f64 = 0.6;
  9const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
 10const MIN_DISTANCE_PENALTY: f64 = 0.2;
 11
 12// TODO:
 13// Use `Path` instead of `&str` for paths.
 14pub struct Matcher<'a> {
 15    query: &'a [char],
 16    lowercase_query: &'a [char],
 17    query_char_bag: CharBag,
 18    smart_case: bool,
 19    min_score: f64,
 20    match_positions: Vec<usize>,
 21    last_positions: Vec<usize>,
 22    score_matrix: Vec<Option<f64>>,
 23    best_position_matrix: Vec<usize>,
 24}
 25
 26pub trait MatchCandidate {
 27    fn has_chars(&self, bag: CharBag) -> bool;
 28    fn to_string(&self) -> Cow<'_, str>;
 29}
 30
 31impl<'a> Matcher<'a> {
 32    pub fn new(
 33        query: &'a [char],
 34        lowercase_query: &'a [char],
 35        query_char_bag: CharBag,
 36        smart_case: bool,
 37    ) -> Self {
 38        Self {
 39            query,
 40            lowercase_query,
 41            query_char_bag,
 42            min_score: 0.0,
 43            last_positions: vec![0; lowercase_query.len()],
 44            match_positions: vec![0; query.len()],
 45            score_matrix: Vec::new(),
 46            best_position_matrix: Vec::new(),
 47            smart_case,
 48        }
 49    }
 50
 51    /// Filter and score fuzzy match candidates. Results are returned unsorted, in the same order as
 52    /// the input candidates.
 53    pub fn match_candidates<C, R, F, T>(
 54        &mut self,
 55        prefix: &[char],
 56        lowercase_prefix: &[char],
 57        candidates: impl Iterator<Item = T>,
 58        results: &mut Vec<R>,
 59        cancel_flag: &AtomicBool,
 60        build_match: F,
 61    ) where
 62        C: MatchCandidate,
 63        T: Borrow<C>,
 64        F: Fn(&C, f64, &Vec<usize>) -> R,
 65    {
 66        let mut candidate_chars = Vec::new();
 67        let mut lowercase_candidate_chars = Vec::new();
 68
 69        for candidate in candidates {
 70            if !candidate.borrow().has_chars(self.query_char_bag) {
 71                continue;
 72            }
 73
 74            if cancel_flag.load(atomic::Ordering::Relaxed) {
 75                break;
 76            }
 77
 78            candidate_chars.clear();
 79            lowercase_candidate_chars.clear();
 80            for c in candidate.borrow().to_string().chars() {
 81                candidate_chars.push(c);
 82                lowercase_candidate_chars.append(&mut c.to_lowercase().collect::<Vec<_>>());
 83            }
 84
 85            if !self.find_last_positions(lowercase_prefix, &lowercase_candidate_chars) {
 86                continue;
 87            }
 88
 89            let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
 90            self.score_matrix.clear();
 91            self.score_matrix.resize(matrix_len, None);
 92            self.best_position_matrix.clear();
 93            self.best_position_matrix.resize(matrix_len, 0);
 94
 95            let score = self.score_match(
 96                &candidate_chars,
 97                &lowercase_candidate_chars,
 98                prefix,
 99                lowercase_prefix,
100            );
101
102            if score > 0.0 {
103                results.push(build_match(
104                    candidate.borrow(),
105                    score,
106                    &self.match_positions,
107                ));
108            }
109        }
110    }
111
112    fn find_last_positions(
113        &mut self,
114        lowercase_prefix: &[char],
115        lowercase_candidate: &[char],
116    ) -> bool {
117        let mut lowercase_prefix = lowercase_prefix.iter();
118        let mut lowercase_candidate = lowercase_candidate.iter();
119        for (i, char) in self.lowercase_query.iter().enumerate().rev() {
120            if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
121                self.last_positions[i] = j + lowercase_prefix.len();
122            } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
123                self.last_positions[i] = j;
124            } else {
125                return false;
126            }
127        }
128        true
129    }
130
131    fn score_match(
132        &mut self,
133        path: &[char],
134        path_cased: &[char],
135        prefix: &[char],
136        lowercase_prefix: &[char],
137    ) -> f64 {
138        let score = self.recursive_score_match(
139            path,
140            path_cased,
141            prefix,
142            lowercase_prefix,
143            0,
144            0,
145            self.query.len() as f64,
146        ) * self.query.len() as f64;
147
148        if score <= 0.0 {
149            return 0.0;
150        }
151
152        let path_len = prefix.len() + path.len();
153        let mut cur_start = 0;
154        let mut byte_ix = 0;
155        let mut char_ix = 0;
156        for i in 0..self.query.len() {
157            let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
158            while char_ix < match_char_ix {
159                let ch = prefix
160                    .get(char_ix)
161                    .or_else(|| path.get(char_ix - prefix.len()))
162                    .unwrap();
163                byte_ix += ch.len_utf8();
164                char_ix += 1;
165            }
166            cur_start = match_char_ix + 1;
167            self.match_positions[i] = byte_ix;
168        }
169
170        score
171    }
172
173    #[allow(clippy::too_many_arguments)]
174    fn recursive_score_match(
175        &mut self,
176        path: &[char],
177        path_cased: &[char],
178        prefix: &[char],
179        lowercase_prefix: &[char],
180        query_idx: usize,
181        path_idx: usize,
182        cur_score: f64,
183    ) -> f64 {
184        use std::path::MAIN_SEPARATOR;
185
186        if query_idx == self.query.len() {
187            return 1.0;
188        }
189
190        let path_len = prefix.len() + path.len();
191
192        if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
193            return memoized;
194        }
195
196        let mut score = 0.0;
197        let mut best_position = 0;
198
199        let query_char = self.lowercase_query[query_idx];
200        let limit = self.last_positions[query_idx];
201
202        let mut last_slash = 0;
203        for j in path_idx..=limit {
204            let path_char = if j < prefix.len() {
205                lowercase_prefix[j]
206            } else {
207                path_cased[j - prefix.len()]
208            };
209            let is_path_sep = path_char == MAIN_SEPARATOR;
210
211            if query_idx == 0 && is_path_sep {
212                last_slash = j;
213            }
214
215            #[cfg(not(target_os = "windows"))]
216            let need_to_score =
217                query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\');
218            // `query_char == '\\'` breaks `test_match_path_entries` on Windows, `\` is only used as a path separator on Windows.
219            #[cfg(target_os = "windows")]
220            let need_to_score = query_char == path_char || (is_path_sep && query_char == '_');
221            if need_to_score {
222                let curr = if j < prefix.len() {
223                    prefix[j]
224                } else {
225                    path[j - prefix.len()]
226                };
227
228                let mut char_score = 1.0;
229                if j > path_idx {
230                    let last = if j - 1 < prefix.len() {
231                        prefix[j - 1]
232                    } else {
233                        path[j - 1 - prefix.len()]
234                    };
235
236                    if last == MAIN_SEPARATOR {
237                        char_score = 0.9;
238                    } else if (last == '-' || last == '_' || last == ' ' || last.is_numeric())
239                        || (last.is_lowercase() && curr.is_uppercase())
240                    {
241                        char_score = 0.8;
242                    } else if last == '.' {
243                        char_score = 0.7;
244                    } else if query_idx == 0 {
245                        char_score = BASE_DISTANCE_PENALTY;
246                    } else {
247                        char_score = MIN_DISTANCE_PENALTY.max(
248                            BASE_DISTANCE_PENALTY
249                                - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
250                        );
251                    }
252                }
253
254                // Apply a severe penalty if the case doesn't match.
255                // This will make the exact matches have higher score than the case-insensitive and the
256                // path insensitive matches.
257                if (self.smart_case || curr == MAIN_SEPARATOR) && self.query[query_idx] != curr {
258                    char_score *= 0.001;
259                }
260
261                let mut multiplier = char_score;
262
263                // Scale the score based on how deep within the path we found the match.
264                if query_idx == 0 {
265                    multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
266                }
267
268                let mut next_score = 1.0;
269                if self.min_score > 0.0 {
270                    next_score = cur_score * multiplier;
271                    // Scores only decrease. If we can't pass the previous best, bail
272                    if next_score < self.min_score {
273                        // Ensure that score is non-zero so we use it in the memo table.
274                        if score == 0.0 {
275                            score = 1e-18;
276                        }
277                        continue;
278                    }
279                }
280
281                let new_score = self.recursive_score_match(
282                    path,
283                    path_cased,
284                    prefix,
285                    lowercase_prefix,
286                    query_idx + 1,
287                    j + 1,
288                    next_score,
289                ) * multiplier;
290
291                if new_score > score {
292                    score = new_score;
293                    best_position = j;
294                    // Optimization: can't score better than 1.
295                    if new_score == 1.0 {
296                        break;
297                    }
298                }
299            }
300        }
301
302        if best_position != 0 {
303            self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
304        }
305
306        self.score_matrix[query_idx * path_len + path_idx] = Some(score);
307        score
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use crate::{PathMatch, PathMatchCandidate};
314
315    use super::*;
316    use std::{
317        path::{Path, PathBuf},
318        sync::Arc,
319    };
320
321    #[test]
322    fn test_get_last_positions() {
323        let mut query: &[char] = &['d', 'c'];
324        let mut matcher = Matcher::new(query, query, query.into(), false);
325        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
326        assert!(!result);
327
328        query = &['c', 'd'];
329        let mut matcher = Matcher::new(query, query, query.into(), false);
330        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
331        assert!(result);
332        assert_eq!(matcher.last_positions, vec![2, 4]);
333
334        query = &['z', '/', 'z', 'f'];
335        let mut matcher = Matcher::new(query, query, query.into(), false);
336        let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
337        assert!(result);
338        assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
339    }
340
341    #[cfg(not(target_os = "windows"))]
342    #[test]
343    fn test_match_path_entries() {
344        let paths = vec![
345            "",
346            "a",
347            "ab",
348            "abC",
349            "abcd",
350            "alphabravocharlie",
351            "AlphaBravoCharlie",
352            "thisisatestdir",
353            "/////ThisIsATestDir",
354            "/this/is/a/test/dir",
355            "/test/tiatd",
356        ];
357
358        assert_eq!(
359            match_single_path_query("abc", false, &paths),
360            vec![
361                ("abC", vec![0, 1, 2]),
362                ("abcd", vec![0, 1, 2]),
363                ("AlphaBravoCharlie", vec![0, 5, 10]),
364                ("alphabravocharlie", vec![4, 5, 10]),
365            ]
366        );
367        assert_eq!(
368            match_single_path_query("t/i/a/t/d", false, &paths),
369            vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
370        );
371
372        assert_eq!(
373            match_single_path_query("tiatd", false, &paths),
374            vec![
375                ("/test/tiatd", vec![6, 7, 8, 9, 10]),
376                ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
377                ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
378                ("thisisatestdir", vec![0, 2, 6, 7, 11]),
379            ]
380        );
381    }
382
383    /// todo(windows)
384    /// Now, on Windows, users can only use the backslash as a path separator.
385    /// I do want to support both the backslash and the forward slash as path separators on Windows.
386    #[cfg(target_os = "windows")]
387    #[test]
388    fn test_match_path_entries() {
389        let paths = vec![
390            "",
391            "a",
392            "ab",
393            "abC",
394            "abcd",
395            "alphabravocharlie",
396            "AlphaBravoCharlie",
397            "thisisatestdir",
398            "\\\\\\\\\\ThisIsATestDir",
399            "\\this\\is\\a\\test\\dir",
400            "\\test\\tiatd",
401        ];
402
403        assert_eq!(
404            match_single_path_query("abc", false, &paths),
405            vec![
406                ("abC", vec![0, 1, 2]),
407                ("abcd", vec![0, 1, 2]),
408                ("AlphaBravoCharlie", vec![0, 5, 10]),
409                ("alphabravocharlie", vec![4, 5, 10]),
410            ]
411        );
412        assert_eq!(
413            match_single_path_query("t\\i\\a\\t\\d", false, &paths),
414            vec![(
415                "\\this\\is\\a\\test\\dir",
416                vec![1, 5, 6, 8, 9, 10, 11, 15, 16]
417            ),]
418        );
419
420        assert_eq!(
421            match_single_path_query("tiatd", false, &paths),
422            vec![
423                ("\\test\\tiatd", vec![6, 7, 8, 9, 10]),
424                ("\\this\\is\\a\\test\\dir", vec![1, 6, 9, 11, 16]),
425                ("\\\\\\\\\\ThisIsATestDir", vec![5, 9, 11, 12, 16]),
426                ("thisisatestdir", vec![0, 2, 6, 7, 11]),
427            ]
428        );
429    }
430
431    #[test]
432    fn test_lowercase_longer_than_uppercase() {
433        // This character has more chars in lower-case than in upper-case.
434        let paths = vec!["\u{0130}"];
435        let query = "\u{0130}";
436        assert_eq!(
437            match_single_path_query(query, false, &paths),
438            vec![("\u{0130}", vec![0])]
439        );
440
441        // Path is the lower-case version of the query
442        let paths = vec!["i\u{307}"];
443        let query = "\u{0130}";
444        assert_eq!(
445            match_single_path_query(query, false, &paths),
446            vec![("i\u{307}", vec![0])]
447        );
448    }
449
450    #[test]
451    fn test_match_multibyte_path_entries() {
452        let paths = vec![
453            "aαbβ/cγdδ",
454            "αβγδ/bcde",
455            "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f",
456            "/d/🆒/h",
457        ];
458        assert_eq!("1️⃣".len(), 7);
459        assert_eq!(
460            match_single_path_query("bcd", false, &paths),
461            vec![
462                ("αβγδ/bcde", vec![9, 10, 11]),
463                ("aαbβ/cγdδ", vec![3, 7, 10]),
464            ]
465        );
466        assert_eq!(
467            match_single_path_query("cde", false, &paths),
468            vec![
469                ("αβγδ/bcde", vec![10, 11, 12]),
470                ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
471            ]
472        );
473    }
474
475    fn match_single_path_query<'a>(
476        query: &str,
477        smart_case: bool,
478        paths: &[&'a str],
479    ) -> Vec<(&'a str, Vec<usize>)> {
480        let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
481        let query = query.chars().collect::<Vec<_>>();
482        let query_chars = CharBag::from(&lowercase_query[..]);
483
484        let path_arcs: Vec<Arc<Path>> = paths
485            .iter()
486            .map(|path| Arc::from(PathBuf::from(path)))
487            .collect::<Vec<_>>();
488        let mut path_entries = Vec::new();
489        for (i, path) in paths.iter().enumerate() {
490            let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
491            let char_bag = CharBag::from(lowercase_path.as_slice());
492            path_entries.push(PathMatchCandidate {
493                is_dir: false,
494                char_bag,
495                path: &path_arcs[i],
496            });
497        }
498
499        let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case);
500
501        let cancel_flag = AtomicBool::new(false);
502        let mut results = Vec::new();
503
504        matcher.match_candidates(
505            &[],
506            &[],
507            path_entries.into_iter(),
508            &mut results,
509            &cancel_flag,
510            |candidate, score, positions| PathMatch {
511                score,
512                worktree_id: 0,
513                positions: positions.clone(),
514                path: Arc::from(candidate.path),
515                path_prefix: "".into(),
516                distance_to_relative_ancestor: usize::MAX,
517                is_dir: false,
518            },
519        );
520        results.sort_by(|a, b| b.cmp(a));
521
522        results
523            .into_iter()
524            .map(|result| {
525                (
526                    paths
527                        .iter()
528                        .copied()
529                        .find(|p| result.path.as_ref() == Path::new(p))
530                        .unwrap(),
531                    result.positions,
532                )
533            })
534            .collect()
535    }
536}