1use std::{
2 borrow::Borrow,
3 collections::BTreeMap,
4 sync::atomic::{self, AtomicBool},
5};
6
7use crate::CharBag;
8
9const BASE_DISTANCE_PENALTY: f64 = 0.6;
10const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
11const MIN_DISTANCE_PENALTY: f64 = 0.2;
12
13// TODO:
14// Use `Path` instead of `&str` for paths.
15pub struct Matcher<'a> {
16 query: &'a [char],
17 lowercase_query: &'a [char],
18 query_char_bag: CharBag,
19 smart_case: bool,
20 penalize_length: bool,
21 min_score: f64,
22 match_positions: Vec<usize>,
23 last_positions: Vec<usize>,
24 score_matrix: Vec<Option<f64>>,
25 best_position_matrix: Vec<usize>,
26}
27
28pub trait MatchCandidate {
29 fn has_chars(&self, bag: CharBag) -> bool;
30 fn candidate_chars(&self) -> impl Iterator<Item = char>;
31}
32
33impl<'a> Matcher<'a> {
34 pub fn new(
35 query: &'a [char],
36 lowercase_query: &'a [char],
37 query_char_bag: CharBag,
38 smart_case: bool,
39 penalize_length: bool,
40 ) -> Self {
41 Self {
42 query,
43 lowercase_query,
44 query_char_bag,
45 min_score: 0.0,
46 last_positions: vec![0; lowercase_query.len()],
47 match_positions: vec![0; query.len()],
48 score_matrix: Vec::new(),
49 best_position_matrix: Vec::new(),
50 smart_case,
51 penalize_length,
52 }
53 }
54
55 /// Filter and score fuzzy match candidates. Results are returned unsorted, in the same order as
56 /// the input candidates.
57 pub(crate) fn match_candidates<C, R, F, T>(
58 &mut self,
59 prefix: &[char],
60 lowercase_prefix: &[char],
61 candidates: impl Iterator<Item = T>,
62 results: &mut Vec<R>,
63 cancel_flag: &AtomicBool,
64 build_match: F,
65 ) where
66 C: MatchCandidate,
67 T: Borrow<C>,
68 F: Fn(&C, f64, &Vec<usize>) -> R,
69 {
70 let mut candidate_chars = Vec::new();
71 let mut lowercase_candidate_chars = Vec::new();
72 let mut extra_lowercase_chars = BTreeMap::new();
73
74 for candidate in candidates {
75 if !candidate.borrow().has_chars(self.query_char_bag) {
76 continue;
77 }
78
79 if cancel_flag.load(atomic::Ordering::Acquire) {
80 break;
81 }
82
83 candidate_chars.clear();
84 lowercase_candidate_chars.clear();
85 extra_lowercase_chars.clear();
86 for (i, c) in candidate.borrow().candidate_chars().enumerate() {
87 candidate_chars.push(c);
88 let mut char_lowercased = c.to_lowercase().collect::<Vec<_>>();
89 if char_lowercased.len() > 1 {
90 extra_lowercase_chars.insert(i, char_lowercased.len() - 1);
91 }
92 lowercase_candidate_chars.append(&mut char_lowercased);
93 }
94
95 if !self.find_last_positions(lowercase_prefix, &lowercase_candidate_chars) {
96 continue;
97 }
98
99 let matrix_len =
100 self.query.len() * (lowercase_prefix.len() + lowercase_candidate_chars.len());
101 self.score_matrix.clear();
102 self.score_matrix.resize(matrix_len, None);
103 self.best_position_matrix.clear();
104 self.best_position_matrix.resize(matrix_len, 0);
105
106 let score = self.score_match(
107 &candidate_chars,
108 &lowercase_candidate_chars,
109 prefix,
110 lowercase_prefix,
111 &extra_lowercase_chars,
112 );
113
114 if score > 0.0 {
115 results.push(build_match(
116 candidate.borrow(),
117 score,
118 &self.match_positions,
119 ));
120 }
121 }
122 }
123
124 fn find_last_positions(
125 &mut self,
126 lowercase_prefix: &[char],
127 lowercase_candidate: &[char],
128 ) -> bool {
129 let mut lowercase_prefix = lowercase_prefix.iter();
130 let mut lowercase_candidate = lowercase_candidate.iter();
131 for (i, char) in self.lowercase_query.iter().enumerate().rev() {
132 if let Some(j) = lowercase_candidate.rposition(|c| c == char) {
133 self.last_positions[i] = j + lowercase_prefix.len();
134 } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) {
135 self.last_positions[i] = j;
136 } else {
137 return false;
138 }
139 }
140 true
141 }
142
143 fn score_match(
144 &mut self,
145 path: &[char],
146 path_lowercased: &[char],
147 prefix: &[char],
148 lowercase_prefix: &[char],
149 extra_lowercase_chars: &BTreeMap<usize, usize>,
150 ) -> f64 {
151 let score = self.recursive_score_match(
152 path,
153 path_lowercased,
154 prefix,
155 lowercase_prefix,
156 0,
157 0,
158 self.query.len() as f64,
159 extra_lowercase_chars,
160 ) * self.query.len() as f64;
161
162 if score <= 0.0 {
163 return 0.0;
164 }
165 let path_len = prefix.len() + path.len();
166 let mut cur_start = 0;
167 let mut byte_ix = 0;
168 let mut char_ix = 0;
169 for i in 0..self.query.len() {
170 let match_char_ix = self.best_position_matrix[i * path_len + cur_start];
171 while char_ix < match_char_ix {
172 let ch = prefix
173 .get(char_ix)
174 .or_else(|| path.get(char_ix - prefix.len()))
175 .unwrap();
176 byte_ix += ch.len_utf8();
177 char_ix += 1;
178 }
179
180 self.match_positions[i] = byte_ix;
181
182 let matched_ch = prefix
183 .get(match_char_ix)
184 .or_else(|| path.get(match_char_ix - prefix.len()))
185 .unwrap();
186 byte_ix += matched_ch.len_utf8();
187
188 cur_start = match_char_ix + 1;
189 char_ix = match_char_ix + 1;
190 }
191
192 score
193 }
194
195 fn recursive_score_match(
196 &mut self,
197 path: &[char],
198 path_lowercased: &[char],
199 prefix: &[char],
200 lowercase_prefix: &[char],
201 query_idx: usize,
202 path_idx: usize,
203 cur_score: f64,
204 extra_lowercase_chars: &BTreeMap<usize, usize>,
205 ) -> f64 {
206 if query_idx == self.query.len() {
207 return 1.0;
208 }
209
210 let limit = self.last_positions[query_idx];
211 let max_valid_index = (prefix.len() + path_lowercased.len()).saturating_sub(1);
212 let safe_limit = limit.min(max_valid_index);
213
214 if path_idx > safe_limit {
215 return 0.0;
216 }
217
218 let path_len = prefix.len() + path.len();
219 if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] {
220 return memoized;
221 }
222
223 let mut score = 0.0;
224 let mut best_position = 0;
225
226 let query_char = self.lowercase_query[query_idx];
227
228 let mut last_slash = 0;
229
230 for j in path_idx..=safe_limit {
231 let extra_lowercase_chars_count = extra_lowercase_chars
232 .iter()
233 .take_while(|&(&i, _)| i < j)
234 .map(|(_, increment)| increment)
235 .sum::<usize>();
236 let j_regular = j - extra_lowercase_chars_count;
237
238 let path_char = if j < prefix.len() {
239 lowercase_prefix[j]
240 } else {
241 let path_index = j - prefix.len();
242 match path_lowercased.get(path_index) {
243 Some(&char) => char,
244 None => continue,
245 }
246 };
247 let is_path_sep = path_char == '/';
248
249 if query_idx == 0 && is_path_sep {
250 last_slash = j_regular;
251 }
252 let need_to_score = query_char == path_char || (is_path_sep && query_char == '_');
253 if need_to_score {
254 let curr = match prefix.get(j_regular) {
255 Some(&curr) => curr,
256 None => path[j_regular - prefix.len()],
257 };
258
259 let mut char_score = 1.0;
260 if j > path_idx {
261 let last = match prefix.get(j_regular - 1) {
262 Some(&last) => last,
263 None => path[j_regular - 1 - prefix.len()],
264 };
265
266 if last == '/' {
267 char_score = 0.9;
268 } else if (last == '-' || last == '_' || last == ' ' || last.is_numeric())
269 || (last.is_lowercase() && curr.is_uppercase())
270 {
271 char_score = 0.8;
272 } else if last == '.' {
273 char_score = 0.7;
274 } else if query_idx == 0 {
275 char_score = BASE_DISTANCE_PENALTY;
276 } else {
277 char_score = MIN_DISTANCE_PENALTY.max(
278 BASE_DISTANCE_PENALTY
279 - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
280 );
281 }
282 }
283
284 // Apply a severe penalty if the case doesn't match.
285 // This will make the exact matches have higher score than the case-insensitive and the
286 // path insensitive matches.
287 if (self.smart_case || curr == '/') && self.query[query_idx] != curr {
288 char_score *= 0.001;
289 }
290
291 let mut multiplier = char_score;
292
293 // Scale the score based on how deep within the path we found the match.
294 if self.penalize_length && query_idx == 0 {
295 multiplier /= ((prefix.len() + path.len()) - last_slash) as f64;
296 }
297
298 let mut next_score = 1.0;
299 if self.min_score > 0.0 {
300 next_score = cur_score * multiplier;
301 // Scores only decrease. If we can't pass the previous best, bail
302 if next_score < self.min_score {
303 // Ensure that score is non-zero so we use it in the memo table.
304 if score == 0.0 {
305 score = 1e-18;
306 }
307 continue;
308 }
309 }
310
311 let new_score = self.recursive_score_match(
312 path,
313 path_lowercased,
314 prefix,
315 lowercase_prefix,
316 query_idx + 1,
317 j + 1,
318 next_score,
319 extra_lowercase_chars,
320 ) * multiplier;
321
322 if new_score > score {
323 score = new_score;
324 best_position = j_regular;
325 // Optimization: can't score better than 1.
326 if new_score == 1.0 {
327 break;
328 }
329 }
330 }
331 }
332
333 if best_position != 0 {
334 self.best_position_matrix[query_idx * path_len + path_idx] = best_position;
335 }
336
337 self.score_matrix[query_idx * path_len + path_idx] = Some(score);
338 score
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use util::rel_path::{RelPath, rel_path};
345
346 use crate::{PathMatch, PathMatchCandidate};
347
348 use super::*;
349 use std::sync::Arc;
350
351 #[test]
352 fn test_get_last_positions() {
353 let mut query: &[char] = &['d', 'c'];
354 let mut matcher = Matcher::new(query, query, query.into(), false, true);
355 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
356 assert!(!result);
357
358 query = &['c', 'd'];
359 let mut matcher = Matcher::new(query, query, query.into(), false, true);
360 let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']);
361 assert!(result);
362 assert_eq!(matcher.last_positions, vec![2, 4]);
363
364 query = &['z', '/', 'z', 'f'];
365 let mut matcher = Matcher::new(query, query, query.into(), false, true);
366 let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']);
367 assert!(result);
368 assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]);
369 }
370
371 #[test]
372 fn test_match_path_entries() {
373 let paths = vec![
374 "",
375 "a",
376 "ab",
377 "abC",
378 "abcd",
379 "alphabravocharlie",
380 "AlphaBravoCharlie",
381 "thisisatestdir",
382 "ThisIsATestDir",
383 "this/is/a/test/dir",
384 "test/tiatd",
385 ];
386
387 assert_eq!(
388 match_single_path_query("abc", false, &paths),
389 vec![
390 ("abC", vec![0, 1, 2]),
391 ("abcd", vec![0, 1, 2]),
392 ("AlphaBravoCharlie", vec![0, 5, 10]),
393 ("alphabravocharlie", vec![4, 5, 10]),
394 ]
395 );
396 assert_eq!(
397 match_single_path_query("t/i/a/t/d", false, &paths),
398 vec![("this/is/a/test/dir", vec![0, 4, 5, 7, 8, 9, 10, 14, 15]),]
399 );
400
401 assert_eq!(
402 match_single_path_query("tiatd", false, &paths),
403 vec![
404 ("test/tiatd", vec![5, 6, 7, 8, 9]),
405 ("ThisIsATestDir", vec![0, 4, 6, 7, 11]),
406 ("this/is/a/test/dir", vec![0, 5, 8, 10, 15]),
407 ("thisisatestdir", vec![0, 2, 6, 7, 11]),
408 ]
409 );
410 }
411
412 #[test]
413 fn test_lowercase_longer_than_uppercase() {
414 // This character has more chars in lower-case than in upper-case.
415 let paths = vec!["\u{0130}"];
416 let query = "\u{0130}";
417 assert_eq!(
418 match_single_path_query(query, false, &paths),
419 vec![("\u{0130}", vec![0])]
420 );
421
422 // Path is the lower-case version of the query
423 let paths = vec!["i\u{307}"];
424 let query = "\u{0130}";
425 assert_eq!(
426 match_single_path_query(query, false, &paths),
427 vec![("i\u{307}", vec![0])]
428 );
429 }
430
431 #[test]
432 fn test_match_multibyte_path_entries() {
433 let paths = vec![
434 "aαbβ/cγdδ",
435 "αβγδ/bcde",
436 "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f",
437 "d/🆒/h",
438 ];
439 assert_eq!("1️⃣".len(), 7);
440 assert_eq!(
441 match_single_path_query("bcd", false, &paths),
442 vec![
443 ("αβγδ/bcde", vec![9, 10, 11]),
444 ("aαbβ/cγdδ", vec![3, 7, 10]),
445 ]
446 );
447 assert_eq!(
448 match_single_path_query("cde", false, &paths),
449 vec![
450 ("αβγδ/bcde", vec![10, 11, 12]),
451 ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]),
452 ]
453 );
454 }
455
456 #[test]
457 fn match_unicode_path_entries() {
458 let mixed_unicode_paths = vec![
459 "İolu/oluş",
460 "İstanbul/code",
461 "Athens/Şanlıurfa",
462 "Çanakkale/scripts",
463 "paris/Düzce_İl",
464 "Berlin_Önemli_Ğündem",
465 "KİTAPLIK/london/dosya",
466 "tokyo/kyoto/fuji",
467 "new_york/san_francisco",
468 ];
469
470 assert_eq!(
471 match_single_path_query("İo/oluş", false, &mixed_unicode_paths),
472 vec![("İolu/oluş", vec![0, 2, 4, 6, 8, 10, 12])]
473 );
474
475 assert_eq!(
476 match_single_path_query("İst/code", false, &mixed_unicode_paths),
477 vec![("İstanbul/code", vec![0, 2, 4, 6, 8, 10, 12, 14])]
478 );
479
480 assert_eq!(
481 match_single_path_query("athens/şa", false, &mixed_unicode_paths),
482 vec![("Athens/Şanlıurfa", vec![0, 1, 2, 3, 4, 5, 6, 7, 9])]
483 );
484
485 assert_eq!(
486 match_single_path_query("BerlinÖĞ", false, &mixed_unicode_paths),
487 vec![("Berlin_Önemli_Ğündem", vec![0, 1, 2, 3, 4, 5, 7, 15])]
488 );
489
490 assert_eq!(
491 match_single_path_query("tokyo/fuji", false, &mixed_unicode_paths),
492 vec![("tokyo/kyoto/fuji", vec![0, 1, 2, 3, 4, 5, 12, 13, 14, 15])]
493 );
494
495 let mixed_script_paths = vec![
496 "résumé_Москва",
497 "naïve_київ_implementation",
498 "café_北京_app",
499 "東京_über_driver",
500 "déjà_vu_cairo",
501 "seoul_piñata_game",
502 "voilà_istanbul_result",
503 ];
504
505 assert_eq!(
506 match_single_path_query("résmé", false, &mixed_script_paths),
507 vec![("résumé_Москва", vec![0, 1, 3, 5, 6])]
508 );
509
510 assert_eq!(
511 match_single_path_query("café北京", false, &mixed_script_paths),
512 vec![("café_北京_app", vec![0, 1, 2, 3, 6, 9])]
513 );
514
515 assert_eq!(
516 match_single_path_query("ista", false, &mixed_script_paths),
517 vec![("voilà_istanbul_result", vec![7, 8, 9, 10])]
518 );
519
520 let complex_paths = vec![
521 "document_📚_library",
522 "project_👨👩👧👦_family",
523 "flags_🇯🇵🇺🇸🇪🇺_world",
524 "code_😀😃😄😁_happy",
525 "photo_👩👩👧👦_album",
526 ];
527
528 assert_eq!(
529 match_single_path_query("doc📚lib", false, &complex_paths),
530 vec![("document_📚_library", vec![0, 1, 2, 9, 14, 15, 16])]
531 );
532
533 assert_eq!(
534 match_single_path_query("codehappy", false, &complex_paths),
535 vec![("code_😀😃😄😁_happy", vec![0, 1, 2, 3, 22, 23, 24, 25, 26])]
536 );
537 }
538
539 fn match_single_path_query<'a>(
540 query: &str,
541 smart_case: bool,
542 paths: &[&'a str],
543 ) -> Vec<(&'a str, Vec<usize>)> {
544 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
545 let query = query.chars().collect::<Vec<_>>();
546 let query_chars = CharBag::from(&lowercase_query[..]);
547
548 let path_arcs: Vec<Arc<RelPath>> = paths
549 .iter()
550 .map(|path| Arc::from(rel_path(path)))
551 .collect::<Vec<_>>();
552 let mut path_entries = Vec::new();
553 for (i, path) in paths.iter().enumerate() {
554 let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
555 let char_bag = CharBag::from(lowercase_path.as_slice());
556 path_entries.push(PathMatchCandidate {
557 is_dir: false,
558 char_bag,
559 path: &path_arcs[i],
560 });
561 }
562
563 let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case, true);
564
565 let cancel_flag = AtomicBool::new(false);
566 let mut results = Vec::new();
567
568 matcher.match_candidates(
569 &[],
570 &[],
571 path_entries.into_iter(),
572 &mut results,
573 &cancel_flag,
574 |candidate, score, positions| PathMatch {
575 score,
576 worktree_id: 0,
577 positions: positions.clone(),
578 path: candidate.path.into(),
579 path_prefix: RelPath::empty().into(),
580 distance_to_relative_ancestor: usize::MAX,
581 is_dir: false,
582 },
583 );
584 results.sort_by(|a, b| b.cmp(a));
585
586 results
587 .into_iter()
588 .map(|result| {
589 (
590 paths
591 .iter()
592 .copied()
593 .find(|p| result.path.as_ref() == rel_path(p))
594 .unwrap(),
595 result.positions,
596 )
597 })
598 .collect()
599 }
600
601 /// Test for https://github.com/zed-industries/zed/issues/44324
602 #[test]
603 fn test_recursive_score_match_index_out_of_bounds() {
604 let paths = vec!["İ/İ/İ/İ"];
605 let query = "İ/İ";
606
607 // This panicked with "index out of bounds: the len is 21 but the index is 22"
608 let result = match_single_path_query(query, false, &paths);
609 let _ = result;
610 }
611}