1mod char_bag;
2
3use gpui::executor;
4use std::{
5 borrow::Cow,
6 cmp::{self, Ordering},
7 path::Path,
8 sync::atomic::{self, AtomicBool},
9 sync::Arc,
10};
11
12pub use char_bag::CharBag;
13
14const BASE_DISTANCE_PENALTY: f64 = 0.6;
15const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
16const MIN_DISTANCE_PENALTY: f64 = 0.2;
17
18pub struct Matcher<'a> {
19 query: &'a [char],
20 lowercase_query: &'a [char],
21 query_char_bag: CharBag,
22 smart_case: bool,
23 max_results: usize,
24 min_score: f64,
25 match_positions: Vec<usize>,
26 last_positions: Vec<usize>,
27 score_matrix: Vec<Option<f64>>,
28 best_position_matrix: Vec<usize>,
29}
30
31trait Match: Ord {
32 fn score(&self) -> f64;
33 fn set_positions(&mut self, positions: Vec<usize>);
34}
35
36trait MatchCandidate {
37 fn has_chars(&self, bag: CharBag) -> bool;
38 fn to_string<'a>(&'a self) -> Cow<'a, str>;
39}
40
41#[derive(Clone, Debug)]
42pub struct PathMatchCandidate<'a> {
43 pub path: &'a Arc<Path>,
44 pub char_bag: CharBag,
45}
46
47#[derive(Clone, Debug)]
48pub struct PathMatch {
49 pub score: f64,
50 pub positions: Vec<usize>,
51 pub worktree_id: usize,
52 pub path: Arc<Path>,
53 pub path_prefix: Arc<str>,
54}
55
56#[derive(Clone, Debug)]
57pub struct StringMatchCandidate {
58 pub id: usize,
59 pub string: String,
60 pub char_bag: CharBag,
61}
62
63pub trait PathMatchCandidateSet<'a>: Send + Sync {
64 type Candidates: Iterator<Item = PathMatchCandidate<'a>>;
65 fn id(&self) -> usize;
66 fn len(&self) -> usize;
67 fn prefix(&self) -> Arc<str>;
68 fn candidates(&'a self, start: usize) -> Self::Candidates;
69}
70
71impl Match for PathMatch {
72 fn score(&self) -> f64 {
73 self.score
74 }
75
76 fn set_positions(&mut self, positions: Vec<usize>) {
77 self.positions = positions;
78 }
79}
80
81impl Match for StringMatch {
82 fn score(&self) -> f64 {
83 self.score
84 }
85
86 fn set_positions(&mut self, positions: Vec<usize>) {
87 self.positions = positions;
88 }
89}
90
91impl<'a> MatchCandidate for PathMatchCandidate<'a> {
92 fn has_chars(&self, bag: CharBag) -> bool {
93 self.char_bag.is_superset(bag)
94 }
95
96 fn to_string(&self) -> Cow<'a, str> {
97 self.path.to_string_lossy()
98 }
99}
100
101impl<'a> MatchCandidate for &'a StringMatchCandidate {
102 fn has_chars(&self, bag: CharBag) -> bool {
103 self.char_bag.is_superset(bag)
104 }
105
106 fn to_string(&self) -> Cow<'a, str> {
107 self.string.as_str().into()
108 }
109}
110
111#[derive(Clone, Debug)]
112pub struct StringMatch {
113 pub candidate_id: usize,
114 pub score: f64,
115 pub positions: Vec<usize>,
116 pub string: String,
117}
118
119impl PartialEq for StringMatch {
120 fn eq(&self, other: &Self) -> bool {
121 self.cmp(other).is_eq()
122 }
123}
124
125impl Eq for StringMatch {}
126
127impl PartialOrd for StringMatch {
128 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
129 Some(self.cmp(other))
130 }
131}
132
133impl Ord for StringMatch {
134 fn cmp(&self, other: &Self) -> Ordering {
135 self.score
136 .partial_cmp(&other.score)
137 .unwrap_or(Ordering::Equal)
138 .then_with(|| self.candidate_id.cmp(&other.candidate_id))
139 }
140}
141
142impl PartialEq for PathMatch {
143 fn eq(&self, other: &Self) -> bool {
144 self.cmp(other).is_eq()
145 }
146}
147
148impl Eq for PathMatch {}
149
150impl PartialOrd for PathMatch {
151 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
152 Some(self.cmp(other))
153 }
154}
155
156impl Ord for PathMatch {
157 fn cmp(&self, other: &Self) -> Ordering {
158 self.score
159 .partial_cmp(&other.score)
160 .unwrap_or(Ordering::Equal)
161 .then_with(|| self.worktree_id.cmp(&other.worktree_id))
162 .then_with(|| Arc::as_ptr(&self.path).cmp(&Arc::as_ptr(&other.path)))
163 }
164}
165
166pub async fn match_strings(
167 candidates: &[StringMatchCandidate],
168 query: &str,
169 smart_case: bool,
170 max_results: usize,
171 cancel_flag: &AtomicBool,
172 background: Arc<executor::Background>,
173) -> Vec<StringMatch> {
174 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
175 let query = query.chars().collect::<Vec<_>>();
176
177 let lowercase_query = &lowercase_query;
178 let query = &query;
179 let query_char_bag = CharBag::from(&lowercase_query[..]);
180
181 let num_cpus = background.num_cpus().min(candidates.len());
182 let segment_size = (candidates.len() + num_cpus - 1) / num_cpus;
183 let mut segment_results = (0..num_cpus)
184 .map(|_| Vec::with_capacity(max_results))
185 .collect::<Vec<_>>();
186
187 background
188 .scoped(|scope| {
189 for (segment_idx, results) in segment_results.iter_mut().enumerate() {
190 let cancel_flag = &cancel_flag;
191 scope.spawn(async move {
192 let segment_start = cmp::min(segment_idx * segment_size, candidates.len());
193 let segment_end = cmp::min(segment_start + segment_size, candidates.len());
194 let mut matcher = Matcher::new(
195 query,
196 lowercase_query,
197 query_char_bag,
198 smart_case,
199 max_results,
200 );
201 matcher.match_strings(
202 &candidates[segment_start..segment_end],
203 results,
204 cancel_flag,
205 );
206 });
207 }
208 })
209 .await;
210
211 let mut results = Vec::new();
212 for segment_result in segment_results {
213 if results.is_empty() {
214 results = segment_result;
215 } else {
216 util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(&a));
217 }
218 }
219 results
220}
221
222pub async fn match_paths<'a, Set: PathMatchCandidateSet<'a>>(
223 candidate_sets: &'a [Set],
224 query: &str,
225 smart_case: bool,
226 max_results: usize,
227 cancel_flag: &AtomicBool,
228 background: Arc<executor::Background>,
229) -> Vec<PathMatch> {
230 let path_count: usize = candidate_sets.iter().map(|s| s.len()).sum();
231 if path_count == 0 {
232 return Vec::new();
233 }
234
235 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
236 let query = query.chars().collect::<Vec<_>>();
237
238 let lowercase_query = &lowercase_query;
239 let query = &query;
240 let query_char_bag = CharBag::from(&lowercase_query[..]);
241
242 let num_cpus = background.num_cpus().min(path_count);
243 let segment_size = (path_count + num_cpus - 1) / num_cpus;
244 let mut segment_results = (0..num_cpus)
245 .map(|_| Vec::with_capacity(max_results))
246 .collect::<Vec<_>>();
247
248 background
249 .scoped(|scope| {
250 for (segment_idx, results) in segment_results.iter_mut().enumerate() {
251 scope.spawn(async move {
252 let segment_start = segment_idx * segment_size;
253 let segment_end = segment_start + segment_size;
254 let mut matcher = Matcher::new(
255 query,
256 lowercase_query,
257 query_char_bag,
258 smart_case,
259 max_results,
260 );
261
262 let mut tree_start = 0;
263 for candidate_set in candidate_sets {
264 let tree_end = tree_start + candidate_set.len();
265
266 if tree_start < segment_end && segment_start < tree_end {
267 let start = cmp::max(tree_start, segment_start) - tree_start;
268 let end = cmp::min(tree_end, segment_end) - tree_start;
269 let candidates = candidate_set.candidates(start).take(end - start);
270
271 matcher.match_paths(
272 candidate_set.id(),
273 candidate_set.prefix(),
274 candidates,
275 results,
276 &cancel_flag,
277 );
278 }
279 if tree_end >= segment_end {
280 break;
281 }
282 tree_start = tree_end;
283 }
284 })
285 }
286 })
287 .await;
288
289 let mut results = Vec::new();
290 for segment_result in segment_results {
291 if results.is_empty() {
292 results = segment_result;
293 } else {
294 util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(&a));
295 }
296 }
297 results
298}
299
300impl<'a> Matcher<'a> {
301 pub fn new(
302 query: &'a [char],
303 lowercase_query: &'a [char],
304 query_char_bag: CharBag,
305 smart_case: bool,
306 max_results: usize,
307 ) -> Self {
308 Self {
309 query,
310 lowercase_query,
311 query_char_bag,
312 min_score: 0.0,
313 last_positions: vec![0; query.len()],
314 match_positions: vec![0; query.len()],
315 score_matrix: Vec::new(),
316 best_position_matrix: Vec::new(),
317 smart_case,
318 max_results,
319 }
320 }
321
322 pub fn match_strings(
323 &mut self,
324 candidates: &[StringMatchCandidate],
325 results: &mut Vec<StringMatch>,
326 cancel_flag: &AtomicBool,
327 ) {
328 self.match_internal(
329 &[],
330 &[],
331 candidates.iter(),
332 results,
333 cancel_flag,
334 |candidate, score| StringMatch {
335 candidate_id: candidate.id,
336 score,
337 positions: Vec::new(),
338 string: candidate.string.to_string(),
339 },
340 )
341 }
342
343 pub fn match_paths<'c: 'a>(
344 &mut self,
345 tree_id: usize,
346 path_prefix: Arc<str>,
347 path_entries: impl Iterator<Item = PathMatchCandidate<'c>>,
348 results: &mut Vec<PathMatch>,
349 cancel_flag: &AtomicBool,
350 ) {
351 let prefix = path_prefix.chars().collect::<Vec<_>>();
352 let lowercase_prefix = prefix
353 .iter()
354 .map(|c| c.to_ascii_lowercase())
355 .collect::<Vec<_>>();
356 self.match_internal(
357 &prefix,
358 &lowercase_prefix,
359 path_entries,
360 results,
361 cancel_flag,
362 |candidate, score| PathMatch {
363 score,
364 worktree_id: tree_id,
365 positions: Vec::new(),
366 path: candidate.path.clone(),
367 path_prefix: path_prefix.clone(),
368 },
369 )
370 }
371
372 fn match_internal<C: MatchCandidate, R, F>(
373 &mut self,
374 prefix: &[char],
375 lowercase_prefix: &[char],
376 candidates: impl Iterator<Item = C>,
377 results: &mut Vec<R>,
378 cancel_flag: &AtomicBool,
379 build_match: F,
380 ) where
381 R: Match,
382 F: Fn(&C, f64) -> R,
383 {
384 let mut candidate_chars = Vec::new();
385 let mut lowercase_candidate_chars = Vec::new();
386
387 for candidate in candidates {
388 if !candidate.has_chars(self.query_char_bag) {
389 continue;
390 }
391
392 if cancel_flag.load(atomic::Ordering::Relaxed) {
393 break;
394 }
395
396 candidate_chars.clear();
397 lowercase_candidate_chars.clear();
398 for c in candidate.to_string().chars() {
399 candidate_chars.push(c);
400 lowercase_candidate_chars.push(c.to_ascii_lowercase());
401 }
402
403 if !self.find_last_positions(&lowercase_prefix, &lowercase_candidate_chars) {
404 continue;
405 }
406
407 let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
408 self.score_matrix.clear();
409 self.score_matrix.resize(matrix_len, None);
410 self.best_position_matrix.clear();
411 self.best_position_matrix.resize(matrix_len, 0);
412
413 let score = self.score_match(
414 &candidate_chars,
415 &lowercase_candidate_chars,
416 &prefix,
417 &lowercase_prefix,
418 );
419
420 if score > 0.0 {
421 let mut mat = build_match(&candidate, score);
422 if let Err(i) = results.binary_search_by(|m| mat.cmp(&m)) {
423 if results.len() < self.max_results {
424 mat.set_positions(self.match_positions.clone());
425 results.insert(i, mat);
426 } else if i < results.len() {
427 results.pop();
428 mat.set_positions(self.match_positions.clone());
429 results.insert(i, mat);
430 }
431 if results.len() == self.max_results {
432 self.min_score = results.last().unwrap().score();
433 }
434 }
435 }
436 }
437 }
438
439 fn find_last_positions(
440 &mut self,
441 lowercase_prefix: &[char],
442 lowercase_candidate: &[char],
443 ) -> bool {
444 let mut lowercase_prefix = lowercase_prefix.iter();
445 let mut lowercase_candidate = lowercase_candidate.iter();
446 for (i, char) in self.lowercase_query.iter().enumerate().rev() {
447 if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
448 self.last_positions[i] = j + lowercase_prefix.len();
449 } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
450 self.last_positions[i] = j;
451 } else {
452 return false;
453 }
454 }
455 true
456 }
457
458 fn score_match(
459 &mut self,
460 path: &[char],
461 path_cased: &[char],
462 prefix: &[char],
463 lowercase_prefix: &[char],
464 ) -> f64 {
465 let score = self.recursive_score_match(
466 path,
467 path_cased,
468 prefix,
469 lowercase_prefix,
470 0,
471 0,
472 self.query.len() as f64,
473 ) * self.query.len() as f64;
474
475 if score <= 0.0 {
476 return 0.0;
477 }
478
479 let path_len = prefix.len() + path.len();
480 let mut cur_start = 0;
481 let mut byte_ix = 0;
482 let mut char_ix = 0;
483 for i in 0..self.query.len() {
484 let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
485 while char_ix < match_char_ix {
486 let ch = prefix
487 .get(char_ix)
488 .or_else(|| path.get(char_ix - prefix.len()))
489 .unwrap();
490 byte_ix += ch.len_utf8();
491 char_ix += 1;
492 }
493 cur_start = match_char_ix + 1;
494 self.match_positions[i] = byte_ix;
495 }
496
497 score
498 }
499
500 fn recursive_score_match(
501 &mut self,
502 path: &[char],
503 path_cased: &[char],
504 prefix: &[char],
505 lowercase_prefix: &[char],
506 query_idx: usize,
507 path_idx: usize,
508 cur_score: f64,
509 ) -> f64 {
510 if query_idx == self.query.len() {
511 return 1.0;
512 }
513
514 let path_len = prefix.len() + path.len();
515
516 if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
517 return memoized;
518 }
519
520 let mut score = 0.0;
521 let mut best_position = 0;
522
523 let query_char = self.lowercase_query[query_idx];
524 let limit = self.last_positions[query_idx];
525
526 let mut last_slash = 0;
527 for j in path_idx..=limit {
528 let path_char = if j < prefix.len() {
529 lowercase_prefix[j]
530 } else {
531 path_cased[j - prefix.len()]
532 };
533 let is_path_sep = path_char == '/' || path_char == '\\';
534
535 if query_idx == 0 && is_path_sep {
536 last_slash = j;
537 }
538
539 if query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\') {
540 let curr = if j < prefix.len() {
541 prefix[j]
542 } else {
543 path[j - prefix.len()]
544 };
545
546 let mut char_score = 1.0;
547 if j > path_idx {
548 let last = if j - 1 < prefix.len() {
549 prefix[j - 1]
550 } else {
551 path[j - 1 - prefix.len()]
552 };
553
554 if last == '/' {
555 char_score = 0.9;
556 } else if last == '-' || last == '_' || last == ' ' || last.is_numeric() {
557 char_score = 0.8;
558 } else if last.is_lowercase() && curr.is_uppercase() {
559 char_score = 0.8;
560 } else if last == '.' {
561 char_score = 0.7;
562 } else if query_idx == 0 {
563 char_score = BASE_DISTANCE_PENALTY;
564 } else {
565 char_score = MIN_DISTANCE_PENALTY.max(
566 BASE_DISTANCE_PENALTY
567 - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
568 );
569 }
570 }
571
572 // Apply a severe penalty if the case doesn't match.
573 // This will make the exact matches have higher score than the case-insensitive and the
574 // path insensitive matches.
575 if (self.smart_case || curr == '/') && self.query[query_idx] != curr {
576 char_score *= 0.001;
577 }
578
579 let mut multiplier = char_score;
580
581 // Scale the score based on how deep within the path we found the match.
582 if query_idx == 0 {
583 multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
584 }
585
586 let mut next_score = 1.0;
587 if self.min_score > 0.0 {
588 next_score = cur_score * multiplier;
589 // Scores only decrease. If we can't pass the previous best, bail
590 if next_score < self.min_score {
591 // Ensure that score is non-zero so we use it in the memo table.
592 if score == 0.0 {
593 score = 1e-18;
594 }
595 continue;
596 }
597 }
598
599 let new_score = self.recursive_score_match(
600 path,
601 path_cased,
602 prefix,
603 lowercase_prefix,
604 query_idx + 1,
605 j + 1,
606 next_score,
607 ) * multiplier;
608
609 if new_score > score {
610 score = new_score;
611 best_position = j;
612 // Optimization: can't score better than 1.
613 if new_score == 1.0 {
614 break;
615 }
616 }
617 }
618 }
619
620 if best_position != 0 {
621 self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
622 }
623
624 self.score_matrix[query_idx * path_len + path_idx] = Some(score);
625 score
626 }
627}
628
629#[cfg(test)]
630mod tests {
631 use super::*;
632 use std::path::PathBuf;
633
634 #[test]
635 fn test_get_last_positions() {
636 let mut query: &[char] = &['d', 'c'];
637 let mut matcher = Matcher::new(query, query, query.into(), false, 10);
638 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
639 assert_eq!(result, false);
640
641 query = &['c', 'd'];
642 let mut matcher = Matcher::new(query, query, query.into(), false, 10);
643 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
644 assert_eq!(result, true);
645 assert_eq!(matcher.last_positions, vec![2, 4]);
646
647 query = &['z', '/', 'z', 'f'];
648 let mut matcher = Matcher::new(query, query, query.into(), false, 10);
649 let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
650 assert_eq!(result, true);
651 assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
652 }
653
654 #[test]
655 fn test_match_path_entries() {
656 let paths = vec![
657 "",
658 "a",
659 "ab",
660 "abC",
661 "abcd",
662 "alphabravocharlie",
663 "AlphaBravoCharlie",
664 "thisisatestdir",
665 "/////ThisIsATestDir",
666 "/this/is/a/test/dir",
667 "/test/tiatd",
668 ];
669
670 assert_eq!(
671 match_query("abc", false, &paths),
672 vec![
673 ("abC", vec![0, 1, 2]),
674 ("abcd", vec![0, 1, 2]),
675 ("AlphaBravoCharlie", vec![0, 5, 10]),
676 ("alphabravocharlie", vec![4, 5, 10]),
677 ]
678 );
679 assert_eq!(
680 match_query("t/i/a/t/d", false, &paths),
681 vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
682 );
683
684 assert_eq!(
685 match_query("tiatd", false, &paths),
686 vec![
687 ("/test/tiatd", vec![6, 7, 8, 9, 10]),
688 ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
689 ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
690 ("thisisatestdir", vec![0, 2, 6, 7, 11]),
691 ]
692 );
693 }
694
695 #[test]
696 fn test_match_multibyte_path_entries() {
697 let paths = vec!["aαbβ/cγdδ", "αβγδ/bcde", "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", "/d/🆒/h"];
698 assert_eq!("1️⃣".len(), 7);
699 assert_eq!(
700 match_query("bcd", false, &paths),
701 vec![
702 ("αβγδ/bcde", vec![9, 10, 11]),
703 ("aαbβ/cγdδ", vec![3, 7, 10]),
704 ]
705 );
706 assert_eq!(
707 match_query("cde", false, &paths),
708 vec![
709 ("αβγδ/bcde", vec![10, 11, 12]),
710 ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
711 ]
712 );
713 }
714
715 fn match_query<'a>(
716 query: &str,
717 smart_case: bool,
718 paths: &Vec<&'a str>,
719 ) -> Vec<(&'a str, Vec<usize>)> {
720 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
721 let query = query.chars().collect::<Vec<_>>();
722 let query_chars = CharBag::from(&lowercase_query[..]);
723
724 let path_arcs = paths
725 .iter()
726 .map(|path| Arc::from(PathBuf::from(path)))
727 .collect::<Vec<_>>();
728 let mut path_entries = Vec::new();
729 for (i, path) in paths.iter().enumerate() {
730 let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
731 let char_bag = CharBag::from(lowercase_path.as_slice());
732 path_entries.push(PathMatchCandidate {
733 char_bag,
734 path: path_arcs.get(i).unwrap(),
735 });
736 }
737
738 let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case, 100);
739
740 let cancel_flag = AtomicBool::new(false);
741 let mut results = Vec::new();
742 matcher.match_paths(
743 0,
744 "".into(),
745 path_entries.into_iter(),
746 &mut results,
747 &cancel_flag,
748 );
749
750 results
751 .into_iter()
752 .map(|result| {
753 (
754 paths
755 .iter()
756 .copied()
757 .find(|p| result.path.as_ref() == Path::new(p))
758 .unwrap(),
759 result.positions,
760 )
761 })
762 .collect()
763 }
764}