semantic_index_tests.rs

   1use crate::{
   2    embedding_queue::EmbeddingQueue,
   3    parsing::{subtract_ranges, CodeContextRetriever, Span, SpanDigest},
   4    semantic_index_settings::SemanticIndexSettings,
   5    FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
   6};
   7use ai::embedding::{DummyEmbeddings, Embedding, EmbeddingProvider};
   8use anyhow::Result;
   9use async_trait::async_trait;
  10use gpui::{executor::Deterministic, Task, TestAppContext};
  11use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
  12use parking_lot::Mutex;
  13use pretty_assertions::assert_eq;
  14use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs, Project};
  15use rand::{rngs::StdRng, Rng};
  16use serde_json::json;
  17use settings::SettingsStore;
  18use std::{
  19    path::Path,
  20    sync::{
  21        atomic::{self, AtomicUsize},
  22        Arc,
  23    },
  24    time::{Instant, SystemTime},
  25};
  26use unindent::Unindent;
  27use util::RandomCharIter;
  28
  29#[ctor::ctor]
  30fn init_logger() {
  31    if std::env::var("RUST_LOG").is_ok() {
  32        env_logger::init();
  33    }
  34}
  35
  36#[gpui::test]
  37async fn test_semantic_index(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
  38    init_test(cx);
  39
  40    let fs = FakeFs::new(cx.background());
  41    fs.insert_tree(
  42        "/the-root",
  43        json!({
  44            "src": {
  45                "file1.rs": "
  46                    fn aaa() {
  47                        println!(\"aaaaaaaaaaaa!\");
  48                    }
  49
  50                    fn zzzzz() {
  51                        println!(\"SLEEPING\");
  52                    }
  53                ".unindent(),
  54                "file2.rs": "
  55                    fn bbb() {
  56                        println!(\"bbbbbbbbbbbbb!\");
  57                    }
  58                    struct pqpqpqp {}
  59                ".unindent(),
  60                "file3.toml": "
  61                    ZZZZZZZZZZZZZZZZZZ = 5
  62                ".unindent(),
  63            }
  64        }),
  65    )
  66    .await;
  67
  68    let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
  69    let rust_language = rust_lang();
  70    let toml_language = toml_lang();
  71    languages.add(rust_language);
  72    languages.add(toml_language);
  73
  74    let db_dir = tempdir::TempDir::new("vector-store").unwrap();
  75    let db_path = db_dir.path().join("db.sqlite");
  76
  77    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
  78    let semantic_index = SemanticIndex::new(
  79        fs.clone(),
  80        db_path,
  81        embedding_provider.clone(),
  82        languages,
  83        cx.to_async(),
  84    )
  85    .await
  86    .unwrap();
  87
  88    let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
  89
  90    let search_results = semantic_index.update(cx, |store, cx| {
  91        store.search_project(
  92            project.clone(),
  93            "aaaaaabbbbzz".to_string(),
  94            5,
  95            vec![],
  96            vec![],
  97            cx,
  98        )
  99    });
 100    let pending_file_count =
 101        semantic_index.read_with(cx, |index, _| index.pending_file_count(&project).unwrap());
 102    deterministic.run_until_parked();
 103    assert_eq!(*pending_file_count.borrow(), 3);
 104    deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
 105    assert_eq!(*pending_file_count.borrow(), 0);
 106
 107    let search_results = search_results.await.unwrap();
 108    assert_search_results(
 109        &search_results,
 110        &[
 111            (Path::new("src/file1.rs").into(), 0),
 112            (Path::new("src/file2.rs").into(), 0),
 113            (Path::new("src/file3.toml").into(), 0),
 114            (Path::new("src/file1.rs").into(), 45),
 115            (Path::new("src/file2.rs").into(), 45),
 116        ],
 117        cx,
 118    );
 119
 120    // Test Include Files Functonality
 121    let include_files = vec![PathMatcher::new("*.rs").unwrap()];
 122    let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
 123    let rust_only_search_results = semantic_index
 124        .update(cx, |store, cx| {
 125            store.search_project(
 126                project.clone(),
 127                "aaaaaabbbbzz".to_string(),
 128                5,
 129                include_files,
 130                vec![],
 131                cx,
 132            )
 133        })
 134        .await
 135        .unwrap();
 136
 137    assert_search_results(
 138        &rust_only_search_results,
 139        &[
 140            (Path::new("src/file1.rs").into(), 0),
 141            (Path::new("src/file2.rs").into(), 0),
 142            (Path::new("src/file1.rs").into(), 45),
 143            (Path::new("src/file2.rs").into(), 45),
 144        ],
 145        cx,
 146    );
 147
 148    let no_rust_search_results = semantic_index
 149        .update(cx, |store, cx| {
 150            store.search_project(
 151                project.clone(),
 152                "aaaaaabbbbzz".to_string(),
 153                5,
 154                vec![],
 155                exclude_files,
 156                cx,
 157            )
 158        })
 159        .await
 160        .unwrap();
 161
 162    assert_search_results(
 163        &no_rust_search_results,
 164        &[(Path::new("src/file3.toml").into(), 0)],
 165        cx,
 166    );
 167
 168    fs.save(
 169        "/the-root/src/file2.rs".as_ref(),
 170        &"
 171            fn dddd() { println!(\"ddddd!\"); }
 172            struct pqpqpqp {}
 173        "
 174        .unindent()
 175        .into(),
 176        Default::default(),
 177    )
 178    .await
 179    .unwrap();
 180
 181    deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
 182
 183    let prev_embedding_count = embedding_provider.embedding_count();
 184    let index = semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx));
 185    deterministic.run_until_parked();
 186    assert_eq!(*pending_file_count.borrow(), 1);
 187    deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
 188    assert_eq!(*pending_file_count.borrow(), 0);
 189    index.await.unwrap();
 190
 191    assert_eq!(
 192        embedding_provider.embedding_count() - prev_embedding_count,
 193        1
 194    );
 195}
 196
 197#[gpui::test(iterations = 10)]
 198async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
 199    let (outstanding_job_count, _) = postage::watch::channel_with(0);
 200    let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count));
 201
 202    let files = (1..=3)
 203        .map(|file_ix| FileToEmbed {
 204            worktree_id: 5,
 205            path: Path::new(&format!("path-{file_ix}")).into(),
 206            mtime: SystemTime::now(),
 207            spans: (0..rng.gen_range(4..22))
 208                .map(|document_ix| {
 209                    let content_len = rng.gen_range(10..100);
 210                    let content = RandomCharIter::new(&mut rng)
 211                        .with_simple_text()
 212                        .take(content_len)
 213                        .collect::<String>();
 214                    let digest = SpanDigest::from(content.as_str());
 215                    Span {
 216                        range: 0..10,
 217                        embedding: None,
 218                        name: format!("document {document_ix}"),
 219                        content,
 220                        digest,
 221                        token_count: rng.gen_range(10..30),
 222                    }
 223                })
 224                .collect(),
 225            job_handle: JobHandle::new(&outstanding_job_count),
 226        })
 227        .collect::<Vec<_>>();
 228
 229    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 230
 231    let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
 232    for file in &files {
 233        queue.push(file.clone());
 234    }
 235    queue.flush();
 236
 237    cx.foreground().run_until_parked();
 238    let finished_files = queue.finished_files();
 239    let mut embedded_files: Vec<_> = files
 240        .iter()
 241        .map(|_| finished_files.try_recv().expect("no finished file"))
 242        .collect();
 243
 244    let expected_files: Vec<_> = files
 245        .iter()
 246        .map(|file| {
 247            let mut file = file.clone();
 248            for doc in &mut file.spans {
 249                doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
 250            }
 251            file
 252        })
 253        .collect();
 254
 255    embedded_files.sort_by_key(|f| f.path.clone());
 256
 257    assert_eq!(embedded_files, expected_files);
 258}
 259
 260#[track_caller]
 261fn assert_search_results(
 262    actual: &[SearchResult],
 263    expected: &[(Arc<Path>, usize)],
 264    cx: &TestAppContext,
 265) {
 266    let actual = actual
 267        .iter()
 268        .map(|search_result| {
 269            search_result.buffer.read_with(cx, |buffer, _cx| {
 270                (
 271                    buffer.file().unwrap().path().clone(),
 272                    search_result.range.start.to_offset(buffer),
 273                )
 274            })
 275        })
 276        .collect::<Vec<_>>();
 277    assert_eq!(actual, expected);
 278}
 279
 280#[gpui::test]
 281async fn test_code_context_retrieval_rust() {
 282    let language = rust_lang();
 283    let embedding_provider = Arc::new(DummyEmbeddings {});
 284    let mut retriever = CodeContextRetriever::new(embedding_provider);
 285
 286    let text = "
 287        /// A doc comment
 288        /// that spans multiple lines
 289        #[gpui::test]
 290        fn a() {
 291            b
 292        }
 293
 294        impl C for D {
 295        }
 296
 297        impl E {
 298            // This is also a preceding comment
 299            pub fn function_1() -> Option<()> {
 300                todo!();
 301            }
 302
 303            // This is a preceding comment
 304            fn function_2() -> Result<()> {
 305                todo!();
 306            }
 307        }
 308    "
 309    .unindent();
 310
 311    let documents = retriever.parse_file(&text, language).unwrap();
 312
 313    assert_documents_eq(
 314        &documents,
 315        &[
 316            (
 317                "
 318                /// A doc comment
 319                /// that spans multiple lines
 320                #[gpui::test]
 321                fn a() {
 322                    b
 323                }"
 324                .unindent(),
 325                text.find("fn a").unwrap(),
 326            ),
 327            (
 328                "
 329                impl C for D {
 330                }"
 331                .unindent(),
 332                text.find("impl C").unwrap(),
 333            ),
 334            (
 335                "
 336                impl E {
 337                    // This is also a preceding comment
 338                    pub fn function_1() -> Option<()> { /* ... */ }
 339
 340                    // This is a preceding comment
 341                    fn function_2() -> Result<()> { /* ... */ }
 342                }"
 343                .unindent(),
 344                text.find("impl E").unwrap(),
 345            ),
 346            (
 347                "
 348                // This is also a preceding comment
 349                pub fn function_1() -> Option<()> {
 350                    todo!();
 351                }"
 352                .unindent(),
 353                text.find("pub fn function_1").unwrap(),
 354            ),
 355            (
 356                "
 357                // This is a preceding comment
 358                fn function_2() -> Result<()> {
 359                    todo!();
 360                }"
 361                .unindent(),
 362                text.find("fn function_2").unwrap(),
 363            ),
 364        ],
 365    );
 366}
 367
 368#[gpui::test]
 369async fn test_code_context_retrieval_json() {
 370    let language = json_lang();
 371    let embedding_provider = Arc::new(DummyEmbeddings {});
 372    let mut retriever = CodeContextRetriever::new(embedding_provider);
 373
 374    let text = r#"
 375        {
 376            "array": [1, 2, 3, 4],
 377            "string": "abcdefg",
 378            "nested_object": {
 379                "array_2": [5, 6, 7, 8],
 380                "string_2": "hijklmnop",
 381                "boolean": true,
 382                "none": null
 383            }
 384        }
 385    "#
 386    .unindent();
 387
 388    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 389
 390    assert_documents_eq(
 391        &documents,
 392        &[(
 393            r#"
 394                {
 395                    "array": [],
 396                    "string": "",
 397                    "nested_object": {
 398                        "array_2": [],
 399                        "string_2": "",
 400                        "boolean": true,
 401                        "none": null
 402                    }
 403                }"#
 404            .unindent(),
 405            text.find("{").unwrap(),
 406        )],
 407    );
 408
 409    let text = r#"
 410        [
 411            {
 412                "name": "somebody",
 413                "age": 42
 414            },
 415            {
 416                "name": "somebody else",
 417                "age": 43
 418            }
 419        ]
 420    "#
 421    .unindent();
 422
 423    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 424
 425    assert_documents_eq(
 426        &documents,
 427        &[(
 428            r#"
 429            [{
 430                    "name": "",
 431                    "age": 42
 432                }]"#
 433            .unindent(),
 434            text.find("[").unwrap(),
 435        )],
 436    );
 437}
 438
 439fn assert_documents_eq(
 440    documents: &[Span],
 441    expected_contents_and_start_offsets: &[(String, usize)],
 442) {
 443    assert_eq!(
 444        documents
 445            .iter()
 446            .map(|document| (document.content.clone(), document.range.start))
 447            .collect::<Vec<_>>(),
 448        expected_contents_and_start_offsets
 449    );
 450}
 451
 452#[gpui::test]
 453async fn test_code_context_retrieval_javascript() {
 454    let language = js_lang();
 455    let embedding_provider = Arc::new(DummyEmbeddings {});
 456    let mut retriever = CodeContextRetriever::new(embedding_provider);
 457
 458    let text = "
 459        /* globals importScripts, backend */
 460        function _authorize() {}
 461
 462        /**
 463         * Sometimes the frontend build is way faster than backend.
 464         */
 465        export async function authorizeBank() {
 466            _authorize(pushModal, upgradingAccountId, {});
 467        }
 468
 469        export class SettingsPage {
 470            /* This is a test setting */
 471            constructor(page) {
 472                this.page = page;
 473            }
 474        }
 475
 476        /* This is a test comment */
 477        class TestClass {}
 478
 479        /* Schema for editor_events in Clickhouse. */
 480        export interface ClickhouseEditorEvent {
 481            installation_id: string
 482            operation: string
 483        }
 484        "
 485    .unindent();
 486
 487    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 488
 489    assert_documents_eq(
 490        &documents,
 491        &[
 492            (
 493                "
 494            /* globals importScripts, backend */
 495            function _authorize() {}"
 496                    .unindent(),
 497                37,
 498            ),
 499            (
 500                "
 501            /**
 502             * Sometimes the frontend build is way faster than backend.
 503             */
 504            export async function authorizeBank() {
 505                _authorize(pushModal, upgradingAccountId, {});
 506            }"
 507                .unindent(),
 508                131,
 509            ),
 510            (
 511                "
 512                export class SettingsPage {
 513                    /* This is a test setting */
 514                    constructor(page) {
 515                        this.page = page;
 516                    }
 517                }"
 518                .unindent(),
 519                225,
 520            ),
 521            (
 522                "
 523                /* This is a test setting */
 524                constructor(page) {
 525                    this.page = page;
 526                }"
 527                .unindent(),
 528                290,
 529            ),
 530            (
 531                "
 532                /* This is a test comment */
 533                class TestClass {}"
 534                    .unindent(),
 535                374,
 536            ),
 537            (
 538                "
 539                /* Schema for editor_events in Clickhouse. */
 540                export interface ClickhouseEditorEvent {
 541                    installation_id: string
 542                    operation: string
 543                }"
 544                .unindent(),
 545                440,
 546            ),
 547        ],
 548    )
 549}
 550
 551#[gpui::test]
 552async fn test_code_context_retrieval_lua() {
 553    let language = lua_lang();
 554    let embedding_provider = Arc::new(DummyEmbeddings {});
 555    let mut retriever = CodeContextRetriever::new(embedding_provider);
 556
 557    let text = r#"
 558        -- Creates a new class
 559        -- @param baseclass The Baseclass of this class, or nil.
 560        -- @return A new class reference.
 561        function classes.class(baseclass)
 562            -- Create the class definition and metatable.
 563            local classdef = {}
 564            -- Find the super class, either Object or user-defined.
 565            baseclass = baseclass or classes.Object
 566            -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
 567            setmetatable(classdef, { __index = baseclass })
 568            -- All class instances have a reference to the class object.
 569            classdef.class = classdef
 570            --- Recursivly allocates the inheritance tree of the instance.
 571            -- @param mastertable The 'root' of the inheritance tree.
 572            -- @return Returns the instance with the allocated inheritance tree.
 573            function classdef.alloc(mastertable)
 574                -- All class instances have a reference to a superclass object.
 575                local instance = { super = baseclass.alloc(mastertable) }
 576                -- Any functions this instance does not know of will 'look up' to the superclass definition.
 577                setmetatable(instance, { __index = classdef, __newindex = mastertable })
 578                return instance
 579            end
 580        end
 581        "#.unindent();
 582
 583    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 584
 585    assert_documents_eq(
 586        &documents,
 587        &[
 588            (r#"
 589                -- Creates a new class
 590                -- @param baseclass The Baseclass of this class, or nil.
 591                -- @return A new class reference.
 592                function classes.class(baseclass)
 593                    -- Create the class definition and metatable.
 594                    local classdef = {}
 595                    -- Find the super class, either Object or user-defined.
 596                    baseclass = baseclass or classes.Object
 597                    -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
 598                    setmetatable(classdef, { __index = baseclass })
 599                    -- All class instances have a reference to the class object.
 600                    classdef.class = classdef
 601                    --- Recursivly allocates the inheritance tree of the instance.
 602                    -- @param mastertable The 'root' of the inheritance tree.
 603                    -- @return Returns the instance with the allocated inheritance tree.
 604                    function classdef.alloc(mastertable)
 605                        --[ ... ]--
 606                        --[ ... ]--
 607                    end
 608                end"#.unindent(),
 609            114),
 610            (r#"
 611            --- Recursivly allocates the inheritance tree of the instance.
 612            -- @param mastertable The 'root' of the inheritance tree.
 613            -- @return Returns the instance with the allocated inheritance tree.
 614            function classdef.alloc(mastertable)
 615                -- All class instances have a reference to a superclass object.
 616                local instance = { super = baseclass.alloc(mastertable) }
 617                -- Any functions this instance does not know of will 'look up' to the superclass definition.
 618                setmetatable(instance, { __index = classdef, __newindex = mastertable })
 619                return instance
 620            end"#.unindent(), 809),
 621        ]
 622    );
 623}
 624
 625#[gpui::test]
 626async fn test_code_context_retrieval_elixir() {
 627    let language = elixir_lang();
 628    let embedding_provider = Arc::new(DummyEmbeddings {});
 629    let mut retriever = CodeContextRetriever::new(embedding_provider);
 630
 631    let text = r#"
 632        defmodule File.Stream do
 633            @moduledoc """
 634            Defines a `File.Stream` struct returned by `File.stream!/3`.
 635
 636            The following fields are public:
 637
 638            * `path`          - the file path
 639            * `modes`         - the file modes
 640            * `raw`           - a boolean indicating if bin functions should be used
 641            * `line_or_bytes` - if reading should read lines or a given number of bytes
 642            * `node`          - the node the file belongs to
 643
 644            """
 645
 646            defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
 647
 648            @type t :: %__MODULE__{}
 649
 650            @doc false
 651            def __build__(path, modes, line_or_bytes) do
 652            raw = :lists.keyfind(:encoding, 1, modes) == false
 653
 654            modes =
 655                case raw do
 656                true ->
 657                    case :lists.keyfind(:read_ahead, 1, modes) do
 658                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
 659                    {:read_ahead, _} -> [:raw | modes]
 660                    false -> [:raw, :read_ahead | modes]
 661                    end
 662
 663                false ->
 664                    modes
 665                end
 666
 667            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
 668
 669            end"#
 670    .unindent();
 671
 672    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 673
 674    assert_documents_eq(
 675        &documents,
 676        &[(
 677            r#"
 678        defmodule File.Stream do
 679            @moduledoc """
 680            Defines a `File.Stream` struct returned by `File.stream!/3`.
 681
 682            The following fields are public:
 683
 684            * `path`          - the file path
 685            * `modes`         - the file modes
 686            * `raw`           - a boolean indicating if bin functions should be used
 687            * `line_or_bytes` - if reading should read lines or a given number of bytes
 688            * `node`          - the node the file belongs to
 689
 690            """
 691
 692            defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
 693
 694            @type t :: %__MODULE__{}
 695
 696            @doc false
 697            def __build__(path, modes, line_or_bytes) do
 698            raw = :lists.keyfind(:encoding, 1, modes) == false
 699
 700            modes =
 701                case raw do
 702                true ->
 703                    case :lists.keyfind(:read_ahead, 1, modes) do
 704                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
 705                    {:read_ahead, _} -> [:raw | modes]
 706                    false -> [:raw, :read_ahead | modes]
 707                    end
 708
 709                false ->
 710                    modes
 711                end
 712
 713            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
 714
 715            end"#
 716                .unindent(),
 717            0,
 718        ),(r#"
 719            @doc false
 720            def __build__(path, modes, line_or_bytes) do
 721            raw = :lists.keyfind(:encoding, 1, modes) == false
 722
 723            modes =
 724                case raw do
 725                true ->
 726                    case :lists.keyfind(:read_ahead, 1, modes) do
 727                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
 728                    {:read_ahead, _} -> [:raw | modes]
 729                    false -> [:raw, :read_ahead | modes]
 730                    end
 731
 732                false ->
 733                    modes
 734                end
 735
 736            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
 737
 738            end"#.unindent(), 574)],
 739    );
 740}
 741
 742#[gpui::test]
 743async fn test_code_context_retrieval_cpp() {
 744    let language = cpp_lang();
 745    let embedding_provider = Arc::new(DummyEmbeddings {});
 746    let mut retriever = CodeContextRetriever::new(embedding_provider);
 747
 748    let text = "
 749    /**
 750     * @brief Main function
 751     * @returns 0 on exit
 752     */
 753    int main() { return 0; }
 754
 755    /**
 756    * This is a test comment
 757    */
 758    class MyClass {       // The class
 759        public:           // Access specifier
 760        int myNum;        // Attribute (int variable)
 761        string myString;  // Attribute (string variable)
 762    };
 763
 764    // This is a test comment
 765    enum Color { red, green, blue };
 766
 767    /** This is a preceding block comment
 768     * This is the second line
 769     */
 770    struct {           // Structure declaration
 771        int myNum;       // Member (int variable)
 772        string myString; // Member (string variable)
 773    } myStructure;
 774
 775    /**
 776     * @brief Matrix class.
 777     */
 778    template <typename T,
 779              typename = typename std::enable_if<
 780                std::is_integral<T>::value || std::is_floating_point<T>::value,
 781                bool>::type>
 782    class Matrix2 {
 783        std::vector<std::vector<T>> _mat;
 784
 785        public:
 786            /**
 787            * @brief Constructor
 788            * @tparam Integer ensuring integers are being evaluated and not other
 789            * data types.
 790            * @param size denoting the size of Matrix as size x size
 791            */
 792            template <typename Integer,
 793                    typename = typename std::enable_if<std::is_integral<Integer>::value,
 794                    Integer>::type>
 795            explicit Matrix(const Integer size) {
 796                for (size_t i = 0; i < size; ++i) {
 797                    _mat.emplace_back(std::vector<T>(size, 0));
 798                }
 799            }
 800    }"
 801    .unindent();
 802
 803    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 804
 805    assert_documents_eq(
 806        &documents,
 807        &[
 808            (
 809                "
 810        /**
 811         * @brief Main function
 812         * @returns 0 on exit
 813         */
 814        int main() { return 0; }"
 815                    .unindent(),
 816                54,
 817            ),
 818            (
 819                "
 820                /**
 821                * This is a test comment
 822                */
 823                class MyClass {       // The class
 824                    public:           // Access specifier
 825                    int myNum;        // Attribute (int variable)
 826                    string myString;  // Attribute (string variable)
 827                }"
 828                .unindent(),
 829                112,
 830            ),
 831            (
 832                "
 833                // This is a test comment
 834                enum Color { red, green, blue }"
 835                    .unindent(),
 836                322,
 837            ),
 838            (
 839                "
 840                /** This is a preceding block comment
 841                 * This is the second line
 842                 */
 843                struct {           // Structure declaration
 844                    int myNum;       // Member (int variable)
 845                    string myString; // Member (string variable)
 846                } myStructure;"
 847                    .unindent(),
 848                425,
 849            ),
 850            (
 851                "
 852                /**
 853                 * @brief Matrix class.
 854                 */
 855                template <typename T,
 856                          typename = typename std::enable_if<
 857                            std::is_integral<T>::value || std::is_floating_point<T>::value,
 858                            bool>::type>
 859                class Matrix2 {
 860                    std::vector<std::vector<T>> _mat;
 861
 862                    public:
 863                        /**
 864                        * @brief Constructor
 865                        * @tparam Integer ensuring integers are being evaluated and not other
 866                        * data types.
 867                        * @param size denoting the size of Matrix as size x size
 868                        */
 869                        template <typename Integer,
 870                                typename = typename std::enable_if<std::is_integral<Integer>::value,
 871                                Integer>::type>
 872                        explicit Matrix(const Integer size) {
 873                            for (size_t i = 0; i < size; ++i) {
 874                                _mat.emplace_back(std::vector<T>(size, 0));
 875                            }
 876                        }
 877                }"
 878                .unindent(),
 879                612,
 880            ),
 881            (
 882                "
 883                explicit Matrix(const Integer size) {
 884                    for (size_t i = 0; i < size; ++i) {
 885                        _mat.emplace_back(std::vector<T>(size, 0));
 886                    }
 887                }"
 888                .unindent(),
 889                1226,
 890            ),
 891        ],
 892    );
 893}
 894
 895#[gpui::test]
 896async fn test_code_context_retrieval_ruby() {
 897    let language = ruby_lang();
 898    let embedding_provider = Arc::new(DummyEmbeddings {});
 899    let mut retriever = CodeContextRetriever::new(embedding_provider);
 900
 901    let text = r#"
 902        # This concern is inspired by "sudo mode" on GitHub. It
 903        # is a way to re-authenticate a user before allowing them
 904        # to see or perform an action.
 905        #
 906        # Add `before_action :require_challenge!` to actions you
 907        # want to protect.
 908        #
 909        # The user will be shown a page to enter the challenge (which
 910        # is either the password, or just the username when no
 911        # password exists). Upon passing, there is a grace period
 912        # during which no challenge will be asked from the user.
 913        #
 914        # Accessing challenge-protected resources during the grace
 915        # period will refresh the grace period.
 916        module ChallengableConcern
 917            extend ActiveSupport::Concern
 918
 919            CHALLENGE_TIMEOUT = 1.hour.freeze
 920
 921            def require_challenge!
 922                return if skip_challenge?
 923
 924                if challenge_passed_recently?
 925                    session[:challenge_passed_at] = Time.now.utc
 926                    return
 927                end
 928
 929                @challenge = Form::Challenge.new(return_to: request.url)
 930
 931                if params.key?(:form_challenge)
 932                    if challenge_passed?
 933                        session[:challenge_passed_at] = Time.now.utc
 934                    else
 935                        flash.now[:alert] = I18n.t('challenge.invalid_password')
 936                        render_challenge
 937                    end
 938                else
 939                    render_challenge
 940                end
 941            end
 942
 943            def challenge_passed?
 944                current_user.valid_password?(challenge_params[:current_password])
 945            end
 946        end
 947
 948        class Animal
 949            include Comparable
 950
 951            attr_reader :legs
 952
 953            def initialize(name, legs)
 954                @name, @legs = name, legs
 955            end
 956
 957            def <=>(other)
 958                legs <=> other.legs
 959            end
 960        end
 961
 962        # Singleton method for car object
 963        def car.wheels
 964            puts "There are four wheels"
 965        end"#
 966        .unindent();
 967
 968    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 969
 970    assert_documents_eq(
 971        &documents,
 972        &[
 973            (
 974                r#"
 975        # This concern is inspired by "sudo mode" on GitHub. It
 976        # is a way to re-authenticate a user before allowing them
 977        # to see or perform an action.
 978        #
 979        # Add `before_action :require_challenge!` to actions you
 980        # want to protect.
 981        #
 982        # The user will be shown a page to enter the challenge (which
 983        # is either the password, or just the username when no
 984        # password exists). Upon passing, there is a grace period
 985        # during which no challenge will be asked from the user.
 986        #
 987        # Accessing challenge-protected resources during the grace
 988        # period will refresh the grace period.
 989        module ChallengableConcern
 990            extend ActiveSupport::Concern
 991
 992            CHALLENGE_TIMEOUT = 1.hour.freeze
 993
 994            def require_challenge!
 995                # ...
 996            end
 997
 998            def challenge_passed?
 999                # ...
1000            end
1001        end"#
1002                    .unindent(),
1003                558,
1004            ),
1005            (
1006                r#"
1007            def require_challenge!
1008                return if skip_challenge?
1009
1010                if challenge_passed_recently?
1011                    session[:challenge_passed_at] = Time.now.utc
1012                    return
1013                end
1014
1015                @challenge = Form::Challenge.new(return_to: request.url)
1016
1017                if params.key?(:form_challenge)
1018                    if challenge_passed?
1019                        session[:challenge_passed_at] = Time.now.utc
1020                    else
1021                        flash.now[:alert] = I18n.t('challenge.invalid_password')
1022                        render_challenge
1023                    end
1024                else
1025                    render_challenge
1026                end
1027            end"#
1028                    .unindent(),
1029                663,
1030            ),
1031            (
1032                r#"
1033                def challenge_passed?
1034                    current_user.valid_password?(challenge_params[:current_password])
1035                end"#
1036                    .unindent(),
1037                1254,
1038            ),
1039            (
1040                r#"
1041                class Animal
1042                    include Comparable
1043
1044                    attr_reader :legs
1045
1046                    def initialize(name, legs)
1047                        # ...
1048                    end
1049
1050                    def <=>(other)
1051                        # ...
1052                    end
1053                end"#
1054                    .unindent(),
1055                1363,
1056            ),
1057            (
1058                r#"
1059                def initialize(name, legs)
1060                    @name, @legs = name, legs
1061                end"#
1062                    .unindent(),
1063                1427,
1064            ),
1065            (
1066                r#"
1067                def <=>(other)
1068                    legs <=> other.legs
1069                end"#
1070                    .unindent(),
1071                1501,
1072            ),
1073            (
1074                r#"
1075                # Singleton method for car object
1076                def car.wheels
1077                    puts "There are four wheels"
1078                end"#
1079                    .unindent(),
1080                1591,
1081            ),
1082        ],
1083    );
1084}
1085
1086#[gpui::test]
1087async fn test_code_context_retrieval_php() {
1088    let language = php_lang();
1089    let embedding_provider = Arc::new(DummyEmbeddings {});
1090    let mut retriever = CodeContextRetriever::new(embedding_provider);
1091
1092    let text = r#"
1093        <?php
1094
1095        namespace LevelUp\Experience\Concerns;
1096
1097        /*
1098        This is a multiple-lines comment block
1099        that spans over multiple
1100        lines
1101        */
1102        function functionName() {
1103            echo "Hello world!";
1104        }
1105
1106        trait HasAchievements
1107        {
1108            /**
1109            * @throws \Exception
1110            */
1111            public function grantAchievement(Achievement $achievement, $progress = null): void
1112            {
1113                if ($progress > 100) {
1114                    throw new Exception(message: 'Progress cannot be greater than 100');
1115                }
1116
1117                if ($this->achievements()->find($achievement->id)) {
1118                    throw new Exception(message: 'User already has this Achievement');
1119                }
1120
1121                $this->achievements()->attach($achievement, [
1122                    'progress' => $progress ?? null,
1123                ]);
1124
1125                $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1126            }
1127
1128            public function achievements(): BelongsToMany
1129            {
1130                return $this->belongsToMany(related: Achievement::class)
1131                ->withPivot(columns: 'progress')
1132                ->where('is_secret', false)
1133                ->using(AchievementUser::class);
1134            }
1135        }
1136
1137        interface Multiplier
1138        {
1139            public function qualifies(array $data): bool;
1140
1141            public function setMultiplier(): int;
1142        }
1143
1144        enum AuditType: string
1145        {
1146            case Add = 'add';
1147            case Remove = 'remove';
1148            case Reset = 'reset';
1149            case LevelUp = 'level_up';
1150        }
1151
1152        ?>"#
1153    .unindent();
1154
1155    let documents = retriever.parse_file(&text, language.clone()).unwrap();
1156
1157    assert_documents_eq(
1158        &documents,
1159        &[
1160            (
1161                r#"
1162        /*
1163        This is a multiple-lines comment block
1164        that spans over multiple
1165        lines
1166        */
1167        function functionName() {
1168            echo "Hello world!";
1169        }"#
1170                .unindent(),
1171                123,
1172            ),
1173            (
1174                r#"
1175        trait HasAchievements
1176        {
1177            /**
1178            * @throws \Exception
1179            */
1180            public function grantAchievement(Achievement $achievement, $progress = null): void
1181            {/* ... */}
1182
1183            public function achievements(): BelongsToMany
1184            {/* ... */}
1185        }"#
1186                .unindent(),
1187                177,
1188            ),
1189            (r#"
1190            /**
1191            * @throws \Exception
1192            */
1193            public function grantAchievement(Achievement $achievement, $progress = null): void
1194            {
1195                if ($progress > 100) {
1196                    throw new Exception(message: 'Progress cannot be greater than 100');
1197                }
1198
1199                if ($this->achievements()->find($achievement->id)) {
1200                    throw new Exception(message: 'User already has this Achievement');
1201                }
1202
1203                $this->achievements()->attach($achievement, [
1204                    'progress' => $progress ?? null,
1205                ]);
1206
1207                $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1208            }"#.unindent(), 245),
1209            (r#"
1210                public function achievements(): BelongsToMany
1211                {
1212                    return $this->belongsToMany(related: Achievement::class)
1213                    ->withPivot(columns: 'progress')
1214                    ->where('is_secret', false)
1215                    ->using(AchievementUser::class);
1216                }"#.unindent(), 902),
1217            (r#"
1218                interface Multiplier
1219                {
1220                    public function qualifies(array $data): bool;
1221
1222                    public function setMultiplier(): int;
1223                }"#.unindent(),
1224                1146),
1225            (r#"
1226                enum AuditType: string
1227                {
1228                    case Add = 'add';
1229                    case Remove = 'remove';
1230                    case Reset = 'reset';
1231                    case LevelUp = 'level_up';
1232                }"#.unindent(), 1265)
1233        ],
1234    );
1235}
1236
1237#[derive(Default)]
1238struct FakeEmbeddingProvider {
1239    embedding_count: AtomicUsize,
1240}
1241
1242impl FakeEmbeddingProvider {
1243    fn embedding_count(&self) -> usize {
1244        self.embedding_count.load(atomic::Ordering::SeqCst)
1245    }
1246
1247    fn embed_sync(&self, span: &str) -> Embedding {
1248        let mut result = vec![1.0; 26];
1249        for letter in span.chars() {
1250            let letter = letter.to_ascii_lowercase();
1251            if letter as u32 >= 'a' as u32 {
1252                let ix = (letter as u32) - ('a' as u32);
1253                if ix < 26 {
1254                    result[ix as usize] += 1.0;
1255                }
1256            }
1257        }
1258
1259        let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
1260        for x in &mut result {
1261            *x /= norm;
1262        }
1263
1264        result.into()
1265    }
1266}
1267
1268#[async_trait]
1269impl EmbeddingProvider for FakeEmbeddingProvider {
1270    fn is_authenticated(&self) -> bool {
1271        true
1272    }
1273    fn truncate(&self, span: &str) -> (String, usize) {
1274        (span.to_string(), 1)
1275    }
1276
1277    fn max_tokens_per_batch(&self) -> usize {
1278        200
1279    }
1280
1281    fn rate_limit_expiration(&self) -> Option<Instant> {
1282        None
1283    }
1284
1285    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
1286        self.embedding_count
1287            .fetch_add(spans.len(), atomic::Ordering::SeqCst);
1288        Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
1289    }
1290}
1291
1292fn js_lang() -> Arc<Language> {
1293    Arc::new(
1294        Language::new(
1295            LanguageConfig {
1296                name: "Javascript".into(),
1297                path_suffixes: vec!["js".into()],
1298                ..Default::default()
1299            },
1300            Some(tree_sitter_typescript::language_tsx()),
1301        )
1302        .with_embedding_query(
1303            &r#"
1304
1305            (
1306                (comment)* @context
1307                .
1308                [
1309                (export_statement
1310                    (function_declaration
1311                        "async"? @name
1312                        "function" @name
1313                        name: (_) @name))
1314                (function_declaration
1315                    "async"? @name
1316                    "function" @name
1317                    name: (_) @name)
1318                ] @item
1319            )
1320
1321            (
1322                (comment)* @context
1323                .
1324                [
1325                (export_statement
1326                    (class_declaration
1327                        "class" @name
1328                        name: (_) @name))
1329                (class_declaration
1330                    "class" @name
1331                    name: (_) @name)
1332                ] @item
1333            )
1334
1335            (
1336                (comment)* @context
1337                .
1338                [
1339                (export_statement
1340                    (interface_declaration
1341                        "interface" @name
1342                        name: (_) @name))
1343                (interface_declaration
1344                    "interface" @name
1345                    name: (_) @name)
1346                ] @item
1347            )
1348
1349            (
1350                (comment)* @context
1351                .
1352                [
1353                (export_statement
1354                    (enum_declaration
1355                        "enum" @name
1356                        name: (_) @name))
1357                (enum_declaration
1358                    "enum" @name
1359                    name: (_) @name)
1360                ] @item
1361            )
1362
1363            (
1364                (comment)* @context
1365                .
1366                (method_definition
1367                    [
1368                        "get"
1369                        "set"
1370                        "async"
1371                        "*"
1372                        "static"
1373                    ]* @name
1374                    name: (_) @name) @item
1375            )
1376
1377                    "#
1378            .unindent(),
1379        )
1380        .unwrap(),
1381    )
1382}
1383
1384fn rust_lang() -> Arc<Language> {
1385    Arc::new(
1386        Language::new(
1387            LanguageConfig {
1388                name: "Rust".into(),
1389                path_suffixes: vec!["rs".into()],
1390                collapsed_placeholder: " /* ... */ ".to_string(),
1391                ..Default::default()
1392            },
1393            Some(tree_sitter_rust::language()),
1394        )
1395        .with_embedding_query(
1396            r#"
1397            (
1398                [(line_comment) (attribute_item)]* @context
1399                .
1400                [
1401                    (struct_item
1402                        name: (_) @name)
1403
1404                    (enum_item
1405                        name: (_) @name)
1406
1407                    (impl_item
1408                        trait: (_)? @name
1409                        "for"? @name
1410                        type: (_) @name)
1411
1412                    (trait_item
1413                        name: (_) @name)
1414
1415                    (function_item
1416                        name: (_) @name
1417                        body: (block
1418                            "{" @keep
1419                            "}" @keep) @collapse)
1420
1421                    (macro_definition
1422                        name: (_) @name)
1423                ] @item
1424            )
1425            "#,
1426        )
1427        .unwrap(),
1428    )
1429}
1430
1431fn json_lang() -> Arc<Language> {
1432    Arc::new(
1433        Language::new(
1434            LanguageConfig {
1435                name: "JSON".into(),
1436                path_suffixes: vec!["json".into()],
1437                ..Default::default()
1438            },
1439            Some(tree_sitter_json::language()),
1440        )
1441        .with_embedding_query(
1442            r#"
1443            (document) @item
1444
1445            (array
1446                "[" @keep
1447                .
1448                (object)? @keep
1449                "]" @keep) @collapse
1450
1451            (pair value: (string
1452                "\"" @keep
1453                "\"" @keep) @collapse)
1454            "#,
1455        )
1456        .unwrap(),
1457    )
1458}
1459
1460fn toml_lang() -> Arc<Language> {
1461    Arc::new(Language::new(
1462        LanguageConfig {
1463            name: "TOML".into(),
1464            path_suffixes: vec!["toml".into()],
1465            ..Default::default()
1466        },
1467        Some(tree_sitter_toml::language()),
1468    ))
1469}
1470
1471fn cpp_lang() -> Arc<Language> {
1472    Arc::new(
1473        Language::new(
1474            LanguageConfig {
1475                name: "CPP".into(),
1476                path_suffixes: vec!["cpp".into()],
1477                ..Default::default()
1478            },
1479            Some(tree_sitter_cpp::language()),
1480        )
1481        .with_embedding_query(
1482            r#"
1483            (
1484                (comment)* @context
1485                .
1486                (function_definition
1487                    (type_qualifier)? @name
1488                    type: (_)? @name
1489                    declarator: [
1490                        (function_declarator
1491                            declarator: (_) @name)
1492                        (pointer_declarator
1493                            "*" @name
1494                            declarator: (function_declarator
1495                            declarator: (_) @name))
1496                        (pointer_declarator
1497                            "*" @name
1498                            declarator: (pointer_declarator
1499                                "*" @name
1500                            declarator: (function_declarator
1501                                declarator: (_) @name)))
1502                        (reference_declarator
1503                            ["&" "&&"] @name
1504                            (function_declarator
1505                            declarator: (_) @name))
1506                    ]
1507                    (type_qualifier)? @name) @item
1508                )
1509
1510            (
1511                (comment)* @context
1512                .
1513                (template_declaration
1514                    (class_specifier
1515                        "class" @name
1516                        name: (_) @name)
1517                        ) @item
1518            )
1519
1520            (
1521                (comment)* @context
1522                .
1523                (class_specifier
1524                    "class" @name
1525                    name: (_) @name) @item
1526                )
1527
1528            (
1529                (comment)* @context
1530                .
1531                (enum_specifier
1532                    "enum" @name
1533                    name: (_) @name) @item
1534                )
1535
1536            (
1537                (comment)* @context
1538                .
1539                (declaration
1540                    type: (struct_specifier
1541                    "struct" @name)
1542                    declarator: (_) @name) @item
1543            )
1544
1545            "#,
1546        )
1547        .unwrap(),
1548    )
1549}
1550
1551fn lua_lang() -> Arc<Language> {
1552    Arc::new(
1553        Language::new(
1554            LanguageConfig {
1555                name: "Lua".into(),
1556                path_suffixes: vec!["lua".into()],
1557                collapsed_placeholder: "--[ ... ]--".to_string(),
1558                ..Default::default()
1559            },
1560            Some(tree_sitter_lua::language()),
1561        )
1562        .with_embedding_query(
1563            r#"
1564            (
1565                (comment)* @context
1566                .
1567                (function_declaration
1568                    "function" @name
1569                    name: (_) @name
1570                    (comment)* @collapse
1571                    body: (block) @collapse
1572                ) @item
1573            )
1574        "#,
1575        )
1576        .unwrap(),
1577    )
1578}
1579
1580fn php_lang() -> Arc<Language> {
1581    Arc::new(
1582        Language::new(
1583            LanguageConfig {
1584                name: "PHP".into(),
1585                path_suffixes: vec!["php".into()],
1586                collapsed_placeholder: "/* ... */".into(),
1587                ..Default::default()
1588            },
1589            Some(tree_sitter_php::language()),
1590        )
1591        .with_embedding_query(
1592            r#"
1593            (
1594                (comment)* @context
1595                .
1596                [
1597                    (function_definition
1598                        "function" @name
1599                        name: (_) @name
1600                        body: (_
1601                            "{" @keep
1602                            "}" @keep) @collapse
1603                        )
1604
1605                    (trait_declaration
1606                        "trait" @name
1607                        name: (_) @name)
1608
1609                    (method_declaration
1610                        "function" @name
1611                        name: (_) @name
1612                        body: (_
1613                            "{" @keep
1614                            "}" @keep) @collapse
1615                        )
1616
1617                    (interface_declaration
1618                        "interface" @name
1619                        name: (_) @name
1620                        )
1621
1622                    (enum_declaration
1623                        "enum" @name
1624                        name: (_) @name
1625                        )
1626
1627                ] @item
1628            )
1629            "#,
1630        )
1631        .unwrap(),
1632    )
1633}
1634
1635fn ruby_lang() -> Arc<Language> {
1636    Arc::new(
1637        Language::new(
1638            LanguageConfig {
1639                name: "Ruby".into(),
1640                path_suffixes: vec!["rb".into()],
1641                collapsed_placeholder: "# ...".to_string(),
1642                ..Default::default()
1643            },
1644            Some(tree_sitter_ruby::language()),
1645        )
1646        .with_embedding_query(
1647            r#"
1648            (
1649                (comment)* @context
1650                .
1651                [
1652                (module
1653                    "module" @name
1654                    name: (_) @name)
1655                (method
1656                    "def" @name
1657                    name: (_) @name
1658                    body: (body_statement) @collapse)
1659                (class
1660                    "class" @name
1661                    name: (_) @name)
1662                (singleton_method
1663                    "def" @name
1664                    object: (_) @name
1665                    "." @name
1666                    name: (_) @name
1667                    body: (body_statement) @collapse)
1668                ] @item
1669            )
1670            "#,
1671        )
1672        .unwrap(),
1673    )
1674}
1675
1676fn elixir_lang() -> Arc<Language> {
1677    Arc::new(
1678        Language::new(
1679            LanguageConfig {
1680                name: "Elixir".into(),
1681                path_suffixes: vec!["rs".into()],
1682                ..Default::default()
1683            },
1684            Some(tree_sitter_elixir::language()),
1685        )
1686        .with_embedding_query(
1687            r#"
1688            (
1689                (unary_operator
1690                    operator: "@"
1691                    operand: (call
1692                        target: (identifier) @unary
1693                        (#match? @unary "^(doc)$"))
1694                    ) @context
1695                .
1696                (call
1697                target: (identifier) @name
1698                (arguments
1699                [
1700                (identifier) @name
1701                (call
1702                target: (identifier) @name)
1703                (binary_operator
1704                left: (call
1705                target: (identifier) @name)
1706                operator: "when")
1707                ])
1708                (#match? @name "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @item
1709                )
1710
1711            (call
1712                target: (identifier) @name
1713                (arguments (alias) @name)
1714                (#match? @name "^(defmodule|defprotocol)$")) @item
1715            "#,
1716        )
1717        .unwrap(),
1718    )
1719}
1720
1721#[gpui::test]
1722fn test_subtract_ranges() {
1723    // collapsed_ranges: Vec<Range<usize>>, keep_ranges: Vec<Range<usize>>
1724
1725    assert_eq!(
1726        subtract_ranges(&[0..5, 10..21], &[0..1, 4..5]),
1727        vec![1..4, 10..21]
1728    );
1729
1730    assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
1731}
1732
1733fn init_test(cx: &mut TestAppContext) {
1734    cx.update(|cx| {
1735        cx.set_global(SettingsStore::test(cx));
1736        settings::register::<SemanticIndexSettings>(cx);
1737        settings::register::<ProjectSettings>(cx);
1738    });
1739}