Detailed changes
@@ -4895,6 +4895,7 @@ dependencies = [
"anyhow",
"assistant_tool",
"assistant_tools",
+ "async-trait",
"async-watch",
"chrono",
"clap",
@@ -4915,13 +4916,14 @@ dependencies = [
"language_models",
"languages",
"node_runtime",
- "parking_lot",
"paths",
"project",
"prompt_store",
+ "regex",
"release_channel",
"reqwest_client",
"serde",
+ "serde_json",
"settings",
"shellexpand 2.1.2",
"smol",
@@ -315,6 +315,7 @@ pub struct Thread {
request_callback: Option<
Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
>,
+ remaining_turns: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -368,6 +369,7 @@ impl Thread {
message_feedback: HashMap::default(),
last_auto_capture_at: None,
request_callback: None,
+ remaining_turns: u32::MAX,
}
}
@@ -442,6 +444,7 @@ impl Thread {
message_feedback: HashMap::default(),
last_auto_capture_at: None,
request_callback: None,
+ remaining_turns: u32::MAX,
}
}
@@ -522,7 +525,7 @@ impl Thread {
self.messages.iter().find(|message| message.id == id)
}
- pub fn messages(&self) -> impl Iterator<Item = &Message> {
+ pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
self.messages.iter()
}
@@ -958,7 +961,21 @@ impl Thread {
})
}
+ pub fn remaining_turns(&self) -> u32 {
+ self.remaining_turns
+ }
+
+ pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
+ self.remaining_turns = remaining_turns;
+ }
+
pub fn send_to_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
+ if self.remaining_turns == 0 {
+ return;
+ }
+
+ self.remaining_turns -= 1;
+
let mut request = self.to_completion_request(cx);
if model.supports_tools() {
request.tools = {
@@ -56,6 +56,8 @@ use crate::symbol_info_tool::SymbolInfoTool;
use crate::terminal_tool::TerminalTool;
use crate::thinking_tool::ThinkingTool;
+pub use path_search_tool::PathSearchToolInput;
+
pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
assistant_tool::init(cx);
@@ -9,6 +9,7 @@ agent.workspace = true
anyhow.workspace = true
assistant_tool.workspace = true
assistant_tools.workspace = true
+async-trait.workspace = true
async-watch.workspace = true
chrono.workspace = true
clap.workspace = true
@@ -29,13 +30,14 @@ language_model.workspace = true
language_models.workspace = true
languages = { workspace = true, features = ["load-grammars"] }
node_runtime.workspace = true
-parking_lot.workspace = true
paths.workspace = true
project.workspace = true
prompt_store.workspace = true
+regex.workspace = true
release_channel.workspace = true
reqwest_client.workspace = true
serde.workspace = true
+serde_json.workspace = true
settings.workspace = true
shellexpand.workspace = true
smol.workspace = true
@@ -45,7 +47,6 @@ unindent.workspace = true
util.workspace = true
uuid = { version = "1.6", features = ["v4"] }
workspace-hack.workspace = true
-
[[bin]]
name = "eval"
path = "src/eval.rs"
@@ -1,3 +0,0 @@
-url = "https://github.com/zed-industries/zed.git"
-revision = "38fcadf9481d018543c65f36ac3bafeba190179b"
-language_extension = "rs"
@@ -1,2 +0,0 @@
-- The changes must replace the previous output returned by `FindReplaceFileTool` with the new `ToolResult` struct. The struct should contain an `output` field that is the same as the task we were returning before, and a new `card` field that contains a view for the card.
-- The card should be a view that displays a diff. Each line in the diff should be colored according to whether it was added, removed or unchanged.
@@ -1,3 +0,0 @@
-Look at the `find_replace_file_tool.rs`. I want to implement a card for it. The card should implement the `Render` trait.
-
-The card should show a diff. It should be a beautifully presented diff. The card "box" should look like what we show for markdown codeblocks (look at `MarkdownElement`). I want to see a red background for lines that were deleted and a green background for lines that were added. We should have a div per diff line.
@@ -1,3 +0,0 @@
-- The first tool call should be to path search including "find_replace_file_tool.rs" in the string. (*Not* grep, for example, or reading the file based on a guess at the path.) This is because we gave the model a filename and it needs to turn that into a real path.
-- After obtaining the correct path of "zed/crates/assistant_tools/src/find_replace_file_tool.rs", it should read the contents of that path.
-- When trying to find information about the Render trait, it should *not* begin with a path search, because it doesn't yet have any information on what path the Render trait might be in.
@@ -0,0 +1,157 @@
+use serde::{Deserialize, Serialize};
+use std::fmt::Write;
+use std::fmt::{self};
+
+#[derive(Default, Debug, Serialize, Deserialize, Clone)]
+pub struct AssertionsReport {
+ pub ran: Vec<RanAssertion>,
+ pub max: Option<usize>,
+}
+
+#[derive(Debug, Serialize, Deserialize, Clone)]
+pub struct RanAssertion {
+ pub id: String,
+ pub result: Result<RanAssertionResult, String>,
+}
+
+#[derive(Debug, Serialize, Deserialize, Clone)]
+pub struct RanAssertionResult {
+ pub analysis: Option<String>,
+ pub passed: bool,
+}
+
+impl AssertionsReport {
+ pub fn new(max: Option<usize>) -> Self {
+ AssertionsReport {
+ ran: Vec::new(),
+ max,
+ }
+ }
+
+ pub fn is_empty(&self) -> bool {
+ self.ran.is_empty()
+ }
+
+ pub fn total_count(&self) -> usize {
+ self.run_count().max(self.max.unwrap_or(0))
+ }
+
+ pub fn run_count(&self) -> usize {
+ self.ran.len()
+ }
+
+ pub fn passed_count(&self) -> usize {
+ self.ran
+ .iter()
+ .filter(|a| a.result.as_ref().map_or(false, |result| result.passed))
+ .count()
+ }
+
+ pub fn passed_percentage(&self) -> f32 {
+ if self.total_count() == 0 {
+ 0.0
+ } else {
+ (self.passed_count() as f32 / self.total_count() as f32) * 100.0
+ }
+ }
+}
+
+const ROUND_WIDTH: usize = "Round".len();
+const ASSERTIONS_WIDTH: usize = 42;
+const RESULTS_WIDTH: usize = 8;
+
+pub fn print_table_header() {
+ println!(
+ "┌─{}─┬─{}─┬─{}─┐",
+ "─".repeat(ROUND_WIDTH),
+ "─".repeat(ASSERTIONS_WIDTH),
+ "─".repeat(RESULTS_WIDTH)
+ );
+
+ println!(
+ "│ {:^ROUND_WIDTH$} │ {:^ASSERTIONS_WIDTH$} │ {:^RESULTS_WIDTH$} │",
+ "Round", "Assertion", "Result"
+ );
+
+ println!(
+ "├─{}─┼─{}─┼─{}─┤",
+ "─".repeat(ROUND_WIDTH),
+ "─".repeat(ASSERTIONS_WIDTH),
+ "─".repeat(RESULTS_WIDTH)
+ )
+}
+
+pub fn display_error_row(f: &mut String, round: usize, error: String) -> fmt::Result {
+ let last_two_columns = ASSERTIONS_WIDTH + RESULTS_WIDTH;
+ writeln!(
+ f,
+ "│ {:^ROUND_WIDTH$} │ {:<last_two_columns$} |",
+ round,
+ truncate(&error, last_two_columns)
+ )
+}
+
+pub fn display_table_row(f: &mut String, round: usize, assertion: &RanAssertion) -> fmt::Result {
+ let result = match &assertion.result {
+ Ok(result) if result.passed => "\x1b[32m✔︎ Passed\x1b[0m",
+ Ok(_) => "\x1b[31m✗ Failed\x1b[0m",
+ Err(_) => "\x1b[31m💥 Judge Error\x1b[0m",
+ };
+
+ writeln!(
+ f,
+ "│ {:^ROUND_WIDTH$} │ {:<ASSERTIONS_WIDTH$} │ {:>RESULTS_WIDTH$} │",
+ round,
+ truncate(&assertion.id, ASSERTIONS_WIDTH),
+ result
+ )
+}
+
+pub fn print_table_round_summary<'a>(
+ round: &str,
+ reports: impl Iterator<Item = &'a AssertionsReport>,
+) {
+ let mut passed = 0;
+ let mut total = 0;
+ for report in reports {
+ passed += report.passed_count();
+ total += report.total_count();
+ }
+
+ println!(
+ "│ {:^ROUND_WIDTH$} │ {:<ASSERTIONS_WIDTH$} │ {:>RESULTS_WIDTH$} │",
+ round,
+ "total",
+ format!("{}%", (passed as f32 / total as f32 * 100.0).floor())
+ )
+}
+
+pub fn print_table_footer() {
+ println!(
+ "└─{}─┴─{}─┴─{}─┘",
+ "─".repeat(ROUND_WIDTH),
+ "─".repeat(ASSERTIONS_WIDTH),
+ "─".repeat(RESULTS_WIDTH)
+ )
+}
+
+pub fn print_table_divider() {
+ println!(
+ "├─{}─┼─{}─┼─{}─┤",
+ "─".repeat(ROUND_WIDTH),
+ "─".repeat(ASSERTIONS_WIDTH),
+ "─".repeat(RESULTS_WIDTH)
+ )
+}
+
+fn truncate(assertion: &str, max_width: usize) -> String {
+ if assertion.len() <= max_width {
+ assertion.to_string()
+ } else {
+ let mut end_ix = max_width - 1;
+ while !assertion.is_char_boundary(end_ix) {
+ end_ix -= 1;
+ }
+ format!("{}…", &assertion[..end_ix])
+ }
+}
@@ -1,13 +1,16 @@
+mod assertions;
mod example;
+mod examples;
mod ids;
+mod instance;
mod tool_metrics;
-pub(crate) use example::*;
-use parking_lot::Mutex;
+use assertions::display_error_row;
+use instance::{ExampleInstance, JudgeOutput, RunOutput, run_git};
pub(crate) use tool_metrics::*;
use ::fs::RealFs;
-use anyhow::{Result, anyhow};
+use anyhow::anyhow;
use clap::Parser;
use client::{Client, ProxySettings, UserStore};
use collections::{HashMap, HashSet};
@@ -25,18 +28,20 @@ use prompt_store::PromptBuilder;
use release_channel::AppVersion;
use reqwest_client::ReqwestClient;
use settings::{Settings, SettingsStore};
+use std::cell::RefCell;
use std::collections::VecDeque;
use std::env;
use std::path::{Path, PathBuf};
+use std::rc::Rc;
use std::sync::Arc;
use util::ResultExt as _;
#[derive(Parser, Debug)]
#[command(name = "eval", disable_version_flag = true)]
struct Args {
- /// Runs all examples that contain these substrings. If unspecified, all examples are run.
+ /// Runs all examples and threads that contain these substrings. If unspecified, all examples and threads are run.
#[arg(value_name = "EXAMPLE_SUBSTRING")]
- examples: Vec<String>,
+ filter: Vec<String>,
/// Model to use (default: "claude-3-7-sonnet-latest")
#[arg(long, default_value = "claude-3-7-sonnet-latest")]
model: String,
@@ -66,43 +71,30 @@ fn main() {
.parent()
.unwrap()
.parent()
+ .unwrap()
+ .canonicalize()
.unwrap();
- let eval_crate_dir = root_dir.join("crates/eval");
+ let eval_crate_dir = root_dir.join("crates").join("eval");
let repos_dir = eval_crate_dir.join("repos");
let worktrees_dir = eval_crate_dir.join("worktrees");
- let examples_dir = eval_crate_dir.join("examples");
- let runs_dir = eval_crate_dir.join("runs");
- let run_dir = runs_dir.join(format!("{}", run_timestamp));
+ let examples_dir = eval_crate_dir.join("src").join("examples");
+ let run_dir = eval_crate_dir
+ .join("runs")
+ .join(format!("{}", run_timestamp));
std::fs::create_dir_all(&run_dir).unwrap();
std::fs::create_dir_all(&repos_dir).unwrap();
std::fs::create_dir_all(&worktrees_dir).unwrap();
std::fs::create_dir_all(&examples_dir).unwrap();
std::fs::create_dir_all(&paths::config_dir()).unwrap();
- let zed_commit_sha = commit_sha_for_path(root_dir);
- let zed_branch_name = git_branch_for_path(root_dir);
+ let zed_commit_sha = commit_sha_for_path(&root_dir);
+ let zed_branch_name = git_branch_for_path(&root_dir);
let args = Args::parse();
- let all_available_examples = list_all_examples(&examples_dir).unwrap();
-
- let example_paths = all_available_examples
- .iter()
- .filter_map(|example_path| {
- let name = example_path.file_name()?.to_string_lossy();
- if args.examples.is_empty()
- || args
- .examples
- .iter()
- .any(|name_substring| name.contains(name_substring))
- {
- Some(example_path.clone())
- } else {
- None
- }
- })
- .collect::<Vec<_>>();
+ let languages: HashSet<String> = args.languages.into_iter().collect();
let http_client = Arc::new(ReqwestClient::new());
let app = Application::headless().with_http_client(http_client.clone());
+ let all_threads = examples::all(&examples_dir);
app.run(move |cx| {
let app_state = init(cx);
@@ -163,28 +155,40 @@ fn main() {
let mut skipped = Vec::new();
- for example_path in &example_paths {
- let example = Example::load_from_directory(
- example_path,
- &run_dir,
- &worktrees_dir,
- &repos_dir,
- )?;
-
- if !example
- .base
- .language_extension
- .as_ref()
- .map_or(false, |lang| args.languages.contains(lang))
+ for thread in all_threads {
+ let meta = thread.meta();
+ if !args.filter.is_empty() && !args.filter.iter().any(|sub| meta.name.contains(sub))
{
- skipped.push(example.name);
+ skipped.push(meta.name);
continue;
}
- examples.extend(example.repeat(args.repetitions));
+ if meta.language_server.map_or(false, |language| {
+ !languages.contains(&language.file_extension)
+ }) {
+ skipped.push(meta.name);
+ continue;
+ }
+
+ // TODO: This creates a worktree per repetition. Ideally these examples should
+ // either be run sequentially on the same worktree, or reuse worktrees when there
+ // are more examples to run than the concurrency limit.
+ for repetition_number in 0..args.repetitions {
+ let example_instance = ExampleInstance::new(
+ thread.clone(),
+ &repos_dir,
+ &run_dir,
+ &worktrees_dir,
+ repetition_number,
+ );
+
+ examples.push(example_instance);
+ }
}
- println!("Skipped examples: {}\n", skipped.join(", "));
+ if !skipped.is_empty() {
+ println!("Skipped threads: {}", skipped.join(", "));
+ }
if examples.is_empty() {
eprintln!("Filter matched no examples");
@@ -196,22 +200,23 @@ fn main() {
let max_name_width = examples
.iter()
- .map(|e| e.repetition_name().len())
+ .map(|e| e.worktree_name().len())
.max()
.unwrap_or(0);
- for (i, example) in examples.iter_mut().enumerate() {
+
+ for (i, example_instance) in examples.iter_mut().enumerate() {
let color = COLORS[i % COLORS.len()].to_string();
- example.set_log_prefix_style(&color, max_name_width);
+ example_instance.set_log_prefix_style(&color, max_name_width);
println!(
"{}Logging to: {}",
- example.log_prefix,
- example.run_directory_path().display()
+ example_instance.log_prefix,
+ example_instance.run_directory.display()
);
- let repo_url = example.base.url.clone();
+ let repo_url = example_instance.repo_url();
if repo_urls.insert(repo_url.clone()) {
- let repo_path = example.repo_path.clone();
+ let repo_path = example_instance.repo_path.clone();
if !repo_path.join(".git").is_dir() {
println!(
@@ -251,12 +256,12 @@ fn main() {
future::join_all(clone_tasks).await;
- for example in examples.iter_mut() {
- example.fetch().await?;
+ for example_instance in examples.iter_mut() {
+ example_instance.fetch().await?;
}
- let examples = Arc::new(Mutex::new(VecDeque::from(examples)));
- let results_by_example_name = Arc::new(Mutex::new(HashMap::default()));
+ let examples = Rc::new(RefCell::new(VecDeque::from(examples)));
+ let results_by_example_name = Rc::new(RefCell::new(HashMap::default()));
future::join_all((0..args.concurrency).map(|_| {
let app_state = app_state.clone();
@@ -268,7 +273,7 @@ fn main() {
let results = results_by_example_name.clone();
cx.spawn(async move |cx| {
loop {
- let Some(mut example) = examples.lock().pop_front() else {
+ let Some(mut example) = examples.borrow_mut().pop_front() else {
break;
};
let result = async {
@@ -291,7 +296,7 @@ fn main() {
}
.await;
results
- .lock()
+ .borrow_mut()
.entry(example.name.clone())
.or_insert(Vec::new())
.push((example.clone(), result));
@@ -300,97 +305,155 @@ fn main() {
}))
.await;
- println!("\n\n");
- print_header("EVAL RESULTS");
+ print_h1("EVAL RESULTS");
let mut diff_scores = Vec::new();
let mut thread_scores = Vec::new();
+ let mut programmatic_scores = Vec::new();
let mut error_count = 0;
- for (example_name, results) in results_by_example_name.lock().iter_mut() {
- print_header(&example_name);
+ for (example_name, results) in results_by_example_name.borrow_mut().iter_mut() {
+ print_h2(&example_name);
results.sort_unstable_by_key(|(example, _)| example.repetition);
let mut example_cumulative_tool_metrics = ToolMetrics::default();
- println!("┌───────┬──────┬────────┐");
- println!("│ Round │ Diff │ Thread │");
- println!("├───────┼──────┼────────┤");
- for (example, result) in results {
- let run_dir_path = example.run_directory_path();
- let relative_run_dir_path = run_dir_path.strip_prefix(root_dir).unwrap();
+ let mut table_rows = String::new();
+ for (example, result) in results.iter() {
match result {
Err(err) => {
- println!(
- "|{:^7}│{:^6}│{:^8}│ {:?}{}",
+ display_error_row(
+ &mut table_rows,
example.repetition,
- "N/A",
- "N/A",
- err,
- relative_run_dir_path.display()
- );
+ err.to_string(),
+ )?;
error_count += 1;
}
- Ok((run_output, judge_result)) => {
+ Ok((run_output, judge_output)) => {
cumulative_tool_metrics.merge(&run_output.tool_metrics);
example_cumulative_tool_metrics.merge(&run_output.tool_metrics);
- match judge_result {
- Ok(judge_output) => {
- diff_scores.push(judge_output.diff.score());
- thread_scores.push(judge_output.thread.score());
- println!(
- "|{:^7}│{:^6}│{:^8}│ {}",
+ if !run_output.programmatic_assertions.total_count() > 0 {
+ for assertion in &run_output.programmatic_assertions.ran {
+ assertions::display_table_row(
+ &mut table_rows,
+ example.repetition,
+ assertion,
+ )?;
+ }
+
+ programmatic_scores
+ .push(run_output.programmatic_assertions.passed_percentage())
+ }
+
+ if !judge_output.diff.is_empty() {
+ diff_scores.push(judge_output.diff.passed_percentage());
+
+ for assertion in &judge_output.diff.ran {
+ assertions::display_table_row(
+ &mut table_rows,
example.repetition,
- format!("{}%", judge_output.diff.score()),
- format!("{}%", judge_output.thread.score()),
- relative_run_dir_path.display()
- );
+ assertion,
+ )?;
}
- Err(err) => {
- println!(
- "|{:^7}│{:^6}│{:^8}│{:?}│ {}",
+ }
+
+ if !judge_output.thread.is_empty() {
+ thread_scores.push(judge_output.thread.passed_percentage());
+
+ for assertion in &judge_output.thread.ran {
+ assertions::display_table_row(
+ &mut table_rows,
example.repetition,
- "N/A",
- "N/A",
- err,
- relative_run_dir_path.display()
- );
+ assertion,
+ )?;
}
}
}
}
}
- println!("└───────┴──────┴────────┘");
- println!("{}", example_cumulative_tool_metrics);
- }
+ if !table_rows.is_empty() {
+ assertions::print_table_header();
+ print!("{}", table_rows);
+
+ assertions::print_table_divider();
+
+ for (example, result) in results.iter() {
+ if let Ok((run_output, judge_output)) = result {
+ assertions::print_table_round_summary(
+ &example.repetition.to_string(),
+ [
+ &run_output.programmatic_assertions,
+ &judge_output.diff,
+ &judge_output.thread,
+ ]
+ .into_iter(),
+ )
+ }
+ }
- let diff_score_count = diff_scores.len();
- let average_diff_score = diff_scores
- .into_iter()
- .map(|score| score as f32)
- .sum::<f32>()
- / (diff_score_count as f32);
+ assertions::print_table_divider();
+
+ assertions::print_table_round_summary(
+ "avg",
+ results.iter().flat_map(|(_, result)| {
+ result.iter().flat_map(|(run_output, judge_output)| {
+ [
+ &run_output.programmatic_assertions,
+ &judge_output.diff,
+ &judge_output.thread,
+ ]
+ .into_iter()
+ })
+ }),
+ );
+
+ assertions::print_table_footer();
+ }
- if error_count > 0 {
- println!("\n{error_count} examples failed to run!");
+ if !example_cumulative_tool_metrics.is_empty() {
+ println!("{}", &example_cumulative_tool_metrics);
+ }
}
- println!("\nAverage code diff score: {average_diff_score}");
+ if results_by_example_name.borrow().len() > 1 {
+ print_h1("AGGREGATE");
- let thread_score_count = thread_scores.len();
- let average_thread_score = thread_scores
- .into_iter()
- .map(|score| score as f32)
- .sum::<f32>()
- / (thread_score_count as f32);
+ if error_count > 0 {
+ println!("\n{error_count} examples failed to run!");
+ }
- println!("\nAverage thread score: {average_thread_score}");
+ let programmatic_score_count = programmatic_scores.len();
+ if programmatic_score_count > 0 {
+ let average_programmatic_score = (programmatic_scores.into_iter().sum::<f32>()
+ / (programmatic_score_count as f32))
+ .floor();
+ println!("Average programmatic score: {average_programmatic_score}%");
+ }
- print_header("CUMULATIVE TOOL METRICS");
- println!("{}", cumulative_tool_metrics);
+ let diff_score_count = diff_scores.len();
+ if diff_score_count > 0 {
+ let average_diff_score =
+ (diff_scores.into_iter().sum::<f32>() / (diff_score_count as f32)).floor();
+ println!("Average diff score: {average_diff_score}%");
+ }
+
+ let thread_score_count = thread_scores.len();
+
+ if thread_score_count > 0 {
+ let average_thread_score = (thread_scores.into_iter().sum::<f32>()
+ / (thread_score_count as f32))
+ .floor();
+ println!("Average thread score: {average_thread_score}%");
+ }
+
+ println!("");
+
+ print_h2("CUMULATIVE TOOL METRICS");
+ println!("{}", cumulative_tool_metrics);
+ }
app_state.client.telemetry().flush_events().await;
@@ -400,20 +463,6 @@ fn main() {
});
}
-fn list_all_examples(examples_dir: &Path) -> Result<Vec<PathBuf>> {
- let path = std::fs::canonicalize(examples_dir).unwrap();
- let entries = std::fs::read_dir(path).unwrap();
- let mut result_paths = Vec::new();
- for entry in entries {
- let entry = entry?;
- let path = entry.path();
- if path.is_dir() {
- result_paths.push(path);
- }
- }
- Ok(result_paths)
-}
-
/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
pub struct AgentAppState {
pub languages: Arc<LanguageRegistry>,
@@ -570,7 +619,7 @@ pub fn git_branch_for_path(repo_path: &Path) -> String {
}
async fn judge_example(
- example: Example,
+ example: ExampleInstance,
model: Arc<dyn LanguageModel>,
zed_commit_sha: &str,
zed_branch_name: &str,
@@ -578,19 +627,9 @@ async fn judge_example(
run_output: &RunOutput,
enable_telemetry: bool,
cx: &AsyncApp,
-) -> Result<JudgeOutput> {
+) -> JudgeOutput {
let judge_output = example.judge(model.clone(), &run_output, cx).await;
- let diff_evaluation;
- let thread_evaluation;
- if let Ok(output) = judge_output.as_ref() {
- diff_evaluation = Some(output.diff.clone());
- thread_evaluation = Some(output.thread.clone());
- } else {
- diff_evaluation = None;
- thread_evaluation = None;
- }
-
if enable_telemetry {
telemetry::event!(
"Agent Example Evaluated",
@@ -599,15 +638,15 @@ async fn judge_example(
run_id = run_id,
example_name = example.name.clone(),
example_repetition = example.repetition,
- diff_evaluation = diff_evaluation,
- thread_evaluation = thread_evaluation,
+ diff_evaluation = judge_output.diff.clone(),
+ thread_evaluation = judge_output.thread.clone(),
tool_metrics = run_output.tool_metrics,
response_count = run_output.response_count,
token_usage = run_output.token_usage,
model = model.telemetry_id(),
model_provider = model.provider_id().to_string(),
- repository_url = example.base.url.clone(),
- repository_revision = example.base.revision.clone(),
+ repository_url = example.repo_url(),
+ repository_revision = example.revision(),
diagnostic_summary_before = run_output.diagnostic_summary_before,
diagnostic_summary_after = run_output.diagnostic_summary_after,
diagnostics_before = run_output.diagnostics_before,
@@ -618,8 +657,16 @@ async fn judge_example(
judge_output
}
-fn print_header(header: &str) {
- println!("\n========================================");
- println!("{:^40}", header);
- println!("========================================\n");
+const HEADER_WIDTH: usize = 65;
+
+fn print_h1(header: &str) {
+ println!("\n\n{:=^HEADER_WIDTH$}", "");
+ println!("{:^HEADER_WIDTH$}", header);
+ println!("{:=^HEADER_WIDTH$}\n", "");
+}
+
+fn print_h2(header: &str) {
+ println!("\n{:-^HEADER_WIDTH$}", "");
+ println!("{:^HEADER_WIDTH$}", header);
+ println!("{:-^HEADER_WIDTH$}\n", "");
}
@@ -1,53 +1,57 @@
-use crate::{AgentAppState, ToolMetrics};
-use agent::{ThreadEvent, ThreadStore};
-use anyhow::{Context as _, Result, anyhow};
-use assistant_tool::ToolWorkingSet;
-use client::proto::LspWorkProgress;
-use futures::channel::mpsc;
-use futures::{FutureExt, StreamExt as _, select_biased};
-use gpui::{App, AppContext as _, AsyncApp, Entity, Task};
-use handlebars::Handlebars;
-use language::{Buffer, DiagnosticSeverity, OffsetRangeExt};
-use language_model::{
- LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
- MessageContent, Role, StopReason, TokenUsage,
-};
-use project::{DiagnosticSummary, Project, ProjectPath};
-use serde::{Deserialize, Serialize};
-use std::cell::RefCell;
-use std::fmt::Write as _;
-use std::fs::File;
-use std::io::Write as _;
-use std::rc::Rc;
-use std::sync::{Arc, Mutex};
-use std::time::Duration;
use std::{
- fs,
- path::{Path, PathBuf},
+ error::Error,
+ fmt::{self, Debug},
+ sync::{Arc, Mutex},
+ time::Duration,
};
-use unindent::Unindent as _;
-use util::ResultExt as _;
-use util::command::new_smol_command;
-use util::markdown::MarkdownString;
-use util::serde::default_true;
-const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
+use crate::{
+ ToolMetrics,
+ assertions::{AssertionsReport, RanAssertion, RanAssertionResult},
+};
+use agent::ThreadEvent;
+use anyhow::{Result, anyhow};
+use async_trait::async_trait;
+use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased};
+use gpui::{AppContext, AsyncApp, Entity};
+use language_model::{LanguageModel, Role, StopReason};
+
+pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
+
+#[async_trait(?Send)]
+pub trait Example {
+ fn meta(&self) -> ExampleMetadata;
+ async fn conversation(&self, cx: &mut ExampleContext) -> Result<()>;
+ fn diff_assertions(&self) -> Vec<JudgeAssertion> {
+ Vec::new()
+ }
+ fn thread_assertions(&self) -> Vec<JudgeAssertion> {
+ Vec::new()
+ }
+}
-const ZED_REPO_URL: &str = "https://github.com/zed-industries/zed.git";
+#[derive(Clone, Debug)]
+pub struct JudgeAssertion {
+ pub id: String,
+ pub description: String,
+}
-#[derive(Clone, Debug, Deserialize)]
-pub struct ExampleBase {
+#[derive(Clone, Debug)]
+pub struct ExampleMetadata {
+ pub name: String,
pub url: String,
pub revision: String,
- pub language_extension: Option<String>,
- pub insert_id: Option<String>,
- #[serde(default = "default_true")]
- pub require_lsp: bool,
- #[serde(default)]
+ pub language_server: Option<LanguageServer>,
+ pub max_assertions: Option<usize>,
+}
+
+#[derive(Clone, Debug)]
+pub struct LanguageServer {
+ pub file_extension: String,
pub allow_preexisting_diagnostics: bool,
}
-impl ExampleBase {
+impl ExampleMetadata {
pub fn repo_name(&self) -> String {
self.url
.split('/')
@@ -58,1042 +62,310 @@ impl ExampleBase {
}
}
-#[derive(Clone, Debug)]
-pub struct Example {
- pub name: String,
- /// Content of `base.toml`
- pub base: ExampleBase,
- /// Content of `prompt.md`
- pub prompt: String,
- /// Content of `diff_criteria.md`
- pub diff_criteria: String,
- /// Content of `thread_criteria.md`, if that file exists (it's optional)
- pub thread_criteria: String,
- /// Prefix used for logging that identifies this example
- pub log_prefix: String,
- /// The repetition number for this example (0-based)
- /// When running multiple repetitions of the same example, each instance is assigned a unique repetition number.
- /// This affects the worktree path and log prefix to avoid clobbering results between runs.
- pub repetition: usize,
- pub repo_path: PathBuf,
- /// Path to the directory containing the requests and responses for the agentic loop
- run_dir_path: PathBuf,
- worktrees_dir: PathBuf,
-}
+pub struct FailedAssertion(pub String);
-#[derive(Debug, Serialize, Clone)]
-pub struct RunOutput {
- pub repository_diff: String,
- pub ran_diagnostics_check: bool,
- pub diagnostic_summary_before: DiagnosticSummary,
- pub diagnostic_summary_after: DiagnosticSummary,
- pub diagnostics_before: Option<String>,
- pub diagnostics_after: Option<String>,
- pub response_count: usize,
- pub token_usage: TokenUsage,
- pub tool_metrics: ToolMetrics,
- pub last_request: LanguageModelRequest,
-}
-
-#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct JudgeDiffInput {
- pub repository_diff: String,
- pub criteria: String,
+impl fmt::Debug for FailedAssertion {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "Assertion failure: {}", self.0)
+ }
}
-#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct JudgeThreadInput {
- pub messages: String,
- pub criteria: String,
+impl fmt::Display for FailedAssertion {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "{}", self.0)
+ }
}
-#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct JudgeResponse {
- pub analysis: String,
- pub passing_criteria: u32,
- pub total_criteria: u32,
-}
+impl Error for FailedAssertion {}
-#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct JudgeOutput {
- pub thread: JudgeResponse,
- pub diff: JudgeResponse,
+pub struct ExampleContext {
+ meta: ExampleMetadata,
+ log_prefix: String,
+ agent_thread: Entity<agent::Thread>,
+ app: AsyncApp,
+ model: Arc<dyn LanguageModel>,
+ pub assertions: AssertionsReport,
+ pub tool_metrics: Arc<Mutex<ToolMetrics>>,
}
-impl Example {
- /// Load an example from a directory containing base.toml, prompt.md, and criteria.md
- pub fn load_from_directory(
- dir_path: &Path,
- run_dir: &Path,
- worktrees_dir: &Path,
- repos_dir: &Path,
- ) -> Result<Self> {
- let name = dir_path.file_name().unwrap().to_string_lossy().to_string();
- let base_path = dir_path.join("base.toml");
- let prompt_path = dir_path.join("prompt.md");
- let diff_criteria_path = dir_path.join("diff_criteria.md");
- let thread_criteria_path = dir_path.join("thread_criteria.md");
- let base: ExampleBase = toml::from_str(&fs::read_to_string(&base_path)?)?;
-
- let repo_path = repo_path_for_url(repos_dir, &base.url);
-
- Ok(Example {
- name: name.clone(),
- base,
- prompt: fs::read_to_string(prompt_path.clone())?,
- thread_criteria: fs::read_to_string(thread_criteria_path.clone())?,
- diff_criteria: fs::read_to_string(diff_criteria_path.clone())?,
- run_dir_path: run_dir.to_path_buf(),
- worktrees_dir: worktrees_dir.to_path_buf(),
- repo_path,
- log_prefix: name,
- repetition: 0,
- })
- }
-
- pub fn repetition_name(&self) -> String {
- format!("{}-{}", self.name, self.repetition)
- }
-
- pub fn worktree_path(&self) -> PathBuf {
- self.worktrees_dir
- .canonicalize()
- .unwrap()
- .join(self.repetition_name())
- .join(&self.base.repo_name())
+impl ExampleContext {
+ pub fn new(
+ meta: ExampleMetadata,
+ log_prefix: String,
+ agent_thread: Entity<agent::Thread>,
+ model: Arc<dyn LanguageModel>,
+ app: AsyncApp,
+ ) -> Self {
+ let assertions = AssertionsReport::new(meta.max_assertions);
+
+ Self {
+ meta,
+ log_prefix,
+ agent_thread,
+ assertions,
+ model,
+ app,
+ tool_metrics: Arc::new(Mutex::new(ToolMetrics::default())),
+ }
}
- pub fn run_directory_path(&self) -> PathBuf {
- self.run_dir_path.join(self.repetition_name())
+ pub fn push_user_message(&mut self, text: impl ToString) {
+ self.app
+ .update_entity(&self.agent_thread, |thread, cx| {
+ thread.insert_user_message(text.to_string(), vec![], None, cx);
+ })
+ .unwrap();
}
- /// Create an iterator that returns copies of this example with different repetition numbers
- /// Each copy will have a different repetition number and worktree path based on the repetition
- pub fn repeat(self, repetitions: usize) -> impl Iterator<Item = Self> {
- (0..repetitions).map(move |repetition| {
- let mut example = self.clone();
- example.repetition = repetition;
- example
- })
+ pub fn assert(&mut self, expected: bool, message: impl ToString) -> Result<()> {
+ let message = message.to_string();
+ self.log_assertion(
+ if expected {
+ Ok(())
+ } else {
+ Err(anyhow::Error::from(FailedAssertion(message.clone())))
+ },
+ message,
+ )
}
- pub fn set_log_prefix_style(&mut self, color: &str, name_width: usize) {
- self.log_prefix = format!(
- "{}{:<width$}\x1b[0m | ",
- color,
- self.repetition_name(),
- width = name_width
- );
+ pub fn assert_some<T>(&mut self, option: Option<T>, message: impl ToString) -> Result<T> {
+ let message = message.to_string();
+ self.log_assertion(
+ match option {
+ Some(value) => Ok(value),
+ None => Err(anyhow::Error::from(FailedAssertion(message.clone()))),
+ },
+ message,
+ )
}
- /// Set up the example by checking out the specified Git revision
- pub async fn fetch(&mut self) -> Result<()> {
- let revision_exists = run_git(
- &self.repo_path,
- &["rev-parse", &format!("{}^{{commit}}", self.base.revision)],
+ #[allow(dead_code)]
+ pub fn assert_eq<T: PartialEq + Debug>(
+ &mut self,
+ left: T,
+ right: T,
+ message: impl ToString,
+ ) -> Result<()> {
+ let message = message.to_string();
+ self.log_assertion(
+ if left == right {
+ Ok(())
+ } else {
+ println!("{}{:#?} != {:#?}", self.log_prefix, left, right);
+ Err(anyhow::Error::from(FailedAssertion(message.clone())))
+ },
+ message,
)
- .await
- .is_ok();
+ }
- if !revision_exists {
- println!(
- "{}Fetching revision {}",
- self.log_prefix, &self.base.revision
- );
- run_git(
- &self.repo_path,
- &["fetch", "--depth", "1", "origin", &self.base.revision],
- )
- .await?;
+ fn log_assertion<T>(&mut self, result: Result<T>, message: String) -> Result<T> {
+ if let Some(max) = self.meta.max_assertions {
+ if self.assertions.run_count() > max {
+ return Err(anyhow!(
+ "More assertions were run than the stated max_assertions of {}",
+ max
+ ));
+ }
}
- Ok(())
- }
- /// Set up the example by checking out the specified Git revision
- pub async fn setup(&mut self) -> Result<()> {
- let worktree_path = self.worktree_path();
- if worktree_path.is_dir() {
- println!("{}Resetting existing worktree", self.log_prefix);
+ self.assertions.ran.push(RanAssertion {
+ id: message.clone(),
+ result: Ok(RanAssertionResult {
+ analysis: None,
+ passed: result.is_ok(),
+ }),
+ });
- // TODO: consider including "-x" to remove ignored files. The downside of this is that
- // it will also remove build artifacts, and so prevent incremental reuse there.
- run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
- run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
- run_git(&worktree_path, &["checkout", &self.base.revision]).await?;
+ if result.is_ok() {
+ println!("{}✅ {}", self.log_prefix, message);
} else {
- println!("{}Creating worktree", self.log_prefix);
-
- let worktree_path_string = worktree_path.to_string_lossy().to_string();
-
- run_git(
- &self.repo_path,
- &[
- "worktree",
- "add",
- "-f",
- &worktree_path_string,
- &self.base.revision,
- ],
- )
- .await?;
+ println!("{}❌ {}", self.log_prefix, message);
}
- if self.base.url == ZED_REPO_URL {
- std::fs::write(worktree_path.join(".rules"), std::fs::read(".rules")?)?;
- }
-
- std::fs::create_dir_all(self.run_directory_path())?;
-
- Ok(())
+ result
}
- pub fn run(
- &self,
- model: Arc<dyn LanguageModel>,
- app_state: Arc<AgentAppState>,
- cx: &mut App,
- ) -> Task<Result<RunOutput>> {
- let project = Project::local(
- app_state.client.clone(),
- app_state.node_runtime.clone(),
- app_state.user_store.clone(),
- app_state.languages.clone(),
- app_state.fs.clone(),
- None,
- cx,
- );
-
- let worktree = project.update(cx, |project, cx| {
- project.create_worktree(self.worktree_path(), true, cx)
- });
-
- let tools = cx.new(|_| ToolWorkingSet::default());
- let thread_store =
- ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
- let this = self.clone();
-
- cx.spawn(async move |cx| {
- let worktree = worktree.await?;
-
- // Wait for worktree scan to finish before choosing a file to open.
- worktree
- .update(cx, |worktree, _cx| {
- worktree.as_local().unwrap().scan_complete()
- })?
- .await;
-
- let lsp = if this.base.require_lsp {
- let language_extension = this.base.language_extension.as_deref().context(
- "language_extension field is required in base.toml when `require_lsp == true`",
- )?;
-
- // Open a file that matches the language to cause LSP to start.
- let language_file = worktree.read_with(cx, |worktree, _cx| {
- worktree
- .files(false, 0)
- .find_map(|e| {
- if e.path.clone().extension().and_then(|ext| ext.to_str())
- == Some(language_extension)
- {
- Some(ProjectPath {
- worktree_id: worktree.id(),
- path: e.path.clone(),
- })
- } else {
- None
- }
- })
- .context("Failed to find a file for example language")
- })??;
-
- let open_language_file_buffer_task = project.update(cx, |project, cx| {
- project.open_buffer(language_file.clone(), cx)
- })?;
-
- let language_file_buffer = open_language_file_buffer_task.await?;
-
- let lsp_open_handle = project.update(cx, |project, cx| {
- project.register_buffer_with_language_servers(&language_file_buffer, cx)
- })?;
-
- wait_for_lang_server(&project, &language_file_buffer, this.log_prefix.clone(), cx).await?;
-
- Some((lsp_open_handle, language_file_buffer))
- } else {
- None
- };
-
- let diagnostic_summary_before = project.read_with(cx, |project, cx| {
- project.diagnostic_summary(false, cx)
- })?;
- let diagnostics_before = query_lsp_diagnostics(project.clone(), cx).await?;
- if diagnostics_before.is_some() && !this.base.allow_preexisting_diagnostics {
- return Err(anyhow!("Example has pre-existing diagnostics. If you want to run this example regardless, set `allow_preexisting_diagnostics` to `true` in `base.toml`"));
- }
-
- if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() {
- return Err(anyhow!("Setup only mode"));
- }
-
- let example_output_dir = this.run_directory_path();
- let last_diff_file_path = example_output_dir.join("last.diff");
-
- // Write an empty "last.diff" so that it can be opened in Zed for convenient view of the
- // history using undo/redo.
- std::fs::write(&last_diff_file_path, "")?;
-
- let thread_store = thread_store.await?;
- let thread =
- thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
- let last_request = Rc::new(RefCell::new(None));
-
- thread.update(cx, |thread, _cx| {
- let mut request_count = 0;
- let last_request = Rc::clone(&last_request);
- let previous_diff = Rc::new(RefCell::new("".to_string()));
- let example_output_dir = example_output_dir.clone();
- let last_diff_file_path = last_diff_file_path.clone();
- let this = this.clone();
- thread.set_request_callback(move |request, response_events| {
- *last_request.borrow_mut() = Some(request.clone());
+ pub async fn run_to_end(&mut self) -> Result<Response> {
+ self.run_turns(u32::MAX).await
+ }
- request_count += 1;
- let messages_file_path = example_output_dir.join(format!("{request_count}.messages.md"));
- let diff_file_path = example_output_dir.join(format!("{request_count}.diff"));
- let last_messages_file_path = example_output_dir.join("last.messages.md");
- let request_markdown = RequestMarkdown::new(request);
- let response_events_markdown = response_events_to_markdown(response_events);
+ pub async fn run_turn(&mut self) -> Result<Response> {
+ self.run_turns(1).await
+ }
- let messages = format!("{}\n\n{}", request_markdown.messages, response_events_markdown);
- fs::write(&messages_file_path, messages.clone()).expect("failed to write messages file");
- fs::write(&last_messages_file_path, messages).expect("failed to write last messages file");
+ pub async fn run_turns(&mut self, iterations: u32) -> Result<Response> {
+ let (mut tx, mut rx) = mpsc::channel(1);
- let diff_result = smol::block_on(this.repository_diff());
- match diff_result {
- Ok(diff) => {
- if diff != previous_diff.borrow().clone() {
- fs::write(&diff_file_path, &diff).expect("failed to write diff file");
- fs::write(&last_diff_file_path, &diff).expect("failed to write last diff file");
- *previous_diff.borrow_mut() = diff;
- }
- }
- Err(err) => {
- let error_message = format!("{err:?}");
- fs::write(&diff_file_path, &error_message).expect("failed to write diff error to file");
- fs::write(&last_diff_file_path, &error_message).expect("failed to write last diff file");
+ let tool_metrics = self.tool_metrics.clone();
+ let log_prefix = self.log_prefix.clone();
+ let _subscription = self.app.subscribe(
+ &self.agent_thread,
+ move |thread, event: &ThreadEvent, cx| match event {
+ ThreadEvent::ShowError(thread_error) => {
+ tx.try_send(Err(anyhow!(thread_error.clone()))).ok();
+ }
+ ThreadEvent::Stopped(reason) => match reason {
+ Ok(StopReason::EndTurn) => {
+ tx.close_channel();
+ }
+ Ok(StopReason::ToolUse) => {
+ if thread.read(cx).remaining_turns() == 0 {
+ tx.close_channel();
}
}
-
- if request_count == 1 {
- let tools_file_path = example_output_dir.join("tools.md");
- fs::write(tools_file_path, request_markdown.tools).expect("failed to write tools file");
+ Ok(StopReason::MaxTokens) => {
+ tx.try_send(Err(anyhow!("Exceeded maximum tokens"))).ok();
}
- });
- })?;
-
- let tool_metrics = Arc::new(Mutex::new(ToolMetrics::default()));
-
- let (thread_event_tx, mut thread_event_rx) = mpsc::unbounded();
-
- let subscription = cx.subscribe(&thread, move |_thread, event: &ThreadEvent, _cx| {
- thread_event_tx.unbounded_send(event.clone()).log_err();
- });
-
- let event_handler_task = cx.spawn({
- let log_prefix = this.log_prefix.clone();
- let tool_metrics = tool_metrics.clone();
- let thread = thread.downgrade();
- async move |cx| {
- loop {
- let event = select_biased! {
- event = thread_event_rx.next() => event,
- _ = cx.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
- return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT));
- }
- };
- let Some(event) = event else {
- return Err(anyhow!("ThreadEvent channel ended early"));
- };
-
- match event {
- ThreadEvent::Stopped(reason) => match reason {
- Ok(StopReason::EndTurn) => {
- return Ok(());
- }
- Ok(StopReason::MaxTokens) => {
- return Err(anyhow!("Exceeded maximum tokens"));
- }
- Ok(StopReason::ToolUse) => {
- if std::env::var("ZED_EVAL_DEBUG").is_ok() {
- println!("{}StopReason: Tool use", log_prefix);
- }
- }
- Err(error) => {
- return Err(anyhow!(error.clone()));
- }
- },
- ThreadEvent::ShowError(thread_error) => {
- break Err(anyhow!(thread_error.clone()));
- }
- ThreadEvent::StreamedAssistantText(_, _)| ThreadEvent::StreamedAssistantThinking(_, _) | ThreadEvent::UsePendingTools { .. } => {
- }
- ThreadEvent::ToolFinished {
- tool_use_id,
- pending_tool_use,
- ..
- } => {
- thread.update(cx, |thread, _cx| {
- if let Some(tool_use) = pending_tool_use {
- let mut tool_metrics = tool_metrics.lock().unwrap();
- if let Some(tool_result) = thread.tool_result(&tool_use_id) {
- let message = if tool_result.is_error {
- format!("TOOL FAILED: {}", tool_use.name)
- } else {
- format!("TOOL FINISHED: {}", tool_use.name)
- };
- println!("{log_prefix}{message}");
- tool_metrics.insert(tool_result.tool_name.clone(), !tool_result.is_error);
- } else {
- let message = format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
- println!("{log_prefix}{message}");
- tool_metrics.insert(tool_use.name.clone(), true);
- }
- }
- })?;
- }
- ThreadEvent::ToolConfirmationNeeded => {
- panic!("{}Bug: Tool confirmation should not be required in eval", log_prefix);
- },
- ThreadEvent::StreamedToolUse { .. } |
- ThreadEvent::StreamedCompletion |
- ThreadEvent::MessageAdded(_) |
- ThreadEvent::MessageEdited(_) |
- ThreadEvent::MessageDeleted(_) |
- ThreadEvent::SummaryChanged |
- ThreadEvent::SummaryGenerated |
- ThreadEvent::CheckpointChanged |
- ThreadEvent::ReceivedTextChunk |
- ThreadEvent::UsageUpdated(_) => {
- if std::env::var("ZED_EVAL_DEBUG").is_ok() {
- println!("{}Event: {:#?}", log_prefix, event);
- }
+ Err(err) => {
+ tx.try_send(Err(anyhow!(err.clone()))).ok();
+ }
+ },
+ ThreadEvent::StreamedAssistantText(_, _)
+ | ThreadEvent::StreamedAssistantThinking(_, _)
+ | ThreadEvent::UsePendingTools { .. } => {}
+ ThreadEvent::ToolFinished {
+ tool_use_id,
+ pending_tool_use,
+ ..
+ } => {
+ thread.update(cx, |thread, _cx| {
+ if let Some(tool_use) = pending_tool_use {
+ let mut tool_metrics = tool_metrics.lock().unwrap();
+ if let Some(tool_result) = thread.tool_result(&tool_use_id) {
+ let message = if tool_result.is_error {
+ format!("TOOL FAILED: {}", tool_use.name)
+ } else {
+ format!("TOOL FINISHED: {}", tool_use.name)
+ };
+ println!("{log_prefix}{message}");
+ tool_metrics
+ .insert(tool_result.tool_name.clone(), !tool_result.is_error);
+ } else {
+ let message =
+ format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
+ println!("{log_prefix}{message}");
+ tool_metrics.insert(tool_use.name.clone(), true);
}
}
- }
+ });
}
- });
-
- thread.update(cx, |thread, cx| {
- let context = vec![];
- thread.insert_user_message(this.prompt.clone(), context, None, cx);
- thread.send_to_model(model, cx);
- })?;
-
- event_handler_task.await?;
-
- println!("{}Stopped", this.log_prefix);
-
- if let Some((_, language_file_buffer)) = lsp.as_ref() {
- wait_for_lang_server(&project, &language_file_buffer, this.log_prefix.clone(), cx).await?;
- }
-
- println!("{}Getting repository diff", this.log_prefix);
- let repository_diff = this.repository_diff().await?;
- std::fs::write(last_diff_file_path, &repository_diff)?;
-
- println!("{}Getting diagnostics", this.log_prefix);
- let diagnostic_summary_after = project.read_with(cx, |project, cx| {
- project.diagnostic_summary(false, cx)
- })?;
- let diagnostics_after = cx
- .update(move |cx| {
- cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await)
- })?
- .await?;
- println!("{}Got diagnostics", this.log_prefix);
-
- let Some(last_request) = last_request.borrow_mut().take() else {
- return Err(anyhow!("No requests ran."));
- };
-
- drop(subscription);
- drop(lsp);
-
- if let Some(diagnostics_before) = &diagnostics_before {
- fs::write(example_output_dir.join("diagnostics_before.txt"), diagnostics_before)?;
- }
-
- if let Some(diagnostics_after) = &diagnostics_after {
- fs::write(example_output_dir.join("diagnostics_after.txt"), diagnostics_after)?;
- }
-
-
- thread.update(cx, |thread, _cx| {
- let response_count = thread
- .messages()
- .filter(|message| message.role == language_model::Role::Assistant)
- .count();
- RunOutput {
- repository_diff,
- ran_diagnostics_check: this.base.require_lsp,
- diagnostic_summary_before,
- diagnostic_summary_after,
- diagnostics_before,
- diagnostics_after,
- response_count,
- token_usage: thread.cumulative_token_usage(),
- tool_metrics: tool_metrics.lock().unwrap().clone(),
- last_request,
+ ThreadEvent::ToolConfirmationNeeded => {
+ panic!(
+ "{}Bug: Tool confirmation should not be required in eval",
+ log_prefix
+ );
+ }
+ ThreadEvent::StreamedCompletion
+ | ThreadEvent::MessageAdded(_)
+ | ThreadEvent::MessageEdited(_)
+ | ThreadEvent::MessageDeleted(_)
+ | ThreadEvent::SummaryChanged
+ | ThreadEvent::SummaryGenerated
+ | ThreadEvent::ReceivedTextChunk
+ | ThreadEvent::StreamedToolUse { .. }
+ | ThreadEvent::CheckpointChanged
+ | ThreadEvent::UsageUpdated(_) => {
+ tx.try_send(Ok(())).ok();
+ if std::env::var("ZED_EVAL_DEBUG").is_ok() {
+ println!("{}Event: {:#?}", log_prefix, event);
+ }
}
- })
- })
- }
-
- async fn judge_diff(
- &self,
- model: Arc<dyn LanguageModel>,
- run_output: &RunOutput,
- cx: &AsyncApp,
- ) -> Result<(String, JudgeResponse)> {
- let judge_diff_prompt = include_str!("judge_diff_prompt.hbs");
- let judge_diff_prompt_name = "judge_diff_prompt";
- let mut hbs = Handlebars::new();
- hbs.register_template_string(judge_diff_prompt_name, judge_diff_prompt)?;
-
- let diff_prompt = hbs.render(
- judge_diff_prompt_name,
- &JudgeDiffInput {
- repository_diff: run_output.repository_diff.clone(),
- criteria: self.diff_criteria.clone(),
- },
- )?;
-
- let request = LanguageModelRequest {
- thread_id: None,
- prompt_id: None,
- messages: vec![LanguageModelRequestMessage {
- role: Role::User,
- content: vec![MessageContent::Text(diff_prompt)],
- cache: false,
- }],
- temperature: None,
- tools: Vec::new(),
- stop: Vec::new(),
- };
-
- let diff_response = send_language_model_request(model, request, cx).await?;
- let diff_output = JudgeResponse::parse(&diff_response)?;
-
- println!(
- "{}Judge - Diff score: {}%",
- self.log_prefix,
- diff_output.score()
- );
-
- Ok((diff_response, diff_output))
- }
-
- async fn judge_thread(
- &self,
- model: Arc<dyn LanguageModel>,
- run_output: &RunOutput,
- cx: &AsyncApp,
- ) -> Result<(String, JudgeResponse)> {
- let judge_thread_prompt = include_str!("judge_thread_prompt.hbs");
- let judge_thread_prompt_name = "judge_thread_prompt";
- let mut hbs = Handlebars::new();
- hbs.register_template_string(judge_thread_prompt_name, judge_thread_prompt)?;
-
- let request_markdown = RequestMarkdown::new(&run_output.last_request);
- let thread_prompt = hbs.render(
- judge_thread_prompt_name,
- &JudgeThreadInput {
- messages: request_markdown.messages,
- criteria: self.thread_criteria.clone(),
},
- )?;
-
- let request = LanguageModelRequest {
- thread_id: None,
- prompt_id: None,
- messages: vec![LanguageModelRequestMessage {
- role: Role::User,
- content: vec![MessageContent::Text(thread_prompt)],
- cache: false,
- }],
- temperature: None,
- tools: Vec::new(),
- stop: Vec::new(),
- };
-
- let thread_response = send_language_model_request(model, request, cx).await?;
- let thread_output = JudgeResponse::parse(&thread_response)?;
-
- println!(
- "{}Judge - Thread score: {}%",
- self.log_prefix,
- thread_output.score()
);
- Ok((thread_response, thread_output))
- }
-
- pub async fn judge(
- &self,
- model: Arc<dyn LanguageModel>,
- run_output: &RunOutput,
- cx: &AsyncApp,
- ) -> Result<JudgeOutput> {
- let mut output_file = File::create(self.run_directory_path().join("judge.md"))
- .expect("failed to create judge.md");
-
- println!("{}Running judge", self.log_prefix);
-
- let diff_task = self.judge_diff(model.clone(), &run_output, cx);
- let thread_task = self.judge_thread(model.clone(), &run_output, cx);
-
- let (diff_result, thread_result) = futures::join!(diff_task, thread_task);
-
- let (diff_response, diff_output) = diff_result?;
- let (thread_response, thread_output) = thread_result?;
-
- writeln!(
- &mut output_file,
- "# Judgment\n\n## Thread\n\n{thread_response}\n\n## Diff\n\n{diff_response}",
- )
- .log_err();
-
- Ok(JudgeOutput {
- thread: thread_output,
- diff: diff_output,
- })
- }
-
- async fn repository_diff(&self) -> Result<String> {
- let worktree_path = self.worktree_path();
-
- run_git(&worktree_path, &["add", "."]).await?;
- let mut diff_args = vec!["diff", "--staged"];
- if self.base.url == ZED_REPO_URL {
- diff_args.push(":(exclude).rules");
- }
- run_git(&worktree_path, &diff_args).await
- }
-}
-
-fn wait_for_lang_server(
- project: &Entity<Project>,
- buffer: &Entity<Buffer>,
- log_prefix: String,
- cx: &mut AsyncApp,
-) -> Task<Result<()>> {
- println!("{}⏵ Waiting for language server", log_prefix);
-
- let (mut tx, mut rx) = mpsc::channel(1);
-
- let lsp_store = project
- .update(cx, |project, _| project.lsp_store())
- .unwrap();
-
- let has_lang_server = buffer
- .update(cx, |buffer, cx| {
- lsp_store.update(cx, |lsp_store, cx| {
- lsp_store
- .language_servers_for_local_buffer(&buffer, cx)
- .next()
- .is_some()
- })
- })
- .unwrap_or(false);
-
- if has_lang_server {
- project
- .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
- .unwrap()
- .detach();
- }
-
- let subscriptions =
- [
- cx.subscribe(&lsp_store, {
- let log_prefix = log_prefix.clone();
- move |_, event, _| match event {
- project::LspStoreEvent::LanguageServerUpdate {
- message:
- client::proto::update_language_server::Variant::WorkProgress(
- LspWorkProgress {
- message: Some(message),
- ..
- },
- ),
- ..
- } => println!("{}⟲ {message}", log_prefix),
- _ => {}
- }
- }),
- cx.subscribe(&project, {
- let buffer = buffer.clone();
- move |project, event, cx| match event {
- project::Event::LanguageServerAdded(_, _, _) => {
- let buffer = buffer.clone();
- project
- .update(cx, |project, cx| project.save_buffer(buffer, cx))
- .detach();
+ let model = self.model.clone();
+
+ let message_count_before = self.app.update_entity(&self.agent_thread, |thread, cx| {
+ thread.set_remaining_turns(iterations);
+ thread.send_to_model(model, cx);
+ thread.messages().len()
+ })?;
+
+ loop {
+ select_biased! {
+ result = rx.next() => {
+ if let Some(result) = result {
+ result?;
+ } else {
+ break;
}
- project::Event::DiskBasedDiagnosticsFinished { .. } => {
- tx.try_send(()).ok();
- }
- _ => {}
}
- }),
- ];
-
- cx.spawn(async move |cx| {
- let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
- let result = futures::select! {
- _ = rx.next() => {
- println!("{}⚑ Language server idle", log_prefix);
- anyhow::Ok(())
- },
- _ = timeout.fuse() => {
- Err(anyhow!("LSP wait timed out after 5 minutes"))
+ _ = self.app.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
+ return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT));
+ }
}
- };
- drop(subscriptions);
- result
- })
-}
-
-async fn query_lsp_diagnostics(
- project: Entity<Project>,
- cx: &mut AsyncApp,
-) -> Result<Option<String>> {
- let paths_with_diagnostics = project.update(cx, |project, cx| {
- project
- .diagnostic_summaries(true, cx)
- .filter(|(_, _, summary)| summary.error_count > 0 || summary.warning_count > 0)
- .map(|(project_path, _, _)| project_path)
- .collect::<Vec<_>>()
- })?;
-
- if paths_with_diagnostics.is_empty() {
- return Ok(None);
- }
-
- let mut output = String::new();
- for project_path in paths_with_diagnostics {
- let buffer = project
- .update(cx, |project, cx| project.open_buffer(project_path, cx))?
- .await?;
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
-
- for (_, group) in snapshot.diagnostic_groups(None) {
- let entry = &group.entries[group.primary_ix];
- let range = entry.range.to_point(&snapshot);
- let severity = match entry.diagnostic.severity {
- DiagnosticSeverity::ERROR => "error",
- DiagnosticSeverity::WARNING => "warning",
- _ => continue,
- };
-
- writeln!(
- output,
- "{} at line {}: {}",
- severity,
- range.start.row + 1,
- entry.diagnostic.message
- )?;
}
- }
- anyhow::Ok(Some(output))
-}
-
-impl JudgeResponse {
- fn parse(response: &str) -> Result<Self> {
- let analysis = get_tag("analysis", response)?.to_string();
- let passing_criteria = get_tag("passing_criteria", response)?
- .parse()
- .context("error parsing score")?;
- let total_criteria = get_tag("total_criteria", response)?
- .parse()
- .context("error parsing score")?;
- Ok(Self {
- analysis,
- total_criteria,
- passing_criteria,
- })
- }
-
- pub fn score(&self) -> u32 {
- (100.0 * self.passing_criteria as f32 / self.total_criteria as f32).round() as u32
- }
-}
-fn get_tag(name: &'static str, response: &str) -> Result<String> {
- let start_tag = format!("<{}>", name);
- let end_tag = format!("</{}>", name);
-
- let start_ix = response
- .find(&start_tag)
- .context(format!("{} start tag not found", name))?;
- let content_start_ix = start_ix + start_tag.len();
-
- let end_ix = content_start_ix
- + response[content_start_ix..]
- .find(&end_tag)
- .context(format!("{} end tag not found", name))?;
+ let messages = self.app.read_entity(&self.agent_thread, |thread, cx| {
+ let mut messages = Vec::new();
+ for message in thread.messages().skip(message_count_before) {
+ messages.push(Message {
+ _role: message.role,
+ _text: message.to_string(),
+ tool_use: thread
+ .tool_uses_for_message(message.id, cx)
+ .into_iter()
+ .map(|tool_use| ToolUse {
+ name: tool_use.name.to_string(),
+ value: tool_use.input,
+ })
+ .collect(),
+ });
+ }
+ messages
+ })?;
- let content = response[content_start_ix..end_ix].trim().unindent();
+ let response = Response::new(messages);
- anyhow::Ok(content)
+ Ok(response)
+ }
}
-pub fn repo_path_for_url(repos_dir: &Path, repo_url: &str) -> PathBuf {
- let repo_name = repo_url
- .trim_start_matches("https://")
- .replace(|c: char| !c.is_alphanumeric(), "-");
- Path::new(repos_dir)
- .canonicalize()
- .context(format!("No such directory {}", repos_dir.display()))
- .unwrap()
- .join(repo_name)
+#[derive(Debug)]
+pub struct Response {
+ messages: Vec<Message>,
}
-pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
- let output = new_smol_command("git")
- .current_dir(repo_path)
- .args(args)
- .output()
- .await?;
-
- if output.status.success() {
- Ok(String::from_utf8(output.stdout)?.trim().to_string())
- } else {
- Err(anyhow!(
- "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
- args.join(" "),
- repo_path.display(),
- output.status,
- String::from_utf8_lossy(&output.stderr),
- String::from_utf8_lossy(&output.stdout),
- ))
+impl Response {
+ pub fn new(messages: Vec<Message>) -> Self {
+ Self { messages }
}
-}
-pub async fn send_language_model_request(
- model: Arc<dyn LanguageModel>,
- request: LanguageModelRequest,
- cx: &AsyncApp,
-) -> anyhow::Result<String> {
- match model.stream_completion_text(request, &cx).await {
- Ok(mut stream) => {
- let mut full_response = String::new();
- while let Some(chunk_result) = stream.stream.next().await {
- match chunk_result {
- Ok(chunk_str) => {
- full_response.push_str(&chunk_str);
- }
- Err(err) => {
- return Err(anyhow!(
- "Error receiving response from language model: {err}"
- ));
- }
- }
- }
- Ok(full_response)
- }
- Err(err) => Err(anyhow!(
- "Failed to get response from language model. Error was: {err}"
- )),
+ pub fn expect_tool(
+ &self,
+ tool_name: &'static str,
+ cx: &mut ExampleContext,
+ ) -> Result<&ToolUse> {
+ let result = self.messages.iter().find_map(|msg| {
+ msg.tool_use
+ .iter()
+ .find(|tool_use| tool_use.name == tool_name)
+ });
+ cx.assert_some(result, format!("called `{}`", tool_name))
}
}
-struct RequestMarkdown {
- tools: String,
- messages: String,
+#[derive(Debug)]
+pub struct Message {
+ _role: Role,
+ _text: String,
+ tool_use: Vec<ToolUse>,
}
-impl RequestMarkdown {
- fn new(request: &LanguageModelRequest) -> Self {
- let mut tools = String::new();
- let mut messages = String::new();
- let mut assistant_message_number: u32 = 1;
-
- // Print the tools
- if !request.tools.is_empty() {
- for tool in &request.tools {
- write!(&mut tools, "# {}\n\n", tool.name).unwrap();
- write!(&mut tools, "{}\n\n", tool.description).unwrap();
- write!(
- &mut tools,
- "{}\n",
- MarkdownString::code_block("json", &format!("{:#}", tool.input_schema))
- )
- .unwrap();
- }
- }
-
- // Print the messages
- for message in &request.messages {
- match message.role {
- Role::System => messages.push_str("# ⚙️ SYSTEM\n\n"),
- Role::User => messages.push_str("# 👤 USER\n\n"),
- Role::Assistant => {
- messages.push_str(&format!("# 🤖 ASSISTANT {assistant_message_number}\n\n"));
- assistant_message_number += 1;
- }
- };
-
- for content in &message.content {
- match content {
- MessageContent::Text(text) => {
- messages.push_str(text);
- messages.push_str("\n\n");
- }
- MessageContent::Image(_) => {
- messages.push_str("[IMAGE DATA]\n\n");
- }
- MessageContent::Thinking { text, signature } => {
- messages.push_str("**Thinking**:\n\n");
- if let Some(sig) = signature {
- messages.push_str(&format!("Signature: {}\n\n", sig));
- }
- messages.push_str(text);
- messages.push_str("\n");
- }
- MessageContent::RedactedThinking(items) => {
- messages.push_str(&format!(
- "**Redacted Thinking**: {} item(s)\n\n",
- items.len()
- ));
- }
- MessageContent::ToolUse(tool_use) => {
- messages.push_str(&format!(
- "**Tool Use**: {} (ID: {})\n",
- tool_use.name, tool_use.id
- ));
- messages.push_str(&format!(
- "{}\n",
- MarkdownString::code_block("json", &format!("{:#}", tool_use.input))
- ));
- }
- MessageContent::ToolResult(tool_result) => {
- messages.push_str(&format!(
- "**Tool Result**: {} (ID: {})\n\n",
- tool_result.tool_name, tool_result.tool_use_id
- ));
- if tool_result.is_error {
- messages.push_str("**ERROR:**\n");
- }
- messages.push_str(&format!("{}\n\n", tool_result.content));
- }
- }
- }
- }
-
- Self { tools, messages }
- }
-}
-
-fn response_events_to_markdown(
- response_events: &[std::result::Result<LanguageModelCompletionEvent, String>],
-) -> String {
- let mut response = String::new();
- // Print the response events if any
- response.push_str("# Response\n\n");
- let mut text_buffer = String::new();
- let mut thinking_buffer = String::new();
-
- let flush_buffers =
- |output: &mut String, text_buffer: &mut String, thinking_buffer: &mut String| {
- if !text_buffer.is_empty() {
- output.push_str(&format!("**Text**:\n{}\n\n", text_buffer));
- text_buffer.clear();
- }
- if !thinking_buffer.is_empty() {
- output.push_str(&format!("**Thinking**:\n{}\n\n", thinking_buffer));
- thinking_buffer.clear();
- }
- };
-
- for event in response_events {
- match event {
- Ok(LanguageModelCompletionEvent::Text(text)) => {
- text_buffer.push_str(text);
- }
- Ok(LanguageModelCompletionEvent::Thinking { text, .. }) => {
- thinking_buffer.push_str(text);
- }
- Ok(LanguageModelCompletionEvent::Stop(reason)) => {
- flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
- response.push_str(&format!("**Stop**: {:?}\n\n", reason));
- }
- Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
- flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
- response.push_str(&format!(
- "**Tool Use**: {} (ID: {})\n",
- tool_use.name, tool_use.id
- ));
- response.push_str(&format!(
- "{}\n",
- MarkdownString::code_block("json", &format!("{:#}", tool_use.input))
- ));
- }
- Ok(
- LanguageModelCompletionEvent::UsageUpdate(_)
- | LanguageModelCompletionEvent::StartMessage { .. },
- ) => {}
- Err(error) => {
- flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
- response.push_str(&format!("**Error**: {}\n\n", error));
- }
- }
- }
-
- flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
-
- response
+#[derive(Debug)]
+pub struct ToolUse {
+ name: String,
+ value: serde_json::Value,
}
-#[cfg(test)]
-mod test {
- use super::*;
-
- #[test]
- fn test_parse_judge_output() {
- let response = r#"
- <analysis>The model did a good job but there were still compilations errors.</analysis>
- <passing_criteria>3</passing_criteria>
- <total_criteria>5</total_criteria>
- "#
- .unindent();
-
- let output = JudgeResponse::parse(&response).unwrap();
- assert_eq!(
- output.analysis,
- "The model did a good job but there were still compilations errors."
- );
- assert_eq!(output.passing_criteria, 3);
- assert_eq!(output.total_criteria, 5);
-
- let response = r#"
- Text around ignored
-
- <analysis>
- Failed to compile:
- - Error 1
- - Error 2
- </analysis>
-
- <passing_criteria>1</passing_criteria>
-
- <total_criteria>3</total_criteria>
- "#
- .unindent();
-
- let output = JudgeResponse::parse(&response).unwrap();
- assert_eq!(output.analysis, "Failed to compile:\n- Error 1\n- Error 2");
- assert_eq!(output.passing_criteria, 1);
- assert_eq!(output.total_criteria, 3);
+impl ToolUse {
+ pub fn expect_input<Input>(&self, cx: &mut ExampleContext) -> Result<Input>
+ where
+ Input: for<'de> serde::Deserialize<'de>,
+ {
+ let result =
+ serde_json::from_value::<Input>(self.value.clone()).map_err(|err| anyhow!(err));
+ cx.log_assertion(result, format!("valid `{}` input", &self.name))
}
}
@@ -0,0 +1,53 @@
+use anyhow::Result;
+use assistant_tools::PathSearchToolInput;
+use async_trait::async_trait;
+use regex::Regex;
+
+use crate::example::{Example, ExampleContext, ExampleMetadata};
+
+pub struct FileSearchExample;
+
+#[async_trait(?Send)]
+impl Example for FileSearchExample {
+ fn meta(&self) -> ExampleMetadata {
+ ExampleMetadata {
+ name: "file_search".to_string(),
+ url: "https://github.com/zed-industries/zed.git".to_string(),
+ revision: "03ecb88fe30794873f191ddb728f597935b3101c".to_string(),
+ language_server: None,
+ max_assertions: Some(4),
+ }
+ }
+
+ async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
+ const FILENAME: &str = "find_replace_file_tool.rs";
+ cx.push_user_message(format!(
+ r#"
+ Look at the `{FILENAME}`. I want to implement a card for it. The card should implement the `Render` trait.
+
+ The card should show a diff. It should be a beautifully presented diff. The card "box" should look like what we show for
+ markdown codeblocks (look at `MarkdownElement`). I want to see a red background for lines that were deleted and a green
+ background for lines that were added. We should have a div per diff line.
+ "#
+ ));
+
+ let response = cx.run_turn().await?;
+ let tool_use = response.expect_tool("path_search", cx)?;
+ let input = tool_use.expect_input::<PathSearchToolInput>(cx)?;
+
+ let glob = input.glob;
+ cx.assert(
+ glob.ends_with(FILENAME),
+ format!("glob ends with `{FILENAME}`"),
+ )?;
+
+ let without_filename = glob.replace(FILENAME, "");
+ let matches = Regex::new("(\\*\\*|zed)/(\\*\\*?/)?")
+ .unwrap()
+ .is_match(&without_filename);
+
+ cx.assert(matches, "glob starts with either `**` or `zed`")?;
+
+ Ok(())
+ }
+}
@@ -0,0 +1,43 @@
+url = "https://github.com/zed-industries/zed.git"
+revision = "38fcadf9481d018543c65f36ac3bafeba190179b"
+language_extension = "rs"
+
+prompt = """
+Look at the `find_replace_file_tool.rs`. I want to implement a card for it.
+The card should implement the `Render` trait.
+
+The card should show a diff. It should be a beautifully presented diff.
+The card "box" should look like what we show for markdown codeblocks (look at `MarkdownElement`).
+I want to see a red background for lines that were deleted and a green background for lines
+that were added. We should have a div per diff line.
+"""
+
+[diff_assertions]
+
+modify_find_and_replace_tool = """
+The changes must replace the previous output returned by `FindReplaceFileTool` with the new `ToolResult` struct.
+The struct should contain an `output` field that is the same as the task we were returning before,
+and a new `card` field that contains a view for the card.
+"""
+
+card_implementation = """
+The card should be a view that displays a diff.
+Each line in the diff should be colored according to whether it was added, removed or unchanged.
+"""
+
+[thread_assertions]
+
+path_search = """
+The first tool call should be to path search including "find_replace_file_tool.rs" in the string.
+(*Not* grep, for example, or reading the file based on a guess at the path.)
+This is because we gave the model a filename and it needs to turn that into a real path.
+"""
+
+read_file_from_path_search = """
+After obtaining the correct path of "zed/crates/assistant_tools/src/find_replace_file_tool.rs", it should read the contents of that path.
+"""
+
+symbol_search = """
+When trying to find information about the Render trait, it should *not* begin with a path search, because it doesn't yet have any information
+on what path the Render trait might be in.
+"""
@@ -0,0 +1,128 @@
+use anyhow::Result;
+use async_trait::async_trait;
+use serde::Deserialize;
+use std::collections::BTreeMap;
+use std::fs;
+use std::{
+ path::{Path, PathBuf},
+ rc::Rc,
+};
+use util::serde::default_true;
+
+use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion};
+
+mod file_search;
+
+pub fn all(examples_dir: &Path) -> Vec<Rc<dyn Example>> {
+ let mut threads: Vec<Rc<dyn Example>> = vec![Rc::new(file_search::FileSearchExample)];
+
+ for example_path in list_declarative_examples(examples_dir).unwrap() {
+ threads.push(Rc::new(DeclarativeExample::load(&example_path).unwrap()));
+ }
+
+ threads
+}
+
+struct DeclarativeExample {
+ metadata: ExampleMetadata,
+ prompt: String,
+ diff_assertions: Vec<JudgeAssertion>,
+ thread_assertions: Vec<JudgeAssertion>,
+}
+
+impl DeclarativeExample {
+ pub fn load(example_path: &Path) -> Result<Self> {
+ let name = Self::name_from_path(example_path);
+ let base: ExampleToml = toml::from_str(&fs::read_to_string(&example_path)?)?;
+
+ let language_server = if base.require_lsp {
+ Some(crate::example::LanguageServer {
+ file_extension: base
+ .language_extension
+ .expect("Language extension is required when require_lsp = true"),
+ allow_preexisting_diagnostics: base.allow_preexisting_diagnostics,
+ })
+ } else {
+ None
+ };
+
+ let metadata = ExampleMetadata {
+ name,
+ url: base.url,
+ revision: base.revision,
+ language_server,
+ max_assertions: None,
+ };
+
+ Ok(DeclarativeExample {
+ metadata,
+ prompt: base.prompt,
+ thread_assertions: base
+ .thread_assertions
+ .into_iter()
+ .map(|(id, description)| JudgeAssertion { id, description })
+ .collect(),
+ diff_assertions: base
+ .diff_assertions
+ .into_iter()
+ .map(|(id, description)| JudgeAssertion { id, description })
+ .collect(),
+ })
+ }
+
+ pub fn name_from_path(path: &Path) -> String {
+ path.file_stem().unwrap().to_string_lossy().to_string()
+ }
+}
+
+#[derive(Clone, Debug, Deserialize)]
+pub struct ExampleToml {
+ pub url: String,
+ pub revision: String,
+ pub language_extension: Option<String>,
+ pub insert_id: Option<String>,
+ #[serde(default = "default_true")]
+ pub require_lsp: bool,
+ #[serde(default)]
+ pub allow_preexisting_diagnostics: bool,
+ pub prompt: String,
+ #[serde(default)]
+ pub diff_assertions: BTreeMap<String, String>,
+ #[serde(default)]
+ pub thread_assertions: BTreeMap<String, String>,
+}
+
+#[async_trait(?Send)]
+impl Example for DeclarativeExample {
+ fn meta(&self) -> ExampleMetadata {
+ self.metadata.clone()
+ }
+
+ async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
+ cx.push_user_message(&self.prompt);
+ let _ = cx.run_to_end().await;
+ Ok(())
+ }
+
+ fn diff_assertions(&self) -> Vec<JudgeAssertion> {
+ self.diff_assertions.clone()
+ }
+
+ fn thread_assertions(&self) -> Vec<JudgeAssertion> {
+ self.thread_assertions.clone()
+ }
+}
+
+fn list_declarative_examples(examples_dir: &Path) -> Result<Vec<PathBuf>> {
+ let path = std::fs::canonicalize(examples_dir).unwrap();
+ let entries = std::fs::read_dir(path).unwrap();
+ let mut result_paths = Vec::new();
+ for entry in entries {
+ let entry = entry?;
+ let path = entry.path();
+ if path.extension() == Some("toml".as_ref()) {
+ result_paths.push(path);
+ }
+ }
+ Ok(result_paths)
+}
@@ -0,0 +1,1023 @@
+use agent::ThreadStore;
+use anyhow::{Context, Result, anyhow, bail};
+use assistant_tool::ToolWorkingSet;
+use client::proto::LspWorkProgress;
+use futures::channel::mpsc;
+use futures::{FutureExt as _, StreamExt as _, future};
+use gpui::{App, AppContext as _, AsyncApp, Entity, Task};
+use handlebars::Handlebars;
+use language::{Buffer, DiagnosticSeverity, OffsetRangeExt as _};
+use language_model::{
+ LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
+ MessageContent, Role, TokenUsage,
+};
+use project::lsp_store::OpenLspBufferHandle;
+use project::{DiagnosticSummary, Project, ProjectPath};
+use serde::{Deserialize, Serialize};
+use std::cell::RefCell;
+use std::fmt::Write as _;
+use std::fs;
+use std::fs::File;
+use std::io::Write as _;
+use std::path::Path;
+use std::path::PathBuf;
+use std::rc::Rc;
+use std::sync::Arc;
+use std::time::Duration;
+use unindent::Unindent as _;
+use util::ResultExt as _;
+use util::command::new_smol_command;
+use util::markdown::MarkdownString;
+
+use crate::assertions::{AssertionsReport, RanAssertion, RanAssertionResult};
+use crate::example::{Example, ExampleContext, FailedAssertion, JudgeAssertion};
+use crate::{AgentAppState, ToolMetrics};
+
+pub const ZED_REPO_URL: &str = "https://github.com/zed-industries/zed.git";
+
+#[derive(Clone)]
+pub struct ExampleInstance {
+ pub thread: Rc<dyn Example>,
+ pub name: String,
+ pub run_directory: PathBuf,
+ pub log_prefix: String,
+ /// The repetition number for this example (0-based)
+ /// When running multiple repetitions of the same example, each instance is assigned a unique repetition number.
+ /// This affects the worktree path and log prefix to avoid clobbering results between runs.
+ pub repetition: usize,
+ pub repo_path: PathBuf,
+ /// Path to the directory containing the requests and responses for the agentic loop
+ worktrees_dir: PathBuf,
+}
+
+#[derive(Debug, Serialize, Clone)]
+pub struct RunOutput {
+ pub repository_diff: String,
+ pub diagnostic_summary_before: DiagnosticSummary,
+ pub diagnostic_summary_after: DiagnosticSummary,
+ pub diagnostics_before: Option<String>,
+ pub diagnostics_after: Option<String>,
+ pub response_count: usize,
+ pub token_usage: TokenUsage,
+ pub tool_metrics: ToolMetrics,
+ pub last_request: LanguageModelRequest,
+ pub programmatic_assertions: AssertionsReport,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct JudgeDiffInput {
+ pub repository_diff: String,
+ pub assertion: String,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct JudgeThreadInput {
+ pub messages: String,
+ pub assertion: String,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct JudgeOutput {
+ pub thread: AssertionsReport,
+ pub diff: AssertionsReport,
+}
+
+impl ExampleInstance {
+ pub fn new(
+ thread: Rc<dyn Example>,
+ repos_dir: &Path,
+ run_dir: &Path,
+ worktrees_dir: &Path,
+ repetition: usize,
+ ) -> Self {
+ let name = thread.meta().name.to_string();
+ let run_directory = run_dir
+ .join(&name)
+ .join(repetition.to_string())
+ .to_path_buf();
+
+ let repo_path = repo_path_for_url(repos_dir, &thread.meta().url);
+
+ Self {
+ name,
+ thread,
+ log_prefix: String::new(),
+ run_directory,
+ repetition,
+ repo_path,
+ worktrees_dir: worktrees_dir.to_path_buf(),
+ }
+ }
+
+ pub fn repo_url(&self) -> String {
+ self.thread.meta().url
+ }
+
+ pub fn revision(&self) -> String {
+ self.thread.meta().revision
+ }
+
+ pub fn worktree_name(&self) -> String {
+ format!("{}-{}", self.name, self.repetition)
+ }
+
+ pub fn set_log_prefix_style(&mut self, color: &str, name_width: usize) {
+ self.log_prefix = format!(
+ "{}{:<width$}\x1b[0m | ",
+ color,
+ self.worktree_name(),
+ width = name_width
+ );
+ }
+
+ /// Set up the example by checking out the specified Git revision
+ pub async fn fetch(&mut self) -> Result<()> {
+ let meta = self.thread.meta();
+
+ let revision_exists = run_git(
+ &self.repo_path,
+ &["rev-parse", &format!("{}^{{commit}}", &meta.revision)],
+ )
+ .await
+ .is_ok();
+
+ if !revision_exists {
+ println!("{}Fetching revision {}", self.log_prefix, &meta.revision);
+ run_git(
+ &self.repo_path,
+ &["fetch", "--depth", "1", "origin", &meta.revision],
+ )
+ .await?;
+ }
+ Ok(())
+ }
+
+ /// Set up the example by checking out the specified Git revision
+ pub async fn setup(&mut self) -> Result<()> {
+ let worktree_path = self.worktree_path();
+ let meta = self.thread.meta();
+ if worktree_path.is_dir() {
+ println!("{}Resetting existing worktree", self.log_prefix);
+
+ // TODO: consider including "-x" to remove ignored files. The downside of this is that
+ // it will also remove build artifacts, and so prevent incremental reuse there.
+ run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
+ run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
+ run_git(&worktree_path, &["checkout", &meta.revision]).await?;
+ } else {
+ println!("{}Creating worktree", self.log_prefix);
+
+ let worktree_path_string = worktree_path.to_string_lossy().to_string();
+
+ run_git(
+ &self.repo_path,
+ &[
+ "worktree",
+ "add",
+ "-f",
+ &worktree_path_string,
+ &meta.revision,
+ ],
+ )
+ .await?;
+ }
+
+ if meta.url == ZED_REPO_URL {
+ std::fs::write(worktree_path.join(".rules"), std::fs::read(".rules")?)?;
+ }
+
+ std::fs::create_dir_all(&self.run_directory)?;
+
+ Ok(())
+ }
+
+ pub fn worktree_path(&self) -> PathBuf {
+ self.worktrees_dir
+ .join(self.worktree_name())
+ .join(self.thread.meta().repo_name())
+ }
+
+ pub fn run(
+ &self,
+ model: Arc<dyn LanguageModel>,
+ app_state: Arc<AgentAppState>,
+ cx: &mut App,
+ ) -> Task<Result<RunOutput>> {
+ let project = Project::local(
+ app_state.client.clone(),
+ app_state.node_runtime.clone(),
+ app_state.user_store.clone(),
+ app_state.languages.clone(),
+ app_state.fs.clone(),
+ None,
+ cx,
+ );
+
+ let worktree = project.update(cx, |project, cx| {
+ project.create_worktree(self.worktree_path(), true, cx)
+ });
+
+ let tools = cx.new(|_| ToolWorkingSet::default());
+ let thread_store =
+ ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
+ let meta = self.thread.meta();
+ let this = self.clone();
+
+ cx.spawn(async move |cx| {
+ let worktree = worktree.await?;
+
+ // Wait for worktree scan to finish before choosing a file to open.
+ worktree
+ .update(cx, |worktree, _cx| {
+ worktree.as_local().unwrap().scan_complete()
+ })?
+ .await;
+
+ struct LanguageServerState {
+ _lsp_open_handle: OpenLspBufferHandle,
+ language_file_buffer: Entity<Buffer>,
+ }
+
+ let mut diagnostics_before = None;
+ let mut diagnostic_summary_before = DiagnosticSummary::default();
+
+ let lsp = if let Some(language_server) = &meta.language_server {
+ // Open a file that matches the language to cause LSP to start.
+ let language_file = worktree.read_with(cx, |worktree, _cx| {
+ worktree
+ .files(false, 0)
+ .find_map(|e| {
+ if e.path.clone().extension().and_then(|ext| ext.to_str())
+ == Some(&language_server.file_extension)
+ {
+ Some(ProjectPath {
+ worktree_id: worktree.id(),
+ path: e.path.clone(),
+ })
+ } else {
+ None
+ }
+ })
+ .context("Failed to find a file for example language")
+ })??;
+
+ let open_language_file_buffer_task = project.update(cx, |project, cx| {
+ project.open_buffer(language_file.clone(), cx)
+ })?;
+
+ let language_file_buffer = open_language_file_buffer_task.await?;
+
+ let lsp_open_handle = project.update(cx, |project, cx| {
+ project.register_buffer_with_language_servers(&language_file_buffer, cx)
+ })?;
+
+ wait_for_lang_server(&project, &language_file_buffer, this.log_prefix.clone(), cx).await?;
+
+ diagnostic_summary_before = project.read_with(cx, |project, cx| {
+ project.diagnostic_summary(false, cx)
+ })?;
+
+ diagnostics_before = query_lsp_diagnostics(project.clone(), cx).await?;
+ if diagnostics_before.is_some() && language_server.allow_preexisting_diagnostics {
+ return Err(anyhow!("Example has pre-existing diagnostics. If you want to run this example regardless, set `allow_preexisting_diagnostics` to `true` in `base.toml`"));
+ }
+
+ Some(LanguageServerState {
+ _lsp_open_handle: lsp_open_handle,
+ language_file_buffer,
+ })
+ } else {
+ None
+ };
+
+ if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() {
+ return Err(anyhow!("Setup only mode"));
+ }
+
+ let last_diff_file_path = this.run_directory.join("last.diff");
+
+ // Write an empty "last.diff" so that it can be opened in Zed for convenient view of the
+ // history using undo/redo.
+ std::fs::write(&last_diff_file_path, "")?;
+
+ let thread_store = thread_store.await?;
+ let thread =
+ thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
+ let last_request = Rc::new(RefCell::new(None));
+
+ thread.update(cx, |thread, _cx| {
+ let mut request_count = 0;
+ let last_request = Rc::clone(&last_request);
+ let previous_diff = Rc::new(RefCell::new("".to_string()));
+ let example_output_dir = this.run_directory.clone();
+ let last_diff_file_path = last_diff_file_path.clone();
+ let this = this.clone();
+ thread.set_request_callback(move |request, response_events| {
+ *last_request.borrow_mut() = Some(request.clone());
+
+ request_count += 1;
+ let messages_file_path = example_output_dir.join(format!("{request_count}.messages.md"));
+ let diff_file_path = example_output_dir.join(format!("{request_count}.diff"));
+ let last_messages_file_path = example_output_dir.join("last.messages.md");
+ let request_markdown = RequestMarkdown::new(request);
+ let response_events_markdown = response_events_to_markdown(response_events);
+
+ let messages = format!("{}\n\n{}", request_markdown.messages, response_events_markdown);
+ fs::write(&messages_file_path, messages.clone()).expect("failed to write messages file");
+ fs::write(&last_messages_file_path, messages).expect("failed to write last messages file");
+
+ let diff_result = smol::block_on(this.repository_diff());
+ match diff_result {
+ Ok(diff) => {
+ if diff != previous_diff.borrow().clone() {
+ fs::write(&diff_file_path, &diff).expect("failed to write diff file");
+ fs::write(&last_diff_file_path, &diff).expect("failed to write last diff file");
+ *previous_diff.borrow_mut() = diff;
+ }
+ }
+ Err(err) => {
+ let error_message = format!("{err:?}");
+ fs::write(&diff_file_path, &error_message).expect("failed to write diff error to file");
+ fs::write(&last_diff_file_path, &error_message).expect("failed to write last diff file");
+ }
+ }
+
+ if request_count == 1 {
+ let tools_file_path = example_output_dir.join("tools.md");
+ fs::write(tools_file_path, request_markdown.tools).expect("failed to write tools file");
+ }
+ });
+ })?;
+
+ let mut example_cx = ExampleContext::new(meta.clone(), this.log_prefix.clone(), thread.clone(), model.clone(), cx.clone());
+ let result = this.thread.conversation(&mut example_cx).await;
+
+ if let Err(err) = result {
+ if !err.is::<FailedAssertion>() {
+ return Err(err);
+ }
+ }
+
+ println!("{}Stopped", this.log_prefix);
+
+ println!("{}Getting repository diff", this.log_prefix);
+ let repository_diff = this.repository_diff().await?;
+
+ std::fs::write(last_diff_file_path, &repository_diff)?;
+
+
+ let mut diagnostics_after = None;
+ let mut diagnostic_summary_after = Default::default();
+
+ if let Some(language_server_state) = lsp {
+ wait_for_lang_server(&project, &language_server_state.language_file_buffer, this.log_prefix.clone(), cx).await?;
+
+ println!("{}Getting diagnostics", this.log_prefix);
+ diagnostics_after = cx
+ .update(|cx| {
+ let project = project.clone();
+ cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await)
+ })?
+ .await?;
+ println!("{}Got diagnostics", this.log_prefix);
+
+ diagnostic_summary_after = project.read_with(cx, |project, cx| {
+ project.diagnostic_summary(false, cx)
+ })?;
+
+ }
+
+ let Some(last_request) = last_request.borrow_mut().take() else {
+ return Err(anyhow!("No requests ran."));
+ };
+
+ if let Some(diagnostics_before) = &diagnostics_before {
+ fs::write(this.run_directory.join("diagnostics_before.txt"), diagnostics_before)?;
+ }
+
+ if let Some(diagnostics_after) = &diagnostics_after {
+ fs::write(this.run_directory.join("diagnostics_after.txt"), diagnostics_after)?;
+ }
+
+ thread.update(cx, |thread, _cx| {
+ let response_count = thread
+ .messages()
+ .filter(|message| message.role == language_model::Role::Assistant)
+ .count();
+ RunOutput {
+ repository_diff,
+ diagnostic_summary_before,
+ diagnostic_summary_after,
+ diagnostics_before,
+ diagnostics_after,
+ response_count,
+ token_usage: thread.cumulative_token_usage(),
+ tool_metrics: example_cx.tool_metrics.lock().unwrap().clone(),
+ last_request,
+ programmatic_assertions: example_cx.assertions,
+ }
+ })
+ })
+ }
+
+ async fn repository_diff(&self) -> Result<String> {
+ let worktree_path = self.worktree_path();
+ run_git(&worktree_path, &["add", "."]).await?;
+ let mut diff_args = vec!["diff", "--staged"];
+ if self.thread.meta().url == ZED_REPO_URL {
+ diff_args.push(":(exclude).rules");
+ }
+ run_git(&worktree_path, &diff_args).await
+ }
+
+ pub async fn judge(
+ &self,
+ model: Arc<dyn LanguageModel>,
+ run_output: &RunOutput,
+ cx: &AsyncApp,
+ ) -> JudgeOutput {
+ let mut output_file =
+ File::create(self.run_directory.join("judge.md")).expect("failed to create judge.md");
+
+ let diff_task = self.judge_diff(model.clone(), &run_output, cx);
+ let thread_task = self.judge_thread(model.clone(), &run_output, cx);
+
+ let (diff_result, thread_result) = futures::join!(diff_task, thread_task);
+
+ let (diff_response, diff_output) = diff_result;
+ let (thread_response, thread_output) = thread_result;
+
+ writeln!(
+ &mut output_file,
+ "# Judgment\n\n## Thread\n\n{thread_response}\n\n## Diff\n\n{diff_response}",
+ )
+ .log_err();
+
+ JudgeOutput {
+ thread: thread_output,
+ diff: diff_output,
+ }
+ }
+
+ async fn judge_diff(
+ &self,
+ model: Arc<dyn LanguageModel>,
+ run_output: &RunOutput,
+ cx: &AsyncApp,
+ ) -> (String, AssertionsReport) {
+ let diff_assertions = self.thread.diff_assertions();
+
+ if diff_assertions.is_empty() {
+ return (
+ "No diff assertions".to_string(),
+ AssertionsReport::default(),
+ );
+ }
+
+ println!("{}Running diff judge", self.log_prefix);
+
+ let judge_diff_prompt = include_str!("judge_diff_prompt.hbs");
+ let judge_diff_prompt_name = "judge_diff_prompt";
+ let mut hbs = Handlebars::new();
+ hbs.register_template_string(judge_diff_prompt_name, judge_diff_prompt)
+ .unwrap();
+
+ let to_prompt = |assertion: String| {
+ hbs.render(
+ judge_diff_prompt_name,
+ &JudgeDiffInput {
+ repository_diff: run_output.repository_diff.clone(),
+ assertion,
+ },
+ )
+ .unwrap()
+ };
+
+ let (responses, report) = self
+ .judge_assertions(model, diff_assertions, to_prompt, cx)
+ .await;
+
+ println!(
+ "{}Judge - Diff score: {}%",
+ self.log_prefix,
+ report.passed_percentage()
+ );
+
+ (responses, report)
+ }
+
+ async fn judge_thread(
+ &self,
+ model: Arc<dyn LanguageModel>,
+ run_output: &RunOutput,
+ cx: &AsyncApp,
+ ) -> (String, AssertionsReport) {
+ let thread_assertions = self.thread.thread_assertions();
+
+ if thread_assertions.is_empty() {
+ return (
+ "No diff assertions".to_string(),
+ AssertionsReport::default(),
+ );
+ }
+
+ let judge_thread_prompt = include_str!("judge_thread_prompt.hbs");
+ let judge_diff_prompt_name = "judge_thread_prompt";
+ let mut hbs = Handlebars::new();
+ hbs.register_template_string(judge_diff_prompt_name, judge_thread_prompt)
+ .unwrap();
+
+ let request_markdown = RequestMarkdown::new(&run_output.last_request);
+ let to_prompt = |assertion: String| {
+ hbs.render(
+ judge_diff_prompt_name,
+ &JudgeThreadInput {
+ messages: request_markdown.messages.clone(),
+ assertion,
+ },
+ )
+ .unwrap()
+ };
+
+ let (responses, report) = self
+ .judge_assertions(model, thread_assertions, to_prompt, cx)
+ .await;
+
+ println!(
+ "{}Judge - Thread score: {}%",
+ self.log_prefix,
+ report.passed_percentage()
+ );
+
+ (responses, report)
+ }
+
+ async fn judge_assertions(
+ &self,
+ model: Arc<dyn LanguageModel>,
+ assertions: Vec<JudgeAssertion>,
+ to_prompt: impl Fn(String) -> String,
+ cx: &AsyncApp,
+ ) -> (String, AssertionsReport) {
+ let assertions = assertions.into_iter().map(|assertion| {
+ let request = LanguageModelRequest {
+ thread_id: None,
+ prompt_id: None,
+ messages: vec![LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec![MessageContent::Text(to_prompt(assertion.description))],
+ cache: false,
+ }],
+ temperature: None,
+ tools: Vec::new(),
+ stop: Vec::new(),
+ };
+
+ let model = model.clone();
+ let log_prefix = self.log_prefix.clone();
+ async move {
+ let response = send_language_model_request(model, request, cx).await;
+
+ let (response, result) = match response {
+ Ok(response) => (
+ response.clone(),
+ parse_assertion_result(&response).map_err(|err| err.to_string()),
+ ),
+ Err(err) => (err.to_string(), Err(err.to_string())),
+ };
+
+ if result.is_ok() {
+ println!("{}✅ {}", log_prefix, assertion.id);
+ } else {
+ println!("{}❌ {}", log_prefix, assertion.id);
+ }
+
+ (
+ response,
+ RanAssertion {
+ id: assertion.id,
+ result,
+ },
+ )
+ }
+ });
+
+ let mut responses = String::new();
+ let mut report = AssertionsReport::default();
+
+ for (response, assertion) in future::join_all(assertions).await {
+ writeln!(&mut responses, "# {}", assertion.id).unwrap();
+ writeln!(&mut responses, "{}\n\n", response).unwrap();
+ report.ran.push(assertion);
+ }
+
+ (responses, report)
+ }
+}
+
+pub fn wait_for_lang_server(
+ project: &Entity<Project>,
+ buffer: &Entity<Buffer>,
+ log_prefix: String,
+ cx: &mut AsyncApp,
+) -> Task<Result<()>> {
+ if std::env::var("ZED_EVAL_SKIP_LS").is_ok() {
+ return Task::ready(Ok(()));
+ }
+
+ println!("{}⏵ Waiting for language server", log_prefix);
+
+ let (mut tx, mut rx) = mpsc::channel(1);
+
+ let lsp_store = project
+ .update(cx, |project, _| project.lsp_store())
+ .unwrap();
+
+ let has_lang_server = buffer
+ .update(cx, |buffer, cx| {
+ lsp_store.update(cx, |lsp_store, cx| {
+ lsp_store
+ .language_servers_for_local_buffer(&buffer, cx)
+ .next()
+ .is_some()
+ })
+ })
+ .unwrap_or(false);
+
+ if has_lang_server {
+ project
+ .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
+ .unwrap()
+ .detach();
+ }
+
+ let subscriptions =
+ [
+ cx.subscribe(&lsp_store, {
+ let log_prefix = log_prefix.clone();
+ move |_, event, _| match event {
+ project::LspStoreEvent::LanguageServerUpdate {
+ message:
+ client::proto::update_language_server::Variant::WorkProgress(
+ LspWorkProgress {
+ message: Some(message),
+ ..
+ },
+ ),
+ ..
+ } => println!("{}⟲ {message}", log_prefix),
+ _ => {}
+ }
+ }),
+ cx.subscribe(&project, {
+ let buffer = buffer.clone();
+ move |project, event, cx| match event {
+ project::Event::LanguageServerAdded(_, _, _) => {
+ let buffer = buffer.clone();
+ project
+ .update(cx, |project, cx| project.save_buffer(buffer, cx))
+ .detach();
+ }
+ project::Event::DiskBasedDiagnosticsFinished { .. } => {
+ tx.try_send(()).ok();
+ }
+ _ => {}
+ }
+ }),
+ ];
+
+ cx.spawn(async move |cx| {
+ let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
+ let result = futures::select! {
+ _ = rx.next() => {
+ println!("{}⚑ Language server idle", log_prefix);
+ anyhow::Ok(())
+ },
+ _ = timeout.fuse() => {
+ Err(anyhow!("LSP wait timed out after 5 minutes"))
+ }
+ };
+ drop(subscriptions);
+ result
+ })
+}
+
+pub async fn query_lsp_diagnostics(
+ project: Entity<Project>,
+ cx: &mut AsyncApp,
+) -> Result<Option<String>> {
+ let paths_with_diagnostics = project.update(cx, |project, cx| {
+ project
+ .diagnostic_summaries(true, cx)
+ .filter(|(_, _, summary)| summary.error_count > 0 || summary.warning_count > 0)
+ .map(|(project_path, _, _)| project_path)
+ .collect::<Vec<_>>()
+ })?;
+
+ if paths_with_diagnostics.is_empty() {
+ return Ok(None);
+ }
+
+ let mut output = String::new();
+ for project_path in paths_with_diagnostics {
+ let buffer = project
+ .update(cx, |project, cx| project.open_buffer(project_path, cx))?
+ .await?;
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
+
+ for (_, group) in snapshot.diagnostic_groups(None) {
+ let entry = &group.entries[group.primary_ix];
+ let range = entry.range.to_point(&snapshot);
+ let severity = match entry.diagnostic.severity {
+ DiagnosticSeverity::ERROR => "error",
+ DiagnosticSeverity::WARNING => "warning",
+ _ => continue,
+ };
+
+ writeln!(
+ output,
+ "{} at line {}: {}",
+ severity,
+ range.start.row + 1,
+ entry.diagnostic.message
+ )?;
+ }
+ }
+ anyhow::Ok(Some(output))
+}
+
+fn parse_assertion_result(response: &str) -> Result<RanAssertionResult> {
+ let analysis = get_tag("analysis", response)?.to_string();
+ let passed = match get_tag("passed", response)?.to_lowercase().as_str() {
+ "true" => true,
+ "false" => false,
+ value @ _ => bail!("invalid judge `passed` tag: {value}"),
+ };
+ Ok(RanAssertionResult {
+ analysis: Some(analysis),
+ passed,
+ })
+}
+
+fn get_tag(name: &'static str, response: &str) -> Result<String> {
+ let start_tag = format!("<{}>", name);
+ let end_tag = format!("</{}>", name);
+
+ let start_ix = response
+ .find(&start_tag)
+ .context(format!("{} start tag not found", name))?;
+ let content_start_ix = start_ix + start_tag.len();
+
+ let end_ix = content_start_ix
+ + response[content_start_ix..]
+ .find(&end_tag)
+ .context(format!("{} end tag not found", name))?;
+
+ let content = response[content_start_ix..end_ix].trim().unindent();
+
+ anyhow::Ok(content)
+}
+
+pub fn repo_path_for_url(repos_dir: &Path, repo_url: &str) -> PathBuf {
+ let repo_name = repo_url
+ .trim_start_matches("https://")
+ .replace(|c: char| !c.is_alphanumeric(), "-");
+ Path::new(repos_dir).join(repo_name)
+}
+
+pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
+ let output = new_smol_command("git")
+ .current_dir(repo_path)
+ .args(args)
+ .output()
+ .await?;
+
+ if output.status.success() {
+ Ok(String::from_utf8(output.stdout)?.trim().to_string())
+ } else {
+ Err(anyhow!(
+ "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
+ args.join(" "),
+ repo_path.display(),
+ output.status,
+ String::from_utf8_lossy(&output.stderr),
+ String::from_utf8_lossy(&output.stdout),
+ ))
+ }
+}
+
+pub async fn send_language_model_request(
+ model: Arc<dyn LanguageModel>,
+ request: LanguageModelRequest,
+ cx: &AsyncApp,
+) -> anyhow::Result<String> {
+ match model.stream_completion_text(request, &cx).await {
+ Ok(mut stream) => {
+ let mut full_response = String::new();
+ while let Some(chunk_result) = stream.stream.next().await {
+ match chunk_result {
+ Ok(chunk_str) => {
+ full_response.push_str(&chunk_str);
+ }
+ Err(err) => {
+ return Err(anyhow!(
+ "Error receiving response from language model: {err}"
+ ));
+ }
+ }
+ }
+ Ok(full_response)
+ }
+ Err(err) => Err(anyhow!(
+ "Failed to get response from language model. Error was: {err}"
+ )),
+ }
+}
+
+pub struct RequestMarkdown {
+ pub tools: String,
+ pub messages: String,
+}
+
+impl RequestMarkdown {
+ pub fn new(request: &LanguageModelRequest) -> Self {
+ let mut tools = String::new();
+ let mut messages = String::new();
+ let mut assistant_message_number: u32 = 1;
+
+ // Print the tools
+ if !request.tools.is_empty() {
+ for tool in &request.tools {
+ write!(&mut tools, "# {}\n\n", tool.name).unwrap();
+ write!(&mut tools, "{}\n\n", tool.description).unwrap();
+ write!(
+ &mut tools,
+ "{}\n",
+ MarkdownString::code_block("json", &format!("{:#}", tool.input_schema))
+ )
+ .unwrap();
+ }
+ }
+
+ // Print the messages
+ for message in &request.messages {
+ match message.role {
+ Role::System => messages.push_str("# ⚙️ SYSTEM\n\n"),
+ Role::User => messages.push_str("# 👤 USER\n\n"),
+ Role::Assistant => {
+ messages.push_str(&format!("# 🤖 ASSISTANT {assistant_message_number}\n\n"));
+ assistant_message_number += 1;
+ }
+ };
+
+ for content in &message.content {
+ match content {
+ MessageContent::Text(text) => {
+ messages.push_str(text);
+ messages.push_str("\n\n");
+ }
+ MessageContent::Image(_) => {
+ messages.push_str("[IMAGE DATA]\n\n");
+ }
+ MessageContent::Thinking { text, signature } => {
+ messages.push_str("**Thinking**:\n\n");
+ if let Some(sig) = signature {
+ messages.push_str(&format!("Signature: {}\n\n", sig));
+ }
+ messages.push_str(text);
+ messages.push_str("\n");
+ }
+ MessageContent::RedactedThinking(items) => {
+ messages.push_str(&format!(
+ "**Redacted Thinking**: {} item(s)\n\n",
+ items.len()
+ ));
+ }
+ MessageContent::ToolUse(tool_use) => {
+ messages.push_str(&format!(
+ "**Tool Use**: {} (ID: {})\n",
+ tool_use.name, tool_use.id
+ ));
+ messages.push_str(&format!(
+ "{}\n",
+ MarkdownString::code_block("json", &format!("{:#}", tool_use.input))
+ ));
+ }
+ MessageContent::ToolResult(tool_result) => {
+ messages.push_str(&format!(
+ "**Tool Result**: {} (ID: {})\n\n",
+ tool_result.tool_name, tool_result.tool_use_id
+ ));
+ if tool_result.is_error {
+ messages.push_str("**ERROR:**\n");
+ }
+ messages.push_str(&format!("{}\n\n", tool_result.content));
+ }
+ }
+ }
+ }
+
+ Self { tools, messages }
+ }
+}
+
+pub fn response_events_to_markdown(
+ response_events: &[std::result::Result<LanguageModelCompletionEvent, String>],
+) -> String {
+ let mut response = String::new();
+ // Print the response events if any
+ response.push_str("# Response\n\n");
+ let mut text_buffer = String::new();
+ let mut thinking_buffer = String::new();
+
+ let flush_buffers =
+ |output: &mut String, text_buffer: &mut String, thinking_buffer: &mut String| {
+ if !text_buffer.is_empty() {
+ output.push_str(&format!("**Text**:\n{}\n\n", text_buffer));
+ text_buffer.clear();
+ }
+ if !thinking_buffer.is_empty() {
+ output.push_str(&format!("**Thinking**:\n{}\n\n", thinking_buffer));
+ thinking_buffer.clear();
+ }
+ };
+
+ for event in response_events {
+ match event {
+ Ok(LanguageModelCompletionEvent::Text(text)) => {
+ text_buffer.push_str(text);
+ }
+ Ok(LanguageModelCompletionEvent::Thinking { text, .. }) => {
+ thinking_buffer.push_str(text);
+ }
+ Ok(LanguageModelCompletionEvent::Stop(reason)) => {
+ flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
+ response.push_str(&format!("**Stop**: {:?}\n\n", reason));
+ }
+ Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
+ flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
+ response.push_str(&format!(
+ "**Tool Use**: {} (ID: {})\n",
+ tool_use.name, tool_use.id
+ ));
+ response.push_str(&format!(
+ "{}\n",
+ MarkdownString::code_block("json", &format!("{:#}", tool_use.input))
+ ));
+ }
+ Ok(
+ LanguageModelCompletionEvent::UsageUpdate(_)
+ | LanguageModelCompletionEvent::StartMessage { .. },
+ ) => {}
+ Err(error) => {
+ flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
+ response.push_str(&format!("**Error**: {}\n\n", error));
+ }
+ }
+ }
+
+ flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
+
+ response
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ #[test]
+ fn test_parse_judge_output() {
+ let response = r#"
+ <analysis>The model did a good job but there were still compilations errors.</analysis>
+ <passed>true</passed>
+ "#
+ .unindent();
+
+ let output = parse_assertion_result(&response).unwrap();
+ assert_eq!(
+ output.analysis,
+ Some("The model did a good job but there were still compilations errors.".into())
+ );
+ assert_eq!(output.passed, true);
+
+ let response = r#"
+ Text around ignored
+
+ <analysis>
+ Failed to compile:
+ - Error 1
+ - Error 2
+ </analysis>
+
+ <passed>false</passed>
+ "#
+ .unindent();
+
+ let output = parse_assertion_result(&response).unwrap();
+ assert_eq!(
+ output.analysis,
+ Some("Failed to compile:\n- Error 1\n- Error 2".into())
+ );
+ assert_eq!(output.passed, false);
+ }
+}
@@ -1,5 +1,5 @@
-You are an expert software developer. Your task is to evaluate a diff produced by an AI agent in response to a prompt.
-Here is the prompt and the diff:
+You are an expert software developer. Your task is to evaluate a diff produced by an AI agent
+in response to a prompt. Here is the prompt and the diff:
<prompt>
{{{prompt}}}
@@ -9,17 +9,17 @@ Here is the prompt and the diff:
{{{repository_diff}}}
</diff>
-Evaluate how many of the following criteria were satisfied by the diff:
+Evaluate whether or not the diff passes the following assertion:
-<criteria>
-{{criteria}}
-- There are no changes unrelated to the prompt
-</criteria>
+<assertion>
+{{assertion}}
+</assertion>
Analyze the diff hunk by hunk, and structure your answer in the following XML format:
```
<analysis>{YOUR ANALYSIS HERE}</analysis>
-<total_criteria>{THE TOTAL NUMBER OF CRITERIA THAT WERE LISTED}</total_criteria>
-<passing_criteria>{THE NUMBER OF CRITERIA THAT ARE MET BY THE DIFF}</passing_criteria>
+<passed>{PASSED_ASSERTION}</passed>
```
+
+Where `PASSED_ASSERTION` is either `true` or `false`.
@@ -1,19 +1,21 @@
-You are an expert software developer. Your task is to evaluate an AI agent's messages and tool calls in this conversation:
+You are an expert software developer.
+Your task is to evaluate an AI agent's messages and tool calls in this conversation:
<messages>
{{{messages}}}
</messages>
-You must count how many of the following criteria were satisfied by the messages:
+Evaluate whether or not the sequence of messages passes the following assertion:
-<criteria>
-{{{criteria}}}
-</criteria>
+<assertion>
+{{{assertion}}}
+</assertion>
Analyze the messages one by one, and structure your answer in the following XML format:
```
<analysis>{YOUR ANALYSIS HERE}</analysis>
-<total_criteria>{THE TOTAL NUMBER OF CRITERIA THAT WERE LISTED}</total_criteria>
-<passing_criteria>{THE NUMBER OF CRITERIA THAT ARE MET BY THE MESSAGES}</passing_criteria>
+<passed>{PASSED_ASSERTION}</passed>
```
+
+Where `PASSED_ASSERTION` is either `true` or `false`.
@@ -24,6 +24,10 @@ impl ToolMetrics {
*self.failure_counts.entry(tool_name.clone()).or_insert(0) += failure_count;
}
}
+
+ pub fn is_empty(&self) -> bool {
+ self.use_counts.is_empty() && self.failure_counts.is_empty()
+ }
}
impl Display for ToolMetrics {
@@ -79,7 +83,7 @@ impl Display for ToolMetrics {
let failure_count = self.failure_counts.get(&tool_name).cloned().unwrap_or(0);
writeln!(
f,
- "│{:^30}│{:^10}│{:^10}│{:^10}│",
+ "│{:<30}│{:^10}│{:^10}│{:^10}│",
tool_name,
use_count,
failure_count,