vector_store_tests.rs

  1use crate::{
  2    db::dot, embedding::EmbeddingProvider, vector_store_settings::VectorStoreSettings, VectorStore,
  3};
  4use anyhow::Result;
  5use async_trait::async_trait;
  6use gpui::{Task, TestAppContext};
  7use language::{Language, LanguageConfig, LanguageRegistry};
  8use project::{FakeFs, Project};
  9use rand::{rngs::StdRng, Rng};
 10use serde_json::json;
 11use settings::SettingsStore;
 12use std::sync::Arc;
 13use unindent::Unindent;
 14
 15#[gpui::test]
 16async fn test_vector_store(cx: &mut TestAppContext) {
 17    cx.update(|cx| {
 18        cx.set_global(SettingsStore::test(cx));
 19        settings::register::<VectorStoreSettings>(cx);
 20    });
 21
 22    let fs = FakeFs::new(cx.background());
 23    fs.insert_tree(
 24        "/the-root",
 25        json!({
 26            "src": {
 27                "file1.rs": "
 28                    fn aaa() {
 29                        println!(\"aaaa!\");
 30                    }
 31
 32                    fn zzzzzzzzz() {
 33                        println!(\"SLEEPING\");
 34                    }
 35                ".unindent(),
 36                "file2.rs": "
 37                    fn bbb() {
 38                        println!(\"bbbb!\");
 39                    }
 40                ".unindent(),
 41            }
 42        }),
 43    )
 44    .await;
 45
 46    let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
 47    let rust_language = Arc::new(
 48        Language::new(
 49            LanguageConfig {
 50                name: "Rust".into(),
 51                path_suffixes: vec!["rs".into()],
 52                ..Default::default()
 53            },
 54            Some(tree_sitter_rust::language()),
 55        )
 56        .with_embedding_query(
 57            r#"
 58            (function_item
 59                name: (identifier) @name
 60                body: (block)) @item
 61            "#,
 62        )
 63        .unwrap(),
 64    );
 65    languages.add(rust_language);
 66
 67    let db_dir = tempdir::TempDir::new("vector-store").unwrap();
 68    let db_path = db_dir.path().join("db.sqlite");
 69
 70    let store = VectorStore::new(
 71        fs.clone(),
 72        db_path,
 73        Arc::new(FakeEmbeddingProvider),
 74        languages,
 75        cx.to_async(),
 76    )
 77    .await
 78    .unwrap();
 79
 80    let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
 81    let worktree_id = project.read_with(cx, |project, cx| {
 82        project.worktrees(cx).next().unwrap().read(cx).id()
 83    });
 84    store
 85        .update(cx, |store, cx| store.add_project(project.clone(), cx))
 86        .await
 87        .unwrap();
 88    cx.foreground().run_until_parked();
 89
 90    let search_results = store
 91        .update(cx, |store, cx| {
 92            store.search(project.clone(), "aaaa".to_string(), 5, cx)
 93        })
 94        .await
 95        .unwrap();
 96
 97    assert_eq!(search_results[0].offset, 0);
 98    assert_eq!(search_results[0].name, "aaa");
 99    assert_eq!(search_results[0].worktree_id, worktree_id);
100}
101
102#[gpui::test]
103fn test_dot_product(mut rng: StdRng) {
104    assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
105    assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
106
107    for _ in 0..100 {
108        let size = 1536;
109        let mut a = vec![0.; size];
110        let mut b = vec![0.; size];
111        for (a, b) in a.iter_mut().zip(b.iter_mut()) {
112            *a = rng.gen();
113            *b = rng.gen();
114        }
115
116        assert_eq!(
117            round_to_decimals(dot(&a, &b), 1),
118            round_to_decimals(reference_dot(&a, &b), 1)
119        );
120    }
121
122    fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
123        let factor = (10.0 as f32).powi(decimal_places);
124        (n * factor).round() / factor
125    }
126
127    fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
128        a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
129    }
130}
131
132struct FakeEmbeddingProvider;
133
134#[async_trait]
135impl EmbeddingProvider for FakeEmbeddingProvider {
136    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
137        Ok(spans
138            .iter()
139            .map(|span| {
140                let mut result = vec![1.0; 26];
141                for letter in span.chars() {
142                    let letter = letter.to_ascii_lowercase();
143                    if letter as u32 >= 'a' as u32 {
144                        let ix = (letter as u32) - ('a' as u32);
145                        if ix < 26 {
146                            result[ix as usize] += 1.0;
147                        }
148                    }
149                }
150
151                let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
152                for x in &mut result {
153                    *x /= norm;
154                }
155
156                result
157            })
158            .collect())
159    }
160}