1use std::{
2 borrow::Borrow,
3 collections::BTreeMap,
4 sync::atomic::{self, AtomicBool},
5};
6
7use crate::CharBag;
8
9const BASE_DISTANCE_PENALTY: f64 = 0.6;
10const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
11const MIN_DISTANCE_PENALTY: f64 = 0.2;
12
13// TODO:
14// Use `Path` instead of `&str` for paths.
15pub struct Matcher<'a> {
16 query: &'a [char],
17 lowercase_query: &'a [char],
18 query_char_bag: CharBag,
19 smart_case: bool,
20 penalize_length: bool,
21 min_score: f64,
22 match_positions: Vec<usize>,
23 last_positions: Vec<usize>,
24 score_matrix: Vec<Option<f64>>,
25 best_position_matrix: Vec<usize>,
26}
27
28pub trait MatchCandidate {
29 fn has_chars(&self, bag: CharBag) -> bool;
30 fn candidate_chars(&self) -> impl Iterator<Item = char>;
31}
32
33impl<'a> Matcher<'a> {
34 pub fn new(
35 query: &'a [char],
36 lowercase_query: &'a [char],
37 query_char_bag: CharBag,
38 smart_case: bool,
39 penalize_length: bool,
40 ) -> Self {
41 Self {
42 query,
43 lowercase_query,
44 query_char_bag,
45 min_score: 0.0,
46 last_positions: vec![0; lowercase_query.len()],
47 match_positions: vec![0; query.len()],
48 score_matrix: Vec::new(),
49 best_position_matrix: Vec::new(),
50 smart_case,
51 penalize_length,
52 }
53 }
54
55 /// Filter and score fuzzy match candidates. Results are returned unsorted, in the same order as
56 /// the input candidates.
57 pub(crate) fn match_candidates<C, R, F, T>(
58 &mut self,
59 prefix: &[char],
60 lowercase_prefix: &[char],
61 candidates: impl Iterator<Item = T>,
62 results: &mut Vec<R>,
63 cancel_flag: &AtomicBool,
64 build_match: F,
65 ) where
66 C: MatchCandidate,
67 T: Borrow<C>,
68 F: Fn(&C, f64, &Vec<usize>) -> R,
69 {
70 let mut candidate_chars = Vec::new();
71 let mut lowercase_candidate_chars = Vec::new();
72 let mut extra_lowercase_chars = BTreeMap::new();
73
74 for candidate in candidates {
75 if !candidate.borrow().has_chars(self.query_char_bag) {
76 continue;
77 }
78
79 if cancel_flag.load(atomic::Ordering::Acquire) {
80 break;
81 }
82
83 candidate_chars.clear();
84 lowercase_candidate_chars.clear();
85 extra_lowercase_chars.clear();
86 for (i, c) in candidate.borrow().candidate_chars().enumerate() {
87 candidate_chars.push(c);
88 let mut char_lowercased = c.to_lowercase().collect::<Vec<_>>();
89 if char_lowercased.len() > 1 {
90 extra_lowercase_chars.insert(i, char_lowercased.len() - 1);
91 }
92 lowercase_candidate_chars.append(&mut char_lowercased);
93 }
94
95 if !self.find_last_positions(lowercase_prefix, &lowercase_candidate_chars) {
96 continue;
97 }
98
99 let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
100 self.score_matrix.clear();
101 self.score_matrix.resize(matrix_len, None);
102 self.best_position_matrix.clear();
103 self.best_position_matrix.resize(matrix_len, 0);
104
105 let score = self.score_match(
106 &candidate_chars,
107 &lowercase_candidate_chars,
108 prefix,
109 lowercase_prefix,
110 &extra_lowercase_chars,
111 );
112
113 if score > 0.0 {
114 results.push(build_match(
115 candidate.borrow(),
116 score,
117 &self.match_positions,
118 ));
119 }
120 }
121 }
122
123 fn find_last_positions(
124 &mut self,
125 lowercase_prefix: &[char],
126 lowercase_candidate: &[char],
127 ) -> bool {
128 let mut lowercase_prefix = lowercase_prefix.iter();
129 let mut lowercase_candidate = lowercase_candidate.iter();
130 for (i, char) in self.lowercase_query.iter().enumerate().rev() {
131 if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
132 self.last_positions[i] = j + lowercase_prefix.len();
133 } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
134 self.last_positions[i] = j;
135 } else {
136 return false;
137 }
138 }
139 true
140 }
141
142 fn score_match(
143 &mut self,
144 path: &[char],
145 path_lowercased: &[char],
146 prefix: &[char],
147 lowercase_prefix: &[char],
148 extra_lowercase_chars: &BTreeMap<usize, usize>,
149 ) -> f64 {
150 let score = self.recursive_score_match(
151 path,
152 path_lowercased,
153 prefix,
154 lowercase_prefix,
155 0,
156 0,
157 self.query.len() as f64,
158 extra_lowercase_chars,
159 ) * self.query.len() as f64;
160
161 if score <= 0.0 {
162 return 0.0;
163 }
164 let path_len = prefix.len() + path.len();
165 let mut cur_start = 0;
166 let mut byte_ix = 0;
167 let mut char_ix = 0;
168 for i in 0..self.query.len() {
169 let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
170 while char_ix < match_char_ix {
171 let ch = prefix
172 .get(char_ix)
173 .or_else(|| path.get(char_ix - prefix.len()))
174 .unwrap();
175 byte_ix += ch.len_utf8();
176 char_ix += 1;
177 }
178
179 self.match_positions[i] = byte_ix;
180
181 let matched_ch = prefix
182 .get(match_char_ix)
183 .or_else(|| path.get(match_char_ix - prefix.len()))
184 .unwrap();
185 byte_ix += matched_ch.len_utf8();
186
187 cur_start = match_char_ix + 1;
188 char_ix = match_char_ix + 1;
189 }
190
191 score
192 }
193
194 fn recursive_score_match(
195 &mut self,
196 path: &[char],
197 path_lowercased: &[char],
198 prefix: &[char],
199 lowercase_prefix: &[char],
200 query_idx: usize,
201 path_idx: usize,
202 cur_score: f64,
203 extra_lowercase_chars: &BTreeMap<usize, usize>,
204 ) -> f64 {
205 if query_idx == self.query.len() {
206 return 1.0;
207 }
208
209 let limit = self.last_positions[query_idx];
210 let max_valid_index = (prefix.len() + path_lowercased.len()).saturating_sub(1);
211 let safe_limit = limit.min(max_valid_index);
212
213 if path_idx > safe_limit {
214 return 0.0;
215 }
216
217 let path_len = prefix.len() + path.len();
218 if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
219 return memoized;
220 }
221
222 let mut score = 0.0;
223 let mut best_position = 0;
224
225 let query_char = self.lowercase_query[query_idx];
226
227 let mut last_slash = 0;
228
229 for j in path_idx..=safe_limit {
230 let extra_lowercase_chars_count = extra_lowercase_chars
231 .iter()
232 .take_while(|&(&i, _)| i < j)
233 .map(|(_, increment)| increment)
234 .sum::<usize>();
235 let j_regular = j - extra_lowercase_chars_count;
236
237 let path_char = if j < prefix.len() {
238 lowercase_prefix[j]
239 } else {
240 let path_index = j - prefix.len();
241 match path_lowercased.get(path_index) {
242 Some(&char) => char,
243 None => continue,
244 }
245 };
246 let is_path_sep = path_char == '/';
247
248 if query_idx == 0 && is_path_sep {
249 last_slash = j_regular;
250 }
251 let need_to_score = query_char == path_char || (is_path_sep && query_char == '_');
252 if need_to_score {
253 let curr = match prefix.get(j_regular) {
254 Some(&curr) => curr,
255 None => path[j_regular - prefix.len()],
256 };
257
258 let mut char_score = 1.0;
259 if j > path_idx {
260 let last = match prefix.get(j_regular - 1) {
261 Some(&last) => last,
262 None => path[j_regular - 1 - prefix.len()],
263 };
264
265 if last == '/' {
266 char_score = 0.9;
267 } else if (last == '-' || last == '_' || last == ' ' || last.is_numeric())
268 || (last.is_lowercase() && curr.is_uppercase())
269 {
270 char_score = 0.8;
271 } else if last == '.' {
272 char_score = 0.7;
273 } else if query_idx == 0 {
274 char_score = BASE_DISTANCE_PENALTY;
275 } else {
276 char_score = MIN_DISTANCE_PENALTY.max(
277 BASE_DISTANCE_PENALTY
278 - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
279 );
280 }
281 }
282
283 // Apply a severe penalty if the case doesn't match.
284 // This will make the exact matches have higher score than the case-insensitive and the
285 // path insensitive matches.
286 if (self.smart_case || curr == '/') && self.query[query_idx] != curr {
287 char_score *= 0.001;
288 }
289
290 let mut multiplier = char_score;
291
292 // Scale the score based on how deep within the path we found the match.
293 if self.penalize_length && query_idx == 0 {
294 multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
295 }
296
297 let mut next_score = 1.0;
298 if self.min_score > 0.0 {
299 next_score = cur_score * multiplier;
300 // Scores only decrease. If we can't pass the previous best, bail
301 if next_score < self.min_score {
302 // Ensure that score is non-zero so we use it in the memo table.
303 if score == 0.0 {
304 score = 1e-18;
305 }
306 continue;
307 }
308 }
309
310 let new_score = self.recursive_score_match(
311 path,
312 path_lowercased,
313 prefix,
314 lowercase_prefix,
315 query_idx + 1,
316 j + 1,
317 next_score,
318 extra_lowercase_chars,
319 ) * multiplier;
320
321 if new_score > score {
322 score = new_score;
323 best_position = j_regular;
324 // Optimization: can't score better than 1.
325 if new_score == 1.0 {
326 break;
327 }
328 }
329 }
330 }
331
332 if best_position != 0 {
333 self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
334 }
335
336 self.score_matrix[query_idx * path_len + path_idx] = Some(score);
337 score
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use util::rel_path::{RelPath, rel_path};
344
345 use crate::{PathMatch, PathMatchCandidate};
346
347 use super::*;
348 use std::sync::Arc;
349
350 #[test]
351 fn test_get_last_positions() {
352 let mut query: &[char] = &['d', 'c'];
353 let mut matcher = Matcher::new(query, query, query.into(), false, true);
354 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
355 assert!(!result);
356
357 query = &['c', 'd'];
358 let mut matcher = Matcher::new(query, query, query.into(), false, true);
359 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
360 assert!(result);
361 assert_eq!(matcher.last_positions, vec![2, 4]);
362
363 query = &['z', '/', 'z', 'f'];
364 let mut matcher = Matcher::new(query, query, query.into(), false, true);
365 let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
366 assert!(result);
367 assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
368 }
369
370 #[test]
371 fn test_match_path_entries() {
372 let paths = vec![
373 "",
374 "a",
375 "ab",
376 "abC",
377 "abcd",
378 "alphabravocharlie",
379 "AlphaBravoCharlie",
380 "thisisatestdir",
381 "ThisIsATestDir",
382 "this/is/a/test/dir",
383 "test/tiatd",
384 ];
385
386 assert_eq!(
387 match_single_path_query("abc", false, &paths),
388 vec![
389 ("abC", vec![0, 1, 2]),
390 ("abcd", vec![0, 1, 2]),
391 ("AlphaBravoCharlie", vec![0, 5, 10]),
392 ("alphabravocharlie", vec![4, 5, 10]),
393 ]
394 );
395 assert_eq!(
396 match_single_path_query("t/i/a/t/d", false, &paths),
397 vec![("this/is/a/test/dir", vec![0, 4, 5, 7, 8, 9, 10, 14, 15]),]
398 );
399
400 assert_eq!(
401 match_single_path_query("tiatd", false, &paths),
402 vec![
403 ("test/tiatd", vec![5, 6, 7, 8, 9]),
404 ("ThisIsATestDir", vec![0, 4, 6, 7, 11]),
405 ("this/is/a/test/dir", vec![0, 5, 8, 10, 15]),
406 ("thisisatestdir", vec![0, 2, 6, 7, 11]),
407 ]
408 );
409 }
410
411 #[test]
412 fn test_lowercase_longer_than_uppercase() {
413 // This character has more chars in lower-case than in upper-case.
414 let paths = vec!["\u{0130}"];
415 let query = "\u{0130}";
416 assert_eq!(
417 match_single_path_query(query, false, &paths),
418 vec![("\u{0130}", vec![0])]
419 );
420
421 // Path is the lower-case version of the query
422 let paths = vec!["i\u{307}"];
423 let query = "\u{0130}";
424 assert_eq!(
425 match_single_path_query(query, false, &paths),
426 vec![("i\u{307}", vec![0])]
427 );
428 }
429
430 #[test]
431 fn test_match_multibyte_path_entries() {
432 let paths = vec![
433 "aαbβ/cγdδ",
434 "αβγδ/bcde",
435 "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f",
436 "d/🆒/h",
437 ];
438 assert_eq!("1️⃣".len(), 7);
439 assert_eq!(
440 match_single_path_query("bcd", false, &paths),
441 vec![
442 ("αβγδ/bcde", vec![9, 10, 11]),
443 ("aαbβ/cγdδ", vec![3, 7, 10]),
444 ]
445 );
446 assert_eq!(
447 match_single_path_query("cde", false, &paths),
448 vec![
449 ("αβγδ/bcde", vec![10, 11, 12]),
450 ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
451 ]
452 );
453 }
454
455 #[test]
456 fn match_unicode_path_entries() {
457 let mixed_unicode_paths = vec![
458 "İolu/oluş",
459 "İstanbul/code",
460 "Athens/Şanlıurfa",
461 "Çanakkale/scripts",
462 "paris/Düzce_İl",
463 "Berlin_Önemli_Ğündem",
464 "KİTAPLIK/london/dosya",
465 "tokyo/kyoto/fuji",
466 "new_york/san_francisco",
467 ];
468
469 assert_eq!(
470 match_single_path_query("İo/oluş", false, &mixed_unicode_paths),
471 vec![("İolu/oluş", vec![0, 2, 4, 6, 8, 10, 12])]
472 );
473
474 assert_eq!(
475 match_single_path_query("İst/code", false, &mixed_unicode_paths),
476 vec![("İstanbul/code", vec![0, 2, 4, 6, 8, 10, 12, 14])]
477 );
478
479 assert_eq!(
480 match_single_path_query("athens/şa", false, &mixed_unicode_paths),
481 vec![("Athens/Şanlıurfa", vec![0, 1, 2, 3, 4, 5, 6, 7, 9])]
482 );
483
484 assert_eq!(
485 match_single_path_query("BerlinÖĞ", false, &mixed_unicode_paths),
486 vec![("Berlin_Önemli_Ğündem", vec![0, 1, 2, 3, 4, 5, 7, 15])]
487 );
488
489 assert_eq!(
490 match_single_path_query("tokyo/fuji", false, &mixed_unicode_paths),
491 vec![("tokyo/kyoto/fuji", vec![0, 1, 2, 3, 4, 5, 12, 13, 14, 15])]
492 );
493
494 let mixed_script_paths = vec![
495 "résumé_Москва",
496 "naïve_київ_implementation",
497 "café_北京_app",
498 "東京_über_driver",
499 "déjà_vu_cairo",
500 "seoul_piñata_game",
501 "voilà_istanbul_result",
502 ];
503
504 assert_eq!(
505 match_single_path_query("résmé", false, &mixed_script_paths),
506 vec![("résumé_Москва", vec![0, 1, 3, 5, 6])]
507 );
508
509 assert_eq!(
510 match_single_path_query("café北京", false, &mixed_script_paths),
511 vec![("café_北京_app", vec![0, 1, 2, 3, 6, 9])]
512 );
513
514 assert_eq!(
515 match_single_path_query("ista", false, &mixed_script_paths),
516 vec![("voilà_istanbul_result", vec![7, 8, 9, 10])]
517 );
518
519 let complex_paths = vec![
520 "document_📚_library",
521 "project_👨👩👧👦_family",
522 "flags_🇯🇵🇺🇸🇪🇺_world",
523 "code_😀😃😄😁_happy",
524 "photo_👩👩👧👦_album",
525 ];
526
527 assert_eq!(
528 match_single_path_query("doc📚lib", false, &complex_paths),
529 vec![("document_📚_library", vec![0, 1, 2, 9, 14, 15, 16])]
530 );
531
532 assert_eq!(
533 match_single_path_query("codehappy", false, &complex_paths),
534 vec![("code_😀😃😄😁_happy", vec![0, 1, 2, 3, 22, 23, 24, 25, 26])]
535 );
536 }
537
538 fn match_single_path_query<'a>(
539 query: &str,
540 smart_case: bool,
541 paths: &[&'a str],
542 ) -> Vec<(&'a str, Vec<usize>)> {
543 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
544 let query = query.chars().collect::<Vec<_>>();
545 let query_chars = CharBag::from(&lowercase_query[..]);
546
547 let path_arcs: Vec<Arc<RelPath>> = paths
548 .iter()
549 .map(|path| Arc::from(rel_path(path)))
550 .collect::<Vec<_>>();
551 let mut path_entries = Vec::new();
552 for (i, path) in paths.iter().enumerate() {
553 let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
554 let char_bag = CharBag::from(lowercase_path.as_slice());
555 path_entries.push(PathMatchCandidate {
556 is_dir: false,
557 char_bag,
558 path: &path_arcs[i],
559 });
560 }
561
562 let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case, true);
563
564 let cancel_flag = AtomicBool::new(false);
565 let mut results = Vec::new();
566
567 matcher.match_candidates(
568 &[],
569 &[],
570 path_entries.into_iter(),
571 &mut results,
572 &cancel_flag,
573 |candidate, score, positions| PathMatch {
574 score,
575 worktree_id: 0,
576 positions: positions.clone(),
577 path: candidate.path.into(),
578 path_prefix: RelPath::empty().into(),
579 distance_to_relative_ancestor: usize::MAX,
580 is_dir: false,
581 },
582 );
583 results.sort_by(|a, b| b.cmp(a));
584
585 results
586 .into_iter()
587 .map(|result| {
588 (
589 paths
590 .iter()
591 .copied()
592 .find(|p| result.path.as_ref() == rel_path(p))
593 .unwrap(),
594 result.positions,
595 )
596 })
597 .collect()
598 }
599}