fuzzy.rs

  1mod char_bag;
  2
  3use gpui::executor;
  4use std::{
  5    borrow::Cow,
  6    cmp::{self, Ordering},
  7    path::Path,
  8    sync::atomic::{self, AtomicBool},
  9    sync::Arc,
 10};
 11
 12pub use char_bag::CharBag;
 13
 14const BASE_DISTANCE_PENALTY: f64 = 0.6;
 15const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
 16const MIN_DISTANCE_PENALTY: f64 = 0.2;
 17
 18pub struct Matcher<'a> {
 19    query: &'a [char],
 20    lowercase_query: &'a [char],
 21    query_char_bag: CharBag,
 22    smart_case: bool,
 23    max_results: usize,
 24    min_score: f64,
 25    match_positions: Vec<usize>,
 26    last_positions: Vec<usize>,
 27    score_matrix: Vec<Option<f64>>,
 28    best_position_matrix: Vec<usize>,
 29}
 30
 31trait Match: Ord {
 32    fn score(&self) -> f64;
 33    fn set_positions(&mut self, positions: Vec<usize>);
 34}
 35
 36trait MatchCandidate {
 37    fn has_chars(&self, bag: CharBag) -> bool;
 38    fn to_string<'a>(&'a self) -> Cow<'a, str>;
 39}
 40
 41#[derive(Clone, Debug)]
 42pub struct PathMatchCandidate<'a> {
 43    pub path: &'a Arc<Path>,
 44    pub char_bag: CharBag,
 45}
 46
 47#[derive(Clone, Debug)]
 48pub struct PathMatch {
 49    pub score: f64,
 50    pub positions: Vec<usize>,
 51    pub worktree_id: usize,
 52    pub path: Arc<Path>,
 53    pub path_prefix: Arc<str>,
 54}
 55
 56#[derive(Clone, Debug)]
 57pub struct StringMatchCandidate {
 58    pub id: usize,
 59    pub string: String,
 60    pub char_bag: CharBag,
 61}
 62
 63pub trait PathMatchCandidateSet<'a>: Send + Sync {
 64    type Candidates: Iterator<Item = PathMatchCandidate<'a>>;
 65    fn id(&self) -> usize;
 66    fn len(&self) -> usize;
 67    fn prefix(&self) -> Arc<str>;
 68    fn candidates(&'a self, start: usize) -> Self::Candidates;
 69}
 70
 71impl Match for PathMatch {
 72    fn score(&self) -> f64 {
 73        self.score
 74    }
 75
 76    fn set_positions(&mut self, positions: Vec<usize>) {
 77        self.positions = positions;
 78    }
 79}
 80
 81impl Match for StringMatch {
 82    fn score(&self) -> f64 {
 83        self.score
 84    }
 85
 86    fn set_positions(&mut self, positions: Vec<usize>) {
 87        self.positions = positions;
 88    }
 89}
 90
 91impl<'a> MatchCandidate for PathMatchCandidate<'a> {
 92    fn has_chars(&self, bag: CharBag) -> bool {
 93        self.char_bag.is_superset(bag)
 94    }
 95
 96    fn to_string(&self) -> Cow<'a, str> {
 97        self.path.to_string_lossy()
 98    }
 99}
100
101impl<'a> MatchCandidate for &'a StringMatchCandidate {
102    fn has_chars(&self, bag: CharBag) -> bool {
103        self.char_bag.is_superset(bag)
104    }
105
106    fn to_string(&self) -> Cow<'a, str> {
107        self.string.as_str().into()
108    }
109}
110
111#[derive(Clone, Debug)]
112pub struct StringMatch {
113    pub candidate_id: usize,
114    pub score: f64,
115    pub positions: Vec<usize>,
116    pub string: String,
117}
118
119impl PartialEq for StringMatch {
120    fn eq(&self, other: &Self) -> bool {
121        self.cmp(other).is_eq()
122    }
123}
124
125impl Eq for StringMatch {}
126
127impl PartialOrd for StringMatch {
128    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
129        Some(self.cmp(other))
130    }
131}
132
133impl Ord for StringMatch {
134    fn cmp(&self, other: &Self) -> Ordering {
135        self.score
136            .partial_cmp(&other.score)
137            .unwrap_or(Ordering::Equal)
138            .then_with(|| self.candidate_id.cmp(&other.candidate_id))
139    }
140}
141
142impl PartialEq for PathMatch {
143    fn eq(&self, other: &Self) -> bool {
144        self.cmp(other).is_eq()
145    }
146}
147
148impl Eq for PathMatch {}
149
150impl PartialOrd for PathMatch {
151    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
152        Some(self.cmp(other))
153    }
154}
155
156impl Ord for PathMatch {
157    fn cmp(&self, other: &Self) -> Ordering {
158        self.score
159            .partial_cmp(&other.score)
160            .unwrap_or(Ordering::Equal)
161            .then_with(|| self.worktree_id.cmp(&other.worktree_id))
162            .then_with(|| Arc::as_ptr(&self.path).cmp(&Arc::as_ptr(&other.path)))
163    }
164}
165
166pub async fn match_strings(
167    candidates: &[StringMatchCandidate],
168    query: &str,
169    smart_case: bool,
170    max_results: usize,
171    cancel_flag: &AtomicBool,
172    background: Arc<executor::Background>,
173) -> Vec<StringMatch> {
174    let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
175    let query = query.chars().collect::<Vec<_>>();
176
177    let lowercase_query = &lowercase_query;
178    let query = &query;
179    let query_char_bag = CharBag::from(&lowercase_query[..]);
180
181    let num_cpus = background.num_cpus().min(candidates.len());
182    let segment_size = (candidates.len() + num_cpus - 1) / num_cpus;
183    let mut segment_results = (0..num_cpus)
184        .map(|_| Vec::with_capacity(max_results))
185        .collect::<Vec<_>>();
186
187    background
188        .scoped(|scope| {
189            for (segment_idx, results) in segment_results.iter_mut().enumerate() {
190                let cancel_flag = &cancel_flag;
191                scope.spawn(async move {
192                    let segment_start = cmp::min(segment_idx * segment_size, candidates.len());
193                    let segment_end = cmp::min(segment_start + segment_size, candidates.len());
194                    let mut matcher = Matcher::new(
195                        query,
196                        lowercase_query,
197                        query_char_bag,
198                        smart_case,
199                        max_results,
200                    );
201                    matcher.match_strings(
202                        &candidates[segment_start..segment_end],
203                        results,
204                        cancel_flag,
205                    );
206                });
207            }
208        })
209        .await;
210
211    let mut results = Vec::new();
212    for segment_result in segment_results {
213        if results.is_empty() {
214            results = segment_result;
215        } else {
216            util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(&a));
217        }
218    }
219    results
220}
221
222pub async fn match_paths<'a, Set: PathMatchCandidateSet<'a>>(
223    candidate_sets: &'a [Set],
224    query: &str,
225    smart_case: bool,
226    max_results: usize,
227    cancel_flag: &AtomicBool,
228    background: Arc<executor::Background>,
229) -> Vec<PathMatch> {
230    let path_count: usize = candidate_sets.iter().map(|s| s.len()).sum();
231    if path_count == 0 {
232        return Vec::new();
233    }
234
235    let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
236    let query = query.chars().collect::<Vec<_>>();
237
238    let lowercase_query = &lowercase_query;
239    let query = &query;
240    let query_char_bag = CharBag::from(&lowercase_query[..]);
241
242    let num_cpus = background.num_cpus().min(path_count);
243    let segment_size = (path_count + num_cpus - 1) / num_cpus;
244    let mut segment_results = (0..num_cpus)
245        .map(|_| Vec::with_capacity(max_results))
246        .collect::<Vec<_>>();
247
248    background
249        .scoped(|scope| {
250            for (segment_idx, results) in segment_results.iter_mut().enumerate() {
251                scope.spawn(async move {
252                    let segment_start = segment_idx * segment_size;
253                    let segment_end = segment_start + segment_size;
254                    let mut matcher = Matcher::new(
255                        query,
256                        lowercase_query,
257                        query_char_bag,
258                        smart_case,
259                        max_results,
260                    );
261
262                    let mut tree_start = 0;
263                    for candidate_set in candidate_sets {
264                        let tree_end = tree_start + candidate_set.len();
265
266                        if tree_start < segment_end && segment_start < tree_end {
267                            let start = cmp::max(tree_start, segment_start) - tree_start;
268                            let end = cmp::min(tree_end, segment_end) - tree_start;
269                            let candidates = candidate_set.candidates(start).take(end - start);
270
271                            matcher.match_paths(
272                                candidate_set.id(),
273                                candidate_set.prefix(),
274                                candidates,
275                                results,
276                                &cancel_flag,
277                            );
278                        }
279                        if tree_end >= segment_end {
280                            break;
281                        }
282                        tree_start = tree_end;
283                    }
284                })
285            }
286        })
287        .await;
288
289    let mut results = Vec::new();
290    for segment_result in segment_results {
291        if results.is_empty() {
292            results = segment_result;
293        } else {
294            util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(&a));
295        }
296    }
297    results
298}
299
300impl<'a> Matcher<'a> {
301    pub fn new(
302        query: &'a [char],
303        lowercase_query: &'a [char],
304        query_char_bag: CharBag,
305        smart_case: bool,
306        max_results: usize,
307    ) -> Self {
308        Self {
309            query,
310            lowercase_query,
311            query_char_bag,
312            min_score: 0.0,
313            last_positions: vec![0; query.len()],
314            match_positions: vec![0; query.len()],
315            score_matrix: Vec::new(),
316            best_position_matrix: Vec::new(),
317            smart_case,
318            max_results,
319        }
320    }
321
322    pub fn match_strings(
323        &mut self,
324        candidates: &[StringMatchCandidate],
325        results: &mut Vec<StringMatch>,
326        cancel_flag: &AtomicBool,
327    ) {
328        self.match_internal(
329            &[],
330            &[],
331            candidates.iter(),
332            results,
333            cancel_flag,
334            |candidate, score| StringMatch {
335                candidate_id: candidate.id,
336                score,
337                positions: Vec::new(),
338                string: candidate.string.to_string(),
339            },
340        )
341    }
342
343    pub fn match_paths<'c: 'a>(
344        &mut self,
345        tree_id: usize,
346        path_prefix: Arc<str>,
347        path_entries: impl Iterator<Item = PathMatchCandidate<'c>>,
348        results: &mut Vec<PathMatch>,
349        cancel_flag: &AtomicBool,
350    ) {
351        let prefix = path_prefix.chars().collect::<Vec<_>>();
352        let lowercase_prefix = prefix
353            .iter()
354            .map(|c| c.to_ascii_lowercase())
355            .collect::<Vec<_>>();
356        self.match_internal(
357            &prefix,
358            &lowercase_prefix,
359            path_entries,
360            results,
361            cancel_flag,
362            |candidate, score| PathMatch {
363                score,
364                worktree_id: tree_id,
365                positions: Vec::new(),
366                path: candidate.path.clone(),
367                path_prefix: path_prefix.clone(),
368            },
369        )
370    }
371
372    fn match_internal<C: MatchCandidate, R, F>(
373        &mut self,
374        prefix: &[char],
375        lowercase_prefix: &[char],
376        candidates: impl Iterator<Item = C>,
377        results: &mut Vec<R>,
378        cancel_flag: &AtomicBool,
379        build_match: F,
380    ) where
381        R: Match,
382        F: Fn(&C, f64) -> R,
383    {
384        let mut candidate_chars = Vec::new();
385        let mut lowercase_candidate_chars = Vec::new();
386
387        for candidate in candidates {
388            if !candidate.has_chars(self.query_char_bag) {
389                continue;
390            }
391
392            if cancel_flag.load(atomic::Ordering::Relaxed) {
393                break;
394            }
395
396            candidate_chars.clear();
397            lowercase_candidate_chars.clear();
398            for c in candidate.to_string().chars() {
399                candidate_chars.push(c);
400                lowercase_candidate_chars.push(c.to_ascii_lowercase());
401            }
402
403            if !self.find_last_positions(&lowercase_prefix, &lowercase_candidate_chars) {
404                continue;
405            }
406
407            let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
408            self.score_matrix.clear();
409            self.score_matrix.resize(matrix_len, None);
410            self.best_position_matrix.clear();
411            self.best_position_matrix.resize(matrix_len, 0);
412
413            let score = self.score_match(
414                &candidate_chars,
415                &lowercase_candidate_chars,
416                &prefix,
417                &lowercase_prefix,
418            );
419
420            if score > 0.0 {
421                let mut mat = build_match(&candidate, score);
422                if let Err(i) = results.binary_search_by(|m| mat.cmp(&m)) {
423                    if results.len() < self.max_results {
424                        mat.set_positions(self.match_positions.clone());
425                        results.insert(i, mat);
426                    } else if i < results.len() {
427                        results.pop();
428                        mat.set_positions(self.match_positions.clone());
429                        results.insert(i, mat);
430                    }
431                    if results.len() == self.max_results {
432                        self.min_score = results.last().unwrap().score();
433                    }
434                }
435            }
436        }
437    }
438
439    fn find_last_positions(
440        &mut self,
441        lowercase_prefix: &[char],
442        lowercase_candidate: &[char],
443    ) -> bool {
444        let mut lowercase_prefix = lowercase_prefix.iter();
445        let mut lowercase_candidate = lowercase_candidate.iter();
446        for (i, char) in self.lowercase_query.iter().enumerate().rev() {
447            if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
448                self.last_positions[i] = j + lowercase_prefix.len();
449            } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
450                self.last_positions[i] = j;
451            } else {
452                return false;
453            }
454        }
455        true
456    }
457
458    fn score_match(
459        &mut self,
460        path: &[char],
461        path_cased: &[char],
462        prefix: &[char],
463        lowercase_prefix: &[char],
464    ) -> f64 {
465        let score = self.recursive_score_match(
466            path,
467            path_cased,
468            prefix,
469            lowercase_prefix,
470            0,
471            0,
472            self.query.len() as f64,
473        ) * self.query.len() as f64;
474
475        if score <= 0.0 {
476            return 0.0;
477        }
478
479        let path_len = prefix.len() + path.len();
480        let mut cur_start = 0;
481        let mut byte_ix = 0;
482        let mut char_ix = 0;
483        for i in 0..self.query.len() {
484            let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
485            while char_ix < match_char_ix {
486                let ch = prefix
487                    .get(char_ix)
488                    .or_else(|| path.get(char_ix - prefix.len()))
489                    .unwrap();
490                byte_ix += ch.len_utf8();
491                char_ix += 1;
492            }
493            cur_start = match_char_ix + 1;
494            self.match_positions[i] = byte_ix;
495        }
496
497        score
498    }
499
500    fn recursive_score_match(
501        &mut self,
502        path: &[char],
503        path_cased: &[char],
504        prefix: &[char],
505        lowercase_prefix: &[char],
506        query_idx: usize,
507        path_idx: usize,
508        cur_score: f64,
509    ) -> f64 {
510        if query_idx == self.query.len() {
511            return 1.0;
512        }
513
514        let path_len = prefix.len() + path.len();
515
516        if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
517            return memoized;
518        }
519
520        let mut score = 0.0;
521        let mut best_position = 0;
522
523        let query_char = self.lowercase_query[query_idx];
524        let limit = self.last_positions[query_idx];
525
526        let mut last_slash = 0;
527        for j in path_idx..=limit {
528            let path_char = if j < prefix.len() {
529                lowercase_prefix[j]
530            } else {
531                path_cased[j - prefix.len()]
532            };
533            let is_path_sep = path_char == '/' || path_char == '\\';
534
535            if query_idx == 0 && is_path_sep {
536                last_slash = j;
537            }
538
539            if query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\') {
540                let curr = if j < prefix.len() {
541                    prefix[j]
542                } else {
543                    path[j - prefix.len()]
544                };
545
546                let mut char_score = 1.0;
547                if j > path_idx {
548                    let last = if j - 1 < prefix.len() {
549                        prefix[j - 1]
550                    } else {
551                        path[j - 1 - prefix.len()]
552                    };
553
554                    if last == '/' {
555                        char_score = 0.9;
556                    } else if last == '-' || last == '_' || last == ' ' || last.is_numeric() {
557                        char_score = 0.8;
558                    } else if last.is_lowercase() && curr.is_uppercase() {
559                        char_score = 0.8;
560                    } else if last == '.' {
561                        char_score = 0.7;
562                    } else if query_idx == 0 {
563                        char_score = BASE_DISTANCE_PENALTY;
564                    } else {
565                        char_score = MIN_DISTANCE_PENALTY.max(
566                            BASE_DISTANCE_PENALTY
567                                - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
568                        );
569                    }
570                }
571
572                // Apply a severe penalty if the case doesn't match.
573                // This will make the exact matches have higher score than the case-insensitive and the
574                // path insensitive matches.
575                if (self.smart_case || curr == '/') && self.query[query_idx] != curr {
576                    char_score *= 0.001;
577                }
578
579                let mut multiplier = char_score;
580
581                // Scale the score based on how deep within the path we found the match.
582                if query_idx == 0 {
583                    multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
584                }
585
586                let mut next_score = 1.0;
587                if self.min_score > 0.0 {
588                    next_score = cur_score * multiplier;
589                    // Scores only decrease. If we can't pass the previous best, bail
590                    if next_score < self.min_score {
591                        // Ensure that score is non-zero so we use it in the memo table.
592                        if score == 0.0 {
593                            score = 1e-18;
594                        }
595                        continue;
596                    }
597                }
598
599                let new_score = self.recursive_score_match(
600                    path,
601                    path_cased,
602                    prefix,
603                    lowercase_prefix,
604                    query_idx + 1,
605                    j + 1,
606                    next_score,
607                ) * multiplier;
608
609                if new_score > score {
610                    score = new_score;
611                    best_position = j;
612                    // Optimization: can't score better than 1.
613                    if new_score == 1.0 {
614                        break;
615                    }
616                }
617            }
618        }
619
620        if best_position != 0 {
621            self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
622        }
623
624        self.score_matrix[query_idx * path_len + path_idx] = Some(score);
625        score
626    }
627}
628
629#[cfg(test)]
630mod tests {
631    use super::*;
632    use std::path::PathBuf;
633
634    #[test]
635    fn test_get_last_positions() {
636        let mut query: &[char] = &['d', 'c'];
637        let mut matcher = Matcher::new(query, query, query.into(), false, 10);
638        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
639        assert_eq!(result, false);
640
641        query = &['c', 'd'];
642        let mut matcher = Matcher::new(query, query, query.into(), false, 10);
643        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
644        assert_eq!(result, true);
645        assert_eq!(matcher.last_positions, vec![2, 4]);
646
647        query = &['z', '/', 'z', 'f'];
648        let mut matcher = Matcher::new(query, query, query.into(), false, 10);
649        let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
650        assert_eq!(result, true);
651        assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
652    }
653
654    #[test]
655    fn test_match_path_entries() {
656        let paths = vec![
657            "",
658            "a",
659            "ab",
660            "abC",
661            "abcd",
662            "alphabravocharlie",
663            "AlphaBravoCharlie",
664            "thisisatestdir",
665            "/////ThisIsATestDir",
666            "/this/is/a/test/dir",
667            "/test/tiatd",
668        ];
669
670        assert_eq!(
671            match_query("abc", false, &paths),
672            vec![
673                ("abC", vec![0, 1, 2]),
674                ("abcd", vec![0, 1, 2]),
675                ("AlphaBravoCharlie", vec![0, 5, 10]),
676                ("alphabravocharlie", vec![4, 5, 10]),
677            ]
678        );
679        assert_eq!(
680            match_query("t/i/a/t/d", false, &paths),
681            vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
682        );
683
684        assert_eq!(
685            match_query("tiatd", false, &paths),
686            vec![
687                ("/test/tiatd", vec![6, 7, 8, 9, 10]),
688                ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
689                ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
690                ("thisisatestdir", vec![0, 2, 6, 7, 11]),
691            ]
692        );
693    }
694
695    #[test]
696    fn test_match_multibyte_path_entries() {
697        let paths = vec!["aαbβ/cγdδ", "αβγδ/bcde", "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", "/d/🆒/h"];
698        assert_eq!("1️⃣".len(), 7);
699        assert_eq!(
700            match_query("bcd", false, &paths),
701            vec![
702                ("αβγδ/bcde", vec![9, 10, 11]),
703                ("aαbβ/cγdδ", vec![3, 7, 10]),
704            ]
705        );
706        assert_eq!(
707            match_query("cde", false, &paths),
708            vec![
709                ("αβγδ/bcde", vec![10, 11, 12]),
710                ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
711            ]
712        );
713    }
714
715    fn match_query<'a>(
716        query: &str,
717        smart_case: bool,
718        paths: &Vec<&'a str>,
719    ) -> Vec<(&'a str, Vec<usize>)> {
720        let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
721        let query = query.chars().collect::<Vec<_>>();
722        let query_chars = CharBag::from(&lowercase_query[..]);
723
724        let path_arcs = paths
725            .iter()
726            .map(|path| Arc::from(PathBuf::from(path)))
727            .collect::<Vec<_>>();
728        let mut path_entries = Vec::new();
729        for (i, path) in paths.iter().enumerate() {
730            let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
731            let char_bag = CharBag::from(lowercase_path.as_slice());
732            path_entries.push(PathMatchCandidate {
733                char_bag,
734                path: path_arcs.get(i).unwrap(),
735            });
736        }
737
738        let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case, 100);
739
740        let cancel_flag = AtomicBool::new(false);
741        let mut results = Vec::new();
742        matcher.match_paths(
743            0,
744            "".into(),
745            path_entries.into_iter(),
746            &mut results,
747            &cancel_flag,
748        );
749
750        results
751            .into_iter()
752            .map(|result| {
753                (
754                    paths
755                        .iter()
756                        .copied()
757                        .find(|p| result.path.as_ref() == Path::new(p))
758                        .unwrap(),
759                    result.positions,
760                )
761            })
762            .collect()
763    }
764}