1mod char_bag;
2
3use std::{
4 borrow::Cow,
5 cmp::Ordering,
6 path::Path,
7 sync::atomic::{self, AtomicBool},
8 sync::Arc,
9};
10
11pub use char_bag::CharBag;
12
13const BASE_DISTANCE_PENALTY: f64 = 0.6;
14const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
15const MIN_DISTANCE_PENALTY: f64 = 0.2;
16
17pub struct Matcher<'a> {
18 query: &'a [char],
19 lowercase_query: &'a [char],
20 query_char_bag: CharBag,
21 smart_case: bool,
22 max_results: usize,
23 min_score: f64,
24 match_positions: Vec<usize>,
25 last_positions: Vec<usize>,
26 score_matrix: Vec<Option<f64>>,
27 best_position_matrix: Vec<usize>,
28}
29
30trait Match: Ord {
31 fn score(&self) -> f64;
32 fn set_positions(&mut self, positions: Vec<usize>);
33}
34
35trait MatchCandidate {
36 fn has_chars(&self, bag: CharBag) -> bool;
37 fn to_string<'a>(&'a self) -> Cow<'a, str>;
38}
39
40#[derive(Clone, Debug)]
41pub struct PathMatchCandidate<'a> {
42 pub path: &'a Arc<Path>,
43 pub char_bag: CharBag,
44}
45
46#[derive(Clone, Debug)]
47pub struct PathMatch {
48 pub score: f64,
49 pub positions: Vec<usize>,
50 pub tree_id: usize,
51 pub path: Arc<Path>,
52 pub path_prefix: Arc<str>,
53}
54
55#[derive(Clone, Debug)]
56pub struct StringMatchCandidate {
57 pub string: String,
58 pub char_bag: CharBag,
59}
60
61impl Match for PathMatch {
62 fn score(&self) -> f64 {
63 self.score
64 }
65
66 fn set_positions(&mut self, positions: Vec<usize>) {
67 self.positions = positions;
68 }
69}
70
71impl Match for StringMatch {
72 fn score(&self) -> f64 {
73 self.score
74 }
75
76 fn set_positions(&mut self, positions: Vec<usize>) {
77 self.positions = positions;
78 }
79}
80
81impl<'a> MatchCandidate for PathMatchCandidate<'a> {
82 fn has_chars(&self, bag: CharBag) -> bool {
83 self.char_bag.is_superset(bag)
84 }
85
86 fn to_string(&self) -> Cow<'a, str> {
87 self.path.to_string_lossy()
88 }
89}
90
91impl<'a> MatchCandidate for &'a StringMatchCandidate {
92 fn has_chars(&self, bag: CharBag) -> bool {
93 self.char_bag.is_superset(bag)
94 }
95
96 fn to_string(&self) -> Cow<'a, str> {
97 self.string.as_str().into()
98 }
99}
100
101#[derive(Clone, Debug)]
102pub struct StringMatch {
103 pub score: f64,
104 pub positions: Vec<usize>,
105 pub string: String,
106}
107
108impl PartialEq for StringMatch {
109 fn eq(&self, other: &Self) -> bool {
110 self.score.eq(&other.score)
111 }
112}
113
114impl Eq for StringMatch {}
115
116impl PartialOrd for StringMatch {
117 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
118 Some(self.cmp(other))
119 }
120}
121
122impl Ord for StringMatch {
123 fn cmp(&self, other: &Self) -> Ordering {
124 self.score
125 .partial_cmp(&other.score)
126 .unwrap_or(Ordering::Equal)
127 .then_with(|| self.string.cmp(&other.string))
128 }
129}
130
131impl PartialEq for PathMatch {
132 fn eq(&self, other: &Self) -> bool {
133 self.score.eq(&other.score)
134 }
135}
136
137impl Eq for PathMatch {}
138
139impl PartialOrd for PathMatch {
140 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
141 Some(self.cmp(other))
142 }
143}
144
145impl Ord for PathMatch {
146 fn cmp(&self, other: &Self) -> Ordering {
147 self.score
148 .partial_cmp(&other.score)
149 .unwrap_or(Ordering::Equal)
150 .then_with(|| self.tree_id.cmp(&other.tree_id))
151 .then_with(|| Arc::as_ptr(&self.path).cmp(&Arc::as_ptr(&other.path)))
152 }
153}
154
155impl<'a> Matcher<'a> {
156 pub fn new(
157 query: &'a [char],
158 lowercase_query: &'a [char],
159 query_char_bag: CharBag,
160 smart_case: bool,
161 max_results: usize,
162 ) -> Self {
163 Self {
164 query,
165 lowercase_query,
166 query_char_bag,
167 min_score: 0.0,
168 last_positions: vec![0; query.len()],
169 match_positions: vec![0; query.len()],
170 score_matrix: Vec::new(),
171 best_position_matrix: Vec::new(),
172 smart_case,
173 max_results,
174 }
175 }
176
177 pub fn match_strings(
178 &mut self,
179 candidates: &[StringMatchCandidate],
180 results: &mut Vec<StringMatch>,
181 cancel_flag: &AtomicBool,
182 ) {
183 self.match_internal(
184 &[],
185 &[],
186 candidates.iter(),
187 results,
188 cancel_flag,
189 |candidate, score| StringMatch {
190 score,
191 positions: Vec::new(),
192 string: candidate.string.to_string(),
193 },
194 )
195 }
196
197 pub fn match_paths(
198 &mut self,
199 tree_id: usize,
200 path_prefix: Arc<str>,
201 path_entries: impl Iterator<Item = PathMatchCandidate<'a>>,
202 results: &mut Vec<PathMatch>,
203 cancel_flag: &AtomicBool,
204 ) {
205 let prefix = path_prefix.chars().collect::<Vec<_>>();
206 let lowercase_prefix = prefix
207 .iter()
208 .map(|c| c.to_ascii_lowercase())
209 .collect::<Vec<_>>();
210 self.match_internal(
211 &prefix,
212 &lowercase_prefix,
213 path_entries,
214 results,
215 cancel_flag,
216 |candidate, score| PathMatch {
217 score,
218 tree_id,
219 positions: Vec::new(),
220 path: candidate.path.clone(),
221 path_prefix: path_prefix.clone(),
222 },
223 )
224 }
225
226 fn match_internal<C: MatchCandidate, R, F>(
227 &mut self,
228 prefix: &[char],
229 lowercase_prefix: &[char],
230 candidates: impl Iterator<Item = C>,
231 results: &mut Vec<R>,
232 cancel_flag: &AtomicBool,
233 build_match: F,
234 ) where
235 R: Match,
236 F: Fn(&C, f64) -> R,
237 {
238 let mut candidate_chars = Vec::new();
239 let mut lowercase_candidate_chars = Vec::new();
240
241 for candidate in candidates {
242 if !candidate.has_chars(self.query_char_bag) {
243 continue;
244 }
245
246 if cancel_flag.load(atomic::Ordering::Relaxed) {
247 break;
248 }
249
250 candidate_chars.clear();
251 lowercase_candidate_chars.clear();
252 for c in candidate.to_string().chars() {
253 candidate_chars.push(c);
254 lowercase_candidate_chars.push(c.to_ascii_lowercase());
255 }
256
257 if !self.find_last_positions(&lowercase_prefix, &lowercase_candidate_chars) {
258 continue;
259 }
260
261 let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
262 self.score_matrix.clear();
263 self.score_matrix.resize(matrix_len, None);
264 self.best_position_matrix.clear();
265 self.best_position_matrix.resize(matrix_len, 0);
266
267 let score = self.score_match(
268 &candidate_chars,
269 &lowercase_candidate_chars,
270 &prefix,
271 &lowercase_prefix,
272 );
273
274 if score > 0.0 {
275 let mut mat = build_match(&candidate, score);
276 if let Err(i) = results.binary_search_by(|m| mat.cmp(&m)) {
277 if results.len() < self.max_results {
278 mat.set_positions(self.match_positions.clone());
279 results.insert(i, mat);
280 } else if i < results.len() {
281 results.pop();
282 mat.set_positions(self.match_positions.clone());
283 results.insert(i, mat);
284 }
285 if results.len() == self.max_results {
286 self.min_score = results.last().unwrap().score();
287 }
288 }
289 }
290 }
291 }
292
293 fn find_last_positions(&mut self, prefix: &[char], path: &[char]) -> bool {
294 let mut path = path.iter();
295 let mut prefix_iter = prefix.iter();
296 for (i, char) in self.query.iter().enumerate().rev() {
297 if let Some(j) = path.rposition(|c| c == char) {
298 self.last_positions[i] = j + prefix.len();
299 } else if let Some(j) = prefix_iter.rposition(|c| c == char) {
300 self.last_positions[i] = j;
301 } else {
302 return false;
303 }
304 }
305 true
306 }
307
308 fn score_match(
309 &mut self,
310 path: &[char],
311 path_cased: &[char],
312 prefix: &[char],
313 lowercase_prefix: &[char],
314 ) -> f64 {
315 let score = self.recursive_score_match(
316 path,
317 path_cased,
318 prefix,
319 lowercase_prefix,
320 0,
321 0,
322 self.query.len() as f64,
323 ) * self.query.len() as f64;
324
325 if score <= 0.0 {
326 return 0.0;
327 }
328
329 let path_len = prefix.len() + path.len();
330 let mut cur_start = 0;
331 let mut byte_ix = 0;
332 let mut char_ix = 0;
333 for i in 0..self.query.len() {
334 let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
335 while char_ix < match_char_ix {
336 let ch = prefix
337 .get(char_ix)
338 .or_else(|| path.get(char_ix - prefix.len()))
339 .unwrap();
340 byte_ix += ch.len_utf8();
341 char_ix += 1;
342 }
343 cur_start = match_char_ix + 1;
344 self.match_positions[i] = byte_ix;
345 }
346
347 score
348 }
349
350 fn recursive_score_match(
351 &mut self,
352 path: &[char],
353 path_cased: &[char],
354 prefix: &[char],
355 lowercase_prefix: &[char],
356 query_idx: usize,
357 path_idx: usize,
358 cur_score: f64,
359 ) -> f64 {
360 if query_idx == self.query.len() {
361 return 1.0;
362 }
363
364 let path_len = prefix.len() + path.len();
365
366 if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
367 return memoized;
368 }
369
370 let mut score = 0.0;
371 let mut best_position = 0;
372
373 let query_char = self.lowercase_query[query_idx];
374 let limit = self.last_positions[query_idx];
375
376 let mut last_slash = 0;
377 for j in path_idx..=limit {
378 let path_char = if j < prefix.len() {
379 lowercase_prefix[j]
380 } else {
381 path_cased[j - prefix.len()]
382 };
383 let is_path_sep = path_char == '/' || path_char == '\\';
384
385 if query_idx == 0 && is_path_sep {
386 last_slash = j;
387 }
388
389 if query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\') {
390 let curr = if j < prefix.len() {
391 prefix[j]
392 } else {
393 path[j - prefix.len()]
394 };
395
396 let mut char_score = 1.0;
397 if j > path_idx {
398 let last = if j - 1 < prefix.len() {
399 prefix[j - 1]
400 } else {
401 path[j - 1 - prefix.len()]
402 };
403
404 if last == '/' {
405 char_score = 0.9;
406 } else if last == '-' || last == '_' || last == ' ' || last.is_numeric() {
407 char_score = 0.8;
408 } else if last.is_lowercase() && curr.is_uppercase() {
409 char_score = 0.8;
410 } else if last == '.' {
411 char_score = 0.7;
412 } else if query_idx == 0 {
413 char_score = BASE_DISTANCE_PENALTY;
414 } else {
415 char_score = MIN_DISTANCE_PENALTY.max(
416 BASE_DISTANCE_PENALTY
417 - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
418 );
419 }
420 }
421
422 // Apply a severe penalty if the case doesn't match.
423 // This will make the exact matches have higher score than the case-insensitive and the
424 // path insensitive matches.
425 if (self.smart_case || curr == '/') && self.query[query_idx] != curr {
426 char_score *= 0.001;
427 }
428
429 let mut multiplier = char_score;
430
431 // Scale the score based on how deep within the path we found the match.
432 if query_idx == 0 {
433 multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
434 }
435
436 let mut next_score = 1.0;
437 if self.min_score > 0.0 {
438 next_score = cur_score * multiplier;
439 // Scores only decrease. If we can't pass the previous best, bail
440 if next_score < self.min_score {
441 // Ensure that score is non-zero so we use it in the memo table.
442 if score == 0.0 {
443 score = 1e-18;
444 }
445 continue;
446 }
447 }
448
449 let new_score = self.recursive_score_match(
450 path,
451 path_cased,
452 prefix,
453 lowercase_prefix,
454 query_idx + 1,
455 j + 1,
456 next_score,
457 ) * multiplier;
458
459 if new_score > score {
460 score = new_score;
461 best_position = j;
462 // Optimization: can't score better than 1.
463 if new_score == 1.0 {
464 break;
465 }
466 }
467 }
468 }
469
470 if best_position != 0 {
471 self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
472 }
473
474 self.score_matrix[query_idx * path_len + path_idx] = Some(score);
475 score
476 }
477}
478
479#[cfg(test)]
480mod tests {
481 use super::*;
482 use std::path::PathBuf;
483
484 #[test]
485 fn test_get_last_positions() {
486 let mut query: &[char] = &['d', 'c'];
487 let mut matcher = Matcher::new(query, query, query.into(), false, 10);
488 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
489 assert_eq!(result, false);
490
491 query = &['c', 'd'];
492 let mut matcher = Matcher::new(query, query, query.into(), false, 10);
493 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
494 assert_eq!(result, true);
495 assert_eq!(matcher.last_positions, vec![2, 4]);
496
497 query = &['z', '/', 'z', 'f'];
498 let mut matcher = Matcher::new(query, query, query.into(), false, 10);
499 let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
500 assert_eq!(result, true);
501 assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
502 }
503
504 #[test]
505 fn test_match_path_entries() {
506 let paths = vec![
507 "",
508 "a",
509 "ab",
510 "abC",
511 "abcd",
512 "alphabravocharlie",
513 "AlphaBravoCharlie",
514 "thisisatestdir",
515 "/////ThisIsATestDir",
516 "/this/is/a/test/dir",
517 "/test/tiatd",
518 ];
519
520 assert_eq!(
521 match_query("abc", false, &paths),
522 vec![
523 ("abC", vec![0, 1, 2]),
524 ("abcd", vec![0, 1, 2]),
525 ("AlphaBravoCharlie", vec![0, 5, 10]),
526 ("alphabravocharlie", vec![4, 5, 10]),
527 ]
528 );
529 assert_eq!(
530 match_query("t/i/a/t/d", false, &paths),
531 vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
532 );
533
534 assert_eq!(
535 match_query("tiatd", false, &paths),
536 vec![
537 ("/test/tiatd", vec![6, 7, 8, 9, 10]),
538 ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
539 ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
540 ("thisisatestdir", vec![0, 2, 6, 7, 11]),
541 ]
542 );
543 }
544
545 #[test]
546 fn test_match_multibyte_path_entries() {
547 let paths = vec!["aαbβ/cγdδ", "αβγδ/bcde", "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", "/d/🆒/h"];
548 assert_eq!("1️⃣".len(), 7);
549 assert_eq!(
550 match_query("bcd", false, &paths),
551 vec![
552 ("αβγδ/bcde", vec![9, 10, 11]),
553 ("aαbβ/cγdδ", vec![3, 7, 10]),
554 ]
555 );
556 assert_eq!(
557 match_query("cde", false, &paths),
558 vec![
559 ("αβγδ/bcde", vec![10, 11, 12]),
560 ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
561 ]
562 );
563 }
564
565 fn match_query<'a>(
566 query: &str,
567 smart_case: bool,
568 paths: &Vec<&'a str>,
569 ) -> Vec<(&'a str, Vec<usize>)> {
570 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
571 let query = query.chars().collect::<Vec<_>>();
572 let query_chars = CharBag::from(&lowercase_query[..]);
573
574 let path_arcs = paths
575 .iter()
576 .map(|path| Arc::from(PathBuf::from(path)))
577 .collect::<Vec<_>>();
578 let mut path_entries = Vec::new();
579 for (i, path) in paths.iter().enumerate() {
580 let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
581 let char_bag = CharBag::from(lowercase_path.as_slice());
582 path_entries.push(PathMatchCandidate {
583 char_bag,
584 path: path_arcs.get(i).unwrap(),
585 });
586 }
587
588 let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case, 100);
589
590 let cancel_flag = AtomicBool::new(false);
591 let mut results = Vec::new();
592 matcher.match_paths(
593 0,
594 "".into(),
595 path_entries.into_iter(),
596 &mut results,
597 &cancel_flag,
598 );
599
600 results
601 .into_iter()
602 .map(|result| {
603 (
604 paths
605 .iter()
606 .copied()
607 .find(|p| result.path.as_ref() == Path::new(p))
608 .unwrap(),
609 result.positions,
610 )
611 })
612 .collect()
613 }
614}