vector_store_tests.rs

  1use crate::{
  2    db::dot, embedding::EmbeddingProvider, vector_store_settings::VectorStoreSettings, VectorStore,
  3};
  4use anyhow::Result;
  5use async_trait::async_trait;
  6use gpui::{Task, TestAppContext};
  7use language::{Language, LanguageConfig, LanguageRegistry};
  8use project::{project_settings::ProjectSettings, FakeFs, Project};
  9use rand::{rngs::StdRng, Rng};
 10use serde_json::json;
 11use settings::SettingsStore;
 12use std::sync::Arc;
 13use unindent::Unindent;
 14
 15#[gpui::test]
 16async fn test_vector_store(cx: &mut TestAppContext) {
 17    cx.update(|cx| {
 18        cx.set_global(SettingsStore::test(cx));
 19        settings::register::<VectorStoreSettings>(cx);
 20        settings::register::<ProjectSettings>(cx);
 21    });
 22
 23    let fs = FakeFs::new(cx.background());
 24    fs.insert_tree(
 25        "/the-root",
 26        json!({
 27            "src": {
 28                "file1.rs": "
 29                    fn aaa() {
 30                        println!(\"aaaa!\");
 31                    }
 32
 33                    fn zzzzzzzzz() {
 34                        println!(\"SLEEPING\");
 35                    }
 36                ".unindent(),
 37                "file2.rs": "
 38                    fn bbb() {
 39                        println!(\"bbbb!\");
 40                    }
 41                ".unindent(),
 42            }
 43        }),
 44    )
 45    .await;
 46
 47    let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
 48    let rust_language = Arc::new(
 49        Language::new(
 50            LanguageConfig {
 51                name: "Rust".into(),
 52                path_suffixes: vec!["rs".into()],
 53                ..Default::default()
 54            },
 55            Some(tree_sitter_rust::language()),
 56        )
 57        .with_embedding_query(
 58            r#"
 59            (function_item
 60                name: (identifier) @name
 61                body: (block)) @item
 62            "#,
 63        )
 64        .unwrap(),
 65    );
 66    languages.add(rust_language);
 67
 68    let db_dir = tempdir::TempDir::new("vector-store").unwrap();
 69    let db_path = db_dir.path().join("db.sqlite");
 70
 71    let store = VectorStore::new(
 72        fs.clone(),
 73        db_path,
 74        Arc::new(FakeEmbeddingProvider),
 75        languages,
 76        cx.to_async(),
 77    )
 78    .await
 79    .unwrap();
 80
 81    let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
 82    let worktree_id = project.read_with(cx, |project, cx| {
 83        project.worktrees(cx).next().unwrap().read(cx).id()
 84    });
 85    store
 86        .update(cx, |store, cx| store.add_project(project.clone(), cx))
 87        .await
 88        .unwrap();
 89    cx.foreground().run_until_parked();
 90
 91    let search_results = store
 92        .update(cx, |store, cx| {
 93            store.search(project.clone(), "aaaa".to_string(), 5, cx)
 94        })
 95        .await
 96        .unwrap();
 97
 98    assert_eq!(search_results[0].offset, 0);
 99    assert_eq!(search_results[0].name, "aaa");
100    assert_eq!(search_results[0].worktree_id, worktree_id);
101}
102
103#[gpui::test]
104fn test_dot_product(mut rng: StdRng) {
105    assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
106    assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
107
108    for _ in 0..100 {
109        let size = 1536;
110        let mut a = vec![0.; size];
111        let mut b = vec![0.; size];
112        for (a, b) in a.iter_mut().zip(b.iter_mut()) {
113            *a = rng.gen();
114            *b = rng.gen();
115        }
116
117        assert_eq!(
118            round_to_decimals(dot(&a, &b), 1),
119            round_to_decimals(reference_dot(&a, &b), 1)
120        );
121    }
122
123    fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
124        let factor = (10.0 as f32).powi(decimal_places);
125        (n * factor).round() / factor
126    }
127
128    fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
129        a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
130    }
131}
132
133struct FakeEmbeddingProvider;
134
135#[async_trait]
136impl EmbeddingProvider for FakeEmbeddingProvider {
137    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
138        Ok(spans
139            .iter()
140            .map(|span| {
141                let mut result = vec![1.0; 26];
142                for letter in span.chars() {
143                    let letter = letter.to_ascii_lowercase();
144                    if letter as u32 >= 'a' as u32 {
145                        let ix = (letter as u32) - ('a' as u32);
146                        if ix < 26 {
147                            result[ix as usize] += 1.0;
148                        }
149                    }
150                }
151
152                let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
153                for x in &mut result {
154                    *x /= norm;
155                }
156
157                result
158            })
159            .collect())
160    }
161}