semantic_index_tests.rs

   1use crate::{
   2    embedding_queue::EmbeddingQueue,
   3    parsing::{subtract_ranges, CodeContextRetriever, Span, SpanDigest},
   4    semantic_index_settings::SemanticIndexSettings,
   5    FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
   6};
   7use ai::test::FakeEmbeddingProvider;
   8
   9use gpui::{Task, TestAppContext};
  10use language::{Language, LanguageConfig, LanguageMatcher, LanguageRegistry, ToOffset};
  11use parking_lot::Mutex;
  12use pretty_assertions::assert_eq;
  13use project::{project_settings::ProjectSettings, FakeFs, Fs, Project};
  14use rand::{rngs::StdRng, Rng};
  15use serde_json::json;
  16use settings::{Settings, SettingsStore};
  17use std::{path::Path, sync::Arc, time::SystemTime};
  18use unindent::Unindent;
  19use util::{paths::PathMatcher, RandomCharIter};
  20
  21#[ctor::ctor]
  22fn init_logger() {
  23    if std::env::var("RUST_LOG").is_ok() {
  24        env_logger::init();
  25    }
  26}
  27
  28#[gpui::test]
  29async fn test_semantic_index(cx: &mut TestAppContext) {
  30    init_test(cx);
  31
  32    let fs = FakeFs::new(cx.background_executor.clone());
  33    fs.insert_tree(
  34        "/the-root",
  35        json!({
  36            "src": {
  37                "file1.rs": "
  38                    fn aaa() {
  39                        println!(\"aaaaaaaaaaaa!\");
  40                    }
  41
  42                    fn zzzzz() {
  43                        println!(\"SLEEPING\");
  44                    }
  45                ".unindent(),
  46                "file2.rs": "
  47                    fn bbb() {
  48                        println!(\"bbbbbbbbbbbbb!\");
  49                    }
  50                    struct pqpqpqp {}
  51                ".unindent(),
  52                "file3.toml": "
  53                    ZZZZZZZZZZZZZZZZZZ = 5
  54                ".unindent(),
  55            }
  56        }),
  57    )
  58    .await;
  59
  60    let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
  61    let rust_language = rust_lang();
  62    let toml_language = toml_lang();
  63    languages.add(rust_language);
  64    languages.add(toml_language);
  65
  66    let db_dir = tempfile::Builder::new()
  67        .prefix("vector-store")
  68        .tempdir()
  69        .unwrap();
  70    let db_path = db_dir.path().join("db.sqlite");
  71
  72    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
  73    let semantic_index = SemanticIndex::new(
  74        fs.clone(),
  75        db_path,
  76        embedding_provider.clone(),
  77        languages,
  78        cx.to_async(),
  79    )
  80    .await
  81    .unwrap();
  82
  83    let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
  84
  85    let search_results = semantic_index.update(cx, |store, cx| {
  86        store.search_project(
  87            project.clone(),
  88            "aaaaaabbbbzz".to_string(),
  89            5,
  90            vec![],
  91            vec![],
  92            cx,
  93        )
  94    });
  95    let pending_file_count =
  96        semantic_index.read_with(cx, |index, _| index.pending_file_count(&project).unwrap());
  97    cx.background_executor.run_until_parked();
  98    assert_eq!(*pending_file_count.borrow(), 3);
  99    cx.background_executor
 100        .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
 101    assert_eq!(*pending_file_count.borrow(), 0);
 102
 103    let search_results = search_results.await.unwrap();
 104    assert_search_results(
 105        &search_results,
 106        &[
 107            (Path::new("src/file1.rs").into(), 0),
 108            (Path::new("src/file2.rs").into(), 0),
 109            (Path::new("src/file3.toml").into(), 0),
 110            (Path::new("src/file1.rs").into(), 45),
 111            (Path::new("src/file2.rs").into(), 45),
 112        ],
 113        cx,
 114    );
 115
 116    // Test Include Files Functionality
 117    let include_files = vec![PathMatcher::new("*.rs").unwrap()];
 118    let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
 119    let rust_only_search_results = semantic_index
 120        .update(cx, |store, cx| {
 121            store.search_project(
 122                project.clone(),
 123                "aaaaaabbbbzz".to_string(),
 124                5,
 125                include_files,
 126                vec![],
 127                cx,
 128            )
 129        })
 130        .await
 131        .unwrap();
 132
 133    assert_search_results(
 134        &rust_only_search_results,
 135        &[
 136            (Path::new("src/file1.rs").into(), 0),
 137            (Path::new("src/file2.rs").into(), 0),
 138            (Path::new("src/file1.rs").into(), 45),
 139            (Path::new("src/file2.rs").into(), 45),
 140        ],
 141        cx,
 142    );
 143
 144    let no_rust_search_results = semantic_index
 145        .update(cx, |store, cx| {
 146            store.search_project(
 147                project.clone(),
 148                "aaaaaabbbbzz".to_string(),
 149                5,
 150                vec![],
 151                exclude_files,
 152                cx,
 153            )
 154        })
 155        .await
 156        .unwrap();
 157
 158    assert_search_results(
 159        &no_rust_search_results,
 160        &[(Path::new("src/file3.toml").into(), 0)],
 161        cx,
 162    );
 163
 164    fs.save(
 165        "/the-root/src/file2.rs".as_ref(),
 166        &"
 167            fn dddd() { println!(\"ddddd!\"); }
 168            struct pqpqpqp {}
 169        "
 170        .unindent()
 171        .into(),
 172        Default::default(),
 173    )
 174    .await
 175    .unwrap();
 176
 177    cx.background_executor
 178        .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
 179
 180    let prev_embedding_count = embedding_provider.embedding_count();
 181    let index = semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx));
 182    cx.background_executor.run_until_parked();
 183    assert_eq!(*pending_file_count.borrow(), 1);
 184    cx.background_executor
 185        .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
 186    assert_eq!(*pending_file_count.borrow(), 0);
 187    index.await.unwrap();
 188
 189    assert_eq!(
 190        embedding_provider.embedding_count() - prev_embedding_count,
 191        1
 192    );
 193}
 194
 195#[gpui::test(iterations = 10)]
 196async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
 197    let (outstanding_job_count, _) = postage::watch::channel_with(0);
 198    let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count));
 199
 200    let files = (1..=3)
 201        .map(|file_ix| FileToEmbed {
 202            worktree_id: 5,
 203            path: Path::new(&format!("path-{file_ix}")).into(),
 204            mtime: SystemTime::now(),
 205            spans: (0..rng.gen_range(4..22))
 206                .map(|document_ix| {
 207                    let content_len = rng.gen_range(10..100);
 208                    let content = RandomCharIter::new(&mut rng)
 209                        .with_simple_text()
 210                        .take(content_len)
 211                        .collect::<String>();
 212                    let digest = SpanDigest::from(content.as_str());
 213                    Span {
 214                        range: 0..10,
 215                        embedding: None,
 216                        name: format!("document {document_ix}"),
 217                        content,
 218                        digest,
 219                        token_count: rng.gen_range(10..30),
 220                    }
 221                })
 222                .collect(),
 223            job_handle: JobHandle::new(&outstanding_job_count),
 224        })
 225        .collect::<Vec<_>>();
 226
 227    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 228
 229    let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor.clone());
 230    for file in &files {
 231        queue.push(file.clone());
 232    }
 233    queue.flush();
 234
 235    cx.background_executor.run_until_parked();
 236    let finished_files = queue.finished_files();
 237    let mut embedded_files: Vec<_> = files
 238        .iter()
 239        .map(|_| finished_files.try_recv().expect("no finished file"))
 240        .collect();
 241
 242    let expected_files: Vec<_> = files
 243        .iter()
 244        .map(|file| {
 245            let mut file = file.clone();
 246            for doc in &mut file.spans {
 247                doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
 248            }
 249            file
 250        })
 251        .collect();
 252
 253    embedded_files.sort_by_key(|f| f.path.clone());
 254
 255    assert_eq!(embedded_files, expected_files);
 256}
 257
 258#[track_caller]
 259fn assert_search_results(
 260    actual: &[SearchResult],
 261    expected: &[(Arc<Path>, usize)],
 262    cx: &TestAppContext,
 263) {
 264    let actual = actual
 265        .iter()
 266        .map(|search_result| {
 267            search_result.buffer.read_with(cx, |buffer, _cx| {
 268                (
 269                    buffer.file().unwrap().path().clone(),
 270                    search_result.range.start.to_offset(buffer),
 271                )
 272            })
 273        })
 274        .collect::<Vec<_>>();
 275    assert_eq!(actual, expected);
 276}
 277
 278#[gpui::test]
 279async fn test_code_context_retrieval_rust() {
 280    let language = rust_lang();
 281    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 282    let mut retriever = CodeContextRetriever::new(embedding_provider);
 283
 284    let text = "
 285        /// A doc comment
 286        /// that spans multiple lines
 287        #[gpui::test]
 288        fn a() {
 289            b
 290        }
 291
 292        impl C for D {
 293        }
 294
 295        impl E {
 296            // This is also a preceding comment
 297            pub fn function_1() -> Option<()> {
 298                unimplemented!();
 299            }
 300
 301            // This is a preceding comment
 302            fn function_2() -> Result<()> {
 303                unimplemented!();
 304            }
 305        }
 306
 307        #[derive(Clone)]
 308        struct D {
 309            name: String
 310        }
 311    "
 312    .unindent();
 313
 314    let documents = retriever.parse_file(&text, language).unwrap();
 315
 316    assert_documents_eq(
 317        &documents,
 318        &[
 319            (
 320                "
 321                /// A doc comment
 322                /// that spans multiple lines
 323                #[gpui::test]
 324                fn a() {
 325                    b
 326                }"
 327                .unindent(),
 328                text.find("fn a").unwrap(),
 329            ),
 330            (
 331                "
 332                impl C for D {
 333                }"
 334                .unindent(),
 335                text.find("impl C").unwrap(),
 336            ),
 337            (
 338                "
 339                impl E {
 340                    // This is also a preceding comment
 341                    pub fn function_1() -> Option<()> { /* ... */ }
 342
 343                    // This is a preceding comment
 344                    fn function_2() -> Result<()> { /* ... */ }
 345                }"
 346                .unindent(),
 347                text.find("impl E").unwrap(),
 348            ),
 349            (
 350                "
 351                // This is also a preceding comment
 352                pub fn function_1() -> Option<()> {
 353                    unimplemented!();
 354                }"
 355                .unindent(),
 356                text.find("pub fn function_1").unwrap(),
 357            ),
 358            (
 359                "
 360                // This is a preceding comment
 361                fn function_2() -> Result<()> {
 362                    unimplemented!();
 363                }"
 364                .unindent(),
 365                text.find("fn function_2").unwrap(),
 366            ),
 367            (
 368                "
 369                #[derive(Clone)]
 370                struct D {
 371                    name: String
 372                }"
 373                .unindent(),
 374                text.find("struct D").unwrap(),
 375            ),
 376        ],
 377    );
 378}
 379
 380#[gpui::test]
 381async fn test_code_context_retrieval_json() {
 382    let language = json_lang();
 383    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 384    let mut retriever = CodeContextRetriever::new(embedding_provider);
 385
 386    let text = r#"
 387        {
 388            "array": [1, 2, 3, 4],
 389            "string": "abcdefg",
 390            "nested_object": {
 391                "array_2": [5, 6, 7, 8],
 392                "string_2": "hijklmnop",
 393                "boolean": true,
 394                "none": null
 395            }
 396        }
 397    "#
 398    .unindent();
 399
 400    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 401
 402    assert_documents_eq(
 403        &documents,
 404        &[(
 405            r#"
 406                {
 407                    "array": [],
 408                    "string": "",
 409                    "nested_object": {
 410                        "array_2": [],
 411                        "string_2": "",
 412                        "boolean": true,
 413                        "none": null
 414                    }
 415                }"#
 416            .unindent(),
 417            text.find('{').unwrap(),
 418        )],
 419    );
 420
 421    let text = r#"
 422        [
 423            {
 424                "name": "somebody",
 425                "age": 42
 426            },
 427            {
 428                "name": "somebody else",
 429                "age": 43
 430            }
 431        ]
 432    "#
 433    .unindent();
 434
 435    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 436
 437    assert_documents_eq(
 438        &documents,
 439        &[(
 440            r#"
 441            [{
 442                    "name": "",
 443                    "age": 42
 444                }]"#
 445            .unindent(),
 446            text.find('[').unwrap(),
 447        )],
 448    );
 449}
 450
 451fn assert_documents_eq(
 452    documents: &[Span],
 453    expected_contents_and_start_offsets: &[(String, usize)],
 454) {
 455    assert_eq!(
 456        documents
 457            .iter()
 458            .map(|document| (document.content.clone(), document.range.start))
 459            .collect::<Vec<_>>(),
 460        expected_contents_and_start_offsets
 461    );
 462}
 463
 464#[gpui::test]
 465async fn test_code_context_retrieval_javascript() {
 466    let language = js_lang();
 467    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 468    let mut retriever = CodeContextRetriever::new(embedding_provider);
 469
 470    let text = "
 471        /* globals importScripts, backend */
 472        function _authorize() {}
 473
 474        /**
 475         * Sometimes the frontend build is way faster than backend.
 476         */
 477        export async function authorizeBank() {
 478            _authorize(pushModal, upgradingAccountId, {});
 479        }
 480
 481        export class SettingsPage {
 482            /* This is a test setting */
 483            constructor(page) {
 484                this.page = page;
 485            }
 486        }
 487
 488        /* This is a test comment */
 489        class TestClass {}
 490
 491        /* Schema for editor_events in Clickhouse. */
 492        export interface ClickhouseEditorEvent {
 493            installation_id: string
 494            operation: string
 495        }
 496        "
 497    .unindent();
 498
 499    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 500
 501    assert_documents_eq(
 502        &documents,
 503        &[
 504            (
 505                "
 506            /* globals importScripts, backend */
 507            function _authorize() {}"
 508                    .unindent(),
 509                37,
 510            ),
 511            (
 512                "
 513            /**
 514             * Sometimes the frontend build is way faster than backend.
 515             */
 516            export async function authorizeBank() {
 517                _authorize(pushModal, upgradingAccountId, {});
 518            }"
 519                .unindent(),
 520                131,
 521            ),
 522            (
 523                "
 524                export class SettingsPage {
 525                    /* This is a test setting */
 526                    constructor(page) {
 527                        this.page = page;
 528                    }
 529                }"
 530                .unindent(),
 531                225,
 532            ),
 533            (
 534                "
 535                /* This is a test setting */
 536                constructor(page) {
 537                    this.page = page;
 538                }"
 539                .unindent(),
 540                290,
 541            ),
 542            (
 543                "
 544                /* This is a test comment */
 545                class TestClass {}"
 546                    .unindent(),
 547                374,
 548            ),
 549            (
 550                "
 551                /* Schema for editor_events in Clickhouse. */
 552                export interface ClickhouseEditorEvent {
 553                    installation_id: string
 554                    operation: string
 555                }"
 556                .unindent(),
 557                440,
 558            ),
 559        ],
 560    )
 561}
 562
 563#[gpui::test]
 564async fn test_code_context_retrieval_lua() {
 565    let language = lua_lang();
 566    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 567    let mut retriever = CodeContextRetriever::new(embedding_provider);
 568
 569    let text = r#"
 570        -- Creates a new class
 571        -- @param baseclass The Baseclass of this class, or nil.
 572        -- @return A new class reference.
 573        function classes.class(baseclass)
 574            -- Create the class definition and metatable.
 575            local classdef = {}
 576            -- Find the super class, either Object or user-defined.
 577            baseclass = baseclass or classes.Object
 578            -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
 579            setmetatable(classdef, { __index = baseclass })
 580            -- All class instances have a reference to the class object.
 581            classdef.class = classdef
 582            --- Recursively allocates the inheritance tree of the instance.
 583            -- @param mastertable The 'root' of the inheritance tree.
 584            -- @return Returns the instance with the allocated inheritance tree.
 585            function classdef.alloc(mastertable)
 586                -- All class instances have a reference to a superclass object.
 587                local instance = { super = baseclass.alloc(mastertable) }
 588                -- Any functions this instance does not know of will 'look up' to the superclass definition.
 589                setmetatable(instance, { __index = classdef, __newindex = mastertable })
 590                return instance
 591            end
 592        end
 593        "#.unindent();
 594
 595    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 596
 597    assert_documents_eq(
 598        &documents,
 599        &[
 600            (r#"
 601                -- Creates a new class
 602                -- @param baseclass The Baseclass of this class, or nil.
 603                -- @return A new class reference.
 604                function classes.class(baseclass)
 605                    -- Create the class definition and metatable.
 606                    local classdef = {}
 607                    -- Find the super class, either Object or user-defined.
 608                    baseclass = baseclass or classes.Object
 609                    -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
 610                    setmetatable(classdef, { __index = baseclass })
 611                    -- All class instances have a reference to the class object.
 612                    classdef.class = classdef
 613                    --- Recursively allocates the inheritance tree of the instance.
 614                    -- @param mastertable The 'root' of the inheritance tree.
 615                    -- @return Returns the instance with the allocated inheritance tree.
 616                    function classdef.alloc(mastertable)
 617                        --[ ... ]--
 618                        --[ ... ]--
 619                    end
 620                end"#.unindent(),
 621            114),
 622            (r#"
 623            --- Recursively allocates the inheritance tree of the instance.
 624            -- @param mastertable The 'root' of the inheritance tree.
 625            -- @return Returns the instance with the allocated inheritance tree.
 626            function classdef.alloc(mastertable)
 627                -- All class instances have a reference to a superclass object.
 628                local instance = { super = baseclass.alloc(mastertable) }
 629                -- Any functions this instance does not know of will 'look up' to the superclass definition.
 630                setmetatable(instance, { __index = classdef, __newindex = mastertable })
 631                return instance
 632            end"#.unindent(), 810),
 633        ]
 634    );
 635}
 636
 637#[gpui::test]
 638async fn test_code_context_retrieval_elixir() {
 639    let language = elixir_lang();
 640    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 641    let mut retriever = CodeContextRetriever::new(embedding_provider);
 642
 643    let text = r#"
 644        defmodule File.Stream do
 645            @moduledoc """
 646            Defines a `File.Stream` struct returned by `File.stream!/3`.
 647
 648            The following fields are public:
 649
 650            * `path`          - the file path
 651            * `modes`         - the file modes
 652            * `raw`           - a boolean indicating if bin functions should be used
 653            * `line_or_bytes` - if reading should read lines or a given number of bytes
 654            * `node`          - the node the file belongs to
 655
 656            """
 657
 658            defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
 659
 660            @type t :: %__MODULE__{}
 661
 662            @doc false
 663            def __build__(path, modes, line_or_bytes) do
 664            raw = :lists.keyfind(:encoding, 1, modes) == false
 665
 666            modes =
 667                case raw do
 668                true ->
 669                    case :lists.keyfind(:read_ahead, 1, modes) do
 670                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
 671                    {:read_ahead, _} -> [:raw | modes]
 672                    false -> [:raw, :read_ahead | modes]
 673                    end
 674
 675                false ->
 676                    modes
 677                end
 678
 679            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
 680
 681            end"#
 682    .unindent();
 683
 684    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 685
 686    assert_documents_eq(
 687        &documents,
 688        &[(
 689            r#"
 690        defmodule File.Stream do
 691            @moduledoc """
 692            Defines a `File.Stream` struct returned by `File.stream!/3`.
 693
 694            The following fields are public:
 695
 696            * `path`          - the file path
 697            * `modes`         - the file modes
 698            * `raw`           - a boolean indicating if bin functions should be used
 699            * `line_or_bytes` - if reading should read lines or a given number of bytes
 700            * `node`          - the node the file belongs to
 701
 702            """
 703
 704            defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
 705
 706            @type t :: %__MODULE__{}
 707
 708            @doc false
 709            def __build__(path, modes, line_or_bytes) do
 710            raw = :lists.keyfind(:encoding, 1, modes) == false
 711
 712            modes =
 713                case raw do
 714                true ->
 715                    case :lists.keyfind(:read_ahead, 1, modes) do
 716                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
 717                    {:read_ahead, _} -> [:raw | modes]
 718                    false -> [:raw, :read_ahead | modes]
 719                    end
 720
 721                false ->
 722                    modes
 723                end
 724
 725            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
 726
 727            end"#
 728                .unindent(),
 729            0,
 730        ),(r#"
 731            @doc false
 732            def __build__(path, modes, line_or_bytes) do
 733            raw = :lists.keyfind(:encoding, 1, modes) == false
 734
 735            modes =
 736                case raw do
 737                true ->
 738                    case :lists.keyfind(:read_ahead, 1, modes) do
 739                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
 740                    {:read_ahead, _} -> [:raw | modes]
 741                    false -> [:raw, :read_ahead | modes]
 742                    end
 743
 744                false ->
 745                    modes
 746                end
 747
 748            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
 749
 750            end"#.unindent(), 574)],
 751    );
 752}
 753
 754#[gpui::test]
 755async fn test_code_context_retrieval_cpp() {
 756    let language = cpp_lang();
 757    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 758    let mut retriever = CodeContextRetriever::new(embedding_provider);
 759
 760    let text = "
 761    /**
 762     * @brief Main function
 763     * @returns 0 on exit
 764     */
 765    int main() { return 0; }
 766
 767    /**
 768    * This is a test comment
 769    */
 770    class MyClass {       // The class
 771        public:           // Access specifier
 772        int myNum;        // Attribute (int variable)
 773        string myString;  // Attribute (string variable)
 774    };
 775
 776    // This is a test comment
 777    enum Color { red, green, blue };
 778
 779    /** This is a preceding block comment
 780     * This is the second line
 781     */
 782    struct {           // Structure declaration
 783        int myNum;       // Member (int variable)
 784        string myString; // Member (string variable)
 785    } myStructure;
 786
 787    /**
 788     * @brief Matrix class.
 789     */
 790    template <typename T,
 791              typename = typename std::enable_if<
 792                std::is_integral<T>::value || std::is_floating_point<T>::value,
 793                bool>::type>
 794    class Matrix2 {
 795        std::vector<std::vector<T>> _mat;
 796
 797        public:
 798            /**
 799            * @brief Constructor
 800            * @tparam Integer ensuring integers are being evaluated and not other
 801            * data types.
 802            * @param size denoting the size of Matrix as size x size
 803            */
 804            template <typename Integer,
 805                    typename = typename std::enable_if<std::is_integral<Integer>::value,
 806                    Integer>::type>
 807            explicit Matrix(const Integer size) {
 808                for (size_t i = 0; i < size; ++i) {
 809                    _mat.emplace_back(std::vector<T>(size, 0));
 810                }
 811            }
 812    }"
 813    .unindent();
 814
 815    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 816
 817    assert_documents_eq(
 818        &documents,
 819        &[
 820            (
 821                "
 822        /**
 823         * @brief Main function
 824         * @returns 0 on exit
 825         */
 826        int main() { return 0; }"
 827                    .unindent(),
 828                54,
 829            ),
 830            (
 831                "
 832                /**
 833                * This is a test comment
 834                */
 835                class MyClass {       // The class
 836                    public:           // Access specifier
 837                    int myNum;        // Attribute (int variable)
 838                    string myString;  // Attribute (string variable)
 839                }"
 840                .unindent(),
 841                112,
 842            ),
 843            (
 844                "
 845                // This is a test comment
 846                enum Color { red, green, blue }"
 847                    .unindent(),
 848                322,
 849            ),
 850            (
 851                "
 852                /** This is a preceding block comment
 853                 * This is the second line
 854                 */
 855                struct {           // Structure declaration
 856                    int myNum;       // Member (int variable)
 857                    string myString; // Member (string variable)
 858                } myStructure;"
 859                    .unindent(),
 860                425,
 861            ),
 862            (
 863                "
 864                /**
 865                 * @brief Matrix class.
 866                 */
 867                template <typename T,
 868                          typename = typename std::enable_if<
 869                            std::is_integral<T>::value || std::is_floating_point<T>::value,
 870                            bool>::type>
 871                class Matrix2 {
 872                    std::vector<std::vector<T>> _mat;
 873
 874                    public:
 875                        /**
 876                        * @brief Constructor
 877                        * @tparam Integer ensuring integers are being evaluated and not other
 878                        * data types.
 879                        * @param size denoting the size of Matrix as size x size
 880                        */
 881                        template <typename Integer,
 882                                typename = typename std::enable_if<std::is_integral<Integer>::value,
 883                                Integer>::type>
 884                        explicit Matrix(const Integer size) {
 885                            for (size_t i = 0; i < size; ++i) {
 886                                _mat.emplace_back(std::vector<T>(size, 0));
 887                            }
 888                        }
 889                }"
 890                .unindent(),
 891                612,
 892            ),
 893            (
 894                "
 895                explicit Matrix(const Integer size) {
 896                    for (size_t i = 0; i < size; ++i) {
 897                        _mat.emplace_back(std::vector<T>(size, 0));
 898                    }
 899                }"
 900                .unindent(),
 901                1226,
 902            ),
 903        ],
 904    );
 905}
 906
 907#[gpui::test]
 908async fn test_code_context_retrieval_ruby() {
 909    let language = ruby_lang();
 910    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 911    let mut retriever = CodeContextRetriever::new(embedding_provider);
 912
 913    let text = r#"
 914        # This concern is inspired by "sudo mode" on GitHub. It
 915        # is a way to re-authenticate a user before allowing them
 916        # to see or perform an action.
 917        #
 918        # Add `before_action :require_challenge!` to actions you
 919        # want to protect.
 920        #
 921        # The user will be shown a page to enter the challenge (which
 922        # is either the password, or just the username when no
 923        # password exists). Upon passing, there is a grace period
 924        # during which no challenge will be asked from the user.
 925        #
 926        # Accessing challenge-protected resources during the grace
 927        # period will refresh the grace period.
 928        module ChallengableConcern
 929            extend ActiveSupport::Concern
 930
 931            CHALLENGE_TIMEOUT = 1.hour.freeze
 932
 933            def require_challenge!
 934                return if skip_challenge?
 935
 936                if challenge_passed_recently?
 937                    session[:challenge_passed_at] = Time.now.utc
 938                    return
 939                end
 940
 941                @challenge = Form::Challenge.new(return_to: request.url)
 942
 943                if params.key?(:form_challenge)
 944                    if challenge_passed?
 945                        session[:challenge_passed_at] = Time.now.utc
 946                    else
 947                        flash.now[:alert] = I18n.t('challenge.invalid_password')
 948                        render_challenge
 949                    end
 950                else
 951                    render_challenge
 952                end
 953            end
 954
 955            def challenge_passed?
 956                current_user.valid_password?(challenge_params[:current_password])
 957            end
 958        end
 959
 960        class Animal
 961            include Comparable
 962
 963            attr_reader :legs
 964
 965            def initialize(name, legs)
 966                @name, @legs = name, legs
 967            end
 968
 969            def <=>(other)
 970                legs <=> other.legs
 971            end
 972        end
 973
 974        # Singleton method for car object
 975        def car.wheels
 976            puts "There are four wheels"
 977        end"#
 978        .unindent();
 979
 980    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 981
 982    assert_documents_eq(
 983        &documents,
 984        &[
 985            (
 986                r#"
 987        # This concern is inspired by "sudo mode" on GitHub. It
 988        # is a way to re-authenticate a user before allowing them
 989        # to see or perform an action.
 990        #
 991        # Add `before_action :require_challenge!` to actions you
 992        # want to protect.
 993        #
 994        # The user will be shown a page to enter the challenge (which
 995        # is either the password, or just the username when no
 996        # password exists). Upon passing, there is a grace period
 997        # during which no challenge will be asked from the user.
 998        #
 999        # Accessing challenge-protected resources during the grace
1000        # period will refresh the grace period.
1001        module ChallengableConcern
1002            extend ActiveSupport::Concern
1003
1004            CHALLENGE_TIMEOUT = 1.hour.freeze
1005
1006            def require_challenge!
1007                # ...
1008            end
1009
1010            def challenge_passed?
1011                # ...
1012            end
1013        end"#
1014                    .unindent(),
1015                558,
1016            ),
1017            (
1018                r#"
1019            def require_challenge!
1020                return if skip_challenge?
1021
1022                if challenge_passed_recently?
1023                    session[:challenge_passed_at] = Time.now.utc
1024                    return
1025                end
1026
1027                @challenge = Form::Challenge.new(return_to: request.url)
1028
1029                if params.key?(:form_challenge)
1030                    if challenge_passed?
1031                        session[:challenge_passed_at] = Time.now.utc
1032                    else
1033                        flash.now[:alert] = I18n.t('challenge.invalid_password')
1034                        render_challenge
1035                    end
1036                else
1037                    render_challenge
1038                end
1039            end"#
1040                    .unindent(),
1041                663,
1042            ),
1043            (
1044                r#"
1045                def challenge_passed?
1046                    current_user.valid_password?(challenge_params[:current_password])
1047                end"#
1048                    .unindent(),
1049                1254,
1050            ),
1051            (
1052                r#"
1053                class Animal
1054                    include Comparable
1055
1056                    attr_reader :legs
1057
1058                    def initialize(name, legs)
1059                        # ...
1060                    end
1061
1062                    def <=>(other)
1063                        # ...
1064                    end
1065                end"#
1066                    .unindent(),
1067                1363,
1068            ),
1069            (
1070                r#"
1071                def initialize(name, legs)
1072                    @name, @legs = name, legs
1073                end"#
1074                    .unindent(),
1075                1427,
1076            ),
1077            (
1078                r#"
1079                def <=>(other)
1080                    legs <=> other.legs
1081                end"#
1082                    .unindent(),
1083                1501,
1084            ),
1085            (
1086                r#"
1087                # Singleton method for car object
1088                def car.wheels
1089                    puts "There are four wheels"
1090                end"#
1091                    .unindent(),
1092                1591,
1093            ),
1094        ],
1095    );
1096}
1097
1098#[gpui::test]
1099async fn test_code_context_retrieval_php() {
1100    let language = php_lang();
1101    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
1102    let mut retriever = CodeContextRetriever::new(embedding_provider);
1103
1104    let text = r#"
1105        <?php
1106
1107        namespace LevelUp\Experience\Concerns;
1108
1109        /*
1110        This is a multiple-lines comment block
1111        that spans over multiple
1112        lines
1113        */
1114        function functionName() {
1115            echo "Hello world!";
1116        }
1117
1118        trait HasAchievements
1119        {
1120            /**
1121            * @throws \Exception
1122            */
1123            public function grantAchievement(Achievement $achievement, $progress = null): void
1124            {
1125                if ($progress > 100) {
1126                    throw new Exception(message: 'Progress cannot be greater than 100');
1127                }
1128
1129                if ($this->achievements()->find($achievement->id)) {
1130                    throw new Exception(message: 'User already has this Achievement');
1131                }
1132
1133                $this->achievements()->attach($achievement, [
1134                    'progress' => $progress ?? null,
1135                ]);
1136
1137                $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1138            }
1139
1140            public function achievements(): BelongsToMany
1141            {
1142                return $this->belongsToMany(related: Achievement::class)
1143                ->withPivot(columns: 'progress')
1144                ->where('is_secret', false)
1145                ->using(AchievementUser::class);
1146            }
1147        }
1148
1149        interface Multiplier
1150        {
1151            public function qualifies(array $data): bool;
1152
1153            public function setMultiplier(): int;
1154        }
1155
1156        enum AuditType: string
1157        {
1158            case Add = 'add';
1159            case Remove = 'remove';
1160            case Reset = 'reset';
1161            case LevelUp = 'level_up';
1162        }
1163
1164        ?>"#
1165    .unindent();
1166
1167    let documents = retriever.parse_file(&text, language.clone()).unwrap();
1168
1169    assert_documents_eq(
1170        &documents,
1171        &[
1172            (
1173                r#"
1174        /*
1175        This is a multiple-lines comment block
1176        that spans over multiple
1177        lines
1178        */
1179        function functionName() {
1180            echo "Hello world!";
1181        }"#
1182                .unindent(),
1183                123,
1184            ),
1185            (
1186                r#"
1187        trait HasAchievements
1188        {
1189            /**
1190            * @throws \Exception
1191            */
1192            public function grantAchievement(Achievement $achievement, $progress = null): void
1193            {/* ... */}
1194
1195            public function achievements(): BelongsToMany
1196            {/* ... */}
1197        }"#
1198                .unindent(),
1199                177,
1200            ),
1201            (r#"
1202            /**
1203            * @throws \Exception
1204            */
1205            public function grantAchievement(Achievement $achievement, $progress = null): void
1206            {
1207                if ($progress > 100) {
1208                    throw new Exception(message: 'Progress cannot be greater than 100');
1209                }
1210
1211                if ($this->achievements()->find($achievement->id)) {
1212                    throw new Exception(message: 'User already has this Achievement');
1213                }
1214
1215                $this->achievements()->attach($achievement, [
1216                    'progress' => $progress ?? null,
1217                ]);
1218
1219                $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1220            }"#.unindent(), 245),
1221            (r#"
1222                public function achievements(): BelongsToMany
1223                {
1224                    return $this->belongsToMany(related: Achievement::class)
1225                    ->withPivot(columns: 'progress')
1226                    ->where('is_secret', false)
1227                    ->using(AchievementUser::class);
1228                }"#.unindent(), 902),
1229            (r#"
1230                interface Multiplier
1231                {
1232                    public function qualifies(array $data): bool;
1233
1234                    public function setMultiplier(): int;
1235                }"#.unindent(),
1236                1146),
1237            (r#"
1238                enum AuditType: string
1239                {
1240                    case Add = 'add';
1241                    case Remove = 'remove';
1242                    case Reset = 'reset';
1243                    case LevelUp = 'level_up';
1244                }"#.unindent(), 1265)
1245        ],
1246    );
1247}
1248
1249fn js_lang() -> Arc<Language> {
1250    Arc::new(
1251        Language::new(
1252            LanguageConfig {
1253                name: "Javascript".into(),
1254                matcher: LanguageMatcher {
1255                    path_suffixes: vec!["js".into()],
1256                    ..Default::default()
1257                },
1258                ..Default::default()
1259            },
1260            Some(tree_sitter_typescript::language_tsx()),
1261        )
1262        .with_embedding_query(
1263            &r#"
1264
1265            (
1266                (comment)* @context
1267                .
1268                [
1269                (export_statement
1270                    (function_declaration
1271                        "async"? @name
1272                        "function" @name
1273                        name: (_) @name))
1274                (function_declaration
1275                    "async"? @name
1276                    "function" @name
1277                    name: (_) @name)
1278                ] @item
1279            )
1280
1281            (
1282                (comment)* @context
1283                .
1284                [
1285                (export_statement
1286                    (class_declaration
1287                        "class" @name
1288                        name: (_) @name))
1289                (class_declaration
1290                    "class" @name
1291                    name: (_) @name)
1292                ] @item
1293            )
1294
1295            (
1296                (comment)* @context
1297                .
1298                [
1299                (export_statement
1300                    (interface_declaration
1301                        "interface" @name
1302                        name: (_) @name))
1303                (interface_declaration
1304                    "interface" @name
1305                    name: (_) @name)
1306                ] @item
1307            )
1308
1309            (
1310                (comment)* @context
1311                .
1312                [
1313                (export_statement
1314                    (enum_declaration
1315                        "enum" @name
1316                        name: (_) @name))
1317                (enum_declaration
1318                    "enum" @name
1319                    name: (_) @name)
1320                ] @item
1321            )
1322
1323            (
1324                (comment)* @context
1325                .
1326                (method_definition
1327                    [
1328                        "get"
1329                        "set"
1330                        "async"
1331                        "*"
1332                        "static"
1333                    ]* @name
1334                    name: (_) @name) @item
1335            )
1336
1337                    "#
1338            .unindent(),
1339        )
1340        .unwrap(),
1341    )
1342}
1343
1344fn rust_lang() -> Arc<Language> {
1345    Arc::new(
1346        Language::new(
1347            LanguageConfig {
1348                name: "Rust".into(),
1349                matcher: LanguageMatcher {
1350                    path_suffixes: vec!["rs".into()],
1351                    ..Default::default()
1352                },
1353                collapsed_placeholder: " /* ... */ ".to_string(),
1354                ..Default::default()
1355            },
1356            Some(tree_sitter_rust::language()),
1357        )
1358        .with_embedding_query(
1359            r#"
1360            (
1361                [(line_comment) (attribute_item)]* @context
1362                .
1363                [
1364                    (struct_item
1365                        name: (_) @name)
1366
1367                    (enum_item
1368                        name: (_) @name)
1369
1370                    (impl_item
1371                        trait: (_)? @name
1372                        "for"? @name
1373                        type: (_) @name)
1374
1375                    (trait_item
1376                        name: (_) @name)
1377
1378                    (function_item
1379                        name: (_) @name
1380                        body: (block
1381                            "{" @keep
1382                            "}" @keep) @collapse)
1383
1384                    (macro_definition
1385                        name: (_) @name)
1386                ] @item
1387            )
1388
1389            (attribute_item) @collapse
1390            (use_declaration) @collapse
1391            "#,
1392        )
1393        .unwrap(),
1394    )
1395}
1396
1397fn json_lang() -> Arc<Language> {
1398    Arc::new(
1399        Language::new(
1400            LanguageConfig {
1401                name: "JSON".into(),
1402                matcher: LanguageMatcher {
1403                    path_suffixes: vec!["json".into()],
1404                    ..Default::default()
1405                },
1406                ..Default::default()
1407            },
1408            Some(tree_sitter_json::language()),
1409        )
1410        .with_embedding_query(
1411            r#"
1412            (document) @item
1413
1414            (array
1415                "[" @keep
1416                .
1417                (object)? @keep
1418                "]" @keep) @collapse
1419
1420            (pair value: (string
1421                "\"" @keep
1422                "\"" @keep) @collapse)
1423            "#,
1424        )
1425        .unwrap(),
1426    )
1427}
1428
1429fn toml_lang() -> Arc<Language> {
1430    Arc::new(Language::new(
1431        LanguageConfig {
1432            name: "TOML".into(),
1433            matcher: LanguageMatcher {
1434                path_suffixes: vec!["toml".into()],
1435                ..Default::default()
1436            },
1437            ..Default::default()
1438        },
1439        Some(tree_sitter_toml::language()),
1440    ))
1441}
1442
1443fn cpp_lang() -> Arc<Language> {
1444    Arc::new(
1445        Language::new(
1446            LanguageConfig {
1447                name: "CPP".into(),
1448                matcher: LanguageMatcher {
1449                    path_suffixes: vec!["cpp".into()],
1450                    ..Default::default()
1451                },
1452                ..Default::default()
1453            },
1454            Some(tree_sitter_cpp::language()),
1455        )
1456        .with_embedding_query(
1457            r#"
1458            (
1459                (comment)* @context
1460                .
1461                (function_definition
1462                    (type_qualifier)? @name
1463                    type: (_)? @name
1464                    declarator: [
1465                        (function_declarator
1466                            declarator: (_) @name)
1467                        (pointer_declarator
1468                            "*" @name
1469                            declarator: (function_declarator
1470                            declarator: (_) @name))
1471                        (pointer_declarator
1472                            "*" @name
1473                            declarator: (pointer_declarator
1474                                "*" @name
1475                            declarator: (function_declarator
1476                                declarator: (_) @name)))
1477                        (reference_declarator
1478                            ["&" "&&"] @name
1479                            (function_declarator
1480                            declarator: (_) @name))
1481                    ]
1482                    (type_qualifier)? @name) @item
1483                )
1484
1485            (
1486                (comment)* @context
1487                .
1488                (template_declaration
1489                    (class_specifier
1490                        "class" @name
1491                        name: (_) @name)
1492                        ) @item
1493            )
1494
1495            (
1496                (comment)* @context
1497                .
1498                (class_specifier
1499                    "class" @name
1500                    name: (_) @name) @item
1501                )
1502
1503            (
1504                (comment)* @context
1505                .
1506                (enum_specifier
1507                    "enum" @name
1508                    name: (_) @name) @item
1509                )
1510
1511            (
1512                (comment)* @context
1513                .
1514                (declaration
1515                    type: (struct_specifier
1516                    "struct" @name)
1517                    declarator: (_) @name) @item
1518            )
1519
1520            "#,
1521        )
1522        .unwrap(),
1523    )
1524}
1525
1526fn lua_lang() -> Arc<Language> {
1527    Arc::new(
1528        Language::new(
1529            LanguageConfig {
1530                name: "Lua".into(),
1531                matcher: LanguageMatcher {
1532                    path_suffixes: vec!["lua".into()],
1533                    ..Default::default()
1534                },
1535                collapsed_placeholder: "--[ ... ]--".to_string(),
1536                ..Default::default()
1537            },
1538            Some(tree_sitter_lua::language()),
1539        )
1540        .with_embedding_query(
1541            r#"
1542            (
1543                (comment)* @context
1544                .
1545                (function_declaration
1546                    "function" @name
1547                    name: (_) @name
1548                    (comment)* @collapse
1549                    body: (block) @collapse
1550                ) @item
1551            )
1552        "#,
1553        )
1554        .unwrap(),
1555    )
1556}
1557
1558fn php_lang() -> Arc<Language> {
1559    Arc::new(
1560        Language::new(
1561            LanguageConfig {
1562                name: "PHP".into(),
1563                matcher: LanguageMatcher {
1564                    path_suffixes: vec!["php".into()],
1565                    ..Default::default()
1566                },
1567                collapsed_placeholder: "/* ... */".into(),
1568                ..Default::default()
1569            },
1570            Some(tree_sitter_php::language_php()),
1571        )
1572        .with_embedding_query(
1573            r#"
1574            (
1575                (comment)* @context
1576                .
1577                [
1578                    (function_definition
1579                        "function" @name
1580                        name: (_) @name
1581                        body: (_
1582                            "{" @keep
1583                            "}" @keep) @collapse
1584                        )
1585
1586                    (trait_declaration
1587                        "trait" @name
1588                        name: (_) @name)
1589
1590                    (method_declaration
1591                        "function" @name
1592                        name: (_) @name
1593                        body: (_
1594                            "{" @keep
1595                            "}" @keep) @collapse
1596                        )
1597
1598                    (interface_declaration
1599                        "interface" @name
1600                        name: (_) @name
1601                        )
1602
1603                    (enum_declaration
1604                        "enum" @name
1605                        name: (_) @name
1606                        )
1607
1608                ] @item
1609            )
1610            "#,
1611        )
1612        .unwrap(),
1613    )
1614}
1615
1616fn ruby_lang() -> Arc<Language> {
1617    Arc::new(
1618        Language::new(
1619            LanguageConfig {
1620                name: "Ruby".into(),
1621                matcher: LanguageMatcher {
1622                    path_suffixes: vec!["rb".into()],
1623                    ..Default::default()
1624                },
1625                collapsed_placeholder: "# ...".to_string(),
1626                ..Default::default()
1627            },
1628            Some(tree_sitter_ruby::language()),
1629        )
1630        .with_embedding_query(
1631            r#"
1632            (
1633                (comment)* @context
1634                .
1635                [
1636                (module
1637                    "module" @name
1638                    name: (_) @name)
1639                (method
1640                    "def" @name
1641                    name: (_) @name
1642                    body: (body_statement) @collapse)
1643                (class
1644                    "class" @name
1645                    name: (_) @name)
1646                (singleton_method
1647                    "def" @name
1648                    object: (_) @name
1649                    "." @name
1650                    name: (_) @name
1651                    body: (body_statement) @collapse)
1652                ] @item
1653            )
1654            "#,
1655        )
1656        .unwrap(),
1657    )
1658}
1659
1660fn elixir_lang() -> Arc<Language> {
1661    Arc::new(
1662        Language::new(
1663            LanguageConfig {
1664                name: "Elixir".into(),
1665                matcher: LanguageMatcher {
1666                    path_suffixes: vec!["rs".into()],
1667                    ..Default::default()
1668                },
1669                ..Default::default()
1670            },
1671            Some(tree_sitter_elixir::language()),
1672        )
1673        .with_embedding_query(
1674            r#"
1675            (
1676                (unary_operator
1677                    operator: "@"
1678                    operand: (call
1679                        target: (identifier) @unary
1680                        (#match? @unary "^(doc)$"))
1681                    ) @context
1682                .
1683                (call
1684                target: (identifier) @name
1685                (arguments
1686                [
1687                (identifier) @name
1688                (call
1689                target: (identifier) @name)
1690                (binary_operator
1691                left: (call
1692                target: (identifier) @name)
1693                operator: "when")
1694                ])
1695                (#any-match? @name "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @item
1696                )
1697
1698            (call
1699                target: (identifier) @name
1700                (arguments (alias) @name)
1701                (#any-match? @name "^(defmodule|defprotocol)$")) @item
1702            "#,
1703        )
1704        .unwrap(),
1705    )
1706}
1707
1708#[gpui::test]
1709fn test_subtract_ranges() {
1710    assert_eq!(
1711        subtract_ranges(&[0..5, 10..21], &[0..1, 4..5]),
1712        vec![1..4, 10..21]
1713    );
1714
1715    assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
1716}
1717
1718fn init_test(cx: &mut TestAppContext) {
1719    cx.update(|cx| {
1720        let settings_store = SettingsStore::test(cx);
1721        cx.set_global(settings_store);
1722        SemanticIndexSettings::register(cx);
1723        ProjectSettings::register(cx);
1724    });
1725}