parsing.rs

  1use anyhow::{anyhow, Ok, Result};
  2use language::{Grammar, Language};
  3use std::{
  4    cmp::{self, Reverse},
  5    collections::HashSet,
  6    ops::Range,
  7    path::Path,
  8    sync::Arc,
  9};
 10use tree_sitter::{Parser, QueryCursor};
 11
 12#[derive(Debug, PartialEq, Clone)]
 13pub struct Document {
 14    pub name: String,
 15    pub range: Range<usize>,
 16    pub content: String,
 17    pub embedding: Vec<f32>,
 18}
 19
 20const CODE_CONTEXT_TEMPLATE: &str =
 21    "The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
 22const ENTIRE_FILE_TEMPLATE: &str =
 23    "The below snippet is from file '<path>'\n\n```<language>\n<item>\n```";
 24pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] = &["TOML", "YAML", "CSS"];
 25
 26pub struct CodeContextRetriever {
 27    pub parser: Parser,
 28    pub cursor: QueryCursor,
 29}
 30
 31// Every match has an item, this represents the fundamental treesitter symbol and anchors the search
 32// Every match has one or more 'name' captures. These indicate the display range of the item for deduplication.
 33// If there are preceeding comments, we track this with a context capture
 34// If there is a piece that should be collapsed in hierarchical queries, we capture it with a collapse capture
 35// If there is a piece that should be kept inside a collapsed node, we capture it with a keep capture
 36#[derive(Debug, Clone)]
 37pub struct CodeContextMatch {
 38    pub start_col: usize,
 39    pub item_range: Option<Range<usize>>,
 40    pub name_range: Option<Range<usize>>,
 41    pub context_ranges: Vec<Range<usize>>,
 42    pub collapse_ranges: Vec<Range<usize>>,
 43}
 44
 45impl CodeContextRetriever {
 46    pub fn new() -> Self {
 47        Self {
 48            parser: Parser::new(),
 49            cursor: QueryCursor::new(),
 50        }
 51    }
 52
 53    fn parse_entire_file(
 54        &self,
 55        relative_path: &Path,
 56        language_name: Arc<str>,
 57        content: &str,
 58    ) -> Result<Vec<Document>> {
 59        let document_span = ENTIRE_FILE_TEMPLATE
 60            .replace("<path>", relative_path.to_string_lossy().as_ref())
 61            .replace("<language>", language_name.as_ref())
 62            .replace("item", &content);
 63
 64        Ok(vec![Document {
 65            range: 0..content.len(),
 66            content: document_span,
 67            embedding: Vec::new(),
 68            name: language_name.to_string(),
 69        }])
 70    }
 71
 72    fn get_matches_in_file(
 73        &mut self,
 74        content: &str,
 75        grammar: &Arc<Grammar>,
 76    ) -> Result<Vec<CodeContextMatch>> {
 77        let embedding_config = grammar
 78            .embedding_config
 79            .as_ref()
 80            .ok_or_else(|| anyhow!("no embedding queries"))?;
 81        self.parser.set_language(grammar.ts_language).unwrap();
 82
 83        let tree = self
 84            .parser
 85            .parse(&content, None)
 86            .ok_or_else(|| anyhow!("parsing failed"))?;
 87
 88        let mut captures: Vec<CodeContextMatch> = Vec::new();
 89        let mut collapse_ranges: Vec<Range<usize>> = Vec::new();
 90        let mut keep_ranges: Vec<Range<usize>> = Vec::new();
 91        for mat in self.cursor.matches(
 92            &embedding_config.query,
 93            tree.root_node(),
 94            content.as_bytes(),
 95        ) {
 96            let mut start_col = 0;
 97            let mut item_range: Option<Range<usize>> = None;
 98            let mut name_range: Option<Range<usize>> = None;
 99            let mut context_ranges: Vec<Range<usize>> = Vec::new();
100            collapse_ranges.clear();
101            keep_ranges.clear();
102            for capture in mat.captures {
103                if capture.index == embedding_config.item_capture_ix {
104                    item_range = Some(capture.node.byte_range());
105                    start_col = capture.node.start_position().column;
106                } else if Some(capture.index) == embedding_config.name_capture_ix {
107                    name_range = Some(capture.node.byte_range());
108                } else if Some(capture.index) == embedding_config.context_capture_ix {
109                    context_ranges.push(capture.node.byte_range());
110                } else if Some(capture.index) == embedding_config.collapse_capture_ix {
111                    collapse_ranges.push(capture.node.byte_range());
112                } else if Some(capture.index) == embedding_config.keep_capture_ix {
113                    keep_ranges.push(capture.node.byte_range());
114                }
115            }
116
117            captures.push(CodeContextMatch {
118                start_col,
119                item_range,
120                name_range,
121                context_ranges,
122                collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
123            });
124        }
125        Ok(captures)
126    }
127
128    pub fn parse_file_with_template(
129        &mut self,
130        relative_path: &Path,
131        content: &str,
132        language: Arc<Language>,
133    ) -> Result<Vec<Document>> {
134        let language_name = language.name();
135
136        if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
137            return self.parse_entire_file(relative_path, language_name, &content);
138        }
139
140        let mut documents = self.parse_file(content, language)?;
141        for document in &mut documents {
142            document.content = CODE_CONTEXT_TEMPLATE
143                .replace("<path>", relative_path.to_string_lossy().as_ref())
144                .replace("<language>", language_name.as_ref())
145                .replace("item", &document.content);
146        }
147        Ok(documents)
148    }
149
150    pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Document>> {
151        let grammar = language
152            .grammar()
153            .ok_or_else(|| anyhow!("no grammar for language"))?;
154
155        // Iterate through query matches
156        let matches = self.get_matches_in_file(content, grammar)?;
157
158        let language_scope = language.default_scope();
159        let placeholder = language_scope.collapsed_placeholder();
160
161        let mut documents = Vec::new();
162        let mut collapsed_ranges_within = Vec::new();
163        let mut parsed_name_ranges = HashSet::new();
164        for (i, context_match) in matches.iter().enumerate() {
165            // Items which are collapsible but not embeddable have no item range
166            let item_range = if let Some(item_range) = context_match.item_range.clone() {
167                item_range
168            } else {
169                continue;
170            };
171
172            // Checks for deduplication
173            let name;
174            if let Some(name_range) = context_match.name_range.clone() {
175                name = content
176                    .get(name_range.clone())
177                    .map_or(String::new(), |s| s.to_string());
178                if parsed_name_ranges.contains(&name_range) {
179                    continue;
180                }
181                parsed_name_ranges.insert(name_range);
182            } else {
183                name = String::new();
184            }
185
186            collapsed_ranges_within.clear();
187            'outer: for remaining_match in &matches[(i + 1)..] {
188                for collapsed_range in &remaining_match.collapse_ranges {
189                    if item_range.start <= collapsed_range.start
190                        && item_range.end >= collapsed_range.end
191                    {
192                        collapsed_ranges_within.push(collapsed_range.clone());
193                    } else {
194                        break 'outer;
195                    }
196                }
197            }
198
199            collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
200
201            let mut document_content = String::new();
202            for context_range in &context_match.context_ranges {
203                document_content.push_str(&content[context_range.clone()]);
204                document_content.push_str("\n");
205            }
206
207            let mut offset = item_range.start;
208            for collapsed_range in &collapsed_ranges_within {
209                if collapsed_range.start > offset {
210                    add_content_from_range(
211                        &mut document_content,
212                        content,
213                        offset..collapsed_range.start,
214                        context_match.start_col,
215                    );
216                    offset = collapsed_range.start;
217                }
218
219                if collapsed_range.end > offset {
220                    document_content.push_str(placeholder);
221                    offset = collapsed_range.end;
222                }
223            }
224
225            if offset < item_range.end {
226                add_content_from_range(
227                    &mut document_content,
228                    content,
229                    offset..item_range.end,
230                    context_match.start_col,
231                );
232            }
233
234            documents.push(Document {
235                name,
236                content: document_content,
237                range: item_range.clone(),
238                embedding: vec![],
239            })
240        }
241
242        return Ok(documents);
243    }
244}
245
246pub(crate) fn subtract_ranges(
247    ranges: &[Range<usize>],
248    ranges_to_subtract: &[Range<usize>],
249) -> Vec<Range<usize>> {
250    let mut result = Vec::new();
251
252    let mut ranges_to_subtract = ranges_to_subtract.iter().peekable();
253
254    for range in ranges {
255        let mut offset = range.start;
256
257        while offset < range.end {
258            if let Some(range_to_subtract) = ranges_to_subtract.peek() {
259                if offset < range_to_subtract.start {
260                    let next_offset = cmp::min(range_to_subtract.start, range.end);
261                    result.push(offset..next_offset);
262                    offset = next_offset;
263                } else {
264                    let next_offset = cmp::min(range_to_subtract.end, range.end);
265                    offset = next_offset;
266                }
267
268                if offset >= range_to_subtract.end {
269                    ranges_to_subtract.next();
270                }
271            } else {
272                result.push(offset..range.end);
273                offset = range.end;
274            }
275        }
276    }
277
278    result
279}
280
281fn add_content_from_range(
282    output: &mut String,
283    content: &str,
284    range: Range<usize>,
285    start_col: usize,
286) {
287    for mut line in content.get(range.clone()).unwrap_or("").lines() {
288        for _ in 0..start_col {
289            if line.starts_with(' ') {
290                line = &line[1..];
291            } else {
292                break;
293            }
294        }
295        output.push_str(line);
296        output.push('\n');
297    }
298    output.pop();
299}