1use crate::{
2 db::dot,
3 embedding::{DummyEmbeddings, EmbeddingProvider},
4 embedding_queue::EmbeddingQueue,
5 parsing::{subtract_ranges, CodeContextRetriever, Document},
6 semantic_index_settings::SemanticIndexSettings,
7 FileToEmbed, JobHandle, SearchResult, SemanticIndex,
8};
9use anyhow::Result;
10use async_trait::async_trait;
11use gpui::{Task, TestAppContext};
12use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
13use parking_lot::Mutex;
14use pretty_assertions::assert_eq;
15use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs, Project};
16use rand::{rngs::StdRng, Rng};
17use serde_json::json;
18use settings::SettingsStore;
19use std::{
20 path::Path,
21 sync::{
22 atomic::{self, AtomicUsize},
23 Arc,
24 },
25 time::SystemTime,
26};
27use unindent::Unindent;
28use util::RandomCharIter;
29
30#[ctor::ctor]
31fn init_logger() {
32 if std::env::var("RUST_LOG").is_ok() {
33 env_logger::init();
34 }
35}
36
37#[gpui::test]
38async fn test_semantic_index(cx: &mut TestAppContext) {
39 init_test(cx);
40
41 let fs = FakeFs::new(cx.background());
42 fs.insert_tree(
43 "/the-root",
44 json!({
45 "src": {
46 "file1.rs": "
47 fn aaa() {
48 println!(\"aaaaaaaaaaaa!\");
49 }
50
51 fn zzzzz() {
52 println!(\"SLEEPING\");
53 }
54 ".unindent(),
55 "file2.rs": "
56 fn bbb() {
57 println!(\"bbbbbbbbbbbbb!\");
58 }
59 ".unindent(),
60 "file3.toml": "
61 ZZZZZZZZZZZZZZZZZZ = 5
62 ".unindent(),
63 }
64 }),
65 )
66 .await;
67
68 let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
69 let rust_language = rust_lang();
70 let toml_language = toml_lang();
71 languages.add(rust_language);
72 languages.add(toml_language);
73
74 let db_dir = tempdir::TempDir::new("vector-store").unwrap();
75 let db_path = db_dir.path().join("db.sqlite");
76
77 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
78 let semantic_index = SemanticIndex::new(
79 fs.clone(),
80 db_path,
81 embedding_provider.clone(),
82 languages,
83 cx.to_async(),
84 )
85 .await
86 .unwrap();
87
88 let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
89
90 let _ = semantic_index
91 .update(cx, |store, cx| {
92 store.initialize_project(project.clone(), cx)
93 })
94 .await;
95
96 let (file_count, outstanding_file_count) = semantic_index
97 .update(cx, |store, cx| store.index_project(project.clone(), cx))
98 .await
99 .unwrap();
100 assert_eq!(file_count, 3);
101 cx.foreground().run_until_parked();
102 assert_eq!(*outstanding_file_count.borrow(), 0);
103
104 let search_results = semantic_index
105 .update(cx, |store, cx| {
106 store.search_project(
107 project.clone(),
108 "aaaaaabbbbzz".to_string(),
109 5,
110 vec![],
111 vec![],
112 cx,
113 )
114 })
115 .await
116 .unwrap();
117
118 assert_search_results(
119 &search_results,
120 &[
121 (Path::new("src/file1.rs").into(), 0),
122 (Path::new("src/file2.rs").into(), 0),
123 (Path::new("src/file3.toml").into(), 0),
124 (Path::new("src/file1.rs").into(), 45),
125 ],
126 cx,
127 );
128
129 // Test Include Files Functonality
130 let include_files = vec![PathMatcher::new("*.rs").unwrap()];
131 let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
132 let rust_only_search_results = semantic_index
133 .update(cx, |store, cx| {
134 store.search_project(
135 project.clone(),
136 "aaaaaabbbbzz".to_string(),
137 5,
138 include_files,
139 vec![],
140 cx,
141 )
142 })
143 .await
144 .unwrap();
145
146 assert_search_results(
147 &rust_only_search_results,
148 &[
149 (Path::new("src/file1.rs").into(), 0),
150 (Path::new("src/file2.rs").into(), 0),
151 (Path::new("src/file1.rs").into(), 45),
152 ],
153 cx,
154 );
155
156 let no_rust_search_results = semantic_index
157 .update(cx, |store, cx| {
158 store.search_project(
159 project.clone(),
160 "aaaaaabbbbzz".to_string(),
161 5,
162 vec![],
163 exclude_files,
164 cx,
165 )
166 })
167 .await
168 .unwrap();
169
170 assert_search_results(
171 &no_rust_search_results,
172 &[(Path::new("src/file3.toml").into(), 0)],
173 cx,
174 );
175
176 fs.save(
177 "/the-root/src/file2.rs".as_ref(),
178 &"
179 fn dddd() { println!(\"ddddd!\"); }
180 struct pqpqpqp {}
181 "
182 .unindent()
183 .into(),
184 Default::default(),
185 )
186 .await
187 .unwrap();
188
189 cx.foreground().run_until_parked();
190
191 let prev_embedding_count = embedding_provider.embedding_count();
192 let (file_count, outstanding_file_count) = semantic_index
193 .update(cx, |store, cx| store.index_project(project.clone(), cx))
194 .await
195 .unwrap();
196 assert_eq!(file_count, 1);
197
198 cx.foreground().run_until_parked();
199 assert_eq!(*outstanding_file_count.borrow(), 0);
200
201 assert_eq!(
202 embedding_provider.embedding_count() - prev_embedding_count,
203 2
204 );
205}
206
207#[gpui::test(iterations = 10)]
208async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
209 let (outstanding_job_count, _) = postage::watch::channel_with(0);
210 let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count));
211
212 let files = (1..=3)
213 .map(|file_ix| FileToEmbed {
214 worktree_id: 5,
215 path: format!("path-{file_ix}").into(),
216 mtime: SystemTime::now(),
217 documents: (0..rng.gen_range(4..22))
218 .map(|document_ix| {
219 let content_len = rng.gen_range(10..100);
220 Document {
221 range: 0..10,
222 embedding: Vec::new(),
223 name: format!("document {document_ix}"),
224 content: RandomCharIter::new(&mut rng)
225 .with_simple_text()
226 .take(content_len)
227 .collect(),
228 sha1: rng.gen(),
229 token_count: rng.gen_range(10..30),
230 }
231 })
232 .collect(),
233 job_handle: JobHandle::new(&outstanding_job_count),
234 })
235 .collect::<Vec<_>>();
236
237 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
238
239 let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
240 for file in &files {
241 queue.push(file.clone());
242 }
243 queue.flush();
244
245 cx.foreground().run_until_parked();
246 let finished_files = queue.finished_files();
247 let mut embedded_files: Vec<_> = files
248 .iter()
249 .map(|_| finished_files.try_recv().expect("no finished file"))
250 .collect();
251
252 let expected_files: Vec<_> = files
253 .iter()
254 .map(|file| {
255 let mut file = file.clone();
256 for doc in &mut file.documents {
257 doc.embedding = embedding_provider.embed_sync(doc.content.as_ref());
258 }
259 file
260 })
261 .collect();
262
263 embedded_files.sort_by_key(|f| f.path.clone());
264
265 assert_eq!(embedded_files, expected_files);
266}
267
268#[track_caller]
269fn assert_search_results(
270 actual: &[SearchResult],
271 expected: &[(Arc<Path>, usize)],
272 cx: &TestAppContext,
273) {
274 let actual = actual
275 .iter()
276 .map(|search_result| {
277 search_result.buffer.read_with(cx, |buffer, _cx| {
278 (
279 buffer.file().unwrap().path().clone(),
280 search_result.range.start.to_offset(buffer),
281 )
282 })
283 })
284 .collect::<Vec<_>>();
285 assert_eq!(actual, expected);
286}
287
288#[gpui::test]
289async fn test_code_context_retrieval_rust() {
290 let language = rust_lang();
291 let embedding_provider = Arc::new(DummyEmbeddings {});
292 let mut retriever = CodeContextRetriever::new(embedding_provider);
293
294 let text = "
295 /// A doc comment
296 /// that spans multiple lines
297 #[gpui::test]
298 fn a() {
299 b
300 }
301
302 impl C for D {
303 }
304
305 impl E {
306 // This is also a preceding comment
307 pub fn function_1() -> Option<()> {
308 todo!();
309 }
310
311 // This is a preceding comment
312 fn function_2() -> Result<()> {
313 todo!();
314 }
315 }
316 "
317 .unindent();
318
319 let documents = retriever.parse_file(&text, language).unwrap();
320
321 assert_documents_eq(
322 &documents,
323 &[
324 (
325 "
326 /// A doc comment
327 /// that spans multiple lines
328 #[gpui::test]
329 fn a() {
330 b
331 }"
332 .unindent(),
333 text.find("fn a").unwrap(),
334 ),
335 (
336 "
337 impl C for D {
338 }"
339 .unindent(),
340 text.find("impl C").unwrap(),
341 ),
342 (
343 "
344 impl E {
345 // This is also a preceding comment
346 pub fn function_1() -> Option<()> { /* ... */ }
347
348 // This is a preceding comment
349 fn function_2() -> Result<()> { /* ... */ }
350 }"
351 .unindent(),
352 text.find("impl E").unwrap(),
353 ),
354 (
355 "
356 // This is also a preceding comment
357 pub fn function_1() -> Option<()> {
358 todo!();
359 }"
360 .unindent(),
361 text.find("pub fn function_1").unwrap(),
362 ),
363 (
364 "
365 // This is a preceding comment
366 fn function_2() -> Result<()> {
367 todo!();
368 }"
369 .unindent(),
370 text.find("fn function_2").unwrap(),
371 ),
372 ],
373 );
374}
375
376#[gpui::test]
377async fn test_code_context_retrieval_json() {
378 let language = json_lang();
379 let embedding_provider = Arc::new(DummyEmbeddings {});
380 let mut retriever = CodeContextRetriever::new(embedding_provider);
381
382 let text = r#"
383 {
384 "array": [1, 2, 3, 4],
385 "string": "abcdefg",
386 "nested_object": {
387 "array_2": [5, 6, 7, 8],
388 "string_2": "hijklmnop",
389 "boolean": true,
390 "none": null
391 }
392 }
393 "#
394 .unindent();
395
396 let documents = retriever.parse_file(&text, language.clone()).unwrap();
397
398 assert_documents_eq(
399 &documents,
400 &[(
401 r#"
402 {
403 "array": [],
404 "string": "",
405 "nested_object": {
406 "array_2": [],
407 "string_2": "",
408 "boolean": true,
409 "none": null
410 }
411 }"#
412 .unindent(),
413 text.find("{").unwrap(),
414 )],
415 );
416
417 let text = r#"
418 [
419 {
420 "name": "somebody",
421 "age": 42
422 },
423 {
424 "name": "somebody else",
425 "age": 43
426 }
427 ]
428 "#
429 .unindent();
430
431 let documents = retriever.parse_file(&text, language.clone()).unwrap();
432
433 assert_documents_eq(
434 &documents,
435 &[(
436 r#"
437 [{
438 "name": "",
439 "age": 42
440 }]"#
441 .unindent(),
442 text.find("[").unwrap(),
443 )],
444 );
445}
446
447fn assert_documents_eq(
448 documents: &[Document],
449 expected_contents_and_start_offsets: &[(String, usize)],
450) {
451 assert_eq!(
452 documents
453 .iter()
454 .map(|document| (document.content.clone(), document.range.start))
455 .collect::<Vec<_>>(),
456 expected_contents_and_start_offsets
457 );
458}
459
460#[gpui::test]
461async fn test_code_context_retrieval_javascript() {
462 let language = js_lang();
463 let embedding_provider = Arc::new(DummyEmbeddings {});
464 let mut retriever = CodeContextRetriever::new(embedding_provider);
465
466 let text = "
467 /* globals importScripts, backend */
468 function _authorize() {}
469
470 /**
471 * Sometimes the frontend build is way faster than backend.
472 */
473 export async function authorizeBank() {
474 _authorize(pushModal, upgradingAccountId, {});
475 }
476
477 export class SettingsPage {
478 /* This is a test setting */
479 constructor(page) {
480 this.page = page;
481 }
482 }
483
484 /* This is a test comment */
485 class TestClass {}
486
487 /* Schema for editor_events in Clickhouse. */
488 export interface ClickhouseEditorEvent {
489 installation_id: string
490 operation: string
491 }
492 "
493 .unindent();
494
495 let documents = retriever.parse_file(&text, language.clone()).unwrap();
496
497 assert_documents_eq(
498 &documents,
499 &[
500 (
501 "
502 /* globals importScripts, backend */
503 function _authorize() {}"
504 .unindent(),
505 37,
506 ),
507 (
508 "
509 /**
510 * Sometimes the frontend build is way faster than backend.
511 */
512 export async function authorizeBank() {
513 _authorize(pushModal, upgradingAccountId, {});
514 }"
515 .unindent(),
516 131,
517 ),
518 (
519 "
520 export class SettingsPage {
521 /* This is a test setting */
522 constructor(page) {
523 this.page = page;
524 }
525 }"
526 .unindent(),
527 225,
528 ),
529 (
530 "
531 /* This is a test setting */
532 constructor(page) {
533 this.page = page;
534 }"
535 .unindent(),
536 290,
537 ),
538 (
539 "
540 /* This is a test comment */
541 class TestClass {}"
542 .unindent(),
543 374,
544 ),
545 (
546 "
547 /* Schema for editor_events in Clickhouse. */
548 export interface ClickhouseEditorEvent {
549 installation_id: string
550 operation: string
551 }"
552 .unindent(),
553 440,
554 ),
555 ],
556 )
557}
558
559#[gpui::test]
560async fn test_code_context_retrieval_lua() {
561 let language = lua_lang();
562 let embedding_provider = Arc::new(DummyEmbeddings {});
563 let mut retriever = CodeContextRetriever::new(embedding_provider);
564
565 let text = r#"
566 -- Creates a new class
567 -- @param baseclass The Baseclass of this class, or nil.
568 -- @return A new class reference.
569 function classes.class(baseclass)
570 -- Create the class definition and metatable.
571 local classdef = {}
572 -- Find the super class, either Object or user-defined.
573 baseclass = baseclass or classes.Object
574 -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
575 setmetatable(classdef, { __index = baseclass })
576 -- All class instances have a reference to the class object.
577 classdef.class = classdef
578 --- Recursivly allocates the inheritance tree of the instance.
579 -- @param mastertable The 'root' of the inheritance tree.
580 -- @return Returns the instance with the allocated inheritance tree.
581 function classdef.alloc(mastertable)
582 -- All class instances have a reference to a superclass object.
583 local instance = { super = baseclass.alloc(mastertable) }
584 -- Any functions this instance does not know of will 'look up' to the superclass definition.
585 setmetatable(instance, { __index = classdef, __newindex = mastertable })
586 return instance
587 end
588 end
589 "#.unindent();
590
591 let documents = retriever.parse_file(&text, language.clone()).unwrap();
592
593 assert_documents_eq(
594 &documents,
595 &[
596 (r#"
597 -- Creates a new class
598 -- @param baseclass The Baseclass of this class, or nil.
599 -- @return A new class reference.
600 function classes.class(baseclass)
601 -- Create the class definition and metatable.
602 local classdef = {}
603 -- Find the super class, either Object or user-defined.
604 baseclass = baseclass or classes.Object
605 -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
606 setmetatable(classdef, { __index = baseclass })
607 -- All class instances have a reference to the class object.
608 classdef.class = classdef
609 --- Recursivly allocates the inheritance tree of the instance.
610 -- @param mastertable The 'root' of the inheritance tree.
611 -- @return Returns the instance with the allocated inheritance tree.
612 function classdef.alloc(mastertable)
613 --[ ... ]--
614 --[ ... ]--
615 end
616 end"#.unindent(),
617 114),
618 (r#"
619 --- Recursivly allocates the inheritance tree of the instance.
620 -- @param mastertable The 'root' of the inheritance tree.
621 -- @return Returns the instance with the allocated inheritance tree.
622 function classdef.alloc(mastertable)
623 -- All class instances have a reference to a superclass object.
624 local instance = { super = baseclass.alloc(mastertable) }
625 -- Any functions this instance does not know of will 'look up' to the superclass definition.
626 setmetatable(instance, { __index = classdef, __newindex = mastertable })
627 return instance
628 end"#.unindent(), 809),
629 ]
630 );
631}
632
633#[gpui::test]
634async fn test_code_context_retrieval_elixir() {
635 let language = elixir_lang();
636 let embedding_provider = Arc::new(DummyEmbeddings {});
637 let mut retriever = CodeContextRetriever::new(embedding_provider);
638
639 let text = r#"
640 defmodule File.Stream do
641 @moduledoc """
642 Defines a `File.Stream` struct returned by `File.stream!/3`.
643
644 The following fields are public:
645
646 * `path` - the file path
647 * `modes` - the file modes
648 * `raw` - a boolean indicating if bin functions should be used
649 * `line_or_bytes` - if reading should read lines or a given number of bytes
650 * `node` - the node the file belongs to
651
652 """
653
654 defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
655
656 @type t :: %__MODULE__{}
657
658 @doc false
659 def __build__(path, modes, line_or_bytes) do
660 raw = :lists.keyfind(:encoding, 1, modes) == false
661
662 modes =
663 case raw do
664 true ->
665 case :lists.keyfind(:read_ahead, 1, modes) do
666 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
667 {:read_ahead, _} -> [:raw | modes]
668 false -> [:raw, :read_ahead | modes]
669 end
670
671 false ->
672 modes
673 end
674
675 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
676
677 end"#
678 .unindent();
679
680 let documents = retriever.parse_file(&text, language.clone()).unwrap();
681
682 assert_documents_eq(
683 &documents,
684 &[(
685 r#"
686 defmodule File.Stream do
687 @moduledoc """
688 Defines a `File.Stream` struct returned by `File.stream!/3`.
689
690 The following fields are public:
691
692 * `path` - the file path
693 * `modes` - the file modes
694 * `raw` - a boolean indicating if bin functions should be used
695 * `line_or_bytes` - if reading should read lines or a given number of bytes
696 * `node` - the node the file belongs to
697
698 """
699
700 defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
701
702 @type t :: %__MODULE__{}
703
704 @doc false
705 def __build__(path, modes, line_or_bytes) do
706 raw = :lists.keyfind(:encoding, 1, modes) == false
707
708 modes =
709 case raw do
710 true ->
711 case :lists.keyfind(:read_ahead, 1, modes) do
712 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
713 {:read_ahead, _} -> [:raw | modes]
714 false -> [:raw, :read_ahead | modes]
715 end
716
717 false ->
718 modes
719 end
720
721 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
722
723 end"#
724 .unindent(),
725 0,
726 ),(r#"
727 @doc false
728 def __build__(path, modes, line_or_bytes) do
729 raw = :lists.keyfind(:encoding, 1, modes) == false
730
731 modes =
732 case raw do
733 true ->
734 case :lists.keyfind(:read_ahead, 1, modes) do
735 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
736 {:read_ahead, _} -> [:raw | modes]
737 false -> [:raw, :read_ahead | modes]
738 end
739
740 false ->
741 modes
742 end
743
744 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
745
746 end"#.unindent(), 574)],
747 );
748}
749
750#[gpui::test]
751async fn test_code_context_retrieval_cpp() {
752 let language = cpp_lang();
753 let embedding_provider = Arc::new(DummyEmbeddings {});
754 let mut retriever = CodeContextRetriever::new(embedding_provider);
755
756 let text = "
757 /**
758 * @brief Main function
759 * @returns 0 on exit
760 */
761 int main() { return 0; }
762
763 /**
764 * This is a test comment
765 */
766 class MyClass { // The class
767 public: // Access specifier
768 int myNum; // Attribute (int variable)
769 string myString; // Attribute (string variable)
770 };
771
772 // This is a test comment
773 enum Color { red, green, blue };
774
775 /** This is a preceding block comment
776 * This is the second line
777 */
778 struct { // Structure declaration
779 int myNum; // Member (int variable)
780 string myString; // Member (string variable)
781 } myStructure;
782
783 /**
784 * @brief Matrix class.
785 */
786 template <typename T,
787 typename = typename std::enable_if<
788 std::is_integral<T>::value || std::is_floating_point<T>::value,
789 bool>::type>
790 class Matrix2 {
791 std::vector<std::vector<T>> _mat;
792
793 public:
794 /**
795 * @brief Constructor
796 * @tparam Integer ensuring integers are being evaluated and not other
797 * data types.
798 * @param size denoting the size of Matrix as size x size
799 */
800 template <typename Integer,
801 typename = typename std::enable_if<std::is_integral<Integer>::value,
802 Integer>::type>
803 explicit Matrix(const Integer size) {
804 for (size_t i = 0; i < size; ++i) {
805 _mat.emplace_back(std::vector<T>(size, 0));
806 }
807 }
808 }"
809 .unindent();
810
811 let documents = retriever.parse_file(&text, language.clone()).unwrap();
812
813 assert_documents_eq(
814 &documents,
815 &[
816 (
817 "
818 /**
819 * @brief Main function
820 * @returns 0 on exit
821 */
822 int main() { return 0; }"
823 .unindent(),
824 54,
825 ),
826 (
827 "
828 /**
829 * This is a test comment
830 */
831 class MyClass { // The class
832 public: // Access specifier
833 int myNum; // Attribute (int variable)
834 string myString; // Attribute (string variable)
835 }"
836 .unindent(),
837 112,
838 ),
839 (
840 "
841 // This is a test comment
842 enum Color { red, green, blue }"
843 .unindent(),
844 322,
845 ),
846 (
847 "
848 /** This is a preceding block comment
849 * This is the second line
850 */
851 struct { // Structure declaration
852 int myNum; // Member (int variable)
853 string myString; // Member (string variable)
854 } myStructure;"
855 .unindent(),
856 425,
857 ),
858 (
859 "
860 /**
861 * @brief Matrix class.
862 */
863 template <typename T,
864 typename = typename std::enable_if<
865 std::is_integral<T>::value || std::is_floating_point<T>::value,
866 bool>::type>
867 class Matrix2 {
868 std::vector<std::vector<T>> _mat;
869
870 public:
871 /**
872 * @brief Constructor
873 * @tparam Integer ensuring integers are being evaluated and not other
874 * data types.
875 * @param size denoting the size of Matrix as size x size
876 */
877 template <typename Integer,
878 typename = typename std::enable_if<std::is_integral<Integer>::value,
879 Integer>::type>
880 explicit Matrix(const Integer size) {
881 for (size_t i = 0; i < size; ++i) {
882 _mat.emplace_back(std::vector<T>(size, 0));
883 }
884 }
885 }"
886 .unindent(),
887 612,
888 ),
889 (
890 "
891 explicit Matrix(const Integer size) {
892 for (size_t i = 0; i < size; ++i) {
893 _mat.emplace_back(std::vector<T>(size, 0));
894 }
895 }"
896 .unindent(),
897 1226,
898 ),
899 ],
900 );
901}
902
903#[gpui::test]
904async fn test_code_context_retrieval_ruby() {
905 let language = ruby_lang();
906 let embedding_provider = Arc::new(DummyEmbeddings {});
907 let mut retriever = CodeContextRetriever::new(embedding_provider);
908
909 let text = r#"
910 # This concern is inspired by "sudo mode" on GitHub. It
911 # is a way to re-authenticate a user before allowing them
912 # to see or perform an action.
913 #
914 # Add `before_action :require_challenge!` to actions you
915 # want to protect.
916 #
917 # The user will be shown a page to enter the challenge (which
918 # is either the password, or just the username when no
919 # password exists). Upon passing, there is a grace period
920 # during which no challenge will be asked from the user.
921 #
922 # Accessing challenge-protected resources during the grace
923 # period will refresh the grace period.
924 module ChallengableConcern
925 extend ActiveSupport::Concern
926
927 CHALLENGE_TIMEOUT = 1.hour.freeze
928
929 def require_challenge!
930 return if skip_challenge?
931
932 if challenge_passed_recently?
933 session[:challenge_passed_at] = Time.now.utc
934 return
935 end
936
937 @challenge = Form::Challenge.new(return_to: request.url)
938
939 if params.key?(:form_challenge)
940 if challenge_passed?
941 session[:challenge_passed_at] = Time.now.utc
942 else
943 flash.now[:alert] = I18n.t('challenge.invalid_password')
944 render_challenge
945 end
946 else
947 render_challenge
948 end
949 end
950
951 def challenge_passed?
952 current_user.valid_password?(challenge_params[:current_password])
953 end
954 end
955
956 class Animal
957 include Comparable
958
959 attr_reader :legs
960
961 def initialize(name, legs)
962 @name, @legs = name, legs
963 end
964
965 def <=>(other)
966 legs <=> other.legs
967 end
968 end
969
970 # Singleton method for car object
971 def car.wheels
972 puts "There are four wheels"
973 end"#
974 .unindent();
975
976 let documents = retriever.parse_file(&text, language.clone()).unwrap();
977
978 assert_documents_eq(
979 &documents,
980 &[
981 (
982 r#"
983 # This concern is inspired by "sudo mode" on GitHub. It
984 # is a way to re-authenticate a user before allowing them
985 # to see or perform an action.
986 #
987 # Add `before_action :require_challenge!` to actions you
988 # want to protect.
989 #
990 # The user will be shown a page to enter the challenge (which
991 # is either the password, or just the username when no
992 # password exists). Upon passing, there is a grace period
993 # during which no challenge will be asked from the user.
994 #
995 # Accessing challenge-protected resources during the grace
996 # period will refresh the grace period.
997 module ChallengableConcern
998 extend ActiveSupport::Concern
999
1000 CHALLENGE_TIMEOUT = 1.hour.freeze
1001
1002 def require_challenge!
1003 # ...
1004 end
1005
1006 def challenge_passed?
1007 # ...
1008 end
1009 end"#
1010 .unindent(),
1011 558,
1012 ),
1013 (
1014 r#"
1015 def require_challenge!
1016 return if skip_challenge?
1017
1018 if challenge_passed_recently?
1019 session[:challenge_passed_at] = Time.now.utc
1020 return
1021 end
1022
1023 @challenge = Form::Challenge.new(return_to: request.url)
1024
1025 if params.key?(:form_challenge)
1026 if challenge_passed?
1027 session[:challenge_passed_at] = Time.now.utc
1028 else
1029 flash.now[:alert] = I18n.t('challenge.invalid_password')
1030 render_challenge
1031 end
1032 else
1033 render_challenge
1034 end
1035 end"#
1036 .unindent(),
1037 663,
1038 ),
1039 (
1040 r#"
1041 def challenge_passed?
1042 current_user.valid_password?(challenge_params[:current_password])
1043 end"#
1044 .unindent(),
1045 1254,
1046 ),
1047 (
1048 r#"
1049 class Animal
1050 include Comparable
1051
1052 attr_reader :legs
1053
1054 def initialize(name, legs)
1055 # ...
1056 end
1057
1058 def <=>(other)
1059 # ...
1060 end
1061 end"#
1062 .unindent(),
1063 1363,
1064 ),
1065 (
1066 r#"
1067 def initialize(name, legs)
1068 @name, @legs = name, legs
1069 end"#
1070 .unindent(),
1071 1427,
1072 ),
1073 (
1074 r#"
1075 def <=>(other)
1076 legs <=> other.legs
1077 end"#
1078 .unindent(),
1079 1501,
1080 ),
1081 (
1082 r#"
1083 # Singleton method for car object
1084 def car.wheels
1085 puts "There are four wheels"
1086 end"#
1087 .unindent(),
1088 1591,
1089 ),
1090 ],
1091 );
1092}
1093
1094#[gpui::test]
1095async fn test_code_context_retrieval_php() {
1096 let language = php_lang();
1097 let embedding_provider = Arc::new(DummyEmbeddings {});
1098 let mut retriever = CodeContextRetriever::new(embedding_provider);
1099
1100 let text = r#"
1101 <?php
1102
1103 namespace LevelUp\Experience\Concerns;
1104
1105 /*
1106 This is a multiple-lines comment block
1107 that spans over multiple
1108 lines
1109 */
1110 function functionName() {
1111 echo "Hello world!";
1112 }
1113
1114 trait HasAchievements
1115 {
1116 /**
1117 * @throws \Exception
1118 */
1119 public function grantAchievement(Achievement $achievement, $progress = null): void
1120 {
1121 if ($progress > 100) {
1122 throw new Exception(message: 'Progress cannot be greater than 100');
1123 }
1124
1125 if ($this->achievements()->find($achievement->id)) {
1126 throw new Exception(message: 'User already has this Achievement');
1127 }
1128
1129 $this->achievements()->attach($achievement, [
1130 'progress' => $progress ?? null,
1131 ]);
1132
1133 $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1134 }
1135
1136 public function achievements(): BelongsToMany
1137 {
1138 return $this->belongsToMany(related: Achievement::class)
1139 ->withPivot(columns: 'progress')
1140 ->where('is_secret', false)
1141 ->using(AchievementUser::class);
1142 }
1143 }
1144
1145 interface Multiplier
1146 {
1147 public function qualifies(array $data): bool;
1148
1149 public function setMultiplier(): int;
1150 }
1151
1152 enum AuditType: string
1153 {
1154 case Add = 'add';
1155 case Remove = 'remove';
1156 case Reset = 'reset';
1157 case LevelUp = 'level_up';
1158 }
1159
1160 ?>"#
1161 .unindent();
1162
1163 let documents = retriever.parse_file(&text, language.clone()).unwrap();
1164
1165 assert_documents_eq(
1166 &documents,
1167 &[
1168 (
1169 r#"
1170 /*
1171 This is a multiple-lines comment block
1172 that spans over multiple
1173 lines
1174 */
1175 function functionName() {
1176 echo "Hello world!";
1177 }"#
1178 .unindent(),
1179 123,
1180 ),
1181 (
1182 r#"
1183 trait HasAchievements
1184 {
1185 /**
1186 * @throws \Exception
1187 */
1188 public function grantAchievement(Achievement $achievement, $progress = null): void
1189 {/* ... */}
1190
1191 public function achievements(): BelongsToMany
1192 {/* ... */}
1193 }"#
1194 .unindent(),
1195 177,
1196 ),
1197 (r#"
1198 /**
1199 * @throws \Exception
1200 */
1201 public function grantAchievement(Achievement $achievement, $progress = null): void
1202 {
1203 if ($progress > 100) {
1204 throw new Exception(message: 'Progress cannot be greater than 100');
1205 }
1206
1207 if ($this->achievements()->find($achievement->id)) {
1208 throw new Exception(message: 'User already has this Achievement');
1209 }
1210
1211 $this->achievements()->attach($achievement, [
1212 'progress' => $progress ?? null,
1213 ]);
1214
1215 $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1216 }"#.unindent(), 245),
1217 (r#"
1218 public function achievements(): BelongsToMany
1219 {
1220 return $this->belongsToMany(related: Achievement::class)
1221 ->withPivot(columns: 'progress')
1222 ->where('is_secret', false)
1223 ->using(AchievementUser::class);
1224 }"#.unindent(), 902),
1225 (r#"
1226 interface Multiplier
1227 {
1228 public function qualifies(array $data): bool;
1229
1230 public function setMultiplier(): int;
1231 }"#.unindent(),
1232 1146),
1233 (r#"
1234 enum AuditType: string
1235 {
1236 case Add = 'add';
1237 case Remove = 'remove';
1238 case Reset = 'reset';
1239 case LevelUp = 'level_up';
1240 }"#.unindent(), 1265)
1241 ],
1242 );
1243}
1244
1245#[gpui::test]
1246fn test_dot_product(mut rng: StdRng) {
1247 assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
1248 assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
1249
1250 for _ in 0..100 {
1251 let size = 1536;
1252 let mut a = vec![0.; size];
1253 let mut b = vec![0.; size];
1254 for (a, b) in a.iter_mut().zip(b.iter_mut()) {
1255 *a = rng.gen();
1256 *b = rng.gen();
1257 }
1258
1259 assert_eq!(
1260 round_to_decimals(dot(&a, &b), 1),
1261 round_to_decimals(reference_dot(&a, &b), 1)
1262 );
1263 }
1264
1265 fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
1266 let factor = (10.0 as f32).powi(decimal_places);
1267 (n * factor).round() / factor
1268 }
1269
1270 fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
1271 a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
1272 }
1273}
1274
1275#[derive(Default)]
1276struct FakeEmbeddingProvider {
1277 embedding_count: AtomicUsize,
1278}
1279
1280impl FakeEmbeddingProvider {
1281 fn embedding_count(&self) -> usize {
1282 self.embedding_count.load(atomic::Ordering::SeqCst)
1283 }
1284
1285 fn embed_sync(&self, span: &str) -> Vec<f32> {
1286 let mut result = vec![1.0; 26];
1287 for letter in span.chars() {
1288 let letter = letter.to_ascii_lowercase();
1289 if letter as u32 >= 'a' as u32 {
1290 let ix = (letter as u32) - ('a' as u32);
1291 if ix < 26 {
1292 result[ix as usize] += 1.0;
1293 }
1294 }
1295 }
1296
1297 let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
1298 for x in &mut result {
1299 *x /= norm;
1300 }
1301
1302 result
1303 }
1304}
1305
1306#[async_trait]
1307impl EmbeddingProvider for FakeEmbeddingProvider {
1308 fn truncate(&self, span: &str) -> (String, usize) {
1309 (span.to_string(), 1)
1310 }
1311
1312 fn max_tokens_per_batch(&self) -> usize {
1313 200
1314 }
1315
1316 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
1317 self.embedding_count
1318 .fetch_add(spans.len(), atomic::Ordering::SeqCst);
1319 Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
1320 }
1321}
1322
1323fn js_lang() -> Arc<Language> {
1324 Arc::new(
1325 Language::new(
1326 LanguageConfig {
1327 name: "Javascript".into(),
1328 path_suffixes: vec!["js".into()],
1329 ..Default::default()
1330 },
1331 Some(tree_sitter_typescript::language_tsx()),
1332 )
1333 .with_embedding_query(
1334 &r#"
1335
1336 (
1337 (comment)* @context
1338 .
1339 [
1340 (export_statement
1341 (function_declaration
1342 "async"? @name
1343 "function" @name
1344 name: (_) @name))
1345 (function_declaration
1346 "async"? @name
1347 "function" @name
1348 name: (_) @name)
1349 ] @item
1350 )
1351
1352 (
1353 (comment)* @context
1354 .
1355 [
1356 (export_statement
1357 (class_declaration
1358 "class" @name
1359 name: (_) @name))
1360 (class_declaration
1361 "class" @name
1362 name: (_) @name)
1363 ] @item
1364 )
1365
1366 (
1367 (comment)* @context
1368 .
1369 [
1370 (export_statement
1371 (interface_declaration
1372 "interface" @name
1373 name: (_) @name))
1374 (interface_declaration
1375 "interface" @name
1376 name: (_) @name)
1377 ] @item
1378 )
1379
1380 (
1381 (comment)* @context
1382 .
1383 [
1384 (export_statement
1385 (enum_declaration
1386 "enum" @name
1387 name: (_) @name))
1388 (enum_declaration
1389 "enum" @name
1390 name: (_) @name)
1391 ] @item
1392 )
1393
1394 (
1395 (comment)* @context
1396 .
1397 (method_definition
1398 [
1399 "get"
1400 "set"
1401 "async"
1402 "*"
1403 "static"
1404 ]* @name
1405 name: (_) @name) @item
1406 )
1407
1408 "#
1409 .unindent(),
1410 )
1411 .unwrap(),
1412 )
1413}
1414
1415fn rust_lang() -> Arc<Language> {
1416 Arc::new(
1417 Language::new(
1418 LanguageConfig {
1419 name: "Rust".into(),
1420 path_suffixes: vec!["rs".into()],
1421 collapsed_placeholder: " /* ... */ ".to_string(),
1422 ..Default::default()
1423 },
1424 Some(tree_sitter_rust::language()),
1425 )
1426 .with_embedding_query(
1427 r#"
1428 (
1429 [(line_comment) (attribute_item)]* @context
1430 .
1431 [
1432 (struct_item
1433 name: (_) @name)
1434
1435 (enum_item
1436 name: (_) @name)
1437
1438 (impl_item
1439 trait: (_)? @name
1440 "for"? @name
1441 type: (_) @name)
1442
1443 (trait_item
1444 name: (_) @name)
1445
1446 (function_item
1447 name: (_) @name
1448 body: (block
1449 "{" @keep
1450 "}" @keep) @collapse)
1451
1452 (macro_definition
1453 name: (_) @name)
1454 ] @item
1455 )
1456 "#,
1457 )
1458 .unwrap(),
1459 )
1460}
1461
1462fn json_lang() -> Arc<Language> {
1463 Arc::new(
1464 Language::new(
1465 LanguageConfig {
1466 name: "JSON".into(),
1467 path_suffixes: vec!["json".into()],
1468 ..Default::default()
1469 },
1470 Some(tree_sitter_json::language()),
1471 )
1472 .with_embedding_query(
1473 r#"
1474 (document) @item
1475
1476 (array
1477 "[" @keep
1478 .
1479 (object)? @keep
1480 "]" @keep) @collapse
1481
1482 (pair value: (string
1483 "\"" @keep
1484 "\"" @keep) @collapse)
1485 "#,
1486 )
1487 .unwrap(),
1488 )
1489}
1490
1491fn toml_lang() -> Arc<Language> {
1492 Arc::new(Language::new(
1493 LanguageConfig {
1494 name: "TOML".into(),
1495 path_suffixes: vec!["toml".into()],
1496 ..Default::default()
1497 },
1498 Some(tree_sitter_toml::language()),
1499 ))
1500}
1501
1502fn cpp_lang() -> Arc<Language> {
1503 Arc::new(
1504 Language::new(
1505 LanguageConfig {
1506 name: "CPP".into(),
1507 path_suffixes: vec!["cpp".into()],
1508 ..Default::default()
1509 },
1510 Some(tree_sitter_cpp::language()),
1511 )
1512 .with_embedding_query(
1513 r#"
1514 (
1515 (comment)* @context
1516 .
1517 (function_definition
1518 (type_qualifier)? @name
1519 type: (_)? @name
1520 declarator: [
1521 (function_declarator
1522 declarator: (_) @name)
1523 (pointer_declarator
1524 "*" @name
1525 declarator: (function_declarator
1526 declarator: (_) @name))
1527 (pointer_declarator
1528 "*" @name
1529 declarator: (pointer_declarator
1530 "*" @name
1531 declarator: (function_declarator
1532 declarator: (_) @name)))
1533 (reference_declarator
1534 ["&" "&&"] @name
1535 (function_declarator
1536 declarator: (_) @name))
1537 ]
1538 (type_qualifier)? @name) @item
1539 )
1540
1541 (
1542 (comment)* @context
1543 .
1544 (template_declaration
1545 (class_specifier
1546 "class" @name
1547 name: (_) @name)
1548 ) @item
1549 )
1550
1551 (
1552 (comment)* @context
1553 .
1554 (class_specifier
1555 "class" @name
1556 name: (_) @name) @item
1557 )
1558
1559 (
1560 (comment)* @context
1561 .
1562 (enum_specifier
1563 "enum" @name
1564 name: (_) @name) @item
1565 )
1566
1567 (
1568 (comment)* @context
1569 .
1570 (declaration
1571 type: (struct_specifier
1572 "struct" @name)
1573 declarator: (_) @name) @item
1574 )
1575
1576 "#,
1577 )
1578 .unwrap(),
1579 )
1580}
1581
1582fn lua_lang() -> Arc<Language> {
1583 Arc::new(
1584 Language::new(
1585 LanguageConfig {
1586 name: "Lua".into(),
1587 path_suffixes: vec!["lua".into()],
1588 collapsed_placeholder: "--[ ... ]--".to_string(),
1589 ..Default::default()
1590 },
1591 Some(tree_sitter_lua::language()),
1592 )
1593 .with_embedding_query(
1594 r#"
1595 (
1596 (comment)* @context
1597 .
1598 (function_declaration
1599 "function" @name
1600 name: (_) @name
1601 (comment)* @collapse
1602 body: (block) @collapse
1603 ) @item
1604 )
1605 "#,
1606 )
1607 .unwrap(),
1608 )
1609}
1610
1611fn php_lang() -> Arc<Language> {
1612 Arc::new(
1613 Language::new(
1614 LanguageConfig {
1615 name: "PHP".into(),
1616 path_suffixes: vec!["php".into()],
1617 collapsed_placeholder: "/* ... */".into(),
1618 ..Default::default()
1619 },
1620 Some(tree_sitter_php::language()),
1621 )
1622 .with_embedding_query(
1623 r#"
1624 (
1625 (comment)* @context
1626 .
1627 [
1628 (function_definition
1629 "function" @name
1630 name: (_) @name
1631 body: (_
1632 "{" @keep
1633 "}" @keep) @collapse
1634 )
1635
1636 (trait_declaration
1637 "trait" @name
1638 name: (_) @name)
1639
1640 (method_declaration
1641 "function" @name
1642 name: (_) @name
1643 body: (_
1644 "{" @keep
1645 "}" @keep) @collapse
1646 )
1647
1648 (interface_declaration
1649 "interface" @name
1650 name: (_) @name
1651 )
1652
1653 (enum_declaration
1654 "enum" @name
1655 name: (_) @name
1656 )
1657
1658 ] @item
1659 )
1660 "#,
1661 )
1662 .unwrap(),
1663 )
1664}
1665
1666fn ruby_lang() -> Arc<Language> {
1667 Arc::new(
1668 Language::new(
1669 LanguageConfig {
1670 name: "Ruby".into(),
1671 path_suffixes: vec!["rb".into()],
1672 collapsed_placeholder: "# ...".to_string(),
1673 ..Default::default()
1674 },
1675 Some(tree_sitter_ruby::language()),
1676 )
1677 .with_embedding_query(
1678 r#"
1679 (
1680 (comment)* @context
1681 .
1682 [
1683 (module
1684 "module" @name
1685 name: (_) @name)
1686 (method
1687 "def" @name
1688 name: (_) @name
1689 body: (body_statement) @collapse)
1690 (class
1691 "class" @name
1692 name: (_) @name)
1693 (singleton_method
1694 "def" @name
1695 object: (_) @name
1696 "." @name
1697 name: (_) @name
1698 body: (body_statement) @collapse)
1699 ] @item
1700 )
1701 "#,
1702 )
1703 .unwrap(),
1704 )
1705}
1706
1707fn elixir_lang() -> Arc<Language> {
1708 Arc::new(
1709 Language::new(
1710 LanguageConfig {
1711 name: "Elixir".into(),
1712 path_suffixes: vec!["rs".into()],
1713 ..Default::default()
1714 },
1715 Some(tree_sitter_elixir::language()),
1716 )
1717 .with_embedding_query(
1718 r#"
1719 (
1720 (unary_operator
1721 operator: "@"
1722 operand: (call
1723 target: (identifier) @unary
1724 (#match? @unary "^(doc)$"))
1725 ) @context
1726 .
1727 (call
1728 target: (identifier) @name
1729 (arguments
1730 [
1731 (identifier) @name
1732 (call
1733 target: (identifier) @name)
1734 (binary_operator
1735 left: (call
1736 target: (identifier) @name)
1737 operator: "when")
1738 ])
1739 (#match? @name "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @item
1740 )
1741
1742 (call
1743 target: (identifier) @name
1744 (arguments (alias) @name)
1745 (#match? @name "^(defmodule|defprotocol)$")) @item
1746 "#,
1747 )
1748 .unwrap(),
1749 )
1750}
1751
1752#[gpui::test]
1753fn test_subtract_ranges() {
1754 // collapsed_ranges: Vec<Range<usize>>, keep_ranges: Vec<Range<usize>>
1755
1756 assert_eq!(
1757 subtract_ranges(&[0..5, 10..21], &[0..1, 4..5]),
1758 vec![1..4, 10..21]
1759 );
1760
1761 assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
1762}
1763
1764fn init_test(cx: &mut TestAppContext) {
1765 cx.update(|cx| {
1766 cx.set_global(SettingsStore::test(cx));
1767 settings::register::<SemanticIndexSettings>(cx);
1768 settings::register::<ProjectSettings>(cx);
1769 });
1770}