fuzzy.rs

  1mod char_bag;
  2
  3use gpui::executor;
  4use std::{
  5    borrow::Cow,
  6    cmp::{self, Ordering},
  7    path::Path,
  8    sync::atomic::{self, AtomicBool},
  9    sync::Arc,
 10};
 11
 12pub use char_bag::CharBag;
 13
 14const BASE_DISTANCE_PENALTY: f64 = 0.6;
 15const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
 16const MIN_DISTANCE_PENALTY: f64 = 0.2;
 17
 18pub struct Matcher<'a> {
 19    query: &'a [char],
 20    lowercase_query: &'a [char],
 21    query_char_bag: CharBag,
 22    smart_case: bool,
 23    max_results: usize,
 24    min_score: f64,
 25    match_positions: Vec<usize>,
 26    last_positions: Vec<usize>,
 27    score_matrix: Vec<Option<f64>>,
 28    best_position_matrix: Vec<usize>,
 29}
 30
 31trait Match: Ord {
 32    fn score(&self) -> f64;
 33    fn set_positions(&mut self, positions: Vec<usize>);
 34}
 35
 36trait MatchCandidate {
 37    fn has_chars(&self, bag: CharBag) -> bool;
 38    fn to_string<'a>(&'a self) -> Cow<'a, str>;
 39}
 40
 41#[derive(Clone, Debug)]
 42pub struct PathMatchCandidate<'a> {
 43    pub path: &'a Arc<Path>,
 44    pub char_bag: CharBag,
 45}
 46
 47#[derive(Clone, Debug)]
 48pub struct PathMatch {
 49    pub score: f64,
 50    pub positions: Vec<usize>,
 51    pub worktree_id: usize,
 52    pub path: Arc<Path>,
 53    pub path_prefix: Arc<str>,
 54}
 55
 56#[derive(Clone, Debug)]
 57pub struct StringMatchCandidate {
 58    pub id: usize,
 59    pub string: String,
 60    pub char_bag: CharBag,
 61}
 62
 63pub trait PathMatchCandidateSet<'a>: Send + Sync {
 64    type Candidates: Iterator<Item = PathMatchCandidate<'a>>;
 65    fn id(&self) -> usize;
 66    fn len(&self) -> usize;
 67    fn prefix(&self) -> Arc<str>;
 68    fn candidates(&'a self, start: usize) -> Self::Candidates;
 69}
 70
 71impl Match for PathMatch {
 72    fn score(&self) -> f64 {
 73        self.score
 74    }
 75
 76    fn set_positions(&mut self, positions: Vec<usize>) {
 77        self.positions = positions;
 78    }
 79}
 80
 81impl Match for StringMatch {
 82    fn score(&self) -> f64 {
 83        self.score
 84    }
 85
 86    fn set_positions(&mut self, positions: Vec<usize>) {
 87        self.positions = positions;
 88    }
 89}
 90
 91impl<'a> MatchCandidate for PathMatchCandidate<'a> {
 92    fn has_chars(&self, bag: CharBag) -> bool {
 93        self.char_bag.is_superset(bag)
 94    }
 95
 96    fn to_string(&self) -> Cow<'a, str> {
 97        self.path.to_string_lossy()
 98    }
 99}
100
101impl StringMatchCandidate {
102    pub fn new(id: usize, string: String) -> Self {
103        Self {
104            id,
105            char_bag: CharBag::from(string.as_str()),
106            string,
107        }
108    }
109}
110
111impl<'a> MatchCandidate for &'a StringMatchCandidate {
112    fn has_chars(&self, bag: CharBag) -> bool {
113        self.char_bag.is_superset(bag)
114    }
115
116    fn to_string(&self) -> Cow<'a, str> {
117        self.string.as_str().into()
118    }
119}
120
121#[derive(Clone, Debug)]
122pub struct StringMatch {
123    pub candidate_id: usize,
124    pub score: f64,
125    pub positions: Vec<usize>,
126    pub string: String,
127}
128
129impl PartialEq for StringMatch {
130    fn eq(&self, other: &Self) -> bool {
131        self.cmp(other).is_eq()
132    }
133}
134
135impl Eq for StringMatch {}
136
137impl PartialOrd for StringMatch {
138    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
139        Some(self.cmp(other))
140    }
141}
142
143impl Ord for StringMatch {
144    fn cmp(&self, other: &Self) -> Ordering {
145        self.score
146            .partial_cmp(&other.score)
147            .unwrap_or(Ordering::Equal)
148            .then_with(|| self.candidate_id.cmp(&other.candidate_id))
149    }
150}
151
152impl PartialEq for PathMatch {
153    fn eq(&self, other: &Self) -> bool {
154        self.cmp(other).is_eq()
155    }
156}
157
158impl Eq for PathMatch {}
159
160impl PartialOrd for PathMatch {
161    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
162        Some(self.cmp(other))
163    }
164}
165
166impl Ord for PathMatch {
167    fn cmp(&self, other: &Self) -> Ordering {
168        self.score
169            .partial_cmp(&other.score)
170            .unwrap_or(Ordering::Equal)
171            .then_with(|| self.worktree_id.cmp(&other.worktree_id))
172            .then_with(|| Arc::as_ptr(&self.path).cmp(&Arc::as_ptr(&other.path)))
173    }
174}
175
176pub async fn match_strings(
177    candidates: &[StringMatchCandidate],
178    query: &str,
179    smart_case: bool,
180    max_results: usize,
181    cancel_flag: &AtomicBool,
182    background: Arc<executor::Background>,
183) -> Vec<StringMatch> {
184    if candidates.is_empty() || max_results == 0 {
185        return Default::default();
186    }
187
188    if query.is_empty() {
189        return candidates
190            .iter()
191            .map(|candidate| StringMatch {
192                candidate_id: candidate.id,
193                score: 0.,
194                positions: Default::default(),
195                string: candidate.string.clone(),
196            })
197            .collect();
198    }
199
200    let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
201    let query = query.chars().collect::<Vec<_>>();
202
203    let lowercase_query = &lowercase_query;
204    let query = &query;
205    let query_char_bag = CharBag::from(&lowercase_query[..]);
206
207    let num_cpus = background.num_cpus().min(candidates.len());
208    let segment_size = (candidates.len() + num_cpus - 1) / num_cpus;
209    let mut segment_results = (0..num_cpus)
210        .map(|_| Vec::with_capacity(max_results.min(candidates.len())))
211        .collect::<Vec<_>>();
212
213    background
214        .scoped(|scope| {
215            for (segment_idx, results) in segment_results.iter_mut().enumerate() {
216                let cancel_flag = &cancel_flag;
217                scope.spawn(async move {
218                    let segment_start = cmp::min(segment_idx * segment_size, candidates.len());
219                    let segment_end = cmp::min(segment_start + segment_size, candidates.len());
220                    let mut matcher = Matcher::new(
221                        query,
222                        lowercase_query,
223                        query_char_bag,
224                        smart_case,
225                        max_results,
226                    );
227                    matcher.match_strings(
228                        &candidates[segment_start..segment_end],
229                        results,
230                        cancel_flag,
231                    );
232                });
233            }
234        })
235        .await;
236
237    let mut results = Vec::new();
238    for segment_result in segment_results {
239        if results.is_empty() {
240            results = segment_result;
241        } else {
242            util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(&a));
243        }
244    }
245    results
246}
247
248pub async fn match_paths<'a, Set: PathMatchCandidateSet<'a>>(
249    candidate_sets: &'a [Set],
250    query: &str,
251    smart_case: bool,
252    max_results: usize,
253    cancel_flag: &AtomicBool,
254    background: Arc<executor::Background>,
255) -> Vec<PathMatch> {
256    let path_count: usize = candidate_sets.iter().map(|s| s.len()).sum();
257    if path_count == 0 {
258        return Vec::new();
259    }
260
261    let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
262    let query = query.chars().collect::<Vec<_>>();
263
264    let lowercase_query = &lowercase_query;
265    let query = &query;
266    let query_char_bag = CharBag::from(&lowercase_query[..]);
267
268    let num_cpus = background.num_cpus().min(path_count);
269    let segment_size = (path_count + num_cpus - 1) / num_cpus;
270    let mut segment_results = (0..num_cpus)
271        .map(|_| Vec::with_capacity(max_results))
272        .collect::<Vec<_>>();
273
274    background
275        .scoped(|scope| {
276            for (segment_idx, results) in segment_results.iter_mut().enumerate() {
277                scope.spawn(async move {
278                    let segment_start = segment_idx * segment_size;
279                    let segment_end = segment_start + segment_size;
280                    let mut matcher = Matcher::new(
281                        query,
282                        lowercase_query,
283                        query_char_bag,
284                        smart_case,
285                        max_results,
286                    );
287
288                    let mut tree_start = 0;
289                    for candidate_set in candidate_sets {
290                        let tree_end = tree_start + candidate_set.len();
291
292                        if tree_start < segment_end && segment_start < tree_end {
293                            let start = cmp::max(tree_start, segment_start) - tree_start;
294                            let end = cmp::min(tree_end, segment_end) - tree_start;
295                            let candidates = candidate_set.candidates(start).take(end - start);
296
297                            matcher.match_paths(
298                                candidate_set.id(),
299                                candidate_set.prefix(),
300                                candidates,
301                                results,
302                                &cancel_flag,
303                            );
304                        }
305                        if tree_end >= segment_end {
306                            break;
307                        }
308                        tree_start = tree_end;
309                    }
310                })
311            }
312        })
313        .await;
314
315    let mut results = Vec::new();
316    for segment_result in segment_results {
317        if results.is_empty() {
318            results = segment_result;
319        } else {
320            util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(&a));
321        }
322    }
323    results
324}
325
326impl<'a> Matcher<'a> {
327    pub fn new(
328        query: &'a [char],
329        lowercase_query: &'a [char],
330        query_char_bag: CharBag,
331        smart_case: bool,
332        max_results: usize,
333    ) -> Self {
334        Self {
335            query,
336            lowercase_query,
337            query_char_bag,
338            min_score: 0.0,
339            last_positions: vec![0; query.len()],
340            match_positions: vec![0; query.len()],
341            score_matrix: Vec::new(),
342            best_position_matrix: Vec::new(),
343            smart_case,
344            max_results,
345        }
346    }
347
348    pub fn match_strings(
349        &mut self,
350        candidates: &[StringMatchCandidate],
351        results: &mut Vec<StringMatch>,
352        cancel_flag: &AtomicBool,
353    ) {
354        self.match_internal(
355            &[],
356            &[],
357            candidates.iter(),
358            results,
359            cancel_flag,
360            |candidate, score| StringMatch {
361                candidate_id: candidate.id,
362                score,
363                positions: Vec::new(),
364                string: candidate.string.to_string(),
365            },
366        )
367    }
368
369    pub fn match_paths<'c: 'a>(
370        &mut self,
371        tree_id: usize,
372        path_prefix: Arc<str>,
373        path_entries: impl Iterator<Item = PathMatchCandidate<'c>>,
374        results: &mut Vec<PathMatch>,
375        cancel_flag: &AtomicBool,
376    ) {
377        let prefix = path_prefix.chars().collect::<Vec<_>>();
378        let lowercase_prefix = prefix
379            .iter()
380            .map(|c| c.to_ascii_lowercase())
381            .collect::<Vec<_>>();
382        self.match_internal(
383            &prefix,
384            &lowercase_prefix,
385            path_entries,
386            results,
387            cancel_flag,
388            |candidate, score| PathMatch {
389                score,
390                worktree_id: tree_id,
391                positions: Vec::new(),
392                path: candidate.path.clone(),
393                path_prefix: path_prefix.clone(),
394            },
395        )
396    }
397
398    fn match_internal<C: MatchCandidate, R, F>(
399        &mut self,
400        prefix: &[char],
401        lowercase_prefix: &[char],
402        candidates: impl Iterator<Item = C>,
403        results: &mut Vec<R>,
404        cancel_flag: &AtomicBool,
405        build_match: F,
406    ) where
407        R: Match,
408        F: Fn(&C, f64) -> R,
409    {
410        let mut candidate_chars = Vec::new();
411        let mut lowercase_candidate_chars = Vec::new();
412
413        for candidate in candidates {
414            if !candidate.has_chars(self.query_char_bag) {
415                continue;
416            }
417
418            if cancel_flag.load(atomic::Ordering::Relaxed) {
419                break;
420            }
421
422            candidate_chars.clear();
423            lowercase_candidate_chars.clear();
424            for c in candidate.to_string().chars() {
425                candidate_chars.push(c);
426                lowercase_candidate_chars.push(c.to_ascii_lowercase());
427            }
428
429            if !self.find_last_positions(&lowercase_prefix, &lowercase_candidate_chars) {
430                continue;
431            }
432
433            let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
434            self.score_matrix.clear();
435            self.score_matrix.resize(matrix_len, None);
436            self.best_position_matrix.clear();
437            self.best_position_matrix.resize(matrix_len, 0);
438
439            let score = self.score_match(
440                &candidate_chars,
441                &lowercase_candidate_chars,
442                &prefix,
443                &lowercase_prefix,
444            );
445
446            if score > 0.0 {
447                let mut mat = build_match(&candidate, score);
448                if let Err(i) = results.binary_search_by(|m| mat.cmp(&m)) {
449                    if results.len() < self.max_results {
450                        mat.set_positions(self.match_positions.clone());
451                        results.insert(i, mat);
452                    } else if i < results.len() {
453                        results.pop();
454                        mat.set_positions(self.match_positions.clone());
455                        results.insert(i, mat);
456                    }
457                    if results.len() == self.max_results {
458                        self.min_score = results.last().unwrap().score();
459                    }
460                }
461            }
462        }
463    }
464
465    fn find_last_positions(
466        &mut self,
467        lowercase_prefix: &[char],
468        lowercase_candidate: &[char],
469    ) -> bool {
470        let mut lowercase_prefix = lowercase_prefix.iter();
471        let mut lowercase_candidate = lowercase_candidate.iter();
472        for (i, char) in self.lowercase_query.iter().enumerate().rev() {
473            if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
474                self.last_positions[i] = j + lowercase_prefix.len();
475            } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
476                self.last_positions[i] = j;
477            } else {
478                return false;
479            }
480        }
481        true
482    }
483
484    fn score_match(
485        &mut self,
486        path: &[char],
487        path_cased: &[char],
488        prefix: &[char],
489        lowercase_prefix: &[char],
490    ) -> f64 {
491        let score = self.recursive_score_match(
492            path,
493            path_cased,
494            prefix,
495            lowercase_prefix,
496            0,
497            0,
498            self.query.len() as f64,
499        ) * self.query.len() as f64;
500
501        if score <= 0.0 {
502            return 0.0;
503        }
504
505        let path_len = prefix.len() + path.len();
506        let mut cur_start = 0;
507        let mut byte_ix = 0;
508        let mut char_ix = 0;
509        for i in 0..self.query.len() {
510            let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
511            while char_ix < match_char_ix {
512                let ch = prefix
513                    .get(char_ix)
514                    .or_else(|| path.get(char_ix - prefix.len()))
515                    .unwrap();
516                byte_ix += ch.len_utf8();
517                char_ix += 1;
518            }
519            cur_start = match_char_ix + 1;
520            self.match_positions[i] = byte_ix;
521        }
522
523        score
524    }
525
526    fn recursive_score_match(
527        &mut self,
528        path: &[char],
529        path_cased: &[char],
530        prefix: &[char],
531        lowercase_prefix: &[char],
532        query_idx: usize,
533        path_idx: usize,
534        cur_score: f64,
535    ) -> f64 {
536        if query_idx == self.query.len() {
537            return 1.0;
538        }
539
540        let path_len = prefix.len() + path.len();
541
542        if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
543            return memoized;
544        }
545
546        let mut score = 0.0;
547        let mut best_position = 0;
548
549        let query_char = self.lowercase_query[query_idx];
550        let limit = self.last_positions[query_idx];
551
552        let mut last_slash = 0;
553        for j in path_idx..=limit {
554            let path_char = if j < prefix.len() {
555                lowercase_prefix[j]
556            } else {
557                path_cased[j - prefix.len()]
558            };
559            let is_path_sep = path_char == '/' || path_char == '\\';
560
561            if query_idx == 0 && is_path_sep {
562                last_slash = j;
563            }
564
565            if query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\') {
566                let curr = if j < prefix.len() {
567                    prefix[j]
568                } else {
569                    path[j - prefix.len()]
570                };
571
572                let mut char_score = 1.0;
573                if j > path_idx {
574                    let last = if j - 1 < prefix.len() {
575                        prefix[j - 1]
576                    } else {
577                        path[j - 1 - prefix.len()]
578                    };
579
580                    if last == '/' {
581                        char_score = 0.9;
582                    } else if last == '-' || last == '_' || last == ' ' || last.is_numeric() {
583                        char_score = 0.8;
584                    } else if last.is_lowercase() && curr.is_uppercase() {
585                        char_score = 0.8;
586                    } else if last == '.' {
587                        char_score = 0.7;
588                    } else if query_idx == 0 {
589                        char_score = BASE_DISTANCE_PENALTY;
590                    } else {
591                        char_score = MIN_DISTANCE_PENALTY.max(
592                            BASE_DISTANCE_PENALTY
593                                - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
594                        );
595                    }
596                }
597
598                // Apply a severe penalty if the case doesn't match.
599                // This will make the exact matches have higher score than the case-insensitive and the
600                // path insensitive matches.
601                if (self.smart_case || curr == '/') && self.query[query_idx] != curr {
602                    char_score *= 0.001;
603                }
604
605                let mut multiplier = char_score;
606
607                // Scale the score based on how deep within the path we found the match.
608                if query_idx == 0 {
609                    multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
610                }
611
612                let mut next_score = 1.0;
613                if self.min_score > 0.0 {
614                    next_score = cur_score * multiplier;
615                    // Scores only decrease. If we can't pass the previous best, bail
616                    if next_score < self.min_score {
617                        // Ensure that score is non-zero so we use it in the memo table.
618                        if score == 0.0 {
619                            score = 1e-18;
620                        }
621                        continue;
622                    }
623                }
624
625                let new_score = self.recursive_score_match(
626                    path,
627                    path_cased,
628                    prefix,
629                    lowercase_prefix,
630                    query_idx + 1,
631                    j + 1,
632                    next_score,
633                ) * multiplier;
634
635                if new_score > score {
636                    score = new_score;
637                    best_position = j;
638                    // Optimization: can't score better than 1.
639                    if new_score == 1.0 {
640                        break;
641                    }
642                }
643            }
644        }
645
646        if best_position != 0 {
647            self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
648        }
649
650        self.score_matrix[query_idx * path_len + path_idx] = Some(score);
651        score
652    }
653}
654
655#[cfg(test)]
656mod tests {
657    use super::*;
658    use std::path::PathBuf;
659
660    #[test]
661    fn test_get_last_positions() {
662        let mut query: &[char] = &['d', 'c'];
663        let mut matcher = Matcher::new(query, query, query.into(), false, 10);
664        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
665        assert_eq!(result, false);
666
667        query = &['c', 'd'];
668        let mut matcher = Matcher::new(query, query, query.into(), false, 10);
669        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
670        assert_eq!(result, true);
671        assert_eq!(matcher.last_positions, vec![2, 4]);
672
673        query = &['z', '/', 'z', 'f'];
674        let mut matcher = Matcher::new(query, query, query.into(), false, 10);
675        let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
676        assert_eq!(result, true);
677        assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
678    }
679
680    #[test]
681    fn test_match_path_entries() {
682        let paths = vec![
683            "",
684            "a",
685            "ab",
686            "abC",
687            "abcd",
688            "alphabravocharlie",
689            "AlphaBravoCharlie",
690            "thisisatestdir",
691            "/////ThisIsATestDir",
692            "/this/is/a/test/dir",
693            "/test/tiatd",
694        ];
695
696        assert_eq!(
697            match_query("abc", false, &paths),
698            vec![
699                ("abC", vec![0, 1, 2]),
700                ("abcd", vec![0, 1, 2]),
701                ("AlphaBravoCharlie", vec![0, 5, 10]),
702                ("alphabravocharlie", vec![4, 5, 10]),
703            ]
704        );
705        assert_eq!(
706            match_query("t/i/a/t/d", false, &paths),
707            vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
708        );
709
710        assert_eq!(
711            match_query("tiatd", false, &paths),
712            vec![
713                ("/test/tiatd", vec![6, 7, 8, 9, 10]),
714                ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
715                ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
716                ("thisisatestdir", vec![0, 2, 6, 7, 11]),
717            ]
718        );
719    }
720
721    #[test]
722    fn test_match_multibyte_path_entries() {
723        let paths = vec!["aαbβ/cγdδ", "αβγδ/bcde", "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", "/d/🆒/h"];
724        assert_eq!("1️⃣".len(), 7);
725        assert_eq!(
726            match_query("bcd", false, &paths),
727            vec![
728                ("αβγδ/bcde", vec![9, 10, 11]),
729                ("aαbβ/cγdδ", vec![3, 7, 10]),
730            ]
731        );
732        assert_eq!(
733            match_query("cde", false, &paths),
734            vec![
735                ("αβγδ/bcde", vec![10, 11, 12]),
736                ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
737            ]
738        );
739    }
740
741    fn match_query<'a>(
742        query: &str,
743        smart_case: bool,
744        paths: &Vec<&'a str>,
745    ) -> Vec<(&'a str, Vec<usize>)> {
746        let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
747        let query = query.chars().collect::<Vec<_>>();
748        let query_chars = CharBag::from(&lowercase_query[..]);
749
750        let path_arcs = paths
751            .iter()
752            .map(|path| Arc::from(PathBuf::from(path)))
753            .collect::<Vec<_>>();
754        let mut path_entries = Vec::new();
755        for (i, path) in paths.iter().enumerate() {
756            let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
757            let char_bag = CharBag::from(lowercase_path.as_slice());
758            path_entries.push(PathMatchCandidate {
759                char_bag,
760                path: path_arcs.get(i).unwrap(),
761            });
762        }
763
764        let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case, 100);
765
766        let cancel_flag = AtomicBool::new(false);
767        let mut results = Vec::new();
768        matcher.match_paths(
769            0,
770            "".into(),
771            path_entries.into_iter(),
772            &mut results,
773            &cancel_flag,
774        );
775
776        results
777            .into_iter()
778            .map(|result| {
779                (
780                    paths
781                        .iter()
782                        .copied()
783                        .find(|p| result.path.as_ref() == Path::new(p))
784                        .unwrap(),
785                    result.positions,
786                )
787            })
788            .collect()
789    }
790}