vector_store_tests.rs

  1use std::sync::Arc;
  2
  3use crate::{dot, embedding::EmbeddingProvider, VectorStore};
  4use anyhow::Result;
  5use async_trait::async_trait;
  6use gpui::{Task, TestAppContext};
  7use language::{Language, LanguageConfig, LanguageRegistry};
  8use project::{FakeFs, Project};
  9use rand::Rng;
 10use serde_json::json;
 11use unindent::Unindent;
 12
 13#[gpui::test]
 14async fn test_vector_store(cx: &mut TestAppContext) {
 15    let fs = FakeFs::new(cx.background());
 16    fs.insert_tree(
 17        "/the-root",
 18        json!({
 19            "src": {
 20                "file1.rs": "
 21                    fn aaa() {
 22                        println!(\"aaaa!\");
 23                    }
 24
 25                    fn zzzzzzzzz() {
 26                        println!(\"SLEEPING\");
 27                    }
 28                ".unindent(),
 29                "file2.rs": "
 30                    fn bbb() {
 31                        println!(\"bbbb!\");
 32                    }
 33                ".unindent(),
 34            }
 35        }),
 36    )
 37    .await;
 38
 39    let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
 40    let rust_language = Arc::new(
 41        Language::new(
 42            LanguageConfig {
 43                name: "Rust".into(),
 44                path_suffixes: vec!["rs".into()],
 45                ..Default::default()
 46            },
 47            Some(tree_sitter_rust::language()),
 48        )
 49        .with_outline_query(
 50            r#"
 51            (function_item
 52                name: (identifier) @name
 53                body: (block)) @item
 54            "#,
 55        )
 56        .unwrap(),
 57    );
 58    languages.add(rust_language);
 59
 60    let store = cx.add_model(|_| {
 61        VectorStore::new(
 62            fs.clone(),
 63            "foo".to_string(),
 64            Arc::new(FakeEmbeddingProvider),
 65            languages,
 66        )
 67    });
 68
 69    let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
 70    store
 71        .update(cx, |store, cx| store.add_project(project, cx))
 72        .await
 73        .unwrap();
 74
 75    let search_results = store
 76        .update(cx, |store, cx| store.search("aaaa".to_string(), 5, cx))
 77        .await
 78        .unwrap();
 79
 80    assert_eq!(search_results[0].offset, 0);
 81    assert_eq!(search_results[1].name, "aaa");
 82}
 83
 84#[test]
 85fn test_dot_product() {
 86    assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
 87    assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
 88
 89    for _ in 0..100 {
 90        let mut rng = rand::thread_rng();
 91        let a: [f32; 32] = rng.gen();
 92        let b: [f32; 32] = rng.gen();
 93        assert_eq!(
 94            round_to_decimals(dot(&a, &b), 3),
 95            round_to_decimals(reference_dot(&a, &b), 3)
 96        );
 97    }
 98
 99    fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
100        let factor = (10.0 as f32).powi(decimal_places);
101        (n * factor).round() / factor
102    }
103
104    fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
105        a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
106    }
107}
108
109struct FakeEmbeddingProvider;
110
111#[async_trait]
112impl EmbeddingProvider for FakeEmbeddingProvider {
113    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
114        Ok(spans
115            .iter()
116            .map(|span| {
117                let mut result = vec![0.0; 26];
118                for letter in span.chars() {
119                    if letter as u32 > 'a' as u32 {
120                        let ix = (letter as u32) - ('a' as u32);
121                        if ix < 26 {
122                            result[ix as usize] += 1.0;
123                        }
124                    }
125                }
126
127                let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
128                for x in &mut result {
129                    *x /= norm;
130                }
131
132                result
133            })
134            .collect())
135    }
136}