semantic_index_tests.rs

   1use crate::{
   2    db::dot,
   3    embedding::EmbeddingProvider,
   4    parsing::{subtract_ranges, CodeContextRetriever, Document},
   5    semantic_index_settings::SemanticIndexSettings,
   6    SearchResult, SemanticIndex,
   7};
   8use anyhow::Result;
   9use async_trait::async_trait;
  10use gpui::{Task, TestAppContext};
  11use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
  12use pretty_assertions::assert_eq;
  13use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs, Project};
  14use rand::{rngs::StdRng, Rng};
  15use serde_json::json;
  16use settings::SettingsStore;
  17use std::{
  18    path::Path,
  19    sync::{
  20        atomic::{self, AtomicUsize},
  21        Arc,
  22    },
  23};
  24use unindent::Unindent;
  25
  26#[ctor::ctor]
  27fn init_logger() {
  28    if std::env::var("RUST_LOG").is_ok() {
  29        env_logger::init();
  30    }
  31}
  32
  33#[gpui::test]
  34async fn test_semantic_index(cx: &mut TestAppContext) {
  35    cx.update(|cx| {
  36        cx.set_global(SettingsStore::test(cx));
  37        settings::register::<SemanticIndexSettings>(cx);
  38        settings::register::<ProjectSettings>(cx);
  39    });
  40
  41    let fs = FakeFs::new(cx.background());
  42    fs.insert_tree(
  43        "/the-root",
  44        json!({
  45            "src": {
  46                "file1.rs": "
  47                    fn aaa() {
  48                        println!(\"aaaaaaaaaaaa!\");
  49                    }
  50
  51                    fn zzzzz() {
  52                        println!(\"SLEEPING\");
  53                    }
  54                ".unindent(),
  55                "file2.rs": "
  56                    fn bbb() {
  57                        println!(\"bbbbbbbbbbbbb!\");
  58                    }
  59                ".unindent(),
  60                "file3.toml": "
  61                    ZZZZZZZZZZZZZZZZZZ = 5
  62                ".unindent(),
  63            }
  64        }),
  65    )
  66    .await;
  67
  68    let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
  69    let rust_language = rust_lang();
  70    let toml_language = toml_lang();
  71    languages.add(rust_language);
  72    languages.add(toml_language);
  73
  74    let db_dir = tempdir::TempDir::new("vector-store").unwrap();
  75    let db_path = db_dir.path().join("db.sqlite");
  76
  77    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
  78    let store = SemanticIndex::new(
  79        fs.clone(),
  80        db_path,
  81        embedding_provider.clone(),
  82        languages,
  83        cx.to_async(),
  84    )
  85    .await
  86    .unwrap();
  87
  88    let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
  89
  90    let _ = store
  91        .update(cx, |store, cx| {
  92            store.initialize_project(project.clone(), cx)
  93        })
  94        .await;
  95
  96    let (file_count, outstanding_file_count) = store
  97        .update(cx, |store, cx| store.index_project(project.clone(), cx))
  98        .await
  99        .unwrap();
 100    assert_eq!(file_count, 3);
 101    cx.foreground().run_until_parked();
 102    assert_eq!(*outstanding_file_count.borrow(), 0);
 103
 104    let search_results = store
 105        .update(cx, |store, cx| {
 106            store.search_project(
 107                project.clone(),
 108                "aaaaaabbbbzz".to_string(),
 109                5,
 110                vec![],
 111                vec![],
 112                cx,
 113            )
 114        })
 115        .await
 116        .unwrap();
 117
 118    assert_search_results(
 119        &search_results,
 120        &[
 121            (Path::new("src/file1.rs").into(), 0),
 122            (Path::new("src/file2.rs").into(), 0),
 123            (Path::new("src/file3.toml").into(), 0),
 124            (Path::new("src/file1.rs").into(), 45),
 125        ],
 126        cx,
 127    );
 128
 129    // Test Include Files Functonality
 130    let include_files = vec![PathMatcher::new("*.rs").unwrap()];
 131    let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
 132    let rust_only_search_results = store
 133        .update(cx, |store, cx| {
 134            store.search_project(
 135                project.clone(),
 136                "aaaaaabbbbzz".to_string(),
 137                5,
 138                include_files,
 139                vec![],
 140                cx,
 141            )
 142        })
 143        .await
 144        .unwrap();
 145
 146    assert_search_results(
 147        &rust_only_search_results,
 148        &[
 149            (Path::new("src/file1.rs").into(), 0),
 150            (Path::new("src/file2.rs").into(), 0),
 151            (Path::new("src/file1.rs").into(), 45),
 152        ],
 153        cx,
 154    );
 155
 156    let no_rust_search_results = store
 157        .update(cx, |store, cx| {
 158            store.search_project(
 159                project.clone(),
 160                "aaaaaabbbbzz".to_string(),
 161                5,
 162                vec![],
 163                exclude_files,
 164                cx,
 165            )
 166        })
 167        .await
 168        .unwrap();
 169
 170    assert_search_results(
 171        &no_rust_search_results,
 172        &[(Path::new("src/file3.toml").into(), 0)],
 173        cx,
 174    );
 175
 176    fs.save(
 177        "/the-root/src/file2.rs".as_ref(),
 178        &"
 179            fn dddd() { println!(\"ddddd!\"); }
 180            struct pqpqpqp {}
 181        "
 182        .unindent()
 183        .into(),
 184        Default::default(),
 185    )
 186    .await
 187    .unwrap();
 188
 189    cx.foreground().run_until_parked();
 190
 191    let prev_embedding_count = embedding_provider.embedding_count();
 192    let (file_count, outstanding_file_count) = store
 193        .update(cx, |store, cx| store.index_project(project.clone(), cx))
 194        .await
 195        .unwrap();
 196    assert_eq!(file_count, 1);
 197
 198    cx.foreground().run_until_parked();
 199    assert_eq!(*outstanding_file_count.borrow(), 0);
 200
 201    assert_eq!(
 202        embedding_provider.embedding_count() - prev_embedding_count,
 203        2
 204    );
 205}
 206
 207#[track_caller]
 208fn assert_search_results(
 209    actual: &[SearchResult],
 210    expected: &[(Arc<Path>, usize)],
 211    cx: &TestAppContext,
 212) {
 213    let actual = actual
 214        .iter()
 215        .map(|search_result| {
 216            search_result.buffer.read_with(cx, |buffer, _cx| {
 217                (
 218                    buffer.file().unwrap().path().clone(),
 219                    search_result.range.start.to_offset(buffer),
 220                )
 221            })
 222        })
 223        .collect::<Vec<_>>();
 224    assert_eq!(actual, expected);
 225}
 226
 227#[gpui::test]
 228async fn test_code_context_retrieval_rust() {
 229    let language = rust_lang();
 230    let mut retriever = CodeContextRetriever::new();
 231
 232    let text = "
 233        /// A doc comment
 234        /// that spans multiple lines
 235        #[gpui::test]
 236        fn a() {
 237            b
 238        }
 239
 240        impl C for D {
 241        }
 242
 243        impl E {
 244            // This is also a preceding comment
 245            pub fn function_1() -> Option<()> {
 246                todo!();
 247            }
 248
 249            // This is a preceding comment
 250            fn function_2() -> Result<()> {
 251                todo!();
 252            }
 253        }
 254    "
 255    .unindent();
 256
 257    let documents = retriever.parse_file(&text, language).unwrap();
 258
 259    assert_documents_eq(
 260        &documents,
 261        &[
 262            (
 263                "
 264                /// A doc comment
 265                /// that spans multiple lines
 266                #[gpui::test]
 267                fn a() {
 268                    b
 269                }"
 270                .unindent(),
 271                text.find("fn a").unwrap(),
 272            ),
 273            (
 274                "
 275                impl C for D {
 276                }"
 277                .unindent(),
 278                text.find("impl C").unwrap(),
 279            ),
 280            (
 281                "
 282                impl E {
 283                    // This is also a preceding comment
 284                    pub fn function_1() -> Option<()> { /* ... */ }
 285
 286                    // This is a preceding comment
 287                    fn function_2() -> Result<()> { /* ... */ }
 288                }"
 289                .unindent(),
 290                text.find("impl E").unwrap(),
 291            ),
 292            (
 293                "
 294                // This is also a preceding comment
 295                pub fn function_1() -> Option<()> {
 296                    todo!();
 297                }"
 298                .unindent(),
 299                text.find("pub fn function_1").unwrap(),
 300            ),
 301            (
 302                "
 303                // This is a preceding comment
 304                fn function_2() -> Result<()> {
 305                    todo!();
 306                }"
 307                .unindent(),
 308                text.find("fn function_2").unwrap(),
 309            ),
 310        ],
 311    );
 312}
 313
 314#[gpui::test]
 315async fn test_code_context_retrieval_json() {
 316    let language = json_lang();
 317    let mut retriever = CodeContextRetriever::new();
 318
 319    let text = r#"
 320        {
 321            "array": [1, 2, 3, 4],
 322            "string": "abcdefg",
 323            "nested_object": {
 324                "array_2": [5, 6, 7, 8],
 325                "string_2": "hijklmnop",
 326                "boolean": true,
 327                "none": null
 328            }
 329        }
 330    "#
 331    .unindent();
 332
 333    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 334
 335    assert_documents_eq(
 336        &documents,
 337        &[(
 338            r#"
 339                {
 340                    "array": [],
 341                    "string": "",
 342                    "nested_object": {
 343                        "array_2": [],
 344                        "string_2": "",
 345                        "boolean": true,
 346                        "none": null
 347                    }
 348                }"#
 349            .unindent(),
 350            text.find("{").unwrap(),
 351        )],
 352    );
 353
 354    let text = r#"
 355        [
 356            {
 357                "name": "somebody",
 358                "age": 42
 359            },
 360            {
 361                "name": "somebody else",
 362                "age": 43
 363            }
 364        ]
 365    "#
 366    .unindent();
 367
 368    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 369
 370    assert_documents_eq(
 371        &documents,
 372        &[(
 373            r#"
 374            [{
 375                    "name": "",
 376                    "age": 42
 377                }]"#
 378            .unindent(),
 379            text.find("[").unwrap(),
 380        )],
 381    );
 382}
 383
 384fn assert_documents_eq(
 385    documents: &[Document],
 386    expected_contents_and_start_offsets: &[(String, usize)],
 387) {
 388    assert_eq!(
 389        documents
 390            .iter()
 391            .map(|document| (document.content.clone(), document.range.start))
 392            .collect::<Vec<_>>(),
 393        expected_contents_and_start_offsets
 394    );
 395}
 396
 397#[gpui::test]
 398async fn test_code_context_retrieval_javascript() {
 399    let language = js_lang();
 400    let mut retriever = CodeContextRetriever::new();
 401
 402    let text = "
 403        /* globals importScripts, backend */
 404        function _authorize() {}
 405
 406        /**
 407         * Sometimes the frontend build is way faster than backend.
 408         */
 409        export async function authorizeBank() {
 410            _authorize(pushModal, upgradingAccountId, {});
 411        }
 412
 413        export class SettingsPage {
 414            /* This is a test setting */
 415            constructor(page) {
 416                this.page = page;
 417            }
 418        }
 419
 420        /* This is a test comment */
 421        class TestClass {}
 422
 423        /* Schema for editor_events in Clickhouse. */
 424        export interface ClickhouseEditorEvent {
 425            installation_id: string
 426            operation: string
 427        }
 428        "
 429    .unindent();
 430
 431    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 432
 433    assert_documents_eq(
 434        &documents,
 435        &[
 436            (
 437                "
 438            /* globals importScripts, backend */
 439            function _authorize() {}"
 440                    .unindent(),
 441                37,
 442            ),
 443            (
 444                "
 445            /**
 446             * Sometimes the frontend build is way faster than backend.
 447             */
 448            export async function authorizeBank() {
 449                _authorize(pushModal, upgradingAccountId, {});
 450            }"
 451                .unindent(),
 452                131,
 453            ),
 454            (
 455                "
 456                export class SettingsPage {
 457                    /* This is a test setting */
 458                    constructor(page) {
 459                        this.page = page;
 460                    }
 461                }"
 462                .unindent(),
 463                225,
 464            ),
 465            (
 466                "
 467                /* This is a test setting */
 468                constructor(page) {
 469                    this.page = page;
 470                }"
 471                .unindent(),
 472                290,
 473            ),
 474            (
 475                "
 476                /* This is a test comment */
 477                class TestClass {}"
 478                    .unindent(),
 479                374,
 480            ),
 481            (
 482                "
 483                /* Schema for editor_events in Clickhouse. */
 484                export interface ClickhouseEditorEvent {
 485                    installation_id: string
 486                    operation: string
 487                }"
 488                .unindent(),
 489                440,
 490            ),
 491        ],
 492    )
 493}
 494
 495#[gpui::test]
 496async fn test_code_context_retrieval_lua() {
 497    let language = lua_lang();
 498    let mut retriever = CodeContextRetriever::new();
 499
 500    let text = r#"
 501        -- Creates a new class
 502        -- @param baseclass The Baseclass of this class, or nil.
 503        -- @return A new class reference.
 504        function classes.class(baseclass)
 505            -- Create the class definition and metatable.
 506            local classdef = {}
 507            -- Find the super class, either Object or user-defined.
 508            baseclass = baseclass or classes.Object
 509            -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
 510            setmetatable(classdef, { __index = baseclass })
 511            -- All class instances have a reference to the class object.
 512            classdef.class = classdef
 513            --- Recursivly allocates the inheritance tree of the instance.
 514            -- @param mastertable The 'root' of the inheritance tree.
 515            -- @return Returns the instance with the allocated inheritance tree.
 516            function classdef.alloc(mastertable)
 517                -- All class instances have a reference to a superclass object.
 518                local instance = { super = baseclass.alloc(mastertable) }
 519                -- Any functions this instance does not know of will 'look up' to the superclass definition.
 520                setmetatable(instance, { __index = classdef, __newindex = mastertable })
 521                return instance
 522            end
 523        end
 524        "#.unindent();
 525
 526    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 527
 528    assert_documents_eq(
 529        &documents,
 530        &[
 531            (r#"
 532                -- Creates a new class
 533                -- @param baseclass The Baseclass of this class, or nil.
 534                -- @return A new class reference.
 535                function classes.class(baseclass)
 536                    -- Create the class definition and metatable.
 537                    local classdef = {}
 538                    -- Find the super class, either Object or user-defined.
 539                    baseclass = baseclass or classes.Object
 540                    -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
 541                    setmetatable(classdef, { __index = baseclass })
 542                    -- All class instances have a reference to the class object.
 543                    classdef.class = classdef
 544                    --- Recursivly allocates the inheritance tree of the instance.
 545                    -- @param mastertable The 'root' of the inheritance tree.
 546                    -- @return Returns the instance with the allocated inheritance tree.
 547                    function classdef.alloc(mastertable)
 548                        --[ ... ]--
 549                        --[ ... ]--
 550                    end
 551                end"#.unindent(),
 552            114),
 553            (r#"
 554            --- Recursivly allocates the inheritance tree of the instance.
 555            -- @param mastertable The 'root' of the inheritance tree.
 556            -- @return Returns the instance with the allocated inheritance tree.
 557            function classdef.alloc(mastertable)
 558                -- All class instances have a reference to a superclass object.
 559                local instance = { super = baseclass.alloc(mastertable) }
 560                -- Any functions this instance does not know of will 'look up' to the superclass definition.
 561                setmetatable(instance, { __index = classdef, __newindex = mastertable })
 562                return instance
 563            end"#.unindent(), 809),
 564        ]
 565    );
 566}
 567
 568#[gpui::test]
 569async fn test_code_context_retrieval_elixir() {
 570    let language = elixir_lang();
 571    let mut retriever = CodeContextRetriever::new();
 572
 573    let text = r#"
 574        defmodule File.Stream do
 575            @moduledoc """
 576            Defines a `File.Stream` struct returned by `File.stream!/3`.
 577
 578            The following fields are public:
 579
 580            * `path`          - the file path
 581            * `modes`         - the file modes
 582            * `raw`           - a boolean indicating if bin functions should be used
 583            * `line_or_bytes` - if reading should read lines or a given number of bytes
 584            * `node`          - the node the file belongs to
 585
 586            """
 587
 588            defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
 589
 590            @type t :: %__MODULE__{}
 591
 592            @doc false
 593            def __build__(path, modes, line_or_bytes) do
 594            raw = :lists.keyfind(:encoding, 1, modes) == false
 595
 596            modes =
 597                case raw do
 598                true ->
 599                    case :lists.keyfind(:read_ahead, 1, modes) do
 600                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
 601                    {:read_ahead, _} -> [:raw | modes]
 602                    false -> [:raw, :read_ahead | modes]
 603                    end
 604
 605                false ->
 606                    modes
 607                end
 608
 609            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
 610
 611            end"#
 612    .unindent();
 613
 614    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 615
 616    assert_documents_eq(
 617        &documents,
 618        &[(
 619            r#"
 620        defmodule File.Stream do
 621            @moduledoc """
 622            Defines a `File.Stream` struct returned by `File.stream!/3`.
 623
 624            The following fields are public:
 625
 626            * `path`          - the file path
 627            * `modes`         - the file modes
 628            * `raw`           - a boolean indicating if bin functions should be used
 629            * `line_or_bytes` - if reading should read lines or a given number of bytes
 630            * `node`          - the node the file belongs to
 631
 632            """
 633
 634            defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
 635
 636            @type t :: %__MODULE__{}
 637
 638            @doc false
 639            def __build__(path, modes, line_or_bytes) do
 640            raw = :lists.keyfind(:encoding, 1, modes) == false
 641
 642            modes =
 643                case raw do
 644                true ->
 645                    case :lists.keyfind(:read_ahead, 1, modes) do
 646                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
 647                    {:read_ahead, _} -> [:raw | modes]
 648                    false -> [:raw, :read_ahead | modes]
 649                    end
 650
 651                false ->
 652                    modes
 653                end
 654
 655            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
 656
 657            end"#
 658                .unindent(),
 659            0,
 660        ),(r#"
 661            @doc false
 662            def __build__(path, modes, line_or_bytes) do
 663            raw = :lists.keyfind(:encoding, 1, modes) == false
 664
 665            modes =
 666                case raw do
 667                true ->
 668                    case :lists.keyfind(:read_ahead, 1, modes) do
 669                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
 670                    {:read_ahead, _} -> [:raw | modes]
 671                    false -> [:raw, :read_ahead | modes]
 672                    end
 673
 674                false ->
 675                    modes
 676                end
 677
 678            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
 679
 680            end"#.unindent(), 574)],
 681    );
 682}
 683
 684#[gpui::test]
 685async fn test_code_context_retrieval_cpp() {
 686    let language = cpp_lang();
 687    let mut retriever = CodeContextRetriever::new();
 688
 689    let text = "
 690    /**
 691     * @brief Main function
 692     * @returns 0 on exit
 693     */
 694    int main() { return 0; }
 695
 696    /**
 697    * This is a test comment
 698    */
 699    class MyClass {       // The class
 700        public:           // Access specifier
 701        int myNum;        // Attribute (int variable)
 702        string myString;  // Attribute (string variable)
 703    };
 704
 705    // This is a test comment
 706    enum Color { red, green, blue };
 707
 708    /** This is a preceding block comment
 709     * This is the second line
 710     */
 711    struct {           // Structure declaration
 712        int myNum;       // Member (int variable)
 713        string myString; // Member (string variable)
 714    } myStructure;
 715
 716    /**
 717     * @brief Matrix class.
 718     */
 719    template <typename T,
 720              typename = typename std::enable_if<
 721                std::is_integral<T>::value || std::is_floating_point<T>::value,
 722                bool>::type>
 723    class Matrix2 {
 724        std::vector<std::vector<T>> _mat;
 725
 726        public:
 727            /**
 728            * @brief Constructor
 729            * @tparam Integer ensuring integers are being evaluated and not other
 730            * data types.
 731            * @param size denoting the size of Matrix as size x size
 732            */
 733            template <typename Integer,
 734                    typename = typename std::enable_if<std::is_integral<Integer>::value,
 735                    Integer>::type>
 736            explicit Matrix(const Integer size) {
 737                for (size_t i = 0; i < size; ++i) {
 738                    _mat.emplace_back(std::vector<T>(size, 0));
 739                }
 740            }
 741    }"
 742    .unindent();
 743
 744    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 745
 746    assert_documents_eq(
 747        &documents,
 748        &[
 749            (
 750                "
 751        /**
 752         * @brief Main function
 753         * @returns 0 on exit
 754         */
 755        int main() { return 0; }"
 756                    .unindent(),
 757                54,
 758            ),
 759            (
 760                "
 761                /**
 762                * This is a test comment
 763                */
 764                class MyClass {       // The class
 765                    public:           // Access specifier
 766                    int myNum;        // Attribute (int variable)
 767                    string myString;  // Attribute (string variable)
 768                }"
 769                .unindent(),
 770                112,
 771            ),
 772            (
 773                "
 774                // This is a test comment
 775                enum Color { red, green, blue }"
 776                    .unindent(),
 777                322,
 778            ),
 779            (
 780                "
 781                /** This is a preceding block comment
 782                 * This is the second line
 783                 */
 784                struct {           // Structure declaration
 785                    int myNum;       // Member (int variable)
 786                    string myString; // Member (string variable)
 787                } myStructure;"
 788                    .unindent(),
 789                425,
 790            ),
 791            (
 792                "
 793                /**
 794                 * @brief Matrix class.
 795                 */
 796                template <typename T,
 797                          typename = typename std::enable_if<
 798                            std::is_integral<T>::value || std::is_floating_point<T>::value,
 799                            bool>::type>
 800                class Matrix2 {
 801                    std::vector<std::vector<T>> _mat;
 802
 803                    public:
 804                        /**
 805                        * @brief Constructor
 806                        * @tparam Integer ensuring integers are being evaluated and not other
 807                        * data types.
 808                        * @param size denoting the size of Matrix as size x size
 809                        */
 810                        template <typename Integer,
 811                                typename = typename std::enable_if<std::is_integral<Integer>::value,
 812                                Integer>::type>
 813                        explicit Matrix(const Integer size) {
 814                            for (size_t i = 0; i < size; ++i) {
 815                                _mat.emplace_back(std::vector<T>(size, 0));
 816                            }
 817                        }
 818                }"
 819                .unindent(),
 820                612,
 821            ),
 822            (
 823                "
 824                explicit Matrix(const Integer size) {
 825                    for (size_t i = 0; i < size; ++i) {
 826                        _mat.emplace_back(std::vector<T>(size, 0));
 827                    }
 828                }"
 829                .unindent(),
 830                1226,
 831            ),
 832        ],
 833    );
 834}
 835
 836#[gpui::test]
 837async fn test_code_context_retrieval_ruby() {
 838    let language = ruby_lang();
 839    let mut retriever = CodeContextRetriever::new();
 840
 841    let text = r#"
 842        # This concern is inspired by "sudo mode" on GitHub. It
 843        # is a way to re-authenticate a user before allowing them
 844        # to see or perform an action.
 845        #
 846        # Add `before_action :require_challenge!` to actions you
 847        # want to protect.
 848        #
 849        # The user will be shown a page to enter the challenge (which
 850        # is either the password, or just the username when no
 851        # password exists). Upon passing, there is a grace period
 852        # during which no challenge will be asked from the user.
 853        #
 854        # Accessing challenge-protected resources during the grace
 855        # period will refresh the grace period.
 856        module ChallengableConcern
 857            extend ActiveSupport::Concern
 858
 859            CHALLENGE_TIMEOUT = 1.hour.freeze
 860
 861            def require_challenge!
 862                return if skip_challenge?
 863
 864                if challenge_passed_recently?
 865                    session[:challenge_passed_at] = Time.now.utc
 866                    return
 867                end
 868
 869                @challenge = Form::Challenge.new(return_to: request.url)
 870
 871                if params.key?(:form_challenge)
 872                    if challenge_passed?
 873                        session[:challenge_passed_at] = Time.now.utc
 874                    else
 875                        flash.now[:alert] = I18n.t('challenge.invalid_password')
 876                        render_challenge
 877                    end
 878                else
 879                    render_challenge
 880                end
 881            end
 882
 883            def challenge_passed?
 884                current_user.valid_password?(challenge_params[:current_password])
 885            end
 886        end
 887
 888        class Animal
 889            include Comparable
 890
 891            attr_reader :legs
 892
 893            def initialize(name, legs)
 894                @name, @legs = name, legs
 895            end
 896
 897            def <=>(other)
 898                legs <=> other.legs
 899            end
 900        end
 901
 902        # Singleton method for car object
 903        def car.wheels
 904            puts "There are four wheels"
 905        end"#
 906        .unindent();
 907
 908    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 909
 910    assert_documents_eq(
 911        &documents,
 912        &[
 913            (
 914                r#"
 915        # This concern is inspired by "sudo mode" on GitHub. It
 916        # is a way to re-authenticate a user before allowing them
 917        # to see or perform an action.
 918        #
 919        # Add `before_action :require_challenge!` to actions you
 920        # want to protect.
 921        #
 922        # The user will be shown a page to enter the challenge (which
 923        # is either the password, or just the username when no
 924        # password exists). Upon passing, there is a grace period
 925        # during which no challenge will be asked from the user.
 926        #
 927        # Accessing challenge-protected resources during the grace
 928        # period will refresh the grace period.
 929        module ChallengableConcern
 930            extend ActiveSupport::Concern
 931
 932            CHALLENGE_TIMEOUT = 1.hour.freeze
 933
 934            def require_challenge!
 935                # ...
 936            end
 937
 938            def challenge_passed?
 939                # ...
 940            end
 941        end"#
 942                    .unindent(),
 943                558,
 944            ),
 945            (
 946                r#"
 947            def require_challenge!
 948                return if skip_challenge?
 949
 950                if challenge_passed_recently?
 951                    session[:challenge_passed_at] = Time.now.utc
 952                    return
 953                end
 954
 955                @challenge = Form::Challenge.new(return_to: request.url)
 956
 957                if params.key?(:form_challenge)
 958                    if challenge_passed?
 959                        session[:challenge_passed_at] = Time.now.utc
 960                    else
 961                        flash.now[:alert] = I18n.t('challenge.invalid_password')
 962                        render_challenge
 963                    end
 964                else
 965                    render_challenge
 966                end
 967            end"#
 968                    .unindent(),
 969                663,
 970            ),
 971            (
 972                r#"
 973                def challenge_passed?
 974                    current_user.valid_password?(challenge_params[:current_password])
 975                end"#
 976                    .unindent(),
 977                1254,
 978            ),
 979            (
 980                r#"
 981                class Animal
 982                    include Comparable
 983
 984                    attr_reader :legs
 985
 986                    def initialize(name, legs)
 987                        # ...
 988                    end
 989
 990                    def <=>(other)
 991                        # ...
 992                    end
 993                end"#
 994                    .unindent(),
 995                1363,
 996            ),
 997            (
 998                r#"
 999                def initialize(name, legs)
1000                    @name, @legs = name, legs
1001                end"#
1002                    .unindent(),
1003                1427,
1004            ),
1005            (
1006                r#"
1007                def <=>(other)
1008                    legs <=> other.legs
1009                end"#
1010                    .unindent(),
1011                1501,
1012            ),
1013            (
1014                r#"
1015                # Singleton method for car object
1016                def car.wheels
1017                    puts "There are four wheels"
1018                end"#
1019                    .unindent(),
1020                1591,
1021            ),
1022        ],
1023    );
1024}
1025
1026#[gpui::test]
1027async fn test_code_context_retrieval_php() {
1028    let language = php_lang();
1029    let mut retriever = CodeContextRetriever::new();
1030
1031    let text = r#"
1032        <?php
1033
1034        namespace LevelUp\Experience\Concerns;
1035
1036        /*
1037        This is a multiple-lines comment block
1038        that spans over multiple
1039        lines
1040        */
1041        function functionName() {
1042            echo "Hello world!";
1043        }
1044
1045        trait HasAchievements
1046        {
1047            /**
1048            * @throws \Exception
1049            */
1050            public function grantAchievement(Achievement $achievement, $progress = null): void
1051            {
1052                if ($progress > 100) {
1053                    throw new Exception(message: 'Progress cannot be greater than 100');
1054                }
1055
1056                if ($this->achievements()->find($achievement->id)) {
1057                    throw new Exception(message: 'User already has this Achievement');
1058                }
1059
1060                $this->achievements()->attach($achievement, [
1061                    'progress' => $progress ?? null,
1062                ]);
1063
1064                $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1065            }
1066
1067            public function achievements(): BelongsToMany
1068            {
1069                return $this->belongsToMany(related: Achievement::class)
1070                ->withPivot(columns: 'progress')
1071                ->where('is_secret', false)
1072                ->using(AchievementUser::class);
1073            }
1074        }
1075
1076        interface Multiplier
1077        {
1078            public function qualifies(array $data): bool;
1079
1080            public function setMultiplier(): int;
1081        }
1082
1083        enum AuditType: string
1084        {
1085            case Add = 'add';
1086            case Remove = 'remove';
1087            case Reset = 'reset';
1088            case LevelUp = 'level_up';
1089        }
1090
1091        ?>"#
1092    .unindent();
1093
1094    let documents = retriever.parse_file(&text, language.clone()).unwrap();
1095
1096    assert_documents_eq(
1097        &documents,
1098        &[
1099            (
1100                r#"
1101        /*
1102        This is a multiple-lines comment block
1103        that spans over multiple
1104        lines
1105        */
1106        function functionName() {
1107            echo "Hello world!";
1108        }"#
1109                .unindent(),
1110                123,
1111            ),
1112            (
1113                r#"
1114        trait HasAchievements
1115        {
1116            /**
1117            * @throws \Exception
1118            */
1119            public function grantAchievement(Achievement $achievement, $progress = null): void
1120            {/* ... */}
1121
1122            public function achievements(): BelongsToMany
1123            {/* ... */}
1124        }"#
1125                .unindent(),
1126                177,
1127            ),
1128            (r#"
1129            /**
1130            * @throws \Exception
1131            */
1132            public function grantAchievement(Achievement $achievement, $progress = null): void
1133            {
1134                if ($progress > 100) {
1135                    throw new Exception(message: 'Progress cannot be greater than 100');
1136                }
1137
1138                if ($this->achievements()->find($achievement->id)) {
1139                    throw new Exception(message: 'User already has this Achievement');
1140                }
1141
1142                $this->achievements()->attach($achievement, [
1143                    'progress' => $progress ?? null,
1144                ]);
1145
1146                $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1147            }"#.unindent(), 245),
1148            (r#"
1149                public function achievements(): BelongsToMany
1150                {
1151                    return $this->belongsToMany(related: Achievement::class)
1152                    ->withPivot(columns: 'progress')
1153                    ->where('is_secret', false)
1154                    ->using(AchievementUser::class);
1155                }"#.unindent(), 902),
1156            (r#"
1157                interface Multiplier
1158                {
1159                    public function qualifies(array $data): bool;
1160
1161                    public function setMultiplier(): int;
1162                }"#.unindent(),
1163                1146),
1164            (r#"
1165                enum AuditType: string
1166                {
1167                    case Add = 'add';
1168                    case Remove = 'remove';
1169                    case Reset = 'reset';
1170                    case LevelUp = 'level_up';
1171                }"#.unindent(), 1265)
1172        ],
1173    );
1174}
1175
1176#[gpui::test]
1177fn test_dot_product(mut rng: StdRng) {
1178    assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
1179    assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
1180
1181    for _ in 0..100 {
1182        let size = 1536;
1183        let mut a = vec![0.; size];
1184        let mut b = vec![0.; size];
1185        for (a, b) in a.iter_mut().zip(b.iter_mut()) {
1186            *a = rng.gen();
1187            *b = rng.gen();
1188        }
1189
1190        assert_eq!(
1191            round_to_decimals(dot(&a, &b), 1),
1192            round_to_decimals(reference_dot(&a, &b), 1)
1193        );
1194    }
1195
1196    fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
1197        let factor = (10.0 as f32).powi(decimal_places);
1198        (n * factor).round() / factor
1199    }
1200
1201    fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
1202        a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
1203    }
1204}
1205
1206#[derive(Default)]
1207struct FakeEmbeddingProvider {
1208    embedding_count: AtomicUsize,
1209}
1210
1211impl FakeEmbeddingProvider {
1212    fn embedding_count(&self) -> usize {
1213        self.embedding_count.load(atomic::Ordering::SeqCst)
1214    }
1215}
1216
1217#[async_trait]
1218impl EmbeddingProvider for FakeEmbeddingProvider {
1219    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
1220        self.embedding_count
1221            .fetch_add(spans.len(), atomic::Ordering::SeqCst);
1222        Ok(spans
1223            .iter()
1224            .map(|span| {
1225                let mut result = vec![1.0; 26];
1226                for letter in span.chars() {
1227                    let letter = letter.to_ascii_lowercase();
1228                    if letter as u32 >= 'a' as u32 {
1229                        let ix = (letter as u32) - ('a' as u32);
1230                        if ix < 26 {
1231                            result[ix as usize] += 1.0;
1232                        }
1233                    }
1234                }
1235
1236                let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
1237                for x in &mut result {
1238                    *x /= norm;
1239                }
1240
1241                result
1242            })
1243            .collect())
1244    }
1245}
1246
1247fn js_lang() -> Arc<Language> {
1248    Arc::new(
1249        Language::new(
1250            LanguageConfig {
1251                name: "Javascript".into(),
1252                path_suffixes: vec!["js".into()],
1253                ..Default::default()
1254            },
1255            Some(tree_sitter_typescript::language_tsx()),
1256        )
1257        .with_embedding_query(
1258            &r#"
1259
1260            (
1261                (comment)* @context
1262                .
1263                [
1264                (export_statement
1265                    (function_declaration
1266                        "async"? @name
1267                        "function" @name
1268                        name: (_) @name))
1269                (function_declaration
1270                    "async"? @name
1271                    "function" @name
1272                    name: (_) @name)
1273                ] @item
1274            )
1275
1276            (
1277                (comment)* @context
1278                .
1279                [
1280                (export_statement
1281                    (class_declaration
1282                        "class" @name
1283                        name: (_) @name))
1284                (class_declaration
1285                    "class" @name
1286                    name: (_) @name)
1287                ] @item
1288            )
1289
1290            (
1291                (comment)* @context
1292                .
1293                [
1294                (export_statement
1295                    (interface_declaration
1296                        "interface" @name
1297                        name: (_) @name))
1298                (interface_declaration
1299                    "interface" @name
1300                    name: (_) @name)
1301                ] @item
1302            )
1303
1304            (
1305                (comment)* @context
1306                .
1307                [
1308                (export_statement
1309                    (enum_declaration
1310                        "enum" @name
1311                        name: (_) @name))
1312                (enum_declaration
1313                    "enum" @name
1314                    name: (_) @name)
1315                ] @item
1316            )
1317
1318            (
1319                (comment)* @context
1320                .
1321                (method_definition
1322                    [
1323                        "get"
1324                        "set"
1325                        "async"
1326                        "*"
1327                        "static"
1328                    ]* @name
1329                    name: (_) @name) @item
1330            )
1331
1332                    "#
1333            .unindent(),
1334        )
1335        .unwrap(),
1336    )
1337}
1338
1339fn rust_lang() -> Arc<Language> {
1340    Arc::new(
1341        Language::new(
1342            LanguageConfig {
1343                name: "Rust".into(),
1344                path_suffixes: vec!["rs".into()],
1345                collapsed_placeholder: " /* ... */ ".to_string(),
1346                ..Default::default()
1347            },
1348            Some(tree_sitter_rust::language()),
1349        )
1350        .with_embedding_query(
1351            r#"
1352            (
1353                [(line_comment) (attribute_item)]* @context
1354                .
1355                [
1356                    (struct_item
1357                        name: (_) @name)
1358
1359                    (enum_item
1360                        name: (_) @name)
1361
1362                    (impl_item
1363                        trait: (_)? @name
1364                        "for"? @name
1365                        type: (_) @name)
1366
1367                    (trait_item
1368                        name: (_) @name)
1369
1370                    (function_item
1371                        name: (_) @name
1372                        body: (block
1373                            "{" @keep
1374                            "}" @keep) @collapse)
1375
1376                    (macro_definition
1377                        name: (_) @name)
1378                ] @item
1379            )
1380            "#,
1381        )
1382        .unwrap(),
1383    )
1384}
1385
1386fn json_lang() -> Arc<Language> {
1387    Arc::new(
1388        Language::new(
1389            LanguageConfig {
1390                name: "JSON".into(),
1391                path_suffixes: vec!["json".into()],
1392                ..Default::default()
1393            },
1394            Some(tree_sitter_json::language()),
1395        )
1396        .with_embedding_query(
1397            r#"
1398            (document) @item
1399
1400            (array
1401                "[" @keep
1402                .
1403                (object)? @keep
1404                "]" @keep) @collapse
1405
1406            (pair value: (string
1407                "\"" @keep
1408                "\"" @keep) @collapse)
1409            "#,
1410        )
1411        .unwrap(),
1412    )
1413}
1414
1415fn toml_lang() -> Arc<Language> {
1416    Arc::new(Language::new(
1417        LanguageConfig {
1418            name: "TOML".into(),
1419            path_suffixes: vec!["toml".into()],
1420            ..Default::default()
1421        },
1422        Some(tree_sitter_toml::language()),
1423    ))
1424}
1425
1426fn cpp_lang() -> Arc<Language> {
1427    Arc::new(
1428        Language::new(
1429            LanguageConfig {
1430                name: "CPP".into(),
1431                path_suffixes: vec!["cpp".into()],
1432                ..Default::default()
1433            },
1434            Some(tree_sitter_cpp::language()),
1435        )
1436        .with_embedding_query(
1437            r#"
1438            (
1439                (comment)* @context
1440                .
1441                (function_definition
1442                    (type_qualifier)? @name
1443                    type: (_)? @name
1444                    declarator: [
1445                        (function_declarator
1446                            declarator: (_) @name)
1447                        (pointer_declarator
1448                            "*" @name
1449                            declarator: (function_declarator
1450                            declarator: (_) @name))
1451                        (pointer_declarator
1452                            "*" @name
1453                            declarator: (pointer_declarator
1454                                "*" @name
1455                            declarator: (function_declarator
1456                                declarator: (_) @name)))
1457                        (reference_declarator
1458                            ["&" "&&"] @name
1459                            (function_declarator
1460                            declarator: (_) @name))
1461                    ]
1462                    (type_qualifier)? @name) @item
1463                )
1464
1465            (
1466                (comment)* @context
1467                .
1468                (template_declaration
1469                    (class_specifier
1470                        "class" @name
1471                        name: (_) @name)
1472                        ) @item
1473            )
1474
1475            (
1476                (comment)* @context
1477                .
1478                (class_specifier
1479                    "class" @name
1480                    name: (_) @name) @item
1481                )
1482
1483            (
1484                (comment)* @context
1485                .
1486                (enum_specifier
1487                    "enum" @name
1488                    name: (_) @name) @item
1489                )
1490
1491            (
1492                (comment)* @context
1493                .
1494                (declaration
1495                    type: (struct_specifier
1496                    "struct" @name)
1497                    declarator: (_) @name) @item
1498            )
1499
1500            "#,
1501        )
1502        .unwrap(),
1503    )
1504}
1505
1506fn lua_lang() -> Arc<Language> {
1507    Arc::new(
1508        Language::new(
1509            LanguageConfig {
1510                name: "Lua".into(),
1511                path_suffixes: vec!["lua".into()],
1512                collapsed_placeholder: "--[ ... ]--".to_string(),
1513                ..Default::default()
1514            },
1515            Some(tree_sitter_lua::language()),
1516        )
1517        .with_embedding_query(
1518            r#"
1519            (
1520                (comment)* @context
1521                .
1522                (function_declaration
1523                    "function" @name
1524                    name: (_) @name
1525                    (comment)* @collapse
1526                    body: (block) @collapse
1527                ) @item
1528            )
1529        "#,
1530        )
1531        .unwrap(),
1532    )
1533}
1534
1535fn php_lang() -> Arc<Language> {
1536    Arc::new(
1537        Language::new(
1538            LanguageConfig {
1539                name: "PHP".into(),
1540                path_suffixes: vec!["php".into()],
1541                collapsed_placeholder: "/* ... */".into(),
1542                ..Default::default()
1543            },
1544            Some(tree_sitter_php::language()),
1545        )
1546        .with_embedding_query(
1547            r#"
1548            (
1549                (comment)* @context
1550                .
1551                [
1552                    (function_definition
1553                        "function" @name
1554                        name: (_) @name
1555                        body: (_
1556                            "{" @keep
1557                            "}" @keep) @collapse
1558                        )
1559
1560                    (trait_declaration
1561                        "trait" @name
1562                        name: (_) @name)
1563
1564                    (method_declaration
1565                        "function" @name
1566                        name: (_) @name
1567                        body: (_
1568                            "{" @keep
1569                            "}" @keep) @collapse
1570                        )
1571
1572                    (interface_declaration
1573                        "interface" @name
1574                        name: (_) @name
1575                        )
1576
1577                    (enum_declaration
1578                        "enum" @name
1579                        name: (_) @name
1580                        )
1581
1582                ] @item
1583            )
1584            "#,
1585        )
1586        .unwrap(),
1587    )
1588}
1589
1590fn ruby_lang() -> Arc<Language> {
1591    Arc::new(
1592        Language::new(
1593            LanguageConfig {
1594                name: "Ruby".into(),
1595                path_suffixes: vec!["rb".into()],
1596                collapsed_placeholder: "# ...".to_string(),
1597                ..Default::default()
1598            },
1599            Some(tree_sitter_ruby::language()),
1600        )
1601        .with_embedding_query(
1602            r#"
1603            (
1604                (comment)* @context
1605                .
1606                [
1607                (module
1608                    "module" @name
1609                    name: (_) @name)
1610                (method
1611                    "def" @name
1612                    name: (_) @name
1613                    body: (body_statement) @collapse)
1614                (class
1615                    "class" @name
1616                    name: (_) @name)
1617                (singleton_method
1618                    "def" @name
1619                    object: (_) @name
1620                    "." @name
1621                    name: (_) @name
1622                    body: (body_statement) @collapse)
1623                ] @item
1624            )
1625            "#,
1626        )
1627        .unwrap(),
1628    )
1629}
1630
1631fn elixir_lang() -> Arc<Language> {
1632    Arc::new(
1633        Language::new(
1634            LanguageConfig {
1635                name: "Elixir".into(),
1636                path_suffixes: vec!["rs".into()],
1637                ..Default::default()
1638            },
1639            Some(tree_sitter_elixir::language()),
1640        )
1641        .with_embedding_query(
1642            r#"
1643            (
1644                (unary_operator
1645                    operator: "@"
1646                    operand: (call
1647                        target: (identifier) @unary
1648                        (#match? @unary "^(doc)$"))
1649                    ) @context
1650                .
1651                (call
1652                target: (identifier) @name
1653                (arguments
1654                [
1655                (identifier) @name
1656                (call
1657                target: (identifier) @name)
1658                (binary_operator
1659                left: (call
1660                target: (identifier) @name)
1661                operator: "when")
1662                ])
1663                (#match? @name "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @item
1664                )
1665
1666            (call
1667                target: (identifier) @name
1668                (arguments (alias) @name)
1669                (#match? @name "^(defmodule|defprotocol)$")) @item
1670            "#,
1671        )
1672        .unwrap(),
1673    )
1674}
1675
1676#[gpui::test]
1677fn test_subtract_ranges() {
1678    // collapsed_ranges: Vec<Range<usize>>, keep_ranges: Vec<Range<usize>>
1679
1680    assert_eq!(
1681        subtract_ranges(&[0..5, 10..21], &[0..1, 4..5]),
1682        vec![1..4, 10..21]
1683    );
1684
1685    assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
1686}