fuzzy.rs

  1mod char_bag;
  2
  3use gpui::executor;
  4use std::{
  5    borrow::Cow,
  6    cmp::{self, Ordering},
  7    path::Path,
  8    sync::atomic::{self, AtomicBool},
  9    sync::Arc,
 10};
 11
 12pub use char_bag::CharBag;
 13
 14const BASE_DISTANCE_PENALTY: f64 = 0.6;
 15const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
 16const MIN_DISTANCE_PENALTY: f64 = 0.2;
 17
 18pub struct Matcher<'a> {
 19    query: &'a [char],
 20    lowercase_query: &'a [char],
 21    query_char_bag: CharBag,
 22    smart_case: bool,
 23    max_results: usize,
 24    min_score: f64,
 25    match_positions: Vec<usize>,
 26    last_positions: Vec<usize>,
 27    score_matrix: Vec<Option<f64>>,
 28    best_position_matrix: Vec<usize>,
 29}
 30
 31trait Match: Ord {
 32    fn score(&self) -> f64;
 33    fn set_positions(&mut self, positions: Vec<usize>);
 34}
 35
 36trait MatchCandidate {
 37    fn has_chars(&self, bag: CharBag) -> bool;
 38    fn to_string<'a>(&'a self) -> Cow<'a, str>;
 39}
 40
 41#[derive(Clone, Debug)]
 42pub struct PathMatchCandidate<'a> {
 43    pub path: &'a Arc<Path>,
 44    pub char_bag: CharBag,
 45}
 46
 47#[derive(Clone, Debug)]
 48pub struct PathMatch {
 49    pub score: f64,
 50    pub positions: Vec<usize>,
 51    pub worktree_id: usize,
 52    pub path: Arc<Path>,
 53    pub path_prefix: Arc<str>,
 54}
 55
 56#[derive(Clone, Debug)]
 57pub struct StringMatchCandidate {
 58    pub id: usize,
 59    pub string: String,
 60    pub char_bag: CharBag,
 61}
 62
 63pub trait PathMatchCandidateSet<'a>: Send + Sync {
 64    type Candidates: Iterator<Item = PathMatchCandidate<'a>>;
 65    fn id(&self) -> usize;
 66    fn len(&self) -> usize;
 67    fn prefix(&self) -> Arc<str>;
 68    fn candidates(&'a self, start: usize) -> Self::Candidates;
 69}
 70
 71impl Match for PathMatch {
 72    fn score(&self) -> f64 {
 73        self.score
 74    }
 75
 76    fn set_positions(&mut self, positions: Vec<usize>) {
 77        self.positions = positions;
 78    }
 79}
 80
 81impl Match for StringMatch {
 82    fn score(&self) -> f64 {
 83        self.score
 84    }
 85
 86    fn set_positions(&mut self, positions: Vec<usize>) {
 87        self.positions = positions;
 88    }
 89}
 90
 91impl<'a> MatchCandidate for PathMatchCandidate<'a> {
 92    fn has_chars(&self, bag: CharBag) -> bool {
 93        self.char_bag.is_superset(bag)
 94    }
 95
 96    fn to_string(&self) -> Cow<'a, str> {
 97        self.path.to_string_lossy()
 98    }
 99}
100
101impl StringMatchCandidate {
102    pub fn new(id: usize, string: String) -> Self {
103        Self {
104            id,
105            char_bag: CharBag::from(string.as_str()),
106            string,
107        }
108    }
109}
110
111impl<'a> MatchCandidate for &'a StringMatchCandidate {
112    fn has_chars(&self, bag: CharBag) -> bool {
113        self.char_bag.is_superset(bag)
114    }
115
116    fn to_string(&self) -> Cow<'a, str> {
117        self.string.as_str().into()
118    }
119}
120
121#[derive(Clone, Debug)]
122pub struct StringMatch {
123    pub candidate_id: usize,
124    pub score: f64,
125    pub positions: Vec<usize>,
126    pub string: String,
127}
128
129impl PartialEq for StringMatch {
130    fn eq(&self, other: &Self) -> bool {
131        self.cmp(other).is_eq()
132    }
133}
134
135impl Eq for StringMatch {}
136
137impl PartialOrd for StringMatch {
138    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
139        Some(self.cmp(other))
140    }
141}
142
143impl Ord for StringMatch {
144    fn cmp(&self, other: &Self) -> Ordering {
145        self.score
146            .partial_cmp(&other.score)
147            .unwrap_or(Ordering::Equal)
148            .then_with(|| self.candidate_id.cmp(&other.candidate_id))
149    }
150}
151
152impl PartialEq for PathMatch {
153    fn eq(&self, other: &Self) -> bool {
154        self.cmp(other).is_eq()
155    }
156}
157
158impl Eq for PathMatch {}
159
160impl PartialOrd for PathMatch {
161    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
162        Some(self.cmp(other))
163    }
164}
165
166impl Ord for PathMatch {
167    fn cmp(&self, other: &Self) -> Ordering {
168        self.score
169            .partial_cmp(&other.score)
170            .unwrap_or(Ordering::Equal)
171            .then_with(|| self.worktree_id.cmp(&other.worktree_id))
172            .then_with(|| Arc::as_ptr(&self.path).cmp(&Arc::as_ptr(&other.path)))
173    }
174}
175
176pub async fn match_strings(
177    candidates: &[StringMatchCandidate],
178    query: &str,
179    smart_case: bool,
180    max_results: usize,
181    cancel_flag: &AtomicBool,
182    background: Arc<executor::Background>,
183) -> Vec<StringMatch> {
184    if candidates.is_empty() {
185        return Default::default();
186    }
187
188    let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
189    let query = query.chars().collect::<Vec<_>>();
190
191    let lowercase_query = &lowercase_query;
192    let query = &query;
193    let query_char_bag = CharBag::from(&lowercase_query[..]);
194
195    let num_cpus = background.num_cpus().min(candidates.len());
196    let segment_size = (candidates.len() + num_cpus - 1) / num_cpus;
197    let mut segment_results = (0..num_cpus)
198        .map(|_| Vec::with_capacity(max_results))
199        .collect::<Vec<_>>();
200
201    background
202        .scoped(|scope| {
203            for (segment_idx, results) in segment_results.iter_mut().enumerate() {
204                let cancel_flag = &cancel_flag;
205                scope.spawn(async move {
206                    let segment_start = cmp::min(segment_idx * segment_size, candidates.len());
207                    let segment_end = cmp::min(segment_start + segment_size, candidates.len());
208                    let mut matcher = Matcher::new(
209                        query,
210                        lowercase_query,
211                        query_char_bag,
212                        smart_case,
213                        max_results,
214                    );
215                    matcher.match_strings(
216                        &candidates[segment_start..segment_end],
217                        results,
218                        cancel_flag,
219                    );
220                });
221            }
222        })
223        .await;
224
225    let mut results = Vec::new();
226    for segment_result in segment_results {
227        if results.is_empty() {
228            results = segment_result;
229        } else {
230            util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(&a));
231        }
232    }
233    results
234}
235
236pub async fn match_paths<'a, Set: PathMatchCandidateSet<'a>>(
237    candidate_sets: &'a [Set],
238    query: &str,
239    smart_case: bool,
240    max_results: usize,
241    cancel_flag: &AtomicBool,
242    background: Arc<executor::Background>,
243) -> Vec<PathMatch> {
244    let path_count: usize = candidate_sets.iter().map(|s| s.len()).sum();
245    if path_count == 0 {
246        return Vec::new();
247    }
248
249    let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
250    let query = query.chars().collect::<Vec<_>>();
251
252    let lowercase_query = &lowercase_query;
253    let query = &query;
254    let query_char_bag = CharBag::from(&lowercase_query[..]);
255
256    let num_cpus = background.num_cpus().min(path_count);
257    let segment_size = (path_count + num_cpus - 1) / num_cpus;
258    let mut segment_results = (0..num_cpus)
259        .map(|_| Vec::with_capacity(max_results))
260        .collect::<Vec<_>>();
261
262    background
263        .scoped(|scope| {
264            for (segment_idx, results) in segment_results.iter_mut().enumerate() {
265                scope.spawn(async move {
266                    let segment_start = segment_idx * segment_size;
267                    let segment_end = segment_start + segment_size;
268                    let mut matcher = Matcher::new(
269                        query,
270                        lowercase_query,
271                        query_char_bag,
272                        smart_case,
273                        max_results,
274                    );
275
276                    let mut tree_start = 0;
277                    for candidate_set in candidate_sets {
278                        let tree_end = tree_start + candidate_set.len();
279
280                        if tree_start < segment_end && segment_start < tree_end {
281                            let start = cmp::max(tree_start, segment_start) - tree_start;
282                            let end = cmp::min(tree_end, segment_end) - tree_start;
283                            let candidates = candidate_set.candidates(start).take(end - start);
284
285                            matcher.match_paths(
286                                candidate_set.id(),
287                                candidate_set.prefix(),
288                                candidates,
289                                results,
290                                &cancel_flag,
291                            );
292                        }
293                        if tree_end >= segment_end {
294                            break;
295                        }
296                        tree_start = tree_end;
297                    }
298                })
299            }
300        })
301        .await;
302
303    let mut results = Vec::new();
304    for segment_result in segment_results {
305        if results.is_empty() {
306            results = segment_result;
307        } else {
308            util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(&a));
309        }
310    }
311    results
312}
313
314impl<'a> Matcher<'a> {
315    pub fn new(
316        query: &'a [char],
317        lowercase_query: &'a [char],
318        query_char_bag: CharBag,
319        smart_case: bool,
320        max_results: usize,
321    ) -> Self {
322        Self {
323            query,
324            lowercase_query,
325            query_char_bag,
326            min_score: 0.0,
327            last_positions: vec![0; query.len()],
328            match_positions: vec![0; query.len()],
329            score_matrix: Vec::new(),
330            best_position_matrix: Vec::new(),
331            smart_case,
332            max_results,
333        }
334    }
335
336    pub fn match_strings(
337        &mut self,
338        candidates: &[StringMatchCandidate],
339        results: &mut Vec<StringMatch>,
340        cancel_flag: &AtomicBool,
341    ) {
342        self.match_internal(
343            &[],
344            &[],
345            candidates.iter(),
346            results,
347            cancel_flag,
348            |candidate, score| StringMatch {
349                candidate_id: candidate.id,
350                score,
351                positions: Vec::new(),
352                string: candidate.string.to_string(),
353            },
354        )
355    }
356
357    pub fn match_paths<'c: 'a>(
358        &mut self,
359        tree_id: usize,
360        path_prefix: Arc<str>,
361        path_entries: impl Iterator<Item = PathMatchCandidate<'c>>,
362        results: &mut Vec<PathMatch>,
363        cancel_flag: &AtomicBool,
364    ) {
365        let prefix = path_prefix.chars().collect::<Vec<_>>();
366        let lowercase_prefix = prefix
367            .iter()
368            .map(|c| c.to_ascii_lowercase())
369            .collect::<Vec<_>>();
370        self.match_internal(
371            &prefix,
372            &lowercase_prefix,
373            path_entries,
374            results,
375            cancel_flag,
376            |candidate, score| PathMatch {
377                score,
378                worktree_id: tree_id,
379                positions: Vec::new(),
380                path: candidate.path.clone(),
381                path_prefix: path_prefix.clone(),
382            },
383        )
384    }
385
386    fn match_internal<C: MatchCandidate, R, F>(
387        &mut self,
388        prefix: &[char],
389        lowercase_prefix: &[char],
390        candidates: impl Iterator<Item = C>,
391        results: &mut Vec<R>,
392        cancel_flag: &AtomicBool,
393        build_match: F,
394    ) where
395        R: Match,
396        F: Fn(&C, f64) -> R,
397    {
398        let mut candidate_chars = Vec::new();
399        let mut lowercase_candidate_chars = Vec::new();
400
401        for candidate in candidates {
402            if !candidate.has_chars(self.query_char_bag) {
403                continue;
404            }
405
406            if cancel_flag.load(atomic::Ordering::Relaxed) {
407                break;
408            }
409
410            candidate_chars.clear();
411            lowercase_candidate_chars.clear();
412            for c in candidate.to_string().chars() {
413                candidate_chars.push(c);
414                lowercase_candidate_chars.push(c.to_ascii_lowercase());
415            }
416
417            if !self.find_last_positions(&lowercase_prefix, &lowercase_candidate_chars) {
418                continue;
419            }
420
421            let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
422            self.score_matrix.clear();
423            self.score_matrix.resize(matrix_len, None);
424            self.best_position_matrix.clear();
425            self.best_position_matrix.resize(matrix_len, 0);
426
427            let score = self.score_match(
428                &candidate_chars,
429                &lowercase_candidate_chars,
430                &prefix,
431                &lowercase_prefix,
432            );
433
434            if score > 0.0 {
435                let mut mat = build_match(&candidate, score);
436                if let Err(i) = results.binary_search_by(|m| mat.cmp(&m)) {
437                    if results.len() < self.max_results {
438                        mat.set_positions(self.match_positions.clone());
439                        results.insert(i, mat);
440                    } else if i < results.len() {
441                        results.pop();
442                        mat.set_positions(self.match_positions.clone());
443                        results.insert(i, mat);
444                    }
445                    if results.len() == self.max_results {
446                        self.min_score = results.last().unwrap().score();
447                    }
448                }
449            }
450        }
451    }
452
453    fn find_last_positions(
454        &mut self,
455        lowercase_prefix: &[char],
456        lowercase_candidate: &[char],
457    ) -> bool {
458        let mut lowercase_prefix = lowercase_prefix.iter();
459        let mut lowercase_candidate = lowercase_candidate.iter();
460        for (i, char) in self.lowercase_query.iter().enumerate().rev() {
461            if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
462                self.last_positions[i] = j + lowercase_prefix.len();
463            } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
464                self.last_positions[i] = j;
465            } else {
466                return false;
467            }
468        }
469        true
470    }
471
472    fn score_match(
473        &mut self,
474        path: &[char],
475        path_cased: &[char],
476        prefix: &[char],
477        lowercase_prefix: &[char],
478    ) -> f64 {
479        let score = self.recursive_score_match(
480            path,
481            path_cased,
482            prefix,
483            lowercase_prefix,
484            0,
485            0,
486            self.query.len() as f64,
487        ) * self.query.len() as f64;
488
489        if score <= 0.0 {
490            return 0.0;
491        }
492
493        let path_len = prefix.len() + path.len();
494        let mut cur_start = 0;
495        let mut byte_ix = 0;
496        let mut char_ix = 0;
497        for i in 0..self.query.len() {
498            let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
499            while char_ix < match_char_ix {
500                let ch = prefix
501                    .get(char_ix)
502                    .or_else(|| path.get(char_ix - prefix.len()))
503                    .unwrap();
504                byte_ix += ch.len_utf8();
505                char_ix += 1;
506            }
507            cur_start = match_char_ix + 1;
508            self.match_positions[i] = byte_ix;
509        }
510
511        score
512    }
513
514    fn recursive_score_match(
515        &mut self,
516        path: &[char],
517        path_cased: &[char],
518        prefix: &[char],
519        lowercase_prefix: &[char],
520        query_idx: usize,
521        path_idx: usize,
522        cur_score: f64,
523    ) -> f64 {
524        if query_idx == self.query.len() {
525            return 1.0;
526        }
527
528        let path_len = prefix.len() + path.len();
529
530        if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
531            return memoized;
532        }
533
534        let mut score = 0.0;
535        let mut best_position = 0;
536
537        let query_char = self.lowercase_query[query_idx];
538        let limit = self.last_positions[query_idx];
539
540        let mut last_slash = 0;
541        for j in path_idx..=limit {
542            let path_char = if j < prefix.len() {
543                lowercase_prefix[j]
544            } else {
545                path_cased[j - prefix.len()]
546            };
547            let is_path_sep = path_char == '/' || path_char == '\\';
548
549            if query_idx == 0 && is_path_sep {
550                last_slash = j;
551            }
552
553            if query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\') {
554                let curr = if j < prefix.len() {
555                    prefix[j]
556                } else {
557                    path[j - prefix.len()]
558                };
559
560                let mut char_score = 1.0;
561                if j > path_idx {
562                    let last = if j - 1 < prefix.len() {
563                        prefix[j - 1]
564                    } else {
565                        path[j - 1 - prefix.len()]
566                    };
567
568                    if last == '/' {
569                        char_score = 0.9;
570                    } else if last == '-' || last == '_' || last == ' ' || last.is_numeric() {
571                        char_score = 0.8;
572                    } else if last.is_lowercase() && curr.is_uppercase() {
573                        char_score = 0.8;
574                    } else if last == '.' {
575                        char_score = 0.7;
576                    } else if query_idx == 0 {
577                        char_score = BASE_DISTANCE_PENALTY;
578                    } else {
579                        char_score = MIN_DISTANCE_PENALTY.max(
580                            BASE_DISTANCE_PENALTY
581                                - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
582                        );
583                    }
584                }
585
586                // Apply a severe penalty if the case doesn't match.
587                // This will make the exact matches have higher score than the case-insensitive and the
588                // path insensitive matches.
589                if (self.smart_case || curr == '/') && self.query[query_idx] != curr {
590                    char_score *= 0.001;
591                }
592
593                let mut multiplier = char_score;
594
595                // Scale the score based on how deep within the path we found the match.
596                if query_idx == 0 {
597                    multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
598                }
599
600                let mut next_score = 1.0;
601                if self.min_score > 0.0 {
602                    next_score = cur_score * multiplier;
603                    // Scores only decrease. If we can't pass the previous best, bail
604                    if next_score < self.min_score {
605                        // Ensure that score is non-zero so we use it in the memo table.
606                        if score == 0.0 {
607                            score = 1e-18;
608                        }
609                        continue;
610                    }
611                }
612
613                let new_score = self.recursive_score_match(
614                    path,
615                    path_cased,
616                    prefix,
617                    lowercase_prefix,
618                    query_idx + 1,
619                    j + 1,
620                    next_score,
621                ) * multiplier;
622
623                if new_score > score {
624                    score = new_score;
625                    best_position = j;
626                    // Optimization: can't score better than 1.
627                    if new_score == 1.0 {
628                        break;
629                    }
630                }
631            }
632        }
633
634        if best_position != 0 {
635            self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
636        }
637
638        self.score_matrix[query_idx * path_len + path_idx] = Some(score);
639        score
640    }
641}
642
643#[cfg(test)]
644mod tests {
645    use super::*;
646    use std::path::PathBuf;
647
648    #[test]
649    fn test_get_last_positions() {
650        let mut query: &[char] = &['d', 'c'];
651        let mut matcher = Matcher::new(query, query, query.into(), false, 10);
652        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
653        assert_eq!(result, false);
654
655        query = &['c', 'd'];
656        let mut matcher = Matcher::new(query, query, query.into(), false, 10);
657        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
658        assert_eq!(result, true);
659        assert_eq!(matcher.last_positions, vec![2, 4]);
660
661        query = &['z', '/', 'z', 'f'];
662        let mut matcher = Matcher::new(query, query, query.into(), false, 10);
663        let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
664        assert_eq!(result, true);
665        assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
666    }
667
668    #[test]
669    fn test_match_path_entries() {
670        let paths = vec![
671            "",
672            "a",
673            "ab",
674            "abC",
675            "abcd",
676            "alphabravocharlie",
677            "AlphaBravoCharlie",
678            "thisisatestdir",
679            "/////ThisIsATestDir",
680            "/this/is/a/test/dir",
681            "/test/tiatd",
682        ];
683
684        assert_eq!(
685            match_query("abc", false, &paths),
686            vec![
687                ("abC", vec![0, 1, 2]),
688                ("abcd", vec![0, 1, 2]),
689                ("AlphaBravoCharlie", vec![0, 5, 10]),
690                ("alphabravocharlie", vec![4, 5, 10]),
691            ]
692        );
693        assert_eq!(
694            match_query("t/i/a/t/d", false, &paths),
695            vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
696        );
697
698        assert_eq!(
699            match_query("tiatd", false, &paths),
700            vec![
701                ("/test/tiatd", vec![6, 7, 8, 9, 10]),
702                ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
703                ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
704                ("thisisatestdir", vec![0, 2, 6, 7, 11]),
705            ]
706        );
707    }
708
709    #[test]
710    fn test_match_multibyte_path_entries() {
711        let paths = vec!["aαbβ/cγdδ", "αβγδ/bcde", "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", "/d/🆒/h"];
712        assert_eq!("1️⃣".len(), 7);
713        assert_eq!(
714            match_query("bcd", false, &paths),
715            vec![
716                ("αβγδ/bcde", vec![9, 10, 11]),
717                ("aαbβ/cγdδ", vec![3, 7, 10]),
718            ]
719        );
720        assert_eq!(
721            match_query("cde", false, &paths),
722            vec![
723                ("αβγδ/bcde", vec![10, 11, 12]),
724                ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
725            ]
726        );
727    }
728
729    fn match_query<'a>(
730        query: &str,
731        smart_case: bool,
732        paths: &Vec<&'a str>,
733    ) -> Vec<(&'a str, Vec<usize>)> {
734        let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
735        let query = query.chars().collect::<Vec<_>>();
736        let query_chars = CharBag::from(&lowercase_query[..]);
737
738        let path_arcs = paths
739            .iter()
740            .map(|path| Arc::from(PathBuf::from(path)))
741            .collect::<Vec<_>>();
742        let mut path_entries = Vec::new();
743        for (i, path) in paths.iter().enumerate() {
744            let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
745            let char_bag = CharBag::from(lowercase_path.as_slice());
746            path_entries.push(PathMatchCandidate {
747                char_bag,
748                path: path_arcs.get(i).unwrap(),
749            });
750        }
751
752        let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case, 100);
753
754        let cancel_flag = AtomicBool::new(false);
755        let mut results = Vec::new();
756        matcher.match_paths(
757            0,
758            "".into(),
759            path_entries.into_iter(),
760            &mut results,
761            &cancel_flag,
762        );
763
764        results
765            .into_iter()
766            .map(|result| {
767                (
768                    paths
769                        .iter()
770                        .copied()
771                        .find(|p| result.path.as_ref() == Path::new(p))
772                        .unwrap(),
773                    result.positions,
774                )
775            })
776            .collect()
777    }
778}