vector_store_tests.rs

  1use std::sync::Arc;
  2
  3use crate::{dot, embedding::EmbeddingProvider, VectorStore};
  4use anyhow::Result;
  5use async_trait::async_trait;
  6use gpui::{Task, TestAppContext};
  7use language::{Language, LanguageConfig, LanguageRegistry};
  8use project::{FakeFs, Project};
  9use rand::Rng;
 10use serde_json::json;
 11use unindent::Unindent;
 12
 13#[gpui::test]
 14async fn test_vector_store(cx: &mut TestAppContext) {
 15    let fs = FakeFs::new(cx.background());
 16    fs.insert_tree(
 17        "/the-root",
 18        json!({
 19            "src": {
 20                "file1.rs": "
 21                    fn aaa() {
 22                        println!(\"aaaa!\");
 23                    }
 24
 25                    fn zzzzzzzzz() {
 26                        println!(\"SLEEPING\");
 27                    }
 28                ".unindent(),
 29                "file2.rs": "
 30                    fn bbb() {
 31                        println!(\"bbbb!\");
 32                    }
 33                ".unindent(),
 34            }
 35        }),
 36    )
 37    .await;
 38
 39    let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
 40    let rust_language = Arc::new(
 41        Language::new(
 42            LanguageConfig {
 43                name: "Rust".into(),
 44                path_suffixes: vec!["rs".into()],
 45                ..Default::default()
 46            },
 47            Some(tree_sitter_rust::language()),
 48        )
 49        .with_embedding_query(
 50            r#"
 51            (function_item
 52                name: (identifier) @name
 53                body: (block)) @item
 54            "#,
 55        )
 56        .unwrap(),
 57    );
 58    languages.add(rust_language);
 59
 60    let db_dir = tempdir::TempDir::new("vector-store").unwrap();
 61    let db_path = db_dir.path().join("db.sqlite");
 62
 63    let store = VectorStore::new(
 64        fs.clone(),
 65        db_path,
 66        Arc::new(FakeEmbeddingProvider),
 67        languages,
 68        cx.to_async(),
 69    )
 70    .await
 71    .unwrap();
 72
 73    let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
 74    let worktree_id = project.read_with(cx, |project, cx| {
 75        project.worktrees(cx).next().unwrap().read(cx).id()
 76    });
 77    let add_project = store.update(cx, |store, cx| store.add_project(project.clone(), cx));
 78
 79    add_project.await.unwrap();
 80
 81    let search_results = store
 82        .update(cx, |store, cx| {
 83            store.search(project.clone(), "aaaa".to_string(), 5, cx)
 84        })
 85        .await
 86        .unwrap();
 87
 88    assert_eq!(search_results[0].offset, 0);
 89    assert_eq!(search_results[0].name, "aaa");
 90    assert_eq!(search_results[0].worktree_id, worktree_id);
 91}
 92
 93#[test]
 94fn test_dot_product() {
 95    assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
 96    assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
 97
 98    for _ in 0..100 {
 99        let mut rng = rand::thread_rng();
100        let a: [f32; 32] = rng.gen();
101        let b: [f32; 32] = rng.gen();
102        assert_eq!(
103            round_to_decimals(dot(&a, &b), 3),
104            round_to_decimals(reference_dot(&a, &b), 3)
105        );
106    }
107
108    fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
109        let factor = (10.0 as f32).powi(decimal_places);
110        (n * factor).round() / factor
111    }
112
113    fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
114        a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
115    }
116}
117
118struct FakeEmbeddingProvider;
119
120#[async_trait]
121impl EmbeddingProvider for FakeEmbeddingProvider {
122    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
123        Ok(spans
124            .iter()
125            .map(|span| {
126                let mut result = vec![1.0; 26];
127                for letter in span.chars() {
128                    let letter = letter.to_ascii_lowercase();
129                    if letter as u32 >= 'a' as u32 {
130                        let ix = (letter as u32) - ('a' as u32);
131                        if ix < 26 {
132                            result[ix as usize] += 1.0;
133                        }
134                    }
135                }
136
137                let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
138                for x in &mut result {
139                    *x /= norm;
140                }
141
142                result
143            })
144            .collect())
145    }
146}