fuzzy.rs

  1mod char_bag;
  2
  3use gpui::executor;
  4use std::{
  5    borrow::Cow,
  6    cmp::{self, Ordering},
  7    path::Path,
  8    sync::atomic::{self, AtomicBool},
  9    sync::Arc,
 10};
 11
 12pub use char_bag::CharBag;
 13
 14const BASE_DISTANCE_PENALTY: f64 = 0.6;
 15const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
 16const MIN_DISTANCE_PENALTY: f64 = 0.2;
 17
 18pub struct Matcher<'a> {
 19    query: &'a [char],
 20    lowercase_query: &'a [char],
 21    query_char_bag: CharBag,
 22    smart_case: bool,
 23    max_results: usize,
 24    min_score: f64,
 25    match_positions: Vec<usize>,
 26    last_positions: Vec<usize>,
 27    score_matrix: Vec<Option<f64>>,
 28    best_position_matrix: Vec<usize>,
 29}
 30
 31trait Match: Ord {
 32    fn score(&self) -> f64;
 33    fn set_positions(&mut self, positions: Vec<usize>);
 34}
 35
 36trait MatchCandidate {
 37    fn has_chars(&self, bag: CharBag) -> bool;
 38    fn to_string(&self) -> Cow<'_, str>;
 39}
 40
 41#[derive(Clone, Debug)]
 42pub struct PathMatchCandidate<'a> {
 43    pub path: &'a Arc<Path>,
 44    pub char_bag: CharBag,
 45}
 46
 47#[derive(Clone, Debug)]
 48pub struct PathMatch {
 49    pub score: f64,
 50    pub positions: Vec<usize>,
 51    pub worktree_id: usize,
 52    pub path: Arc<Path>,
 53    pub path_prefix: Arc<str>,
 54}
 55
 56#[derive(Clone, Debug)]
 57pub struct StringMatchCandidate {
 58    pub id: usize,
 59    pub string: String,
 60    pub char_bag: CharBag,
 61}
 62
 63pub trait PathMatchCandidateSet<'a>: Send + Sync {
 64    type Candidates: Iterator<Item = PathMatchCandidate<'a>>;
 65    fn id(&self) -> usize;
 66    fn len(&self) -> usize;
 67    fn is_empty(&self) -> bool {
 68        self.len() == 0
 69    }
 70    fn prefix(&self) -> Arc<str>;
 71    fn candidates(&'a self, start: usize) -> Self::Candidates;
 72}
 73
 74impl Match for PathMatch {
 75    fn score(&self) -> f64 {
 76        self.score
 77    }
 78
 79    fn set_positions(&mut self, positions: Vec<usize>) {
 80        self.positions = positions;
 81    }
 82}
 83
 84impl Match for StringMatch {
 85    fn score(&self) -> f64 {
 86        self.score
 87    }
 88
 89    fn set_positions(&mut self, positions: Vec<usize>) {
 90        self.positions = positions;
 91    }
 92}
 93
 94impl<'a> MatchCandidate for PathMatchCandidate<'a> {
 95    fn has_chars(&self, bag: CharBag) -> bool {
 96        self.char_bag.is_superset(bag)
 97    }
 98
 99    fn to_string(&self) -> Cow<'a, str> {
100        self.path.to_string_lossy()
101    }
102}
103
104impl StringMatchCandidate {
105    pub fn new(id: usize, string: String) -> Self {
106        Self {
107            id,
108            char_bag: CharBag::from(string.as_str()),
109            string,
110        }
111    }
112}
113
114impl<'a> MatchCandidate for &'a StringMatchCandidate {
115    fn has_chars(&self, bag: CharBag) -> bool {
116        self.char_bag.is_superset(bag)
117    }
118
119    fn to_string(&self) -> Cow<'a, str> {
120        self.string.as_str().into()
121    }
122}
123
124#[derive(Clone, Debug)]
125pub struct StringMatch {
126    pub candidate_id: usize,
127    pub score: f64,
128    pub positions: Vec<usize>,
129    pub string: String,
130}
131
132impl PartialEq for StringMatch {
133    fn eq(&self, other: &Self) -> bool {
134        self.cmp(other).is_eq()
135    }
136}
137
138impl Eq for StringMatch {}
139
140impl PartialOrd for StringMatch {
141    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
142        Some(self.cmp(other))
143    }
144}
145
146impl Ord for StringMatch {
147    fn cmp(&self, other: &Self) -> Ordering {
148        self.score
149            .partial_cmp(&other.score)
150            .unwrap_or(Ordering::Equal)
151            .then_with(|| self.candidate_id.cmp(&other.candidate_id))
152    }
153}
154
155impl PartialEq for PathMatch {
156    fn eq(&self, other: &Self) -> bool {
157        self.cmp(other).is_eq()
158    }
159}
160
161impl Eq for PathMatch {}
162
163impl PartialOrd for PathMatch {
164    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
165        Some(self.cmp(other))
166    }
167}
168
169impl Ord for PathMatch {
170    fn cmp(&self, other: &Self) -> Ordering {
171        self.score
172            .partial_cmp(&other.score)
173            .unwrap_or(Ordering::Equal)
174            .then_with(|| self.worktree_id.cmp(&other.worktree_id))
175            .then_with(|| Arc::as_ptr(&self.path).cmp(&Arc::as_ptr(&other.path)))
176    }
177}
178
179pub async fn match_strings(
180    candidates: &[StringMatchCandidate],
181    query: &str,
182    smart_case: bool,
183    max_results: usize,
184    cancel_flag: &AtomicBool,
185    background: Arc<executor::Background>,
186) -> Vec<StringMatch> {
187    if candidates.is_empty() || max_results == 0 {
188        return Default::default();
189    }
190
191    if query.is_empty() {
192        return candidates
193            .iter()
194            .map(|candidate| StringMatch {
195                candidate_id: candidate.id,
196                score: 0.,
197                positions: Default::default(),
198                string: candidate.string.clone(),
199            })
200            .collect();
201    }
202
203    let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
204    let query = query.chars().collect::<Vec<_>>();
205
206    let lowercase_query = &lowercase_query;
207    let query = &query;
208    let query_char_bag = CharBag::from(&lowercase_query[..]);
209
210    let num_cpus = background.num_cpus().min(candidates.len());
211    let segment_size = (candidates.len() + num_cpus - 1) / num_cpus;
212    let mut segment_results = (0..num_cpus)
213        .map(|_| Vec::with_capacity(max_results.min(candidates.len())))
214        .collect::<Vec<_>>();
215
216    background
217        .scoped(|scope| {
218            for (segment_idx, results) in segment_results.iter_mut().enumerate() {
219                let cancel_flag = &cancel_flag;
220                scope.spawn(async move {
221                    let segment_start = cmp::min(segment_idx * segment_size, candidates.len());
222                    let segment_end = cmp::min(segment_start + segment_size, candidates.len());
223                    let mut matcher = Matcher::new(
224                        query,
225                        lowercase_query,
226                        query_char_bag,
227                        smart_case,
228                        max_results,
229                    );
230                    matcher.match_strings(
231                        &candidates[segment_start..segment_end],
232                        results,
233                        cancel_flag,
234                    );
235                });
236            }
237        })
238        .await;
239
240    let mut results = Vec::new();
241    for segment_result in segment_results {
242        if results.is_empty() {
243            results = segment_result;
244        } else {
245            util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(a));
246        }
247    }
248    results
249}
250
251pub async fn match_paths<'a, Set: PathMatchCandidateSet<'a>>(
252    candidate_sets: &'a [Set],
253    query: &str,
254    smart_case: bool,
255    max_results: usize,
256    cancel_flag: &AtomicBool,
257    background: Arc<executor::Background>,
258) -> Vec<PathMatch> {
259    let path_count: usize = candidate_sets.iter().map(|s| s.len()).sum();
260    if path_count == 0 {
261        return Vec::new();
262    }
263
264    let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
265    let query = query.chars().collect::<Vec<_>>();
266
267    let lowercase_query = &lowercase_query;
268    let query = &query;
269    let query_char_bag = CharBag::from(&lowercase_query[..]);
270
271    let num_cpus = background.num_cpus().min(path_count);
272    let segment_size = (path_count + num_cpus - 1) / num_cpus;
273    let mut segment_results = (0..num_cpus)
274        .map(|_| Vec::with_capacity(max_results))
275        .collect::<Vec<_>>();
276
277    background
278        .scoped(|scope| {
279            for (segment_idx, results) in segment_results.iter_mut().enumerate() {
280                scope.spawn(async move {
281                    let segment_start = segment_idx * segment_size;
282                    let segment_end = segment_start + segment_size;
283                    let mut matcher = Matcher::new(
284                        query,
285                        lowercase_query,
286                        query_char_bag,
287                        smart_case,
288                        max_results,
289                    );
290
291                    let mut tree_start = 0;
292                    for candidate_set in candidate_sets {
293                        let tree_end = tree_start + candidate_set.len();
294
295                        if tree_start < segment_end && segment_start < tree_end {
296                            let start = cmp::max(tree_start, segment_start) - tree_start;
297                            let end = cmp::min(tree_end, segment_end) - tree_start;
298                            let candidates = candidate_set.candidates(start).take(end - start);
299
300                            matcher.match_paths(
301                                candidate_set.id(),
302                                candidate_set.prefix(),
303                                candidates,
304                                results,
305                                cancel_flag,
306                            );
307                        }
308                        if tree_end >= segment_end {
309                            break;
310                        }
311                        tree_start = tree_end;
312                    }
313                })
314            }
315        })
316        .await;
317
318    let mut results = Vec::new();
319    for segment_result in segment_results {
320        if results.is_empty() {
321            results = segment_result;
322        } else {
323            util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(a));
324        }
325    }
326    results
327}
328
329impl<'a> Matcher<'a> {
330    pub fn new(
331        query: &'a [char],
332        lowercase_query: &'a [char],
333        query_char_bag: CharBag,
334        smart_case: bool,
335        max_results: usize,
336    ) -> Self {
337        Self {
338            query,
339            lowercase_query,
340            query_char_bag,
341            min_score: 0.0,
342            last_positions: vec![0; query.len()],
343            match_positions: vec![0; query.len()],
344            score_matrix: Vec::new(),
345            best_position_matrix: Vec::new(),
346            smart_case,
347            max_results,
348        }
349    }
350
351    pub fn match_strings(
352        &mut self,
353        candidates: &[StringMatchCandidate],
354        results: &mut Vec<StringMatch>,
355        cancel_flag: &AtomicBool,
356    ) {
357        self.match_internal(
358            &[],
359            &[],
360            candidates.iter(),
361            results,
362            cancel_flag,
363            |candidate, score| StringMatch {
364                candidate_id: candidate.id,
365                score,
366                positions: Vec::new(),
367                string: candidate.string.to_string(),
368            },
369        )
370    }
371
372    pub fn match_paths<'c: 'a>(
373        &mut self,
374        tree_id: usize,
375        path_prefix: Arc<str>,
376        path_entries: impl Iterator<Item = PathMatchCandidate<'c>>,
377        results: &mut Vec<PathMatch>,
378        cancel_flag: &AtomicBool,
379    ) {
380        let prefix = path_prefix.chars().collect::<Vec<_>>();
381        let lowercase_prefix = prefix
382            .iter()
383            .map(|c| c.to_ascii_lowercase())
384            .collect::<Vec<_>>();
385        self.match_internal(
386            &prefix,
387            &lowercase_prefix,
388            path_entries,
389            results,
390            cancel_flag,
391            |candidate, score| PathMatch {
392                score,
393                worktree_id: tree_id,
394                positions: Vec::new(),
395                path: candidate.path.clone(),
396                path_prefix: path_prefix.clone(),
397            },
398        )
399    }
400
401    fn match_internal<C: MatchCandidate, R, F>(
402        &mut self,
403        prefix: &[char],
404        lowercase_prefix: &[char],
405        candidates: impl Iterator<Item = C>,
406        results: &mut Vec<R>,
407        cancel_flag: &AtomicBool,
408        build_match: F,
409    ) where
410        R: Match,
411        F: Fn(&C, f64) -> R,
412    {
413        let mut candidate_chars = Vec::new();
414        let mut lowercase_candidate_chars = Vec::new();
415
416        for candidate in candidates {
417            if !candidate.has_chars(self.query_char_bag) {
418                continue;
419            }
420
421            if cancel_flag.load(atomic::Ordering::Relaxed) {
422                break;
423            }
424
425            candidate_chars.clear();
426            lowercase_candidate_chars.clear();
427            for c in candidate.to_string().chars() {
428                candidate_chars.push(c);
429                lowercase_candidate_chars.push(c.to_ascii_lowercase());
430            }
431
432            if !self.find_last_positions(lowercase_prefix, &lowercase_candidate_chars) {
433                continue;
434            }
435
436            let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
437            self.score_matrix.clear();
438            self.score_matrix.resize(matrix_len, None);
439            self.best_position_matrix.clear();
440            self.best_position_matrix.resize(matrix_len, 0);
441
442            let score = self.score_match(
443                &candidate_chars,
444                &lowercase_candidate_chars,
445                prefix,
446                lowercase_prefix,
447            );
448
449            if score > 0.0 {
450                let mut mat = build_match(&candidate, score);
451                if let Err(i) = results.binary_search_by(|m| mat.cmp(m)) {
452                    if results.len() < self.max_results {
453                        mat.set_positions(self.match_positions.clone());
454                        results.insert(i, mat);
455                    } else if i < results.len() {
456                        results.pop();
457                        mat.set_positions(self.match_positions.clone());
458                        results.insert(i, mat);
459                    }
460                    if results.len() == self.max_results {
461                        self.min_score = results.last().unwrap().score();
462                    }
463                }
464            }
465        }
466    }
467
468    fn find_last_positions(
469        &mut self,
470        lowercase_prefix: &[char],
471        lowercase_candidate: &[char],
472    ) -> bool {
473        let mut lowercase_prefix = lowercase_prefix.iter();
474        let mut lowercase_candidate = lowercase_candidate.iter();
475        for (i, char) in self.lowercase_query.iter().enumerate().rev() {
476            if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
477                self.last_positions[i] = j + lowercase_prefix.len();
478            } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
479                self.last_positions[i] = j;
480            } else {
481                return false;
482            }
483        }
484        true
485    }
486
487    fn score_match(
488        &mut self,
489        path: &[char],
490        path_cased: &[char],
491        prefix: &[char],
492        lowercase_prefix: &[char],
493    ) -> f64 {
494        let score = self.recursive_score_match(
495            path,
496            path_cased,
497            prefix,
498            lowercase_prefix,
499            0,
500            0,
501            self.query.len() as f64,
502        ) * self.query.len() as f64;
503
504        if score <= 0.0 {
505            return 0.0;
506        }
507
508        let path_len = prefix.len() + path.len();
509        let mut cur_start = 0;
510        let mut byte_ix = 0;
511        let mut char_ix = 0;
512        for i in 0..self.query.len() {
513            let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
514            while char_ix < match_char_ix {
515                let ch = prefix
516                    .get(char_ix)
517                    .or_else(|| path.get(char_ix - prefix.len()))
518                    .unwrap();
519                byte_ix += ch.len_utf8();
520                char_ix += 1;
521            }
522            cur_start = match_char_ix + 1;
523            self.match_positions[i] = byte_ix;
524        }
525
526        score
527    }
528
529    #[allow(clippy::too_many_arguments)]
530    fn recursive_score_match(
531        &mut self,
532        path: &[char],
533        path_cased: &[char],
534        prefix: &[char],
535        lowercase_prefix: &[char],
536        query_idx: usize,
537        path_idx: usize,
538        cur_score: f64,
539    ) -> f64 {
540        if query_idx == self.query.len() {
541            return 1.0;
542        }
543
544        let path_len = prefix.len() + path.len();
545
546        if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
547            return memoized;
548        }
549
550        let mut score = 0.0;
551        let mut best_position = 0;
552
553        let query_char = self.lowercase_query[query_idx];
554        let limit = self.last_positions[query_idx];
555
556        let mut last_slash = 0;
557        for j in path_idx..=limit {
558            let path_char = if j < prefix.len() {
559                lowercase_prefix[j]
560            } else {
561                path_cased[j - prefix.len()]
562            };
563            let is_path_sep = path_char == '/' || path_char == '\\';
564
565            if query_idx == 0 && is_path_sep {
566                last_slash = j;
567            }
568
569            if query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\') {
570                let curr = if j < prefix.len() {
571                    prefix[j]
572                } else {
573                    path[j - prefix.len()]
574                };
575
576                let mut char_score = 1.0;
577                if j > path_idx {
578                    let last = if j - 1 < prefix.len() {
579                        prefix[j - 1]
580                    } else {
581                        path[j - 1 - prefix.len()]
582                    };
583
584                    if last == '/' {
585                        char_score = 0.9;
586                    } else if (last == '-' || last == '_' || last == ' ' || last.is_numeric())
587                        || (last.is_lowercase() && curr.is_uppercase())
588                    {
589                        char_score = 0.8;
590                    } else if last == '.' {
591                        char_score = 0.7;
592                    } else if query_idx == 0 {
593                        char_score = BASE_DISTANCE_PENALTY;
594                    } else {
595                        char_score = MIN_DISTANCE_PENALTY.max(
596                            BASE_DISTANCE_PENALTY
597                                - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
598                        );
599                    }
600                }
601
602                // Apply a severe penalty if the case doesn't match.
603                // This will make the exact matches have higher score than the case-insensitive and the
604                // path insensitive matches.
605                if (self.smart_case || curr == '/') && self.query[query_idx] != curr {
606                    char_score *= 0.001;
607                }
608
609                let mut multiplier = char_score;
610
611                // Scale the score based on how deep within the path we found the match.
612                if query_idx == 0 {
613                    multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
614                }
615
616                let mut next_score = 1.0;
617                if self.min_score > 0.0 {
618                    next_score = cur_score * multiplier;
619                    // Scores only decrease. If we can't pass the previous best, bail
620                    if next_score < self.min_score {
621                        // Ensure that score is non-zero so we use it in the memo table.
622                        if score == 0.0 {
623                            score = 1e-18;
624                        }
625                        continue;
626                    }
627                }
628
629                let new_score = self.recursive_score_match(
630                    path,
631                    path_cased,
632                    prefix,
633                    lowercase_prefix,
634                    query_idx + 1,
635                    j + 1,
636                    next_score,
637                ) * multiplier;
638
639                if new_score > score {
640                    score = new_score;
641                    best_position = j;
642                    // Optimization: can't score better than 1.
643                    if new_score == 1.0 {
644                        break;
645                    }
646                }
647            }
648        }
649
650        if best_position != 0 {
651            self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
652        }
653
654        self.score_matrix[query_idx * path_len + path_idx] = Some(score);
655        score
656    }
657}
658
659#[cfg(test)]
660mod tests {
661    use super::*;
662    use std::path::PathBuf;
663
664    #[test]
665    fn test_get_last_positions() {
666        let mut query: &[char] = &['d', 'c'];
667        let mut matcher = Matcher::new(query, query, query.into(), false, 10);
668        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
669        assert!(!result);
670
671        query = &['c', 'd'];
672        let mut matcher = Matcher::new(query, query, query.into(), false, 10);
673        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
674        assert!(result);
675        assert_eq!(matcher.last_positions, vec![2, 4]);
676
677        query = &['z', '/', 'z', 'f'];
678        let mut matcher = Matcher::new(query, query, query.into(), false, 10);
679        let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
680        assert!(result);
681        assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
682    }
683
684    #[test]
685    fn test_match_path_entries() {
686        let paths = vec![
687            "",
688            "a",
689            "ab",
690            "abC",
691            "abcd",
692            "alphabravocharlie",
693            "AlphaBravoCharlie",
694            "thisisatestdir",
695            "/////ThisIsATestDir",
696            "/this/is/a/test/dir",
697            "/test/tiatd",
698        ];
699
700        assert_eq!(
701            match_query("abc", false, &paths),
702            vec![
703                ("abC", vec![0, 1, 2]),
704                ("abcd", vec![0, 1, 2]),
705                ("AlphaBravoCharlie", vec![0, 5, 10]),
706                ("alphabravocharlie", vec![4, 5, 10]),
707            ]
708        );
709        assert_eq!(
710            match_query("t/i/a/t/d", false, &paths),
711            vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
712        );
713
714        assert_eq!(
715            match_query("tiatd", false, &paths),
716            vec![
717                ("/test/tiatd", vec![6, 7, 8, 9, 10]),
718                ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
719                ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
720                ("thisisatestdir", vec![0, 2, 6, 7, 11]),
721            ]
722        );
723    }
724
725    #[test]
726    fn test_match_multibyte_path_entries() {
727        let paths = vec!["aαbβ/cγdδ", "αβγδ/bcde", "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", "/d/🆒/h"];
728        assert_eq!("1️⃣".len(), 7);
729        assert_eq!(
730            match_query("bcd", false, &paths),
731            vec![
732                ("αβγδ/bcde", vec![9, 10, 11]),
733                ("aαbβ/cγdδ", vec![3, 7, 10]),
734            ]
735        );
736        assert_eq!(
737            match_query("cde", false, &paths),
738            vec![
739                ("αβγδ/bcde", vec![10, 11, 12]),
740                ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
741            ]
742        );
743    }
744
745    fn match_query<'a>(
746        query: &str,
747        smart_case: bool,
748        paths: &[&'a str],
749    ) -> Vec<(&'a str, Vec<usize>)> {
750        let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
751        let query = query.chars().collect::<Vec<_>>();
752        let query_chars = CharBag::from(&lowercase_query[..]);
753
754        let path_arcs = paths
755            .iter()
756            .map(|path| Arc::from(PathBuf::from(path)))
757            .collect::<Vec<_>>();
758        let mut path_entries = Vec::new();
759        for (i, path) in paths.iter().enumerate() {
760            let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
761            let char_bag = CharBag::from(lowercase_path.as_slice());
762            path_entries.push(PathMatchCandidate {
763                char_bag,
764                path: path_arcs.get(i).unwrap(),
765            });
766        }
767
768        let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case, 100);
769
770        let cancel_flag = AtomicBool::new(false);
771        let mut results = Vec::new();
772        matcher.match_paths(
773            0,
774            "".into(),
775            path_entries.into_iter(),
776            &mut results,
777            &cancel_flag,
778        );
779
780        results
781            .into_iter()
782            .map(|result| {
783                (
784                    paths
785                        .iter()
786                        .copied()
787                        .find(|p| result.path.as_ref() == Path::new(p))
788                        .unwrap(),
789                    result.positions,
790                )
791            })
792            .collect()
793    }
794}