1use crate::{
2 db::dot,
3 embedding::EmbeddingProvider,
4 parsing::{subtract_ranges, CodeContextRetriever, Document},
5 semantic_index_settings::SemanticIndexSettings,
6 SearchResult, SemanticIndex,
7};
8use anyhow::Result;
9use async_trait::async_trait;
10use gpui::{Task, TestAppContext};
11use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
12use pretty_assertions::assert_eq;
13use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs, Project};
14use rand::{rngs::StdRng, Rng};
15use serde_json::json;
16use settings::SettingsStore;
17use std::{
18 path::Path,
19 sync::{
20 atomic::{self, AtomicUsize},
21 Arc,
22 },
23};
24use unindent::Unindent;
25
26#[ctor::ctor]
27fn init_logger() {
28 if std::env::var("RUST_LOG").is_ok() {
29 env_logger::init();
30 }
31}
32
33#[gpui::test]
34async fn test_semantic_index(cx: &mut TestAppContext) {
35 cx.update(|cx| {
36 cx.set_global(SettingsStore::test(cx));
37 settings::register::<SemanticIndexSettings>(cx);
38 settings::register::<ProjectSettings>(cx);
39 });
40
41 let fs = FakeFs::new(cx.background());
42 fs.insert_tree(
43 "/the-root",
44 json!({
45 "src": {
46 "file1.rs": "
47 fn aaa() {
48 println!(\"aaaaaaaaaaaa!\");
49 }
50
51 fn zzzzz() {
52 println!(\"SLEEPING\");
53 }
54 ".unindent(),
55 "file2.rs": "
56 fn bbb() {
57 println!(\"bbbbbbbbbbbbb!\");
58 }
59 ".unindent(),
60 "file3.toml": "
61 ZZZZZZZZZZZZZZZZZZ = 5
62 ".unindent(),
63 }
64 }),
65 )
66 .await;
67
68 let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
69 let rust_language = rust_lang();
70 let toml_language = toml_lang();
71 languages.add(rust_language);
72 languages.add(toml_language);
73
74 let db_dir = tempdir::TempDir::new("vector-store").unwrap();
75 let db_path = db_dir.path().join("db.sqlite");
76
77 let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
78 let store = SemanticIndex::new(
79 fs.clone(),
80 db_path,
81 embedding_provider.clone(),
82 languages,
83 cx.to_async(),
84 )
85 .await
86 .unwrap();
87
88 let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
89
90 let _ = store
91 .update(cx, |store, cx| {
92 store.initialize_project(project.clone(), cx)
93 })
94 .await;
95
96 let (file_count, outstanding_file_count) = store
97 .update(cx, |store, cx| store.index_project(project.clone(), cx))
98 .await
99 .unwrap();
100 assert_eq!(file_count, 3);
101 cx.foreground().run_until_parked();
102 assert_eq!(*outstanding_file_count.borrow(), 0);
103
104 let search_results = store
105 .update(cx, |store, cx| {
106 store.search_project(
107 project.clone(),
108 "aaaaaabbbbzz".to_string(),
109 5,
110 vec![],
111 vec![],
112 cx,
113 )
114 })
115 .await
116 .unwrap();
117
118 assert_search_results(
119 &search_results,
120 &[
121 (Path::new("src/file1.rs").into(), 0),
122 (Path::new("src/file2.rs").into(), 0),
123 (Path::new("src/file3.toml").into(), 0),
124 (Path::new("src/file1.rs").into(), 45),
125 ],
126 cx,
127 );
128
129 // Test Include Files Functonality
130 let include_files = vec![PathMatcher::new("*.rs").unwrap()];
131 let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
132 let rust_only_search_results = store
133 .update(cx, |store, cx| {
134 store.search_project(
135 project.clone(),
136 "aaaaaabbbbzz".to_string(),
137 5,
138 include_files,
139 vec![],
140 cx,
141 )
142 })
143 .await
144 .unwrap();
145
146 assert_search_results(
147 &rust_only_search_results,
148 &[
149 (Path::new("src/file1.rs").into(), 0),
150 (Path::new("src/file2.rs").into(), 0),
151 (Path::new("src/file1.rs").into(), 45),
152 ],
153 cx,
154 );
155
156 let no_rust_search_results = store
157 .update(cx, |store, cx| {
158 store.search_project(
159 project.clone(),
160 "aaaaaabbbbzz".to_string(),
161 5,
162 vec![],
163 exclude_files,
164 cx,
165 )
166 })
167 .await
168 .unwrap();
169
170 assert_search_results(
171 &no_rust_search_results,
172 &[(Path::new("src/file3.toml").into(), 0)],
173 cx,
174 );
175
176 fs.save(
177 "/the-root/src/file2.rs".as_ref(),
178 &"
179 fn dddd() { println!(\"ddddd!\"); }
180 struct pqpqpqp {}
181 "
182 .unindent()
183 .into(),
184 Default::default(),
185 )
186 .await
187 .unwrap();
188
189 cx.foreground().run_until_parked();
190
191 let prev_embedding_count = embedding_provider.embedding_count();
192 let (file_count, outstanding_file_count) = store
193 .update(cx, |store, cx| store.index_project(project.clone(), cx))
194 .await
195 .unwrap();
196 assert_eq!(file_count, 1);
197
198 cx.foreground().run_until_parked();
199 assert_eq!(*outstanding_file_count.borrow(), 0);
200
201 assert_eq!(
202 embedding_provider.embedding_count() - prev_embedding_count,
203 2
204 );
205}
206
207#[track_caller]
208fn assert_search_results(
209 actual: &[SearchResult],
210 expected: &[(Arc<Path>, usize)],
211 cx: &TestAppContext,
212) {
213 let actual = actual
214 .iter()
215 .map(|search_result| {
216 search_result.buffer.read_with(cx, |buffer, _cx| {
217 (
218 buffer.file().unwrap().path().clone(),
219 search_result.range.start.to_offset(buffer),
220 )
221 })
222 })
223 .collect::<Vec<_>>();
224 assert_eq!(actual, expected);
225}
226
227#[gpui::test]
228async fn test_code_context_retrieval_rust() {
229 let language = rust_lang();
230 let mut retriever = CodeContextRetriever::new();
231
232 let text = "
233 /// A doc comment
234 /// that spans multiple lines
235 #[gpui::test]
236 fn a() {
237 b
238 }
239
240 impl C for D {
241 }
242
243 impl E {
244 // This is also a preceding comment
245 pub fn function_1() -> Option<()> {
246 todo!();
247 }
248
249 // This is a preceding comment
250 fn function_2() -> Result<()> {
251 todo!();
252 }
253 }
254 "
255 .unindent();
256
257 let documents = retriever.parse_file(&text, language).unwrap();
258
259 assert_documents_eq(
260 &documents,
261 &[
262 (
263 "
264 /// A doc comment
265 /// that spans multiple lines
266 #[gpui::test]
267 fn a() {
268 b
269 }"
270 .unindent(),
271 text.find("fn a").unwrap(),
272 ),
273 (
274 "
275 impl C for D {
276 }"
277 .unindent(),
278 text.find("impl C").unwrap(),
279 ),
280 (
281 "
282 impl E {
283 // This is also a preceding comment
284 pub fn function_1() -> Option<()> { /* ... */ }
285
286 // This is a preceding comment
287 fn function_2() -> Result<()> { /* ... */ }
288 }"
289 .unindent(),
290 text.find("impl E").unwrap(),
291 ),
292 (
293 "
294 // This is also a preceding comment
295 pub fn function_1() -> Option<()> {
296 todo!();
297 }"
298 .unindent(),
299 text.find("pub fn function_1").unwrap(),
300 ),
301 (
302 "
303 // This is a preceding comment
304 fn function_2() -> Result<()> {
305 todo!();
306 }"
307 .unindent(),
308 text.find("fn function_2").unwrap(),
309 ),
310 ],
311 );
312}
313
314#[gpui::test]
315async fn test_code_context_retrieval_json() {
316 let language = json_lang();
317 let mut retriever = CodeContextRetriever::new();
318
319 let text = r#"
320 {
321 "array": [1, 2, 3, 4],
322 "string": "abcdefg",
323 "nested_object": {
324 "array_2": [5, 6, 7, 8],
325 "string_2": "hijklmnop",
326 "boolean": true,
327 "none": null
328 }
329 }
330 "#
331 .unindent();
332
333 let documents = retriever.parse_file(&text, language.clone()).unwrap();
334
335 assert_documents_eq(
336 &documents,
337 &[(
338 r#"
339 {
340 "array": [],
341 "string": "",
342 "nested_object": {
343 "array_2": [],
344 "string_2": "",
345 "boolean": true,
346 "none": null
347 }
348 }"#
349 .unindent(),
350 text.find("{").unwrap(),
351 )],
352 );
353
354 let text = r#"
355 [
356 {
357 "name": "somebody",
358 "age": 42
359 },
360 {
361 "name": "somebody else",
362 "age": 43
363 }
364 ]
365 "#
366 .unindent();
367
368 let documents = retriever.parse_file(&text, language.clone()).unwrap();
369
370 assert_documents_eq(
371 &documents,
372 &[(
373 r#"
374 [{
375 "name": "",
376 "age": 42
377 }]"#
378 .unindent(),
379 text.find("[").unwrap(),
380 )],
381 );
382}
383
384fn assert_documents_eq(
385 documents: &[Document],
386 expected_contents_and_start_offsets: &[(String, usize)],
387) {
388 assert_eq!(
389 documents
390 .iter()
391 .map(|document| (document.content.clone(), document.range.start))
392 .collect::<Vec<_>>(),
393 expected_contents_and_start_offsets
394 );
395}
396
397#[gpui::test]
398async fn test_code_context_retrieval_javascript() {
399 let language = js_lang();
400 let mut retriever = CodeContextRetriever::new();
401
402 let text = "
403 /* globals importScripts, backend */
404 function _authorize() {}
405
406 /**
407 * Sometimes the frontend build is way faster than backend.
408 */
409 export async function authorizeBank() {
410 _authorize(pushModal, upgradingAccountId, {});
411 }
412
413 export class SettingsPage {
414 /* This is a test setting */
415 constructor(page) {
416 this.page = page;
417 }
418 }
419
420 /* This is a test comment */
421 class TestClass {}
422
423 /* Schema for editor_events in Clickhouse. */
424 export interface ClickhouseEditorEvent {
425 installation_id: string
426 operation: string
427 }
428 "
429 .unindent();
430
431 let documents = retriever.parse_file(&text, language.clone()).unwrap();
432
433 assert_documents_eq(
434 &documents,
435 &[
436 (
437 "
438 /* globals importScripts, backend */
439 function _authorize() {}"
440 .unindent(),
441 37,
442 ),
443 (
444 "
445 /**
446 * Sometimes the frontend build is way faster than backend.
447 */
448 export async function authorizeBank() {
449 _authorize(pushModal, upgradingAccountId, {});
450 }"
451 .unindent(),
452 131,
453 ),
454 (
455 "
456 export class SettingsPage {
457 /* This is a test setting */
458 constructor(page) {
459 this.page = page;
460 }
461 }"
462 .unindent(),
463 225,
464 ),
465 (
466 "
467 /* This is a test setting */
468 constructor(page) {
469 this.page = page;
470 }"
471 .unindent(),
472 290,
473 ),
474 (
475 "
476 /* This is a test comment */
477 class TestClass {}"
478 .unindent(),
479 374,
480 ),
481 (
482 "
483 /* Schema for editor_events in Clickhouse. */
484 export interface ClickhouseEditorEvent {
485 installation_id: string
486 operation: string
487 }"
488 .unindent(),
489 440,
490 ),
491 ],
492 )
493}
494
495#[gpui::test]
496async fn test_code_context_retrieval_lua() {
497 let language = lua_lang();
498 let mut retriever = CodeContextRetriever::new();
499
500 let text = r#"
501 -- Creates a new class
502 -- @param baseclass The Baseclass of this class, or nil.
503 -- @return A new class reference.
504 function classes.class(baseclass)
505 -- Create the class definition and metatable.
506 local classdef = {}
507 -- Find the super class, either Object or user-defined.
508 baseclass = baseclass or classes.Object
509 -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
510 setmetatable(classdef, { __index = baseclass })
511 -- All class instances have a reference to the class object.
512 classdef.class = classdef
513 --- Recursivly allocates the inheritance tree of the instance.
514 -- @param mastertable The 'root' of the inheritance tree.
515 -- @return Returns the instance with the allocated inheritance tree.
516 function classdef.alloc(mastertable)
517 -- All class instances have a reference to a superclass object.
518 local instance = { super = baseclass.alloc(mastertable) }
519 -- Any functions this instance does not know of will 'look up' to the superclass definition.
520 setmetatable(instance, { __index = classdef, __newindex = mastertable })
521 return instance
522 end
523 end
524 "#.unindent();
525
526 let documents = retriever.parse_file(&text, language.clone()).unwrap();
527
528 assert_documents_eq(
529 &documents,
530 &[
531 (r#"
532 -- Creates a new class
533 -- @param baseclass The Baseclass of this class, or nil.
534 -- @return A new class reference.
535 function classes.class(baseclass)
536 -- Create the class definition and metatable.
537 local classdef = {}
538 -- Find the super class, either Object or user-defined.
539 baseclass = baseclass or classes.Object
540 -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
541 setmetatable(classdef, { __index = baseclass })
542 -- All class instances have a reference to the class object.
543 classdef.class = classdef
544 --- Recursivly allocates the inheritance tree of the instance.
545 -- @param mastertable The 'root' of the inheritance tree.
546 -- @return Returns the instance with the allocated inheritance tree.
547 function classdef.alloc(mastertable)
548 --[ ... ]--
549 --[ ... ]--
550 end
551 end"#.unindent(),
552 114),
553 (r#"
554 --- Recursivly allocates the inheritance tree of the instance.
555 -- @param mastertable The 'root' of the inheritance tree.
556 -- @return Returns the instance with the allocated inheritance tree.
557 function classdef.alloc(mastertable)
558 -- All class instances have a reference to a superclass object.
559 local instance = { super = baseclass.alloc(mastertable) }
560 -- Any functions this instance does not know of will 'look up' to the superclass definition.
561 setmetatable(instance, { __index = classdef, __newindex = mastertable })
562 return instance
563 end"#.unindent(), 809),
564 ]
565 );
566}
567
568#[gpui::test]
569async fn test_code_context_retrieval_elixir() {
570 let language = elixir_lang();
571 let mut retriever = CodeContextRetriever::new();
572
573 let text = r#"
574 defmodule File.Stream do
575 @moduledoc """
576 Defines a `File.Stream` struct returned by `File.stream!/3`.
577
578 The following fields are public:
579
580 * `path` - the file path
581 * `modes` - the file modes
582 * `raw` - a boolean indicating if bin functions should be used
583 * `line_or_bytes` - if reading should read lines or a given number of bytes
584 * `node` - the node the file belongs to
585
586 """
587
588 defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
589
590 @type t :: %__MODULE__{}
591
592 @doc false
593 def __build__(path, modes, line_or_bytes) do
594 raw = :lists.keyfind(:encoding, 1, modes) == false
595
596 modes =
597 case raw do
598 true ->
599 case :lists.keyfind(:read_ahead, 1, modes) do
600 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
601 {:read_ahead, _} -> [:raw | modes]
602 false -> [:raw, :read_ahead | modes]
603 end
604
605 false ->
606 modes
607 end
608
609 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
610
611 end"#
612 .unindent();
613
614 let documents = retriever.parse_file(&text, language.clone()).unwrap();
615
616 assert_documents_eq(
617 &documents,
618 &[(
619 r#"
620 defmodule File.Stream do
621 @moduledoc """
622 Defines a `File.Stream` struct returned by `File.stream!/3`.
623
624 The following fields are public:
625
626 * `path` - the file path
627 * `modes` - the file modes
628 * `raw` - a boolean indicating if bin functions should be used
629 * `line_or_bytes` - if reading should read lines or a given number of bytes
630 * `node` - the node the file belongs to
631
632 """
633
634 defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
635
636 @type t :: %__MODULE__{}
637
638 @doc false
639 def __build__(path, modes, line_or_bytes) do
640 raw = :lists.keyfind(:encoding, 1, modes) == false
641
642 modes =
643 case raw do
644 true ->
645 case :lists.keyfind(:read_ahead, 1, modes) do
646 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
647 {:read_ahead, _} -> [:raw | modes]
648 false -> [:raw, :read_ahead | modes]
649 end
650
651 false ->
652 modes
653 end
654
655 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
656
657 end"#
658 .unindent(),
659 0,
660 ),(r#"
661 @doc false
662 def __build__(path, modes, line_or_bytes) do
663 raw = :lists.keyfind(:encoding, 1, modes) == false
664
665 modes =
666 case raw do
667 true ->
668 case :lists.keyfind(:read_ahead, 1, modes) do
669 {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
670 {:read_ahead, _} -> [:raw | modes]
671 false -> [:raw, :read_ahead | modes]
672 end
673
674 false ->
675 modes
676 end
677
678 %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
679
680 end"#.unindent(), 574)],
681 );
682}
683
684#[gpui::test]
685async fn test_code_context_retrieval_cpp() {
686 let language = cpp_lang();
687 let mut retriever = CodeContextRetriever::new();
688
689 let text = "
690 /**
691 * @brief Main function
692 * @returns 0 on exit
693 */
694 int main() { return 0; }
695
696 /**
697 * This is a test comment
698 */
699 class MyClass { // The class
700 public: // Access specifier
701 int myNum; // Attribute (int variable)
702 string myString; // Attribute (string variable)
703 };
704
705 // This is a test comment
706 enum Color { red, green, blue };
707
708 /** This is a preceding block comment
709 * This is the second line
710 */
711 struct { // Structure declaration
712 int myNum; // Member (int variable)
713 string myString; // Member (string variable)
714 } myStructure;
715
716 /**
717 * @brief Matrix class.
718 */
719 template <typename T,
720 typename = typename std::enable_if<
721 std::is_integral<T>::value || std::is_floating_point<T>::value,
722 bool>::type>
723 class Matrix2 {
724 std::vector<std::vector<T>> _mat;
725
726 public:
727 /**
728 * @brief Constructor
729 * @tparam Integer ensuring integers are being evaluated and not other
730 * data types.
731 * @param size denoting the size of Matrix as size x size
732 */
733 template <typename Integer,
734 typename = typename std::enable_if<std::is_integral<Integer>::value,
735 Integer>::type>
736 explicit Matrix(const Integer size) {
737 for (size_t i = 0; i < size; ++i) {
738 _mat.emplace_back(std::vector<T>(size, 0));
739 }
740 }
741 }"
742 .unindent();
743
744 let documents = retriever.parse_file(&text, language.clone()).unwrap();
745
746 assert_documents_eq(
747 &documents,
748 &[
749 (
750 "
751 /**
752 * @brief Main function
753 * @returns 0 on exit
754 */
755 int main() { return 0; }"
756 .unindent(),
757 54,
758 ),
759 (
760 "
761 /**
762 * This is a test comment
763 */
764 class MyClass { // The class
765 public: // Access specifier
766 int myNum; // Attribute (int variable)
767 string myString; // Attribute (string variable)
768 }"
769 .unindent(),
770 112,
771 ),
772 (
773 "
774 // This is a test comment
775 enum Color { red, green, blue }"
776 .unindent(),
777 322,
778 ),
779 (
780 "
781 /** This is a preceding block comment
782 * This is the second line
783 */
784 struct { // Structure declaration
785 int myNum; // Member (int variable)
786 string myString; // Member (string variable)
787 } myStructure;"
788 .unindent(),
789 425,
790 ),
791 (
792 "
793 /**
794 * @brief Matrix class.
795 */
796 template <typename T,
797 typename = typename std::enable_if<
798 std::is_integral<T>::value || std::is_floating_point<T>::value,
799 bool>::type>
800 class Matrix2 {
801 std::vector<std::vector<T>> _mat;
802
803 public:
804 /**
805 * @brief Constructor
806 * @tparam Integer ensuring integers are being evaluated and not other
807 * data types.
808 * @param size denoting the size of Matrix as size x size
809 */
810 template <typename Integer,
811 typename = typename std::enable_if<std::is_integral<Integer>::value,
812 Integer>::type>
813 explicit Matrix(const Integer size) {
814 for (size_t i = 0; i < size; ++i) {
815 _mat.emplace_back(std::vector<T>(size, 0));
816 }
817 }
818 }"
819 .unindent(),
820 612,
821 ),
822 (
823 "
824 explicit Matrix(const Integer size) {
825 for (size_t i = 0; i < size; ++i) {
826 _mat.emplace_back(std::vector<T>(size, 0));
827 }
828 }"
829 .unindent(),
830 1226,
831 ),
832 ],
833 );
834}
835
836#[gpui::test]
837async fn test_code_context_retrieval_ruby() {
838 let language = ruby_lang();
839 let mut retriever = CodeContextRetriever::new();
840
841 let text = r#"
842 # This concern is inspired by "sudo mode" on GitHub. It
843 # is a way to re-authenticate a user before allowing them
844 # to see or perform an action.
845 #
846 # Add `before_action :require_challenge!` to actions you
847 # want to protect.
848 #
849 # The user will be shown a page to enter the challenge (which
850 # is either the password, or just the username when no
851 # password exists). Upon passing, there is a grace period
852 # during which no challenge will be asked from the user.
853 #
854 # Accessing challenge-protected resources during the grace
855 # period will refresh the grace period.
856 module ChallengableConcern
857 extend ActiveSupport::Concern
858
859 CHALLENGE_TIMEOUT = 1.hour.freeze
860
861 def require_challenge!
862 return if skip_challenge?
863
864 if challenge_passed_recently?
865 session[:challenge_passed_at] = Time.now.utc
866 return
867 end
868
869 @challenge = Form::Challenge.new(return_to: request.url)
870
871 if params.key?(:form_challenge)
872 if challenge_passed?
873 session[:challenge_passed_at] = Time.now.utc
874 else
875 flash.now[:alert] = I18n.t('challenge.invalid_password')
876 render_challenge
877 end
878 else
879 render_challenge
880 end
881 end
882
883 def challenge_passed?
884 current_user.valid_password?(challenge_params[:current_password])
885 end
886 end
887
888 class Animal
889 include Comparable
890
891 attr_reader :legs
892
893 def initialize(name, legs)
894 @name, @legs = name, legs
895 end
896
897 def <=>(other)
898 legs <=> other.legs
899 end
900 end
901
902 # Singleton method for car object
903 def car.wheels
904 puts "There are four wheels"
905 end"#
906 .unindent();
907
908 let documents = retriever.parse_file(&text, language.clone()).unwrap();
909
910 assert_documents_eq(
911 &documents,
912 &[
913 (
914 r#"
915 # This concern is inspired by "sudo mode" on GitHub. It
916 # is a way to re-authenticate a user before allowing them
917 # to see or perform an action.
918 #
919 # Add `before_action :require_challenge!` to actions you
920 # want to protect.
921 #
922 # The user will be shown a page to enter the challenge (which
923 # is either the password, or just the username when no
924 # password exists). Upon passing, there is a grace period
925 # during which no challenge will be asked from the user.
926 #
927 # Accessing challenge-protected resources during the grace
928 # period will refresh the grace period.
929 module ChallengableConcern
930 extend ActiveSupport::Concern
931
932 CHALLENGE_TIMEOUT = 1.hour.freeze
933
934 def require_challenge!
935 # ...
936 end
937
938 def challenge_passed?
939 # ...
940 end
941 end"#
942 .unindent(),
943 558,
944 ),
945 (
946 r#"
947 def require_challenge!
948 return if skip_challenge?
949
950 if challenge_passed_recently?
951 session[:challenge_passed_at] = Time.now.utc
952 return
953 end
954
955 @challenge = Form::Challenge.new(return_to: request.url)
956
957 if params.key?(:form_challenge)
958 if challenge_passed?
959 session[:challenge_passed_at] = Time.now.utc
960 else
961 flash.now[:alert] = I18n.t('challenge.invalid_password')
962 render_challenge
963 end
964 else
965 render_challenge
966 end
967 end"#
968 .unindent(),
969 663,
970 ),
971 (
972 r#"
973 def challenge_passed?
974 current_user.valid_password?(challenge_params[:current_password])
975 end"#
976 .unindent(),
977 1254,
978 ),
979 (
980 r#"
981 class Animal
982 include Comparable
983
984 attr_reader :legs
985
986 def initialize(name, legs)
987 # ...
988 end
989
990 def <=>(other)
991 # ...
992 end
993 end"#
994 .unindent(),
995 1363,
996 ),
997 (
998 r#"
999 def initialize(name, legs)
1000 @name, @legs = name, legs
1001 end"#
1002 .unindent(),
1003 1427,
1004 ),
1005 (
1006 r#"
1007 def <=>(other)
1008 legs <=> other.legs
1009 end"#
1010 .unindent(),
1011 1501,
1012 ),
1013 (
1014 r#"
1015 # Singleton method for car object
1016 def car.wheels
1017 puts "There are four wheels"
1018 end"#
1019 .unindent(),
1020 1591,
1021 ),
1022 ],
1023 );
1024}
1025
1026#[gpui::test]
1027async fn test_code_context_retrieval_php() {
1028 let language = php_lang();
1029 let mut retriever = CodeContextRetriever::new();
1030
1031 let text = r#"
1032 <?php
1033
1034 namespace LevelUp\Experience\Concerns;
1035
1036 /*
1037 This is a multiple-lines comment block
1038 that spans over multiple
1039 lines
1040 */
1041 function functionName() {
1042 echo "Hello world!";
1043 }
1044
1045 trait HasAchievements
1046 {
1047 /**
1048 * @throws \Exception
1049 */
1050 public function grantAchievement(Achievement $achievement, $progress = null): void
1051 {
1052 if ($progress > 100) {
1053 throw new Exception(message: 'Progress cannot be greater than 100');
1054 }
1055
1056 if ($this->achievements()->find($achievement->id)) {
1057 throw new Exception(message: 'User already has this Achievement');
1058 }
1059
1060 $this->achievements()->attach($achievement, [
1061 'progress' => $progress ?? null,
1062 ]);
1063
1064 $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1065 }
1066
1067 public function achievements(): BelongsToMany
1068 {
1069 return $this->belongsToMany(related: Achievement::class)
1070 ->withPivot(columns: 'progress')
1071 ->where('is_secret', false)
1072 ->using(AchievementUser::class);
1073 }
1074 }
1075
1076 interface Multiplier
1077 {
1078 public function qualifies(array $data): bool;
1079
1080 public function setMultiplier(): int;
1081 }
1082
1083 enum AuditType: string
1084 {
1085 case Add = 'add';
1086 case Remove = 'remove';
1087 case Reset = 'reset';
1088 case LevelUp = 'level_up';
1089 }
1090
1091 ?>"#
1092 .unindent();
1093
1094 let documents = retriever.parse_file(&text, language.clone()).unwrap();
1095
1096 assert_documents_eq(
1097 &documents,
1098 &[
1099 (
1100 r#"
1101 /*
1102 This is a multiple-lines comment block
1103 that spans over multiple
1104 lines
1105 */
1106 function functionName() {
1107 echo "Hello world!";
1108 }"#
1109 .unindent(),
1110 123,
1111 ),
1112 (
1113 r#"
1114 trait HasAchievements
1115 {
1116 /**
1117 * @throws \Exception
1118 */
1119 public function grantAchievement(Achievement $achievement, $progress = null): void
1120 {/* ... */}
1121
1122 public function achievements(): BelongsToMany
1123 {/* ... */}
1124 }"#
1125 .unindent(),
1126 177,
1127 ),
1128 (r#"
1129 /**
1130 * @throws \Exception
1131 */
1132 public function grantAchievement(Achievement $achievement, $progress = null): void
1133 {
1134 if ($progress > 100) {
1135 throw new Exception(message: 'Progress cannot be greater than 100');
1136 }
1137
1138 if ($this->achievements()->find($achievement->id)) {
1139 throw new Exception(message: 'User already has this Achievement');
1140 }
1141
1142 $this->achievements()->attach($achievement, [
1143 'progress' => $progress ?? null,
1144 ]);
1145
1146 $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1147 }"#.unindent(), 245),
1148 (r#"
1149 public function achievements(): BelongsToMany
1150 {
1151 return $this->belongsToMany(related: Achievement::class)
1152 ->withPivot(columns: 'progress')
1153 ->where('is_secret', false)
1154 ->using(AchievementUser::class);
1155 }"#.unindent(), 902),
1156 (r#"
1157 interface Multiplier
1158 {
1159 public function qualifies(array $data): bool;
1160
1161 public function setMultiplier(): int;
1162 }"#.unindent(),
1163 1146),
1164 (r#"
1165 enum AuditType: string
1166 {
1167 case Add = 'add';
1168 case Remove = 'remove';
1169 case Reset = 'reset';
1170 case LevelUp = 'level_up';
1171 }"#.unindent(), 1265)
1172 ],
1173 );
1174}
1175
1176#[gpui::test]
1177fn test_dot_product(mut rng: StdRng) {
1178 assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
1179 assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
1180
1181 for _ in 0..100 {
1182 let size = 1536;
1183 let mut a = vec![0.; size];
1184 let mut b = vec![0.; size];
1185 for (a, b) in a.iter_mut().zip(b.iter_mut()) {
1186 *a = rng.gen();
1187 *b = rng.gen();
1188 }
1189
1190 assert_eq!(
1191 round_to_decimals(dot(&a, &b), 1),
1192 round_to_decimals(reference_dot(&a, &b), 1)
1193 );
1194 }
1195
1196 fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
1197 let factor = (10.0 as f32).powi(decimal_places);
1198 (n * factor).round() / factor
1199 }
1200
1201 fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
1202 a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
1203 }
1204}
1205
1206#[derive(Default)]
1207struct FakeEmbeddingProvider {
1208 embedding_count: AtomicUsize,
1209}
1210
1211impl FakeEmbeddingProvider {
1212 fn embedding_count(&self) -> usize {
1213 self.embedding_count.load(atomic::Ordering::SeqCst)
1214 }
1215}
1216
1217#[async_trait]
1218impl EmbeddingProvider for FakeEmbeddingProvider {
1219 async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
1220 self.embedding_count
1221 .fetch_add(spans.len(), atomic::Ordering::SeqCst);
1222 Ok(spans
1223 .iter()
1224 .map(|span| {
1225 let mut result = vec![1.0; 26];
1226 for letter in span.chars() {
1227 let letter = letter.to_ascii_lowercase();
1228 if letter as u32 >= 'a' as u32 {
1229 let ix = (letter as u32) - ('a' as u32);
1230 if ix < 26 {
1231 result[ix as usize] += 1.0;
1232 }
1233 }
1234 }
1235
1236 let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
1237 for x in &mut result {
1238 *x /= norm;
1239 }
1240
1241 result
1242 })
1243 .collect())
1244 }
1245}
1246
1247fn js_lang() -> Arc<Language> {
1248 Arc::new(
1249 Language::new(
1250 LanguageConfig {
1251 name: "Javascript".into(),
1252 path_suffixes: vec!["js".into()],
1253 ..Default::default()
1254 },
1255 Some(tree_sitter_typescript::language_tsx()),
1256 )
1257 .with_embedding_query(
1258 &r#"
1259
1260 (
1261 (comment)* @context
1262 .
1263 [
1264 (export_statement
1265 (function_declaration
1266 "async"? @name
1267 "function" @name
1268 name: (_) @name))
1269 (function_declaration
1270 "async"? @name
1271 "function" @name
1272 name: (_) @name)
1273 ] @item
1274 )
1275
1276 (
1277 (comment)* @context
1278 .
1279 [
1280 (export_statement
1281 (class_declaration
1282 "class" @name
1283 name: (_) @name))
1284 (class_declaration
1285 "class" @name
1286 name: (_) @name)
1287 ] @item
1288 )
1289
1290 (
1291 (comment)* @context
1292 .
1293 [
1294 (export_statement
1295 (interface_declaration
1296 "interface" @name
1297 name: (_) @name))
1298 (interface_declaration
1299 "interface" @name
1300 name: (_) @name)
1301 ] @item
1302 )
1303
1304 (
1305 (comment)* @context
1306 .
1307 [
1308 (export_statement
1309 (enum_declaration
1310 "enum" @name
1311 name: (_) @name))
1312 (enum_declaration
1313 "enum" @name
1314 name: (_) @name)
1315 ] @item
1316 )
1317
1318 (
1319 (comment)* @context
1320 .
1321 (method_definition
1322 [
1323 "get"
1324 "set"
1325 "async"
1326 "*"
1327 "static"
1328 ]* @name
1329 name: (_) @name) @item
1330 )
1331
1332 "#
1333 .unindent(),
1334 )
1335 .unwrap(),
1336 )
1337}
1338
1339fn rust_lang() -> Arc<Language> {
1340 Arc::new(
1341 Language::new(
1342 LanguageConfig {
1343 name: "Rust".into(),
1344 path_suffixes: vec!["rs".into()],
1345 collapsed_placeholder: " /* ... */ ".to_string(),
1346 ..Default::default()
1347 },
1348 Some(tree_sitter_rust::language()),
1349 )
1350 .with_embedding_query(
1351 r#"
1352 (
1353 [(line_comment) (attribute_item)]* @context
1354 .
1355 [
1356 (struct_item
1357 name: (_) @name)
1358
1359 (enum_item
1360 name: (_) @name)
1361
1362 (impl_item
1363 trait: (_)? @name
1364 "for"? @name
1365 type: (_) @name)
1366
1367 (trait_item
1368 name: (_) @name)
1369
1370 (function_item
1371 name: (_) @name
1372 body: (block
1373 "{" @keep
1374 "}" @keep) @collapse)
1375
1376 (macro_definition
1377 name: (_) @name)
1378 ] @item
1379 )
1380 "#,
1381 )
1382 .unwrap(),
1383 )
1384}
1385
1386fn json_lang() -> Arc<Language> {
1387 Arc::new(
1388 Language::new(
1389 LanguageConfig {
1390 name: "JSON".into(),
1391 path_suffixes: vec!["json".into()],
1392 ..Default::default()
1393 },
1394 Some(tree_sitter_json::language()),
1395 )
1396 .with_embedding_query(
1397 r#"
1398 (document) @item
1399
1400 (array
1401 "[" @keep
1402 .
1403 (object)? @keep
1404 "]" @keep) @collapse
1405
1406 (pair value: (string
1407 "\"" @keep
1408 "\"" @keep) @collapse)
1409 "#,
1410 )
1411 .unwrap(),
1412 )
1413}
1414
1415fn toml_lang() -> Arc<Language> {
1416 Arc::new(Language::new(
1417 LanguageConfig {
1418 name: "TOML".into(),
1419 path_suffixes: vec!["toml".into()],
1420 ..Default::default()
1421 },
1422 Some(tree_sitter_toml::language()),
1423 ))
1424}
1425
1426fn cpp_lang() -> Arc<Language> {
1427 Arc::new(
1428 Language::new(
1429 LanguageConfig {
1430 name: "CPP".into(),
1431 path_suffixes: vec!["cpp".into()],
1432 ..Default::default()
1433 },
1434 Some(tree_sitter_cpp::language()),
1435 )
1436 .with_embedding_query(
1437 r#"
1438 (
1439 (comment)* @context
1440 .
1441 (function_definition
1442 (type_qualifier)? @name
1443 type: (_)? @name
1444 declarator: [
1445 (function_declarator
1446 declarator: (_) @name)
1447 (pointer_declarator
1448 "*" @name
1449 declarator: (function_declarator
1450 declarator: (_) @name))
1451 (pointer_declarator
1452 "*" @name
1453 declarator: (pointer_declarator
1454 "*" @name
1455 declarator: (function_declarator
1456 declarator: (_) @name)))
1457 (reference_declarator
1458 ["&" "&&"] @name
1459 (function_declarator
1460 declarator: (_) @name))
1461 ]
1462 (type_qualifier)? @name) @item
1463 )
1464
1465 (
1466 (comment)* @context
1467 .
1468 (template_declaration
1469 (class_specifier
1470 "class" @name
1471 name: (_) @name)
1472 ) @item
1473 )
1474
1475 (
1476 (comment)* @context
1477 .
1478 (class_specifier
1479 "class" @name
1480 name: (_) @name) @item
1481 )
1482
1483 (
1484 (comment)* @context
1485 .
1486 (enum_specifier
1487 "enum" @name
1488 name: (_) @name) @item
1489 )
1490
1491 (
1492 (comment)* @context
1493 .
1494 (declaration
1495 type: (struct_specifier
1496 "struct" @name)
1497 declarator: (_) @name) @item
1498 )
1499
1500 "#,
1501 )
1502 .unwrap(),
1503 )
1504}
1505
1506fn lua_lang() -> Arc<Language> {
1507 Arc::new(
1508 Language::new(
1509 LanguageConfig {
1510 name: "Lua".into(),
1511 path_suffixes: vec!["lua".into()],
1512 collapsed_placeholder: "--[ ... ]--".to_string(),
1513 ..Default::default()
1514 },
1515 Some(tree_sitter_lua::language()),
1516 )
1517 .with_embedding_query(
1518 r#"
1519 (
1520 (comment)* @context
1521 .
1522 (function_declaration
1523 "function" @name
1524 name: (_) @name
1525 (comment)* @collapse
1526 body: (block) @collapse
1527 ) @item
1528 )
1529 "#,
1530 )
1531 .unwrap(),
1532 )
1533}
1534
1535fn php_lang() -> Arc<Language> {
1536 Arc::new(
1537 Language::new(
1538 LanguageConfig {
1539 name: "PHP".into(),
1540 path_suffixes: vec!["php".into()],
1541 collapsed_placeholder: "/* ... */".into(),
1542 ..Default::default()
1543 },
1544 Some(tree_sitter_php::language()),
1545 )
1546 .with_embedding_query(
1547 r#"
1548 (
1549 (comment)* @context
1550 .
1551 [
1552 (function_definition
1553 "function" @name
1554 name: (_) @name
1555 body: (_
1556 "{" @keep
1557 "}" @keep) @collapse
1558 )
1559
1560 (trait_declaration
1561 "trait" @name
1562 name: (_) @name)
1563
1564 (method_declaration
1565 "function" @name
1566 name: (_) @name
1567 body: (_
1568 "{" @keep
1569 "}" @keep) @collapse
1570 )
1571
1572 (interface_declaration
1573 "interface" @name
1574 name: (_) @name
1575 )
1576
1577 (enum_declaration
1578 "enum" @name
1579 name: (_) @name
1580 )
1581
1582 ] @item
1583 )
1584 "#,
1585 )
1586 .unwrap(),
1587 )
1588}
1589
1590fn ruby_lang() -> Arc<Language> {
1591 Arc::new(
1592 Language::new(
1593 LanguageConfig {
1594 name: "Ruby".into(),
1595 path_suffixes: vec!["rb".into()],
1596 collapsed_placeholder: "# ...".to_string(),
1597 ..Default::default()
1598 },
1599 Some(tree_sitter_ruby::language()),
1600 )
1601 .with_embedding_query(
1602 r#"
1603 (
1604 (comment)* @context
1605 .
1606 [
1607 (module
1608 "module" @name
1609 name: (_) @name)
1610 (method
1611 "def" @name
1612 name: (_) @name
1613 body: (body_statement) @collapse)
1614 (class
1615 "class" @name
1616 name: (_) @name)
1617 (singleton_method
1618 "def" @name
1619 object: (_) @name
1620 "." @name
1621 name: (_) @name
1622 body: (body_statement) @collapse)
1623 ] @item
1624 )
1625 "#,
1626 )
1627 .unwrap(),
1628 )
1629}
1630
1631fn elixir_lang() -> Arc<Language> {
1632 Arc::new(
1633 Language::new(
1634 LanguageConfig {
1635 name: "Elixir".into(),
1636 path_suffixes: vec!["rs".into()],
1637 ..Default::default()
1638 },
1639 Some(tree_sitter_elixir::language()),
1640 )
1641 .with_embedding_query(
1642 r#"
1643 (
1644 (unary_operator
1645 operator: "@"
1646 operand: (call
1647 target: (identifier) @unary
1648 (#match? @unary "^(doc)$"))
1649 ) @context
1650 .
1651 (call
1652 target: (identifier) @name
1653 (arguments
1654 [
1655 (identifier) @name
1656 (call
1657 target: (identifier) @name)
1658 (binary_operator
1659 left: (call
1660 target: (identifier) @name)
1661 operator: "when")
1662 ])
1663 (#match? @name "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @item
1664 )
1665
1666 (call
1667 target: (identifier) @name
1668 (arguments (alias) @name)
1669 (#match? @name "^(defmodule|defprotocol)$")) @item
1670 "#,
1671 )
1672 .unwrap(),
1673 )
1674}
1675
1676#[gpui::test]
1677fn test_subtract_ranges() {
1678 // collapsed_ranges: Vec<Range<usize>>, keep_ranges: Vec<Range<usize>>
1679
1680 assert_eq!(
1681 subtract_ranges(&[0..5, 10..21], &[0..1, 4..5]),
1682 vec![1..4, 10..21]
1683 );
1684
1685 assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
1686}