1use std::{
2 borrow::Cow,
3 sync::atomic::{self, AtomicBool},
4};
5
6use crate::CharBag;
7
8const BASE_DISTANCE_PENALTY: f64 = 0.6;
9const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
10const MIN_DISTANCE_PENALTY: f64 = 0.2;
11
12pub struct Matcher<'a> {
13 query: &'a [char],
14 lowercase_query: &'a [char],
15 query_char_bag: CharBag,
16 smart_case: bool,
17 max_results: usize,
18 min_score: f64,
19 match_positions: Vec<usize>,
20 last_positions: Vec<usize>,
21 score_matrix: Vec<Option<f64>>,
22 best_position_matrix: Vec<usize>,
23}
24
25pub trait Match: Ord {
26 fn score(&self) -> f64;
27 fn set_positions(&mut self, positions: Vec<usize>);
28}
29
30pub trait MatchCandidate {
31 fn has_chars(&self, bag: CharBag) -> bool;
32 fn to_string(&self) -> Cow<'_, str>;
33}
34
35impl<'a> Matcher<'a> {
36 pub fn new(
37 query: &'a [char],
38 lowercase_query: &'a [char],
39 query_char_bag: CharBag,
40 smart_case: bool,
41 max_results: usize,
42 ) -> Self {
43 Self {
44 query,
45 lowercase_query,
46 query_char_bag,
47 min_score: 0.0,
48 last_positions: vec![0; lowercase_query.len()],
49 match_positions: vec![0; query.len()],
50 score_matrix: Vec::new(),
51 best_position_matrix: Vec::new(),
52 smart_case,
53 max_results,
54 }
55 }
56
57 pub fn match_candidates<C: MatchCandidate, R, F>(
58 &mut self,
59 prefix: &[char],
60 lowercase_prefix: &[char],
61 candidates: impl Iterator<Item = C>,
62 results: &mut Vec<R>,
63 cancel_flag: &AtomicBool,
64 build_match: F,
65 ) where
66 R: Match,
67 F: Fn(&C, f64) -> R,
68 {
69 let mut candidate_chars = Vec::new();
70 let mut lowercase_candidate_chars = Vec::new();
71
72 for candidate in candidates {
73 if !candidate.has_chars(self.query_char_bag) {
74 continue;
75 }
76
77 if cancel_flag.load(atomic::Ordering::Relaxed) {
78 break;
79 }
80
81 candidate_chars.clear();
82 lowercase_candidate_chars.clear();
83 for c in candidate.to_string().chars() {
84 candidate_chars.push(c);
85 lowercase_candidate_chars.append(&mut c.to_lowercase().collect::<Vec<_>>());
86 }
87
88 if !self.find_last_positions(lowercase_prefix, &lowercase_candidate_chars) {
89 continue;
90 }
91
92 let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
93 self.score_matrix.clear();
94 self.score_matrix.resize(matrix_len, None);
95 self.best_position_matrix.clear();
96 self.best_position_matrix.resize(matrix_len, 0);
97
98 let score = self.score_match(
99 &candidate_chars,
100 &lowercase_candidate_chars,
101 prefix,
102 lowercase_prefix,
103 );
104
105 if score > 0.0 {
106 let mut mat = build_match(&candidate, score);
107 if let Err(i) = results.binary_search_by(|m| mat.cmp(m)) {
108 if results.len() < self.max_results {
109 mat.set_positions(self.match_positions.clone());
110 results.insert(i, mat);
111 } else if i < results.len() {
112 results.pop();
113 mat.set_positions(self.match_positions.clone());
114 results.insert(i, mat);
115 }
116 if results.len() == self.max_results {
117 self.min_score = results.last().unwrap().score();
118 }
119 }
120 }
121 }
122 }
123
124 fn find_last_positions(
125 &mut self,
126 lowercase_prefix: &[char],
127 lowercase_candidate: &[char],
128 ) -> bool {
129 let mut lowercase_prefix = lowercase_prefix.iter();
130 let mut lowercase_candidate = lowercase_candidate.iter();
131 for (i, char) in self.lowercase_query.iter().enumerate().rev() {
132 if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
133 self.last_positions[i] = j + lowercase_prefix.len();
134 } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
135 self.last_positions[i] = j;
136 } else {
137 return false;
138 }
139 }
140 true
141 }
142
143 fn score_match(
144 &mut self,
145 path: &[char],
146 path_cased: &[char],
147 prefix: &[char],
148 lowercase_prefix: &[char],
149 ) -> f64 {
150 let score = self.recursive_score_match(
151 path,
152 path_cased,
153 prefix,
154 lowercase_prefix,
155 0,
156 0,
157 self.query.len() as f64,
158 ) * self.query.len() as f64;
159
160 if score <= 0.0 {
161 return 0.0;
162 }
163
164 let path_len = prefix.len() + path.len();
165 let mut cur_start = 0;
166 let mut byte_ix = 0;
167 let mut char_ix = 0;
168 for i in 0..self.query.len() {
169 let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
170 while char_ix < match_char_ix {
171 let ch = prefix
172 .get(char_ix)
173 .or_else(|| path.get(char_ix - prefix.len()))
174 .unwrap();
175 byte_ix += ch.len_utf8();
176 char_ix += 1;
177 }
178 cur_start = match_char_ix + 1;
179 self.match_positions[i] = byte_ix;
180 }
181
182 score
183 }
184
185 #[allow(clippy::too_many_arguments)]
186 fn recursive_score_match(
187 &mut self,
188 path: &[char],
189 path_cased: &[char],
190 prefix: &[char],
191 lowercase_prefix: &[char],
192 query_idx: usize,
193 path_idx: usize,
194 cur_score: f64,
195 ) -> f64 {
196 if query_idx == self.query.len() {
197 return 1.0;
198 }
199
200 let path_len = prefix.len() + path.len();
201
202 if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
203 return memoized;
204 }
205
206 let mut score = 0.0;
207 let mut best_position = 0;
208
209 let query_char = self.lowercase_query[query_idx];
210 let limit = self.last_positions[query_idx];
211
212 let mut last_slash = 0;
213 for j in path_idx..=limit {
214 let path_char = if j < prefix.len() {
215 lowercase_prefix[j]
216 } else {
217 path_cased[j - prefix.len()]
218 };
219 let is_path_sep = path_char == '/' || path_char == '\\';
220
221 if query_idx == 0 && is_path_sep {
222 last_slash = j;
223 }
224
225 if query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\') {
226 let curr = if j < prefix.len() {
227 prefix[j]
228 } else {
229 path[j - prefix.len()]
230 };
231
232 let mut char_score = 1.0;
233 if j > path_idx {
234 let last = if j - 1 < prefix.len() {
235 prefix[j - 1]
236 } else {
237 path[j - 1 - prefix.len()]
238 };
239
240 if last == '/' {
241 char_score = 0.9;
242 } else if (last == '-' || last == '_' || last == ' ' || last.is_numeric())
243 || (last.is_lowercase() && curr.is_uppercase())
244 {
245 char_score = 0.8;
246 } else if last == '.' {
247 char_score = 0.7;
248 } else if query_idx == 0 {
249 char_score = BASE_DISTANCE_PENALTY;
250 } else {
251 char_score = MIN_DISTANCE_PENALTY.max(
252 BASE_DISTANCE_PENALTY
253 - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
254 );
255 }
256 }
257
258 // Apply a severe penalty if the case doesn't match.
259 // This will make the exact matches have higher score than the case-insensitive and the
260 // path insensitive matches.
261 if (self.smart_case || curr == '/') && self.query[query_idx] != curr {
262 char_score *= 0.001;
263 }
264
265 let mut multiplier = char_score;
266
267 // Scale the score based on how deep within the path we found the match.
268 if query_idx == 0 {
269 multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
270 }
271
272 let mut next_score = 1.0;
273 if self.min_score > 0.0 {
274 next_score = cur_score * multiplier;
275 // Scores only decrease. If we can't pass the previous best, bail
276 if next_score < self.min_score {
277 // Ensure that score is non-zero so we use it in the memo table.
278 if score == 0.0 {
279 score = 1e-18;
280 }
281 continue;
282 }
283 }
284
285 let new_score = self.recursive_score_match(
286 path,
287 path_cased,
288 prefix,
289 lowercase_prefix,
290 query_idx + 1,
291 j + 1,
292 next_score,
293 ) * multiplier;
294
295 if new_score > score {
296 score = new_score;
297 best_position = j;
298 // Optimization: can't score better than 1.
299 if new_score == 1.0 {
300 break;
301 }
302 }
303 }
304 }
305
306 if best_position != 0 {
307 self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
308 }
309
310 self.score_matrix[query_idx * path_len + path_idx] = Some(score);
311 score
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use crate::{PathMatch, PathMatchCandidate};
318
319 use super::*;
320 use std::{
321 path::{Path, PathBuf},
322 sync::Arc,
323 };
324
325 #[test]
326 fn test_get_last_positions() {
327 let mut query: &[char] = &['d', 'c'];
328 let mut matcher = Matcher::new(query, query, query.into(), false, 10);
329 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
330 assert!(!result);
331
332 query = &['c', 'd'];
333 let mut matcher = Matcher::new(query, query, query.into(), false, 10);
334 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
335 assert!(result);
336 assert_eq!(matcher.last_positions, vec![2, 4]);
337
338 query = &['z', '/', 'z', 'f'];
339 let mut matcher = Matcher::new(query, query, query.into(), false, 10);
340 let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
341 assert!(result);
342 assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
343 }
344
345 #[test]
346 fn test_match_path_entries() {
347 let paths = vec![
348 "",
349 "a",
350 "ab",
351 "abC",
352 "abcd",
353 "alphabravocharlie",
354 "AlphaBravoCharlie",
355 "thisisatestdir",
356 "/////ThisIsATestDir",
357 "/this/is/a/test/dir",
358 "/test/tiatd",
359 ];
360
361 assert_eq!(
362 match_single_path_query("abc", false, &paths),
363 vec![
364 ("abC", vec![0, 1, 2]),
365 ("abcd", vec![0, 1, 2]),
366 ("AlphaBravoCharlie", vec![0, 5, 10]),
367 ("alphabravocharlie", vec![4, 5, 10]),
368 ]
369 );
370 assert_eq!(
371 match_single_path_query("t/i/a/t/d", false, &paths),
372 vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
373 );
374
375 assert_eq!(
376 match_single_path_query("tiatd", false, &paths),
377 vec![
378 ("/test/tiatd", vec![6, 7, 8, 9, 10]),
379 ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
380 ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
381 ("thisisatestdir", vec![0, 2, 6, 7, 11]),
382 ]
383 );
384 }
385
386 #[test]
387 fn test_lowercase_longer_than_uppercase() {
388 // This character has more chars in lower-case than in upper-case.
389 let paths = vec!["\u{0130}"];
390 let query = "\u{0130}";
391 assert_eq!(
392 match_single_path_query(query, false, &paths),
393 vec![("\u{0130}", vec![0])]
394 );
395
396 // Path is the lower-case version of the query
397 let paths = vec!["i\u{307}"];
398 let query = "\u{0130}";
399 assert_eq!(
400 match_single_path_query(query, false, &paths),
401 vec![("i\u{307}", vec![0])]
402 );
403 }
404
405 #[test]
406 fn test_match_multibyte_path_entries() {
407 let paths = vec![
408 "aαbβ/cγdδ",
409 "αβγδ/bcde",
410 "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f",
411 "/d/🆒/h",
412 ];
413 assert_eq!("1️⃣".len(), 7);
414 assert_eq!(
415 match_single_path_query("bcd", false, &paths),
416 vec![
417 ("αβγδ/bcde", vec![9, 10, 11]),
418 ("aαbβ/cγdδ", vec![3, 7, 10]),
419 ]
420 );
421 assert_eq!(
422 match_single_path_query("cde", false, &paths),
423 vec![
424 ("αβγδ/bcde", vec![10, 11, 12]),
425 ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
426 ]
427 );
428 }
429
430 fn match_single_path_query<'a>(
431 query: &str,
432 smart_case: bool,
433 paths: &[&'a str],
434 ) -> Vec<(&'a str, Vec<usize>)> {
435 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
436 let query = query.chars().collect::<Vec<_>>();
437 let query_chars = CharBag::from(&lowercase_query[..]);
438
439 let path_arcs: Vec<Arc<Path>> = paths
440 .iter()
441 .map(|path| Arc::from(PathBuf::from(path)))
442 .collect::<Vec<_>>();
443 let mut path_entries = Vec::new();
444 for (i, path) in paths.iter().enumerate() {
445 let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
446 let char_bag = CharBag::from(lowercase_path.as_slice());
447 path_entries.push(PathMatchCandidate {
448 is_dir: false,
449 char_bag,
450 path: &path_arcs[i],
451 });
452 }
453
454 let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case, 100);
455
456 let cancel_flag = AtomicBool::new(false);
457 let mut results = Vec::new();
458
459 matcher.match_candidates(
460 &[],
461 &[],
462 path_entries.into_iter(),
463 &mut results,
464 &cancel_flag,
465 |candidate, score| PathMatch {
466 score,
467 worktree_id: 0,
468 positions: Vec::new(),
469 path: Arc::from(candidate.path),
470 path_prefix: "".into(),
471 distance_to_relative_ancestor: usize::MAX,
472 is_dir: false,
473 },
474 );
475
476 results
477 .into_iter()
478 .map(|result| {
479 (
480 paths
481 .iter()
482 .copied()
483 .find(|p| result.path.as_ref() == Path::new(p))
484 .unwrap(),
485 result.positions,
486 )
487 })
488 .collect()
489 }
490}