1use crate::{
2 embedding_queue::EmbeddingQueue,
3 parsing::{subtract_ranges, CodeContextRetriever, Span, SpanDigest},
4 semantic_index_settings::SemanticIndexSettings,
5 FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
6};
7use ai::test::FakeEmbeddingProvider;
8
9use gpui::{Task, TestAppContext};
10use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
11use parking_lot::Mutex;
12use pretty_assertions::assert_eq;
13use project::{project_settings::ProjectSettings, FakeFs, Fs, Project};
14use rand::{rngs::StdRng, Rng};
15use serde_json::json;
16use settings::{Settings, SettingsStore};
17use std::{path::Path, sync::Arc, time::SystemTime};
18use unindent::Unindent;
19use util::{paths::PathMatcher, RandomCharIter};
20
21#[ctor::ctor]
22fn init_logger() {
23 if std::env::var("RUST_LOG").is_ok() {
24 env_logger::init();
25 }
26}
27
28#[gpui::test]
29async fn test_semantic_index(cx: &mut TestAppContext) {
30 init_test(cx);
31
32 let fs = FakeFs::new(cx.background_executor.clone());
33 fs.insert_tree(
34 "/the-root",
35 json!({
36 "src": {
37 "file1.rs": "
38 fn aaa() {
39 println!(\"aaaaaaaaaaaa!\");
40 }
41
42 fn zzzzz() {
43 println!(\"SLEEPING\");
44 }
45 ".unindent(),
46 "file2.rs": "
47 fn bbb() {
48 println!(\"bbbbbbbbbbbbb!\");
49 }
50 struct pqpqpqp {}
51 ".unindent(),
52 "file3.toml": "
53 ZZZZZZZZZZZZZZZZZZ = 5
54 ".unindent(),
55 }
56 }),
57 )
58 .await;
59
60 let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
61 let rust_language = rust_lang();
62 let toml_language = toml_lang();
63 languages.add(rust_language);
64 languages.add(toml_language);
65
66 let db_dir = tempdir::TempDir::new("vector-store").unwrap();
67 let db_path = db_dir.path().join("db.sqlite");
68
69 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
70 let semantic_index = SemanticIndex::new(
71 fs.clone(),
72 db_path,
73 embedding_provider.clone(),
74 languages,
75 cx.to_async(),
76 )
77 .await
78 .unwrap();
79
80 let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
81
82 let search_results = semantic_index.update(cx, |store, cx| {
83 store.search_project(
84 project.clone(),
85 "aaaaaabbbbzz".to_string(),
86 5,
87 vec![],
88 vec![],
89 cx,
90 )
91 });
92 let pending_file_count =
93 semantic_index.read_with(cx, |index, _| index.pending_file_count(&project).unwrap());
94 cx.background_executor.run_until_parked();
95 assert_eq!(*pending_file_count.borrow(), 3);
96 cx.background_executor
97 .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
98 assert_eq!(*pending_file_count.borrow(), 0);
99
100 let search_results = search_results.await.unwrap();
101 assert_search_results(
102 &search_results,
103 &[
104 (Path::new("src/file1.rs").into(), 0),
105 (Path::new("src/file2.rs").into(), 0),
106 (Path::new("src/file3.toml").into(), 0),
107 (Path::new("src/file1.rs").into(), 45),
108 (Path::new("src/file2.rs").into(), 45),
109 ],
110 cx,
111 );
112
113 // Test Include Files Functionality
114 let include_files = vec![PathMatcher::new("*.rs").unwrap()];
115 let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
116 let rust_only_search_results = semantic_index
117 .update(cx, |store, cx| {
118 store.search_project(
119 project.clone(),
120 "aaaaaabbbbzz".to_string(),
121 5,
122 include_files,
123 vec![],
124 cx,
125 )
126 })
127 .await
128 .unwrap();
129
130 assert_search_results(
131 &rust_only_search_results,
132 &[
133 (Path::new("src/file1.rs").into(), 0),
134 (Path::new("src/file2.rs").into(), 0),
135 (Path::new("src/file1.rs").into(), 45),
136 (Path::new("src/file2.rs").into(), 45),
137 ],
138 cx,
139 );
140
141 let no_rust_search_results = semantic_index
142 .update(cx, |store, cx| {
143 store.search_project(
144 project.clone(),
145 "aaaaaabbbbzz".to_string(),
146 5,
147 vec![],
148 exclude_files,
149 cx,
150 )
151 })
152 .await
153 .unwrap();
154
155 assert_search_results(
156 &no_rust_search_results,
157 &[(Path::new("src/file3.toml").into(), 0)],
158 cx,
159 );
160
161 fs.save(
162 "/the-root/src/file2.rs".as_ref(),
163 &"
164 fn dddd() { println!(\"ddddd!\"); }
165 struct pqpqpqp {}
166 "
167 .unindent()
168 .into(),
169 Default::default(),
170 )
171 .await
172 .unwrap();
173
174 cx.background_executor
175 .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
176
177 let prev_embedding_count = embedding_provider.embedding_count();
178 let index = semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx));
179 cx.background_executor.run_until_parked();
180 assert_eq!(*pending_file_count.borrow(), 1);
181 cx.background_executor
182 .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
183 assert_eq!(*pending_file_count.borrow(), 0);
184 index.await.unwrap();
185
186 assert_eq!(
187 embedding_provider.embedding_count() - prev_embedding_count,
188 1
189 );
190}
191
192#[gpui::test(iterations = 10)]
193async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
194 let (outstanding_job_count, _) = postage::watch::channel_with(0);
195 let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count));
196
197 let files = (1..=3)
198 .map(|file_ix| FileToEmbed {
199 worktree_id: 5,
200 path: Path::new(&format!("path-{file_ix}")).into(),
201 mtime: SystemTime::now(),
202 spans: (0..rng.gen_range(4..22))
203 .map(|document_ix| {
204 let content_len = rng.gen_range(10..100);
205 let content = RandomCharIter::new(&mut rng)
206 .with_simple_text()
207 .take(content_len)
208 .collect::<String>();
209 let digest = SpanDigest::from(content.as_str());
210 Span {
211 range: 0..10,
212 embedding: None,
213 name: format!("document {document_ix}"),
214 content,
215 digest,
216 token_count: rng.gen_range(10..30),
217 }
218 })
219 .collect(),
220 job_handle: JobHandle::new(&outstanding_job_count),
221 })
222 .collect::<Vec<_>>();
223
224 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
225
226 let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor.clone());
227 for file in &files {
228 queue.push(file.clone());
229 }
230 queue.flush();
231
232 cx.background_executor.run_until_parked();
233 let finished_files = queue.finished_files();
234 let mut embedded_files: Vec<_> = files
235 .iter()
236 .map(|_| finished_files.try_recv().expect("no finished file"))
237 .collect();
238
239 let expected_files: Vec<_> = files
240 .iter()
241 .map(|file| {
242 let mut file = file.clone();
243 for doc in &mut file.spans {
244 doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
245 }
246 file
247 })
248 .collect();
249
250 embedded_files.sort_by_key(|f| f.path.clone());
251
252 assert_eq!(embedded_files, expected_files);
253}
254
255#[track_caller]
256fn assert_search_results(
257 actual: &[SearchResult],
258 expected: &[(Arc<Path>, usize)],
259 cx: &TestAppContext,
260) {
261 let actual = actual
262 .iter()
263 .map(|search_result| {
264 search_result.buffer.read_with(cx, |buffer, _cx| {
265 (
266 buffer.file().unwrap().path().clone(),
267 search_result.range.start.to_offset(buffer),
268 )
269 })
270 })
271 .collect::<Vec<_>>();
272 assert_eq!(actual, expected);
273}
274
275#[gpui::test]
276async fn test_code_context_retrieval_rust() {
277 let language = rust_lang();
278 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
279 let mut retriever = CodeContextRetriever::new(embedding_provider);
280
281 let text = "
282 /// A doc comment
283 /// that spans multiple lines
284 #[gpui::test]
285 fn a() {
286 b
287 }
288
289 impl C for D {
290 }
291
292 impl E {
293 // This is also a preceding comment
294 pub fn function_1() -> Option<()> {
295 unimplemented!();
296 }
297
298 // This is a preceding comment
299 fn function_2() -> Result<()> {
300 unimplemented!();
301 }
302 }
303
304 #[derive(Clone)]
305 struct D {
306 name: String
307 }
308 "
309 .unindent();
310
311 let documents = retriever.parse_file(&text, language).unwrap();
312
313 assert_documents_eq(
314 &documents,
315 &[
316 (
317 "
318 /// A doc comment
319 /// that spans multiple lines
320 #[gpui::test]
321 fn a() {
322 b
323 }"
324 .unindent(),
325 text.find("fn a").unwrap(),
326 ),
327 (
328 "
329 impl C for D {
330 }"
331 .unindent(),
332 text.find("impl C").unwrap(),
333 ),
334 (
335 "
336 impl E {
337 // This is also a preceding comment
338 pub fn function_1() -> Option<()> { /* ... */ }
339
340 // This is a preceding comment
341 fn function_2() -> Result<()> { /* ... */ }
342 }"
343 .unindent(),
344 text.find("impl E").unwrap(),
345 ),
346 (
347 "
348 // This is also a preceding comment
349 pub fn function_1() -> Option<()> {
350 unimplemented!();
351 }"
352 .unindent(),
353 text.find("pub fn function_1").unwrap(),
354 ),
355 (
356 "
357 // This is a preceding comment
358 fn function_2() -> Result<()> {
359 unimplemented!();
360 }"
361 .unindent(),
362 text.find("fn function_2").unwrap(),
363 ),
364 (
365 "
366 #[derive(Clone)]
367 struct D {
368 name: String
369 }"
370 .unindent(),
371 text.find("struct D").unwrap(),
372 ),
373 ],
374 );
375}
376
377#[gpui::test]
378async fn test_code_context_retrieval_json() {
379 let language = json_lang();
380 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
381 let mut retriever = CodeContextRetriever::new(embedding_provider);
382
383 let text = r#"
384 {
385 "array": [1, 2, 3, 4],
386 "string": "abcdefg",
387 "nested_object": {
388 "array_2": [5, 6, 7, 8],
389 "string_2": "hijklmnop",
390 "boolean": true,
391 "none": null
392 }
393 }
394 "#
395 .unindent();
396
397 let documents = retriever.parse_file(&text, language.clone()).unwrap();
398
399 assert_documents_eq(
400 &documents,
401 &[(
402 r#"
403 {
404 "array": [],
405 "string": "",
406 "nested_object": {
407 "array_2": [],
408 "string_2": "",
409 "boolean": true,
410 "none": null
411 }
412 }"#
413 .unindent(),
414 text.find("{").unwrap(),
415 )],
416 );
417
418 let text = r#"
419 [
420 {
421 "name": "somebody",
422 "age": 42
423 },
424 {
425 "name": "somebody else",
426 "age": 43
427 }
428 ]
429 "#
430 .unindent();
431
432 let documents = retriever.parse_file(&text, language.clone()).unwrap();
433
434 assert_documents_eq(
435 &documents,
436 &[(
437 r#"
438 [{
439 "name": "",
440 "age": 42
441 }]"#
442 .unindent(),
443 text.find("[").unwrap(),
444 )],
445 );
446}
447
448fn assert_documents_eq(
449 documents: &[Span],
450 expected_contents_and_start_offsets: &[(String, usize)],
451) {
452 assert_eq!(
453 documents
454 .iter()
455 .map(|document| (document.content.clone(), document.range.start))
456 .collect::<Vec<_>>(),
457 expected_contents_and_start_offsets
458 );
459}
460
461#[gpui::test]
462async fn test_code_context_retrieval_javascript() {
463 let language = js_lang();
464 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
465 let mut retriever = CodeContextRetriever::new(embedding_provider);
466
467 let text = "
468 /* globals importScripts, backend */
469 function _authorize() {}
470
471 /**
472 * Sometimes the frontend build is way faster than backend.
473 */
474 export async function authorizeBank() {
475 _authorize(pushModal, upgradingAccountId, {});
476 }
477
478 export class SettingsPage {
479 /* This is a test setting */
480 constructor(page) {
481 this.page = page;
482 }
483 }
484
485 /* This is a test comment */
486 class TestClass {}
487
488 /* Schema for editor_events in Clickhouse. */
489 export interface ClickhouseEditorEvent {
490 installation_id: string
491 operation: string
492 }
493 "
494 .unindent();
495
496 let documents = retriever.parse_file(&text, language.clone()).unwrap();
497
498 assert_documents_eq(
499 &documents,
500 &[
501 (
502 "
503 /* globals importScripts, backend */
504 function _authorize() {}"
505 .unindent(),
506 37,
507 ),
508 (
509 "
510 /**
511 * Sometimes the frontend build is way faster than backend.
512 */
513 export async function authorizeBank() {
514 _authorize(pushModal, upgradingAccountId, {});
515 }"
516 .unindent(),
517 131,
518 ),
519 (
520 "
521 export class SettingsPage {
522 /* This is a test setting */
523 constructor(page) {
524 this.page = page;
525 }
526 }"
527 .unindent(),
528 225,
529 ),
530 (
531 "
532 /* This is a test setting */
533 constructor(page) {
534 this.page = page;
535 }"
536 .unindent(),
537 290,
538 ),
539 (
540 "
541 /* This is a test comment */
542 class TestClass {}"
543 .unindent(),
544 374,
545 ),
546 (
547 "
548 /* Schema for editor_events in Clickhouse. */
549 export interface ClickhouseEditorEvent {
550 installation_id: string
551 operation: string
552 }"
553 .unindent(),
554 440,
555 ),
556 ],
557 )
558}
559
560#[gpui::test]
561async fn test_code_context_retrieval_lua() {
562 let language = lua_lang();
563 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
564 let mut retriever = CodeContextRetriever::new(embedding_provider);
565
566 let text = r#"
567 -- Creates a new class
568 -- @param baseclass The Baseclass of this class, or nil.
569 -- @return A new class reference.
570 function classes.class(baseclass)
571 -- Create the class definition and metatable.
572 local classdef = {}
573 -- Find the super class, either Object or user-defined.
574 baseclass = baseclass or classes.Object
575 -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
576 setmetatable(classdef, { __index = baseclass })
577 -- All class instances have a reference to the class object.
578 classdef.class = classdef
579 --- Recursively allocates the inheritance tree of the instance.
580 -- @param mastertable The 'root' of the inheritance tree.
581 -- @return Returns the instance with the allocated inheritance tree.
582 function classdef.alloc(mastertable)
583 -- All class instances have a reference to a superclass object.
584 local instance = { super = baseclass.alloc(mastertable) }
585 -- Any functions this instance does not know of will 'look up' to the superclass definition.
586 setmetatable(instance, { __index = classdef, __newindex = mastertable })
587 return instance
588 end
589 end
590 "#.unindent();
591
592 let documents = retriever.parse_file(&text, language.clone()).unwrap();
593
594 assert_documents_eq(
595 &documents,
596 &[
597 (r#"
598 -- Creates a new class
599 -- @param baseclass The Baseclass of this class, or nil.
600 -- @return A new class reference.
601 function classes.class(baseclass)
602 -- Create the class definition and metatable.
603 local classdef = {}
604 -- Find the super class, either Object or user-defined.
605 baseclass = baseclass or classes.Object
606 -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
607 setmetatable(classdef, { __index = baseclass })
608 -- All class instances have a reference to the class object.
609 classdef.class = classdef
610 --- Recursively allocates the inheritance tree of the instance.
611 -- @param mastertable The 'root' of the inheritance tree.
612 -- @return Returns the instance with the allocated inheritance tree.
613 function classdef.alloc(mastertable)
614 --[ ... ]--
615 --[ ... ]--
616 end
617 end"#.unindent(),
618 114),
619 (r#"
620 --- Recursively allocates the inheritance tree of the instance.
621 -- @param mastertable The 'root' of the inheritance tree.
622 -- @return Returns the instance with the allocated inheritance tree.
623 function classdef.alloc(mastertable)
624 -- All class instances have a reference to a superclass object.
625 local instance = { super = baseclass.alloc(mastertable) }
626 -- Any functions this instance does not know of will 'look up' to the superclass definition.
627 setmetatable(instance, { __index = classdef, __newindex = mastertable })
628 return instance
629 end"#.unindent(), 810),
630 ]
631 );
632}
633
634#[gpui::test]
635async fn test_code_context_retrieval_elixir() {
636 let language = elixir_lang();
637 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
638 let mut retriever = CodeContextRetriever::new(embedding_provider);
639
640 let text = r#"
641 defmodule File.Stream do
642 @moduledoc """
643 Defines a `File.Stream` struct returned by `File.stream!/3`.
644
645 The following fields are public:
646
647 * `path` - the file path
648 * `modes` - the file modes
649 * `raw` - a boolean indicating if bin functions should be used
650 * `line_or_bytes` - if reading should read lines or a given number of bytes
651 * `node` - the node the file belongs to
652
653 """
654
655 defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
656
657 @type t :: %__MODULE__{}
658
659 @doc false
660 def __build__(path, modes, line_or_bytes) do
661 raw = :lists.keyfind(:encoding, 1, modes) == false
662
663 modes =
664 case raw do
665 true ->
666 case :lists.keyfind(:read_ahead, 1, modes) do
667 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
668 {:read_ahead, _} -> [:raw | modes]
669 false -> [:raw, :read_ahead | modes]
670 end
671
672 false ->
673 modes
674 end
675
676 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
677
678 end"#
679 .unindent();
680
681 let documents = retriever.parse_file(&text, language.clone()).unwrap();
682
683 assert_documents_eq(
684 &documents,
685 &[(
686 r#"
687 defmodule File.Stream do
688 @moduledoc """
689 Defines a `File.Stream` struct returned by `File.stream!/3`.
690
691 The following fields are public:
692
693 * `path` - the file path
694 * `modes` - the file modes
695 * `raw` - a boolean indicating if bin functions should be used
696 * `line_or_bytes` - if reading should read lines or a given number of bytes
697 * `node` - the node the file belongs to
698
699 """
700
701 defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
702
703 @type t :: %__MODULE__{}
704
705 @doc false
706 def __build__(path, modes, line_or_bytes) do
707 raw = :lists.keyfind(:encoding, 1, modes) == false
708
709 modes =
710 case raw do
711 true ->
712 case :lists.keyfind(:read_ahead, 1, modes) do
713 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
714 {:read_ahead, _} -> [:raw | modes]
715 false -> [:raw, :read_ahead | modes]
716 end
717
718 false ->
719 modes
720 end
721
722 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
723
724 end"#
725 .unindent(),
726 0,
727 ),(r#"
728 @doc false
729 def __build__(path, modes, line_or_bytes) do
730 raw = :lists.keyfind(:encoding, 1, modes) == false
731
732 modes =
733 case raw do
734 true ->
735 case :lists.keyfind(:read_ahead, 1, modes) do
736 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
737 {:read_ahead, _} -> [:raw | modes]
738 false -> [:raw, :read_ahead | modes]
739 end
740
741 false ->
742 modes
743 end
744
745 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
746
747 end"#.unindent(), 574)],
748 );
749}
750
751#[gpui::test]
752async fn test_code_context_retrieval_cpp() {
753 let language = cpp_lang();
754 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
755 let mut retriever = CodeContextRetriever::new(embedding_provider);
756
757 let text = "
758 /**
759 * @brief Main function
760 * @returns 0 on exit
761 */
762 int main() { return 0; }
763
764 /**
765 * This is a test comment
766 */
767 class MyClass { // The class
768 public: // Access specifier
769 int myNum; // Attribute (int variable)
770 string myString; // Attribute (string variable)
771 };
772
773 // This is a test comment
774 enum Color { red, green, blue };
775
776 /** This is a preceding block comment
777 * This is the second line
778 */
779 struct { // Structure declaration
780 int myNum; // Member (int variable)
781 string myString; // Member (string variable)
782 } myStructure;
783
784 /**
785 * @brief Matrix class.
786 */
787 template <typename T,
788 typename = typename std::enable_if<
789 std::is_integral<T>::value || std::is_floating_point<T>::value,
790 bool>::type>
791 class Matrix2 {
792 std::vector<std::vector<T>> _mat;
793
794 public:
795 /**
796 * @brief Constructor
797 * @tparam Integer ensuring integers are being evaluated and not other
798 * data types.
799 * @param size denoting the size of Matrix as size x size
800 */
801 template <typename Integer,
802 typename = typename std::enable_if<std::is_integral<Integer>::value,
803 Integer>::type>
804 explicit Matrix(const Integer size) {
805 for (size_t i = 0; i < size; ++i) {
806 _mat.emplace_back(std::vector<T>(size, 0));
807 }
808 }
809 }"
810 .unindent();
811
812 let documents = retriever.parse_file(&text, language.clone()).unwrap();
813
814 assert_documents_eq(
815 &documents,
816 &[
817 (
818 "
819 /**
820 * @brief Main function
821 * @returns 0 on exit
822 */
823 int main() { return 0; }"
824 .unindent(),
825 54,
826 ),
827 (
828 "
829 /**
830 * This is a test comment
831 */
832 class MyClass { // The class
833 public: // Access specifier
834 int myNum; // Attribute (int variable)
835 string myString; // Attribute (string variable)
836 }"
837 .unindent(),
838 112,
839 ),
840 (
841 "
842 // This is a test comment
843 enum Color { red, green, blue }"
844 .unindent(),
845 322,
846 ),
847 (
848 "
849 /** This is a preceding block comment
850 * This is the second line
851 */
852 struct { // Structure declaration
853 int myNum; // Member (int variable)
854 string myString; // Member (string variable)
855 } myStructure;"
856 .unindent(),
857 425,
858 ),
859 (
860 "
861 /**
862 * @brief Matrix class.
863 */
864 template <typename T,
865 typename = typename std::enable_if<
866 std::is_integral<T>::value || std::is_floating_point<T>::value,
867 bool>::type>
868 class Matrix2 {
869 std::vector<std::vector<T>> _mat;
870
871 public:
872 /**
873 * @brief Constructor
874 * @tparam Integer ensuring integers are being evaluated and not other
875 * data types.
876 * @param size denoting the size of Matrix as size x size
877 */
878 template <typename Integer,
879 typename = typename std::enable_if<std::is_integral<Integer>::value,
880 Integer>::type>
881 explicit Matrix(const Integer size) {
882 for (size_t i = 0; i < size; ++i) {
883 _mat.emplace_back(std::vector<T>(size, 0));
884 }
885 }
886 }"
887 .unindent(),
888 612,
889 ),
890 (
891 "
892 explicit Matrix(const Integer size) {
893 for (size_t i = 0; i < size; ++i) {
894 _mat.emplace_back(std::vector<T>(size, 0));
895 }
896 }"
897 .unindent(),
898 1226,
899 ),
900 ],
901 );
902}
903
904#[gpui::test]
905async fn test_code_context_retrieval_ruby() {
906 let language = ruby_lang();
907 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
908 let mut retriever = CodeContextRetriever::new(embedding_provider);
909
910 let text = r#"
911 # This concern is inspired by "sudo mode" on GitHub. It
912 # is a way to re-authenticate a user before allowing them
913 # to see or perform an action.
914 #
915 # Add `before_action :require_challenge!` to actions you
916 # want to protect.
917 #
918 # The user will be shown a page to enter the challenge (which
919 # is either the password, or just the username when no
920 # password exists). Upon passing, there is a grace period
921 # during which no challenge will be asked from the user.
922 #
923 # Accessing challenge-protected resources during the grace
924 # period will refresh the grace period.
925 module ChallengableConcern
926 extend ActiveSupport::Concern
927
928 CHALLENGE_TIMEOUT = 1.hour.freeze
929
930 def require_challenge!
931 return if skip_challenge?
932
933 if challenge_passed_recently?
934 session[:challenge_passed_at] = Time.now.utc
935 return
936 end
937
938 @challenge = Form::Challenge.new(return_to: request.url)
939
940 if params.key?(:form_challenge)
941 if challenge_passed?
942 session[:challenge_passed_at] = Time.now.utc
943 else
944 flash.now[:alert] = I18n.t('challenge.invalid_password')
945 render_challenge
946 end
947 else
948 render_challenge
949 end
950 end
951
952 def challenge_passed?
953 current_user.valid_password?(challenge_params[:current_password])
954 end
955 end
956
957 class Animal
958 include Comparable
959
960 attr_reader :legs
961
962 def initialize(name, legs)
963 @name, @legs = name, legs
964 end
965
966 def <=>(other)
967 legs <=> other.legs
968 end
969 end
970
971 # Singleton method for car object
972 def car.wheels
973 puts "There are four wheels"
974 end"#
975 .unindent();
976
977 let documents = retriever.parse_file(&text, language.clone()).unwrap();
978
979 assert_documents_eq(
980 &documents,
981 &[
982 (
983 r#"
984 # This concern is inspired by "sudo mode" on GitHub. It
985 # is a way to re-authenticate a user before allowing them
986 # to see or perform an action.
987 #
988 # Add `before_action :require_challenge!` to actions you
989 # want to protect.
990 #
991 # The user will be shown a page to enter the challenge (which
992 # is either the password, or just the username when no
993 # password exists). Upon passing, there is a grace period
994 # during which no challenge will be asked from the user.
995 #
996 # Accessing challenge-protected resources during the grace
997 # period will refresh the grace period.
998 module ChallengableConcern
999 extend ActiveSupport::Concern
1000
1001 CHALLENGE_TIMEOUT = 1.hour.freeze
1002
1003 def require_challenge!
1004 # ...
1005 end
1006
1007 def challenge_passed?
1008 # ...
1009 end
1010 end"#
1011 .unindent(),
1012 558,
1013 ),
1014 (
1015 r#"
1016 def require_challenge!
1017 return if skip_challenge?
1018
1019 if challenge_passed_recently?
1020 session[:challenge_passed_at] = Time.now.utc
1021 return
1022 end
1023
1024 @challenge = Form::Challenge.new(return_to: request.url)
1025
1026 if params.key?(:form_challenge)
1027 if challenge_passed?
1028 session[:challenge_passed_at] = Time.now.utc
1029 else
1030 flash.now[:alert] = I18n.t('challenge.invalid_password')
1031 render_challenge
1032 end
1033 else
1034 render_challenge
1035 end
1036 end"#
1037 .unindent(),
1038 663,
1039 ),
1040 (
1041 r#"
1042 def challenge_passed?
1043 current_user.valid_password?(challenge_params[:current_password])
1044 end"#
1045 .unindent(),
1046 1254,
1047 ),
1048 (
1049 r#"
1050 class Animal
1051 include Comparable
1052
1053 attr_reader :legs
1054
1055 def initialize(name, legs)
1056 # ...
1057 end
1058
1059 def <=>(other)
1060 # ...
1061 end
1062 end"#
1063 .unindent(),
1064 1363,
1065 ),
1066 (
1067 r#"
1068 def initialize(name, legs)
1069 @name, @legs = name, legs
1070 end"#
1071 .unindent(),
1072 1427,
1073 ),
1074 (
1075 r#"
1076 def <=>(other)
1077 legs <=> other.legs
1078 end"#
1079 .unindent(),
1080 1501,
1081 ),
1082 (
1083 r#"
1084 # Singleton method for car object
1085 def car.wheels
1086 puts "There are four wheels"
1087 end"#
1088 .unindent(),
1089 1591,
1090 ),
1091 ],
1092 );
1093}
1094
1095#[gpui::test]
1096async fn test_code_context_retrieval_php() {
1097 let language = php_lang();
1098 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
1099 let mut retriever = CodeContextRetriever::new(embedding_provider);
1100
1101 let text = r#"
1102 <?php
1103
1104 namespace LevelUp\Experience\Concerns;
1105
1106 /*
1107 This is a multiple-lines comment block
1108 that spans over multiple
1109 lines
1110 */
1111 function functionName() {
1112 echo "Hello world!";
1113 }
1114
1115 trait HasAchievements
1116 {
1117 /**
1118 * @throws \Exception
1119 */
1120 public function grantAchievement(Achievement $achievement, $progress = null): void
1121 {
1122 if ($progress > 100) {
1123 throw new Exception(message: 'Progress cannot be greater than 100');
1124 }
1125
1126 if ($this->achievements()->find($achievement->id)) {
1127 throw new Exception(message: 'User already has this Achievement');
1128 }
1129
1130 $this->achievements()->attach($achievement, [
1131 'progress' => $progress ?? null,
1132 ]);
1133
1134 $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1135 }
1136
1137 public function achievements(): BelongsToMany
1138 {
1139 return $this->belongsToMany(related: Achievement::class)
1140 ->withPivot(columns: 'progress')
1141 ->where('is_secret', false)
1142 ->using(AchievementUser::class);
1143 }
1144 }
1145
1146 interface Multiplier
1147 {
1148 public function qualifies(array $data): bool;
1149
1150 public function setMultiplier(): int;
1151 }
1152
1153 enum AuditType: string
1154 {
1155 case Add = 'add';
1156 case Remove = 'remove';
1157 case Reset = 'reset';
1158 case LevelUp = 'level_up';
1159 }
1160
1161 ?>"#
1162 .unindent();
1163
1164 let documents = retriever.parse_file(&text, language.clone()).unwrap();
1165
1166 assert_documents_eq(
1167 &documents,
1168 &[
1169 (
1170 r#"
1171 /*
1172 This is a multiple-lines comment block
1173 that spans over multiple
1174 lines
1175 */
1176 function functionName() {
1177 echo "Hello world!";
1178 }"#
1179 .unindent(),
1180 123,
1181 ),
1182 (
1183 r#"
1184 trait HasAchievements
1185 {
1186 /**
1187 * @throws \Exception
1188 */
1189 public function grantAchievement(Achievement $achievement, $progress = null): void
1190 {/* ... */}
1191
1192 public function achievements(): BelongsToMany
1193 {/* ... */}
1194 }"#
1195 .unindent(),
1196 177,
1197 ),
1198 (r#"
1199 /**
1200 * @throws \Exception
1201 */
1202 public function grantAchievement(Achievement $achievement, $progress = null): void
1203 {
1204 if ($progress > 100) {
1205 throw new Exception(message: 'Progress cannot be greater than 100');
1206 }
1207
1208 if ($this->achievements()->find($achievement->id)) {
1209 throw new Exception(message: 'User already has this Achievement');
1210 }
1211
1212 $this->achievements()->attach($achievement, [
1213 'progress' => $progress ?? null,
1214 ]);
1215
1216 $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1217 }"#.unindent(), 245),
1218 (r#"
1219 public function achievements(): BelongsToMany
1220 {
1221 return $this->belongsToMany(related: Achievement::class)
1222 ->withPivot(columns: 'progress')
1223 ->where('is_secret', false)
1224 ->using(AchievementUser::class);
1225 }"#.unindent(), 902),
1226 (r#"
1227 interface Multiplier
1228 {
1229 public function qualifies(array $data): bool;
1230
1231 public function setMultiplier(): int;
1232 }"#.unindent(),
1233 1146),
1234 (r#"
1235 enum AuditType: string
1236 {
1237 case Add = 'add';
1238 case Remove = 'remove';
1239 case Reset = 'reset';
1240 case LevelUp = 'level_up';
1241 }"#.unindent(), 1265)
1242 ],
1243 );
1244}
1245
1246fn js_lang() -> Arc<Language> {
1247 Arc::new(
1248 Language::new(
1249 LanguageConfig {
1250 name: "Javascript".into(),
1251 path_suffixes: vec!["js".into()],
1252 ..Default::default()
1253 },
1254 Some(tree_sitter_typescript::language_tsx()),
1255 )
1256 .with_embedding_query(
1257 &r#"
1258
1259 (
1260 (comment)* @context
1261 .
1262 [
1263 (export_statement
1264 (function_declaration
1265 "async"? @name
1266 "function" @name
1267 name: (_) @name))
1268 (function_declaration
1269 "async"? @name
1270 "function" @name
1271 name: (_) @name)
1272 ] @item
1273 )
1274
1275 (
1276 (comment)* @context
1277 .
1278 [
1279 (export_statement
1280 (class_declaration
1281 "class" @name
1282 name: (_) @name))
1283 (class_declaration
1284 "class" @name
1285 name: (_) @name)
1286 ] @item
1287 )
1288
1289 (
1290 (comment)* @context
1291 .
1292 [
1293 (export_statement
1294 (interface_declaration
1295 "interface" @name
1296 name: (_) @name))
1297 (interface_declaration
1298 "interface" @name
1299 name: (_) @name)
1300 ] @item
1301 )
1302
1303 (
1304 (comment)* @context
1305 .
1306 [
1307 (export_statement
1308 (enum_declaration
1309 "enum" @name
1310 name: (_) @name))
1311 (enum_declaration
1312 "enum" @name
1313 name: (_) @name)
1314 ] @item
1315 )
1316
1317 (
1318 (comment)* @context
1319 .
1320 (method_definition
1321 [
1322 "get"
1323 "set"
1324 "async"
1325 "*"
1326 "static"
1327 ]* @name
1328 name: (_) @name) @item
1329 )
1330
1331 "#
1332 .unindent(),
1333 )
1334 .unwrap(),
1335 )
1336}
1337
1338fn rust_lang() -> Arc<Language> {
1339 Arc::new(
1340 Language::new(
1341 LanguageConfig {
1342 name: "Rust".into(),
1343 path_suffixes: vec!["rs".into()],
1344 collapsed_placeholder: " /* ... */ ".to_string(),
1345 ..Default::default()
1346 },
1347 Some(tree_sitter_rust::language()),
1348 )
1349 .with_embedding_query(
1350 r#"
1351 (
1352 [(line_comment) (attribute_item)]* @context
1353 .
1354 [
1355 (struct_item
1356 name: (_) @name)
1357
1358 (enum_item
1359 name: (_) @name)
1360
1361 (impl_item
1362 trait: (_)? @name
1363 "for"? @name
1364 type: (_) @name)
1365
1366 (trait_item
1367 name: (_) @name)
1368
1369 (function_item
1370 name: (_) @name
1371 body: (block
1372 "{" @keep
1373 "}" @keep) @collapse)
1374
1375 (macro_definition
1376 name: (_) @name)
1377 ] @item
1378 )
1379
1380 (attribute_item) @collapse
1381 (use_declaration) @collapse
1382 "#,
1383 )
1384 .unwrap(),
1385 )
1386}
1387
1388fn json_lang() -> Arc<Language> {
1389 Arc::new(
1390 Language::new(
1391 LanguageConfig {
1392 name: "JSON".into(),
1393 path_suffixes: vec!["json".into()],
1394 ..Default::default()
1395 },
1396 Some(tree_sitter_json::language()),
1397 )
1398 .with_embedding_query(
1399 r#"
1400 (document) @item
1401
1402 (array
1403 "[" @keep
1404 .
1405 (object)? @keep
1406 "]" @keep) @collapse
1407
1408 (pair value: (string
1409 "\"" @keep
1410 "\"" @keep) @collapse)
1411 "#,
1412 )
1413 .unwrap(),
1414 )
1415}
1416
1417fn toml_lang() -> Arc<Language> {
1418 Arc::new(Language::new(
1419 LanguageConfig {
1420 name: "TOML".into(),
1421 path_suffixes: vec!["toml".into()],
1422 ..Default::default()
1423 },
1424 Some(tree_sitter_toml::language()),
1425 ))
1426}
1427
1428fn cpp_lang() -> Arc<Language> {
1429 Arc::new(
1430 Language::new(
1431 LanguageConfig {
1432 name: "CPP".into(),
1433 path_suffixes: vec!["cpp".into()],
1434 ..Default::default()
1435 },
1436 Some(tree_sitter_cpp::language()),
1437 )
1438 .with_embedding_query(
1439 r#"
1440 (
1441 (comment)* @context
1442 .
1443 (function_definition
1444 (type_qualifier)? @name
1445 type: (_)? @name
1446 declarator: [
1447 (function_declarator
1448 declarator: (_) @name)
1449 (pointer_declarator
1450 "*" @name
1451 declarator: (function_declarator
1452 declarator: (_) @name))
1453 (pointer_declarator
1454 "*" @name
1455 declarator: (pointer_declarator
1456 "*" @name
1457 declarator: (function_declarator
1458 declarator: (_) @name)))
1459 (reference_declarator
1460 ["&" "&&"] @name
1461 (function_declarator
1462 declarator: (_) @name))
1463 ]
1464 (type_qualifier)? @name) @item
1465 )
1466
1467 (
1468 (comment)* @context
1469 .
1470 (template_declaration
1471 (class_specifier
1472 "class" @name
1473 name: (_) @name)
1474 ) @item
1475 )
1476
1477 (
1478 (comment)* @context
1479 .
1480 (class_specifier
1481 "class" @name
1482 name: (_) @name) @item
1483 )
1484
1485 (
1486 (comment)* @context
1487 .
1488 (enum_specifier
1489 "enum" @name
1490 name: (_) @name) @item
1491 )
1492
1493 (
1494 (comment)* @context
1495 .
1496 (declaration
1497 type: (struct_specifier
1498 "struct" @name)
1499 declarator: (_) @name) @item
1500 )
1501
1502 "#,
1503 )
1504 .unwrap(),
1505 )
1506}
1507
1508fn lua_lang() -> Arc<Language> {
1509 Arc::new(
1510 Language::new(
1511 LanguageConfig {
1512 name: "Lua".into(),
1513 path_suffixes: vec!["lua".into()],
1514 collapsed_placeholder: "--[ ... ]--".to_string(),
1515 ..Default::default()
1516 },
1517 Some(tree_sitter_lua::language()),
1518 )
1519 .with_embedding_query(
1520 r#"
1521 (
1522 (comment)* @context
1523 .
1524 (function_declaration
1525 "function" @name
1526 name: (_) @name
1527 (comment)* @collapse
1528 body: (block) @collapse
1529 ) @item
1530 )
1531 "#,
1532 )
1533 .unwrap(),
1534 )
1535}
1536
1537fn php_lang() -> Arc<Language> {
1538 Arc::new(
1539 Language::new(
1540 LanguageConfig {
1541 name: "PHP".into(),
1542 path_suffixes: vec!["php".into()],
1543 collapsed_placeholder: "/* ... */".into(),
1544 ..Default::default()
1545 },
1546 Some(tree_sitter_php::language()),
1547 )
1548 .with_embedding_query(
1549 r#"
1550 (
1551 (comment)* @context
1552 .
1553 [
1554 (function_definition
1555 "function" @name
1556 name: (_) @name
1557 body: (_
1558 "{" @keep
1559 "}" @keep) @collapse
1560 )
1561
1562 (trait_declaration
1563 "trait" @name
1564 name: (_) @name)
1565
1566 (method_declaration
1567 "function" @name
1568 name: (_) @name
1569 body: (_
1570 "{" @keep
1571 "}" @keep) @collapse
1572 )
1573
1574 (interface_declaration
1575 "interface" @name
1576 name: (_) @name
1577 )
1578
1579 (enum_declaration
1580 "enum" @name
1581 name: (_) @name
1582 )
1583
1584 ] @item
1585 )
1586 "#,
1587 )
1588 .unwrap(),
1589 )
1590}
1591
1592fn ruby_lang() -> Arc<Language> {
1593 Arc::new(
1594 Language::new(
1595 LanguageConfig {
1596 name: "Ruby".into(),
1597 path_suffixes: vec!["rb".into()],
1598 collapsed_placeholder: "# ...".to_string(),
1599 ..Default::default()
1600 },
1601 Some(tree_sitter_ruby::language()),
1602 )
1603 .with_embedding_query(
1604 r#"
1605 (
1606 (comment)* @context
1607 .
1608 [
1609 (module
1610 "module" @name
1611 name: (_) @name)
1612 (method
1613 "def" @name
1614 name: (_) @name
1615 body: (body_statement) @collapse)
1616 (class
1617 "class" @name
1618 name: (_) @name)
1619 (singleton_method
1620 "def" @name
1621 object: (_) @name
1622 "." @name
1623 name: (_) @name
1624 body: (body_statement) @collapse)
1625 ] @item
1626 )
1627 "#,
1628 )
1629 .unwrap(),
1630 )
1631}
1632
1633fn elixir_lang() -> Arc<Language> {
1634 Arc::new(
1635 Language::new(
1636 LanguageConfig {
1637 name: "Elixir".into(),
1638 path_suffixes: vec!["rs".into()],
1639 ..Default::default()
1640 },
1641 Some(tree_sitter_elixir::language()),
1642 )
1643 .with_embedding_query(
1644 r#"
1645 (
1646 (unary_operator
1647 operator: "@"
1648 operand: (call
1649 target: (identifier) @unary
1650 (#match? @unary "^(doc)$"))
1651 ) @context
1652 .
1653 (call
1654 target: (identifier) @name
1655 (arguments
1656 [
1657 (identifier) @name
1658 (call
1659 target: (identifier) @name)
1660 (binary_operator
1661 left: (call
1662 target: (identifier) @name)
1663 operator: "when")
1664 ])
1665 (#any-match? @name "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @item
1666 )
1667
1668 (call
1669 target: (identifier) @name
1670 (arguments (alias) @name)
1671 (#any-match? @name "^(defmodule|defprotocol)$")) @item
1672 "#,
1673 )
1674 .unwrap(),
1675 )
1676}
1677
1678#[gpui::test]
1679fn test_subtract_ranges() {
1680 assert_eq!(
1681 subtract_ranges(&[0..5, 10..21], &[0..1, 4..5]),
1682 vec![1..4, 10..21]
1683 );
1684
1685 assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
1686}
1687
1688fn init_test(cx: &mut TestAppContext) {
1689 cx.update(|cx| {
1690 let settings_store = SettingsStore::test(cx);
1691 cx.set_global(settings_store);
1692 SemanticIndexSettings::register(cx);
1693 ProjectSettings::register(cx);
1694 });
1695}