vector_store_tests.rs

  1use crate::{
  2    db::dot,
  3    embedding::EmbeddingProvider,
  4    parsing::{CodeContextRetriever, Document},
  5    vector_store_settings::VectorStoreSettings,
  6    VectorStore,
  7};
  8use anyhow::Result;
  9use async_trait::async_trait;
 10use gpui::{Task, TestAppContext};
 11use language::{Language, LanguageConfig, LanguageRegistry};
 12use project::{project_settings::ProjectSettings, FakeFs, Fs, Project};
 13use rand::{rngs::StdRng, Rng};
 14use serde_json::json;
 15use settings::SettingsStore;
 16use std::{
 17    path::Path,
 18    sync::{
 19        atomic::{self, AtomicUsize},
 20        Arc,
 21    },
 22};
 23use unindent::Unindent;
 24
 25#[ctor::ctor]
 26fn init_logger() {
 27    if std::env::var("RUST_LOG").is_ok() {
 28        env_logger::init();
 29    }
 30}
 31
 32#[gpui::test]
 33async fn test_vector_store(cx: &mut TestAppContext) {
 34    cx.update(|cx| {
 35        cx.set_global(SettingsStore::test(cx));
 36        settings::register::<VectorStoreSettings>(cx);
 37        settings::register::<ProjectSettings>(cx);
 38    });
 39
 40    let fs = FakeFs::new(cx.background());
 41    fs.insert_tree(
 42        "/the-root",
 43        json!({
 44            "src": {
 45                "file1.rs": "
 46                    fn aaa() {
 47                        println!(\"aaaa!\");
 48                    }
 49
 50                    fn zzzzzzzzz() {
 51                        println!(\"SLEEPING\");
 52                    }
 53                ".unindent(),
 54                "file2.rs": "
 55                    fn bbb() {
 56                        println!(\"bbbb!\");
 57                    }
 58                ".unindent(),
 59            }
 60        }),
 61    )
 62    .await;
 63
 64    let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
 65    let rust_language = rust_lang();
 66    languages.add(rust_language);
 67
 68    let db_dir = tempdir::TempDir::new("vector-store").unwrap();
 69    let db_path = db_dir.path().join("db.sqlite");
 70
 71    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 72    let store = VectorStore::new(
 73        fs.clone(),
 74        db_path,
 75        embedding_provider.clone(),
 76        languages,
 77        cx.to_async(),
 78    )
 79    .await
 80    .unwrap();
 81
 82    let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
 83    let worktree_id = project.read_with(cx, |project, cx| {
 84        project.worktrees(cx).next().unwrap().read(cx).id()
 85    });
 86    let file_count = store
 87        .update(cx, |store, cx| store.index_project(project.clone(), cx))
 88        .await
 89        .unwrap();
 90    assert_eq!(file_count, 2);
 91    cx.foreground().run_until_parked();
 92    store.update(cx, |store, _cx| {
 93        assert_eq!(
 94            store.remaining_files_to_index_for_project(&project),
 95            Some(0)
 96        );
 97    });
 98
 99    let search_results = store
100        .update(cx, |store, cx| {
101            store.search_project(project.clone(), "aaaa".to_string(), 5, cx)
102        })
103        .await
104        .unwrap();
105
106    assert_eq!(search_results[0].byte_range.start, 0);
107    assert_eq!(search_results[0].name, "aaa");
108    assert_eq!(search_results[0].worktree_id, worktree_id);
109
110    fs.save(
111        "/the-root/src/file2.rs".as_ref(),
112        &"
113            fn dddd() { println!(\"ddddd!\"); }
114            struct pqpqpqp {}
115        "
116        .unindent()
117        .into(),
118        Default::default(),
119    )
120    .await
121    .unwrap();
122
123    cx.foreground().run_until_parked();
124
125    let prev_embedding_count = embedding_provider.embedding_count();
126    let file_count = store
127        .update(cx, |store, cx| store.index_project(project.clone(), cx))
128        .await
129        .unwrap();
130    assert_eq!(file_count, 1);
131
132    cx.foreground().run_until_parked();
133    store.update(cx, |store, _cx| {
134        assert_eq!(
135            store.remaining_files_to_index_for_project(&project),
136            Some(0)
137        );
138    });
139
140    assert_eq!(
141        embedding_provider.embedding_count() - prev_embedding_count,
142        2
143    );
144}
145
146#[gpui::test]
147async fn test_code_context_retrieval_rust() {
148    let language = rust_lang();
149    let mut retriever = CodeContextRetriever::new();
150
151    let text = "
152        /// A doc comment
153        /// that spans multiple lines
154        fn a() {
155            b
156        }
157
158        impl C for D {
159        }
160    "
161    .unindent();
162
163    let parsed_files = retriever
164        .parse_file(Path::new("foo.rs"), &text, language)
165        .unwrap();
166
167    assert_eq!(
168        parsed_files,
169        &[
170            Document {
171                name: "a".into(),
172                range: text.find("fn a").unwrap()..(text.find("}").unwrap() + 1),
173                content: "
174                    The below code snippet is from file 'foo.rs'
175
176                    ```rust
177                    /// A doc comment
178                    /// that spans multiple lines
179                    fn a() {
180                        b
181                    }
182                    ```"
183                .unindent(),
184                embedding: vec![],
185            },
186            Document {
187                name: "C for D".into(),
188                range: text.find("impl C").unwrap()..(text.rfind("}").unwrap() + 1),
189                content: "
190                    The below code snippet is from file 'foo.rs'
191
192                    ```rust
193                    impl C for D {
194                    }
195                    ```"
196                .unindent(),
197                embedding: vec![],
198            }
199        ]
200    );
201}
202
203#[gpui::test]
204async fn test_code_context_retrieval_javascript() {
205    let language = js_lang();
206    let mut retriever = CodeContextRetriever::new();
207
208    let text = "
209/* globals importScripts, backend */
210function _authorize() {}
211
212/**
213 * Sometimes the frontend build is way faster than backend.
214 */
215export async function authorizeBank() {
216    _authorize(pushModal, upgradingAccountId, {});
217}
218
219export class SettingsPage {
220    /* This is a test setting */
221    constructor(page) {
222        this.page = page;
223    }
224}
225
226/* This is a test comment */
227class TestClass {}
228
229/* Schema for editor_events in Clickhouse. */
230export interface ClickhouseEditorEvent {
231    installation_id: string
232    operation: string
233}
234";
235
236    let parsed_files = retriever
237        .parse_file(Path::new("foo.js"), &text, language)
238        .unwrap();
239
240    let test_documents = &[
241        Document {
242            name: "function _authorize".into(),
243            range: text.find("function _authorize").unwrap()..(text.find("}").unwrap() + 1),
244            content: "
245                    The below code snippet is from file 'foo.js'
246
247                    ```javascript
248                    /* globals importScripts, backend */
249                    function _authorize() {}
250                    ```"
251            .unindent(),
252            embedding: vec![],
253        },
254        Document {
255            name: "async function authorizeBank".into(),
256            range: text.find("export async").unwrap()..224,
257            content: "
258                    The below code snippet is from file 'foo.js'
259
260                    ```javascript
261                    /**
262                     * Sometimes the frontend build is way faster than backend.
263                     */
264                    export async function authorizeBank() {
265                        _authorize(pushModal, upgradingAccountId, {});
266                    }
267                    ```"
268            .unindent(),
269            embedding: vec![],
270        },
271        Document {
272            name: "class SettingsPage".into(),
273            range: 226..344,
274            content: "
275                    The below code snippet is from file 'foo.js'
276
277                    ```javascript
278                    export class SettingsPage {
279                        /* This is a test setting */
280                        constructor(page) {
281                            this.page = page;
282                        }
283                    }
284                    ```"
285            .unindent(),
286            embedding: vec![],
287        },
288        Document {
289            name: "constructor".into(),
290            range: 291..342,
291            content: "
292                The below code snippet is from file 'foo.js'
293
294                ```javascript
295                /* This is a test setting */
296                constructor(page) {
297                        this.page = page;
298                    }
299                ```"
300            .unindent(),
301            embedding: vec![],
302        },
303        Document {
304            name: "class TestClass".into(),
305            range: 375..393,
306            content: "
307                    The below code snippet is from file 'foo.js'
308
309                    ```javascript
310                    /* This is a test comment */
311                    class TestClass {}
312                    ```"
313            .unindent(),
314            embedding: vec![],
315        },
316        Document {
317            name: "interface ClickhouseEditorEvent".into(),
318            range: 441..533,
319            content: "
320                    The below code snippet is from file 'foo.js'
321
322                    ```javascript
323                    /* Schema for editor_events in Clickhouse. */
324                    export interface ClickhouseEditorEvent {
325                        installation_id: string
326                        operation: string
327                    }
328                    ```"
329            .unindent(),
330            embedding: vec![],
331        },
332    ];
333
334    for idx in 0..test_documents.len() {
335        assert_eq!(test_documents[idx], parsed_files[idx]);
336    }
337}
338
339#[gpui::test]
340fn test_dot_product(mut rng: StdRng) {
341    assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
342    assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
343
344    for _ in 0..100 {
345        let size = 1536;
346        let mut a = vec![0.; size];
347        let mut b = vec![0.; size];
348        for (a, b) in a.iter_mut().zip(b.iter_mut()) {
349            *a = rng.gen();
350            *b = rng.gen();
351        }
352
353        assert_eq!(
354            round_to_decimals(dot(&a, &b), 1),
355            round_to_decimals(reference_dot(&a, &b), 1)
356        );
357    }
358
359    fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
360        let factor = (10.0 as f32).powi(decimal_places);
361        (n * factor).round() / factor
362    }
363
364    fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
365        a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
366    }
367}
368
369#[derive(Default)]
370struct FakeEmbeddingProvider {
371    embedding_count: AtomicUsize,
372}
373
374impl FakeEmbeddingProvider {
375    fn embedding_count(&self) -> usize {
376        self.embedding_count.load(atomic::Ordering::SeqCst)
377    }
378}
379
380#[async_trait]
381impl EmbeddingProvider for FakeEmbeddingProvider {
382    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
383        self.embedding_count
384            .fetch_add(spans.len(), atomic::Ordering::SeqCst);
385        Ok(spans
386            .iter()
387            .map(|span| {
388                let mut result = vec![1.0; 26];
389                for letter in span.chars() {
390                    let letter = letter.to_ascii_lowercase();
391                    if letter as u32 >= 'a' as u32 {
392                        let ix = (letter as u32) - ('a' as u32);
393                        if ix < 26 {
394                            result[ix as usize] += 1.0;
395                        }
396                    }
397                }
398
399                let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
400                for x in &mut result {
401                    *x /= norm;
402                }
403
404                result
405            })
406            .collect())
407    }
408}
409
410fn js_lang() -> Arc<Language> {
411    Arc::new(
412        Language::new(
413            LanguageConfig {
414                name: "Javascript".into(),
415                path_suffixes: vec!["js".into()],
416                ..Default::default()
417            },
418            Some(tree_sitter_typescript::language_tsx()),
419        )
420        .with_embedding_query(
421            &r#"
422
423            (
424                (comment)* @context
425                .
426                (export_statement
427                    (function_declaration
428                        "async"? @name
429                        "function" @name
430                        name: (_) @name)) @item
431                    )
432
433            (
434                (comment)* @context
435                .
436                (function_declaration
437                    "async"? @name
438                    "function" @name
439                    name: (_) @name) @item
440                    )
441
442            (
443                (comment)* @context
444                .
445                (export_statement
446                    (class_declaration
447                        "class" @name
448                        name: (_) @name)) @item
449                    )
450
451            (
452                (comment)* @context
453                .
454                (class_declaration
455                    "class" @name
456                    name: (_) @name) @item
457                    )
458
459            (
460                (comment)* @context
461                .
462                (method_definition
463                    [
464                        "get"
465                        "set"
466                        "async"
467                        "*"
468                        "static"
469                    ]* @name
470                    name: (_) @name) @item
471                )
472
473            (
474                (comment)* @context
475                .
476                (export_statement
477                    (interface_declaration
478                        "interface" @name
479                        name: (_) @name)) @item
480                )
481
482            (
483                (comment)* @context
484                .
485                (interface_declaration
486                    "interface" @name
487                    name: (_) @name) @item
488                )
489
490            (
491                (comment)* @context
492                .
493                (export_statement
494                    (enum_declaration
495                        "enum" @name
496                        name: (_) @name)) @item
497                )
498
499            (
500                (comment)* @context
501                .
502                (enum_declaration
503                    "enum" @name
504                    name: (_) @name) @item
505                )
506
507                    "#
508            .unindent(),
509        )
510        .unwrap(),
511    )
512}
513
514fn rust_lang() -> Arc<Language> {
515    Arc::new(
516        Language::new(
517            LanguageConfig {
518                name: "Rust".into(),
519                path_suffixes: vec!["rs".into()],
520                ..Default::default()
521            },
522            Some(tree_sitter_rust::language()),
523        )
524        .with_embedding_query(
525            r#"
526            (
527                (line_comment)* @context
528                .
529                (enum_item
530                    name: (_) @name) @item
531            )
532
533            (
534                (line_comment)* @context
535                .
536                (struct_item
537                    name: (_) @name) @item
538            )
539
540            (
541                (line_comment)* @context
542                .
543                (impl_item
544                    trait: (_)? @name
545                    "for"? @name
546                    type: (_) @name) @item
547            )
548
549            (
550                (line_comment)* @context
551                .
552                (trait_item
553                    name: (_) @name) @item
554            )
555
556            (
557                (line_comment)* @context
558                .
559                (function_item
560                    name: (_) @name) @item
561            )
562
563            (
564                (line_comment)* @context
565                .
566                (macro_definition
567                    name: (_) @name) @item
568            )
569
570            (
571                (line_comment)* @context
572                .
573                (function_signature_item
574                    name: (_) @name) @item
575            )
576            "#,
577        )
578        .unwrap(),
579    )
580}