1use crate::{
2 embedding::{DummyEmbeddings, Embedding, EmbeddingProvider},
3 embedding_queue::EmbeddingQueue,
4 parsing::{subtract_ranges, CodeContextRetriever, Document, DocumentDigest},
5 semantic_index_settings::SemanticIndexSettings,
6 FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
7};
8use anyhow::Result;
9use async_trait::async_trait;
10use gpui::{executor::Deterministic, Task, TestAppContext};
11use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
12use parking_lot::Mutex;
13use pretty_assertions::assert_eq;
14use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs, Project};
15use rand::{rngs::StdRng, Rng};
16use serde_json::json;
17use settings::SettingsStore;
18use std::{
19 path::Path,
20 sync::{
21 atomic::{self, AtomicUsize},
22 Arc,
23 },
24 time::SystemTime,
25};
26use unindent::Unindent;
27use util::RandomCharIter;
28
29#[ctor::ctor]
30fn init_logger() {
31 if std::env::var("RUST_LOG").is_ok() {
32 env_logger::init();
33 }
34}
35
36#[gpui::test]
37async fn test_semantic_index(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
38 init_test(cx);
39
40 let fs = FakeFs::new(cx.background());
41 fs.insert_tree(
42 "/the-root",
43 json!({
44 "src": {
45 "file1.rs": "
46 fn aaa() {
47 println!(\"aaaaaaaaaaaa!\");
48 }
49
50 fn zzzzz() {
51 println!(\"SLEEPING\");
52 }
53 ".unindent(),
54 "file2.rs": "
55 fn bbb() {
56 println!(\"bbbbbbbbbbbbb!\");
57 }
58 struct pqpqpqp {}
59 ".unindent(),
60 "file3.toml": "
61 ZZZZZZZZZZZZZZZZZZ = 5
62 ".unindent(),
63 }
64 }),
65 )
66 .await;
67
68 let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
69 let rust_language = rust_lang();
70 let toml_language = toml_lang();
71 languages.add(rust_language);
72 languages.add(toml_language);
73
74 let db_dir = tempdir::TempDir::new("vector-store").unwrap();
75 let db_path = db_dir.path().join("db.sqlite");
76
77 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
78 let semantic_index = SemanticIndex::new(
79 fs.clone(),
80 db_path,
81 embedding_provider.clone(),
82 languages,
83 cx.to_async(),
84 )
85 .await
86 .unwrap();
87
88 let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
89
90 let _ = semantic_index
91 .update(cx, |store, cx| {
92 store.initialize_project(project.clone(), cx)
93 })
94 .await;
95
96 let (file_count, outstanding_file_count) = semantic_index
97 .update(cx, |store, cx| store.index_project(project.clone(), cx))
98 .await
99 .unwrap();
100 assert_eq!(file_count, 3);
101 deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
102 assert_eq!(*outstanding_file_count.borrow(), 0);
103
104 let search_results = semantic_index
105 .update(cx, |store, cx| {
106 store.search_project(
107 project.clone(),
108 "aaaaaabbbbzz".to_string(),
109 5,
110 vec![],
111 vec![],
112 cx,
113 )
114 })
115 .await
116 .unwrap();
117
118 assert_search_results(
119 &search_results,
120 &[
121 (Path::new("src/file1.rs").into(), 0),
122 (Path::new("src/file2.rs").into(), 0),
123 (Path::new("src/file3.toml").into(), 0),
124 (Path::new("src/file1.rs").into(), 45),
125 (Path::new("src/file2.rs").into(), 45),
126 ],
127 cx,
128 );
129
130 // Test Include Files Functonality
131 let include_files = vec![PathMatcher::new("*.rs").unwrap()];
132 let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
133 let rust_only_search_results = semantic_index
134 .update(cx, |store, cx| {
135 store.search_project(
136 project.clone(),
137 "aaaaaabbbbzz".to_string(),
138 5,
139 include_files,
140 vec![],
141 cx,
142 )
143 })
144 .await
145 .unwrap();
146
147 assert_search_results(
148 &rust_only_search_results,
149 &[
150 (Path::new("src/file1.rs").into(), 0),
151 (Path::new("src/file2.rs").into(), 0),
152 (Path::new("src/file1.rs").into(), 45),
153 (Path::new("src/file2.rs").into(), 45),
154 ],
155 cx,
156 );
157
158 let no_rust_search_results = semantic_index
159 .update(cx, |store, cx| {
160 store.search_project(
161 project.clone(),
162 "aaaaaabbbbzz".to_string(),
163 5,
164 vec![],
165 exclude_files,
166 cx,
167 )
168 })
169 .await
170 .unwrap();
171
172 assert_search_results(
173 &no_rust_search_results,
174 &[(Path::new("src/file3.toml").into(), 0)],
175 cx,
176 );
177
178 fs.save(
179 "/the-root/src/file2.rs".as_ref(),
180 &"
181 fn dddd() { println!(\"ddddd!\"); }
182 struct pqpqpqp {}
183 "
184 .unindent()
185 .into(),
186 Default::default(),
187 )
188 .await
189 .unwrap();
190
191 deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
192
193 let prev_embedding_count = embedding_provider.embedding_count();
194 let (file_count, outstanding_file_count) = semantic_index
195 .update(cx, |store, cx| store.index_project(project.clone(), cx))
196 .await
197 .unwrap();
198 assert_eq!(file_count, 1);
199
200 deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
201 assert_eq!(*outstanding_file_count.borrow(), 0);
202
203 assert_eq!(
204 embedding_provider.embedding_count() - prev_embedding_count,
205 1
206 );
207}
208
209#[gpui::test(iterations = 10)]
210async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
211 let (outstanding_job_count, _) = postage::watch::channel_with(0);
212 let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count));
213
214 let files = (1..=3)
215 .map(|file_ix| FileToEmbed {
216 worktree_id: 5,
217 path: format!("path-{file_ix}").into(),
218 mtime: SystemTime::now(),
219 documents: (0..rng.gen_range(4..22))
220 .map(|document_ix| {
221 let content_len = rng.gen_range(10..100);
222 let content = RandomCharIter::new(&mut rng)
223 .with_simple_text()
224 .take(content_len)
225 .collect::<String>();
226 let digest = DocumentDigest::from(content.as_str());
227 Document {
228 range: 0..10,
229 embedding: None,
230 name: format!("document {document_ix}"),
231 content,
232 digest,
233 token_count: rng.gen_range(10..30),
234 }
235 })
236 .collect(),
237 job_handle: JobHandle::new(&outstanding_job_count),
238 })
239 .collect::<Vec<_>>();
240
241 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
242
243 let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
244 for file in &files {
245 queue.push(file.clone());
246 }
247 queue.flush();
248
249 cx.foreground().run_until_parked();
250 let finished_files = queue.finished_files();
251 let mut embedded_files: Vec<_> = files
252 .iter()
253 .map(|_| finished_files.try_recv().expect("no finished file"))
254 .collect();
255
256 let expected_files: Vec<_> = files
257 .iter()
258 .map(|file| {
259 let mut file = file.clone();
260 for doc in &mut file.documents {
261 doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
262 }
263 file
264 })
265 .collect();
266
267 embedded_files.sort_by_key(|f| f.path.clone());
268
269 assert_eq!(embedded_files, expected_files);
270}
271
272#[track_caller]
273fn assert_search_results(
274 actual: &[SearchResult],
275 expected: &[(Arc<Path>, usize)],
276 cx: &TestAppContext,
277) {
278 let actual = actual
279 .iter()
280 .map(|search_result| {
281 search_result.buffer.read_with(cx, |buffer, _cx| {
282 (
283 buffer.file().unwrap().path().clone(),
284 search_result.range.start.to_offset(buffer),
285 )
286 })
287 })
288 .collect::<Vec<_>>();
289 assert_eq!(actual, expected);
290}
291
292#[gpui::test]
293async fn test_code_context_retrieval_rust() {
294 let language = rust_lang();
295 let embedding_provider = Arc::new(DummyEmbeddings {});
296 let mut retriever = CodeContextRetriever::new(embedding_provider);
297
298 let text = "
299 /// A doc comment
300 /// that spans multiple lines
301 #[gpui::test]
302 fn a() {
303 b
304 }
305
306 impl C for D {
307 }
308
309 impl E {
310 // This is also a preceding comment
311 pub fn function_1() -> Option<()> {
312 todo!();
313 }
314
315 // This is a preceding comment
316 fn function_2() -> Result<()> {
317 todo!();
318 }
319 }
320 "
321 .unindent();
322
323 let documents = retriever.parse_file(&text, language).unwrap();
324
325 assert_documents_eq(
326 &documents,
327 &[
328 (
329 "
330 /// A doc comment
331 /// that spans multiple lines
332 #[gpui::test]
333 fn a() {
334 b
335 }"
336 .unindent(),
337 text.find("fn a").unwrap(),
338 ),
339 (
340 "
341 impl C for D {
342 }"
343 .unindent(),
344 text.find("impl C").unwrap(),
345 ),
346 (
347 "
348 impl E {
349 // This is also a preceding comment
350 pub fn function_1() -> Option<()> { /* ... */ }
351
352 // This is a preceding comment
353 fn function_2() -> Result<()> { /* ... */ }
354 }"
355 .unindent(),
356 text.find("impl E").unwrap(),
357 ),
358 (
359 "
360 // This is also a preceding comment
361 pub fn function_1() -> Option<()> {
362 todo!();
363 }"
364 .unindent(),
365 text.find("pub fn function_1").unwrap(),
366 ),
367 (
368 "
369 // This is a preceding comment
370 fn function_2() -> Result<()> {
371 todo!();
372 }"
373 .unindent(),
374 text.find("fn function_2").unwrap(),
375 ),
376 ],
377 );
378}
379
380#[gpui::test]
381async fn test_code_context_retrieval_json() {
382 let language = json_lang();
383 let embedding_provider = Arc::new(DummyEmbeddings {});
384 let mut retriever = CodeContextRetriever::new(embedding_provider);
385
386 let text = r#"
387 {
388 "array": [1, 2, 3, 4],
389 "string": "abcdefg",
390 "nested_object": {
391 "array_2": [5, 6, 7, 8],
392 "string_2": "hijklmnop",
393 "boolean": true,
394 "none": null
395 }
396 }
397 "#
398 .unindent();
399
400 let documents = retriever.parse_file(&text, language.clone()).unwrap();
401
402 assert_documents_eq(
403 &documents,
404 &[(
405 r#"
406 {
407 "array": [],
408 "string": "",
409 "nested_object": {
410 "array_2": [],
411 "string_2": "",
412 "boolean": true,
413 "none": null
414 }
415 }"#
416 .unindent(),
417 text.find("{").unwrap(),
418 )],
419 );
420
421 let text = r#"
422 [
423 {
424 "name": "somebody",
425 "age": 42
426 },
427 {
428 "name": "somebody else",
429 "age": 43
430 }
431 ]
432 "#
433 .unindent();
434
435 let documents = retriever.parse_file(&text, language.clone()).unwrap();
436
437 assert_documents_eq(
438 &documents,
439 &[(
440 r#"
441 [{
442 "name": "",
443 "age": 42
444 }]"#
445 .unindent(),
446 text.find("[").unwrap(),
447 )],
448 );
449}
450
451fn assert_documents_eq(
452 documents: &[Document],
453 expected_contents_and_start_offsets: &[(String, usize)],
454) {
455 assert_eq!(
456 documents
457 .iter()
458 .map(|document| (document.content.clone(), document.range.start))
459 .collect::<Vec<_>>(),
460 expected_contents_and_start_offsets
461 );
462}
463
464#[gpui::test]
465async fn test_code_context_retrieval_javascript() {
466 let language = js_lang();
467 let embedding_provider = Arc::new(DummyEmbeddings {});
468 let mut retriever = CodeContextRetriever::new(embedding_provider);
469
470 let text = "
471 /* globals importScripts, backend */
472 function _authorize() {}
473
474 /**
475 * Sometimes the frontend build is way faster than backend.
476 */
477 export async function authorizeBank() {
478 _authorize(pushModal, upgradingAccountId, {});
479 }
480
481 export class SettingsPage {
482 /* This is a test setting */
483 constructor(page) {
484 this.page = page;
485 }
486 }
487
488 /* This is a test comment */
489 class TestClass {}
490
491 /* Schema for editor_events in Clickhouse. */
492 export interface ClickhouseEditorEvent {
493 installation_id: string
494 operation: string
495 }
496 "
497 .unindent();
498
499 let documents = retriever.parse_file(&text, language.clone()).unwrap();
500
501 assert_documents_eq(
502 &documents,
503 &[
504 (
505 "
506 /* globals importScripts, backend */
507 function _authorize() {}"
508 .unindent(),
509 37,
510 ),
511 (
512 "
513 /**
514 * Sometimes the frontend build is way faster than backend.
515 */
516 export async function authorizeBank() {
517 _authorize(pushModal, upgradingAccountId, {});
518 }"
519 .unindent(),
520 131,
521 ),
522 (
523 "
524 export class SettingsPage {
525 /* This is a test setting */
526 constructor(page) {
527 this.page = page;
528 }
529 }"
530 .unindent(),
531 225,
532 ),
533 (
534 "
535 /* This is a test setting */
536 constructor(page) {
537 this.page = page;
538 }"
539 .unindent(),
540 290,
541 ),
542 (
543 "
544 /* This is a test comment */
545 class TestClass {}"
546 .unindent(),
547 374,
548 ),
549 (
550 "
551 /* Schema for editor_events in Clickhouse. */
552 export interface ClickhouseEditorEvent {
553 installation_id: string
554 operation: string
555 }"
556 .unindent(),
557 440,
558 ),
559 ],
560 )
561}
562
563#[gpui::test]
564async fn test_code_context_retrieval_lua() {
565 let language = lua_lang();
566 let embedding_provider = Arc::new(DummyEmbeddings {});
567 let mut retriever = CodeContextRetriever::new(embedding_provider);
568
569 let text = r#"
570 -- Creates a new class
571 -- @param baseclass The Baseclass of this class, or nil.
572 -- @return A new class reference.
573 function classes.class(baseclass)
574 -- Create the class definition and metatable.
575 local classdef = {}
576 -- Find the super class, either Object or user-defined.
577 baseclass = baseclass or classes.Object
578 -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
579 setmetatable(classdef, { __index = baseclass })
580 -- All class instances have a reference to the class object.
581 classdef.class = classdef
582 --- Recursivly allocates the inheritance tree of the instance.
583 -- @param mastertable The 'root' of the inheritance tree.
584 -- @return Returns the instance with the allocated inheritance tree.
585 function classdef.alloc(mastertable)
586 -- All class instances have a reference to a superclass object.
587 local instance = { super = baseclass.alloc(mastertable) }
588 -- Any functions this instance does not know of will 'look up' to the superclass definition.
589 setmetatable(instance, { __index = classdef, __newindex = mastertable })
590 return instance
591 end
592 end
593 "#.unindent();
594
595 let documents = retriever.parse_file(&text, language.clone()).unwrap();
596
597 assert_documents_eq(
598 &documents,
599 &[
600 (r#"
601 -- Creates a new class
602 -- @param baseclass The Baseclass of this class, or nil.
603 -- @return A new class reference.
604 function classes.class(baseclass)
605 -- Create the class definition and metatable.
606 local classdef = {}
607 -- Find the super class, either Object or user-defined.
608 baseclass = baseclass or classes.Object
609 -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
610 setmetatable(classdef, { __index = baseclass })
611 -- All class instances have a reference to the class object.
612 classdef.class = classdef
613 --- Recursivly allocates the inheritance tree of the instance.
614 -- @param mastertable The 'root' of the inheritance tree.
615 -- @return Returns the instance with the allocated inheritance tree.
616 function classdef.alloc(mastertable)
617 --[ ... ]--
618 --[ ... ]--
619 end
620 end"#.unindent(),
621 114),
622 (r#"
623 --- Recursivly allocates the inheritance tree of the instance.
624 -- @param mastertable The 'root' of the inheritance tree.
625 -- @return Returns the instance with the allocated inheritance tree.
626 function classdef.alloc(mastertable)
627 -- All class instances have a reference to a superclass object.
628 local instance = { super = baseclass.alloc(mastertable) }
629 -- Any functions this instance does not know of will 'look up' to the superclass definition.
630 setmetatable(instance, { __index = classdef, __newindex = mastertable })
631 return instance
632 end"#.unindent(), 809),
633 ]
634 );
635}
636
637#[gpui::test]
638async fn test_code_context_retrieval_elixir() {
639 let language = elixir_lang();
640 let embedding_provider = Arc::new(DummyEmbeddings {});
641 let mut retriever = CodeContextRetriever::new(embedding_provider);
642
643 let text = r#"
644 defmodule File.Stream do
645 @moduledoc """
646 Defines a `File.Stream` struct returned by `File.stream!/3`.
647
648 The following fields are public:
649
650 * `path` - the file path
651 * `modes` - the file modes
652 * `raw` - a boolean indicating if bin functions should be used
653 * `line_or_bytes` - if reading should read lines or a given number of bytes
654 * `node` - the node the file belongs to
655
656 """
657
658 defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
659
660 @type t :: %__MODULE__{}
661
662 @doc false
663 def __build__(path, modes, line_or_bytes) do
664 raw = :lists.keyfind(:encoding, 1, modes) == false
665
666 modes =
667 case raw do
668 true ->
669 case :lists.keyfind(:read_ahead, 1, modes) do
670 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
671 {:read_ahead, _} -> [:raw | modes]
672 false -> [:raw, :read_ahead | modes]
673 end
674
675 false ->
676 modes
677 end
678
679 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
680
681 end"#
682 .unindent();
683
684 let documents = retriever.parse_file(&text, language.clone()).unwrap();
685
686 assert_documents_eq(
687 &documents,
688 &[(
689 r#"
690 defmodule File.Stream do
691 @moduledoc """
692 Defines a `File.Stream` struct returned by `File.stream!/3`.
693
694 The following fields are public:
695
696 * `path` - the file path
697 * `modes` - the file modes
698 * `raw` - a boolean indicating if bin functions should be used
699 * `line_or_bytes` - if reading should read lines or a given number of bytes
700 * `node` - the node the file belongs to
701
702 """
703
704 defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
705
706 @type t :: %__MODULE__{}
707
708 @doc false
709 def __build__(path, modes, line_or_bytes) do
710 raw = :lists.keyfind(:encoding, 1, modes) == false
711
712 modes =
713 case raw do
714 true ->
715 case :lists.keyfind(:read_ahead, 1, modes) do
716 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
717 {:read_ahead, _} -> [:raw | modes]
718 false -> [:raw, :read_ahead | modes]
719 end
720
721 false ->
722 modes
723 end
724
725 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
726
727 end"#
728 .unindent(),
729 0,
730 ),(r#"
731 @doc false
732 def __build__(path, modes, line_or_bytes) do
733 raw = :lists.keyfind(:encoding, 1, modes) == false
734
735 modes =
736 case raw do
737 true ->
738 case :lists.keyfind(:read_ahead, 1, modes) do
739 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
740 {:read_ahead, _} -> [:raw | modes]
741 false -> [:raw, :read_ahead | modes]
742 end
743
744 false ->
745 modes
746 end
747
748 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
749
750 end"#.unindent(), 574)],
751 );
752}
753
754#[gpui::test]
755async fn test_code_context_retrieval_cpp() {
756 let language = cpp_lang();
757 let embedding_provider = Arc::new(DummyEmbeddings {});
758 let mut retriever = CodeContextRetriever::new(embedding_provider);
759
760 let text = "
761 /**
762 * @brief Main function
763 * @returns 0 on exit
764 */
765 int main() { return 0; }
766
767 /**
768 * This is a test comment
769 */
770 class MyClass { // The class
771 public: // Access specifier
772 int myNum; // Attribute (int variable)
773 string myString; // Attribute (string variable)
774 };
775
776 // This is a test comment
777 enum Color { red, green, blue };
778
779 /** This is a preceding block comment
780 * This is the second line
781 */
782 struct { // Structure declaration
783 int myNum; // Member (int variable)
784 string myString; // Member (string variable)
785 } myStructure;
786
787 /**
788 * @brief Matrix class.
789 */
790 template <typename T,
791 typename = typename std::enable_if<
792 std::is_integral<T>::value || std::is_floating_point<T>::value,
793 bool>::type>
794 class Matrix2 {
795 std::vector<std::vector<T>> _mat;
796
797 public:
798 /**
799 * @brief Constructor
800 * @tparam Integer ensuring integers are being evaluated and not other
801 * data types.
802 * @param size denoting the size of Matrix as size x size
803 */
804 template <typename Integer,
805 typename = typename std::enable_if<std::is_integral<Integer>::value,
806 Integer>::type>
807 explicit Matrix(const Integer size) {
808 for (size_t i = 0; i < size; ++i) {
809 _mat.emplace_back(std::vector<T>(size, 0));
810 }
811 }
812 }"
813 .unindent();
814
815 let documents = retriever.parse_file(&text, language.clone()).unwrap();
816
817 assert_documents_eq(
818 &documents,
819 &[
820 (
821 "
822 /**
823 * @brief Main function
824 * @returns 0 on exit
825 */
826 int main() { return 0; }"
827 .unindent(),
828 54,
829 ),
830 (
831 "
832 /**
833 * This is a test comment
834 */
835 class MyClass { // The class
836 public: // Access specifier
837 int myNum; // Attribute (int variable)
838 string myString; // Attribute (string variable)
839 }"
840 .unindent(),
841 112,
842 ),
843 (
844 "
845 // This is a test comment
846 enum Color { red, green, blue }"
847 .unindent(),
848 322,
849 ),
850 (
851 "
852 /** This is a preceding block comment
853 * This is the second line
854 */
855 struct { // Structure declaration
856 int myNum; // Member (int variable)
857 string myString; // Member (string variable)
858 } myStructure;"
859 .unindent(),
860 425,
861 ),
862 (
863 "
864 /**
865 * @brief Matrix class.
866 */
867 template <typename T,
868 typename = typename std::enable_if<
869 std::is_integral<T>::value || std::is_floating_point<T>::value,
870 bool>::type>
871 class Matrix2 {
872 std::vector<std::vector<T>> _mat;
873
874 public:
875 /**
876 * @brief Constructor
877 * @tparam Integer ensuring integers are being evaluated and not other
878 * data types.
879 * @param size denoting the size of Matrix as size x size
880 */
881 template <typename Integer,
882 typename = typename std::enable_if<std::is_integral<Integer>::value,
883 Integer>::type>
884 explicit Matrix(const Integer size) {
885 for (size_t i = 0; i < size; ++i) {
886 _mat.emplace_back(std::vector<T>(size, 0));
887 }
888 }
889 }"
890 .unindent(),
891 612,
892 ),
893 (
894 "
895 explicit Matrix(const Integer size) {
896 for (size_t i = 0; i < size; ++i) {
897 _mat.emplace_back(std::vector<T>(size, 0));
898 }
899 }"
900 .unindent(),
901 1226,
902 ),
903 ],
904 );
905}
906
907#[gpui::test]
908async fn test_code_context_retrieval_ruby() {
909 let language = ruby_lang();
910 let embedding_provider = Arc::new(DummyEmbeddings {});
911 let mut retriever = CodeContextRetriever::new(embedding_provider);
912
913 let text = r#"
914 # This concern is inspired by "sudo mode" on GitHub. It
915 # is a way to re-authenticate a user before allowing them
916 # to see or perform an action.
917 #
918 # Add `before_action :require_challenge!` to actions you
919 # want to protect.
920 #
921 # The user will be shown a page to enter the challenge (which
922 # is either the password, or just the username when no
923 # password exists). Upon passing, there is a grace period
924 # during which no challenge will be asked from the user.
925 #
926 # Accessing challenge-protected resources during the grace
927 # period will refresh the grace period.
928 module ChallengableConcern
929 extend ActiveSupport::Concern
930
931 CHALLENGE_TIMEOUT = 1.hour.freeze
932
933 def require_challenge!
934 return if skip_challenge?
935
936 if challenge_passed_recently?
937 session[:challenge_passed_at] = Time.now.utc
938 return
939 end
940
941 @challenge = Form::Challenge.new(return_to: request.url)
942
943 if params.key?(:form_challenge)
944 if challenge_passed?
945 session[:challenge_passed_at] = Time.now.utc
946 else
947 flash.now[:alert] = I18n.t('challenge.invalid_password')
948 render_challenge
949 end
950 else
951 render_challenge
952 end
953 end
954
955 def challenge_passed?
956 current_user.valid_password?(challenge_params[:current_password])
957 end
958 end
959
960 class Animal
961 include Comparable
962
963 attr_reader :legs
964
965 def initialize(name, legs)
966 @name, @legs = name, legs
967 end
968
969 def <=>(other)
970 legs <=> other.legs
971 end
972 end
973
974 # Singleton method for car object
975 def car.wheels
976 puts "There are four wheels"
977 end"#
978 .unindent();
979
980 let documents = retriever.parse_file(&text, language.clone()).unwrap();
981
982 assert_documents_eq(
983 &documents,
984 &[
985 (
986 r#"
987 # This concern is inspired by "sudo mode" on GitHub. It
988 # is a way to re-authenticate a user before allowing them
989 # to see or perform an action.
990 #
991 # Add `before_action :require_challenge!` to actions you
992 # want to protect.
993 #
994 # The user will be shown a page to enter the challenge (which
995 # is either the password, or just the username when no
996 # password exists). Upon passing, there is a grace period
997 # during which no challenge will be asked from the user.
998 #
999 # Accessing challenge-protected resources during the grace
1000 # period will refresh the grace period.
1001 module ChallengableConcern
1002 extend ActiveSupport::Concern
1003
1004 CHALLENGE_TIMEOUT = 1.hour.freeze
1005
1006 def require_challenge!
1007 # ...
1008 end
1009
1010 def challenge_passed?
1011 # ...
1012 end
1013 end"#
1014 .unindent(),
1015 558,
1016 ),
1017 (
1018 r#"
1019 def require_challenge!
1020 return if skip_challenge?
1021
1022 if challenge_passed_recently?
1023 session[:challenge_passed_at] = Time.now.utc
1024 return
1025 end
1026
1027 @challenge = Form::Challenge.new(return_to: request.url)
1028
1029 if params.key?(:form_challenge)
1030 if challenge_passed?
1031 session[:challenge_passed_at] = Time.now.utc
1032 else
1033 flash.now[:alert] = I18n.t('challenge.invalid_password')
1034 render_challenge
1035 end
1036 else
1037 render_challenge
1038 end
1039 end"#
1040 .unindent(),
1041 663,
1042 ),
1043 (
1044 r#"
1045 def challenge_passed?
1046 current_user.valid_password?(challenge_params[:current_password])
1047 end"#
1048 .unindent(),
1049 1254,
1050 ),
1051 (
1052 r#"
1053 class Animal
1054 include Comparable
1055
1056 attr_reader :legs
1057
1058 def initialize(name, legs)
1059 # ...
1060 end
1061
1062 def <=>(other)
1063 # ...
1064 end
1065 end"#
1066 .unindent(),
1067 1363,
1068 ),
1069 (
1070 r#"
1071 def initialize(name, legs)
1072 @name, @legs = name, legs
1073 end"#
1074 .unindent(),
1075 1427,
1076 ),
1077 (
1078 r#"
1079 def <=>(other)
1080 legs <=> other.legs
1081 end"#
1082 .unindent(),
1083 1501,
1084 ),
1085 (
1086 r#"
1087 # Singleton method for car object
1088 def car.wheels
1089 puts "There are four wheels"
1090 end"#
1091 .unindent(),
1092 1591,
1093 ),
1094 ],
1095 );
1096}
1097
1098#[gpui::test]
1099async fn test_code_context_retrieval_php() {
1100 let language = php_lang();
1101 let embedding_provider = Arc::new(DummyEmbeddings {});
1102 let mut retriever = CodeContextRetriever::new(embedding_provider);
1103
1104 let text = r#"
1105 <?php
1106
1107 namespace LevelUp\Experience\Concerns;
1108
1109 /*
1110 This is a multiple-lines comment block
1111 that spans over multiple
1112 lines
1113 */
1114 function functionName() {
1115 echo "Hello world!";
1116 }
1117
1118 trait HasAchievements
1119 {
1120 /**
1121 * @throws \Exception
1122 */
1123 public function grantAchievement(Achievement $achievement, $progress = null): void
1124 {
1125 if ($progress > 100) {
1126 throw new Exception(message: 'Progress cannot be greater than 100');
1127 }
1128
1129 if ($this->achievements()->find($achievement->id)) {
1130 throw new Exception(message: 'User already has this Achievement');
1131 }
1132
1133 $this->achievements()->attach($achievement, [
1134 'progress' => $progress ?? null,
1135 ]);
1136
1137 $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1138 }
1139
1140 public function achievements(): BelongsToMany
1141 {
1142 return $this->belongsToMany(related: Achievement::class)
1143 ->withPivot(columns: 'progress')
1144 ->where('is_secret', false)
1145 ->using(AchievementUser::class);
1146 }
1147 }
1148
1149 interface Multiplier
1150 {
1151 public function qualifies(array $data): bool;
1152
1153 public function setMultiplier(): int;
1154 }
1155
1156 enum AuditType: string
1157 {
1158 case Add = 'add';
1159 case Remove = 'remove';
1160 case Reset = 'reset';
1161 case LevelUp = 'level_up';
1162 }
1163
1164 ?>"#
1165 .unindent();
1166
1167 let documents = retriever.parse_file(&text, language.clone()).unwrap();
1168
1169 assert_documents_eq(
1170 &documents,
1171 &[
1172 (
1173 r#"
1174 /*
1175 This is a multiple-lines comment block
1176 that spans over multiple
1177 lines
1178 */
1179 function functionName() {
1180 echo "Hello world!";
1181 }"#
1182 .unindent(),
1183 123,
1184 ),
1185 (
1186 r#"
1187 trait HasAchievements
1188 {
1189 /**
1190 * @throws \Exception
1191 */
1192 public function grantAchievement(Achievement $achievement, $progress = null): void
1193 {/* ... */}
1194
1195 public function achievements(): BelongsToMany
1196 {/* ... */}
1197 }"#
1198 .unindent(),
1199 177,
1200 ),
1201 (r#"
1202 /**
1203 * @throws \Exception
1204 */
1205 public function grantAchievement(Achievement $achievement, $progress = null): void
1206 {
1207 if ($progress > 100) {
1208 throw new Exception(message: 'Progress cannot be greater than 100');
1209 }
1210
1211 if ($this->achievements()->find($achievement->id)) {
1212 throw new Exception(message: 'User already has this Achievement');
1213 }
1214
1215 $this->achievements()->attach($achievement, [
1216 'progress' => $progress ?? null,
1217 ]);
1218
1219 $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1220 }"#.unindent(), 245),
1221 (r#"
1222 public function achievements(): BelongsToMany
1223 {
1224 return $this->belongsToMany(related: Achievement::class)
1225 ->withPivot(columns: 'progress')
1226 ->where('is_secret', false)
1227 ->using(AchievementUser::class);
1228 }"#.unindent(), 902),
1229 (r#"
1230 interface Multiplier
1231 {
1232 public function qualifies(array $data): bool;
1233
1234 public function setMultiplier(): int;
1235 }"#.unindent(),
1236 1146),
1237 (r#"
1238 enum AuditType: string
1239 {
1240 case Add = 'add';
1241 case Remove = 'remove';
1242 case Reset = 'reset';
1243 case LevelUp = 'level_up';
1244 }"#.unindent(), 1265)
1245 ],
1246 );
1247}
1248
1249#[derive(Default)]
1250struct FakeEmbeddingProvider {
1251 embedding_count: AtomicUsize,
1252}
1253
1254impl FakeEmbeddingProvider {
1255 fn embedding_count(&self) -> usize {
1256 self.embedding_count.load(atomic::Ordering::SeqCst)
1257 }
1258
1259 fn embed_sync(&self, span: &str) -> Embedding {
1260 let mut result = vec![1.0; 26];
1261 for letter in span.chars() {
1262 let letter = letter.to_ascii_lowercase();
1263 if letter as u32 >= 'a' as u32 {
1264 let ix = (letter as u32) - ('a' as u32);
1265 if ix < 26 {
1266 result[ix as usize] += 1.0;
1267 }
1268 }
1269 }
1270
1271 let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
1272 for x in &mut result {
1273 *x /= norm;
1274 }
1275
1276 result.into()
1277 }
1278}
1279
1280#[async_trait]
1281impl EmbeddingProvider for FakeEmbeddingProvider {
1282 fn truncate(&self, span: &str) -> (String, usize) {
1283 (span.to_string(), 1)
1284 }
1285
1286 fn max_tokens_per_batch(&self) -> usize {
1287 200
1288 }
1289
1290 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
1291 self.embedding_count
1292 .fetch_add(spans.len(), atomic::Ordering::SeqCst);
1293 Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
1294 }
1295}
1296
1297fn js_lang() -> Arc<Language> {
1298 Arc::new(
1299 Language::new(
1300 LanguageConfig {
1301 name: "Javascript".into(),
1302 path_suffixes: vec!["js".into()],
1303 ..Default::default()
1304 },
1305 Some(tree_sitter_typescript::language_tsx()),
1306 )
1307 .with_embedding_query(
1308 &r#"
1309
1310 (
1311 (comment)* @context
1312 .
1313 [
1314 (export_statement
1315 (function_declaration
1316 "async"? @name
1317 "function" @name
1318 name: (_) @name))
1319 (function_declaration
1320 "async"? @name
1321 "function" @name
1322 name: (_) @name)
1323 ] @item
1324 )
1325
1326 (
1327 (comment)* @context
1328 .
1329 [
1330 (export_statement
1331 (class_declaration
1332 "class" @name
1333 name: (_) @name))
1334 (class_declaration
1335 "class" @name
1336 name: (_) @name)
1337 ] @item
1338 )
1339
1340 (
1341 (comment)* @context
1342 .
1343 [
1344 (export_statement
1345 (interface_declaration
1346 "interface" @name
1347 name: (_) @name))
1348 (interface_declaration
1349 "interface" @name
1350 name: (_) @name)
1351 ] @item
1352 )
1353
1354 (
1355 (comment)* @context
1356 .
1357 [
1358 (export_statement
1359 (enum_declaration
1360 "enum" @name
1361 name: (_) @name))
1362 (enum_declaration
1363 "enum" @name
1364 name: (_) @name)
1365 ] @item
1366 )
1367
1368 (
1369 (comment)* @context
1370 .
1371 (method_definition
1372 [
1373 "get"
1374 "set"
1375 "async"
1376 "*"
1377 "static"
1378 ]* @name
1379 name: (_) @name) @item
1380 )
1381
1382 "#
1383 .unindent(),
1384 )
1385 .unwrap(),
1386 )
1387}
1388
1389fn rust_lang() -> Arc<Language> {
1390 Arc::new(
1391 Language::new(
1392 LanguageConfig {
1393 name: "Rust".into(),
1394 path_suffixes: vec!["rs".into()],
1395 collapsed_placeholder: " /* ... */ ".to_string(),
1396 ..Default::default()
1397 },
1398 Some(tree_sitter_rust::language()),
1399 )
1400 .with_embedding_query(
1401 r#"
1402 (
1403 [(line_comment) (attribute_item)]* @context
1404 .
1405 [
1406 (struct_item
1407 name: (_) @name)
1408
1409 (enum_item
1410 name: (_) @name)
1411
1412 (impl_item
1413 trait: (_)? @name
1414 "for"? @name
1415 type: (_) @name)
1416
1417 (trait_item
1418 name: (_) @name)
1419
1420 (function_item
1421 name: (_) @name
1422 body: (block
1423 "{" @keep
1424 "}" @keep) @collapse)
1425
1426 (macro_definition
1427 name: (_) @name)
1428 ] @item
1429 )
1430 "#,
1431 )
1432 .unwrap(),
1433 )
1434}
1435
1436fn json_lang() -> Arc<Language> {
1437 Arc::new(
1438 Language::new(
1439 LanguageConfig {
1440 name: "JSON".into(),
1441 path_suffixes: vec!["json".into()],
1442 ..Default::default()
1443 },
1444 Some(tree_sitter_json::language()),
1445 )
1446 .with_embedding_query(
1447 r#"
1448 (document) @item
1449
1450 (array
1451 "[" @keep
1452 .
1453 (object)? @keep
1454 "]" @keep) @collapse
1455
1456 (pair value: (string
1457 "\"" @keep
1458 "\"" @keep) @collapse)
1459 "#,
1460 )
1461 .unwrap(),
1462 )
1463}
1464
1465fn toml_lang() -> Arc<Language> {
1466 Arc::new(Language::new(
1467 LanguageConfig {
1468 name: "TOML".into(),
1469 path_suffixes: vec!["toml".into()],
1470 ..Default::default()
1471 },
1472 Some(tree_sitter_toml::language()),
1473 ))
1474}
1475
1476fn cpp_lang() -> Arc<Language> {
1477 Arc::new(
1478 Language::new(
1479 LanguageConfig {
1480 name: "CPP".into(),
1481 path_suffixes: vec!["cpp".into()],
1482 ..Default::default()
1483 },
1484 Some(tree_sitter_cpp::language()),
1485 )
1486 .with_embedding_query(
1487 r#"
1488 (
1489 (comment)* @context
1490 .
1491 (function_definition
1492 (type_qualifier)? @name
1493 type: (_)? @name
1494 declarator: [
1495 (function_declarator
1496 declarator: (_) @name)
1497 (pointer_declarator
1498 "*" @name
1499 declarator: (function_declarator
1500 declarator: (_) @name))
1501 (pointer_declarator
1502 "*" @name
1503 declarator: (pointer_declarator
1504 "*" @name
1505 declarator: (function_declarator
1506 declarator: (_) @name)))
1507 (reference_declarator
1508 ["&" "&&"] @name
1509 (function_declarator
1510 declarator: (_) @name))
1511 ]
1512 (type_qualifier)? @name) @item
1513 )
1514
1515 (
1516 (comment)* @context
1517 .
1518 (template_declaration
1519 (class_specifier
1520 "class" @name
1521 name: (_) @name)
1522 ) @item
1523 )
1524
1525 (
1526 (comment)* @context
1527 .
1528 (class_specifier
1529 "class" @name
1530 name: (_) @name) @item
1531 )
1532
1533 (
1534 (comment)* @context
1535 .
1536 (enum_specifier
1537 "enum" @name
1538 name: (_) @name) @item
1539 )
1540
1541 (
1542 (comment)* @context
1543 .
1544 (declaration
1545 type: (struct_specifier
1546 "struct" @name)
1547 declarator: (_) @name) @item
1548 )
1549
1550 "#,
1551 )
1552 .unwrap(),
1553 )
1554}
1555
1556fn lua_lang() -> Arc<Language> {
1557 Arc::new(
1558 Language::new(
1559 LanguageConfig {
1560 name: "Lua".into(),
1561 path_suffixes: vec!["lua".into()],
1562 collapsed_placeholder: "--[ ... ]--".to_string(),
1563 ..Default::default()
1564 },
1565 Some(tree_sitter_lua::language()),
1566 )
1567 .with_embedding_query(
1568 r#"
1569 (
1570 (comment)* @context
1571 .
1572 (function_declaration
1573 "function" @name
1574 name: (_) @name
1575 (comment)* @collapse
1576 body: (block) @collapse
1577 ) @item
1578 )
1579 "#,
1580 )
1581 .unwrap(),
1582 )
1583}
1584
1585fn php_lang() -> Arc<Language> {
1586 Arc::new(
1587 Language::new(
1588 LanguageConfig {
1589 name: "PHP".into(),
1590 path_suffixes: vec!["php".into()],
1591 collapsed_placeholder: "/* ... */".into(),
1592 ..Default::default()
1593 },
1594 Some(tree_sitter_php::language()),
1595 )
1596 .with_embedding_query(
1597 r#"
1598 (
1599 (comment)* @context
1600 .
1601 [
1602 (function_definition
1603 "function" @name
1604 name: (_) @name
1605 body: (_
1606 "{" @keep
1607 "}" @keep) @collapse
1608 )
1609
1610 (trait_declaration
1611 "trait" @name
1612 name: (_) @name)
1613
1614 (method_declaration
1615 "function" @name
1616 name: (_) @name
1617 body: (_
1618 "{" @keep
1619 "}" @keep) @collapse
1620 )
1621
1622 (interface_declaration
1623 "interface" @name
1624 name: (_) @name
1625 )
1626
1627 (enum_declaration
1628 "enum" @name
1629 name: (_) @name
1630 )
1631
1632 ] @item
1633 )
1634 "#,
1635 )
1636 .unwrap(),
1637 )
1638}
1639
1640fn ruby_lang() -> Arc<Language> {
1641 Arc::new(
1642 Language::new(
1643 LanguageConfig {
1644 name: "Ruby".into(),
1645 path_suffixes: vec!["rb".into()],
1646 collapsed_placeholder: "# ...".to_string(),
1647 ..Default::default()
1648 },
1649 Some(tree_sitter_ruby::language()),
1650 )
1651 .with_embedding_query(
1652 r#"
1653 (
1654 (comment)* @context
1655 .
1656 [
1657 (module
1658 "module" @name
1659 name: (_) @name)
1660 (method
1661 "def" @name
1662 name: (_) @name
1663 body: (body_statement) @collapse)
1664 (class
1665 "class" @name
1666 name: (_) @name)
1667 (singleton_method
1668 "def" @name
1669 object: (_) @name
1670 "." @name
1671 name: (_) @name
1672 body: (body_statement) @collapse)
1673 ] @item
1674 )
1675 "#,
1676 )
1677 .unwrap(),
1678 )
1679}
1680
1681fn elixir_lang() -> Arc<Language> {
1682 Arc::new(
1683 Language::new(
1684 LanguageConfig {
1685 name: "Elixir".into(),
1686 path_suffixes: vec!["rs".into()],
1687 ..Default::default()
1688 },
1689 Some(tree_sitter_elixir::language()),
1690 )
1691 .with_embedding_query(
1692 r#"
1693 (
1694 (unary_operator
1695 operator: "@"
1696 operand: (call
1697 target: (identifier) @unary
1698 (#match? @unary "^(doc)$"))
1699 ) @context
1700 .
1701 (call
1702 target: (identifier) @name
1703 (arguments
1704 [
1705 (identifier) @name
1706 (call
1707 target: (identifier) @name)
1708 (binary_operator
1709 left: (call
1710 target: (identifier) @name)
1711 operator: "when")
1712 ])
1713 (#match? @name "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @item
1714 )
1715
1716 (call
1717 target: (identifier) @name
1718 (arguments (alias) @name)
1719 (#match? @name "^(defmodule|defprotocol)$")) @item
1720 "#,
1721 )
1722 .unwrap(),
1723 )
1724}
1725
1726#[gpui::test]
1727fn test_subtract_ranges() {
1728 // collapsed_ranges: Vec<Range<usize>>, keep_ranges: Vec<Range<usize>>
1729
1730 assert_eq!(
1731 subtract_ranges(&[0..5, 10..21], &[0..1, 4..5]),
1732 vec![1..4, 10..21]
1733 );
1734
1735 assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
1736}
1737
1738fn init_test(cx: &mut TestAppContext) {
1739 cx.update(|cx| {
1740 cx.set_global(SettingsStore::test(cx));
1741 settings::register::<SemanticIndexSettings>(cx);
1742 settings::register::<ProjectSettings>(cx);
1743 });
1744}