1use crate::{
2 embedding_queue::EmbeddingQueue,
3 parsing::{subtract_ranges, CodeContextRetriever, Span, SpanDigest},
4 semantic_index_settings::SemanticIndexSettings,
5 FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
6};
7use ai::providers::dummy::{DummyEmbeddingProvider, DummyLanguageModel};
8use ai::{
9 embedding::{Embedding, EmbeddingProvider},
10 models::LanguageModel,
11};
12use anyhow::Result;
13use async_trait::async_trait;
14use gpui::{executor::Deterministic, Task, TestAppContext};
15use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
16use parking_lot::Mutex;
17use pretty_assertions::assert_eq;
18use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs, Project};
19use rand::{rngs::StdRng, Rng};
20use serde_json::json;
21use settings::SettingsStore;
22use std::{
23 path::Path,
24 sync::{
25 atomic::{self, AtomicUsize},
26 Arc,
27 },
28 time::{Instant, SystemTime},
29};
30use unindent::Unindent;
31use util::RandomCharIter;
32
33#[ctor::ctor]
34fn init_logger() {
35 if std::env::var("RUST_LOG").is_ok() {
36 env_logger::init();
37 }
38}
39
40#[gpui::test]
41async fn test_semantic_index(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
42 init_test(cx);
43
44 let fs = FakeFs::new(cx.background());
45 fs.insert_tree(
46 "/the-root",
47 json!({
48 "src": {
49 "file1.rs": "
50 fn aaa() {
51 println!(\"aaaaaaaaaaaa!\");
52 }
53
54 fn zzzzz() {
55 println!(\"SLEEPING\");
56 }
57 ".unindent(),
58 "file2.rs": "
59 fn bbb() {
60 println!(\"bbbbbbbbbbbbb!\");
61 }
62 struct pqpqpqp {}
63 ".unindent(),
64 "file3.toml": "
65 ZZZZZZZZZZZZZZZZZZ = 5
66 ".unindent(),
67 }
68 }),
69 )
70 .await;
71
72 let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
73 let rust_language = rust_lang();
74 let toml_language = toml_lang();
75 languages.add(rust_language);
76 languages.add(toml_language);
77
78 let db_dir = tempdir::TempDir::new("vector-store").unwrap();
79 let db_path = db_dir.path().join("db.sqlite");
80
81 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
82 let semantic_index = SemanticIndex::new(
83 fs.clone(),
84 db_path,
85 embedding_provider.clone(),
86 languages,
87 cx.to_async(),
88 )
89 .await
90 .unwrap();
91
92 let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
93
94 let search_results = semantic_index.update(cx, |store, cx| {
95 store.search_project(
96 project.clone(),
97 "aaaaaabbbbzz".to_string(),
98 5,
99 vec![],
100 vec![],
101 cx,
102 )
103 });
104 let pending_file_count =
105 semantic_index.read_with(cx, |index, _| index.pending_file_count(&project).unwrap());
106 deterministic.run_until_parked();
107 assert_eq!(*pending_file_count.borrow(), 3);
108 deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
109 assert_eq!(*pending_file_count.borrow(), 0);
110
111 let search_results = search_results.await.unwrap();
112 assert_search_results(
113 &search_results,
114 &[
115 (Path::new("src/file1.rs").into(), 0),
116 (Path::new("src/file2.rs").into(), 0),
117 (Path::new("src/file3.toml").into(), 0),
118 (Path::new("src/file1.rs").into(), 45),
119 (Path::new("src/file2.rs").into(), 45),
120 ],
121 cx,
122 );
123
124 // Test Include Files Functonality
125 let include_files = vec![PathMatcher::new("*.rs").unwrap()];
126 let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
127 let rust_only_search_results = semantic_index
128 .update(cx, |store, cx| {
129 store.search_project(
130 project.clone(),
131 "aaaaaabbbbzz".to_string(),
132 5,
133 include_files,
134 vec![],
135 cx,
136 )
137 })
138 .await
139 .unwrap();
140
141 assert_search_results(
142 &rust_only_search_results,
143 &[
144 (Path::new("src/file1.rs").into(), 0),
145 (Path::new("src/file2.rs").into(), 0),
146 (Path::new("src/file1.rs").into(), 45),
147 (Path::new("src/file2.rs").into(), 45),
148 ],
149 cx,
150 );
151
152 let no_rust_search_results = semantic_index
153 .update(cx, |store, cx| {
154 store.search_project(
155 project.clone(),
156 "aaaaaabbbbzz".to_string(),
157 5,
158 vec![],
159 exclude_files,
160 cx,
161 )
162 })
163 .await
164 .unwrap();
165
166 assert_search_results(
167 &no_rust_search_results,
168 &[(Path::new("src/file3.toml").into(), 0)],
169 cx,
170 );
171
172 fs.save(
173 "/the-root/src/file2.rs".as_ref(),
174 &"
175 fn dddd() { println!(\"ddddd!\"); }
176 struct pqpqpqp {}
177 "
178 .unindent()
179 .into(),
180 Default::default(),
181 )
182 .await
183 .unwrap();
184
185 deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
186
187 let prev_embedding_count = embedding_provider.embedding_count();
188 let index = semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx));
189 deterministic.run_until_parked();
190 assert_eq!(*pending_file_count.borrow(), 1);
191 deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
192 assert_eq!(*pending_file_count.borrow(), 0);
193 index.await.unwrap();
194
195 assert_eq!(
196 embedding_provider.embedding_count() - prev_embedding_count,
197 1
198 );
199}
200
201#[gpui::test(iterations = 10)]
202async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
203 let (outstanding_job_count, _) = postage::watch::channel_with(0);
204 let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count));
205
206 let files = (1..=3)
207 .map(|file_ix| FileToEmbed {
208 worktree_id: 5,
209 path: Path::new(&format!("path-{file_ix}")).into(),
210 mtime: SystemTime::now(),
211 spans: (0..rng.gen_range(4..22))
212 .map(|document_ix| {
213 let content_len = rng.gen_range(10..100);
214 let content = RandomCharIter::new(&mut rng)
215 .with_simple_text()
216 .take(content_len)
217 .collect::<String>();
218 let digest = SpanDigest::from(content.as_str());
219 Span {
220 range: 0..10,
221 embedding: None,
222 name: format!("document {document_ix}"),
223 content,
224 digest,
225 token_count: rng.gen_range(10..30),
226 }
227 })
228 .collect(),
229 job_handle: JobHandle::new(&outstanding_job_count),
230 })
231 .collect::<Vec<_>>();
232
233 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
234
235 let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
236 for file in &files {
237 queue.push(file.clone());
238 }
239 queue.flush();
240
241 cx.foreground().run_until_parked();
242 let finished_files = queue.finished_files();
243 let mut embedded_files: Vec<_> = files
244 .iter()
245 .map(|_| finished_files.try_recv().expect("no finished file"))
246 .collect();
247
248 let expected_files: Vec<_> = files
249 .iter()
250 .map(|file| {
251 let mut file = file.clone();
252 for doc in &mut file.spans {
253 doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
254 }
255 file
256 })
257 .collect();
258
259 embedded_files.sort_by_key(|f| f.path.clone());
260
261 assert_eq!(embedded_files, expected_files);
262}
263
264#[track_caller]
265fn assert_search_results(
266 actual: &[SearchResult],
267 expected: &[(Arc<Path>, usize)],
268 cx: &TestAppContext,
269) {
270 let actual = actual
271 .iter()
272 .map(|search_result| {
273 search_result.buffer.read_with(cx, |buffer, _cx| {
274 (
275 buffer.file().unwrap().path().clone(),
276 search_result.range.start.to_offset(buffer),
277 )
278 })
279 })
280 .collect::<Vec<_>>();
281 assert_eq!(actual, expected);
282}
283
284#[gpui::test]
285async fn test_code_context_retrieval_rust() {
286 let language = rust_lang();
287 let embedding_provider = Arc::new(DummyEmbeddingProvider {});
288 let mut retriever = CodeContextRetriever::new(embedding_provider);
289
290 let text = "
291 /// A doc comment
292 /// that spans multiple lines
293 #[gpui::test]
294 fn a() {
295 b
296 }
297
298 impl C for D {
299 }
300
301 impl E {
302 // This is also a preceding comment
303 pub fn function_1() -> Option<()> {
304 todo!();
305 }
306
307 // This is a preceding comment
308 fn function_2() -> Result<()> {
309 todo!();
310 }
311 }
312
313 #[derive(Clone)]
314 struct D {
315 name: String
316 }
317 "
318 .unindent();
319
320 let documents = retriever.parse_file(&text, language).unwrap();
321
322 assert_documents_eq(
323 &documents,
324 &[
325 (
326 "
327 /// A doc comment
328 /// that spans multiple lines
329 #[gpui::test]
330 fn a() {
331 b
332 }"
333 .unindent(),
334 text.find("fn a").unwrap(),
335 ),
336 (
337 "
338 impl C for D {
339 }"
340 .unindent(),
341 text.find("impl C").unwrap(),
342 ),
343 (
344 "
345 impl E {
346 // This is also a preceding comment
347 pub fn function_1() -> Option<()> { /* ... */ }
348
349 // This is a preceding comment
350 fn function_2() -> Result<()> { /* ... */ }
351 }"
352 .unindent(),
353 text.find("impl E").unwrap(),
354 ),
355 (
356 "
357 // This is also a preceding comment
358 pub fn function_1() -> Option<()> {
359 todo!();
360 }"
361 .unindent(),
362 text.find("pub fn function_1").unwrap(),
363 ),
364 (
365 "
366 // This is a preceding comment
367 fn function_2() -> Result<()> {
368 todo!();
369 }"
370 .unindent(),
371 text.find("fn function_2").unwrap(),
372 ),
373 (
374 "
375 #[derive(Clone)]
376 struct D {
377 name: String
378 }"
379 .unindent(),
380 text.find("struct D").unwrap(),
381 ),
382 ],
383 );
384}
385
386#[gpui::test]
387async fn test_code_context_retrieval_json() {
388 let language = json_lang();
389 let embedding_provider = Arc::new(DummyEmbeddingProvider {});
390 let mut retriever = CodeContextRetriever::new(embedding_provider);
391
392 let text = r#"
393 {
394 "array": [1, 2, 3, 4],
395 "string": "abcdefg",
396 "nested_object": {
397 "array_2": [5, 6, 7, 8],
398 "string_2": "hijklmnop",
399 "boolean": true,
400 "none": null
401 }
402 }
403 "#
404 .unindent();
405
406 let documents = retriever.parse_file(&text, language.clone()).unwrap();
407
408 assert_documents_eq(
409 &documents,
410 &[(
411 r#"
412 {
413 "array": [],
414 "string": "",
415 "nested_object": {
416 "array_2": [],
417 "string_2": "",
418 "boolean": true,
419 "none": null
420 }
421 }"#
422 .unindent(),
423 text.find("{").unwrap(),
424 )],
425 );
426
427 let text = r#"
428 [
429 {
430 "name": "somebody",
431 "age": 42
432 },
433 {
434 "name": "somebody else",
435 "age": 43
436 }
437 ]
438 "#
439 .unindent();
440
441 let documents = retriever.parse_file(&text, language.clone()).unwrap();
442
443 assert_documents_eq(
444 &documents,
445 &[(
446 r#"
447 [{
448 "name": "",
449 "age": 42
450 }]"#
451 .unindent(),
452 text.find("[").unwrap(),
453 )],
454 );
455}
456
457fn assert_documents_eq(
458 documents: &[Span],
459 expected_contents_and_start_offsets: &[(String, usize)],
460) {
461 assert_eq!(
462 documents
463 .iter()
464 .map(|document| (document.content.clone(), document.range.start))
465 .collect::<Vec<_>>(),
466 expected_contents_and_start_offsets
467 );
468}
469
470#[gpui::test]
471async fn test_code_context_retrieval_javascript() {
472 let language = js_lang();
473 let embedding_provider = Arc::new(DummyEmbeddingProvider {});
474 let mut retriever = CodeContextRetriever::new(embedding_provider);
475
476 let text = "
477 /* globals importScripts, backend */
478 function _authorize() {}
479
480 /**
481 * Sometimes the frontend build is way faster than backend.
482 */
483 export async function authorizeBank() {
484 _authorize(pushModal, upgradingAccountId, {});
485 }
486
487 export class SettingsPage {
488 /* This is a test setting */
489 constructor(page) {
490 this.page = page;
491 }
492 }
493
494 /* This is a test comment */
495 class TestClass {}
496
497 /* Schema for editor_events in Clickhouse. */
498 export interface ClickhouseEditorEvent {
499 installation_id: string
500 operation: string
501 }
502 "
503 .unindent();
504
505 let documents = retriever.parse_file(&text, language.clone()).unwrap();
506
507 assert_documents_eq(
508 &documents,
509 &[
510 (
511 "
512 /* globals importScripts, backend */
513 function _authorize() {}"
514 .unindent(),
515 37,
516 ),
517 (
518 "
519 /**
520 * Sometimes the frontend build is way faster than backend.
521 */
522 export async function authorizeBank() {
523 _authorize(pushModal, upgradingAccountId, {});
524 }"
525 .unindent(),
526 131,
527 ),
528 (
529 "
530 export class SettingsPage {
531 /* This is a test setting */
532 constructor(page) {
533 this.page = page;
534 }
535 }"
536 .unindent(),
537 225,
538 ),
539 (
540 "
541 /* This is a test setting */
542 constructor(page) {
543 this.page = page;
544 }"
545 .unindent(),
546 290,
547 ),
548 (
549 "
550 /* This is a test comment */
551 class TestClass {}"
552 .unindent(),
553 374,
554 ),
555 (
556 "
557 /* Schema for editor_events in Clickhouse. */
558 export interface ClickhouseEditorEvent {
559 installation_id: string
560 operation: string
561 }"
562 .unindent(),
563 440,
564 ),
565 ],
566 )
567}
568
569#[gpui::test]
570async fn test_code_context_retrieval_lua() {
571 let language = lua_lang();
572 let embedding_provider = Arc::new(DummyEmbeddingProvider {});
573 let mut retriever = CodeContextRetriever::new(embedding_provider);
574
575 let text = r#"
576 -- Creates a new class
577 -- @param baseclass The Baseclass of this class, or nil.
578 -- @return A new class reference.
579 function classes.class(baseclass)
580 -- Create the class definition and metatable.
581 local classdef = {}
582 -- Find the super class, either Object or user-defined.
583 baseclass = baseclass or classes.Object
584 -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
585 setmetatable(classdef, { __index = baseclass })
586 -- All class instances have a reference to the class object.
587 classdef.class = classdef
588 --- Recursivly allocates the inheritance tree of the instance.
589 -- @param mastertable The 'root' of the inheritance tree.
590 -- @return Returns the instance with the allocated inheritance tree.
591 function classdef.alloc(mastertable)
592 -- All class instances have a reference to a superclass object.
593 local instance = { super = baseclass.alloc(mastertable) }
594 -- Any functions this instance does not know of will 'look up' to the superclass definition.
595 setmetatable(instance, { __index = classdef, __newindex = mastertable })
596 return instance
597 end
598 end
599 "#.unindent();
600
601 let documents = retriever.parse_file(&text, language.clone()).unwrap();
602
603 assert_documents_eq(
604 &documents,
605 &[
606 (r#"
607 -- Creates a new class
608 -- @param baseclass The Baseclass of this class, or nil.
609 -- @return A new class reference.
610 function classes.class(baseclass)
611 -- Create the class definition and metatable.
612 local classdef = {}
613 -- Find the super class, either Object or user-defined.
614 baseclass = baseclass or classes.Object
615 -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
616 setmetatable(classdef, { __index = baseclass })
617 -- All class instances have a reference to the class object.
618 classdef.class = classdef
619 --- Recursivly allocates the inheritance tree of the instance.
620 -- @param mastertable The 'root' of the inheritance tree.
621 -- @return Returns the instance with the allocated inheritance tree.
622 function classdef.alloc(mastertable)
623 --[ ... ]--
624 --[ ... ]--
625 end
626 end"#.unindent(),
627 114),
628 (r#"
629 --- Recursivly allocates the inheritance tree of the instance.
630 -- @param mastertable The 'root' of the inheritance tree.
631 -- @return Returns the instance with the allocated inheritance tree.
632 function classdef.alloc(mastertable)
633 -- All class instances have a reference to a superclass object.
634 local instance = { super = baseclass.alloc(mastertable) }
635 -- Any functions this instance does not know of will 'look up' to the superclass definition.
636 setmetatable(instance, { __index = classdef, __newindex = mastertable })
637 return instance
638 end"#.unindent(), 809),
639 ]
640 );
641}
642
643#[gpui::test]
644async fn test_code_context_retrieval_elixir() {
645 let language = elixir_lang();
646 let embedding_provider = Arc::new(DummyEmbeddingProvider {});
647 let mut retriever = CodeContextRetriever::new(embedding_provider);
648
649 let text = r#"
650 defmodule File.Stream do
651 @moduledoc """
652 Defines a `File.Stream` struct returned by `File.stream!/3`.
653
654 The following fields are public:
655
656 * `path` - the file path
657 * `modes` - the file modes
658 * `raw` - a boolean indicating if bin functions should be used
659 * `line_or_bytes` - if reading should read lines or a given number of bytes
660 * `node` - the node the file belongs to
661
662 """
663
664 defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
665
666 @type t :: %__MODULE__{}
667
668 @doc false
669 def __build__(path, modes, line_or_bytes) do
670 raw = :lists.keyfind(:encoding, 1, modes) == false
671
672 modes =
673 case raw do
674 true ->
675 case :lists.keyfind(:read_ahead, 1, modes) do
676 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
677 {:read_ahead, _} -> [:raw | modes]
678 false -> [:raw, :read_ahead | modes]
679 end
680
681 false ->
682 modes
683 end
684
685 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
686
687 end"#
688 .unindent();
689
690 let documents = retriever.parse_file(&text, language.clone()).unwrap();
691
692 assert_documents_eq(
693 &documents,
694 &[(
695 r#"
696 defmodule File.Stream do
697 @moduledoc """
698 Defines a `File.Stream` struct returned by `File.stream!/3`.
699
700 The following fields are public:
701
702 * `path` - the file path
703 * `modes` - the file modes
704 * `raw` - a boolean indicating if bin functions should be used
705 * `line_or_bytes` - if reading should read lines or a given number of bytes
706 * `node` - the node the file belongs to
707
708 """
709
710 defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
711
712 @type t :: %__MODULE__{}
713
714 @doc false
715 def __build__(path, modes, line_or_bytes) do
716 raw = :lists.keyfind(:encoding, 1, modes) == false
717
718 modes =
719 case raw do
720 true ->
721 case :lists.keyfind(:read_ahead, 1, modes) do
722 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
723 {:read_ahead, _} -> [:raw | modes]
724 false -> [:raw, :read_ahead | modes]
725 end
726
727 false ->
728 modes
729 end
730
731 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
732
733 end"#
734 .unindent(),
735 0,
736 ),(r#"
737 @doc false
738 def __build__(path, modes, line_or_bytes) do
739 raw = :lists.keyfind(:encoding, 1, modes) == false
740
741 modes =
742 case raw do
743 true ->
744 case :lists.keyfind(:read_ahead, 1, modes) do
745 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
746 {:read_ahead, _} -> [:raw | modes]
747 false -> [:raw, :read_ahead | modes]
748 end
749
750 false ->
751 modes
752 end
753
754 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
755
756 end"#.unindent(), 574)],
757 );
758}
759
760#[gpui::test]
761async fn test_code_context_retrieval_cpp() {
762 let language = cpp_lang();
763 let embedding_provider = Arc::new(DummyEmbeddingProvider {});
764 let mut retriever = CodeContextRetriever::new(embedding_provider);
765
766 let text = "
767 /**
768 * @brief Main function
769 * @returns 0 on exit
770 */
771 int main() { return 0; }
772
773 /**
774 * This is a test comment
775 */
776 class MyClass { // The class
777 public: // Access specifier
778 int myNum; // Attribute (int variable)
779 string myString; // Attribute (string variable)
780 };
781
782 // This is a test comment
783 enum Color { red, green, blue };
784
785 /** This is a preceding block comment
786 * This is the second line
787 */
788 struct { // Structure declaration
789 int myNum; // Member (int variable)
790 string myString; // Member (string variable)
791 } myStructure;
792
793 /**
794 * @brief Matrix class.
795 */
796 template <typename T,
797 typename = typename std::enable_if<
798 std::is_integral<T>::value || std::is_floating_point<T>::value,
799 bool>::type>
800 class Matrix2 {
801 std::vector<std::vector<T>> _mat;
802
803 public:
804 /**
805 * @brief Constructor
806 * @tparam Integer ensuring integers are being evaluated and not other
807 * data types.
808 * @param size denoting the size of Matrix as size x size
809 */
810 template <typename Integer,
811 typename = typename std::enable_if<std::is_integral<Integer>::value,
812 Integer>::type>
813 explicit Matrix(const Integer size) {
814 for (size_t i = 0; i < size; ++i) {
815 _mat.emplace_back(std::vector<T>(size, 0));
816 }
817 }
818 }"
819 .unindent();
820
821 let documents = retriever.parse_file(&text, language.clone()).unwrap();
822
823 assert_documents_eq(
824 &documents,
825 &[
826 (
827 "
828 /**
829 * @brief Main function
830 * @returns 0 on exit
831 */
832 int main() { return 0; }"
833 .unindent(),
834 54,
835 ),
836 (
837 "
838 /**
839 * This is a test comment
840 */
841 class MyClass { // The class
842 public: // Access specifier
843 int myNum; // Attribute (int variable)
844 string myString; // Attribute (string variable)
845 }"
846 .unindent(),
847 112,
848 ),
849 (
850 "
851 // This is a test comment
852 enum Color { red, green, blue }"
853 .unindent(),
854 322,
855 ),
856 (
857 "
858 /** This is a preceding block comment
859 * This is the second line
860 */
861 struct { // Structure declaration
862 int myNum; // Member (int variable)
863 string myString; // Member (string variable)
864 } myStructure;"
865 .unindent(),
866 425,
867 ),
868 (
869 "
870 /**
871 * @brief Matrix class.
872 */
873 template <typename T,
874 typename = typename std::enable_if<
875 std::is_integral<T>::value || std::is_floating_point<T>::value,
876 bool>::type>
877 class Matrix2 {
878 std::vector<std::vector<T>> _mat;
879
880 public:
881 /**
882 * @brief Constructor
883 * @tparam Integer ensuring integers are being evaluated and not other
884 * data types.
885 * @param size denoting the size of Matrix as size x size
886 */
887 template <typename Integer,
888 typename = typename std::enable_if<std::is_integral<Integer>::value,
889 Integer>::type>
890 explicit Matrix(const Integer size) {
891 for (size_t i = 0; i < size; ++i) {
892 _mat.emplace_back(std::vector<T>(size, 0));
893 }
894 }
895 }"
896 .unindent(),
897 612,
898 ),
899 (
900 "
901 explicit Matrix(const Integer size) {
902 for (size_t i = 0; i < size; ++i) {
903 _mat.emplace_back(std::vector<T>(size, 0));
904 }
905 }"
906 .unindent(),
907 1226,
908 ),
909 ],
910 );
911}
912
913#[gpui::test]
914async fn test_code_context_retrieval_ruby() {
915 let language = ruby_lang();
916 let embedding_provider = Arc::new(DummyEmbeddingProvider {});
917 let mut retriever = CodeContextRetriever::new(embedding_provider);
918
919 let text = r#"
920 # This concern is inspired by "sudo mode" on GitHub. It
921 # is a way to re-authenticate a user before allowing them
922 # to see or perform an action.
923 #
924 # Add `before_action :require_challenge!` to actions you
925 # want to protect.
926 #
927 # The user will be shown a page to enter the challenge (which
928 # is either the password, or just the username when no
929 # password exists). Upon passing, there is a grace period
930 # during which no challenge will be asked from the user.
931 #
932 # Accessing challenge-protected resources during the grace
933 # period will refresh the grace period.
934 module ChallengableConcern
935 extend ActiveSupport::Concern
936
937 CHALLENGE_TIMEOUT = 1.hour.freeze
938
939 def require_challenge!
940 return if skip_challenge?
941
942 if challenge_passed_recently?
943 session[:challenge_passed_at] = Time.now.utc
944 return
945 end
946
947 @challenge = Form::Challenge.new(return_to: request.url)
948
949 if params.key?(:form_challenge)
950 if challenge_passed?
951 session[:challenge_passed_at] = Time.now.utc
952 else
953 flash.now[:alert] = I18n.t('challenge.invalid_password')
954 render_challenge
955 end
956 else
957 render_challenge
958 end
959 end
960
961 def challenge_passed?
962 current_user.valid_password?(challenge_params[:current_password])
963 end
964 end
965
966 class Animal
967 include Comparable
968
969 attr_reader :legs
970
971 def initialize(name, legs)
972 @name, @legs = name, legs
973 end
974
975 def <=>(other)
976 legs <=> other.legs
977 end
978 end
979
980 # Singleton method for car object
981 def car.wheels
982 puts "There are four wheels"
983 end"#
984 .unindent();
985
986 let documents = retriever.parse_file(&text, language.clone()).unwrap();
987
988 assert_documents_eq(
989 &documents,
990 &[
991 (
992 r#"
993 # This concern is inspired by "sudo mode" on GitHub. It
994 # is a way to re-authenticate a user before allowing them
995 # to see or perform an action.
996 #
997 # Add `before_action :require_challenge!` to actions you
998 # want to protect.
999 #
1000 # The user will be shown a page to enter the challenge (which
1001 # is either the password, or just the username when no
1002 # password exists). Upon passing, there is a grace period
1003 # during which no challenge will be asked from the user.
1004 #
1005 # Accessing challenge-protected resources during the grace
1006 # period will refresh the grace period.
1007 module ChallengableConcern
1008 extend ActiveSupport::Concern
1009
1010 CHALLENGE_TIMEOUT = 1.hour.freeze
1011
1012 def require_challenge!
1013 # ...
1014 end
1015
1016 def challenge_passed?
1017 # ...
1018 end
1019 end"#
1020 .unindent(),
1021 558,
1022 ),
1023 (
1024 r#"
1025 def require_challenge!
1026 return if skip_challenge?
1027
1028 if challenge_passed_recently?
1029 session[:challenge_passed_at] = Time.now.utc
1030 return
1031 end
1032
1033 @challenge = Form::Challenge.new(return_to: request.url)
1034
1035 if params.key?(:form_challenge)
1036 if challenge_passed?
1037 session[:challenge_passed_at] = Time.now.utc
1038 else
1039 flash.now[:alert] = I18n.t('challenge.invalid_password')
1040 render_challenge
1041 end
1042 else
1043 render_challenge
1044 end
1045 end"#
1046 .unindent(),
1047 663,
1048 ),
1049 (
1050 r#"
1051 def challenge_passed?
1052 current_user.valid_password?(challenge_params[:current_password])
1053 end"#
1054 .unindent(),
1055 1254,
1056 ),
1057 (
1058 r#"
1059 class Animal
1060 include Comparable
1061
1062 attr_reader :legs
1063
1064 def initialize(name, legs)
1065 # ...
1066 end
1067
1068 def <=>(other)
1069 # ...
1070 end
1071 end"#
1072 .unindent(),
1073 1363,
1074 ),
1075 (
1076 r#"
1077 def initialize(name, legs)
1078 @name, @legs = name, legs
1079 end"#
1080 .unindent(),
1081 1427,
1082 ),
1083 (
1084 r#"
1085 def <=>(other)
1086 legs <=> other.legs
1087 end"#
1088 .unindent(),
1089 1501,
1090 ),
1091 (
1092 r#"
1093 # Singleton method for car object
1094 def car.wheels
1095 puts "There are four wheels"
1096 end"#
1097 .unindent(),
1098 1591,
1099 ),
1100 ],
1101 );
1102}
1103
1104#[gpui::test]
1105async fn test_code_context_retrieval_php() {
1106 let language = php_lang();
1107 let embedding_provider = Arc::new(DummyEmbeddingProvider {});
1108 let mut retriever = CodeContextRetriever::new(embedding_provider);
1109
1110 let text = r#"
1111 <?php
1112
1113 namespace LevelUp\Experience\Concerns;
1114
1115 /*
1116 This is a multiple-lines comment block
1117 that spans over multiple
1118 lines
1119 */
1120 function functionName() {
1121 echo "Hello world!";
1122 }
1123
1124 trait HasAchievements
1125 {
1126 /**
1127 * @throws \Exception
1128 */
1129 public function grantAchievement(Achievement $achievement, $progress = null): void
1130 {
1131 if ($progress > 100) {
1132 throw new Exception(message: 'Progress cannot be greater than 100');
1133 }
1134
1135 if ($this->achievements()->find($achievement->id)) {
1136 throw new Exception(message: 'User already has this Achievement');
1137 }
1138
1139 $this->achievements()->attach($achievement, [
1140 'progress' => $progress ?? null,
1141 ]);
1142
1143 $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1144 }
1145
1146 public function achievements(): BelongsToMany
1147 {
1148 return $this->belongsToMany(related: Achievement::class)
1149 ->withPivot(columns: 'progress')
1150 ->where('is_secret', false)
1151 ->using(AchievementUser::class);
1152 }
1153 }
1154
1155 interface Multiplier
1156 {
1157 public function qualifies(array $data): bool;
1158
1159 public function setMultiplier(): int;
1160 }
1161
1162 enum AuditType: string
1163 {
1164 case Add = 'add';
1165 case Remove = 'remove';
1166 case Reset = 'reset';
1167 case LevelUp = 'level_up';
1168 }
1169
1170 ?>"#
1171 .unindent();
1172
1173 let documents = retriever.parse_file(&text, language.clone()).unwrap();
1174
1175 assert_documents_eq(
1176 &documents,
1177 &[
1178 (
1179 r#"
1180 /*
1181 This is a multiple-lines comment block
1182 that spans over multiple
1183 lines
1184 */
1185 function functionName() {
1186 echo "Hello world!";
1187 }"#
1188 .unindent(),
1189 123,
1190 ),
1191 (
1192 r#"
1193 trait HasAchievements
1194 {
1195 /**
1196 * @throws \Exception
1197 */
1198 public function grantAchievement(Achievement $achievement, $progress = null): void
1199 {/* ... */}
1200
1201 public function achievements(): BelongsToMany
1202 {/* ... */}
1203 }"#
1204 .unindent(),
1205 177,
1206 ),
1207 (r#"
1208 /**
1209 * @throws \Exception
1210 */
1211 public function grantAchievement(Achievement $achievement, $progress = null): void
1212 {
1213 if ($progress > 100) {
1214 throw new Exception(message: 'Progress cannot be greater than 100');
1215 }
1216
1217 if ($this->achievements()->find($achievement->id)) {
1218 throw new Exception(message: 'User already has this Achievement');
1219 }
1220
1221 $this->achievements()->attach($achievement, [
1222 'progress' => $progress ?? null,
1223 ]);
1224
1225 $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1226 }"#.unindent(), 245),
1227 (r#"
1228 public function achievements(): BelongsToMany
1229 {
1230 return $this->belongsToMany(related: Achievement::class)
1231 ->withPivot(columns: 'progress')
1232 ->where('is_secret', false)
1233 ->using(AchievementUser::class);
1234 }"#.unindent(), 902),
1235 (r#"
1236 interface Multiplier
1237 {
1238 public function qualifies(array $data): bool;
1239
1240 public function setMultiplier(): int;
1241 }"#.unindent(),
1242 1146),
1243 (r#"
1244 enum AuditType: string
1245 {
1246 case Add = 'add';
1247 case Remove = 'remove';
1248 case Reset = 'reset';
1249 case LevelUp = 'level_up';
1250 }"#.unindent(), 1265)
1251 ],
1252 );
1253}
1254
1255#[derive(Default)]
1256struct FakeEmbeddingProvider {
1257 embedding_count: AtomicUsize,
1258}
1259
1260impl FakeEmbeddingProvider {
1261 fn embedding_count(&self) -> usize {
1262 self.embedding_count.load(atomic::Ordering::SeqCst)
1263 }
1264
1265 fn embed_sync(&self, span: &str) -> Embedding {
1266 let mut result = vec![1.0; 26];
1267 for letter in span.chars() {
1268 let letter = letter.to_ascii_lowercase();
1269 if letter as u32 >= 'a' as u32 {
1270 let ix = (letter as u32) - ('a' as u32);
1271 if ix < 26 {
1272 result[ix as usize] += 1.0;
1273 }
1274 }
1275 }
1276
1277 let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
1278 for x in &mut result {
1279 *x /= norm;
1280 }
1281
1282 result.into()
1283 }
1284}
1285
1286#[async_trait]
1287impl EmbeddingProvider for FakeEmbeddingProvider {
1288 fn base_model(&self) -> Box<dyn LanguageModel> {
1289 Box::new(DummyLanguageModel {})
1290 }
1291 fn is_authenticated(&self) -> bool {
1292 true
1293 }
1294 fn truncate(&self, span: &str) -> (String, usize) {
1295 (span.to_string(), 1)
1296 }
1297
1298 fn max_tokens_per_batch(&self) -> usize {
1299 200
1300 }
1301
1302 fn rate_limit_expiration(&self) -> Option<Instant> {
1303 None
1304 }
1305
1306 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
1307 self.embedding_count
1308 .fetch_add(spans.len(), atomic::Ordering::SeqCst);
1309 Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
1310 }
1311}
1312
1313fn js_lang() -> Arc<Language> {
1314 Arc::new(
1315 Language::new(
1316 LanguageConfig {
1317 name: "Javascript".into(),
1318 path_suffixes: vec!["js".into()],
1319 ..Default::default()
1320 },
1321 Some(tree_sitter_typescript::language_tsx()),
1322 )
1323 .with_embedding_query(
1324 &r#"
1325
1326 (
1327 (comment)* @context
1328 .
1329 [
1330 (export_statement
1331 (function_declaration
1332 "async"? @name
1333 "function" @name
1334 name: (_) @name))
1335 (function_declaration
1336 "async"? @name
1337 "function" @name
1338 name: (_) @name)
1339 ] @item
1340 )
1341
1342 (
1343 (comment)* @context
1344 .
1345 [
1346 (export_statement
1347 (class_declaration
1348 "class" @name
1349 name: (_) @name))
1350 (class_declaration
1351 "class" @name
1352 name: (_) @name)
1353 ] @item
1354 )
1355
1356 (
1357 (comment)* @context
1358 .
1359 [
1360 (export_statement
1361 (interface_declaration
1362 "interface" @name
1363 name: (_) @name))
1364 (interface_declaration
1365 "interface" @name
1366 name: (_) @name)
1367 ] @item
1368 )
1369
1370 (
1371 (comment)* @context
1372 .
1373 [
1374 (export_statement
1375 (enum_declaration
1376 "enum" @name
1377 name: (_) @name))
1378 (enum_declaration
1379 "enum" @name
1380 name: (_) @name)
1381 ] @item
1382 )
1383
1384 (
1385 (comment)* @context
1386 .
1387 (method_definition
1388 [
1389 "get"
1390 "set"
1391 "async"
1392 "*"
1393 "static"
1394 ]* @name
1395 name: (_) @name) @item
1396 )
1397
1398 "#
1399 .unindent(),
1400 )
1401 .unwrap(),
1402 )
1403}
1404
1405fn rust_lang() -> Arc<Language> {
1406 Arc::new(
1407 Language::new(
1408 LanguageConfig {
1409 name: "Rust".into(),
1410 path_suffixes: vec!["rs".into()],
1411 collapsed_placeholder: " /* ... */ ".to_string(),
1412 ..Default::default()
1413 },
1414 Some(tree_sitter_rust::language()),
1415 )
1416 .with_embedding_query(
1417 r#"
1418 (
1419 [(line_comment) (attribute_item)]* @context
1420 .
1421 [
1422 (struct_item
1423 name: (_) @name)
1424
1425 (enum_item
1426 name: (_) @name)
1427
1428 (impl_item
1429 trait: (_)? @name
1430 "for"? @name
1431 type: (_) @name)
1432
1433 (trait_item
1434 name: (_) @name)
1435
1436 (function_item
1437 name: (_) @name
1438 body: (block
1439 "{" @keep
1440 "}" @keep) @collapse)
1441
1442 (macro_definition
1443 name: (_) @name)
1444 ] @item
1445 )
1446
1447 (attribute_item) @collapse
1448 (use_declaration) @collapse
1449 "#,
1450 )
1451 .unwrap(),
1452 )
1453}
1454
1455fn json_lang() -> Arc<Language> {
1456 Arc::new(
1457 Language::new(
1458 LanguageConfig {
1459 name: "JSON".into(),
1460 path_suffixes: vec!["json".into()],
1461 ..Default::default()
1462 },
1463 Some(tree_sitter_json::language()),
1464 )
1465 .with_embedding_query(
1466 r#"
1467 (document) @item
1468
1469 (array
1470 "[" @keep
1471 .
1472 (object)? @keep
1473 "]" @keep) @collapse
1474
1475 (pair value: (string
1476 "\"" @keep
1477 "\"" @keep) @collapse)
1478 "#,
1479 )
1480 .unwrap(),
1481 )
1482}
1483
1484fn toml_lang() -> Arc<Language> {
1485 Arc::new(Language::new(
1486 LanguageConfig {
1487 name: "TOML".into(),
1488 path_suffixes: vec!["toml".into()],
1489 ..Default::default()
1490 },
1491 Some(tree_sitter_toml::language()),
1492 ))
1493}
1494
1495fn cpp_lang() -> Arc<Language> {
1496 Arc::new(
1497 Language::new(
1498 LanguageConfig {
1499 name: "CPP".into(),
1500 path_suffixes: vec!["cpp".into()],
1501 ..Default::default()
1502 },
1503 Some(tree_sitter_cpp::language()),
1504 )
1505 .with_embedding_query(
1506 r#"
1507 (
1508 (comment)* @context
1509 .
1510 (function_definition
1511 (type_qualifier)? @name
1512 type: (_)? @name
1513 declarator: [
1514 (function_declarator
1515 declarator: (_) @name)
1516 (pointer_declarator
1517 "*" @name
1518 declarator: (function_declarator
1519 declarator: (_) @name))
1520 (pointer_declarator
1521 "*" @name
1522 declarator: (pointer_declarator
1523 "*" @name
1524 declarator: (function_declarator
1525 declarator: (_) @name)))
1526 (reference_declarator
1527 ["&" "&&"] @name
1528 (function_declarator
1529 declarator: (_) @name))
1530 ]
1531 (type_qualifier)? @name) @item
1532 )
1533
1534 (
1535 (comment)* @context
1536 .
1537 (template_declaration
1538 (class_specifier
1539 "class" @name
1540 name: (_) @name)
1541 ) @item
1542 )
1543
1544 (
1545 (comment)* @context
1546 .
1547 (class_specifier
1548 "class" @name
1549 name: (_) @name) @item
1550 )
1551
1552 (
1553 (comment)* @context
1554 .
1555 (enum_specifier
1556 "enum" @name
1557 name: (_) @name) @item
1558 )
1559
1560 (
1561 (comment)* @context
1562 .
1563 (declaration
1564 type: (struct_specifier
1565 "struct" @name)
1566 declarator: (_) @name) @item
1567 )
1568
1569 "#,
1570 )
1571 .unwrap(),
1572 )
1573}
1574
1575fn lua_lang() -> Arc<Language> {
1576 Arc::new(
1577 Language::new(
1578 LanguageConfig {
1579 name: "Lua".into(),
1580 path_suffixes: vec!["lua".into()],
1581 collapsed_placeholder: "--[ ... ]--".to_string(),
1582 ..Default::default()
1583 },
1584 Some(tree_sitter_lua::language()),
1585 )
1586 .with_embedding_query(
1587 r#"
1588 (
1589 (comment)* @context
1590 .
1591 (function_declaration
1592 "function" @name
1593 name: (_) @name
1594 (comment)* @collapse
1595 body: (block) @collapse
1596 ) @item
1597 )
1598 "#,
1599 )
1600 .unwrap(),
1601 )
1602}
1603
1604fn php_lang() -> Arc<Language> {
1605 Arc::new(
1606 Language::new(
1607 LanguageConfig {
1608 name: "PHP".into(),
1609 path_suffixes: vec!["php".into()],
1610 collapsed_placeholder: "/* ... */".into(),
1611 ..Default::default()
1612 },
1613 Some(tree_sitter_php::language()),
1614 )
1615 .with_embedding_query(
1616 r#"
1617 (
1618 (comment)* @context
1619 .
1620 [
1621 (function_definition
1622 "function" @name
1623 name: (_) @name
1624 body: (_
1625 "{" @keep
1626 "}" @keep) @collapse
1627 )
1628
1629 (trait_declaration
1630 "trait" @name
1631 name: (_) @name)
1632
1633 (method_declaration
1634 "function" @name
1635 name: (_) @name
1636 body: (_
1637 "{" @keep
1638 "}" @keep) @collapse
1639 )
1640
1641 (interface_declaration
1642 "interface" @name
1643 name: (_) @name
1644 )
1645
1646 (enum_declaration
1647 "enum" @name
1648 name: (_) @name
1649 )
1650
1651 ] @item
1652 )
1653 "#,
1654 )
1655 .unwrap(),
1656 )
1657}
1658
1659fn ruby_lang() -> Arc<Language> {
1660 Arc::new(
1661 Language::new(
1662 LanguageConfig {
1663 name: "Ruby".into(),
1664 path_suffixes: vec!["rb".into()],
1665 collapsed_placeholder: "# ...".to_string(),
1666 ..Default::default()
1667 },
1668 Some(tree_sitter_ruby::language()),
1669 )
1670 .with_embedding_query(
1671 r#"
1672 (
1673 (comment)* @context
1674 .
1675 [
1676 (module
1677 "module" @name
1678 name: (_) @name)
1679 (method
1680 "def" @name
1681 name: (_) @name
1682 body: (body_statement) @collapse)
1683 (class
1684 "class" @name
1685 name: (_) @name)
1686 (singleton_method
1687 "def" @name
1688 object: (_) @name
1689 "." @name
1690 name: (_) @name
1691 body: (body_statement) @collapse)
1692 ] @item
1693 )
1694 "#,
1695 )
1696 .unwrap(),
1697 )
1698}
1699
1700fn elixir_lang() -> Arc<Language> {
1701 Arc::new(
1702 Language::new(
1703 LanguageConfig {
1704 name: "Elixir".into(),
1705 path_suffixes: vec!["rs".into()],
1706 ..Default::default()
1707 },
1708 Some(tree_sitter_elixir::language()),
1709 )
1710 .with_embedding_query(
1711 r#"
1712 (
1713 (unary_operator
1714 operator: "@"
1715 operand: (call
1716 target: (identifier) @unary
1717 (#match? @unary "^(doc)$"))
1718 ) @context
1719 .
1720 (call
1721 target: (identifier) @name
1722 (arguments
1723 [
1724 (identifier) @name
1725 (call
1726 target: (identifier) @name)
1727 (binary_operator
1728 left: (call
1729 target: (identifier) @name)
1730 operator: "when")
1731 ])
1732 (#match? @name "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @item
1733 )
1734
1735 (call
1736 target: (identifier) @name
1737 (arguments (alias) @name)
1738 (#match? @name "^(defmodule|defprotocol)$")) @item
1739 "#,
1740 )
1741 .unwrap(),
1742 )
1743}
1744
1745#[gpui::test]
1746fn test_subtract_ranges() {
1747 // collapsed_ranges: Vec<Range<usize>>, keep_ranges: Vec<Range<usize>>
1748
1749 assert_eq!(
1750 subtract_ranges(&[0..5, 10..21], &[0..1, 4..5]),
1751 vec![1..4, 10..21]
1752 );
1753
1754 assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
1755}
1756
1757fn init_test(cx: &mut TestAppContext) {
1758 cx.update(|cx| {
1759 cx.set_global(SettingsStore::test(cx));
1760 settings::register::<SemanticIndexSettings>(cx);
1761 settings::register::<ProjectSettings>(cx);
1762 });
1763}