1use super::*;
2use crate::{ReadFileToolInput, edit_file_tool::EditFileToolInput, grep_tool::GrepToolInput};
3use Role::*;
4use anyhow::anyhow;
5use client::{Client, UserStore};
6use collections::HashMap;
7use fs::FakeFs;
8use futures::{FutureExt, future::LocalBoxFuture};
9use gpui::{AppContext, TestAppContext};
10use indoc::indoc;
11use language_model::{
12 LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId,
13};
14use project::Project;
15use rand::prelude::*;
16use reqwest_client::ReqwestClient;
17use serde_json::json;
18use std::{
19 cmp::Reverse,
20 fmt::{self, Display},
21 io::Write as _,
22 sync::mpsc,
23};
24use util::path;
25
26#[test]
27#[cfg_attr(not(feature = "eval"), ignore)]
28fn eval_extract_handle_command_output() {
29 let input_file_path = "root/blame.rs";
30 let input_file_content = include_str!("evals/fixtures/extract_handle_command_output/before.rs");
31 let output_file_content = include_str!("evals/fixtures/extract_handle_command_output/after.rs");
32 let edit_description = "Extract `handle_command_output` method from `run_git_blame`.";
33 eval(
34 100,
35 0.95,
36 EvalInput {
37 conversation: vec![
38 message(
39 User,
40 [text(indoc! {"
41 Read the `{input_file_path}` file and extract a method in
42 the final stanza of `run_git_blame` to deal with command failures,
43 call it `handle_command_output` and take the std::process::Output as the only parameter.
44
45 Add it right next to `run_git_blame` and copy it verbatim from `run_git_blame`.
46 "})],
47 ),
48 message(
49 Assistant,
50 [tool_use(
51 "tool_1",
52 "read_file",
53 ReadFileToolInput {
54 path: input_file_path.into(),
55 start_line: None,
56 end_line: None,
57 },
58 )],
59 ),
60 message(
61 User,
62 [tool_result("tool_1", "read_file", input_file_content)],
63 ),
64 message(
65 Assistant,
66 [tool_use(
67 "tool_2",
68 "edit_file",
69 EditFileToolInput {
70 display_description: edit_description.into(),
71 path: input_file_path.into(),
72 create_or_overwrite: false,
73 },
74 )],
75 ),
76 ],
77 input_path: input_file_path.into(),
78 input_content: Some(input_file_content.into()),
79 edit_description: edit_description.into(),
80 assertion: EvalAssertion::assert_eq(output_file_content),
81 },
82 );
83}
84
85#[test]
86#[cfg_attr(not(feature = "eval"), ignore)]
87fn eval_delete_run_git_blame() {
88 let input_file_path = "root/blame.rs";
89 let input_file_content = include_str!("evals/fixtures/delete_run_git_blame/before.rs");
90 let output_file_content = include_str!("evals/fixtures/delete_run_git_blame/after.rs");
91 let edit_description = "Delete the `run_git_blame` function.";
92 eval(
93 100,
94 0.95,
95 EvalInput {
96 conversation: vec![
97 message(
98 User,
99 [text(indoc! {"
100 Read the `{input_file_path}` file and delete `run_git_blame`. Just that
101 one function, not its usages.
102 "})],
103 ),
104 message(
105 Assistant,
106 [tool_use(
107 "tool_1",
108 "read_file",
109 ReadFileToolInput {
110 path: input_file_path.into(),
111 start_line: None,
112 end_line: None,
113 },
114 )],
115 ),
116 message(
117 User,
118 [tool_result("tool_1", "read_file", input_file_content)],
119 ),
120 message(
121 Assistant,
122 [tool_use(
123 "tool_2",
124 "edit_file",
125 EditFileToolInput {
126 display_description: edit_description.into(),
127 path: input_file_path.into(),
128 create_or_overwrite: false,
129 },
130 )],
131 ),
132 ],
133 input_path: input_file_path.into(),
134 input_content: Some(input_file_content.into()),
135 edit_description: edit_description.into(),
136 assertion: EvalAssertion::assert_eq(output_file_content),
137 },
138 );
139}
140
141#[test]
142#[cfg_attr(not(feature = "eval"), ignore)]
143fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
144 let input_file_path = "root/lib.rs";
145 let input_file_content =
146 include_str!("evals/fixtures/use_wasi_sdk_in_compile_parser_to_wasm/before.rs");
147 let edit_description = "Update compile_parser_to_wasm to use wasi-sdk instead of emscripten";
148 eval(
149 100,
150 0.95,
151 EvalInput {
152 conversation: vec![
153 message(
154 User,
155 [text(indoc! {"
156 Read the `{input_file_path}` file and change `compile_parser_to_wasm` to use `wasi-sdk` instead of emscripten.
157 Use `ureq` to download the SDK for the current platform and architecture.
158 Extract the archive into a sibling of `lib` inside the `tree-sitter` directory in the cache_dir.
159 Compile the parser to wasm using the `bin/clang` executable (or `bin/clang.exe` on windows)
160 that's inside of the archive.
161 Don't re-download the SDK if that executable already exists.
162
163 Use these clang flags: -fPIC -shared -Os -Wl,--export=tree_sitter_{language_name}
164
165 Here are the available wasi-sdk assets:
166 - wasi-sdk-25.0-x86_64-macos.tar.gz
167 - wasi-sdk-25.0-arm64-macos.tar.gz
168 - wasi-sdk-25.0-x86_64-linux.tar.gz
169 - wasi-sdk-25.0-arm64-linux.tar.gz
170 - wasi-sdk-25.0-x86_64-linux.tar.gz
171 - wasi-sdk-25.0-arm64-linux.tar.gz
172 - wasi-sdk-25.0-x86_64-windows.tar.gz
173 "})],
174 ),
175 message(
176 Assistant,
177 [tool_use(
178 "tool_1",
179 "read_file",
180 ReadFileToolInput {
181 path: input_file_path.into(),
182 start_line: Some(971),
183 end_line: Some(1050),
184 },
185 )],
186 ),
187 message(
188 User,
189 [tool_result(
190 "tool_1",
191 "read_file",
192 lines(input_file_content, 971..1050),
193 )],
194 ),
195 message(
196 Assistant,
197 [tool_use(
198 "tool_2",
199 "read_file",
200 ReadFileToolInput {
201 path: input_file_path.into(),
202 start_line: Some(1050),
203 end_line: Some(1100),
204 },
205 )],
206 ),
207 message(
208 User,
209 [tool_result(
210 "tool_2",
211 "read_file",
212 lines(input_file_content, 1050..1100),
213 )],
214 ),
215 message(
216 Assistant,
217 [tool_use(
218 "tool_3",
219 "read_file",
220 ReadFileToolInput {
221 path: input_file_path.into(),
222 start_line: Some(1100),
223 end_line: Some(1150),
224 },
225 )],
226 ),
227 message(
228 User,
229 [tool_result(
230 "tool_3",
231 "read_file",
232 lines(input_file_content, 1100..1150),
233 )],
234 ),
235 message(
236 Assistant,
237 [tool_use(
238 "tool_4",
239 "edit_file",
240 EditFileToolInput {
241 display_description: edit_description.into(),
242 path: input_file_path.into(),
243 create_or_overwrite: false,
244 },
245 )],
246 ),
247 ],
248 input_path: input_file_path.into(),
249 input_content: Some(input_file_content.into()),
250 edit_description: edit_description.into(),
251 assertion: EvalAssertion::judge_diff(indoc! {"
252 - The compile_parser_to_wasm method has been changed to use wasi-sdk
253 - ureq is used to download the SDK for current platform and architecture
254 "}),
255 },
256 );
257}
258
259#[test]
260#[cfg_attr(not(feature = "eval"), ignore)]
261fn eval_disable_cursor_blinking() {
262 let input_file_path = "root/editor.rs";
263 let input_file_content = include_str!("evals/fixtures/disable_cursor_blinking/before.rs");
264 let output_file_content = include_str!("evals/fixtures/disable_cursor_blinking/after.rs");
265 let edit_description = "Comment out the call to `BlinkManager::enable`";
266 eval(
267 200,
268 0.6, // TODO: make this eval better
269 EvalInput {
270 conversation: vec![
271 message(User, [text("Let's research how to cursor blinking works.")]),
272 message(
273 Assistant,
274 [tool_use(
275 "tool_1",
276 "grep",
277 GrepToolInput {
278 regex: "blink".into(),
279 include_pattern: None,
280 offset: 0,
281 case_sensitive: false,
282 },
283 )],
284 ),
285 message(
286 User,
287 [tool_result(
288 "tool_1",
289 "grep",
290 [
291 lines(input_file_content, 100..400),
292 lines(input_file_content, 800..1300),
293 lines(input_file_content, 1600..2000),
294 lines(input_file_content, 5000..5500),
295 lines(input_file_content, 8000..9000),
296 lines(input_file_content, 18455..18470),
297 lines(input_file_content, 20000..20500),
298 lines(input_file_content, 21000..21300),
299 ]
300 .join("Match found:\n\n"),
301 )],
302 ),
303 message(
304 User,
305 [text(indoc! {"
306 Comment out the lines that interact with the BlinkManager.
307 Keep the outer `update` blocks, but comments everything that's inside (including if statements).
308 Don't add additional comments.
309 "})],
310 ),
311 message(
312 Assistant,
313 [tool_use(
314 "tool_4",
315 "edit_file",
316 EditFileToolInput {
317 display_description: edit_description.into(),
318 path: input_file_path.into(),
319 create_or_overwrite: false,
320 },
321 )],
322 ),
323 ],
324 input_path: input_file_path.into(),
325 input_content: Some(input_file_content.into()),
326 edit_description: edit_description.into(),
327 assertion: EvalAssertion::assert_eq(output_file_content),
328 },
329 );
330}
331
332#[test]
333#[cfg_attr(not(feature = "eval"), ignore)]
334fn eval_from_pixels_constructor() {
335 let input_file_path = "root/canvas.rs";
336 let input_file_content = include_str!("evals/fixtures/from_pixels_constructor/before.rs");
337 let edit_description = "Implement from_pixels constructor and add tests.";
338 eval(
339 100,
340 0.95,
341 EvalInput {
342 conversation: vec![
343 message(
344 User,
345 [text(indoc! {"
346 Introduce a new `from_pixels` constructor in Canvas and
347 also add tests for it in the same file.
348 "})],
349 ),
350 message(
351 Assistant,
352 [tool_use(
353 "tool_1",
354 "read_file",
355 ReadFileToolInput {
356 path: input_file_path.into(),
357 start_line: None,
358 end_line: None,
359 },
360 )],
361 ),
362 message(
363 User,
364 [tool_result("tool_1", "read_file", input_file_content)],
365 ),
366 message(
367 Assistant,
368 [tool_use(
369 "tool_2",
370 "grep",
371 GrepToolInput {
372 regex: "mod\\s+tests".into(),
373 include_pattern: Some("font-kit/src/canvas.rs".into()),
374 offset: 0,
375 case_sensitive: false,
376 },
377 )],
378 ),
379 message(User, [tool_result("tool_2", "grep", "No matches found")]),
380 message(
381 Assistant,
382 [tool_use(
383 "tool_3",
384 "grep",
385 GrepToolInput {
386 regex: "mod\\s+tests".into(),
387 include_pattern: Some("font-kit/src/**/*.rs".into()),
388 offset: 0,
389 case_sensitive: false,
390 },
391 )],
392 ),
393 message(User, [tool_result("tool_3", "grep", "No matches found")]),
394 message(
395 Assistant,
396 [tool_use(
397 "tool_4",
398 "grep",
399 GrepToolInput {
400 regex: "#\\[test\\]".into(),
401 include_pattern: Some("font-kit/src/**/*.rs".into()),
402 offset: 0,
403 case_sensitive: false,
404 },
405 )],
406 ),
407 message(
408 User,
409 [tool_result(
410 "tool_4",
411 "grep",
412 indoc! {"
413 Found 6 matches:
414
415 ## Matches in font-kit/src/loaders/core_text.rs
416
417 ### mod test › L926-936
418 ```
419 mod test {
420 use super::Font;
421 use crate::properties::{Stretch, Weight};
422
423 #[cfg(feature = \"source\")]
424 use crate::source::SystemSource;
425
426 static TEST_FONT_POSTSCRIPT_NAME: &'static str = \"ArialMT\";
427
428 #[cfg(feature = \"source\")]
429 #[test]
430 ```
431
432 55 lines remaining in ancestor node. Read the file to see all.
433
434 ### mod test › L947-951
435 ```
436 }
437
438 #[test]
439 fn test_core_text_to_css_font_weight() {
440 // Exact matches
441 ```
442
443 ### mod test › L959-963
444 ```
445 }
446
447 #[test]
448 fn test_core_text_to_css_font_stretch() {
449 // Exact matches
450 ```
451
452 ## Matches in font-kit/src/loaders/freetype.rs
453
454 ### mod test › L1238-1248
455 ```
456 mod test {
457 use crate::loaders::freetype::Font;
458
459 static PCF_FONT_PATH: &str = \"resources/tests/times-roman-pcf/timR12.pcf\";
460 static PCF_FONT_POSTSCRIPT_NAME: &str = \"Times-Roman\";
461
462 #[test]
463 fn get_pcf_postscript_name() {
464 let font = Font::from_path(PCF_FONT_PATH, 0).unwrap();
465 assert_eq!(font.postscript_name().unwrap(), PCF_FONT_POSTSCRIPT_NAME);
466 }
467 ```
468
469 1 lines remaining in ancestor node. Read the file to see all.
470
471 ## Matches in font-kit/src/sources/core_text.rs
472
473 ### mod test › L265-275
474 ```
475 mod test {
476 use crate::properties::{Stretch, Weight};
477
478 #[test]
479 fn test_css_to_core_text_font_weight() {
480 // Exact matches
481 assert_eq!(super::css_to_core_text_font_weight(Weight(100.0)), -0.7);
482 assert_eq!(super::css_to_core_text_font_weight(Weight(400.0)), 0.0);
483 assert_eq!(super::css_to_core_text_font_weight(Weight(700.0)), 0.4);
484 assert_eq!(super::css_to_core_text_font_weight(Weight(900.0)), 0.8);
485
486 ```
487
488 27 lines remaining in ancestor node. Read the file to see all.
489
490 ### mod test › L278-282
491 ```
492 }
493
494 #[test]
495 fn test_css_to_core_text_font_stretch() {
496 // Exact matches
497 ```
498 "},
499 )],
500 ),
501 message(
502 Assistant,
503 [tool_use(
504 "tool_5",
505 "edit_file",
506 EditFileToolInput {
507 display_description: edit_description.into(),
508 path: input_file_path.into(),
509 create_or_overwrite: false,
510 },
511 )],
512 ),
513 ],
514 input_path: input_file_path.into(),
515 input_content: Some(input_file_content.into()),
516 edit_description: edit_description.into(),
517 assertion: EvalAssertion::judge_diff(indoc! {"
518 - The diff contains a new `from_pixels` constructor
519 - The diff contains new tests for the `from_pixels` constructor
520 "}),
521 },
522 );
523}
524
525#[test]
526#[cfg_attr(not(feature = "eval"), ignore)]
527fn eval_zode() {
528 let input_file_path = "root/zode.py";
529 let edit_description = "Create the main Zode CLI script";
530 eval(
531 200,
532 1.,
533 EvalInput {
534 conversation: vec![
535 message(User, [text(include_str!("evals/fixtures/zode/prompt.md"))]),
536 message(
537 Assistant,
538 [
539 tool_use(
540 "tool_1",
541 "read_file",
542 ReadFileToolInput {
543 path: "root/eval/react.py".into(),
544 start_line: None,
545 end_line: None,
546 },
547 ),
548 tool_use(
549 "tool_2",
550 "read_file",
551 ReadFileToolInput {
552 path: "root/eval/react_test.py".into(),
553 start_line: None,
554 end_line: None,
555 },
556 ),
557 ],
558 ),
559 message(
560 User,
561 [
562 tool_result(
563 "tool_1",
564 "read_file",
565 include_str!("evals/fixtures/zode/react.py"),
566 ),
567 tool_result(
568 "tool_2",
569 "read_file",
570 include_str!("evals/fixtures/zode/react_test.py"),
571 ),
572 ],
573 ),
574 message(
575 Assistant,
576 [
577 text(
578 "Now that I understand what we need to build, I'll create the main Python script:",
579 ),
580 tool_use(
581 "tool_3",
582 "edit_file",
583 EditFileToolInput {
584 display_description: edit_description.into(),
585 path: input_file_path.into(),
586 create_or_overwrite: true,
587 },
588 ),
589 ],
590 ),
591 ],
592 input_path: input_file_path.into(),
593 input_content: None,
594 edit_description: edit_description.into(),
595 assertion: EvalAssertion::new(async move |sample, _, _cx| {
596 let invalid_starts = [' ', '`', '\n'];
597 let mut message = String::new();
598 for start in invalid_starts {
599 if sample.text.starts_with(start) {
600 message.push_str(&format!("The sample starts with a {:?}\n", start));
601 break;
602 }
603 }
604 // Remove trailing newline.
605 message.pop();
606
607 if message.is_empty() {
608 Ok(EvalAssertionOutcome {
609 score: 100,
610 message: None,
611 })
612 } else {
613 Ok(EvalAssertionOutcome {
614 score: 0,
615 message: Some(message),
616 })
617 }
618 }),
619 },
620 );
621}
622
623#[test]
624#[cfg_attr(not(feature = "eval"), ignore)]
625fn eval_add_overwrite_test() {
626 let input_file_path = "root/action_log.rs";
627 let input_file_content = include_str!("evals/fixtures/add_overwrite_test/before.rs");
628 let edit_description = "Add a new test for overwriting a file in action_log.rs";
629 eval(
630 200,
631 0.5, // TODO: make this eval better
632 EvalInput {
633 conversation: vec![
634 message(
635 User,
636 [text(indoc! {"
637 Introduce a new test in `action_log.rs` to test overwriting a file.
638 That is, a file already exists, but we call `buffer_created` as if the file were new.
639 Take inspiration from all the other tests in the file.
640 "})],
641 ),
642 message(
643 Assistant,
644 [tool_use(
645 "tool_1",
646 "read_file",
647 ReadFileToolInput {
648 path: input_file_path.into(),
649 start_line: None,
650 end_line: None,
651 },
652 )],
653 ),
654 message(
655 User,
656 [tool_result(
657 "tool_1",
658 "read_file",
659 indoc! {"
660 pub struct ActionLog [L13-20]
661 tracked_buffers [L15]
662 edited_since_project_diagnostics_check [L17]
663 project [L19]
664 impl ActionLog [L22-498]
665 pub fn new [L24-30]
666 pub fn project [L32-34]
667 pub fn checked_project_diagnostics [L37-39]
668 pub fn has_edited_files_since_project_diagnostics_check [L42-44]
669 fn track_buffer_internal [L46-101]
670 fn handle_buffer_event [L103-116]
671 fn handle_buffer_edited [L118-123]
672 fn handle_buffer_file_changed [L125-158]
673 async fn maintain_diff [L160-264]
674 pub fn buffer_read [L267-269]
675 pub fn buffer_created [L272-276]
676 pub fn buffer_edited [L279-287]
677 pub fn will_delete_buffer [L289-304]
678 pub fn keep_edits_in_range [L306-364]
679 pub fn reject_edits_in_ranges [L366-459]
680 pub fn keep_all_edits [L461-473]
681 pub fn changed_buffers [L476-482]
682 pub fn stale_buffers [L485-497]
683 fn apply_non_conflicting_edits [L500-561]
684 fn diff_snapshots [L563-585]
685 fn point_to_row_edit [L587-614]
686 enum ChangeAuthor [L617-620]
687 User [L618]
688 Agent [L619]
689 enum TrackedBufferStatus [L623-627]
690 Created [L624]
691 Modified [L625]
692 Deleted [L626]
693 struct TrackedBuffer [L629-641]
694 buffer [L630]
695 base_text [L631]
696 unreviewed_changes [L632]
697 status [L633]
698 version [L634]
699 diff [L635]
700 snapshot [L636]
701 diff_update [L637]
702 _open_lsp_handle [L638]
703 _maintain_diff [L639]
704 _subscription [L640]
705 impl TrackedBuffer [L643-657]
706 fn has_changes [L644-650]
707 fn schedule_diff_update [L652-656]
708 pub struct ChangedBuffer [L659-661]
709 pub diff [L660]
710 mod tests [L664-1574]
711 fn init_logger [L678-682]
712 fn init_test [L684-691]
713 async fn test_keep_edits [L694-769]
714 async fn test_deletions [L772-854]
715 async fn test_overlapping_user_edits [L857-951]
716 async fn test_creating_files [L954-1010]
717 async fn test_deleting_files [L1013-1120]
718 async fn test_reject_edits [L1123-1255]
719 async fn test_reject_multiple_edits [L1258-1331]
720 async fn test_reject_deleted_file [L1334-1388]
721 async fn test_reject_created_file [L1391-1443]
722 async fn test_random_diffs [L1446-1535]
723 fn quiesce [L1510-1534]
724 struct HunkStatus [L1538-1542]
725 range [L1539]
726 diff_status [L1540]
727 old_text [L1541]
728 fn unreviewed_hunks [L1544-1573]
729
730 Showing symbols 1-69 (total symbols: 69)
731
732 Using the line numbers in this outline, you can call this tool again while specifying
733 the start_line and end_line fields to see the implementations of symbols in the outline.
734 "},
735 )],
736 ),
737 message(
738 Assistant,
739 [
740 text(
741 "Let's look at the existing test cases that are related to file creation to understand the pattern:",
742 ),
743 tool_use(
744 "tool_2",
745 "read_file",
746 ReadFileToolInput {
747 path: input_file_path.into(),
748 start_line: Some(953),
749 end_line: Some(1010),
750 },
751 ),
752 ],
753 ),
754 message(
755 User,
756 [tool_result(
757 "tool_2",
758 "read_file",
759 lines(input_file_content, 953..1010),
760 )],
761 ),
762 message(
763 Assistant,
764 [
765 text(
766 "Now let's look at another relevant test to understand how they're structured:",
767 ),
768 tool_use(
769 "tool_3",
770 "read_file",
771 ReadFileToolInput {
772 path: input_file_path.into(),
773 start_line: Some(1012),
774 end_line: Some(1120),
775 },
776 ),
777 ],
778 ),
779 message(
780 User,
781 [tool_result(
782 "tool_3",
783 "read_file",
784 lines(input_file_content, 1012..1120),
785 )],
786 ),
787 message(
788 Assistant,
789 [
790 text("Now let's look at how `buffer_created` is implemented:"),
791 tool_use(
792 "tool_4",
793 "read_file",
794 ReadFileToolInput {
795 path: input_file_path.into(),
796 start_line: Some(271),
797 end_line: Some(276),
798 },
799 ),
800 ],
801 ),
802 message(
803 User,
804 [tool_result(
805 "tool_4",
806 "read_file",
807 lines(input_file_content, 271..276),
808 )],
809 ),
810 message(
811 Assistant,
812 [
813 text(indoc! {"
814 Based on my understanding of the existing tests and code, I'll now create a new test for overwriting a file. In this test, we'll:
815
816 1. Create a file
817 2. Open a buffer for it
818 3. Track it in the action log
819 4. Modify the buffer
820 5. Call `buffer_created` again as if the file were new
821 6. Check that the changes are properly tracked
822
823 Let's write the test:
824 "}),
825 tool_use(
826 "tool_5",
827 "edit_file",
828 EditFileToolInput {
829 display_description: edit_description.into(),
830 path: input_file_path.into(),
831 create_or_overwrite: false,
832 },
833 ),
834 ],
835 ),
836 ],
837 input_path: input_file_path.into(),
838 input_content: Some(input_file_content.into()),
839 edit_description: edit_description.into(),
840 assertion: EvalAssertion::judge_diff(
841 "A new test for overwritten files was created, without changing any previous test",
842 ),
843 },
844 );
845}
846
847fn message(
848 role: Role,
849 contents: impl IntoIterator<Item = MessageContent>,
850) -> LanguageModelRequestMessage {
851 LanguageModelRequestMessage {
852 role,
853 content: contents.into_iter().collect(),
854 cache: false,
855 }
856}
857
858fn text(text: impl Into<String>) -> MessageContent {
859 MessageContent::Text(text.into())
860}
861
862fn lines(input: &str, range: Range<usize>) -> String {
863 input
864 .lines()
865 .skip(range.start)
866 .take(range.len())
867 .collect::<Vec<_>>()
868 .join("\n")
869}
870
871fn tool_use(
872 id: impl Into<Arc<str>>,
873 name: impl Into<Arc<str>>,
874 input: impl Serialize,
875) -> MessageContent {
876 MessageContent::ToolUse(LanguageModelToolUse {
877 id: LanguageModelToolUseId::from(id.into()),
878 name: name.into(),
879 raw_input: serde_json::to_string_pretty(&input).unwrap(),
880 input: serde_json::to_value(input).unwrap(),
881 is_input_complete: true,
882 })
883}
884
885fn tool_result(
886 id: impl Into<Arc<str>>,
887 name: impl Into<Arc<str>>,
888 result: impl Into<Arc<str>>,
889) -> MessageContent {
890 MessageContent::ToolResult(LanguageModelToolResult {
891 tool_use_id: LanguageModelToolUseId::from(id.into()),
892 tool_name: name.into(),
893 is_error: false,
894 content: result.into(),
895 output: None,
896 })
897}
898
899#[derive(Clone)]
900struct EvalInput {
901 conversation: Vec<LanguageModelRequestMessage>,
902 input_path: PathBuf,
903 input_content: Option<String>,
904 edit_description: String,
905 assertion: EvalAssertion,
906}
907
908#[derive(Clone)]
909struct EvalSample {
910 text: String,
911 edit_output: EditAgentOutput,
912 diff: String,
913}
914
915trait AssertionFn: 'static + Send + Sync {
916 fn assert<'a>(
917 &'a self,
918 sample: &'a EvalSample,
919 judge_model: Arc<dyn LanguageModel>,
920 cx: &'a mut TestAppContext,
921 ) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>>;
922}
923
924impl<F> AssertionFn for F
925where
926 F: 'static
927 + Send
928 + Sync
929 + AsyncFn(
930 &EvalSample,
931 Arc<dyn LanguageModel>,
932 &mut TestAppContext,
933 ) -> Result<EvalAssertionOutcome>,
934{
935 fn assert<'a>(
936 &'a self,
937 sample: &'a EvalSample,
938 judge_model: Arc<dyn LanguageModel>,
939 cx: &'a mut TestAppContext,
940 ) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>> {
941 (self)(sample, judge_model, cx).boxed_local()
942 }
943}
944
945#[derive(Clone)]
946struct EvalAssertion(Arc<dyn AssertionFn>);
947
948impl EvalAssertion {
949 fn new<F>(f: F) -> Self
950 where
951 F: 'static
952 + Send
953 + Sync
954 + AsyncFn(
955 &EvalSample,
956 Arc<dyn LanguageModel>,
957 &mut TestAppContext,
958 ) -> Result<EvalAssertionOutcome>,
959 {
960 EvalAssertion(Arc::new(f))
961 }
962
963 fn assert_eq(expected: impl Into<String>) -> Self {
964 let expected = expected.into();
965 Self::new(async move |sample, _judge, _cx| {
966 Ok(EvalAssertionOutcome {
967 score: if strip_empty_lines(&sample.text) == strip_empty_lines(&expected) {
968 100
969 } else {
970 0
971 },
972 message: None,
973 })
974 })
975 }
976
977 fn judge_diff(assertions: &'static str) -> Self {
978 Self::new(async move |sample, judge, cx| {
979 let prompt = DiffJudgeTemplate {
980 diff: sample.diff.clone(),
981 assertions,
982 }
983 .render(&Templates::new())
984 .unwrap();
985
986 let request = LanguageModelRequest {
987 messages: vec![LanguageModelRequestMessage {
988 role: Role::User,
989 content: vec![prompt.into()],
990 cache: false,
991 }],
992 ..Default::default()
993 };
994 let mut response = judge
995 .stream_completion_text(request, &cx.to_async())
996 .await?;
997 let mut output = String::new();
998 while let Some(chunk) = response.stream.next().await {
999 let chunk = chunk?;
1000 output.push_str(&chunk);
1001 }
1002
1003 // Parse the score from the response
1004 let re = regex::Regex::new(r"<score>(\d+)</score>").unwrap();
1005 if let Some(captures) = re.captures(&output) {
1006 if let Some(score_match) = captures.get(1) {
1007 let score = score_match.as_str().parse().unwrap_or(0);
1008 return Ok(EvalAssertionOutcome {
1009 score,
1010 message: Some(output),
1011 });
1012 }
1013 }
1014
1015 Err(anyhow!(
1016 "No score found in response. Raw output: {}",
1017 output
1018 ))
1019 })
1020 }
1021
1022 async fn run(
1023 &self,
1024 input: &EvalSample,
1025 judge_model: Arc<dyn LanguageModel>,
1026 cx: &mut TestAppContext,
1027 ) -> Result<EvalAssertionOutcome> {
1028 self.0.assert(input, judge_model, cx).await
1029 }
1030}
1031
1032fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
1033 let mut evaluated_count = 0;
1034 report_progress(evaluated_count, iterations);
1035
1036 let (tx, rx) = mpsc::channel();
1037
1038 // Cache the last message in the conversation, and run one instance of the eval so that
1039 // all the next ones are cached.
1040 eval.conversation.last_mut().unwrap().cache = true;
1041 run_eval(eval.clone(), tx.clone());
1042
1043 let executor = gpui::background_executor();
1044 for _ in 1..iterations {
1045 let eval = eval.clone();
1046 let tx = tx.clone();
1047 executor.spawn(async move { run_eval(eval, tx) }).detach();
1048 }
1049 drop(tx);
1050
1051 let mut failed_count = 0;
1052 let mut failed_evals = HashMap::default();
1053 let mut errored_evals = HashMap::default();
1054 let mut eval_outputs = Vec::new();
1055 let mut cumulative_parser_metrics = EditParserMetrics::default();
1056 while let Ok(output) = rx.recv() {
1057 match output {
1058 Ok(output) => {
1059 cumulative_parser_metrics += output.sample.edit_output._parser_metrics.clone();
1060 eval_outputs.push(output.clone());
1061 if output.assertion.score < 80 {
1062 failed_count += 1;
1063 failed_evals
1064 .entry(output.sample.text.clone())
1065 .or_insert(Vec::new())
1066 .push(output);
1067 }
1068 }
1069 Err(error) => {
1070 failed_count += 1;
1071 *errored_evals.entry(format!("{:?}", error)).or_insert(0) += 1;
1072 }
1073 }
1074
1075 evaluated_count += 1;
1076 report_progress(evaluated_count, iterations);
1077 }
1078
1079 let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32;
1080 println!("Actual pass ratio: {}\n", actual_pass_ratio);
1081 if actual_pass_ratio < expected_pass_ratio {
1082 let mut errored_evals = errored_evals.into_iter().collect::<Vec<_>>();
1083 errored_evals.sort_by_key(|(_, count)| Reverse(*count));
1084 for (error, count) in errored_evals {
1085 println!("Eval errored {} times. Error: {}", count, error);
1086 }
1087
1088 let mut failed_evals = failed_evals.into_iter().collect::<Vec<_>>();
1089 failed_evals.sort_by_key(|(_, evals)| Reverse(evals.len()));
1090 for (_buffer_output, failed_evals) in failed_evals {
1091 let eval_output = failed_evals.first().unwrap();
1092 println!("Eval failed {} times", failed_evals.len());
1093 println!("{}", eval_output);
1094 }
1095
1096 panic!(
1097 "Actual pass ratio: {}\nExpected pass ratio: {}",
1098 actual_pass_ratio, expected_pass_ratio
1099 );
1100 }
1101
1102 let mismatched_tag_ratio =
1103 cumulative_parser_metrics.mismatched_tags as f32 / cumulative_parser_metrics.tags as f32;
1104 if mismatched_tag_ratio > 0.05 {
1105 for eval_output in eval_outputs {
1106 println!("{}", eval_output);
1107 }
1108 panic!("Too many mismatched tags: {:?}", cumulative_parser_metrics);
1109 }
1110}
1111
1112fn run_eval(eval: EvalInput, tx: mpsc::Sender<Result<EvalOutput>>) {
1113 let dispatcher = gpui::TestDispatcher::new(StdRng::from_entropy());
1114 let mut cx = TestAppContext::build(dispatcher, None);
1115 let output = cx.executor().block_test(async {
1116 let test = EditAgentTest::new(&mut cx).await;
1117 test.eval(eval, &mut cx).await
1118 });
1119 tx.send(output).unwrap();
1120}
1121
1122#[derive(Clone)]
1123struct EvalOutput {
1124 sample: EvalSample,
1125 assertion: EvalAssertionOutcome,
1126}
1127
1128impl Display for EvalOutput {
1129 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1130 writeln!(f, "Score: {:?}", self.assertion.score)?;
1131 if let Some(message) = self.assertion.message.as_ref() {
1132 writeln!(f, "Message: {}", message)?;
1133 }
1134
1135 writeln!(f, "Diff:\n{}", self.sample.diff)?;
1136
1137 writeln!(
1138 f,
1139 "Parser Metrics:\n{:#?}",
1140 self.sample.edit_output._parser_metrics
1141 )?;
1142 writeln!(f, "Raw Edits:\n{}", self.sample.edit_output._raw_edits)?;
1143 Ok(())
1144 }
1145}
1146
1147fn report_progress(evaluated_count: usize, iterations: usize) {
1148 print!("\r\x1b[KEvaluated {}/{}", evaluated_count, iterations);
1149 std::io::stdout().flush().unwrap();
1150}
1151
1152struct EditAgentTest {
1153 agent: EditAgent,
1154 project: Entity<Project>,
1155 judge_model: Arc<dyn LanguageModel>,
1156}
1157
1158impl EditAgentTest {
1159 async fn new(cx: &mut TestAppContext) -> Self {
1160 cx.executor().allow_parking();
1161 cx.update(settings::init);
1162 cx.update(Project::init_settings);
1163 cx.update(language::init);
1164 cx.update(gpui_tokio::init);
1165 cx.update(client::init_settings);
1166
1167 let fs = FakeFs::new(cx.executor().clone());
1168 fs.insert_tree("/root", json!({})).await;
1169 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1170 let (agent_model, judge_model) = cx
1171 .update(|cx| {
1172 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
1173 cx.set_http_client(Arc::new(http_client));
1174
1175 let client = Client::production(cx);
1176 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1177 language_model::init(client.clone(), cx);
1178 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
1179
1180 cx.spawn(async move |cx| {
1181 let agent_model =
1182 Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
1183 let judge_model =
1184 Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
1185 (agent_model.unwrap(), judge_model.unwrap())
1186 })
1187 })
1188 .await;
1189 let action_log = cx.new(|_| ActionLog::new(project.clone()));
1190
1191 Self {
1192 agent: EditAgent::new(agent_model, project.clone(), action_log, Templates::new()),
1193 project,
1194 judge_model,
1195 }
1196 }
1197
1198 async fn load_model(
1199 provider: &str,
1200 id: &str,
1201 cx: &mut AsyncApp,
1202 ) -> Result<Arc<dyn LanguageModel>> {
1203 let (provider, model) = cx.update(|cx| {
1204 let models = LanguageModelRegistry::read_global(cx);
1205 let model = models
1206 .available_models(cx)
1207 .find(|model| model.provider_id().0 == provider && model.id().0 == id)
1208 .unwrap();
1209 let provider = models.provider(&model.provider_id()).unwrap();
1210 (provider, model)
1211 })?;
1212 cx.update(|cx| provider.authenticate(cx))?.await?;
1213 Ok(model)
1214 }
1215
1216 async fn eval(&self, eval: EvalInput, cx: &mut TestAppContext) -> Result<EvalOutput> {
1217 let path = self
1218 .project
1219 .read_with(cx, |project, cx| {
1220 project.find_project_path(eval.input_path, cx)
1221 })
1222 .unwrap();
1223 let buffer = self
1224 .project
1225 .update(cx, |project, cx| project.open_buffer(path, cx))
1226 .await
1227 .unwrap();
1228 let edit_output = if let Some(input_content) = eval.input_content.as_deref() {
1229 buffer.update(cx, |buffer, cx| buffer.set_text(input_content, cx));
1230 let (edit_output, _) = self.agent.edit(
1231 buffer.clone(),
1232 eval.edit_description,
1233 eval.conversation,
1234 &mut cx.to_async(),
1235 );
1236 edit_output.await?
1237 } else {
1238 let (edit_output, _) = self.agent.overwrite(
1239 buffer.clone(),
1240 eval.edit_description,
1241 eval.conversation,
1242 &mut cx.to_async(),
1243 );
1244 edit_output.await?
1245 };
1246
1247 let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
1248 let sample = EvalSample {
1249 edit_output,
1250 diff: language::unified_diff(
1251 eval.input_content.as_deref().unwrap_or_default(),
1252 &buffer_text,
1253 ),
1254 text: buffer_text,
1255 };
1256 let assertion = eval
1257 .assertion
1258 .run(&sample, self.judge_model.clone(), cx)
1259 .await?;
1260
1261 Ok(EvalOutput { assertion, sample })
1262 }
1263}
1264
1265#[derive(Clone, Debug, Eq, PartialEq, Hash)]
1266struct EvalAssertionOutcome {
1267 score: usize,
1268 message: Option<String>,
1269}
1270
1271#[derive(Serialize)]
1272pub struct DiffJudgeTemplate {
1273 diff: String,
1274 assertions: &'static str,
1275}
1276
1277impl Template for DiffJudgeTemplate {
1278 const TEMPLATE_NAME: &'static str = "diff_judge.hbs";
1279}
1280
1281fn strip_empty_lines(text: &str) -> String {
1282 text.lines()
1283 .filter(|line| !line.trim().is_empty())
1284 .collect::<Vec<_>>()
1285 .join("\n")
1286}