1use std::{
2 borrow::Cow,
3 sync::atomic::{self, AtomicBool},
4};
5
6use crate::CharBag;
7
8const BASE_DISTANCE_PENALTY: f64 = 0.6;
9const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
10const MIN_DISTANCE_PENALTY: f64 = 0.2;
11
12// TODO:
13// Use `Path` instead of `&str` for paths.
14pub struct Matcher<'a> {
15 query: &'a [char],
16 lowercase_query: &'a [char],
17 query_char_bag: CharBag,
18 smart_case: bool,
19 min_score: f64,
20 match_positions: Vec<usize>,
21 last_positions: Vec<usize>,
22 score_matrix: Vec<Option<f64>>,
23 best_position_matrix: Vec<usize>,
24}
25
26pub trait MatchCandidate {
27 fn has_chars(&self, bag: CharBag) -> bool;
28 fn to_string(&self) -> Cow<'_, str>;
29}
30
31impl<'a> Matcher<'a> {
32 pub fn new(
33 query: &'a [char],
34 lowercase_query: &'a [char],
35 query_char_bag: CharBag,
36 smart_case: bool,
37 ) -> Self {
38 Self {
39 query,
40 lowercase_query,
41 query_char_bag,
42 min_score: 0.0,
43 last_positions: vec![0; lowercase_query.len()],
44 match_positions: vec![0; query.len()],
45 score_matrix: Vec::new(),
46 best_position_matrix: Vec::new(),
47 smart_case,
48 }
49 }
50
51 /// Filter and score fuzzy match candidates. Results are returned unsorted, in the same order as
52 /// the input candidates.
53 pub fn match_candidates<C: MatchCandidate, R, F>(
54 &mut self,
55 prefix: &[char],
56 lowercase_prefix: &[char],
57 candidates: impl Iterator<Item = C>,
58 results: &mut Vec<R>,
59 cancel_flag: &AtomicBool,
60 build_match: F,
61 ) where
62 F: Fn(&C, f64, &Vec<usize>) -> R,
63 {
64 let mut candidate_chars = Vec::new();
65 let mut lowercase_candidate_chars = Vec::new();
66
67 for candidate in candidates {
68 if !candidate.has_chars(self.query_char_bag) {
69 continue;
70 }
71
72 if cancel_flag.load(atomic::Ordering::Relaxed) {
73 break;
74 }
75
76 candidate_chars.clear();
77 lowercase_candidate_chars.clear();
78 for c in candidate.to_string().chars() {
79 candidate_chars.push(c);
80 lowercase_candidate_chars.append(&mut c.to_lowercase().collect::<Vec<_>>());
81 }
82
83 if !self.find_last_positions(lowercase_prefix, &lowercase_candidate_chars) {
84 continue;
85 }
86
87 let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
88 self.score_matrix.clear();
89 self.score_matrix.resize(matrix_len, None);
90 self.best_position_matrix.clear();
91 self.best_position_matrix.resize(matrix_len, 0);
92
93 let score = self.score_match(
94 &candidate_chars,
95 &lowercase_candidate_chars,
96 prefix,
97 lowercase_prefix,
98 );
99
100 if score > 0.0 {
101 results.push(build_match(&candidate, score, &self.match_positions));
102 }
103 }
104 }
105
106 fn find_last_positions(
107 &mut self,
108 lowercase_prefix: &[char],
109 lowercase_candidate: &[char],
110 ) -> bool {
111 let mut lowercase_prefix = lowercase_prefix.iter();
112 let mut lowercase_candidate = lowercase_candidate.iter();
113 for (i, char) in self.lowercase_query.iter().enumerate().rev() {
114 if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
115 self.last_positions[i] = j + lowercase_prefix.len();
116 } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
117 self.last_positions[i] = j;
118 } else {
119 return false;
120 }
121 }
122 true
123 }
124
125 fn score_match(
126 &mut self,
127 path: &[char],
128 path_cased: &[char],
129 prefix: &[char],
130 lowercase_prefix: &[char],
131 ) -> f64 {
132 let score = self.recursive_score_match(
133 path,
134 path_cased,
135 prefix,
136 lowercase_prefix,
137 0,
138 0,
139 self.query.len() as f64,
140 ) * self.query.len() as f64;
141
142 if score <= 0.0 {
143 return 0.0;
144 }
145
146 let path_len = prefix.len() + path.len();
147 let mut cur_start = 0;
148 let mut byte_ix = 0;
149 let mut char_ix = 0;
150 for i in 0..self.query.len() {
151 let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
152 while char_ix < match_char_ix {
153 let ch = prefix
154 .get(char_ix)
155 .or_else(|| path.get(char_ix - prefix.len()))
156 .unwrap();
157 byte_ix += ch.len_utf8();
158 char_ix += 1;
159 }
160 cur_start = match_char_ix + 1;
161 self.match_positions[i] = byte_ix;
162 }
163
164 score
165 }
166
167 #[allow(clippy::too_many_arguments)]
168 fn recursive_score_match(
169 &mut self,
170 path: &[char],
171 path_cased: &[char],
172 prefix: &[char],
173 lowercase_prefix: &[char],
174 query_idx: usize,
175 path_idx: usize,
176 cur_score: f64,
177 ) -> f64 {
178 use std::path::MAIN_SEPARATOR;
179
180 if query_idx == self.query.len() {
181 return 1.0;
182 }
183
184 let path_len = prefix.len() + path.len();
185
186 if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
187 return memoized;
188 }
189
190 let mut score = 0.0;
191 let mut best_position = 0;
192
193 let query_char = self.lowercase_query[query_idx];
194 let limit = self.last_positions[query_idx];
195
196 let mut last_slash = 0;
197 for j in path_idx..=limit {
198 let path_char = if j < prefix.len() {
199 lowercase_prefix[j]
200 } else {
201 path_cased[j - prefix.len()]
202 };
203 let is_path_sep = path_char == MAIN_SEPARATOR;
204
205 if query_idx == 0 && is_path_sep {
206 last_slash = j;
207 }
208
209 #[cfg(not(target_os = "windows"))]
210 let need_to_score =
211 query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\');
212 // `query_char == '\\'` breaks `test_match_path_entries` on Windows, `\` is only used as a path separator on Windows.
213 #[cfg(target_os = "windows")]
214 let need_to_score = query_char == path_char || (is_path_sep && query_char == '_');
215 if need_to_score {
216 let curr = if j < prefix.len() {
217 prefix[j]
218 } else {
219 path[j - prefix.len()]
220 };
221
222 let mut char_score = 1.0;
223 if j > path_idx {
224 let last = if j - 1 < prefix.len() {
225 prefix[j - 1]
226 } else {
227 path[j - 1 - prefix.len()]
228 };
229
230 if last == MAIN_SEPARATOR {
231 char_score = 0.9;
232 } else if (last == '-' || last == '_' || last == ' ' || last.is_numeric())
233 || (last.is_lowercase() && curr.is_uppercase())
234 {
235 char_score = 0.8;
236 } else if last == '.' {
237 char_score = 0.7;
238 } else if query_idx == 0 {
239 char_score = BASE_DISTANCE_PENALTY;
240 } else {
241 char_score = MIN_DISTANCE_PENALTY.max(
242 BASE_DISTANCE_PENALTY
243 - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
244 );
245 }
246 }
247
248 // Apply a severe penalty if the case doesn't match.
249 // This will make the exact matches have higher score than the case-insensitive and the
250 // path insensitive matches.
251 if (self.smart_case || curr == MAIN_SEPARATOR) && self.query[query_idx] != curr {
252 char_score *= 0.001;
253 }
254
255 let mut multiplier = char_score;
256
257 // Scale the score based on how deep within the path we found the match.
258 if query_idx == 0 {
259 multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
260 }
261
262 let mut next_score = 1.0;
263 if self.min_score > 0.0 {
264 next_score = cur_score * multiplier;
265 // Scores only decrease. If we can't pass the previous best, bail
266 if next_score < self.min_score {
267 // Ensure that score is non-zero so we use it in the memo table.
268 if score == 0.0 {
269 score = 1e-18;
270 }
271 continue;
272 }
273 }
274
275 let new_score = self.recursive_score_match(
276 path,
277 path_cased,
278 prefix,
279 lowercase_prefix,
280 query_idx + 1,
281 j + 1,
282 next_score,
283 ) * multiplier;
284
285 if new_score > score {
286 score = new_score;
287 best_position = j;
288 // Optimization: can't score better than 1.
289 if new_score == 1.0 {
290 break;
291 }
292 }
293 }
294 }
295
296 if best_position != 0 {
297 self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
298 }
299
300 self.score_matrix[query_idx * path_len + path_idx] = Some(score);
301 score
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use crate::{PathMatch, PathMatchCandidate};
308
309 use super::*;
310 use std::{
311 path::{Path, PathBuf},
312 sync::Arc,
313 };
314
315 #[test]
316 fn test_get_last_positions() {
317 let mut query: &[char] = &['d', 'c'];
318 let mut matcher = Matcher::new(query, query, query.into(), false);
319 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
320 assert!(!result);
321
322 query = &['c', 'd'];
323 let mut matcher = Matcher::new(query, query, query.into(), false);
324 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
325 assert!(result);
326 assert_eq!(matcher.last_positions, vec![2, 4]);
327
328 query = &['z', '/', 'z', 'f'];
329 let mut matcher = Matcher::new(query, query, query.into(), false);
330 let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
331 assert!(result);
332 assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
333 }
334
335 #[cfg(not(target_os = "windows"))]
336 #[test]
337 fn test_match_path_entries() {
338 let paths = vec![
339 "",
340 "a",
341 "ab",
342 "abC",
343 "abcd",
344 "alphabravocharlie",
345 "AlphaBravoCharlie",
346 "thisisatestdir",
347 "/////ThisIsATestDir",
348 "/this/is/a/test/dir",
349 "/test/tiatd",
350 ];
351
352 assert_eq!(
353 match_single_path_query("abc", false, &paths),
354 vec![
355 ("abC", vec![0, 1, 2]),
356 ("abcd", vec![0, 1, 2]),
357 ("AlphaBravoCharlie", vec![0, 5, 10]),
358 ("alphabravocharlie", vec![4, 5, 10]),
359 ]
360 );
361 assert_eq!(
362 match_single_path_query("t/i/a/t/d", false, &paths),
363 vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
364 );
365
366 assert_eq!(
367 match_single_path_query("tiatd", false, &paths),
368 vec![
369 ("/test/tiatd", vec![6, 7, 8, 9, 10]),
370 ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
371 ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
372 ("thisisatestdir", vec![0, 2, 6, 7, 11]),
373 ]
374 );
375 }
376
377 /// todo(windows)
378 /// Now, on Windows, users can only use the backslash as a path separator.
379 /// I do want to support both the backslash and the forward slash as path separators on Windows.
380 #[cfg(target_os = "windows")]
381 #[test]
382 fn test_match_path_entries() {
383 let paths = vec![
384 "",
385 "a",
386 "ab",
387 "abC",
388 "abcd",
389 "alphabravocharlie",
390 "AlphaBravoCharlie",
391 "thisisatestdir",
392 "\\\\\\\\\\ThisIsATestDir",
393 "\\this\\is\\a\\test\\dir",
394 "\\test\\tiatd",
395 ];
396
397 assert_eq!(
398 match_single_path_query("abc", false, &paths),
399 vec![
400 ("abC", vec![0, 1, 2]),
401 ("abcd", vec![0, 1, 2]),
402 ("AlphaBravoCharlie", vec![0, 5, 10]),
403 ("alphabravocharlie", vec![4, 5, 10]),
404 ]
405 );
406 assert_eq!(
407 match_single_path_query("t\\i\\a\\t\\d", false, &paths),
408 vec![(
409 "\\this\\is\\a\\test\\dir",
410 vec![1, 5, 6, 8, 9, 10, 11, 15, 16]
411 ),]
412 );
413
414 assert_eq!(
415 match_single_path_query("tiatd", false, &paths),
416 vec![
417 ("\\test\\tiatd", vec![6, 7, 8, 9, 10]),
418 ("\\this\\is\\a\\test\\dir", vec![1, 6, 9, 11, 16]),
419 ("\\\\\\\\\\ThisIsATestDir", vec![5, 9, 11, 12, 16]),
420 ("thisisatestdir", vec![0, 2, 6, 7, 11]),
421 ]
422 );
423 }
424
425 #[test]
426 fn test_lowercase_longer_than_uppercase() {
427 // This character has more chars in lower-case than in upper-case.
428 let paths = vec!["\u{0130}"];
429 let query = "\u{0130}";
430 assert_eq!(
431 match_single_path_query(query, false, &paths),
432 vec![("\u{0130}", vec![0])]
433 );
434
435 // Path is the lower-case version of the query
436 let paths = vec!["i\u{307}"];
437 let query = "\u{0130}";
438 assert_eq!(
439 match_single_path_query(query, false, &paths),
440 vec![("i\u{307}", vec![0])]
441 );
442 }
443
444 #[test]
445 fn test_match_multibyte_path_entries() {
446 let paths = vec![
447 "aαbβ/cγdδ",
448 "αβγδ/bcde",
449 "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f",
450 "/d/🆒/h",
451 ];
452 assert_eq!("1️⃣".len(), 7);
453 assert_eq!(
454 match_single_path_query("bcd", false, &paths),
455 vec![
456 ("αβγδ/bcde", vec![9, 10, 11]),
457 ("aαbβ/cγdδ", vec![3, 7, 10]),
458 ]
459 );
460 assert_eq!(
461 match_single_path_query("cde", false, &paths),
462 vec![
463 ("αβγδ/bcde", vec![10, 11, 12]),
464 ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
465 ]
466 );
467 }
468
469 fn match_single_path_query<'a>(
470 query: &str,
471 smart_case: bool,
472 paths: &[&'a str],
473 ) -> Vec<(&'a str, Vec<usize>)> {
474 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
475 let query = query.chars().collect::<Vec<_>>();
476 let query_chars = CharBag::from(&lowercase_query[..]);
477
478 let path_arcs: Vec<Arc<Path>> = paths
479 .iter()
480 .map(|path| Arc::from(PathBuf::from(path)))
481 .collect::<Vec<_>>();
482 let mut path_entries = Vec::new();
483 for (i, path) in paths.iter().enumerate() {
484 let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
485 let char_bag = CharBag::from(lowercase_path.as_slice());
486 path_entries.push(PathMatchCandidate {
487 is_dir: false,
488 char_bag,
489 path: &path_arcs[i],
490 });
491 }
492
493 let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case);
494
495 let cancel_flag = AtomicBool::new(false);
496 let mut results = Vec::new();
497
498 matcher.match_candidates(
499 &[],
500 &[],
501 path_entries.into_iter(),
502 &mut results,
503 &cancel_flag,
504 |candidate, score, positions| PathMatch {
505 score,
506 worktree_id: 0,
507 positions: positions.clone(),
508 path: Arc::from(candidate.path),
509 path_prefix: "".into(),
510 distance_to_relative_ancestor: usize::MAX,
511 is_dir: false,
512 },
513 );
514 results.sort_by(|a, b| b.cmp(a));
515
516 results
517 .into_iter()
518 .map(|result| {
519 (
520 paths
521 .iter()
522 .copied()
523 .find(|p| result.path.as_ref() == Path::new(p))
524 .unwrap(),
525 result.positions,
526 )
527 })
528 .collect()
529 }
530}