semantic_index_tests.rs

   1use crate::{
   2    embedding_queue::EmbeddingQueue,
   3    parsing::{subtract_ranges, CodeContextRetriever, Span, SpanDigest},
   4    semantic_index_settings::SemanticIndexSettings,
   5    FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
   6};
   7use ai::test::FakeEmbeddingProvider;
   8
   9use gpui::{Task, TestAppContext};
  10use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
  11use parking_lot::Mutex;
  12use pretty_assertions::assert_eq;
  13use project::{project_settings::ProjectSettings, FakeFs, Fs, Project};
  14use rand::{rngs::StdRng, Rng};
  15use serde_json::json;
  16use settings::{Settings, SettingsStore};
  17use std::{path::Path, sync::Arc, time::SystemTime};
  18use unindent::Unindent;
  19use util::{paths::PathMatcher, RandomCharIter};
  20
  21#[ctor::ctor]
  22fn init_logger() {
  23    if std::env::var("RUST_LOG").is_ok() {
  24        env_logger::init();
  25    }
  26}
  27
  28#[gpui::test]
  29async fn test_semantic_index(cx: &mut TestAppContext) {
  30    init_test(cx);
  31
  32    let fs = FakeFs::new(cx.background_executor.clone());
  33    fs.insert_tree(
  34        "/the-root",
  35        json!({
  36            "src": {
  37                "file1.rs": "
  38                    fn aaa() {
  39                        println!(\"aaaaaaaaaaaa!\");
  40                    }
  41
  42                    fn zzzzz() {
  43                        println!(\"SLEEPING\");
  44                    }
  45                ".unindent(),
  46                "file2.rs": "
  47                    fn bbb() {
  48                        println!(\"bbbbbbbbbbbbb!\");
  49                    }
  50                    struct pqpqpqp {}
  51                ".unindent(),
  52                "file3.toml": "
  53                    ZZZZZZZZZZZZZZZZZZ = 5
  54                ".unindent(),
  55            }
  56        }),
  57    )
  58    .await;
  59
  60    let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
  61    let rust_language = rust_lang();
  62    let toml_language = toml_lang();
  63    languages.add(rust_language);
  64    languages.add(toml_language);
  65
  66    let db_dir = tempfile::Builder::new()
  67        .prefix("vector-store")
  68        .tempdir()
  69        .unwrap();
  70    let db_path = db_dir.path().join("db.sqlite");
  71
  72    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
  73    let semantic_index = SemanticIndex::new(
  74        fs.clone(),
  75        db_path,
  76        embedding_provider.clone(),
  77        languages,
  78        cx.to_async(),
  79    )
  80    .await
  81    .unwrap();
  82
  83    let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
  84
  85    let search_results = semantic_index.update(cx, |store, cx| {
  86        store.search_project(
  87            project.clone(),
  88            "aaaaaabbbbzz".to_string(),
  89            5,
  90            vec![],
  91            vec![],
  92            cx,
  93        )
  94    });
  95    let pending_file_count =
  96        semantic_index.read_with(cx, |index, _| index.pending_file_count(&project).unwrap());
  97    cx.background_executor.run_until_parked();
  98    assert_eq!(*pending_file_count.borrow(), 3);
  99    cx.background_executor
 100        .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
 101    assert_eq!(*pending_file_count.borrow(), 0);
 102
 103    let search_results = search_results.await.unwrap();
 104    assert_search_results(
 105        &search_results,
 106        &[
 107            (Path::new("src/file1.rs").into(), 0),
 108            (Path::new("src/file2.rs").into(), 0),
 109            (Path::new("src/file3.toml").into(), 0),
 110            (Path::new("src/file1.rs").into(), 45),
 111            (Path::new("src/file2.rs").into(), 45),
 112        ],
 113        cx,
 114    );
 115
 116    // Test Include Files Functionality
 117    let include_files = vec![PathMatcher::new("*.rs").unwrap()];
 118    let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
 119    let rust_only_search_results = semantic_index
 120        .update(cx, |store, cx| {
 121            store.search_project(
 122                project.clone(),
 123                "aaaaaabbbbzz".to_string(),
 124                5,
 125                include_files,
 126                vec![],
 127                cx,
 128            )
 129        })
 130        .await
 131        .unwrap();
 132
 133    assert_search_results(
 134        &rust_only_search_results,
 135        &[
 136            (Path::new("src/file1.rs").into(), 0),
 137            (Path::new("src/file2.rs").into(), 0),
 138            (Path::new("src/file1.rs").into(), 45),
 139            (Path::new("src/file2.rs").into(), 45),
 140        ],
 141        cx,
 142    );
 143
 144    let no_rust_search_results = semantic_index
 145        .update(cx, |store, cx| {
 146            store.search_project(
 147                project.clone(),
 148                "aaaaaabbbbzz".to_string(),
 149                5,
 150                vec![],
 151                exclude_files,
 152                cx,
 153            )
 154        })
 155        .await
 156        .unwrap();
 157
 158    assert_search_results(
 159        &no_rust_search_results,
 160        &[(Path::new("src/file3.toml").into(), 0)],
 161        cx,
 162    );
 163
 164    fs.save(
 165        "/the-root/src/file2.rs".as_ref(),
 166        &"
 167            fn dddd() { println!(\"ddddd!\"); }
 168            struct pqpqpqp {}
 169        "
 170        .unindent()
 171        .into(),
 172        Default::default(),
 173    )
 174    .await
 175    .unwrap();
 176
 177    cx.background_executor
 178        .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
 179
 180    let prev_embedding_count = embedding_provider.embedding_count();
 181    let index = semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx));
 182    cx.background_executor.run_until_parked();
 183    assert_eq!(*pending_file_count.borrow(), 1);
 184    cx.background_executor
 185        .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
 186    assert_eq!(*pending_file_count.borrow(), 0);
 187    index.await.unwrap();
 188
 189    assert_eq!(
 190        embedding_provider.embedding_count() - prev_embedding_count,
 191        1
 192    );
 193}
 194
 195#[gpui::test(iterations = 10)]
 196async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
 197    let (outstanding_job_count, _) = postage::watch::channel_with(0);
 198    let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count));
 199
 200    let files = (1..=3)
 201        .map(|file_ix| FileToEmbed {
 202            worktree_id: 5,
 203            path: Path::new(&format!("path-{file_ix}")).into(),
 204            mtime: SystemTime::now(),
 205            spans: (0..rng.gen_range(4..22))
 206                .map(|document_ix| {
 207                    let content_len = rng.gen_range(10..100);
 208                    let content = RandomCharIter::new(&mut rng)
 209                        .with_simple_text()
 210                        .take(content_len)
 211                        .collect::<String>();
 212                    let digest = SpanDigest::from(content.as_str());
 213                    Span {
 214                        range: 0..10,
 215                        embedding: None,
 216                        name: format!("document {document_ix}"),
 217                        content,
 218                        digest,
 219                        token_count: rng.gen_range(10..30),
 220                    }
 221                })
 222                .collect(),
 223            job_handle: JobHandle::new(&outstanding_job_count),
 224        })
 225        .collect::<Vec<_>>();
 226
 227    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 228
 229    let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor.clone());
 230    for file in &files {
 231        queue.push(file.clone());
 232    }
 233    queue.flush();
 234
 235    cx.background_executor.run_until_parked();
 236    let finished_files = queue.finished_files();
 237    let mut embedded_files: Vec<_> = files
 238        .iter()
 239        .map(|_| finished_files.try_recv().expect("no finished file"))
 240        .collect();
 241
 242    let expected_files: Vec<_> = files
 243        .iter()
 244        .map(|file| {
 245            let mut file = file.clone();
 246            for doc in &mut file.spans {
 247                doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
 248            }
 249            file
 250        })
 251        .collect();
 252
 253    embedded_files.sort_by_key(|f| f.path.clone());
 254
 255    assert_eq!(embedded_files, expected_files);
 256}
 257
 258#[track_caller]
 259fn assert_search_results(
 260    actual: &[SearchResult],
 261    expected: &[(Arc<Path>, usize)],
 262    cx: &TestAppContext,
 263) {
 264    let actual = actual
 265        .iter()
 266        .map(|search_result| {
 267            search_result.buffer.read_with(cx, |buffer, _cx| {
 268                (
 269                    buffer.file().unwrap().path().clone(),
 270                    search_result.range.start.to_offset(buffer),
 271                )
 272            })
 273        })
 274        .collect::<Vec<_>>();
 275    assert_eq!(actual, expected);
 276}
 277
 278#[gpui::test]
 279async fn test_code_context_retrieval_rust() {
 280    let language = rust_lang();
 281    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 282    let mut retriever = CodeContextRetriever::new(embedding_provider);
 283
 284    let text = "
 285        /// A doc comment
 286        /// that spans multiple lines
 287        #[gpui::test]
 288        fn a() {
 289            b
 290        }
 291
 292        impl C for D {
 293        }
 294
 295        impl E {
 296            // This is also a preceding comment
 297            pub fn function_1() -> Option<()> {
 298                unimplemented!();
 299            }
 300
 301            // This is a preceding comment
 302            fn function_2() -> Result<()> {
 303                unimplemented!();
 304            }
 305        }
 306
 307        #[derive(Clone)]
 308        struct D {
 309            name: String
 310        }
 311    "
 312    .unindent();
 313
 314    let documents = retriever.parse_file(&text, language).unwrap();
 315
 316    assert_documents_eq(
 317        &documents,
 318        &[
 319            (
 320                "
 321                /// A doc comment
 322                /// that spans multiple lines
 323                #[gpui::test]
 324                fn a() {
 325                    b
 326                }"
 327                .unindent(),
 328                text.find("fn a").unwrap(),
 329            ),
 330            (
 331                "
 332                impl C for D {
 333                }"
 334                .unindent(),
 335                text.find("impl C").unwrap(),
 336            ),
 337            (
 338                "
 339                impl E {
 340                    // This is also a preceding comment
 341                    pub fn function_1() -> Option<()> { /* ... */ }
 342
 343                    // This is a preceding comment
 344                    fn function_2() -> Result<()> { /* ... */ }
 345                }"
 346                .unindent(),
 347                text.find("impl E").unwrap(),
 348            ),
 349            (
 350                "
 351                // This is also a preceding comment
 352                pub fn function_1() -> Option<()> {
 353                    unimplemented!();
 354                }"
 355                .unindent(),
 356                text.find("pub fn function_1").unwrap(),
 357            ),
 358            (
 359                "
 360                // This is a preceding comment
 361                fn function_2() -> Result<()> {
 362                    unimplemented!();
 363                }"
 364                .unindent(),
 365                text.find("fn function_2").unwrap(),
 366            ),
 367            (
 368                "
 369                #[derive(Clone)]
 370                struct D {
 371                    name: String
 372                }"
 373                .unindent(),
 374                text.find("struct D").unwrap(),
 375            ),
 376        ],
 377    );
 378}
 379
 380#[gpui::test]
 381async fn test_code_context_retrieval_json() {
 382    let language = json_lang();
 383    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 384    let mut retriever = CodeContextRetriever::new(embedding_provider);
 385
 386    let text = r#"
 387        {
 388            "array": [1, 2, 3, 4],
 389            "string": "abcdefg",
 390            "nested_object": {
 391                "array_2": [5, 6, 7, 8],
 392                "string_2": "hijklmnop",
 393                "boolean": true,
 394                "none": null
 395            }
 396        }
 397    "#
 398    .unindent();
 399
 400    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 401
 402    assert_documents_eq(
 403        &documents,
 404        &[(
 405            r#"
 406                {
 407                    "array": [],
 408                    "string": "",
 409                    "nested_object": {
 410                        "array_2": [],
 411                        "string_2": "",
 412                        "boolean": true,
 413                        "none": null
 414                    }
 415                }"#
 416            .unindent(),
 417            text.find("{").unwrap(),
 418        )],
 419    );
 420
 421    let text = r#"
 422        [
 423            {
 424                "name": "somebody",
 425                "age": 42
 426            },
 427            {
 428                "name": "somebody else",
 429                "age": 43
 430            }
 431        ]
 432    "#
 433    .unindent();
 434
 435    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 436
 437    assert_documents_eq(
 438        &documents,
 439        &[(
 440            r#"
 441            [{
 442                    "name": "",
 443                    "age": 42
 444                }]"#
 445            .unindent(),
 446            text.find("[").unwrap(),
 447        )],
 448    );
 449}
 450
 451fn assert_documents_eq(
 452    documents: &[Span],
 453    expected_contents_and_start_offsets: &[(String, usize)],
 454) {
 455    assert_eq!(
 456        documents
 457            .iter()
 458            .map(|document| (document.content.clone(), document.range.start))
 459            .collect::<Vec<_>>(),
 460        expected_contents_and_start_offsets
 461    );
 462}
 463
 464#[gpui::test]
 465async fn test_code_context_retrieval_javascript() {
 466    let language = js_lang();
 467    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 468    let mut retriever = CodeContextRetriever::new(embedding_provider);
 469
 470    let text = "
 471        /* globals importScripts, backend */
 472        function _authorize() {}
 473
 474        /**
 475         * Sometimes the frontend build is way faster than backend.
 476         */
 477        export async function authorizeBank() {
 478            _authorize(pushModal, upgradingAccountId, {});
 479        }
 480
 481        export class SettingsPage {
 482            /* This is a test setting */
 483            constructor(page) {
 484                this.page = page;
 485            }
 486        }
 487
 488        /* This is a test comment */
 489        class TestClass {}
 490
 491        /* Schema for editor_events in Clickhouse. */
 492        export interface ClickhouseEditorEvent {
 493            installation_id: string
 494            operation: string
 495        }
 496        "
 497    .unindent();
 498
 499    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 500
 501    assert_documents_eq(
 502        &documents,
 503        &[
 504            (
 505                "
 506            /* globals importScripts, backend */
 507            function _authorize() {}"
 508                    .unindent(),
 509                37,
 510            ),
 511            (
 512                "
 513            /**
 514             * Sometimes the frontend build is way faster than backend.
 515             */
 516            export async function authorizeBank() {
 517                _authorize(pushModal, upgradingAccountId, {});
 518            }"
 519                .unindent(),
 520                131,
 521            ),
 522            (
 523                "
 524                export class SettingsPage {
 525                    /* This is a test setting */
 526                    constructor(page) {
 527                        this.page = page;
 528                    }
 529                }"
 530                .unindent(),
 531                225,
 532            ),
 533            (
 534                "
 535                /* This is a test setting */
 536                constructor(page) {
 537                    this.page = page;
 538                }"
 539                .unindent(),
 540                290,
 541            ),
 542            (
 543                "
 544                /* This is a test comment */
 545                class TestClass {}"
 546                    .unindent(),
 547                374,
 548            ),
 549            (
 550                "
 551                /* Schema for editor_events in Clickhouse. */
 552                export interface ClickhouseEditorEvent {
 553                    installation_id: string
 554                    operation: string
 555                }"
 556                .unindent(),
 557                440,
 558            ),
 559        ],
 560    )
 561}
 562
 563#[gpui::test]
 564async fn test_code_context_retrieval_lua() {
 565    let language = lua_lang();
 566    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 567    let mut retriever = CodeContextRetriever::new(embedding_provider);
 568
 569    let text = r#"
 570        -- Creates a new class
 571        -- @param baseclass The Baseclass of this class, or nil.
 572        -- @return A new class reference.
 573        function classes.class(baseclass)
 574            -- Create the class definition and metatable.
 575            local classdef = {}
 576            -- Find the super class, either Object or user-defined.
 577            baseclass = baseclass or classes.Object
 578            -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
 579            setmetatable(classdef, { __index = baseclass })
 580            -- All class instances have a reference to the class object.
 581            classdef.class = classdef
 582            --- Recursively allocates the inheritance tree of the instance.
 583            -- @param mastertable The 'root' of the inheritance tree.
 584            -- @return Returns the instance with the allocated inheritance tree.
 585            function classdef.alloc(mastertable)
 586                -- All class instances have a reference to a superclass object.
 587                local instance = { super = baseclass.alloc(mastertable) }
 588                -- Any functions this instance does not know of will 'look up' to the superclass definition.
 589                setmetatable(instance, { __index = classdef, __newindex = mastertable })
 590                return instance
 591            end
 592        end
 593        "#.unindent();
 594
 595    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 596
 597    assert_documents_eq(
 598        &documents,
 599        &[
 600            (r#"
 601                -- Creates a new class
 602                -- @param baseclass The Baseclass of this class, or nil.
 603                -- @return A new class reference.
 604                function classes.class(baseclass)
 605                    -- Create the class definition and metatable.
 606                    local classdef = {}
 607                    -- Find the super class, either Object or user-defined.
 608                    baseclass = baseclass or classes.Object
 609                    -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
 610                    setmetatable(classdef, { __index = baseclass })
 611                    -- All class instances have a reference to the class object.
 612                    classdef.class = classdef
 613                    --- Recursively allocates the inheritance tree of the instance.
 614                    -- @param mastertable The 'root' of the inheritance tree.
 615                    -- @return Returns the instance with the allocated inheritance tree.
 616                    function classdef.alloc(mastertable)
 617                        --[ ... ]--
 618                        --[ ... ]--
 619                    end
 620                end"#.unindent(),
 621            114),
 622            (r#"
 623            --- Recursively allocates the inheritance tree of the instance.
 624            -- @param mastertable The 'root' of the inheritance tree.
 625            -- @return Returns the instance with the allocated inheritance tree.
 626            function classdef.alloc(mastertable)
 627                -- All class instances have a reference to a superclass object.
 628                local instance = { super = baseclass.alloc(mastertable) }
 629                -- Any functions this instance does not know of will 'look up' to the superclass definition.
 630                setmetatable(instance, { __index = classdef, __newindex = mastertable })
 631                return instance
 632            end"#.unindent(), 810),
 633        ]
 634    );
 635}
 636
 637#[gpui::test]
 638async fn test_code_context_retrieval_elixir() {
 639    let language = elixir_lang();
 640    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 641    let mut retriever = CodeContextRetriever::new(embedding_provider);
 642
 643    let text = r#"
 644        defmodule File.Stream do
 645            @moduledoc """
 646            Defines a `File.Stream` struct returned by `File.stream!/3`.
 647
 648            The following fields are public:
 649
 650            * `path`          - the file path
 651            * `modes`         - the file modes
 652            * `raw`           - a boolean indicating if bin functions should be used
 653            * `line_or_bytes` - if reading should read lines or a given number of bytes
 654            * `node`          - the node the file belongs to
 655
 656            """
 657
 658            defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
 659
 660            @type t :: %__MODULE__{}
 661
 662            @doc false
 663            def __build__(path, modes, line_or_bytes) do
 664            raw = :lists.keyfind(:encoding, 1, modes) == false
 665
 666            modes =
 667                case raw do
 668                true ->
 669                    case :lists.keyfind(:read_ahead, 1, modes) do
 670                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
 671                    {:read_ahead, _} -> [:raw | modes]
 672                    false -> [:raw, :read_ahead | modes]
 673                    end
 674
 675                false ->
 676                    modes
 677                end
 678
 679            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
 680
 681            end"#
 682    .unindent();
 683
 684    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 685
 686    assert_documents_eq(
 687        &documents,
 688        &[(
 689            r#"
 690        defmodule File.Stream do
 691            @moduledoc """
 692            Defines a `File.Stream` struct returned by `File.stream!/3`.
 693
 694            The following fields are public:
 695
 696            * `path`          - the file path
 697            * `modes`         - the file modes
 698            * `raw`           - a boolean indicating if bin functions should be used
 699            * `line_or_bytes` - if reading should read lines or a given number of bytes
 700            * `node`          - the node the file belongs to
 701
 702            """
 703
 704            defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
 705
 706            @type t :: %__MODULE__{}
 707
 708            @doc false
 709            def __build__(path, modes, line_or_bytes) do
 710            raw = :lists.keyfind(:encoding, 1, modes) == false
 711
 712            modes =
 713                case raw do
 714                true ->
 715                    case :lists.keyfind(:read_ahead, 1, modes) do
 716                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
 717                    {:read_ahead, _} -> [:raw | modes]
 718                    false -> [:raw, :read_ahead | modes]
 719                    end
 720
 721                false ->
 722                    modes
 723                end
 724
 725            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
 726
 727            end"#
 728                .unindent(),
 729            0,
 730        ),(r#"
 731            @doc false
 732            def __build__(path, modes, line_or_bytes) do
 733            raw = :lists.keyfind(:encoding, 1, modes) == false
 734
 735            modes =
 736                case raw do
 737                true ->
 738                    case :lists.keyfind(:read_ahead, 1, modes) do
 739                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
 740                    {:read_ahead, _} -> [:raw | modes]
 741                    false -> [:raw, :read_ahead | modes]
 742                    end
 743
 744                false ->
 745                    modes
 746                end
 747
 748            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
 749
 750            end"#.unindent(), 574)],
 751    );
 752}
 753
 754#[gpui::test]
 755async fn test_code_context_retrieval_cpp() {
 756    let language = cpp_lang();
 757    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 758    let mut retriever = CodeContextRetriever::new(embedding_provider);
 759
 760    let text = "
 761    /**
 762     * @brief Main function
 763     * @returns 0 on exit
 764     */
 765    int main() { return 0; }
 766
 767    /**
 768    * This is a test comment
 769    */
 770    class MyClass {       // The class
 771        public:           // Access specifier
 772        int myNum;        // Attribute (int variable)
 773        string myString;  // Attribute (string variable)
 774    };
 775
 776    // This is a test comment
 777    enum Color { red, green, blue };
 778
 779    /** This is a preceding block comment
 780     * This is the second line
 781     */
 782    struct {           // Structure declaration
 783        int myNum;       // Member (int variable)
 784        string myString; // Member (string variable)
 785    } myStructure;
 786
 787    /**
 788     * @brief Matrix class.
 789     */
 790    template <typename T,
 791              typename = typename std::enable_if<
 792                std::is_integral<T>::value || std::is_floating_point<T>::value,
 793                bool>::type>
 794    class Matrix2 {
 795        std::vector<std::vector<T>> _mat;
 796
 797        public:
 798            /**
 799            * @brief Constructor
 800            * @tparam Integer ensuring integers are being evaluated and not other
 801            * data types.
 802            * @param size denoting the size of Matrix as size x size
 803            */
 804            template <typename Integer,
 805                    typename = typename std::enable_if<std::is_integral<Integer>::value,
 806                    Integer>::type>
 807            explicit Matrix(const Integer size) {
 808                for (size_t i = 0; i < size; ++i) {
 809                    _mat.emplace_back(std::vector<T>(size, 0));
 810                }
 811            }
 812    }"
 813    .unindent();
 814
 815    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 816
 817    assert_documents_eq(
 818        &documents,
 819        &[
 820            (
 821                "
 822        /**
 823         * @brief Main function
 824         * @returns 0 on exit
 825         */
 826        int main() { return 0; }"
 827                    .unindent(),
 828                54,
 829            ),
 830            (
 831                "
 832                /**
 833                * This is a test comment
 834                */
 835                class MyClass {       // The class
 836                    public:           // Access specifier
 837                    int myNum;        // Attribute (int variable)
 838                    string myString;  // Attribute (string variable)
 839                }"
 840                .unindent(),
 841                112,
 842            ),
 843            (
 844                "
 845                // This is a test comment
 846                enum Color { red, green, blue }"
 847                    .unindent(),
 848                322,
 849            ),
 850            (
 851                "
 852                /** This is a preceding block comment
 853                 * This is the second line
 854                 */
 855                struct {           // Structure declaration
 856                    int myNum;       // Member (int variable)
 857                    string myString; // Member (string variable)
 858                } myStructure;"
 859                    .unindent(),
 860                425,
 861            ),
 862            (
 863                "
 864                /**
 865                 * @brief Matrix class.
 866                 */
 867                template <typename T,
 868                          typename = typename std::enable_if<
 869                            std::is_integral<T>::value || std::is_floating_point<T>::value,
 870                            bool>::type>
 871                class Matrix2 {
 872                    std::vector<std::vector<T>> _mat;
 873
 874                    public:
 875                        /**
 876                        * @brief Constructor
 877                        * @tparam Integer ensuring integers are being evaluated and not other
 878                        * data types.
 879                        * @param size denoting the size of Matrix as size x size
 880                        */
 881                        template <typename Integer,
 882                                typename = typename std::enable_if<std::is_integral<Integer>::value,
 883                                Integer>::type>
 884                        explicit Matrix(const Integer size) {
 885                            for (size_t i = 0; i < size; ++i) {
 886                                _mat.emplace_back(std::vector<T>(size, 0));
 887                            }
 888                        }
 889                }"
 890                .unindent(),
 891                612,
 892            ),
 893            (
 894                "
 895                explicit Matrix(const Integer size) {
 896                    for (size_t i = 0; i < size; ++i) {
 897                        _mat.emplace_back(std::vector<T>(size, 0));
 898                    }
 899                }"
 900                .unindent(),
 901                1226,
 902            ),
 903        ],
 904    );
 905}
 906
 907#[gpui::test]
 908async fn test_code_context_retrieval_ruby() {
 909    let language = ruby_lang();
 910    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 911    let mut retriever = CodeContextRetriever::new(embedding_provider);
 912
 913    let text = r#"
 914        # This concern is inspired by "sudo mode" on GitHub. It
 915        # is a way to re-authenticate a user before allowing them
 916        # to see or perform an action.
 917        #
 918        # Add `before_action :require_challenge!` to actions you
 919        # want to protect.
 920        #
 921        # The user will be shown a page to enter the challenge (which
 922        # is either the password, or just the username when no
 923        # password exists). Upon passing, there is a grace period
 924        # during which no challenge will be asked from the user.
 925        #
 926        # Accessing challenge-protected resources during the grace
 927        # period will refresh the grace period.
 928        module ChallengableConcern
 929            extend ActiveSupport::Concern
 930
 931            CHALLENGE_TIMEOUT = 1.hour.freeze
 932
 933            def require_challenge!
 934                return if skip_challenge?
 935
 936                if challenge_passed_recently?
 937                    session[:challenge_passed_at] = Time.now.utc
 938                    return
 939                end
 940
 941                @challenge = Form::Challenge.new(return_to: request.url)
 942
 943                if params.key?(:form_challenge)
 944                    if challenge_passed?
 945                        session[:challenge_passed_at] = Time.now.utc
 946                    else
 947                        flash.now[:alert] = I18n.t('challenge.invalid_password')
 948                        render_challenge
 949                    end
 950                else
 951                    render_challenge
 952                end
 953            end
 954
 955            def challenge_passed?
 956                current_user.valid_password?(challenge_params[:current_password])
 957            end
 958        end
 959
 960        class Animal
 961            include Comparable
 962
 963            attr_reader :legs
 964
 965            def initialize(name, legs)
 966                @name, @legs = name, legs
 967            end
 968
 969            def <=>(other)
 970                legs <=> other.legs
 971            end
 972        end
 973
 974        # Singleton method for car object
 975        def car.wheels
 976            puts "There are four wheels"
 977        end"#
 978        .unindent();
 979
 980    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 981
 982    assert_documents_eq(
 983        &documents,
 984        &[
 985            (
 986                r#"
 987        # This concern is inspired by "sudo mode" on GitHub. It
 988        # is a way to re-authenticate a user before allowing them
 989        # to see or perform an action.
 990        #
 991        # Add `before_action :require_challenge!` to actions you
 992        # want to protect.
 993        #
 994        # The user will be shown a page to enter the challenge (which
 995        # is either the password, or just the username when no
 996        # password exists). Upon passing, there is a grace period
 997        # during which no challenge will be asked from the user.
 998        #
 999        # Accessing challenge-protected resources during the grace
1000        # period will refresh the grace period.
1001        module ChallengableConcern
1002            extend ActiveSupport::Concern
1003
1004            CHALLENGE_TIMEOUT = 1.hour.freeze
1005
1006            def require_challenge!
1007                # ...
1008            end
1009
1010            def challenge_passed?
1011                # ...
1012            end
1013        end"#
1014                    .unindent(),
1015                558,
1016            ),
1017            (
1018                r#"
1019            def require_challenge!
1020                return if skip_challenge?
1021
1022                if challenge_passed_recently?
1023                    session[:challenge_passed_at] = Time.now.utc
1024                    return
1025                end
1026
1027                @challenge = Form::Challenge.new(return_to: request.url)
1028
1029                if params.key?(:form_challenge)
1030                    if challenge_passed?
1031                        session[:challenge_passed_at] = Time.now.utc
1032                    else
1033                        flash.now[:alert] = I18n.t('challenge.invalid_password')
1034                        render_challenge
1035                    end
1036                else
1037                    render_challenge
1038                end
1039            end"#
1040                    .unindent(),
1041                663,
1042            ),
1043            (
1044                r#"
1045                def challenge_passed?
1046                    current_user.valid_password?(challenge_params[:current_password])
1047                end"#
1048                    .unindent(),
1049                1254,
1050            ),
1051            (
1052                r#"
1053                class Animal
1054                    include Comparable
1055
1056                    attr_reader :legs
1057
1058                    def initialize(name, legs)
1059                        # ...
1060                    end
1061
1062                    def <=>(other)
1063                        # ...
1064                    end
1065                end"#
1066                    .unindent(),
1067                1363,
1068            ),
1069            (
1070                r#"
1071                def initialize(name, legs)
1072                    @name, @legs = name, legs
1073                end"#
1074                    .unindent(),
1075                1427,
1076            ),
1077            (
1078                r#"
1079                def <=>(other)
1080                    legs <=> other.legs
1081                end"#
1082                    .unindent(),
1083                1501,
1084            ),
1085            (
1086                r#"
1087                # Singleton method for car object
1088                def car.wheels
1089                    puts "There are four wheels"
1090                end"#
1091                    .unindent(),
1092                1591,
1093            ),
1094        ],
1095    );
1096}
1097
1098#[gpui::test]
1099async fn test_code_context_retrieval_php() {
1100    let language = php_lang();
1101    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
1102    let mut retriever = CodeContextRetriever::new(embedding_provider);
1103
1104    let text = r#"
1105        <?php
1106
1107        namespace LevelUp\Experience\Concerns;
1108
1109        /*
1110        This is a multiple-lines comment block
1111        that spans over multiple
1112        lines
1113        */
1114        function functionName() {
1115            echo "Hello world!";
1116        }
1117
1118        trait HasAchievements
1119        {
1120            /**
1121            * @throws \Exception
1122            */
1123            public function grantAchievement(Achievement $achievement, $progress = null): void
1124            {
1125                if ($progress > 100) {
1126                    throw new Exception(message: 'Progress cannot be greater than 100');
1127                }
1128
1129                if ($this->achievements()->find($achievement->id)) {
1130                    throw new Exception(message: 'User already has this Achievement');
1131                }
1132
1133                $this->achievements()->attach($achievement, [
1134                    'progress' => $progress ?? null,
1135                ]);
1136
1137                $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1138            }
1139
1140            public function achievements(): BelongsToMany
1141            {
1142                return $this->belongsToMany(related: Achievement::class)
1143                ->withPivot(columns: 'progress')
1144                ->where('is_secret', false)
1145                ->using(AchievementUser::class);
1146            }
1147        }
1148
1149        interface Multiplier
1150        {
1151            public function qualifies(array $data): bool;
1152
1153            public function setMultiplier(): int;
1154        }
1155
1156        enum AuditType: string
1157        {
1158            case Add = 'add';
1159            case Remove = 'remove';
1160            case Reset = 'reset';
1161            case LevelUp = 'level_up';
1162        }
1163
1164        ?>"#
1165    .unindent();
1166
1167    let documents = retriever.parse_file(&text, language.clone()).unwrap();
1168
1169    assert_documents_eq(
1170        &documents,
1171        &[
1172            (
1173                r#"
1174        /*
1175        This is a multiple-lines comment block
1176        that spans over multiple
1177        lines
1178        */
1179        function functionName() {
1180            echo "Hello world!";
1181        }"#
1182                .unindent(),
1183                123,
1184            ),
1185            (
1186                r#"
1187        trait HasAchievements
1188        {
1189            /**
1190            * @throws \Exception
1191            */
1192            public function grantAchievement(Achievement $achievement, $progress = null): void
1193            {/* ... */}
1194
1195            public function achievements(): BelongsToMany
1196            {/* ... */}
1197        }"#
1198                .unindent(),
1199                177,
1200            ),
1201            (r#"
1202            /**
1203            * @throws \Exception
1204            */
1205            public function grantAchievement(Achievement $achievement, $progress = null): void
1206            {
1207                if ($progress > 100) {
1208                    throw new Exception(message: 'Progress cannot be greater than 100');
1209                }
1210
1211                if ($this->achievements()->find($achievement->id)) {
1212                    throw new Exception(message: 'User already has this Achievement');
1213                }
1214
1215                $this->achievements()->attach($achievement, [
1216                    'progress' => $progress ?? null,
1217                ]);
1218
1219                $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1220            }"#.unindent(), 245),
1221            (r#"
1222                public function achievements(): BelongsToMany
1223                {
1224                    return $this->belongsToMany(related: Achievement::class)
1225                    ->withPivot(columns: 'progress')
1226                    ->where('is_secret', false)
1227                    ->using(AchievementUser::class);
1228                }"#.unindent(), 902),
1229            (r#"
1230                interface Multiplier
1231                {
1232                    public function qualifies(array $data): bool;
1233
1234                    public function setMultiplier(): int;
1235                }"#.unindent(),
1236                1146),
1237            (r#"
1238                enum AuditType: string
1239                {
1240                    case Add = 'add';
1241                    case Remove = 'remove';
1242                    case Reset = 'reset';
1243                    case LevelUp = 'level_up';
1244                }"#.unindent(), 1265)
1245        ],
1246    );
1247}
1248
1249fn js_lang() -> Arc<Language> {
1250    Arc::new(
1251        Language::new(
1252            LanguageConfig {
1253                name: "Javascript".into(),
1254                path_suffixes: vec!["js".into()],
1255                ..Default::default()
1256            },
1257            Some(tree_sitter_typescript::language_tsx()),
1258        )
1259        .with_embedding_query(
1260            &r#"
1261
1262            (
1263                (comment)* @context
1264                .
1265                [
1266                (export_statement
1267                    (function_declaration
1268                        "async"? @name
1269                        "function" @name
1270                        name: (_) @name))
1271                (function_declaration
1272                    "async"? @name
1273                    "function" @name
1274                    name: (_) @name)
1275                ] @item
1276            )
1277
1278            (
1279                (comment)* @context
1280                .
1281                [
1282                (export_statement
1283                    (class_declaration
1284                        "class" @name
1285                        name: (_) @name))
1286                (class_declaration
1287                    "class" @name
1288                    name: (_) @name)
1289                ] @item
1290            )
1291
1292            (
1293                (comment)* @context
1294                .
1295                [
1296                (export_statement
1297                    (interface_declaration
1298                        "interface" @name
1299                        name: (_) @name))
1300                (interface_declaration
1301                    "interface" @name
1302                    name: (_) @name)
1303                ] @item
1304            )
1305
1306            (
1307                (comment)* @context
1308                .
1309                [
1310                (export_statement
1311                    (enum_declaration
1312                        "enum" @name
1313                        name: (_) @name))
1314                (enum_declaration
1315                    "enum" @name
1316                    name: (_) @name)
1317                ] @item
1318            )
1319
1320            (
1321                (comment)* @context
1322                .
1323                (method_definition
1324                    [
1325                        "get"
1326                        "set"
1327                        "async"
1328                        "*"
1329                        "static"
1330                    ]* @name
1331                    name: (_) @name) @item
1332            )
1333
1334                    "#
1335            .unindent(),
1336        )
1337        .unwrap(),
1338    )
1339}
1340
1341fn rust_lang() -> Arc<Language> {
1342    Arc::new(
1343        Language::new(
1344            LanguageConfig {
1345                name: "Rust".into(),
1346                path_suffixes: vec!["rs".into()],
1347                collapsed_placeholder: " /* ... */ ".to_string(),
1348                ..Default::default()
1349            },
1350            Some(tree_sitter_rust::language()),
1351        )
1352        .with_embedding_query(
1353            r#"
1354            (
1355                [(line_comment) (attribute_item)]* @context
1356                .
1357                [
1358                    (struct_item
1359                        name: (_) @name)
1360
1361                    (enum_item
1362                        name: (_) @name)
1363
1364                    (impl_item
1365                        trait: (_)? @name
1366                        "for"? @name
1367                        type: (_) @name)
1368
1369                    (trait_item
1370                        name: (_) @name)
1371
1372                    (function_item
1373                        name: (_) @name
1374                        body: (block
1375                            "{" @keep
1376                            "}" @keep) @collapse)
1377
1378                    (macro_definition
1379                        name: (_) @name)
1380                ] @item
1381            )
1382
1383            (attribute_item) @collapse
1384            (use_declaration) @collapse
1385            "#,
1386        )
1387        .unwrap(),
1388    )
1389}
1390
1391fn json_lang() -> Arc<Language> {
1392    Arc::new(
1393        Language::new(
1394            LanguageConfig {
1395                name: "JSON".into(),
1396                path_suffixes: vec!["json".into()],
1397                ..Default::default()
1398            },
1399            Some(tree_sitter_json::language()),
1400        )
1401        .with_embedding_query(
1402            r#"
1403            (document) @item
1404
1405            (array
1406                "[" @keep
1407                .
1408                (object)? @keep
1409                "]" @keep) @collapse
1410
1411            (pair value: (string
1412                "\"" @keep
1413                "\"" @keep) @collapse)
1414            "#,
1415        )
1416        .unwrap(),
1417    )
1418}
1419
1420fn toml_lang() -> Arc<Language> {
1421    Arc::new(Language::new(
1422        LanguageConfig {
1423            name: "TOML".into(),
1424            path_suffixes: vec!["toml".into()],
1425            ..Default::default()
1426        },
1427        Some(tree_sitter_toml::language()),
1428    ))
1429}
1430
1431fn cpp_lang() -> Arc<Language> {
1432    Arc::new(
1433        Language::new(
1434            LanguageConfig {
1435                name: "CPP".into(),
1436                path_suffixes: vec!["cpp".into()],
1437                ..Default::default()
1438            },
1439            Some(tree_sitter_cpp::language()),
1440        )
1441        .with_embedding_query(
1442            r#"
1443            (
1444                (comment)* @context
1445                .
1446                (function_definition
1447                    (type_qualifier)? @name
1448                    type: (_)? @name
1449                    declarator: [
1450                        (function_declarator
1451                            declarator: (_) @name)
1452                        (pointer_declarator
1453                            "*" @name
1454                            declarator: (function_declarator
1455                            declarator: (_) @name))
1456                        (pointer_declarator
1457                            "*" @name
1458                            declarator: (pointer_declarator
1459                                "*" @name
1460                            declarator: (function_declarator
1461                                declarator: (_) @name)))
1462                        (reference_declarator
1463                            ["&" "&&"] @name
1464                            (function_declarator
1465                            declarator: (_) @name))
1466                    ]
1467                    (type_qualifier)? @name) @item
1468                )
1469
1470            (
1471                (comment)* @context
1472                .
1473                (template_declaration
1474                    (class_specifier
1475                        "class" @name
1476                        name: (_) @name)
1477                        ) @item
1478            )
1479
1480            (
1481                (comment)* @context
1482                .
1483                (class_specifier
1484                    "class" @name
1485                    name: (_) @name) @item
1486                )
1487
1488            (
1489                (comment)* @context
1490                .
1491                (enum_specifier
1492                    "enum" @name
1493                    name: (_) @name) @item
1494                )
1495
1496            (
1497                (comment)* @context
1498                .
1499                (declaration
1500                    type: (struct_specifier
1501                    "struct" @name)
1502                    declarator: (_) @name) @item
1503            )
1504
1505            "#,
1506        )
1507        .unwrap(),
1508    )
1509}
1510
1511fn lua_lang() -> Arc<Language> {
1512    Arc::new(
1513        Language::new(
1514            LanguageConfig {
1515                name: "Lua".into(),
1516                path_suffixes: vec!["lua".into()],
1517                collapsed_placeholder: "--[ ... ]--".to_string(),
1518                ..Default::default()
1519            },
1520            Some(tree_sitter_lua::language()),
1521        )
1522        .with_embedding_query(
1523            r#"
1524            (
1525                (comment)* @context
1526                .
1527                (function_declaration
1528                    "function" @name
1529                    name: (_) @name
1530                    (comment)* @collapse
1531                    body: (block) @collapse
1532                ) @item
1533            )
1534        "#,
1535        )
1536        .unwrap(),
1537    )
1538}
1539
1540fn php_lang() -> Arc<Language> {
1541    Arc::new(
1542        Language::new(
1543            LanguageConfig {
1544                name: "PHP".into(),
1545                path_suffixes: vec!["php".into()],
1546                collapsed_placeholder: "/* ... */".into(),
1547                ..Default::default()
1548            },
1549            Some(tree_sitter_php::language_php()),
1550        )
1551        .with_embedding_query(
1552            r#"
1553            (
1554                (comment)* @context
1555                .
1556                [
1557                    (function_definition
1558                        "function" @name
1559                        name: (_) @name
1560                        body: (_
1561                            "{" @keep
1562                            "}" @keep) @collapse
1563                        )
1564
1565                    (trait_declaration
1566                        "trait" @name
1567                        name: (_) @name)
1568
1569                    (method_declaration
1570                        "function" @name
1571                        name: (_) @name
1572                        body: (_
1573                            "{" @keep
1574                            "}" @keep) @collapse
1575                        )
1576
1577                    (interface_declaration
1578                        "interface" @name
1579                        name: (_) @name
1580                        )
1581
1582                    (enum_declaration
1583                        "enum" @name
1584                        name: (_) @name
1585                        )
1586
1587                ] @item
1588            )
1589            "#,
1590        )
1591        .unwrap(),
1592    )
1593}
1594
1595fn ruby_lang() -> Arc<Language> {
1596    Arc::new(
1597        Language::new(
1598            LanguageConfig {
1599                name: "Ruby".into(),
1600                path_suffixes: vec!["rb".into()],
1601                collapsed_placeholder: "# ...".to_string(),
1602                ..Default::default()
1603            },
1604            Some(tree_sitter_ruby::language()),
1605        )
1606        .with_embedding_query(
1607            r#"
1608            (
1609                (comment)* @context
1610                .
1611                [
1612                (module
1613                    "module" @name
1614                    name: (_) @name)
1615                (method
1616                    "def" @name
1617                    name: (_) @name
1618                    body: (body_statement) @collapse)
1619                (class
1620                    "class" @name
1621                    name: (_) @name)
1622                (singleton_method
1623                    "def" @name
1624                    object: (_) @name
1625                    "." @name
1626                    name: (_) @name
1627                    body: (body_statement) @collapse)
1628                ] @item
1629            )
1630            "#,
1631        )
1632        .unwrap(),
1633    )
1634}
1635
1636fn elixir_lang() -> Arc<Language> {
1637    Arc::new(
1638        Language::new(
1639            LanguageConfig {
1640                name: "Elixir".into(),
1641                path_suffixes: vec!["rs".into()],
1642                ..Default::default()
1643            },
1644            Some(tree_sitter_elixir::language()),
1645        )
1646        .with_embedding_query(
1647            r#"
1648            (
1649                (unary_operator
1650                    operator: "@"
1651                    operand: (call
1652                        target: (identifier) @unary
1653                        (#match? @unary "^(doc)$"))
1654                    ) @context
1655                .
1656                (call
1657                target: (identifier) @name
1658                (arguments
1659                [
1660                (identifier) @name
1661                (call
1662                target: (identifier) @name)
1663                (binary_operator
1664                left: (call
1665                target: (identifier) @name)
1666                operator: "when")
1667                ])
1668                (#any-match? @name "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @item
1669                )
1670
1671            (call
1672                target: (identifier) @name
1673                (arguments (alias) @name)
1674                (#any-match? @name "^(defmodule|defprotocol)$")) @item
1675            "#,
1676        )
1677        .unwrap(),
1678    )
1679}
1680
1681#[gpui::test]
1682fn test_subtract_ranges() {
1683    assert_eq!(
1684        subtract_ranges(&[0..5, 10..21], &[0..1, 4..5]),
1685        vec![1..4, 10..21]
1686    );
1687
1688    assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
1689}
1690
1691fn init_test(cx: &mut TestAppContext) {
1692    cx.update(|cx| {
1693        let settings_store = SettingsStore::test(cx);
1694        cx.set_global(settings_store);
1695        SemanticIndexSettings::register(cx);
1696        ProjectSettings::register(cx);
1697    });
1698}