1use super::*;
2use crate::{
3 ReadFileToolInput,
4 edit_file_tool::{EditFileMode, EditFileToolInput},
5 grep_tool::GrepToolInput,
6};
7use Role::*;
8use anyhow::anyhow;
9use assistant_tool::ToolRegistry;
10use client::{Client, UserStore};
11use collections::HashMap;
12use fs::FakeFs;
13use futures::{FutureExt, future::LocalBoxFuture};
14use gpui::{AppContext, TestAppContext};
15use indoc::{formatdoc, indoc};
16use language_model::{
17 LanguageModelRegistry, LanguageModelRequestTool, LanguageModelToolResult,
18 LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, SelectedModel,
19};
20use project::Project;
21use rand::prelude::*;
22use reqwest_client::ReqwestClient;
23use serde_json::json;
24use std::{
25 cmp::Reverse,
26 fmt::{self, Display},
27 io::Write as _,
28 str::FromStr,
29 sync::mpsc,
30};
31use util::path;
32
33#[test]
34#[cfg_attr(not(feature = "eval"), ignore)]
35fn eval_extract_handle_command_output() {
36 let input_file_path = "root/blame.rs";
37 let input_file_content = include_str!("evals/fixtures/extract_handle_command_output/before.rs");
38 let output_file_content = include_str!("evals/fixtures/extract_handle_command_output/after.rs");
39 let edit_description = "Extract `handle_command_output` method from `run_git_blame`.";
40 eval(
41 100,
42 0.95,
43 EvalInput {
44 conversation: vec![
45 message(
46 User,
47 [text(formatdoc! {"
48 Read the `{input_file_path}` file and extract a method in
49 the final stanza of `run_git_blame` to deal with command failures,
50 call it `handle_command_output` and take the std::process::Output as the only parameter.
51
52 Add it right next to `run_git_blame` and copy it verbatim from `run_git_blame`.
53 "})],
54 ),
55 message(
56 Assistant,
57 [tool_use(
58 "tool_1",
59 "read_file",
60 ReadFileToolInput {
61 path: input_file_path.into(),
62 start_line: None,
63 end_line: None,
64 },
65 )],
66 ),
67 message(
68 User,
69 [tool_result("tool_1", "read_file", input_file_content)],
70 ),
71 message(
72 Assistant,
73 [tool_use(
74 "tool_2",
75 "edit_file",
76 EditFileToolInput {
77 display_description: edit_description.into(),
78 path: input_file_path.into(),
79 mode: EditFileMode::Edit,
80 },
81 )],
82 ),
83 ],
84 input_path: input_file_path.into(),
85 input_content: Some(input_file_content.into()),
86 edit_description: edit_description.into(),
87 assertion: EvalAssertion::assert_eq(output_file_content),
88 },
89 );
90}
91
92#[test]
93#[cfg_attr(not(feature = "eval"), ignore)]
94fn eval_delete_run_git_blame() {
95 let input_file_path = "root/blame.rs";
96 let input_file_content = include_str!("evals/fixtures/delete_run_git_blame/before.rs");
97 let output_file_content = include_str!("evals/fixtures/delete_run_git_blame/after.rs");
98 let edit_description = "Delete the `run_git_blame` function.";
99 eval(
100 100,
101 0.95,
102 EvalInput {
103 conversation: vec![
104 message(
105 User,
106 [text(formatdoc! {"
107 Read the `{input_file_path}` file and delete `run_git_blame`. Just that
108 one function, not its usages.
109 "})],
110 ),
111 message(
112 Assistant,
113 [tool_use(
114 "tool_1",
115 "read_file",
116 ReadFileToolInput {
117 path: input_file_path.into(),
118 start_line: None,
119 end_line: None,
120 },
121 )],
122 ),
123 message(
124 User,
125 [tool_result("tool_1", "read_file", input_file_content)],
126 ),
127 message(
128 Assistant,
129 [tool_use(
130 "tool_2",
131 "edit_file",
132 EditFileToolInput {
133 display_description: edit_description.into(),
134 path: input_file_path.into(),
135 mode: EditFileMode::Edit,
136 },
137 )],
138 ),
139 ],
140 input_path: input_file_path.into(),
141 input_content: Some(input_file_content.into()),
142 edit_description: edit_description.into(),
143 assertion: EvalAssertion::assert_eq(output_file_content),
144 },
145 );
146}
147
148#[test]
149#[cfg_attr(not(feature = "eval"), ignore)]
150fn eval_translate_doc_comments() {
151 let input_file_path = "root/canvas.rs";
152 let input_file_content = include_str!("evals/fixtures/translate_doc_comments/before.rs");
153 let edit_description = "Translate all doc comments to Italian";
154 eval(
155 200,
156 1.,
157 EvalInput {
158 conversation: vec![
159 message(
160 User,
161 [text(formatdoc! {"
162 Read the {input_file_path} file and edit it (without overwriting it),
163 translating all the doc comments to italian.
164 "})],
165 ),
166 message(
167 Assistant,
168 [tool_use(
169 "tool_1",
170 "read_file",
171 ReadFileToolInput {
172 path: input_file_path.into(),
173 start_line: None,
174 end_line: None,
175 },
176 )],
177 ),
178 message(
179 User,
180 [tool_result("tool_1", "read_file", input_file_content)],
181 ),
182 message(
183 Assistant,
184 [tool_use(
185 "tool_2",
186 "edit_file",
187 EditFileToolInput {
188 display_description: edit_description.into(),
189 path: input_file_path.into(),
190 mode: EditFileMode::Edit,
191 },
192 )],
193 ),
194 ],
195 input_path: input_file_path.into(),
196 input_content: Some(input_file_content.into()),
197 edit_description: edit_description.into(),
198 assertion: EvalAssertion::judge_diff("Doc comments were translated to Italian"),
199 },
200 );
201}
202
203#[test]
204#[cfg_attr(not(feature = "eval"), ignore)]
205fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
206 let input_file_path = "root/lib.rs";
207 let input_file_content =
208 include_str!("evals/fixtures/use_wasi_sdk_in_compile_parser_to_wasm/before.rs");
209 let edit_description = "Update compile_parser_to_wasm to use wasi-sdk instead of emscripten";
210 eval(
211 100,
212 0.95,
213 EvalInput {
214 conversation: vec![
215 message(
216 User,
217 [text(formatdoc! {"
218 Read the `{input_file_path}` file and change `compile_parser_to_wasm` to use `wasi-sdk` instead of emscripten.
219 Use `ureq` to download the SDK for the current platform and architecture.
220 Extract the archive into a sibling of `lib` inside the `tree-sitter` directory in the cache_dir.
221 Compile the parser to wasm using the `bin/clang` executable (or `bin/clang.exe` on windows)
222 that's inside of the archive.
223 Don't re-download the SDK if that executable already exists.
224
225 Use these clang flags: -fPIC -shared -Os -Wl,--export=tree_sitter_{{language_name}}
226
227 Here are the available wasi-sdk assets:
228 - wasi-sdk-25.0-x86_64-macos.tar.gz
229 - wasi-sdk-25.0-arm64-macos.tar.gz
230 - wasi-sdk-25.0-x86_64-linux.tar.gz
231 - wasi-sdk-25.0-arm64-linux.tar.gz
232 - wasi-sdk-25.0-x86_64-linux.tar.gz
233 - wasi-sdk-25.0-arm64-linux.tar.gz
234 - wasi-sdk-25.0-x86_64-windows.tar.gz
235 "})],
236 ),
237 message(
238 Assistant,
239 [tool_use(
240 "tool_1",
241 "read_file",
242 ReadFileToolInput {
243 path: input_file_path.into(),
244 start_line: Some(971),
245 end_line: Some(1050),
246 },
247 )],
248 ),
249 message(
250 User,
251 [tool_result(
252 "tool_1",
253 "read_file",
254 lines(input_file_content, 971..1050),
255 )],
256 ),
257 message(
258 Assistant,
259 [tool_use(
260 "tool_2",
261 "read_file",
262 ReadFileToolInput {
263 path: input_file_path.into(),
264 start_line: Some(1050),
265 end_line: Some(1100),
266 },
267 )],
268 ),
269 message(
270 User,
271 [tool_result(
272 "tool_2",
273 "read_file",
274 lines(input_file_content, 1050..1100),
275 )],
276 ),
277 message(
278 Assistant,
279 [tool_use(
280 "tool_3",
281 "read_file",
282 ReadFileToolInput {
283 path: input_file_path.into(),
284 start_line: Some(1100),
285 end_line: Some(1150),
286 },
287 )],
288 ),
289 message(
290 User,
291 [tool_result(
292 "tool_3",
293 "read_file",
294 lines(input_file_content, 1100..1150),
295 )],
296 ),
297 message(
298 Assistant,
299 [tool_use(
300 "tool_4",
301 "edit_file",
302 EditFileToolInput {
303 display_description: edit_description.into(),
304 path: input_file_path.into(),
305 mode: EditFileMode::Edit,
306 },
307 )],
308 ),
309 ],
310 input_path: input_file_path.into(),
311 input_content: Some(input_file_content.into()),
312 edit_description: edit_description.into(),
313 assertion: EvalAssertion::judge_diff(indoc! {"
314 - The compile_parser_to_wasm method has been changed to use wasi-sdk
315 - ureq is used to download the SDK for current platform and architecture
316 "}),
317 },
318 );
319}
320
321#[test]
322#[cfg_attr(not(feature = "eval"), ignore)]
323fn eval_disable_cursor_blinking() {
324 let input_file_path = "root/editor.rs";
325 let input_file_content = include_str!("evals/fixtures/disable_cursor_blinking/before.rs");
326 let edit_description = "Comment out the call to `BlinkManager::enable`";
327 eval(
328 200,
329 0.95,
330 EvalInput {
331 conversation: vec![
332 message(User, [text("Let's research how to cursor blinking works.")]),
333 message(
334 Assistant,
335 [tool_use(
336 "tool_1",
337 "grep",
338 GrepToolInput {
339 regex: "blink".into(),
340 include_pattern: None,
341 offset: 0,
342 case_sensitive: false,
343 },
344 )],
345 ),
346 message(
347 User,
348 [tool_result(
349 "tool_1",
350 "grep",
351 [
352 lines(input_file_content, 100..400),
353 lines(input_file_content, 800..1300),
354 lines(input_file_content, 1600..2000),
355 lines(input_file_content, 5000..5500),
356 lines(input_file_content, 8000..9000),
357 lines(input_file_content, 18455..18470),
358 lines(input_file_content, 20000..20500),
359 lines(input_file_content, 21000..21300),
360 ]
361 .join("Match found:\n\n"),
362 )],
363 ),
364 message(
365 User,
366 [text(indoc! {"
367 Comment out the lines that interact with the BlinkManager.
368 Keep the outer `update` blocks, but comments everything that's inside (including if statements).
369 Don't add additional comments.
370 "})],
371 ),
372 message(
373 Assistant,
374 [tool_use(
375 "tool_4",
376 "edit_file",
377 EditFileToolInput {
378 display_description: edit_description.into(),
379 path: input_file_path.into(),
380 mode: EditFileMode::Edit,
381 },
382 )],
383 ),
384 ],
385 input_path: input_file_path.into(),
386 input_content: Some(input_file_content.into()),
387 edit_description: edit_description.into(),
388 assertion: EvalAssertion::judge_diff(indoc! {"
389 - Calls to BlinkManager in `observe_window_activation` were commented out
390 - The call to `blink_manager.enable` above the call to show_cursor_names was commented out
391 - All the edits have valid indentation
392 "}),
393 },
394 );
395}
396
397#[test]
398#[cfg_attr(not(feature = "eval"), ignore)]
399fn eval_from_pixels_constructor() {
400 let input_file_path = "root/canvas.rs";
401 let input_file_content = include_str!("evals/fixtures/from_pixels_constructor/before.rs");
402 let edit_description = "Implement from_pixels constructor and add tests.";
403 eval(
404 100,
405 0.95,
406 EvalInput {
407 conversation: vec![
408 message(
409 User,
410 [text(indoc! {"
411 Introduce a new `from_pixels` constructor in Canvas and
412 also add tests for it in the same file.
413 "})],
414 ),
415 message(
416 Assistant,
417 [tool_use(
418 "tool_1",
419 "read_file",
420 ReadFileToolInput {
421 path: input_file_path.into(),
422 start_line: None,
423 end_line: None,
424 },
425 )],
426 ),
427 message(
428 User,
429 [tool_result("tool_1", "read_file", input_file_content)],
430 ),
431 message(
432 Assistant,
433 [tool_use(
434 "tool_2",
435 "grep",
436 GrepToolInput {
437 regex: "mod\\s+tests".into(),
438 include_pattern: Some("font-kit/src/canvas.rs".into()),
439 offset: 0,
440 case_sensitive: false,
441 },
442 )],
443 ),
444 message(User, [tool_result("tool_2", "grep", "No matches found")]),
445 message(
446 Assistant,
447 [tool_use(
448 "tool_3",
449 "grep",
450 GrepToolInput {
451 regex: "mod\\s+tests".into(),
452 include_pattern: Some("font-kit/src/**/*.rs".into()),
453 offset: 0,
454 case_sensitive: false,
455 },
456 )],
457 ),
458 message(User, [tool_result("tool_3", "grep", "No matches found")]),
459 message(
460 Assistant,
461 [tool_use(
462 "tool_4",
463 "grep",
464 GrepToolInput {
465 regex: "#\\[test\\]".into(),
466 include_pattern: Some("font-kit/src/**/*.rs".into()),
467 offset: 0,
468 case_sensitive: false,
469 },
470 )],
471 ),
472 message(
473 User,
474 [tool_result(
475 "tool_4",
476 "grep",
477 indoc! {"
478 Found 6 matches:
479
480 ## Matches in font-kit/src/loaders/core_text.rs
481
482 ### mod test › L926-936
483 ```
484 mod test {
485 use super::Font;
486 use crate::properties::{Stretch, Weight};
487
488 #[cfg(feature = \"source\")]
489 use crate::source::SystemSource;
490
491 static TEST_FONT_POSTSCRIPT_NAME: &'static str = \"ArialMT\";
492
493 #[cfg(feature = \"source\")]
494 #[test]
495 ```
496
497 55 lines remaining in ancestor node. Read the file to see all.
498
499 ### mod test › L947-951
500 ```
501 }
502
503 #[test]
504 fn test_core_text_to_css_font_weight() {
505 // Exact matches
506 ```
507
508 ### mod test › L959-963
509 ```
510 }
511
512 #[test]
513 fn test_core_text_to_css_font_stretch() {
514 // Exact matches
515 ```
516
517 ## Matches in font-kit/src/loaders/freetype.rs
518
519 ### mod test › L1238-1248
520 ```
521 mod test {
522 use crate::loaders::freetype::Font;
523
524 static PCF_FONT_PATH: &str = \"resources/tests/times-roman-pcf/timR12.pcf\";
525 static PCF_FONT_POSTSCRIPT_NAME: &str = \"Times-Roman\";
526
527 #[test]
528 fn get_pcf_postscript_name() {
529 let font = Font::from_path(PCF_FONT_PATH, 0).unwrap();
530 assert_eq!(font.postscript_name().unwrap(), PCF_FONT_POSTSCRIPT_NAME);
531 }
532 ```
533
534 1 lines remaining in ancestor node. Read the file to see all.
535
536 ## Matches in font-kit/src/sources/core_text.rs
537
538 ### mod test › L265-275
539 ```
540 mod test {
541 use crate::properties::{Stretch, Weight};
542
543 #[test]
544 fn test_css_to_core_text_font_weight() {
545 // Exact matches
546 assert_eq!(super::css_to_core_text_font_weight(Weight(100.0)), -0.7);
547 assert_eq!(super::css_to_core_text_font_weight(Weight(400.0)), 0.0);
548 assert_eq!(super::css_to_core_text_font_weight(Weight(700.0)), 0.4);
549 assert_eq!(super::css_to_core_text_font_weight(Weight(900.0)), 0.8);
550
551 ```
552
553 27 lines remaining in ancestor node. Read the file to see all.
554
555 ### mod test › L278-282
556 ```
557 }
558
559 #[test]
560 fn test_css_to_core_text_font_stretch() {
561 // Exact matches
562 ```
563 "},
564 )],
565 ),
566 message(
567 Assistant,
568 [tool_use(
569 "tool_5",
570 "edit_file",
571 EditFileToolInput {
572 display_description: edit_description.into(),
573 path: input_file_path.into(),
574 mode: EditFileMode::Edit,
575 },
576 )],
577 ),
578 ],
579 input_path: input_file_path.into(),
580 input_content: Some(input_file_content.into()),
581 edit_description: edit_description.into(),
582 assertion: EvalAssertion::judge_diff(indoc! {"
583 - The diff contains a new `from_pixels` constructor
584 - The diff contains new tests for the `from_pixels` constructor
585 "}),
586 },
587 );
588}
589
590#[test]
591#[cfg_attr(not(feature = "eval"), ignore)]
592fn eval_zode() {
593 let input_file_path = "root/zode.py";
594 let edit_description = "Create the main Zode CLI script";
595 eval(
596 200,
597 1.,
598 EvalInput {
599 conversation: vec![
600 message(User, [text(include_str!("evals/fixtures/zode/prompt.md"))]),
601 message(
602 Assistant,
603 [
604 tool_use(
605 "tool_1",
606 "read_file",
607 ReadFileToolInput {
608 path: "root/eval/react.py".into(),
609 start_line: None,
610 end_line: None,
611 },
612 ),
613 tool_use(
614 "tool_2",
615 "read_file",
616 ReadFileToolInput {
617 path: "root/eval/react_test.py".into(),
618 start_line: None,
619 end_line: None,
620 },
621 ),
622 ],
623 ),
624 message(
625 User,
626 [
627 tool_result(
628 "tool_1",
629 "read_file",
630 include_str!("evals/fixtures/zode/react.py"),
631 ),
632 tool_result(
633 "tool_2",
634 "read_file",
635 include_str!("evals/fixtures/zode/react_test.py"),
636 ),
637 ],
638 ),
639 message(
640 Assistant,
641 [
642 text(
643 "Now that I understand what we need to build, I'll create the main Python script:",
644 ),
645 tool_use(
646 "tool_3",
647 "edit_file",
648 EditFileToolInput {
649 display_description: edit_description.into(),
650 path: input_file_path.into(),
651 mode: EditFileMode::Create,
652 },
653 ),
654 ],
655 ),
656 ],
657 input_path: input_file_path.into(),
658 input_content: None,
659 edit_description: edit_description.into(),
660 assertion: EvalAssertion::new(async move |sample, _, _cx| {
661 let invalid_starts = [' ', '`', '\n'];
662 let mut message = String::new();
663 for start in invalid_starts {
664 if sample.text.starts_with(start) {
665 message.push_str(&format!("The sample starts with a {:?}\n", start));
666 break;
667 }
668 }
669 // Remove trailing newline.
670 message.pop();
671
672 if message.is_empty() {
673 Ok(EvalAssertionOutcome {
674 score: 100,
675 message: None,
676 })
677 } else {
678 Ok(EvalAssertionOutcome {
679 score: 0,
680 message: Some(message),
681 })
682 }
683 }),
684 },
685 );
686}
687
688#[test]
689#[cfg_attr(not(feature = "eval"), ignore)]
690fn eval_add_overwrite_test() {
691 let input_file_path = "root/action_log.rs";
692 let input_file_content = include_str!("evals/fixtures/add_overwrite_test/before.rs");
693 let edit_description = "Add a new test for overwriting a file in action_log.rs";
694 eval(
695 200,
696 0.5, // TODO: make this eval better
697 EvalInput {
698 conversation: vec![
699 message(
700 User,
701 [text(indoc! {"
702 Introduce a new test in `action_log.rs` to test overwriting a file.
703 That is, a file already exists, but we call `buffer_created` as if the file were new.
704 Take inspiration from all the other tests in the file.
705 "})],
706 ),
707 message(
708 Assistant,
709 [tool_use(
710 "tool_1",
711 "read_file",
712 ReadFileToolInput {
713 path: input_file_path.into(),
714 start_line: None,
715 end_line: None,
716 },
717 )],
718 ),
719 message(
720 User,
721 [tool_result(
722 "tool_1",
723 "read_file",
724 indoc! {"
725 pub struct ActionLog [L13-20]
726 tracked_buffers [L15]
727 edited_since_project_diagnostics_check [L17]
728 project [L19]
729 impl ActionLog [L22-498]
730 pub fn new [L24-30]
731 pub fn project [L32-34]
732 pub fn checked_project_diagnostics [L37-39]
733 pub fn has_edited_files_since_project_diagnostics_check [L42-44]
734 fn track_buffer_internal [L46-101]
735 fn handle_buffer_event [L103-116]
736 fn handle_buffer_edited [L118-123]
737 fn handle_buffer_file_changed [L125-158]
738 async fn maintain_diff [L160-264]
739 pub fn buffer_read [L267-269]
740 pub fn buffer_created [L272-276]
741 pub fn buffer_edited [L279-287]
742 pub fn will_delete_buffer [L289-304]
743 pub fn keep_edits_in_range [L306-364]
744 pub fn reject_edits_in_ranges [L366-459]
745 pub fn keep_all_edits [L461-473]
746 pub fn changed_buffers [L476-482]
747 pub fn stale_buffers [L485-497]
748 fn apply_non_conflicting_edits [L500-561]
749 fn diff_snapshots [L563-585]
750 fn point_to_row_edit [L587-614]
751 enum ChangeAuthor [L617-620]
752 User [L618]
753 Agent [L619]
754 enum TrackedBufferStatus [L623-627]
755 Created [L624]
756 Modified [L625]
757 Deleted [L626]
758 struct TrackedBuffer [L629-641]
759 buffer [L630]
760 base_text [L631]
761 unreviewed_changes [L632]
762 status [L633]
763 version [L634]
764 diff [L635]
765 snapshot [L636]
766 diff_update [L637]
767 _open_lsp_handle [L638]
768 _maintain_diff [L639]
769 _subscription [L640]
770 impl TrackedBuffer [L643-657]
771 fn has_changes [L644-650]
772 fn schedule_diff_update [L652-656]
773 pub struct ChangedBuffer [L659-661]
774 pub diff [L660]
775 mod tests [L664-1574]
776 fn init_logger [L678-682]
777 fn init_test [L684-691]
778 async fn test_keep_edits [L694-769]
779 async fn test_deletions [L772-854]
780 async fn test_overlapping_user_edits [L857-951]
781 async fn test_creating_files [L954-1010]
782 async fn test_deleting_files [L1013-1120]
783 async fn test_reject_edits [L1123-1255]
784 async fn test_reject_multiple_edits [L1258-1331]
785 async fn test_reject_deleted_file [L1334-1388]
786 async fn test_reject_created_file [L1391-1443]
787 async fn test_random_diffs [L1446-1535]
788 fn quiesce [L1510-1534]
789 struct HunkStatus [L1538-1542]
790 range [L1539]
791 diff_status [L1540]
792 old_text [L1541]
793 fn unreviewed_hunks [L1544-1573]
794
795 Showing symbols 1-69 (total symbols: 69)
796
797 Using the line numbers in this outline, you can call this tool again while specifying
798 the start_line and end_line fields to see the implementations of symbols in the outline.
799 "},
800 )],
801 ),
802 message(
803 Assistant,
804 [
805 text(
806 "Let's look at the existing test cases that are related to file creation to understand the pattern:",
807 ),
808 tool_use(
809 "tool_2",
810 "read_file",
811 ReadFileToolInput {
812 path: input_file_path.into(),
813 start_line: Some(953),
814 end_line: Some(1010),
815 },
816 ),
817 ],
818 ),
819 message(
820 User,
821 [tool_result(
822 "tool_2",
823 "read_file",
824 lines(input_file_content, 953..1010),
825 )],
826 ),
827 message(
828 Assistant,
829 [
830 text(
831 "Now let's look at another relevant test to understand how they're structured:",
832 ),
833 tool_use(
834 "tool_3",
835 "read_file",
836 ReadFileToolInput {
837 path: input_file_path.into(),
838 start_line: Some(1012),
839 end_line: Some(1120),
840 },
841 ),
842 ],
843 ),
844 message(
845 User,
846 [tool_result(
847 "tool_3",
848 "read_file",
849 lines(input_file_content, 1012..1120),
850 )],
851 ),
852 message(
853 Assistant,
854 [
855 text("Now let's look at how `buffer_created` is implemented:"),
856 tool_use(
857 "tool_4",
858 "read_file",
859 ReadFileToolInput {
860 path: input_file_path.into(),
861 start_line: Some(271),
862 end_line: Some(276),
863 },
864 ),
865 ],
866 ),
867 message(
868 User,
869 [tool_result(
870 "tool_4",
871 "read_file",
872 lines(input_file_content, 271..276),
873 )],
874 ),
875 message(
876 Assistant,
877 [
878 text(indoc! {"
879 Based on my understanding of the existing tests and code, I'll now create a new test for overwriting a file. In this test, we'll:
880
881 1. Create a file
882 2. Open a buffer for it
883 3. Track it in the action log
884 4. Modify the buffer
885 5. Call `buffer_created` again as if the file were new
886 6. Check that the changes are properly tracked
887
888 Let's write the test:
889 "}),
890 tool_use(
891 "tool_5",
892 "edit_file",
893 EditFileToolInput {
894 display_description: edit_description.into(),
895 path: input_file_path.into(),
896 mode: EditFileMode::Edit,
897 },
898 ),
899 ],
900 ),
901 ],
902 input_path: input_file_path.into(),
903 input_content: Some(input_file_content.into()),
904 edit_description: edit_description.into(),
905 assertion: EvalAssertion::judge_diff(
906 "A new test for overwritten files was created, without changing any previous test",
907 ),
908 },
909 );
910}
911
912fn message(
913 role: Role,
914 contents: impl IntoIterator<Item = MessageContent>,
915) -> LanguageModelRequestMessage {
916 LanguageModelRequestMessage {
917 role,
918 content: contents.into_iter().collect(),
919 cache: false,
920 }
921}
922
923fn text(text: impl Into<String>) -> MessageContent {
924 MessageContent::Text(text.into())
925}
926
927fn lines(input: &str, range: Range<usize>) -> String {
928 input
929 .lines()
930 .skip(range.start)
931 .take(range.len())
932 .collect::<Vec<_>>()
933 .join("\n")
934}
935
936fn tool_use(
937 id: impl Into<Arc<str>>,
938 name: impl Into<Arc<str>>,
939 input: impl Serialize,
940) -> MessageContent {
941 MessageContent::ToolUse(LanguageModelToolUse {
942 id: LanguageModelToolUseId::from(id.into()),
943 name: name.into(),
944 raw_input: serde_json::to_string_pretty(&input).unwrap(),
945 input: serde_json::to_value(input).unwrap(),
946 is_input_complete: true,
947 })
948}
949
950fn tool_result(
951 id: impl Into<Arc<str>>,
952 name: impl Into<Arc<str>>,
953 result: impl Into<Arc<str>>,
954) -> MessageContent {
955 MessageContent::ToolResult(LanguageModelToolResult {
956 tool_use_id: LanguageModelToolUseId::from(id.into()),
957 tool_name: name.into(),
958 is_error: false,
959 content: LanguageModelToolResultContent::Text(result.into()),
960 output: None,
961 })
962}
963
964#[derive(Clone)]
965struct EvalInput {
966 conversation: Vec<LanguageModelRequestMessage>,
967 input_path: PathBuf,
968 input_content: Option<String>,
969 edit_description: String,
970 assertion: EvalAssertion,
971}
972
973#[derive(Clone)]
974struct EvalSample {
975 text: String,
976 edit_output: EditAgentOutput,
977 diff: String,
978}
979
980trait AssertionFn: 'static + Send + Sync {
981 fn assert<'a>(
982 &'a self,
983 sample: &'a EvalSample,
984 judge_model: Arc<dyn LanguageModel>,
985 cx: &'a mut TestAppContext,
986 ) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>>;
987}
988
989impl<F> AssertionFn for F
990where
991 F: 'static
992 + Send
993 + Sync
994 + AsyncFn(
995 &EvalSample,
996 Arc<dyn LanguageModel>,
997 &mut TestAppContext,
998 ) -> Result<EvalAssertionOutcome>,
999{
1000 fn assert<'a>(
1001 &'a self,
1002 sample: &'a EvalSample,
1003 judge_model: Arc<dyn LanguageModel>,
1004 cx: &'a mut TestAppContext,
1005 ) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>> {
1006 (self)(sample, judge_model, cx).boxed_local()
1007 }
1008}
1009
1010#[derive(Clone)]
1011struct EvalAssertion(Arc<dyn AssertionFn>);
1012
1013impl EvalAssertion {
1014 fn new<F>(f: F) -> Self
1015 where
1016 F: 'static
1017 + Send
1018 + Sync
1019 + AsyncFn(
1020 &EvalSample,
1021 Arc<dyn LanguageModel>,
1022 &mut TestAppContext,
1023 ) -> Result<EvalAssertionOutcome>,
1024 {
1025 EvalAssertion(Arc::new(f))
1026 }
1027
1028 fn assert_eq(expected: impl Into<String>) -> Self {
1029 let expected = expected.into();
1030 Self::new(async move |sample, _judge, _cx| {
1031 Ok(EvalAssertionOutcome {
1032 score: if strip_empty_lines(&sample.text) == strip_empty_lines(&expected) {
1033 100
1034 } else {
1035 0
1036 },
1037 message: None,
1038 })
1039 })
1040 }
1041
1042 fn judge_diff(assertions: &'static str) -> Self {
1043 Self::new(async move |sample, judge, cx| {
1044 let prompt = DiffJudgeTemplate {
1045 diff: sample.diff.clone(),
1046 assertions,
1047 }
1048 .render(&Templates::new())
1049 .unwrap();
1050
1051 let request = LanguageModelRequest {
1052 messages: vec![LanguageModelRequestMessage {
1053 role: Role::User,
1054 content: vec![prompt.into()],
1055 cache: false,
1056 }],
1057 ..Default::default()
1058 };
1059 let mut response = judge
1060 .stream_completion_text(request, &cx.to_async())
1061 .await?;
1062 let mut output = String::new();
1063 while let Some(chunk) = response.stream.next().await {
1064 let chunk = chunk?;
1065 output.push_str(&chunk);
1066 }
1067
1068 // Parse the score from the response
1069 let re = regex::Regex::new(r"<score>(\d+)</score>").unwrap();
1070 if let Some(captures) = re.captures(&output) {
1071 if let Some(score_match) = captures.get(1) {
1072 let score = score_match.as_str().parse().unwrap_or(0);
1073 return Ok(EvalAssertionOutcome {
1074 score,
1075 message: Some(output),
1076 });
1077 }
1078 }
1079
1080 Err(anyhow!(
1081 "No score found in response. Raw output: {}",
1082 output
1083 ))
1084 })
1085 }
1086
1087 async fn run(
1088 &self,
1089 input: &EvalSample,
1090 judge_model: Arc<dyn LanguageModel>,
1091 cx: &mut TestAppContext,
1092 ) -> Result<EvalAssertionOutcome> {
1093 self.0.assert(input, judge_model, cx).await
1094 }
1095}
1096
1097fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
1098 let mut evaluated_count = 0;
1099 let mut failed_count = 0;
1100 report_progress(evaluated_count, failed_count, iterations);
1101
1102 let (tx, rx) = mpsc::channel();
1103
1104 // Cache the last message in the conversation, and run one instance of the eval so that
1105 // all the next ones are cached.
1106 eval.conversation.last_mut().unwrap().cache = true;
1107 run_eval(eval.clone(), tx.clone());
1108
1109 let executor = gpui::background_executor();
1110 for _ in 1..iterations {
1111 let eval = eval.clone();
1112 let tx = tx.clone();
1113 executor.spawn(async move { run_eval(eval, tx) }).detach();
1114 }
1115 drop(tx);
1116
1117 let mut failed_evals = HashMap::default();
1118 let mut errored_evals = HashMap::default();
1119 let mut eval_outputs = Vec::new();
1120 let mut cumulative_parser_metrics = EditParserMetrics::default();
1121 while let Ok(output) = rx.recv() {
1122 match output {
1123 Ok(output) => {
1124 cumulative_parser_metrics += output.sample.edit_output.parser_metrics.clone();
1125 eval_outputs.push(output.clone());
1126 if output.assertion.score < 80 {
1127 failed_count += 1;
1128 failed_evals
1129 .entry(output.sample.text.clone())
1130 .or_insert(Vec::new())
1131 .push(output);
1132 }
1133 }
1134 Err(error) => {
1135 failed_count += 1;
1136 *errored_evals.entry(format!("{:?}", error)).or_insert(0) += 1;
1137 }
1138 }
1139
1140 evaluated_count += 1;
1141 report_progress(evaluated_count, failed_count, iterations);
1142 }
1143
1144 let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32;
1145 println!("Actual pass ratio: {}\n", actual_pass_ratio);
1146 if actual_pass_ratio < expected_pass_ratio {
1147 let mut errored_evals = errored_evals.into_iter().collect::<Vec<_>>();
1148 errored_evals.sort_by_key(|(_, count)| Reverse(*count));
1149 for (error, count) in errored_evals {
1150 println!("Eval errored {} times. Error: {}", count, error);
1151 }
1152
1153 let mut failed_evals = failed_evals.into_iter().collect::<Vec<_>>();
1154 failed_evals.sort_by_key(|(_, evals)| Reverse(evals.len()));
1155 for (_buffer_output, failed_evals) in failed_evals {
1156 let eval_output = failed_evals.first().unwrap();
1157 println!("Eval failed {} times", failed_evals.len());
1158 println!("{}", eval_output);
1159 }
1160
1161 panic!(
1162 "Actual pass ratio: {}\nExpected pass ratio: {}",
1163 actual_pass_ratio, expected_pass_ratio
1164 );
1165 }
1166
1167 let mismatched_tag_ratio =
1168 cumulative_parser_metrics.mismatched_tags as f32 / cumulative_parser_metrics.tags as f32;
1169 if mismatched_tag_ratio > 0.05 {
1170 for eval_output in eval_outputs {
1171 println!("{}", eval_output);
1172 }
1173 panic!("Too many mismatched tags: {:?}", cumulative_parser_metrics);
1174 }
1175}
1176
1177fn run_eval(eval: EvalInput, tx: mpsc::Sender<Result<EvalOutput>>) {
1178 let dispatcher = gpui::TestDispatcher::new(StdRng::from_entropy());
1179 let mut cx = TestAppContext::build(dispatcher, None);
1180 let output = cx.executor().block_test(async {
1181 let test = EditAgentTest::new(&mut cx).await;
1182 test.eval(eval, &mut cx).await
1183 });
1184 tx.send(output).unwrap();
1185}
1186
1187#[derive(Clone)]
1188struct EvalOutput {
1189 sample: EvalSample,
1190 assertion: EvalAssertionOutcome,
1191}
1192
1193impl Display for EvalOutput {
1194 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1195 writeln!(f, "Score: {:?}", self.assertion.score)?;
1196 if let Some(message) = self.assertion.message.as_ref() {
1197 writeln!(f, "Message: {}", message)?;
1198 }
1199
1200 writeln!(f, "Diff:\n{}", self.sample.diff)?;
1201
1202 writeln!(
1203 f,
1204 "Parser Metrics:\n{:#?}",
1205 self.sample.edit_output.parser_metrics
1206 )?;
1207 writeln!(f, "Raw Edits:\n{}", self.sample.edit_output.raw_edits)?;
1208 Ok(())
1209 }
1210}
1211
1212fn report_progress(evaluated_count: usize, failed_count: usize, iterations: usize) {
1213 let passed_count = evaluated_count - failed_count;
1214 let passed_ratio = if evaluated_count == 0 {
1215 0.0
1216 } else {
1217 passed_count as f64 / evaluated_count as f64
1218 };
1219 print!(
1220 "\r\x1b[KEvaluated {}/{} ({:.2}% passed)",
1221 evaluated_count,
1222 iterations,
1223 passed_ratio * 100.0
1224 );
1225 std::io::stdout().flush().unwrap();
1226}
1227
1228struct EditAgentTest {
1229 agent: EditAgent,
1230 project: Entity<Project>,
1231 judge_model: Arc<dyn LanguageModel>,
1232}
1233
1234impl EditAgentTest {
1235 async fn new(cx: &mut TestAppContext) -> Self {
1236 cx.executor().allow_parking();
1237
1238 let fs = FakeFs::new(cx.executor().clone());
1239 cx.update(|cx| {
1240 settings::init(cx);
1241 gpui_tokio::init(cx);
1242 let http_client = Arc::new(ReqwestClient::user_agent("agent tests").unwrap());
1243 cx.set_http_client(http_client);
1244
1245 client::init_settings(cx);
1246 let client = Client::production(cx);
1247 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1248
1249 settings::init(cx);
1250 Project::init_settings(cx);
1251 language::init(cx);
1252 language_model::init(client.clone(), cx);
1253 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
1254 crate::init(client.http_client(), cx);
1255 });
1256
1257 fs.insert_tree("/root", json!({})).await;
1258 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1259 let agent_model = SelectedModel::from_str(
1260 &std::env::var("ZED_AGENT_MODEL")
1261 .unwrap_or("anthropic/claude-3-7-sonnet-latest".into()),
1262 )
1263 .unwrap();
1264 let judge_model = SelectedModel::from_str(
1265 &std::env::var("ZED_JUDGE_MODEL")
1266 .unwrap_or("anthropic/claude-3-7-sonnet-latest".into()),
1267 )
1268 .unwrap();
1269 let (agent_model, judge_model) = cx
1270 .update(|cx| {
1271 cx.spawn(async move |cx| {
1272 let agent_model = Self::load_model(&agent_model, cx).await;
1273 let judge_model = Self::load_model(&judge_model, cx).await;
1274 (agent_model.unwrap(), judge_model.unwrap())
1275 })
1276 })
1277 .await;
1278 let action_log = cx.new(|_| ActionLog::new(project.clone()));
1279
1280 Self {
1281 agent: EditAgent::new(agent_model, project.clone(), action_log, Templates::new()),
1282 project,
1283 judge_model,
1284 }
1285 }
1286
1287 async fn load_model(
1288 selected_model: &SelectedModel,
1289 cx: &mut AsyncApp,
1290 ) -> Result<Arc<dyn LanguageModel>> {
1291 let (provider, model) = cx.update(|cx| {
1292 let models = LanguageModelRegistry::read_global(cx);
1293 let model = models
1294 .available_models(cx)
1295 .find(|model| {
1296 model.provider_id() == selected_model.provider
1297 && model.id() == selected_model.model
1298 })
1299 .unwrap();
1300 let provider = models.provider(&model.provider_id()).unwrap();
1301 (provider, model)
1302 })?;
1303 cx.update(|cx| provider.authenticate(cx))?.await?;
1304 Ok(model)
1305 }
1306
1307 async fn eval(&self, eval: EvalInput, cx: &mut TestAppContext) -> Result<EvalOutput> {
1308 let path = self
1309 .project
1310 .read_with(cx, |project, cx| {
1311 project.find_project_path(eval.input_path, cx)
1312 })
1313 .unwrap();
1314 let buffer = self
1315 .project
1316 .update(cx, |project, cx| project.open_buffer(path, cx))
1317 .await
1318 .unwrap();
1319 let conversation = LanguageModelRequest {
1320 messages: eval.conversation,
1321 tools: cx.update(|cx| {
1322 ToolRegistry::default_global(cx)
1323 .tools()
1324 .into_iter()
1325 .filter_map(|tool| {
1326 let input_schema = tool
1327 .input_schema(self.agent.model.tool_input_format())
1328 .ok()?;
1329 Some(LanguageModelRequestTool {
1330 name: tool.name(),
1331 description: tool.description(),
1332 input_schema,
1333 })
1334 })
1335 .collect()
1336 }),
1337 ..Default::default()
1338 };
1339 let edit_output = if let Some(input_content) = eval.input_content.as_deref() {
1340 buffer.update(cx, |buffer, cx| buffer.set_text(input_content, cx));
1341 let (edit_output, _) = self.agent.edit(
1342 buffer.clone(),
1343 eval.edit_description,
1344 &conversation,
1345 &mut cx.to_async(),
1346 );
1347 edit_output.await?
1348 } else {
1349 let (edit_output, _) = self.agent.overwrite(
1350 buffer.clone(),
1351 eval.edit_description,
1352 &conversation,
1353 &mut cx.to_async(),
1354 );
1355 edit_output.await?
1356 };
1357
1358 let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
1359 let sample = EvalSample {
1360 edit_output,
1361 diff: language::unified_diff(
1362 eval.input_content.as_deref().unwrap_or_default(),
1363 &buffer_text,
1364 ),
1365 text: buffer_text,
1366 };
1367 let assertion = eval
1368 .assertion
1369 .run(&sample, self.judge_model.clone(), cx)
1370 .await?;
1371
1372 Ok(EvalOutput { assertion, sample })
1373 }
1374}
1375
1376#[derive(Clone, Debug, Eq, PartialEq, Hash)]
1377struct EvalAssertionOutcome {
1378 score: usize,
1379 message: Option<String>,
1380}
1381
1382#[derive(Serialize)]
1383pub struct DiffJudgeTemplate {
1384 diff: String,
1385 assertions: &'static str,
1386}
1387
1388impl Template for DiffJudgeTemplate {
1389 const TEMPLATE_NAME: &'static str = "diff_judge.hbs";
1390}
1391
1392fn strip_empty_lines(text: &str) -> String {
1393 text.lines()
1394 .filter(|line| !line.trim().is_empty())
1395 .collect::<Vec<_>>()
1396 .join("\n")
1397}