add embedding query for json with nested arrays and strings

KCaverly and maxbrunsfeld created

Co-authored-by: maxbrunsfeld <max@zed.dev>

Change summary

Cargo.lock                                        |   1 
crates/language/src/language.rs                   |   4 
crates/semantic_index/Cargo.toml                  |   1 
crates/semantic_index/src/parsing.rs              | 123 +++++++++-------
crates/semantic_index/src/semantic_index.rs       |   2 
crates/semantic_index/src/semantic_index_tests.rs | 103 +++++++++++++
crates/zed/src/languages/json/embedding.scm       |  14 +
7 files changed, 189 insertions(+), 59 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -6502,6 +6502,7 @@ dependencies = [
  "tree-sitter",
  "tree-sitter-cpp",
  "tree-sitter-elixir 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
+ "tree-sitter-json 0.19.0",
  "tree-sitter-rust",
  "tree-sitter-toml 0.20.0",
  "tree-sitter-typescript 0.20.2 (registry+https://github.com/rust-lang/crates.io-index)",

crates/language/src/language.rs 🔗

@@ -526,7 +526,7 @@ pub struct OutlineConfig {
 pub struct EmbeddingConfig {
     pub query: Query,
     pub item_capture_ix: u32,
-    pub name_capture_ix: u32,
+    pub name_capture_ix: Option<u32>,
     pub context_capture_ix: Option<u32>,
     pub collapse_capture_ix: Option<u32>,
     pub keep_capture_ix: Option<u32>,
@@ -1263,7 +1263,7 @@ impl Language {
                 ("collapse", &mut collapse_capture_ix),
             ],
         );
-        if let Some((item_capture_ix, name_capture_ix)) = item_capture_ix.zip(name_capture_ix) {
+        if let Some(item_capture_ix) = item_capture_ix {
             grammar.embedding_config = Some(EmbeddingConfig {
                 query,
                 item_capture_ix,

crates/semantic_index/Cargo.toml 🔗

@@ -54,6 +54,7 @@ ctor.workspace = true
 env_logger.workspace = true
 
 tree-sitter-typescript = "*"
+tree-sitter-json = "*"
 tree-sitter-rust = "*"
 tree-sitter-toml = "*"
 tree-sitter-cpp = "*"

crates/semantic_index/src/parsing.rs 🔗

@@ -1,6 +1,12 @@
 use anyhow::{anyhow, Ok, Result};
 use language::{Grammar, Language};
-use std::{cmp, collections::HashSet, ops::Range, path::Path, sync::Arc};
+use std::{
+    cmp::{self, Reverse},
+    collections::HashSet,
+    ops::Range,
+    path::Path,
+    sync::Arc,
+};
 use tree_sitter::{Parser, QueryCursor};
 
 #[derive(Debug, PartialEq, Clone)]
@@ -15,7 +21,7 @@ const CODE_CONTEXT_TEMPLATE: &str =
     "The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
 const ENTIRE_FILE_TEMPLATE: &str =
     "The below snippet is from file '<path>'\n\n```<language>\n<item>\n```";
-pub const PARSEABLE_ENTIRE_FILE_TYPES: [&str; 4] = ["TOML", "YAML", "JSON", "CSS"];
+pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] = &["TOML", "YAML", "CSS"];
 
 pub struct CodeContextRetriever {
     pub parser: Parser,
@@ -30,8 +36,8 @@ pub struct CodeContextRetriever {
 #[derive(Debug, Clone)]
 pub struct CodeContextMatch {
     pub start_col: usize,
-    pub item_range: Range<usize>,
-    pub name_range: Range<usize>,
+    pub item_range: Option<Range<usize>>,
+    pub name_range: Option<Range<usize>>,
     pub context_ranges: Vec<Range<usize>>,
     pub collapse_ranges: Vec<Range<usize>>,
 }
@@ -44,7 +50,7 @@ impl CodeContextRetriever {
         }
     }
 
-    fn _parse_entire_file(
+    fn parse_entire_file(
         &self,
         relative_path: &Path,
         language_name: Arc<str>,
@@ -97,7 +103,7 @@ impl CodeContextRetriever {
                 if capture.index == embedding_config.item_capture_ix {
                     item_range = Some(capture.node.byte_range());
                     start_col = capture.node.start_position().column;
-                } else if capture.index == embedding_config.name_capture_ix {
+                } else if Some(capture.index) == embedding_config.name_capture_ix {
                     name_range = Some(capture.node.byte_range());
                 } else if Some(capture.index) == embedding_config.context_capture_ix {
                     context_ranges.push(capture.node.byte_range());
@@ -108,16 +114,13 @@ impl CodeContextRetriever {
                 }
             }
 
-            if item_range.is_some() && name_range.is_some() {
-                let item_range = item_range.unwrap();
-                captures.push(CodeContextMatch {
-                    start_col,
-                    item_range,
-                    name_range: name_range.unwrap(),
-                    context_ranges,
-                    collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
-                });
-            }
+            captures.push(CodeContextMatch {
+                start_col,
+                item_range,
+                name_range,
+                context_ranges,
+                collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
+            });
         }
         Ok(captures)
     }
@@ -129,7 +132,12 @@ impl CodeContextRetriever {
         language: Arc<Language>,
     ) -> Result<Vec<Document>> {
         let language_name = language.name();
-        let mut documents = self.parse_file(relative_path, content, language)?;
+
+        if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
+            return self.parse_entire_file(relative_path, language_name, &content);
+        }
+
+        let mut documents = self.parse_file(content, language)?;
         for document in &mut documents {
             document.content = CODE_CONTEXT_TEMPLATE
                 .replace("<path>", relative_path.to_string_lossy().as_ref())
@@ -139,16 +147,7 @@ impl CodeContextRetriever {
         Ok(documents)
     }
 
-    pub fn parse_file(
-        &mut self,
-        relative_path: &Path,
-        content: &str,
-        language: Arc<Language>,
-    ) -> Result<Vec<Document>> {
-        if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref()) {
-            return self._parse_entire_file(relative_path, language.name(), &content);
-        }
-
+    pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Document>> {
         let grammar = language
             .grammar()
             .ok_or_else(|| anyhow!("no grammar for language"))?;
@@ -163,32 +162,49 @@ impl CodeContextRetriever {
         let mut collapsed_ranges_within = Vec::new();
         let mut parsed_name_ranges = HashSet::new();
         for (i, context_match) in matches.iter().enumerate() {
-            if parsed_name_ranges.contains(&context_match.name_range) {
+            // Items which are collapsible but not embeddable have no item range
+            let item_range = if let Some(item_range) = context_match.item_range.clone() {
+                item_range
+            } else {
                 continue;
+            };
+
+            // Checks for deduplication
+            let name;
+            if let Some(name_range) = context_match.name_range.clone() {
+                name = content
+                    .get(name_range.clone())
+                    .map_or(String::new(), |s| s.to_string());
+                if parsed_name_ranges.contains(&name_range) {
+                    continue;
+                }
+                parsed_name_ranges.insert(name_range);
+            } else {
+                name = String::new();
             }
 
             collapsed_ranges_within.clear();
-            for remaining_match in &matches[(i + 1)..] {
-                if context_match
-                    .item_range
-                    .contains(&remaining_match.item_range.start)
-                    && context_match
-                        .item_range
-                        .contains(&remaining_match.item_range.end)
-                {
-                    collapsed_ranges_within.extend(remaining_match.collapse_ranges.iter().cloned());
-                } else {
-                    break;
+            'outer: for remaining_match in &matches[(i + 1)..] {
+                for collapsed_range in &remaining_match.collapse_ranges {
+                    if item_range.start <= collapsed_range.start
+                        && item_range.end >= collapsed_range.end
+                    {
+                        collapsed_ranges_within.push(collapsed_range.clone());
+                    } else {
+                        break 'outer;
+                    }
                 }
             }
 
+            collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
+
             let mut document_content = String::new();
             for context_range in &context_match.context_ranges {
                 document_content.push_str(&content[context_range.clone()]);
                 document_content.push_str("\n");
             }
 
-            let mut offset = context_match.item_range.start;
+            let mut offset = item_range.start;
             for collapsed_range in &collapsed_ranges_within {
                 if collapsed_range.start > offset {
                     add_content_from_range(
@@ -197,29 +213,30 @@ impl CodeContextRetriever {
                         offset..collapsed_range.start,
                         context_match.start_col,
                     );
+                    offset = collapsed_range.start;
+                }
+
+                if collapsed_range.end > offset {
+                    document_content.push_str(placeholder);
+                    offset = collapsed_range.end;
                 }
-                document_content.push_str(placeholder);
-                offset = collapsed_range.end;
             }
 
-            if offset < context_match.item_range.end {
+            if offset < item_range.end {
                 add_content_from_range(
                     &mut document_content,
                     content,
-                    offset..context_match.item_range.end,
+                    offset..item_range.end,
                     context_match.start_col,
                 );
             }
 
-            if let Some(name) = content.get(context_match.name_range.clone()) {
-                parsed_name_ranges.insert(context_match.name_range.clone());
-                documents.push(Document {
-                    name: name.to_string(),
-                    content: document_content,
-                    range: context_match.item_range.clone(),
-                    embedding: vec![],
-                })
-            }
+            documents.push(Document {
+                name,
+                content: document_content,
+                range: item_range.clone(),
+                embedding: vec![],
+            })
         }
 
         return Ok(documents);

crates/semantic_index/src/semantic_index_tests.rs 🔗

@@ -170,9 +170,7 @@ async fn test_code_context_retrieval_rust() {
     "
     .unindent();
 
-    let documents = retriever
-        .parse_file(Path::new("foo.rs"), &text, language)
-        .unwrap();
+    let documents = retriever.parse_file(&text, language).unwrap();
 
     assert_documents_eq(
         &documents,
@@ -229,6 +227,76 @@ async fn test_code_context_retrieval_rust() {
     );
 }
 
+#[gpui::test]
+async fn test_code_context_retrieval_json() {
+    let language = json_lang();
+    let mut retriever = CodeContextRetriever::new();
+
+    let text = r#"
+        {
+            "array": [1, 2, 3, 4],
+            "string": "abcdefg",
+            "nested_object": {
+                "array_2": [5, 6, 7, 8],
+                "string_2": "hijklmnop",
+                "boolean": true,
+                "none": null
+            }
+        }
+    "#
+    .unindent();
+
+    let documents = retriever.parse_file(&text, language.clone()).unwrap();
+
+    assert_documents_eq(
+        &documents,
+        &[(
+            r#"
+                {
+                    "array": [],
+                    "string": "",
+                    "nested_object": {
+                        "array_2": [],
+                        "string_2": "",
+                        "boolean": true,
+                        "none": null
+                    }
+                }"#
+            .unindent(),
+            text.find("{").unwrap(),
+        )],
+    );
+
+    let text = r#"
+        [
+            {
+                "name": "somebody",
+                "age": 42
+            },
+            {
+                "name": "somebody else",
+                "age": 43
+            }
+        ]
+    "#
+    .unindent();
+
+    let documents = retriever.parse_file(&text, language.clone()).unwrap();
+
+    assert_documents_eq(
+        &documents,
+        &[(
+            r#"
+            [{
+                    "name": "",
+                    "age": 42
+                }]"#
+            .unindent(),
+            text.find("[").unwrap(),
+        )],
+    );
+}
+
 fn assert_documents_eq(
     documents: &[Document],
     expected_contents_and_start_offsets: &[(String, usize)],
@@ -913,6 +981,35 @@ fn rust_lang() -> Arc<Language> {
     )
 }
 
+fn json_lang() -> Arc<Language> {
+    Arc::new(
+        Language::new(
+            LanguageConfig {
+                name: "JSON".into(),
+                path_suffixes: vec!["json".into()],
+                ..Default::default()
+            },
+            Some(tree_sitter_json::language()),
+        )
+        .with_embedding_query(
+            r#"
+            (document) @item
+
+            (array
+                "[" @keep
+                .
+                (object)? @keep
+                "]" @keep) @collapse
+
+            (pair value: (string
+                "\"" @keep
+                "\"" @keep) @collapse)
+            "#,
+        )
+        .unwrap(),
+    )
+}
+
 fn toml_lang() -> Arc<Language> {
     Arc::new(Language::new(
         LanguageConfig {

crates/zed/src/languages/json/embedding.scm 🔗

@@ -0,0 +1,14 @@
+; Only produce one embedding for the entire file.
+(document) @item
+
+; Collapse arrays, except for the first object.
+(array
+  "[" @keep
+  .
+  (object)? @keep
+  "]" @keep) @collapse
+
+; Collapse string values (but not keys).
+(pair value: (string
+  "\"" @keep
+  "\"" @keep) @collapse)