1use anyhow::{anyhow, Ok, Result};
2use language::{Grammar, Language};
3use std::{
4 cmp::{self, Reverse},
5 collections::HashSet,
6 ops::Range,
7 path::Path,
8 sync::Arc,
9};
10use tree_sitter::{Parser, QueryCursor};
11
12#[derive(Debug, PartialEq, Clone)]
13pub struct Document {
14 pub name: String,
15 pub range: Range<usize>,
16 pub content: String,
17 pub embedding: Vec<f32>,
18}
19
20const CODE_CONTEXT_TEMPLATE: &str =
21 "The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
22const ENTIRE_FILE_TEMPLATE: &str =
23 "The below snippet is from file '<path>'\n\n```<language>\n<item>\n```";
24const MARKDOWN_CONTEXT_TEMPLATE: &str = "The below file contents is from file '<path>'\n\n<item>";
25pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] =
26 &["TOML", "YAML", "CSS", "HEEX", "ERB", "SVELTE", "HTML"];
27
28pub struct CodeContextRetriever {
29 pub parser: Parser,
30 pub cursor: QueryCursor,
31}
32
33// Every match has an item, this represents the fundamental treesitter symbol and anchors the search
34// Every match has one or more 'name' captures. These indicate the display range of the item for deduplication.
35// If there are preceeding comments, we track this with a context capture
36// If there is a piece that should be collapsed in hierarchical queries, we capture it with a collapse capture
37// If there is a piece that should be kept inside a collapsed node, we capture it with a keep capture
38#[derive(Debug, Clone)]
39pub struct CodeContextMatch {
40 pub start_col: usize,
41 pub item_range: Option<Range<usize>>,
42 pub name_range: Option<Range<usize>>,
43 pub context_ranges: Vec<Range<usize>>,
44 pub collapse_ranges: Vec<Range<usize>>,
45}
46
47impl CodeContextRetriever {
48 pub fn new() -> Self {
49 Self {
50 parser: Parser::new(),
51 cursor: QueryCursor::new(),
52 }
53 }
54
55 fn parse_entire_file(
56 &self,
57 relative_path: &Path,
58 language_name: Arc<str>,
59 content: &str,
60 ) -> Result<Vec<Document>> {
61 let document_span = ENTIRE_FILE_TEMPLATE
62 .replace("<path>", relative_path.to_string_lossy().as_ref())
63 .replace("<language>", language_name.as_ref())
64 .replace("<item>", &content);
65
66 Ok(vec![Document {
67 range: 0..content.len(),
68 content: document_span,
69 embedding: Vec::new(),
70 name: language_name.to_string(),
71 }])
72 }
73
74 fn parse_markdown_file(&self, relative_path: &Path, content: &str) -> Result<Vec<Document>> {
75 let document_span = MARKDOWN_CONTEXT_TEMPLATE
76 .replace("<path>", relative_path.to_string_lossy().as_ref())
77 .replace("<item>", &content);
78
79 Ok(vec![Document {
80 range: 0..content.len(),
81 content: document_span,
82 embedding: Vec::new(),
83 name: "Markdown".to_string(),
84 }])
85 }
86
87 fn get_matches_in_file(
88 &mut self,
89 content: &str,
90 grammar: &Arc<Grammar>,
91 ) -> Result<Vec<CodeContextMatch>> {
92 let embedding_config = grammar
93 .embedding_config
94 .as_ref()
95 .ok_or_else(|| anyhow!("no embedding queries"))?;
96 self.parser.set_language(grammar.ts_language).unwrap();
97
98 let tree = self
99 .parser
100 .parse(&content, None)
101 .ok_or_else(|| anyhow!("parsing failed"))?;
102
103 let mut captures: Vec<CodeContextMatch> = Vec::new();
104 let mut collapse_ranges: Vec<Range<usize>> = Vec::new();
105 let mut keep_ranges: Vec<Range<usize>> = Vec::new();
106 for mat in self.cursor.matches(
107 &embedding_config.query,
108 tree.root_node(),
109 content.as_bytes(),
110 ) {
111 let mut start_col = 0;
112 let mut item_range: Option<Range<usize>> = None;
113 let mut name_range: Option<Range<usize>> = None;
114 let mut context_ranges: Vec<Range<usize>> = Vec::new();
115 collapse_ranges.clear();
116 keep_ranges.clear();
117 for capture in mat.captures {
118 if capture.index == embedding_config.item_capture_ix {
119 item_range = Some(capture.node.byte_range());
120 start_col = capture.node.start_position().column;
121 } else if Some(capture.index) == embedding_config.name_capture_ix {
122 name_range = Some(capture.node.byte_range());
123 } else if Some(capture.index) == embedding_config.context_capture_ix {
124 context_ranges.push(capture.node.byte_range());
125 } else if Some(capture.index) == embedding_config.collapse_capture_ix {
126 collapse_ranges.push(capture.node.byte_range());
127 } else if Some(capture.index) == embedding_config.keep_capture_ix {
128 keep_ranges.push(capture.node.byte_range());
129 }
130 }
131
132 captures.push(CodeContextMatch {
133 start_col,
134 item_range,
135 name_range,
136 context_ranges,
137 collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
138 });
139 }
140 Ok(captures)
141 }
142
143 pub fn parse_file_with_template(
144 &mut self,
145 relative_path: &Path,
146 content: &str,
147 language: Arc<Language>,
148 ) -> Result<Vec<Document>> {
149 let language_name = language.name();
150
151 if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
152 return self.parse_entire_file(relative_path, language_name, &content);
153 } else if &language_name.to_string() == &"Markdown".to_string() {
154 return self.parse_markdown_file(relative_path, &content);
155 }
156
157 let mut documents = self.parse_file(content, language)?;
158 for document in &mut documents {
159 document.content = CODE_CONTEXT_TEMPLATE
160 .replace("<path>", relative_path.to_string_lossy().as_ref())
161 .replace("<language>", language_name.as_ref())
162 .replace("item", &document.content);
163 }
164 Ok(documents)
165 }
166
167 pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Document>> {
168 let grammar = language
169 .grammar()
170 .ok_or_else(|| anyhow!("no grammar for language"))?;
171
172 // Iterate through query matches
173 let matches = self.get_matches_in_file(content, grammar)?;
174
175 let language_scope = language.default_scope();
176 let placeholder = language_scope.collapsed_placeholder();
177
178 let mut documents = Vec::new();
179 let mut collapsed_ranges_within = Vec::new();
180 let mut parsed_name_ranges = HashSet::new();
181 for (i, context_match) in matches.iter().enumerate() {
182 // Items which are collapsible but not embeddable have no item range
183 let item_range = if let Some(item_range) = context_match.item_range.clone() {
184 item_range
185 } else {
186 continue;
187 };
188
189 // Checks for deduplication
190 let name;
191 if let Some(name_range) = context_match.name_range.clone() {
192 name = content
193 .get(name_range.clone())
194 .map_or(String::new(), |s| s.to_string());
195 if parsed_name_ranges.contains(&name_range) {
196 continue;
197 }
198 parsed_name_ranges.insert(name_range);
199 } else {
200 name = String::new();
201 }
202
203 collapsed_ranges_within.clear();
204 'outer: for remaining_match in &matches[(i + 1)..] {
205 for collapsed_range in &remaining_match.collapse_ranges {
206 if item_range.start <= collapsed_range.start
207 && item_range.end >= collapsed_range.end
208 {
209 collapsed_ranges_within.push(collapsed_range.clone());
210 } else {
211 break 'outer;
212 }
213 }
214 }
215
216 collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
217
218 let mut document_content = String::new();
219 for context_range in &context_match.context_ranges {
220 add_content_from_range(
221 &mut document_content,
222 content,
223 context_range.clone(),
224 context_match.start_col,
225 );
226 document_content.push_str("\n");
227 }
228
229 let mut offset = item_range.start;
230 for collapsed_range in &collapsed_ranges_within {
231 if collapsed_range.start > offset {
232 add_content_from_range(
233 &mut document_content,
234 content,
235 offset..collapsed_range.start,
236 context_match.start_col,
237 );
238 offset = collapsed_range.start;
239 }
240
241 if collapsed_range.end > offset {
242 document_content.push_str(placeholder);
243 offset = collapsed_range.end;
244 }
245 }
246
247 if offset < item_range.end {
248 add_content_from_range(
249 &mut document_content,
250 content,
251 offset..item_range.end,
252 context_match.start_col,
253 );
254 }
255
256 documents.push(Document {
257 name,
258 content: document_content,
259 range: item_range.clone(),
260 embedding: vec![],
261 })
262 }
263
264 return Ok(documents);
265 }
266}
267
268pub(crate) fn subtract_ranges(
269 ranges: &[Range<usize>],
270 ranges_to_subtract: &[Range<usize>],
271) -> Vec<Range<usize>> {
272 let mut result = Vec::new();
273
274 let mut ranges_to_subtract = ranges_to_subtract.iter().peekable();
275
276 for range in ranges {
277 let mut offset = range.start;
278
279 while offset < range.end {
280 if let Some(range_to_subtract) = ranges_to_subtract.peek() {
281 if offset < range_to_subtract.start {
282 let next_offset = cmp::min(range_to_subtract.start, range.end);
283 result.push(offset..next_offset);
284 offset = next_offset;
285 } else {
286 let next_offset = cmp::min(range_to_subtract.end, range.end);
287 offset = next_offset;
288 }
289
290 if offset >= range_to_subtract.end {
291 ranges_to_subtract.next();
292 }
293 } else {
294 result.push(offset..range.end);
295 offset = range.end;
296 }
297 }
298 }
299
300 result
301}
302
303fn add_content_from_range(
304 output: &mut String,
305 content: &str,
306 range: Range<usize>,
307 start_col: usize,
308) {
309 for mut line in content.get(range.clone()).unwrap_or("").lines() {
310 for _ in 0..start_col {
311 if line.starts_with(' ') {
312 line = &line[1..];
313 } else {
314 break;
315 }
316 }
317 output.push_str(line);
318 output.push('\n');
319 }
320 output.pop();
321}