evals.rs

   1use super::*;
   2use crate::{
   3    ReadFileToolInput, grep_tool::GrepToolInput,
   4    streaming_edit_file_tool::StreamingEditFileToolInput,
   5};
   6use Role::*;
   7use anyhow::anyhow;
   8use client::{Client, UserStore};
   9use collections::HashMap;
  10use fs::FakeFs;
  11use futures::{FutureExt, future::LocalBoxFuture};
  12use gpui::{AppContext, TestAppContext};
  13use indoc::indoc;
  14use language_model::{
  15    LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId,
  16};
  17use project::Project;
  18use rand::prelude::*;
  19use reqwest_client::ReqwestClient;
  20use serde_json::json;
  21use std::{
  22    cmp::Reverse,
  23    fmt::{self, Display},
  24    io::Write as _,
  25    sync::mpsc,
  26};
  27use util::path;
  28
  29#[test]
  30#[cfg_attr(not(feature = "eval"), ignore)]
  31fn eval_extract_handle_command_output() {
  32    let input_file_path = "root/blame.rs";
  33    let input_file_content = include_str!("evals/fixtures/extract_handle_command_output/before.rs");
  34    let output_file_content = include_str!("evals/fixtures/extract_handle_command_output/after.rs");
  35    let edit_description = "Extract `handle_command_output` method from `run_git_blame`.";
  36    eval(
  37        100,
  38        0.95,
  39        EvalInput {
  40            conversation: vec![
  41                message(
  42                    User,
  43                    [text(indoc! {"
  44                        Read the `{input_file_path}` file and extract a method in
  45                        the final stanza of `run_git_blame` to deal with command failures,
  46                        call it `handle_command_output` and take the std::process::Output as the only parameter.
  47
  48                        Add it right next to `run_git_blame` and copy it verbatim from `run_git_blame`.
  49                    "})],
  50                ),
  51                message(
  52                    Assistant,
  53                    [tool_use(
  54                        "tool_1",
  55                        "read_file",
  56                        ReadFileToolInput {
  57                            path: input_file_path.into(),
  58                            start_line: None,
  59                            end_line: None,
  60                        },
  61                    )],
  62                ),
  63                message(
  64                    User,
  65                    [tool_result("tool_1", "read_file", input_file_content)],
  66                ),
  67                message(
  68                    Assistant,
  69                    [tool_use(
  70                        "tool_2",
  71                        "edit_file",
  72                        StreamingEditFileToolInput {
  73                            display_description: edit_description.into(),
  74                            path: input_file_path.into(),
  75                            create_or_overwrite: false,
  76                        },
  77                    )],
  78                ),
  79            ],
  80            input_path: input_file_path.into(),
  81            input_content: Some(input_file_content.into()),
  82            edit_description: edit_description.into(),
  83            assertion: EvalAssertion::assert_eq(output_file_content),
  84        },
  85    );
  86}
  87
  88#[test]
  89#[cfg_attr(not(feature = "eval"), ignore)]
  90fn eval_delete_run_git_blame() {
  91    let input_file_path = "root/blame.rs";
  92    let input_file_content = include_str!("evals/fixtures/delete_run_git_blame/before.rs");
  93    let output_file_content = include_str!("evals/fixtures/delete_run_git_blame/after.rs");
  94    let edit_description = "Delete the `run_git_blame` function.";
  95    eval(
  96        100,
  97        0.95,
  98        EvalInput {
  99            conversation: vec![
 100                message(
 101                    User,
 102                    [text(indoc! {"
 103                        Read the `{input_file_path}` file and delete `run_git_blame`. Just that
 104                        one function, not its usages.
 105                    "})],
 106                ),
 107                message(
 108                    Assistant,
 109                    [tool_use(
 110                        "tool_1",
 111                        "read_file",
 112                        ReadFileToolInput {
 113                            path: input_file_path.into(),
 114                            start_line: None,
 115                            end_line: None,
 116                        },
 117                    )],
 118                ),
 119                message(
 120                    User,
 121                    [tool_result("tool_1", "read_file", input_file_content)],
 122                ),
 123                message(
 124                    Assistant,
 125                    [tool_use(
 126                        "tool_2",
 127                        "edit_file",
 128                        StreamingEditFileToolInput {
 129                            display_description: edit_description.into(),
 130                            path: input_file_path.into(),
 131                            create_or_overwrite: false,
 132                        },
 133                    )],
 134                ),
 135            ],
 136            input_path: input_file_path.into(),
 137            input_content: Some(input_file_content.into()),
 138            edit_description: edit_description.into(),
 139            assertion: EvalAssertion::assert_eq(output_file_content),
 140        },
 141    );
 142}
 143
 144#[test]
 145#[cfg_attr(not(feature = "eval"), ignore)]
 146fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
 147    let input_file_path = "root/lib.rs";
 148    let input_file_content =
 149        include_str!("evals/fixtures/use_wasi_sdk_in_compile_parser_to_wasm/before.rs");
 150    let edit_description = "Update compile_parser_to_wasm to use wasi-sdk instead of emscripten";
 151    eval(
 152        100,
 153        0.95,
 154        EvalInput {
 155            conversation: vec![
 156                message(
 157                    User,
 158                    [text(indoc! {"
 159                        Read the `{input_file_path}` file and change `compile_parser_to_wasm` to use `wasi-sdk` instead of emscripten.
 160                        Use `ureq` to download the SDK for the current platform and architecture.
 161                        Extract the archive into a sibling of `lib` inside the `tree-sitter` directory in the cache_dir.
 162                        Compile the parser to wasm using the `bin/clang` executable (or `bin/clang.exe` on windows)
 163                        that's inside of the archive.
 164                        Don't re-download the SDK if that executable already exists.
 165
 166                        Use these clang flags: -fPIC -shared -Os -Wl,--export=tree_sitter_{language_name}
 167
 168                        Here are the available wasi-sdk assets:
 169                        - wasi-sdk-25.0-x86_64-macos.tar.gz
 170                        - wasi-sdk-25.0-arm64-macos.tar.gz
 171                        - wasi-sdk-25.0-x86_64-linux.tar.gz
 172                        - wasi-sdk-25.0-arm64-linux.tar.gz
 173                        - wasi-sdk-25.0-x86_64-linux.tar.gz
 174                        - wasi-sdk-25.0-arm64-linux.tar.gz
 175                        - wasi-sdk-25.0-x86_64-windows.tar.gz
 176                    "})],
 177                ),
 178                message(
 179                    Assistant,
 180                    [tool_use(
 181                        "tool_1",
 182                        "read_file",
 183                        ReadFileToolInput {
 184                            path: input_file_path.into(),
 185                            start_line: Some(971),
 186                            end_line: Some(1050),
 187                        },
 188                    )],
 189                ),
 190                message(
 191                    User,
 192                    [tool_result(
 193                        "tool_1",
 194                        "read_file",
 195                        lines(input_file_content, 971..1050),
 196                    )],
 197                ),
 198                message(
 199                    Assistant,
 200                    [tool_use(
 201                        "tool_2",
 202                        "read_file",
 203                        ReadFileToolInput {
 204                            path: input_file_path.into(),
 205                            start_line: Some(1050),
 206                            end_line: Some(1100),
 207                        },
 208                    )],
 209                ),
 210                message(
 211                    User,
 212                    [tool_result(
 213                        "tool_2",
 214                        "read_file",
 215                        lines(input_file_content, 1050..1100),
 216                    )],
 217                ),
 218                message(
 219                    Assistant,
 220                    [tool_use(
 221                        "tool_3",
 222                        "read_file",
 223                        ReadFileToolInput {
 224                            path: input_file_path.into(),
 225                            start_line: Some(1100),
 226                            end_line: Some(1150),
 227                        },
 228                    )],
 229                ),
 230                message(
 231                    User,
 232                    [tool_result(
 233                        "tool_3",
 234                        "read_file",
 235                        lines(input_file_content, 1100..1150),
 236                    )],
 237                ),
 238                message(
 239                    Assistant,
 240                    [tool_use(
 241                        "tool_4",
 242                        "edit_file",
 243                        StreamingEditFileToolInput {
 244                            display_description: edit_description.into(),
 245                            path: input_file_path.into(),
 246                            create_or_overwrite: false,
 247                        },
 248                    )],
 249                ),
 250            ],
 251            input_path: input_file_path.into(),
 252            input_content: Some(input_file_content.into()),
 253            edit_description: edit_description.into(),
 254            assertion: EvalAssertion::judge_diff(indoc! {"
 255                - The compile_parser_to_wasm method has been changed to use wasi-sdk
 256                - ureq is used to download the SDK for current platform and architecture
 257            "}),
 258        },
 259    );
 260}
 261
 262#[test]
 263#[cfg_attr(not(feature = "eval"), ignore)]
 264fn eval_disable_cursor_blinking() {
 265    let input_file_path = "root/editor.rs";
 266    let input_file_content = include_str!("evals/fixtures/disable_cursor_blinking/before.rs");
 267    let output_file_content = include_str!("evals/fixtures/disable_cursor_blinking/after.rs");
 268    let edit_description = "Comment out the call to `BlinkManager::enable`";
 269    eval(
 270        200,
 271        0.6, // TODO: make this eval better
 272        EvalInput {
 273            conversation: vec![
 274                message(User, [text("Let's research how to cursor blinking works.")]),
 275                message(
 276                    Assistant,
 277                    [tool_use(
 278                        "tool_1",
 279                        "grep",
 280                        GrepToolInput {
 281                            regex: "blink".into(),
 282                            include_pattern: None,
 283                            offset: 0,
 284                            case_sensitive: false,
 285                        },
 286                    )],
 287                ),
 288                message(
 289                    User,
 290                    [tool_result(
 291                        "tool_1",
 292                        "grep",
 293                        [
 294                            lines(input_file_content, 100..400),
 295                            lines(input_file_content, 800..1300),
 296                            lines(input_file_content, 1600..2000),
 297                            lines(input_file_content, 5000..5500),
 298                            lines(input_file_content, 8000..9000),
 299                            lines(input_file_content, 18455..18470),
 300                            lines(input_file_content, 20000..20500),
 301                            lines(input_file_content, 21000..21300),
 302                        ]
 303                        .join("Match found:\n\n"),
 304                    )],
 305                ),
 306                message(
 307                    User,
 308                    [text(indoc! {"
 309                        Comment out the lines that interact with the BlinkManager.
 310                        Keep the outer `update` blocks, but comments everything that's inside (including if statements).
 311                        Don't add additional comments.
 312                    "})],
 313                ),
 314                message(
 315                    Assistant,
 316                    [tool_use(
 317                        "tool_4",
 318                        "edit_file",
 319                        StreamingEditFileToolInput {
 320                            display_description: edit_description.into(),
 321                            path: input_file_path.into(),
 322                            create_or_overwrite: false,
 323                        },
 324                    )],
 325                ),
 326            ],
 327            input_path: input_file_path.into(),
 328            input_content: Some(input_file_content.into()),
 329            edit_description: edit_description.into(),
 330            assertion: EvalAssertion::assert_eq(output_file_content),
 331        },
 332    );
 333}
 334
 335#[test]
 336#[cfg_attr(not(feature = "eval"), ignore)]
 337fn eval_from_pixels_constructor() {
 338    let input_file_path = "root/canvas.rs";
 339    let input_file_content = include_str!("evals/fixtures/from_pixels_constructor/before.rs");
 340    let edit_description = "Implement from_pixels constructor and add tests.";
 341    eval(
 342        100,
 343        0.95,
 344        EvalInput {
 345            conversation: vec![
 346                message(
 347                    User,
 348                    [text(indoc! {"
 349                        Introduce a new `from_pixels` constructor in Canvas and
 350                        also add tests for it in the same file.
 351                    "})],
 352                ),
 353                message(
 354                    Assistant,
 355                    [tool_use(
 356                        "tool_1",
 357                        "read_file",
 358                        ReadFileToolInput {
 359                            path: input_file_path.into(),
 360                            start_line: None,
 361                            end_line: None,
 362                        },
 363                    )],
 364                ),
 365                message(
 366                    User,
 367                    [tool_result("tool_1", "read_file", input_file_content)],
 368                ),
 369                message(
 370                    Assistant,
 371                    [tool_use(
 372                        "tool_2",
 373                        "grep",
 374                        GrepToolInput {
 375                            regex: "mod\\s+tests".into(),
 376                            include_pattern: Some("font-kit/src/canvas.rs".into()),
 377                            offset: 0,
 378                            case_sensitive: false,
 379                        },
 380                    )],
 381                ),
 382                message(User, [tool_result("tool_2", "grep", "No matches found")]),
 383                message(
 384                    Assistant,
 385                    [tool_use(
 386                        "tool_3",
 387                        "grep",
 388                        GrepToolInput {
 389                            regex: "mod\\s+tests".into(),
 390                            include_pattern: Some("font-kit/src/**/*.rs".into()),
 391                            offset: 0,
 392                            case_sensitive: false,
 393                        },
 394                    )],
 395                ),
 396                message(User, [tool_result("tool_3", "grep", "No matches found")]),
 397                message(
 398                    Assistant,
 399                    [tool_use(
 400                        "tool_4",
 401                        "grep",
 402                        GrepToolInput {
 403                            regex: "#\\[test\\]".into(),
 404                            include_pattern: Some("font-kit/src/**/*.rs".into()),
 405                            offset: 0,
 406                            case_sensitive: false,
 407                        },
 408                    )],
 409                ),
 410                message(
 411                    User,
 412                    [tool_result(
 413                        "tool_4",
 414                        "grep",
 415                        indoc! {"
 416                            Found 6 matches:
 417
 418                            ## Matches in font-kit/src/loaders/core_text.rs
 419
 420                            ### mod test › L926-936
 421                            ```
 422                            mod test {
 423                                use super::Font;
 424                                use crate::properties::{Stretch, Weight};
 425
 426                                #[cfg(feature = \"source\")]
 427                                use crate::source::SystemSource;
 428
 429                                static TEST_FONT_POSTSCRIPT_NAME: &'static str = \"ArialMT\";
 430
 431                                #[cfg(feature = \"source\")]
 432                                #[test]
 433                            ```
 434
 435                            55 lines remaining in ancestor node. Read the file to see all.
 436
 437                            ### mod test › L947-951
 438                            ```
 439                                }
 440
 441                                #[test]
 442                                fn test_core_text_to_css_font_weight() {
 443                                    // Exact matches
 444                            ```
 445
 446                            ### mod test › L959-963
 447                            ```
 448                                }
 449
 450                                #[test]
 451                                fn test_core_text_to_css_font_stretch() {
 452                                    // Exact matches
 453                            ```
 454
 455                            ## Matches in font-kit/src/loaders/freetype.rs
 456
 457                            ### mod test › L1238-1248
 458                            ```
 459                            mod test {
 460                                use crate::loaders::freetype::Font;
 461
 462                                static PCF_FONT_PATH: &str = \"resources/tests/times-roman-pcf/timR12.pcf\";
 463                                static PCF_FONT_POSTSCRIPT_NAME: &str = \"Times-Roman\";
 464
 465                                #[test]
 466                                fn get_pcf_postscript_name() {
 467                                    let font = Font::from_path(PCF_FONT_PATH, 0).unwrap();
 468                                    assert_eq!(font.postscript_name().unwrap(), PCF_FONT_POSTSCRIPT_NAME);
 469                                }
 470                            ```
 471
 472                            1 lines remaining in ancestor node. Read the file to see all.
 473
 474                            ## Matches in font-kit/src/sources/core_text.rs
 475
 476                            ### mod test › L265-275
 477                            ```
 478                            mod test {
 479                                use crate::properties::{Stretch, Weight};
 480
 481                                #[test]
 482                                fn test_css_to_core_text_font_weight() {
 483                                    // Exact matches
 484                                    assert_eq!(super::css_to_core_text_font_weight(Weight(100.0)), -0.7);
 485                                    assert_eq!(super::css_to_core_text_font_weight(Weight(400.0)), 0.0);
 486                                    assert_eq!(super::css_to_core_text_font_weight(Weight(700.0)), 0.4);
 487                                    assert_eq!(super::css_to_core_text_font_weight(Weight(900.0)), 0.8);
 488
 489                            ```
 490
 491                            27 lines remaining in ancestor node. Read the file to see all.
 492
 493                            ### mod test › L278-282
 494                            ```
 495                                }
 496
 497                                #[test]
 498                                fn test_css_to_core_text_font_stretch() {
 499                                    // Exact matches
 500                            ```
 501                        "},
 502                    )],
 503                ),
 504                message(
 505                    Assistant,
 506                    [tool_use(
 507                        "tool_5",
 508                        "edit_file",
 509                        StreamingEditFileToolInput {
 510                            display_description: edit_description.into(),
 511                            path: input_file_path.into(),
 512                            create_or_overwrite: false,
 513                        },
 514                    )],
 515                ),
 516            ],
 517            input_path: input_file_path.into(),
 518            input_content: Some(input_file_content.into()),
 519            edit_description: edit_description.into(),
 520            assertion: EvalAssertion::judge_diff(indoc! {"
 521                - The diff contains a new `from_pixels` constructor
 522                - The diff contains new tests for the `from_pixels` constructor
 523            "}),
 524        },
 525    );
 526}
 527
 528#[test]
 529#[cfg_attr(not(feature = "eval"), ignore)]
 530fn eval_zode() {
 531    let input_file_path = "root/zode.py";
 532    let edit_description = "Create the main Zode CLI script";
 533    eval(
 534        200,
 535        1.,
 536        EvalInput {
 537            conversation: vec![
 538                message(User, [text(include_str!("evals/fixtures/zode/prompt.md"))]),
 539                message(
 540                    Assistant,
 541                    [
 542                        tool_use(
 543                            "tool_1",
 544                            "read_file",
 545                            ReadFileToolInput {
 546                                path: "root/eval/react.py".into(),
 547                                start_line: None,
 548                                end_line: None,
 549                            },
 550                        ),
 551                        tool_use(
 552                            "tool_2",
 553                            "read_file",
 554                            ReadFileToolInput {
 555                                path: "root/eval/react_test.py".into(),
 556                                start_line: None,
 557                                end_line: None,
 558                            },
 559                        ),
 560                    ],
 561                ),
 562                message(
 563                    User,
 564                    [
 565                        tool_result(
 566                            "tool_1",
 567                            "read_file",
 568                            include_str!("evals/fixtures/zode/react.py"),
 569                        ),
 570                        tool_result(
 571                            "tool_2",
 572                            "read_file",
 573                            include_str!("evals/fixtures/zode/react_test.py"),
 574                        ),
 575                    ],
 576                ),
 577                message(
 578                    Assistant,
 579                    [
 580                        text(
 581                            "Now that I understand what we need to build, I'll create the main Python script:",
 582                        ),
 583                        tool_use(
 584                            "tool_3",
 585                            "edit_file",
 586                            StreamingEditFileToolInput {
 587                                display_description: edit_description.into(),
 588                                path: input_file_path.into(),
 589                                create_or_overwrite: true,
 590                            },
 591                        ),
 592                    ],
 593                ),
 594            ],
 595            input_path: input_file_path.into(),
 596            input_content: None,
 597            edit_description: edit_description.into(),
 598            assertion: EvalAssertion::new(async move |sample, _, _cx| {
 599                let invalid_starts = [' ', '`', '\n'];
 600                let mut message = String::new();
 601                for start in invalid_starts {
 602                    if sample.text.starts_with(start) {
 603                        message.push_str(&format!("The sample starts with a {:?}\n", start));
 604                        break;
 605                    }
 606                }
 607                // Remove trailing newline.
 608                message.pop();
 609
 610                if message.is_empty() {
 611                    Ok(EvalAssertionOutcome {
 612                        score: 100,
 613                        message: None,
 614                    })
 615                } else {
 616                    Ok(EvalAssertionOutcome {
 617                        score: 0,
 618                        message: Some(message),
 619                    })
 620                }
 621            }),
 622        },
 623    );
 624}
 625
 626#[test]
 627#[cfg_attr(not(feature = "eval"), ignore)]
 628fn eval_add_overwrite_test() {
 629    let input_file_path = "root/action_log.rs";
 630    let input_file_content = include_str!("evals/fixtures/add_overwrite_test/before.rs");
 631    let edit_description = "Add a new test for overwriting a file in action_log.rs";
 632    eval(
 633        200,
 634        0.5, // TODO: make this eval better
 635        EvalInput {
 636            conversation: vec![
 637                message(
 638                    User,
 639                    [text(indoc! {"
 640                        Introduce a new test in `action_log.rs` to test overwriting a file.
 641                        That is, a file already exists, but we call `buffer_created` as if the file were new.
 642                        Take inspiration from all the other tests in the file.
 643                    "})],
 644                ),
 645                message(
 646                    Assistant,
 647                    [tool_use(
 648                        "tool_1",
 649                        "read_file",
 650                        ReadFileToolInput {
 651                            path: input_file_path.into(),
 652                            start_line: None,
 653                            end_line: None,
 654                        },
 655                    )],
 656                ),
 657                message(
 658                    User,
 659                    [tool_result(
 660                        "tool_1",
 661                        "read_file",
 662                        indoc! {"
 663                            pub struct ActionLog [L13-20]
 664                             tracked_buffers [L15]
 665                             edited_since_project_diagnostics_check [L17]
 666                             project [L19]
 667                            impl ActionLog [L22-498]
 668                             pub fn new [L24-30]
 669                             pub fn project [L32-34]
 670                             pub fn checked_project_diagnostics [L37-39]
 671                             pub fn has_edited_files_since_project_diagnostics_check [L42-44]
 672                             fn track_buffer_internal [L46-101]
 673                             fn handle_buffer_event [L103-116]
 674                             fn handle_buffer_edited [L118-123]
 675                             fn handle_buffer_file_changed [L125-158]
 676                             async fn maintain_diff [L160-264]
 677                             pub fn buffer_read [L267-269]
 678                             pub fn buffer_created [L272-276]
 679                             pub fn buffer_edited [L279-287]
 680                             pub fn will_delete_buffer [L289-304]
 681                             pub fn keep_edits_in_range [L306-364]
 682                             pub fn reject_edits_in_ranges [L366-459]
 683                             pub fn keep_all_edits [L461-473]
 684                             pub fn changed_buffers [L476-482]
 685                             pub fn stale_buffers [L485-497]
 686                            fn apply_non_conflicting_edits [L500-561]
 687                            fn diff_snapshots [L563-585]
 688                            fn point_to_row_edit [L587-614]
 689                            enum ChangeAuthor [L617-620]
 690                             User [L618]
 691                             Agent [L619]
 692                            enum TrackedBufferStatus [L623-627]
 693                             Created [L624]
 694                             Modified [L625]
 695                             Deleted [L626]
 696                            struct TrackedBuffer [L629-641]
 697                             buffer [L630]
 698                             base_text [L631]
 699                             unreviewed_changes [L632]
 700                             status [L633]
 701                             version [L634]
 702                             diff [L635]
 703                             snapshot [L636]
 704                             diff_update [L637]
 705                             _open_lsp_handle [L638]
 706                             _maintain_diff [L639]
 707                             _subscription [L640]
 708                            impl TrackedBuffer [L643-657]
 709                             fn has_changes [L644-650]
 710                             fn schedule_diff_update [L652-656]
 711                            pub struct ChangedBuffer [L659-661]
 712                             pub diff [L660]
 713                            mod tests [L664-1574]
 714                             fn init_logger [L678-682]
 715                             fn init_test [L684-691]
 716                             async fn test_keep_edits [L694-769]
 717                             async fn test_deletions [L772-854]
 718                             async fn test_overlapping_user_edits [L857-951]
 719                             async fn test_creating_files [L954-1010]
 720                             async fn test_deleting_files [L1013-1120]
 721                             async fn test_reject_edits [L1123-1255]
 722                             async fn test_reject_multiple_edits [L1258-1331]
 723                             async fn test_reject_deleted_file [L1334-1388]
 724                             async fn test_reject_created_file [L1391-1443]
 725                             async fn test_random_diffs [L1446-1535]
 726                              fn quiesce [L1510-1534]
 727                             struct HunkStatus [L1538-1542]
 728                              range [L1539]
 729                              diff_status [L1540]
 730                              old_text [L1541]
 731                             fn unreviewed_hunks [L1544-1573]
 732
 733                            Showing symbols 1-69 (total symbols: 69)
 734
 735                            Using the line numbers in this outline, you can call this tool again while specifying
 736                            the start_line and end_line fields to see the implementations of symbols in the outline.
 737                        "},
 738                    )],
 739                ),
 740                message(
 741                    Assistant,
 742                    [
 743                        text(
 744                            "Let's look at the existing test cases that are related to file creation to understand the pattern:",
 745                        ),
 746                        tool_use(
 747                            "tool_2",
 748                            "read_file",
 749                            ReadFileToolInput {
 750                                path: input_file_path.into(),
 751                                start_line: Some(953),
 752                                end_line: Some(1010),
 753                            },
 754                        ),
 755                    ],
 756                ),
 757                message(
 758                    User,
 759                    [tool_result(
 760                        "tool_2",
 761                        "read_file",
 762                        lines(input_file_content, 953..1010),
 763                    )],
 764                ),
 765                message(
 766                    Assistant,
 767                    [
 768                        text(
 769                            "Now let's look at another relevant test to understand how they're structured:",
 770                        ),
 771                        tool_use(
 772                            "tool_3",
 773                            "read_file",
 774                            ReadFileToolInput {
 775                                path: input_file_path.into(),
 776                                start_line: Some(1012),
 777                                end_line: Some(1120),
 778                            },
 779                        ),
 780                    ],
 781                ),
 782                message(
 783                    User,
 784                    [tool_result(
 785                        "tool_3",
 786                        "read_file",
 787                        lines(input_file_content, 1012..1120),
 788                    )],
 789                ),
 790                message(
 791                    Assistant,
 792                    [
 793                        text("Now let's look at how `buffer_created` is implemented:"),
 794                        tool_use(
 795                            "tool_4",
 796                            "read_file",
 797                            ReadFileToolInput {
 798                                path: input_file_path.into(),
 799                                start_line: Some(271),
 800                                end_line: Some(276),
 801                            },
 802                        ),
 803                    ],
 804                ),
 805                message(
 806                    User,
 807                    [tool_result(
 808                        "tool_4",
 809                        "read_file",
 810                        lines(input_file_content, 271..276),
 811                    )],
 812                ),
 813                message(
 814                    Assistant,
 815                    [
 816                        text(indoc! {"
 817                            Based on my understanding of the existing tests and code, I'll now create a new test for overwriting a file. In this test, we'll:
 818
 819                            1. Create a file
 820                            2. Open a buffer for it
 821                            3. Track it in the action log
 822                            4. Modify the buffer
 823                            5. Call `buffer_created` again as if the file were new
 824                            6. Check that the changes are properly tracked
 825
 826                            Let's write the test:
 827                        "}),
 828                        tool_use(
 829                            "tool_5",
 830                            "edit_file",
 831                            StreamingEditFileToolInput {
 832                                display_description: edit_description.into(),
 833                                path: input_file_path.into(),
 834                                create_or_overwrite: false,
 835                            },
 836                        ),
 837                    ],
 838                ),
 839            ],
 840            input_path: input_file_path.into(),
 841            input_content: Some(input_file_content.into()),
 842            edit_description: edit_description.into(),
 843            assertion: EvalAssertion::judge_diff(
 844                "A new test for overwritten files was created, without changing any previous test",
 845            ),
 846        },
 847    );
 848}
 849
 850fn message(
 851    role: Role,
 852    contents: impl IntoIterator<Item = MessageContent>,
 853) -> LanguageModelRequestMessage {
 854    LanguageModelRequestMessage {
 855        role,
 856        content: contents.into_iter().collect(),
 857        cache: false,
 858    }
 859}
 860
 861fn text(text: impl Into<String>) -> MessageContent {
 862    MessageContent::Text(text.into())
 863}
 864
 865fn lines(input: &str, range: Range<usize>) -> String {
 866    input
 867        .lines()
 868        .skip(range.start)
 869        .take(range.len())
 870        .collect::<Vec<_>>()
 871        .join("\n")
 872}
 873
 874fn tool_use(
 875    id: impl Into<Arc<str>>,
 876    name: impl Into<Arc<str>>,
 877    input: impl Serialize,
 878) -> MessageContent {
 879    MessageContent::ToolUse(LanguageModelToolUse {
 880        id: LanguageModelToolUseId::from(id.into()),
 881        name: name.into(),
 882        raw_input: serde_json::to_string_pretty(&input).unwrap(),
 883        input: serde_json::to_value(input).unwrap(),
 884        is_input_complete: true,
 885    })
 886}
 887
 888fn tool_result(
 889    id: impl Into<Arc<str>>,
 890    name: impl Into<Arc<str>>,
 891    result: impl Into<Arc<str>>,
 892) -> MessageContent {
 893    MessageContent::ToolResult(LanguageModelToolResult {
 894        tool_use_id: LanguageModelToolUseId::from(id.into()),
 895        tool_name: name.into(),
 896        is_error: false,
 897        content: result.into(),
 898        output: None,
 899    })
 900}
 901
 902#[derive(Clone)]
 903struct EvalInput {
 904    conversation: Vec<LanguageModelRequestMessage>,
 905    input_path: PathBuf,
 906    input_content: Option<String>,
 907    edit_description: String,
 908    assertion: EvalAssertion,
 909}
 910
 911#[derive(Clone)]
 912struct EvalSample {
 913    text: String,
 914    edit_output: EditAgentOutput,
 915    diff: String,
 916}
 917
 918trait AssertionFn: 'static + Send + Sync {
 919    fn assert<'a>(
 920        &'a self,
 921        sample: &'a EvalSample,
 922        judge_model: Arc<dyn LanguageModel>,
 923        cx: &'a mut TestAppContext,
 924    ) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>>;
 925}
 926
 927impl<F> AssertionFn for F
 928where
 929    F: 'static
 930        + Send
 931        + Sync
 932        + AsyncFn(
 933            &EvalSample,
 934            Arc<dyn LanguageModel>,
 935            &mut TestAppContext,
 936        ) -> Result<EvalAssertionOutcome>,
 937{
 938    fn assert<'a>(
 939        &'a self,
 940        sample: &'a EvalSample,
 941        judge_model: Arc<dyn LanguageModel>,
 942        cx: &'a mut TestAppContext,
 943    ) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>> {
 944        (self)(sample, judge_model, cx).boxed_local()
 945    }
 946}
 947
 948#[derive(Clone)]
 949struct EvalAssertion(Arc<dyn AssertionFn>);
 950
 951impl EvalAssertion {
 952    fn new<F>(f: F) -> Self
 953    where
 954        F: 'static
 955            + Send
 956            + Sync
 957            + AsyncFn(
 958                &EvalSample,
 959                Arc<dyn LanguageModel>,
 960                &mut TestAppContext,
 961            ) -> Result<EvalAssertionOutcome>,
 962    {
 963        EvalAssertion(Arc::new(f))
 964    }
 965
 966    fn assert_eq(expected: impl Into<String>) -> Self {
 967        let expected = expected.into();
 968        Self::new(async move |sample, _judge, _cx| {
 969            Ok(EvalAssertionOutcome {
 970                score: if strip_empty_lines(&sample.text) == strip_empty_lines(&expected) {
 971                    100
 972                } else {
 973                    0
 974                },
 975                message: None,
 976            })
 977        })
 978    }
 979
 980    fn judge_diff(assertions: &'static str) -> Self {
 981        Self::new(async move |sample, judge, cx| {
 982            let prompt = DiffJudgeTemplate {
 983                diff: sample.diff.clone(),
 984                assertions,
 985            }
 986            .render(&Templates::new())
 987            .unwrap();
 988
 989            let request = LanguageModelRequest {
 990                messages: vec![LanguageModelRequestMessage {
 991                    role: Role::User,
 992                    content: vec![prompt.into()],
 993                    cache: false,
 994                }],
 995                ..Default::default()
 996            };
 997            let mut response = judge
 998                .stream_completion_text(request, &cx.to_async())
 999                .await?;
1000            let mut output = String::new();
1001            while let Some(chunk) = response.stream.next().await {
1002                let chunk = chunk?;
1003                output.push_str(&chunk);
1004            }
1005
1006            // Parse the score from the response
1007            let re = regex::Regex::new(r"<score>(\d+)</score>").unwrap();
1008            if let Some(captures) = re.captures(&output) {
1009                if let Some(score_match) = captures.get(1) {
1010                    let score = score_match.as_str().parse().unwrap_or(0);
1011                    return Ok(EvalAssertionOutcome {
1012                        score,
1013                        message: Some(output),
1014                    });
1015                }
1016            }
1017
1018            Err(anyhow!(
1019                "No score found in response. Raw output: {}",
1020                output
1021            ))
1022        })
1023    }
1024
1025    async fn run(
1026        &self,
1027        input: &EvalSample,
1028        judge_model: Arc<dyn LanguageModel>,
1029        cx: &mut TestAppContext,
1030    ) -> Result<EvalAssertionOutcome> {
1031        self.0.assert(input, judge_model, cx).await
1032    }
1033}
1034
1035fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
1036    let mut evaluated_count = 0;
1037    report_progress(evaluated_count, iterations);
1038
1039    let (tx, rx) = mpsc::channel();
1040
1041    // Cache the last message in the conversation, and run one instance of the eval so that
1042    // all the next ones are cached.
1043    eval.conversation.last_mut().unwrap().cache = true;
1044    run_eval(eval.clone(), tx.clone());
1045
1046    let executor = gpui::background_executor();
1047    for _ in 1..iterations {
1048        let eval = eval.clone();
1049        let tx = tx.clone();
1050        executor.spawn(async move { run_eval(eval, tx) }).detach();
1051    }
1052    drop(tx);
1053
1054    let mut failed_count = 0;
1055    let mut failed_evals = HashMap::default();
1056    let mut errored_evals = HashMap::default();
1057    let mut eval_outputs = Vec::new();
1058    let mut cumulative_parser_metrics = EditParserMetrics::default();
1059    while let Ok(output) = rx.recv() {
1060        match output {
1061            Ok(output) => {
1062                cumulative_parser_metrics += output.sample.edit_output._parser_metrics.clone();
1063                eval_outputs.push(output.clone());
1064                if output.assertion.score < 80 {
1065                    failed_count += 1;
1066                    failed_evals
1067                        .entry(output.sample.text.clone())
1068                        .or_insert(Vec::new())
1069                        .push(output);
1070                }
1071            }
1072            Err(error) => {
1073                failed_count += 1;
1074                *errored_evals.entry(format!("{:?}", error)).or_insert(0) += 1;
1075            }
1076        }
1077
1078        evaluated_count += 1;
1079        report_progress(evaluated_count, iterations);
1080    }
1081
1082    let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32;
1083    println!("Actual pass ratio: {}\n", actual_pass_ratio);
1084    if actual_pass_ratio < expected_pass_ratio {
1085        let mut errored_evals = errored_evals.into_iter().collect::<Vec<_>>();
1086        errored_evals.sort_by_key(|(_, count)| Reverse(*count));
1087        for (error, count) in errored_evals {
1088            println!("Eval errored {} times. Error: {}", count, error);
1089        }
1090
1091        let mut failed_evals = failed_evals.into_iter().collect::<Vec<_>>();
1092        failed_evals.sort_by_key(|(_, evals)| Reverse(evals.len()));
1093        for (_buffer_output, failed_evals) in failed_evals {
1094            let eval_output = failed_evals.first().unwrap();
1095            println!("Eval failed {} times", failed_evals.len());
1096            println!("{}", eval_output);
1097        }
1098
1099        panic!(
1100            "Actual pass ratio: {}\nExpected pass ratio: {}",
1101            actual_pass_ratio, expected_pass_ratio
1102        );
1103    }
1104
1105    let mismatched_tag_ratio =
1106        cumulative_parser_metrics.mismatched_tags as f32 / cumulative_parser_metrics.tags as f32;
1107    if mismatched_tag_ratio > 0.05 {
1108        for eval_output in eval_outputs {
1109            println!("{}", eval_output);
1110        }
1111        panic!("Too many mismatched tags: {:?}", cumulative_parser_metrics);
1112    }
1113}
1114
1115fn run_eval(eval: EvalInput, tx: mpsc::Sender<Result<EvalOutput>>) {
1116    let dispatcher = gpui::TestDispatcher::new(StdRng::from_entropy());
1117    let mut cx = TestAppContext::build(dispatcher, None);
1118    let output = cx.executor().block_test(async {
1119        let test = EditAgentTest::new(&mut cx).await;
1120        test.eval(eval, &mut cx).await
1121    });
1122    tx.send(output).unwrap();
1123}
1124
1125#[derive(Clone)]
1126struct EvalOutput {
1127    sample: EvalSample,
1128    assertion: EvalAssertionOutcome,
1129}
1130
1131impl Display for EvalOutput {
1132    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1133        writeln!(f, "Score: {:?}", self.assertion.score)?;
1134        if let Some(message) = self.assertion.message.as_ref() {
1135            writeln!(f, "Message: {}", message)?;
1136        }
1137
1138        writeln!(f, "Diff:\n{}", self.sample.diff)?;
1139
1140        writeln!(
1141            f,
1142            "Parser Metrics:\n{:#?}",
1143            self.sample.edit_output._parser_metrics
1144        )?;
1145        writeln!(f, "Raw Edits:\n{}", self.sample.edit_output._raw_edits)?;
1146        Ok(())
1147    }
1148}
1149
1150fn report_progress(evaluated_count: usize, iterations: usize) {
1151    print!("\r\x1b[KEvaluated {}/{}", evaluated_count, iterations);
1152    std::io::stdout().flush().unwrap();
1153}
1154
1155struct EditAgentTest {
1156    agent: EditAgent,
1157    project: Entity<Project>,
1158    judge_model: Arc<dyn LanguageModel>,
1159}
1160
1161impl EditAgentTest {
1162    async fn new(cx: &mut TestAppContext) -> Self {
1163        cx.executor().allow_parking();
1164        cx.update(settings::init);
1165        cx.update(Project::init_settings);
1166        cx.update(language::init);
1167        cx.update(gpui_tokio::init);
1168        cx.update(client::init_settings);
1169
1170        let fs = FakeFs::new(cx.executor().clone());
1171        fs.insert_tree("/root", json!({})).await;
1172        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1173        let (agent_model, judge_model) = cx
1174            .update(|cx| {
1175                let http_client = ReqwestClient::user_agent("agent tests").unwrap();
1176                cx.set_http_client(Arc::new(http_client));
1177
1178                let client = Client::production(cx);
1179                let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1180                language_model::init(client.clone(), cx);
1181                language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
1182
1183                cx.spawn(async move |cx| {
1184                    let agent_model =
1185                        Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
1186                    let judge_model =
1187                        Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
1188                    (agent_model.unwrap(), judge_model.unwrap())
1189                })
1190            })
1191            .await;
1192        let action_log = cx.new(|_| ActionLog::new(project.clone()));
1193
1194        Self {
1195            agent: EditAgent::new(agent_model, project.clone(), action_log, Templates::new()),
1196            project,
1197            judge_model,
1198        }
1199    }
1200
1201    async fn load_model(
1202        provider: &str,
1203        id: &str,
1204        cx: &mut AsyncApp,
1205    ) -> Result<Arc<dyn LanguageModel>> {
1206        let (provider, model) = cx.update(|cx| {
1207            let models = LanguageModelRegistry::read_global(cx);
1208            let model = models
1209                .available_models(cx)
1210                .find(|model| model.provider_id().0 == provider && model.id().0 == id)
1211                .unwrap();
1212            let provider = models.provider(&model.provider_id()).unwrap();
1213            (provider, model)
1214        })?;
1215        cx.update(|cx| provider.authenticate(cx))?.await?;
1216        Ok(model)
1217    }
1218
1219    async fn eval(&self, eval: EvalInput, cx: &mut TestAppContext) -> Result<EvalOutput> {
1220        let path = self
1221            .project
1222            .read_with(cx, |project, cx| {
1223                project.find_project_path(eval.input_path, cx)
1224            })
1225            .unwrap();
1226        let buffer = self
1227            .project
1228            .update(cx, |project, cx| project.open_buffer(path, cx))
1229            .await
1230            .unwrap();
1231        let edit_output = if let Some(input_content) = eval.input_content.as_deref() {
1232            buffer.update(cx, |buffer, cx| buffer.set_text(input_content, cx));
1233            let (edit_output, _) = self.agent.edit(
1234                buffer.clone(),
1235                eval.edit_description,
1236                eval.conversation,
1237                &mut cx.to_async(),
1238            );
1239            edit_output.await?
1240        } else {
1241            let (edit_output, _) = self.agent.overwrite(
1242                buffer.clone(),
1243                eval.edit_description,
1244                eval.conversation,
1245                &mut cx.to_async(),
1246            );
1247            edit_output.await?
1248        };
1249
1250        let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
1251        let sample = EvalSample {
1252            edit_output,
1253            diff: language::unified_diff(
1254                eval.input_content.as_deref().unwrap_or_default(),
1255                &buffer_text,
1256            ),
1257            text: buffer_text,
1258        };
1259        let assertion = eval
1260            .assertion
1261            .run(&sample, self.judge_model.clone(), cx)
1262            .await?;
1263
1264        Ok(EvalOutput { assertion, sample })
1265    }
1266}
1267
1268#[derive(Clone, Debug, Eq, PartialEq, Hash)]
1269struct EvalAssertionOutcome {
1270    score: usize,
1271    message: Option<String>,
1272}
1273
1274#[derive(Serialize)]
1275pub struct DiffJudgeTemplate {
1276    diff: String,
1277    assertions: &'static str,
1278}
1279
1280impl Template for DiffJudgeTemplate {
1281    const TEMPLATE_NAME: &'static str = "diff_judge.hbs";
1282}
1283
1284fn strip_empty_lines(text: &str) -> String {
1285    text.lines()
1286        .filter(|line| !line.trim().is_empty())
1287        .collect::<Vec<_>>()
1288        .join("\n")
1289}