1use crate::embedding::EmbeddingProvider;
2use anyhow::{anyhow, Ok, Result};
3use language::{Grammar, Language};
4use sha1::{Digest, Sha1};
5use std::{
6 cmp::{self, Reverse},
7 collections::HashSet,
8 ops::Range,
9 path::Path,
10 sync::Arc,
11};
12use tree_sitter::{Parser, QueryCursor};
13
14#[derive(Debug, PartialEq, Clone)]
15pub struct Document {
16 pub name: String,
17 pub range: Range<usize>,
18 pub content: String,
19 pub embedding: Vec<f32>,
20 pub sha1: [u8; 20],
21 pub token_count: usize,
22}
23
24const CODE_CONTEXT_TEMPLATE: &str =
25 "The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
26const ENTIRE_FILE_TEMPLATE: &str =
27 "The below snippet is from file '<path>'\n\n```<language>\n<item>\n```";
28const MARKDOWN_CONTEXT_TEMPLATE: &str = "The below file contents is from file '<path>'\n\n<item>";
29pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] =
30 &["TOML", "YAML", "CSS", "HEEX", "ERB", "SVELTE", "HTML"];
31
32pub struct CodeContextRetriever {
33 pub parser: Parser,
34 pub cursor: QueryCursor,
35 pub embedding_provider: Arc<dyn EmbeddingProvider>,
36}
37
38// Every match has an item, this represents the fundamental treesitter symbol and anchors the search
39// Every match has one or more 'name' captures. These indicate the display range of the item for deduplication.
40// If there are preceeding comments, we track this with a context capture
41// If there is a piece that should be collapsed in hierarchical queries, we capture it with a collapse capture
42// If there is a piece that should be kept inside a collapsed node, we capture it with a keep capture
43#[derive(Debug, Clone)]
44pub struct CodeContextMatch {
45 pub start_col: usize,
46 pub item_range: Option<Range<usize>>,
47 pub name_range: Option<Range<usize>>,
48 pub context_ranges: Vec<Range<usize>>,
49 pub collapse_ranges: Vec<Range<usize>>,
50}
51
52impl CodeContextRetriever {
53 pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
54 Self {
55 parser: Parser::new(),
56 cursor: QueryCursor::new(),
57 embedding_provider,
58 }
59 }
60
61 fn parse_entire_file(
62 &self,
63 relative_path: &Path,
64 language_name: Arc<str>,
65 content: &str,
66 ) -> Result<Vec<Document>> {
67 let document_span = ENTIRE_FILE_TEMPLATE
68 .replace("<path>", relative_path.to_string_lossy().as_ref())
69 .replace("<language>", language_name.as_ref())
70 .replace("<item>", &content);
71
72 let mut sha1 = Sha1::new();
73 sha1.update(&document_span);
74
75 let token_count = self.embedding_provider.count_tokens(&document_span);
76
77 Ok(vec![Document {
78 range: 0..content.len(),
79 content: document_span,
80 embedding: Vec::new(),
81 name: language_name.to_string(),
82 sha1: sha1.finalize().into(),
83 token_count,
84 }])
85 }
86
87 fn parse_markdown_file(&self, relative_path: &Path, content: &str) -> Result<Vec<Document>> {
88 let document_span = MARKDOWN_CONTEXT_TEMPLATE
89 .replace("<path>", relative_path.to_string_lossy().as_ref())
90 .replace("<item>", &content);
91
92 let mut sha1 = Sha1::new();
93 sha1.update(&document_span);
94
95 let token_count = self.embedding_provider.count_tokens(&document_span);
96
97 Ok(vec![Document {
98 range: 0..content.len(),
99 content: document_span,
100 embedding: Vec::new(),
101 name: "Markdown".to_string(),
102 sha1: sha1.finalize().into(),
103 token_count,
104 }])
105 }
106
107 fn get_matches_in_file(
108 &mut self,
109 content: &str,
110 grammar: &Arc<Grammar>,
111 ) -> Result<Vec<CodeContextMatch>> {
112 let embedding_config = grammar
113 .embedding_config
114 .as_ref()
115 .ok_or_else(|| anyhow!("no embedding queries"))?;
116 self.parser.set_language(grammar.ts_language).unwrap();
117
118 let tree = self
119 .parser
120 .parse(&content, None)
121 .ok_or_else(|| anyhow!("parsing failed"))?;
122
123 let mut captures: Vec<CodeContextMatch> = Vec::new();
124 let mut collapse_ranges: Vec<Range<usize>> = Vec::new();
125 let mut keep_ranges: Vec<Range<usize>> = Vec::new();
126 for mat in self.cursor.matches(
127 &embedding_config.query,
128 tree.root_node(),
129 content.as_bytes(),
130 ) {
131 let mut start_col = 0;
132 let mut item_range: Option<Range<usize>> = None;
133 let mut name_range: Option<Range<usize>> = None;
134 let mut context_ranges: Vec<Range<usize>> = Vec::new();
135 collapse_ranges.clear();
136 keep_ranges.clear();
137 for capture in mat.captures {
138 if capture.index == embedding_config.item_capture_ix {
139 item_range = Some(capture.node.byte_range());
140 start_col = capture.node.start_position().column;
141 } else if Some(capture.index) == embedding_config.name_capture_ix {
142 name_range = Some(capture.node.byte_range());
143 } else if Some(capture.index) == embedding_config.context_capture_ix {
144 context_ranges.push(capture.node.byte_range());
145 } else if Some(capture.index) == embedding_config.collapse_capture_ix {
146 collapse_ranges.push(capture.node.byte_range());
147 } else if Some(capture.index) == embedding_config.keep_capture_ix {
148 keep_ranges.push(capture.node.byte_range());
149 }
150 }
151
152 captures.push(CodeContextMatch {
153 start_col,
154 item_range,
155 name_range,
156 context_ranges,
157 collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
158 });
159 }
160 Ok(captures)
161 }
162
163 pub fn parse_file_with_template(
164 &mut self,
165 relative_path: &Path,
166 content: &str,
167 language: Arc<Language>,
168 ) -> Result<Vec<Document>> {
169 let language_name = language.name();
170
171 if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
172 return self.parse_entire_file(relative_path, language_name, &content);
173 } else if &language_name.to_string() == &"Markdown".to_string() {
174 return self.parse_markdown_file(relative_path, &content);
175 }
176
177 let mut documents = self.parse_file(content, language)?;
178 for document in &mut documents {
179 let document_content = CODE_CONTEXT_TEMPLATE
180 .replace("<path>", relative_path.to_string_lossy().as_ref())
181 .replace("<language>", language_name.as_ref())
182 .replace("item", &document.content);
183
184 let token_count = self.embedding_provider.count_tokens(&document_content);
185 document.content = document_content;
186 document.token_count = token_count;
187 }
188 Ok(documents)
189 }
190
191 pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Document>> {
192 let grammar = language
193 .grammar()
194 .ok_or_else(|| anyhow!("no grammar for language"))?;
195
196 // Iterate through query matches
197 let matches = self.get_matches_in_file(content, grammar)?;
198
199 let language_scope = language.default_scope();
200 let placeholder = language_scope.collapsed_placeholder();
201
202 let mut documents = Vec::new();
203 let mut collapsed_ranges_within = Vec::new();
204 let mut parsed_name_ranges = HashSet::new();
205 for (i, context_match) in matches.iter().enumerate() {
206 // Items which are collapsible but not embeddable have no item range
207 let item_range = if let Some(item_range) = context_match.item_range.clone() {
208 item_range
209 } else {
210 continue;
211 };
212
213 // Checks for deduplication
214 let name;
215 if let Some(name_range) = context_match.name_range.clone() {
216 name = content
217 .get(name_range.clone())
218 .map_or(String::new(), |s| s.to_string());
219 if parsed_name_ranges.contains(&name_range) {
220 continue;
221 }
222 parsed_name_ranges.insert(name_range);
223 } else {
224 name = String::new();
225 }
226
227 collapsed_ranges_within.clear();
228 'outer: for remaining_match in &matches[(i + 1)..] {
229 for collapsed_range in &remaining_match.collapse_ranges {
230 if item_range.start <= collapsed_range.start
231 && item_range.end >= collapsed_range.end
232 {
233 collapsed_ranges_within.push(collapsed_range.clone());
234 } else {
235 break 'outer;
236 }
237 }
238 }
239
240 collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
241
242 let mut document_content = String::new();
243 for context_range in &context_match.context_ranges {
244 add_content_from_range(
245 &mut document_content,
246 content,
247 context_range.clone(),
248 context_match.start_col,
249 );
250 document_content.push_str("\n");
251 }
252
253 let mut offset = item_range.start;
254 for collapsed_range in &collapsed_ranges_within {
255 if collapsed_range.start > offset {
256 add_content_from_range(
257 &mut document_content,
258 content,
259 offset..collapsed_range.start,
260 context_match.start_col,
261 );
262 offset = collapsed_range.start;
263 }
264
265 if collapsed_range.end > offset {
266 document_content.push_str(placeholder);
267 offset = collapsed_range.end;
268 }
269 }
270
271 if offset < item_range.end {
272 add_content_from_range(
273 &mut document_content,
274 content,
275 offset..item_range.end,
276 context_match.start_col,
277 );
278 }
279
280 let mut sha1 = Sha1::new();
281 sha1.update(&document_content);
282
283 documents.push(Document {
284 name,
285 content: document_content,
286 range: item_range.clone(),
287 embedding: vec![],
288 sha1: sha1.finalize().into(),
289 token_count: 0,
290 })
291 }
292
293 return Ok(documents);
294 }
295}
296
297pub(crate) fn subtract_ranges(
298 ranges: &[Range<usize>],
299 ranges_to_subtract: &[Range<usize>],
300) -> Vec<Range<usize>> {
301 let mut result = Vec::new();
302
303 let mut ranges_to_subtract = ranges_to_subtract.iter().peekable();
304
305 for range in ranges {
306 let mut offset = range.start;
307
308 while offset < range.end {
309 if let Some(range_to_subtract) = ranges_to_subtract.peek() {
310 if offset < range_to_subtract.start {
311 let next_offset = cmp::min(range_to_subtract.start, range.end);
312 result.push(offset..next_offset);
313 offset = next_offset;
314 } else {
315 let next_offset = cmp::min(range_to_subtract.end, range.end);
316 offset = next_offset;
317 }
318
319 if offset >= range_to_subtract.end {
320 ranges_to_subtract.next();
321 }
322 } else {
323 result.push(offset..range.end);
324 offset = range.end;
325 }
326 }
327 }
328
329 result
330}
331
332fn add_content_from_range(
333 output: &mut String,
334 content: &str,
335 range: Range<usize>,
336 start_col: usize,
337) {
338 for mut line in content.get(range.clone()).unwrap_or("").lines() {
339 for _ in 0..start_col {
340 if line.starts_with(' ') {
341 line = &line[1..];
342 } else {
343 break;
344 }
345 }
346 output.push_str(line);
347 output.push('\n');
348 }
349 output.pop();
350}