add embedding treesitter query for cpp

KCaverly created

Change summary

Cargo.lock                                    |   1 
crates/vector_store/Cargo.toml                |   1 
crates/vector_store/src/vector_store_tests.rs | 312 +++++++++++++++++++-
crates/zed/src/languages/cpp/embedding.scm    |  61 ++++
4 files changed, 347 insertions(+), 28 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -8518,6 +8518,7 @@ dependencies = [
  "theme",
  "tiktoken-rs 0.5.0",
  "tree-sitter",
+ "tree-sitter-cpp",
  "tree-sitter-rust",
  "tree-sitter-toml 0.20.0",
  "tree-sitter-typescript 0.20.2 (registry+https://github.com/rust-lang/crates.io-index)",

crates/vector_store/Cargo.toml 🔗

@@ -54,3 +54,4 @@ env_logger.workspace = true
 tree-sitter-typescript = "*"
 tree-sitter-rust = "*"
 tree-sitter-toml = "*"
+tree-sitter-cpp = "*"

crates/vector_store/src/vector_store_tests.rs 🔗

@@ -211,32 +211,33 @@ async fn test_code_context_retrieval_javascript() {
     let mut retriever = CodeContextRetriever::new();
 
     let text = "
-/* globals importScripts, backend */
-function _authorize() {}
-
-/**
- * Sometimes the frontend build is way faster than backend.
- */
-export async function authorizeBank() {
-    _authorize(pushModal, upgradingAccountId, {});
-}
+        /* globals importScripts, backend */
+        function _authorize() {}
+
+        /**
+         * Sometimes the frontend build is way faster than backend.
+         */
+        export async function authorizeBank() {
+            _authorize(pushModal, upgradingAccountId, {});
+        }
 
-export class SettingsPage {
-    /* This is a test setting */
-    constructor(page) {
-        this.page = page;
-    }
-}
+        export class SettingsPage {
+            /* This is a test setting */
+            constructor(page) {
+                this.page = page;
+            }
+        }
 
-/* This is a test comment */
-class TestClass {}
+        /* This is a test comment */
+        class TestClass {}
 
-/* Schema for editor_events in Clickhouse. */
-export interface ClickhouseEditorEvent {
-    installation_id: string
-    operation: string
-}
-";
+        /* Schema for editor_events in Clickhouse. */
+        export interface ClickhouseEditorEvent {
+            installation_id: string
+            operation: string
+        }
+        "
+    .unindent();
 
     let parsed_files = retriever
         .parse_file(Path::new("foo.js"), &text, language)
@@ -258,7 +259,7 @@ export interface ClickhouseEditorEvent {
         },
         Document {
             name: "async function authorizeBank".into(),
-            range: text.find("export async").unwrap()..224,
+            range: text.find("export async").unwrap()..223,
             content: "
                     The below code snippet is from file 'foo.js'
 
@@ -275,7 +276,7 @@ export interface ClickhouseEditorEvent {
         },
         Document {
             name: "class SettingsPage".into(),
-            range: 226..344,
+            range: 225..343,
             content: "
                     The below code snippet is from file 'foo.js'
 
@@ -292,7 +293,7 @@ export interface ClickhouseEditorEvent {
         },
         Document {
             name: "constructor".into(),
-            range: 291..342,
+            range: 290..341,
             content: "
                 The below code snippet is from file 'foo.js'
 
@@ -307,7 +308,7 @@ export interface ClickhouseEditorEvent {
         },
         Document {
             name: "class TestClass".into(),
-            range: 375..393,
+            range: 374..392,
             content: "
                     The below code snippet is from file 'foo.js'
 
@@ -320,7 +321,7 @@ export interface ClickhouseEditorEvent {
         },
         Document {
             name: "interface ClickhouseEditorEvent".into(),
-            range: 441..533,
+            range: 440..532,
             content: "
                     The below code snippet is from file 'foo.js'
 
@@ -341,6 +342,181 @@ export interface ClickhouseEditorEvent {
     }
 }
 
+#[gpui::test]
+async fn test_code_context_retrieval_cpp() {
+    let language = cpp_lang();
+    let mut retriever = CodeContextRetriever::new();
+
+    let text = "
+    /**
+     * @brief Main function
+     * @returns 0 on exit
+     */
+    int main() { return 0; }
+
+    /**
+    * This is a test comment
+    */
+    class MyClass {       // The class
+        public:             // Access specifier
+        int myNum;        // Attribute (int variable)
+        string myString;  // Attribute (string variable)
+    };
+
+    // This is a test comment
+    enum Color { red, green, blue };
+
+    /** This is a preceeding block comment
+     * This is the second line
+     */
+    struct {           // Structure declaration
+        int myNum;       // Member (int variable)
+        string myString; // Member (string variable)
+    } myStructure;
+
+    /**
+    * @brief Matrix class.
+    */
+    template <typename T,
+              typename = typename std::enable_if<
+                std::is_integral<T>::value || std::is_floating_point<T>::value,
+                bool>::type>
+    class Matrix2 {
+        std::vector<std::vector<T>> _mat;
+
+    public:
+        /**
+        * @brief Constructor
+        * @tparam Integer ensuring integers are being evaluated and not other
+        * data types.
+        * @param size denoting the size of Matrix as size x size
+        */
+        template <typename Integer,
+                  typename = typename std::enable_if<std::is_integral<Integer>::value,
+                  Integer>::type>
+        explicit Matrix(const Integer size) {
+            for (size_t i = 0; i < size; ++i) {
+                _mat.emplace_back(std::vector<T>(size, 0));
+            }
+        }
+    }"
+    .unindent();
+
+    let parsed_files = retriever
+        .parse_file(Path::new("foo.cpp"), &text, language)
+        .unwrap();
+
+    let test_documents = &[
+        Document {
+            name: "int main".into(),
+            range: 54..78,
+            content: "
+                The below code snippet is from file 'foo.cpp'
+
+                ```cpp
+                /**
+                 * @brief Main function
+                 * @returns 0 on exit
+                 */
+                int main() { return 0; }
+                ```"
+            .unindent(),
+            embedding: vec![],
+        },
+        Document {
+            name: "class MyClass".into(),
+            range: 112..295,
+            content: "
+                The below code snippet is from file 'foo.cpp'
+
+                ```cpp
+                /**
+                * This is a test comment
+                */
+                class MyClass {       // The class
+                    public:             // Access specifier
+                    int myNum;        // Attribute (int variable)
+                    string myString;  // Attribute (string variable)
+                }
+                ```"
+            .unindent(),
+            embedding: vec![],
+        },
+        Document {
+            name: "enum Color".into(),
+            range: 324..355,
+            content: "
+                The below code snippet is from file 'foo.cpp'
+
+                ```cpp
+                // This is a test comment
+                enum Color { red, green, blue }
+                ```"
+            .unindent(),
+            embedding: vec![],
+        },
+        Document {
+            name: "struct myStructure".into(),
+            range: 428..581,
+            content: "
+                The below code snippet is from file 'foo.cpp'
+
+                ```cpp
+                /** This is a preceeding block comment
+                 * This is the second line
+                 */
+                struct {           // Structure declaration
+                    int myNum;       // Member (int variable)
+                    string myString; // Member (string variable)
+                } myStructure;
+                ```"
+            .unindent(),
+            embedding: vec![],
+        },
+        Document {
+            name: "class Matrix2".into(),
+            range: 613..1342,
+            content: "
+                The below code snippet is from file 'foo.cpp'
+
+                ```cpp
+                /**
+                * @brief Matrix class.
+                */
+                template <typename T,
+                          typename = typename std::enable_if<
+                            std::is_integral<T>::value || std::is_floating_point<T>::value,
+                            bool>::type>
+                class Matrix2 {
+                    std::vector<std::vector<T>> _mat;
+
+                public:
+                    /**
+                    * @brief Constructor
+                    * @tparam Integer ensuring integers are being evaluated and not other
+                    * data types.
+                    * @param size denoting the size of Matrix as size x size
+                    */
+                    template <typename Integer,
+                              typename = typename std::enable_if<std::is_integral<Integer>::value,
+                              Integer>::type>
+                    explicit Matrix(const Integer size) {
+                        for (size_t i = 0; i < size; ++i) {
+                            _mat.emplace_back(std::vector<T>(size, 0));
+                        }
+                    }
+                }
+                ```"
+            .unindent(),
+            embedding: vec![],
+        },
+    ];
+
+    for idx in 0..test_documents.len() {
+        assert_eq!(test_documents[idx], parsed_files[idx]);
+    }
+}
+
 #[gpui::test]
 fn test_dot_product(mut rng: StdRng) {
     assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
@@ -594,3 +770,83 @@ fn toml_lang() -> Arc<Language> {
         Some(tree_sitter_toml::language()),
     ))
 }
+
+fn cpp_lang() -> Arc<Language> {
+    Arc::new(
+        Language::new(
+            LanguageConfig {
+                name: "CPP".into(),
+                path_suffixes: vec!["cpp".into()],
+                ..Default::default()
+            },
+            Some(tree_sitter_cpp::language()),
+        )
+        .with_embedding_query(
+            r#"
+            (
+                (comment)* @context
+                .
+                (function_definition
+                    (type_qualifier)? @name
+                    type: (_)? @name
+                    declarator: [
+                        (function_declarator
+                            declarator: (_) @name)
+                        (pointer_declarator
+                            "*" @name
+                            declarator: (function_declarator
+                            declarator: (_) @name))
+                        (pointer_declarator
+                            "*" @name
+                            declarator: (pointer_declarator
+                                "*" @name
+                            declarator: (function_declarator
+                                declarator: (_) @name)))
+                        (reference_declarator
+                            ["&" "&&"] @name
+                            (function_declarator
+                            declarator: (_) @name))
+                    ]
+                    (type_qualifier)? @name) @item
+                )
+
+            (
+                (comment)* @context
+                .
+                (template_declaration
+                    (class_specifier
+                        "class" @name
+                        name: (_) @name)
+                        ) @item
+            )
+
+            (
+                (comment)* @context
+                .
+                (class_specifier
+                    "class" @name
+                    name: (_) @name) @item
+                )
+
+            (
+                (comment)* @context
+                .
+                (enum_specifier
+                    "enum" @name
+                    name: (_) @name) @item
+                )
+
+            (
+                (comment)* @context
+                .
+                (declaration
+                    type: (struct_specifier
+                    "struct" @name)
+                    declarator: (_) @name) @item
+            )
+
+            "#,
+        )
+        .unwrap(),
+    )
+}

crates/zed/src/languages/cpp/embedding.scm 🔗

@@ -0,0 +1,61 @@
+(
+    (comment)* @context
+    .
+    (function_definition
+        (type_qualifier)? @name
+        type: (_)? @name
+        declarator: [
+            (function_declarator
+                declarator: (_) @name)
+            (pointer_declarator
+                "*" @name
+                declarator: (function_declarator
+                declarator: (_) @name))
+            (pointer_declarator
+                "*" @name
+                declarator: (pointer_declarator
+                    "*" @name
+                declarator: (function_declarator
+                    declarator: (_) @name)))
+            (reference_declarator
+                ["&" "&&"] @name
+                (function_declarator
+                declarator: (_) @name))
+        ]
+        (type_qualifier)? @name) @item
+    )
+
+(
+    (comment)* @context
+    .
+    (template_declaration
+        (class_specifier
+            "class" @name
+            name: (_) @name)
+            ) @item
+)
+
+(
+    (comment)* @context
+    .
+    (class_specifier
+        "class" @name
+        name: (_) @name) @item
+    )
+
+(
+    (comment)* @context
+    .
+    (enum_specifier
+        "enum" @name
+        name: (_) @name) @item
+    )
+
+(
+    (comment)* @context
+    .
+    (declaration
+        type: (struct_specifier
+        "struct" @name)
+        declarator: (_) @name) @item
+)