1use crate::{
2 embedding_queue::EmbeddingQueue,
3 parsing::{subtract_ranges, CodeContextRetriever, Span, SpanDigest},
4 semantic_index_settings::SemanticIndexSettings,
5 FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
6};
7use ai::test::FakeEmbeddingProvider;
8
9use gpui::{Task, TestAppContext};
10use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
11use parking_lot::Mutex;
12use pretty_assertions::assert_eq;
13use project::{project_settings::ProjectSettings, FakeFs, Fs, Project};
14use rand::{rngs::StdRng, Rng};
15use serde_json::json;
16use settings::{Settings, SettingsStore};
17use std::{path::Path, sync::Arc, time::SystemTime};
18use unindent::Unindent;
19use util::{paths::PathMatcher, RandomCharIter};
20
21#[ctor::ctor]
22fn init_logger() {
23 if std::env::var("RUST_LOG").is_ok() {
24 env_logger::init();
25 }
26}
27
28#[gpui::test]
29async fn test_semantic_index(cx: &mut TestAppContext) {
30 init_test(cx);
31
32 let fs = FakeFs::new(cx.background_executor.clone());
33 fs.insert_tree(
34 "/the-root",
35 json!({
36 "src": {
37 "file1.rs": "
38 fn aaa() {
39 println!(\"aaaaaaaaaaaa!\");
40 }
41
42 fn zzzzz() {
43 println!(\"SLEEPING\");
44 }
45 ".unindent(),
46 "file2.rs": "
47 fn bbb() {
48 println!(\"bbbbbbbbbbbbb!\");
49 }
50 struct pqpqpqp {}
51 ".unindent(),
52 "file3.toml": "
53 ZZZZZZZZZZZZZZZZZZ = 5
54 ".unindent(),
55 }
56 }),
57 )
58 .await;
59
60 let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
61 let rust_language = rust_lang();
62 let toml_language = toml_lang();
63 languages.add(rust_language);
64 languages.add(toml_language);
65
66 let db_dir = tempfile::Builder::new()
67 .prefix("vector-store")
68 .tempdir()
69 .unwrap();
70 let db_path = db_dir.path().join("db.sqlite");
71
72 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
73 let semantic_index = SemanticIndex::new(
74 fs.clone(),
75 db_path,
76 embedding_provider.clone(),
77 languages,
78 cx.to_async(),
79 )
80 .await
81 .unwrap();
82
83 let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
84
85 let search_results = semantic_index.update(cx, |store, cx| {
86 store.search_project(
87 project.clone(),
88 "aaaaaabbbbzz".to_string(),
89 5,
90 vec![],
91 vec![],
92 cx,
93 )
94 });
95 let pending_file_count =
96 semantic_index.read_with(cx, |index, _| index.pending_file_count(&project).unwrap());
97 cx.background_executor.run_until_parked();
98 assert_eq!(*pending_file_count.borrow(), 3);
99 cx.background_executor
100 .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
101 assert_eq!(*pending_file_count.borrow(), 0);
102
103 let search_results = search_results.await.unwrap();
104 assert_search_results(
105 &search_results,
106 &[
107 (Path::new("src/file1.rs").into(), 0),
108 (Path::new("src/file2.rs").into(), 0),
109 (Path::new("src/file3.toml").into(), 0),
110 (Path::new("src/file1.rs").into(), 45),
111 (Path::new("src/file2.rs").into(), 45),
112 ],
113 cx,
114 );
115
116 // Test Include Files Functionality
117 let include_files = vec![PathMatcher::new("*.rs").unwrap()];
118 let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
119 let rust_only_search_results = semantic_index
120 .update(cx, |store, cx| {
121 store.search_project(
122 project.clone(),
123 "aaaaaabbbbzz".to_string(),
124 5,
125 include_files,
126 vec![],
127 cx,
128 )
129 })
130 .await
131 .unwrap();
132
133 assert_search_results(
134 &rust_only_search_results,
135 &[
136 (Path::new("src/file1.rs").into(), 0),
137 (Path::new("src/file2.rs").into(), 0),
138 (Path::new("src/file1.rs").into(), 45),
139 (Path::new("src/file2.rs").into(), 45),
140 ],
141 cx,
142 );
143
144 let no_rust_search_results = semantic_index
145 .update(cx, |store, cx| {
146 store.search_project(
147 project.clone(),
148 "aaaaaabbbbzz".to_string(),
149 5,
150 vec![],
151 exclude_files,
152 cx,
153 )
154 })
155 .await
156 .unwrap();
157
158 assert_search_results(
159 &no_rust_search_results,
160 &[(Path::new("src/file3.toml").into(), 0)],
161 cx,
162 );
163
164 fs.save(
165 "/the-root/src/file2.rs".as_ref(),
166 &"
167 fn dddd() { println!(\"ddddd!\"); }
168 struct pqpqpqp {}
169 "
170 .unindent()
171 .into(),
172 Default::default(),
173 )
174 .await
175 .unwrap();
176
177 cx.background_executor
178 .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
179
180 let prev_embedding_count = embedding_provider.embedding_count();
181 let index = semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx));
182 cx.background_executor.run_until_parked();
183 assert_eq!(*pending_file_count.borrow(), 1);
184 cx.background_executor
185 .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
186 assert_eq!(*pending_file_count.borrow(), 0);
187 index.await.unwrap();
188
189 assert_eq!(
190 embedding_provider.embedding_count() - prev_embedding_count,
191 1
192 );
193}
194
195#[gpui::test(iterations = 10)]
196async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
197 let (outstanding_job_count, _) = postage::watch::channel_with(0);
198 let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count));
199
200 let files = (1..=3)
201 .map(|file_ix| FileToEmbed {
202 worktree_id: 5,
203 path: Path::new(&format!("path-{file_ix}")).into(),
204 mtime: SystemTime::now(),
205 spans: (0..rng.gen_range(4..22))
206 .map(|document_ix| {
207 let content_len = rng.gen_range(10..100);
208 let content = RandomCharIter::new(&mut rng)
209 .with_simple_text()
210 .take(content_len)
211 .collect::<String>();
212 let digest = SpanDigest::from(content.as_str());
213 Span {
214 range: 0..10,
215 embedding: None,
216 name: format!("document {document_ix}"),
217 content,
218 digest,
219 token_count: rng.gen_range(10..30),
220 }
221 })
222 .collect(),
223 job_handle: JobHandle::new(&outstanding_job_count),
224 })
225 .collect::<Vec<_>>();
226
227 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
228
229 let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor.clone());
230 for file in &files {
231 queue.push(file.clone());
232 }
233 queue.flush();
234
235 cx.background_executor.run_until_parked();
236 let finished_files = queue.finished_files();
237 let mut embedded_files: Vec<_> = files
238 .iter()
239 .map(|_| finished_files.try_recv().expect("no finished file"))
240 .collect();
241
242 let expected_files: Vec<_> = files
243 .iter()
244 .map(|file| {
245 let mut file = file.clone();
246 for doc in &mut file.spans {
247 doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
248 }
249 file
250 })
251 .collect();
252
253 embedded_files.sort_by_key(|f| f.path.clone());
254
255 assert_eq!(embedded_files, expected_files);
256}
257
258#[track_caller]
259fn assert_search_results(
260 actual: &[SearchResult],
261 expected: &[(Arc<Path>, usize)],
262 cx: &TestAppContext,
263) {
264 let actual = actual
265 .iter()
266 .map(|search_result| {
267 search_result.buffer.read_with(cx, |buffer, _cx| {
268 (
269 buffer.file().unwrap().path().clone(),
270 search_result.range.start.to_offset(buffer),
271 )
272 })
273 })
274 .collect::<Vec<_>>();
275 assert_eq!(actual, expected);
276}
277
278#[gpui::test]
279async fn test_code_context_retrieval_rust() {
280 let language = rust_lang();
281 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
282 let mut retriever = CodeContextRetriever::new(embedding_provider);
283
284 let text = "
285 /// A doc comment
286 /// that spans multiple lines
287 #[gpui::test]
288 fn a() {
289 b
290 }
291
292 impl C for D {
293 }
294
295 impl E {
296 // This is also a preceding comment
297 pub fn function_1() -> Option<()> {
298 unimplemented!();
299 }
300
301 // This is a preceding comment
302 fn function_2() -> Result<()> {
303 unimplemented!();
304 }
305 }
306
307 #[derive(Clone)]
308 struct D {
309 name: String
310 }
311 "
312 .unindent();
313
314 let documents = retriever.parse_file(&text, language).unwrap();
315
316 assert_documents_eq(
317 &documents,
318 &[
319 (
320 "
321 /// A doc comment
322 /// that spans multiple lines
323 #[gpui::test]
324 fn a() {
325 b
326 }"
327 .unindent(),
328 text.find("fn a").unwrap(),
329 ),
330 (
331 "
332 impl C for D {
333 }"
334 .unindent(),
335 text.find("impl C").unwrap(),
336 ),
337 (
338 "
339 impl E {
340 // This is also a preceding comment
341 pub fn function_1() -> Option<()> { /* ... */ }
342
343 // This is a preceding comment
344 fn function_2() -> Result<()> { /* ... */ }
345 }"
346 .unindent(),
347 text.find("impl E").unwrap(),
348 ),
349 (
350 "
351 // This is also a preceding comment
352 pub fn function_1() -> Option<()> {
353 unimplemented!();
354 }"
355 .unindent(),
356 text.find("pub fn function_1").unwrap(),
357 ),
358 (
359 "
360 // This is a preceding comment
361 fn function_2() -> Result<()> {
362 unimplemented!();
363 }"
364 .unindent(),
365 text.find("fn function_2").unwrap(),
366 ),
367 (
368 "
369 #[derive(Clone)]
370 struct D {
371 name: String
372 }"
373 .unindent(),
374 text.find("struct D").unwrap(),
375 ),
376 ],
377 );
378}
379
380#[gpui::test]
381async fn test_code_context_retrieval_json() {
382 let language = json_lang();
383 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
384 let mut retriever = CodeContextRetriever::new(embedding_provider);
385
386 let text = r#"
387 {
388 "array": [1, 2, 3, 4],
389 "string": "abcdefg",
390 "nested_object": {
391 "array_2": [5, 6, 7, 8],
392 "string_2": "hijklmnop",
393 "boolean": true,
394 "none": null
395 }
396 }
397 "#
398 .unindent();
399
400 let documents = retriever.parse_file(&text, language.clone()).unwrap();
401
402 assert_documents_eq(
403 &documents,
404 &[(
405 r#"
406 {
407 "array": [],
408 "string": "",
409 "nested_object": {
410 "array_2": [],
411 "string_2": "",
412 "boolean": true,
413 "none": null
414 }
415 }"#
416 .unindent(),
417 text.find("{").unwrap(),
418 )],
419 );
420
421 let text = r#"
422 [
423 {
424 "name": "somebody",
425 "age": 42
426 },
427 {
428 "name": "somebody else",
429 "age": 43
430 }
431 ]
432 "#
433 .unindent();
434
435 let documents = retriever.parse_file(&text, language.clone()).unwrap();
436
437 assert_documents_eq(
438 &documents,
439 &[(
440 r#"
441 [{
442 "name": "",
443 "age": 42
444 }]"#
445 .unindent(),
446 text.find("[").unwrap(),
447 )],
448 );
449}
450
451fn assert_documents_eq(
452 documents: &[Span],
453 expected_contents_and_start_offsets: &[(String, usize)],
454) {
455 assert_eq!(
456 documents
457 .iter()
458 .map(|document| (document.content.clone(), document.range.start))
459 .collect::<Vec<_>>(),
460 expected_contents_and_start_offsets
461 );
462}
463
464#[gpui::test]
465async fn test_code_context_retrieval_javascript() {
466 let language = js_lang();
467 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
468 let mut retriever = CodeContextRetriever::new(embedding_provider);
469
470 let text = "
471 /* globals importScripts, backend */
472 function _authorize() {}
473
474 /**
475 * Sometimes the frontend build is way faster than backend.
476 */
477 export async function authorizeBank() {
478 _authorize(pushModal, upgradingAccountId, {});
479 }
480
481 export class SettingsPage {
482 /* This is a test setting */
483 constructor(page) {
484 this.page = page;
485 }
486 }
487
488 /* This is a test comment */
489 class TestClass {}
490
491 /* Schema for editor_events in Clickhouse. */
492 export interface ClickhouseEditorEvent {
493 installation_id: string
494 operation: string
495 }
496 "
497 .unindent();
498
499 let documents = retriever.parse_file(&text, language.clone()).unwrap();
500
501 assert_documents_eq(
502 &documents,
503 &[
504 (
505 "
506 /* globals importScripts, backend */
507 function _authorize() {}"
508 .unindent(),
509 37,
510 ),
511 (
512 "
513 /**
514 * Sometimes the frontend build is way faster than backend.
515 */
516 export async function authorizeBank() {
517 _authorize(pushModal, upgradingAccountId, {});
518 }"
519 .unindent(),
520 131,
521 ),
522 (
523 "
524 export class SettingsPage {
525 /* This is a test setting */
526 constructor(page) {
527 this.page = page;
528 }
529 }"
530 .unindent(),
531 225,
532 ),
533 (
534 "
535 /* This is a test setting */
536 constructor(page) {
537 this.page = page;
538 }"
539 .unindent(),
540 290,
541 ),
542 (
543 "
544 /* This is a test comment */
545 class TestClass {}"
546 .unindent(),
547 374,
548 ),
549 (
550 "
551 /* Schema for editor_events in Clickhouse. */
552 export interface ClickhouseEditorEvent {
553 installation_id: string
554 operation: string
555 }"
556 .unindent(),
557 440,
558 ),
559 ],
560 )
561}
562
563#[gpui::test]
564async fn test_code_context_retrieval_lua() {
565 let language = lua_lang();
566 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
567 let mut retriever = CodeContextRetriever::new(embedding_provider);
568
569 let text = r#"
570 -- Creates a new class
571 -- @param baseclass The Baseclass of this class, or nil.
572 -- @return A new class reference.
573 function classes.class(baseclass)
574 -- Create the class definition and metatable.
575 local classdef = {}
576 -- Find the super class, either Object or user-defined.
577 baseclass = baseclass or classes.Object
578 -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
579 setmetatable(classdef, { __index = baseclass })
580 -- All class instances have a reference to the class object.
581 classdef.class = classdef
582 --- Recursively allocates the inheritance tree of the instance.
583 -- @param mastertable The 'root' of the inheritance tree.
584 -- @return Returns the instance with the allocated inheritance tree.
585 function classdef.alloc(mastertable)
586 -- All class instances have a reference to a superclass object.
587 local instance = { super = baseclass.alloc(mastertable) }
588 -- Any functions this instance does not know of will 'look up' to the superclass definition.
589 setmetatable(instance, { __index = classdef, __newindex = mastertable })
590 return instance
591 end
592 end
593 "#.unindent();
594
595 let documents = retriever.parse_file(&text, language.clone()).unwrap();
596
597 assert_documents_eq(
598 &documents,
599 &[
600 (r#"
601 -- Creates a new class
602 -- @param baseclass The Baseclass of this class, or nil.
603 -- @return A new class reference.
604 function classes.class(baseclass)
605 -- Create the class definition and metatable.
606 local classdef = {}
607 -- Find the super class, either Object or user-defined.
608 baseclass = baseclass or classes.Object
609 -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
610 setmetatable(classdef, { __index = baseclass })
611 -- All class instances have a reference to the class object.
612 classdef.class = classdef
613 --- Recursively allocates the inheritance tree of the instance.
614 -- @param mastertable The 'root' of the inheritance tree.
615 -- @return Returns the instance with the allocated inheritance tree.
616 function classdef.alloc(mastertable)
617 --[ ... ]--
618 --[ ... ]--
619 end
620 end"#.unindent(),
621 114),
622 (r#"
623 --- Recursively allocates the inheritance tree of the instance.
624 -- @param mastertable The 'root' of the inheritance tree.
625 -- @return Returns the instance with the allocated inheritance tree.
626 function classdef.alloc(mastertable)
627 -- All class instances have a reference to a superclass object.
628 local instance = { super = baseclass.alloc(mastertable) }
629 -- Any functions this instance does not know of will 'look up' to the superclass definition.
630 setmetatable(instance, { __index = classdef, __newindex = mastertable })
631 return instance
632 end"#.unindent(), 810),
633 ]
634 );
635}
636
637#[gpui::test]
638async fn test_code_context_retrieval_elixir() {
639 let language = elixir_lang();
640 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
641 let mut retriever = CodeContextRetriever::new(embedding_provider);
642
643 let text = r#"
644 defmodule File.Stream do
645 @moduledoc """
646 Defines a `File.Stream` struct returned by `File.stream!/3`.
647
648 The following fields are public:
649
650 * `path` - the file path
651 * `modes` - the file modes
652 * `raw` - a boolean indicating if bin functions should be used
653 * `line_or_bytes` - if reading should read lines or a given number of bytes
654 * `node` - the node the file belongs to
655
656 """
657
658 defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
659
660 @type t :: %__MODULE__{}
661
662 @doc false
663 def __build__(path, modes, line_or_bytes) do
664 raw = :lists.keyfind(:encoding, 1, modes) == false
665
666 modes =
667 case raw do
668 true ->
669 case :lists.keyfind(:read_ahead, 1, modes) do
670 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
671 {:read_ahead, _} -> [:raw | modes]
672 false -> [:raw, :read_ahead | modes]
673 end
674
675 false ->
676 modes
677 end
678
679 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
680
681 end"#
682 .unindent();
683
684 let documents = retriever.parse_file(&text, language.clone()).unwrap();
685
686 assert_documents_eq(
687 &documents,
688 &[(
689 r#"
690 defmodule File.Stream do
691 @moduledoc """
692 Defines a `File.Stream` struct returned by `File.stream!/3`.
693
694 The following fields are public:
695
696 * `path` - the file path
697 * `modes` - the file modes
698 * `raw` - a boolean indicating if bin functions should be used
699 * `line_or_bytes` - if reading should read lines or a given number of bytes
700 * `node` - the node the file belongs to
701
702 """
703
704 defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
705
706 @type t :: %__MODULE__{}
707
708 @doc false
709 def __build__(path, modes, line_or_bytes) do
710 raw = :lists.keyfind(:encoding, 1, modes) == false
711
712 modes =
713 case raw do
714 true ->
715 case :lists.keyfind(:read_ahead, 1, modes) do
716 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
717 {:read_ahead, _} -> [:raw | modes]
718 false -> [:raw, :read_ahead | modes]
719 end
720
721 false ->
722 modes
723 end
724
725 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
726
727 end"#
728 .unindent(),
729 0,
730 ),(r#"
731 @doc false
732 def __build__(path, modes, line_or_bytes) do
733 raw = :lists.keyfind(:encoding, 1, modes) == false
734
735 modes =
736 case raw do
737 true ->
738 case :lists.keyfind(:read_ahead, 1, modes) do
739 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
740 {:read_ahead, _} -> [:raw | modes]
741 false -> [:raw, :read_ahead | modes]
742 end
743
744 false ->
745 modes
746 end
747
748 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
749
750 end"#.unindent(), 574)],
751 );
752}
753
754#[gpui::test]
755async fn test_code_context_retrieval_cpp() {
756 let language = cpp_lang();
757 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
758 let mut retriever = CodeContextRetriever::new(embedding_provider);
759
760 let text = "
761 /**
762 * @brief Main function
763 * @returns 0 on exit
764 */
765 int main() { return 0; }
766
767 /**
768 * This is a test comment
769 */
770 class MyClass { // The class
771 public: // Access specifier
772 int myNum; // Attribute (int variable)
773 string myString; // Attribute (string variable)
774 };
775
776 // This is a test comment
777 enum Color { red, green, blue };
778
779 /** This is a preceding block comment
780 * This is the second line
781 */
782 struct { // Structure declaration
783 int myNum; // Member (int variable)
784 string myString; // Member (string variable)
785 } myStructure;
786
787 /**
788 * @brief Matrix class.
789 */
790 template <typename T,
791 typename = typename std::enable_if<
792 std::is_integral<T>::value || std::is_floating_point<T>::value,
793 bool>::type>
794 class Matrix2 {
795 std::vector<std::vector<T>> _mat;
796
797 public:
798 /**
799 * @brief Constructor
800 * @tparam Integer ensuring integers are being evaluated and not other
801 * data types.
802 * @param size denoting the size of Matrix as size x size
803 */
804 template <typename Integer,
805 typename = typename std::enable_if<std::is_integral<Integer>::value,
806 Integer>::type>
807 explicit Matrix(const Integer size) {
808 for (size_t i = 0; i < size; ++i) {
809 _mat.emplace_back(std::vector<T>(size, 0));
810 }
811 }
812 }"
813 .unindent();
814
815 let documents = retriever.parse_file(&text, language.clone()).unwrap();
816
817 assert_documents_eq(
818 &documents,
819 &[
820 (
821 "
822 /**
823 * @brief Main function
824 * @returns 0 on exit
825 */
826 int main() { return 0; }"
827 .unindent(),
828 54,
829 ),
830 (
831 "
832 /**
833 * This is a test comment
834 */
835 class MyClass { // The class
836 public: // Access specifier
837 int myNum; // Attribute (int variable)
838 string myString; // Attribute (string variable)
839 }"
840 .unindent(),
841 112,
842 ),
843 (
844 "
845 // This is a test comment
846 enum Color { red, green, blue }"
847 .unindent(),
848 322,
849 ),
850 (
851 "
852 /** This is a preceding block comment
853 * This is the second line
854 */
855 struct { // Structure declaration
856 int myNum; // Member (int variable)
857 string myString; // Member (string variable)
858 } myStructure;"
859 .unindent(),
860 425,
861 ),
862 (
863 "
864 /**
865 * @brief Matrix class.
866 */
867 template <typename T,
868 typename = typename std::enable_if<
869 std::is_integral<T>::value || std::is_floating_point<T>::value,
870 bool>::type>
871 class Matrix2 {
872 std::vector<std::vector<T>> _mat;
873
874 public:
875 /**
876 * @brief Constructor
877 * @tparam Integer ensuring integers are being evaluated and not other
878 * data types.
879 * @param size denoting the size of Matrix as size x size
880 */
881 template <typename Integer,
882 typename = typename std::enable_if<std::is_integral<Integer>::value,
883 Integer>::type>
884 explicit Matrix(const Integer size) {
885 for (size_t i = 0; i < size; ++i) {
886 _mat.emplace_back(std::vector<T>(size, 0));
887 }
888 }
889 }"
890 .unindent(),
891 612,
892 ),
893 (
894 "
895 explicit Matrix(const Integer size) {
896 for (size_t i = 0; i < size; ++i) {
897 _mat.emplace_back(std::vector<T>(size, 0));
898 }
899 }"
900 .unindent(),
901 1226,
902 ),
903 ],
904 );
905}
906
907#[gpui::test]
908async fn test_code_context_retrieval_ruby() {
909 let language = ruby_lang();
910 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
911 let mut retriever = CodeContextRetriever::new(embedding_provider);
912
913 let text = r#"
914 # This concern is inspired by "sudo mode" on GitHub. It
915 # is a way to re-authenticate a user before allowing them
916 # to see or perform an action.
917 #
918 # Add `before_action :require_challenge!` to actions you
919 # want to protect.
920 #
921 # The user will be shown a page to enter the challenge (which
922 # is either the password, or just the username when no
923 # password exists). Upon passing, there is a grace period
924 # during which no challenge will be asked from the user.
925 #
926 # Accessing challenge-protected resources during the grace
927 # period will refresh the grace period.
928 module ChallengableConcern
929 extend ActiveSupport::Concern
930
931 CHALLENGE_TIMEOUT = 1.hour.freeze
932
933 def require_challenge!
934 return if skip_challenge?
935
936 if challenge_passed_recently?
937 session[:challenge_passed_at] = Time.now.utc
938 return
939 end
940
941 @challenge = Form::Challenge.new(return_to: request.url)
942
943 if params.key?(:form_challenge)
944 if challenge_passed?
945 session[:challenge_passed_at] = Time.now.utc
946 else
947 flash.now[:alert] = I18n.t('challenge.invalid_password')
948 render_challenge
949 end
950 else
951 render_challenge
952 end
953 end
954
955 def challenge_passed?
956 current_user.valid_password?(challenge_params[:current_password])
957 end
958 end
959
960 class Animal
961 include Comparable
962
963 attr_reader :legs
964
965 def initialize(name, legs)
966 @name, @legs = name, legs
967 end
968
969 def <=>(other)
970 legs <=> other.legs
971 end
972 end
973
974 # Singleton method for car object
975 def car.wheels
976 puts "There are four wheels"
977 end"#
978 .unindent();
979
980 let documents = retriever.parse_file(&text, language.clone()).unwrap();
981
982 assert_documents_eq(
983 &documents,
984 &[
985 (
986 r#"
987 # This concern is inspired by "sudo mode" on GitHub. It
988 # is a way to re-authenticate a user before allowing them
989 # to see or perform an action.
990 #
991 # Add `before_action :require_challenge!` to actions you
992 # want to protect.
993 #
994 # The user will be shown a page to enter the challenge (which
995 # is either the password, or just the username when no
996 # password exists). Upon passing, there is a grace period
997 # during which no challenge will be asked from the user.
998 #
999 # Accessing challenge-protected resources during the grace
1000 # period will refresh the grace period.
1001 module ChallengableConcern
1002 extend ActiveSupport::Concern
1003
1004 CHALLENGE_TIMEOUT = 1.hour.freeze
1005
1006 def require_challenge!
1007 # ...
1008 end
1009
1010 def challenge_passed?
1011 # ...
1012 end
1013 end"#
1014 .unindent(),
1015 558,
1016 ),
1017 (
1018 r#"
1019 def require_challenge!
1020 return if skip_challenge?
1021
1022 if challenge_passed_recently?
1023 session[:challenge_passed_at] = Time.now.utc
1024 return
1025 end
1026
1027 @challenge = Form::Challenge.new(return_to: request.url)
1028
1029 if params.key?(:form_challenge)
1030 if challenge_passed?
1031 session[:challenge_passed_at] = Time.now.utc
1032 else
1033 flash.now[:alert] = I18n.t('challenge.invalid_password')
1034 render_challenge
1035 end
1036 else
1037 render_challenge
1038 end
1039 end"#
1040 .unindent(),
1041 663,
1042 ),
1043 (
1044 r#"
1045 def challenge_passed?
1046 current_user.valid_password?(challenge_params[:current_password])
1047 end"#
1048 .unindent(),
1049 1254,
1050 ),
1051 (
1052 r#"
1053 class Animal
1054 include Comparable
1055
1056 attr_reader :legs
1057
1058 def initialize(name, legs)
1059 # ...
1060 end
1061
1062 def <=>(other)
1063 # ...
1064 end
1065 end"#
1066 .unindent(),
1067 1363,
1068 ),
1069 (
1070 r#"
1071 def initialize(name, legs)
1072 @name, @legs = name, legs
1073 end"#
1074 .unindent(),
1075 1427,
1076 ),
1077 (
1078 r#"
1079 def <=>(other)
1080 legs <=> other.legs
1081 end"#
1082 .unindent(),
1083 1501,
1084 ),
1085 (
1086 r#"
1087 # Singleton method for car object
1088 def car.wheels
1089 puts "There are four wheels"
1090 end"#
1091 .unindent(),
1092 1591,
1093 ),
1094 ],
1095 );
1096}
1097
1098#[gpui::test]
1099async fn test_code_context_retrieval_php() {
1100 let language = php_lang();
1101 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
1102 let mut retriever = CodeContextRetriever::new(embedding_provider);
1103
1104 let text = r#"
1105 <?php
1106
1107 namespace LevelUp\Experience\Concerns;
1108
1109 /*
1110 This is a multiple-lines comment block
1111 that spans over multiple
1112 lines
1113 */
1114 function functionName() {
1115 echo "Hello world!";
1116 }
1117
1118 trait HasAchievements
1119 {
1120 /**
1121 * @throws \Exception
1122 */
1123 public function grantAchievement(Achievement $achievement, $progress = null): void
1124 {
1125 if ($progress > 100) {
1126 throw new Exception(message: 'Progress cannot be greater than 100');
1127 }
1128
1129 if ($this->achievements()->find($achievement->id)) {
1130 throw new Exception(message: 'User already has this Achievement');
1131 }
1132
1133 $this->achievements()->attach($achievement, [
1134 'progress' => $progress ?? null,
1135 ]);
1136
1137 $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1138 }
1139
1140 public function achievements(): BelongsToMany
1141 {
1142 return $this->belongsToMany(related: Achievement::class)
1143 ->withPivot(columns: 'progress')
1144 ->where('is_secret', false)
1145 ->using(AchievementUser::class);
1146 }
1147 }
1148
1149 interface Multiplier
1150 {
1151 public function qualifies(array $data): bool;
1152
1153 public function setMultiplier(): int;
1154 }
1155
1156 enum AuditType: string
1157 {
1158 case Add = 'add';
1159 case Remove = 'remove';
1160 case Reset = 'reset';
1161 case LevelUp = 'level_up';
1162 }
1163
1164 ?>"#
1165 .unindent();
1166
1167 let documents = retriever.parse_file(&text, language.clone()).unwrap();
1168
1169 assert_documents_eq(
1170 &documents,
1171 &[
1172 (
1173 r#"
1174 /*
1175 This is a multiple-lines comment block
1176 that spans over multiple
1177 lines
1178 */
1179 function functionName() {
1180 echo "Hello world!";
1181 }"#
1182 .unindent(),
1183 123,
1184 ),
1185 (
1186 r#"
1187 trait HasAchievements
1188 {
1189 /**
1190 * @throws \Exception
1191 */
1192 public function grantAchievement(Achievement $achievement, $progress = null): void
1193 {/* ... */}
1194
1195 public function achievements(): BelongsToMany
1196 {/* ... */}
1197 }"#
1198 .unindent(),
1199 177,
1200 ),
1201 (r#"
1202 /**
1203 * @throws \Exception
1204 */
1205 public function grantAchievement(Achievement $achievement, $progress = null): void
1206 {
1207 if ($progress > 100) {
1208 throw new Exception(message: 'Progress cannot be greater than 100');
1209 }
1210
1211 if ($this->achievements()->find($achievement->id)) {
1212 throw new Exception(message: 'User already has this Achievement');
1213 }
1214
1215 $this->achievements()->attach($achievement, [
1216 'progress' => $progress ?? null,
1217 ]);
1218
1219 $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1220 }"#.unindent(), 245),
1221 (r#"
1222 public function achievements(): BelongsToMany
1223 {
1224 return $this->belongsToMany(related: Achievement::class)
1225 ->withPivot(columns: 'progress')
1226 ->where('is_secret', false)
1227 ->using(AchievementUser::class);
1228 }"#.unindent(), 902),
1229 (r#"
1230 interface Multiplier
1231 {
1232 public function qualifies(array $data): bool;
1233
1234 public function setMultiplier(): int;
1235 }"#.unindent(),
1236 1146),
1237 (r#"
1238 enum AuditType: string
1239 {
1240 case Add = 'add';
1241 case Remove = 'remove';
1242 case Reset = 'reset';
1243 case LevelUp = 'level_up';
1244 }"#.unindent(), 1265)
1245 ],
1246 );
1247}
1248
1249fn js_lang() -> Arc<Language> {
1250 Arc::new(
1251 Language::new(
1252 LanguageConfig {
1253 name: "Javascript".into(),
1254 path_suffixes: vec!["js".into()],
1255 ..Default::default()
1256 },
1257 Some(tree_sitter_typescript::language_tsx()),
1258 )
1259 .with_embedding_query(
1260 &r#"
1261
1262 (
1263 (comment)* @context
1264 .
1265 [
1266 (export_statement
1267 (function_declaration
1268 "async"? @name
1269 "function" @name
1270 name: (_) @name))
1271 (function_declaration
1272 "async"? @name
1273 "function" @name
1274 name: (_) @name)
1275 ] @item
1276 )
1277
1278 (
1279 (comment)* @context
1280 .
1281 [
1282 (export_statement
1283 (class_declaration
1284 "class" @name
1285 name: (_) @name))
1286 (class_declaration
1287 "class" @name
1288 name: (_) @name)
1289 ] @item
1290 )
1291
1292 (
1293 (comment)* @context
1294 .
1295 [
1296 (export_statement
1297 (interface_declaration
1298 "interface" @name
1299 name: (_) @name))
1300 (interface_declaration
1301 "interface" @name
1302 name: (_) @name)
1303 ] @item
1304 )
1305
1306 (
1307 (comment)* @context
1308 .
1309 [
1310 (export_statement
1311 (enum_declaration
1312 "enum" @name
1313 name: (_) @name))
1314 (enum_declaration
1315 "enum" @name
1316 name: (_) @name)
1317 ] @item
1318 )
1319
1320 (
1321 (comment)* @context
1322 .
1323 (method_definition
1324 [
1325 "get"
1326 "set"
1327 "async"
1328 "*"
1329 "static"
1330 ]* @name
1331 name: (_) @name) @item
1332 )
1333
1334 "#
1335 .unindent(),
1336 )
1337 .unwrap(),
1338 )
1339}
1340
1341fn rust_lang() -> Arc<Language> {
1342 Arc::new(
1343 Language::new(
1344 LanguageConfig {
1345 name: "Rust".into(),
1346 path_suffixes: vec!["rs".into()],
1347 collapsed_placeholder: " /* ... */ ".to_string(),
1348 ..Default::default()
1349 },
1350 Some(tree_sitter_rust::language()),
1351 )
1352 .with_embedding_query(
1353 r#"
1354 (
1355 [(line_comment) (attribute_item)]* @context
1356 .
1357 [
1358 (struct_item
1359 name: (_) @name)
1360
1361 (enum_item
1362 name: (_) @name)
1363
1364 (impl_item
1365 trait: (_)? @name
1366 "for"? @name
1367 type: (_) @name)
1368
1369 (trait_item
1370 name: (_) @name)
1371
1372 (function_item
1373 name: (_) @name
1374 body: (block
1375 "{" @keep
1376 "}" @keep) @collapse)
1377
1378 (macro_definition
1379 name: (_) @name)
1380 ] @item
1381 )
1382
1383 (attribute_item) @collapse
1384 (use_declaration) @collapse
1385 "#,
1386 )
1387 .unwrap(),
1388 )
1389}
1390
1391fn json_lang() -> Arc<Language> {
1392 Arc::new(
1393 Language::new(
1394 LanguageConfig {
1395 name: "JSON".into(),
1396 path_suffixes: vec!["json".into()],
1397 ..Default::default()
1398 },
1399 Some(tree_sitter_json::language()),
1400 )
1401 .with_embedding_query(
1402 r#"
1403 (document) @item
1404
1405 (array
1406 "[" @keep
1407 .
1408 (object)? @keep
1409 "]" @keep) @collapse
1410
1411 (pair value: (string
1412 "\"" @keep
1413 "\"" @keep) @collapse)
1414 "#,
1415 )
1416 .unwrap(),
1417 )
1418}
1419
1420fn toml_lang() -> Arc<Language> {
1421 Arc::new(Language::new(
1422 LanguageConfig {
1423 name: "TOML".into(),
1424 path_suffixes: vec!["toml".into()],
1425 ..Default::default()
1426 },
1427 Some(tree_sitter_toml::language()),
1428 ))
1429}
1430
1431fn cpp_lang() -> Arc<Language> {
1432 Arc::new(
1433 Language::new(
1434 LanguageConfig {
1435 name: "CPP".into(),
1436 path_suffixes: vec!["cpp".into()],
1437 ..Default::default()
1438 },
1439 Some(tree_sitter_cpp::language()),
1440 )
1441 .with_embedding_query(
1442 r#"
1443 (
1444 (comment)* @context
1445 .
1446 (function_definition
1447 (type_qualifier)? @name
1448 type: (_)? @name
1449 declarator: [
1450 (function_declarator
1451 declarator: (_) @name)
1452 (pointer_declarator
1453 "*" @name
1454 declarator: (function_declarator
1455 declarator: (_) @name))
1456 (pointer_declarator
1457 "*" @name
1458 declarator: (pointer_declarator
1459 "*" @name
1460 declarator: (function_declarator
1461 declarator: (_) @name)))
1462 (reference_declarator
1463 ["&" "&&"] @name
1464 (function_declarator
1465 declarator: (_) @name))
1466 ]
1467 (type_qualifier)? @name) @item
1468 )
1469
1470 (
1471 (comment)* @context
1472 .
1473 (template_declaration
1474 (class_specifier
1475 "class" @name
1476 name: (_) @name)
1477 ) @item
1478 )
1479
1480 (
1481 (comment)* @context
1482 .
1483 (class_specifier
1484 "class" @name
1485 name: (_) @name) @item
1486 )
1487
1488 (
1489 (comment)* @context
1490 .
1491 (enum_specifier
1492 "enum" @name
1493 name: (_) @name) @item
1494 )
1495
1496 (
1497 (comment)* @context
1498 .
1499 (declaration
1500 type: (struct_specifier
1501 "struct" @name)
1502 declarator: (_) @name) @item
1503 )
1504
1505 "#,
1506 )
1507 .unwrap(),
1508 )
1509}
1510
1511fn lua_lang() -> Arc<Language> {
1512 Arc::new(
1513 Language::new(
1514 LanguageConfig {
1515 name: "Lua".into(),
1516 path_suffixes: vec!["lua".into()],
1517 collapsed_placeholder: "--[ ... ]--".to_string(),
1518 ..Default::default()
1519 },
1520 Some(tree_sitter_lua::language()),
1521 )
1522 .with_embedding_query(
1523 r#"
1524 (
1525 (comment)* @context
1526 .
1527 (function_declaration
1528 "function" @name
1529 name: (_) @name
1530 (comment)* @collapse
1531 body: (block) @collapse
1532 ) @item
1533 )
1534 "#,
1535 )
1536 .unwrap(),
1537 )
1538}
1539
1540fn php_lang() -> Arc<Language> {
1541 Arc::new(
1542 Language::new(
1543 LanguageConfig {
1544 name: "PHP".into(),
1545 path_suffixes: vec!["php".into()],
1546 collapsed_placeholder: "/* ... */".into(),
1547 ..Default::default()
1548 },
1549 Some(tree_sitter_php::language_php()),
1550 )
1551 .with_embedding_query(
1552 r#"
1553 (
1554 (comment)* @context
1555 .
1556 [
1557 (function_definition
1558 "function" @name
1559 name: (_) @name
1560 body: (_
1561 "{" @keep
1562 "}" @keep) @collapse
1563 )
1564
1565 (trait_declaration
1566 "trait" @name
1567 name: (_) @name)
1568
1569 (method_declaration
1570 "function" @name
1571 name: (_) @name
1572 body: (_
1573 "{" @keep
1574 "}" @keep) @collapse
1575 )
1576
1577 (interface_declaration
1578 "interface" @name
1579 name: (_) @name
1580 )
1581
1582 (enum_declaration
1583 "enum" @name
1584 name: (_) @name
1585 )
1586
1587 ] @item
1588 )
1589 "#,
1590 )
1591 .unwrap(),
1592 )
1593}
1594
1595fn ruby_lang() -> Arc<Language> {
1596 Arc::new(
1597 Language::new(
1598 LanguageConfig {
1599 name: "Ruby".into(),
1600 path_suffixes: vec!["rb".into()],
1601 collapsed_placeholder: "# ...".to_string(),
1602 ..Default::default()
1603 },
1604 Some(tree_sitter_ruby::language()),
1605 )
1606 .with_embedding_query(
1607 r#"
1608 (
1609 (comment)* @context
1610 .
1611 [
1612 (module
1613 "module" @name
1614 name: (_) @name)
1615 (method
1616 "def" @name
1617 name: (_) @name
1618 body: (body_statement) @collapse)
1619 (class
1620 "class" @name
1621 name: (_) @name)
1622 (singleton_method
1623 "def" @name
1624 object: (_) @name
1625 "." @name
1626 name: (_) @name
1627 body: (body_statement) @collapse)
1628 ] @item
1629 )
1630 "#,
1631 )
1632 .unwrap(),
1633 )
1634}
1635
1636fn elixir_lang() -> Arc<Language> {
1637 Arc::new(
1638 Language::new(
1639 LanguageConfig {
1640 name: "Elixir".into(),
1641 path_suffixes: vec!["rs".into()],
1642 ..Default::default()
1643 },
1644 Some(tree_sitter_elixir::language()),
1645 )
1646 .with_embedding_query(
1647 r#"
1648 (
1649 (unary_operator
1650 operator: "@"
1651 operand: (call
1652 target: (identifier) @unary
1653 (#match? @unary "^(doc)$"))
1654 ) @context
1655 .
1656 (call
1657 target: (identifier) @name
1658 (arguments
1659 [
1660 (identifier) @name
1661 (call
1662 target: (identifier) @name)
1663 (binary_operator
1664 left: (call
1665 target: (identifier) @name)
1666 operator: "when")
1667 ])
1668 (#any-match? @name "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @item
1669 )
1670
1671 (call
1672 target: (identifier) @name
1673 (arguments (alias) @name)
1674 (#any-match? @name "^(defmodule|defprotocol)$")) @item
1675 "#,
1676 )
1677 .unwrap(),
1678 )
1679}
1680
1681#[gpui::test]
1682fn test_subtract_ranges() {
1683 assert_eq!(
1684 subtract_ranges(&[0..5, 10..21], &[0..1, 4..5]),
1685 vec![1..4, 10..21]
1686 );
1687
1688 assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
1689}
1690
1691fn init_test(cx: &mut TestAppContext) {
1692 cx.update(|cx| {
1693 let settings_store = SettingsStore::test(cx);
1694 cx.set_global(settings_store);
1695 SemanticIndexSettings::register(cx);
1696 ProjectSettings::register(cx);
1697 });
1698}