vector_store_tests.rs

  1use std::sync::Arc;
  2
  3use crate::{dot, embedding::EmbeddingProvider, VectorStore};
  4use anyhow::Result;
  5use async_trait::async_trait;
  6use gpui::{Task, TestAppContext};
  7use language::{Language, LanguageConfig, LanguageRegistry};
  8use project::{FakeFs, Project};
  9use rand::Rng;
 10use serde_json::json;
 11use unindent::Unindent;
 12
 13#[gpui::test]
 14async fn test_vector_store(cx: &mut TestAppContext) {
 15    let fs = FakeFs::new(cx.background());
 16    fs.insert_tree(
 17        "/the-root",
 18        json!({
 19            "src": {
 20                "file1.rs": "
 21                    fn aaa() {
 22                        println!(\"aaaa!\");
 23                    }
 24
 25                    fn zzzzzzzzz() {
 26                        println!(\"SLEEPING\");
 27                    }
 28                ".unindent(),
 29                "file2.rs": "
 30                    fn bbb() {
 31                        println!(\"bbbb!\");
 32                    }
 33                ".unindent(),
 34            }
 35        }),
 36    )
 37    .await;
 38
 39    let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
 40    let rust_language = Arc::new(
 41        Language::new(
 42            LanguageConfig {
 43                name: "Rust".into(),
 44                path_suffixes: vec!["rs".into()],
 45                ..Default::default()
 46            },
 47            Some(tree_sitter_rust::language()),
 48        )
 49        .with_outline_query(
 50            r#"
 51            (function_item
 52                name: (identifier) @name
 53                body: (block)) @item
 54            "#,
 55        )
 56        .unwrap(),
 57    );
 58    languages.add(rust_language);
 59
 60    let db_dir = tempdir::TempDir::new("vector-store").unwrap();
 61    let db_path = db_dir.path().join("db.sqlite");
 62
 63    let store = cx.add_model(|_| {
 64        VectorStore::new(
 65            fs.clone(),
 66            db_path.to_string_lossy().to_string(),
 67            Arc::new(FakeEmbeddingProvider),
 68            languages,
 69        )
 70    });
 71
 72    let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
 73    let add_project = store.update(cx, |store, cx| store.add_project(project, cx));
 74
 75    // TODO - remove
 76    cx.foreground()
 77        .advance_clock(std::time::Duration::from_secs(3));
 78
 79    add_project.await.unwrap();
 80
 81    let search_results = store
 82        .update(cx, |store, cx| store.search("aaaa".to_string(), 5, cx))
 83        .await
 84        .unwrap();
 85
 86    assert_eq!(search_results[0].offset, 0);
 87    assert_eq!(search_results[0].name, "aaa");
 88}
 89
 90#[test]
 91fn test_dot_product() {
 92    assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
 93    assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
 94
 95    for _ in 0..100 {
 96        let mut rng = rand::thread_rng();
 97        let a: [f32; 32] = rng.gen();
 98        let b: [f32; 32] = rng.gen();
 99        assert_eq!(
100            round_to_decimals(dot(&a, &b), 3),
101            round_to_decimals(reference_dot(&a, &b), 3)
102        );
103    }
104
105    fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
106        let factor = (10.0 as f32).powi(decimal_places);
107        (n * factor).round() / factor
108    }
109
110    fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
111        a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
112    }
113}
114
115struct FakeEmbeddingProvider;
116
117#[async_trait]
118impl EmbeddingProvider for FakeEmbeddingProvider {
119    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
120        Ok(spans
121            .iter()
122            .map(|span| {
123                let mut result = vec![1.0; 26];
124                for letter in span.chars() {
125                    let letter = letter.to_ascii_lowercase();
126                    if letter as u32 >= 'a' as u32 {
127                        let ix = (letter as u32) - ('a' as u32);
128                        if ix < 26 {
129                            result[ix as usize] += 1.0;
130                        }
131                    }
132                }
133
134                let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
135                for x in &mut result {
136                    *x /= norm;
137                }
138
139                result
140            })
141            .collect())
142    }
143}