parsing.rs

  1use ai::{
  2    embedding::{Embedding, EmbeddingProvider},
  3    models::TruncationDirection,
  4};
  5use anyhow::{anyhow, Result};
  6use language::{Grammar, Language};
  7use rusqlite::{
  8    types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
  9    ToSql,
 10};
 11use sha1::{Digest, Sha1};
 12use std::{
 13    borrow::Cow,
 14    cmp::{self, Reverse},
 15    collections::HashSet,
 16    ops::Range,
 17    path::Path,
 18    sync::Arc,
 19};
 20use tree_sitter::{Parser, QueryCursor};
 21
 22#[derive(Debug, PartialEq, Eq, Clone, Hash)]
 23pub struct SpanDigest(pub [u8; 20]);
 24
 25impl FromSql for SpanDigest {
 26    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 27        let blob = value.as_blob()?;
 28        let bytes =
 29            blob.try_into()
 30                .map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize {
 31                    expected_size: 20,
 32                    blob_size: blob.len(),
 33                })?;
 34        return Ok(SpanDigest(bytes));
 35    }
 36}
 37
 38impl ToSql for SpanDigest {
 39    fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
 40        self.0.to_sql()
 41    }
 42}
 43
 44impl From<&'_ str> for SpanDigest {
 45    fn from(value: &'_ str) -> Self {
 46        let mut sha1 = Sha1::new();
 47        sha1.update(value);
 48        Self(sha1.finalize().into())
 49    }
 50}
 51
 52#[derive(Debug, PartialEq, Clone)]
 53pub struct Span {
 54    pub name: String,
 55    pub range: Range<usize>,
 56    pub content: String,
 57    pub embedding: Option<Embedding>,
 58    pub digest: SpanDigest,
 59    pub token_count: usize,
 60}
 61
 62const CODE_CONTEXT_TEMPLATE: &str =
 63    "The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
 64const ENTIRE_FILE_TEMPLATE: &str =
 65    "The below snippet is from file '<path>'\n\n```<language>\n<item>\n```";
 66const MARKDOWN_CONTEXT_TEMPLATE: &str = "The below file contents is from file '<path>'\n\n<item>";
 67pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] = &[
 68    "TOML", "YAML", "CSS", "HEEX", "ERB", "SVELTE", "HTML", "Scheme",
 69];
 70
 71pub struct CodeContextRetriever {
 72    pub parser: Parser,
 73    pub cursor: QueryCursor,
 74    pub embedding_provider: Arc<dyn EmbeddingProvider>,
 75}
 76
 77// Every match has an item, this represents the fundamental treesitter symbol and anchors the search
 78// Every match has one or more 'name' captures. These indicate the display range of the item for deduplication.
 79// If there are preceding comments, we track this with a context capture
 80// If there is a piece that should be collapsed in hierarchical queries, we capture it with a collapse capture
 81// If there is a piece that should be kept inside a collapsed node, we capture it with a keep capture
 82#[derive(Debug, Clone)]
 83pub struct CodeContextMatch {
 84    pub start_col: usize,
 85    pub item_range: Option<Range<usize>>,
 86    pub name_range: Option<Range<usize>>,
 87    pub context_ranges: Vec<Range<usize>>,
 88    pub collapse_ranges: Vec<Range<usize>>,
 89}
 90
 91impl CodeContextRetriever {
 92    pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
 93        Self {
 94            parser: Parser::new(),
 95            cursor: QueryCursor::new(),
 96            embedding_provider,
 97        }
 98    }
 99
100    fn parse_entire_file(
101        &self,
102        relative_path: Option<&Path>,
103        language_name: Arc<str>,
104        content: &str,
105    ) -> Result<Vec<Span>> {
106        let document_span = ENTIRE_FILE_TEMPLATE
107            .replace(
108                "<path>",
109                &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
110            )
111            .replace("<language>", language_name.as_ref())
112            .replace("<item>", &content);
113        let digest = SpanDigest::from(document_span.as_str());
114        let model = self.embedding_provider.base_model();
115        let document_span = model.truncate(
116            &document_span,
117            model.capacity()?,
118            ai::models::TruncationDirection::End,
119        )?;
120        let token_count = model.count_tokens(&document_span)?;
121
122        Ok(vec![Span {
123            range: 0..content.len(),
124            content: document_span,
125            embedding: Default::default(),
126            name: language_name.to_string(),
127            digest,
128            token_count,
129        }])
130    }
131
132    fn parse_markdown_file(
133        &self,
134        relative_path: Option<&Path>,
135        content: &str,
136    ) -> Result<Vec<Span>> {
137        let document_span = MARKDOWN_CONTEXT_TEMPLATE
138            .replace(
139                "<path>",
140                &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
141            )
142            .replace("<item>", &content);
143        let digest = SpanDigest::from(document_span.as_str());
144
145        let model = self.embedding_provider.base_model();
146        let document_span = model.truncate(
147            &document_span,
148            model.capacity()?,
149            ai::models::TruncationDirection::End,
150        )?;
151        let token_count = model.count_tokens(&document_span)?;
152
153        Ok(vec![Span {
154            range: 0..content.len(),
155            content: document_span,
156            embedding: None,
157            name: "Markdown".to_string(),
158            digest,
159            token_count,
160        }])
161    }
162
163    fn get_matches_in_file(
164        &mut self,
165        content: &str,
166        grammar: &Arc<Grammar>,
167    ) -> Result<Vec<CodeContextMatch>> {
168        let embedding_config = grammar
169            .embedding_config
170            .as_ref()
171            .ok_or_else(|| anyhow!("no embedding queries"))?;
172        self.parser.set_language(&grammar.ts_language).unwrap();
173
174        let tree = self
175            .parser
176            .parse(&content, None)
177            .ok_or_else(|| anyhow!("parsing failed"))?;
178
179        let mut captures: Vec<CodeContextMatch> = Vec::new();
180        let mut collapse_ranges: Vec<Range<usize>> = Vec::new();
181        let mut keep_ranges: Vec<Range<usize>> = Vec::new();
182        for mat in self.cursor.matches(
183            &embedding_config.query,
184            tree.root_node(),
185            content.as_bytes(),
186        ) {
187            let mut start_col = 0;
188            let mut item_range: Option<Range<usize>> = None;
189            let mut name_range: Option<Range<usize>> = None;
190            let mut context_ranges: Vec<Range<usize>> = Vec::new();
191            collapse_ranges.clear();
192            keep_ranges.clear();
193            for capture in mat.captures {
194                if capture.index == embedding_config.item_capture_ix {
195                    item_range = Some(capture.node.byte_range());
196                    start_col = capture.node.start_position().column;
197                } else if Some(capture.index) == embedding_config.name_capture_ix {
198                    name_range = Some(capture.node.byte_range());
199                } else if Some(capture.index) == embedding_config.context_capture_ix {
200                    context_ranges.push(capture.node.byte_range());
201                } else if Some(capture.index) == embedding_config.collapse_capture_ix {
202                    collapse_ranges.push(capture.node.byte_range());
203                } else if Some(capture.index) == embedding_config.keep_capture_ix {
204                    keep_ranges.push(capture.node.byte_range());
205                }
206            }
207
208            captures.push(CodeContextMatch {
209                start_col,
210                item_range,
211                name_range,
212                context_ranges,
213                collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
214            });
215        }
216        Ok(captures)
217    }
218
219    pub fn parse_file_with_template(
220        &mut self,
221        relative_path: Option<&Path>,
222        content: &str,
223        language: Arc<Language>,
224    ) -> Result<Vec<Span>> {
225        let language_name = language.name();
226
227        if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
228            return self.parse_entire_file(relative_path, language_name, &content);
229        } else if ["Markdown", "Plain Text"].contains(&language_name.as_ref()) {
230            return self.parse_markdown_file(relative_path, &content);
231        }
232
233        let mut spans = self.parse_file(content, language)?;
234        for span in &mut spans {
235            let document_content = CODE_CONTEXT_TEMPLATE
236                .replace(
237                    "<path>",
238                    &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
239                )
240                .replace("<language>", language_name.as_ref())
241                .replace("item", &span.content);
242
243            let model = self.embedding_provider.base_model();
244            let document_content = model.truncate(
245                &document_content,
246                model.capacity()?,
247                TruncationDirection::End,
248            )?;
249            let token_count = model.count_tokens(&document_content)?;
250
251            span.content = document_content;
252            span.token_count = token_count;
253        }
254        Ok(spans)
255    }
256
257    pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Span>> {
258        let grammar = language
259            .grammar()
260            .ok_or_else(|| anyhow!("no grammar for language"))?;
261
262        // Iterate through query matches
263        let matches = self.get_matches_in_file(content, grammar)?;
264
265        let language_scope = language.default_scope();
266        let placeholder = language_scope.collapsed_placeholder();
267
268        let mut spans = Vec::new();
269        let mut collapsed_ranges_within = Vec::new();
270        let mut parsed_name_ranges = HashSet::new();
271        for (i, context_match) in matches.iter().enumerate() {
272            // Items which are collapsible but not embeddable have no item range
273            let item_range = if let Some(item_range) = context_match.item_range.clone() {
274                item_range
275            } else {
276                continue;
277            };
278
279            // Checks for deduplication
280            let name;
281            if let Some(name_range) = context_match.name_range.clone() {
282                name = content
283                    .get(name_range.clone())
284                    .map_or(String::new(), |s| s.to_string());
285                if parsed_name_ranges.contains(&name_range) {
286                    continue;
287                }
288                parsed_name_ranges.insert(name_range);
289            } else {
290                name = String::new();
291            }
292
293            collapsed_ranges_within.clear();
294            'outer: for remaining_match in &matches[(i + 1)..] {
295                for collapsed_range in &remaining_match.collapse_ranges {
296                    if item_range.start <= collapsed_range.start
297                        && item_range.end >= collapsed_range.end
298                    {
299                        collapsed_ranges_within.push(collapsed_range.clone());
300                    } else {
301                        break 'outer;
302                    }
303                }
304            }
305
306            collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
307
308            let mut span_content = String::new();
309            for context_range in &context_match.context_ranges {
310                add_content_from_range(
311                    &mut span_content,
312                    content,
313                    context_range.clone(),
314                    context_match.start_col,
315                );
316                span_content.push_str("\n");
317            }
318
319            let mut offset = item_range.start;
320            for collapsed_range in &collapsed_ranges_within {
321                if collapsed_range.start > offset {
322                    add_content_from_range(
323                        &mut span_content,
324                        content,
325                        offset..collapsed_range.start,
326                        context_match.start_col,
327                    );
328                    offset = collapsed_range.start;
329                }
330
331                if collapsed_range.end > offset {
332                    span_content.push_str(placeholder);
333                    offset = collapsed_range.end;
334                }
335            }
336
337            if offset < item_range.end {
338                add_content_from_range(
339                    &mut span_content,
340                    content,
341                    offset..item_range.end,
342                    context_match.start_col,
343                );
344            }
345
346            let sha1 = SpanDigest::from(span_content.as_str());
347            spans.push(Span {
348                name,
349                content: span_content,
350                range: item_range.clone(),
351                embedding: None,
352                digest: sha1,
353                token_count: 0,
354            })
355        }
356
357        return Ok(spans);
358    }
359}
360
361pub(crate) fn subtract_ranges(
362    ranges: &[Range<usize>],
363    ranges_to_subtract: &[Range<usize>],
364) -> Vec<Range<usize>> {
365    let mut result = Vec::new();
366
367    let mut ranges_to_subtract = ranges_to_subtract.iter().peekable();
368
369    for range in ranges {
370        let mut offset = range.start;
371
372        while offset < range.end {
373            if let Some(range_to_subtract) = ranges_to_subtract.peek() {
374                if offset < range_to_subtract.start {
375                    let next_offset = cmp::min(range_to_subtract.start, range.end);
376                    result.push(offset..next_offset);
377                    offset = next_offset;
378                } else {
379                    let next_offset = cmp::min(range_to_subtract.end, range.end);
380                    offset = next_offset;
381                }
382
383                if offset >= range_to_subtract.end {
384                    ranges_to_subtract.next();
385                }
386            } else {
387                result.push(offset..range.end);
388                offset = range.end;
389            }
390        }
391    }
392
393    result
394}
395
396fn add_content_from_range(
397    output: &mut String,
398    content: &str,
399    range: Range<usize>,
400    start_col: usize,
401) {
402    for mut line in content.get(range.clone()).unwrap_or("").lines() {
403        for _ in 0..start_col {
404            if line.starts_with(' ') {
405                line = &line[1..];
406            } else {
407                break;
408            }
409        }
410        output.push_str(line);
411        output.push('\n');
412    }
413    output.pop();
414}