1use crate::{
2 db::dot,
3 embedding::EmbeddingProvider,
4 parsing::{CodeContextRetriever, Document},
5 vector_store_settings::VectorStoreSettings,
6 VectorStore,
7};
8use anyhow::Result;
9use async_trait::async_trait;
10use gpui::{Task, TestAppContext};
11use language::{Language, LanguageConfig, LanguageRegistry};
12use project::{project_settings::ProjectSettings, FakeFs, Fs, Project};
13use rand::{rngs::StdRng, Rng};
14use serde_json::json;
15use settings::SettingsStore;
16use std::{
17 path::Path,
18 sync::{
19 atomic::{self, AtomicUsize},
20 Arc,
21 },
22};
23use unindent::Unindent;
24
25#[ctor::ctor]
26fn init_logger() {
27 if std::env::var("RUST_LOG").is_ok() {
28 env_logger::init();
29 }
30}
31
32#[gpui::test]
33async fn test_vector_store(cx: &mut TestAppContext) {
34 cx.update(|cx| {
35 cx.set_global(SettingsStore::test(cx));
36 settings::register::<VectorStoreSettings>(cx);
37 settings::register::<ProjectSettings>(cx);
38 });
39
40 let fs = FakeFs::new(cx.background());
41 fs.insert_tree(
42 "/the-root",
43 json!({
44 "src": {
45 "file1.rs": "
46 fn aaa() {
47 println!(\"aaaa!\");
48 }
49
50 fn zzzzzzzzz() {
51 println!(\"SLEEPING\");
52 }
53 ".unindent(),
54 "file2.rs": "
55 fn bbb() {
56 println!(\"bbbb!\");
57 }
58 ".unindent(),
59 }
60 }),
61 )
62 .await;
63
64 let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
65 let rust_language = rust_lang();
66 languages.add(rust_language);
67
68 let db_dir = tempdir::TempDir::new("vector-store").unwrap();
69 let db_path = db_dir.path().join("db.sqlite");
70
71 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
72 let store = VectorStore::new(
73 fs.clone(),
74 db_path,
75 embedding_provider.clone(),
76 languages,
77 cx.to_async(),
78 )
79 .await
80 .unwrap();
81
82 let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
83 let worktree_id = project.read_with(cx, |project, cx| {
84 project.worktrees(cx).next().unwrap().read(cx).id()
85 });
86 let file_count = store
87 .update(cx, |store, cx| store.index_project(project.clone(), cx))
88 .await
89 .unwrap();
90 assert_eq!(file_count, 2);
91 cx.foreground().run_until_parked();
92 store.update(cx, |store, _cx| {
93 assert_eq!(
94 store.remaining_files_to_index_for_project(&project),
95 Some(0)
96 );
97 });
98
99 let search_results = store
100 .update(cx, |store, cx| {
101 store.search_project(project.clone(), "aaaa".to_string(), 5, cx)
102 })
103 .await
104 .unwrap();
105
106 assert_eq!(search_results[0].byte_range.start, 0);
107 assert_eq!(search_results[0].name, "aaa");
108 assert_eq!(search_results[0].worktree_id, worktree_id);
109
110 fs.save(
111 "/the-root/src/file2.rs".as_ref(),
112 &"
113 fn dddd() { println!(\"ddddd!\"); }
114 struct pqpqpqp {}
115 "
116 .unindent()
117 .into(),
118 Default::default(),
119 )
120 .await
121 .unwrap();
122
123 cx.foreground().run_until_parked();
124
125 let prev_embedding_count = embedding_provider.embedding_count();
126 let file_count = store
127 .update(cx, |store, cx| store.index_project(project.clone(), cx))
128 .await
129 .unwrap();
130 assert_eq!(file_count, 1);
131
132 cx.foreground().run_until_parked();
133 store.update(cx, |store, _cx| {
134 assert_eq!(
135 store.remaining_files_to_index_for_project(&project),
136 Some(0)
137 );
138 });
139
140 assert_eq!(
141 embedding_provider.embedding_count() - prev_embedding_count,
142 2
143 );
144}
145
146#[gpui::test]
147async fn test_code_context_retrieval() {
148 let language = rust_lang();
149 let mut retriever = CodeContextRetriever::new();
150
151 let text = "
152 /// A doc comment
153 /// that spans multiple lines
154 fn a() {
155 b
156 }
157
158 impl C for D {
159 }
160 "
161 .unindent();
162
163 let parsed_files = retriever
164 .parse_file(Path::new("foo.rs"), &text, language)
165 .unwrap();
166
167 assert_eq!(
168 parsed_files,
169 &[
170 Document {
171 name: "a".into(),
172 range: text.find("fn a").unwrap()..(text.find("}").unwrap() + 1),
173 content: "
174 The below code snippet is from file 'foo.rs'
175
176 ```rust
177 /// A doc comment
178 /// that spans multiple lines
179 fn a() {
180 b
181 }
182 ```"
183 .unindent(),
184 embedding: vec![],
185 },
186 Document {
187 name: "C for D".into(),
188 range: text.find("impl C").unwrap()..(text.rfind("}").unwrap() + 1),
189 content: "
190 The below code snippet is from file 'foo.rs'
191
192 ```rust
193 impl C for D {
194 }
195 ```"
196 .unindent(),
197 embedding: vec![],
198 }
199 ]
200 );
201}
202
203#[gpui::test]
204fn test_dot_product(mut rng: StdRng) {
205 assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
206 assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
207
208 for _ in 0..100 {
209 let size = 1536;
210 let mut a = vec![0.; size];
211 let mut b = vec![0.; size];
212 for (a, b) in a.iter_mut().zip(b.iter_mut()) {
213 *a = rng.gen();
214 *b = rng.gen();
215 }
216
217 assert_eq!(
218 round_to_decimals(dot(&a, &b), 1),
219 round_to_decimals(reference_dot(&a, &b), 1)
220 );
221 }
222
223 fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
224 let factor = (10.0 as f32).powi(decimal_places);
225 (n * factor).round() / factor
226 }
227
228 fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
229 a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
230 }
231}
232
233#[derive(Default)]
234struct FakeEmbeddingProvider {
235 embedding_count: AtomicUsize,
236}
237
238impl FakeEmbeddingProvider {
239 fn embedding_count(&self) -> usize {
240 self.embedding_count.load(atomic::Ordering::SeqCst)
241 }
242}
243
244#[async_trait]
245impl EmbeddingProvider for FakeEmbeddingProvider {
246 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
247 self.embedding_count
248 .fetch_add(spans.len(), atomic::Ordering::SeqCst);
249 Ok(spans
250 .iter()
251 .map(|span| {
252 let mut result = vec![1.0; 26];
253 for letter in span.chars() {
254 let letter = letter.to_ascii_lowercase();
255 if letter as u32 >= 'a' as u32 {
256 let ix = (letter as u32) - ('a' as u32);
257 if ix < 26 {
258 result[ix as usize] += 1.0;
259 }
260 }
261 }
262
263 let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
264 for x in &mut result {
265 *x /= norm;
266 }
267
268 result
269 })
270 .collect())
271 }
272}
273
274fn rust_lang() -> Arc<Language> {
275 Arc::new(
276 Language::new(
277 LanguageConfig {
278 name: "Rust".into(),
279 path_suffixes: vec!["rs".into()],
280 ..Default::default()
281 },
282 Some(tree_sitter_rust::language()),
283 )
284 .with_embedding_query(
285 r#"
286 (
287 (line_comment)* @context
288 .
289 (enum_item
290 name: (_) @name) @item
291 )
292
293 (
294 (line_comment)* @context
295 .
296 (struct_item
297 name: (_) @name) @item
298 )
299
300 (
301 (line_comment)* @context
302 .
303 (impl_item
304 trait: (_)? @name
305 "for"? @name
306 type: (_) @name) @item
307 )
308
309 (
310 (line_comment)* @context
311 .
312 (trait_item
313 name: (_) @name) @item
314 )
315
316 (
317 (line_comment)* @context
318 .
319 (function_item
320 name: (_) @name) @item
321 )
322
323 (
324 (line_comment)* @context
325 .
326 (macro_definition
327 name: (_) @name) @item
328 )
329
330 (
331 (line_comment)* @context
332 .
333 (function_signature_item
334 name: (_) @name) @item
335 )
336 "#,
337 )
338 .unwrap(),
339 )
340}