1use crate::{
2 embedding_queue::EmbeddingQueue,
3 parsing::{subtract_ranges, CodeContextRetriever, Span, SpanDigest},
4 semantic_index_settings::SemanticIndexSettings,
5 FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
6};
7use ai::test::FakeEmbeddingProvider;
8
9use gpui::{Task, TestAppContext};
10use language::{Language, LanguageConfig, LanguageMatcher, LanguageRegistry, ToOffset};
11use parking_lot::Mutex;
12use pretty_assertions::assert_eq;
13use project::{project_settings::ProjectSettings, FakeFs, Fs, Project};
14use rand::{rngs::StdRng, Rng};
15use serde_json::json;
16use settings::{Settings, SettingsStore};
17use std::{path::Path, sync::Arc, time::SystemTime};
18use unindent::Unindent;
19use util::{paths::PathMatcher, RandomCharIter};
20
21#[ctor::ctor]
22fn init_logger() {
23 if std::env::var("RUST_LOG").is_ok() {
24 env_logger::init();
25 }
26}
27
28#[gpui::test]
29async fn test_semantic_index(cx: &mut TestAppContext) {
30 init_test(cx);
31
32 let fs = FakeFs::new(cx.background_executor.clone());
33 fs.insert_tree(
34 "/the-root",
35 json!({
36 "src": {
37 "file1.rs": "
38 fn aaa() {
39 println!(\"aaaaaaaaaaaa!\");
40 }
41
42 fn zzzzz() {
43 println!(\"SLEEPING\");
44 }
45 ".unindent(),
46 "file2.rs": "
47 fn bbb() {
48 println!(\"bbbbbbbbbbbbb!\");
49 }
50 struct pqpqpqp {}
51 ".unindent(),
52 "file3.toml": "
53 ZZZZZZZZZZZZZZZZZZ = 5
54 ".unindent(),
55 }
56 }),
57 )
58 .await;
59
60 let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
61 let rust_language = rust_lang();
62 let toml_language = toml_lang();
63 languages.add(rust_language);
64 languages.add(toml_language);
65
66 let db_dir = tempfile::Builder::new()
67 .prefix("vector-store")
68 .tempdir()
69 .unwrap();
70 let db_path = db_dir.path().join("db.sqlite");
71
72 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
73 let semantic_index = SemanticIndex::new(
74 fs.clone(),
75 db_path,
76 embedding_provider.clone(),
77 languages,
78 cx.to_async(),
79 )
80 .await
81 .unwrap();
82
83 let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
84
85 let search_results = semantic_index.update(cx, |store, cx| {
86 store.search_project(
87 project.clone(),
88 "aaaaaabbbbzz".to_string(),
89 5,
90 vec![],
91 vec![],
92 cx,
93 )
94 });
95 let pending_file_count =
96 semantic_index.read_with(cx, |index, _| index.pending_file_count(&project).unwrap());
97 cx.background_executor.run_until_parked();
98 assert_eq!(*pending_file_count.borrow(), 3);
99 cx.background_executor
100 .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
101 assert_eq!(*pending_file_count.borrow(), 0);
102
103 let search_results = search_results.await.unwrap();
104 assert_search_results(
105 &search_results,
106 &[
107 (Path::new("src/file1.rs").into(), 0),
108 (Path::new("src/file2.rs").into(), 0),
109 (Path::new("src/file3.toml").into(), 0),
110 (Path::new("src/file1.rs").into(), 45),
111 (Path::new("src/file2.rs").into(), 45),
112 ],
113 cx,
114 );
115
116 // Test Include Files Functionality
117 let include_files = vec![PathMatcher::new("*.rs").unwrap()];
118 let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
119 let rust_only_search_results = semantic_index
120 .update(cx, |store, cx| {
121 store.search_project(
122 project.clone(),
123 "aaaaaabbbbzz".to_string(),
124 5,
125 include_files,
126 vec![],
127 cx,
128 )
129 })
130 .await
131 .unwrap();
132
133 assert_search_results(
134 &rust_only_search_results,
135 &[
136 (Path::new("src/file1.rs").into(), 0),
137 (Path::new("src/file2.rs").into(), 0),
138 (Path::new("src/file1.rs").into(), 45),
139 (Path::new("src/file2.rs").into(), 45),
140 ],
141 cx,
142 );
143
144 let no_rust_search_results = semantic_index
145 .update(cx, |store, cx| {
146 store.search_project(
147 project.clone(),
148 "aaaaaabbbbzz".to_string(),
149 5,
150 vec![],
151 exclude_files,
152 cx,
153 )
154 })
155 .await
156 .unwrap();
157
158 assert_search_results(
159 &no_rust_search_results,
160 &[(Path::new("src/file3.toml").into(), 0)],
161 cx,
162 );
163
164 fs.save(
165 "/the-root/src/file2.rs".as_ref(),
166 &"
167 fn dddd() { println!(\"ddddd!\"); }
168 struct pqpqpqp {}
169 "
170 .unindent()
171 .into(),
172 Default::default(),
173 )
174 .await
175 .unwrap();
176
177 cx.background_executor
178 .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
179
180 let prev_embedding_count = embedding_provider.embedding_count();
181 let index = semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx));
182 cx.background_executor.run_until_parked();
183 assert_eq!(*pending_file_count.borrow(), 1);
184 cx.background_executor
185 .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
186 assert_eq!(*pending_file_count.borrow(), 0);
187 index.await.unwrap();
188
189 assert_eq!(
190 embedding_provider.embedding_count() - prev_embedding_count,
191 1
192 );
193}
194
195#[gpui::test(iterations = 10)]
196async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
197 let (outstanding_job_count, _) = postage::watch::channel_with(0);
198 let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count));
199
200 let files = (1..=3)
201 .map(|file_ix| FileToEmbed {
202 worktree_id: 5,
203 path: Path::new(&format!("path-{file_ix}")).into(),
204 mtime: SystemTime::now(),
205 spans: (0..rng.gen_range(4..22))
206 .map(|document_ix| {
207 let content_len = rng.gen_range(10..100);
208 let content = RandomCharIter::new(&mut rng)
209 .with_simple_text()
210 .take(content_len)
211 .collect::<String>();
212 let digest = SpanDigest::from(content.as_str());
213 Span {
214 range: 0..10,
215 embedding: None,
216 name: format!("document {document_ix}"),
217 content,
218 digest,
219 token_count: rng.gen_range(10..30),
220 }
221 })
222 .collect(),
223 job_handle: JobHandle::new(&outstanding_job_count),
224 })
225 .collect::<Vec<_>>();
226
227 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
228
229 let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor.clone());
230 for file in &files {
231 queue.push(file.clone());
232 }
233 queue.flush();
234
235 cx.background_executor.run_until_parked();
236 let finished_files = queue.finished_files();
237 let mut embedded_files: Vec<_> = files
238 .iter()
239 .map(|_| finished_files.try_recv().expect("no finished file"))
240 .collect();
241
242 let expected_files: Vec<_> = files
243 .iter()
244 .map(|file| {
245 let mut file = file.clone();
246 for doc in &mut file.spans {
247 doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
248 }
249 file
250 })
251 .collect();
252
253 embedded_files.sort_by_key(|f| f.path.clone());
254
255 assert_eq!(embedded_files, expected_files);
256}
257
258#[track_caller]
259fn assert_search_results(
260 actual: &[SearchResult],
261 expected: &[(Arc<Path>, usize)],
262 cx: &TestAppContext,
263) {
264 let actual = actual
265 .iter()
266 .map(|search_result| {
267 search_result.buffer.read_with(cx, |buffer, _cx| {
268 (
269 buffer.file().unwrap().path().clone(),
270 search_result.range.start.to_offset(buffer),
271 )
272 })
273 })
274 .collect::<Vec<_>>();
275 assert_eq!(actual, expected);
276}
277
278#[gpui::test]
279async fn test_code_context_retrieval_rust() {
280 let language = rust_lang();
281 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
282 let mut retriever = CodeContextRetriever::new(embedding_provider);
283
284 let text = "
285 /// A doc comment
286 /// that spans multiple lines
287 #[gpui::test]
288 fn a() {
289 b
290 }
291
292 impl C for D {
293 }
294
295 impl E {
296 // This is also a preceding comment
297 pub fn function_1() -> Option<()> {
298 unimplemented!();
299 }
300
301 // This is a preceding comment
302 fn function_2() -> Result<()> {
303 unimplemented!();
304 }
305 }
306
307 #[derive(Clone)]
308 struct D {
309 name: String
310 }
311 "
312 .unindent();
313
314 let documents = retriever.parse_file(&text, language).unwrap();
315
316 assert_documents_eq(
317 &documents,
318 &[
319 (
320 "
321 /// A doc comment
322 /// that spans multiple lines
323 #[gpui::test]
324 fn a() {
325 b
326 }"
327 .unindent(),
328 text.find("fn a").unwrap(),
329 ),
330 (
331 "
332 impl C for D {
333 }"
334 .unindent(),
335 text.find("impl C").unwrap(),
336 ),
337 (
338 "
339 impl E {
340 // This is also a preceding comment
341 pub fn function_1() -> Option<()> { /* ... */ }
342
343 // This is a preceding comment
344 fn function_2() -> Result<()> { /* ... */ }
345 }"
346 .unindent(),
347 text.find("impl E").unwrap(),
348 ),
349 (
350 "
351 // This is also a preceding comment
352 pub fn function_1() -> Option<()> {
353 unimplemented!();
354 }"
355 .unindent(),
356 text.find("pub fn function_1").unwrap(),
357 ),
358 (
359 "
360 // This is a preceding comment
361 fn function_2() -> Result<()> {
362 unimplemented!();
363 }"
364 .unindent(),
365 text.find("fn function_2").unwrap(),
366 ),
367 (
368 "
369 #[derive(Clone)]
370 struct D {
371 name: String
372 }"
373 .unindent(),
374 text.find("struct D").unwrap(),
375 ),
376 ],
377 );
378}
379
380#[gpui::test]
381async fn test_code_context_retrieval_json() {
382 let language = json_lang();
383 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
384 let mut retriever = CodeContextRetriever::new(embedding_provider);
385
386 let text = r#"
387 {
388 "array": [1, 2, 3, 4],
389 "string": "abcdefg",
390 "nested_object": {
391 "array_2": [5, 6, 7, 8],
392 "string_2": "hijklmnop",
393 "boolean": true,
394 "none": null
395 }
396 }
397 "#
398 .unindent();
399
400 let documents = retriever.parse_file(&text, language.clone()).unwrap();
401
402 assert_documents_eq(
403 &documents,
404 &[(
405 r#"
406 {
407 "array": [],
408 "string": "",
409 "nested_object": {
410 "array_2": [],
411 "string_2": "",
412 "boolean": true,
413 "none": null
414 }
415 }"#
416 .unindent(),
417 text.find("{").unwrap(),
418 )],
419 );
420
421 let text = r#"
422 [
423 {
424 "name": "somebody",
425 "age": 42
426 },
427 {
428 "name": "somebody else",
429 "age": 43
430 }
431 ]
432 "#
433 .unindent();
434
435 let documents = retriever.parse_file(&text, language.clone()).unwrap();
436
437 assert_documents_eq(
438 &documents,
439 &[(
440 r#"
441 [{
442 "name": "",
443 "age": 42
444 }]"#
445 .unindent(),
446 text.find("[").unwrap(),
447 )],
448 );
449}
450
451fn assert_documents_eq(
452 documents: &[Span],
453 expected_contents_and_start_offsets: &[(String, usize)],
454) {
455 assert_eq!(
456 documents
457 .iter()
458 .map(|document| (document.content.clone(), document.range.start))
459 .collect::<Vec<_>>(),
460 expected_contents_and_start_offsets
461 );
462}
463
464#[gpui::test]
465async fn test_code_context_retrieval_javascript() {
466 let language = js_lang();
467 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
468 let mut retriever = CodeContextRetriever::new(embedding_provider);
469
470 let text = "
471 /* globals importScripts, backend */
472 function _authorize() {}
473
474 /**
475 * Sometimes the frontend build is way faster than backend.
476 */
477 export async function authorizeBank() {
478 _authorize(pushModal, upgradingAccountId, {});
479 }
480
481 export class SettingsPage {
482 /* This is a test setting */
483 constructor(page) {
484 this.page = page;
485 }
486 }
487
488 /* This is a test comment */
489 class TestClass {}
490
491 /* Schema for editor_events in Clickhouse. */
492 export interface ClickhouseEditorEvent {
493 installation_id: string
494 operation: string
495 }
496 "
497 .unindent();
498
499 let documents = retriever.parse_file(&text, language.clone()).unwrap();
500
501 assert_documents_eq(
502 &documents,
503 &[
504 (
505 "
506 /* globals importScripts, backend */
507 function _authorize() {}"
508 .unindent(),
509 37,
510 ),
511 (
512 "
513 /**
514 * Sometimes the frontend build is way faster than backend.
515 */
516 export async function authorizeBank() {
517 _authorize(pushModal, upgradingAccountId, {});
518 }"
519 .unindent(),
520 131,
521 ),
522 (
523 "
524 export class SettingsPage {
525 /* This is a test setting */
526 constructor(page) {
527 this.page = page;
528 }
529 }"
530 .unindent(),
531 225,
532 ),
533 (
534 "
535 /* This is a test setting */
536 constructor(page) {
537 this.page = page;
538 }"
539 .unindent(),
540 290,
541 ),
542 (
543 "
544 /* This is a test comment */
545 class TestClass {}"
546 .unindent(),
547 374,
548 ),
549 (
550 "
551 /* Schema for editor_events in Clickhouse. */
552 export interface ClickhouseEditorEvent {
553 installation_id: string
554 operation: string
555 }"
556 .unindent(),
557 440,
558 ),
559 ],
560 )
561}
562
563#[gpui::test]
564async fn test_code_context_retrieval_lua() {
565 let language = lua_lang();
566 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
567 let mut retriever = CodeContextRetriever::new(embedding_provider);
568
569 let text = r#"
570 -- Creates a new class
571 -- @param baseclass The Baseclass of this class, or nil.
572 -- @return A new class reference.
573 function classes.class(baseclass)
574 -- Create the class definition and metatable.
575 local classdef = {}
576 -- Find the super class, either Object or user-defined.
577 baseclass = baseclass or classes.Object
578 -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
579 setmetatable(classdef, { __index = baseclass })
580 -- All class instances have a reference to the class object.
581 classdef.class = classdef
582 --- Recursively allocates the inheritance tree of the instance.
583 -- @param mastertable The 'root' of the inheritance tree.
584 -- @return Returns the instance with the allocated inheritance tree.
585 function classdef.alloc(mastertable)
586 -- All class instances have a reference to a superclass object.
587 local instance = { super = baseclass.alloc(mastertable) }
588 -- Any functions this instance does not know of will 'look up' to the superclass definition.
589 setmetatable(instance, { __index = classdef, __newindex = mastertable })
590 return instance
591 end
592 end
593 "#.unindent();
594
595 let documents = retriever.parse_file(&text, language.clone()).unwrap();
596
597 assert_documents_eq(
598 &documents,
599 &[
600 (r#"
601 -- Creates a new class
602 -- @param baseclass The Baseclass of this class, or nil.
603 -- @return A new class reference.
604 function classes.class(baseclass)
605 -- Create the class definition and metatable.
606 local classdef = {}
607 -- Find the super class, either Object or user-defined.
608 baseclass = baseclass or classes.Object
609 -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
610 setmetatable(classdef, { __index = baseclass })
611 -- All class instances have a reference to the class object.
612 classdef.class = classdef
613 --- Recursively allocates the inheritance tree of the instance.
614 -- @param mastertable The 'root' of the inheritance tree.
615 -- @return Returns the instance with the allocated inheritance tree.
616 function classdef.alloc(mastertable)
617 --[ ... ]--
618 --[ ... ]--
619 end
620 end"#.unindent(),
621 114),
622 (r#"
623 --- Recursively allocates the inheritance tree of the instance.
624 -- @param mastertable The 'root' of the inheritance tree.
625 -- @return Returns the instance with the allocated inheritance tree.
626 function classdef.alloc(mastertable)
627 -- All class instances have a reference to a superclass object.
628 local instance = { super = baseclass.alloc(mastertable) }
629 -- Any functions this instance does not know of will 'look up' to the superclass definition.
630 setmetatable(instance, { __index = classdef, __newindex = mastertable })
631 return instance
632 end"#.unindent(), 810),
633 ]
634 );
635}
636
637#[gpui::test]
638async fn test_code_context_retrieval_elixir() {
639 let language = elixir_lang();
640 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
641 let mut retriever = CodeContextRetriever::new(embedding_provider);
642
643 let text = r#"
644 defmodule File.Stream do
645 @moduledoc """
646 Defines a `File.Stream` struct returned by `File.stream!/3`.
647
648 The following fields are public:
649
650 * `path` - the file path
651 * `modes` - the file modes
652 * `raw` - a boolean indicating if bin functions should be used
653 * `line_or_bytes` - if reading should read lines or a given number of bytes
654 * `node` - the node the file belongs to
655
656 """
657
658 defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
659
660 @type t :: %__MODULE__{}
661
662 @doc false
663 def __build__(path, modes, line_or_bytes) do
664 raw = :lists.keyfind(:encoding, 1, modes) == false
665
666 modes =
667 case raw do
668 true ->
669 case :lists.keyfind(:read_ahead, 1, modes) do
670 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
671 {:read_ahead, _} -> [:raw | modes]
672 false -> [:raw, :read_ahead | modes]
673 end
674
675 false ->
676 modes
677 end
678
679 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
680
681 end"#
682 .unindent();
683
684 let documents = retriever.parse_file(&text, language.clone()).unwrap();
685
686 assert_documents_eq(
687 &documents,
688 &[(
689 r#"
690 defmodule File.Stream do
691 @moduledoc """
692 Defines a `File.Stream` struct returned by `File.stream!/3`.
693
694 The following fields are public:
695
696 * `path` - the file path
697 * `modes` - the file modes
698 * `raw` - a boolean indicating if bin functions should be used
699 * `line_or_bytes` - if reading should read lines or a given number of bytes
700 * `node` - the node the file belongs to
701
702 """
703
704 defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
705
706 @type t :: %__MODULE__{}
707
708 @doc false
709 def __build__(path, modes, line_or_bytes) do
710 raw = :lists.keyfind(:encoding, 1, modes) == false
711
712 modes =
713 case raw do
714 true ->
715 case :lists.keyfind(:read_ahead, 1, modes) do
716 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
717 {:read_ahead, _} -> [:raw | modes]
718 false -> [:raw, :read_ahead | modes]
719 end
720
721 false ->
722 modes
723 end
724
725 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
726
727 end"#
728 .unindent(),
729 0,
730 ),(r#"
731 @doc false
732 def __build__(path, modes, line_or_bytes) do
733 raw = :lists.keyfind(:encoding, 1, modes) == false
734
735 modes =
736 case raw do
737 true ->
738 case :lists.keyfind(:read_ahead, 1, modes) do
739 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
740 {:read_ahead, _} -> [:raw | modes]
741 false -> [:raw, :read_ahead | modes]
742 end
743
744 false ->
745 modes
746 end
747
748 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
749
750 end"#.unindent(), 574)],
751 );
752}
753
754#[gpui::test]
755async fn test_code_context_retrieval_cpp() {
756 let language = cpp_lang();
757 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
758 let mut retriever = CodeContextRetriever::new(embedding_provider);
759
760 let text = "
761 /**
762 * @brief Main function
763 * @returns 0 on exit
764 */
765 int main() { return 0; }
766
767 /**
768 * This is a test comment
769 */
770 class MyClass { // The class
771 public: // Access specifier
772 int myNum; // Attribute (int variable)
773 string myString; // Attribute (string variable)
774 };
775
776 // This is a test comment
777 enum Color { red, green, blue };
778
779 /** This is a preceding block comment
780 * This is the second line
781 */
782 struct { // Structure declaration
783 int myNum; // Member (int variable)
784 string myString; // Member (string variable)
785 } myStructure;
786
787 /**
788 * @brief Matrix class.
789 */
790 template <typename T,
791 typename = typename std::enable_if<
792 std::is_integral<T>::value || std::is_floating_point<T>::value,
793 bool>::type>
794 class Matrix2 {
795 std::vector<std::vector<T>> _mat;
796
797 public:
798 /**
799 * @brief Constructor
800 * @tparam Integer ensuring integers are being evaluated and not other
801 * data types.
802 * @param size denoting the size of Matrix as size x size
803 */
804 template <typename Integer,
805 typename = typename std::enable_if<std::is_integral<Integer>::value,
806 Integer>::type>
807 explicit Matrix(const Integer size) {
808 for (size_t i = 0; i < size; ++i) {
809 _mat.emplace_back(std::vector<T>(size, 0));
810 }
811 }
812 }"
813 .unindent();
814
815 let documents = retriever.parse_file(&text, language.clone()).unwrap();
816
817 assert_documents_eq(
818 &documents,
819 &[
820 (
821 "
822 /**
823 * @brief Main function
824 * @returns 0 on exit
825 */
826 int main() { return 0; }"
827 .unindent(),
828 54,
829 ),
830 (
831 "
832 /**
833 * This is a test comment
834 */
835 class MyClass { // The class
836 public: // Access specifier
837 int myNum; // Attribute (int variable)
838 string myString; // Attribute (string variable)
839 }"
840 .unindent(),
841 112,
842 ),
843 (
844 "
845 // This is a test comment
846 enum Color { red, green, blue }"
847 .unindent(),
848 322,
849 ),
850 (
851 "
852 /** This is a preceding block comment
853 * This is the second line
854 */
855 struct { // Structure declaration
856 int myNum; // Member (int variable)
857 string myString; // Member (string variable)
858 } myStructure;"
859 .unindent(),
860 425,
861 ),
862 (
863 "
864 /**
865 * @brief Matrix class.
866 */
867 template <typename T,
868 typename = typename std::enable_if<
869 std::is_integral<T>::value || std::is_floating_point<T>::value,
870 bool>::type>
871 class Matrix2 {
872 std::vector<std::vector<T>> _mat;
873
874 public:
875 /**
876 * @brief Constructor
877 * @tparam Integer ensuring integers are being evaluated and not other
878 * data types.
879 * @param size denoting the size of Matrix as size x size
880 */
881 template <typename Integer,
882 typename = typename std::enable_if<std::is_integral<Integer>::value,
883 Integer>::type>
884 explicit Matrix(const Integer size) {
885 for (size_t i = 0; i < size; ++i) {
886 _mat.emplace_back(std::vector<T>(size, 0));
887 }
888 }
889 }"
890 .unindent(),
891 612,
892 ),
893 (
894 "
895 explicit Matrix(const Integer size) {
896 for (size_t i = 0; i < size; ++i) {
897 _mat.emplace_back(std::vector<T>(size, 0));
898 }
899 }"
900 .unindent(),
901 1226,
902 ),
903 ],
904 );
905}
906
907#[gpui::test]
908async fn test_code_context_retrieval_ruby() {
909 let language = ruby_lang();
910 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
911 let mut retriever = CodeContextRetriever::new(embedding_provider);
912
913 let text = r#"
914 # This concern is inspired by "sudo mode" on GitHub. It
915 # is a way to re-authenticate a user before allowing them
916 # to see or perform an action.
917 #
918 # Add `before_action :require_challenge!` to actions you
919 # want to protect.
920 #
921 # The user will be shown a page to enter the challenge (which
922 # is either the password, or just the username when no
923 # password exists). Upon passing, there is a grace period
924 # during which no challenge will be asked from the user.
925 #
926 # Accessing challenge-protected resources during the grace
927 # period will refresh the grace period.
928 module ChallengableConcern
929 extend ActiveSupport::Concern
930
931 CHALLENGE_TIMEOUT = 1.hour.freeze
932
933 def require_challenge!
934 return if skip_challenge?
935
936 if challenge_passed_recently?
937 session[:challenge_passed_at] = Time.now.utc
938 return
939 end
940
941 @challenge = Form::Challenge.new(return_to: request.url)
942
943 if params.key?(:form_challenge)
944 if challenge_passed?
945 session[:challenge_passed_at] = Time.now.utc
946 else
947 flash.now[:alert] = I18n.t('challenge.invalid_password')
948 render_challenge
949 end
950 else
951 render_challenge
952 end
953 end
954
955 def challenge_passed?
956 current_user.valid_password?(challenge_params[:current_password])
957 end
958 end
959
960 class Animal
961 include Comparable
962
963 attr_reader :legs
964
965 def initialize(name, legs)
966 @name, @legs = name, legs
967 end
968
969 def <=>(other)
970 legs <=> other.legs
971 end
972 end
973
974 # Singleton method for car object
975 def car.wheels
976 puts "There are four wheels"
977 end"#
978 .unindent();
979
980 let documents = retriever.parse_file(&text, language.clone()).unwrap();
981
982 assert_documents_eq(
983 &documents,
984 &[
985 (
986 r#"
987 # This concern is inspired by "sudo mode" on GitHub. It
988 # is a way to re-authenticate a user before allowing them
989 # to see or perform an action.
990 #
991 # Add `before_action :require_challenge!` to actions you
992 # want to protect.
993 #
994 # The user will be shown a page to enter the challenge (which
995 # is either the password, or just the username when no
996 # password exists). Upon passing, there is a grace period
997 # during which no challenge will be asked from the user.
998 #
999 # Accessing challenge-protected resources during the grace
1000 # period will refresh the grace period.
1001 module ChallengableConcern
1002 extend ActiveSupport::Concern
1003
1004 CHALLENGE_TIMEOUT = 1.hour.freeze
1005
1006 def require_challenge!
1007 # ...
1008 end
1009
1010 def challenge_passed?
1011 # ...
1012 end
1013 end"#
1014 .unindent(),
1015 558,
1016 ),
1017 (
1018 r#"
1019 def require_challenge!
1020 return if skip_challenge?
1021
1022 if challenge_passed_recently?
1023 session[:challenge_passed_at] = Time.now.utc
1024 return
1025 end
1026
1027 @challenge = Form::Challenge.new(return_to: request.url)
1028
1029 if params.key?(:form_challenge)
1030 if challenge_passed?
1031 session[:challenge_passed_at] = Time.now.utc
1032 else
1033 flash.now[:alert] = I18n.t('challenge.invalid_password')
1034 render_challenge
1035 end
1036 else
1037 render_challenge
1038 end
1039 end"#
1040 .unindent(),
1041 663,
1042 ),
1043 (
1044 r#"
1045 def challenge_passed?
1046 current_user.valid_password?(challenge_params[:current_password])
1047 end"#
1048 .unindent(),
1049 1254,
1050 ),
1051 (
1052 r#"
1053 class Animal
1054 include Comparable
1055
1056 attr_reader :legs
1057
1058 def initialize(name, legs)
1059 # ...
1060 end
1061
1062 def <=>(other)
1063 # ...
1064 end
1065 end"#
1066 .unindent(),
1067 1363,
1068 ),
1069 (
1070 r#"
1071 def initialize(name, legs)
1072 @name, @legs = name, legs
1073 end"#
1074 .unindent(),
1075 1427,
1076 ),
1077 (
1078 r#"
1079 def <=>(other)
1080 legs <=> other.legs
1081 end"#
1082 .unindent(),
1083 1501,
1084 ),
1085 (
1086 r#"
1087 # Singleton method for car object
1088 def car.wheels
1089 puts "There are four wheels"
1090 end"#
1091 .unindent(),
1092 1591,
1093 ),
1094 ],
1095 );
1096}
1097
1098#[gpui::test]
1099async fn test_code_context_retrieval_php() {
1100 let language = php_lang();
1101 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
1102 let mut retriever = CodeContextRetriever::new(embedding_provider);
1103
1104 let text = r#"
1105 <?php
1106
1107 namespace LevelUp\Experience\Concerns;
1108
1109 /*
1110 This is a multiple-lines comment block
1111 that spans over multiple
1112 lines
1113 */
1114 function functionName() {
1115 echo "Hello world!";
1116 }
1117
1118 trait HasAchievements
1119 {
1120 /**
1121 * @throws \Exception
1122 */
1123 public function grantAchievement(Achievement $achievement, $progress = null): void
1124 {
1125 if ($progress > 100) {
1126 throw new Exception(message: 'Progress cannot be greater than 100');
1127 }
1128
1129 if ($this->achievements()->find($achievement->id)) {
1130 throw new Exception(message: 'User already has this Achievement');
1131 }
1132
1133 $this->achievements()->attach($achievement, [
1134 'progress' => $progress ?? null,
1135 ]);
1136
1137 $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1138 }
1139
1140 public function achievements(): BelongsToMany
1141 {
1142 return $this->belongsToMany(related: Achievement::class)
1143 ->withPivot(columns: 'progress')
1144 ->where('is_secret', false)
1145 ->using(AchievementUser::class);
1146 }
1147 }
1148
1149 interface Multiplier
1150 {
1151 public function qualifies(array $data): bool;
1152
1153 public function setMultiplier(): int;
1154 }
1155
1156 enum AuditType: string
1157 {
1158 case Add = 'add';
1159 case Remove = 'remove';
1160 case Reset = 'reset';
1161 case LevelUp = 'level_up';
1162 }
1163
1164 ?>"#
1165 .unindent();
1166
1167 let documents = retriever.parse_file(&text, language.clone()).unwrap();
1168
1169 assert_documents_eq(
1170 &documents,
1171 &[
1172 (
1173 r#"
1174 /*
1175 This is a multiple-lines comment block
1176 that spans over multiple
1177 lines
1178 */
1179 function functionName() {
1180 echo "Hello world!";
1181 }"#
1182 .unindent(),
1183 123,
1184 ),
1185 (
1186 r#"
1187 trait HasAchievements
1188 {
1189 /**
1190 * @throws \Exception
1191 */
1192 public function grantAchievement(Achievement $achievement, $progress = null): void
1193 {/* ... */}
1194
1195 public function achievements(): BelongsToMany
1196 {/* ... */}
1197 }"#
1198 .unindent(),
1199 177,
1200 ),
1201 (r#"
1202 /**
1203 * @throws \Exception
1204 */
1205 public function grantAchievement(Achievement $achievement, $progress = null): void
1206 {
1207 if ($progress > 100) {
1208 throw new Exception(message: 'Progress cannot be greater than 100');
1209 }
1210
1211 if ($this->achievements()->find($achievement->id)) {
1212 throw new Exception(message: 'User already has this Achievement');
1213 }
1214
1215 $this->achievements()->attach($achievement, [
1216 'progress' => $progress ?? null,
1217 ]);
1218
1219 $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1220 }"#.unindent(), 245),
1221 (r#"
1222 public function achievements(): BelongsToMany
1223 {
1224 return $this->belongsToMany(related: Achievement::class)
1225 ->withPivot(columns: 'progress')
1226 ->where('is_secret', false)
1227 ->using(AchievementUser::class);
1228 }"#.unindent(), 902),
1229 (r#"
1230 interface Multiplier
1231 {
1232 public function qualifies(array $data): bool;
1233
1234 public function setMultiplier(): int;
1235 }"#.unindent(),
1236 1146),
1237 (r#"
1238 enum AuditType: string
1239 {
1240 case Add = 'add';
1241 case Remove = 'remove';
1242 case Reset = 'reset';
1243 case LevelUp = 'level_up';
1244 }"#.unindent(), 1265)
1245 ],
1246 );
1247}
1248
1249fn js_lang() -> Arc<Language> {
1250 Arc::new(
1251 Language::new(
1252 LanguageConfig {
1253 name: "Javascript".into(),
1254 matcher: LanguageMatcher {
1255 path_suffixes: vec!["js".into()],
1256 ..Default::default()
1257 },
1258 ..Default::default()
1259 },
1260 Some(tree_sitter_typescript::language_tsx()),
1261 )
1262 .with_embedding_query(
1263 &r#"
1264
1265 (
1266 (comment)* @context
1267 .
1268 [
1269 (export_statement
1270 (function_declaration
1271 "async"? @name
1272 "function" @name
1273 name: (_) @name))
1274 (function_declaration
1275 "async"? @name
1276 "function" @name
1277 name: (_) @name)
1278 ] @item
1279 )
1280
1281 (
1282 (comment)* @context
1283 .
1284 [
1285 (export_statement
1286 (class_declaration
1287 "class" @name
1288 name: (_) @name))
1289 (class_declaration
1290 "class" @name
1291 name: (_) @name)
1292 ] @item
1293 )
1294
1295 (
1296 (comment)* @context
1297 .
1298 [
1299 (export_statement
1300 (interface_declaration
1301 "interface" @name
1302 name: (_) @name))
1303 (interface_declaration
1304 "interface" @name
1305 name: (_) @name)
1306 ] @item
1307 )
1308
1309 (
1310 (comment)* @context
1311 .
1312 [
1313 (export_statement
1314 (enum_declaration
1315 "enum" @name
1316 name: (_) @name))
1317 (enum_declaration
1318 "enum" @name
1319 name: (_) @name)
1320 ] @item
1321 )
1322
1323 (
1324 (comment)* @context
1325 .
1326 (method_definition
1327 [
1328 "get"
1329 "set"
1330 "async"
1331 "*"
1332 "static"
1333 ]* @name
1334 name: (_) @name) @item
1335 )
1336
1337 "#
1338 .unindent(),
1339 )
1340 .unwrap(),
1341 )
1342}
1343
1344fn rust_lang() -> Arc<Language> {
1345 Arc::new(
1346 Language::new(
1347 LanguageConfig {
1348 name: "Rust".into(),
1349 matcher: LanguageMatcher {
1350 path_suffixes: vec!["rs".into()],
1351 ..Default::default()
1352 },
1353 collapsed_placeholder: " /* ... */ ".to_string(),
1354 ..Default::default()
1355 },
1356 Some(tree_sitter_rust::language()),
1357 )
1358 .with_embedding_query(
1359 r#"
1360 (
1361 [(line_comment) (attribute_item)]* @context
1362 .
1363 [
1364 (struct_item
1365 name: (_) @name)
1366
1367 (enum_item
1368 name: (_) @name)
1369
1370 (impl_item
1371 trait: (_)? @name
1372 "for"? @name
1373 type: (_) @name)
1374
1375 (trait_item
1376 name: (_) @name)
1377
1378 (function_item
1379 name: (_) @name
1380 body: (block
1381 "{" @keep
1382 "}" @keep) @collapse)
1383
1384 (macro_definition
1385 name: (_) @name)
1386 ] @item
1387 )
1388
1389 (attribute_item) @collapse
1390 (use_declaration) @collapse
1391 "#,
1392 )
1393 .unwrap(),
1394 )
1395}
1396
1397fn json_lang() -> Arc<Language> {
1398 Arc::new(
1399 Language::new(
1400 LanguageConfig {
1401 name: "JSON".into(),
1402 matcher: LanguageMatcher {
1403 path_suffixes: vec!["json".into()],
1404 ..Default::default()
1405 },
1406 ..Default::default()
1407 },
1408 Some(tree_sitter_json::language()),
1409 )
1410 .with_embedding_query(
1411 r#"
1412 (document) @item
1413
1414 (array
1415 "[" @keep
1416 .
1417 (object)? @keep
1418 "]" @keep) @collapse
1419
1420 (pair value: (string
1421 "\"" @keep
1422 "\"" @keep) @collapse)
1423 "#,
1424 )
1425 .unwrap(),
1426 )
1427}
1428
1429fn toml_lang() -> Arc<Language> {
1430 Arc::new(Language::new(
1431 LanguageConfig {
1432 name: "TOML".into(),
1433 matcher: LanguageMatcher {
1434 path_suffixes: vec!["toml".into()],
1435 ..Default::default()
1436 },
1437 ..Default::default()
1438 },
1439 Some(tree_sitter_toml::language()),
1440 ))
1441}
1442
1443fn cpp_lang() -> Arc<Language> {
1444 Arc::new(
1445 Language::new(
1446 LanguageConfig {
1447 name: "CPP".into(),
1448 matcher: LanguageMatcher {
1449 path_suffixes: vec!["cpp".into()],
1450 ..Default::default()
1451 },
1452 ..Default::default()
1453 },
1454 Some(tree_sitter_cpp::language()),
1455 )
1456 .with_embedding_query(
1457 r#"
1458 (
1459 (comment)* @context
1460 .
1461 (function_definition
1462 (type_qualifier)? @name
1463 type: (_)? @name
1464 declarator: [
1465 (function_declarator
1466 declarator: (_) @name)
1467 (pointer_declarator
1468 "*" @name
1469 declarator: (function_declarator
1470 declarator: (_) @name))
1471 (pointer_declarator
1472 "*" @name
1473 declarator: (pointer_declarator
1474 "*" @name
1475 declarator: (function_declarator
1476 declarator: (_) @name)))
1477 (reference_declarator
1478 ["&" "&&"] @name
1479 (function_declarator
1480 declarator: (_) @name))
1481 ]
1482 (type_qualifier)? @name) @item
1483 )
1484
1485 (
1486 (comment)* @context
1487 .
1488 (template_declaration
1489 (class_specifier
1490 "class" @name
1491 name: (_) @name)
1492 ) @item
1493 )
1494
1495 (
1496 (comment)* @context
1497 .
1498 (class_specifier
1499 "class" @name
1500 name: (_) @name) @item
1501 )
1502
1503 (
1504 (comment)* @context
1505 .
1506 (enum_specifier
1507 "enum" @name
1508 name: (_) @name) @item
1509 )
1510
1511 (
1512 (comment)* @context
1513 .
1514 (declaration
1515 type: (struct_specifier
1516 "struct" @name)
1517 declarator: (_) @name) @item
1518 )
1519
1520 "#,
1521 )
1522 .unwrap(),
1523 )
1524}
1525
1526fn lua_lang() -> Arc<Language> {
1527 Arc::new(
1528 Language::new(
1529 LanguageConfig {
1530 name: "Lua".into(),
1531 matcher: LanguageMatcher {
1532 path_suffixes: vec!["lua".into()],
1533 ..Default::default()
1534 },
1535 collapsed_placeholder: "--[ ... ]--".to_string(),
1536 ..Default::default()
1537 },
1538 Some(tree_sitter_lua::language()),
1539 )
1540 .with_embedding_query(
1541 r#"
1542 (
1543 (comment)* @context
1544 .
1545 (function_declaration
1546 "function" @name
1547 name: (_) @name
1548 (comment)* @collapse
1549 body: (block) @collapse
1550 ) @item
1551 )
1552 "#,
1553 )
1554 .unwrap(),
1555 )
1556}
1557
1558fn php_lang() -> Arc<Language> {
1559 Arc::new(
1560 Language::new(
1561 LanguageConfig {
1562 name: "PHP".into(),
1563 matcher: LanguageMatcher {
1564 path_suffixes: vec!["php".into()],
1565 ..Default::default()
1566 },
1567 collapsed_placeholder: "/* ... */".into(),
1568 ..Default::default()
1569 },
1570 Some(tree_sitter_php::language_php()),
1571 )
1572 .with_embedding_query(
1573 r#"
1574 (
1575 (comment)* @context
1576 .
1577 [
1578 (function_definition
1579 "function" @name
1580 name: (_) @name
1581 body: (_
1582 "{" @keep
1583 "}" @keep) @collapse
1584 )
1585
1586 (trait_declaration
1587 "trait" @name
1588 name: (_) @name)
1589
1590 (method_declaration
1591 "function" @name
1592 name: (_) @name
1593 body: (_
1594 "{" @keep
1595 "}" @keep) @collapse
1596 )
1597
1598 (interface_declaration
1599 "interface" @name
1600 name: (_) @name
1601 )
1602
1603 (enum_declaration
1604 "enum" @name
1605 name: (_) @name
1606 )
1607
1608 ] @item
1609 )
1610 "#,
1611 )
1612 .unwrap(),
1613 )
1614}
1615
1616fn ruby_lang() -> Arc<Language> {
1617 Arc::new(
1618 Language::new(
1619 LanguageConfig {
1620 name: "Ruby".into(),
1621 matcher: LanguageMatcher {
1622 path_suffixes: vec!["rb".into()],
1623 ..Default::default()
1624 },
1625 collapsed_placeholder: "# ...".to_string(),
1626 ..Default::default()
1627 },
1628 Some(tree_sitter_ruby::language()),
1629 )
1630 .with_embedding_query(
1631 r#"
1632 (
1633 (comment)* @context
1634 .
1635 [
1636 (module
1637 "module" @name
1638 name: (_) @name)
1639 (method
1640 "def" @name
1641 name: (_) @name
1642 body: (body_statement) @collapse)
1643 (class
1644 "class" @name
1645 name: (_) @name)
1646 (singleton_method
1647 "def" @name
1648 object: (_) @name
1649 "." @name
1650 name: (_) @name
1651 body: (body_statement) @collapse)
1652 ] @item
1653 )
1654 "#,
1655 )
1656 .unwrap(),
1657 )
1658}
1659
1660fn elixir_lang() -> Arc<Language> {
1661 Arc::new(
1662 Language::new(
1663 LanguageConfig {
1664 name: "Elixir".into(),
1665 matcher: LanguageMatcher {
1666 path_suffixes: vec!["rs".into()],
1667 ..Default::default()
1668 },
1669 ..Default::default()
1670 },
1671 Some(tree_sitter_elixir::language()),
1672 )
1673 .with_embedding_query(
1674 r#"
1675 (
1676 (unary_operator
1677 operator: "@"
1678 operand: (call
1679 target: (identifier) @unary
1680 (#match? @unary "^(doc)$"))
1681 ) @context
1682 .
1683 (call
1684 target: (identifier) @name
1685 (arguments
1686 [
1687 (identifier) @name
1688 (call
1689 target: (identifier) @name)
1690 (binary_operator
1691 left: (call
1692 target: (identifier) @name)
1693 operator: "when")
1694 ])
1695 (#any-match? @name "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @item
1696 )
1697
1698 (call
1699 target: (identifier) @name
1700 (arguments (alias) @name)
1701 (#any-match? @name "^(defmodule|defprotocol)$")) @item
1702 "#,
1703 )
1704 .unwrap(),
1705 )
1706}
1707
1708#[gpui::test]
1709fn test_subtract_ranges() {
1710 assert_eq!(
1711 subtract_ranges(&[0..5, 10..21], &[0..1, 4..5]),
1712 vec![1..4, 10..21]
1713 );
1714
1715 assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
1716}
1717
1718fn init_test(cx: &mut TestAppContext) {
1719 cx.update(|cx| {
1720 let settings_store = SettingsStore::test(cx);
1721 cx.set_global(settings_store);
1722 SemanticIndexSettings::register(cx);
1723 ProjectSettings::register(cx);
1724 });
1725}