1use super::*;
2use crate::{
3 ReadFileToolInput, grep_tool::GrepToolInput,
4 streaming_edit_file_tool::StreamingEditFileToolInput,
5};
6use Role::*;
7use anyhow::anyhow;
8use client::{Client, UserStore};
9use collections::HashMap;
10use fs::FakeFs;
11use futures::{FutureExt, future::LocalBoxFuture};
12use gpui::{AppContext, TestAppContext};
13use indoc::indoc;
14use language_model::{
15 LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId,
16};
17use project::Project;
18use rand::prelude::*;
19use reqwest_client::ReqwestClient;
20use serde_json::json;
21use std::{
22 cmp::Reverse,
23 fmt::{self, Display},
24 io::Write as _,
25 sync::mpsc,
26};
27use util::path;
28
29#[test]
30#[cfg_attr(not(feature = "eval"), ignore)]
31fn eval_extract_handle_command_output() {
32 let input_file_path = "root/blame.rs";
33 let input_file_content = include_str!("evals/fixtures/extract_handle_command_output/before.rs");
34 let output_file_content = include_str!("evals/fixtures/extract_handle_command_output/after.rs");
35 let edit_description = "Extract `handle_command_output` method from `run_git_blame`.";
36 eval(
37 100,
38 0.95,
39 EvalInput {
40 conversation: vec![
41 message(
42 User,
43 [text(indoc! {"
44 Read the `{input_file_path}` file and extract a method in
45 the final stanza of `run_git_blame` to deal with command failures,
46 call it `handle_command_output` and take the std::process::Output as the only parameter.
47
48 Add it right next to `run_git_blame` and copy it verbatim from `run_git_blame`.
49 "})],
50 ),
51 message(
52 Assistant,
53 [tool_use(
54 "tool_1",
55 "read_file",
56 ReadFileToolInput {
57 path: input_file_path.into(),
58 start_line: None,
59 end_line: None,
60 },
61 )],
62 ),
63 message(
64 User,
65 [tool_result("tool_1", "read_file", input_file_content)],
66 ),
67 message(
68 Assistant,
69 [tool_use(
70 "tool_2",
71 "edit_file",
72 StreamingEditFileToolInput {
73 display_description: edit_description.into(),
74 path: input_file_path.into(),
75 create_or_overwrite: false,
76 },
77 )],
78 ),
79 ],
80 input_path: input_file_path.into(),
81 input_content: Some(input_file_content.into()),
82 edit_description: edit_description.into(),
83 assertion: EvalAssertion::assert_eq(output_file_content),
84 },
85 );
86}
87
88#[test]
89#[cfg_attr(not(feature = "eval"), ignore)]
90fn eval_delete_run_git_blame() {
91 let input_file_path = "root/blame.rs";
92 let input_file_content = include_str!("evals/fixtures/delete_run_git_blame/before.rs");
93 let output_file_content = include_str!("evals/fixtures/delete_run_git_blame/after.rs");
94 let edit_description = "Delete the `run_git_blame` function.";
95 eval(
96 100,
97 0.95,
98 EvalInput {
99 conversation: vec![
100 message(
101 User,
102 [text(indoc! {"
103 Read the `{input_file_path}` file and delete `run_git_blame`. Just that
104 one function, not its usages.
105 "})],
106 ),
107 message(
108 Assistant,
109 [tool_use(
110 "tool_1",
111 "read_file",
112 ReadFileToolInput {
113 path: input_file_path.into(),
114 start_line: None,
115 end_line: None,
116 },
117 )],
118 ),
119 message(
120 User,
121 [tool_result("tool_1", "read_file", input_file_content)],
122 ),
123 message(
124 Assistant,
125 [tool_use(
126 "tool_2",
127 "edit_file",
128 StreamingEditFileToolInput {
129 display_description: edit_description.into(),
130 path: input_file_path.into(),
131 create_or_overwrite: false,
132 },
133 )],
134 ),
135 ],
136 input_path: input_file_path.into(),
137 input_content: Some(input_file_content.into()),
138 edit_description: edit_description.into(),
139 assertion: EvalAssertion::assert_eq(output_file_content),
140 },
141 );
142}
143
144#[test]
145#[cfg_attr(not(feature = "eval"), ignore)]
146fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
147 let input_file_path = "root/lib.rs";
148 let input_file_content =
149 include_str!("evals/fixtures/use_wasi_sdk_in_compile_parser_to_wasm/before.rs");
150 let edit_description = "Update compile_parser_to_wasm to use wasi-sdk instead of emscripten";
151 eval(
152 100,
153 0.95,
154 EvalInput {
155 conversation: vec![
156 message(
157 User,
158 [text(indoc! {"
159 Read the `{input_file_path}` file and change `compile_parser_to_wasm` to use `wasi-sdk` instead of emscripten.
160 Use `ureq` to download the SDK for the current platform and architecture.
161 Extract the archive into a sibling of `lib` inside the `tree-sitter` directory in the cache_dir.
162 Compile the parser to wasm using the `bin/clang` executable (or `bin/clang.exe` on windows)
163 that's inside of the archive.
164 Don't re-download the SDK if that executable already exists.
165
166 Use these clang flags: -fPIC -shared -Os -Wl,--export=tree_sitter_{language_name}
167
168 Here are the available wasi-sdk assets:
169 - wasi-sdk-25.0-x86_64-macos.tar.gz
170 - wasi-sdk-25.0-arm64-macos.tar.gz
171 - wasi-sdk-25.0-x86_64-linux.tar.gz
172 - wasi-sdk-25.0-arm64-linux.tar.gz
173 - wasi-sdk-25.0-x86_64-linux.tar.gz
174 - wasi-sdk-25.0-arm64-linux.tar.gz
175 - wasi-sdk-25.0-x86_64-windows.tar.gz
176 "})],
177 ),
178 message(
179 Assistant,
180 [tool_use(
181 "tool_1",
182 "read_file",
183 ReadFileToolInput {
184 path: input_file_path.into(),
185 start_line: Some(971),
186 end_line: Some(1050),
187 },
188 )],
189 ),
190 message(
191 User,
192 [tool_result(
193 "tool_1",
194 "read_file",
195 lines(input_file_content, 971..1050),
196 )],
197 ),
198 message(
199 Assistant,
200 [tool_use(
201 "tool_2",
202 "read_file",
203 ReadFileToolInput {
204 path: input_file_path.into(),
205 start_line: Some(1050),
206 end_line: Some(1100),
207 },
208 )],
209 ),
210 message(
211 User,
212 [tool_result(
213 "tool_2",
214 "read_file",
215 lines(input_file_content, 1050..1100),
216 )],
217 ),
218 message(
219 Assistant,
220 [tool_use(
221 "tool_3",
222 "read_file",
223 ReadFileToolInput {
224 path: input_file_path.into(),
225 start_line: Some(1100),
226 end_line: Some(1150),
227 },
228 )],
229 ),
230 message(
231 User,
232 [tool_result(
233 "tool_3",
234 "read_file",
235 lines(input_file_content, 1100..1150),
236 )],
237 ),
238 message(
239 Assistant,
240 [tool_use(
241 "tool_4",
242 "edit_file",
243 StreamingEditFileToolInput {
244 display_description: edit_description.into(),
245 path: input_file_path.into(),
246 create_or_overwrite: false,
247 },
248 )],
249 ),
250 ],
251 input_path: input_file_path.into(),
252 input_content: Some(input_file_content.into()),
253 edit_description: edit_description.into(),
254 assertion: EvalAssertion::judge_diff(indoc! {"
255 - The compile_parser_to_wasm method has been changed to use wasi-sdk
256 - ureq is used to download the SDK for current platform and architecture
257 "}),
258 },
259 );
260}
261
262#[test]
263#[cfg_attr(not(feature = "eval"), ignore)]
264fn eval_disable_cursor_blinking() {
265 let input_file_path = "root/editor.rs";
266 let input_file_content = include_str!("evals/fixtures/disable_cursor_blinking/before.rs");
267 let output_file_content = include_str!("evals/fixtures/disable_cursor_blinking/after.rs");
268 let edit_description = "Comment out the call to `BlinkManager::enable`";
269 eval(
270 200,
271 0.6, // TODO: make this eval better
272 EvalInput {
273 conversation: vec![
274 message(User, [text("Let's research how to cursor blinking works.")]),
275 message(
276 Assistant,
277 [tool_use(
278 "tool_1",
279 "grep",
280 GrepToolInput {
281 regex: "blink".into(),
282 include_pattern: None,
283 offset: 0,
284 case_sensitive: false,
285 },
286 )],
287 ),
288 message(
289 User,
290 [tool_result(
291 "tool_1",
292 "grep",
293 [
294 lines(input_file_content, 100..400),
295 lines(input_file_content, 800..1300),
296 lines(input_file_content, 1600..2000),
297 lines(input_file_content, 5000..5500),
298 lines(input_file_content, 8000..9000),
299 lines(input_file_content, 18455..18470),
300 lines(input_file_content, 20000..20500),
301 lines(input_file_content, 21000..21300),
302 ]
303 .join("Match found:\n\n"),
304 )],
305 ),
306 message(
307 User,
308 [text(indoc! {"
309 Comment out the lines that interact with the BlinkManager.
310 Keep the outer `update` blocks, but comments everything that's inside (including if statements).
311 Don't add additional comments.
312 "})],
313 ),
314 message(
315 Assistant,
316 [tool_use(
317 "tool_4",
318 "edit_file",
319 StreamingEditFileToolInput {
320 display_description: edit_description.into(),
321 path: input_file_path.into(),
322 create_or_overwrite: false,
323 },
324 )],
325 ),
326 ],
327 input_path: input_file_path.into(),
328 input_content: Some(input_file_content.into()),
329 edit_description: edit_description.into(),
330 assertion: EvalAssertion::assert_eq(output_file_content),
331 },
332 );
333}
334
335#[test]
336#[cfg_attr(not(feature = "eval"), ignore)]
337fn eval_from_pixels_constructor() {
338 let input_file_path = "root/canvas.rs";
339 let input_file_content = include_str!("evals/fixtures/from_pixels_constructor/before.rs");
340 let edit_description = "Implement from_pixels constructor and add tests.";
341 eval(
342 100,
343 0.95,
344 EvalInput {
345 conversation: vec![
346 message(
347 User,
348 [text(indoc! {"
349 Introduce a new `from_pixels` constructor in Canvas and
350 also add tests for it in the same file.
351 "})],
352 ),
353 message(
354 Assistant,
355 [tool_use(
356 "tool_1",
357 "read_file",
358 ReadFileToolInput {
359 path: input_file_path.into(),
360 start_line: None,
361 end_line: None,
362 },
363 )],
364 ),
365 message(
366 User,
367 [tool_result("tool_1", "read_file", input_file_content)],
368 ),
369 message(
370 Assistant,
371 [tool_use(
372 "tool_2",
373 "grep",
374 GrepToolInput {
375 regex: "mod\\s+tests".into(),
376 include_pattern: Some("font-kit/src/canvas.rs".into()),
377 offset: 0,
378 case_sensitive: false,
379 },
380 )],
381 ),
382 message(User, [tool_result("tool_2", "grep", "No matches found")]),
383 message(
384 Assistant,
385 [tool_use(
386 "tool_3",
387 "grep",
388 GrepToolInput {
389 regex: "mod\\s+tests".into(),
390 include_pattern: Some("font-kit/src/**/*.rs".into()),
391 offset: 0,
392 case_sensitive: false,
393 },
394 )],
395 ),
396 message(User, [tool_result("tool_3", "grep", "No matches found")]),
397 message(
398 Assistant,
399 [tool_use(
400 "tool_4",
401 "grep",
402 GrepToolInput {
403 regex: "#\\[test\\]".into(),
404 include_pattern: Some("font-kit/src/**/*.rs".into()),
405 offset: 0,
406 case_sensitive: false,
407 },
408 )],
409 ),
410 message(
411 User,
412 [tool_result(
413 "tool_4",
414 "grep",
415 indoc! {"
416 Found 6 matches:
417
418 ## Matches in font-kit/src/loaders/core_text.rs
419
420 ### mod test › L926-936
421 ```
422 mod test {
423 use super::Font;
424 use crate::properties::{Stretch, Weight};
425
426 #[cfg(feature = \"source\")]
427 use crate::source::SystemSource;
428
429 static TEST_FONT_POSTSCRIPT_NAME: &'static str = \"ArialMT\";
430
431 #[cfg(feature = \"source\")]
432 #[test]
433 ```
434
435 55 lines remaining in ancestor node. Read the file to see all.
436
437 ### mod test › L947-951
438 ```
439 }
440
441 #[test]
442 fn test_core_text_to_css_font_weight() {
443 // Exact matches
444 ```
445
446 ### mod test › L959-963
447 ```
448 }
449
450 #[test]
451 fn test_core_text_to_css_font_stretch() {
452 // Exact matches
453 ```
454
455 ## Matches in font-kit/src/loaders/freetype.rs
456
457 ### mod test › L1238-1248
458 ```
459 mod test {
460 use crate::loaders::freetype::Font;
461
462 static PCF_FONT_PATH: &str = \"resources/tests/times-roman-pcf/timR12.pcf\";
463 static PCF_FONT_POSTSCRIPT_NAME: &str = \"Times-Roman\";
464
465 #[test]
466 fn get_pcf_postscript_name() {
467 let font = Font::from_path(PCF_FONT_PATH, 0).unwrap();
468 assert_eq!(font.postscript_name().unwrap(), PCF_FONT_POSTSCRIPT_NAME);
469 }
470 ```
471
472 1 lines remaining in ancestor node. Read the file to see all.
473
474 ## Matches in font-kit/src/sources/core_text.rs
475
476 ### mod test › L265-275
477 ```
478 mod test {
479 use crate::properties::{Stretch, Weight};
480
481 #[test]
482 fn test_css_to_core_text_font_weight() {
483 // Exact matches
484 assert_eq!(super::css_to_core_text_font_weight(Weight(100.0)), -0.7);
485 assert_eq!(super::css_to_core_text_font_weight(Weight(400.0)), 0.0);
486 assert_eq!(super::css_to_core_text_font_weight(Weight(700.0)), 0.4);
487 assert_eq!(super::css_to_core_text_font_weight(Weight(900.0)), 0.8);
488
489 ```
490
491 27 lines remaining in ancestor node. Read the file to see all.
492
493 ### mod test › L278-282
494 ```
495 }
496
497 #[test]
498 fn test_css_to_core_text_font_stretch() {
499 // Exact matches
500 ```
501 "},
502 )],
503 ),
504 message(
505 Assistant,
506 [tool_use(
507 "tool_5",
508 "edit_file",
509 StreamingEditFileToolInput {
510 display_description: edit_description.into(),
511 path: input_file_path.into(),
512 create_or_overwrite: false,
513 },
514 )],
515 ),
516 ],
517 input_path: input_file_path.into(),
518 input_content: Some(input_file_content.into()),
519 edit_description: edit_description.into(),
520 assertion: EvalAssertion::judge_diff(indoc! {"
521 - The diff contains a new `from_pixels` constructor
522 - The diff contains new tests for the `from_pixels` constructor
523 "}),
524 },
525 );
526}
527
528#[test]
529#[cfg_attr(not(feature = "eval"), ignore)]
530fn eval_zode() {
531 let input_file_path = "root/zode.py";
532 let edit_description = "Create the main Zode CLI script";
533 eval(
534 200,
535 1.,
536 EvalInput {
537 conversation: vec![
538 message(User, [text(include_str!("evals/fixtures/zode/prompt.md"))]),
539 message(
540 Assistant,
541 [
542 tool_use(
543 "tool_1",
544 "read_file",
545 ReadFileToolInput {
546 path: "root/eval/react.py".into(),
547 start_line: None,
548 end_line: None,
549 },
550 ),
551 tool_use(
552 "tool_2",
553 "read_file",
554 ReadFileToolInput {
555 path: "root/eval/react_test.py".into(),
556 start_line: None,
557 end_line: None,
558 },
559 ),
560 ],
561 ),
562 message(
563 User,
564 [
565 tool_result(
566 "tool_1",
567 "read_file",
568 include_str!("evals/fixtures/zode/react.py"),
569 ),
570 tool_result(
571 "tool_2",
572 "read_file",
573 include_str!("evals/fixtures/zode/react_test.py"),
574 ),
575 ],
576 ),
577 message(
578 Assistant,
579 [
580 text(
581 "Now that I understand what we need to build, I'll create the main Python script:",
582 ),
583 tool_use(
584 "tool_3",
585 "edit_file",
586 StreamingEditFileToolInput {
587 display_description: edit_description.into(),
588 path: input_file_path.into(),
589 create_or_overwrite: true,
590 },
591 ),
592 ],
593 ),
594 ],
595 input_path: input_file_path.into(),
596 input_content: None,
597 edit_description: edit_description.into(),
598 assertion: EvalAssertion::new(async move |sample, _, _cx| {
599 let invalid_starts = [' ', '`', '\n'];
600 let mut message = String::new();
601 for start in invalid_starts {
602 if sample.text.starts_with(start) {
603 message.push_str(&format!("The sample starts with a {:?}\n", start));
604 break;
605 }
606 }
607 // Remove trailing newline.
608 message.pop();
609
610 if message.is_empty() {
611 Ok(EvalAssertionOutcome {
612 score: 100,
613 message: None,
614 })
615 } else {
616 Ok(EvalAssertionOutcome {
617 score: 0,
618 message: Some(message),
619 })
620 }
621 }),
622 },
623 );
624}
625
626#[test]
627#[cfg_attr(not(feature = "eval"), ignore)]
628fn eval_add_overwrite_test() {
629 let input_file_path = "root/action_log.rs";
630 let input_file_content = include_str!("evals/fixtures/add_overwrite_test/before.rs");
631 let edit_description = "Add a new test for overwriting a file in action_log.rs";
632 eval(
633 200,
634 0.5, // TODO: make this eval better
635 EvalInput {
636 conversation: vec![
637 message(
638 User,
639 [text(indoc! {"
640 Introduce a new test in `action_log.rs` to test overwriting a file.
641 That is, a file already exists, but we call `buffer_created` as if the file were new.
642 Take inspiration from all the other tests in the file.
643 "})],
644 ),
645 message(
646 Assistant,
647 [tool_use(
648 "tool_1",
649 "read_file",
650 ReadFileToolInput {
651 path: input_file_path.into(),
652 start_line: None,
653 end_line: None,
654 },
655 )],
656 ),
657 message(
658 User,
659 [tool_result(
660 "tool_1",
661 "read_file",
662 indoc! {"
663 pub struct ActionLog [L13-20]
664 tracked_buffers [L15]
665 edited_since_project_diagnostics_check [L17]
666 project [L19]
667 impl ActionLog [L22-498]
668 pub fn new [L24-30]
669 pub fn project [L32-34]
670 pub fn checked_project_diagnostics [L37-39]
671 pub fn has_edited_files_since_project_diagnostics_check [L42-44]
672 fn track_buffer_internal [L46-101]
673 fn handle_buffer_event [L103-116]
674 fn handle_buffer_edited [L118-123]
675 fn handle_buffer_file_changed [L125-158]
676 async fn maintain_diff [L160-264]
677 pub fn buffer_read [L267-269]
678 pub fn buffer_created [L272-276]
679 pub fn buffer_edited [L279-287]
680 pub fn will_delete_buffer [L289-304]
681 pub fn keep_edits_in_range [L306-364]
682 pub fn reject_edits_in_ranges [L366-459]
683 pub fn keep_all_edits [L461-473]
684 pub fn changed_buffers [L476-482]
685 pub fn stale_buffers [L485-497]
686 fn apply_non_conflicting_edits [L500-561]
687 fn diff_snapshots [L563-585]
688 fn point_to_row_edit [L587-614]
689 enum ChangeAuthor [L617-620]
690 User [L618]
691 Agent [L619]
692 enum TrackedBufferStatus [L623-627]
693 Created [L624]
694 Modified [L625]
695 Deleted [L626]
696 struct TrackedBuffer [L629-641]
697 buffer [L630]
698 base_text [L631]
699 unreviewed_changes [L632]
700 status [L633]
701 version [L634]
702 diff [L635]
703 snapshot [L636]
704 diff_update [L637]
705 _open_lsp_handle [L638]
706 _maintain_diff [L639]
707 _subscription [L640]
708 impl TrackedBuffer [L643-657]
709 fn has_changes [L644-650]
710 fn schedule_diff_update [L652-656]
711 pub struct ChangedBuffer [L659-661]
712 pub diff [L660]
713 mod tests [L664-1574]
714 fn init_logger [L678-682]
715 fn init_test [L684-691]
716 async fn test_keep_edits [L694-769]
717 async fn test_deletions [L772-854]
718 async fn test_overlapping_user_edits [L857-951]
719 async fn test_creating_files [L954-1010]
720 async fn test_deleting_files [L1013-1120]
721 async fn test_reject_edits [L1123-1255]
722 async fn test_reject_multiple_edits [L1258-1331]
723 async fn test_reject_deleted_file [L1334-1388]
724 async fn test_reject_created_file [L1391-1443]
725 async fn test_random_diffs [L1446-1535]
726 fn quiesce [L1510-1534]
727 struct HunkStatus [L1538-1542]
728 range [L1539]
729 diff_status [L1540]
730 old_text [L1541]
731 fn unreviewed_hunks [L1544-1573]
732
733 Showing symbols 1-69 (total symbols: 69)
734
735 Using the line numbers in this outline, you can call this tool again while specifying
736 the start_line and end_line fields to see the implementations of symbols in the outline.
737 "},
738 )],
739 ),
740 message(
741 Assistant,
742 [
743 text(
744 "Let's look at the existing test cases that are related to file creation to understand the pattern:",
745 ),
746 tool_use(
747 "tool_2",
748 "read_file",
749 ReadFileToolInput {
750 path: input_file_path.into(),
751 start_line: Some(953),
752 end_line: Some(1010),
753 },
754 ),
755 ],
756 ),
757 message(
758 User,
759 [tool_result(
760 "tool_2",
761 "read_file",
762 lines(input_file_content, 953..1010),
763 )],
764 ),
765 message(
766 Assistant,
767 [
768 text(
769 "Now let's look at another relevant test to understand how they're structured:",
770 ),
771 tool_use(
772 "tool_3",
773 "read_file",
774 ReadFileToolInput {
775 path: input_file_path.into(),
776 start_line: Some(1012),
777 end_line: Some(1120),
778 },
779 ),
780 ],
781 ),
782 message(
783 User,
784 [tool_result(
785 "tool_3",
786 "read_file",
787 lines(input_file_content, 1012..1120),
788 )],
789 ),
790 message(
791 Assistant,
792 [
793 text("Now let's look at how `buffer_created` is implemented:"),
794 tool_use(
795 "tool_4",
796 "read_file",
797 ReadFileToolInput {
798 path: input_file_path.into(),
799 start_line: Some(271),
800 end_line: Some(276),
801 },
802 ),
803 ],
804 ),
805 message(
806 User,
807 [tool_result(
808 "tool_4",
809 "read_file",
810 lines(input_file_content, 271..276),
811 )],
812 ),
813 message(
814 Assistant,
815 [
816 text(indoc! {"
817 Based on my understanding of the existing tests and code, I'll now create a new test for overwriting a file. In this test, we'll:
818
819 1. Create a file
820 2. Open a buffer for it
821 3. Track it in the action log
822 4. Modify the buffer
823 5. Call `buffer_created` again as if the file were new
824 6. Check that the changes are properly tracked
825
826 Let's write the test:
827 "}),
828 tool_use(
829 "tool_5",
830 "edit_file",
831 StreamingEditFileToolInput {
832 display_description: edit_description.into(),
833 path: input_file_path.into(),
834 create_or_overwrite: false,
835 },
836 ),
837 ],
838 ),
839 ],
840 input_path: input_file_path.into(),
841 input_content: Some(input_file_content.into()),
842 edit_description: edit_description.into(),
843 assertion: EvalAssertion::judge_diff(
844 "A new test for overwritten files was created, without changing any previous test",
845 ),
846 },
847 );
848}
849
850fn message(
851 role: Role,
852 contents: impl IntoIterator<Item = MessageContent>,
853) -> LanguageModelRequestMessage {
854 LanguageModelRequestMessage {
855 role,
856 content: contents.into_iter().collect(),
857 cache: false,
858 }
859}
860
861fn text(text: impl Into<String>) -> MessageContent {
862 MessageContent::Text(text.into())
863}
864
865fn lines(input: &str, range: Range<usize>) -> String {
866 input
867 .lines()
868 .skip(range.start)
869 .take(range.len())
870 .collect::<Vec<_>>()
871 .join("\n")
872}
873
874fn tool_use(
875 id: impl Into<Arc<str>>,
876 name: impl Into<Arc<str>>,
877 input: impl Serialize,
878) -> MessageContent {
879 MessageContent::ToolUse(LanguageModelToolUse {
880 id: LanguageModelToolUseId::from(id.into()),
881 name: name.into(),
882 raw_input: serde_json::to_string_pretty(&input).unwrap(),
883 input: serde_json::to_value(input).unwrap(),
884 is_input_complete: true,
885 })
886}
887
888fn tool_result(
889 id: impl Into<Arc<str>>,
890 name: impl Into<Arc<str>>,
891 result: impl Into<Arc<str>>,
892) -> MessageContent {
893 MessageContent::ToolResult(LanguageModelToolResult {
894 tool_use_id: LanguageModelToolUseId::from(id.into()),
895 tool_name: name.into(),
896 is_error: false,
897 content: result.into(),
898 })
899}
900
901#[derive(Clone)]
902struct EvalInput {
903 conversation: Vec<LanguageModelRequestMessage>,
904 input_path: PathBuf,
905 input_content: Option<String>,
906 edit_description: String,
907 assertion: EvalAssertion,
908}
909
910#[derive(Clone)]
911struct EvalSample {
912 text: String,
913 edit_output: EditAgentOutput,
914 diff: String,
915}
916
917trait AssertionFn: 'static + Send + Sync {
918 fn assert<'a>(
919 &'a self,
920 sample: &'a EvalSample,
921 judge_model: Arc<dyn LanguageModel>,
922 cx: &'a mut TestAppContext,
923 ) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>>;
924}
925
926impl<F> AssertionFn for F
927where
928 F: 'static
929 + Send
930 + Sync
931 + AsyncFn(
932 &EvalSample,
933 Arc<dyn LanguageModel>,
934 &mut TestAppContext,
935 ) -> Result<EvalAssertionOutcome>,
936{
937 fn assert<'a>(
938 &'a self,
939 sample: &'a EvalSample,
940 judge_model: Arc<dyn LanguageModel>,
941 cx: &'a mut TestAppContext,
942 ) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>> {
943 (self)(sample, judge_model, cx).boxed_local()
944 }
945}
946
947#[derive(Clone)]
948struct EvalAssertion(Arc<dyn AssertionFn>);
949
950impl EvalAssertion {
951 fn new<F>(f: F) -> Self
952 where
953 F: 'static
954 + Send
955 + Sync
956 + AsyncFn(
957 &EvalSample,
958 Arc<dyn LanguageModel>,
959 &mut TestAppContext,
960 ) -> Result<EvalAssertionOutcome>,
961 {
962 EvalAssertion(Arc::new(f))
963 }
964
965 fn assert_eq(expected: impl Into<String>) -> Self {
966 let expected = expected.into();
967 Self::new(async move |sample, _judge, _cx| {
968 Ok(EvalAssertionOutcome {
969 score: if strip_empty_lines(&sample.text) == strip_empty_lines(&expected) {
970 100
971 } else {
972 0
973 },
974 message: None,
975 })
976 })
977 }
978
979 fn judge_diff(assertions: &'static str) -> Self {
980 Self::new(async move |sample, judge, cx| {
981 let prompt = DiffJudgeTemplate {
982 diff: sample.diff.clone(),
983 assertions,
984 }
985 .render(&Templates::new())
986 .unwrap();
987
988 let request = LanguageModelRequest {
989 messages: vec![LanguageModelRequestMessage {
990 role: Role::User,
991 content: vec![prompt.into()],
992 cache: false,
993 }],
994 ..Default::default()
995 };
996 let mut response = judge
997 .stream_completion_text(request, &cx.to_async())
998 .await?;
999 let mut output = String::new();
1000 while let Some(chunk) = response.stream.next().await {
1001 let chunk = chunk?;
1002 output.push_str(&chunk);
1003 }
1004
1005 // Parse the score from the response
1006 let re = regex::Regex::new(r"<score>(\d+)</score>").unwrap();
1007 if let Some(captures) = re.captures(&output) {
1008 if let Some(score_match) = captures.get(1) {
1009 let score = score_match.as_str().parse().unwrap_or(0);
1010 return Ok(EvalAssertionOutcome {
1011 score,
1012 message: Some(output),
1013 });
1014 }
1015 }
1016
1017 Err(anyhow!(
1018 "No score found in response. Raw output: {}",
1019 output
1020 ))
1021 })
1022 }
1023
1024 async fn run(
1025 &self,
1026 input: &EvalSample,
1027 judge_model: Arc<dyn LanguageModel>,
1028 cx: &mut TestAppContext,
1029 ) -> Result<EvalAssertionOutcome> {
1030 self.0.assert(input, judge_model, cx).await
1031 }
1032}
1033
1034fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
1035 let mut evaluated_count = 0;
1036 report_progress(evaluated_count, iterations);
1037
1038 let (tx, rx) = mpsc::channel();
1039
1040 // Cache the last message in the conversation, and run one instance of the eval so that
1041 // all the next ones are cached.
1042 eval.conversation.last_mut().unwrap().cache = true;
1043 run_eval(eval.clone(), tx.clone());
1044
1045 let executor = gpui::background_executor();
1046 for _ in 1..iterations {
1047 let eval = eval.clone();
1048 let tx = tx.clone();
1049 executor.spawn(async move { run_eval(eval, tx) }).detach();
1050 }
1051 drop(tx);
1052
1053 let mut failed_count = 0;
1054 let mut failed_evals = HashMap::default();
1055 let mut errored_evals = HashMap::default();
1056 let mut eval_outputs = Vec::new();
1057 let mut cumulative_parser_metrics = EditParserMetrics::default();
1058 while let Ok(output) = rx.recv() {
1059 match output {
1060 Ok(output) => {
1061 cumulative_parser_metrics += output.sample.edit_output._parser_metrics.clone();
1062 eval_outputs.push(output.clone());
1063 if output.assertion.score < 80 {
1064 failed_count += 1;
1065 failed_evals
1066 .entry(output.sample.text.clone())
1067 .or_insert(Vec::new())
1068 .push(output);
1069 }
1070 }
1071 Err(error) => {
1072 failed_count += 1;
1073 *errored_evals.entry(format!("{:?}", error)).or_insert(0) += 1;
1074 }
1075 }
1076
1077 evaluated_count += 1;
1078 report_progress(evaluated_count, iterations);
1079 }
1080
1081 let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32;
1082 println!("Actual pass ratio: {}\n", actual_pass_ratio);
1083 if actual_pass_ratio < expected_pass_ratio {
1084 let mut errored_evals = errored_evals.into_iter().collect::<Vec<_>>();
1085 errored_evals.sort_by_key(|(_, count)| Reverse(*count));
1086 for (error, count) in errored_evals {
1087 println!("Eval errored {} times. Error: {}", count, error);
1088 }
1089
1090 let mut failed_evals = failed_evals.into_iter().collect::<Vec<_>>();
1091 failed_evals.sort_by_key(|(_, evals)| Reverse(evals.len()));
1092 for (_buffer_output, failed_evals) in failed_evals {
1093 let eval_output = failed_evals.first().unwrap();
1094 println!("Eval failed {} times", failed_evals.len());
1095 println!("{}", eval_output);
1096 }
1097
1098 panic!(
1099 "Actual pass ratio: {}\nExpected pass ratio: {}",
1100 actual_pass_ratio, expected_pass_ratio
1101 );
1102 }
1103
1104 let mismatched_tag_ratio =
1105 cumulative_parser_metrics.mismatched_tags as f32 / cumulative_parser_metrics.tags as f32;
1106 if mismatched_tag_ratio > 0.05 {
1107 for eval_output in eval_outputs {
1108 println!("{}", eval_output);
1109 }
1110 panic!("Too many mismatched tags: {:?}", cumulative_parser_metrics);
1111 }
1112}
1113
1114fn run_eval(eval: EvalInput, tx: mpsc::Sender<Result<EvalOutput>>) {
1115 let dispatcher = gpui::TestDispatcher::new(StdRng::from_entropy());
1116 let mut cx = TestAppContext::build(dispatcher, None);
1117 let output = cx.executor().block_test(async {
1118 let test = EditAgentTest::new(&mut cx).await;
1119 test.eval(eval, &mut cx).await
1120 });
1121 tx.send(output).unwrap();
1122}
1123
1124#[derive(Clone)]
1125struct EvalOutput {
1126 sample: EvalSample,
1127 assertion: EvalAssertionOutcome,
1128}
1129
1130impl Display for EvalOutput {
1131 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1132 writeln!(f, "Score: {:?}", self.assertion.score)?;
1133 if let Some(message) = self.assertion.message.as_ref() {
1134 writeln!(f, "Message: {}", message)?;
1135 }
1136
1137 writeln!(f, "Diff:\n{}", self.sample.diff)?;
1138
1139 writeln!(
1140 f,
1141 "Parser Metrics:\n{:#?}",
1142 self.sample.edit_output._parser_metrics
1143 )?;
1144 writeln!(f, "Raw Edits:\n{}", self.sample.edit_output._raw_edits)?;
1145 Ok(())
1146 }
1147}
1148
1149fn report_progress(evaluated_count: usize, iterations: usize) {
1150 print!("\r\x1b[KEvaluated {}/{}", evaluated_count, iterations);
1151 std::io::stdout().flush().unwrap();
1152}
1153
1154struct EditAgentTest {
1155 agent: EditAgent,
1156 project: Entity<Project>,
1157 judge_model: Arc<dyn LanguageModel>,
1158}
1159
1160impl EditAgentTest {
1161 async fn new(cx: &mut TestAppContext) -> Self {
1162 cx.executor().allow_parking();
1163 cx.update(settings::init);
1164 cx.update(Project::init_settings);
1165 cx.update(language::init);
1166 cx.update(gpui_tokio::init);
1167 cx.update(client::init_settings);
1168
1169 let fs = FakeFs::new(cx.executor().clone());
1170 fs.insert_tree("/root", json!({})).await;
1171 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1172 let (agent_model, judge_model) = cx
1173 .update(|cx| {
1174 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
1175 cx.set_http_client(Arc::new(http_client));
1176
1177 let client = Client::production(cx);
1178 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1179 language_model::init(client.clone(), cx);
1180 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
1181
1182 cx.spawn(async move |cx| {
1183 let agent_model =
1184 Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
1185 let judge_model =
1186 Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
1187 (agent_model.unwrap(), judge_model.unwrap())
1188 })
1189 })
1190 .await;
1191 let action_log = cx.new(|_| ActionLog::new(project.clone()));
1192
1193 Self {
1194 agent: EditAgent::new(agent_model, project.clone(), action_log, Templates::new()),
1195 project,
1196 judge_model,
1197 }
1198 }
1199
1200 async fn load_model(
1201 provider: &str,
1202 id: &str,
1203 cx: &mut AsyncApp,
1204 ) -> Result<Arc<dyn LanguageModel>> {
1205 let (provider, model) = cx.update(|cx| {
1206 let models = LanguageModelRegistry::read_global(cx);
1207 let model = models
1208 .available_models(cx)
1209 .find(|model| model.provider_id().0 == provider && model.id().0 == id)
1210 .unwrap();
1211 let provider = models.provider(&model.provider_id()).unwrap();
1212 (provider, model)
1213 })?;
1214 cx.update(|cx| provider.authenticate(cx))?.await?;
1215 Ok(model)
1216 }
1217
1218 async fn eval(&self, eval: EvalInput, cx: &mut TestAppContext) -> Result<EvalOutput> {
1219 let path = self
1220 .project
1221 .read_with(cx, |project, cx| {
1222 project.find_project_path(eval.input_path, cx)
1223 })
1224 .unwrap();
1225 let buffer = self
1226 .project
1227 .update(cx, |project, cx| project.open_buffer(path, cx))
1228 .await
1229 .unwrap();
1230 let edit_output = if let Some(input_content) = eval.input_content.as_deref() {
1231 buffer.update(cx, |buffer, cx| buffer.set_text(input_content, cx));
1232 let (edit_output, _) = self.agent.edit(
1233 buffer.clone(),
1234 eval.edit_description,
1235 eval.conversation,
1236 &mut cx.to_async(),
1237 );
1238 edit_output.await?
1239 } else {
1240 let (edit_output, _) = self.agent.overwrite(
1241 buffer.clone(),
1242 eval.edit_description,
1243 eval.conversation,
1244 &mut cx.to_async(),
1245 );
1246 edit_output.await?
1247 };
1248
1249 let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
1250 let sample = EvalSample {
1251 edit_output,
1252 diff: language::unified_diff(
1253 eval.input_content.as_deref().unwrap_or_default(),
1254 &buffer_text,
1255 ),
1256 text: buffer_text,
1257 };
1258 let assertion = eval
1259 .assertion
1260 .run(&sample, self.judge_model.clone(), cx)
1261 .await?;
1262
1263 Ok(EvalOutput { assertion, sample })
1264 }
1265}
1266
1267#[derive(Clone, Debug, Eq, PartialEq, Hash)]
1268struct EvalAssertionOutcome {
1269 score: usize,
1270 message: Option<String>,
1271}
1272
1273#[derive(Serialize)]
1274pub struct DiffJudgeTemplate {
1275 diff: String,
1276 assertions: &'static str,
1277}
1278
1279impl Template for DiffJudgeTemplate {
1280 const TEMPLATE_NAME: &'static str = "diff_judge.hbs";
1281}
1282
1283fn strip_empty_lines(text: &str) -> String {
1284 text.lines()
1285 .filter(|line| !line.trim().is_empty())
1286 .collect::<Vec<_>>()
1287 .join("\n")
1288}