1use crate::{
2 embedding_queue::EmbeddingQueue,
3 parsing::{subtract_ranges, CodeContextRetriever, Span, SpanDigest},
4 semantic_index_settings::SemanticIndexSettings,
5 FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
6};
7use ai::embedding::{DummyEmbeddings, Embedding, EmbeddingProvider};
8use anyhow::Result;
9use async_trait::async_trait;
10use gpui::{executor::Deterministic, AppContext, Task, TestAppContext};
11use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
12use parking_lot::Mutex;
13use pretty_assertions::assert_eq;
14use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs, Project};
15use rand::{rngs::StdRng, Rng};
16use serde_json::json;
17use settings::SettingsStore;
18use std::{
19 path::Path,
20 sync::{
21 atomic::{self, AtomicUsize},
22 Arc,
23 },
24 time::{Instant, SystemTime},
25};
26use unindent::Unindent;
27use util::RandomCharIter;
28
29#[ctor::ctor]
30fn init_logger() {
31 if std::env::var("RUST_LOG").is_ok() {
32 env_logger::init();
33 }
34}
35
36#[gpui::test]
37async fn test_semantic_index(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
38 init_test(cx);
39
40 let fs = FakeFs::new(cx.background());
41 fs.insert_tree(
42 "/the-root",
43 json!({
44 "src": {
45 "file1.rs": "
46 fn aaa() {
47 println!(\"aaaaaaaaaaaa!\");
48 }
49
50 fn zzzzz() {
51 println!(\"SLEEPING\");
52 }
53 ".unindent(),
54 "file2.rs": "
55 fn bbb() {
56 println!(\"bbbbbbbbbbbbb!\");
57 }
58 struct pqpqpqp {}
59 ".unindent(),
60 "file3.toml": "
61 ZZZZZZZZZZZZZZZZZZ = 5
62 ".unindent(),
63 }
64 }),
65 )
66 .await;
67
68 let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
69 let rust_language = rust_lang();
70 let toml_language = toml_lang();
71 languages.add(rust_language);
72 languages.add(toml_language);
73
74 let db_dir = tempdir::TempDir::new("vector-store").unwrap();
75 let db_path = db_dir.path().join("db.sqlite");
76
77 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
78 let semantic_index = SemanticIndex::new(
79 fs.clone(),
80 db_path,
81 embedding_provider.clone(),
82 languages,
83 cx.to_async(),
84 )
85 .await
86 .unwrap();
87
88 let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
89
90 let search_results = semantic_index.update(cx, |store, cx| {
91 store.search_project(
92 project.clone(),
93 "aaaaaabbbbzz".to_string(),
94 5,
95 vec![],
96 vec![],
97 cx,
98 )
99 });
100 let pending_file_count =
101 semantic_index.read_with(cx, |index, _| index.pending_file_count(&project).unwrap());
102 deterministic.run_until_parked();
103 assert_eq!(*pending_file_count.borrow(), 3);
104 deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
105 assert_eq!(*pending_file_count.borrow(), 0);
106
107 let search_results = search_results.await.unwrap();
108 assert_search_results(
109 &search_results,
110 &[
111 (Path::new("src/file1.rs").into(), 0),
112 (Path::new("src/file2.rs").into(), 0),
113 (Path::new("src/file3.toml").into(), 0),
114 (Path::new("src/file1.rs").into(), 45),
115 (Path::new("src/file2.rs").into(), 45),
116 ],
117 cx,
118 );
119
120 // Test Include Files Functonality
121 let include_files = vec![PathMatcher::new("*.rs").unwrap()];
122 let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
123 let rust_only_search_results = semantic_index
124 .update(cx, |store, cx| {
125 store.search_project(
126 project.clone(),
127 "aaaaaabbbbzz".to_string(),
128 5,
129 include_files,
130 vec![],
131 cx,
132 )
133 })
134 .await
135 .unwrap();
136
137 assert_search_results(
138 &rust_only_search_results,
139 &[
140 (Path::new("src/file1.rs").into(), 0),
141 (Path::new("src/file2.rs").into(), 0),
142 (Path::new("src/file1.rs").into(), 45),
143 (Path::new("src/file2.rs").into(), 45),
144 ],
145 cx,
146 );
147
148 let no_rust_search_results = semantic_index
149 .update(cx, |store, cx| {
150 store.search_project(
151 project.clone(),
152 "aaaaaabbbbzz".to_string(),
153 5,
154 vec![],
155 exclude_files,
156 cx,
157 )
158 })
159 .await
160 .unwrap();
161
162 assert_search_results(
163 &no_rust_search_results,
164 &[(Path::new("src/file3.toml").into(), 0)],
165 cx,
166 );
167
168 fs.save(
169 "/the-root/src/file2.rs".as_ref(),
170 &"
171 fn dddd() { println!(\"ddddd!\"); }
172 struct pqpqpqp {}
173 "
174 .unindent()
175 .into(),
176 Default::default(),
177 )
178 .await
179 .unwrap();
180
181 deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
182
183 let prev_embedding_count = embedding_provider.embedding_count();
184 let index = semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx));
185 deterministic.run_until_parked();
186 assert_eq!(*pending_file_count.borrow(), 1);
187 deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
188 assert_eq!(*pending_file_count.borrow(), 0);
189 index.await.unwrap();
190
191 assert_eq!(
192 embedding_provider.embedding_count() - prev_embedding_count,
193 1
194 );
195}
196
197#[gpui::test(iterations = 10)]
198async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
199 let (outstanding_job_count, _) = postage::watch::channel_with(0);
200 let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count));
201
202 let files = (1..=3)
203 .map(|file_ix| FileToEmbed {
204 worktree_id: 5,
205 path: Path::new(&format!("path-{file_ix}")).into(),
206 mtime: SystemTime::now(),
207 spans: (0..rng.gen_range(4..22))
208 .map(|document_ix| {
209 let content_len = rng.gen_range(10..100);
210 let content = RandomCharIter::new(&mut rng)
211 .with_simple_text()
212 .take(content_len)
213 .collect::<String>();
214 let digest = SpanDigest::from(content.as_str());
215 Span {
216 range: 0..10,
217 embedding: None,
218 name: format!("document {document_ix}"),
219 content,
220 digest,
221 token_count: rng.gen_range(10..30),
222 }
223 })
224 .collect(),
225 job_handle: JobHandle::new(&outstanding_job_count),
226 })
227 .collect::<Vec<_>>();
228
229 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
230
231 let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background(), None);
232 for file in &files {
233 queue.push(file.clone());
234 }
235 queue.flush();
236
237 cx.foreground().run_until_parked();
238 let finished_files = queue.finished_files();
239 let mut embedded_files: Vec<_> = files
240 .iter()
241 .map(|_| finished_files.try_recv().expect("no finished file"))
242 .collect();
243
244 let expected_files: Vec<_> = files
245 .iter()
246 .map(|file| {
247 let mut file = file.clone();
248 for doc in &mut file.spans {
249 doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
250 }
251 file
252 })
253 .collect();
254
255 embedded_files.sort_by_key(|f| f.path.clone());
256
257 assert_eq!(embedded_files, expected_files);
258}
259
260#[track_caller]
261fn assert_search_results(
262 actual: &[SearchResult],
263 expected: &[(Arc<Path>, usize)],
264 cx: &TestAppContext,
265) {
266 let actual = actual
267 .iter()
268 .map(|search_result| {
269 search_result.buffer.read_with(cx, |buffer, _cx| {
270 (
271 buffer.file().unwrap().path().clone(),
272 search_result.range.start.to_offset(buffer),
273 )
274 })
275 })
276 .collect::<Vec<_>>();
277 assert_eq!(actual, expected);
278}
279
280#[gpui::test]
281async fn test_code_context_retrieval_rust() {
282 let language = rust_lang();
283 let embedding_provider = Arc::new(DummyEmbeddings {});
284 let mut retriever = CodeContextRetriever::new(embedding_provider);
285
286 let text = "
287 /// A doc comment
288 /// that spans multiple lines
289 #[gpui::test]
290 fn a() {
291 b
292 }
293
294 impl C for D {
295 }
296
297 impl E {
298 // This is also a preceding comment
299 pub fn function_1() -> Option<()> {
300 todo!();
301 }
302
303 // This is a preceding comment
304 fn function_2() -> Result<()> {
305 todo!();
306 }
307 }
308
309 #[derive(Clone)]
310 struct D {
311 name: String
312 }
313 "
314 .unindent();
315
316 let documents = retriever.parse_file(&text, language).unwrap();
317
318 assert_documents_eq(
319 &documents,
320 &[
321 (
322 "
323 /// A doc comment
324 /// that spans multiple lines
325 #[gpui::test]
326 fn a() {
327 b
328 }"
329 .unindent(),
330 text.find("fn a").unwrap(),
331 ),
332 (
333 "
334 impl C for D {
335 }"
336 .unindent(),
337 text.find("impl C").unwrap(),
338 ),
339 (
340 "
341 impl E {
342 // This is also a preceding comment
343 pub fn function_1() -> Option<()> { /* ... */ }
344
345 // This is a preceding comment
346 fn function_2() -> Result<()> { /* ... */ }
347 }"
348 .unindent(),
349 text.find("impl E").unwrap(),
350 ),
351 (
352 "
353 // This is also a preceding comment
354 pub fn function_1() -> Option<()> {
355 todo!();
356 }"
357 .unindent(),
358 text.find("pub fn function_1").unwrap(),
359 ),
360 (
361 "
362 // This is a preceding comment
363 fn function_2() -> Result<()> {
364 todo!();
365 }"
366 .unindent(),
367 text.find("fn function_2").unwrap(),
368 ),
369 (
370 "
371 #[derive(Clone)]
372 struct D {
373 name: String
374 }"
375 .unindent(),
376 text.find("struct D").unwrap(),
377 ),
378 ],
379 );
380}
381
382#[gpui::test]
383async fn test_code_context_retrieval_json() {
384 let language = json_lang();
385 let embedding_provider = Arc::new(DummyEmbeddings {});
386 let mut retriever = CodeContextRetriever::new(embedding_provider);
387
388 let text = r#"
389 {
390 "array": [1, 2, 3, 4],
391 "string": "abcdefg",
392 "nested_object": {
393 "array_2": [5, 6, 7, 8],
394 "string_2": "hijklmnop",
395 "boolean": true,
396 "none": null
397 }
398 }
399 "#
400 .unindent();
401
402 let documents = retriever.parse_file(&text, language.clone()).unwrap();
403
404 assert_documents_eq(
405 &documents,
406 &[(
407 r#"
408 {
409 "array": [],
410 "string": "",
411 "nested_object": {
412 "array_2": [],
413 "string_2": "",
414 "boolean": true,
415 "none": null
416 }
417 }"#
418 .unindent(),
419 text.find("{").unwrap(),
420 )],
421 );
422
423 let text = r#"
424 [
425 {
426 "name": "somebody",
427 "age": 42
428 },
429 {
430 "name": "somebody else",
431 "age": 43
432 }
433 ]
434 "#
435 .unindent();
436
437 let documents = retriever.parse_file(&text, language.clone()).unwrap();
438
439 assert_documents_eq(
440 &documents,
441 &[(
442 r#"
443 [{
444 "name": "",
445 "age": 42
446 }]"#
447 .unindent(),
448 text.find("[").unwrap(),
449 )],
450 );
451}
452
453fn assert_documents_eq(
454 documents: &[Span],
455 expected_contents_and_start_offsets: &[(String, usize)],
456) {
457 assert_eq!(
458 documents
459 .iter()
460 .map(|document| (document.content.clone(), document.range.start))
461 .collect::<Vec<_>>(),
462 expected_contents_and_start_offsets
463 );
464}
465
466#[gpui::test]
467async fn test_code_context_retrieval_javascript() {
468 let language = js_lang();
469 let embedding_provider = Arc::new(DummyEmbeddings {});
470 let mut retriever = CodeContextRetriever::new(embedding_provider);
471
472 let text = "
473 /* globals importScripts, backend */
474 function _authorize() {}
475
476 /**
477 * Sometimes the frontend build is way faster than backend.
478 */
479 export async function authorizeBank() {
480 _authorize(pushModal, upgradingAccountId, {});
481 }
482
483 export class SettingsPage {
484 /* This is a test setting */
485 constructor(page) {
486 this.page = page;
487 }
488 }
489
490 /* This is a test comment */
491 class TestClass {}
492
493 /* Schema for editor_events in Clickhouse. */
494 export interface ClickhouseEditorEvent {
495 installation_id: string
496 operation: string
497 }
498 "
499 .unindent();
500
501 let documents = retriever.parse_file(&text, language.clone()).unwrap();
502
503 assert_documents_eq(
504 &documents,
505 &[
506 (
507 "
508 /* globals importScripts, backend */
509 function _authorize() {}"
510 .unindent(),
511 37,
512 ),
513 (
514 "
515 /**
516 * Sometimes the frontend build is way faster than backend.
517 */
518 export async function authorizeBank() {
519 _authorize(pushModal, upgradingAccountId, {});
520 }"
521 .unindent(),
522 131,
523 ),
524 (
525 "
526 export class SettingsPage {
527 /* This is a test setting */
528 constructor(page) {
529 this.page = page;
530 }
531 }"
532 .unindent(),
533 225,
534 ),
535 (
536 "
537 /* This is a test setting */
538 constructor(page) {
539 this.page = page;
540 }"
541 .unindent(),
542 290,
543 ),
544 (
545 "
546 /* This is a test comment */
547 class TestClass {}"
548 .unindent(),
549 374,
550 ),
551 (
552 "
553 /* Schema for editor_events in Clickhouse. */
554 export interface ClickhouseEditorEvent {
555 installation_id: string
556 operation: string
557 }"
558 .unindent(),
559 440,
560 ),
561 ],
562 )
563}
564
565#[gpui::test]
566async fn test_code_context_retrieval_lua() {
567 let language = lua_lang();
568 let embedding_provider = Arc::new(DummyEmbeddings {});
569 let mut retriever = CodeContextRetriever::new(embedding_provider);
570
571 let text = r#"
572 -- Creates a new class
573 -- @param baseclass The Baseclass of this class, or nil.
574 -- @return A new class reference.
575 function classes.class(baseclass)
576 -- Create the class definition and metatable.
577 local classdef = {}
578 -- Find the super class, either Object or user-defined.
579 baseclass = baseclass or classes.Object
580 -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
581 setmetatable(classdef, { __index = baseclass })
582 -- All class instances have a reference to the class object.
583 classdef.class = classdef
584 --- Recursivly allocates the inheritance tree of the instance.
585 -- @param mastertable The 'root' of the inheritance tree.
586 -- @return Returns the instance with the allocated inheritance tree.
587 function classdef.alloc(mastertable)
588 -- All class instances have a reference to a superclass object.
589 local instance = { super = baseclass.alloc(mastertable) }
590 -- Any functions this instance does not know of will 'look up' to the superclass definition.
591 setmetatable(instance, { __index = classdef, __newindex = mastertable })
592 return instance
593 end
594 end
595 "#.unindent();
596
597 let documents = retriever.parse_file(&text, language.clone()).unwrap();
598
599 assert_documents_eq(
600 &documents,
601 &[
602 (r#"
603 -- Creates a new class
604 -- @param baseclass The Baseclass of this class, or nil.
605 -- @return A new class reference.
606 function classes.class(baseclass)
607 -- Create the class definition and metatable.
608 local classdef = {}
609 -- Find the super class, either Object or user-defined.
610 baseclass = baseclass or classes.Object
611 -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
612 setmetatable(classdef, { __index = baseclass })
613 -- All class instances have a reference to the class object.
614 classdef.class = classdef
615 --- Recursivly allocates the inheritance tree of the instance.
616 -- @param mastertable The 'root' of the inheritance tree.
617 -- @return Returns the instance with the allocated inheritance tree.
618 function classdef.alloc(mastertable)
619 --[ ... ]--
620 --[ ... ]--
621 end
622 end"#.unindent(),
623 114),
624 (r#"
625 --- Recursivly allocates the inheritance tree of the instance.
626 -- @param mastertable The 'root' of the inheritance tree.
627 -- @return Returns the instance with the allocated inheritance tree.
628 function classdef.alloc(mastertable)
629 -- All class instances have a reference to a superclass object.
630 local instance = { super = baseclass.alloc(mastertable) }
631 -- Any functions this instance does not know of will 'look up' to the superclass definition.
632 setmetatable(instance, { __index = classdef, __newindex = mastertable })
633 return instance
634 end"#.unindent(), 809),
635 ]
636 );
637}
638
639#[gpui::test]
640async fn test_code_context_retrieval_elixir() {
641 let language = elixir_lang();
642 let embedding_provider = Arc::new(DummyEmbeddings {});
643 let mut retriever = CodeContextRetriever::new(embedding_provider);
644
645 let text = r#"
646 defmodule File.Stream do
647 @moduledoc """
648 Defines a `File.Stream` struct returned by `File.stream!/3`.
649
650 The following fields are public:
651
652 * `path` - the file path
653 * `modes` - the file modes
654 * `raw` - a boolean indicating if bin functions should be used
655 * `line_or_bytes` - if reading should read lines or a given number of bytes
656 * `node` - the node the file belongs to
657
658 """
659
660 defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
661
662 @type t :: %__MODULE__{}
663
664 @doc false
665 def __build__(path, modes, line_or_bytes) do
666 raw = :lists.keyfind(:encoding, 1, modes) == false
667
668 modes =
669 case raw do
670 true ->
671 case :lists.keyfind(:read_ahead, 1, modes) do
672 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
673 {:read_ahead, _} -> [:raw | modes]
674 false -> [:raw, :read_ahead | modes]
675 end
676
677 false ->
678 modes
679 end
680
681 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
682
683 end"#
684 .unindent();
685
686 let documents = retriever.parse_file(&text, language.clone()).unwrap();
687
688 assert_documents_eq(
689 &documents,
690 &[(
691 r#"
692 defmodule File.Stream do
693 @moduledoc """
694 Defines a `File.Stream` struct returned by `File.stream!/3`.
695
696 The following fields are public:
697
698 * `path` - the file path
699 * `modes` - the file modes
700 * `raw` - a boolean indicating if bin functions should be used
701 * `line_or_bytes` - if reading should read lines or a given number of bytes
702 * `node` - the node the file belongs to
703
704 """
705
706 defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
707
708 @type t :: %__MODULE__{}
709
710 @doc false
711 def __build__(path, modes, line_or_bytes) do
712 raw = :lists.keyfind(:encoding, 1, modes) == false
713
714 modes =
715 case raw do
716 true ->
717 case :lists.keyfind(:read_ahead, 1, modes) do
718 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
719 {:read_ahead, _} -> [:raw | modes]
720 false -> [:raw, :read_ahead | modes]
721 end
722
723 false ->
724 modes
725 end
726
727 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
728
729 end"#
730 .unindent(),
731 0,
732 ),(r#"
733 @doc false
734 def __build__(path, modes, line_or_bytes) do
735 raw = :lists.keyfind(:encoding, 1, modes) == false
736
737 modes =
738 case raw do
739 true ->
740 case :lists.keyfind(:read_ahead, 1, modes) do
741 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
742 {:read_ahead, _} -> [:raw | modes]
743 false -> [:raw, :read_ahead | modes]
744 end
745
746 false ->
747 modes
748 end
749
750 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
751
752 end"#.unindent(), 574)],
753 );
754}
755
756#[gpui::test]
757async fn test_code_context_retrieval_cpp() {
758 let language = cpp_lang();
759 let embedding_provider = Arc::new(DummyEmbeddings {});
760 let mut retriever = CodeContextRetriever::new(embedding_provider);
761
762 let text = "
763 /**
764 * @brief Main function
765 * @returns 0 on exit
766 */
767 int main() { return 0; }
768
769 /**
770 * This is a test comment
771 */
772 class MyClass { // The class
773 public: // Access specifier
774 int myNum; // Attribute (int variable)
775 string myString; // Attribute (string variable)
776 };
777
778 // This is a test comment
779 enum Color { red, green, blue };
780
781 /** This is a preceding block comment
782 * This is the second line
783 */
784 struct { // Structure declaration
785 int myNum; // Member (int variable)
786 string myString; // Member (string variable)
787 } myStructure;
788
789 /**
790 * @brief Matrix class.
791 */
792 template <typename T,
793 typename = typename std::enable_if<
794 std::is_integral<T>::value || std::is_floating_point<T>::value,
795 bool>::type>
796 class Matrix2 {
797 std::vector<std::vector<T>> _mat;
798
799 public:
800 /**
801 * @brief Constructor
802 * @tparam Integer ensuring integers are being evaluated and not other
803 * data types.
804 * @param size denoting the size of Matrix as size x size
805 */
806 template <typename Integer,
807 typename = typename std::enable_if<std::is_integral<Integer>::value,
808 Integer>::type>
809 explicit Matrix(const Integer size) {
810 for (size_t i = 0; i < size; ++i) {
811 _mat.emplace_back(std::vector<T>(size, 0));
812 }
813 }
814 }"
815 .unindent();
816
817 let documents = retriever.parse_file(&text, language.clone()).unwrap();
818
819 assert_documents_eq(
820 &documents,
821 &[
822 (
823 "
824 /**
825 * @brief Main function
826 * @returns 0 on exit
827 */
828 int main() { return 0; }"
829 .unindent(),
830 54,
831 ),
832 (
833 "
834 /**
835 * This is a test comment
836 */
837 class MyClass { // The class
838 public: // Access specifier
839 int myNum; // Attribute (int variable)
840 string myString; // Attribute (string variable)
841 }"
842 .unindent(),
843 112,
844 ),
845 (
846 "
847 // This is a test comment
848 enum Color { red, green, blue }"
849 .unindent(),
850 322,
851 ),
852 (
853 "
854 /** This is a preceding block comment
855 * This is the second line
856 */
857 struct { // Structure declaration
858 int myNum; // Member (int variable)
859 string myString; // Member (string variable)
860 } myStructure;"
861 .unindent(),
862 425,
863 ),
864 (
865 "
866 /**
867 * @brief Matrix class.
868 */
869 template <typename T,
870 typename = typename std::enable_if<
871 std::is_integral<T>::value || std::is_floating_point<T>::value,
872 bool>::type>
873 class Matrix2 {
874 std::vector<std::vector<T>> _mat;
875
876 public:
877 /**
878 * @brief Constructor
879 * @tparam Integer ensuring integers are being evaluated and not other
880 * data types.
881 * @param size denoting the size of Matrix as size x size
882 */
883 template <typename Integer,
884 typename = typename std::enable_if<std::is_integral<Integer>::value,
885 Integer>::type>
886 explicit Matrix(const Integer size) {
887 for (size_t i = 0; i < size; ++i) {
888 _mat.emplace_back(std::vector<T>(size, 0));
889 }
890 }
891 }"
892 .unindent(),
893 612,
894 ),
895 (
896 "
897 explicit Matrix(const Integer size) {
898 for (size_t i = 0; i < size; ++i) {
899 _mat.emplace_back(std::vector<T>(size, 0));
900 }
901 }"
902 .unindent(),
903 1226,
904 ),
905 ],
906 );
907}
908
909#[gpui::test]
910async fn test_code_context_retrieval_ruby() {
911 let language = ruby_lang();
912 let embedding_provider = Arc::new(DummyEmbeddings {});
913 let mut retriever = CodeContextRetriever::new(embedding_provider);
914
915 let text = r#"
916 # This concern is inspired by "sudo mode" on GitHub. It
917 # is a way to re-authenticate a user before allowing them
918 # to see or perform an action.
919 #
920 # Add `before_action :require_challenge!` to actions you
921 # want to protect.
922 #
923 # The user will be shown a page to enter the challenge (which
924 # is either the password, or just the username when no
925 # password exists). Upon passing, there is a grace period
926 # during which no challenge will be asked from the user.
927 #
928 # Accessing challenge-protected resources during the grace
929 # period will refresh the grace period.
930 module ChallengableConcern
931 extend ActiveSupport::Concern
932
933 CHALLENGE_TIMEOUT = 1.hour.freeze
934
935 def require_challenge!
936 return if skip_challenge?
937
938 if challenge_passed_recently?
939 session[:challenge_passed_at] = Time.now.utc
940 return
941 end
942
943 @challenge = Form::Challenge.new(return_to: request.url)
944
945 if params.key?(:form_challenge)
946 if challenge_passed?
947 session[:challenge_passed_at] = Time.now.utc
948 else
949 flash.now[:alert] = I18n.t('challenge.invalid_password')
950 render_challenge
951 end
952 else
953 render_challenge
954 end
955 end
956
957 def challenge_passed?
958 current_user.valid_password?(challenge_params[:current_password])
959 end
960 end
961
962 class Animal
963 include Comparable
964
965 attr_reader :legs
966
967 def initialize(name, legs)
968 @name, @legs = name, legs
969 end
970
971 def <=>(other)
972 legs <=> other.legs
973 end
974 end
975
976 # Singleton method for car object
977 def car.wheels
978 puts "There are four wheels"
979 end"#
980 .unindent();
981
982 let documents = retriever.parse_file(&text, language.clone()).unwrap();
983
984 assert_documents_eq(
985 &documents,
986 &[
987 (
988 r#"
989 # This concern is inspired by "sudo mode" on GitHub. It
990 # is a way to re-authenticate a user before allowing them
991 # to see or perform an action.
992 #
993 # Add `before_action :require_challenge!` to actions you
994 # want to protect.
995 #
996 # The user will be shown a page to enter the challenge (which
997 # is either the password, or just the username when no
998 # password exists). Upon passing, there is a grace period
999 # during which no challenge will be asked from the user.
1000 #
1001 # Accessing challenge-protected resources during the grace
1002 # period will refresh the grace period.
1003 module ChallengableConcern
1004 extend ActiveSupport::Concern
1005
1006 CHALLENGE_TIMEOUT = 1.hour.freeze
1007
1008 def require_challenge!
1009 # ...
1010 end
1011
1012 def challenge_passed?
1013 # ...
1014 end
1015 end"#
1016 .unindent(),
1017 558,
1018 ),
1019 (
1020 r#"
1021 def require_challenge!
1022 return if skip_challenge?
1023
1024 if challenge_passed_recently?
1025 session[:challenge_passed_at] = Time.now.utc
1026 return
1027 end
1028
1029 @challenge = Form::Challenge.new(return_to: request.url)
1030
1031 if params.key?(:form_challenge)
1032 if challenge_passed?
1033 session[:challenge_passed_at] = Time.now.utc
1034 else
1035 flash.now[:alert] = I18n.t('challenge.invalid_password')
1036 render_challenge
1037 end
1038 else
1039 render_challenge
1040 end
1041 end"#
1042 .unindent(),
1043 663,
1044 ),
1045 (
1046 r#"
1047 def challenge_passed?
1048 current_user.valid_password?(challenge_params[:current_password])
1049 end"#
1050 .unindent(),
1051 1254,
1052 ),
1053 (
1054 r#"
1055 class Animal
1056 include Comparable
1057
1058 attr_reader :legs
1059
1060 def initialize(name, legs)
1061 # ...
1062 end
1063
1064 def <=>(other)
1065 # ...
1066 end
1067 end"#
1068 .unindent(),
1069 1363,
1070 ),
1071 (
1072 r#"
1073 def initialize(name, legs)
1074 @name, @legs = name, legs
1075 end"#
1076 .unindent(),
1077 1427,
1078 ),
1079 (
1080 r#"
1081 def <=>(other)
1082 legs <=> other.legs
1083 end"#
1084 .unindent(),
1085 1501,
1086 ),
1087 (
1088 r#"
1089 # Singleton method for car object
1090 def car.wheels
1091 puts "There are four wheels"
1092 end"#
1093 .unindent(),
1094 1591,
1095 ),
1096 ],
1097 );
1098}
1099
1100#[gpui::test]
1101async fn test_code_context_retrieval_php() {
1102 let language = php_lang();
1103 let embedding_provider = Arc::new(DummyEmbeddings {});
1104 let mut retriever = CodeContextRetriever::new(embedding_provider);
1105
1106 let text = r#"
1107 <?php
1108
1109 namespace LevelUp\Experience\Concerns;
1110
1111 /*
1112 This is a multiple-lines comment block
1113 that spans over multiple
1114 lines
1115 */
1116 function functionName() {
1117 echo "Hello world!";
1118 }
1119
1120 trait HasAchievements
1121 {
1122 /**
1123 * @throws \Exception
1124 */
1125 public function grantAchievement(Achievement $achievement, $progress = null): void
1126 {
1127 if ($progress > 100) {
1128 throw new Exception(message: 'Progress cannot be greater than 100');
1129 }
1130
1131 if ($this->achievements()->find($achievement->id)) {
1132 throw new Exception(message: 'User already has this Achievement');
1133 }
1134
1135 $this->achievements()->attach($achievement, [
1136 'progress' => $progress ?? null,
1137 ]);
1138
1139 $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1140 }
1141
1142 public function achievements(): BelongsToMany
1143 {
1144 return $this->belongsToMany(related: Achievement::class)
1145 ->withPivot(columns: 'progress')
1146 ->where('is_secret', false)
1147 ->using(AchievementUser::class);
1148 }
1149 }
1150
1151 interface Multiplier
1152 {
1153 public function qualifies(array $data): bool;
1154
1155 public function setMultiplier(): int;
1156 }
1157
1158 enum AuditType: string
1159 {
1160 case Add = 'add';
1161 case Remove = 'remove';
1162 case Reset = 'reset';
1163 case LevelUp = 'level_up';
1164 }
1165
1166 ?>"#
1167 .unindent();
1168
1169 let documents = retriever.parse_file(&text, language.clone()).unwrap();
1170
1171 assert_documents_eq(
1172 &documents,
1173 &[
1174 (
1175 r#"
1176 /*
1177 This is a multiple-lines comment block
1178 that spans over multiple
1179 lines
1180 */
1181 function functionName() {
1182 echo "Hello world!";
1183 }"#
1184 .unindent(),
1185 123,
1186 ),
1187 (
1188 r#"
1189 trait HasAchievements
1190 {
1191 /**
1192 * @throws \Exception
1193 */
1194 public function grantAchievement(Achievement $achievement, $progress = null): void
1195 {/* ... */}
1196
1197 public function achievements(): BelongsToMany
1198 {/* ... */}
1199 }"#
1200 .unindent(),
1201 177,
1202 ),
1203 (r#"
1204 /**
1205 * @throws \Exception
1206 */
1207 public function grantAchievement(Achievement $achievement, $progress = null): void
1208 {
1209 if ($progress > 100) {
1210 throw new Exception(message: 'Progress cannot be greater than 100');
1211 }
1212
1213 if ($this->achievements()->find($achievement->id)) {
1214 throw new Exception(message: 'User already has this Achievement');
1215 }
1216
1217 $this->achievements()->attach($achievement, [
1218 'progress' => $progress ?? null,
1219 ]);
1220
1221 $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1222 }"#.unindent(), 245),
1223 (r#"
1224 public function achievements(): BelongsToMany
1225 {
1226 return $this->belongsToMany(related: Achievement::class)
1227 ->withPivot(columns: 'progress')
1228 ->where('is_secret', false)
1229 ->using(AchievementUser::class);
1230 }"#.unindent(), 902),
1231 (r#"
1232 interface Multiplier
1233 {
1234 public function qualifies(array $data): bool;
1235
1236 public function setMultiplier(): int;
1237 }"#.unindent(),
1238 1146),
1239 (r#"
1240 enum AuditType: string
1241 {
1242 case Add = 'add';
1243 case Remove = 'remove';
1244 case Reset = 'reset';
1245 case LevelUp = 'level_up';
1246 }"#.unindent(), 1265)
1247 ],
1248 );
1249}
1250
1251#[derive(Default)]
1252struct FakeEmbeddingProvider {
1253 embedding_count: AtomicUsize,
1254}
1255
1256impl FakeEmbeddingProvider {
1257 fn embedding_count(&self) -> usize {
1258 self.embedding_count.load(atomic::Ordering::SeqCst)
1259 }
1260
1261 fn embed_sync(&self, span: &str) -> Embedding {
1262 let mut result = vec![1.0; 26];
1263 for letter in span.chars() {
1264 let letter = letter.to_ascii_lowercase();
1265 if letter as u32 >= 'a' as u32 {
1266 let ix = (letter as u32) - ('a' as u32);
1267 if ix < 26 {
1268 result[ix as usize] += 1.0;
1269 }
1270 }
1271 }
1272
1273 let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
1274 for x in &mut result {
1275 *x /= norm;
1276 }
1277
1278 result.into()
1279 }
1280}
1281
1282#[async_trait]
1283impl EmbeddingProvider for FakeEmbeddingProvider {
1284 fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
1285 Some("Fake Credentials".to_string())
1286 }
1287 fn truncate(&self, span: &str) -> (String, usize) {
1288 (span.to_string(), 1)
1289 }
1290
1291 fn max_tokens_per_batch(&self) -> usize {
1292 200
1293 }
1294
1295 fn rate_limit_expiration(&self) -> Option<Instant> {
1296 None
1297 }
1298
1299 async fn embed_batch(
1300 &self,
1301 spans: Vec<String>,
1302 _api_key: Option<String>,
1303 ) -> Result<Vec<Embedding>> {
1304 self.embedding_count
1305 .fetch_add(spans.len(), atomic::Ordering::SeqCst);
1306 Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
1307 }
1308}
1309
1310fn js_lang() -> Arc<Language> {
1311 Arc::new(
1312 Language::new(
1313 LanguageConfig {
1314 name: "Javascript".into(),
1315 path_suffixes: vec!["js".into()],
1316 ..Default::default()
1317 },
1318 Some(tree_sitter_typescript::language_tsx()),
1319 )
1320 .with_embedding_query(
1321 &r#"
1322
1323 (
1324 (comment)* @context
1325 .
1326 [
1327 (export_statement
1328 (function_declaration
1329 "async"? @name
1330 "function" @name
1331 name: (_) @name))
1332 (function_declaration
1333 "async"? @name
1334 "function" @name
1335 name: (_) @name)
1336 ] @item
1337 )
1338
1339 (
1340 (comment)* @context
1341 .
1342 [
1343 (export_statement
1344 (class_declaration
1345 "class" @name
1346 name: (_) @name))
1347 (class_declaration
1348 "class" @name
1349 name: (_) @name)
1350 ] @item
1351 )
1352
1353 (
1354 (comment)* @context
1355 .
1356 [
1357 (export_statement
1358 (interface_declaration
1359 "interface" @name
1360 name: (_) @name))
1361 (interface_declaration
1362 "interface" @name
1363 name: (_) @name)
1364 ] @item
1365 )
1366
1367 (
1368 (comment)* @context
1369 .
1370 [
1371 (export_statement
1372 (enum_declaration
1373 "enum" @name
1374 name: (_) @name))
1375 (enum_declaration
1376 "enum" @name
1377 name: (_) @name)
1378 ] @item
1379 )
1380
1381 (
1382 (comment)* @context
1383 .
1384 (method_definition
1385 [
1386 "get"
1387 "set"
1388 "async"
1389 "*"
1390 "static"
1391 ]* @name
1392 name: (_) @name) @item
1393 )
1394
1395 "#
1396 .unindent(),
1397 )
1398 .unwrap(),
1399 )
1400}
1401
1402fn rust_lang() -> Arc<Language> {
1403 Arc::new(
1404 Language::new(
1405 LanguageConfig {
1406 name: "Rust".into(),
1407 path_suffixes: vec!["rs".into()],
1408 collapsed_placeholder: " /* ... */ ".to_string(),
1409 ..Default::default()
1410 },
1411 Some(tree_sitter_rust::language()),
1412 )
1413 .with_embedding_query(
1414 r#"
1415 (
1416 [(line_comment) (attribute_item)]* @context
1417 .
1418 [
1419 (struct_item
1420 name: (_) @name)
1421
1422 (enum_item
1423 name: (_) @name)
1424
1425 (impl_item
1426 trait: (_)? @name
1427 "for"? @name
1428 type: (_) @name)
1429
1430 (trait_item
1431 name: (_) @name)
1432
1433 (function_item
1434 name: (_) @name
1435 body: (block
1436 "{" @keep
1437 "}" @keep) @collapse)
1438
1439 (macro_definition
1440 name: (_) @name)
1441 ] @item
1442 )
1443
1444 (attribute_item) @collapse
1445 (use_declaration) @collapse
1446 "#,
1447 )
1448 .unwrap(),
1449 )
1450}
1451
1452fn json_lang() -> Arc<Language> {
1453 Arc::new(
1454 Language::new(
1455 LanguageConfig {
1456 name: "JSON".into(),
1457 path_suffixes: vec!["json".into()],
1458 ..Default::default()
1459 },
1460 Some(tree_sitter_json::language()),
1461 )
1462 .with_embedding_query(
1463 r#"
1464 (document) @item
1465
1466 (array
1467 "[" @keep
1468 .
1469 (object)? @keep
1470 "]" @keep) @collapse
1471
1472 (pair value: (string
1473 "\"" @keep
1474 "\"" @keep) @collapse)
1475 "#,
1476 )
1477 .unwrap(),
1478 )
1479}
1480
1481fn toml_lang() -> Arc<Language> {
1482 Arc::new(Language::new(
1483 LanguageConfig {
1484 name: "TOML".into(),
1485 path_suffixes: vec!["toml".into()],
1486 ..Default::default()
1487 },
1488 Some(tree_sitter_toml::language()),
1489 ))
1490}
1491
1492fn cpp_lang() -> Arc<Language> {
1493 Arc::new(
1494 Language::new(
1495 LanguageConfig {
1496 name: "CPP".into(),
1497 path_suffixes: vec!["cpp".into()],
1498 ..Default::default()
1499 },
1500 Some(tree_sitter_cpp::language()),
1501 )
1502 .with_embedding_query(
1503 r#"
1504 (
1505 (comment)* @context
1506 .
1507 (function_definition
1508 (type_qualifier)? @name
1509 type: (_)? @name
1510 declarator: [
1511 (function_declarator
1512 declarator: (_) @name)
1513 (pointer_declarator
1514 "*" @name
1515 declarator: (function_declarator
1516 declarator: (_) @name))
1517 (pointer_declarator
1518 "*" @name
1519 declarator: (pointer_declarator
1520 "*" @name
1521 declarator: (function_declarator
1522 declarator: (_) @name)))
1523 (reference_declarator
1524 ["&" "&&"] @name
1525 (function_declarator
1526 declarator: (_) @name))
1527 ]
1528 (type_qualifier)? @name) @item
1529 )
1530
1531 (
1532 (comment)* @context
1533 .
1534 (template_declaration
1535 (class_specifier
1536 "class" @name
1537 name: (_) @name)
1538 ) @item
1539 )
1540
1541 (
1542 (comment)* @context
1543 .
1544 (class_specifier
1545 "class" @name
1546 name: (_) @name) @item
1547 )
1548
1549 (
1550 (comment)* @context
1551 .
1552 (enum_specifier
1553 "enum" @name
1554 name: (_) @name) @item
1555 )
1556
1557 (
1558 (comment)* @context
1559 .
1560 (declaration
1561 type: (struct_specifier
1562 "struct" @name)
1563 declarator: (_) @name) @item
1564 )
1565
1566 "#,
1567 )
1568 .unwrap(),
1569 )
1570}
1571
1572fn lua_lang() -> Arc<Language> {
1573 Arc::new(
1574 Language::new(
1575 LanguageConfig {
1576 name: "Lua".into(),
1577 path_suffixes: vec!["lua".into()],
1578 collapsed_placeholder: "--[ ... ]--".to_string(),
1579 ..Default::default()
1580 },
1581 Some(tree_sitter_lua::language()),
1582 )
1583 .with_embedding_query(
1584 r#"
1585 (
1586 (comment)* @context
1587 .
1588 (function_declaration
1589 "function" @name
1590 name: (_) @name
1591 (comment)* @collapse
1592 body: (block) @collapse
1593 ) @item
1594 )
1595 "#,
1596 )
1597 .unwrap(),
1598 )
1599}
1600
1601fn php_lang() -> Arc<Language> {
1602 Arc::new(
1603 Language::new(
1604 LanguageConfig {
1605 name: "PHP".into(),
1606 path_suffixes: vec!["php".into()],
1607 collapsed_placeholder: "/* ... */".into(),
1608 ..Default::default()
1609 },
1610 Some(tree_sitter_php::language()),
1611 )
1612 .with_embedding_query(
1613 r#"
1614 (
1615 (comment)* @context
1616 .
1617 [
1618 (function_definition
1619 "function" @name
1620 name: (_) @name
1621 body: (_
1622 "{" @keep
1623 "}" @keep) @collapse
1624 )
1625
1626 (trait_declaration
1627 "trait" @name
1628 name: (_) @name)
1629
1630 (method_declaration
1631 "function" @name
1632 name: (_) @name
1633 body: (_
1634 "{" @keep
1635 "}" @keep) @collapse
1636 )
1637
1638 (interface_declaration
1639 "interface" @name
1640 name: (_) @name
1641 )
1642
1643 (enum_declaration
1644 "enum" @name
1645 name: (_) @name
1646 )
1647
1648 ] @item
1649 )
1650 "#,
1651 )
1652 .unwrap(),
1653 )
1654}
1655
1656fn ruby_lang() -> Arc<Language> {
1657 Arc::new(
1658 Language::new(
1659 LanguageConfig {
1660 name: "Ruby".into(),
1661 path_suffixes: vec!["rb".into()],
1662 collapsed_placeholder: "# ...".to_string(),
1663 ..Default::default()
1664 },
1665 Some(tree_sitter_ruby::language()),
1666 )
1667 .with_embedding_query(
1668 r#"
1669 (
1670 (comment)* @context
1671 .
1672 [
1673 (module
1674 "module" @name
1675 name: (_) @name)
1676 (method
1677 "def" @name
1678 name: (_) @name
1679 body: (body_statement) @collapse)
1680 (class
1681 "class" @name
1682 name: (_) @name)
1683 (singleton_method
1684 "def" @name
1685 object: (_) @name
1686 "." @name
1687 name: (_) @name
1688 body: (body_statement) @collapse)
1689 ] @item
1690 )
1691 "#,
1692 )
1693 .unwrap(),
1694 )
1695}
1696
1697fn elixir_lang() -> Arc<Language> {
1698 Arc::new(
1699 Language::new(
1700 LanguageConfig {
1701 name: "Elixir".into(),
1702 path_suffixes: vec!["rs".into()],
1703 ..Default::default()
1704 },
1705 Some(tree_sitter_elixir::language()),
1706 )
1707 .with_embedding_query(
1708 r#"
1709 (
1710 (unary_operator
1711 operator: "@"
1712 operand: (call
1713 target: (identifier) @unary
1714 (#match? @unary "^(doc)$"))
1715 ) @context
1716 .
1717 (call
1718 target: (identifier) @name
1719 (arguments
1720 [
1721 (identifier) @name
1722 (call
1723 target: (identifier) @name)
1724 (binary_operator
1725 left: (call
1726 target: (identifier) @name)
1727 operator: "when")
1728 ])
1729 (#match? @name "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @item
1730 )
1731
1732 (call
1733 target: (identifier) @name
1734 (arguments (alias) @name)
1735 (#match? @name "^(defmodule|defprotocol)$")) @item
1736 "#,
1737 )
1738 .unwrap(),
1739 )
1740}
1741
1742#[gpui::test]
1743fn test_subtract_ranges() {
1744 // collapsed_ranges: Vec<Range<usize>>, keep_ranges: Vec<Range<usize>>
1745
1746 assert_eq!(
1747 subtract_ranges(&[0..5, 10..21], &[0..1, 4..5]),
1748 vec![1..4, 10..21]
1749 );
1750
1751 assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
1752}
1753
1754fn init_test(cx: &mut TestAppContext) {
1755 cx.update(|cx| {
1756 cx.set_global(SettingsStore::test(cx));
1757 settings::register::<SemanticIndexSettings>(cx);
1758 settings::register::<ProjectSettings>(cx);
1759 });
1760}