1use crate::{
2 embedding_queue::EmbeddingQueue,
3 parsing::{subtract_ranges, CodeContextRetriever, Span, SpanDigest},
4 semantic_index_settings::SemanticIndexSettings,
5 FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
6};
7use ai::embedding::{DummyEmbeddings, Embedding, EmbeddingProvider};
8use anyhow::Result;
9use async_trait::async_trait;
10use gpui::{executor::Deterministic, Task, TestAppContext};
11use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
12use parking_lot::Mutex;
13use pretty_assertions::assert_eq;
14use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs, Project};
15use rand::{rngs::StdRng, Rng};
16use serde_json::json;
17use settings::SettingsStore;
18use std::{
19 path::Path,
20 sync::{
21 atomic::{self, AtomicUsize},
22 Arc,
23 },
24 time::{Instant, SystemTime},
25};
26use unindent::Unindent;
27use util::RandomCharIter;
28
29#[ctor::ctor]
30fn init_logger() {
31 if std::env::var("RUST_LOG").is_ok() {
32 env_logger::init();
33 }
34}
35
36#[gpui::test]
37async fn test_semantic_index(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
38 init_test(cx);
39
40 let fs = FakeFs::new(cx.background());
41 fs.insert_tree(
42 "/the-root",
43 json!({
44 "src": {
45 "file1.rs": "
46 fn aaa() {
47 println!(\"aaaaaaaaaaaa!\");
48 }
49
50 fn zzzzz() {
51 println!(\"SLEEPING\");
52 }
53 ".unindent(),
54 "file2.rs": "
55 fn bbb() {
56 println!(\"bbbbbbbbbbbbb!\");
57 }
58 struct pqpqpqp {}
59 ".unindent(),
60 "file3.toml": "
61 ZZZZZZZZZZZZZZZZZZ = 5
62 ".unindent(),
63 }
64 }),
65 )
66 .await;
67
68 let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
69 let rust_language = rust_lang();
70 let toml_language = toml_lang();
71 languages.add(rust_language);
72 languages.add(toml_language);
73
74 let db_dir = tempdir::TempDir::new("vector-store").unwrap();
75 let db_path = db_dir.path().join("db.sqlite");
76
77 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
78 let semantic_index = SemanticIndex::new(
79 fs.clone(),
80 db_path,
81 embedding_provider.clone(),
82 languages,
83 cx.to_async(),
84 )
85 .await
86 .unwrap();
87
88 let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
89
90 let search_results = semantic_index.update(cx, |store, cx| {
91 store.search_project(
92 project.clone(),
93 "aaaaaabbbbzz".to_string(),
94 5,
95 vec![],
96 vec![],
97 cx,
98 )
99 });
100 let pending_file_count =
101 semantic_index.read_with(cx, |index, _| index.pending_file_count(&project).unwrap());
102 deterministic.run_until_parked();
103 assert_eq!(*pending_file_count.borrow(), 3);
104 deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
105 assert_eq!(*pending_file_count.borrow(), 0);
106
107 let search_results = search_results.await.unwrap();
108 assert_search_results(
109 &search_results,
110 &[
111 (Path::new("src/file1.rs").into(), 0),
112 (Path::new("src/file2.rs").into(), 0),
113 (Path::new("src/file3.toml").into(), 0),
114 (Path::new("src/file1.rs").into(), 45),
115 (Path::new("src/file2.rs").into(), 45),
116 ],
117 cx,
118 );
119
120 // Test Include Files Functonality
121 let include_files = vec![PathMatcher::new("*.rs").unwrap()];
122 let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
123 let rust_only_search_results = semantic_index
124 .update(cx, |store, cx| {
125 store.search_project(
126 project.clone(),
127 "aaaaaabbbbzz".to_string(),
128 5,
129 include_files,
130 vec![],
131 cx,
132 )
133 })
134 .await
135 .unwrap();
136
137 assert_search_results(
138 &rust_only_search_results,
139 &[
140 (Path::new("src/file1.rs").into(), 0),
141 (Path::new("src/file2.rs").into(), 0),
142 (Path::new("src/file1.rs").into(), 45),
143 (Path::new("src/file2.rs").into(), 45),
144 ],
145 cx,
146 );
147
148 let no_rust_search_results = semantic_index
149 .update(cx, |store, cx| {
150 store.search_project(
151 project.clone(),
152 "aaaaaabbbbzz".to_string(),
153 5,
154 vec![],
155 exclude_files,
156 cx,
157 )
158 })
159 .await
160 .unwrap();
161
162 assert_search_results(
163 &no_rust_search_results,
164 &[(Path::new("src/file3.toml").into(), 0)],
165 cx,
166 );
167
168 fs.save(
169 "/the-root/src/file2.rs".as_ref(),
170 &"
171 fn dddd() { println!(\"ddddd!\"); }
172 struct pqpqpqp {}
173 "
174 .unindent()
175 .into(),
176 Default::default(),
177 )
178 .await
179 .unwrap();
180
181 deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
182
183 let prev_embedding_count = embedding_provider.embedding_count();
184 let index = semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx));
185 deterministic.run_until_parked();
186 assert_eq!(*pending_file_count.borrow(), 1);
187 deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
188 assert_eq!(*pending_file_count.borrow(), 0);
189 index.await.unwrap();
190
191 assert_eq!(
192 embedding_provider.embedding_count() - prev_embedding_count,
193 1
194 );
195}
196
197#[gpui::test(iterations = 10)]
198async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
199 let (outstanding_job_count, _) = postage::watch::channel_with(0);
200 let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count));
201
202 let files = (1..=3)
203 .map(|file_ix| FileToEmbed {
204 worktree_id: 5,
205 path: Path::new(&format!("path-{file_ix}")).into(),
206 mtime: SystemTime::now(),
207 spans: (0..rng.gen_range(4..22))
208 .map(|document_ix| {
209 let content_len = rng.gen_range(10..100);
210 let content = RandomCharIter::new(&mut rng)
211 .with_simple_text()
212 .take(content_len)
213 .collect::<String>();
214 let digest = SpanDigest::from(content.as_str());
215 Span {
216 range: 0..10,
217 embedding: None,
218 name: format!("document {document_ix}"),
219 content,
220 digest,
221 token_count: rng.gen_range(10..30),
222 }
223 })
224 .collect(),
225 job_handle: JobHandle::new(&outstanding_job_count),
226 })
227 .collect::<Vec<_>>();
228
229 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
230
231 let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
232 for file in &files {
233 queue.push(file.clone());
234 }
235 queue.flush();
236
237 cx.foreground().run_until_parked();
238 let finished_files = queue.finished_files();
239 let mut embedded_files: Vec<_> = files
240 .iter()
241 .map(|_| finished_files.try_recv().expect("no finished file"))
242 .collect();
243
244 let expected_files: Vec<_> = files
245 .iter()
246 .map(|file| {
247 let mut file = file.clone();
248 for doc in &mut file.spans {
249 doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
250 }
251 file
252 })
253 .collect();
254
255 embedded_files.sort_by_key(|f| f.path.clone());
256
257 assert_eq!(embedded_files, expected_files);
258}
259
260#[track_caller]
261fn assert_search_results(
262 actual: &[SearchResult],
263 expected: &[(Arc<Path>, usize)],
264 cx: &TestAppContext,
265) {
266 let actual = actual
267 .iter()
268 .map(|search_result| {
269 search_result.buffer.read_with(cx, |buffer, _cx| {
270 (
271 buffer.file().unwrap().path().clone(),
272 search_result.range.start.to_offset(buffer),
273 )
274 })
275 })
276 .collect::<Vec<_>>();
277 assert_eq!(actual, expected);
278}
279
280#[gpui::test]
281async fn test_code_context_retrieval_rust() {
282 let language = rust_lang();
283 let embedding_provider = Arc::new(DummyEmbeddings {});
284 let mut retriever = CodeContextRetriever::new(embedding_provider);
285
286 let text = "
287 /// A doc comment
288 /// that spans multiple lines
289 #[gpui::test]
290 fn a() {
291 b
292 }
293
294 impl C for D {
295 }
296
297 impl E {
298 // This is also a preceding comment
299 pub fn function_1() -> Option<()> {
300 todo!();
301 }
302
303 // This is a preceding comment
304 fn function_2() -> Result<()> {
305 todo!();
306 }
307 }
308 "
309 .unindent();
310
311 let documents = retriever.parse_file(&text, language).unwrap();
312
313 assert_documents_eq(
314 &documents,
315 &[
316 (
317 "
318 /// A doc comment
319 /// that spans multiple lines
320 #[gpui::test]
321 fn a() {
322 b
323 }"
324 .unindent(),
325 text.find("fn a").unwrap(),
326 ),
327 (
328 "
329 impl C for D {
330 }"
331 .unindent(),
332 text.find("impl C").unwrap(),
333 ),
334 (
335 "
336 impl E {
337 // This is also a preceding comment
338 pub fn function_1() -> Option<()> { /* ... */ }
339
340 // This is a preceding comment
341 fn function_2() -> Result<()> { /* ... */ }
342 }"
343 .unindent(),
344 text.find("impl E").unwrap(),
345 ),
346 (
347 "
348 // This is also a preceding comment
349 pub fn function_1() -> Option<()> {
350 todo!();
351 }"
352 .unindent(),
353 text.find("pub fn function_1").unwrap(),
354 ),
355 (
356 "
357 // This is a preceding comment
358 fn function_2() -> Result<()> {
359 todo!();
360 }"
361 .unindent(),
362 text.find("fn function_2").unwrap(),
363 ),
364 ],
365 );
366}
367
368#[gpui::test]
369async fn test_code_context_retrieval_json() {
370 let language = json_lang();
371 let embedding_provider = Arc::new(DummyEmbeddings {});
372 let mut retriever = CodeContextRetriever::new(embedding_provider);
373
374 let text = r#"
375 {
376 "array": [1, 2, 3, 4],
377 "string": "abcdefg",
378 "nested_object": {
379 "array_2": [5, 6, 7, 8],
380 "string_2": "hijklmnop",
381 "boolean": true,
382 "none": null
383 }
384 }
385 "#
386 .unindent();
387
388 let documents = retriever.parse_file(&text, language.clone()).unwrap();
389
390 assert_documents_eq(
391 &documents,
392 &[(
393 r#"
394 {
395 "array": [],
396 "string": "",
397 "nested_object": {
398 "array_2": [],
399 "string_2": "",
400 "boolean": true,
401 "none": null
402 }
403 }"#
404 .unindent(),
405 text.find("{").unwrap(),
406 )],
407 );
408
409 let text = r#"
410 [
411 {
412 "name": "somebody",
413 "age": 42
414 },
415 {
416 "name": "somebody else",
417 "age": 43
418 }
419 ]
420 "#
421 .unindent();
422
423 let documents = retriever.parse_file(&text, language.clone()).unwrap();
424
425 assert_documents_eq(
426 &documents,
427 &[(
428 r#"
429 [{
430 "name": "",
431 "age": 42
432 }]"#
433 .unindent(),
434 text.find("[").unwrap(),
435 )],
436 );
437}
438
439fn assert_documents_eq(
440 documents: &[Span],
441 expected_contents_and_start_offsets: &[(String, usize)],
442) {
443 assert_eq!(
444 documents
445 .iter()
446 .map(|document| (document.content.clone(), document.range.start))
447 .collect::<Vec<_>>(),
448 expected_contents_and_start_offsets
449 );
450}
451
452#[gpui::test]
453async fn test_code_context_retrieval_javascript() {
454 let language = js_lang();
455 let embedding_provider = Arc::new(DummyEmbeddings {});
456 let mut retriever = CodeContextRetriever::new(embedding_provider);
457
458 let text = "
459 /* globals importScripts, backend */
460 function _authorize() {}
461
462 /**
463 * Sometimes the frontend build is way faster than backend.
464 */
465 export async function authorizeBank() {
466 _authorize(pushModal, upgradingAccountId, {});
467 }
468
469 export class SettingsPage {
470 /* This is a test setting */
471 constructor(page) {
472 this.page = page;
473 }
474 }
475
476 /* This is a test comment */
477 class TestClass {}
478
479 /* Schema for editor_events in Clickhouse. */
480 export interface ClickhouseEditorEvent {
481 installation_id: string
482 operation: string
483 }
484 "
485 .unindent();
486
487 let documents = retriever.parse_file(&text, language.clone()).unwrap();
488
489 assert_documents_eq(
490 &documents,
491 &[
492 (
493 "
494 /* globals importScripts, backend */
495 function _authorize() {}"
496 .unindent(),
497 37,
498 ),
499 (
500 "
501 /**
502 * Sometimes the frontend build is way faster than backend.
503 */
504 export async function authorizeBank() {
505 _authorize(pushModal, upgradingAccountId, {});
506 }"
507 .unindent(),
508 131,
509 ),
510 (
511 "
512 export class SettingsPage {
513 /* This is a test setting */
514 constructor(page) {
515 this.page = page;
516 }
517 }"
518 .unindent(),
519 225,
520 ),
521 (
522 "
523 /* This is a test setting */
524 constructor(page) {
525 this.page = page;
526 }"
527 .unindent(),
528 290,
529 ),
530 (
531 "
532 /* This is a test comment */
533 class TestClass {}"
534 .unindent(),
535 374,
536 ),
537 (
538 "
539 /* Schema for editor_events in Clickhouse. */
540 export interface ClickhouseEditorEvent {
541 installation_id: string
542 operation: string
543 }"
544 .unindent(),
545 440,
546 ),
547 ],
548 )
549}
550
551#[gpui::test]
552async fn test_code_context_retrieval_lua() {
553 let language = lua_lang();
554 let embedding_provider = Arc::new(DummyEmbeddings {});
555 let mut retriever = CodeContextRetriever::new(embedding_provider);
556
557 let text = r#"
558 -- Creates a new class
559 -- @param baseclass The Baseclass of this class, or nil.
560 -- @return A new class reference.
561 function classes.class(baseclass)
562 -- Create the class definition and metatable.
563 local classdef = {}
564 -- Find the super class, either Object or user-defined.
565 baseclass = baseclass or classes.Object
566 -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
567 setmetatable(classdef, { __index = baseclass })
568 -- All class instances have a reference to the class object.
569 classdef.class = classdef
570 --- Recursivly allocates the inheritance tree of the instance.
571 -- @param mastertable The 'root' of the inheritance tree.
572 -- @return Returns the instance with the allocated inheritance tree.
573 function classdef.alloc(mastertable)
574 -- All class instances have a reference to a superclass object.
575 local instance = { super = baseclass.alloc(mastertable) }
576 -- Any functions this instance does not know of will 'look up' to the superclass definition.
577 setmetatable(instance, { __index = classdef, __newindex = mastertable })
578 return instance
579 end
580 end
581 "#.unindent();
582
583 let documents = retriever.parse_file(&text, language.clone()).unwrap();
584
585 assert_documents_eq(
586 &documents,
587 &[
588 (r#"
589 -- Creates a new class
590 -- @param baseclass The Baseclass of this class, or nil.
591 -- @return A new class reference.
592 function classes.class(baseclass)
593 -- Create the class definition and metatable.
594 local classdef = {}
595 -- Find the super class, either Object or user-defined.
596 baseclass = baseclass or classes.Object
597 -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
598 setmetatable(classdef, { __index = baseclass })
599 -- All class instances have a reference to the class object.
600 classdef.class = classdef
601 --- Recursivly allocates the inheritance tree of the instance.
602 -- @param mastertable The 'root' of the inheritance tree.
603 -- @return Returns the instance with the allocated inheritance tree.
604 function classdef.alloc(mastertable)
605 --[ ... ]--
606 --[ ... ]--
607 end
608 end"#.unindent(),
609 114),
610 (r#"
611 --- Recursivly allocates the inheritance tree of the instance.
612 -- @param mastertable The 'root' of the inheritance tree.
613 -- @return Returns the instance with the allocated inheritance tree.
614 function classdef.alloc(mastertable)
615 -- All class instances have a reference to a superclass object.
616 local instance = { super = baseclass.alloc(mastertable) }
617 -- Any functions this instance does not know of will 'look up' to the superclass definition.
618 setmetatable(instance, { __index = classdef, __newindex = mastertable })
619 return instance
620 end"#.unindent(), 809),
621 ]
622 );
623}
624
625#[gpui::test]
626async fn test_code_context_retrieval_elixir() {
627 let language = elixir_lang();
628 let embedding_provider = Arc::new(DummyEmbeddings {});
629 let mut retriever = CodeContextRetriever::new(embedding_provider);
630
631 let text = r#"
632 defmodule File.Stream do
633 @moduledoc """
634 Defines a `File.Stream` struct returned by `File.stream!/3`.
635
636 The following fields are public:
637
638 * `path` - the file path
639 * `modes` - the file modes
640 * `raw` - a boolean indicating if bin functions should be used
641 * `line_or_bytes` - if reading should read lines or a given number of bytes
642 * `node` - the node the file belongs to
643
644 """
645
646 defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
647
648 @type t :: %__MODULE__{}
649
650 @doc false
651 def __build__(path, modes, line_or_bytes) do
652 raw = :lists.keyfind(:encoding, 1, modes) == false
653
654 modes =
655 case raw do
656 true ->
657 case :lists.keyfind(:read_ahead, 1, modes) do
658 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
659 {:read_ahead, _} -> [:raw | modes]
660 false -> [:raw, :read_ahead | modes]
661 end
662
663 false ->
664 modes
665 end
666
667 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
668
669 end"#
670 .unindent();
671
672 let documents = retriever.parse_file(&text, language.clone()).unwrap();
673
674 assert_documents_eq(
675 &documents,
676 &[(
677 r#"
678 defmodule File.Stream do
679 @moduledoc """
680 Defines a `File.Stream` struct returned by `File.stream!/3`.
681
682 The following fields are public:
683
684 * `path` - the file path
685 * `modes` - the file modes
686 * `raw` - a boolean indicating if bin functions should be used
687 * `line_or_bytes` - if reading should read lines or a given number of bytes
688 * `node` - the node the file belongs to
689
690 """
691
692 defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
693
694 @type t :: %__MODULE__{}
695
696 @doc false
697 def __build__(path, modes, line_or_bytes) do
698 raw = :lists.keyfind(:encoding, 1, modes) == false
699
700 modes =
701 case raw do
702 true ->
703 case :lists.keyfind(:read_ahead, 1, modes) do
704 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
705 {:read_ahead, _} -> [:raw | modes]
706 false -> [:raw, :read_ahead | modes]
707 end
708
709 false ->
710 modes
711 end
712
713 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
714
715 end"#
716 .unindent(),
717 0,
718 ),(r#"
719 @doc false
720 def __build__(path, modes, line_or_bytes) do
721 raw = :lists.keyfind(:encoding, 1, modes) == false
722
723 modes =
724 case raw do
725 true ->
726 case :lists.keyfind(:read_ahead, 1, modes) do
727 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
728 {:read_ahead, _} -> [:raw | modes]
729 false -> [:raw, :read_ahead | modes]
730 end
731
732 false ->
733 modes
734 end
735
736 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
737
738 end"#.unindent(), 574)],
739 );
740}
741
742#[gpui::test]
743async fn test_code_context_retrieval_cpp() {
744 let language = cpp_lang();
745 let embedding_provider = Arc::new(DummyEmbeddings {});
746 let mut retriever = CodeContextRetriever::new(embedding_provider);
747
748 let text = "
749 /**
750 * @brief Main function
751 * @returns 0 on exit
752 */
753 int main() { return 0; }
754
755 /**
756 * This is a test comment
757 */
758 class MyClass { // The class
759 public: // Access specifier
760 int myNum; // Attribute (int variable)
761 string myString; // Attribute (string variable)
762 };
763
764 // This is a test comment
765 enum Color { red, green, blue };
766
767 /** This is a preceding block comment
768 * This is the second line
769 */
770 struct { // Structure declaration
771 int myNum; // Member (int variable)
772 string myString; // Member (string variable)
773 } myStructure;
774
775 /**
776 * @brief Matrix class.
777 */
778 template <typename T,
779 typename = typename std::enable_if<
780 std::is_integral<T>::value || std::is_floating_point<T>::value,
781 bool>::type>
782 class Matrix2 {
783 std::vector<std::vector<T>> _mat;
784
785 public:
786 /**
787 * @brief Constructor
788 * @tparam Integer ensuring integers are being evaluated and not other
789 * data types.
790 * @param size denoting the size of Matrix as size x size
791 */
792 template <typename Integer,
793 typename = typename std::enable_if<std::is_integral<Integer>::value,
794 Integer>::type>
795 explicit Matrix(const Integer size) {
796 for (size_t i = 0; i < size; ++i) {
797 _mat.emplace_back(std::vector<T>(size, 0));
798 }
799 }
800 }"
801 .unindent();
802
803 let documents = retriever.parse_file(&text, language.clone()).unwrap();
804
805 assert_documents_eq(
806 &documents,
807 &[
808 (
809 "
810 /**
811 * @brief Main function
812 * @returns 0 on exit
813 */
814 int main() { return 0; }"
815 .unindent(),
816 54,
817 ),
818 (
819 "
820 /**
821 * This is a test comment
822 */
823 class MyClass { // The class
824 public: // Access specifier
825 int myNum; // Attribute (int variable)
826 string myString; // Attribute (string variable)
827 }"
828 .unindent(),
829 112,
830 ),
831 (
832 "
833 // This is a test comment
834 enum Color { red, green, blue }"
835 .unindent(),
836 322,
837 ),
838 (
839 "
840 /** This is a preceding block comment
841 * This is the second line
842 */
843 struct { // Structure declaration
844 int myNum; // Member (int variable)
845 string myString; // Member (string variable)
846 } myStructure;"
847 .unindent(),
848 425,
849 ),
850 (
851 "
852 /**
853 * @brief Matrix class.
854 */
855 template <typename T,
856 typename = typename std::enable_if<
857 std::is_integral<T>::value || std::is_floating_point<T>::value,
858 bool>::type>
859 class Matrix2 {
860 std::vector<std::vector<T>> _mat;
861
862 public:
863 /**
864 * @brief Constructor
865 * @tparam Integer ensuring integers are being evaluated and not other
866 * data types.
867 * @param size denoting the size of Matrix as size x size
868 */
869 template <typename Integer,
870 typename = typename std::enable_if<std::is_integral<Integer>::value,
871 Integer>::type>
872 explicit Matrix(const Integer size) {
873 for (size_t i = 0; i < size; ++i) {
874 _mat.emplace_back(std::vector<T>(size, 0));
875 }
876 }
877 }"
878 .unindent(),
879 612,
880 ),
881 (
882 "
883 explicit Matrix(const Integer size) {
884 for (size_t i = 0; i < size; ++i) {
885 _mat.emplace_back(std::vector<T>(size, 0));
886 }
887 }"
888 .unindent(),
889 1226,
890 ),
891 ],
892 );
893}
894
895#[gpui::test]
896async fn test_code_context_retrieval_ruby() {
897 let language = ruby_lang();
898 let embedding_provider = Arc::new(DummyEmbeddings {});
899 let mut retriever = CodeContextRetriever::new(embedding_provider);
900
901 let text = r#"
902 # This concern is inspired by "sudo mode" on GitHub. It
903 # is a way to re-authenticate a user before allowing them
904 # to see or perform an action.
905 #
906 # Add `before_action :require_challenge!` to actions you
907 # want to protect.
908 #
909 # The user will be shown a page to enter the challenge (which
910 # is either the password, or just the username when no
911 # password exists). Upon passing, there is a grace period
912 # during which no challenge will be asked from the user.
913 #
914 # Accessing challenge-protected resources during the grace
915 # period will refresh the grace period.
916 module ChallengableConcern
917 extend ActiveSupport::Concern
918
919 CHALLENGE_TIMEOUT = 1.hour.freeze
920
921 def require_challenge!
922 return if skip_challenge?
923
924 if challenge_passed_recently?
925 session[:challenge_passed_at] = Time.now.utc
926 return
927 end
928
929 @challenge = Form::Challenge.new(return_to: request.url)
930
931 if params.key?(:form_challenge)
932 if challenge_passed?
933 session[:challenge_passed_at] = Time.now.utc
934 else
935 flash.now[:alert] = I18n.t('challenge.invalid_password')
936 render_challenge
937 end
938 else
939 render_challenge
940 end
941 end
942
943 def challenge_passed?
944 current_user.valid_password?(challenge_params[:current_password])
945 end
946 end
947
948 class Animal
949 include Comparable
950
951 attr_reader :legs
952
953 def initialize(name, legs)
954 @name, @legs = name, legs
955 end
956
957 def <=>(other)
958 legs <=> other.legs
959 end
960 end
961
962 # Singleton method for car object
963 def car.wheels
964 puts "There are four wheels"
965 end"#
966 .unindent();
967
968 let documents = retriever.parse_file(&text, language.clone()).unwrap();
969
970 assert_documents_eq(
971 &documents,
972 &[
973 (
974 r#"
975 # This concern is inspired by "sudo mode" on GitHub. It
976 # is a way to re-authenticate a user before allowing them
977 # to see or perform an action.
978 #
979 # Add `before_action :require_challenge!` to actions you
980 # want to protect.
981 #
982 # The user will be shown a page to enter the challenge (which
983 # is either the password, or just the username when no
984 # password exists). Upon passing, there is a grace period
985 # during which no challenge will be asked from the user.
986 #
987 # Accessing challenge-protected resources during the grace
988 # period will refresh the grace period.
989 module ChallengableConcern
990 extend ActiveSupport::Concern
991
992 CHALLENGE_TIMEOUT = 1.hour.freeze
993
994 def require_challenge!
995 # ...
996 end
997
998 def challenge_passed?
999 # ...
1000 end
1001 end"#
1002 .unindent(),
1003 558,
1004 ),
1005 (
1006 r#"
1007 def require_challenge!
1008 return if skip_challenge?
1009
1010 if challenge_passed_recently?
1011 session[:challenge_passed_at] = Time.now.utc
1012 return
1013 end
1014
1015 @challenge = Form::Challenge.new(return_to: request.url)
1016
1017 if params.key?(:form_challenge)
1018 if challenge_passed?
1019 session[:challenge_passed_at] = Time.now.utc
1020 else
1021 flash.now[:alert] = I18n.t('challenge.invalid_password')
1022 render_challenge
1023 end
1024 else
1025 render_challenge
1026 end
1027 end"#
1028 .unindent(),
1029 663,
1030 ),
1031 (
1032 r#"
1033 def challenge_passed?
1034 current_user.valid_password?(challenge_params[:current_password])
1035 end"#
1036 .unindent(),
1037 1254,
1038 ),
1039 (
1040 r#"
1041 class Animal
1042 include Comparable
1043
1044 attr_reader :legs
1045
1046 def initialize(name, legs)
1047 # ...
1048 end
1049
1050 def <=>(other)
1051 # ...
1052 end
1053 end"#
1054 .unindent(),
1055 1363,
1056 ),
1057 (
1058 r#"
1059 def initialize(name, legs)
1060 @name, @legs = name, legs
1061 end"#
1062 .unindent(),
1063 1427,
1064 ),
1065 (
1066 r#"
1067 def <=>(other)
1068 legs <=> other.legs
1069 end"#
1070 .unindent(),
1071 1501,
1072 ),
1073 (
1074 r#"
1075 # Singleton method for car object
1076 def car.wheels
1077 puts "There are four wheels"
1078 end"#
1079 .unindent(),
1080 1591,
1081 ),
1082 ],
1083 );
1084}
1085
1086#[gpui::test]
1087async fn test_code_context_retrieval_php() {
1088 let language = php_lang();
1089 let embedding_provider = Arc::new(DummyEmbeddings {});
1090 let mut retriever = CodeContextRetriever::new(embedding_provider);
1091
1092 let text = r#"
1093 <?php
1094
1095 namespace LevelUp\Experience\Concerns;
1096
1097 /*
1098 This is a multiple-lines comment block
1099 that spans over multiple
1100 lines
1101 */
1102 function functionName() {
1103 echo "Hello world!";
1104 }
1105
1106 trait HasAchievements
1107 {
1108 /**
1109 * @throws \Exception
1110 */
1111 public function grantAchievement(Achievement $achievement, $progress = null): void
1112 {
1113 if ($progress > 100) {
1114 throw new Exception(message: 'Progress cannot be greater than 100');
1115 }
1116
1117 if ($this->achievements()->find($achievement->id)) {
1118 throw new Exception(message: 'User already has this Achievement');
1119 }
1120
1121 $this->achievements()->attach($achievement, [
1122 'progress' => $progress ?? null,
1123 ]);
1124
1125 $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1126 }
1127
1128 public function achievements(): BelongsToMany
1129 {
1130 return $this->belongsToMany(related: Achievement::class)
1131 ->withPivot(columns: 'progress')
1132 ->where('is_secret', false)
1133 ->using(AchievementUser::class);
1134 }
1135 }
1136
1137 interface Multiplier
1138 {
1139 public function qualifies(array $data): bool;
1140
1141 public function setMultiplier(): int;
1142 }
1143
1144 enum AuditType: string
1145 {
1146 case Add = 'add';
1147 case Remove = 'remove';
1148 case Reset = 'reset';
1149 case LevelUp = 'level_up';
1150 }
1151
1152 ?>"#
1153 .unindent();
1154
1155 let documents = retriever.parse_file(&text, language.clone()).unwrap();
1156
1157 assert_documents_eq(
1158 &documents,
1159 &[
1160 (
1161 r#"
1162 /*
1163 This is a multiple-lines comment block
1164 that spans over multiple
1165 lines
1166 */
1167 function functionName() {
1168 echo "Hello world!";
1169 }"#
1170 .unindent(),
1171 123,
1172 ),
1173 (
1174 r#"
1175 trait HasAchievements
1176 {
1177 /**
1178 * @throws \Exception
1179 */
1180 public function grantAchievement(Achievement $achievement, $progress = null): void
1181 {/* ... */}
1182
1183 public function achievements(): BelongsToMany
1184 {/* ... */}
1185 }"#
1186 .unindent(),
1187 177,
1188 ),
1189 (r#"
1190 /**
1191 * @throws \Exception
1192 */
1193 public function grantAchievement(Achievement $achievement, $progress = null): void
1194 {
1195 if ($progress > 100) {
1196 throw new Exception(message: 'Progress cannot be greater than 100');
1197 }
1198
1199 if ($this->achievements()->find($achievement->id)) {
1200 throw new Exception(message: 'User already has this Achievement');
1201 }
1202
1203 $this->achievements()->attach($achievement, [
1204 'progress' => $progress ?? null,
1205 ]);
1206
1207 $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1208 }"#.unindent(), 245),
1209 (r#"
1210 public function achievements(): BelongsToMany
1211 {
1212 return $this->belongsToMany(related: Achievement::class)
1213 ->withPivot(columns: 'progress')
1214 ->where('is_secret', false)
1215 ->using(AchievementUser::class);
1216 }"#.unindent(), 902),
1217 (r#"
1218 interface Multiplier
1219 {
1220 public function qualifies(array $data): bool;
1221
1222 public function setMultiplier(): int;
1223 }"#.unindent(),
1224 1146),
1225 (r#"
1226 enum AuditType: string
1227 {
1228 case Add = 'add';
1229 case Remove = 'remove';
1230 case Reset = 'reset';
1231 case LevelUp = 'level_up';
1232 }"#.unindent(), 1265)
1233 ],
1234 );
1235}
1236
1237#[derive(Default)]
1238struct FakeEmbeddingProvider {
1239 embedding_count: AtomicUsize,
1240}
1241
1242impl FakeEmbeddingProvider {
1243 fn embedding_count(&self) -> usize {
1244 self.embedding_count.load(atomic::Ordering::SeqCst)
1245 }
1246
1247 fn embed_sync(&self, span: &str) -> Embedding {
1248 let mut result = vec![1.0; 26];
1249 for letter in span.chars() {
1250 let letter = letter.to_ascii_lowercase();
1251 if letter as u32 >= 'a' as u32 {
1252 let ix = (letter as u32) - ('a' as u32);
1253 if ix < 26 {
1254 result[ix as usize] += 1.0;
1255 }
1256 }
1257 }
1258
1259 let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
1260 for x in &mut result {
1261 *x /= norm;
1262 }
1263
1264 result.into()
1265 }
1266}
1267
1268#[async_trait]
1269impl EmbeddingProvider for FakeEmbeddingProvider {
1270 fn is_authenticated(&self) -> bool {
1271 true
1272 }
1273 fn truncate(&self, span: &str) -> (String, usize) {
1274 (span.to_string(), 1)
1275 }
1276
1277 fn max_tokens_per_batch(&self) -> usize {
1278 200
1279 }
1280
1281 fn rate_limit_expiration(&self) -> Option<Instant> {
1282 None
1283 }
1284
1285 async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
1286 self.embedding_count
1287 .fetch_add(spans.len(), atomic::Ordering::SeqCst);
1288 Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
1289 }
1290}
1291
1292fn js_lang() -> Arc<Language> {
1293 Arc::new(
1294 Language::new(
1295 LanguageConfig {
1296 name: "Javascript".into(),
1297 path_suffixes: vec!["js".into()],
1298 ..Default::default()
1299 },
1300 Some(tree_sitter_typescript::language_tsx()),
1301 )
1302 .with_embedding_query(
1303 &r#"
1304
1305 (
1306 (comment)* @context
1307 .
1308 [
1309 (export_statement
1310 (function_declaration
1311 "async"? @name
1312 "function" @name
1313 name: (_) @name))
1314 (function_declaration
1315 "async"? @name
1316 "function" @name
1317 name: (_) @name)
1318 ] @item
1319 )
1320
1321 (
1322 (comment)* @context
1323 .
1324 [
1325 (export_statement
1326 (class_declaration
1327 "class" @name
1328 name: (_) @name))
1329 (class_declaration
1330 "class" @name
1331 name: (_) @name)
1332 ] @item
1333 )
1334
1335 (
1336 (comment)* @context
1337 .
1338 [
1339 (export_statement
1340 (interface_declaration
1341 "interface" @name
1342 name: (_) @name))
1343 (interface_declaration
1344 "interface" @name
1345 name: (_) @name)
1346 ] @item
1347 )
1348
1349 (
1350 (comment)* @context
1351 .
1352 [
1353 (export_statement
1354 (enum_declaration
1355 "enum" @name
1356 name: (_) @name))
1357 (enum_declaration
1358 "enum" @name
1359 name: (_) @name)
1360 ] @item
1361 )
1362
1363 (
1364 (comment)* @context
1365 .
1366 (method_definition
1367 [
1368 "get"
1369 "set"
1370 "async"
1371 "*"
1372 "static"
1373 ]* @name
1374 name: (_) @name) @item
1375 )
1376
1377 "#
1378 .unindent(),
1379 )
1380 .unwrap(),
1381 )
1382}
1383
1384fn rust_lang() -> Arc<Language> {
1385 Arc::new(
1386 Language::new(
1387 LanguageConfig {
1388 name: "Rust".into(),
1389 path_suffixes: vec!["rs".into()],
1390 collapsed_placeholder: " /* ... */ ".to_string(),
1391 ..Default::default()
1392 },
1393 Some(tree_sitter_rust::language()),
1394 )
1395 .with_embedding_query(
1396 r#"
1397 (
1398 [(line_comment) (attribute_item)]* @context
1399 .
1400 [
1401 (struct_item
1402 name: (_) @name)
1403
1404 (enum_item
1405 name: (_) @name)
1406
1407 (impl_item
1408 trait: (_)? @name
1409 "for"? @name
1410 type: (_) @name)
1411
1412 (trait_item
1413 name: (_) @name)
1414
1415 (function_item
1416 name: (_) @name
1417 body: (block
1418 "{" @keep
1419 "}" @keep) @collapse)
1420
1421 (macro_definition
1422 name: (_) @name)
1423 ] @item
1424 )
1425 "#,
1426 )
1427 .unwrap(),
1428 )
1429}
1430
1431fn json_lang() -> Arc<Language> {
1432 Arc::new(
1433 Language::new(
1434 LanguageConfig {
1435 name: "JSON".into(),
1436 path_suffixes: vec!["json".into()],
1437 ..Default::default()
1438 },
1439 Some(tree_sitter_json::language()),
1440 )
1441 .with_embedding_query(
1442 r#"
1443 (document) @item
1444
1445 (array
1446 "[" @keep
1447 .
1448 (object)? @keep
1449 "]" @keep) @collapse
1450
1451 (pair value: (string
1452 "\"" @keep
1453 "\"" @keep) @collapse)
1454 "#,
1455 )
1456 .unwrap(),
1457 )
1458}
1459
1460fn toml_lang() -> Arc<Language> {
1461 Arc::new(Language::new(
1462 LanguageConfig {
1463 name: "TOML".into(),
1464 path_suffixes: vec!["toml".into()],
1465 ..Default::default()
1466 },
1467 Some(tree_sitter_toml::language()),
1468 ))
1469}
1470
1471fn cpp_lang() -> Arc<Language> {
1472 Arc::new(
1473 Language::new(
1474 LanguageConfig {
1475 name: "CPP".into(),
1476 path_suffixes: vec!["cpp".into()],
1477 ..Default::default()
1478 },
1479 Some(tree_sitter_cpp::language()),
1480 )
1481 .with_embedding_query(
1482 r#"
1483 (
1484 (comment)* @context
1485 .
1486 (function_definition
1487 (type_qualifier)? @name
1488 type: (_)? @name
1489 declarator: [
1490 (function_declarator
1491 declarator: (_) @name)
1492 (pointer_declarator
1493 "*" @name
1494 declarator: (function_declarator
1495 declarator: (_) @name))
1496 (pointer_declarator
1497 "*" @name
1498 declarator: (pointer_declarator
1499 "*" @name
1500 declarator: (function_declarator
1501 declarator: (_) @name)))
1502 (reference_declarator
1503 ["&" "&&"] @name
1504 (function_declarator
1505 declarator: (_) @name))
1506 ]
1507 (type_qualifier)? @name) @item
1508 )
1509
1510 (
1511 (comment)* @context
1512 .
1513 (template_declaration
1514 (class_specifier
1515 "class" @name
1516 name: (_) @name)
1517 ) @item
1518 )
1519
1520 (
1521 (comment)* @context
1522 .
1523 (class_specifier
1524 "class" @name
1525 name: (_) @name) @item
1526 )
1527
1528 (
1529 (comment)* @context
1530 .
1531 (enum_specifier
1532 "enum" @name
1533 name: (_) @name) @item
1534 )
1535
1536 (
1537 (comment)* @context
1538 .
1539 (declaration
1540 type: (struct_specifier
1541 "struct" @name)
1542 declarator: (_) @name) @item
1543 )
1544
1545 "#,
1546 )
1547 .unwrap(),
1548 )
1549}
1550
1551fn lua_lang() -> Arc<Language> {
1552 Arc::new(
1553 Language::new(
1554 LanguageConfig {
1555 name: "Lua".into(),
1556 path_suffixes: vec!["lua".into()],
1557 collapsed_placeholder: "--[ ... ]--".to_string(),
1558 ..Default::default()
1559 },
1560 Some(tree_sitter_lua::language()),
1561 )
1562 .with_embedding_query(
1563 r#"
1564 (
1565 (comment)* @context
1566 .
1567 (function_declaration
1568 "function" @name
1569 name: (_) @name
1570 (comment)* @collapse
1571 body: (block) @collapse
1572 ) @item
1573 )
1574 "#,
1575 )
1576 .unwrap(),
1577 )
1578}
1579
1580fn php_lang() -> Arc<Language> {
1581 Arc::new(
1582 Language::new(
1583 LanguageConfig {
1584 name: "PHP".into(),
1585 path_suffixes: vec!["php".into()],
1586 collapsed_placeholder: "/* ... */".into(),
1587 ..Default::default()
1588 },
1589 Some(tree_sitter_php::language()),
1590 )
1591 .with_embedding_query(
1592 r#"
1593 (
1594 (comment)* @context
1595 .
1596 [
1597 (function_definition
1598 "function" @name
1599 name: (_) @name
1600 body: (_
1601 "{" @keep
1602 "}" @keep) @collapse
1603 )
1604
1605 (trait_declaration
1606 "trait" @name
1607 name: (_) @name)
1608
1609 (method_declaration
1610 "function" @name
1611 name: (_) @name
1612 body: (_
1613 "{" @keep
1614 "}" @keep) @collapse
1615 )
1616
1617 (interface_declaration
1618 "interface" @name
1619 name: (_) @name
1620 )
1621
1622 (enum_declaration
1623 "enum" @name
1624 name: (_) @name
1625 )
1626
1627 ] @item
1628 )
1629 "#,
1630 )
1631 .unwrap(),
1632 )
1633}
1634
1635fn ruby_lang() -> Arc<Language> {
1636 Arc::new(
1637 Language::new(
1638 LanguageConfig {
1639 name: "Ruby".into(),
1640 path_suffixes: vec!["rb".into()],
1641 collapsed_placeholder: "# ...".to_string(),
1642 ..Default::default()
1643 },
1644 Some(tree_sitter_ruby::language()),
1645 )
1646 .with_embedding_query(
1647 r#"
1648 (
1649 (comment)* @context
1650 .
1651 [
1652 (module
1653 "module" @name
1654 name: (_) @name)
1655 (method
1656 "def" @name
1657 name: (_) @name
1658 body: (body_statement) @collapse)
1659 (class
1660 "class" @name
1661 name: (_) @name)
1662 (singleton_method
1663 "def" @name
1664 object: (_) @name
1665 "." @name
1666 name: (_) @name
1667 body: (body_statement) @collapse)
1668 ] @item
1669 )
1670 "#,
1671 )
1672 .unwrap(),
1673 )
1674}
1675
1676fn elixir_lang() -> Arc<Language> {
1677 Arc::new(
1678 Language::new(
1679 LanguageConfig {
1680 name: "Elixir".into(),
1681 path_suffixes: vec!["rs".into()],
1682 ..Default::default()
1683 },
1684 Some(tree_sitter_elixir::language()),
1685 )
1686 .with_embedding_query(
1687 r#"
1688 (
1689 (unary_operator
1690 operator: "@"
1691 operand: (call
1692 target: (identifier) @unary
1693 (#match? @unary "^(doc)$"))
1694 ) @context
1695 .
1696 (call
1697 target: (identifier) @name
1698 (arguments
1699 [
1700 (identifier) @name
1701 (call
1702 target: (identifier) @name)
1703 (binary_operator
1704 left: (call
1705 target: (identifier) @name)
1706 operator: "when")
1707 ])
1708 (#match? @name "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @item
1709 )
1710
1711 (call
1712 target: (identifier) @name
1713 (arguments (alias) @name)
1714 (#match? @name "^(defmodule|defprotocol)$")) @item
1715 "#,
1716 )
1717 .unwrap(),
1718 )
1719}
1720
1721#[gpui::test]
1722fn test_subtract_ranges() {
1723 // collapsed_ranges: Vec<Range<usize>>, keep_ranges: Vec<Range<usize>>
1724
1725 assert_eq!(
1726 subtract_ranges(&[0..5, 10..21], &[0..1, 4..5]),
1727 vec![1..4, 10..21]
1728 );
1729
1730 assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
1731}
1732
1733fn init_test(cx: &mut TestAppContext) {
1734 cx.update(|cx| {
1735 cx.set_global(SettingsStore::test(cx));
1736 settings::register::<SemanticIndexSettings>(cx);
1737 settings::register::<ProjectSettings>(cx);
1738 });
1739}