From c60343af719a04323b4b6bdbf27d1d549247309d Mon Sep 17 00:00:00 2001 From: Bennet Fenner Date: Wed, 22 Oct 2025 19:55:26 +0200 Subject: [PATCH] eval: Port to agent2 (#40704) Release Notes: - N/A --- Cargo.lock | 55 ++ Cargo.toml | 2 +- crates/acp_thread/src/terminal.rs | 74 ++- crates/agent/Cargo.toml | 2 + crates/agent/src/edit_agent/evals.rs | 29 +- crates/agent/src/tests/mod.rs | 2 +- crates/agent/src/thread.rs | 41 +- crates/agent_servers/src/acp.rs | 74 +-- crates/eval/Cargo.toml | 5 +- crates/eval/runner_settings.json | 5 +- crates/eval/src/eval.rs | 39 +- crates/eval/src/example.rs | 302 +++++------ .../src/examples/add_arg_to_trait_method.rs | 6 +- .../eval/src/examples/code_block_citations.rs | 19 +- .../eval/src/examples/comment_translation.rs | 32 +- .../src/examples/file_change_notification.rs | 6 +- crates/eval/src/examples/file_search.rs | 11 +- .../src/examples/grep_params_escapement.rs | 7 +- crates/eval/src/examples/mod.rs | 3 +- crates/eval/src/examples/overwrite_file.rs | 18 +- crates/eval/src/examples/planets.rs | 15 +- crates/eval/src/instance.rs | 492 ++++++++++++------ 22 files changed, 783 insertions(+), 456 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4628b20e29cee879c9a68d5af52a89c3d684302b..e426bc4ce64d540ea77fcd03decb875ebb76a572 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5692,6 +5692,61 @@ dependencies = [ "num-traits", ] +[[package]] +name = "eval" +version = "0.1.0" +dependencies = [ + "acp_thread", + "agent", + "agent-client-protocol", + "agent_settings", + "agent_ui", + "anyhow", + "async-trait", + "buffer_diff", + "chrono", + "clap", + "client", + "collections", + "debug_adapter_extension", + "dirs 4.0.0", + "dotenvy", + "env_logger 0.11.8", + "extension", + "fs", + "futures 0.3.31", + "gpui", + "gpui_tokio", + "handlebars 4.5.0", + "language", + "language_extension", + "language_model", + "language_models", + "languages", + "markdown", + "node_runtime", + "pathdiff", + "paths", + "pretty_assertions", + "project", + "prompt_store", + "rand 0.9.2", + "regex", + "release_channel", + "reqwest_client", + "serde", + "serde_json", + "settings", + "shellexpand 2.1.2", + "telemetry", + "terminal_view", + "toml 0.8.23", + "unindent", + "util", + "uuid", + "watch", +] + [[package]] name = "event-listener" version = "2.5.3" diff --git a/Cargo.toml b/Cargo.toml index 792f38e4ce0aa2ad947f60b2962f7711eff846f4..c0c0ffc1508aaa51465db7a30cccfcfa04fd8467 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,7 +58,7 @@ members = [ "crates/edit_prediction_context", "crates/zeta2_tools", "crates/editor", - # "crates/eval", + "crates/eval", "crates/explorer_command_injector", "crates/extension", "crates/extension_api", diff --git a/crates/acp_thread/src/terminal.rs b/crates/acp_thread/src/terminal.rs index 888c7698c3d2270769f3afbe712ecba7d08b055f..9ca6d4021b316231930ab7803957dab3a0139f1e 100644 --- a/crates/acp_thread/src/terminal.rs +++ b/crates/acp_thread/src/terminal.rs @@ -1,10 +1,15 @@ use agent_client_protocol as acp; - +use anyhow::Result; use futures::{FutureExt as _, future::Shared}; -use gpui::{App, AppContext, Context, Entity, Task}; +use gpui::{App, AppContext, AsyncApp, Context, Entity, Task}; use language::LanguageRegistry; use markdown::Markdown; +use project::Project; +use settings::{Settings as _, SettingsLocation}; use std::{path::PathBuf, process::ExitStatus, sync::Arc, time::Instant}; +use task::Shell; +use terminal::terminal_settings::TerminalSettings; +use util::get_default_system_shell_preferring_bash; pub struct Terminal { id: acp::TerminalId, @@ -170,3 +175,68 @@ impl Terminal { ) } } + +pub async fn create_terminal_entity( + command: String, + args: &[String], + env_vars: Vec<(String, String)>, + cwd: Option, + project: &Entity, + cx: &mut AsyncApp, +) -> Result> { + let mut env = if let Some(dir) = &cwd { + project + .update(cx, |project, cx| { + let worktree = project.find_worktree(dir.as_path(), cx); + let shell = TerminalSettings::get( + worktree.as_ref().map(|(worktree, path)| SettingsLocation { + worktree_id: worktree.read(cx).id(), + path: &path, + }), + cx, + ) + .shell + .clone(); + project.directory_environment(&shell, dir.clone().into(), cx) + })? + .await + .unwrap_or_default() + } else { + Default::default() + }; + + // Disables paging for `git` and hopefully other commands + env.insert("PAGER".into(), "".into()); + env.extend(env_vars); + + // Use remote shell or default system shell, as appropriate + let shell = project + .update(cx, |project, cx| { + project + .remote_client() + .and_then(|r| r.read(cx).default_system_shell()) + .map(Shell::Program) + })? + .unwrap_or_else(|| Shell::Program(get_default_system_shell_preferring_bash())); + let is_windows = project + .read_with(cx, |project, cx| project.path_style(cx).is_windows()) + .unwrap_or(cfg!(windows)); + let (task_command, task_args) = task::ShellBuilder::new(&shell, is_windows) + .redirect_stdin_to_dev_null() + .build(Some(command.clone()), &args); + + project + .update(cx, |project, cx| { + project.create_terminal_task( + task::SpawnInTerminal { + command: Some(task_command), + args: task_args, + cwd, + env, + ..Default::default() + }, + cx, + ) + })? + .await +} diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index 86027d01fe3e93d2f6234cce9e935ebace318481..9e5b6ad66096b784bfb496b71ef1ee5cb30005cb 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -10,6 +10,8 @@ path = "src/agent.rs" [features] test-support = ["db/test-support"] +eval = [] +edit-agent-eval = [] e2e = [] [lints] diff --git a/crates/agent/src/edit_agent/evals.rs b/crates/agent/src/edit_agent/evals.rs index b3043f0a81256568338f5d4be22bfe02de277076..a39ed21c7cde1304e4e955e20f6011672ee70c3e 100644 --- a/crates/agent/src/edit_agent/evals.rs +++ b/crates/agent/src/edit_agent/evals.rs @@ -31,7 +31,7 @@ use std::{ use util::path; #[test] -#[cfg_attr(not(feature = "eval"), ignore)] +#[cfg_attr(not(feature = "edit-agent-eval"), ignore)] fn eval_extract_handle_command_output() { // Test how well agent generates multiple edit hunks. // @@ -108,7 +108,7 @@ fn eval_extract_handle_command_output() { } #[test] -#[cfg_attr(not(feature = "eval"), ignore)] +#[cfg_attr(not(feature = "edit-agent-eval"), ignore)] fn eval_delete_run_git_blame() { // Model | Pass rate // ----------------------------|---------- @@ -171,7 +171,7 @@ fn eval_delete_run_git_blame() { } #[test] -#[cfg_attr(not(feature = "eval"), ignore)] +#[cfg_attr(not(feature = "edit-agent-eval"), ignore)] fn eval_translate_doc_comments() { // Model | Pass rate // ============================================ @@ -234,7 +234,7 @@ fn eval_translate_doc_comments() { } #[test] -#[cfg_attr(not(feature = "eval"), ignore)] +#[cfg_attr(not(feature = "edit-agent-eval"), ignore)] fn eval_use_wasi_sdk_in_compile_parser_to_wasm() { // Model | Pass rate // ============================================ @@ -360,7 +360,7 @@ fn eval_use_wasi_sdk_in_compile_parser_to_wasm() { } #[test] -#[cfg_attr(not(feature = "eval"), ignore)] +#[cfg_attr(not(feature = "edit-agent-eval"), ignore)] fn eval_disable_cursor_blinking() { // Model | Pass rate // ============================================ @@ -446,7 +446,7 @@ fn eval_disable_cursor_blinking() { } #[test] -#[cfg_attr(not(feature = "eval"), ignore)] +#[cfg_attr(not(feature = "edit-agent-eval"), ignore)] fn eval_from_pixels_constructor() { // Results for 2025-06-13 // @@ -656,7 +656,7 @@ fn eval_from_pixels_constructor() { } #[test] -#[cfg_attr(not(feature = "eval"), ignore)] +#[cfg_attr(not(feature = "edit-agent-eval"), ignore)] fn eval_zode() { // Model | Pass rate // ============================================ @@ -763,7 +763,7 @@ fn eval_zode() { } #[test] -#[cfg_attr(not(feature = "eval"), ignore)] +#[cfg_attr(not(feature = "edit-agent-eval"), ignore)] fn eval_add_overwrite_test() { // Model | Pass rate // ============================================ @@ -995,7 +995,7 @@ fn eval_add_overwrite_test() { } #[test] -#[cfg_attr(not(feature = "eval"), ignore)] +#[cfg_attr(not(feature = "edit-agent-eval"), ignore)] fn eval_create_empty_file() { // Check that Edit Agent can create a file without writing its // thoughts into it. This issue is not specific to empty files, but @@ -1490,9 +1490,20 @@ impl EditAgentTest { &std::env::var("ZED_JUDGE_MODEL").unwrap_or("anthropic/claude-4-sonnet-latest".into()), ) .unwrap(); + + let authenticate_provider_tasks = cx.update(|cx| { + LanguageModelRegistry::global(cx).update(cx, |registry, cx| { + registry + .providers() + .iter() + .map(|p| p.authenticate(cx)) + .collect::>() + }) + }); let (agent_model, judge_model) = cx .update(|cx| { cx.spawn(async move |cx| { + futures::future::join_all(authenticate_provider_tasks).await; let agent_model = Self::load_model(&agent_model, cx).await; let judge_model = Self::load_model(&judge_model, cx).await; (agent_model.unwrap(), judge_model.unwrap()) diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index 6b7d30b37f825bf664ee270bee9f965ee194291c..66b006893e50b9c59701eff850adb7747f96e3b5 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -1995,7 +1995,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) { locations: vec![], raw_input: Some(json!({})), raw_output: None, - meta: None, + meta: Some(json!({ "tool_name": "thinking" })), } ); let update = expect_tool_call_update_fields(&mut events).await; diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index c89ad1df241c3b9c6e07b9a5433dd964244ba2cb..d873e4f26cb22d34c501b1d4d3ffd3af94465af4 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -745,7 +745,13 @@ impl Thread { let title = tool.initial_title(tool_use.input.clone(), cx); let kind = tool.kind(); - stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone()); + stream.send_tool_call( + &tool_use.id, + &tool_use.name, + title, + kind, + tool_use.input.clone(), + ); let output = tool_result .as_ref() @@ -1044,14 +1050,18 @@ impl Thread { Ok(()) } - pub fn latest_token_usage(&self) -> Option { + pub fn latest_request_token_usage(&self) -> Option { let last_user_message = self.last_user_message()?; let tokens = self.request_token_usage.get(&last_user_message.id)?; - let model = self.model.clone()?; + Some(*tokens) + } + pub fn latest_token_usage(&self) -> Option { + let usage = self.latest_request_token_usage()?; + let model = self.model.clone()?; Some(acp_thread::TokenUsage { max_tokens: model.max_token_count_for_mode(self.completion_mode.into()), - used_tokens: tokens.total_tokens(), + used_tokens: usage.total_tokens(), }) } @@ -1094,6 +1104,14 @@ impl Thread { self.run_turn(cx) } + #[cfg(feature = "eval")] + pub fn proceed( + &mut self, + cx: &mut Context, + ) -> Result>> { + self.run_turn(cx) + } + fn run_turn( &mut self, cx: &mut Context, @@ -1461,7 +1479,13 @@ impl Thread { }); if push_new_tool_use { - event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone()); + event_stream.send_tool_call( + &tool_use.id, + &tool_use.name, + title, + kind, + tool_use.input.clone(), + ); last_message .content .push(AgentMessageContent::ToolUse(tool_use.clone())); @@ -2256,6 +2280,7 @@ impl ThreadEventStream { fn send_tool_call( &self, id: &LanguageModelToolUseId, + tool_name: &str, title: SharedString, kind: acp::ToolKind, input: serde_json::Value, @@ -2263,6 +2288,7 @@ impl ThreadEventStream { self.0 .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call( id, + tool_name, title.to_string(), kind, input, @@ -2272,12 +2298,15 @@ impl ThreadEventStream { fn initial_tool_call( id: &LanguageModelToolUseId, + tool_name: &str, title: String, kind: acp::ToolKind, input: serde_json::Value, ) -> acp::ToolCall { acp::ToolCall { - meta: None, + meta: Some(serde_json::json!({ + "tool_name": tool_name + })), id: acp::ToolCallId(id.to_string().into()), title, kind, diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs index ad205137a44f3fd7e33e4998c023d552e4007b5c..6f92b958b2d94e48539e34b6a58b4789ea376fb5 100644 --- a/crates/agent_servers/src/acp.rs +++ b/crates/agent_servers/src/acp.rs @@ -9,9 +9,7 @@ use futures::io::BufReader; use project::Project; use project::agent_server_store::AgentServerCommand; use serde::Deserialize; -use settings::{Settings as _, SettingsLocation}; -use task::Shell; -use util::{ResultExt as _, get_default_system_shell_preferring_bash}; +use util::ResultExt as _; use std::path::PathBuf; use std::{any::Any, cell::RefCell}; @@ -23,7 +21,7 @@ use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntit use acp_thread::{AcpThread, AuthRequired, LoadError, TerminalProviderEvent}; use terminal::TerminalBuilder; -use terminal::terminal_settings::{AlternateScroll, CursorShape, TerminalSettings}; +use terminal::terminal_settings::{AlternateScroll, CursorShape}; #[derive(Debug, Error)] #[error("Unsupported version")] @@ -816,62 +814,18 @@ impl acp::Client for ClientDelegate { let thread = self.session_thread(&args.session_id)?; let project = thread.read_with(&self.cx, |thread, _cx| thread.project().clone())?; - let mut env = if let Some(dir) = &args.cwd { - project - .update(&mut self.cx.clone(), |project, cx| { - let worktree = project.find_worktree(dir.as_path(), cx); - let shell = TerminalSettings::get( - worktree.as_ref().map(|(worktree, path)| SettingsLocation { - worktree_id: worktree.read(cx).id(), - path: &path, - }), - cx, - ) - .shell - .clone(); - project.directory_environment(&shell, dir.clone().into(), cx) - })? - .await - .unwrap_or_default() - } else { - Default::default() - }; - // Disables paging for `git` and hopefully other commands - env.insert("PAGER".into(), "".into()); - for var in args.env { - env.insert(var.name, var.value); - } - - // Use remote shell or default system shell, as appropriate - let shell = project - .update(&mut self.cx.clone(), |project, cx| { - project - .remote_client() - .and_then(|r| r.read(cx).default_system_shell()) - .map(Shell::Program) - })? - .unwrap_or_else(|| Shell::Program(get_default_system_shell_preferring_bash())); - let is_windows = project - .read_with(&self.cx, |project, cx| project.path_style(cx).is_windows()) - .unwrap_or(cfg!(windows)); - let (task_command, task_args) = task::ShellBuilder::new(&shell, is_windows) - .redirect_stdin_to_dev_null() - .build(Some(args.command.clone()), &args.args); - - let terminal_entity = project - .update(&mut self.cx.clone(), |project, cx| { - project.create_terminal_task( - task::SpawnInTerminal { - command: Some(task_command), - args: task_args, - cwd: args.cwd.clone(), - env, - ..Default::default() - }, - cx, - ) - })? - .await?; + let terminal_entity = acp_thread::create_terminal_entity( + args.command.clone(), + &args.args, + args.env + .into_iter() + .map(|env| (env.name, env.value)) + .collect(), + args.cwd.clone(), + &project, + &mut self.cx.clone(), + ) + .await?; // Register with renderer let terminal_entity = thread.update(&mut self.cx.clone(), |thread, cx| { diff --git a/crates/eval/Cargo.toml b/crates/eval/Cargo.toml index 6e1cc7a6f554ac4a8d6c84ecf58f4fcbf8ac1d96..30908be1e2fde15c0c32894b266d971b7f0ca54f 100644 --- a/crates/eval/Cargo.toml +++ b/crates/eval/Cargo.toml @@ -19,7 +19,7 @@ path = "src/explorer.rs" [dependencies] acp_thread.workspace = true -agent.workspace = true +agent = { workspace = true, features = ["eval"] } agent-client-protocol.workspace = true agent_settings.workspace = true agent_ui.workspace = true @@ -29,7 +29,6 @@ buffer_diff.workspace = true chrono.workspace = true clap.workspace = true client.workspace = true -cloud_llm_client.workspace = true collections.workspace = true debug_adapter_extension.workspace = true dirs.workspace = true @@ -54,13 +53,13 @@ pretty_assertions.workspace = true project.workspace = true prompt_store.workspace = true regex.workspace = true +rand.workspace = true release_channel.workspace = true reqwest_client.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true shellexpand.workspace = true -smol.workspace = true telemetry.workspace = true terminal_view.workspace = true toml.workspace = true diff --git a/crates/eval/runner_settings.json b/crates/eval/runner_settings.json index 53d853023c75e78f19c78f797b5751ff79bf1e44..ea2ccb051164c4a6c40aed9d6607db0a8911c5d6 100644 --- a/crates/eval/runner_settings.json +++ b/crates/eval/runner_settings.json @@ -1,6 +1,5 @@ { - "assistant": { - "always_allow_tool_actions": true, - "version": "2" + "agent": { + "always_allow_tool_actions": true } } diff --git a/crates/eval/src/eval.rs b/crates/eval/src/eval.rs index 3afcc32a930ab32746352e81577d55a25c807cb4..c5b34a63eec33a45e6d1c75e73fa473f845c5e36 100644 --- a/crates/eval/src/eval.rs +++ b/crates/eval/src/eval.rs @@ -61,9 +61,22 @@ struct Args { /// Maximum number of examples to run concurrently. #[arg(long, default_value = "4")] concurrency: usize, + /// Output current environment variables as JSON to stdout + #[arg(long, hide = true)] + printenv: bool, } fn main() { + let args = Args::parse(); + + // This prevents errors showing up in the logs, because + // project::environment::load_shell_environment() calls + // std::env::current_exe().unwrap() --printenv + if args.printenv { + util::shell_env::print_env(); + return; + } + dotenvy::from_filename(CARGO_MANIFEST_DIR.join(".env")).ok(); env_logger::init(); @@ -99,7 +112,6 @@ fn main() { let zed_commit_sha = commit_sha_for_path(&root_dir); let zed_branch_name = git_branch_for_path(&root_dir); - let args = Args::parse(); let languages: HashSet = args.languages.into_iter().collect(); let http_client = Arc::new(ReqwestClient::new()); @@ -126,19 +138,20 @@ fn main() { let mut cumulative_tool_metrics = ToolMetrics::default(); - let agent_model = load_model(&args.model, cx).unwrap(); - let judge_model = load_model(&args.judge_model, cx).unwrap(); - - LanguageModelRegistry::global(cx).update(cx, |registry, cx| { - registry.set_default_model(Some(agent_model.clone()), cx); + let tasks = LanguageModelRegistry::global(cx).update(cx, |registry, cx| { + registry.providers().iter().map(|p| p.authenticate(cx)).collect::>() }); - let auth1 = agent_model.provider.authenticate(cx); - let auth2 = judge_model.provider.authenticate(cx); - cx.spawn(async move |cx| { - auth1.await?; - auth2.await?; + future::join_all(tasks).await; + let judge_model = cx.update(|cx| { + let agent_model = load_model(&args.model, cx).unwrap(); + let judge_model = load_model(&args.judge_model, cx).unwrap(); + LanguageModelRegistry::global(cx).update(cx, |registry, cx| { + registry.set_default_model(Some(agent_model.clone()), cx); + }); + judge_model + })?; let mut examples = Vec::new(); @@ -268,7 +281,6 @@ fn main() { future::join_all((0..args.concurrency).map(|_| { let app_state = app_state.clone(); - let model = agent_model.model.clone(); let judge_model = judge_model.model.clone(); let zed_commit_sha = zed_commit_sha.clone(); let zed_branch_name = zed_branch_name.clone(); @@ -283,7 +295,7 @@ fn main() { let result = async { example.setup().await?; let run_output = cx - .update(|cx| example.run(model.clone(), app_state.clone(), cx))? + .update(|cx| example.run(app_state.clone(), cx))? .await?; let judge_output = judge_example( example.clone(), @@ -524,7 +536,6 @@ async fn judge_example( diff_evaluation = judge_output.diff.clone(), thread_evaluation = judge_output.thread, tool_metrics = run_output.tool_metrics, - response_count = run_output.response_count, token_usage = run_output.token_usage, model = model.telemetry_id(), model_provider = model.provider_id().to_string(), diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index 22a8f9484c9f2c1d4ad01a107841b57e8b96f67b..84c47766e96948bccfc01f3b4472b5100c4b7b64 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -3,6 +3,7 @@ use std::{ fmt::{self, Debug}, sync::{Arc, Mutex}, time::Duration, + u32, }; use crate::{ @@ -16,11 +17,10 @@ use agent_settings::AgentProfileId; use anyhow::{Result, anyhow}; use async_trait::async_trait; use buffer_diff::DiffHunkStatus; -use cloud_llm_client::CompletionIntent; use collections::HashMap; -use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased}; +use futures::{FutureExt as _, StreamExt, select_biased}; use gpui::{App, AppContext, AsyncApp, Entity}; -use language_model::{LanguageModel, Role, StopReason}; +use language_model::Role; use util::rel_path::RelPath; pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2); @@ -93,7 +93,6 @@ pub struct ExampleContext { log_prefix: String, agent_thread: Entity, app: AsyncApp, - model: Arc, pub assertions: AssertionsReport, pub tool_metrics: Arc>, } @@ -103,7 +102,6 @@ impl ExampleContext { meta: ExampleMetadata, log_prefix: String, agent_thread: Entity, - model: Arc, app: AsyncApp, ) -> Self { let assertions = AssertionsReport::new(meta.max_assertions); @@ -113,26 +111,11 @@ impl ExampleContext { log_prefix, agent_thread, assertions, - model, app, tool_metrics: Arc::new(Mutex::new(ToolMetrics::default())), } } - pub fn push_user_message(&mut self, text: impl ToString) { - self.app - .update_entity(&self.agent_thread, |thread, cx| { - thread.insert_user_message( - text.to_string(), - ContextLoadResult::default(), - None, - Vec::new(), - cx, - ); - }) - .unwrap(); - } - pub fn assert(&mut self, expected: bool, message: impl ToString) -> Result<()> { let message = message.to_string(); self.log_assertion( @@ -204,156 +187,174 @@ impl ExampleContext { result } - pub async fn run_to_end(&mut self) -> Result { - self.run_turns(u32::MAX).await + pub async fn prompt(&mut self, prompt: impl Into) -> Result { + self.prompt_with_max_turns(prompt, u32::MAX).await } - pub async fn run_turn(&mut self) -> Result { - self.run_turns(1).await + pub async fn prompt_with_max_turns( + &mut self, + prompt: impl Into, + max_turns: u32, + ) -> Result { + let content = vec![UserMessageContent::Text(prompt.into())]; + self.run_turns(Some(content), max_turns).await } - pub async fn run_turns(&mut self, iterations: u32) -> Result { - let (mut tx, mut rx) = mpsc::channel(1); + pub async fn proceed_with_max_turns(&mut self, max_turns: u32) -> Result { + self.run_turns(None, max_turns).await + } + async fn run_turns( + &mut self, + prompt: Option>, + max_turns: u32, + ) -> Result { let tool_metrics = self.tool_metrics.clone(); let log_prefix = self.log_prefix.clone(); - let _subscription = self.app.subscribe( - &self.agent_thread, - move |thread, event: &ThreadEvent, cx| match event { - ThreadEvent::ShowError(thread_error) => { - tx.try_send(Err(anyhow!(thread_error.clone()))).ok(); - } - ThreadEvent::Stopped(reason) => match reason { - Ok(StopReason::EndTurn) => { - tx.close_channel(); + + let mut remaining_turns = max_turns; + + let mut event_stream = self.agent_thread.update(&mut self.app, |thread, cx| { + if let Some(prompt) = prompt { + let id = UserMessageId::new(); + thread.send(id, prompt, cx) + } else { + thread.proceed(cx) + } + })??; + + let task = self.app.background_spawn(async move { + let mut messages = Vec::new(); + let mut tool_uses_by_id = HashMap::default(); + while let Some(event) = event_stream.next().await { + match event? { + ThreadEvent::UserMessage(user_message) => { + messages.push(Message { + role: Role::User, + text: user_message.to_markdown(), + tool_use: Vec::new(), + }); } - Ok(StopReason::ToolUse) => { - if thread.read(cx).remaining_turns() == 0 { - tx.close_channel(); + ThreadEvent::AgentThinking(text) | ThreadEvent::AgentText(text) => { + if matches!( + messages.last(), + Some(Message { + role: Role::Assistant, + .. + }) + ) { + messages.last_mut().unwrap().text.push_str(&text); + } else { + messages.push(Message { + role: Role::Assistant, + text, + tool_use: Vec::new(), + }); } } - Ok(StopReason::MaxTokens) => { - tx.try_send(Err(anyhow!("Exceeded maximum tokens"))).ok(); - } - Ok(StopReason::Refusal) => { - tx.try_send(Err(anyhow!("Model refused to generate content"))) - .ok(); - } - Err(err) => { - tx.try_send(Err(anyhow!(err.clone()))).ok(); + ThreadEvent::ToolCall(tool_call) => { + let meta = tool_call.meta.expect("Missing meta field in tool_call"); + let tool_name = meta + .get("tool_name") + .expect("Missing tool_name field in meta") + .as_str() + .expect("Unknown tool_name content in meta"); + + tool_uses_by_id.insert( + tool_call.id, + ToolUse { + name: tool_name.to_string(), + value: tool_call.raw_input.unwrap_or_default(), + }, + ); + if matches!( + tool_call.status, + acp::ToolCallStatus::Completed | acp::ToolCallStatus::Failed + ) { + panic!("Tool call completed without update"); + } } - }, - ThreadEvent::NewRequest - | ThreadEvent::StreamedAssistantText(_, _) - | ThreadEvent::StreamedAssistantThinking(_, _) - | ThreadEvent::UsePendingTools { .. } - | ThreadEvent::CompletionCanceled => {} - ThreadEvent::ToolUseLimitReached => {} - ThreadEvent::ToolFinished { - tool_use_id, - pending_tool_use, - .. - } => { - thread.update(cx, |thread, _cx| { - if let Some(tool_use) = pending_tool_use { - let mut tool_metrics = tool_metrics.lock().unwrap(); - if let Some(tool_result) = thread.tool_result(tool_use_id) { - let message = if tool_result.is_error { - format!("✖︎ {}", tool_use.name) - } else { + ThreadEvent::ToolCallUpdate(tool_call_update) => { + if let acp_thread::ToolCallUpdate::UpdateFields(update) = tool_call_update { + if let Some(raw_input) = update.fields.raw_input { + if let Some(tool_use) = tool_uses_by_id.get_mut(&update.id) { + tool_use.value = raw_input; + } + } + + if matches!( + update.fields.status, + Some(acp::ToolCallStatus::Completed | acp::ToolCallStatus::Failed) + ) { + let succeeded = + update.fields.status == Some(acp::ToolCallStatus::Completed); + + let tool_use = tool_uses_by_id + .remove(&update.id) + .expect("Unrecognized tool call completed"); + + let log_message = if succeeded { format!("✔︎ {}", tool_use.name) + } else { + format!("✖︎ {}", tool_use.name) }; - println!("{log_prefix}{message}"); + println!("{log_prefix}{log_message}"); + tool_metrics - .insert(tool_result.tool_name.clone(), !tool_result.is_error); - } else { - let message = - format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name); - println!("{log_prefix}{message}"); - tool_metrics.insert(tool_use.name.clone(), true); + .lock() + .unwrap() + .insert(tool_use.name.clone().into(), succeeded); + + if let Some(message) = messages.last_mut() { + message.tool_use.push(tool_use); + } else { + messages.push(Message { + role: Role::Assistant, + text: "".to_string(), + tool_use: vec![tool_use], + }); + } + + remaining_turns -= 1; + if remaining_turns == 0 { + return Ok(messages); + } } } - }); - } - ThreadEvent::InvalidToolInput { .. } => { - println!("{log_prefix} invalid tool input"); - } - ThreadEvent::MissingToolUse { - tool_use_id: _, - ui_text, - } => { - println!("{log_prefix} {ui_text}"); - } - ThreadEvent::ToolConfirmationNeeded => { - panic!( + } + ThreadEvent::ToolCallAuthorization(_) => panic!( "{}Bug: Tool confirmation should not be required in eval", log_prefix - ); - } - ThreadEvent::StreamedCompletion - | ThreadEvent::MessageAdded(_) - | ThreadEvent::MessageEdited(_) - | ThreadEvent::MessageDeleted(_) - | ThreadEvent::SummaryChanged - | ThreadEvent::SummaryGenerated - | ThreadEvent::ProfileChanged - | ThreadEvent::ReceivedTextChunk - | ThreadEvent::StreamedToolUse { .. } - | ThreadEvent::CheckpointChanged - | ThreadEvent::CancelEditing => { - tx.try_send(Ok(())).ok(); - if std::env::var("ZED_EVAL_DEBUG").is_ok() { - println!("{}Event: {:#?}", log_prefix, event); - } - } - }, - ); - - let model = self.model.clone(); - - let message_count_before = self.app.update_entity(&self.agent_thread, |thread, cx| { - thread.set_remaining_turns(iterations); - thread.send_to_model(model, CompletionIntent::UserPrompt, None, cx); - thread.messages().len() - })?; - - loop { - select_biased! { - result = rx.next() => { - if let Some(result) = result { - result?; - } else { - break; + ), + ThreadEvent::Retry(status) => { + println!("{log_prefix} Got retry: {status:?}"); } - } - _ = self.app.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => { - anyhow::bail!("Agentic loop stalled - waited {THREAD_EVENT_TIMEOUT:?} without any events"); + ThreadEvent::Stop(stop_reason) => match stop_reason { + acp::StopReason::EndTurn => {} + acp::StopReason::MaxTokens => { + return Err(anyhow!("Exceeded maximum tokens")); + } + acp::StopReason::MaxTurnRequests => { + return Err(anyhow!("Exceeded maximum turn requests")); + } + acp::StopReason::Refusal => { + return Err(anyhow!("Refusal")); + } + acp::StopReason::Cancelled => return Err(anyhow!("Cancelled")), + }, } } - } + Ok(messages) + }); - let messages = self.app.read_entity(&self.agent_thread, |thread, cx| { - let mut messages = Vec::new(); - for message in thread.messages().skip(message_count_before) { - messages.push(Message { - _role: message.role, - text: message.to_message_content(), - tool_use: thread - .tool_uses_for_message(message.id, cx) - .into_iter() - .map(|tool_use| ToolUse { - name: tool_use.name.to_string(), - value: tool_use.input, - }) - .collect(), - }); + select_biased! { + result = task.fuse() => { + Ok(Response::new(result?)) } - messages - })?; - - let response = Response::new(messages); - - Ok(response) + _ = self.app.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => { + anyhow::bail!("Agentic loop stalled - waited {THREAD_EVENT_TIMEOUT:?} without any events"); + } + } } pub fn edits(&self) -> HashMap, FileEdits> { @@ -488,7 +489,7 @@ impl Response { Self { messages } } - pub fn expect_tool( + pub fn expect_tool_call( &self, tool_name: &'static str, cx: &mut ExampleContext, @@ -505,8 +506,7 @@ impl Response { }) } - #[allow(dead_code)] - pub fn tool_uses(&self) -> impl Iterator { + pub fn tool_calls(&self) -> impl Iterator { self.messages.iter().flat_map(|msg| &msg.tool_use) } @@ -517,7 +517,7 @@ impl Response { #[derive(Debug)] pub struct Message { - _role: Role, + role: Role, text: String, tool_use: Vec, } diff --git a/crates/eval/src/examples/add_arg_to_trait_method.rs b/crates/eval/src/examples/add_arg_to_trait_method.rs index 41fa7c3dc6361c25868e2bbe73b71010b5d07d80..1692932b3304e07ebce261afb75877400e0493f4 100644 --- a/crates/eval/src/examples/add_arg_to_trait_method.rs +++ b/crates/eval/src/examples/add_arg_to_trait_method.rs @@ -27,14 +27,12 @@ impl Example for AddArgToTraitMethod { async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> { const FILENAME: &str = "assistant_tool.rs"; - cx.push_user_message(format!( + let _ = cx.prompt(format!( r#" Add a `window: Option` argument to the `Tool::run` trait method in {FILENAME}, and update all the implementations of the trait and call sites accordingly. "# - )); - - let _ = cx.run_to_end().await?; + )).await?; // Adds ignored argument to all but `batch_tool` diff --git a/crates/eval/src/examples/code_block_citations.rs b/crates/eval/src/examples/code_block_citations.rs index 8150d68ac3e54772e35fe52f086fb942d8923ffb..c8ba75e99f019b0b0609743b10573bae712f82cd 100644 --- a/crates/eval/src/examples/code_block_citations.rs +++ b/crates/eval/src/examples/code_block_citations.rs @@ -29,16 +29,19 @@ impl Example for CodeBlockCitations { async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> { const FILENAME: &str = "assistant_tool.rs"; - cx.push_user_message(format!( - r#" - Show me the method bodies of all the methods of the `Tool` trait in {FILENAME}. - - Please show each method in a separate code snippet. - "# - )); // Verify that the messages all have the correct formatting. - let texts: Vec = cx.run_to_end().await?.texts().collect(); + let texts: Vec = cx + .prompt(format!( + r#" + Show me the method bodies of all the methods of the `Tool` trait in {FILENAME}. + + Please show each method in a separate code snippet. + "# + )) + .await? + .texts() + .collect(); let closing_fence = format!("\n{FENCE}"); for text in texts.iter() { diff --git a/crates/eval/src/examples/comment_translation.rs b/crates/eval/src/examples/comment_translation.rs index 893166f3f13207e3444cb03bb17b2dea650170e7..421999893a5a39b3d6f61c22d405bf90528758e7 100644 --- a/crates/eval/src/examples/comment_translation.rs +++ b/crates/eval/src/examples/comment_translation.rs @@ -22,30 +22,26 @@ impl Example for CommentTranslation { } async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> { - cx.push_user_message(r#" - Edit the following files and translate all their comments to italian, in this exact order: + let response = cx.prompt( + r#" + Edit the following files and translate all their comments to italian, in this exact order: - - font-kit/src/family.rs - - font-kit/src/canvas.rs - - font-kit/src/error.rs - "#); - cx.run_to_end().await?; + - font-kit/src/family.rs + - font-kit/src/canvas.rs + - font-kit/src/error.rs + "# + ).await?; let mut create_or_overwrite_count = 0; - cx.agent_thread().read_with(cx, |thread, cx| { - for message in thread.messages() { - for tool_use in thread.tool_uses_for_message(message.id, cx) { - if tool_use.name == "edit_file" { - let input: EditFileToolInput = serde_json::from_value(tool_use.input)?; - if !matches!(input.mode, EditFileMode::Edit) { - create_or_overwrite_count += 1; - } - } + for tool_call in response.tool_calls() { + if tool_call.name == "edit_file" { + let input = tool_call.parse_input::()?; + if !matches!(input.mode, EditFileMode::Edit) { + create_or_overwrite_count += 1; } } + } - anyhow::Ok(()) - })??; cx.assert_eq(create_or_overwrite_count, 0, "no_creation_or_overwrite")?; Ok(()) diff --git a/crates/eval/src/examples/file_change_notification.rs b/crates/eval/src/examples/file_change_notification.rs index 7879ad6f2ebb782bd4a5620f0fdf562c9aad1360..41ce10cd2240f2e81812a51b2ec581422c102c41 100644 --- a/crates/eval/src/examples/file_change_notification.rs +++ b/crates/eval/src/examples/file_change_notification.rs @@ -48,8 +48,8 @@ impl Example for FileChangeNotificationExample { })?; // Start conversation (specific message is not important) - cx.push_user_message("Find all files in this repo"); - cx.run_turn().await?; + cx.prompt_with_max_turns("Find all files in this repo", 1) + .await?; // Edit the README buffer - the model should get a notification on next turn buffer.update(cx, |buffer, cx| { @@ -58,7 +58,7 @@ impl Example for FileChangeNotificationExample { // Run for some more turns. // The model shouldn't thank us for letting it know about the file change. - cx.run_turns(3).await?; + cx.proceed_with_max_turns(3).await?; Ok(()) } diff --git a/crates/eval/src/examples/file_search.rs b/crates/eval/src/examples/file_search.rs index c893aef14299a6086e8c50072d69b0cbed7e9fde..7de7a07d19184b473fd2cb5ba29b270431b71a4c 100644 --- a/crates/eval/src/examples/file_search.rs +++ b/crates/eval/src/examples/file_search.rs @@ -25,18 +25,19 @@ impl Example for FileSearchExample { async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> { const FILENAME: &str = "find_replace_file_tool.rs"; - cx.push_user_message(format!( - r#" + + let prompt = format!( + r#" Look at the `{FILENAME}`. I want to implement a card for it. The card should implement the `Render` trait. The card should show a diff. It should be a beautifully presented diff. The card "box" should look like what we show for markdown codeblocks (look at `MarkdownElement`). I want to see a red background for lines that were deleted and a green background for lines that were added. We should have a div per diff line. "# - )); + ); - let response = cx.run_turn().await?; - let tool_use = response.expect_tool("find_path", cx)?; + let response = cx.prompt_with_max_turns(prompt, 1).await?; + let tool_use = response.expect_tool_call("find_path", cx)?; let input = tool_use.parse_input::()?; let glob = input.glob; diff --git a/crates/eval/src/examples/grep_params_escapement.rs b/crates/eval/src/examples/grep_params_escapement.rs index face6451572725ed402f23aac7bdc2c70a670b67..57086a1b9bd217e04072754539ddea20aa38c7a8 100644 --- a/crates/eval/src/examples/grep_params_escapement.rs +++ b/crates/eval/src/examples/grep_params_escapement.rs @@ -1,3 +1,4 @@ +use agent::GrepToolInput; use agent_settings::AgentProfileId; use anyhow::Result; use async_trait::async_trait; @@ -35,9 +36,9 @@ impl Example for GrepParamsEscapementExample { } async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> { - // cx.push_user_message("How does the precedence/specificity work with Keymap contexts? I am seeing that `MessageEditor > Editor` is lower precendence than `Editor` which is surprising to me, but might be how it works"); - cx.push_user_message("Search for files containing the characters `>` or `<`"); - let response = cx.run_turns(2).await?; + let response = cx + .prompt_with_max_turns("Search for files containing the characters `>` or `<`", 2) + .await?; let grep_input = response .find_tool_call("grep") .and_then(|tool_use| tool_use.parse_input::().ok()); diff --git a/crates/eval/src/examples/mod.rs b/crates/eval/src/examples/mod.rs index afe258aa76b1abb5406ce212af4f223c56cb2020..aec1bce07957fb81c17666b3e64b00a1fa47240f 100644 --- a/crates/eval/src/examples/mod.rs +++ b/crates/eval/src/examples/mod.rs @@ -144,9 +144,8 @@ impl Example for DeclarativeExample { } async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> { - cx.push_user_message(&self.prompt); let max_turns = self.metadata.max_turns.unwrap_or(1000); - let _ = cx.run_turns(max_turns).await; + let _ = cx.prompt_with_max_turns(&self.prompt, max_turns).await; Ok(()) } diff --git a/crates/eval/src/examples/overwrite_file.rs b/crates/eval/src/examples/overwrite_file.rs index d4b73aaec4d7d9a18be411ba7d453db9ffcb18a1..a4df1e97a3f4d9c66262f8679d93324e53df9d53 100644 --- a/crates/eval/src/examples/overwrite_file.rs +++ b/crates/eval/src/examples/overwrite_file.rs @@ -1,3 +1,4 @@ +use agent::{EditFileMode, EditFileToolInput}; use agent_settings::AgentProfileId; use anyhow::Result; use async_trait::async_trait; @@ -35,17 +36,14 @@ impl Example for FileOverwriteExample { } async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> { - let response = cx.run_turns(1).await?; - let file_overwritten = if let Some(tool_use) = response.find_tool_call("edit_file") { - let input = tool_use.parse_input::()?; - match input.mode { - EditFileMode::Edit => false, - EditFileMode::Create | EditFileMode::Overwrite => { - input.path.ends_with("src/language_model_selector.rs") - } + let response = cx.proceed_with_max_turns(1).await?; + let tool_use = response.expect_tool_call("edit_file", cx)?; + let input = tool_use.parse_input::()?; + let file_overwritten = match input.mode { + EditFileMode::Edit => false, + EditFileMode::Create | EditFileMode::Overwrite => { + input.path.ends_with("src/language_model_selector.rs") } - } else { - false }; cx.assert(!file_overwritten, "File should be edited, not overwritten") diff --git a/crates/eval/src/examples/planets.rs b/crates/eval/src/examples/planets.rs index caa15c728400a82b4223fb9ea8522b0815b36b5a..6b6ca0e3fe75633c49f11f24a24835dc58886a01 100644 --- a/crates/eval/src/examples/planets.rs +++ b/crates/eval/src/examples/planets.rs @@ -23,20 +23,19 @@ impl Example for Planets { } async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> { - cx.push_user_message( - r#" + let response = cx + .prompt( + r#" Make a plain JavaScript web page which renders an animated 3D solar system. Let me drag to rotate the camera around. Do not use npm. - "# - .to_string(), - ); - - let response = cx.run_to_end().await?; + "#, + ) + .await?; let mut open_tool_uses = 0; let mut terminal_tool_uses = 0; - for tool_use in response.tool_uses() { + for tool_use in response.tool_calls() { if tool_use.name == OpenTool::name() { open_tool_uses += 1; } else if tool_use.name == TerminalTool::name() { diff --git a/crates/eval/src/instance.rs b/crates/eval/src/instance.rs index e95264c3c3b726244abe4edb61dee474d3bff51a..5317f100456748616dfec63819bc0373aaceb4c1 100644 --- a/crates/eval/src/instance.rs +++ b/crates/eval/src/instance.rs @@ -1,36 +1,38 @@ -use agent::Message; +use agent::ContextServerRegistry; +use agent_client_protocol as acp; use anyhow::{Context as _, Result, anyhow, bail}; use client::proto::LspWorkProgress; use futures::channel::mpsc; +use futures::future::Shared; use futures::{FutureExt as _, StreamExt as _, future}; use gpui::{App, AppContext as _, AsyncApp, Entity, Task}; use handlebars::Handlebars; use language::{Buffer, DiagnosticSeverity, OffsetRangeExt as _}; use language_model::{ - LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage, - LanguageModelToolResultContent, MessageContent, Role, TokenUsage, + LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest, + LanguageModelRequestMessage, LanguageModelToolResultContent, MessageContent, Role, TokenUsage, }; -use project::lsp_store::OpenLspBufferHandle; -use project::{DiagnosticSummary, Project, ProjectPath}; +use project::{DiagnosticSummary, Project, ProjectPath, lsp_store::OpenLspBufferHandle}; +use prompt_store::{ProjectContext, WorktreeContext}; +use rand::{distr, prelude::*}; use serde::{Deserialize, Serialize}; -use std::cell::RefCell; -use std::fmt::Write as _; -use std::fs; -use std::fs::File; -use std::io::Write as _; -use std::path::Path; -use std::path::PathBuf; -use std::rc::Rc; -use std::sync::Arc; -use std::time::Duration; +use std::{ + fmt::Write as _, + fs::{self, File}, + io::Write as _, + path::{Path, PathBuf}, + rc::Rc, + sync::{Arc, Mutex}, + time::Duration, +}; use unindent::Unindent as _; -use util::ResultExt as _; -use util::command::new_smol_command; -use util::markdown::MarkdownCodeBlock; +use util::{ResultExt as _, command::new_smol_command, markdown::MarkdownCodeBlock}; -use crate::assertions::{AssertionsReport, RanAssertion, RanAssertionResult}; -use crate::example::{Example, ExampleContext, FailedAssertion, JudgeAssertion}; -use crate::{AgentAppState, ToolMetrics}; +use crate::{ + AgentAppState, ToolMetrics, + assertions::{AssertionsReport, RanAssertion, RanAssertionResult}, + example::{Example, ExampleContext, FailedAssertion, JudgeAssertion}, +}; pub const ZED_REPO_URL: &str = "https://github.com/zed-industries/zed.git"; @@ -56,10 +58,9 @@ pub struct RunOutput { pub diagnostic_summary_after: DiagnosticSummary, pub diagnostics_before: Option, pub diagnostics_after: Option, - pub response_count: usize, pub token_usage: TokenUsage, pub tool_metrics: ToolMetrics, - pub all_messages: String, + pub thread_markdown: String, pub programmatic_assertions: AssertionsReport, } @@ -193,12 +194,7 @@ impl ExampleInstance { .join(self.thread.meta().repo_name()) } - pub fn run( - &self, - model: Arc, - app_state: Arc, - cx: &mut App, - ) -> Task> { + pub fn run(&self, app_state: Arc, cx: &mut App) -> Task> { let project = Project::local( app_state.client.clone(), app_state.node_runtime.clone(), @@ -213,15 +209,6 @@ impl ExampleInstance { project.create_worktree(self.worktree_path(), true, cx) }); - let tools = cx.new(|_| ToolWorkingSet::default()); - let prompt_store = None; - let thread_store = ThreadStore::load( - project.clone(), - tools, - prompt_store, - app_state.prompt_builder.clone(), - cx, - ); let meta = self.thread.meta(); let this = self.clone(); @@ -300,74 +287,62 @@ impl ExampleInstance { // history using undo/redo. std::fs::write(&last_diff_file_path, "")?; - let thread_store = thread_store.await?; - + let thread = cx.update(|cx| { + //todo: Do we want to load rules files here? + let worktrees = project.read(cx).visible_worktrees(cx).map(|worktree| { + let root_name = worktree.read(cx).root_name_str().into(); + let abs_path = worktree.read(cx).abs_path(); - let thread = - thread_store.update(cx, |thread_store, cx| { - let thread = if let Some(json) = &meta.existing_thread_json { - let serialized = SerializedThread::from_json(json.as_bytes()).expect("Can't read serialized thread"); - thread_store.create_thread_from_serialized(serialized, cx) - } else { - thread_store.create_thread(cx) - }; - thread.update(cx, |thread, cx| { - thread.set_profile(meta.profile_id.clone(), cx); - }); - thread - })?; - - - thread.update(cx, |thread, _cx| { - let mut request_count = 0; - let previous_diff = Rc::new(RefCell::new("".to_string())); - let example_output_dir = this.run_directory.clone(); - let last_diff_file_path = last_diff_file_path.clone(); - let messages_json_file_path = example_output_dir.join("last.messages.json"); - let this = this.clone(); - thread.set_request_callback(move |request, response_events| { - request_count += 1; - let messages_file_path = example_output_dir.join(format!("{request_count}.messages.md")); - let diff_file_path = example_output_dir.join(format!("{request_count}.diff")); - let last_messages_file_path = example_output_dir.join("last.messages.md"); - let request_markdown = RequestMarkdown::new(request); - let response_events_markdown = response_events_to_markdown(response_events); - let dialog = ThreadDialog::new(request, response_events); - let dialog_json = serde_json::to_string_pretty(&dialog.to_combined_request()).unwrap_or_default(); - - let messages = format!("{}\n\n{}", request_markdown.messages, response_events_markdown); - fs::write(&messages_file_path, messages.clone()).expect("failed to write messages file"); - fs::write(&last_messages_file_path, messages).expect("failed to write last messages file"); - fs::write(&messages_json_file_path, dialog_json).expect("failed to write last.messages.json"); - - let diff_result = smol::block_on(this.repository_diff()); - match diff_result { - Ok(diff) => { - if diff != previous_diff.borrow().clone() { - fs::write(&diff_file_path, &diff).expect("failed to write diff file"); - fs::write(&last_diff_file_path, &diff).expect("failed to write last diff file"); - *previous_diff.borrow_mut() = diff; - } - } - Err(err) => { - let error_message = format!("{err:?}"); - fs::write(&diff_file_path, &error_message).expect("failed to write diff error to file"); - fs::write(&last_diff_file_path, &error_message).expect("failed to write last diff file"); - } + WorktreeContext { + root_name, + abs_path, + rules_file: None, } + }).collect::>(); + let project_context = cx.new(|_cx| ProjectContext::new(worktrees, vec![])); + let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + + let thread = if let Some(json) = &meta.existing_thread_json { + let session_id = acp::SessionId( + rand::rng() + .sample_iter(&distr::Alphanumeric) + .take(7) + .map(char::from) + .collect::() + .into(), + ); + + let db_thread = agent::DbThread::from_json(json.as_bytes()).expect("Can't read serialized thread"); + cx.new(|cx| agent::Thread::from_db(session_id, db_thread, project.clone(), project_context, context_server_registry, agent::Templates::new(), cx)) + } else { + cx.new(|cx| agent::Thread::new(project.clone(), project_context, context_server_registry, agent::Templates::new(), None, cx)) + }; - if request_count == 1 { - let tools_file_path = example_output_dir.join("tools.md"); - fs::write(tools_file_path, request_markdown.tools).expect("failed to write tools file"); - } + thread.update(cx, |thread, cx| { + thread.add_default_tools(Rc::new(EvalThreadEnvironment { + project: project.clone(), + }), cx); + thread.set_profile(meta.profile_id.clone()); + thread.set_model( + LanguageModelInterceptor::new( + LanguageModelRegistry::read_global(cx).default_model().expect("Missing model").model.clone(), + this.run_directory.clone(), + last_diff_file_path.clone(), + this.run_directory.join("last.messages.json"), + this.worktree_path(), + this.repo_url(), + ), + cx, + ); }); - })?; + + thread + }).unwrap(); let mut example_cx = ExampleContext::new( meta.clone(), this.log_prefix.clone(), thread.clone(), - model.clone(), cx.clone(), ); let result = this.thread.conversation(&mut example_cx).await; @@ -380,7 +355,7 @@ impl ExampleInstance { println!("{}Stopped", this.log_prefix); println!("{}Getting repository diff", this.log_prefix); - let repository_diff = this.repository_diff().await?; + let repository_diff = Self::repository_diff(this.worktree_path(), &this.repo_url()).await?; std::fs::write(last_diff_file_path, &repository_diff)?; @@ -415,34 +390,28 @@ impl ExampleInstance { } thread.update(cx, |thread, _cx| { - let response_count = thread - .messages() - .filter(|message| message.role == language_model::Role::Assistant) - .count(); RunOutput { repository_diff, diagnostic_summary_before, diagnostic_summary_after, diagnostics_before, diagnostics_after, - response_count, - token_usage: thread.cumulative_token_usage(), + token_usage: thread.latest_request_token_usage().unwrap(), tool_metrics: example_cx.tool_metrics.lock().unwrap().clone(), - all_messages: messages_to_markdown(thread.messages()), + thread_markdown: thread.to_markdown(), programmatic_assertions: example_cx.assertions, } }) }) } - async fn repository_diff(&self) -> Result { - let worktree_path = self.worktree_path(); - run_git(&worktree_path, &["add", "."]).await?; + async fn repository_diff(repository_path: PathBuf, repository_url: &str) -> Result { + run_git(&repository_path, &["add", "."]).await?; let mut diff_args = vec!["diff", "--staged"]; - if self.thread.meta().url == ZED_REPO_URL { + if repository_url == ZED_REPO_URL { diff_args.push(":(exclude).rules"); } - run_git(&worktree_path, &diff_args).await + run_git(&repository_path, &diff_args).await } pub async fn judge( @@ -542,7 +511,7 @@ impl ExampleInstance { hbs.register_template_string(judge_thread_prompt_name, judge_thread_prompt) .unwrap(); - let complete_messages = &run_output.all_messages; + let complete_messages = &run_output.thread_markdown; let to_prompt = |assertion: String| { hbs.render( judge_thread_prompt_name, @@ -634,6 +603,273 @@ impl ExampleInstance { } } +struct EvalThreadEnvironment { + project: Entity, +} + +struct EvalTerminalHandle { + terminal: Entity, +} + +impl agent::TerminalHandle for EvalTerminalHandle { + fn id(&self, cx: &AsyncApp) -> Result { + self.terminal.read_with(cx, |term, _cx| term.id().clone()) + } + + fn wait_for_exit(&self, cx: &AsyncApp) -> Result>> { + self.terminal + .read_with(cx, |term, _cx| term.wait_for_exit()) + } + + fn current_output(&self, cx: &AsyncApp) -> Result { + self.terminal + .read_with(cx, |term, cx| term.current_output(cx)) + } +} + +impl agent::ThreadEnvironment for EvalThreadEnvironment { + fn create_terminal( + &self, + command: String, + cwd: Option, + output_byte_limit: Option, + cx: &mut AsyncApp, + ) -> Task>> { + let project = self.project.clone(); + cx.spawn(async move |cx| { + let language_registry = + project.read_with(cx, |project, _cx| project.languages().clone())?; + let id = acp::TerminalId(uuid::Uuid::new_v4().to_string().into()); + let terminal = + acp_thread::create_terminal_entity(command, &[], vec![], cwd.clone(), &project, cx) + .await?; + let terminal = cx.new(|cx| { + acp_thread::Terminal::new( + id, + "", + cwd, + output_byte_limit.map(|limit| limit as usize), + terminal, + language_registry, + cx, + ) + })?; + Ok(Rc::new(EvalTerminalHandle { terminal }) as Rc) + }) + } +} + +struct LanguageModelInterceptor { + model: Arc, + request_count: Arc>, + previous_diff: Arc>, + example_output_dir: PathBuf, + last_diff_file_path: PathBuf, + messages_json_file_path: PathBuf, + repository_path: PathBuf, + repository_url: String, +} + +impl LanguageModelInterceptor { + fn new( + model: Arc, + example_output_dir: PathBuf, + last_diff_file_path: PathBuf, + messages_json_file_path: PathBuf, + repository_path: PathBuf, + repository_url: String, + ) -> Arc { + Arc::new(Self { + model, + request_count: Arc::new(Mutex::new(0)), + previous_diff: Arc::new(Mutex::new("".to_string())), + example_output_dir, + last_diff_file_path, + messages_json_file_path, + repository_path, + repository_url, + }) + } +} + +impl language_model::LanguageModel for LanguageModelInterceptor { + fn id(&self) -> language_model::LanguageModelId { + self.model.id() + } + + fn name(&self) -> language_model::LanguageModelName { + self.model.name() + } + + fn provider_id(&self) -> language_model::LanguageModelProviderId { + self.model.provider_id() + } + + fn provider_name(&self) -> language_model::LanguageModelProviderName { + self.model.provider_name() + } + + fn telemetry_id(&self) -> String { + self.model.telemetry_id() + } + + fn supports_images(&self) -> bool { + self.model.supports_images() + } + + fn supports_tools(&self) -> bool { + self.model.supports_tools() + } + + fn supports_tool_choice(&self, choice: language_model::LanguageModelToolChoice) -> bool { + self.model.supports_tool_choice(choice) + } + + fn max_token_count(&self) -> u64 { + self.model.max_token_count() + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &App, + ) -> future::BoxFuture<'static, Result> { + self.model.count_tokens(request, cx) + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncApp, + ) -> future::BoxFuture< + 'static, + Result< + futures::stream::BoxStream< + 'static, + Result, + >, + language_model::LanguageModelCompletionError, + >, + > { + let stream = self.model.stream_completion(request.clone(), cx); + let request_count = self.request_count.clone(); + let previous_diff = self.previous_diff.clone(); + let example_output_dir = self.example_output_dir.clone(); + let last_diff_file_path = self.last_diff_file_path.clone(); + let messages_json_file_path = self.messages_json_file_path.clone(); + let repository_path = self.repository_path.clone(); + let repository_url = self.repository_url.clone(); + + Box::pin(async move { + let stream = stream.await?; + + let response_events = Arc::new(Mutex::new(Vec::new())); + let request_clone = request.clone(); + + let wrapped_stream = stream.then(move |event| { + let response_events = response_events.clone(); + let request = request_clone.clone(); + let request_count = request_count.clone(); + let previous_diff = previous_diff.clone(); + let example_output_dir = example_output_dir.clone(); + let last_diff_file_path = last_diff_file_path.clone(); + let messages_json_file_path = messages_json_file_path.clone(); + let repository_path = repository_path.clone(); + let repository_url = repository_url.clone(); + + async move { + let event_result = match &event { + Ok(ev) => Ok(ev.clone()), + Err(err) => Err(err.to_string()), + }; + response_events.lock().unwrap().push(event_result); + + let should_execute = matches!( + &event, + Ok(LanguageModelCompletionEvent::Stop { .. }) | Err(_) + ); + + if should_execute { + let current_request_count = { + let mut count = request_count.lock().unwrap(); + *count += 1; + *count + }; + + let messages_file_path = + example_output_dir.join(format!("{current_request_count}.messages.md")); + let diff_file_path = + example_output_dir.join(format!("{current_request_count}.diff")); + let last_messages_file_path = example_output_dir.join("last.messages.md"); + + let collected_events = response_events.lock().unwrap().clone(); + let request_markdown = RequestMarkdown::new(&request); + let response_events_markdown = + response_events_to_markdown(&collected_events); + let dialog = ThreadDialog::new(&request, &collected_events); + let dialog_json = + serde_json::to_string_pretty(&dialog.to_combined_request()) + .unwrap_or_default(); + + let messages = format!( + "{}\n\n{}", + request_markdown.messages, response_events_markdown + ); + fs::write(&messages_file_path, messages.clone()) + .expect("failed to write messages file"); + fs::write(&last_messages_file_path, messages) + .expect("failed to write last messages file"); + fs::write(&messages_json_file_path, dialog_json) + .expect("failed to write last.messages.json"); + + // Get repository diff + let diff_result = + ExampleInstance::repository_diff(repository_path, &repository_url) + .await; + + match diff_result { + Ok(diff) => { + let prev_diff = previous_diff.lock().unwrap().clone(); + if diff != prev_diff { + fs::write(&diff_file_path, &diff) + .expect("failed to write diff file"); + fs::write(&last_diff_file_path, &diff) + .expect("failed to write last diff file"); + *previous_diff.lock().unwrap() = diff; + } + } + Err(err) => { + let error_message = format!("{err:?}"); + fs::write(&diff_file_path, &error_message) + .expect("failed to write diff error to file"); + fs::write(&last_diff_file_path, &error_message) + .expect("failed to write last diff file"); + } + } + + if current_request_count == 1 { + let tools_file_path = example_output_dir.join("tools.md"); + fs::write(tools_file_path, request_markdown.tools) + .expect("failed to write tools file"); + } + } + + event + } + }); + + Ok(Box::pin(wrapped_stream) + as futures::stream::BoxStream< + 'static, + Result< + LanguageModelCompletionEvent, + language_model::LanguageModelCompletionError, + >, + >) + }) + } +} + pub fn wait_for_lang_server( project: &Entity, buffer: &Entity, @@ -825,40 +1061,6 @@ pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result { Ok(String::from_utf8(output.stdout)?.trim().to_string()) } -fn messages_to_markdown<'a>(message_iter: impl IntoIterator) -> String { - let mut messages = String::new(); - let mut assistant_message_number: u32 = 1; - - for message in message_iter { - push_role(&message.role, &mut messages, &mut assistant_message_number); - - for segment in &message.segments { - match segment { - MessageSegment::Text(text) => { - messages.push_str(text); - messages.push_str("\n\n"); - } - MessageSegment::Thinking { text, signature } => { - messages.push_str("**Thinking**:\n\n"); - if let Some(sig) = signature { - messages.push_str(&format!("Signature: {}\n\n", sig)); - } - messages.push_str(text); - messages.push_str("\n"); - } - MessageSegment::RedactedThinking(items) => { - messages.push_str(&format!( - "**Redacted Thinking**: {} item(s)\n\n", - items.len() - )); - } - } - } - } - - messages -} - fn push_role(role: &Role, buf: &mut String, assistant_message_number: &mut u32) { match role { Role::System => buf.push_str("# ⚙️ SYSTEM\n\n"),