semantic_index_tests.rs

   1use crate::{
   2    db::dot,
   3    embedding::EmbeddingProvider,
   4    parsing::{subtract_ranges, CodeContextRetriever, Document},
   5    semantic_index_settings::SemanticIndexSettings,
   6    SearchResult, SemanticIndex,
   7};
   8use anyhow::Result;
   9use async_trait::async_trait;
  10use gpui::{Task, TestAppContext};
  11use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
  12use pretty_assertions::assert_eq;
  13use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs, Project};
  14use rand::{rngs::StdRng, Rng};
  15use serde_json::json;
  16use settings::SettingsStore;
  17use std::{
  18    path::Path,
  19    sync::{
  20        atomic::{self, AtomicUsize},
  21        Arc,
  22    },
  23};
  24use unindent::Unindent;
  25
  26#[ctor::ctor]
  27fn init_logger() {
  28    if std::env::var("RUST_LOG").is_ok() {
  29        env_logger::init();
  30    }
  31}
  32
  33#[gpui::test]
  34async fn test_semantic_index(cx: &mut TestAppContext) {
  35    cx.update(|cx| {
  36        cx.set_global(SettingsStore::test(cx));
  37        settings::register::<SemanticIndexSettings>(cx);
  38        settings::register::<ProjectSettings>(cx);
  39    });
  40
  41    let fs = FakeFs::new(cx.background());
  42    fs.insert_tree(
  43        "/the-root",
  44        json!({
  45            "src": {
  46                "file1.rs": "
  47                    fn aaa() {
  48                        println!(\"aaaaaaaaaaaa!\");
  49                    }
  50
  51                    fn zzzzz() {
  52                        println!(\"SLEEPING\");
  53                    }
  54                ".unindent(),
  55                "file2.rs": "
  56                    fn bbb() {
  57                        println!(\"bbbbbbbbbbbbb!\");
  58                    }
  59                ".unindent(),
  60                "file3.toml": "
  61                    ZZZZZZZZZZZZZZZZZZ = 5
  62                ".unindent(),
  63            }
  64        }),
  65    )
  66    .await;
  67
  68    let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
  69    let rust_language = rust_lang();
  70    let toml_language = toml_lang();
  71    languages.add(rust_language);
  72    languages.add(toml_language);
  73
  74    let db_dir = tempdir::TempDir::new("vector-store").unwrap();
  75    let db_path = db_dir.path().join("db.sqlite");
  76
  77    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
  78    let store = SemanticIndex::new(
  79        fs.clone(),
  80        db_path,
  81        embedding_provider.clone(),
  82        languages,
  83        cx.to_async(),
  84    )
  85    .await
  86    .unwrap();
  87
  88    let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
  89    let (file_count, outstanding_file_count) = store
  90        .update(cx, |store, cx| store.index_project(project.clone(), cx))
  91        .await
  92        .unwrap();
  93    assert_eq!(file_count, 3);
  94    cx.foreground().run_until_parked();
  95    assert_eq!(*outstanding_file_count.borrow(), 0);
  96
  97    let search_results = store
  98        .update(cx, |store, cx| {
  99            store.search_project(
 100                project.clone(),
 101                "aaaaaabbbbzz".to_string(),
 102                5,
 103                vec![],
 104                vec![],
 105                cx,
 106            )
 107        })
 108        .await
 109        .unwrap();
 110
 111    assert_search_results(
 112        &search_results,
 113        &[
 114            (Path::new("src/file1.rs").into(), 0),
 115            (Path::new("src/file2.rs").into(), 0),
 116            (Path::new("src/file3.toml").into(), 0),
 117            (Path::new("src/file1.rs").into(), 45),
 118        ],
 119        cx,
 120    );
 121
 122    // Test Include Files Functonality
 123    let include_files = vec![PathMatcher::new("*.rs").unwrap()];
 124    let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
 125    let rust_only_search_results = store
 126        .update(cx, |store, cx| {
 127            store.search_project(
 128                project.clone(),
 129                "aaaaaabbbbzz".to_string(),
 130                5,
 131                include_files,
 132                vec![],
 133                cx,
 134            )
 135        })
 136        .await
 137        .unwrap();
 138
 139    assert_search_results(
 140        &rust_only_search_results,
 141        &[
 142            (Path::new("src/file1.rs").into(), 0),
 143            (Path::new("src/file2.rs").into(), 0),
 144            (Path::new("src/file1.rs").into(), 45),
 145        ],
 146        cx,
 147    );
 148
 149    let no_rust_search_results = store
 150        .update(cx, |store, cx| {
 151            store.search_project(
 152                project.clone(),
 153                "aaaaaabbbbzz".to_string(),
 154                5,
 155                vec![],
 156                exclude_files,
 157                cx,
 158            )
 159        })
 160        .await
 161        .unwrap();
 162
 163    assert_search_results(
 164        &no_rust_search_results,
 165        &[(Path::new("src/file3.toml").into(), 0)],
 166        cx,
 167    );
 168
 169    fs.save(
 170        "/the-root/src/file2.rs".as_ref(),
 171        &"
 172            fn dddd() { println!(\"ddddd!\"); }
 173            struct pqpqpqp {}
 174        "
 175        .unindent()
 176        .into(),
 177        Default::default(),
 178    )
 179    .await
 180    .unwrap();
 181
 182    cx.foreground().run_until_parked();
 183
 184    let prev_embedding_count = embedding_provider.embedding_count();
 185    let (file_count, outstanding_file_count) = store
 186        .update(cx, |store, cx| store.index_project(project.clone(), cx))
 187        .await
 188        .unwrap();
 189    assert_eq!(file_count, 1);
 190
 191    cx.foreground().run_until_parked();
 192    assert_eq!(*outstanding_file_count.borrow(), 0);
 193
 194    assert_eq!(
 195        embedding_provider.embedding_count() - prev_embedding_count,
 196        2
 197    );
 198}
 199
 200#[track_caller]
 201fn assert_search_results(
 202    actual: &[SearchResult],
 203    expected: &[(Arc<Path>, usize)],
 204    cx: &TestAppContext,
 205) {
 206    let actual = actual
 207        .iter()
 208        .map(|search_result| {
 209            search_result.buffer.read_with(cx, |buffer, _cx| {
 210                (
 211                    buffer.file().unwrap().path().clone(),
 212                    search_result.range.start.to_offset(buffer),
 213                )
 214            })
 215        })
 216        .collect::<Vec<_>>();
 217    assert_eq!(actual, expected);
 218}
 219
 220#[gpui::test]
 221async fn test_code_context_retrieval_rust() {
 222    let language = rust_lang();
 223    let mut retriever = CodeContextRetriever::new();
 224
 225    let text = "
 226        /// A doc comment
 227        /// that spans multiple lines
 228        #[gpui::test]
 229        fn a() {
 230            b
 231        }
 232
 233        impl C for D {
 234        }
 235
 236        impl E {
 237            // This is also a preceding comment
 238            pub fn function_1() -> Option<()> {
 239                todo!();
 240            }
 241
 242            // This is a preceding comment
 243            fn function_2() -> Result<()> {
 244                todo!();
 245            }
 246        }
 247    "
 248    .unindent();
 249
 250    let documents = retriever.parse_file(&text, language).unwrap();
 251
 252    assert_documents_eq(
 253        &documents,
 254        &[
 255            (
 256                "
 257                /// A doc comment
 258                /// that spans multiple lines
 259                #[gpui::test]
 260                fn a() {
 261                    b
 262                }"
 263                .unindent(),
 264                text.find("fn a").unwrap(),
 265            ),
 266            (
 267                "
 268                impl C for D {
 269                }"
 270                .unindent(),
 271                text.find("impl C").unwrap(),
 272            ),
 273            (
 274                "
 275                impl E {
 276                    // This is also a preceding comment
 277                    pub fn function_1() -> Option<()> { /* ... */ }
 278
 279                    // This is a preceding comment
 280                    fn function_2() -> Result<()> { /* ... */ }
 281                }"
 282                .unindent(),
 283                text.find("impl E").unwrap(),
 284            ),
 285            (
 286                "
 287                // This is also a preceding comment
 288                pub fn function_1() -> Option<()> {
 289                    todo!();
 290                }"
 291                .unindent(),
 292                text.find("pub fn function_1").unwrap(),
 293            ),
 294            (
 295                "
 296                // This is a preceding comment
 297                fn function_2() -> Result<()> {
 298                    todo!();
 299                }"
 300                .unindent(),
 301                text.find("fn function_2").unwrap(),
 302            ),
 303        ],
 304    );
 305}
 306
 307#[gpui::test]
 308async fn test_code_context_retrieval_json() {
 309    let language = json_lang();
 310    let mut retriever = CodeContextRetriever::new();
 311
 312    let text = r#"
 313        {
 314            "array": [1, 2, 3, 4],
 315            "string": "abcdefg",
 316            "nested_object": {
 317                "array_2": [5, 6, 7, 8],
 318                "string_2": "hijklmnop",
 319                "boolean": true,
 320                "none": null
 321            }
 322        }
 323    "#
 324    .unindent();
 325
 326    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 327
 328    assert_documents_eq(
 329        &documents,
 330        &[(
 331            r#"
 332                {
 333                    "array": [],
 334                    "string": "",
 335                    "nested_object": {
 336                        "array_2": [],
 337                        "string_2": "",
 338                        "boolean": true,
 339                        "none": null
 340                    }
 341                }"#
 342            .unindent(),
 343            text.find("{").unwrap(),
 344        )],
 345    );
 346
 347    let text = r#"
 348        [
 349            {
 350                "name": "somebody",
 351                "age": 42
 352            },
 353            {
 354                "name": "somebody else",
 355                "age": 43
 356            }
 357        ]
 358    "#
 359    .unindent();
 360
 361    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 362
 363    assert_documents_eq(
 364        &documents,
 365        &[(
 366            r#"
 367            [{
 368                    "name": "",
 369                    "age": 42
 370                }]"#
 371            .unindent(),
 372            text.find("[").unwrap(),
 373        )],
 374    );
 375}
 376
 377fn assert_documents_eq(
 378    documents: &[Document],
 379    expected_contents_and_start_offsets: &[(String, usize)],
 380) {
 381    assert_eq!(
 382        documents
 383            .iter()
 384            .map(|document| (document.content.clone(), document.range.start))
 385            .collect::<Vec<_>>(),
 386        expected_contents_and_start_offsets
 387    );
 388}
 389
 390#[gpui::test]
 391async fn test_code_context_retrieval_javascript() {
 392    let language = js_lang();
 393    let mut retriever = CodeContextRetriever::new();
 394
 395    let text = "
 396        /* globals importScripts, backend */
 397        function _authorize() {}
 398
 399        /**
 400         * Sometimes the frontend build is way faster than backend.
 401         */
 402        export async function authorizeBank() {
 403            _authorize(pushModal, upgradingAccountId, {});
 404        }
 405
 406        export class SettingsPage {
 407            /* This is a test setting */
 408            constructor(page) {
 409                this.page = page;
 410            }
 411        }
 412
 413        /* This is a test comment */
 414        class TestClass {}
 415
 416        /* Schema for editor_events in Clickhouse. */
 417        export interface ClickhouseEditorEvent {
 418            installation_id: string
 419            operation: string
 420        }
 421        "
 422    .unindent();
 423
 424    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 425
 426    assert_documents_eq(
 427        &documents,
 428        &[
 429            (
 430                "
 431            /* globals importScripts, backend */
 432            function _authorize() {}"
 433                    .unindent(),
 434                37,
 435            ),
 436            (
 437                "
 438            /**
 439             * Sometimes the frontend build is way faster than backend.
 440             */
 441            export async function authorizeBank() {
 442                _authorize(pushModal, upgradingAccountId, {});
 443            }"
 444                .unindent(),
 445                131,
 446            ),
 447            (
 448                "
 449                export class SettingsPage {
 450                    /* This is a test setting */
 451                    constructor(page) {
 452                        this.page = page;
 453                    }
 454                }"
 455                .unindent(),
 456                225,
 457            ),
 458            (
 459                "
 460                /* This is a test setting */
 461                constructor(page) {
 462                    this.page = page;
 463                }"
 464                .unindent(),
 465                290,
 466            ),
 467            (
 468                "
 469                /* This is a test comment */
 470                class TestClass {}"
 471                    .unindent(),
 472                374,
 473            ),
 474            (
 475                "
 476                /* Schema for editor_events in Clickhouse. */
 477                export interface ClickhouseEditorEvent {
 478                    installation_id: string
 479                    operation: string
 480                }"
 481                .unindent(),
 482                440,
 483            ),
 484        ],
 485    )
 486}
 487
 488#[gpui::test]
 489async fn test_code_context_retrieval_lua() {
 490    let language = lua_lang();
 491    let mut retriever = CodeContextRetriever::new();
 492
 493    let text = r#"
 494        -- Creates a new class
 495        -- @param baseclass The Baseclass of this class, or nil.
 496        -- @return A new class reference.
 497        function classes.class(baseclass)
 498            -- Create the class definition and metatable.
 499            local classdef = {}
 500            -- Find the super class, either Object or user-defined.
 501            baseclass = baseclass or classes.Object
 502            -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
 503            setmetatable(classdef, { __index = baseclass })
 504            -- All class instances have a reference to the class object.
 505            classdef.class = classdef
 506            --- Recursivly allocates the inheritance tree of the instance.
 507            -- @param mastertable The 'root' of the inheritance tree.
 508            -- @return Returns the instance with the allocated inheritance tree.
 509            function classdef.alloc(mastertable)
 510                -- All class instances have a reference to a superclass object.
 511                local instance = { super = baseclass.alloc(mastertable) }
 512                -- Any functions this instance does not know of will 'look up' to the superclass definition.
 513                setmetatable(instance, { __index = classdef, __newindex = mastertable })
 514                return instance
 515            end
 516        end
 517        "#.unindent();
 518
 519    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 520
 521    assert_documents_eq(
 522        &documents,
 523        &[
 524            (r#"
 525                -- Creates a new class
 526                -- @param baseclass The Baseclass of this class, or nil.
 527                -- @return A new class reference.
 528                function classes.class(baseclass)
 529                    -- Create the class definition and metatable.
 530                    local classdef = {}
 531                    -- Find the super class, either Object or user-defined.
 532                    baseclass = baseclass or classes.Object
 533                    -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
 534                    setmetatable(classdef, { __index = baseclass })
 535                    -- All class instances have a reference to the class object.
 536                    classdef.class = classdef
 537                    --- Recursivly allocates the inheritance tree of the instance.
 538                    -- @param mastertable The 'root' of the inheritance tree.
 539                    -- @return Returns the instance with the allocated inheritance tree.
 540                    function classdef.alloc(mastertable)
 541                        --[ ... ]--
 542                        --[ ... ]--
 543                    end
 544                end"#.unindent(),
 545            114),
 546            (r#"
 547            --- Recursivly allocates the inheritance tree of the instance.
 548            -- @param mastertable The 'root' of the inheritance tree.
 549            -- @return Returns the instance with the allocated inheritance tree.
 550            function classdef.alloc(mastertable)
 551                -- All class instances have a reference to a superclass object.
 552                local instance = { super = baseclass.alloc(mastertable) }
 553                -- Any functions this instance does not know of will 'look up' to the superclass definition.
 554                setmetatable(instance, { __index = classdef, __newindex = mastertable })
 555                return instance
 556            end"#.unindent(), 809),
 557        ]
 558    );
 559}
 560
 561#[gpui::test]
 562async fn test_code_context_retrieval_elixir() {
 563    let language = elixir_lang();
 564    let mut retriever = CodeContextRetriever::new();
 565
 566    let text = r#"
 567        defmodule File.Stream do
 568            @moduledoc """
 569            Defines a `File.Stream` struct returned by `File.stream!/3`.
 570
 571            The following fields are public:
 572
 573            * `path`          - the file path
 574            * `modes`         - the file modes
 575            * `raw`           - a boolean indicating if bin functions should be used
 576            * `line_or_bytes` - if reading should read lines or a given number of bytes
 577            * `node`          - the node the file belongs to
 578
 579            """
 580
 581            defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
 582
 583            @type t :: %__MODULE__{}
 584
 585            @doc false
 586            def __build__(path, modes, line_or_bytes) do
 587            raw = :lists.keyfind(:encoding, 1, modes) == false
 588
 589            modes =
 590                case raw do
 591                true ->
 592                    case :lists.keyfind(:read_ahead, 1, modes) do
 593                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
 594                    {:read_ahead, _} -> [:raw | modes]
 595                    false -> [:raw, :read_ahead | modes]
 596                    end
 597
 598                false ->
 599                    modes
 600                end
 601
 602            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
 603
 604            end"#
 605    .unindent();
 606
 607    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 608
 609    assert_documents_eq(
 610        &documents,
 611        &[(
 612            r#"
 613        defmodule File.Stream do
 614            @moduledoc """
 615            Defines a `File.Stream` struct returned by `File.stream!/3`.
 616
 617            The following fields are public:
 618
 619            * `path`          - the file path
 620            * `modes`         - the file modes
 621            * `raw`           - a boolean indicating if bin functions should be used
 622            * `line_or_bytes` - if reading should read lines or a given number of bytes
 623            * `node`          - the node the file belongs to
 624
 625            """
 626
 627            defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
 628
 629            @type t :: %__MODULE__{}
 630
 631            @doc false
 632            def __build__(path, modes, line_or_bytes) do
 633            raw = :lists.keyfind(:encoding, 1, modes) == false
 634
 635            modes =
 636                case raw do
 637                true ->
 638                    case :lists.keyfind(:read_ahead, 1, modes) do
 639                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
 640                    {:read_ahead, _} -> [:raw | modes]
 641                    false -> [:raw, :read_ahead | modes]
 642                    end
 643
 644                false ->
 645                    modes
 646                end
 647
 648            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
 649
 650            end"#
 651                .unindent(),
 652            0,
 653        ),(r#"
 654            @doc false
 655            def __build__(path, modes, line_or_bytes) do
 656            raw = :lists.keyfind(:encoding, 1, modes) == false
 657
 658            modes =
 659                case raw do
 660                true ->
 661                    case :lists.keyfind(:read_ahead, 1, modes) do
 662                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
 663                    {:read_ahead, _} -> [:raw | modes]
 664                    false -> [:raw, :read_ahead | modes]
 665                    end
 666
 667                false ->
 668                    modes
 669                end
 670
 671            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
 672
 673            end"#.unindent(), 574)],
 674    );
 675}
 676
 677#[gpui::test]
 678async fn test_code_context_retrieval_cpp() {
 679    let language = cpp_lang();
 680    let mut retriever = CodeContextRetriever::new();
 681
 682    let text = "
 683    /**
 684     * @brief Main function
 685     * @returns 0 on exit
 686     */
 687    int main() { return 0; }
 688
 689    /**
 690    * This is a test comment
 691    */
 692    class MyClass {       // The class
 693        public:           // Access specifier
 694        int myNum;        // Attribute (int variable)
 695        string myString;  // Attribute (string variable)
 696    };
 697
 698    // This is a test comment
 699    enum Color { red, green, blue };
 700
 701    /** This is a preceding block comment
 702     * This is the second line
 703     */
 704    struct {           // Structure declaration
 705        int myNum;       // Member (int variable)
 706        string myString; // Member (string variable)
 707    } myStructure;
 708
 709    /**
 710     * @brief Matrix class.
 711     */
 712    template <typename T,
 713              typename = typename std::enable_if<
 714                std::is_integral<T>::value || std::is_floating_point<T>::value,
 715                bool>::type>
 716    class Matrix2 {
 717        std::vector<std::vector<T>> _mat;
 718
 719        public:
 720            /**
 721            * @brief Constructor
 722            * @tparam Integer ensuring integers are being evaluated and not other
 723            * data types.
 724            * @param size denoting the size of Matrix as size x size
 725            */
 726            template <typename Integer,
 727                    typename = typename std::enable_if<std::is_integral<Integer>::value,
 728                    Integer>::type>
 729            explicit Matrix(const Integer size) {
 730                for (size_t i = 0; i < size; ++i) {
 731                    _mat.emplace_back(std::vector<T>(size, 0));
 732                }
 733            }
 734    }"
 735    .unindent();
 736
 737    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 738
 739    assert_documents_eq(
 740        &documents,
 741        &[
 742            (
 743                "
 744        /**
 745         * @brief Main function
 746         * @returns 0 on exit
 747         */
 748        int main() { return 0; }"
 749                    .unindent(),
 750                54,
 751            ),
 752            (
 753                "
 754                /**
 755                * This is a test comment
 756                */
 757                class MyClass {       // The class
 758                    public:           // Access specifier
 759                    int myNum;        // Attribute (int variable)
 760                    string myString;  // Attribute (string variable)
 761                }"
 762                .unindent(),
 763                112,
 764            ),
 765            (
 766                "
 767                // This is a test comment
 768                enum Color { red, green, blue }"
 769                    .unindent(),
 770                322,
 771            ),
 772            (
 773                "
 774                /** This is a preceding block comment
 775                 * This is the second line
 776                 */
 777                struct {           // Structure declaration
 778                    int myNum;       // Member (int variable)
 779                    string myString; // Member (string variable)
 780                } myStructure;"
 781                    .unindent(),
 782                425,
 783            ),
 784            (
 785                "
 786                /**
 787                 * @brief Matrix class.
 788                 */
 789                template <typename T,
 790                          typename = typename std::enable_if<
 791                            std::is_integral<T>::value || std::is_floating_point<T>::value,
 792                            bool>::type>
 793                class Matrix2 {
 794                    std::vector<std::vector<T>> _mat;
 795
 796                    public:
 797                        /**
 798                        * @brief Constructor
 799                        * @tparam Integer ensuring integers are being evaluated and not other
 800                        * data types.
 801                        * @param size denoting the size of Matrix as size x size
 802                        */
 803                        template <typename Integer,
 804                                typename = typename std::enable_if<std::is_integral<Integer>::value,
 805                                Integer>::type>
 806                        explicit Matrix(const Integer size) {
 807                            for (size_t i = 0; i < size; ++i) {
 808                                _mat.emplace_back(std::vector<T>(size, 0));
 809                            }
 810                        }
 811                }"
 812                .unindent(),
 813                612,
 814            ),
 815            (
 816                "
 817                explicit Matrix(const Integer size) {
 818                    for (size_t i = 0; i < size; ++i) {
 819                        _mat.emplace_back(std::vector<T>(size, 0));
 820                    }
 821                }"
 822                .unindent(),
 823                1226,
 824            ),
 825        ],
 826    );
 827}
 828
 829#[gpui::test]
 830async fn test_code_context_retrieval_ruby() {
 831    let language = ruby_lang();
 832    let mut retriever = CodeContextRetriever::new();
 833
 834    let text = r#"
 835        # This concern is inspired by "sudo mode" on GitHub. It
 836        # is a way to re-authenticate a user before allowing them
 837        # to see or perform an action.
 838        #
 839        # Add `before_action :require_challenge!` to actions you
 840        # want to protect.
 841        #
 842        # The user will be shown a page to enter the challenge (which
 843        # is either the password, or just the username when no
 844        # password exists). Upon passing, there is a grace period
 845        # during which no challenge will be asked from the user.
 846        #
 847        # Accessing challenge-protected resources during the grace
 848        # period will refresh the grace period.
 849        module ChallengableConcern
 850            extend ActiveSupport::Concern
 851
 852            CHALLENGE_TIMEOUT = 1.hour.freeze
 853
 854            def require_challenge!
 855                return if skip_challenge?
 856
 857                if challenge_passed_recently?
 858                    session[:challenge_passed_at] = Time.now.utc
 859                    return
 860                end
 861
 862                @challenge = Form::Challenge.new(return_to: request.url)
 863
 864                if params.key?(:form_challenge)
 865                    if challenge_passed?
 866                        session[:challenge_passed_at] = Time.now.utc
 867                    else
 868                        flash.now[:alert] = I18n.t('challenge.invalid_password')
 869                        render_challenge
 870                    end
 871                else
 872                    render_challenge
 873                end
 874            end
 875
 876            def challenge_passed?
 877                current_user.valid_password?(challenge_params[:current_password])
 878            end
 879        end
 880
 881        class Animal
 882            include Comparable
 883
 884            attr_reader :legs
 885
 886            def initialize(name, legs)
 887                @name, @legs = name, legs
 888            end
 889
 890            def <=>(other)
 891                legs <=> other.legs
 892            end
 893        end
 894
 895        # Singleton method for car object
 896        def car.wheels
 897            puts "There are four wheels"
 898        end"#
 899        .unindent();
 900
 901    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 902
 903    assert_documents_eq(
 904        &documents,
 905        &[
 906            (
 907                r#"
 908        # This concern is inspired by "sudo mode" on GitHub. It
 909        # is a way to re-authenticate a user before allowing them
 910        # to see or perform an action.
 911        #
 912        # Add `before_action :require_challenge!` to actions you
 913        # want to protect.
 914        #
 915        # The user will be shown a page to enter the challenge (which
 916        # is either the password, or just the username when no
 917        # password exists). Upon passing, there is a grace period
 918        # during which no challenge will be asked from the user.
 919        #
 920        # Accessing challenge-protected resources during the grace
 921        # period will refresh the grace period.
 922        module ChallengableConcern
 923            extend ActiveSupport::Concern
 924
 925            CHALLENGE_TIMEOUT = 1.hour.freeze
 926
 927            def require_challenge!
 928                # ...
 929            end
 930
 931            def challenge_passed?
 932                # ...
 933            end
 934        end"#
 935                    .unindent(),
 936                558,
 937            ),
 938            (
 939                r#"
 940            def require_challenge!
 941                return if skip_challenge?
 942
 943                if challenge_passed_recently?
 944                    session[:challenge_passed_at] = Time.now.utc
 945                    return
 946                end
 947
 948                @challenge = Form::Challenge.new(return_to: request.url)
 949
 950                if params.key?(:form_challenge)
 951                    if challenge_passed?
 952                        session[:challenge_passed_at] = Time.now.utc
 953                    else
 954                        flash.now[:alert] = I18n.t('challenge.invalid_password')
 955                        render_challenge
 956                    end
 957                else
 958                    render_challenge
 959                end
 960            end"#
 961                    .unindent(),
 962                663,
 963            ),
 964            (
 965                r#"
 966                def challenge_passed?
 967                    current_user.valid_password?(challenge_params[:current_password])
 968                end"#
 969                    .unindent(),
 970                1254,
 971            ),
 972            (
 973                r#"
 974                class Animal
 975                    include Comparable
 976
 977                    attr_reader :legs
 978
 979                    def initialize(name, legs)
 980                        # ...
 981                    end
 982
 983                    def <=>(other)
 984                        # ...
 985                    end
 986                end"#
 987                    .unindent(),
 988                1363,
 989            ),
 990            (
 991                r#"
 992                def initialize(name, legs)
 993                    @name, @legs = name, legs
 994                end"#
 995                    .unindent(),
 996                1427,
 997            ),
 998            (
 999                r#"
1000                def <=>(other)
1001                    legs <=> other.legs
1002                end"#
1003                    .unindent(),
1004                1501,
1005            ),
1006            (
1007                r#"
1008                # Singleton method for car object
1009                def car.wheels
1010                    puts "There are four wheels"
1011                end"#
1012                    .unindent(),
1013                1591,
1014            ),
1015        ],
1016    );
1017}
1018
1019#[gpui::test]
1020async fn test_code_context_retrieval_php() {
1021    let language = php_lang();
1022    let mut retriever = CodeContextRetriever::new();
1023
1024    let text = r#"
1025        <?php
1026
1027        namespace LevelUp\Experience\Concerns;
1028
1029        /*
1030        This is a multiple-lines comment block
1031        that spans over multiple
1032        lines
1033        */
1034        function functionName() {
1035            echo "Hello world!";
1036        }
1037
1038        trait HasAchievements
1039        {
1040            /**
1041            * @throws \Exception
1042            */
1043            public function grantAchievement(Achievement $achievement, $progress = null): void
1044            {
1045                if ($progress > 100) {
1046                    throw new Exception(message: 'Progress cannot be greater than 100');
1047                }
1048
1049                if ($this->achievements()->find($achievement->id)) {
1050                    throw new Exception(message: 'User already has this Achievement');
1051                }
1052
1053                $this->achievements()->attach($achievement, [
1054                    'progress' => $progress ?? null,
1055                ]);
1056
1057                $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1058            }
1059
1060            public function achievements(): BelongsToMany
1061            {
1062                return $this->belongsToMany(related: Achievement::class)
1063                ->withPivot(columns: 'progress')
1064                ->where('is_secret', false)
1065                ->using(AchievementUser::class);
1066            }
1067        }
1068
1069        interface Multiplier
1070        {
1071            public function qualifies(array $data): bool;
1072
1073            public function setMultiplier(): int;
1074        }
1075
1076        enum AuditType: string
1077        {
1078            case Add = 'add';
1079            case Remove = 'remove';
1080            case Reset = 'reset';
1081            case LevelUp = 'level_up';
1082        }
1083
1084        ?>"#
1085    .unindent();
1086
1087    let documents = retriever.parse_file(&text, language.clone()).unwrap();
1088
1089    assert_documents_eq(
1090        &documents,
1091        &[
1092            (
1093                r#"
1094        /*
1095        This is a multiple-lines comment block
1096        that spans over multiple
1097        lines
1098        */
1099        function functionName() {
1100            echo "Hello world!";
1101        }"#
1102                .unindent(),
1103                123,
1104            ),
1105            (
1106                r#"
1107        trait HasAchievements
1108        {
1109            /**
1110            * @throws \Exception
1111            */
1112            public function grantAchievement(Achievement $achievement, $progress = null): void
1113            {/* ... */}
1114
1115            public function achievements(): BelongsToMany
1116            {/* ... */}
1117        }"#
1118                .unindent(),
1119                177,
1120            ),
1121            (r#"
1122            /**
1123            * @throws \Exception
1124            */
1125            public function grantAchievement(Achievement $achievement, $progress = null): void
1126            {
1127                if ($progress > 100) {
1128                    throw new Exception(message: 'Progress cannot be greater than 100');
1129                }
1130
1131                if ($this->achievements()->find($achievement->id)) {
1132                    throw new Exception(message: 'User already has this Achievement');
1133                }
1134
1135                $this->achievements()->attach($achievement, [
1136                    'progress' => $progress ?? null,
1137                ]);
1138
1139                $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1140            }"#.unindent(), 245),
1141            (r#"
1142                public function achievements(): BelongsToMany
1143                {
1144                    return $this->belongsToMany(related: Achievement::class)
1145                    ->withPivot(columns: 'progress')
1146                    ->where('is_secret', false)
1147                    ->using(AchievementUser::class);
1148                }"#.unindent(), 902),
1149            (r#"
1150                interface Multiplier
1151                {
1152                    public function qualifies(array $data): bool;
1153
1154                    public function setMultiplier(): int;
1155                }"#.unindent(),
1156                1146),
1157            (r#"
1158                enum AuditType: string
1159                {
1160                    case Add = 'add';
1161                    case Remove = 'remove';
1162                    case Reset = 'reset';
1163                    case LevelUp = 'level_up';
1164                }"#.unindent(), 1265)
1165        ],
1166    );
1167}
1168
1169#[gpui::test]
1170fn test_dot_product(mut rng: StdRng) {
1171    assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
1172    assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
1173
1174    for _ in 0..100 {
1175        let size = 1536;
1176        let mut a = vec![0.; size];
1177        let mut b = vec![0.; size];
1178        for (a, b) in a.iter_mut().zip(b.iter_mut()) {
1179            *a = rng.gen();
1180            *b = rng.gen();
1181        }
1182
1183        assert_eq!(
1184            round_to_decimals(dot(&a, &b), 1),
1185            round_to_decimals(reference_dot(&a, &b), 1)
1186        );
1187    }
1188
1189    fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
1190        let factor = (10.0 as f32).powi(decimal_places);
1191        (n * factor).round() / factor
1192    }
1193
1194    fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
1195        a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
1196    }
1197}
1198
1199#[derive(Default)]
1200struct FakeEmbeddingProvider {
1201    embedding_count: AtomicUsize,
1202}
1203
1204impl FakeEmbeddingProvider {
1205    fn embedding_count(&self) -> usize {
1206        self.embedding_count.load(atomic::Ordering::SeqCst)
1207    }
1208}
1209
1210#[async_trait]
1211impl EmbeddingProvider for FakeEmbeddingProvider {
1212    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
1213        self.embedding_count
1214            .fetch_add(spans.len(), atomic::Ordering::SeqCst);
1215        Ok(spans
1216            .iter()
1217            .map(|span| {
1218                let mut result = vec![1.0; 26];
1219                for letter in span.chars() {
1220                    let letter = letter.to_ascii_lowercase();
1221                    if letter as u32 >= 'a' as u32 {
1222                        let ix = (letter as u32) - ('a' as u32);
1223                        if ix < 26 {
1224                            result[ix as usize] += 1.0;
1225                        }
1226                    }
1227                }
1228
1229                let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
1230                for x in &mut result {
1231                    *x /= norm;
1232                }
1233
1234                result
1235            })
1236            .collect())
1237    }
1238}
1239
1240fn js_lang() -> Arc<Language> {
1241    Arc::new(
1242        Language::new(
1243            LanguageConfig {
1244                name: "Javascript".into(),
1245                path_suffixes: vec!["js".into()],
1246                ..Default::default()
1247            },
1248            Some(tree_sitter_typescript::language_tsx()),
1249        )
1250        .with_embedding_query(
1251            &r#"
1252
1253            (
1254                (comment)* @context
1255                .
1256                [
1257                (export_statement
1258                    (function_declaration
1259                        "async"? @name
1260                        "function" @name
1261                        name: (_) @name))
1262                (function_declaration
1263                    "async"? @name
1264                    "function" @name
1265                    name: (_) @name)
1266                ] @item
1267            )
1268
1269            (
1270                (comment)* @context
1271                .
1272                [
1273                (export_statement
1274                    (class_declaration
1275                        "class" @name
1276                        name: (_) @name))
1277                (class_declaration
1278                    "class" @name
1279                    name: (_) @name)
1280                ] @item
1281            )
1282
1283            (
1284                (comment)* @context
1285                .
1286                [
1287                (export_statement
1288                    (interface_declaration
1289                        "interface" @name
1290                        name: (_) @name))
1291                (interface_declaration
1292                    "interface" @name
1293                    name: (_) @name)
1294                ] @item
1295            )
1296
1297            (
1298                (comment)* @context
1299                .
1300                [
1301                (export_statement
1302                    (enum_declaration
1303                        "enum" @name
1304                        name: (_) @name))
1305                (enum_declaration
1306                    "enum" @name
1307                    name: (_) @name)
1308                ] @item
1309            )
1310
1311            (
1312                (comment)* @context
1313                .
1314                (method_definition
1315                    [
1316                        "get"
1317                        "set"
1318                        "async"
1319                        "*"
1320                        "static"
1321                    ]* @name
1322                    name: (_) @name) @item
1323            )
1324
1325                    "#
1326            .unindent(),
1327        )
1328        .unwrap(),
1329    )
1330}
1331
1332fn rust_lang() -> Arc<Language> {
1333    Arc::new(
1334        Language::new(
1335            LanguageConfig {
1336                name: "Rust".into(),
1337                path_suffixes: vec!["rs".into()],
1338                collapsed_placeholder: " /* ... */ ".to_string(),
1339                ..Default::default()
1340            },
1341            Some(tree_sitter_rust::language()),
1342        )
1343        .with_embedding_query(
1344            r#"
1345            (
1346                [(line_comment) (attribute_item)]* @context
1347                .
1348                [
1349                    (struct_item
1350                        name: (_) @name)
1351
1352                    (enum_item
1353                        name: (_) @name)
1354
1355                    (impl_item
1356                        trait: (_)? @name
1357                        "for"? @name
1358                        type: (_) @name)
1359
1360                    (trait_item
1361                        name: (_) @name)
1362
1363                    (function_item
1364                        name: (_) @name
1365                        body: (block
1366                            "{" @keep
1367                            "}" @keep) @collapse)
1368
1369                    (macro_definition
1370                        name: (_) @name)
1371                ] @item
1372            )
1373            "#,
1374        )
1375        .unwrap(),
1376    )
1377}
1378
1379fn json_lang() -> Arc<Language> {
1380    Arc::new(
1381        Language::new(
1382            LanguageConfig {
1383                name: "JSON".into(),
1384                path_suffixes: vec!["json".into()],
1385                ..Default::default()
1386            },
1387            Some(tree_sitter_json::language()),
1388        )
1389        .with_embedding_query(
1390            r#"
1391            (document) @item
1392
1393            (array
1394                "[" @keep
1395                .
1396                (object)? @keep
1397                "]" @keep) @collapse
1398
1399            (pair value: (string
1400                "\"" @keep
1401                "\"" @keep) @collapse)
1402            "#,
1403        )
1404        .unwrap(),
1405    )
1406}
1407
1408fn toml_lang() -> Arc<Language> {
1409    Arc::new(Language::new(
1410        LanguageConfig {
1411            name: "TOML".into(),
1412            path_suffixes: vec!["toml".into()],
1413            ..Default::default()
1414        },
1415        Some(tree_sitter_toml::language()),
1416    ))
1417}
1418
1419fn cpp_lang() -> Arc<Language> {
1420    Arc::new(
1421        Language::new(
1422            LanguageConfig {
1423                name: "CPP".into(),
1424                path_suffixes: vec!["cpp".into()],
1425                ..Default::default()
1426            },
1427            Some(tree_sitter_cpp::language()),
1428        )
1429        .with_embedding_query(
1430            r#"
1431            (
1432                (comment)* @context
1433                .
1434                (function_definition
1435                    (type_qualifier)? @name
1436                    type: (_)? @name
1437                    declarator: [
1438                        (function_declarator
1439                            declarator: (_) @name)
1440                        (pointer_declarator
1441                            "*" @name
1442                            declarator: (function_declarator
1443                            declarator: (_) @name))
1444                        (pointer_declarator
1445                            "*" @name
1446                            declarator: (pointer_declarator
1447                                "*" @name
1448                            declarator: (function_declarator
1449                                declarator: (_) @name)))
1450                        (reference_declarator
1451                            ["&" "&&"] @name
1452                            (function_declarator
1453                            declarator: (_) @name))
1454                    ]
1455                    (type_qualifier)? @name) @item
1456                )
1457
1458            (
1459                (comment)* @context
1460                .
1461                (template_declaration
1462                    (class_specifier
1463                        "class" @name
1464                        name: (_) @name)
1465                        ) @item
1466            )
1467
1468            (
1469                (comment)* @context
1470                .
1471                (class_specifier
1472                    "class" @name
1473                    name: (_) @name) @item
1474                )
1475
1476            (
1477                (comment)* @context
1478                .
1479                (enum_specifier
1480                    "enum" @name
1481                    name: (_) @name) @item
1482                )
1483
1484            (
1485                (comment)* @context
1486                .
1487                (declaration
1488                    type: (struct_specifier
1489                    "struct" @name)
1490                    declarator: (_) @name) @item
1491            )
1492
1493            "#,
1494        )
1495        .unwrap(),
1496    )
1497}
1498
1499fn lua_lang() -> Arc<Language> {
1500    Arc::new(
1501        Language::new(
1502            LanguageConfig {
1503                name: "Lua".into(),
1504                path_suffixes: vec!["lua".into()],
1505                collapsed_placeholder: "--[ ... ]--".to_string(),
1506                ..Default::default()
1507            },
1508            Some(tree_sitter_lua::language()),
1509        )
1510        .with_embedding_query(
1511            r#"
1512            (
1513                (comment)* @context
1514                .
1515                (function_declaration
1516                    "function" @name
1517                    name: (_) @name
1518                    (comment)* @collapse
1519                    body: (block) @collapse
1520                ) @item
1521            )
1522        "#,
1523        )
1524        .unwrap(),
1525    )
1526}
1527
1528fn php_lang() -> Arc<Language> {
1529    Arc::new(
1530        Language::new(
1531            LanguageConfig {
1532                name: "PHP".into(),
1533                path_suffixes: vec!["php".into()],
1534                collapsed_placeholder: "/* ... */".into(),
1535                ..Default::default()
1536            },
1537            Some(tree_sitter_php::language()),
1538        )
1539        .with_embedding_query(
1540            r#"
1541            (
1542                (comment)* @context
1543                .
1544                [
1545                    (function_definition
1546                        "function" @name
1547                        name: (_) @name
1548                        body: (_
1549                            "{" @keep
1550                            "}" @keep) @collapse
1551                        )
1552
1553                    (trait_declaration
1554                        "trait" @name
1555                        name: (_) @name)
1556
1557                    (method_declaration
1558                        "function" @name
1559                        name: (_) @name
1560                        body: (_
1561                            "{" @keep
1562                            "}" @keep) @collapse
1563                        )
1564
1565                    (interface_declaration
1566                        "interface" @name
1567                        name: (_) @name
1568                        )
1569
1570                    (enum_declaration
1571                        "enum" @name
1572                        name: (_) @name
1573                        )
1574
1575                ] @item
1576            )
1577            "#,
1578        )
1579        .unwrap(),
1580    )
1581}
1582
1583fn ruby_lang() -> Arc<Language> {
1584    Arc::new(
1585        Language::new(
1586            LanguageConfig {
1587                name: "Ruby".into(),
1588                path_suffixes: vec!["rb".into()],
1589                collapsed_placeholder: "# ...".to_string(),
1590                ..Default::default()
1591            },
1592            Some(tree_sitter_ruby::language()),
1593        )
1594        .with_embedding_query(
1595            r#"
1596            (
1597                (comment)* @context
1598                .
1599                [
1600                (module
1601                    "module" @name
1602                    name: (_) @name)
1603                (method
1604                    "def" @name
1605                    name: (_) @name
1606                    body: (body_statement) @collapse)
1607                (class
1608                    "class" @name
1609                    name: (_) @name)
1610                (singleton_method
1611                    "def" @name
1612                    object: (_) @name
1613                    "." @name
1614                    name: (_) @name
1615                    body: (body_statement) @collapse)
1616                ] @item
1617            )
1618            "#,
1619        )
1620        .unwrap(),
1621    )
1622}
1623
1624fn elixir_lang() -> Arc<Language> {
1625    Arc::new(
1626        Language::new(
1627            LanguageConfig {
1628                name: "Elixir".into(),
1629                path_suffixes: vec!["rs".into()],
1630                ..Default::default()
1631            },
1632            Some(tree_sitter_elixir::language()),
1633        )
1634        .with_embedding_query(
1635            r#"
1636            (
1637                (unary_operator
1638                    operator: "@"
1639                    operand: (call
1640                        target: (identifier) @unary
1641                        (#match? @unary "^(doc)$"))
1642                    ) @context
1643                .
1644                (call
1645                target: (identifier) @name
1646                (arguments
1647                [
1648                (identifier) @name
1649                (call
1650                target: (identifier) @name)
1651                (binary_operator
1652                left: (call
1653                target: (identifier) @name)
1654                operator: "when")
1655                ])
1656                (#match? @name "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @item
1657                )
1658
1659            (call
1660                target: (identifier) @name
1661                (arguments (alias) @name)
1662                (#match? @name "^(defmodule|defprotocol)$")) @item
1663            "#,
1664        )
1665        .unwrap(),
1666    )
1667}
1668
1669#[gpui::test]
1670fn test_subtract_ranges() {
1671    // collapsed_ranges: Vec<Range<usize>>, keep_ranges: Vec<Range<usize>>
1672
1673    assert_eq!(
1674        subtract_ranges(&[0..5, 10..21], &[0..1, 4..5]),
1675        vec![1..4, 10..21]
1676    );
1677
1678    assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
1679}