parsing.rs

  1use ai::embedding::{Embedding, EmbeddingProvider};
  2use anyhow::{anyhow, Result};
  3use language::{Grammar, Language};
  4use rusqlite::{
  5    types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
  6    ToSql,
  7};
  8use sha1::{Digest, Sha1};
  9use std::{
 10    borrow::Cow,
 11    cmp::{self, Reverse},
 12    collections::HashSet,
 13    ops::Range,
 14    path::Path,
 15    sync::Arc,
 16};
 17use tree_sitter::{Parser, QueryCursor};
 18
 19#[derive(Debug, PartialEq, Eq, Clone, Hash)]
 20pub struct SpanDigest(pub [u8; 20]);
 21
 22impl FromSql for SpanDigest {
 23    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 24        let blob = value.as_blob()?;
 25        let bytes =
 26            blob.try_into()
 27                .map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize {
 28                    expected_size: 20,
 29                    blob_size: blob.len(),
 30                })?;
 31        return Ok(SpanDigest(bytes));
 32    }
 33}
 34
 35impl ToSql for SpanDigest {
 36    fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
 37        self.0.to_sql()
 38    }
 39}
 40
 41impl From<&'_ str> for SpanDigest {
 42    fn from(value: &'_ str) -> Self {
 43        let mut sha1 = Sha1::new();
 44        sha1.update(value);
 45        Self(sha1.finalize().into())
 46    }
 47}
 48
 49#[derive(Debug, PartialEq, Clone)]
 50pub struct Span {
 51    pub name: String,
 52    pub range: Range<usize>,
 53    pub content: String,
 54    pub embedding: Option<Embedding>,
 55    pub digest: SpanDigest,
 56    pub token_count: usize,
 57}
 58
 59const CODE_CONTEXT_TEMPLATE: &str =
 60    "The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
 61const ENTIRE_FILE_TEMPLATE: &str =
 62    "The below snippet is from file '<path>'\n\n```<language>\n<item>\n```";
 63const MARKDOWN_CONTEXT_TEMPLATE: &str = "The below file contents is from file '<path>'\n\n<item>";
 64pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] =
 65    &["TOML", "YAML", "CSS", "HEEX", "ERB", "SVELTE", "HTML"];
 66
 67pub struct CodeContextRetriever {
 68    pub parser: Parser,
 69    pub cursor: QueryCursor,
 70    pub embedding_provider: Arc<dyn EmbeddingProvider>,
 71}
 72
 73// Every match has an item, this represents the fundamental treesitter symbol and anchors the search
 74// Every match has one or more 'name' captures. These indicate the display range of the item for deduplication.
 75// If there are preceeding comments, we track this with a context capture
 76// If there is a piece that should be collapsed in hierarchical queries, we capture it with a collapse capture
 77// If there is a piece that should be kept inside a collapsed node, we capture it with a keep capture
 78#[derive(Debug, Clone)]
 79pub struct CodeContextMatch {
 80    pub start_col: usize,
 81    pub item_range: Option<Range<usize>>,
 82    pub name_range: Option<Range<usize>>,
 83    pub context_ranges: Vec<Range<usize>>,
 84    pub collapse_ranges: Vec<Range<usize>>,
 85}
 86
 87impl CodeContextRetriever {
 88    pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
 89        Self {
 90            parser: Parser::new(),
 91            cursor: QueryCursor::new(),
 92            embedding_provider,
 93        }
 94    }
 95
 96    fn parse_entire_file(
 97        &self,
 98        relative_path: Option<&Path>,
 99        language_name: Arc<str>,
100        content: &str,
101    ) -> Result<Vec<Span>> {
102        let document_span = ENTIRE_FILE_TEMPLATE
103            .replace(
104                "<path>",
105                &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
106            )
107            .replace("<language>", language_name.as_ref())
108            .replace("<item>", &content);
109        let digest = SpanDigest::from(document_span.as_str());
110        let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
111        Ok(vec![Span {
112            range: 0..content.len(),
113            content: document_span,
114            embedding: Default::default(),
115            name: language_name.to_string(),
116            digest,
117            token_count,
118        }])
119    }
120
121    fn parse_markdown_file(
122        &self,
123        relative_path: Option<&Path>,
124        content: &str,
125    ) -> Result<Vec<Span>> {
126        let document_span = MARKDOWN_CONTEXT_TEMPLATE
127            .replace(
128                "<path>",
129                &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
130            )
131            .replace("<item>", &content);
132        let digest = SpanDigest::from(document_span.as_str());
133        let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
134        Ok(vec![Span {
135            range: 0..content.len(),
136            content: document_span,
137            embedding: None,
138            name: "Markdown".to_string(),
139            digest,
140            token_count,
141        }])
142    }
143
144    fn get_matches_in_file(
145        &mut self,
146        content: &str,
147        grammar: &Arc<Grammar>,
148    ) -> Result<Vec<CodeContextMatch>> {
149        let embedding_config = grammar
150            .embedding_config
151            .as_ref()
152            .ok_or_else(|| anyhow!("no embedding queries"))?;
153        self.parser.set_language(grammar.ts_language).unwrap();
154
155        let tree = self
156            .parser
157            .parse(&content, None)
158            .ok_or_else(|| anyhow!("parsing failed"))?;
159
160        let mut captures: Vec<CodeContextMatch> = Vec::new();
161        let mut collapse_ranges: Vec<Range<usize>> = Vec::new();
162        let mut keep_ranges: Vec<Range<usize>> = Vec::new();
163        for mat in self.cursor.matches(
164            &embedding_config.query,
165            tree.root_node(),
166            content.as_bytes(),
167        ) {
168            let mut start_col = 0;
169            let mut item_range: Option<Range<usize>> = None;
170            let mut name_range: Option<Range<usize>> = None;
171            let mut context_ranges: Vec<Range<usize>> = Vec::new();
172            collapse_ranges.clear();
173            keep_ranges.clear();
174            for capture in mat.captures {
175                if capture.index == embedding_config.item_capture_ix {
176                    item_range = Some(capture.node.byte_range());
177                    start_col = capture.node.start_position().column;
178                } else if Some(capture.index) == embedding_config.name_capture_ix {
179                    name_range = Some(capture.node.byte_range());
180                } else if Some(capture.index) == embedding_config.context_capture_ix {
181                    context_ranges.push(capture.node.byte_range());
182                } else if Some(capture.index) == embedding_config.collapse_capture_ix {
183                    collapse_ranges.push(capture.node.byte_range());
184                } else if Some(capture.index) == embedding_config.keep_capture_ix {
185                    keep_ranges.push(capture.node.byte_range());
186                }
187            }
188
189            captures.push(CodeContextMatch {
190                start_col,
191                item_range,
192                name_range,
193                context_ranges,
194                collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
195            });
196        }
197        Ok(captures)
198    }
199
200    pub fn parse_file_with_template(
201        &mut self,
202        relative_path: Option<&Path>,
203        content: &str,
204        language: Arc<Language>,
205    ) -> Result<Vec<Span>> {
206        let language_name = language.name();
207
208        if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
209            return self.parse_entire_file(relative_path, language_name, &content);
210        } else if ["Markdown", "Plain Text"].contains(&language_name.as_ref()) {
211            return self.parse_markdown_file(relative_path, &content);
212        }
213
214        let mut spans = self.parse_file(content, language)?;
215        for span in &mut spans {
216            let document_content = CODE_CONTEXT_TEMPLATE
217                .replace(
218                    "<path>",
219                    &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
220                )
221                .replace("<language>", language_name.as_ref())
222                .replace("item", &span.content);
223
224            let (document_content, token_count) =
225                self.embedding_provider.truncate(&document_content);
226
227            span.content = document_content;
228            span.token_count = token_count;
229        }
230        Ok(spans)
231    }
232
233    pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Span>> {
234        let grammar = language
235            .grammar()
236            .ok_or_else(|| anyhow!("no grammar for language"))?;
237
238        // Iterate through query matches
239        let matches = self.get_matches_in_file(content, grammar)?;
240
241        let language_scope = language.default_scope();
242        let placeholder = language_scope.collapsed_placeholder();
243
244        let mut spans = Vec::new();
245        let mut collapsed_ranges_within = Vec::new();
246        let mut parsed_name_ranges = HashSet::new();
247        for (i, context_match) in matches.iter().enumerate() {
248            // Items which are collapsible but not embeddable have no item range
249            let item_range = if let Some(item_range) = context_match.item_range.clone() {
250                item_range
251            } else {
252                continue;
253            };
254
255            // Checks for deduplication
256            let name;
257            if let Some(name_range) = context_match.name_range.clone() {
258                name = content
259                    .get(name_range.clone())
260                    .map_or(String::new(), |s| s.to_string());
261                if parsed_name_ranges.contains(&name_range) {
262                    continue;
263                }
264                parsed_name_ranges.insert(name_range);
265            } else {
266                name = String::new();
267            }
268
269            collapsed_ranges_within.clear();
270            'outer: for remaining_match in &matches[(i + 1)..] {
271                for collapsed_range in &remaining_match.collapse_ranges {
272                    if item_range.start <= collapsed_range.start
273                        && item_range.end >= collapsed_range.end
274                    {
275                        collapsed_ranges_within.push(collapsed_range.clone());
276                    } else {
277                        break 'outer;
278                    }
279                }
280            }
281
282            collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
283
284            let mut span_content = String::new();
285            for context_range in &context_match.context_ranges {
286                add_content_from_range(
287                    &mut span_content,
288                    content,
289                    context_range.clone(),
290                    context_match.start_col,
291                );
292                span_content.push_str("\n");
293            }
294
295            let mut offset = item_range.start;
296            for collapsed_range in &collapsed_ranges_within {
297                if collapsed_range.start > offset {
298                    add_content_from_range(
299                        &mut span_content,
300                        content,
301                        offset..collapsed_range.start,
302                        context_match.start_col,
303                    );
304                    offset = collapsed_range.start;
305                }
306
307                if collapsed_range.end > offset {
308                    span_content.push_str(placeholder);
309                    offset = collapsed_range.end;
310                }
311            }
312
313            if offset < item_range.end {
314                add_content_from_range(
315                    &mut span_content,
316                    content,
317                    offset..item_range.end,
318                    context_match.start_col,
319                );
320            }
321
322            let sha1 = SpanDigest::from(span_content.as_str());
323            spans.push(Span {
324                name,
325                content: span_content,
326                range: item_range.clone(),
327                embedding: None,
328                digest: sha1,
329                token_count: 0,
330            })
331        }
332
333        return Ok(spans);
334    }
335}
336
337pub(crate) fn subtract_ranges(
338    ranges: &[Range<usize>],
339    ranges_to_subtract: &[Range<usize>],
340) -> Vec<Range<usize>> {
341    let mut result = Vec::new();
342
343    let mut ranges_to_subtract = ranges_to_subtract.iter().peekable();
344
345    for range in ranges {
346        let mut offset = range.start;
347
348        while offset < range.end {
349            if let Some(range_to_subtract) = ranges_to_subtract.peek() {
350                if offset < range_to_subtract.start {
351                    let next_offset = cmp::min(range_to_subtract.start, range.end);
352                    result.push(offset..next_offset);
353                    offset = next_offset;
354                } else {
355                    let next_offset = cmp::min(range_to_subtract.end, range.end);
356                    offset = next_offset;
357                }
358
359                if offset >= range_to_subtract.end {
360                    ranges_to_subtract.next();
361                }
362            } else {
363                result.push(offset..range.end);
364                offset = range.end;
365            }
366        }
367    }
368
369    result
370}
371
372fn add_content_from_range(
373    output: &mut String,
374    content: &str,
375    range: Range<usize>,
376    start_col: usize,
377) {
378    for mut line in content.get(range.clone()).unwrap_or("").lines() {
379        for _ in 0..start_col {
380            if line.starts_with(' ') {
381                line = &line[1..];
382            } else {
383                break;
384            }
385        }
386        output.push_str(line);
387        output.push('\n');
388    }
389    output.pop();
390}