vector_store_tests.rs

  1use crate::{db::dot, embedding::EmbeddingProvider, VectorStore};
  2use anyhow::Result;
  3use async_trait::async_trait;
  4use gpui::{Task, TestAppContext};
  5use language::{Language, LanguageConfig, LanguageRegistry};
  6use project::{FakeFs, Project};
  7use rand::Rng;
  8use serde_json::json;
  9use std::sync::Arc;
 10use unindent::Unindent;
 11
 12#[gpui::test]
 13async fn test_vector_store(cx: &mut TestAppContext) {
 14    let fs = FakeFs::new(cx.background());
 15    fs.insert_tree(
 16        "/the-root",
 17        json!({
 18            "src": {
 19                "file1.rs": "
 20                    fn aaa() {
 21                        println!(\"aaaa!\");
 22                    }
 23
 24                    fn zzzzzzzzz() {
 25                        println!(\"SLEEPING\");
 26                    }
 27                ".unindent(),
 28                "file2.rs": "
 29                    fn bbb() {
 30                        println!(\"bbbb!\");
 31                    }
 32                ".unindent(),
 33            }
 34        }),
 35    )
 36    .await;
 37
 38    let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
 39    let rust_language = Arc::new(
 40        Language::new(
 41            LanguageConfig {
 42                name: "Rust".into(),
 43                path_suffixes: vec!["rs".into()],
 44                ..Default::default()
 45            },
 46            Some(tree_sitter_rust::language()),
 47        )
 48        .with_embedding_query(
 49            r#"
 50            (function_item
 51                name: (identifier) @name
 52                body: (block)) @item
 53            "#,
 54        )
 55        .unwrap(),
 56    );
 57    languages.add(rust_language);
 58
 59    let db_dir = tempdir::TempDir::new("vector-store").unwrap();
 60    let db_path = db_dir.path().join("db.sqlite");
 61
 62    let store = VectorStore::new(
 63        fs.clone(),
 64        db_path,
 65        Arc::new(FakeEmbeddingProvider),
 66        languages,
 67        cx.to_async(),
 68    )
 69    .await
 70    .unwrap();
 71
 72    let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
 73    let worktree_id = project.read_with(cx, |project, cx| {
 74        project.worktrees(cx).next().unwrap().read(cx).id()
 75    });
 76    let add_project = store.update(cx, |store, cx| store.add_project(project.clone(), cx));
 77
 78    add_project.await.unwrap();
 79
 80    let search_results = store
 81        .update(cx, |store, cx| {
 82            store.search(project.clone(), "aaaa".to_string(), 5, cx)
 83        })
 84        .await
 85        .unwrap();
 86
 87    assert_eq!(search_results[0].offset, 0);
 88    assert_eq!(search_results[0].name, "aaa");
 89    assert_eq!(search_results[0].worktree_id, worktree_id);
 90}
 91
 92#[test]
 93fn test_dot_product() {
 94    assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
 95    assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
 96
 97    for _ in 0..100 {
 98        let mut rng = rand::thread_rng();
 99        let a: [f32; 32] = rng.gen();
100        let b: [f32; 32] = rng.gen();
101        assert_eq!(
102            round_to_decimals(dot(&a, &b), 3),
103            round_to_decimals(reference_dot(&a, &b), 3)
104        );
105    }
106
107    fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
108        let factor = (10.0 as f32).powi(decimal_places);
109        (n * factor).round() / factor
110    }
111
112    fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
113        a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
114    }
115}
116
117struct FakeEmbeddingProvider;
118
119#[async_trait]
120impl EmbeddingProvider for FakeEmbeddingProvider {
121    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
122        Ok(spans
123            .iter()
124            .map(|span| {
125                let mut result = vec![1.0; 26];
126                for letter in span.chars() {
127                    let letter = letter.to_ascii_lowercase();
128                    if letter as u32 >= 'a' as u32 {
129                        let ix = (letter as u32) - ('a' as u32);
130                        if ix < 26 {
131                            result[ix as usize] += 1.0;
132                        }
133                    }
134                }
135
136                let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
137                for x in &mut result {
138                    *x /= norm;
139                }
140
141                result
142            })
143            .collect())
144    }
145}