1use crate::{
2 embedding_queue::EmbeddingQueue,
3 parsing::{subtract_ranges, CodeContextRetriever, Span, SpanDigest},
4 semantic_index_settings::SemanticIndexSettings,
5 FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
6};
7use ai::test::FakeEmbeddingProvider;
8
9use gpui::{executor::Deterministic, Task, TestAppContext};
10use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
11use parking_lot::Mutex;
12use pretty_assertions::assert_eq;
13use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs, Project};
14use rand::{rngs::StdRng, Rng};
15use serde_json::json;
16use settings::SettingsStore;
17use std::{path::Path, sync::Arc, time::SystemTime};
18use unindent::Unindent;
19use util::RandomCharIter;
20
21#[ctor::ctor]
22fn init_logger() {
23 if std::env::var("RUST_LOG").is_ok() {
24 env_logger::init();
25 }
26}
27
28#[gpui::test]
29async fn test_semantic_index(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
30 init_test(cx);
31
32 let fs = FakeFs::new(cx.background());
33 fs.insert_tree(
34 "/the-root",
35 json!({
36 "src": {
37 "file1.rs": "
38 fn aaa() {
39 println!(\"aaaaaaaaaaaa!\");
40 }
41
42 fn zzzzz() {
43 println!(\"SLEEPING\");
44 }
45 ".unindent(),
46 "file2.rs": "
47 fn bbb() {
48 println!(\"bbbbbbbbbbbbb!\");
49 }
50 struct pqpqpqp {}
51 ".unindent(),
52 "file3.toml": "
53 ZZZZZZZZZZZZZZZZZZ = 5
54 ".unindent(),
55 }
56 }),
57 )
58 .await;
59
60 let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
61 let rust_language = rust_lang();
62 let toml_language = toml_lang();
63 languages.add(rust_language);
64 languages.add(toml_language);
65
66 let db_dir = tempdir::TempDir::new("vector-store").unwrap();
67 let db_path = db_dir.path().join("db.sqlite");
68
69 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
70 let semantic_index = SemanticIndex::new(
71 fs.clone(),
72 db_path,
73 embedding_provider.clone(),
74 languages,
75 cx.to_async(),
76 )
77 .await
78 .unwrap();
79
80 let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
81
82 let search_results = semantic_index.update(cx, |store, cx| {
83 store.search_project(
84 project.clone(),
85 "aaaaaabbbbzz".to_string(),
86 5,
87 vec![],
88 vec![],
89 cx,
90 )
91 });
92 let pending_file_count =
93 semantic_index.read_with(cx, |index, _| index.pending_file_count(&project).unwrap());
94 deterministic.run_until_parked();
95 assert_eq!(*pending_file_count.borrow(), 3);
96 deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
97 assert_eq!(*pending_file_count.borrow(), 0);
98
99 let search_results = search_results.await.unwrap();
100 assert_search_results(
101 &search_results,
102 &[
103 (Path::new("src/file1.rs").into(), 0),
104 (Path::new("src/file2.rs").into(), 0),
105 (Path::new("src/file3.toml").into(), 0),
106 (Path::new("src/file1.rs").into(), 45),
107 (Path::new("src/file2.rs").into(), 45),
108 ],
109 cx,
110 );
111
112 // Test Include Files Functonality
113 let include_files = vec![PathMatcher::new("*.rs").unwrap()];
114 let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
115 let rust_only_search_results = semantic_index
116 .update(cx, |store, cx| {
117 store.search_project(
118 project.clone(),
119 "aaaaaabbbbzz".to_string(),
120 5,
121 include_files,
122 vec![],
123 cx,
124 )
125 })
126 .await
127 .unwrap();
128
129 assert_search_results(
130 &rust_only_search_results,
131 &[
132 (Path::new("src/file1.rs").into(), 0),
133 (Path::new("src/file2.rs").into(), 0),
134 (Path::new("src/file1.rs").into(), 45),
135 (Path::new("src/file2.rs").into(), 45),
136 ],
137 cx,
138 );
139
140 let no_rust_search_results = semantic_index
141 .update(cx, |store, cx| {
142 store.search_project(
143 project.clone(),
144 "aaaaaabbbbzz".to_string(),
145 5,
146 vec![],
147 exclude_files,
148 cx,
149 )
150 })
151 .await
152 .unwrap();
153
154 assert_search_results(
155 &no_rust_search_results,
156 &[(Path::new("src/file3.toml").into(), 0)],
157 cx,
158 );
159
160 fs.save(
161 "/the-root/src/file2.rs".as_ref(),
162 &"
163 fn dddd() { println!(\"ddddd!\"); }
164 struct pqpqpqp {}
165 "
166 .unindent()
167 .into(),
168 Default::default(),
169 )
170 .await
171 .unwrap();
172
173 deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
174
175 let prev_embedding_count = embedding_provider.embedding_count();
176 let index = semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx));
177 deterministic.run_until_parked();
178 assert_eq!(*pending_file_count.borrow(), 1);
179 deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
180 assert_eq!(*pending_file_count.borrow(), 0);
181 index.await.unwrap();
182
183 assert_eq!(
184 embedding_provider.embedding_count() - prev_embedding_count,
185 1
186 );
187}
188
189#[gpui::test(iterations = 10)]
190async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
191 let (outstanding_job_count, _) = postage::watch::channel_with(0);
192 let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count));
193
194 let files = (1..=3)
195 .map(|file_ix| FileToEmbed {
196 worktree_id: 5,
197 path: Path::new(&format!("path-{file_ix}")).into(),
198 mtime: SystemTime::now(),
199 spans: (0..rng.gen_range(4..22))
200 .map(|document_ix| {
201 let content_len = rng.gen_range(10..100);
202 let content = RandomCharIter::new(&mut rng)
203 .with_simple_text()
204 .take(content_len)
205 .collect::<String>();
206 let digest = SpanDigest::from(content.as_str());
207 Span {
208 range: 0..10,
209 embedding: None,
210 name: format!("document {document_ix}"),
211 content,
212 digest,
213 token_count: rng.gen_range(10..30),
214 }
215 })
216 .collect(),
217 job_handle: JobHandle::new(&outstanding_job_count),
218 })
219 .collect::<Vec<_>>();
220
221 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
222
223 let mut queue = EmbeddingQueue::new(
224 embedding_provider.clone(),
225 cx.background(),
226 ai::auth::ProviderCredential::NoCredentials,
227 );
228 for file in &files {
229 queue.push(file.clone());
230 }
231 queue.flush();
232
233 cx.foreground().run_until_parked();
234 let finished_files = queue.finished_files();
235 let mut embedded_files: Vec<_> = files
236 .iter()
237 .map(|_| finished_files.try_recv().expect("no finished file"))
238 .collect();
239
240 let expected_files: Vec<_> = files
241 .iter()
242 .map(|file| {
243 let mut file = file.clone();
244 for doc in &mut file.spans {
245 doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
246 }
247 file
248 })
249 .collect();
250
251 embedded_files.sort_by_key(|f| f.path.clone());
252
253 assert_eq!(embedded_files, expected_files);
254}
255
256#[track_caller]
257fn assert_search_results(
258 actual: &[SearchResult],
259 expected: &[(Arc<Path>, usize)],
260 cx: &TestAppContext,
261) {
262 let actual = actual
263 .iter()
264 .map(|search_result| {
265 search_result.buffer.read_with(cx, |buffer, _cx| {
266 (
267 buffer.file().unwrap().path().clone(),
268 search_result.range.start.to_offset(buffer),
269 )
270 })
271 })
272 .collect::<Vec<_>>();
273 assert_eq!(actual, expected);
274}
275
276#[gpui::test]
277async fn test_code_context_retrieval_rust() {
278 let language = rust_lang();
279 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
280 let mut retriever = CodeContextRetriever::new(embedding_provider);
281
282 let text = "
283 /// A doc comment
284 /// that spans multiple lines
285 #[gpui::test]
286 fn a() {
287 b
288 }
289
290 impl C for D {
291 }
292
293 impl E {
294 // This is also a preceding comment
295 pub fn function_1() -> Option<()> {
296 todo!();
297 }
298
299 // This is a preceding comment
300 fn function_2() -> Result<()> {
301 todo!();
302 }
303 }
304
305 #[derive(Clone)]
306 struct D {
307 name: String
308 }
309 "
310 .unindent();
311
312 let documents = retriever.parse_file(&text, language).unwrap();
313
314 assert_documents_eq(
315 &documents,
316 &[
317 (
318 "
319 /// A doc comment
320 /// that spans multiple lines
321 #[gpui::test]
322 fn a() {
323 b
324 }"
325 .unindent(),
326 text.find("fn a").unwrap(),
327 ),
328 (
329 "
330 impl C for D {
331 }"
332 .unindent(),
333 text.find("impl C").unwrap(),
334 ),
335 (
336 "
337 impl E {
338 // This is also a preceding comment
339 pub fn function_1() -> Option<()> { /* ... */ }
340
341 // This is a preceding comment
342 fn function_2() -> Result<()> { /* ... */ }
343 }"
344 .unindent(),
345 text.find("impl E").unwrap(),
346 ),
347 (
348 "
349 // This is also a preceding comment
350 pub fn function_1() -> Option<()> {
351 todo!();
352 }"
353 .unindent(),
354 text.find("pub fn function_1").unwrap(),
355 ),
356 (
357 "
358 // This is a preceding comment
359 fn function_2() -> Result<()> {
360 todo!();
361 }"
362 .unindent(),
363 text.find("fn function_2").unwrap(),
364 ),
365 (
366 "
367 #[derive(Clone)]
368 struct D {
369 name: String
370 }"
371 .unindent(),
372 text.find("struct D").unwrap(),
373 ),
374 ],
375 );
376}
377
378#[gpui::test]
379async fn test_code_context_retrieval_json() {
380 let language = json_lang();
381 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
382 let mut retriever = CodeContextRetriever::new(embedding_provider);
383
384 let text = r#"
385 {
386 "array": [1, 2, 3, 4],
387 "string": "abcdefg",
388 "nested_object": {
389 "array_2": [5, 6, 7, 8],
390 "string_2": "hijklmnop",
391 "boolean": true,
392 "none": null
393 }
394 }
395 "#
396 .unindent();
397
398 let documents = retriever.parse_file(&text, language.clone()).unwrap();
399
400 assert_documents_eq(
401 &documents,
402 &[(
403 r#"
404 {
405 "array": [],
406 "string": "",
407 "nested_object": {
408 "array_2": [],
409 "string_2": "",
410 "boolean": true,
411 "none": null
412 }
413 }"#
414 .unindent(),
415 text.find("{").unwrap(),
416 )],
417 );
418
419 let text = r#"
420 [
421 {
422 "name": "somebody",
423 "age": 42
424 },
425 {
426 "name": "somebody else",
427 "age": 43
428 }
429 ]
430 "#
431 .unindent();
432
433 let documents = retriever.parse_file(&text, language.clone()).unwrap();
434
435 assert_documents_eq(
436 &documents,
437 &[(
438 r#"
439 [{
440 "name": "",
441 "age": 42
442 }]"#
443 .unindent(),
444 text.find("[").unwrap(),
445 )],
446 );
447}
448
449fn assert_documents_eq(
450 documents: &[Span],
451 expected_contents_and_start_offsets: &[(String, usize)],
452) {
453 assert_eq!(
454 documents
455 .iter()
456 .map(|document| (document.content.clone(), document.range.start))
457 .collect::<Vec<_>>(),
458 expected_contents_and_start_offsets
459 );
460}
461
462#[gpui::test]
463async fn test_code_context_retrieval_javascript() {
464 let language = js_lang();
465 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
466 let mut retriever = CodeContextRetriever::new(embedding_provider);
467
468 let text = "
469 /* globals importScripts, backend */
470 function _authorize() {}
471
472 /**
473 * Sometimes the frontend build is way faster than backend.
474 */
475 export async function authorizeBank() {
476 _authorize(pushModal, upgradingAccountId, {});
477 }
478
479 export class SettingsPage {
480 /* This is a test setting */
481 constructor(page) {
482 this.page = page;
483 }
484 }
485
486 /* This is a test comment */
487 class TestClass {}
488
489 /* Schema for editor_events in Clickhouse. */
490 export interface ClickhouseEditorEvent {
491 installation_id: string
492 operation: string
493 }
494 "
495 .unindent();
496
497 let documents = retriever.parse_file(&text, language.clone()).unwrap();
498
499 assert_documents_eq(
500 &documents,
501 &[
502 (
503 "
504 /* globals importScripts, backend */
505 function _authorize() {}"
506 .unindent(),
507 37,
508 ),
509 (
510 "
511 /**
512 * Sometimes the frontend build is way faster than backend.
513 */
514 export async function authorizeBank() {
515 _authorize(pushModal, upgradingAccountId, {});
516 }"
517 .unindent(),
518 131,
519 ),
520 (
521 "
522 export class SettingsPage {
523 /* This is a test setting */
524 constructor(page) {
525 this.page = page;
526 }
527 }"
528 .unindent(),
529 225,
530 ),
531 (
532 "
533 /* This is a test setting */
534 constructor(page) {
535 this.page = page;
536 }"
537 .unindent(),
538 290,
539 ),
540 (
541 "
542 /* This is a test comment */
543 class TestClass {}"
544 .unindent(),
545 374,
546 ),
547 (
548 "
549 /* Schema for editor_events in Clickhouse. */
550 export interface ClickhouseEditorEvent {
551 installation_id: string
552 operation: string
553 }"
554 .unindent(),
555 440,
556 ),
557 ],
558 )
559}
560
561#[gpui::test]
562async fn test_code_context_retrieval_lua() {
563 let language = lua_lang();
564 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
565 let mut retriever = CodeContextRetriever::new(embedding_provider);
566
567 let text = r#"
568 -- Creates a new class
569 -- @param baseclass The Baseclass of this class, or nil.
570 -- @return A new class reference.
571 function classes.class(baseclass)
572 -- Create the class definition and metatable.
573 local classdef = {}
574 -- Find the super class, either Object or user-defined.
575 baseclass = baseclass or classes.Object
576 -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
577 setmetatable(classdef, { __index = baseclass })
578 -- All class instances have a reference to the class object.
579 classdef.class = classdef
580 --- Recursivly allocates the inheritance tree of the instance.
581 -- @param mastertable The 'root' of the inheritance tree.
582 -- @return Returns the instance with the allocated inheritance tree.
583 function classdef.alloc(mastertable)
584 -- All class instances have a reference to a superclass object.
585 local instance = { super = baseclass.alloc(mastertable) }
586 -- Any functions this instance does not know of will 'look up' to the superclass definition.
587 setmetatable(instance, { __index = classdef, __newindex = mastertable })
588 return instance
589 end
590 end
591 "#.unindent();
592
593 let documents = retriever.parse_file(&text, language.clone()).unwrap();
594
595 assert_documents_eq(
596 &documents,
597 &[
598 (r#"
599 -- Creates a new class
600 -- @param baseclass The Baseclass of this class, or nil.
601 -- @return A new class reference.
602 function classes.class(baseclass)
603 -- Create the class definition and metatable.
604 local classdef = {}
605 -- Find the super class, either Object or user-defined.
606 baseclass = baseclass or classes.Object
607 -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
608 setmetatable(classdef, { __index = baseclass })
609 -- All class instances have a reference to the class object.
610 classdef.class = classdef
611 --- Recursivly allocates the inheritance tree of the instance.
612 -- @param mastertable The 'root' of the inheritance tree.
613 -- @return Returns the instance with the allocated inheritance tree.
614 function classdef.alloc(mastertable)
615 --[ ... ]--
616 --[ ... ]--
617 end
618 end"#.unindent(),
619 114),
620 (r#"
621 --- Recursivly allocates the inheritance tree of the instance.
622 -- @param mastertable The 'root' of the inheritance tree.
623 -- @return Returns the instance with the allocated inheritance tree.
624 function classdef.alloc(mastertable)
625 -- All class instances have a reference to a superclass object.
626 local instance = { super = baseclass.alloc(mastertable) }
627 -- Any functions this instance does not know of will 'look up' to the superclass definition.
628 setmetatable(instance, { __index = classdef, __newindex = mastertable })
629 return instance
630 end"#.unindent(), 809),
631 ]
632 );
633}
634
635#[gpui::test]
636async fn test_code_context_retrieval_elixir() {
637 let language = elixir_lang();
638 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
639 let mut retriever = CodeContextRetriever::new(embedding_provider);
640
641 let text = r#"
642 defmodule File.Stream do
643 @moduledoc """
644 Defines a `File.Stream` struct returned by `File.stream!/3`.
645
646 The following fields are public:
647
648 * `path` - the file path
649 * `modes` - the file modes
650 * `raw` - a boolean indicating if bin functions should be used
651 * `line_or_bytes` - if reading should read lines or a given number of bytes
652 * `node` - the node the file belongs to
653
654 """
655
656 defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
657
658 @type t :: %__MODULE__{}
659
660 @doc false
661 def __build__(path, modes, line_or_bytes) do
662 raw = :lists.keyfind(:encoding, 1, modes) == false
663
664 modes =
665 case raw do
666 true ->
667 case :lists.keyfind(:read_ahead, 1, modes) do
668 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
669 {:read_ahead, _} -> [:raw | modes]
670 false -> [:raw, :read_ahead | modes]
671 end
672
673 false ->
674 modes
675 end
676
677 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
678
679 end"#
680 .unindent();
681
682 let documents = retriever.parse_file(&text, language.clone()).unwrap();
683
684 assert_documents_eq(
685 &documents,
686 &[(
687 r#"
688 defmodule File.Stream do
689 @moduledoc """
690 Defines a `File.Stream` struct returned by `File.stream!/3`.
691
692 The following fields are public:
693
694 * `path` - the file path
695 * `modes` - the file modes
696 * `raw` - a boolean indicating if bin functions should be used
697 * `line_or_bytes` - if reading should read lines or a given number of bytes
698 * `node` - the node the file belongs to
699
700 """
701
702 defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
703
704 @type t :: %__MODULE__{}
705
706 @doc false
707 def __build__(path, modes, line_or_bytes) do
708 raw = :lists.keyfind(:encoding, 1, modes) == false
709
710 modes =
711 case raw do
712 true ->
713 case :lists.keyfind(:read_ahead, 1, modes) do
714 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
715 {:read_ahead, _} -> [:raw | modes]
716 false -> [:raw, :read_ahead | modes]
717 end
718
719 false ->
720 modes
721 end
722
723 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
724
725 end"#
726 .unindent(),
727 0,
728 ),(r#"
729 @doc false
730 def __build__(path, modes, line_or_bytes) do
731 raw = :lists.keyfind(:encoding, 1, modes) == false
732
733 modes =
734 case raw do
735 true ->
736 case :lists.keyfind(:read_ahead, 1, modes) do
737 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
738 {:read_ahead, _} -> [:raw | modes]
739 false -> [:raw, :read_ahead | modes]
740 end
741
742 false ->
743 modes
744 end
745
746 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
747
748 end"#.unindent(), 574)],
749 );
750}
751
752#[gpui::test]
753async fn test_code_context_retrieval_cpp() {
754 let language = cpp_lang();
755 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
756 let mut retriever = CodeContextRetriever::new(embedding_provider);
757
758 let text = "
759 /**
760 * @brief Main function
761 * @returns 0 on exit
762 */
763 int main() { return 0; }
764
765 /**
766 * This is a test comment
767 */
768 class MyClass { // The class
769 public: // Access specifier
770 int myNum; // Attribute (int variable)
771 string myString; // Attribute (string variable)
772 };
773
774 // This is a test comment
775 enum Color { red, green, blue };
776
777 /** This is a preceding block comment
778 * This is the second line
779 */
780 struct { // Structure declaration
781 int myNum; // Member (int variable)
782 string myString; // Member (string variable)
783 } myStructure;
784
785 /**
786 * @brief Matrix class.
787 */
788 template <typename T,
789 typename = typename std::enable_if<
790 std::is_integral<T>::value || std::is_floating_point<T>::value,
791 bool>::type>
792 class Matrix2 {
793 std::vector<std::vector<T>> _mat;
794
795 public:
796 /**
797 * @brief Constructor
798 * @tparam Integer ensuring integers are being evaluated and not other
799 * data types.
800 * @param size denoting the size of Matrix as size x size
801 */
802 template <typename Integer,
803 typename = typename std::enable_if<std::is_integral<Integer>::value,
804 Integer>::type>
805 explicit Matrix(const Integer size) {
806 for (size_t i = 0; i < size; ++i) {
807 _mat.emplace_back(std::vector<T>(size, 0));
808 }
809 }
810 }"
811 .unindent();
812
813 let documents = retriever.parse_file(&text, language.clone()).unwrap();
814
815 assert_documents_eq(
816 &documents,
817 &[
818 (
819 "
820 /**
821 * @brief Main function
822 * @returns 0 on exit
823 */
824 int main() { return 0; }"
825 .unindent(),
826 54,
827 ),
828 (
829 "
830 /**
831 * This is a test comment
832 */
833 class MyClass { // The class
834 public: // Access specifier
835 int myNum; // Attribute (int variable)
836 string myString; // Attribute (string variable)
837 }"
838 .unindent(),
839 112,
840 ),
841 (
842 "
843 // This is a test comment
844 enum Color { red, green, blue }"
845 .unindent(),
846 322,
847 ),
848 (
849 "
850 /** This is a preceding block comment
851 * This is the second line
852 */
853 struct { // Structure declaration
854 int myNum; // Member (int variable)
855 string myString; // Member (string variable)
856 } myStructure;"
857 .unindent(),
858 425,
859 ),
860 (
861 "
862 /**
863 * @brief Matrix class.
864 */
865 template <typename T,
866 typename = typename std::enable_if<
867 std::is_integral<T>::value || std::is_floating_point<T>::value,
868 bool>::type>
869 class Matrix2 {
870 std::vector<std::vector<T>> _mat;
871
872 public:
873 /**
874 * @brief Constructor
875 * @tparam Integer ensuring integers are being evaluated and not other
876 * data types.
877 * @param size denoting the size of Matrix as size x size
878 */
879 template <typename Integer,
880 typename = typename std::enable_if<std::is_integral<Integer>::value,
881 Integer>::type>
882 explicit Matrix(const Integer size) {
883 for (size_t i = 0; i < size; ++i) {
884 _mat.emplace_back(std::vector<T>(size, 0));
885 }
886 }
887 }"
888 .unindent(),
889 612,
890 ),
891 (
892 "
893 explicit Matrix(const Integer size) {
894 for (size_t i = 0; i < size; ++i) {
895 _mat.emplace_back(std::vector<T>(size, 0));
896 }
897 }"
898 .unindent(),
899 1226,
900 ),
901 ],
902 );
903}
904
905#[gpui::test]
906async fn test_code_context_retrieval_ruby() {
907 let language = ruby_lang();
908 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
909 let mut retriever = CodeContextRetriever::new(embedding_provider);
910
911 let text = r#"
912 # This concern is inspired by "sudo mode" on GitHub. It
913 # is a way to re-authenticate a user before allowing them
914 # to see or perform an action.
915 #
916 # Add `before_action :require_challenge!` to actions you
917 # want to protect.
918 #
919 # The user will be shown a page to enter the challenge (which
920 # is either the password, or just the username when no
921 # password exists). Upon passing, there is a grace period
922 # during which no challenge will be asked from the user.
923 #
924 # Accessing challenge-protected resources during the grace
925 # period will refresh the grace period.
926 module ChallengableConcern
927 extend ActiveSupport::Concern
928
929 CHALLENGE_TIMEOUT = 1.hour.freeze
930
931 def require_challenge!
932 return if skip_challenge?
933
934 if challenge_passed_recently?
935 session[:challenge_passed_at] = Time.now.utc
936 return
937 end
938
939 @challenge = Form::Challenge.new(return_to: request.url)
940
941 if params.key?(:form_challenge)
942 if challenge_passed?
943 session[:challenge_passed_at] = Time.now.utc
944 else
945 flash.now[:alert] = I18n.t('challenge.invalid_password')
946 render_challenge
947 end
948 else
949 render_challenge
950 end
951 end
952
953 def challenge_passed?
954 current_user.valid_password?(challenge_params[:current_password])
955 end
956 end
957
958 class Animal
959 include Comparable
960
961 attr_reader :legs
962
963 def initialize(name, legs)
964 @name, @legs = name, legs
965 end
966
967 def <=>(other)
968 legs <=> other.legs
969 end
970 end
971
972 # Singleton method for car object
973 def car.wheels
974 puts "There are four wheels"
975 end"#
976 .unindent();
977
978 let documents = retriever.parse_file(&text, language.clone()).unwrap();
979
980 assert_documents_eq(
981 &documents,
982 &[
983 (
984 r#"
985 # This concern is inspired by "sudo mode" on GitHub. It
986 # is a way to re-authenticate a user before allowing them
987 # to see or perform an action.
988 #
989 # Add `before_action :require_challenge!` to actions you
990 # want to protect.
991 #
992 # The user will be shown a page to enter the challenge (which
993 # is either the password, or just the username when no
994 # password exists). Upon passing, there is a grace period
995 # during which no challenge will be asked from the user.
996 #
997 # Accessing challenge-protected resources during the grace
998 # period will refresh the grace period.
999 module ChallengableConcern
1000 extend ActiveSupport::Concern
1001
1002 CHALLENGE_TIMEOUT = 1.hour.freeze
1003
1004 def require_challenge!
1005 # ...
1006 end
1007
1008 def challenge_passed?
1009 # ...
1010 end
1011 end"#
1012 .unindent(),
1013 558,
1014 ),
1015 (
1016 r#"
1017 def require_challenge!
1018 return if skip_challenge?
1019
1020 if challenge_passed_recently?
1021 session[:challenge_passed_at] = Time.now.utc
1022 return
1023 end
1024
1025 @challenge = Form::Challenge.new(return_to: request.url)
1026
1027 if params.key?(:form_challenge)
1028 if challenge_passed?
1029 session[:challenge_passed_at] = Time.now.utc
1030 else
1031 flash.now[:alert] = I18n.t('challenge.invalid_password')
1032 render_challenge
1033 end
1034 else
1035 render_challenge
1036 end
1037 end"#
1038 .unindent(),
1039 663,
1040 ),
1041 (
1042 r#"
1043 def challenge_passed?
1044 current_user.valid_password?(challenge_params[:current_password])
1045 end"#
1046 .unindent(),
1047 1254,
1048 ),
1049 (
1050 r#"
1051 class Animal
1052 include Comparable
1053
1054 attr_reader :legs
1055
1056 def initialize(name, legs)
1057 # ...
1058 end
1059
1060 def <=>(other)
1061 # ...
1062 end
1063 end"#
1064 .unindent(),
1065 1363,
1066 ),
1067 (
1068 r#"
1069 def initialize(name, legs)
1070 @name, @legs = name, legs
1071 end"#
1072 .unindent(),
1073 1427,
1074 ),
1075 (
1076 r#"
1077 def <=>(other)
1078 legs <=> other.legs
1079 end"#
1080 .unindent(),
1081 1501,
1082 ),
1083 (
1084 r#"
1085 # Singleton method for car object
1086 def car.wheels
1087 puts "There are four wheels"
1088 end"#
1089 .unindent(),
1090 1591,
1091 ),
1092 ],
1093 );
1094}
1095
1096#[gpui::test]
1097async fn test_code_context_retrieval_php() {
1098 let language = php_lang();
1099 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
1100 let mut retriever = CodeContextRetriever::new(embedding_provider);
1101
1102 let text = r#"
1103 <?php
1104
1105 namespace LevelUp\Experience\Concerns;
1106
1107 /*
1108 This is a multiple-lines comment block
1109 that spans over multiple
1110 lines
1111 */
1112 function functionName() {
1113 echo "Hello world!";
1114 }
1115
1116 trait HasAchievements
1117 {
1118 /**
1119 * @throws \Exception
1120 */
1121 public function grantAchievement(Achievement $achievement, $progress = null): void
1122 {
1123 if ($progress > 100) {
1124 throw new Exception(message: 'Progress cannot be greater than 100');
1125 }
1126
1127 if ($this->achievements()->find($achievement->id)) {
1128 throw new Exception(message: 'User already has this Achievement');
1129 }
1130
1131 $this->achievements()->attach($achievement, [
1132 'progress' => $progress ?? null,
1133 ]);
1134
1135 $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1136 }
1137
1138 public function achievements(): BelongsToMany
1139 {
1140 return $this->belongsToMany(related: Achievement::class)
1141 ->withPivot(columns: 'progress')
1142 ->where('is_secret', false)
1143 ->using(AchievementUser::class);
1144 }
1145 }
1146
1147 interface Multiplier
1148 {
1149 public function qualifies(array $data): bool;
1150
1151 public function setMultiplier(): int;
1152 }
1153
1154 enum AuditType: string
1155 {
1156 case Add = 'add';
1157 case Remove = 'remove';
1158 case Reset = 'reset';
1159 case LevelUp = 'level_up';
1160 }
1161
1162 ?>"#
1163 .unindent();
1164
1165 let documents = retriever.parse_file(&text, language.clone()).unwrap();
1166
1167 assert_documents_eq(
1168 &documents,
1169 &[
1170 (
1171 r#"
1172 /*
1173 This is a multiple-lines comment block
1174 that spans over multiple
1175 lines
1176 */
1177 function functionName() {
1178 echo "Hello world!";
1179 }"#
1180 .unindent(),
1181 123,
1182 ),
1183 (
1184 r#"
1185 trait HasAchievements
1186 {
1187 /**
1188 * @throws \Exception
1189 */
1190 public function grantAchievement(Achievement $achievement, $progress = null): void
1191 {/* ... */}
1192
1193 public function achievements(): BelongsToMany
1194 {/* ... */}
1195 }"#
1196 .unindent(),
1197 177,
1198 ),
1199 (r#"
1200 /**
1201 * @throws \Exception
1202 */
1203 public function grantAchievement(Achievement $achievement, $progress = null): void
1204 {
1205 if ($progress > 100) {
1206 throw new Exception(message: 'Progress cannot be greater than 100');
1207 }
1208
1209 if ($this->achievements()->find($achievement->id)) {
1210 throw new Exception(message: 'User already has this Achievement');
1211 }
1212
1213 $this->achievements()->attach($achievement, [
1214 'progress' => $progress ?? null,
1215 ]);
1216
1217 $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1218 }"#.unindent(), 245),
1219 (r#"
1220 public function achievements(): BelongsToMany
1221 {
1222 return $this->belongsToMany(related: Achievement::class)
1223 ->withPivot(columns: 'progress')
1224 ->where('is_secret', false)
1225 ->using(AchievementUser::class);
1226 }"#.unindent(), 902),
1227 (r#"
1228 interface Multiplier
1229 {
1230 public function qualifies(array $data): bool;
1231
1232 public function setMultiplier(): int;
1233 }"#.unindent(),
1234 1146),
1235 (r#"
1236 enum AuditType: string
1237 {
1238 case Add = 'add';
1239 case Remove = 'remove';
1240 case Reset = 'reset';
1241 case LevelUp = 'level_up';
1242 }"#.unindent(), 1265)
1243 ],
1244 );
1245}
1246
1247fn js_lang() -> Arc<Language> {
1248 Arc::new(
1249 Language::new(
1250 LanguageConfig {
1251 name: "Javascript".into(),
1252 path_suffixes: vec!["js".into()],
1253 ..Default::default()
1254 },
1255 Some(tree_sitter_typescript::language_tsx()),
1256 )
1257 .with_embedding_query(
1258 &r#"
1259
1260 (
1261 (comment)* @context
1262 .
1263 [
1264 (export_statement
1265 (function_declaration
1266 "async"? @name
1267 "function" @name
1268 name: (_) @name))
1269 (function_declaration
1270 "async"? @name
1271 "function" @name
1272 name: (_) @name)
1273 ] @item
1274 )
1275
1276 (
1277 (comment)* @context
1278 .
1279 [
1280 (export_statement
1281 (class_declaration
1282 "class" @name
1283 name: (_) @name))
1284 (class_declaration
1285 "class" @name
1286 name: (_) @name)
1287 ] @item
1288 )
1289
1290 (
1291 (comment)* @context
1292 .
1293 [
1294 (export_statement
1295 (interface_declaration
1296 "interface" @name
1297 name: (_) @name))
1298 (interface_declaration
1299 "interface" @name
1300 name: (_) @name)
1301 ] @item
1302 )
1303
1304 (
1305 (comment)* @context
1306 .
1307 [
1308 (export_statement
1309 (enum_declaration
1310 "enum" @name
1311 name: (_) @name))
1312 (enum_declaration
1313 "enum" @name
1314 name: (_) @name)
1315 ] @item
1316 )
1317
1318 (
1319 (comment)* @context
1320 .
1321 (method_definition
1322 [
1323 "get"
1324 "set"
1325 "async"
1326 "*"
1327 "static"
1328 ]* @name
1329 name: (_) @name) @item
1330 )
1331
1332 "#
1333 .unindent(),
1334 )
1335 .unwrap(),
1336 )
1337}
1338
1339fn rust_lang() -> Arc<Language> {
1340 Arc::new(
1341 Language::new(
1342 LanguageConfig {
1343 name: "Rust".into(),
1344 path_suffixes: vec!["rs".into()],
1345 collapsed_placeholder: " /* ... */ ".to_string(),
1346 ..Default::default()
1347 },
1348 Some(tree_sitter_rust::language()),
1349 )
1350 .with_embedding_query(
1351 r#"
1352 (
1353 [(line_comment) (attribute_item)]* @context
1354 .
1355 [
1356 (struct_item
1357 name: (_) @name)
1358
1359 (enum_item
1360 name: (_) @name)
1361
1362 (impl_item
1363 trait: (_)? @name
1364 "for"? @name
1365 type: (_) @name)
1366
1367 (trait_item
1368 name: (_) @name)
1369
1370 (function_item
1371 name: (_) @name
1372 body: (block
1373 "{" @keep
1374 "}" @keep) @collapse)
1375
1376 (macro_definition
1377 name: (_) @name)
1378 ] @item
1379 )
1380
1381 (attribute_item) @collapse
1382 (use_declaration) @collapse
1383 "#,
1384 )
1385 .unwrap(),
1386 )
1387}
1388
1389fn json_lang() -> Arc<Language> {
1390 Arc::new(
1391 Language::new(
1392 LanguageConfig {
1393 name: "JSON".into(),
1394 path_suffixes: vec!["json".into()],
1395 ..Default::default()
1396 },
1397 Some(tree_sitter_json::language()),
1398 )
1399 .with_embedding_query(
1400 r#"
1401 (document) @item
1402
1403 (array
1404 "[" @keep
1405 .
1406 (object)? @keep
1407 "]" @keep) @collapse
1408
1409 (pair value: (string
1410 "\"" @keep
1411 "\"" @keep) @collapse)
1412 "#,
1413 )
1414 .unwrap(),
1415 )
1416}
1417
1418fn toml_lang() -> Arc<Language> {
1419 Arc::new(Language::new(
1420 LanguageConfig {
1421 name: "TOML".into(),
1422 path_suffixes: vec!["toml".into()],
1423 ..Default::default()
1424 },
1425 Some(tree_sitter_toml::language()),
1426 ))
1427}
1428
1429fn cpp_lang() -> Arc<Language> {
1430 Arc::new(
1431 Language::new(
1432 LanguageConfig {
1433 name: "CPP".into(),
1434 path_suffixes: vec!["cpp".into()],
1435 ..Default::default()
1436 },
1437 Some(tree_sitter_cpp::language()),
1438 )
1439 .with_embedding_query(
1440 r#"
1441 (
1442 (comment)* @context
1443 .
1444 (function_definition
1445 (type_qualifier)? @name
1446 type: (_)? @name
1447 declarator: [
1448 (function_declarator
1449 declarator: (_) @name)
1450 (pointer_declarator
1451 "*" @name
1452 declarator: (function_declarator
1453 declarator: (_) @name))
1454 (pointer_declarator
1455 "*" @name
1456 declarator: (pointer_declarator
1457 "*" @name
1458 declarator: (function_declarator
1459 declarator: (_) @name)))
1460 (reference_declarator
1461 ["&" "&&"] @name
1462 (function_declarator
1463 declarator: (_) @name))
1464 ]
1465 (type_qualifier)? @name) @item
1466 )
1467
1468 (
1469 (comment)* @context
1470 .
1471 (template_declaration
1472 (class_specifier
1473 "class" @name
1474 name: (_) @name)
1475 ) @item
1476 )
1477
1478 (
1479 (comment)* @context
1480 .
1481 (class_specifier
1482 "class" @name
1483 name: (_) @name) @item
1484 )
1485
1486 (
1487 (comment)* @context
1488 .
1489 (enum_specifier
1490 "enum" @name
1491 name: (_) @name) @item
1492 )
1493
1494 (
1495 (comment)* @context
1496 .
1497 (declaration
1498 type: (struct_specifier
1499 "struct" @name)
1500 declarator: (_) @name) @item
1501 )
1502
1503 "#,
1504 )
1505 .unwrap(),
1506 )
1507}
1508
1509fn lua_lang() -> Arc<Language> {
1510 Arc::new(
1511 Language::new(
1512 LanguageConfig {
1513 name: "Lua".into(),
1514 path_suffixes: vec!["lua".into()],
1515 collapsed_placeholder: "--[ ... ]--".to_string(),
1516 ..Default::default()
1517 },
1518 Some(tree_sitter_lua::language()),
1519 )
1520 .with_embedding_query(
1521 r#"
1522 (
1523 (comment)* @context
1524 .
1525 (function_declaration
1526 "function" @name
1527 name: (_) @name
1528 (comment)* @collapse
1529 body: (block) @collapse
1530 ) @item
1531 )
1532 "#,
1533 )
1534 .unwrap(),
1535 )
1536}
1537
1538fn php_lang() -> Arc<Language> {
1539 Arc::new(
1540 Language::new(
1541 LanguageConfig {
1542 name: "PHP".into(),
1543 path_suffixes: vec!["php".into()],
1544 collapsed_placeholder: "/* ... */".into(),
1545 ..Default::default()
1546 },
1547 Some(tree_sitter_php::language()),
1548 )
1549 .with_embedding_query(
1550 r#"
1551 (
1552 (comment)* @context
1553 .
1554 [
1555 (function_definition
1556 "function" @name
1557 name: (_) @name
1558 body: (_
1559 "{" @keep
1560 "}" @keep) @collapse
1561 )
1562
1563 (trait_declaration
1564 "trait" @name
1565 name: (_) @name)
1566
1567 (method_declaration
1568 "function" @name
1569 name: (_) @name
1570 body: (_
1571 "{" @keep
1572 "}" @keep) @collapse
1573 )
1574
1575 (interface_declaration
1576 "interface" @name
1577 name: (_) @name
1578 )
1579
1580 (enum_declaration
1581 "enum" @name
1582 name: (_) @name
1583 )
1584
1585 ] @item
1586 )
1587 "#,
1588 )
1589 .unwrap(),
1590 )
1591}
1592
1593fn ruby_lang() -> Arc<Language> {
1594 Arc::new(
1595 Language::new(
1596 LanguageConfig {
1597 name: "Ruby".into(),
1598 path_suffixes: vec!["rb".into()],
1599 collapsed_placeholder: "# ...".to_string(),
1600 ..Default::default()
1601 },
1602 Some(tree_sitter_ruby::language()),
1603 )
1604 .with_embedding_query(
1605 r#"
1606 (
1607 (comment)* @context
1608 .
1609 [
1610 (module
1611 "module" @name
1612 name: (_) @name)
1613 (method
1614 "def" @name
1615 name: (_) @name
1616 body: (body_statement) @collapse)
1617 (class
1618 "class" @name
1619 name: (_) @name)
1620 (singleton_method
1621 "def" @name
1622 object: (_) @name
1623 "." @name
1624 name: (_) @name
1625 body: (body_statement) @collapse)
1626 ] @item
1627 )
1628 "#,
1629 )
1630 .unwrap(),
1631 )
1632}
1633
1634fn elixir_lang() -> Arc<Language> {
1635 Arc::new(
1636 Language::new(
1637 LanguageConfig {
1638 name: "Elixir".into(),
1639 path_suffixes: vec!["rs".into()],
1640 ..Default::default()
1641 },
1642 Some(tree_sitter_elixir::language()),
1643 )
1644 .with_embedding_query(
1645 r#"
1646 (
1647 (unary_operator
1648 operator: "@"
1649 operand: (call
1650 target: (identifier) @unary
1651 (#match? @unary "^(doc)$"))
1652 ) @context
1653 .
1654 (call
1655 target: (identifier) @name
1656 (arguments
1657 [
1658 (identifier) @name
1659 (call
1660 target: (identifier) @name)
1661 (binary_operator
1662 left: (call
1663 target: (identifier) @name)
1664 operator: "when")
1665 ])
1666 (#match? @name "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @item
1667 )
1668
1669 (call
1670 target: (identifier) @name
1671 (arguments (alias) @name)
1672 (#match? @name "^(defmodule|defprotocol)$")) @item
1673 "#,
1674 )
1675 .unwrap(),
1676 )
1677}
1678
1679#[gpui::test]
1680fn test_subtract_ranges() {
1681 // collapsed_ranges: Vec<Range<usize>>, keep_ranges: Vec<Range<usize>>
1682
1683 assert_eq!(
1684 subtract_ranges(&[0..5, 10..21], &[0..1, 4..5]),
1685 vec![1..4, 10..21]
1686 );
1687
1688 assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
1689}
1690
1691fn init_test(cx: &mut TestAppContext) {
1692 cx.update(|cx| {
1693 cx.set_global(SettingsStore::test(cx));
1694 settings::register::<SemanticIndexSettings>(cx);
1695 settings::register::<ProjectSettings>(cx);
1696 });
1697}