fuzzy.rs

  1mod char_bag;
  2
  3use gpui::executor;
  4use std::{
  5    borrow::Cow,
  6    cmp::{self, Ordering},
  7    path::Path,
  8    sync::atomic::{self, AtomicBool},
  9    sync::Arc,
 10};
 11
 12pub use char_bag::CharBag;
 13
 14const BASE_DISTANCE_PENALTY: f64 = 0.6;
 15const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
 16const MIN_DISTANCE_PENALTY: f64 = 0.2;
 17
 18pub struct Matcher<'a> {
 19    query: &'a [char],
 20    lowercase_query: &'a [char],
 21    query_char_bag: CharBag,
 22    smart_case: bool,
 23    max_results: usize,
 24    min_score: f64,
 25    match_positions: Vec<usize>,
 26    last_positions: Vec<usize>,
 27    score_matrix: Vec<Option<f64>>,
 28    best_position_matrix: Vec<usize>,
 29}
 30
 31trait Match: Ord {
 32    fn score(&self) -> f64;
 33    fn set_positions(&mut self, positions: Vec<usize>);
 34}
 35
 36trait MatchCandidate {
 37    fn has_chars(&self, bag: CharBag) -> bool;
 38    fn to_string<'a>(&'a self) -> Cow<'a, str>;
 39}
 40
 41#[derive(Clone, Debug)]
 42pub struct PathMatchCandidate<'a> {
 43    pub path: &'a Arc<Path>,
 44    pub char_bag: CharBag,
 45}
 46
 47#[derive(Clone, Debug)]
 48pub struct PathMatch {
 49    pub score: f64,
 50    pub positions: Vec<usize>,
 51    pub worktree_id: usize,
 52    pub path: Arc<Path>,
 53    pub path_prefix: Arc<str>,
 54}
 55
 56#[derive(Clone, Debug)]
 57pub struct StringMatchCandidate {
 58    pub id: usize,
 59    pub string: String,
 60    pub char_bag: CharBag,
 61}
 62
 63pub trait PathMatchCandidateSet<'a>: Send + Sync {
 64    type Candidates: Iterator<Item = PathMatchCandidate<'a>>;
 65    fn id(&self) -> usize;
 66    fn len(&self) -> usize;
 67    fn prefix(&self) -> Arc<str>;
 68    fn candidates(&'a self, start: usize) -> Self::Candidates;
 69}
 70
 71impl Match for PathMatch {
 72    fn score(&self) -> f64 {
 73        self.score
 74    }
 75
 76    fn set_positions(&mut self, positions: Vec<usize>) {
 77        self.positions = positions;
 78    }
 79}
 80
 81impl Match for StringMatch {
 82    fn score(&self) -> f64 {
 83        self.score
 84    }
 85
 86    fn set_positions(&mut self, positions: Vec<usize>) {
 87        self.positions = positions;
 88    }
 89}
 90
 91impl<'a> MatchCandidate for PathMatchCandidate<'a> {
 92    fn has_chars(&self, bag: CharBag) -> bool {
 93        self.char_bag.is_superset(bag)
 94    }
 95
 96    fn to_string(&self) -> Cow<'a, str> {
 97        self.path.to_string_lossy()
 98    }
 99}
100
101impl StringMatchCandidate {
102    pub fn new(id: usize, string: String) -> Self {
103        Self {
104            id,
105            char_bag: CharBag::from(string.as_str()),
106            string,
107        }
108    }
109}
110
111impl<'a> MatchCandidate for &'a StringMatchCandidate {
112    fn has_chars(&self, bag: CharBag) -> bool {
113        self.char_bag.is_superset(bag)
114    }
115
116    fn to_string(&self) -> Cow<'a, str> {
117        self.string.as_str().into()
118    }
119}
120
121#[derive(Clone, Debug)]
122pub struct StringMatch {
123    pub candidate_id: usize,
124    pub score: f64,
125    pub positions: Vec<usize>,
126    pub string: String,
127}
128
129impl PartialEq for StringMatch {
130    fn eq(&self, other: &Self) -> bool {
131        self.cmp(other).is_eq()
132    }
133}
134
135impl Eq for StringMatch {}
136
137impl PartialOrd for StringMatch {
138    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
139        Some(self.cmp(other))
140    }
141}
142
143impl Ord for StringMatch {
144    fn cmp(&self, other: &Self) -> Ordering {
145        self.score
146            .partial_cmp(&other.score)
147            .unwrap_or(Ordering::Equal)
148            .then_with(|| self.candidate_id.cmp(&other.candidate_id))
149    }
150}
151
152impl PartialEq for PathMatch {
153    fn eq(&self, other: &Self) -> bool {
154        self.cmp(other).is_eq()
155    }
156}
157
158impl Eq for PathMatch {}
159
160impl PartialOrd for PathMatch {
161    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
162        Some(self.cmp(other))
163    }
164}
165
166impl Ord for PathMatch {
167    fn cmp(&self, other: &Self) -> Ordering {
168        self.score
169            .partial_cmp(&other.score)
170            .unwrap_or(Ordering::Equal)
171            .then_with(|| self.worktree_id.cmp(&other.worktree_id))
172            .then_with(|| Arc::as_ptr(&self.path).cmp(&Arc::as_ptr(&other.path)))
173    }
174}
175
176pub async fn match_strings(
177    candidates: &[StringMatchCandidate],
178    query: &str,
179    smart_case: bool,
180    max_results: usize,
181    cancel_flag: &AtomicBool,
182    background: Arc<executor::Background>,
183) -> Vec<StringMatch> {
184    let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
185    let query = query.chars().collect::<Vec<_>>();
186
187    let lowercase_query = &lowercase_query;
188    let query = &query;
189    let query_char_bag = CharBag::from(&lowercase_query[..]);
190
191    let num_cpus = background.num_cpus().min(candidates.len());
192    let segment_size = (candidates.len() + num_cpus - 1) / num_cpus;
193    let mut segment_results = (0..num_cpus)
194        .map(|_| Vec::with_capacity(max_results))
195        .collect::<Vec<_>>();
196
197    background
198        .scoped(|scope| {
199            for (segment_idx, results) in segment_results.iter_mut().enumerate() {
200                let cancel_flag = &cancel_flag;
201                scope.spawn(async move {
202                    let segment_start = cmp::min(segment_idx * segment_size, candidates.len());
203                    let segment_end = cmp::min(segment_start + segment_size, candidates.len());
204                    let mut matcher = Matcher::new(
205                        query,
206                        lowercase_query,
207                        query_char_bag,
208                        smart_case,
209                        max_results,
210                    );
211                    matcher.match_strings(
212                        &candidates[segment_start..segment_end],
213                        results,
214                        cancel_flag,
215                    );
216                });
217            }
218        })
219        .await;
220
221    let mut results = Vec::new();
222    for segment_result in segment_results {
223        if results.is_empty() {
224            results = segment_result;
225        } else {
226            util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(&a));
227        }
228    }
229    results
230}
231
232pub async fn match_paths<'a, Set: PathMatchCandidateSet<'a>>(
233    candidate_sets: &'a [Set],
234    query: &str,
235    smart_case: bool,
236    max_results: usize,
237    cancel_flag: &AtomicBool,
238    background: Arc<executor::Background>,
239) -> Vec<PathMatch> {
240    let path_count: usize = candidate_sets.iter().map(|s| s.len()).sum();
241    if path_count == 0 {
242        return Vec::new();
243    }
244
245    let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
246    let query = query.chars().collect::<Vec<_>>();
247
248    let lowercase_query = &lowercase_query;
249    let query = &query;
250    let query_char_bag = CharBag::from(&lowercase_query[..]);
251
252    let num_cpus = background.num_cpus().min(path_count);
253    let segment_size = (path_count + num_cpus - 1) / num_cpus;
254    let mut segment_results = (0..num_cpus)
255        .map(|_| Vec::with_capacity(max_results))
256        .collect::<Vec<_>>();
257
258    background
259        .scoped(|scope| {
260            for (segment_idx, results) in segment_results.iter_mut().enumerate() {
261                scope.spawn(async move {
262                    let segment_start = segment_idx * segment_size;
263                    let segment_end = segment_start + segment_size;
264                    let mut matcher = Matcher::new(
265                        query,
266                        lowercase_query,
267                        query_char_bag,
268                        smart_case,
269                        max_results,
270                    );
271
272                    let mut tree_start = 0;
273                    for candidate_set in candidate_sets {
274                        let tree_end = tree_start + candidate_set.len();
275
276                        if tree_start < segment_end && segment_start < tree_end {
277                            let start = cmp::max(tree_start, segment_start) - tree_start;
278                            let end = cmp::min(tree_end, segment_end) - tree_start;
279                            let candidates = candidate_set.candidates(start).take(end - start);
280
281                            matcher.match_paths(
282                                candidate_set.id(),
283                                candidate_set.prefix(),
284                                candidates,
285                                results,
286                                &cancel_flag,
287                            );
288                        }
289                        if tree_end >= segment_end {
290                            break;
291                        }
292                        tree_start = tree_end;
293                    }
294                })
295            }
296        })
297        .await;
298
299    let mut results = Vec::new();
300    for segment_result in segment_results {
301        if results.is_empty() {
302            results = segment_result;
303        } else {
304            util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(&a));
305        }
306    }
307    results
308}
309
310impl<'a> Matcher<'a> {
311    pub fn new(
312        query: &'a [char],
313        lowercase_query: &'a [char],
314        query_char_bag: CharBag,
315        smart_case: bool,
316        max_results: usize,
317    ) -> Self {
318        Self {
319            query,
320            lowercase_query,
321            query_char_bag,
322            min_score: 0.0,
323            last_positions: vec![0; query.len()],
324            match_positions: vec![0; query.len()],
325            score_matrix: Vec::new(),
326            best_position_matrix: Vec::new(),
327            smart_case,
328            max_results,
329        }
330    }
331
332    pub fn match_strings(
333        &mut self,
334        candidates: &[StringMatchCandidate],
335        results: &mut Vec<StringMatch>,
336        cancel_flag: &AtomicBool,
337    ) {
338        self.match_internal(
339            &[],
340            &[],
341            candidates.iter(),
342            results,
343            cancel_flag,
344            |candidate, score| StringMatch {
345                candidate_id: candidate.id,
346                score,
347                positions: Vec::new(),
348                string: candidate.string.to_string(),
349            },
350        )
351    }
352
353    pub fn match_paths<'c: 'a>(
354        &mut self,
355        tree_id: usize,
356        path_prefix: Arc<str>,
357        path_entries: impl Iterator<Item = PathMatchCandidate<'c>>,
358        results: &mut Vec<PathMatch>,
359        cancel_flag: &AtomicBool,
360    ) {
361        let prefix = path_prefix.chars().collect::<Vec<_>>();
362        let lowercase_prefix = prefix
363            .iter()
364            .map(|c| c.to_ascii_lowercase())
365            .collect::<Vec<_>>();
366        self.match_internal(
367            &prefix,
368            &lowercase_prefix,
369            path_entries,
370            results,
371            cancel_flag,
372            |candidate, score| PathMatch {
373                score,
374                worktree_id: tree_id,
375                positions: Vec::new(),
376                path: candidate.path.clone(),
377                path_prefix: path_prefix.clone(),
378            },
379        )
380    }
381
382    fn match_internal<C: MatchCandidate, R, F>(
383        &mut self,
384        prefix: &[char],
385        lowercase_prefix: &[char],
386        candidates: impl Iterator<Item = C>,
387        results: &mut Vec<R>,
388        cancel_flag: &AtomicBool,
389        build_match: F,
390    ) where
391        R: Match,
392        F: Fn(&C, f64) -> R,
393    {
394        let mut candidate_chars = Vec::new();
395        let mut lowercase_candidate_chars = Vec::new();
396
397        for candidate in candidates {
398            if !candidate.has_chars(self.query_char_bag) {
399                continue;
400            }
401
402            if cancel_flag.load(atomic::Ordering::Relaxed) {
403                break;
404            }
405
406            candidate_chars.clear();
407            lowercase_candidate_chars.clear();
408            for c in candidate.to_string().chars() {
409                candidate_chars.push(c);
410                lowercase_candidate_chars.push(c.to_ascii_lowercase());
411            }
412
413            if !self.find_last_positions(&lowercase_prefix, &lowercase_candidate_chars) {
414                continue;
415            }
416
417            let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
418            self.score_matrix.clear();
419            self.score_matrix.resize(matrix_len, None);
420            self.best_position_matrix.clear();
421            self.best_position_matrix.resize(matrix_len, 0);
422
423            let score = self.score_match(
424                &candidate_chars,
425                &lowercase_candidate_chars,
426                &prefix,
427                &lowercase_prefix,
428            );
429
430            if score > 0.0 {
431                let mut mat = build_match(&candidate, score);
432                if let Err(i) = results.binary_search_by(|m| mat.cmp(&m)) {
433                    if results.len() < self.max_results {
434                        mat.set_positions(self.match_positions.clone());
435                        results.insert(i, mat);
436                    } else if i < results.len() {
437                        results.pop();
438                        mat.set_positions(self.match_positions.clone());
439                        results.insert(i, mat);
440                    }
441                    if results.len() == self.max_results {
442                        self.min_score = results.last().unwrap().score();
443                    }
444                }
445            }
446        }
447    }
448
449    fn find_last_positions(
450        &mut self,
451        lowercase_prefix: &[char],
452        lowercase_candidate: &[char],
453    ) -> bool {
454        let mut lowercase_prefix = lowercase_prefix.iter();
455        let mut lowercase_candidate = lowercase_candidate.iter();
456        for (i, char) in self.lowercase_query.iter().enumerate().rev() {
457            if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
458                self.last_positions[i] = j + lowercase_prefix.len();
459            } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
460                self.last_positions[i] = j;
461            } else {
462                return false;
463            }
464        }
465        true
466    }
467
468    fn score_match(
469        &mut self,
470        path: &[char],
471        path_cased: &[char],
472        prefix: &[char],
473        lowercase_prefix: &[char],
474    ) -> f64 {
475        let score = self.recursive_score_match(
476            path,
477            path_cased,
478            prefix,
479            lowercase_prefix,
480            0,
481            0,
482            self.query.len() as f64,
483        ) * self.query.len() as f64;
484
485        if score <= 0.0 {
486            return 0.0;
487        }
488
489        let path_len = prefix.len() + path.len();
490        let mut cur_start = 0;
491        let mut byte_ix = 0;
492        let mut char_ix = 0;
493        for i in 0..self.query.len() {
494            let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
495            while char_ix < match_char_ix {
496                let ch = prefix
497                    .get(char_ix)
498                    .or_else(|| path.get(char_ix - prefix.len()))
499                    .unwrap();
500                byte_ix += ch.len_utf8();
501                char_ix += 1;
502            }
503            cur_start = match_char_ix + 1;
504            self.match_positions[i] = byte_ix;
505        }
506
507        score
508    }
509
510    fn recursive_score_match(
511        &mut self,
512        path: &[char],
513        path_cased: &[char],
514        prefix: &[char],
515        lowercase_prefix: &[char],
516        query_idx: usize,
517        path_idx: usize,
518        cur_score: f64,
519    ) -> f64 {
520        if query_idx == self.query.len() {
521            return 1.0;
522        }
523
524        let path_len = prefix.len() + path.len();
525
526        if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
527            return memoized;
528        }
529
530        let mut score = 0.0;
531        let mut best_position = 0;
532
533        let query_char = self.lowercase_query[query_idx];
534        let limit = self.last_positions[query_idx];
535
536        let mut last_slash = 0;
537        for j in path_idx..=limit {
538            let path_char = if j < prefix.len() {
539                lowercase_prefix[j]
540            } else {
541                path_cased[j - prefix.len()]
542            };
543            let is_path_sep = path_char == '/' || path_char == '\\';
544
545            if query_idx == 0 && is_path_sep {
546                last_slash = j;
547            }
548
549            if query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\') {
550                let curr = if j < prefix.len() {
551                    prefix[j]
552                } else {
553                    path[j - prefix.len()]
554                };
555
556                let mut char_score = 1.0;
557                if j > path_idx {
558                    let last = if j - 1 < prefix.len() {
559                        prefix[j - 1]
560                    } else {
561                        path[j - 1 - prefix.len()]
562                    };
563
564                    if last == '/' {
565                        char_score = 0.9;
566                    } else if last == '-' || last == '_' || last == ' ' || last.is_numeric() {
567                        char_score = 0.8;
568                    } else if last.is_lowercase() && curr.is_uppercase() {
569                        char_score = 0.8;
570                    } else if last == '.' {
571                        char_score = 0.7;
572                    } else if query_idx == 0 {
573                        char_score = BASE_DISTANCE_PENALTY;
574                    } else {
575                        char_score = MIN_DISTANCE_PENALTY.max(
576                            BASE_DISTANCE_PENALTY
577                                - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
578                        );
579                    }
580                }
581
582                // Apply a severe penalty if the case doesn't match.
583                // This will make the exact matches have higher score than the case-insensitive and the
584                // path insensitive matches.
585                if (self.smart_case || curr == '/') && self.query[query_idx] != curr {
586                    char_score *= 0.001;
587                }
588
589                let mut multiplier = char_score;
590
591                // Scale the score based on how deep within the path we found the match.
592                if query_idx == 0 {
593                    multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
594                }
595
596                let mut next_score = 1.0;
597                if self.min_score > 0.0 {
598                    next_score = cur_score * multiplier;
599                    // Scores only decrease. If we can't pass the previous best, bail
600                    if next_score < self.min_score {
601                        // Ensure that score is non-zero so we use it in the memo table.
602                        if score == 0.0 {
603                            score = 1e-18;
604                        }
605                        continue;
606                    }
607                }
608
609                let new_score = self.recursive_score_match(
610                    path,
611                    path_cased,
612                    prefix,
613                    lowercase_prefix,
614                    query_idx + 1,
615                    j + 1,
616                    next_score,
617                ) * multiplier;
618
619                if new_score > score {
620                    score = new_score;
621                    best_position = j;
622                    // Optimization: can't score better than 1.
623                    if new_score == 1.0 {
624                        break;
625                    }
626                }
627            }
628        }
629
630        if best_position != 0 {
631            self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
632        }
633
634        self.score_matrix[query_idx * path_len + path_idx] = Some(score);
635        score
636    }
637}
638
639#[cfg(test)]
640mod tests {
641    use super::*;
642    use std::path::PathBuf;
643
644    #[test]
645    fn test_get_last_positions() {
646        let mut query: &[char] = &['d', 'c'];
647        let mut matcher = Matcher::new(query, query, query.into(), false, 10);
648        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
649        assert_eq!(result, false);
650
651        query = &['c', 'd'];
652        let mut matcher = Matcher::new(query, query, query.into(), false, 10);
653        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
654        assert_eq!(result, true);
655        assert_eq!(matcher.last_positions, vec![2, 4]);
656
657        query = &['z', '/', 'z', 'f'];
658        let mut matcher = Matcher::new(query, query, query.into(), false, 10);
659        let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
660        assert_eq!(result, true);
661        assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
662    }
663
664    #[test]
665    fn test_match_path_entries() {
666        let paths = vec![
667            "",
668            "a",
669            "ab",
670            "abC",
671            "abcd",
672            "alphabravocharlie",
673            "AlphaBravoCharlie",
674            "thisisatestdir",
675            "/////ThisIsATestDir",
676            "/this/is/a/test/dir",
677            "/test/tiatd",
678        ];
679
680        assert_eq!(
681            match_query("abc", false, &paths),
682            vec![
683                ("abC", vec![0, 1, 2]),
684                ("abcd", vec![0, 1, 2]),
685                ("AlphaBravoCharlie", vec![0, 5, 10]),
686                ("alphabravocharlie", vec![4, 5, 10]),
687            ]
688        );
689        assert_eq!(
690            match_query("t/i/a/t/d", false, &paths),
691            vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
692        );
693
694        assert_eq!(
695            match_query("tiatd", false, &paths),
696            vec![
697                ("/test/tiatd", vec![6, 7, 8, 9, 10]),
698                ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
699                ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
700                ("thisisatestdir", vec![0, 2, 6, 7, 11]),
701            ]
702        );
703    }
704
705    #[test]
706    fn test_match_multibyte_path_entries() {
707        let paths = vec!["aαbβ/cγdδ", "αβγδ/bcde", "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", "/d/🆒/h"];
708        assert_eq!("1️⃣".len(), 7);
709        assert_eq!(
710            match_query("bcd", false, &paths),
711            vec![
712                ("αβγδ/bcde", vec![9, 10, 11]),
713                ("aαbβ/cγdδ", vec![3, 7, 10]),
714            ]
715        );
716        assert_eq!(
717            match_query("cde", false, &paths),
718            vec![
719                ("αβγδ/bcde", vec![10, 11, 12]),
720                ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
721            ]
722        );
723    }
724
725    fn match_query<'a>(
726        query: &str,
727        smart_case: bool,
728        paths: &Vec<&'a str>,
729    ) -> Vec<(&'a str, Vec<usize>)> {
730        let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
731        let query = query.chars().collect::<Vec<_>>();
732        let query_chars = CharBag::from(&lowercase_query[..]);
733
734        let path_arcs = paths
735            .iter()
736            .map(|path| Arc::from(PathBuf::from(path)))
737            .collect::<Vec<_>>();
738        let mut path_entries = Vec::new();
739        for (i, path) in paths.iter().enumerate() {
740            let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
741            let char_bag = CharBag::from(lowercase_path.as_slice());
742            path_entries.push(PathMatchCandidate {
743                char_bag,
744                path: path_arcs.get(i).unwrap(),
745            });
746        }
747
748        let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case, 100);
749
750        let cancel_flag = AtomicBool::new(false);
751        let mut results = Vec::new();
752        matcher.match_paths(
753            0,
754            "".into(),
755            path_entries.into_iter(),
756            &mut results,
757            &cancel_flag,
758        );
759
760        results
761            .into_iter()
762            .map(|result| {
763                (
764                    paths
765                        .iter()
766                        .copied()
767                        .find(|p| result.path.as_ref() == Path::new(p))
768                        .unwrap(),
769                    result.positions,
770                )
771            })
772            .collect()
773    }
774}