1use super::*;
2use crate::{
3 ReadFileToolInput, grep_tool::GrepToolInput,
4 streaming_edit_file_tool::StreamingEditFileToolInput,
5};
6use Role::*;
7use anyhow::{Context, anyhow};
8use client::{Client, UserStore};
9use collections::HashMap;
10use fs::FakeFs;
11use gpui::{AppContext, TestAppContext};
12use indoc::indoc;
13use language_model::{
14 LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId,
15};
16use project::Project;
17use rand::prelude::*;
18use reqwest_client::ReqwestClient;
19use serde_json::json;
20use std::{
21 cmp::Reverse,
22 fmt::{self, Display},
23 io::Write as _,
24 sync::mpsc,
25};
26use util::path;
27
28#[test]
29#[cfg_attr(not(feature = "eval"), ignore)]
30fn eval_extract_handle_command_output() {
31 let input_file_path = "root/blame.rs";
32 let input_file_content = include_str!("evals/fixtures/extract_handle_command_output/before.rs");
33 let output_file_content = include_str!("evals/fixtures/extract_handle_command_output/after.rs");
34 let edit_description = "Extract `handle_command_output` method from `run_git_blame`.";
35 eval(
36 100,
37 0.95,
38 EvalInput {
39 conversation: vec![
40 message(
41 User,
42 [text(indoc! {"
43 Read the `{input_file_path}` file and extract a method in
44 the final stanza of `run_git_blame` to deal with command failures,
45 call it `handle_command_output` and take the std::process::Output as the only parameter.
46
47 Add it right next to `run_git_blame` and copy it verbatim from `run_git_blame`.
48 "})],
49 ),
50 message(
51 Assistant,
52 [tool_use(
53 "tool_1",
54 "read_file",
55 ReadFileToolInput {
56 path: input_file_path.into(),
57 start_line: None,
58 end_line: None,
59 },
60 )],
61 ),
62 message(
63 User,
64 [tool_result("tool_1", "read_file", input_file_content)],
65 ),
66 message(
67 Assistant,
68 [tool_use(
69 "tool_2",
70 "edit_file",
71 StreamingEditFileToolInput {
72 display_description: edit_description.into(),
73 path: input_file_path.into(),
74 },
75 )],
76 ),
77 ],
78 input_path: input_file_path.into(),
79 input_content: input_file_content.into(),
80 edit_description: edit_description.into(),
81 assertion: EvalAssertion::AssertEqual(output_file_content.into()),
82 },
83 );
84}
85
86#[test]
87#[cfg_attr(not(feature = "eval"), ignore)]
88fn eval_delete_run_git_blame() {
89 let input_file_path = "root/blame.rs";
90 let input_file_content = include_str!("evals/fixtures/delete_run_git_blame/before.rs");
91 let output_file_content = include_str!("evals/fixtures/delete_run_git_blame/after.rs");
92 let edit_description = "Delete the `run_git_blame` function.";
93 eval(
94 100,
95 0.95,
96 EvalInput {
97 conversation: vec![
98 message(
99 User,
100 [text(indoc! {"
101 Read the `{input_file_path}` file and delete `run_git_blame`. Just that
102 one function, not its usages.
103 "})],
104 ),
105 message(
106 Assistant,
107 [tool_use(
108 "tool_1",
109 "read_file",
110 ReadFileToolInput {
111 path: input_file_path.into(),
112 start_line: None,
113 end_line: None,
114 },
115 )],
116 ),
117 message(
118 User,
119 [tool_result("tool_1", "read_file", input_file_content)],
120 ),
121 message(
122 Assistant,
123 [tool_use(
124 "tool_2",
125 "edit_file",
126 StreamingEditFileToolInput {
127 display_description: edit_description.into(),
128 path: input_file_path.into(),
129 },
130 )],
131 ),
132 ],
133 input_path: input_file_path.into(),
134 input_content: input_file_content.into(),
135 edit_description: edit_description.into(),
136 assertion: EvalAssertion::AssertEqual(output_file_content.into()),
137 },
138 );
139}
140
141#[test]
142#[cfg_attr(not(feature = "eval"), ignore)]
143fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
144 let input_file_path = "root/lib.rs";
145 let input_file_content =
146 include_str!("evals/fixtures/use_wasi_sdk_in_compile_parser_to_wasm/before.rs");
147 let edit_description = "Update compile_parser_to_wasm to use wasi-sdk instead of emscripten";
148 eval(
149 100,
150 0.95,
151 EvalInput {
152 conversation: vec![
153 message(
154 User,
155 [text(indoc! {"
156 Read the `{input_file_path}` file and change `compile_parser_to_wasm` to use `wasi-sdk` instead of emscripten.
157 Use `ureq` to download the SDK for the current platform and architecture.
158 Extract the archive into a sibling of `lib` inside the `tree-sitter` directory in the cache_dir.
159 Compile the parser to wasm using the `bin/clang` executable (or `bin/clang.exe` on windows)
160 that's inside of the archive.
161 Don't re-download the SDK if that executable already exists.
162
163 Use these clang flags: -fPIC -shared -Os -Wl,--export=tree_sitter_{language_name}
164
165 Here are the available wasi-sdk assets:
166 - wasi-sdk-25.0-x86_64-macos.tar.gz
167 - wasi-sdk-25.0-arm64-macos.tar.gz
168 - wasi-sdk-25.0-x86_64-linux.tar.gz
169 - wasi-sdk-25.0-arm64-linux.tar.gz
170 - wasi-sdk-25.0-x86_64-linux.tar.gz
171 - wasi-sdk-25.0-arm64-linux.tar.gz
172 - wasi-sdk-25.0-x86_64-windows.tar.gz
173 "})],
174 ),
175 message(
176 Assistant,
177 [tool_use(
178 "tool_1",
179 "read_file",
180 ReadFileToolInput {
181 path: input_file_path.into(),
182 start_line: Some(971),
183 end_line: Some(1050),
184 },
185 )],
186 ),
187 message(
188 User,
189 [tool_result(
190 "tool_1",
191 "read_file",
192 lines(input_file_content, 971..1050),
193 )],
194 ),
195 message(
196 Assistant,
197 [tool_use(
198 "tool_2",
199 "read_file",
200 ReadFileToolInput {
201 path: input_file_path.into(),
202 start_line: Some(1050),
203 end_line: Some(1100),
204 },
205 )],
206 ),
207 message(
208 User,
209 [tool_result(
210 "tool_2",
211 "read_file",
212 lines(input_file_content, 1050..1100),
213 )],
214 ),
215 message(
216 Assistant,
217 [tool_use(
218 "tool_3",
219 "read_file",
220 ReadFileToolInput {
221 path: input_file_path.into(),
222 start_line: Some(1100),
223 end_line: Some(1150),
224 },
225 )],
226 ),
227 message(
228 User,
229 [tool_result(
230 "tool_3",
231 "read_file",
232 lines(input_file_content, 1100..1150),
233 )],
234 ),
235 message(
236 Assistant,
237 [tool_use(
238 "tool_4",
239 "edit_file",
240 StreamingEditFileToolInput {
241 display_description: edit_description.into(),
242 path: input_file_path.into(),
243 },
244 )],
245 ),
246 ],
247 input_path: input_file_path.into(),
248 input_content: input_file_content.into(),
249 edit_description: edit_description.into(),
250 assertion: EvalAssertion::JudgeDiff(indoc! {"
251 - The compile_parser_to_wasm method has been changed to use wasi-sdk
252 - ureq is used to download the SDK for current platform and architecture
253 "}),
254 },
255 );
256}
257
258#[test]
259#[cfg_attr(not(feature = "eval"), ignore)]
260fn eval_disable_cursor_blinking() {
261 let input_file_path = "root/editor.rs";
262 let input_file_content = include_str!("evals/fixtures/disable_cursor_blinking/before.rs");
263 let output_file_content = include_str!("evals/fixtures/disable_cursor_blinking/after.rs");
264 let edit_description = "Comment out the call to `BlinkManager::enable`";
265 eval(
266 100,
267 0.6, // TODO: make this eval better
268 EvalInput {
269 conversation: vec![
270 message(User, [text("Let's research how to cursor blinking works.")]),
271 message(
272 Assistant,
273 [tool_use(
274 "tool_1",
275 "grep",
276 GrepToolInput {
277 regex: "blink".into(),
278 include_pattern: None,
279 offset: 0,
280 case_sensitive: false,
281 },
282 )],
283 ),
284 message(
285 User,
286 [tool_result(
287 "tool_1",
288 "grep",
289 [
290 lines(input_file_content, 100..400),
291 lines(input_file_content, 800..1300),
292 lines(input_file_content, 1600..2000),
293 lines(input_file_content, 5000..5500),
294 lines(input_file_content, 8000..9000),
295 lines(input_file_content, 18455..18470),
296 lines(input_file_content, 20000..20500),
297 lines(input_file_content, 21000..21300),
298 ]
299 .join("Match found:\n\n"),
300 )],
301 ),
302 message(
303 User,
304 [text(indoc! {"
305 Comment out the lines that interact with the BlinkManager.
306 Keep the outer `update` blocks, but comments everything that's inside (including if statements).
307 Don't add additional comments.
308 "})],
309 ),
310 message(
311 Assistant,
312 [tool_use(
313 "tool_4",
314 "edit_file",
315 StreamingEditFileToolInput {
316 display_description: edit_description.into(),
317 path: input_file_path.into(),
318 },
319 )],
320 ),
321 ],
322 input_path: input_file_path.into(),
323 input_content: input_file_content.into(),
324 edit_description: edit_description.into(),
325 assertion: EvalAssertion::AssertEqual(output_file_content.into()),
326 },
327 );
328}
329
330#[test]
331#[cfg_attr(not(feature = "eval"), ignore)]
332fn eval_from_pixels_constructor() {
333 let input_file_path = "root/canvas.rs";
334 let input_file_content = include_str!("evals/fixtures/from_pixels_constructor/before.rs");
335 let edit_description = "Implement from_pixels constructor and add tests.";
336 eval(
337 100,
338 0.95,
339 EvalInput {
340 conversation: vec![
341 message(
342 User,
343 [text(indoc! {"
344 Introduce a new `from_pixels` constructor in Canvas and
345 also add tests for it in the same file.
346 "})],
347 ),
348 message(
349 Assistant,
350 [tool_use(
351 "tool_1",
352 "read_file",
353 ReadFileToolInput {
354 path: input_file_path.into(),
355 start_line: None,
356 end_line: None,
357 },
358 )],
359 ),
360 message(
361 User,
362 [tool_result("tool_1", "read_file", input_file_content)],
363 ),
364 message(
365 Assistant,
366 [tool_use(
367 "tool_2",
368 "grep",
369 GrepToolInput {
370 regex: "mod\\s+tests".into(),
371 include_pattern: Some("font-kit/src/canvas.rs".into()),
372 offset: 0,
373 case_sensitive: false,
374 },
375 )],
376 ),
377 message(User, [tool_result("tool_2", "grep", "No matches found")]),
378 message(
379 Assistant,
380 [tool_use(
381 "tool_3",
382 "grep",
383 GrepToolInput {
384 regex: "mod\\s+tests".into(),
385 include_pattern: Some("font-kit/src/**/*.rs".into()),
386 offset: 0,
387 case_sensitive: false,
388 },
389 )],
390 ),
391 message(User, [tool_result("tool_3", "grep", "No matches found")]),
392 message(
393 Assistant,
394 [tool_use(
395 "tool_4",
396 "grep",
397 GrepToolInput {
398 regex: "#\\[test\\]".into(),
399 include_pattern: Some("font-kit/src/**/*.rs".into()),
400 offset: 0,
401 case_sensitive: false,
402 },
403 )],
404 ),
405 message(
406 User,
407 [tool_result(
408 "tool_4",
409 "grep",
410 indoc! {"
411 Found 6 matches:
412
413 ## Matches in font-kit/src/loaders/core_text.rs
414
415 ### mod test › L926-936
416 ```
417 mod test {
418 use super::Font;
419 use crate::properties::{Stretch, Weight};
420
421 #[cfg(feature = \"source\")]
422 use crate::source::SystemSource;
423
424 static TEST_FONT_POSTSCRIPT_NAME: &'static str = \"ArialMT\";
425
426 #[cfg(feature = \"source\")]
427 #[test]
428 ```
429
430 55 lines remaining in ancestor node. Read the file to see all.
431
432 ### mod test › L947-951
433 ```
434 }
435
436 #[test]
437 fn test_core_text_to_css_font_weight() {
438 // Exact matches
439 ```
440
441 ### mod test › L959-963
442 ```
443 }
444
445 #[test]
446 fn test_core_text_to_css_font_stretch() {
447 // Exact matches
448 ```
449
450 ## Matches in font-kit/src/loaders/freetype.rs
451
452 ### mod test › L1238-1248
453 ```
454 mod test {
455 use crate::loaders::freetype::Font;
456
457 static PCF_FONT_PATH: &str = \"resources/tests/times-roman-pcf/timR12.pcf\";
458 static PCF_FONT_POSTSCRIPT_NAME: &str = \"Times-Roman\";
459
460 #[test]
461 fn get_pcf_postscript_name() {
462 let font = Font::from_path(PCF_FONT_PATH, 0).unwrap();
463 assert_eq!(font.postscript_name().unwrap(), PCF_FONT_POSTSCRIPT_NAME);
464 }
465 ```
466
467 1 lines remaining in ancestor node. Read the file to see all.
468
469 ## Matches in font-kit/src/sources/core_text.rs
470
471 ### mod test › L265-275
472 ```
473 mod test {
474 use crate::properties::{Stretch, Weight};
475
476 #[test]
477 fn test_css_to_core_text_font_weight() {
478 // Exact matches
479 assert_eq!(super::css_to_core_text_font_weight(Weight(100.0)), -0.7);
480 assert_eq!(super::css_to_core_text_font_weight(Weight(400.0)), 0.0);
481 assert_eq!(super::css_to_core_text_font_weight(Weight(700.0)), 0.4);
482 assert_eq!(super::css_to_core_text_font_weight(Weight(900.0)), 0.8);
483
484 ```
485
486 27 lines remaining in ancestor node. Read the file to see all.
487
488 ### mod test › L278-282
489 ```
490 }
491
492 #[test]
493 fn test_css_to_core_text_font_stretch() {
494 // Exact matches
495 ```
496 "},
497 )],
498 ),
499 message(
500 Assistant,
501 [tool_use(
502 "tool_5",
503 "edit_file",
504 StreamingEditFileToolInput {
505 display_description: edit_description.into(),
506 path: input_file_path.into(),
507 },
508 )],
509 ),
510 ],
511 input_path: input_file_path.into(),
512 input_content: input_file_content.into(),
513 edit_description: edit_description.into(),
514 assertion: EvalAssertion::JudgeDiff(indoc! {"
515 - The diff contains a new `from_pixels` constructor
516 - The diff contains new tests for the `from_pixels` constructor
517 "}),
518 },
519 );
520}
521
522fn message(
523 role: Role,
524 contents: impl IntoIterator<Item = MessageContent>,
525) -> LanguageModelRequestMessage {
526 LanguageModelRequestMessage {
527 role,
528 content: contents.into_iter().collect(),
529 cache: false,
530 }
531}
532
533fn text(text: impl Into<String>) -> MessageContent {
534 MessageContent::Text(text.into())
535}
536
537fn lines(input: &str, range: Range<usize>) -> String {
538 input
539 .lines()
540 .skip(range.start)
541 .take(range.len())
542 .collect::<Vec<_>>()
543 .join("\n")
544}
545
546fn tool_use(
547 id: impl Into<Arc<str>>,
548 name: impl Into<Arc<str>>,
549 input: impl Serialize,
550) -> MessageContent {
551 MessageContent::ToolUse(LanguageModelToolUse {
552 id: LanguageModelToolUseId::from(id.into()),
553 name: name.into(),
554 raw_input: serde_json::to_string_pretty(&input).unwrap(),
555 input: serde_json::to_value(input).unwrap(),
556 is_input_complete: true,
557 })
558}
559
560fn tool_result(
561 id: impl Into<Arc<str>>,
562 name: impl Into<Arc<str>>,
563 result: impl Into<Arc<str>>,
564) -> MessageContent {
565 MessageContent::ToolResult(LanguageModelToolResult {
566 tool_use_id: LanguageModelToolUseId::from(id.into()),
567 tool_name: name.into(),
568 is_error: false,
569 content: result.into(),
570 })
571}
572
573#[derive(Clone)]
574struct EvalInput {
575 conversation: Vec<LanguageModelRequestMessage>,
576 input_path: PathBuf,
577 input_content: String,
578 edit_description: String,
579 assertion: EvalAssertion,
580}
581
582fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
583 let mut evaluated_count = 0;
584 report_progress(evaluated_count, iterations);
585
586 let (tx, rx) = mpsc::channel();
587
588 // Cache the last message in the conversation, and run one instance of the eval so that
589 // all the next ones are cached.
590 eval.conversation.last_mut().unwrap().cache = true;
591 run_eval(eval.clone(), tx.clone());
592
593 let executor = gpui::background_executor();
594 for _ in 1..iterations {
595 let eval = eval.clone();
596 let tx = tx.clone();
597 executor.spawn(async move { run_eval(eval, tx) }).detach();
598 }
599 drop(tx);
600
601 let mut failed_count = 0;
602 let mut failed_evals = HashMap::default();
603 let mut errored_evals = HashMap::default();
604 let mut eval_outputs = Vec::new();
605 let mut cumulative_parser_metrics = EditParserMetrics::default();
606 while let Ok(output) = rx.recv() {
607 match output {
608 Ok(output) => {
609 cumulative_parser_metrics += output.edit_output._parser_metrics.clone();
610 eval_outputs.push(output.clone());
611 if output.assertion.score < 80 {
612 failed_count += 1;
613 failed_evals
614 .entry(output.buffer_text.clone())
615 .or_insert(Vec::new())
616 .push(output);
617 }
618 }
619 Err(error) => {
620 failed_count += 1;
621 *errored_evals.entry(format!("{:?}", error)).or_insert(0) += 1;
622 }
623 }
624
625 evaluated_count += 1;
626 report_progress(evaluated_count, iterations);
627 }
628
629 let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32;
630 println!("Actual pass ratio: {}\n", actual_pass_ratio);
631 if actual_pass_ratio < expected_pass_ratio {
632 let mut errored_evals = errored_evals.into_iter().collect::<Vec<_>>();
633 errored_evals.sort_by_key(|(_, count)| Reverse(*count));
634 for (error, count) in errored_evals {
635 println!("Eval errored {} times. Error: {}", count, error);
636 }
637
638 let mut failed_evals = failed_evals.into_iter().collect::<Vec<_>>();
639 failed_evals.sort_by_key(|(_, evals)| Reverse(evals.len()));
640 for (_buffer_output, failed_evals) in failed_evals {
641 let eval_output = failed_evals.first().unwrap();
642 println!("Eval failed {} times", failed_evals.len());
643 println!("{}", eval_output);
644 }
645
646 panic!(
647 "Actual pass ratio: {}\nExpected pass ratio: {}",
648 actual_pass_ratio, expected_pass_ratio
649 );
650 }
651
652 let mismatched_tag_ratio =
653 cumulative_parser_metrics.mismatched_tags as f32 / cumulative_parser_metrics.tags as f32;
654 if mismatched_tag_ratio > 0.02 {
655 for eval_output in eval_outputs {
656 println!("{}", eval_output);
657 }
658 panic!("Too many mismatched tags: {:?}", cumulative_parser_metrics);
659 }
660}
661
662fn run_eval(eval: EvalInput, tx: mpsc::Sender<Result<EvalOutput>>) {
663 let dispatcher = gpui::TestDispatcher::new(StdRng::from_entropy());
664 let mut cx = TestAppContext::build(dispatcher, None);
665 let output = cx.executor().block_test(async {
666 let test = EditAgentTest::new(&mut cx).await;
667 test.eval(eval, &mut cx).await
668 });
669 tx.send(output).unwrap();
670}
671
672#[derive(Clone)]
673struct EvalOutput {
674 assertion: EvalAssertionResult,
675 buffer_text: String,
676 edit_output: EditAgentOutput,
677 diff: String,
678}
679
680impl Display for EvalOutput {
681 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
682 writeln!(f, "Score: {:?}", self.assertion.score)?;
683 if let Some(message) = self.assertion.message.as_ref() {
684 writeln!(f, "Message: {}", message)?;
685 }
686
687 writeln!(f, "Diff:\n{}", self.diff)?;
688
689 writeln!(
690 f,
691 "Parser Metrics:\n{:#?}",
692 self.edit_output._parser_metrics
693 )?;
694 writeln!(f, "Raw Edits:\n{}", self.edit_output._raw_edits)?;
695 Ok(())
696 }
697}
698
699fn report_progress(evaluated_count: usize, iterations: usize) {
700 print!("\r\x1b[KEvaluated {}/{}", evaluated_count, iterations);
701 std::io::stdout().flush().unwrap();
702}
703
704struct EditAgentTest {
705 agent: EditAgent,
706 project: Entity<Project>,
707 judge_model: Arc<dyn LanguageModel>,
708}
709
710impl EditAgentTest {
711 async fn new(cx: &mut TestAppContext) -> Self {
712 cx.executor().allow_parking();
713 cx.update(settings::init);
714 cx.update(Project::init_settings);
715 cx.update(language::init);
716 cx.update(gpui_tokio::init);
717 cx.update(client::init_settings);
718
719 let fs = FakeFs::new(cx.executor().clone());
720 fs.insert_tree("/root", json!({})).await;
721 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
722 let (agent_model, judge_model) = cx
723 .update(|cx| {
724 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
725 cx.set_http_client(Arc::new(http_client));
726
727 let client = Client::production(cx);
728 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
729 language_model::init(client.clone(), cx);
730 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
731
732 cx.spawn(async move |cx| {
733 let agent_model =
734 Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
735 let judge_model =
736 Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
737 (agent_model.unwrap(), judge_model.unwrap())
738 })
739 })
740 .await;
741 let action_log = cx.new(|_| ActionLog::new(project.clone()));
742
743 Self {
744 agent: EditAgent::new(agent_model, action_log, Templates::new()),
745 project,
746 judge_model,
747 }
748 }
749
750 async fn load_model(
751 provider: &str,
752 id: &str,
753 cx: &mut AsyncApp,
754 ) -> Result<Arc<dyn LanguageModel>> {
755 let (provider, model) = cx.update(|cx| {
756 let models = LanguageModelRegistry::read_global(cx);
757 let model = models
758 .available_models(cx)
759 .find(|model| model.provider_id().0 == provider && model.id().0 == id)
760 .unwrap();
761 let provider = models.provider(&model.provider_id()).unwrap();
762 (provider, model)
763 })?;
764 cx.update(|cx| provider.authenticate(cx))?.await?;
765 Ok(model)
766 }
767
768 async fn eval(&self, eval: EvalInput, cx: &mut TestAppContext) -> Result<EvalOutput> {
769 let path = self
770 .project
771 .read_with(cx, |project, cx| {
772 project.find_project_path(eval.input_path, cx)
773 })
774 .unwrap();
775 let buffer = self
776 .project
777 .update(cx, |project, cx| project.open_buffer(path, cx))
778 .await
779 .unwrap();
780 buffer.update(cx, |buffer, cx| {
781 buffer.set_text(eval.input_content.clone(), cx)
782 });
783 let (edit_output, _events) = self.agent.edit(
784 buffer.clone(),
785 eval.edit_description,
786 eval.conversation,
787 &mut cx.to_async(),
788 );
789 let edit_output = edit_output.await?;
790 let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
791 let actual_diff = language::unified_diff(&eval.input_content, &buffer_text);
792 let assertion = match eval.assertion {
793 EvalAssertion::AssertEqual(expected_output) => EvalAssertionResult {
794 score: if strip_empty_lines(&buffer_text) == strip_empty_lines(&expected_output) {
795 100
796 } else {
797 0
798 },
799 message: None,
800 },
801 EvalAssertion::JudgeDiff(assertions) => self
802 .judge_diff(&actual_diff, assertions, &cx.to_async())
803 .await
804 .context("failed comparing diffs")?,
805 };
806
807 Ok(EvalOutput {
808 assertion,
809 diff: actual_diff,
810 buffer_text,
811 edit_output,
812 })
813 }
814
815 async fn judge_diff(
816 &self,
817 diff: &str,
818 assertions: &'static str,
819 cx: &AsyncApp,
820 ) -> Result<EvalAssertionResult> {
821 let prompt = DiffJudgeTemplate {
822 diff: diff.to_string(),
823 assertions,
824 }
825 .render(&self.agent.templates)
826 .unwrap();
827
828 let request = LanguageModelRequest {
829 messages: vec![LanguageModelRequestMessage {
830 role: Role::User,
831 content: vec![prompt.into()],
832 cache: false,
833 }],
834 ..Default::default()
835 };
836 let mut response = self.judge_model.stream_completion_text(request, cx).await?;
837 let mut output = String::new();
838 while let Some(chunk) = response.stream.next().await {
839 let chunk = chunk?;
840 output.push_str(&chunk);
841 }
842
843 // Parse the score from the response
844 let re = regex::Regex::new(r"<score>(\d+)</score>").unwrap();
845 if let Some(captures) = re.captures(&output) {
846 if let Some(score_match) = captures.get(1) {
847 let score = score_match.as_str().parse().unwrap_or(0);
848 return Ok(EvalAssertionResult {
849 score,
850 message: Some(output),
851 });
852 }
853 }
854
855 Err(anyhow!(
856 "No score found in response. Raw output: {}",
857 output
858 ))
859 }
860}
861
862#[derive(Clone, Debug, Eq, PartialEq, Hash)]
863enum EvalAssertion {
864 AssertEqual(String),
865 JudgeDiff(&'static str),
866}
867
868#[derive(Clone, Debug, Eq, PartialEq, Hash)]
869struct EvalAssertionResult {
870 score: usize,
871 message: Option<String>,
872}
873
874#[derive(Serialize)]
875pub struct DiffJudgeTemplate {
876 diff: String,
877 assertions: &'static str,
878}
879
880impl Template for DiffJudgeTemplate {
881 const TEMPLATE_NAME: &'static str = "diff_judge.hbs";
882}
883
884fn strip_empty_lines(text: &str) -> String {
885 text.lines()
886 .filter(|line| !line.trim().is_empty())
887 .collect::<Vec<_>>()
888 .join("\n")
889}