1use crate::{
2 embedding_queue::EmbeddingQueue,
3 parsing::{subtract_ranges, CodeContextRetriever, Span, SpanDigest},
4 semantic_index_settings::SemanticIndexSettings,
5 FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
6};
7use ai::test::FakeEmbeddingProvider;
8
9use gpui::{executor::Deterministic, Task, TestAppContext};
10use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
11use parking_lot::Mutex;
12use pretty_assertions::assert_eq;
13use project::{project_settings::ProjectSettings, FakeFs, Fs, Project};
14use rand::{rngs::StdRng, Rng};
15use serde_json::json;
16use settings::SettingsStore;
17use std::{path::Path, sync::Arc, time::SystemTime};
18use unindent::Unindent;
19use util::{paths::PathMatcher, RandomCharIter};
20
21#[ctor::ctor]
22fn init_logger() {
23 if std::env::var("RUST_LOG").is_ok() {
24 env_logger::init();
25 }
26}
27
28#[gpui::test]
29async fn test_semantic_index(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
30 init_test(cx);
31
32 let fs = FakeFs::new(cx.background());
33 fs.insert_tree(
34 "/the-root",
35 json!({
36 "src": {
37 "file1.rs": "
38 fn aaa() {
39 println!(\"aaaaaaaaaaaa!\");
40 }
41
42 fn zzzzz() {
43 println!(\"SLEEPING\");
44 }
45 ".unindent(),
46 "file2.rs": "
47 fn bbb() {
48 println!(\"bbbbbbbbbbbbb!\");
49 }
50 struct pqpqpqp {}
51 ".unindent(),
52 "file3.toml": "
53 ZZZZZZZZZZZZZZZZZZ = 5
54 ".unindent(),
55 }
56 }),
57 )
58 .await;
59
60 let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
61 let rust_language = rust_lang();
62 let toml_language = toml_lang();
63 languages.add(rust_language);
64 languages.add(toml_language);
65
66 let db_dir = tempdir::TempDir::new("vector-store").unwrap();
67 let db_path = db_dir.path().join("db.sqlite");
68
69 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
70 let semantic_index = SemanticIndex::new(
71 fs.clone(),
72 db_path,
73 embedding_provider.clone(),
74 languages,
75 cx.to_async(),
76 )
77 .await
78 .unwrap();
79
80 let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
81
82 let search_results = semantic_index.update(cx, |store, cx| {
83 store.search_project(
84 project.clone(),
85 "aaaaaabbbbzz".to_string(),
86 5,
87 vec![],
88 vec![],
89 cx,
90 )
91 });
92 let pending_file_count =
93 semantic_index.read_with(cx, |index, _| index.pending_file_count(&project).unwrap());
94 deterministic.run_until_parked();
95 assert_eq!(*pending_file_count.borrow(), 3);
96 deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
97 assert_eq!(*pending_file_count.borrow(), 0);
98
99 let search_results = search_results.await.unwrap();
100 assert_search_results(
101 &search_results,
102 &[
103 (Path::new("src/file1.rs").into(), 0),
104 (Path::new("src/file2.rs").into(), 0),
105 (Path::new("src/file3.toml").into(), 0),
106 (Path::new("src/file1.rs").into(), 45),
107 (Path::new("src/file2.rs").into(), 45),
108 ],
109 cx,
110 );
111
112 // Test Include Files Functonality
113 let include_files = vec![PathMatcher::new("*.rs").unwrap()];
114 let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
115 let rust_only_search_results = semantic_index
116 .update(cx, |store, cx| {
117 store.search_project(
118 project.clone(),
119 "aaaaaabbbbzz".to_string(),
120 5,
121 include_files,
122 vec![],
123 cx,
124 )
125 })
126 .await
127 .unwrap();
128
129 assert_search_results(
130 &rust_only_search_results,
131 &[
132 (Path::new("src/file1.rs").into(), 0),
133 (Path::new("src/file2.rs").into(), 0),
134 (Path::new("src/file1.rs").into(), 45),
135 (Path::new("src/file2.rs").into(), 45),
136 ],
137 cx,
138 );
139
140 let no_rust_search_results = semantic_index
141 .update(cx, |store, cx| {
142 store.search_project(
143 project.clone(),
144 "aaaaaabbbbzz".to_string(),
145 5,
146 vec![],
147 exclude_files,
148 cx,
149 )
150 })
151 .await
152 .unwrap();
153
154 assert_search_results(
155 &no_rust_search_results,
156 &[(Path::new("src/file3.toml").into(), 0)],
157 cx,
158 );
159
160 fs.save(
161 "/the-root/src/file2.rs".as_ref(),
162 &"
163 fn dddd() { println!(\"ddddd!\"); }
164 struct pqpqpqp {}
165 "
166 .unindent()
167 .into(),
168 Default::default(),
169 )
170 .await
171 .unwrap();
172
173 deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
174
175 let prev_embedding_count = embedding_provider.embedding_count();
176 let index = semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx));
177 deterministic.run_until_parked();
178 assert_eq!(*pending_file_count.borrow(), 1);
179 deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
180 assert_eq!(*pending_file_count.borrow(), 0);
181 index.await.unwrap();
182
183 assert_eq!(
184 embedding_provider.embedding_count() - prev_embedding_count,
185 1
186 );
187}
188
189#[gpui::test(iterations = 10)]
190async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
191 let (outstanding_job_count, _) = postage::watch::channel_with(0);
192 let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count));
193
194 let files = (1..=3)
195 .map(|file_ix| FileToEmbed {
196 worktree_id: 5,
197 path: Path::new(&format!("path-{file_ix}")).into(),
198 mtime: SystemTime::now(),
199 spans: (0..rng.gen_range(4..22))
200 .map(|document_ix| {
201 let content_len = rng.gen_range(10..100);
202 let content = RandomCharIter::new(&mut rng)
203 .with_simple_text()
204 .take(content_len)
205 .collect::<String>();
206 let digest = SpanDigest::from(content.as_str());
207 Span {
208 range: 0..10,
209 embedding: None,
210 name: format!("document {document_ix}"),
211 content,
212 digest,
213 token_count: rng.gen_range(10..30),
214 }
215 })
216 .collect(),
217 job_handle: JobHandle::new(&outstanding_job_count),
218 })
219 .collect::<Vec<_>>();
220
221 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
222
223 let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
224 for file in &files {
225 queue.push(file.clone());
226 }
227 queue.flush();
228
229 cx.foreground().run_until_parked();
230 let finished_files = queue.finished_files();
231 let mut embedded_files: Vec<_> = files
232 .iter()
233 .map(|_| finished_files.try_recv().expect("no finished file"))
234 .collect();
235
236 let expected_files: Vec<_> = files
237 .iter()
238 .map(|file| {
239 let mut file = file.clone();
240 for doc in &mut file.spans {
241 doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
242 }
243 file
244 })
245 .collect();
246
247 embedded_files.sort_by_key(|f| f.path.clone());
248
249 assert_eq!(embedded_files, expected_files);
250}
251
252#[track_caller]
253fn assert_search_results(
254 actual: &[SearchResult],
255 expected: &[(Arc<Path>, usize)],
256 cx: &TestAppContext,
257) {
258 let actual = actual
259 .iter()
260 .map(|search_result| {
261 search_result.buffer.read_with(cx, |buffer, _cx| {
262 (
263 buffer.file().unwrap().path().clone(),
264 search_result.range.start.to_offset(buffer),
265 )
266 })
267 })
268 .collect::<Vec<_>>();
269 assert_eq!(actual, expected);
270}
271
272#[gpui::test]
273async fn test_code_context_retrieval_rust() {
274 let language = rust_lang();
275 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
276 let mut retriever = CodeContextRetriever::new(embedding_provider);
277
278 let text = "
279 /// A doc comment
280 /// that spans multiple lines
281 #[gpui::test]
282 fn a() {
283 b
284 }
285
286 impl C for D {
287 }
288
289 impl E {
290 // This is also a preceding comment
291 pub fn function_1() -> Option<()> {
292 unimplemented!();
293 }
294
295 // This is a preceding comment
296 fn function_2() -> Result<()> {
297 unimplemented!();
298 }
299 }
300
301 #[derive(Clone)]
302 struct D {
303 name: String
304 }
305 "
306 .unindent();
307
308 let documents = retriever.parse_file(&text, language).unwrap();
309
310 assert_documents_eq(
311 &documents,
312 &[
313 (
314 "
315 /// A doc comment
316 /// that spans multiple lines
317 #[gpui::test]
318 fn a() {
319 b
320 }"
321 .unindent(),
322 text.find("fn a").unwrap(),
323 ),
324 (
325 "
326 impl C for D {
327 }"
328 .unindent(),
329 text.find("impl C").unwrap(),
330 ),
331 (
332 "
333 impl E {
334 // This is also a preceding comment
335 pub fn function_1() -> Option<()> { /* ... */ }
336
337 // This is a preceding comment
338 fn function_2() -> Result<()> { /* ... */ }
339 }"
340 .unindent(),
341 text.find("impl E").unwrap(),
342 ),
343 (
344 "
345 // This is also a preceding comment
346 pub fn function_1() -> Option<()> {
347 unimplemented!();
348 }"
349 .unindent(),
350 text.find("pub fn function_1").unwrap(),
351 ),
352 (
353 "
354 // This is a preceding comment
355 fn function_2() -> Result<()> {
356 unimplemented!();
357 }"
358 .unindent(),
359 text.find("fn function_2").unwrap(),
360 ),
361 (
362 "
363 #[derive(Clone)]
364 struct D {
365 name: String
366 }"
367 .unindent(),
368 text.find("struct D").unwrap(),
369 ),
370 ],
371 );
372}
373
374#[gpui::test]
375async fn test_code_context_retrieval_json() {
376 let language = json_lang();
377 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
378 let mut retriever = CodeContextRetriever::new(embedding_provider);
379
380 let text = r#"
381 {
382 "array": [1, 2, 3, 4],
383 "string": "abcdefg",
384 "nested_object": {
385 "array_2": [5, 6, 7, 8],
386 "string_2": "hijklmnop",
387 "boolean": true,
388 "none": null
389 }
390 }
391 "#
392 .unindent();
393
394 let documents = retriever.parse_file(&text, language.clone()).unwrap();
395
396 assert_documents_eq(
397 &documents,
398 &[(
399 r#"
400 {
401 "array": [],
402 "string": "",
403 "nested_object": {
404 "array_2": [],
405 "string_2": "",
406 "boolean": true,
407 "none": null
408 }
409 }"#
410 .unindent(),
411 text.find("{").unwrap(),
412 )],
413 );
414
415 let text = r#"
416 [
417 {
418 "name": "somebody",
419 "age": 42
420 },
421 {
422 "name": "somebody else",
423 "age": 43
424 }
425 ]
426 "#
427 .unindent();
428
429 let documents = retriever.parse_file(&text, language.clone()).unwrap();
430
431 assert_documents_eq(
432 &documents,
433 &[(
434 r#"
435 [{
436 "name": "",
437 "age": 42
438 }]"#
439 .unindent(),
440 text.find("[").unwrap(),
441 )],
442 );
443}
444
445fn assert_documents_eq(
446 documents: &[Span],
447 expected_contents_and_start_offsets: &[(String, usize)],
448) {
449 assert_eq!(
450 documents
451 .iter()
452 .map(|document| (document.content.clone(), document.range.start))
453 .collect::<Vec<_>>(),
454 expected_contents_and_start_offsets
455 );
456}
457
458#[gpui::test]
459async fn test_code_context_retrieval_javascript() {
460 let language = js_lang();
461 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
462 let mut retriever = CodeContextRetriever::new(embedding_provider);
463
464 let text = "
465 /* globals importScripts, backend */
466 function _authorize() {}
467
468 /**
469 * Sometimes the frontend build is way faster than backend.
470 */
471 export async function authorizeBank() {
472 _authorize(pushModal, upgradingAccountId, {});
473 }
474
475 export class SettingsPage {
476 /* This is a test setting */
477 constructor(page) {
478 this.page = page;
479 }
480 }
481
482 /* This is a test comment */
483 class TestClass {}
484
485 /* Schema for editor_events in Clickhouse. */
486 export interface ClickhouseEditorEvent {
487 installation_id: string
488 operation: string
489 }
490 "
491 .unindent();
492
493 let documents = retriever.parse_file(&text, language.clone()).unwrap();
494
495 assert_documents_eq(
496 &documents,
497 &[
498 (
499 "
500 /* globals importScripts, backend */
501 function _authorize() {}"
502 .unindent(),
503 37,
504 ),
505 (
506 "
507 /**
508 * Sometimes the frontend build is way faster than backend.
509 */
510 export async function authorizeBank() {
511 _authorize(pushModal, upgradingAccountId, {});
512 }"
513 .unindent(),
514 131,
515 ),
516 (
517 "
518 export class SettingsPage {
519 /* This is a test setting */
520 constructor(page) {
521 this.page = page;
522 }
523 }"
524 .unindent(),
525 225,
526 ),
527 (
528 "
529 /* This is a test setting */
530 constructor(page) {
531 this.page = page;
532 }"
533 .unindent(),
534 290,
535 ),
536 (
537 "
538 /* This is a test comment */
539 class TestClass {}"
540 .unindent(),
541 374,
542 ),
543 (
544 "
545 /* Schema for editor_events in Clickhouse. */
546 export interface ClickhouseEditorEvent {
547 installation_id: string
548 operation: string
549 }"
550 .unindent(),
551 440,
552 ),
553 ],
554 )
555}
556
557#[gpui::test]
558async fn test_code_context_retrieval_lua() {
559 let language = lua_lang();
560 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
561 let mut retriever = CodeContextRetriever::new(embedding_provider);
562
563 let text = r#"
564 -- Creates a new class
565 -- @param baseclass The Baseclass of this class, or nil.
566 -- @return A new class reference.
567 function classes.class(baseclass)
568 -- Create the class definition and metatable.
569 local classdef = {}
570 -- Find the super class, either Object or user-defined.
571 baseclass = baseclass or classes.Object
572 -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
573 setmetatable(classdef, { __index = baseclass })
574 -- All class instances have a reference to the class object.
575 classdef.class = classdef
576 --- Recursivly allocates the inheritance tree of the instance.
577 -- @param mastertable The 'root' of the inheritance tree.
578 -- @return Returns the instance with the allocated inheritance tree.
579 function classdef.alloc(mastertable)
580 -- All class instances have a reference to a superclass object.
581 local instance = { super = baseclass.alloc(mastertable) }
582 -- Any functions this instance does not know of will 'look up' to the superclass definition.
583 setmetatable(instance, { __index = classdef, __newindex = mastertable })
584 return instance
585 end
586 end
587 "#.unindent();
588
589 let documents = retriever.parse_file(&text, language.clone()).unwrap();
590
591 assert_documents_eq(
592 &documents,
593 &[
594 (r#"
595 -- Creates a new class
596 -- @param baseclass The Baseclass of this class, or nil.
597 -- @return A new class reference.
598 function classes.class(baseclass)
599 -- Create the class definition and metatable.
600 local classdef = {}
601 -- Find the super class, either Object or user-defined.
602 baseclass = baseclass or classes.Object
603 -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
604 setmetatable(classdef, { __index = baseclass })
605 -- All class instances have a reference to the class object.
606 classdef.class = classdef
607 --- Recursivly allocates the inheritance tree of the instance.
608 -- @param mastertable The 'root' of the inheritance tree.
609 -- @return Returns the instance with the allocated inheritance tree.
610 function classdef.alloc(mastertable)
611 --[ ... ]--
612 --[ ... ]--
613 end
614 end"#.unindent(),
615 114),
616 (r#"
617 --- Recursivly allocates the inheritance tree of the instance.
618 -- @param mastertable The 'root' of the inheritance tree.
619 -- @return Returns the instance with the allocated inheritance tree.
620 function classdef.alloc(mastertable)
621 -- All class instances have a reference to a superclass object.
622 local instance = { super = baseclass.alloc(mastertable) }
623 -- Any functions this instance does not know of will 'look up' to the superclass definition.
624 setmetatable(instance, { __index = classdef, __newindex = mastertable })
625 return instance
626 end"#.unindent(), 809),
627 ]
628 );
629}
630
631#[gpui::test]
632async fn test_code_context_retrieval_elixir() {
633 let language = elixir_lang();
634 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
635 let mut retriever = CodeContextRetriever::new(embedding_provider);
636
637 let text = r#"
638 defmodule File.Stream do
639 @moduledoc """
640 Defines a `File.Stream` struct returned by `File.stream!/3`.
641
642 The following fields are public:
643
644 * `path` - the file path
645 * `modes` - the file modes
646 * `raw` - a boolean indicating if bin functions should be used
647 * `line_or_bytes` - if reading should read lines or a given number of bytes
648 * `node` - the node the file belongs to
649
650 """
651
652 defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
653
654 @type t :: %__MODULE__{}
655
656 @doc false
657 def __build__(path, modes, line_or_bytes) do
658 raw = :lists.keyfind(:encoding, 1, modes) == false
659
660 modes =
661 case raw do
662 true ->
663 case :lists.keyfind(:read_ahead, 1, modes) do
664 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
665 {:read_ahead, _} -> [:raw | modes]
666 false -> [:raw, :read_ahead | modes]
667 end
668
669 false ->
670 modes
671 end
672
673 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
674
675 end"#
676 .unindent();
677
678 let documents = retriever.parse_file(&text, language.clone()).unwrap();
679
680 assert_documents_eq(
681 &documents,
682 &[(
683 r#"
684 defmodule File.Stream do
685 @moduledoc """
686 Defines a `File.Stream` struct returned by `File.stream!/3`.
687
688 The following fields are public:
689
690 * `path` - the file path
691 * `modes` - the file modes
692 * `raw` - a boolean indicating if bin functions should be used
693 * `line_or_bytes` - if reading should read lines or a given number of bytes
694 * `node` - the node the file belongs to
695
696 """
697
698 defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
699
700 @type t :: %__MODULE__{}
701
702 @doc false
703 def __build__(path, modes, line_or_bytes) do
704 raw = :lists.keyfind(:encoding, 1, modes) == false
705
706 modes =
707 case raw do
708 true ->
709 case :lists.keyfind(:read_ahead, 1, modes) do
710 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
711 {:read_ahead, _} -> [:raw | modes]
712 false -> [:raw, :read_ahead | modes]
713 end
714
715 false ->
716 modes
717 end
718
719 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
720
721 end"#
722 .unindent(),
723 0,
724 ),(r#"
725 @doc false
726 def __build__(path, modes, line_or_bytes) do
727 raw = :lists.keyfind(:encoding, 1, modes) == false
728
729 modes =
730 case raw do
731 true ->
732 case :lists.keyfind(:read_ahead, 1, modes) do
733 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
734 {:read_ahead, _} -> [:raw | modes]
735 false -> [:raw, :read_ahead | modes]
736 end
737
738 false ->
739 modes
740 end
741
742 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
743
744 end"#.unindent(), 574)],
745 );
746}
747
748#[gpui::test]
749async fn test_code_context_retrieval_cpp() {
750 let language = cpp_lang();
751 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
752 let mut retriever = CodeContextRetriever::new(embedding_provider);
753
754 let text = "
755 /**
756 * @brief Main function
757 * @returns 0 on exit
758 */
759 int main() { return 0; }
760
761 /**
762 * This is a test comment
763 */
764 class MyClass { // The class
765 public: // Access specifier
766 int myNum; // Attribute (int variable)
767 string myString; // Attribute (string variable)
768 };
769
770 // This is a test comment
771 enum Color { red, green, blue };
772
773 /** This is a preceding block comment
774 * This is the second line
775 */
776 struct { // Structure declaration
777 int myNum; // Member (int variable)
778 string myString; // Member (string variable)
779 } myStructure;
780
781 /**
782 * @brief Matrix class.
783 */
784 template <typename T,
785 typename = typename std::enable_if<
786 std::is_integral<T>::value || std::is_floating_point<T>::value,
787 bool>::type>
788 class Matrix2 {
789 std::vector<std::vector<T>> _mat;
790
791 public:
792 /**
793 * @brief Constructor
794 * @tparam Integer ensuring integers are being evaluated and not other
795 * data types.
796 * @param size denoting the size of Matrix as size x size
797 */
798 template <typename Integer,
799 typename = typename std::enable_if<std::is_integral<Integer>::value,
800 Integer>::type>
801 explicit Matrix(const Integer size) {
802 for (size_t i = 0; i < size; ++i) {
803 _mat.emplace_back(std::vector<T>(size, 0));
804 }
805 }
806 }"
807 .unindent();
808
809 let documents = retriever.parse_file(&text, language.clone()).unwrap();
810
811 assert_documents_eq(
812 &documents,
813 &[
814 (
815 "
816 /**
817 * @brief Main function
818 * @returns 0 on exit
819 */
820 int main() { return 0; }"
821 .unindent(),
822 54,
823 ),
824 (
825 "
826 /**
827 * This is a test comment
828 */
829 class MyClass { // The class
830 public: // Access specifier
831 int myNum; // Attribute (int variable)
832 string myString; // Attribute (string variable)
833 }"
834 .unindent(),
835 112,
836 ),
837 (
838 "
839 // This is a test comment
840 enum Color { red, green, blue }"
841 .unindent(),
842 322,
843 ),
844 (
845 "
846 /** This is a preceding block comment
847 * This is the second line
848 */
849 struct { // Structure declaration
850 int myNum; // Member (int variable)
851 string myString; // Member (string variable)
852 } myStructure;"
853 .unindent(),
854 425,
855 ),
856 (
857 "
858 /**
859 * @brief Matrix class.
860 */
861 template <typename T,
862 typename = typename std::enable_if<
863 std::is_integral<T>::value || std::is_floating_point<T>::value,
864 bool>::type>
865 class Matrix2 {
866 std::vector<std::vector<T>> _mat;
867
868 public:
869 /**
870 * @brief Constructor
871 * @tparam Integer ensuring integers are being evaluated and not other
872 * data types.
873 * @param size denoting the size of Matrix as size x size
874 */
875 template <typename Integer,
876 typename = typename std::enable_if<std::is_integral<Integer>::value,
877 Integer>::type>
878 explicit Matrix(const Integer size) {
879 for (size_t i = 0; i < size; ++i) {
880 _mat.emplace_back(std::vector<T>(size, 0));
881 }
882 }
883 }"
884 .unindent(),
885 612,
886 ),
887 (
888 "
889 explicit Matrix(const Integer size) {
890 for (size_t i = 0; i < size; ++i) {
891 _mat.emplace_back(std::vector<T>(size, 0));
892 }
893 }"
894 .unindent(),
895 1226,
896 ),
897 ],
898 );
899}
900
901#[gpui::test]
902async fn test_code_context_retrieval_ruby() {
903 let language = ruby_lang();
904 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
905 let mut retriever = CodeContextRetriever::new(embedding_provider);
906
907 let text = r#"
908 # This concern is inspired by "sudo mode" on GitHub. It
909 # is a way to re-authenticate a user before allowing them
910 # to see or perform an action.
911 #
912 # Add `before_action :require_challenge!` to actions you
913 # want to protect.
914 #
915 # The user will be shown a page to enter the challenge (which
916 # is either the password, or just the username when no
917 # password exists). Upon passing, there is a grace period
918 # during which no challenge will be asked from the user.
919 #
920 # Accessing challenge-protected resources during the grace
921 # period will refresh the grace period.
922 module ChallengableConcern
923 extend ActiveSupport::Concern
924
925 CHALLENGE_TIMEOUT = 1.hour.freeze
926
927 def require_challenge!
928 return if skip_challenge?
929
930 if challenge_passed_recently?
931 session[:challenge_passed_at] = Time.now.utc
932 return
933 end
934
935 @challenge = Form::Challenge.new(return_to: request.url)
936
937 if params.key?(:form_challenge)
938 if challenge_passed?
939 session[:challenge_passed_at] = Time.now.utc
940 else
941 flash.now[:alert] = I18n.t('challenge.invalid_password')
942 render_challenge
943 end
944 else
945 render_challenge
946 end
947 end
948
949 def challenge_passed?
950 current_user.valid_password?(challenge_params[:current_password])
951 end
952 end
953
954 class Animal
955 include Comparable
956
957 attr_reader :legs
958
959 def initialize(name, legs)
960 @name, @legs = name, legs
961 end
962
963 def <=>(other)
964 legs <=> other.legs
965 end
966 end
967
968 # Singleton method for car object
969 def car.wheels
970 puts "There are four wheels"
971 end"#
972 .unindent();
973
974 let documents = retriever.parse_file(&text, language.clone()).unwrap();
975
976 assert_documents_eq(
977 &documents,
978 &[
979 (
980 r#"
981 # This concern is inspired by "sudo mode" on GitHub. It
982 # is a way to re-authenticate a user before allowing them
983 # to see or perform an action.
984 #
985 # Add `before_action :require_challenge!` to actions you
986 # want to protect.
987 #
988 # The user will be shown a page to enter the challenge (which
989 # is either the password, or just the username when no
990 # password exists). Upon passing, there is a grace period
991 # during which no challenge will be asked from the user.
992 #
993 # Accessing challenge-protected resources during the grace
994 # period will refresh the grace period.
995 module ChallengableConcern
996 extend ActiveSupport::Concern
997
998 CHALLENGE_TIMEOUT = 1.hour.freeze
999
1000 def require_challenge!
1001 # ...
1002 end
1003
1004 def challenge_passed?
1005 # ...
1006 end
1007 end"#
1008 .unindent(),
1009 558,
1010 ),
1011 (
1012 r#"
1013 def require_challenge!
1014 return if skip_challenge?
1015
1016 if challenge_passed_recently?
1017 session[:challenge_passed_at] = Time.now.utc
1018 return
1019 end
1020
1021 @challenge = Form::Challenge.new(return_to: request.url)
1022
1023 if params.key?(:form_challenge)
1024 if challenge_passed?
1025 session[:challenge_passed_at] = Time.now.utc
1026 else
1027 flash.now[:alert] = I18n.t('challenge.invalid_password')
1028 render_challenge
1029 end
1030 else
1031 render_challenge
1032 end
1033 end"#
1034 .unindent(),
1035 663,
1036 ),
1037 (
1038 r#"
1039 def challenge_passed?
1040 current_user.valid_password?(challenge_params[:current_password])
1041 end"#
1042 .unindent(),
1043 1254,
1044 ),
1045 (
1046 r#"
1047 class Animal
1048 include Comparable
1049
1050 attr_reader :legs
1051
1052 def initialize(name, legs)
1053 # ...
1054 end
1055
1056 def <=>(other)
1057 # ...
1058 end
1059 end"#
1060 .unindent(),
1061 1363,
1062 ),
1063 (
1064 r#"
1065 def initialize(name, legs)
1066 @name, @legs = name, legs
1067 end"#
1068 .unindent(),
1069 1427,
1070 ),
1071 (
1072 r#"
1073 def <=>(other)
1074 legs <=> other.legs
1075 end"#
1076 .unindent(),
1077 1501,
1078 ),
1079 (
1080 r#"
1081 # Singleton method for car object
1082 def car.wheels
1083 puts "There are four wheels"
1084 end"#
1085 .unindent(),
1086 1591,
1087 ),
1088 ],
1089 );
1090}
1091
1092#[gpui::test]
1093async fn test_code_context_retrieval_php() {
1094 let language = php_lang();
1095 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
1096 let mut retriever = CodeContextRetriever::new(embedding_provider);
1097
1098 let text = r#"
1099 <?php
1100
1101 namespace LevelUp\Experience\Concerns;
1102
1103 /*
1104 This is a multiple-lines comment block
1105 that spans over multiple
1106 lines
1107 */
1108 function functionName() {
1109 echo "Hello world!";
1110 }
1111
1112 trait HasAchievements
1113 {
1114 /**
1115 * @throws \Exception
1116 */
1117 public function grantAchievement(Achievement $achievement, $progress = null): void
1118 {
1119 if ($progress > 100) {
1120 throw new Exception(message: 'Progress cannot be greater than 100');
1121 }
1122
1123 if ($this->achievements()->find($achievement->id)) {
1124 throw new Exception(message: 'User already has this Achievement');
1125 }
1126
1127 $this->achievements()->attach($achievement, [
1128 'progress' => $progress ?? null,
1129 ]);
1130
1131 $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1132 }
1133
1134 public function achievements(): BelongsToMany
1135 {
1136 return $this->belongsToMany(related: Achievement::class)
1137 ->withPivot(columns: 'progress')
1138 ->where('is_secret', false)
1139 ->using(AchievementUser::class);
1140 }
1141 }
1142
1143 interface Multiplier
1144 {
1145 public function qualifies(array $data): bool;
1146
1147 public function setMultiplier(): int;
1148 }
1149
1150 enum AuditType: string
1151 {
1152 case Add = 'add';
1153 case Remove = 'remove';
1154 case Reset = 'reset';
1155 case LevelUp = 'level_up';
1156 }
1157
1158 ?>"#
1159 .unindent();
1160
1161 let documents = retriever.parse_file(&text, language.clone()).unwrap();
1162
1163 assert_documents_eq(
1164 &documents,
1165 &[
1166 (
1167 r#"
1168 /*
1169 This is a multiple-lines comment block
1170 that spans over multiple
1171 lines
1172 */
1173 function functionName() {
1174 echo "Hello world!";
1175 }"#
1176 .unindent(),
1177 123,
1178 ),
1179 (
1180 r#"
1181 trait HasAchievements
1182 {
1183 /**
1184 * @throws \Exception
1185 */
1186 public function grantAchievement(Achievement $achievement, $progress = null): void
1187 {/* ... */}
1188
1189 public function achievements(): BelongsToMany
1190 {/* ... */}
1191 }"#
1192 .unindent(),
1193 177,
1194 ),
1195 (r#"
1196 /**
1197 * @throws \Exception
1198 */
1199 public function grantAchievement(Achievement $achievement, $progress = null): void
1200 {
1201 if ($progress > 100) {
1202 throw new Exception(message: 'Progress cannot be greater than 100');
1203 }
1204
1205 if ($this->achievements()->find($achievement->id)) {
1206 throw new Exception(message: 'User already has this Achievement');
1207 }
1208
1209 $this->achievements()->attach($achievement, [
1210 'progress' => $progress ?? null,
1211 ]);
1212
1213 $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1214 }"#.unindent(), 245),
1215 (r#"
1216 public function achievements(): BelongsToMany
1217 {
1218 return $this->belongsToMany(related: Achievement::class)
1219 ->withPivot(columns: 'progress')
1220 ->where('is_secret', false)
1221 ->using(AchievementUser::class);
1222 }"#.unindent(), 902),
1223 (r#"
1224 interface Multiplier
1225 {
1226 public function qualifies(array $data): bool;
1227
1228 public function setMultiplier(): int;
1229 }"#.unindent(),
1230 1146),
1231 (r#"
1232 enum AuditType: string
1233 {
1234 case Add = 'add';
1235 case Remove = 'remove';
1236 case Reset = 'reset';
1237 case LevelUp = 'level_up';
1238 }"#.unindent(), 1265)
1239 ],
1240 );
1241}
1242
1243fn js_lang() -> Arc<Language> {
1244 Arc::new(
1245 Language::new(
1246 LanguageConfig {
1247 name: "Javascript".into(),
1248 path_suffixes: vec!["js".into()],
1249 ..Default::default()
1250 },
1251 Some(tree_sitter_typescript::language_tsx()),
1252 )
1253 .with_embedding_query(
1254 &r#"
1255
1256 (
1257 (comment)* @context
1258 .
1259 [
1260 (export_statement
1261 (function_declaration
1262 "async"? @name
1263 "function" @name
1264 name: (_) @name))
1265 (function_declaration
1266 "async"? @name
1267 "function" @name
1268 name: (_) @name)
1269 ] @item
1270 )
1271
1272 (
1273 (comment)* @context
1274 .
1275 [
1276 (export_statement
1277 (class_declaration
1278 "class" @name
1279 name: (_) @name))
1280 (class_declaration
1281 "class" @name
1282 name: (_) @name)
1283 ] @item
1284 )
1285
1286 (
1287 (comment)* @context
1288 .
1289 [
1290 (export_statement
1291 (interface_declaration
1292 "interface" @name
1293 name: (_) @name))
1294 (interface_declaration
1295 "interface" @name
1296 name: (_) @name)
1297 ] @item
1298 )
1299
1300 (
1301 (comment)* @context
1302 .
1303 [
1304 (export_statement
1305 (enum_declaration
1306 "enum" @name
1307 name: (_) @name))
1308 (enum_declaration
1309 "enum" @name
1310 name: (_) @name)
1311 ] @item
1312 )
1313
1314 (
1315 (comment)* @context
1316 .
1317 (method_definition
1318 [
1319 "get"
1320 "set"
1321 "async"
1322 "*"
1323 "static"
1324 ]* @name
1325 name: (_) @name) @item
1326 )
1327
1328 "#
1329 .unindent(),
1330 )
1331 .unwrap(),
1332 )
1333}
1334
1335fn rust_lang() -> Arc<Language> {
1336 Arc::new(
1337 Language::new(
1338 LanguageConfig {
1339 name: "Rust".into(),
1340 path_suffixes: vec!["rs".into()],
1341 collapsed_placeholder: " /* ... */ ".to_string(),
1342 ..Default::default()
1343 },
1344 Some(tree_sitter_rust::language()),
1345 )
1346 .with_embedding_query(
1347 r#"
1348 (
1349 [(line_comment) (attribute_item)]* @context
1350 .
1351 [
1352 (struct_item
1353 name: (_) @name)
1354
1355 (enum_item
1356 name: (_) @name)
1357
1358 (impl_item
1359 trait: (_)? @name
1360 "for"? @name
1361 type: (_) @name)
1362
1363 (trait_item
1364 name: (_) @name)
1365
1366 (function_item
1367 name: (_) @name
1368 body: (block
1369 "{" @keep
1370 "}" @keep) @collapse)
1371
1372 (macro_definition
1373 name: (_) @name)
1374 ] @item
1375 )
1376
1377 (attribute_item) @collapse
1378 (use_declaration) @collapse
1379 "#,
1380 )
1381 .unwrap(),
1382 )
1383}
1384
1385fn json_lang() -> Arc<Language> {
1386 Arc::new(
1387 Language::new(
1388 LanguageConfig {
1389 name: "JSON".into(),
1390 path_suffixes: vec!["json".into()],
1391 ..Default::default()
1392 },
1393 Some(tree_sitter_json::language()),
1394 )
1395 .with_embedding_query(
1396 r#"
1397 (document) @item
1398
1399 (array
1400 "[" @keep
1401 .
1402 (object)? @keep
1403 "]" @keep) @collapse
1404
1405 (pair value: (string
1406 "\"" @keep
1407 "\"" @keep) @collapse)
1408 "#,
1409 )
1410 .unwrap(),
1411 )
1412}
1413
1414fn toml_lang() -> Arc<Language> {
1415 Arc::new(Language::new(
1416 LanguageConfig {
1417 name: "TOML".into(),
1418 path_suffixes: vec!["toml".into()],
1419 ..Default::default()
1420 },
1421 Some(tree_sitter_toml::language()),
1422 ))
1423}
1424
1425fn cpp_lang() -> Arc<Language> {
1426 Arc::new(
1427 Language::new(
1428 LanguageConfig {
1429 name: "CPP".into(),
1430 path_suffixes: vec!["cpp".into()],
1431 ..Default::default()
1432 },
1433 Some(tree_sitter_cpp::language()),
1434 )
1435 .with_embedding_query(
1436 r#"
1437 (
1438 (comment)* @context
1439 .
1440 (function_definition
1441 (type_qualifier)? @name
1442 type: (_)? @name
1443 declarator: [
1444 (function_declarator
1445 declarator: (_) @name)
1446 (pointer_declarator
1447 "*" @name
1448 declarator: (function_declarator
1449 declarator: (_) @name))
1450 (pointer_declarator
1451 "*" @name
1452 declarator: (pointer_declarator
1453 "*" @name
1454 declarator: (function_declarator
1455 declarator: (_) @name)))
1456 (reference_declarator
1457 ["&" "&&"] @name
1458 (function_declarator
1459 declarator: (_) @name))
1460 ]
1461 (type_qualifier)? @name) @item
1462 )
1463
1464 (
1465 (comment)* @context
1466 .
1467 (template_declaration
1468 (class_specifier
1469 "class" @name
1470 name: (_) @name)
1471 ) @item
1472 )
1473
1474 (
1475 (comment)* @context
1476 .
1477 (class_specifier
1478 "class" @name
1479 name: (_) @name) @item
1480 )
1481
1482 (
1483 (comment)* @context
1484 .
1485 (enum_specifier
1486 "enum" @name
1487 name: (_) @name) @item
1488 )
1489
1490 (
1491 (comment)* @context
1492 .
1493 (declaration
1494 type: (struct_specifier
1495 "struct" @name)
1496 declarator: (_) @name) @item
1497 )
1498
1499 "#,
1500 )
1501 .unwrap(),
1502 )
1503}
1504
1505fn lua_lang() -> Arc<Language> {
1506 Arc::new(
1507 Language::new(
1508 LanguageConfig {
1509 name: "Lua".into(),
1510 path_suffixes: vec!["lua".into()],
1511 collapsed_placeholder: "--[ ... ]--".to_string(),
1512 ..Default::default()
1513 },
1514 Some(tree_sitter_lua::language()),
1515 )
1516 .with_embedding_query(
1517 r#"
1518 (
1519 (comment)* @context
1520 .
1521 (function_declaration
1522 "function" @name
1523 name: (_) @name
1524 (comment)* @collapse
1525 body: (block) @collapse
1526 ) @item
1527 )
1528 "#,
1529 )
1530 .unwrap(),
1531 )
1532}
1533
1534fn php_lang() -> Arc<Language> {
1535 Arc::new(
1536 Language::new(
1537 LanguageConfig {
1538 name: "PHP".into(),
1539 path_suffixes: vec!["php".into()],
1540 collapsed_placeholder: "/* ... */".into(),
1541 ..Default::default()
1542 },
1543 Some(tree_sitter_php::language()),
1544 )
1545 .with_embedding_query(
1546 r#"
1547 (
1548 (comment)* @context
1549 .
1550 [
1551 (function_definition
1552 "function" @name
1553 name: (_) @name
1554 body: (_
1555 "{" @keep
1556 "}" @keep) @collapse
1557 )
1558
1559 (trait_declaration
1560 "trait" @name
1561 name: (_) @name)
1562
1563 (method_declaration
1564 "function" @name
1565 name: (_) @name
1566 body: (_
1567 "{" @keep
1568 "}" @keep) @collapse
1569 )
1570
1571 (interface_declaration
1572 "interface" @name
1573 name: (_) @name
1574 )
1575
1576 (enum_declaration
1577 "enum" @name
1578 name: (_) @name
1579 )
1580
1581 ] @item
1582 )
1583 "#,
1584 )
1585 .unwrap(),
1586 )
1587}
1588
1589fn ruby_lang() -> Arc<Language> {
1590 Arc::new(
1591 Language::new(
1592 LanguageConfig {
1593 name: "Ruby".into(),
1594 path_suffixes: vec!["rb".into()],
1595 collapsed_placeholder: "# ...".to_string(),
1596 ..Default::default()
1597 },
1598 Some(tree_sitter_ruby::language()),
1599 )
1600 .with_embedding_query(
1601 r#"
1602 (
1603 (comment)* @context
1604 .
1605 [
1606 (module
1607 "module" @name
1608 name: (_) @name)
1609 (method
1610 "def" @name
1611 name: (_) @name
1612 body: (body_statement) @collapse)
1613 (class
1614 "class" @name
1615 name: (_) @name)
1616 (singleton_method
1617 "def" @name
1618 object: (_) @name
1619 "." @name
1620 name: (_) @name
1621 body: (body_statement) @collapse)
1622 ] @item
1623 )
1624 "#,
1625 )
1626 .unwrap(),
1627 )
1628}
1629
1630fn elixir_lang() -> Arc<Language> {
1631 Arc::new(
1632 Language::new(
1633 LanguageConfig {
1634 name: "Elixir".into(),
1635 path_suffixes: vec!["rs".into()],
1636 ..Default::default()
1637 },
1638 Some(tree_sitter_elixir::language()),
1639 )
1640 .with_embedding_query(
1641 r#"
1642 (
1643 (unary_operator
1644 operator: "@"
1645 operand: (call
1646 target: (identifier) @unary
1647 (#match? @unary "^(doc)$"))
1648 ) @context
1649 .
1650 (call
1651 target: (identifier) @name
1652 (arguments
1653 [
1654 (identifier) @name
1655 (call
1656 target: (identifier) @name)
1657 (binary_operator
1658 left: (call
1659 target: (identifier) @name)
1660 operator: "when")
1661 ])
1662 (#any-match? @name "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @item
1663 )
1664
1665 (call
1666 target: (identifier) @name
1667 (arguments (alias) @name)
1668 (#any-match? @name "^(defmodule|defprotocol)$")) @item
1669 "#,
1670 )
1671 .unwrap(),
1672 )
1673}
1674
1675#[gpui::test]
1676fn test_subtract_ranges() {
1677 // collapsed_ranges: Vec<Range<usize>>, keep_ranges: Vec<Range<usize>>
1678
1679 assert_eq!(
1680 subtract_ranges(&[0..5, 10..21], &[0..1, 4..5]),
1681 vec![1..4, 10..21]
1682 );
1683
1684 assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
1685}
1686
1687fn init_test(cx: &mut TestAppContext) {
1688 cx.update(|cx| {
1689 cx.set_global(SettingsStore::test(cx));
1690 settings::register::<SemanticIndexSettings>(cx);
1691 settings::register::<ProjectSettings>(cx);
1692 });
1693}