1use ai::{
2 embedding::{Embedding, EmbeddingProvider},
3 models::TruncationDirection,
4};
5use anyhow::{anyhow, Result};
6use language::{Grammar, Language};
7use rusqlite::{
8 types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
9 ToSql,
10};
11use sha1::{Digest, Sha1};
12use std::{
13 borrow::Cow,
14 cmp::{self, Reverse},
15 collections::HashSet,
16 ops::Range,
17 path::Path,
18 sync::Arc,
19};
20use tree_sitter::{Parser, QueryCursor};
21
22#[derive(Debug, PartialEq, Eq, Clone, Hash)]
23pub struct SpanDigest(pub [u8; 20]);
24
25impl FromSql for SpanDigest {
26 fn column_result(value: ValueRef) -> FromSqlResult<Self> {
27 let blob = value.as_blob()?;
28 let bytes =
29 blob.try_into()
30 .map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize {
31 expected_size: 20,
32 blob_size: blob.len(),
33 })?;
34 return Ok(SpanDigest(bytes));
35 }
36}
37
38impl ToSql for SpanDigest {
39 fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
40 self.0.to_sql()
41 }
42}
43
44impl From<&'_ str> for SpanDigest {
45 fn from(value: &'_ str) -> Self {
46 let mut sha1 = Sha1::new();
47 sha1.update(value);
48 Self(sha1.finalize().into())
49 }
50}
51
52#[derive(Debug, PartialEq, Clone)]
53pub struct Span {
54 pub name: String,
55 pub range: Range<usize>,
56 pub content: String,
57 pub embedding: Option<Embedding>,
58 pub digest: SpanDigest,
59 pub token_count: usize,
60}
61
62const CODE_CONTEXT_TEMPLATE: &str =
63 "The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
64const ENTIRE_FILE_TEMPLATE: &str =
65 "The below snippet is from file '<path>'\n\n```<language>\n<item>\n```";
66const MARKDOWN_CONTEXT_TEMPLATE: &str = "The below file contents is from file '<path>'\n\n<item>";
67pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] = &[
68 "TOML", "YAML", "CSS", "HEEX", "ERB", "SVELTE", "HTML", "Scheme",
69];
70
71pub struct CodeContextRetriever {
72 pub parser: Parser,
73 pub cursor: QueryCursor,
74 pub embedding_provider: Arc<dyn EmbeddingProvider>,
75}
76
77// Every match has an item, this represents the fundamental treesitter symbol and anchors the search
78// Every match has one or more 'name' captures. These indicate the display range of the item for deduplication.
79// If there are preceding comments, we track this with a context capture
80// If there is a piece that should be collapsed in hierarchical queries, we capture it with a collapse capture
81// If there is a piece that should be kept inside a collapsed node, we capture it with a keep capture
82#[derive(Debug, Clone)]
83pub struct CodeContextMatch {
84 pub start_col: usize,
85 pub item_range: Option<Range<usize>>,
86 pub name_range: Option<Range<usize>>,
87 pub context_ranges: Vec<Range<usize>>,
88 pub collapse_ranges: Vec<Range<usize>>,
89}
90
91impl CodeContextRetriever {
92 pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
93 Self {
94 parser: Parser::new(),
95 cursor: QueryCursor::new(),
96 embedding_provider,
97 }
98 }
99
100 fn parse_entire_file(
101 &self,
102 relative_path: Option<&Path>,
103 language_name: Arc<str>,
104 content: &str,
105 ) -> Result<Vec<Span>> {
106 let document_span = ENTIRE_FILE_TEMPLATE
107 .replace(
108 "<path>",
109 &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
110 )
111 .replace("<language>", language_name.as_ref())
112 .replace("<item>", &content);
113 let digest = SpanDigest::from(document_span.as_str());
114 let model = self.embedding_provider.base_model();
115 let document_span = model.truncate(
116 &document_span,
117 model.capacity()?,
118 ai::models::TruncationDirection::End,
119 )?;
120 let token_count = model.count_tokens(&document_span)?;
121
122 Ok(vec![Span {
123 range: 0..content.len(),
124 content: document_span,
125 embedding: Default::default(),
126 name: language_name.to_string(),
127 digest,
128 token_count,
129 }])
130 }
131
132 fn parse_markdown_file(
133 &self,
134 relative_path: Option<&Path>,
135 content: &str,
136 ) -> Result<Vec<Span>> {
137 let document_span = MARKDOWN_CONTEXT_TEMPLATE
138 .replace(
139 "<path>",
140 &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
141 )
142 .replace("<item>", &content);
143 let digest = SpanDigest::from(document_span.as_str());
144
145 let model = self.embedding_provider.base_model();
146 let document_span = model.truncate(
147 &document_span,
148 model.capacity()?,
149 ai::models::TruncationDirection::End,
150 )?;
151 let token_count = model.count_tokens(&document_span)?;
152
153 Ok(vec![Span {
154 range: 0..content.len(),
155 content: document_span,
156 embedding: None,
157 name: "Markdown".to_string(),
158 digest,
159 token_count,
160 }])
161 }
162
163 fn get_matches_in_file(
164 &mut self,
165 content: &str,
166 grammar: &Arc<Grammar>,
167 ) -> Result<Vec<CodeContextMatch>> {
168 let embedding_config = grammar
169 .embedding_config
170 .as_ref()
171 .ok_or_else(|| anyhow!("no embedding queries"))?;
172 self.parser.set_language(&grammar.ts_language).unwrap();
173
174 let tree = self
175 .parser
176 .parse(&content, None)
177 .ok_or_else(|| anyhow!("parsing failed"))?;
178
179 let mut captures: Vec<CodeContextMatch> = Vec::new();
180 let mut collapse_ranges: Vec<Range<usize>> = Vec::new();
181 let mut keep_ranges: Vec<Range<usize>> = Vec::new();
182 for mat in self.cursor.matches(
183 &embedding_config.query,
184 tree.root_node(),
185 content.as_bytes(),
186 ) {
187 let mut start_col = 0;
188 let mut item_range: Option<Range<usize>> = None;
189 let mut name_range: Option<Range<usize>> = None;
190 let mut context_ranges: Vec<Range<usize>> = Vec::new();
191 collapse_ranges.clear();
192 keep_ranges.clear();
193 for capture in mat.captures {
194 if capture.index == embedding_config.item_capture_ix {
195 item_range = Some(capture.node.byte_range());
196 start_col = capture.node.start_position().column;
197 } else if Some(capture.index) == embedding_config.name_capture_ix {
198 name_range = Some(capture.node.byte_range());
199 } else if Some(capture.index) == embedding_config.context_capture_ix {
200 context_ranges.push(capture.node.byte_range());
201 } else if Some(capture.index) == embedding_config.collapse_capture_ix {
202 collapse_ranges.push(capture.node.byte_range());
203 } else if Some(capture.index) == embedding_config.keep_capture_ix {
204 keep_ranges.push(capture.node.byte_range());
205 }
206 }
207
208 captures.push(CodeContextMatch {
209 start_col,
210 item_range,
211 name_range,
212 context_ranges,
213 collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
214 });
215 }
216 Ok(captures)
217 }
218
219 pub fn parse_file_with_template(
220 &mut self,
221 relative_path: Option<&Path>,
222 content: &str,
223 language: Arc<Language>,
224 ) -> Result<Vec<Span>> {
225 let language_name = language.name();
226
227 if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
228 return self.parse_entire_file(relative_path, language_name, &content);
229 } else if ["Markdown", "Plain Text"].contains(&language_name.as_ref()) {
230 return self.parse_markdown_file(relative_path, &content);
231 }
232
233 let mut spans = self.parse_file(content, language)?;
234 for span in &mut spans {
235 let document_content = CODE_CONTEXT_TEMPLATE
236 .replace(
237 "<path>",
238 &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
239 )
240 .replace("<language>", language_name.as_ref())
241 .replace("item", &span.content);
242
243 let model = self.embedding_provider.base_model();
244 let document_content = model.truncate(
245 &document_content,
246 model.capacity()?,
247 TruncationDirection::End,
248 )?;
249 let token_count = model.count_tokens(&document_content)?;
250
251 span.content = document_content;
252 span.token_count = token_count;
253 }
254 Ok(spans)
255 }
256
257 pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Span>> {
258 let grammar = language
259 .grammar()
260 .ok_or_else(|| anyhow!("no grammar for language"))?;
261
262 // Iterate through query matches
263 let matches = self.get_matches_in_file(content, grammar)?;
264
265 let language_scope = language.default_scope();
266 let placeholder = language_scope.collapsed_placeholder();
267
268 let mut spans = Vec::new();
269 let mut collapsed_ranges_within = Vec::new();
270 let mut parsed_name_ranges = HashSet::new();
271 for (i, context_match) in matches.iter().enumerate() {
272 // Items which are collapsible but not embeddable have no item range
273 let item_range = if let Some(item_range) = context_match.item_range.clone() {
274 item_range
275 } else {
276 continue;
277 };
278
279 // Checks for deduplication
280 let name;
281 if let Some(name_range) = context_match.name_range.clone() {
282 name = content
283 .get(name_range.clone())
284 .map_or(String::new(), |s| s.to_string());
285 if parsed_name_ranges.contains(&name_range) {
286 continue;
287 }
288 parsed_name_ranges.insert(name_range);
289 } else {
290 name = String::new();
291 }
292
293 collapsed_ranges_within.clear();
294 'outer: for remaining_match in &matches[(i + 1)..] {
295 for collapsed_range in &remaining_match.collapse_ranges {
296 if item_range.start <= collapsed_range.start
297 && item_range.end >= collapsed_range.end
298 {
299 collapsed_ranges_within.push(collapsed_range.clone());
300 } else {
301 break 'outer;
302 }
303 }
304 }
305
306 collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
307
308 let mut span_content = String::new();
309 for context_range in &context_match.context_ranges {
310 add_content_from_range(
311 &mut span_content,
312 content,
313 context_range.clone(),
314 context_match.start_col,
315 );
316 span_content.push_str("\n");
317 }
318
319 let mut offset = item_range.start;
320 for collapsed_range in &collapsed_ranges_within {
321 if collapsed_range.start > offset {
322 add_content_from_range(
323 &mut span_content,
324 content,
325 offset..collapsed_range.start,
326 context_match.start_col,
327 );
328 offset = collapsed_range.start;
329 }
330
331 if collapsed_range.end > offset {
332 span_content.push_str(placeholder);
333 offset = collapsed_range.end;
334 }
335 }
336
337 if offset < item_range.end {
338 add_content_from_range(
339 &mut span_content,
340 content,
341 offset..item_range.end,
342 context_match.start_col,
343 );
344 }
345
346 let sha1 = SpanDigest::from(span_content.as_str());
347 spans.push(Span {
348 name,
349 content: span_content,
350 range: item_range.clone(),
351 embedding: None,
352 digest: sha1,
353 token_count: 0,
354 })
355 }
356
357 return Ok(spans);
358 }
359}
360
361pub(crate) fn subtract_ranges(
362 ranges: &[Range<usize>],
363 ranges_to_subtract: &[Range<usize>],
364) -> Vec<Range<usize>> {
365 let mut result = Vec::new();
366
367 let mut ranges_to_subtract = ranges_to_subtract.iter().peekable();
368
369 for range in ranges {
370 let mut offset = range.start;
371
372 while offset < range.end {
373 if let Some(range_to_subtract) = ranges_to_subtract.peek() {
374 if offset < range_to_subtract.start {
375 let next_offset = cmp::min(range_to_subtract.start, range.end);
376 result.push(offset..next_offset);
377 offset = next_offset;
378 } else {
379 let next_offset = cmp::min(range_to_subtract.end, range.end);
380 offset = next_offset;
381 }
382
383 if offset >= range_to_subtract.end {
384 ranges_to_subtract.next();
385 }
386 } else {
387 result.push(offset..range.end);
388 offset = range.end;
389 }
390 }
391 }
392
393 result
394}
395
396fn add_content_from_range(
397 output: &mut String,
398 content: &str,
399 range: Range<usize>,
400 start_col: usize,
401) {
402 for mut line in content.get(range.clone()).unwrap_or("").lines() {
403 for _ in 0..start_col {
404 if line.starts_with(' ') {
405 line = &line[1..];
406 } else {
407 break;
408 }
409 }
410 output.push_str(line);
411 output.push('\n');
412 }
413 output.pop();
414}