@@ -20,6 +20,9 @@ pub struct ParsedFile {
pub documents: Vec<Document>,
}
+const CODE_CONTEXT_TEMPLATE: &str =
+ "The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
+
pub struct CodeContextRetriever {
pub parser: Parser,
pub cursor: QueryCursor,
@@ -58,28 +61,41 @@ impl CodeContextRetriever {
tree.root_node(),
content.as_bytes(),
) {
- let mut item_range: Option<Range<usize>> = None;
- let mut name_range: Option<Range<usize>> = None;
+ let mut name: Vec<&str> = vec![];
+ let mut item: Option<&str> = None;
+ let mut offset: Option<usize> = None;
for capture in mat.captures {
if capture.index == embedding_config.item_capture_ix {
- item_range = Some(capture.node.byte_range());
+ offset = Some(capture.node.byte_range().start);
+ item = content.get(capture.node.byte_range());
} else if capture.index == embedding_config.name_capture_ix {
- name_range = Some(capture.node.byte_range());
+ if let Some(name_content) = content.get(capture.node.byte_range()) {
+ name.push(name_content);
+ }
}
- }
- if let Some((item_range, name_range)) = item_range.zip(name_range) {
- if let Some((item, name)) =
- content.get(item_range.clone()).zip(content.get(name_range))
- {
- context_spans.push(item.to_string());
- documents.push(Document {
- name: name.to_string(),
- offset: item_range.start,
- embedding: Vec::new(),
- });
+ if let Some(context_capture_ix) = embedding_config.context_capture_ix {
+ if capture.index == context_capture_ix {
+ if let Some(context) = content.get(capture.node.byte_range()) {
+ name.push(context);
+ }
+ }
}
}
+
+ if item.is_some() && offset.is_some() && name.len() > 0 {
+ let context_span = CODE_CONTEXT_TEMPLATE
+ .replace("<path>", pending_file.relative_path.to_str().unwrap())
+ .replace("<language>", &pending_file.language.name().to_lowercase())
+ .replace("<item>", item.unwrap());
+
+ context_spans.push(context_span);
+ documents.push(Document {
+ name: name.join(" "),
+ offset: offset.unwrap(),
+ embedding: Vec::new(),
+ })
+ }
}
return Ok((