vector_store_tests.rs

  1use std::sync::Arc;
  2
  3use crate::{dot, embedding::EmbeddingProvider, VectorStore};
  4use anyhow::Result;
  5use async_trait::async_trait;
  6use gpui::{Task, TestAppContext};
  7use language::{Language, LanguageConfig, LanguageRegistry};
  8use project::{FakeFs, Project};
  9use rand::Rng;
 10use serde_json::json;
 11use unindent::Unindent;
 12
 13#[gpui::test]
 14async fn test_vector_store(cx: &mut TestAppContext) {
 15    let fs = FakeFs::new(cx.background());
 16    fs.insert_tree(
 17        "/the-root",
 18        json!({
 19            "src": {
 20                "file1.rs": "
 21                    fn aaa() {
 22                        println!(\"aaaa!\");
 23                    }
 24
 25                    fn zzzzzzzzz() {
 26                        println!(\"SLEEPING\");
 27                    }
 28                ".unindent(),
 29                "file2.rs": "
 30                    fn bbb() {
 31                        println!(\"bbbb!\");
 32                    }
 33                ".unindent(),
 34            }
 35        }),
 36    )
 37    .await;
 38
 39    let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
 40    let rust_language = Arc::new(
 41        Language::new(
 42            LanguageConfig {
 43                name: "Rust".into(),
 44                path_suffixes: vec!["rs".into()],
 45                ..Default::default()
 46            },
 47            Some(tree_sitter_rust::language()),
 48        )
 49        .with_embedding_query(
 50            r#"
 51            (function_item
 52                name: (identifier) @name
 53                body: (block)) @item
 54            "#,
 55        )
 56        .unwrap(),
 57    );
 58    languages.add(rust_language);
 59
 60    let db_dir = tempdir::TempDir::new("vector-store").unwrap();
 61    let db_path = db_dir.path().join("db.sqlite");
 62
 63    let store = cx.add_model(|_| {
 64        VectorStore::new(
 65            fs.clone(),
 66            db_path,
 67            Arc::new(FakeEmbeddingProvider),
 68            languages,
 69        )
 70    });
 71
 72    let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
 73    let worktree_id = project.read_with(cx, |project, cx| {
 74        project.worktrees(cx).next().unwrap().read(cx).id()
 75    });
 76    let add_project = store.update(cx, |store, cx| store.add_project(project.clone(), cx));
 77
 78    // TODO - remove
 79    cx.foreground()
 80        .advance_clock(std::time::Duration::from_secs(3));
 81
 82    add_project.await.unwrap();
 83
 84    let search_results = store
 85        .update(cx, |store, cx| {
 86            store.search(&project, "aaaa".to_string(), 5, cx)
 87        })
 88        .await
 89        .unwrap();
 90
 91    assert_eq!(search_results[0].offset, 0);
 92    assert_eq!(search_results[0].name, "aaa");
 93    assert_eq!(search_results[0].worktree_id, worktree_id);
 94}
 95
 96#[test]
 97fn test_dot_product() {
 98    assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
 99    assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
100
101    for _ in 0..100 {
102        let mut rng = rand::thread_rng();
103        let a: [f32; 32] = rng.gen();
104        let b: [f32; 32] = rng.gen();
105        assert_eq!(
106            round_to_decimals(dot(&a, &b), 3),
107            round_to_decimals(reference_dot(&a, &b), 3)
108        );
109    }
110
111    fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
112        let factor = (10.0 as f32).powi(decimal_places);
113        (n * factor).round() / factor
114    }
115
116    fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
117        a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
118    }
119}
120
121struct FakeEmbeddingProvider;
122
123#[async_trait]
124impl EmbeddingProvider for FakeEmbeddingProvider {
125    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
126        Ok(spans
127            .iter()
128            .map(|span| {
129                let mut result = vec![1.0; 26];
130                for letter in span.chars() {
131                    let letter = letter.to_ascii_lowercase();
132                    if letter as u32 >= 'a' as u32 {
133                        let ix = (letter as u32) - ('a' as u32);
134                        if ix < 26 {
135                            result[ix as usize] += 1.0;
136                        }
137                    }
138                }
139
140                let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
141                for x in &mut result {
142                    *x /= norm;
143                }
144
145                result
146            })
147            .collect())
148    }
149}