search.rs

  1use aho_corasick::{AhoCorasick, AhoCorasickBuilder};
  2use anyhow::Result;
  3use language::{char_kind, Rope};
  4use regex::{Regex, RegexBuilder};
  5use smol::future::yield_now;
  6use std::{
  7    io::{BufRead, BufReader, Read},
  8    ops::Range,
  9    sync::Arc,
 10};
 11
 12#[derive(Clone)]
 13pub enum SearchQuery {
 14    Text {
 15        search: Arc<AhoCorasick<usize>>,
 16        query: Arc<str>,
 17        whole_word: bool,
 18        case_sensitive: bool,
 19    },
 20    Regex {
 21        regex: Regex,
 22        query: Arc<str>,
 23        multiline: bool,
 24        whole_word: bool,
 25        case_sensitive: bool,
 26    },
 27}
 28
 29impl SearchQuery {
 30    pub fn text(query: impl ToString, whole_word: bool, case_sensitive: bool) -> Self {
 31        let query = query.to_string();
 32        let search = AhoCorasickBuilder::new()
 33            .auto_configure(&[&query])
 34            .ascii_case_insensitive(!case_sensitive)
 35            .build(&[&query]);
 36        Self::Text {
 37            search: Arc::new(search),
 38            query: Arc::from(query),
 39            whole_word,
 40            case_sensitive,
 41        }
 42    }
 43
 44    pub fn regex(query: impl ToString, whole_word: bool, case_sensitive: bool) -> Result<Self> {
 45        let mut query = query.to_string();
 46        let initial_query = Arc::from(query.as_str());
 47        if whole_word {
 48            let mut word_query = String::new();
 49            word_query.push_str("\\b");
 50            word_query.push_str(&query);
 51            word_query.push_str("\\b");
 52            query = word_query
 53        }
 54
 55        let multiline = query.contains("\n") || query.contains("\\n");
 56        let regex = RegexBuilder::new(&query)
 57            .case_insensitive(!case_sensitive)
 58            .multi_line(multiline)
 59            .build()?;
 60        Ok(Self::Regex {
 61            regex,
 62            query: initial_query,
 63            multiline,
 64            whole_word,
 65            case_sensitive,
 66        })
 67    }
 68
 69    pub fn detect<T: Read>(&self, stream: T) -> Result<bool> {
 70        if self.as_str().is_empty() {
 71            return Ok(false);
 72        }
 73
 74        match self {
 75            Self::Text { search, .. } => {
 76                let mat = search.stream_find_iter(stream).next();
 77                match mat {
 78                    Some(Ok(_)) => Ok(true),
 79                    Some(Err(err)) => Err(err.into()),
 80                    None => Ok(false),
 81                }
 82            }
 83            Self::Regex {
 84                regex, multiline, ..
 85            } => {
 86                let mut reader = BufReader::new(stream);
 87                if *multiline {
 88                    let mut text = String::new();
 89                    if let Err(err) = reader.read_to_string(&mut text) {
 90                        Err(err.into())
 91                    } else {
 92                        Ok(regex.find(&text).is_some())
 93                    }
 94                } else {
 95                    for line in reader.lines() {
 96                        let line = line?;
 97                        if regex.find(&line).is_some() {
 98                            return Ok(true);
 99                        }
100                    }
101                    Ok(false)
102                }
103            }
104        }
105    }
106
107    pub async fn search(&self, rope: &Rope) -> Vec<Range<usize>> {
108        const YIELD_INTERVAL: usize = 20000;
109
110        if self.as_str().is_empty() {
111            return Default::default();
112        }
113
114        let mut matches = Vec::new();
115        match self {
116            Self::Text {
117                search, whole_word, ..
118            } => {
119                for (ix, mat) in search
120                    .stream_find_iter(rope.bytes_in_range(0..rope.len()))
121                    .enumerate()
122                {
123                    if (ix + 1) % YIELD_INTERVAL == 0 {
124                        yield_now().await;
125                    }
126
127                    let mat = mat.unwrap();
128                    if *whole_word {
129                        let prev_kind = rope.reversed_chars_at(mat.start()).next().map(char_kind);
130                        let start_kind = char_kind(rope.chars_at(mat.start()).next().unwrap());
131                        let end_kind = char_kind(rope.reversed_chars_at(mat.end()).next().unwrap());
132                        let next_kind = rope.chars_at(mat.end()).next().map(char_kind);
133                        if Some(start_kind) == prev_kind || Some(end_kind) == next_kind {
134                            continue;
135                        }
136                    }
137                    matches.push(mat.start()..mat.end())
138                }
139            }
140            Self::Regex {
141                regex, multiline, ..
142            } => {
143                if *multiline {
144                    let text = rope.to_string();
145                    for (ix, mat) in regex.find_iter(&text).enumerate() {
146                        if (ix + 1) % YIELD_INTERVAL == 0 {
147                            yield_now().await;
148                        }
149
150                        matches.push(mat.start()..mat.end());
151                    }
152                } else {
153                    let mut line = String::new();
154                    let mut line_offset = 0;
155                    for (chunk_ix, chunk) in rope.chunks().chain(["\n"]).enumerate() {
156                        if (chunk_ix + 1) % YIELD_INTERVAL == 0 {
157                            yield_now().await;
158                        }
159
160                        for (newline_ix, text) in chunk.split('\n').enumerate() {
161                            if newline_ix > 0 {
162                                for mat in regex.find_iter(&line) {
163                                    let start = line_offset + mat.start();
164                                    let end = line_offset + mat.end();
165                                    matches.push(start..end);
166                                }
167
168                                line_offset += line.len() + 1;
169                                line.clear();
170                            }
171                            line.push_str(text);
172                        }
173                    }
174                }
175            }
176        }
177        matches
178    }
179
180    pub fn as_str(&self) -> &str {
181        match self {
182            Self::Text { query, .. } => query.as_ref(),
183            Self::Regex { query, .. } => query.as_ref(),
184        }
185    }
186
187    pub fn whole_word(&self) -> bool {
188        match self {
189            Self::Text { whole_word, .. } => *whole_word,
190            Self::Regex { whole_word, .. } => *whole_word,
191        }
192    }
193
194    pub fn case_sensitive(&self) -> bool {
195        match self {
196            Self::Text { case_sensitive, .. } => *case_sensitive,
197            Self::Regex { case_sensitive, .. } => *case_sensitive,
198        }
199    }
200
201    pub fn is_regex(&self) -> bool {
202        matches!(self, Self::Regex { .. })
203    }
204}