search.rs

  1use aho_corasick::{AhoCorasick, AhoCorasickBuilder};
  2use anyhow::{Context, Result};
  3use client::proto;
  4use itertools::Itertools;
  5use language::{char_kind, BufferSnapshot};
  6use regex::{Captures, Regex, RegexBuilder};
  7use smol::future::yield_now;
  8use std::{
  9    borrow::Cow,
 10    io::{BufRead, BufReader, Read},
 11    ops::Range,
 12    path::Path,
 13    sync::{Arc, OnceLock},
 14};
 15use util::paths::PathMatcher;
 16
 17static TEXT_REPLACEMENT_SPECIAL_CHARACTERS_REGEX: OnceLock<Regex> = OnceLock::new();
 18
 19#[derive(Clone, Debug)]
 20pub struct SearchInputs {
 21    query: Arc<str>,
 22    files_to_include: Vec<PathMatcher>,
 23    files_to_exclude: Vec<PathMatcher>,
 24}
 25
 26impl SearchInputs {
 27    pub fn as_str(&self) -> &str {
 28        self.query.as_ref()
 29    }
 30    pub fn files_to_include(&self) -> &[PathMatcher] {
 31        &self.files_to_include
 32    }
 33    pub fn files_to_exclude(&self) -> &[PathMatcher] {
 34        &self.files_to_exclude
 35    }
 36}
 37#[derive(Clone, Debug)]
 38pub enum SearchQuery {
 39    Text {
 40        search: Arc<AhoCorasick>,
 41        replacement: Option<String>,
 42        whole_word: bool,
 43        case_sensitive: bool,
 44        include_ignored: bool,
 45        inner: SearchInputs,
 46    },
 47
 48    Regex {
 49        regex: Regex,
 50        replacement: Option<String>,
 51        multiline: bool,
 52        whole_word: bool,
 53        case_sensitive: bool,
 54        include_ignored: bool,
 55        inner: SearchInputs,
 56    },
 57}
 58
 59impl SearchQuery {
 60    pub fn text(
 61        query: impl ToString,
 62        whole_word: bool,
 63        case_sensitive: bool,
 64        include_ignored: bool,
 65        files_to_include: Vec<PathMatcher>,
 66        files_to_exclude: Vec<PathMatcher>,
 67    ) -> Result<Self> {
 68        let query = query.to_string();
 69        let search = AhoCorasickBuilder::new()
 70            .ascii_case_insensitive(!case_sensitive)
 71            .build(&[&query])?;
 72        let inner = SearchInputs {
 73            query: query.into(),
 74            files_to_exclude,
 75            files_to_include,
 76        };
 77        Ok(Self::Text {
 78            search: Arc::new(search),
 79            replacement: None,
 80            whole_word,
 81            case_sensitive,
 82            include_ignored,
 83            inner,
 84        })
 85    }
 86
 87    pub fn regex(
 88        query: impl ToString,
 89        whole_word: bool,
 90        case_sensitive: bool,
 91        include_ignored: bool,
 92        files_to_include: Vec<PathMatcher>,
 93        files_to_exclude: Vec<PathMatcher>,
 94    ) -> Result<Self> {
 95        let mut query = query.to_string();
 96        let initial_query = Arc::from(query.as_str());
 97        if whole_word {
 98            let mut word_query = String::new();
 99            word_query.push_str("\\b");
100            word_query.push_str(&query);
101            word_query.push_str("\\b");
102            query = word_query
103        }
104
105        let multiline = query.contains('\n') || query.contains("\\n");
106        let regex = RegexBuilder::new(&query)
107            .case_insensitive(!case_sensitive)
108            .multi_line(multiline)
109            .build()?;
110        let inner = SearchInputs {
111            query: initial_query,
112            files_to_exclude,
113            files_to_include,
114        };
115        Ok(Self::Regex {
116            regex,
117            replacement: None,
118            multiline,
119            whole_word,
120            case_sensitive,
121            include_ignored,
122            inner,
123        })
124    }
125
126    pub fn from_proto(message: proto::SearchProject) -> Result<Self> {
127        if message.regex {
128            Self::regex(
129                message.query,
130                message.whole_word,
131                message.case_sensitive,
132                message.include_ignored,
133                deserialize_path_matches(&message.files_to_include)?,
134                deserialize_path_matches(&message.files_to_exclude)?,
135            )
136        } else {
137            Self::text(
138                message.query,
139                message.whole_word,
140                message.case_sensitive,
141                message.include_ignored,
142                deserialize_path_matches(&message.files_to_include)?,
143                deserialize_path_matches(&message.files_to_exclude)?,
144            )
145        }
146    }
147    pub fn with_replacement(mut self, new_replacement: String) -> Self {
148        match self {
149            Self::Text {
150                ref mut replacement,
151                ..
152            }
153            | Self::Regex {
154                ref mut replacement,
155                ..
156            } => {
157                *replacement = Some(new_replacement);
158                self
159            }
160        }
161    }
162    pub fn to_proto(&self, project_id: u64) -> proto::SearchProject {
163        proto::SearchProject {
164            project_id,
165            query: self.as_str().to_string(),
166            regex: self.is_regex(),
167            whole_word: self.whole_word(),
168            case_sensitive: self.case_sensitive(),
169            include_ignored: self.include_ignored(),
170            files_to_include: self
171                .files_to_include()
172                .iter()
173                .map(|matcher| matcher.to_string())
174                .join(","),
175            files_to_exclude: self
176                .files_to_exclude()
177                .iter()
178                .map(|matcher| matcher.to_string())
179                .join(","),
180        }
181    }
182
183    pub fn detect<T: Read>(&self, stream: T) -> Result<bool> {
184        if self.as_str().is_empty() {
185            return Ok(false);
186        }
187
188        match self {
189            Self::Text { search, .. } => {
190                let mat = search.stream_find_iter(stream).next();
191                match mat {
192                    Some(Ok(_)) => Ok(true),
193                    Some(Err(err)) => Err(err.into()),
194                    None => Ok(false),
195                }
196            }
197            Self::Regex {
198                regex, multiline, ..
199            } => {
200                let mut reader = BufReader::new(stream);
201                if *multiline {
202                    let mut text = String::new();
203                    if let Err(err) = reader.read_to_string(&mut text) {
204                        Err(err.into())
205                    } else {
206                        Ok(regex.find(&text).is_some())
207                    }
208                } else {
209                    for line in reader.lines() {
210                        let line = line?;
211                        if regex.find(&line).is_some() {
212                            return Ok(true);
213                        }
214                    }
215                    Ok(false)
216                }
217            }
218        }
219    }
220    /// Returns the replacement text for this `SearchQuery`.
221    pub fn replacement(&self) -> Option<&str> {
222        match self {
223            SearchQuery::Text { replacement, .. } | SearchQuery::Regex { replacement, .. } => {
224                replacement.as_deref()
225            }
226        }
227    }
228    /// Replaces search hits if replacement is set. `text` is assumed to be a string that matches this `SearchQuery` exactly, without any leftovers on either side.
229    pub fn replacement_for<'a>(&self, text: &'a str) -> Option<Cow<'a, str>> {
230        match self {
231            SearchQuery::Text { replacement, .. } => replacement.clone().map(Cow::from),
232            SearchQuery::Regex {
233                regex, replacement, ..
234            } => {
235                if let Some(replacement) = replacement {
236                    let replacement = TEXT_REPLACEMENT_SPECIAL_CHARACTERS_REGEX
237                        .get_or_init(|| Regex::new(r"\\\\|\\n|\\t").unwrap())
238                        .replace_all(replacement, |c: &Captures| {
239                            match c.get(0).unwrap().as_str() {
240                                r"\\" => "\\",
241                                r"\n" => "\n",
242                                r"\t" => "\t",
243                                x => unreachable!("Unexpected escape sequence: {}", x),
244                            }
245                        });
246                    Some(regex.replace(text, replacement))
247                } else {
248                    None
249                }
250            }
251        }
252    }
253
254    pub async fn search(
255        &self,
256        buffer: &BufferSnapshot,
257        subrange: Option<Range<usize>>,
258    ) -> Vec<Range<usize>> {
259        const YIELD_INTERVAL: usize = 20000;
260
261        if self.as_str().is_empty() {
262            return Default::default();
263        }
264
265        let range_offset = subrange.as_ref().map(|r| r.start).unwrap_or(0);
266        let rope = if let Some(range) = subrange {
267            buffer.as_rope().slice(range)
268        } else {
269            buffer.as_rope().clone()
270        };
271
272        let mut matches = Vec::new();
273        match self {
274            Self::Text {
275                search, whole_word, ..
276            } => {
277                for (ix, mat) in search
278                    .stream_find_iter(rope.bytes_in_range(0..rope.len()))
279                    .enumerate()
280                {
281                    if (ix + 1) % YIELD_INTERVAL == 0 {
282                        yield_now().await;
283                    }
284
285                    let mat = mat.unwrap();
286                    if *whole_word {
287                        let scope = buffer.language_scope_at(range_offset + mat.start());
288                        let kind = |c| char_kind(&scope, c);
289
290                        let prev_kind = rope.reversed_chars_at(mat.start()).next().map(kind);
291                        let start_kind = kind(rope.chars_at(mat.start()).next().unwrap());
292                        let end_kind = kind(rope.reversed_chars_at(mat.end()).next().unwrap());
293                        let next_kind = rope.chars_at(mat.end()).next().map(kind);
294                        if Some(start_kind) == prev_kind || Some(end_kind) == next_kind {
295                            continue;
296                        }
297                    }
298                    matches.push(mat.start()..mat.end())
299                }
300            }
301
302            Self::Regex {
303                regex, multiline, ..
304            } => {
305                if *multiline {
306                    let text = rope.to_string();
307                    for (ix, mat) in regex.find_iter(&text).enumerate() {
308                        if (ix + 1) % YIELD_INTERVAL == 0 {
309                            yield_now().await;
310                        }
311
312                        matches.push(mat.start()..mat.end());
313                    }
314                } else {
315                    let mut line = String::new();
316                    let mut line_offset = 0;
317                    for (chunk_ix, chunk) in rope.chunks().chain(["\n"]).enumerate() {
318                        if (chunk_ix + 1) % YIELD_INTERVAL == 0 {
319                            yield_now().await;
320                        }
321
322                        for (newline_ix, text) in chunk.split('\n').enumerate() {
323                            if newline_ix > 0 {
324                                for mat in regex.find_iter(&line) {
325                                    let start = line_offset + mat.start();
326                                    let end = line_offset + mat.end();
327                                    matches.push(start..end);
328                                }
329
330                                line_offset += line.len() + 1;
331                                line.clear();
332                            }
333                            line.push_str(text);
334                        }
335                    }
336                }
337            }
338        }
339
340        matches
341    }
342
343    pub fn is_empty(&self) -> bool {
344        self.as_str().is_empty()
345    }
346
347    pub fn as_str(&self) -> &str {
348        self.as_inner().as_str()
349    }
350
351    pub fn whole_word(&self) -> bool {
352        match self {
353            Self::Text { whole_word, .. } => *whole_word,
354            Self::Regex { whole_word, .. } => *whole_word,
355        }
356    }
357
358    pub fn case_sensitive(&self) -> bool {
359        match self {
360            Self::Text { case_sensitive, .. } => *case_sensitive,
361            Self::Regex { case_sensitive, .. } => *case_sensitive,
362        }
363    }
364
365    pub fn include_ignored(&self) -> bool {
366        match self {
367            Self::Text {
368                include_ignored, ..
369            } => *include_ignored,
370            Self::Regex {
371                include_ignored, ..
372            } => *include_ignored,
373        }
374    }
375
376    pub fn is_regex(&self) -> bool {
377        matches!(self, Self::Regex { .. })
378    }
379
380    pub fn files_to_include(&self) -> &[PathMatcher] {
381        self.as_inner().files_to_include()
382    }
383
384    pub fn files_to_exclude(&self) -> &[PathMatcher] {
385        self.as_inner().files_to_exclude()
386    }
387
388    pub fn file_matches(&self, file_path: Option<&Path>) -> bool {
389        match file_path {
390            Some(file_path) => {
391                let mut path = file_path.to_path_buf();
392                loop {
393                    if self
394                        .files_to_exclude()
395                        .iter()
396                        .any(|exclude_glob| exclude_glob.is_match(&path))
397                    {
398                        return false;
399                    } else if self.files_to_include().is_empty()
400                        || self
401                            .files_to_include()
402                            .iter()
403                            .any(|include_glob| include_glob.is_match(&path))
404                    {
405                        return true;
406                    } else if !path.pop() {
407                        return false;
408                    }
409                }
410            }
411            None => self.files_to_include().is_empty(),
412        }
413    }
414    pub fn as_inner(&self) -> &SearchInputs {
415        match self {
416            Self::Regex { inner, .. } | Self::Text { inner, .. } => inner,
417        }
418    }
419}
420
421fn deserialize_path_matches(glob_set: &str) -> anyhow::Result<Vec<PathMatcher>> {
422    glob_set
423        .split(',')
424        .map(str::trim)
425        .filter(|glob_str| !glob_str.is_empty())
426        .map(|glob_str| {
427            PathMatcher::new(glob_str)
428                .with_context(|| format!("deserializing path match glob {glob_str}"))
429        })
430        .collect()
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436
437    #[test]
438    fn path_matcher_creation_for_valid_paths() {
439        for valid_path in [
440            "file",
441            "Cargo.toml",
442            ".DS_Store",
443            "~/dir/another_dir/",
444            "./dir/file",
445            "dir/[a-z].txt",
446            "../dir/filé",
447        ] {
448            let path_matcher = PathMatcher::new(valid_path).unwrap_or_else(|e| {
449                panic!("Valid path {valid_path} should be accepted, but got: {e}")
450            });
451            assert!(
452                path_matcher.is_match(valid_path),
453                "Path matcher for valid path {valid_path} should match itself"
454            )
455        }
456    }
457
458    #[test]
459    fn path_matcher_creation_for_globs() {
460        for invalid_glob in ["dir/[].txt", "dir/[a-z.txt", "dir/{file"] {
461            match PathMatcher::new(invalid_glob) {
462                Ok(_) => panic!("Invalid glob {invalid_glob} should not be accepted"),
463                Err(_expected) => {}
464            }
465        }
466
467        for valid_glob in [
468            "dir/?ile",
469            "dir/*.txt",
470            "dir/**/file",
471            "dir/[a-z].txt",
472            "{dir,file}",
473        ] {
474            match PathMatcher::new(valid_glob) {
475                Ok(_expected) => {}
476                Err(e) => panic!("Valid glob {valid_glob} should be accepted, but got: {e}"),
477            }
478        }
479    }
480}