lib.rs

  1mod char_bag;
  2
  3use std::{
  4    borrow::Cow,
  5    cmp::Ordering,
  6    path::Path,
  7    sync::atomic::{self, AtomicBool},
  8    sync::Arc,
  9};
 10
 11pub use char_bag::CharBag;
 12
 13const BASE_DISTANCE_PENALTY: f64 = 0.6;
 14const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
 15const MIN_DISTANCE_PENALTY: f64 = 0.2;
 16
 17pub struct Matcher<'a> {
 18    query: &'a [char],
 19    lowercase_query: &'a [char],
 20    query_char_bag: CharBag,
 21    smart_case: bool,
 22    max_results: usize,
 23    min_score: f64,
 24    match_positions: Vec<usize>,
 25    last_positions: Vec<usize>,
 26    score_matrix: Vec<Option<f64>>,
 27    best_position_matrix: Vec<usize>,
 28}
 29
 30trait Match: Ord {
 31    fn score(&self) -> f64;
 32    fn set_positions(&mut self, positions: Vec<usize>);
 33}
 34
 35trait MatchCandidate {
 36    fn has_chars(&self, bag: CharBag) -> bool;
 37    fn to_string<'a>(&'a self) -> Cow<'a, str>;
 38}
 39
 40#[derive(Clone, Debug)]
 41pub struct PathMatchCandidate<'a> {
 42    pub path: &'a Arc<Path>,
 43    pub char_bag: CharBag,
 44}
 45
 46#[derive(Clone, Debug)]
 47pub struct PathMatch {
 48    pub score: f64,
 49    pub positions: Vec<usize>,
 50    pub worktree_id: usize,
 51    pub path: Arc<Path>,
 52    pub path_prefix: Arc<str>,
 53}
 54
 55#[derive(Clone, Debug)]
 56pub struct StringMatchCandidate {
 57    pub string: String,
 58    pub char_bag: CharBag,
 59}
 60
 61impl Match for PathMatch {
 62    fn score(&self) -> f64 {
 63        self.score
 64    }
 65
 66    fn set_positions(&mut self, positions: Vec<usize>) {
 67        self.positions = positions;
 68    }
 69}
 70
 71impl Match for StringMatch {
 72    fn score(&self) -> f64 {
 73        self.score
 74    }
 75
 76    fn set_positions(&mut self, positions: Vec<usize>) {
 77        self.positions = positions;
 78    }
 79}
 80
 81impl<'a> MatchCandidate for PathMatchCandidate<'a> {
 82    fn has_chars(&self, bag: CharBag) -> bool {
 83        self.char_bag.is_superset(bag)
 84    }
 85
 86    fn to_string(&self) -> Cow<'a, str> {
 87        self.path.to_string_lossy()
 88    }
 89}
 90
 91impl<'a> MatchCandidate for &'a StringMatchCandidate {
 92    fn has_chars(&self, bag: CharBag) -> bool {
 93        self.char_bag.is_superset(bag)
 94    }
 95
 96    fn to_string(&self) -> Cow<'a, str> {
 97        self.string.as_str().into()
 98    }
 99}
100
101#[derive(Clone, Debug)]
102pub struct StringMatch {
103    pub score: f64,
104    pub positions: Vec<usize>,
105    pub string: String,
106}
107
108impl PartialEq for StringMatch {
109    fn eq(&self, other: &Self) -> bool {
110        self.score.eq(&other.score)
111    }
112}
113
114impl Eq for StringMatch {}
115
116impl PartialOrd for StringMatch {
117    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
118        Some(self.cmp(other))
119    }
120}
121
122impl Ord for StringMatch {
123    fn cmp(&self, other: &Self) -> Ordering {
124        self.score
125            .partial_cmp(&other.score)
126            .unwrap_or(Ordering::Equal)
127            .then_with(|| self.string.cmp(&other.string))
128    }
129}
130
131impl PartialEq for PathMatch {
132    fn eq(&self, other: &Self) -> bool {
133        self.score.eq(&other.score)
134    }
135}
136
137impl Eq for PathMatch {}
138
139impl PartialOrd for PathMatch {
140    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
141        Some(self.cmp(other))
142    }
143}
144
145impl Ord for PathMatch {
146    fn cmp(&self, other: &Self) -> Ordering {
147        self.score
148            .partial_cmp(&other.score)
149            .unwrap_or(Ordering::Equal)
150            .then_with(|| self.worktree_id.cmp(&other.worktree_id))
151            .then_with(|| Arc::as_ptr(&self.path).cmp(&Arc::as_ptr(&other.path)))
152    }
153}
154
155impl<'a> Matcher<'a> {
156    pub fn new(
157        query: &'a [char],
158        lowercase_query: &'a [char],
159        query_char_bag: CharBag,
160        smart_case: bool,
161        max_results: usize,
162    ) -> Self {
163        Self {
164            query,
165            lowercase_query,
166            query_char_bag,
167            min_score: 0.0,
168            last_positions: vec![0; query.len()],
169            match_positions: vec![0; query.len()],
170            score_matrix: Vec::new(),
171            best_position_matrix: Vec::new(),
172            smart_case,
173            max_results,
174        }
175    }
176
177    pub fn match_strings(
178        &mut self,
179        candidates: &[StringMatchCandidate],
180        results: &mut Vec<StringMatch>,
181        cancel_flag: &AtomicBool,
182    ) {
183        self.match_internal(
184            &[],
185            &[],
186            candidates.iter(),
187            results,
188            cancel_flag,
189            |candidate, score| StringMatch {
190                score,
191                positions: Vec::new(),
192                string: candidate.string.to_string(),
193            },
194        )
195    }
196
197    pub fn match_paths(
198        &mut self,
199        tree_id: usize,
200        path_prefix: Arc<str>,
201        path_entries: impl Iterator<Item = PathMatchCandidate<'a>>,
202        results: &mut Vec<PathMatch>,
203        cancel_flag: &AtomicBool,
204    ) {
205        let prefix = path_prefix.chars().collect::<Vec<_>>();
206        let lowercase_prefix = prefix
207            .iter()
208            .map(|c| c.to_ascii_lowercase())
209            .collect::<Vec<_>>();
210        self.match_internal(
211            &prefix,
212            &lowercase_prefix,
213            path_entries,
214            results,
215            cancel_flag,
216            |candidate, score| PathMatch {
217                score,
218                worktree_id: tree_id,
219                positions: Vec::new(),
220                path: candidate.path.clone(),
221                path_prefix: path_prefix.clone(),
222            },
223        )
224    }
225
226    fn match_internal<C: MatchCandidate, R, F>(
227        &mut self,
228        prefix: &[char],
229        lowercase_prefix: &[char],
230        candidates: impl Iterator<Item = C>,
231        results: &mut Vec<R>,
232        cancel_flag: &AtomicBool,
233        build_match: F,
234    ) where
235        R: Match,
236        F: Fn(&C, f64) -> R,
237    {
238        let mut candidate_chars = Vec::new();
239        let mut lowercase_candidate_chars = Vec::new();
240
241        for candidate in candidates {
242            if !candidate.has_chars(self.query_char_bag) {
243                continue;
244            }
245
246            if cancel_flag.load(atomic::Ordering::Relaxed) {
247                break;
248            }
249
250            candidate_chars.clear();
251            lowercase_candidate_chars.clear();
252            for c in candidate.to_string().chars() {
253                candidate_chars.push(c);
254                lowercase_candidate_chars.push(c.to_ascii_lowercase());
255            }
256
257            if !self.find_last_positions(&lowercase_prefix, &lowercase_candidate_chars) {
258                continue;
259            }
260
261            let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
262            self.score_matrix.clear();
263            self.score_matrix.resize(matrix_len, None);
264            self.best_position_matrix.clear();
265            self.best_position_matrix.resize(matrix_len, 0);
266
267            let score = self.score_match(
268                &candidate_chars,
269                &lowercase_candidate_chars,
270                &prefix,
271                &lowercase_prefix,
272            );
273
274            if score > 0.0 {
275                let mut mat = build_match(&candidate, score);
276                if let Err(i) = results.binary_search_by(|m| mat.cmp(&m)) {
277                    if results.len() < self.max_results {
278                        mat.set_positions(self.match_positions.clone());
279                        results.insert(i, mat);
280                    } else if i < results.len() {
281                        results.pop();
282                        mat.set_positions(self.match_positions.clone());
283                        results.insert(i, mat);
284                    }
285                    if results.len() == self.max_results {
286                        self.min_score = results.last().unwrap().score();
287                    }
288                }
289            }
290        }
291    }
292
293    fn find_last_positions(&mut self, prefix: &[char], path: &[char]) -> bool {
294        let mut path = path.iter();
295        let mut prefix_iter = prefix.iter();
296        for (i, char) in self.query.iter().enumerate().rev() {
297            if let Some(j) = path.rposition(|c| c == char) {
298                self.last_positions[i] = j + prefix.len();
299            } else if let Some(j) = prefix_iter.rposition(|c| c == char) {
300                self.last_positions[i] = j;
301            } else {
302                return false;
303            }
304        }
305        true
306    }
307
308    fn score_match(
309        &mut self,
310        path: &[char],
311        path_cased: &[char],
312        prefix: &[char],
313        lowercase_prefix: &[char],
314    ) -> f64 {
315        let score = self.recursive_score_match(
316            path,
317            path_cased,
318            prefix,
319            lowercase_prefix,
320            0,
321            0,
322            self.query.len() as f64,
323        ) * self.query.len() as f64;
324
325        if score <= 0.0 {
326            return 0.0;
327        }
328
329        let path_len = prefix.len() + path.len();
330        let mut cur_start = 0;
331        let mut byte_ix = 0;
332        let mut char_ix = 0;
333        for i in 0..self.query.len() {
334            let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
335            while char_ix < match_char_ix {
336                let ch = prefix
337                    .get(char_ix)
338                    .or_else(|| path.get(char_ix - prefix.len()))
339                    .unwrap();
340                byte_ix += ch.len_utf8();
341                char_ix += 1;
342            }
343            cur_start = match_char_ix + 1;
344            self.match_positions[i] = byte_ix;
345        }
346
347        score
348    }
349
350    fn recursive_score_match(
351        &mut self,
352        path: &[char],
353        path_cased: &[char],
354        prefix: &[char],
355        lowercase_prefix: &[char],
356        query_idx: usize,
357        path_idx: usize,
358        cur_score: f64,
359    ) -> f64 {
360        if query_idx == self.query.len() {
361            return 1.0;
362        }
363
364        let path_len = prefix.len() + path.len();
365
366        if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
367            return memoized;
368        }
369
370        let mut score = 0.0;
371        let mut best_position = 0;
372
373        let query_char = self.lowercase_query[query_idx];
374        let limit = self.last_positions[query_idx];
375
376        let mut last_slash = 0;
377        for j in path_idx..=limit {
378            let path_char = if j < prefix.len() {
379                lowercase_prefix[j]
380            } else {
381                path_cased[j - prefix.len()]
382            };
383            let is_path_sep = path_char == '/' || path_char == '\\';
384
385            if query_idx == 0 && is_path_sep {
386                last_slash = j;
387            }
388
389            if query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\') {
390                let curr = if j < prefix.len() {
391                    prefix[j]
392                } else {
393                    path[j - prefix.len()]
394                };
395
396                let mut char_score = 1.0;
397                if j > path_idx {
398                    let last = if j - 1 < prefix.len() {
399                        prefix[j - 1]
400                    } else {
401                        path[j - 1 - prefix.len()]
402                    };
403
404                    if last == '/' {
405                        char_score = 0.9;
406                    } else if last == '-' || last == '_' || last == ' ' || last.is_numeric() {
407                        char_score = 0.8;
408                    } else if last.is_lowercase() && curr.is_uppercase() {
409                        char_score = 0.8;
410                    } else if last == '.' {
411                        char_score = 0.7;
412                    } else if query_idx == 0 {
413                        char_score = BASE_DISTANCE_PENALTY;
414                    } else {
415                        char_score = MIN_DISTANCE_PENALTY.max(
416                            BASE_DISTANCE_PENALTY
417                                - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
418                        );
419                    }
420                }
421
422                // Apply a severe penalty if the case doesn't match.
423                // This will make the exact matches have higher score than the case-insensitive and the
424                // path insensitive matches.
425                if (self.smart_case || curr == '/') && self.query[query_idx] != curr {
426                    char_score *= 0.001;
427                }
428
429                let mut multiplier = char_score;
430
431                // Scale the score based on how deep within the path we found the match.
432                if query_idx == 0 {
433                    multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
434                }
435
436                let mut next_score = 1.0;
437                if self.min_score > 0.0 {
438                    next_score = cur_score * multiplier;
439                    // Scores only decrease. If we can't pass the previous best, bail
440                    if next_score < self.min_score {
441                        // Ensure that score is non-zero so we use it in the memo table.
442                        if score == 0.0 {
443                            score = 1e-18;
444                        }
445                        continue;
446                    }
447                }
448
449                let new_score = self.recursive_score_match(
450                    path,
451                    path_cased,
452                    prefix,
453                    lowercase_prefix,
454                    query_idx + 1,
455                    j + 1,
456                    next_score,
457                ) * multiplier;
458
459                if new_score > score {
460                    score = new_score;
461                    best_position = j;
462                    // Optimization: can't score better than 1.
463                    if new_score == 1.0 {
464                        break;
465                    }
466                }
467            }
468        }
469
470        if best_position != 0 {
471            self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
472        }
473
474        self.score_matrix[query_idx * path_len + path_idx] = Some(score);
475        score
476    }
477}
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482    use std::path::PathBuf;
483
484    #[test]
485    fn test_get_last_positions() {
486        let mut query: &[char] = &['d', 'c'];
487        let mut matcher = Matcher::new(query, query, query.into(), false, 10);
488        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
489        assert_eq!(result, false);
490
491        query = &['c', 'd'];
492        let mut matcher = Matcher::new(query, query, query.into(), false, 10);
493        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
494        assert_eq!(result, true);
495        assert_eq!(matcher.last_positions, vec![2, 4]);
496
497        query = &['z', '/', 'z', 'f'];
498        let mut matcher = Matcher::new(query, query, query.into(), false, 10);
499        let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
500        assert_eq!(result, true);
501        assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
502    }
503
504    #[test]
505    fn test_match_path_entries() {
506        let paths = vec![
507            "",
508            "a",
509            "ab",
510            "abC",
511            "abcd",
512            "alphabravocharlie",
513            "AlphaBravoCharlie",
514            "thisisatestdir",
515            "/////ThisIsATestDir",
516            "/this/is/a/test/dir",
517            "/test/tiatd",
518        ];
519
520        assert_eq!(
521            match_query("abc", false, &paths),
522            vec![
523                ("abC", vec![0, 1, 2]),
524                ("abcd", vec![0, 1, 2]),
525                ("AlphaBravoCharlie", vec![0, 5, 10]),
526                ("alphabravocharlie", vec![4, 5, 10]),
527            ]
528        );
529        assert_eq!(
530            match_query("t/i/a/t/d", false, &paths),
531            vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
532        );
533
534        assert_eq!(
535            match_query("tiatd", false, &paths),
536            vec![
537                ("/test/tiatd", vec![6, 7, 8, 9, 10]),
538                ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
539                ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
540                ("thisisatestdir", vec![0, 2, 6, 7, 11]),
541            ]
542        );
543    }
544
545    #[test]
546    fn test_match_multibyte_path_entries() {
547        let paths = vec!["aαbβ/cγdδ", "αβγδ/bcde", "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", "/d/🆒/h"];
548        assert_eq!("1️⃣".len(), 7);
549        assert_eq!(
550            match_query("bcd", false, &paths),
551            vec![
552                ("αβγδ/bcde", vec![9, 10, 11]),
553                ("aαbβ/cγdδ", vec![3, 7, 10]),
554            ]
555        );
556        assert_eq!(
557            match_query("cde", false, &paths),
558            vec![
559                ("αβγδ/bcde", vec![10, 11, 12]),
560                ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
561            ]
562        );
563    }
564
565    fn match_query<'a>(
566        query: &str,
567        smart_case: bool,
568        paths: &Vec<&'a str>,
569    ) -> Vec<(&'a str, Vec<usize>)> {
570        let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
571        let query = query.chars().collect::<Vec<_>>();
572        let query_chars = CharBag::from(&lowercase_query[..]);
573
574        let path_arcs = paths
575            .iter()
576            .map(|path| Arc::from(PathBuf::from(path)))
577            .collect::<Vec<_>>();
578        let mut path_entries = Vec::new();
579        for (i, path) in paths.iter().enumerate() {
580            let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
581            let char_bag = CharBag::from(lowercase_path.as_slice());
582            path_entries.push(PathMatchCandidate {
583                char_bag,
584                path: path_arcs.get(i).unwrap(),
585            });
586        }
587
588        let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case, 100);
589
590        let cancel_flag = AtomicBool::new(false);
591        let mut results = Vec::new();
592        matcher.match_paths(
593            0,
594            "".into(),
595            path_entries.into_iter(),
596            &mut results,
597            &cancel_flag,
598        );
599
600        results
601            .into_iter()
602            .map(|result| {
603                (
604                    paths
605                        .iter()
606                        .copied()
607                        .find(|p| result.path.as_ref() == Path::new(p))
608                        .unwrap(),
609                    result.positions,
610                )
611            })
612            .collect()
613    }
614}