1use ai::embedding::{Embedding, EmbeddingProvider};
2use anyhow::{anyhow, Result};
3use language::{Grammar, Language};
4use rusqlite::{
5 types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
6 ToSql,
7};
8use sha1::{Digest, Sha1};
9use std::{
10 borrow::Cow,
11 cmp::{self, Reverse},
12 collections::HashSet,
13 ops::Range,
14 path::Path,
15 sync::Arc,
16};
17use tree_sitter::{Parser, QueryCursor};
18
19#[derive(Debug, PartialEq, Eq, Clone, Hash)]
20pub struct SpanDigest(pub [u8; 20]);
21
22impl FromSql for SpanDigest {
23 fn column_result(value: ValueRef) -> FromSqlResult<Self> {
24 let blob = value.as_blob()?;
25 let bytes =
26 blob.try_into()
27 .map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize {
28 expected_size: 20,
29 blob_size: blob.len(),
30 })?;
31 return Ok(SpanDigest(bytes));
32 }
33}
34
35impl ToSql for SpanDigest {
36 fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
37 self.0.to_sql()
38 }
39}
40
41impl From<&'_ str> for SpanDigest {
42 fn from(value: &'_ str) -> Self {
43 let mut sha1 = Sha1::new();
44 sha1.update(value);
45 Self(sha1.finalize().into())
46 }
47}
48
49#[derive(Debug, PartialEq, Clone)]
50pub struct Span {
51 pub name: String,
52 pub range: Range<usize>,
53 pub content: String,
54 pub embedding: Option<Embedding>,
55 pub digest: SpanDigest,
56 pub token_count: usize,
57}
58
59const CODE_CONTEXT_TEMPLATE: &str =
60 "The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
61const ENTIRE_FILE_TEMPLATE: &str =
62 "The below snippet is from file '<path>'\n\n```<language>\n<item>\n```";
63const MARKDOWN_CONTEXT_TEMPLATE: &str = "The below file contents is from file '<path>'\n\n<item>";
64pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] = &[
65 "TOML", "YAML", "CSS", "HEEX", "ERB", "SVELTE", "HTML", "Scheme",
66];
67
68pub struct CodeContextRetriever {
69 pub parser: Parser,
70 pub cursor: QueryCursor,
71 pub embedding_provider: Arc<dyn EmbeddingProvider>,
72}
73
74// Every match has an item, this represents the fundamental treesitter symbol and anchors the search
75// Every match has one or more 'name' captures. These indicate the display range of the item for deduplication.
76// If there are preceeding comments, we track this with a context capture
77// If there is a piece that should be collapsed in hierarchical queries, we capture it with a collapse capture
78// If there is a piece that should be kept inside a collapsed node, we capture it with a keep capture
79#[derive(Debug, Clone)]
80pub struct CodeContextMatch {
81 pub start_col: usize,
82 pub item_range: Option<Range<usize>>,
83 pub name_range: Option<Range<usize>>,
84 pub context_ranges: Vec<Range<usize>>,
85 pub collapse_ranges: Vec<Range<usize>>,
86}
87
88impl CodeContextRetriever {
89 pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
90 Self {
91 parser: Parser::new(),
92 cursor: QueryCursor::new(),
93 embedding_provider,
94 }
95 }
96
97 fn parse_entire_file(
98 &self,
99 relative_path: Option<&Path>,
100 language_name: Arc<str>,
101 content: &str,
102 ) -> Result<Vec<Span>> {
103 let document_span = ENTIRE_FILE_TEMPLATE
104 .replace(
105 "<path>",
106 &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
107 )
108 .replace("<language>", language_name.as_ref())
109 .replace("<item>", &content);
110 let digest = SpanDigest::from(document_span.as_str());
111 let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
112 Ok(vec![Span {
113 range: 0..content.len(),
114 content: document_span,
115 embedding: Default::default(),
116 name: language_name.to_string(),
117 digest,
118 token_count,
119 }])
120 }
121
122 fn parse_markdown_file(
123 &self,
124 relative_path: Option<&Path>,
125 content: &str,
126 ) -> Result<Vec<Span>> {
127 let document_span = MARKDOWN_CONTEXT_TEMPLATE
128 .replace(
129 "<path>",
130 &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
131 )
132 .replace("<item>", &content);
133 let digest = SpanDigest::from(document_span.as_str());
134 let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
135 Ok(vec![Span {
136 range: 0..content.len(),
137 content: document_span,
138 embedding: None,
139 name: "Markdown".to_string(),
140 digest,
141 token_count,
142 }])
143 }
144
145 fn get_matches_in_file(
146 &mut self,
147 content: &str,
148 grammar: &Arc<Grammar>,
149 ) -> Result<Vec<CodeContextMatch>> {
150 let embedding_config = grammar
151 .embedding_config
152 .as_ref()
153 .ok_or_else(|| anyhow!("no embedding queries"))?;
154 self.parser.set_language(grammar.ts_language).unwrap();
155
156 let tree = self
157 .parser
158 .parse(&content, None)
159 .ok_or_else(|| anyhow!("parsing failed"))?;
160
161 let mut captures: Vec<CodeContextMatch> = Vec::new();
162 let mut collapse_ranges: Vec<Range<usize>> = Vec::new();
163 let mut keep_ranges: Vec<Range<usize>> = Vec::new();
164 for mat in self.cursor.matches(
165 &embedding_config.query,
166 tree.root_node(),
167 content.as_bytes(),
168 ) {
169 let mut start_col = 0;
170 let mut item_range: Option<Range<usize>> = None;
171 let mut name_range: Option<Range<usize>> = None;
172 let mut context_ranges: Vec<Range<usize>> = Vec::new();
173 collapse_ranges.clear();
174 keep_ranges.clear();
175 for capture in mat.captures {
176 if capture.index == embedding_config.item_capture_ix {
177 item_range = Some(capture.node.byte_range());
178 start_col = capture.node.start_position().column;
179 } else if Some(capture.index) == embedding_config.name_capture_ix {
180 name_range = Some(capture.node.byte_range());
181 } else if Some(capture.index) == embedding_config.context_capture_ix {
182 context_ranges.push(capture.node.byte_range());
183 } else if Some(capture.index) == embedding_config.collapse_capture_ix {
184 collapse_ranges.push(capture.node.byte_range());
185 } else if Some(capture.index) == embedding_config.keep_capture_ix {
186 keep_ranges.push(capture.node.byte_range());
187 }
188 }
189
190 captures.push(CodeContextMatch {
191 start_col,
192 item_range,
193 name_range,
194 context_ranges,
195 collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
196 });
197 }
198 Ok(captures)
199 }
200
201 pub fn parse_file_with_template(
202 &mut self,
203 relative_path: Option<&Path>,
204 content: &str,
205 language: Arc<Language>,
206 ) -> Result<Vec<Span>> {
207 let language_name = language.name();
208
209 if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
210 return self.parse_entire_file(relative_path, language_name, &content);
211 } else if ["Markdown", "Plain Text"].contains(&language_name.as_ref()) {
212 return self.parse_markdown_file(relative_path, &content);
213 }
214
215 let mut spans = self.parse_file(content, language)?;
216 for span in &mut spans {
217 let document_content = CODE_CONTEXT_TEMPLATE
218 .replace(
219 "<path>",
220 &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
221 )
222 .replace("<language>", language_name.as_ref())
223 .replace("item", &span.content);
224
225 let (document_content, token_count) =
226 self.embedding_provider.truncate(&document_content);
227
228 span.content = document_content;
229 span.token_count = token_count;
230 }
231 Ok(spans)
232 }
233
234 pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Span>> {
235 let grammar = language
236 .grammar()
237 .ok_or_else(|| anyhow!("no grammar for language"))?;
238
239 // Iterate through query matches
240 let matches = self.get_matches_in_file(content, grammar)?;
241
242 let language_scope = language.default_scope();
243 let placeholder = language_scope.collapsed_placeholder();
244
245 let mut spans = Vec::new();
246 let mut collapsed_ranges_within = Vec::new();
247 let mut parsed_name_ranges = HashSet::new();
248 for (i, context_match) in matches.iter().enumerate() {
249 // Items which are collapsible but not embeddable have no item range
250 let item_range = if let Some(item_range) = context_match.item_range.clone() {
251 item_range
252 } else {
253 continue;
254 };
255
256 // Checks for deduplication
257 let name;
258 if let Some(name_range) = context_match.name_range.clone() {
259 name = content
260 .get(name_range.clone())
261 .map_or(String::new(), |s| s.to_string());
262 if parsed_name_ranges.contains(&name_range) {
263 continue;
264 }
265 parsed_name_ranges.insert(name_range);
266 } else {
267 name = String::new();
268 }
269
270 collapsed_ranges_within.clear();
271 'outer: for remaining_match in &matches[(i + 1)..] {
272 for collapsed_range in &remaining_match.collapse_ranges {
273 if item_range.start <= collapsed_range.start
274 && item_range.end >= collapsed_range.end
275 {
276 collapsed_ranges_within.push(collapsed_range.clone());
277 } else {
278 break 'outer;
279 }
280 }
281 }
282
283 collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
284
285 let mut span_content = String::new();
286 for context_range in &context_match.context_ranges {
287 add_content_from_range(
288 &mut span_content,
289 content,
290 context_range.clone(),
291 context_match.start_col,
292 );
293 span_content.push_str("\n");
294 }
295
296 let mut offset = item_range.start;
297 for collapsed_range in &collapsed_ranges_within {
298 if collapsed_range.start > offset {
299 add_content_from_range(
300 &mut span_content,
301 content,
302 offset..collapsed_range.start,
303 context_match.start_col,
304 );
305 offset = collapsed_range.start;
306 }
307
308 if collapsed_range.end > offset {
309 span_content.push_str(placeholder);
310 offset = collapsed_range.end;
311 }
312 }
313
314 if offset < item_range.end {
315 add_content_from_range(
316 &mut span_content,
317 content,
318 offset..item_range.end,
319 context_match.start_col,
320 );
321 }
322
323 let sha1 = SpanDigest::from(span_content.as_str());
324 spans.push(Span {
325 name,
326 content: span_content,
327 range: item_range.clone(),
328 embedding: None,
329 digest: sha1,
330 token_count: 0,
331 })
332 }
333
334 return Ok(spans);
335 }
336}
337
338pub(crate) fn subtract_ranges(
339 ranges: &[Range<usize>],
340 ranges_to_subtract: &[Range<usize>],
341) -> Vec<Range<usize>> {
342 let mut result = Vec::new();
343
344 let mut ranges_to_subtract = ranges_to_subtract.iter().peekable();
345
346 for range in ranges {
347 let mut offset = range.start;
348
349 while offset < range.end {
350 if let Some(range_to_subtract) = ranges_to_subtract.peek() {
351 if offset < range_to_subtract.start {
352 let next_offset = cmp::min(range_to_subtract.start, range.end);
353 result.push(offset..next_offset);
354 offset = next_offset;
355 } else {
356 let next_offset = cmp::min(range_to_subtract.end, range.end);
357 offset = next_offset;
358 }
359
360 if offset >= range_to_subtract.end {
361 ranges_to_subtract.next();
362 }
363 } else {
364 result.push(offset..range.end);
365 offset = range.end;
366 }
367 }
368 }
369
370 result
371}
372
373fn add_content_from_range(
374 output: &mut String,
375 content: &str,
376 range: Range<usize>,
377 start_col: usize,
378) {
379 for mut line in content.get(range.clone()).unwrap_or("").lines() {
380 for _ in 0..start_col {
381 if line.starts_with(' ') {
382 line = &line[1..];
383 } else {
384 break;
385 }
386 }
387 output.push_str(line);
388 output.push('\n');
389 }
390 output.pop();
391}