1use std::{
2 borrow::Cow,
3 sync::atomic::{self, AtomicBool},
4};
5
6use crate::CharBag;
7
8const BASE_DISTANCE_PENALTY: f64 = 0.6;
9const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
10const MIN_DISTANCE_PENALTY: f64 = 0.2;
11
12pub struct Matcher<'a> {
13 query: &'a [char],
14 lowercase_query: &'a [char],
15 query_char_bag: CharBag,
16 smart_case: bool,
17 min_score: f64,
18 match_positions: Vec<usize>,
19 last_positions: Vec<usize>,
20 score_matrix: Vec<Option<f64>>,
21 best_position_matrix: Vec<usize>,
22}
23
24pub trait MatchCandidate {
25 fn has_chars(&self, bag: CharBag) -> bool;
26 fn to_string(&self) -> Cow<'_, str>;
27}
28
29impl<'a> Matcher<'a> {
30 pub fn new(
31 query: &'a [char],
32 lowercase_query: &'a [char],
33 query_char_bag: CharBag,
34 smart_case: bool,
35 ) -> Self {
36 Self {
37 query,
38 lowercase_query,
39 query_char_bag,
40 min_score: 0.0,
41 last_positions: vec![0; lowercase_query.len()],
42 match_positions: vec![0; query.len()],
43 score_matrix: Vec::new(),
44 best_position_matrix: Vec::new(),
45 smart_case,
46 }
47 }
48
49 /// Filter and score fuzzy match candidates. Results are returned unsorted, in the same order as
50 /// the input candidates.
51 pub fn match_candidates<C: MatchCandidate, R, F>(
52 &mut self,
53 prefix: &[char],
54 lowercase_prefix: &[char],
55 candidates: impl Iterator<Item = C>,
56 results: &mut Vec<R>,
57 cancel_flag: &AtomicBool,
58 build_match: F,
59 ) where
60 F: Fn(&C, f64, &Vec<usize>) -> R,
61 {
62 let mut candidate_chars = Vec::new();
63 let mut lowercase_candidate_chars = Vec::new();
64
65 for candidate in candidates {
66 if !candidate.has_chars(self.query_char_bag) {
67 continue;
68 }
69
70 if cancel_flag.load(atomic::Ordering::Relaxed) {
71 break;
72 }
73
74 candidate_chars.clear();
75 lowercase_candidate_chars.clear();
76 for c in candidate.to_string().chars() {
77 candidate_chars.push(c);
78 lowercase_candidate_chars.append(&mut c.to_lowercase().collect::<Vec<_>>());
79 }
80
81 if !self.find_last_positions(lowercase_prefix, &lowercase_candidate_chars) {
82 continue;
83 }
84
85 let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
86 self.score_matrix.clear();
87 self.score_matrix.resize(matrix_len, None);
88 self.best_position_matrix.clear();
89 self.best_position_matrix.resize(matrix_len, 0);
90
91 let score = self.score_match(
92 &candidate_chars,
93 &lowercase_candidate_chars,
94 prefix,
95 lowercase_prefix,
96 );
97
98 if score > 0.0 {
99 results.push(build_match(&candidate, score, &self.match_positions));
100 }
101 }
102 }
103
104 fn find_last_positions(
105 &mut self,
106 lowercase_prefix: &[char],
107 lowercase_candidate: &[char],
108 ) -> bool {
109 let mut lowercase_prefix = lowercase_prefix.iter();
110 let mut lowercase_candidate = lowercase_candidate.iter();
111 for (i, char) in self.lowercase_query.iter().enumerate().rev() {
112 if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
113 self.last_positions[i] = j + lowercase_prefix.len();
114 } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
115 self.last_positions[i] = j;
116 } else {
117 return false;
118 }
119 }
120 true
121 }
122
123 fn score_match(
124 &mut self,
125 path: &[char],
126 path_cased: &[char],
127 prefix: &[char],
128 lowercase_prefix: &[char],
129 ) -> f64 {
130 let score = self.recursive_score_match(
131 path,
132 path_cased,
133 prefix,
134 lowercase_prefix,
135 0,
136 0,
137 self.query.len() as f64,
138 ) * self.query.len() as f64;
139
140 if score <= 0.0 {
141 return 0.0;
142 }
143
144 let path_len = prefix.len() + path.len();
145 let mut cur_start = 0;
146 let mut byte_ix = 0;
147 let mut char_ix = 0;
148 for i in 0..self.query.len() {
149 let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
150 while char_ix < match_char_ix {
151 let ch = prefix
152 .get(char_ix)
153 .or_else(|| path.get(char_ix - prefix.len()))
154 .unwrap();
155 byte_ix += ch.len_utf8();
156 char_ix += 1;
157 }
158 cur_start = match_char_ix + 1;
159 self.match_positions[i] = byte_ix;
160 }
161
162 score
163 }
164
165 #[allow(clippy::too_many_arguments)]
166 fn recursive_score_match(
167 &mut self,
168 path: &[char],
169 path_cased: &[char],
170 prefix: &[char],
171 lowercase_prefix: &[char],
172 query_idx: usize,
173 path_idx: usize,
174 cur_score: f64,
175 ) -> f64 {
176 if query_idx == self.query.len() {
177 return 1.0;
178 }
179
180 let path_len = prefix.len() + path.len();
181
182 if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
183 return memoized;
184 }
185
186 let mut score = 0.0;
187 let mut best_position = 0;
188
189 let query_char = self.lowercase_query[query_idx];
190 let limit = self.last_positions[query_idx];
191
192 let mut last_slash = 0;
193 for j in path_idx..=limit {
194 let path_char = if j < prefix.len() {
195 lowercase_prefix[j]
196 } else {
197 path_cased[j - prefix.len()]
198 };
199 let is_path_sep = path_char == '/' || path_char == '\\';
200
201 if query_idx == 0 && is_path_sep {
202 last_slash = j;
203 }
204
205 if query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\') {
206 let curr = if j < prefix.len() {
207 prefix[j]
208 } else {
209 path[j - prefix.len()]
210 };
211
212 let mut char_score = 1.0;
213 if j > path_idx {
214 let last = if j - 1 < prefix.len() {
215 prefix[j - 1]
216 } else {
217 path[j - 1 - prefix.len()]
218 };
219
220 if last == '/' {
221 char_score = 0.9;
222 } else if (last == '-' || last == '_' || last == ' ' || last.is_numeric())
223 || (last.is_lowercase() && curr.is_uppercase())
224 {
225 char_score = 0.8;
226 } else if last == '.' {
227 char_score = 0.7;
228 } else if query_idx == 0 {
229 char_score = BASE_DISTANCE_PENALTY;
230 } else {
231 char_score = MIN_DISTANCE_PENALTY.max(
232 BASE_DISTANCE_PENALTY
233 - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
234 );
235 }
236 }
237
238 // Apply a severe penalty if the case doesn't match.
239 // This will make the exact matches have higher score than the case-insensitive and the
240 // path insensitive matches.
241 if (self.smart_case || curr == '/') && self.query[query_idx] != curr {
242 char_score *= 0.001;
243 }
244
245 let mut multiplier = char_score;
246
247 // Scale the score based on how deep within the path we found the match.
248 if query_idx == 0 {
249 multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
250 }
251
252 let mut next_score = 1.0;
253 if self.min_score > 0.0 {
254 next_score = cur_score * multiplier;
255 // Scores only decrease. If we can't pass the previous best, bail
256 if next_score < self.min_score {
257 // Ensure that score is non-zero so we use it in the memo table.
258 if score == 0.0 {
259 score = 1e-18;
260 }
261 continue;
262 }
263 }
264
265 let new_score = self.recursive_score_match(
266 path,
267 path_cased,
268 prefix,
269 lowercase_prefix,
270 query_idx + 1,
271 j + 1,
272 next_score,
273 ) * multiplier;
274
275 if new_score > score {
276 score = new_score;
277 best_position = j;
278 // Optimization: can't score better than 1.
279 if new_score == 1.0 {
280 break;
281 }
282 }
283 }
284 }
285
286 if best_position != 0 {
287 self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
288 }
289
290 self.score_matrix[query_idx * path_len + path_idx] = Some(score);
291 score
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use crate::{PathMatch, PathMatchCandidate};
298
299 use super::*;
300 use std::{
301 path::{Path, PathBuf},
302 sync::Arc,
303 };
304
305 #[test]
306 fn test_get_last_positions() {
307 let mut query: &[char] = &['d', 'c'];
308 let mut matcher = Matcher::new(query, query, query.into(), false);
309 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
310 assert!(!result);
311
312 query = &['c', 'd'];
313 let mut matcher = Matcher::new(query, query, query.into(), false);
314 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
315 assert!(result);
316 assert_eq!(matcher.last_positions, vec![2, 4]);
317
318 query = &['z', '/', 'z', 'f'];
319 let mut matcher = Matcher::new(query, query, query.into(), false);
320 let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
321 assert!(result);
322 assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
323 }
324
325 #[test]
326 fn test_match_path_entries() {
327 let paths = vec![
328 "",
329 "a",
330 "ab",
331 "abC",
332 "abcd",
333 "alphabravocharlie",
334 "AlphaBravoCharlie",
335 "thisisatestdir",
336 "/////ThisIsATestDir",
337 "/this/is/a/test/dir",
338 "/test/tiatd",
339 ];
340
341 assert_eq!(
342 match_single_path_query("abc", false, &paths),
343 vec![
344 ("abC", vec![0, 1, 2]),
345 ("abcd", vec![0, 1, 2]),
346 ("AlphaBravoCharlie", vec![0, 5, 10]),
347 ("alphabravocharlie", vec![4, 5, 10]),
348 ]
349 );
350 assert_eq!(
351 match_single_path_query("t/i/a/t/d", false, &paths),
352 vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
353 );
354
355 assert_eq!(
356 match_single_path_query("tiatd", false, &paths),
357 vec![
358 ("/test/tiatd", vec![6, 7, 8, 9, 10]),
359 ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
360 ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
361 ("thisisatestdir", vec![0, 2, 6, 7, 11]),
362 ]
363 );
364 }
365
366 #[test]
367 fn test_lowercase_longer_than_uppercase() {
368 // This character has more chars in lower-case than in upper-case.
369 let paths = vec!["\u{0130}"];
370 let query = "\u{0130}";
371 assert_eq!(
372 match_single_path_query(query, false, &paths),
373 vec![("\u{0130}", vec![0])]
374 );
375
376 // Path is the lower-case version of the query
377 let paths = vec!["i\u{307}"];
378 let query = "\u{0130}";
379 assert_eq!(
380 match_single_path_query(query, false, &paths),
381 vec![("i\u{307}", vec![0])]
382 );
383 }
384
385 #[test]
386 fn test_match_multibyte_path_entries() {
387 let paths = vec![
388 "aαbβ/cγdδ",
389 "αβγδ/bcde",
390 "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f",
391 "/d/🆒/h",
392 ];
393 assert_eq!("1️⃣".len(), 7);
394 assert_eq!(
395 match_single_path_query("bcd", false, &paths),
396 vec![
397 ("αβγδ/bcde", vec![9, 10, 11]),
398 ("aαbβ/cγdδ", vec![3, 7, 10]),
399 ]
400 );
401 assert_eq!(
402 match_single_path_query("cde", false, &paths),
403 vec![
404 ("αβγδ/bcde", vec![10, 11, 12]),
405 ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
406 ]
407 );
408 }
409
410 fn match_single_path_query<'a>(
411 query: &str,
412 smart_case: bool,
413 paths: &[&'a str],
414 ) -> Vec<(&'a str, Vec<usize>)> {
415 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
416 let query = query.chars().collect::<Vec<_>>();
417 let query_chars = CharBag::from(&lowercase_query[..]);
418
419 let path_arcs: Vec<Arc<Path>> = paths
420 .iter()
421 .map(|path| Arc::from(PathBuf::from(path)))
422 .collect::<Vec<_>>();
423 let mut path_entries = Vec::new();
424 for (i, path) in paths.iter().enumerate() {
425 let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
426 let char_bag = CharBag::from(lowercase_path.as_slice());
427 path_entries.push(PathMatchCandidate {
428 is_dir: false,
429 char_bag,
430 path: &path_arcs[i],
431 });
432 }
433
434 let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case);
435
436 let cancel_flag = AtomicBool::new(false);
437 let mut results = Vec::new();
438
439 matcher.match_candidates(
440 &[],
441 &[],
442 path_entries.into_iter(),
443 &mut results,
444 &cancel_flag,
445 |candidate, score, positions| PathMatch {
446 score,
447 worktree_id: 0,
448 positions: positions.clone(),
449 path: Arc::from(candidate.path),
450 path_prefix: "".into(),
451 distance_to_relative_ancestor: usize::MAX,
452 is_dir: false,
453 },
454 );
455 results.sort_by(|a, b| b.cmp(a));
456
457 results
458 .into_iter()
459 .map(|result| {
460 (
461 paths
462 .iter()
463 .copied()
464 .find(|p| result.path.as_ref() == Path::new(p))
465 .unwrap(),
466 result.positions,
467 )
468 })
469 .collect()
470 }
471}