1mod char_bag;
2
3use gpui::executor;
4use std::{
5 borrow::Cow,
6 cmp::{self, Ordering},
7 path::Path,
8 sync::atomic::{self, AtomicBool},
9 sync::Arc,
10};
11
12pub use char_bag::CharBag;
13
14const BASE_DISTANCE_PENALTY: f64 = 0.6;
15const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
16const MIN_DISTANCE_PENALTY: f64 = 0.2;
17
18pub struct Matcher<'a> {
19 query: &'a [char],
20 lowercase_query: &'a [char],
21 query_char_bag: CharBag,
22 smart_case: bool,
23 max_results: usize,
24 min_score: f64,
25 match_positions: Vec<usize>,
26 last_positions: Vec<usize>,
27 score_matrix: Vec<Option<f64>>,
28 best_position_matrix: Vec<usize>,
29}
30
31trait Match: Ord {
32 fn score(&self) -> f64;
33 fn set_positions(&mut self, positions: Vec<usize>);
34}
35
36trait MatchCandidate {
37 fn has_chars(&self, bag: CharBag) -> bool;
38 fn to_string<'a>(&'a self) -> Cow<'a, str>;
39}
40
41#[derive(Clone, Debug)]
42pub struct PathMatchCandidate<'a> {
43 pub path: &'a Arc<Path>,
44 pub char_bag: CharBag,
45}
46
47#[derive(Clone, Debug)]
48pub struct PathMatch {
49 pub score: f64,
50 pub positions: Vec<usize>,
51 pub worktree_id: usize,
52 pub path: Arc<Path>,
53 pub path_prefix: Arc<str>,
54}
55
56#[derive(Clone, Debug)]
57pub struct StringMatchCandidate {
58 pub string: String,
59 pub char_bag: CharBag,
60}
61
62pub trait PathMatchCandidateSet<'a>: Send + Sync {
63 type Candidates: Iterator<Item = PathMatchCandidate<'a>>;
64 fn id(&self) -> usize;
65 fn len(&self) -> usize;
66 fn prefix(&self) -> Arc<str>;
67 fn candidates(&'a self, start: usize) -> Self::Candidates;
68}
69
70impl Match for PathMatch {
71 fn score(&self) -> f64 {
72 self.score
73 }
74
75 fn set_positions(&mut self, positions: Vec<usize>) {
76 self.positions = positions;
77 }
78}
79
80impl Match for StringMatch {
81 fn score(&self) -> f64 {
82 self.score
83 }
84
85 fn set_positions(&mut self, positions: Vec<usize>) {
86 self.positions = positions;
87 }
88}
89
90impl<'a> MatchCandidate for PathMatchCandidate<'a> {
91 fn has_chars(&self, bag: CharBag) -> bool {
92 self.char_bag.is_superset(bag)
93 }
94
95 fn to_string(&self) -> Cow<'a, str> {
96 self.path.to_string_lossy()
97 }
98}
99
100impl<'a> MatchCandidate for &'a StringMatchCandidate {
101 fn has_chars(&self, bag: CharBag) -> bool {
102 self.char_bag.is_superset(bag)
103 }
104
105 fn to_string(&self) -> Cow<'a, str> {
106 self.string.as_str().into()
107 }
108}
109
110#[derive(Clone, Debug)]
111pub struct StringMatch {
112 pub candidate_index: usize,
113 pub score: f64,
114 pub positions: Vec<usize>,
115 pub string: String,
116}
117
118impl PartialEq for StringMatch {
119 fn eq(&self, other: &Self) -> bool {
120 self.score.eq(&other.score)
121 }
122}
123
124impl Eq for StringMatch {}
125
126impl PartialOrd for StringMatch {
127 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
128 Some(self.cmp(other))
129 }
130}
131
132impl Ord for StringMatch {
133 fn cmp(&self, other: &Self) -> Ordering {
134 self.score
135 .partial_cmp(&other.score)
136 .unwrap_or(Ordering::Equal)
137 .then_with(|| self.string.cmp(&other.string))
138 }
139}
140
141impl PartialEq for PathMatch {
142 fn eq(&self, other: &Self) -> bool {
143 self.score.eq(&other.score)
144 }
145}
146
147impl Eq for PathMatch {}
148
149impl PartialOrd for PathMatch {
150 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
151 Some(self.cmp(other))
152 }
153}
154
155impl Ord for PathMatch {
156 fn cmp(&self, other: &Self) -> Ordering {
157 self.score
158 .partial_cmp(&other.score)
159 .unwrap_or(Ordering::Equal)
160 .then_with(|| self.worktree_id.cmp(&other.worktree_id))
161 .then_with(|| Arc::as_ptr(&self.path).cmp(&Arc::as_ptr(&other.path)))
162 }
163}
164
165pub async fn match_strings(
166 candidates: &[StringMatchCandidate],
167 query: &str,
168 smart_case: bool,
169 max_results: usize,
170 cancel_flag: &AtomicBool,
171 background: Arc<executor::Background>,
172) -> Vec<StringMatch> {
173 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
174 let query = query.chars().collect::<Vec<_>>();
175
176 let lowercase_query = &lowercase_query;
177 let query = &query;
178 let query_char_bag = CharBag::from(&lowercase_query[..]);
179
180 let num_cpus = background.num_cpus().min(candidates.len());
181 let segment_size = (candidates.len() + num_cpus - 1) / num_cpus;
182 let mut segment_results = (0..num_cpus)
183 .map(|_| Vec::with_capacity(max_results))
184 .collect::<Vec<_>>();
185
186 background
187 .scoped(|scope| {
188 for (segment_idx, results) in segment_results.iter_mut().enumerate() {
189 let cancel_flag = &cancel_flag;
190 scope.spawn(async move {
191 let segment_start = cmp::min(segment_idx * segment_size, candidates.len());
192 let segment_end = cmp::min(segment_start + segment_size, candidates.len());
193 let mut matcher = Matcher::new(
194 query,
195 lowercase_query,
196 query_char_bag,
197 smart_case,
198 max_results,
199 );
200 matcher.match_strings(
201 segment_start,
202 &candidates[segment_start..segment_end],
203 results,
204 cancel_flag,
205 );
206 });
207 }
208 })
209 .await;
210
211 let mut results = Vec::new();
212 for segment_result in segment_results {
213 if results.is_empty() {
214 results = segment_result;
215 } else {
216 util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(&a));
217 }
218 }
219 results
220}
221
222pub async fn match_paths<'a, Set: PathMatchCandidateSet<'a>>(
223 candidate_sets: &'a [Set],
224 query: &str,
225 smart_case: bool,
226 max_results: usize,
227 cancel_flag: &AtomicBool,
228 background: Arc<executor::Background>,
229) -> Vec<PathMatch> {
230 let path_count: usize = candidate_sets.iter().map(|s| s.len()).sum();
231 if path_count == 0 {
232 return Vec::new();
233 }
234
235 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
236 let query = query.chars().collect::<Vec<_>>();
237
238 let lowercase_query = &lowercase_query;
239 let query = &query;
240 let query_char_bag = CharBag::from(&lowercase_query[..]);
241
242 let num_cpus = background.num_cpus().min(path_count);
243 let segment_size = (path_count + num_cpus - 1) / num_cpus;
244 let mut segment_results = (0..num_cpus)
245 .map(|_| Vec::with_capacity(max_results))
246 .collect::<Vec<_>>();
247
248 background
249 .scoped(|scope| {
250 for (segment_idx, results) in segment_results.iter_mut().enumerate() {
251 scope.spawn(async move {
252 let segment_start = segment_idx * segment_size;
253 let segment_end = segment_start + segment_size;
254 let mut matcher = Matcher::new(
255 query,
256 lowercase_query,
257 query_char_bag,
258 smart_case,
259 max_results,
260 );
261
262 let mut tree_start = 0;
263 for candidate_set in candidate_sets {
264 let tree_end = tree_start + candidate_set.len();
265
266 if tree_start < segment_end && segment_start < tree_end {
267 let start = cmp::max(tree_start, segment_start) - tree_start;
268 let end = cmp::min(tree_end, segment_end) - tree_start;
269 let candidates = candidate_set.candidates(start).take(end - start);
270
271 matcher.match_paths(
272 candidate_set.id(),
273 candidate_set.prefix(),
274 candidates,
275 results,
276 &cancel_flag,
277 );
278 }
279 if tree_end >= segment_end {
280 break;
281 }
282 tree_start = tree_end;
283 }
284 })
285 }
286 })
287 .await;
288
289 let mut results = Vec::new();
290 for segment_result in segment_results {
291 if results.is_empty() {
292 results = segment_result;
293 } else {
294 util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(&a));
295 }
296 }
297 results
298}
299
300impl<'a> Matcher<'a> {
301 pub fn new(
302 query: &'a [char],
303 lowercase_query: &'a [char],
304 query_char_bag: CharBag,
305 smart_case: bool,
306 max_results: usize,
307 ) -> Self {
308 Self {
309 query,
310 lowercase_query,
311 query_char_bag,
312 min_score: 0.0,
313 last_positions: vec![0; query.len()],
314 match_positions: vec![0; query.len()],
315 score_matrix: Vec::new(),
316 best_position_matrix: Vec::new(),
317 smart_case,
318 max_results,
319 }
320 }
321
322 pub fn match_strings(
323 &mut self,
324 start_index: usize,
325 candidates: &[StringMatchCandidate],
326 results: &mut Vec<StringMatch>,
327 cancel_flag: &AtomicBool,
328 ) {
329 self.match_internal(
330 &[],
331 &[],
332 start_index,
333 candidates.iter(),
334 results,
335 cancel_flag,
336 |candidate_index, candidate, score| StringMatch {
337 candidate_index,
338 score,
339 positions: Vec::new(),
340 string: candidate.string.to_string(),
341 },
342 )
343 }
344
345 pub fn match_paths<'c: 'a>(
346 &mut self,
347 tree_id: usize,
348 path_prefix: Arc<str>,
349 path_entries: impl Iterator<Item = PathMatchCandidate<'c>>,
350 results: &mut Vec<PathMatch>,
351 cancel_flag: &AtomicBool,
352 ) {
353 let prefix = path_prefix.chars().collect::<Vec<_>>();
354 let lowercase_prefix = prefix
355 .iter()
356 .map(|c| c.to_ascii_lowercase())
357 .collect::<Vec<_>>();
358 self.match_internal(
359 &prefix,
360 &lowercase_prefix,
361 0,
362 path_entries,
363 results,
364 cancel_flag,
365 |_, candidate, score| PathMatch {
366 score,
367 worktree_id: tree_id,
368 positions: Vec::new(),
369 path: candidate.path.clone(),
370 path_prefix: path_prefix.clone(),
371 },
372 )
373 }
374
375 fn match_internal<C: MatchCandidate, R, F>(
376 &mut self,
377 prefix: &[char],
378 lowercase_prefix: &[char],
379 start_index: usize,
380 candidates: impl Iterator<Item = C>,
381 results: &mut Vec<R>,
382 cancel_flag: &AtomicBool,
383 build_match: F,
384 ) where
385 R: Match,
386 F: Fn(usize, &C, f64) -> R,
387 {
388 let mut candidate_chars = Vec::new();
389 let mut lowercase_candidate_chars = Vec::new();
390
391 for (candidate_ix, candidate) in candidates.enumerate() {
392 if !candidate.has_chars(self.query_char_bag) {
393 continue;
394 }
395
396 if cancel_flag.load(atomic::Ordering::Relaxed) {
397 break;
398 }
399
400 candidate_chars.clear();
401 lowercase_candidate_chars.clear();
402 for c in candidate.to_string().chars() {
403 candidate_chars.push(c);
404 lowercase_candidate_chars.push(c.to_ascii_lowercase());
405 }
406
407 if !self.find_last_positions(&lowercase_prefix, &lowercase_candidate_chars) {
408 continue;
409 }
410
411 let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
412 self.score_matrix.clear();
413 self.score_matrix.resize(matrix_len, None);
414 self.best_position_matrix.clear();
415 self.best_position_matrix.resize(matrix_len, 0);
416
417 let score = self.score_match(
418 &candidate_chars,
419 &lowercase_candidate_chars,
420 &prefix,
421 &lowercase_prefix,
422 );
423
424 if score > 0.0 {
425 let mut mat = build_match(start_index + candidate_ix, &candidate, score);
426 if let Err(i) = results.binary_search_by(|m| mat.cmp(&m)) {
427 if results.len() < self.max_results {
428 mat.set_positions(self.match_positions.clone());
429 results.insert(i, mat);
430 } else if i < results.len() {
431 results.pop();
432 mat.set_positions(self.match_positions.clone());
433 results.insert(i, mat);
434 }
435 if results.len() == self.max_results {
436 self.min_score = results.last().unwrap().score();
437 }
438 }
439 }
440 }
441 }
442
443 fn find_last_positions(
444 &mut self,
445 lowercase_prefix: &[char],
446 lowercase_candidate: &[char],
447 ) -> bool {
448 let mut lowercase_prefix = lowercase_prefix.iter();
449 let mut lowercase_candidate = lowercase_candidate.iter();
450 for (i, char) in self.lowercase_query.iter().enumerate().rev() {
451 if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
452 self.last_positions[i] = j + lowercase_prefix.len();
453 } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
454 self.last_positions[i] = j;
455 } else {
456 return false;
457 }
458 }
459 true
460 }
461
462 fn score_match(
463 &mut self,
464 path: &[char],
465 path_cased: &[char],
466 prefix: &[char],
467 lowercase_prefix: &[char],
468 ) -> f64 {
469 let score = self.recursive_score_match(
470 path,
471 path_cased,
472 prefix,
473 lowercase_prefix,
474 0,
475 0,
476 self.query.len() as f64,
477 ) * self.query.len() as f64;
478
479 if score <= 0.0 {
480 return 0.0;
481 }
482
483 let path_len = prefix.len() + path.len();
484 let mut cur_start = 0;
485 let mut byte_ix = 0;
486 let mut char_ix = 0;
487 for i in 0..self.query.len() {
488 let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
489 while char_ix < match_char_ix {
490 let ch = prefix
491 .get(char_ix)
492 .or_else(|| path.get(char_ix - prefix.len()))
493 .unwrap();
494 byte_ix += ch.len_utf8();
495 char_ix += 1;
496 }
497 cur_start = match_char_ix + 1;
498 self.match_positions[i] = byte_ix;
499 }
500
501 score
502 }
503
504 fn recursive_score_match(
505 &mut self,
506 path: &[char],
507 path_cased: &[char],
508 prefix: &[char],
509 lowercase_prefix: &[char],
510 query_idx: usize,
511 path_idx: usize,
512 cur_score: f64,
513 ) -> f64 {
514 if query_idx == self.query.len() {
515 return 1.0;
516 }
517
518 let path_len = prefix.len() + path.len();
519
520 if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
521 return memoized;
522 }
523
524 let mut score = 0.0;
525 let mut best_position = 0;
526
527 let query_char = self.lowercase_query[query_idx];
528 let limit = self.last_positions[query_idx];
529
530 let mut last_slash = 0;
531 for j in path_idx..=limit {
532 let path_char = if j < prefix.len() {
533 lowercase_prefix[j]
534 } else {
535 path_cased[j - prefix.len()]
536 };
537 let is_path_sep = path_char == '/' || path_char == '\\';
538
539 if query_idx == 0 && is_path_sep {
540 last_slash = j;
541 }
542
543 if query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\') {
544 let curr = if j < prefix.len() {
545 prefix[j]
546 } else {
547 path[j - prefix.len()]
548 };
549
550 let mut char_score = 1.0;
551 if j > path_idx {
552 let last = if j - 1 < prefix.len() {
553 prefix[j - 1]
554 } else {
555 path[j - 1 - prefix.len()]
556 };
557
558 if last == '/' {
559 char_score = 0.9;
560 } else if last == '-' || last == '_' || last == ' ' || last.is_numeric() {
561 char_score = 0.8;
562 } else if last.is_lowercase() && curr.is_uppercase() {
563 char_score = 0.8;
564 } else if last == '.' {
565 char_score = 0.7;
566 } else if query_idx == 0 {
567 char_score = BASE_DISTANCE_PENALTY;
568 } else {
569 char_score = MIN_DISTANCE_PENALTY.max(
570 BASE_DISTANCE_PENALTY
571 - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
572 );
573 }
574 }
575
576 // Apply a severe penalty if the case doesn't match.
577 // This will make the exact matches have higher score than the case-insensitive and the
578 // path insensitive matches.
579 if (self.smart_case || curr == '/') && self.query[query_idx] != curr {
580 char_score *= 0.001;
581 }
582
583 let mut multiplier = char_score;
584
585 // Scale the score based on how deep within the path we found the match.
586 if query_idx == 0 {
587 multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
588 }
589
590 let mut next_score = 1.0;
591 if self.min_score > 0.0 {
592 next_score = cur_score * multiplier;
593 // Scores only decrease. If we can't pass the previous best, bail
594 if next_score < self.min_score {
595 // Ensure that score is non-zero so we use it in the memo table.
596 if score == 0.0 {
597 score = 1e-18;
598 }
599 continue;
600 }
601 }
602
603 let new_score = self.recursive_score_match(
604 path,
605 path_cased,
606 prefix,
607 lowercase_prefix,
608 query_idx + 1,
609 j + 1,
610 next_score,
611 ) * multiplier;
612
613 if new_score > score {
614 score = new_score;
615 best_position = j;
616 // Optimization: can't score better than 1.
617 if new_score == 1.0 {
618 break;
619 }
620 }
621 }
622 }
623
624 if best_position != 0 {
625 self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
626 }
627
628 self.score_matrix[query_idx * path_len + path_idx] = Some(score);
629 score
630 }
631}
632
633#[cfg(test)]
634mod tests {
635 use super::*;
636 use std::path::PathBuf;
637
638 #[test]
639 fn test_get_last_positions() {
640 let mut query: &[char] = &['d', 'c'];
641 let mut matcher = Matcher::new(query, query, query.into(), false, 10);
642 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
643 assert_eq!(result, false);
644
645 query = &['c', 'd'];
646 let mut matcher = Matcher::new(query, query, query.into(), false, 10);
647 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
648 assert_eq!(result, true);
649 assert_eq!(matcher.last_positions, vec![2, 4]);
650
651 query = &['z', '/', 'z', 'f'];
652 let mut matcher = Matcher::new(query, query, query.into(), false, 10);
653 let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
654 assert_eq!(result, true);
655 assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
656 }
657
658 #[test]
659 fn test_match_path_entries() {
660 let paths = vec![
661 "",
662 "a",
663 "ab",
664 "abC",
665 "abcd",
666 "alphabravocharlie",
667 "AlphaBravoCharlie",
668 "thisisatestdir",
669 "/////ThisIsATestDir",
670 "/this/is/a/test/dir",
671 "/test/tiatd",
672 ];
673
674 assert_eq!(
675 match_query("abc", false, &paths),
676 vec![
677 ("abC", vec![0, 1, 2]),
678 ("abcd", vec![0, 1, 2]),
679 ("AlphaBravoCharlie", vec![0, 5, 10]),
680 ("alphabravocharlie", vec![4, 5, 10]),
681 ]
682 );
683 assert_eq!(
684 match_query("t/i/a/t/d", false, &paths),
685 vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
686 );
687
688 assert_eq!(
689 match_query("tiatd", false, &paths),
690 vec![
691 ("/test/tiatd", vec![6, 7, 8, 9, 10]),
692 ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
693 ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
694 ("thisisatestdir", vec![0, 2, 6, 7, 11]),
695 ]
696 );
697 }
698
699 #[test]
700 fn test_match_multibyte_path_entries() {
701 let paths = vec!["aαbβ/cγdδ", "αβγδ/bcde", "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", "/d/🆒/h"];
702 assert_eq!("1️⃣".len(), 7);
703 assert_eq!(
704 match_query("bcd", false, &paths),
705 vec![
706 ("αβγδ/bcde", vec![9, 10, 11]),
707 ("aαbβ/cγdδ", vec![3, 7, 10]),
708 ]
709 );
710 assert_eq!(
711 match_query("cde", false, &paths),
712 vec![
713 ("αβγδ/bcde", vec![10, 11, 12]),
714 ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
715 ]
716 );
717 }
718
719 fn match_query<'a>(
720 query: &str,
721 smart_case: bool,
722 paths: &Vec<&'a str>,
723 ) -> Vec<(&'a str, Vec<usize>)> {
724 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
725 let query = query.chars().collect::<Vec<_>>();
726 let query_chars = CharBag::from(&lowercase_query[..]);
727
728 let path_arcs = paths
729 .iter()
730 .map(|path| Arc::from(PathBuf::from(path)))
731 .collect::<Vec<_>>();
732 let mut path_entries = Vec::new();
733 for (i, path) in paths.iter().enumerate() {
734 let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
735 let char_bag = CharBag::from(lowercase_path.as_slice());
736 path_entries.push(PathMatchCandidate {
737 char_bag,
738 path: path_arcs.get(i).unwrap(),
739 });
740 }
741
742 let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case, 100);
743
744 let cancel_flag = AtomicBool::new(false);
745 let mut results = Vec::new();
746 matcher.match_paths(
747 0,
748 "".into(),
749 path_entries.into_iter(),
750 &mut results,
751 &cancel_flag,
752 );
753
754 results
755 .into_iter()
756 .map(|result| {
757 (
758 paths
759 .iter()
760 .copied()
761 .find(|p| result.path.as_ref() == Path::new(p))
762 .unwrap(),
763 result.positions,
764 )
765 })
766 .collect()
767 }
768}