1use crate::embedding::{Embedding, EmbeddingProvider};
2use anyhow::{anyhow, Result};
3use language::{Grammar, Language};
4use rusqlite::{
5 types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
6 ToSql,
7};
8use sha1::{Digest, Sha1};
9use std::{
10 cmp::{self, Reverse},
11 collections::HashSet,
12 ops::Range,
13 path::Path,
14 sync::Arc,
15};
16use tree_sitter::{Parser, QueryCursor};
17
18#[derive(Debug, PartialEq, Eq, Clone, Hash)]
19pub struct SpanDigest([u8; 20]);
20
21impl FromSql for SpanDigest {
22 fn column_result(value: ValueRef) -> FromSqlResult<Self> {
23 let blob = value.as_blob()?;
24 let bytes =
25 blob.try_into()
26 .map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize {
27 expected_size: 20,
28 blob_size: blob.len(),
29 })?;
30 return Ok(SpanDigest(bytes));
31 }
32}
33
34impl ToSql for SpanDigest {
35 fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
36 self.0.to_sql()
37 }
38}
39
40impl From<&'_ str> for SpanDigest {
41 fn from(value: &'_ str) -> Self {
42 let mut sha1 = Sha1::new();
43 sha1.update(value);
44 Self(sha1.finalize().into())
45 }
46}
47
48#[derive(Debug, PartialEq, Clone)]
49pub struct Span {
50 pub name: String,
51 pub range: Range<usize>,
52 pub content: String,
53 pub embedding: Option<Embedding>,
54 pub digest: SpanDigest,
55 pub token_count: usize,
56}
57
58const CODE_CONTEXT_TEMPLATE: &str =
59 "The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
60const ENTIRE_FILE_TEMPLATE: &str =
61 "The below snippet is from file '<path>'\n\n```<language>\n<item>\n```";
62const MARKDOWN_CONTEXT_TEMPLATE: &str = "The below file contents is from file '<path>'\n\n<item>";
63pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] =
64 &["TOML", "YAML", "CSS", "HEEX", "ERB", "SVELTE", "HTML"];
65
66pub struct CodeContextRetriever {
67 pub parser: Parser,
68 pub cursor: QueryCursor,
69 pub embedding_provider: Arc<dyn EmbeddingProvider>,
70}
71
72// Every match has an item, this represents the fundamental treesitter symbol and anchors the search
73// Every match has one or more 'name' captures. These indicate the display range of the item for deduplication.
74// If there are preceeding comments, we track this with a context capture
75// If there is a piece that should be collapsed in hierarchical queries, we capture it with a collapse capture
76// If there is a piece that should be kept inside a collapsed node, we capture it with a keep capture
77#[derive(Debug, Clone)]
78pub struct CodeContextMatch {
79 pub start_col: usize,
80 pub item_range: Option<Range<usize>>,
81 pub name_range: Option<Range<usize>>,
82 pub context_ranges: Vec<Range<usize>>,
83 pub collapse_ranges: Vec<Range<usize>>,
84}
85
86impl CodeContextRetriever {
87 pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
88 Self {
89 parser: Parser::new(),
90 cursor: QueryCursor::new(),
91 embedding_provider,
92 }
93 }
94
95 fn parse_entire_file(
96 &self,
97 relative_path: &Path,
98 language_name: Arc<str>,
99 content: &str,
100 ) -> Result<Vec<Span>> {
101 let document_span = ENTIRE_FILE_TEMPLATE
102 .replace("<path>", relative_path.to_string_lossy().as_ref())
103 .replace("<language>", language_name.as_ref())
104 .replace("<item>", &content);
105 let digest = SpanDigest::from(document_span.as_str());
106 let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
107 Ok(vec![Span {
108 range: 0..content.len(),
109 content: document_span,
110 embedding: Default::default(),
111 name: language_name.to_string(),
112 digest,
113 token_count,
114 }])
115 }
116
117 fn parse_markdown_file(&self, relative_path: &Path, content: &str) -> Result<Vec<Span>> {
118 let document_span = MARKDOWN_CONTEXT_TEMPLATE
119 .replace("<path>", relative_path.to_string_lossy().as_ref())
120 .replace("<item>", &content);
121 let digest = SpanDigest::from(document_span.as_str());
122 let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
123 Ok(vec![Span {
124 range: 0..content.len(),
125 content: document_span,
126 embedding: None,
127 name: "Markdown".to_string(),
128 digest,
129 token_count,
130 }])
131 }
132
133 fn get_matches_in_file(
134 &mut self,
135 content: &str,
136 grammar: &Arc<Grammar>,
137 ) -> Result<Vec<CodeContextMatch>> {
138 let embedding_config = grammar
139 .embedding_config
140 .as_ref()
141 .ok_or_else(|| anyhow!("no embedding queries"))?;
142 self.parser.set_language(grammar.ts_language).unwrap();
143
144 let tree = self
145 .parser
146 .parse(&content, None)
147 .ok_or_else(|| anyhow!("parsing failed"))?;
148
149 let mut captures: Vec<CodeContextMatch> = Vec::new();
150 let mut collapse_ranges: Vec<Range<usize>> = Vec::new();
151 let mut keep_ranges: Vec<Range<usize>> = Vec::new();
152 for mat in self.cursor.matches(
153 &embedding_config.query,
154 tree.root_node(),
155 content.as_bytes(),
156 ) {
157 let mut start_col = 0;
158 let mut item_range: Option<Range<usize>> = None;
159 let mut name_range: Option<Range<usize>> = None;
160 let mut context_ranges: Vec<Range<usize>> = Vec::new();
161 collapse_ranges.clear();
162 keep_ranges.clear();
163 for capture in mat.captures {
164 if capture.index == embedding_config.item_capture_ix {
165 item_range = Some(capture.node.byte_range());
166 start_col = capture.node.start_position().column;
167 } else if Some(capture.index) == embedding_config.name_capture_ix {
168 name_range = Some(capture.node.byte_range());
169 } else if Some(capture.index) == embedding_config.context_capture_ix {
170 context_ranges.push(capture.node.byte_range());
171 } else if Some(capture.index) == embedding_config.collapse_capture_ix {
172 collapse_ranges.push(capture.node.byte_range());
173 } else if Some(capture.index) == embedding_config.keep_capture_ix {
174 keep_ranges.push(capture.node.byte_range());
175 }
176 }
177
178 captures.push(CodeContextMatch {
179 start_col,
180 item_range,
181 name_range,
182 context_ranges,
183 collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
184 });
185 }
186 Ok(captures)
187 }
188
189 pub fn parse_file_with_template(
190 &mut self,
191 relative_path: &Path,
192 content: &str,
193 language: Arc<Language>,
194 ) -> Result<Vec<Span>> {
195 let language_name = language.name();
196
197 if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
198 return self.parse_entire_file(relative_path, language_name, &content);
199 } else if language_name.as_ref() == "Markdown" {
200 return self.parse_markdown_file(relative_path, &content);
201 }
202
203 let mut spans = self.parse_file(content, language)?;
204 for span in &mut spans {
205 let document_content = CODE_CONTEXT_TEMPLATE
206 .replace("<path>", relative_path.to_string_lossy().as_ref())
207 .replace("<language>", language_name.as_ref())
208 .replace("item", &span.content);
209
210 let (document_content, token_count) =
211 self.embedding_provider.truncate(&document_content);
212
213 span.content = document_content;
214 span.token_count = token_count;
215 }
216 Ok(spans)
217 }
218
219 pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Span>> {
220 let grammar = language
221 .grammar()
222 .ok_or_else(|| anyhow!("no grammar for language"))?;
223
224 // Iterate through query matches
225 let matches = self.get_matches_in_file(content, grammar)?;
226
227 let language_scope = language.default_scope();
228 let placeholder = language_scope.collapsed_placeholder();
229
230 let mut spans = Vec::new();
231 let mut collapsed_ranges_within = Vec::new();
232 let mut parsed_name_ranges = HashSet::new();
233 for (i, context_match) in matches.iter().enumerate() {
234 // Items which are collapsible but not embeddable have no item range
235 let item_range = if let Some(item_range) = context_match.item_range.clone() {
236 item_range
237 } else {
238 continue;
239 };
240
241 // Checks for deduplication
242 let name;
243 if let Some(name_range) = context_match.name_range.clone() {
244 name = content
245 .get(name_range.clone())
246 .map_or(String::new(), |s| s.to_string());
247 if parsed_name_ranges.contains(&name_range) {
248 continue;
249 }
250 parsed_name_ranges.insert(name_range);
251 } else {
252 name = String::new();
253 }
254
255 collapsed_ranges_within.clear();
256 'outer: for remaining_match in &matches[(i + 1)..] {
257 for collapsed_range in &remaining_match.collapse_ranges {
258 if item_range.start <= collapsed_range.start
259 && item_range.end >= collapsed_range.end
260 {
261 collapsed_ranges_within.push(collapsed_range.clone());
262 } else {
263 break 'outer;
264 }
265 }
266 }
267
268 collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
269
270 let mut span_content = String::new();
271 for context_range in &context_match.context_ranges {
272 add_content_from_range(
273 &mut span_content,
274 content,
275 context_range.clone(),
276 context_match.start_col,
277 );
278 span_content.push_str("\n");
279 }
280
281 let mut offset = item_range.start;
282 for collapsed_range in &collapsed_ranges_within {
283 if collapsed_range.start > offset {
284 add_content_from_range(
285 &mut span_content,
286 content,
287 offset..collapsed_range.start,
288 context_match.start_col,
289 );
290 offset = collapsed_range.start;
291 }
292
293 if collapsed_range.end > offset {
294 span_content.push_str(placeholder);
295 offset = collapsed_range.end;
296 }
297 }
298
299 if offset < item_range.end {
300 add_content_from_range(
301 &mut span_content,
302 content,
303 offset..item_range.end,
304 context_match.start_col,
305 );
306 }
307
308 let sha1 = SpanDigest::from(span_content.as_str());
309 spans.push(Span {
310 name,
311 content: span_content,
312 range: item_range.clone(),
313 embedding: None,
314 digest: sha1,
315 token_count: 0,
316 })
317 }
318
319 return Ok(spans);
320 }
321}
322
323pub(crate) fn subtract_ranges(
324 ranges: &[Range<usize>],
325 ranges_to_subtract: &[Range<usize>],
326) -> Vec<Range<usize>> {
327 let mut result = Vec::new();
328
329 let mut ranges_to_subtract = ranges_to_subtract.iter().peekable();
330
331 for range in ranges {
332 let mut offset = range.start;
333
334 while offset < range.end {
335 if let Some(range_to_subtract) = ranges_to_subtract.peek() {
336 if offset < range_to_subtract.start {
337 let next_offset = cmp::min(range_to_subtract.start, range.end);
338 result.push(offset..next_offset);
339 offset = next_offset;
340 } else {
341 let next_offset = cmp::min(range_to_subtract.end, range.end);
342 offset = next_offset;
343 }
344
345 if offset >= range_to_subtract.end {
346 ranges_to_subtract.next();
347 }
348 } else {
349 result.push(offset..range.end);
350 offset = range.end;
351 }
352 }
353 }
354
355 result
356}
357
358fn add_content_from_range(
359 output: &mut String,
360 content: &str,
361 range: Range<usize>,
362 start_col: usize,
363) {
364 for mut line in content.get(range.clone()).unwrap_or("").lines() {
365 for _ in 0..start_col {
366 if line.starts_with(' ') {
367 line = &line[1..];
368 } else {
369 break;
370 }
371 }
372 output.push_str(line);
373 output.push('\n');
374 }
375 output.pop();
376}