vector_store_tests.rs

  1use crate::{
  2    db::dot,
  3    embedding::EmbeddingProvider,
  4    parsing::{CodeContextRetriever, Document},
  5    vector_store_settings::VectorStoreSettings,
  6    VectorStore,
  7};
  8use anyhow::Result;
  9use async_trait::async_trait;
 10use gpui::{Task, TestAppContext};
 11use language::{Language, LanguageConfig, LanguageRegistry};
 12use project::{project_settings::ProjectSettings, FakeFs, Fs, Project};
 13use rand::{rngs::StdRng, Rng};
 14use serde_json::json;
 15use settings::SettingsStore;
 16use std::{
 17    path::Path,
 18    sync::{
 19        atomic::{self, AtomicUsize},
 20        Arc,
 21    },
 22};
 23use unindent::Unindent;
 24
 25#[ctor::ctor]
 26fn init_logger() {
 27    if std::env::var("RUST_LOG").is_ok() {
 28        env_logger::init();
 29    }
 30}
 31
 32#[gpui::test]
 33async fn test_vector_store(cx: &mut TestAppContext) {
 34    cx.update(|cx| {
 35        cx.set_global(SettingsStore::test(cx));
 36        settings::register::<VectorStoreSettings>(cx);
 37        settings::register::<ProjectSettings>(cx);
 38    });
 39
 40    let fs = FakeFs::new(cx.background());
 41    fs.insert_tree(
 42        "/the-root",
 43        json!({
 44            "src": {
 45                "file1.rs": "
 46                    fn aaa() {
 47                        println!(\"aaaa!\");
 48                    }
 49
 50                    fn zzzzzzzzz() {
 51                        println!(\"SLEEPING\");
 52                    }
 53                ".unindent(),
 54                "file2.rs": "
 55                    fn bbb() {
 56                        println!(\"bbbb!\");
 57                    }
 58                ".unindent(),
 59                "file3.toml": "
 60                    ZZZZZZZ = 5
 61                    ".unindent(),
 62            }
 63        }),
 64    )
 65    .await;
 66
 67    let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
 68    let rust_language = rust_lang();
 69    let toml_language = toml_lang();
 70    languages.add(rust_language);
 71    languages.add(toml_language);
 72
 73    let db_dir = tempdir::TempDir::new("vector-store").unwrap();
 74    let db_path = db_dir.path().join("db.sqlite");
 75
 76    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 77    let store = VectorStore::new(
 78        fs.clone(),
 79        db_path,
 80        embedding_provider.clone(),
 81        languages,
 82        cx.to_async(),
 83    )
 84    .await
 85    .unwrap();
 86
 87    let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
 88    let worktree_id = project.read_with(cx, |project, cx| {
 89        project.worktrees(cx).next().unwrap().read(cx).id()
 90    });
 91    let file_count = store
 92        .update(cx, |store, cx| store.index_project(project.clone(), cx))
 93        .await
 94        .unwrap();
 95    assert_eq!(file_count, 3);
 96    cx.foreground().run_until_parked();
 97    store.update(cx, |store, _cx| {
 98        assert_eq!(
 99            store.remaining_files_to_index_for_project(&project),
100            Some(0)
101        );
102    });
103
104    let search_results = store
105        .update(cx, |store, cx| {
106            store.search_project(project.clone(), "aaaa".to_string(), 5, cx)
107        })
108        .await
109        .unwrap();
110
111    assert_eq!(search_results[0].byte_range.start, 0);
112    assert_eq!(search_results[0].name, "aaa");
113    assert_eq!(search_results[0].worktree_id, worktree_id);
114
115    fs.save(
116        "/the-root/src/file2.rs".as_ref(),
117        &"
118            fn dddd() { println!(\"ddddd!\"); }
119            struct pqpqpqp {}
120        "
121        .unindent()
122        .into(),
123        Default::default(),
124    )
125    .await
126    .unwrap();
127
128    cx.foreground().run_until_parked();
129
130    let prev_embedding_count = embedding_provider.embedding_count();
131    let file_count = store
132        .update(cx, |store, cx| store.index_project(project.clone(), cx))
133        .await
134        .unwrap();
135    assert_eq!(file_count, 1);
136
137    cx.foreground().run_until_parked();
138    store.update(cx, |store, _cx| {
139        assert_eq!(
140            store.remaining_files_to_index_for_project(&project),
141            Some(0)
142        );
143    });
144
145    assert_eq!(
146        embedding_provider.embedding_count() - prev_embedding_count,
147        2
148    );
149}
150
151#[gpui::test]
152async fn test_code_context_retrieval_rust() {
153    let language = rust_lang();
154    let mut retriever = CodeContextRetriever::new();
155
156    let text = "
157        /// A doc comment
158        /// that spans multiple lines
159        fn a() {
160            b
161        }
162
163        impl C for D {
164        }
165    "
166    .unindent();
167
168    let parsed_files = retriever
169        .parse_file(Path::new("foo.rs"), &text, language)
170        .unwrap();
171
172    assert_eq!(
173        parsed_files,
174        &[
175            Document {
176                name: "a".into(),
177                range: text.find("fn a").unwrap()..(text.find("}").unwrap() + 1),
178                content: "
179                    The below code snippet is from file 'foo.rs'
180
181                    ```rust
182                    /// A doc comment
183                    /// that spans multiple lines
184                    fn a() {
185                        b
186                    }
187                    ```"
188                .unindent(),
189                embedding: vec![],
190            },
191            Document {
192                name: "C for D".into(),
193                range: text.find("impl C").unwrap()..(text.rfind("}").unwrap() + 1),
194                content: "
195                    The below code snippet is from file 'foo.rs'
196
197                    ```rust
198                    impl C for D {
199                    }
200                    ```"
201                .unindent(),
202                embedding: vec![],
203            }
204        ]
205    );
206}
207
208#[gpui::test]
209async fn test_code_context_retrieval_javascript() {
210    let language = js_lang();
211    let mut retriever = CodeContextRetriever::new();
212
213    let text = "
214/* globals importScripts, backend */
215function _authorize() {}
216
217/**
218 * Sometimes the frontend build is way faster than backend.
219 */
220export async function authorizeBank() {
221    _authorize(pushModal, upgradingAccountId, {});
222}
223
224export class SettingsPage {
225    /* This is a test setting */
226    constructor(page) {
227        this.page = page;
228    }
229}
230
231/* This is a test comment */
232class TestClass {}
233
234/* Schema for editor_events in Clickhouse. */
235export interface ClickhouseEditorEvent {
236    installation_id: string
237    operation: string
238}
239";
240
241    let parsed_files = retriever
242        .parse_file(Path::new("foo.js"), &text, language)
243        .unwrap();
244
245    let test_documents = &[
246        Document {
247            name: "function _authorize".into(),
248            range: text.find("function _authorize").unwrap()..(text.find("}").unwrap() + 1),
249            content: "
250                    The below code snippet is from file 'foo.js'
251
252                    ```javascript
253                    /* globals importScripts, backend */
254                    function _authorize() {}
255                    ```"
256            .unindent(),
257            embedding: vec![],
258        },
259        Document {
260            name: "async function authorizeBank".into(),
261            range: text.find("export async").unwrap()..224,
262            content: "
263                    The below code snippet is from file 'foo.js'
264
265                    ```javascript
266                    /**
267                     * Sometimes the frontend build is way faster than backend.
268                     */
269                    export async function authorizeBank() {
270                        _authorize(pushModal, upgradingAccountId, {});
271                    }
272                    ```"
273            .unindent(),
274            embedding: vec![],
275        },
276        Document {
277            name: "class SettingsPage".into(),
278            range: 226..344,
279            content: "
280                    The below code snippet is from file 'foo.js'
281
282                    ```javascript
283                    export class SettingsPage {
284                        /* This is a test setting */
285                        constructor(page) {
286                            this.page = page;
287                        }
288                    }
289                    ```"
290            .unindent(),
291            embedding: vec![],
292        },
293        Document {
294            name: "constructor".into(),
295            range: 291..342,
296            content: "
297                The below code snippet is from file 'foo.js'
298
299                ```javascript
300                /* This is a test setting */
301                constructor(page) {
302                        this.page = page;
303                    }
304                ```"
305            .unindent(),
306            embedding: vec![],
307        },
308        Document {
309            name: "class TestClass".into(),
310            range: 375..393,
311            content: "
312                    The below code snippet is from file 'foo.js'
313
314                    ```javascript
315                    /* This is a test comment */
316                    class TestClass {}
317                    ```"
318            .unindent(),
319            embedding: vec![],
320        },
321        Document {
322            name: "interface ClickhouseEditorEvent".into(),
323            range: 441..533,
324            content: "
325                    The below code snippet is from file 'foo.js'
326
327                    ```javascript
328                    /* Schema for editor_events in Clickhouse. */
329                    export interface ClickhouseEditorEvent {
330                        installation_id: string
331                        operation: string
332                    }
333                    ```"
334            .unindent(),
335            embedding: vec![],
336        },
337    ];
338
339    for idx in 0..test_documents.len() {
340        assert_eq!(test_documents[idx], parsed_files[idx]);
341    }
342}
343
344#[gpui::test]
345fn test_dot_product(mut rng: StdRng) {
346    assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
347    assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
348
349    for _ in 0..100 {
350        let size = 1536;
351        let mut a = vec![0.; size];
352        let mut b = vec![0.; size];
353        for (a, b) in a.iter_mut().zip(b.iter_mut()) {
354            *a = rng.gen();
355            *b = rng.gen();
356        }
357
358        assert_eq!(
359            round_to_decimals(dot(&a, &b), 1),
360            round_to_decimals(reference_dot(&a, &b), 1)
361        );
362    }
363
364    fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
365        let factor = (10.0 as f32).powi(decimal_places);
366        (n * factor).round() / factor
367    }
368
369    fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
370        a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
371    }
372}
373
374#[derive(Default)]
375struct FakeEmbeddingProvider {
376    embedding_count: AtomicUsize,
377}
378
379impl FakeEmbeddingProvider {
380    fn embedding_count(&self) -> usize {
381        self.embedding_count.load(atomic::Ordering::SeqCst)
382    }
383}
384
385#[async_trait]
386impl EmbeddingProvider for FakeEmbeddingProvider {
387    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
388        self.embedding_count
389            .fetch_add(spans.len(), atomic::Ordering::SeqCst);
390        Ok(spans
391            .iter()
392            .map(|span| {
393                let mut result = vec![1.0; 26];
394                for letter in span.chars() {
395                    let letter = letter.to_ascii_lowercase();
396                    if letter as u32 >= 'a' as u32 {
397                        let ix = (letter as u32) - ('a' as u32);
398                        if ix < 26 {
399                            result[ix as usize] += 1.0;
400                        }
401                    }
402                }
403
404                let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
405                for x in &mut result {
406                    *x /= norm;
407                }
408
409                result
410            })
411            .collect())
412    }
413}
414
415fn js_lang() -> Arc<Language> {
416    Arc::new(
417        Language::new(
418            LanguageConfig {
419                name: "Javascript".into(),
420                path_suffixes: vec!["js".into()],
421                ..Default::default()
422            },
423            Some(tree_sitter_typescript::language_tsx()),
424        )
425        .with_embedding_query(
426            &r#"
427
428            (
429                (comment)* @context
430                .
431                (export_statement
432                    (function_declaration
433                        "async"? @name
434                        "function" @name
435                        name: (_) @name)) @item
436                    )
437
438            (
439                (comment)* @context
440                .
441                (function_declaration
442                    "async"? @name
443                    "function" @name
444                    name: (_) @name) @item
445                    )
446
447            (
448                (comment)* @context
449                .
450                (export_statement
451                    (class_declaration
452                        "class" @name
453                        name: (_) @name)) @item
454                    )
455
456            (
457                (comment)* @context
458                .
459                (class_declaration
460                    "class" @name
461                    name: (_) @name) @item
462                    )
463
464            (
465                (comment)* @context
466                .
467                (method_definition
468                    [
469                        "get"
470                        "set"
471                        "async"
472                        "*"
473                        "static"
474                    ]* @name
475                    name: (_) @name) @item
476                )
477
478            (
479                (comment)* @context
480                .
481                (export_statement
482                    (interface_declaration
483                        "interface" @name
484                        name: (_) @name)) @item
485                )
486
487            (
488                (comment)* @context
489                .
490                (interface_declaration
491                    "interface" @name
492                    name: (_) @name) @item
493                )
494
495            (
496                (comment)* @context
497                .
498                (export_statement
499                    (enum_declaration
500                        "enum" @name
501                        name: (_) @name)) @item
502                )
503
504            (
505                (comment)* @context
506                .
507                (enum_declaration
508                    "enum" @name
509                    name: (_) @name) @item
510                )
511
512                    "#
513            .unindent(),
514        )
515        .unwrap(),
516    )
517}
518
519fn rust_lang() -> Arc<Language> {
520    Arc::new(
521        Language::new(
522            LanguageConfig {
523                name: "Rust".into(),
524                path_suffixes: vec!["rs".into()],
525                ..Default::default()
526            },
527            Some(tree_sitter_rust::language()),
528        )
529        .with_embedding_query(
530            r#"
531            (
532                (line_comment)* @context
533                .
534                (enum_item
535                    name: (_) @name) @item
536            )
537
538            (
539                (line_comment)* @context
540                .
541                (struct_item
542                    name: (_) @name) @item
543            )
544
545            (
546                (line_comment)* @context
547                .
548                (impl_item
549                    trait: (_)? @name
550                    "for"? @name
551                    type: (_) @name) @item
552            )
553
554            (
555                (line_comment)* @context
556                .
557                (trait_item
558                    name: (_) @name) @item
559            )
560
561            (
562                (line_comment)* @context
563                .
564                (function_item
565                    name: (_) @name) @item
566            )
567
568            (
569                (line_comment)* @context
570                .
571                (macro_definition
572                    name: (_) @name) @item
573            )
574
575            (
576                (line_comment)* @context
577                .
578                (function_signature_item
579                    name: (_) @name) @item
580            )
581            "#,
582        )
583        .unwrap(),
584    )
585}
586
587fn toml_lang() -> Arc<Language> {
588    Arc::new(Language::new(
589        LanguageConfig {
590            name: "TOML".into(),
591            path_suffixes: vec!["toml".into()],
592            ..Default::default()
593        },
594        Some(tree_sitter_toml::language()),
595    ))
596}