evals.rs

   1use super::*;
   2use crate::{
   3    ReadFileToolInput, grep_tool::GrepToolInput,
   4    streaming_edit_file_tool::StreamingEditFileToolInput,
   5};
   6use Role::*;
   7use anyhow::anyhow;
   8use client::{Client, UserStore};
   9use collections::HashMap;
  10use fs::FakeFs;
  11use futures::{FutureExt, future::LocalBoxFuture};
  12use gpui::{AppContext, TestAppContext};
  13use indoc::indoc;
  14use language_model::{
  15    LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId,
  16};
  17use project::Project;
  18use rand::prelude::*;
  19use reqwest_client::ReqwestClient;
  20use serde_json::json;
  21use std::{
  22    cmp::Reverse,
  23    fmt::{self, Display},
  24    io::Write as _,
  25    sync::mpsc,
  26};
  27use util::path;
  28
  29#[test]
  30#[cfg_attr(not(feature = "eval"), ignore)]
  31fn eval_extract_handle_command_output() {
  32    let input_file_path = "root/blame.rs";
  33    let input_file_content = include_str!("evals/fixtures/extract_handle_command_output/before.rs");
  34    let output_file_content = include_str!("evals/fixtures/extract_handle_command_output/after.rs");
  35    let edit_description = "Extract `handle_command_output` method from `run_git_blame`.";
  36    eval(
  37        100,
  38        0.95,
  39        EvalInput {
  40            conversation: vec![
  41                message(
  42                    User,
  43                    [text(indoc! {"
  44                        Read the `{input_file_path}` file and extract a method in
  45                        the final stanza of `run_git_blame` to deal with command failures,
  46                        call it `handle_command_output` and take the std::process::Output as the only parameter.
  47
  48                        Add it right next to `run_git_blame` and copy it verbatim from `run_git_blame`.
  49                    "})],
  50                ),
  51                message(
  52                    Assistant,
  53                    [tool_use(
  54                        "tool_1",
  55                        "read_file",
  56                        ReadFileToolInput {
  57                            path: input_file_path.into(),
  58                            start_line: None,
  59                            end_line: None,
  60                        },
  61                    )],
  62                ),
  63                message(
  64                    User,
  65                    [tool_result("tool_1", "read_file", input_file_content)],
  66                ),
  67                message(
  68                    Assistant,
  69                    [tool_use(
  70                        "tool_2",
  71                        "edit_file",
  72                        StreamingEditFileToolInput {
  73                            display_description: edit_description.into(),
  74                            path: input_file_path.into(),
  75                            create_or_overwrite: false,
  76                        },
  77                    )],
  78                ),
  79            ],
  80            input_path: input_file_path.into(),
  81            input_content: Some(input_file_content.into()),
  82            edit_description: edit_description.into(),
  83            assertion: EvalAssertion::assert_eq(output_file_content),
  84        },
  85    );
  86}
  87
  88#[test]
  89#[cfg_attr(not(feature = "eval"), ignore)]
  90fn eval_delete_run_git_blame() {
  91    let input_file_path = "root/blame.rs";
  92    let input_file_content = include_str!("evals/fixtures/delete_run_git_blame/before.rs");
  93    let output_file_content = include_str!("evals/fixtures/delete_run_git_blame/after.rs");
  94    let edit_description = "Delete the `run_git_blame` function.";
  95    eval(
  96        100,
  97        0.95,
  98        EvalInput {
  99            conversation: vec![
 100                message(
 101                    User,
 102                    [text(indoc! {"
 103                        Read the `{input_file_path}` file and delete `run_git_blame`. Just that
 104                        one function, not its usages.
 105                    "})],
 106                ),
 107                message(
 108                    Assistant,
 109                    [tool_use(
 110                        "tool_1",
 111                        "read_file",
 112                        ReadFileToolInput {
 113                            path: input_file_path.into(),
 114                            start_line: None,
 115                            end_line: None,
 116                        },
 117                    )],
 118                ),
 119                message(
 120                    User,
 121                    [tool_result("tool_1", "read_file", input_file_content)],
 122                ),
 123                message(
 124                    Assistant,
 125                    [tool_use(
 126                        "tool_2",
 127                        "edit_file",
 128                        StreamingEditFileToolInput {
 129                            display_description: edit_description.into(),
 130                            path: input_file_path.into(),
 131                            create_or_overwrite: false,
 132                        },
 133                    )],
 134                ),
 135            ],
 136            input_path: input_file_path.into(),
 137            input_content: Some(input_file_content.into()),
 138            edit_description: edit_description.into(),
 139            assertion: EvalAssertion::assert_eq(output_file_content),
 140        },
 141    );
 142}
 143
 144#[test]
 145#[cfg_attr(not(feature = "eval"), ignore)]
 146fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
 147    let input_file_path = "root/lib.rs";
 148    let input_file_content =
 149        include_str!("evals/fixtures/use_wasi_sdk_in_compile_parser_to_wasm/before.rs");
 150    let edit_description = "Update compile_parser_to_wasm to use wasi-sdk instead of emscripten";
 151    eval(
 152        100,
 153        0.95,
 154        EvalInput {
 155            conversation: vec![
 156                message(
 157                    User,
 158                    [text(indoc! {"
 159                        Read the `{input_file_path}` file and change `compile_parser_to_wasm` to use `wasi-sdk` instead of emscripten.
 160                        Use `ureq` to download the SDK for the current platform and architecture.
 161                        Extract the archive into a sibling of `lib` inside the `tree-sitter` directory in the cache_dir.
 162                        Compile the parser to wasm using the `bin/clang` executable (or `bin/clang.exe` on windows)
 163                        that's inside of the archive.
 164                        Don't re-download the SDK if that executable already exists.
 165
 166                        Use these clang flags: -fPIC -shared -Os -Wl,--export=tree_sitter_{language_name}
 167
 168                        Here are the available wasi-sdk assets:
 169                        - wasi-sdk-25.0-x86_64-macos.tar.gz
 170                        - wasi-sdk-25.0-arm64-macos.tar.gz
 171                        - wasi-sdk-25.0-x86_64-linux.tar.gz
 172                        - wasi-sdk-25.0-arm64-linux.tar.gz
 173                        - wasi-sdk-25.0-x86_64-linux.tar.gz
 174                        - wasi-sdk-25.0-arm64-linux.tar.gz
 175                        - wasi-sdk-25.0-x86_64-windows.tar.gz
 176                    "})],
 177                ),
 178                message(
 179                    Assistant,
 180                    [tool_use(
 181                        "tool_1",
 182                        "read_file",
 183                        ReadFileToolInput {
 184                            path: input_file_path.into(),
 185                            start_line: Some(971),
 186                            end_line: Some(1050),
 187                        },
 188                    )],
 189                ),
 190                message(
 191                    User,
 192                    [tool_result(
 193                        "tool_1",
 194                        "read_file",
 195                        lines(input_file_content, 971..1050),
 196                    )],
 197                ),
 198                message(
 199                    Assistant,
 200                    [tool_use(
 201                        "tool_2",
 202                        "read_file",
 203                        ReadFileToolInput {
 204                            path: input_file_path.into(),
 205                            start_line: Some(1050),
 206                            end_line: Some(1100),
 207                        },
 208                    )],
 209                ),
 210                message(
 211                    User,
 212                    [tool_result(
 213                        "tool_2",
 214                        "read_file",
 215                        lines(input_file_content, 1050..1100),
 216                    )],
 217                ),
 218                message(
 219                    Assistant,
 220                    [tool_use(
 221                        "tool_3",
 222                        "read_file",
 223                        ReadFileToolInput {
 224                            path: input_file_path.into(),
 225                            start_line: Some(1100),
 226                            end_line: Some(1150),
 227                        },
 228                    )],
 229                ),
 230                message(
 231                    User,
 232                    [tool_result(
 233                        "tool_3",
 234                        "read_file",
 235                        lines(input_file_content, 1100..1150),
 236                    )],
 237                ),
 238                message(
 239                    Assistant,
 240                    [tool_use(
 241                        "tool_4",
 242                        "edit_file",
 243                        StreamingEditFileToolInput {
 244                            display_description: edit_description.into(),
 245                            path: input_file_path.into(),
 246                            create_or_overwrite: false,
 247                        },
 248                    )],
 249                ),
 250            ],
 251            input_path: input_file_path.into(),
 252            input_content: Some(input_file_content.into()),
 253            edit_description: edit_description.into(),
 254            assertion: EvalAssertion::judge_diff(indoc! {"
 255                - The compile_parser_to_wasm method has been changed to use wasi-sdk
 256                - ureq is used to download the SDK for current platform and architecture
 257            "}),
 258        },
 259    );
 260}
 261
 262#[test]
 263#[cfg_attr(not(feature = "eval"), ignore)]
 264fn eval_disable_cursor_blinking() {
 265    let input_file_path = "root/editor.rs";
 266    let input_file_content = include_str!("evals/fixtures/disable_cursor_blinking/before.rs");
 267    let output_file_content = include_str!("evals/fixtures/disable_cursor_blinking/after.rs");
 268    let edit_description = "Comment out the call to `BlinkManager::enable`";
 269    eval(
 270        200,
 271        0.6, // TODO: make this eval better
 272        EvalInput {
 273            conversation: vec![
 274                message(User, [text("Let's research how to cursor blinking works.")]),
 275                message(
 276                    Assistant,
 277                    [tool_use(
 278                        "tool_1",
 279                        "grep",
 280                        GrepToolInput {
 281                            regex: "blink".into(),
 282                            include_pattern: None,
 283                            offset: 0,
 284                            case_sensitive: false,
 285                        },
 286                    )],
 287                ),
 288                message(
 289                    User,
 290                    [tool_result(
 291                        "tool_1",
 292                        "grep",
 293                        [
 294                            lines(input_file_content, 100..400),
 295                            lines(input_file_content, 800..1300),
 296                            lines(input_file_content, 1600..2000),
 297                            lines(input_file_content, 5000..5500),
 298                            lines(input_file_content, 8000..9000),
 299                            lines(input_file_content, 18455..18470),
 300                            lines(input_file_content, 20000..20500),
 301                            lines(input_file_content, 21000..21300),
 302                        ]
 303                        .join("Match found:\n\n"),
 304                    )],
 305                ),
 306                message(
 307                    User,
 308                    [text(indoc! {"
 309                        Comment out the lines that interact with the BlinkManager.
 310                        Keep the outer `update` blocks, but comments everything that's inside (including if statements).
 311                        Don't add additional comments.
 312                    "})],
 313                ),
 314                message(
 315                    Assistant,
 316                    [tool_use(
 317                        "tool_4",
 318                        "edit_file",
 319                        StreamingEditFileToolInput {
 320                            display_description: edit_description.into(),
 321                            path: input_file_path.into(),
 322                            create_or_overwrite: false,
 323                        },
 324                    )],
 325                ),
 326            ],
 327            input_path: input_file_path.into(),
 328            input_content: Some(input_file_content.into()),
 329            edit_description: edit_description.into(),
 330            assertion: EvalAssertion::assert_eq(output_file_content),
 331        },
 332    );
 333}
 334
 335#[test]
 336#[cfg_attr(not(feature = "eval"), ignore)]
 337fn eval_from_pixels_constructor() {
 338    let input_file_path = "root/canvas.rs";
 339    let input_file_content = include_str!("evals/fixtures/from_pixels_constructor/before.rs");
 340    let edit_description = "Implement from_pixels constructor and add tests.";
 341    eval(
 342        100,
 343        0.95,
 344        EvalInput {
 345            conversation: vec![
 346                message(
 347                    User,
 348                    [text(indoc! {"
 349                        Introduce a new `from_pixels` constructor in Canvas and
 350                        also add tests for it in the same file.
 351                    "})],
 352                ),
 353                message(
 354                    Assistant,
 355                    [tool_use(
 356                        "tool_1",
 357                        "read_file",
 358                        ReadFileToolInput {
 359                            path: input_file_path.into(),
 360                            start_line: None,
 361                            end_line: None,
 362                        },
 363                    )],
 364                ),
 365                message(
 366                    User,
 367                    [tool_result("tool_1", "read_file", input_file_content)],
 368                ),
 369                message(
 370                    Assistant,
 371                    [tool_use(
 372                        "tool_2",
 373                        "grep",
 374                        GrepToolInput {
 375                            regex: "mod\\s+tests".into(),
 376                            include_pattern: Some("font-kit/src/canvas.rs".into()),
 377                            offset: 0,
 378                            case_sensitive: false,
 379                        },
 380                    )],
 381                ),
 382                message(User, [tool_result("tool_2", "grep", "No matches found")]),
 383                message(
 384                    Assistant,
 385                    [tool_use(
 386                        "tool_3",
 387                        "grep",
 388                        GrepToolInput {
 389                            regex: "mod\\s+tests".into(),
 390                            include_pattern: Some("font-kit/src/**/*.rs".into()),
 391                            offset: 0,
 392                            case_sensitive: false,
 393                        },
 394                    )],
 395                ),
 396                message(User, [tool_result("tool_3", "grep", "No matches found")]),
 397                message(
 398                    Assistant,
 399                    [tool_use(
 400                        "tool_4",
 401                        "grep",
 402                        GrepToolInput {
 403                            regex: "#\\[test\\]".into(),
 404                            include_pattern: Some("font-kit/src/**/*.rs".into()),
 405                            offset: 0,
 406                            case_sensitive: false,
 407                        },
 408                    )],
 409                ),
 410                message(
 411                    User,
 412                    [tool_result(
 413                        "tool_4",
 414                        "grep",
 415                        indoc! {"
 416                            Found 6 matches:
 417
 418                            ## Matches in font-kit/src/loaders/core_text.rs
 419
 420                            ### mod test › L926-936
 421                            ```
 422                            mod test {
 423                                use super::Font;
 424                                use crate::properties::{Stretch, Weight};
 425
 426                                #[cfg(feature = \"source\")]
 427                                use crate::source::SystemSource;
 428
 429                                static TEST_FONT_POSTSCRIPT_NAME: &'static str = \"ArialMT\";
 430
 431                                #[cfg(feature = \"source\")]
 432                                #[test]
 433                            ```
 434
 435                            55 lines remaining in ancestor node. Read the file to see all.
 436
 437                            ### mod test › L947-951
 438                            ```
 439                                }
 440
 441                                #[test]
 442                                fn test_core_text_to_css_font_weight() {
 443                                    // Exact matches
 444                            ```
 445
 446                            ### mod test › L959-963
 447                            ```
 448                                }
 449
 450                                #[test]
 451                                fn test_core_text_to_css_font_stretch() {
 452                                    // Exact matches
 453                            ```
 454
 455                            ## Matches in font-kit/src/loaders/freetype.rs
 456
 457                            ### mod test › L1238-1248
 458                            ```
 459                            mod test {
 460                                use crate::loaders::freetype::Font;
 461
 462                                static PCF_FONT_PATH: &str = \"resources/tests/times-roman-pcf/timR12.pcf\";
 463                                static PCF_FONT_POSTSCRIPT_NAME: &str = \"Times-Roman\";
 464
 465                                #[test]
 466                                fn get_pcf_postscript_name() {
 467                                    let font = Font::from_path(PCF_FONT_PATH, 0).unwrap();
 468                                    assert_eq!(font.postscript_name().unwrap(), PCF_FONT_POSTSCRIPT_NAME);
 469                                }
 470                            ```
 471
 472                            1 lines remaining in ancestor node. Read the file to see all.
 473
 474                            ## Matches in font-kit/src/sources/core_text.rs
 475
 476                            ### mod test › L265-275
 477                            ```
 478                            mod test {
 479                                use crate::properties::{Stretch, Weight};
 480
 481                                #[test]
 482                                fn test_css_to_core_text_font_weight() {
 483                                    // Exact matches
 484                                    assert_eq!(super::css_to_core_text_font_weight(Weight(100.0)), -0.7);
 485                                    assert_eq!(super::css_to_core_text_font_weight(Weight(400.0)), 0.0);
 486                                    assert_eq!(super::css_to_core_text_font_weight(Weight(700.0)), 0.4);
 487                                    assert_eq!(super::css_to_core_text_font_weight(Weight(900.0)), 0.8);
 488
 489                            ```
 490
 491                            27 lines remaining in ancestor node. Read the file to see all.
 492
 493                            ### mod test › L278-282
 494                            ```
 495                                }
 496
 497                                #[test]
 498                                fn test_css_to_core_text_font_stretch() {
 499                                    // Exact matches
 500                            ```
 501                        "},
 502                    )],
 503                ),
 504                message(
 505                    Assistant,
 506                    [tool_use(
 507                        "tool_5",
 508                        "edit_file",
 509                        StreamingEditFileToolInput {
 510                            display_description: edit_description.into(),
 511                            path: input_file_path.into(),
 512                            create_or_overwrite: false,
 513                        },
 514                    )],
 515                ),
 516            ],
 517            input_path: input_file_path.into(),
 518            input_content: Some(input_file_content.into()),
 519            edit_description: edit_description.into(),
 520            assertion: EvalAssertion::judge_diff(indoc! {"
 521                - The diff contains a new `from_pixels` constructor
 522                - The diff contains new tests for the `from_pixels` constructor
 523            "}),
 524        },
 525    );
 526}
 527
 528#[test]
 529#[cfg_attr(not(feature = "eval"), ignore)]
 530fn eval_zode() {
 531    let input_file_path = "root/zode.py";
 532    let edit_description = "Create the main Zode CLI script";
 533    eval(
 534        200,
 535        1.,
 536        EvalInput {
 537            conversation: vec![
 538                message(User, [text(include_str!("evals/fixtures/zode/prompt.md"))]),
 539                message(
 540                    Assistant,
 541                    [
 542                        tool_use(
 543                            "tool_1",
 544                            "read_file",
 545                            ReadFileToolInput {
 546                                path: "root/eval/react.py".into(),
 547                                start_line: None,
 548                                end_line: None,
 549                            },
 550                        ),
 551                        tool_use(
 552                            "tool_2",
 553                            "read_file",
 554                            ReadFileToolInput {
 555                                path: "root/eval/react_test.py".into(),
 556                                start_line: None,
 557                                end_line: None,
 558                            },
 559                        ),
 560                    ],
 561                ),
 562                message(
 563                    User,
 564                    [
 565                        tool_result(
 566                            "tool_1",
 567                            "read_file",
 568                            include_str!("evals/fixtures/zode/react.py"),
 569                        ),
 570                        tool_result(
 571                            "tool_2",
 572                            "read_file",
 573                            include_str!("evals/fixtures/zode/react_test.py"),
 574                        ),
 575                    ],
 576                ),
 577                message(
 578                    Assistant,
 579                    [
 580                        text(
 581                            "Now that I understand what we need to build, I'll create the main Python script:",
 582                        ),
 583                        tool_use(
 584                            "tool_3",
 585                            "edit_file",
 586                            StreamingEditFileToolInput {
 587                                display_description: edit_description.into(),
 588                                path: input_file_path.into(),
 589                                create_or_overwrite: true,
 590                            },
 591                        ),
 592                    ],
 593                ),
 594            ],
 595            input_path: input_file_path.into(),
 596            input_content: None,
 597            edit_description: edit_description.into(),
 598            assertion: EvalAssertion::new(async move |sample, _, _cx| {
 599                let invalid_starts = [' ', '`', '\n'];
 600                let mut message = String::new();
 601                for start in invalid_starts {
 602                    if sample.text.starts_with(start) {
 603                        message.push_str(&format!("The sample starts with a {:?}\n", start));
 604                        break;
 605                    }
 606                }
 607                // Remove trailing newline.
 608                message.pop();
 609
 610                if message.is_empty() {
 611                    Ok(EvalAssertionOutcome {
 612                        score: 100,
 613                        message: None,
 614                    })
 615                } else {
 616                    Ok(EvalAssertionOutcome {
 617                        score: 0,
 618                        message: Some(message),
 619                    })
 620                }
 621            }),
 622        },
 623    );
 624}
 625
 626#[test]
 627#[cfg_attr(not(feature = "eval"), ignore)]
 628fn eval_add_overwrite_test() {
 629    let input_file_path = "root/action_log.rs";
 630    let input_file_content = include_str!("evals/fixtures/add_overwrite_test/before.rs");
 631    let edit_description = "Add a new test for overwriting a file in action_log.rs";
 632    eval(
 633        200,
 634        0.5, // TODO: make this eval better
 635        EvalInput {
 636            conversation: vec![
 637                message(
 638                    User,
 639                    [text(indoc! {"
 640                        Introduce a new test in `action_log.rs` to test overwriting a file.
 641                        That is, a file already exists, but we call `buffer_created` as if the file were new.
 642                        Take inspiration from all the other tests in the file.
 643                    "})],
 644                ),
 645                message(
 646                    Assistant,
 647                    [tool_use(
 648                        "tool_1",
 649                        "read_file",
 650                        ReadFileToolInput {
 651                            path: input_file_path.into(),
 652                            start_line: None,
 653                            end_line: None,
 654                        },
 655                    )],
 656                ),
 657                message(
 658                    User,
 659                    [tool_result(
 660                        "tool_1",
 661                        "read_file",
 662                        indoc! {"
 663                            pub struct ActionLog [L13-20]
 664                             tracked_buffers [L15]
 665                             edited_since_project_diagnostics_check [L17]
 666                             project [L19]
 667                            impl ActionLog [L22-498]
 668                             pub fn new [L24-30]
 669                             pub fn project [L32-34]
 670                             pub fn checked_project_diagnostics [L37-39]
 671                             pub fn has_edited_files_since_project_diagnostics_check [L42-44]
 672                             fn track_buffer_internal [L46-101]
 673                             fn handle_buffer_event [L103-116]
 674                             fn handle_buffer_edited [L118-123]
 675                             fn handle_buffer_file_changed [L125-158]
 676                             async fn maintain_diff [L160-264]
 677                             pub fn buffer_read [L267-269]
 678                             pub fn buffer_created [L272-276]
 679                             pub fn buffer_edited [L279-287]
 680                             pub fn will_delete_buffer [L289-304]
 681                             pub fn keep_edits_in_range [L306-364]
 682                             pub fn reject_edits_in_ranges [L366-459]
 683                             pub fn keep_all_edits [L461-473]
 684                             pub fn changed_buffers [L476-482]
 685                             pub fn stale_buffers [L485-497]
 686                            fn apply_non_conflicting_edits [L500-561]
 687                            fn diff_snapshots [L563-585]
 688                            fn point_to_row_edit [L587-614]
 689                            enum ChangeAuthor [L617-620]
 690                             User [L618]
 691                             Agent [L619]
 692                            enum TrackedBufferStatus [L623-627]
 693                             Created [L624]
 694                             Modified [L625]
 695                             Deleted [L626]
 696                            struct TrackedBuffer [L629-641]
 697                             buffer [L630]
 698                             base_text [L631]
 699                             unreviewed_changes [L632]
 700                             status [L633]
 701                             version [L634]
 702                             diff [L635]
 703                             snapshot [L636]
 704                             diff_update [L637]
 705                             _open_lsp_handle [L638]
 706                             _maintain_diff [L639]
 707                             _subscription [L640]
 708                            impl TrackedBuffer [L643-657]
 709                             fn has_changes [L644-650]
 710                             fn schedule_diff_update [L652-656]
 711                            pub struct ChangedBuffer [L659-661]
 712                             pub diff [L660]
 713                            mod tests [L664-1574]
 714                             fn init_logger [L678-682]
 715                             fn init_test [L684-691]
 716                             async fn test_keep_edits [L694-769]
 717                             async fn test_deletions [L772-854]
 718                             async fn test_overlapping_user_edits [L857-951]
 719                             async fn test_creating_files [L954-1010]
 720                             async fn test_deleting_files [L1013-1120]
 721                             async fn test_reject_edits [L1123-1255]
 722                             async fn test_reject_multiple_edits [L1258-1331]
 723                             async fn test_reject_deleted_file [L1334-1388]
 724                             async fn test_reject_created_file [L1391-1443]
 725                             async fn test_random_diffs [L1446-1535]
 726                              fn quiesce [L1510-1534]
 727                             struct HunkStatus [L1538-1542]
 728                              range [L1539]
 729                              diff_status [L1540]
 730                              old_text [L1541]
 731                             fn unreviewed_hunks [L1544-1573]
 732
 733                            Showing symbols 1-69 (total symbols: 69)
 734
 735                            Using the line numbers in this outline, you can call this tool again while specifying
 736                            the start_line and end_line fields to see the implementations of symbols in the outline.
 737                        "},
 738                    )],
 739                ),
 740                message(
 741                    Assistant,
 742                    [
 743                        text(
 744                            "Let's look at the existing test cases that are related to file creation to understand the pattern:",
 745                        ),
 746                        tool_use(
 747                            "tool_2",
 748                            "read_file",
 749                            ReadFileToolInput {
 750                                path: input_file_path.into(),
 751                                start_line: Some(953),
 752                                end_line: Some(1010),
 753                            },
 754                        ),
 755                    ],
 756                ),
 757                message(
 758                    User,
 759                    [tool_result(
 760                        "tool_2",
 761                        "read_file",
 762                        lines(input_file_content, 953..1010),
 763                    )],
 764                ),
 765                message(
 766                    Assistant,
 767                    [
 768                        text(
 769                            "Now let's look at another relevant test to understand how they're structured:",
 770                        ),
 771                        tool_use(
 772                            "tool_3",
 773                            "read_file",
 774                            ReadFileToolInput {
 775                                path: input_file_path.into(),
 776                                start_line: Some(1012),
 777                                end_line: Some(1120),
 778                            },
 779                        ),
 780                    ],
 781                ),
 782                message(
 783                    User,
 784                    [tool_result(
 785                        "tool_3",
 786                        "read_file",
 787                        lines(input_file_content, 1012..1120),
 788                    )],
 789                ),
 790                message(
 791                    Assistant,
 792                    [
 793                        text("Now let's look at how `buffer_created` is implemented:"),
 794                        tool_use(
 795                            "tool_4",
 796                            "read_file",
 797                            ReadFileToolInput {
 798                                path: input_file_path.into(),
 799                                start_line: Some(271),
 800                                end_line: Some(276),
 801                            },
 802                        ),
 803                    ],
 804                ),
 805                message(
 806                    User,
 807                    [tool_result(
 808                        "tool_4",
 809                        "read_file",
 810                        lines(input_file_content, 271..276),
 811                    )],
 812                ),
 813                message(
 814                    Assistant,
 815                    [
 816                        text(indoc! {"
 817                            Based on my understanding of the existing tests and code, I'll now create a new test for overwriting a file. In this test, we'll:
 818
 819                            1. Create a file
 820                            2. Open a buffer for it
 821                            3. Track it in the action log
 822                            4. Modify the buffer
 823                            5. Call `buffer_created` again as if the file were new
 824                            6. Check that the changes are properly tracked
 825
 826                            Let's write the test:
 827                        "}),
 828                        tool_use(
 829                            "tool_5",
 830                            "edit_file",
 831                            StreamingEditFileToolInput {
 832                                display_description: edit_description.into(),
 833                                path: input_file_path.into(),
 834                                create_or_overwrite: false,
 835                            },
 836                        ),
 837                    ],
 838                ),
 839            ],
 840            input_path: input_file_path.into(),
 841            input_content: Some(input_file_content.into()),
 842            edit_description: edit_description.into(),
 843            assertion: EvalAssertion::judge_diff(
 844                "A new test for overwritten files was created, without changing any previous test",
 845            ),
 846        },
 847    );
 848}
 849
 850fn message(
 851    role: Role,
 852    contents: impl IntoIterator<Item = MessageContent>,
 853) -> LanguageModelRequestMessage {
 854    LanguageModelRequestMessage {
 855        role,
 856        content: contents.into_iter().collect(),
 857        cache: false,
 858    }
 859}
 860
 861fn text(text: impl Into<String>) -> MessageContent {
 862    MessageContent::Text(text.into())
 863}
 864
 865fn lines(input: &str, range: Range<usize>) -> String {
 866    input
 867        .lines()
 868        .skip(range.start)
 869        .take(range.len())
 870        .collect::<Vec<_>>()
 871        .join("\n")
 872}
 873
 874fn tool_use(
 875    id: impl Into<Arc<str>>,
 876    name: impl Into<Arc<str>>,
 877    input: impl Serialize,
 878) -> MessageContent {
 879    MessageContent::ToolUse(LanguageModelToolUse {
 880        id: LanguageModelToolUseId::from(id.into()),
 881        name: name.into(),
 882        raw_input: serde_json::to_string_pretty(&input).unwrap(),
 883        input: serde_json::to_value(input).unwrap(),
 884        is_input_complete: true,
 885    })
 886}
 887
 888fn tool_result(
 889    id: impl Into<Arc<str>>,
 890    name: impl Into<Arc<str>>,
 891    result: impl Into<Arc<str>>,
 892) -> MessageContent {
 893    MessageContent::ToolResult(LanguageModelToolResult {
 894        tool_use_id: LanguageModelToolUseId::from(id.into()),
 895        tool_name: name.into(),
 896        is_error: false,
 897        content: result.into(),
 898    })
 899}
 900
 901#[derive(Clone)]
 902struct EvalInput {
 903    conversation: Vec<LanguageModelRequestMessage>,
 904    input_path: PathBuf,
 905    input_content: Option<String>,
 906    edit_description: String,
 907    assertion: EvalAssertion,
 908}
 909
 910#[derive(Clone)]
 911struct EvalSample {
 912    text: String,
 913    edit_output: EditAgentOutput,
 914    diff: String,
 915}
 916
 917trait AssertionFn: 'static + Send + Sync {
 918    fn assert<'a>(
 919        &'a self,
 920        sample: &'a EvalSample,
 921        judge_model: Arc<dyn LanguageModel>,
 922        cx: &'a mut TestAppContext,
 923    ) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>>;
 924}
 925
 926impl<F> AssertionFn for F
 927where
 928    F: 'static
 929        + Send
 930        + Sync
 931        + AsyncFn(
 932            &EvalSample,
 933            Arc<dyn LanguageModel>,
 934            &mut TestAppContext,
 935        ) -> Result<EvalAssertionOutcome>,
 936{
 937    fn assert<'a>(
 938        &'a self,
 939        sample: &'a EvalSample,
 940        judge_model: Arc<dyn LanguageModel>,
 941        cx: &'a mut TestAppContext,
 942    ) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>> {
 943        (self)(sample, judge_model, cx).boxed_local()
 944    }
 945}
 946
 947#[derive(Clone)]
 948struct EvalAssertion(Arc<dyn AssertionFn>);
 949
 950impl EvalAssertion {
 951    fn new<F>(f: F) -> Self
 952    where
 953        F: 'static
 954            + Send
 955            + Sync
 956            + AsyncFn(
 957                &EvalSample,
 958                Arc<dyn LanguageModel>,
 959                &mut TestAppContext,
 960            ) -> Result<EvalAssertionOutcome>,
 961    {
 962        EvalAssertion(Arc::new(f))
 963    }
 964
 965    fn assert_eq(expected: impl Into<String>) -> Self {
 966        let expected = expected.into();
 967        Self::new(async move |sample, _judge, _cx| {
 968            Ok(EvalAssertionOutcome {
 969                score: if strip_empty_lines(&sample.text) == strip_empty_lines(&expected) {
 970                    100
 971                } else {
 972                    0
 973                },
 974                message: None,
 975            })
 976        })
 977    }
 978
 979    fn judge_diff(assertions: &'static str) -> Self {
 980        Self::new(async move |sample, judge, cx| {
 981            let prompt = DiffJudgeTemplate {
 982                diff: sample.diff.clone(),
 983                assertions,
 984            }
 985            .render(&Templates::new())
 986            .unwrap();
 987
 988            let request = LanguageModelRequest {
 989                messages: vec![LanguageModelRequestMessage {
 990                    role: Role::User,
 991                    content: vec![prompt.into()],
 992                    cache: false,
 993                }],
 994                ..Default::default()
 995            };
 996            let mut response = judge
 997                .stream_completion_text(request, &cx.to_async())
 998                .await?;
 999            let mut output = String::new();
1000            while let Some(chunk) = response.stream.next().await {
1001                let chunk = chunk?;
1002                output.push_str(&chunk);
1003            }
1004
1005            // Parse the score from the response
1006            let re = regex::Regex::new(r"<score>(\d+)</score>").unwrap();
1007            if let Some(captures) = re.captures(&output) {
1008                if let Some(score_match) = captures.get(1) {
1009                    let score = score_match.as_str().parse().unwrap_or(0);
1010                    return Ok(EvalAssertionOutcome {
1011                        score,
1012                        message: Some(output),
1013                    });
1014                }
1015            }
1016
1017            Err(anyhow!(
1018                "No score found in response. Raw output: {}",
1019                output
1020            ))
1021        })
1022    }
1023
1024    async fn run(
1025        &self,
1026        input: &EvalSample,
1027        judge_model: Arc<dyn LanguageModel>,
1028        cx: &mut TestAppContext,
1029    ) -> Result<EvalAssertionOutcome> {
1030        self.0.assert(input, judge_model, cx).await
1031    }
1032}
1033
1034fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
1035    let mut evaluated_count = 0;
1036    report_progress(evaluated_count, iterations);
1037
1038    let (tx, rx) = mpsc::channel();
1039
1040    // Cache the last message in the conversation, and run one instance of the eval so that
1041    // all the next ones are cached.
1042    eval.conversation.last_mut().unwrap().cache = true;
1043    run_eval(eval.clone(), tx.clone());
1044
1045    let executor = gpui::background_executor();
1046    for _ in 1..iterations {
1047        let eval = eval.clone();
1048        let tx = tx.clone();
1049        executor.spawn(async move { run_eval(eval, tx) }).detach();
1050    }
1051    drop(tx);
1052
1053    let mut failed_count = 0;
1054    let mut failed_evals = HashMap::default();
1055    let mut errored_evals = HashMap::default();
1056    let mut eval_outputs = Vec::new();
1057    let mut cumulative_parser_metrics = EditParserMetrics::default();
1058    while let Ok(output) = rx.recv() {
1059        match output {
1060            Ok(output) => {
1061                cumulative_parser_metrics += output.sample.edit_output._parser_metrics.clone();
1062                eval_outputs.push(output.clone());
1063                if output.assertion.score < 80 {
1064                    failed_count += 1;
1065                    failed_evals
1066                        .entry(output.sample.text.clone())
1067                        .or_insert(Vec::new())
1068                        .push(output);
1069                }
1070            }
1071            Err(error) => {
1072                failed_count += 1;
1073                *errored_evals.entry(format!("{:?}", error)).or_insert(0) += 1;
1074            }
1075        }
1076
1077        evaluated_count += 1;
1078        report_progress(evaluated_count, iterations);
1079    }
1080
1081    let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32;
1082    println!("Actual pass ratio: {}\n", actual_pass_ratio);
1083    if actual_pass_ratio < expected_pass_ratio {
1084        let mut errored_evals = errored_evals.into_iter().collect::<Vec<_>>();
1085        errored_evals.sort_by_key(|(_, count)| Reverse(*count));
1086        for (error, count) in errored_evals {
1087            println!("Eval errored {} times. Error: {}", count, error);
1088        }
1089
1090        let mut failed_evals = failed_evals.into_iter().collect::<Vec<_>>();
1091        failed_evals.sort_by_key(|(_, evals)| Reverse(evals.len()));
1092        for (_buffer_output, failed_evals) in failed_evals {
1093            let eval_output = failed_evals.first().unwrap();
1094            println!("Eval failed {} times", failed_evals.len());
1095            println!("{}", eval_output);
1096        }
1097
1098        panic!(
1099            "Actual pass ratio: {}\nExpected pass ratio: {}",
1100            actual_pass_ratio, expected_pass_ratio
1101        );
1102    }
1103
1104    let mismatched_tag_ratio =
1105        cumulative_parser_metrics.mismatched_tags as f32 / cumulative_parser_metrics.tags as f32;
1106    if mismatched_tag_ratio > 0.05 {
1107        for eval_output in eval_outputs {
1108            println!("{}", eval_output);
1109        }
1110        panic!("Too many mismatched tags: {:?}", cumulative_parser_metrics);
1111    }
1112}
1113
1114fn run_eval(eval: EvalInput, tx: mpsc::Sender<Result<EvalOutput>>) {
1115    let dispatcher = gpui::TestDispatcher::new(StdRng::from_entropy());
1116    let mut cx = TestAppContext::build(dispatcher, None);
1117    let output = cx.executor().block_test(async {
1118        let test = EditAgentTest::new(&mut cx).await;
1119        test.eval(eval, &mut cx).await
1120    });
1121    tx.send(output).unwrap();
1122}
1123
1124#[derive(Clone)]
1125struct EvalOutput {
1126    sample: EvalSample,
1127    assertion: EvalAssertionOutcome,
1128}
1129
1130impl Display for EvalOutput {
1131    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1132        writeln!(f, "Score: {:?}", self.assertion.score)?;
1133        if let Some(message) = self.assertion.message.as_ref() {
1134            writeln!(f, "Message: {}", message)?;
1135        }
1136
1137        writeln!(f, "Diff:\n{}", self.sample.diff)?;
1138
1139        writeln!(
1140            f,
1141            "Parser Metrics:\n{:#?}",
1142            self.sample.edit_output._parser_metrics
1143        )?;
1144        writeln!(f, "Raw Edits:\n{}", self.sample.edit_output._raw_edits)?;
1145        Ok(())
1146    }
1147}
1148
1149fn report_progress(evaluated_count: usize, iterations: usize) {
1150    print!("\r\x1b[KEvaluated {}/{}", evaluated_count, iterations);
1151    std::io::stdout().flush().unwrap();
1152}
1153
1154struct EditAgentTest {
1155    agent: EditAgent,
1156    project: Entity<Project>,
1157    judge_model: Arc<dyn LanguageModel>,
1158}
1159
1160impl EditAgentTest {
1161    async fn new(cx: &mut TestAppContext) -> Self {
1162        cx.executor().allow_parking();
1163        cx.update(settings::init);
1164        cx.update(Project::init_settings);
1165        cx.update(language::init);
1166        cx.update(gpui_tokio::init);
1167        cx.update(client::init_settings);
1168
1169        let fs = FakeFs::new(cx.executor().clone());
1170        fs.insert_tree("/root", json!({})).await;
1171        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1172        let (agent_model, judge_model) = cx
1173            .update(|cx| {
1174                let http_client = ReqwestClient::user_agent("agent tests").unwrap();
1175                cx.set_http_client(Arc::new(http_client));
1176
1177                let client = Client::production(cx);
1178                let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1179                language_model::init(client.clone(), cx);
1180                language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
1181
1182                cx.spawn(async move |cx| {
1183                    let agent_model =
1184                        Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
1185                    let judge_model =
1186                        Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
1187                    (agent_model.unwrap(), judge_model.unwrap())
1188                })
1189            })
1190            .await;
1191        let action_log = cx.new(|_| ActionLog::new(project.clone()));
1192
1193        Self {
1194            agent: EditAgent::new(agent_model, project.clone(), action_log, Templates::new()),
1195            project,
1196            judge_model,
1197        }
1198    }
1199
1200    async fn load_model(
1201        provider: &str,
1202        id: &str,
1203        cx: &mut AsyncApp,
1204    ) -> Result<Arc<dyn LanguageModel>> {
1205        let (provider, model) = cx.update(|cx| {
1206            let models = LanguageModelRegistry::read_global(cx);
1207            let model = models
1208                .available_models(cx)
1209                .find(|model| model.provider_id().0 == provider && model.id().0 == id)
1210                .unwrap();
1211            let provider = models.provider(&model.provider_id()).unwrap();
1212            (provider, model)
1213        })?;
1214        cx.update(|cx| provider.authenticate(cx))?.await?;
1215        Ok(model)
1216    }
1217
1218    async fn eval(&self, eval: EvalInput, cx: &mut TestAppContext) -> Result<EvalOutput> {
1219        let path = self
1220            .project
1221            .read_with(cx, |project, cx| {
1222                project.find_project_path(eval.input_path, cx)
1223            })
1224            .unwrap();
1225        let buffer = self
1226            .project
1227            .update(cx, |project, cx| project.open_buffer(path, cx))
1228            .await
1229            .unwrap();
1230        let edit_output = if let Some(input_content) = eval.input_content.as_deref() {
1231            buffer.update(cx, |buffer, cx| buffer.set_text(input_content, cx));
1232            let (edit_output, _) = self.agent.edit(
1233                buffer.clone(),
1234                eval.edit_description,
1235                eval.conversation,
1236                &mut cx.to_async(),
1237            );
1238            edit_output.await?
1239        } else {
1240            let (edit_output, _) = self.agent.overwrite(
1241                buffer.clone(),
1242                eval.edit_description,
1243                eval.conversation,
1244                &mut cx.to_async(),
1245            );
1246            edit_output.await?
1247        };
1248
1249        let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
1250        let sample = EvalSample {
1251            edit_output,
1252            diff: language::unified_diff(
1253                eval.input_content.as_deref().unwrap_or_default(),
1254                &buffer_text,
1255            ),
1256            text: buffer_text,
1257        };
1258        let assertion = eval
1259            .assertion
1260            .run(&sample, self.judge_model.clone(), cx)
1261            .await?;
1262
1263        Ok(EvalOutput { assertion, sample })
1264    }
1265}
1266
1267#[derive(Clone, Debug, Eq, PartialEq, Hash)]
1268struct EvalAssertionOutcome {
1269    score: usize,
1270    message: Option<String>,
1271}
1272
1273#[derive(Serialize)]
1274pub struct DiffJudgeTemplate {
1275    diff: String,
1276    assertions: &'static str,
1277}
1278
1279impl Template for DiffJudgeTemplate {
1280    const TEMPLATE_NAME: &'static str = "diff_judge.hbs";
1281}
1282
1283fn strip_empty_lines(text: &str) -> String {
1284    text.lines()
1285        .filter(|line| !line.trim().is_empty())
1286        .collect::<Vec<_>>()
1287        .join("\n")
1288}