1use std::sync::Arc;
2
3use crate::{dot, embedding::EmbeddingProvider, VectorStore};
4use anyhow::Result;
5use async_trait::async_trait;
6use gpui::{Task, TestAppContext};
7use language::{Language, LanguageConfig, LanguageRegistry};
8use project::{FakeFs, Project};
9use rand::Rng;
10use serde_json::json;
11use unindent::Unindent;
12
13#[gpui::test]
14async fn test_vector_store(cx: &mut TestAppContext) {
15 let fs = FakeFs::new(cx.background());
16 fs.insert_tree(
17 "/the-root",
18 json!({
19 "src": {
20 "file1.rs": "
21 fn aaa() {
22 println!(\"aaaa!\");
23 }
24
25 fn zzzzzzzzz() {
26 println!(\"SLEEPING\");
27 }
28 ".unindent(),
29 "file2.rs": "
30 fn bbb() {
31 println!(\"bbbb!\");
32 }
33 ".unindent(),
34 }
35 }),
36 )
37 .await;
38
39 let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
40 let rust_language = Arc::new(
41 Language::new(
42 LanguageConfig {
43 name: "Rust".into(),
44 path_suffixes: vec!["rs".into()],
45 ..Default::default()
46 },
47 Some(tree_sitter_rust::language()),
48 )
49 .with_embedding_query(
50 r#"
51 (function_item
52 name: (identifier) @name
53 body: (block)) @item
54 "#,
55 )
56 .unwrap(),
57 );
58 languages.add(rust_language);
59
60 let db_dir = tempdir::TempDir::new("vector-store").unwrap();
61 let db_path = db_dir.path().join("db.sqlite");
62
63 let store = VectorStore::new(
64 fs.clone(),
65 db_path,
66 Arc::new(FakeEmbeddingProvider),
67 languages,
68 cx.to_async(),
69 )
70 .await
71 .unwrap();
72
73 let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
74 let worktree_id = project.read_with(cx, |project, cx| {
75 project.worktrees(cx).next().unwrap().read(cx).id()
76 });
77 let add_project = store.update(cx, |store, cx| store.add_project(project.clone(), cx));
78
79 add_project.await.unwrap();
80
81 let search_results = store
82 .update(cx, |store, cx| {
83 store.search(project.clone(), "aaaa".to_string(), 5, cx)
84 })
85 .await
86 .unwrap();
87
88 assert_eq!(search_results[0].offset, 0);
89 assert_eq!(search_results[0].name, "aaa");
90 assert_eq!(search_results[0].worktree_id, worktree_id);
91}
92
93#[test]
94fn test_dot_product() {
95 assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
96 assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
97
98 for _ in 0..100 {
99 let mut rng = rand::thread_rng();
100 let a: [f32; 32] = rng.gen();
101 let b: [f32; 32] = rng.gen();
102 assert_eq!(
103 round_to_decimals(dot(&a, &b), 3),
104 round_to_decimals(reference_dot(&a, &b), 3)
105 );
106 }
107
108 fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
109 let factor = (10.0 as f32).powi(decimal_places);
110 (n * factor).round() / factor
111 }
112
113 fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
114 a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
115 }
116}
117
118struct FakeEmbeddingProvider;
119
120#[async_trait]
121impl EmbeddingProvider for FakeEmbeddingProvider {
122 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
123 Ok(spans
124 .iter()
125 .map(|span| {
126 let mut result = vec![1.0; 26];
127 for letter in span.chars() {
128 let letter = letter.to_ascii_lowercase();
129 if letter as u32 >= 'a' as u32 {
130 let ix = (letter as u32) - ('a' as u32);
131 if ix < 26 {
132 result[ix as usize] += 1.0;
133 }
134 }
135 }
136
137 let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
138 for x in &mut result {
139 *x /= norm;
140 }
141
142 result
143 })
144 .collect())
145 }
146}