Detailed changes
@@ -159,6 +159,7 @@ dependencies = [
"derive_more 0.99.20",
"editor",
"env_logger 0.11.8",
+ "eval_utils",
"fs",
"futures 0.3.31",
"git",
@@ -327,6 +328,7 @@ dependencies = [
"buffer_diff",
"chrono",
"client",
+ "clock",
"cloud_llm_client",
"collections",
"command_palette_hooks",
@@ -334,6 +336,7 @@ dependencies = [
"context_server",
"db",
"editor",
+ "eval_utils",
"extension",
"extension_host",
"feature_flags",
@@ -342,6 +345,7 @@ dependencies = [
"futures 0.3.31",
"fuzzy",
"gpui",
+ "gpui_tokio",
"html_to_markdown",
"http_client",
"image",
@@ -369,6 +373,7 @@ dependencies = [
"proto",
"rand 0.9.2",
"release_channel",
+ "reqwest_client",
"rope",
"rules_library",
"schemars",
@@ -5775,6 +5780,15 @@ dependencies = [
"watch",
]
+[[package]]
+name = "eval_utils"
+version = "0.1.0"
+dependencies = [
+ "gpui",
+ "serde",
+ "smol",
+]
+
[[package]]
name = "event-listener"
version = "2.5.3"
@@ -59,6 +59,7 @@ members = [
"crates/zeta2_tools",
"crates/editor",
"crates/eval",
+ "crates/eval_utils",
"crates/explorer_command_injector",
"crates/extension",
"crates/extension_api",
@@ -288,6 +289,7 @@ deepseek = { path = "crates/deepseek" }
derive_refineable = { path = "crates/refineable/derive_refineable" }
diagnostics = { path = "crates/diagnostics" }
editor = { path = "crates/editor" }
+eval_utils = { path = "crates/eval_utils" }
extension = { path = "crates/extension" }
extension_host = { path = "crates/extension_host" }
extensions_ui = { path = "crates/extensions_ui" }
@@ -83,6 +83,7 @@ ctor.workspace = true
db = { workspace = true, "features" = ["test-support"] }
editor = { workspace = true, "features" = ["test-support"] }
env_logger.workspace = true
+eval_utils.workspace = true
fs = { workspace = true, "features" = ["test-support"] }
git = { workspace = true, "features" = ["test-support"] }
gpui = { workspace = true, "features" = ["test-support"] }
@@ -4,7 +4,7 @@ use crate::{
};
use Role::*;
use client::{Client, UserStore};
-use collections::HashMap;
+use eval_utils::{EvalOutput, EvalOutputProcessor, OutcomeKind};
use fs::FakeFs;
use futures::{FutureExt, future::LocalBoxFuture};
use gpui::{AppContext, TestAppContext, Timer};
@@ -20,16 +20,62 @@ use rand::prelude::*;
use reqwest_client::ReqwestClient;
use serde_json::json;
use std::{
- cmp::Reverse,
fmt::{self, Display},
- io::Write as _,
path::Path,
str::FromStr,
- sync::mpsc,
time::Duration,
};
use util::path;
+#[derive(Default, Clone, Debug)]
+struct EditAgentOutputProcessor {
+ mismatched_tag_threshold: f32,
+ cumulative_tags: usize,
+ cumulative_mismatched_tags: usize,
+ eval_outputs: Vec<EvalOutput<EditEvalMetadata>>,
+}
+
+fn mismatched_tag_threshold(mismatched_tag_threshold: f32) -> EditAgentOutputProcessor {
+ EditAgentOutputProcessor {
+ mismatched_tag_threshold,
+ cumulative_tags: 0,
+ cumulative_mismatched_tags: 0,
+ eval_outputs: Vec::new(),
+ }
+}
+
+#[derive(Clone, Debug)]
+struct EditEvalMetadata {
+ tags: usize,
+ mismatched_tags: usize,
+}
+
+impl EvalOutputProcessor for EditAgentOutputProcessor {
+ type Metadata = EditEvalMetadata;
+
+ fn process(&mut self, output: &EvalOutput<Self::Metadata>) {
+ if matches!(output.outcome, OutcomeKind::Passed | OutcomeKind::Failed) {
+ self.cumulative_mismatched_tags += output.metadata.mismatched_tags;
+ self.cumulative_tags += output.metadata.tags;
+ self.eval_outputs.push(output.clone());
+ }
+ }
+
+ fn assert(&mut self) {
+ let mismatched_tag_ratio =
+ self.cumulative_mismatched_tags as f32 / self.cumulative_tags as f32;
+ if mismatched_tag_ratio > self.mismatched_tag_threshold {
+ for eval_output in &self.eval_outputs {
+ println!("{}", eval_output.data);
+ }
+ panic!(
+ "Too many mismatched tags: {:?}",
+ self.cumulative_mismatched_tags
+ );
+ }
+ }
+}
+
#[test]
#[cfg_attr(not(feature = "unit-eval"), ignore)]
fn eval_extract_handle_command_output() {
@@ -55,22 +101,19 @@ fn eval_extract_handle_command_output() {
include_str!("evals/fixtures/extract_handle_command_output/possible-07.diff"),
];
let edit_description = "Extract `handle_command_output` method from `run_git_blame`.";
- eval(
- 100,
- 0.95,
- 0.05,
- EvalInput::from_conversation(
+ eval_utils::eval(100, 0.95, mismatched_tag_threshold(0.05), move || {
+ run_eval(EvalInput::from_conversation(
vec![
message(
User,
[text(formatdoc! {"
- Read the `{input_file_path}` file and extract a method in
- the final stanza of `run_git_blame` to deal with command failures,
- call it `handle_command_output` and take the std::process::Output as the only parameter.
- Do not document the method and do not add any comments.
+ Read the `{input_file_path}` file and extract a method in
+ the final stanza of `run_git_blame` to deal with command failures,
+ call it `handle_command_output` and take the std::process::Output as the only parameter.
+ Do not document the method and do not add any comments.
- Add it right next to `run_git_blame` and copy it verbatim from `run_git_blame`.
- "})],
+ Add it right next to `run_git_blame` and copy it verbatim from `run_git_blame`.
+ "})],
),
message(
Assistant,
@@ -102,9 +145,9 @@ fn eval_extract_handle_command_output() {
),
],
Some(input_file_content.into()),
- EvalAssertion::assert_diff_any(possible_diffs),
- ),
- );
+ EvalAssertion::assert_diff_any(possible_diffs.clone()),
+ ))
+ });
}
#[test]
@@ -122,18 +165,16 @@ fn eval_delete_run_git_blame() {
let input_file_content = include_str!("evals/fixtures/delete_run_git_blame/before.rs");
let output_file_content = include_str!("evals/fixtures/delete_run_git_blame/after.rs");
let edit_description = "Delete the `run_git_blame` function.";
- eval(
- 100,
- 0.95,
- 0.05,
- EvalInput::from_conversation(
+
+ eval_utils::eval(100, 0.95, mismatched_tag_threshold(0.05), move || {
+ run_eval(EvalInput::from_conversation(
vec![
message(
User,
[text(formatdoc! {"
- Read the `{input_file_path}` file and delete `run_git_blame`. Just that
- one function, not its usages.
- "})],
+ Read the `{input_file_path}` file and delete `run_git_blame`. Just that
+ one function, not its usages.
+ "})],
),
message(
Assistant,
@@ -166,8 +207,8 @@ fn eval_delete_run_git_blame() {
],
Some(input_file_content.into()),
EvalAssertion::assert_eq(output_file_content),
- ),
- );
+ ))
+ });
}
#[test]
@@ -185,18 +226,16 @@ fn eval_translate_doc_comments() {
let input_file_path = "root/canvas.rs";
let input_file_content = include_str!("evals/fixtures/translate_doc_comments/before.rs");
let edit_description = "Translate all doc comments to Italian";
- eval(
- 200,
- 1.,
- 0.05,
- EvalInput::from_conversation(
+
+ eval_utils::eval(200, 1., mismatched_tag_threshold(0.05), move || {
+ run_eval(EvalInput::from_conversation(
vec![
message(
User,
[text(formatdoc! {"
- Read the {input_file_path} file and edit it (without overwriting it),
- translating all the doc comments to italian.
- "})],
+ Read the {input_file_path} file and edit it (without overwriting it),
+ translating all the doc comments to italian.
+ "})],
),
message(
Assistant,
@@ -229,8 +268,8 @@ fn eval_translate_doc_comments() {
],
Some(input_file_content.into()),
EvalAssertion::judge_diff("Doc comments were translated to Italian"),
- ),
- );
+ ))
+ });
}
#[test]
@@ -249,33 +288,31 @@ fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
let input_file_content =
include_str!("evals/fixtures/use_wasi_sdk_in_compile_parser_to_wasm/before.rs");
let edit_description = "Update compile_parser_to_wasm to use wasi-sdk instead of emscripten";
- eval(
- 100,
- 0.95,
- 0.05,
- EvalInput::from_conversation(
+
+ eval_utils::eval(100, 0.95, mismatched_tag_threshold(0.05), move || {
+ run_eval(EvalInput::from_conversation(
vec![
message(
User,
[text(formatdoc! {"
- Read the `{input_file_path}` file and change `compile_parser_to_wasm` to use `wasi-sdk` instead of emscripten.
- Use `ureq` to download the SDK for the current platform and architecture.
- Extract the archive into a sibling of `lib` inside the `tree-sitter` directory in the cache_dir.
- Compile the parser to wasm using the `bin/clang` executable (or `bin/clang.exe` on windows)
- that's inside of the archive.
- Don't re-download the SDK if that executable already exists.
-
- Use these clang flags: -fPIC -shared -Os -Wl,--export=tree_sitter_{{language_name}}
-
- Here are the available wasi-sdk assets:
- - wasi-sdk-25.0-x86_64-macos.tar.gz
- - wasi-sdk-25.0-arm64-macos.tar.gz
- - wasi-sdk-25.0-x86_64-linux.tar.gz
- - wasi-sdk-25.0-arm64-linux.tar.gz
- - wasi-sdk-25.0-x86_64-linux.tar.gz
- - wasi-sdk-25.0-arm64-linux.tar.gz
- - wasi-sdk-25.0-x86_64-windows.tar.gz
- "})],
+ Read the `{input_file_path}` file and change `compile_parser_to_wasm` to use `wasi-sdk` instead of emscripten.
+ Use `ureq` to download the SDK for the current platform and architecture.
+ Extract the archive into a sibling of `lib` inside the `tree-sitter` directory in the cache_dir.
+ Compile the parser to wasm using the `bin/clang` executable (or `bin/clang.exe` on windows)
+ that's inside of the archive.
+ Don't re-download the SDK if that executable already exists.
+
+ Use these clang flags: -fPIC -shared -Os -Wl,--export=tree_sitter_{{language_name}}
+
+ Here are the available wasi-sdk assets:
+ - wasi-sdk-25.0-x86_64-macos.tar.gz
+ - wasi-sdk-25.0-arm64-macos.tar.gz
+ - wasi-sdk-25.0-x86_64-linux.tar.gz
+ - wasi-sdk-25.0-arm64-linux.tar.gz
+ - wasi-sdk-25.0-x86_64-linux.tar.gz
+ - wasi-sdk-25.0-arm64-linux.tar.gz
+ - wasi-sdk-25.0-x86_64-windows.tar.gz
+ "})],
),
message(
Assistant,
@@ -352,11 +389,11 @@ fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
],
Some(input_file_content.into()),
EvalAssertion::judge_diff(indoc! {"
- - The compile_parser_to_wasm method has been changed to use wasi-sdk
- - ureq is used to download the SDK for current platform and architecture
- "}),
- ),
- );
+ - The compile_parser_to_wasm method has been changed to use wasi-sdk
+ - ureq is used to download the SDK for current platform and architecture
+ "}),
+ ))
+ });
}
#[test]
@@ -380,11 +417,8 @@ fn eval_disable_cursor_blinking() {
include_str!("evals/fixtures/disable_cursor_blinking/possible-03.diff"),
include_str!("evals/fixtures/disable_cursor_blinking/possible-04.diff"),
];
- eval(
- 100,
- 0.51,
- 0.05,
- EvalInput::from_conversation(
+ eval_utils::eval(100, 0.51, mismatched_tag_threshold(0.05), move || {
+ run_eval(EvalInput::from_conversation(
vec![
message(User, [text("Let's research how to cursor blinking works.")]),
message(
@@ -421,10 +455,10 @@ fn eval_disable_cursor_blinking() {
message(
User,
[text(indoc! {"
- Comment out the lines that interact with the BlinkManager.
- Keep the outer `update` blocks, but comments everything that's inside (including if statements).
- Don't add additional comments.
- "})],
+ Comment out the lines that interact with the BlinkManager.
+ Keep the outer `update` blocks, but comments everything that's inside (including if statements).
+ Don't add additional comments.
+ "})],
),
message(
Assistant,
@@ -440,9 +474,9 @@ fn eval_disable_cursor_blinking() {
),
],
Some(input_file_content.into()),
- EvalAssertion::assert_diff_any(possible_diffs),
- ),
- );
+ EvalAssertion::assert_diff_any(possible_diffs.clone()),
+ ))
+ });
}
#[test]
@@ -467,20 +501,16 @@ fn eval_from_pixels_constructor() {
let input_file_path = "root/canvas.rs";
let input_file_content = include_str!("evals/fixtures/from_pixels_constructor/before.rs");
let edit_description = "Implement from_pixels constructor and add tests.";
- eval(
- 100,
- 0.95,
- // For whatever reason, this eval produces more mismatched tags.
- // Increasing for now, let's see if we can bring this down.
- 0.25,
- EvalInput::from_conversation(
+
+ eval_utils::eval(100, 0.95, mismatched_tag_threshold(0.25), move || {
+ run_eval(EvalInput::from_conversation(
vec![
message(
User,
[text(indoc! {"
- Introduce a new `from_pixels` constructor in Canvas and
- also add tests for it in the same file.
- "})],
+ Introduce a new `from_pixels` constructor in Canvas and
+ also add tests for it in the same file.
+ "})],
),
message(
Assistant,
@@ -545,92 +575,92 @@ fn eval_from_pixels_constructor() {
"tool_4",
"grep",
indoc! {"
- Found 6 matches:
+ Found 6 matches:
- ## Matches in font-kit/src/loaders/core_text.rs
+ ## Matches in font-kit/src/loaders/core_text.rs
- ### mod test βΊ L926-936
- ```
- mod test {
- use super::Font;
- use crate::properties::{Stretch, Weight};
+ ### mod test βΊ L926-936
+ ```
+ mod test {
+ use super::Font;
+ use crate::properties::{Stretch, Weight};
- #[cfg(feature = \"source\")]
- use crate::source::SystemSource;
+ #[cfg(feature = \"source\")]
+ use crate::source::SystemSource;
- static TEST_FONT_POSTSCRIPT_NAME: &'static str = \"ArialMT\";
+ static TEST_FONT_POSTSCRIPT_NAME: &'static str = \"ArialMT\";
- #[cfg(feature = \"source\")]
- #[test]
- ```
+ #[cfg(feature = \"source\")]
+ #[test]
+ ```
- 55 lines remaining in ancestor node. Read the file to see all.
+ 55 lines remaining in ancestor node. Read the file to see all.
- ### mod test βΊ L947-951
- ```
- }
+ ### mod test βΊ L947-951
+ ```
+ }
- #[test]
- fn test_core_text_to_css_font_weight() {
- // Exact matches
- ```
+ #[test]
+ fn test_core_text_to_css_font_weight() {
+ // Exact matches
+ ```
- ### mod test βΊ L959-963
- ```
- }
+ ### mod test βΊ L959-963
+ ```
+ }
- #[test]
- fn test_core_text_to_css_font_stretch() {
- // Exact matches
- ```
+ #[test]
+ fn test_core_text_to_css_font_stretch() {
+ // Exact matches
+ ```
- ## Matches in font-kit/src/loaders/freetype.rs
+ ## Matches in font-kit/src/loaders/freetype.rs
- ### mod test βΊ L1238-1248
- ```
- mod test {
- use crate::loaders::freetype::Font;
+ ### mod test βΊ L1238-1248
+ ```
+ mod test {
+ use crate::loaders::freetype::Font;
- static PCF_FONT_PATH: &str = \"resources/tests/times-roman-pcf/timR12.pcf\";
- static PCF_FONT_POSTSCRIPT_NAME: &str = \"Times-Roman\";
+ static PCF_FONT_PATH: &str = \"resources/tests/times-roman-pcf/timR12.pcf\";
+ static PCF_FONT_POSTSCRIPT_NAME: &str = \"Times-Roman\";
- #[test]
- fn get_pcf_postscript_name() {
- let font = Font::from_path(PCF_FONT_PATH, 0).unwrap();
- assert_eq!(font.postscript_name().unwrap(), PCF_FONT_POSTSCRIPT_NAME);
- }
- ```
+ #[test]
+ fn get_pcf_postscript_name() {
+ let font = Font::from_path(PCF_FONT_PATH, 0).unwrap();
+ assert_eq!(font.postscript_name().unwrap(), PCF_FONT_POSTSCRIPT_NAME);
+ }
+ ```
- 1 lines remaining in ancestor node. Read the file to see all.
+ 1 lines remaining in ancestor node. Read the file to see all.
- ## Matches in font-kit/src/sources/core_text.rs
+ ## Matches in font-kit/src/sources/core_text.rs
- ### mod test βΊ L265-275
- ```
- mod test {
- use crate::properties::{Stretch, Weight};
+ ### mod test βΊ L265-275
+ ```
+ mod test {
+ use crate::properties::{Stretch, Weight};
- #[test]
- fn test_css_to_core_text_font_weight() {
- // Exact matches
- assert_eq!(super::css_to_core_text_font_weight(Weight(100.0)), -0.7);
- assert_eq!(super::css_to_core_text_font_weight(Weight(400.0)), 0.0);
- assert_eq!(super::css_to_core_text_font_weight(Weight(700.0)), 0.4);
- assert_eq!(super::css_to_core_text_font_weight(Weight(900.0)), 0.8);
+ #[test]
+ fn test_css_to_core_text_font_weight() {
+ // Exact matches
+ assert_eq!(super::css_to_core_text_font_weight(Weight(100.0)), -0.7);
+ assert_eq!(super::css_to_core_text_font_weight(Weight(400.0)), 0.0);
+ assert_eq!(super::css_to_core_text_font_weight(Weight(700.0)), 0.4);
+ assert_eq!(super::css_to_core_text_font_weight(Weight(900.0)), 0.8);
- ```
+ ```
- 27 lines remaining in ancestor node. Read the file to see all.
+ 27 lines remaining in ancestor node. Read the file to see all.
- ### mod test βΊ L278-282
- ```
- }
+ ### mod test βΊ L278-282
+ ```
+ }
- #[test]
- fn test_css_to_core_text_font_stretch() {
- // Exact matches
- ```
- "},
+ #[test]
+ fn test_css_to_core_text_font_stretch() {
+ // Exact matches
+ ```
+ "},
)],
),
message(
@@ -648,11 +678,11 @@ fn eval_from_pixels_constructor() {
],
Some(input_file_content.into()),
EvalAssertion::judge_diff(indoc! {"
- - The diff contains a new `from_pixels` constructor
- - The diff contains new tests for the `from_pixels` constructor
- "}),
- ),
- );
+ - The diff contains a new `from_pixels` constructor
+ - The diff contains new tests for the `from_pixels` constructor
+ "}),
+ ))
+ });
}
#[test]
@@ -670,11 +700,9 @@ fn eval_zode() {
let input_file_path = "root/zode.py";
let input_content = None;
let edit_description = "Create the main Zode CLI script";
- eval(
- 50,
- 1.,
- 0.05,
- EvalInput::from_conversation(
+
+ eval_utils::eval(50, 1., mismatched_tag_threshold(0.05), move || {
+ run_eval(EvalInput::from_conversation(
vec![
message(User, [text(include_str!("evals/fixtures/zode/prompt.md"))]),
message(
@@ -733,7 +761,7 @@ fn eval_zode() {
],
),
],
- input_content,
+ input_content.clone(),
EvalAssertion::new(async move |sample, _, _cx| {
let invalid_starts = [' ', '`', '\n'];
let mut message = String::new();
@@ -758,8 +786,8 @@ fn eval_zode() {
})
}
}),
- ),
- );
+ ))
+ });
}
#[test]
@@ -777,19 +805,17 @@ fn eval_add_overwrite_test() {
let input_file_path = "root/action_log.rs";
let input_file_content = include_str!("evals/fixtures/add_overwrite_test/before.rs");
let edit_description = "Add a new test for overwriting a file in action_log.rs";
- eval(
- 200,
- 0.5, // TODO: make this eval better
- 0.05,
- EvalInput::from_conversation(
+
+ eval_utils::eval(200, 0.5, mismatched_tag_threshold(0.05), move || {
+ run_eval(EvalInput::from_conversation(
vec![
message(
User,
[text(indoc! {"
- Introduce a new test in `action_log.rs` to test overwriting a file.
- That is, a file already exists, but we call `buffer_created` as if the file were new.
- Take inspiration from all the other tests in the file.
- "})],
+ Introduce a new test in `action_log.rs` to test overwriting a file.
+ That is, a file already exists, but we call `buffer_created` as if the file were new.
+ Take inspiration from all the other tests in the file.
+ "})],
),
message(
Assistant,
@@ -809,81 +835,81 @@ fn eval_add_overwrite_test() {
"tool_1",
"read_file",
indoc! {"
- pub struct ActionLog [L13-20]
- tracked_buffers [L15]
- edited_since_project_diagnostics_check [L17]
- project [L19]
- impl ActionLog [L22-498]
- pub fn new [L24-30]
- pub fn project [L32-34]
- pub fn checked_project_diagnostics [L37-39]
- pub fn has_edited_files_since_project_diagnostics_check [L42-44]
- fn track_buffer_internal [L46-101]
- fn handle_buffer_event [L103-116]
- fn handle_buffer_edited [L118-123]
- fn handle_buffer_file_changed [L125-158]
- async fn maintain_diff [L160-264]
- pub fn buffer_read [L267-269]
- pub fn buffer_created [L272-276]
- pub fn buffer_edited [L279-287]
- pub fn will_delete_buffer [L289-304]
- pub fn keep_edits_in_range [L306-364]
- pub fn reject_edits_in_ranges [L366-459]
- pub fn keep_all_edits [L461-473]
- pub fn changed_buffers [L476-482]
- pub fn stale_buffers [L485-497]
- fn apply_non_conflicting_edits [L500-561]
- fn diff_snapshots [L563-585]
- fn point_to_row_edit [L587-614]
- enum ChangeAuthor [L617-620]
- User [L618]
- Agent [L619]
- enum TrackedBufferStatus [L623-627]
- Created [L624]
- Modified [L625]
- Deleted [L626]
- struct TrackedBuffer [L629-641]
- buffer [L630]
- base_text [L631]
- unreviewed_changes [L632]
- status [L633]
- version [L634]
- diff [L635]
- snapshot [L636]
- diff_update [L637]
- _open_lsp_handle [L638]
- _maintain_diff [L639]
- _subscription [L640]
- impl TrackedBuffer [L643-657]
- fn has_changes [L644-650]
- fn schedule_diff_update [L652-656]
- pub struct ChangedBuffer [L659-661]
- pub diff [L660]
- mod tests [L664-1574]
- fn init_logger [L678-682]
- fn init_test [L684-691]
- async fn test_keep_edits [L694-769]
- async fn test_deletions [L772-854]
- async fn test_overlapping_user_edits [L857-951]
- async fn test_creating_files [L954-1010]
- async fn test_deleting_files [L1013-1120]
- async fn test_reject_edits [L1123-1255]
- async fn test_reject_multiple_edits [L1258-1331]
- async fn test_reject_deleted_file [L1334-1388]
- async fn test_reject_created_file [L1391-1443]
- async fn test_random_diffs [L1446-1535]
- fn quiesce [L1510-1534]
- struct HunkStatus [L1538-1542]
- range [L1539]
- diff_status [L1540]
- old_text [L1541]
- fn unreviewed_hunks [L1544-1573]
-
- Showing symbols 1-69 (total symbols: 69)
-
- Using the line numbers in this outline, you can call this tool again while specifying
- the start_line and end_line fields to see the implementations of symbols in the outline.
- "},
+ pub struct ActionLog [L13-20]
+ tracked_buffers [L15]
+ edited_since_project_diagnostics_check [L17]
+ project [L19]
+ impl ActionLog [L22-498]
+ pub fn new [L24-30]
+ pub fn project [L32-34]
+ pub fn checked_project_diagnostics [L37-39]
+ pub fn has_edited_files_since_project_diagnostics_check [L42-44]
+ fn track_buffer_internal [L46-101]
+ fn handle_buffer_event [L103-116]
+ fn handle_buffer_edited [L118-123]
+ fn handle_buffer_file_changed [L125-158]
+ async fn maintain_diff [L160-264]
+ pub fn buffer_read [L267-269]
+ pub fn buffer_created [L272-276]
+ pub fn buffer_edited [L279-287]
+ pub fn will_delete_buffer [L289-304]
+ pub fn keep_edits_in_range [L306-364]
+ pub fn reject_edits_in_ranges [L366-459]
+ pub fn keep_all_edits [L461-473]
+ pub fn changed_buffers [L476-482]
+ pub fn stale_buffers [L485-497]
+ fn apply_non_conflicting_edits [L500-561]
+ fn diff_snapshots [L563-585]
+ fn point_to_row_edit [L587-614]
+ enum ChangeAuthor [L617-620]
+ User [L618]
+ Agent [L619]
+ enum TrackedBufferStatus [L623-627]
+ Created [L624]
+ Modified [L625]
+ Deleted [L626]
+ struct TrackedBuffer [L629-641]
+ buffer [L630]
+ base_text [L631]
+ unreviewed_changes [L632]
+ status [L633]
+ version [L634]
+ diff [L635]
+ snapshot [L636]
+ diff_update [L637]
+ _open_lsp_handle [L638]
+ _maintain_diff [L639]
+ _subscription [L640]
+ impl TrackedBuffer [L643-657]
+ fn has_changes [L644-650]
+ fn schedule_diff_update [L652-656]
+ pub struct ChangedBuffer [L659-661]
+ pub diff [L660]
+ mod tests [L664-1574]
+ fn init_logger [L678-682]
+ fn init_test [L684-691]
+ async fn test_keep_edits [L694-769]
+ async fn test_deletions [L772-854]
+ async fn test_overlapping_user_edits [L857-951]
+ async fn test_creating_files [L954-1010]
+ async fn test_deleting_files [L1013-1120]
+ async fn test_reject_edits [L1123-1255]
+ async fn test_reject_multiple_edits [L1258-1331]
+ async fn test_reject_deleted_file [L1334-1388]
+ async fn test_reject_created_file [L1391-1443]
+ async fn test_random_diffs [L1446-1535]
+ fn quiesce [L1510-1534]
+ struct HunkStatus [L1538-1542]
+ range [L1539]
+ diff_status [L1540]
+ old_text [L1541]
+ fn unreviewed_hunks [L1544-1573]
+
+ Showing symbols 1-69 (total symbols: 69)
+
+ Using the line numbers in this outline, you can call this tool again while specifying
+ the start_line and end_line fields to see the implementations of symbols in the outline.
+ "},
)],
),
message(
@@ -963,17 +989,17 @@ fn eval_add_overwrite_test() {
Assistant,
[
text(indoc! {"
- Based on my understanding of the existing tests and code, I'll now create a new test for overwriting a file. In this test, we'll:
+ Based on my understanding of the existing tests and code, I'll now create a new test for overwriting a file. In this test, we'll:
- 1. Create a file
- 2. Open a buffer for it
- 3. Track it in the action log
- 4. Modify the buffer
- 5. Call `buffer_created` again as if the file were new
- 6. Check that the changes are properly tracked
+ 1. Create a file
+ 2. Open a buffer for it
+ 3. Track it in the action log
+ 4. Modify the buffer
+ 5. Call `buffer_created` again as if the file were new
+ 6. Check that the changes are properly tracked
- Let's write the test:
- "}),
+ Let's write the test:
+ "}),
tool_use(
"tool_5",
"edit_file",
@@ -990,8 +1016,8 @@ fn eval_add_overwrite_test() {
EvalAssertion::judge_diff(
"A new test for overwritten files was created, without changing any previous test",
),
- ),
- );
+ ))
+ });
}
#[test]
@@ -1016,20 +1042,18 @@ fn eval_create_empty_file() {
let input_file_content = None;
let expected_output_content = String::new();
- eval(
- 100,
- 0.99,
- 0.05,
- EvalInput::from_conversation(
+
+ eval_utils::eval(100, 0.99, mismatched_tag_threshold(0.05), move || {
+ run_eval(EvalInput::from_conversation(
vec![
message(User, [text("Create a second empty todo file ")]),
message(
Assistant,
[
text(formatdoc! {"
- I'll help you create a second empty todo file.
- First, let me examine the project structure to see if there's already a todo file, which will help me determine the appropriate name and location for the second one.
- "}),
+ I'll help you create a second empty todo file.
+ First, let me examine the project structure to see if there's already a todo file, which will help me determine the appropriate name and location for the second one.
+ "}),
tool_use(
"toolu_01GAF8TtsgpjKxCr8fgQLDgR",
"list_directory",
@@ -1051,8 +1075,8 @@ fn eval_create_empty_file() {
Assistant,
[
text(formatdoc! {"
- I can see there's already a `TODO` file in the `root` directory. Let me create a second empty todo file called `TODO3` in the same directory:
- "}),
+ I can see there's already a `TODO` file in the `root` directory. Let me create a second empty todo file called `TODO3` in the same directory:
+ "}),
tool_use(
"toolu_01Tb3iQ9griqSYMmVuykQPWU",
"edit_file",
@@ -1065,12 +1089,12 @@ fn eval_create_empty_file() {
],
),
],
- input_file_content,
+ input_file_content.clone(),
// Bad behavior is to write something like
// "I'll create an empty TODO3 file as requested."
- EvalAssertion::assert_eq(expected_output_content),
- ),
- );
+ EvalAssertion::assert_eq(expected_output_content.clone()),
+ ))
+ });
}
fn message(
@@ -1312,115 +1336,44 @@ impl EvalAssertion {
}
}
-fn eval(
- iterations: usize,
- expected_pass_ratio: f32,
- mismatched_tag_threshold: f32,
- mut eval: EvalInput,
-) {
- let mut evaluated_count = 0;
- let mut failed_count = 0;
- report_progress(evaluated_count, failed_count, iterations);
-
- let (tx, rx) = mpsc::channel();
-
- // Cache the last message in the conversation, and run one instance of the eval so that
- // all the next ones are cached.
- eval.conversation.last_mut().unwrap().cache = true;
- run_eval(eval.clone(), tx.clone());
-
- let executor = gpui::background_executor();
- let semaphore = Arc::new(smol::lock::Semaphore::new(32));
- for _ in 1..iterations {
- let eval = eval.clone();
- let tx = tx.clone();
- let semaphore = semaphore.clone();
- executor
- .spawn(async move {
- let _guard = semaphore.acquire().await;
- run_eval(eval, tx)
- })
- .detach();
- }
- drop(tx);
-
- let mut failed_evals = HashMap::default();
- let mut errored_evals = HashMap::default();
- let mut eval_outputs = Vec::new();
- let mut cumulative_parser_metrics = EditParserMetrics::default();
- while let Ok(output) = rx.recv() {
- match output {
- Ok(output) => {
- cumulative_parser_metrics += output.sample.edit_output.parser_metrics.clone();
- eval_outputs.push(output.clone());
- if output.assertion.score < 80 {
- failed_count += 1;
- failed_evals
- .entry(output.sample.text_after.clone())
- .or_insert(Vec::new())
- .push(output);
- }
- }
- Err(error) => {
- failed_count += 1;
- *errored_evals.entry(format!("{:?}", error)).or_insert(0) += 1;
- }
- }
-
- evaluated_count += 1;
- report_progress(evaluated_count, failed_count, iterations);
- }
-
- let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32;
- println!("Actual pass ratio: {}\n", actual_pass_ratio);
- if actual_pass_ratio < expected_pass_ratio {
- let mut errored_evals = errored_evals.into_iter().collect::<Vec<_>>();
- errored_evals.sort_by_key(|(_, count)| Reverse(*count));
- for (error, count) in errored_evals {
- println!("Eval errored {} times. Error: {}", count, error);
- }
-
- let mut failed_evals = failed_evals.into_iter().collect::<Vec<_>>();
- failed_evals.sort_by_key(|(_, evals)| Reverse(evals.len()));
- for (_buffer_output, failed_evals) in failed_evals {
- let eval_output = failed_evals.first().unwrap();
- println!("Eval failed {} times", failed_evals.len());
- println!("{}", eval_output);
- }
-
- panic!(
- "Actual pass ratio: {}\nExpected pass ratio: {}",
- actual_pass_ratio, expected_pass_ratio
- );
- }
-
- let mismatched_tag_ratio =
- cumulative_parser_metrics.mismatched_tags as f32 / cumulative_parser_metrics.tags as f32;
- if mismatched_tag_ratio > mismatched_tag_threshold {
- for eval_output in eval_outputs {
- println!("{}", eval_output);
- }
- panic!("Too many mismatched tags: {:?}", cumulative_parser_metrics);
- }
-}
-
-fn run_eval(eval: EvalInput, tx: mpsc::Sender<Result<EvalOutput>>) {
+fn run_eval(eval: EvalInput) -> eval_utils::EvalOutput<EditEvalMetadata> {
let dispatcher = gpui::TestDispatcher::new(StdRng::from_os_rng());
let mut cx = TestAppContext::build(dispatcher, None);
- let output = cx.executor().block_test(async {
+ let result = cx.executor().block_test(async {
let test = EditAgentTest::new(&mut cx).await;
test.eval(eval, &mut cx).await
});
- tx.send(output).unwrap();
+ match result {
+ Ok(output) => eval_utils::EvalOutput {
+ data: output.to_string(),
+ outcome: if output.assertion.score < 80 {
+ eval_utils::OutcomeKind::Failed
+ } else {
+ eval_utils::OutcomeKind::Passed
+ },
+ metadata: EditEvalMetadata {
+ tags: output.sample.edit_output.parser_metrics.tags,
+ mismatched_tags: output.sample.edit_output.parser_metrics.mismatched_tags,
+ },
+ },
+ Err(e) => eval_utils::EvalOutput {
+ data: format!("{e:?}"),
+ outcome: eval_utils::OutcomeKind::Error,
+ metadata: EditEvalMetadata {
+ tags: 0,
+ mismatched_tags: 0,
+ },
+ },
+ }
}
#[derive(Clone)]
-struct EvalOutput {
+struct EditEvalOutput {
sample: EvalSample,
assertion: EvalAssertionOutcome,
}
-impl Display for EvalOutput {
+impl Display for EditEvalOutput {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(f, "Score: {:?}", self.assertion.score)?;
if let Some(message) = self.assertion.message.as_ref() {
@@ -1439,22 +1392,6 @@ impl Display for EvalOutput {
}
}
-fn report_progress(evaluated_count: usize, failed_count: usize, iterations: usize) {
- let passed_count = evaluated_count - failed_count;
- let passed_ratio = if evaluated_count == 0 {
- 0.0
- } else {
- passed_count as f64 / evaluated_count as f64
- };
- print!(
- "\r\x1b[KEvaluated {}/{} ({:.2}% passed)",
- evaluated_count,
- iterations,
- passed_ratio * 100.0
- );
- std::io::stdout().flush().unwrap();
-}
-
struct EditAgentTest {
agent: EditAgent,
project: Entity<Project>,
@@ -1550,7 +1487,10 @@ impl EditAgentTest {
})
}
- async fn eval(&self, eval: EvalInput, cx: &mut TestAppContext) -> Result<EvalOutput> {
+ async fn eval(&self, mut eval: EvalInput, cx: &mut TestAppContext) -> Result<EditEvalOutput> {
+ // Make sure the last message in the conversation is cached.
+ eval.conversation.last_mut().unwrap().cache = true;
+
let path = self
.project
.read_with(cx, |project, cx| {
@@ -1656,7 +1596,7 @@ impl EditAgentTest {
.run(&sample, self.judge_model.clone(), cx)
.await?;
- Ok(EvalOutput { assertion, sample })
+ Ok(EditEvalOutput { assertion, sample })
}
}
@@ -13,7 +13,8 @@ path = "src/agent_ui.rs"
doctest = false
[features]
-test-support = ["gpui/test-support", "language/test-support"]
+test-support = ["gpui/test-support", "language/test-support", "reqwest_client"]
+unit-eval = []
[dependencies]
acp_thread.workspace = true
@@ -47,6 +48,7 @@ fs.workspace = true
futures.workspace = true
fuzzy.workspace = true
gpui.workspace = true
+gpui_tokio.workspace = true
html_to_markdown.workspace = true
http_client.workspace = true
indoc.workspace = true
@@ -98,14 +100,17 @@ workspace.workspace = true
zed_actions.workspace = true
image.workspace = true
async-fs.workspace = true
+reqwest_client = { workspace = true, optional = true }
[dev-dependencies]
acp_thread = { workspace = true, features = ["test-support"] }
agent = { workspace = true, features = ["test-support"] }
assistant_text_thread = { workspace = true, features = ["test-support"] }
buffer_diff = { workspace = true, features = ["test-support"] }
+clock.workspace = true
db = { workspace = true, features = ["test-support"] }
editor = { workspace = true, features = ["test-support"] }
+eval_utils.workspace = true
gpui = { workspace = true, "features" = ["test-support"] }
indoc.workspace = true
language = { workspace = true, "features" = ["test-support"] }
@@ -115,5 +120,6 @@ pretty_assertions.workspace = true
project = { workspace = true, features = ["test-support"] }
semver.workspace = true
rand.workspace = true
+reqwest_client.workspace = true
tree-sitter-md.workspace = true
unindent.workspace = true
@@ -2685,16 +2685,17 @@ impl rules_library::InlineAssistDelegate for PromptLibraryInlineAssist {
return;
};
let project = workspace.read(cx).project().downgrade();
+ let thread_store = panel.read(cx).thread_store().clone();
assistant.assist(
prompt_editor,
self.workspace.clone(),
project,
- panel.read(cx).thread_store().clone(),
+ thread_store,
None,
initial_prompt,
window,
cx,
- )
+ );
})
}
@@ -7,6 +7,8 @@ mod buffer_codegen;
mod completion_provider;
mod context;
mod context_server_configuration;
+#[cfg(test)]
+mod evals;
mod inline_assistant;
mod inline_prompt_editor;
mod language_model_selector;
@@ -719,6 +719,7 @@ impl CodegenAlternative {
output_tokens = usage.output_tokens,
)
}
+
cx.emit(CodegenEvent::Finished);
cx.notify();
})
@@ -0,0 +1,89 @@
+use std::str::FromStr;
+
+use crate::inline_assistant::test::run_inline_assistant_test;
+
+use eval_utils::{EvalOutput, NoProcessor};
+use gpui::TestAppContext;
+use language_model::{LanguageModelRegistry, SelectedModel};
+use rand::{SeedableRng as _, rngs::StdRng};
+
+#[test]
+#[cfg_attr(not(feature = "unit-eval"), ignore)]
+fn eval_single_cursor_edit() {
+ eval_utils::eval(20, 1.0, NoProcessor, move || {
+ run_eval(
+ &EvalInput {
+ prompt: "Rename this variable to buffer_text".to_string(),
+ buffer: indoc::indoc! {"
+ struct EvalExampleStruct {
+ text: StrΛing,
+ prompt: String,
+ }
+ "}
+ .to_string(),
+ },
+ &|_, output| {
+ let expected = indoc::indoc! {"
+ struct EvalExampleStruct {
+ buffer_text: String,
+ prompt: String,
+ }
+ "};
+ if output == expected {
+ EvalOutput {
+ outcome: eval_utils::OutcomeKind::Passed,
+ data: "Passed!".to_string(),
+ metadata: (),
+ }
+ } else {
+ EvalOutput {
+ outcome: eval_utils::OutcomeKind::Failed,
+ data: format!("Failed to rename variable, output: {}", output),
+ metadata: (),
+ }
+ }
+ },
+ )
+ });
+}
+
+struct EvalInput {
+ buffer: String,
+ prompt: String,
+}
+
+fn run_eval(
+ input: &EvalInput,
+ judge: &dyn Fn(&EvalInput, &str) -> eval_utils::EvalOutput<()>,
+) -> eval_utils::EvalOutput<()> {
+ let dispatcher = gpui::TestDispatcher::new(StdRng::from_os_rng());
+ let mut cx = TestAppContext::build(dispatcher, None);
+ cx.skip_drawing();
+
+ let buffer_text = run_inline_assistant_test(
+ input.buffer.clone(),
+ input.prompt.clone(),
+ |cx| {
+ // Reconfigure to use a real model instead of the fake one
+ let model_name = std::env::var("ZED_AGENT_MODEL")
+ .unwrap_or("anthropic/claude-sonnet-4-latest".into());
+
+ let selected_model = SelectedModel::from_str(&model_name)
+ .expect("Invalid model format. Use 'provider/model-id'");
+
+ log::info!("Selected model: {selected_model:?}");
+
+ cx.update(|_, cx| {
+ LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
+ registry.select_inline_assistant_model(Some(&selected_model), cx);
+ });
+ });
+ },
+ |_cx| {
+ log::info!("Waiting for actual response from the LLM...");
+ },
+ &mut cx,
+ );
+
+ judge(input, &buffer_text)
+}
@@ -32,7 +32,7 @@ use editor::{
},
};
use fs::Fs;
-use futures::FutureExt;
+use futures::{FutureExt, channel::mpsc};
use gpui::{
App, Context, Entity, Focusable, Global, HighlightStyle, Subscription, Task, UpdateGlobal,
WeakEntity, Window, point,
@@ -102,6 +102,7 @@ pub struct InlineAssistant {
prompt_builder: Arc<PromptBuilder>,
telemetry: Arc<Telemetry>,
fs: Arc<dyn Fs>,
+ _inline_assistant_completions: Option<mpsc::UnboundedSender<anyhow::Result<InlineAssistId>>>,
}
impl Global for InlineAssistant {}
@@ -123,9 +124,18 @@ impl InlineAssistant {
prompt_builder,
telemetry,
fs,
+ _inline_assistant_completions: None,
}
}
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn set_completion_receiver(
+ &mut self,
+ sender: mpsc::UnboundedSender<anyhow::Result<InlineAssistId>>,
+ ) {
+ self._inline_assistant_completions = Some(sender);
+ }
+
pub fn register_workspace(
&mut self,
workspace: &Entity<Workspace>,
@@ -287,7 +297,7 @@ impl InlineAssistant {
action.prompt.clone(),
window,
cx,
- )
+ );
})
}
InlineAssistTarget::Terminal(active_terminal) => {
@@ -301,8 +311,8 @@ impl InlineAssistant {
action.prompt.clone(),
window,
cx,
- )
- })
+ );
+ });
}
};
@@ -598,13 +608,13 @@ impl InlineAssistant {
initial_prompt: Option<String>,
window: &mut Window,
cx: &mut App,
- ) {
+ ) -> Option<InlineAssistId> {
let snapshot = editor.update(cx, |editor, cx| editor.snapshot(window, cx));
let Some((codegen_ranges, newest_selection)) =
self.codegen_ranges(editor, &snapshot, window, cx)
else {
- return;
+ return None;
};
let assist_to_focus = self.batch_assist(
@@ -624,6 +634,8 @@ impl InlineAssistant {
if let Some(assist_id) = assist_to_focus {
self.focus_assist(assist_id, window, cx);
}
+
+ assist_to_focus
}
pub fn suggest_assist(
@@ -1740,6 +1752,16 @@ impl InlineAssist {
&& assist.decorations.is_none()
&& let Some(workspace) = assist.workspace.upgrade()
{
+ #[cfg(any(test, feature = "test-support"))]
+ if let Some(sender) = &mut this._inline_assistant_completions {
+ sender
+ .unbounded_send(Err(anyhow::anyhow!(
+ "Inline assistant error: {}",
+ error
+ )))
+ .ok();
+ }
+
let error = format!("Inline assistant error: {}", error);
workspace.update(cx, |workspace, cx| {
struct InlineAssistantError;
@@ -1750,6 +1772,11 @@ impl InlineAssist {
workspace.show_toast(Toast::new(id, error), cx);
})
+ } else {
+ #[cfg(any(test, feature = "test-support"))]
+ if let Some(sender) = &mut this._inline_assistant_completions {
+ sender.unbounded_send(Ok(assist_id)).ok();
+ }
}
if assist.decorations.is_none() {
@@ -1943,3 +1970,160 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
}
}
}
+
+#[cfg(any(test, feature = "test-support"))]
+pub mod test {
+ use std::sync::Arc;
+
+ use agent::HistoryStore;
+ use assistant_text_thread::TextThreadStore;
+ use client::{Client, UserStore};
+ use editor::{Editor, MultiBuffer, MultiBufferOffset};
+ use fs::FakeFs;
+ use futures::channel::mpsc;
+ use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
+ use language::Buffer;
+ use language_model::LanguageModelRegistry;
+ use project::Project;
+ use prompt_store::PromptBuilder;
+ use smol::stream::StreamExt as _;
+ use util::test::marked_text_ranges;
+ use workspace::Workspace;
+
+ use crate::InlineAssistant;
+
+ pub fn run_inline_assistant_test<SetupF, TestF>(
+ base_buffer: String,
+ prompt: String,
+ setup: SetupF,
+ test: TestF,
+ cx: &mut TestAppContext,
+ ) -> String
+ where
+ SetupF: FnOnce(&mut gpui::VisualTestContext),
+ TestF: FnOnce(&mut gpui::VisualTestContext),
+ {
+ let fs = FakeFs::new(cx.executor());
+ let app_state = cx.update(|cx| workspace::AppState::test(cx));
+ let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let http = Arc::new(reqwest_client::ReqwestClient::user_agent("agent tests").unwrap());
+ let client = cx.update(|cx| {
+ cx.set_http_client(http);
+ Client::production(cx)
+ });
+ let mut inline_assistant =
+ InlineAssistant::new(fs.clone(), prompt_builder, client.telemetry().clone());
+
+ let (tx, mut completion_rx) = mpsc::unbounded();
+ inline_assistant.set_completion_receiver(tx);
+
+ // Initialize settings and client
+ cx.update(|cx| {
+ gpui_tokio::init(cx);
+ settings::init(cx);
+ client::init(&client, cx);
+ workspace::init(app_state.clone(), cx);
+ let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
+ language_model::init(client.clone(), cx);
+ language_models::init(user_store, client.clone(), cx);
+
+ cx.set_global(inline_assistant);
+ });
+
+ let project = cx
+ .executor()
+ .block_test(async { Project::test(fs.clone(), [], cx).await });
+
+ // Create workspace with window
+ let (workspace, cx) = cx.add_window_view(|window, cx| {
+ window.activate_window();
+ Workspace::new(None, project.clone(), app_state.clone(), window, cx)
+ });
+
+ setup(cx);
+
+ let (_editor, buffer) = cx.update(|window, cx| {
+ let buffer = cx.new(|cx| Buffer::local("", cx));
+ let multibuffer = cx.new(|cx| MultiBuffer::singleton(buffer.clone(), cx));
+ let editor = cx.new(|cx| Editor::for_multibuffer(multibuffer, None, window, cx));
+ editor.update(cx, |editor, cx| {
+ let (unmarked_text, selection_ranges) = marked_text_ranges(&base_buffer, true);
+ editor.set_text(unmarked_text, window, cx);
+ editor.change_selections(Default::default(), window, cx, |s| {
+ s.select_ranges(
+ selection_ranges.into_iter().map(|range| {
+ MultiBufferOffset(range.start)..MultiBufferOffset(range.end)
+ }),
+ )
+ })
+ });
+
+ let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx));
+ let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
+
+ // Add editor to workspace
+ workspace.update(cx, |workspace, cx| {
+ workspace.add_item_to_active_pane(Box::new(editor.clone()), None, true, window, cx);
+ });
+
+ // Call assist method
+ InlineAssistant::update_global(cx, |inline_assistant, cx| {
+ let assist_id = inline_assistant
+ .assist(
+ &editor,
+ workspace.downgrade(),
+ project.downgrade(),
+ history_store, // thread_store
+ None, // prompt_store
+ Some(prompt),
+ window,
+ cx,
+ )
+ .unwrap();
+
+ inline_assistant.start_assist(assist_id, window, cx);
+ });
+
+ (editor, buffer)
+ });
+
+ cx.run_until_parked();
+
+ test(cx);
+
+ cx.executor()
+ .block_test(async { completion_rx.next().await });
+
+ buffer.read_with(cx, |buffer, _| buffer.text())
+ }
+
+ #[allow(unused)]
+ pub fn test_inline_assistant(
+ base_buffer: &'static str,
+ llm_output: &'static str,
+ cx: &mut TestAppContext,
+ ) -> String {
+ run_inline_assistant_test(
+ base_buffer.to_string(),
+ "Prompt doesn't matter because we're using a fake model".to_string(),
+ |cx| {
+ cx.update(|_, cx| LanguageModelRegistry::test(cx));
+ },
+ |cx| {
+ let fake_model = cx.update(|_, cx| {
+ LanguageModelRegistry::global(cx)
+ .update(cx, |registry, _| registry.fake_model())
+ });
+ let fake = fake_model.as_fake();
+
+ // let fake = fake_model;
+ fake.send_last_completion_stream_text_chunk(llm_output.to_string());
+ fake.end_last_completion_stream();
+
+ // Run again to process the model's response
+ cx.run_until_parked();
+ },
+ cx,
+ )
+ }
+}
@@ -0,0 +1,18 @@
+[package]
+name = "eval_utils"
+version = "0.1.0"
+edition.workspace = true
+publish.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/eval_utils.rs"
+doctest = false
+
+[dependencies]
+gpui.workspace = true
+serde.workspace = true
+smol.workspace = true
@@ -0,0 +1 @@
+LICENSE-GPL
@@ -0,0 +1,3 @@
+# eval_utils
+
+Utilities for evals of agents.
@@ -0,0 +1,128 @@
+//! Utilities for evaluation and benchmarking.
+
+use std::{
+ collections::HashMap,
+ sync::{Arc, mpsc},
+};
+
+fn report_progress(evaluated_count: usize, failed_count: usize, iterations: usize) {
+ let passed_count = evaluated_count - failed_count;
+ let passed_ratio = if evaluated_count == 0 {
+ 0.0
+ } else {
+ passed_count as f64 / evaluated_count as f64
+ };
+ println!(
+ "\r\x1b[KEvaluated {}/{} ({:.2}% passed)",
+ evaluated_count,
+ iterations,
+ passed_ratio * 100.0
+ )
+}
+
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub enum OutcomeKind {
+ Passed,
+ Failed,
+ Error,
+}
+
+pub trait EvalOutputProcessor {
+ type Metadata: 'static + Send;
+ fn process(&mut self, output: &EvalOutput<Self::Metadata>);
+ fn assert(&mut self);
+}
+
+#[derive(Clone, Debug)]
+pub struct EvalOutput<M> {
+ pub outcome: OutcomeKind,
+ pub data: String,
+ pub metadata: M,
+}
+
+pub struct NoProcessor;
+impl EvalOutputProcessor for NoProcessor {
+ type Metadata = ();
+
+ fn process(&mut self, _output: &EvalOutput<Self::Metadata>) {}
+
+ fn assert(&mut self) {}
+}
+
+pub fn eval<P>(
+ iterations: usize,
+ expected_pass_ratio: f32,
+ mut processor: P,
+ evalf: impl Fn() -> EvalOutput<P::Metadata> + Send + Sync + 'static,
+) where
+ P: EvalOutputProcessor,
+{
+ let mut evaluated_count = 0;
+ let mut failed_count = 0;
+ let evalf = Arc::new(evalf);
+ report_progress(evaluated_count, failed_count, iterations);
+
+ let (tx, rx) = mpsc::channel();
+
+ let executor = gpui::background_executor();
+ let semaphore = Arc::new(smol::lock::Semaphore::new(32));
+ let evalf = Arc::new(evalf);
+ // Warm the cache once
+ let first_output = evalf();
+ tx.send(first_output).ok();
+
+ for _ in 1..iterations {
+ let tx = tx.clone();
+ let semaphore = semaphore.clone();
+ let evalf = evalf.clone();
+ executor
+ .spawn(async move {
+ let _guard = semaphore.acquire().await;
+ let output = evalf();
+ tx.send(output).ok();
+ })
+ .detach();
+ }
+ drop(tx);
+
+ let mut failed_evals = Vec::new();
+ let mut errored_evals = HashMap::new();
+ while let Ok(output) = rx.recv() {
+ processor.process(&output);
+
+ match output.outcome {
+ OutcomeKind::Passed => {}
+ OutcomeKind::Failed => {
+ failed_count += 1;
+ failed_evals.push(output);
+ }
+ OutcomeKind::Error => {
+ failed_count += 1;
+ *errored_evals.entry(output.data).or_insert(0) += 1;
+ }
+ }
+
+ evaluated_count += 1;
+ report_progress(evaluated_count, failed_count, iterations);
+ }
+
+ let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32;
+ println!("Actual pass ratio: {}\n", actual_pass_ratio);
+ if actual_pass_ratio < expected_pass_ratio {
+ for (error, count) in errored_evals {
+ println!("Eval errored {} times. Error: {}", count, error);
+ }
+
+ for failed in failed_evals {
+ println!("Eval failed");
+ println!("{}", failed.data);
+ }
+
+ panic!(
+ "Actual pass ratio: {}\nExpected pass ratio: {}",
+ actual_pass_ratio, expected_pass_ratio
+ );
+ }
+
+ processor.assert();
+}
@@ -551,12 +551,39 @@ impl SystemWindowTabController {
}
}
+pub(crate) enum GpuiMode {
+ #[cfg(any(test, feature = "test-support"))]
+ Test {
+ skip_drawing: bool,
+ },
+ Production,
+}
+
+impl GpuiMode {
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn test() -> Self {
+ GpuiMode::Test {
+ skip_drawing: false,
+ }
+ }
+
+ #[inline]
+ pub(crate) fn skip_drawing(&self) -> bool {
+ match self {
+ #[cfg(any(test, feature = "test-support"))]
+ GpuiMode::Test { skip_drawing } => *skip_drawing,
+ GpuiMode::Production => false,
+ }
+ }
+}
+
/// Contains the state of the full application, and passed as a reference to a variety of callbacks.
/// Other [Context] derefs to this type.
/// You need a reference to an `App` to access the state of a [Entity].
pub struct App {
pub(crate) this: Weak<AppCell>,
pub(crate) platform: Rc<dyn Platform>,
+ pub(crate) mode: GpuiMode,
text_system: Arc<TextSystem>,
flushing_effects: bool,
pending_updates: usize,
@@ -635,6 +662,7 @@ impl App {
this: this.clone(),
platform: platform.clone(),
text_system,
+ mode: GpuiMode::Production,
actions: Rc::new(ActionRegistry::default()),
flushing_effects: false,
pending_updates: 0,
@@ -5,7 +5,7 @@ use crate::{
ModifiersChangedEvent, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Pixels,
Platform, Point, Render, Result, Size, Task, TestDispatcher, TestPlatform,
TestScreenCaptureSource, TestWindow, TextSystem, VisualContext, Window, WindowBounds,
- WindowHandle, WindowOptions,
+ WindowHandle, WindowOptions, app::GpuiMode,
};
use anyhow::{anyhow, bail};
use futures::{Stream, StreamExt, channel::oneshot};
@@ -132,8 +132,11 @@ impl TestAppContext {
let http_client = http_client::FakeHttpClient::with_404_response();
let text_system = Arc::new(TextSystem::new(platform.text_system()));
+ let mut app = App::new_app(platform.clone(), asset_source, http_client);
+ app.borrow_mut().mode = GpuiMode::test();
+
Self {
- app: App::new_app(platform.clone(), asset_source, http_client),
+ app,
background_executor,
foreground_executor,
dispatcher,
@@ -144,6 +147,11 @@ impl TestAppContext {
}
}
+ /// Skip all drawing operations for the duration of this test.
+ pub fn skip_drawing(&mut self) {
+ self.app.borrow_mut().mode = GpuiMode::Test { skip_drawing: true };
+ }
+
/// Create a single TestAppContext, for non-multi-client tests
pub fn single() -> Self {
let dispatcher = TestDispatcher::new(StdRng::seed_from_u64(0));
@@ -2006,7 +2006,9 @@ impl Window {
if let Some(input_handler) = self.platform_window.take_input_handler() {
self.rendered_frame.input_handlers.push(Some(input_handler));
}
- self.draw_roots(cx);
+ if !cx.mode.skip_drawing() {
+ self.draw_roots(cx);
+ }
self.dirty_views.clear();
self.next_frame.window_active = self.active.get();
@@ -408,6 +408,7 @@ impl FakeHttpClient {
}
pub fn with_404_response() -> Arc<HttpClientWithUrl> {
+ log::warn!("Using fake HTTP client with 404 response");
Self::create(|_| async move {
Ok(Response::builder()
.status(404)
@@ -417,6 +418,7 @@ impl FakeHttpClient {
}
pub fn with_200_response() -> Arc<HttpClientWithUrl> {
+ log::warn!("Using fake HTTP client with 200 response");
Self::create(|_| async move {
Ok(Response::builder()
.status(200)
@@ -135,6 +135,11 @@ impl LanguageModelRegistry {
fake_provider
}
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn fake_model(&self) -> Arc<dyn LanguageModel> {
+ self.default_model.as_ref().unwrap().model.clone()
+ }
+
pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
&mut self,
provider: Arc<T>,