semantic_index_tests.rs

   1use crate::{
   2    embedding_queue::EmbeddingQueue,
   3    parsing::{subtract_ranges, CodeContextRetriever, Span, SpanDigest},
   4    semantic_index_settings::SemanticIndexSettings,
   5    FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
   6};
   7use ai::test::FakeEmbeddingProvider;
   8use gpui::TestAppContext;
   9use language::{Language, LanguageConfig, LanguageMatcher, LanguageRegistry, ToOffset};
  10use parking_lot::Mutex;
  11use pretty_assertions::assert_eq;
  12use project::{FakeFs, Fs, Project};
  13use rand::{rngs::StdRng, Rng};
  14use serde_json::json;
  15use settings::{Settings, SettingsStore};
  16use std::{path::Path, sync::Arc, time::SystemTime};
  17use unindent::Unindent;
  18use util::{paths::PathMatcher, RandomCharIter};
  19
  20#[ctor::ctor]
  21fn init_logger() {
  22    if std::env::var("RUST_LOG").is_ok() {
  23        env_logger::init();
  24    }
  25}
  26
  27#[gpui::test]
  28async fn test_semantic_index(cx: &mut TestAppContext) {
  29    init_test(cx);
  30
  31    let fs = FakeFs::new(cx.background_executor.clone());
  32    fs.insert_tree(
  33        "/the-root",
  34        json!({
  35            "src": {
  36                "file1.rs": "
  37                    fn aaa() {
  38                        println!(\"aaaaaaaaaaaa!\");
  39                    }
  40
  41                    fn zzzzz() {
  42                        println!(\"SLEEPING\");
  43                    }
  44                ".unindent(),
  45                "file2.rs": "
  46                    fn bbb() {
  47                        println!(\"bbbbbbbbbbbbb!\");
  48                    }
  49                    struct pqpqpqp {}
  50                ".unindent(),
  51                "file3.toml": "
  52                    ZZZZZZZZZZZZZZZZZZ = 5
  53                ".unindent(),
  54            }
  55        }),
  56    )
  57    .await;
  58
  59    let languages = Arc::new(LanguageRegistry::test(cx.executor().clone()));
  60    let rust_language = rust_lang();
  61    let toml_language = toml_lang();
  62    languages.add(rust_language);
  63    languages.add(toml_language);
  64
  65    let db_dir = tempfile::Builder::new()
  66        .prefix("vector-store")
  67        .tempdir()
  68        .unwrap();
  69    let db_path = db_dir.path().join("db.sqlite");
  70
  71    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
  72    let semantic_index = SemanticIndex::new(
  73        fs.clone(),
  74        db_path,
  75        embedding_provider.clone(),
  76        languages,
  77        cx.to_async(),
  78    )
  79    .await
  80    .unwrap();
  81
  82    let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
  83
  84    let search_results = semantic_index.update(cx, |store, cx| {
  85        store.search_project(
  86            project.clone(),
  87            "aaaaaabbbbzz".to_string(),
  88            5,
  89            vec![],
  90            vec![],
  91            cx,
  92        )
  93    });
  94    let pending_file_count =
  95        semantic_index.read_with(cx, |index, _| index.pending_file_count(&project).unwrap());
  96    cx.background_executor.run_until_parked();
  97    assert_eq!(*pending_file_count.borrow(), 3);
  98    cx.background_executor
  99        .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
 100    assert_eq!(*pending_file_count.borrow(), 0);
 101
 102    let search_results = search_results.await.unwrap();
 103    assert_search_results(
 104        &search_results,
 105        &[
 106            (Path::new("src/file1.rs").into(), 0),
 107            (Path::new("src/file2.rs").into(), 0),
 108            (Path::new("src/file3.toml").into(), 0),
 109            (Path::new("src/file1.rs").into(), 45),
 110            (Path::new("src/file2.rs").into(), 45),
 111        ],
 112        cx,
 113    );
 114
 115    // Test Include Files Functionality
 116    let include_files = vec![PathMatcher::new("*.rs").unwrap()];
 117    let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
 118    let rust_only_search_results = semantic_index
 119        .update(cx, |store, cx| {
 120            store.search_project(
 121                project.clone(),
 122                "aaaaaabbbbzz".to_string(),
 123                5,
 124                include_files,
 125                vec![],
 126                cx,
 127            )
 128        })
 129        .await
 130        .unwrap();
 131
 132    assert_search_results(
 133        &rust_only_search_results,
 134        &[
 135            (Path::new("src/file1.rs").into(), 0),
 136            (Path::new("src/file2.rs").into(), 0),
 137            (Path::new("src/file1.rs").into(), 45),
 138            (Path::new("src/file2.rs").into(), 45),
 139        ],
 140        cx,
 141    );
 142
 143    let no_rust_search_results = semantic_index
 144        .update(cx, |store, cx| {
 145            store.search_project(
 146                project.clone(),
 147                "aaaaaabbbbzz".to_string(),
 148                5,
 149                vec![],
 150                exclude_files,
 151                cx,
 152            )
 153        })
 154        .await
 155        .unwrap();
 156
 157    assert_search_results(
 158        &no_rust_search_results,
 159        &[(Path::new("src/file3.toml").into(), 0)],
 160        cx,
 161    );
 162
 163    fs.save(
 164        "/the-root/src/file2.rs".as_ref(),
 165        &"
 166            fn dddd() { println!(\"ddddd!\"); }
 167            struct pqpqpqp {}
 168        "
 169        .unindent()
 170        .into(),
 171        Default::default(),
 172    )
 173    .await
 174    .unwrap();
 175
 176    cx.background_executor
 177        .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
 178
 179    let prev_embedding_count = embedding_provider.embedding_count();
 180    let index = semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx));
 181    cx.background_executor.run_until_parked();
 182    assert_eq!(*pending_file_count.borrow(), 1);
 183    cx.background_executor
 184        .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
 185    assert_eq!(*pending_file_count.borrow(), 0);
 186    index.await.unwrap();
 187
 188    assert_eq!(
 189        embedding_provider.embedding_count() - prev_embedding_count,
 190        1
 191    );
 192}
 193
 194#[gpui::test(iterations = 10)]
 195async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
 196    let (outstanding_job_count, _) = postage::watch::channel_with(0);
 197    let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count));
 198
 199    let files = (1..=3)
 200        .map(|file_ix| FileToEmbed {
 201            worktree_id: 5,
 202            path: Path::new(&format!("path-{file_ix}")).into(),
 203            mtime: SystemTime::now(),
 204            spans: (0..rng.gen_range(4..22))
 205                .map(|document_ix| {
 206                    let content_len = rng.gen_range(10..100);
 207                    let content = RandomCharIter::new(&mut rng)
 208                        .with_simple_text()
 209                        .take(content_len)
 210                        .collect::<String>();
 211                    let digest = SpanDigest::from(content.as_str());
 212                    Span {
 213                        range: 0..10,
 214                        embedding: None,
 215                        name: format!("document {document_ix}"),
 216                        content,
 217                        digest,
 218                        token_count: rng.gen_range(10..30),
 219                    }
 220                })
 221                .collect(),
 222            job_handle: JobHandle::new(&outstanding_job_count),
 223        })
 224        .collect::<Vec<_>>();
 225
 226    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 227
 228    let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor.clone());
 229    for file in &files {
 230        queue.push(file.clone());
 231    }
 232    queue.flush();
 233
 234    cx.background_executor.run_until_parked();
 235    let finished_files = queue.finished_files();
 236    let mut embedded_files: Vec<_> = files
 237        .iter()
 238        .map(|_| finished_files.try_recv().expect("no finished file"))
 239        .collect();
 240
 241    let expected_files: Vec<_> = files
 242        .iter()
 243        .map(|file| {
 244            let mut file = file.clone();
 245            for doc in &mut file.spans {
 246                doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
 247            }
 248            file
 249        })
 250        .collect();
 251
 252    embedded_files.sort_by_key(|f| f.path.clone());
 253
 254    assert_eq!(embedded_files, expected_files);
 255}
 256
 257#[track_caller]
 258fn assert_search_results(
 259    actual: &[SearchResult],
 260    expected: &[(Arc<Path>, usize)],
 261    cx: &TestAppContext,
 262) {
 263    let actual = actual
 264        .iter()
 265        .map(|search_result| {
 266            search_result.buffer.read_with(cx, |buffer, _cx| {
 267                (
 268                    buffer.file().unwrap().path().clone(),
 269                    search_result.range.start.to_offset(buffer),
 270                )
 271            })
 272        })
 273        .collect::<Vec<_>>();
 274    assert_eq!(actual, expected);
 275}
 276
 277#[gpui::test]
 278async fn test_code_context_retrieval_rust() {
 279    let language = rust_lang();
 280    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 281    let mut retriever = CodeContextRetriever::new(embedding_provider);
 282
 283    let text = "
 284        /// A doc comment
 285        /// that spans multiple lines
 286        #[gpui::test]
 287        fn a() {
 288            b
 289        }
 290
 291        impl C for D {
 292        }
 293
 294        impl E {
 295            // This is also a preceding comment
 296            pub fn function_1() -> Option<()> {
 297                unimplemented!();
 298            }
 299
 300            // This is a preceding comment
 301            fn function_2() -> Result<()> {
 302                unimplemented!();
 303            }
 304        }
 305
 306        #[derive(Clone)]
 307        struct D {
 308            name: String
 309        }
 310    "
 311    .unindent();
 312
 313    let documents = retriever.parse_file(&text, language).unwrap();
 314
 315    assert_documents_eq(
 316        &documents,
 317        &[
 318            (
 319                "
 320                /// A doc comment
 321                /// that spans multiple lines
 322                #[gpui::test]
 323                fn a() {
 324                    b
 325                }"
 326                .unindent(),
 327                text.find("fn a").unwrap(),
 328            ),
 329            (
 330                "
 331                impl C for D {
 332                }"
 333                .unindent(),
 334                text.find("impl C").unwrap(),
 335            ),
 336            (
 337                "
 338                impl E {
 339                    // This is also a preceding comment
 340                    pub fn function_1() -> Option<()> { /* ... */ }
 341
 342                    // This is a preceding comment
 343                    fn function_2() -> Result<()> { /* ... */ }
 344                }"
 345                .unindent(),
 346                text.find("impl E").unwrap(),
 347            ),
 348            (
 349                "
 350                // This is also a preceding comment
 351                pub fn function_1() -> Option<()> {
 352                    unimplemented!();
 353                }"
 354                .unindent(),
 355                text.find("pub fn function_1").unwrap(),
 356            ),
 357            (
 358                "
 359                // This is a preceding comment
 360                fn function_2() -> Result<()> {
 361                    unimplemented!();
 362                }"
 363                .unindent(),
 364                text.find("fn function_2").unwrap(),
 365            ),
 366            (
 367                "
 368                #[derive(Clone)]
 369                struct D {
 370                    name: String
 371                }"
 372                .unindent(),
 373                text.find("struct D").unwrap(),
 374            ),
 375        ],
 376    );
 377}
 378
 379#[gpui::test]
 380async fn test_code_context_retrieval_json() {
 381    let language = json_lang();
 382    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 383    let mut retriever = CodeContextRetriever::new(embedding_provider);
 384
 385    let text = r#"
 386        {
 387            "array": [1, 2, 3, 4],
 388            "string": "abcdefg",
 389            "nested_object": {
 390                "array_2": [5, 6, 7, 8],
 391                "string_2": "hijklmnop",
 392                "boolean": true,
 393                "none": null
 394            }
 395        }
 396    "#
 397    .unindent();
 398
 399    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 400
 401    assert_documents_eq(
 402        &documents,
 403        &[(
 404            r#"
 405                {
 406                    "array": [],
 407                    "string": "",
 408                    "nested_object": {
 409                        "array_2": [],
 410                        "string_2": "",
 411                        "boolean": true,
 412                        "none": null
 413                    }
 414                }"#
 415            .unindent(),
 416            text.find('{').unwrap(),
 417        )],
 418    );
 419
 420    let text = r#"
 421        [
 422            {
 423                "name": "somebody",
 424                "age": 42
 425            },
 426            {
 427                "name": "somebody else",
 428                "age": 43
 429            }
 430        ]
 431    "#
 432    .unindent();
 433
 434    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 435
 436    assert_documents_eq(
 437        &documents,
 438        &[(
 439            r#"
 440            [{
 441                    "name": "",
 442                    "age": 42
 443                }]"#
 444            .unindent(),
 445            text.find('[').unwrap(),
 446        )],
 447    );
 448}
 449
 450fn assert_documents_eq(
 451    documents: &[Span],
 452    expected_contents_and_start_offsets: &[(String, usize)],
 453) {
 454    assert_eq!(
 455        documents
 456            .iter()
 457            .map(|document| (document.content.clone(), document.range.start))
 458            .collect::<Vec<_>>(),
 459        expected_contents_and_start_offsets
 460    );
 461}
 462
 463#[gpui::test]
 464async fn test_code_context_retrieval_javascript() {
 465    let language = js_lang();
 466    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 467    let mut retriever = CodeContextRetriever::new(embedding_provider);
 468
 469    let text = "
 470        /* globals importScripts, backend */
 471        function _authorize() {}
 472
 473        /**
 474         * Sometimes the frontend build is way faster than backend.
 475         */
 476        export async function authorizeBank() {
 477            _authorize(pushModal, upgradingAccountId, {});
 478        }
 479
 480        export class SettingsPage {
 481            /* This is a test setting */
 482            constructor(page) {
 483                this.page = page;
 484            }
 485        }
 486
 487        /* This is a test comment */
 488        class TestClass {}
 489
 490        /* Schema for editor_events in Clickhouse. */
 491        export interface ClickhouseEditorEvent {
 492            installation_id: string
 493            operation: string
 494        }
 495        "
 496    .unindent();
 497
 498    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 499
 500    assert_documents_eq(
 501        &documents,
 502        &[
 503            (
 504                "
 505            /* globals importScripts, backend */
 506            function _authorize() {}"
 507                    .unindent(),
 508                37,
 509            ),
 510            (
 511                "
 512            /**
 513             * Sometimes the frontend build is way faster than backend.
 514             */
 515            export async function authorizeBank() {
 516                _authorize(pushModal, upgradingAccountId, {});
 517            }"
 518                .unindent(),
 519                131,
 520            ),
 521            (
 522                "
 523                export class SettingsPage {
 524                    /* This is a test setting */
 525                    constructor(page) {
 526                        this.page = page;
 527                    }
 528                }"
 529                .unindent(),
 530                225,
 531            ),
 532            (
 533                "
 534                /* This is a test setting */
 535                constructor(page) {
 536                    this.page = page;
 537                }"
 538                .unindent(),
 539                290,
 540            ),
 541            (
 542                "
 543                /* This is a test comment */
 544                class TestClass {}"
 545                    .unindent(),
 546                374,
 547            ),
 548            (
 549                "
 550                /* Schema for editor_events in Clickhouse. */
 551                export interface ClickhouseEditorEvent {
 552                    installation_id: string
 553                    operation: string
 554                }"
 555                .unindent(),
 556                440,
 557            ),
 558        ],
 559    )
 560}
 561
 562#[gpui::test]
 563async fn test_code_context_retrieval_lua() {
 564    let language = lua_lang();
 565    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 566    let mut retriever = CodeContextRetriever::new(embedding_provider);
 567
 568    let text = r#"
 569        -- Creates a new class
 570        -- @param baseclass The Baseclass of this class, or nil.
 571        -- @return A new class reference.
 572        function classes.class(baseclass)
 573            -- Create the class definition and metatable.
 574            local classdef = {}
 575            -- Find the super class, either Object or user-defined.
 576            baseclass = baseclass or classes.Object
 577            -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
 578            setmetatable(classdef, { __index = baseclass })
 579            -- All class instances have a reference to the class object.
 580            classdef.class = classdef
 581            --- Recursively allocates the inheritance tree of the instance.
 582            -- @param mastertable The 'root' of the inheritance tree.
 583            -- @return Returns the instance with the allocated inheritance tree.
 584            function classdef.alloc(mastertable)
 585                -- All class instances have a reference to a superclass object.
 586                local instance = { super = baseclass.alloc(mastertable) }
 587                -- Any functions this instance does not know of will 'look up' to the superclass definition.
 588                setmetatable(instance, { __index = classdef, __newindex = mastertable })
 589                return instance
 590            end
 591        end
 592        "#.unindent();
 593
 594    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 595
 596    assert_documents_eq(
 597        &documents,
 598        &[
 599            (r#"
 600                -- Creates a new class
 601                -- @param baseclass The Baseclass of this class, or nil.
 602                -- @return A new class reference.
 603                function classes.class(baseclass)
 604                    -- Create the class definition and metatable.
 605                    local classdef = {}
 606                    -- Find the super class, either Object or user-defined.
 607                    baseclass = baseclass or classes.Object
 608                    -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
 609                    setmetatable(classdef, { __index = baseclass })
 610                    -- All class instances have a reference to the class object.
 611                    classdef.class = classdef
 612                    --- Recursively allocates the inheritance tree of the instance.
 613                    -- @param mastertable The 'root' of the inheritance tree.
 614                    -- @return Returns the instance with the allocated inheritance tree.
 615                    function classdef.alloc(mastertable)
 616                        --[ ... ]--
 617                        --[ ... ]--
 618                    end
 619                end"#.unindent(),
 620            114),
 621            (r#"
 622            --- Recursively allocates the inheritance tree of the instance.
 623            -- @param mastertable The 'root' of the inheritance tree.
 624            -- @return Returns the instance with the allocated inheritance tree.
 625            function classdef.alloc(mastertable)
 626                -- All class instances have a reference to a superclass object.
 627                local instance = { super = baseclass.alloc(mastertable) }
 628                -- Any functions this instance does not know of will 'look up' to the superclass definition.
 629                setmetatable(instance, { __index = classdef, __newindex = mastertable })
 630                return instance
 631            end"#.unindent(), 810),
 632        ]
 633    );
 634}
 635
 636#[gpui::test]
 637async fn test_code_context_retrieval_elixir() {
 638    let language = elixir_lang();
 639    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 640    let mut retriever = CodeContextRetriever::new(embedding_provider);
 641
 642    let text = r#"
 643        defmodule File.Stream do
 644            @moduledoc """
 645            Defines a `File.Stream` struct returned by `File.stream!/3`.
 646
 647            The following fields are public:
 648
 649            * `path`          - the file path
 650            * `modes`         - the file modes
 651            * `raw`           - a boolean indicating if bin functions should be used
 652            * `line_or_bytes` - if reading should read lines or a given number of bytes
 653            * `node`          - the node the file belongs to
 654
 655            """
 656
 657            defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
 658
 659            @type t :: %__MODULE__{}
 660
 661            @doc false
 662            def __build__(path, modes, line_or_bytes) do
 663            raw = :lists.keyfind(:encoding, 1, modes) == false
 664
 665            modes =
 666                case raw do
 667                true ->
 668                    case :lists.keyfind(:read_ahead, 1, modes) do
 669                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
 670                    {:read_ahead, _} -> [:raw | modes]
 671                    false -> [:raw, :read_ahead | modes]
 672                    end
 673
 674                false ->
 675                    modes
 676                end
 677
 678            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
 679
 680            end"#
 681    .unindent();
 682
 683    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 684
 685    assert_documents_eq(
 686        &documents,
 687        &[(
 688            r#"
 689        defmodule File.Stream do
 690            @moduledoc """
 691            Defines a `File.Stream` struct returned by `File.stream!/3`.
 692
 693            The following fields are public:
 694
 695            * `path`          - the file path
 696            * `modes`         - the file modes
 697            * `raw`           - a boolean indicating if bin functions should be used
 698            * `line_or_bytes` - if reading should read lines or a given number of bytes
 699            * `node`          - the node the file belongs to
 700
 701            """
 702
 703            defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
 704
 705            @type t :: %__MODULE__{}
 706
 707            @doc false
 708            def __build__(path, modes, line_or_bytes) do
 709            raw = :lists.keyfind(:encoding, 1, modes) == false
 710
 711            modes =
 712                case raw do
 713                true ->
 714                    case :lists.keyfind(:read_ahead, 1, modes) do
 715                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
 716                    {:read_ahead, _} -> [:raw | modes]
 717                    false -> [:raw, :read_ahead | modes]
 718                    end
 719
 720                false ->
 721                    modes
 722                end
 723
 724            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
 725
 726            end"#
 727                .unindent(),
 728            0,
 729        ),(r#"
 730            @doc false
 731            def __build__(path, modes, line_or_bytes) do
 732            raw = :lists.keyfind(:encoding, 1, modes) == false
 733
 734            modes =
 735                case raw do
 736                true ->
 737                    case :lists.keyfind(:read_ahead, 1, modes) do
 738                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
 739                    {:read_ahead, _} -> [:raw | modes]
 740                    false -> [:raw, :read_ahead | modes]
 741                    end
 742
 743                false ->
 744                    modes
 745                end
 746
 747            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
 748
 749            end"#.unindent(), 574)],
 750    );
 751}
 752
 753#[gpui::test]
 754async fn test_code_context_retrieval_cpp() {
 755    let language = cpp_lang();
 756    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 757    let mut retriever = CodeContextRetriever::new(embedding_provider);
 758
 759    let text = "
 760    /**
 761     * @brief Main function
 762     * @returns 0 on exit
 763     */
 764    int main() { return 0; }
 765
 766    /**
 767    * This is a test comment
 768    */
 769    class MyClass {       // The class
 770        public:           // Access specifier
 771        int myNum;        // Attribute (int variable)
 772        string myString;  // Attribute (string variable)
 773    };
 774
 775    // This is a test comment
 776    enum Color { red, green, blue };
 777
 778    /** This is a preceding block comment
 779     * This is the second line
 780     */
 781    struct {           // Structure declaration
 782        int myNum;       // Member (int variable)
 783        string myString; // Member (string variable)
 784    } myStructure;
 785
 786    /**
 787     * @brief Matrix class.
 788     */
 789    template <typename T,
 790              typename = typename std::enable_if<
 791                std::is_integral<T>::value || std::is_floating_point<T>::value,
 792                bool>::type>
 793    class Matrix2 {
 794        std::vector<std::vector<T>> _mat;
 795
 796        public:
 797            /**
 798            * @brief Constructor
 799            * @tparam Integer ensuring integers are being evaluated and not other
 800            * data types.
 801            * @param size denoting the size of Matrix as size x size
 802            */
 803            template <typename Integer,
 804                    typename = typename std::enable_if<std::is_integral<Integer>::value,
 805                    Integer>::type>
 806            explicit Matrix(const Integer size) {
 807                for (size_t i = 0; i < size; ++i) {
 808                    _mat.emplace_back(std::vector<T>(size, 0));
 809                }
 810            }
 811    }"
 812    .unindent();
 813
 814    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 815
 816    assert_documents_eq(
 817        &documents,
 818        &[
 819            (
 820                "
 821        /**
 822         * @brief Main function
 823         * @returns 0 on exit
 824         */
 825        int main() { return 0; }"
 826                    .unindent(),
 827                54,
 828            ),
 829            (
 830                "
 831                /**
 832                * This is a test comment
 833                */
 834                class MyClass {       // The class
 835                    public:           // Access specifier
 836                    int myNum;        // Attribute (int variable)
 837                    string myString;  // Attribute (string variable)
 838                }"
 839                .unindent(),
 840                112,
 841            ),
 842            (
 843                "
 844                // This is a test comment
 845                enum Color { red, green, blue }"
 846                    .unindent(),
 847                322,
 848            ),
 849            (
 850                "
 851                /** This is a preceding block comment
 852                 * This is the second line
 853                 */
 854                struct {           // Structure declaration
 855                    int myNum;       // Member (int variable)
 856                    string myString; // Member (string variable)
 857                } myStructure;"
 858                    .unindent(),
 859                425,
 860            ),
 861            (
 862                "
 863                /**
 864                 * @brief Matrix class.
 865                 */
 866                template <typename T,
 867                          typename = typename std::enable_if<
 868                            std::is_integral<T>::value || std::is_floating_point<T>::value,
 869                            bool>::type>
 870                class Matrix2 {
 871                    std::vector<std::vector<T>> _mat;
 872
 873                    public:
 874                        /**
 875                        * @brief Constructor
 876                        * @tparam Integer ensuring integers are being evaluated and not other
 877                        * data types.
 878                        * @param size denoting the size of Matrix as size x size
 879                        */
 880                        template <typename Integer,
 881                                typename = typename std::enable_if<std::is_integral<Integer>::value,
 882                                Integer>::type>
 883                        explicit Matrix(const Integer size) {
 884                            for (size_t i = 0; i < size; ++i) {
 885                                _mat.emplace_back(std::vector<T>(size, 0));
 886                            }
 887                        }
 888                }"
 889                .unindent(),
 890                612,
 891            ),
 892            (
 893                "
 894                explicit Matrix(const Integer size) {
 895                    for (size_t i = 0; i < size; ++i) {
 896                        _mat.emplace_back(std::vector<T>(size, 0));
 897                    }
 898                }"
 899                .unindent(),
 900                1226,
 901            ),
 902        ],
 903    );
 904}
 905
 906#[gpui::test]
 907async fn test_code_context_retrieval_ruby() {
 908    let language = ruby_lang();
 909    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 910    let mut retriever = CodeContextRetriever::new(embedding_provider);
 911
 912    let text = r#"
 913        # This concern is inspired by "sudo mode" on GitHub. It
 914        # is a way to re-authenticate a user before allowing them
 915        # to see or perform an action.
 916        #
 917        # Add `before_action :require_challenge!` to actions you
 918        # want to protect.
 919        #
 920        # The user will be shown a page to enter the challenge (which
 921        # is either the password, or just the username when no
 922        # password exists). Upon passing, there is a grace period
 923        # during which no challenge will be asked from the user.
 924        #
 925        # Accessing challenge-protected resources during the grace
 926        # period will refresh the grace period.
 927        module ChallengableConcern
 928            extend ActiveSupport::Concern
 929
 930            CHALLENGE_TIMEOUT = 1.hour.freeze
 931
 932            def require_challenge!
 933                return if skip_challenge?
 934
 935                if challenge_passed_recently?
 936                    session[:challenge_passed_at] = Time.now.utc
 937                    return
 938                end
 939
 940                @challenge = Form::Challenge.new(return_to: request.url)
 941
 942                if params.key?(:form_challenge)
 943                    if challenge_passed?
 944                        session[:challenge_passed_at] = Time.now.utc
 945                    else
 946                        flash.now[:alert] = I18n.t('challenge.invalid_password')
 947                        render_challenge
 948                    end
 949                else
 950                    render_challenge
 951                end
 952            end
 953
 954            def challenge_passed?
 955                current_user.valid_password?(challenge_params[:current_password])
 956            end
 957        end
 958
 959        class Animal
 960            include Comparable
 961
 962            attr_reader :legs
 963
 964            def initialize(name, legs)
 965                @name, @legs = name, legs
 966            end
 967
 968            def <=>(other)
 969                legs <=> other.legs
 970            end
 971        end
 972
 973        # Singleton method for car object
 974        def car.wheels
 975            puts "There are four wheels"
 976        end"#
 977        .unindent();
 978
 979    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 980
 981    assert_documents_eq(
 982        &documents,
 983        &[
 984            (
 985                r#"
 986        # This concern is inspired by "sudo mode" on GitHub. It
 987        # is a way to re-authenticate a user before allowing them
 988        # to see or perform an action.
 989        #
 990        # Add `before_action :require_challenge!` to actions you
 991        # want to protect.
 992        #
 993        # The user will be shown a page to enter the challenge (which
 994        # is either the password, or just the username when no
 995        # password exists). Upon passing, there is a grace period
 996        # during which no challenge will be asked from the user.
 997        #
 998        # Accessing challenge-protected resources during the grace
 999        # period will refresh the grace period.
1000        module ChallengableConcern
1001            extend ActiveSupport::Concern
1002
1003            CHALLENGE_TIMEOUT = 1.hour.freeze
1004
1005            def require_challenge!
1006                # ...
1007            end
1008
1009            def challenge_passed?
1010                # ...
1011            end
1012        end"#
1013                    .unindent(),
1014                558,
1015            ),
1016            (
1017                r#"
1018            def require_challenge!
1019                return if skip_challenge?
1020
1021                if challenge_passed_recently?
1022                    session[:challenge_passed_at] = Time.now.utc
1023                    return
1024                end
1025
1026                @challenge = Form::Challenge.new(return_to: request.url)
1027
1028                if params.key?(:form_challenge)
1029                    if challenge_passed?
1030                        session[:challenge_passed_at] = Time.now.utc
1031                    else
1032                        flash.now[:alert] = I18n.t('challenge.invalid_password')
1033                        render_challenge
1034                    end
1035                else
1036                    render_challenge
1037                end
1038            end"#
1039                    .unindent(),
1040                663,
1041            ),
1042            (
1043                r#"
1044                def challenge_passed?
1045                    current_user.valid_password?(challenge_params[:current_password])
1046                end"#
1047                    .unindent(),
1048                1254,
1049            ),
1050            (
1051                r#"
1052                class Animal
1053                    include Comparable
1054
1055                    attr_reader :legs
1056
1057                    def initialize(name, legs)
1058                        # ...
1059                    end
1060
1061                    def <=>(other)
1062                        # ...
1063                    end
1064                end"#
1065                    .unindent(),
1066                1363,
1067            ),
1068            (
1069                r#"
1070                def initialize(name, legs)
1071                    @name, @legs = name, legs
1072                end"#
1073                    .unindent(),
1074                1427,
1075            ),
1076            (
1077                r#"
1078                def <=>(other)
1079                    legs <=> other.legs
1080                end"#
1081                    .unindent(),
1082                1501,
1083            ),
1084            (
1085                r#"
1086                # Singleton method for car object
1087                def car.wheels
1088                    puts "There are four wheels"
1089                end"#
1090                    .unindent(),
1091                1591,
1092            ),
1093        ],
1094    );
1095}
1096
1097#[gpui::test]
1098async fn test_code_context_retrieval_php() {
1099    let language = php_lang();
1100    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
1101    let mut retriever = CodeContextRetriever::new(embedding_provider);
1102
1103    let text = r#"
1104        <?php
1105
1106        namespace LevelUp\Experience\Concerns;
1107
1108        /*
1109        This is a multiple-lines comment block
1110        that spans over multiple
1111        lines
1112        */
1113        function functionName() {
1114            echo "Hello world!";
1115        }
1116
1117        trait HasAchievements
1118        {
1119            /**
1120            * @throws \Exception
1121            */
1122            public function grantAchievement(Achievement $achievement, $progress = null): void
1123            {
1124                if ($progress > 100) {
1125                    throw new Exception(message: 'Progress cannot be greater than 100');
1126                }
1127
1128                if ($this->achievements()->find($achievement->id)) {
1129                    throw new Exception(message: 'User already has this Achievement');
1130                }
1131
1132                $this->achievements()->attach($achievement, [
1133                    'progress' => $progress ?? null,
1134                ]);
1135
1136                $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1137            }
1138
1139            public function achievements(): BelongsToMany
1140            {
1141                return $this->belongsToMany(related: Achievement::class)
1142                ->withPivot(columns: 'progress')
1143                ->where('is_secret', false)
1144                ->using(AchievementUser::class);
1145            }
1146        }
1147
1148        interface Multiplier
1149        {
1150            public function qualifies(array $data): bool;
1151
1152            public function setMultiplier(): int;
1153        }
1154
1155        enum AuditType: string
1156        {
1157            case Add = 'add';
1158            case Remove = 'remove';
1159            case Reset = 'reset';
1160            case LevelUp = 'level_up';
1161        }
1162
1163        ?>"#
1164    .unindent();
1165
1166    let documents = retriever.parse_file(&text, language.clone()).unwrap();
1167
1168    assert_documents_eq(
1169        &documents,
1170        &[
1171            (
1172                r#"
1173        /*
1174        This is a multiple-lines comment block
1175        that spans over multiple
1176        lines
1177        */
1178        function functionName() {
1179            echo "Hello world!";
1180        }"#
1181                .unindent(),
1182                123,
1183            ),
1184            (
1185                r#"
1186        trait HasAchievements
1187        {
1188            /**
1189            * @throws \Exception
1190            */
1191            public function grantAchievement(Achievement $achievement, $progress = null): void
1192            {/* ... */}
1193
1194            public function achievements(): BelongsToMany
1195            {/* ... */}
1196        }"#
1197                .unindent(),
1198                177,
1199            ),
1200            (r#"
1201            /**
1202            * @throws \Exception
1203            */
1204            public function grantAchievement(Achievement $achievement, $progress = null): void
1205            {
1206                if ($progress > 100) {
1207                    throw new Exception(message: 'Progress cannot be greater than 100');
1208                }
1209
1210                if ($this->achievements()->find($achievement->id)) {
1211                    throw new Exception(message: 'User already has this Achievement');
1212                }
1213
1214                $this->achievements()->attach($achievement, [
1215                    'progress' => $progress ?? null,
1216                ]);
1217
1218                $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1219            }"#.unindent(), 245),
1220            (r#"
1221                public function achievements(): BelongsToMany
1222                {
1223                    return $this->belongsToMany(related: Achievement::class)
1224                    ->withPivot(columns: 'progress')
1225                    ->where('is_secret', false)
1226                    ->using(AchievementUser::class);
1227                }"#.unindent(), 902),
1228            (r#"
1229                interface Multiplier
1230                {
1231                    public function qualifies(array $data): bool;
1232
1233                    public function setMultiplier(): int;
1234                }"#.unindent(),
1235                1146),
1236            (r#"
1237                enum AuditType: string
1238                {
1239                    case Add = 'add';
1240                    case Remove = 'remove';
1241                    case Reset = 'reset';
1242                    case LevelUp = 'level_up';
1243                }"#.unindent(), 1265)
1244        ],
1245    );
1246}
1247
1248fn js_lang() -> Arc<Language> {
1249    Arc::new(
1250        Language::new(
1251            LanguageConfig {
1252                name: "Javascript".into(),
1253                matcher: LanguageMatcher {
1254                    path_suffixes: vec!["js".into()],
1255                    ..Default::default()
1256                },
1257                ..Default::default()
1258            },
1259            Some(tree_sitter_typescript::language_tsx()),
1260        )
1261        .with_embedding_query(
1262            &r#"
1263
1264            (
1265                (comment)* @context
1266                .
1267                [
1268                (export_statement
1269                    (function_declaration
1270                        "async"? @name
1271                        "function" @name
1272                        name: (_) @name))
1273                (function_declaration
1274                    "async"? @name
1275                    "function" @name
1276                    name: (_) @name)
1277                ] @item
1278            )
1279
1280            (
1281                (comment)* @context
1282                .
1283                [
1284                (export_statement
1285                    (class_declaration
1286                        "class" @name
1287                        name: (_) @name))
1288                (class_declaration
1289                    "class" @name
1290                    name: (_) @name)
1291                ] @item
1292            )
1293
1294            (
1295                (comment)* @context
1296                .
1297                [
1298                (export_statement
1299                    (interface_declaration
1300                        "interface" @name
1301                        name: (_) @name))
1302                (interface_declaration
1303                    "interface" @name
1304                    name: (_) @name)
1305                ] @item
1306            )
1307
1308            (
1309                (comment)* @context
1310                .
1311                [
1312                (export_statement
1313                    (enum_declaration
1314                        "enum" @name
1315                        name: (_) @name))
1316                (enum_declaration
1317                    "enum" @name
1318                    name: (_) @name)
1319                ] @item
1320            )
1321
1322            (
1323                (comment)* @context
1324                .
1325                (method_definition
1326                    [
1327                        "get"
1328                        "set"
1329                        "async"
1330                        "*"
1331                        "static"
1332                    ]* @name
1333                    name: (_) @name) @item
1334            )
1335
1336                    "#
1337            .unindent(),
1338        )
1339        .unwrap(),
1340    )
1341}
1342
1343fn rust_lang() -> Arc<Language> {
1344    Arc::new(
1345        Language::new(
1346            LanguageConfig {
1347                name: "Rust".into(),
1348                matcher: LanguageMatcher {
1349                    path_suffixes: vec!["rs".into()],
1350                    ..Default::default()
1351                },
1352                collapsed_placeholder: " /* ... */ ".to_string(),
1353                ..Default::default()
1354            },
1355            Some(tree_sitter_rust::language()),
1356        )
1357        .with_embedding_query(
1358            r#"
1359            (
1360                [(line_comment) (attribute_item)]* @context
1361                .
1362                [
1363                    (struct_item
1364                        name: (_) @name)
1365
1366                    (enum_item
1367                        name: (_) @name)
1368
1369                    (impl_item
1370                        trait: (_)? @name
1371                        "for"? @name
1372                        type: (_) @name)
1373
1374                    (trait_item
1375                        name: (_) @name)
1376
1377                    (function_item
1378                        name: (_) @name
1379                        body: (block
1380                            "{" @keep
1381                            "}" @keep) @collapse)
1382
1383                    (macro_definition
1384                        name: (_) @name)
1385                ] @item
1386            )
1387
1388            (attribute_item) @collapse
1389            (use_declaration) @collapse
1390            "#,
1391        )
1392        .unwrap(),
1393    )
1394}
1395
1396fn json_lang() -> Arc<Language> {
1397    Arc::new(
1398        Language::new(
1399            LanguageConfig {
1400                name: "JSON".into(),
1401                matcher: LanguageMatcher {
1402                    path_suffixes: vec!["json".into()],
1403                    ..Default::default()
1404                },
1405                ..Default::default()
1406            },
1407            Some(tree_sitter_json::language()),
1408        )
1409        .with_embedding_query(
1410            r#"
1411            (document) @item
1412
1413            (array
1414                "[" @keep
1415                .
1416                (object)? @keep
1417                "]" @keep) @collapse
1418
1419            (pair value: (string
1420                "\"" @keep
1421                "\"" @keep) @collapse)
1422            "#,
1423        )
1424        .unwrap(),
1425    )
1426}
1427
1428fn toml_lang() -> Arc<Language> {
1429    Arc::new(Language::new(
1430        LanguageConfig {
1431            name: "TOML".into(),
1432            matcher: LanguageMatcher {
1433                path_suffixes: vec!["toml".into()],
1434                ..Default::default()
1435            },
1436            ..Default::default()
1437        },
1438        Some(tree_sitter_toml::language()),
1439    ))
1440}
1441
1442fn cpp_lang() -> Arc<Language> {
1443    Arc::new(
1444        Language::new(
1445            LanguageConfig {
1446                name: "CPP".into(),
1447                matcher: LanguageMatcher {
1448                    path_suffixes: vec!["cpp".into()],
1449                    ..Default::default()
1450                },
1451                ..Default::default()
1452            },
1453            Some(tree_sitter_cpp::language()),
1454        )
1455        .with_embedding_query(
1456            r#"
1457            (
1458                (comment)* @context
1459                .
1460                (function_definition
1461                    (type_qualifier)? @name
1462                    type: (_)? @name
1463                    declarator: [
1464                        (function_declarator
1465                            declarator: (_) @name)
1466                        (pointer_declarator
1467                            "*" @name
1468                            declarator: (function_declarator
1469                            declarator: (_) @name))
1470                        (pointer_declarator
1471                            "*" @name
1472                            declarator: (pointer_declarator
1473                                "*" @name
1474                            declarator: (function_declarator
1475                                declarator: (_) @name)))
1476                        (reference_declarator
1477                            ["&" "&&"] @name
1478                            (function_declarator
1479                            declarator: (_) @name))
1480                    ]
1481                    (type_qualifier)? @name) @item
1482                )
1483
1484            (
1485                (comment)* @context
1486                .
1487                (template_declaration
1488                    (class_specifier
1489                        "class" @name
1490                        name: (_) @name)
1491                        ) @item
1492            )
1493
1494            (
1495                (comment)* @context
1496                .
1497                (class_specifier
1498                    "class" @name
1499                    name: (_) @name) @item
1500                )
1501
1502            (
1503                (comment)* @context
1504                .
1505                (enum_specifier
1506                    "enum" @name
1507                    name: (_) @name) @item
1508                )
1509
1510            (
1511                (comment)* @context
1512                .
1513                (declaration
1514                    type: (struct_specifier
1515                    "struct" @name)
1516                    declarator: (_) @name) @item
1517            )
1518
1519            "#,
1520        )
1521        .unwrap(),
1522    )
1523}
1524
1525fn lua_lang() -> Arc<Language> {
1526    Arc::new(
1527        Language::new(
1528            LanguageConfig {
1529                name: "Lua".into(),
1530                matcher: LanguageMatcher {
1531                    path_suffixes: vec!["lua".into()],
1532                    ..Default::default()
1533                },
1534                collapsed_placeholder: "--[ ... ]--".to_string(),
1535                ..Default::default()
1536            },
1537            Some(tree_sitter_lua::language()),
1538        )
1539        .with_embedding_query(
1540            r#"
1541            (
1542                (comment)* @context
1543                .
1544                (function_declaration
1545                    "function" @name
1546                    name: (_) @name
1547                    (comment)* @collapse
1548                    body: (block) @collapse
1549                ) @item
1550            )
1551        "#,
1552        )
1553        .unwrap(),
1554    )
1555}
1556
1557fn php_lang() -> Arc<Language> {
1558    Arc::new(
1559        Language::new(
1560            LanguageConfig {
1561                name: "PHP".into(),
1562                matcher: LanguageMatcher {
1563                    path_suffixes: vec!["php".into()],
1564                    ..Default::default()
1565                },
1566                collapsed_placeholder: "/* ... */".into(),
1567                ..Default::default()
1568            },
1569            Some(tree_sitter_php::language_php()),
1570        )
1571        .with_embedding_query(
1572            r#"
1573            (
1574                (comment)* @context
1575                .
1576                [
1577                    (function_definition
1578                        "function" @name
1579                        name: (_) @name
1580                        body: (_
1581                            "{" @keep
1582                            "}" @keep) @collapse
1583                        )
1584
1585                    (trait_declaration
1586                        "trait" @name
1587                        name: (_) @name)
1588
1589                    (method_declaration
1590                        "function" @name
1591                        name: (_) @name
1592                        body: (_
1593                            "{" @keep
1594                            "}" @keep) @collapse
1595                        )
1596
1597                    (interface_declaration
1598                        "interface" @name
1599                        name: (_) @name
1600                        )
1601
1602                    (enum_declaration
1603                        "enum" @name
1604                        name: (_) @name
1605                        )
1606
1607                ] @item
1608            )
1609            "#,
1610        )
1611        .unwrap(),
1612    )
1613}
1614
1615fn ruby_lang() -> Arc<Language> {
1616    Arc::new(
1617        Language::new(
1618            LanguageConfig {
1619                name: "Ruby".into(),
1620                matcher: LanguageMatcher {
1621                    path_suffixes: vec!["rb".into()],
1622                    ..Default::default()
1623                },
1624                collapsed_placeholder: "# ...".to_string(),
1625                ..Default::default()
1626            },
1627            Some(tree_sitter_ruby::language()),
1628        )
1629        .with_embedding_query(
1630            r#"
1631            (
1632                (comment)* @context
1633                .
1634                [
1635                (module
1636                    "module" @name
1637                    name: (_) @name)
1638                (method
1639                    "def" @name
1640                    name: (_) @name
1641                    body: (body_statement) @collapse)
1642                (class
1643                    "class" @name
1644                    name: (_) @name)
1645                (singleton_method
1646                    "def" @name
1647                    object: (_) @name
1648                    "." @name
1649                    name: (_) @name
1650                    body: (body_statement) @collapse)
1651                ] @item
1652            )
1653            "#,
1654        )
1655        .unwrap(),
1656    )
1657}
1658
1659fn elixir_lang() -> Arc<Language> {
1660    Arc::new(
1661        Language::new(
1662            LanguageConfig {
1663                name: "Elixir".into(),
1664                matcher: LanguageMatcher {
1665                    path_suffixes: vec!["rs".into()],
1666                    ..Default::default()
1667                },
1668                ..Default::default()
1669            },
1670            Some(tree_sitter_elixir::language()),
1671        )
1672        .with_embedding_query(
1673            r#"
1674            (
1675                (unary_operator
1676                    operator: "@"
1677                    operand: (call
1678                        target: (identifier) @unary
1679                        (#match? @unary "^(doc)$"))
1680                    ) @context
1681                .
1682                (call
1683                target: (identifier) @name
1684                (arguments
1685                [
1686                (identifier) @name
1687                (call
1688                target: (identifier) @name)
1689                (binary_operator
1690                left: (call
1691                target: (identifier) @name)
1692                operator: "when")
1693                ])
1694                (#any-match? @name "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @item
1695                )
1696
1697            (call
1698                target: (identifier) @name
1699                (arguments (alias) @name)
1700                (#any-match? @name "^(defmodule|defprotocol)$")) @item
1701            "#,
1702        )
1703        .unwrap(),
1704    )
1705}
1706
1707#[gpui::test]
1708fn test_subtract_ranges() {
1709    assert_eq!(
1710        subtract_ranges(&[0..5, 10..21], &[0..1, 4..5]),
1711        vec![1..4, 10..21]
1712    );
1713
1714    assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
1715}
1716
1717fn init_test(cx: &mut TestAppContext) {
1718    cx.update(|cx| {
1719        let settings_store = SettingsStore::test(cx);
1720        cx.set_global(settings_store);
1721        SemanticIndexSettings::register(cx);
1722        language::init(cx);
1723        Project::init_settings(cx);
1724    });
1725}