matcher.rs

  1use std::{
  2    borrow::Cow,
  3    sync::atomic::{self, AtomicBool},
  4};
  5
  6use crate::CharBag;
  7
  8const BASE_DISTANCE_PENALTY: f64 = 0.6;
  9const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
 10const MIN_DISTANCE_PENALTY: f64 = 0.2;
 11
 12pub struct Matcher<'a> {
 13    query: &'a [char],
 14    lowercase_query: &'a [char],
 15    query_char_bag: CharBag,
 16    smart_case: bool,
 17    min_score: f64,
 18    match_positions: Vec<usize>,
 19    last_positions: Vec<usize>,
 20    score_matrix: Vec<Option<f64>>,
 21    best_position_matrix: Vec<usize>,
 22}
 23
 24pub trait MatchCandidate {
 25    fn has_chars(&self, bag: CharBag) -> bool;
 26    fn to_string(&self) -> Cow<'_, str>;
 27}
 28
 29impl<'a> Matcher<'a> {
 30    pub fn new(
 31        query: &'a [char],
 32        lowercase_query: &'a [char],
 33        query_char_bag: CharBag,
 34        smart_case: bool,
 35    ) -> Self {
 36        Self {
 37            query,
 38            lowercase_query,
 39            query_char_bag,
 40            min_score: 0.0,
 41            last_positions: vec![0; lowercase_query.len()],
 42            match_positions: vec![0; query.len()],
 43            score_matrix: Vec::new(),
 44            best_position_matrix: Vec::new(),
 45            smart_case,
 46        }
 47    }
 48
 49    /// Filter and score fuzzy match candidates. Results are returned unsorted, in the same order as
 50    /// the input candidates.
 51    pub fn match_candidates<C: MatchCandidate, R, F>(
 52        &mut self,
 53        prefix: &[char],
 54        lowercase_prefix: &[char],
 55        candidates: impl Iterator<Item = C>,
 56        results: &mut Vec<R>,
 57        cancel_flag: &AtomicBool,
 58        build_match: F,
 59    ) where
 60        F: Fn(&C, f64, &Vec<usize>) -> R,
 61    {
 62        let mut candidate_chars = Vec::new();
 63        let mut lowercase_candidate_chars = Vec::new();
 64
 65        for candidate in candidates {
 66            if !candidate.has_chars(self.query_char_bag) {
 67                continue;
 68            }
 69
 70            if cancel_flag.load(atomic::Ordering::Relaxed) {
 71                break;
 72            }
 73
 74            candidate_chars.clear();
 75            lowercase_candidate_chars.clear();
 76            for c in candidate.to_string().chars() {
 77                candidate_chars.push(c);
 78                lowercase_candidate_chars.append(&mut c.to_lowercase().collect::<Vec<_>>());
 79            }
 80
 81            if !self.find_last_positions(lowercase_prefix, &lowercase_candidate_chars) {
 82                continue;
 83            }
 84
 85            let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
 86            self.score_matrix.clear();
 87            self.score_matrix.resize(matrix_len, None);
 88            self.best_position_matrix.clear();
 89            self.best_position_matrix.resize(matrix_len, 0);
 90
 91            let score = self.score_match(
 92                &candidate_chars,
 93                &lowercase_candidate_chars,
 94                prefix,
 95                lowercase_prefix,
 96            );
 97
 98            if score > 0.0 {
 99                results.push(build_match(&candidate, score, &self.match_positions));
100            }
101        }
102    }
103
104    fn find_last_positions(
105        &mut self,
106        lowercase_prefix: &[char],
107        lowercase_candidate: &[char],
108    ) -> bool {
109        let mut lowercase_prefix = lowercase_prefix.iter();
110        let mut lowercase_candidate = lowercase_candidate.iter();
111        for (i, char) in self.lowercase_query.iter().enumerate().rev() {
112            if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
113                self.last_positions[i] = j + lowercase_prefix.len();
114            } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
115                self.last_positions[i] = j;
116            } else {
117                return false;
118            }
119        }
120        true
121    }
122
123    fn score_match(
124        &mut self,
125        path: &[char],
126        path_cased: &[char],
127        prefix: &[char],
128        lowercase_prefix: &[char],
129    ) -> f64 {
130        let score = self.recursive_score_match(
131            path,
132            path_cased,
133            prefix,
134            lowercase_prefix,
135            0,
136            0,
137            self.query.len() as f64,
138        ) * self.query.len() as f64;
139
140        if score <= 0.0 {
141            return 0.0;
142        }
143
144        let path_len = prefix.len() + path.len();
145        let mut cur_start = 0;
146        let mut byte_ix = 0;
147        let mut char_ix = 0;
148        for i in 0..self.query.len() {
149            let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
150            while char_ix < match_char_ix {
151                let ch = prefix
152                    .get(char_ix)
153                    .or_else(|| path.get(char_ix - prefix.len()))
154                    .unwrap();
155                byte_ix += ch.len_utf8();
156                char_ix += 1;
157            }
158            cur_start = match_char_ix + 1;
159            self.match_positions[i] = byte_ix;
160        }
161
162        score
163    }
164
165    #[allow(clippy::too_many_arguments)]
166    fn recursive_score_match(
167        &mut self,
168        path: &[char],
169        path_cased: &[char],
170        prefix: &[char],
171        lowercase_prefix: &[char],
172        query_idx: usize,
173        path_idx: usize,
174        cur_score: f64,
175    ) -> f64 {
176        if query_idx == self.query.len() {
177            return 1.0;
178        }
179
180        let path_len = prefix.len() + path.len();
181
182        if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
183            return memoized;
184        }
185
186        let mut score = 0.0;
187        let mut best_position = 0;
188
189        let query_char = self.lowercase_query[query_idx];
190        let limit = self.last_positions[query_idx];
191
192        let mut last_slash = 0;
193        for j in path_idx..=limit {
194            let path_char = if j < prefix.len() {
195                lowercase_prefix[j]
196            } else {
197                path_cased[j - prefix.len()]
198            };
199            let is_path_sep = path_char == '/' || path_char == '\\';
200
201            if query_idx == 0 && is_path_sep {
202                last_slash = j;
203            }
204
205            if query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\') {
206                let curr = if j < prefix.len() {
207                    prefix[j]
208                } else {
209                    path[j - prefix.len()]
210                };
211
212                let mut char_score = 1.0;
213                if j > path_idx {
214                    let last = if j - 1 < prefix.len() {
215                        prefix[j - 1]
216                    } else {
217                        path[j - 1 - prefix.len()]
218                    };
219
220                    if last == '/' {
221                        char_score = 0.9;
222                    } else if (last == '-' || last == '_' || last == ' ' || last.is_numeric())
223                        || (last.is_lowercase() && curr.is_uppercase())
224                    {
225                        char_score = 0.8;
226                    } else if last == '.' {
227                        char_score = 0.7;
228                    } else if query_idx == 0 {
229                        char_score = BASE_DISTANCE_PENALTY;
230                    } else {
231                        char_score = MIN_DISTANCE_PENALTY.max(
232                            BASE_DISTANCE_PENALTY
233                                - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
234                        );
235                    }
236                }
237
238                // Apply a severe penalty if the case doesn't match.
239                // This will make the exact matches have higher score than the case-insensitive and the
240                // path insensitive matches.
241                if (self.smart_case || curr == '/') && self.query[query_idx] != curr {
242                    char_score *= 0.001;
243                }
244
245                let mut multiplier = char_score;
246
247                // Scale the score based on how deep within the path we found the match.
248                if query_idx == 0 {
249                    multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
250                }
251
252                let mut next_score = 1.0;
253                if self.min_score > 0.0 {
254                    next_score = cur_score * multiplier;
255                    // Scores only decrease. If we can't pass the previous best, bail
256                    if next_score < self.min_score {
257                        // Ensure that score is non-zero so we use it in the memo table.
258                        if score == 0.0 {
259                            score = 1e-18;
260                        }
261                        continue;
262                    }
263                }
264
265                let new_score = self.recursive_score_match(
266                    path,
267                    path_cased,
268                    prefix,
269                    lowercase_prefix,
270                    query_idx + 1,
271                    j + 1,
272                    next_score,
273                ) * multiplier;
274
275                if new_score > score {
276                    score = new_score;
277                    best_position = j;
278                    // Optimization: can't score better than 1.
279                    if new_score == 1.0 {
280                        break;
281                    }
282                }
283            }
284        }
285
286        if best_position != 0 {
287            self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
288        }
289
290        self.score_matrix[query_idx * path_len + path_idx] = Some(score);
291        score
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use crate::{PathMatch, PathMatchCandidate};
298
299    use super::*;
300    use std::{
301        path::{Path, PathBuf},
302        sync::Arc,
303    };
304
305    #[test]
306    fn test_get_last_positions() {
307        let mut query: &[char] = &['d', 'c'];
308        let mut matcher = Matcher::new(query, query, query.into(), false);
309        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
310        assert!(!result);
311
312        query = &['c', 'd'];
313        let mut matcher = Matcher::new(query, query, query.into(), false);
314        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
315        assert!(result);
316        assert_eq!(matcher.last_positions, vec![2, 4]);
317
318        query = &['z', '/', 'z', 'f'];
319        let mut matcher = Matcher::new(query, query, query.into(), false);
320        let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
321        assert!(result);
322        assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
323    }
324
325    #[test]
326    fn test_match_path_entries() {
327        let paths = vec![
328            "",
329            "a",
330            "ab",
331            "abC",
332            "abcd",
333            "alphabravocharlie",
334            "AlphaBravoCharlie",
335            "thisisatestdir",
336            "/////ThisIsATestDir",
337            "/this/is/a/test/dir",
338            "/test/tiatd",
339        ];
340
341        assert_eq!(
342            match_single_path_query("abc", false, &paths),
343            vec![
344                ("abC", vec![0, 1, 2]),
345                ("abcd", vec![0, 1, 2]),
346                ("AlphaBravoCharlie", vec![0, 5, 10]),
347                ("alphabravocharlie", vec![4, 5, 10]),
348            ]
349        );
350        assert_eq!(
351            match_single_path_query("t/i/a/t/d", false, &paths),
352            vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
353        );
354
355        assert_eq!(
356            match_single_path_query("tiatd", false, &paths),
357            vec![
358                ("/test/tiatd", vec![6, 7, 8, 9, 10]),
359                ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
360                ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
361                ("thisisatestdir", vec![0, 2, 6, 7, 11]),
362            ]
363        );
364    }
365
366    #[test]
367    fn test_lowercase_longer_than_uppercase() {
368        // This character has more chars in lower-case than in upper-case.
369        let paths = vec!["\u{0130}"];
370        let query = "\u{0130}";
371        assert_eq!(
372            match_single_path_query(query, false, &paths),
373            vec![("\u{0130}", vec![0])]
374        );
375
376        // Path is the lower-case version of the query
377        let paths = vec!["i\u{307}"];
378        let query = "\u{0130}";
379        assert_eq!(
380            match_single_path_query(query, false, &paths),
381            vec![("i\u{307}", vec![0])]
382        );
383    }
384
385    #[test]
386    fn test_match_multibyte_path_entries() {
387        let paths = vec![
388            "aαbβ/cγdδ",
389            "αβγδ/bcde",
390            "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f",
391            "/d/🆒/h",
392        ];
393        assert_eq!("1️⃣".len(), 7);
394        assert_eq!(
395            match_single_path_query("bcd", false, &paths),
396            vec![
397                ("αβγδ/bcde", vec![9, 10, 11]),
398                ("aαbβ/cγdδ", vec![3, 7, 10]),
399            ]
400        );
401        assert_eq!(
402            match_single_path_query("cde", false, &paths),
403            vec![
404                ("αβγδ/bcde", vec![10, 11, 12]),
405                ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
406            ]
407        );
408    }
409
410    fn match_single_path_query<'a>(
411        query: &str,
412        smart_case: bool,
413        paths: &[&'a str],
414    ) -> Vec<(&'a str, Vec<usize>)> {
415        let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
416        let query = query.chars().collect::<Vec<_>>();
417        let query_chars = CharBag::from(&lowercase_query[..]);
418
419        let path_arcs: Vec<Arc<Path>> = paths
420            .iter()
421            .map(|path| Arc::from(PathBuf::from(path)))
422            .collect::<Vec<_>>();
423        let mut path_entries = Vec::new();
424        for (i, path) in paths.iter().enumerate() {
425            let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
426            let char_bag = CharBag::from(lowercase_path.as_slice());
427            path_entries.push(PathMatchCandidate {
428                is_dir: false,
429                char_bag,
430                path: &path_arcs[i],
431            });
432        }
433
434        let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case);
435
436        let cancel_flag = AtomicBool::new(false);
437        let mut results = Vec::new();
438
439        matcher.match_candidates(
440            &[],
441            &[],
442            path_entries.into_iter(),
443            &mut results,
444            &cancel_flag,
445            |candidate, score, positions| PathMatch {
446                score,
447                worktree_id: 0,
448                positions: positions.clone(),
449                path: Arc::from(candidate.path),
450                path_prefix: "".into(),
451                distance_to_relative_ancestor: usize::MAX,
452                is_dir: false,
453            },
454        );
455        results.sort_by(|a, b| b.cmp(a));
456
457        results
458            .into_iter()
459            .map(|result| {
460                (
461                    paths
462                        .iter()
463                        .copied()
464                        .find(|p| result.path.as_ref() == Path::new(p))
465                        .unwrap(),
466                    result.positions,
467                )
468            })
469            .collect()
470    }
471}