matcher.rs

  1use std::{
  2    borrow::Cow,
  3    sync::atomic::{self, AtomicBool},
  4};
  5
  6use crate::CharBag;
  7
  8const BASE_DISTANCE_PENALTY: f64 = 0.6;
  9const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
 10const MIN_DISTANCE_PENALTY: f64 = 0.2;
 11
 12pub struct Matcher<'a> {
 13    query: &'a [char],
 14    lowercase_query: &'a [char],
 15    query_char_bag: CharBag,
 16    smart_case: bool,
 17    max_results: usize,
 18    min_score: f64,
 19    match_positions: Vec<usize>,
 20    last_positions: Vec<usize>,
 21    score_matrix: Vec<Option<f64>>,
 22    best_position_matrix: Vec<usize>,
 23}
 24
 25pub trait Match: Ord {
 26    fn score(&self) -> f64;
 27    fn set_positions(&mut self, positions: Vec<usize>);
 28}
 29
 30pub trait MatchCandidate {
 31    fn has_chars(&self, bag: CharBag) -> bool;
 32    fn to_string(&self) -> Cow<'_, str>;
 33}
 34
 35impl<'a> Matcher<'a> {
 36    pub fn new(
 37        query: &'a [char],
 38        lowercase_query: &'a [char],
 39        query_char_bag: CharBag,
 40        smart_case: bool,
 41        max_results: usize,
 42    ) -> Self {
 43        Self {
 44            query,
 45            lowercase_query,
 46            query_char_bag,
 47            min_score: 0.0,
 48            last_positions: vec![0; lowercase_query.len()],
 49            match_positions: vec![0; query.len()],
 50            score_matrix: Vec::new(),
 51            best_position_matrix: Vec::new(),
 52            smart_case,
 53            max_results,
 54        }
 55    }
 56
 57    pub fn match_candidates<C: MatchCandidate, R, F>(
 58        &mut self,
 59        prefix: &[char],
 60        lowercase_prefix: &[char],
 61        candidates: impl Iterator<Item = C>,
 62        results: &mut Vec<R>,
 63        cancel_flag: &AtomicBool,
 64        build_match: F,
 65    ) where
 66        R: Match,
 67        F: Fn(&C, f64) -> R,
 68    {
 69        let mut candidate_chars = Vec::new();
 70        let mut lowercase_candidate_chars = Vec::new();
 71
 72        for candidate in candidates {
 73            if !candidate.has_chars(self.query_char_bag) {
 74                continue;
 75            }
 76
 77            if cancel_flag.load(atomic::Ordering::Relaxed) {
 78                break;
 79            }
 80
 81            candidate_chars.clear();
 82            lowercase_candidate_chars.clear();
 83            for c in candidate.to_string().chars() {
 84                candidate_chars.push(c);
 85                lowercase_candidate_chars.append(&mut c.to_lowercase().collect::<Vec<_>>());
 86            }
 87
 88            if !self.find_last_positions(lowercase_prefix, &lowercase_candidate_chars) {
 89                continue;
 90            }
 91
 92            let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
 93            self.score_matrix.clear();
 94            self.score_matrix.resize(matrix_len, None);
 95            self.best_position_matrix.clear();
 96            self.best_position_matrix.resize(matrix_len, 0);
 97
 98            let score = self.score_match(
 99                &candidate_chars,
100                &lowercase_candidate_chars,
101                prefix,
102                lowercase_prefix,
103            );
104
105            if score > 0.0 {
106                let mut mat = build_match(&candidate, score);
107                if let Err(i) = results.binary_search_by(|m| mat.cmp(m)) {
108                    if results.len() < self.max_results {
109                        mat.set_positions(self.match_positions.clone());
110                        results.insert(i, mat);
111                    } else if i < results.len() {
112                        results.pop();
113                        mat.set_positions(self.match_positions.clone());
114                        results.insert(i, mat);
115                    }
116                    if results.len() == self.max_results {
117                        self.min_score = results.last().unwrap().score();
118                    }
119                }
120            }
121        }
122    }
123
124    fn find_last_positions(
125        &mut self,
126        lowercase_prefix: &[char],
127        lowercase_candidate: &[char],
128    ) -> bool {
129        let mut lowercase_prefix = lowercase_prefix.iter();
130        let mut lowercase_candidate = lowercase_candidate.iter();
131        for (i, char) in self.lowercase_query.iter().enumerate().rev() {
132            if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
133                self.last_positions[i] = j + lowercase_prefix.len();
134            } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
135                self.last_positions[i] = j;
136            } else {
137                return false;
138            }
139        }
140        true
141    }
142
143    fn score_match(
144        &mut self,
145        path: &[char],
146        path_cased: &[char],
147        prefix: &[char],
148        lowercase_prefix: &[char],
149    ) -> f64 {
150        let score = self.recursive_score_match(
151            path,
152            path_cased,
153            prefix,
154            lowercase_prefix,
155            0,
156            0,
157            self.query.len() as f64,
158        ) * self.query.len() as f64;
159
160        if score <= 0.0 {
161            return 0.0;
162        }
163
164        let path_len = prefix.len() + path.len();
165        let mut cur_start = 0;
166        let mut byte_ix = 0;
167        let mut char_ix = 0;
168        for i in 0..self.query.len() {
169            let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
170            while char_ix < match_char_ix {
171                let ch = prefix
172                    .get(char_ix)
173                    .or_else(|| path.get(char_ix - prefix.len()))
174                    .unwrap();
175                byte_ix += ch.len_utf8();
176                char_ix += 1;
177            }
178            cur_start = match_char_ix + 1;
179            self.match_positions[i] = byte_ix;
180        }
181
182        score
183    }
184
185    #[allow(clippy::too_many_arguments)]
186    fn recursive_score_match(
187        &mut self,
188        path: &[char],
189        path_cased: &[char],
190        prefix: &[char],
191        lowercase_prefix: &[char],
192        query_idx: usize,
193        path_idx: usize,
194        cur_score: f64,
195    ) -> f64 {
196        if query_idx == self.query.len() {
197            return 1.0;
198        }
199
200        let path_len = prefix.len() + path.len();
201
202        if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
203            return memoized;
204        }
205
206        let mut score = 0.0;
207        let mut best_position = 0;
208
209        let query_char = self.lowercase_query[query_idx];
210        let limit = self.last_positions[query_idx];
211
212        let mut last_slash = 0;
213        for j in path_idx..=limit {
214            let path_char = if j < prefix.len() {
215                lowercase_prefix[j]
216            } else {
217                path_cased[j - prefix.len()]
218            };
219            let is_path_sep = path_char == '/' || path_char == '\\';
220
221            if query_idx == 0 && is_path_sep {
222                last_slash = j;
223            }
224
225            if query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\') {
226                let curr = if j < prefix.len() {
227                    prefix[j]
228                } else {
229                    path[j - prefix.len()]
230                };
231
232                let mut char_score = 1.0;
233                if j > path_idx {
234                    let last = if j - 1 < prefix.len() {
235                        prefix[j - 1]
236                    } else {
237                        path[j - 1 - prefix.len()]
238                    };
239
240                    if last == '/' {
241                        char_score = 0.9;
242                    } else if (last == '-' || last == '_' || last == ' ' || last.is_numeric())
243                        || (last.is_lowercase() && curr.is_uppercase())
244                    {
245                        char_score = 0.8;
246                    } else if last == '.' {
247                        char_score = 0.7;
248                    } else if query_idx == 0 {
249                        char_score = BASE_DISTANCE_PENALTY;
250                    } else {
251                        char_score = MIN_DISTANCE_PENALTY.max(
252                            BASE_DISTANCE_PENALTY
253                                - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
254                        );
255                    }
256                }
257
258                // Apply a severe penalty if the case doesn't match.
259                // This will make the exact matches have higher score than the case-insensitive and the
260                // path insensitive matches.
261                if (self.smart_case || curr == '/') && self.query[query_idx] != curr {
262                    char_score *= 0.001;
263                }
264
265                let mut multiplier = char_score;
266
267                // Scale the score based on how deep within the path we found the match.
268                if query_idx == 0 {
269                    multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
270                }
271
272                let mut next_score = 1.0;
273                if self.min_score > 0.0 {
274                    next_score = cur_score * multiplier;
275                    // Scores only decrease. If we can't pass the previous best, bail
276                    if next_score < self.min_score {
277                        // Ensure that score is non-zero so we use it in the memo table.
278                        if score == 0.0 {
279                            score = 1e-18;
280                        }
281                        continue;
282                    }
283                }
284
285                let new_score = self.recursive_score_match(
286                    path,
287                    path_cased,
288                    prefix,
289                    lowercase_prefix,
290                    query_idx + 1,
291                    j + 1,
292                    next_score,
293                ) * multiplier;
294
295                if new_score > score {
296                    score = new_score;
297                    best_position = j;
298                    // Optimization: can't score better than 1.
299                    if new_score == 1.0 {
300                        break;
301                    }
302                }
303            }
304        }
305
306        if best_position != 0 {
307            self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
308        }
309
310        self.score_matrix[query_idx * path_len + path_idx] = Some(score);
311        score
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use crate::{PathMatch, PathMatchCandidate};
318
319    use super::*;
320    use std::{
321        path::{Path, PathBuf},
322        sync::Arc,
323    };
324
325    #[test]
326    fn test_get_last_positions() {
327        let mut query: &[char] = &['d', 'c'];
328        let mut matcher = Matcher::new(query, query, query.into(), false, 10);
329        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
330        assert!(!result);
331
332        query = &['c', 'd'];
333        let mut matcher = Matcher::new(query, query, query.into(), false, 10);
334        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
335        assert!(result);
336        assert_eq!(matcher.last_positions, vec![2, 4]);
337
338        query = &['z', '/', 'z', 'f'];
339        let mut matcher = Matcher::new(query, query, query.into(), false, 10);
340        let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
341        assert!(result);
342        assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
343    }
344
345    #[test]
346    fn test_match_path_entries() {
347        let paths = vec![
348            "",
349            "a",
350            "ab",
351            "abC",
352            "abcd",
353            "alphabravocharlie",
354            "AlphaBravoCharlie",
355            "thisisatestdir",
356            "/////ThisIsATestDir",
357            "/this/is/a/test/dir",
358            "/test/tiatd",
359        ];
360
361        assert_eq!(
362            match_single_path_query("abc", false, &paths),
363            vec![
364                ("abC", vec![0, 1, 2]),
365                ("abcd", vec![0, 1, 2]),
366                ("AlphaBravoCharlie", vec![0, 5, 10]),
367                ("alphabravocharlie", vec![4, 5, 10]),
368            ]
369        );
370        assert_eq!(
371            match_single_path_query("t/i/a/t/d", false, &paths),
372            vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
373        );
374
375        assert_eq!(
376            match_single_path_query("tiatd", false, &paths),
377            vec![
378                ("/test/tiatd", vec![6, 7, 8, 9, 10]),
379                ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
380                ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
381                ("thisisatestdir", vec![0, 2, 6, 7, 11]),
382            ]
383        );
384    }
385
386    #[test]
387    fn test_lowercase_longer_than_uppercase() {
388        // This character has more chars in lower-case than in upper-case.
389        let paths = vec!["\u{0130}"];
390        let query = "\u{0130}";
391        assert_eq!(
392            match_single_path_query(query, false, &paths),
393            vec![("\u{0130}", vec![0])]
394        );
395
396        // Path is the lower-case version of the query
397        let paths = vec!["i\u{307}"];
398        let query = "\u{0130}";
399        assert_eq!(
400            match_single_path_query(query, false, &paths),
401            vec![("i\u{307}", vec![0])]
402        );
403    }
404
405    #[test]
406    fn test_match_multibyte_path_entries() {
407        let paths = vec![
408            "aαbβ/cγdδ",
409            "αβγδ/bcde",
410            "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f",
411            "/d/🆒/h",
412        ];
413        assert_eq!("1️⃣".len(), 7);
414        assert_eq!(
415            match_single_path_query("bcd", false, &paths),
416            vec![
417                ("αβγδ/bcde", vec![9, 10, 11]),
418                ("aαbβ/cγdδ", vec![3, 7, 10]),
419            ]
420        );
421        assert_eq!(
422            match_single_path_query("cde", false, &paths),
423            vec![
424                ("αβγδ/bcde", vec![10, 11, 12]),
425                ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
426            ]
427        );
428    }
429
430    fn match_single_path_query<'a>(
431        query: &str,
432        smart_case: bool,
433        paths: &[&'a str],
434    ) -> Vec<(&'a str, Vec<usize>)> {
435        let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
436        let query = query.chars().collect::<Vec<_>>();
437        let query_chars = CharBag::from(&lowercase_query[..]);
438
439        let path_arcs: Vec<Arc<Path>> = paths
440            .iter()
441            .map(|path| Arc::from(PathBuf::from(path)))
442            .collect::<Vec<_>>();
443        let mut path_entries = Vec::new();
444        for (i, path) in paths.iter().enumerate() {
445            let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
446            let char_bag = CharBag::from(lowercase_path.as_slice());
447            path_entries.push(PathMatchCandidate {
448                is_dir: false,
449                char_bag,
450                path: &path_arcs[i],
451            });
452        }
453
454        let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case, 100);
455
456        let cancel_flag = AtomicBool::new(false);
457        let mut results = Vec::new();
458
459        matcher.match_candidates(
460            &[],
461            &[],
462            path_entries.into_iter(),
463            &mut results,
464            &cancel_flag,
465            |candidate, score| PathMatch {
466                score,
467                worktree_id: 0,
468                positions: Vec::new(),
469                path: Arc::from(candidate.path),
470                path_prefix: "".into(),
471                distance_to_relative_ancestor: usize::MAX,
472                is_dir: false,
473            },
474        );
475
476        results
477            .into_iter()
478            .map(|result| {
479                (
480                    paths
481                        .iter()
482                        .copied()
483                        .find(|p| result.path.as_ref() == Path::new(p))
484                        .unwrap(),
485                    result.positions,
486                )
487            })
488            .collect()
489    }
490}