vector_store_tests.rs

  1use crate::{
  2    db::dot, embedding::EmbeddingProvider, vector_store_settings::VectorStoreSettings, VectorStore,
  3};
  4use anyhow::Result;
  5use async_trait::async_trait;
  6use gpui::{Task, TestAppContext};
  7use language::{Language, LanguageConfig, LanguageRegistry};
  8use project::{FakeFs, Project};
  9use rand::Rng;
 10use serde_json::json;
 11use settings::SettingsStore;
 12use std::sync::Arc;
 13use unindent::Unindent;
 14
 15#[gpui::test]
 16async fn test_vector_store(cx: &mut TestAppContext) {
 17    cx.update(|cx| {
 18        cx.set_global(SettingsStore::test(cx));
 19        settings::register::<VectorStoreSettings>(cx);
 20    });
 21
 22    let fs = FakeFs::new(cx.background());
 23    fs.insert_tree(
 24        "/the-root",
 25        json!({
 26            "src": {
 27                "file1.rs": "
 28                    fn aaa() {
 29                        println!(\"aaaa!\");
 30                    }
 31
 32                    fn zzzzzzzzz() {
 33                        println!(\"SLEEPING\");
 34                    }
 35                ".unindent(),
 36                "file2.rs": "
 37                    fn bbb() {
 38                        println!(\"bbbb!\");
 39                    }
 40                ".unindent(),
 41            }
 42        }),
 43    )
 44    .await;
 45
 46    let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
 47    let rust_language = Arc::new(
 48        Language::new(
 49            LanguageConfig {
 50                name: "Rust".into(),
 51                path_suffixes: vec!["rs".into()],
 52                ..Default::default()
 53            },
 54            Some(tree_sitter_rust::language()),
 55        )
 56        .with_embedding_query(
 57            r#"
 58            (function_item
 59                name: (identifier) @name
 60                body: (block)) @item
 61            "#,
 62        )
 63        .unwrap(),
 64    );
 65    languages.add(rust_language);
 66
 67    let db_dir = tempdir::TempDir::new("vector-store").unwrap();
 68    let db_path = db_dir.path().join("db.sqlite");
 69
 70    let store = VectorStore::new(
 71        fs.clone(),
 72        db_path,
 73        Arc::new(FakeEmbeddingProvider),
 74        languages,
 75        cx.to_async(),
 76    )
 77    .await
 78    .unwrap();
 79
 80    let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
 81    let worktree_id = project.read_with(cx, |project, cx| {
 82        project.worktrees(cx).next().unwrap().read(cx).id()
 83    });
 84    let add_project = store.update(cx, |store, cx| store.add_project(project.clone(), cx));
 85
 86    add_project.await.unwrap();
 87
 88    let search_results = store
 89        .update(cx, |store, cx| {
 90            store.search(project.clone(), "aaaa".to_string(), 5, cx)
 91        })
 92        .await
 93        .unwrap();
 94
 95    assert_eq!(search_results[0].offset, 0);
 96    assert_eq!(search_results[0].name, "aaa");
 97    assert_eq!(search_results[0].worktree_id, worktree_id);
 98}
 99
100#[test]
101fn test_dot_product() {
102    assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
103    assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
104
105    for _ in 0..100 {
106        let mut rng = rand::thread_rng();
107        let a: [f32; 32] = rng.gen();
108        let b: [f32; 32] = rng.gen();
109        assert_eq!(
110            round_to_decimals(dot(&a, &b), 3),
111            round_to_decimals(reference_dot(&a, &b), 3)
112        );
113    }
114
115    fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
116        let factor = (10.0 as f32).powi(decimal_places);
117        (n * factor).round() / factor
118    }
119
120    fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
121        a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
122    }
123}
124
125struct FakeEmbeddingProvider;
126
127#[async_trait]
128impl EmbeddingProvider for FakeEmbeddingProvider {
129    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
130        Ok(spans
131            .iter()
132            .map(|span| {
133                let mut result = vec![1.0; 26];
134                for letter in span.chars() {
135                    let letter = letter.to_ascii_lowercase();
136                    if letter as u32 >= 'a' as u32 {
137                        let ix = (letter as u32) - ('a' as u32);
138                        if ix < 26 {
139                            result[ix as usize] += 1.0;
140                        }
141                    }
142                }
143
144                let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
145                for x in &mut result {
146                    *x /= norm;
147                }
148
149                result
150            })
151            .collect())
152    }
153}