evals.rs

   1use super::*;
   2use crate::{ReadFileToolInput, edit_file_tool::EditFileToolInput, grep_tool::GrepToolInput};
   3use Role::*;
   4use anyhow::anyhow;
   5use client::{Client, UserStore};
   6use collections::HashMap;
   7use fs::FakeFs;
   8use futures::{FutureExt, future::LocalBoxFuture};
   9use gpui::{AppContext, TestAppContext};
  10use indoc::indoc;
  11use language_model::{
  12    LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId,
  13};
  14use project::Project;
  15use rand::prelude::*;
  16use reqwest_client::ReqwestClient;
  17use serde_json::json;
  18use std::{
  19    cmp::Reverse,
  20    fmt::{self, Display},
  21    io::Write as _,
  22    sync::mpsc,
  23};
  24use util::path;
  25
  26#[test]
  27#[cfg_attr(not(feature = "eval"), ignore)]
  28fn eval_extract_handle_command_output() {
  29    let input_file_path = "root/blame.rs";
  30    let input_file_content = include_str!("evals/fixtures/extract_handle_command_output/before.rs");
  31    let output_file_content = include_str!("evals/fixtures/extract_handle_command_output/after.rs");
  32    let edit_description = "Extract `handle_command_output` method from `run_git_blame`.";
  33    eval(
  34        100,
  35        0.95,
  36        EvalInput {
  37            conversation: vec![
  38                message(
  39                    User,
  40                    [text(indoc! {"
  41                        Read the `{input_file_path}` file and extract a method in
  42                        the final stanza of `run_git_blame` to deal with command failures,
  43                        call it `handle_command_output` and take the std::process::Output as the only parameter.
  44
  45                        Add it right next to `run_git_blame` and copy it verbatim from `run_git_blame`.
  46                    "})],
  47                ),
  48                message(
  49                    Assistant,
  50                    [tool_use(
  51                        "tool_1",
  52                        "read_file",
  53                        ReadFileToolInput {
  54                            path: input_file_path.into(),
  55                            start_line: None,
  56                            end_line: None,
  57                        },
  58                    )],
  59                ),
  60                message(
  61                    User,
  62                    [tool_result("tool_1", "read_file", input_file_content)],
  63                ),
  64                message(
  65                    Assistant,
  66                    [tool_use(
  67                        "tool_2",
  68                        "edit_file",
  69                        EditFileToolInput {
  70                            display_description: edit_description.into(),
  71                            path: input_file_path.into(),
  72                            create_or_overwrite: false,
  73                        },
  74                    )],
  75                ),
  76            ],
  77            input_path: input_file_path.into(),
  78            input_content: Some(input_file_content.into()),
  79            edit_description: edit_description.into(),
  80            assertion: EvalAssertion::assert_eq(output_file_content),
  81        },
  82    );
  83}
  84
  85#[test]
  86#[cfg_attr(not(feature = "eval"), ignore)]
  87fn eval_delete_run_git_blame() {
  88    let input_file_path = "root/blame.rs";
  89    let input_file_content = include_str!("evals/fixtures/delete_run_git_blame/before.rs");
  90    let output_file_content = include_str!("evals/fixtures/delete_run_git_blame/after.rs");
  91    let edit_description = "Delete the `run_git_blame` function.";
  92    eval(
  93        100,
  94        0.95,
  95        EvalInput {
  96            conversation: vec![
  97                message(
  98                    User,
  99                    [text(indoc! {"
 100                        Read the `{input_file_path}` file and delete `run_git_blame`. Just that
 101                        one function, not its usages.
 102                    "})],
 103                ),
 104                message(
 105                    Assistant,
 106                    [tool_use(
 107                        "tool_1",
 108                        "read_file",
 109                        ReadFileToolInput {
 110                            path: input_file_path.into(),
 111                            start_line: None,
 112                            end_line: None,
 113                        },
 114                    )],
 115                ),
 116                message(
 117                    User,
 118                    [tool_result("tool_1", "read_file", input_file_content)],
 119                ),
 120                message(
 121                    Assistant,
 122                    [tool_use(
 123                        "tool_2",
 124                        "edit_file",
 125                        EditFileToolInput {
 126                            display_description: edit_description.into(),
 127                            path: input_file_path.into(),
 128                            create_or_overwrite: false,
 129                        },
 130                    )],
 131                ),
 132            ],
 133            input_path: input_file_path.into(),
 134            input_content: Some(input_file_content.into()),
 135            edit_description: edit_description.into(),
 136            assertion: EvalAssertion::assert_eq(output_file_content),
 137        },
 138    );
 139}
 140
 141#[test]
 142#[cfg_attr(not(feature = "eval"), ignore)]
 143fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
 144    let input_file_path = "root/lib.rs";
 145    let input_file_content =
 146        include_str!("evals/fixtures/use_wasi_sdk_in_compile_parser_to_wasm/before.rs");
 147    let edit_description = "Update compile_parser_to_wasm to use wasi-sdk instead of emscripten";
 148    eval(
 149        100,
 150        0.95,
 151        EvalInput {
 152            conversation: vec![
 153                message(
 154                    User,
 155                    [text(indoc! {"
 156                        Read the `{input_file_path}` file and change `compile_parser_to_wasm` to use `wasi-sdk` instead of emscripten.
 157                        Use `ureq` to download the SDK for the current platform and architecture.
 158                        Extract the archive into a sibling of `lib` inside the `tree-sitter` directory in the cache_dir.
 159                        Compile the parser to wasm using the `bin/clang` executable (or `bin/clang.exe` on windows)
 160                        that's inside of the archive.
 161                        Don't re-download the SDK if that executable already exists.
 162
 163                        Use these clang flags: -fPIC -shared -Os -Wl,--export=tree_sitter_{language_name}
 164
 165                        Here are the available wasi-sdk assets:
 166                        - wasi-sdk-25.0-x86_64-macos.tar.gz
 167                        - wasi-sdk-25.0-arm64-macos.tar.gz
 168                        - wasi-sdk-25.0-x86_64-linux.tar.gz
 169                        - wasi-sdk-25.0-arm64-linux.tar.gz
 170                        - wasi-sdk-25.0-x86_64-linux.tar.gz
 171                        - wasi-sdk-25.0-arm64-linux.tar.gz
 172                        - wasi-sdk-25.0-x86_64-windows.tar.gz
 173                    "})],
 174                ),
 175                message(
 176                    Assistant,
 177                    [tool_use(
 178                        "tool_1",
 179                        "read_file",
 180                        ReadFileToolInput {
 181                            path: input_file_path.into(),
 182                            start_line: Some(971),
 183                            end_line: Some(1050),
 184                        },
 185                    )],
 186                ),
 187                message(
 188                    User,
 189                    [tool_result(
 190                        "tool_1",
 191                        "read_file",
 192                        lines(input_file_content, 971..1050),
 193                    )],
 194                ),
 195                message(
 196                    Assistant,
 197                    [tool_use(
 198                        "tool_2",
 199                        "read_file",
 200                        ReadFileToolInput {
 201                            path: input_file_path.into(),
 202                            start_line: Some(1050),
 203                            end_line: Some(1100),
 204                        },
 205                    )],
 206                ),
 207                message(
 208                    User,
 209                    [tool_result(
 210                        "tool_2",
 211                        "read_file",
 212                        lines(input_file_content, 1050..1100),
 213                    )],
 214                ),
 215                message(
 216                    Assistant,
 217                    [tool_use(
 218                        "tool_3",
 219                        "read_file",
 220                        ReadFileToolInput {
 221                            path: input_file_path.into(),
 222                            start_line: Some(1100),
 223                            end_line: Some(1150),
 224                        },
 225                    )],
 226                ),
 227                message(
 228                    User,
 229                    [tool_result(
 230                        "tool_3",
 231                        "read_file",
 232                        lines(input_file_content, 1100..1150),
 233                    )],
 234                ),
 235                message(
 236                    Assistant,
 237                    [tool_use(
 238                        "tool_4",
 239                        "edit_file",
 240                        EditFileToolInput {
 241                            display_description: edit_description.into(),
 242                            path: input_file_path.into(),
 243                            create_or_overwrite: false,
 244                        },
 245                    )],
 246                ),
 247            ],
 248            input_path: input_file_path.into(),
 249            input_content: Some(input_file_content.into()),
 250            edit_description: edit_description.into(),
 251            assertion: EvalAssertion::judge_diff(indoc! {"
 252                - The compile_parser_to_wasm method has been changed to use wasi-sdk
 253                - ureq is used to download the SDK for current platform and architecture
 254            "}),
 255        },
 256    );
 257}
 258
 259#[test]
 260#[cfg_attr(not(feature = "eval"), ignore)]
 261fn eval_disable_cursor_blinking() {
 262    let input_file_path = "root/editor.rs";
 263    let input_file_content = include_str!("evals/fixtures/disable_cursor_blinking/before.rs");
 264    let output_file_content = include_str!("evals/fixtures/disable_cursor_blinking/after.rs");
 265    let edit_description = "Comment out the call to `BlinkManager::enable`";
 266    eval(
 267        200,
 268        0.6, // TODO: make this eval better
 269        EvalInput {
 270            conversation: vec![
 271                message(User, [text("Let's research how to cursor blinking works.")]),
 272                message(
 273                    Assistant,
 274                    [tool_use(
 275                        "tool_1",
 276                        "grep",
 277                        GrepToolInput {
 278                            regex: "blink".into(),
 279                            include_pattern: None,
 280                            offset: 0,
 281                            case_sensitive: false,
 282                        },
 283                    )],
 284                ),
 285                message(
 286                    User,
 287                    [tool_result(
 288                        "tool_1",
 289                        "grep",
 290                        [
 291                            lines(input_file_content, 100..400),
 292                            lines(input_file_content, 800..1300),
 293                            lines(input_file_content, 1600..2000),
 294                            lines(input_file_content, 5000..5500),
 295                            lines(input_file_content, 8000..9000),
 296                            lines(input_file_content, 18455..18470),
 297                            lines(input_file_content, 20000..20500),
 298                            lines(input_file_content, 21000..21300),
 299                        ]
 300                        .join("Match found:\n\n"),
 301                    )],
 302                ),
 303                message(
 304                    User,
 305                    [text(indoc! {"
 306                        Comment out the lines that interact with the BlinkManager.
 307                        Keep the outer `update` blocks, but comments everything that's inside (including if statements).
 308                        Don't add additional comments.
 309                    "})],
 310                ),
 311                message(
 312                    Assistant,
 313                    [tool_use(
 314                        "tool_4",
 315                        "edit_file",
 316                        EditFileToolInput {
 317                            display_description: edit_description.into(),
 318                            path: input_file_path.into(),
 319                            create_or_overwrite: false,
 320                        },
 321                    )],
 322                ),
 323            ],
 324            input_path: input_file_path.into(),
 325            input_content: Some(input_file_content.into()),
 326            edit_description: edit_description.into(),
 327            assertion: EvalAssertion::assert_eq(output_file_content),
 328        },
 329    );
 330}
 331
 332#[test]
 333#[cfg_attr(not(feature = "eval"), ignore)]
 334fn eval_from_pixels_constructor() {
 335    let input_file_path = "root/canvas.rs";
 336    let input_file_content = include_str!("evals/fixtures/from_pixels_constructor/before.rs");
 337    let edit_description = "Implement from_pixels constructor and add tests.";
 338    eval(
 339        100,
 340        0.95,
 341        EvalInput {
 342            conversation: vec![
 343                message(
 344                    User,
 345                    [text(indoc! {"
 346                        Introduce a new `from_pixels` constructor in Canvas and
 347                        also add tests for it in the same file.
 348                    "})],
 349                ),
 350                message(
 351                    Assistant,
 352                    [tool_use(
 353                        "tool_1",
 354                        "read_file",
 355                        ReadFileToolInput {
 356                            path: input_file_path.into(),
 357                            start_line: None,
 358                            end_line: None,
 359                        },
 360                    )],
 361                ),
 362                message(
 363                    User,
 364                    [tool_result("tool_1", "read_file", input_file_content)],
 365                ),
 366                message(
 367                    Assistant,
 368                    [tool_use(
 369                        "tool_2",
 370                        "grep",
 371                        GrepToolInput {
 372                            regex: "mod\\s+tests".into(),
 373                            include_pattern: Some("font-kit/src/canvas.rs".into()),
 374                            offset: 0,
 375                            case_sensitive: false,
 376                        },
 377                    )],
 378                ),
 379                message(User, [tool_result("tool_2", "grep", "No matches found")]),
 380                message(
 381                    Assistant,
 382                    [tool_use(
 383                        "tool_3",
 384                        "grep",
 385                        GrepToolInput {
 386                            regex: "mod\\s+tests".into(),
 387                            include_pattern: Some("font-kit/src/**/*.rs".into()),
 388                            offset: 0,
 389                            case_sensitive: false,
 390                        },
 391                    )],
 392                ),
 393                message(User, [tool_result("tool_3", "grep", "No matches found")]),
 394                message(
 395                    Assistant,
 396                    [tool_use(
 397                        "tool_4",
 398                        "grep",
 399                        GrepToolInput {
 400                            regex: "#\\[test\\]".into(),
 401                            include_pattern: Some("font-kit/src/**/*.rs".into()),
 402                            offset: 0,
 403                            case_sensitive: false,
 404                        },
 405                    )],
 406                ),
 407                message(
 408                    User,
 409                    [tool_result(
 410                        "tool_4",
 411                        "grep",
 412                        indoc! {"
 413                            Found 6 matches:
 414
 415                            ## Matches in font-kit/src/loaders/core_text.rs
 416
 417                            ### mod test › L926-936
 418                            ```
 419                            mod test {
 420                                use super::Font;
 421                                use crate::properties::{Stretch, Weight};
 422
 423                                #[cfg(feature = \"source\")]
 424                                use crate::source::SystemSource;
 425
 426                                static TEST_FONT_POSTSCRIPT_NAME: &'static str = \"ArialMT\";
 427
 428                                #[cfg(feature = \"source\")]
 429                                #[test]
 430                            ```
 431
 432                            55 lines remaining in ancestor node. Read the file to see all.
 433
 434                            ### mod test › L947-951
 435                            ```
 436                                }
 437
 438                                #[test]
 439                                fn test_core_text_to_css_font_weight() {
 440                                    // Exact matches
 441                            ```
 442
 443                            ### mod test › L959-963
 444                            ```
 445                                }
 446
 447                                #[test]
 448                                fn test_core_text_to_css_font_stretch() {
 449                                    // Exact matches
 450                            ```
 451
 452                            ## Matches in font-kit/src/loaders/freetype.rs
 453
 454                            ### mod test › L1238-1248
 455                            ```
 456                            mod test {
 457                                use crate::loaders::freetype::Font;
 458
 459                                static PCF_FONT_PATH: &str = \"resources/tests/times-roman-pcf/timR12.pcf\";
 460                                static PCF_FONT_POSTSCRIPT_NAME: &str = \"Times-Roman\";
 461
 462                                #[test]
 463                                fn get_pcf_postscript_name() {
 464                                    let font = Font::from_path(PCF_FONT_PATH, 0).unwrap();
 465                                    assert_eq!(font.postscript_name().unwrap(), PCF_FONT_POSTSCRIPT_NAME);
 466                                }
 467                            ```
 468
 469                            1 lines remaining in ancestor node. Read the file to see all.
 470
 471                            ## Matches in font-kit/src/sources/core_text.rs
 472
 473                            ### mod test › L265-275
 474                            ```
 475                            mod test {
 476                                use crate::properties::{Stretch, Weight};
 477
 478                                #[test]
 479                                fn test_css_to_core_text_font_weight() {
 480                                    // Exact matches
 481                                    assert_eq!(super::css_to_core_text_font_weight(Weight(100.0)), -0.7);
 482                                    assert_eq!(super::css_to_core_text_font_weight(Weight(400.0)), 0.0);
 483                                    assert_eq!(super::css_to_core_text_font_weight(Weight(700.0)), 0.4);
 484                                    assert_eq!(super::css_to_core_text_font_weight(Weight(900.0)), 0.8);
 485
 486                            ```
 487
 488                            27 lines remaining in ancestor node. Read the file to see all.
 489
 490                            ### mod test › L278-282
 491                            ```
 492                                }
 493
 494                                #[test]
 495                                fn test_css_to_core_text_font_stretch() {
 496                                    // Exact matches
 497                            ```
 498                        "},
 499                    )],
 500                ),
 501                message(
 502                    Assistant,
 503                    [tool_use(
 504                        "tool_5",
 505                        "edit_file",
 506                        EditFileToolInput {
 507                            display_description: edit_description.into(),
 508                            path: input_file_path.into(),
 509                            create_or_overwrite: false,
 510                        },
 511                    )],
 512                ),
 513            ],
 514            input_path: input_file_path.into(),
 515            input_content: Some(input_file_content.into()),
 516            edit_description: edit_description.into(),
 517            assertion: EvalAssertion::judge_diff(indoc! {"
 518                - The diff contains a new `from_pixels` constructor
 519                - The diff contains new tests for the `from_pixels` constructor
 520            "}),
 521        },
 522    );
 523}
 524
 525#[test]
 526#[cfg_attr(not(feature = "eval"), ignore)]
 527fn eval_zode() {
 528    let input_file_path = "root/zode.py";
 529    let edit_description = "Create the main Zode CLI script";
 530    eval(
 531        200,
 532        1.,
 533        EvalInput {
 534            conversation: vec![
 535                message(User, [text(include_str!("evals/fixtures/zode/prompt.md"))]),
 536                message(
 537                    Assistant,
 538                    [
 539                        tool_use(
 540                            "tool_1",
 541                            "read_file",
 542                            ReadFileToolInput {
 543                                path: "root/eval/react.py".into(),
 544                                start_line: None,
 545                                end_line: None,
 546                            },
 547                        ),
 548                        tool_use(
 549                            "tool_2",
 550                            "read_file",
 551                            ReadFileToolInput {
 552                                path: "root/eval/react_test.py".into(),
 553                                start_line: None,
 554                                end_line: None,
 555                            },
 556                        ),
 557                    ],
 558                ),
 559                message(
 560                    User,
 561                    [
 562                        tool_result(
 563                            "tool_1",
 564                            "read_file",
 565                            include_str!("evals/fixtures/zode/react.py"),
 566                        ),
 567                        tool_result(
 568                            "tool_2",
 569                            "read_file",
 570                            include_str!("evals/fixtures/zode/react_test.py"),
 571                        ),
 572                    ],
 573                ),
 574                message(
 575                    Assistant,
 576                    [
 577                        text(
 578                            "Now that I understand what we need to build, I'll create the main Python script:",
 579                        ),
 580                        tool_use(
 581                            "tool_3",
 582                            "edit_file",
 583                            EditFileToolInput {
 584                                display_description: edit_description.into(),
 585                                path: input_file_path.into(),
 586                                create_or_overwrite: true,
 587                            },
 588                        ),
 589                    ],
 590                ),
 591            ],
 592            input_path: input_file_path.into(),
 593            input_content: None,
 594            edit_description: edit_description.into(),
 595            assertion: EvalAssertion::new(async move |sample, _, _cx| {
 596                let invalid_starts = [' ', '`', '\n'];
 597                let mut message = String::new();
 598                for start in invalid_starts {
 599                    if sample.text.starts_with(start) {
 600                        message.push_str(&format!("The sample starts with a {:?}\n", start));
 601                        break;
 602                    }
 603                }
 604                // Remove trailing newline.
 605                message.pop();
 606
 607                if message.is_empty() {
 608                    Ok(EvalAssertionOutcome {
 609                        score: 100,
 610                        message: None,
 611                    })
 612                } else {
 613                    Ok(EvalAssertionOutcome {
 614                        score: 0,
 615                        message: Some(message),
 616                    })
 617                }
 618            }),
 619        },
 620    );
 621}
 622
 623#[test]
 624#[cfg_attr(not(feature = "eval"), ignore)]
 625fn eval_add_overwrite_test() {
 626    let input_file_path = "root/action_log.rs";
 627    let input_file_content = include_str!("evals/fixtures/add_overwrite_test/before.rs");
 628    let edit_description = "Add a new test for overwriting a file in action_log.rs";
 629    eval(
 630        200,
 631        0.5, // TODO: make this eval better
 632        EvalInput {
 633            conversation: vec![
 634                message(
 635                    User,
 636                    [text(indoc! {"
 637                        Introduce a new test in `action_log.rs` to test overwriting a file.
 638                        That is, a file already exists, but we call `buffer_created` as if the file were new.
 639                        Take inspiration from all the other tests in the file.
 640                    "})],
 641                ),
 642                message(
 643                    Assistant,
 644                    [tool_use(
 645                        "tool_1",
 646                        "read_file",
 647                        ReadFileToolInput {
 648                            path: input_file_path.into(),
 649                            start_line: None,
 650                            end_line: None,
 651                        },
 652                    )],
 653                ),
 654                message(
 655                    User,
 656                    [tool_result(
 657                        "tool_1",
 658                        "read_file",
 659                        indoc! {"
 660                            pub struct ActionLog [L13-20]
 661                             tracked_buffers [L15]
 662                             edited_since_project_diagnostics_check [L17]
 663                             project [L19]
 664                            impl ActionLog [L22-498]
 665                             pub fn new [L24-30]
 666                             pub fn project [L32-34]
 667                             pub fn checked_project_diagnostics [L37-39]
 668                             pub fn has_edited_files_since_project_diagnostics_check [L42-44]
 669                             fn track_buffer_internal [L46-101]
 670                             fn handle_buffer_event [L103-116]
 671                             fn handle_buffer_edited [L118-123]
 672                             fn handle_buffer_file_changed [L125-158]
 673                             async fn maintain_diff [L160-264]
 674                             pub fn buffer_read [L267-269]
 675                             pub fn buffer_created [L272-276]
 676                             pub fn buffer_edited [L279-287]
 677                             pub fn will_delete_buffer [L289-304]
 678                             pub fn keep_edits_in_range [L306-364]
 679                             pub fn reject_edits_in_ranges [L366-459]
 680                             pub fn keep_all_edits [L461-473]
 681                             pub fn changed_buffers [L476-482]
 682                             pub fn stale_buffers [L485-497]
 683                            fn apply_non_conflicting_edits [L500-561]
 684                            fn diff_snapshots [L563-585]
 685                            fn point_to_row_edit [L587-614]
 686                            enum ChangeAuthor [L617-620]
 687                             User [L618]
 688                             Agent [L619]
 689                            enum TrackedBufferStatus [L623-627]
 690                             Created [L624]
 691                             Modified [L625]
 692                             Deleted [L626]
 693                            struct TrackedBuffer [L629-641]
 694                             buffer [L630]
 695                             base_text [L631]
 696                             unreviewed_changes [L632]
 697                             status [L633]
 698                             version [L634]
 699                             diff [L635]
 700                             snapshot [L636]
 701                             diff_update [L637]
 702                             _open_lsp_handle [L638]
 703                             _maintain_diff [L639]
 704                             _subscription [L640]
 705                            impl TrackedBuffer [L643-657]
 706                             fn has_changes [L644-650]
 707                             fn schedule_diff_update [L652-656]
 708                            pub struct ChangedBuffer [L659-661]
 709                             pub diff [L660]
 710                            mod tests [L664-1574]
 711                             fn init_logger [L678-682]
 712                             fn init_test [L684-691]
 713                             async fn test_keep_edits [L694-769]
 714                             async fn test_deletions [L772-854]
 715                             async fn test_overlapping_user_edits [L857-951]
 716                             async fn test_creating_files [L954-1010]
 717                             async fn test_deleting_files [L1013-1120]
 718                             async fn test_reject_edits [L1123-1255]
 719                             async fn test_reject_multiple_edits [L1258-1331]
 720                             async fn test_reject_deleted_file [L1334-1388]
 721                             async fn test_reject_created_file [L1391-1443]
 722                             async fn test_random_diffs [L1446-1535]
 723                              fn quiesce [L1510-1534]
 724                             struct HunkStatus [L1538-1542]
 725                              range [L1539]
 726                              diff_status [L1540]
 727                              old_text [L1541]
 728                             fn unreviewed_hunks [L1544-1573]
 729
 730                            Showing symbols 1-69 (total symbols: 69)
 731
 732                            Using the line numbers in this outline, you can call this tool again while specifying
 733                            the start_line and end_line fields to see the implementations of symbols in the outline.
 734                        "},
 735                    )],
 736                ),
 737                message(
 738                    Assistant,
 739                    [
 740                        text(
 741                            "Let's look at the existing test cases that are related to file creation to understand the pattern:",
 742                        ),
 743                        tool_use(
 744                            "tool_2",
 745                            "read_file",
 746                            ReadFileToolInput {
 747                                path: input_file_path.into(),
 748                                start_line: Some(953),
 749                                end_line: Some(1010),
 750                            },
 751                        ),
 752                    ],
 753                ),
 754                message(
 755                    User,
 756                    [tool_result(
 757                        "tool_2",
 758                        "read_file",
 759                        lines(input_file_content, 953..1010),
 760                    )],
 761                ),
 762                message(
 763                    Assistant,
 764                    [
 765                        text(
 766                            "Now let's look at another relevant test to understand how they're structured:",
 767                        ),
 768                        tool_use(
 769                            "tool_3",
 770                            "read_file",
 771                            ReadFileToolInput {
 772                                path: input_file_path.into(),
 773                                start_line: Some(1012),
 774                                end_line: Some(1120),
 775                            },
 776                        ),
 777                    ],
 778                ),
 779                message(
 780                    User,
 781                    [tool_result(
 782                        "tool_3",
 783                        "read_file",
 784                        lines(input_file_content, 1012..1120),
 785                    )],
 786                ),
 787                message(
 788                    Assistant,
 789                    [
 790                        text("Now let's look at how `buffer_created` is implemented:"),
 791                        tool_use(
 792                            "tool_4",
 793                            "read_file",
 794                            ReadFileToolInput {
 795                                path: input_file_path.into(),
 796                                start_line: Some(271),
 797                                end_line: Some(276),
 798                            },
 799                        ),
 800                    ],
 801                ),
 802                message(
 803                    User,
 804                    [tool_result(
 805                        "tool_4",
 806                        "read_file",
 807                        lines(input_file_content, 271..276),
 808                    )],
 809                ),
 810                message(
 811                    Assistant,
 812                    [
 813                        text(indoc! {"
 814                            Based on my understanding of the existing tests and code, I'll now create a new test for overwriting a file. In this test, we'll:
 815
 816                            1. Create a file
 817                            2. Open a buffer for it
 818                            3. Track it in the action log
 819                            4. Modify the buffer
 820                            5. Call `buffer_created` again as if the file were new
 821                            6. Check that the changes are properly tracked
 822
 823                            Let's write the test:
 824                        "}),
 825                        tool_use(
 826                            "tool_5",
 827                            "edit_file",
 828                            EditFileToolInput {
 829                                display_description: edit_description.into(),
 830                                path: input_file_path.into(),
 831                                create_or_overwrite: false,
 832                            },
 833                        ),
 834                    ],
 835                ),
 836            ],
 837            input_path: input_file_path.into(),
 838            input_content: Some(input_file_content.into()),
 839            edit_description: edit_description.into(),
 840            assertion: EvalAssertion::judge_diff(
 841                "A new test for overwritten files was created, without changing any previous test",
 842            ),
 843        },
 844    );
 845}
 846
 847fn message(
 848    role: Role,
 849    contents: impl IntoIterator<Item = MessageContent>,
 850) -> LanguageModelRequestMessage {
 851    LanguageModelRequestMessage {
 852        role,
 853        content: contents.into_iter().collect(),
 854        cache: false,
 855    }
 856}
 857
 858fn text(text: impl Into<String>) -> MessageContent {
 859    MessageContent::Text(text.into())
 860}
 861
 862fn lines(input: &str, range: Range<usize>) -> String {
 863    input
 864        .lines()
 865        .skip(range.start)
 866        .take(range.len())
 867        .collect::<Vec<_>>()
 868        .join("\n")
 869}
 870
 871fn tool_use(
 872    id: impl Into<Arc<str>>,
 873    name: impl Into<Arc<str>>,
 874    input: impl Serialize,
 875) -> MessageContent {
 876    MessageContent::ToolUse(LanguageModelToolUse {
 877        id: LanguageModelToolUseId::from(id.into()),
 878        name: name.into(),
 879        raw_input: serde_json::to_string_pretty(&input).unwrap(),
 880        input: serde_json::to_value(input).unwrap(),
 881        is_input_complete: true,
 882    })
 883}
 884
 885fn tool_result(
 886    id: impl Into<Arc<str>>,
 887    name: impl Into<Arc<str>>,
 888    result: impl Into<Arc<str>>,
 889) -> MessageContent {
 890    MessageContent::ToolResult(LanguageModelToolResult {
 891        tool_use_id: LanguageModelToolUseId::from(id.into()),
 892        tool_name: name.into(),
 893        is_error: false,
 894        content: result.into(),
 895        output: None,
 896    })
 897}
 898
 899#[derive(Clone)]
 900struct EvalInput {
 901    conversation: Vec<LanguageModelRequestMessage>,
 902    input_path: PathBuf,
 903    input_content: Option<String>,
 904    edit_description: String,
 905    assertion: EvalAssertion,
 906}
 907
 908#[derive(Clone)]
 909struct EvalSample {
 910    text: String,
 911    edit_output: EditAgentOutput,
 912    diff: String,
 913}
 914
 915trait AssertionFn: 'static + Send + Sync {
 916    fn assert<'a>(
 917        &'a self,
 918        sample: &'a EvalSample,
 919        judge_model: Arc<dyn LanguageModel>,
 920        cx: &'a mut TestAppContext,
 921    ) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>>;
 922}
 923
 924impl<F> AssertionFn for F
 925where
 926    F: 'static
 927        + Send
 928        + Sync
 929        + AsyncFn(
 930            &EvalSample,
 931            Arc<dyn LanguageModel>,
 932            &mut TestAppContext,
 933        ) -> Result<EvalAssertionOutcome>,
 934{
 935    fn assert<'a>(
 936        &'a self,
 937        sample: &'a EvalSample,
 938        judge_model: Arc<dyn LanguageModel>,
 939        cx: &'a mut TestAppContext,
 940    ) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>> {
 941        (self)(sample, judge_model, cx).boxed_local()
 942    }
 943}
 944
 945#[derive(Clone)]
 946struct EvalAssertion(Arc<dyn AssertionFn>);
 947
 948impl EvalAssertion {
 949    fn new<F>(f: F) -> Self
 950    where
 951        F: 'static
 952            + Send
 953            + Sync
 954            + AsyncFn(
 955                &EvalSample,
 956                Arc<dyn LanguageModel>,
 957                &mut TestAppContext,
 958            ) -> Result<EvalAssertionOutcome>,
 959    {
 960        EvalAssertion(Arc::new(f))
 961    }
 962
 963    fn assert_eq(expected: impl Into<String>) -> Self {
 964        let expected = expected.into();
 965        Self::new(async move |sample, _judge, _cx| {
 966            Ok(EvalAssertionOutcome {
 967                score: if strip_empty_lines(&sample.text) == strip_empty_lines(&expected) {
 968                    100
 969                } else {
 970                    0
 971                },
 972                message: None,
 973            })
 974        })
 975    }
 976
 977    fn judge_diff(assertions: &'static str) -> Self {
 978        Self::new(async move |sample, judge, cx| {
 979            let prompt = DiffJudgeTemplate {
 980                diff: sample.diff.clone(),
 981                assertions,
 982            }
 983            .render(&Templates::new())
 984            .unwrap();
 985
 986            let request = LanguageModelRequest {
 987                messages: vec![LanguageModelRequestMessage {
 988                    role: Role::User,
 989                    content: vec![prompt.into()],
 990                    cache: false,
 991                }],
 992                ..Default::default()
 993            };
 994            let mut response = judge
 995                .stream_completion_text(request, &cx.to_async())
 996                .await?;
 997            let mut output = String::new();
 998            while let Some(chunk) = response.stream.next().await {
 999                let chunk = chunk?;
1000                output.push_str(&chunk);
1001            }
1002
1003            // Parse the score from the response
1004            let re = regex::Regex::new(r"<score>(\d+)</score>").unwrap();
1005            if let Some(captures) = re.captures(&output) {
1006                if let Some(score_match) = captures.get(1) {
1007                    let score = score_match.as_str().parse().unwrap_or(0);
1008                    return Ok(EvalAssertionOutcome {
1009                        score,
1010                        message: Some(output),
1011                    });
1012                }
1013            }
1014
1015            Err(anyhow!(
1016                "No score found in response. Raw output: {}",
1017                output
1018            ))
1019        })
1020    }
1021
1022    async fn run(
1023        &self,
1024        input: &EvalSample,
1025        judge_model: Arc<dyn LanguageModel>,
1026        cx: &mut TestAppContext,
1027    ) -> Result<EvalAssertionOutcome> {
1028        self.0.assert(input, judge_model, cx).await
1029    }
1030}
1031
1032fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
1033    let mut evaluated_count = 0;
1034    report_progress(evaluated_count, iterations);
1035
1036    let (tx, rx) = mpsc::channel();
1037
1038    // Cache the last message in the conversation, and run one instance of the eval so that
1039    // all the next ones are cached.
1040    eval.conversation.last_mut().unwrap().cache = true;
1041    run_eval(eval.clone(), tx.clone());
1042
1043    let executor = gpui::background_executor();
1044    for _ in 1..iterations {
1045        let eval = eval.clone();
1046        let tx = tx.clone();
1047        executor.spawn(async move { run_eval(eval, tx) }).detach();
1048    }
1049    drop(tx);
1050
1051    let mut failed_count = 0;
1052    let mut failed_evals = HashMap::default();
1053    let mut errored_evals = HashMap::default();
1054    let mut eval_outputs = Vec::new();
1055    let mut cumulative_parser_metrics = EditParserMetrics::default();
1056    while let Ok(output) = rx.recv() {
1057        match output {
1058            Ok(output) => {
1059                cumulative_parser_metrics += output.sample.edit_output._parser_metrics.clone();
1060                eval_outputs.push(output.clone());
1061                if output.assertion.score < 80 {
1062                    failed_count += 1;
1063                    failed_evals
1064                        .entry(output.sample.text.clone())
1065                        .or_insert(Vec::new())
1066                        .push(output);
1067                }
1068            }
1069            Err(error) => {
1070                failed_count += 1;
1071                *errored_evals.entry(format!("{:?}", error)).or_insert(0) += 1;
1072            }
1073        }
1074
1075        evaluated_count += 1;
1076        report_progress(evaluated_count, iterations);
1077    }
1078
1079    let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32;
1080    println!("Actual pass ratio: {}\n", actual_pass_ratio);
1081    if actual_pass_ratio < expected_pass_ratio {
1082        let mut errored_evals = errored_evals.into_iter().collect::<Vec<_>>();
1083        errored_evals.sort_by_key(|(_, count)| Reverse(*count));
1084        for (error, count) in errored_evals {
1085            println!("Eval errored {} times. Error: {}", count, error);
1086        }
1087
1088        let mut failed_evals = failed_evals.into_iter().collect::<Vec<_>>();
1089        failed_evals.sort_by_key(|(_, evals)| Reverse(evals.len()));
1090        for (_buffer_output, failed_evals) in failed_evals {
1091            let eval_output = failed_evals.first().unwrap();
1092            println!("Eval failed {} times", failed_evals.len());
1093            println!("{}", eval_output);
1094        }
1095
1096        panic!(
1097            "Actual pass ratio: {}\nExpected pass ratio: {}",
1098            actual_pass_ratio, expected_pass_ratio
1099        );
1100    }
1101
1102    let mismatched_tag_ratio =
1103        cumulative_parser_metrics.mismatched_tags as f32 / cumulative_parser_metrics.tags as f32;
1104    if mismatched_tag_ratio > 0.05 {
1105        for eval_output in eval_outputs {
1106            println!("{}", eval_output);
1107        }
1108        panic!("Too many mismatched tags: {:?}", cumulative_parser_metrics);
1109    }
1110}
1111
1112fn run_eval(eval: EvalInput, tx: mpsc::Sender<Result<EvalOutput>>) {
1113    let dispatcher = gpui::TestDispatcher::new(StdRng::from_entropy());
1114    let mut cx = TestAppContext::build(dispatcher, None);
1115    let output = cx.executor().block_test(async {
1116        let test = EditAgentTest::new(&mut cx).await;
1117        test.eval(eval, &mut cx).await
1118    });
1119    tx.send(output).unwrap();
1120}
1121
1122#[derive(Clone)]
1123struct EvalOutput {
1124    sample: EvalSample,
1125    assertion: EvalAssertionOutcome,
1126}
1127
1128impl Display for EvalOutput {
1129    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1130        writeln!(f, "Score: {:?}", self.assertion.score)?;
1131        if let Some(message) = self.assertion.message.as_ref() {
1132            writeln!(f, "Message: {}", message)?;
1133        }
1134
1135        writeln!(f, "Diff:\n{}", self.sample.diff)?;
1136
1137        writeln!(
1138            f,
1139            "Parser Metrics:\n{:#?}",
1140            self.sample.edit_output._parser_metrics
1141        )?;
1142        writeln!(f, "Raw Edits:\n{}", self.sample.edit_output._raw_edits)?;
1143        Ok(())
1144    }
1145}
1146
1147fn report_progress(evaluated_count: usize, iterations: usize) {
1148    print!("\r\x1b[KEvaluated {}/{}", evaluated_count, iterations);
1149    std::io::stdout().flush().unwrap();
1150}
1151
1152struct EditAgentTest {
1153    agent: EditAgent,
1154    project: Entity<Project>,
1155    judge_model: Arc<dyn LanguageModel>,
1156}
1157
1158impl EditAgentTest {
1159    async fn new(cx: &mut TestAppContext) -> Self {
1160        cx.executor().allow_parking();
1161        cx.update(settings::init);
1162        cx.update(Project::init_settings);
1163        cx.update(language::init);
1164        cx.update(gpui_tokio::init);
1165        cx.update(client::init_settings);
1166
1167        let fs = FakeFs::new(cx.executor().clone());
1168        fs.insert_tree("/root", json!({})).await;
1169        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1170        let (agent_model, judge_model) = cx
1171            .update(|cx| {
1172                let http_client = ReqwestClient::user_agent("agent tests").unwrap();
1173                cx.set_http_client(Arc::new(http_client));
1174
1175                let client = Client::production(cx);
1176                let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1177                language_model::init(client.clone(), cx);
1178                language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
1179
1180                cx.spawn(async move |cx| {
1181                    let agent_model =
1182                        Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
1183                    let judge_model =
1184                        Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
1185                    (agent_model.unwrap(), judge_model.unwrap())
1186                })
1187            })
1188            .await;
1189        let action_log = cx.new(|_| ActionLog::new(project.clone()));
1190
1191        Self {
1192            agent: EditAgent::new(agent_model, project.clone(), action_log, Templates::new()),
1193            project,
1194            judge_model,
1195        }
1196    }
1197
1198    async fn load_model(
1199        provider: &str,
1200        id: &str,
1201        cx: &mut AsyncApp,
1202    ) -> Result<Arc<dyn LanguageModel>> {
1203        let (provider, model) = cx.update(|cx| {
1204            let models = LanguageModelRegistry::read_global(cx);
1205            let model = models
1206                .available_models(cx)
1207                .find(|model| model.provider_id().0 == provider && model.id().0 == id)
1208                .unwrap();
1209            let provider = models.provider(&model.provider_id()).unwrap();
1210            (provider, model)
1211        })?;
1212        cx.update(|cx| provider.authenticate(cx))?.await?;
1213        Ok(model)
1214    }
1215
1216    async fn eval(&self, eval: EvalInput, cx: &mut TestAppContext) -> Result<EvalOutput> {
1217        let path = self
1218            .project
1219            .read_with(cx, |project, cx| {
1220                project.find_project_path(eval.input_path, cx)
1221            })
1222            .unwrap();
1223        let buffer = self
1224            .project
1225            .update(cx, |project, cx| project.open_buffer(path, cx))
1226            .await
1227            .unwrap();
1228        let edit_output = if let Some(input_content) = eval.input_content.as_deref() {
1229            buffer.update(cx, |buffer, cx| buffer.set_text(input_content, cx));
1230            let (edit_output, _) = self.agent.edit(
1231                buffer.clone(),
1232                eval.edit_description,
1233                eval.conversation,
1234                &mut cx.to_async(),
1235            );
1236            edit_output.await?
1237        } else {
1238            let (edit_output, _) = self.agent.overwrite(
1239                buffer.clone(),
1240                eval.edit_description,
1241                eval.conversation,
1242                &mut cx.to_async(),
1243            );
1244            edit_output.await?
1245        };
1246
1247        let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
1248        let sample = EvalSample {
1249            edit_output,
1250            diff: language::unified_diff(
1251                eval.input_content.as_deref().unwrap_or_default(),
1252                &buffer_text,
1253            ),
1254            text: buffer_text,
1255        };
1256        let assertion = eval
1257            .assertion
1258            .run(&sample, self.judge_model.clone(), cx)
1259            .await?;
1260
1261        Ok(EvalOutput { assertion, sample })
1262    }
1263}
1264
1265#[derive(Clone, Debug, Eq, PartialEq, Hash)]
1266struct EvalAssertionOutcome {
1267    score: usize,
1268    message: Option<String>,
1269}
1270
1271#[derive(Serialize)]
1272pub struct DiffJudgeTemplate {
1273    diff: String,
1274    assertions: &'static str,
1275}
1276
1277impl Template for DiffJudgeTemplate {
1278    const TEMPLATE_NAME: &'static str = "diff_judge.hbs";
1279}
1280
1281fn strip_empty_lines(text: &str) -> String {
1282    text.lines()
1283        .filter(|line| !line.trim().is_empty())
1284        .collect::<Vec<_>>()
1285        .join("\n")
1286}