1use std::{
2 borrow::{Borrow, Cow},
3 sync::atomic::{self, AtomicBool},
4};
5
6use crate::CharBag;
7
8const BASE_DISTANCE_PENALTY: f64 = 0.6;
9const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
10const MIN_DISTANCE_PENALTY: f64 = 0.2;
11
12// TODO:
13// Use `Path` instead of `&str` for paths.
14pub struct Matcher<'a> {
15 query: &'a [char],
16 lowercase_query: &'a [char],
17 query_char_bag: CharBag,
18 smart_case: bool,
19 min_score: f64,
20 match_positions: Vec<usize>,
21 last_positions: Vec<usize>,
22 score_matrix: Vec<Option<f64>>,
23 best_position_matrix: Vec<usize>,
24}
25
26pub trait MatchCandidate {
27 fn has_chars(&self, bag: CharBag) -> bool;
28 fn to_string(&self) -> Cow<'_, str>;
29}
30
31impl<'a> Matcher<'a> {
32 pub fn new(
33 query: &'a [char],
34 lowercase_query: &'a [char],
35 query_char_bag: CharBag,
36 smart_case: bool,
37 ) -> Self {
38 Self {
39 query,
40 lowercase_query,
41 query_char_bag,
42 min_score: 0.0,
43 last_positions: vec![0; lowercase_query.len()],
44 match_positions: vec![0; query.len()],
45 score_matrix: Vec::new(),
46 best_position_matrix: Vec::new(),
47 smart_case,
48 }
49 }
50
51 /// Filter and score fuzzy match candidates. Results are returned unsorted, in the same order as
52 /// the input candidates.
53 pub fn match_candidates<C, R, F, T>(
54 &mut self,
55 prefix: &[char],
56 lowercase_prefix: &[char],
57 candidates: impl Iterator<Item = T>,
58 results: &mut Vec<R>,
59 cancel_flag: &AtomicBool,
60 build_match: F,
61 ) where
62 C: MatchCandidate,
63 T: Borrow<C>,
64 F: Fn(&C, f64, &Vec<usize>) -> R,
65 {
66 let mut candidate_chars = Vec::new();
67 let mut lowercase_candidate_chars = Vec::new();
68
69 for candidate in candidates {
70 if !candidate.borrow().has_chars(self.query_char_bag) {
71 continue;
72 }
73
74 if cancel_flag.load(atomic::Ordering::Relaxed) {
75 break;
76 }
77
78 candidate_chars.clear();
79 lowercase_candidate_chars.clear();
80 for c in candidate.borrow().to_string().chars() {
81 candidate_chars.push(c);
82 lowercase_candidate_chars.append(&mut c.to_lowercase().collect::<Vec<_>>());
83 }
84
85 if !self.find_last_positions(lowercase_prefix, &lowercase_candidate_chars) {
86 continue;
87 }
88
89 let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
90 self.score_matrix.clear();
91 self.score_matrix.resize(matrix_len, None);
92 self.best_position_matrix.clear();
93 self.best_position_matrix.resize(matrix_len, 0);
94
95 let score = self.score_match(
96 &candidate_chars,
97 &lowercase_candidate_chars,
98 prefix,
99 lowercase_prefix,
100 );
101
102 if score > 0.0 {
103 results.push(build_match(
104 candidate.borrow(),
105 score,
106 &self.match_positions,
107 ));
108 }
109 }
110 }
111
112 fn find_last_positions(
113 &mut self,
114 lowercase_prefix: &[char],
115 lowercase_candidate: &[char],
116 ) -> bool {
117 let mut lowercase_prefix = lowercase_prefix.iter();
118 let mut lowercase_candidate = lowercase_candidate.iter();
119 for (i, char) in self.lowercase_query.iter().enumerate().rev() {
120 if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
121 self.last_positions[i] = j + lowercase_prefix.len();
122 } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
123 self.last_positions[i] = j;
124 } else {
125 return false;
126 }
127 }
128 true
129 }
130
131 fn score_match(
132 &mut self,
133 path: &[char],
134 path_cased: &[char],
135 prefix: &[char],
136 lowercase_prefix: &[char],
137 ) -> f64 {
138 let score = self.recursive_score_match(
139 path,
140 path_cased,
141 prefix,
142 lowercase_prefix,
143 0,
144 0,
145 self.query.len() as f64,
146 ) * self.query.len() as f64;
147
148 if score <= 0.0 {
149 return 0.0;
150 }
151
152 let path_len = prefix.len() + path.len();
153 let mut cur_start = 0;
154 let mut byte_ix = 0;
155 let mut char_ix = 0;
156 for i in 0..self.query.len() {
157 let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
158 while char_ix < match_char_ix {
159 let ch = prefix
160 .get(char_ix)
161 .or_else(|| path.get(char_ix - prefix.len()))
162 .unwrap();
163 byte_ix += ch.len_utf8();
164 char_ix += 1;
165 }
166 cur_start = match_char_ix + 1;
167 self.match_positions[i] = byte_ix;
168 }
169
170 score
171 }
172
173 fn recursive_score_match(
174 &mut self,
175 path: &[char],
176 path_cased: &[char],
177 prefix: &[char],
178 lowercase_prefix: &[char],
179 query_idx: usize,
180 path_idx: usize,
181 cur_score: f64,
182 ) -> f64 {
183 use std::path::MAIN_SEPARATOR;
184
185 if query_idx == self.query.len() {
186 return 1.0;
187 }
188
189 let path_len = prefix.len() + path.len();
190
191 if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
192 return memoized;
193 }
194
195 let mut score = 0.0;
196 let mut best_position = 0;
197
198 let query_char = self.lowercase_query[query_idx];
199 let limit = self.last_positions[query_idx];
200
201 let mut last_slash = 0;
202 for j in path_idx..=limit {
203 let path_char = if j < prefix.len() {
204 lowercase_prefix[j]
205 } else {
206 path_cased[j - prefix.len()]
207 };
208 let is_path_sep = path_char == MAIN_SEPARATOR;
209
210 if query_idx == 0 && is_path_sep {
211 last_slash = j;
212 }
213
214 #[cfg(not(target_os = "windows"))]
215 let need_to_score =
216 query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\');
217 // `query_char == '\\'` breaks `test_match_path_entries` on Windows, `\` is only used as a path separator on Windows.
218 #[cfg(target_os = "windows")]
219 let need_to_score = query_char == path_char || (is_path_sep && query_char == '_');
220 if need_to_score {
221 let curr = if j < prefix.len() {
222 prefix[j]
223 } else {
224 path[j - prefix.len()]
225 };
226
227 let mut char_score = 1.0;
228 if j > path_idx {
229 let last = if j - 1 < prefix.len() {
230 prefix[j - 1]
231 } else {
232 path[j - 1 - prefix.len()]
233 };
234
235 if last == MAIN_SEPARATOR {
236 char_score = 0.9;
237 } else if (last == '-' || last == '_' || last == ' ' || last.is_numeric())
238 || (last.is_lowercase() && curr.is_uppercase())
239 {
240 char_score = 0.8;
241 } else if last == '.' {
242 char_score = 0.7;
243 } else if query_idx == 0 {
244 char_score = BASE_DISTANCE_PENALTY;
245 } else {
246 char_score = MIN_DISTANCE_PENALTY.max(
247 BASE_DISTANCE_PENALTY
248 - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
249 );
250 }
251 }
252
253 // Apply a severe penalty if the case doesn't match.
254 // This will make the exact matches have higher score than the case-insensitive and the
255 // path insensitive matches.
256 if (self.smart_case || curr == MAIN_SEPARATOR) && self.query[query_idx] != curr {
257 char_score *= 0.001;
258 }
259
260 let mut multiplier = char_score;
261
262 // Scale the score based on how deep within the path we found the match.
263 if query_idx == 0 {
264 multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
265 }
266
267 let mut next_score = 1.0;
268 if self.min_score > 0.0 {
269 next_score = cur_score * multiplier;
270 // Scores only decrease. If we can't pass the previous best, bail
271 if next_score < self.min_score {
272 // Ensure that score is non-zero so we use it in the memo table.
273 if score == 0.0 {
274 score = 1e-18;
275 }
276 continue;
277 }
278 }
279
280 let new_score = self.recursive_score_match(
281 path,
282 path_cased,
283 prefix,
284 lowercase_prefix,
285 query_idx + 1,
286 j + 1,
287 next_score,
288 ) * multiplier;
289
290 if new_score > score {
291 score = new_score;
292 best_position = j;
293 // Optimization: can't score better than 1.
294 if new_score == 1.0 {
295 break;
296 }
297 }
298 }
299 }
300
301 if best_position != 0 {
302 self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
303 }
304
305 self.score_matrix[query_idx * path_len + path_idx] = Some(score);
306 score
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use crate::{PathMatch, PathMatchCandidate};
313
314 use super::*;
315 use std::{
316 path::{Path, PathBuf},
317 sync::Arc,
318 };
319
320 #[test]
321 fn test_get_last_positions() {
322 let mut query: &[char] = &['d', 'c'];
323 let mut matcher = Matcher::new(query, query, query.into(), false);
324 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
325 assert!(!result);
326
327 query = &['c', 'd'];
328 let mut matcher = Matcher::new(query, query, query.into(), false);
329 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
330 assert!(result);
331 assert_eq!(matcher.last_positions, vec![2, 4]);
332
333 query = &['z', '/', 'z', 'f'];
334 let mut matcher = Matcher::new(query, query, query.into(), false);
335 let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
336 assert!(result);
337 assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
338 }
339
340 #[cfg(not(target_os = "windows"))]
341 #[test]
342 fn test_match_path_entries() {
343 let paths = vec![
344 "",
345 "a",
346 "ab",
347 "abC",
348 "abcd",
349 "alphabravocharlie",
350 "AlphaBravoCharlie",
351 "thisisatestdir",
352 "/////ThisIsATestDir",
353 "/this/is/a/test/dir",
354 "/test/tiatd",
355 ];
356
357 assert_eq!(
358 match_single_path_query("abc", false, &paths),
359 vec![
360 ("abC", vec![0, 1, 2]),
361 ("abcd", vec![0, 1, 2]),
362 ("AlphaBravoCharlie", vec![0, 5, 10]),
363 ("alphabravocharlie", vec![4, 5, 10]),
364 ]
365 );
366 assert_eq!(
367 match_single_path_query("t/i/a/t/d", false, &paths),
368 vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
369 );
370
371 assert_eq!(
372 match_single_path_query("tiatd", false, &paths),
373 vec![
374 ("/test/tiatd", vec![6, 7, 8, 9, 10]),
375 ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
376 ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
377 ("thisisatestdir", vec![0, 2, 6, 7, 11]),
378 ]
379 );
380 }
381
382 /// todo(windows)
383 /// Now, on Windows, users can only use the backslash as a path separator.
384 /// I do want to support both the backslash and the forward slash as path separators on Windows.
385 #[cfg(target_os = "windows")]
386 #[test]
387 fn test_match_path_entries() {
388 let paths = vec![
389 "",
390 "a",
391 "ab",
392 "abC",
393 "abcd",
394 "alphabravocharlie",
395 "AlphaBravoCharlie",
396 "thisisatestdir",
397 "\\\\\\\\\\ThisIsATestDir",
398 "\\this\\is\\a\\test\\dir",
399 "\\test\\tiatd",
400 ];
401
402 assert_eq!(
403 match_single_path_query("abc", false, &paths),
404 vec![
405 ("abC", vec![0, 1, 2]),
406 ("abcd", vec![0, 1, 2]),
407 ("AlphaBravoCharlie", vec![0, 5, 10]),
408 ("alphabravocharlie", vec![4, 5, 10]),
409 ]
410 );
411 assert_eq!(
412 match_single_path_query("t\\i\\a\\t\\d", false, &paths),
413 vec![(
414 "\\this\\is\\a\\test\\dir",
415 vec![1, 5, 6, 8, 9, 10, 11, 15, 16]
416 ),]
417 );
418
419 assert_eq!(
420 match_single_path_query("tiatd", false, &paths),
421 vec![
422 ("\\test\\tiatd", vec![6, 7, 8, 9, 10]),
423 ("\\this\\is\\a\\test\\dir", vec![1, 6, 9, 11, 16]),
424 ("\\\\\\\\\\ThisIsATestDir", vec![5, 9, 11, 12, 16]),
425 ("thisisatestdir", vec![0, 2, 6, 7, 11]),
426 ]
427 );
428 }
429
430 #[test]
431 fn test_lowercase_longer_than_uppercase() {
432 // This character has more chars in lower-case than in upper-case.
433 let paths = vec!["\u{0130}"];
434 let query = "\u{0130}";
435 assert_eq!(
436 match_single_path_query(query, false, &paths),
437 vec![("\u{0130}", vec![0])]
438 );
439
440 // Path is the lower-case version of the query
441 let paths = vec!["i\u{307}"];
442 let query = "\u{0130}";
443 assert_eq!(
444 match_single_path_query(query, false, &paths),
445 vec![("i\u{307}", vec![0])]
446 );
447 }
448
449 #[test]
450 fn test_match_multibyte_path_entries() {
451 let paths = vec![
452 "aαbβ/cγdδ",
453 "αβγδ/bcde",
454 "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f",
455 "/d/🆒/h",
456 ];
457 assert_eq!("1️⃣".len(), 7);
458 assert_eq!(
459 match_single_path_query("bcd", false, &paths),
460 vec![
461 ("αβγδ/bcde", vec![9, 10, 11]),
462 ("aαbβ/cγdδ", vec![3, 7, 10]),
463 ]
464 );
465 assert_eq!(
466 match_single_path_query("cde", false, &paths),
467 vec![
468 ("αβγδ/bcde", vec![10, 11, 12]),
469 ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
470 ]
471 );
472 }
473
474 fn match_single_path_query<'a>(
475 query: &str,
476 smart_case: bool,
477 paths: &[&'a str],
478 ) -> Vec<(&'a str, Vec<usize>)> {
479 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
480 let query = query.chars().collect::<Vec<_>>();
481 let query_chars = CharBag::from(&lowercase_query[..]);
482
483 let path_arcs: Vec<Arc<Path>> = paths
484 .iter()
485 .map(|path| Arc::from(PathBuf::from(path)))
486 .collect::<Vec<_>>();
487 let mut path_entries = Vec::new();
488 for (i, path) in paths.iter().enumerate() {
489 let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
490 let char_bag = CharBag::from(lowercase_path.as_slice());
491 path_entries.push(PathMatchCandidate {
492 is_dir: false,
493 char_bag,
494 path: &path_arcs[i],
495 });
496 }
497
498 let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case);
499
500 let cancel_flag = AtomicBool::new(false);
501 let mut results = Vec::new();
502
503 matcher.match_candidates(
504 &[],
505 &[],
506 path_entries.into_iter(),
507 &mut results,
508 &cancel_flag,
509 |candidate, score, positions| PathMatch {
510 score,
511 worktree_id: 0,
512 positions: positions.clone(),
513 path: Arc::from(candidate.path),
514 path_prefix: "".into(),
515 distance_to_relative_ancestor: usize::MAX,
516 is_dir: false,
517 },
518 );
519 results.sort_by(|a, b| b.cmp(a));
520
521 results
522 .into_iter()
523 .map(|result| {
524 (
525 paths
526 .iter()
527 .copied()
528 .find(|p| result.path.as_ref() == Path::new(p))
529 .unwrap(),
530 result.positions,
531 )
532 })
533 .collect()
534 }
535}