1mod char_bag;
2
3use gpui::executor;
4use std::{
5 borrow::Cow,
6 cmp::{self, Ordering},
7 path::Path,
8 sync::atomic::{self, AtomicBool},
9 sync::Arc,
10};
11
12pub use char_bag::CharBag;
13
14const BASE_DISTANCE_PENALTY: f64 = 0.6;
15const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
16const MIN_DISTANCE_PENALTY: f64 = 0.2;
17
18pub struct Matcher<'a> {
19 query: &'a [char],
20 lowercase_query: &'a [char],
21 query_char_bag: CharBag,
22 smart_case: bool,
23 max_results: usize,
24 min_score: f64,
25 match_positions: Vec<usize>,
26 last_positions: Vec<usize>,
27 score_matrix: Vec<Option<f64>>,
28 best_position_matrix: Vec<usize>,
29}
30
31trait Match: Ord {
32 fn score(&self) -> f64;
33 fn set_positions(&mut self, positions: Vec<usize>);
34}
35
36trait MatchCandidate {
37 fn has_chars(&self, bag: CharBag) -> bool;
38 fn to_string<'a>(&'a self) -> Cow<'a, str>;
39}
40
41#[derive(Clone, Debug)]
42pub struct PathMatchCandidate<'a> {
43 pub path: &'a Arc<Path>,
44 pub char_bag: CharBag,
45}
46
47#[derive(Clone, Debug)]
48pub struct PathMatch {
49 pub score: f64,
50 pub positions: Vec<usize>,
51 pub worktree_id: usize,
52 pub path: Arc<Path>,
53 pub path_prefix: Arc<str>,
54}
55
56#[derive(Clone, Debug)]
57pub struct StringMatchCandidate {
58 pub string: String,
59 pub char_bag: CharBag,
60}
61
62pub trait PathMatchCandidateSet<'a>: Send + Sync {
63 type Candidates: Iterator<Item = PathMatchCandidate<'a>>;
64 fn id(&self) -> usize;
65 fn len(&self) -> usize;
66 fn prefix(&self) -> Arc<str>;
67 fn candidates(&'a self, start: usize) -> Self::Candidates;
68}
69
70impl Match for PathMatch {
71 fn score(&self) -> f64 {
72 self.score
73 }
74
75 fn set_positions(&mut self, positions: Vec<usize>) {
76 self.positions = positions;
77 }
78}
79
80impl Match for StringMatch {
81 fn score(&self) -> f64 {
82 self.score
83 }
84
85 fn set_positions(&mut self, positions: Vec<usize>) {
86 self.positions = positions;
87 }
88}
89
90impl<'a> MatchCandidate for PathMatchCandidate<'a> {
91 fn has_chars(&self, bag: CharBag) -> bool {
92 self.char_bag.is_superset(bag)
93 }
94
95 fn to_string(&self) -> Cow<'a, str> {
96 self.path.to_string_lossy()
97 }
98}
99
100impl<'a> MatchCandidate for &'a StringMatchCandidate {
101 fn has_chars(&self, bag: CharBag) -> bool {
102 self.char_bag.is_superset(bag)
103 }
104
105 fn to_string(&self) -> Cow<'a, str> {
106 self.string.as_str().into()
107 }
108}
109
110#[derive(Clone, Debug)]
111pub struct StringMatch {
112 pub score: f64,
113 pub positions: Vec<usize>,
114 pub string: String,
115}
116
117impl PartialEq for StringMatch {
118 fn eq(&self, other: &Self) -> bool {
119 self.score.eq(&other.score)
120 }
121}
122
123impl Eq for StringMatch {}
124
125impl PartialOrd for StringMatch {
126 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
127 Some(self.cmp(other))
128 }
129}
130
131impl Ord for StringMatch {
132 fn cmp(&self, other: &Self) -> Ordering {
133 self.score
134 .partial_cmp(&other.score)
135 .unwrap_or(Ordering::Equal)
136 .then_with(|| self.string.cmp(&other.string))
137 }
138}
139
140impl PartialEq for PathMatch {
141 fn eq(&self, other: &Self) -> bool {
142 self.score.eq(&other.score)
143 }
144}
145
146impl Eq for PathMatch {}
147
148impl PartialOrd for PathMatch {
149 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
150 Some(self.cmp(other))
151 }
152}
153
154impl Ord for PathMatch {
155 fn cmp(&self, other: &Self) -> Ordering {
156 self.score
157 .partial_cmp(&other.score)
158 .unwrap_or(Ordering::Equal)
159 .then_with(|| self.worktree_id.cmp(&other.worktree_id))
160 .then_with(|| Arc::as_ptr(&self.path).cmp(&Arc::as_ptr(&other.path)))
161 }
162}
163
164pub async fn match_strings(
165 candidates: &[StringMatchCandidate],
166 query: &str,
167 smart_case: bool,
168 max_results: usize,
169 cancel_flag: &AtomicBool,
170 background: Arc<executor::Background>,
171) -> Vec<StringMatch> {
172 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
173 let query = query.chars().collect::<Vec<_>>();
174
175 let lowercase_query = &lowercase_query;
176 let query = &query;
177 let query_char_bag = CharBag::from(&lowercase_query[..]);
178
179 let num_cpus = background.num_cpus().min(candidates.len());
180 let segment_size = (candidates.len() + num_cpus - 1) / num_cpus;
181 let mut segment_results = (0..num_cpus)
182 .map(|_| Vec::with_capacity(max_results))
183 .collect::<Vec<_>>();
184
185 background
186 .scoped(|scope| {
187 for (segment_idx, results) in segment_results.iter_mut().enumerate() {
188 let cancel_flag = &cancel_flag;
189 scope.spawn(async move {
190 let segment_start = segment_idx * segment_size;
191 let segment_end = segment_start + segment_size;
192 let mut matcher = Matcher::new(
193 query,
194 lowercase_query,
195 query_char_bag,
196 smart_case,
197 max_results,
198 );
199 matcher.match_strings(
200 &candidates[segment_start..segment_end],
201 results,
202 cancel_flag,
203 );
204 });
205 }
206 })
207 .await;
208
209 let mut results = Vec::new();
210 for segment_result in segment_results {
211 if results.is_empty() {
212 results = segment_result;
213 } else {
214 util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(&a));
215 }
216 }
217 results
218}
219
220pub async fn match_paths<'a, Set: PathMatchCandidateSet<'a>>(
221 candidate_sets: &'a [Set],
222 query: &str,
223 smart_case: bool,
224 max_results: usize,
225 cancel_flag: &AtomicBool,
226 background: Arc<executor::Background>,
227) -> Vec<PathMatch> {
228 let path_count: usize = candidate_sets.iter().map(|s| s.len()).sum();
229 if path_count == 0 {
230 return Vec::new();
231 }
232
233 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
234 let query = query.chars().collect::<Vec<_>>();
235
236 let lowercase_query = &lowercase_query;
237 let query = &query;
238 let query_char_bag = CharBag::from(&lowercase_query[..]);
239
240 let num_cpus = background.num_cpus().min(path_count);
241 let segment_size = (path_count + num_cpus - 1) / num_cpus;
242 let mut segment_results = (0..num_cpus)
243 .map(|_| Vec::with_capacity(max_results))
244 .collect::<Vec<_>>();
245
246 background
247 .scoped(|scope| {
248 for (segment_idx, results) in segment_results.iter_mut().enumerate() {
249 scope.spawn(async move {
250 let segment_start = segment_idx * segment_size;
251 let segment_end = segment_start + segment_size;
252 let mut matcher = Matcher::new(
253 query,
254 lowercase_query,
255 query_char_bag,
256 smart_case,
257 max_results,
258 );
259
260 let mut tree_start = 0;
261 for candidate_set in candidate_sets {
262 let tree_end = tree_start + candidate_set.len();
263
264 if tree_start < segment_end && segment_start < tree_end {
265 let start = cmp::max(tree_start, segment_start) - tree_start;
266 let end = cmp::min(tree_end, segment_end) - tree_start;
267 let candidates = candidate_set.candidates(start).take(end - start);
268
269 matcher.match_paths(
270 candidate_set.id(),
271 candidate_set.prefix(),
272 candidates,
273 results,
274 &cancel_flag,
275 );
276 }
277 if tree_end >= segment_end {
278 break;
279 }
280 tree_start = tree_end;
281 }
282 })
283 }
284 })
285 .await;
286
287 let mut results = Vec::new();
288 for segment_result in segment_results {
289 if results.is_empty() {
290 results = segment_result;
291 } else {
292 util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(&a));
293 }
294 }
295 results
296}
297
298impl<'a> Matcher<'a> {
299 pub fn new(
300 query: &'a [char],
301 lowercase_query: &'a [char],
302 query_char_bag: CharBag,
303 smart_case: bool,
304 max_results: usize,
305 ) -> Self {
306 Self {
307 query,
308 lowercase_query,
309 query_char_bag,
310 min_score: 0.0,
311 last_positions: vec![0; query.len()],
312 match_positions: vec![0; query.len()],
313 score_matrix: Vec::new(),
314 best_position_matrix: Vec::new(),
315 smart_case,
316 max_results,
317 }
318 }
319
320 pub fn match_strings(
321 &mut self,
322 candidates: &[StringMatchCandidate],
323 results: &mut Vec<StringMatch>,
324 cancel_flag: &AtomicBool,
325 ) {
326 self.match_internal(
327 &[],
328 &[],
329 candidates.iter(),
330 results,
331 cancel_flag,
332 |candidate, score| StringMatch {
333 score,
334 positions: Vec::new(),
335 string: candidate.string.to_string(),
336 },
337 )
338 }
339
340 pub fn match_paths<'c: 'a>(
341 &mut self,
342 tree_id: usize,
343 path_prefix: Arc<str>,
344 path_entries: impl Iterator<Item = PathMatchCandidate<'c>>,
345 results: &mut Vec<PathMatch>,
346 cancel_flag: &AtomicBool,
347 ) {
348 let prefix = path_prefix.chars().collect::<Vec<_>>();
349 let lowercase_prefix = prefix
350 .iter()
351 .map(|c| c.to_ascii_lowercase())
352 .collect::<Vec<_>>();
353 self.match_internal(
354 &prefix,
355 &lowercase_prefix,
356 path_entries,
357 results,
358 cancel_flag,
359 |candidate, score| PathMatch {
360 score,
361 worktree_id: tree_id,
362 positions: Vec::new(),
363 path: candidate.path.clone(),
364 path_prefix: path_prefix.clone(),
365 },
366 )
367 }
368
369 fn match_internal<C: MatchCandidate, R, F>(
370 &mut self,
371 prefix: &[char],
372 lowercase_prefix: &[char],
373 candidates: impl Iterator<Item = C>,
374 results: &mut Vec<R>,
375 cancel_flag: &AtomicBool,
376 build_match: F,
377 ) where
378 R: Match,
379 F: Fn(&C, f64) -> R,
380 {
381 let mut candidate_chars = Vec::new();
382 let mut lowercase_candidate_chars = Vec::new();
383
384 for candidate in candidates {
385 if !candidate.has_chars(self.query_char_bag) {
386 continue;
387 }
388
389 if cancel_flag.load(atomic::Ordering::Relaxed) {
390 break;
391 }
392
393 candidate_chars.clear();
394 lowercase_candidate_chars.clear();
395 for c in candidate.to_string().chars() {
396 candidate_chars.push(c);
397 lowercase_candidate_chars.push(c.to_ascii_lowercase());
398 }
399
400 if !self.find_last_positions(&lowercase_prefix, &lowercase_candidate_chars) {
401 continue;
402 }
403
404 let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
405 self.score_matrix.clear();
406 self.score_matrix.resize(matrix_len, None);
407 self.best_position_matrix.clear();
408 self.best_position_matrix.resize(matrix_len, 0);
409
410 let score = self.score_match(
411 &candidate_chars,
412 &lowercase_candidate_chars,
413 &prefix,
414 &lowercase_prefix,
415 );
416
417 if score > 0.0 {
418 let mut mat = build_match(&candidate, score);
419 if let Err(i) = results.binary_search_by(|m| mat.cmp(&m)) {
420 if results.len() < self.max_results {
421 mat.set_positions(self.match_positions.clone());
422 results.insert(i, mat);
423 } else if i < results.len() {
424 results.pop();
425 mat.set_positions(self.match_positions.clone());
426 results.insert(i, mat);
427 }
428 if results.len() == self.max_results {
429 self.min_score = results.last().unwrap().score();
430 }
431 }
432 }
433 }
434 }
435
436 fn find_last_positions(&mut self, prefix: &[char], path: &[char]) -> bool {
437 let mut path = path.iter();
438 let mut prefix_iter = prefix.iter();
439 for (i, char) in self.query.iter().enumerate().rev() {
440 if let Some(j) = path.rposition(|c| c == char) {
441 self.last_positions[i] = j + prefix.len();
442 } else if let Some(j) = prefix_iter.rposition(|c| c == char) {
443 self.last_positions[i] = j;
444 } else {
445 return false;
446 }
447 }
448 true
449 }
450
451 fn score_match(
452 &mut self,
453 path: &[char],
454 path_cased: &[char],
455 prefix: &[char],
456 lowercase_prefix: &[char],
457 ) -> f64 {
458 let score = self.recursive_score_match(
459 path,
460 path_cased,
461 prefix,
462 lowercase_prefix,
463 0,
464 0,
465 self.query.len() as f64,
466 ) * self.query.len() as f64;
467
468 if score <= 0.0 {
469 return 0.0;
470 }
471
472 let path_len = prefix.len() + path.len();
473 let mut cur_start = 0;
474 let mut byte_ix = 0;
475 let mut char_ix = 0;
476 for i in 0..self.query.len() {
477 let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
478 while char_ix < match_char_ix {
479 let ch = prefix
480 .get(char_ix)
481 .or_else(|| path.get(char_ix - prefix.len()))
482 .unwrap();
483 byte_ix += ch.len_utf8();
484 char_ix += 1;
485 }
486 cur_start = match_char_ix + 1;
487 self.match_positions[i] = byte_ix;
488 }
489
490 score
491 }
492
493 fn recursive_score_match(
494 &mut self,
495 path: &[char],
496 path_cased: &[char],
497 prefix: &[char],
498 lowercase_prefix: &[char],
499 query_idx: usize,
500 path_idx: usize,
501 cur_score: f64,
502 ) -> f64 {
503 if query_idx == self.query.len() {
504 return 1.0;
505 }
506
507 let path_len = prefix.len() + path.len();
508
509 if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
510 return memoized;
511 }
512
513 let mut score = 0.0;
514 let mut best_position = 0;
515
516 let query_char = self.lowercase_query[query_idx];
517 let limit = self.last_positions[query_idx];
518
519 let mut last_slash = 0;
520 for j in path_idx..=limit {
521 let path_char = if j < prefix.len() {
522 lowercase_prefix[j]
523 } else {
524 path_cased[j - prefix.len()]
525 };
526 let is_path_sep = path_char == '/' || path_char == '\\';
527
528 if query_idx == 0 && is_path_sep {
529 last_slash = j;
530 }
531
532 if query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\') {
533 let curr = if j < prefix.len() {
534 prefix[j]
535 } else {
536 path[j - prefix.len()]
537 };
538
539 let mut char_score = 1.0;
540 if j > path_idx {
541 let last = if j - 1 < prefix.len() {
542 prefix[j - 1]
543 } else {
544 path[j - 1 - prefix.len()]
545 };
546
547 if last == '/' {
548 char_score = 0.9;
549 } else if last == '-' || last == '_' || last == ' ' || last.is_numeric() {
550 char_score = 0.8;
551 } else if last.is_lowercase() && curr.is_uppercase() {
552 char_score = 0.8;
553 } else if last == '.' {
554 char_score = 0.7;
555 } else if query_idx == 0 {
556 char_score = BASE_DISTANCE_PENALTY;
557 } else {
558 char_score = MIN_DISTANCE_PENALTY.max(
559 BASE_DISTANCE_PENALTY
560 - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
561 );
562 }
563 }
564
565 // Apply a severe penalty if the case doesn't match.
566 // This will make the exact matches have higher score than the case-insensitive and the
567 // path insensitive matches.
568 if (self.smart_case || curr == '/') && self.query[query_idx] != curr {
569 char_score *= 0.001;
570 }
571
572 let mut multiplier = char_score;
573
574 // Scale the score based on how deep within the path we found the match.
575 if query_idx == 0 {
576 multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
577 }
578
579 let mut next_score = 1.0;
580 if self.min_score > 0.0 {
581 next_score = cur_score * multiplier;
582 // Scores only decrease. If we can't pass the previous best, bail
583 if next_score < self.min_score {
584 // Ensure that score is non-zero so we use it in the memo table.
585 if score == 0.0 {
586 score = 1e-18;
587 }
588 continue;
589 }
590 }
591
592 let new_score = self.recursive_score_match(
593 path,
594 path_cased,
595 prefix,
596 lowercase_prefix,
597 query_idx + 1,
598 j + 1,
599 next_score,
600 ) * multiplier;
601
602 if new_score > score {
603 score = new_score;
604 best_position = j;
605 // Optimization: can't score better than 1.
606 if new_score == 1.0 {
607 break;
608 }
609 }
610 }
611 }
612
613 if best_position != 0 {
614 self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
615 }
616
617 self.score_matrix[query_idx * path_len + path_idx] = Some(score);
618 score
619 }
620}
621
622#[cfg(test)]
623mod tests {
624 use super::*;
625 use std::path::PathBuf;
626
627 #[test]
628 fn test_get_last_positions() {
629 let mut query: &[char] = &['d', 'c'];
630 let mut matcher = Matcher::new(query, query, query.into(), false, 10);
631 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
632 assert_eq!(result, false);
633
634 query = &['c', 'd'];
635 let mut matcher = Matcher::new(query, query, query.into(), false, 10);
636 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
637 assert_eq!(result, true);
638 assert_eq!(matcher.last_positions, vec![2, 4]);
639
640 query = &['z', '/', 'z', 'f'];
641 let mut matcher = Matcher::new(query, query, query.into(), false, 10);
642 let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
643 assert_eq!(result, true);
644 assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
645 }
646
647 #[test]
648 fn test_match_path_entries() {
649 let paths = vec![
650 "",
651 "a",
652 "ab",
653 "abC",
654 "abcd",
655 "alphabravocharlie",
656 "AlphaBravoCharlie",
657 "thisisatestdir",
658 "/////ThisIsATestDir",
659 "/this/is/a/test/dir",
660 "/test/tiatd",
661 ];
662
663 assert_eq!(
664 match_query("abc", false, &paths),
665 vec![
666 ("abC", vec![0, 1, 2]),
667 ("abcd", vec![0, 1, 2]),
668 ("AlphaBravoCharlie", vec![0, 5, 10]),
669 ("alphabravocharlie", vec![4, 5, 10]),
670 ]
671 );
672 assert_eq!(
673 match_query("t/i/a/t/d", false, &paths),
674 vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
675 );
676
677 assert_eq!(
678 match_query("tiatd", false, &paths),
679 vec![
680 ("/test/tiatd", vec![6, 7, 8, 9, 10]),
681 ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
682 ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
683 ("thisisatestdir", vec![0, 2, 6, 7, 11]),
684 ]
685 );
686 }
687
688 #[test]
689 fn test_match_multibyte_path_entries() {
690 let paths = vec!["aαbβ/cγdδ", "αβγδ/bcde", "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", "/d/🆒/h"];
691 assert_eq!("1️⃣".len(), 7);
692 assert_eq!(
693 match_query("bcd", false, &paths),
694 vec![
695 ("αβγδ/bcde", vec![9, 10, 11]),
696 ("aαbβ/cγdδ", vec![3, 7, 10]),
697 ]
698 );
699 assert_eq!(
700 match_query("cde", false, &paths),
701 vec![
702 ("αβγδ/bcde", vec![10, 11, 12]),
703 ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
704 ]
705 );
706 }
707
708 fn match_query<'a>(
709 query: &str,
710 smart_case: bool,
711 paths: &Vec<&'a str>,
712 ) -> Vec<(&'a str, Vec<usize>)> {
713 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
714 let query = query.chars().collect::<Vec<_>>();
715 let query_chars = CharBag::from(&lowercase_query[..]);
716
717 let path_arcs = paths
718 .iter()
719 .map(|path| Arc::from(PathBuf::from(path)))
720 .collect::<Vec<_>>();
721 let mut path_entries = Vec::new();
722 for (i, path) in paths.iter().enumerate() {
723 let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
724 let char_bag = CharBag::from(lowercase_path.as_slice());
725 path_entries.push(PathMatchCandidate {
726 char_bag,
727 path: path_arcs.get(i).unwrap(),
728 });
729 }
730
731 let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case, 100);
732
733 let cancel_flag = AtomicBool::new(false);
734 let mut results = Vec::new();
735 matcher.match_paths(
736 0,
737 "".into(),
738 path_entries.into_iter(),
739 &mut results,
740 &cancel_flag,
741 );
742
743 results
744 .into_iter()
745 .map(|result| {
746 (
747 paths
748 .iter()
749 .copied()
750 .find(|p| result.path.as_ref() == Path::new(p))
751 .unwrap(),
752 result.positions,
753 )
754 })
755 .collect()
756 }
757}