1mod char_bag;
2
3use gpui::executor;
4use std::{
5 borrow::Cow,
6 cmp::{self, Ordering},
7 path::Path,
8 sync::atomic::{self, AtomicBool},
9 sync::Arc,
10};
11
12pub use char_bag::CharBag;
13
14const BASE_DISTANCE_PENALTY: f64 = 0.6;
15const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
16const MIN_DISTANCE_PENALTY: f64 = 0.2;
17
18pub struct Matcher<'a> {
19 query: &'a [char],
20 lowercase_query: &'a [char],
21 query_char_bag: CharBag,
22 smart_case: bool,
23 max_results: usize,
24 min_score: f64,
25 match_positions: Vec<usize>,
26 last_positions: Vec<usize>,
27 score_matrix: Vec<Option<f64>>,
28 best_position_matrix: Vec<usize>,
29}
30
31trait Match: Ord {
32 fn score(&self) -> f64;
33 fn set_positions(&mut self, positions: Vec<usize>);
34}
35
36trait MatchCandidate {
37 fn has_chars(&self, bag: CharBag) -> bool;
38 fn to_string<'a>(&'a self) -> Cow<'a, str>;
39}
40
41#[derive(Clone, Debug)]
42pub struct PathMatchCandidate<'a> {
43 pub path: &'a Arc<Path>,
44 pub char_bag: CharBag,
45}
46
47#[derive(Clone, Debug)]
48pub struct PathMatch {
49 pub score: f64,
50 pub positions: Vec<usize>,
51 pub worktree_id: usize,
52 pub path: Arc<Path>,
53 pub path_prefix: Arc<str>,
54}
55
56#[derive(Clone, Debug)]
57pub struct StringMatchCandidate {
58 pub id: usize,
59 pub string: String,
60 pub char_bag: CharBag,
61}
62
63pub trait PathMatchCandidateSet<'a>: Send + Sync {
64 type Candidates: Iterator<Item = PathMatchCandidate<'a>>;
65 fn id(&self) -> usize;
66 fn len(&self) -> usize;
67 fn prefix(&self) -> Arc<str>;
68 fn candidates(&'a self, start: usize) -> Self::Candidates;
69}
70
71impl Match for PathMatch {
72 fn score(&self) -> f64 {
73 self.score
74 }
75
76 fn set_positions(&mut self, positions: Vec<usize>) {
77 self.positions = positions;
78 }
79}
80
81impl Match for StringMatch {
82 fn score(&self) -> f64 {
83 self.score
84 }
85
86 fn set_positions(&mut self, positions: Vec<usize>) {
87 self.positions = positions;
88 }
89}
90
91impl<'a> MatchCandidate for PathMatchCandidate<'a> {
92 fn has_chars(&self, bag: CharBag) -> bool {
93 self.char_bag.is_superset(bag)
94 }
95
96 fn to_string(&self) -> Cow<'a, str> {
97 self.path.to_string_lossy()
98 }
99}
100
101impl StringMatchCandidate {
102 pub fn new(id: usize, string: String) -> Self {
103 Self {
104 id,
105 char_bag: CharBag::from(string.as_str()),
106 string,
107 }
108 }
109}
110
111impl<'a> MatchCandidate for &'a StringMatchCandidate {
112 fn has_chars(&self, bag: CharBag) -> bool {
113 self.char_bag.is_superset(bag)
114 }
115
116 fn to_string(&self) -> Cow<'a, str> {
117 self.string.as_str().into()
118 }
119}
120
121#[derive(Clone, Debug)]
122pub struct StringMatch {
123 pub candidate_id: usize,
124 pub score: f64,
125 pub positions: Vec<usize>,
126 pub string: String,
127}
128
129impl PartialEq for StringMatch {
130 fn eq(&self, other: &Self) -> bool {
131 self.cmp(other).is_eq()
132 }
133}
134
135impl Eq for StringMatch {}
136
137impl PartialOrd for StringMatch {
138 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
139 Some(self.cmp(other))
140 }
141}
142
143impl Ord for StringMatch {
144 fn cmp(&self, other: &Self) -> Ordering {
145 self.score
146 .partial_cmp(&other.score)
147 .unwrap_or(Ordering::Equal)
148 .then_with(|| self.candidate_id.cmp(&other.candidate_id))
149 }
150}
151
152impl PartialEq for PathMatch {
153 fn eq(&self, other: &Self) -> bool {
154 self.cmp(other).is_eq()
155 }
156}
157
158impl Eq for PathMatch {}
159
160impl PartialOrd for PathMatch {
161 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
162 Some(self.cmp(other))
163 }
164}
165
166impl Ord for PathMatch {
167 fn cmp(&self, other: &Self) -> Ordering {
168 self.score
169 .partial_cmp(&other.score)
170 .unwrap_or(Ordering::Equal)
171 .then_with(|| self.worktree_id.cmp(&other.worktree_id))
172 .then_with(|| Arc::as_ptr(&self.path).cmp(&Arc::as_ptr(&other.path)))
173 }
174}
175
176pub async fn match_strings(
177 candidates: &[StringMatchCandidate],
178 query: &str,
179 smart_case: bool,
180 max_results: usize,
181 cancel_flag: &AtomicBool,
182 background: Arc<executor::Background>,
183) -> Vec<StringMatch> {
184 if candidates.is_empty() {
185 return Default::default();
186 }
187
188 if query.is_empty() {
189 return candidates
190 .iter()
191 .map(|candidate| StringMatch {
192 candidate_id: candidate.id,
193 score: 0.,
194 positions: Default::default(),
195 string: candidate.string.clone(),
196 })
197 .collect();
198 }
199
200 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
201 let query = query.chars().collect::<Vec<_>>();
202
203 let lowercase_query = &lowercase_query;
204 let query = &query;
205 let query_char_bag = CharBag::from(&lowercase_query[..]);
206
207 let num_cpus = background.num_cpus().min(candidates.len());
208 let segment_size = (candidates.len() + num_cpus - 1) / num_cpus;
209 let mut segment_results = (0..num_cpus)
210 .map(|_| Vec::with_capacity(max_results.min(candidates.len())))
211 .collect::<Vec<_>>();
212
213 background
214 .scoped(|scope| {
215 for (segment_idx, results) in segment_results.iter_mut().enumerate() {
216 let cancel_flag = &cancel_flag;
217 scope.spawn(async move {
218 let segment_start = cmp::min(segment_idx * segment_size, candidates.len());
219 let segment_end = cmp::min(segment_start + segment_size, candidates.len());
220 let mut matcher = Matcher::new(
221 query,
222 lowercase_query,
223 query_char_bag,
224 smart_case,
225 max_results,
226 );
227 matcher.match_strings(
228 &candidates[segment_start..segment_end],
229 results,
230 cancel_flag,
231 );
232 });
233 }
234 })
235 .await;
236
237 let mut results = Vec::new();
238 for segment_result in segment_results {
239 if results.is_empty() {
240 results = segment_result;
241 } else {
242 util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(&a));
243 }
244 }
245 results
246}
247
248pub async fn match_paths<'a, Set: PathMatchCandidateSet<'a>>(
249 candidate_sets: &'a [Set],
250 query: &str,
251 smart_case: bool,
252 max_results: usize,
253 cancel_flag: &AtomicBool,
254 background: Arc<executor::Background>,
255) -> Vec<PathMatch> {
256 let path_count: usize = candidate_sets.iter().map(|s| s.len()).sum();
257 if path_count == 0 {
258 return Vec::new();
259 }
260
261 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
262 let query = query.chars().collect::<Vec<_>>();
263
264 let lowercase_query = &lowercase_query;
265 let query = &query;
266 let query_char_bag = CharBag::from(&lowercase_query[..]);
267
268 let num_cpus = background.num_cpus().min(path_count);
269 let segment_size = (path_count + num_cpus - 1) / num_cpus;
270 let mut segment_results = (0..num_cpus)
271 .map(|_| Vec::with_capacity(max_results))
272 .collect::<Vec<_>>();
273
274 background
275 .scoped(|scope| {
276 for (segment_idx, results) in segment_results.iter_mut().enumerate() {
277 scope.spawn(async move {
278 let segment_start = segment_idx * segment_size;
279 let segment_end = segment_start + segment_size;
280 let mut matcher = Matcher::new(
281 query,
282 lowercase_query,
283 query_char_bag,
284 smart_case,
285 max_results,
286 );
287
288 let mut tree_start = 0;
289 for candidate_set in candidate_sets {
290 let tree_end = tree_start + candidate_set.len();
291
292 if tree_start < segment_end && segment_start < tree_end {
293 let start = cmp::max(tree_start, segment_start) - tree_start;
294 let end = cmp::min(tree_end, segment_end) - tree_start;
295 let candidates = candidate_set.candidates(start).take(end - start);
296
297 matcher.match_paths(
298 candidate_set.id(),
299 candidate_set.prefix(),
300 candidates,
301 results,
302 &cancel_flag,
303 );
304 }
305 if tree_end >= segment_end {
306 break;
307 }
308 tree_start = tree_end;
309 }
310 })
311 }
312 })
313 .await;
314
315 let mut results = Vec::new();
316 for segment_result in segment_results {
317 if results.is_empty() {
318 results = segment_result;
319 } else {
320 util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(&a));
321 }
322 }
323 results
324}
325
326impl<'a> Matcher<'a> {
327 pub fn new(
328 query: &'a [char],
329 lowercase_query: &'a [char],
330 query_char_bag: CharBag,
331 smart_case: bool,
332 max_results: usize,
333 ) -> Self {
334 Self {
335 query,
336 lowercase_query,
337 query_char_bag,
338 min_score: 0.0,
339 last_positions: vec![0; query.len()],
340 match_positions: vec![0; query.len()],
341 score_matrix: Vec::new(),
342 best_position_matrix: Vec::new(),
343 smart_case,
344 max_results,
345 }
346 }
347
348 pub fn match_strings(
349 &mut self,
350 candidates: &[StringMatchCandidate],
351 results: &mut Vec<StringMatch>,
352 cancel_flag: &AtomicBool,
353 ) {
354 self.match_internal(
355 &[],
356 &[],
357 candidates.iter(),
358 results,
359 cancel_flag,
360 |candidate, score| StringMatch {
361 candidate_id: candidate.id,
362 score,
363 positions: Vec::new(),
364 string: candidate.string.to_string(),
365 },
366 )
367 }
368
369 pub fn match_paths<'c: 'a>(
370 &mut self,
371 tree_id: usize,
372 path_prefix: Arc<str>,
373 path_entries: impl Iterator<Item = PathMatchCandidate<'c>>,
374 results: &mut Vec<PathMatch>,
375 cancel_flag: &AtomicBool,
376 ) {
377 let prefix = path_prefix.chars().collect::<Vec<_>>();
378 let lowercase_prefix = prefix
379 .iter()
380 .map(|c| c.to_ascii_lowercase())
381 .collect::<Vec<_>>();
382 self.match_internal(
383 &prefix,
384 &lowercase_prefix,
385 path_entries,
386 results,
387 cancel_flag,
388 |candidate, score| PathMatch {
389 score,
390 worktree_id: tree_id,
391 positions: Vec::new(),
392 path: candidate.path.clone(),
393 path_prefix: path_prefix.clone(),
394 },
395 )
396 }
397
398 fn match_internal<C: MatchCandidate, R, F>(
399 &mut self,
400 prefix: &[char],
401 lowercase_prefix: &[char],
402 candidates: impl Iterator<Item = C>,
403 results: &mut Vec<R>,
404 cancel_flag: &AtomicBool,
405 build_match: F,
406 ) where
407 R: Match,
408 F: Fn(&C, f64) -> R,
409 {
410 let mut candidate_chars = Vec::new();
411 let mut lowercase_candidate_chars = Vec::new();
412
413 for candidate in candidates {
414 if !candidate.has_chars(self.query_char_bag) {
415 continue;
416 }
417
418 if cancel_flag.load(atomic::Ordering::Relaxed) {
419 break;
420 }
421
422 candidate_chars.clear();
423 lowercase_candidate_chars.clear();
424 for c in candidate.to_string().chars() {
425 candidate_chars.push(c);
426 lowercase_candidate_chars.push(c.to_ascii_lowercase());
427 }
428
429 if !self.find_last_positions(&lowercase_prefix, &lowercase_candidate_chars) {
430 continue;
431 }
432
433 let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
434 self.score_matrix.clear();
435 self.score_matrix.resize(matrix_len, None);
436 self.best_position_matrix.clear();
437 self.best_position_matrix.resize(matrix_len, 0);
438
439 let score = self.score_match(
440 &candidate_chars,
441 &lowercase_candidate_chars,
442 &prefix,
443 &lowercase_prefix,
444 );
445
446 if score > 0.0 {
447 let mut mat = build_match(&candidate, score);
448 if let Err(i) = results.binary_search_by(|m| mat.cmp(&m)) {
449 if results.len() < self.max_results {
450 mat.set_positions(self.match_positions.clone());
451 results.insert(i, mat);
452 } else if i < results.len() {
453 results.pop();
454 mat.set_positions(self.match_positions.clone());
455 results.insert(i, mat);
456 }
457 if results.len() == self.max_results {
458 self.min_score = results.last().unwrap().score();
459 }
460 }
461 }
462 }
463 }
464
465 fn find_last_positions(
466 &mut self,
467 lowercase_prefix: &[char],
468 lowercase_candidate: &[char],
469 ) -> bool {
470 let mut lowercase_prefix = lowercase_prefix.iter();
471 let mut lowercase_candidate = lowercase_candidate.iter();
472 for (i, char) in self.lowercase_query.iter().enumerate().rev() {
473 if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
474 self.last_positions[i] = j + lowercase_prefix.len();
475 } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
476 self.last_positions[i] = j;
477 } else {
478 return false;
479 }
480 }
481 true
482 }
483
484 fn score_match(
485 &mut self,
486 path: &[char],
487 path_cased: &[char],
488 prefix: &[char],
489 lowercase_prefix: &[char],
490 ) -> f64 {
491 let score = self.recursive_score_match(
492 path,
493 path_cased,
494 prefix,
495 lowercase_prefix,
496 0,
497 0,
498 self.query.len() as f64,
499 ) * self.query.len() as f64;
500
501 if score <= 0.0 {
502 return 0.0;
503 }
504
505 let path_len = prefix.len() + path.len();
506 let mut cur_start = 0;
507 let mut byte_ix = 0;
508 let mut char_ix = 0;
509 for i in 0..self.query.len() {
510 let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
511 while char_ix < match_char_ix {
512 let ch = prefix
513 .get(char_ix)
514 .or_else(|| path.get(char_ix - prefix.len()))
515 .unwrap();
516 byte_ix += ch.len_utf8();
517 char_ix += 1;
518 }
519 cur_start = match_char_ix + 1;
520 self.match_positions[i] = byte_ix;
521 }
522
523 score
524 }
525
526 fn recursive_score_match(
527 &mut self,
528 path: &[char],
529 path_cased: &[char],
530 prefix: &[char],
531 lowercase_prefix: &[char],
532 query_idx: usize,
533 path_idx: usize,
534 cur_score: f64,
535 ) -> f64 {
536 if query_idx == self.query.len() {
537 return 1.0;
538 }
539
540 let path_len = prefix.len() + path.len();
541
542 if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
543 return memoized;
544 }
545
546 let mut score = 0.0;
547 let mut best_position = 0;
548
549 let query_char = self.lowercase_query[query_idx];
550 let limit = self.last_positions[query_idx];
551
552 let mut last_slash = 0;
553 for j in path_idx..=limit {
554 let path_char = if j < prefix.len() {
555 lowercase_prefix[j]
556 } else {
557 path_cased[j - prefix.len()]
558 };
559 let is_path_sep = path_char == '/' || path_char == '\\';
560
561 if query_idx == 0 && is_path_sep {
562 last_slash = j;
563 }
564
565 if query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\') {
566 let curr = if j < prefix.len() {
567 prefix[j]
568 } else {
569 path[j - prefix.len()]
570 };
571
572 let mut char_score = 1.0;
573 if j > path_idx {
574 let last = if j - 1 < prefix.len() {
575 prefix[j - 1]
576 } else {
577 path[j - 1 - prefix.len()]
578 };
579
580 if last == '/' {
581 char_score = 0.9;
582 } else if last == '-' || last == '_' || last == ' ' || last.is_numeric() {
583 char_score = 0.8;
584 } else if last.is_lowercase() && curr.is_uppercase() {
585 char_score = 0.8;
586 } else if last == '.' {
587 char_score = 0.7;
588 } else if query_idx == 0 {
589 char_score = BASE_DISTANCE_PENALTY;
590 } else {
591 char_score = MIN_DISTANCE_PENALTY.max(
592 BASE_DISTANCE_PENALTY
593 - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
594 );
595 }
596 }
597
598 // Apply a severe penalty if the case doesn't match.
599 // This will make the exact matches have higher score than the case-insensitive and the
600 // path insensitive matches.
601 if (self.smart_case || curr == '/') && self.query[query_idx] != curr {
602 char_score *= 0.001;
603 }
604
605 let mut multiplier = char_score;
606
607 // Scale the score based on how deep within the path we found the match.
608 if query_idx == 0 {
609 multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
610 }
611
612 let mut next_score = 1.0;
613 if self.min_score > 0.0 {
614 next_score = cur_score * multiplier;
615 // Scores only decrease. If we can't pass the previous best, bail
616 if next_score < self.min_score {
617 // Ensure that score is non-zero so we use it in the memo table.
618 if score == 0.0 {
619 score = 1e-18;
620 }
621 continue;
622 }
623 }
624
625 let new_score = self.recursive_score_match(
626 path,
627 path_cased,
628 prefix,
629 lowercase_prefix,
630 query_idx + 1,
631 j + 1,
632 next_score,
633 ) * multiplier;
634
635 if new_score > score {
636 score = new_score;
637 best_position = j;
638 // Optimization: can't score better than 1.
639 if new_score == 1.0 {
640 break;
641 }
642 }
643 }
644 }
645
646 if best_position != 0 {
647 self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
648 }
649
650 self.score_matrix[query_idx * path_len + path_idx] = Some(score);
651 score
652 }
653}
654
655#[cfg(test)]
656mod tests {
657 use super::*;
658 use std::path::PathBuf;
659
660 #[test]
661 fn test_get_last_positions() {
662 let mut query: &[char] = &['d', 'c'];
663 let mut matcher = Matcher::new(query, query, query.into(), false, 10);
664 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
665 assert_eq!(result, false);
666
667 query = &['c', 'd'];
668 let mut matcher = Matcher::new(query, query, query.into(), false, 10);
669 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
670 assert_eq!(result, true);
671 assert_eq!(matcher.last_positions, vec![2, 4]);
672
673 query = &['z', '/', 'z', 'f'];
674 let mut matcher = Matcher::new(query, query, query.into(), false, 10);
675 let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
676 assert_eq!(result, true);
677 assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
678 }
679
680 #[test]
681 fn test_match_path_entries() {
682 let paths = vec![
683 "",
684 "a",
685 "ab",
686 "abC",
687 "abcd",
688 "alphabravocharlie",
689 "AlphaBravoCharlie",
690 "thisisatestdir",
691 "/////ThisIsATestDir",
692 "/this/is/a/test/dir",
693 "/test/tiatd",
694 ];
695
696 assert_eq!(
697 match_query("abc", false, &paths),
698 vec![
699 ("abC", vec![0, 1, 2]),
700 ("abcd", vec![0, 1, 2]),
701 ("AlphaBravoCharlie", vec![0, 5, 10]),
702 ("alphabravocharlie", vec![4, 5, 10]),
703 ]
704 );
705 assert_eq!(
706 match_query("t/i/a/t/d", false, &paths),
707 vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
708 );
709
710 assert_eq!(
711 match_query("tiatd", false, &paths),
712 vec![
713 ("/test/tiatd", vec![6, 7, 8, 9, 10]),
714 ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
715 ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
716 ("thisisatestdir", vec![0, 2, 6, 7, 11]),
717 ]
718 );
719 }
720
721 #[test]
722 fn test_match_multibyte_path_entries() {
723 let paths = vec!["aαbβ/cγdδ", "αβγδ/bcde", "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", "/d/🆒/h"];
724 assert_eq!("1️⃣".len(), 7);
725 assert_eq!(
726 match_query("bcd", false, &paths),
727 vec![
728 ("αβγδ/bcde", vec![9, 10, 11]),
729 ("aαbβ/cγdδ", vec![3, 7, 10]),
730 ]
731 );
732 assert_eq!(
733 match_query("cde", false, &paths),
734 vec![
735 ("αβγδ/bcde", vec![10, 11, 12]),
736 ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
737 ]
738 );
739 }
740
741 fn match_query<'a>(
742 query: &str,
743 smart_case: bool,
744 paths: &Vec<&'a str>,
745 ) -> Vec<(&'a str, Vec<usize>)> {
746 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
747 let query = query.chars().collect::<Vec<_>>();
748 let query_chars = CharBag::from(&lowercase_query[..]);
749
750 let path_arcs = paths
751 .iter()
752 .map(|path| Arc::from(PathBuf::from(path)))
753 .collect::<Vec<_>>();
754 let mut path_entries = Vec::new();
755 for (i, path) in paths.iter().enumerate() {
756 let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
757 let char_bag = CharBag::from(lowercase_path.as_slice());
758 path_entries.push(PathMatchCandidate {
759 char_bag,
760 path: path_arcs.get(i).unwrap(),
761 });
762 }
763
764 let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case, 100);
765
766 let cancel_flag = AtomicBool::new(false);
767 let mut results = Vec::new();
768 matcher.match_paths(
769 0,
770 "".into(),
771 path_entries.into_iter(),
772 &mut results,
773 &cancel_flag,
774 );
775
776 results
777 .into_iter()
778 .map(|result| {
779 (
780 paths
781 .iter()
782 .copied()
783 .find(|p| result.path.as_ref() == Path::new(p))
784 .unwrap(),
785 result.positions,
786 )
787 })
788 .collect()
789 }
790}