semantic_index_tests.rs

   1use crate::{
   2    embedding::{DummyEmbeddings, Embedding, EmbeddingProvider},
   3    embedding_queue::EmbeddingQueue,
   4    parsing::{subtract_ranges, CodeContextRetriever, Span, SpanDigest},
   5    semantic_index_settings::SemanticIndexSettings,
   6    FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
   7};
   8use anyhow::Result;
   9use async_trait::async_trait;
  10use gpui::{executor::Deterministic, Task, TestAppContext};
  11use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
  12use parking_lot::Mutex;
  13use pretty_assertions::assert_eq;
  14use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs, Project};
  15use rand::{rngs::StdRng, Rng};
  16use serde_json::json;
  17use settings::SettingsStore;
  18use std::{
  19    path::Path,
  20    sync::{
  21        atomic::{self, AtomicUsize},
  22        Arc,
  23    },
  24    time::{Instant, SystemTime},
  25};
  26use unindent::Unindent;
  27use util::RandomCharIter;
  28
  29#[ctor::ctor]
  30fn init_logger() {
  31    if std::env::var("RUST_LOG").is_ok() {
  32        env_logger::init();
  33    }
  34}
  35
  36#[gpui::test]
  37async fn test_semantic_index(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
  38    init_test(cx);
  39
  40    let fs = FakeFs::new(cx.background());
  41    fs.insert_tree(
  42        "/the-root",
  43        json!({
  44            "src": {
  45                "file1.rs": "
  46                    fn aaa() {
  47                        println!(\"aaaaaaaaaaaa!\");
  48                    }
  49
  50                    fn zzzzz() {
  51                        println!(\"SLEEPING\");
  52                    }
  53                ".unindent(),
  54                "file2.rs": "
  55                    fn bbb() {
  56                        println!(\"bbbbbbbbbbbbb!\");
  57                    }
  58                    struct pqpqpqp {}
  59                ".unindent(),
  60                "file3.toml": "
  61                    ZZZZZZZZZZZZZZZZZZ = 5
  62                ".unindent(),
  63            }
  64        }),
  65    )
  66    .await;
  67
  68    let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
  69    let rust_language = rust_lang();
  70    let toml_language = toml_lang();
  71    languages.add(rust_language);
  72    languages.add(toml_language);
  73
  74    let db_dir = tempdir::TempDir::new("vector-store").unwrap();
  75    let db_path = db_dir.path().join("db.sqlite");
  76
  77    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
  78    let semantic_index = SemanticIndex::new(
  79        fs.clone(),
  80        db_path,
  81        embedding_provider.clone(),
  82        languages,
  83        cx.to_async(),
  84    )
  85    .await
  86    .unwrap();
  87
  88    let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
  89
  90    let search_results = semantic_index.update(cx, |store, cx| {
  91        store.search_project(
  92            project.clone(),
  93            "aaaaaabbbbzz".to_string(),
  94            5,
  95            vec![],
  96            vec![],
  97            cx,
  98        )
  99    });
 100    let pending_file_count =
 101        semantic_index.read_with(cx, |index, _| index.pending_file_count(&project).unwrap());
 102    deterministic.run_until_parked();
 103    assert_eq!(*pending_file_count.borrow(), 3);
 104    deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
 105    assert_eq!(*pending_file_count.borrow(), 0);
 106
 107    let search_results = search_results.await.unwrap();
 108    assert_search_results(
 109        &search_results,
 110        &[
 111            (Path::new("src/file1.rs").into(), 0),
 112            (Path::new("src/file2.rs").into(), 0),
 113            (Path::new("src/file3.toml").into(), 0),
 114            (Path::new("src/file1.rs").into(), 45),
 115            (Path::new("src/file2.rs").into(), 45),
 116        ],
 117        cx,
 118    );
 119
 120    // Test Include Files Functonality
 121    let include_files = vec![PathMatcher::new("*.rs").unwrap()];
 122    let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
 123    let rust_only_search_results = semantic_index
 124        .update(cx, |store, cx| {
 125            store.search_project(
 126                project.clone(),
 127                "aaaaaabbbbzz".to_string(),
 128                5,
 129                include_files,
 130                vec![],
 131                cx,
 132            )
 133        })
 134        .await
 135        .unwrap();
 136
 137    assert_search_results(
 138        &rust_only_search_results,
 139        &[
 140            (Path::new("src/file1.rs").into(), 0),
 141            (Path::new("src/file2.rs").into(), 0),
 142            (Path::new("src/file1.rs").into(), 45),
 143            (Path::new("src/file2.rs").into(), 45),
 144        ],
 145        cx,
 146    );
 147
 148    let no_rust_search_results = semantic_index
 149        .update(cx, |store, cx| {
 150            store.search_project(
 151                project.clone(),
 152                "aaaaaabbbbzz".to_string(),
 153                5,
 154                vec![],
 155                exclude_files,
 156                cx,
 157            )
 158        })
 159        .await
 160        .unwrap();
 161
 162    assert_search_results(
 163        &no_rust_search_results,
 164        &[(Path::new("src/file3.toml").into(), 0)],
 165        cx,
 166    );
 167
 168    fs.save(
 169        "/the-root/src/file2.rs".as_ref(),
 170        &"
 171            fn dddd() { println!(\"ddddd!\"); }
 172            struct pqpqpqp {}
 173        "
 174        .unindent()
 175        .into(),
 176        Default::default(),
 177    )
 178    .await
 179    .unwrap();
 180
 181    deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
 182
 183    let prev_embedding_count = embedding_provider.embedding_count();
 184    let index = semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx));
 185    deterministic.run_until_parked();
 186    assert_eq!(*pending_file_count.borrow(), 1);
 187    deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
 188    assert_eq!(*pending_file_count.borrow(), 0);
 189    index.await.unwrap();
 190
 191    assert_eq!(
 192        embedding_provider.embedding_count() - prev_embedding_count,
 193        1
 194    );
 195}
 196
 197#[gpui::test(iterations = 10)]
 198async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
 199    let (outstanding_job_count, _) = postage::watch::channel_with(0);
 200    let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count));
 201
 202    let files = (1..=3)
 203        .map(|file_ix| FileToEmbed {
 204            worktree_id: 5,
 205            path: Path::new(&format!("path-{file_ix}")).into(),
 206            mtime: SystemTime::now(),
 207            spans: (0..rng.gen_range(4..22))
 208                .map(|document_ix| {
 209                    let content_len = rng.gen_range(10..100);
 210                    let content = RandomCharIter::new(&mut rng)
 211                        .with_simple_text()
 212                        .take(content_len)
 213                        .collect::<String>();
 214                    let digest = SpanDigest::from(content.as_str());
 215                    Span {
 216                        range: 0..10,
 217                        embedding: None,
 218                        name: format!("document {document_ix}"),
 219                        content,
 220                        digest,
 221                        token_count: rng.gen_range(10..30),
 222                    }
 223                })
 224                .collect(),
 225            job_handle: JobHandle::new(&outstanding_job_count),
 226        })
 227        .collect::<Vec<_>>();
 228
 229    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 230
 231    let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
 232    for file in &files {
 233        queue.push(file.clone());
 234    }
 235    queue.flush();
 236
 237    cx.foreground().run_until_parked();
 238    let finished_files = queue.finished_files();
 239    let mut embedded_files: Vec<_> = files
 240        .iter()
 241        .map(|_| finished_files.try_recv().expect("no finished file"))
 242        .collect();
 243
 244    let expected_files: Vec<_> = files
 245        .iter()
 246        .map(|file| {
 247            let mut file = file.clone();
 248            for doc in &mut file.spans {
 249                doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
 250            }
 251            file
 252        })
 253        .collect();
 254
 255    embedded_files.sort_by_key(|f| f.path.clone());
 256
 257    assert_eq!(embedded_files, expected_files);
 258}
 259
 260#[track_caller]
 261fn assert_search_results(
 262    actual: &[SearchResult],
 263    expected: &[(Arc<Path>, usize)],
 264    cx: &TestAppContext,
 265) {
 266    let actual = actual
 267        .iter()
 268        .map(|search_result| {
 269            search_result.buffer.read_with(cx, |buffer, _cx| {
 270                (
 271                    buffer.file().unwrap().path().clone(),
 272                    search_result.range.start.to_offset(buffer),
 273                )
 274            })
 275        })
 276        .collect::<Vec<_>>();
 277    assert_eq!(actual, expected);
 278}
 279
 280#[gpui::test]
 281async fn test_code_context_retrieval_rust() {
 282    let language = rust_lang();
 283    let embedding_provider = Arc::new(DummyEmbeddings {});
 284    let mut retriever = CodeContextRetriever::new(embedding_provider);
 285
 286    let text = "
 287        /// A doc comment
 288        /// that spans multiple lines
 289        #[gpui::test]
 290        fn a() {
 291            b
 292        }
 293
 294        impl C for D {
 295        }
 296
 297        impl E {
 298            // This is also a preceding comment
 299            pub fn function_1() -> Option<()> {
 300                todo!();
 301            }
 302
 303            // This is a preceding comment
 304            fn function_2() -> Result<()> {
 305                todo!();
 306            }
 307        }
 308    "
 309    .unindent();
 310
 311    let documents = retriever.parse_file(&text, language).unwrap();
 312
 313    assert_documents_eq(
 314        &documents,
 315        &[
 316            (
 317                "
 318                /// A doc comment
 319                /// that spans multiple lines
 320                #[gpui::test]
 321                fn a() {
 322                    b
 323                }"
 324                .unindent(),
 325                text.find("fn a").unwrap(),
 326            ),
 327            (
 328                "
 329                impl C for D {
 330                }"
 331                .unindent(),
 332                text.find("impl C").unwrap(),
 333            ),
 334            (
 335                "
 336                impl E {
 337                    // This is also a preceding comment
 338                    pub fn function_1() -> Option<()> { /* ... */ }
 339
 340                    // This is a preceding comment
 341                    fn function_2() -> Result<()> { /* ... */ }
 342                }"
 343                .unindent(),
 344                text.find("impl E").unwrap(),
 345            ),
 346            (
 347                "
 348                // This is also a preceding comment
 349                pub fn function_1() -> Option<()> {
 350                    todo!();
 351                }"
 352                .unindent(),
 353                text.find("pub fn function_1").unwrap(),
 354            ),
 355            (
 356                "
 357                // This is a preceding comment
 358                fn function_2() -> Result<()> {
 359                    todo!();
 360                }"
 361                .unindent(),
 362                text.find("fn function_2").unwrap(),
 363            ),
 364        ],
 365    );
 366}
 367
 368#[gpui::test]
 369async fn test_code_context_retrieval_json() {
 370    let language = json_lang();
 371    let embedding_provider = Arc::new(DummyEmbeddings {});
 372    let mut retriever = CodeContextRetriever::new(embedding_provider);
 373
 374    let text = r#"
 375        {
 376            "array": [1, 2, 3, 4],
 377            "string": "abcdefg",
 378            "nested_object": {
 379                "array_2": [5, 6, 7, 8],
 380                "string_2": "hijklmnop",
 381                "boolean": true,
 382                "none": null
 383            }
 384        }
 385    "#
 386    .unindent();
 387
 388    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 389
 390    assert_documents_eq(
 391        &documents,
 392        &[(
 393            r#"
 394                {
 395                    "array": [],
 396                    "string": "",
 397                    "nested_object": {
 398                        "array_2": [],
 399                        "string_2": "",
 400                        "boolean": true,
 401                        "none": null
 402                    }
 403                }"#
 404            .unindent(),
 405            text.find("{").unwrap(),
 406        )],
 407    );
 408
 409    let text = r#"
 410        [
 411            {
 412                "name": "somebody",
 413                "age": 42
 414            },
 415            {
 416                "name": "somebody else",
 417                "age": 43
 418            }
 419        ]
 420    "#
 421    .unindent();
 422
 423    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 424
 425    assert_documents_eq(
 426        &documents,
 427        &[(
 428            r#"
 429            [{
 430                    "name": "",
 431                    "age": 42
 432                }]"#
 433            .unindent(),
 434            text.find("[").unwrap(),
 435        )],
 436    );
 437}
 438
 439fn assert_documents_eq(
 440    documents: &[Span],
 441    expected_contents_and_start_offsets: &[(String, usize)],
 442) {
 443    assert_eq!(
 444        documents
 445            .iter()
 446            .map(|document| (document.content.clone(), document.range.start))
 447            .collect::<Vec<_>>(),
 448        expected_contents_and_start_offsets
 449    );
 450}
 451
 452#[gpui::test]
 453async fn test_code_context_retrieval_javascript() {
 454    let language = js_lang();
 455    let embedding_provider = Arc::new(DummyEmbeddings {});
 456    let mut retriever = CodeContextRetriever::new(embedding_provider);
 457
 458    let text = "
 459        /* globals importScripts, backend */
 460        function _authorize() {}
 461
 462        /**
 463         * Sometimes the frontend build is way faster than backend.
 464         */
 465        export async function authorizeBank() {
 466            _authorize(pushModal, upgradingAccountId, {});
 467        }
 468
 469        export class SettingsPage {
 470            /* This is a test setting */
 471            constructor(page) {
 472                this.page = page;
 473            }
 474        }
 475
 476        /* This is a test comment */
 477        class TestClass {}
 478
 479        /* Schema for editor_events in Clickhouse. */
 480        export interface ClickhouseEditorEvent {
 481            installation_id: string
 482            operation: string
 483        }
 484        "
 485    .unindent();
 486
 487    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 488
 489    assert_documents_eq(
 490        &documents,
 491        &[
 492            (
 493                "
 494            /* globals importScripts, backend */
 495            function _authorize() {}"
 496                    .unindent(),
 497                37,
 498            ),
 499            (
 500                "
 501            /**
 502             * Sometimes the frontend build is way faster than backend.
 503             */
 504            export async function authorizeBank() {
 505                _authorize(pushModal, upgradingAccountId, {});
 506            }"
 507                .unindent(),
 508                131,
 509            ),
 510            (
 511                "
 512                export class SettingsPage {
 513                    /* This is a test setting */
 514                    constructor(page) {
 515                        this.page = page;
 516                    }
 517                }"
 518                .unindent(),
 519                225,
 520            ),
 521            (
 522                "
 523                /* This is a test setting */
 524                constructor(page) {
 525                    this.page = page;
 526                }"
 527                .unindent(),
 528                290,
 529            ),
 530            (
 531                "
 532                /* This is a test comment */
 533                class TestClass {}"
 534                    .unindent(),
 535                374,
 536            ),
 537            (
 538                "
 539                /* Schema for editor_events in Clickhouse. */
 540                export interface ClickhouseEditorEvent {
 541                    installation_id: string
 542                    operation: string
 543                }"
 544                .unindent(),
 545                440,
 546            ),
 547        ],
 548    )
 549}
 550
 551#[gpui::test]
 552async fn test_code_context_retrieval_lua() {
 553    let language = lua_lang();
 554    let embedding_provider = Arc::new(DummyEmbeddings {});
 555    let mut retriever = CodeContextRetriever::new(embedding_provider);
 556
 557    let text = r#"
 558        -- Creates a new class
 559        -- @param baseclass The Baseclass of this class, or nil.
 560        -- @return A new class reference.
 561        function classes.class(baseclass)
 562            -- Create the class definition and metatable.
 563            local classdef = {}
 564            -- Find the super class, either Object or user-defined.
 565            baseclass = baseclass or classes.Object
 566            -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
 567            setmetatable(classdef, { __index = baseclass })
 568            -- All class instances have a reference to the class object.
 569            classdef.class = classdef
 570            --- Recursivly allocates the inheritance tree of the instance.
 571            -- @param mastertable The 'root' of the inheritance tree.
 572            -- @return Returns the instance with the allocated inheritance tree.
 573            function classdef.alloc(mastertable)
 574                -- All class instances have a reference to a superclass object.
 575                local instance = { super = baseclass.alloc(mastertable) }
 576                -- Any functions this instance does not know of will 'look up' to the superclass definition.
 577                setmetatable(instance, { __index = classdef, __newindex = mastertable })
 578                return instance
 579            end
 580        end
 581        "#.unindent();
 582
 583    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 584
 585    assert_documents_eq(
 586        &documents,
 587        &[
 588            (r#"
 589                -- Creates a new class
 590                -- @param baseclass The Baseclass of this class, or nil.
 591                -- @return A new class reference.
 592                function classes.class(baseclass)
 593                    -- Create the class definition and metatable.
 594                    local classdef = {}
 595                    -- Find the super class, either Object or user-defined.
 596                    baseclass = baseclass or classes.Object
 597                    -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
 598                    setmetatable(classdef, { __index = baseclass })
 599                    -- All class instances have a reference to the class object.
 600                    classdef.class = classdef
 601                    --- Recursivly allocates the inheritance tree of the instance.
 602                    -- @param mastertable The 'root' of the inheritance tree.
 603                    -- @return Returns the instance with the allocated inheritance tree.
 604                    function classdef.alloc(mastertable)
 605                        --[ ... ]--
 606                        --[ ... ]--
 607                    end
 608                end"#.unindent(),
 609            114),
 610            (r#"
 611            --- Recursivly allocates the inheritance tree of the instance.
 612            -- @param mastertable The 'root' of the inheritance tree.
 613            -- @return Returns the instance with the allocated inheritance tree.
 614            function classdef.alloc(mastertable)
 615                -- All class instances have a reference to a superclass object.
 616                local instance = { super = baseclass.alloc(mastertable) }
 617                -- Any functions this instance does not know of will 'look up' to the superclass definition.
 618                setmetatable(instance, { __index = classdef, __newindex = mastertable })
 619                return instance
 620            end"#.unindent(), 809),
 621        ]
 622    );
 623}
 624
 625#[gpui::test]
 626async fn test_code_context_retrieval_elixir() {
 627    let language = elixir_lang();
 628    let embedding_provider = Arc::new(DummyEmbeddings {});
 629    let mut retriever = CodeContextRetriever::new(embedding_provider);
 630
 631    let text = r#"
 632        defmodule File.Stream do
 633            @moduledoc """
 634            Defines a `File.Stream` struct returned by `File.stream!/3`.
 635
 636            The following fields are public:
 637
 638            * `path`          - the file path
 639            * `modes`         - the file modes
 640            * `raw`           - a boolean indicating if bin functions should be used
 641            * `line_or_bytes` - if reading should read lines or a given number of bytes
 642            * `node`          - the node the file belongs to
 643
 644            """
 645
 646            defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
 647
 648            @type t :: %__MODULE__{}
 649
 650            @doc false
 651            def __build__(path, modes, line_or_bytes) do
 652            raw = :lists.keyfind(:encoding, 1, modes) == false
 653
 654            modes =
 655                case raw do
 656                true ->
 657                    case :lists.keyfind(:read_ahead, 1, modes) do
 658                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
 659                    {:read_ahead, _} -> [:raw | modes]
 660                    false -> [:raw, :read_ahead | modes]
 661                    end
 662
 663                false ->
 664                    modes
 665                end
 666
 667            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
 668
 669            end"#
 670    .unindent();
 671
 672    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 673
 674    assert_documents_eq(
 675        &documents,
 676        &[(
 677            r#"
 678        defmodule File.Stream do
 679            @moduledoc """
 680            Defines a `File.Stream` struct returned by `File.stream!/3`.
 681
 682            The following fields are public:
 683
 684            * `path`          - the file path
 685            * `modes`         - the file modes
 686            * `raw`           - a boolean indicating if bin functions should be used
 687            * `line_or_bytes` - if reading should read lines or a given number of bytes
 688            * `node`          - the node the file belongs to
 689
 690            """
 691
 692            defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
 693
 694            @type t :: %__MODULE__{}
 695
 696            @doc false
 697            def __build__(path, modes, line_or_bytes) do
 698            raw = :lists.keyfind(:encoding, 1, modes) == false
 699
 700            modes =
 701                case raw do
 702                true ->
 703                    case :lists.keyfind(:read_ahead, 1, modes) do
 704                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
 705                    {:read_ahead, _} -> [:raw | modes]
 706                    false -> [:raw, :read_ahead | modes]
 707                    end
 708
 709                false ->
 710                    modes
 711                end
 712
 713            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
 714
 715            end"#
 716                .unindent(),
 717            0,
 718        ),(r#"
 719            @doc false
 720            def __build__(path, modes, line_or_bytes) do
 721            raw = :lists.keyfind(:encoding, 1, modes) == false
 722
 723            modes =
 724                case raw do
 725                true ->
 726                    case :lists.keyfind(:read_ahead, 1, modes) do
 727                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
 728                    {:read_ahead, _} -> [:raw | modes]
 729                    false -> [:raw, :read_ahead | modes]
 730                    end
 731
 732                false ->
 733                    modes
 734                end
 735
 736            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
 737
 738            end"#.unindent(), 574)],
 739    );
 740}
 741
 742#[gpui::test]
 743async fn test_code_context_retrieval_cpp() {
 744    let language = cpp_lang();
 745    let embedding_provider = Arc::new(DummyEmbeddings {});
 746    let mut retriever = CodeContextRetriever::new(embedding_provider);
 747
 748    let text = "
 749    /**
 750     * @brief Main function
 751     * @returns 0 on exit
 752     */
 753    int main() { return 0; }
 754
 755    /**
 756    * This is a test comment
 757    */
 758    class MyClass {       // The class
 759        public:           // Access specifier
 760        int myNum;        // Attribute (int variable)
 761        string myString;  // Attribute (string variable)
 762    };
 763
 764    // This is a test comment
 765    enum Color { red, green, blue };
 766
 767    /** This is a preceding block comment
 768     * This is the second line
 769     */
 770    struct {           // Structure declaration
 771        int myNum;       // Member (int variable)
 772        string myString; // Member (string variable)
 773    } myStructure;
 774
 775    /**
 776     * @brief Matrix class.
 777     */
 778    template <typename T,
 779              typename = typename std::enable_if<
 780                std::is_integral<T>::value || std::is_floating_point<T>::value,
 781                bool>::type>
 782    class Matrix2 {
 783        std::vector<std::vector<T>> _mat;
 784
 785        public:
 786            /**
 787            * @brief Constructor
 788            * @tparam Integer ensuring integers are being evaluated and not other
 789            * data types.
 790            * @param size denoting the size of Matrix as size x size
 791            */
 792            template <typename Integer,
 793                    typename = typename std::enable_if<std::is_integral<Integer>::value,
 794                    Integer>::type>
 795            explicit Matrix(const Integer size) {
 796                for (size_t i = 0; i < size; ++i) {
 797                    _mat.emplace_back(std::vector<T>(size, 0));
 798                }
 799            }
 800    }"
 801    .unindent();
 802
 803    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 804
 805    assert_documents_eq(
 806        &documents,
 807        &[
 808            (
 809                "
 810        /**
 811         * @brief Main function
 812         * @returns 0 on exit
 813         */
 814        int main() { return 0; }"
 815                    .unindent(),
 816                54,
 817            ),
 818            (
 819                "
 820                /**
 821                * This is a test comment
 822                */
 823                class MyClass {       // The class
 824                    public:           // Access specifier
 825                    int myNum;        // Attribute (int variable)
 826                    string myString;  // Attribute (string variable)
 827                }"
 828                .unindent(),
 829                112,
 830            ),
 831            (
 832                "
 833                // This is a test comment
 834                enum Color { red, green, blue }"
 835                    .unindent(),
 836                322,
 837            ),
 838            (
 839                "
 840                /** This is a preceding block comment
 841                 * This is the second line
 842                 */
 843                struct {           // Structure declaration
 844                    int myNum;       // Member (int variable)
 845                    string myString; // Member (string variable)
 846                } myStructure;"
 847                    .unindent(),
 848                425,
 849            ),
 850            (
 851                "
 852                /**
 853                 * @brief Matrix class.
 854                 */
 855                template <typename T,
 856                          typename = typename std::enable_if<
 857                            std::is_integral<T>::value || std::is_floating_point<T>::value,
 858                            bool>::type>
 859                class Matrix2 {
 860                    std::vector<std::vector<T>> _mat;
 861
 862                    public:
 863                        /**
 864                        * @brief Constructor
 865                        * @tparam Integer ensuring integers are being evaluated and not other
 866                        * data types.
 867                        * @param size denoting the size of Matrix as size x size
 868                        */
 869                        template <typename Integer,
 870                                typename = typename std::enable_if<std::is_integral<Integer>::value,
 871                                Integer>::type>
 872                        explicit Matrix(const Integer size) {
 873                            for (size_t i = 0; i < size; ++i) {
 874                                _mat.emplace_back(std::vector<T>(size, 0));
 875                            }
 876                        }
 877                }"
 878                .unindent(),
 879                612,
 880            ),
 881            (
 882                "
 883                explicit Matrix(const Integer size) {
 884                    for (size_t i = 0; i < size; ++i) {
 885                        _mat.emplace_back(std::vector<T>(size, 0));
 886                    }
 887                }"
 888                .unindent(),
 889                1226,
 890            ),
 891        ],
 892    );
 893}
 894
 895#[gpui::test]
 896async fn test_code_context_retrieval_ruby() {
 897    let language = ruby_lang();
 898    let embedding_provider = Arc::new(DummyEmbeddings {});
 899    let mut retriever = CodeContextRetriever::new(embedding_provider);
 900
 901    let text = r#"
 902        # This concern is inspired by "sudo mode" on GitHub. It
 903        # is a way to re-authenticate a user before allowing them
 904        # to see or perform an action.
 905        #
 906        # Add `before_action :require_challenge!` to actions you
 907        # want to protect.
 908        #
 909        # The user will be shown a page to enter the challenge (which
 910        # is either the password, or just the username when no
 911        # password exists). Upon passing, there is a grace period
 912        # during which no challenge will be asked from the user.
 913        #
 914        # Accessing challenge-protected resources during the grace
 915        # period will refresh the grace period.
 916        module ChallengableConcern
 917            extend ActiveSupport::Concern
 918
 919            CHALLENGE_TIMEOUT = 1.hour.freeze
 920
 921            def require_challenge!
 922                return if skip_challenge?
 923
 924                if challenge_passed_recently?
 925                    session[:challenge_passed_at] = Time.now.utc
 926                    return
 927                end
 928
 929                @challenge = Form::Challenge.new(return_to: request.url)
 930
 931                if params.key?(:form_challenge)
 932                    if challenge_passed?
 933                        session[:challenge_passed_at] = Time.now.utc
 934                    else
 935                        flash.now[:alert] = I18n.t('challenge.invalid_password')
 936                        render_challenge
 937                    end
 938                else
 939                    render_challenge
 940                end
 941            end
 942
 943            def challenge_passed?
 944                current_user.valid_password?(challenge_params[:current_password])
 945            end
 946        end
 947
 948        class Animal
 949            include Comparable
 950
 951            attr_reader :legs
 952
 953            def initialize(name, legs)
 954                @name, @legs = name, legs
 955            end
 956
 957            def <=>(other)
 958                legs <=> other.legs
 959            end
 960        end
 961
 962        # Singleton method for car object
 963        def car.wheels
 964            puts "There are four wheels"
 965        end"#
 966        .unindent();
 967
 968    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 969
 970    assert_documents_eq(
 971        &documents,
 972        &[
 973            (
 974                r#"
 975        # This concern is inspired by "sudo mode" on GitHub. It
 976        # is a way to re-authenticate a user before allowing them
 977        # to see or perform an action.
 978        #
 979        # Add `before_action :require_challenge!` to actions you
 980        # want to protect.
 981        #
 982        # The user will be shown a page to enter the challenge (which
 983        # is either the password, or just the username when no
 984        # password exists). Upon passing, there is a grace period
 985        # during which no challenge will be asked from the user.
 986        #
 987        # Accessing challenge-protected resources during the grace
 988        # period will refresh the grace period.
 989        module ChallengableConcern
 990            extend ActiveSupport::Concern
 991
 992            CHALLENGE_TIMEOUT = 1.hour.freeze
 993
 994            def require_challenge!
 995                # ...
 996            end
 997
 998            def challenge_passed?
 999                # ...
1000            end
1001        end"#
1002                    .unindent(),
1003                558,
1004            ),
1005            (
1006                r#"
1007            def require_challenge!
1008                return if skip_challenge?
1009
1010                if challenge_passed_recently?
1011                    session[:challenge_passed_at] = Time.now.utc
1012                    return
1013                end
1014
1015                @challenge = Form::Challenge.new(return_to: request.url)
1016
1017                if params.key?(:form_challenge)
1018                    if challenge_passed?
1019                        session[:challenge_passed_at] = Time.now.utc
1020                    else
1021                        flash.now[:alert] = I18n.t('challenge.invalid_password')
1022                        render_challenge
1023                    end
1024                else
1025                    render_challenge
1026                end
1027            end"#
1028                    .unindent(),
1029                663,
1030            ),
1031            (
1032                r#"
1033                def challenge_passed?
1034                    current_user.valid_password?(challenge_params[:current_password])
1035                end"#
1036                    .unindent(),
1037                1254,
1038            ),
1039            (
1040                r#"
1041                class Animal
1042                    include Comparable
1043
1044                    attr_reader :legs
1045
1046                    def initialize(name, legs)
1047                        # ...
1048                    end
1049
1050                    def <=>(other)
1051                        # ...
1052                    end
1053                end"#
1054                    .unindent(),
1055                1363,
1056            ),
1057            (
1058                r#"
1059                def initialize(name, legs)
1060                    @name, @legs = name, legs
1061                end"#
1062                    .unindent(),
1063                1427,
1064            ),
1065            (
1066                r#"
1067                def <=>(other)
1068                    legs <=> other.legs
1069                end"#
1070                    .unindent(),
1071                1501,
1072            ),
1073            (
1074                r#"
1075                # Singleton method for car object
1076                def car.wheels
1077                    puts "There are four wheels"
1078                end"#
1079                    .unindent(),
1080                1591,
1081            ),
1082        ],
1083    );
1084}
1085
1086#[gpui::test]
1087async fn test_code_context_retrieval_php() {
1088    let language = php_lang();
1089    let embedding_provider = Arc::new(DummyEmbeddings {});
1090    let mut retriever = CodeContextRetriever::new(embedding_provider);
1091
1092    let text = r#"
1093        <?php
1094
1095        namespace LevelUp\Experience\Concerns;
1096
1097        /*
1098        This is a multiple-lines comment block
1099        that spans over multiple
1100        lines
1101        */
1102        function functionName() {
1103            echo "Hello world!";
1104        }
1105
1106        trait HasAchievements
1107        {
1108            /**
1109            * @throws \Exception
1110            */
1111            public function grantAchievement(Achievement $achievement, $progress = null): void
1112            {
1113                if ($progress > 100) {
1114                    throw new Exception(message: 'Progress cannot be greater than 100');
1115                }
1116
1117                if ($this->achievements()->find($achievement->id)) {
1118                    throw new Exception(message: 'User already has this Achievement');
1119                }
1120
1121                $this->achievements()->attach($achievement, [
1122                    'progress' => $progress ?? null,
1123                ]);
1124
1125                $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1126            }
1127
1128            public function achievements(): BelongsToMany
1129            {
1130                return $this->belongsToMany(related: Achievement::class)
1131                ->withPivot(columns: 'progress')
1132                ->where('is_secret', false)
1133                ->using(AchievementUser::class);
1134            }
1135        }
1136
1137        interface Multiplier
1138        {
1139            public function qualifies(array $data): bool;
1140
1141            public function setMultiplier(): int;
1142        }
1143
1144        enum AuditType: string
1145        {
1146            case Add = 'add';
1147            case Remove = 'remove';
1148            case Reset = 'reset';
1149            case LevelUp = 'level_up';
1150        }
1151
1152        ?>"#
1153    .unindent();
1154
1155    let documents = retriever.parse_file(&text, language.clone()).unwrap();
1156
1157    assert_documents_eq(
1158        &documents,
1159        &[
1160            (
1161                r#"
1162        /*
1163        This is a multiple-lines comment block
1164        that spans over multiple
1165        lines
1166        */
1167        function functionName() {
1168            echo "Hello world!";
1169        }"#
1170                .unindent(),
1171                123,
1172            ),
1173            (
1174                r#"
1175        trait HasAchievements
1176        {
1177            /**
1178            * @throws \Exception
1179            */
1180            public function grantAchievement(Achievement $achievement, $progress = null): void
1181            {/* ... */}
1182
1183            public function achievements(): BelongsToMany
1184            {/* ... */}
1185        }"#
1186                .unindent(),
1187                177,
1188            ),
1189            (r#"
1190            /**
1191            * @throws \Exception
1192            */
1193            public function grantAchievement(Achievement $achievement, $progress = null): void
1194            {
1195                if ($progress > 100) {
1196                    throw new Exception(message: 'Progress cannot be greater than 100');
1197                }
1198
1199                if ($this->achievements()->find($achievement->id)) {
1200                    throw new Exception(message: 'User already has this Achievement');
1201                }
1202
1203                $this->achievements()->attach($achievement, [
1204                    'progress' => $progress ?? null,
1205                ]);
1206
1207                $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1208            }"#.unindent(), 245),
1209            (r#"
1210                public function achievements(): BelongsToMany
1211                {
1212                    return $this->belongsToMany(related: Achievement::class)
1213                    ->withPivot(columns: 'progress')
1214                    ->where('is_secret', false)
1215                    ->using(AchievementUser::class);
1216                }"#.unindent(), 902),
1217            (r#"
1218                interface Multiplier
1219                {
1220                    public function qualifies(array $data): bool;
1221
1222                    public function setMultiplier(): int;
1223                }"#.unindent(),
1224                1146),
1225            (r#"
1226                enum AuditType: string
1227                {
1228                    case Add = 'add';
1229                    case Remove = 'remove';
1230                    case Reset = 'reset';
1231                    case LevelUp = 'level_up';
1232                }"#.unindent(), 1265)
1233        ],
1234    );
1235}
1236
1237#[derive(Default)]
1238struct FakeEmbeddingProvider {
1239    embedding_count: AtomicUsize,
1240}
1241
1242impl FakeEmbeddingProvider {
1243    fn embedding_count(&self) -> usize {
1244        self.embedding_count.load(atomic::Ordering::SeqCst)
1245    }
1246
1247    fn embed_sync(&self, span: &str) -> Embedding {
1248        let mut result = vec![1.0; 26];
1249        for letter in span.chars() {
1250            let letter = letter.to_ascii_lowercase();
1251            if letter as u32 >= 'a' as u32 {
1252                let ix = (letter as u32) - ('a' as u32);
1253                if ix < 26 {
1254                    result[ix as usize] += 1.0;
1255                }
1256            }
1257        }
1258
1259        let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
1260        for x in &mut result {
1261            *x /= norm;
1262        }
1263
1264        result.into()
1265    }
1266}
1267
1268#[async_trait]
1269impl EmbeddingProvider for FakeEmbeddingProvider {
1270    fn truncate(&self, span: &str) -> (String, usize) {
1271        (span.to_string(), 1)
1272    }
1273
1274    fn max_tokens_per_batch(&self) -> usize {
1275        200
1276    }
1277
1278    fn rate_limit_expiration(&self) -> Option<Instant> {
1279        None
1280    }
1281
1282    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
1283        self.embedding_count
1284            .fetch_add(spans.len(), atomic::Ordering::SeqCst);
1285        Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
1286    }
1287}
1288
1289fn js_lang() -> Arc<Language> {
1290    Arc::new(
1291        Language::new(
1292            LanguageConfig {
1293                name: "Javascript".into(),
1294                path_suffixes: vec!["js".into()],
1295                ..Default::default()
1296            },
1297            Some(tree_sitter_typescript::language_tsx()),
1298        )
1299        .with_embedding_query(
1300            &r#"
1301
1302            (
1303                (comment)* @context
1304                .
1305                [
1306                (export_statement
1307                    (function_declaration
1308                        "async"? @name
1309                        "function" @name
1310                        name: (_) @name))
1311                (function_declaration
1312                    "async"? @name
1313                    "function" @name
1314                    name: (_) @name)
1315                ] @item
1316            )
1317
1318            (
1319                (comment)* @context
1320                .
1321                [
1322                (export_statement
1323                    (class_declaration
1324                        "class" @name
1325                        name: (_) @name))
1326                (class_declaration
1327                    "class" @name
1328                    name: (_) @name)
1329                ] @item
1330            )
1331
1332            (
1333                (comment)* @context
1334                .
1335                [
1336                (export_statement
1337                    (interface_declaration
1338                        "interface" @name
1339                        name: (_) @name))
1340                (interface_declaration
1341                    "interface" @name
1342                    name: (_) @name)
1343                ] @item
1344            )
1345
1346            (
1347                (comment)* @context
1348                .
1349                [
1350                (export_statement
1351                    (enum_declaration
1352                        "enum" @name
1353                        name: (_) @name))
1354                (enum_declaration
1355                    "enum" @name
1356                    name: (_) @name)
1357                ] @item
1358            )
1359
1360            (
1361                (comment)* @context
1362                .
1363                (method_definition
1364                    [
1365                        "get"
1366                        "set"
1367                        "async"
1368                        "*"
1369                        "static"
1370                    ]* @name
1371                    name: (_) @name) @item
1372            )
1373
1374                    "#
1375            .unindent(),
1376        )
1377        .unwrap(),
1378    )
1379}
1380
1381fn rust_lang() -> Arc<Language> {
1382    Arc::new(
1383        Language::new(
1384            LanguageConfig {
1385                name: "Rust".into(),
1386                path_suffixes: vec!["rs".into()],
1387                collapsed_placeholder: " /* ... */ ".to_string(),
1388                ..Default::default()
1389            },
1390            Some(tree_sitter_rust::language()),
1391        )
1392        .with_embedding_query(
1393            r#"
1394            (
1395                [(line_comment) (attribute_item)]* @context
1396                .
1397                [
1398                    (struct_item
1399                        name: (_) @name)
1400
1401                    (enum_item
1402                        name: (_) @name)
1403
1404                    (impl_item
1405                        trait: (_)? @name
1406                        "for"? @name
1407                        type: (_) @name)
1408
1409                    (trait_item
1410                        name: (_) @name)
1411
1412                    (function_item
1413                        name: (_) @name
1414                        body: (block
1415                            "{" @keep
1416                            "}" @keep) @collapse)
1417
1418                    (macro_definition
1419                        name: (_) @name)
1420                ] @item
1421            )
1422            "#,
1423        )
1424        .unwrap(),
1425    )
1426}
1427
1428fn json_lang() -> Arc<Language> {
1429    Arc::new(
1430        Language::new(
1431            LanguageConfig {
1432                name: "JSON".into(),
1433                path_suffixes: vec!["json".into()],
1434                ..Default::default()
1435            },
1436            Some(tree_sitter_json::language()),
1437        )
1438        .with_embedding_query(
1439            r#"
1440            (document) @item
1441
1442            (array
1443                "[" @keep
1444                .
1445                (object)? @keep
1446                "]" @keep) @collapse
1447
1448            (pair value: (string
1449                "\"" @keep
1450                "\"" @keep) @collapse)
1451            "#,
1452        )
1453        .unwrap(),
1454    )
1455}
1456
1457fn toml_lang() -> Arc<Language> {
1458    Arc::new(Language::new(
1459        LanguageConfig {
1460            name: "TOML".into(),
1461            path_suffixes: vec!["toml".into()],
1462            ..Default::default()
1463        },
1464        Some(tree_sitter_toml::language()),
1465    ))
1466}
1467
1468fn cpp_lang() -> Arc<Language> {
1469    Arc::new(
1470        Language::new(
1471            LanguageConfig {
1472                name: "CPP".into(),
1473                path_suffixes: vec!["cpp".into()],
1474                ..Default::default()
1475            },
1476            Some(tree_sitter_cpp::language()),
1477        )
1478        .with_embedding_query(
1479            r#"
1480            (
1481                (comment)* @context
1482                .
1483                (function_definition
1484                    (type_qualifier)? @name
1485                    type: (_)? @name
1486                    declarator: [
1487                        (function_declarator
1488                            declarator: (_) @name)
1489                        (pointer_declarator
1490                            "*" @name
1491                            declarator: (function_declarator
1492                            declarator: (_) @name))
1493                        (pointer_declarator
1494                            "*" @name
1495                            declarator: (pointer_declarator
1496                                "*" @name
1497                            declarator: (function_declarator
1498                                declarator: (_) @name)))
1499                        (reference_declarator
1500                            ["&" "&&"] @name
1501                            (function_declarator
1502                            declarator: (_) @name))
1503                    ]
1504                    (type_qualifier)? @name) @item
1505                )
1506
1507            (
1508                (comment)* @context
1509                .
1510                (template_declaration
1511                    (class_specifier
1512                        "class" @name
1513                        name: (_) @name)
1514                        ) @item
1515            )
1516
1517            (
1518                (comment)* @context
1519                .
1520                (class_specifier
1521                    "class" @name
1522                    name: (_) @name) @item
1523                )
1524
1525            (
1526                (comment)* @context
1527                .
1528                (enum_specifier
1529                    "enum" @name
1530                    name: (_) @name) @item
1531                )
1532
1533            (
1534                (comment)* @context
1535                .
1536                (declaration
1537                    type: (struct_specifier
1538                    "struct" @name)
1539                    declarator: (_) @name) @item
1540            )
1541
1542            "#,
1543        )
1544        .unwrap(),
1545    )
1546}
1547
1548fn lua_lang() -> Arc<Language> {
1549    Arc::new(
1550        Language::new(
1551            LanguageConfig {
1552                name: "Lua".into(),
1553                path_suffixes: vec!["lua".into()],
1554                collapsed_placeholder: "--[ ... ]--".to_string(),
1555                ..Default::default()
1556            },
1557            Some(tree_sitter_lua::language()),
1558        )
1559        .with_embedding_query(
1560            r#"
1561            (
1562                (comment)* @context
1563                .
1564                (function_declaration
1565                    "function" @name
1566                    name: (_) @name
1567                    (comment)* @collapse
1568                    body: (block) @collapse
1569                ) @item
1570            )
1571        "#,
1572        )
1573        .unwrap(),
1574    )
1575}
1576
1577fn php_lang() -> Arc<Language> {
1578    Arc::new(
1579        Language::new(
1580            LanguageConfig {
1581                name: "PHP".into(),
1582                path_suffixes: vec!["php".into()],
1583                collapsed_placeholder: "/* ... */".into(),
1584                ..Default::default()
1585            },
1586            Some(tree_sitter_php::language()),
1587        )
1588        .with_embedding_query(
1589            r#"
1590            (
1591                (comment)* @context
1592                .
1593                [
1594                    (function_definition
1595                        "function" @name
1596                        name: (_) @name
1597                        body: (_
1598                            "{" @keep
1599                            "}" @keep) @collapse
1600                        )
1601
1602                    (trait_declaration
1603                        "trait" @name
1604                        name: (_) @name)
1605
1606                    (method_declaration
1607                        "function" @name
1608                        name: (_) @name
1609                        body: (_
1610                            "{" @keep
1611                            "}" @keep) @collapse
1612                        )
1613
1614                    (interface_declaration
1615                        "interface" @name
1616                        name: (_) @name
1617                        )
1618
1619                    (enum_declaration
1620                        "enum" @name
1621                        name: (_) @name
1622                        )
1623
1624                ] @item
1625            )
1626            "#,
1627        )
1628        .unwrap(),
1629    )
1630}
1631
1632fn ruby_lang() -> Arc<Language> {
1633    Arc::new(
1634        Language::new(
1635            LanguageConfig {
1636                name: "Ruby".into(),
1637                path_suffixes: vec!["rb".into()],
1638                collapsed_placeholder: "# ...".to_string(),
1639                ..Default::default()
1640            },
1641            Some(tree_sitter_ruby::language()),
1642        )
1643        .with_embedding_query(
1644            r#"
1645            (
1646                (comment)* @context
1647                .
1648                [
1649                (module
1650                    "module" @name
1651                    name: (_) @name)
1652                (method
1653                    "def" @name
1654                    name: (_) @name
1655                    body: (body_statement) @collapse)
1656                (class
1657                    "class" @name
1658                    name: (_) @name)
1659                (singleton_method
1660                    "def" @name
1661                    object: (_) @name
1662                    "." @name
1663                    name: (_) @name
1664                    body: (body_statement) @collapse)
1665                ] @item
1666            )
1667            "#,
1668        )
1669        .unwrap(),
1670    )
1671}
1672
1673fn elixir_lang() -> Arc<Language> {
1674    Arc::new(
1675        Language::new(
1676            LanguageConfig {
1677                name: "Elixir".into(),
1678                path_suffixes: vec!["rs".into()],
1679                ..Default::default()
1680            },
1681            Some(tree_sitter_elixir::language()),
1682        )
1683        .with_embedding_query(
1684            r#"
1685            (
1686                (unary_operator
1687                    operator: "@"
1688                    operand: (call
1689                        target: (identifier) @unary
1690                        (#match? @unary "^(doc)$"))
1691                    ) @context
1692                .
1693                (call
1694                target: (identifier) @name
1695                (arguments
1696                [
1697                (identifier) @name
1698                (call
1699                target: (identifier) @name)
1700                (binary_operator
1701                left: (call
1702                target: (identifier) @name)
1703                operator: "when")
1704                ])
1705                (#match? @name "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @item
1706                )
1707
1708            (call
1709                target: (identifier) @name
1710                (arguments (alias) @name)
1711                (#match? @name "^(defmodule|defprotocol)$")) @item
1712            "#,
1713        )
1714        .unwrap(),
1715    )
1716}
1717
1718#[gpui::test]
1719fn test_subtract_ranges() {
1720    // collapsed_ranges: Vec<Range<usize>>, keep_ranges: Vec<Range<usize>>
1721
1722    assert_eq!(
1723        subtract_ranges(&[0..5, 10..21], &[0..1, 4..5]),
1724        vec![1..4, 10..21]
1725    );
1726
1727    assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
1728}
1729
1730fn init_test(cx: &mut TestAppContext) {
1731    cx.update(|cx| {
1732        cx.set_global(SettingsStore::test(cx));
1733        settings::register::<SemanticIndexSettings>(cx);
1734        settings::register::<ProjectSettings>(cx);
1735    });
1736}