fuzzy.rs

  1mod char_bag;
  2
  3use crate::{
  4    util,
  5    worktree::{EntryKind, Snapshot},
  6};
  7use gpui::executor;
  8use std::{
  9    borrow::Cow,
 10    cmp::{max, min, Ordering},
 11    path::Path,
 12    sync::atomic::{self, AtomicBool},
 13    sync::Arc,
 14};
 15
 16pub use char_bag::CharBag;
 17
 18const BASE_DISTANCE_PENALTY: f64 = 0.6;
 19const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
 20const MIN_DISTANCE_PENALTY: f64 = 0.2;
 21
 22struct Matcher<'a> {
 23    query: &'a [char],
 24    lowercase_query: &'a [char],
 25    query_char_bag: CharBag,
 26    smart_case: bool,
 27    max_results: usize,
 28    min_score: f64,
 29    match_positions: Vec<usize>,
 30    last_positions: Vec<usize>,
 31    score_matrix: Vec<Option<f64>>,
 32    best_position_matrix: Vec<usize>,
 33}
 34
 35trait Match: Ord {
 36    fn score(&self) -> f64;
 37    fn set_positions(&mut self, positions: Vec<usize>);
 38}
 39
 40trait MatchCandidate {
 41    fn has_chars(&self, bag: CharBag) -> bool;
 42    fn to_string<'a>(&'a self) -> Cow<'a, str>;
 43}
 44
 45#[derive(Clone, Debug)]
 46pub struct PathMatchCandidate<'a> {
 47    pub path: &'a Arc<Path>,
 48    pub char_bag: CharBag,
 49}
 50
 51#[derive(Clone, Debug)]
 52pub struct PathMatch {
 53    pub score: f64,
 54    pub positions: Vec<usize>,
 55    pub tree_id: usize,
 56    pub path: Arc<Path>,
 57    pub path_prefix: Arc<str>,
 58}
 59
 60#[derive(Clone, Debug)]
 61pub struct StringMatchCandidate {
 62    pub string: String,
 63    pub char_bag: CharBag,
 64}
 65
 66impl Match for PathMatch {
 67    fn score(&self) -> f64 {
 68        self.score
 69    }
 70
 71    fn set_positions(&mut self, positions: Vec<usize>) {
 72        self.positions = positions;
 73    }
 74}
 75
 76impl Match for StringMatch {
 77    fn score(&self) -> f64 {
 78        self.score
 79    }
 80
 81    fn set_positions(&mut self, positions: Vec<usize>) {
 82        self.positions = positions;
 83    }
 84}
 85
 86impl<'a> MatchCandidate for PathMatchCandidate<'a> {
 87    fn has_chars(&self, bag: CharBag) -> bool {
 88        self.char_bag.is_superset(bag)
 89    }
 90
 91    fn to_string(&self) -> Cow<'a, str> {
 92        self.path.to_string_lossy()
 93    }
 94}
 95
 96impl<'a> MatchCandidate for &'a StringMatchCandidate {
 97    fn has_chars(&self, bag: CharBag) -> bool {
 98        self.char_bag.is_superset(bag)
 99    }
100
101    fn to_string(&self) -> Cow<'a, str> {
102        self.string.as_str().into()
103    }
104}
105
106#[derive(Clone, Debug)]
107pub struct StringMatch {
108    pub score: f64,
109    pub positions: Vec<usize>,
110    pub string: String,
111}
112
113impl PartialEq for StringMatch {
114    fn eq(&self, other: &Self) -> bool {
115        self.score.eq(&other.score)
116    }
117}
118
119impl Eq for StringMatch {}
120
121impl PartialOrd for StringMatch {
122    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
123        Some(self.cmp(other))
124    }
125}
126
127impl Ord for StringMatch {
128    fn cmp(&self, other: &Self) -> Ordering {
129        self.score
130            .partial_cmp(&other.score)
131            .unwrap_or(Ordering::Equal)
132            .then_with(|| self.string.cmp(&other.string))
133    }
134}
135
136impl PartialEq for PathMatch {
137    fn eq(&self, other: &Self) -> bool {
138        self.score.eq(&other.score)
139    }
140}
141
142impl Eq for PathMatch {}
143
144impl PartialOrd for PathMatch {
145    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
146        Some(self.cmp(other))
147    }
148}
149
150impl Ord for PathMatch {
151    fn cmp(&self, other: &Self) -> Ordering {
152        self.score
153            .partial_cmp(&other.score)
154            .unwrap_or(Ordering::Equal)
155            .then_with(|| self.tree_id.cmp(&other.tree_id))
156            .then_with(|| Arc::as_ptr(&self.path).cmp(&Arc::as_ptr(&other.path)))
157    }
158}
159
160pub async fn match_strings(
161    candidates: &[StringMatchCandidate],
162    query: &str,
163    smart_case: bool,
164    max_results: usize,
165    cancel_flag: &AtomicBool,
166    background: Arc<executor::Background>,
167) -> Vec<StringMatch> {
168    let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
169    let query = query.chars().collect::<Vec<_>>();
170
171    let lowercase_query = &lowercase_query;
172    let query = &query;
173    let query_char_bag = CharBag::from(&lowercase_query[..]);
174
175    let num_cpus = background.num_cpus().min(candidates.len());
176    let segment_size = (candidates.len() + num_cpus - 1) / num_cpus;
177    let mut segment_results = (0..num_cpus)
178        .map(|_| Vec::with_capacity(max_results))
179        .collect::<Vec<_>>();
180
181    background
182        .scoped(|scope| {
183            for (segment_idx, results) in segment_results.iter_mut().enumerate() {
184                let cancel_flag = &cancel_flag;
185                scope.spawn(async move {
186                    let segment_start = segment_idx * segment_size;
187                    let segment_end = segment_start + segment_size;
188                    let mut matcher = Matcher::new(
189                        query,
190                        lowercase_query,
191                        query_char_bag,
192                        smart_case,
193                        max_results,
194                    );
195                    matcher.match_strings(
196                        &candidates[segment_start..segment_end],
197                        results,
198                        cancel_flag,
199                    );
200                });
201            }
202        })
203        .await;
204
205    let mut results = Vec::new();
206    for segment_result in segment_results {
207        if results.is_empty() {
208            results = segment_result;
209        } else {
210            util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(&a));
211        }
212    }
213    results
214}
215
216pub async fn match_paths(
217    snapshots: &[Snapshot],
218    query: &str,
219    include_ignored: bool,
220    smart_case: bool,
221    max_results: usize,
222    cancel_flag: &AtomicBool,
223    background: Arc<executor::Background>,
224) -> Vec<PathMatch> {
225    let path_count: usize = if include_ignored {
226        snapshots.iter().map(Snapshot::file_count).sum()
227    } else {
228        snapshots.iter().map(Snapshot::visible_file_count).sum()
229    };
230    if path_count == 0 {
231        return Vec::new();
232    }
233
234    let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
235    let query = query.chars().collect::<Vec<_>>();
236
237    let lowercase_query = &lowercase_query;
238    let query = &query;
239    let query_char_bag = CharBag::from(&lowercase_query[..]);
240
241    let num_cpus = background.num_cpus().min(path_count);
242    let segment_size = (path_count + num_cpus - 1) / num_cpus;
243    let mut segment_results = (0..num_cpus)
244        .map(|_| Vec::with_capacity(max_results))
245        .collect::<Vec<_>>();
246
247    background
248        .scoped(|scope| {
249            for (segment_idx, results) in segment_results.iter_mut().enumerate() {
250                scope.spawn(async move {
251                    let segment_start = segment_idx * segment_size;
252                    let segment_end = segment_start + segment_size;
253                    let mut matcher = Matcher::new(
254                        query,
255                        lowercase_query,
256                        query_char_bag,
257                        smart_case,
258                        max_results,
259                    );
260
261                    let mut tree_start = 0;
262                    for snapshot in snapshots {
263                        let tree_end = if include_ignored {
264                            tree_start + snapshot.file_count()
265                        } else {
266                            tree_start + snapshot.visible_file_count()
267                        };
268
269                        if tree_start < segment_end && segment_start < tree_end {
270                            let path_prefix: Arc<str> =
271                                if snapshot.root_entry().map_or(false, |e| e.is_file()) {
272                                    snapshot.root_name().into()
273                                } else if snapshots.len() > 1 {
274                                    format!("{}/", snapshot.root_name()).into()
275                                } else {
276                                    "".into()
277                                };
278
279                            let start = max(tree_start, segment_start) - tree_start;
280                            let end = min(tree_end, segment_end) - tree_start;
281                            if include_ignored {
282                                let paths = snapshot.files(start).take(end - start).map(|entry| {
283                                    if let EntryKind::File(char_bag) = entry.kind {
284                                        PathMatchCandidate {
285                                            path: &entry.path,
286                                            char_bag,
287                                        }
288                                    } else {
289                                        unreachable!()
290                                    }
291                                });
292                                matcher.match_paths(
293                                    snapshot.id(),
294                                    path_prefix,
295                                    paths,
296                                    results,
297                                    &cancel_flag,
298                                );
299                            } else {
300                                let paths =
301                                    snapshot
302                                        .visible_files(start)
303                                        .take(end - start)
304                                        .map(|entry| {
305                                            if let EntryKind::File(char_bag) = entry.kind {
306                                                PathMatchCandidate {
307                                                    path: &entry.path,
308                                                    char_bag,
309                                                }
310                                            } else {
311                                                unreachable!()
312                                            }
313                                        });
314                                matcher.match_paths(
315                                    snapshot.id(),
316                                    path_prefix,
317                                    paths,
318                                    results,
319                                    &cancel_flag,
320                                );
321                            };
322                        }
323                        if tree_end >= segment_end {
324                            break;
325                        }
326                        tree_start = tree_end;
327                    }
328                })
329            }
330        })
331        .await;
332
333    let mut results = Vec::new();
334    for segment_result in segment_results {
335        if results.is_empty() {
336            results = segment_result;
337        } else {
338            util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(&a));
339        }
340    }
341    results
342}
343
344impl<'a> Matcher<'a> {
345    fn new(
346        query: &'a [char],
347        lowercase_query: &'a [char],
348        query_char_bag: CharBag,
349        smart_case: bool,
350        max_results: usize,
351    ) -> Self {
352        Self {
353            query,
354            lowercase_query,
355            query_char_bag,
356            min_score: 0.0,
357            last_positions: vec![0; query.len()],
358            match_positions: vec![0; query.len()],
359            score_matrix: Vec::new(),
360            best_position_matrix: Vec::new(),
361            smart_case,
362            max_results,
363        }
364    }
365
366    fn match_strings(
367        &mut self,
368        candidates: &[StringMatchCandidate],
369        results: &mut Vec<StringMatch>,
370        cancel_flag: &AtomicBool,
371    ) {
372        self.match_internal(
373            &[],
374            &[],
375            candidates.iter(),
376            results,
377            cancel_flag,
378            |candidate, score| StringMatch {
379                score,
380                positions: Vec::new(),
381                string: candidate.string.to_string(),
382            },
383        )
384    }
385
386    fn match_paths(
387        &mut self,
388        tree_id: usize,
389        path_prefix: Arc<str>,
390        path_entries: impl Iterator<Item = PathMatchCandidate<'a>>,
391        results: &mut Vec<PathMatch>,
392        cancel_flag: &AtomicBool,
393    ) {
394        let prefix = path_prefix.chars().collect::<Vec<_>>();
395        let lowercase_prefix = prefix
396            .iter()
397            .map(|c| c.to_ascii_lowercase())
398            .collect::<Vec<_>>();
399        self.match_internal(
400            &prefix,
401            &lowercase_prefix,
402            path_entries,
403            results,
404            cancel_flag,
405            |candidate, score| PathMatch {
406                score,
407                tree_id,
408                positions: Vec::new(),
409                path: candidate.path.clone(),
410                path_prefix: path_prefix.clone(),
411            },
412        )
413    }
414
415    fn match_internal<C: MatchCandidate, R, F>(
416        &mut self,
417        prefix: &[char],
418        lowercase_prefix: &[char],
419        candidates: impl Iterator<Item = C>,
420        results: &mut Vec<R>,
421        cancel_flag: &AtomicBool,
422        build_match: F,
423    ) where
424        R: Match,
425        F: Fn(&C, f64) -> R,
426    {
427        let mut candidate_chars = Vec::new();
428        let mut lowercase_candidate_chars = Vec::new();
429
430        for candidate in candidates {
431            if !candidate.has_chars(self.query_char_bag) {
432                continue;
433            }
434
435            if cancel_flag.load(atomic::Ordering::Relaxed) {
436                break;
437            }
438
439            candidate_chars.clear();
440            lowercase_candidate_chars.clear();
441            for c in candidate.to_string().chars() {
442                candidate_chars.push(c);
443                lowercase_candidate_chars.push(c.to_ascii_lowercase());
444            }
445
446            if !self.find_last_positions(&lowercase_prefix, &lowercase_candidate_chars) {
447                continue;
448            }
449
450            let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
451            self.score_matrix.clear();
452            self.score_matrix.resize(matrix_len, None);
453            self.best_position_matrix.clear();
454            self.best_position_matrix.resize(matrix_len, 0);
455
456            let score = self.score_match(
457                &candidate_chars,
458                &lowercase_candidate_chars,
459                &prefix,
460                &lowercase_prefix,
461            );
462
463            if score > 0.0 {
464                let mut mat = build_match(&candidate, score);
465                if let Err(i) = results.binary_search_by(|m| mat.cmp(&m)) {
466                    if results.len() < self.max_results {
467                        mat.set_positions(self.match_positions.clone());
468                        results.insert(i, mat);
469                    } else if i < results.len() {
470                        results.pop();
471                        mat.set_positions(self.match_positions.clone());
472                        results.insert(i, mat);
473                    }
474                    if results.len() == self.max_results {
475                        self.min_score = results.last().unwrap().score();
476                    }
477                }
478            }
479        }
480    }
481
482    fn find_last_positions(&mut self, prefix: &[char], path: &[char]) -> bool {
483        let mut path = path.iter();
484        let mut prefix_iter = prefix.iter();
485        for (i, char) in self.query.iter().enumerate().rev() {
486            if let Some(j) = path.rposition(|c| c == char) {
487                self.last_positions[i] = j + prefix.len();
488            } else if let Some(j) = prefix_iter.rposition(|c| c == char) {
489                self.last_positions[i] = j;
490            } else {
491                return false;
492            }
493        }
494        true
495    }
496
497    fn score_match(
498        &mut self,
499        path: &[char],
500        path_cased: &[char],
501        prefix: &[char],
502        lowercase_prefix: &[char],
503    ) -> f64 {
504        let score = self.recursive_score_match(
505            path,
506            path_cased,
507            prefix,
508            lowercase_prefix,
509            0,
510            0,
511            self.query.len() as f64,
512        ) * self.query.len() as f64;
513
514        if score <= 0.0 {
515            return 0.0;
516        }
517
518        let path_len = prefix.len() + path.len();
519        let mut cur_start = 0;
520        let mut byte_ix = 0;
521        let mut char_ix = 0;
522        for i in 0..self.query.len() {
523            let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
524            while char_ix < match_char_ix {
525                let ch = prefix
526                    .get(char_ix)
527                    .or_else(|| path.get(char_ix - prefix.len()))
528                    .unwrap();
529                byte_ix += ch.len_utf8();
530                char_ix += 1;
531            }
532            cur_start = match_char_ix + 1;
533            self.match_positions[i] = byte_ix;
534        }
535
536        score
537    }
538
539    fn recursive_score_match(
540        &mut self,
541        path: &[char],
542        path_cased: &[char],
543        prefix: &[char],
544        lowercase_prefix: &[char],
545        query_idx: usize,
546        path_idx: usize,
547        cur_score: f64,
548    ) -> f64 {
549        if query_idx == self.query.len() {
550            return 1.0;
551        }
552
553        let path_len = prefix.len() + path.len();
554
555        if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
556            return memoized;
557        }
558
559        let mut score = 0.0;
560        let mut best_position = 0;
561
562        let query_char = self.lowercase_query[query_idx];
563        let limit = self.last_positions[query_idx];
564
565        let mut last_slash = 0;
566        for j in path_idx..=limit {
567            let path_char = if j < prefix.len() {
568                lowercase_prefix[j]
569            } else {
570                path_cased[j - prefix.len()]
571            };
572            let is_path_sep = path_char == '/' || path_char == '\\';
573
574            if query_idx == 0 && is_path_sep {
575                last_slash = j;
576            }
577
578            if query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\') {
579                let curr = if j < prefix.len() {
580                    prefix[j]
581                } else {
582                    path[j - prefix.len()]
583                };
584
585                let mut char_score = 1.0;
586                if j > path_idx {
587                    let last = if j - 1 < prefix.len() {
588                        prefix[j - 1]
589                    } else {
590                        path[j - 1 - prefix.len()]
591                    };
592
593                    if last == '/' {
594                        char_score = 0.9;
595                    } else if last == '-' || last == '_' || last == ' ' || last.is_numeric() {
596                        char_score = 0.8;
597                    } else if last.is_lowercase() && curr.is_uppercase() {
598                        char_score = 0.8;
599                    } else if last == '.' {
600                        char_score = 0.7;
601                    } else if query_idx == 0 {
602                        char_score = BASE_DISTANCE_PENALTY;
603                    } else {
604                        char_score = MIN_DISTANCE_PENALTY.max(
605                            BASE_DISTANCE_PENALTY
606                                - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
607                        );
608                    }
609                }
610
611                // Apply a severe penalty if the case doesn't match.
612                // This will make the exact matches have higher score than the case-insensitive and the
613                // path insensitive matches.
614                if (self.smart_case || curr == '/') && self.query[query_idx] != curr {
615                    char_score *= 0.001;
616                }
617
618                let mut multiplier = char_score;
619
620                // Scale the score based on how deep within the path we found the match.
621                if query_idx == 0 {
622                    multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
623                }
624
625                let mut next_score = 1.0;
626                if self.min_score > 0.0 {
627                    next_score = cur_score * multiplier;
628                    // Scores only decrease. If we can't pass the previous best, bail
629                    if next_score < self.min_score {
630                        // Ensure that score is non-zero so we use it in the memo table.
631                        if score == 0.0 {
632                            score = 1e-18;
633                        }
634                        continue;
635                    }
636                }
637
638                let new_score = self.recursive_score_match(
639                    path,
640                    path_cased,
641                    prefix,
642                    lowercase_prefix,
643                    query_idx + 1,
644                    j + 1,
645                    next_score,
646                ) * multiplier;
647
648                if new_score > score {
649                    score = new_score;
650                    best_position = j;
651                    // Optimization: can't score better than 1.
652                    if new_score == 1.0 {
653                        break;
654                    }
655                }
656            }
657        }
658
659        if best_position != 0 {
660            self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
661        }
662
663        self.score_matrix[query_idx * path_len + path_idx] = Some(score);
664        score
665    }
666}
667
668#[cfg(test)]
669mod tests {
670    use super::*;
671    use std::path::PathBuf;
672
673    #[test]
674    fn test_get_last_positions() {
675        let mut query: &[char] = &['d', 'c'];
676        let mut matcher = Matcher::new(query, query, query.into(), false, 10);
677        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
678        assert_eq!(result, false);
679
680        query = &['c', 'd'];
681        let mut matcher = Matcher::new(query, query, query.into(), false, 10);
682        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
683        assert_eq!(result, true);
684        assert_eq!(matcher.last_positions, vec![2, 4]);
685
686        query = &['z', '/', 'z', 'f'];
687        let mut matcher = Matcher::new(query, query, query.into(), false, 10);
688        let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
689        assert_eq!(result, true);
690        assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
691    }
692
693    #[test]
694    fn test_match_path_entries() {
695        let paths = vec![
696            "",
697            "a",
698            "ab",
699            "abC",
700            "abcd",
701            "alphabravocharlie",
702            "AlphaBravoCharlie",
703            "thisisatestdir",
704            "/////ThisIsATestDir",
705            "/this/is/a/test/dir",
706            "/test/tiatd",
707        ];
708
709        assert_eq!(
710            match_query("abc", false, &paths),
711            vec![
712                ("abC", vec![0, 1, 2]),
713                ("abcd", vec![0, 1, 2]),
714                ("AlphaBravoCharlie", vec![0, 5, 10]),
715                ("alphabravocharlie", vec![4, 5, 10]),
716            ]
717        );
718        assert_eq!(
719            match_query("t/i/a/t/d", false, &paths),
720            vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
721        );
722
723        assert_eq!(
724            match_query("tiatd", false, &paths),
725            vec![
726                ("/test/tiatd", vec![6, 7, 8, 9, 10]),
727                ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
728                ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
729                ("thisisatestdir", vec![0, 2, 6, 7, 11]),
730            ]
731        );
732    }
733
734    #[test]
735    fn test_match_multibyte_path_entries() {
736        let paths = vec!["aαbβ/cγdδ", "αβγδ/bcde", "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", "/d/🆒/h"];
737        assert_eq!("1️⃣".len(), 7);
738        assert_eq!(
739            match_query("bcd", false, &paths),
740            vec![
741                ("αβγδ/bcde", vec![9, 10, 11]),
742                ("aαbβ/cγdδ", vec![3, 7, 10]),
743            ]
744        );
745        assert_eq!(
746            match_query("cde", false, &paths),
747            vec![
748                ("αβγδ/bcde", vec![10, 11, 12]),
749                ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
750            ]
751        );
752    }
753
754    fn match_query<'a>(
755        query: &str,
756        smart_case: bool,
757        paths: &Vec<&'a str>,
758    ) -> Vec<(&'a str, Vec<usize>)> {
759        let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
760        let query = query.chars().collect::<Vec<_>>();
761        let query_chars = CharBag::from(&lowercase_query[..]);
762
763        let path_arcs = paths
764            .iter()
765            .map(|path| Arc::from(PathBuf::from(path)))
766            .collect::<Vec<_>>();
767        let mut path_entries = Vec::new();
768        for (i, path) in paths.iter().enumerate() {
769            let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
770            let char_bag = CharBag::from(lowercase_path.as_slice());
771            path_entries.push(PathMatchCandidate {
772                char_bag,
773                path: path_arcs.get(i).unwrap(),
774            });
775        }
776
777        let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case, 100);
778
779        let cancel_flag = AtomicBool::new(false);
780        let mut results = Vec::new();
781        matcher.match_paths(
782            0,
783            "".into(),
784            path_entries.into_iter(),
785            &mut results,
786            &cancel_flag,
787        );
788
789        results
790            .into_iter()
791            .map(|result| {
792                (
793                    paths
794                        .iter()
795                        .copied()
796                        .find(|p| result.path.as_ref() == Path::new(p))
797                        .unwrap(),
798                    result.positions,
799                )
800            })
801            .collect()
802    }
803}