1use super::*;
2use crate::{ReadFileToolInput, edit_file_tool::EditFileToolInput, grep_tool::GrepToolInput};
3use Role::*;
4use anyhow::anyhow;
5use assistant_tool::ToolRegistry;
6use client::{Client, UserStore};
7use collections::HashMap;
8use fs::FakeFs;
9use futures::{FutureExt, future::LocalBoxFuture};
10use gpui::{AppContext, TestAppContext};
11use indoc::{formatdoc, indoc};
12use language_model::{
13 LanguageModelRegistry, LanguageModelRequestTool, LanguageModelToolResult,
14 LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId,
15};
16use project::Project;
17use rand::prelude::*;
18use reqwest_client::ReqwestClient;
19use serde_json::json;
20use std::{
21 cmp::Reverse,
22 fmt::{self, Display},
23 io::Write as _,
24 sync::mpsc,
25};
26use util::path;
27
28#[test]
29#[cfg_attr(not(feature = "eval"), ignore)]
30fn eval_extract_handle_command_output() {
31 let input_file_path = "root/blame.rs";
32 let input_file_content = include_str!("evals/fixtures/extract_handle_command_output/before.rs");
33 let output_file_content = include_str!("evals/fixtures/extract_handle_command_output/after.rs");
34 let edit_description = "Extract `handle_command_output` method from `run_git_blame`.";
35 eval(
36 100,
37 0.95,
38 EvalInput {
39 conversation: vec![
40 message(
41 User,
42 [text(formatdoc! {"
43 Read the `{input_file_path}` file and extract a method in
44 the final stanza of `run_git_blame` to deal with command failures,
45 call it `handle_command_output` and take the std::process::Output as the only parameter.
46
47 Add it right next to `run_git_blame` and copy it verbatim from `run_git_blame`.
48 "})],
49 ),
50 message(
51 Assistant,
52 [tool_use(
53 "tool_1",
54 "read_file",
55 ReadFileToolInput {
56 path: input_file_path.into(),
57 start_line: None,
58 end_line: None,
59 },
60 )],
61 ),
62 message(
63 User,
64 [tool_result("tool_1", "read_file", input_file_content)],
65 ),
66 message(
67 Assistant,
68 [tool_use(
69 "tool_2",
70 "edit_file",
71 EditFileToolInput {
72 display_description: edit_description.into(),
73 path: input_file_path.into(),
74 create_or_overwrite: false,
75 },
76 )],
77 ),
78 ],
79 input_path: input_file_path.into(),
80 input_content: Some(input_file_content.into()),
81 edit_description: edit_description.into(),
82 assertion: EvalAssertion::assert_eq(output_file_content),
83 },
84 );
85}
86
87#[test]
88#[cfg_attr(not(feature = "eval"), ignore)]
89fn eval_delete_run_git_blame() {
90 let input_file_path = "root/blame.rs";
91 let input_file_content = include_str!("evals/fixtures/delete_run_git_blame/before.rs");
92 let output_file_content = include_str!("evals/fixtures/delete_run_git_blame/after.rs");
93 let edit_description = "Delete the `run_git_blame` function.";
94 eval(
95 100,
96 0.95,
97 EvalInput {
98 conversation: vec![
99 message(
100 User,
101 [text(formatdoc! {"
102 Read the `{input_file_path}` file and delete `run_git_blame`. Just that
103 one function, not its usages.
104 "})],
105 ),
106 message(
107 Assistant,
108 [tool_use(
109 "tool_1",
110 "read_file",
111 ReadFileToolInput {
112 path: input_file_path.into(),
113 start_line: None,
114 end_line: None,
115 },
116 )],
117 ),
118 message(
119 User,
120 [tool_result("tool_1", "read_file", input_file_content)],
121 ),
122 message(
123 Assistant,
124 [tool_use(
125 "tool_2",
126 "edit_file",
127 EditFileToolInput {
128 display_description: edit_description.into(),
129 path: input_file_path.into(),
130 create_or_overwrite: false,
131 },
132 )],
133 ),
134 ],
135 input_path: input_file_path.into(),
136 input_content: Some(input_file_content.into()),
137 edit_description: edit_description.into(),
138 assertion: EvalAssertion::assert_eq(output_file_content),
139 },
140 );
141}
142
143#[test]
144#[cfg_attr(not(feature = "eval"), ignore)]
145fn eval_translate_doc_comments() {
146 let input_file_path = "root/canvas.rs";
147 let input_file_content = include_str!("evals/fixtures/translate_doc_comments/before.rs");
148 let edit_description = "Translate all doc comments to Italian";
149 eval(
150 200,
151 1.,
152 EvalInput {
153 conversation: vec![
154 message(
155 User,
156 [text(formatdoc! {"
157 Read the {input_file_path} file and edit it (without overwriting it),
158 translating all the doc comments to italian.
159 "})],
160 ),
161 message(
162 Assistant,
163 [tool_use(
164 "tool_1",
165 "read_file",
166 ReadFileToolInput {
167 path: input_file_path.into(),
168 start_line: None,
169 end_line: None,
170 },
171 )],
172 ),
173 message(
174 User,
175 [tool_result("tool_1", "read_file", input_file_content)],
176 ),
177 message(
178 Assistant,
179 [tool_use(
180 "tool_2",
181 "edit_file",
182 EditFileToolInput {
183 display_description: edit_description.into(),
184 path: input_file_path.into(),
185 create_or_overwrite: false,
186 },
187 )],
188 ),
189 ],
190 input_path: input_file_path.into(),
191 input_content: Some(input_file_content.into()),
192 edit_description: edit_description.into(),
193 assertion: EvalAssertion::judge_diff("Doc comments were translated to Italian"),
194 },
195 );
196}
197
198#[test]
199#[cfg_attr(not(feature = "eval"), ignore)]
200fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
201 let input_file_path = "root/lib.rs";
202 let input_file_content =
203 include_str!("evals/fixtures/use_wasi_sdk_in_compile_parser_to_wasm/before.rs");
204 let edit_description = "Update compile_parser_to_wasm to use wasi-sdk instead of emscripten";
205 eval(
206 100,
207 0.95,
208 EvalInput {
209 conversation: vec![
210 message(
211 User,
212 [text(formatdoc! {"
213 Read the `{input_file_path}` file and change `compile_parser_to_wasm` to use `wasi-sdk` instead of emscripten.
214 Use `ureq` to download the SDK for the current platform and architecture.
215 Extract the archive into a sibling of `lib` inside the `tree-sitter` directory in the cache_dir.
216 Compile the parser to wasm using the `bin/clang` executable (or `bin/clang.exe` on windows)
217 that's inside of the archive.
218 Don't re-download the SDK if that executable already exists.
219
220 Use these clang flags: -fPIC -shared -Os -Wl,--export=tree_sitter_{{language_name}}
221
222 Here are the available wasi-sdk assets:
223 - wasi-sdk-25.0-x86_64-macos.tar.gz
224 - wasi-sdk-25.0-arm64-macos.tar.gz
225 - wasi-sdk-25.0-x86_64-linux.tar.gz
226 - wasi-sdk-25.0-arm64-linux.tar.gz
227 - wasi-sdk-25.0-x86_64-linux.tar.gz
228 - wasi-sdk-25.0-arm64-linux.tar.gz
229 - wasi-sdk-25.0-x86_64-windows.tar.gz
230 "})],
231 ),
232 message(
233 Assistant,
234 [tool_use(
235 "tool_1",
236 "read_file",
237 ReadFileToolInput {
238 path: input_file_path.into(),
239 start_line: Some(971),
240 end_line: Some(1050),
241 },
242 )],
243 ),
244 message(
245 User,
246 [tool_result(
247 "tool_1",
248 "read_file",
249 lines(input_file_content, 971..1050),
250 )],
251 ),
252 message(
253 Assistant,
254 [tool_use(
255 "tool_2",
256 "read_file",
257 ReadFileToolInput {
258 path: input_file_path.into(),
259 start_line: Some(1050),
260 end_line: Some(1100),
261 },
262 )],
263 ),
264 message(
265 User,
266 [tool_result(
267 "tool_2",
268 "read_file",
269 lines(input_file_content, 1050..1100),
270 )],
271 ),
272 message(
273 Assistant,
274 [tool_use(
275 "tool_3",
276 "read_file",
277 ReadFileToolInput {
278 path: input_file_path.into(),
279 start_line: Some(1100),
280 end_line: Some(1150),
281 },
282 )],
283 ),
284 message(
285 User,
286 [tool_result(
287 "tool_3",
288 "read_file",
289 lines(input_file_content, 1100..1150),
290 )],
291 ),
292 message(
293 Assistant,
294 [tool_use(
295 "tool_4",
296 "edit_file",
297 EditFileToolInput {
298 display_description: edit_description.into(),
299 path: input_file_path.into(),
300 create_or_overwrite: false,
301 },
302 )],
303 ),
304 ],
305 input_path: input_file_path.into(),
306 input_content: Some(input_file_content.into()),
307 edit_description: edit_description.into(),
308 assertion: EvalAssertion::judge_diff(indoc! {"
309 - The compile_parser_to_wasm method has been changed to use wasi-sdk
310 - ureq is used to download the SDK for current platform and architecture
311 "}),
312 },
313 );
314}
315
316#[test]
317#[cfg_attr(not(feature = "eval"), ignore)]
318fn eval_disable_cursor_blinking() {
319 let input_file_path = "root/editor.rs";
320 let input_file_content = include_str!("evals/fixtures/disable_cursor_blinking/before.rs");
321 let edit_description = "Comment out the call to `BlinkManager::enable`";
322 eval(
323 200,
324 0.95,
325 EvalInput {
326 conversation: vec![
327 message(User, [text("Let's research how to cursor blinking works.")]),
328 message(
329 Assistant,
330 [tool_use(
331 "tool_1",
332 "grep",
333 GrepToolInput {
334 regex: "blink".into(),
335 include_pattern: None,
336 offset: 0,
337 case_sensitive: false,
338 },
339 )],
340 ),
341 message(
342 User,
343 [tool_result(
344 "tool_1",
345 "grep",
346 [
347 lines(input_file_content, 100..400),
348 lines(input_file_content, 800..1300),
349 lines(input_file_content, 1600..2000),
350 lines(input_file_content, 5000..5500),
351 lines(input_file_content, 8000..9000),
352 lines(input_file_content, 18455..18470),
353 lines(input_file_content, 20000..20500),
354 lines(input_file_content, 21000..21300),
355 ]
356 .join("Match found:\n\n"),
357 )],
358 ),
359 message(
360 User,
361 [text(indoc! {"
362 Comment out the lines that interact with the BlinkManager.
363 Keep the outer `update` blocks, but comments everything that's inside (including if statements).
364 Don't add additional comments.
365 "})],
366 ),
367 message(
368 Assistant,
369 [tool_use(
370 "tool_4",
371 "edit_file",
372 EditFileToolInput {
373 display_description: edit_description.into(),
374 path: input_file_path.into(),
375 create_or_overwrite: false,
376 },
377 )],
378 ),
379 ],
380 input_path: input_file_path.into(),
381 input_content: Some(input_file_content.into()),
382 edit_description: edit_description.into(),
383 assertion: EvalAssertion::judge_diff(indoc! {"
384 - Calls to BlinkManager in `observe_window_activation` were commented out
385 - The call to `blink_manager.enable` above the call to show_cursor_names was commented out
386 - All the edits have valid indentation
387 "}),
388 },
389 );
390}
391
392#[test]
393#[cfg_attr(not(feature = "eval"), ignore)]
394fn eval_from_pixels_constructor() {
395 let input_file_path = "root/canvas.rs";
396 let input_file_content = include_str!("evals/fixtures/from_pixels_constructor/before.rs");
397 let edit_description = "Implement from_pixels constructor and add tests.";
398 eval(
399 100,
400 0.95,
401 EvalInput {
402 conversation: vec![
403 message(
404 User,
405 [text(indoc! {"
406 Introduce a new `from_pixels` constructor in Canvas and
407 also add tests for it in the same file.
408 "})],
409 ),
410 message(
411 Assistant,
412 [tool_use(
413 "tool_1",
414 "read_file",
415 ReadFileToolInput {
416 path: input_file_path.into(),
417 start_line: None,
418 end_line: None,
419 },
420 )],
421 ),
422 message(
423 User,
424 [tool_result("tool_1", "read_file", input_file_content)],
425 ),
426 message(
427 Assistant,
428 [tool_use(
429 "tool_2",
430 "grep",
431 GrepToolInput {
432 regex: "mod\\s+tests".into(),
433 include_pattern: Some("font-kit/src/canvas.rs".into()),
434 offset: 0,
435 case_sensitive: false,
436 },
437 )],
438 ),
439 message(User, [tool_result("tool_2", "grep", "No matches found")]),
440 message(
441 Assistant,
442 [tool_use(
443 "tool_3",
444 "grep",
445 GrepToolInput {
446 regex: "mod\\s+tests".into(),
447 include_pattern: Some("font-kit/src/**/*.rs".into()),
448 offset: 0,
449 case_sensitive: false,
450 },
451 )],
452 ),
453 message(User, [tool_result("tool_3", "grep", "No matches found")]),
454 message(
455 Assistant,
456 [tool_use(
457 "tool_4",
458 "grep",
459 GrepToolInput {
460 regex: "#\\[test\\]".into(),
461 include_pattern: Some("font-kit/src/**/*.rs".into()),
462 offset: 0,
463 case_sensitive: false,
464 },
465 )],
466 ),
467 message(
468 User,
469 [tool_result(
470 "tool_4",
471 "grep",
472 indoc! {"
473 Found 6 matches:
474
475 ## Matches in font-kit/src/loaders/core_text.rs
476
477 ### mod test › L926-936
478 ```
479 mod test {
480 use super::Font;
481 use crate::properties::{Stretch, Weight};
482
483 #[cfg(feature = \"source\")]
484 use crate::source::SystemSource;
485
486 static TEST_FONT_POSTSCRIPT_NAME: &'static str = \"ArialMT\";
487
488 #[cfg(feature = \"source\")]
489 #[test]
490 ```
491
492 55 lines remaining in ancestor node. Read the file to see all.
493
494 ### mod test › L947-951
495 ```
496 }
497
498 #[test]
499 fn test_core_text_to_css_font_weight() {
500 // Exact matches
501 ```
502
503 ### mod test › L959-963
504 ```
505 }
506
507 #[test]
508 fn test_core_text_to_css_font_stretch() {
509 // Exact matches
510 ```
511
512 ## Matches in font-kit/src/loaders/freetype.rs
513
514 ### mod test › L1238-1248
515 ```
516 mod test {
517 use crate::loaders::freetype::Font;
518
519 static PCF_FONT_PATH: &str = \"resources/tests/times-roman-pcf/timR12.pcf\";
520 static PCF_FONT_POSTSCRIPT_NAME: &str = \"Times-Roman\";
521
522 #[test]
523 fn get_pcf_postscript_name() {
524 let font = Font::from_path(PCF_FONT_PATH, 0).unwrap();
525 assert_eq!(font.postscript_name().unwrap(), PCF_FONT_POSTSCRIPT_NAME);
526 }
527 ```
528
529 1 lines remaining in ancestor node. Read the file to see all.
530
531 ## Matches in font-kit/src/sources/core_text.rs
532
533 ### mod test › L265-275
534 ```
535 mod test {
536 use crate::properties::{Stretch, Weight};
537
538 #[test]
539 fn test_css_to_core_text_font_weight() {
540 // Exact matches
541 assert_eq!(super::css_to_core_text_font_weight(Weight(100.0)), -0.7);
542 assert_eq!(super::css_to_core_text_font_weight(Weight(400.0)), 0.0);
543 assert_eq!(super::css_to_core_text_font_weight(Weight(700.0)), 0.4);
544 assert_eq!(super::css_to_core_text_font_weight(Weight(900.0)), 0.8);
545
546 ```
547
548 27 lines remaining in ancestor node. Read the file to see all.
549
550 ### mod test › L278-282
551 ```
552 }
553
554 #[test]
555 fn test_css_to_core_text_font_stretch() {
556 // Exact matches
557 ```
558 "},
559 )],
560 ),
561 message(
562 Assistant,
563 [tool_use(
564 "tool_5",
565 "edit_file",
566 EditFileToolInput {
567 display_description: edit_description.into(),
568 path: input_file_path.into(),
569 create_or_overwrite: false,
570 },
571 )],
572 ),
573 ],
574 input_path: input_file_path.into(),
575 input_content: Some(input_file_content.into()),
576 edit_description: edit_description.into(),
577 assertion: EvalAssertion::judge_diff(indoc! {"
578 - The diff contains a new `from_pixels` constructor
579 - The diff contains new tests for the `from_pixels` constructor
580 "}),
581 },
582 );
583}
584
585#[test]
586#[cfg_attr(not(feature = "eval"), ignore)]
587fn eval_zode() {
588 let input_file_path = "root/zode.py";
589 let edit_description = "Create the main Zode CLI script";
590 eval(
591 200,
592 1.,
593 EvalInput {
594 conversation: vec![
595 message(User, [text(include_str!("evals/fixtures/zode/prompt.md"))]),
596 message(
597 Assistant,
598 [
599 tool_use(
600 "tool_1",
601 "read_file",
602 ReadFileToolInput {
603 path: "root/eval/react.py".into(),
604 start_line: None,
605 end_line: None,
606 },
607 ),
608 tool_use(
609 "tool_2",
610 "read_file",
611 ReadFileToolInput {
612 path: "root/eval/react_test.py".into(),
613 start_line: None,
614 end_line: None,
615 },
616 ),
617 ],
618 ),
619 message(
620 User,
621 [
622 tool_result(
623 "tool_1",
624 "read_file",
625 include_str!("evals/fixtures/zode/react.py"),
626 ),
627 tool_result(
628 "tool_2",
629 "read_file",
630 include_str!("evals/fixtures/zode/react_test.py"),
631 ),
632 ],
633 ),
634 message(
635 Assistant,
636 [
637 text(
638 "Now that I understand what we need to build, I'll create the main Python script:",
639 ),
640 tool_use(
641 "tool_3",
642 "edit_file",
643 EditFileToolInput {
644 display_description: edit_description.into(),
645 path: input_file_path.into(),
646 create_or_overwrite: true,
647 },
648 ),
649 ],
650 ),
651 ],
652 input_path: input_file_path.into(),
653 input_content: None,
654 edit_description: edit_description.into(),
655 assertion: EvalAssertion::new(async move |sample, _, _cx| {
656 let invalid_starts = [' ', '`', '\n'];
657 let mut message = String::new();
658 for start in invalid_starts {
659 if sample.text.starts_with(start) {
660 message.push_str(&format!("The sample starts with a {:?}\n", start));
661 break;
662 }
663 }
664 // Remove trailing newline.
665 message.pop();
666
667 if message.is_empty() {
668 Ok(EvalAssertionOutcome {
669 score: 100,
670 message: None,
671 })
672 } else {
673 Ok(EvalAssertionOutcome {
674 score: 0,
675 message: Some(message),
676 })
677 }
678 }),
679 },
680 );
681}
682
683#[test]
684#[cfg_attr(not(feature = "eval"), ignore)]
685fn eval_add_overwrite_test() {
686 let input_file_path = "root/action_log.rs";
687 let input_file_content = include_str!("evals/fixtures/add_overwrite_test/before.rs");
688 let edit_description = "Add a new test for overwriting a file in action_log.rs";
689 eval(
690 200,
691 0.5, // TODO: make this eval better
692 EvalInput {
693 conversation: vec![
694 message(
695 User,
696 [text(indoc! {"
697 Introduce a new test in `action_log.rs` to test overwriting a file.
698 That is, a file already exists, but we call `buffer_created` as if the file were new.
699 Take inspiration from all the other tests in the file.
700 "})],
701 ),
702 message(
703 Assistant,
704 [tool_use(
705 "tool_1",
706 "read_file",
707 ReadFileToolInput {
708 path: input_file_path.into(),
709 start_line: None,
710 end_line: None,
711 },
712 )],
713 ),
714 message(
715 User,
716 [tool_result(
717 "tool_1",
718 "read_file",
719 indoc! {"
720 pub struct ActionLog [L13-20]
721 tracked_buffers [L15]
722 edited_since_project_diagnostics_check [L17]
723 project [L19]
724 impl ActionLog [L22-498]
725 pub fn new [L24-30]
726 pub fn project [L32-34]
727 pub fn checked_project_diagnostics [L37-39]
728 pub fn has_edited_files_since_project_diagnostics_check [L42-44]
729 fn track_buffer_internal [L46-101]
730 fn handle_buffer_event [L103-116]
731 fn handle_buffer_edited [L118-123]
732 fn handle_buffer_file_changed [L125-158]
733 async fn maintain_diff [L160-264]
734 pub fn buffer_read [L267-269]
735 pub fn buffer_created [L272-276]
736 pub fn buffer_edited [L279-287]
737 pub fn will_delete_buffer [L289-304]
738 pub fn keep_edits_in_range [L306-364]
739 pub fn reject_edits_in_ranges [L366-459]
740 pub fn keep_all_edits [L461-473]
741 pub fn changed_buffers [L476-482]
742 pub fn stale_buffers [L485-497]
743 fn apply_non_conflicting_edits [L500-561]
744 fn diff_snapshots [L563-585]
745 fn point_to_row_edit [L587-614]
746 enum ChangeAuthor [L617-620]
747 User [L618]
748 Agent [L619]
749 enum TrackedBufferStatus [L623-627]
750 Created [L624]
751 Modified [L625]
752 Deleted [L626]
753 struct TrackedBuffer [L629-641]
754 buffer [L630]
755 base_text [L631]
756 unreviewed_changes [L632]
757 status [L633]
758 version [L634]
759 diff [L635]
760 snapshot [L636]
761 diff_update [L637]
762 _open_lsp_handle [L638]
763 _maintain_diff [L639]
764 _subscription [L640]
765 impl TrackedBuffer [L643-657]
766 fn has_changes [L644-650]
767 fn schedule_diff_update [L652-656]
768 pub struct ChangedBuffer [L659-661]
769 pub diff [L660]
770 mod tests [L664-1574]
771 fn init_logger [L678-682]
772 fn init_test [L684-691]
773 async fn test_keep_edits [L694-769]
774 async fn test_deletions [L772-854]
775 async fn test_overlapping_user_edits [L857-951]
776 async fn test_creating_files [L954-1010]
777 async fn test_deleting_files [L1013-1120]
778 async fn test_reject_edits [L1123-1255]
779 async fn test_reject_multiple_edits [L1258-1331]
780 async fn test_reject_deleted_file [L1334-1388]
781 async fn test_reject_created_file [L1391-1443]
782 async fn test_random_diffs [L1446-1535]
783 fn quiesce [L1510-1534]
784 struct HunkStatus [L1538-1542]
785 range [L1539]
786 diff_status [L1540]
787 old_text [L1541]
788 fn unreviewed_hunks [L1544-1573]
789
790 Showing symbols 1-69 (total symbols: 69)
791
792 Using the line numbers in this outline, you can call this tool again while specifying
793 the start_line and end_line fields to see the implementations of symbols in the outline.
794 "},
795 )],
796 ),
797 message(
798 Assistant,
799 [
800 text(
801 "Let's look at the existing test cases that are related to file creation to understand the pattern:",
802 ),
803 tool_use(
804 "tool_2",
805 "read_file",
806 ReadFileToolInput {
807 path: input_file_path.into(),
808 start_line: Some(953),
809 end_line: Some(1010),
810 },
811 ),
812 ],
813 ),
814 message(
815 User,
816 [tool_result(
817 "tool_2",
818 "read_file",
819 lines(input_file_content, 953..1010),
820 )],
821 ),
822 message(
823 Assistant,
824 [
825 text(
826 "Now let's look at another relevant test to understand how they're structured:",
827 ),
828 tool_use(
829 "tool_3",
830 "read_file",
831 ReadFileToolInput {
832 path: input_file_path.into(),
833 start_line: Some(1012),
834 end_line: Some(1120),
835 },
836 ),
837 ],
838 ),
839 message(
840 User,
841 [tool_result(
842 "tool_3",
843 "read_file",
844 lines(input_file_content, 1012..1120),
845 )],
846 ),
847 message(
848 Assistant,
849 [
850 text("Now let's look at how `buffer_created` is implemented:"),
851 tool_use(
852 "tool_4",
853 "read_file",
854 ReadFileToolInput {
855 path: input_file_path.into(),
856 start_line: Some(271),
857 end_line: Some(276),
858 },
859 ),
860 ],
861 ),
862 message(
863 User,
864 [tool_result(
865 "tool_4",
866 "read_file",
867 lines(input_file_content, 271..276),
868 )],
869 ),
870 message(
871 Assistant,
872 [
873 text(indoc! {"
874 Based on my understanding of the existing tests and code, I'll now create a new test for overwriting a file. In this test, we'll:
875
876 1. Create a file
877 2. Open a buffer for it
878 3. Track it in the action log
879 4. Modify the buffer
880 5. Call `buffer_created` again as if the file were new
881 6. Check that the changes are properly tracked
882
883 Let's write the test:
884 "}),
885 tool_use(
886 "tool_5",
887 "edit_file",
888 EditFileToolInput {
889 display_description: edit_description.into(),
890 path: input_file_path.into(),
891 create_or_overwrite: false,
892 },
893 ),
894 ],
895 ),
896 ],
897 input_path: input_file_path.into(),
898 input_content: Some(input_file_content.into()),
899 edit_description: edit_description.into(),
900 assertion: EvalAssertion::judge_diff(
901 "A new test for overwritten files was created, without changing any previous test",
902 ),
903 },
904 );
905}
906
907fn message(
908 role: Role,
909 contents: impl IntoIterator<Item = MessageContent>,
910) -> LanguageModelRequestMessage {
911 LanguageModelRequestMessage {
912 role,
913 content: contents.into_iter().collect(),
914 cache: false,
915 }
916}
917
918fn text(text: impl Into<String>) -> MessageContent {
919 MessageContent::Text(text.into())
920}
921
922fn lines(input: &str, range: Range<usize>) -> String {
923 input
924 .lines()
925 .skip(range.start)
926 .take(range.len())
927 .collect::<Vec<_>>()
928 .join("\n")
929}
930
931fn tool_use(
932 id: impl Into<Arc<str>>,
933 name: impl Into<Arc<str>>,
934 input: impl Serialize,
935) -> MessageContent {
936 MessageContent::ToolUse(LanguageModelToolUse {
937 id: LanguageModelToolUseId::from(id.into()),
938 name: name.into(),
939 raw_input: serde_json::to_string_pretty(&input).unwrap(),
940 input: serde_json::to_value(input).unwrap(),
941 is_input_complete: true,
942 })
943}
944
945fn tool_result(
946 id: impl Into<Arc<str>>,
947 name: impl Into<Arc<str>>,
948 result: impl Into<Arc<str>>,
949) -> MessageContent {
950 MessageContent::ToolResult(LanguageModelToolResult {
951 tool_use_id: LanguageModelToolUseId::from(id.into()),
952 tool_name: name.into(),
953 is_error: false,
954 content: LanguageModelToolResultContent::Text(result.into()),
955 output: None,
956 })
957}
958
959#[derive(Clone)]
960struct EvalInput {
961 conversation: Vec<LanguageModelRequestMessage>,
962 input_path: PathBuf,
963 input_content: Option<String>,
964 edit_description: String,
965 assertion: EvalAssertion,
966}
967
968#[derive(Clone)]
969struct EvalSample {
970 text: String,
971 edit_output: EditAgentOutput,
972 diff: String,
973}
974
975trait AssertionFn: 'static + Send + Sync {
976 fn assert<'a>(
977 &'a self,
978 sample: &'a EvalSample,
979 judge_model: Arc<dyn LanguageModel>,
980 cx: &'a mut TestAppContext,
981 ) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>>;
982}
983
984impl<F> AssertionFn for F
985where
986 F: 'static
987 + Send
988 + Sync
989 + AsyncFn(
990 &EvalSample,
991 Arc<dyn LanguageModel>,
992 &mut TestAppContext,
993 ) -> Result<EvalAssertionOutcome>,
994{
995 fn assert<'a>(
996 &'a self,
997 sample: &'a EvalSample,
998 judge_model: Arc<dyn LanguageModel>,
999 cx: &'a mut TestAppContext,
1000 ) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>> {
1001 (self)(sample, judge_model, cx).boxed_local()
1002 }
1003}
1004
1005#[derive(Clone)]
1006struct EvalAssertion(Arc<dyn AssertionFn>);
1007
1008impl EvalAssertion {
1009 fn new<F>(f: F) -> Self
1010 where
1011 F: 'static
1012 + Send
1013 + Sync
1014 + AsyncFn(
1015 &EvalSample,
1016 Arc<dyn LanguageModel>,
1017 &mut TestAppContext,
1018 ) -> Result<EvalAssertionOutcome>,
1019 {
1020 EvalAssertion(Arc::new(f))
1021 }
1022
1023 fn assert_eq(expected: impl Into<String>) -> Self {
1024 let expected = expected.into();
1025 Self::new(async move |sample, _judge, _cx| {
1026 Ok(EvalAssertionOutcome {
1027 score: if strip_empty_lines(&sample.text) == strip_empty_lines(&expected) {
1028 100
1029 } else {
1030 0
1031 },
1032 message: None,
1033 })
1034 })
1035 }
1036
1037 fn judge_diff(assertions: &'static str) -> Self {
1038 Self::new(async move |sample, judge, cx| {
1039 let prompt = DiffJudgeTemplate {
1040 diff: sample.diff.clone(),
1041 assertions,
1042 }
1043 .render(&Templates::new())
1044 .unwrap();
1045
1046 let request = LanguageModelRequest {
1047 messages: vec![LanguageModelRequestMessage {
1048 role: Role::User,
1049 content: vec![prompt.into()],
1050 cache: false,
1051 }],
1052 ..Default::default()
1053 };
1054 let mut response = judge
1055 .stream_completion_text(request, &cx.to_async())
1056 .await?;
1057 let mut output = String::new();
1058 while let Some(chunk) = response.stream.next().await {
1059 let chunk = chunk?;
1060 output.push_str(&chunk);
1061 }
1062
1063 // Parse the score from the response
1064 let re = regex::Regex::new(r"<score>(\d+)</score>").unwrap();
1065 if let Some(captures) = re.captures(&output) {
1066 if let Some(score_match) = captures.get(1) {
1067 let score = score_match.as_str().parse().unwrap_or(0);
1068 return Ok(EvalAssertionOutcome {
1069 score,
1070 message: Some(output),
1071 });
1072 }
1073 }
1074
1075 Err(anyhow!(
1076 "No score found in response. Raw output: {}",
1077 output
1078 ))
1079 })
1080 }
1081
1082 async fn run(
1083 &self,
1084 input: &EvalSample,
1085 judge_model: Arc<dyn LanguageModel>,
1086 cx: &mut TestAppContext,
1087 ) -> Result<EvalAssertionOutcome> {
1088 self.0.assert(input, judge_model, cx).await
1089 }
1090}
1091
1092fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
1093 let mut evaluated_count = 0;
1094 let mut failed_count = 0;
1095 report_progress(evaluated_count, failed_count, iterations);
1096
1097 let (tx, rx) = mpsc::channel();
1098
1099 // Cache the last message in the conversation, and run one instance of the eval so that
1100 // all the next ones are cached.
1101 eval.conversation.last_mut().unwrap().cache = true;
1102 run_eval(eval.clone(), tx.clone());
1103
1104 let executor = gpui::background_executor();
1105 for _ in 1..iterations {
1106 let eval = eval.clone();
1107 let tx = tx.clone();
1108 executor.spawn(async move { run_eval(eval, tx) }).detach();
1109 }
1110 drop(tx);
1111
1112 let mut failed_evals = HashMap::default();
1113 let mut errored_evals = HashMap::default();
1114 let mut eval_outputs = Vec::new();
1115 let mut cumulative_parser_metrics = EditParserMetrics::default();
1116 while let Ok(output) = rx.recv() {
1117 match output {
1118 Ok(output) => {
1119 cumulative_parser_metrics += output.sample.edit_output.parser_metrics.clone();
1120 eval_outputs.push(output.clone());
1121 if output.assertion.score < 80 {
1122 failed_count += 1;
1123 failed_evals
1124 .entry(output.sample.text.clone())
1125 .or_insert(Vec::new())
1126 .push(output);
1127 }
1128 }
1129 Err(error) => {
1130 failed_count += 1;
1131 *errored_evals.entry(format!("{:?}", error)).or_insert(0) += 1;
1132 }
1133 }
1134
1135 evaluated_count += 1;
1136 report_progress(evaluated_count, failed_count, iterations);
1137 }
1138
1139 let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32;
1140 println!("Actual pass ratio: {}\n", actual_pass_ratio);
1141 if actual_pass_ratio < expected_pass_ratio {
1142 let mut errored_evals = errored_evals.into_iter().collect::<Vec<_>>();
1143 errored_evals.sort_by_key(|(_, count)| Reverse(*count));
1144 for (error, count) in errored_evals {
1145 println!("Eval errored {} times. Error: {}", count, error);
1146 }
1147
1148 let mut failed_evals = failed_evals.into_iter().collect::<Vec<_>>();
1149 failed_evals.sort_by_key(|(_, evals)| Reverse(evals.len()));
1150 for (_buffer_output, failed_evals) in failed_evals {
1151 let eval_output = failed_evals.first().unwrap();
1152 println!("Eval failed {} times", failed_evals.len());
1153 println!("{}", eval_output);
1154 }
1155
1156 panic!(
1157 "Actual pass ratio: {}\nExpected pass ratio: {}",
1158 actual_pass_ratio, expected_pass_ratio
1159 );
1160 }
1161
1162 let mismatched_tag_ratio =
1163 cumulative_parser_metrics.mismatched_tags as f32 / cumulative_parser_metrics.tags as f32;
1164 if mismatched_tag_ratio > 0.05 {
1165 for eval_output in eval_outputs {
1166 println!("{}", eval_output);
1167 }
1168 panic!("Too many mismatched tags: {:?}", cumulative_parser_metrics);
1169 }
1170}
1171
1172fn run_eval(eval: EvalInput, tx: mpsc::Sender<Result<EvalOutput>>) {
1173 let dispatcher = gpui::TestDispatcher::new(StdRng::from_entropy());
1174 let mut cx = TestAppContext::build(dispatcher, None);
1175 let output = cx.executor().block_test(async {
1176 let test = EditAgentTest::new(&mut cx).await;
1177 test.eval(eval, &mut cx).await
1178 });
1179 tx.send(output).unwrap();
1180}
1181
1182#[derive(Clone)]
1183struct EvalOutput {
1184 sample: EvalSample,
1185 assertion: EvalAssertionOutcome,
1186}
1187
1188impl Display for EvalOutput {
1189 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1190 writeln!(f, "Score: {:?}", self.assertion.score)?;
1191 if let Some(message) = self.assertion.message.as_ref() {
1192 writeln!(f, "Message: {}", message)?;
1193 }
1194
1195 writeln!(f, "Diff:\n{}", self.sample.diff)?;
1196
1197 writeln!(
1198 f,
1199 "Parser Metrics:\n{:#?}",
1200 self.sample.edit_output.parser_metrics
1201 )?;
1202 writeln!(f, "Raw Edits:\n{}", self.sample.edit_output.raw_edits)?;
1203 Ok(())
1204 }
1205}
1206
1207fn report_progress(evaluated_count: usize, failed_count: usize, iterations: usize) {
1208 let passed_count = evaluated_count - failed_count;
1209 let passed_ratio = if evaluated_count == 0 {
1210 0.0
1211 } else {
1212 passed_count as f64 / evaluated_count as f64
1213 };
1214 print!(
1215 "\r\x1b[KEvaluated {}/{} ({:.2}%)",
1216 evaluated_count,
1217 iterations,
1218 passed_ratio * 100.0
1219 );
1220 std::io::stdout().flush().unwrap();
1221}
1222
1223struct EditAgentTest {
1224 agent: EditAgent,
1225 project: Entity<Project>,
1226 judge_model: Arc<dyn LanguageModel>,
1227}
1228
1229impl EditAgentTest {
1230 async fn new(cx: &mut TestAppContext) -> Self {
1231 cx.executor().allow_parking();
1232
1233 let fs = FakeFs::new(cx.executor().clone());
1234 cx.update(|cx| {
1235 settings::init(cx);
1236 gpui_tokio::init(cx);
1237 let http_client = Arc::new(ReqwestClient::user_agent("agent tests").unwrap());
1238 cx.set_http_client(http_client);
1239
1240 client::init_settings(cx);
1241 let client = Client::production(cx);
1242 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1243
1244 settings::init(cx);
1245 Project::init_settings(cx);
1246 language::init(cx);
1247 language_model::init(client.clone(), cx);
1248 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
1249 crate::init(client.http_client(), cx);
1250 });
1251
1252 fs.insert_tree("/root", json!({})).await;
1253 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1254 let (agent_model, judge_model) = cx
1255 .update(|cx| {
1256 cx.spawn(async move |cx| {
1257 let agent_model =
1258 Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
1259 let judge_model =
1260 Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
1261 (agent_model.unwrap(), judge_model.unwrap())
1262 })
1263 })
1264 .await;
1265 let action_log = cx.new(|_| ActionLog::new(project.clone()));
1266
1267 Self {
1268 agent: EditAgent::new(agent_model, project.clone(), action_log, Templates::new()),
1269 project,
1270 judge_model,
1271 }
1272 }
1273
1274 async fn load_model(
1275 provider: &str,
1276 id: &str,
1277 cx: &mut AsyncApp,
1278 ) -> Result<Arc<dyn LanguageModel>> {
1279 let (provider, model) = cx.update(|cx| {
1280 let models = LanguageModelRegistry::read_global(cx);
1281 let model = models
1282 .available_models(cx)
1283 .find(|model| model.provider_id().0 == provider && model.id().0 == id)
1284 .unwrap();
1285 let provider = models.provider(&model.provider_id()).unwrap();
1286 (provider, model)
1287 })?;
1288 cx.update(|cx| provider.authenticate(cx))?.await?;
1289 Ok(model)
1290 }
1291
1292 async fn eval(&self, eval: EvalInput, cx: &mut TestAppContext) -> Result<EvalOutput> {
1293 let path = self
1294 .project
1295 .read_with(cx, |project, cx| {
1296 project.find_project_path(eval.input_path, cx)
1297 })
1298 .unwrap();
1299 let buffer = self
1300 .project
1301 .update(cx, |project, cx| project.open_buffer(path, cx))
1302 .await
1303 .unwrap();
1304 let conversation = LanguageModelRequest {
1305 messages: eval.conversation,
1306 tools: cx.update(|cx| {
1307 ToolRegistry::default_global(cx)
1308 .tools()
1309 .into_iter()
1310 .filter_map(|tool| {
1311 let input_schema = tool
1312 .input_schema(self.agent.model.tool_input_format())
1313 .ok()?;
1314 Some(LanguageModelRequestTool {
1315 name: tool.name(),
1316 description: tool.description(),
1317 input_schema,
1318 })
1319 })
1320 .collect()
1321 }),
1322 ..Default::default()
1323 };
1324 let edit_output = if let Some(input_content) = eval.input_content.as_deref() {
1325 buffer.update(cx, |buffer, cx| buffer.set_text(input_content, cx));
1326 let (edit_output, _) = self.agent.edit(
1327 buffer.clone(),
1328 eval.edit_description,
1329 &conversation,
1330 &mut cx.to_async(),
1331 );
1332 edit_output.await?
1333 } else {
1334 let (edit_output, _) = self.agent.overwrite(
1335 buffer.clone(),
1336 eval.edit_description,
1337 &conversation,
1338 &mut cx.to_async(),
1339 );
1340 edit_output.await?
1341 };
1342
1343 let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
1344 let sample = EvalSample {
1345 edit_output,
1346 diff: language::unified_diff(
1347 eval.input_content.as_deref().unwrap_or_default(),
1348 &buffer_text,
1349 ),
1350 text: buffer_text,
1351 };
1352 let assertion = eval
1353 .assertion
1354 .run(&sample, self.judge_model.clone(), cx)
1355 .await?;
1356
1357 Ok(EvalOutput { assertion, sample })
1358 }
1359}
1360
1361#[derive(Clone, Debug, Eq, PartialEq, Hash)]
1362struct EvalAssertionOutcome {
1363 score: usize,
1364 message: Option<String>,
1365}
1366
1367#[derive(Serialize)]
1368pub struct DiffJudgeTemplate {
1369 diff: String,
1370 assertions: &'static str,
1371}
1372
1373impl Template for DiffJudgeTemplate {
1374 const TEMPLATE_NAME: &'static str = "diff_judge.hbs";
1375}
1376
1377fn strip_empty_lines(text: &str) -> String {
1378 text.lines()
1379 .filter(|line| !line.trim().is_empty())
1380 .collect::<Vec<_>>()
1381 .join("\n")
1382}