1use crate::{CharBag, matcher};
2use gpui::BackgroundExecutor;
3use nucleo::pattern::{AtomKind, CaseMatching, Normalization, Pattern};
4use std::{
5 borrow::Borrow,
6 cmp, iter,
7 ops::Range,
8 sync::atomic::{AtomicBool, Ordering},
9};
10
11#[derive(Clone, Debug)]
12pub struct StringMatchCandidate {
13 pub id: usize,
14 pub string: String,
15 pub char_bag: CharBag,
16}
17
18impl StringMatchCandidate {
19 pub fn new(id: usize, string: &str) -> Self {
20 Self {
21 id,
22 string: string.into(),
23 char_bag: string.into(),
24 }
25 }
26}
27
28#[derive(Clone, Debug)]
29pub struct StringMatch {
30 pub candidate_id: usize,
31 pub score: f64,
32 pub positions: Vec<usize>,
33 pub string: String,
34}
35
36impl StringMatch {
37 pub fn ranges(&self) -> impl '_ + Iterator<Item = Range<usize>> {
38 let mut positions = self.positions.iter().peekable();
39 iter::from_fn(move || {
40 if let Some(start) = positions.next().copied() {
41 let Some(char_len) = self.char_len_at_index(start) else {
42 log::error!(
43 "Invariant violation: Index {start} out of range or not on a utf-8 boundary in string {:?}",
44 self.string
45 );
46 return None;
47 };
48 let mut end = start + char_len;
49 while let Some(next_start) = positions.peek() {
50 if end == **next_start {
51 let Some(char_len) = self.char_len_at_index(end) else {
52 log::error!(
53 "Invariant violation: Index {end} out of range or not on a utf-8 boundary in string {:?}",
54 self.string
55 );
56 return None;
57 };
58 end += char_len;
59 positions.next();
60 } else {
61 break;
62 }
63 }
64
65 return Some(start..end);
66 }
67 None
68 })
69 }
70
71 /// Gets the byte length of the utf-8 character at a byte offset. If the index is out of range
72 /// or not on a utf-8 boundary then None is returned.
73 fn char_len_at_index(&self, ix: usize) -> Option<usize> {
74 self.string
75 .get(ix..)
76 .and_then(|slice| slice.chars().next().map(|char| char.len_utf8()))
77 }
78}
79
80impl PartialEq for StringMatch {
81 fn eq(&self, other: &Self) -> bool {
82 self.cmp(other).is_eq()
83 }
84}
85
86impl Eq for StringMatch {}
87
88impl PartialOrd for StringMatch {
89 fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
90 Some(self.cmp(other))
91 }
92}
93
94impl Ord for StringMatch {
95 fn cmp(&self, other: &Self) -> cmp::Ordering {
96 // dbg!(&self.string, self.score);
97 // dbg!(&other.string, other.score);
98 self.score
99 .total_cmp(&other.score)
100 .reverse()
101 .then_with(|| self.string.cmp(&other.string))
102 }
103}
104
105pub async fn match_strings<T>(
106 candidates: &[T],
107 query: &str,
108 smart_case: bool,
109 prefer_shorter: bool,
110 max_results: usize,
111 cancel_flag: &AtomicBool,
112 executor: BackgroundExecutor,
113) -> Vec<StringMatch>
114where
115 T: Borrow<StringMatchCandidate> + Sync,
116{
117 if candidates.is_empty() || max_results == 0 {
118 return Default::default();
119 }
120 // FIXME should support fzf syntax with Pattern::parse
121 let pattern = Pattern::new(
122 query,
123 if smart_case {
124 CaseMatching::Smart
125 } else {
126 CaseMatching::Ignore
127 },
128 Normalization::Smart,
129 AtomKind::Fuzzy,
130 );
131
132 if query.is_empty() {
133 return candidates
134 .iter()
135 .map(|candidate| StringMatch {
136 candidate_id: candidate.borrow().id,
137 score: 0.,
138 positions: Default::default(),
139 string: candidate.borrow().string.clone(),
140 })
141 .collect();
142 }
143
144 let num_cpus = executor.num_cpus().min(candidates.len());
145 let segment_size = candidates.len().div_ceil(num_cpus);
146 let mut segment_results = (0..num_cpus)
147 .map(|_| Vec::<StringMatch>::with_capacity(max_results.min(candidates.len())))
148 .collect::<Vec<_>>();
149
150 let mut config = nucleo::Config::DEFAULT;
151 config.prefer_prefix = true; // TODO: consider making this a setting
152 let mut matchers = matcher::get_matchers(num_cpus, config);
153
154 executor
155 .scoped(|scope| {
156 for (segment_idx, (results, matcher)) in segment_results
157 .iter_mut()
158 .zip(matchers.iter_mut())
159 .enumerate()
160 {
161 let cancel_flag = &cancel_flag;
162 let pattern = pattern.clone();
163 scope.spawn(async move {
164 let segment_start = cmp::min(segment_idx * segment_size, candidates.len());
165 let segment_end = cmp::min(segment_start + segment_size, candidates.len());
166
167 for c in candidates[segment_start..segment_end].iter() {
168 if cancel_flag.load(Ordering::Relaxed) {
169 break;
170 }
171 let candidate = c.borrow();
172 let mut indices = Vec::new();
173 let mut buf = Vec::new();
174 if let Some(score) = pattern.indices(
175 nucleo::Utf32Str::new(&candidate.string, &mut buf),
176 matcher,
177 &mut indices,
178 ) {
179 let length_modifier = candidate.string.chars().count() as f64 / 10_000.;
180 results.push(StringMatch {
181 candidate_id: candidate.id,
182 score: score as f64
183 + if prefer_shorter {
184 -length_modifier
185 } else {
186 length_modifier
187 },
188
189 // TODO: need to convert indices/positions from char offsets to byte offsets.
190 positions: indices.into_iter().map(|n| n as usize).collect(),
191 string: candidate.string.clone(),
192 })
193 };
194 }
195 });
196 }
197 })
198 .await;
199
200 matcher::return_matchers(matchers);
201
202 if cancel_flag.load(Ordering::Relaxed) {
203 return Vec::new();
204 }
205
206 let mut results = segment_results.concat();
207 util::truncate_to_bottom_n_sorted(&mut results, max_results);
208 for r in &mut results {
209 r.positions.sort();
210 }
211 results
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 use std::sync::atomic::AtomicBool;
218
219 use gpui::TestAppContext;
220
221 async fn get_matches(
222 cx: &mut TestAppContext,
223 candidates: &[&'static str],
224 query: &'static str,
225 penalize_length: bool,
226 ) -> Vec<StringMatch> {
227 let candidates: Vec<_> = candidates
228 .iter()
229 .enumerate()
230 .map(|(i, s)| StringMatchCandidate::new(i, s))
231 .collect();
232
233 let cancellation_flag = AtomicBool::new(false);
234 let executor = cx.background_executor.clone();
235 cx.foreground_executor
236 .spawn(async move {
237 super::match_strings(
238 &candidates,
239 query,
240 true,
241 penalize_length,
242 100,
243 &cancellation_flag,
244 executor,
245 )
246 .await
247 })
248 .await
249 }
250
251 async fn string_matches(
252 cx: &mut TestAppContext,
253 candidates: &[&'static str],
254 query: &'static str,
255 penalize_length: bool,
256 ) -> Vec<String> {
257 let matches = get_matches(cx, candidates, query, penalize_length).await;
258 matches
259 .iter()
260 .map(|sm| dbg!(sm).string.clone())
261 .collect::<Vec<_>>()
262 }
263
264 async fn match_positions(
265 cx: &mut TestAppContext,
266 candidates: &[&'static str],
267 query: &'static str,
268 penalize_length: bool,
269 ) -> Vec<usize> {
270 let mut matches = get_matches(cx, candidates, query, penalize_length).await;
271 matches.remove(0).positions
272 }
273
274 #[gpui::test]
275 async fn prefer_shorter_matches(cx: &mut TestAppContext) {
276 let candidates = &["a", "aa", "aaa"];
277 assert_eq!(
278 string_matches(cx, candidates, "a", true).await,
279 ["a", "aa", "aaa"]
280 );
281 }
282
283 #[gpui::test]
284 async fn prefer_longer_matches(cx: &mut TestAppContext) {
285 let candidates = &["unreachable", "unreachable!()"];
286 assert_eq!(
287 string_matches(cx, candidates, "unreac", false).await,
288 ["unreachable!()", "unreachable",]
289 );
290 }
291
292 #[gpui::test]
293 async fn shorter_over_lexicographical(cx: &mut TestAppContext) {
294 const CANDIDATES: &'static [&'static str] = &["qr", "qqqqqqqqqqqq"];
295 assert_eq!(
296 string_matches(cx, CANDIDATES, "q", true).await,
297 ["qr", "qqqqqqqqqqqq"]
298 );
299 }
300
301 #[gpui::test]
302 async fn indices_are_sorted_and_correct(cx: &mut TestAppContext) {
303 const CANDIDATES: &'static [&'static str] = &["hello how are you"];
304 assert_eq!(
305 match_positions(cx, CANDIDATES, "you hello", true).await,
306 vec![0, 1, 2, 3, 4, 14, 15, 16]
307 );
308
309 // const CANDIDATES: &'static [&'static str] =
310 // &["crates/livekit_api/vendored/protocol/README.md"];
311 }
312
313 // This is broken?
314 #[gpui::test]
315 async fn broken_nucleo_matcher(cx: &mut TestAppContext) {
316 let candidates = &["lsp_code_lens", "code_lens"];
317 assert_eq!(
318 string_matches(cx, candidates, "lens", false).await,
319 ["code_lens", "lsp_code_lens",]
320 );
321 }
322}