1use anyhow::{anyhow, Ok, Result};
2use language::Language;
3use std::{ops::Range, path::Path, sync::Arc};
4use tree_sitter::{Parser, QueryCursor};
5
6#[derive(Debug, PartialEq, Clone)]
7pub struct Document {
8 pub name: String,
9 pub range: Range<usize>,
10 pub content: String,
11 pub embedding: Vec<f32>,
12}
13
14const CODE_CONTEXT_TEMPLATE: &str =
15 "The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
16
17pub struct CodeContextRetriever {
18 pub parser: Parser,
19 pub cursor: QueryCursor,
20}
21
22impl CodeContextRetriever {
23 pub fn new() -> Self {
24 Self {
25 parser: Parser::new(),
26 cursor: QueryCursor::new(),
27 }
28 }
29
30 pub fn parse_file(
31 &mut self,
32 relative_path: &Path,
33 content: &str,
34 language: Arc<Language>,
35 ) -> Result<Vec<Document>> {
36 let grammar = language
37 .grammar()
38 .ok_or_else(|| anyhow!("no grammar for language"))?;
39 let embedding_config = grammar
40 .embedding_config
41 .as_ref()
42 .ok_or_else(|| anyhow!("no embedding queries"))?;
43
44 self.parser.set_language(grammar.ts_language).unwrap();
45
46 let tree = self
47 .parser
48 .parse(&content, None)
49 .ok_or_else(|| anyhow!("parsing failed"))?;
50
51 let mut documents = Vec::new();
52
53 // Iterate through query matches
54 let mut name_ranges: Vec<Range<usize>> = vec![];
55 for mat in self.cursor.matches(
56 &embedding_config.query,
57 tree.root_node(),
58 content.as_bytes(),
59 ) {
60 let mut name: Vec<&str> = vec![];
61 let mut item: Option<&str> = None;
62 let mut byte_range: Option<Range<usize>> = None;
63 let mut context_spans: Vec<&str> = vec![];
64 for capture in mat.captures {
65 if capture.index == embedding_config.item_capture_ix {
66 byte_range = Some(capture.node.byte_range());
67 item = content.get(capture.node.byte_range());
68 } else if capture.index == embedding_config.name_capture_ix {
69 let name_range = capture.node.byte_range();
70 if name_ranges.contains(&name_range) {
71 continue;
72 }
73 name_ranges.push(name_range.clone());
74 if let Some(name_content) = content.get(name_range.clone()) {
75 name.push(name_content);
76 }
77 }
78
79 if let Some(context_capture_ix) = embedding_config.context_capture_ix {
80 if capture.index == context_capture_ix {
81 if let Some(context) = content.get(capture.node.byte_range()) {
82 context_spans.push(context);
83 }
84 }
85 }
86 }
87
88 if let Some((item, byte_range)) = item.zip(byte_range) {
89 if !name.is_empty() {
90 let item = if context_spans.is_empty() {
91 item.to_string()
92 } else {
93 format!("{}\n{}", context_spans.join("\n"), item)
94 };
95
96 let document_text = CODE_CONTEXT_TEMPLATE
97 .replace("<path>", relative_path.to_str().unwrap())
98 .replace("<language>", &language.name().to_lowercase())
99 .replace("<item>", item.as_str());
100
101 documents.push(Document {
102 range: byte_range,
103 content: document_text,
104 embedding: Vec::new(),
105 name: name.join(" ").to_string(),
106 });
107 }
108 }
109 }
110
111 return Ok(documents);
112 }
113}