diff --git a/Cargo.lock b/Cargo.lock index 3e2f12a91c2b76a393f7f99f68bcd05933cb27f1..03b7339856a9adba3538152ac3874fd0dec859b5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -159,6 +159,7 @@ dependencies = [ "derive_more 0.99.20", "editor", "env_logger 0.11.8", + "eval_utils", "fs", "futures 0.3.31", "git", @@ -327,6 +328,7 @@ dependencies = [ "buffer_diff", "chrono", "client", + "clock", "cloud_llm_client", "collections", "command_palette_hooks", @@ -334,6 +336,7 @@ dependencies = [ "context_server", "db", "editor", + "eval_utils", "extension", "extension_host", "feature_flags", @@ -342,6 +345,7 @@ dependencies = [ "futures 0.3.31", "fuzzy", "gpui", + "gpui_tokio", "html_to_markdown", "http_client", "image", @@ -369,6 +373,7 @@ dependencies = [ "proto", "rand 0.9.2", "release_channel", + "reqwest_client", "rope", "rules_library", "schemars", @@ -5775,6 +5780,15 @@ dependencies = [ "watch", ] +[[package]] +name = "eval_utils" +version = "0.1.0" +dependencies = [ + "gpui", + "serde", + "smol", +] + [[package]] name = "event-listener" version = "2.5.3" diff --git a/Cargo.toml b/Cargo.toml index a6512c79093c197f5ed7a195f78bf7a170a15abe..e81e53426fc9ee47000e14cb8141ce4e4b6d8b30 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,6 +59,7 @@ members = [ "crates/zeta2_tools", "crates/editor", "crates/eval", + "crates/eval_utils", "crates/explorer_command_injector", "crates/extension", "crates/extension_api", @@ -288,6 +289,7 @@ deepseek = { path = "crates/deepseek" } derive_refineable = { path = "crates/refineable/derive_refineable" } diagnostics = { path = "crates/diagnostics" } editor = { path = "crates/editor" } +eval_utils = { path = "crates/eval_utils" } extension = { path = "crates/extension" } extension_host = { path = "crates/extension_host" } extensions_ui = { path = "crates/extensions_ui" } diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index cacbbd6e4e4423e2560fb963ef59daddce2309dc..667033a1bb33ea0372b8a9d8b0bfb00b23f59347 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -83,6 +83,7 @@ ctor.workspace = true db = { workspace = true, "features" = ["test-support"] } editor = { workspace = true, "features" = ["test-support"] } env_logger.workspace = true +eval_utils.workspace = true fs = { workspace = true, "features" = ["test-support"] } git = { workspace = true, "features" = ["test-support"] } gpui = { workspace = true, "features" = ["test-support"] } diff --git a/crates/agent/src/edit_agent/evals.rs b/crates/agent/src/edit_agent/evals.rs index 81dce33d0394b5757be4934031f31b6f17233e9c..edf8a0f671d231b3bfbd29526c256388fd41f85a 100644 --- a/crates/agent/src/edit_agent/evals.rs +++ b/crates/agent/src/edit_agent/evals.rs @@ -4,7 +4,7 @@ use crate::{ }; use Role::*; use client::{Client, UserStore}; -use collections::HashMap; +use eval_utils::{EvalOutput, EvalOutputProcessor, OutcomeKind}; use fs::FakeFs; use futures::{FutureExt, future::LocalBoxFuture}; use gpui::{AppContext, TestAppContext, Timer}; @@ -20,16 +20,62 @@ use rand::prelude::*; use reqwest_client::ReqwestClient; use serde_json::json; use std::{ - cmp::Reverse, fmt::{self, Display}, - io::Write as _, path::Path, str::FromStr, - sync::mpsc, time::Duration, }; use util::path; +#[derive(Default, Clone, Debug)] +struct EditAgentOutputProcessor { + mismatched_tag_threshold: f32, + cumulative_tags: usize, + cumulative_mismatched_tags: usize, + eval_outputs: Vec>, +} + +fn mismatched_tag_threshold(mismatched_tag_threshold: f32) -> EditAgentOutputProcessor { + EditAgentOutputProcessor { + mismatched_tag_threshold, + cumulative_tags: 0, + cumulative_mismatched_tags: 0, + eval_outputs: Vec::new(), + } +} + +#[derive(Clone, Debug)] +struct EditEvalMetadata { + tags: usize, + mismatched_tags: usize, +} + +impl EvalOutputProcessor for EditAgentOutputProcessor { + type Metadata = EditEvalMetadata; + + fn process(&mut self, output: &EvalOutput) { + if matches!(output.outcome, OutcomeKind::Passed | OutcomeKind::Failed) { + self.cumulative_mismatched_tags += output.metadata.mismatched_tags; + self.cumulative_tags += output.metadata.tags; + self.eval_outputs.push(output.clone()); + } + } + + fn assert(&mut self) { + let mismatched_tag_ratio = + self.cumulative_mismatched_tags as f32 / self.cumulative_tags as f32; + if mismatched_tag_ratio > self.mismatched_tag_threshold { + for eval_output in &self.eval_outputs { + println!("{}", eval_output.data); + } + panic!( + "Too many mismatched tags: {:?}", + self.cumulative_mismatched_tags + ); + } + } +} + #[test] #[cfg_attr(not(feature = "unit-eval"), ignore)] fn eval_extract_handle_command_output() { @@ -55,22 +101,19 @@ fn eval_extract_handle_command_output() { include_str!("evals/fixtures/extract_handle_command_output/possible-07.diff"), ]; let edit_description = "Extract `handle_command_output` method from `run_git_blame`."; - eval( - 100, - 0.95, - 0.05, - EvalInput::from_conversation( + eval_utils::eval(100, 0.95, mismatched_tag_threshold(0.05), move || { + run_eval(EvalInput::from_conversation( vec![ message( User, [text(formatdoc! {" - Read the `{input_file_path}` file and extract a method in - the final stanza of `run_git_blame` to deal with command failures, - call it `handle_command_output` and take the std::process::Output as the only parameter. - Do not document the method and do not add any comments. + Read the `{input_file_path}` file and extract a method in + the final stanza of `run_git_blame` to deal with command failures, + call it `handle_command_output` and take the std::process::Output as the only parameter. + Do not document the method and do not add any comments. - Add it right next to `run_git_blame` and copy it verbatim from `run_git_blame`. - "})], + Add it right next to `run_git_blame` and copy it verbatim from `run_git_blame`. + "})], ), message( Assistant, @@ -102,9 +145,9 @@ fn eval_extract_handle_command_output() { ), ], Some(input_file_content.into()), - EvalAssertion::assert_diff_any(possible_diffs), - ), - ); + EvalAssertion::assert_diff_any(possible_diffs.clone()), + )) + }); } #[test] @@ -122,18 +165,16 @@ fn eval_delete_run_git_blame() { let input_file_content = include_str!("evals/fixtures/delete_run_git_blame/before.rs"); let output_file_content = include_str!("evals/fixtures/delete_run_git_blame/after.rs"); let edit_description = "Delete the `run_git_blame` function."; - eval( - 100, - 0.95, - 0.05, - EvalInput::from_conversation( + + eval_utils::eval(100, 0.95, mismatched_tag_threshold(0.05), move || { + run_eval(EvalInput::from_conversation( vec![ message( User, [text(formatdoc! {" - Read the `{input_file_path}` file and delete `run_git_blame`. Just that - one function, not its usages. - "})], + Read the `{input_file_path}` file and delete `run_git_blame`. Just that + one function, not its usages. + "})], ), message( Assistant, @@ -166,8 +207,8 @@ fn eval_delete_run_git_blame() { ], Some(input_file_content.into()), EvalAssertion::assert_eq(output_file_content), - ), - ); + )) + }); } #[test] @@ -185,18 +226,16 @@ fn eval_translate_doc_comments() { let input_file_path = "root/canvas.rs"; let input_file_content = include_str!("evals/fixtures/translate_doc_comments/before.rs"); let edit_description = "Translate all doc comments to Italian"; - eval( - 200, - 1., - 0.05, - EvalInput::from_conversation( + + eval_utils::eval(200, 1., mismatched_tag_threshold(0.05), move || { + run_eval(EvalInput::from_conversation( vec![ message( User, [text(formatdoc! {" - Read the {input_file_path} file and edit it (without overwriting it), - translating all the doc comments to italian. - "})], + Read the {input_file_path} file and edit it (without overwriting it), + translating all the doc comments to italian. + "})], ), message( Assistant, @@ -229,8 +268,8 @@ fn eval_translate_doc_comments() { ], Some(input_file_content.into()), EvalAssertion::judge_diff("Doc comments were translated to Italian"), - ), - ); + )) + }); } #[test] @@ -249,33 +288,31 @@ fn eval_use_wasi_sdk_in_compile_parser_to_wasm() { let input_file_content = include_str!("evals/fixtures/use_wasi_sdk_in_compile_parser_to_wasm/before.rs"); let edit_description = "Update compile_parser_to_wasm to use wasi-sdk instead of emscripten"; - eval( - 100, - 0.95, - 0.05, - EvalInput::from_conversation( + + eval_utils::eval(100, 0.95, mismatched_tag_threshold(0.05), move || { + run_eval(EvalInput::from_conversation( vec![ message( User, [text(formatdoc! {" - Read the `{input_file_path}` file and change `compile_parser_to_wasm` to use `wasi-sdk` instead of emscripten. - Use `ureq` to download the SDK for the current platform and architecture. - Extract the archive into a sibling of `lib` inside the `tree-sitter` directory in the cache_dir. - Compile the parser to wasm using the `bin/clang` executable (or `bin/clang.exe` on windows) - that's inside of the archive. - Don't re-download the SDK if that executable already exists. - - Use these clang flags: -fPIC -shared -Os -Wl,--export=tree_sitter_{{language_name}} - - Here are the available wasi-sdk assets: - - wasi-sdk-25.0-x86_64-macos.tar.gz - - wasi-sdk-25.0-arm64-macos.tar.gz - - wasi-sdk-25.0-x86_64-linux.tar.gz - - wasi-sdk-25.0-arm64-linux.tar.gz - - wasi-sdk-25.0-x86_64-linux.tar.gz - - wasi-sdk-25.0-arm64-linux.tar.gz - - wasi-sdk-25.0-x86_64-windows.tar.gz - "})], + Read the `{input_file_path}` file and change `compile_parser_to_wasm` to use `wasi-sdk` instead of emscripten. + Use `ureq` to download the SDK for the current platform and architecture. + Extract the archive into a sibling of `lib` inside the `tree-sitter` directory in the cache_dir. + Compile the parser to wasm using the `bin/clang` executable (or `bin/clang.exe` on windows) + that's inside of the archive. + Don't re-download the SDK if that executable already exists. + + Use these clang flags: -fPIC -shared -Os -Wl,--export=tree_sitter_{{language_name}} + + Here are the available wasi-sdk assets: + - wasi-sdk-25.0-x86_64-macos.tar.gz + - wasi-sdk-25.0-arm64-macos.tar.gz + - wasi-sdk-25.0-x86_64-linux.tar.gz + - wasi-sdk-25.0-arm64-linux.tar.gz + - wasi-sdk-25.0-x86_64-linux.tar.gz + - wasi-sdk-25.0-arm64-linux.tar.gz + - wasi-sdk-25.0-x86_64-windows.tar.gz + "})], ), message( Assistant, @@ -352,11 +389,11 @@ fn eval_use_wasi_sdk_in_compile_parser_to_wasm() { ], Some(input_file_content.into()), EvalAssertion::judge_diff(indoc! {" - - The compile_parser_to_wasm method has been changed to use wasi-sdk - - ureq is used to download the SDK for current platform and architecture - "}), - ), - ); + - The compile_parser_to_wasm method has been changed to use wasi-sdk + - ureq is used to download the SDK for current platform and architecture + "}), + )) + }); } #[test] @@ -380,11 +417,8 @@ fn eval_disable_cursor_blinking() { include_str!("evals/fixtures/disable_cursor_blinking/possible-03.diff"), include_str!("evals/fixtures/disable_cursor_blinking/possible-04.diff"), ]; - eval( - 100, - 0.51, - 0.05, - EvalInput::from_conversation( + eval_utils::eval(100, 0.51, mismatched_tag_threshold(0.05), move || { + run_eval(EvalInput::from_conversation( vec![ message(User, [text("Let's research how to cursor blinking works.")]), message( @@ -421,10 +455,10 @@ fn eval_disable_cursor_blinking() { message( User, [text(indoc! {" - Comment out the lines that interact with the BlinkManager. - Keep the outer `update` blocks, but comments everything that's inside (including if statements). - Don't add additional comments. - "})], + Comment out the lines that interact with the BlinkManager. + Keep the outer `update` blocks, but comments everything that's inside (including if statements). + Don't add additional comments. + "})], ), message( Assistant, @@ -440,9 +474,9 @@ fn eval_disable_cursor_blinking() { ), ], Some(input_file_content.into()), - EvalAssertion::assert_diff_any(possible_diffs), - ), - ); + EvalAssertion::assert_diff_any(possible_diffs.clone()), + )) + }); } #[test] @@ -467,20 +501,16 @@ fn eval_from_pixels_constructor() { let input_file_path = "root/canvas.rs"; let input_file_content = include_str!("evals/fixtures/from_pixels_constructor/before.rs"); let edit_description = "Implement from_pixels constructor and add tests."; - eval( - 100, - 0.95, - // For whatever reason, this eval produces more mismatched tags. - // Increasing for now, let's see if we can bring this down. - 0.25, - EvalInput::from_conversation( + + eval_utils::eval(100, 0.95, mismatched_tag_threshold(0.25), move || { + run_eval(EvalInput::from_conversation( vec![ message( User, [text(indoc! {" - Introduce a new `from_pixels` constructor in Canvas and - also add tests for it in the same file. - "})], + Introduce a new `from_pixels` constructor in Canvas and + also add tests for it in the same file. + "})], ), message( Assistant, @@ -545,92 +575,92 @@ fn eval_from_pixels_constructor() { "tool_4", "grep", indoc! {" - Found 6 matches: + Found 6 matches: - ## Matches in font-kit/src/loaders/core_text.rs + ## Matches in font-kit/src/loaders/core_text.rs - ### mod test › L926-936 - ``` - mod test { - use super::Font; - use crate::properties::{Stretch, Weight}; + ### mod test › L926-936 + ``` + mod test { + use super::Font; + use crate::properties::{Stretch, Weight}; - #[cfg(feature = \"source\")] - use crate::source::SystemSource; + #[cfg(feature = \"source\")] + use crate::source::SystemSource; - static TEST_FONT_POSTSCRIPT_NAME: &'static str = \"ArialMT\"; + static TEST_FONT_POSTSCRIPT_NAME: &'static str = \"ArialMT\"; - #[cfg(feature = \"source\")] - #[test] - ``` + #[cfg(feature = \"source\")] + #[test] + ``` - 55 lines remaining in ancestor node. Read the file to see all. + 55 lines remaining in ancestor node. Read the file to see all. - ### mod test › L947-951 - ``` - } + ### mod test › L947-951 + ``` + } - #[test] - fn test_core_text_to_css_font_weight() { - // Exact matches - ``` + #[test] + fn test_core_text_to_css_font_weight() { + // Exact matches + ``` - ### mod test › L959-963 - ``` - } + ### mod test › L959-963 + ``` + } - #[test] - fn test_core_text_to_css_font_stretch() { - // Exact matches - ``` + #[test] + fn test_core_text_to_css_font_stretch() { + // Exact matches + ``` - ## Matches in font-kit/src/loaders/freetype.rs + ## Matches in font-kit/src/loaders/freetype.rs - ### mod test › L1238-1248 - ``` - mod test { - use crate::loaders::freetype::Font; + ### mod test › L1238-1248 + ``` + mod test { + use crate::loaders::freetype::Font; - static PCF_FONT_PATH: &str = \"resources/tests/times-roman-pcf/timR12.pcf\"; - static PCF_FONT_POSTSCRIPT_NAME: &str = \"Times-Roman\"; + static PCF_FONT_PATH: &str = \"resources/tests/times-roman-pcf/timR12.pcf\"; + static PCF_FONT_POSTSCRIPT_NAME: &str = \"Times-Roman\"; - #[test] - fn get_pcf_postscript_name() { - let font = Font::from_path(PCF_FONT_PATH, 0).unwrap(); - assert_eq!(font.postscript_name().unwrap(), PCF_FONT_POSTSCRIPT_NAME); - } - ``` + #[test] + fn get_pcf_postscript_name() { + let font = Font::from_path(PCF_FONT_PATH, 0).unwrap(); + assert_eq!(font.postscript_name().unwrap(), PCF_FONT_POSTSCRIPT_NAME); + } + ``` - 1 lines remaining in ancestor node. Read the file to see all. + 1 lines remaining in ancestor node. Read the file to see all. - ## Matches in font-kit/src/sources/core_text.rs + ## Matches in font-kit/src/sources/core_text.rs - ### mod test › L265-275 - ``` - mod test { - use crate::properties::{Stretch, Weight}; + ### mod test › L265-275 + ``` + mod test { + use crate::properties::{Stretch, Weight}; - #[test] - fn test_css_to_core_text_font_weight() { - // Exact matches - assert_eq!(super::css_to_core_text_font_weight(Weight(100.0)), -0.7); - assert_eq!(super::css_to_core_text_font_weight(Weight(400.0)), 0.0); - assert_eq!(super::css_to_core_text_font_weight(Weight(700.0)), 0.4); - assert_eq!(super::css_to_core_text_font_weight(Weight(900.0)), 0.8); + #[test] + fn test_css_to_core_text_font_weight() { + // Exact matches + assert_eq!(super::css_to_core_text_font_weight(Weight(100.0)), -0.7); + assert_eq!(super::css_to_core_text_font_weight(Weight(400.0)), 0.0); + assert_eq!(super::css_to_core_text_font_weight(Weight(700.0)), 0.4); + assert_eq!(super::css_to_core_text_font_weight(Weight(900.0)), 0.8); - ``` + ``` - 27 lines remaining in ancestor node. Read the file to see all. + 27 lines remaining in ancestor node. Read the file to see all. - ### mod test › L278-282 - ``` - } + ### mod test › L278-282 + ``` + } - #[test] - fn test_css_to_core_text_font_stretch() { - // Exact matches - ``` - "}, + #[test] + fn test_css_to_core_text_font_stretch() { + // Exact matches + ``` + "}, )], ), message( @@ -648,11 +678,11 @@ fn eval_from_pixels_constructor() { ], Some(input_file_content.into()), EvalAssertion::judge_diff(indoc! {" - - The diff contains a new `from_pixels` constructor - - The diff contains new tests for the `from_pixels` constructor - "}), - ), - ); + - The diff contains a new `from_pixels` constructor + - The diff contains new tests for the `from_pixels` constructor + "}), + )) + }); } #[test] @@ -670,11 +700,9 @@ fn eval_zode() { let input_file_path = "root/zode.py"; let input_content = None; let edit_description = "Create the main Zode CLI script"; - eval( - 50, - 1., - 0.05, - EvalInput::from_conversation( + + eval_utils::eval(50, 1., mismatched_tag_threshold(0.05), move || { + run_eval(EvalInput::from_conversation( vec![ message(User, [text(include_str!("evals/fixtures/zode/prompt.md"))]), message( @@ -733,7 +761,7 @@ fn eval_zode() { ], ), ], - input_content, + input_content.clone(), EvalAssertion::new(async move |sample, _, _cx| { let invalid_starts = [' ', '`', '\n']; let mut message = String::new(); @@ -758,8 +786,8 @@ fn eval_zode() { }) } }), - ), - ); + )) + }); } #[test] @@ -777,19 +805,17 @@ fn eval_add_overwrite_test() { let input_file_path = "root/action_log.rs"; let input_file_content = include_str!("evals/fixtures/add_overwrite_test/before.rs"); let edit_description = "Add a new test for overwriting a file in action_log.rs"; - eval( - 200, - 0.5, // TODO: make this eval better - 0.05, - EvalInput::from_conversation( + + eval_utils::eval(200, 0.5, mismatched_tag_threshold(0.05), move || { + run_eval(EvalInput::from_conversation( vec![ message( User, [text(indoc! {" - Introduce a new test in `action_log.rs` to test overwriting a file. - That is, a file already exists, but we call `buffer_created` as if the file were new. - Take inspiration from all the other tests in the file. - "})], + Introduce a new test in `action_log.rs` to test overwriting a file. + That is, a file already exists, but we call `buffer_created` as if the file were new. + Take inspiration from all the other tests in the file. + "})], ), message( Assistant, @@ -809,81 +835,81 @@ fn eval_add_overwrite_test() { "tool_1", "read_file", indoc! {" - pub struct ActionLog [L13-20] - tracked_buffers [L15] - edited_since_project_diagnostics_check [L17] - project [L19] - impl ActionLog [L22-498] - pub fn new [L24-30] - pub fn project [L32-34] - pub fn checked_project_diagnostics [L37-39] - pub fn has_edited_files_since_project_diagnostics_check [L42-44] - fn track_buffer_internal [L46-101] - fn handle_buffer_event [L103-116] - fn handle_buffer_edited [L118-123] - fn handle_buffer_file_changed [L125-158] - async fn maintain_diff [L160-264] - pub fn buffer_read [L267-269] - pub fn buffer_created [L272-276] - pub fn buffer_edited [L279-287] - pub fn will_delete_buffer [L289-304] - pub fn keep_edits_in_range [L306-364] - pub fn reject_edits_in_ranges [L366-459] - pub fn keep_all_edits [L461-473] - pub fn changed_buffers [L476-482] - pub fn stale_buffers [L485-497] - fn apply_non_conflicting_edits [L500-561] - fn diff_snapshots [L563-585] - fn point_to_row_edit [L587-614] - enum ChangeAuthor [L617-620] - User [L618] - Agent [L619] - enum TrackedBufferStatus [L623-627] - Created [L624] - Modified [L625] - Deleted [L626] - struct TrackedBuffer [L629-641] - buffer [L630] - base_text [L631] - unreviewed_changes [L632] - status [L633] - version [L634] - diff [L635] - snapshot [L636] - diff_update [L637] - _open_lsp_handle [L638] - _maintain_diff [L639] - _subscription [L640] - impl TrackedBuffer [L643-657] - fn has_changes [L644-650] - fn schedule_diff_update [L652-656] - pub struct ChangedBuffer [L659-661] - pub diff [L660] - mod tests [L664-1574] - fn init_logger [L678-682] - fn init_test [L684-691] - async fn test_keep_edits [L694-769] - async fn test_deletions [L772-854] - async fn test_overlapping_user_edits [L857-951] - async fn test_creating_files [L954-1010] - async fn test_deleting_files [L1013-1120] - async fn test_reject_edits [L1123-1255] - async fn test_reject_multiple_edits [L1258-1331] - async fn test_reject_deleted_file [L1334-1388] - async fn test_reject_created_file [L1391-1443] - async fn test_random_diffs [L1446-1535] - fn quiesce [L1510-1534] - struct HunkStatus [L1538-1542] - range [L1539] - diff_status [L1540] - old_text [L1541] - fn unreviewed_hunks [L1544-1573] - - Showing symbols 1-69 (total symbols: 69) - - Using the line numbers in this outline, you can call this tool again while specifying - the start_line and end_line fields to see the implementations of symbols in the outline. - "}, + pub struct ActionLog [L13-20] + tracked_buffers [L15] + edited_since_project_diagnostics_check [L17] + project [L19] + impl ActionLog [L22-498] + pub fn new [L24-30] + pub fn project [L32-34] + pub fn checked_project_diagnostics [L37-39] + pub fn has_edited_files_since_project_diagnostics_check [L42-44] + fn track_buffer_internal [L46-101] + fn handle_buffer_event [L103-116] + fn handle_buffer_edited [L118-123] + fn handle_buffer_file_changed [L125-158] + async fn maintain_diff [L160-264] + pub fn buffer_read [L267-269] + pub fn buffer_created [L272-276] + pub fn buffer_edited [L279-287] + pub fn will_delete_buffer [L289-304] + pub fn keep_edits_in_range [L306-364] + pub fn reject_edits_in_ranges [L366-459] + pub fn keep_all_edits [L461-473] + pub fn changed_buffers [L476-482] + pub fn stale_buffers [L485-497] + fn apply_non_conflicting_edits [L500-561] + fn diff_snapshots [L563-585] + fn point_to_row_edit [L587-614] + enum ChangeAuthor [L617-620] + User [L618] + Agent [L619] + enum TrackedBufferStatus [L623-627] + Created [L624] + Modified [L625] + Deleted [L626] + struct TrackedBuffer [L629-641] + buffer [L630] + base_text [L631] + unreviewed_changes [L632] + status [L633] + version [L634] + diff [L635] + snapshot [L636] + diff_update [L637] + _open_lsp_handle [L638] + _maintain_diff [L639] + _subscription [L640] + impl TrackedBuffer [L643-657] + fn has_changes [L644-650] + fn schedule_diff_update [L652-656] + pub struct ChangedBuffer [L659-661] + pub diff [L660] + mod tests [L664-1574] + fn init_logger [L678-682] + fn init_test [L684-691] + async fn test_keep_edits [L694-769] + async fn test_deletions [L772-854] + async fn test_overlapping_user_edits [L857-951] + async fn test_creating_files [L954-1010] + async fn test_deleting_files [L1013-1120] + async fn test_reject_edits [L1123-1255] + async fn test_reject_multiple_edits [L1258-1331] + async fn test_reject_deleted_file [L1334-1388] + async fn test_reject_created_file [L1391-1443] + async fn test_random_diffs [L1446-1535] + fn quiesce [L1510-1534] + struct HunkStatus [L1538-1542] + range [L1539] + diff_status [L1540] + old_text [L1541] + fn unreviewed_hunks [L1544-1573] + + Showing symbols 1-69 (total symbols: 69) + + Using the line numbers in this outline, you can call this tool again while specifying + the start_line and end_line fields to see the implementations of symbols in the outline. + "}, )], ), message( @@ -963,17 +989,17 @@ fn eval_add_overwrite_test() { Assistant, [ text(indoc! {" - Based on my understanding of the existing tests and code, I'll now create a new test for overwriting a file. In this test, we'll: + Based on my understanding of the existing tests and code, I'll now create a new test for overwriting a file. In this test, we'll: - 1. Create a file - 2. Open a buffer for it - 3. Track it in the action log - 4. Modify the buffer - 5. Call `buffer_created` again as if the file were new - 6. Check that the changes are properly tracked + 1. Create a file + 2. Open a buffer for it + 3. Track it in the action log + 4. Modify the buffer + 5. Call `buffer_created` again as if the file were new + 6. Check that the changes are properly tracked - Let's write the test: - "}), + Let's write the test: + "}), tool_use( "tool_5", "edit_file", @@ -990,8 +1016,8 @@ fn eval_add_overwrite_test() { EvalAssertion::judge_diff( "A new test for overwritten files was created, without changing any previous test", ), - ), - ); + )) + }); } #[test] @@ -1016,20 +1042,18 @@ fn eval_create_empty_file() { let input_file_content = None; let expected_output_content = String::new(); - eval( - 100, - 0.99, - 0.05, - EvalInput::from_conversation( + + eval_utils::eval(100, 0.99, mismatched_tag_threshold(0.05), move || { + run_eval(EvalInput::from_conversation( vec![ message(User, [text("Create a second empty todo file ")]), message( Assistant, [ text(formatdoc! {" - I'll help you create a second empty todo file. - First, let me examine the project structure to see if there's already a todo file, which will help me determine the appropriate name and location for the second one. - "}), + I'll help you create a second empty todo file. + First, let me examine the project structure to see if there's already a todo file, which will help me determine the appropriate name and location for the second one. + "}), tool_use( "toolu_01GAF8TtsgpjKxCr8fgQLDgR", "list_directory", @@ -1051,8 +1075,8 @@ fn eval_create_empty_file() { Assistant, [ text(formatdoc! {" - I can see there's already a `TODO` file in the `root` directory. Let me create a second empty todo file called `TODO3` in the same directory: - "}), + I can see there's already a `TODO` file in the `root` directory. Let me create a second empty todo file called `TODO3` in the same directory: + "}), tool_use( "toolu_01Tb3iQ9griqSYMmVuykQPWU", "edit_file", @@ -1065,12 +1089,12 @@ fn eval_create_empty_file() { ], ), ], - input_file_content, + input_file_content.clone(), // Bad behavior is to write something like // "I'll create an empty TODO3 file as requested." - EvalAssertion::assert_eq(expected_output_content), - ), - ); + EvalAssertion::assert_eq(expected_output_content.clone()), + )) + }); } fn message( @@ -1312,115 +1336,44 @@ impl EvalAssertion { } } -fn eval( - iterations: usize, - expected_pass_ratio: f32, - mismatched_tag_threshold: f32, - mut eval: EvalInput, -) { - let mut evaluated_count = 0; - let mut failed_count = 0; - report_progress(evaluated_count, failed_count, iterations); - - let (tx, rx) = mpsc::channel(); - - // Cache the last message in the conversation, and run one instance of the eval so that - // all the next ones are cached. - eval.conversation.last_mut().unwrap().cache = true; - run_eval(eval.clone(), tx.clone()); - - let executor = gpui::background_executor(); - let semaphore = Arc::new(smol::lock::Semaphore::new(32)); - for _ in 1..iterations { - let eval = eval.clone(); - let tx = tx.clone(); - let semaphore = semaphore.clone(); - executor - .spawn(async move { - let _guard = semaphore.acquire().await; - run_eval(eval, tx) - }) - .detach(); - } - drop(tx); - - let mut failed_evals = HashMap::default(); - let mut errored_evals = HashMap::default(); - let mut eval_outputs = Vec::new(); - let mut cumulative_parser_metrics = EditParserMetrics::default(); - while let Ok(output) = rx.recv() { - match output { - Ok(output) => { - cumulative_parser_metrics += output.sample.edit_output.parser_metrics.clone(); - eval_outputs.push(output.clone()); - if output.assertion.score < 80 { - failed_count += 1; - failed_evals - .entry(output.sample.text_after.clone()) - .or_insert(Vec::new()) - .push(output); - } - } - Err(error) => { - failed_count += 1; - *errored_evals.entry(format!("{:?}", error)).or_insert(0) += 1; - } - } - - evaluated_count += 1; - report_progress(evaluated_count, failed_count, iterations); - } - - let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32; - println!("Actual pass ratio: {}\n", actual_pass_ratio); - if actual_pass_ratio < expected_pass_ratio { - let mut errored_evals = errored_evals.into_iter().collect::>(); - errored_evals.sort_by_key(|(_, count)| Reverse(*count)); - for (error, count) in errored_evals { - println!("Eval errored {} times. Error: {}", count, error); - } - - let mut failed_evals = failed_evals.into_iter().collect::>(); - failed_evals.sort_by_key(|(_, evals)| Reverse(evals.len())); - for (_buffer_output, failed_evals) in failed_evals { - let eval_output = failed_evals.first().unwrap(); - println!("Eval failed {} times", failed_evals.len()); - println!("{}", eval_output); - } - - panic!( - "Actual pass ratio: {}\nExpected pass ratio: {}", - actual_pass_ratio, expected_pass_ratio - ); - } - - let mismatched_tag_ratio = - cumulative_parser_metrics.mismatched_tags as f32 / cumulative_parser_metrics.tags as f32; - if mismatched_tag_ratio > mismatched_tag_threshold { - for eval_output in eval_outputs { - println!("{}", eval_output); - } - panic!("Too many mismatched tags: {:?}", cumulative_parser_metrics); - } -} - -fn run_eval(eval: EvalInput, tx: mpsc::Sender>) { +fn run_eval(eval: EvalInput) -> eval_utils::EvalOutput { let dispatcher = gpui::TestDispatcher::new(StdRng::from_os_rng()); let mut cx = TestAppContext::build(dispatcher, None); - let output = cx.executor().block_test(async { + let result = cx.executor().block_test(async { let test = EditAgentTest::new(&mut cx).await; test.eval(eval, &mut cx).await }); - tx.send(output).unwrap(); + match result { + Ok(output) => eval_utils::EvalOutput { + data: output.to_string(), + outcome: if output.assertion.score < 80 { + eval_utils::OutcomeKind::Failed + } else { + eval_utils::OutcomeKind::Passed + }, + metadata: EditEvalMetadata { + tags: output.sample.edit_output.parser_metrics.tags, + mismatched_tags: output.sample.edit_output.parser_metrics.mismatched_tags, + }, + }, + Err(e) => eval_utils::EvalOutput { + data: format!("{e:?}"), + outcome: eval_utils::OutcomeKind::Error, + metadata: EditEvalMetadata { + tags: 0, + mismatched_tags: 0, + }, + }, + } } #[derive(Clone)] -struct EvalOutput { +struct EditEvalOutput { sample: EvalSample, assertion: EvalAssertionOutcome, } -impl Display for EvalOutput { +impl Display for EditEvalOutput { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { writeln!(f, "Score: {:?}", self.assertion.score)?; if let Some(message) = self.assertion.message.as_ref() { @@ -1439,22 +1392,6 @@ impl Display for EvalOutput { } } -fn report_progress(evaluated_count: usize, failed_count: usize, iterations: usize) { - let passed_count = evaluated_count - failed_count; - let passed_ratio = if evaluated_count == 0 { - 0.0 - } else { - passed_count as f64 / evaluated_count as f64 - }; - print!( - "\r\x1b[KEvaluated {}/{} ({:.2}% passed)", - evaluated_count, - iterations, - passed_ratio * 100.0 - ); - std::io::stdout().flush().unwrap(); -} - struct EditAgentTest { agent: EditAgent, project: Entity, @@ -1550,7 +1487,10 @@ impl EditAgentTest { }) } - async fn eval(&self, eval: EvalInput, cx: &mut TestAppContext) -> Result { + async fn eval(&self, mut eval: EvalInput, cx: &mut TestAppContext) -> Result { + // Make sure the last message in the conversation is cached. + eval.conversation.last_mut().unwrap().cache = true; + let path = self .project .read_with(cx, |project, cx| { @@ -1656,7 +1596,7 @@ impl EditAgentTest { .run(&sample, self.judge_model.clone(), cx) .await?; - Ok(EvalOutput { assertion, sample }) + Ok(EditEvalOutput { assertion, sample }) } } diff --git a/crates/agent_ui/Cargo.toml b/crates/agent_ui/Cargo.toml index 0f52c07078f447c9d8a95312ccd96561516907a1..048ffab9b72bdecce3754320bf34f1702f021554 100644 --- a/crates/agent_ui/Cargo.toml +++ b/crates/agent_ui/Cargo.toml @@ -13,7 +13,8 @@ path = "src/agent_ui.rs" doctest = false [features] -test-support = ["gpui/test-support", "language/test-support"] +test-support = ["gpui/test-support", "language/test-support", "reqwest_client"] +unit-eval = [] [dependencies] acp_thread.workspace = true @@ -47,6 +48,7 @@ fs.workspace = true futures.workspace = true fuzzy.workspace = true gpui.workspace = true +gpui_tokio.workspace = true html_to_markdown.workspace = true http_client.workspace = true indoc.workspace = true @@ -98,14 +100,17 @@ workspace.workspace = true zed_actions.workspace = true image.workspace = true async-fs.workspace = true +reqwest_client = { workspace = true, optional = true } [dev-dependencies] acp_thread = { workspace = true, features = ["test-support"] } agent = { workspace = true, features = ["test-support"] } assistant_text_thread = { workspace = true, features = ["test-support"] } buffer_diff = { workspace = true, features = ["test-support"] } +clock.workspace = true db = { workspace = true, features = ["test-support"] } editor = { workspace = true, features = ["test-support"] } +eval_utils.workspace = true gpui = { workspace = true, "features" = ["test-support"] } indoc.workspace = true language = { workspace = true, "features" = ["test-support"] } @@ -115,5 +120,6 @@ pretty_assertions.workspace = true project = { workspace = true, features = ["test-support"] } semver.workspace = true rand.workspace = true +reqwest_client.workspace = true tree-sitter-md.workspace = true unindent.workspace = true diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 9dd77774ff4e6f00bdfd26d024e9ee4b389b7f7e..18e8f1e731defa82e865dd45e66389634992037c 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -2685,16 +2685,17 @@ impl rules_library::InlineAssistDelegate for PromptLibraryInlineAssist { return; }; let project = workspace.read(cx).project().downgrade(); + let thread_store = panel.read(cx).thread_store().clone(); assistant.assist( prompt_editor, self.workspace.clone(), project, - panel.read(cx).thread_store().clone(), + thread_store, None, initial_prompt, window, cx, - ) + ); }) } diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index 5f5682b7dcc90d2b779744ba353380987a5907a1..f7b07b7bd393b8d3efffc3757eaf6025d5c651cd 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -7,6 +7,8 @@ mod buffer_codegen; mod completion_provider; mod context; mod context_server_configuration; +#[cfg(test)] +mod evals; mod inline_assistant; mod inline_prompt_editor; mod language_model_selector; diff --git a/crates/agent_ui/src/buffer_codegen.rs b/crates/agent_ui/src/buffer_codegen.rs index 1ac3ec1aec38c8d44d7557e1cf1e3ff09832c9d9..972ead664464876e57d7830b18db3f2b0c49629c 100644 --- a/crates/agent_ui/src/buffer_codegen.rs +++ b/crates/agent_ui/src/buffer_codegen.rs @@ -719,6 +719,7 @@ impl CodegenAlternative { output_tokens = usage.output_tokens, ) } + cx.emit(CodegenEvent::Finished); cx.notify(); }) diff --git a/crates/agent_ui/src/evals.rs b/crates/agent_ui/src/evals.rs new file mode 100644 index 0000000000000000000000000000000000000000..e82d21bd1fdb02a666c61bdf4754f27e79f92fda --- /dev/null +++ b/crates/agent_ui/src/evals.rs @@ -0,0 +1,89 @@ +use std::str::FromStr; + +use crate::inline_assistant::test::run_inline_assistant_test; + +use eval_utils::{EvalOutput, NoProcessor}; +use gpui::TestAppContext; +use language_model::{LanguageModelRegistry, SelectedModel}; +use rand::{SeedableRng as _, rngs::StdRng}; + +#[test] +#[cfg_attr(not(feature = "unit-eval"), ignore)] +fn eval_single_cursor_edit() { + eval_utils::eval(20, 1.0, NoProcessor, move || { + run_eval( + &EvalInput { + prompt: "Rename this variable to buffer_text".to_string(), + buffer: indoc::indoc! {" + struct EvalExampleStruct { + text: Strˇing, + prompt: String, + } + "} + .to_string(), + }, + &|_, output| { + let expected = indoc::indoc! {" + struct EvalExampleStruct { + buffer_text: String, + prompt: String, + } + "}; + if output == expected { + EvalOutput { + outcome: eval_utils::OutcomeKind::Passed, + data: "Passed!".to_string(), + metadata: (), + } + } else { + EvalOutput { + outcome: eval_utils::OutcomeKind::Failed, + data: format!("Failed to rename variable, output: {}", output), + metadata: (), + } + } + }, + ) + }); +} + +struct EvalInput { + buffer: String, + prompt: String, +} + +fn run_eval( + input: &EvalInput, + judge: &dyn Fn(&EvalInput, &str) -> eval_utils::EvalOutput<()>, +) -> eval_utils::EvalOutput<()> { + let dispatcher = gpui::TestDispatcher::new(StdRng::from_os_rng()); + let mut cx = TestAppContext::build(dispatcher, None); + cx.skip_drawing(); + + let buffer_text = run_inline_assistant_test( + input.buffer.clone(), + input.prompt.clone(), + |cx| { + // Reconfigure to use a real model instead of the fake one + let model_name = std::env::var("ZED_AGENT_MODEL") + .unwrap_or("anthropic/claude-sonnet-4-latest".into()); + + let selected_model = SelectedModel::from_str(&model_name) + .expect("Invalid model format. Use 'provider/model-id'"); + + log::info!("Selected model: {selected_model:?}"); + + cx.update(|_, cx| { + LanguageModelRegistry::global(cx).update(cx, |registry, cx| { + registry.select_inline_assistant_model(Some(&selected_model), cx); + }); + }); + }, + |_cx| { + log::info!("Waiting for actual response from the LLM..."); + }, + &mut cx, + ); + + judge(input, &buffer_text) +} diff --git a/crates/agent_ui/src/inline_assistant.rs b/crates/agent_ui/src/inline_assistant.rs index 3f27d0985991f19148cc852c44bfa60c57eaf750..cbc5891036fdf03ee04cca6b77820748faed2d0a 100644 --- a/crates/agent_ui/src/inline_assistant.rs +++ b/crates/agent_ui/src/inline_assistant.rs @@ -32,7 +32,7 @@ use editor::{ }, }; use fs::Fs; -use futures::FutureExt; +use futures::{FutureExt, channel::mpsc}; use gpui::{ App, Context, Entity, Focusable, Global, HighlightStyle, Subscription, Task, UpdateGlobal, WeakEntity, Window, point, @@ -102,6 +102,7 @@ pub struct InlineAssistant { prompt_builder: Arc, telemetry: Arc, fs: Arc, + _inline_assistant_completions: Option>>, } impl Global for InlineAssistant {} @@ -123,9 +124,18 @@ impl InlineAssistant { prompt_builder, telemetry, fs, + _inline_assistant_completions: None, } } + #[cfg(any(test, feature = "test-support"))] + pub fn set_completion_receiver( + &mut self, + sender: mpsc::UnboundedSender>, + ) { + self._inline_assistant_completions = Some(sender); + } + pub fn register_workspace( &mut self, workspace: &Entity, @@ -287,7 +297,7 @@ impl InlineAssistant { action.prompt.clone(), window, cx, - ) + ); }) } InlineAssistTarget::Terminal(active_terminal) => { @@ -301,8 +311,8 @@ impl InlineAssistant { action.prompt.clone(), window, cx, - ) - }) + ); + }); } }; @@ -598,13 +608,13 @@ impl InlineAssistant { initial_prompt: Option, window: &mut Window, cx: &mut App, - ) { + ) -> Option { let snapshot = editor.update(cx, |editor, cx| editor.snapshot(window, cx)); let Some((codegen_ranges, newest_selection)) = self.codegen_ranges(editor, &snapshot, window, cx) else { - return; + return None; }; let assist_to_focus = self.batch_assist( @@ -624,6 +634,8 @@ impl InlineAssistant { if let Some(assist_id) = assist_to_focus { self.focus_assist(assist_id, window, cx); } + + assist_to_focus } pub fn suggest_assist( @@ -1740,6 +1752,16 @@ impl InlineAssist { && assist.decorations.is_none() && let Some(workspace) = assist.workspace.upgrade() { + #[cfg(any(test, feature = "test-support"))] + if let Some(sender) = &mut this._inline_assistant_completions { + sender + .unbounded_send(Err(anyhow::anyhow!( + "Inline assistant error: {}", + error + ))) + .ok(); + } + let error = format!("Inline assistant error: {}", error); workspace.update(cx, |workspace, cx| { struct InlineAssistantError; @@ -1750,6 +1772,11 @@ impl InlineAssist { workspace.show_toast(Toast::new(id, error), cx); }) + } else { + #[cfg(any(test, feature = "test-support"))] + if let Some(sender) = &mut this._inline_assistant_completions { + sender.unbounded_send(Ok(assist_id)).ok(); + } } if assist.decorations.is_none() { @@ -1943,3 +1970,160 @@ fn merge_ranges(ranges: &mut Vec>, buffer: &MultiBufferSnapshot) { } } } + +#[cfg(any(test, feature = "test-support"))] +pub mod test { + use std::sync::Arc; + + use agent::HistoryStore; + use assistant_text_thread::TextThreadStore; + use client::{Client, UserStore}; + use editor::{Editor, MultiBuffer, MultiBufferOffset}; + use fs::FakeFs; + use futures::channel::mpsc; + use gpui::{AppContext, TestAppContext, UpdateGlobal as _}; + use language::Buffer; + use language_model::LanguageModelRegistry; + use project::Project; + use prompt_store::PromptBuilder; + use smol::stream::StreamExt as _; + use util::test::marked_text_ranges; + use workspace::Workspace; + + use crate::InlineAssistant; + + pub fn run_inline_assistant_test( + base_buffer: String, + prompt: String, + setup: SetupF, + test: TestF, + cx: &mut TestAppContext, + ) -> String + where + SetupF: FnOnce(&mut gpui::VisualTestContext), + TestF: FnOnce(&mut gpui::VisualTestContext), + { + let fs = FakeFs::new(cx.executor()); + let app_state = cx.update(|cx| workspace::AppState::test(cx)); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let http = Arc::new(reqwest_client::ReqwestClient::user_agent("agent tests").unwrap()); + let client = cx.update(|cx| { + cx.set_http_client(http); + Client::production(cx) + }); + let mut inline_assistant = + InlineAssistant::new(fs.clone(), prompt_builder, client.telemetry().clone()); + + let (tx, mut completion_rx) = mpsc::unbounded(); + inline_assistant.set_completion_receiver(tx); + + // Initialize settings and client + cx.update(|cx| { + gpui_tokio::init(cx); + settings::init(cx); + client::init(&client, cx); + workspace::init(app_state.clone(), cx); + let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + language_model::init(client.clone(), cx); + language_models::init(user_store, client.clone(), cx); + + cx.set_global(inline_assistant); + }); + + let project = cx + .executor() + .block_test(async { Project::test(fs.clone(), [], cx).await }); + + // Create workspace with window + let (workspace, cx) = cx.add_window_view(|window, cx| { + window.activate_window(); + Workspace::new(None, project.clone(), app_state.clone(), window, cx) + }); + + setup(cx); + + let (_editor, buffer) = cx.update(|window, cx| { + let buffer = cx.new(|cx| Buffer::local("", cx)); + let multibuffer = cx.new(|cx| MultiBuffer::singleton(buffer.clone(), cx)); + let editor = cx.new(|cx| Editor::for_multibuffer(multibuffer, None, window, cx)); + editor.update(cx, |editor, cx| { + let (unmarked_text, selection_ranges) = marked_text_ranges(&base_buffer, true); + editor.set_text(unmarked_text, window, cx); + editor.change_selections(Default::default(), window, cx, |s| { + s.select_ranges( + selection_ranges.into_iter().map(|range| { + MultiBufferOffset(range.start)..MultiBufferOffset(range.end) + }), + ) + }) + }); + + let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx)); + let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx)); + + // Add editor to workspace + workspace.update(cx, |workspace, cx| { + workspace.add_item_to_active_pane(Box::new(editor.clone()), None, true, window, cx); + }); + + // Call assist method + InlineAssistant::update_global(cx, |inline_assistant, cx| { + let assist_id = inline_assistant + .assist( + &editor, + workspace.downgrade(), + project.downgrade(), + history_store, // thread_store + None, // prompt_store + Some(prompt), + window, + cx, + ) + .unwrap(); + + inline_assistant.start_assist(assist_id, window, cx); + }); + + (editor, buffer) + }); + + cx.run_until_parked(); + + test(cx); + + cx.executor() + .block_test(async { completion_rx.next().await }); + + buffer.read_with(cx, |buffer, _| buffer.text()) + } + + #[allow(unused)] + pub fn test_inline_assistant( + base_buffer: &'static str, + llm_output: &'static str, + cx: &mut TestAppContext, + ) -> String { + run_inline_assistant_test( + base_buffer.to_string(), + "Prompt doesn't matter because we're using a fake model".to_string(), + |cx| { + cx.update(|_, cx| LanguageModelRegistry::test(cx)); + }, + |cx| { + let fake_model = cx.update(|_, cx| { + LanguageModelRegistry::global(cx) + .update(cx, |registry, _| registry.fake_model()) + }); + let fake = fake_model.as_fake(); + + // let fake = fake_model; + fake.send_last_completion_stream_text_chunk(llm_output.to_string()); + fake.end_last_completion_stream(); + + // Run again to process the model's response + cx.run_until_parked(); + }, + cx, + ) + } +} diff --git a/crates/eval_utils/Cargo.toml b/crates/eval_utils/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..a512035f5d1754f0f6f942faa27d063e169a22ef --- /dev/null +++ b/crates/eval_utils/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "eval_utils" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/eval_utils.rs" +doctest = false + +[dependencies] +gpui.workspace = true +serde.workspace = true +smol.workspace = true diff --git a/crates/eval_utils/LICENSE-GPL b/crates/eval_utils/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..e0f9dbd5d63fef1630c297edc4ceba4790be6f02 --- /dev/null +++ b/crates/eval_utils/LICENSE-GPL @@ -0,0 +1 @@ +LICENSE-GPL \ No newline at end of file diff --git a/crates/eval_utils/README.md b/crates/eval_utils/README.md new file mode 100644 index 0000000000000000000000000000000000000000..617077a81524ff918e8b9b93aa970d636504479c --- /dev/null +++ b/crates/eval_utils/README.md @@ -0,0 +1,3 @@ +# eval_utils + +Utilities for evals of agents. diff --git a/crates/eval_utils/src/eval_utils.rs b/crates/eval_utils/src/eval_utils.rs new file mode 100644 index 0000000000000000000000000000000000000000..880b1a97e414bbc3219bdf8f7163dbf9b6c9c82b --- /dev/null +++ b/crates/eval_utils/src/eval_utils.rs @@ -0,0 +1,128 @@ +//! Utilities for evaluation and benchmarking. + +use std::{ + collections::HashMap, + sync::{Arc, mpsc}, +}; + +fn report_progress(evaluated_count: usize, failed_count: usize, iterations: usize) { + let passed_count = evaluated_count - failed_count; + let passed_ratio = if evaluated_count == 0 { + 0.0 + } else { + passed_count as f64 / evaluated_count as f64 + }; + println!( + "\r\x1b[KEvaluated {}/{} ({:.2}% passed)", + evaluated_count, + iterations, + passed_ratio * 100.0 + ) +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum OutcomeKind { + Passed, + Failed, + Error, +} + +pub trait EvalOutputProcessor { + type Metadata: 'static + Send; + fn process(&mut self, output: &EvalOutput); + fn assert(&mut self); +} + +#[derive(Clone, Debug)] +pub struct EvalOutput { + pub outcome: OutcomeKind, + pub data: String, + pub metadata: M, +} + +pub struct NoProcessor; +impl EvalOutputProcessor for NoProcessor { + type Metadata = (); + + fn process(&mut self, _output: &EvalOutput) {} + + fn assert(&mut self) {} +} + +pub fn eval

( + iterations: usize, + expected_pass_ratio: f32, + mut processor: P, + evalf: impl Fn() -> EvalOutput + Send + Sync + 'static, +) where + P: EvalOutputProcessor, +{ + let mut evaluated_count = 0; + let mut failed_count = 0; + let evalf = Arc::new(evalf); + report_progress(evaluated_count, failed_count, iterations); + + let (tx, rx) = mpsc::channel(); + + let executor = gpui::background_executor(); + let semaphore = Arc::new(smol::lock::Semaphore::new(32)); + let evalf = Arc::new(evalf); + // Warm the cache once + let first_output = evalf(); + tx.send(first_output).ok(); + + for _ in 1..iterations { + let tx = tx.clone(); + let semaphore = semaphore.clone(); + let evalf = evalf.clone(); + executor + .spawn(async move { + let _guard = semaphore.acquire().await; + let output = evalf(); + tx.send(output).ok(); + }) + .detach(); + } + drop(tx); + + let mut failed_evals = Vec::new(); + let mut errored_evals = HashMap::new(); + while let Ok(output) = rx.recv() { + processor.process(&output); + + match output.outcome { + OutcomeKind::Passed => {} + OutcomeKind::Failed => { + failed_count += 1; + failed_evals.push(output); + } + OutcomeKind::Error => { + failed_count += 1; + *errored_evals.entry(output.data).or_insert(0) += 1; + } + } + + evaluated_count += 1; + report_progress(evaluated_count, failed_count, iterations); + } + + let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32; + println!("Actual pass ratio: {}\n", actual_pass_ratio); + if actual_pass_ratio < expected_pass_ratio { + for (error, count) in errored_evals { + println!("Eval errored {} times. Error: {}", count, error); + } + + for failed in failed_evals { + println!("Eval failed"); + println!("{}", failed.data); + } + + panic!( + "Actual pass ratio: {}\nExpected pass ratio: {}", + actual_pass_ratio, expected_pass_ratio + ); + } + + processor.assert(); +} diff --git a/crates/gpui/src/app.rs b/crates/gpui/src/app.rs index c042d85a1239dc6723b6501b27690a9f593a021b..2f4c7611dcf9d24302b3dda1d05c4c2b8711a68d 100644 --- a/crates/gpui/src/app.rs +++ b/crates/gpui/src/app.rs @@ -551,12 +551,39 @@ impl SystemWindowTabController { } } +pub(crate) enum GpuiMode { + #[cfg(any(test, feature = "test-support"))] + Test { + skip_drawing: bool, + }, + Production, +} + +impl GpuiMode { + #[cfg(any(test, feature = "test-support"))] + pub fn test() -> Self { + GpuiMode::Test { + skip_drawing: false, + } + } + + #[inline] + pub(crate) fn skip_drawing(&self) -> bool { + match self { + #[cfg(any(test, feature = "test-support"))] + GpuiMode::Test { skip_drawing } => *skip_drawing, + GpuiMode::Production => false, + } + } +} + /// Contains the state of the full application, and passed as a reference to a variety of callbacks. /// Other [Context] derefs to this type. /// You need a reference to an `App` to access the state of a [Entity]. pub struct App { pub(crate) this: Weak, pub(crate) platform: Rc, + pub(crate) mode: GpuiMode, text_system: Arc, flushing_effects: bool, pending_updates: usize, @@ -635,6 +662,7 @@ impl App { this: this.clone(), platform: platform.clone(), text_system, + mode: GpuiMode::Production, actions: Rc::new(ActionRegistry::default()), flushing_effects: false, pending_updates: 0, diff --git a/crates/gpui/src/app/test_context.rs b/crates/gpui/src/app/test_context.rs index 4a7b73c359ed3dd55b136b22e9487dee1735e42e..5be2e394e8edfd26a25c70c79c321a7fb8fdc8ba 100644 --- a/crates/gpui/src/app/test_context.rs +++ b/crates/gpui/src/app/test_context.rs @@ -5,7 +5,7 @@ use crate::{ ModifiersChangedEvent, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Pixels, Platform, Point, Render, Result, Size, Task, TestDispatcher, TestPlatform, TestScreenCaptureSource, TestWindow, TextSystem, VisualContext, Window, WindowBounds, - WindowHandle, WindowOptions, + WindowHandle, WindowOptions, app::GpuiMode, }; use anyhow::{anyhow, bail}; use futures::{Stream, StreamExt, channel::oneshot}; @@ -132,8 +132,11 @@ impl TestAppContext { let http_client = http_client::FakeHttpClient::with_404_response(); let text_system = Arc::new(TextSystem::new(platform.text_system())); + let mut app = App::new_app(platform.clone(), asset_source, http_client); + app.borrow_mut().mode = GpuiMode::test(); + Self { - app: App::new_app(platform.clone(), asset_source, http_client), + app, background_executor, foreground_executor, dispatcher, @@ -144,6 +147,11 @@ impl TestAppContext { } } + /// Skip all drawing operations for the duration of this test. + pub fn skip_drawing(&mut self) { + self.app.borrow_mut().mode = GpuiMode::Test { skip_drawing: true }; + } + /// Create a single TestAppContext, for non-multi-client tests pub fn single() -> Self { let dispatcher = TestDispatcher::new(StdRng::seed_from_u64(0)); diff --git a/crates/gpui/src/window.rs b/crates/gpui/src/window.rs index dabf7cf2b42cf57becb996e1f9360aaba0b6eead..2d525adb8f82a96c24ee3f524030782a7de3577c 100644 --- a/crates/gpui/src/window.rs +++ b/crates/gpui/src/window.rs @@ -2006,7 +2006,9 @@ impl Window { if let Some(input_handler) = self.platform_window.take_input_handler() { self.rendered_frame.input_handlers.push(Some(input_handler)); } - self.draw_roots(cx); + if !cx.mode.skip_drawing() { + self.draw_roots(cx); + } self.dirty_views.clear(); self.next_frame.window_active = self.active.get(); diff --git a/crates/http_client/src/http_client.rs b/crates/http_client/src/http_client.rs index 98c67f4e27a8e8b20489cc3c4ad4a1207e8b848f..f357e01da062398d18134df6625d30b8129bf875 100644 --- a/crates/http_client/src/http_client.rs +++ b/crates/http_client/src/http_client.rs @@ -408,6 +408,7 @@ impl FakeHttpClient { } pub fn with_404_response() -> Arc { + log::warn!("Using fake HTTP client with 404 response"); Self::create(|_| async move { Ok(Response::builder() .status(404) @@ -417,6 +418,7 @@ impl FakeHttpClient { } pub fn with_200_response() -> Arc { + log::warn!("Using fake HTTP client with 200 response"); Self::create(|_| async move { Ok(Response::builder() .status(200) diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index 6ed8bf07c4e976c88fecebd929843335333b1fa6..27b8309810962981d3c0ec78e6e67dfdfba122bf 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -135,6 +135,11 @@ impl LanguageModelRegistry { fake_provider } + #[cfg(any(test, feature = "test-support"))] + pub fn fake_model(&self) -> Arc { + self.default_model.as_ref().unwrap().model.clone() + } + pub fn register_provider( &mut self, provider: Arc,