parsing.rs

  1use crate::embedding::EmbeddingProvider;
  2use anyhow::{anyhow, Ok, Result};
  3use language::{Grammar, Language};
  4use sha1::{Digest, Sha1};
  5use std::{
  6    cmp::{self, Reverse},
  7    collections::HashSet,
  8    ops::Range,
  9    path::Path,
 10    sync::Arc,
 11};
 12use tree_sitter::{Parser, QueryCursor};
 13
 14#[derive(Debug, PartialEq, Clone)]
 15pub struct Document {
 16    pub name: String,
 17    pub range: Range<usize>,
 18    pub content: String,
 19    pub embedding: Vec<f32>,
 20    pub sha1: [u8; 20],
 21    pub token_count: usize,
 22}
 23
 24const CODE_CONTEXT_TEMPLATE: &str =
 25    "The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
 26const ENTIRE_FILE_TEMPLATE: &str =
 27    "The below snippet is from file '<path>'\n\n```<language>\n<item>\n```";
 28const MARKDOWN_CONTEXT_TEMPLATE: &str = "The below file contents is from file '<path>'\n\n<item>";
 29pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] =
 30    &["TOML", "YAML", "CSS", "HEEX", "ERB", "SVELTE", "HTML"];
 31
 32pub struct CodeContextRetriever {
 33    pub parser: Parser,
 34    pub cursor: QueryCursor,
 35    pub embedding_provider: Arc<dyn EmbeddingProvider>,
 36}
 37
 38// Every match has an item, this represents the fundamental treesitter symbol and anchors the search
 39// Every match has one or more 'name' captures. These indicate the display range of the item for deduplication.
 40// If there are preceeding comments, we track this with a context capture
 41// If there is a piece that should be collapsed in hierarchical queries, we capture it with a collapse capture
 42// If there is a piece that should be kept inside a collapsed node, we capture it with a keep capture
 43#[derive(Debug, Clone)]
 44pub struct CodeContextMatch {
 45    pub start_col: usize,
 46    pub item_range: Option<Range<usize>>,
 47    pub name_range: Option<Range<usize>>,
 48    pub context_ranges: Vec<Range<usize>>,
 49    pub collapse_ranges: Vec<Range<usize>>,
 50}
 51
 52impl CodeContextRetriever {
 53    pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
 54        Self {
 55            parser: Parser::new(),
 56            cursor: QueryCursor::new(),
 57            embedding_provider,
 58        }
 59    }
 60
 61    fn parse_entire_file(
 62        &self,
 63        relative_path: &Path,
 64        language_name: Arc<str>,
 65        content: &str,
 66    ) -> Result<Vec<Document>> {
 67        let document_span = ENTIRE_FILE_TEMPLATE
 68            .replace("<path>", relative_path.to_string_lossy().as_ref())
 69            .replace("<language>", language_name.as_ref())
 70            .replace("<item>", &content);
 71
 72        let mut sha1 = Sha1::new();
 73        sha1.update(&document_span);
 74
 75        let token_count = self.embedding_provider.count_tokens(&document_span);
 76
 77        Ok(vec![Document {
 78            range: 0..content.len(),
 79            content: document_span,
 80            embedding: Vec::new(),
 81            name: language_name.to_string(),
 82            sha1: sha1.finalize().into(),
 83            token_count,
 84        }])
 85    }
 86
 87    fn parse_markdown_file(&self, relative_path: &Path, content: &str) -> Result<Vec<Document>> {
 88        let document_span = MARKDOWN_CONTEXT_TEMPLATE
 89            .replace("<path>", relative_path.to_string_lossy().as_ref())
 90            .replace("<item>", &content);
 91
 92        let mut sha1 = Sha1::new();
 93        sha1.update(&document_span);
 94
 95        let token_count = self.embedding_provider.count_tokens(&document_span);
 96
 97        Ok(vec![Document {
 98            range: 0..content.len(),
 99            content: document_span,
100            embedding: Vec::new(),
101            name: "Markdown".to_string(),
102            sha1: sha1.finalize().into(),
103            token_count,
104        }])
105    }
106
107    fn get_matches_in_file(
108        &mut self,
109        content: &str,
110        grammar: &Arc<Grammar>,
111    ) -> Result<Vec<CodeContextMatch>> {
112        let embedding_config = grammar
113            .embedding_config
114            .as_ref()
115            .ok_or_else(|| anyhow!("no embedding queries"))?;
116        self.parser.set_language(grammar.ts_language).unwrap();
117
118        let tree = self
119            .parser
120            .parse(&content, None)
121            .ok_or_else(|| anyhow!("parsing failed"))?;
122
123        let mut captures: Vec<CodeContextMatch> = Vec::new();
124        let mut collapse_ranges: Vec<Range<usize>> = Vec::new();
125        let mut keep_ranges: Vec<Range<usize>> = Vec::new();
126        for mat in self.cursor.matches(
127            &embedding_config.query,
128            tree.root_node(),
129            content.as_bytes(),
130        ) {
131            let mut start_col = 0;
132            let mut item_range: Option<Range<usize>> = None;
133            let mut name_range: Option<Range<usize>> = None;
134            let mut context_ranges: Vec<Range<usize>> = Vec::new();
135            collapse_ranges.clear();
136            keep_ranges.clear();
137            for capture in mat.captures {
138                if capture.index == embedding_config.item_capture_ix {
139                    item_range = Some(capture.node.byte_range());
140                    start_col = capture.node.start_position().column;
141                } else if Some(capture.index) == embedding_config.name_capture_ix {
142                    name_range = Some(capture.node.byte_range());
143                } else if Some(capture.index) == embedding_config.context_capture_ix {
144                    context_ranges.push(capture.node.byte_range());
145                } else if Some(capture.index) == embedding_config.collapse_capture_ix {
146                    collapse_ranges.push(capture.node.byte_range());
147                } else if Some(capture.index) == embedding_config.keep_capture_ix {
148                    keep_ranges.push(capture.node.byte_range());
149                }
150            }
151
152            captures.push(CodeContextMatch {
153                start_col,
154                item_range,
155                name_range,
156                context_ranges,
157                collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
158            });
159        }
160        Ok(captures)
161    }
162
163    pub fn parse_file_with_template(
164        &mut self,
165        relative_path: &Path,
166        content: &str,
167        language: Arc<Language>,
168    ) -> Result<Vec<Document>> {
169        let language_name = language.name();
170
171        if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
172            return self.parse_entire_file(relative_path, language_name, &content);
173        } else if &language_name.to_string() == &"Markdown".to_string() {
174            return self.parse_markdown_file(relative_path, &content);
175        }
176
177        let mut documents = self.parse_file(content, language)?;
178        for document in &mut documents {
179            let document_content = CODE_CONTEXT_TEMPLATE
180                .replace("<path>", relative_path.to_string_lossy().as_ref())
181                .replace("<language>", language_name.as_ref())
182                .replace("item", &document.content);
183
184            let token_count = self.embedding_provider.count_tokens(&document_content);
185            document.content = document_content;
186            document.token_count = token_count;
187        }
188        Ok(documents)
189    }
190
191    pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Document>> {
192        let grammar = language
193            .grammar()
194            .ok_or_else(|| anyhow!("no grammar for language"))?;
195
196        // Iterate through query matches
197        let matches = self.get_matches_in_file(content, grammar)?;
198
199        let language_scope = language.default_scope();
200        let placeholder = language_scope.collapsed_placeholder();
201
202        let mut documents = Vec::new();
203        let mut collapsed_ranges_within = Vec::new();
204        let mut parsed_name_ranges = HashSet::new();
205        for (i, context_match) in matches.iter().enumerate() {
206            // Items which are collapsible but not embeddable have no item range
207            let item_range = if let Some(item_range) = context_match.item_range.clone() {
208                item_range
209            } else {
210                continue;
211            };
212
213            // Checks for deduplication
214            let name;
215            if let Some(name_range) = context_match.name_range.clone() {
216                name = content
217                    .get(name_range.clone())
218                    .map_or(String::new(), |s| s.to_string());
219                if parsed_name_ranges.contains(&name_range) {
220                    continue;
221                }
222                parsed_name_ranges.insert(name_range);
223            } else {
224                name = String::new();
225            }
226
227            collapsed_ranges_within.clear();
228            'outer: for remaining_match in &matches[(i + 1)..] {
229                for collapsed_range in &remaining_match.collapse_ranges {
230                    if item_range.start <= collapsed_range.start
231                        && item_range.end >= collapsed_range.end
232                    {
233                        collapsed_ranges_within.push(collapsed_range.clone());
234                    } else {
235                        break 'outer;
236                    }
237                }
238            }
239
240            collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
241
242            let mut document_content = String::new();
243            for context_range in &context_match.context_ranges {
244                add_content_from_range(
245                    &mut document_content,
246                    content,
247                    context_range.clone(),
248                    context_match.start_col,
249                );
250                document_content.push_str("\n");
251            }
252
253            let mut offset = item_range.start;
254            for collapsed_range in &collapsed_ranges_within {
255                if collapsed_range.start > offset {
256                    add_content_from_range(
257                        &mut document_content,
258                        content,
259                        offset..collapsed_range.start,
260                        context_match.start_col,
261                    );
262                    offset = collapsed_range.start;
263                }
264
265                if collapsed_range.end > offset {
266                    document_content.push_str(placeholder);
267                    offset = collapsed_range.end;
268                }
269            }
270
271            if offset < item_range.end {
272                add_content_from_range(
273                    &mut document_content,
274                    content,
275                    offset..item_range.end,
276                    context_match.start_col,
277                );
278            }
279
280            let mut sha1 = Sha1::new();
281            sha1.update(&document_content);
282
283            documents.push(Document {
284                name,
285                content: document_content,
286                range: item_range.clone(),
287                embedding: vec![],
288                sha1: sha1.finalize().into(),
289                token_count: 0,
290            })
291        }
292
293        return Ok(documents);
294    }
295}
296
297pub(crate) fn subtract_ranges(
298    ranges: &[Range<usize>],
299    ranges_to_subtract: &[Range<usize>],
300) -> Vec<Range<usize>> {
301    let mut result = Vec::new();
302
303    let mut ranges_to_subtract = ranges_to_subtract.iter().peekable();
304
305    for range in ranges {
306        let mut offset = range.start;
307
308        while offset < range.end {
309            if let Some(range_to_subtract) = ranges_to_subtract.peek() {
310                if offset < range_to_subtract.start {
311                    let next_offset = cmp::min(range_to_subtract.start, range.end);
312                    result.push(offset..next_offset);
313                    offset = next_offset;
314                } else {
315                    let next_offset = cmp::min(range_to_subtract.end, range.end);
316                    offset = next_offset;
317                }
318
319                if offset >= range_to_subtract.end {
320                    ranges_to_subtract.next();
321                }
322            } else {
323                result.push(offset..range.end);
324                offset = range.end;
325            }
326        }
327    }
328
329    result
330}
331
332fn add_content_from_range(
333    output: &mut String,
334    content: &str,
335    range: Range<usize>,
336    start_col: usize,
337) {
338    for mut line in content.get(range.clone()).unwrap_or("").lines() {
339        for _ in 0..start_col {
340            if line.starts_with(' ') {
341                line = &line[1..];
342            } else {
343                break;
344            }
345        }
346        output.push_str(line);
347        output.push('\n');
348    }
349    output.pop();
350}