parsing.rs

  1use crate::embedding::{Embedding, EmbeddingProvider};
  2use anyhow::{anyhow, Result};
  3use language::{Grammar, Language};
  4use rusqlite::{
  5    types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
  6    ToSql,
  7};
  8use sha1::{Digest, Sha1};
  9use std::{
 10    cmp::{self, Reverse},
 11    collections::HashSet,
 12    ops::Range,
 13    path::Path,
 14    sync::Arc,
 15};
 16use tree_sitter::{Parser, QueryCursor};
 17
 18#[derive(Debug, PartialEq, Eq, Clone, Hash)]
 19pub struct SpanDigest([u8; 20]);
 20
 21impl FromSql for SpanDigest {
 22    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 23        let blob = value.as_blob()?;
 24        let bytes =
 25            blob.try_into()
 26                .map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize {
 27                    expected_size: 20,
 28                    blob_size: blob.len(),
 29                })?;
 30        return Ok(SpanDigest(bytes));
 31    }
 32}
 33
 34impl ToSql for SpanDigest {
 35    fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
 36        self.0.to_sql()
 37    }
 38}
 39
 40impl From<&'_ str> for SpanDigest {
 41    fn from(value: &'_ str) -> Self {
 42        let mut sha1 = Sha1::new();
 43        sha1.update(value);
 44        Self(sha1.finalize().into())
 45    }
 46}
 47
 48#[derive(Debug, PartialEq, Clone)]
 49pub struct Span {
 50    pub name: String,
 51    pub range: Range<usize>,
 52    pub content: String,
 53    pub embedding: Option<Embedding>,
 54    pub digest: SpanDigest,
 55    pub token_count: usize,
 56}
 57
 58const CODE_CONTEXT_TEMPLATE: &str =
 59    "The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
 60const ENTIRE_FILE_TEMPLATE: &str =
 61    "The below snippet is from file '<path>'\n\n```<language>\n<item>\n```";
 62const MARKDOWN_CONTEXT_TEMPLATE: &str = "The below file contents is from file '<path>'\n\n<item>";
 63pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] =
 64    &["TOML", "YAML", "CSS", "HEEX", "ERB", "SVELTE", "HTML"];
 65
 66pub struct CodeContextRetriever {
 67    pub parser: Parser,
 68    pub cursor: QueryCursor,
 69    pub embedding_provider: Arc<dyn EmbeddingProvider>,
 70}
 71
 72// Every match has an item, this represents the fundamental treesitter symbol and anchors the search
 73// Every match has one or more 'name' captures. These indicate the display range of the item for deduplication.
 74// If there are preceeding comments, we track this with a context capture
 75// If there is a piece that should be collapsed in hierarchical queries, we capture it with a collapse capture
 76// If there is a piece that should be kept inside a collapsed node, we capture it with a keep capture
 77#[derive(Debug, Clone)]
 78pub struct CodeContextMatch {
 79    pub start_col: usize,
 80    pub item_range: Option<Range<usize>>,
 81    pub name_range: Option<Range<usize>>,
 82    pub context_ranges: Vec<Range<usize>>,
 83    pub collapse_ranges: Vec<Range<usize>>,
 84}
 85
 86impl CodeContextRetriever {
 87    pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
 88        Self {
 89            parser: Parser::new(),
 90            cursor: QueryCursor::new(),
 91            embedding_provider,
 92        }
 93    }
 94
 95    fn parse_entire_file(
 96        &self,
 97        relative_path: &Path,
 98        language_name: Arc<str>,
 99        content: &str,
100    ) -> Result<Vec<Span>> {
101        let document_span = ENTIRE_FILE_TEMPLATE
102            .replace("<path>", relative_path.to_string_lossy().as_ref())
103            .replace("<language>", language_name.as_ref())
104            .replace("<item>", &content);
105        let digest = SpanDigest::from(document_span.as_str());
106        let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
107        Ok(vec![Span {
108            range: 0..content.len(),
109            content: document_span,
110            embedding: Default::default(),
111            name: language_name.to_string(),
112            digest,
113            token_count,
114        }])
115    }
116
117    fn parse_markdown_file(&self, relative_path: &Path, content: &str) -> Result<Vec<Span>> {
118        let document_span = MARKDOWN_CONTEXT_TEMPLATE
119            .replace("<path>", relative_path.to_string_lossy().as_ref())
120            .replace("<item>", &content);
121        let digest = SpanDigest::from(document_span.as_str());
122        let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
123        Ok(vec![Span {
124            range: 0..content.len(),
125            content: document_span,
126            embedding: None,
127            name: "Markdown".to_string(),
128            digest,
129            token_count,
130        }])
131    }
132
133    fn get_matches_in_file(
134        &mut self,
135        content: &str,
136        grammar: &Arc<Grammar>,
137    ) -> Result<Vec<CodeContextMatch>> {
138        let embedding_config = grammar
139            .embedding_config
140            .as_ref()
141            .ok_or_else(|| anyhow!("no embedding queries"))?;
142        self.parser.set_language(grammar.ts_language).unwrap();
143
144        let tree = self
145            .parser
146            .parse(&content, None)
147            .ok_or_else(|| anyhow!("parsing failed"))?;
148
149        let mut captures: Vec<CodeContextMatch> = Vec::new();
150        let mut collapse_ranges: Vec<Range<usize>> = Vec::new();
151        let mut keep_ranges: Vec<Range<usize>> = Vec::new();
152        for mat in self.cursor.matches(
153            &embedding_config.query,
154            tree.root_node(),
155            content.as_bytes(),
156        ) {
157            let mut start_col = 0;
158            let mut item_range: Option<Range<usize>> = None;
159            let mut name_range: Option<Range<usize>> = None;
160            let mut context_ranges: Vec<Range<usize>> = Vec::new();
161            collapse_ranges.clear();
162            keep_ranges.clear();
163            for capture in mat.captures {
164                if capture.index == embedding_config.item_capture_ix {
165                    item_range = Some(capture.node.byte_range());
166                    start_col = capture.node.start_position().column;
167                } else if Some(capture.index) == embedding_config.name_capture_ix {
168                    name_range = Some(capture.node.byte_range());
169                } else if Some(capture.index) == embedding_config.context_capture_ix {
170                    context_ranges.push(capture.node.byte_range());
171                } else if Some(capture.index) == embedding_config.collapse_capture_ix {
172                    collapse_ranges.push(capture.node.byte_range());
173                } else if Some(capture.index) == embedding_config.keep_capture_ix {
174                    keep_ranges.push(capture.node.byte_range());
175                }
176            }
177
178            captures.push(CodeContextMatch {
179                start_col,
180                item_range,
181                name_range,
182                context_ranges,
183                collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
184            });
185        }
186        Ok(captures)
187    }
188
189    pub fn parse_file_with_template(
190        &mut self,
191        relative_path: &Path,
192        content: &str,
193        language: Arc<Language>,
194    ) -> Result<Vec<Span>> {
195        let language_name = language.name();
196
197        if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
198            return self.parse_entire_file(relative_path, language_name, &content);
199        } else if language_name.as_ref() == "Markdown" {
200            return self.parse_markdown_file(relative_path, &content);
201        }
202
203        let mut spans = self.parse_file(content, language)?;
204        for span in &mut spans {
205            let document_content = CODE_CONTEXT_TEMPLATE
206                .replace("<path>", relative_path.to_string_lossy().as_ref())
207                .replace("<language>", language_name.as_ref())
208                .replace("item", &span.content);
209
210            let (document_content, token_count) =
211                self.embedding_provider.truncate(&document_content);
212
213            span.content = document_content;
214            span.token_count = token_count;
215        }
216        Ok(spans)
217    }
218
219    pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Span>> {
220        let grammar = language
221            .grammar()
222            .ok_or_else(|| anyhow!("no grammar for language"))?;
223
224        // Iterate through query matches
225        let matches = self.get_matches_in_file(content, grammar)?;
226
227        let language_scope = language.default_scope();
228        let placeholder = language_scope.collapsed_placeholder();
229
230        let mut spans = Vec::new();
231        let mut collapsed_ranges_within = Vec::new();
232        let mut parsed_name_ranges = HashSet::new();
233        for (i, context_match) in matches.iter().enumerate() {
234            // Items which are collapsible but not embeddable have no item range
235            let item_range = if let Some(item_range) = context_match.item_range.clone() {
236                item_range
237            } else {
238                continue;
239            };
240
241            // Checks for deduplication
242            let name;
243            if let Some(name_range) = context_match.name_range.clone() {
244                name = content
245                    .get(name_range.clone())
246                    .map_or(String::new(), |s| s.to_string());
247                if parsed_name_ranges.contains(&name_range) {
248                    continue;
249                }
250                parsed_name_ranges.insert(name_range);
251            } else {
252                name = String::new();
253            }
254
255            collapsed_ranges_within.clear();
256            'outer: for remaining_match in &matches[(i + 1)..] {
257                for collapsed_range in &remaining_match.collapse_ranges {
258                    if item_range.start <= collapsed_range.start
259                        && item_range.end >= collapsed_range.end
260                    {
261                        collapsed_ranges_within.push(collapsed_range.clone());
262                    } else {
263                        break 'outer;
264                    }
265                }
266            }
267
268            collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
269
270            let mut span_content = String::new();
271            for context_range in &context_match.context_ranges {
272                add_content_from_range(
273                    &mut span_content,
274                    content,
275                    context_range.clone(),
276                    context_match.start_col,
277                );
278                span_content.push_str("\n");
279            }
280
281            let mut offset = item_range.start;
282            for collapsed_range in &collapsed_ranges_within {
283                if collapsed_range.start > offset {
284                    add_content_from_range(
285                        &mut span_content,
286                        content,
287                        offset..collapsed_range.start,
288                        context_match.start_col,
289                    );
290                    offset = collapsed_range.start;
291                }
292
293                if collapsed_range.end > offset {
294                    span_content.push_str(placeholder);
295                    offset = collapsed_range.end;
296                }
297            }
298
299            if offset < item_range.end {
300                add_content_from_range(
301                    &mut span_content,
302                    content,
303                    offset..item_range.end,
304                    context_match.start_col,
305                );
306            }
307
308            let sha1 = SpanDigest::from(span_content.as_str());
309            spans.push(Span {
310                name,
311                content: span_content,
312                range: item_range.clone(),
313                embedding: None,
314                digest: sha1,
315                token_count: 0,
316            })
317        }
318
319        return Ok(spans);
320    }
321}
322
323pub(crate) fn subtract_ranges(
324    ranges: &[Range<usize>],
325    ranges_to_subtract: &[Range<usize>],
326) -> Vec<Range<usize>> {
327    let mut result = Vec::new();
328
329    let mut ranges_to_subtract = ranges_to_subtract.iter().peekable();
330
331    for range in ranges {
332        let mut offset = range.start;
333
334        while offset < range.end {
335            if let Some(range_to_subtract) = ranges_to_subtract.peek() {
336                if offset < range_to_subtract.start {
337                    let next_offset = cmp::min(range_to_subtract.start, range.end);
338                    result.push(offset..next_offset);
339                    offset = next_offset;
340                } else {
341                    let next_offset = cmp::min(range_to_subtract.end, range.end);
342                    offset = next_offset;
343                }
344
345                if offset >= range_to_subtract.end {
346                    ranges_to_subtract.next();
347                }
348            } else {
349                result.push(offset..range.end);
350                offset = range.end;
351            }
352        }
353    }
354
355    result
356}
357
358fn add_content_from_range(
359    output: &mut String,
360    content: &str,
361    range: Range<usize>,
362    start_col: usize,
363) {
364    for mut line in content.get(range.clone()).unwrap_or("").lines() {
365        for _ in 0..start_col {
366            if line.starts_with(' ') {
367                line = &line[1..];
368            } else {
369                break;
370            }
371        }
372        output.push_str(line);
373        output.push('\n');
374    }
375    output.pop();
376}