semantic_index_tests.rs

   1use crate::{
   2    embedding::{DummyEmbeddings, Embedding, EmbeddingProvider},
   3    embedding_queue::EmbeddingQueue,
   4    parsing::{subtract_ranges, CodeContextRetriever, Document, DocumentDigest},
   5    semantic_index_settings::SemanticIndexSettings,
   6    FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
   7};
   8use anyhow::Result;
   9use async_trait::async_trait;
  10use gpui::{executor::Deterministic, Task, TestAppContext};
  11use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
  12use parking_lot::Mutex;
  13use pretty_assertions::assert_eq;
  14use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs, Project};
  15use rand::{rngs::StdRng, Rng};
  16use serde_json::json;
  17use settings::SettingsStore;
  18use std::{
  19    path::Path,
  20    sync::{
  21        atomic::{self, AtomicUsize},
  22        Arc,
  23    },
  24    time::SystemTime,
  25};
  26use unindent::Unindent;
  27use util::RandomCharIter;
  28
  29#[ctor::ctor]
  30fn init_logger() {
  31    if std::env::var("RUST_LOG").is_ok() {
  32        env_logger::init();
  33    }
  34}
  35
  36#[gpui::test]
  37async fn test_semantic_index(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
  38    init_test(cx);
  39
  40    let fs = FakeFs::new(cx.background());
  41    fs.insert_tree(
  42        "/the-root",
  43        json!({
  44            "src": {
  45                "file1.rs": "
  46                    fn aaa() {
  47                        println!(\"aaaaaaaaaaaa!\");
  48                    }
  49
  50                    fn zzzzz() {
  51                        println!(\"SLEEPING\");
  52                    }
  53                ".unindent(),
  54                "file2.rs": "
  55                    fn bbb() {
  56                        println!(\"bbbbbbbbbbbbb!\");
  57                    }
  58                    struct pqpqpqp {}
  59                ".unindent(),
  60                "file3.toml": "
  61                    ZZZZZZZZZZZZZZZZZZ = 5
  62                ".unindent(),
  63            }
  64        }),
  65    )
  66    .await;
  67
  68    let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
  69    let rust_language = rust_lang();
  70    let toml_language = toml_lang();
  71    languages.add(rust_language);
  72    languages.add(toml_language);
  73
  74    let db_dir = tempdir::TempDir::new("vector-store").unwrap();
  75    let db_path = db_dir.path().join("db.sqlite");
  76
  77    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
  78    let semantic_index = SemanticIndex::new(
  79        fs.clone(),
  80        db_path,
  81        embedding_provider.clone(),
  82        languages,
  83        cx.to_async(),
  84    )
  85    .await
  86    .unwrap();
  87
  88    let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
  89
  90    let _ = semantic_index
  91        .update(cx, |store, cx| {
  92            store.initialize_project(project.clone(), cx)
  93        })
  94        .await;
  95
  96    let (file_count, outstanding_file_count) = semantic_index
  97        .update(cx, |store, cx| store.index_project(project.clone(), cx))
  98        .await
  99        .unwrap();
 100    assert_eq!(file_count, 3);
 101    deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
 102    assert_eq!(*outstanding_file_count.borrow(), 0);
 103
 104    let search_results = semantic_index
 105        .update(cx, |store, cx| {
 106            store.search_project(
 107                project.clone(),
 108                "aaaaaabbbbzz".to_string(),
 109                5,
 110                vec![],
 111                vec![],
 112                cx,
 113            )
 114        })
 115        .await
 116        .unwrap();
 117
 118    assert_search_results(
 119        &search_results,
 120        &[
 121            (Path::new("src/file1.rs").into(), 0),
 122            (Path::new("src/file2.rs").into(), 0),
 123            (Path::new("src/file3.toml").into(), 0),
 124            (Path::new("src/file1.rs").into(), 45),
 125            (Path::new("src/file2.rs").into(), 45),
 126        ],
 127        cx,
 128    );
 129
 130    // Test Include Files Functonality
 131    let include_files = vec![PathMatcher::new("*.rs").unwrap()];
 132    let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
 133    let rust_only_search_results = semantic_index
 134        .update(cx, |store, cx| {
 135            store.search_project(
 136                project.clone(),
 137                "aaaaaabbbbzz".to_string(),
 138                5,
 139                include_files,
 140                vec![],
 141                cx,
 142            )
 143        })
 144        .await
 145        .unwrap();
 146
 147    assert_search_results(
 148        &rust_only_search_results,
 149        &[
 150            (Path::new("src/file1.rs").into(), 0),
 151            (Path::new("src/file2.rs").into(), 0),
 152            (Path::new("src/file1.rs").into(), 45),
 153            (Path::new("src/file2.rs").into(), 45),
 154        ],
 155        cx,
 156    );
 157
 158    let no_rust_search_results = semantic_index
 159        .update(cx, |store, cx| {
 160            store.search_project(
 161                project.clone(),
 162                "aaaaaabbbbzz".to_string(),
 163                5,
 164                vec![],
 165                exclude_files,
 166                cx,
 167            )
 168        })
 169        .await
 170        .unwrap();
 171
 172    assert_search_results(
 173        &no_rust_search_results,
 174        &[(Path::new("src/file3.toml").into(), 0)],
 175        cx,
 176    );
 177
 178    fs.save(
 179        "/the-root/src/file2.rs".as_ref(),
 180        &"
 181            fn dddd() { println!(\"ddddd!\"); }
 182            struct pqpqpqp {}
 183        "
 184        .unindent()
 185        .into(),
 186        Default::default(),
 187    )
 188    .await
 189    .unwrap();
 190
 191    deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
 192
 193    let prev_embedding_count = embedding_provider.embedding_count();
 194    let (file_count, outstanding_file_count) = semantic_index
 195        .update(cx, |store, cx| store.index_project(project.clone(), cx))
 196        .await
 197        .unwrap();
 198    assert_eq!(file_count, 1);
 199
 200    deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
 201    assert_eq!(*outstanding_file_count.borrow(), 0);
 202
 203    assert_eq!(
 204        embedding_provider.embedding_count() - prev_embedding_count,
 205        1
 206    );
 207}
 208
 209#[gpui::test(iterations = 10)]
 210async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
 211    let (outstanding_job_count, _) = postage::watch::channel_with(0);
 212    let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count));
 213
 214    let files = (1..=3)
 215        .map(|file_ix| FileToEmbed {
 216            worktree_id: 5,
 217            path: format!("path-{file_ix}").into(),
 218            mtime: SystemTime::now(),
 219            documents: (0..rng.gen_range(4..22))
 220                .map(|document_ix| {
 221                    let content_len = rng.gen_range(10..100);
 222                    let content = RandomCharIter::new(&mut rng)
 223                        .with_simple_text()
 224                        .take(content_len)
 225                        .collect::<String>();
 226                    let digest = DocumentDigest::from(content.as_str());
 227                    Document {
 228                        range: 0..10,
 229                        embedding: None,
 230                        name: format!("document {document_ix}"),
 231                        content,
 232                        digest,
 233                        token_count: rng.gen_range(10..30),
 234                    }
 235                })
 236                .collect(),
 237            job_handle: JobHandle::new(&outstanding_job_count),
 238        })
 239        .collect::<Vec<_>>();
 240
 241    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 242
 243    let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
 244    for file in &files {
 245        queue.push(file.clone());
 246    }
 247    queue.flush();
 248
 249    cx.foreground().run_until_parked();
 250    let finished_files = queue.finished_files();
 251    let mut embedded_files: Vec<_> = files
 252        .iter()
 253        .map(|_| finished_files.try_recv().expect("no finished file"))
 254        .collect();
 255
 256    let expected_files: Vec<_> = files
 257        .iter()
 258        .map(|file| {
 259            let mut file = file.clone();
 260            for doc in &mut file.documents {
 261                doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
 262            }
 263            file
 264        })
 265        .collect();
 266
 267    embedded_files.sort_by_key(|f| f.path.clone());
 268
 269    assert_eq!(embedded_files, expected_files);
 270}
 271
 272#[track_caller]
 273fn assert_search_results(
 274    actual: &[SearchResult],
 275    expected: &[(Arc<Path>, usize)],
 276    cx: &TestAppContext,
 277) {
 278    let actual = actual
 279        .iter()
 280        .map(|search_result| {
 281            search_result.buffer.read_with(cx, |buffer, _cx| {
 282                (
 283                    buffer.file().unwrap().path().clone(),
 284                    search_result.range.start.to_offset(buffer),
 285                )
 286            })
 287        })
 288        .collect::<Vec<_>>();
 289    assert_eq!(actual, expected);
 290}
 291
 292#[gpui::test]
 293async fn test_code_context_retrieval_rust() {
 294    let language = rust_lang();
 295    let embedding_provider = Arc::new(DummyEmbeddings {});
 296    let mut retriever = CodeContextRetriever::new(embedding_provider);
 297
 298    let text = "
 299        /// A doc comment
 300        /// that spans multiple lines
 301        #[gpui::test]
 302        fn a() {
 303            b
 304        }
 305
 306        impl C for D {
 307        }
 308
 309        impl E {
 310            // This is also a preceding comment
 311            pub fn function_1() -> Option<()> {
 312                todo!();
 313            }
 314
 315            // This is a preceding comment
 316            fn function_2() -> Result<()> {
 317                todo!();
 318            }
 319        }
 320    "
 321    .unindent();
 322
 323    let documents = retriever.parse_file(&text, language).unwrap();
 324
 325    assert_documents_eq(
 326        &documents,
 327        &[
 328            (
 329                "
 330                /// A doc comment
 331                /// that spans multiple lines
 332                #[gpui::test]
 333                fn a() {
 334                    b
 335                }"
 336                .unindent(),
 337                text.find("fn a").unwrap(),
 338            ),
 339            (
 340                "
 341                impl C for D {
 342                }"
 343                .unindent(),
 344                text.find("impl C").unwrap(),
 345            ),
 346            (
 347                "
 348                impl E {
 349                    // This is also a preceding comment
 350                    pub fn function_1() -> Option<()> { /* ... */ }
 351
 352                    // This is a preceding comment
 353                    fn function_2() -> Result<()> { /* ... */ }
 354                }"
 355                .unindent(),
 356                text.find("impl E").unwrap(),
 357            ),
 358            (
 359                "
 360                // This is also a preceding comment
 361                pub fn function_1() -> Option<()> {
 362                    todo!();
 363                }"
 364                .unindent(),
 365                text.find("pub fn function_1").unwrap(),
 366            ),
 367            (
 368                "
 369                // This is a preceding comment
 370                fn function_2() -> Result<()> {
 371                    todo!();
 372                }"
 373                .unindent(),
 374                text.find("fn function_2").unwrap(),
 375            ),
 376        ],
 377    );
 378}
 379
 380#[gpui::test]
 381async fn test_code_context_retrieval_json() {
 382    let language = json_lang();
 383    let embedding_provider = Arc::new(DummyEmbeddings {});
 384    let mut retriever = CodeContextRetriever::new(embedding_provider);
 385
 386    let text = r#"
 387        {
 388            "array": [1, 2, 3, 4],
 389            "string": "abcdefg",
 390            "nested_object": {
 391                "array_2": [5, 6, 7, 8],
 392                "string_2": "hijklmnop",
 393                "boolean": true,
 394                "none": null
 395            }
 396        }
 397    "#
 398    .unindent();
 399
 400    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 401
 402    assert_documents_eq(
 403        &documents,
 404        &[(
 405            r#"
 406                {
 407                    "array": [],
 408                    "string": "",
 409                    "nested_object": {
 410                        "array_2": [],
 411                        "string_2": "",
 412                        "boolean": true,
 413                        "none": null
 414                    }
 415                }"#
 416            .unindent(),
 417            text.find("{").unwrap(),
 418        )],
 419    );
 420
 421    let text = r#"
 422        [
 423            {
 424                "name": "somebody",
 425                "age": 42
 426            },
 427            {
 428                "name": "somebody else",
 429                "age": 43
 430            }
 431        ]
 432    "#
 433    .unindent();
 434
 435    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 436
 437    assert_documents_eq(
 438        &documents,
 439        &[(
 440            r#"
 441            [{
 442                    "name": "",
 443                    "age": 42
 444                }]"#
 445            .unindent(),
 446            text.find("[").unwrap(),
 447        )],
 448    );
 449}
 450
 451fn assert_documents_eq(
 452    documents: &[Document],
 453    expected_contents_and_start_offsets: &[(String, usize)],
 454) {
 455    assert_eq!(
 456        documents
 457            .iter()
 458            .map(|document| (document.content.clone(), document.range.start))
 459            .collect::<Vec<_>>(),
 460        expected_contents_and_start_offsets
 461    );
 462}
 463
 464#[gpui::test]
 465async fn test_code_context_retrieval_javascript() {
 466    let language = js_lang();
 467    let embedding_provider = Arc::new(DummyEmbeddings {});
 468    let mut retriever = CodeContextRetriever::new(embedding_provider);
 469
 470    let text = "
 471        /* globals importScripts, backend */
 472        function _authorize() {}
 473
 474        /**
 475         * Sometimes the frontend build is way faster than backend.
 476         */
 477        export async function authorizeBank() {
 478            _authorize(pushModal, upgradingAccountId, {});
 479        }
 480
 481        export class SettingsPage {
 482            /* This is a test setting */
 483            constructor(page) {
 484                this.page = page;
 485            }
 486        }
 487
 488        /* This is a test comment */
 489        class TestClass {}
 490
 491        /* Schema for editor_events in Clickhouse. */
 492        export interface ClickhouseEditorEvent {
 493            installation_id: string
 494            operation: string
 495        }
 496        "
 497    .unindent();
 498
 499    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 500
 501    assert_documents_eq(
 502        &documents,
 503        &[
 504            (
 505                "
 506            /* globals importScripts, backend */
 507            function _authorize() {}"
 508                    .unindent(),
 509                37,
 510            ),
 511            (
 512                "
 513            /**
 514             * Sometimes the frontend build is way faster than backend.
 515             */
 516            export async function authorizeBank() {
 517                _authorize(pushModal, upgradingAccountId, {});
 518            }"
 519                .unindent(),
 520                131,
 521            ),
 522            (
 523                "
 524                export class SettingsPage {
 525                    /* This is a test setting */
 526                    constructor(page) {
 527                        this.page = page;
 528                    }
 529                }"
 530                .unindent(),
 531                225,
 532            ),
 533            (
 534                "
 535                /* This is a test setting */
 536                constructor(page) {
 537                    this.page = page;
 538                }"
 539                .unindent(),
 540                290,
 541            ),
 542            (
 543                "
 544                /* This is a test comment */
 545                class TestClass {}"
 546                    .unindent(),
 547                374,
 548            ),
 549            (
 550                "
 551                /* Schema for editor_events in Clickhouse. */
 552                export interface ClickhouseEditorEvent {
 553                    installation_id: string
 554                    operation: string
 555                }"
 556                .unindent(),
 557                440,
 558            ),
 559        ],
 560    )
 561}
 562
 563#[gpui::test]
 564async fn test_code_context_retrieval_lua() {
 565    let language = lua_lang();
 566    let embedding_provider = Arc::new(DummyEmbeddings {});
 567    let mut retriever = CodeContextRetriever::new(embedding_provider);
 568
 569    let text = r#"
 570        -- Creates a new class
 571        -- @param baseclass The Baseclass of this class, or nil.
 572        -- @return A new class reference.
 573        function classes.class(baseclass)
 574            -- Create the class definition and metatable.
 575            local classdef = {}
 576            -- Find the super class, either Object or user-defined.
 577            baseclass = baseclass or classes.Object
 578            -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
 579            setmetatable(classdef, { __index = baseclass })
 580            -- All class instances have a reference to the class object.
 581            classdef.class = classdef
 582            --- Recursivly allocates the inheritance tree of the instance.
 583            -- @param mastertable The 'root' of the inheritance tree.
 584            -- @return Returns the instance with the allocated inheritance tree.
 585            function classdef.alloc(mastertable)
 586                -- All class instances have a reference to a superclass object.
 587                local instance = { super = baseclass.alloc(mastertable) }
 588                -- Any functions this instance does not know of will 'look up' to the superclass definition.
 589                setmetatable(instance, { __index = classdef, __newindex = mastertable })
 590                return instance
 591            end
 592        end
 593        "#.unindent();
 594
 595    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 596
 597    assert_documents_eq(
 598        &documents,
 599        &[
 600            (r#"
 601                -- Creates a new class
 602                -- @param baseclass The Baseclass of this class, or nil.
 603                -- @return A new class reference.
 604                function classes.class(baseclass)
 605                    -- Create the class definition and metatable.
 606                    local classdef = {}
 607                    -- Find the super class, either Object or user-defined.
 608                    baseclass = baseclass or classes.Object
 609                    -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
 610                    setmetatable(classdef, { __index = baseclass })
 611                    -- All class instances have a reference to the class object.
 612                    classdef.class = classdef
 613                    --- Recursivly allocates the inheritance tree of the instance.
 614                    -- @param mastertable The 'root' of the inheritance tree.
 615                    -- @return Returns the instance with the allocated inheritance tree.
 616                    function classdef.alloc(mastertable)
 617                        --[ ... ]--
 618                        --[ ... ]--
 619                    end
 620                end"#.unindent(),
 621            114),
 622            (r#"
 623            --- Recursivly allocates the inheritance tree of the instance.
 624            -- @param mastertable The 'root' of the inheritance tree.
 625            -- @return Returns the instance with the allocated inheritance tree.
 626            function classdef.alloc(mastertable)
 627                -- All class instances have a reference to a superclass object.
 628                local instance = { super = baseclass.alloc(mastertable) }
 629                -- Any functions this instance does not know of will 'look up' to the superclass definition.
 630                setmetatable(instance, { __index = classdef, __newindex = mastertable })
 631                return instance
 632            end"#.unindent(), 809),
 633        ]
 634    );
 635}
 636
 637#[gpui::test]
 638async fn test_code_context_retrieval_elixir() {
 639    let language = elixir_lang();
 640    let embedding_provider = Arc::new(DummyEmbeddings {});
 641    let mut retriever = CodeContextRetriever::new(embedding_provider);
 642
 643    let text = r#"
 644        defmodule File.Stream do
 645            @moduledoc """
 646            Defines a `File.Stream` struct returned by `File.stream!/3`.
 647
 648            The following fields are public:
 649
 650            * `path`          - the file path
 651            * `modes`         - the file modes
 652            * `raw`           - a boolean indicating if bin functions should be used
 653            * `line_or_bytes` - if reading should read lines or a given number of bytes
 654            * `node`          - the node the file belongs to
 655
 656            """
 657
 658            defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
 659
 660            @type t :: %__MODULE__{}
 661
 662            @doc false
 663            def __build__(path, modes, line_or_bytes) do
 664            raw = :lists.keyfind(:encoding, 1, modes) == false
 665
 666            modes =
 667                case raw do
 668                true ->
 669                    case :lists.keyfind(:read_ahead, 1, modes) do
 670                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
 671                    {:read_ahead, _} -> [:raw | modes]
 672                    false -> [:raw, :read_ahead | modes]
 673                    end
 674
 675                false ->
 676                    modes
 677                end
 678
 679            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
 680
 681            end"#
 682    .unindent();
 683
 684    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 685
 686    assert_documents_eq(
 687        &documents,
 688        &[(
 689            r#"
 690        defmodule File.Stream do
 691            @moduledoc """
 692            Defines a `File.Stream` struct returned by `File.stream!/3`.
 693
 694            The following fields are public:
 695
 696            * `path`          - the file path
 697            * `modes`         - the file modes
 698            * `raw`           - a boolean indicating if bin functions should be used
 699            * `line_or_bytes` - if reading should read lines or a given number of bytes
 700            * `node`          - the node the file belongs to
 701
 702            """
 703
 704            defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
 705
 706            @type t :: %__MODULE__{}
 707
 708            @doc false
 709            def __build__(path, modes, line_or_bytes) do
 710            raw = :lists.keyfind(:encoding, 1, modes) == false
 711
 712            modes =
 713                case raw do
 714                true ->
 715                    case :lists.keyfind(:read_ahead, 1, modes) do
 716                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
 717                    {:read_ahead, _} -> [:raw | modes]
 718                    false -> [:raw, :read_ahead | modes]
 719                    end
 720
 721                false ->
 722                    modes
 723                end
 724
 725            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
 726
 727            end"#
 728                .unindent(),
 729            0,
 730        ),(r#"
 731            @doc false
 732            def __build__(path, modes, line_or_bytes) do
 733            raw = :lists.keyfind(:encoding, 1, modes) == false
 734
 735            modes =
 736                case raw do
 737                true ->
 738                    case :lists.keyfind(:read_ahead, 1, modes) do
 739                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
 740                    {:read_ahead, _} -> [:raw | modes]
 741                    false -> [:raw, :read_ahead | modes]
 742                    end
 743
 744                false ->
 745                    modes
 746                end
 747
 748            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
 749
 750            end"#.unindent(), 574)],
 751    );
 752}
 753
 754#[gpui::test]
 755async fn test_code_context_retrieval_cpp() {
 756    let language = cpp_lang();
 757    let embedding_provider = Arc::new(DummyEmbeddings {});
 758    let mut retriever = CodeContextRetriever::new(embedding_provider);
 759
 760    let text = "
 761    /**
 762     * @brief Main function
 763     * @returns 0 on exit
 764     */
 765    int main() { return 0; }
 766
 767    /**
 768    * This is a test comment
 769    */
 770    class MyClass {       // The class
 771        public:           // Access specifier
 772        int myNum;        // Attribute (int variable)
 773        string myString;  // Attribute (string variable)
 774    };
 775
 776    // This is a test comment
 777    enum Color { red, green, blue };
 778
 779    /** This is a preceding block comment
 780     * This is the second line
 781     */
 782    struct {           // Structure declaration
 783        int myNum;       // Member (int variable)
 784        string myString; // Member (string variable)
 785    } myStructure;
 786
 787    /**
 788     * @brief Matrix class.
 789     */
 790    template <typename T,
 791              typename = typename std::enable_if<
 792                std::is_integral<T>::value || std::is_floating_point<T>::value,
 793                bool>::type>
 794    class Matrix2 {
 795        std::vector<std::vector<T>> _mat;
 796
 797        public:
 798            /**
 799            * @brief Constructor
 800            * @tparam Integer ensuring integers are being evaluated and not other
 801            * data types.
 802            * @param size denoting the size of Matrix as size x size
 803            */
 804            template <typename Integer,
 805                    typename = typename std::enable_if<std::is_integral<Integer>::value,
 806                    Integer>::type>
 807            explicit Matrix(const Integer size) {
 808                for (size_t i = 0; i < size; ++i) {
 809                    _mat.emplace_back(std::vector<T>(size, 0));
 810                }
 811            }
 812    }"
 813    .unindent();
 814
 815    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 816
 817    assert_documents_eq(
 818        &documents,
 819        &[
 820            (
 821                "
 822        /**
 823         * @brief Main function
 824         * @returns 0 on exit
 825         */
 826        int main() { return 0; }"
 827                    .unindent(),
 828                54,
 829            ),
 830            (
 831                "
 832                /**
 833                * This is a test comment
 834                */
 835                class MyClass {       // The class
 836                    public:           // Access specifier
 837                    int myNum;        // Attribute (int variable)
 838                    string myString;  // Attribute (string variable)
 839                }"
 840                .unindent(),
 841                112,
 842            ),
 843            (
 844                "
 845                // This is a test comment
 846                enum Color { red, green, blue }"
 847                    .unindent(),
 848                322,
 849            ),
 850            (
 851                "
 852                /** This is a preceding block comment
 853                 * This is the second line
 854                 */
 855                struct {           // Structure declaration
 856                    int myNum;       // Member (int variable)
 857                    string myString; // Member (string variable)
 858                } myStructure;"
 859                    .unindent(),
 860                425,
 861            ),
 862            (
 863                "
 864                /**
 865                 * @brief Matrix class.
 866                 */
 867                template <typename T,
 868                          typename = typename std::enable_if<
 869                            std::is_integral<T>::value || std::is_floating_point<T>::value,
 870                            bool>::type>
 871                class Matrix2 {
 872                    std::vector<std::vector<T>> _mat;
 873
 874                    public:
 875                        /**
 876                        * @brief Constructor
 877                        * @tparam Integer ensuring integers are being evaluated and not other
 878                        * data types.
 879                        * @param size denoting the size of Matrix as size x size
 880                        */
 881                        template <typename Integer,
 882                                typename = typename std::enable_if<std::is_integral<Integer>::value,
 883                                Integer>::type>
 884                        explicit Matrix(const Integer size) {
 885                            for (size_t i = 0; i < size; ++i) {
 886                                _mat.emplace_back(std::vector<T>(size, 0));
 887                            }
 888                        }
 889                }"
 890                .unindent(),
 891                612,
 892            ),
 893            (
 894                "
 895                explicit Matrix(const Integer size) {
 896                    for (size_t i = 0; i < size; ++i) {
 897                        _mat.emplace_back(std::vector<T>(size, 0));
 898                    }
 899                }"
 900                .unindent(),
 901                1226,
 902            ),
 903        ],
 904    );
 905}
 906
 907#[gpui::test]
 908async fn test_code_context_retrieval_ruby() {
 909    let language = ruby_lang();
 910    let embedding_provider = Arc::new(DummyEmbeddings {});
 911    let mut retriever = CodeContextRetriever::new(embedding_provider);
 912
 913    let text = r#"
 914        # This concern is inspired by "sudo mode" on GitHub. It
 915        # is a way to re-authenticate a user before allowing them
 916        # to see or perform an action.
 917        #
 918        # Add `before_action :require_challenge!` to actions you
 919        # want to protect.
 920        #
 921        # The user will be shown a page to enter the challenge (which
 922        # is either the password, or just the username when no
 923        # password exists). Upon passing, there is a grace period
 924        # during which no challenge will be asked from the user.
 925        #
 926        # Accessing challenge-protected resources during the grace
 927        # period will refresh the grace period.
 928        module ChallengableConcern
 929            extend ActiveSupport::Concern
 930
 931            CHALLENGE_TIMEOUT = 1.hour.freeze
 932
 933            def require_challenge!
 934                return if skip_challenge?
 935
 936                if challenge_passed_recently?
 937                    session[:challenge_passed_at] = Time.now.utc
 938                    return
 939                end
 940
 941                @challenge = Form::Challenge.new(return_to: request.url)
 942
 943                if params.key?(:form_challenge)
 944                    if challenge_passed?
 945                        session[:challenge_passed_at] = Time.now.utc
 946                    else
 947                        flash.now[:alert] = I18n.t('challenge.invalid_password')
 948                        render_challenge
 949                    end
 950                else
 951                    render_challenge
 952                end
 953            end
 954
 955            def challenge_passed?
 956                current_user.valid_password?(challenge_params[:current_password])
 957            end
 958        end
 959
 960        class Animal
 961            include Comparable
 962
 963            attr_reader :legs
 964
 965            def initialize(name, legs)
 966                @name, @legs = name, legs
 967            end
 968
 969            def <=>(other)
 970                legs <=> other.legs
 971            end
 972        end
 973
 974        # Singleton method for car object
 975        def car.wheels
 976            puts "There are four wheels"
 977        end"#
 978        .unindent();
 979
 980    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 981
 982    assert_documents_eq(
 983        &documents,
 984        &[
 985            (
 986                r#"
 987        # This concern is inspired by "sudo mode" on GitHub. It
 988        # is a way to re-authenticate a user before allowing them
 989        # to see or perform an action.
 990        #
 991        # Add `before_action :require_challenge!` to actions you
 992        # want to protect.
 993        #
 994        # The user will be shown a page to enter the challenge (which
 995        # is either the password, or just the username when no
 996        # password exists). Upon passing, there is a grace period
 997        # during which no challenge will be asked from the user.
 998        #
 999        # Accessing challenge-protected resources during the grace
1000        # period will refresh the grace period.
1001        module ChallengableConcern
1002            extend ActiveSupport::Concern
1003
1004            CHALLENGE_TIMEOUT = 1.hour.freeze
1005
1006            def require_challenge!
1007                # ...
1008            end
1009
1010            def challenge_passed?
1011                # ...
1012            end
1013        end"#
1014                    .unindent(),
1015                558,
1016            ),
1017            (
1018                r#"
1019            def require_challenge!
1020                return if skip_challenge?
1021
1022                if challenge_passed_recently?
1023                    session[:challenge_passed_at] = Time.now.utc
1024                    return
1025                end
1026
1027                @challenge = Form::Challenge.new(return_to: request.url)
1028
1029                if params.key?(:form_challenge)
1030                    if challenge_passed?
1031                        session[:challenge_passed_at] = Time.now.utc
1032                    else
1033                        flash.now[:alert] = I18n.t('challenge.invalid_password')
1034                        render_challenge
1035                    end
1036                else
1037                    render_challenge
1038                end
1039            end"#
1040                    .unindent(),
1041                663,
1042            ),
1043            (
1044                r#"
1045                def challenge_passed?
1046                    current_user.valid_password?(challenge_params[:current_password])
1047                end"#
1048                    .unindent(),
1049                1254,
1050            ),
1051            (
1052                r#"
1053                class Animal
1054                    include Comparable
1055
1056                    attr_reader :legs
1057
1058                    def initialize(name, legs)
1059                        # ...
1060                    end
1061
1062                    def <=>(other)
1063                        # ...
1064                    end
1065                end"#
1066                    .unindent(),
1067                1363,
1068            ),
1069            (
1070                r#"
1071                def initialize(name, legs)
1072                    @name, @legs = name, legs
1073                end"#
1074                    .unindent(),
1075                1427,
1076            ),
1077            (
1078                r#"
1079                def <=>(other)
1080                    legs <=> other.legs
1081                end"#
1082                    .unindent(),
1083                1501,
1084            ),
1085            (
1086                r#"
1087                # Singleton method for car object
1088                def car.wheels
1089                    puts "There are four wheels"
1090                end"#
1091                    .unindent(),
1092                1591,
1093            ),
1094        ],
1095    );
1096}
1097
1098#[gpui::test]
1099async fn test_code_context_retrieval_php() {
1100    let language = php_lang();
1101    let embedding_provider = Arc::new(DummyEmbeddings {});
1102    let mut retriever = CodeContextRetriever::new(embedding_provider);
1103
1104    let text = r#"
1105        <?php
1106
1107        namespace LevelUp\Experience\Concerns;
1108
1109        /*
1110        This is a multiple-lines comment block
1111        that spans over multiple
1112        lines
1113        */
1114        function functionName() {
1115            echo "Hello world!";
1116        }
1117
1118        trait HasAchievements
1119        {
1120            /**
1121            * @throws \Exception
1122            */
1123            public function grantAchievement(Achievement $achievement, $progress = null): void
1124            {
1125                if ($progress > 100) {
1126                    throw new Exception(message: 'Progress cannot be greater than 100');
1127                }
1128
1129                if ($this->achievements()->find($achievement->id)) {
1130                    throw new Exception(message: 'User already has this Achievement');
1131                }
1132
1133                $this->achievements()->attach($achievement, [
1134                    'progress' => $progress ?? null,
1135                ]);
1136
1137                $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1138            }
1139
1140            public function achievements(): BelongsToMany
1141            {
1142                return $this->belongsToMany(related: Achievement::class)
1143                ->withPivot(columns: 'progress')
1144                ->where('is_secret', false)
1145                ->using(AchievementUser::class);
1146            }
1147        }
1148
1149        interface Multiplier
1150        {
1151            public function qualifies(array $data): bool;
1152
1153            public function setMultiplier(): int;
1154        }
1155
1156        enum AuditType: string
1157        {
1158            case Add = 'add';
1159            case Remove = 'remove';
1160            case Reset = 'reset';
1161            case LevelUp = 'level_up';
1162        }
1163
1164        ?>"#
1165    .unindent();
1166
1167    let documents = retriever.parse_file(&text, language.clone()).unwrap();
1168
1169    assert_documents_eq(
1170        &documents,
1171        &[
1172            (
1173                r#"
1174        /*
1175        This is a multiple-lines comment block
1176        that spans over multiple
1177        lines
1178        */
1179        function functionName() {
1180            echo "Hello world!";
1181        }"#
1182                .unindent(),
1183                123,
1184            ),
1185            (
1186                r#"
1187        trait HasAchievements
1188        {
1189            /**
1190            * @throws \Exception
1191            */
1192            public function grantAchievement(Achievement $achievement, $progress = null): void
1193            {/* ... */}
1194
1195            public function achievements(): BelongsToMany
1196            {/* ... */}
1197        }"#
1198                .unindent(),
1199                177,
1200            ),
1201            (r#"
1202            /**
1203            * @throws \Exception
1204            */
1205            public function grantAchievement(Achievement $achievement, $progress = null): void
1206            {
1207                if ($progress > 100) {
1208                    throw new Exception(message: 'Progress cannot be greater than 100');
1209                }
1210
1211                if ($this->achievements()->find($achievement->id)) {
1212                    throw new Exception(message: 'User already has this Achievement');
1213                }
1214
1215                $this->achievements()->attach($achievement, [
1216                    'progress' => $progress ?? null,
1217                ]);
1218
1219                $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1220            }"#.unindent(), 245),
1221            (r#"
1222                public function achievements(): BelongsToMany
1223                {
1224                    return $this->belongsToMany(related: Achievement::class)
1225                    ->withPivot(columns: 'progress')
1226                    ->where('is_secret', false)
1227                    ->using(AchievementUser::class);
1228                }"#.unindent(), 902),
1229            (r#"
1230                interface Multiplier
1231                {
1232                    public function qualifies(array $data): bool;
1233
1234                    public function setMultiplier(): int;
1235                }"#.unindent(),
1236                1146),
1237            (r#"
1238                enum AuditType: string
1239                {
1240                    case Add = 'add';
1241                    case Remove = 'remove';
1242                    case Reset = 'reset';
1243                    case LevelUp = 'level_up';
1244                }"#.unindent(), 1265)
1245        ],
1246    );
1247}
1248
1249#[derive(Default)]
1250struct FakeEmbeddingProvider {
1251    embedding_count: AtomicUsize,
1252}
1253
1254impl FakeEmbeddingProvider {
1255    fn embedding_count(&self) -> usize {
1256        self.embedding_count.load(atomic::Ordering::SeqCst)
1257    }
1258
1259    fn embed_sync(&self, span: &str) -> Embedding {
1260        let mut result = vec![1.0; 26];
1261        for letter in span.chars() {
1262            let letter = letter.to_ascii_lowercase();
1263            if letter as u32 >= 'a' as u32 {
1264                let ix = (letter as u32) - ('a' as u32);
1265                if ix < 26 {
1266                    result[ix as usize] += 1.0;
1267                }
1268            }
1269        }
1270
1271        let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
1272        for x in &mut result {
1273            *x /= norm;
1274        }
1275
1276        result.into()
1277    }
1278}
1279
1280#[async_trait]
1281impl EmbeddingProvider for FakeEmbeddingProvider {
1282    fn truncate(&self, span: &str) -> (String, usize) {
1283        (span.to_string(), 1)
1284    }
1285
1286    fn max_tokens_per_batch(&self) -> usize {
1287        200
1288    }
1289
1290    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
1291        self.embedding_count
1292            .fetch_add(spans.len(), atomic::Ordering::SeqCst);
1293        Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
1294    }
1295}
1296
1297fn js_lang() -> Arc<Language> {
1298    Arc::new(
1299        Language::new(
1300            LanguageConfig {
1301                name: "Javascript".into(),
1302                path_suffixes: vec!["js".into()],
1303                ..Default::default()
1304            },
1305            Some(tree_sitter_typescript::language_tsx()),
1306        )
1307        .with_embedding_query(
1308            &r#"
1309
1310            (
1311                (comment)* @context
1312                .
1313                [
1314                (export_statement
1315                    (function_declaration
1316                        "async"? @name
1317                        "function" @name
1318                        name: (_) @name))
1319                (function_declaration
1320                    "async"? @name
1321                    "function" @name
1322                    name: (_) @name)
1323                ] @item
1324            )
1325
1326            (
1327                (comment)* @context
1328                .
1329                [
1330                (export_statement
1331                    (class_declaration
1332                        "class" @name
1333                        name: (_) @name))
1334                (class_declaration
1335                    "class" @name
1336                    name: (_) @name)
1337                ] @item
1338            )
1339
1340            (
1341                (comment)* @context
1342                .
1343                [
1344                (export_statement
1345                    (interface_declaration
1346                        "interface" @name
1347                        name: (_) @name))
1348                (interface_declaration
1349                    "interface" @name
1350                    name: (_) @name)
1351                ] @item
1352            )
1353
1354            (
1355                (comment)* @context
1356                .
1357                [
1358                (export_statement
1359                    (enum_declaration
1360                        "enum" @name
1361                        name: (_) @name))
1362                (enum_declaration
1363                    "enum" @name
1364                    name: (_) @name)
1365                ] @item
1366            )
1367
1368            (
1369                (comment)* @context
1370                .
1371                (method_definition
1372                    [
1373                        "get"
1374                        "set"
1375                        "async"
1376                        "*"
1377                        "static"
1378                    ]* @name
1379                    name: (_) @name) @item
1380            )
1381
1382                    "#
1383            .unindent(),
1384        )
1385        .unwrap(),
1386    )
1387}
1388
1389fn rust_lang() -> Arc<Language> {
1390    Arc::new(
1391        Language::new(
1392            LanguageConfig {
1393                name: "Rust".into(),
1394                path_suffixes: vec!["rs".into()],
1395                collapsed_placeholder: " /* ... */ ".to_string(),
1396                ..Default::default()
1397            },
1398            Some(tree_sitter_rust::language()),
1399        )
1400        .with_embedding_query(
1401            r#"
1402            (
1403                [(line_comment) (attribute_item)]* @context
1404                .
1405                [
1406                    (struct_item
1407                        name: (_) @name)
1408
1409                    (enum_item
1410                        name: (_) @name)
1411
1412                    (impl_item
1413                        trait: (_)? @name
1414                        "for"? @name
1415                        type: (_) @name)
1416
1417                    (trait_item
1418                        name: (_) @name)
1419
1420                    (function_item
1421                        name: (_) @name
1422                        body: (block
1423                            "{" @keep
1424                            "}" @keep) @collapse)
1425
1426                    (macro_definition
1427                        name: (_) @name)
1428                ] @item
1429            )
1430            "#,
1431        )
1432        .unwrap(),
1433    )
1434}
1435
1436fn json_lang() -> Arc<Language> {
1437    Arc::new(
1438        Language::new(
1439            LanguageConfig {
1440                name: "JSON".into(),
1441                path_suffixes: vec!["json".into()],
1442                ..Default::default()
1443            },
1444            Some(tree_sitter_json::language()),
1445        )
1446        .with_embedding_query(
1447            r#"
1448            (document) @item
1449
1450            (array
1451                "[" @keep
1452                .
1453                (object)? @keep
1454                "]" @keep) @collapse
1455
1456            (pair value: (string
1457                "\"" @keep
1458                "\"" @keep) @collapse)
1459            "#,
1460        )
1461        .unwrap(),
1462    )
1463}
1464
1465fn toml_lang() -> Arc<Language> {
1466    Arc::new(Language::new(
1467        LanguageConfig {
1468            name: "TOML".into(),
1469            path_suffixes: vec!["toml".into()],
1470            ..Default::default()
1471        },
1472        Some(tree_sitter_toml::language()),
1473    ))
1474}
1475
1476fn cpp_lang() -> Arc<Language> {
1477    Arc::new(
1478        Language::new(
1479            LanguageConfig {
1480                name: "CPP".into(),
1481                path_suffixes: vec!["cpp".into()],
1482                ..Default::default()
1483            },
1484            Some(tree_sitter_cpp::language()),
1485        )
1486        .with_embedding_query(
1487            r#"
1488            (
1489                (comment)* @context
1490                .
1491                (function_definition
1492                    (type_qualifier)? @name
1493                    type: (_)? @name
1494                    declarator: [
1495                        (function_declarator
1496                            declarator: (_) @name)
1497                        (pointer_declarator
1498                            "*" @name
1499                            declarator: (function_declarator
1500                            declarator: (_) @name))
1501                        (pointer_declarator
1502                            "*" @name
1503                            declarator: (pointer_declarator
1504                                "*" @name
1505                            declarator: (function_declarator
1506                                declarator: (_) @name)))
1507                        (reference_declarator
1508                            ["&" "&&"] @name
1509                            (function_declarator
1510                            declarator: (_) @name))
1511                    ]
1512                    (type_qualifier)? @name) @item
1513                )
1514
1515            (
1516                (comment)* @context
1517                .
1518                (template_declaration
1519                    (class_specifier
1520                        "class" @name
1521                        name: (_) @name)
1522                        ) @item
1523            )
1524
1525            (
1526                (comment)* @context
1527                .
1528                (class_specifier
1529                    "class" @name
1530                    name: (_) @name) @item
1531                )
1532
1533            (
1534                (comment)* @context
1535                .
1536                (enum_specifier
1537                    "enum" @name
1538                    name: (_) @name) @item
1539                )
1540
1541            (
1542                (comment)* @context
1543                .
1544                (declaration
1545                    type: (struct_specifier
1546                    "struct" @name)
1547                    declarator: (_) @name) @item
1548            )
1549
1550            "#,
1551        )
1552        .unwrap(),
1553    )
1554}
1555
1556fn lua_lang() -> Arc<Language> {
1557    Arc::new(
1558        Language::new(
1559            LanguageConfig {
1560                name: "Lua".into(),
1561                path_suffixes: vec!["lua".into()],
1562                collapsed_placeholder: "--[ ... ]--".to_string(),
1563                ..Default::default()
1564            },
1565            Some(tree_sitter_lua::language()),
1566        )
1567        .with_embedding_query(
1568            r#"
1569            (
1570                (comment)* @context
1571                .
1572                (function_declaration
1573                    "function" @name
1574                    name: (_) @name
1575                    (comment)* @collapse
1576                    body: (block) @collapse
1577                ) @item
1578            )
1579        "#,
1580        )
1581        .unwrap(),
1582    )
1583}
1584
1585fn php_lang() -> Arc<Language> {
1586    Arc::new(
1587        Language::new(
1588            LanguageConfig {
1589                name: "PHP".into(),
1590                path_suffixes: vec!["php".into()],
1591                collapsed_placeholder: "/* ... */".into(),
1592                ..Default::default()
1593            },
1594            Some(tree_sitter_php::language()),
1595        )
1596        .with_embedding_query(
1597            r#"
1598            (
1599                (comment)* @context
1600                .
1601                [
1602                    (function_definition
1603                        "function" @name
1604                        name: (_) @name
1605                        body: (_
1606                            "{" @keep
1607                            "}" @keep) @collapse
1608                        )
1609
1610                    (trait_declaration
1611                        "trait" @name
1612                        name: (_) @name)
1613
1614                    (method_declaration
1615                        "function" @name
1616                        name: (_) @name
1617                        body: (_
1618                            "{" @keep
1619                            "}" @keep) @collapse
1620                        )
1621
1622                    (interface_declaration
1623                        "interface" @name
1624                        name: (_) @name
1625                        )
1626
1627                    (enum_declaration
1628                        "enum" @name
1629                        name: (_) @name
1630                        )
1631
1632                ] @item
1633            )
1634            "#,
1635        )
1636        .unwrap(),
1637    )
1638}
1639
1640fn ruby_lang() -> Arc<Language> {
1641    Arc::new(
1642        Language::new(
1643            LanguageConfig {
1644                name: "Ruby".into(),
1645                path_suffixes: vec!["rb".into()],
1646                collapsed_placeholder: "# ...".to_string(),
1647                ..Default::default()
1648            },
1649            Some(tree_sitter_ruby::language()),
1650        )
1651        .with_embedding_query(
1652            r#"
1653            (
1654                (comment)* @context
1655                .
1656                [
1657                (module
1658                    "module" @name
1659                    name: (_) @name)
1660                (method
1661                    "def" @name
1662                    name: (_) @name
1663                    body: (body_statement) @collapse)
1664                (class
1665                    "class" @name
1666                    name: (_) @name)
1667                (singleton_method
1668                    "def" @name
1669                    object: (_) @name
1670                    "." @name
1671                    name: (_) @name
1672                    body: (body_statement) @collapse)
1673                ] @item
1674            )
1675            "#,
1676        )
1677        .unwrap(),
1678    )
1679}
1680
1681fn elixir_lang() -> Arc<Language> {
1682    Arc::new(
1683        Language::new(
1684            LanguageConfig {
1685                name: "Elixir".into(),
1686                path_suffixes: vec!["rs".into()],
1687                ..Default::default()
1688            },
1689            Some(tree_sitter_elixir::language()),
1690        )
1691        .with_embedding_query(
1692            r#"
1693            (
1694                (unary_operator
1695                    operator: "@"
1696                    operand: (call
1697                        target: (identifier) @unary
1698                        (#match? @unary "^(doc)$"))
1699                    ) @context
1700                .
1701                (call
1702                target: (identifier) @name
1703                (arguments
1704                [
1705                (identifier) @name
1706                (call
1707                target: (identifier) @name)
1708                (binary_operator
1709                left: (call
1710                target: (identifier) @name)
1711                operator: "when")
1712                ])
1713                (#match? @name "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @item
1714                )
1715
1716            (call
1717                target: (identifier) @name
1718                (arguments (alias) @name)
1719                (#match? @name "^(defmodule|defprotocol)$")) @item
1720            "#,
1721        )
1722        .unwrap(),
1723    )
1724}
1725
1726#[gpui::test]
1727fn test_subtract_ranges() {
1728    // collapsed_ranges: Vec<Range<usize>>, keep_ranges: Vec<Range<usize>>
1729
1730    assert_eq!(
1731        subtract_ranges(&[0..5, 10..21], &[0..1, 4..5]),
1732        vec![1..4, 10..21]
1733    );
1734
1735    assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
1736}
1737
1738fn init_test(cx: &mut TestAppContext) {
1739    cx.update(|cx| {
1740        cx.set_global(SettingsStore::test(cx));
1741        settings::register::<SemanticIndexSettings>(cx);
1742        settings::register::<ProjectSettings>(cx);
1743    });
1744}