1use anyhow::{anyhow, Ok, Result};
2use language::{Grammar, Language};
3use std::{
4 cmp::{self, Reverse},
5 collections::HashSet,
6 ops::Range,
7 path::Path,
8 sync::Arc,
9};
10use tree_sitter::{Parser, QueryCursor};
11
12#[derive(Debug, PartialEq, Clone)]
13pub struct Document {
14 pub name: String,
15 pub range: Range<usize>,
16 pub content: String,
17 pub embedding: Vec<f32>,
18}
19
20const CODE_CONTEXT_TEMPLATE: &str =
21 "The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
22const ENTIRE_FILE_TEMPLATE: &str =
23 "The below snippet is from file '<path>'\n\n```<language>\n<item>\n```";
24pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] = &["TOML", "YAML", "CSS"];
25
26pub struct CodeContextRetriever {
27 pub parser: Parser,
28 pub cursor: QueryCursor,
29}
30
31// Every match has an item, this represents the fundamental treesitter symbol and anchors the search
32// Every match has one or more 'name' captures. These indicate the display range of the item for deduplication.
33// If there are preceeding comments, we track this with a context capture
34// If there is a piece that should be collapsed in hierarchical queries, we capture it with a collapse capture
35// If there is a piece that should be kept inside a collapsed node, we capture it with a keep capture
36#[derive(Debug, Clone)]
37pub struct CodeContextMatch {
38 pub start_col: usize,
39 pub item_range: Option<Range<usize>>,
40 pub name_range: Option<Range<usize>>,
41 pub context_ranges: Vec<Range<usize>>,
42 pub collapse_ranges: Vec<Range<usize>>,
43}
44
45impl CodeContextRetriever {
46 pub fn new() -> Self {
47 Self {
48 parser: Parser::new(),
49 cursor: QueryCursor::new(),
50 }
51 }
52
53 fn parse_entire_file(
54 &self,
55 relative_path: &Path,
56 language_name: Arc<str>,
57 content: &str,
58 ) -> Result<Vec<Document>> {
59 let document_span = ENTIRE_FILE_TEMPLATE
60 .replace("<path>", relative_path.to_string_lossy().as_ref())
61 .replace("<language>", language_name.as_ref())
62 .replace("item", &content);
63
64 Ok(vec![Document {
65 range: 0..content.len(),
66 content: document_span,
67 embedding: Vec::new(),
68 name: language_name.to_string(),
69 }])
70 }
71
72 fn get_matches_in_file(
73 &mut self,
74 content: &str,
75 grammar: &Arc<Grammar>,
76 ) -> Result<Vec<CodeContextMatch>> {
77 let embedding_config = grammar
78 .embedding_config
79 .as_ref()
80 .ok_or_else(|| anyhow!("no embedding queries"))?;
81 self.parser.set_language(grammar.ts_language).unwrap();
82
83 let tree = self
84 .parser
85 .parse(&content, None)
86 .ok_or_else(|| anyhow!("parsing failed"))?;
87
88 let mut captures: Vec<CodeContextMatch> = Vec::new();
89 let mut collapse_ranges: Vec<Range<usize>> = Vec::new();
90 let mut keep_ranges: Vec<Range<usize>> = Vec::new();
91 for mat in self.cursor.matches(
92 &embedding_config.query,
93 tree.root_node(),
94 content.as_bytes(),
95 ) {
96 let mut start_col = 0;
97 let mut item_range: Option<Range<usize>> = None;
98 let mut name_range: Option<Range<usize>> = None;
99 let mut context_ranges: Vec<Range<usize>> = Vec::new();
100 collapse_ranges.clear();
101 keep_ranges.clear();
102 for capture in mat.captures {
103 if capture.index == embedding_config.item_capture_ix {
104 item_range = Some(capture.node.byte_range());
105 start_col = capture.node.start_position().column;
106 } else if Some(capture.index) == embedding_config.name_capture_ix {
107 name_range = Some(capture.node.byte_range());
108 } else if Some(capture.index) == embedding_config.context_capture_ix {
109 context_ranges.push(capture.node.byte_range());
110 } else if Some(capture.index) == embedding_config.collapse_capture_ix {
111 collapse_ranges.push(capture.node.byte_range());
112 } else if Some(capture.index) == embedding_config.keep_capture_ix {
113 keep_ranges.push(capture.node.byte_range());
114 }
115 }
116
117 captures.push(CodeContextMatch {
118 start_col,
119 item_range,
120 name_range,
121 context_ranges,
122 collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
123 });
124 }
125 Ok(captures)
126 }
127
128 pub fn parse_file_with_template(
129 &mut self,
130 relative_path: &Path,
131 content: &str,
132 language: Arc<Language>,
133 ) -> Result<Vec<Document>> {
134 let language_name = language.name();
135
136 if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
137 return self.parse_entire_file(relative_path, language_name, &content);
138 }
139
140 let mut documents = self.parse_file(content, language)?;
141 for document in &mut documents {
142 document.content = CODE_CONTEXT_TEMPLATE
143 .replace("<path>", relative_path.to_string_lossy().as_ref())
144 .replace("<language>", language_name.as_ref())
145 .replace("item", &document.content);
146 }
147 Ok(documents)
148 }
149
150 pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Document>> {
151 let grammar = language
152 .grammar()
153 .ok_or_else(|| anyhow!("no grammar for language"))?;
154
155 // Iterate through query matches
156 let matches = self.get_matches_in_file(content, grammar)?;
157
158 let language_scope = language.default_scope();
159 let placeholder = language_scope.collapsed_placeholder();
160
161 let mut documents = Vec::new();
162 let mut collapsed_ranges_within = Vec::new();
163 let mut parsed_name_ranges = HashSet::new();
164 for (i, context_match) in matches.iter().enumerate() {
165 // Items which are collapsible but not embeddable have no item range
166 let item_range = if let Some(item_range) = context_match.item_range.clone() {
167 item_range
168 } else {
169 continue;
170 };
171
172 // Checks for deduplication
173 let name;
174 if let Some(name_range) = context_match.name_range.clone() {
175 name = content
176 .get(name_range.clone())
177 .map_or(String::new(), |s| s.to_string());
178 if parsed_name_ranges.contains(&name_range) {
179 continue;
180 }
181 parsed_name_ranges.insert(name_range);
182 } else {
183 name = String::new();
184 }
185
186 collapsed_ranges_within.clear();
187 'outer: for remaining_match in &matches[(i + 1)..] {
188 for collapsed_range in &remaining_match.collapse_ranges {
189 if item_range.start <= collapsed_range.start
190 && item_range.end >= collapsed_range.end
191 {
192 collapsed_ranges_within.push(collapsed_range.clone());
193 } else {
194 break 'outer;
195 }
196 }
197 }
198
199 collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
200
201 let mut document_content = String::new();
202 for context_range in &context_match.context_ranges {
203 document_content.push_str(&content[context_range.clone()]);
204 document_content.push_str("\n");
205 }
206
207 let mut offset = item_range.start;
208 for collapsed_range in &collapsed_ranges_within {
209 if collapsed_range.start > offset {
210 add_content_from_range(
211 &mut document_content,
212 content,
213 offset..collapsed_range.start,
214 context_match.start_col,
215 );
216 offset = collapsed_range.start;
217 }
218
219 if collapsed_range.end > offset {
220 document_content.push_str(placeholder);
221 offset = collapsed_range.end;
222 }
223 }
224
225 if offset < item_range.end {
226 add_content_from_range(
227 &mut document_content,
228 content,
229 offset..item_range.end,
230 context_match.start_col,
231 );
232 }
233
234 documents.push(Document {
235 name,
236 content: document_content,
237 range: item_range.clone(),
238 embedding: vec![],
239 })
240 }
241
242 return Ok(documents);
243 }
244}
245
246pub(crate) fn subtract_ranges(
247 ranges: &[Range<usize>],
248 ranges_to_subtract: &[Range<usize>],
249) -> Vec<Range<usize>> {
250 let mut result = Vec::new();
251
252 let mut ranges_to_subtract = ranges_to_subtract.iter().peekable();
253
254 for range in ranges {
255 let mut offset = range.start;
256
257 while offset < range.end {
258 if let Some(range_to_subtract) = ranges_to_subtract.peek() {
259 if offset < range_to_subtract.start {
260 let next_offset = cmp::min(range_to_subtract.start, range.end);
261 result.push(offset..next_offset);
262 offset = next_offset;
263 } else {
264 let next_offset = cmp::min(range_to_subtract.end, range.end);
265 offset = next_offset;
266 }
267
268 if offset >= range_to_subtract.end {
269 ranges_to_subtract.next();
270 }
271 } else {
272 result.push(offset..range.end);
273 offset = range.end;
274 }
275 }
276 }
277
278 result
279}
280
281fn add_content_from_range(
282 output: &mut String,
283 content: &str,
284 range: Range<usize>,
285 start_col: usize,
286) {
287 for mut line in content.get(range.clone()).unwrap_or("").lines() {
288 for _ in 0..start_col {
289 if line.starts_with(' ') {
290 line = &line[1..];
291 } else {
292 break;
293 }
294 }
295 output.push_str(line);
296 output.push('\n');
297 }
298 output.pop();
299}