search.rs

  1use aho_corasick::{AhoCorasick, AhoCorasickBuilder};
  2use anyhow::Result;
  3use client::proto;
  4use language::{char_kind, Rope};
  5use regex::{Regex, RegexBuilder};
  6use smol::future::yield_now;
  7use std::{
  8    io::{BufRead, BufReader, Read},
  9    ops::Range,
 10    sync::Arc,
 11};
 12
 13#[derive(Clone)]
 14pub enum SearchQuery {
 15    Text {
 16        search: Arc<AhoCorasick<usize>>,
 17        query: Arc<str>,
 18        whole_word: bool,
 19        case_sensitive: bool,
 20    },
 21    Regex {
 22        regex: Regex,
 23        query: Arc<str>,
 24        multiline: bool,
 25        whole_word: bool,
 26        case_sensitive: bool,
 27    },
 28}
 29
 30impl SearchQuery {
 31    pub fn text(query: impl ToString, whole_word: bool, case_sensitive: bool) -> Self {
 32        let query = query.to_string();
 33        let search = AhoCorasickBuilder::new()
 34            .auto_configure(&[&query])
 35            .ascii_case_insensitive(!case_sensitive)
 36            .build(&[&query]);
 37        Self::Text {
 38            search: Arc::new(search),
 39            query: Arc::from(query),
 40            whole_word,
 41            case_sensitive,
 42        }
 43    }
 44
 45    pub fn regex(query: impl ToString, whole_word: bool, case_sensitive: bool) -> Result<Self> {
 46        let mut query = query.to_string();
 47        let initial_query = Arc::from(query.as_str());
 48        if whole_word {
 49            let mut word_query = String::new();
 50            word_query.push_str("\\b");
 51            word_query.push_str(&query);
 52            word_query.push_str("\\b");
 53            query = word_query
 54        }
 55
 56        let multiline = query.contains('\n') || query.contains("\\n");
 57        let regex = RegexBuilder::new(&query)
 58            .case_insensitive(!case_sensitive)
 59            .multi_line(multiline)
 60            .build()?;
 61        Ok(Self::Regex {
 62            regex,
 63            query: initial_query,
 64            multiline,
 65            whole_word,
 66            case_sensitive,
 67        })
 68    }
 69
 70    pub fn from_proto(message: proto::SearchProject) -> Result<Self> {
 71        if message.regex {
 72            Self::regex(message.query, message.whole_word, message.case_sensitive)
 73        } else {
 74            Ok(Self::text(
 75                message.query,
 76                message.whole_word,
 77                message.case_sensitive,
 78            ))
 79        }
 80    }
 81
 82    pub fn to_proto(&self, project_id: u64) -> proto::SearchProject {
 83        proto::SearchProject {
 84            project_id,
 85            query: self.as_str().to_string(),
 86            regex: self.is_regex(),
 87            whole_word: self.whole_word(),
 88            case_sensitive: self.case_sensitive(),
 89        }
 90    }
 91
 92    pub fn detect<T: Read>(&self, stream: T) -> Result<bool> {
 93        if self.as_str().is_empty() {
 94            return Ok(false);
 95        }
 96
 97        match self {
 98            Self::Text { search, .. } => {
 99                let mat = search.stream_find_iter(stream).next();
100                match mat {
101                    Some(Ok(_)) => Ok(true),
102                    Some(Err(err)) => Err(err.into()),
103                    None => Ok(false),
104                }
105            }
106            Self::Regex {
107                regex, multiline, ..
108            } => {
109                let mut reader = BufReader::new(stream);
110                if *multiline {
111                    let mut text = String::new();
112                    if let Err(err) = reader.read_to_string(&mut text) {
113                        Err(err.into())
114                    } else {
115                        Ok(regex.find(&text).is_some())
116                    }
117                } else {
118                    for line in reader.lines() {
119                        let line = line?;
120                        if regex.find(&line).is_some() {
121                            return Ok(true);
122                        }
123                    }
124                    Ok(false)
125                }
126            }
127        }
128    }
129
130    pub async fn search(&self, rope: &Rope) -> Vec<Range<usize>> {
131        const YIELD_INTERVAL: usize = 20000;
132
133        if self.as_str().is_empty() {
134            return Default::default();
135        }
136
137        let mut matches = Vec::new();
138        match self {
139            Self::Text {
140                search, whole_word, ..
141            } => {
142                for (ix, mat) in search
143                    .stream_find_iter(rope.bytes_in_range(0..rope.len()))
144                    .enumerate()
145                {
146                    if (ix + 1) % YIELD_INTERVAL == 0 {
147                        yield_now().await;
148                    }
149
150                    let mat = mat.unwrap();
151                    if *whole_word {
152                        let prev_kind = rope.reversed_chars_at(mat.start()).next().map(char_kind);
153                        let start_kind = char_kind(rope.chars_at(mat.start()).next().unwrap());
154                        let end_kind = char_kind(rope.reversed_chars_at(mat.end()).next().unwrap());
155                        let next_kind = rope.chars_at(mat.end()).next().map(char_kind);
156                        if Some(start_kind) == prev_kind || Some(end_kind) == next_kind {
157                            continue;
158                        }
159                    }
160                    matches.push(mat.start()..mat.end())
161                }
162            }
163            Self::Regex {
164                regex, multiline, ..
165            } => {
166                if *multiline {
167                    let text = rope.to_string();
168                    for (ix, mat) in regex.find_iter(&text).enumerate() {
169                        if (ix + 1) % YIELD_INTERVAL == 0 {
170                            yield_now().await;
171                        }
172
173                        matches.push(mat.start()..mat.end());
174                    }
175                } else {
176                    let mut line = String::new();
177                    let mut line_offset = 0;
178                    for (chunk_ix, chunk) in rope.chunks().chain(["\n"]).enumerate() {
179                        if (chunk_ix + 1) % YIELD_INTERVAL == 0 {
180                            yield_now().await;
181                        }
182
183                        for (newline_ix, text) in chunk.split('\n').enumerate() {
184                            if newline_ix > 0 {
185                                for mat in regex.find_iter(&line) {
186                                    let start = line_offset + mat.start();
187                                    let end = line_offset + mat.end();
188                                    matches.push(start..end);
189                                }
190
191                                line_offset += line.len() + 1;
192                                line.clear();
193                            }
194                            line.push_str(text);
195                        }
196                    }
197                }
198            }
199        }
200        matches
201    }
202
203    pub fn as_str(&self) -> &str {
204        match self {
205            Self::Text { query, .. } => query.as_ref(),
206            Self::Regex { query, .. } => query.as_ref(),
207        }
208    }
209
210    pub fn whole_word(&self) -> bool {
211        match self {
212            Self::Text { whole_word, .. } => *whole_word,
213            Self::Regex { whole_word, .. } => *whole_word,
214        }
215    }
216
217    pub fn case_sensitive(&self) -> bool {
218        match self {
219            Self::Text { case_sensitive, .. } => *case_sensitive,
220            Self::Regex { case_sensitive, .. } => *case_sensitive,
221        }
222    }
223
224    pub fn is_regex(&self) -> bool {
225        matches!(self, Self::Regex { .. })
226    }
227}