semantic_index_tests.rs

   1use crate::{
   2    embedding_queue::EmbeddingQueue,
   3    parsing::{subtract_ranges, CodeContextRetriever, Span, SpanDigest},
   4    semantic_index_settings::SemanticIndexSettings,
   5    FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
   6};
   7use ai::test::FakeEmbeddingProvider;
   8
   9use gpui::{executor::Deterministic, Task, TestAppContext};
  10use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
  11use parking_lot::Mutex;
  12use pretty_assertions::assert_eq;
  13use project::{project_settings::ProjectSettings, FakeFs, Fs, Project};
  14use rand::{rngs::StdRng, Rng};
  15use serde_json::json;
  16use settings::SettingsStore;
  17use std::{path::Path, sync::Arc, time::SystemTime};
  18use unindent::Unindent;
  19use util::{paths::PathMatcher, RandomCharIter};
  20
  21#[ctor::ctor]
  22fn init_logger() {
  23    if std::env::var("RUST_LOG").is_ok() {
  24        env_logger::init();
  25    }
  26}
  27
  28#[gpui::test]
  29async fn test_semantic_index(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
  30    init_test(cx);
  31
  32    let fs = FakeFs::new(cx.background());
  33    fs.insert_tree(
  34        "/the-root",
  35        json!({
  36            "src": {
  37                "file1.rs": "
  38                    fn aaa() {
  39                        println!(\"aaaaaaaaaaaa!\");
  40                    }
  41
  42                    fn zzzzz() {
  43                        println!(\"SLEEPING\");
  44                    }
  45                ".unindent(),
  46                "file2.rs": "
  47                    fn bbb() {
  48                        println!(\"bbbbbbbbbbbbb!\");
  49                    }
  50                    struct pqpqpqp {}
  51                ".unindent(),
  52                "file3.toml": "
  53                    ZZZZZZZZZZZZZZZZZZ = 5
  54                ".unindent(),
  55            }
  56        }),
  57    )
  58    .await;
  59
  60    let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
  61    let rust_language = rust_lang();
  62    let toml_language = toml_lang();
  63    languages.add(rust_language);
  64    languages.add(toml_language);
  65
  66    let db_dir = tempdir::TempDir::new("vector-store").unwrap();
  67    let db_path = db_dir.path().join("db.sqlite");
  68
  69    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
  70    let semantic_index = SemanticIndex::new(
  71        fs.clone(),
  72        db_path,
  73        embedding_provider.clone(),
  74        languages,
  75        cx.to_async(),
  76    )
  77    .await
  78    .unwrap();
  79
  80    let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
  81
  82    let search_results = semantic_index.update(cx, |store, cx| {
  83        store.search_project(
  84            project.clone(),
  85            "aaaaaabbbbzz".to_string(),
  86            5,
  87            vec![],
  88            vec![],
  89            cx,
  90        )
  91    });
  92    let pending_file_count =
  93        semantic_index.read_with(cx, |index, _| index.pending_file_count(&project).unwrap());
  94    deterministic.run_until_parked();
  95    assert_eq!(*pending_file_count.borrow(), 3);
  96    deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
  97    assert_eq!(*pending_file_count.borrow(), 0);
  98
  99    let search_results = search_results.await.unwrap();
 100    assert_search_results(
 101        &search_results,
 102        &[
 103            (Path::new("src/file1.rs").into(), 0),
 104            (Path::new("src/file2.rs").into(), 0),
 105            (Path::new("src/file3.toml").into(), 0),
 106            (Path::new("src/file1.rs").into(), 45),
 107            (Path::new("src/file2.rs").into(), 45),
 108        ],
 109        cx,
 110    );
 111
 112    // Test Include Files Functonality
 113    let include_files = vec![PathMatcher::new("*.rs").unwrap()];
 114    let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
 115    let rust_only_search_results = semantic_index
 116        .update(cx, |store, cx| {
 117            store.search_project(
 118                project.clone(),
 119                "aaaaaabbbbzz".to_string(),
 120                5,
 121                include_files,
 122                vec![],
 123                cx,
 124            )
 125        })
 126        .await
 127        .unwrap();
 128
 129    assert_search_results(
 130        &rust_only_search_results,
 131        &[
 132            (Path::new("src/file1.rs").into(), 0),
 133            (Path::new("src/file2.rs").into(), 0),
 134            (Path::new("src/file1.rs").into(), 45),
 135            (Path::new("src/file2.rs").into(), 45),
 136        ],
 137        cx,
 138    );
 139
 140    let no_rust_search_results = semantic_index
 141        .update(cx, |store, cx| {
 142            store.search_project(
 143                project.clone(),
 144                "aaaaaabbbbzz".to_string(),
 145                5,
 146                vec![],
 147                exclude_files,
 148                cx,
 149            )
 150        })
 151        .await
 152        .unwrap();
 153
 154    assert_search_results(
 155        &no_rust_search_results,
 156        &[(Path::new("src/file3.toml").into(), 0)],
 157        cx,
 158    );
 159
 160    fs.save(
 161        "/the-root/src/file2.rs".as_ref(),
 162        &"
 163            fn dddd() { println!(\"ddddd!\"); }
 164            struct pqpqpqp {}
 165        "
 166        .unindent()
 167        .into(),
 168        Default::default(),
 169    )
 170    .await
 171    .unwrap();
 172
 173    deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
 174
 175    let prev_embedding_count = embedding_provider.embedding_count();
 176    let index = semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx));
 177    deterministic.run_until_parked();
 178    assert_eq!(*pending_file_count.borrow(), 1);
 179    deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
 180    assert_eq!(*pending_file_count.borrow(), 0);
 181    index.await.unwrap();
 182
 183    assert_eq!(
 184        embedding_provider.embedding_count() - prev_embedding_count,
 185        1
 186    );
 187}
 188
 189#[gpui::test(iterations = 10)]
 190async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
 191    let (outstanding_job_count, _) = postage::watch::channel_with(0);
 192    let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count));
 193
 194    let files = (1..=3)
 195        .map(|file_ix| FileToEmbed {
 196            worktree_id: 5,
 197            path: Path::new(&format!("path-{file_ix}")).into(),
 198            mtime: SystemTime::now(),
 199            spans: (0..rng.gen_range(4..22))
 200                .map(|document_ix| {
 201                    let content_len = rng.gen_range(10..100);
 202                    let content = RandomCharIter::new(&mut rng)
 203                        .with_simple_text()
 204                        .take(content_len)
 205                        .collect::<String>();
 206                    let digest = SpanDigest::from(content.as_str());
 207                    Span {
 208                        range: 0..10,
 209                        embedding: None,
 210                        name: format!("document {document_ix}"),
 211                        content,
 212                        digest,
 213                        token_count: rng.gen_range(10..30),
 214                    }
 215                })
 216                .collect(),
 217            job_handle: JobHandle::new(&outstanding_job_count),
 218        })
 219        .collect::<Vec<_>>();
 220
 221    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 222
 223    let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
 224    for file in &files {
 225        queue.push(file.clone());
 226    }
 227    queue.flush();
 228
 229    cx.foreground().run_until_parked();
 230    let finished_files = queue.finished_files();
 231    let mut embedded_files: Vec<_> = files
 232        .iter()
 233        .map(|_| finished_files.try_recv().expect("no finished file"))
 234        .collect();
 235
 236    let expected_files: Vec<_> = files
 237        .iter()
 238        .map(|file| {
 239            let mut file = file.clone();
 240            for doc in &mut file.spans {
 241                doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
 242            }
 243            file
 244        })
 245        .collect();
 246
 247    embedded_files.sort_by_key(|f| f.path.clone());
 248
 249    assert_eq!(embedded_files, expected_files);
 250}
 251
 252#[track_caller]
 253fn assert_search_results(
 254    actual: &[SearchResult],
 255    expected: &[(Arc<Path>, usize)],
 256    cx: &TestAppContext,
 257) {
 258    let actual = actual
 259        .iter()
 260        .map(|search_result| {
 261            search_result.buffer.read_with(cx, |buffer, _cx| {
 262                (
 263                    buffer.file().unwrap().path().clone(),
 264                    search_result.range.start.to_offset(buffer),
 265                )
 266            })
 267        })
 268        .collect::<Vec<_>>();
 269    assert_eq!(actual, expected);
 270}
 271
 272#[gpui::test]
 273async fn test_code_context_retrieval_rust() {
 274    let language = rust_lang();
 275    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 276    let mut retriever = CodeContextRetriever::new(embedding_provider);
 277
 278    let text = "
 279        /// A doc comment
 280        /// that spans multiple lines
 281        #[gpui::test]
 282        fn a() {
 283            b
 284        }
 285
 286        impl C for D {
 287        }
 288
 289        impl E {
 290            // This is also a preceding comment
 291            pub fn function_1() -> Option<()> {
 292                unimplemented!();
 293            }
 294
 295            // This is a preceding comment
 296            fn function_2() -> Result<()> {
 297                unimplemented!();
 298            }
 299        }
 300
 301        #[derive(Clone)]
 302        struct D {
 303            name: String
 304        }
 305    "
 306    .unindent();
 307
 308    let documents = retriever.parse_file(&text, language).unwrap();
 309
 310    assert_documents_eq(
 311        &documents,
 312        &[
 313            (
 314                "
 315                /// A doc comment
 316                /// that spans multiple lines
 317                #[gpui::test]
 318                fn a() {
 319                    b
 320                }"
 321                .unindent(),
 322                text.find("fn a").unwrap(),
 323            ),
 324            (
 325                "
 326                impl C for D {
 327                }"
 328                .unindent(),
 329                text.find("impl C").unwrap(),
 330            ),
 331            (
 332                "
 333                impl E {
 334                    // This is also a preceding comment
 335                    pub fn function_1() -> Option<()> { /* ... */ }
 336
 337                    // This is a preceding comment
 338                    fn function_2() -> Result<()> { /* ... */ }
 339                }"
 340                .unindent(),
 341                text.find("impl E").unwrap(),
 342            ),
 343            (
 344                "
 345                // This is also a preceding comment
 346                pub fn function_1() -> Option<()> {
 347                    unimplemented!();
 348                }"
 349                .unindent(),
 350                text.find("pub fn function_1").unwrap(),
 351            ),
 352            (
 353                "
 354                // This is a preceding comment
 355                fn function_2() -> Result<()> {
 356                    unimplemented!();
 357                }"
 358                .unindent(),
 359                text.find("fn function_2").unwrap(),
 360            ),
 361            (
 362                "
 363                #[derive(Clone)]
 364                struct D {
 365                    name: String
 366                }"
 367                .unindent(),
 368                text.find("struct D").unwrap(),
 369            ),
 370        ],
 371    );
 372}
 373
 374#[gpui::test]
 375async fn test_code_context_retrieval_json() {
 376    let language = json_lang();
 377    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 378    let mut retriever = CodeContextRetriever::new(embedding_provider);
 379
 380    let text = r#"
 381        {
 382            "array": [1, 2, 3, 4],
 383            "string": "abcdefg",
 384            "nested_object": {
 385                "array_2": [5, 6, 7, 8],
 386                "string_2": "hijklmnop",
 387                "boolean": true,
 388                "none": null
 389            }
 390        }
 391    "#
 392    .unindent();
 393
 394    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 395
 396    assert_documents_eq(
 397        &documents,
 398        &[(
 399            r#"
 400                {
 401                    "array": [],
 402                    "string": "",
 403                    "nested_object": {
 404                        "array_2": [],
 405                        "string_2": "",
 406                        "boolean": true,
 407                        "none": null
 408                    }
 409                }"#
 410            .unindent(),
 411            text.find("{").unwrap(),
 412        )],
 413    );
 414
 415    let text = r#"
 416        [
 417            {
 418                "name": "somebody",
 419                "age": 42
 420            },
 421            {
 422                "name": "somebody else",
 423                "age": 43
 424            }
 425        ]
 426    "#
 427    .unindent();
 428
 429    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 430
 431    assert_documents_eq(
 432        &documents,
 433        &[(
 434            r#"
 435            [{
 436                    "name": "",
 437                    "age": 42
 438                }]"#
 439            .unindent(),
 440            text.find("[").unwrap(),
 441        )],
 442    );
 443}
 444
 445fn assert_documents_eq(
 446    documents: &[Span],
 447    expected_contents_and_start_offsets: &[(String, usize)],
 448) {
 449    assert_eq!(
 450        documents
 451            .iter()
 452            .map(|document| (document.content.clone(), document.range.start))
 453            .collect::<Vec<_>>(),
 454        expected_contents_and_start_offsets
 455    );
 456}
 457
 458#[gpui::test]
 459async fn test_code_context_retrieval_javascript() {
 460    let language = js_lang();
 461    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 462    let mut retriever = CodeContextRetriever::new(embedding_provider);
 463
 464    let text = "
 465        /* globals importScripts, backend */
 466        function _authorize() {}
 467
 468        /**
 469         * Sometimes the frontend build is way faster than backend.
 470         */
 471        export async function authorizeBank() {
 472            _authorize(pushModal, upgradingAccountId, {});
 473        }
 474
 475        export class SettingsPage {
 476            /* This is a test setting */
 477            constructor(page) {
 478                this.page = page;
 479            }
 480        }
 481
 482        /* This is a test comment */
 483        class TestClass {}
 484
 485        /* Schema for editor_events in Clickhouse. */
 486        export interface ClickhouseEditorEvent {
 487            installation_id: string
 488            operation: string
 489        }
 490        "
 491    .unindent();
 492
 493    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 494
 495    assert_documents_eq(
 496        &documents,
 497        &[
 498            (
 499                "
 500            /* globals importScripts, backend */
 501            function _authorize() {}"
 502                    .unindent(),
 503                37,
 504            ),
 505            (
 506                "
 507            /**
 508             * Sometimes the frontend build is way faster than backend.
 509             */
 510            export async function authorizeBank() {
 511                _authorize(pushModal, upgradingAccountId, {});
 512            }"
 513                .unindent(),
 514                131,
 515            ),
 516            (
 517                "
 518                export class SettingsPage {
 519                    /* This is a test setting */
 520                    constructor(page) {
 521                        this.page = page;
 522                    }
 523                }"
 524                .unindent(),
 525                225,
 526            ),
 527            (
 528                "
 529                /* This is a test setting */
 530                constructor(page) {
 531                    this.page = page;
 532                }"
 533                .unindent(),
 534                290,
 535            ),
 536            (
 537                "
 538                /* This is a test comment */
 539                class TestClass {}"
 540                    .unindent(),
 541                374,
 542            ),
 543            (
 544                "
 545                /* Schema for editor_events in Clickhouse. */
 546                export interface ClickhouseEditorEvent {
 547                    installation_id: string
 548                    operation: string
 549                }"
 550                .unindent(),
 551                440,
 552            ),
 553        ],
 554    )
 555}
 556
 557#[gpui::test]
 558async fn test_code_context_retrieval_lua() {
 559    let language = lua_lang();
 560    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 561    let mut retriever = CodeContextRetriever::new(embedding_provider);
 562
 563    let text = r#"
 564        -- Creates a new class
 565        -- @param baseclass The Baseclass of this class, or nil.
 566        -- @return A new class reference.
 567        function classes.class(baseclass)
 568            -- Create the class definition and metatable.
 569            local classdef = {}
 570            -- Find the super class, either Object or user-defined.
 571            baseclass = baseclass or classes.Object
 572            -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
 573            setmetatable(classdef, { __index = baseclass })
 574            -- All class instances have a reference to the class object.
 575            classdef.class = classdef
 576            --- Recursivly allocates the inheritance tree of the instance.
 577            -- @param mastertable The 'root' of the inheritance tree.
 578            -- @return Returns the instance with the allocated inheritance tree.
 579            function classdef.alloc(mastertable)
 580                -- All class instances have a reference to a superclass object.
 581                local instance = { super = baseclass.alloc(mastertable) }
 582                -- Any functions this instance does not know of will 'look up' to the superclass definition.
 583                setmetatable(instance, { __index = classdef, __newindex = mastertable })
 584                return instance
 585            end
 586        end
 587        "#.unindent();
 588
 589    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 590
 591    assert_documents_eq(
 592        &documents,
 593        &[
 594            (r#"
 595                -- Creates a new class
 596                -- @param baseclass The Baseclass of this class, or nil.
 597                -- @return A new class reference.
 598                function classes.class(baseclass)
 599                    -- Create the class definition and metatable.
 600                    local classdef = {}
 601                    -- Find the super class, either Object or user-defined.
 602                    baseclass = baseclass or classes.Object
 603                    -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
 604                    setmetatable(classdef, { __index = baseclass })
 605                    -- All class instances have a reference to the class object.
 606                    classdef.class = classdef
 607                    --- Recursivly allocates the inheritance tree of the instance.
 608                    -- @param mastertable The 'root' of the inheritance tree.
 609                    -- @return Returns the instance with the allocated inheritance tree.
 610                    function classdef.alloc(mastertable)
 611                        --[ ... ]--
 612                        --[ ... ]--
 613                    end
 614                end"#.unindent(),
 615            114),
 616            (r#"
 617            --- Recursivly allocates the inheritance tree of the instance.
 618            -- @param mastertable The 'root' of the inheritance tree.
 619            -- @return Returns the instance with the allocated inheritance tree.
 620            function classdef.alloc(mastertable)
 621                -- All class instances have a reference to a superclass object.
 622                local instance = { super = baseclass.alloc(mastertable) }
 623                -- Any functions this instance does not know of will 'look up' to the superclass definition.
 624                setmetatable(instance, { __index = classdef, __newindex = mastertable })
 625                return instance
 626            end"#.unindent(), 809),
 627        ]
 628    );
 629}
 630
 631#[gpui::test]
 632async fn test_code_context_retrieval_elixir() {
 633    let language = elixir_lang();
 634    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 635    let mut retriever = CodeContextRetriever::new(embedding_provider);
 636
 637    let text = r#"
 638        defmodule File.Stream do
 639            @moduledoc """
 640            Defines a `File.Stream` struct returned by `File.stream!/3`.
 641
 642            The following fields are public:
 643
 644            * `path`          - the file path
 645            * `modes`         - the file modes
 646            * `raw`           - a boolean indicating if bin functions should be used
 647            * `line_or_bytes` - if reading should read lines or a given number of bytes
 648            * `node`          - the node the file belongs to
 649
 650            """
 651
 652            defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
 653
 654            @type t :: %__MODULE__{}
 655
 656            @doc false
 657            def __build__(path, modes, line_or_bytes) do
 658            raw = :lists.keyfind(:encoding, 1, modes) == false
 659
 660            modes =
 661                case raw do
 662                true ->
 663                    case :lists.keyfind(:read_ahead, 1, modes) do
 664                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
 665                    {:read_ahead, _} -> [:raw | modes]
 666                    false -> [:raw, :read_ahead | modes]
 667                    end
 668
 669                false ->
 670                    modes
 671                end
 672
 673            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
 674
 675            end"#
 676    .unindent();
 677
 678    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 679
 680    assert_documents_eq(
 681        &documents,
 682        &[(
 683            r#"
 684        defmodule File.Stream do
 685            @moduledoc """
 686            Defines a `File.Stream` struct returned by `File.stream!/3`.
 687
 688            The following fields are public:
 689
 690            * `path`          - the file path
 691            * `modes`         - the file modes
 692            * `raw`           - a boolean indicating if bin functions should be used
 693            * `line_or_bytes` - if reading should read lines or a given number of bytes
 694            * `node`          - the node the file belongs to
 695
 696            """
 697
 698            defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
 699
 700            @type t :: %__MODULE__{}
 701
 702            @doc false
 703            def __build__(path, modes, line_or_bytes) do
 704            raw = :lists.keyfind(:encoding, 1, modes) == false
 705
 706            modes =
 707                case raw do
 708                true ->
 709                    case :lists.keyfind(:read_ahead, 1, modes) do
 710                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
 711                    {:read_ahead, _} -> [:raw | modes]
 712                    false -> [:raw, :read_ahead | modes]
 713                    end
 714
 715                false ->
 716                    modes
 717                end
 718
 719            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
 720
 721            end"#
 722                .unindent(),
 723            0,
 724        ),(r#"
 725            @doc false
 726            def __build__(path, modes, line_or_bytes) do
 727            raw = :lists.keyfind(:encoding, 1, modes) == false
 728
 729            modes =
 730                case raw do
 731                true ->
 732                    case :lists.keyfind(:read_ahead, 1, modes) do
 733                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
 734                    {:read_ahead, _} -> [:raw | modes]
 735                    false -> [:raw, :read_ahead | modes]
 736                    end
 737
 738                false ->
 739                    modes
 740                end
 741
 742            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
 743
 744            end"#.unindent(), 574)],
 745    );
 746}
 747
 748#[gpui::test]
 749async fn test_code_context_retrieval_cpp() {
 750    let language = cpp_lang();
 751    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 752    let mut retriever = CodeContextRetriever::new(embedding_provider);
 753
 754    let text = "
 755    /**
 756     * @brief Main function
 757     * @returns 0 on exit
 758     */
 759    int main() { return 0; }
 760
 761    /**
 762    * This is a test comment
 763    */
 764    class MyClass {       // The class
 765        public:           // Access specifier
 766        int myNum;        // Attribute (int variable)
 767        string myString;  // Attribute (string variable)
 768    };
 769
 770    // This is a test comment
 771    enum Color { red, green, blue };
 772
 773    /** This is a preceding block comment
 774     * This is the second line
 775     */
 776    struct {           // Structure declaration
 777        int myNum;       // Member (int variable)
 778        string myString; // Member (string variable)
 779    } myStructure;
 780
 781    /**
 782     * @brief Matrix class.
 783     */
 784    template <typename T,
 785              typename = typename std::enable_if<
 786                std::is_integral<T>::value || std::is_floating_point<T>::value,
 787                bool>::type>
 788    class Matrix2 {
 789        std::vector<std::vector<T>> _mat;
 790
 791        public:
 792            /**
 793            * @brief Constructor
 794            * @tparam Integer ensuring integers are being evaluated and not other
 795            * data types.
 796            * @param size denoting the size of Matrix as size x size
 797            */
 798            template <typename Integer,
 799                    typename = typename std::enable_if<std::is_integral<Integer>::value,
 800                    Integer>::type>
 801            explicit Matrix(const Integer size) {
 802                for (size_t i = 0; i < size; ++i) {
 803                    _mat.emplace_back(std::vector<T>(size, 0));
 804                }
 805            }
 806    }"
 807    .unindent();
 808
 809    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 810
 811    assert_documents_eq(
 812        &documents,
 813        &[
 814            (
 815                "
 816        /**
 817         * @brief Main function
 818         * @returns 0 on exit
 819         */
 820        int main() { return 0; }"
 821                    .unindent(),
 822                54,
 823            ),
 824            (
 825                "
 826                /**
 827                * This is a test comment
 828                */
 829                class MyClass {       // The class
 830                    public:           // Access specifier
 831                    int myNum;        // Attribute (int variable)
 832                    string myString;  // Attribute (string variable)
 833                }"
 834                .unindent(),
 835                112,
 836            ),
 837            (
 838                "
 839                // This is a test comment
 840                enum Color { red, green, blue }"
 841                    .unindent(),
 842                322,
 843            ),
 844            (
 845                "
 846                /** This is a preceding block comment
 847                 * This is the second line
 848                 */
 849                struct {           // Structure declaration
 850                    int myNum;       // Member (int variable)
 851                    string myString; // Member (string variable)
 852                } myStructure;"
 853                    .unindent(),
 854                425,
 855            ),
 856            (
 857                "
 858                /**
 859                 * @brief Matrix class.
 860                 */
 861                template <typename T,
 862                          typename = typename std::enable_if<
 863                            std::is_integral<T>::value || std::is_floating_point<T>::value,
 864                            bool>::type>
 865                class Matrix2 {
 866                    std::vector<std::vector<T>> _mat;
 867
 868                    public:
 869                        /**
 870                        * @brief Constructor
 871                        * @tparam Integer ensuring integers are being evaluated and not other
 872                        * data types.
 873                        * @param size denoting the size of Matrix as size x size
 874                        */
 875                        template <typename Integer,
 876                                typename = typename std::enable_if<std::is_integral<Integer>::value,
 877                                Integer>::type>
 878                        explicit Matrix(const Integer size) {
 879                            for (size_t i = 0; i < size; ++i) {
 880                                _mat.emplace_back(std::vector<T>(size, 0));
 881                            }
 882                        }
 883                }"
 884                .unindent(),
 885                612,
 886            ),
 887            (
 888                "
 889                explicit Matrix(const Integer size) {
 890                    for (size_t i = 0; i < size; ++i) {
 891                        _mat.emplace_back(std::vector<T>(size, 0));
 892                    }
 893                }"
 894                .unindent(),
 895                1226,
 896            ),
 897        ],
 898    );
 899}
 900
 901#[gpui::test]
 902async fn test_code_context_retrieval_ruby() {
 903    let language = ruby_lang();
 904    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 905    let mut retriever = CodeContextRetriever::new(embedding_provider);
 906
 907    let text = r#"
 908        # This concern is inspired by "sudo mode" on GitHub. It
 909        # is a way to re-authenticate a user before allowing them
 910        # to see or perform an action.
 911        #
 912        # Add `before_action :require_challenge!` to actions you
 913        # want to protect.
 914        #
 915        # The user will be shown a page to enter the challenge (which
 916        # is either the password, or just the username when no
 917        # password exists). Upon passing, there is a grace period
 918        # during which no challenge will be asked from the user.
 919        #
 920        # Accessing challenge-protected resources during the grace
 921        # period will refresh the grace period.
 922        module ChallengableConcern
 923            extend ActiveSupport::Concern
 924
 925            CHALLENGE_TIMEOUT = 1.hour.freeze
 926
 927            def require_challenge!
 928                return if skip_challenge?
 929
 930                if challenge_passed_recently?
 931                    session[:challenge_passed_at] = Time.now.utc
 932                    return
 933                end
 934
 935                @challenge = Form::Challenge.new(return_to: request.url)
 936
 937                if params.key?(:form_challenge)
 938                    if challenge_passed?
 939                        session[:challenge_passed_at] = Time.now.utc
 940                    else
 941                        flash.now[:alert] = I18n.t('challenge.invalid_password')
 942                        render_challenge
 943                    end
 944                else
 945                    render_challenge
 946                end
 947            end
 948
 949            def challenge_passed?
 950                current_user.valid_password?(challenge_params[:current_password])
 951            end
 952        end
 953
 954        class Animal
 955            include Comparable
 956
 957            attr_reader :legs
 958
 959            def initialize(name, legs)
 960                @name, @legs = name, legs
 961            end
 962
 963            def <=>(other)
 964                legs <=> other.legs
 965            end
 966        end
 967
 968        # Singleton method for car object
 969        def car.wheels
 970            puts "There are four wheels"
 971        end"#
 972        .unindent();
 973
 974    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 975
 976    assert_documents_eq(
 977        &documents,
 978        &[
 979            (
 980                r#"
 981        # This concern is inspired by "sudo mode" on GitHub. It
 982        # is a way to re-authenticate a user before allowing them
 983        # to see or perform an action.
 984        #
 985        # Add `before_action :require_challenge!` to actions you
 986        # want to protect.
 987        #
 988        # The user will be shown a page to enter the challenge (which
 989        # is either the password, or just the username when no
 990        # password exists). Upon passing, there is a grace period
 991        # during which no challenge will be asked from the user.
 992        #
 993        # Accessing challenge-protected resources during the grace
 994        # period will refresh the grace period.
 995        module ChallengableConcern
 996            extend ActiveSupport::Concern
 997
 998            CHALLENGE_TIMEOUT = 1.hour.freeze
 999
1000            def require_challenge!
1001                # ...
1002            end
1003
1004            def challenge_passed?
1005                # ...
1006            end
1007        end"#
1008                    .unindent(),
1009                558,
1010            ),
1011            (
1012                r#"
1013            def require_challenge!
1014                return if skip_challenge?
1015
1016                if challenge_passed_recently?
1017                    session[:challenge_passed_at] = Time.now.utc
1018                    return
1019                end
1020
1021                @challenge = Form::Challenge.new(return_to: request.url)
1022
1023                if params.key?(:form_challenge)
1024                    if challenge_passed?
1025                        session[:challenge_passed_at] = Time.now.utc
1026                    else
1027                        flash.now[:alert] = I18n.t('challenge.invalid_password')
1028                        render_challenge
1029                    end
1030                else
1031                    render_challenge
1032                end
1033            end"#
1034                    .unindent(),
1035                663,
1036            ),
1037            (
1038                r#"
1039                def challenge_passed?
1040                    current_user.valid_password?(challenge_params[:current_password])
1041                end"#
1042                    .unindent(),
1043                1254,
1044            ),
1045            (
1046                r#"
1047                class Animal
1048                    include Comparable
1049
1050                    attr_reader :legs
1051
1052                    def initialize(name, legs)
1053                        # ...
1054                    end
1055
1056                    def <=>(other)
1057                        # ...
1058                    end
1059                end"#
1060                    .unindent(),
1061                1363,
1062            ),
1063            (
1064                r#"
1065                def initialize(name, legs)
1066                    @name, @legs = name, legs
1067                end"#
1068                    .unindent(),
1069                1427,
1070            ),
1071            (
1072                r#"
1073                def <=>(other)
1074                    legs <=> other.legs
1075                end"#
1076                    .unindent(),
1077                1501,
1078            ),
1079            (
1080                r#"
1081                # Singleton method for car object
1082                def car.wheels
1083                    puts "There are four wheels"
1084                end"#
1085                    .unindent(),
1086                1591,
1087            ),
1088        ],
1089    );
1090}
1091
1092#[gpui::test]
1093async fn test_code_context_retrieval_php() {
1094    let language = php_lang();
1095    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
1096    let mut retriever = CodeContextRetriever::new(embedding_provider);
1097
1098    let text = r#"
1099        <?php
1100
1101        namespace LevelUp\Experience\Concerns;
1102
1103        /*
1104        This is a multiple-lines comment block
1105        that spans over multiple
1106        lines
1107        */
1108        function functionName() {
1109            echo "Hello world!";
1110        }
1111
1112        trait HasAchievements
1113        {
1114            /**
1115            * @throws \Exception
1116            */
1117            public function grantAchievement(Achievement $achievement, $progress = null): void
1118            {
1119                if ($progress > 100) {
1120                    throw new Exception(message: 'Progress cannot be greater than 100');
1121                }
1122
1123                if ($this->achievements()->find($achievement->id)) {
1124                    throw new Exception(message: 'User already has this Achievement');
1125                }
1126
1127                $this->achievements()->attach($achievement, [
1128                    'progress' => $progress ?? null,
1129                ]);
1130
1131                $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1132            }
1133
1134            public function achievements(): BelongsToMany
1135            {
1136                return $this->belongsToMany(related: Achievement::class)
1137                ->withPivot(columns: 'progress')
1138                ->where('is_secret', false)
1139                ->using(AchievementUser::class);
1140            }
1141        }
1142
1143        interface Multiplier
1144        {
1145            public function qualifies(array $data): bool;
1146
1147            public function setMultiplier(): int;
1148        }
1149
1150        enum AuditType: string
1151        {
1152            case Add = 'add';
1153            case Remove = 'remove';
1154            case Reset = 'reset';
1155            case LevelUp = 'level_up';
1156        }
1157
1158        ?>"#
1159    .unindent();
1160
1161    let documents = retriever.parse_file(&text, language.clone()).unwrap();
1162
1163    assert_documents_eq(
1164        &documents,
1165        &[
1166            (
1167                r#"
1168        /*
1169        This is a multiple-lines comment block
1170        that spans over multiple
1171        lines
1172        */
1173        function functionName() {
1174            echo "Hello world!";
1175        }"#
1176                .unindent(),
1177                123,
1178            ),
1179            (
1180                r#"
1181        trait HasAchievements
1182        {
1183            /**
1184            * @throws \Exception
1185            */
1186            public function grantAchievement(Achievement $achievement, $progress = null): void
1187            {/* ... */}
1188
1189            public function achievements(): BelongsToMany
1190            {/* ... */}
1191        }"#
1192                .unindent(),
1193                177,
1194            ),
1195            (r#"
1196            /**
1197            * @throws \Exception
1198            */
1199            public function grantAchievement(Achievement $achievement, $progress = null): void
1200            {
1201                if ($progress > 100) {
1202                    throw new Exception(message: 'Progress cannot be greater than 100');
1203                }
1204
1205                if ($this->achievements()->find($achievement->id)) {
1206                    throw new Exception(message: 'User already has this Achievement');
1207                }
1208
1209                $this->achievements()->attach($achievement, [
1210                    'progress' => $progress ?? null,
1211                ]);
1212
1213                $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1214            }"#.unindent(), 245),
1215            (r#"
1216                public function achievements(): BelongsToMany
1217                {
1218                    return $this->belongsToMany(related: Achievement::class)
1219                    ->withPivot(columns: 'progress')
1220                    ->where('is_secret', false)
1221                    ->using(AchievementUser::class);
1222                }"#.unindent(), 902),
1223            (r#"
1224                interface Multiplier
1225                {
1226                    public function qualifies(array $data): bool;
1227
1228                    public function setMultiplier(): int;
1229                }"#.unindent(),
1230                1146),
1231            (r#"
1232                enum AuditType: string
1233                {
1234                    case Add = 'add';
1235                    case Remove = 'remove';
1236                    case Reset = 'reset';
1237                    case LevelUp = 'level_up';
1238                }"#.unindent(), 1265)
1239        ],
1240    );
1241}
1242
1243fn js_lang() -> Arc<Language> {
1244    Arc::new(
1245        Language::new(
1246            LanguageConfig {
1247                name: "Javascript".into(),
1248                path_suffixes: vec!["js".into()],
1249                ..Default::default()
1250            },
1251            Some(tree_sitter_typescript::language_tsx()),
1252        )
1253        .with_embedding_query(
1254            &r#"
1255
1256            (
1257                (comment)* @context
1258                .
1259                [
1260                (export_statement
1261                    (function_declaration
1262                        "async"? @name
1263                        "function" @name
1264                        name: (_) @name))
1265                (function_declaration
1266                    "async"? @name
1267                    "function" @name
1268                    name: (_) @name)
1269                ] @item
1270            )
1271
1272            (
1273                (comment)* @context
1274                .
1275                [
1276                (export_statement
1277                    (class_declaration
1278                        "class" @name
1279                        name: (_) @name))
1280                (class_declaration
1281                    "class" @name
1282                    name: (_) @name)
1283                ] @item
1284            )
1285
1286            (
1287                (comment)* @context
1288                .
1289                [
1290                (export_statement
1291                    (interface_declaration
1292                        "interface" @name
1293                        name: (_) @name))
1294                (interface_declaration
1295                    "interface" @name
1296                    name: (_) @name)
1297                ] @item
1298            )
1299
1300            (
1301                (comment)* @context
1302                .
1303                [
1304                (export_statement
1305                    (enum_declaration
1306                        "enum" @name
1307                        name: (_) @name))
1308                (enum_declaration
1309                    "enum" @name
1310                    name: (_) @name)
1311                ] @item
1312            )
1313
1314            (
1315                (comment)* @context
1316                .
1317                (method_definition
1318                    [
1319                        "get"
1320                        "set"
1321                        "async"
1322                        "*"
1323                        "static"
1324                    ]* @name
1325                    name: (_) @name) @item
1326            )
1327
1328                    "#
1329            .unindent(),
1330        )
1331        .unwrap(),
1332    )
1333}
1334
1335fn rust_lang() -> Arc<Language> {
1336    Arc::new(
1337        Language::new(
1338            LanguageConfig {
1339                name: "Rust".into(),
1340                path_suffixes: vec!["rs".into()],
1341                collapsed_placeholder: " /* ... */ ".to_string(),
1342                ..Default::default()
1343            },
1344            Some(tree_sitter_rust::language()),
1345        )
1346        .with_embedding_query(
1347            r#"
1348            (
1349                [(line_comment) (attribute_item)]* @context
1350                .
1351                [
1352                    (struct_item
1353                        name: (_) @name)
1354
1355                    (enum_item
1356                        name: (_) @name)
1357
1358                    (impl_item
1359                        trait: (_)? @name
1360                        "for"? @name
1361                        type: (_) @name)
1362
1363                    (trait_item
1364                        name: (_) @name)
1365
1366                    (function_item
1367                        name: (_) @name
1368                        body: (block
1369                            "{" @keep
1370                            "}" @keep) @collapse)
1371
1372                    (macro_definition
1373                        name: (_) @name)
1374                ] @item
1375            )
1376
1377            (attribute_item) @collapse
1378            (use_declaration) @collapse
1379            "#,
1380        )
1381        .unwrap(),
1382    )
1383}
1384
1385fn json_lang() -> Arc<Language> {
1386    Arc::new(
1387        Language::new(
1388            LanguageConfig {
1389                name: "JSON".into(),
1390                path_suffixes: vec!["json".into()],
1391                ..Default::default()
1392            },
1393            Some(tree_sitter_json::language()),
1394        )
1395        .with_embedding_query(
1396            r#"
1397            (document) @item
1398
1399            (array
1400                "[" @keep
1401                .
1402                (object)? @keep
1403                "]" @keep) @collapse
1404
1405            (pair value: (string
1406                "\"" @keep
1407                "\"" @keep) @collapse)
1408            "#,
1409        )
1410        .unwrap(),
1411    )
1412}
1413
1414fn toml_lang() -> Arc<Language> {
1415    Arc::new(Language::new(
1416        LanguageConfig {
1417            name: "TOML".into(),
1418            path_suffixes: vec!["toml".into()],
1419            ..Default::default()
1420        },
1421        Some(tree_sitter_toml::language()),
1422    ))
1423}
1424
1425fn cpp_lang() -> Arc<Language> {
1426    Arc::new(
1427        Language::new(
1428            LanguageConfig {
1429                name: "CPP".into(),
1430                path_suffixes: vec!["cpp".into()],
1431                ..Default::default()
1432            },
1433            Some(tree_sitter_cpp::language()),
1434        )
1435        .with_embedding_query(
1436            r#"
1437            (
1438                (comment)* @context
1439                .
1440                (function_definition
1441                    (type_qualifier)? @name
1442                    type: (_)? @name
1443                    declarator: [
1444                        (function_declarator
1445                            declarator: (_) @name)
1446                        (pointer_declarator
1447                            "*" @name
1448                            declarator: (function_declarator
1449                            declarator: (_) @name))
1450                        (pointer_declarator
1451                            "*" @name
1452                            declarator: (pointer_declarator
1453                                "*" @name
1454                            declarator: (function_declarator
1455                                declarator: (_) @name)))
1456                        (reference_declarator
1457                            ["&" "&&"] @name
1458                            (function_declarator
1459                            declarator: (_) @name))
1460                    ]
1461                    (type_qualifier)? @name) @item
1462                )
1463
1464            (
1465                (comment)* @context
1466                .
1467                (template_declaration
1468                    (class_specifier
1469                        "class" @name
1470                        name: (_) @name)
1471                        ) @item
1472            )
1473
1474            (
1475                (comment)* @context
1476                .
1477                (class_specifier
1478                    "class" @name
1479                    name: (_) @name) @item
1480                )
1481
1482            (
1483                (comment)* @context
1484                .
1485                (enum_specifier
1486                    "enum" @name
1487                    name: (_) @name) @item
1488                )
1489
1490            (
1491                (comment)* @context
1492                .
1493                (declaration
1494                    type: (struct_specifier
1495                    "struct" @name)
1496                    declarator: (_) @name) @item
1497            )
1498
1499            "#,
1500        )
1501        .unwrap(),
1502    )
1503}
1504
1505fn lua_lang() -> Arc<Language> {
1506    Arc::new(
1507        Language::new(
1508            LanguageConfig {
1509                name: "Lua".into(),
1510                path_suffixes: vec!["lua".into()],
1511                collapsed_placeholder: "--[ ... ]--".to_string(),
1512                ..Default::default()
1513            },
1514            Some(tree_sitter_lua::language()),
1515        )
1516        .with_embedding_query(
1517            r#"
1518            (
1519                (comment)* @context
1520                .
1521                (function_declaration
1522                    "function" @name
1523                    name: (_) @name
1524                    (comment)* @collapse
1525                    body: (block) @collapse
1526                ) @item
1527            )
1528        "#,
1529        )
1530        .unwrap(),
1531    )
1532}
1533
1534fn php_lang() -> Arc<Language> {
1535    Arc::new(
1536        Language::new(
1537            LanguageConfig {
1538                name: "PHP".into(),
1539                path_suffixes: vec!["php".into()],
1540                collapsed_placeholder: "/* ... */".into(),
1541                ..Default::default()
1542            },
1543            Some(tree_sitter_php::language()),
1544        )
1545        .with_embedding_query(
1546            r#"
1547            (
1548                (comment)* @context
1549                .
1550                [
1551                    (function_definition
1552                        "function" @name
1553                        name: (_) @name
1554                        body: (_
1555                            "{" @keep
1556                            "}" @keep) @collapse
1557                        )
1558
1559                    (trait_declaration
1560                        "trait" @name
1561                        name: (_) @name)
1562
1563                    (method_declaration
1564                        "function" @name
1565                        name: (_) @name
1566                        body: (_
1567                            "{" @keep
1568                            "}" @keep) @collapse
1569                        )
1570
1571                    (interface_declaration
1572                        "interface" @name
1573                        name: (_) @name
1574                        )
1575
1576                    (enum_declaration
1577                        "enum" @name
1578                        name: (_) @name
1579                        )
1580
1581                ] @item
1582            )
1583            "#,
1584        )
1585        .unwrap(),
1586    )
1587}
1588
1589fn ruby_lang() -> Arc<Language> {
1590    Arc::new(
1591        Language::new(
1592            LanguageConfig {
1593                name: "Ruby".into(),
1594                path_suffixes: vec!["rb".into()],
1595                collapsed_placeholder: "# ...".to_string(),
1596                ..Default::default()
1597            },
1598            Some(tree_sitter_ruby::language()),
1599        )
1600        .with_embedding_query(
1601            r#"
1602            (
1603                (comment)* @context
1604                .
1605                [
1606                (module
1607                    "module" @name
1608                    name: (_) @name)
1609                (method
1610                    "def" @name
1611                    name: (_) @name
1612                    body: (body_statement) @collapse)
1613                (class
1614                    "class" @name
1615                    name: (_) @name)
1616                (singleton_method
1617                    "def" @name
1618                    object: (_) @name
1619                    "." @name
1620                    name: (_) @name
1621                    body: (body_statement) @collapse)
1622                ] @item
1623            )
1624            "#,
1625        )
1626        .unwrap(),
1627    )
1628}
1629
1630fn elixir_lang() -> Arc<Language> {
1631    Arc::new(
1632        Language::new(
1633            LanguageConfig {
1634                name: "Elixir".into(),
1635                path_suffixes: vec!["rs".into()],
1636                ..Default::default()
1637            },
1638            Some(tree_sitter_elixir::language()),
1639        )
1640        .with_embedding_query(
1641            r#"
1642            (
1643                (unary_operator
1644                    operator: "@"
1645                    operand: (call
1646                        target: (identifier) @unary
1647                        (#match? @unary "^(doc)$"))
1648                    ) @context
1649                .
1650                (call
1651                target: (identifier) @name
1652                (arguments
1653                [
1654                (identifier) @name
1655                (call
1656                target: (identifier) @name)
1657                (binary_operator
1658                left: (call
1659                target: (identifier) @name)
1660                operator: "when")
1661                ])
1662                (#any-match? @name "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @item
1663                )
1664
1665            (call
1666                target: (identifier) @name
1667                (arguments (alias) @name)
1668                (#any-match? @name "^(defmodule|defprotocol)$")) @item
1669            "#,
1670        )
1671        .unwrap(),
1672    )
1673}
1674
1675#[gpui::test]
1676fn test_subtract_ranges() {
1677    // collapsed_ranges: Vec<Range<usize>>, keep_ranges: Vec<Range<usize>>
1678
1679    assert_eq!(
1680        subtract_ranges(&[0..5, 10..21], &[0..1, 4..5]),
1681        vec![1..4, 10..21]
1682    );
1683
1684    assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
1685}
1686
1687fn init_test(cx: &mut TestAppContext) {
1688    cx.update(|cx| {
1689        cx.set_global(SettingsStore::test(cx));
1690        settings::register::<SemanticIndexSettings>(cx);
1691        settings::register::<ProjectSettings>(cx);
1692    });
1693}