1use super::*;
2use crate::{
3 ReadFileToolInput,
4 edit_file_tool::{EditFileMode, EditFileToolInput},
5 grep_tool::GrepToolInput,
6};
7use Role::*;
8use anyhow::anyhow;
9use assistant_tool::ToolRegistry;
10use client::{Client, UserStore};
11use collections::HashMap;
12use fs::FakeFs;
13use futures::{FutureExt, future::LocalBoxFuture};
14use gpui::{AppContext, TestAppContext};
15use indoc::{formatdoc, indoc};
16use language_model::{
17 LanguageModelRegistry, LanguageModelRequestTool, LanguageModelToolResult,
18 LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId,
19};
20use project::Project;
21use rand::prelude::*;
22use reqwest_client::ReqwestClient;
23use serde_json::json;
24use std::{
25 cmp::Reverse,
26 fmt::{self, Display},
27 io::Write as _,
28 sync::mpsc,
29};
30use util::path;
31
32#[test]
33#[cfg_attr(not(feature = "eval"), ignore)]
34fn eval_extract_handle_command_output() {
35 let input_file_path = "root/blame.rs";
36 let input_file_content = include_str!("evals/fixtures/extract_handle_command_output/before.rs");
37 let output_file_content = include_str!("evals/fixtures/extract_handle_command_output/after.rs");
38 let edit_description = "Extract `handle_command_output` method from `run_git_blame`.";
39 eval(
40 100,
41 0.95,
42 EvalInput {
43 conversation: vec![
44 message(
45 User,
46 [text(formatdoc! {"
47 Read the `{input_file_path}` file and extract a method in
48 the final stanza of `run_git_blame` to deal with command failures,
49 call it `handle_command_output` and take the std::process::Output as the only parameter.
50
51 Add it right next to `run_git_blame` and copy it verbatim from `run_git_blame`.
52 "})],
53 ),
54 message(
55 Assistant,
56 [tool_use(
57 "tool_1",
58 "read_file",
59 ReadFileToolInput {
60 path: input_file_path.into(),
61 start_line: None,
62 end_line: None,
63 },
64 )],
65 ),
66 message(
67 User,
68 [tool_result("tool_1", "read_file", input_file_content)],
69 ),
70 message(
71 Assistant,
72 [tool_use(
73 "tool_2",
74 "edit_file",
75 EditFileToolInput {
76 display_description: edit_description.into(),
77 path: input_file_path.into(),
78 mode: EditFileMode::Edit,
79 },
80 )],
81 ),
82 ],
83 input_path: input_file_path.into(),
84 input_content: Some(input_file_content.into()),
85 edit_description: edit_description.into(),
86 assertion: EvalAssertion::assert_eq(output_file_content),
87 },
88 );
89}
90
91#[test]
92#[cfg_attr(not(feature = "eval"), ignore)]
93fn eval_delete_run_git_blame() {
94 let input_file_path = "root/blame.rs";
95 let input_file_content = include_str!("evals/fixtures/delete_run_git_blame/before.rs");
96 let output_file_content = include_str!("evals/fixtures/delete_run_git_blame/after.rs");
97 let edit_description = "Delete the `run_git_blame` function.";
98 eval(
99 100,
100 0.95,
101 EvalInput {
102 conversation: vec![
103 message(
104 User,
105 [text(formatdoc! {"
106 Read the `{input_file_path}` file and delete `run_git_blame`. Just that
107 one function, not its usages.
108 "})],
109 ),
110 message(
111 Assistant,
112 [tool_use(
113 "tool_1",
114 "read_file",
115 ReadFileToolInput {
116 path: input_file_path.into(),
117 start_line: None,
118 end_line: None,
119 },
120 )],
121 ),
122 message(
123 User,
124 [tool_result("tool_1", "read_file", input_file_content)],
125 ),
126 message(
127 Assistant,
128 [tool_use(
129 "tool_2",
130 "edit_file",
131 EditFileToolInput {
132 display_description: edit_description.into(),
133 path: input_file_path.into(),
134 mode: EditFileMode::Edit,
135 },
136 )],
137 ),
138 ],
139 input_path: input_file_path.into(),
140 input_content: Some(input_file_content.into()),
141 edit_description: edit_description.into(),
142 assertion: EvalAssertion::assert_eq(output_file_content),
143 },
144 );
145}
146
147#[test]
148#[cfg_attr(not(feature = "eval"), ignore)]
149fn eval_translate_doc_comments() {
150 let input_file_path = "root/canvas.rs";
151 let input_file_content = include_str!("evals/fixtures/translate_doc_comments/before.rs");
152 let edit_description = "Translate all doc comments to Italian";
153 eval(
154 200,
155 1.,
156 EvalInput {
157 conversation: vec![
158 message(
159 User,
160 [text(formatdoc! {"
161 Read the {input_file_path} file and edit it (without overwriting it),
162 translating all the doc comments to italian.
163 "})],
164 ),
165 message(
166 Assistant,
167 [tool_use(
168 "tool_1",
169 "read_file",
170 ReadFileToolInput {
171 path: input_file_path.into(),
172 start_line: None,
173 end_line: None,
174 },
175 )],
176 ),
177 message(
178 User,
179 [tool_result("tool_1", "read_file", input_file_content)],
180 ),
181 message(
182 Assistant,
183 [tool_use(
184 "tool_2",
185 "edit_file",
186 EditFileToolInput {
187 display_description: edit_description.into(),
188 path: input_file_path.into(),
189 mode: EditFileMode::Edit,
190 },
191 )],
192 ),
193 ],
194 input_path: input_file_path.into(),
195 input_content: Some(input_file_content.into()),
196 edit_description: edit_description.into(),
197 assertion: EvalAssertion::judge_diff("Doc comments were translated to Italian"),
198 },
199 );
200}
201
202#[test]
203#[cfg_attr(not(feature = "eval"), ignore)]
204fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
205 let input_file_path = "root/lib.rs";
206 let input_file_content =
207 include_str!("evals/fixtures/use_wasi_sdk_in_compile_parser_to_wasm/before.rs");
208 let edit_description = "Update compile_parser_to_wasm to use wasi-sdk instead of emscripten";
209 eval(
210 100,
211 0.95,
212 EvalInput {
213 conversation: vec![
214 message(
215 User,
216 [text(formatdoc! {"
217 Read the `{input_file_path}` file and change `compile_parser_to_wasm` to use `wasi-sdk` instead of emscripten.
218 Use `ureq` to download the SDK for the current platform and architecture.
219 Extract the archive into a sibling of `lib` inside the `tree-sitter` directory in the cache_dir.
220 Compile the parser to wasm using the `bin/clang` executable (or `bin/clang.exe` on windows)
221 that's inside of the archive.
222 Don't re-download the SDK if that executable already exists.
223
224 Use these clang flags: -fPIC -shared -Os -Wl,--export=tree_sitter_{{language_name}}
225
226 Here are the available wasi-sdk assets:
227 - wasi-sdk-25.0-x86_64-macos.tar.gz
228 - wasi-sdk-25.0-arm64-macos.tar.gz
229 - wasi-sdk-25.0-x86_64-linux.tar.gz
230 - wasi-sdk-25.0-arm64-linux.tar.gz
231 - wasi-sdk-25.0-x86_64-linux.tar.gz
232 - wasi-sdk-25.0-arm64-linux.tar.gz
233 - wasi-sdk-25.0-x86_64-windows.tar.gz
234 "})],
235 ),
236 message(
237 Assistant,
238 [tool_use(
239 "tool_1",
240 "read_file",
241 ReadFileToolInput {
242 path: input_file_path.into(),
243 start_line: Some(971),
244 end_line: Some(1050),
245 },
246 )],
247 ),
248 message(
249 User,
250 [tool_result(
251 "tool_1",
252 "read_file",
253 lines(input_file_content, 971..1050),
254 )],
255 ),
256 message(
257 Assistant,
258 [tool_use(
259 "tool_2",
260 "read_file",
261 ReadFileToolInput {
262 path: input_file_path.into(),
263 start_line: Some(1050),
264 end_line: Some(1100),
265 },
266 )],
267 ),
268 message(
269 User,
270 [tool_result(
271 "tool_2",
272 "read_file",
273 lines(input_file_content, 1050..1100),
274 )],
275 ),
276 message(
277 Assistant,
278 [tool_use(
279 "tool_3",
280 "read_file",
281 ReadFileToolInput {
282 path: input_file_path.into(),
283 start_line: Some(1100),
284 end_line: Some(1150),
285 },
286 )],
287 ),
288 message(
289 User,
290 [tool_result(
291 "tool_3",
292 "read_file",
293 lines(input_file_content, 1100..1150),
294 )],
295 ),
296 message(
297 Assistant,
298 [tool_use(
299 "tool_4",
300 "edit_file",
301 EditFileToolInput {
302 display_description: edit_description.into(),
303 path: input_file_path.into(),
304 mode: EditFileMode::Edit,
305 },
306 )],
307 ),
308 ],
309 input_path: input_file_path.into(),
310 input_content: Some(input_file_content.into()),
311 edit_description: edit_description.into(),
312 assertion: EvalAssertion::judge_diff(indoc! {"
313 - The compile_parser_to_wasm method has been changed to use wasi-sdk
314 - ureq is used to download the SDK for current platform and architecture
315 "}),
316 },
317 );
318}
319
320#[test]
321#[cfg_attr(not(feature = "eval"), ignore)]
322fn eval_disable_cursor_blinking() {
323 let input_file_path = "root/editor.rs";
324 let input_file_content = include_str!("evals/fixtures/disable_cursor_blinking/before.rs");
325 let edit_description = "Comment out the call to `BlinkManager::enable`";
326 eval(
327 200,
328 0.95,
329 EvalInput {
330 conversation: vec![
331 message(User, [text("Let's research how to cursor blinking works.")]),
332 message(
333 Assistant,
334 [tool_use(
335 "tool_1",
336 "grep",
337 GrepToolInput {
338 regex: "blink".into(),
339 include_pattern: None,
340 offset: 0,
341 case_sensitive: false,
342 },
343 )],
344 ),
345 message(
346 User,
347 [tool_result(
348 "tool_1",
349 "grep",
350 [
351 lines(input_file_content, 100..400),
352 lines(input_file_content, 800..1300),
353 lines(input_file_content, 1600..2000),
354 lines(input_file_content, 5000..5500),
355 lines(input_file_content, 8000..9000),
356 lines(input_file_content, 18455..18470),
357 lines(input_file_content, 20000..20500),
358 lines(input_file_content, 21000..21300),
359 ]
360 .join("Match found:\n\n"),
361 )],
362 ),
363 message(
364 User,
365 [text(indoc! {"
366 Comment out the lines that interact with the BlinkManager.
367 Keep the outer `update` blocks, but comments everything that's inside (including if statements).
368 Don't add additional comments.
369 "})],
370 ),
371 message(
372 Assistant,
373 [tool_use(
374 "tool_4",
375 "edit_file",
376 EditFileToolInput {
377 display_description: edit_description.into(),
378 path: input_file_path.into(),
379 mode: EditFileMode::Edit,
380 },
381 )],
382 ),
383 ],
384 input_path: input_file_path.into(),
385 input_content: Some(input_file_content.into()),
386 edit_description: edit_description.into(),
387 assertion: EvalAssertion::judge_diff(indoc! {"
388 - Calls to BlinkManager in `observe_window_activation` were commented out
389 - The call to `blink_manager.enable` above the call to show_cursor_names was commented out
390 - All the edits have valid indentation
391 "}),
392 },
393 );
394}
395
396#[test]
397#[cfg_attr(not(feature = "eval"), ignore)]
398fn eval_from_pixels_constructor() {
399 let input_file_path = "root/canvas.rs";
400 let input_file_content = include_str!("evals/fixtures/from_pixels_constructor/before.rs");
401 let edit_description = "Implement from_pixels constructor and add tests.";
402 eval(
403 100,
404 0.95,
405 EvalInput {
406 conversation: vec![
407 message(
408 User,
409 [text(indoc! {"
410 Introduce a new `from_pixels` constructor in Canvas and
411 also add tests for it in the same file.
412 "})],
413 ),
414 message(
415 Assistant,
416 [tool_use(
417 "tool_1",
418 "read_file",
419 ReadFileToolInput {
420 path: input_file_path.into(),
421 start_line: None,
422 end_line: None,
423 },
424 )],
425 ),
426 message(
427 User,
428 [tool_result("tool_1", "read_file", input_file_content)],
429 ),
430 message(
431 Assistant,
432 [tool_use(
433 "tool_2",
434 "grep",
435 GrepToolInput {
436 regex: "mod\\s+tests".into(),
437 include_pattern: Some("font-kit/src/canvas.rs".into()),
438 offset: 0,
439 case_sensitive: false,
440 },
441 )],
442 ),
443 message(User, [tool_result("tool_2", "grep", "No matches found")]),
444 message(
445 Assistant,
446 [tool_use(
447 "tool_3",
448 "grep",
449 GrepToolInput {
450 regex: "mod\\s+tests".into(),
451 include_pattern: Some("font-kit/src/**/*.rs".into()),
452 offset: 0,
453 case_sensitive: false,
454 },
455 )],
456 ),
457 message(User, [tool_result("tool_3", "grep", "No matches found")]),
458 message(
459 Assistant,
460 [tool_use(
461 "tool_4",
462 "grep",
463 GrepToolInput {
464 regex: "#\\[test\\]".into(),
465 include_pattern: Some("font-kit/src/**/*.rs".into()),
466 offset: 0,
467 case_sensitive: false,
468 },
469 )],
470 ),
471 message(
472 User,
473 [tool_result(
474 "tool_4",
475 "grep",
476 indoc! {"
477 Found 6 matches:
478
479 ## Matches in font-kit/src/loaders/core_text.rs
480
481 ### mod test › L926-936
482 ```
483 mod test {
484 use super::Font;
485 use crate::properties::{Stretch, Weight};
486
487 #[cfg(feature = \"source\")]
488 use crate::source::SystemSource;
489
490 static TEST_FONT_POSTSCRIPT_NAME: &'static str = \"ArialMT\";
491
492 #[cfg(feature = \"source\")]
493 #[test]
494 ```
495
496 55 lines remaining in ancestor node. Read the file to see all.
497
498 ### mod test › L947-951
499 ```
500 }
501
502 #[test]
503 fn test_core_text_to_css_font_weight() {
504 // Exact matches
505 ```
506
507 ### mod test › L959-963
508 ```
509 }
510
511 #[test]
512 fn test_core_text_to_css_font_stretch() {
513 // Exact matches
514 ```
515
516 ## Matches in font-kit/src/loaders/freetype.rs
517
518 ### mod test › L1238-1248
519 ```
520 mod test {
521 use crate::loaders::freetype::Font;
522
523 static PCF_FONT_PATH: &str = \"resources/tests/times-roman-pcf/timR12.pcf\";
524 static PCF_FONT_POSTSCRIPT_NAME: &str = \"Times-Roman\";
525
526 #[test]
527 fn get_pcf_postscript_name() {
528 let font = Font::from_path(PCF_FONT_PATH, 0).unwrap();
529 assert_eq!(font.postscript_name().unwrap(), PCF_FONT_POSTSCRIPT_NAME);
530 }
531 ```
532
533 1 lines remaining in ancestor node. Read the file to see all.
534
535 ## Matches in font-kit/src/sources/core_text.rs
536
537 ### mod test › L265-275
538 ```
539 mod test {
540 use crate::properties::{Stretch, Weight};
541
542 #[test]
543 fn test_css_to_core_text_font_weight() {
544 // Exact matches
545 assert_eq!(super::css_to_core_text_font_weight(Weight(100.0)), -0.7);
546 assert_eq!(super::css_to_core_text_font_weight(Weight(400.0)), 0.0);
547 assert_eq!(super::css_to_core_text_font_weight(Weight(700.0)), 0.4);
548 assert_eq!(super::css_to_core_text_font_weight(Weight(900.0)), 0.8);
549
550 ```
551
552 27 lines remaining in ancestor node. Read the file to see all.
553
554 ### mod test › L278-282
555 ```
556 }
557
558 #[test]
559 fn test_css_to_core_text_font_stretch() {
560 // Exact matches
561 ```
562 "},
563 )],
564 ),
565 message(
566 Assistant,
567 [tool_use(
568 "tool_5",
569 "edit_file",
570 EditFileToolInput {
571 display_description: edit_description.into(),
572 path: input_file_path.into(),
573 mode: EditFileMode::Edit,
574 },
575 )],
576 ),
577 ],
578 input_path: input_file_path.into(),
579 input_content: Some(input_file_content.into()),
580 edit_description: edit_description.into(),
581 assertion: EvalAssertion::judge_diff(indoc! {"
582 - The diff contains a new `from_pixels` constructor
583 - The diff contains new tests for the `from_pixels` constructor
584 "}),
585 },
586 );
587}
588
589#[test]
590#[cfg_attr(not(feature = "eval"), ignore)]
591fn eval_zode() {
592 let input_file_path = "root/zode.py";
593 let edit_description = "Create the main Zode CLI script";
594 eval(
595 200,
596 1.,
597 EvalInput {
598 conversation: vec![
599 message(User, [text(include_str!("evals/fixtures/zode/prompt.md"))]),
600 message(
601 Assistant,
602 [
603 tool_use(
604 "tool_1",
605 "read_file",
606 ReadFileToolInput {
607 path: "root/eval/react.py".into(),
608 start_line: None,
609 end_line: None,
610 },
611 ),
612 tool_use(
613 "tool_2",
614 "read_file",
615 ReadFileToolInput {
616 path: "root/eval/react_test.py".into(),
617 start_line: None,
618 end_line: None,
619 },
620 ),
621 ],
622 ),
623 message(
624 User,
625 [
626 tool_result(
627 "tool_1",
628 "read_file",
629 include_str!("evals/fixtures/zode/react.py"),
630 ),
631 tool_result(
632 "tool_2",
633 "read_file",
634 include_str!("evals/fixtures/zode/react_test.py"),
635 ),
636 ],
637 ),
638 message(
639 Assistant,
640 [
641 text(
642 "Now that I understand what we need to build, I'll create the main Python script:",
643 ),
644 tool_use(
645 "tool_3",
646 "edit_file",
647 EditFileToolInput {
648 display_description: edit_description.into(),
649 path: input_file_path.into(),
650 mode: EditFileMode::Create,
651 },
652 ),
653 ],
654 ),
655 ],
656 input_path: input_file_path.into(),
657 input_content: None,
658 edit_description: edit_description.into(),
659 assertion: EvalAssertion::new(async move |sample, _, _cx| {
660 let invalid_starts = [' ', '`', '\n'];
661 let mut message = String::new();
662 for start in invalid_starts {
663 if sample.text.starts_with(start) {
664 message.push_str(&format!("The sample starts with a {:?}\n", start));
665 break;
666 }
667 }
668 // Remove trailing newline.
669 message.pop();
670
671 if message.is_empty() {
672 Ok(EvalAssertionOutcome {
673 score: 100,
674 message: None,
675 })
676 } else {
677 Ok(EvalAssertionOutcome {
678 score: 0,
679 message: Some(message),
680 })
681 }
682 }),
683 },
684 );
685}
686
687#[test]
688#[cfg_attr(not(feature = "eval"), ignore)]
689fn eval_add_overwrite_test() {
690 let input_file_path = "root/action_log.rs";
691 let input_file_content = include_str!("evals/fixtures/add_overwrite_test/before.rs");
692 let edit_description = "Add a new test for overwriting a file in action_log.rs";
693 eval(
694 200,
695 0.5, // TODO: make this eval better
696 EvalInput {
697 conversation: vec![
698 message(
699 User,
700 [text(indoc! {"
701 Introduce a new test in `action_log.rs` to test overwriting a file.
702 That is, a file already exists, but we call `buffer_created` as if the file were new.
703 Take inspiration from all the other tests in the file.
704 "})],
705 ),
706 message(
707 Assistant,
708 [tool_use(
709 "tool_1",
710 "read_file",
711 ReadFileToolInput {
712 path: input_file_path.into(),
713 start_line: None,
714 end_line: None,
715 },
716 )],
717 ),
718 message(
719 User,
720 [tool_result(
721 "tool_1",
722 "read_file",
723 indoc! {"
724 pub struct ActionLog [L13-20]
725 tracked_buffers [L15]
726 edited_since_project_diagnostics_check [L17]
727 project [L19]
728 impl ActionLog [L22-498]
729 pub fn new [L24-30]
730 pub fn project [L32-34]
731 pub fn checked_project_diagnostics [L37-39]
732 pub fn has_edited_files_since_project_diagnostics_check [L42-44]
733 fn track_buffer_internal [L46-101]
734 fn handle_buffer_event [L103-116]
735 fn handle_buffer_edited [L118-123]
736 fn handle_buffer_file_changed [L125-158]
737 async fn maintain_diff [L160-264]
738 pub fn buffer_read [L267-269]
739 pub fn buffer_created [L272-276]
740 pub fn buffer_edited [L279-287]
741 pub fn will_delete_buffer [L289-304]
742 pub fn keep_edits_in_range [L306-364]
743 pub fn reject_edits_in_ranges [L366-459]
744 pub fn keep_all_edits [L461-473]
745 pub fn changed_buffers [L476-482]
746 pub fn stale_buffers [L485-497]
747 fn apply_non_conflicting_edits [L500-561]
748 fn diff_snapshots [L563-585]
749 fn point_to_row_edit [L587-614]
750 enum ChangeAuthor [L617-620]
751 User [L618]
752 Agent [L619]
753 enum TrackedBufferStatus [L623-627]
754 Created [L624]
755 Modified [L625]
756 Deleted [L626]
757 struct TrackedBuffer [L629-641]
758 buffer [L630]
759 base_text [L631]
760 unreviewed_changes [L632]
761 status [L633]
762 version [L634]
763 diff [L635]
764 snapshot [L636]
765 diff_update [L637]
766 _open_lsp_handle [L638]
767 _maintain_diff [L639]
768 _subscription [L640]
769 impl TrackedBuffer [L643-657]
770 fn has_changes [L644-650]
771 fn schedule_diff_update [L652-656]
772 pub struct ChangedBuffer [L659-661]
773 pub diff [L660]
774 mod tests [L664-1574]
775 fn init_logger [L678-682]
776 fn init_test [L684-691]
777 async fn test_keep_edits [L694-769]
778 async fn test_deletions [L772-854]
779 async fn test_overlapping_user_edits [L857-951]
780 async fn test_creating_files [L954-1010]
781 async fn test_deleting_files [L1013-1120]
782 async fn test_reject_edits [L1123-1255]
783 async fn test_reject_multiple_edits [L1258-1331]
784 async fn test_reject_deleted_file [L1334-1388]
785 async fn test_reject_created_file [L1391-1443]
786 async fn test_random_diffs [L1446-1535]
787 fn quiesce [L1510-1534]
788 struct HunkStatus [L1538-1542]
789 range [L1539]
790 diff_status [L1540]
791 old_text [L1541]
792 fn unreviewed_hunks [L1544-1573]
793
794 Showing symbols 1-69 (total symbols: 69)
795
796 Using the line numbers in this outline, you can call this tool again while specifying
797 the start_line and end_line fields to see the implementations of symbols in the outline.
798 "},
799 )],
800 ),
801 message(
802 Assistant,
803 [
804 text(
805 "Let's look at the existing test cases that are related to file creation to understand the pattern:",
806 ),
807 tool_use(
808 "tool_2",
809 "read_file",
810 ReadFileToolInput {
811 path: input_file_path.into(),
812 start_line: Some(953),
813 end_line: Some(1010),
814 },
815 ),
816 ],
817 ),
818 message(
819 User,
820 [tool_result(
821 "tool_2",
822 "read_file",
823 lines(input_file_content, 953..1010),
824 )],
825 ),
826 message(
827 Assistant,
828 [
829 text(
830 "Now let's look at another relevant test to understand how they're structured:",
831 ),
832 tool_use(
833 "tool_3",
834 "read_file",
835 ReadFileToolInput {
836 path: input_file_path.into(),
837 start_line: Some(1012),
838 end_line: Some(1120),
839 },
840 ),
841 ],
842 ),
843 message(
844 User,
845 [tool_result(
846 "tool_3",
847 "read_file",
848 lines(input_file_content, 1012..1120),
849 )],
850 ),
851 message(
852 Assistant,
853 [
854 text("Now let's look at how `buffer_created` is implemented:"),
855 tool_use(
856 "tool_4",
857 "read_file",
858 ReadFileToolInput {
859 path: input_file_path.into(),
860 start_line: Some(271),
861 end_line: Some(276),
862 },
863 ),
864 ],
865 ),
866 message(
867 User,
868 [tool_result(
869 "tool_4",
870 "read_file",
871 lines(input_file_content, 271..276),
872 )],
873 ),
874 message(
875 Assistant,
876 [
877 text(indoc! {"
878 Based on my understanding of the existing tests and code, I'll now create a new test for overwriting a file. In this test, we'll:
879
880 1. Create a file
881 2. Open a buffer for it
882 3. Track it in the action log
883 4. Modify the buffer
884 5. Call `buffer_created` again as if the file were new
885 6. Check that the changes are properly tracked
886
887 Let's write the test:
888 "}),
889 tool_use(
890 "tool_5",
891 "edit_file",
892 EditFileToolInput {
893 display_description: edit_description.into(),
894 path: input_file_path.into(),
895 mode: EditFileMode::Edit,
896 },
897 ),
898 ],
899 ),
900 ],
901 input_path: input_file_path.into(),
902 input_content: Some(input_file_content.into()),
903 edit_description: edit_description.into(),
904 assertion: EvalAssertion::judge_diff(
905 "A new test for overwritten files was created, without changing any previous test",
906 ),
907 },
908 );
909}
910
911fn message(
912 role: Role,
913 contents: impl IntoIterator<Item = MessageContent>,
914) -> LanguageModelRequestMessage {
915 LanguageModelRequestMessage {
916 role,
917 content: contents.into_iter().collect(),
918 cache: false,
919 }
920}
921
922fn text(text: impl Into<String>) -> MessageContent {
923 MessageContent::Text(text.into())
924}
925
926fn lines(input: &str, range: Range<usize>) -> String {
927 input
928 .lines()
929 .skip(range.start)
930 .take(range.len())
931 .collect::<Vec<_>>()
932 .join("\n")
933}
934
935fn tool_use(
936 id: impl Into<Arc<str>>,
937 name: impl Into<Arc<str>>,
938 input: impl Serialize,
939) -> MessageContent {
940 MessageContent::ToolUse(LanguageModelToolUse {
941 id: LanguageModelToolUseId::from(id.into()),
942 name: name.into(),
943 raw_input: serde_json::to_string_pretty(&input).unwrap(),
944 input: serde_json::to_value(input).unwrap(),
945 is_input_complete: true,
946 })
947}
948
949fn tool_result(
950 id: impl Into<Arc<str>>,
951 name: impl Into<Arc<str>>,
952 result: impl Into<Arc<str>>,
953) -> MessageContent {
954 MessageContent::ToolResult(LanguageModelToolResult {
955 tool_use_id: LanguageModelToolUseId::from(id.into()),
956 tool_name: name.into(),
957 is_error: false,
958 content: LanguageModelToolResultContent::Text(result.into()),
959 output: None,
960 })
961}
962
963#[derive(Clone)]
964struct EvalInput {
965 conversation: Vec<LanguageModelRequestMessage>,
966 input_path: PathBuf,
967 input_content: Option<String>,
968 edit_description: String,
969 assertion: EvalAssertion,
970}
971
972#[derive(Clone)]
973struct EvalSample {
974 text: String,
975 edit_output: EditAgentOutput,
976 diff: String,
977}
978
979trait AssertionFn: 'static + Send + Sync {
980 fn assert<'a>(
981 &'a self,
982 sample: &'a EvalSample,
983 judge_model: Arc<dyn LanguageModel>,
984 cx: &'a mut TestAppContext,
985 ) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>>;
986}
987
988impl<F> AssertionFn for F
989where
990 F: 'static
991 + Send
992 + Sync
993 + AsyncFn(
994 &EvalSample,
995 Arc<dyn LanguageModel>,
996 &mut TestAppContext,
997 ) -> Result<EvalAssertionOutcome>,
998{
999 fn assert<'a>(
1000 &'a self,
1001 sample: &'a EvalSample,
1002 judge_model: Arc<dyn LanguageModel>,
1003 cx: &'a mut TestAppContext,
1004 ) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>> {
1005 (self)(sample, judge_model, cx).boxed_local()
1006 }
1007}
1008
1009#[derive(Clone)]
1010struct EvalAssertion(Arc<dyn AssertionFn>);
1011
1012impl EvalAssertion {
1013 fn new<F>(f: F) -> Self
1014 where
1015 F: 'static
1016 + Send
1017 + Sync
1018 + AsyncFn(
1019 &EvalSample,
1020 Arc<dyn LanguageModel>,
1021 &mut TestAppContext,
1022 ) -> Result<EvalAssertionOutcome>,
1023 {
1024 EvalAssertion(Arc::new(f))
1025 }
1026
1027 fn assert_eq(expected: impl Into<String>) -> Self {
1028 let expected = expected.into();
1029 Self::new(async move |sample, _judge, _cx| {
1030 Ok(EvalAssertionOutcome {
1031 score: if strip_empty_lines(&sample.text) == strip_empty_lines(&expected) {
1032 100
1033 } else {
1034 0
1035 },
1036 message: None,
1037 })
1038 })
1039 }
1040
1041 fn judge_diff(assertions: &'static str) -> Self {
1042 Self::new(async move |sample, judge, cx| {
1043 let prompt = DiffJudgeTemplate {
1044 diff: sample.diff.clone(),
1045 assertions,
1046 }
1047 .render(&Templates::new())
1048 .unwrap();
1049
1050 let request = LanguageModelRequest {
1051 messages: vec![LanguageModelRequestMessage {
1052 role: Role::User,
1053 content: vec![prompt.into()],
1054 cache: false,
1055 }],
1056 ..Default::default()
1057 };
1058 let mut response = judge
1059 .stream_completion_text(request, &cx.to_async())
1060 .await?;
1061 let mut output = String::new();
1062 while let Some(chunk) = response.stream.next().await {
1063 let chunk = chunk?;
1064 output.push_str(&chunk);
1065 }
1066
1067 // Parse the score from the response
1068 let re = regex::Regex::new(r"<score>(\d+)</score>").unwrap();
1069 if let Some(captures) = re.captures(&output) {
1070 if let Some(score_match) = captures.get(1) {
1071 let score = score_match.as_str().parse().unwrap_or(0);
1072 return Ok(EvalAssertionOutcome {
1073 score,
1074 message: Some(output),
1075 });
1076 }
1077 }
1078
1079 Err(anyhow!(
1080 "No score found in response. Raw output: {}",
1081 output
1082 ))
1083 })
1084 }
1085
1086 async fn run(
1087 &self,
1088 input: &EvalSample,
1089 judge_model: Arc<dyn LanguageModel>,
1090 cx: &mut TestAppContext,
1091 ) -> Result<EvalAssertionOutcome> {
1092 self.0.assert(input, judge_model, cx).await
1093 }
1094}
1095
1096fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
1097 let mut evaluated_count = 0;
1098 let mut failed_count = 0;
1099 report_progress(evaluated_count, failed_count, iterations);
1100
1101 let (tx, rx) = mpsc::channel();
1102
1103 // Cache the last message in the conversation, and run one instance of the eval so that
1104 // all the next ones are cached.
1105 eval.conversation.last_mut().unwrap().cache = true;
1106 run_eval(eval.clone(), tx.clone());
1107
1108 let executor = gpui::background_executor();
1109 for _ in 1..iterations {
1110 let eval = eval.clone();
1111 let tx = tx.clone();
1112 executor.spawn(async move { run_eval(eval, tx) }).detach();
1113 }
1114 drop(tx);
1115
1116 let mut failed_evals = HashMap::default();
1117 let mut errored_evals = HashMap::default();
1118 let mut eval_outputs = Vec::new();
1119 let mut cumulative_parser_metrics = EditParserMetrics::default();
1120 while let Ok(output) = rx.recv() {
1121 match output {
1122 Ok(output) => {
1123 cumulative_parser_metrics += output.sample.edit_output.parser_metrics.clone();
1124 eval_outputs.push(output.clone());
1125 if output.assertion.score < 80 {
1126 failed_count += 1;
1127 failed_evals
1128 .entry(output.sample.text.clone())
1129 .or_insert(Vec::new())
1130 .push(output);
1131 }
1132 }
1133 Err(error) => {
1134 failed_count += 1;
1135 *errored_evals.entry(format!("{:?}", error)).or_insert(0) += 1;
1136 }
1137 }
1138
1139 evaluated_count += 1;
1140 report_progress(evaluated_count, failed_count, iterations);
1141 }
1142
1143 let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32;
1144 println!("Actual pass ratio: {}\n", actual_pass_ratio);
1145 if actual_pass_ratio < expected_pass_ratio {
1146 let mut errored_evals = errored_evals.into_iter().collect::<Vec<_>>();
1147 errored_evals.sort_by_key(|(_, count)| Reverse(*count));
1148 for (error, count) in errored_evals {
1149 println!("Eval errored {} times. Error: {}", count, error);
1150 }
1151
1152 let mut failed_evals = failed_evals.into_iter().collect::<Vec<_>>();
1153 failed_evals.sort_by_key(|(_, evals)| Reverse(evals.len()));
1154 for (_buffer_output, failed_evals) in failed_evals {
1155 let eval_output = failed_evals.first().unwrap();
1156 println!("Eval failed {} times", failed_evals.len());
1157 println!("{}", eval_output);
1158 }
1159
1160 panic!(
1161 "Actual pass ratio: {}\nExpected pass ratio: {}",
1162 actual_pass_ratio, expected_pass_ratio
1163 );
1164 }
1165
1166 let mismatched_tag_ratio =
1167 cumulative_parser_metrics.mismatched_tags as f32 / cumulative_parser_metrics.tags as f32;
1168 if mismatched_tag_ratio > 0.05 {
1169 for eval_output in eval_outputs {
1170 println!("{}", eval_output);
1171 }
1172 panic!("Too many mismatched tags: {:?}", cumulative_parser_metrics);
1173 }
1174}
1175
1176fn run_eval(eval: EvalInput, tx: mpsc::Sender<Result<EvalOutput>>) {
1177 let dispatcher = gpui::TestDispatcher::new(StdRng::from_entropy());
1178 let mut cx = TestAppContext::build(dispatcher, None);
1179 let output = cx.executor().block_test(async {
1180 let test = EditAgentTest::new(&mut cx).await;
1181 test.eval(eval, &mut cx).await
1182 });
1183 tx.send(output).unwrap();
1184}
1185
1186#[derive(Clone)]
1187struct EvalOutput {
1188 sample: EvalSample,
1189 assertion: EvalAssertionOutcome,
1190}
1191
1192impl Display for EvalOutput {
1193 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1194 writeln!(f, "Score: {:?}", self.assertion.score)?;
1195 if let Some(message) = self.assertion.message.as_ref() {
1196 writeln!(f, "Message: {}", message)?;
1197 }
1198
1199 writeln!(f, "Diff:\n{}", self.sample.diff)?;
1200
1201 writeln!(
1202 f,
1203 "Parser Metrics:\n{:#?}",
1204 self.sample.edit_output.parser_metrics
1205 )?;
1206 writeln!(f, "Raw Edits:\n{}", self.sample.edit_output.raw_edits)?;
1207 Ok(())
1208 }
1209}
1210
1211fn report_progress(evaluated_count: usize, failed_count: usize, iterations: usize) {
1212 let passed_count = evaluated_count - failed_count;
1213 let passed_ratio = if evaluated_count == 0 {
1214 0.0
1215 } else {
1216 passed_count as f64 / evaluated_count as f64
1217 };
1218 print!(
1219 "\r\x1b[KEvaluated {}/{} ({:.2}%)",
1220 evaluated_count,
1221 iterations,
1222 passed_ratio * 100.0
1223 );
1224 std::io::stdout().flush().unwrap();
1225}
1226
1227struct EditAgentTest {
1228 agent: EditAgent,
1229 project: Entity<Project>,
1230 judge_model: Arc<dyn LanguageModel>,
1231}
1232
1233impl EditAgentTest {
1234 async fn new(cx: &mut TestAppContext) -> Self {
1235 cx.executor().allow_parking();
1236
1237 let fs = FakeFs::new(cx.executor().clone());
1238 cx.update(|cx| {
1239 settings::init(cx);
1240 gpui_tokio::init(cx);
1241 let http_client = Arc::new(ReqwestClient::user_agent("agent tests").unwrap());
1242 cx.set_http_client(http_client);
1243
1244 client::init_settings(cx);
1245 let client = Client::production(cx);
1246 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1247
1248 settings::init(cx);
1249 Project::init_settings(cx);
1250 language::init(cx);
1251 language_model::init(client.clone(), cx);
1252 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
1253 crate::init(client.http_client(), cx);
1254 });
1255
1256 fs.insert_tree("/root", json!({})).await;
1257 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1258 let (agent_model, judge_model) = cx
1259 .update(|cx| {
1260 cx.spawn(async move |cx| {
1261 let agent_model =
1262 Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
1263 let judge_model =
1264 Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
1265 (agent_model.unwrap(), judge_model.unwrap())
1266 })
1267 })
1268 .await;
1269 let action_log = cx.new(|_| ActionLog::new(project.clone()));
1270
1271 Self {
1272 agent: EditAgent::new(agent_model, project.clone(), action_log, Templates::new()),
1273 project,
1274 judge_model,
1275 }
1276 }
1277
1278 async fn load_model(
1279 provider: &str,
1280 id: &str,
1281 cx: &mut AsyncApp,
1282 ) -> Result<Arc<dyn LanguageModel>> {
1283 let (provider, model) = cx.update(|cx| {
1284 let models = LanguageModelRegistry::read_global(cx);
1285 let model = models
1286 .available_models(cx)
1287 .find(|model| model.provider_id().0 == provider && model.id().0 == id)
1288 .unwrap();
1289 let provider = models.provider(&model.provider_id()).unwrap();
1290 (provider, model)
1291 })?;
1292 cx.update(|cx| provider.authenticate(cx))?.await?;
1293 Ok(model)
1294 }
1295
1296 async fn eval(&self, eval: EvalInput, cx: &mut TestAppContext) -> Result<EvalOutput> {
1297 let path = self
1298 .project
1299 .read_with(cx, |project, cx| {
1300 project.find_project_path(eval.input_path, cx)
1301 })
1302 .unwrap();
1303 let buffer = self
1304 .project
1305 .update(cx, |project, cx| project.open_buffer(path, cx))
1306 .await
1307 .unwrap();
1308 let conversation = LanguageModelRequest {
1309 messages: eval.conversation,
1310 tools: cx.update(|cx| {
1311 ToolRegistry::default_global(cx)
1312 .tools()
1313 .into_iter()
1314 .filter_map(|tool| {
1315 let input_schema = tool
1316 .input_schema(self.agent.model.tool_input_format())
1317 .ok()?;
1318 Some(LanguageModelRequestTool {
1319 name: tool.name(),
1320 description: tool.description(),
1321 input_schema,
1322 })
1323 })
1324 .collect()
1325 }),
1326 ..Default::default()
1327 };
1328 let edit_output = if let Some(input_content) = eval.input_content.as_deref() {
1329 buffer.update(cx, |buffer, cx| buffer.set_text(input_content, cx));
1330 let (edit_output, _) = self.agent.edit(
1331 buffer.clone(),
1332 eval.edit_description,
1333 &conversation,
1334 &mut cx.to_async(),
1335 );
1336 edit_output.await?
1337 } else {
1338 let (edit_output, _) = self.agent.overwrite(
1339 buffer.clone(),
1340 eval.edit_description,
1341 &conversation,
1342 &mut cx.to_async(),
1343 );
1344 edit_output.await?
1345 };
1346
1347 let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
1348 let sample = EvalSample {
1349 edit_output,
1350 diff: language::unified_diff(
1351 eval.input_content.as_deref().unwrap_or_default(),
1352 &buffer_text,
1353 ),
1354 text: buffer_text,
1355 };
1356 let assertion = eval
1357 .assertion
1358 .run(&sample, self.judge_model.clone(), cx)
1359 .await?;
1360
1361 Ok(EvalOutput { assertion, sample })
1362 }
1363}
1364
1365#[derive(Clone, Debug, Eq, PartialEq, Hash)]
1366struct EvalAssertionOutcome {
1367 score: usize,
1368 message: Option<String>,
1369}
1370
1371#[derive(Serialize)]
1372pub struct DiffJudgeTemplate {
1373 diff: String,
1374 assertions: &'static str,
1375}
1376
1377impl Template for DiffJudgeTemplate {
1378 const TEMPLATE_NAME: &'static str = "diff_judge.hbs";
1379}
1380
1381fn strip_empty_lines(text: &str) -> String {
1382 text.lines()
1383 .filter(|line| !line.trim().is_empty())
1384 .collect::<Vec<_>>()
1385 .join("\n")
1386}