parsing.rs

  1use anyhow::{anyhow, Ok, Result};
  2use language::Language;
  3use std::{ops::Range, path::Path, sync::Arc};
  4use tree_sitter::{Parser, QueryCursor};
  5
  6#[derive(Debug, PartialEq, Clone)]
  7pub struct Document {
  8    pub name: String,
  9    pub range: Range<usize>,
 10    pub content: String,
 11    pub embedding: Vec<f32>,
 12}
 13
 14const CODE_CONTEXT_TEMPLATE: &str =
 15    "The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
 16
 17pub struct CodeContextRetriever {
 18    pub parser: Parser,
 19    pub cursor: QueryCursor,
 20}
 21
 22impl CodeContextRetriever {
 23    pub fn new() -> Self {
 24        Self {
 25            parser: Parser::new(),
 26            cursor: QueryCursor::new(),
 27        }
 28    }
 29
 30    pub fn parse_file(
 31        &mut self,
 32        relative_path: &Path,
 33        content: &str,
 34        language: Arc<Language>,
 35    ) -> Result<Vec<Document>> {
 36        let grammar = language
 37            .grammar()
 38            .ok_or_else(|| anyhow!("no grammar for language"))?;
 39        let embedding_config = grammar
 40            .embedding_config
 41            .as_ref()
 42            .ok_or_else(|| anyhow!("no embedding queries"))?;
 43
 44        self.parser.set_language(grammar.ts_language).unwrap();
 45
 46        let tree = self
 47            .parser
 48            .parse(&content, None)
 49            .ok_or_else(|| anyhow!("parsing failed"))?;
 50
 51        let mut documents = Vec::new();
 52
 53        // Iterate through query matches
 54        let mut name_ranges: Vec<Range<usize>> = vec![];
 55        for mat in self.cursor.matches(
 56            &embedding_config.query,
 57            tree.root_node(),
 58            content.as_bytes(),
 59        ) {
 60            let mut name: Vec<&str> = vec![];
 61            let mut item: Option<&str> = None;
 62            let mut byte_range: Option<Range<usize>> = None;
 63            let mut context_spans: Vec<&str> = vec![];
 64            for capture in mat.captures {
 65                if capture.index == embedding_config.item_capture_ix {
 66                    byte_range = Some(capture.node.byte_range());
 67                    item = content.get(capture.node.byte_range());
 68                } else if capture.index == embedding_config.name_capture_ix {
 69                    let name_range = capture.node.byte_range();
 70                    if name_ranges.contains(&name_range) {
 71                        continue;
 72                    }
 73                    name_ranges.push(name_range.clone());
 74                    if let Some(name_content) = content.get(name_range.clone()) {
 75                        name.push(name_content);
 76                    }
 77                }
 78
 79                if let Some(context_capture_ix) = embedding_config.context_capture_ix {
 80                    if capture.index == context_capture_ix {
 81                        if let Some(context) = content.get(capture.node.byte_range()) {
 82                            context_spans.push(context);
 83                        }
 84                    }
 85                }
 86            }
 87
 88            if let Some((item, byte_range)) = item.zip(byte_range) {
 89                if !name.is_empty() {
 90                    let item = if context_spans.is_empty() {
 91                        item.to_string()
 92                    } else {
 93                        format!("{}\n{}", context_spans.join("\n"), item)
 94                    };
 95
 96                    let document_text = CODE_CONTEXT_TEMPLATE
 97                        .replace("<path>", relative_path.to_str().unwrap())
 98                        .replace("<language>", &language.name().to_lowercase())
 99                        .replace("<item>", item.as_str());
100
101                    documents.push(Document {
102                        range: byte_range,
103                        content: document_text,
104                        embedding: Vec::new(),
105                        name: name.join(" ").to_string(),
106                    });
107                }
108            }
109        }
110
111        return Ok(documents);
112    }
113}