1use super::*;
2use crate::{
3 ReadFileToolInput, grep_tool::GrepToolInput,
4 streaming_edit_file_tool::StreamingEditFileToolInput,
5};
6use Role::*;
7use anyhow::anyhow;
8use client::{Client, UserStore};
9use collections::HashMap;
10use fs::FakeFs;
11use futures::{FutureExt, future::LocalBoxFuture};
12use gpui::{AppContext, TestAppContext};
13use indoc::indoc;
14use language_model::{
15 LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId,
16};
17use project::Project;
18use rand::prelude::*;
19use reqwest_client::ReqwestClient;
20use serde_json::json;
21use std::{
22 cmp::Reverse,
23 fmt::{self, Display},
24 io::Write as _,
25 sync::mpsc,
26};
27use util::path;
28
29#[test]
30#[cfg_attr(not(feature = "eval"), ignore)]
31fn eval_extract_handle_command_output() {
32 let input_file_path = "root/blame.rs";
33 let input_file_content = include_str!("evals/fixtures/extract_handle_command_output/before.rs");
34 let output_file_content = include_str!("evals/fixtures/extract_handle_command_output/after.rs");
35 let edit_description = "Extract `handle_command_output` method from `run_git_blame`.";
36 eval(
37 100,
38 0.95,
39 EvalInput {
40 conversation: vec![
41 message(
42 User,
43 [text(indoc! {"
44 Read the `{input_file_path}` file and extract a method in
45 the final stanza of `run_git_blame` to deal with command failures,
46 call it `handle_command_output` and take the std::process::Output as the only parameter.
47
48 Add it right next to `run_git_blame` and copy it verbatim from `run_git_blame`.
49 "})],
50 ),
51 message(
52 Assistant,
53 [tool_use(
54 "tool_1",
55 "read_file",
56 ReadFileToolInput {
57 path: input_file_path.into(),
58 start_line: None,
59 end_line: None,
60 },
61 )],
62 ),
63 message(
64 User,
65 [tool_result("tool_1", "read_file", input_file_content)],
66 ),
67 message(
68 Assistant,
69 [tool_use(
70 "tool_2",
71 "edit_file",
72 StreamingEditFileToolInput {
73 display_description: edit_description.into(),
74 path: input_file_path.into(),
75 create_or_overwrite: false,
76 },
77 )],
78 ),
79 ],
80 input_path: input_file_path.into(),
81 input_content: Some(input_file_content.into()),
82 edit_description: edit_description.into(),
83 assertion: EvalAssertion::assert_eq(output_file_content),
84 },
85 );
86}
87
88#[test]
89#[cfg_attr(not(feature = "eval"), ignore)]
90fn eval_delete_run_git_blame() {
91 let input_file_path = "root/blame.rs";
92 let input_file_content = include_str!("evals/fixtures/delete_run_git_blame/before.rs");
93 let output_file_content = include_str!("evals/fixtures/delete_run_git_blame/after.rs");
94 let edit_description = "Delete the `run_git_blame` function.";
95 eval(
96 100,
97 0.95,
98 EvalInput {
99 conversation: vec![
100 message(
101 User,
102 [text(indoc! {"
103 Read the `{input_file_path}` file and delete `run_git_blame`. Just that
104 one function, not its usages.
105 "})],
106 ),
107 message(
108 Assistant,
109 [tool_use(
110 "tool_1",
111 "read_file",
112 ReadFileToolInput {
113 path: input_file_path.into(),
114 start_line: None,
115 end_line: None,
116 },
117 )],
118 ),
119 message(
120 User,
121 [tool_result("tool_1", "read_file", input_file_content)],
122 ),
123 message(
124 Assistant,
125 [tool_use(
126 "tool_2",
127 "edit_file",
128 StreamingEditFileToolInput {
129 display_description: edit_description.into(),
130 path: input_file_path.into(),
131 create_or_overwrite: false,
132 },
133 )],
134 ),
135 ],
136 input_path: input_file_path.into(),
137 input_content: Some(input_file_content.into()),
138 edit_description: edit_description.into(),
139 assertion: EvalAssertion::assert_eq(output_file_content),
140 },
141 );
142}
143
144#[test]
145#[cfg_attr(not(feature = "eval"), ignore)]
146fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
147 let input_file_path = "root/lib.rs";
148 let input_file_content =
149 include_str!("evals/fixtures/use_wasi_sdk_in_compile_parser_to_wasm/before.rs");
150 let edit_description = "Update compile_parser_to_wasm to use wasi-sdk instead of emscripten";
151 eval(
152 100,
153 0.95,
154 EvalInput {
155 conversation: vec![
156 message(
157 User,
158 [text(indoc! {"
159 Read the `{input_file_path}` file and change `compile_parser_to_wasm` to use `wasi-sdk` instead of emscripten.
160 Use `ureq` to download the SDK for the current platform and architecture.
161 Extract the archive into a sibling of `lib` inside the `tree-sitter` directory in the cache_dir.
162 Compile the parser to wasm using the `bin/clang` executable (or `bin/clang.exe` on windows)
163 that's inside of the archive.
164 Don't re-download the SDK if that executable already exists.
165
166 Use these clang flags: -fPIC -shared -Os -Wl,--export=tree_sitter_{language_name}
167
168 Here are the available wasi-sdk assets:
169 - wasi-sdk-25.0-x86_64-macos.tar.gz
170 - wasi-sdk-25.0-arm64-macos.tar.gz
171 - wasi-sdk-25.0-x86_64-linux.tar.gz
172 - wasi-sdk-25.0-arm64-linux.tar.gz
173 - wasi-sdk-25.0-x86_64-linux.tar.gz
174 - wasi-sdk-25.0-arm64-linux.tar.gz
175 - wasi-sdk-25.0-x86_64-windows.tar.gz
176 "})],
177 ),
178 message(
179 Assistant,
180 [tool_use(
181 "tool_1",
182 "read_file",
183 ReadFileToolInput {
184 path: input_file_path.into(),
185 start_line: Some(971),
186 end_line: Some(1050),
187 },
188 )],
189 ),
190 message(
191 User,
192 [tool_result(
193 "tool_1",
194 "read_file",
195 lines(input_file_content, 971..1050),
196 )],
197 ),
198 message(
199 Assistant,
200 [tool_use(
201 "tool_2",
202 "read_file",
203 ReadFileToolInput {
204 path: input_file_path.into(),
205 start_line: Some(1050),
206 end_line: Some(1100),
207 },
208 )],
209 ),
210 message(
211 User,
212 [tool_result(
213 "tool_2",
214 "read_file",
215 lines(input_file_content, 1050..1100),
216 )],
217 ),
218 message(
219 Assistant,
220 [tool_use(
221 "tool_3",
222 "read_file",
223 ReadFileToolInput {
224 path: input_file_path.into(),
225 start_line: Some(1100),
226 end_line: Some(1150),
227 },
228 )],
229 ),
230 message(
231 User,
232 [tool_result(
233 "tool_3",
234 "read_file",
235 lines(input_file_content, 1100..1150),
236 )],
237 ),
238 message(
239 Assistant,
240 [tool_use(
241 "tool_4",
242 "edit_file",
243 StreamingEditFileToolInput {
244 display_description: edit_description.into(),
245 path: input_file_path.into(),
246 create_or_overwrite: false,
247 },
248 )],
249 ),
250 ],
251 input_path: input_file_path.into(),
252 input_content: Some(input_file_content.into()),
253 edit_description: edit_description.into(),
254 assertion: EvalAssertion::judge_diff(indoc! {"
255 - The compile_parser_to_wasm method has been changed to use wasi-sdk
256 - ureq is used to download the SDK for current platform and architecture
257 "}),
258 },
259 );
260}
261
262#[test]
263#[cfg_attr(not(feature = "eval"), ignore)]
264fn eval_disable_cursor_blinking() {
265 let input_file_path = "root/editor.rs";
266 let input_file_content = include_str!("evals/fixtures/disable_cursor_blinking/before.rs");
267 let output_file_content = include_str!("evals/fixtures/disable_cursor_blinking/after.rs");
268 let edit_description = "Comment out the call to `BlinkManager::enable`";
269 eval(
270 100,
271 0.6, // TODO: make this eval better
272 EvalInput {
273 conversation: vec![
274 message(User, [text("Let's research how to cursor blinking works.")]),
275 message(
276 Assistant,
277 [tool_use(
278 "tool_1",
279 "grep",
280 GrepToolInput {
281 regex: "blink".into(),
282 include_pattern: None,
283 offset: 0,
284 case_sensitive: false,
285 },
286 )],
287 ),
288 message(
289 User,
290 [tool_result(
291 "tool_1",
292 "grep",
293 [
294 lines(input_file_content, 100..400),
295 lines(input_file_content, 800..1300),
296 lines(input_file_content, 1600..2000),
297 lines(input_file_content, 5000..5500),
298 lines(input_file_content, 8000..9000),
299 lines(input_file_content, 18455..18470),
300 lines(input_file_content, 20000..20500),
301 lines(input_file_content, 21000..21300),
302 ]
303 .join("Match found:\n\n"),
304 )],
305 ),
306 message(
307 User,
308 [text(indoc! {"
309 Comment out the lines that interact with the BlinkManager.
310 Keep the outer `update` blocks, but comments everything that's inside (including if statements).
311 Don't add additional comments.
312 "})],
313 ),
314 message(
315 Assistant,
316 [tool_use(
317 "tool_4",
318 "edit_file",
319 StreamingEditFileToolInput {
320 display_description: edit_description.into(),
321 path: input_file_path.into(),
322 create_or_overwrite: false,
323 },
324 )],
325 ),
326 ],
327 input_path: input_file_path.into(),
328 input_content: Some(input_file_content.into()),
329 edit_description: edit_description.into(),
330 assertion: EvalAssertion::assert_eq(output_file_content),
331 },
332 );
333}
334
335#[test]
336#[cfg_attr(not(feature = "eval"), ignore)]
337fn eval_from_pixels_constructor() {
338 let input_file_path = "root/canvas.rs";
339 let input_file_content = include_str!("evals/fixtures/from_pixels_constructor/before.rs");
340 let edit_description = "Implement from_pixels constructor and add tests.";
341 eval(
342 100,
343 0.95,
344 EvalInput {
345 conversation: vec![
346 message(
347 User,
348 [text(indoc! {"
349 Introduce a new `from_pixels` constructor in Canvas and
350 also add tests for it in the same file.
351 "})],
352 ),
353 message(
354 Assistant,
355 [tool_use(
356 "tool_1",
357 "read_file",
358 ReadFileToolInput {
359 path: input_file_path.into(),
360 start_line: None,
361 end_line: None,
362 },
363 )],
364 ),
365 message(
366 User,
367 [tool_result("tool_1", "read_file", input_file_content)],
368 ),
369 message(
370 Assistant,
371 [tool_use(
372 "tool_2",
373 "grep",
374 GrepToolInput {
375 regex: "mod\\s+tests".into(),
376 include_pattern: Some("font-kit/src/canvas.rs".into()),
377 offset: 0,
378 case_sensitive: false,
379 },
380 )],
381 ),
382 message(User, [tool_result("tool_2", "grep", "No matches found")]),
383 message(
384 Assistant,
385 [tool_use(
386 "tool_3",
387 "grep",
388 GrepToolInput {
389 regex: "mod\\s+tests".into(),
390 include_pattern: Some("font-kit/src/**/*.rs".into()),
391 offset: 0,
392 case_sensitive: false,
393 },
394 )],
395 ),
396 message(User, [tool_result("tool_3", "grep", "No matches found")]),
397 message(
398 Assistant,
399 [tool_use(
400 "tool_4",
401 "grep",
402 GrepToolInput {
403 regex: "#\\[test\\]".into(),
404 include_pattern: Some("font-kit/src/**/*.rs".into()),
405 offset: 0,
406 case_sensitive: false,
407 },
408 )],
409 ),
410 message(
411 User,
412 [tool_result(
413 "tool_4",
414 "grep",
415 indoc! {"
416 Found 6 matches:
417
418 ## Matches in font-kit/src/loaders/core_text.rs
419
420 ### mod test › L926-936
421 ```
422 mod test {
423 use super::Font;
424 use crate::properties::{Stretch, Weight};
425
426 #[cfg(feature = \"source\")]
427 use crate::source::SystemSource;
428
429 static TEST_FONT_POSTSCRIPT_NAME: &'static str = \"ArialMT\";
430
431 #[cfg(feature = \"source\")]
432 #[test]
433 ```
434
435 55 lines remaining in ancestor node. Read the file to see all.
436
437 ### mod test › L947-951
438 ```
439 }
440
441 #[test]
442 fn test_core_text_to_css_font_weight() {
443 // Exact matches
444 ```
445
446 ### mod test › L959-963
447 ```
448 }
449
450 #[test]
451 fn test_core_text_to_css_font_stretch() {
452 // Exact matches
453 ```
454
455 ## Matches in font-kit/src/loaders/freetype.rs
456
457 ### mod test › L1238-1248
458 ```
459 mod test {
460 use crate::loaders::freetype::Font;
461
462 static PCF_FONT_PATH: &str = \"resources/tests/times-roman-pcf/timR12.pcf\";
463 static PCF_FONT_POSTSCRIPT_NAME: &str = \"Times-Roman\";
464
465 #[test]
466 fn get_pcf_postscript_name() {
467 let font = Font::from_path(PCF_FONT_PATH, 0).unwrap();
468 assert_eq!(font.postscript_name().unwrap(), PCF_FONT_POSTSCRIPT_NAME);
469 }
470 ```
471
472 1 lines remaining in ancestor node. Read the file to see all.
473
474 ## Matches in font-kit/src/sources/core_text.rs
475
476 ### mod test › L265-275
477 ```
478 mod test {
479 use crate::properties::{Stretch, Weight};
480
481 #[test]
482 fn test_css_to_core_text_font_weight() {
483 // Exact matches
484 assert_eq!(super::css_to_core_text_font_weight(Weight(100.0)), -0.7);
485 assert_eq!(super::css_to_core_text_font_weight(Weight(400.0)), 0.0);
486 assert_eq!(super::css_to_core_text_font_weight(Weight(700.0)), 0.4);
487 assert_eq!(super::css_to_core_text_font_weight(Weight(900.0)), 0.8);
488
489 ```
490
491 27 lines remaining in ancestor node. Read the file to see all.
492
493 ### mod test › L278-282
494 ```
495 }
496
497 #[test]
498 fn test_css_to_core_text_font_stretch() {
499 // Exact matches
500 ```
501 "},
502 )],
503 ),
504 message(
505 Assistant,
506 [tool_use(
507 "tool_5",
508 "edit_file",
509 StreamingEditFileToolInput {
510 display_description: edit_description.into(),
511 path: input_file_path.into(),
512 create_or_overwrite: false,
513 },
514 )],
515 ),
516 ],
517 input_path: input_file_path.into(),
518 input_content: Some(input_file_content.into()),
519 edit_description: edit_description.into(),
520 assertion: EvalAssertion::judge_diff(indoc! {"
521 - The diff contains a new `from_pixels` constructor
522 - The diff contains new tests for the `from_pixels` constructor
523 "}),
524 },
525 );
526}
527
528#[test]
529#[cfg_attr(not(feature = "eval"), ignore)]
530fn eval_zode() {
531 let input_file_path = "root/zode.py";
532 let edit_description = "Create the main Zode CLI script";
533 eval(
534 200,
535 1.,
536 EvalInput {
537 conversation: vec![
538 message(User, [text(include_str!("evals/fixtures/zode/prompt.md"))]),
539 message(
540 Assistant,
541 [
542 tool_use(
543 "tool_1",
544 "read_file",
545 ReadFileToolInput {
546 path: "root/eval/react.py".into(),
547 start_line: None,
548 end_line: None,
549 },
550 ),
551 tool_use(
552 "tool_2",
553 "read_file",
554 ReadFileToolInput {
555 path: "root/eval/react_test.py".into(),
556 start_line: None,
557 end_line: None,
558 },
559 ),
560 ],
561 ),
562 message(
563 User,
564 [
565 tool_result(
566 "tool_1",
567 "read_file",
568 include_str!("evals/fixtures/zode/react.py"),
569 ),
570 tool_result(
571 "tool_2",
572 "read_file",
573 include_str!("evals/fixtures/zode/react_test.py"),
574 ),
575 ],
576 ),
577 message(
578 Assistant,
579 [
580 text(
581 "Now that I understand what we need to build, I'll create the main Python script:",
582 ),
583 tool_use(
584 "tool_3",
585 "edit_file",
586 StreamingEditFileToolInput {
587 display_description: edit_description.into(),
588 path: input_file_path.into(),
589 create_or_overwrite: true,
590 },
591 ),
592 ],
593 ),
594 ],
595 input_path: input_file_path.into(),
596 input_content: None,
597 edit_description: edit_description.into(),
598 assertion: EvalAssertion::new(async move |sample, _, _cx| {
599 let invalid_starts = [' ', '`', '\n'];
600 let mut message = String::new();
601 for start in invalid_starts {
602 if sample.text.starts_with(start) {
603 message.push_str(&format!("The sample starts with a {:?}\n", start));
604 break;
605 }
606 }
607 // Remove trailing newline.
608 message.pop();
609
610 if message.is_empty() {
611 Ok(EvalAssertionOutcome {
612 score: 100,
613 message: None,
614 })
615 } else {
616 Ok(EvalAssertionOutcome {
617 score: 0,
618 message: Some(message),
619 })
620 }
621 }),
622 },
623 );
624}
625
626fn message(
627 role: Role,
628 contents: impl IntoIterator<Item = MessageContent>,
629) -> LanguageModelRequestMessage {
630 LanguageModelRequestMessage {
631 role,
632 content: contents.into_iter().collect(),
633 cache: false,
634 }
635}
636
637fn text(text: impl Into<String>) -> MessageContent {
638 MessageContent::Text(text.into())
639}
640
641fn lines(input: &str, range: Range<usize>) -> String {
642 input
643 .lines()
644 .skip(range.start)
645 .take(range.len())
646 .collect::<Vec<_>>()
647 .join("\n")
648}
649
650fn tool_use(
651 id: impl Into<Arc<str>>,
652 name: impl Into<Arc<str>>,
653 input: impl Serialize,
654) -> MessageContent {
655 MessageContent::ToolUse(LanguageModelToolUse {
656 id: LanguageModelToolUseId::from(id.into()),
657 name: name.into(),
658 raw_input: serde_json::to_string_pretty(&input).unwrap(),
659 input: serde_json::to_value(input).unwrap(),
660 is_input_complete: true,
661 })
662}
663
664fn tool_result(
665 id: impl Into<Arc<str>>,
666 name: impl Into<Arc<str>>,
667 result: impl Into<Arc<str>>,
668) -> MessageContent {
669 MessageContent::ToolResult(LanguageModelToolResult {
670 tool_use_id: LanguageModelToolUseId::from(id.into()),
671 tool_name: name.into(),
672 is_error: false,
673 content: result.into(),
674 })
675}
676
677#[derive(Clone)]
678struct EvalInput {
679 conversation: Vec<LanguageModelRequestMessage>,
680 input_path: PathBuf,
681 input_content: Option<String>,
682 edit_description: String,
683 assertion: EvalAssertion,
684}
685
686#[derive(Clone)]
687struct EvalSample {
688 text: String,
689 edit_output: EditAgentOutput,
690 diff: String,
691}
692
693trait AssertionFn: 'static + Send + Sync {
694 fn assert<'a>(
695 &'a self,
696 sample: &'a EvalSample,
697 judge_model: Arc<dyn LanguageModel>,
698 cx: &'a mut TestAppContext,
699 ) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>>;
700}
701
702impl<F> AssertionFn for F
703where
704 F: 'static
705 + Send
706 + Sync
707 + AsyncFn(
708 &EvalSample,
709 Arc<dyn LanguageModel>,
710 &mut TestAppContext,
711 ) -> Result<EvalAssertionOutcome>,
712{
713 fn assert<'a>(
714 &'a self,
715 sample: &'a EvalSample,
716 judge_model: Arc<dyn LanguageModel>,
717 cx: &'a mut TestAppContext,
718 ) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>> {
719 (self)(sample, judge_model, cx).boxed_local()
720 }
721}
722
723#[derive(Clone)]
724struct EvalAssertion(Arc<dyn AssertionFn>);
725
726impl EvalAssertion {
727 fn new<F>(f: F) -> Self
728 where
729 F: 'static
730 + Send
731 + Sync
732 + AsyncFn(
733 &EvalSample,
734 Arc<dyn LanguageModel>,
735 &mut TestAppContext,
736 ) -> Result<EvalAssertionOutcome>,
737 {
738 EvalAssertion(Arc::new(f))
739 }
740
741 fn assert_eq(expected: impl Into<String>) -> Self {
742 let expected = expected.into();
743 Self::new(async move |sample, _judge, _cx| {
744 Ok(EvalAssertionOutcome {
745 score: if strip_empty_lines(&sample.text) == strip_empty_lines(&expected) {
746 100
747 } else {
748 0
749 },
750 message: None,
751 })
752 })
753 }
754
755 fn judge_diff(assertions: &'static str) -> Self {
756 Self::new(async move |sample, judge, cx| {
757 let prompt = DiffJudgeTemplate {
758 diff: sample.diff.clone(),
759 assertions,
760 }
761 .render(&Templates::new())
762 .unwrap();
763
764 let request = LanguageModelRequest {
765 messages: vec![LanguageModelRequestMessage {
766 role: Role::User,
767 content: vec![prompt.into()],
768 cache: false,
769 }],
770 ..Default::default()
771 };
772 let mut response = judge
773 .stream_completion_text(request, &cx.to_async())
774 .await?;
775 let mut output = String::new();
776 while let Some(chunk) = response.stream.next().await {
777 let chunk = chunk?;
778 output.push_str(&chunk);
779 }
780
781 // Parse the score from the response
782 let re = regex::Regex::new(r"<score>(\d+)</score>").unwrap();
783 if let Some(captures) = re.captures(&output) {
784 if let Some(score_match) = captures.get(1) {
785 let score = score_match.as_str().parse().unwrap_or(0);
786 return Ok(EvalAssertionOutcome {
787 score,
788 message: Some(output),
789 });
790 }
791 }
792
793 Err(anyhow!(
794 "No score found in response. Raw output: {}",
795 output
796 ))
797 })
798 }
799
800 async fn run(
801 &self,
802 input: &EvalSample,
803 judge_model: Arc<dyn LanguageModel>,
804 cx: &mut TestAppContext,
805 ) -> Result<EvalAssertionOutcome> {
806 self.0.assert(input, judge_model, cx).await
807 }
808}
809
810fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
811 let mut evaluated_count = 0;
812 report_progress(evaluated_count, iterations);
813
814 let (tx, rx) = mpsc::channel();
815
816 // Cache the last message in the conversation, and run one instance of the eval so that
817 // all the next ones are cached.
818 eval.conversation.last_mut().unwrap().cache = true;
819 run_eval(eval.clone(), tx.clone());
820
821 let executor = gpui::background_executor();
822 for _ in 1..iterations {
823 let eval = eval.clone();
824 let tx = tx.clone();
825 executor.spawn(async move { run_eval(eval, tx) }).detach();
826 }
827 drop(tx);
828
829 let mut failed_count = 0;
830 let mut failed_evals = HashMap::default();
831 let mut errored_evals = HashMap::default();
832 let mut eval_outputs = Vec::new();
833 let mut cumulative_parser_metrics = EditParserMetrics::default();
834 while let Ok(output) = rx.recv() {
835 match output {
836 Ok(output) => {
837 cumulative_parser_metrics += output.sample.edit_output._parser_metrics.clone();
838 eval_outputs.push(output.clone());
839 if output.assertion.score < 80 {
840 failed_count += 1;
841 failed_evals
842 .entry(output.sample.text.clone())
843 .or_insert(Vec::new())
844 .push(output);
845 }
846 }
847 Err(error) => {
848 failed_count += 1;
849 *errored_evals.entry(format!("{:?}", error)).or_insert(0) += 1;
850 }
851 }
852
853 evaluated_count += 1;
854 report_progress(evaluated_count, iterations);
855 }
856
857 let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32;
858 println!("Actual pass ratio: {}\n", actual_pass_ratio);
859 if actual_pass_ratio < expected_pass_ratio {
860 let mut errored_evals = errored_evals.into_iter().collect::<Vec<_>>();
861 errored_evals.sort_by_key(|(_, count)| Reverse(*count));
862 for (error, count) in errored_evals {
863 println!("Eval errored {} times. Error: {}", count, error);
864 }
865
866 let mut failed_evals = failed_evals.into_iter().collect::<Vec<_>>();
867 failed_evals.sort_by_key(|(_, evals)| Reverse(evals.len()));
868 for (_buffer_output, failed_evals) in failed_evals {
869 let eval_output = failed_evals.first().unwrap();
870 println!("Eval failed {} times", failed_evals.len());
871 println!("{}", eval_output);
872 }
873
874 panic!(
875 "Actual pass ratio: {}\nExpected pass ratio: {}",
876 actual_pass_ratio, expected_pass_ratio
877 );
878 }
879
880 let mismatched_tag_ratio =
881 cumulative_parser_metrics.mismatched_tags as f32 / cumulative_parser_metrics.tags as f32;
882 if mismatched_tag_ratio > 0.02 {
883 for eval_output in eval_outputs {
884 println!("{}", eval_output);
885 }
886 panic!("Too many mismatched tags: {:?}", cumulative_parser_metrics);
887 }
888}
889
890fn run_eval(eval: EvalInput, tx: mpsc::Sender<Result<EvalOutput>>) {
891 let dispatcher = gpui::TestDispatcher::new(StdRng::from_entropy());
892 let mut cx = TestAppContext::build(dispatcher, None);
893 let output = cx.executor().block_test(async {
894 let test = EditAgentTest::new(&mut cx).await;
895 test.eval(eval, &mut cx).await
896 });
897 tx.send(output).unwrap();
898}
899
900#[derive(Clone)]
901struct EvalOutput {
902 sample: EvalSample,
903 assertion: EvalAssertionOutcome,
904}
905
906impl Display for EvalOutput {
907 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
908 writeln!(f, "Score: {:?}", self.assertion.score)?;
909 if let Some(message) = self.assertion.message.as_ref() {
910 writeln!(f, "Message: {}", message)?;
911 }
912
913 writeln!(f, "Diff:\n{}", self.sample.diff)?;
914
915 writeln!(
916 f,
917 "Parser Metrics:\n{:#?}",
918 self.sample.edit_output._parser_metrics
919 )?;
920 writeln!(f, "Raw Edits:\n{}", self.sample.edit_output._raw_edits)?;
921 Ok(())
922 }
923}
924
925fn report_progress(evaluated_count: usize, iterations: usize) {
926 print!("\r\x1b[KEvaluated {}/{}", evaluated_count, iterations);
927 std::io::stdout().flush().unwrap();
928}
929
930struct EditAgentTest {
931 agent: EditAgent,
932 project: Entity<Project>,
933 judge_model: Arc<dyn LanguageModel>,
934}
935
936impl EditAgentTest {
937 async fn new(cx: &mut TestAppContext) -> Self {
938 cx.executor().allow_parking();
939 cx.update(settings::init);
940 cx.update(Project::init_settings);
941 cx.update(language::init);
942 cx.update(gpui_tokio::init);
943 cx.update(client::init_settings);
944
945 let fs = FakeFs::new(cx.executor().clone());
946 fs.insert_tree("/root", json!({})).await;
947 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
948 let (agent_model, judge_model) = cx
949 .update(|cx| {
950 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
951 cx.set_http_client(Arc::new(http_client));
952
953 let client = Client::production(cx);
954 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
955 language_model::init(client.clone(), cx);
956 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
957
958 cx.spawn(async move |cx| {
959 let agent_model =
960 Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
961 let judge_model =
962 Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
963 (agent_model.unwrap(), judge_model.unwrap())
964 })
965 })
966 .await;
967 let action_log = cx.new(|_| ActionLog::new(project.clone()));
968
969 Self {
970 agent: EditAgent::new(agent_model, project.clone(), action_log, Templates::new()),
971 project,
972 judge_model,
973 }
974 }
975
976 async fn load_model(
977 provider: &str,
978 id: &str,
979 cx: &mut AsyncApp,
980 ) -> Result<Arc<dyn LanguageModel>> {
981 let (provider, model) = cx.update(|cx| {
982 let models = LanguageModelRegistry::read_global(cx);
983 let model = models
984 .available_models(cx)
985 .find(|model| model.provider_id().0 == provider && model.id().0 == id)
986 .unwrap();
987 let provider = models.provider(&model.provider_id()).unwrap();
988 (provider, model)
989 })?;
990 cx.update(|cx| provider.authenticate(cx))?.await?;
991 Ok(model)
992 }
993
994 async fn eval(&self, eval: EvalInput, cx: &mut TestAppContext) -> Result<EvalOutput> {
995 let path = self
996 .project
997 .read_with(cx, |project, cx| {
998 project.find_project_path(eval.input_path, cx)
999 })
1000 .unwrap();
1001 let buffer = self
1002 .project
1003 .update(cx, |project, cx| project.open_buffer(path, cx))
1004 .await
1005 .unwrap();
1006 let edit_output = if let Some(input_content) = eval.input_content.as_deref() {
1007 buffer.update(cx, |buffer, cx| buffer.set_text(input_content, cx));
1008 let (edit_output, _) = self.agent.edit(
1009 buffer.clone(),
1010 eval.edit_description,
1011 eval.conversation,
1012 &mut cx.to_async(),
1013 );
1014 edit_output.await?
1015 } else {
1016 let (edit_output, _) = self.agent.overwrite(
1017 buffer.clone(),
1018 eval.edit_description,
1019 eval.conversation,
1020 &mut cx.to_async(),
1021 );
1022 edit_output.await?
1023 };
1024
1025 let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
1026 let sample = EvalSample {
1027 edit_output,
1028 diff: language::unified_diff(
1029 eval.input_content.as_deref().unwrap_or_default(),
1030 &buffer_text,
1031 ),
1032 text: buffer_text,
1033 };
1034 let assertion = eval
1035 .assertion
1036 .run(&sample, self.judge_model.clone(), cx)
1037 .await?;
1038
1039 Ok(EvalOutput { assertion, sample })
1040 }
1041}
1042
1043#[derive(Clone, Debug, Eq, PartialEq, Hash)]
1044struct EvalAssertionOutcome {
1045 score: usize,
1046 message: Option<String>,
1047}
1048
1049#[derive(Serialize)]
1050pub struct DiffJudgeTemplate {
1051 diff: String,
1052 assertions: &'static str,
1053}
1054
1055impl Template for DiffJudgeTemplate {
1056 const TEMPLATE_NAME: &'static str = "diff_judge.hbs";
1057}
1058
1059fn strip_empty_lines(text: &str) -> String {
1060 text.lines()
1061 .filter(|line| !line.trim().is_empty())
1062 .collect::<Vec<_>>()
1063 .join("\n")
1064}