parsing.rs

  1use anyhow::{anyhow, Ok, Result};
  2use language::{Grammar, Language};
  3use sha1::{Digest, Sha1};
  4use std::{
  5    cmp::{self, Reverse},
  6    collections::HashSet,
  7    ops::Range,
  8    path::Path,
  9    sync::Arc,
 10};
 11use tree_sitter::{Parser, QueryCursor};
 12
 13#[derive(Debug, PartialEq, Clone)]
 14pub struct Document {
 15    pub name: String,
 16    pub range: Range<usize>,
 17    pub content: String,
 18    pub embedding: Vec<f32>,
 19    pub sha1: [u8; 20],
 20}
 21
 22const CODE_CONTEXT_TEMPLATE: &str =
 23    "The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
 24const ENTIRE_FILE_TEMPLATE: &str =
 25    "The below snippet is from file '<path>'\n\n```<language>\n<item>\n```";
 26const MARKDOWN_CONTEXT_TEMPLATE: &str = "The below file contents is from file '<path>'\n\n<item>";
 27pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] =
 28    &["TOML", "YAML", "CSS", "HEEX", "ERB", "SVELTE", "HTML"];
 29
 30pub struct CodeContextRetriever {
 31    pub parser: Parser,
 32    pub cursor: QueryCursor,
 33}
 34
 35// Every match has an item, this represents the fundamental treesitter symbol and anchors the search
 36// Every match has one or more 'name' captures. These indicate the display range of the item for deduplication.
 37// If there are preceeding comments, we track this with a context capture
 38// If there is a piece that should be collapsed in hierarchical queries, we capture it with a collapse capture
 39// If there is a piece that should be kept inside a collapsed node, we capture it with a keep capture
 40#[derive(Debug, Clone)]
 41pub struct CodeContextMatch {
 42    pub start_col: usize,
 43    pub item_range: Option<Range<usize>>,
 44    pub name_range: Option<Range<usize>>,
 45    pub context_ranges: Vec<Range<usize>>,
 46    pub collapse_ranges: Vec<Range<usize>>,
 47}
 48
 49impl CodeContextRetriever {
 50    pub fn new() -> Self {
 51        Self {
 52            parser: Parser::new(),
 53            cursor: QueryCursor::new(),
 54        }
 55    }
 56
 57    fn parse_entire_file(
 58        &self,
 59        relative_path: &Path,
 60        language_name: Arc<str>,
 61        content: &str,
 62    ) -> Result<Vec<Document>> {
 63        let document_span = ENTIRE_FILE_TEMPLATE
 64            .replace("<path>", relative_path.to_string_lossy().as_ref())
 65            .replace("<language>", language_name.as_ref())
 66            .replace("<item>", &content);
 67
 68        let mut sha1 = Sha1::new();
 69        sha1.update(&document_span);
 70
 71        Ok(vec![Document {
 72            range: 0..content.len(),
 73            content: document_span,
 74            embedding: Vec::new(),
 75            name: language_name.to_string(),
 76            sha1: sha1.finalize().into(),
 77        }])
 78    }
 79
 80    fn parse_markdown_file(&self, relative_path: &Path, content: &str) -> Result<Vec<Document>> {
 81        let document_span = MARKDOWN_CONTEXT_TEMPLATE
 82            .replace("<path>", relative_path.to_string_lossy().as_ref())
 83            .replace("<item>", &content);
 84
 85        let mut sha1 = Sha1::new();
 86        sha1.update(&document_span);
 87
 88        Ok(vec![Document {
 89            range: 0..content.len(),
 90            content: document_span,
 91            embedding: Vec::new(),
 92            name: "Markdown".to_string(),
 93            sha1: sha1.finalize().into(),
 94        }])
 95    }
 96
 97    fn get_matches_in_file(
 98        &mut self,
 99        content: &str,
100        grammar: &Arc<Grammar>,
101    ) -> Result<Vec<CodeContextMatch>> {
102        let embedding_config = grammar
103            .embedding_config
104            .as_ref()
105            .ok_or_else(|| anyhow!("no embedding queries"))?;
106        self.parser.set_language(grammar.ts_language).unwrap();
107
108        let tree = self
109            .parser
110            .parse(&content, None)
111            .ok_or_else(|| anyhow!("parsing failed"))?;
112
113        let mut captures: Vec<CodeContextMatch> = Vec::new();
114        let mut collapse_ranges: Vec<Range<usize>> = Vec::new();
115        let mut keep_ranges: Vec<Range<usize>> = Vec::new();
116        for mat in self.cursor.matches(
117            &embedding_config.query,
118            tree.root_node(),
119            content.as_bytes(),
120        ) {
121            let mut start_col = 0;
122            let mut item_range: Option<Range<usize>> = None;
123            let mut name_range: Option<Range<usize>> = None;
124            let mut context_ranges: Vec<Range<usize>> = Vec::new();
125            collapse_ranges.clear();
126            keep_ranges.clear();
127            for capture in mat.captures {
128                if capture.index == embedding_config.item_capture_ix {
129                    item_range = Some(capture.node.byte_range());
130                    start_col = capture.node.start_position().column;
131                } else if Some(capture.index) == embedding_config.name_capture_ix {
132                    name_range = Some(capture.node.byte_range());
133                } else if Some(capture.index) == embedding_config.context_capture_ix {
134                    context_ranges.push(capture.node.byte_range());
135                } else if Some(capture.index) == embedding_config.collapse_capture_ix {
136                    collapse_ranges.push(capture.node.byte_range());
137                } else if Some(capture.index) == embedding_config.keep_capture_ix {
138                    keep_ranges.push(capture.node.byte_range());
139                }
140            }
141
142            captures.push(CodeContextMatch {
143                start_col,
144                item_range,
145                name_range,
146                context_ranges,
147                collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
148            });
149        }
150        Ok(captures)
151    }
152
153    pub fn parse_file_with_template(
154        &mut self,
155        relative_path: &Path,
156        content: &str,
157        language: Arc<Language>,
158    ) -> Result<Vec<Document>> {
159        let language_name = language.name();
160
161        if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
162            return self.parse_entire_file(relative_path, language_name, &content);
163        } else if &language_name.to_string() == &"Markdown".to_string() {
164            return self.parse_markdown_file(relative_path, &content);
165        }
166
167        let mut documents = self.parse_file(content, language)?;
168        for document in &mut documents {
169            document.content = CODE_CONTEXT_TEMPLATE
170                .replace("<path>", relative_path.to_string_lossy().as_ref())
171                .replace("<language>", language_name.as_ref())
172                .replace("item", &document.content);
173        }
174        Ok(documents)
175    }
176
177    pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Document>> {
178        let grammar = language
179            .grammar()
180            .ok_or_else(|| anyhow!("no grammar for language"))?;
181
182        // Iterate through query matches
183        let matches = self.get_matches_in_file(content, grammar)?;
184
185        let language_scope = language.default_scope();
186        let placeholder = language_scope.collapsed_placeholder();
187
188        let mut documents = Vec::new();
189        let mut collapsed_ranges_within = Vec::new();
190        let mut parsed_name_ranges = HashSet::new();
191        for (i, context_match) in matches.iter().enumerate() {
192            // Items which are collapsible but not embeddable have no item range
193            let item_range = if let Some(item_range) = context_match.item_range.clone() {
194                item_range
195            } else {
196                continue;
197            };
198
199            // Checks for deduplication
200            let name;
201            if let Some(name_range) = context_match.name_range.clone() {
202                name = content
203                    .get(name_range.clone())
204                    .map_or(String::new(), |s| s.to_string());
205                if parsed_name_ranges.contains(&name_range) {
206                    continue;
207                }
208                parsed_name_ranges.insert(name_range);
209            } else {
210                name = String::new();
211            }
212
213            collapsed_ranges_within.clear();
214            'outer: for remaining_match in &matches[(i + 1)..] {
215                for collapsed_range in &remaining_match.collapse_ranges {
216                    if item_range.start <= collapsed_range.start
217                        && item_range.end >= collapsed_range.end
218                    {
219                        collapsed_ranges_within.push(collapsed_range.clone());
220                    } else {
221                        break 'outer;
222                    }
223                }
224            }
225
226            collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
227
228            let mut document_content = String::new();
229            for context_range in &context_match.context_ranges {
230                add_content_from_range(
231                    &mut document_content,
232                    content,
233                    context_range.clone(),
234                    context_match.start_col,
235                );
236                document_content.push_str("\n");
237            }
238
239            let mut offset = item_range.start;
240            for collapsed_range in &collapsed_ranges_within {
241                if collapsed_range.start > offset {
242                    add_content_from_range(
243                        &mut document_content,
244                        content,
245                        offset..collapsed_range.start,
246                        context_match.start_col,
247                    );
248                    offset = collapsed_range.start;
249                }
250
251                if collapsed_range.end > offset {
252                    document_content.push_str(placeholder);
253                    offset = collapsed_range.end;
254                }
255            }
256
257            if offset < item_range.end {
258                add_content_from_range(
259                    &mut document_content,
260                    content,
261                    offset..item_range.end,
262                    context_match.start_col,
263                );
264            }
265
266            let mut sha1 = Sha1::new();
267            sha1.update(&document_content);
268
269            documents.push(Document {
270                name,
271                content: document_content,
272                range: item_range.clone(),
273                embedding: vec![],
274                sha1: sha1.finalize().into(),
275            })
276        }
277
278        return Ok(documents);
279    }
280}
281
282pub(crate) fn subtract_ranges(
283    ranges: &[Range<usize>],
284    ranges_to_subtract: &[Range<usize>],
285) -> Vec<Range<usize>> {
286    let mut result = Vec::new();
287
288    let mut ranges_to_subtract = ranges_to_subtract.iter().peekable();
289
290    for range in ranges {
291        let mut offset = range.start;
292
293        while offset < range.end {
294            if let Some(range_to_subtract) = ranges_to_subtract.peek() {
295                if offset < range_to_subtract.start {
296                    let next_offset = cmp::min(range_to_subtract.start, range.end);
297                    result.push(offset..next_offset);
298                    offset = next_offset;
299                } else {
300                    let next_offset = cmp::min(range_to_subtract.end, range.end);
301                    offset = next_offset;
302                }
303
304                if offset >= range_to_subtract.end {
305                    ranges_to_subtract.next();
306                }
307            } else {
308                result.push(offset..range.end);
309                offset = range.end;
310            }
311        }
312    }
313
314    result
315}
316
317fn add_content_from_range(
318    output: &mut String,
319    content: &str,
320    range: Range<usize>,
321    start_col: usize,
322) {
323    for mut line in content.get(range.clone()).unwrap_or("").lines() {
324        for _ in 0..start_col {
325            if line.starts_with(' ') {
326                line = &line[1..];
327            } else {
328                break;
329            }
330        }
331        output.push_str(line);
332        output.push('\n');
333    }
334    output.pop();
335}