1use std::{
2 borrow::{Borrow, Cow},
3 collections::BTreeMap,
4 sync::atomic::{self, AtomicBool},
5};
6
7use crate::CharBag;
8
9const BASE_DISTANCE_PENALTY: f64 = 0.6;
10const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
11const MIN_DISTANCE_PENALTY: f64 = 0.2;
12
13// TODO:
14// Use `Path` instead of `&str` for paths.
15pub struct Matcher<'a> {
16 query: &'a [char],
17 lowercase_query: &'a [char],
18 query_char_bag: CharBag,
19 smart_case: bool,
20 penalize_length: bool,
21 min_score: f64,
22 match_positions: Vec<usize>,
23 last_positions: Vec<usize>,
24 score_matrix: Vec<Option<f64>>,
25 best_position_matrix: Vec<usize>,
26}
27
28pub trait MatchCandidate {
29 fn has_chars(&self, bag: CharBag) -> bool;
30 fn to_string(&self) -> Cow<'_, str>;
31}
32
33impl<'a> Matcher<'a> {
34 pub fn new(
35 query: &'a [char],
36 lowercase_query: &'a [char],
37 query_char_bag: CharBag,
38 smart_case: bool,
39 penalize_length: bool,
40 ) -> Self {
41 Self {
42 query,
43 lowercase_query,
44 query_char_bag,
45 min_score: 0.0,
46 last_positions: vec![0; lowercase_query.len()],
47 match_positions: vec![0; query.len()],
48 score_matrix: Vec::new(),
49 best_position_matrix: Vec::new(),
50 smart_case,
51 penalize_length,
52 }
53 }
54
55 /// Filter and score fuzzy match candidates. Results are returned unsorted, in the same order as
56 /// the input candidates.
57 pub(crate) fn match_candidates<C, R, F, T>(
58 &mut self,
59 prefix: &[char],
60 lowercase_prefix: &[char],
61 candidates: impl Iterator<Item = T>,
62 results: &mut Vec<R>,
63 cancel_flag: &AtomicBool,
64 build_match: F,
65 ) where
66 C: MatchCandidate,
67 T: Borrow<C>,
68 F: Fn(&C, f64, &Vec<usize>) -> R,
69 {
70 let mut candidate_chars = Vec::new();
71 let mut lowercase_candidate_chars = Vec::new();
72 let mut extra_lowercase_chars = BTreeMap::new();
73
74 for candidate in candidates {
75 if !candidate.borrow().has_chars(self.query_char_bag) {
76 continue;
77 }
78
79 if cancel_flag.load(atomic::Ordering::Relaxed) {
80 break;
81 }
82
83 candidate_chars.clear();
84 lowercase_candidate_chars.clear();
85 extra_lowercase_chars.clear();
86 for (i, c) in candidate.borrow().to_string().chars().enumerate() {
87 candidate_chars.push(c);
88 let mut char_lowercased = c.to_lowercase().collect::<Vec<_>>();
89 if char_lowercased.len() > 1 {
90 extra_lowercase_chars.insert(i, char_lowercased.len() - 1);
91 }
92 lowercase_candidate_chars.append(&mut char_lowercased);
93 }
94
95 if !self.find_last_positions(lowercase_prefix, &lowercase_candidate_chars) {
96 continue;
97 }
98
99 let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
100 self.score_matrix.clear();
101 self.score_matrix.resize(matrix_len, None);
102 self.best_position_matrix.clear();
103 self.best_position_matrix.resize(matrix_len, 0);
104
105 let score = self.score_match(
106 &candidate_chars,
107 &lowercase_candidate_chars,
108 prefix,
109 lowercase_prefix,
110 &extra_lowercase_chars,
111 );
112
113 if score > 0.0 {
114 results.push(build_match(
115 candidate.borrow(),
116 score,
117 &self.match_positions,
118 ));
119 }
120 }
121 }
122
123 fn find_last_positions(
124 &mut self,
125 lowercase_prefix: &[char],
126 lowercase_candidate: &[char],
127 ) -> bool {
128 let mut lowercase_prefix = lowercase_prefix.iter();
129 let mut lowercase_candidate = lowercase_candidate.iter();
130 for (i, char) in self.lowercase_query.iter().enumerate().rev() {
131 if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
132 self.last_positions[i] = j + lowercase_prefix.len();
133 } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
134 self.last_positions[i] = j;
135 } else {
136 return false;
137 }
138 }
139 true
140 }
141
142 fn score_match(
143 &mut self,
144 path: &[char],
145 path_lowercased: &[char],
146 prefix: &[char],
147 lowercase_prefix: &[char],
148 extra_lowercase_chars: &BTreeMap<usize, usize>,
149 ) -> f64 {
150 let score = self.recursive_score_match(
151 path,
152 path_lowercased,
153 prefix,
154 lowercase_prefix,
155 0,
156 0,
157 self.query.len() as f64,
158 extra_lowercase_chars,
159 ) * self.query.len() as f64;
160
161 if score <= 0.0 {
162 return 0.0;
163 }
164 let path_len = prefix.len() + path.len();
165 let mut cur_start = 0;
166 let mut byte_ix = 0;
167 let mut char_ix = 0;
168 for i in 0..self.query.len() {
169 let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
170 while char_ix < match_char_ix {
171 let ch = prefix
172 .get(char_ix)
173 .or_else(|| path.get(char_ix - prefix.len()))
174 .unwrap();
175 byte_ix += ch.len_utf8();
176 char_ix += 1;
177 }
178
179 self.match_positions[i] = byte_ix;
180
181 let matched_ch = prefix
182 .get(match_char_ix)
183 .or_else(|| path.get(match_char_ix - prefix.len()))
184 .unwrap();
185 byte_ix += matched_ch.len_utf8();
186
187 cur_start = match_char_ix + 1;
188 char_ix = match_char_ix + 1;
189 }
190
191 score
192 }
193
194 fn recursive_score_match(
195 &mut self,
196 path: &[char],
197 path_lowercased: &[char],
198 prefix: &[char],
199 lowercase_prefix: &[char],
200 query_idx: usize,
201 path_idx: usize,
202 cur_score: f64,
203 extra_lowercase_chars: &BTreeMap<usize, usize>,
204 ) -> f64 {
205 use std::path::MAIN_SEPARATOR;
206
207 if query_idx == self.query.len() {
208 return 1.0;
209 }
210
211 let limit = self.last_positions[query_idx];
212 let max_valid_index = (prefix.len() + path_lowercased.len()).saturating_sub(1);
213 let safe_limit = limit.min(max_valid_index);
214
215 if path_idx > safe_limit {
216 return 0.0;
217 }
218
219 let path_len = prefix.len() + path.len();
220 if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
221 return memoized;
222 }
223
224 let mut score = 0.0;
225 let mut best_position = 0;
226
227 let query_char = self.lowercase_query[query_idx];
228
229 let mut last_slash = 0;
230
231 for j in path_idx..=safe_limit {
232 let extra_lowercase_chars_count = extra_lowercase_chars
233 .iter()
234 .take_while(|&(&i, _)| i < j)
235 .map(|(_, increment)| increment)
236 .sum::<usize>();
237 let j_regular = j - extra_lowercase_chars_count;
238
239 let path_char = if j < prefix.len() {
240 lowercase_prefix[j]
241 } else {
242 let path_index = j - prefix.len();
243 match path_lowercased.get(path_index) {
244 Some(&char) => char,
245 None => continue,
246 }
247 };
248 let is_path_sep = path_char == MAIN_SEPARATOR;
249
250 if query_idx == 0 && is_path_sep {
251 last_slash = j_regular;
252 }
253
254 #[cfg(not(target_os = "windows"))]
255 let need_to_score =
256 query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\');
257 // `query_char == '\\'` breaks `test_match_path_entries` on Windows, `\` is only used as a path separator on Windows.
258 #[cfg(target_os = "windows")]
259 let need_to_score = query_char == path_char || (is_path_sep && query_char == '_');
260 if need_to_score {
261 let curr = match prefix.get(j_regular) {
262 Some(&curr) => curr,
263 None => path[j_regular - prefix.len()],
264 };
265
266 let mut char_score = 1.0;
267 if j > path_idx {
268 let last = match prefix.get(j_regular - 1) {
269 Some(&last) => last,
270 None => path[j_regular - 1 - prefix.len()],
271 };
272
273 if last == MAIN_SEPARATOR {
274 char_score = 0.9;
275 } else if (last == '-' || last == '_' || last == ' ' || last.is_numeric())
276 || (last.is_lowercase() && curr.is_uppercase())
277 {
278 char_score = 0.8;
279 } else if last == '.' {
280 char_score = 0.7;
281 } else if query_idx == 0 {
282 char_score = BASE_DISTANCE_PENALTY;
283 } else {
284 char_score = MIN_DISTANCE_PENALTY.max(
285 BASE_DISTANCE_PENALTY
286 - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
287 );
288 }
289 }
290
291 // Apply a severe penalty if the case doesn't match.
292 // This will make the exact matches have higher score than the case-insensitive and the
293 // path insensitive matches.
294 if (self.smart_case || curr == MAIN_SEPARATOR) && self.query[query_idx] != curr {
295 char_score *= 0.001;
296 }
297
298 let mut multiplier = char_score;
299
300 // Scale the score based on how deep within the path we found the match.
301 if self.penalize_length && query_idx == 0 {
302 multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
303 }
304
305 let mut next_score = 1.0;
306 if self.min_score > 0.0 {
307 next_score = cur_score * multiplier;
308 // Scores only decrease. If we can't pass the previous best, bail
309 if next_score < self.min_score {
310 // Ensure that score is non-zero so we use it in the memo table.
311 if score == 0.0 {
312 score = 1e-18;
313 }
314 continue;
315 }
316 }
317
318 let new_score = self.recursive_score_match(
319 path,
320 path_lowercased,
321 prefix,
322 lowercase_prefix,
323 query_idx + 1,
324 j + 1,
325 next_score,
326 extra_lowercase_chars,
327 ) * multiplier;
328
329 if new_score > score {
330 score = new_score;
331 best_position = j_regular;
332 // Optimization: can't score better than 1.
333 if new_score == 1.0 {
334 break;
335 }
336 }
337 }
338 }
339
340 if best_position != 0 {
341 self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
342 }
343
344 self.score_matrix[query_idx * path_len + path_idx] = Some(score);
345 score
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use crate::{PathMatch, PathMatchCandidate};
352
353 use super::*;
354 use std::{
355 path::{Path, PathBuf},
356 sync::Arc,
357 };
358
359 #[test]
360 fn test_get_last_positions() {
361 let mut query: &[char] = &['d', 'c'];
362 let mut matcher = Matcher::new(query, query, query.into(), false, true);
363 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
364 assert!(!result);
365
366 query = &['c', 'd'];
367 let mut matcher = Matcher::new(query, query, query.into(), false, true);
368 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
369 assert!(result);
370 assert_eq!(matcher.last_positions, vec![2, 4]);
371
372 query = &['z', '/', 'z', 'f'];
373 let mut matcher = Matcher::new(query, query, query.into(), false, true);
374 let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
375 assert!(result);
376 assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
377 }
378
379 #[cfg(not(target_os = "windows"))]
380 #[test]
381 fn test_match_path_entries() {
382 let paths = vec![
383 "",
384 "a",
385 "ab",
386 "abC",
387 "abcd",
388 "alphabravocharlie",
389 "AlphaBravoCharlie",
390 "thisisatestdir",
391 "/////ThisIsATestDir",
392 "/this/is/a/test/dir",
393 "/test/tiatd",
394 ];
395
396 assert_eq!(
397 match_single_path_query("abc", false, &paths),
398 vec![
399 ("abC", vec![0, 1, 2]),
400 ("abcd", vec![0, 1, 2]),
401 ("AlphaBravoCharlie", vec![0, 5, 10]),
402 ("alphabravocharlie", vec![4, 5, 10]),
403 ]
404 );
405 assert_eq!(
406 match_single_path_query("t/i/a/t/d", false, &paths),
407 vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
408 );
409
410 assert_eq!(
411 match_single_path_query("tiatd", false, &paths),
412 vec![
413 ("/test/tiatd", vec![6, 7, 8, 9, 10]),
414 ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
415 ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
416 ("thisisatestdir", vec![0, 2, 6, 7, 11]),
417 ]
418 );
419 }
420
421 /// todo(windows)
422 /// Now, on Windows, users can only use the backslash as a path separator.
423 /// I do want to support both the backslash and the forward slash as path separators on Windows.
424 #[cfg(target_os = "windows")]
425 #[test]
426 fn test_match_path_entries() {
427 let paths = vec![
428 "",
429 "a",
430 "ab",
431 "abC",
432 "abcd",
433 "alphabravocharlie",
434 "AlphaBravoCharlie",
435 "thisisatestdir",
436 "\\\\\\\\\\ThisIsATestDir",
437 "\\this\\is\\a\\test\\dir",
438 "\\test\\tiatd",
439 ];
440
441 assert_eq!(
442 match_single_path_query("abc", false, &paths),
443 vec![
444 ("abC", vec![0, 1, 2]),
445 ("abcd", vec![0, 1, 2]),
446 ("AlphaBravoCharlie", vec![0, 5, 10]),
447 ("alphabravocharlie", vec![4, 5, 10]),
448 ]
449 );
450 assert_eq!(
451 match_single_path_query("t\\i\\a\\t\\d", false, &paths),
452 vec![(
453 "\\this\\is\\a\\test\\dir",
454 vec![1, 5, 6, 8, 9, 10, 11, 15, 16]
455 ),]
456 );
457
458 assert_eq!(
459 match_single_path_query("tiatd", false, &paths),
460 vec![
461 ("\\test\\tiatd", vec![6, 7, 8, 9, 10]),
462 ("\\this\\is\\a\\test\\dir", vec![1, 6, 9, 11, 16]),
463 ("\\\\\\\\\\ThisIsATestDir", vec![5, 9, 11, 12, 16]),
464 ("thisisatestdir", vec![0, 2, 6, 7, 11]),
465 ]
466 );
467 }
468
469 #[test]
470 fn test_lowercase_longer_than_uppercase() {
471 // This character has more chars in lower-case than in upper-case.
472 let paths = vec!["\u{0130}"];
473 let query = "\u{0130}";
474 assert_eq!(
475 match_single_path_query(query, false, &paths),
476 vec![("\u{0130}", vec![0])]
477 );
478
479 // Path is the lower-case version of the query
480 let paths = vec!["i\u{307}"];
481 let query = "\u{0130}";
482 assert_eq!(
483 match_single_path_query(query, false, &paths),
484 vec![("i\u{307}", vec![0])]
485 );
486 }
487
488 #[test]
489 fn test_match_multibyte_path_entries() {
490 let paths = vec![
491 "aαbβ/cγdδ",
492 "αβγδ/bcde",
493 "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f",
494 "/d/🆒/h",
495 ];
496 assert_eq!("1️⃣".len(), 7);
497 assert_eq!(
498 match_single_path_query("bcd", false, &paths),
499 vec![
500 ("αβγδ/bcde", vec![9, 10, 11]),
501 ("aαbβ/cγdδ", vec![3, 7, 10]),
502 ]
503 );
504 assert_eq!(
505 match_single_path_query("cde", false, &paths),
506 vec![
507 ("αβγδ/bcde", vec![10, 11, 12]),
508 ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
509 ]
510 );
511 }
512
513 #[test]
514 fn match_unicode_path_entries() {
515 let mixed_unicode_paths = vec![
516 "İolu/oluş",
517 "İstanbul/code",
518 "Athens/Şanlıurfa",
519 "Çanakkale/scripts",
520 "paris/Düzce_İl",
521 "Berlin_Önemli_Ğündem",
522 "KİTAPLIK/london/dosya",
523 "tokyo/kyoto/fuji",
524 "new_york/san_francisco",
525 ];
526
527 assert_eq!(
528 match_single_path_query("İo/oluş", false, &mixed_unicode_paths),
529 vec![("İolu/oluş", vec![0, 2, 4, 6, 8, 10, 12])]
530 );
531
532 assert_eq!(
533 match_single_path_query("İst/code", false, &mixed_unicode_paths),
534 vec![("İstanbul/code", vec![0, 2, 4, 6, 8, 10, 12, 14])]
535 );
536
537 assert_eq!(
538 match_single_path_query("athens/şa", false, &mixed_unicode_paths),
539 vec![("Athens/Şanlıurfa", vec![0, 1, 2, 3, 4, 5, 6, 7, 9])]
540 );
541
542 assert_eq!(
543 match_single_path_query("BerlinÖĞ", false, &mixed_unicode_paths),
544 vec![("Berlin_Önemli_Ğündem", vec![0, 1, 2, 3, 4, 5, 7, 15])]
545 );
546
547 assert_eq!(
548 match_single_path_query("tokyo/fuji", false, &mixed_unicode_paths),
549 vec![("tokyo/kyoto/fuji", vec![0, 1, 2, 3, 4, 5, 12, 13, 14, 15])]
550 );
551
552 let mixed_script_paths = vec![
553 "résumé_Москва",
554 "naïve_київ_implementation",
555 "café_北京_app",
556 "東京_über_driver",
557 "déjà_vu_cairo",
558 "seoul_piñata_game",
559 "voilà_istanbul_result",
560 ];
561
562 assert_eq!(
563 match_single_path_query("résmé", false, &mixed_script_paths),
564 vec![("résumé_Москва", vec![0, 1, 3, 5, 6])]
565 );
566
567 assert_eq!(
568 match_single_path_query("café北京", false, &mixed_script_paths),
569 vec![("café_北京_app", vec![0, 1, 2, 3, 6, 9])]
570 );
571
572 assert_eq!(
573 match_single_path_query("ista", false, &mixed_script_paths),
574 vec![("voilà_istanbul_result", vec![7, 8, 9, 10])]
575 );
576
577 let complex_paths = vec![
578 "document_📚_library",
579 "project_👨👩👧👦_family",
580 "flags_🇯🇵🇺🇸🇪🇺_world",
581 "code_😀😃😄😁_happy",
582 "photo_👩👩👧👦_album",
583 ];
584
585 assert_eq!(
586 match_single_path_query("doc📚lib", false, &complex_paths),
587 vec![("document_📚_library", vec![0, 1, 2, 9, 14, 15, 16])]
588 );
589
590 assert_eq!(
591 match_single_path_query("codehappy", false, &complex_paths),
592 vec![("code_😀😃😄😁_happy", vec![0, 1, 2, 3, 22, 23, 24, 25, 26])]
593 );
594 }
595
596 fn match_single_path_query<'a>(
597 query: &str,
598 smart_case: bool,
599 paths: &[&'a str],
600 ) -> Vec<(&'a str, Vec<usize>)> {
601 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
602 let query = query.chars().collect::<Vec<_>>();
603 let query_chars = CharBag::from(&lowercase_query[..]);
604
605 let path_arcs: Vec<Arc<Path>> = paths
606 .iter()
607 .map(|path| Arc::from(PathBuf::from(path)))
608 .collect::<Vec<_>>();
609 let mut path_entries = Vec::new();
610 for (i, path) in paths.iter().enumerate() {
611 let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
612 let char_bag = CharBag::from(lowercase_path.as_slice());
613 path_entries.push(PathMatchCandidate {
614 is_dir: false,
615 char_bag,
616 path: &path_arcs[i],
617 });
618 }
619
620 let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case, true);
621
622 let cancel_flag = AtomicBool::new(false);
623 let mut results = Vec::new();
624
625 matcher.match_candidates(
626 &[],
627 &[],
628 path_entries.into_iter(),
629 &mut results,
630 &cancel_flag,
631 |candidate, score, positions| PathMatch {
632 score,
633 worktree_id: 0,
634 positions: positions.clone(),
635 path: Arc::from(candidate.path),
636 path_prefix: "".into(),
637 distance_to_relative_ancestor: usize::MAX,
638 is_dir: false,
639 },
640 );
641 results.sort_by(|a, b| b.cmp(a));
642
643 results
644 .into_iter()
645 .map(|result| {
646 (
647 paths
648 .iter()
649 .copied()
650 .find(|p| result.path.as_ref() == Path::new(p))
651 .unwrap(),
652 result.positions,
653 )
654 })
655 .collect()
656 }
657}