vector_store_tests.rs

  1use crate::{
  2    db::dot,
  3    embedding::EmbeddingProvider,
  4    parsing::{CodeContextRetriever, Document},
  5    vector_store_settings::VectorStoreSettings,
  6    VectorStore,
  7};
  8use anyhow::Result;
  9use async_trait::async_trait;
 10use gpui::{Task, TestAppContext};
 11use language::{Language, LanguageConfig, LanguageRegistry};
 12use project::{project_settings::ProjectSettings, FakeFs, Fs, Project};
 13use rand::{rngs::StdRng, Rng};
 14use serde_json::json;
 15use settings::SettingsStore;
 16use std::{
 17    path::Path,
 18    sync::{
 19        atomic::{self, AtomicUsize},
 20        Arc,
 21    },
 22};
 23use unindent::Unindent;
 24
 25#[ctor::ctor]
 26fn init_logger() {
 27    if std::env::var("RUST_LOG").is_ok() {
 28        env_logger::init();
 29    }
 30}
 31
 32#[gpui::test]
 33async fn test_vector_store(cx: &mut TestAppContext) {
 34    cx.update(|cx| {
 35        cx.set_global(SettingsStore::test(cx));
 36        settings::register::<VectorStoreSettings>(cx);
 37        settings::register::<ProjectSettings>(cx);
 38    });
 39
 40    let fs = FakeFs::new(cx.background());
 41    fs.insert_tree(
 42        "/the-root",
 43        json!({
 44            "src": {
 45                "file1.rs": "
 46                    fn aaa() {
 47                        println!(\"aaaa!\");
 48                    }
 49
 50                    fn zzzzzzzzz() {
 51                        println!(\"SLEEPING\");
 52                    }
 53                ".unindent(),
 54                "file2.rs": "
 55                    fn bbb() {
 56                        println!(\"bbbb!\");
 57                    }
 58                ".unindent(),
 59                "file3.toml": "
 60                    ZZZZZZZ = 5
 61                    ".unindent(),
 62            }
 63        }),
 64    )
 65    .await;
 66
 67    let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
 68    let rust_language = rust_lang();
 69    let toml_language = toml_lang();
 70    languages.add(rust_language);
 71    languages.add(toml_language);
 72
 73    let db_dir = tempdir::TempDir::new("vector-store").unwrap();
 74    let db_path = db_dir.path().join("db.sqlite");
 75
 76    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 77    let store = VectorStore::new(
 78        fs.clone(),
 79        db_path,
 80        embedding_provider.clone(),
 81        languages,
 82        cx.to_async(),
 83    )
 84    .await
 85    .unwrap();
 86
 87    let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
 88    let worktree_id = project.read_with(cx, |project, cx| {
 89        project.worktrees(cx).next().unwrap().read(cx).id()
 90    });
 91    let file_count = store
 92        .update(cx, |store, cx| store.index_project(project.clone(), cx))
 93        .await
 94        .unwrap();
 95    assert_eq!(file_count, 3);
 96    cx.foreground().run_until_parked();
 97    store.update(cx, |store, _cx| {
 98        assert_eq!(
 99            store.remaining_files_to_index_for_project(&project),
100            Some(0)
101        );
102    });
103
104    let search_results = store
105        .update(cx, |store, cx| {
106            store.search_project(project.clone(), "aaaa".to_string(), 5, cx)
107        })
108        .await
109        .unwrap();
110
111    assert_eq!(search_results[0].byte_range.start, 0);
112    assert_eq!(search_results[0].name, "aaa");
113    assert_eq!(search_results[0].worktree_id, worktree_id);
114
115    fs.save(
116        "/the-root/src/file2.rs".as_ref(),
117        &"
118            fn dddd() { println!(\"ddddd!\"); }
119            struct pqpqpqp {}
120        "
121        .unindent()
122        .into(),
123        Default::default(),
124    )
125    .await
126    .unwrap();
127
128    cx.foreground().run_until_parked();
129
130    let prev_embedding_count = embedding_provider.embedding_count();
131    let file_count = store
132        .update(cx, |store, cx| store.index_project(project.clone(), cx))
133        .await
134        .unwrap();
135    assert_eq!(file_count, 1);
136
137    cx.foreground().run_until_parked();
138    store.update(cx, |store, _cx| {
139        assert_eq!(
140            store.remaining_files_to_index_for_project(&project),
141            Some(0)
142        );
143    });
144
145    assert_eq!(
146        embedding_provider.embedding_count() - prev_embedding_count,
147        2
148    );
149}
150
151#[gpui::test]
152async fn test_code_context_retrieval_rust() {
153    let language = rust_lang();
154    let mut retriever = CodeContextRetriever::new();
155
156    let text = "
157        /// A doc comment
158        /// that spans multiple lines
159        fn a() {
160            b
161        }
162
163        impl C for D {
164        }
165    "
166    .unindent();
167
168    let parsed_files = retriever
169        .parse_file(Path::new("foo.rs"), &text, language)
170        .unwrap();
171
172    assert_eq!(
173        parsed_files,
174        &[
175            Document {
176                name: "a".into(),
177                range: text.find("fn a").unwrap()..(text.find("}").unwrap() + 1),
178                content: "
179                    The below code snippet is from file 'foo.rs'
180
181                    ```rust
182                    /// A doc comment
183                    /// that spans multiple lines
184                    fn a() {
185                        b
186                    }
187                    ```"
188                .unindent(),
189                embedding: vec![],
190            },
191            Document {
192                name: "C for D".into(),
193                range: text.find("impl C").unwrap()..(text.rfind("}").unwrap() + 1),
194                content: "
195                    The below code snippet is from file 'foo.rs'
196
197                    ```rust
198                    impl C for D {
199                    }
200                    ```"
201                .unindent(),
202                embedding: vec![],
203            }
204        ]
205    );
206}
207
208#[gpui::test]
209async fn test_code_context_retrieval_javascript() {
210    let language = js_lang();
211    let mut retriever = CodeContextRetriever::new();
212
213    let text = "
214        /* globals importScripts, backend */
215        function _authorize() {}
216
217        /**
218         * Sometimes the frontend build is way faster than backend.
219         */
220        export async function authorizeBank() {
221            _authorize(pushModal, upgradingAccountId, {});
222        }
223
224        export class SettingsPage {
225            /* This is a test setting */
226            constructor(page) {
227                this.page = page;
228            }
229        }
230
231        /* This is a test comment */
232        class TestClass {}
233
234        /* Schema for editor_events in Clickhouse. */
235        export interface ClickhouseEditorEvent {
236            installation_id: string
237            operation: string
238        }
239        "
240    .unindent();
241
242    let parsed_files = retriever
243        .parse_file(Path::new("foo.js"), &text, language)
244        .unwrap();
245
246    let test_documents = &[
247        Document {
248            name: "function _authorize".into(),
249            range: text.find("function _authorize").unwrap()..(text.find("}").unwrap() + 1),
250            content: "
251                    The below code snippet is from file 'foo.js'
252
253                    ```javascript
254                    /* globals importScripts, backend */
255                    function _authorize() {}
256                    ```"
257            .unindent(),
258            embedding: vec![],
259        },
260        Document {
261            name: "async function authorizeBank".into(),
262            range: text.find("export async").unwrap()..223,
263            content: "
264                    The below code snippet is from file 'foo.js'
265
266                    ```javascript
267                    /**
268                     * Sometimes the frontend build is way faster than backend.
269                     */
270                    export async function authorizeBank() {
271                        _authorize(pushModal, upgradingAccountId, {});
272                    }
273                    ```"
274            .unindent(),
275            embedding: vec![],
276        },
277        Document {
278            name: "class SettingsPage".into(),
279            range: 225..343,
280            content: "
281                    The below code snippet is from file 'foo.js'
282
283                    ```javascript
284                    export class SettingsPage {
285                        /* This is a test setting */
286                        constructor(page) {
287                            this.page = page;
288                        }
289                    }
290                    ```"
291            .unindent(),
292            embedding: vec![],
293        },
294        Document {
295            name: "constructor".into(),
296            range: 290..341,
297            content: "
298                The below code snippet is from file 'foo.js'
299
300                ```javascript
301                /* This is a test setting */
302                constructor(page) {
303                        this.page = page;
304                    }
305                ```"
306            .unindent(),
307            embedding: vec![],
308        },
309        Document {
310            name: "class TestClass".into(),
311            range: 374..392,
312            content: "
313                    The below code snippet is from file 'foo.js'
314
315                    ```javascript
316                    /* This is a test comment */
317                    class TestClass {}
318                    ```"
319            .unindent(),
320            embedding: vec![],
321        },
322        Document {
323            name: "interface ClickhouseEditorEvent".into(),
324            range: 440..532,
325            content: "
326                    The below code snippet is from file 'foo.js'
327
328                    ```javascript
329                    /* Schema for editor_events in Clickhouse. */
330                    export interface ClickhouseEditorEvent {
331                        installation_id: string
332                        operation: string
333                    }
334                    ```"
335            .unindent(),
336            embedding: vec![],
337        },
338    ];
339
340    for idx in 0..test_documents.len() {
341        assert_eq!(test_documents[idx], parsed_files[idx]);
342    }
343}
344
345#[gpui::test]
346async fn test_code_context_retrieval_cpp() {
347    let language = cpp_lang();
348    let mut retriever = CodeContextRetriever::new();
349
350    let text = "
351    /**
352     * @brief Main function
353     * @returns 0 on exit
354     */
355    int main() { return 0; }
356
357    /**
358    * This is a test comment
359    */
360    class MyClass {       // The class
361        public:             // Access specifier
362        int myNum;        // Attribute (int variable)
363        string myString;  // Attribute (string variable)
364    };
365
366    // This is a test comment
367    enum Color { red, green, blue };
368
369    /** This is a preceeding block comment
370     * This is the second line
371     */
372    struct {           // Structure declaration
373        int myNum;       // Member (int variable)
374        string myString; // Member (string variable)
375    } myStructure;
376
377    /**
378    * @brief Matrix class.
379    */
380    template <typename T,
381              typename = typename std::enable_if<
382                std::is_integral<T>::value || std::is_floating_point<T>::value,
383                bool>::type>
384    class Matrix2 {
385        std::vector<std::vector<T>> _mat;
386
387    public:
388        /**
389        * @brief Constructor
390        * @tparam Integer ensuring integers are being evaluated and not other
391        * data types.
392        * @param size denoting the size of Matrix as size x size
393        */
394        template <typename Integer,
395                  typename = typename std::enable_if<std::is_integral<Integer>::value,
396                  Integer>::type>
397        explicit Matrix(const Integer size) {
398            for (size_t i = 0; i < size; ++i) {
399                _mat.emplace_back(std::vector<T>(size, 0));
400            }
401        }
402    }"
403    .unindent();
404
405    let parsed_files = retriever
406        .parse_file(Path::new("foo.cpp"), &text, language)
407        .unwrap();
408
409    let test_documents = &[
410        Document {
411            name: "int main".into(),
412            range: 54..78,
413            content: "
414                The below code snippet is from file 'foo.cpp'
415
416                ```cpp
417                /**
418                 * @brief Main function
419                 * @returns 0 on exit
420                 */
421                int main() { return 0; }
422                ```"
423            .unindent(),
424            embedding: vec![],
425        },
426        Document {
427            name: "class MyClass".into(),
428            range: 112..295,
429            content: "
430                The below code snippet is from file 'foo.cpp'
431
432                ```cpp
433                /**
434                * This is a test comment
435                */
436                class MyClass {       // The class
437                    public:             // Access specifier
438                    int myNum;        // Attribute (int variable)
439                    string myString;  // Attribute (string variable)
440                }
441                ```"
442            .unindent(),
443            embedding: vec![],
444        },
445        Document {
446            name: "enum Color".into(),
447            range: 324..355,
448            content: "
449                The below code snippet is from file 'foo.cpp'
450
451                ```cpp
452                // This is a test comment
453                enum Color { red, green, blue }
454                ```"
455            .unindent(),
456            embedding: vec![],
457        },
458        Document {
459            name: "struct myStructure".into(),
460            range: 428..581,
461            content: "
462                The below code snippet is from file 'foo.cpp'
463
464                ```cpp
465                /** This is a preceeding block comment
466                 * This is the second line
467                 */
468                struct {           // Structure declaration
469                    int myNum;       // Member (int variable)
470                    string myString; // Member (string variable)
471                } myStructure;
472                ```"
473            .unindent(),
474            embedding: vec![],
475        },
476        Document {
477            name: "class Matrix2".into(),
478            range: 613..1342,
479            content: "
480                The below code snippet is from file 'foo.cpp'
481
482                ```cpp
483                /**
484                * @brief Matrix class.
485                */
486                template <typename T,
487                          typename = typename std::enable_if<
488                            std::is_integral<T>::value || std::is_floating_point<T>::value,
489                            bool>::type>
490                class Matrix2 {
491                    std::vector<std::vector<T>> _mat;
492
493                public:
494                    /**
495                    * @brief Constructor
496                    * @tparam Integer ensuring integers are being evaluated and not other
497                    * data types.
498                    * @param size denoting the size of Matrix as size x size
499                    */
500                    template <typename Integer,
501                              typename = typename std::enable_if<std::is_integral<Integer>::value,
502                              Integer>::type>
503                    explicit Matrix(const Integer size) {
504                        for (size_t i = 0; i < size; ++i) {
505                            _mat.emplace_back(std::vector<T>(size, 0));
506                        }
507                    }
508                }
509                ```"
510            .unindent(),
511            embedding: vec![],
512        },
513    ];
514
515    for idx in 0..test_documents.len() {
516        assert_eq!(test_documents[idx], parsed_files[idx]);
517    }
518}
519
520#[gpui::test]
521fn test_dot_product(mut rng: StdRng) {
522    assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
523    assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
524
525    for _ in 0..100 {
526        let size = 1536;
527        let mut a = vec![0.; size];
528        let mut b = vec![0.; size];
529        for (a, b) in a.iter_mut().zip(b.iter_mut()) {
530            *a = rng.gen();
531            *b = rng.gen();
532        }
533
534        assert_eq!(
535            round_to_decimals(dot(&a, &b), 1),
536            round_to_decimals(reference_dot(&a, &b), 1)
537        );
538    }
539
540    fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
541        let factor = (10.0 as f32).powi(decimal_places);
542        (n * factor).round() / factor
543    }
544
545    fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
546        a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
547    }
548}
549
550#[derive(Default)]
551struct FakeEmbeddingProvider {
552    embedding_count: AtomicUsize,
553}
554
555impl FakeEmbeddingProvider {
556    fn embedding_count(&self) -> usize {
557        self.embedding_count.load(atomic::Ordering::SeqCst)
558    }
559}
560
561#[async_trait]
562impl EmbeddingProvider for FakeEmbeddingProvider {
563    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
564        self.embedding_count
565            .fetch_add(spans.len(), atomic::Ordering::SeqCst);
566        Ok(spans
567            .iter()
568            .map(|span| {
569                let mut result = vec![1.0; 26];
570                for letter in span.chars() {
571                    let letter = letter.to_ascii_lowercase();
572                    if letter as u32 >= 'a' as u32 {
573                        let ix = (letter as u32) - ('a' as u32);
574                        if ix < 26 {
575                            result[ix as usize] += 1.0;
576                        }
577                    }
578                }
579
580                let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
581                for x in &mut result {
582                    *x /= norm;
583                }
584
585                result
586            })
587            .collect())
588    }
589}
590
591fn js_lang() -> Arc<Language> {
592    Arc::new(
593        Language::new(
594            LanguageConfig {
595                name: "Javascript".into(),
596                path_suffixes: vec!["js".into()],
597                ..Default::default()
598            },
599            Some(tree_sitter_typescript::language_tsx()),
600        )
601        .with_embedding_query(
602            &r#"
603
604            (
605                (comment)* @context
606                .
607                (export_statement
608                    (function_declaration
609                        "async"? @name
610                        "function" @name
611                        name: (_) @name)) @item
612                    )
613
614            (
615                (comment)* @context
616                .
617                (function_declaration
618                    "async"? @name
619                    "function" @name
620                    name: (_) @name) @item
621                    )
622
623            (
624                (comment)* @context
625                .
626                (export_statement
627                    (class_declaration
628                        "class" @name
629                        name: (_) @name)) @item
630                    )
631
632            (
633                (comment)* @context
634                .
635                (class_declaration
636                    "class" @name
637                    name: (_) @name) @item
638                    )
639
640            (
641                (comment)* @context
642                .
643                (method_definition
644                    [
645                        "get"
646                        "set"
647                        "async"
648                        "*"
649                        "static"
650                    ]* @name
651                    name: (_) @name) @item
652                )
653
654            (
655                (comment)* @context
656                .
657                (export_statement
658                    (interface_declaration
659                        "interface" @name
660                        name: (_) @name)) @item
661                )
662
663            (
664                (comment)* @context
665                .
666                (interface_declaration
667                    "interface" @name
668                    name: (_) @name) @item
669                )
670
671            (
672                (comment)* @context
673                .
674                (export_statement
675                    (enum_declaration
676                        "enum" @name
677                        name: (_) @name)) @item
678                )
679
680            (
681                (comment)* @context
682                .
683                (enum_declaration
684                    "enum" @name
685                    name: (_) @name) @item
686                )
687
688                    "#
689            .unindent(),
690        )
691        .unwrap(),
692    )
693}
694
695fn rust_lang() -> Arc<Language> {
696    Arc::new(
697        Language::new(
698            LanguageConfig {
699                name: "Rust".into(),
700                path_suffixes: vec!["rs".into()],
701                ..Default::default()
702            },
703            Some(tree_sitter_rust::language()),
704        )
705        .with_embedding_query(
706            r#"
707            (
708                (line_comment)* @context
709                .
710                (enum_item
711                    name: (_) @name) @item
712            )
713
714            (
715                (line_comment)* @context
716                .
717                (struct_item
718                    name: (_) @name) @item
719            )
720
721            (
722                (line_comment)* @context
723                .
724                (impl_item
725                    trait: (_)? @name
726                    "for"? @name
727                    type: (_) @name) @item
728            )
729
730            (
731                (line_comment)* @context
732                .
733                (trait_item
734                    name: (_) @name) @item
735            )
736
737            (
738                (line_comment)* @context
739                .
740                (function_item
741                    name: (_) @name) @item
742            )
743
744            (
745                (line_comment)* @context
746                .
747                (macro_definition
748                    name: (_) @name) @item
749            )
750
751            (
752                (line_comment)* @context
753                .
754                (function_signature_item
755                    name: (_) @name) @item
756            )
757            "#,
758        )
759        .unwrap(),
760    )
761}
762
763fn toml_lang() -> Arc<Language> {
764    Arc::new(Language::new(
765        LanguageConfig {
766            name: "TOML".into(),
767            path_suffixes: vec!["toml".into()],
768            ..Default::default()
769        },
770        Some(tree_sitter_toml::language()),
771    ))
772}
773
774fn cpp_lang() -> Arc<Language> {
775    Arc::new(
776        Language::new(
777            LanguageConfig {
778                name: "CPP".into(),
779                path_suffixes: vec!["cpp".into()],
780                ..Default::default()
781            },
782            Some(tree_sitter_cpp::language()),
783        )
784        .with_embedding_query(
785            r#"
786            (
787                (comment)* @context
788                .
789                (function_definition
790                    (type_qualifier)? @name
791                    type: (_)? @name
792                    declarator: [
793                        (function_declarator
794                            declarator: (_) @name)
795                        (pointer_declarator
796                            "*" @name
797                            declarator: (function_declarator
798                            declarator: (_) @name))
799                        (pointer_declarator
800                            "*" @name
801                            declarator: (pointer_declarator
802                                "*" @name
803                            declarator: (function_declarator
804                                declarator: (_) @name)))
805                        (reference_declarator
806                            ["&" "&&"] @name
807                            (function_declarator
808                            declarator: (_) @name))
809                    ]
810                    (type_qualifier)? @name) @item
811                )
812
813            (
814                (comment)* @context
815                .
816                (template_declaration
817                    (class_specifier
818                        "class" @name
819                        name: (_) @name)
820                        ) @item
821            )
822
823            (
824                (comment)* @context
825                .
826                (class_specifier
827                    "class" @name
828                    name: (_) @name) @item
829                )
830
831            (
832                (comment)* @context
833                .
834                (enum_specifier
835                    "enum" @name
836                    name: (_) @name) @item
837                )
838
839            (
840                (comment)* @context
841                .
842                (declaration
843                    type: (struct_specifier
844                    "struct" @name)
845                    declarator: (_) @name) @item
846            )
847
848            "#,
849        )
850        .unwrap(),
851    )
852}