semantic_index_tests.rs

   1use crate::{
   2    db::dot,
   3    embedding::{DummyEmbeddings, EmbeddingProvider},
   4    parsing::{subtract_ranges, CodeContextRetriever, Document},
   5    semantic_index_settings::SemanticIndexSettings,
   6    SearchResult, SemanticIndex,
   7};
   8use anyhow::Result;
   9use async_trait::async_trait;
  10use gpui::{Task, TestAppContext};
  11use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
  12use pretty_assertions::assert_eq;
  13use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs, Project};
  14use rand::{rngs::StdRng, Rng};
  15use serde_json::json;
  16use settings::SettingsStore;
  17use std::{
  18    path::Path,
  19    sync::{
  20        atomic::{self, AtomicUsize},
  21        Arc,
  22    },
  23};
  24use unindent::Unindent;
  25
  26#[ctor::ctor]
  27fn init_logger() {
  28    if std::env::var("RUST_LOG").is_ok() {
  29        env_logger::init();
  30    }
  31}
  32
  33#[gpui::test]
  34async fn test_semantic_index(cx: &mut TestAppContext) {
  35    cx.update(|cx| {
  36        cx.set_global(SettingsStore::test(cx));
  37        settings::register::<SemanticIndexSettings>(cx);
  38        settings::register::<ProjectSettings>(cx);
  39    });
  40
  41    let fs = FakeFs::new(cx.background());
  42    fs.insert_tree(
  43        "/the-root",
  44        json!({
  45            "src": {
  46                "file1.rs": "
  47                    fn aaa() {
  48                        println!(\"aaaaaaaaaaaa!\");
  49                    }
  50
  51                    fn zzzzz() {
  52                        println!(\"SLEEPING\");
  53                    }
  54                ".unindent(),
  55                "file2.rs": "
  56                    fn bbb() {
  57                        println!(\"bbbbbbbbbbbbb!\");
  58                    }
  59                ".unindent(),
  60                "file3.toml": "
  61                    ZZZZZZZZZZZZZZZZZZ = 5
  62                ".unindent(),
  63            }
  64        }),
  65    )
  66    .await;
  67
  68    let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
  69    let rust_language = rust_lang();
  70    let toml_language = toml_lang();
  71    languages.add(rust_language);
  72    languages.add(toml_language);
  73
  74    let db_dir = tempdir::TempDir::new("vector-store").unwrap();
  75    let db_path = db_dir.path().join("db.sqlite");
  76
  77    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
  78    let store = SemanticIndex::new(
  79        fs.clone(),
  80        db_path,
  81        embedding_provider.clone(),
  82        languages,
  83        cx.to_async(),
  84    )
  85    .await
  86    .unwrap();
  87
  88    let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
  89
  90    let _ = store
  91        .update(cx, |store, cx| {
  92            store.initialize_project(project.clone(), cx)
  93        })
  94        .await;
  95
  96    let (file_count, outstanding_file_count) = store
  97        .update(cx, |store, cx| store.index_project(project.clone(), cx))
  98        .await
  99        .unwrap();
 100    assert_eq!(file_count, 3);
 101    cx.foreground().run_until_parked();
 102    assert_eq!(*outstanding_file_count.borrow(), 0);
 103
 104    let search_results = store
 105        .update(cx, |store, cx| {
 106            store.search_project(
 107                project.clone(),
 108                "aaaaaabbbbzz".to_string(),
 109                5,
 110                vec![],
 111                vec![],
 112                cx,
 113            )
 114        })
 115        .await
 116        .unwrap();
 117
 118    assert_search_results(
 119        &search_results,
 120        &[
 121            (Path::new("src/file1.rs").into(), 0),
 122            (Path::new("src/file2.rs").into(), 0),
 123            (Path::new("src/file3.toml").into(), 0),
 124            (Path::new("src/file1.rs").into(), 45),
 125        ],
 126        cx,
 127    );
 128
 129    // Test Include Files Functonality
 130    let include_files = vec![PathMatcher::new("*.rs").unwrap()];
 131    let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
 132    let rust_only_search_results = store
 133        .update(cx, |store, cx| {
 134            store.search_project(
 135                project.clone(),
 136                "aaaaaabbbbzz".to_string(),
 137                5,
 138                include_files,
 139                vec![],
 140                cx,
 141            )
 142        })
 143        .await
 144        .unwrap();
 145
 146    assert_search_results(
 147        &rust_only_search_results,
 148        &[
 149            (Path::new("src/file1.rs").into(), 0),
 150            (Path::new("src/file2.rs").into(), 0),
 151            (Path::new("src/file1.rs").into(), 45),
 152        ],
 153        cx,
 154    );
 155
 156    let no_rust_search_results = store
 157        .update(cx, |store, cx| {
 158            store.search_project(
 159                project.clone(),
 160                "aaaaaabbbbzz".to_string(),
 161                5,
 162                vec![],
 163                exclude_files,
 164                cx,
 165            )
 166        })
 167        .await
 168        .unwrap();
 169
 170    assert_search_results(
 171        &no_rust_search_results,
 172        &[(Path::new("src/file3.toml").into(), 0)],
 173        cx,
 174    );
 175
 176    fs.save(
 177        "/the-root/src/file2.rs".as_ref(),
 178        &"
 179            fn dddd() { println!(\"ddddd!\"); }
 180            struct pqpqpqp {}
 181        "
 182        .unindent()
 183        .into(),
 184        Default::default(),
 185    )
 186    .await
 187    .unwrap();
 188
 189    cx.foreground().run_until_parked();
 190
 191    let prev_embedding_count = embedding_provider.embedding_count();
 192    let (file_count, outstanding_file_count) = store
 193        .update(cx, |store, cx| store.index_project(project.clone(), cx))
 194        .await
 195        .unwrap();
 196    assert_eq!(file_count, 1);
 197
 198    cx.foreground().run_until_parked();
 199    assert_eq!(*outstanding_file_count.borrow(), 0);
 200
 201    assert_eq!(
 202        embedding_provider.embedding_count() - prev_embedding_count,
 203        2
 204    );
 205}
 206
 207#[track_caller]
 208fn assert_search_results(
 209    actual: &[SearchResult],
 210    expected: &[(Arc<Path>, usize)],
 211    cx: &TestAppContext,
 212) {
 213    let actual = actual
 214        .iter()
 215        .map(|search_result| {
 216            search_result.buffer.read_with(cx, |buffer, _cx| {
 217                (
 218                    buffer.file().unwrap().path().clone(),
 219                    search_result.range.start.to_offset(buffer),
 220                )
 221            })
 222        })
 223        .collect::<Vec<_>>();
 224    assert_eq!(actual, expected);
 225}
 226
 227#[gpui::test]
 228async fn test_code_context_retrieval_rust() {
 229    let language = rust_lang();
 230    let embedding_provider = Arc::new(DummyEmbeddings {});
 231    let mut retriever = CodeContextRetriever::new(embedding_provider);
 232
 233    let text = "
 234        /// A doc comment
 235        /// that spans multiple lines
 236        #[gpui::test]
 237        fn a() {
 238            b
 239        }
 240
 241        impl C for D {
 242        }
 243
 244        impl E {
 245            // This is also a preceding comment
 246            pub fn function_1() -> Option<()> {
 247                todo!();
 248            }
 249
 250            // This is a preceding comment
 251            fn function_2() -> Result<()> {
 252                todo!();
 253            }
 254        }
 255    "
 256    .unindent();
 257
 258    let documents = retriever.parse_file(&text, language).unwrap();
 259
 260    assert_documents_eq(
 261        &documents,
 262        &[
 263            (
 264                "
 265                /// A doc comment
 266                /// that spans multiple lines
 267                #[gpui::test]
 268                fn a() {
 269                    b
 270                }"
 271                .unindent(),
 272                text.find("fn a").unwrap(),
 273            ),
 274            (
 275                "
 276                impl C for D {
 277                }"
 278                .unindent(),
 279                text.find("impl C").unwrap(),
 280            ),
 281            (
 282                "
 283                impl E {
 284                    // This is also a preceding comment
 285                    pub fn function_1() -> Option<()> { /* ... */ }
 286
 287                    // This is a preceding comment
 288                    fn function_2() -> Result<()> { /* ... */ }
 289                }"
 290                .unindent(),
 291                text.find("impl E").unwrap(),
 292            ),
 293            (
 294                "
 295                // This is also a preceding comment
 296                pub fn function_1() -> Option<()> {
 297                    todo!();
 298                }"
 299                .unindent(),
 300                text.find("pub fn function_1").unwrap(),
 301            ),
 302            (
 303                "
 304                // This is a preceding comment
 305                fn function_2() -> Result<()> {
 306                    todo!();
 307                }"
 308                .unindent(),
 309                text.find("fn function_2").unwrap(),
 310            ),
 311        ],
 312    );
 313}
 314
 315#[gpui::test]
 316async fn test_code_context_retrieval_json() {
 317    let language = json_lang();
 318    let embedding_provider = Arc::new(DummyEmbeddings {});
 319    let mut retriever = CodeContextRetriever::new(embedding_provider);
 320
 321    let text = r#"
 322        {
 323            "array": [1, 2, 3, 4],
 324            "string": "abcdefg",
 325            "nested_object": {
 326                "array_2": [5, 6, 7, 8],
 327                "string_2": "hijklmnop",
 328                "boolean": true,
 329                "none": null
 330            }
 331        }
 332    "#
 333    .unindent();
 334
 335    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 336
 337    assert_documents_eq(
 338        &documents,
 339        &[(
 340            r#"
 341                {
 342                    "array": [],
 343                    "string": "",
 344                    "nested_object": {
 345                        "array_2": [],
 346                        "string_2": "",
 347                        "boolean": true,
 348                        "none": null
 349                    }
 350                }"#
 351            .unindent(),
 352            text.find("{").unwrap(),
 353        )],
 354    );
 355
 356    let text = r#"
 357        [
 358            {
 359                "name": "somebody",
 360                "age": 42
 361            },
 362            {
 363                "name": "somebody else",
 364                "age": 43
 365            }
 366        ]
 367    "#
 368    .unindent();
 369
 370    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 371
 372    assert_documents_eq(
 373        &documents,
 374        &[(
 375            r#"
 376            [{
 377                    "name": "",
 378                    "age": 42
 379                }]"#
 380            .unindent(),
 381            text.find("[").unwrap(),
 382        )],
 383    );
 384}
 385
 386fn assert_documents_eq(
 387    documents: &[Document],
 388    expected_contents_and_start_offsets: &[(String, usize)],
 389) {
 390    assert_eq!(
 391        documents
 392            .iter()
 393            .map(|document| (document.content.clone(), document.range.start))
 394            .collect::<Vec<_>>(),
 395        expected_contents_and_start_offsets
 396    );
 397}
 398
 399#[gpui::test]
 400async fn test_code_context_retrieval_javascript() {
 401    let language = js_lang();
 402    let embedding_provider = Arc::new(DummyEmbeddings {});
 403    let mut retriever = CodeContextRetriever::new(embedding_provider);
 404
 405    let text = "
 406        /* globals importScripts, backend */
 407        function _authorize() {}
 408
 409        /**
 410         * Sometimes the frontend build is way faster than backend.
 411         */
 412        export async function authorizeBank() {
 413            _authorize(pushModal, upgradingAccountId, {});
 414        }
 415
 416        export class SettingsPage {
 417            /* This is a test setting */
 418            constructor(page) {
 419                this.page = page;
 420            }
 421        }
 422
 423        /* This is a test comment */
 424        class TestClass {}
 425
 426        /* Schema for editor_events in Clickhouse. */
 427        export interface ClickhouseEditorEvent {
 428            installation_id: string
 429            operation: string
 430        }
 431        "
 432    .unindent();
 433
 434    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 435
 436    assert_documents_eq(
 437        &documents,
 438        &[
 439            (
 440                "
 441            /* globals importScripts, backend */
 442            function _authorize() {}"
 443                    .unindent(),
 444                37,
 445            ),
 446            (
 447                "
 448            /**
 449             * Sometimes the frontend build is way faster than backend.
 450             */
 451            export async function authorizeBank() {
 452                _authorize(pushModal, upgradingAccountId, {});
 453            }"
 454                .unindent(),
 455                131,
 456            ),
 457            (
 458                "
 459                export class SettingsPage {
 460                    /* This is a test setting */
 461                    constructor(page) {
 462                        this.page = page;
 463                    }
 464                }"
 465                .unindent(),
 466                225,
 467            ),
 468            (
 469                "
 470                /* This is a test setting */
 471                constructor(page) {
 472                    this.page = page;
 473                }"
 474                .unindent(),
 475                290,
 476            ),
 477            (
 478                "
 479                /* This is a test comment */
 480                class TestClass {}"
 481                    .unindent(),
 482                374,
 483            ),
 484            (
 485                "
 486                /* Schema for editor_events in Clickhouse. */
 487                export interface ClickhouseEditorEvent {
 488                    installation_id: string
 489                    operation: string
 490                }"
 491                .unindent(),
 492                440,
 493            ),
 494        ],
 495    )
 496}
 497
 498#[gpui::test]
 499async fn test_code_context_retrieval_lua() {
 500    let language = lua_lang();
 501    let embedding_provider = Arc::new(DummyEmbeddings {});
 502    let mut retriever = CodeContextRetriever::new(embedding_provider);
 503
 504    let text = r#"
 505        -- Creates a new class
 506        -- @param baseclass The Baseclass of this class, or nil.
 507        -- @return A new class reference.
 508        function classes.class(baseclass)
 509            -- Create the class definition and metatable.
 510            local classdef = {}
 511            -- Find the super class, either Object or user-defined.
 512            baseclass = baseclass or classes.Object
 513            -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
 514            setmetatable(classdef, { __index = baseclass })
 515            -- All class instances have a reference to the class object.
 516            classdef.class = classdef
 517            --- Recursivly allocates the inheritance tree of the instance.
 518            -- @param mastertable The 'root' of the inheritance tree.
 519            -- @return Returns the instance with the allocated inheritance tree.
 520            function classdef.alloc(mastertable)
 521                -- All class instances have a reference to a superclass object.
 522                local instance = { super = baseclass.alloc(mastertable) }
 523                -- Any functions this instance does not know of will 'look up' to the superclass definition.
 524                setmetatable(instance, { __index = classdef, __newindex = mastertable })
 525                return instance
 526            end
 527        end
 528        "#.unindent();
 529
 530    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 531
 532    assert_documents_eq(
 533        &documents,
 534        &[
 535            (r#"
 536                -- Creates a new class
 537                -- @param baseclass The Baseclass of this class, or nil.
 538                -- @return A new class reference.
 539                function classes.class(baseclass)
 540                    -- Create the class definition and metatable.
 541                    local classdef = {}
 542                    -- Find the super class, either Object or user-defined.
 543                    baseclass = baseclass or classes.Object
 544                    -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
 545                    setmetatable(classdef, { __index = baseclass })
 546                    -- All class instances have a reference to the class object.
 547                    classdef.class = classdef
 548                    --- Recursivly allocates the inheritance tree of the instance.
 549                    -- @param mastertable The 'root' of the inheritance tree.
 550                    -- @return Returns the instance with the allocated inheritance tree.
 551                    function classdef.alloc(mastertable)
 552                        --[ ... ]--
 553                        --[ ... ]--
 554                    end
 555                end"#.unindent(),
 556            114),
 557            (r#"
 558            --- Recursivly allocates the inheritance tree of the instance.
 559            -- @param mastertable The 'root' of the inheritance tree.
 560            -- @return Returns the instance with the allocated inheritance tree.
 561            function classdef.alloc(mastertable)
 562                -- All class instances have a reference to a superclass object.
 563                local instance = { super = baseclass.alloc(mastertable) }
 564                -- Any functions this instance does not know of will 'look up' to the superclass definition.
 565                setmetatable(instance, { __index = classdef, __newindex = mastertable })
 566                return instance
 567            end"#.unindent(), 809),
 568        ]
 569    );
 570}
 571
 572#[gpui::test]
 573async fn test_code_context_retrieval_elixir() {
 574    let language = elixir_lang();
 575    let embedding_provider = Arc::new(DummyEmbeddings {});
 576    let mut retriever = CodeContextRetriever::new(embedding_provider);
 577
 578    let text = r#"
 579        defmodule File.Stream do
 580            @moduledoc """
 581            Defines a `File.Stream` struct returned by `File.stream!/3`.
 582
 583            The following fields are public:
 584
 585            * `path`          - the file path
 586            * `modes`         - the file modes
 587            * `raw`           - a boolean indicating if bin functions should be used
 588            * `line_or_bytes` - if reading should read lines or a given number of bytes
 589            * `node`          - the node the file belongs to
 590
 591            """
 592
 593            defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
 594
 595            @type t :: %__MODULE__{}
 596
 597            @doc false
 598            def __build__(path, modes, line_or_bytes) do
 599            raw = :lists.keyfind(:encoding, 1, modes) == false
 600
 601            modes =
 602                case raw do
 603                true ->
 604                    case :lists.keyfind(:read_ahead, 1, modes) do
 605                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
 606                    {:read_ahead, _} -> [:raw | modes]
 607                    false -> [:raw, :read_ahead | modes]
 608                    end
 609
 610                false ->
 611                    modes
 612                end
 613
 614            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
 615
 616            end"#
 617    .unindent();
 618
 619    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 620
 621    assert_documents_eq(
 622        &documents,
 623        &[(
 624            r#"
 625        defmodule File.Stream do
 626            @moduledoc """
 627            Defines a `File.Stream` struct returned by `File.stream!/3`.
 628
 629            The following fields are public:
 630
 631            * `path`          - the file path
 632            * `modes`         - the file modes
 633            * `raw`           - a boolean indicating if bin functions should be used
 634            * `line_or_bytes` - if reading should read lines or a given number of bytes
 635            * `node`          - the node the file belongs to
 636
 637            """
 638
 639            defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
 640
 641            @type t :: %__MODULE__{}
 642
 643            @doc false
 644            def __build__(path, modes, line_or_bytes) do
 645            raw = :lists.keyfind(:encoding, 1, modes) == false
 646
 647            modes =
 648                case raw do
 649                true ->
 650                    case :lists.keyfind(:read_ahead, 1, modes) do
 651                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
 652                    {:read_ahead, _} -> [:raw | modes]
 653                    false -> [:raw, :read_ahead | modes]
 654                    end
 655
 656                false ->
 657                    modes
 658                end
 659
 660            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
 661
 662            end"#
 663                .unindent(),
 664            0,
 665        ),(r#"
 666            @doc false
 667            def __build__(path, modes, line_or_bytes) do
 668            raw = :lists.keyfind(:encoding, 1, modes) == false
 669
 670            modes =
 671                case raw do
 672                true ->
 673                    case :lists.keyfind(:read_ahead, 1, modes) do
 674                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
 675                    {:read_ahead, _} -> [:raw | modes]
 676                    false -> [:raw, :read_ahead | modes]
 677                    end
 678
 679                false ->
 680                    modes
 681                end
 682
 683            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
 684
 685            end"#.unindent(), 574)],
 686    );
 687}
 688
 689#[gpui::test]
 690async fn test_code_context_retrieval_cpp() {
 691    let language = cpp_lang();
 692    let embedding_provider = Arc::new(DummyEmbeddings {});
 693    let mut retriever = CodeContextRetriever::new(embedding_provider);
 694
 695    let text = "
 696    /**
 697     * @brief Main function
 698     * @returns 0 on exit
 699     */
 700    int main() { return 0; }
 701
 702    /**
 703    * This is a test comment
 704    */
 705    class MyClass {       // The class
 706        public:           // Access specifier
 707        int myNum;        // Attribute (int variable)
 708        string myString;  // Attribute (string variable)
 709    };
 710
 711    // This is a test comment
 712    enum Color { red, green, blue };
 713
 714    /** This is a preceding block comment
 715     * This is the second line
 716     */
 717    struct {           // Structure declaration
 718        int myNum;       // Member (int variable)
 719        string myString; // Member (string variable)
 720    } myStructure;
 721
 722    /**
 723     * @brief Matrix class.
 724     */
 725    template <typename T,
 726              typename = typename std::enable_if<
 727                std::is_integral<T>::value || std::is_floating_point<T>::value,
 728                bool>::type>
 729    class Matrix2 {
 730        std::vector<std::vector<T>> _mat;
 731
 732        public:
 733            /**
 734            * @brief Constructor
 735            * @tparam Integer ensuring integers are being evaluated and not other
 736            * data types.
 737            * @param size denoting the size of Matrix as size x size
 738            */
 739            template <typename Integer,
 740                    typename = typename std::enable_if<std::is_integral<Integer>::value,
 741                    Integer>::type>
 742            explicit Matrix(const Integer size) {
 743                for (size_t i = 0; i < size; ++i) {
 744                    _mat.emplace_back(std::vector<T>(size, 0));
 745                }
 746            }
 747    }"
 748    .unindent();
 749
 750    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 751
 752    assert_documents_eq(
 753        &documents,
 754        &[
 755            (
 756                "
 757        /**
 758         * @brief Main function
 759         * @returns 0 on exit
 760         */
 761        int main() { return 0; }"
 762                    .unindent(),
 763                54,
 764            ),
 765            (
 766                "
 767                /**
 768                * This is a test comment
 769                */
 770                class MyClass {       // The class
 771                    public:           // Access specifier
 772                    int myNum;        // Attribute (int variable)
 773                    string myString;  // Attribute (string variable)
 774                }"
 775                .unindent(),
 776                112,
 777            ),
 778            (
 779                "
 780                // This is a test comment
 781                enum Color { red, green, blue }"
 782                    .unindent(),
 783                322,
 784            ),
 785            (
 786                "
 787                /** This is a preceding block comment
 788                 * This is the second line
 789                 */
 790                struct {           // Structure declaration
 791                    int myNum;       // Member (int variable)
 792                    string myString; // Member (string variable)
 793                } myStructure;"
 794                    .unindent(),
 795                425,
 796            ),
 797            (
 798                "
 799                /**
 800                 * @brief Matrix class.
 801                 */
 802                template <typename T,
 803                          typename = typename std::enable_if<
 804                            std::is_integral<T>::value || std::is_floating_point<T>::value,
 805                            bool>::type>
 806                class Matrix2 {
 807                    std::vector<std::vector<T>> _mat;
 808
 809                    public:
 810                        /**
 811                        * @brief Constructor
 812                        * @tparam Integer ensuring integers are being evaluated and not other
 813                        * data types.
 814                        * @param size denoting the size of Matrix as size x size
 815                        */
 816                        template <typename Integer,
 817                                typename = typename std::enable_if<std::is_integral<Integer>::value,
 818                                Integer>::type>
 819                        explicit Matrix(const Integer size) {
 820                            for (size_t i = 0; i < size; ++i) {
 821                                _mat.emplace_back(std::vector<T>(size, 0));
 822                            }
 823                        }
 824                }"
 825                .unindent(),
 826                612,
 827            ),
 828            (
 829                "
 830                explicit Matrix(const Integer size) {
 831                    for (size_t i = 0; i < size; ++i) {
 832                        _mat.emplace_back(std::vector<T>(size, 0));
 833                    }
 834                }"
 835                .unindent(),
 836                1226,
 837            ),
 838        ],
 839    );
 840}
 841
 842#[gpui::test]
 843async fn test_code_context_retrieval_ruby() {
 844    let language = ruby_lang();
 845    let embedding_provider = Arc::new(DummyEmbeddings {});
 846    let mut retriever = CodeContextRetriever::new(embedding_provider);
 847
 848    let text = r#"
 849        # This concern is inspired by "sudo mode" on GitHub. It
 850        # is a way to re-authenticate a user before allowing them
 851        # to see or perform an action.
 852        #
 853        # Add `before_action :require_challenge!` to actions you
 854        # want to protect.
 855        #
 856        # The user will be shown a page to enter the challenge (which
 857        # is either the password, or just the username when no
 858        # password exists). Upon passing, there is a grace period
 859        # during which no challenge will be asked from the user.
 860        #
 861        # Accessing challenge-protected resources during the grace
 862        # period will refresh the grace period.
 863        module ChallengableConcern
 864            extend ActiveSupport::Concern
 865
 866            CHALLENGE_TIMEOUT = 1.hour.freeze
 867
 868            def require_challenge!
 869                return if skip_challenge?
 870
 871                if challenge_passed_recently?
 872                    session[:challenge_passed_at] = Time.now.utc
 873                    return
 874                end
 875
 876                @challenge = Form::Challenge.new(return_to: request.url)
 877
 878                if params.key?(:form_challenge)
 879                    if challenge_passed?
 880                        session[:challenge_passed_at] = Time.now.utc
 881                    else
 882                        flash.now[:alert] = I18n.t('challenge.invalid_password')
 883                        render_challenge
 884                    end
 885                else
 886                    render_challenge
 887                end
 888            end
 889
 890            def challenge_passed?
 891                current_user.valid_password?(challenge_params[:current_password])
 892            end
 893        end
 894
 895        class Animal
 896            include Comparable
 897
 898            attr_reader :legs
 899
 900            def initialize(name, legs)
 901                @name, @legs = name, legs
 902            end
 903
 904            def <=>(other)
 905                legs <=> other.legs
 906            end
 907        end
 908
 909        # Singleton method for car object
 910        def car.wheels
 911            puts "There are four wheels"
 912        end"#
 913        .unindent();
 914
 915    let documents = retriever.parse_file(&text, language.clone()).unwrap();
 916
 917    assert_documents_eq(
 918        &documents,
 919        &[
 920            (
 921                r#"
 922        # This concern is inspired by "sudo mode" on GitHub. It
 923        # is a way to re-authenticate a user before allowing them
 924        # to see or perform an action.
 925        #
 926        # Add `before_action :require_challenge!` to actions you
 927        # want to protect.
 928        #
 929        # The user will be shown a page to enter the challenge (which
 930        # is either the password, or just the username when no
 931        # password exists). Upon passing, there is a grace period
 932        # during which no challenge will be asked from the user.
 933        #
 934        # Accessing challenge-protected resources during the grace
 935        # period will refresh the grace period.
 936        module ChallengableConcern
 937            extend ActiveSupport::Concern
 938
 939            CHALLENGE_TIMEOUT = 1.hour.freeze
 940
 941            def require_challenge!
 942                # ...
 943            end
 944
 945            def challenge_passed?
 946                # ...
 947            end
 948        end"#
 949                    .unindent(),
 950                558,
 951            ),
 952            (
 953                r#"
 954            def require_challenge!
 955                return if skip_challenge?
 956
 957                if challenge_passed_recently?
 958                    session[:challenge_passed_at] = Time.now.utc
 959                    return
 960                end
 961
 962                @challenge = Form::Challenge.new(return_to: request.url)
 963
 964                if params.key?(:form_challenge)
 965                    if challenge_passed?
 966                        session[:challenge_passed_at] = Time.now.utc
 967                    else
 968                        flash.now[:alert] = I18n.t('challenge.invalid_password')
 969                        render_challenge
 970                    end
 971                else
 972                    render_challenge
 973                end
 974            end"#
 975                    .unindent(),
 976                663,
 977            ),
 978            (
 979                r#"
 980                def challenge_passed?
 981                    current_user.valid_password?(challenge_params[:current_password])
 982                end"#
 983                    .unindent(),
 984                1254,
 985            ),
 986            (
 987                r#"
 988                class Animal
 989                    include Comparable
 990
 991                    attr_reader :legs
 992
 993                    def initialize(name, legs)
 994                        # ...
 995                    end
 996
 997                    def <=>(other)
 998                        # ...
 999                    end
1000                end"#
1001                    .unindent(),
1002                1363,
1003            ),
1004            (
1005                r#"
1006                def initialize(name, legs)
1007                    @name, @legs = name, legs
1008                end"#
1009                    .unindent(),
1010                1427,
1011            ),
1012            (
1013                r#"
1014                def <=>(other)
1015                    legs <=> other.legs
1016                end"#
1017                    .unindent(),
1018                1501,
1019            ),
1020            (
1021                r#"
1022                # Singleton method for car object
1023                def car.wheels
1024                    puts "There are four wheels"
1025                end"#
1026                    .unindent(),
1027                1591,
1028            ),
1029        ],
1030    );
1031}
1032
1033#[gpui::test]
1034async fn test_code_context_retrieval_php() {
1035    let language = php_lang();
1036    let embedding_provider = Arc::new(DummyEmbeddings {});
1037    let mut retriever = CodeContextRetriever::new(embedding_provider);
1038
1039    let text = r#"
1040        <?php
1041
1042        namespace LevelUp\Experience\Concerns;
1043
1044        /*
1045        This is a multiple-lines comment block
1046        that spans over multiple
1047        lines
1048        */
1049        function functionName() {
1050            echo "Hello world!";
1051        }
1052
1053        trait HasAchievements
1054        {
1055            /**
1056            * @throws \Exception
1057            */
1058            public function grantAchievement(Achievement $achievement, $progress = null): void
1059            {
1060                if ($progress > 100) {
1061                    throw new Exception(message: 'Progress cannot be greater than 100');
1062                }
1063
1064                if ($this->achievements()->find($achievement->id)) {
1065                    throw new Exception(message: 'User already has this Achievement');
1066                }
1067
1068                $this->achievements()->attach($achievement, [
1069                    'progress' => $progress ?? null,
1070                ]);
1071
1072                $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1073            }
1074
1075            public function achievements(): BelongsToMany
1076            {
1077                return $this->belongsToMany(related: Achievement::class)
1078                ->withPivot(columns: 'progress')
1079                ->where('is_secret', false)
1080                ->using(AchievementUser::class);
1081            }
1082        }
1083
1084        interface Multiplier
1085        {
1086            public function qualifies(array $data): bool;
1087
1088            public function setMultiplier(): int;
1089        }
1090
1091        enum AuditType: string
1092        {
1093            case Add = 'add';
1094            case Remove = 'remove';
1095            case Reset = 'reset';
1096            case LevelUp = 'level_up';
1097        }
1098
1099        ?>"#
1100    .unindent();
1101
1102    let documents = retriever.parse_file(&text, language.clone()).unwrap();
1103
1104    assert_documents_eq(
1105        &documents,
1106        &[
1107            (
1108                r#"
1109        /*
1110        This is a multiple-lines comment block
1111        that spans over multiple
1112        lines
1113        */
1114        function functionName() {
1115            echo "Hello world!";
1116        }"#
1117                .unindent(),
1118                123,
1119            ),
1120            (
1121                r#"
1122        trait HasAchievements
1123        {
1124            /**
1125            * @throws \Exception
1126            */
1127            public function grantAchievement(Achievement $achievement, $progress = null): void
1128            {/* ... */}
1129
1130            public function achievements(): BelongsToMany
1131            {/* ... */}
1132        }"#
1133                .unindent(),
1134                177,
1135            ),
1136            (r#"
1137            /**
1138            * @throws \Exception
1139            */
1140            public function grantAchievement(Achievement $achievement, $progress = null): void
1141            {
1142                if ($progress > 100) {
1143                    throw new Exception(message: 'Progress cannot be greater than 100');
1144                }
1145
1146                if ($this->achievements()->find($achievement->id)) {
1147                    throw new Exception(message: 'User already has this Achievement');
1148                }
1149
1150                $this->achievements()->attach($achievement, [
1151                    'progress' => $progress ?? null,
1152                ]);
1153
1154                $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
1155            }"#.unindent(), 245),
1156            (r#"
1157                public function achievements(): BelongsToMany
1158                {
1159                    return $this->belongsToMany(related: Achievement::class)
1160                    ->withPivot(columns: 'progress')
1161                    ->where('is_secret', false)
1162                    ->using(AchievementUser::class);
1163                }"#.unindent(), 902),
1164            (r#"
1165                interface Multiplier
1166                {
1167                    public function qualifies(array $data): bool;
1168
1169                    public function setMultiplier(): int;
1170                }"#.unindent(),
1171                1146),
1172            (r#"
1173                enum AuditType: string
1174                {
1175                    case Add = 'add';
1176                    case Remove = 'remove';
1177                    case Reset = 'reset';
1178                    case LevelUp = 'level_up';
1179                }"#.unindent(), 1265)
1180        ],
1181    );
1182}
1183
1184#[gpui::test]
1185fn test_dot_product(mut rng: StdRng) {
1186    assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
1187    assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
1188
1189    for _ in 0..100 {
1190        let size = 1536;
1191        let mut a = vec![0.; size];
1192        let mut b = vec![0.; size];
1193        for (a, b) in a.iter_mut().zip(b.iter_mut()) {
1194            *a = rng.gen();
1195            *b = rng.gen();
1196        }
1197
1198        assert_eq!(
1199            round_to_decimals(dot(&a, &b), 1),
1200            round_to_decimals(reference_dot(&a, &b), 1)
1201        );
1202    }
1203
1204    fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
1205        let factor = (10.0 as f32).powi(decimal_places);
1206        (n * factor).round() / factor
1207    }
1208
1209    fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
1210        a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
1211    }
1212}
1213
1214#[derive(Default)]
1215struct FakeEmbeddingProvider {
1216    embedding_count: AtomicUsize,
1217}
1218
1219impl FakeEmbeddingProvider {
1220    fn embedding_count(&self) -> usize {
1221        self.embedding_count.load(atomic::Ordering::SeqCst)
1222    }
1223}
1224
1225#[async_trait]
1226impl EmbeddingProvider for FakeEmbeddingProvider {
1227    fn count_tokens(&self, span: &str) -> usize {
1228        span.len()
1229    }
1230
1231    fn should_truncate(&self, span: &str) -> bool {
1232        false
1233    }
1234
1235    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
1236        self.embedding_count
1237            .fetch_add(spans.len(), atomic::Ordering::SeqCst);
1238        Ok(spans
1239            .iter()
1240            .map(|span| {
1241                let mut result = vec![1.0; 26];
1242                for letter in span.chars() {
1243                    let letter = letter.to_ascii_lowercase();
1244                    if letter as u32 >= 'a' as u32 {
1245                        let ix = (letter as u32) - ('a' as u32);
1246                        if ix < 26 {
1247                            result[ix as usize] += 1.0;
1248                        }
1249                    }
1250                }
1251
1252                let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
1253                for x in &mut result {
1254                    *x /= norm;
1255                }
1256
1257                result
1258            })
1259            .collect())
1260    }
1261}
1262
1263fn js_lang() -> Arc<Language> {
1264    Arc::new(
1265        Language::new(
1266            LanguageConfig {
1267                name: "Javascript".into(),
1268                path_suffixes: vec!["js".into()],
1269                ..Default::default()
1270            },
1271            Some(tree_sitter_typescript::language_tsx()),
1272        )
1273        .with_embedding_query(
1274            &r#"
1275
1276            (
1277                (comment)* @context
1278                .
1279                [
1280                (export_statement
1281                    (function_declaration
1282                        "async"? @name
1283                        "function" @name
1284                        name: (_) @name))
1285                (function_declaration
1286                    "async"? @name
1287                    "function" @name
1288                    name: (_) @name)
1289                ] @item
1290            )
1291
1292            (
1293                (comment)* @context
1294                .
1295                [
1296                (export_statement
1297                    (class_declaration
1298                        "class" @name
1299                        name: (_) @name))
1300                (class_declaration
1301                    "class" @name
1302                    name: (_) @name)
1303                ] @item
1304            )
1305
1306            (
1307                (comment)* @context
1308                .
1309                [
1310                (export_statement
1311                    (interface_declaration
1312                        "interface" @name
1313                        name: (_) @name))
1314                (interface_declaration
1315                    "interface" @name
1316                    name: (_) @name)
1317                ] @item
1318            )
1319
1320            (
1321                (comment)* @context
1322                .
1323                [
1324                (export_statement
1325                    (enum_declaration
1326                        "enum" @name
1327                        name: (_) @name))
1328                (enum_declaration
1329                    "enum" @name
1330                    name: (_) @name)
1331                ] @item
1332            )
1333
1334            (
1335                (comment)* @context
1336                .
1337                (method_definition
1338                    [
1339                        "get"
1340                        "set"
1341                        "async"
1342                        "*"
1343                        "static"
1344                    ]* @name
1345                    name: (_) @name) @item
1346            )
1347
1348                    "#
1349            .unindent(),
1350        )
1351        .unwrap(),
1352    )
1353}
1354
1355fn rust_lang() -> Arc<Language> {
1356    Arc::new(
1357        Language::new(
1358            LanguageConfig {
1359                name: "Rust".into(),
1360                path_suffixes: vec!["rs".into()],
1361                collapsed_placeholder: " /* ... */ ".to_string(),
1362                ..Default::default()
1363            },
1364            Some(tree_sitter_rust::language()),
1365        )
1366        .with_embedding_query(
1367            r#"
1368            (
1369                [(line_comment) (attribute_item)]* @context
1370                .
1371                [
1372                    (struct_item
1373                        name: (_) @name)
1374
1375                    (enum_item
1376                        name: (_) @name)
1377
1378                    (impl_item
1379                        trait: (_)? @name
1380                        "for"? @name
1381                        type: (_) @name)
1382
1383                    (trait_item
1384                        name: (_) @name)
1385
1386                    (function_item
1387                        name: (_) @name
1388                        body: (block
1389                            "{" @keep
1390                            "}" @keep) @collapse)
1391
1392                    (macro_definition
1393                        name: (_) @name)
1394                ] @item
1395            )
1396            "#,
1397        )
1398        .unwrap(),
1399    )
1400}
1401
1402fn json_lang() -> Arc<Language> {
1403    Arc::new(
1404        Language::new(
1405            LanguageConfig {
1406                name: "JSON".into(),
1407                path_suffixes: vec!["json".into()],
1408                ..Default::default()
1409            },
1410            Some(tree_sitter_json::language()),
1411        )
1412        .with_embedding_query(
1413            r#"
1414            (document) @item
1415
1416            (array
1417                "[" @keep
1418                .
1419                (object)? @keep
1420                "]" @keep) @collapse
1421
1422            (pair value: (string
1423                "\"" @keep
1424                "\"" @keep) @collapse)
1425            "#,
1426        )
1427        .unwrap(),
1428    )
1429}
1430
1431fn toml_lang() -> Arc<Language> {
1432    Arc::new(Language::new(
1433        LanguageConfig {
1434            name: "TOML".into(),
1435            path_suffixes: vec!["toml".into()],
1436            ..Default::default()
1437        },
1438        Some(tree_sitter_toml::language()),
1439    ))
1440}
1441
1442fn cpp_lang() -> Arc<Language> {
1443    Arc::new(
1444        Language::new(
1445            LanguageConfig {
1446                name: "CPP".into(),
1447                path_suffixes: vec!["cpp".into()],
1448                ..Default::default()
1449            },
1450            Some(tree_sitter_cpp::language()),
1451        )
1452        .with_embedding_query(
1453            r#"
1454            (
1455                (comment)* @context
1456                .
1457                (function_definition
1458                    (type_qualifier)? @name
1459                    type: (_)? @name
1460                    declarator: [
1461                        (function_declarator
1462                            declarator: (_) @name)
1463                        (pointer_declarator
1464                            "*" @name
1465                            declarator: (function_declarator
1466                            declarator: (_) @name))
1467                        (pointer_declarator
1468                            "*" @name
1469                            declarator: (pointer_declarator
1470                                "*" @name
1471                            declarator: (function_declarator
1472                                declarator: (_) @name)))
1473                        (reference_declarator
1474                            ["&" "&&"] @name
1475                            (function_declarator
1476                            declarator: (_) @name))
1477                    ]
1478                    (type_qualifier)? @name) @item
1479                )
1480
1481            (
1482                (comment)* @context
1483                .
1484                (template_declaration
1485                    (class_specifier
1486                        "class" @name
1487                        name: (_) @name)
1488                        ) @item
1489            )
1490
1491            (
1492                (comment)* @context
1493                .
1494                (class_specifier
1495                    "class" @name
1496                    name: (_) @name) @item
1497                )
1498
1499            (
1500                (comment)* @context
1501                .
1502                (enum_specifier
1503                    "enum" @name
1504                    name: (_) @name) @item
1505                )
1506
1507            (
1508                (comment)* @context
1509                .
1510                (declaration
1511                    type: (struct_specifier
1512                    "struct" @name)
1513                    declarator: (_) @name) @item
1514            )
1515
1516            "#,
1517        )
1518        .unwrap(),
1519    )
1520}
1521
1522fn lua_lang() -> Arc<Language> {
1523    Arc::new(
1524        Language::new(
1525            LanguageConfig {
1526                name: "Lua".into(),
1527                path_suffixes: vec!["lua".into()],
1528                collapsed_placeholder: "--[ ... ]--".to_string(),
1529                ..Default::default()
1530            },
1531            Some(tree_sitter_lua::language()),
1532        )
1533        .with_embedding_query(
1534            r#"
1535            (
1536                (comment)* @context
1537                .
1538                (function_declaration
1539                    "function" @name
1540                    name: (_) @name
1541                    (comment)* @collapse
1542                    body: (block) @collapse
1543                ) @item
1544            )
1545        "#,
1546        )
1547        .unwrap(),
1548    )
1549}
1550
1551fn php_lang() -> Arc<Language> {
1552    Arc::new(
1553        Language::new(
1554            LanguageConfig {
1555                name: "PHP".into(),
1556                path_suffixes: vec!["php".into()],
1557                collapsed_placeholder: "/* ... */".into(),
1558                ..Default::default()
1559            },
1560            Some(tree_sitter_php::language()),
1561        )
1562        .with_embedding_query(
1563            r#"
1564            (
1565                (comment)* @context
1566                .
1567                [
1568                    (function_definition
1569                        "function" @name
1570                        name: (_) @name
1571                        body: (_
1572                            "{" @keep
1573                            "}" @keep) @collapse
1574                        )
1575
1576                    (trait_declaration
1577                        "trait" @name
1578                        name: (_) @name)
1579
1580                    (method_declaration
1581                        "function" @name
1582                        name: (_) @name
1583                        body: (_
1584                            "{" @keep
1585                            "}" @keep) @collapse
1586                        )
1587
1588                    (interface_declaration
1589                        "interface" @name
1590                        name: (_) @name
1591                        )
1592
1593                    (enum_declaration
1594                        "enum" @name
1595                        name: (_) @name
1596                        )
1597
1598                ] @item
1599            )
1600            "#,
1601        )
1602        .unwrap(),
1603    )
1604}
1605
1606fn ruby_lang() -> Arc<Language> {
1607    Arc::new(
1608        Language::new(
1609            LanguageConfig {
1610                name: "Ruby".into(),
1611                path_suffixes: vec!["rb".into()],
1612                collapsed_placeholder: "# ...".to_string(),
1613                ..Default::default()
1614            },
1615            Some(tree_sitter_ruby::language()),
1616        )
1617        .with_embedding_query(
1618            r#"
1619            (
1620                (comment)* @context
1621                .
1622                [
1623                (module
1624                    "module" @name
1625                    name: (_) @name)
1626                (method
1627                    "def" @name
1628                    name: (_) @name
1629                    body: (body_statement) @collapse)
1630                (class
1631                    "class" @name
1632                    name: (_) @name)
1633                (singleton_method
1634                    "def" @name
1635                    object: (_) @name
1636                    "." @name
1637                    name: (_) @name
1638                    body: (body_statement) @collapse)
1639                ] @item
1640            )
1641            "#,
1642        )
1643        .unwrap(),
1644    )
1645}
1646
1647fn elixir_lang() -> Arc<Language> {
1648    Arc::new(
1649        Language::new(
1650            LanguageConfig {
1651                name: "Elixir".into(),
1652                path_suffixes: vec!["rs".into()],
1653                ..Default::default()
1654            },
1655            Some(tree_sitter_elixir::language()),
1656        )
1657        .with_embedding_query(
1658            r#"
1659            (
1660                (unary_operator
1661                    operator: "@"
1662                    operand: (call
1663                        target: (identifier) @unary
1664                        (#match? @unary "^(doc)$"))
1665                    ) @context
1666                .
1667                (call
1668                target: (identifier) @name
1669                (arguments
1670                [
1671                (identifier) @name
1672                (call
1673                target: (identifier) @name)
1674                (binary_operator
1675                left: (call
1676                target: (identifier) @name)
1677                operator: "when")
1678                ])
1679                (#match? @name "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @item
1680                )
1681
1682            (call
1683                target: (identifier) @name
1684                (arguments (alias) @name)
1685                (#match? @name "^(defmodule|defprotocol)$")) @item
1686            "#,
1687        )
1688        .unwrap(),
1689    )
1690}
1691
1692#[gpui::test]
1693fn test_subtract_ranges() {
1694    // collapsed_ranges: Vec<Range<usize>>, keep_ranges: Vec<Range<usize>>
1695
1696    assert_eq!(
1697        subtract_ranges(&[0..5, 10..21], &[0..1, 4..5]),
1698        vec![1..4, 10..21]
1699    );
1700
1701    assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
1702}