1use crate::{
2 db::dot,
3 embedding::{DummyEmbeddings, EmbeddingProvider},
4 parsing::{subtract_ranges, CodeContextRetriever, Document},
5 semantic_index_settings::SemanticIndexSettings,
6 SearchResult, SemanticIndex,
7};
8use anyhow::Result;
9use async_trait::async_trait;
10use gpui::{Task, TestAppContext};
11use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
12use pretty_assertions::assert_eq;
13use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs, Project};
14use rand::{rngs::StdRng, Rng};
15use serde_json::json;
16use settings::SettingsStore;
17use std::{
18 path::Path,
19 sync::{
20 atomic::{self, AtomicUsize},
21 Arc,
22 },
23};
24use unindent::Unindent;
25
26#[ctor::ctor]
27fn init_logger() {
28 if std::env::var("RUST_LOG").is_ok() {
29 env_logger::init();
30 }
31}
32
33#[gpui::test]
34async fn test_semantic_index(cx: &mut TestAppContext) {
35 cx.update(|cx| {
36 cx.set_global(SettingsStore::test(cx));
37 settings::register::<SemanticIndexSettings>(cx);
38 settings::register::<ProjectSettings>(cx);
39 });
40
41 let fs = FakeFs::new(cx.background());
42 fs.insert_tree(
43 "/the-root",
44 json!({
45 "src": {
46 "file1.rs": "
47 fn aaa() {
48 println!(\"aaaaaaaaaaaa!\");
49 }
50
51 fn zzzzz() {
52 println!(\"SLEEPING\");
53 }
54 ".unindent(),
55 "file2.rs": "
56 fn bbb() {
57 println!(\"bbbbbbbbbbbbb!\");
58 }
59 ".unindent(),
60 "file3.toml": "
61 ZZZZZZZZZZZZZZZZZZ = 5
62 ".unindent(),
63 }
64 }),
65 )
66 .await;
67
68 let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
69 let rust_language = rust_lang();
70 let toml_language = toml_lang();
71 languages.add(rust_language);
72 languages.add(toml_language);
73
74 let db_dir = tempdir::TempDir::new("vector-store").unwrap();
75 let db_path = db_dir.path().join("db.sqlite");
76
77 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
78 let store = SemanticIndex::new(
79 fs.clone(),
80 db_path,
81 embedding_provider.clone(),
82 languages,
83 cx.to_async(),
84 )
85 .await
86 .unwrap();
87
88 let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
89
90 let _ = store
91 .update(cx, |store, cx| {
92 store.initialize_project(project.clone(), cx)
93 })
94 .await;
95
96 let (file_count, outstanding_file_count) = store
97 .update(cx, |store, cx| store.index_project(project.clone(), cx))
98 .await
99 .unwrap();
100 assert_eq!(file_count, 3);
101 cx.foreground().run_until_parked();
102 assert_eq!(*outstanding_file_count.borrow(), 0);
103
104 let search_results = store
105 .update(cx, |store, cx| {
106 store.search_project(
107 project.clone(),
108 "aaaaaabbbbzz".to_string(),
109 5,
110 vec![],
111 vec![],
112 cx,
113 )
114 })
115 .await
116 .unwrap();
117
118 assert_search_results(
119 &search_results,
120 &[
121 (Path::new("src/file1.rs").into(), 0),
122 (Path::new("src/file2.rs").into(), 0),
123 (Path::new("src/file3.toml").into(), 0),
124 (Path::new("src/file1.rs").into(), 45),
125 ],
126 cx,
127 );
128
129 // Test Include Files Functonality
130 let include_files = vec![PathMatcher::new("*.rs").unwrap()];
131 let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
132 let rust_only_search_results = store
133 .update(cx, |store, cx| {
134 store.search_project(
135 project.clone(),
136 "aaaaaabbbbzz".to_string(),
137 5,
138 include_files,
139 vec![],
140 cx,
141 )
142 })
143 .await
144 .unwrap();
145
146 assert_search_results(
147 &rust_only_search_results,
148 &[
149 (Path::new("src/file1.rs").into(), 0),
150 (Path::new("src/file2.rs").into(), 0),
151 (Path::new("src/file1.rs").into(), 45),
152 ],
153 cx,
154 );
155
156 let no_rust_search_results = store
157 .update(cx, |store, cx| {
158 store.search_project(
159 project.clone(),
160 "aaaaaabbbbzz".to_string(),
161 5,
162 vec![],
163 exclude_files,
164 cx,
165 )
166 })
167 .await
168 .unwrap();
169
170 assert_search_results(
171 &no_rust_search_results,
172 &[(Path::new("src/file3.toml").into(), 0)],
173 cx,
174 );
175
176 fs.save(
177 "/the-root/src/file2.rs".as_ref(),
178 &"
179 fn dddd() { println!(\"ddddd!\"); }
180 struct pqpqpqp {}
181 "
182 .unindent()
183 .into(),
184 Default::default(),
185 )
186 .await
187 .unwrap();
188
189 cx.foreground().run_until_parked();
190
191 let prev_embedding_count = embedding_provider.embedding_count();
192 let (file_count, outstanding_file_count) = store
193 .update(cx, |store, cx| store.index_project(project.clone(), cx))
194 .await
195 .unwrap();
196 assert_eq!(file_count, 1);
197
198 cx.foreground().run_until_parked();
199 assert_eq!(*outstanding_file_count.borrow(), 0);
200
201 assert_eq!(
202 embedding_provider.embedding_count() - prev_embedding_count,
203 2
204 );
205}
206
207#[track_caller]
208fn assert_search_results(
209 actual: &[SearchResult],
210 expected: &[(Arc<Path>, usize)],
211 cx: &TestAppContext,
212) {
213 let actual = actual
214 .iter()
215 .map(|search_result| {
216 search_result.buffer.read_with(cx, |buffer, _cx| {
217 (
218 buffer.file().unwrap().path().clone(),
219 search_result.range.start.to_offset(buffer),
220 )
221 })
222 })
223 .collect::<Vec<_>>();
224 assert_eq!(actual, expected);
225}
226
227#[gpui::test]
228async fn test_code_context_retrieval_rust() {
229 let language = rust_lang();
230 let embedding_provider = Arc::new(DummyEmbeddings {});
231 let mut retriever = CodeContextRetriever::new(embedding_provider);
232
233 let text = "
234 /// A doc comment
235 /// that spans multiple lines
236 #[gpui::test]
237 fn a() {
238 b
239 }
240
241 impl C for D {
242 }
243
244 impl E {
245 // This is also a preceding comment
246 pub fn function_1() -> Option<()> {
247 todo!();
248 }
249
250 // This is a preceding comment
251 fn function_2() -> Result<()> {
252 todo!();
253 }
254 }
255 "
256 .unindent();
257
258 let documents = retriever.parse_file(&text, language).unwrap();
259
260 assert_documents_eq(
261 &documents,
262 &[
263 (
264 "
265 /// A doc comment
266 /// that spans multiple lines
267 #[gpui::test]
268 fn a() {
269 b
270 }"
271 .unindent(),
272 text.find("fn a").unwrap(),
273 ),
274 (
275 "
276 impl C for D {
277 }"
278 .unindent(),
279 text.find("impl C").unwrap(),
280 ),
281 (
282 "
283 impl E {
284 // This is also a preceding comment
285 pub fn function_1() -> Option<()> { /* ... */ }
286
287 // This is a preceding comment
288 fn function_2() -> Result<()> { /* ... */ }
289 }"
290 .unindent(),
291 text.find("impl E").unwrap(),
292 ),
293 (
294 "
295 // This is also a preceding comment
296 pub fn function_1() -> Option<()> {
297 todo!();
298 }"
299 .unindent(),
300 text.find("pub fn function_1").unwrap(),
301 ),
302 (
303 "
304 // This is a preceding comment
305 fn function_2() -> Result<()> {
306 todo!();
307 }"
308 .unindent(),
309 text.find("fn function_2").unwrap(),
310 ),
311 ],
312 );
313}
314
315#[gpui::test]
316async fn test_code_context_retrieval_json() {
317 let language = json_lang();
318 let embedding_provider = Arc::new(DummyEmbeddings {});
319 let mut retriever = CodeContextRetriever::new(embedding_provider);
320
321 let text = r#"
322 {
323 "array": [1, 2, 3, 4],
324 "string": "abcdefg",
325 "nested_object": {
326 "array_2": [5, 6, 7, 8],
327 "string_2": "hijklmnop",
328 "boolean": true,
329 "none": null
330 }
331 }
332 "#
333 .unindent();
334
335 let documents = retriever.parse_file(&text, language.clone()).unwrap();
336
337 assert_documents_eq(
338 &documents,
339 &[(
340 r#"
341 {
342 "array": [],
343 "string": "",
344 "nested_object": {
345 "array_2": [],
346 "string_2": "",
347 "boolean": true,
348 "none": null
349 }
350 }"#
351 .unindent(),
352 text.find("{").unwrap(),
353 )],
354 );
355
356 let text = r#"
357 [
358 {
359 "name": "somebody",
360 "age": 42
361 },
362 {
363 "name": "somebody else",
364 "age": 43
365 }
366 ]
367 "#
368 .unindent();
369
370 let documents = retriever.parse_file(&text, language.clone()).unwrap();
371
372 assert_documents_eq(
373 &documents,
374 &[(
375 r#"
376 [{
377 "name": "",
378 "age": 42
379 }]"#
380 .unindent(),
381 text.find("[").unwrap(),
382 )],
383 );
384}
385
386fn assert_documents_eq(
387 documents: &[Document],
388 expected_contents_and_start_offsets: &[(String, usize)],
389) {
390 assert_eq!(
391 documents
392 .iter()
393 .map(|document| (document.content.clone(), document.range.start))
394 .collect::<Vec<_>>(),
395 expected_contents_and_start_offsets
396 );
397}
398
399#[gpui::test]
400async fn test_code_context_retrieval_javascript() {
401 let language = js_lang();
402 let embedding_provider = Arc::new(DummyEmbeddings {});
403 let mut retriever = CodeContextRetriever::new(embedding_provider);
404
405 let text = "
406 /* globals importScripts, backend */
407 function _authorize() {}
408
409 /**
410 * Sometimes the frontend build is way faster than backend.
411 */
412 export async function authorizeBank() {
413 _authorize(pushModal, upgradingAccountId, {});
414 }
415
416 export class SettingsPage {
417 /* This is a test setting */
418 constructor(page) {
419 this.page = page;
420 }
421 }
422
423 /* This is a test comment */
424 class TestClass {}
425
426 /* Schema for editor_events in Clickhouse. */
427 export interface ClickhouseEditorEvent {
428 installation_id: string
429 operation: string
430 }
431 "
432 .unindent();
433
434 let documents = retriever.parse_file(&text, language.clone()).unwrap();
435
436 assert_documents_eq(
437 &documents,
438 &[
439 (
440 "
441 /* globals importScripts, backend */
442 function _authorize() {}"
443 .unindent(),
444 37,
445 ),
446 (
447 "
448 /**
449 * Sometimes the frontend build is way faster than backend.
450 */
451 export async function authorizeBank() {
452 _authorize(pushModal, upgradingAccountId, {});
453 }"
454 .unindent(),
455 131,
456 ),
457 (
458 "
459 export class SettingsPage {
460 /* This is a test setting */
461 constructor(page) {
462 this.page = page;
463 }
464 }"
465 .unindent(),
466 225,
467 ),
468 (
469 "
470 /* This is a test setting */
471 constructor(page) {
472 this.page = page;
473 }"
474 .unindent(),
475 290,
476 ),
477 (
478 "
479 /* This is a test comment */
480 class TestClass {}"
481 .unindent(),
482 374,
483 ),
484 (
485 "
486 /* Schema for editor_events in Clickhouse. */
487 export interface ClickhouseEditorEvent {
488 installation_id: string
489 operation: string
490 }"
491 .unindent(),
492 440,
493 ),
494 ],
495 )
496}
497
498#[gpui::test]
499async fn test_code_context_retrieval_lua() {
500 let language = lua_lang();
501 let embedding_provider = Arc::new(DummyEmbeddings {});
502 let mut retriever = CodeContextRetriever::new(embedding_provider);
503
504 let text = r#"
505 -- Creates a new class
506 -- @param baseclass The Baseclass of this class, or nil.
507 -- @return A new class reference.
508 function classes.class(baseclass)
509 -- Create the class definition and metatable.
510 local classdef = {}
511 -- Find the super class, either Object or user-defined.
512 baseclass = baseclass or classes.Object
513 -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
514 setmetatable(classdef, { __index = baseclass })
515 -- All class instances have a reference to the class object.
516 classdef.class = classdef
517 --- Recursivly allocates the inheritance tree of the instance.
518 -- @param mastertable The 'root' of the inheritance tree.
519 -- @return Returns the instance with the allocated inheritance tree.
520 function classdef.alloc(mastertable)
521 -- All class instances have a reference to a superclass object.
522 local instance = { super = baseclass.alloc(mastertable) }
523 -- Any functions this instance does not know of will 'look up' to the superclass definition.
524 setmetatable(instance, { __index = classdef, __newindex = mastertable })
525 return instance
526 end
527 end
528 "#.unindent();
529
530 let documents = retriever.parse_file(&text, language.clone()).unwrap();
531
532 assert_documents_eq(
533 &documents,
534 &[
535 (r#"
536 -- Creates a new class
537 -- @param baseclass The Baseclass of this class, or nil.
538 -- @return A new class reference.
539 function classes.class(baseclass)
540 -- Create the class definition and metatable.
541 local classdef = {}
542 -- Find the super class, either Object or user-defined.
543 baseclass = baseclass or classes.Object
544 -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
545 setmetatable(classdef, { __index = baseclass })
546 -- All class instances have a reference to the class object.
547 classdef.class = classdef
548 --- Recursivly allocates the inheritance tree of the instance.
549 -- @param mastertable The 'root' of the inheritance tree.
550 -- @return Returns the instance with the allocated inheritance tree.
551 function classdef.alloc(mastertable)
552 --[ ... ]--
553 --[ ... ]--
554 end
555 end"#.unindent(),
556 114),
557 (r#"
558 --- Recursivly allocates the inheritance tree of the instance.
559 -- @param mastertable The 'root' of the inheritance tree.
560 -- @return Returns the instance with the allocated inheritance tree.
561 function classdef.alloc(mastertable)
562 -- All class instances have a reference to a superclass object.
563 local instance = { super = baseclass.alloc(mastertable) }
564 -- Any functions this instance does not know of will 'look up' to the superclass definition.
565 setmetatable(instance, { __index = classdef, __newindex = mastertable })
566 return instance
567 end"#.unindent(), 809),
568 ]
569 );
570}
571
572#[gpui::test]
573async fn test_code_context_retrieval_elixir() {
574 let language = elixir_lang();
575 let embedding_provider = Arc::new(DummyEmbeddings {});
576 let mut retriever = CodeContextRetriever::new(embedding_provider);
577
578 let text = r#"
579 defmodule File.Stream do
580 @moduledoc """
581 Defines a `File.Stream` struct returned by `File.stream!/3`.
582
583 The following fields are public:
584
585 * `path` - the file path
586 * `modes` - the file modes
587 * `raw` - a boolean indicating if bin functions should be used
588 * `line_or_bytes` - if reading should read lines or a given number of bytes
589 * `node` - the node the file belongs to
590
591 """
592
593 defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
594
595 @type t :: %__MODULE__{}
596
597 @doc false
598 def __build__(path, modes, line_or_bytes) do
599 raw = :lists.keyfind(:encoding, 1, modes) == false
600
601 modes =
602 case raw do
603 true ->
604 case :lists.keyfind(:read_ahead, 1, modes) do
605 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
606 {:read_ahead, _} -> [:raw | modes]
607 false -> [:raw, :read_ahead | modes]
608 end
609
610 false ->
611 modes
612 end
613
614 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
615
616 end"#
617 .unindent();
618
619 let documents = retriever.parse_file(&text, language.clone()).unwrap();
620
621 assert_documents_eq(
622 &documents,
623 &[(
624 r#"
625 defmodule File.Stream do
626 @moduledoc """
627 Defines a `File.Stream` struct returned by `File.stream!/3`.
628
629 The following fields are public:
630
631 * `path` - the file path
632 * `modes` - the file modes
633 * `raw` - a boolean indicating if bin functions should be used
634 * `line_or_bytes` - if reading should read lines or a given number of bytes
635 * `node` - the node the file belongs to
636
637 """
638
639 defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
640
641 @type t :: %__MODULE__{}
642
643 @doc false
644 def __build__(path, modes, line_or_bytes) do
645 raw = :lists.keyfind(:encoding, 1, modes) == false
646
647 modes =
648 case raw do
649 true ->
650 case :lists.keyfind(:read_ahead, 1, modes) do
651 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
652 {:read_ahead, _} -> [:raw | modes]
653 false -> [:raw, :read_ahead | modes]
654 end
655
656 false ->
657 modes
658 end
659
660 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
661
662 end"#
663 .unindent(),
664 0,
665 ),(r#"
666 @doc false
667 def __build__(path, modes, line_or_bytes) do
668 raw = :lists.keyfind(:encoding, 1, modes) == false
669
670 modes =
671 case raw do
672 true ->
673 case :lists.keyfind(:read_ahead, 1, modes) do
674 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
675 {:read_ahead, _} -> [:raw | modes]
676 false -> [:raw, :read_ahead | modes]
677 end
678
679 false ->
680 modes
681 end
682
683 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
684
685 end"#.unindent(), 574)],
686 );
687}
688
689#[gpui::test]
690async fn test_code_context_retrieval_cpp() {
691 let language = cpp_lang();
692 let embedding_provider = Arc::new(DummyEmbeddings {});
693 let mut retriever = CodeContextRetriever::new(embedding_provider);
694
695 let text = "
696 /**
697 * @brief Main function
698 * @returns 0 on exit
699 */
700 int main() { return 0; }
701
702 /**
703 * This is a test comment
704 */
705 class MyClass { // The class
706 public: // Access specifier
707 int myNum; // Attribute (int variable)
708 string myString; // Attribute (string variable)
709 };
710
711 // This is a test comment
712 enum Color { red, green, blue };
713
714 /** This is a preceding block comment
715 * This is the second line
716 */
717 struct { // Structure declaration
718 int myNum; // Member (int variable)
719 string myString; // Member (string variable)
720 } myStructure;
721
722 /**
723 * @brief Matrix class.
724 */
725 template <typename T,
726 typename = typename std::enable_if<
727 std::is_integral<T>::value || std::is_floating_point<T>::value,
728 bool>::type>
729 class Matrix2 {
730 std::vector<std::vector<T>> _mat;
731
732 public:
733 /**
734 * @brief Constructor
735 * @tparam Integer ensuring integers are being evaluated and not other
736 * data types.
737 * @param size denoting the size of Matrix as size x size
738 */
739 template <typename Integer,
740 typename = typename std::enable_if<std::is_integral<Integer>::value,
741 Integer>::type>
742 explicit Matrix(const Integer size) {
743 for (size_t i = 0; i < size; ++i) {
744 _mat.emplace_back(std::vector<T>(size, 0));
745 }
746 }
747 }"
748 .unindent();
749
750 let documents = retriever.parse_file(&text, language.clone()).unwrap();
751
752 assert_documents_eq(
753 &documents,
754 &[
755 (
756 "
757 /**
758 * @brief Main function
759 * @returns 0 on exit
760 */
761 int main() { return 0; }"
762 .unindent(),
763 54,
764 ),
765 (
766 "
767 /**
768 * This is a test comment
769 */
770 class MyClass { // The class
771 public: // Access specifier
772 int myNum; // Attribute (int variable)
773 string myString; // Attribute (string variable)
774 }"
775 .unindent(),
776 112,
777 ),
778 (
779 "
780 // This is a test comment
781 enum Color { red, green, blue }"
782 .unindent(),
783 322,
784 ),
785 (
786 "
787 /** This is a preceding block comment
788 * This is the second line
789 */
790 struct { // Structure declaration
791 int myNum; // Member (int variable)
792 string myString; // Member (string variable)
793 } myStructure;"
794 .unindent(),
795 425,
796 ),
797 (
798 "
799 /**
800 * @brief Matrix class.
801 */
802 template <typename T,
803 typename = typename std::enable_if<
804 std::is_integral<T>::value || std::is_floating_point<T>::value,
805 bool>::type>
806 class Matrix2 {
807 std::vector<std::vector<T>> _mat;
808
809 public:
810 /**
811 * @brief Constructor
812 * @tparam Integer ensuring integers are being evaluated and not other
813 * data types.
814 * @param size denoting the size of Matrix as size x size
815 */
816 template <typename Integer,
817 typename = typename std::enable_if<std::is_integral<Integer>::value,
818 Integer>::type>
819 explicit Matrix(const Integer size) {
820 for (size_t i = 0; i < size; ++i) {
821 _mat.emplace_back(std::vector<T>(size, 0));
822 }
823 }
824 }"
825 .unindent(),
826 612,
827 ),
828 (
829 "
830 explicit Matrix(const Integer size) {
831 for (size_t i = 0; i < size; ++i) {
832 _mat.emplace_back(std::vector<T>(size, 0));
833 }
834 }"
835 .unindent(),
836 1226,
837 ),
838 ],
839 );
840}
841
842#[gpui::test]
843async fn test_code_context_retrieval_ruby() {
844 let language = ruby_lang();
845 let embedding_provider = Arc::new(DummyEmbeddings {});
846 let mut retriever = CodeContextRetriever::new(embedding_provider);
847
848 let text = r#"
849 # This concern is inspired by "sudo mode" on GitHub. It
850 # is a way to re-authenticate a user before allowing them
851 # to see or perform an action.
852 #
853 # Add `before_action :require_challenge!` to actions you
854 # want to protect.
855 #
856 # The user will be shown a page to enter the challenge (which
857 # is either the password, or just the username when no
858 # password exists). Upon passing, there is a grace period
859 # during which no challenge will be asked from the user.
860 #
861 # Accessing challenge-protected resources during the grace
862 # period will refresh the grace period.
863 module ChallengableConcern
864 extend ActiveSupport::Concern
865
866 CHALLENGE_TIMEOUT = 1.hour.freeze
867
868 def require_challenge!
869 return if skip_challenge?
870
871 if challenge_passed_recently?
872 session[:challenge_passed_at] = Time.now.utc
873 return
874 end
875
876 @challenge = Form::Challenge.new(return_to: request.url)
877
878 if params.key?(:form_challenge)
879 if challenge_passed?
880 session[:challenge_passed_at] = Time.now.utc
881 else
882 flash.now[:alert] = I18n.t('challenge.invalid_password')
883 render_challenge
884 end
885 else
886 render_challenge
887 end
888 end
889
890 def challenge_passed?
891 current_user.valid_password?(challenge_params[:current_password])
892 end
893 end
894
895 class Animal
896 include Comparable
897
898 attr_reader :legs
899
900 def initialize(name, legs)
901 @name, @legs = name, legs
902 end
903
904 def <=>(other)
905 legs <=> other.legs
906 end
907 end
908
909 # Singleton method for car object
910 def car.wheels
911 puts "There are four wheels"
912 end"#
913 .unindent();
914
915 let documents = retriever.parse_file(&text, language.clone()).unwrap();
916
917 assert_documents_eq(
918 &documents,
919 &[
920 (
921 r#"
922 # This concern is inspired by "sudo mode" on GitHub. It
923 # is a way to re-authenticate a user before allowing them
924 # to see or perform an action.
925 #
926 # Add `before_action :require_challenge!` to actions you
927 # want to protect.
928 #
929 # The user will be shown a page to enter the challenge (which
930 # is either the password, or just the username when no
931 # password exists). Upon passing, there is a grace period
932 # during which no challenge will be asked from the user.
933 #
934 # Accessing challenge-protected resources during the grace
935 # period will refresh the grace period.
936 module ChallengableConcern
937 extend ActiveSupport::Concern
938
939 CHALLENGE_TIMEOUT = 1.hour.freeze
940
941 def require_challenge!
942 # ...
943 end
944
945 def challenge_passed?
946 # ...
947 end
948 end"#
949 .unindent(),
950 558,
951 ),
952 (
953 r#"
954 def require_challenge!
955 return if skip_challenge?
956
957 if challenge_passed_recently?
958 session[:challenge_passed_at] = Time.now.utc
959 return
960 end
961
962 @challenge = Form::Challenge.new(return_to: request.url)
963
964 if params.key?(:form_challenge)
965 if challenge_passed?
966 session[:challenge_passed_at] = Time.now.utc
967 else
968 flash.now[:alert] = I18n.t('challenge.invalid_password')
969 render_challenge
970 end
971 else
972 render_challenge
973 end
974 end"#
975 .unindent(),
976 663,
977 ),
978 (
979 r#"
980 def challenge_passed?
981 current_user.valid_password?(challenge_params[:current_password])
982 end"#
983 .unindent(),
984 1254,
985 ),
986 (
987 r#"
988 class Animal
989 include Comparable
990
991 attr_reader :legs
992
993 def initialize(name, legs)
994 # ...
995 end
996
997 def <=>(other)
998 # ...
999 end
1000 end"#
1001 .unindent(),
1002 1363,
1003 ),
1004 (
1005 r#"
1006 def initialize(name, legs)
1007 @name, @legs = name, legs
1008 end"#
1009 .unindent(),
1010 1427,
1011 ),
1012 (
1013 r#"
1014 def <=>(other)
1015 legs <=> other.legs
1016 end"#
1017 .unindent(),
1018 1501,
1019 ),
1020 (
1021 r#"
1022 # Singleton method for car object
1023 def car.wheels
1024 puts "There are four wheels"
1025 end"#
1026 .unindent(),
1027 1591,
1028 ),
1029 ],
1030 );
1031}
1032
1033#[gpui::test]
1034async fn test_code_context_retrieval_php() {
1035 let language = php_lang();
1036 let embedding_provider = Arc::new(DummyEmbeddings {});
1037 let mut retriever = CodeContextRetriever::new(embedding_provider);
1038
1039 let text = r#"
1040 <?php
1041
1042 namespace LevelUp\Experience\Concerns;
1043
1044 /*
1045 This is a multiple-lines comment block
1046 that spans over multiple
1047 lines
1048 */
1049 function functionName() {
1050 echo "Hello world!";
1051 }
1052
1053 trait HasAchievements
1054 {
1055 /**
1056 * @throws \Exception
1057 */
1058 public function grantAchievement(Achievement $achievement, $progress = null): void
1059 {
1060 if ($progress > 100) {
1061 throw new Exception(message: 'Progress cannot be greater than 100');
1062 }
1063
1064 if ($this->achievements()->find($achievement->id)) {
1065 throw new Exception(message: 'User already has this Achievement');
1066 }
1067
1068 $this->achievements()->attach($achievement, [
1069 'progress' => $progress ?? null,
1070 ]);
1071
1072 $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1073 }
1074
1075 public function achievements(): BelongsToMany
1076 {
1077 return $this->belongsToMany(related: Achievement::class)
1078 ->withPivot(columns: 'progress')
1079 ->where('is_secret', false)
1080 ->using(AchievementUser::class);
1081 }
1082 }
1083
1084 interface Multiplier
1085 {
1086 public function qualifies(array $data): bool;
1087
1088 public function setMultiplier(): int;
1089 }
1090
1091 enum AuditType: string
1092 {
1093 case Add = 'add';
1094 case Remove = 'remove';
1095 case Reset = 'reset';
1096 case LevelUp = 'level_up';
1097 }
1098
1099 ?>"#
1100 .unindent();
1101
1102 let documents = retriever.parse_file(&text, language.clone()).unwrap();
1103
1104 assert_documents_eq(
1105 &documents,
1106 &[
1107 (
1108 r#"
1109 /*
1110 This is a multiple-lines comment block
1111 that spans over multiple
1112 lines
1113 */
1114 function functionName() {
1115 echo "Hello world!";
1116 }"#
1117 .unindent(),
1118 123,
1119 ),
1120 (
1121 r#"
1122 trait HasAchievements
1123 {
1124 /**
1125 * @throws \Exception
1126 */
1127 public function grantAchievement(Achievement $achievement, $progress = null): void
1128 {/* ... */}
1129
1130 public function achievements(): BelongsToMany
1131 {/* ... */}
1132 }"#
1133 .unindent(),
1134 177,
1135 ),
1136 (r#"
1137 /**
1138 * @throws \Exception
1139 */
1140 public function grantAchievement(Achievement $achievement, $progress = null): void
1141 {
1142 if ($progress > 100) {
1143 throw new Exception(message: 'Progress cannot be greater than 100');
1144 }
1145
1146 if ($this->achievements()->find($achievement->id)) {
1147 throw new Exception(message: 'User already has this Achievement');
1148 }
1149
1150 $this->achievements()->attach($achievement, [
1151 'progress' => $progress ?? null,
1152 ]);
1153
1154 $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1155 }"#.unindent(), 245),
1156 (r#"
1157 public function achievements(): BelongsToMany
1158 {
1159 return $this->belongsToMany(related: Achievement::class)
1160 ->withPivot(columns: 'progress')
1161 ->where('is_secret', false)
1162 ->using(AchievementUser::class);
1163 }"#.unindent(), 902),
1164 (r#"
1165 interface Multiplier
1166 {
1167 public function qualifies(array $data): bool;
1168
1169 public function setMultiplier(): int;
1170 }"#.unindent(),
1171 1146),
1172 (r#"
1173 enum AuditType: string
1174 {
1175 case Add = 'add';
1176 case Remove = 'remove';
1177 case Reset = 'reset';
1178 case LevelUp = 'level_up';
1179 }"#.unindent(), 1265)
1180 ],
1181 );
1182}
1183
1184#[gpui::test]
1185fn test_dot_product(mut rng: StdRng) {
1186 assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
1187 assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
1188
1189 for _ in 0..100 {
1190 let size = 1536;
1191 let mut a = vec![0.; size];
1192 let mut b = vec![0.; size];
1193 for (a, b) in a.iter_mut().zip(b.iter_mut()) {
1194 *a = rng.gen();
1195 *b = rng.gen();
1196 }
1197
1198 assert_eq!(
1199 round_to_decimals(dot(&a, &b), 1),
1200 round_to_decimals(reference_dot(&a, &b), 1)
1201 );
1202 }
1203
1204 fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
1205 let factor = (10.0 as f32).powi(decimal_places);
1206 (n * factor).round() / factor
1207 }
1208
1209 fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
1210 a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
1211 }
1212}
1213
1214#[derive(Default)]
1215struct FakeEmbeddingProvider {
1216 embedding_count: AtomicUsize,
1217}
1218
1219impl FakeEmbeddingProvider {
1220 fn embedding_count(&self) -> usize {
1221 self.embedding_count.load(atomic::Ordering::SeqCst)
1222 }
1223}
1224
1225#[async_trait]
1226impl EmbeddingProvider for FakeEmbeddingProvider {
1227 fn count_tokens(&self, span: &str) -> usize {
1228 span.len()
1229 }
1230
1231 fn should_truncate(&self, span: &str) -> bool {
1232 false
1233 }
1234
1235 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
1236 self.embedding_count
1237 .fetch_add(spans.len(), atomic::Ordering::SeqCst);
1238 Ok(spans
1239 .iter()
1240 .map(|span| {
1241 let mut result = vec![1.0; 26];
1242 for letter in span.chars() {
1243 let letter = letter.to_ascii_lowercase();
1244 if letter as u32 >= 'a' as u32 {
1245 let ix = (letter as u32) - ('a' as u32);
1246 if ix < 26 {
1247 result[ix as usize] += 1.0;
1248 }
1249 }
1250 }
1251
1252 let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
1253 for x in &mut result {
1254 *x /= norm;
1255 }
1256
1257 result
1258 })
1259 .collect())
1260 }
1261}
1262
1263fn js_lang() -> Arc<Language> {
1264 Arc::new(
1265 Language::new(
1266 LanguageConfig {
1267 name: "Javascript".into(),
1268 path_suffixes: vec!["js".into()],
1269 ..Default::default()
1270 },
1271 Some(tree_sitter_typescript::language_tsx()),
1272 )
1273 .with_embedding_query(
1274 &r#"
1275
1276 (
1277 (comment)* @context
1278 .
1279 [
1280 (export_statement
1281 (function_declaration
1282 "async"? @name
1283 "function" @name
1284 name: (_) @name))
1285 (function_declaration
1286 "async"? @name
1287 "function" @name
1288 name: (_) @name)
1289 ] @item
1290 )
1291
1292 (
1293 (comment)* @context
1294 .
1295 [
1296 (export_statement
1297 (class_declaration
1298 "class" @name
1299 name: (_) @name))
1300 (class_declaration
1301 "class" @name
1302 name: (_) @name)
1303 ] @item
1304 )
1305
1306 (
1307 (comment)* @context
1308 .
1309 [
1310 (export_statement
1311 (interface_declaration
1312 "interface" @name
1313 name: (_) @name))
1314 (interface_declaration
1315 "interface" @name
1316 name: (_) @name)
1317 ] @item
1318 )
1319
1320 (
1321 (comment)* @context
1322 .
1323 [
1324 (export_statement
1325 (enum_declaration
1326 "enum" @name
1327 name: (_) @name))
1328 (enum_declaration
1329 "enum" @name
1330 name: (_) @name)
1331 ] @item
1332 )
1333
1334 (
1335 (comment)* @context
1336 .
1337 (method_definition
1338 [
1339 "get"
1340 "set"
1341 "async"
1342 "*"
1343 "static"
1344 ]* @name
1345 name: (_) @name) @item
1346 )
1347
1348 "#
1349 .unindent(),
1350 )
1351 .unwrap(),
1352 )
1353}
1354
1355fn rust_lang() -> Arc<Language> {
1356 Arc::new(
1357 Language::new(
1358 LanguageConfig {
1359 name: "Rust".into(),
1360 path_suffixes: vec!["rs".into()],
1361 collapsed_placeholder: " /* ... */ ".to_string(),
1362 ..Default::default()
1363 },
1364 Some(tree_sitter_rust::language()),
1365 )
1366 .with_embedding_query(
1367 r#"
1368 (
1369 [(line_comment) (attribute_item)]* @context
1370 .
1371 [
1372 (struct_item
1373 name: (_) @name)
1374
1375 (enum_item
1376 name: (_) @name)
1377
1378 (impl_item
1379 trait: (_)? @name
1380 "for"? @name
1381 type: (_) @name)
1382
1383 (trait_item
1384 name: (_) @name)
1385
1386 (function_item
1387 name: (_) @name
1388 body: (block
1389 "{" @keep
1390 "}" @keep) @collapse)
1391
1392 (macro_definition
1393 name: (_) @name)
1394 ] @item
1395 )
1396 "#,
1397 )
1398 .unwrap(),
1399 )
1400}
1401
1402fn json_lang() -> Arc<Language> {
1403 Arc::new(
1404 Language::new(
1405 LanguageConfig {
1406 name: "JSON".into(),
1407 path_suffixes: vec!["json".into()],
1408 ..Default::default()
1409 },
1410 Some(tree_sitter_json::language()),
1411 )
1412 .with_embedding_query(
1413 r#"
1414 (document) @item
1415
1416 (array
1417 "[" @keep
1418 .
1419 (object)? @keep
1420 "]" @keep) @collapse
1421
1422 (pair value: (string
1423 "\"" @keep
1424 "\"" @keep) @collapse)
1425 "#,
1426 )
1427 .unwrap(),
1428 )
1429}
1430
1431fn toml_lang() -> Arc<Language> {
1432 Arc::new(Language::new(
1433 LanguageConfig {
1434 name: "TOML".into(),
1435 path_suffixes: vec!["toml".into()],
1436 ..Default::default()
1437 },
1438 Some(tree_sitter_toml::language()),
1439 ))
1440}
1441
1442fn cpp_lang() -> Arc<Language> {
1443 Arc::new(
1444 Language::new(
1445 LanguageConfig {
1446 name: "CPP".into(),
1447 path_suffixes: vec!["cpp".into()],
1448 ..Default::default()
1449 },
1450 Some(tree_sitter_cpp::language()),
1451 )
1452 .with_embedding_query(
1453 r#"
1454 (
1455 (comment)* @context
1456 .
1457 (function_definition
1458 (type_qualifier)? @name
1459 type: (_)? @name
1460 declarator: [
1461 (function_declarator
1462 declarator: (_) @name)
1463 (pointer_declarator
1464 "*" @name
1465 declarator: (function_declarator
1466 declarator: (_) @name))
1467 (pointer_declarator
1468 "*" @name
1469 declarator: (pointer_declarator
1470 "*" @name
1471 declarator: (function_declarator
1472 declarator: (_) @name)))
1473 (reference_declarator
1474 ["&" "&&"] @name
1475 (function_declarator
1476 declarator: (_) @name))
1477 ]
1478 (type_qualifier)? @name) @item
1479 )
1480
1481 (
1482 (comment)* @context
1483 .
1484 (template_declaration
1485 (class_specifier
1486 "class" @name
1487 name: (_) @name)
1488 ) @item
1489 )
1490
1491 (
1492 (comment)* @context
1493 .
1494 (class_specifier
1495 "class" @name
1496 name: (_) @name) @item
1497 )
1498
1499 (
1500 (comment)* @context
1501 .
1502 (enum_specifier
1503 "enum" @name
1504 name: (_) @name) @item
1505 )
1506
1507 (
1508 (comment)* @context
1509 .
1510 (declaration
1511 type: (struct_specifier
1512 "struct" @name)
1513 declarator: (_) @name) @item
1514 )
1515
1516 "#,
1517 )
1518 .unwrap(),
1519 )
1520}
1521
1522fn lua_lang() -> Arc<Language> {
1523 Arc::new(
1524 Language::new(
1525 LanguageConfig {
1526 name: "Lua".into(),
1527 path_suffixes: vec!["lua".into()],
1528 collapsed_placeholder: "--[ ... ]--".to_string(),
1529 ..Default::default()
1530 },
1531 Some(tree_sitter_lua::language()),
1532 )
1533 .with_embedding_query(
1534 r#"
1535 (
1536 (comment)* @context
1537 .
1538 (function_declaration
1539 "function" @name
1540 name: (_) @name
1541 (comment)* @collapse
1542 body: (block) @collapse
1543 ) @item
1544 )
1545 "#,
1546 )
1547 .unwrap(),
1548 )
1549}
1550
1551fn php_lang() -> Arc<Language> {
1552 Arc::new(
1553 Language::new(
1554 LanguageConfig {
1555 name: "PHP".into(),
1556 path_suffixes: vec!["php".into()],
1557 collapsed_placeholder: "/* ... */".into(),
1558 ..Default::default()
1559 },
1560 Some(tree_sitter_php::language()),
1561 )
1562 .with_embedding_query(
1563 r#"
1564 (
1565 (comment)* @context
1566 .
1567 [
1568 (function_definition
1569 "function" @name
1570 name: (_) @name
1571 body: (_
1572 "{" @keep
1573 "}" @keep) @collapse
1574 )
1575
1576 (trait_declaration
1577 "trait" @name
1578 name: (_) @name)
1579
1580 (method_declaration
1581 "function" @name
1582 name: (_) @name
1583 body: (_
1584 "{" @keep
1585 "}" @keep) @collapse
1586 )
1587
1588 (interface_declaration
1589 "interface" @name
1590 name: (_) @name
1591 )
1592
1593 (enum_declaration
1594 "enum" @name
1595 name: (_) @name
1596 )
1597
1598 ] @item
1599 )
1600 "#,
1601 )
1602 .unwrap(),
1603 )
1604}
1605
1606fn ruby_lang() -> Arc<Language> {
1607 Arc::new(
1608 Language::new(
1609 LanguageConfig {
1610 name: "Ruby".into(),
1611 path_suffixes: vec!["rb".into()],
1612 collapsed_placeholder: "# ...".to_string(),
1613 ..Default::default()
1614 },
1615 Some(tree_sitter_ruby::language()),
1616 )
1617 .with_embedding_query(
1618 r#"
1619 (
1620 (comment)* @context
1621 .
1622 [
1623 (module
1624 "module" @name
1625 name: (_) @name)
1626 (method
1627 "def" @name
1628 name: (_) @name
1629 body: (body_statement) @collapse)
1630 (class
1631 "class" @name
1632 name: (_) @name)
1633 (singleton_method
1634 "def" @name
1635 object: (_) @name
1636 "." @name
1637 name: (_) @name
1638 body: (body_statement) @collapse)
1639 ] @item
1640 )
1641 "#,
1642 )
1643 .unwrap(),
1644 )
1645}
1646
1647fn elixir_lang() -> Arc<Language> {
1648 Arc::new(
1649 Language::new(
1650 LanguageConfig {
1651 name: "Elixir".into(),
1652 path_suffixes: vec!["rs".into()],
1653 ..Default::default()
1654 },
1655 Some(tree_sitter_elixir::language()),
1656 )
1657 .with_embedding_query(
1658 r#"
1659 (
1660 (unary_operator
1661 operator: "@"
1662 operand: (call
1663 target: (identifier) @unary
1664 (#match? @unary "^(doc)$"))
1665 ) @context
1666 .
1667 (call
1668 target: (identifier) @name
1669 (arguments
1670 [
1671 (identifier) @name
1672 (call
1673 target: (identifier) @name)
1674 (binary_operator
1675 left: (call
1676 target: (identifier) @name)
1677 operator: "when")
1678 ])
1679 (#match? @name "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @item
1680 )
1681
1682 (call
1683 target: (identifier) @name
1684 (arguments (alias) @name)
1685 (#match? @name "^(defmodule|defprotocol)$")) @item
1686 "#,
1687 )
1688 .unwrap(),
1689 )
1690}
1691
1692#[gpui::test]
1693fn test_subtract_ranges() {
1694 // collapsed_ranges: Vec<Range<usize>>, keep_ranges: Vec<Range<usize>>
1695
1696 assert_eq!(
1697 subtract_ranges(&[0..5, 10..21], &[0..1, 4..5]),
1698 vec![1..4, 10..21]
1699 );
1700
1701 assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
1702}