parsing.rs

  1use anyhow::{anyhow, Ok, Result};
  2use language::{Grammar, Language};
  3use std::{
  4    cmp::{self, Reverse},
  5    collections::HashSet,
  6    ops::Range,
  7    path::Path,
  8    sync::Arc,
  9};
 10use tree_sitter::{Parser, QueryCursor};
 11
 12#[derive(Debug, PartialEq, Clone)]
 13pub struct Document {
 14    pub name: String,
 15    pub range: Range<usize>,
 16    pub content: String,
 17    pub embedding: Vec<f32>,
 18}
 19
 20const CODE_CONTEXT_TEMPLATE: &str =
 21    "The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
 22const ENTIRE_FILE_TEMPLATE: &str =
 23    "The below snippet is from file '<path>'\n\n```<language>\n<item>\n```";
 24const MARKDOWN_CONTEXT_TEMPLATE: &str = "The below file contents is from file '<path>'\n\n<item>";
 25pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] =
 26    &["TOML", "YAML", "CSS", "HEEX", "ERB", "SVELTE", "HTML"];
 27
 28pub struct CodeContextRetriever {
 29    pub parser: Parser,
 30    pub cursor: QueryCursor,
 31}
 32
 33// Every match has an item, this represents the fundamental treesitter symbol and anchors the search
 34// Every match has one or more 'name' captures. These indicate the display range of the item for deduplication.
 35// If there are preceeding comments, we track this with a context capture
 36// If there is a piece that should be collapsed in hierarchical queries, we capture it with a collapse capture
 37// If there is a piece that should be kept inside a collapsed node, we capture it with a keep capture
 38#[derive(Debug, Clone)]
 39pub struct CodeContextMatch {
 40    pub start_col: usize,
 41    pub item_range: Option<Range<usize>>,
 42    pub name_range: Option<Range<usize>>,
 43    pub context_ranges: Vec<Range<usize>>,
 44    pub collapse_ranges: Vec<Range<usize>>,
 45}
 46
 47impl CodeContextRetriever {
 48    pub fn new() -> Self {
 49        Self {
 50            parser: Parser::new(),
 51            cursor: QueryCursor::new(),
 52        }
 53    }
 54
 55    fn parse_entire_file(
 56        &self,
 57        relative_path: &Path,
 58        language_name: Arc<str>,
 59        content: &str,
 60    ) -> Result<Vec<Document>> {
 61        let document_span = ENTIRE_FILE_TEMPLATE
 62            .replace("<path>", relative_path.to_string_lossy().as_ref())
 63            .replace("<language>", language_name.as_ref())
 64            .replace("<item>", &content);
 65
 66        Ok(vec![Document {
 67            range: 0..content.len(),
 68            content: document_span,
 69            embedding: Vec::new(),
 70            name: language_name.to_string(),
 71        }])
 72    }
 73
 74    fn parse_markdown_file(&self, relative_path: &Path, content: &str) -> Result<Vec<Document>> {
 75        let document_span = MARKDOWN_CONTEXT_TEMPLATE
 76            .replace("<path>", relative_path.to_string_lossy().as_ref())
 77            .replace("<item>", &content);
 78
 79        Ok(vec![Document {
 80            range: 0..content.len(),
 81            content: document_span,
 82            embedding: Vec::new(),
 83            name: "Markdown".to_string(),
 84        }])
 85    }
 86
 87    fn get_matches_in_file(
 88        &mut self,
 89        content: &str,
 90        grammar: &Arc<Grammar>,
 91    ) -> Result<Vec<CodeContextMatch>> {
 92        let embedding_config = grammar
 93            .embedding_config
 94            .as_ref()
 95            .ok_or_else(|| anyhow!("no embedding queries"))?;
 96        self.parser.set_language(grammar.ts_language).unwrap();
 97
 98        let tree = self
 99            .parser
100            .parse(&content, None)
101            .ok_or_else(|| anyhow!("parsing failed"))?;
102
103        let mut captures: Vec<CodeContextMatch> = Vec::new();
104        let mut collapse_ranges: Vec<Range<usize>> = Vec::new();
105        let mut keep_ranges: Vec<Range<usize>> = Vec::new();
106        for mat in self.cursor.matches(
107            &embedding_config.query,
108            tree.root_node(),
109            content.as_bytes(),
110        ) {
111            let mut start_col = 0;
112            let mut item_range: Option<Range<usize>> = None;
113            let mut name_range: Option<Range<usize>> = None;
114            let mut context_ranges: Vec<Range<usize>> = Vec::new();
115            collapse_ranges.clear();
116            keep_ranges.clear();
117            for capture in mat.captures {
118                if capture.index == embedding_config.item_capture_ix {
119                    item_range = Some(capture.node.byte_range());
120                    start_col = capture.node.start_position().column;
121                } else if Some(capture.index) == embedding_config.name_capture_ix {
122                    name_range = Some(capture.node.byte_range());
123                } else if Some(capture.index) == embedding_config.context_capture_ix {
124                    context_ranges.push(capture.node.byte_range());
125                } else if Some(capture.index) == embedding_config.collapse_capture_ix {
126                    collapse_ranges.push(capture.node.byte_range());
127                } else if Some(capture.index) == embedding_config.keep_capture_ix {
128                    keep_ranges.push(capture.node.byte_range());
129                }
130            }
131
132            captures.push(CodeContextMatch {
133                start_col,
134                item_range,
135                name_range,
136                context_ranges,
137                collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
138            });
139        }
140        Ok(captures)
141    }
142
143    pub fn parse_file_with_template(
144        &mut self,
145        relative_path: &Path,
146        content: &str,
147        language: Arc<Language>,
148    ) -> Result<Vec<Document>> {
149        let language_name = language.name();
150
151        if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
152            return self.parse_entire_file(relative_path, language_name, &content);
153        } else if &language_name.to_string() == &"Markdown".to_string() {
154            return self.parse_markdown_file(relative_path, &content);
155        }
156
157        let mut documents = self.parse_file(content, language)?;
158        for document in &mut documents {
159            document.content = CODE_CONTEXT_TEMPLATE
160                .replace("<path>", relative_path.to_string_lossy().as_ref())
161                .replace("<language>", language_name.as_ref())
162                .replace("item", &document.content);
163        }
164        Ok(documents)
165    }
166
167    pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Document>> {
168        let grammar = language
169            .grammar()
170            .ok_or_else(|| anyhow!("no grammar for language"))?;
171
172        // Iterate through query matches
173        let matches = self.get_matches_in_file(content, grammar)?;
174
175        let language_scope = language.default_scope();
176        let placeholder = language_scope.collapsed_placeholder();
177
178        let mut documents = Vec::new();
179        let mut collapsed_ranges_within = Vec::new();
180        let mut parsed_name_ranges = HashSet::new();
181        for (i, context_match) in matches.iter().enumerate() {
182            // Items which are collapsible but not embeddable have no item range
183            let item_range = if let Some(item_range) = context_match.item_range.clone() {
184                item_range
185            } else {
186                continue;
187            };
188
189            // Checks for deduplication
190            let name;
191            if let Some(name_range) = context_match.name_range.clone() {
192                name = content
193                    .get(name_range.clone())
194                    .map_or(String::new(), |s| s.to_string());
195                if parsed_name_ranges.contains(&name_range) {
196                    continue;
197                }
198                parsed_name_ranges.insert(name_range);
199            } else {
200                name = String::new();
201            }
202
203            collapsed_ranges_within.clear();
204            'outer: for remaining_match in &matches[(i + 1)..] {
205                for collapsed_range in &remaining_match.collapse_ranges {
206                    if item_range.start <= collapsed_range.start
207                        && item_range.end >= collapsed_range.end
208                    {
209                        collapsed_ranges_within.push(collapsed_range.clone());
210                    } else {
211                        break 'outer;
212                    }
213                }
214            }
215
216            collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
217
218            let mut document_content = String::new();
219            for context_range in &context_match.context_ranges {
220                add_content_from_range(
221                    &mut document_content,
222                    content,
223                    context_range.clone(),
224                    context_match.start_col,
225                );
226                document_content.push_str("\n");
227            }
228
229            let mut offset = item_range.start;
230            for collapsed_range in &collapsed_ranges_within {
231                if collapsed_range.start > offset {
232                    add_content_from_range(
233                        &mut document_content,
234                        content,
235                        offset..collapsed_range.start,
236                        context_match.start_col,
237                    );
238                    offset = collapsed_range.start;
239                }
240
241                if collapsed_range.end > offset {
242                    document_content.push_str(placeholder);
243                    offset = collapsed_range.end;
244                }
245            }
246
247            if offset < item_range.end {
248                add_content_from_range(
249                    &mut document_content,
250                    content,
251                    offset..item_range.end,
252                    context_match.start_col,
253                );
254            }
255
256            documents.push(Document {
257                name,
258                content: document_content,
259                range: item_range.clone(),
260                embedding: vec![],
261            })
262        }
263
264        return Ok(documents);
265    }
266}
267
268pub(crate) fn subtract_ranges(
269    ranges: &[Range<usize>],
270    ranges_to_subtract: &[Range<usize>],
271) -> Vec<Range<usize>> {
272    let mut result = Vec::new();
273
274    let mut ranges_to_subtract = ranges_to_subtract.iter().peekable();
275
276    for range in ranges {
277        let mut offset = range.start;
278
279        while offset < range.end {
280            if let Some(range_to_subtract) = ranges_to_subtract.peek() {
281                if offset < range_to_subtract.start {
282                    let next_offset = cmp::min(range_to_subtract.start, range.end);
283                    result.push(offset..next_offset);
284                    offset = next_offset;
285                } else {
286                    let next_offset = cmp::min(range_to_subtract.end, range.end);
287                    offset = next_offset;
288                }
289
290                if offset >= range_to_subtract.end {
291                    ranges_to_subtract.next();
292                }
293            } else {
294                result.push(offset..range.end);
295                offset = range.end;
296            }
297        }
298    }
299
300    result
301}
302
303fn add_content_from_range(
304    output: &mut String,
305    content: &str,
306    range: Range<usize>,
307    start_col: usize,
308) {
309    for mut line in content.get(range.clone()).unwrap_or("").lines() {
310        for _ in 0..start_col {
311            if line.starts_with(' ') {
312                line = &line[1..];
313            } else {
314                break;
315            }
316        }
317        output.push_str(line);
318        output.push('\n');
319    }
320    output.pop();
321}