@@ -3,15 +3,16 @@ mod ids;
mod tool_metrics;
pub(crate) use example::*;
+use parking_lot::Mutex;
pub(crate) use tool_metrics::*;
use ::fs::RealFs;
use anyhow::{Result, anyhow};
use clap::Parser;
use client::{Client, ProxySettings, UserStore};
-use collections::HashSet;
+use collections::{HashMap, HashSet};
use extension::ExtensionHostProxy;
-use futures::{StreamExt, future};
+use futures::future;
use gpui::http_client::{Uri, read_proxy_from_env};
use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
use gpui_tokio::Tokio;
@@ -24,6 +25,7 @@ use prompt_store::PromptBuilder;
use release_channel::AppVersion;
use reqwest_client::ReqwestClient;
use settings::{Settings, SettingsStore};
+use std::collections::VecDeque;
use std::env;
use std::path::{Path, PathBuf};
use std::sync::Arc;
@@ -40,13 +42,9 @@ struct Args {
model: String,
#[arg(long, value_delimiter = ',', default_value = "rs,ts")]
languages: Vec<String>,
- /// How many times to run each example. Note that this is currently not very efficient as N
- /// worktrees will be created for the examples.
+ /// How many times to run each example.
#[arg(long, default_value = "1")]
- repetitions: u32,
- /// How many times to run the judge on each example run.
- #[arg(long, default_value = "3")]
- judge_repetitions: u32,
+ repetitions: usize,
/// Maximum number of examples to run concurrently.
#[arg(long, default_value = "10")]
concurrency: usize,
@@ -163,7 +161,6 @@ fn main() {
"\x1b[96m", // Bright Cyan
];
- let mut max_name_width = 0;
let mut skipped = Vec::new();
for example_path in &example_paths {
@@ -184,20 +181,7 @@ fn main() {
continue;
}
- // TODO: This creates a worktree per repetition. Ideally these examples should
- // either be run sequentially on the same worktree, or reuse worktrees when there
- // are more examples to run than the concurrency limit.
- for repetition_number in 0..args.repetitions {
- let mut example = example.clone();
- example.set_repetition_number(repetition_number);
-
- let name_len = example.name.len();
- if name_len > max_name_width {
- max_name_width = example.name.len();
- }
-
- examples.push(example);
- }
+ examples.extend(example.repeat(args.repetitions));
}
println!("Skipped examples: {}\n", skipped.join(", "));
@@ -210,6 +194,11 @@ fn main() {
let mut repo_urls = HashSet::default();
let mut clone_tasks = Vec::new();
+ let max_name_width = examples
+ .iter()
+ .map(|e| e.repetition_name().len())
+ .max()
+ .unwrap_or(0);
for (i, example) in examples.iter_mut().enumerate() {
let color = COLORS[i % COLORS.len()].to_string();
example.set_log_prefix_style(&color, max_name_width);
@@ -217,7 +206,7 @@ fn main() {
println!(
"{}Logging to: {}",
example.log_prefix,
- example.example_output_directory().display()
+ example.run_directory_path().display()
);
let repo_url = example.base.url.clone();
@@ -263,49 +252,53 @@ fn main() {
future::join_all(clone_tasks).await;
for example in examples.iter_mut() {
- example.setup().await?;
+ example.fetch().await?;
}
- let judge_repetitions = args.judge_repetitions;
- let concurrency = args.concurrency;
+ let examples = Arc::new(Mutex::new(VecDeque::from(examples)));
+ let results_by_example_name = Arc::new(Mutex::new(HashMap::default()));
- let tasks = examples.iter().map(|example| {
+ future::join_all((0..args.concurrency).map(|_| {
let app_state = app_state.clone();
let model = model.clone();
- let example = example.clone();
let zed_commit_sha = zed_commit_sha.clone();
let zed_branch_name = zed_branch_name.clone();
let run_id = run_id.clone();
+ let examples = examples.clone();
+ let results = results_by_example_name.clone();
cx.spawn(async move |cx| {
- let result = async {
- let run_output = cx
- .update(|cx| example.run(model.clone(), app_state.clone(), cx))?
- .await?;
- let judge_tasks = (0..judge_repetitions).map(|round| {
- run_judge_repetition(
+ loop {
+ let Some(mut example) = examples.lock().pop_front() else {
+ break;
+ };
+ let result = async {
+ example.setup().await?;
+ let run_output = cx
+ .update(|cx| example.run(model.clone(), app_state.clone(), cx))?
+ .await?;
+ let judge_output = judge_example(
example.clone(),
model.clone(),
&zed_commit_sha,
&zed_branch_name,
&run_id,
&run_output,
- round,
enable_telemetry,
cx,
)
- });
- let judge_outputs = future::join_all(judge_tasks).await;
- anyhow::Ok((run_output, judge_outputs))
+ .await;
+ anyhow::Ok((run_output, judge_output))
+ }
+ .await;
+ results
+ .lock()
+ .entry(example.name.clone())
+ .or_insert(Vec::new())
+ .push((example.clone(), result));
}
- .await;
- (example, result)
})
- });
-
- let results = futures::stream::iter(tasks)
- .buffer_unordered(concurrency)
- .collect::<Vec<_>>()
- .await;
+ }))
+ .await;
println!("\n\n");
print_header("EVAL RESULTS");
@@ -314,59 +307,64 @@ fn main() {
let mut thread_scores = Vec::new();
let mut error_count = 0;
- for (example, result) in results {
- print_header(&example.name);
-
- match result {
- Err(err) => {
- println!("💥 {}{:?}", example.log_prefix, err);
- error_count += 1;
- }
- Ok((run_output, judge_results)) => {
- cumulative_tool_metrics.merge(&run_output.tool_metrics);
-
- println!("┌───────┬──────┬────────┐");
- println!("│ Judge │ Diff │ Thread │");
- println!("├───────┼──────┼────────┤");
+ for (example_name, results) in results_by_example_name.lock().iter_mut() {
+ print_header(&example_name);
+
+ results.sort_unstable_by_key(|(example, _)| example.repetition);
+ let mut example_cumulative_tool_metrics = ToolMetrics::default();
+
+ println!("┌───────┬──────┬────────┐");
+ println!("│ Round │ Diff │ Thread │");
+ println!("├───────┼──────┼────────┤");
+ for (example, result) in results {
+ let run_dir_path = example.run_directory_path();
+ let relative_run_dir_path = run_dir_path.strip_prefix(root_dir).unwrap();
+
+ match result {
+ Err(err) => {
+ println!(
+ "|{:^7}│{:^6}│{:^8}│ {:?}{}",
+ example.repetition,
+ "N/A",
+ "N/A",
+ err,
+ relative_run_dir_path.display()
+ );
+ error_count += 1;
+ }
+ Ok((run_output, judge_result)) => {
+ cumulative_tool_metrics.merge(&run_output.tool_metrics);
+ example_cumulative_tool_metrics.merge(&run_output.tool_metrics);
- for (i, judge_result) in judge_results.iter().enumerate() {
match judge_result {
Ok(judge_output) => {
- let diff_score = judge_output.diff.score;
- diff_scores.push(diff_score);
-
- let thread_display = if let Some(thread) = &judge_output.thread
- {
- let thread_score = thread.score;
- thread_scores.push(thread_score);
- format!("{}", thread_score)
- } else {
- "N/A".to_string()
- };
-
+ diff_scores.push(judge_output.diff.score());
+ thread_scores.push(judge_output.thread.score());
println!(
- "|{:^7}│{:^6}│{:^8}│",
- i + 1,
- diff_score,
- thread_display
+ "|{:^7}│{:^6}│{:^8}│ {}",
+ example.repetition,
+ format!("{}%", judge_output.diff.score()),
+ format!("{}%", judge_output.thread.score()),
+ relative_run_dir_path.display()
);
}
Err(err) => {
- println!("|{:^7}│{:^6}│{:^8}│{:?}", i + 1, "N/A", "N/A", err);
+ println!(
+ "|{:^7}│{:^6}│{:^8}│{:?}│ {}",
+ example.repetition,
+ "N/A",
+ "N/A",
+ err,
+ relative_run_dir_path.display()
+ );
}
}
}
-
- println!("└───────┴──────┴────────┘");
-
- println!("{}", run_output.tool_metrics);
}
}
- println!(
- "{} > {}",
- " ".repeat(max_name_width),
- example.example_output_directory().display()
- );
+
+ println!("└───────┴──────┴────────┘");
+ println!("{}", example_cumulative_tool_metrics);
}
let diff_score_count = diff_scores.len();
@@ -380,24 +378,16 @@ fn main() {
println!("\n{error_count} examples failed to run!");
}
- if diff_score_count > 0 {
- println!("\nAverage code diff score: {average_diff_score}");
- }
+ println!("\nAverage code diff score: {average_diff_score}");
let thread_score_count = thread_scores.len();
+ let average_thread_score = thread_scores
+ .into_iter()
+ .map(|score| score as f32)
+ .sum::<f32>()
+ / (thread_score_count as f32);
- // We might have gotten no thread scores if we weren't asked to judge the thread.
- if thread_score_count > 0 {
- let average_thread_score = thread_scores
- .into_iter()
- .map(|score| score as f32)
- .sum::<f32>()
- / (thread_score_count as f32);
-
- if diff_score_count > 0 {
- println!("\nAverage thread score: {average_thread_score}");
- }
- }
+ println!("\nAverage thread score: {average_thread_score}");
print_header("CUMULATIVE TOOL METRICS");
println!("{}", cumulative_tool_metrics);
@@ -579,27 +569,26 @@ pub fn git_branch_for_path(repo_path: &Path) -> String {
}
}
-async fn run_judge_repetition(
+async fn judge_example(
example: Example,
model: Arc<dyn LanguageModel>,
zed_commit_sha: &str,
zed_branch_name: &str,
run_id: &str,
run_output: &RunOutput,
- round: u32,
enable_telemetry: bool,
cx: &AsyncApp,
) -> Result<JudgeOutput> {
- let judge_output = example.judge(model.clone(), &run_output, round, cx).await;
+ let judge_output = example.judge(model.clone(), &run_output, cx).await;
let diff_evaluation;
- let thread_diff_evaluation;
+ let thread_evaluation;
if let Ok(output) = judge_output.as_ref() {
diff_evaluation = Some(output.diff.clone());
- thread_diff_evaluation = output.thread.clone();
+ thread_evaluation = Some(output.thread.clone());
} else {
diff_evaluation = None;
- thread_diff_evaluation = None;
+ thread_evaluation = None;
}
if enable_telemetry {
@@ -609,9 +598,9 @@ async fn run_judge_repetition(
zed_branch_name = zed_branch_name,
run_id = run_id,
example_name = example.name.clone(),
- round = round,
+ example_repetition = example.repetition,
diff_evaluation = diff_evaluation,
- thread_evaluation = thread_diff_evaluation,
+ thread_evaluation = thread_evaluation,
tool_metrics = run_output.tool_metrics,
response_count = run_output.response_count,
token_usage = run_output.token_usage,
@@ -619,6 +608,8 @@ async fn run_judge_repetition(
model_provider = model.provider_id().to_string(),
repository_url = example.base.url.clone(),
repository_revision = example.base.revision.clone(),
+ diagnostic_summary_before = run_output.diagnostic_summary_before,
+ diagnostic_summary_after = run_output.diagnostic_summary_after,
diagnostics_before = run_output.diagnostics_before,
diagnostics_after = run_output.diagnostics_after,
);
@@ -12,7 +12,7 @@ use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
MessageContent, Role, StopReason, TokenUsage,
};
-use project::{Project, ProjectPath};
+use project::{DiagnosticSummary, Project, ProjectPath};
use serde::{Deserialize, Serialize};
use std::cell::RefCell;
use std::fmt::Write as _;
@@ -68,19 +68,25 @@ pub struct Example {
/// Content of `diff_criteria.md`
pub diff_criteria: String,
/// Content of `thread_criteria.md`, if that file exists (it's optional)
- pub thread_criteria: Option<String>,
- /// Path to the directory containing the requests and responses for the agentic loop
- pub run_directory_path: PathBuf,
+ pub thread_criteria: String,
/// Prefix used for logging that identifies this example
pub log_prefix: String,
- pub worktree_path: PathBuf,
+ /// The repetition number for this example (0-based)
+ /// When running multiple repetitions of the same example, each instance is assigned a unique repetition number.
+ /// This affects the worktree path and log prefix to avoid clobbering results between runs.
+ pub repetition: usize,
pub repo_path: PathBuf,
+ /// Path to the directory containing the requests and responses for the agentic loop
+ run_dir_path: PathBuf,
+ worktrees_dir: PathBuf,
}
-#[derive(Debug, Serialize, Deserialize, Clone)]
+#[derive(Debug, Serialize, Clone)]
pub struct RunOutput {
pub repository_diff: String,
pub ran_diagnostics_check: bool,
+ pub diagnostic_summary_before: DiagnosticSummary,
+ pub diagnostic_summary_after: DiagnosticSummary,
pub diagnostics_before: Option<String>,
pub diagnostics_after: Option<String>,
pub response_count: usize,
@@ -92,11 +98,6 @@ pub struct RunOutput {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JudgeDiffInput {
pub repository_diff: String,
- pub ran_diagnostics_check: bool,
- #[serde(skip_serializing_if = "Option::is_none")]
- pub diagnostics_before: Option<String>,
- #[serde(skip_serializing_if = "Option::is_none")]
- pub diagnostics_after: Option<String>,
pub criteria: String,
}
@@ -109,12 +110,13 @@ pub struct JudgeThreadInput {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JudgeResponse {
pub analysis: String,
- pub score: u32,
+ pub passing_criteria: u32,
+ pub total_criteria: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JudgeOutput {
- pub thread: Option<JudgeResponse>,
+ pub thread: JudgeResponse,
pub diff: JudgeResponse,
}
@@ -126,65 +128,66 @@ impl Example {
worktrees_dir: &Path,
repos_dir: &Path,
) -> Result<Self> {
- let name = Self::name_from_path(dir_path);
+ let name = dir_path.file_name().unwrap().to_string_lossy().to_string();
let base_path = dir_path.join("base.toml");
let prompt_path = dir_path.join("prompt.md");
let diff_criteria_path = dir_path.join("diff_criteria.md");
let thread_criteria_path = dir_path.join("thread_criteria.md");
- let thread_criteria = if thread_criteria_path.exists() {
- Some(fs::read_to_string(thread_criteria_path.clone())?)
- } else {
- None
- };
-
let base: ExampleBase = toml::from_str(&fs::read_to_string(&base_path)?)?;
let repo_path = repo_path_for_url(repos_dir, &base.url);
- let worktree_path = worktrees_dir
- .canonicalize()
- .unwrap()
- .join(&name)
- .join(&base.repo_name());
-
Ok(Example {
name: name.clone(),
base,
prompt: fs::read_to_string(prompt_path.clone())?,
- thread_criteria,
+ thread_criteria: fs::read_to_string(thread_criteria_path.clone())?,
diff_criteria: fs::read_to_string(diff_criteria_path.clone())?,
- run_directory_path: run_dir.to_path_buf(),
- worktree_path,
+ run_dir_path: run_dir.to_path_buf(),
+ worktrees_dir: worktrees_dir.to_path_buf(),
repo_path,
log_prefix: name,
+ repetition: 0,
})
}
- pub fn set_repetition_number(&mut self, repetition_number: u32) {
- if repetition_number > 0 {
- self.name = format!("{}-{}", self.name, repetition_number);
- }
+ pub fn repetition_name(&self) -> String {
+ format!("{}-{}", self.name, self.repetition)
}
- pub fn example_output_directory(&self) -> PathBuf {
- self.run_directory_path.join(&self.name)
+ pub fn worktree_path(&self) -> PathBuf {
+ self.worktrees_dir
+ .canonicalize()
+ .unwrap()
+ .join(self.repetition_name())
+ .join(&self.base.repo_name())
+ }
+
+ pub fn run_directory_path(&self) -> PathBuf {
+ self.run_dir_path.join(self.repetition_name())
+ }
+
+ /// Create an iterator that returns copies of this example with different repetition numbers
+ /// Each copy will have a different repetition number and worktree path based on the repetition
+ pub fn repeat(self, repetitions: usize) -> impl Iterator<Item = Self> {
+ (0..repetitions).map(move |repetition| {
+ let mut example = self.clone();
+ example.repetition = repetition;
+ example
+ })
}
pub fn set_log_prefix_style(&mut self, color: &str, name_width: usize) {
self.log_prefix = format!(
"{}{:<width$}\x1b[0m | ",
color,
- self.name,
+ self.repetition_name(),
width = name_width
);
}
- pub fn name_from_path(path: &Path) -> String {
- path.file_name().unwrap().to_string_lossy().to_string()
- }
-
/// Set up the example by checking out the specified Git revision
- pub async fn setup(&mut self) -> Result<()> {
+ pub async fn fetch(&mut self) -> Result<()> {
let revision_exists = run_git(
&self.repo_path,
&["rev-parse", &format!("{}^{{commit}}", self.base.revision)],
@@ -203,19 +206,24 @@ impl Example {
)
.await?;
}
+ Ok(())
+ }
- if self.worktree_path.is_dir() {
+ /// Set up the example by checking out the specified Git revision
+ pub async fn setup(&mut self) -> Result<()> {
+ let worktree_path = self.worktree_path();
+ if worktree_path.is_dir() {
println!("{}Resetting existing worktree", self.log_prefix);
// TODO: consider including "-x" to remove ignored files. The downside of this is that
// it will also remove build artifacts, and so prevent incremental reuse there.
- run_git(&self.worktree_path, &["clean", "--force", "-d"]).await?;
- run_git(&self.worktree_path, &["reset", "--hard", "HEAD"]).await?;
- run_git(&self.worktree_path, &["checkout", &self.base.revision]).await?;
+ run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
+ run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
+ run_git(&worktree_path, &["checkout", &self.base.revision]).await?;
} else {
println!("{}Creating worktree", self.log_prefix);
- let worktree_path_string = self.worktree_path.to_string_lossy().to_string();
+ let worktree_path_string = worktree_path.to_string_lossy().to_string();
run_git(
&self.repo_path,
@@ -231,10 +239,10 @@ impl Example {
}
if self.base.url == ZED_REPO_URL {
- std::fs::write(self.worktree_path.join(".rules"), std::fs::read(".rules")?)?;
+ std::fs::write(worktree_path.join(".rules"), std::fs::read(".rules")?)?;
}
- std::fs::create_dir_all(self.example_output_directory())?;
+ std::fs::create_dir_all(self.run_directory_path())?;
Ok(())
}
@@ -256,7 +264,7 @@ impl Example {
);
let worktree = project.update(cx, |project, cx| {
- project.create_worktree(&self.worktree_path, true, cx)
+ project.create_worktree(self.worktree_path(), true, cx)
});
let tools = cx.new(|_| ToolWorkingSet::default());
@@ -315,6 +323,9 @@ impl Example {
None
};
+ let diagnostic_summary_before = project.read_with(cx, |project, cx| {
+ project.diagnostic_summary(false, cx)
+ })?;
let diagnostics_before = query_lsp_diagnostics(project.clone(), cx).await?;
if diagnostics_before.is_some() && !this.base.allow_preexisting_diagnostics {
return Err(anyhow!("Example has pre-existing diagnostics. If you want to run this example regardless, set `allow_preexisting_diagnostics` to `true` in `base.toml`"));
@@ -324,7 +335,7 @@ impl Example {
return Err(anyhow!("Setup only mode"));
}
- let example_output_dir = this.example_output_directory();
+ let example_output_dir = this.run_directory_path();
let last_diff_file_path = example_output_dir.join("last.diff");
// Write an empty "last.diff" so that it can be opened in Zed for convenient view of the
@@ -491,6 +502,9 @@ impl Example {
std::fs::write(last_diff_file_path, &repository_diff)?;
println!("{}Getting diagnostics", this.log_prefix);
+ let diagnostic_summary_after = project.read_with(cx, |project, cx| {
+ project.diagnostic_summary(false, cx)
+ })?;
let diagnostics_after = cx
.update(move |cx| {
cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await)
@@ -522,6 +536,8 @@ impl Example {
RunOutput {
repository_diff,
ran_diagnostics_check: this.base.require_lsp,
+ diagnostic_summary_before,
+ diagnostic_summary_after,
diagnostics_before,
diagnostics_after,
response_count,
@@ -537,7 +553,6 @@ impl Example {
&self,
model: Arc<dyn LanguageModel>,
run_output: &RunOutput,
- judge_number: u32,
cx: &AsyncApp,
) -> Result<(String, JudgeResponse)> {
let judge_diff_prompt = include_str!("judge_diff_prompt.hbs");
@@ -549,9 +564,6 @@ impl Example {
judge_diff_prompt_name,
&JudgeDiffInput {
repository_diff: run_output.repository_diff.clone(),
- ran_diagnostics_check: run_output.ran_diagnostics_check,
- diagnostics_before: run_output.diagnostics_before.clone(),
- diagnostics_after: run_output.diagnostics_after.clone(),
criteria: self.diff_criteria.clone(),
},
)?;
@@ -573,8 +585,9 @@ impl Example {
let diff_output = JudgeResponse::parse(&diff_response)?;
println!(
- "{}Judge #{judge_number} - Diff score: {}",
- self.log_prefix, diff_output.score
+ "{}Judge - Diff score: {}%",
+ self.log_prefix,
+ diff_output.score()
);
Ok((diff_response, diff_output))
@@ -584,69 +597,60 @@ impl Example {
&self,
model: Arc<dyn LanguageModel>,
run_output: &RunOutput,
- judge_number: u32,
cx: &AsyncApp,
- ) -> Result<(String, Option<JudgeResponse>)> {
- if let Some(criteria) = self.thread_criteria.clone() {
- let judge_thread_prompt = include_str!("judge_thread_prompt.hbs");
- let judge_thread_prompt_name = "judge_thread_prompt";
- let mut hbs = Handlebars::new();
- hbs.register_template_string(judge_thread_prompt_name, judge_thread_prompt)?;
-
- let request_markdown = RequestMarkdown::new(&run_output.last_request);
- let thread_prompt = hbs.render(
- judge_thread_prompt_name,
- &JudgeThreadInput {
- messages: request_markdown.messages,
- criteria,
- },
- )?;
+ ) -> Result<(String, JudgeResponse)> {
+ let judge_thread_prompt = include_str!("judge_thread_prompt.hbs");
+ let judge_thread_prompt_name = "judge_thread_prompt";
+ let mut hbs = Handlebars::new();
+ hbs.register_template_string(judge_thread_prompt_name, judge_thread_prompt)?;
+
+ let request_markdown = RequestMarkdown::new(&run_output.last_request);
+ let thread_prompt = hbs.render(
+ judge_thread_prompt_name,
+ &JudgeThreadInput {
+ messages: request_markdown.messages,
+ criteria: self.thread_criteria.clone(),
+ },
+ )?;
- let request = LanguageModelRequest {
- thread_id: None,
- prompt_id: None,
- messages: vec![LanguageModelRequestMessage {
- role: Role::User,
- content: vec![MessageContent::Text(thread_prompt)],
- cache: false,
- }],
- temperature: None,
- tools: Vec::new(),
- stop: Vec::new(),
- };
+ let request = LanguageModelRequest {
+ thread_id: None,
+ prompt_id: None,
+ messages: vec![LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec![MessageContent::Text(thread_prompt)],
+ cache: false,
+ }],
+ temperature: None,
+ tools: Vec::new(),
+ stop: Vec::new(),
+ };
- let thread_response = send_language_model_request(model, request, cx).await?;
- let thread_output = JudgeResponse::parse(&thread_response)?;
+ let thread_response = send_language_model_request(model, request, cx).await?;
+ let thread_output = JudgeResponse::parse(&thread_response)?;
- println!(
- "{}Judge #{judge_number} - Thread score: {}",
- self.log_prefix, thread_output.score
- );
+ println!(
+ "{}Judge - Thread score: {}%",
+ self.log_prefix,
+ thread_output.score()
+ );
- Ok((thread_response, Some(thread_output)))
- } else {
- let msg = "There were no criteria specified for this thread, so this example was not judged on its thread.".to_string();
- Ok((msg, None))
- }
+ Ok((thread_response, thread_output))
}
pub async fn judge(
&self,
model: Arc<dyn LanguageModel>,
run_output: &RunOutput,
- judge_number: u32,
cx: &AsyncApp,
) -> Result<JudgeOutput> {
- let mut output_file = File::create(
- self.example_output_directory()
- .join(format!("judge_{}.md", judge_number)),
- )
- .expect("failed to create judge.md");
+ let mut output_file = File::create(self.run_directory_path().join("judge.md"))
+ .expect("failed to create judge.md");
- println!("{}Running judge #{judge_number}", self.log_prefix);
+ println!("{}Running judge", self.log_prefix);
- let diff_task = self.judge_diff(model.clone(), &run_output, judge_number, cx);
- let thread_task = self.judge_thread(model.clone(), &run_output, judge_number, cx);
+ let diff_task = self.judge_diff(model.clone(), &run_output, cx);
+ let thread_task = self.judge_thread(model.clone(), &run_output, cx);
let (diff_result, thread_result) = futures::join!(diff_task, thread_task);
@@ -666,12 +670,14 @@ impl Example {
}
async fn repository_diff(&self) -> Result<String> {
- run_git(&self.worktree_path, &["add", "."]).await?;
+ let worktree_path = self.worktree_path();
+
+ run_git(&worktree_path, &["add", "."]).await?;
let mut diff_args = vec!["diff", "--staged"];
if self.base.url == ZED_REPO_URL {
diff_args.push(":(exclude).rules");
}
- run_git(&self.worktree_path, &diff_args).await
+ run_git(&worktree_path, &diff_args).await
}
}
@@ -805,11 +811,21 @@ async fn query_lsp_diagnostics(
impl JudgeResponse {
fn parse(response: &str) -> Result<Self> {
let analysis = get_tag("analysis", response)?.to_string();
- let score = get_tag("score", response)?
+ let passing_criteria = get_tag("passing_criteria", response)?
.parse()
.context("error parsing score")?;
+ let total_criteria = get_tag("total_criteria", response)?
+ .parse()
+ .context("error parsing score")?;
+ Ok(Self {
+ analysis,
+ total_criteria,
+ passing_criteria,
+ })
+ }
- Ok(Self { analysis, score })
+ pub fn score(&self) -> u32 {
+ (100.0 * self.passing_criteria as f32 / self.total_criteria as f32).round() as u32
}
}
@@ -1042,13 +1058,13 @@ fn response_events_to_markdown(
#[cfg(test)]
mod test {
use super::*;
- use handlebars::Handlebars;
#[test]
fn test_parse_judge_output() {
let response = r#"
<analysis>The model did a good job but there were still compilations errors.</analysis>
- <score>3</score>
+ <passing_criteria>3</passing_criteria>
+ <total_criteria>5</total_criteria>
"#
.unindent();
@@ -1057,7 +1073,8 @@ mod test {
output.analysis,
"The model did a good job but there were still compilations errors."
);
- assert_eq!(output.score, 3);
+ assert_eq!(output.passing_criteria, 3);
+ assert_eq!(output.total_criteria, 5);
let response = r#"
Text around ignored
@@ -1068,162 +1085,15 @@ mod test {
- Error 2
</analysis>
- <score>1</score>
+ <passing_criteria>1</passing_criteria>
+
+ <total_criteria>3</total_criteria>
"#
.unindent();
let output = JudgeResponse::parse(&response).unwrap();
assert_eq!(output.analysis, "Failed to compile:\n- Error 1\n- Error 2");
- assert_eq!(output.score, 1);
- }
-
- #[test]
- fn test_judge_prompt_with_diagnostics() {
- // Case 1: Both diagnostics before and after are present
- let input = JudgeDiffInput {
- repository_diff: "diff content goes here".to_string(),
- ran_diagnostics_check: true,
- diagnostics_before: Some("Error at line 10: variable not found".to_string()),
- diagnostics_after: Some("Error at line 15: missing semicolon".to_string()),
- criteria: "Fix all bugs".to_string(),
- };
-
- let rendered = templates().render(JUDGE_PROMPT_NAME, &input).unwrap();
-
- let expected_diagnostics_section = r#"
- Take into account the diagnostics before and after applying the change:
-
- <diagnostics_before>
- Error at line 10: variable not found
- </diagnostics_before>
-
- <diagnostics_after>
- Error at line 15: missing semicolon
- </diagnostics_after>
- "#
- .unindent();
-
- assert!(rendered.contains(&expected_diagnostics_section));
- }
-
- #[test]
- fn test_judge_prompt_with_empty_diagnostics() {
- // Case 2: Diagnostics check run but no diagnostics found
- let input = JudgeDiffInput {
- repository_diff: "diff content goes here".to_string(),
- ran_diagnostics_check: true,
- diagnostics_before: None,
- diagnostics_after: None,
- criteria: "Fix all bugs".to_string(),
- };
-
- let rendered = templates().render(JUDGE_PROMPT_NAME, &input).unwrap();
-
- let expected_diagnostics_section = r#"
- Take into account the diagnostics before and after applying the change:
-
- <diagnostics_before>
- No diagnostics before applying the edits.
- </diagnostics_before>
-
- <diagnostics_after>
- No diagnostics after applying the edits.
- </diagnostics_after>
- "#
- .unindent();
-
- assert!(rendered.contains(&expected_diagnostics_section));
- }
-
- #[test]
- fn test_judge_prompt_with_mixed_diagnostics() {
- let templates = templates();
-
- // Case 3: Before diagnostics present, after diagnostics absent
- let input = JudgeDiffInput {
- repository_diff: "diff content goes here".to_string(),
- ran_diagnostics_check: true,
- diagnostics_before: Some("Error at line 10: variable not found".to_string()),
- diagnostics_after: None,
- criteria: "Fix all bugs".to_string(),
- };
-
- let rendered = templates.render(JUDGE_PROMPT_NAME, &input).unwrap();
-
- let expected_diagnostics_section = r#"
- Take into account the diagnostics before and after applying the change:
-
- <diagnostics_before>
- Error at line 10: variable not found
- </diagnostics_before>
-
- <diagnostics_after>
- No diagnostics after applying the edits.
- </diagnostics_after>
- "#
- .unindent();
-
- assert!(rendered.contains(&expected_diagnostics_section));
-
- // Case 4: Before diagnostics absent, after diagnostics present
- let input = JudgeDiffInput {
- repository_diff: "diff content goes here".to_string(),
- ran_diagnostics_check: true,
- diagnostics_before: None,
- diagnostics_after: Some("Error at line 15: missing semicolon".to_string()),
- criteria: "Fix all bugs".to_string(),
- };
-
- let rendered = templates.render(JUDGE_PROMPT_NAME, &input).unwrap();
-
- let expected_diagnostics_section = r#"
- Take into account the diagnostics before and after applying the change:
-
- <diagnostics_before>
- No diagnostics before applying the edits.
- </diagnostics_before>
-
- <diagnostics_after>
- Error at line 15: missing semicolon
- </diagnostics_after>
- "#
- .unindent();
-
- assert!(rendered.contains(&expected_diagnostics_section));
- }
-
- #[test]
- fn test_judge_prompt_without_diagnostics() {
- let templates = templates();
-
- // Case 5: No diagnostics check run
- let input = JudgeDiffInput {
- repository_diff: "diff content goes here".to_string(),
- ran_diagnostics_check: false,
- diagnostics_before: None,
- diagnostics_after: None,
- criteria: "Fix all bugs".to_string(),
- };
-
- let rendered = templates.render(JUDGE_PROMPT_NAME, &input).unwrap();
-
- // Check for the message when no diagnostics were performed
- let diagnostics_message = "No diagnostic checks were performed.";
-
- assert!(rendered.contains(diagnostics_message));
- assert!(!rendered.contains("<diagnostics_before>"));
- assert!(!rendered.contains("<diagnostics_after>"));
- }
-
- const JUDGE_PROMPT_NAME: &str = "judge_prompt";
-
- fn templates() -> Handlebars<'static> {
- let mut judge_prompt = include_str!("judge_diff_prompt.hbs").to_string();
- language::LineEnding::normalize(&mut judge_prompt);
- let mut handlebars = Handlebars::new();
- handlebars
- .register_template_string(JUDGE_PROMPT_NAME, judge_prompt)
- .unwrap();
- handlebars
+ assert_eq!(output.passing_criteria, 1);
+ assert_eq!(output.total_criteria, 3);
}
}