evals.rs

  1use super::*;
  2use crate::{
  3    ReadFileToolInput, grep_tool::GrepToolInput,
  4    streaming_edit_file_tool::StreamingEditFileToolInput,
  5};
  6use Role::*;
  7use anyhow::{Context, anyhow};
  8use client::{Client, UserStore};
  9use collections::HashMap;
 10use fs::FakeFs;
 11use gpui::{AppContext, TestAppContext};
 12use indoc::indoc;
 13use language_model::{
 14    LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId,
 15};
 16use project::Project;
 17use rand::prelude::*;
 18use reqwest_client::ReqwestClient;
 19use serde_json::json;
 20use std::{
 21    cmp::Reverse,
 22    fmt::{self, Display},
 23    io::Write as _,
 24    sync::mpsc,
 25};
 26use util::path;
 27
 28#[test]
 29#[cfg_attr(not(feature = "eval"), ignore)]
 30fn eval_extract_handle_command_output() {
 31    let input_file_path = "root/blame.rs";
 32    let input_file_content = include_str!("evals/fixtures/extract_handle_command_output/before.rs");
 33    let output_file_content = include_str!("evals/fixtures/extract_handle_command_output/after.rs");
 34    let edit_description = "Extract `handle_command_output` method from `run_git_blame`.";
 35    eval(
 36        100,
 37        0.95,
 38        EvalInput {
 39            conversation: vec![
 40                message(
 41                    User,
 42                    [text(indoc! {"
 43                        Read the `{input_file_path}` file and extract a method in
 44                        the final stanza of `run_git_blame` to deal with command failures,
 45                        call it `handle_command_output` and take the std::process::Output as the only parameter.
 46
 47                        Add it right next to `run_git_blame` and copy it verbatim from `run_git_blame`.
 48                    "})],
 49                ),
 50                message(
 51                    Assistant,
 52                    [tool_use(
 53                        "tool_1",
 54                        "read_file",
 55                        ReadFileToolInput {
 56                            path: input_file_path.into(),
 57                            start_line: None,
 58                            end_line: None,
 59                        },
 60                    )],
 61                ),
 62                message(
 63                    User,
 64                    [tool_result("tool_1", "read_file", input_file_content)],
 65                ),
 66                message(
 67                    Assistant,
 68                    [tool_use(
 69                        "tool_2",
 70                        "edit_file",
 71                        StreamingEditFileToolInput {
 72                            display_description: edit_description.into(),
 73                            path: input_file_path.into(),
 74                        },
 75                    )],
 76                ),
 77            ],
 78            input_path: input_file_path.into(),
 79            input_content: input_file_content.into(),
 80            edit_description: edit_description.into(),
 81            assertion: EvalAssertion::AssertEqual(output_file_content.into()),
 82        },
 83    );
 84}
 85
 86#[test]
 87#[cfg_attr(not(feature = "eval"), ignore)]
 88fn eval_delete_run_git_blame() {
 89    let input_file_path = "root/blame.rs";
 90    let input_file_content = include_str!("evals/fixtures/delete_run_git_blame/before.rs");
 91    let output_file_content = include_str!("evals/fixtures/delete_run_git_blame/after.rs");
 92    let edit_description = "Delete the `run_git_blame` function.";
 93    eval(
 94        100,
 95        0.95,
 96        EvalInput {
 97            conversation: vec![
 98                message(
 99                    User,
100                    [text(indoc! {"
101                        Read the `{input_file_path}` file and delete `run_git_blame`. Just that
102                        one function, not its usages.
103                    "})],
104                ),
105                message(
106                    Assistant,
107                    [tool_use(
108                        "tool_1",
109                        "read_file",
110                        ReadFileToolInput {
111                            path: input_file_path.into(),
112                            start_line: None,
113                            end_line: None,
114                        },
115                    )],
116                ),
117                message(
118                    User,
119                    [tool_result("tool_1", "read_file", input_file_content)],
120                ),
121                message(
122                    Assistant,
123                    [tool_use(
124                        "tool_2",
125                        "edit_file",
126                        StreamingEditFileToolInput {
127                            display_description: edit_description.into(),
128                            path: input_file_path.into(),
129                        },
130                    )],
131                ),
132            ],
133            input_path: input_file_path.into(),
134            input_content: input_file_content.into(),
135            edit_description: edit_description.into(),
136            assertion: EvalAssertion::AssertEqual(output_file_content.into()),
137        },
138    );
139}
140
141#[test]
142#[cfg_attr(not(feature = "eval"), ignore)]
143fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
144    let input_file_path = "root/lib.rs";
145    let input_file_content =
146        include_str!("evals/fixtures/use_wasi_sdk_in_compile_parser_to_wasm/before.rs");
147    let edit_description = "Update compile_parser_to_wasm to use wasi-sdk instead of emscripten";
148    eval(
149        100,
150        0.95,
151        EvalInput {
152            conversation: vec![
153                message(
154                    User,
155                    [text(indoc! {"
156                        Read the `{input_file_path}` file and change `compile_parser_to_wasm` to use `wasi-sdk` instead of emscripten.
157                        Use `ureq` to download the SDK for the current platform and architecture.
158                        Extract the archive into a sibling of `lib` inside the `tree-sitter` directory in the cache_dir.
159                        Compile the parser to wasm using the `bin/clang` executable (or `bin/clang.exe` on windows)
160                        that's inside of the archive.
161                        Don't re-download the SDK if that executable already exists.
162
163                        Use these clang flags: -fPIC -shared -Os -Wl,--export=tree_sitter_{language_name}
164
165                        Here are the available wasi-sdk assets:
166                        - wasi-sdk-25.0-x86_64-macos.tar.gz
167                        - wasi-sdk-25.0-arm64-macos.tar.gz
168                        - wasi-sdk-25.0-x86_64-linux.tar.gz
169                        - wasi-sdk-25.0-arm64-linux.tar.gz
170                        - wasi-sdk-25.0-x86_64-linux.tar.gz
171                        - wasi-sdk-25.0-arm64-linux.tar.gz
172                        - wasi-sdk-25.0-x86_64-windows.tar.gz
173                    "})],
174                ),
175                message(
176                    Assistant,
177                    [tool_use(
178                        "tool_1",
179                        "read_file",
180                        ReadFileToolInput {
181                            path: input_file_path.into(),
182                            start_line: Some(971),
183                            end_line: Some(1050),
184                        },
185                    )],
186                ),
187                message(
188                    User,
189                    [tool_result(
190                        "tool_1",
191                        "read_file",
192                        lines(input_file_content, 971..1050),
193                    )],
194                ),
195                message(
196                    Assistant,
197                    [tool_use(
198                        "tool_2",
199                        "read_file",
200                        ReadFileToolInput {
201                            path: input_file_path.into(),
202                            start_line: Some(1050),
203                            end_line: Some(1100),
204                        },
205                    )],
206                ),
207                message(
208                    User,
209                    [tool_result(
210                        "tool_2",
211                        "read_file",
212                        lines(input_file_content, 1050..1100),
213                    )],
214                ),
215                message(
216                    Assistant,
217                    [tool_use(
218                        "tool_3",
219                        "read_file",
220                        ReadFileToolInput {
221                            path: input_file_path.into(),
222                            start_line: Some(1100),
223                            end_line: Some(1150),
224                        },
225                    )],
226                ),
227                message(
228                    User,
229                    [tool_result(
230                        "tool_3",
231                        "read_file",
232                        lines(input_file_content, 1100..1150),
233                    )],
234                ),
235                message(
236                    Assistant,
237                    [tool_use(
238                        "tool_4",
239                        "edit_file",
240                        StreamingEditFileToolInput {
241                            display_description: edit_description.into(),
242                            path: input_file_path.into(),
243                        },
244                    )],
245                ),
246            ],
247            input_path: input_file_path.into(),
248            input_content: input_file_content.into(),
249            edit_description: edit_description.into(),
250            assertion: EvalAssertion::JudgeDiff(indoc! {"
251                - The compile_parser_to_wasm method has been changed to use wasi-sdk
252                - ureq is used to download the SDK for current platform and architecture
253            "}),
254        },
255    );
256}
257
258#[test]
259#[cfg_attr(not(feature = "eval"), ignore)]
260fn eval_disable_cursor_blinking() {
261    let input_file_path = "root/editor.rs";
262    let input_file_content = include_str!("evals/fixtures/disable_cursor_blinking/before.rs");
263    let output_file_content = include_str!("evals/fixtures/disable_cursor_blinking/after.rs");
264    let edit_description = "Comment out the call to `BlinkManager::enable`";
265    eval(
266        100,
267        0.6, // TODO: make this eval better
268        EvalInput {
269            conversation: vec![
270                message(User, [text("Let's research how to cursor blinking works.")]),
271                message(
272                    Assistant,
273                    [tool_use(
274                        "tool_1",
275                        "grep",
276                        GrepToolInput {
277                            regex: "blink".into(),
278                            include_pattern: None,
279                            offset: 0,
280                            case_sensitive: false,
281                        },
282                    )],
283                ),
284                message(
285                    User,
286                    [tool_result(
287                        "tool_1",
288                        "grep",
289                        [
290                            lines(input_file_content, 100..400),
291                            lines(input_file_content, 800..1300),
292                            lines(input_file_content, 1600..2000),
293                            lines(input_file_content, 5000..5500),
294                            lines(input_file_content, 8000..9000),
295                            lines(input_file_content, 18455..18470),
296                            lines(input_file_content, 20000..20500),
297                            lines(input_file_content, 21000..21300),
298                        ]
299                        .join("Match found:\n\n"),
300                    )],
301                ),
302                message(
303                    User,
304                    [text(indoc! {"
305                        Comment out the lines that interact with the BlinkManager.
306                        Keep the outer `update` blocks, but comments everything that's inside (including if statements).
307                        Don't add additional comments.
308                    "})],
309                ),
310                message(
311                    Assistant,
312                    [tool_use(
313                        "tool_4",
314                        "edit_file",
315                        StreamingEditFileToolInput {
316                            display_description: edit_description.into(),
317                            path: input_file_path.into(),
318                        },
319                    )],
320                ),
321            ],
322            input_path: input_file_path.into(),
323            input_content: input_file_content.into(),
324            edit_description: edit_description.into(),
325            assertion: EvalAssertion::AssertEqual(output_file_content.into()),
326        },
327    );
328}
329
330#[test]
331#[cfg_attr(not(feature = "eval"), ignore)]
332fn eval_from_pixels_constructor() {
333    let input_file_path = "root/canvas.rs";
334    let input_file_content = include_str!("evals/fixtures/from_pixels_constructor/before.rs");
335    let edit_description = "Implement from_pixels constructor and add tests.";
336    eval(
337        100,
338        0.95,
339        EvalInput {
340            conversation: vec![
341                message(
342                    User,
343                    [text(indoc! {"
344                        Introduce a new `from_pixels` constructor in Canvas and
345                        also add tests for it in the same file.
346                    "})],
347                ),
348                message(
349                    Assistant,
350                    [tool_use(
351                        "tool_1",
352                        "read_file",
353                        ReadFileToolInput {
354                            path: input_file_path.into(),
355                            start_line: None,
356                            end_line: None,
357                        },
358                    )],
359                ),
360                message(
361                    User,
362                    [tool_result("tool_1", "read_file", input_file_content)],
363                ),
364                message(
365                    Assistant,
366                    [tool_use(
367                        "tool_2",
368                        "grep",
369                        GrepToolInput {
370                            regex: "mod\\s+tests".into(),
371                            include_pattern: Some("font-kit/src/canvas.rs".into()),
372                            offset: 0,
373                            case_sensitive: false,
374                        },
375                    )],
376                ),
377                message(User, [tool_result("tool_2", "grep", "No matches found")]),
378                message(
379                    Assistant,
380                    [tool_use(
381                        "tool_3",
382                        "grep",
383                        GrepToolInput {
384                            regex: "mod\\s+tests".into(),
385                            include_pattern: Some("font-kit/src/**/*.rs".into()),
386                            offset: 0,
387                            case_sensitive: false,
388                        },
389                    )],
390                ),
391                message(User, [tool_result("tool_3", "grep", "No matches found")]),
392                message(
393                    Assistant,
394                    [tool_use(
395                        "tool_4",
396                        "grep",
397                        GrepToolInput {
398                            regex: "#\\[test\\]".into(),
399                            include_pattern: Some("font-kit/src/**/*.rs".into()),
400                            offset: 0,
401                            case_sensitive: false,
402                        },
403                    )],
404                ),
405                message(
406                    User,
407                    [tool_result(
408                        "tool_4",
409                        "grep",
410                        indoc! {"
411                            Found 6 matches:
412
413                            ## Matches in font-kit/src/loaders/core_text.rs
414
415                            ### mod test › L926-936
416                            ```
417                            mod test {
418                                use super::Font;
419                                use crate::properties::{Stretch, Weight};
420
421                                #[cfg(feature = \"source\")]
422                                use crate::source::SystemSource;
423
424                                static TEST_FONT_POSTSCRIPT_NAME: &'static str = \"ArialMT\";
425
426                                #[cfg(feature = \"source\")]
427                                #[test]
428                            ```
429
430                            55 lines remaining in ancestor node. Read the file to see all.
431
432                            ### mod test › L947-951
433                            ```
434                                }
435
436                                #[test]
437                                fn test_core_text_to_css_font_weight() {
438                                    // Exact matches
439                            ```
440
441                            ### mod test › L959-963
442                            ```
443                                }
444
445                                #[test]
446                                fn test_core_text_to_css_font_stretch() {
447                                    // Exact matches
448                            ```
449
450                            ## Matches in font-kit/src/loaders/freetype.rs
451
452                            ### mod test › L1238-1248
453                            ```
454                            mod test {
455                                use crate::loaders::freetype::Font;
456
457                                static PCF_FONT_PATH: &str = \"resources/tests/times-roman-pcf/timR12.pcf\";
458                                static PCF_FONT_POSTSCRIPT_NAME: &str = \"Times-Roman\";
459
460                                #[test]
461                                fn get_pcf_postscript_name() {
462                                    let font = Font::from_path(PCF_FONT_PATH, 0).unwrap();
463                                    assert_eq!(font.postscript_name().unwrap(), PCF_FONT_POSTSCRIPT_NAME);
464                                }
465                            ```
466
467                            1 lines remaining in ancestor node. Read the file to see all.
468
469                            ## Matches in font-kit/src/sources/core_text.rs
470
471                            ### mod test › L265-275
472                            ```
473                            mod test {
474                                use crate::properties::{Stretch, Weight};
475
476                                #[test]
477                                fn test_css_to_core_text_font_weight() {
478                                    // Exact matches
479                                    assert_eq!(super::css_to_core_text_font_weight(Weight(100.0)), -0.7);
480                                    assert_eq!(super::css_to_core_text_font_weight(Weight(400.0)), 0.0);
481                                    assert_eq!(super::css_to_core_text_font_weight(Weight(700.0)), 0.4);
482                                    assert_eq!(super::css_to_core_text_font_weight(Weight(900.0)), 0.8);
483
484                            ```
485
486                            27 lines remaining in ancestor node. Read the file to see all.
487
488                            ### mod test › L278-282
489                            ```
490                                }
491
492                                #[test]
493                                fn test_css_to_core_text_font_stretch() {
494                                    // Exact matches
495                            ```
496                        "},
497                    )],
498                ),
499                message(
500                    Assistant,
501                    [tool_use(
502                        "tool_5",
503                        "edit_file",
504                        StreamingEditFileToolInput {
505                            display_description: edit_description.into(),
506                            path: input_file_path.into(),
507                        },
508                    )],
509                ),
510            ],
511            input_path: input_file_path.into(),
512            input_content: input_file_content.into(),
513            edit_description: edit_description.into(),
514            assertion: EvalAssertion::JudgeDiff(indoc! {"
515                - The diff contains a new `from_pixels` constructor
516                - The diff contains new tests for the `from_pixels` constructor
517            "}),
518        },
519    );
520}
521
522fn message(
523    role: Role,
524    contents: impl IntoIterator<Item = MessageContent>,
525) -> LanguageModelRequestMessage {
526    LanguageModelRequestMessage {
527        role,
528        content: contents.into_iter().collect(),
529        cache: false,
530    }
531}
532
533fn text(text: impl Into<String>) -> MessageContent {
534    MessageContent::Text(text.into())
535}
536
537fn lines(input: &str, range: Range<usize>) -> String {
538    input
539        .lines()
540        .skip(range.start)
541        .take(range.len())
542        .collect::<Vec<_>>()
543        .join("\n")
544}
545
546fn tool_use(
547    id: impl Into<Arc<str>>,
548    name: impl Into<Arc<str>>,
549    input: impl Serialize,
550) -> MessageContent {
551    MessageContent::ToolUse(LanguageModelToolUse {
552        id: LanguageModelToolUseId::from(id.into()),
553        name: name.into(),
554        raw_input: serde_json::to_string_pretty(&input).unwrap(),
555        input: serde_json::to_value(input).unwrap(),
556        is_input_complete: true,
557    })
558}
559
560fn tool_result(
561    id: impl Into<Arc<str>>,
562    name: impl Into<Arc<str>>,
563    result: impl Into<Arc<str>>,
564) -> MessageContent {
565    MessageContent::ToolResult(LanguageModelToolResult {
566        tool_use_id: LanguageModelToolUseId::from(id.into()),
567        tool_name: name.into(),
568        is_error: false,
569        content: result.into(),
570    })
571}
572
573#[derive(Clone)]
574struct EvalInput {
575    conversation: Vec<LanguageModelRequestMessage>,
576    input_path: PathBuf,
577    input_content: String,
578    edit_description: String,
579    assertion: EvalAssertion,
580}
581
582fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
583    let mut evaluated_count = 0;
584    report_progress(evaluated_count, iterations);
585
586    let (tx, rx) = mpsc::channel();
587
588    // Cache the last message in the conversation, and run one instance of the eval so that
589    // all the next ones are cached.
590    eval.conversation.last_mut().unwrap().cache = true;
591    run_eval(eval.clone(), tx.clone());
592
593    let executor = gpui::background_executor();
594    for _ in 1..iterations {
595        let eval = eval.clone();
596        let tx = tx.clone();
597        executor.spawn(async move { run_eval(eval, tx) }).detach();
598    }
599    drop(tx);
600
601    let mut failed_count = 0;
602    let mut failed_evals = HashMap::default();
603    let mut errored_evals = HashMap::default();
604    let mut eval_outputs = Vec::new();
605    let mut cumulative_parser_metrics = EditParserMetrics::default();
606    while let Ok(output) = rx.recv() {
607        match output {
608            Ok(output) => {
609                cumulative_parser_metrics += output.edit_output._parser_metrics.clone();
610                eval_outputs.push(output.clone());
611                if output.assertion.score < 80 {
612                    failed_count += 1;
613                    failed_evals
614                        .entry(output.buffer_text.clone())
615                        .or_insert(Vec::new())
616                        .push(output);
617                }
618            }
619            Err(error) => {
620                failed_count += 1;
621                *errored_evals.entry(format!("{:?}", error)).or_insert(0) += 1;
622            }
623        }
624
625        evaluated_count += 1;
626        report_progress(evaluated_count, iterations);
627    }
628
629    let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32;
630    println!("Actual pass ratio: {}\n", actual_pass_ratio);
631    if actual_pass_ratio < expected_pass_ratio {
632        let mut errored_evals = errored_evals.into_iter().collect::<Vec<_>>();
633        errored_evals.sort_by_key(|(_, count)| Reverse(*count));
634        for (error, count) in errored_evals {
635            println!("Eval errored {} times. Error: {}", count, error);
636        }
637
638        let mut failed_evals = failed_evals.into_iter().collect::<Vec<_>>();
639        failed_evals.sort_by_key(|(_, evals)| Reverse(evals.len()));
640        for (_buffer_output, failed_evals) in failed_evals {
641            let eval_output = failed_evals.first().unwrap();
642            println!("Eval failed {} times", failed_evals.len());
643            println!("{}", eval_output);
644        }
645
646        panic!(
647            "Actual pass ratio: {}\nExpected pass ratio: {}",
648            actual_pass_ratio, expected_pass_ratio
649        );
650    }
651
652    let mismatched_tag_ratio =
653        cumulative_parser_metrics.mismatched_tags as f32 / cumulative_parser_metrics.tags as f32;
654    if mismatched_tag_ratio > 0.02 {
655        for eval_output in eval_outputs {
656            println!("{}", eval_output);
657        }
658        panic!("Too many mismatched tags: {:?}", cumulative_parser_metrics);
659    }
660}
661
662fn run_eval(eval: EvalInput, tx: mpsc::Sender<Result<EvalOutput>>) {
663    let dispatcher = gpui::TestDispatcher::new(StdRng::from_entropy());
664    let mut cx = TestAppContext::build(dispatcher, None);
665    let output = cx.executor().block_test(async {
666        let test = EditAgentTest::new(&mut cx).await;
667        test.eval(eval, &mut cx).await
668    });
669    tx.send(output).unwrap();
670}
671
672#[derive(Clone)]
673struct EvalOutput {
674    assertion: EvalAssertionResult,
675    buffer_text: String,
676    edit_output: EditAgentOutput,
677    diff: String,
678}
679
680impl Display for EvalOutput {
681    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
682        writeln!(f, "Score: {:?}", self.assertion.score)?;
683        if let Some(message) = self.assertion.message.as_ref() {
684            writeln!(f, "Message: {}", message)?;
685        }
686
687        writeln!(f, "Diff:\n{}", self.diff)?;
688
689        writeln!(
690            f,
691            "Parser Metrics:\n{:#?}",
692            self.edit_output._parser_metrics
693        )?;
694        writeln!(f, "Raw Edits:\n{}", self.edit_output._raw_edits)?;
695        Ok(())
696    }
697}
698
699fn report_progress(evaluated_count: usize, iterations: usize) {
700    print!("\r\x1b[KEvaluated {}/{}", evaluated_count, iterations);
701    std::io::stdout().flush().unwrap();
702}
703
704struct EditAgentTest {
705    agent: EditAgent,
706    project: Entity<Project>,
707    judge_model: Arc<dyn LanguageModel>,
708}
709
710impl EditAgentTest {
711    async fn new(cx: &mut TestAppContext) -> Self {
712        cx.executor().allow_parking();
713        cx.update(settings::init);
714        cx.update(Project::init_settings);
715        cx.update(language::init);
716        cx.update(gpui_tokio::init);
717        cx.update(client::init_settings);
718
719        let fs = FakeFs::new(cx.executor().clone());
720        fs.insert_tree("/root", json!({})).await;
721        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
722        let (agent_model, judge_model) = cx
723            .update(|cx| {
724                let http_client = ReqwestClient::user_agent("agent tests").unwrap();
725                cx.set_http_client(Arc::new(http_client));
726
727                let client = Client::production(cx);
728                let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
729                language_model::init(client.clone(), cx);
730                language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
731
732                cx.spawn(async move |cx| {
733                    let agent_model =
734                        Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
735                    let judge_model =
736                        Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
737                    (agent_model.unwrap(), judge_model.unwrap())
738                })
739            })
740            .await;
741        let action_log = cx.new(|_| ActionLog::new(project.clone()));
742
743        Self {
744            agent: EditAgent::new(agent_model, action_log, Templates::new()),
745            project,
746            judge_model,
747        }
748    }
749
750    async fn load_model(
751        provider: &str,
752        id: &str,
753        cx: &mut AsyncApp,
754    ) -> Result<Arc<dyn LanguageModel>> {
755        let (provider, model) = cx.update(|cx| {
756            let models = LanguageModelRegistry::read_global(cx);
757            let model = models
758                .available_models(cx)
759                .find(|model| model.provider_id().0 == provider && model.id().0 == id)
760                .unwrap();
761            let provider = models.provider(&model.provider_id()).unwrap();
762            (provider, model)
763        })?;
764        cx.update(|cx| provider.authenticate(cx))?.await?;
765        Ok(model)
766    }
767
768    async fn eval(&self, eval: EvalInput, cx: &mut TestAppContext) -> Result<EvalOutput> {
769        let path = self
770            .project
771            .read_with(cx, |project, cx| {
772                project.find_project_path(eval.input_path, cx)
773            })
774            .unwrap();
775        let buffer = self
776            .project
777            .update(cx, |project, cx| project.open_buffer(path, cx))
778            .await
779            .unwrap();
780        buffer.update(cx, |buffer, cx| {
781            buffer.set_text(eval.input_content.clone(), cx)
782        });
783        let (edit_output, _events) = self.agent.edit(
784            buffer.clone(),
785            eval.edit_description,
786            eval.conversation,
787            &mut cx.to_async(),
788        );
789        let edit_output = edit_output.await?;
790        let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
791        let actual_diff = language::unified_diff(&eval.input_content, &buffer_text);
792        let assertion = match eval.assertion {
793            EvalAssertion::AssertEqual(expected_output) => EvalAssertionResult {
794                score: if strip_empty_lines(&buffer_text) == strip_empty_lines(&expected_output) {
795                    100
796                } else {
797                    0
798                },
799                message: None,
800            },
801            EvalAssertion::JudgeDiff(assertions) => self
802                .judge_diff(&actual_diff, assertions, &cx.to_async())
803                .await
804                .context("failed comparing diffs")?,
805        };
806
807        Ok(EvalOutput {
808            assertion,
809            diff: actual_diff,
810            buffer_text,
811            edit_output,
812        })
813    }
814
815    async fn judge_diff(
816        &self,
817        diff: &str,
818        assertions: &'static str,
819        cx: &AsyncApp,
820    ) -> Result<EvalAssertionResult> {
821        let prompt = DiffJudgeTemplate {
822            diff: diff.to_string(),
823            assertions,
824        }
825        .render(&self.agent.templates)
826        .unwrap();
827
828        let request = LanguageModelRequest {
829            messages: vec![LanguageModelRequestMessage {
830                role: Role::User,
831                content: vec![prompt.into()],
832                cache: false,
833            }],
834            ..Default::default()
835        };
836        let mut response = self.judge_model.stream_completion_text(request, cx).await?;
837        let mut output = String::new();
838        while let Some(chunk) = response.stream.next().await {
839            let chunk = chunk?;
840            output.push_str(&chunk);
841        }
842
843        // Parse the score from the response
844        let re = regex::Regex::new(r"<score>(\d+)</score>").unwrap();
845        if let Some(captures) = re.captures(&output) {
846            if let Some(score_match) = captures.get(1) {
847                let score = score_match.as_str().parse().unwrap_or(0);
848                return Ok(EvalAssertionResult {
849                    score,
850                    message: Some(output),
851                });
852            }
853        }
854
855        Err(anyhow!(
856            "No score found in response. Raw output: {}",
857            output
858        ))
859    }
860}
861
862#[derive(Clone, Debug, Eq, PartialEq, Hash)]
863enum EvalAssertion {
864    AssertEqual(String),
865    JudgeDiff(&'static str),
866}
867
868#[derive(Clone, Debug, Eq, PartialEq, Hash)]
869struct EvalAssertionResult {
870    score: usize,
871    message: Option<String>,
872}
873
874#[derive(Serialize)]
875pub struct DiffJudgeTemplate {
876    diff: String,
877    assertions: &'static str,
878}
879
880impl Template for DiffJudgeTemplate {
881    const TEMPLATE_NAME: &'static str = "diff_judge.hbs";
882}
883
884fn strip_empty_lines(text: &str) -> String {
885    text.lines()
886        .filter(|line| !line.trim().is_empty())
887        .collect::<Vec<_>>()
888        .join("\n")
889}