parsing.rs

  1use ai::embedding::{Embedding, EmbeddingProvider};
  2use anyhow::{anyhow, Result};
  3use language::{Grammar, Language};
  4use rusqlite::{
  5    types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
  6    ToSql,
  7};
  8use sha1::{Digest, Sha1};
  9use std::{
 10    borrow::Cow,
 11    cmp::{self, Reverse},
 12    collections::HashSet,
 13    ops::Range,
 14    path::Path,
 15    sync::Arc,
 16};
 17use tree_sitter::{Parser, QueryCursor};
 18
 19#[derive(Debug, PartialEq, Eq, Clone, Hash)]
 20pub struct SpanDigest(pub [u8; 20]);
 21
 22impl FromSql for SpanDigest {
 23    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 24        let blob = value.as_blob()?;
 25        let bytes =
 26            blob.try_into()
 27                .map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize {
 28                    expected_size: 20,
 29                    blob_size: blob.len(),
 30                })?;
 31        return Ok(SpanDigest(bytes));
 32    }
 33}
 34
 35impl ToSql for SpanDigest {
 36    fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
 37        self.0.to_sql()
 38    }
 39}
 40
 41impl From<&'_ str> for SpanDigest {
 42    fn from(value: &'_ str) -> Self {
 43        let mut sha1 = Sha1::new();
 44        sha1.update(value);
 45        Self(sha1.finalize().into())
 46    }
 47}
 48
 49#[derive(Debug, PartialEq, Clone)]
 50pub struct Span {
 51    pub name: String,
 52    pub range: Range<usize>,
 53    pub content: String,
 54    pub embedding: Option<Embedding>,
 55    pub digest: SpanDigest,
 56    pub token_count: usize,
 57}
 58
 59const CODE_CONTEXT_TEMPLATE: &str =
 60    "The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
 61const ENTIRE_FILE_TEMPLATE: &str =
 62    "The below snippet is from file '<path>'\n\n```<language>\n<item>\n```";
 63const MARKDOWN_CONTEXT_TEMPLATE: &str = "The below file contents is from file '<path>'\n\n<item>";
 64pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] = &[
 65    "TOML", "YAML", "CSS", "HEEX", "ERB", "SVELTE", "HTML", "Scheme",
 66];
 67
 68pub struct CodeContextRetriever {
 69    pub parser: Parser,
 70    pub cursor: QueryCursor,
 71    pub embedding_provider: Arc<dyn EmbeddingProvider>,
 72}
 73
 74// Every match has an item, this represents the fundamental treesitter symbol and anchors the search
 75// Every match has one or more 'name' captures. These indicate the display range of the item for deduplication.
 76// If there are preceeding comments, we track this with a context capture
 77// If there is a piece that should be collapsed in hierarchical queries, we capture it with a collapse capture
 78// If there is a piece that should be kept inside a collapsed node, we capture it with a keep capture
 79#[derive(Debug, Clone)]
 80pub struct CodeContextMatch {
 81    pub start_col: usize,
 82    pub item_range: Option<Range<usize>>,
 83    pub name_range: Option<Range<usize>>,
 84    pub context_ranges: Vec<Range<usize>>,
 85    pub collapse_ranges: Vec<Range<usize>>,
 86}
 87
 88impl CodeContextRetriever {
 89    pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
 90        Self {
 91            parser: Parser::new(),
 92            cursor: QueryCursor::new(),
 93            embedding_provider,
 94        }
 95    }
 96
 97    fn parse_entire_file(
 98        &self,
 99        relative_path: Option<&Path>,
100        language_name: Arc<str>,
101        content: &str,
102    ) -> Result<Vec<Span>> {
103        let document_span = ENTIRE_FILE_TEMPLATE
104            .replace(
105                "<path>",
106                &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
107            )
108            .replace("<language>", language_name.as_ref())
109            .replace("<item>", &content);
110        let digest = SpanDigest::from(document_span.as_str());
111        let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
112        Ok(vec![Span {
113            range: 0..content.len(),
114            content: document_span,
115            embedding: Default::default(),
116            name: language_name.to_string(),
117            digest,
118            token_count,
119        }])
120    }
121
122    fn parse_markdown_file(
123        &self,
124        relative_path: Option<&Path>,
125        content: &str,
126    ) -> Result<Vec<Span>> {
127        let document_span = MARKDOWN_CONTEXT_TEMPLATE
128            .replace(
129                "<path>",
130                &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
131            )
132            .replace("<item>", &content);
133        let digest = SpanDigest::from(document_span.as_str());
134        let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
135        Ok(vec![Span {
136            range: 0..content.len(),
137            content: document_span,
138            embedding: None,
139            name: "Markdown".to_string(),
140            digest,
141            token_count,
142        }])
143    }
144
145    fn get_matches_in_file(
146        &mut self,
147        content: &str,
148        grammar: &Arc<Grammar>,
149    ) -> Result<Vec<CodeContextMatch>> {
150        let embedding_config = grammar
151            .embedding_config
152            .as_ref()
153            .ok_or_else(|| anyhow!("no embedding queries"))?;
154        self.parser.set_language(grammar.ts_language).unwrap();
155
156        let tree = self
157            .parser
158            .parse(&content, None)
159            .ok_or_else(|| anyhow!("parsing failed"))?;
160
161        let mut captures: Vec<CodeContextMatch> = Vec::new();
162        let mut collapse_ranges: Vec<Range<usize>> = Vec::new();
163        let mut keep_ranges: Vec<Range<usize>> = Vec::new();
164        for mat in self.cursor.matches(
165            &embedding_config.query,
166            tree.root_node(),
167            content.as_bytes(),
168        ) {
169            let mut start_col = 0;
170            let mut item_range: Option<Range<usize>> = None;
171            let mut name_range: Option<Range<usize>> = None;
172            let mut context_ranges: Vec<Range<usize>> = Vec::new();
173            collapse_ranges.clear();
174            keep_ranges.clear();
175            for capture in mat.captures {
176                if capture.index == embedding_config.item_capture_ix {
177                    item_range = Some(capture.node.byte_range());
178                    start_col = capture.node.start_position().column;
179                } else if Some(capture.index) == embedding_config.name_capture_ix {
180                    name_range = Some(capture.node.byte_range());
181                } else if Some(capture.index) == embedding_config.context_capture_ix {
182                    context_ranges.push(capture.node.byte_range());
183                } else if Some(capture.index) == embedding_config.collapse_capture_ix {
184                    collapse_ranges.push(capture.node.byte_range());
185                } else if Some(capture.index) == embedding_config.keep_capture_ix {
186                    keep_ranges.push(capture.node.byte_range());
187                }
188            }
189
190            captures.push(CodeContextMatch {
191                start_col,
192                item_range,
193                name_range,
194                context_ranges,
195                collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
196            });
197        }
198        Ok(captures)
199    }
200
201    pub fn parse_file_with_template(
202        &mut self,
203        relative_path: Option<&Path>,
204        content: &str,
205        language: Arc<Language>,
206    ) -> Result<Vec<Span>> {
207        let language_name = language.name();
208
209        if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
210            return self.parse_entire_file(relative_path, language_name, &content);
211        } else if ["Markdown", "Plain Text"].contains(&language_name.as_ref()) {
212            return self.parse_markdown_file(relative_path, &content);
213        }
214
215        let mut spans = self.parse_file(content, language)?;
216        for span in &mut spans {
217            let document_content = CODE_CONTEXT_TEMPLATE
218                .replace(
219                    "<path>",
220                    &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
221                )
222                .replace("<language>", language_name.as_ref())
223                .replace("item", &span.content);
224
225            let (document_content, token_count) =
226                self.embedding_provider.truncate(&document_content);
227
228            span.content = document_content;
229            span.token_count = token_count;
230        }
231        Ok(spans)
232    }
233
234    pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Span>> {
235        let grammar = language
236            .grammar()
237            .ok_or_else(|| anyhow!("no grammar for language"))?;
238
239        // Iterate through query matches
240        let matches = self.get_matches_in_file(content, grammar)?;
241
242        let language_scope = language.default_scope();
243        let placeholder = language_scope.collapsed_placeholder();
244
245        let mut spans = Vec::new();
246        let mut collapsed_ranges_within = Vec::new();
247        let mut parsed_name_ranges = HashSet::new();
248        for (i, context_match) in matches.iter().enumerate() {
249            // Items which are collapsible but not embeddable have no item range
250            let item_range = if let Some(item_range) = context_match.item_range.clone() {
251                item_range
252            } else {
253                continue;
254            };
255
256            // Checks for deduplication
257            let name;
258            if let Some(name_range) = context_match.name_range.clone() {
259                name = content
260                    .get(name_range.clone())
261                    .map_or(String::new(), |s| s.to_string());
262                if parsed_name_ranges.contains(&name_range) {
263                    continue;
264                }
265                parsed_name_ranges.insert(name_range);
266            } else {
267                name = String::new();
268            }
269
270            collapsed_ranges_within.clear();
271            'outer: for remaining_match in &matches[(i + 1)..] {
272                for collapsed_range in &remaining_match.collapse_ranges {
273                    if item_range.start <= collapsed_range.start
274                        && item_range.end >= collapsed_range.end
275                    {
276                        collapsed_ranges_within.push(collapsed_range.clone());
277                    } else {
278                        break 'outer;
279                    }
280                }
281            }
282
283            collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
284
285            let mut span_content = String::new();
286            for context_range in &context_match.context_ranges {
287                add_content_from_range(
288                    &mut span_content,
289                    content,
290                    context_range.clone(),
291                    context_match.start_col,
292                );
293                span_content.push_str("\n");
294            }
295
296            let mut offset = item_range.start;
297            for collapsed_range in &collapsed_ranges_within {
298                if collapsed_range.start > offset {
299                    add_content_from_range(
300                        &mut span_content,
301                        content,
302                        offset..collapsed_range.start,
303                        context_match.start_col,
304                    );
305                    offset = collapsed_range.start;
306                }
307
308                if collapsed_range.end > offset {
309                    span_content.push_str(placeholder);
310                    offset = collapsed_range.end;
311                }
312            }
313
314            if offset < item_range.end {
315                add_content_from_range(
316                    &mut span_content,
317                    content,
318                    offset..item_range.end,
319                    context_match.start_col,
320                );
321            }
322
323            let sha1 = SpanDigest::from(span_content.as_str());
324            spans.push(Span {
325                name,
326                content: span_content,
327                range: item_range.clone(),
328                embedding: None,
329                digest: sha1,
330                token_count: 0,
331            })
332        }
333
334        return Ok(spans);
335    }
336}
337
338pub(crate) fn subtract_ranges(
339    ranges: &[Range<usize>],
340    ranges_to_subtract: &[Range<usize>],
341) -> Vec<Range<usize>> {
342    let mut result = Vec::new();
343
344    let mut ranges_to_subtract = ranges_to_subtract.iter().peekable();
345
346    for range in ranges {
347        let mut offset = range.start;
348
349        while offset < range.end {
350            if let Some(range_to_subtract) = ranges_to_subtract.peek() {
351                if offset < range_to_subtract.start {
352                    let next_offset = cmp::min(range_to_subtract.start, range.end);
353                    result.push(offset..next_offset);
354                    offset = next_offset;
355                } else {
356                    let next_offset = cmp::min(range_to_subtract.end, range.end);
357                    offset = next_offset;
358                }
359
360                if offset >= range_to_subtract.end {
361                    ranges_to_subtract.next();
362                }
363            } else {
364                result.push(offset..range.end);
365                offset = range.end;
366            }
367        }
368    }
369
370    result
371}
372
373fn add_content_from_range(
374    output: &mut String,
375    content: &str,
376    range: Range<usize>,
377    start_col: usize,
378) {
379    for mut line in content.get(range.clone()).unwrap_or("").lines() {
380        for _ in 0..start_col {
381            if line.starts_with(' ') {
382                line = &line[1..];
383            } else {
384                break;
385            }
386        }
387        output.push_str(line);
388        output.push('\n');
389    }
390    output.pop();
391}