1use std::{path::PathBuf, sync::Arc, time::SystemTime};
2
3use anyhow::{anyhow, Ok, Result};
4use project::Fs;
5use tree_sitter::{Parser, QueryCursor};
6
7use crate::PendingFile;
8
9#[derive(Debug, PartialEq, Clone)]
10pub struct Document {
11 pub offset: usize,
12 pub name: String,
13 pub embedding: Vec<f32>,
14}
15
16#[derive(Debug, PartialEq, Clone)]
17pub struct ParsedFile {
18 pub path: PathBuf,
19 pub mtime: SystemTime,
20 pub documents: Vec<Document>,
21}
22
23const CODE_CONTEXT_TEMPLATE: &str =
24 "The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
25
26pub struct CodeContextRetriever {
27 pub parser: Parser,
28 pub cursor: QueryCursor,
29 pub fs: Arc<dyn Fs>,
30}
31
32impl CodeContextRetriever {
33 pub async fn parse_file(
34 &mut self,
35 pending_file: PendingFile,
36 ) -> Result<(ParsedFile, Vec<String>)> {
37 let grammar = pending_file
38 .language
39 .grammar()
40 .ok_or_else(|| anyhow!("no grammar for language"))?;
41 let embedding_config = grammar
42 .embedding_config
43 .as_ref()
44 .ok_or_else(|| anyhow!("no embedding queries"))?;
45
46 let content = self.fs.load(&pending_file.absolute_path).await?;
47
48 self.parser.set_language(grammar.ts_language).unwrap();
49
50 let tree = self
51 .parser
52 .parse(&content, None)
53 .ok_or_else(|| anyhow!("parsing failed"))?;
54
55 let mut documents = Vec::new();
56 let mut context_spans = Vec::new();
57
58 // Iterate through query matches
59 for mat in self.cursor.matches(
60 &embedding_config.query,
61 tree.root_node(),
62 content.as_bytes(),
63 ) {
64 // log::info!("-----MATCH-----");
65
66 let mut name = Vec::new();
67 let mut item: Option<&str> = None;
68 let mut offset: Option<usize> = None;
69 for capture in mat.captures {
70 if capture.index == embedding_config.item_capture_ix {
71 offset = Some(capture.node.byte_range().start);
72 item = content.get(capture.node.byte_range());
73 } else if capture.index == embedding_config.name_capture_ix {
74 if let Some(name_content) = content.get(capture.node.byte_range()) {
75 name.push(name_content);
76 }
77 }
78
79 if let Some(context_capture_ix) = embedding_config.context_capture_ix {
80 if capture.index == context_capture_ix {
81 if let Some(context) = content.get(capture.node.byte_range()) {
82 name.push(context);
83 }
84 }
85 }
86 }
87
88 if item.is_some() && offset.is_some() && name.len() > 0 {
89 let context_span = CODE_CONTEXT_TEMPLATE
90 .replace("<path>", pending_file.relative_path.to_str().unwrap())
91 .replace("<language>", &pending_file.language.name().to_lowercase())
92 .replace("<item>", item.unwrap());
93
94 // log::info!("Name: {:?}", name);
95 // log::info!("Span: {:?}", util::truncate(&context_span, 100));
96
97 context_spans.push(context_span);
98 documents.push(Document {
99 name: name.join(" "),
100 offset: offset.unwrap(),
101 embedding: Vec::new(),
102 })
103 }
104 }
105
106 return Ok((
107 ParsedFile {
108 path: pending_file.relative_path,
109 mtime: pending_file.modified_time,
110 documents,
111 },
112 context_spans,
113 ));
114 }
115}