1use std::sync::Arc;
2
3use crate::{dot, embedding::EmbeddingProvider, VectorStore};
4use anyhow::Result;
5use async_trait::async_trait;
6use gpui::{Task, TestAppContext};
7use language::{Language, LanguageConfig, LanguageRegistry};
8use project::{FakeFs, Project};
9use rand::Rng;
10use serde_json::json;
11use unindent::Unindent;
12
13#[gpui::test]
14async fn test_vector_store(cx: &mut TestAppContext) {
15 let fs = FakeFs::new(cx.background());
16 fs.insert_tree(
17 "/the-root",
18 json!({
19 "src": {
20 "file1.rs": "
21 fn aaa() {
22 println!(\"aaaa!\");
23 }
24
25 fn zzzzzzzzz() {
26 println!(\"SLEEPING\");
27 }
28 ".unindent(),
29 "file2.rs": "
30 fn bbb() {
31 println!(\"bbbb!\");
32 }
33 ".unindent(),
34 }
35 }),
36 )
37 .await;
38
39 let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
40 let rust_language = Arc::new(
41 Language::new(
42 LanguageConfig {
43 name: "Rust".into(),
44 path_suffixes: vec!["rs".into()],
45 ..Default::default()
46 },
47 Some(tree_sitter_rust::language()),
48 )
49 .with_outline_query(
50 r#"
51 (function_item
52 name: (identifier) @name
53 body: (block)) @item
54 "#,
55 )
56 .unwrap(),
57 );
58 languages.add(rust_language);
59
60 let db_dir = tempdir::TempDir::new("vector-store").unwrap();
61 let db_path = db_dir.path().join("db.sqlite");
62
63 let store = cx.add_model(|_| {
64 VectorStore::new(
65 fs.clone(),
66 db_path.to_string_lossy().to_string(),
67 Arc::new(FakeEmbeddingProvider),
68 languages,
69 )
70 });
71
72 let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
73 let add_project = store.update(cx, |store, cx| store.add_project(project, cx));
74
75 // TODO - remove
76 cx.foreground()
77 .advance_clock(std::time::Duration::from_secs(3));
78
79 add_project.await.unwrap();
80
81 let search_results = store
82 .update(cx, |store, cx| store.search("aaaa".to_string(), 5, cx))
83 .await
84 .unwrap();
85
86 assert_eq!(search_results[0].offset, 0);
87 assert_eq!(search_results[0].name, "aaa");
88}
89
90#[test]
91fn test_dot_product() {
92 assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
93 assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
94
95 for _ in 0..100 {
96 let mut rng = rand::thread_rng();
97 let a: [f32; 32] = rng.gen();
98 let b: [f32; 32] = rng.gen();
99 assert_eq!(
100 round_to_decimals(dot(&a, &b), 3),
101 round_to_decimals(reference_dot(&a, &b), 3)
102 );
103 }
104
105 fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
106 let factor = (10.0 as f32).powi(decimal_places);
107 (n * factor).round() / factor
108 }
109
110 fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
111 a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
112 }
113}
114
115struct FakeEmbeddingProvider;
116
117#[async_trait]
118impl EmbeddingProvider for FakeEmbeddingProvider {
119 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
120 Ok(spans
121 .iter()
122 .map(|span| {
123 let mut result = vec![1.0; 26];
124 for letter in span.chars() {
125 let letter = letter.to_ascii_lowercase();
126 if letter as u32 >= 'a' as u32 {
127 let ix = (letter as u32) - ('a' as u32);
128 if ix < 26 {
129 result[ix as usize] += 1.0;
130 }
131 }
132 }
133
134 let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
135 for x in &mut result {
136 *x /= norm;
137 }
138
139 result
140 })
141 .collect())
142 }
143}