1use anyhow::{anyhow, Ok, Result};
2use language::{Grammar, Language};
3use sha1::{Digest, Sha1};
4use std::{
5 cmp::{self, Reverse},
6 collections::HashSet,
7 ops::Range,
8 path::Path,
9 sync::Arc,
10};
11use tree_sitter::{Parser, QueryCursor};
12
13#[derive(Debug, PartialEq, Clone)]
14pub struct Document {
15 pub name: String,
16 pub range: Range<usize>,
17 pub content: String,
18 pub embedding: Vec<f32>,
19 pub sha1: [u8; 20],
20}
21
22const CODE_CONTEXT_TEMPLATE: &str =
23 "The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
24const ENTIRE_FILE_TEMPLATE: &str =
25 "The below snippet is from file '<path>'\n\n```<language>\n<item>\n```";
26const MARKDOWN_CONTEXT_TEMPLATE: &str = "The below file contents is from file '<path>'\n\n<item>";
27pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] =
28 &["TOML", "YAML", "CSS", "HEEX", "ERB", "SVELTE", "HTML"];
29
30pub struct CodeContextRetriever {
31 pub parser: Parser,
32 pub cursor: QueryCursor,
33}
34
35// Every match has an item, this represents the fundamental treesitter symbol and anchors the search
36// Every match has one or more 'name' captures. These indicate the display range of the item for deduplication.
37// If there are preceeding comments, we track this with a context capture
38// If there is a piece that should be collapsed in hierarchical queries, we capture it with a collapse capture
39// If there is a piece that should be kept inside a collapsed node, we capture it with a keep capture
40#[derive(Debug, Clone)]
41pub struct CodeContextMatch {
42 pub start_col: usize,
43 pub item_range: Option<Range<usize>>,
44 pub name_range: Option<Range<usize>>,
45 pub context_ranges: Vec<Range<usize>>,
46 pub collapse_ranges: Vec<Range<usize>>,
47}
48
49impl CodeContextRetriever {
50 pub fn new() -> Self {
51 Self {
52 parser: Parser::new(),
53 cursor: QueryCursor::new(),
54 }
55 }
56
57 fn parse_entire_file(
58 &self,
59 relative_path: &Path,
60 language_name: Arc<str>,
61 content: &str,
62 ) -> Result<Vec<Document>> {
63 let document_span = ENTIRE_FILE_TEMPLATE
64 .replace("<path>", relative_path.to_string_lossy().as_ref())
65 .replace("<language>", language_name.as_ref())
66 .replace("<item>", &content);
67
68 let mut sha1 = Sha1::new();
69 sha1.update(&document_span);
70
71 Ok(vec![Document {
72 range: 0..content.len(),
73 content: document_span,
74 embedding: Vec::new(),
75 name: language_name.to_string(),
76 sha1: sha1.finalize().into(),
77 }])
78 }
79
80 fn parse_markdown_file(&self, relative_path: &Path, content: &str) -> Result<Vec<Document>> {
81 let document_span = MARKDOWN_CONTEXT_TEMPLATE
82 .replace("<path>", relative_path.to_string_lossy().as_ref())
83 .replace("<item>", &content);
84
85 let mut sha1 = Sha1::new();
86 sha1.update(&document_span);
87
88 Ok(vec![Document {
89 range: 0..content.len(),
90 content: document_span,
91 embedding: Vec::new(),
92 name: "Markdown".to_string(),
93 sha1: sha1.finalize().into(),
94 }])
95 }
96
97 fn get_matches_in_file(
98 &mut self,
99 content: &str,
100 grammar: &Arc<Grammar>,
101 ) -> Result<Vec<CodeContextMatch>> {
102 let embedding_config = grammar
103 .embedding_config
104 .as_ref()
105 .ok_or_else(|| anyhow!("no embedding queries"))?;
106 self.parser.set_language(grammar.ts_language).unwrap();
107
108 let tree = self
109 .parser
110 .parse(&content, None)
111 .ok_or_else(|| anyhow!("parsing failed"))?;
112
113 let mut captures: Vec<CodeContextMatch> = Vec::new();
114 let mut collapse_ranges: Vec<Range<usize>> = Vec::new();
115 let mut keep_ranges: Vec<Range<usize>> = Vec::new();
116 for mat in self.cursor.matches(
117 &embedding_config.query,
118 tree.root_node(),
119 content.as_bytes(),
120 ) {
121 let mut start_col = 0;
122 let mut item_range: Option<Range<usize>> = None;
123 let mut name_range: Option<Range<usize>> = None;
124 let mut context_ranges: Vec<Range<usize>> = Vec::new();
125 collapse_ranges.clear();
126 keep_ranges.clear();
127 for capture in mat.captures {
128 if capture.index == embedding_config.item_capture_ix {
129 item_range = Some(capture.node.byte_range());
130 start_col = capture.node.start_position().column;
131 } else if Some(capture.index) == embedding_config.name_capture_ix {
132 name_range = Some(capture.node.byte_range());
133 } else if Some(capture.index) == embedding_config.context_capture_ix {
134 context_ranges.push(capture.node.byte_range());
135 } else if Some(capture.index) == embedding_config.collapse_capture_ix {
136 collapse_ranges.push(capture.node.byte_range());
137 } else if Some(capture.index) == embedding_config.keep_capture_ix {
138 keep_ranges.push(capture.node.byte_range());
139 }
140 }
141
142 captures.push(CodeContextMatch {
143 start_col,
144 item_range,
145 name_range,
146 context_ranges,
147 collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
148 });
149 }
150 Ok(captures)
151 }
152
153 pub fn parse_file_with_template(
154 &mut self,
155 relative_path: &Path,
156 content: &str,
157 language: Arc<Language>,
158 ) -> Result<Vec<Document>> {
159 let language_name = language.name();
160
161 if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
162 return self.parse_entire_file(relative_path, language_name, &content);
163 } else if &language_name.to_string() == &"Markdown".to_string() {
164 return self.parse_markdown_file(relative_path, &content);
165 }
166
167 let mut documents = self.parse_file(content, language)?;
168 for document in &mut documents {
169 document.content = CODE_CONTEXT_TEMPLATE
170 .replace("<path>", relative_path.to_string_lossy().as_ref())
171 .replace("<language>", language_name.as_ref())
172 .replace("item", &document.content);
173 }
174 Ok(documents)
175 }
176
177 pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Document>> {
178 let grammar = language
179 .grammar()
180 .ok_or_else(|| anyhow!("no grammar for language"))?;
181
182 // Iterate through query matches
183 let matches = self.get_matches_in_file(content, grammar)?;
184
185 let language_scope = language.default_scope();
186 let placeholder = language_scope.collapsed_placeholder();
187
188 let mut documents = Vec::new();
189 let mut collapsed_ranges_within = Vec::new();
190 let mut parsed_name_ranges = HashSet::new();
191 for (i, context_match) in matches.iter().enumerate() {
192 // Items which are collapsible but not embeddable have no item range
193 let item_range = if let Some(item_range) = context_match.item_range.clone() {
194 item_range
195 } else {
196 continue;
197 };
198
199 // Checks for deduplication
200 let name;
201 if let Some(name_range) = context_match.name_range.clone() {
202 name = content
203 .get(name_range.clone())
204 .map_or(String::new(), |s| s.to_string());
205 if parsed_name_ranges.contains(&name_range) {
206 continue;
207 }
208 parsed_name_ranges.insert(name_range);
209 } else {
210 name = String::new();
211 }
212
213 collapsed_ranges_within.clear();
214 'outer: for remaining_match in &matches[(i + 1)..] {
215 for collapsed_range in &remaining_match.collapse_ranges {
216 if item_range.start <= collapsed_range.start
217 && item_range.end >= collapsed_range.end
218 {
219 collapsed_ranges_within.push(collapsed_range.clone());
220 } else {
221 break 'outer;
222 }
223 }
224 }
225
226 collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
227
228 let mut document_content = String::new();
229 for context_range in &context_match.context_ranges {
230 add_content_from_range(
231 &mut document_content,
232 content,
233 context_range.clone(),
234 context_match.start_col,
235 );
236 document_content.push_str("\n");
237 }
238
239 let mut offset = item_range.start;
240 for collapsed_range in &collapsed_ranges_within {
241 if collapsed_range.start > offset {
242 add_content_from_range(
243 &mut document_content,
244 content,
245 offset..collapsed_range.start,
246 context_match.start_col,
247 );
248 offset = collapsed_range.start;
249 }
250
251 if collapsed_range.end > offset {
252 document_content.push_str(placeholder);
253 offset = collapsed_range.end;
254 }
255 }
256
257 if offset < item_range.end {
258 add_content_from_range(
259 &mut document_content,
260 content,
261 offset..item_range.end,
262 context_match.start_col,
263 );
264 }
265
266 let mut sha1 = Sha1::new();
267 sha1.update(&document_content);
268
269 documents.push(Document {
270 name,
271 content: document_content,
272 range: item_range.clone(),
273 embedding: vec![],
274 sha1: sha1.finalize().into(),
275 })
276 }
277
278 return Ok(documents);
279 }
280}
281
282pub(crate) fn subtract_ranges(
283 ranges: &[Range<usize>],
284 ranges_to_subtract: &[Range<usize>],
285) -> Vec<Range<usize>> {
286 let mut result = Vec::new();
287
288 let mut ranges_to_subtract = ranges_to_subtract.iter().peekable();
289
290 for range in ranges {
291 let mut offset = range.start;
292
293 while offset < range.end {
294 if let Some(range_to_subtract) = ranges_to_subtract.peek() {
295 if offset < range_to_subtract.start {
296 let next_offset = cmp::min(range_to_subtract.start, range.end);
297 result.push(offset..next_offset);
298 offset = next_offset;
299 } else {
300 let next_offset = cmp::min(range_to_subtract.end, range.end);
301 offset = next_offset;
302 }
303
304 if offset >= range_to_subtract.end {
305 ranges_to_subtract.next();
306 }
307 } else {
308 result.push(offset..range.end);
309 offset = range.end;
310 }
311 }
312 }
313
314 result
315}
316
317fn add_content_from_range(
318 output: &mut String,
319 content: &str,
320 range: Range<usize>,
321 start_col: usize,
322) {
323 for mut line in content.get(range.clone()).unwrap_or("").lines() {
324 for _ in 0..start_col {
325 if line.starts_with(' ') {
326 line = &line[1..];
327 } else {
328 break;
329 }
330 }
331 output.push_str(line);
332 output.push('\n');
333 }
334 output.pop();
335}