1use crate::{
2 db::dot,
3 embedding::EmbeddingProvider,
4 parsing::{CodeContextRetriever, Document},
5 vector_store_settings::VectorStoreSettings,
6 VectorStore,
7};
8use anyhow::Result;
9use async_trait::async_trait;
10use gpui::{Task, TestAppContext};
11use language::{Language, LanguageConfig, LanguageRegistry};
12use project::{project_settings::ProjectSettings, FakeFs, Fs, Project};
13use rand::{rngs::StdRng, Rng};
14use serde_json::json;
15use settings::SettingsStore;
16use std::{
17 path::Path,
18 sync::{
19 atomic::{self, AtomicUsize},
20 Arc,
21 },
22};
23use unindent::Unindent;
24
25#[ctor::ctor]
26fn init_logger() {
27 if std::env::var("RUST_LOG").is_ok() {
28 env_logger::init();
29 }
30}
31
32#[gpui::test]
33async fn test_vector_store(cx: &mut TestAppContext) {
34 cx.update(|cx| {
35 cx.set_global(SettingsStore::test(cx));
36 settings::register::<VectorStoreSettings>(cx);
37 settings::register::<ProjectSettings>(cx);
38 });
39
40 let fs = FakeFs::new(cx.background());
41 fs.insert_tree(
42 "/the-root",
43 json!({
44 "src": {
45 "file1.rs": "
46 fn aaa() {
47 println!(\"aaaa!\");
48 }
49
50 fn zzzzzzzzz() {
51 println!(\"SLEEPING\");
52 }
53 ".unindent(),
54 "file2.rs": "
55 fn bbb() {
56 println!(\"bbbb!\");
57 }
58 ".unindent(),
59 "file3.toml": "
60 ZZZZZZZ = 5
61 ".unindent(),
62 }
63 }),
64 )
65 .await;
66
67 let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
68 let rust_language = rust_lang();
69 let toml_language = toml_lang();
70 languages.add(rust_language);
71 languages.add(toml_language);
72
73 let db_dir = tempdir::TempDir::new("vector-store").unwrap();
74 let db_path = db_dir.path().join("db.sqlite");
75
76 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
77 let store = VectorStore::new(
78 fs.clone(),
79 db_path,
80 embedding_provider.clone(),
81 languages,
82 cx.to_async(),
83 )
84 .await
85 .unwrap();
86
87 let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
88 let worktree_id = project.read_with(cx, |project, cx| {
89 project.worktrees(cx).next().unwrap().read(cx).id()
90 });
91 let file_count = store
92 .update(cx, |store, cx| store.index_project(project.clone(), cx))
93 .await
94 .unwrap();
95 assert_eq!(file_count, 3);
96 cx.foreground().run_until_parked();
97 store.update(cx, |store, _cx| {
98 assert_eq!(
99 store.remaining_files_to_index_for_project(&project),
100 Some(0)
101 );
102 });
103
104 let search_results = store
105 .update(cx, |store, cx| {
106 store.search_project(project.clone(), "aaaa".to_string(), 5, cx)
107 })
108 .await
109 .unwrap();
110
111 assert_eq!(search_results[0].byte_range.start, 0);
112 assert_eq!(search_results[0].name, "aaa");
113 assert_eq!(search_results[0].worktree_id, worktree_id);
114
115 fs.save(
116 "/the-root/src/file2.rs".as_ref(),
117 &"
118 fn dddd() { println!(\"ddddd!\"); }
119 struct pqpqpqp {}
120 "
121 .unindent()
122 .into(),
123 Default::default(),
124 )
125 .await
126 .unwrap();
127
128 cx.foreground().run_until_parked();
129
130 let prev_embedding_count = embedding_provider.embedding_count();
131 let file_count = store
132 .update(cx, |store, cx| store.index_project(project.clone(), cx))
133 .await
134 .unwrap();
135 assert_eq!(file_count, 1);
136
137 cx.foreground().run_until_parked();
138 store.update(cx, |store, _cx| {
139 assert_eq!(
140 store.remaining_files_to_index_for_project(&project),
141 Some(0)
142 );
143 });
144
145 assert_eq!(
146 embedding_provider.embedding_count() - prev_embedding_count,
147 2
148 );
149}
150
151#[gpui::test]
152async fn test_code_context_retrieval_rust() {
153 let language = rust_lang();
154 let mut retriever = CodeContextRetriever::new();
155
156 let text = "
157 /// A doc comment
158 /// that spans multiple lines
159 fn a() {
160 b
161 }
162
163 impl C for D {
164 }
165 "
166 .unindent();
167
168 let parsed_files = retriever
169 .parse_file(Path::new("foo.rs"), &text, language)
170 .unwrap();
171
172 assert_eq!(
173 parsed_files,
174 &[
175 Document {
176 name: "a".into(),
177 range: text.find("fn a").unwrap()..(text.find("}").unwrap() + 1),
178 content: "
179 The below code snippet is from file 'foo.rs'
180
181 ```rust
182 /// A doc comment
183 /// that spans multiple lines
184 fn a() {
185 b
186 }
187 ```"
188 .unindent(),
189 embedding: vec![],
190 },
191 Document {
192 name: "C for D".into(),
193 range: text.find("impl C").unwrap()..(text.rfind("}").unwrap() + 1),
194 content: "
195 The below code snippet is from file 'foo.rs'
196
197 ```rust
198 impl C for D {
199 }
200 ```"
201 .unindent(),
202 embedding: vec![],
203 }
204 ]
205 );
206}
207
208#[gpui::test]
209async fn test_code_context_retrieval_javascript() {
210 let language = js_lang();
211 let mut retriever = CodeContextRetriever::new();
212
213 let text = "
214 /* globals importScripts, backend */
215 function _authorize() {}
216
217 /**
218 * Sometimes the frontend build is way faster than backend.
219 */
220 export async function authorizeBank() {
221 _authorize(pushModal, upgradingAccountId, {});
222 }
223
224 export class SettingsPage {
225 /* This is a test setting */
226 constructor(page) {
227 this.page = page;
228 }
229 }
230
231 /* This is a test comment */
232 class TestClass {}
233
234 /* Schema for editor_events in Clickhouse. */
235 export interface ClickhouseEditorEvent {
236 installation_id: string
237 operation: string
238 }
239 "
240 .unindent();
241
242 let parsed_files = retriever
243 .parse_file(Path::new("foo.js"), &text, language)
244 .unwrap();
245
246 let test_documents = &[
247 Document {
248 name: "function _authorize".into(),
249 range: text.find("function _authorize").unwrap()..(text.find("}").unwrap() + 1),
250 content: "
251 The below code snippet is from file 'foo.js'
252
253 ```javascript
254 /* globals importScripts, backend */
255 function _authorize() {}
256 ```"
257 .unindent(),
258 embedding: vec![],
259 },
260 Document {
261 name: "async function authorizeBank".into(),
262 range: text.find("export async").unwrap()..223,
263 content: "
264 The below code snippet is from file 'foo.js'
265
266 ```javascript
267 /**
268 * Sometimes the frontend build is way faster than backend.
269 */
270 export async function authorizeBank() {
271 _authorize(pushModal, upgradingAccountId, {});
272 }
273 ```"
274 .unindent(),
275 embedding: vec![],
276 },
277 Document {
278 name: "class SettingsPage".into(),
279 range: 225..343,
280 content: "
281 The below code snippet is from file 'foo.js'
282
283 ```javascript
284 export class SettingsPage {
285 /* This is a test setting */
286 constructor(page) {
287 this.page = page;
288 }
289 }
290 ```"
291 .unindent(),
292 embedding: vec![],
293 },
294 Document {
295 name: "constructor".into(),
296 range: 290..341,
297 content: "
298 The below code snippet is from file 'foo.js'
299
300 ```javascript
301 /* This is a test setting */
302 constructor(page) {
303 this.page = page;
304 }
305 ```"
306 .unindent(),
307 embedding: vec![],
308 },
309 Document {
310 name: "class TestClass".into(),
311 range: 374..392,
312 content: "
313 The below code snippet is from file 'foo.js'
314
315 ```javascript
316 /* This is a test comment */
317 class TestClass {}
318 ```"
319 .unindent(),
320 embedding: vec![],
321 },
322 Document {
323 name: "interface ClickhouseEditorEvent".into(),
324 range: 440..532,
325 content: "
326 The below code snippet is from file 'foo.js'
327
328 ```javascript
329 /* Schema for editor_events in Clickhouse. */
330 export interface ClickhouseEditorEvent {
331 installation_id: string
332 operation: string
333 }
334 ```"
335 .unindent(),
336 embedding: vec![],
337 },
338 ];
339
340 for idx in 0..test_documents.len() {
341 assert_eq!(test_documents[idx], parsed_files[idx]);
342 }
343}
344
345#[gpui::test]
346async fn test_code_context_retrieval_cpp() {
347 let language = cpp_lang();
348 let mut retriever = CodeContextRetriever::new();
349
350 let text = "
351 /**
352 * @brief Main function
353 * @returns 0 on exit
354 */
355 int main() { return 0; }
356
357 /**
358 * This is a test comment
359 */
360 class MyClass { // The class
361 public: // Access specifier
362 int myNum; // Attribute (int variable)
363 string myString; // Attribute (string variable)
364 };
365
366 // This is a test comment
367 enum Color { red, green, blue };
368
369 /** This is a preceeding block comment
370 * This is the second line
371 */
372 struct { // Structure declaration
373 int myNum; // Member (int variable)
374 string myString; // Member (string variable)
375 } myStructure;
376
377 /**
378 * @brief Matrix class.
379 */
380 template <typename T,
381 typename = typename std::enable_if<
382 std::is_integral<T>::value || std::is_floating_point<T>::value,
383 bool>::type>
384 class Matrix2 {
385 std::vector<std::vector<T>> _mat;
386
387 public:
388 /**
389 * @brief Constructor
390 * @tparam Integer ensuring integers are being evaluated and not other
391 * data types.
392 * @param size denoting the size of Matrix as size x size
393 */
394 template <typename Integer,
395 typename = typename std::enable_if<std::is_integral<Integer>::value,
396 Integer>::type>
397 explicit Matrix(const Integer size) {
398 for (size_t i = 0; i < size; ++i) {
399 _mat.emplace_back(std::vector<T>(size, 0));
400 }
401 }
402 }"
403 .unindent();
404
405 let parsed_files = retriever
406 .parse_file(Path::new("foo.cpp"), &text, language)
407 .unwrap();
408
409 let test_documents = &[
410 Document {
411 name: "int main".into(),
412 range: 54..78,
413 content: "
414 The below code snippet is from file 'foo.cpp'
415
416 ```cpp
417 /**
418 * @brief Main function
419 * @returns 0 on exit
420 */
421 int main() { return 0; }
422 ```"
423 .unindent(),
424 embedding: vec![],
425 },
426 Document {
427 name: "class MyClass".into(),
428 range: 112..295,
429 content: "
430 The below code snippet is from file 'foo.cpp'
431
432 ```cpp
433 /**
434 * This is a test comment
435 */
436 class MyClass { // The class
437 public: // Access specifier
438 int myNum; // Attribute (int variable)
439 string myString; // Attribute (string variable)
440 }
441 ```"
442 .unindent(),
443 embedding: vec![],
444 },
445 Document {
446 name: "enum Color".into(),
447 range: 324..355,
448 content: "
449 The below code snippet is from file 'foo.cpp'
450
451 ```cpp
452 // This is a test comment
453 enum Color { red, green, blue }
454 ```"
455 .unindent(),
456 embedding: vec![],
457 },
458 Document {
459 name: "struct myStructure".into(),
460 range: 428..581,
461 content: "
462 The below code snippet is from file 'foo.cpp'
463
464 ```cpp
465 /** This is a preceeding block comment
466 * This is the second line
467 */
468 struct { // Structure declaration
469 int myNum; // Member (int variable)
470 string myString; // Member (string variable)
471 } myStructure;
472 ```"
473 .unindent(),
474 embedding: vec![],
475 },
476 Document {
477 name: "class Matrix2".into(),
478 range: 613..1342,
479 content: "
480 The below code snippet is from file 'foo.cpp'
481
482 ```cpp
483 /**
484 * @brief Matrix class.
485 */
486 template <typename T,
487 typename = typename std::enable_if<
488 std::is_integral<T>::value || std::is_floating_point<T>::value,
489 bool>::type>
490 class Matrix2 {
491 std::vector<std::vector<T>> _mat;
492
493 public:
494 /**
495 * @brief Constructor
496 * @tparam Integer ensuring integers are being evaluated and not other
497 * data types.
498 * @param size denoting the size of Matrix as size x size
499 */
500 template <typename Integer,
501 typename = typename std::enable_if<std::is_integral<Integer>::value,
502 Integer>::type>
503 explicit Matrix(const Integer size) {
504 for (size_t i = 0; i < size; ++i) {
505 _mat.emplace_back(std::vector<T>(size, 0));
506 }
507 }
508 }
509 ```"
510 .unindent(),
511 embedding: vec![],
512 },
513 ];
514
515 for idx in 0..test_documents.len() {
516 assert_eq!(test_documents[idx], parsed_files[idx]);
517 }
518}
519
520#[gpui::test]
521fn test_dot_product(mut rng: StdRng) {
522 assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
523 assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
524
525 for _ in 0..100 {
526 let size = 1536;
527 let mut a = vec![0.; size];
528 let mut b = vec![0.; size];
529 for (a, b) in a.iter_mut().zip(b.iter_mut()) {
530 *a = rng.gen();
531 *b = rng.gen();
532 }
533
534 assert_eq!(
535 round_to_decimals(dot(&a, &b), 1),
536 round_to_decimals(reference_dot(&a, &b), 1)
537 );
538 }
539
540 fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
541 let factor = (10.0 as f32).powi(decimal_places);
542 (n * factor).round() / factor
543 }
544
545 fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
546 a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
547 }
548}
549
550#[derive(Default)]
551struct FakeEmbeddingProvider {
552 embedding_count: AtomicUsize,
553}
554
555impl FakeEmbeddingProvider {
556 fn embedding_count(&self) -> usize {
557 self.embedding_count.load(atomic::Ordering::SeqCst)
558 }
559}
560
561#[async_trait]
562impl EmbeddingProvider for FakeEmbeddingProvider {
563 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
564 self.embedding_count
565 .fetch_add(spans.len(), atomic::Ordering::SeqCst);
566 Ok(spans
567 .iter()
568 .map(|span| {
569 let mut result = vec![1.0; 26];
570 for letter in span.chars() {
571 let letter = letter.to_ascii_lowercase();
572 if letter as u32 >= 'a' as u32 {
573 let ix = (letter as u32) - ('a' as u32);
574 if ix < 26 {
575 result[ix as usize] += 1.0;
576 }
577 }
578 }
579
580 let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
581 for x in &mut result {
582 *x /= norm;
583 }
584
585 result
586 })
587 .collect())
588 }
589}
590
591fn js_lang() -> Arc<Language> {
592 Arc::new(
593 Language::new(
594 LanguageConfig {
595 name: "Javascript".into(),
596 path_suffixes: vec!["js".into()],
597 ..Default::default()
598 },
599 Some(tree_sitter_typescript::language_tsx()),
600 )
601 .with_embedding_query(
602 &r#"
603
604 (
605 (comment)* @context
606 .
607 (export_statement
608 (function_declaration
609 "async"? @name
610 "function" @name
611 name: (_) @name)) @item
612 )
613
614 (
615 (comment)* @context
616 .
617 (function_declaration
618 "async"? @name
619 "function" @name
620 name: (_) @name) @item
621 )
622
623 (
624 (comment)* @context
625 .
626 (export_statement
627 (class_declaration
628 "class" @name
629 name: (_) @name)) @item
630 )
631
632 (
633 (comment)* @context
634 .
635 (class_declaration
636 "class" @name
637 name: (_) @name) @item
638 )
639
640 (
641 (comment)* @context
642 .
643 (method_definition
644 [
645 "get"
646 "set"
647 "async"
648 "*"
649 "static"
650 ]* @name
651 name: (_) @name) @item
652 )
653
654 (
655 (comment)* @context
656 .
657 (export_statement
658 (interface_declaration
659 "interface" @name
660 name: (_) @name)) @item
661 )
662
663 (
664 (comment)* @context
665 .
666 (interface_declaration
667 "interface" @name
668 name: (_) @name) @item
669 )
670
671 (
672 (comment)* @context
673 .
674 (export_statement
675 (enum_declaration
676 "enum" @name
677 name: (_) @name)) @item
678 )
679
680 (
681 (comment)* @context
682 .
683 (enum_declaration
684 "enum" @name
685 name: (_) @name) @item
686 )
687
688 "#
689 .unindent(),
690 )
691 .unwrap(),
692 )
693}
694
695fn rust_lang() -> Arc<Language> {
696 Arc::new(
697 Language::new(
698 LanguageConfig {
699 name: "Rust".into(),
700 path_suffixes: vec!["rs".into()],
701 ..Default::default()
702 },
703 Some(tree_sitter_rust::language()),
704 )
705 .with_embedding_query(
706 r#"
707 (
708 (line_comment)* @context
709 .
710 (enum_item
711 name: (_) @name) @item
712 )
713
714 (
715 (line_comment)* @context
716 .
717 (struct_item
718 name: (_) @name) @item
719 )
720
721 (
722 (line_comment)* @context
723 .
724 (impl_item
725 trait: (_)? @name
726 "for"? @name
727 type: (_) @name) @item
728 )
729
730 (
731 (line_comment)* @context
732 .
733 (trait_item
734 name: (_) @name) @item
735 )
736
737 (
738 (line_comment)* @context
739 .
740 (function_item
741 name: (_) @name) @item
742 )
743
744 (
745 (line_comment)* @context
746 .
747 (macro_definition
748 name: (_) @name) @item
749 )
750
751 (
752 (line_comment)* @context
753 .
754 (function_signature_item
755 name: (_) @name) @item
756 )
757 "#,
758 )
759 .unwrap(),
760 )
761}
762
763fn toml_lang() -> Arc<Language> {
764 Arc::new(Language::new(
765 LanguageConfig {
766 name: "TOML".into(),
767 path_suffixes: vec!["toml".into()],
768 ..Default::default()
769 },
770 Some(tree_sitter_toml::language()),
771 ))
772}
773
774fn cpp_lang() -> Arc<Language> {
775 Arc::new(
776 Language::new(
777 LanguageConfig {
778 name: "CPP".into(),
779 path_suffixes: vec!["cpp".into()],
780 ..Default::default()
781 },
782 Some(tree_sitter_cpp::language()),
783 )
784 .with_embedding_query(
785 r#"
786 (
787 (comment)* @context
788 .
789 (function_definition
790 (type_qualifier)? @name
791 type: (_)? @name
792 declarator: [
793 (function_declarator
794 declarator: (_) @name)
795 (pointer_declarator
796 "*" @name
797 declarator: (function_declarator
798 declarator: (_) @name))
799 (pointer_declarator
800 "*" @name
801 declarator: (pointer_declarator
802 "*" @name
803 declarator: (function_declarator
804 declarator: (_) @name)))
805 (reference_declarator
806 ["&" "&&"] @name
807 (function_declarator
808 declarator: (_) @name))
809 ]
810 (type_qualifier)? @name) @item
811 )
812
813 (
814 (comment)* @context
815 .
816 (template_declaration
817 (class_specifier
818 "class" @name
819 name: (_) @name)
820 ) @item
821 )
822
823 (
824 (comment)* @context
825 .
826 (class_specifier
827 "class" @name
828 name: (_) @name) @item
829 )
830
831 (
832 (comment)* @context
833 .
834 (enum_specifier
835 "enum" @name
836 name: (_) @name) @item
837 )
838
839 (
840 (comment)* @context
841 .
842 (declaration
843 type: (struct_specifier
844 "struct" @name)
845 declarator: (_) @name) @item
846 )
847
848 "#,
849 )
850 .unwrap(),
851 )
852}