1use std::{
2 borrow::{Borrow, Cow},
3 sync::atomic::{self, AtomicBool},
4};
5
6use crate::CharBag;
7
8const BASE_DISTANCE_PENALTY: f64 = 0.6;
9const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
10const MIN_DISTANCE_PENALTY: f64 = 0.2;
11
12// TODO:
13// Use `Path` instead of `&str` for paths.
14pub struct Matcher<'a> {
15 query: &'a [char],
16 lowercase_query: &'a [char],
17 query_char_bag: CharBag,
18 smart_case: bool,
19 min_score: f64,
20 match_positions: Vec<usize>,
21 last_positions: Vec<usize>,
22 score_matrix: Vec<Option<f64>>,
23 best_position_matrix: Vec<usize>,
24}
25
26pub trait MatchCandidate {
27 fn has_chars(&self, bag: CharBag) -> bool;
28 fn to_string(&self) -> Cow<'_, str>;
29}
30
31impl<'a> Matcher<'a> {
32 pub fn new(
33 query: &'a [char],
34 lowercase_query: &'a [char],
35 query_char_bag: CharBag,
36 smart_case: bool,
37 ) -> Self {
38 Self {
39 query,
40 lowercase_query,
41 query_char_bag,
42 min_score: 0.0,
43 last_positions: vec![0; lowercase_query.len()],
44 match_positions: vec![0; query.len()],
45 score_matrix: Vec::new(),
46 best_position_matrix: Vec::new(),
47 smart_case,
48 }
49 }
50
51 /// Filter and score fuzzy match candidates. Results are returned unsorted, in the same order as
52 /// the input candidates.
53 pub fn match_candidates<C, R, F, T>(
54 &mut self,
55 prefix: &[char],
56 lowercase_prefix: &[char],
57 candidates: impl Iterator<Item = T>,
58 results: &mut Vec<R>,
59 cancel_flag: &AtomicBool,
60 build_match: F,
61 ) where
62 C: MatchCandidate,
63 T: Borrow<C>,
64 F: Fn(&C, f64, &Vec<usize>) -> R,
65 {
66 let mut candidate_chars = Vec::new();
67 let mut lowercase_candidate_chars = Vec::new();
68
69 for candidate in candidates {
70 if !candidate.borrow().has_chars(self.query_char_bag) {
71 continue;
72 }
73
74 if cancel_flag.load(atomic::Ordering::Relaxed) {
75 break;
76 }
77
78 candidate_chars.clear();
79 lowercase_candidate_chars.clear();
80 for c in candidate.borrow().to_string().chars() {
81 candidate_chars.push(c);
82 lowercase_candidate_chars.append(&mut c.to_lowercase().collect::<Vec<_>>());
83 }
84
85 if !self.find_last_positions(lowercase_prefix, &lowercase_candidate_chars) {
86 continue;
87 }
88
89 let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
90 self.score_matrix.clear();
91 self.score_matrix.resize(matrix_len, None);
92 self.best_position_matrix.clear();
93 self.best_position_matrix.resize(matrix_len, 0);
94
95 let score = self.score_match(
96 &candidate_chars,
97 &lowercase_candidate_chars,
98 prefix,
99 lowercase_prefix,
100 );
101
102 if score > 0.0 {
103 results.push(build_match(
104 candidate.borrow(),
105 score,
106 &self.match_positions,
107 ));
108 }
109 }
110 }
111
112 fn find_last_positions(
113 &mut self,
114 lowercase_prefix: &[char],
115 lowercase_candidate: &[char],
116 ) -> bool {
117 let mut lowercase_prefix = lowercase_prefix.iter();
118 let mut lowercase_candidate = lowercase_candidate.iter();
119 for (i, char) in self.lowercase_query.iter().enumerate().rev() {
120 if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
121 self.last_positions[i] = j + lowercase_prefix.len();
122 } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
123 self.last_positions[i] = j;
124 } else {
125 return false;
126 }
127 }
128 true
129 }
130
131 fn score_match(
132 &mut self,
133 path: &[char],
134 path_cased: &[char],
135 prefix: &[char],
136 lowercase_prefix: &[char],
137 ) -> f64 {
138 let score = self.recursive_score_match(
139 path,
140 path_cased,
141 prefix,
142 lowercase_prefix,
143 0,
144 0,
145 self.query.len() as f64,
146 ) * self.query.len() as f64;
147
148 if score <= 0.0 {
149 return 0.0;
150 }
151
152 let path_len = prefix.len() + path.len();
153 let mut cur_start = 0;
154 let mut byte_ix = 0;
155 let mut char_ix = 0;
156 for i in 0..self.query.len() {
157 let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
158 while char_ix < match_char_ix {
159 let ch = prefix
160 .get(char_ix)
161 .or_else(|| path.get(char_ix - prefix.len()))
162 .unwrap();
163 byte_ix += ch.len_utf8();
164 char_ix += 1;
165 }
166 cur_start = match_char_ix + 1;
167 self.match_positions[i] = byte_ix;
168 }
169
170 score
171 }
172
173 #[allow(clippy::too_many_arguments)]
174 fn recursive_score_match(
175 &mut self,
176 path: &[char],
177 path_cased: &[char],
178 prefix: &[char],
179 lowercase_prefix: &[char],
180 query_idx: usize,
181 path_idx: usize,
182 cur_score: f64,
183 ) -> f64 {
184 use std::path::MAIN_SEPARATOR;
185
186 if query_idx == self.query.len() {
187 return 1.0;
188 }
189
190 let path_len = prefix.len() + path.len();
191
192 if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
193 return memoized;
194 }
195
196 let mut score = 0.0;
197 let mut best_position = 0;
198
199 let query_char = self.lowercase_query[query_idx];
200 let limit = self.last_positions[query_idx];
201
202 let mut last_slash = 0;
203 for j in path_idx..=limit {
204 let path_char = if j < prefix.len() {
205 lowercase_prefix[j]
206 } else {
207 path_cased[j - prefix.len()]
208 };
209 let is_path_sep = path_char == MAIN_SEPARATOR;
210
211 if query_idx == 0 && is_path_sep {
212 last_slash = j;
213 }
214
215 #[cfg(not(target_os = "windows"))]
216 let need_to_score =
217 query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\');
218 // `query_char == '\\'` breaks `test_match_path_entries` on Windows, `\` is only used as a path separator on Windows.
219 #[cfg(target_os = "windows")]
220 let need_to_score = query_char == path_char || (is_path_sep && query_char == '_');
221 if need_to_score {
222 let curr = if j < prefix.len() {
223 prefix[j]
224 } else {
225 path[j - prefix.len()]
226 };
227
228 let mut char_score = 1.0;
229 if j > path_idx {
230 let last = if j - 1 < prefix.len() {
231 prefix[j - 1]
232 } else {
233 path[j - 1 - prefix.len()]
234 };
235
236 if last == MAIN_SEPARATOR {
237 char_score = 0.9;
238 } else if (last == '-' || last == '_' || last == ' ' || last.is_numeric())
239 || (last.is_lowercase() && curr.is_uppercase())
240 {
241 char_score = 0.8;
242 } else if last == '.' {
243 char_score = 0.7;
244 } else if query_idx == 0 {
245 char_score = BASE_DISTANCE_PENALTY;
246 } else {
247 char_score = MIN_DISTANCE_PENALTY.max(
248 BASE_DISTANCE_PENALTY
249 - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
250 );
251 }
252 }
253
254 // Apply a severe penalty if the case doesn't match.
255 // This will make the exact matches have higher score than the case-insensitive and the
256 // path insensitive matches.
257 if (self.smart_case || curr == MAIN_SEPARATOR) && self.query[query_idx] != curr {
258 char_score *= 0.001;
259 }
260
261 let mut multiplier = char_score;
262
263 // Scale the score based on how deep within the path we found the match.
264 if query_idx == 0 {
265 multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
266 }
267
268 let mut next_score = 1.0;
269 if self.min_score > 0.0 {
270 next_score = cur_score * multiplier;
271 // Scores only decrease. If we can't pass the previous best, bail
272 if next_score < self.min_score {
273 // Ensure that score is non-zero so we use it in the memo table.
274 if score == 0.0 {
275 score = 1e-18;
276 }
277 continue;
278 }
279 }
280
281 let new_score = self.recursive_score_match(
282 path,
283 path_cased,
284 prefix,
285 lowercase_prefix,
286 query_idx + 1,
287 j + 1,
288 next_score,
289 ) * multiplier;
290
291 if new_score > score {
292 score = new_score;
293 best_position = j;
294 // Optimization: can't score better than 1.
295 if new_score == 1.0 {
296 break;
297 }
298 }
299 }
300 }
301
302 if best_position != 0 {
303 self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
304 }
305
306 self.score_matrix[query_idx * path_len + path_idx] = Some(score);
307 score
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use crate::{PathMatch, PathMatchCandidate};
314
315 use super::*;
316 use std::{
317 path::{Path, PathBuf},
318 sync::Arc,
319 };
320
321 #[test]
322 fn test_get_last_positions() {
323 let mut query: &[char] = &['d', 'c'];
324 let mut matcher = Matcher::new(query, query, query.into(), false);
325 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
326 assert!(!result);
327
328 query = &['c', 'd'];
329 let mut matcher = Matcher::new(query, query, query.into(), false);
330 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
331 assert!(result);
332 assert_eq!(matcher.last_positions, vec![2, 4]);
333
334 query = &['z', '/', 'z', 'f'];
335 let mut matcher = Matcher::new(query, query, query.into(), false);
336 let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
337 assert!(result);
338 assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
339 }
340
341 #[cfg(not(target_os = "windows"))]
342 #[test]
343 fn test_match_path_entries() {
344 let paths = vec![
345 "",
346 "a",
347 "ab",
348 "abC",
349 "abcd",
350 "alphabravocharlie",
351 "AlphaBravoCharlie",
352 "thisisatestdir",
353 "/////ThisIsATestDir",
354 "/this/is/a/test/dir",
355 "/test/tiatd",
356 ];
357
358 assert_eq!(
359 match_single_path_query("abc", false, &paths),
360 vec![
361 ("abC", vec![0, 1, 2]),
362 ("abcd", vec![0, 1, 2]),
363 ("AlphaBravoCharlie", vec![0, 5, 10]),
364 ("alphabravocharlie", vec![4, 5, 10]),
365 ]
366 );
367 assert_eq!(
368 match_single_path_query("t/i/a/t/d", false, &paths),
369 vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
370 );
371
372 assert_eq!(
373 match_single_path_query("tiatd", false, &paths),
374 vec![
375 ("/test/tiatd", vec![6, 7, 8, 9, 10]),
376 ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
377 ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
378 ("thisisatestdir", vec![0, 2, 6, 7, 11]),
379 ]
380 );
381 }
382
383 /// todo(windows)
384 /// Now, on Windows, users can only use the backslash as a path separator.
385 /// I do want to support both the backslash and the forward slash as path separators on Windows.
386 #[cfg(target_os = "windows")]
387 #[test]
388 fn test_match_path_entries() {
389 let paths = vec![
390 "",
391 "a",
392 "ab",
393 "abC",
394 "abcd",
395 "alphabravocharlie",
396 "AlphaBravoCharlie",
397 "thisisatestdir",
398 "\\\\\\\\\\ThisIsATestDir",
399 "\\this\\is\\a\\test\\dir",
400 "\\test\\tiatd",
401 ];
402
403 assert_eq!(
404 match_single_path_query("abc", false, &paths),
405 vec![
406 ("abC", vec![0, 1, 2]),
407 ("abcd", vec![0, 1, 2]),
408 ("AlphaBravoCharlie", vec![0, 5, 10]),
409 ("alphabravocharlie", vec![4, 5, 10]),
410 ]
411 );
412 assert_eq!(
413 match_single_path_query("t\\i\\a\\t\\d", false, &paths),
414 vec![(
415 "\\this\\is\\a\\test\\dir",
416 vec![1, 5, 6, 8, 9, 10, 11, 15, 16]
417 ),]
418 );
419
420 assert_eq!(
421 match_single_path_query("tiatd", false, &paths),
422 vec![
423 ("\\test\\tiatd", vec![6, 7, 8, 9, 10]),
424 ("\\this\\is\\a\\test\\dir", vec![1, 6, 9, 11, 16]),
425 ("\\\\\\\\\\ThisIsATestDir", vec![5, 9, 11, 12, 16]),
426 ("thisisatestdir", vec![0, 2, 6, 7, 11]),
427 ]
428 );
429 }
430
431 #[test]
432 fn test_lowercase_longer_than_uppercase() {
433 // This character has more chars in lower-case than in upper-case.
434 let paths = vec!["\u{0130}"];
435 let query = "\u{0130}";
436 assert_eq!(
437 match_single_path_query(query, false, &paths),
438 vec![("\u{0130}", vec![0])]
439 );
440
441 // Path is the lower-case version of the query
442 let paths = vec!["i\u{307}"];
443 let query = "\u{0130}";
444 assert_eq!(
445 match_single_path_query(query, false, &paths),
446 vec![("i\u{307}", vec![0])]
447 );
448 }
449
450 #[test]
451 fn test_match_multibyte_path_entries() {
452 let paths = vec![
453 "aαbβ/cγdδ",
454 "αβγδ/bcde",
455 "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f",
456 "/d/🆒/h",
457 ];
458 assert_eq!("1️⃣".len(), 7);
459 assert_eq!(
460 match_single_path_query("bcd", false, &paths),
461 vec![
462 ("αβγδ/bcde", vec![9, 10, 11]),
463 ("aαbβ/cγdδ", vec![3, 7, 10]),
464 ]
465 );
466 assert_eq!(
467 match_single_path_query("cde", false, &paths),
468 vec![
469 ("αβγδ/bcde", vec![10, 11, 12]),
470 ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
471 ]
472 );
473 }
474
475 fn match_single_path_query<'a>(
476 query: &str,
477 smart_case: bool,
478 paths: &[&'a str],
479 ) -> Vec<(&'a str, Vec<usize>)> {
480 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
481 let query = query.chars().collect::<Vec<_>>();
482 let query_chars = CharBag::from(&lowercase_query[..]);
483
484 let path_arcs: Vec<Arc<Path>> = paths
485 .iter()
486 .map(|path| Arc::from(PathBuf::from(path)))
487 .collect::<Vec<_>>();
488 let mut path_entries = Vec::new();
489 for (i, path) in paths.iter().enumerate() {
490 let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
491 let char_bag = CharBag::from(lowercase_path.as_slice());
492 path_entries.push(PathMatchCandidate {
493 is_dir: false,
494 char_bag,
495 path: &path_arcs[i],
496 });
497 }
498
499 let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case);
500
501 let cancel_flag = AtomicBool::new(false);
502 let mut results = Vec::new();
503
504 matcher.match_candidates(
505 &[],
506 &[],
507 path_entries.into_iter(),
508 &mut results,
509 &cancel_flag,
510 |candidate, score, positions| PathMatch {
511 score,
512 worktree_id: 0,
513 positions: positions.clone(),
514 path: Arc::from(candidate.path),
515 path_prefix: "".into(),
516 distance_to_relative_ancestor: usize::MAX,
517 is_dir: false,
518 },
519 );
520 results.sort_by(|a, b| b.cmp(a));
521
522 results
523 .into_iter()
524 .map(|result| {
525 (
526 paths
527 .iter()
528 .copied()
529 .find(|p| result.path.as_ref() == Path::new(p))
530 .unwrap(),
531 result.positions,
532 )
533 })
534 .collect()
535 }
536}