1use std::{
2 io::{IsTerminal, Write},
3 path::PathBuf,
4 sync::Arc,
5};
6
7use anyhow::Result;
8use clap::Args;
9use collections::HashSet;
10use gpui::{AsyncApp, Entity};
11use project::Project;
12use util::ResultExt as _;
13use zeta2::{Zeta, udiff::DiffLine};
14
15use crate::{
16 PromptFormat,
17 example::{Example, NamedExample},
18 headless::ZetaCliAppState,
19 paths::print_run_data_dir,
20 predict::{CacheMode, PredictionDetails, zeta2_predict},
21};
22
23#[derive(Debug, Args)]
24pub struct EvaluateArguments {
25 example_paths: Vec<PathBuf>,
26 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
27 prompt_format: PromptFormat,
28 #[arg(long)]
29 use_expected_context: bool,
30 #[clap(long, value_enum, default_value_t = CacheMode::default())]
31 cache: CacheMode,
32 #[clap(short, long, default_value_t = 1, alias = "repeat")]
33 repetitions: u16,
34}
35
36pub async fn run_evaluate(
37 args: EvaluateArguments,
38 app_state: &Arc<ZetaCliAppState>,
39 cx: &mut AsyncApp,
40) {
41 if args.example_paths.is_empty() {
42 eprintln!("No examples provided");
43 return;
44 }
45 let all_tasks = args.example_paths.into_iter().map(|path| {
46 let app_state = app_state.clone();
47 let example = NamedExample::load(&path).unwrap();
48
49 cx.spawn(async move |cx| {
50 let (project, zetas, _edited_buffers) = example
51 .setup_project(&app_state, args.repetitions, cx)
52 .await
53 .unwrap();
54
55 let tasks = zetas.into_iter().enumerate().map(|(repetition_ix, zeta)| {
56 let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
57 let example = example.clone();
58 let project = project.clone();
59
60 cx.spawn(async move |cx| {
61 let name = example.name.clone();
62 run_evaluate_one(
63 example,
64 repetition_ix,
65 project,
66 zeta,
67 args.prompt_format,
68 args.use_expected_context,
69 args.cache,
70 cx,
71 )
72 .await
73 .map_err(|err| (err, name, repetition_ix))
74 })
75 });
76 futures::future::join_all(tasks).await
77 })
78 });
79 let all_results = futures::future::join_all(all_tasks).await;
80
81 write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap();
82 if let Some(mut output_file) =
83 std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err()
84 {
85 write_aggregated_scores(&mut output_file, &all_results).log_err();
86 };
87 print_run_data_dir(args.repetitions == 1, std::io::stdout().is_terminal());
88}
89
90fn write_aggregated_scores(
91 w: &mut impl std::io::Write,
92 all_results: &Vec<Vec<Result<EvaluationResult, (anyhow::Error, String, Option<u16>)>>>,
93) -> Result<()> {
94 let mut successful = Vec::new();
95 let mut failed_count = 0;
96
97 for result in all_results.iter().flatten() {
98 match result {
99 Ok(eval_result) => successful.push(eval_result),
100 Err((err, name, repetition_ix)) => {
101 if failed_count == 0 {
102 writeln!(w, "## Errors\n")?;
103 }
104
105 failed_count += 1;
106 let err = format!("{err:?}")
107 .replace("<edits", "```xml\n<edits")
108 .replace("</edits>", "</edits>\n```");
109 writeln!(
110 w,
111 "### ERROR {name}{}\n\n{err}\n",
112 repetition_ix
113 .map(|ix| format!(" [RUN {ix:03}]"))
114 .unwrap_or_default()
115 )?;
116 }
117 }
118 }
119
120 if successful.len() > 1 {
121 let aggregated_result = EvaluationResult {
122 context: Scores::aggregate(successful.iter().map(|r| &r.context)),
123 edit_prediction: Scores::aggregate(successful.iter().map(|r| &r.edit_prediction)),
124 };
125
126 writeln!(w, "\n{}", "-".repeat(80))?;
127 writeln!(w, "\n## TOTAL SCORES")?;
128 writeln!(w, "\n### Success Rate")?;
129 writeln!(w, "{}", aggregated_result)?;
130 }
131
132 if successful.len() + failed_count > 1 {
133 writeln!(
134 w,
135 "\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
136 successful.len(),
137 successful.len() + failed_count,
138 (successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
139 )?;
140 }
141
142 Ok(())
143}
144
145pub async fn run_evaluate_one(
146 example: NamedExample,
147 repetition_ix: Option<u16>,
148 project: Entity<Project>,
149 zeta: Entity<Zeta>,
150 prompt_format: PromptFormat,
151 use_expected_context: bool,
152 cache_mode: CacheMode,
153 cx: &mut AsyncApp,
154) -> Result<EvaluationResult> {
155 let predict_result = zeta2_predict(
156 example.clone(),
157 project,
158 zeta,
159 repetition_ix,
160 prompt_format,
161 use_expected_context,
162 cache_mode,
163 cx,
164 )
165 .await?;
166
167 let evaluation_result = evaluate(&example.example, &predict_result);
168
169 if repetition_ix.is_none() {
170 write_eval_result(
171 &example,
172 &predict_result,
173 &evaluation_result,
174 &mut std::io::stdout(),
175 std::io::stdout().is_terminal(),
176 )?;
177 }
178
179 if let Some(mut results_file) =
180 std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err()
181 {
182 write_eval_result(
183 &example,
184 &predict_result,
185 &evaluation_result,
186 &mut results_file,
187 false,
188 )
189 .log_err();
190 }
191
192 anyhow::Ok(evaluation_result)
193}
194
195fn write_eval_result(
196 example: &NamedExample,
197 predictions: &PredictionDetails,
198 evaluation_result: &EvaluationResult,
199 out: &mut impl Write,
200 use_color: bool,
201) -> Result<()> {
202 writeln!(
203 out,
204 "## Expected edit prediction:\n\n```diff\n{}\n```\n",
205 compare_diffs(
206 &example.example.expected_patch,
207 &predictions.diff,
208 use_color
209 )
210 )?;
211 writeln!(
212 out,
213 "## Actual edit prediction:\n\n```diff\n{}\n```\n",
214 compare_diffs(
215 &predictions.diff,
216 &example.example.expected_patch,
217 use_color
218 )
219 )?;
220 writeln!(out, "{:#}", evaluation_result)?;
221
222 anyhow::Ok(())
223}
224
225#[derive(Debug, Default)]
226pub struct EvaluationResult {
227 pub edit_prediction: Scores,
228 pub context: Scores,
229}
230
231#[derive(Default, Debug)]
232pub struct Scores {
233 pub true_positives: usize,
234 pub false_positives: usize,
235 pub false_negatives: usize,
236}
237
238impl Scores {
239 pub fn new(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
240 let true_positives = expected.intersection(actual).count();
241 let false_positives = actual.difference(expected).count();
242 let false_negatives = expected.difference(actual).count();
243
244 Scores {
245 true_positives,
246 false_positives,
247 false_negatives,
248 }
249 }
250
251 pub fn to_markdown(&self) -> String {
252 format!(
253 "
254Precision : {:.4}
255Recall : {:.4}
256F1 Score : {:.4}
257True Positives : {}
258False Positives : {}
259False Negatives : {}",
260 self.precision(),
261 self.recall(),
262 self.f1_score(),
263 self.true_positives,
264 self.false_positives,
265 self.false_negatives
266 )
267 }
268
269 pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
270 let mut true_positives = 0;
271 let mut false_positives = 0;
272 let mut false_negatives = 0;
273
274 for score in scores {
275 true_positives += score.true_positives;
276 false_positives += score.false_positives;
277 false_negatives += score.false_negatives;
278 }
279
280 Scores {
281 true_positives,
282 false_positives,
283 false_negatives,
284 }
285 }
286
287 pub fn precision(&self) -> f64 {
288 if self.true_positives + self.false_positives == 0 {
289 0.0
290 } else {
291 self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
292 }
293 }
294
295 pub fn recall(&self) -> f64 {
296 if self.true_positives + self.false_negatives == 0 {
297 0.0
298 } else {
299 self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
300 }
301 }
302
303 pub fn f1_score(&self) -> f64 {
304 let recall = self.recall();
305 let precision = self.precision();
306 if precision + recall == 0.0 {
307 0.0
308 } else {
309 2.0 * precision * recall / (precision + recall)
310 }
311 }
312}
313
314impl std::fmt::Display for EvaluationResult {
315 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316 if f.alternate() {
317 self.fmt_table(f)
318 } else {
319 self.fmt_markdown(f)
320 }
321 }
322}
323
324impl EvaluationResult {
325 fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326 write!(
327 f,
328 r#"
329### Context Scores
330{}
331
332### Edit Prediction Scores
333{}
334"#,
335 self.context.to_markdown(),
336 self.edit_prediction.to_markdown()
337 )
338 }
339
340 fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341 writeln!(f, "### Scores\n")?;
342 writeln!(
343 f,
344 " TP FP FN Precision Recall F1"
345 )?;
346 writeln!(
347 f,
348 "──────────────────────────────────────────────────────────────────"
349 )?;
350 writeln!(
351 f,
352 "Context Retrieval {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
353 self.context.true_positives,
354 self.context.false_positives,
355 self.context.false_negatives,
356 self.context.precision() * 100.0,
357 self.context.recall() * 100.0,
358 self.context.f1_score() * 100.0
359 )?;
360 writeln!(
361 f,
362 "Edit Prediction {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
363 self.edit_prediction.true_positives,
364 self.edit_prediction.false_positives,
365 self.edit_prediction.false_negatives,
366 self.edit_prediction.precision() * 100.0,
367 self.edit_prediction.recall() * 100.0,
368 self.edit_prediction.f1_score() * 100.0
369 )
370 }
371}
372
373pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult {
374 let mut eval_result = EvaluationResult::default();
375
376 let actual_context_lines: HashSet<_> = preds
377 .excerpts
378 .iter()
379 .flat_map(|excerpt| {
380 excerpt
381 .text
382 .lines()
383 .map(|line| format!("{}: {line}", excerpt.path.display()))
384 })
385 .collect();
386
387 let mut false_positive_lines = actual_context_lines.clone();
388
389 for entry in &example.expected_context {
390 let mut best_alternative_score: Option<Scores> = None;
391
392 for alternative in &entry.alternatives {
393 let expected: HashSet<_> = alternative
394 .excerpts
395 .iter()
396 .flat_map(|excerpt| {
397 excerpt
398 .text
399 .lines()
400 .map(|line| format!("{}: {line}", excerpt.path.display()))
401 })
402 .collect();
403
404 let scores = Scores::new(&expected, &actual_context_lines);
405
406 false_positive_lines.retain(|line| !actual_context_lines.contains(line));
407
408 if best_alternative_score
409 .as_ref()
410 .is_none_or(|best| scores.recall() > best.recall())
411 {
412 best_alternative_score = Some(scores);
413 }
414 }
415
416 let best_alternative = best_alternative_score.unwrap_or_default();
417 eval_result.context.false_negatives += best_alternative.false_negatives;
418 eval_result.context.true_positives += best_alternative.true_positives;
419 }
420
421 eval_result.context.false_positives = false_positive_lines.len();
422
423 // todo: alternatives for patches
424 let expected_patch_lines = example
425 .expected_patch
426 .lines()
427 .map(DiffLine::parse)
428 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
429 .map(|line| line.to_string())
430 .collect();
431
432 let actual_patch_lines = preds
433 .diff
434 .lines()
435 .map(DiffLine::parse)
436 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
437 .map(|line| line.to_string())
438 .collect();
439
440 eval_result.edit_prediction = Scores::new(&expected_patch_lines, &actual_patch_lines);
441 eval_result
442}
443
444/// Return annotated `patch_a` so that:
445/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
446/// Additions and deletions that are present in `patch_b` will be highlighted in green.
447pub fn compare_diffs(patch_a: &str, patch_b: &str, use_color: bool) -> String {
448 let green = if use_color { "\x1b[32m✓ " } else { "" };
449 let red = if use_color { "\x1b[31m✗ " } else { "" };
450 let neutral = if use_color { " " } else { "" };
451 let reset = if use_color { "\x1b[0m" } else { "" };
452 let lines_a = patch_a.lines().map(DiffLine::parse);
453 let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
454
455 let annotated = lines_a
456 .map(|line| match line {
457 DiffLine::Addition(_) | DiffLine::Deletion(_) => {
458 if lines_b.contains(&line) {
459 format!("{green}{line}{reset}")
460 } else {
461 format!("{red}{line}{reset}")
462 }
463 }
464 _ => format!("{neutral}{line}{reset}"),
465 })
466 .collect::<Vec<String>>();
467
468 annotated.join("\n")
469}