1mod char_bag;
2
3use gpui::executor;
4use std::{
5 borrow::Cow,
6 cmp::{self, Ordering},
7 path::Path,
8 sync::atomic::{self, AtomicBool},
9 sync::Arc,
10};
11
12pub use char_bag::CharBag;
13
14const BASE_DISTANCE_PENALTY: f64 = 0.6;
15const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
16const MIN_DISTANCE_PENALTY: f64 = 0.2;
17
18pub struct Matcher<'a> {
19 query: &'a [char],
20 lowercase_query: &'a [char],
21 query_char_bag: CharBag,
22 smart_case: bool,
23 max_results: usize,
24 min_score: f64,
25 match_positions: Vec<usize>,
26 last_positions: Vec<usize>,
27 score_matrix: Vec<Option<f64>>,
28 best_position_matrix: Vec<usize>,
29}
30
31trait Match: Ord {
32 fn score(&self) -> f64;
33 fn set_positions(&mut self, positions: Vec<usize>);
34}
35
36trait MatchCandidate {
37 fn has_chars(&self, bag: CharBag) -> bool;
38 fn to_string<'a>(&'a self) -> Cow<'a, str>;
39}
40
41#[derive(Clone, Debug)]
42pub struct PathMatchCandidate<'a> {
43 pub path: &'a Arc<Path>,
44 pub char_bag: CharBag,
45}
46
47#[derive(Clone, Debug)]
48pub struct PathMatch {
49 pub score: f64,
50 pub positions: Vec<usize>,
51 pub worktree_id: usize,
52 pub path: Arc<Path>,
53 pub path_prefix: Arc<str>,
54}
55
56#[derive(Clone, Debug)]
57pub struct StringMatchCandidate {
58 pub id: usize,
59 pub string: String,
60 pub char_bag: CharBag,
61}
62
63pub trait PathMatchCandidateSet<'a>: Send + Sync {
64 type Candidates: Iterator<Item = PathMatchCandidate<'a>>;
65 fn id(&self) -> usize;
66 fn len(&self) -> usize;
67 fn prefix(&self) -> Arc<str>;
68 fn candidates(&'a self, start: usize) -> Self::Candidates;
69}
70
71impl Match for PathMatch {
72 fn score(&self) -> f64 {
73 self.score
74 }
75
76 fn set_positions(&mut self, positions: Vec<usize>) {
77 self.positions = positions;
78 }
79}
80
81impl Match for StringMatch {
82 fn score(&self) -> f64 {
83 self.score
84 }
85
86 fn set_positions(&mut self, positions: Vec<usize>) {
87 self.positions = positions;
88 }
89}
90
91impl<'a> MatchCandidate for PathMatchCandidate<'a> {
92 fn has_chars(&self, bag: CharBag) -> bool {
93 self.char_bag.is_superset(bag)
94 }
95
96 fn to_string(&self) -> Cow<'a, str> {
97 self.path.to_string_lossy()
98 }
99}
100
101impl StringMatchCandidate {
102 pub fn new(id: usize, string: String) -> Self {
103 Self {
104 id,
105 char_bag: CharBag::from(string.as_str()),
106 string,
107 }
108 }
109}
110
111impl<'a> MatchCandidate for &'a StringMatchCandidate {
112 fn has_chars(&self, bag: CharBag) -> bool {
113 self.char_bag.is_superset(bag)
114 }
115
116 fn to_string(&self) -> Cow<'a, str> {
117 self.string.as_str().into()
118 }
119}
120
121#[derive(Clone, Debug)]
122pub struct StringMatch {
123 pub candidate_id: usize,
124 pub score: f64,
125 pub positions: Vec<usize>,
126 pub string: String,
127}
128
129impl PartialEq for StringMatch {
130 fn eq(&self, other: &Self) -> bool {
131 self.cmp(other).is_eq()
132 }
133}
134
135impl Eq for StringMatch {}
136
137impl PartialOrd for StringMatch {
138 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
139 Some(self.cmp(other))
140 }
141}
142
143impl Ord for StringMatch {
144 fn cmp(&self, other: &Self) -> Ordering {
145 self.score
146 .partial_cmp(&other.score)
147 .unwrap_or(Ordering::Equal)
148 .then_with(|| self.candidate_id.cmp(&other.candidate_id))
149 }
150}
151
152impl PartialEq for PathMatch {
153 fn eq(&self, other: &Self) -> bool {
154 self.cmp(other).is_eq()
155 }
156}
157
158impl Eq for PathMatch {}
159
160impl PartialOrd for PathMatch {
161 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
162 Some(self.cmp(other))
163 }
164}
165
166impl Ord for PathMatch {
167 fn cmp(&self, other: &Self) -> Ordering {
168 self.score
169 .partial_cmp(&other.score)
170 .unwrap_or(Ordering::Equal)
171 .then_with(|| self.worktree_id.cmp(&other.worktree_id))
172 .then_with(|| Arc::as_ptr(&self.path).cmp(&Arc::as_ptr(&other.path)))
173 }
174}
175
176pub async fn match_strings(
177 candidates: &[StringMatchCandidate],
178 query: &str,
179 smart_case: bool,
180 max_results: usize,
181 cancel_flag: &AtomicBool,
182 background: Arc<executor::Background>,
183) -> Vec<StringMatch> {
184 if candidates.is_empty() {
185 return Default::default();
186 }
187
188 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
189 let query = query.chars().collect::<Vec<_>>();
190
191 let lowercase_query = &lowercase_query;
192 let query = &query;
193 let query_char_bag = CharBag::from(&lowercase_query[..]);
194
195 let num_cpus = background.num_cpus().min(candidates.len());
196 let segment_size = (candidates.len() + num_cpus - 1) / num_cpus;
197 let mut segment_results = (0..num_cpus)
198 .map(|_| Vec::with_capacity(max_results))
199 .collect::<Vec<_>>();
200
201 background
202 .scoped(|scope| {
203 for (segment_idx, results) in segment_results.iter_mut().enumerate() {
204 let cancel_flag = &cancel_flag;
205 scope.spawn(async move {
206 let segment_start = cmp::min(segment_idx * segment_size, candidates.len());
207 let segment_end = cmp::min(segment_start + segment_size, candidates.len());
208 let mut matcher = Matcher::new(
209 query,
210 lowercase_query,
211 query_char_bag,
212 smart_case,
213 max_results,
214 );
215 matcher.match_strings(
216 &candidates[segment_start..segment_end],
217 results,
218 cancel_flag,
219 );
220 });
221 }
222 })
223 .await;
224
225 let mut results = Vec::new();
226 for segment_result in segment_results {
227 if results.is_empty() {
228 results = segment_result;
229 } else {
230 util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(&a));
231 }
232 }
233 results
234}
235
236pub async fn match_paths<'a, Set: PathMatchCandidateSet<'a>>(
237 candidate_sets: &'a [Set],
238 query: &str,
239 smart_case: bool,
240 max_results: usize,
241 cancel_flag: &AtomicBool,
242 background: Arc<executor::Background>,
243) -> Vec<PathMatch> {
244 let path_count: usize = candidate_sets.iter().map(|s| s.len()).sum();
245 if path_count == 0 {
246 return Vec::new();
247 }
248
249 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
250 let query = query.chars().collect::<Vec<_>>();
251
252 let lowercase_query = &lowercase_query;
253 let query = &query;
254 let query_char_bag = CharBag::from(&lowercase_query[..]);
255
256 let num_cpus = background.num_cpus().min(path_count);
257 let segment_size = (path_count + num_cpus - 1) / num_cpus;
258 let mut segment_results = (0..num_cpus)
259 .map(|_| Vec::with_capacity(max_results))
260 .collect::<Vec<_>>();
261
262 background
263 .scoped(|scope| {
264 for (segment_idx, results) in segment_results.iter_mut().enumerate() {
265 scope.spawn(async move {
266 let segment_start = segment_idx * segment_size;
267 let segment_end = segment_start + segment_size;
268 let mut matcher = Matcher::new(
269 query,
270 lowercase_query,
271 query_char_bag,
272 smart_case,
273 max_results,
274 );
275
276 let mut tree_start = 0;
277 for candidate_set in candidate_sets {
278 let tree_end = tree_start + candidate_set.len();
279
280 if tree_start < segment_end && segment_start < tree_end {
281 let start = cmp::max(tree_start, segment_start) - tree_start;
282 let end = cmp::min(tree_end, segment_end) - tree_start;
283 let candidates = candidate_set.candidates(start).take(end - start);
284
285 matcher.match_paths(
286 candidate_set.id(),
287 candidate_set.prefix(),
288 candidates,
289 results,
290 &cancel_flag,
291 );
292 }
293 if tree_end >= segment_end {
294 break;
295 }
296 tree_start = tree_end;
297 }
298 })
299 }
300 })
301 .await;
302
303 let mut results = Vec::new();
304 for segment_result in segment_results {
305 if results.is_empty() {
306 results = segment_result;
307 } else {
308 util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(&a));
309 }
310 }
311 results
312}
313
314impl<'a> Matcher<'a> {
315 pub fn new(
316 query: &'a [char],
317 lowercase_query: &'a [char],
318 query_char_bag: CharBag,
319 smart_case: bool,
320 max_results: usize,
321 ) -> Self {
322 Self {
323 query,
324 lowercase_query,
325 query_char_bag,
326 min_score: 0.0,
327 last_positions: vec![0; query.len()],
328 match_positions: vec![0; query.len()],
329 score_matrix: Vec::new(),
330 best_position_matrix: Vec::new(),
331 smart_case,
332 max_results,
333 }
334 }
335
336 pub fn match_strings(
337 &mut self,
338 candidates: &[StringMatchCandidate],
339 results: &mut Vec<StringMatch>,
340 cancel_flag: &AtomicBool,
341 ) {
342 self.match_internal(
343 &[],
344 &[],
345 candidates.iter(),
346 results,
347 cancel_flag,
348 |candidate, score| StringMatch {
349 candidate_id: candidate.id,
350 score,
351 positions: Vec::new(),
352 string: candidate.string.to_string(),
353 },
354 )
355 }
356
357 pub fn match_paths<'c: 'a>(
358 &mut self,
359 tree_id: usize,
360 path_prefix: Arc<str>,
361 path_entries: impl Iterator<Item = PathMatchCandidate<'c>>,
362 results: &mut Vec<PathMatch>,
363 cancel_flag: &AtomicBool,
364 ) {
365 let prefix = path_prefix.chars().collect::<Vec<_>>();
366 let lowercase_prefix = prefix
367 .iter()
368 .map(|c| c.to_ascii_lowercase())
369 .collect::<Vec<_>>();
370 self.match_internal(
371 &prefix,
372 &lowercase_prefix,
373 path_entries,
374 results,
375 cancel_flag,
376 |candidate, score| PathMatch {
377 score,
378 worktree_id: tree_id,
379 positions: Vec::new(),
380 path: candidate.path.clone(),
381 path_prefix: path_prefix.clone(),
382 },
383 )
384 }
385
386 fn match_internal<C: MatchCandidate, R, F>(
387 &mut self,
388 prefix: &[char],
389 lowercase_prefix: &[char],
390 candidates: impl Iterator<Item = C>,
391 results: &mut Vec<R>,
392 cancel_flag: &AtomicBool,
393 build_match: F,
394 ) where
395 R: Match,
396 F: Fn(&C, f64) -> R,
397 {
398 let mut candidate_chars = Vec::new();
399 let mut lowercase_candidate_chars = Vec::new();
400
401 for candidate in candidates {
402 if !candidate.has_chars(self.query_char_bag) {
403 continue;
404 }
405
406 if cancel_flag.load(atomic::Ordering::Relaxed) {
407 break;
408 }
409
410 candidate_chars.clear();
411 lowercase_candidate_chars.clear();
412 for c in candidate.to_string().chars() {
413 candidate_chars.push(c);
414 lowercase_candidate_chars.push(c.to_ascii_lowercase());
415 }
416
417 if !self.find_last_positions(&lowercase_prefix, &lowercase_candidate_chars) {
418 continue;
419 }
420
421 let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
422 self.score_matrix.clear();
423 self.score_matrix.resize(matrix_len, None);
424 self.best_position_matrix.clear();
425 self.best_position_matrix.resize(matrix_len, 0);
426
427 let score = self.score_match(
428 &candidate_chars,
429 &lowercase_candidate_chars,
430 &prefix,
431 &lowercase_prefix,
432 );
433
434 if score > 0.0 {
435 let mut mat = build_match(&candidate, score);
436 if let Err(i) = results.binary_search_by(|m| mat.cmp(&m)) {
437 if results.len() < self.max_results {
438 mat.set_positions(self.match_positions.clone());
439 results.insert(i, mat);
440 } else if i < results.len() {
441 results.pop();
442 mat.set_positions(self.match_positions.clone());
443 results.insert(i, mat);
444 }
445 if results.len() == self.max_results {
446 self.min_score = results.last().unwrap().score();
447 }
448 }
449 }
450 }
451 }
452
453 fn find_last_positions(
454 &mut self,
455 lowercase_prefix: &[char],
456 lowercase_candidate: &[char],
457 ) -> bool {
458 let mut lowercase_prefix = lowercase_prefix.iter();
459 let mut lowercase_candidate = lowercase_candidate.iter();
460 for (i, char) in self.lowercase_query.iter().enumerate().rev() {
461 if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
462 self.last_positions[i] = j + lowercase_prefix.len();
463 } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
464 self.last_positions[i] = j;
465 } else {
466 return false;
467 }
468 }
469 true
470 }
471
472 fn score_match(
473 &mut self,
474 path: &[char],
475 path_cased: &[char],
476 prefix: &[char],
477 lowercase_prefix: &[char],
478 ) -> f64 {
479 let score = self.recursive_score_match(
480 path,
481 path_cased,
482 prefix,
483 lowercase_prefix,
484 0,
485 0,
486 self.query.len() as f64,
487 ) * self.query.len() as f64;
488
489 if score <= 0.0 {
490 return 0.0;
491 }
492
493 let path_len = prefix.len() + path.len();
494 let mut cur_start = 0;
495 let mut byte_ix = 0;
496 let mut char_ix = 0;
497 for i in 0..self.query.len() {
498 let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
499 while char_ix < match_char_ix {
500 let ch = prefix
501 .get(char_ix)
502 .or_else(|| path.get(char_ix - prefix.len()))
503 .unwrap();
504 byte_ix += ch.len_utf8();
505 char_ix += 1;
506 }
507 cur_start = match_char_ix + 1;
508 self.match_positions[i] = byte_ix;
509 }
510
511 score
512 }
513
514 fn recursive_score_match(
515 &mut self,
516 path: &[char],
517 path_cased: &[char],
518 prefix: &[char],
519 lowercase_prefix: &[char],
520 query_idx: usize,
521 path_idx: usize,
522 cur_score: f64,
523 ) -> f64 {
524 if query_idx == self.query.len() {
525 return 1.0;
526 }
527
528 let path_len = prefix.len() + path.len();
529
530 if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
531 return memoized;
532 }
533
534 let mut score = 0.0;
535 let mut best_position = 0;
536
537 let query_char = self.lowercase_query[query_idx];
538 let limit = self.last_positions[query_idx];
539
540 let mut last_slash = 0;
541 for j in path_idx..=limit {
542 let path_char = if j < prefix.len() {
543 lowercase_prefix[j]
544 } else {
545 path_cased[j - prefix.len()]
546 };
547 let is_path_sep = path_char == '/' || path_char == '\\';
548
549 if query_idx == 0 && is_path_sep {
550 last_slash = j;
551 }
552
553 if query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\') {
554 let curr = if j < prefix.len() {
555 prefix[j]
556 } else {
557 path[j - prefix.len()]
558 };
559
560 let mut char_score = 1.0;
561 if j > path_idx {
562 let last = if j - 1 < prefix.len() {
563 prefix[j - 1]
564 } else {
565 path[j - 1 - prefix.len()]
566 };
567
568 if last == '/' {
569 char_score = 0.9;
570 } else if last == '-' || last == '_' || last == ' ' || last.is_numeric() {
571 char_score = 0.8;
572 } else if last.is_lowercase() && curr.is_uppercase() {
573 char_score = 0.8;
574 } else if last == '.' {
575 char_score = 0.7;
576 } else if query_idx == 0 {
577 char_score = BASE_DISTANCE_PENALTY;
578 } else {
579 char_score = MIN_DISTANCE_PENALTY.max(
580 BASE_DISTANCE_PENALTY
581 - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
582 );
583 }
584 }
585
586 // Apply a severe penalty if the case doesn't match.
587 // This will make the exact matches have higher score than the case-insensitive and the
588 // path insensitive matches.
589 if (self.smart_case || curr == '/') && self.query[query_idx] != curr {
590 char_score *= 0.001;
591 }
592
593 let mut multiplier = char_score;
594
595 // Scale the score based on how deep within the path we found the match.
596 if query_idx == 0 {
597 multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
598 }
599
600 let mut next_score = 1.0;
601 if self.min_score > 0.0 {
602 next_score = cur_score * multiplier;
603 // Scores only decrease. If we can't pass the previous best, bail
604 if next_score < self.min_score {
605 // Ensure that score is non-zero so we use it in the memo table.
606 if score == 0.0 {
607 score = 1e-18;
608 }
609 continue;
610 }
611 }
612
613 let new_score = self.recursive_score_match(
614 path,
615 path_cased,
616 prefix,
617 lowercase_prefix,
618 query_idx + 1,
619 j + 1,
620 next_score,
621 ) * multiplier;
622
623 if new_score > score {
624 score = new_score;
625 best_position = j;
626 // Optimization: can't score better than 1.
627 if new_score == 1.0 {
628 break;
629 }
630 }
631 }
632 }
633
634 if best_position != 0 {
635 self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
636 }
637
638 self.score_matrix[query_idx * path_len + path_idx] = Some(score);
639 score
640 }
641}
642
643#[cfg(test)]
644mod tests {
645 use super::*;
646 use std::path::PathBuf;
647
648 #[test]
649 fn test_get_last_positions() {
650 let mut query: &[char] = &['d', 'c'];
651 let mut matcher = Matcher::new(query, query, query.into(), false, 10);
652 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
653 assert_eq!(result, false);
654
655 query = &['c', 'd'];
656 let mut matcher = Matcher::new(query, query, query.into(), false, 10);
657 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
658 assert_eq!(result, true);
659 assert_eq!(matcher.last_positions, vec![2, 4]);
660
661 query = &['z', '/', 'z', 'f'];
662 let mut matcher = Matcher::new(query, query, query.into(), false, 10);
663 let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
664 assert_eq!(result, true);
665 assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
666 }
667
668 #[test]
669 fn test_match_path_entries() {
670 let paths = vec![
671 "",
672 "a",
673 "ab",
674 "abC",
675 "abcd",
676 "alphabravocharlie",
677 "AlphaBravoCharlie",
678 "thisisatestdir",
679 "/////ThisIsATestDir",
680 "/this/is/a/test/dir",
681 "/test/tiatd",
682 ];
683
684 assert_eq!(
685 match_query("abc", false, &paths),
686 vec![
687 ("abC", vec![0, 1, 2]),
688 ("abcd", vec![0, 1, 2]),
689 ("AlphaBravoCharlie", vec![0, 5, 10]),
690 ("alphabravocharlie", vec![4, 5, 10]),
691 ]
692 );
693 assert_eq!(
694 match_query("t/i/a/t/d", false, &paths),
695 vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
696 );
697
698 assert_eq!(
699 match_query("tiatd", false, &paths),
700 vec![
701 ("/test/tiatd", vec![6, 7, 8, 9, 10]),
702 ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
703 ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
704 ("thisisatestdir", vec![0, 2, 6, 7, 11]),
705 ]
706 );
707 }
708
709 #[test]
710 fn test_match_multibyte_path_entries() {
711 let paths = vec!["aαbβ/cγdδ", "αβγδ/bcde", "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", "/d/🆒/h"];
712 assert_eq!("1️⃣".len(), 7);
713 assert_eq!(
714 match_query("bcd", false, &paths),
715 vec![
716 ("αβγδ/bcde", vec![9, 10, 11]),
717 ("aαbβ/cγdδ", vec![3, 7, 10]),
718 ]
719 );
720 assert_eq!(
721 match_query("cde", false, &paths),
722 vec![
723 ("αβγδ/bcde", vec![10, 11, 12]),
724 ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
725 ]
726 );
727 }
728
729 fn match_query<'a>(
730 query: &str,
731 smart_case: bool,
732 paths: &Vec<&'a str>,
733 ) -> Vec<(&'a str, Vec<usize>)> {
734 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
735 let query = query.chars().collect::<Vec<_>>();
736 let query_chars = CharBag::from(&lowercase_query[..]);
737
738 let path_arcs = paths
739 .iter()
740 .map(|path| Arc::from(PathBuf::from(path)))
741 .collect::<Vec<_>>();
742 let mut path_entries = Vec::new();
743 for (i, path) in paths.iter().enumerate() {
744 let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
745 let char_bag = CharBag::from(lowercase_path.as_slice());
746 path_entries.push(PathMatchCandidate {
747 char_bag,
748 path: path_arcs.get(i).unwrap(),
749 });
750 }
751
752 let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case, 100);
753
754 let cancel_flag = AtomicBool::new(false);
755 let mut results = Vec::new();
756 matcher.match_paths(
757 0,
758 "".into(),
759 path_entries.into_iter(),
760 &mut results,
761 &cancel_flag,
762 );
763
764 results
765 .into_iter()
766 .map(|result| {
767 (
768 paths
769 .iter()
770 .copied()
771 .find(|p| result.path.as_ref() == Path::new(p))
772 .unwrap(),
773 result.positions,
774 )
775 })
776 .collect()
777 }
778}