strings.rs

  1use crate::{CharBag, matcher};
  2use gpui::BackgroundExecutor;
  3use nucleo::pattern::{AtomKind, CaseMatching, Normalization, Pattern};
  4use std::{
  5    borrow::Borrow,
  6    cmp, iter,
  7    ops::Range,
  8    sync::atomic::{AtomicBool, Ordering},
  9};
 10
 11#[derive(Clone, Debug)]
 12pub struct StringMatchCandidate {
 13    pub id: usize,
 14    pub string: String,
 15    pub char_bag: CharBag,
 16}
 17
 18impl StringMatchCandidate {
 19    pub fn new(id: usize, string: &str) -> Self {
 20        Self {
 21            id,
 22            string: string.into(),
 23            char_bag: string.into(),
 24        }
 25    }
 26}
 27
 28#[derive(Clone, Debug)]
 29pub struct StringMatch {
 30    pub candidate_id: usize,
 31    pub score: f64,
 32    pub positions: Vec<usize>,
 33    pub string: String,
 34}
 35
 36impl StringMatch {
 37    pub fn ranges(&self) -> impl '_ + Iterator<Item = Range<usize>> {
 38        let mut positions = self.positions.iter().peekable();
 39        iter::from_fn(move || {
 40            if let Some(start) = positions.next().copied() {
 41                let Some(char_len) = self.char_len_at_index(start) else {
 42                    log::error!(
 43                        "Invariant violation: Index {start} out of range or not on a utf-8 boundary in string {:?}",
 44                        self.string
 45                    );
 46                    return None;
 47                };
 48                let mut end = start + char_len;
 49                while let Some(next_start) = positions.peek() {
 50                    if end == **next_start {
 51                        let Some(char_len) = self.char_len_at_index(end) else {
 52                            log::error!(
 53                                "Invariant violation: Index {end} out of range or not on a utf-8 boundary in string {:?}",
 54                                self.string
 55                            );
 56                            return None;
 57                        };
 58                        end += char_len;
 59                        positions.next();
 60                    } else {
 61                        break;
 62                    }
 63                }
 64
 65                return Some(start..end);
 66            }
 67            None
 68        })
 69    }
 70
 71    /// Gets the byte length of the utf-8 character at a byte offset. If the index is out of range
 72    /// or not on a utf-8 boundary then None is returned.
 73    fn char_len_at_index(&self, ix: usize) -> Option<usize> {
 74        self.string
 75            .get(ix..)
 76            .and_then(|slice| slice.chars().next().map(|char| char.len_utf8()))
 77    }
 78}
 79
 80impl PartialEq for StringMatch {
 81    fn eq(&self, other: &Self) -> bool {
 82        self.cmp(other).is_eq()
 83    }
 84}
 85
 86impl Eq for StringMatch {}
 87
 88impl PartialOrd for StringMatch {
 89    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
 90        Some(self.cmp(other))
 91    }
 92}
 93
 94impl Ord for StringMatch {
 95    fn cmp(&self, other: &Self) -> cmp::Ordering {
 96        // dbg!(&self.string, self.score);
 97        // dbg!(&other.string, other.score);
 98        self.score
 99            .total_cmp(&other.score)
100            .reverse()
101            .then_with(|| self.string.cmp(&other.string))
102    }
103}
104
105pub async fn match_strings<T>(
106    candidates: &[T],
107    query: &str,
108    smart_case: bool,
109    prefer_shorter: bool,
110    max_results: usize,
111    cancel_flag: &AtomicBool,
112    executor: BackgroundExecutor,
113) -> Vec<StringMatch>
114where
115    T: Borrow<StringMatchCandidate> + Sync,
116{
117    if candidates.is_empty() || max_results == 0 {
118        return Default::default();
119    }
120    // FIXME should support fzf syntax with Pattern::parse
121    let pattern = Pattern::new(
122        query,
123        if smart_case {
124            CaseMatching::Smart
125        } else {
126            CaseMatching::Ignore
127        },
128        Normalization::Smart,
129        AtomKind::Fuzzy,
130    );
131
132    if query.is_empty() {
133        return candidates
134            .iter()
135            .map(|candidate| StringMatch {
136                candidate_id: candidate.borrow().id,
137                score: 0.,
138                positions: Default::default(),
139                string: candidate.borrow().string.clone(),
140            })
141            .collect();
142    }
143
144    let num_cpus = executor.num_cpus().min(candidates.len());
145    let segment_size = candidates.len().div_ceil(num_cpus);
146    let mut segment_results = (0..num_cpus)
147        .map(|_| Vec::<StringMatch>::with_capacity(max_results.min(candidates.len())))
148        .collect::<Vec<_>>();
149
150    let mut config = nucleo::Config::DEFAULT;
151    config.prefer_prefix = true; // TODO: consider making this a setting
152    let mut matchers = matcher::get_matchers(num_cpus, config);
153
154    executor
155        .scoped(|scope| {
156            for (segment_idx, (results, matcher)) in segment_results
157                .iter_mut()
158                .zip(matchers.iter_mut())
159                .enumerate()
160            {
161                let cancel_flag = &cancel_flag;
162                let pattern = pattern.clone();
163                scope.spawn(async move {
164                    let segment_start = cmp::min(segment_idx * segment_size, candidates.len());
165                    let segment_end = cmp::min(segment_start + segment_size, candidates.len());
166
167                    for c in candidates[segment_start..segment_end].iter() {
168                        if cancel_flag.load(Ordering::Relaxed) {
169                            break;
170                        }
171                        let candidate = c.borrow();
172                        let mut indices = Vec::new();
173                        let mut buf = Vec::new();
174                        if let Some(score) = pattern.indices(
175                            nucleo::Utf32Str::new(&candidate.string, &mut buf),
176                            matcher,
177                            &mut indices,
178                        ) {
179                            let length_modifier = candidate.string.chars().count() as f64 / 10_000.;
180                            results.push(StringMatch {
181                                candidate_id: candidate.id,
182                                score: score as f64
183                                    + if prefer_shorter {
184                                        -length_modifier
185                                    } else {
186                                        length_modifier
187                                    },
188
189                                // TODO: need to convert indices/positions from char offsets to byte offsets.
190                                positions: indices.into_iter().map(|n| n as usize).collect(),
191                                string: candidate.string.clone(),
192                            })
193                        };
194                    }
195                });
196            }
197        })
198        .await;
199
200    matcher::return_matchers(matchers);
201
202    if cancel_flag.load(Ordering::Relaxed) {
203        return Vec::new();
204    }
205
206    let mut results = segment_results.concat();
207    util::truncate_to_bottom_n_sorted(&mut results, max_results);
208    for r in &mut results {
209        r.positions.sort();
210    }
211    results
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use std::sync::atomic::AtomicBool;
218
219    use gpui::TestAppContext;
220
221    async fn get_matches(
222        cx: &mut TestAppContext,
223        candidates: &[&'static str],
224        query: &'static str,
225        penalize_length: bool,
226    ) -> Vec<StringMatch> {
227        let candidates: Vec<_> = candidates
228            .iter()
229            .enumerate()
230            .map(|(i, s)| StringMatchCandidate::new(i, s))
231            .collect();
232
233        let cancellation_flag = AtomicBool::new(false);
234        let executor = cx.background_executor.clone();
235        cx.foreground_executor
236            .spawn(async move {
237                super::match_strings(
238                    &candidates,
239                    query,
240                    true,
241                    penalize_length,
242                    100,
243                    &cancellation_flag,
244                    executor,
245                )
246                .await
247            })
248            .await
249    }
250
251    async fn string_matches(
252        cx: &mut TestAppContext,
253        candidates: &[&'static str],
254        query: &'static str,
255        penalize_length: bool,
256    ) -> Vec<String> {
257        let matches = get_matches(cx, candidates, query, penalize_length).await;
258        matches
259            .iter()
260            .map(|sm| dbg!(sm).string.clone())
261            .collect::<Vec<_>>()
262    }
263
264    async fn match_positions(
265        cx: &mut TestAppContext,
266        candidates: &[&'static str],
267        query: &'static str,
268        penalize_length: bool,
269    ) -> Vec<usize> {
270        let mut matches = get_matches(cx, candidates, query, penalize_length).await;
271        matches.remove(0).positions
272    }
273
274    #[gpui::test]
275    async fn prefer_shorter_matches(cx: &mut TestAppContext) {
276        let candidates = &["a", "aa", "aaa"];
277        assert_eq!(
278            string_matches(cx, candidates, "a", true).await,
279            ["a", "aa", "aaa"]
280        );
281    }
282
283    #[gpui::test]
284    async fn prefer_longer_matches(cx: &mut TestAppContext) {
285        let candidates = &["unreachable", "unreachable!()"];
286        assert_eq!(
287            string_matches(cx, candidates, "unreac", false).await,
288            ["unreachable!()", "unreachable",]
289        );
290    }
291
292    #[gpui::test]
293    async fn shorter_over_lexicographical(cx: &mut TestAppContext) {
294        const CANDIDATES: &'static [&'static str] = &["qr", "qqqqqqqqqqqq"];
295        assert_eq!(
296            string_matches(cx, CANDIDATES, "q", true).await,
297            ["qr", "qqqqqqqqqqqq"]
298        );
299    }
300
301    #[gpui::test]
302    async fn indices_are_sorted_and_correct(cx: &mut TestAppContext) {
303        const CANDIDATES: &'static [&'static str] = &["hello how are you"];
304        assert_eq!(
305            match_positions(cx, CANDIDATES, "you hello", true).await,
306            vec![0, 1, 2, 3, 4, 14, 15, 16]
307        );
308
309        // const CANDIDATES: &'static [&'static str] =
310        //     &["crates/livekit_api/vendored/protocol/README.md"];
311    }
312
313    // This is broken?
314    #[gpui::test]
315    async fn broken_nucleo_matcher(cx: &mut TestAppContext) {
316        let candidates = &["lsp_code_lens", "code_lens"];
317        assert_eq!(
318            string_matches(cx, candidates, "lens", false).await,
319            ["code_lens", "lsp_code_lens",]
320        );
321    }
322}