@@ -526,7 +526,7 @@ pub struct OutlineConfig {
pub struct EmbeddingConfig {
pub query: Query,
pub item_capture_ix: u32,
- pub name_capture_ix: u32,
+ pub name_capture_ix: Option<u32>,
pub context_capture_ix: Option<u32>,
pub collapse_capture_ix: Option<u32>,
pub keep_capture_ix: Option<u32>,
@@ -1263,7 +1263,7 @@ impl Language {
("collapse", &mut collapse_capture_ix),
],
);
- if let Some((item_capture_ix, name_capture_ix)) = item_capture_ix.zip(name_capture_ix) {
+ if let Some(item_capture_ix) = item_capture_ix {
grammar.embedding_config = Some(EmbeddingConfig {
query,
item_capture_ix,
@@ -1,6 +1,12 @@
use anyhow::{anyhow, Ok, Result};
use language::{Grammar, Language};
-use std::{cmp, collections::HashSet, ops::Range, path::Path, sync::Arc};
+use std::{
+ cmp::{self, Reverse},
+ collections::HashSet,
+ ops::Range,
+ path::Path,
+ sync::Arc,
+};
use tree_sitter::{Parser, QueryCursor};
#[derive(Debug, PartialEq, Clone)]
@@ -15,7 +21,7 @@ const CODE_CONTEXT_TEMPLATE: &str =
"The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
const ENTIRE_FILE_TEMPLATE: &str =
"The below snippet is from file '<path>'\n\n```<language>\n<item>\n```";
-pub const PARSEABLE_ENTIRE_FILE_TYPES: [&str; 4] = ["TOML", "YAML", "JSON", "CSS"];
+pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] = &["TOML", "YAML", "CSS"];
pub struct CodeContextRetriever {
pub parser: Parser,
@@ -30,8 +36,8 @@ pub struct CodeContextRetriever {
#[derive(Debug, Clone)]
pub struct CodeContextMatch {
pub start_col: usize,
- pub item_range: Range<usize>,
- pub name_range: Range<usize>,
+ pub item_range: Option<Range<usize>>,
+ pub name_range: Option<Range<usize>>,
pub context_ranges: Vec<Range<usize>>,
pub collapse_ranges: Vec<Range<usize>>,
}
@@ -44,7 +50,7 @@ impl CodeContextRetriever {
}
}
- fn _parse_entire_file(
+ fn parse_entire_file(
&self,
relative_path: &Path,
language_name: Arc<str>,
@@ -97,7 +103,7 @@ impl CodeContextRetriever {
if capture.index == embedding_config.item_capture_ix {
item_range = Some(capture.node.byte_range());
start_col = capture.node.start_position().column;
- } else if capture.index == embedding_config.name_capture_ix {
+ } else if Some(capture.index) == embedding_config.name_capture_ix {
name_range = Some(capture.node.byte_range());
} else if Some(capture.index) == embedding_config.context_capture_ix {
context_ranges.push(capture.node.byte_range());
@@ -108,16 +114,13 @@ impl CodeContextRetriever {
}
}
- if item_range.is_some() && name_range.is_some() {
- let item_range = item_range.unwrap();
- captures.push(CodeContextMatch {
- start_col,
- item_range,
- name_range: name_range.unwrap(),
- context_ranges,
- collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
- });
- }
+ captures.push(CodeContextMatch {
+ start_col,
+ item_range,
+ name_range,
+ context_ranges,
+ collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
+ });
}
Ok(captures)
}
@@ -129,7 +132,12 @@ impl CodeContextRetriever {
language: Arc<Language>,
) -> Result<Vec<Document>> {
let language_name = language.name();
- let mut documents = self.parse_file(relative_path, content, language)?;
+
+ if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
+ return self.parse_entire_file(relative_path, language_name, &content);
+ }
+
+ let mut documents = self.parse_file(content, language)?;
for document in &mut documents {
document.content = CODE_CONTEXT_TEMPLATE
.replace("<path>", relative_path.to_string_lossy().as_ref())
@@ -139,16 +147,7 @@ impl CodeContextRetriever {
Ok(documents)
}
- pub fn parse_file(
- &mut self,
- relative_path: &Path,
- content: &str,
- language: Arc<Language>,
- ) -> Result<Vec<Document>> {
- if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref()) {
- return self._parse_entire_file(relative_path, language.name(), &content);
- }
-
+ pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Document>> {
let grammar = language
.grammar()
.ok_or_else(|| anyhow!("no grammar for language"))?;
@@ -163,32 +162,49 @@ impl CodeContextRetriever {
let mut collapsed_ranges_within = Vec::new();
let mut parsed_name_ranges = HashSet::new();
for (i, context_match) in matches.iter().enumerate() {
- if parsed_name_ranges.contains(&context_match.name_range) {
+ // Items which are collapsible but not embeddable have no item range
+ let item_range = if let Some(item_range) = context_match.item_range.clone() {
+ item_range
+ } else {
continue;
+ };
+
+ // Checks for deduplication
+ let name;
+ if let Some(name_range) = context_match.name_range.clone() {
+ name = content
+ .get(name_range.clone())
+ .map_or(String::new(), |s| s.to_string());
+ if parsed_name_ranges.contains(&name_range) {
+ continue;
+ }
+ parsed_name_ranges.insert(name_range);
+ } else {
+ name = String::new();
}
collapsed_ranges_within.clear();
- for remaining_match in &matches[(i + 1)..] {
- if context_match
- .item_range
- .contains(&remaining_match.item_range.start)
- && context_match
- .item_range
- .contains(&remaining_match.item_range.end)
- {
- collapsed_ranges_within.extend(remaining_match.collapse_ranges.iter().cloned());
- } else {
- break;
+ 'outer: for remaining_match in &matches[(i + 1)..] {
+ for collapsed_range in &remaining_match.collapse_ranges {
+ if item_range.start <= collapsed_range.start
+ && item_range.end >= collapsed_range.end
+ {
+ collapsed_ranges_within.push(collapsed_range.clone());
+ } else {
+ break 'outer;
+ }
}
}
+ collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
+
let mut document_content = String::new();
for context_range in &context_match.context_ranges {
document_content.push_str(&content[context_range.clone()]);
document_content.push_str("\n");
}
- let mut offset = context_match.item_range.start;
+ let mut offset = item_range.start;
for collapsed_range in &collapsed_ranges_within {
if collapsed_range.start > offset {
add_content_from_range(
@@ -197,29 +213,30 @@ impl CodeContextRetriever {
offset..collapsed_range.start,
context_match.start_col,
);
+ offset = collapsed_range.start;
+ }
+
+ if collapsed_range.end > offset {
+ document_content.push_str(placeholder);
+ offset = collapsed_range.end;
}
- document_content.push_str(placeholder);
- offset = collapsed_range.end;
}
- if offset < context_match.item_range.end {
+ if offset < item_range.end {
add_content_from_range(
&mut document_content,
content,
- offset..context_match.item_range.end,
+ offset..item_range.end,
context_match.start_col,
);
}
- if let Some(name) = content.get(context_match.name_range.clone()) {
- parsed_name_ranges.insert(context_match.name_range.clone());
- documents.push(Document {
- name: name.to_string(),
- content: document_content,
- range: context_match.item_range.clone(),
- embedding: vec![],
- })
- }
+ documents.push(Document {
+ name,
+ content: document_content,
+ range: item_range.clone(),
+ embedding: vec![],
+ })
}
return Ok(documents);
@@ -170,9 +170,7 @@ async fn test_code_context_retrieval_rust() {
"
.unindent();
- let documents = retriever
- .parse_file(Path::new("foo.rs"), &text, language)
- .unwrap();
+ let documents = retriever.parse_file(&text, language).unwrap();
assert_documents_eq(
&documents,
@@ -229,6 +227,76 @@ async fn test_code_context_retrieval_rust() {
);
}
+#[gpui::test]
+async fn test_code_context_retrieval_json() {
+ let language = json_lang();
+ let mut retriever = CodeContextRetriever::new();
+
+ let text = r#"
+ {
+ "array": [1, 2, 3, 4],
+ "string": "abcdefg",
+ "nested_object": {
+ "array_2": [5, 6, 7, 8],
+ "string_2": "hijklmnop",
+ "boolean": true,
+ "none": null
+ }
+ }
+ "#
+ .unindent();
+
+ let documents = retriever.parse_file(&text, language.clone()).unwrap();
+
+ assert_documents_eq(
+ &documents,
+ &[(
+ r#"
+ {
+ "array": [],
+ "string": "",
+ "nested_object": {
+ "array_2": [],
+ "string_2": "",
+ "boolean": true,
+ "none": null
+ }
+ }"#
+ .unindent(),
+ text.find("{").unwrap(),
+ )],
+ );
+
+ let text = r#"
+ [
+ {
+ "name": "somebody",
+ "age": 42
+ },
+ {
+ "name": "somebody else",
+ "age": 43
+ }
+ ]
+ "#
+ .unindent();
+
+ let documents = retriever.parse_file(&text, language.clone()).unwrap();
+
+ assert_documents_eq(
+ &documents,
+ &[(
+ r#"
+ [{
+ "name": "",
+ "age": 42
+ }]"#
+ .unindent(),
+ text.find("[").unwrap(),
+ )],
+ );
+}
+
fn assert_documents_eq(
documents: &[Document],
expected_contents_and_start_offsets: &[(String, usize)],
@@ -913,6 +981,35 @@ fn rust_lang() -> Arc<Language> {
)
}
+fn json_lang() -> Arc<Language> {
+ Arc::new(
+ Language::new(
+ LanguageConfig {
+ name: "JSON".into(),
+ path_suffixes: vec!["json".into()],
+ ..Default::default()
+ },
+ Some(tree_sitter_json::language()),
+ )
+ .with_embedding_query(
+ r#"
+ (document) @item
+
+ (array
+ "[" @keep
+ .
+ (object)? @keep
+ "]" @keep) @collapse
+
+ (pair value: (string
+ "\"" @keep
+ "\"" @keep) @collapse)
+ "#,
+ )
+ .unwrap(),
+ )
+}
+
fn toml_lang() -> Arc<Language> {
Arc::new(Language::new(
LanguageConfig {