1use crate::{
2 db::dot,
3 embedding::EmbeddingProvider,
4 parsing::{CodeContextRetriever, Document},
5 vector_store_settings::VectorStoreSettings,
6 VectorStore,
7};
8use anyhow::Result;
9use async_trait::async_trait;
10use gpui::{Task, TestAppContext};
11use language::{Language, LanguageConfig, LanguageRegistry};
12use project::{project_settings::ProjectSettings, FakeFs, Fs, Project};
13use rand::{rngs::StdRng, Rng};
14use serde_json::json;
15use settings::SettingsStore;
16use std::{
17 path::Path,
18 sync::{
19 atomic::{self, AtomicUsize},
20 Arc,
21 },
22};
23use unindent::Unindent;
24
25#[ctor::ctor]
26fn init_logger() {
27 if std::env::var("RUST_LOG").is_ok() {
28 env_logger::init();
29 }
30}
31
32#[gpui::test]
33async fn test_vector_store(cx: &mut TestAppContext) {
34 cx.update(|cx| {
35 cx.set_global(SettingsStore::test(cx));
36 settings::register::<VectorStoreSettings>(cx);
37 settings::register::<ProjectSettings>(cx);
38 });
39
40 let fs = FakeFs::new(cx.background());
41 fs.insert_tree(
42 "/the-root",
43 json!({
44 "src": {
45 "file1.rs": "
46 fn aaa() {
47 println!(\"aaaa!\");
48 }
49
50 fn zzzzzzzzz() {
51 println!(\"SLEEPING\");
52 }
53 ".unindent(),
54 "file2.rs": "
55 fn bbb() {
56 println!(\"bbbb!\");
57 }
58 ".unindent(),
59 }
60 }),
61 )
62 .await;
63
64 let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
65 let rust_language = rust_lang();
66 languages.add(rust_language);
67
68 let db_dir = tempdir::TempDir::new("vector-store").unwrap();
69 let db_path = db_dir.path().join("db.sqlite");
70
71 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
72 let store = VectorStore::new(
73 fs.clone(),
74 db_path,
75 embedding_provider.clone(),
76 languages,
77 cx.to_async(),
78 )
79 .await
80 .unwrap();
81
82 let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
83 let worktree_id = project.read_with(cx, |project, cx| {
84 project.worktrees(cx).next().unwrap().read(cx).id()
85 });
86 let file_count = store
87 .update(cx, |store, cx| store.index_project(project.clone(), cx))
88 .await
89 .unwrap();
90 assert_eq!(file_count, 2);
91 cx.foreground().run_until_parked();
92 store.update(cx, |store, _cx| {
93 assert_eq!(
94 store.remaining_files_to_index_for_project(&project),
95 Some(0)
96 );
97 });
98
99 let search_results = store
100 .update(cx, |store, cx| {
101 store.search_project(project.clone(), "aaaa".to_string(), 5, cx)
102 })
103 .await
104 .unwrap();
105
106 assert_eq!(search_results[0].byte_range.start, 0);
107 assert_eq!(search_results[0].name, "aaa");
108 assert_eq!(search_results[0].worktree_id, worktree_id);
109
110 fs.save(
111 "/the-root/src/file2.rs".as_ref(),
112 &"
113 fn dddd() { println!(\"ddddd!\"); }
114 struct pqpqpqp {}
115 "
116 .unindent()
117 .into(),
118 Default::default(),
119 )
120 .await
121 .unwrap();
122
123 cx.foreground().run_until_parked();
124
125 let prev_embedding_count = embedding_provider.embedding_count();
126 let file_count = store
127 .update(cx, |store, cx| store.index_project(project.clone(), cx))
128 .await
129 .unwrap();
130 assert_eq!(file_count, 1);
131
132 cx.foreground().run_until_parked();
133 store.update(cx, |store, _cx| {
134 assert_eq!(
135 store.remaining_files_to_index_for_project(&project),
136 Some(0)
137 );
138 });
139
140 assert_eq!(
141 embedding_provider.embedding_count() - prev_embedding_count,
142 2
143 );
144}
145
146#[gpui::test]
147async fn test_code_context_retrieval_rust() {
148 let language = rust_lang();
149 let mut retriever = CodeContextRetriever::new();
150
151 let text = "
152 /// A doc comment
153 /// that spans multiple lines
154 fn a() {
155 b
156 }
157
158 impl C for D {
159 }
160 "
161 .unindent();
162
163 let parsed_files = retriever
164 .parse_file(Path::new("foo.rs"), &text, language)
165 .unwrap();
166
167 assert_eq!(
168 parsed_files,
169 &[
170 Document {
171 name: "a".into(),
172 range: text.find("fn a").unwrap()..(text.find("}").unwrap() + 1),
173 content: "
174 The below code snippet is from file 'foo.rs'
175
176 ```rust
177 /// A doc comment
178 /// that spans multiple lines
179 fn a() {
180 b
181 }
182 ```"
183 .unindent(),
184 embedding: vec![],
185 },
186 Document {
187 name: "C for D".into(),
188 range: text.find("impl C").unwrap()..(text.rfind("}").unwrap() + 1),
189 content: "
190 The below code snippet is from file 'foo.rs'
191
192 ```rust
193 impl C for D {
194 }
195 ```"
196 .unindent(),
197 embedding: vec![],
198 }
199 ]
200 );
201}
202
203#[gpui::test]
204async fn test_code_context_retrieval_javascript() {
205 let language = js_lang();
206 let mut retriever = CodeContextRetriever::new();
207
208 let text = "
209/* globals importScripts, backend */
210function _authorize() {}
211
212/**
213 * Sometimes the frontend build is way faster than backend.
214 */
215export async function authorizeBank() {
216 _authorize(pushModal, upgradingAccountId, {});
217}
218
219export class SettingsPage {
220 /* This is a test setting */
221 constructor(page) {
222 this.page = page;
223 }
224}
225
226/* This is a test comment */
227class TestClass {}
228
229/* Schema for editor_events in Clickhouse. */
230export interface ClickhouseEditorEvent {
231 installation_id: string
232 operation: string
233}
234";
235
236 let parsed_files = retriever
237 .parse_file(Path::new("foo.js"), &text, language)
238 .unwrap();
239
240 let test_documents = &[
241 Document {
242 name: "function _authorize".into(),
243 range: text.find("function _authorize").unwrap()..(text.find("}").unwrap() + 1),
244 content: "
245 The below code snippet is from file 'foo.js'
246
247 ```javascript
248 /* globals importScripts, backend */
249 function _authorize() {}
250 ```"
251 .unindent(),
252 embedding: vec![],
253 },
254 Document {
255 name: "async function authorizeBank".into(),
256 range: text.find("export async").unwrap()..224,
257 content: "
258 The below code snippet is from file 'foo.js'
259
260 ```javascript
261 /**
262 * Sometimes the frontend build is way faster than backend.
263 */
264 export async function authorizeBank() {
265 _authorize(pushModal, upgradingAccountId, {});
266 }
267 ```"
268 .unindent(),
269 embedding: vec![],
270 },
271 Document {
272 name: "class SettingsPage".into(),
273 range: 226..344,
274 content: "
275 The below code snippet is from file 'foo.js'
276
277 ```javascript
278 export class SettingsPage {
279 /* This is a test setting */
280 constructor(page) {
281 this.page = page;
282 }
283 }
284 ```"
285 .unindent(),
286 embedding: vec![],
287 },
288 Document {
289 name: "constructor".into(),
290 range: 291..342,
291 content: "
292 The below code snippet is from file 'foo.js'
293
294 ```javascript
295 /* This is a test setting */
296 constructor(page) {
297 this.page = page;
298 }
299 ```"
300 .unindent(),
301 embedding: vec![],
302 },
303 Document {
304 name: "class TestClass".into(),
305 range: 375..393,
306 content: "
307 The below code snippet is from file 'foo.js'
308
309 ```javascript
310 /* This is a test comment */
311 class TestClass {}
312 ```"
313 .unindent(),
314 embedding: vec![],
315 },
316 Document {
317 name: "interface ClickhouseEditorEvent".into(),
318 range: 441..533,
319 content: "
320 The below code snippet is from file 'foo.js'
321
322 ```javascript
323 /* Schema for editor_events in Clickhouse. */
324 export interface ClickhouseEditorEvent {
325 installation_id: string
326 operation: string
327 }
328 ```"
329 .unindent(),
330 embedding: vec![],
331 },
332 ];
333
334 for idx in 0..test_documents.len() {
335 assert_eq!(test_documents[idx], parsed_files[idx]);
336 }
337}
338
339#[gpui::test]
340fn test_dot_product(mut rng: StdRng) {
341 assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
342 assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
343
344 for _ in 0..100 {
345 let size = 1536;
346 let mut a = vec![0.; size];
347 let mut b = vec![0.; size];
348 for (a, b) in a.iter_mut().zip(b.iter_mut()) {
349 *a = rng.gen();
350 *b = rng.gen();
351 }
352
353 assert_eq!(
354 round_to_decimals(dot(&a, &b), 1),
355 round_to_decimals(reference_dot(&a, &b), 1)
356 );
357 }
358
359 fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
360 let factor = (10.0 as f32).powi(decimal_places);
361 (n * factor).round() / factor
362 }
363
364 fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
365 a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
366 }
367}
368
369#[derive(Default)]
370struct FakeEmbeddingProvider {
371 embedding_count: AtomicUsize,
372}
373
374impl FakeEmbeddingProvider {
375 fn embedding_count(&self) -> usize {
376 self.embedding_count.load(atomic::Ordering::SeqCst)
377 }
378}
379
380#[async_trait]
381impl EmbeddingProvider for FakeEmbeddingProvider {
382 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
383 self.embedding_count
384 .fetch_add(spans.len(), atomic::Ordering::SeqCst);
385 Ok(spans
386 .iter()
387 .map(|span| {
388 let mut result = vec![1.0; 26];
389 for letter in span.chars() {
390 let letter = letter.to_ascii_lowercase();
391 if letter as u32 >= 'a' as u32 {
392 let ix = (letter as u32) - ('a' as u32);
393 if ix < 26 {
394 result[ix as usize] += 1.0;
395 }
396 }
397 }
398
399 let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
400 for x in &mut result {
401 *x /= norm;
402 }
403
404 result
405 })
406 .collect())
407 }
408}
409
410fn js_lang() -> Arc<Language> {
411 Arc::new(
412 Language::new(
413 LanguageConfig {
414 name: "Javascript".into(),
415 path_suffixes: vec!["js".into()],
416 ..Default::default()
417 },
418 Some(tree_sitter_typescript::language_tsx()),
419 )
420 .with_embedding_query(
421 &r#"
422
423 (
424 (comment)* @context
425 .
426 (export_statement
427 (function_declaration
428 "async"? @name
429 "function" @name
430 name: (_) @name)) @item
431 )
432
433 (
434 (comment)* @context
435 .
436 (function_declaration
437 "async"? @name
438 "function" @name
439 name: (_) @name) @item
440 )
441
442 (
443 (comment)* @context
444 .
445 (export_statement
446 (class_declaration
447 "class" @name
448 name: (_) @name)) @item
449 )
450
451 (
452 (comment)* @context
453 .
454 (class_declaration
455 "class" @name
456 name: (_) @name) @item
457 )
458
459 (
460 (comment)* @context
461 .
462 (method_definition
463 [
464 "get"
465 "set"
466 "async"
467 "*"
468 "static"
469 ]* @name
470 name: (_) @name) @item
471 )
472
473 (
474 (comment)* @context
475 .
476 (export_statement
477 (interface_declaration
478 "interface" @name
479 name: (_) @name)) @item
480 )
481
482 (
483 (comment)* @context
484 .
485 (interface_declaration
486 "interface" @name
487 name: (_) @name) @item
488 )
489
490 (
491 (comment)* @context
492 .
493 (export_statement
494 (enum_declaration
495 "enum" @name
496 name: (_) @name)) @item
497 )
498
499 (
500 (comment)* @context
501 .
502 (enum_declaration
503 "enum" @name
504 name: (_) @name) @item
505 )
506
507 "#
508 .unindent(),
509 )
510 .unwrap(),
511 )
512}
513
514fn rust_lang() -> Arc<Language> {
515 Arc::new(
516 Language::new(
517 LanguageConfig {
518 name: "Rust".into(),
519 path_suffixes: vec!["rs".into()],
520 ..Default::default()
521 },
522 Some(tree_sitter_rust::language()),
523 )
524 .with_embedding_query(
525 r#"
526 (
527 (line_comment)* @context
528 .
529 (enum_item
530 name: (_) @name) @item
531 )
532
533 (
534 (line_comment)* @context
535 .
536 (struct_item
537 name: (_) @name) @item
538 )
539
540 (
541 (line_comment)* @context
542 .
543 (impl_item
544 trait: (_)? @name
545 "for"? @name
546 type: (_) @name) @item
547 )
548
549 (
550 (line_comment)* @context
551 .
552 (trait_item
553 name: (_) @name) @item
554 )
555
556 (
557 (line_comment)* @context
558 .
559 (function_item
560 name: (_) @name) @item
561 )
562
563 (
564 (line_comment)* @context
565 .
566 (macro_definition
567 name: (_) @name) @item
568 )
569
570 (
571 (line_comment)* @context
572 .
573 (function_signature_item
574 name: (_) @name) @item
575 )
576 "#,
577 )
578 .unwrap(),
579 )
580}