1use std::{
2 borrow::{Borrow, Cow},
3 collections::BTreeMap,
4 sync::atomic::{self, AtomicBool},
5};
6
7use crate::CharBag;
8
9const BASE_DISTANCE_PENALTY: f64 = 0.6;
10const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
11const MIN_DISTANCE_PENALTY: f64 = 0.2;
12
13// TODO:
14// Use `Path` instead of `&str` for paths.
15pub struct Matcher<'a> {
16 query: &'a [char],
17 lowercase_query: &'a [char],
18 query_char_bag: CharBag,
19 smart_case: bool,
20 min_score: f64,
21 match_positions: Vec<usize>,
22 last_positions: Vec<usize>,
23 score_matrix: Vec<Option<f64>>,
24 best_position_matrix: Vec<usize>,
25}
26
27pub trait MatchCandidate {
28 fn has_chars(&self, bag: CharBag) -> bool;
29 fn to_string(&self) -> Cow<'_, str>;
30}
31
32impl<'a> Matcher<'a> {
33 pub fn new(
34 query: &'a [char],
35 lowercase_query: &'a [char],
36 query_char_bag: CharBag,
37 smart_case: bool,
38 ) -> Self {
39 Self {
40 query,
41 lowercase_query,
42 query_char_bag,
43 min_score: 0.0,
44 last_positions: vec![0; lowercase_query.len()],
45 match_positions: vec![0; query.len()],
46 score_matrix: Vec::new(),
47 best_position_matrix: Vec::new(),
48 smart_case,
49 }
50 }
51
52 /// Filter and score fuzzy match candidates. Results are returned unsorted, in the same order as
53 /// the input candidates.
54 pub(crate) fn match_candidates<C, R, F, T>(
55 &mut self,
56 prefix: &[char],
57 lowercase_prefix: &[char],
58 candidates: impl Iterator<Item = T>,
59 results: &mut Vec<R>,
60 cancel_flag: &AtomicBool,
61 build_match: F,
62 ) where
63 C: MatchCandidate,
64 T: Borrow<C>,
65 F: Fn(&C, f64, &Vec<usize>) -> R,
66 {
67 let mut candidate_chars = Vec::new();
68 let mut lowercase_candidate_chars = Vec::new();
69 let mut extra_lowercase_chars = BTreeMap::new();
70
71 for candidate in candidates {
72 if !candidate.borrow().has_chars(self.query_char_bag) {
73 continue;
74 }
75
76 if cancel_flag.load(atomic::Ordering::Relaxed) {
77 break;
78 }
79
80 candidate_chars.clear();
81 lowercase_candidate_chars.clear();
82 extra_lowercase_chars.clear();
83 for (i, c) in candidate.borrow().to_string().chars().enumerate() {
84 candidate_chars.push(c);
85 let mut char_lowercased = c.to_lowercase().collect::<Vec<_>>();
86 if char_lowercased.len() > 1 {
87 extra_lowercase_chars.insert(i, char_lowercased.len() - 1);
88 }
89 lowercase_candidate_chars.append(&mut char_lowercased);
90 }
91
92 if !self.find_last_positions(lowercase_prefix, &lowercase_candidate_chars) {
93 continue;
94 }
95
96 let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
97 self.score_matrix.clear();
98 self.score_matrix.resize(matrix_len, None);
99 self.best_position_matrix.clear();
100 self.best_position_matrix.resize(matrix_len, 0);
101
102 let score = self.score_match(
103 &candidate_chars,
104 &lowercase_candidate_chars,
105 prefix,
106 lowercase_prefix,
107 &extra_lowercase_chars,
108 );
109
110 if score > 0.0 {
111 results.push(build_match(
112 candidate.borrow(),
113 score,
114 &self.match_positions,
115 ));
116 }
117 }
118 }
119
120 fn find_last_positions(
121 &mut self,
122 lowercase_prefix: &[char],
123 lowercase_candidate: &[char],
124 ) -> bool {
125 let mut lowercase_prefix = lowercase_prefix.iter();
126 let mut lowercase_candidate = lowercase_candidate.iter();
127 for (i, char) in self.lowercase_query.iter().enumerate().rev() {
128 if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
129 self.last_positions[i] = j + lowercase_prefix.len();
130 } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
131 self.last_positions[i] = j;
132 } else {
133 return false;
134 }
135 }
136 true
137 }
138
139 fn score_match(
140 &mut self,
141 path: &[char],
142 path_lowercased: &[char],
143 prefix: &[char],
144 lowercase_prefix: &[char],
145 extra_lowercase_chars: &BTreeMap<usize, usize>,
146 ) -> f64 {
147 let score = self.recursive_score_match(
148 path,
149 path_lowercased,
150 prefix,
151 lowercase_prefix,
152 0,
153 0,
154 self.query.len() as f64,
155 extra_lowercase_chars,
156 ) * self.query.len() as f64;
157
158 if score <= 0.0 {
159 return 0.0;
160 }
161 let path_len = prefix.len() + path.len();
162 let mut cur_start = 0;
163 let mut byte_ix = 0;
164 let mut char_ix = 0;
165 for i in 0..self.query.len() {
166 let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
167 while char_ix < match_char_ix {
168 let ch = prefix
169 .get(char_ix)
170 .or_else(|| path.get(char_ix - prefix.len()))
171 .unwrap();
172 byte_ix += ch.len_utf8();
173 char_ix += 1;
174 }
175
176 self.match_positions[i] = byte_ix;
177
178 let matched_ch = prefix
179 .get(match_char_ix)
180 .or_else(|| path.get(match_char_ix - prefix.len()))
181 .unwrap();
182 byte_ix += matched_ch.len_utf8();
183
184 cur_start = match_char_ix + 1;
185 char_ix = match_char_ix + 1;
186 }
187
188 score
189 }
190
191 fn recursive_score_match(
192 &mut self,
193 path: &[char],
194 path_lowercased: &[char],
195 prefix: &[char],
196 lowercase_prefix: &[char],
197 query_idx: usize,
198 path_idx: usize,
199 cur_score: f64,
200 extra_lowercase_chars: &BTreeMap<usize, usize>,
201 ) -> f64 {
202 use std::path::MAIN_SEPARATOR;
203
204 if query_idx == self.query.len() {
205 return 1.0;
206 }
207
208 let path_len = prefix.len() + path.len();
209
210 if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
211 return memoized;
212 }
213
214 let mut score = 0.0;
215 let mut best_position = 0;
216
217 let query_char = self.lowercase_query[query_idx];
218 let limit = self.last_positions[query_idx];
219
220 let max_valid_index = (prefix.len() + path_lowercased.len()).saturating_sub(1);
221 let safe_limit = limit.min(max_valid_index);
222
223 let mut last_slash = 0;
224 for j in path_idx..=safe_limit {
225 let extra_lowercase_chars_count = extra_lowercase_chars
226 .iter()
227 .take_while(|(i, _)| i < &&j)
228 .map(|(_, increment)| increment)
229 .sum::<usize>();
230 let j_regular = j - extra_lowercase_chars_count;
231
232 let path_char = if j < prefix.len() {
233 lowercase_prefix[j]
234 } else {
235 let path_index = j - prefix.len();
236 if path_index < path_lowercased.len() {
237 path_lowercased[path_index]
238 } else {
239 continue;
240 }
241 };
242 let is_path_sep = path_char == MAIN_SEPARATOR;
243
244 if query_idx == 0 && is_path_sep {
245 last_slash = j_regular;
246 }
247
248 #[cfg(not(target_os = "windows"))]
249 let need_to_score =
250 query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\');
251 // `query_char == '\\'` breaks `test_match_path_entries` on Windows, `\` is only used as a path separator on Windows.
252 #[cfg(target_os = "windows")]
253 let need_to_score = query_char == path_char || (is_path_sep && query_char == '_');
254 if need_to_score {
255 let curr = if j_regular < prefix.len() {
256 prefix[j_regular]
257 } else {
258 path[j_regular - prefix.len()]
259 };
260
261 let mut char_score = 1.0;
262 if j > path_idx {
263 let last = if j_regular - 1 < prefix.len() {
264 prefix[j_regular - 1]
265 } else {
266 path[j_regular - 1 - prefix.len()]
267 };
268
269 if last == MAIN_SEPARATOR {
270 char_score = 0.9;
271 } else if (last == '-' || last == '_' || last == ' ' || last.is_numeric())
272 || (last.is_lowercase() && curr.is_uppercase())
273 {
274 char_score = 0.8;
275 } else if last == '.' {
276 char_score = 0.7;
277 } else if query_idx == 0 {
278 char_score = BASE_DISTANCE_PENALTY;
279 } else {
280 char_score = MIN_DISTANCE_PENALTY.max(
281 BASE_DISTANCE_PENALTY
282 - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
283 );
284 }
285 }
286
287 // Apply a severe penalty if the case doesn't match.
288 // This will make the exact matches have higher score than the case-insensitive and the
289 // path insensitive matches.
290 if (self.smart_case || curr == MAIN_SEPARATOR) && self.query[query_idx] != curr {
291 char_score *= 0.001;
292 }
293
294 let mut multiplier = char_score;
295
296 // Scale the score based on how deep within the path we found the match.
297 if query_idx == 0 {
298 multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
299 }
300
301 let mut next_score = 1.0;
302 if self.min_score > 0.0 {
303 next_score = cur_score * multiplier;
304 // Scores only decrease. If we can't pass the previous best, bail
305 if next_score < self.min_score {
306 // Ensure that score is non-zero so we use it in the memo table.
307 if score == 0.0 {
308 score = 1e-18;
309 }
310 continue;
311 }
312 }
313
314 let new_score = self.recursive_score_match(
315 path,
316 path_lowercased,
317 prefix,
318 lowercase_prefix,
319 query_idx + 1,
320 j + 1,
321 next_score,
322 extra_lowercase_chars,
323 ) * multiplier;
324
325 if new_score > score {
326 score = new_score;
327 best_position = j_regular;
328 // Optimization: can't score better than 1.
329 if new_score == 1.0 {
330 break;
331 }
332 }
333 }
334 }
335
336 if best_position != 0 {
337 self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
338 }
339
340 self.score_matrix[query_idx * path_len + path_idx] = Some(score);
341 score
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use crate::{PathMatch, PathMatchCandidate};
348
349 use super::*;
350 use std::{
351 path::{Path, PathBuf},
352 sync::Arc,
353 };
354
355 #[test]
356 fn test_get_last_positions() {
357 let mut query: &[char] = &['d', 'c'];
358 let mut matcher = Matcher::new(query, query, query.into(), false);
359 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
360 assert!(!result);
361
362 query = &['c', 'd'];
363 let mut matcher = Matcher::new(query, query, query.into(), false);
364 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
365 assert!(result);
366 assert_eq!(matcher.last_positions, vec![2, 4]);
367
368 query = &['z', '/', 'z', 'f'];
369 let mut matcher = Matcher::new(query, query, query.into(), false);
370 let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
371 assert!(result);
372 assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
373 }
374
375 #[cfg(not(target_os = "windows"))]
376 #[test]
377 fn test_match_path_entries() {
378 let paths = vec![
379 "",
380 "a",
381 "ab",
382 "abC",
383 "abcd",
384 "alphabravocharlie",
385 "AlphaBravoCharlie",
386 "thisisatestdir",
387 "/////ThisIsATestDir",
388 "/this/is/a/test/dir",
389 "/test/tiatd",
390 ];
391
392 assert_eq!(
393 match_single_path_query("abc", false, &paths),
394 vec![
395 ("abC", vec![0, 1, 2]),
396 ("abcd", vec![0, 1, 2]),
397 ("AlphaBravoCharlie", vec![0, 5, 10]),
398 ("alphabravocharlie", vec![4, 5, 10]),
399 ]
400 );
401 assert_eq!(
402 match_single_path_query("t/i/a/t/d", false, &paths),
403 vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
404 );
405
406 assert_eq!(
407 match_single_path_query("tiatd", false, &paths),
408 vec![
409 ("/test/tiatd", vec![6, 7, 8, 9, 10]),
410 ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
411 ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
412 ("thisisatestdir", vec![0, 2, 6, 7, 11]),
413 ]
414 );
415 }
416
417 /// todo(windows)
418 /// Now, on Windows, users can only use the backslash as a path separator.
419 /// I do want to support both the backslash and the forward slash as path separators on Windows.
420 #[cfg(target_os = "windows")]
421 #[test]
422 fn test_match_path_entries() {
423 let paths = vec![
424 "",
425 "a",
426 "ab",
427 "abC",
428 "abcd",
429 "alphabravocharlie",
430 "AlphaBravoCharlie",
431 "thisisatestdir",
432 "\\\\\\\\\\ThisIsATestDir",
433 "\\this\\is\\a\\test\\dir",
434 "\\test\\tiatd",
435 ];
436
437 assert_eq!(
438 match_single_path_query("abc", false, &paths),
439 vec![
440 ("abC", vec![0, 1, 2]),
441 ("abcd", vec![0, 1, 2]),
442 ("AlphaBravoCharlie", vec![0, 5, 10]),
443 ("alphabravocharlie", vec![4, 5, 10]),
444 ]
445 );
446 assert_eq!(
447 match_single_path_query("t\\i\\a\\t\\d", false, &paths),
448 vec![(
449 "\\this\\is\\a\\test\\dir",
450 vec![1, 5, 6, 8, 9, 10, 11, 15, 16]
451 ),]
452 );
453
454 assert_eq!(
455 match_single_path_query("tiatd", false, &paths),
456 vec![
457 ("\\test\\tiatd", vec![6, 7, 8, 9, 10]),
458 ("\\this\\is\\a\\test\\dir", vec![1, 6, 9, 11, 16]),
459 ("\\\\\\\\\\ThisIsATestDir", vec![5, 9, 11, 12, 16]),
460 ("thisisatestdir", vec![0, 2, 6, 7, 11]),
461 ]
462 );
463 }
464
465 #[test]
466 fn test_lowercase_longer_than_uppercase() {
467 // This character has more chars in lower-case than in upper-case.
468 let paths = vec!["\u{0130}"];
469 let query = "\u{0130}";
470 assert_eq!(
471 match_single_path_query(query, false, &paths),
472 vec![("\u{0130}", vec![0])]
473 );
474
475 // Path is the lower-case version of the query
476 let paths = vec!["i\u{307}"];
477 let query = "\u{0130}";
478 assert_eq!(
479 match_single_path_query(query, false, &paths),
480 vec![("i\u{307}", vec![0])]
481 );
482 }
483
484 #[test]
485 fn test_match_multibyte_path_entries() {
486 let paths = vec![
487 "aαbβ/cγdδ",
488 "αβγδ/bcde",
489 "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f",
490 "/d/🆒/h",
491 ];
492 assert_eq!("1️⃣".len(), 7);
493 assert_eq!(
494 match_single_path_query("bcd", false, &paths),
495 vec![
496 ("αβγδ/bcde", vec![9, 10, 11]),
497 ("aαbβ/cγdδ", vec![3, 7, 10]),
498 ]
499 );
500 assert_eq!(
501 match_single_path_query("cde", false, &paths),
502 vec![
503 ("αβγδ/bcde", vec![10, 11, 12]),
504 ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
505 ]
506 );
507 }
508
509 #[test]
510 fn match_unicode_path_entries() {
511 let mixed_unicode_paths = vec![
512 "İolu/oluş",
513 "İstanbul/code",
514 "Athens/Şanlıurfa",
515 "Çanakkale/scripts",
516 "paris/Düzce_İl",
517 "Berlin_Önemli_Ğündem",
518 "KİTAPLIK/london/dosya",
519 "tokyo/kyoto/fuji",
520 "new_york/san_francisco",
521 ];
522
523 assert_eq!(
524 match_single_path_query("İo/oluş", false, &mixed_unicode_paths),
525 vec![("İolu/oluş", vec![0, 2, 4, 6, 8, 10, 12])]
526 );
527
528 assert_eq!(
529 match_single_path_query("İst/code", false, &mixed_unicode_paths),
530 vec![("İstanbul/code", vec![0, 2, 4, 6, 8, 10, 12, 14])]
531 );
532
533 assert_eq!(
534 match_single_path_query("athens/şa", false, &mixed_unicode_paths),
535 vec![("Athens/Şanlıurfa", vec![0, 1, 2, 3, 4, 5, 6, 7, 9])]
536 );
537
538 assert_eq!(
539 match_single_path_query("BerlinÖĞ", false, &mixed_unicode_paths),
540 vec![("Berlin_Önemli_Ğündem", vec![0, 1, 2, 3, 4, 5, 7, 15])]
541 );
542
543 assert_eq!(
544 match_single_path_query("tokyo/fuji", false, &mixed_unicode_paths),
545 vec![("tokyo/kyoto/fuji", vec![0, 1, 2, 3, 4, 5, 12, 13, 14, 15])]
546 );
547
548 let mixed_script_paths = vec![
549 "résumé_Москва",
550 "naïve_київ_implementation",
551 "café_北京_app",
552 "東京_über_driver",
553 "déjà_vu_cairo",
554 "seoul_piñata_game",
555 "voilà_istanbul_result",
556 ];
557
558 assert_eq!(
559 match_single_path_query("résmé", false, &mixed_script_paths),
560 vec![("résumé_Москва", vec![0, 1, 3, 5, 6])]
561 );
562
563 assert_eq!(
564 match_single_path_query("café北京", false, &mixed_script_paths),
565 vec![("café_北京_app", vec![0, 1, 2, 3, 6, 9])]
566 );
567
568 assert_eq!(
569 match_single_path_query("ista", false, &mixed_script_paths),
570 vec![("voilà_istanbul_result", vec![7, 8, 9, 10])]
571 );
572
573 let complex_paths = vec![
574 "document_📚_library",
575 "project_👨👩👧👦_family",
576 "flags_🇯🇵🇺🇸🇪🇺_world",
577 "code_😀😃😄😁_happy",
578 "photo_👩👩👧👦_album",
579 ];
580
581 assert_eq!(
582 match_single_path_query("doc📚lib", false, &complex_paths),
583 vec![("document_📚_library", vec![0, 1, 2, 9, 14, 15, 16])]
584 );
585
586 assert_eq!(
587 match_single_path_query("codehappy", false, &complex_paths),
588 vec![("code_😀😃😄😁_happy", vec![0, 1, 2, 3, 22, 23, 24, 25, 26])]
589 );
590 }
591
592 fn match_single_path_query<'a>(
593 query: &str,
594 smart_case: bool,
595 paths: &[&'a str],
596 ) -> Vec<(&'a str, Vec<usize>)> {
597 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
598 let query = query.chars().collect::<Vec<_>>();
599 let query_chars = CharBag::from(&lowercase_query[..]);
600
601 let path_arcs: Vec<Arc<Path>> = paths
602 .iter()
603 .map(|path| Arc::from(PathBuf::from(path)))
604 .collect::<Vec<_>>();
605 let mut path_entries = Vec::new();
606 for (i, path) in paths.iter().enumerate() {
607 let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
608 let char_bag = CharBag::from(lowercase_path.as_slice());
609 path_entries.push(PathMatchCandidate {
610 is_dir: false,
611 char_bag,
612 path: &path_arcs[i],
613 });
614 }
615
616 let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case);
617
618 let cancel_flag = AtomicBool::new(false);
619 let mut results = Vec::new();
620
621 matcher.match_candidates(
622 &[],
623 &[],
624 path_entries.into_iter(),
625 &mut results,
626 &cancel_flag,
627 |candidate, score, positions| PathMatch {
628 score,
629 worktree_id: 0,
630 positions: positions.clone(),
631 path: Arc::from(candidate.path),
632 path_prefix: "".into(),
633 distance_to_relative_ancestor: usize::MAX,
634 is_dir: false,
635 },
636 );
637 results.sort_by(|a, b| b.cmp(a));
638
639 results
640 .into_iter()
641 .map(|result| {
642 (
643 paths
644 .iter()
645 .copied()
646 .find(|p| result.path.as_ref() == Path::new(p))
647 .unwrap(),
648 result.positions,
649 )
650 })
651 .collect()
652 }
653}