1use ai::embedding::{Embedding, EmbeddingProvider};
2use anyhow::{anyhow, Result};
3use language::{Grammar, Language};
4use rusqlite::{
5 types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
6 ToSql,
7};
8use sha1::{Digest, Sha1};
9use std::{
10 borrow::Cow,
11 cmp::{self, Reverse},
12 collections::HashSet,
13 ops::Range,
14 path::Path,
15 sync::Arc,
16};
17use tree_sitter::{Parser, QueryCursor};
18
19#[derive(Debug, PartialEq, Eq, Clone, Hash)]
20pub struct SpanDigest(pub [u8; 20]);
21
22impl FromSql for SpanDigest {
23 fn column_result(value: ValueRef) -> FromSqlResult<Self> {
24 let blob = value.as_blob()?;
25 let bytes =
26 blob.try_into()
27 .map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize {
28 expected_size: 20,
29 blob_size: blob.len(),
30 })?;
31 return Ok(SpanDigest(bytes));
32 }
33}
34
35impl ToSql for SpanDigest {
36 fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
37 self.0.to_sql()
38 }
39}
40
41impl From<&'_ str> for SpanDigest {
42 fn from(value: &'_ str) -> Self {
43 let mut sha1 = Sha1::new();
44 sha1.update(value);
45 Self(sha1.finalize().into())
46 }
47}
48
49#[derive(Debug, PartialEq, Clone)]
50pub struct Span {
51 pub name: String,
52 pub range: Range<usize>,
53 pub content: String,
54 pub embedding: Option<Embedding>,
55 pub digest: SpanDigest,
56 pub token_count: usize,
57}
58
59const CODE_CONTEXT_TEMPLATE: &str =
60 "The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
61const ENTIRE_FILE_TEMPLATE: &str =
62 "The below snippet is from file '<path>'\n\n```<language>\n<item>\n```";
63const MARKDOWN_CONTEXT_TEMPLATE: &str = "The below file contents is from file '<path>'\n\n<item>";
64pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] =
65 &["TOML", "YAML", "CSS", "HEEX", "ERB", "SVELTE", "HTML"];
66
67pub struct CodeContextRetriever {
68 pub parser: Parser,
69 pub cursor: QueryCursor,
70 pub embedding_provider: Arc<dyn EmbeddingProvider>,
71}
72
73// Every match has an item, this represents the fundamental treesitter symbol and anchors the search
74// Every match has one or more 'name' captures. These indicate the display range of the item for deduplication.
75// If there are preceeding comments, we track this with a context capture
76// If there is a piece that should be collapsed in hierarchical queries, we capture it with a collapse capture
77// If there is a piece that should be kept inside a collapsed node, we capture it with a keep capture
78#[derive(Debug, Clone)]
79pub struct CodeContextMatch {
80 pub start_col: usize,
81 pub item_range: Option<Range<usize>>,
82 pub name_range: Option<Range<usize>>,
83 pub context_ranges: Vec<Range<usize>>,
84 pub collapse_ranges: Vec<Range<usize>>,
85}
86
87impl CodeContextRetriever {
88 pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
89 Self {
90 parser: Parser::new(),
91 cursor: QueryCursor::new(),
92 embedding_provider,
93 }
94 }
95
96 fn parse_entire_file(
97 &self,
98 relative_path: Option<&Path>,
99 language_name: Arc<str>,
100 content: &str,
101 ) -> Result<Vec<Span>> {
102 let document_span = ENTIRE_FILE_TEMPLATE
103 .replace(
104 "<path>",
105 &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
106 )
107 .replace("<language>", language_name.as_ref())
108 .replace("<item>", &content);
109 let digest = SpanDigest::from(document_span.as_str());
110 let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
111 Ok(vec![Span {
112 range: 0..content.len(),
113 content: document_span,
114 embedding: Default::default(),
115 name: language_name.to_string(),
116 digest,
117 token_count,
118 }])
119 }
120
121 fn parse_markdown_file(
122 &self,
123 relative_path: Option<&Path>,
124 content: &str,
125 ) -> Result<Vec<Span>> {
126 let document_span = MARKDOWN_CONTEXT_TEMPLATE
127 .replace(
128 "<path>",
129 &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
130 )
131 .replace("<item>", &content);
132 let digest = SpanDigest::from(document_span.as_str());
133 let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
134 Ok(vec![Span {
135 range: 0..content.len(),
136 content: document_span,
137 embedding: None,
138 name: "Markdown".to_string(),
139 digest,
140 token_count,
141 }])
142 }
143
144 fn get_matches_in_file(
145 &mut self,
146 content: &str,
147 grammar: &Arc<Grammar>,
148 ) -> Result<Vec<CodeContextMatch>> {
149 let embedding_config = grammar
150 .embedding_config
151 .as_ref()
152 .ok_or_else(|| anyhow!("no embedding queries"))?;
153 self.parser.set_language(grammar.ts_language).unwrap();
154
155 let tree = self
156 .parser
157 .parse(&content, None)
158 .ok_or_else(|| anyhow!("parsing failed"))?;
159
160 let mut captures: Vec<CodeContextMatch> = Vec::new();
161 let mut collapse_ranges: Vec<Range<usize>> = Vec::new();
162 let mut keep_ranges: Vec<Range<usize>> = Vec::new();
163 for mat in self.cursor.matches(
164 &embedding_config.query,
165 tree.root_node(),
166 content.as_bytes(),
167 ) {
168 let mut start_col = 0;
169 let mut item_range: Option<Range<usize>> = None;
170 let mut name_range: Option<Range<usize>> = None;
171 let mut context_ranges: Vec<Range<usize>> = Vec::new();
172 collapse_ranges.clear();
173 keep_ranges.clear();
174 for capture in mat.captures {
175 if capture.index == embedding_config.item_capture_ix {
176 item_range = Some(capture.node.byte_range());
177 start_col = capture.node.start_position().column;
178 } else if Some(capture.index) == embedding_config.name_capture_ix {
179 name_range = Some(capture.node.byte_range());
180 } else if Some(capture.index) == embedding_config.context_capture_ix {
181 context_ranges.push(capture.node.byte_range());
182 } else if Some(capture.index) == embedding_config.collapse_capture_ix {
183 collapse_ranges.push(capture.node.byte_range());
184 } else if Some(capture.index) == embedding_config.keep_capture_ix {
185 keep_ranges.push(capture.node.byte_range());
186 }
187 }
188
189 captures.push(CodeContextMatch {
190 start_col,
191 item_range,
192 name_range,
193 context_ranges,
194 collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
195 });
196 }
197 Ok(captures)
198 }
199
200 pub fn parse_file_with_template(
201 &mut self,
202 relative_path: Option<&Path>,
203 content: &str,
204 language: Arc<Language>,
205 ) -> Result<Vec<Span>> {
206 let language_name = language.name();
207
208 if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
209 return self.parse_entire_file(relative_path, language_name, &content);
210 } else if ["Markdown", "Plain Text"].contains(&language_name.as_ref()) {
211 return self.parse_markdown_file(relative_path, &content);
212 }
213
214 let mut spans = self.parse_file(content, language)?;
215 for span in &mut spans {
216 let document_content = CODE_CONTEXT_TEMPLATE
217 .replace(
218 "<path>",
219 &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
220 )
221 .replace("<language>", language_name.as_ref())
222 .replace("item", &span.content);
223
224 let (document_content, token_count) =
225 self.embedding_provider.truncate(&document_content);
226
227 span.content = document_content;
228 span.token_count = token_count;
229 }
230 Ok(spans)
231 }
232
233 pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Span>> {
234 let grammar = language
235 .grammar()
236 .ok_or_else(|| anyhow!("no grammar for language"))?;
237
238 // Iterate through query matches
239 let matches = self.get_matches_in_file(content, grammar)?;
240
241 let language_scope = language.default_scope();
242 let placeholder = language_scope.collapsed_placeholder();
243
244 let mut spans = Vec::new();
245 let mut collapsed_ranges_within = Vec::new();
246 let mut parsed_name_ranges = HashSet::new();
247 for (i, context_match) in matches.iter().enumerate() {
248 // Items which are collapsible but not embeddable have no item range
249 let item_range = if let Some(item_range) = context_match.item_range.clone() {
250 item_range
251 } else {
252 continue;
253 };
254
255 // Checks for deduplication
256 let name;
257 if let Some(name_range) = context_match.name_range.clone() {
258 name = content
259 .get(name_range.clone())
260 .map_or(String::new(), |s| s.to_string());
261 if parsed_name_ranges.contains(&name_range) {
262 continue;
263 }
264 parsed_name_ranges.insert(name_range);
265 } else {
266 name = String::new();
267 }
268
269 collapsed_ranges_within.clear();
270 'outer: for remaining_match in &matches[(i + 1)..] {
271 for collapsed_range in &remaining_match.collapse_ranges {
272 if item_range.start <= collapsed_range.start
273 && item_range.end >= collapsed_range.end
274 {
275 collapsed_ranges_within.push(collapsed_range.clone());
276 } else {
277 break 'outer;
278 }
279 }
280 }
281
282 collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
283
284 let mut span_content = String::new();
285 for context_range in &context_match.context_ranges {
286 add_content_from_range(
287 &mut span_content,
288 content,
289 context_range.clone(),
290 context_match.start_col,
291 );
292 span_content.push_str("\n");
293 }
294
295 let mut offset = item_range.start;
296 for collapsed_range in &collapsed_ranges_within {
297 if collapsed_range.start > offset {
298 add_content_from_range(
299 &mut span_content,
300 content,
301 offset..collapsed_range.start,
302 context_match.start_col,
303 );
304 offset = collapsed_range.start;
305 }
306
307 if collapsed_range.end > offset {
308 span_content.push_str(placeholder);
309 offset = collapsed_range.end;
310 }
311 }
312
313 if offset < item_range.end {
314 add_content_from_range(
315 &mut span_content,
316 content,
317 offset..item_range.end,
318 context_match.start_col,
319 );
320 }
321
322 let sha1 = SpanDigest::from(span_content.as_str());
323 spans.push(Span {
324 name,
325 content: span_content,
326 range: item_range.clone(),
327 embedding: None,
328 digest: sha1,
329 token_count: 0,
330 })
331 }
332
333 return Ok(spans);
334 }
335}
336
337pub(crate) fn subtract_ranges(
338 ranges: &[Range<usize>],
339 ranges_to_subtract: &[Range<usize>],
340) -> Vec<Range<usize>> {
341 let mut result = Vec::new();
342
343 let mut ranges_to_subtract = ranges_to_subtract.iter().peekable();
344
345 for range in ranges {
346 let mut offset = range.start;
347
348 while offset < range.end {
349 if let Some(range_to_subtract) = ranges_to_subtract.peek() {
350 if offset < range_to_subtract.start {
351 let next_offset = cmp::min(range_to_subtract.start, range.end);
352 result.push(offset..next_offset);
353 offset = next_offset;
354 } else {
355 let next_offset = cmp::min(range_to_subtract.end, range.end);
356 offset = next_offset;
357 }
358
359 if offset >= range_to_subtract.end {
360 ranges_to_subtract.next();
361 }
362 } else {
363 result.push(offset..range.end);
364 offset = range.end;
365 }
366 }
367 }
368
369 result
370}
371
372fn add_content_from_range(
373 output: &mut String,
374 content: &str,
375 range: Range<usize>,
376 start_col: usize,
377) {
378 for mut line in content.get(range.clone()).unwrap_or("").lines() {
379 for _ in 0..start_col {
380 if line.starts_with(' ') {
381 line = &line[1..];
382 } else {
383 break;
384 }
385 }
386 output.push_str(line);
387 output.push('\n');
388 }
389 output.pop();
390}