1use std::{
2 io::{IsTerminal, Write},
3 path::PathBuf,
4 sync::Arc,
5};
6
7use anyhow::Result;
8use clap::Args;
9use collections::HashSet;
10use gpui::{AsyncApp, Entity};
11use project::Project;
12use util::ResultExt as _;
13use zeta2::{Zeta, udiff::DiffLine};
14
15use crate::{
16 PromptFormat,
17 example::{Example, NamedExample},
18 headless::ZetaCliAppState,
19 paths::print_run_data_dir,
20 predict::{CacheMode, PredictionDetails, zeta2_predict},
21};
22
23#[derive(Debug, Args)]
24pub struct EvaluateArguments {
25 example_paths: Vec<PathBuf>,
26 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
27 prompt_format: PromptFormat,
28 #[arg(long)]
29 use_expected_context: bool,
30 #[clap(long, value_enum, default_value_t = CacheMode::default())]
31 cache: CacheMode,
32 #[clap(short, long, default_value_t = 1, alias = "repeat")]
33 repetitions: u16,
34}
35
36pub async fn run_evaluate(
37 args: EvaluateArguments,
38 app_state: &Arc<ZetaCliAppState>,
39 cx: &mut AsyncApp,
40) {
41 if args.example_paths.is_empty() {
42 eprintln!("No examples provided");
43 return;
44 }
45 let all_tasks = args.example_paths.into_iter().map(|path| {
46 let app_state = app_state.clone();
47 let example = NamedExample::load(&path).unwrap();
48
49 cx.spawn(async move |cx| {
50 let (project, zetas, _edited_buffers) = example
51 .setup_project(&app_state, args.repetitions, cx)
52 .await
53 .unwrap();
54
55 let tasks = zetas.into_iter().enumerate().map(|(repetition_ix, zeta)| {
56 let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
57
58 let example = example.clone();
59 let project = project.clone();
60
61 cx.spawn(async move |cx| {
62 let name = example.name.clone();
63 run_evaluate_one(
64 example,
65 repetition_ix,
66 project,
67 zeta,
68 args.prompt_format,
69 args.use_expected_context,
70 args.cache,
71 cx,
72 )
73 .await
74 .map_err(|err| (err, name, repetition_ix))
75 })
76 });
77 futures::future::join_all(tasks).await
78 })
79 });
80 let all_results = futures::future::join_all(all_tasks).await;
81
82 write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap();
83 if let Some(mut output_file) =
84 std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err()
85 {
86 write_aggregated_scores(&mut output_file, &all_results).log_err();
87 };
88 print_run_data_dir(args.repetitions == 1);
89}
90
91fn write_aggregated_scores(
92 w: &mut impl std::io::Write,
93 all_results: &Vec<Vec<Result<EvaluationResult, (anyhow::Error, String, Option<u16>)>>>,
94) -> Result<()> {
95 let mut successful = Vec::new();
96 let mut failed_count = 0;
97
98 for result in all_results.iter().flatten() {
99 match result {
100 Ok(eval_result) => successful.push(eval_result),
101 Err((err, name, repetition_ix)) => {
102 if failed_count == 0 {
103 writeln!(w, "## Errors\n")?;
104 }
105
106 failed_count += 1;
107 let err = err
108 .to_string()
109 .replace("<edits", "```xml\n<edits")
110 .replace("</edits>", "</edits>\n```");
111 writeln!(
112 w,
113 "### ERROR {name}{}\n\n{err}\n",
114 repetition_ix
115 .map(|ix| format!(" [RUN {ix:03}]"))
116 .unwrap_or_default()
117 )?;
118 }
119 }
120 }
121
122 if successful.len() > 1 {
123 let aggregated_result = EvaluationResult {
124 context: Scores::aggregate(successful.iter().map(|r| &r.context)),
125 edit_prediction: Scores::aggregate(successful.iter().map(|r| &r.edit_prediction)),
126 };
127
128 writeln!(w, "\n{}", "-".repeat(80))?;
129 writeln!(w, "\n## TOTAL SCORES")?;
130 writeln!(w, "\n### Success Rate")?;
131 writeln!(w, "{}", aggregated_result)?;
132 }
133
134 if successful.len() + failed_count > 1 {
135 writeln!(
136 w,
137 "\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
138 successful.len(),
139 successful.len() + failed_count,
140 (successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
141 )?;
142 }
143
144 Ok(())
145}
146
147pub async fn run_evaluate_one(
148 example: NamedExample,
149 repetition_ix: Option<u16>,
150 project: Entity<Project>,
151 zeta: Entity<Zeta>,
152 prompt_format: PromptFormat,
153 use_expected_context: bool,
154 cache_mode: CacheMode,
155 cx: &mut AsyncApp,
156) -> Result<EvaluationResult> {
157 let predict_result = zeta2_predict(
158 example.clone(),
159 project,
160 zeta,
161 repetition_ix,
162 prompt_format,
163 use_expected_context,
164 cache_mode,
165 cx,
166 )
167 .await?;
168
169 let evaluation_result = evaluate(&example.example, &predict_result);
170
171 if repetition_ix.is_none() {
172 write_eval_result(
173 &example,
174 &predict_result,
175 &evaluation_result,
176 &mut std::io::stdout(),
177 )?;
178 }
179
180 if let Some(mut results_file) =
181 std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err()
182 {
183 write_eval_result(
184 &example,
185 &predict_result,
186 &evaluation_result,
187 &mut results_file,
188 )
189 .log_err();
190 }
191
192 anyhow::Ok(evaluation_result)
193}
194
195fn write_eval_result(
196 example: &NamedExample,
197 predictions: &PredictionDetails,
198 evaluation_result: &EvaluationResult,
199 out: &mut impl Write,
200) -> Result<()> {
201 writeln!(
202 out,
203 "## Expected edit prediction:\n\n```diff\n{}\n```\n",
204 compare_diffs(&example.example.expected_patch, &predictions.diff)
205 )?;
206 writeln!(
207 out,
208 "## Actual edit prediction:\n\n```diff\n{}\n```\n",
209 compare_diffs(&predictions.diff, &example.example.expected_patch)
210 )?;
211 writeln!(out, "{}", evaluation_result)?;
212
213 anyhow::Ok(())
214}
215
216#[derive(Debug, Default)]
217pub struct EvaluationResult {
218 pub edit_prediction: Scores,
219 pub context: Scores,
220}
221
222#[derive(Default, Debug)]
223pub struct Scores {
224 pub true_positives: usize,
225 pub false_positives: usize,
226 pub false_negatives: usize,
227}
228
229impl Scores {
230 pub fn new(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
231 let true_positives = expected.intersection(actual).count();
232 let false_positives = actual.difference(expected).count();
233 let false_negatives = expected.difference(actual).count();
234
235 Scores {
236 true_positives,
237 false_positives,
238 false_negatives,
239 }
240 }
241
242 pub fn to_markdown(&self) -> String {
243 format!(
244 "
245Precision : {:.4}
246Recall : {:.4}
247F1 Score : {:.4}
248True Positives : {}
249False Positives : {}
250False Negatives : {}",
251 self.precision(),
252 self.recall(),
253 self.f1_score(),
254 self.true_positives,
255 self.false_positives,
256 self.false_negatives
257 )
258 }
259
260 pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
261 let mut true_positives = 0;
262 let mut false_positives = 0;
263 let mut false_negatives = 0;
264
265 for score in scores {
266 true_positives += score.true_positives;
267 false_positives += score.false_positives;
268 false_negatives += score.false_negatives;
269 }
270
271 Scores {
272 true_positives,
273 false_positives,
274 false_negatives,
275 }
276 }
277
278 pub fn precision(&self) -> f64 {
279 if self.true_positives + self.false_positives == 0 {
280 0.0
281 } else {
282 self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
283 }
284 }
285
286 pub fn recall(&self) -> f64 {
287 if self.true_positives + self.false_negatives == 0 {
288 0.0
289 } else {
290 self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
291 }
292 }
293
294 pub fn f1_score(&self) -> f64 {
295 let recall = self.recall();
296 let precision = self.precision();
297 if precision + recall == 0.0 {
298 0.0
299 } else {
300 2.0 * precision * recall / (precision + recall)
301 }
302 }
303}
304
305impl std::fmt::Display for EvaluationResult {
306 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307 write!(
308 f,
309 r#"
310### Context Scores
311{}
312
313### Edit Prediction Scores
314{}
315"#,
316 self.context.to_markdown(),
317 self.edit_prediction.to_markdown()
318 )
319 }
320}
321
322pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult {
323 let mut eval_result = EvaluationResult::default();
324
325 let actual_context_lines: HashSet<_> = preds
326 .excerpts
327 .iter()
328 .flat_map(|excerpt| {
329 excerpt
330 .text
331 .lines()
332 .map(|line| format!("{}: {line}", excerpt.path.display()))
333 })
334 .collect();
335
336 let mut false_positive_lines = actual_context_lines.clone();
337
338 for entry in &example.expected_context {
339 let mut best_alternative_score: Option<Scores> = None;
340
341 for alternative in &entry.alternatives {
342 let expected: HashSet<_> = alternative
343 .excerpts
344 .iter()
345 .flat_map(|excerpt| {
346 excerpt
347 .text
348 .lines()
349 .map(|line| format!("{}: {line}", excerpt.path.display()))
350 })
351 .collect();
352
353 let scores = Scores::new(&expected, &actual_context_lines);
354
355 false_positive_lines.retain(|line| !actual_context_lines.contains(line));
356
357 if best_alternative_score
358 .as_ref()
359 .is_none_or(|best| scores.recall() > best.recall())
360 {
361 best_alternative_score = Some(scores);
362 }
363 }
364
365 let best_alternative = best_alternative_score.unwrap_or_default();
366 eval_result.context.false_negatives += best_alternative.false_negatives;
367 eval_result.context.true_positives += best_alternative.true_positives;
368 }
369
370 eval_result.context.false_positives = false_positive_lines.len();
371
372 // todo: alternatives for patches
373 let expected_patch_lines = example
374 .expected_patch
375 .lines()
376 .map(DiffLine::parse)
377 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
378 .map(|line| line.to_string())
379 .collect();
380
381 let actual_patch_lines = preds
382 .diff
383 .lines()
384 .map(DiffLine::parse)
385 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
386 .map(|line| line.to_string())
387 .collect();
388
389 eval_result.edit_prediction = Scores::new(&expected_patch_lines, &actual_patch_lines);
390 eval_result
391}
392
393/// Return annotated `patch_a` so that:
394/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
395/// Additions and deletions that are present in `patch_b` will be highlighted in green.
396pub fn compare_diffs(patch_a: &str, patch_b: &str) -> String {
397 let use_color = std::io::stdout().is_terminal();
398 let green = if use_color { "\x1b[32m✓ " } else { "" };
399 let red = if use_color { "\x1b[31m✗ " } else { "" };
400 let neutral = if use_color { " " } else { "" };
401 let reset = if use_color { "\x1b[0m" } else { "" };
402 let lines_a = patch_a.lines().map(DiffLine::parse);
403 let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
404
405 let annotated = lines_a
406 .map(|line| match line {
407 DiffLine::Addition(_) | DiffLine::Deletion(_) => {
408 if lines_b.contains(&line) {
409 format!("{green}{line}{reset}")
410 } else {
411 format!("{red}{line}{reset}")
412 }
413 }
414 _ => format!("{neutral}{line}{reset}"),
415 })
416 .collect::<Vec<String>>();
417
418 annotated.join("\n")
419}