1mod char_bag;
2
3use gpui::executor;
4use std::{
5 borrow::Cow,
6 cmp::{self, Ordering},
7 path::Path,
8 sync::atomic::{self, AtomicBool},
9 sync::Arc,
10};
11
12pub use char_bag::CharBag;
13
14const BASE_DISTANCE_PENALTY: f64 = 0.6;
15const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
16const MIN_DISTANCE_PENALTY: f64 = 0.2;
17
18pub struct Matcher<'a> {
19 query: &'a [char],
20 lowercase_query: &'a [char],
21 query_char_bag: CharBag,
22 smart_case: bool,
23 max_results: usize,
24 min_score: f64,
25 match_positions: Vec<usize>,
26 last_positions: Vec<usize>,
27 score_matrix: Vec<Option<f64>>,
28 best_position_matrix: Vec<usize>,
29}
30
31trait Match: Ord {
32 fn score(&self) -> f64;
33 fn set_positions(&mut self, positions: Vec<usize>);
34}
35
36trait MatchCandidate {
37 fn has_chars(&self, bag: CharBag) -> bool;
38 fn to_string(&self) -> Cow<'_, str>;
39}
40
41#[derive(Clone, Debug)]
42pub struct PathMatchCandidate<'a> {
43 pub path: &'a Arc<Path>,
44 pub char_bag: CharBag,
45}
46
47#[derive(Clone, Debug)]
48pub struct PathMatch {
49 pub score: f64,
50 pub positions: Vec<usize>,
51 pub worktree_id: usize,
52 pub path: Arc<Path>,
53 pub path_prefix: Arc<str>,
54}
55
56#[derive(Clone, Debug)]
57pub struct StringMatchCandidate {
58 pub id: usize,
59 pub string: String,
60 pub char_bag: CharBag,
61}
62
63pub trait PathMatchCandidateSet<'a>: Send + Sync {
64 type Candidates: Iterator<Item = PathMatchCandidate<'a>>;
65 fn id(&self) -> usize;
66 fn len(&self) -> usize;
67 fn is_empty(&self) -> bool {
68 self.len() == 0
69 }
70 fn prefix(&self) -> Arc<str>;
71 fn candidates(&'a self, start: usize) -> Self::Candidates;
72}
73
74impl Match for PathMatch {
75 fn score(&self) -> f64 {
76 self.score
77 }
78
79 fn set_positions(&mut self, positions: Vec<usize>) {
80 self.positions = positions;
81 }
82}
83
84impl Match for StringMatch {
85 fn score(&self) -> f64 {
86 self.score
87 }
88
89 fn set_positions(&mut self, positions: Vec<usize>) {
90 self.positions = positions;
91 }
92}
93
94impl<'a> MatchCandidate for PathMatchCandidate<'a> {
95 fn has_chars(&self, bag: CharBag) -> bool {
96 self.char_bag.is_superset(bag)
97 }
98
99 fn to_string(&self) -> Cow<'a, str> {
100 self.path.to_string_lossy()
101 }
102}
103
104impl StringMatchCandidate {
105 pub fn new(id: usize, string: String) -> Self {
106 Self {
107 id,
108 char_bag: CharBag::from(string.as_str()),
109 string,
110 }
111 }
112}
113
114impl<'a> MatchCandidate for &'a StringMatchCandidate {
115 fn has_chars(&self, bag: CharBag) -> bool {
116 self.char_bag.is_superset(bag)
117 }
118
119 fn to_string(&self) -> Cow<'a, str> {
120 self.string.as_str().into()
121 }
122}
123
124#[derive(Clone, Debug)]
125pub struct StringMatch {
126 pub candidate_id: usize,
127 pub score: f64,
128 pub positions: Vec<usize>,
129 pub string: String,
130}
131
132impl PartialEq for StringMatch {
133 fn eq(&self, other: &Self) -> bool {
134 self.cmp(other).is_eq()
135 }
136}
137
138impl Eq for StringMatch {}
139
140impl PartialOrd for StringMatch {
141 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
142 Some(self.cmp(other))
143 }
144}
145
146impl Ord for StringMatch {
147 fn cmp(&self, other: &Self) -> Ordering {
148 self.score
149 .partial_cmp(&other.score)
150 .unwrap_or(Ordering::Equal)
151 .then_with(|| self.candidate_id.cmp(&other.candidate_id))
152 }
153}
154
155impl PartialEq for PathMatch {
156 fn eq(&self, other: &Self) -> bool {
157 self.cmp(other).is_eq()
158 }
159}
160
161impl Eq for PathMatch {}
162
163impl PartialOrd for PathMatch {
164 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
165 Some(self.cmp(other))
166 }
167}
168
169impl Ord for PathMatch {
170 fn cmp(&self, other: &Self) -> Ordering {
171 self.score
172 .partial_cmp(&other.score)
173 .unwrap_or(Ordering::Equal)
174 .then_with(|| self.worktree_id.cmp(&other.worktree_id))
175 .then_with(|| Arc::as_ptr(&self.path).cmp(&Arc::as_ptr(&other.path)))
176 }
177}
178
179pub async fn match_strings(
180 candidates: &[StringMatchCandidate],
181 query: &str,
182 smart_case: bool,
183 max_results: usize,
184 cancel_flag: &AtomicBool,
185 background: Arc<executor::Background>,
186) -> Vec<StringMatch> {
187 if candidates.is_empty() || max_results == 0 {
188 return Default::default();
189 }
190
191 if query.is_empty() {
192 return candidates
193 .iter()
194 .map(|candidate| StringMatch {
195 candidate_id: candidate.id,
196 score: 0.,
197 positions: Default::default(),
198 string: candidate.string.clone(),
199 })
200 .collect();
201 }
202
203 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
204 let query = query.chars().collect::<Vec<_>>();
205
206 let lowercase_query = &lowercase_query;
207 let query = &query;
208 let query_char_bag = CharBag::from(&lowercase_query[..]);
209
210 let num_cpus = background.num_cpus().min(candidates.len());
211 let segment_size = (candidates.len() + num_cpus - 1) / num_cpus;
212 let mut segment_results = (0..num_cpus)
213 .map(|_| Vec::with_capacity(max_results.min(candidates.len())))
214 .collect::<Vec<_>>();
215
216 background
217 .scoped(|scope| {
218 for (segment_idx, results) in segment_results.iter_mut().enumerate() {
219 let cancel_flag = &cancel_flag;
220 scope.spawn(async move {
221 let segment_start = cmp::min(segment_idx * segment_size, candidates.len());
222 let segment_end = cmp::min(segment_start + segment_size, candidates.len());
223 let mut matcher = Matcher::new(
224 query,
225 lowercase_query,
226 query_char_bag,
227 smart_case,
228 max_results,
229 );
230 matcher.match_strings(
231 &candidates[segment_start..segment_end],
232 results,
233 cancel_flag,
234 );
235 });
236 }
237 })
238 .await;
239
240 let mut results = Vec::new();
241 for segment_result in segment_results {
242 if results.is_empty() {
243 results = segment_result;
244 } else {
245 util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(a));
246 }
247 }
248 results
249}
250
251pub async fn match_paths<'a, Set: PathMatchCandidateSet<'a>>(
252 candidate_sets: &'a [Set],
253 query: &str,
254 smart_case: bool,
255 max_results: usize,
256 cancel_flag: &AtomicBool,
257 background: Arc<executor::Background>,
258) -> Vec<PathMatch> {
259 let path_count: usize = candidate_sets.iter().map(|s| s.len()).sum();
260 if path_count == 0 {
261 return Vec::new();
262 }
263
264 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
265 let query = query.chars().collect::<Vec<_>>();
266
267 let lowercase_query = &lowercase_query;
268 let query = &query;
269 let query_char_bag = CharBag::from(&lowercase_query[..]);
270
271 let num_cpus = background.num_cpus().min(path_count);
272 let segment_size = (path_count + num_cpus - 1) / num_cpus;
273 let mut segment_results = (0..num_cpus)
274 .map(|_| Vec::with_capacity(max_results))
275 .collect::<Vec<_>>();
276
277 background
278 .scoped(|scope| {
279 for (segment_idx, results) in segment_results.iter_mut().enumerate() {
280 scope.spawn(async move {
281 let segment_start = segment_idx * segment_size;
282 let segment_end = segment_start + segment_size;
283 let mut matcher = Matcher::new(
284 query,
285 lowercase_query,
286 query_char_bag,
287 smart_case,
288 max_results,
289 );
290
291 let mut tree_start = 0;
292 for candidate_set in candidate_sets {
293 let tree_end = tree_start + candidate_set.len();
294
295 if tree_start < segment_end && segment_start < tree_end {
296 let start = cmp::max(tree_start, segment_start) - tree_start;
297 let end = cmp::min(tree_end, segment_end) - tree_start;
298 let candidates = candidate_set.candidates(start).take(end - start);
299
300 matcher.match_paths(
301 candidate_set.id(),
302 candidate_set.prefix(),
303 candidates,
304 results,
305 cancel_flag,
306 );
307 }
308 if tree_end >= segment_end {
309 break;
310 }
311 tree_start = tree_end;
312 }
313 })
314 }
315 })
316 .await;
317
318 let mut results = Vec::new();
319 for segment_result in segment_results {
320 if results.is_empty() {
321 results = segment_result;
322 } else {
323 util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(a));
324 }
325 }
326 results
327}
328
329impl<'a> Matcher<'a> {
330 pub fn new(
331 query: &'a [char],
332 lowercase_query: &'a [char],
333 query_char_bag: CharBag,
334 smart_case: bool,
335 max_results: usize,
336 ) -> Self {
337 Self {
338 query,
339 lowercase_query,
340 query_char_bag,
341 min_score: 0.0,
342 last_positions: vec![0; query.len()],
343 match_positions: vec![0; query.len()],
344 score_matrix: Vec::new(),
345 best_position_matrix: Vec::new(),
346 smart_case,
347 max_results,
348 }
349 }
350
351 pub fn match_strings(
352 &mut self,
353 candidates: &[StringMatchCandidate],
354 results: &mut Vec<StringMatch>,
355 cancel_flag: &AtomicBool,
356 ) {
357 self.match_internal(
358 &[],
359 &[],
360 candidates.iter(),
361 results,
362 cancel_flag,
363 |candidate, score| StringMatch {
364 candidate_id: candidate.id,
365 score,
366 positions: Vec::new(),
367 string: candidate.string.to_string(),
368 },
369 )
370 }
371
372 pub fn match_paths<'c: 'a>(
373 &mut self,
374 tree_id: usize,
375 path_prefix: Arc<str>,
376 path_entries: impl Iterator<Item = PathMatchCandidate<'c>>,
377 results: &mut Vec<PathMatch>,
378 cancel_flag: &AtomicBool,
379 ) {
380 let prefix = path_prefix.chars().collect::<Vec<_>>();
381 let lowercase_prefix = prefix
382 .iter()
383 .map(|c| c.to_ascii_lowercase())
384 .collect::<Vec<_>>();
385 self.match_internal(
386 &prefix,
387 &lowercase_prefix,
388 path_entries,
389 results,
390 cancel_flag,
391 |candidate, score| PathMatch {
392 score,
393 worktree_id: tree_id,
394 positions: Vec::new(),
395 path: candidate.path.clone(),
396 path_prefix: path_prefix.clone(),
397 },
398 )
399 }
400
401 fn match_internal<C: MatchCandidate, R, F>(
402 &mut self,
403 prefix: &[char],
404 lowercase_prefix: &[char],
405 candidates: impl Iterator<Item = C>,
406 results: &mut Vec<R>,
407 cancel_flag: &AtomicBool,
408 build_match: F,
409 ) where
410 R: Match,
411 F: Fn(&C, f64) -> R,
412 {
413 let mut candidate_chars = Vec::new();
414 let mut lowercase_candidate_chars = Vec::new();
415
416 for candidate in candidates {
417 if !candidate.has_chars(self.query_char_bag) {
418 continue;
419 }
420
421 if cancel_flag.load(atomic::Ordering::Relaxed) {
422 break;
423 }
424
425 candidate_chars.clear();
426 lowercase_candidate_chars.clear();
427 for c in candidate.to_string().chars() {
428 candidate_chars.push(c);
429 lowercase_candidate_chars.push(c.to_ascii_lowercase());
430 }
431
432 if !self.find_last_positions(lowercase_prefix, &lowercase_candidate_chars) {
433 continue;
434 }
435
436 let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
437 self.score_matrix.clear();
438 self.score_matrix.resize(matrix_len, None);
439 self.best_position_matrix.clear();
440 self.best_position_matrix.resize(matrix_len, 0);
441
442 let score = self.score_match(
443 &candidate_chars,
444 &lowercase_candidate_chars,
445 prefix,
446 lowercase_prefix,
447 );
448
449 if score > 0.0 {
450 let mut mat = build_match(&candidate, score);
451 if let Err(i) = results.binary_search_by(|m| mat.cmp(m)) {
452 if results.len() < self.max_results {
453 mat.set_positions(self.match_positions.clone());
454 results.insert(i, mat);
455 } else if i < results.len() {
456 results.pop();
457 mat.set_positions(self.match_positions.clone());
458 results.insert(i, mat);
459 }
460 if results.len() == self.max_results {
461 self.min_score = results.last().unwrap().score();
462 }
463 }
464 }
465 }
466 }
467
468 fn find_last_positions(
469 &mut self,
470 lowercase_prefix: &[char],
471 lowercase_candidate: &[char],
472 ) -> bool {
473 let mut lowercase_prefix = lowercase_prefix.iter();
474 let mut lowercase_candidate = lowercase_candidate.iter();
475 for (i, char) in self.lowercase_query.iter().enumerate().rev() {
476 if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
477 self.last_positions[i] = j + lowercase_prefix.len();
478 } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
479 self.last_positions[i] = j;
480 } else {
481 return false;
482 }
483 }
484 true
485 }
486
487 fn score_match(
488 &mut self,
489 path: &[char],
490 path_cased: &[char],
491 prefix: &[char],
492 lowercase_prefix: &[char],
493 ) -> f64 {
494 let score = self.recursive_score_match(
495 path,
496 path_cased,
497 prefix,
498 lowercase_prefix,
499 0,
500 0,
501 self.query.len() as f64,
502 ) * self.query.len() as f64;
503
504 if score <= 0.0 {
505 return 0.0;
506 }
507
508 let path_len = prefix.len() + path.len();
509 let mut cur_start = 0;
510 let mut byte_ix = 0;
511 let mut char_ix = 0;
512 for i in 0..self.query.len() {
513 let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
514 while char_ix < match_char_ix {
515 let ch = prefix
516 .get(char_ix)
517 .or_else(|| path.get(char_ix - prefix.len()))
518 .unwrap();
519 byte_ix += ch.len_utf8();
520 char_ix += 1;
521 }
522 cur_start = match_char_ix + 1;
523 self.match_positions[i] = byte_ix;
524 }
525
526 score
527 }
528
529 #[allow(clippy::too_many_arguments)]
530 fn recursive_score_match(
531 &mut self,
532 path: &[char],
533 path_cased: &[char],
534 prefix: &[char],
535 lowercase_prefix: &[char],
536 query_idx: usize,
537 path_idx: usize,
538 cur_score: f64,
539 ) -> f64 {
540 if query_idx == self.query.len() {
541 return 1.0;
542 }
543
544 let path_len = prefix.len() + path.len();
545
546 if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
547 return memoized;
548 }
549
550 let mut score = 0.0;
551 let mut best_position = 0;
552
553 let query_char = self.lowercase_query[query_idx];
554 let limit = self.last_positions[query_idx];
555
556 let mut last_slash = 0;
557 for j in path_idx..=limit {
558 let path_char = if j < prefix.len() {
559 lowercase_prefix[j]
560 } else {
561 path_cased[j - prefix.len()]
562 };
563 let is_path_sep = path_char == '/' || path_char == '\\';
564
565 if query_idx == 0 && is_path_sep {
566 last_slash = j;
567 }
568
569 if query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\') {
570 let curr = if j < prefix.len() {
571 prefix[j]
572 } else {
573 path[j - prefix.len()]
574 };
575
576 let mut char_score = 1.0;
577 if j > path_idx {
578 let last = if j - 1 < prefix.len() {
579 prefix[j - 1]
580 } else {
581 path[j - 1 - prefix.len()]
582 };
583
584 if last == '/' {
585 char_score = 0.9;
586 } else if (last == '-' || last == '_' || last == ' ' || last.is_numeric())
587 || (last.is_lowercase() && curr.is_uppercase())
588 {
589 char_score = 0.8;
590 } else if last == '.' {
591 char_score = 0.7;
592 } else if query_idx == 0 {
593 char_score = BASE_DISTANCE_PENALTY;
594 } else {
595 char_score = MIN_DISTANCE_PENALTY.max(
596 BASE_DISTANCE_PENALTY
597 - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
598 );
599 }
600 }
601
602 // Apply a severe penalty if the case doesn't match.
603 // This will make the exact matches have higher score than the case-insensitive and the
604 // path insensitive matches.
605 if (self.smart_case || curr == '/') && self.query[query_idx] != curr {
606 char_score *= 0.001;
607 }
608
609 let mut multiplier = char_score;
610
611 // Scale the score based on how deep within the path we found the match.
612 if query_idx == 0 {
613 multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
614 }
615
616 let mut next_score = 1.0;
617 if self.min_score > 0.0 {
618 next_score = cur_score * multiplier;
619 // Scores only decrease. If we can't pass the previous best, bail
620 if next_score < self.min_score {
621 // Ensure that score is non-zero so we use it in the memo table.
622 if score == 0.0 {
623 score = 1e-18;
624 }
625 continue;
626 }
627 }
628
629 let new_score = self.recursive_score_match(
630 path,
631 path_cased,
632 prefix,
633 lowercase_prefix,
634 query_idx + 1,
635 j + 1,
636 next_score,
637 ) * multiplier;
638
639 if new_score > score {
640 score = new_score;
641 best_position = j;
642 // Optimization: can't score better than 1.
643 if new_score == 1.0 {
644 break;
645 }
646 }
647 }
648 }
649
650 if best_position != 0 {
651 self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
652 }
653
654 self.score_matrix[query_idx * path_len + path_idx] = Some(score);
655 score
656 }
657}
658
659#[cfg(test)]
660mod tests {
661 use super::*;
662 use std::path::PathBuf;
663
664 #[test]
665 fn test_get_last_positions() {
666 let mut query: &[char] = &['d', 'c'];
667 let mut matcher = Matcher::new(query, query, query.into(), false, 10);
668 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
669 assert!(!result);
670
671 query = &['c', 'd'];
672 let mut matcher = Matcher::new(query, query, query.into(), false, 10);
673 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
674 assert!(result);
675 assert_eq!(matcher.last_positions, vec![2, 4]);
676
677 query = &['z', '/', 'z', 'f'];
678 let mut matcher = Matcher::new(query, query, query.into(), false, 10);
679 let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
680 assert!(result);
681 assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
682 }
683
684 #[test]
685 fn test_match_path_entries() {
686 let paths = vec![
687 "",
688 "a",
689 "ab",
690 "abC",
691 "abcd",
692 "alphabravocharlie",
693 "AlphaBravoCharlie",
694 "thisisatestdir",
695 "/////ThisIsATestDir",
696 "/this/is/a/test/dir",
697 "/test/tiatd",
698 ];
699
700 assert_eq!(
701 match_query("abc", false, &paths),
702 vec![
703 ("abC", vec![0, 1, 2]),
704 ("abcd", vec![0, 1, 2]),
705 ("AlphaBravoCharlie", vec![0, 5, 10]),
706 ("alphabravocharlie", vec![4, 5, 10]),
707 ]
708 );
709 assert_eq!(
710 match_query("t/i/a/t/d", false, &paths),
711 vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
712 );
713
714 assert_eq!(
715 match_query("tiatd", false, &paths),
716 vec![
717 ("/test/tiatd", vec![6, 7, 8, 9, 10]),
718 ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
719 ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
720 ("thisisatestdir", vec![0, 2, 6, 7, 11]),
721 ]
722 );
723 }
724
725 #[test]
726 fn test_match_multibyte_path_entries() {
727 let paths = vec!["aαbβ/cγdδ", "αβγδ/bcde", "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", "/d/🆒/h"];
728 assert_eq!("1️⃣".len(), 7);
729 assert_eq!(
730 match_query("bcd", false, &paths),
731 vec![
732 ("αβγδ/bcde", vec![9, 10, 11]),
733 ("aαbβ/cγdδ", vec![3, 7, 10]),
734 ]
735 );
736 assert_eq!(
737 match_query("cde", false, &paths),
738 vec![
739 ("αβγδ/bcde", vec![10, 11, 12]),
740 ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
741 ]
742 );
743 }
744
745 fn match_query<'a>(
746 query: &str,
747 smart_case: bool,
748 paths: &[&'a str],
749 ) -> Vec<(&'a str, Vec<usize>)> {
750 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
751 let query = query.chars().collect::<Vec<_>>();
752 let query_chars = CharBag::from(&lowercase_query[..]);
753
754 let path_arcs = paths
755 .iter()
756 .map(|path| Arc::from(PathBuf::from(path)))
757 .collect::<Vec<_>>();
758 let mut path_entries = Vec::new();
759 for (i, path) in paths.iter().enumerate() {
760 let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
761 let char_bag = CharBag::from(lowercase_path.as_slice());
762 path_entries.push(PathMatchCandidate {
763 char_bag,
764 path: path_arcs.get(i).unwrap(),
765 });
766 }
767
768 let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case, 100);
769
770 let cancel_flag = AtomicBool::new(false);
771 let mut results = Vec::new();
772 matcher.match_paths(
773 0,
774 "".into(),
775 path_entries.into_iter(),
776 &mut results,
777 &cancel_flag,
778 );
779
780 results
781 .into_iter()
782 .map(|result| {
783 (
784 paths
785 .iter()
786 .copied()
787 .find(|p| result.path.as_ref() == Path::new(p))
788 .unwrap(),
789 result.positions,
790 )
791 })
792 .collect()
793 }
794}