1use crate::{
2 db::dot,
3 embedding::EmbeddingProvider,
4 parsing::{subtract_ranges, CodeContextRetriever, Document},
5 semantic_index_settings::SemanticIndexSettings,
6 SearchResult, SemanticIndex,
7};
8use anyhow::Result;
9use async_trait::async_trait;
10use gpui::{Task, TestAppContext};
11use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
12use pretty_assertions::assert_eq;
13use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs, Project};
14use rand::{rngs::StdRng, Rng};
15use serde_json::json;
16use settings::SettingsStore;
17use std::{
18 path::Path,
19 sync::{
20 atomic::{self, AtomicUsize},
21 Arc,
22 },
23};
24use unindent::Unindent;
25
26#[ctor::ctor]
27fn init_logger() {
28 if std::env::var("RUST_LOG").is_ok() {
29 env_logger::init();
30 }
31}
32
33#[gpui::test]
34async fn test_semantic_index(cx: &mut TestAppContext) {
35 cx.update(|cx| {
36 cx.set_global(SettingsStore::test(cx));
37 settings::register::<SemanticIndexSettings>(cx);
38 settings::register::<ProjectSettings>(cx);
39 });
40
41 let fs = FakeFs::new(cx.background());
42 fs.insert_tree(
43 "/the-root",
44 json!({
45 "src": {
46 "file1.rs": "
47 fn aaa() {
48 println!(\"aaaaaaaaaaaa!\");
49 }
50
51 fn zzzzz() {
52 println!(\"SLEEPING\");
53 }
54 ".unindent(),
55 "file2.rs": "
56 fn bbb() {
57 println!(\"bbbbbbbbbbbbb!\");
58 }
59 ".unindent(),
60 "file3.toml": "
61 ZZZZZZZZZZZZZZZZZZ = 5
62 ".unindent(),
63 }
64 }),
65 )
66 .await;
67
68 let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
69 let rust_language = rust_lang();
70 let toml_language = toml_lang();
71 languages.add(rust_language);
72 languages.add(toml_language);
73
74 let db_dir = tempdir::TempDir::new("vector-store").unwrap();
75 let db_path = db_dir.path().join("db.sqlite");
76
77 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
78 let store = SemanticIndex::new(
79 fs.clone(),
80 db_path,
81 embedding_provider.clone(),
82 languages,
83 cx.to_async(),
84 )
85 .await
86 .unwrap();
87
88 let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
89 let (file_count, outstanding_file_count) = store
90 .update(cx, |store, cx| store.index_project(project.clone(), cx))
91 .await
92 .unwrap();
93 assert_eq!(file_count, 3);
94 cx.foreground().run_until_parked();
95 assert_eq!(*outstanding_file_count.borrow(), 0);
96
97 let search_results = store
98 .update(cx, |store, cx| {
99 store.search_project(
100 project.clone(),
101 "aaaaaabbbbzz".to_string(),
102 5,
103 vec![],
104 vec![],
105 cx,
106 )
107 })
108 .await
109 .unwrap();
110
111 assert_search_results(
112 &search_results,
113 &[
114 (Path::new("src/file1.rs").into(), 0),
115 (Path::new("src/file2.rs").into(), 0),
116 (Path::new("src/file3.toml").into(), 0),
117 (Path::new("src/file1.rs").into(), 45),
118 ],
119 cx,
120 );
121
122 // Test Include Files Functonality
123 let include_files = vec![PathMatcher::new("*.rs").unwrap()];
124 let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
125 let rust_only_search_results = store
126 .update(cx, |store, cx| {
127 store.search_project(
128 project.clone(),
129 "aaaaaabbbbzz".to_string(),
130 5,
131 include_files,
132 vec![],
133 cx,
134 )
135 })
136 .await
137 .unwrap();
138
139 assert_search_results(
140 &rust_only_search_results,
141 &[
142 (Path::new("src/file1.rs").into(), 0),
143 (Path::new("src/file2.rs").into(), 0),
144 (Path::new("src/file1.rs").into(), 45),
145 ],
146 cx,
147 );
148
149 let no_rust_search_results = store
150 .update(cx, |store, cx| {
151 store.search_project(
152 project.clone(),
153 "aaaaaabbbbzz".to_string(),
154 5,
155 vec![],
156 exclude_files,
157 cx,
158 )
159 })
160 .await
161 .unwrap();
162
163 assert_search_results(
164 &no_rust_search_results,
165 &[(Path::new("src/file3.toml").into(), 0)],
166 cx,
167 );
168
169 fs.save(
170 "/the-root/src/file2.rs".as_ref(),
171 &"
172 fn dddd() { println!(\"ddddd!\"); }
173 struct pqpqpqp {}
174 "
175 .unindent()
176 .into(),
177 Default::default(),
178 )
179 .await
180 .unwrap();
181
182 cx.foreground().run_until_parked();
183
184 let prev_embedding_count = embedding_provider.embedding_count();
185 let (file_count, outstanding_file_count) = store
186 .update(cx, |store, cx| store.index_project(project.clone(), cx))
187 .await
188 .unwrap();
189 assert_eq!(file_count, 1);
190
191 cx.foreground().run_until_parked();
192 assert_eq!(*outstanding_file_count.borrow(), 0);
193
194 assert_eq!(
195 embedding_provider.embedding_count() - prev_embedding_count,
196 2
197 );
198}
199
200#[track_caller]
201fn assert_search_results(
202 actual: &[SearchResult],
203 expected: &[(Arc<Path>, usize)],
204 cx: &TestAppContext,
205) {
206 let actual = actual
207 .iter()
208 .map(|search_result| {
209 search_result.buffer.read_with(cx, |buffer, _cx| {
210 (
211 buffer.file().unwrap().path().clone(),
212 search_result.range.start.to_offset(buffer),
213 )
214 })
215 })
216 .collect::<Vec<_>>();
217 assert_eq!(actual, expected);
218}
219
220#[gpui::test]
221async fn test_code_context_retrieval_rust() {
222 let language = rust_lang();
223 let mut retriever = CodeContextRetriever::new();
224
225 let text = "
226 /// A doc comment
227 /// that spans multiple lines
228 #[gpui::test]
229 fn a() {
230 b
231 }
232
233 impl C for D {
234 }
235
236 impl E {
237 // This is also a preceding comment
238 pub fn function_1() -> Option<()> {
239 todo!();
240 }
241
242 // This is a preceding comment
243 fn function_2() -> Result<()> {
244 todo!();
245 }
246 }
247 "
248 .unindent();
249
250 let documents = retriever.parse_file(&text, language).unwrap();
251
252 assert_documents_eq(
253 &documents,
254 &[
255 (
256 "
257 /// A doc comment
258 /// that spans multiple lines
259 #[gpui::test]
260 fn a() {
261 b
262 }"
263 .unindent(),
264 text.find("fn a").unwrap(),
265 ),
266 (
267 "
268 impl C for D {
269 }"
270 .unindent(),
271 text.find("impl C").unwrap(),
272 ),
273 (
274 "
275 impl E {
276 // This is also a preceding comment
277 pub fn function_1() -> Option<()> { /* ... */ }
278
279 // This is a preceding comment
280 fn function_2() -> Result<()> { /* ... */ }
281 }"
282 .unindent(),
283 text.find("impl E").unwrap(),
284 ),
285 (
286 "
287 // This is also a preceding comment
288 pub fn function_1() -> Option<()> {
289 todo!();
290 }"
291 .unindent(),
292 text.find("pub fn function_1").unwrap(),
293 ),
294 (
295 "
296 // This is a preceding comment
297 fn function_2() -> Result<()> {
298 todo!();
299 }"
300 .unindent(),
301 text.find("fn function_2").unwrap(),
302 ),
303 ],
304 );
305}
306
307#[gpui::test]
308async fn test_code_context_retrieval_json() {
309 let language = json_lang();
310 let mut retriever = CodeContextRetriever::new();
311
312 let text = r#"
313 {
314 "array": [1, 2, 3, 4],
315 "string": "abcdefg",
316 "nested_object": {
317 "array_2": [5, 6, 7, 8],
318 "string_2": "hijklmnop",
319 "boolean": true,
320 "none": null
321 }
322 }
323 "#
324 .unindent();
325
326 let documents = retriever.parse_file(&text, language.clone()).unwrap();
327
328 assert_documents_eq(
329 &documents,
330 &[(
331 r#"
332 {
333 "array": [],
334 "string": "",
335 "nested_object": {
336 "array_2": [],
337 "string_2": "",
338 "boolean": true,
339 "none": null
340 }
341 }"#
342 .unindent(),
343 text.find("{").unwrap(),
344 )],
345 );
346
347 let text = r#"
348 [
349 {
350 "name": "somebody",
351 "age": 42
352 },
353 {
354 "name": "somebody else",
355 "age": 43
356 }
357 ]
358 "#
359 .unindent();
360
361 let documents = retriever.parse_file(&text, language.clone()).unwrap();
362
363 assert_documents_eq(
364 &documents,
365 &[(
366 r#"
367 [{
368 "name": "",
369 "age": 42
370 }]"#
371 .unindent(),
372 text.find("[").unwrap(),
373 )],
374 );
375}
376
377fn assert_documents_eq(
378 documents: &[Document],
379 expected_contents_and_start_offsets: &[(String, usize)],
380) {
381 assert_eq!(
382 documents
383 .iter()
384 .map(|document| (document.content.clone(), document.range.start))
385 .collect::<Vec<_>>(),
386 expected_contents_and_start_offsets
387 );
388}
389
390#[gpui::test]
391async fn test_code_context_retrieval_javascript() {
392 let language = js_lang();
393 let mut retriever = CodeContextRetriever::new();
394
395 let text = "
396 /* globals importScripts, backend */
397 function _authorize() {}
398
399 /**
400 * Sometimes the frontend build is way faster than backend.
401 */
402 export async function authorizeBank() {
403 _authorize(pushModal, upgradingAccountId, {});
404 }
405
406 export class SettingsPage {
407 /* This is a test setting */
408 constructor(page) {
409 this.page = page;
410 }
411 }
412
413 /* This is a test comment */
414 class TestClass {}
415
416 /* Schema for editor_events in Clickhouse. */
417 export interface ClickhouseEditorEvent {
418 installation_id: string
419 operation: string
420 }
421 "
422 .unindent();
423
424 let documents = retriever.parse_file(&text, language.clone()).unwrap();
425
426 assert_documents_eq(
427 &documents,
428 &[
429 (
430 "
431 /* globals importScripts, backend */
432 function _authorize() {}"
433 .unindent(),
434 37,
435 ),
436 (
437 "
438 /**
439 * Sometimes the frontend build is way faster than backend.
440 */
441 export async function authorizeBank() {
442 _authorize(pushModal, upgradingAccountId, {});
443 }"
444 .unindent(),
445 131,
446 ),
447 (
448 "
449 export class SettingsPage {
450 /* This is a test setting */
451 constructor(page) {
452 this.page = page;
453 }
454 }"
455 .unindent(),
456 225,
457 ),
458 (
459 "
460 /* This is a test setting */
461 constructor(page) {
462 this.page = page;
463 }"
464 .unindent(),
465 290,
466 ),
467 (
468 "
469 /* This is a test comment */
470 class TestClass {}"
471 .unindent(),
472 374,
473 ),
474 (
475 "
476 /* Schema for editor_events in Clickhouse. */
477 export interface ClickhouseEditorEvent {
478 installation_id: string
479 operation: string
480 }"
481 .unindent(),
482 440,
483 ),
484 ],
485 )
486}
487
488#[gpui::test]
489async fn test_code_context_retrieval_lua() {
490 let language = lua_lang();
491 let mut retriever = CodeContextRetriever::new();
492
493 let text = r#"
494 -- Creates a new class
495 -- @param baseclass The Baseclass of this class, or nil.
496 -- @return A new class reference.
497 function classes.class(baseclass)
498 -- Create the class definition and metatable.
499 local classdef = {}
500 -- Find the super class, either Object or user-defined.
501 baseclass = baseclass or classes.Object
502 -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
503 setmetatable(classdef, { __index = baseclass })
504 -- All class instances have a reference to the class object.
505 classdef.class = classdef
506 --- Recursivly allocates the inheritance tree of the instance.
507 -- @param mastertable The 'root' of the inheritance tree.
508 -- @return Returns the instance with the allocated inheritance tree.
509 function classdef.alloc(mastertable)
510 -- All class instances have a reference to a superclass object.
511 local instance = { super = baseclass.alloc(mastertable) }
512 -- Any functions this instance does not know of will 'look up' to the superclass definition.
513 setmetatable(instance, { __index = classdef, __newindex = mastertable })
514 return instance
515 end
516 end
517 "#.unindent();
518
519 let documents = retriever.parse_file(&text, language.clone()).unwrap();
520
521 assert_documents_eq(
522 &documents,
523 &[
524 (r#"
525 -- Creates a new class
526 -- @param baseclass The Baseclass of this class, or nil.
527 -- @return A new class reference.
528 function classes.class(baseclass)
529 -- Create the class definition and metatable.
530 local classdef = {}
531 -- Find the super class, either Object or user-defined.
532 baseclass = baseclass or classes.Object
533 -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
534 setmetatable(classdef, { __index = baseclass })
535 -- All class instances have a reference to the class object.
536 classdef.class = classdef
537 --- Recursivly allocates the inheritance tree of the instance.
538 -- @param mastertable The 'root' of the inheritance tree.
539 -- @return Returns the instance with the allocated inheritance tree.
540 function classdef.alloc(mastertable)
541 --[ ... ]--
542 --[ ... ]--
543 end
544 end"#.unindent(),
545 114),
546 (r#"
547 --- Recursivly allocates the inheritance tree of the instance.
548 -- @param mastertable The 'root' of the inheritance tree.
549 -- @return Returns the instance with the allocated inheritance tree.
550 function classdef.alloc(mastertable)
551 -- All class instances have a reference to a superclass object.
552 local instance = { super = baseclass.alloc(mastertable) }
553 -- Any functions this instance does not know of will 'look up' to the superclass definition.
554 setmetatable(instance, { __index = classdef, __newindex = mastertable })
555 return instance
556 end"#.unindent(), 809),
557 ]
558 );
559}
560
561#[gpui::test]
562async fn test_code_context_retrieval_elixir() {
563 let language = elixir_lang();
564 let mut retriever = CodeContextRetriever::new();
565
566 let text = r#"
567 defmodule File.Stream do
568 @moduledoc """
569 Defines a `File.Stream` struct returned by `File.stream!/3`.
570
571 The following fields are public:
572
573 * `path` - the file path
574 * `modes` - the file modes
575 * `raw` - a boolean indicating if bin functions should be used
576 * `line_or_bytes` - if reading should read lines or a given number of bytes
577 * `node` - the node the file belongs to
578
579 """
580
581 defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
582
583 @type t :: %__MODULE__{}
584
585 @doc false
586 def __build__(path, modes, line_or_bytes) do
587 raw = :lists.keyfind(:encoding, 1, modes) == false
588
589 modes =
590 case raw do
591 true ->
592 case :lists.keyfind(:read_ahead, 1, modes) do
593 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
594 {:read_ahead, _} -> [:raw | modes]
595 false -> [:raw, :read_ahead | modes]
596 end
597
598 false ->
599 modes
600 end
601
602 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
603
604 end"#
605 .unindent();
606
607 let documents = retriever.parse_file(&text, language.clone()).unwrap();
608
609 assert_documents_eq(
610 &documents,
611 &[(
612 r#"
613 defmodule File.Stream do
614 @moduledoc """
615 Defines a `File.Stream` struct returned by `File.stream!/3`.
616
617 The following fields are public:
618
619 * `path` - the file path
620 * `modes` - the file modes
621 * `raw` - a boolean indicating if bin functions should be used
622 * `line_or_bytes` - if reading should read lines or a given number of bytes
623 * `node` - the node the file belongs to
624
625 """
626
627 defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
628
629 @type t :: %__MODULE__{}
630
631 @doc false
632 def __build__(path, modes, line_or_bytes) do
633 raw = :lists.keyfind(:encoding, 1, modes) == false
634
635 modes =
636 case raw do
637 true ->
638 case :lists.keyfind(:read_ahead, 1, modes) do
639 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
640 {:read_ahead, _} -> [:raw | modes]
641 false -> [:raw, :read_ahead | modes]
642 end
643
644 false ->
645 modes
646 end
647
648 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
649
650 end"#
651 .unindent(),
652 0,
653 ),(r#"
654 @doc false
655 def __build__(path, modes, line_or_bytes) do
656 raw = :lists.keyfind(:encoding, 1, modes) == false
657
658 modes =
659 case raw do
660 true ->
661 case :lists.keyfind(:read_ahead, 1, modes) do
662 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
663 {:read_ahead, _} -> [:raw | modes]
664 false -> [:raw, :read_ahead | modes]
665 end
666
667 false ->
668 modes
669 end
670
671 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
672
673 end"#.unindent(), 574)],
674 );
675}
676
677#[gpui::test]
678async fn test_code_context_retrieval_cpp() {
679 let language = cpp_lang();
680 let mut retriever = CodeContextRetriever::new();
681
682 let text = "
683 /**
684 * @brief Main function
685 * @returns 0 on exit
686 */
687 int main() { return 0; }
688
689 /**
690 * This is a test comment
691 */
692 class MyClass { // The class
693 public: // Access specifier
694 int myNum; // Attribute (int variable)
695 string myString; // Attribute (string variable)
696 };
697
698 // This is a test comment
699 enum Color { red, green, blue };
700
701 /** This is a preceding block comment
702 * This is the second line
703 */
704 struct { // Structure declaration
705 int myNum; // Member (int variable)
706 string myString; // Member (string variable)
707 } myStructure;
708
709 /**
710 * @brief Matrix class.
711 */
712 template <typename T,
713 typename = typename std::enable_if<
714 std::is_integral<T>::value || std::is_floating_point<T>::value,
715 bool>::type>
716 class Matrix2 {
717 std::vector<std::vector<T>> _mat;
718
719 public:
720 /**
721 * @brief Constructor
722 * @tparam Integer ensuring integers are being evaluated and not other
723 * data types.
724 * @param size denoting the size of Matrix as size x size
725 */
726 template <typename Integer,
727 typename = typename std::enable_if<std::is_integral<Integer>::value,
728 Integer>::type>
729 explicit Matrix(const Integer size) {
730 for (size_t i = 0; i < size; ++i) {
731 _mat.emplace_back(std::vector<T>(size, 0));
732 }
733 }
734 }"
735 .unindent();
736
737 let documents = retriever.parse_file(&text, language.clone()).unwrap();
738
739 assert_documents_eq(
740 &documents,
741 &[
742 (
743 "
744 /**
745 * @brief Main function
746 * @returns 0 on exit
747 */
748 int main() { return 0; }"
749 .unindent(),
750 54,
751 ),
752 (
753 "
754 /**
755 * This is a test comment
756 */
757 class MyClass { // The class
758 public: // Access specifier
759 int myNum; // Attribute (int variable)
760 string myString; // Attribute (string variable)
761 }"
762 .unindent(),
763 112,
764 ),
765 (
766 "
767 // This is a test comment
768 enum Color { red, green, blue }"
769 .unindent(),
770 322,
771 ),
772 (
773 "
774 /** This is a preceding block comment
775 * This is the second line
776 */
777 struct { // Structure declaration
778 int myNum; // Member (int variable)
779 string myString; // Member (string variable)
780 } myStructure;"
781 .unindent(),
782 425,
783 ),
784 (
785 "
786 /**
787 * @brief Matrix class.
788 */
789 template <typename T,
790 typename = typename std::enable_if<
791 std::is_integral<T>::value || std::is_floating_point<T>::value,
792 bool>::type>
793 class Matrix2 {
794 std::vector<std::vector<T>> _mat;
795
796 public:
797 /**
798 * @brief Constructor
799 * @tparam Integer ensuring integers are being evaluated and not other
800 * data types.
801 * @param size denoting the size of Matrix as size x size
802 */
803 template <typename Integer,
804 typename = typename std::enable_if<std::is_integral<Integer>::value,
805 Integer>::type>
806 explicit Matrix(const Integer size) {
807 for (size_t i = 0; i < size; ++i) {
808 _mat.emplace_back(std::vector<T>(size, 0));
809 }
810 }
811 }"
812 .unindent(),
813 612,
814 ),
815 (
816 "
817 explicit Matrix(const Integer size) {
818 for (size_t i = 0; i < size; ++i) {
819 _mat.emplace_back(std::vector<T>(size, 0));
820 }
821 }"
822 .unindent(),
823 1226,
824 ),
825 ],
826 );
827}
828
829#[gpui::test]
830async fn test_code_context_retrieval_ruby() {
831 let language = ruby_lang();
832 let mut retriever = CodeContextRetriever::new();
833
834 let text = r#"
835 # This concern is inspired by "sudo mode" on GitHub. It
836 # is a way to re-authenticate a user before allowing them
837 # to see or perform an action.
838 #
839 # Add `before_action :require_challenge!` to actions you
840 # want to protect.
841 #
842 # The user will be shown a page to enter the challenge (which
843 # is either the password, or just the username when no
844 # password exists). Upon passing, there is a grace period
845 # during which no challenge will be asked from the user.
846 #
847 # Accessing challenge-protected resources during the grace
848 # period will refresh the grace period.
849 module ChallengableConcern
850 extend ActiveSupport::Concern
851
852 CHALLENGE_TIMEOUT = 1.hour.freeze
853
854 def require_challenge!
855 return if skip_challenge?
856
857 if challenge_passed_recently?
858 session[:challenge_passed_at] = Time.now.utc
859 return
860 end
861
862 @challenge = Form::Challenge.new(return_to: request.url)
863
864 if params.key?(:form_challenge)
865 if challenge_passed?
866 session[:challenge_passed_at] = Time.now.utc
867 else
868 flash.now[:alert] = I18n.t('challenge.invalid_password')
869 render_challenge
870 end
871 else
872 render_challenge
873 end
874 end
875
876 def challenge_passed?
877 current_user.valid_password?(challenge_params[:current_password])
878 end
879 end
880
881 class Animal
882 include Comparable
883
884 attr_reader :legs
885
886 def initialize(name, legs)
887 @name, @legs = name, legs
888 end
889
890 def <=>(other)
891 legs <=> other.legs
892 end
893 end
894
895 # Singleton method for car object
896 def car.wheels
897 puts "There are four wheels"
898 end"#
899 .unindent();
900
901 let documents = retriever.parse_file(&text, language.clone()).unwrap();
902
903 assert_documents_eq(
904 &documents,
905 &[
906 (
907 r#"
908 # This concern is inspired by "sudo mode" on GitHub. It
909 # is a way to re-authenticate a user before allowing them
910 # to see or perform an action.
911 #
912 # Add `before_action :require_challenge!` to actions you
913 # want to protect.
914 #
915 # The user will be shown a page to enter the challenge (which
916 # is either the password, or just the username when no
917 # password exists). Upon passing, there is a grace period
918 # during which no challenge will be asked from the user.
919 #
920 # Accessing challenge-protected resources during the grace
921 # period will refresh the grace period.
922 module ChallengableConcern
923 extend ActiveSupport::Concern
924
925 CHALLENGE_TIMEOUT = 1.hour.freeze
926
927 def require_challenge!
928 # ...
929 end
930
931 def challenge_passed?
932 # ...
933 end
934 end"#
935 .unindent(),
936 558,
937 ),
938 (
939 r#"
940 def require_challenge!
941 return if skip_challenge?
942
943 if challenge_passed_recently?
944 session[:challenge_passed_at] = Time.now.utc
945 return
946 end
947
948 @challenge = Form::Challenge.new(return_to: request.url)
949
950 if params.key?(:form_challenge)
951 if challenge_passed?
952 session[:challenge_passed_at] = Time.now.utc
953 else
954 flash.now[:alert] = I18n.t('challenge.invalid_password')
955 render_challenge
956 end
957 else
958 render_challenge
959 end
960 end"#
961 .unindent(),
962 663,
963 ),
964 (
965 r#"
966 def challenge_passed?
967 current_user.valid_password?(challenge_params[:current_password])
968 end"#
969 .unindent(),
970 1254,
971 ),
972 (
973 r#"
974 class Animal
975 include Comparable
976
977 attr_reader :legs
978
979 def initialize(name, legs)
980 # ...
981 end
982
983 def <=>(other)
984 # ...
985 end
986 end"#
987 .unindent(),
988 1363,
989 ),
990 (
991 r#"
992 def initialize(name, legs)
993 @name, @legs = name, legs
994 end"#
995 .unindent(),
996 1427,
997 ),
998 (
999 r#"
1000 def <=>(other)
1001 legs <=> other.legs
1002 end"#
1003 .unindent(),
1004 1501,
1005 ),
1006 (
1007 r#"
1008 # Singleton method for car object
1009 def car.wheels
1010 puts "There are four wheels"
1011 end"#
1012 .unindent(),
1013 1591,
1014 ),
1015 ],
1016 );
1017}
1018
1019#[gpui::test]
1020async fn test_code_context_retrieval_php() {
1021 let language = php_lang();
1022 let mut retriever = CodeContextRetriever::new();
1023
1024 let text = r#"
1025 <?php
1026
1027 namespace LevelUp\Experience\Concerns;
1028
1029 /*
1030 This is a multiple-lines comment block
1031 that spans over multiple
1032 lines
1033 */
1034 function functionName() {
1035 echo "Hello world!";
1036 }
1037
1038 trait HasAchievements
1039 {
1040 /**
1041 * @throws \Exception
1042 */
1043 public function grantAchievement(Achievement $achievement, $progress = null): void
1044 {
1045 if ($progress > 100) {
1046 throw new Exception(message: 'Progress cannot be greater than 100');
1047 }
1048
1049 if ($this->achievements()->find($achievement->id)) {
1050 throw new Exception(message: 'User already has this Achievement');
1051 }
1052
1053 $this->achievements()->attach($achievement, [
1054 'progress' => $progress ?? null,
1055 ]);
1056
1057 $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1058 }
1059
1060 public function achievements(): BelongsToMany
1061 {
1062 return $this->belongsToMany(related: Achievement::class)
1063 ->withPivot(columns: 'progress')
1064 ->where('is_secret', false)
1065 ->using(AchievementUser::class);
1066 }
1067 }
1068
1069 interface Multiplier
1070 {
1071 public function qualifies(array $data): bool;
1072
1073 public function setMultiplier(): int;
1074 }
1075
1076 enum AuditType: string
1077 {
1078 case Add = 'add';
1079 case Remove = 'remove';
1080 case Reset = 'reset';
1081 case LevelUp = 'level_up';
1082 }
1083
1084 ?>"#
1085 .unindent();
1086
1087 let documents = retriever.parse_file(&text, language.clone()).unwrap();
1088
1089 assert_documents_eq(
1090 &documents,
1091 &[
1092 (
1093 r#"
1094 /*
1095 This is a multiple-lines comment block
1096 that spans over multiple
1097 lines
1098 */
1099 function functionName() {
1100 echo "Hello world!";
1101 }"#
1102 .unindent(),
1103 123,
1104 ),
1105 (
1106 r#"
1107 trait HasAchievements
1108 {
1109 /**
1110 * @throws \Exception
1111 */
1112 public function grantAchievement(Achievement $achievement, $progress = null): void
1113 {/* ... */}
1114
1115 public function achievements(): BelongsToMany
1116 {/* ... */}
1117 }"#
1118 .unindent(),
1119 177,
1120 ),
1121 (r#"
1122 /**
1123 * @throws \Exception
1124 */
1125 public function grantAchievement(Achievement $achievement, $progress = null): void
1126 {
1127 if ($progress > 100) {
1128 throw new Exception(message: 'Progress cannot be greater than 100');
1129 }
1130
1131 if ($this->achievements()->find($achievement->id)) {
1132 throw new Exception(message: 'User already has this Achievement');
1133 }
1134
1135 $this->achievements()->attach($achievement, [
1136 'progress' => $progress ?? null,
1137 ]);
1138
1139 $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1140 }"#.unindent(), 245),
1141 (r#"
1142 public function achievements(): BelongsToMany
1143 {
1144 return $this->belongsToMany(related: Achievement::class)
1145 ->withPivot(columns: 'progress')
1146 ->where('is_secret', false)
1147 ->using(AchievementUser::class);
1148 }"#.unindent(), 902),
1149 (r#"
1150 interface Multiplier
1151 {
1152 public function qualifies(array $data): bool;
1153
1154 public function setMultiplier(): int;
1155 }"#.unindent(),
1156 1146),
1157 (r#"
1158 enum AuditType: string
1159 {
1160 case Add = 'add';
1161 case Remove = 'remove';
1162 case Reset = 'reset';
1163 case LevelUp = 'level_up';
1164 }"#.unindent(), 1265)
1165 ],
1166 );
1167}
1168
1169#[gpui::test]
1170fn test_dot_product(mut rng: StdRng) {
1171 assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
1172 assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
1173
1174 for _ in 0..100 {
1175 let size = 1536;
1176 let mut a = vec![0.; size];
1177 let mut b = vec![0.; size];
1178 for (a, b) in a.iter_mut().zip(b.iter_mut()) {
1179 *a = rng.gen();
1180 *b = rng.gen();
1181 }
1182
1183 assert_eq!(
1184 round_to_decimals(dot(&a, &b), 1),
1185 round_to_decimals(reference_dot(&a, &b), 1)
1186 );
1187 }
1188
1189 fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
1190 let factor = (10.0 as f32).powi(decimal_places);
1191 (n * factor).round() / factor
1192 }
1193
1194 fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
1195 a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
1196 }
1197}
1198
1199#[derive(Default)]
1200struct FakeEmbeddingProvider {
1201 embedding_count: AtomicUsize,
1202}
1203
1204impl FakeEmbeddingProvider {
1205 fn embedding_count(&self) -> usize {
1206 self.embedding_count.load(atomic::Ordering::SeqCst)
1207 }
1208}
1209
1210#[async_trait]
1211impl EmbeddingProvider for FakeEmbeddingProvider {
1212 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
1213 self.embedding_count
1214 .fetch_add(spans.len(), atomic::Ordering::SeqCst);
1215 Ok(spans
1216 .iter()
1217 .map(|span| {
1218 let mut result = vec![1.0; 26];
1219 for letter in span.chars() {
1220 let letter = letter.to_ascii_lowercase();
1221 if letter as u32 >= 'a' as u32 {
1222 let ix = (letter as u32) - ('a' as u32);
1223 if ix < 26 {
1224 result[ix as usize] += 1.0;
1225 }
1226 }
1227 }
1228
1229 let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
1230 for x in &mut result {
1231 *x /= norm;
1232 }
1233
1234 result
1235 })
1236 .collect())
1237 }
1238}
1239
1240fn js_lang() -> Arc<Language> {
1241 Arc::new(
1242 Language::new(
1243 LanguageConfig {
1244 name: "Javascript".into(),
1245 path_suffixes: vec!["js".into()],
1246 ..Default::default()
1247 },
1248 Some(tree_sitter_typescript::language_tsx()),
1249 )
1250 .with_embedding_query(
1251 &r#"
1252
1253 (
1254 (comment)* @context
1255 .
1256 [
1257 (export_statement
1258 (function_declaration
1259 "async"? @name
1260 "function" @name
1261 name: (_) @name))
1262 (function_declaration
1263 "async"? @name
1264 "function" @name
1265 name: (_) @name)
1266 ] @item
1267 )
1268
1269 (
1270 (comment)* @context
1271 .
1272 [
1273 (export_statement
1274 (class_declaration
1275 "class" @name
1276 name: (_) @name))
1277 (class_declaration
1278 "class" @name
1279 name: (_) @name)
1280 ] @item
1281 )
1282
1283 (
1284 (comment)* @context
1285 .
1286 [
1287 (export_statement
1288 (interface_declaration
1289 "interface" @name
1290 name: (_) @name))
1291 (interface_declaration
1292 "interface" @name
1293 name: (_) @name)
1294 ] @item
1295 )
1296
1297 (
1298 (comment)* @context
1299 .
1300 [
1301 (export_statement
1302 (enum_declaration
1303 "enum" @name
1304 name: (_) @name))
1305 (enum_declaration
1306 "enum" @name
1307 name: (_) @name)
1308 ] @item
1309 )
1310
1311 (
1312 (comment)* @context
1313 .
1314 (method_definition
1315 [
1316 "get"
1317 "set"
1318 "async"
1319 "*"
1320 "static"
1321 ]* @name
1322 name: (_) @name) @item
1323 )
1324
1325 "#
1326 .unindent(),
1327 )
1328 .unwrap(),
1329 )
1330}
1331
1332fn rust_lang() -> Arc<Language> {
1333 Arc::new(
1334 Language::new(
1335 LanguageConfig {
1336 name: "Rust".into(),
1337 path_suffixes: vec!["rs".into()],
1338 collapsed_placeholder: " /* ... */ ".to_string(),
1339 ..Default::default()
1340 },
1341 Some(tree_sitter_rust::language()),
1342 )
1343 .with_embedding_query(
1344 r#"
1345 (
1346 [(line_comment) (attribute_item)]* @context
1347 .
1348 [
1349 (struct_item
1350 name: (_) @name)
1351
1352 (enum_item
1353 name: (_) @name)
1354
1355 (impl_item
1356 trait: (_)? @name
1357 "for"? @name
1358 type: (_) @name)
1359
1360 (trait_item
1361 name: (_) @name)
1362
1363 (function_item
1364 name: (_) @name
1365 body: (block
1366 "{" @keep
1367 "}" @keep) @collapse)
1368
1369 (macro_definition
1370 name: (_) @name)
1371 ] @item
1372 )
1373 "#,
1374 )
1375 .unwrap(),
1376 )
1377}
1378
1379fn json_lang() -> Arc<Language> {
1380 Arc::new(
1381 Language::new(
1382 LanguageConfig {
1383 name: "JSON".into(),
1384 path_suffixes: vec!["json".into()],
1385 ..Default::default()
1386 },
1387 Some(tree_sitter_json::language()),
1388 )
1389 .with_embedding_query(
1390 r#"
1391 (document) @item
1392
1393 (array
1394 "[" @keep
1395 .
1396 (object)? @keep
1397 "]" @keep) @collapse
1398
1399 (pair value: (string
1400 "\"" @keep
1401 "\"" @keep) @collapse)
1402 "#,
1403 )
1404 .unwrap(),
1405 )
1406}
1407
1408fn toml_lang() -> Arc<Language> {
1409 Arc::new(Language::new(
1410 LanguageConfig {
1411 name: "TOML".into(),
1412 path_suffixes: vec!["toml".into()],
1413 ..Default::default()
1414 },
1415 Some(tree_sitter_toml::language()),
1416 ))
1417}
1418
1419fn cpp_lang() -> Arc<Language> {
1420 Arc::new(
1421 Language::new(
1422 LanguageConfig {
1423 name: "CPP".into(),
1424 path_suffixes: vec!["cpp".into()],
1425 ..Default::default()
1426 },
1427 Some(tree_sitter_cpp::language()),
1428 )
1429 .with_embedding_query(
1430 r#"
1431 (
1432 (comment)* @context
1433 .
1434 (function_definition
1435 (type_qualifier)? @name
1436 type: (_)? @name
1437 declarator: [
1438 (function_declarator
1439 declarator: (_) @name)
1440 (pointer_declarator
1441 "*" @name
1442 declarator: (function_declarator
1443 declarator: (_) @name))
1444 (pointer_declarator
1445 "*" @name
1446 declarator: (pointer_declarator
1447 "*" @name
1448 declarator: (function_declarator
1449 declarator: (_) @name)))
1450 (reference_declarator
1451 ["&" "&&"] @name
1452 (function_declarator
1453 declarator: (_) @name))
1454 ]
1455 (type_qualifier)? @name) @item
1456 )
1457
1458 (
1459 (comment)* @context
1460 .
1461 (template_declaration
1462 (class_specifier
1463 "class" @name
1464 name: (_) @name)
1465 ) @item
1466 )
1467
1468 (
1469 (comment)* @context
1470 .
1471 (class_specifier
1472 "class" @name
1473 name: (_) @name) @item
1474 )
1475
1476 (
1477 (comment)* @context
1478 .
1479 (enum_specifier
1480 "enum" @name
1481 name: (_) @name) @item
1482 )
1483
1484 (
1485 (comment)* @context
1486 .
1487 (declaration
1488 type: (struct_specifier
1489 "struct" @name)
1490 declarator: (_) @name) @item
1491 )
1492
1493 "#,
1494 )
1495 .unwrap(),
1496 )
1497}
1498
1499fn lua_lang() -> Arc<Language> {
1500 Arc::new(
1501 Language::new(
1502 LanguageConfig {
1503 name: "Lua".into(),
1504 path_suffixes: vec!["lua".into()],
1505 collapsed_placeholder: "--[ ... ]--".to_string(),
1506 ..Default::default()
1507 },
1508 Some(tree_sitter_lua::language()),
1509 )
1510 .with_embedding_query(
1511 r#"
1512 (
1513 (comment)* @context
1514 .
1515 (function_declaration
1516 "function" @name
1517 name: (_) @name
1518 (comment)* @collapse
1519 body: (block) @collapse
1520 ) @item
1521 )
1522 "#,
1523 )
1524 .unwrap(),
1525 )
1526}
1527
1528fn php_lang() -> Arc<Language> {
1529 Arc::new(
1530 Language::new(
1531 LanguageConfig {
1532 name: "PHP".into(),
1533 path_suffixes: vec!["php".into()],
1534 collapsed_placeholder: "/* ... */".into(),
1535 ..Default::default()
1536 },
1537 Some(tree_sitter_php::language()),
1538 )
1539 .with_embedding_query(
1540 r#"
1541 (
1542 (comment)* @context
1543 .
1544 [
1545 (function_definition
1546 "function" @name
1547 name: (_) @name
1548 body: (_
1549 "{" @keep
1550 "}" @keep) @collapse
1551 )
1552
1553 (trait_declaration
1554 "trait" @name
1555 name: (_) @name)
1556
1557 (method_declaration
1558 "function" @name
1559 name: (_) @name
1560 body: (_
1561 "{" @keep
1562 "}" @keep) @collapse
1563 )
1564
1565 (interface_declaration
1566 "interface" @name
1567 name: (_) @name
1568 )
1569
1570 (enum_declaration
1571 "enum" @name
1572 name: (_) @name
1573 )
1574
1575 ] @item
1576 )
1577 "#,
1578 )
1579 .unwrap(),
1580 )
1581}
1582
1583fn ruby_lang() -> Arc<Language> {
1584 Arc::new(
1585 Language::new(
1586 LanguageConfig {
1587 name: "Ruby".into(),
1588 path_suffixes: vec!["rb".into()],
1589 collapsed_placeholder: "# ...".to_string(),
1590 ..Default::default()
1591 },
1592 Some(tree_sitter_ruby::language()),
1593 )
1594 .with_embedding_query(
1595 r#"
1596 (
1597 (comment)* @context
1598 .
1599 [
1600 (module
1601 "module" @name
1602 name: (_) @name)
1603 (method
1604 "def" @name
1605 name: (_) @name
1606 body: (body_statement) @collapse)
1607 (class
1608 "class" @name
1609 name: (_) @name)
1610 (singleton_method
1611 "def" @name
1612 object: (_) @name
1613 "." @name
1614 name: (_) @name
1615 body: (body_statement) @collapse)
1616 ] @item
1617 )
1618 "#,
1619 )
1620 .unwrap(),
1621 )
1622}
1623
1624fn elixir_lang() -> Arc<Language> {
1625 Arc::new(
1626 Language::new(
1627 LanguageConfig {
1628 name: "Elixir".into(),
1629 path_suffixes: vec!["rs".into()],
1630 ..Default::default()
1631 },
1632 Some(tree_sitter_elixir::language()),
1633 )
1634 .with_embedding_query(
1635 r#"
1636 (
1637 (unary_operator
1638 operator: "@"
1639 operand: (call
1640 target: (identifier) @unary
1641 (#match? @unary "^(doc)$"))
1642 ) @context
1643 .
1644 (call
1645 target: (identifier) @name
1646 (arguments
1647 [
1648 (identifier) @name
1649 (call
1650 target: (identifier) @name)
1651 (binary_operator
1652 left: (call
1653 target: (identifier) @name)
1654 operator: "when")
1655 ])
1656 (#match? @name "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @item
1657 )
1658
1659 (call
1660 target: (identifier) @name
1661 (arguments (alias) @name)
1662 (#match? @name "^(defmodule|defprotocol)$")) @item
1663 "#,
1664 )
1665 .unwrap(),
1666 )
1667}
1668
1669#[gpui::test]
1670fn test_subtract_ranges() {
1671 // collapsed_ranges: Vec<Range<usize>>, keep_ranges: Vec<Range<usize>>
1672
1673 assert_eq!(
1674 subtract_ranges(&[0..5, 10..21], &[0..1, 4..5]),
1675 vec![1..4, 10..21]
1676 );
1677
1678 assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
1679}