1mod char_bag;
2
3use gpui::executor;
4use std::{
5 borrow::Cow,
6 cmp::{self, Ordering},
7 path::Path,
8 sync::atomic::{self, AtomicBool},
9 sync::Arc,
10};
11
12pub use char_bag::CharBag;
13
14const BASE_DISTANCE_PENALTY: f64 = 0.6;
15const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
16const MIN_DISTANCE_PENALTY: f64 = 0.2;
17
18pub struct Matcher<'a> {
19 query: &'a [char],
20 lowercase_query: &'a [char],
21 query_char_bag: CharBag,
22 smart_case: bool,
23 max_results: usize,
24 min_score: f64,
25 match_positions: Vec<usize>,
26 last_positions: Vec<usize>,
27 score_matrix: Vec<Option<f64>>,
28 best_position_matrix: Vec<usize>,
29}
30
31trait Match: Ord {
32 fn score(&self) -> f64;
33 fn set_positions(&mut self, positions: Vec<usize>);
34}
35
36trait MatchCandidate {
37 fn has_chars(&self, bag: CharBag) -> bool;
38 fn to_string<'a>(&'a self) -> Cow<'a, str>;
39}
40
41#[derive(Clone, Debug)]
42pub struct PathMatchCandidate<'a> {
43 pub path: &'a Arc<Path>,
44 pub char_bag: CharBag,
45}
46
47#[derive(Clone, Debug)]
48pub struct PathMatch {
49 pub score: f64,
50 pub positions: Vec<usize>,
51 pub worktree_id: usize,
52 pub path: Arc<Path>,
53 pub path_prefix: Arc<str>,
54}
55
56#[derive(Clone, Debug)]
57pub struct StringMatchCandidate {
58 pub id: usize,
59 pub string: String,
60 pub char_bag: CharBag,
61}
62
63pub trait PathMatchCandidateSet<'a>: Send + Sync {
64 type Candidates: Iterator<Item = PathMatchCandidate<'a>>;
65 fn id(&self) -> usize;
66 fn len(&self) -> usize;
67 fn prefix(&self) -> Arc<str>;
68 fn candidates(&'a self, start: usize) -> Self::Candidates;
69}
70
71impl Match for PathMatch {
72 fn score(&self) -> f64 {
73 self.score
74 }
75
76 fn set_positions(&mut self, positions: Vec<usize>) {
77 self.positions = positions;
78 }
79}
80
81impl Match for StringMatch {
82 fn score(&self) -> f64 {
83 self.score
84 }
85
86 fn set_positions(&mut self, positions: Vec<usize>) {
87 self.positions = positions;
88 }
89}
90
91impl<'a> MatchCandidate for PathMatchCandidate<'a> {
92 fn has_chars(&self, bag: CharBag) -> bool {
93 self.char_bag.is_superset(bag)
94 }
95
96 fn to_string(&self) -> Cow<'a, str> {
97 self.path.to_string_lossy()
98 }
99}
100
101impl StringMatchCandidate {
102 pub fn new(id: usize, string: String) -> Self {
103 Self {
104 id,
105 char_bag: CharBag::from(string.as_str()),
106 string,
107 }
108 }
109}
110
111impl<'a> MatchCandidate for &'a StringMatchCandidate {
112 fn has_chars(&self, bag: CharBag) -> bool {
113 self.char_bag.is_superset(bag)
114 }
115
116 fn to_string(&self) -> Cow<'a, str> {
117 self.string.as_str().into()
118 }
119}
120
121#[derive(Clone, Debug)]
122pub struct StringMatch {
123 pub candidate_id: usize,
124 pub score: f64,
125 pub positions: Vec<usize>,
126 pub string: String,
127}
128
129impl PartialEq for StringMatch {
130 fn eq(&self, other: &Self) -> bool {
131 self.cmp(other).is_eq()
132 }
133}
134
135impl Eq for StringMatch {}
136
137impl PartialOrd for StringMatch {
138 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
139 Some(self.cmp(other))
140 }
141}
142
143impl Ord for StringMatch {
144 fn cmp(&self, other: &Self) -> Ordering {
145 self.score
146 .partial_cmp(&other.score)
147 .unwrap_or(Ordering::Equal)
148 .then_with(|| self.candidate_id.cmp(&other.candidate_id))
149 }
150}
151
152impl PartialEq for PathMatch {
153 fn eq(&self, other: &Self) -> bool {
154 self.cmp(other).is_eq()
155 }
156}
157
158impl Eq for PathMatch {}
159
160impl PartialOrd for PathMatch {
161 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
162 Some(self.cmp(other))
163 }
164}
165
166impl Ord for PathMatch {
167 fn cmp(&self, other: &Self) -> Ordering {
168 self.score
169 .partial_cmp(&other.score)
170 .unwrap_or(Ordering::Equal)
171 .then_with(|| self.worktree_id.cmp(&other.worktree_id))
172 .then_with(|| Arc::as_ptr(&self.path).cmp(&Arc::as_ptr(&other.path)))
173 }
174}
175
176pub async fn match_strings(
177 candidates: &[StringMatchCandidate],
178 query: &str,
179 smart_case: bool,
180 max_results: usize,
181 cancel_flag: &AtomicBool,
182 background: Arc<executor::Background>,
183) -> Vec<StringMatch> {
184 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
185 let query = query.chars().collect::<Vec<_>>();
186
187 let lowercase_query = &lowercase_query;
188 let query = &query;
189 let query_char_bag = CharBag::from(&lowercase_query[..]);
190
191 let num_cpus = background.num_cpus().min(candidates.len());
192 let segment_size = (candidates.len() + num_cpus - 1) / num_cpus;
193 let mut segment_results = (0..num_cpus)
194 .map(|_| Vec::with_capacity(max_results))
195 .collect::<Vec<_>>();
196
197 background
198 .scoped(|scope| {
199 for (segment_idx, results) in segment_results.iter_mut().enumerate() {
200 let cancel_flag = &cancel_flag;
201 scope.spawn(async move {
202 let segment_start = cmp::min(segment_idx * segment_size, candidates.len());
203 let segment_end = cmp::min(segment_start + segment_size, candidates.len());
204 let mut matcher = Matcher::new(
205 query,
206 lowercase_query,
207 query_char_bag,
208 smart_case,
209 max_results,
210 );
211 matcher.match_strings(
212 &candidates[segment_start..segment_end],
213 results,
214 cancel_flag,
215 );
216 });
217 }
218 })
219 .await;
220
221 let mut results = Vec::new();
222 for segment_result in segment_results {
223 if results.is_empty() {
224 results = segment_result;
225 } else {
226 util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(&a));
227 }
228 }
229 results
230}
231
232pub async fn match_paths<'a, Set: PathMatchCandidateSet<'a>>(
233 candidate_sets: &'a [Set],
234 query: &str,
235 smart_case: bool,
236 max_results: usize,
237 cancel_flag: &AtomicBool,
238 background: Arc<executor::Background>,
239) -> Vec<PathMatch> {
240 let path_count: usize = candidate_sets.iter().map(|s| s.len()).sum();
241 if path_count == 0 {
242 return Vec::new();
243 }
244
245 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
246 let query = query.chars().collect::<Vec<_>>();
247
248 let lowercase_query = &lowercase_query;
249 let query = &query;
250 let query_char_bag = CharBag::from(&lowercase_query[..]);
251
252 let num_cpus = background.num_cpus().min(path_count);
253 let segment_size = (path_count + num_cpus - 1) / num_cpus;
254 let mut segment_results = (0..num_cpus)
255 .map(|_| Vec::with_capacity(max_results))
256 .collect::<Vec<_>>();
257
258 background
259 .scoped(|scope| {
260 for (segment_idx, results) in segment_results.iter_mut().enumerate() {
261 scope.spawn(async move {
262 let segment_start = segment_idx * segment_size;
263 let segment_end = segment_start + segment_size;
264 let mut matcher = Matcher::new(
265 query,
266 lowercase_query,
267 query_char_bag,
268 smart_case,
269 max_results,
270 );
271
272 let mut tree_start = 0;
273 for candidate_set in candidate_sets {
274 let tree_end = tree_start + candidate_set.len();
275
276 if tree_start < segment_end && segment_start < tree_end {
277 let start = cmp::max(tree_start, segment_start) - tree_start;
278 let end = cmp::min(tree_end, segment_end) - tree_start;
279 let candidates = candidate_set.candidates(start).take(end - start);
280
281 matcher.match_paths(
282 candidate_set.id(),
283 candidate_set.prefix(),
284 candidates,
285 results,
286 &cancel_flag,
287 );
288 }
289 if tree_end >= segment_end {
290 break;
291 }
292 tree_start = tree_end;
293 }
294 })
295 }
296 })
297 .await;
298
299 let mut results = Vec::new();
300 for segment_result in segment_results {
301 if results.is_empty() {
302 results = segment_result;
303 } else {
304 util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(&a));
305 }
306 }
307 results
308}
309
310impl<'a> Matcher<'a> {
311 pub fn new(
312 query: &'a [char],
313 lowercase_query: &'a [char],
314 query_char_bag: CharBag,
315 smart_case: bool,
316 max_results: usize,
317 ) -> Self {
318 Self {
319 query,
320 lowercase_query,
321 query_char_bag,
322 min_score: 0.0,
323 last_positions: vec![0; query.len()],
324 match_positions: vec![0; query.len()],
325 score_matrix: Vec::new(),
326 best_position_matrix: Vec::new(),
327 smart_case,
328 max_results,
329 }
330 }
331
332 pub fn match_strings(
333 &mut self,
334 candidates: &[StringMatchCandidate],
335 results: &mut Vec<StringMatch>,
336 cancel_flag: &AtomicBool,
337 ) {
338 self.match_internal(
339 &[],
340 &[],
341 candidates.iter(),
342 results,
343 cancel_flag,
344 |candidate, score| StringMatch {
345 candidate_id: candidate.id,
346 score,
347 positions: Vec::new(),
348 string: candidate.string.to_string(),
349 },
350 )
351 }
352
353 pub fn match_paths<'c: 'a>(
354 &mut self,
355 tree_id: usize,
356 path_prefix: Arc<str>,
357 path_entries: impl Iterator<Item = PathMatchCandidate<'c>>,
358 results: &mut Vec<PathMatch>,
359 cancel_flag: &AtomicBool,
360 ) {
361 let prefix = path_prefix.chars().collect::<Vec<_>>();
362 let lowercase_prefix = prefix
363 .iter()
364 .map(|c| c.to_ascii_lowercase())
365 .collect::<Vec<_>>();
366 self.match_internal(
367 &prefix,
368 &lowercase_prefix,
369 path_entries,
370 results,
371 cancel_flag,
372 |candidate, score| PathMatch {
373 score,
374 worktree_id: tree_id,
375 positions: Vec::new(),
376 path: candidate.path.clone(),
377 path_prefix: path_prefix.clone(),
378 },
379 )
380 }
381
382 fn match_internal<C: MatchCandidate, R, F>(
383 &mut self,
384 prefix: &[char],
385 lowercase_prefix: &[char],
386 candidates: impl Iterator<Item = C>,
387 results: &mut Vec<R>,
388 cancel_flag: &AtomicBool,
389 build_match: F,
390 ) where
391 R: Match,
392 F: Fn(&C, f64) -> R,
393 {
394 let mut candidate_chars = Vec::new();
395 let mut lowercase_candidate_chars = Vec::new();
396
397 for candidate in candidates {
398 if !candidate.has_chars(self.query_char_bag) {
399 continue;
400 }
401
402 if cancel_flag.load(atomic::Ordering::Relaxed) {
403 break;
404 }
405
406 candidate_chars.clear();
407 lowercase_candidate_chars.clear();
408 for c in candidate.to_string().chars() {
409 candidate_chars.push(c);
410 lowercase_candidate_chars.push(c.to_ascii_lowercase());
411 }
412
413 if !self.find_last_positions(&lowercase_prefix, &lowercase_candidate_chars) {
414 continue;
415 }
416
417 let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
418 self.score_matrix.clear();
419 self.score_matrix.resize(matrix_len, None);
420 self.best_position_matrix.clear();
421 self.best_position_matrix.resize(matrix_len, 0);
422
423 let score = self.score_match(
424 &candidate_chars,
425 &lowercase_candidate_chars,
426 &prefix,
427 &lowercase_prefix,
428 );
429
430 if score > 0.0 {
431 let mut mat = build_match(&candidate, score);
432 if let Err(i) = results.binary_search_by(|m| mat.cmp(&m)) {
433 if results.len() < self.max_results {
434 mat.set_positions(self.match_positions.clone());
435 results.insert(i, mat);
436 } else if i < results.len() {
437 results.pop();
438 mat.set_positions(self.match_positions.clone());
439 results.insert(i, mat);
440 }
441 if results.len() == self.max_results {
442 self.min_score = results.last().unwrap().score();
443 }
444 }
445 }
446 }
447 }
448
449 fn find_last_positions(
450 &mut self,
451 lowercase_prefix: &[char],
452 lowercase_candidate: &[char],
453 ) -> bool {
454 let mut lowercase_prefix = lowercase_prefix.iter();
455 let mut lowercase_candidate = lowercase_candidate.iter();
456 for (i, char) in self.lowercase_query.iter().enumerate().rev() {
457 if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
458 self.last_positions[i] = j + lowercase_prefix.len();
459 } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
460 self.last_positions[i] = j;
461 } else {
462 return false;
463 }
464 }
465 true
466 }
467
468 fn score_match(
469 &mut self,
470 path: &[char],
471 path_cased: &[char],
472 prefix: &[char],
473 lowercase_prefix: &[char],
474 ) -> f64 {
475 let score = self.recursive_score_match(
476 path,
477 path_cased,
478 prefix,
479 lowercase_prefix,
480 0,
481 0,
482 self.query.len() as f64,
483 ) * self.query.len() as f64;
484
485 if score <= 0.0 {
486 return 0.0;
487 }
488
489 let path_len = prefix.len() + path.len();
490 let mut cur_start = 0;
491 let mut byte_ix = 0;
492 let mut char_ix = 0;
493 for i in 0..self.query.len() {
494 let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
495 while char_ix < match_char_ix {
496 let ch = prefix
497 .get(char_ix)
498 .or_else(|| path.get(char_ix - prefix.len()))
499 .unwrap();
500 byte_ix += ch.len_utf8();
501 char_ix += 1;
502 }
503 cur_start = match_char_ix + 1;
504 self.match_positions[i] = byte_ix;
505 }
506
507 score
508 }
509
510 fn recursive_score_match(
511 &mut self,
512 path: &[char],
513 path_cased: &[char],
514 prefix: &[char],
515 lowercase_prefix: &[char],
516 query_idx: usize,
517 path_idx: usize,
518 cur_score: f64,
519 ) -> f64 {
520 if query_idx == self.query.len() {
521 return 1.0;
522 }
523
524 let path_len = prefix.len() + path.len();
525
526 if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
527 return memoized;
528 }
529
530 let mut score = 0.0;
531 let mut best_position = 0;
532
533 let query_char = self.lowercase_query[query_idx];
534 let limit = self.last_positions[query_idx];
535
536 let mut last_slash = 0;
537 for j in path_idx..=limit {
538 let path_char = if j < prefix.len() {
539 lowercase_prefix[j]
540 } else {
541 path_cased[j - prefix.len()]
542 };
543 let is_path_sep = path_char == '/' || path_char == '\\';
544
545 if query_idx == 0 && is_path_sep {
546 last_slash = j;
547 }
548
549 if query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\') {
550 let curr = if j < prefix.len() {
551 prefix[j]
552 } else {
553 path[j - prefix.len()]
554 };
555
556 let mut char_score = 1.0;
557 if j > path_idx {
558 let last = if j - 1 < prefix.len() {
559 prefix[j - 1]
560 } else {
561 path[j - 1 - prefix.len()]
562 };
563
564 if last == '/' {
565 char_score = 0.9;
566 } else if last == '-' || last == '_' || last == ' ' || last.is_numeric() {
567 char_score = 0.8;
568 } else if last.is_lowercase() && curr.is_uppercase() {
569 char_score = 0.8;
570 } else if last == '.' {
571 char_score = 0.7;
572 } else if query_idx == 0 {
573 char_score = BASE_DISTANCE_PENALTY;
574 } else {
575 char_score = MIN_DISTANCE_PENALTY.max(
576 BASE_DISTANCE_PENALTY
577 - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
578 );
579 }
580 }
581
582 // Apply a severe penalty if the case doesn't match.
583 // This will make the exact matches have higher score than the case-insensitive and the
584 // path insensitive matches.
585 if (self.smart_case || curr == '/') && self.query[query_idx] != curr {
586 char_score *= 0.001;
587 }
588
589 let mut multiplier = char_score;
590
591 // Scale the score based on how deep within the path we found the match.
592 if query_idx == 0 {
593 multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
594 }
595
596 let mut next_score = 1.0;
597 if self.min_score > 0.0 {
598 next_score = cur_score * multiplier;
599 // Scores only decrease. If we can't pass the previous best, bail
600 if next_score < self.min_score {
601 // Ensure that score is non-zero so we use it in the memo table.
602 if score == 0.0 {
603 score = 1e-18;
604 }
605 continue;
606 }
607 }
608
609 let new_score = self.recursive_score_match(
610 path,
611 path_cased,
612 prefix,
613 lowercase_prefix,
614 query_idx + 1,
615 j + 1,
616 next_score,
617 ) * multiplier;
618
619 if new_score > score {
620 score = new_score;
621 best_position = j;
622 // Optimization: can't score better than 1.
623 if new_score == 1.0 {
624 break;
625 }
626 }
627 }
628 }
629
630 if best_position != 0 {
631 self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
632 }
633
634 self.score_matrix[query_idx * path_len + path_idx] = Some(score);
635 score
636 }
637}
638
639#[cfg(test)]
640mod tests {
641 use super::*;
642 use std::path::PathBuf;
643
644 #[test]
645 fn test_get_last_positions() {
646 let mut query: &[char] = &['d', 'c'];
647 let mut matcher = Matcher::new(query, query, query.into(), false, 10);
648 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
649 assert_eq!(result, false);
650
651 query = &['c', 'd'];
652 let mut matcher = Matcher::new(query, query, query.into(), false, 10);
653 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
654 assert_eq!(result, true);
655 assert_eq!(matcher.last_positions, vec![2, 4]);
656
657 query = &['z', '/', 'z', 'f'];
658 let mut matcher = Matcher::new(query, query, query.into(), false, 10);
659 let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
660 assert_eq!(result, true);
661 assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
662 }
663
664 #[test]
665 fn test_match_path_entries() {
666 let paths = vec![
667 "",
668 "a",
669 "ab",
670 "abC",
671 "abcd",
672 "alphabravocharlie",
673 "AlphaBravoCharlie",
674 "thisisatestdir",
675 "/////ThisIsATestDir",
676 "/this/is/a/test/dir",
677 "/test/tiatd",
678 ];
679
680 assert_eq!(
681 match_query("abc", false, &paths),
682 vec![
683 ("abC", vec![0, 1, 2]),
684 ("abcd", vec![0, 1, 2]),
685 ("AlphaBravoCharlie", vec![0, 5, 10]),
686 ("alphabravocharlie", vec![4, 5, 10]),
687 ]
688 );
689 assert_eq!(
690 match_query("t/i/a/t/d", false, &paths),
691 vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
692 );
693
694 assert_eq!(
695 match_query("tiatd", false, &paths),
696 vec![
697 ("/test/tiatd", vec![6, 7, 8, 9, 10]),
698 ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
699 ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
700 ("thisisatestdir", vec![0, 2, 6, 7, 11]),
701 ]
702 );
703 }
704
705 #[test]
706 fn test_match_multibyte_path_entries() {
707 let paths = vec!["aαbβ/cγdδ", "αβγδ/bcde", "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", "/d/🆒/h"];
708 assert_eq!("1️⃣".len(), 7);
709 assert_eq!(
710 match_query("bcd", false, &paths),
711 vec![
712 ("αβγδ/bcde", vec![9, 10, 11]),
713 ("aαbβ/cγdδ", vec![3, 7, 10]),
714 ]
715 );
716 assert_eq!(
717 match_query("cde", false, &paths),
718 vec![
719 ("αβγδ/bcde", vec![10, 11, 12]),
720 ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
721 ]
722 );
723 }
724
725 fn match_query<'a>(
726 query: &str,
727 smart_case: bool,
728 paths: &Vec<&'a str>,
729 ) -> Vec<(&'a str, Vec<usize>)> {
730 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
731 let query = query.chars().collect::<Vec<_>>();
732 let query_chars = CharBag::from(&lowercase_query[..]);
733
734 let path_arcs = paths
735 .iter()
736 .map(|path| Arc::from(PathBuf::from(path)))
737 .collect::<Vec<_>>();
738 let mut path_entries = Vec::new();
739 for (i, path) in paths.iter().enumerate() {
740 let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
741 let char_bag = CharBag::from(lowercase_path.as_slice());
742 path_entries.push(PathMatchCandidate {
743 char_bag,
744 path: path_arcs.get(i).unwrap(),
745 });
746 }
747
748 let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case, 100);
749
750 let cancel_flag = AtomicBool::new(false);
751 let mut results = Vec::new();
752 matcher.match_paths(
753 0,
754 "".into(),
755 path_entries.into_iter(),
756 &mut results,
757 &cancel_flag,
758 );
759
760 results
761 .into_iter()
762 .map(|result| {
763 (
764 paths
765 .iter()
766 .copied()
767 .find(|p| result.path.as_ref() == Path::new(p))
768 .unwrap(),
769 result.positions,
770 )
771 })
772 .collect()
773 }
774}