evals.rs

   1use super::*;
   2use crate::{
   3    ReadFileToolInput, grep_tool::GrepToolInput,
   4    streaming_edit_file_tool::StreamingEditFileToolInput,
   5};
   6use Role::*;
   7use anyhow::anyhow;
   8use client::{Client, UserStore};
   9use collections::HashMap;
  10use fs::FakeFs;
  11use futures::{FutureExt, future::LocalBoxFuture};
  12use gpui::{AppContext, TestAppContext};
  13use indoc::indoc;
  14use language_model::{
  15    LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId,
  16};
  17use project::Project;
  18use rand::prelude::*;
  19use reqwest_client::ReqwestClient;
  20use serde_json::json;
  21use std::{
  22    cmp::Reverse,
  23    fmt::{self, Display},
  24    io::Write as _,
  25    sync::mpsc,
  26};
  27use util::path;
  28
  29#[test]
  30#[cfg_attr(not(feature = "eval"), ignore)]
  31fn eval_extract_handle_command_output() {
  32    let input_file_path = "root/blame.rs";
  33    let input_file_content = include_str!("evals/fixtures/extract_handle_command_output/before.rs");
  34    let output_file_content = include_str!("evals/fixtures/extract_handle_command_output/after.rs");
  35    let edit_description = "Extract `handle_command_output` method from `run_git_blame`.";
  36    eval(
  37        100,
  38        0.95,
  39        EvalInput {
  40            conversation: vec![
  41                message(
  42                    User,
  43                    [text(indoc! {"
  44                        Read the `{input_file_path}` file and extract a method in
  45                        the final stanza of `run_git_blame` to deal with command failures,
  46                        call it `handle_command_output` and take the std::process::Output as the only parameter.
  47
  48                        Add it right next to `run_git_blame` and copy it verbatim from `run_git_blame`.
  49                    "})],
  50                ),
  51                message(
  52                    Assistant,
  53                    [tool_use(
  54                        "tool_1",
  55                        "read_file",
  56                        ReadFileToolInput {
  57                            path: input_file_path.into(),
  58                            start_line: None,
  59                            end_line: None,
  60                        },
  61                    )],
  62                ),
  63                message(
  64                    User,
  65                    [tool_result("tool_1", "read_file", input_file_content)],
  66                ),
  67                message(
  68                    Assistant,
  69                    [tool_use(
  70                        "tool_2",
  71                        "edit_file",
  72                        StreamingEditFileToolInput {
  73                            display_description: edit_description.into(),
  74                            path: input_file_path.into(),
  75                            create_or_overwrite: false,
  76                        },
  77                    )],
  78                ),
  79            ],
  80            input_path: input_file_path.into(),
  81            input_content: Some(input_file_content.into()),
  82            edit_description: edit_description.into(),
  83            assertion: EvalAssertion::assert_eq(output_file_content),
  84        },
  85    );
  86}
  87
  88#[test]
  89#[cfg_attr(not(feature = "eval"), ignore)]
  90fn eval_delete_run_git_blame() {
  91    let input_file_path = "root/blame.rs";
  92    let input_file_content = include_str!("evals/fixtures/delete_run_git_blame/before.rs");
  93    let output_file_content = include_str!("evals/fixtures/delete_run_git_blame/after.rs");
  94    let edit_description = "Delete the `run_git_blame` function.";
  95    eval(
  96        100,
  97        0.95,
  98        EvalInput {
  99            conversation: vec![
 100                message(
 101                    User,
 102                    [text(indoc! {"
 103                        Read the `{input_file_path}` file and delete `run_git_blame`. Just that
 104                        one function, not its usages.
 105                    "})],
 106                ),
 107                message(
 108                    Assistant,
 109                    [tool_use(
 110                        "tool_1",
 111                        "read_file",
 112                        ReadFileToolInput {
 113                            path: input_file_path.into(),
 114                            start_line: None,
 115                            end_line: None,
 116                        },
 117                    )],
 118                ),
 119                message(
 120                    User,
 121                    [tool_result("tool_1", "read_file", input_file_content)],
 122                ),
 123                message(
 124                    Assistant,
 125                    [tool_use(
 126                        "tool_2",
 127                        "edit_file",
 128                        StreamingEditFileToolInput {
 129                            display_description: edit_description.into(),
 130                            path: input_file_path.into(),
 131                            create_or_overwrite: false,
 132                        },
 133                    )],
 134                ),
 135            ],
 136            input_path: input_file_path.into(),
 137            input_content: Some(input_file_content.into()),
 138            edit_description: edit_description.into(),
 139            assertion: EvalAssertion::assert_eq(output_file_content),
 140        },
 141    );
 142}
 143
 144#[test]
 145#[cfg_attr(not(feature = "eval"), ignore)]
 146fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
 147    let input_file_path = "root/lib.rs";
 148    let input_file_content =
 149        include_str!("evals/fixtures/use_wasi_sdk_in_compile_parser_to_wasm/before.rs");
 150    let edit_description = "Update compile_parser_to_wasm to use wasi-sdk instead of emscripten";
 151    eval(
 152        100,
 153        0.95,
 154        EvalInput {
 155            conversation: vec![
 156                message(
 157                    User,
 158                    [text(indoc! {"
 159                        Read the `{input_file_path}` file and change `compile_parser_to_wasm` to use `wasi-sdk` instead of emscripten.
 160                        Use `ureq` to download the SDK for the current platform and architecture.
 161                        Extract the archive into a sibling of `lib` inside the `tree-sitter` directory in the cache_dir.
 162                        Compile the parser to wasm using the `bin/clang` executable (or `bin/clang.exe` on windows)
 163                        that's inside of the archive.
 164                        Don't re-download the SDK if that executable already exists.
 165
 166                        Use these clang flags: -fPIC -shared -Os -Wl,--export=tree_sitter_{language_name}
 167
 168                        Here are the available wasi-sdk assets:
 169                        - wasi-sdk-25.0-x86_64-macos.tar.gz
 170                        - wasi-sdk-25.0-arm64-macos.tar.gz
 171                        - wasi-sdk-25.0-x86_64-linux.tar.gz
 172                        - wasi-sdk-25.0-arm64-linux.tar.gz
 173                        - wasi-sdk-25.0-x86_64-linux.tar.gz
 174                        - wasi-sdk-25.0-arm64-linux.tar.gz
 175                        - wasi-sdk-25.0-x86_64-windows.tar.gz
 176                    "})],
 177                ),
 178                message(
 179                    Assistant,
 180                    [tool_use(
 181                        "tool_1",
 182                        "read_file",
 183                        ReadFileToolInput {
 184                            path: input_file_path.into(),
 185                            start_line: Some(971),
 186                            end_line: Some(1050),
 187                        },
 188                    )],
 189                ),
 190                message(
 191                    User,
 192                    [tool_result(
 193                        "tool_1",
 194                        "read_file",
 195                        lines(input_file_content, 971..1050),
 196                    )],
 197                ),
 198                message(
 199                    Assistant,
 200                    [tool_use(
 201                        "tool_2",
 202                        "read_file",
 203                        ReadFileToolInput {
 204                            path: input_file_path.into(),
 205                            start_line: Some(1050),
 206                            end_line: Some(1100),
 207                        },
 208                    )],
 209                ),
 210                message(
 211                    User,
 212                    [tool_result(
 213                        "tool_2",
 214                        "read_file",
 215                        lines(input_file_content, 1050..1100),
 216                    )],
 217                ),
 218                message(
 219                    Assistant,
 220                    [tool_use(
 221                        "tool_3",
 222                        "read_file",
 223                        ReadFileToolInput {
 224                            path: input_file_path.into(),
 225                            start_line: Some(1100),
 226                            end_line: Some(1150),
 227                        },
 228                    )],
 229                ),
 230                message(
 231                    User,
 232                    [tool_result(
 233                        "tool_3",
 234                        "read_file",
 235                        lines(input_file_content, 1100..1150),
 236                    )],
 237                ),
 238                message(
 239                    Assistant,
 240                    [tool_use(
 241                        "tool_4",
 242                        "edit_file",
 243                        StreamingEditFileToolInput {
 244                            display_description: edit_description.into(),
 245                            path: input_file_path.into(),
 246                            create_or_overwrite: false,
 247                        },
 248                    )],
 249                ),
 250            ],
 251            input_path: input_file_path.into(),
 252            input_content: Some(input_file_content.into()),
 253            edit_description: edit_description.into(),
 254            assertion: EvalAssertion::judge_diff(indoc! {"
 255                - The compile_parser_to_wasm method has been changed to use wasi-sdk
 256                - ureq is used to download the SDK for current platform and architecture
 257            "}),
 258        },
 259    );
 260}
 261
 262#[test]
 263#[cfg_attr(not(feature = "eval"), ignore)]
 264fn eval_disable_cursor_blinking() {
 265    let input_file_path = "root/editor.rs";
 266    let input_file_content = include_str!("evals/fixtures/disable_cursor_blinking/before.rs");
 267    let output_file_content = include_str!("evals/fixtures/disable_cursor_blinking/after.rs");
 268    let edit_description = "Comment out the call to `BlinkManager::enable`";
 269    eval(
 270        100,
 271        0.6, // TODO: make this eval better
 272        EvalInput {
 273            conversation: vec![
 274                message(User, [text("Let's research how to cursor blinking works.")]),
 275                message(
 276                    Assistant,
 277                    [tool_use(
 278                        "tool_1",
 279                        "grep",
 280                        GrepToolInput {
 281                            regex: "blink".into(),
 282                            include_pattern: None,
 283                            offset: 0,
 284                            case_sensitive: false,
 285                        },
 286                    )],
 287                ),
 288                message(
 289                    User,
 290                    [tool_result(
 291                        "tool_1",
 292                        "grep",
 293                        [
 294                            lines(input_file_content, 100..400),
 295                            lines(input_file_content, 800..1300),
 296                            lines(input_file_content, 1600..2000),
 297                            lines(input_file_content, 5000..5500),
 298                            lines(input_file_content, 8000..9000),
 299                            lines(input_file_content, 18455..18470),
 300                            lines(input_file_content, 20000..20500),
 301                            lines(input_file_content, 21000..21300),
 302                        ]
 303                        .join("Match found:\n\n"),
 304                    )],
 305                ),
 306                message(
 307                    User,
 308                    [text(indoc! {"
 309                        Comment out the lines that interact with the BlinkManager.
 310                        Keep the outer `update` blocks, but comments everything that's inside (including if statements).
 311                        Don't add additional comments.
 312                    "})],
 313                ),
 314                message(
 315                    Assistant,
 316                    [tool_use(
 317                        "tool_4",
 318                        "edit_file",
 319                        StreamingEditFileToolInput {
 320                            display_description: edit_description.into(),
 321                            path: input_file_path.into(),
 322                            create_or_overwrite: false,
 323                        },
 324                    )],
 325                ),
 326            ],
 327            input_path: input_file_path.into(),
 328            input_content: Some(input_file_content.into()),
 329            edit_description: edit_description.into(),
 330            assertion: EvalAssertion::assert_eq(output_file_content),
 331        },
 332    );
 333}
 334
 335#[test]
 336#[cfg_attr(not(feature = "eval"), ignore)]
 337fn eval_from_pixels_constructor() {
 338    let input_file_path = "root/canvas.rs";
 339    let input_file_content = include_str!("evals/fixtures/from_pixels_constructor/before.rs");
 340    let edit_description = "Implement from_pixels constructor and add tests.";
 341    eval(
 342        100,
 343        0.95,
 344        EvalInput {
 345            conversation: vec![
 346                message(
 347                    User,
 348                    [text(indoc! {"
 349                        Introduce a new `from_pixels` constructor in Canvas and
 350                        also add tests for it in the same file.
 351                    "})],
 352                ),
 353                message(
 354                    Assistant,
 355                    [tool_use(
 356                        "tool_1",
 357                        "read_file",
 358                        ReadFileToolInput {
 359                            path: input_file_path.into(),
 360                            start_line: None,
 361                            end_line: None,
 362                        },
 363                    )],
 364                ),
 365                message(
 366                    User,
 367                    [tool_result("tool_1", "read_file", input_file_content)],
 368                ),
 369                message(
 370                    Assistant,
 371                    [tool_use(
 372                        "tool_2",
 373                        "grep",
 374                        GrepToolInput {
 375                            regex: "mod\\s+tests".into(),
 376                            include_pattern: Some("font-kit/src/canvas.rs".into()),
 377                            offset: 0,
 378                            case_sensitive: false,
 379                        },
 380                    )],
 381                ),
 382                message(User, [tool_result("tool_2", "grep", "No matches found")]),
 383                message(
 384                    Assistant,
 385                    [tool_use(
 386                        "tool_3",
 387                        "grep",
 388                        GrepToolInput {
 389                            regex: "mod\\s+tests".into(),
 390                            include_pattern: Some("font-kit/src/**/*.rs".into()),
 391                            offset: 0,
 392                            case_sensitive: false,
 393                        },
 394                    )],
 395                ),
 396                message(User, [tool_result("tool_3", "grep", "No matches found")]),
 397                message(
 398                    Assistant,
 399                    [tool_use(
 400                        "tool_4",
 401                        "grep",
 402                        GrepToolInput {
 403                            regex: "#\\[test\\]".into(),
 404                            include_pattern: Some("font-kit/src/**/*.rs".into()),
 405                            offset: 0,
 406                            case_sensitive: false,
 407                        },
 408                    )],
 409                ),
 410                message(
 411                    User,
 412                    [tool_result(
 413                        "tool_4",
 414                        "grep",
 415                        indoc! {"
 416                            Found 6 matches:
 417
 418                            ## Matches in font-kit/src/loaders/core_text.rs
 419
 420                            ### mod test › L926-936
 421                            ```
 422                            mod test {
 423                                use super::Font;
 424                                use crate::properties::{Stretch, Weight};
 425
 426                                #[cfg(feature = \"source\")]
 427                                use crate::source::SystemSource;
 428
 429                                static TEST_FONT_POSTSCRIPT_NAME: &'static str = \"ArialMT\";
 430
 431                                #[cfg(feature = \"source\")]
 432                                #[test]
 433                            ```
 434
 435                            55 lines remaining in ancestor node. Read the file to see all.
 436
 437                            ### mod test › L947-951
 438                            ```
 439                                }
 440
 441                                #[test]
 442                                fn test_core_text_to_css_font_weight() {
 443                                    // Exact matches
 444                            ```
 445
 446                            ### mod test › L959-963
 447                            ```
 448                                }
 449
 450                                #[test]
 451                                fn test_core_text_to_css_font_stretch() {
 452                                    // Exact matches
 453                            ```
 454
 455                            ## Matches in font-kit/src/loaders/freetype.rs
 456
 457                            ### mod test › L1238-1248
 458                            ```
 459                            mod test {
 460                                use crate::loaders::freetype::Font;
 461
 462                                static PCF_FONT_PATH: &str = \"resources/tests/times-roman-pcf/timR12.pcf\";
 463                                static PCF_FONT_POSTSCRIPT_NAME: &str = \"Times-Roman\";
 464
 465                                #[test]
 466                                fn get_pcf_postscript_name() {
 467                                    let font = Font::from_path(PCF_FONT_PATH, 0).unwrap();
 468                                    assert_eq!(font.postscript_name().unwrap(), PCF_FONT_POSTSCRIPT_NAME);
 469                                }
 470                            ```
 471
 472                            1 lines remaining in ancestor node. Read the file to see all.
 473
 474                            ## Matches in font-kit/src/sources/core_text.rs
 475
 476                            ### mod test › L265-275
 477                            ```
 478                            mod test {
 479                                use crate::properties::{Stretch, Weight};
 480
 481                                #[test]
 482                                fn test_css_to_core_text_font_weight() {
 483                                    // Exact matches
 484                                    assert_eq!(super::css_to_core_text_font_weight(Weight(100.0)), -0.7);
 485                                    assert_eq!(super::css_to_core_text_font_weight(Weight(400.0)), 0.0);
 486                                    assert_eq!(super::css_to_core_text_font_weight(Weight(700.0)), 0.4);
 487                                    assert_eq!(super::css_to_core_text_font_weight(Weight(900.0)), 0.8);
 488
 489                            ```
 490
 491                            27 lines remaining in ancestor node. Read the file to see all.
 492
 493                            ### mod test › L278-282
 494                            ```
 495                                }
 496
 497                                #[test]
 498                                fn test_css_to_core_text_font_stretch() {
 499                                    // Exact matches
 500                            ```
 501                        "},
 502                    )],
 503                ),
 504                message(
 505                    Assistant,
 506                    [tool_use(
 507                        "tool_5",
 508                        "edit_file",
 509                        StreamingEditFileToolInput {
 510                            display_description: edit_description.into(),
 511                            path: input_file_path.into(),
 512                            create_or_overwrite: false,
 513                        },
 514                    )],
 515                ),
 516            ],
 517            input_path: input_file_path.into(),
 518            input_content: Some(input_file_content.into()),
 519            edit_description: edit_description.into(),
 520            assertion: EvalAssertion::judge_diff(indoc! {"
 521                - The diff contains a new `from_pixels` constructor
 522                - The diff contains new tests for the `from_pixels` constructor
 523            "}),
 524        },
 525    );
 526}
 527
 528#[test]
 529#[cfg_attr(not(feature = "eval"), ignore)]
 530fn eval_zode() {
 531    let input_file_path = "root/zode.py";
 532    let edit_description = "Create the main Zode CLI script";
 533    eval(
 534        200,
 535        1.,
 536        EvalInput {
 537            conversation: vec![
 538                message(User, [text(include_str!("evals/fixtures/zode/prompt.md"))]),
 539                message(
 540                    Assistant,
 541                    [
 542                        tool_use(
 543                            "tool_1",
 544                            "read_file",
 545                            ReadFileToolInput {
 546                                path: "root/eval/react.py".into(),
 547                                start_line: None,
 548                                end_line: None,
 549                            },
 550                        ),
 551                        tool_use(
 552                            "tool_2",
 553                            "read_file",
 554                            ReadFileToolInput {
 555                                path: "root/eval/react_test.py".into(),
 556                                start_line: None,
 557                                end_line: None,
 558                            },
 559                        ),
 560                    ],
 561                ),
 562                message(
 563                    User,
 564                    [
 565                        tool_result(
 566                            "tool_1",
 567                            "read_file",
 568                            include_str!("evals/fixtures/zode/react.py"),
 569                        ),
 570                        tool_result(
 571                            "tool_2",
 572                            "read_file",
 573                            include_str!("evals/fixtures/zode/react_test.py"),
 574                        ),
 575                    ],
 576                ),
 577                message(
 578                    Assistant,
 579                    [
 580                        text(
 581                            "Now that I understand what we need to build, I'll create the main Python script:",
 582                        ),
 583                        tool_use(
 584                            "tool_3",
 585                            "edit_file",
 586                            StreamingEditFileToolInput {
 587                                display_description: edit_description.into(),
 588                                path: input_file_path.into(),
 589                                create_or_overwrite: true,
 590                            },
 591                        ),
 592                    ],
 593                ),
 594            ],
 595            input_path: input_file_path.into(),
 596            input_content: None,
 597            edit_description: edit_description.into(),
 598            assertion: EvalAssertion::new(async move |sample, _, _cx| {
 599                let invalid_starts = [' ', '`', '\n'];
 600                let mut message = String::new();
 601                for start in invalid_starts {
 602                    if sample.text.starts_with(start) {
 603                        message.push_str(&format!("The sample starts with a {:?}\n", start));
 604                        break;
 605                    }
 606                }
 607                // Remove trailing newline.
 608                message.pop();
 609
 610                if message.is_empty() {
 611                    Ok(EvalAssertionOutcome {
 612                        score: 100,
 613                        message: None,
 614                    })
 615                } else {
 616                    Ok(EvalAssertionOutcome {
 617                        score: 0,
 618                        message: Some(message),
 619                    })
 620                }
 621            }),
 622        },
 623    );
 624}
 625
 626fn message(
 627    role: Role,
 628    contents: impl IntoIterator<Item = MessageContent>,
 629) -> LanguageModelRequestMessage {
 630    LanguageModelRequestMessage {
 631        role,
 632        content: contents.into_iter().collect(),
 633        cache: false,
 634    }
 635}
 636
 637fn text(text: impl Into<String>) -> MessageContent {
 638    MessageContent::Text(text.into())
 639}
 640
 641fn lines(input: &str, range: Range<usize>) -> String {
 642    input
 643        .lines()
 644        .skip(range.start)
 645        .take(range.len())
 646        .collect::<Vec<_>>()
 647        .join("\n")
 648}
 649
 650fn tool_use(
 651    id: impl Into<Arc<str>>,
 652    name: impl Into<Arc<str>>,
 653    input: impl Serialize,
 654) -> MessageContent {
 655    MessageContent::ToolUse(LanguageModelToolUse {
 656        id: LanguageModelToolUseId::from(id.into()),
 657        name: name.into(),
 658        raw_input: serde_json::to_string_pretty(&input).unwrap(),
 659        input: serde_json::to_value(input).unwrap(),
 660        is_input_complete: true,
 661    })
 662}
 663
 664fn tool_result(
 665    id: impl Into<Arc<str>>,
 666    name: impl Into<Arc<str>>,
 667    result: impl Into<Arc<str>>,
 668) -> MessageContent {
 669    MessageContent::ToolResult(LanguageModelToolResult {
 670        tool_use_id: LanguageModelToolUseId::from(id.into()),
 671        tool_name: name.into(),
 672        is_error: false,
 673        content: result.into(),
 674    })
 675}
 676
 677#[derive(Clone)]
 678struct EvalInput {
 679    conversation: Vec<LanguageModelRequestMessage>,
 680    input_path: PathBuf,
 681    input_content: Option<String>,
 682    edit_description: String,
 683    assertion: EvalAssertion,
 684}
 685
 686#[derive(Clone)]
 687struct EvalSample {
 688    text: String,
 689    edit_output: EditAgentOutput,
 690    diff: String,
 691}
 692
 693trait AssertionFn: 'static + Send + Sync {
 694    fn assert<'a>(
 695        &'a self,
 696        sample: &'a EvalSample,
 697        judge_model: Arc<dyn LanguageModel>,
 698        cx: &'a mut TestAppContext,
 699    ) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>>;
 700}
 701
 702impl<F> AssertionFn for F
 703where
 704    F: 'static
 705        + Send
 706        + Sync
 707        + AsyncFn(
 708            &EvalSample,
 709            Arc<dyn LanguageModel>,
 710            &mut TestAppContext,
 711        ) -> Result<EvalAssertionOutcome>,
 712{
 713    fn assert<'a>(
 714        &'a self,
 715        sample: &'a EvalSample,
 716        judge_model: Arc<dyn LanguageModel>,
 717        cx: &'a mut TestAppContext,
 718    ) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>> {
 719        (self)(sample, judge_model, cx).boxed_local()
 720    }
 721}
 722
 723#[derive(Clone)]
 724struct EvalAssertion(Arc<dyn AssertionFn>);
 725
 726impl EvalAssertion {
 727    fn new<F>(f: F) -> Self
 728    where
 729        F: 'static
 730            + Send
 731            + Sync
 732            + AsyncFn(
 733                &EvalSample,
 734                Arc<dyn LanguageModel>,
 735                &mut TestAppContext,
 736            ) -> Result<EvalAssertionOutcome>,
 737    {
 738        EvalAssertion(Arc::new(f))
 739    }
 740
 741    fn assert_eq(expected: impl Into<String>) -> Self {
 742        let expected = expected.into();
 743        Self::new(async move |sample, _judge, _cx| {
 744            Ok(EvalAssertionOutcome {
 745                score: if strip_empty_lines(&sample.text) == strip_empty_lines(&expected) {
 746                    100
 747                } else {
 748                    0
 749                },
 750                message: None,
 751            })
 752        })
 753    }
 754
 755    fn judge_diff(assertions: &'static str) -> Self {
 756        Self::new(async move |sample, judge, cx| {
 757            let prompt = DiffJudgeTemplate {
 758                diff: sample.diff.clone(),
 759                assertions,
 760            }
 761            .render(&Templates::new())
 762            .unwrap();
 763
 764            let request = LanguageModelRequest {
 765                messages: vec![LanguageModelRequestMessage {
 766                    role: Role::User,
 767                    content: vec![prompt.into()],
 768                    cache: false,
 769                }],
 770                ..Default::default()
 771            };
 772            let mut response = judge
 773                .stream_completion_text(request, &cx.to_async())
 774                .await?;
 775            let mut output = String::new();
 776            while let Some(chunk) = response.stream.next().await {
 777                let chunk = chunk?;
 778                output.push_str(&chunk);
 779            }
 780
 781            // Parse the score from the response
 782            let re = regex::Regex::new(r"<score>(\d+)</score>").unwrap();
 783            if let Some(captures) = re.captures(&output) {
 784                if let Some(score_match) = captures.get(1) {
 785                    let score = score_match.as_str().parse().unwrap_or(0);
 786                    return Ok(EvalAssertionOutcome {
 787                        score,
 788                        message: Some(output),
 789                    });
 790                }
 791            }
 792
 793            Err(anyhow!(
 794                "No score found in response. Raw output: {}",
 795                output
 796            ))
 797        })
 798    }
 799
 800    async fn run(
 801        &self,
 802        input: &EvalSample,
 803        judge_model: Arc<dyn LanguageModel>,
 804        cx: &mut TestAppContext,
 805    ) -> Result<EvalAssertionOutcome> {
 806        self.0.assert(input, judge_model, cx).await
 807    }
 808}
 809
 810fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
 811    let mut evaluated_count = 0;
 812    report_progress(evaluated_count, iterations);
 813
 814    let (tx, rx) = mpsc::channel();
 815
 816    // Cache the last message in the conversation, and run one instance of the eval so that
 817    // all the next ones are cached.
 818    eval.conversation.last_mut().unwrap().cache = true;
 819    run_eval(eval.clone(), tx.clone());
 820
 821    let executor = gpui::background_executor();
 822    for _ in 1..iterations {
 823        let eval = eval.clone();
 824        let tx = tx.clone();
 825        executor.spawn(async move { run_eval(eval, tx) }).detach();
 826    }
 827    drop(tx);
 828
 829    let mut failed_count = 0;
 830    let mut failed_evals = HashMap::default();
 831    let mut errored_evals = HashMap::default();
 832    let mut eval_outputs = Vec::new();
 833    let mut cumulative_parser_metrics = EditParserMetrics::default();
 834    while let Ok(output) = rx.recv() {
 835        match output {
 836            Ok(output) => {
 837                cumulative_parser_metrics += output.sample.edit_output._parser_metrics.clone();
 838                eval_outputs.push(output.clone());
 839                if output.assertion.score < 80 {
 840                    failed_count += 1;
 841                    failed_evals
 842                        .entry(output.sample.text.clone())
 843                        .or_insert(Vec::new())
 844                        .push(output);
 845                }
 846            }
 847            Err(error) => {
 848                failed_count += 1;
 849                *errored_evals.entry(format!("{:?}", error)).or_insert(0) += 1;
 850            }
 851        }
 852
 853        evaluated_count += 1;
 854        report_progress(evaluated_count, iterations);
 855    }
 856
 857    let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32;
 858    println!("Actual pass ratio: {}\n", actual_pass_ratio);
 859    if actual_pass_ratio < expected_pass_ratio {
 860        let mut errored_evals = errored_evals.into_iter().collect::<Vec<_>>();
 861        errored_evals.sort_by_key(|(_, count)| Reverse(*count));
 862        for (error, count) in errored_evals {
 863            println!("Eval errored {} times. Error: {}", count, error);
 864        }
 865
 866        let mut failed_evals = failed_evals.into_iter().collect::<Vec<_>>();
 867        failed_evals.sort_by_key(|(_, evals)| Reverse(evals.len()));
 868        for (_buffer_output, failed_evals) in failed_evals {
 869            let eval_output = failed_evals.first().unwrap();
 870            println!("Eval failed {} times", failed_evals.len());
 871            println!("{}", eval_output);
 872        }
 873
 874        panic!(
 875            "Actual pass ratio: {}\nExpected pass ratio: {}",
 876            actual_pass_ratio, expected_pass_ratio
 877        );
 878    }
 879
 880    let mismatched_tag_ratio =
 881        cumulative_parser_metrics.mismatched_tags as f32 / cumulative_parser_metrics.tags as f32;
 882    if mismatched_tag_ratio > 0.02 {
 883        for eval_output in eval_outputs {
 884            println!("{}", eval_output);
 885        }
 886        panic!("Too many mismatched tags: {:?}", cumulative_parser_metrics);
 887    }
 888}
 889
 890fn run_eval(eval: EvalInput, tx: mpsc::Sender<Result<EvalOutput>>) {
 891    let dispatcher = gpui::TestDispatcher::new(StdRng::from_entropy());
 892    let mut cx = TestAppContext::build(dispatcher, None);
 893    let output = cx.executor().block_test(async {
 894        let test = EditAgentTest::new(&mut cx).await;
 895        test.eval(eval, &mut cx).await
 896    });
 897    tx.send(output).unwrap();
 898}
 899
 900#[derive(Clone)]
 901struct EvalOutput {
 902    sample: EvalSample,
 903    assertion: EvalAssertionOutcome,
 904}
 905
 906impl Display for EvalOutput {
 907    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
 908        writeln!(f, "Score: {:?}", self.assertion.score)?;
 909        if let Some(message) = self.assertion.message.as_ref() {
 910            writeln!(f, "Message: {}", message)?;
 911        }
 912
 913        writeln!(f, "Diff:\n{}", self.sample.diff)?;
 914
 915        writeln!(
 916            f,
 917            "Parser Metrics:\n{:#?}",
 918            self.sample.edit_output._parser_metrics
 919        )?;
 920        writeln!(f, "Raw Edits:\n{}", self.sample.edit_output._raw_edits)?;
 921        Ok(())
 922    }
 923}
 924
 925fn report_progress(evaluated_count: usize, iterations: usize) {
 926    print!("\r\x1b[KEvaluated {}/{}", evaluated_count, iterations);
 927    std::io::stdout().flush().unwrap();
 928}
 929
 930struct EditAgentTest {
 931    agent: EditAgent,
 932    project: Entity<Project>,
 933    judge_model: Arc<dyn LanguageModel>,
 934}
 935
 936impl EditAgentTest {
 937    async fn new(cx: &mut TestAppContext) -> Self {
 938        cx.executor().allow_parking();
 939        cx.update(settings::init);
 940        cx.update(Project::init_settings);
 941        cx.update(language::init);
 942        cx.update(gpui_tokio::init);
 943        cx.update(client::init_settings);
 944
 945        let fs = FakeFs::new(cx.executor().clone());
 946        fs.insert_tree("/root", json!({})).await;
 947        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 948        let (agent_model, judge_model) = cx
 949            .update(|cx| {
 950                let http_client = ReqwestClient::user_agent("agent tests").unwrap();
 951                cx.set_http_client(Arc::new(http_client));
 952
 953                let client = Client::production(cx);
 954                let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
 955                language_model::init(client.clone(), cx);
 956                language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
 957
 958                cx.spawn(async move |cx| {
 959                    let agent_model =
 960                        Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
 961                    let judge_model =
 962                        Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
 963                    (agent_model.unwrap(), judge_model.unwrap())
 964                })
 965            })
 966            .await;
 967        let action_log = cx.new(|_| ActionLog::new(project.clone()));
 968
 969        Self {
 970            agent: EditAgent::new(agent_model, project.clone(), action_log, Templates::new()),
 971            project,
 972            judge_model,
 973        }
 974    }
 975
 976    async fn load_model(
 977        provider: &str,
 978        id: &str,
 979        cx: &mut AsyncApp,
 980    ) -> Result<Arc<dyn LanguageModel>> {
 981        let (provider, model) = cx.update(|cx| {
 982            let models = LanguageModelRegistry::read_global(cx);
 983            let model = models
 984                .available_models(cx)
 985                .find(|model| model.provider_id().0 == provider && model.id().0 == id)
 986                .unwrap();
 987            let provider = models.provider(&model.provider_id()).unwrap();
 988            (provider, model)
 989        })?;
 990        cx.update(|cx| provider.authenticate(cx))?.await?;
 991        Ok(model)
 992    }
 993
 994    async fn eval(&self, eval: EvalInput, cx: &mut TestAppContext) -> Result<EvalOutput> {
 995        let path = self
 996            .project
 997            .read_with(cx, |project, cx| {
 998                project.find_project_path(eval.input_path, cx)
 999            })
1000            .unwrap();
1001        let buffer = self
1002            .project
1003            .update(cx, |project, cx| project.open_buffer(path, cx))
1004            .await
1005            .unwrap();
1006        let edit_output = if let Some(input_content) = eval.input_content.as_deref() {
1007            buffer.update(cx, |buffer, cx| buffer.set_text(input_content, cx));
1008            let (edit_output, _) = self.agent.edit(
1009                buffer.clone(),
1010                eval.edit_description,
1011                eval.conversation,
1012                &mut cx.to_async(),
1013            );
1014            edit_output.await?
1015        } else {
1016            let (edit_output, _) = self.agent.overwrite(
1017                buffer.clone(),
1018                eval.edit_description,
1019                eval.conversation,
1020                &mut cx.to_async(),
1021            );
1022            edit_output.await?
1023        };
1024
1025        let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
1026        let sample = EvalSample {
1027            edit_output,
1028            diff: language::unified_diff(
1029                eval.input_content.as_deref().unwrap_or_default(),
1030                &buffer_text,
1031            ),
1032            text: buffer_text,
1033        };
1034        let assertion = eval
1035            .assertion
1036            .run(&sample, self.judge_model.clone(), cx)
1037            .await?;
1038
1039        Ok(EvalOutput { assertion, sample })
1040    }
1041}
1042
1043#[derive(Clone, Debug, Eq, PartialEq, Hash)]
1044struct EvalAssertionOutcome {
1045    score: usize,
1046    message: Option<String>,
1047}
1048
1049#[derive(Serialize)]
1050pub struct DiffJudgeTemplate {
1051    diff: String,
1052    assertions: &'static str,
1053}
1054
1055impl Template for DiffJudgeTemplate {
1056    const TEMPLATE_NAME: &'static str = "diff_judge.hbs";
1057}
1058
1059fn strip_empty_lines(text: &str) -> String {
1060    text.lines()
1061        .filter(|line| !line.trim().is_empty())
1062        .collect::<Vec<_>>()
1063        .join("\n")
1064}