add tests for rust context parsing, and update rust embedding query

KCaverly and maxbrunsfeld created

Co-authored-by: maxbrunsfeld <max@zed.dev>

Change summary

crates/vector_store/src/parsing.rs            |   6 
crates/vector_store/src/vector_store_tests.rs | 156 +++++++++++++++++---
crates/zed/src/languages/rust/embedding.scm   |  64 ++++++--
3 files changed, 179 insertions(+), 47 deletions(-)

Detailed changes

crates/vector_store/src/parsing.rs 🔗

@@ -81,7 +81,11 @@ impl CodeContextRetriever {
 
             if let Some((item, byte_range)) = item.zip(byte_range) {
                 if !name.is_empty() {
-                    let item = format!("{}\n{}", context_spans.join("\n"), item);
+                    let item = if context_spans.is_empty() {
+                        item.to_string()
+                    } else {
+                        format!("{}\n{}", context_spans.join("\n"), item)
+                    };
 
                     let document_text = CODE_CONTEXT_TEMPLATE
                         .replace("<path>", relative_path.to_str().unwrap())

crates/vector_store/src/vector_store_tests.rs 🔗

@@ -1,5 +1,9 @@
 use crate::{
-    db::dot, embedding::EmbeddingProvider, vector_store_settings::VectorStoreSettings, VectorStore,
+    db::dot,
+    embedding::EmbeddingProvider,
+    parsing::{CodeContextRetriever, Document},
+    vector_store_settings::VectorStoreSettings,
+    VectorStore,
 };
 use anyhow::Result;
 use async_trait::async_trait;
@@ -9,7 +13,7 @@ use project::{project_settings::ProjectSettings, FakeFs, Project};
 use rand::{rngs::StdRng, Rng};
 use serde_json::json;
 use settings::SettingsStore;
-use std::sync::Arc;
+use std::{path::Path, sync::Arc};
 use unindent::Unindent;
 
 #[ctor::ctor]
@@ -52,24 +56,7 @@ async fn test_vector_store(cx: &mut TestAppContext) {
     .await;
 
     let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
-    let rust_language = Arc::new(
-        Language::new(
-            LanguageConfig {
-                name: "Rust".into(),
-                path_suffixes: vec!["rs".into()],
-                ..Default::default()
-            },
-            Some(tree_sitter_rust::language()),
-        )
-        .with_embedding_query(
-            r#"
-            (function_item
-                name: (identifier) @name
-                body: (block)) @item
-            "#,
-        )
-        .unwrap(),
-    );
+    let rust_language = rust_lang();
     languages.add(rust_language);
 
     let db_dir = tempdir::TempDir::new("vector-store").unwrap();
@@ -109,14 +96,59 @@ async fn test_vector_store(cx: &mut TestAppContext) {
 
 #[gpui::test]
 async fn test_code_context_retrieval(cx: &mut TestAppContext) {
-    // let mut retriever = CodeContextRetriever::new(fs);
-
-    // retriever::parse_file(
-    //     "
-    //     //
-    // ",
-    // );
-    //
+    let language = rust_lang();
+    let mut retriever = CodeContextRetriever::new();
+
+    let text = "
+        /// A doc comment
+        /// that spans multiple lines
+        fn a() {
+            b
+        }
+
+        impl C for D {
+        }
+    "
+    .unindent();
+
+    let parsed_files = retriever
+        .parse_file(Path::new("foo.rs"), &text, language)
+        .unwrap();
+
+    assert_eq!(
+        parsed_files,
+        &[
+            Document {
+                name: "a".into(),
+                range: text.find("fn a").unwrap()..(text.find("}").unwrap() + 1),
+                content: "
+                    The below code snippet is from file 'foo.rs'
+
+                    ```rust
+                    /// A doc comment
+                    /// that spans multiple lines
+                    fn a() {
+                        b
+                    }
+                    ```"
+                .unindent(),
+                embedding: vec![],
+            },
+            Document {
+                name: "C for D".into(),
+                range: text.find("impl C").unwrap()..(text.rfind("}").unwrap() + 1),
+                content: "
+                    The below code snippet is from file 'foo.rs'
+
+                    ```rust
+                    impl C for D {
+                    }
+                    ```"
+                .unindent(),
+                embedding: vec![],
+            }
+        ]
+    );
 }
 
 #[gpui::test]
@@ -178,3 +210,71 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
             .collect())
     }
 }
+
+fn rust_lang() -> Arc<Language> {
+    Arc::new(
+        Language::new(
+            LanguageConfig {
+                name: "Rust".into(),
+                path_suffixes: vec!["rs".into()],
+                ..Default::default()
+            },
+            Some(tree_sitter_rust::language()),
+        )
+        .with_embedding_query(
+            r#"
+            (
+                (line_comment)* @context
+                .
+                (enum_item
+                    name: (_) @name) @item
+            )
+
+            (
+                (line_comment)* @context
+                .
+                (struct_item
+                    name: (_) @name) @item
+            )
+
+            (
+                (line_comment)* @context
+                .
+                (impl_item
+                    trait: (_)? @name
+                    "for"? @name
+                    type: (_) @name) @item
+            )
+
+            (
+                (line_comment)* @context
+                .
+                (trait_item
+                    name: (_) @name) @item
+            )
+
+            (
+                (line_comment)* @context
+                .
+                (function_item
+                    name: (_) @name) @item
+            )
+
+            (
+                (line_comment)* @context
+                .
+                (macro_definition
+                    name: (_) @name) @item
+            )
+
+            (
+                (line_comment)* @context
+                .
+                (function_signature_item
+                    name: (_) @name) @item
+            )
+            "#,
+        )
+        .unwrap(),
+    )
+}

crates/zed/src/languages/rust/embedding.scm 🔗

@@ -1,22 +1,50 @@
 (
     (line_comment)* @context
     .
-    [
-        (enum_item
-            name: (_) @name) @item
-        (struct_item
-            name: (_) @name) @item
-        (impl_item
-            trait: (_)? @name
-            "for"? @name
-            type: (_) @name) @item
-        (trait_item
-            name: (_) @name) @item
-        (function_item
-            name: (_) @name) @item
-        (macro_definition
-            name: (_) @name) @item
-        (function_signature_item
-            name: (_) @name) @item
-    ]
+    (enum_item
+        name: (_) @name) @item
+)
+
+(
+    (line_comment)* @context
+    .
+    (struct_item
+        name: (_) @name) @item
+)
+
+(
+    (line_comment)* @context
+    .
+    (impl_item
+        trait: (_)? @name
+        "for"? @name
+        type: (_) @name) @item
+)
+
+(
+    (line_comment)* @context
+    .
+    (trait_item
+        name: (_) @name) @item
+)
+
+(
+    (line_comment)* @context
+    .
+    (function_item
+        name: (_) @name) @item
+)
+
+(
+    (line_comment)* @context
+    .
+    (macro_definition
+        name: (_) @name) @item
+)
+
+(
+    (line_comment)* @context
+    .
+    (function_signature_item
+        name: (_) @name) @item
 )