1mod char_bag;
2
3use crate::{
4 util,
5 worktree::{EntryKind, Snapshot},
6};
7use gpui::executor;
8use std::{
9 borrow::Cow,
10 cmp::{max, min, Ordering},
11 path::Path,
12 sync::atomic::{self, AtomicBool},
13 sync::Arc,
14};
15
16pub use char_bag::CharBag;
17
18const BASE_DISTANCE_PENALTY: f64 = 0.6;
19const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
20const MIN_DISTANCE_PENALTY: f64 = 0.2;
21
22struct Matcher<'a> {
23 query: &'a [char],
24 lowercase_query: &'a [char],
25 query_char_bag: CharBag,
26 smart_case: bool,
27 max_results: usize,
28 min_score: f64,
29 match_positions: Vec<usize>,
30 last_positions: Vec<usize>,
31 score_matrix: Vec<Option<f64>>,
32 best_position_matrix: Vec<usize>,
33}
34
35trait Match: Ord {
36 fn score(&self) -> f64;
37 fn set_positions(&mut self, positions: Vec<usize>);
38}
39
40trait MatchCandidate {
41 fn has_chars(&self, bag: CharBag) -> bool;
42 fn to_string<'a>(&'a self) -> Cow<'a, str>;
43}
44
45#[derive(Clone, Debug)]
46pub struct PathMatchCandidate<'a> {
47 pub path: &'a Arc<Path>,
48 pub char_bag: CharBag,
49}
50
51#[derive(Clone, Debug)]
52pub struct PathMatch {
53 pub score: f64,
54 pub positions: Vec<usize>,
55 pub tree_id: usize,
56 pub path: Arc<Path>,
57 pub path_prefix: Arc<str>,
58}
59
60#[derive(Clone, Debug)]
61pub struct StringMatchCandidate {
62 pub string: String,
63 pub char_bag: CharBag,
64}
65
66impl Match for PathMatch {
67 fn score(&self) -> f64 {
68 self.score
69 }
70
71 fn set_positions(&mut self, positions: Vec<usize>) {
72 self.positions = positions;
73 }
74}
75
76impl Match for StringMatch {
77 fn score(&self) -> f64 {
78 self.score
79 }
80
81 fn set_positions(&mut self, positions: Vec<usize>) {
82 self.positions = positions;
83 }
84}
85
86impl<'a> MatchCandidate for PathMatchCandidate<'a> {
87 fn has_chars(&self, bag: CharBag) -> bool {
88 self.char_bag.is_superset(bag)
89 }
90
91 fn to_string(&self) -> Cow<'a, str> {
92 self.path.to_string_lossy()
93 }
94}
95
96impl<'a> MatchCandidate for &'a StringMatchCandidate {
97 fn has_chars(&self, bag: CharBag) -> bool {
98 self.char_bag.is_superset(bag)
99 }
100
101 fn to_string(&self) -> Cow<'a, str> {
102 self.string.as_str().into()
103 }
104}
105
106#[derive(Clone, Debug)]
107pub struct StringMatch {
108 pub score: f64,
109 pub positions: Vec<usize>,
110 pub string: String,
111}
112
113impl PartialEq for StringMatch {
114 fn eq(&self, other: &Self) -> bool {
115 self.score.eq(&other.score)
116 }
117}
118
119impl Eq for StringMatch {}
120
121impl PartialOrd for StringMatch {
122 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
123 Some(self.cmp(other))
124 }
125}
126
127impl Ord for StringMatch {
128 fn cmp(&self, other: &Self) -> Ordering {
129 self.score
130 .partial_cmp(&other.score)
131 .unwrap_or(Ordering::Equal)
132 .then_with(|| self.string.cmp(&other.string))
133 }
134}
135
136impl PartialEq for PathMatch {
137 fn eq(&self, other: &Self) -> bool {
138 self.score.eq(&other.score)
139 }
140}
141
142impl Eq for PathMatch {}
143
144impl PartialOrd for PathMatch {
145 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
146 Some(self.cmp(other))
147 }
148}
149
150impl Ord for PathMatch {
151 fn cmp(&self, other: &Self) -> Ordering {
152 self.score
153 .partial_cmp(&other.score)
154 .unwrap_or(Ordering::Equal)
155 .then_with(|| self.tree_id.cmp(&other.tree_id))
156 .then_with(|| Arc::as_ptr(&self.path).cmp(&Arc::as_ptr(&other.path)))
157 }
158}
159
160pub async fn match_strings(
161 candidates: &[StringMatchCandidate],
162 query: &str,
163 smart_case: bool,
164 max_results: usize,
165 cancel_flag: &AtomicBool,
166 background: Arc<executor::Background>,
167) -> Vec<StringMatch> {
168 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
169 let query = query.chars().collect::<Vec<_>>();
170
171 let lowercase_query = &lowercase_query;
172 let query = &query;
173 let query_char_bag = CharBag::from(&lowercase_query[..]);
174
175 let num_cpus = background.num_cpus().min(candidates.len());
176 let segment_size = (candidates.len() + num_cpus - 1) / num_cpus;
177 let mut segment_results = (0..num_cpus)
178 .map(|_| Vec::with_capacity(max_results))
179 .collect::<Vec<_>>();
180
181 background
182 .scoped(|scope| {
183 for (segment_idx, results) in segment_results.iter_mut().enumerate() {
184 let cancel_flag = &cancel_flag;
185 scope.spawn(async move {
186 let segment_start = segment_idx * segment_size;
187 let segment_end = segment_start + segment_size;
188 let mut matcher = Matcher::new(
189 query,
190 lowercase_query,
191 query_char_bag,
192 smart_case,
193 max_results,
194 );
195 matcher.match_strings(
196 &candidates[segment_start..segment_end],
197 results,
198 cancel_flag,
199 );
200 });
201 }
202 })
203 .await;
204
205 let mut results = Vec::new();
206 for segment_result in segment_results {
207 if results.is_empty() {
208 results = segment_result;
209 } else {
210 util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(&a));
211 }
212 }
213 results
214}
215
216pub async fn match_paths(
217 snapshots: &[Snapshot],
218 query: &str,
219 include_ignored: bool,
220 smart_case: bool,
221 max_results: usize,
222 cancel_flag: &AtomicBool,
223 background: Arc<executor::Background>,
224) -> Vec<PathMatch> {
225 let path_count: usize = if include_ignored {
226 snapshots.iter().map(Snapshot::file_count).sum()
227 } else {
228 snapshots.iter().map(Snapshot::visible_file_count).sum()
229 };
230 if path_count == 0 {
231 return Vec::new();
232 }
233
234 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
235 let query = query.chars().collect::<Vec<_>>();
236
237 let lowercase_query = &lowercase_query;
238 let query = &query;
239 let query_char_bag = CharBag::from(&lowercase_query[..]);
240
241 let num_cpus = background.num_cpus().min(path_count);
242 let segment_size = (path_count + num_cpus - 1) / num_cpus;
243 let mut segment_results = (0..num_cpus)
244 .map(|_| Vec::with_capacity(max_results))
245 .collect::<Vec<_>>();
246
247 background
248 .scoped(|scope| {
249 for (segment_idx, results) in segment_results.iter_mut().enumerate() {
250 scope.spawn(async move {
251 let segment_start = segment_idx * segment_size;
252 let segment_end = segment_start + segment_size;
253 let mut matcher = Matcher::new(
254 query,
255 lowercase_query,
256 query_char_bag,
257 smart_case,
258 max_results,
259 );
260
261 let mut tree_start = 0;
262 for snapshot in snapshots {
263 let tree_end = if include_ignored {
264 tree_start + snapshot.file_count()
265 } else {
266 tree_start + snapshot.visible_file_count()
267 };
268
269 if tree_start < segment_end && segment_start < tree_end {
270 let path_prefix: Arc<str> =
271 if snapshot.root_entry().map_or(false, |e| e.is_file()) {
272 snapshot.root_name().into()
273 } else if snapshots.len() > 1 {
274 format!("{}/", snapshot.root_name()).into()
275 } else {
276 "".into()
277 };
278
279 let start = max(tree_start, segment_start) - tree_start;
280 let end = min(tree_end, segment_end) - tree_start;
281 if include_ignored {
282 let paths = snapshot.files(start).take(end - start).map(|entry| {
283 if let EntryKind::File(char_bag) = entry.kind {
284 PathMatchCandidate {
285 path: &entry.path,
286 char_bag,
287 }
288 } else {
289 unreachable!()
290 }
291 });
292 matcher.match_paths(
293 snapshot.id(),
294 path_prefix,
295 paths,
296 results,
297 &cancel_flag,
298 );
299 } else {
300 let paths =
301 snapshot
302 .visible_files(start)
303 .take(end - start)
304 .map(|entry| {
305 if let EntryKind::File(char_bag) = entry.kind {
306 PathMatchCandidate {
307 path: &entry.path,
308 char_bag,
309 }
310 } else {
311 unreachable!()
312 }
313 });
314 matcher.match_paths(
315 snapshot.id(),
316 path_prefix,
317 paths,
318 results,
319 &cancel_flag,
320 );
321 };
322 }
323 if tree_end >= segment_end {
324 break;
325 }
326 tree_start = tree_end;
327 }
328 })
329 }
330 })
331 .await;
332
333 let mut results = Vec::new();
334 for segment_result in segment_results {
335 if results.is_empty() {
336 results = segment_result;
337 } else {
338 util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(&a));
339 }
340 }
341 results
342}
343
344impl<'a> Matcher<'a> {
345 fn new(
346 query: &'a [char],
347 lowercase_query: &'a [char],
348 query_char_bag: CharBag,
349 smart_case: bool,
350 max_results: usize,
351 ) -> Self {
352 Self {
353 query,
354 lowercase_query,
355 query_char_bag,
356 min_score: 0.0,
357 last_positions: vec![0; query.len()],
358 match_positions: vec![0; query.len()],
359 score_matrix: Vec::new(),
360 best_position_matrix: Vec::new(),
361 smart_case,
362 max_results,
363 }
364 }
365
366 fn match_strings(
367 &mut self,
368 candidates: &[StringMatchCandidate],
369 results: &mut Vec<StringMatch>,
370 cancel_flag: &AtomicBool,
371 ) {
372 self.match_internal(
373 &[],
374 &[],
375 candidates.iter(),
376 results,
377 cancel_flag,
378 |candidate, score| StringMatch {
379 score,
380 positions: Vec::new(),
381 string: candidate.string.to_string(),
382 },
383 )
384 }
385
386 fn match_paths(
387 &mut self,
388 tree_id: usize,
389 path_prefix: Arc<str>,
390 path_entries: impl Iterator<Item = PathMatchCandidate<'a>>,
391 results: &mut Vec<PathMatch>,
392 cancel_flag: &AtomicBool,
393 ) {
394 let prefix = path_prefix.chars().collect::<Vec<_>>();
395 let lowercase_prefix = prefix
396 .iter()
397 .map(|c| c.to_ascii_lowercase())
398 .collect::<Vec<_>>();
399 self.match_internal(
400 &prefix,
401 &lowercase_prefix,
402 path_entries,
403 results,
404 cancel_flag,
405 |candidate, score| PathMatch {
406 score,
407 tree_id,
408 positions: Vec::new(),
409 path: candidate.path.clone(),
410 path_prefix: path_prefix.clone(),
411 },
412 )
413 }
414
415 fn match_internal<C: MatchCandidate, R, F>(
416 &mut self,
417 prefix: &[char],
418 lowercase_prefix: &[char],
419 candidates: impl Iterator<Item = C>,
420 results: &mut Vec<R>,
421 cancel_flag: &AtomicBool,
422 build_match: F,
423 ) where
424 R: Match,
425 F: Fn(&C, f64) -> R,
426 {
427 let mut candidate_chars = Vec::new();
428 let mut lowercase_candidate_chars = Vec::new();
429
430 for candidate in candidates {
431 if !candidate.has_chars(self.query_char_bag) {
432 continue;
433 }
434
435 if cancel_flag.load(atomic::Ordering::Relaxed) {
436 break;
437 }
438
439 candidate_chars.clear();
440 lowercase_candidate_chars.clear();
441 for c in candidate.to_string().chars() {
442 candidate_chars.push(c);
443 lowercase_candidate_chars.push(c.to_ascii_lowercase());
444 }
445
446 if !self.find_last_positions(&lowercase_prefix, &lowercase_candidate_chars) {
447 continue;
448 }
449
450 let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
451 self.score_matrix.clear();
452 self.score_matrix.resize(matrix_len, None);
453 self.best_position_matrix.clear();
454 self.best_position_matrix.resize(matrix_len, 0);
455
456 let score = self.score_match(
457 &candidate_chars,
458 &lowercase_candidate_chars,
459 &prefix,
460 &lowercase_prefix,
461 );
462
463 if score > 0.0 {
464 let mut mat = build_match(&candidate, score);
465 if let Err(i) = results.binary_search_by(|m| mat.cmp(&m)) {
466 if results.len() < self.max_results {
467 mat.set_positions(self.match_positions.clone());
468 results.insert(i, mat);
469 } else if i < results.len() {
470 results.pop();
471 mat.set_positions(self.match_positions.clone());
472 results.insert(i, mat);
473 }
474 if results.len() == self.max_results {
475 self.min_score = results.last().unwrap().score();
476 }
477 }
478 }
479 }
480 }
481
482 fn find_last_positions(&mut self, prefix: &[char], path: &[char]) -> bool {
483 let mut path = path.iter();
484 let mut prefix_iter = prefix.iter();
485 for (i, char) in self.query.iter().enumerate().rev() {
486 if let Some(j) = path.rposition(|c| c == char) {
487 self.last_positions[i] = j + prefix.len();
488 } else if let Some(j) = prefix_iter.rposition(|c| c == char) {
489 self.last_positions[i] = j;
490 } else {
491 return false;
492 }
493 }
494 true
495 }
496
497 fn score_match(
498 &mut self,
499 path: &[char],
500 path_cased: &[char],
501 prefix: &[char],
502 lowercase_prefix: &[char],
503 ) -> f64 {
504 let score = self.recursive_score_match(
505 path,
506 path_cased,
507 prefix,
508 lowercase_prefix,
509 0,
510 0,
511 self.query.len() as f64,
512 ) * self.query.len() as f64;
513
514 if score <= 0.0 {
515 return 0.0;
516 }
517
518 let path_len = prefix.len() + path.len();
519 let mut cur_start = 0;
520 let mut byte_ix = 0;
521 let mut char_ix = 0;
522 for i in 0..self.query.len() {
523 let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
524 while char_ix < match_char_ix {
525 let ch = prefix
526 .get(char_ix)
527 .or_else(|| path.get(char_ix - prefix.len()))
528 .unwrap();
529 byte_ix += ch.len_utf8();
530 char_ix += 1;
531 }
532 cur_start = match_char_ix + 1;
533 self.match_positions[i] = byte_ix;
534 }
535
536 score
537 }
538
539 fn recursive_score_match(
540 &mut self,
541 path: &[char],
542 path_cased: &[char],
543 prefix: &[char],
544 lowercase_prefix: &[char],
545 query_idx: usize,
546 path_idx: usize,
547 cur_score: f64,
548 ) -> f64 {
549 if query_idx == self.query.len() {
550 return 1.0;
551 }
552
553 let path_len = prefix.len() + path.len();
554
555 if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
556 return memoized;
557 }
558
559 let mut score = 0.0;
560 let mut best_position = 0;
561
562 let query_char = self.lowercase_query[query_idx];
563 let limit = self.last_positions[query_idx];
564
565 let mut last_slash = 0;
566 for j in path_idx..=limit {
567 let path_char = if j < prefix.len() {
568 lowercase_prefix[j]
569 } else {
570 path_cased[j - prefix.len()]
571 };
572 let is_path_sep = path_char == '/' || path_char == '\\';
573
574 if query_idx == 0 && is_path_sep {
575 last_slash = j;
576 }
577
578 if query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\') {
579 let curr = if j < prefix.len() {
580 prefix[j]
581 } else {
582 path[j - prefix.len()]
583 };
584
585 let mut char_score = 1.0;
586 if j > path_idx {
587 let last = if j - 1 < prefix.len() {
588 prefix[j - 1]
589 } else {
590 path[j - 1 - prefix.len()]
591 };
592
593 if last == '/' {
594 char_score = 0.9;
595 } else if last == '-' || last == '_' || last == ' ' || last.is_numeric() {
596 char_score = 0.8;
597 } else if last.is_lowercase() && curr.is_uppercase() {
598 char_score = 0.8;
599 } else if last == '.' {
600 char_score = 0.7;
601 } else if query_idx == 0 {
602 char_score = BASE_DISTANCE_PENALTY;
603 } else {
604 char_score = MIN_DISTANCE_PENALTY.max(
605 BASE_DISTANCE_PENALTY
606 - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
607 );
608 }
609 }
610
611 // Apply a severe penalty if the case doesn't match.
612 // This will make the exact matches have higher score than the case-insensitive and the
613 // path insensitive matches.
614 if (self.smart_case || curr == '/') && self.query[query_idx] != curr {
615 char_score *= 0.001;
616 }
617
618 let mut multiplier = char_score;
619
620 // Scale the score based on how deep within the path we found the match.
621 if query_idx == 0 {
622 multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
623 }
624
625 let mut next_score = 1.0;
626 if self.min_score > 0.0 {
627 next_score = cur_score * multiplier;
628 // Scores only decrease. If we can't pass the previous best, bail
629 if next_score < self.min_score {
630 // Ensure that score is non-zero so we use it in the memo table.
631 if score == 0.0 {
632 score = 1e-18;
633 }
634 continue;
635 }
636 }
637
638 let new_score = self.recursive_score_match(
639 path,
640 path_cased,
641 prefix,
642 lowercase_prefix,
643 query_idx + 1,
644 j + 1,
645 next_score,
646 ) * multiplier;
647
648 if new_score > score {
649 score = new_score;
650 best_position = j;
651 // Optimization: can't score better than 1.
652 if new_score == 1.0 {
653 break;
654 }
655 }
656 }
657 }
658
659 if best_position != 0 {
660 self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
661 }
662
663 self.score_matrix[query_idx * path_len + path_idx] = Some(score);
664 score
665 }
666}
667
668#[cfg(test)]
669mod tests {
670 use super::*;
671 use std::path::PathBuf;
672
673 #[test]
674 fn test_get_last_positions() {
675 let mut query: &[char] = &['d', 'c'];
676 let mut matcher = Matcher::new(query, query, query.into(), false, 10);
677 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
678 assert_eq!(result, false);
679
680 query = &['c', 'd'];
681 let mut matcher = Matcher::new(query, query, query.into(), false, 10);
682 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
683 assert_eq!(result, true);
684 assert_eq!(matcher.last_positions, vec![2, 4]);
685
686 query = &['z', '/', 'z', 'f'];
687 let mut matcher = Matcher::new(query, query, query.into(), false, 10);
688 let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
689 assert_eq!(result, true);
690 assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
691 }
692
693 #[test]
694 fn test_match_path_entries() {
695 let paths = vec![
696 "",
697 "a",
698 "ab",
699 "abC",
700 "abcd",
701 "alphabravocharlie",
702 "AlphaBravoCharlie",
703 "thisisatestdir",
704 "/////ThisIsATestDir",
705 "/this/is/a/test/dir",
706 "/test/tiatd",
707 ];
708
709 assert_eq!(
710 match_query("abc", false, &paths),
711 vec![
712 ("abC", vec![0, 1, 2]),
713 ("abcd", vec![0, 1, 2]),
714 ("AlphaBravoCharlie", vec![0, 5, 10]),
715 ("alphabravocharlie", vec![4, 5, 10]),
716 ]
717 );
718 assert_eq!(
719 match_query("t/i/a/t/d", false, &paths),
720 vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
721 );
722
723 assert_eq!(
724 match_query("tiatd", false, &paths),
725 vec![
726 ("/test/tiatd", vec![6, 7, 8, 9, 10]),
727 ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
728 ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
729 ("thisisatestdir", vec![0, 2, 6, 7, 11]),
730 ]
731 );
732 }
733
734 #[test]
735 fn test_match_multibyte_path_entries() {
736 let paths = vec!["aαbβ/cγdδ", "αβγδ/bcde", "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", "/d/🆒/h"];
737 assert_eq!("1️⃣".len(), 7);
738 assert_eq!(
739 match_query("bcd", false, &paths),
740 vec![
741 ("αβγδ/bcde", vec![9, 10, 11]),
742 ("aαbβ/cγdδ", vec![3, 7, 10]),
743 ]
744 );
745 assert_eq!(
746 match_query("cde", false, &paths),
747 vec![
748 ("αβγδ/bcde", vec![10, 11, 12]),
749 ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
750 ]
751 );
752 }
753
754 fn match_query<'a>(
755 query: &str,
756 smart_case: bool,
757 paths: &Vec<&'a str>,
758 ) -> Vec<(&'a str, Vec<usize>)> {
759 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
760 let query = query.chars().collect::<Vec<_>>();
761 let query_chars = CharBag::from(&lowercase_query[..]);
762
763 let path_arcs = paths
764 .iter()
765 .map(|path| Arc::from(PathBuf::from(path)))
766 .collect::<Vec<_>>();
767 let mut path_entries = Vec::new();
768 for (i, path) in paths.iter().enumerate() {
769 let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
770 let char_bag = CharBag::from(lowercase_path.as_slice());
771 path_entries.push(PathMatchCandidate {
772 char_bag,
773 path: path_arcs.get(i).unwrap(),
774 });
775 }
776
777 let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case, 100);
778
779 let cancel_flag = AtomicBool::new(false);
780 let mut results = Vec::new();
781 matcher.match_paths(
782 0,
783 "".into(),
784 path_entries.into_iter(),
785 &mut results,
786 &cancel_flag,
787 );
788
789 results
790 .into_iter()
791 .map(|result| {
792 (
793 paths
794 .iter()
795 .copied()
796 .find(|p| result.path.as_ref() == Path::new(p))
797 .unwrap(),
798 result.positions,
799 )
800 })
801 .collect()
802 }
803}