lib.rs

  1mod char_bag;
  2
  3use gpui::executor;
  4use std::{
  5    borrow::Cow,
  6    cmp::{self, Ordering},
  7    path::Path,
  8    sync::atomic::{self, AtomicBool},
  9    sync::Arc,
 10};
 11
 12pub use char_bag::CharBag;
 13
 14const BASE_DISTANCE_PENALTY: f64 = 0.6;
 15const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
 16const MIN_DISTANCE_PENALTY: f64 = 0.2;
 17
 18pub struct Matcher<'a> {
 19    query: &'a [char],
 20    lowercase_query: &'a [char],
 21    query_char_bag: CharBag,
 22    smart_case: bool,
 23    max_results: usize,
 24    min_score: f64,
 25    match_positions: Vec<usize>,
 26    last_positions: Vec<usize>,
 27    score_matrix: Vec<Option<f64>>,
 28    best_position_matrix: Vec<usize>,
 29}
 30
 31trait Match: Ord {
 32    fn score(&self) -> f64;
 33    fn set_positions(&mut self, positions: Vec<usize>);
 34}
 35
 36trait MatchCandidate {
 37    fn has_chars(&self, bag: CharBag) -> bool;
 38    fn to_string<'a>(&'a self) -> Cow<'a, str>;
 39}
 40
 41#[derive(Clone, Debug)]
 42pub struct PathMatchCandidate<'a> {
 43    pub path: &'a Arc<Path>,
 44    pub char_bag: CharBag,
 45}
 46
 47#[derive(Clone, Debug)]
 48pub struct PathMatch {
 49    pub score: f64,
 50    pub positions: Vec<usize>,
 51    pub worktree_id: usize,
 52    pub path: Arc<Path>,
 53    pub path_prefix: Arc<str>,
 54}
 55
 56#[derive(Clone, Debug)]
 57pub struct StringMatchCandidate {
 58    pub string: String,
 59    pub char_bag: CharBag,
 60}
 61
 62pub trait PathMatchCandidateSet<'a>: Send + Sync {
 63    type Candidates: Iterator<Item = PathMatchCandidate<'a>>;
 64    fn id(&self) -> usize;
 65    fn len(&self) -> usize;
 66    fn prefix(&self) -> Arc<str>;
 67    fn candidates(&'a self, start: usize) -> Self::Candidates;
 68}
 69
 70impl Match for PathMatch {
 71    fn score(&self) -> f64 {
 72        self.score
 73    }
 74
 75    fn set_positions(&mut self, positions: Vec<usize>) {
 76        self.positions = positions;
 77    }
 78}
 79
 80impl Match for StringMatch {
 81    fn score(&self) -> f64 {
 82        self.score
 83    }
 84
 85    fn set_positions(&mut self, positions: Vec<usize>) {
 86        self.positions = positions;
 87    }
 88}
 89
 90impl<'a> MatchCandidate for PathMatchCandidate<'a> {
 91    fn has_chars(&self, bag: CharBag) -> bool {
 92        self.char_bag.is_superset(bag)
 93    }
 94
 95    fn to_string(&self) -> Cow<'a, str> {
 96        self.path.to_string_lossy()
 97    }
 98}
 99
100impl<'a> MatchCandidate for &'a StringMatchCandidate {
101    fn has_chars(&self, bag: CharBag) -> bool {
102        self.char_bag.is_superset(bag)
103    }
104
105    fn to_string(&self) -> Cow<'a, str> {
106        self.string.as_str().into()
107    }
108}
109
110#[derive(Clone, Debug)]
111pub struct StringMatch {
112    pub score: f64,
113    pub positions: Vec<usize>,
114    pub string: String,
115}
116
117impl PartialEq for StringMatch {
118    fn eq(&self, other: &Self) -> bool {
119        self.score.eq(&other.score)
120    }
121}
122
123impl Eq for StringMatch {}
124
125impl PartialOrd for StringMatch {
126    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
127        Some(self.cmp(other))
128    }
129}
130
131impl Ord for StringMatch {
132    fn cmp(&self, other: &Self) -> Ordering {
133        self.score
134            .partial_cmp(&other.score)
135            .unwrap_or(Ordering::Equal)
136            .then_with(|| self.string.cmp(&other.string))
137    }
138}
139
140impl PartialEq for PathMatch {
141    fn eq(&self, other: &Self) -> bool {
142        self.score.eq(&other.score)
143    }
144}
145
146impl Eq for PathMatch {}
147
148impl PartialOrd for PathMatch {
149    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
150        Some(self.cmp(other))
151    }
152}
153
154impl Ord for PathMatch {
155    fn cmp(&self, other: &Self) -> Ordering {
156        self.score
157            .partial_cmp(&other.score)
158            .unwrap_or(Ordering::Equal)
159            .then_with(|| self.worktree_id.cmp(&other.worktree_id))
160            .then_with(|| Arc::as_ptr(&self.path).cmp(&Arc::as_ptr(&other.path)))
161    }
162}
163
164pub async fn match_strings(
165    candidates: &[StringMatchCandidate],
166    query: &str,
167    smart_case: bool,
168    max_results: usize,
169    cancel_flag: &AtomicBool,
170    background: Arc<executor::Background>,
171) -> Vec<StringMatch> {
172    let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
173    let query = query.chars().collect::<Vec<_>>();
174
175    let lowercase_query = &lowercase_query;
176    let query = &query;
177    let query_char_bag = CharBag::from(&lowercase_query[..]);
178
179    let num_cpus = background.num_cpus().min(candidates.len());
180    let segment_size = (candidates.len() + num_cpus - 1) / num_cpus;
181    let mut segment_results = (0..num_cpus)
182        .map(|_| Vec::with_capacity(max_results))
183        .collect::<Vec<_>>();
184
185    background
186        .scoped(|scope| {
187            for (segment_idx, results) in segment_results.iter_mut().enumerate() {
188                let cancel_flag = &cancel_flag;
189                scope.spawn(async move {
190                    let segment_start = segment_idx * segment_size;
191                    let segment_end = segment_start + segment_size;
192                    let mut matcher = Matcher::new(
193                        query,
194                        lowercase_query,
195                        query_char_bag,
196                        smart_case,
197                        max_results,
198                    );
199                    matcher.match_strings(
200                        &candidates[segment_start..segment_end],
201                        results,
202                        cancel_flag,
203                    );
204                });
205            }
206        })
207        .await;
208
209    let mut results = Vec::new();
210    for segment_result in segment_results {
211        if results.is_empty() {
212            results = segment_result;
213        } else {
214            util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(&a));
215        }
216    }
217    results
218}
219
220pub async fn match_paths<'a, Set: PathMatchCandidateSet<'a>>(
221    candidate_sets: &'a [Set],
222    query: &str,
223    smart_case: bool,
224    max_results: usize,
225    cancel_flag: &AtomicBool,
226    background: Arc<executor::Background>,
227) -> Vec<PathMatch> {
228    let path_count: usize = candidate_sets.iter().map(|s| s.len()).sum();
229    if path_count == 0 {
230        return Vec::new();
231    }
232
233    let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
234    let query = query.chars().collect::<Vec<_>>();
235
236    let lowercase_query = &lowercase_query;
237    let query = &query;
238    let query_char_bag = CharBag::from(&lowercase_query[..]);
239
240    let num_cpus = background.num_cpus().min(path_count);
241    let segment_size = (path_count + num_cpus - 1) / num_cpus;
242    let mut segment_results = (0..num_cpus)
243        .map(|_| Vec::with_capacity(max_results))
244        .collect::<Vec<_>>();
245
246    background
247        .scoped(|scope| {
248            for (segment_idx, results) in segment_results.iter_mut().enumerate() {
249                scope.spawn(async move {
250                    let segment_start = segment_idx * segment_size;
251                    let segment_end = segment_start + segment_size;
252                    let mut matcher = Matcher::new(
253                        query,
254                        lowercase_query,
255                        query_char_bag,
256                        smart_case,
257                        max_results,
258                    );
259
260                    let mut tree_start = 0;
261                    for candidate_set in candidate_sets {
262                        let tree_end = tree_start + candidate_set.len();
263
264                        if tree_start < segment_end && segment_start < tree_end {
265                            let start = cmp::max(tree_start, segment_start) - tree_start;
266                            let end = cmp::min(tree_end, segment_end) - tree_start;
267                            let candidates = candidate_set.candidates(start).take(end - start);
268
269                            matcher.match_paths(
270                                candidate_set.id(),
271                                candidate_set.prefix(),
272                                candidates,
273                                results,
274                                &cancel_flag,
275                            );
276                        }
277                        if tree_end >= segment_end {
278                            break;
279                        }
280                        tree_start = tree_end;
281                    }
282                })
283            }
284        })
285        .await;
286
287    let mut results = Vec::new();
288    for segment_result in segment_results {
289        if results.is_empty() {
290            results = segment_result;
291        } else {
292            util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(&a));
293        }
294    }
295    results
296}
297
298impl<'a> Matcher<'a> {
299    pub fn new(
300        query: &'a [char],
301        lowercase_query: &'a [char],
302        query_char_bag: CharBag,
303        smart_case: bool,
304        max_results: usize,
305    ) -> Self {
306        Self {
307            query,
308            lowercase_query,
309            query_char_bag,
310            min_score: 0.0,
311            last_positions: vec![0; query.len()],
312            match_positions: vec![0; query.len()],
313            score_matrix: Vec::new(),
314            best_position_matrix: Vec::new(),
315            smart_case,
316            max_results,
317        }
318    }
319
320    pub fn match_strings(
321        &mut self,
322        candidates: &[StringMatchCandidate],
323        results: &mut Vec<StringMatch>,
324        cancel_flag: &AtomicBool,
325    ) {
326        self.match_internal(
327            &[],
328            &[],
329            candidates.iter(),
330            results,
331            cancel_flag,
332            |candidate, score| StringMatch {
333                score,
334                positions: Vec::new(),
335                string: candidate.string.to_string(),
336            },
337        )
338    }
339
340    pub fn match_paths<'c: 'a>(
341        &mut self,
342        tree_id: usize,
343        path_prefix: Arc<str>,
344        path_entries: impl Iterator<Item = PathMatchCandidate<'c>>,
345        results: &mut Vec<PathMatch>,
346        cancel_flag: &AtomicBool,
347    ) {
348        let prefix = path_prefix.chars().collect::<Vec<_>>();
349        let lowercase_prefix = prefix
350            .iter()
351            .map(|c| c.to_ascii_lowercase())
352            .collect::<Vec<_>>();
353        self.match_internal(
354            &prefix,
355            &lowercase_prefix,
356            path_entries,
357            results,
358            cancel_flag,
359            |candidate, score| PathMatch {
360                score,
361                worktree_id: tree_id,
362                positions: Vec::new(),
363                path: candidate.path.clone(),
364                path_prefix: path_prefix.clone(),
365            },
366        )
367    }
368
369    fn match_internal<C: MatchCandidate, R, F>(
370        &mut self,
371        prefix: &[char],
372        lowercase_prefix: &[char],
373        candidates: impl Iterator<Item = C>,
374        results: &mut Vec<R>,
375        cancel_flag: &AtomicBool,
376        build_match: F,
377    ) where
378        R: Match,
379        F: Fn(&C, f64) -> R,
380    {
381        let mut candidate_chars = Vec::new();
382        let mut lowercase_candidate_chars = Vec::new();
383
384        for candidate in candidates {
385            if !candidate.has_chars(self.query_char_bag) {
386                continue;
387            }
388
389            if cancel_flag.load(atomic::Ordering::Relaxed) {
390                break;
391            }
392
393            candidate_chars.clear();
394            lowercase_candidate_chars.clear();
395            for c in candidate.to_string().chars() {
396                candidate_chars.push(c);
397                lowercase_candidate_chars.push(c.to_ascii_lowercase());
398            }
399
400            if !self.find_last_positions(&lowercase_prefix, &lowercase_candidate_chars) {
401                continue;
402            }
403
404            let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
405            self.score_matrix.clear();
406            self.score_matrix.resize(matrix_len, None);
407            self.best_position_matrix.clear();
408            self.best_position_matrix.resize(matrix_len, 0);
409
410            let score = self.score_match(
411                &candidate_chars,
412                &lowercase_candidate_chars,
413                &prefix,
414                &lowercase_prefix,
415            );
416
417            if score > 0.0 {
418                let mut mat = build_match(&candidate, score);
419                if let Err(i) = results.binary_search_by(|m| mat.cmp(&m)) {
420                    if results.len() < self.max_results {
421                        mat.set_positions(self.match_positions.clone());
422                        results.insert(i, mat);
423                    } else if i < results.len() {
424                        results.pop();
425                        mat.set_positions(self.match_positions.clone());
426                        results.insert(i, mat);
427                    }
428                    if results.len() == self.max_results {
429                        self.min_score = results.last().unwrap().score();
430                    }
431                }
432            }
433        }
434    }
435
436    fn find_last_positions(&mut self, prefix: &[char], path: &[char]) -> bool {
437        let mut path = path.iter();
438        let mut prefix_iter = prefix.iter();
439        for (i, char) in self.query.iter().enumerate().rev() {
440            if let Some(j) = path.rposition(|c| c == char) {
441                self.last_positions[i] = j + prefix.len();
442            } else if let Some(j) = prefix_iter.rposition(|c| c == char) {
443                self.last_positions[i] = j;
444            } else {
445                return false;
446            }
447        }
448        true
449    }
450
451    fn score_match(
452        &mut self,
453        path: &[char],
454        path_cased: &[char],
455        prefix: &[char],
456        lowercase_prefix: &[char],
457    ) -> f64 {
458        let score = self.recursive_score_match(
459            path,
460            path_cased,
461            prefix,
462            lowercase_prefix,
463            0,
464            0,
465            self.query.len() as f64,
466        ) * self.query.len() as f64;
467
468        if score <= 0.0 {
469            return 0.0;
470        }
471
472        let path_len = prefix.len() + path.len();
473        let mut cur_start = 0;
474        let mut byte_ix = 0;
475        let mut char_ix = 0;
476        for i in 0..self.query.len() {
477            let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
478            while char_ix < match_char_ix {
479                let ch = prefix
480                    .get(char_ix)
481                    .or_else(|| path.get(char_ix - prefix.len()))
482                    .unwrap();
483                byte_ix += ch.len_utf8();
484                char_ix += 1;
485            }
486            cur_start = match_char_ix + 1;
487            self.match_positions[i] = byte_ix;
488        }
489
490        score
491    }
492
493    fn recursive_score_match(
494        &mut self,
495        path: &[char],
496        path_cased: &[char],
497        prefix: &[char],
498        lowercase_prefix: &[char],
499        query_idx: usize,
500        path_idx: usize,
501        cur_score: f64,
502    ) -> f64 {
503        if query_idx == self.query.len() {
504            return 1.0;
505        }
506
507        let path_len = prefix.len() + path.len();
508
509        if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
510            return memoized;
511        }
512
513        let mut score = 0.0;
514        let mut best_position = 0;
515
516        let query_char = self.lowercase_query[query_idx];
517        let limit = self.last_positions[query_idx];
518
519        let mut last_slash = 0;
520        for j in path_idx..=limit {
521            let path_char = if j < prefix.len() {
522                lowercase_prefix[j]
523            } else {
524                path_cased[j - prefix.len()]
525            };
526            let is_path_sep = path_char == '/' || path_char == '\\';
527
528            if query_idx == 0 && is_path_sep {
529                last_slash = j;
530            }
531
532            if query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\') {
533                let curr = if j < prefix.len() {
534                    prefix[j]
535                } else {
536                    path[j - prefix.len()]
537                };
538
539                let mut char_score = 1.0;
540                if j > path_idx {
541                    let last = if j - 1 < prefix.len() {
542                        prefix[j - 1]
543                    } else {
544                        path[j - 1 - prefix.len()]
545                    };
546
547                    if last == '/' {
548                        char_score = 0.9;
549                    } else if last == '-' || last == '_' || last == ' ' || last.is_numeric() {
550                        char_score = 0.8;
551                    } else if last.is_lowercase() && curr.is_uppercase() {
552                        char_score = 0.8;
553                    } else if last == '.' {
554                        char_score = 0.7;
555                    } else if query_idx == 0 {
556                        char_score = BASE_DISTANCE_PENALTY;
557                    } else {
558                        char_score = MIN_DISTANCE_PENALTY.max(
559                            BASE_DISTANCE_PENALTY
560                                - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
561                        );
562                    }
563                }
564
565                // Apply a severe penalty if the case doesn't match.
566                // This will make the exact matches have higher score than the case-insensitive and the
567                // path insensitive matches.
568                if (self.smart_case || curr == '/') && self.query[query_idx] != curr {
569                    char_score *= 0.001;
570                }
571
572                let mut multiplier = char_score;
573
574                // Scale the score based on how deep within the path we found the match.
575                if query_idx == 0 {
576                    multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
577                }
578
579                let mut next_score = 1.0;
580                if self.min_score > 0.0 {
581                    next_score = cur_score * multiplier;
582                    // Scores only decrease. If we can't pass the previous best, bail
583                    if next_score < self.min_score {
584                        // Ensure that score is non-zero so we use it in the memo table.
585                        if score == 0.0 {
586                            score = 1e-18;
587                        }
588                        continue;
589                    }
590                }
591
592                let new_score = self.recursive_score_match(
593                    path,
594                    path_cased,
595                    prefix,
596                    lowercase_prefix,
597                    query_idx + 1,
598                    j + 1,
599                    next_score,
600                ) * multiplier;
601
602                if new_score > score {
603                    score = new_score;
604                    best_position = j;
605                    // Optimization: can't score better than 1.
606                    if new_score == 1.0 {
607                        break;
608                    }
609                }
610            }
611        }
612
613        if best_position != 0 {
614            self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
615        }
616
617        self.score_matrix[query_idx * path_len + path_idx] = Some(score);
618        score
619    }
620}
621
622#[cfg(test)]
623mod tests {
624    use super::*;
625    use std::path::PathBuf;
626
627    #[test]
628    fn test_get_last_positions() {
629        let mut query: &[char] = &['d', 'c'];
630        let mut matcher = Matcher::new(query, query, query.into(), false, 10);
631        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
632        assert_eq!(result, false);
633
634        query = &['c', 'd'];
635        let mut matcher = Matcher::new(query, query, query.into(), false, 10);
636        let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
637        assert_eq!(result, true);
638        assert_eq!(matcher.last_positions, vec![2, 4]);
639
640        query = &['z', '/', 'z', 'f'];
641        let mut matcher = Matcher::new(query, query, query.into(), false, 10);
642        let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
643        assert_eq!(result, true);
644        assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
645    }
646
647    #[test]
648    fn test_match_path_entries() {
649        let paths = vec![
650            "",
651            "a",
652            "ab",
653            "abC",
654            "abcd",
655            "alphabravocharlie",
656            "AlphaBravoCharlie",
657            "thisisatestdir",
658            "/////ThisIsATestDir",
659            "/this/is/a/test/dir",
660            "/test/tiatd",
661        ];
662
663        assert_eq!(
664            match_query("abc", false, &paths),
665            vec![
666                ("abC", vec![0, 1, 2]),
667                ("abcd", vec![0, 1, 2]),
668                ("AlphaBravoCharlie", vec![0, 5, 10]),
669                ("alphabravocharlie", vec![4, 5, 10]),
670            ]
671        );
672        assert_eq!(
673            match_query("t/i/a/t/d", false, &paths),
674            vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
675        );
676
677        assert_eq!(
678            match_query("tiatd", false, &paths),
679            vec![
680                ("/test/tiatd", vec![6, 7, 8, 9, 10]),
681                ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
682                ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
683                ("thisisatestdir", vec![0, 2, 6, 7, 11]),
684            ]
685        );
686    }
687
688    #[test]
689    fn test_match_multibyte_path_entries() {
690        let paths = vec!["aαbβ/cγdδ", "αβγδ/bcde", "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", "/d/🆒/h"];
691        assert_eq!("1️⃣".len(), 7);
692        assert_eq!(
693            match_query("bcd", false, &paths),
694            vec![
695                ("αβγδ/bcde", vec![9, 10, 11]),
696                ("aαbβ/cγdδ", vec![3, 7, 10]),
697            ]
698        );
699        assert_eq!(
700            match_query("cde", false, &paths),
701            vec![
702                ("αβγδ/bcde", vec![10, 11, 12]),
703                ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
704            ]
705        );
706    }
707
708    fn match_query<'a>(
709        query: &str,
710        smart_case: bool,
711        paths: &Vec<&'a str>,
712    ) -> Vec<(&'a str, Vec<usize>)> {
713        let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
714        let query = query.chars().collect::<Vec<_>>();
715        let query_chars = CharBag::from(&lowercase_query[..]);
716
717        let path_arcs = paths
718            .iter()
719            .map(|path| Arc::from(PathBuf::from(path)))
720            .collect::<Vec<_>>();
721        let mut path_entries = Vec::new();
722        for (i, path) in paths.iter().enumerate() {
723            let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
724            let char_bag = CharBag::from(lowercase_path.as_slice());
725            path_entries.push(PathMatchCandidate {
726                char_bag,
727                path: path_arcs.get(i).unwrap(),
728            });
729        }
730
731        let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case, 100);
732
733        let cancel_flag = AtomicBool::new(false);
734        let mut results = Vec::new();
735        matcher.match_paths(
736            0,
737            "".into(),
738            path_entries.into_iter(),
739            &mut results,
740            &cancel_flag,
741        );
742
743        results
744            .into_iter()
745            .map(|result| {
746                (
747                    paths
748                        .iter()
749                        .copied()
750                        .find(|p| result.path.as_ref() == Path::new(p))
751                        .unwrap(),
752                    result.positions,
753                )
754            })
755            .collect()
756    }
757}