1use std::{
2 borrow::{Borrow, Cow},
3 collections::BTreeMap,
4 sync::atomic::{self, AtomicBool},
5};
6
7use crate::CharBag;
8
9const BASE_DISTANCE_PENALTY: f64 = 0.6;
10const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
11const MIN_DISTANCE_PENALTY: f64 = 0.2;
12
13// TODO:
14// Use `Path` instead of `&str` for paths.
15pub struct Matcher<'a> {
16 query: &'a [char],
17 lowercase_query: &'a [char],
18 query_char_bag: CharBag,
19 smart_case: bool,
20 penalize_length: bool,
21 min_score: f64,
22 match_positions: Vec<usize>,
23 last_positions: Vec<usize>,
24 score_matrix: Vec<Option<f64>>,
25 best_position_matrix: Vec<usize>,
26}
27
28pub trait MatchCandidate {
29 fn has_chars(&self, bag: CharBag) -> bool;
30 fn to_string(&self) -> Cow<'_, str>;
31}
32
33impl<'a> Matcher<'a> {
34 pub fn new(
35 query: &'a [char],
36 lowercase_query: &'a [char],
37 query_char_bag: CharBag,
38 smart_case: bool,
39 penalize_length: bool,
40 ) -> Self {
41 Self {
42 query,
43 lowercase_query,
44 query_char_bag,
45 min_score: 0.0,
46 last_positions: vec![0; lowercase_query.len()],
47 match_positions: vec![0; query.len()],
48 score_matrix: Vec::new(),
49 best_position_matrix: Vec::new(),
50 smart_case,
51 penalize_length,
52 }
53 }
54
55 /// Filter and score fuzzy match candidates. Results are returned unsorted, in the same order as
56 /// the input candidates.
57 pub(crate) fn match_candidates<C, R, F, T>(
58 &mut self,
59 prefix: &[char],
60 lowercase_prefix: &[char],
61 candidates: impl Iterator<Item = T>,
62 results: &mut Vec<R>,
63 cancel_flag: &AtomicBool,
64 build_match: F,
65 ) where
66 C: MatchCandidate,
67 T: Borrow<C>,
68 F: Fn(&C, f64, &Vec<usize>) -> R,
69 {
70 let mut candidate_chars = Vec::new();
71 let mut lowercase_candidate_chars = Vec::new();
72 let mut extra_lowercase_chars = BTreeMap::new();
73
74 for candidate in candidates {
75 if !candidate.borrow().has_chars(self.query_char_bag) {
76 continue;
77 }
78
79 if cancel_flag.load(atomic::Ordering::Relaxed) {
80 break;
81 }
82
83 candidate_chars.clear();
84 lowercase_candidate_chars.clear();
85 extra_lowercase_chars.clear();
86 for (i, c) in candidate.borrow().to_string().chars().enumerate() {
87 candidate_chars.push(c);
88 let mut char_lowercased = c.to_lowercase().collect::<Vec<_>>();
89 if char_lowercased.len() > 1 {
90 extra_lowercase_chars.insert(i, char_lowercased.len() - 1);
91 }
92 lowercase_candidate_chars.append(&mut char_lowercased);
93 }
94
95 if !self.find_last_positions(lowercase_prefix, &lowercase_candidate_chars) {
96 continue;
97 }
98
99 let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
100 self.score_matrix.clear();
101 self.score_matrix.resize(matrix_len, None);
102 self.best_position_matrix.clear();
103 self.best_position_matrix.resize(matrix_len, 0);
104
105 let score = self.score_match(
106 &candidate_chars,
107 &lowercase_candidate_chars,
108 prefix,
109 lowercase_prefix,
110 &extra_lowercase_chars,
111 );
112
113 if score > 0.0 {
114 results.push(build_match(
115 candidate.borrow(),
116 score,
117 &self.match_positions,
118 ));
119 }
120 }
121 }
122
123 fn find_last_positions(
124 &mut self,
125 lowercase_prefix: &[char],
126 lowercase_candidate: &[char],
127 ) -> bool {
128 let mut lowercase_prefix = lowercase_prefix.iter();
129 let mut lowercase_candidate = lowercase_candidate.iter();
130 for (i, char) in self.lowercase_query.iter().enumerate().rev() {
131 if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
132 self.last_positions[i] = j + lowercase_prefix.len();
133 } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
134 self.last_positions[i] = j;
135 } else {
136 return false;
137 }
138 }
139 true
140 }
141
142 fn score_match(
143 &mut self,
144 path: &[char],
145 path_lowercased: &[char],
146 prefix: &[char],
147 lowercase_prefix: &[char],
148 extra_lowercase_chars: &BTreeMap<usize, usize>,
149 ) -> f64 {
150 let score = self.recursive_score_match(
151 path,
152 path_lowercased,
153 prefix,
154 lowercase_prefix,
155 0,
156 0,
157 self.query.len() as f64,
158 extra_lowercase_chars,
159 ) * self.query.len() as f64;
160
161 if score <= 0.0 {
162 return 0.0;
163 }
164 let path_len = prefix.len() + path.len();
165 let mut cur_start = 0;
166 let mut byte_ix = 0;
167 let mut char_ix = 0;
168 for i in 0..self.query.len() {
169 let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
170 while char_ix < match_char_ix {
171 let ch = prefix
172 .get(char_ix)
173 .or_else(|| path.get(char_ix - prefix.len()))
174 .unwrap();
175 byte_ix += ch.len_utf8();
176 char_ix += 1;
177 }
178
179 self.match_positions[i] = byte_ix;
180
181 let matched_ch = prefix
182 .get(match_char_ix)
183 .or_else(|| path.get(match_char_ix - prefix.len()))
184 .unwrap();
185 byte_ix += matched_ch.len_utf8();
186
187 cur_start = match_char_ix + 1;
188 char_ix = match_char_ix + 1;
189 }
190
191 score
192 }
193
194 fn recursive_score_match(
195 &mut self,
196 path: &[char],
197 path_lowercased: &[char],
198 prefix: &[char],
199 lowercase_prefix: &[char],
200 query_idx: usize,
201 path_idx: usize,
202 cur_score: f64,
203 extra_lowercase_chars: &BTreeMap<usize, usize>,
204 ) -> f64 {
205 use std::path::MAIN_SEPARATOR;
206
207 if query_idx == self.query.len() {
208 return 1.0;
209 }
210
211 let path_len = prefix.len() + path.len();
212
213 if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
214 return memoized;
215 }
216
217 let mut score = 0.0;
218 let mut best_position = 0;
219
220 let query_char = self.lowercase_query[query_idx];
221 let limit = self.last_positions[query_idx];
222
223 let max_valid_index = (prefix.len() + path_lowercased.len()).saturating_sub(1);
224 let safe_limit = limit.min(max_valid_index);
225
226 let mut last_slash = 0;
227 for j in path_idx..=safe_limit {
228 let extra_lowercase_chars_count = extra_lowercase_chars
229 .iter()
230 .take_while(|(i, _)| i < &&j)
231 .map(|(_, increment)| increment)
232 .sum::<usize>();
233 let j_regular = j - extra_lowercase_chars_count;
234
235 let path_char = if j < prefix.len() {
236 lowercase_prefix[j]
237 } else {
238 let path_index = j - prefix.len();
239 if path_index < path_lowercased.len() {
240 path_lowercased[path_index]
241 } else {
242 continue;
243 }
244 };
245 let is_path_sep = path_char == MAIN_SEPARATOR;
246
247 if query_idx == 0 && is_path_sep {
248 last_slash = j_regular;
249 }
250
251 #[cfg(not(target_os = "windows"))]
252 let need_to_score =
253 query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\');
254 // `query_char == '\\'` breaks `test_match_path_entries` on Windows, `\` is only used as a path separator on Windows.
255 #[cfg(target_os = "windows")]
256 let need_to_score = query_char == path_char || (is_path_sep && query_char == '_');
257 if need_to_score {
258 let curr = if j_regular < prefix.len() {
259 prefix[j_regular]
260 } else {
261 path[j_regular - prefix.len()]
262 };
263
264 let mut char_score = 1.0;
265 if j > path_idx {
266 let last = if j_regular - 1 < prefix.len() {
267 prefix[j_regular - 1]
268 } else {
269 path[j_regular - 1 - prefix.len()]
270 };
271
272 if last == MAIN_SEPARATOR {
273 char_score = 0.9;
274 } else if (last == '-' || last == '_' || last == ' ' || last.is_numeric())
275 || (last.is_lowercase() && curr.is_uppercase())
276 {
277 char_score = 0.8;
278 } else if last == '.' {
279 char_score = 0.7;
280 } else if query_idx == 0 {
281 char_score = BASE_DISTANCE_PENALTY;
282 } else {
283 char_score = MIN_DISTANCE_PENALTY.max(
284 BASE_DISTANCE_PENALTY
285 - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
286 );
287 }
288 }
289
290 // Apply a severe penalty if the case doesn't match.
291 // This will make the exact matches have higher score than the case-insensitive and the
292 // path insensitive matches.
293 if (self.smart_case || curr == MAIN_SEPARATOR) && self.query[query_idx] != curr {
294 char_score *= 0.001;
295 }
296
297 let mut multiplier = char_score;
298
299 // Scale the score based on how deep within the path we found the match.
300 if self.penalize_length && query_idx == 0 {
301 multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
302 }
303
304 let mut next_score = 1.0;
305 if self.min_score > 0.0 {
306 next_score = cur_score * multiplier;
307 // Scores only decrease. If we can't pass the previous best, bail
308 if next_score < self.min_score {
309 // Ensure that score is non-zero so we use it in the memo table.
310 if score == 0.0 {
311 score = 1e-18;
312 }
313 continue;
314 }
315 }
316
317 let new_score = self.recursive_score_match(
318 path,
319 path_lowercased,
320 prefix,
321 lowercase_prefix,
322 query_idx + 1,
323 j + 1,
324 next_score,
325 extra_lowercase_chars,
326 ) * multiplier;
327
328 if new_score > score {
329 score = new_score;
330 best_position = j_regular;
331 // Optimization: can't score better than 1.
332 if new_score == 1.0 {
333 break;
334 }
335 }
336 }
337 }
338
339 if best_position != 0 {
340 self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
341 }
342
343 self.score_matrix[query_idx * path_len + path_idx] = Some(score);
344 score
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use crate::{PathMatch, PathMatchCandidate};
351
352 use super::*;
353 use std::{
354 path::{Path, PathBuf},
355 sync::Arc,
356 };
357
358 #[test]
359 fn test_get_last_positions() {
360 let mut query: &[char] = &['d', 'c'];
361 let mut matcher = Matcher::new(query, query, query.into(), false, true);
362 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
363 assert!(!result);
364
365 query = &['c', 'd'];
366 let mut matcher = Matcher::new(query, query, query.into(), false, true);
367 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
368 assert!(result);
369 assert_eq!(matcher.last_positions, vec![2, 4]);
370
371 query = &['z', '/', 'z', 'f'];
372 let mut matcher = Matcher::new(query, query, query.into(), false, true);
373 let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
374 assert!(result);
375 assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
376 }
377
378 #[cfg(not(target_os = "windows"))]
379 #[test]
380 fn test_match_path_entries() {
381 let paths = vec![
382 "",
383 "a",
384 "ab",
385 "abC",
386 "abcd",
387 "alphabravocharlie",
388 "AlphaBravoCharlie",
389 "thisisatestdir",
390 "/////ThisIsATestDir",
391 "/this/is/a/test/dir",
392 "/test/tiatd",
393 ];
394
395 assert_eq!(
396 match_single_path_query("abc", false, &paths),
397 vec![
398 ("abC", vec![0, 1, 2]),
399 ("abcd", vec![0, 1, 2]),
400 ("AlphaBravoCharlie", vec![0, 5, 10]),
401 ("alphabravocharlie", vec![4, 5, 10]),
402 ]
403 );
404 assert_eq!(
405 match_single_path_query("t/i/a/t/d", false, &paths),
406 vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
407 );
408
409 assert_eq!(
410 match_single_path_query("tiatd", false, &paths),
411 vec![
412 ("/test/tiatd", vec![6, 7, 8, 9, 10]),
413 ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
414 ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
415 ("thisisatestdir", vec![0, 2, 6, 7, 11]),
416 ]
417 );
418 }
419
420 /// todo(windows)
421 /// Now, on Windows, users can only use the backslash as a path separator.
422 /// I do want to support both the backslash and the forward slash as path separators on Windows.
423 #[cfg(target_os = "windows")]
424 #[test]
425 fn test_match_path_entries() {
426 let paths = vec![
427 "",
428 "a",
429 "ab",
430 "abC",
431 "abcd",
432 "alphabravocharlie",
433 "AlphaBravoCharlie",
434 "thisisatestdir",
435 "\\\\\\\\\\ThisIsATestDir",
436 "\\this\\is\\a\\test\\dir",
437 "\\test\\tiatd",
438 ];
439
440 assert_eq!(
441 match_single_path_query("abc", false, &paths),
442 vec![
443 ("abC", vec![0, 1, 2]),
444 ("abcd", vec![0, 1, 2]),
445 ("AlphaBravoCharlie", vec![0, 5, 10]),
446 ("alphabravocharlie", vec![4, 5, 10]),
447 ]
448 );
449 assert_eq!(
450 match_single_path_query("t\\i\\a\\t\\d", false, &paths),
451 vec![(
452 "\\this\\is\\a\\test\\dir",
453 vec![1, 5, 6, 8, 9, 10, 11, 15, 16]
454 ),]
455 );
456
457 assert_eq!(
458 match_single_path_query("tiatd", false, &paths),
459 vec![
460 ("\\test\\tiatd", vec![6, 7, 8, 9, 10]),
461 ("\\this\\is\\a\\test\\dir", vec![1, 6, 9, 11, 16]),
462 ("\\\\\\\\\\ThisIsATestDir", vec![5, 9, 11, 12, 16]),
463 ("thisisatestdir", vec![0, 2, 6, 7, 11]),
464 ]
465 );
466 }
467
468 #[test]
469 fn test_lowercase_longer_than_uppercase() {
470 // This character has more chars in lower-case than in upper-case.
471 let paths = vec!["\u{0130}"];
472 let query = "\u{0130}";
473 assert_eq!(
474 match_single_path_query(query, false, &paths),
475 vec![("\u{0130}", vec![0])]
476 );
477
478 // Path is the lower-case version of the query
479 let paths = vec!["i\u{307}"];
480 let query = "\u{0130}";
481 assert_eq!(
482 match_single_path_query(query, false, &paths),
483 vec![("i\u{307}", vec![0])]
484 );
485 }
486
487 #[test]
488 fn test_match_multibyte_path_entries() {
489 let paths = vec![
490 "aαbβ/cγdδ",
491 "αβγδ/bcde",
492 "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f",
493 "/d/🆒/h",
494 ];
495 assert_eq!("1️⃣".len(), 7);
496 assert_eq!(
497 match_single_path_query("bcd", false, &paths),
498 vec![
499 ("αβγδ/bcde", vec![9, 10, 11]),
500 ("aαbβ/cγdδ", vec![3, 7, 10]),
501 ]
502 );
503 assert_eq!(
504 match_single_path_query("cde", false, &paths),
505 vec![
506 ("αβγδ/bcde", vec![10, 11, 12]),
507 ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
508 ]
509 );
510 }
511
512 #[test]
513 fn match_unicode_path_entries() {
514 let mixed_unicode_paths = vec![
515 "İolu/oluş",
516 "İstanbul/code",
517 "Athens/Şanlıurfa",
518 "Çanakkale/scripts",
519 "paris/Düzce_İl",
520 "Berlin_Önemli_Ğündem",
521 "KİTAPLIK/london/dosya",
522 "tokyo/kyoto/fuji",
523 "new_york/san_francisco",
524 ];
525
526 assert_eq!(
527 match_single_path_query("İo/oluş", false, &mixed_unicode_paths),
528 vec![("İolu/oluş", vec![0, 2, 4, 6, 8, 10, 12])]
529 );
530
531 assert_eq!(
532 match_single_path_query("İst/code", false, &mixed_unicode_paths),
533 vec![("İstanbul/code", vec![0, 2, 4, 6, 8, 10, 12, 14])]
534 );
535
536 assert_eq!(
537 match_single_path_query("athens/şa", false, &mixed_unicode_paths),
538 vec![("Athens/Şanlıurfa", vec![0, 1, 2, 3, 4, 5, 6, 7, 9])]
539 );
540
541 assert_eq!(
542 match_single_path_query("BerlinÖĞ", false, &mixed_unicode_paths),
543 vec![("Berlin_Önemli_Ğündem", vec![0, 1, 2, 3, 4, 5, 7, 15])]
544 );
545
546 assert_eq!(
547 match_single_path_query("tokyo/fuji", false, &mixed_unicode_paths),
548 vec![("tokyo/kyoto/fuji", vec![0, 1, 2, 3, 4, 5, 12, 13, 14, 15])]
549 );
550
551 let mixed_script_paths = vec![
552 "résumé_Москва",
553 "naïve_київ_implementation",
554 "café_北京_app",
555 "東京_über_driver",
556 "déjà_vu_cairo",
557 "seoul_piñata_game",
558 "voilà_istanbul_result",
559 ];
560
561 assert_eq!(
562 match_single_path_query("résmé", false, &mixed_script_paths),
563 vec![("résumé_Москва", vec![0, 1, 3, 5, 6])]
564 );
565
566 assert_eq!(
567 match_single_path_query("café北京", false, &mixed_script_paths),
568 vec![("café_北京_app", vec![0, 1, 2, 3, 6, 9])]
569 );
570
571 assert_eq!(
572 match_single_path_query("ista", false, &mixed_script_paths),
573 vec![("voilà_istanbul_result", vec![7, 8, 9, 10])]
574 );
575
576 let complex_paths = vec![
577 "document_📚_library",
578 "project_👨👩👧👦_family",
579 "flags_🇯🇵🇺🇸🇪🇺_world",
580 "code_😀😃😄😁_happy",
581 "photo_👩👩👧👦_album",
582 ];
583
584 assert_eq!(
585 match_single_path_query("doc📚lib", false, &complex_paths),
586 vec![("document_📚_library", vec![0, 1, 2, 9, 14, 15, 16])]
587 );
588
589 assert_eq!(
590 match_single_path_query("codehappy", false, &complex_paths),
591 vec![("code_😀😃😄😁_happy", vec![0, 1, 2, 3, 22, 23, 24, 25, 26])]
592 );
593 }
594
595 fn match_single_path_query<'a>(
596 query: &str,
597 smart_case: bool,
598 paths: &[&'a str],
599 ) -> Vec<(&'a str, Vec<usize>)> {
600 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
601 let query = query.chars().collect::<Vec<_>>();
602 let query_chars = CharBag::from(&lowercase_query[..]);
603
604 let path_arcs: Vec<Arc<Path>> = paths
605 .iter()
606 .map(|path| Arc::from(PathBuf::from(path)))
607 .collect::<Vec<_>>();
608 let mut path_entries = Vec::new();
609 for (i, path) in paths.iter().enumerate() {
610 let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
611 let char_bag = CharBag::from(lowercase_path.as_slice());
612 path_entries.push(PathMatchCandidate {
613 is_dir: false,
614 char_bag,
615 path: &path_arcs[i],
616 });
617 }
618
619 let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case, true);
620
621 let cancel_flag = AtomicBool::new(false);
622 let mut results = Vec::new();
623
624 matcher.match_candidates(
625 &[],
626 &[],
627 path_entries.into_iter(),
628 &mut results,
629 &cancel_flag,
630 |candidate, score, positions| PathMatch {
631 score,
632 worktree_id: 0,
633 positions: positions.clone(),
634 path: Arc::from(candidate.path),
635 path_prefix: "".into(),
636 distance_to_relative_ancestor: usize::MAX,
637 is_dir: false,
638 },
639 );
640 results.sort_by(|a, b| b.cmp(a));
641
642 results
643 .into_iter()
644 .map(|result| {
645 (
646 paths
647 .iter()
648 .copied()
649 .find(|p| result.path.as_ref() == Path::new(p))
650 .unwrap(),
651 result.positions,
652 )
653 })
654 .collect()
655 }
656}