1use crate::{
2 embedding_queue::EmbeddingQueue,
3 parsing::{subtract_ranges, CodeContextRetriever, Span, SpanDigest},
4 semantic_index_settings::SemanticIndexSettings,
5 FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
6};
7use ai::test::FakeEmbeddingProvider;
8use gpui::TestAppContext;
9use language::{Language, LanguageConfig, LanguageMatcher, LanguageRegistry, ToOffset};
10use parking_lot::Mutex;
11use pretty_assertions::assert_eq;
12use project::{FakeFs, Fs, Project};
13use rand::{rngs::StdRng, Rng};
14use serde_json::json;
15use settings::{Settings, SettingsStore};
16use std::{path::Path, sync::Arc, time::SystemTime};
17use unindent::Unindent;
18use util::{paths::PathMatcher, RandomCharIter};
19
20#[ctor::ctor]
21fn init_logger() {
22 if std::env::var("RUST_LOG").is_ok() {
23 env_logger::init();
24 }
25}
26
27#[gpui::test]
28async fn test_semantic_index(cx: &mut TestAppContext) {
29 init_test(cx);
30
31 let fs = FakeFs::new(cx.background_executor.clone());
32 fs.insert_tree(
33 "/the-root",
34 json!({
35 "src": {
36 "file1.rs": "
37 fn aaa() {
38 println!(\"aaaaaaaaaaaa!\");
39 }
40
41 fn zzzzz() {
42 println!(\"SLEEPING\");
43 }
44 ".unindent(),
45 "file2.rs": "
46 fn bbb() {
47 println!(\"bbbbbbbbbbbbb!\");
48 }
49 struct pqpqpqp {}
50 ".unindent(),
51 "file3.toml": "
52 ZZZZZZZZZZZZZZZZZZ = 5
53 ".unindent(),
54 }
55 }),
56 )
57 .await;
58
59 let languages = Arc::new(LanguageRegistry::test(cx.executor().clone()));
60 let rust_language = rust_lang();
61 let toml_language = toml_lang();
62 languages.add(rust_language);
63 languages.add(toml_language);
64
65 let db_dir = tempfile::Builder::new()
66 .prefix("vector-store")
67 .tempdir()
68 .unwrap();
69 let db_path = db_dir.path().join("db.sqlite");
70
71 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
72 let semantic_index = SemanticIndex::new(
73 fs.clone(),
74 db_path,
75 embedding_provider.clone(),
76 languages,
77 cx.to_async(),
78 )
79 .await
80 .unwrap();
81
82 let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
83
84 let search_results = semantic_index.update(cx, |store, cx| {
85 store.search_project(
86 project.clone(),
87 "aaaaaabbbbzz".to_string(),
88 5,
89 vec![],
90 vec![],
91 cx,
92 )
93 });
94 let pending_file_count =
95 semantic_index.read_with(cx, |index, _| index.pending_file_count(&project).unwrap());
96 cx.background_executor.run_until_parked();
97 assert_eq!(*pending_file_count.borrow(), 3);
98 cx.background_executor
99 .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
100 assert_eq!(*pending_file_count.borrow(), 0);
101
102 let search_results = search_results.await.unwrap();
103 assert_search_results(
104 &search_results,
105 &[
106 (Path::new("src/file1.rs").into(), 0),
107 (Path::new("src/file2.rs").into(), 0),
108 (Path::new("src/file3.toml").into(), 0),
109 (Path::new("src/file1.rs").into(), 45),
110 (Path::new("src/file2.rs").into(), 45),
111 ],
112 cx,
113 );
114
115 // Test Include Files Functionality
116 let include_files = vec![PathMatcher::new("*.rs").unwrap()];
117 let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
118 let rust_only_search_results = semantic_index
119 .update(cx, |store, cx| {
120 store.search_project(
121 project.clone(),
122 "aaaaaabbbbzz".to_string(),
123 5,
124 include_files,
125 vec![],
126 cx,
127 )
128 })
129 .await
130 .unwrap();
131
132 assert_search_results(
133 &rust_only_search_results,
134 &[
135 (Path::new("src/file1.rs").into(), 0),
136 (Path::new("src/file2.rs").into(), 0),
137 (Path::new("src/file1.rs").into(), 45),
138 (Path::new("src/file2.rs").into(), 45),
139 ],
140 cx,
141 );
142
143 let no_rust_search_results = semantic_index
144 .update(cx, |store, cx| {
145 store.search_project(
146 project.clone(),
147 "aaaaaabbbbzz".to_string(),
148 5,
149 vec![],
150 exclude_files,
151 cx,
152 )
153 })
154 .await
155 .unwrap();
156
157 assert_search_results(
158 &no_rust_search_results,
159 &[(Path::new("src/file3.toml").into(), 0)],
160 cx,
161 );
162
163 fs.save(
164 "/the-root/src/file2.rs".as_ref(),
165 &"
166 fn dddd() { println!(\"ddddd!\"); }
167 struct pqpqpqp {}
168 "
169 .unindent()
170 .into(),
171 Default::default(),
172 )
173 .await
174 .unwrap();
175
176 cx.background_executor
177 .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
178
179 let prev_embedding_count = embedding_provider.embedding_count();
180 let index = semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx));
181 cx.background_executor.run_until_parked();
182 assert_eq!(*pending_file_count.borrow(), 1);
183 cx.background_executor
184 .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
185 assert_eq!(*pending_file_count.borrow(), 0);
186 index.await.unwrap();
187
188 assert_eq!(
189 embedding_provider.embedding_count() - prev_embedding_count,
190 1
191 );
192}
193
194#[gpui::test(iterations = 10)]
195async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
196 let (outstanding_job_count, _) = postage::watch::channel_with(0);
197 let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count));
198
199 let files = (1..=3)
200 .map(|file_ix| FileToEmbed {
201 worktree_id: 5,
202 path: Path::new(&format!("path-{file_ix}")).into(),
203 mtime: SystemTime::now(),
204 spans: (0..rng.gen_range(4..22))
205 .map(|document_ix| {
206 let content_len = rng.gen_range(10..100);
207 let content = RandomCharIter::new(&mut rng)
208 .with_simple_text()
209 .take(content_len)
210 .collect::<String>();
211 let digest = SpanDigest::from(content.as_str());
212 Span {
213 range: 0..10,
214 embedding: None,
215 name: format!("document {document_ix}"),
216 content,
217 digest,
218 token_count: rng.gen_range(10..30),
219 }
220 })
221 .collect(),
222 job_handle: JobHandle::new(&outstanding_job_count),
223 })
224 .collect::<Vec<_>>();
225
226 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
227
228 let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor.clone());
229 for file in &files {
230 queue.push(file.clone());
231 }
232 queue.flush();
233
234 cx.background_executor.run_until_parked();
235 let finished_files = queue.finished_files();
236 let mut embedded_files: Vec<_> = files
237 .iter()
238 .map(|_| finished_files.try_recv().expect("no finished file"))
239 .collect();
240
241 let expected_files: Vec<_> = files
242 .iter()
243 .map(|file| {
244 let mut file = file.clone();
245 for doc in &mut file.spans {
246 doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
247 }
248 file
249 })
250 .collect();
251
252 embedded_files.sort_by_key(|f| f.path.clone());
253
254 assert_eq!(embedded_files, expected_files);
255}
256
257#[track_caller]
258fn assert_search_results(
259 actual: &[SearchResult],
260 expected: &[(Arc<Path>, usize)],
261 cx: &TestAppContext,
262) {
263 let actual = actual
264 .iter()
265 .map(|search_result| {
266 search_result.buffer.read_with(cx, |buffer, _cx| {
267 (
268 buffer.file().unwrap().path().clone(),
269 search_result.range.start.to_offset(buffer),
270 )
271 })
272 })
273 .collect::<Vec<_>>();
274 assert_eq!(actual, expected);
275}
276
277#[gpui::test]
278async fn test_code_context_retrieval_rust() {
279 let language = rust_lang();
280 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
281 let mut retriever = CodeContextRetriever::new(embedding_provider);
282
283 let text = "
284 /// A doc comment
285 /// that spans multiple lines
286 #[gpui::test]
287 fn a() {
288 b
289 }
290
291 impl C for D {
292 }
293
294 impl E {
295 // This is also a preceding comment
296 pub fn function_1() -> Option<()> {
297 unimplemented!();
298 }
299
300 // This is a preceding comment
301 fn function_2() -> Result<()> {
302 unimplemented!();
303 }
304 }
305
306 #[derive(Clone)]
307 struct D {
308 name: String
309 }
310 "
311 .unindent();
312
313 let documents = retriever.parse_file(&text, language).unwrap();
314
315 assert_documents_eq(
316 &documents,
317 &[
318 (
319 "
320 /// A doc comment
321 /// that spans multiple lines
322 #[gpui::test]
323 fn a() {
324 b
325 }"
326 .unindent(),
327 text.find("fn a").unwrap(),
328 ),
329 (
330 "
331 impl C for D {
332 }"
333 .unindent(),
334 text.find("impl C").unwrap(),
335 ),
336 (
337 "
338 impl E {
339 // This is also a preceding comment
340 pub fn function_1() -> Option<()> { /* ... */ }
341
342 // This is a preceding comment
343 fn function_2() -> Result<()> { /* ... */ }
344 }"
345 .unindent(),
346 text.find("impl E").unwrap(),
347 ),
348 (
349 "
350 // This is also a preceding comment
351 pub fn function_1() -> Option<()> {
352 unimplemented!();
353 }"
354 .unindent(),
355 text.find("pub fn function_1").unwrap(),
356 ),
357 (
358 "
359 // This is a preceding comment
360 fn function_2() -> Result<()> {
361 unimplemented!();
362 }"
363 .unindent(),
364 text.find("fn function_2").unwrap(),
365 ),
366 (
367 "
368 #[derive(Clone)]
369 struct D {
370 name: String
371 }"
372 .unindent(),
373 text.find("struct D").unwrap(),
374 ),
375 ],
376 );
377}
378
379#[gpui::test]
380async fn test_code_context_retrieval_json() {
381 let language = json_lang();
382 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
383 let mut retriever = CodeContextRetriever::new(embedding_provider);
384
385 let text = r#"
386 {
387 "array": [1, 2, 3, 4],
388 "string": "abcdefg",
389 "nested_object": {
390 "array_2": [5, 6, 7, 8],
391 "string_2": "hijklmnop",
392 "boolean": true,
393 "none": null
394 }
395 }
396 "#
397 .unindent();
398
399 let documents = retriever.parse_file(&text, language.clone()).unwrap();
400
401 assert_documents_eq(
402 &documents,
403 &[(
404 r#"
405 {
406 "array": [],
407 "string": "",
408 "nested_object": {
409 "array_2": [],
410 "string_2": "",
411 "boolean": true,
412 "none": null
413 }
414 }"#
415 .unindent(),
416 text.find('{').unwrap(),
417 )],
418 );
419
420 let text = r#"
421 [
422 {
423 "name": "somebody",
424 "age": 42
425 },
426 {
427 "name": "somebody else",
428 "age": 43
429 }
430 ]
431 "#
432 .unindent();
433
434 let documents = retriever.parse_file(&text, language.clone()).unwrap();
435
436 assert_documents_eq(
437 &documents,
438 &[(
439 r#"
440 [{
441 "name": "",
442 "age": 42
443 }]"#
444 .unindent(),
445 text.find('[').unwrap(),
446 )],
447 );
448}
449
450fn assert_documents_eq(
451 documents: &[Span],
452 expected_contents_and_start_offsets: &[(String, usize)],
453) {
454 assert_eq!(
455 documents
456 .iter()
457 .map(|document| (document.content.clone(), document.range.start))
458 .collect::<Vec<_>>(),
459 expected_contents_and_start_offsets
460 );
461}
462
463#[gpui::test]
464async fn test_code_context_retrieval_javascript() {
465 let language = js_lang();
466 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
467 let mut retriever = CodeContextRetriever::new(embedding_provider);
468
469 let text = "
470 /* globals importScripts, backend */
471 function _authorize() {}
472
473 /**
474 * Sometimes the frontend build is way faster than backend.
475 */
476 export async function authorizeBank() {
477 _authorize(pushModal, upgradingAccountId, {});
478 }
479
480 export class SettingsPage {
481 /* This is a test setting */
482 constructor(page) {
483 this.page = page;
484 }
485 }
486
487 /* This is a test comment */
488 class TestClass {}
489
490 /* Schema for editor_events in Clickhouse. */
491 export interface ClickhouseEditorEvent {
492 installation_id: string
493 operation: string
494 }
495 "
496 .unindent();
497
498 let documents = retriever.parse_file(&text, language.clone()).unwrap();
499
500 assert_documents_eq(
501 &documents,
502 &[
503 (
504 "
505 /* globals importScripts, backend */
506 function _authorize() {}"
507 .unindent(),
508 37,
509 ),
510 (
511 "
512 /**
513 * Sometimes the frontend build is way faster than backend.
514 */
515 export async function authorizeBank() {
516 _authorize(pushModal, upgradingAccountId, {});
517 }"
518 .unindent(),
519 131,
520 ),
521 (
522 "
523 export class SettingsPage {
524 /* This is a test setting */
525 constructor(page) {
526 this.page = page;
527 }
528 }"
529 .unindent(),
530 225,
531 ),
532 (
533 "
534 /* This is a test setting */
535 constructor(page) {
536 this.page = page;
537 }"
538 .unindent(),
539 290,
540 ),
541 (
542 "
543 /* This is a test comment */
544 class TestClass {}"
545 .unindent(),
546 374,
547 ),
548 (
549 "
550 /* Schema for editor_events in Clickhouse. */
551 export interface ClickhouseEditorEvent {
552 installation_id: string
553 operation: string
554 }"
555 .unindent(),
556 440,
557 ),
558 ],
559 )
560}
561
562#[gpui::test]
563async fn test_code_context_retrieval_lua() {
564 let language = lua_lang();
565 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
566 let mut retriever = CodeContextRetriever::new(embedding_provider);
567
568 let text = r#"
569 -- Creates a new class
570 -- @param baseclass The Baseclass of this class, or nil.
571 -- @return A new class reference.
572 function classes.class(baseclass)
573 -- Create the class definition and metatable.
574 local classdef = {}
575 -- Find the super class, either Object or user-defined.
576 baseclass = baseclass or classes.Object
577 -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
578 setmetatable(classdef, { __index = baseclass })
579 -- All class instances have a reference to the class object.
580 classdef.class = classdef
581 --- Recursively allocates the inheritance tree of the instance.
582 -- @param mastertable The 'root' of the inheritance tree.
583 -- @return Returns the instance with the allocated inheritance tree.
584 function classdef.alloc(mastertable)
585 -- All class instances have a reference to a superclass object.
586 local instance = { super = baseclass.alloc(mastertable) }
587 -- Any functions this instance does not know of will 'look up' to the superclass definition.
588 setmetatable(instance, { __index = classdef, __newindex = mastertable })
589 return instance
590 end
591 end
592 "#.unindent();
593
594 let documents = retriever.parse_file(&text, language.clone()).unwrap();
595
596 assert_documents_eq(
597 &documents,
598 &[
599 (r#"
600 -- Creates a new class
601 -- @param baseclass The Baseclass of this class, or nil.
602 -- @return A new class reference.
603 function classes.class(baseclass)
604 -- Create the class definition and metatable.
605 local classdef = {}
606 -- Find the super class, either Object or user-defined.
607 baseclass = baseclass or classes.Object
608 -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
609 setmetatable(classdef, { __index = baseclass })
610 -- All class instances have a reference to the class object.
611 classdef.class = classdef
612 --- Recursively allocates the inheritance tree of the instance.
613 -- @param mastertable The 'root' of the inheritance tree.
614 -- @return Returns the instance with the allocated inheritance tree.
615 function classdef.alloc(mastertable)
616 --[ ... ]--
617 --[ ... ]--
618 end
619 end"#.unindent(),
620 114),
621 (r#"
622 --- Recursively allocates the inheritance tree of the instance.
623 -- @param mastertable The 'root' of the inheritance tree.
624 -- @return Returns the instance with the allocated inheritance tree.
625 function classdef.alloc(mastertable)
626 -- All class instances have a reference to a superclass object.
627 local instance = { super = baseclass.alloc(mastertable) }
628 -- Any functions this instance does not know of will 'look up' to the superclass definition.
629 setmetatable(instance, { __index = classdef, __newindex = mastertable })
630 return instance
631 end"#.unindent(), 810),
632 ]
633 );
634}
635
636#[gpui::test]
637async fn test_code_context_retrieval_elixir() {
638 let language = elixir_lang();
639 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
640 let mut retriever = CodeContextRetriever::new(embedding_provider);
641
642 let text = r#"
643 defmodule File.Stream do
644 @moduledoc """
645 Defines a `File.Stream` struct returned by `File.stream!/3`.
646
647 The following fields are public:
648
649 * `path` - the file path
650 * `modes` - the file modes
651 * `raw` - a boolean indicating if bin functions should be used
652 * `line_or_bytes` - if reading should read lines or a given number of bytes
653 * `node` - the node the file belongs to
654
655 """
656
657 defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
658
659 @type t :: %__MODULE__{}
660
661 @doc false
662 def __build__(path, modes, line_or_bytes) do
663 raw = :lists.keyfind(:encoding, 1, modes) == false
664
665 modes =
666 case raw do
667 true ->
668 case :lists.keyfind(:read_ahead, 1, modes) do
669 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
670 {:read_ahead, _} -> [:raw | modes]
671 false -> [:raw, :read_ahead | modes]
672 end
673
674 false ->
675 modes
676 end
677
678 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
679
680 end"#
681 .unindent();
682
683 let documents = retriever.parse_file(&text, language.clone()).unwrap();
684
685 assert_documents_eq(
686 &documents,
687 &[(
688 r#"
689 defmodule File.Stream do
690 @moduledoc """
691 Defines a `File.Stream` struct returned by `File.stream!/3`.
692
693 The following fields are public:
694
695 * `path` - the file path
696 * `modes` - the file modes
697 * `raw` - a boolean indicating if bin functions should be used
698 * `line_or_bytes` - if reading should read lines or a given number of bytes
699 * `node` - the node the file belongs to
700
701 """
702
703 defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
704
705 @type t :: %__MODULE__{}
706
707 @doc false
708 def __build__(path, modes, line_or_bytes) do
709 raw = :lists.keyfind(:encoding, 1, modes) == false
710
711 modes =
712 case raw do
713 true ->
714 case :lists.keyfind(:read_ahead, 1, modes) do
715 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
716 {:read_ahead, _} -> [:raw | modes]
717 false -> [:raw, :read_ahead | modes]
718 end
719
720 false ->
721 modes
722 end
723
724 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
725
726 end"#
727 .unindent(),
728 0,
729 ),(r#"
730 @doc false
731 def __build__(path, modes, line_or_bytes) do
732 raw = :lists.keyfind(:encoding, 1, modes) == false
733
734 modes =
735 case raw do
736 true ->
737 case :lists.keyfind(:read_ahead, 1, modes) do
738 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
739 {:read_ahead, _} -> [:raw | modes]
740 false -> [:raw, :read_ahead | modes]
741 end
742
743 false ->
744 modes
745 end
746
747 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
748
749 end"#.unindent(), 574)],
750 );
751}
752
753#[gpui::test]
754async fn test_code_context_retrieval_cpp() {
755 let language = cpp_lang();
756 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
757 let mut retriever = CodeContextRetriever::new(embedding_provider);
758
759 let text = "
760 /**
761 * @brief Main function
762 * @returns 0 on exit
763 */
764 int main() { return 0; }
765
766 /**
767 * This is a test comment
768 */
769 class MyClass { // The class
770 public: // Access specifier
771 int myNum; // Attribute (int variable)
772 string myString; // Attribute (string variable)
773 };
774
775 // This is a test comment
776 enum Color { red, green, blue };
777
778 /** This is a preceding block comment
779 * This is the second line
780 */
781 struct { // Structure declaration
782 int myNum; // Member (int variable)
783 string myString; // Member (string variable)
784 } myStructure;
785
786 /**
787 * @brief Matrix class.
788 */
789 template <typename T,
790 typename = typename std::enable_if<
791 std::is_integral<T>::value || std::is_floating_point<T>::value,
792 bool>::type>
793 class Matrix2 {
794 std::vector<std::vector<T>> _mat;
795
796 public:
797 /**
798 * @brief Constructor
799 * @tparam Integer ensuring integers are being evaluated and not other
800 * data types.
801 * @param size denoting the size of Matrix as size x size
802 */
803 template <typename Integer,
804 typename = typename std::enable_if<std::is_integral<Integer>::value,
805 Integer>::type>
806 explicit Matrix(const Integer size) {
807 for (size_t i = 0; i < size; ++i) {
808 _mat.emplace_back(std::vector<T>(size, 0));
809 }
810 }
811 }"
812 .unindent();
813
814 let documents = retriever.parse_file(&text, language.clone()).unwrap();
815
816 assert_documents_eq(
817 &documents,
818 &[
819 (
820 "
821 /**
822 * @brief Main function
823 * @returns 0 on exit
824 */
825 int main() { return 0; }"
826 .unindent(),
827 54,
828 ),
829 (
830 "
831 /**
832 * This is a test comment
833 */
834 class MyClass { // The class
835 public: // Access specifier
836 int myNum; // Attribute (int variable)
837 string myString; // Attribute (string variable)
838 }"
839 .unindent(),
840 112,
841 ),
842 (
843 "
844 // This is a test comment
845 enum Color { red, green, blue }"
846 .unindent(),
847 322,
848 ),
849 (
850 "
851 /** This is a preceding block comment
852 * This is the second line
853 */
854 struct { // Structure declaration
855 int myNum; // Member (int variable)
856 string myString; // Member (string variable)
857 } myStructure;"
858 .unindent(),
859 425,
860 ),
861 (
862 "
863 /**
864 * @brief Matrix class.
865 */
866 template <typename T,
867 typename = typename std::enable_if<
868 std::is_integral<T>::value || std::is_floating_point<T>::value,
869 bool>::type>
870 class Matrix2 {
871 std::vector<std::vector<T>> _mat;
872
873 public:
874 /**
875 * @brief Constructor
876 * @tparam Integer ensuring integers are being evaluated and not other
877 * data types.
878 * @param size denoting the size of Matrix as size x size
879 */
880 template <typename Integer,
881 typename = typename std::enable_if<std::is_integral<Integer>::value,
882 Integer>::type>
883 explicit Matrix(const Integer size) {
884 for (size_t i = 0; i < size; ++i) {
885 _mat.emplace_back(std::vector<T>(size, 0));
886 }
887 }
888 }"
889 .unindent(),
890 612,
891 ),
892 (
893 "
894 explicit Matrix(const Integer size) {
895 for (size_t i = 0; i < size; ++i) {
896 _mat.emplace_back(std::vector<T>(size, 0));
897 }
898 }"
899 .unindent(),
900 1226,
901 ),
902 ],
903 );
904}
905
906#[gpui::test]
907async fn test_code_context_retrieval_ruby() {
908 let language = ruby_lang();
909 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
910 let mut retriever = CodeContextRetriever::new(embedding_provider);
911
912 let text = r#"
913 # This concern is inspired by "sudo mode" on GitHub. It
914 # is a way to re-authenticate a user before allowing them
915 # to see or perform an action.
916 #
917 # Add `before_action :require_challenge!` to actions you
918 # want to protect.
919 #
920 # The user will be shown a page to enter the challenge (which
921 # is either the password, or just the username when no
922 # password exists). Upon passing, there is a grace period
923 # during which no challenge will be asked from the user.
924 #
925 # Accessing challenge-protected resources during the grace
926 # period will refresh the grace period.
927 module ChallengableConcern
928 extend ActiveSupport::Concern
929
930 CHALLENGE_TIMEOUT = 1.hour.freeze
931
932 def require_challenge!
933 return if skip_challenge?
934
935 if challenge_passed_recently?
936 session[:challenge_passed_at] = Time.now.utc
937 return
938 end
939
940 @challenge = Form::Challenge.new(return_to: request.url)
941
942 if params.key?(:form_challenge)
943 if challenge_passed?
944 session[:challenge_passed_at] = Time.now.utc
945 else
946 flash.now[:alert] = I18n.t('challenge.invalid_password')
947 render_challenge
948 end
949 else
950 render_challenge
951 end
952 end
953
954 def challenge_passed?
955 current_user.valid_password?(challenge_params[:current_password])
956 end
957 end
958
959 class Animal
960 include Comparable
961
962 attr_reader :legs
963
964 def initialize(name, legs)
965 @name, @legs = name, legs
966 end
967
968 def <=>(other)
969 legs <=> other.legs
970 end
971 end
972
973 # Singleton method for car object
974 def car.wheels
975 puts "There are four wheels"
976 end"#
977 .unindent();
978
979 let documents = retriever.parse_file(&text, language.clone()).unwrap();
980
981 assert_documents_eq(
982 &documents,
983 &[
984 (
985 r#"
986 # This concern is inspired by "sudo mode" on GitHub. It
987 # is a way to re-authenticate a user before allowing them
988 # to see or perform an action.
989 #
990 # Add `before_action :require_challenge!` to actions you
991 # want to protect.
992 #
993 # The user will be shown a page to enter the challenge (which
994 # is either the password, or just the username when no
995 # password exists). Upon passing, there is a grace period
996 # during which no challenge will be asked from the user.
997 #
998 # Accessing challenge-protected resources during the grace
999 # period will refresh the grace period.
1000 module ChallengableConcern
1001 extend ActiveSupport::Concern
1002
1003 CHALLENGE_TIMEOUT = 1.hour.freeze
1004
1005 def require_challenge!
1006 # ...
1007 end
1008
1009 def challenge_passed?
1010 # ...
1011 end
1012 end"#
1013 .unindent(),
1014 558,
1015 ),
1016 (
1017 r#"
1018 def require_challenge!
1019 return if skip_challenge?
1020
1021 if challenge_passed_recently?
1022 session[:challenge_passed_at] = Time.now.utc
1023 return
1024 end
1025
1026 @challenge = Form::Challenge.new(return_to: request.url)
1027
1028 if params.key?(:form_challenge)
1029 if challenge_passed?
1030 session[:challenge_passed_at] = Time.now.utc
1031 else
1032 flash.now[:alert] = I18n.t('challenge.invalid_password')
1033 render_challenge
1034 end
1035 else
1036 render_challenge
1037 end
1038 end"#
1039 .unindent(),
1040 663,
1041 ),
1042 (
1043 r#"
1044 def challenge_passed?
1045 current_user.valid_password?(challenge_params[:current_password])
1046 end"#
1047 .unindent(),
1048 1254,
1049 ),
1050 (
1051 r#"
1052 class Animal
1053 include Comparable
1054
1055 attr_reader :legs
1056
1057 def initialize(name, legs)
1058 # ...
1059 end
1060
1061 def <=>(other)
1062 # ...
1063 end
1064 end"#
1065 .unindent(),
1066 1363,
1067 ),
1068 (
1069 r#"
1070 def initialize(name, legs)
1071 @name, @legs = name, legs
1072 end"#
1073 .unindent(),
1074 1427,
1075 ),
1076 (
1077 r#"
1078 def <=>(other)
1079 legs <=> other.legs
1080 end"#
1081 .unindent(),
1082 1501,
1083 ),
1084 (
1085 r#"
1086 # Singleton method for car object
1087 def car.wheels
1088 puts "There are four wheels"
1089 end"#
1090 .unindent(),
1091 1591,
1092 ),
1093 ],
1094 );
1095}
1096
1097#[gpui::test]
1098async fn test_code_context_retrieval_php() {
1099 let language = php_lang();
1100 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
1101 let mut retriever = CodeContextRetriever::new(embedding_provider);
1102
1103 let text = r#"
1104 <?php
1105
1106 namespace LevelUp\Experience\Concerns;
1107
1108 /*
1109 This is a multiple-lines comment block
1110 that spans over multiple
1111 lines
1112 */
1113 function functionName() {
1114 echo "Hello world!";
1115 }
1116
1117 trait HasAchievements
1118 {
1119 /**
1120 * @throws \Exception
1121 */
1122 public function grantAchievement(Achievement $achievement, $progress = null): void
1123 {
1124 if ($progress > 100) {
1125 throw new Exception(message: 'Progress cannot be greater than 100');
1126 }
1127
1128 if ($this->achievements()->find($achievement->id)) {
1129 throw new Exception(message: 'User already has this Achievement');
1130 }
1131
1132 $this->achievements()->attach($achievement, [
1133 'progress' => $progress ?? null,
1134 ]);
1135
1136 $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1137 }
1138
1139 public function achievements(): BelongsToMany
1140 {
1141 return $this->belongsToMany(related: Achievement::class)
1142 ->withPivot(columns: 'progress')
1143 ->where('is_secret', false)
1144 ->using(AchievementUser::class);
1145 }
1146 }
1147
1148 interface Multiplier
1149 {
1150 public function qualifies(array $data): bool;
1151
1152 public function setMultiplier(): int;
1153 }
1154
1155 enum AuditType: string
1156 {
1157 case Add = 'add';
1158 case Remove = 'remove';
1159 case Reset = 'reset';
1160 case LevelUp = 'level_up';
1161 }
1162
1163 ?>"#
1164 .unindent();
1165
1166 let documents = retriever.parse_file(&text, language.clone()).unwrap();
1167
1168 assert_documents_eq(
1169 &documents,
1170 &[
1171 (
1172 r#"
1173 /*
1174 This is a multiple-lines comment block
1175 that spans over multiple
1176 lines
1177 */
1178 function functionName() {
1179 echo "Hello world!";
1180 }"#
1181 .unindent(),
1182 123,
1183 ),
1184 (
1185 r#"
1186 trait HasAchievements
1187 {
1188 /**
1189 * @throws \Exception
1190 */
1191 public function grantAchievement(Achievement $achievement, $progress = null): void
1192 {/* ... */}
1193
1194 public function achievements(): BelongsToMany
1195 {/* ... */}
1196 }"#
1197 .unindent(),
1198 177,
1199 ),
1200 (r#"
1201 /**
1202 * @throws \Exception
1203 */
1204 public function grantAchievement(Achievement $achievement, $progress = null): void
1205 {
1206 if ($progress > 100) {
1207 throw new Exception(message: 'Progress cannot be greater than 100');
1208 }
1209
1210 if ($this->achievements()->find($achievement->id)) {
1211 throw new Exception(message: 'User already has this Achievement');
1212 }
1213
1214 $this->achievements()->attach($achievement, [
1215 'progress' => $progress ?? null,
1216 ]);
1217
1218 $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1219 }"#.unindent(), 245),
1220 (r#"
1221 public function achievements(): BelongsToMany
1222 {
1223 return $this->belongsToMany(related: Achievement::class)
1224 ->withPivot(columns: 'progress')
1225 ->where('is_secret', false)
1226 ->using(AchievementUser::class);
1227 }"#.unindent(), 902),
1228 (r#"
1229 interface Multiplier
1230 {
1231 public function qualifies(array $data): bool;
1232
1233 public function setMultiplier(): int;
1234 }"#.unindent(),
1235 1146),
1236 (r#"
1237 enum AuditType: string
1238 {
1239 case Add = 'add';
1240 case Remove = 'remove';
1241 case Reset = 'reset';
1242 case LevelUp = 'level_up';
1243 }"#.unindent(), 1265)
1244 ],
1245 );
1246}
1247
1248fn js_lang() -> Arc<Language> {
1249 Arc::new(
1250 Language::new(
1251 LanguageConfig {
1252 name: "Javascript".into(),
1253 matcher: LanguageMatcher {
1254 path_suffixes: vec!["js".into()],
1255 ..Default::default()
1256 },
1257 ..Default::default()
1258 },
1259 Some(tree_sitter_typescript::language_tsx()),
1260 )
1261 .with_embedding_query(
1262 &r#"
1263
1264 (
1265 (comment)* @context
1266 .
1267 [
1268 (export_statement
1269 (function_declaration
1270 "async"? @name
1271 "function" @name
1272 name: (_) @name))
1273 (function_declaration
1274 "async"? @name
1275 "function" @name
1276 name: (_) @name)
1277 ] @item
1278 )
1279
1280 (
1281 (comment)* @context
1282 .
1283 [
1284 (export_statement
1285 (class_declaration
1286 "class" @name
1287 name: (_) @name))
1288 (class_declaration
1289 "class" @name
1290 name: (_) @name)
1291 ] @item
1292 )
1293
1294 (
1295 (comment)* @context
1296 .
1297 [
1298 (export_statement
1299 (interface_declaration
1300 "interface" @name
1301 name: (_) @name))
1302 (interface_declaration
1303 "interface" @name
1304 name: (_) @name)
1305 ] @item
1306 )
1307
1308 (
1309 (comment)* @context
1310 .
1311 [
1312 (export_statement
1313 (enum_declaration
1314 "enum" @name
1315 name: (_) @name))
1316 (enum_declaration
1317 "enum" @name
1318 name: (_) @name)
1319 ] @item
1320 )
1321
1322 (
1323 (comment)* @context
1324 .
1325 (method_definition
1326 [
1327 "get"
1328 "set"
1329 "async"
1330 "*"
1331 "static"
1332 ]* @name
1333 name: (_) @name) @item
1334 )
1335
1336 "#
1337 .unindent(),
1338 )
1339 .unwrap(),
1340 )
1341}
1342
1343fn rust_lang() -> Arc<Language> {
1344 Arc::new(
1345 Language::new(
1346 LanguageConfig {
1347 name: "Rust".into(),
1348 matcher: LanguageMatcher {
1349 path_suffixes: vec!["rs".into()],
1350 ..Default::default()
1351 },
1352 collapsed_placeholder: " /* ... */ ".to_string(),
1353 ..Default::default()
1354 },
1355 Some(tree_sitter_rust::language()),
1356 )
1357 .with_embedding_query(
1358 r#"
1359 (
1360 [(line_comment) (attribute_item)]* @context
1361 .
1362 [
1363 (struct_item
1364 name: (_) @name)
1365
1366 (enum_item
1367 name: (_) @name)
1368
1369 (impl_item
1370 trait: (_)? @name
1371 "for"? @name
1372 type: (_) @name)
1373
1374 (trait_item
1375 name: (_) @name)
1376
1377 (function_item
1378 name: (_) @name
1379 body: (block
1380 "{" @keep
1381 "}" @keep) @collapse)
1382
1383 (macro_definition
1384 name: (_) @name)
1385 ] @item
1386 )
1387
1388 (attribute_item) @collapse
1389 (use_declaration) @collapse
1390 "#,
1391 )
1392 .unwrap(),
1393 )
1394}
1395
1396fn json_lang() -> Arc<Language> {
1397 Arc::new(
1398 Language::new(
1399 LanguageConfig {
1400 name: "JSON".into(),
1401 matcher: LanguageMatcher {
1402 path_suffixes: vec!["json".into()],
1403 ..Default::default()
1404 },
1405 ..Default::default()
1406 },
1407 Some(tree_sitter_json::language()),
1408 )
1409 .with_embedding_query(
1410 r#"
1411 (document) @item
1412
1413 (array
1414 "[" @keep
1415 .
1416 (object)? @keep
1417 "]" @keep) @collapse
1418
1419 (pair value: (string
1420 "\"" @keep
1421 "\"" @keep) @collapse)
1422 "#,
1423 )
1424 .unwrap(),
1425 )
1426}
1427
1428fn toml_lang() -> Arc<Language> {
1429 Arc::new(Language::new(
1430 LanguageConfig {
1431 name: "TOML".into(),
1432 matcher: LanguageMatcher {
1433 path_suffixes: vec!["toml".into()],
1434 ..Default::default()
1435 },
1436 ..Default::default()
1437 },
1438 Some(tree_sitter_toml::language()),
1439 ))
1440}
1441
1442fn cpp_lang() -> Arc<Language> {
1443 Arc::new(
1444 Language::new(
1445 LanguageConfig {
1446 name: "CPP".into(),
1447 matcher: LanguageMatcher {
1448 path_suffixes: vec!["cpp".into()],
1449 ..Default::default()
1450 },
1451 ..Default::default()
1452 },
1453 Some(tree_sitter_cpp::language()),
1454 )
1455 .with_embedding_query(
1456 r#"
1457 (
1458 (comment)* @context
1459 .
1460 (function_definition
1461 (type_qualifier)? @name
1462 type: (_)? @name
1463 declarator: [
1464 (function_declarator
1465 declarator: (_) @name)
1466 (pointer_declarator
1467 "*" @name
1468 declarator: (function_declarator
1469 declarator: (_) @name))
1470 (pointer_declarator
1471 "*" @name
1472 declarator: (pointer_declarator
1473 "*" @name
1474 declarator: (function_declarator
1475 declarator: (_) @name)))
1476 (reference_declarator
1477 ["&" "&&"] @name
1478 (function_declarator
1479 declarator: (_) @name))
1480 ]
1481 (type_qualifier)? @name) @item
1482 )
1483
1484 (
1485 (comment)* @context
1486 .
1487 (template_declaration
1488 (class_specifier
1489 "class" @name
1490 name: (_) @name)
1491 ) @item
1492 )
1493
1494 (
1495 (comment)* @context
1496 .
1497 (class_specifier
1498 "class" @name
1499 name: (_) @name) @item
1500 )
1501
1502 (
1503 (comment)* @context
1504 .
1505 (enum_specifier
1506 "enum" @name
1507 name: (_) @name) @item
1508 )
1509
1510 (
1511 (comment)* @context
1512 .
1513 (declaration
1514 type: (struct_specifier
1515 "struct" @name)
1516 declarator: (_) @name) @item
1517 )
1518
1519 "#,
1520 )
1521 .unwrap(),
1522 )
1523}
1524
1525fn lua_lang() -> Arc<Language> {
1526 Arc::new(
1527 Language::new(
1528 LanguageConfig {
1529 name: "Lua".into(),
1530 matcher: LanguageMatcher {
1531 path_suffixes: vec!["lua".into()],
1532 ..Default::default()
1533 },
1534 collapsed_placeholder: "--[ ... ]--".to_string(),
1535 ..Default::default()
1536 },
1537 Some(tree_sitter_lua::language()),
1538 )
1539 .with_embedding_query(
1540 r#"
1541 (
1542 (comment)* @context
1543 .
1544 (function_declaration
1545 "function" @name
1546 name: (_) @name
1547 (comment)* @collapse
1548 body: (block) @collapse
1549 ) @item
1550 )
1551 "#,
1552 )
1553 .unwrap(),
1554 )
1555}
1556
1557fn php_lang() -> Arc<Language> {
1558 Arc::new(
1559 Language::new(
1560 LanguageConfig {
1561 name: "PHP".into(),
1562 matcher: LanguageMatcher {
1563 path_suffixes: vec!["php".into()],
1564 ..Default::default()
1565 },
1566 collapsed_placeholder: "/* ... */".into(),
1567 ..Default::default()
1568 },
1569 Some(tree_sitter_php::language_php()),
1570 )
1571 .with_embedding_query(
1572 r#"
1573 (
1574 (comment)* @context
1575 .
1576 [
1577 (function_definition
1578 "function" @name
1579 name: (_) @name
1580 body: (_
1581 "{" @keep
1582 "}" @keep) @collapse
1583 )
1584
1585 (trait_declaration
1586 "trait" @name
1587 name: (_) @name)
1588
1589 (method_declaration
1590 "function" @name
1591 name: (_) @name
1592 body: (_
1593 "{" @keep
1594 "}" @keep) @collapse
1595 )
1596
1597 (interface_declaration
1598 "interface" @name
1599 name: (_) @name
1600 )
1601
1602 (enum_declaration
1603 "enum" @name
1604 name: (_) @name
1605 )
1606
1607 ] @item
1608 )
1609 "#,
1610 )
1611 .unwrap(),
1612 )
1613}
1614
1615fn ruby_lang() -> Arc<Language> {
1616 Arc::new(
1617 Language::new(
1618 LanguageConfig {
1619 name: "Ruby".into(),
1620 matcher: LanguageMatcher {
1621 path_suffixes: vec!["rb".into()],
1622 ..Default::default()
1623 },
1624 collapsed_placeholder: "# ...".to_string(),
1625 ..Default::default()
1626 },
1627 Some(tree_sitter_ruby::language()),
1628 )
1629 .with_embedding_query(
1630 r#"
1631 (
1632 (comment)* @context
1633 .
1634 [
1635 (module
1636 "module" @name
1637 name: (_) @name)
1638 (method
1639 "def" @name
1640 name: (_) @name
1641 body: (body_statement) @collapse)
1642 (class
1643 "class" @name
1644 name: (_) @name)
1645 (singleton_method
1646 "def" @name
1647 object: (_) @name
1648 "." @name
1649 name: (_) @name
1650 body: (body_statement) @collapse)
1651 ] @item
1652 )
1653 "#,
1654 )
1655 .unwrap(),
1656 )
1657}
1658
1659fn elixir_lang() -> Arc<Language> {
1660 Arc::new(
1661 Language::new(
1662 LanguageConfig {
1663 name: "Elixir".into(),
1664 matcher: LanguageMatcher {
1665 path_suffixes: vec!["rs".into()],
1666 ..Default::default()
1667 },
1668 ..Default::default()
1669 },
1670 Some(tree_sitter_elixir::language()),
1671 )
1672 .with_embedding_query(
1673 r#"
1674 (
1675 (unary_operator
1676 operator: "@"
1677 operand: (call
1678 target: (identifier) @unary
1679 (#match? @unary "^(doc)$"))
1680 ) @context
1681 .
1682 (call
1683 target: (identifier) @name
1684 (arguments
1685 [
1686 (identifier) @name
1687 (call
1688 target: (identifier) @name)
1689 (binary_operator
1690 left: (call
1691 target: (identifier) @name)
1692 operator: "when")
1693 ])
1694 (#any-match? @name "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @item
1695 )
1696
1697 (call
1698 target: (identifier) @name
1699 (arguments (alias) @name)
1700 (#any-match? @name "^(defmodule|defprotocol)$")) @item
1701 "#,
1702 )
1703 .unwrap(),
1704 )
1705}
1706
1707#[gpui::test]
1708fn test_subtract_ranges() {
1709 assert_eq!(
1710 subtract_ranges(&[0..5, 10..21], &[0..1, 4..5]),
1711 vec![1..4, 10..21]
1712 );
1713
1714 assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
1715}
1716
1717fn init_test(cx: &mut TestAppContext) {
1718 cx.update(|cx| {
1719 let settings_store = SettingsStore::test(cx);
1720 cx.set_global(settings_store);
1721 SemanticIndexSettings::register(cx);
1722 language::init(cx);
1723 Project::init_settings(cx);
1724 });
1725}