fuzzy.rs

  1mod char_bag;
  2
  3use gpui::executor;
  4use std::{
  5    borrow::Cow,
  6    cmp::{self, Ordering},
  7    path::Path,
  8    sync::atomic::{self, AtomicBool},
  9    sync::Arc,
 10};
 11
 12pub use char_bag::CharBag;
 13
 14const BASE_DISTANCE_PENALTY: f64 = 0.6;
 15const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
 16const MIN_DISTANCE_PENALTY: f64 = 0.2;
 17
 18pub struct Matcher<'a> {
 19    query: &'a [char],
 20    lowercase_query: &'a [char],
 21    query_char_bag: CharBag,
 22    smart_case: bool,
 23    max_results: usize,
 24    min_score: f64,
 25    match_positions: Vec<usize>,
 26    last_positions: Vec<usize>,
 27    score_matrix: Vec<Option<f64>>,
 28    best_position_matrix: Vec<usize>,
 29}
 30
 31trait Match: Ord {
 32    fn score(&self) -> f64;
 33    fn set_positions(&mut self, positions: Vec<usize>);
 34}
 35
 36trait MatchCandidate {
 37    fn has_chars(&self, bag: CharBag) -> bool;
 38    fn to_string<'a>(&'a self) -> Cow<'a, str>;
 39}
 40
 41#[derive(Clone, Debug)]
 42pub struct PathMatchCandidate<'a> {
 43    pub path: &'a Arc<Path>,
 44    pub char_bag: CharBag,
 45}
 46
 47#[derive(Clone, Debug)]
 48pub struct PathMatch {
 49    pub score: f64,
 50    pub positions: Vec<usize>,
 51    pub worktree_id: usize,
 52    pub path: Arc<Path>,
 53    pub path_prefix: Arc<str>,
 54}
 55
 56#[derive(Clone, Debug)]
 57pub struct StringMatchCandidate {
 58    pub string: String,
 59    pub char_bag: CharBag,
 60}
 61
 62pub trait PathMatchCandidateSet<'a>: Send + Sync {
 63    type Candidates: Iterator<Item = PathMatchCandidate<'a>>;
 64    fn id(&self) -> usize;
 65    fn len(&self) -> usize;
 66    fn prefix(&self) -> Arc<str>;
 67    fn candidates(&'a self, start: usize) -> Self::Candidates;
 68}
 69
 70impl Match for PathMatch {
 71    fn score(&self) -> f64 {
 72        self.score
 73    }
 74
 75    fn set_positions(&mut self, positions: Vec<usize>) {
 76        self.positions = positions;
 77    }
 78}
 79
 80impl Match for StringMatch {
 81    fn score(&self) -> f64 {
 82        self.score
 83    }
 84
 85    fn set_positions(&mut self, positions: Vec<usize>) {
 86        self.positions = positions;
 87    }
 88}
 89
 90impl<'a> MatchCandidate for PathMatchCandidate<'a> {
 91    fn has_chars(&self, bag: CharBag) -> bool {
 92        self.char_bag.is_superset(bag)
 93    }
 94
 95    fn to_string(&self) -> Cow<'a, str> {
 96        self.path.to_string_lossy()
 97    }
 98}
 99
100impl<'a> MatchCandidate for &'a StringMatchCandidate {
101    fn has_chars(&self, bag: CharBag) -> bool {
102        self.char_bag.is_superset(bag)
103    }
104
105    fn to_string(&self) -> Cow<'a, str> {
106        self.string.as_str().into()
107    }
108}
109
110#[derive(Clone, Debug)]
111pub struct StringMatch {
112    pub candidate_index: usize,
113    pub score: f64,
114    pub positions: Vec<usize>,
115    pub string: String,
116}
117
118impl PartialEq for StringMatch {
119    fn eq(&self, other: &Self) -> bool {
120        self.score.eq(&other.score)
121    }
122}
123
124impl Eq for StringMatch {}
125
126impl PartialOrd for StringMatch {
127    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
128        Some(self.cmp(other))
129    }
130}
131
132impl Ord for StringMatch {
133    fn cmp(&self, other: &Self) -> Ordering {
134        self.score
135            .partial_cmp(&other.score)
136            .unwrap_or(Ordering::Equal)
137            .then_with(|| self.string.cmp(&other.string))
138    }
139}
140
141impl PartialEq for PathMatch {
142    fn eq(&self, other: &Self) -> bool {
143        self.score.eq(&other.score)
144    }
145}
146
147impl Eq for PathMatch {}
148
149impl PartialOrd for PathMatch {
150    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
151        Some(self.cmp(other))
152    }
153}
154
155impl Ord for PathMatch {
156    fn cmp(&self, other: &Self) -> Ordering {
157        self.score
158            .partial_cmp(&other.score)
159            .unwrap_or(Ordering::Equal)
160            .then_with(|| self.worktree_id.cmp(&other.worktree_id))
161            .then_with(|| Arc::as_ptr(&self.path).cmp(&Arc::as_ptr(&other.path)))
162    }
163}
164
165pub async fn match_strings(
166    candidates: &[StringMatchCandidate],
167    query: &str,
168    smart_case: bool,
169    max_results: usize,
170    cancel_flag: &AtomicBool,
171    background: Arc<executor::Background>,
172) -> Vec<StringMatch> {
173    let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
174    let query = query.chars().collect::<Vec<_>>();
175
176    let lowercase_query = &lowercase_query;
177    let query = &query;
178    let query_char_bag = CharBag::from(&lowercase_query[..]);
179
180    let num_cpus = background.num_cpus().min(candidates.len());
181    let segment_size = (candidates.len() + num_cpus - 1) / num_cpus;
182    let mut segment_results = (0..num_cpus)
183        .map(|_| Vec::with_capacity(max_results))
184        .collect::<Vec<_>>();
185
186    background
187        .scoped(|scope| {
188            for (segment_idx, results) in segment_results.iter_mut().enumerate() {
189                let cancel_flag = &cancel_flag;
190                scope.spawn(async move {
191                    let segment_start = cmp::min(segment_idx * segment_size, candidates.len());
192                    let segment_end = cmp::min(segment_start + segment_size, candidates.len());
193                    let mut matcher = Matcher::new(
194                        query,
195                        lowercase_query,
196                        query_char_bag,
197                        smart_case,
198                        max_results,
199                    );
200                    matcher.match_strings(
201                        segment_start,
202                        &candidates[segment_start..segment_end],
203                        results,
204                        cancel_flag,
205                    );
206                });
207            }
208        })
209        .await;
210
211    let mut results = Vec::new();
212    for segment_result in segment_results {
213        if results.is_empty() {
214            results = segment_result;
215        } else {
216            util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(&a));
217        }
218    }
219    results
220}
221
222pub async fn match_paths<'a, Set: PathMatchCandidateSet<'a>>(
223    candidate_sets: &'a [Set],
224    query: &str,
225    smart_case: bool,
226    max_results: usize,
227    cancel_flag: &AtomicBool,
228    background: Arc<executor::Background>,
229) -> Vec<PathMatch> {
230    let path_count: usize = candidate_sets.iter().map(|s| s.len()).sum();
231    if path_count == 0 {
232        return Vec::new();
233    }
234
235    let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
236    let query = query.chars().collect::<Vec<_>>();
237
238    let lowercase_query = &lowercase_query;
239    let query = &query;
240    let query_char_bag = CharBag::from(&lowercase_query[..]);
241
242    let num_cpus = background.num_cpus().min(path_count);
243    let segment_size = (path_count + num_cpus - 1) / num_cpus;
244    let mut segment_results = (0..num_cpus)
245        .map(|_| Vec::with_capacity(max_results))
246        .collect::<Vec<_>>();
247
248    background
249        .scoped(|scope| {
250            for (segment_idx, results) in segment_results.iter_mut().enumerate() {
251                scope.spawn(async move {
252                    let segment_start = segment_idx * segment_size;
253                    let segment_end = segment_start + segment_size;
254                    let mut matcher = Matcher::new(
255                        query,
256                        lowercase_query,
257                        query_char_bag,
258                        smart_case,
259                        max_results,
260                    );
261
262                    let mut tree_start = 0;
263                    for candidate_set in candidate_sets {
264                        let tree_end = tree_start + candidate_set.len();
265
266                        if tree_start < segment_end && segment_start < tree_end {
267                            let start = cmp::max(tree_start, segment_start) - tree_start;
268                            let end = cmp::min(tree_end, segment_end) - tree_start;
269                            let candidates = candidate_set.candidates(start).take(end - start);
270
271                            matcher.match_paths(
272                                candidate_set.id(),
273                                candidate_set.prefix(),
274                                candidates,
275                                results,
276                                &cancel_flag,
277                            );
278                        }
279                        if tree_end >= segment_end {
280                            break;
281                        }
282                        tree_start = tree_end;
283                    }
284                })
285            }
286        })
287        .await;
288
289    let mut results = Vec::new();
290    for segment_result in segment_results {
291        if results.is_empty() {
292            results = segment_result;
293        } else {
294            util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(&a));
295        }
296    }
297    results
298}
299
300impl<'a> Matcher<'a> {
301    pub fn new(
302        query: &'a [char],
303        lowercase_query: &'a [char],
304        query_char_bag: CharBag,
305        smart_case: bool,
306        max_results: usize,
307    ) -> Self {
308        Self {
309            query,
310            lowercase_query,
311            query_char_bag,
312            min_score: 0.0,
313            last_positions: vec![0; query.len()],
314            match_positions: vec![0; query.len()],
315            score_matrix: Vec::new(),
316            best_position_matrix: Vec::new(),
317            smart_case,
318            max_results,
319        }
320    }
321
322    pub fn match_strings(
323        &mut self,
324        start_index: usize,
325        candidates: &[StringMatchCandidate],
326        results: &mut Vec<StringMatch>,
327        cancel_flag: &AtomicBool,
328    ) {
329        self.match_internal(
330            &[],
331            &[],
332            start_index,
333            candidates.iter(),
334            results,
335            cancel_flag,
336            |candidate_index, candidate, score| StringMatch {
337                candidate_index,
338                score,
339                positions: Vec::new(),
340                string: candidate.string.to_string(),
341            },
342        )
343    }
344
345    pub fn match_paths<'c: 'a>(
346        &mut self,
347        tree_id: usize,
348        path_prefix: Arc<str>,
349        path_entries: impl Iterator<Item = PathMatchCandidate<'c>>,
350        results: &mut Vec<PathMatch>,
351        cancel_flag: &AtomicBool,
352    ) {
353        let prefix = path_prefix.chars().collect::<Vec<_>>();
354        let lowercase_prefix = prefix
355            .iter()
356            .map(|c| c.to_ascii_lowercase())
357            .collect::<Vec<_>>();
358        self.match_internal(
359            &prefix,
360            &lowercase_prefix,
361            0,
362            path_entries,
363            results,
364            cancel_flag,
365            |_, candidate, score| PathMatch {
366                score,
367                worktree_id: tree_id,
368                positions: Vec::new(),
369                path: candidate.path.clone(),
370                path_prefix: path_prefix.clone(),
371            },
372        )
373    }
374
375    fn match_internal<C: MatchCandidate, R, F>(
376        &mut self,
377        prefix: &[char],
378        lowercase_prefix: &[char],
379        start_index: usize,
380        candidates: impl Iterator<Item = C>,
381        results: &mut Vec<R>,
382        cancel_flag: &AtomicBool,
383        build_match: F,
384    ) where
385        R: Match,
386        F: Fn(usize, &C, f64) -> R,
387    {
388        let mut candidate_chars = Vec::new();
389        let mut lowercase_candidate_chars = Vec::new();
390
391        for (candidate_ix, candidate) in candidates.enumerate() {
392            if !candidate.has_chars(self.query_char_bag) {
393                continue;
394            }
395
396            if cancel_flag.load(atomic::Ordering::Relaxed) {
397                break;
398            }
399
400            candidate_chars.clear();
401            lowercase_candidate_chars.clear();
402            for c in candidate.to_string().chars() {
403                candidate_chars.push(c);
404                lowercase_candidate_chars.push(c.to_ascii_lowercase());
405            }
406
407            if !self.find_last_positions(&lowercase_prefix, &lowercase_candidate_chars) {
408                continue;
409            }
410
411            let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
412            self.score_matrix.clear();
413            self.score_matrix.resize(matrix_len, None);
414            self.best_position_matrix.clear();
415            self.best_position_matrix.resize(matrix_len, 0);
416
417            let score = self.score_match(
418                &candidate_chars,
419                &lowercase_candidate_chars,
420                &prefix,
421                &lowercase_prefix,
422            );
423
424            if score > 0.0 {
425                let mut mat = build_match(start_index + candidate_ix, &candidate, score);
426                if let Err(i) = results.binary_search_by(|m| mat.cmp(&m)) {
427                    if results.len() < self.max_results {
428                        mat.set_positions(self.match_positions.clone());
429                        results.insert(i, mat);
430                    } else if i < results.len() {
431                        results.pop();
432                        mat.set_positions(self.match_positions.clone());
433                        results.insert(i, mat);
434                    }
435                    if results.len() == self.max_results {
436                        self.min_score = results.last().unwrap().score();
437                    }
438                }
439            }
440        }
441    }
442
443    fn find_last_positions(
444        &mut self,
445        lowercase_prefix: &[char],
446        lowercase_candidate: &[char],
447    ) -> bool {
448        let mut lowercase_prefix = lowercase_prefix.iter();
449        let mut lowercase_candidate = lowercase_candidate.iter();
450        for (i, char) in self.lowercase_query.iter().enumerate().rev() {
451            if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
452                self.last_positions[i] = j + lowercase_prefix.len();
453            } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
454                self.last_positions[i] = j;
455            } else {
456                return false;
457            }
458        }
459        true
460    }
461
462    fn score_match(
463        &mut self,
464        path: &[char],
465        path_cased: &[char],
466        prefix: &[char],
467        lowercase_prefix: &[char],
468    ) -> f64 {
469        let score = self.recursive_score_match(
470            path,
471            path_cased,
472            prefix,
473            lowercase_prefix,
474            0,
475            0,
476            self.query.len() as f64,
477        ) * self.query.len() as f64;
478
479        if score <= 0.0 {
480            return 0.0;
481        }
482
483        let path_len = prefix.len() + path.len();
484        let mut cur_start = 0;
485        let mut byte_ix = 0;
486        let mut char_ix = 0;
487        for i in 0..self.query.len() {
488            let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
489            while char_ix < match_char_ix {
490                let ch = prefix
491                    .get(char_ix)
492                    .or_else(|| path.get(char_ix - prefix.len()))
493                    .unwrap();
494                byte_ix += ch.len_utf8();
495                char_ix += 1;
496            }
497            cur_start = match_char_ix + 1;
498            self.match_positions[i] = byte_ix;
499        }
500
501        score
502    }
503
504    fn recursive_score_match(
505        &mut self,
506        path: &[char],
507        path_cased: &[char],
508        prefix: &[char],
509        lowercase_prefix: &[char],
510        query_idx: usize,
511        path_idx: usize,
512        cur_score: f64,
513    ) -> f64 {
514        if query_idx == self.query.len() {
515            return 1.0;
516        }
517
518        let path_len = prefix.len() + path.len();
519
520        if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
521            return memoized;
522        }
523
524        let mut score = 0.0;
525        let mut best_position = 0;
526
527        let query_char = self.lowercase_query[query_idx];
528        let limit = self.last_positions[query_idx];
529
530        let mut last_slash = 0;
531        for j in path_idx..=limit {
532            let path_char = if j < prefix.len() {
533                lowercase_prefix[j]
534            } else {
535                path_cased[j - prefix.len()]
536            };
537            let is_path_sep = path_char == '/' || path_char == '\\';
538
539            if query_idx == 0 && is_path_sep {
540                last_slash = j;
541            }
542
543            if query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\') {
544                let curr = if j < prefix.len() {
545                    prefix[j]
546                } else {
547                    path[j - prefix.len()]
548                };
549
550                let mut char_score = 1.0;
551                if j > path_idx {
552                    let last = if j - 1 < prefix.len() {
553                        prefix[j - 1]
554                    } else {
555                        path[j - 1 - prefix.len()]
556                    };
557
558                    if last == '/' {
559                        char_score = 0.9;
560                    } else if last == '-' || last == '_' || last == ' ' || last.is_numeric() {
561                        char_score = 0.8;
562                    } else if last.is_lowercase() && curr.is_uppercase() {
563                        char_score = 0.8;
564                    } else if last == '.' {
565                        char_score = 0.7;
566                    } else if query_idx == 0 {
567                        char_score = BASE_DISTANCE_PENALTY;
568                    } else {
569                        char_score = MIN_DISTANCE_PENALTY.max(
570                            BASE_DISTANCE_PENALTY
571                                - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
572                        );
573                    }
574                }
575
576                // Apply a severe penalty if the case doesn't match.
577                // This will make the exact matches have higher score than the case-insensitive and the
578                // path insensitive matches.
579                if (self.smart_case || curr == '/') && self.query[query_idx] != curr {
580                    char_score *= 0.001;
581                }
582
583                let mut multiplier = char_score;
584
585                // Scale the score based on how deep within the path we found the match.
586                if query_idx == 0 {
587                    multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
588                }
589
590                let mut next_score = 1.0;
591                if self.min_score > 0.0 {
592                    next_score = cur_score * multiplier;
593                    // Scores only decrease. If we can't pass the previous best, bail
594                    if next_score < self.min_score {
595                        // Ensure that score is non-zero so we use it in the memo table.
596                        if score == 0.0 {
597                            score = 1e-18;
598                        }
599                        continue;
600                    }
601                }
602
603                let new_score = self.recursive_score_match(
604                    path,
605                    path_cased,
606                    prefix,
607                    lowercase_prefix,
608                    query_idx + 1,
609                    j + 1,
610                    next_score,
611                ) * multiplier;
612
613                if new_score > score {
614                    score = new_score;
615                    best_position = j;
616                    // Optimization: can't score better than 1.
617                    if new_score == 1.0 {
618                        break;
619                    }
620                }
621            }
622        }
623
624        if best_position != 0 {
625            self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
626        }
627
628        self.score_matrix[query_idx * path_len + path_idx] = Some(score);
629        score
630    }
631}
632
633#[cfg(test)]
634mod tests {
635    use super::*;
636    use std::path::PathBuf;
637
638    #[test]
639    fn test_get_last_positions() {
640        let mut query: &[char] = &['d', 'c'];
641        let mut matcher = Matcher::new(query, query, query.into(), false, 10);
642        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
643        assert_eq!(result, false);
644
645        query = &['c', 'd'];
646        let mut matcher = Matcher::new(query, query, query.into(), false, 10);
647        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
648        assert_eq!(result, true);
649        assert_eq!(matcher.last_positions, vec![2, 4]);
650
651        query = &['z', '/', 'z', 'f'];
652        let mut matcher = Matcher::new(query, query, query.into(), false, 10);
653        let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
654        assert_eq!(result, true);
655        assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
656    }
657
658    #[test]
659    fn test_match_path_entries() {
660        let paths = vec![
661            "",
662            "a",
663            "ab",
664            "abC",
665            "abcd",
666            "alphabravocharlie",
667            "AlphaBravoCharlie",
668            "thisisatestdir",
669            "/////ThisIsATestDir",
670            "/this/is/a/test/dir",
671            "/test/tiatd",
672        ];
673
674        assert_eq!(
675            match_query("abc", false, &paths),
676            vec![
677                ("abC", vec![0, 1, 2]),
678                ("abcd", vec![0, 1, 2]),
679                ("AlphaBravoCharlie", vec![0, 5, 10]),
680                ("alphabravocharlie", vec![4, 5, 10]),
681            ]
682        );
683        assert_eq!(
684            match_query("t/i/a/t/d", false, &paths),
685            vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
686        );
687
688        assert_eq!(
689            match_query("tiatd", false, &paths),
690            vec![
691                ("/test/tiatd", vec![6, 7, 8, 9, 10]),
692                ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
693                ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
694                ("thisisatestdir", vec![0, 2, 6, 7, 11]),
695            ]
696        );
697    }
698
699    #[test]
700    fn test_match_multibyte_path_entries() {
701        let paths = vec!["aαbβ/cγdδ", "αβγδ/bcde", "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", "/d/🆒/h"];
702        assert_eq!("1️⃣".len(), 7);
703        assert_eq!(
704            match_query("bcd", false, &paths),
705            vec![
706                ("αβγδ/bcde", vec![9, 10, 11]),
707                ("aαbβ/cγdδ", vec![3, 7, 10]),
708            ]
709        );
710        assert_eq!(
711            match_query("cde", false, &paths),
712            vec![
713                ("αβγδ/bcde", vec![10, 11, 12]),
714                ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
715            ]
716        );
717    }
718
719    fn match_query<'a>(
720        query: &str,
721        smart_case: bool,
722        paths: &Vec<&'a str>,
723    ) -> Vec<(&'a str, Vec<usize>)> {
724        let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
725        let query = query.chars().collect::<Vec<_>>();
726        let query_chars = CharBag::from(&lowercase_query[..]);
727
728        let path_arcs = paths
729            .iter()
730            .map(|path| Arc::from(PathBuf::from(path)))
731            .collect::<Vec<_>>();
732        let mut path_entries = Vec::new();
733        for (i, path) in paths.iter().enumerate() {
734            let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
735            let char_bag = CharBag::from(lowercase_path.as_slice());
736            path_entries.push(PathMatchCandidate {
737                char_bag,
738                path: path_arcs.get(i).unwrap(),
739            });
740        }
741
742        let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case, 100);
743
744        let cancel_flag = AtomicBool::new(false);
745        let mut results = Vec::new();
746        matcher.match_paths(
747            0,
748            "".into(),
749            path_entries.into_iter(),
750            &mut results,
751            &cancel_flag,
752        );
753
754        results
755            .into_iter()
756            .map(|result| {
757                (
758                    paths
759                        .iter()
760                        .copied()
761                        .find(|p| result.path.as_ref() == Path::new(p))
762                        .unwrap(),
763                    result.positions,
764                )
765            })
766            .collect()
767    }
768}