1use std::{
2 borrow::Cow,
3 sync::atomic::{self, AtomicBool},
4};
5
6use crate::CharBag;
7
8const BASE_DISTANCE_PENALTY: f64 = 0.6;
9const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
10const MIN_DISTANCE_PENALTY: f64 = 0.2;
11
12// TODO:
13// Use `Path` instead of `&str` for paths.
14pub struct Matcher<'a> {
15 query: &'a [char],
16 lowercase_query: &'a [char],
17 query_char_bag: CharBag,
18 smart_case: bool,
19 min_score: f64,
20 match_positions: Vec<usize>,
21 last_positions: Vec<usize>,
22 score_matrix: Vec<Option<f64>>,
23 best_position_matrix: Vec<usize>,
24}
25
26pub trait MatchCandidate {
27 fn has_chars(&self, bag: CharBag) -> bool;
28 fn to_string(&self) -> Cow<'_, str>;
29}
30
31impl<'a> Matcher<'a> {
32 pub fn new(
33 query: &'a [char],
34 lowercase_query: &'a [char],
35 query_char_bag: CharBag,
36 smart_case: bool,
37 ) -> Self {
38 Self {
39 query,
40 lowercase_query,
41 query_char_bag,
42 min_score: 0.0,
43 last_positions: vec![0; lowercase_query.len()],
44 match_positions: vec![0; query.len()],
45 score_matrix: Vec::new(),
46 best_position_matrix: Vec::new(),
47 smart_case,
48 }
49 }
50
51 /// Filter and score fuzzy match candidates. Results are returned unsorted, in the same order as
52 /// the input candidates.
53 pub fn match_candidates<C: MatchCandidate, R, F>(
54 &mut self,
55 prefix: &[char],
56 lowercase_prefix: &[char],
57 candidates: impl Iterator<Item = C>,
58 results: &mut Vec<R>,
59 cancel_flag: &AtomicBool,
60 build_match: F,
61 ) where
62 F: Fn(&C, f64, &Vec<usize>) -> R,
63 {
64 let mut candidate_chars = Vec::new();
65 let mut lowercase_candidate_chars = Vec::new();
66
67 for candidate in candidates {
68 if !candidate.has_chars(self.query_char_bag) {
69 continue;
70 }
71
72 if cancel_flag.load(atomic::Ordering::Relaxed) {
73 break;
74 }
75
76 candidate_chars.clear();
77 lowercase_candidate_chars.clear();
78 for c in candidate.to_string().chars() {
79 candidate_chars.push(c);
80 lowercase_candidate_chars.append(&mut c.to_lowercase().collect::<Vec<_>>());
81 }
82
83 if !self.find_last_positions(lowercase_prefix, &lowercase_candidate_chars) {
84 continue;
85 }
86
87 let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
88 self.score_matrix.clear();
89 self.score_matrix.resize(matrix_len, None);
90 self.best_position_matrix.clear();
91 self.best_position_matrix.resize(matrix_len, 0);
92
93 let score = self.score_match(
94 &candidate_chars,
95 &lowercase_candidate_chars,
96 prefix,
97 lowercase_prefix,
98 );
99
100 if score > 0.0 {
101 results.push(build_match(&candidate, score, &self.match_positions));
102 }
103 }
104 }
105
106 fn find_last_positions(
107 &mut self,
108 lowercase_prefix: &[char],
109 lowercase_candidate: &[char],
110 ) -> bool {
111 let mut lowercase_prefix = lowercase_prefix.iter();
112 let mut lowercase_candidate = lowercase_candidate.iter();
113 for (i, char) in self.lowercase_query.iter().enumerate().rev() {
114 if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
115 self.last_positions[i] = j + lowercase_prefix.len();
116 } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
117 self.last_positions[i] = j;
118 } else {
119 return false;
120 }
121 }
122 true
123 }
124
125 fn score_match(
126 &mut self,
127 path: &[char],
128 path_cased: &[char],
129 prefix: &[char],
130 lowercase_prefix: &[char],
131 ) -> f64 {
132 let score = self.recursive_score_match(
133 path,
134 path_cased,
135 prefix,
136 lowercase_prefix,
137 0,
138 0,
139 self.query.len() as f64,
140 ) * self.query.len() as f64;
141
142 if score <= 0.0 {
143 return 0.0;
144 }
145
146 let path_len = prefix.len() + path.len();
147 let mut cur_start = 0;
148 let mut byte_ix = 0;
149 let mut char_ix = 0;
150 for i in 0..self.query.len() {
151 let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
152 while char_ix < match_char_ix {
153 let ch = prefix
154 .get(char_ix)
155 .or_else(|| path.get(char_ix - prefix.len()))
156 .unwrap();
157 byte_ix += ch.len_utf8();
158 char_ix += 1;
159 }
160 cur_start = match_char_ix + 1;
161 self.match_positions[i] = byte_ix;
162 }
163
164 score
165 }
166
167 fn recursive_score_match(
168 &mut self,
169 path: &[char],
170 path_cased: &[char],
171 prefix: &[char],
172 lowercase_prefix: &[char],
173 query_idx: usize,
174 path_idx: usize,
175 cur_score: f64,
176 ) -> f64 {
177 use std::path::MAIN_SEPARATOR;
178
179 if query_idx == self.query.len() {
180 return 1.0;
181 }
182
183 let path_len = prefix.len() + path.len();
184
185 if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
186 return memoized;
187 }
188
189 let mut score = 0.0;
190 let mut best_position = 0;
191
192 let query_char = self.lowercase_query[query_idx];
193 let limit = self.last_positions[query_idx];
194
195 let mut last_slash = 0;
196 for j in path_idx..=limit {
197 let path_char = if j < prefix.len() {
198 lowercase_prefix[j]
199 } else {
200 path_cased[j - prefix.len()]
201 };
202 let is_path_sep = path_char == MAIN_SEPARATOR;
203
204 if query_idx == 0 && is_path_sep {
205 last_slash = j;
206 }
207
208 #[cfg(not(target_os = "windows"))]
209 let need_to_score =
210 query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\');
211 // `query_char == '\\'` breaks `test_match_path_entries` on Windows, `\` is only used as a path separator on Windows.
212 #[cfg(target_os = "windows")]
213 let need_to_score = query_char == path_char || (is_path_sep && query_char == '_');
214 if need_to_score {
215 let curr = if j < prefix.len() {
216 prefix[j]
217 } else {
218 path[j - prefix.len()]
219 };
220
221 let mut char_score = 1.0;
222 if j > path_idx {
223 let last = if j - 1 < prefix.len() {
224 prefix[j - 1]
225 } else {
226 path[j - 1 - prefix.len()]
227 };
228
229 if last == MAIN_SEPARATOR {
230 char_score = 0.9;
231 } else if (last == '-' || last == '_' || last == ' ' || last.is_numeric())
232 || (last.is_lowercase() && curr.is_uppercase())
233 {
234 char_score = 0.8;
235 } else if last == '.' {
236 char_score = 0.7;
237 } else if query_idx == 0 {
238 char_score = BASE_DISTANCE_PENALTY;
239 } else {
240 char_score = MIN_DISTANCE_PENALTY.max(
241 BASE_DISTANCE_PENALTY
242 - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
243 );
244 }
245 }
246
247 // Apply a severe penalty if the case doesn't match.
248 // This will make the exact matches have higher score than the case-insensitive and the
249 // path insensitive matches.
250 if (self.smart_case || curr == MAIN_SEPARATOR) && self.query[query_idx] != curr {
251 char_score *= 0.001;
252 }
253
254 let mut multiplier = char_score;
255
256 // Scale the score based on how deep within the path we found the match.
257 if query_idx == 0 {
258 multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
259 }
260
261 let mut next_score = 1.0;
262 if self.min_score > 0.0 {
263 next_score = cur_score * multiplier;
264 // Scores only decrease. If we can't pass the previous best, bail
265 if next_score < self.min_score {
266 // Ensure that score is non-zero so we use it in the memo table.
267 if score == 0.0 {
268 score = 1e-18;
269 }
270 continue;
271 }
272 }
273
274 let new_score = self.recursive_score_match(
275 path,
276 path_cased,
277 prefix,
278 lowercase_prefix,
279 query_idx + 1,
280 j + 1,
281 next_score,
282 ) * multiplier;
283
284 if new_score > score {
285 score = new_score;
286 best_position = j;
287 // Optimization: can't score better than 1.
288 if new_score == 1.0 {
289 break;
290 }
291 }
292 }
293 }
294
295 if best_position != 0 {
296 self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
297 }
298
299 self.score_matrix[query_idx * path_len + path_idx] = Some(score);
300 score
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use crate::{PathMatch, PathMatchCandidate};
307
308 use super::*;
309 use std::{
310 path::{Path, PathBuf},
311 sync::Arc,
312 };
313
314 #[test]
315 fn test_get_last_positions() {
316 let mut query: &[char] = &['d', 'c'];
317 let mut matcher = Matcher::new(query, query, query.into(), false);
318 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
319 assert!(!result);
320
321 query = &['c', 'd'];
322 let mut matcher = Matcher::new(query, query, query.into(), false);
323 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
324 assert!(result);
325 assert_eq!(matcher.last_positions, vec![2, 4]);
326
327 query = &['z', '/', 'z', 'f'];
328 let mut matcher = Matcher::new(query, query, query.into(), false);
329 let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
330 assert!(result);
331 assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
332 }
333
334 #[cfg(not(target_os = "windows"))]
335 #[test]
336 fn test_match_path_entries() {
337 let paths = vec![
338 "",
339 "a",
340 "ab",
341 "abC",
342 "abcd",
343 "alphabravocharlie",
344 "AlphaBravoCharlie",
345 "thisisatestdir",
346 "/////ThisIsATestDir",
347 "/this/is/a/test/dir",
348 "/test/tiatd",
349 ];
350
351 assert_eq!(
352 match_single_path_query("abc", false, &paths),
353 vec![
354 ("abC", vec![0, 1, 2]),
355 ("abcd", vec![0, 1, 2]),
356 ("AlphaBravoCharlie", vec![0, 5, 10]),
357 ("alphabravocharlie", vec![4, 5, 10]),
358 ]
359 );
360 assert_eq!(
361 match_single_path_query("t/i/a/t/d", false, &paths),
362 vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
363 );
364
365 assert_eq!(
366 match_single_path_query("tiatd", false, &paths),
367 vec![
368 ("/test/tiatd", vec![6, 7, 8, 9, 10]),
369 ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
370 ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
371 ("thisisatestdir", vec![0, 2, 6, 7, 11]),
372 ]
373 );
374 }
375
376 /// todo(windows)
377 /// Now, on Windows, users can only use the backslash as a path separator.
378 /// I do want to support both the backslash and the forward slash as path separators on Windows.
379 #[cfg(target_os = "windows")]
380 #[test]
381 fn test_match_path_entries() {
382 let paths = vec![
383 "",
384 "a",
385 "ab",
386 "abC",
387 "abcd",
388 "alphabravocharlie",
389 "AlphaBravoCharlie",
390 "thisisatestdir",
391 "\\\\\\\\\\ThisIsATestDir",
392 "\\this\\is\\a\\test\\dir",
393 "\\test\\tiatd",
394 ];
395
396 assert_eq!(
397 match_single_path_query("abc", false, &paths),
398 vec![
399 ("abC", vec![0, 1, 2]),
400 ("abcd", vec![0, 1, 2]),
401 ("AlphaBravoCharlie", vec![0, 5, 10]),
402 ("alphabravocharlie", vec![4, 5, 10]),
403 ]
404 );
405 assert_eq!(
406 match_single_path_query("t\\i\\a\\t\\d", false, &paths),
407 vec![(
408 "\\this\\is\\a\\test\\dir",
409 vec![1, 5, 6, 8, 9, 10, 11, 15, 16]
410 ),]
411 );
412
413 assert_eq!(
414 match_single_path_query("tiatd", false, &paths),
415 vec![
416 ("\\test\\tiatd", vec![6, 7, 8, 9, 10]),
417 ("\\this\\is\\a\\test\\dir", vec![1, 6, 9, 11, 16]),
418 ("\\\\\\\\\\ThisIsATestDir", vec![5, 9, 11, 12, 16]),
419 ("thisisatestdir", vec![0, 2, 6, 7, 11]),
420 ]
421 );
422 }
423
424 #[test]
425 fn test_lowercase_longer_than_uppercase() {
426 // This character has more chars in lower-case than in upper-case.
427 let paths = vec!["\u{0130}"];
428 let query = "\u{0130}";
429 assert_eq!(
430 match_single_path_query(query, false, &paths),
431 vec![("\u{0130}", vec![0])]
432 );
433
434 // Path is the lower-case version of the query
435 let paths = vec!["i\u{307}"];
436 let query = "\u{0130}";
437 assert_eq!(
438 match_single_path_query(query, false, &paths),
439 vec![("i\u{307}", vec![0])]
440 );
441 }
442
443 #[test]
444 fn test_match_multibyte_path_entries() {
445 let paths = vec![
446 "aαbβ/cγdδ",
447 "αβγδ/bcde",
448 "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f",
449 "/d/🆒/h",
450 ];
451 assert_eq!("1️⃣".len(), 7);
452 assert_eq!(
453 match_single_path_query("bcd", false, &paths),
454 vec![
455 ("αβγδ/bcde", vec![9, 10, 11]),
456 ("aαbβ/cγdδ", vec![3, 7, 10]),
457 ]
458 );
459 assert_eq!(
460 match_single_path_query("cde", false, &paths),
461 vec![
462 ("αβγδ/bcde", vec![10, 11, 12]),
463 ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
464 ]
465 );
466 }
467
468 fn match_single_path_query<'a>(
469 query: &str,
470 smart_case: bool,
471 paths: &[&'a str],
472 ) -> Vec<(&'a str, Vec<usize>)> {
473 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
474 let query = query.chars().collect::<Vec<_>>();
475 let query_chars = CharBag::from(&lowercase_query[..]);
476
477 let path_arcs: Vec<Arc<Path>> = paths
478 .iter()
479 .map(|path| Arc::from(PathBuf::from(path)))
480 .collect::<Vec<_>>();
481 let mut path_entries = Vec::new();
482 for (i, path) in paths.iter().enumerate() {
483 let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
484 let char_bag = CharBag::from(lowercase_path.as_slice());
485 path_entries.push(PathMatchCandidate {
486 is_dir: false,
487 char_bag,
488 path: &path_arcs[i],
489 });
490 }
491
492 let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case);
493
494 let cancel_flag = AtomicBool::new(false);
495 let mut results = Vec::new();
496
497 matcher.match_candidates(
498 &[],
499 &[],
500 path_entries.into_iter(),
501 &mut results,
502 &cancel_flag,
503 |candidate, score, positions| PathMatch {
504 score,
505 worktree_id: 0,
506 positions: positions.clone(),
507 path: Arc::from(candidate.path),
508 path_prefix: "".into(),
509 distance_to_relative_ancestor: usize::MAX,
510 is_dir: false,
511 },
512 );
513 results.sort_by(|a, b| b.cmp(a));
514
515 results
516 .into_iter()
517 .map(|result| {
518 (
519 paths
520 .iter()
521 .copied()
522 .find(|p| result.path.as_ref() == Path::new(p))
523 .unwrap(),
524 result.positions,
525 )
526 })
527 .collect()
528 }
529}