1use std::{
2 collections::HashMap,
3 io::{IsTerminal, Write},
4 path::PathBuf,
5 sync::Arc,
6};
7
8use anyhow::Result;
9use clap::Args;
10use collections::HashSet;
11use gpui::{AsyncApp, Entity};
12use project::Project;
13use util::ResultExt as _;
14use zeta2::{Zeta, udiff::DiffLine};
15
16use crate::{
17 PromptFormat,
18 example::{Example, NamedExample},
19 headless::ZetaCliAppState,
20 paths::print_run_data_dir,
21 predict::{CacheMode, PredictionDetails, zeta2_predict},
22};
23
24#[derive(Debug, Args)]
25pub struct EvaluateArguments {
26 example_paths: Vec<PathBuf>,
27 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
28 prompt_format: PromptFormat,
29 #[arg(long)]
30 use_expected_context: bool,
31 #[clap(long, value_enum, default_value_t = CacheMode::default())]
32 cache: CacheMode,
33 #[clap(short, long, default_value_t = 1, alias = "repeat")]
34 repetitions: u16,
35 #[arg(long)]
36 skip_prediction: bool,
37}
38
39#[derive(Debug)]
40pub(crate) struct ExecutionData {
41 execution_id: String,
42 diff: String,
43 reasoning: String,
44}
45
46pub async fn run_evaluate(
47 args: EvaluateArguments,
48 app_state: &Arc<ZetaCliAppState>,
49 cx: &mut AsyncApp,
50) {
51 if args.example_paths.is_empty() {
52 eprintln!("No examples provided");
53 return;
54 }
55 let all_tasks = args.example_paths.into_iter().map(|path| {
56 let app_state = app_state.clone();
57 let example = NamedExample::load(&path).expect("Failed to load example");
58
59 cx.spawn(async move |cx| {
60 let (project, zetas, _edited_buffers) = example
61 .setup_project(&app_state, args.repetitions, cx)
62 .await
63 .unwrap();
64
65 let tasks = zetas.into_iter().enumerate().map(|(repetition_ix, zeta)| {
66 let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
67 let example = example.clone();
68 let project = project.clone();
69
70 cx.spawn(async move |cx| {
71 let name = example.name.clone();
72 run_evaluate_one(
73 example,
74 repetition_ix,
75 project,
76 zeta,
77 args.prompt_format,
78 args.use_expected_context,
79 !args.skip_prediction,
80 args.cache,
81 cx,
82 )
83 .await
84 .map_err(|err| (err, name, repetition_ix))
85 })
86 });
87 futures::future::join_all(tasks).await
88 })
89 });
90 let all_results = futures::future::join_all(all_tasks).await;
91
92 write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap();
93 if let Some(mut output_file) =
94 std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err()
95 {
96 write_aggregated_scores(&mut output_file, &all_results).log_err();
97 };
98
99 if args.repetitions > 1 {
100 if let Err(e) = write_bucketed_analysis(&all_results) {
101 eprintln!("Failed to write bucketed analysis: {:?}", e);
102 }
103 }
104
105 print_run_data_dir(args.repetitions == 1, std::io::stdout().is_terminal());
106}
107
108fn write_aggregated_scores(
109 w: &mut impl std::io::Write,
110 all_results: &Vec<
111 Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
112 >,
113) -> Result<()> {
114 let mut successful = Vec::new();
115 let mut failed_count = 0;
116
117 for result in all_results.iter().flatten() {
118 match result {
119 Ok((eval_result, _execution_data)) => successful.push(eval_result),
120 Err((err, name, repetition_ix)) => {
121 if failed_count == 0 {
122 writeln!(w, "## Errors\n")?;
123 }
124
125 failed_count += 1;
126 writeln!(w, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
127 }
128 }
129 }
130
131 if successful.len() > 1 {
132 let mut edit_predictions = successful
133 .iter()
134 .filter_map(|r| r.edit_prediction.as_ref())
135 .peekable();
136 let has_edit_predictions = edit_predictions.peek().is_some();
137 let aggregated_result = EvaluationResult {
138 context: Scores::aggregate(successful.iter().map(|r| &r.context)),
139 edit_prediction: has_edit_predictions.then(|| Scores::aggregate(edit_predictions)),
140 prompt_len: successful.iter().map(|r| r.prompt_len).sum::<usize>() / successful.len(),
141 generated_len: successful.iter().map(|r| r.generated_len).sum::<usize>()
142 / successful.len(),
143 };
144
145 writeln!(w, "\n{}", "-".repeat(80))?;
146 writeln!(w, "\n## TOTAL SCORES")?;
147 writeln!(w, "{:#}", aggregated_result)?;
148 }
149
150 if successful.len() + failed_count > 1 {
151 writeln!(
152 w,
153 "\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
154 successful.len(),
155 successful.len() + failed_count,
156 (successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
157 )?;
158 }
159
160 Ok(())
161}
162
163pub async fn run_evaluate_one(
164 example: NamedExample,
165 repetition_ix: Option<u16>,
166 project: Entity<Project>,
167 zeta: Entity<Zeta>,
168 prompt_format: PromptFormat,
169 use_expected_context: bool,
170 predict: bool,
171 cache_mode: CacheMode,
172 cx: &mut AsyncApp,
173) -> Result<(EvaluationResult, ExecutionData)> {
174 let predict_result = zeta2_predict(
175 example.clone(),
176 project,
177 zeta,
178 repetition_ix,
179 prompt_format,
180 use_expected_context,
181 cache_mode,
182 cx,
183 )
184 .await?;
185
186 let evaluation_result = evaluate(&example.example, &predict_result, predict);
187
188 if repetition_ix.is_none() {
189 write_eval_result(
190 &example,
191 &predict_result,
192 &evaluation_result,
193 &mut std::io::stdout(),
194 std::io::stdout().is_terminal(),
195 predict,
196 )?;
197 }
198
199 if let Some(mut results_file) =
200 std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err()
201 {
202 write_eval_result(
203 &example,
204 &predict_result,
205 &evaluation_result,
206 &mut results_file,
207 false,
208 predict,
209 )
210 .log_err();
211 }
212
213 let execution_data = ExecutionData {
214 execution_id: if let Some(rep_ix) = repetition_ix {
215 format!("{:03}", rep_ix)
216 } else {
217 example.name.clone()
218 },
219 diff: predict_result.diff.clone(),
220 reasoning: std::fs::read_to_string(
221 predict_result
222 .run_example_dir
223 .join("prediction_response.md"),
224 )
225 .unwrap_or_default(),
226 };
227
228 anyhow::Ok((evaluation_result, execution_data))
229}
230
231fn write_eval_result(
232 example: &NamedExample,
233 predictions: &PredictionDetails,
234 evaluation_result: &EvaluationResult,
235 out: &mut impl Write,
236 use_color: bool,
237 predict: bool,
238) -> Result<()> {
239 if predict {
240 writeln!(
241 out,
242 "## Expected edit prediction:\n\n```diff\n{}\n```\n",
243 compare_diffs(
244 &example.example.expected_patch,
245 &predictions.diff,
246 use_color
247 )
248 )?;
249 writeln!(
250 out,
251 "## Actual edit prediction:\n\n```diff\n{}\n```\n",
252 compare_diffs(
253 &predictions.diff,
254 &example.example.expected_patch,
255 use_color
256 )
257 )?;
258 }
259
260 writeln!(out, "{:#}", evaluation_result)?;
261
262 anyhow::Ok(())
263}
264
265#[derive(Debug, Default)]
266pub struct EvaluationResult {
267 pub edit_prediction: Option<Scores>,
268 pub context: Scores,
269 pub prompt_len: usize,
270 pub generated_len: usize,
271}
272
273#[derive(Default, Debug)]
274pub struct Scores {
275 pub true_positives: usize,
276 pub false_positives: usize,
277 pub false_negatives: usize,
278}
279
280impl Scores {
281 pub fn new(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
282 let true_positives = expected.intersection(actual).count();
283 let false_positives = actual.difference(expected).count();
284 let false_negatives = expected.difference(actual).count();
285
286 Scores {
287 true_positives,
288 false_positives,
289 false_negatives,
290 }
291 }
292
293 pub fn to_markdown(&self) -> String {
294 format!(
295 "
296Precision : {:.4}
297Recall : {:.4}
298F1 Score : {:.4}
299True Positives : {}
300False Positives : {}
301False Negatives : {}",
302 self.precision(),
303 self.recall(),
304 self.f1_score(),
305 self.true_positives,
306 self.false_positives,
307 self.false_negatives
308 )
309 }
310
311 pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
312 let mut true_positives = 0;
313 let mut false_positives = 0;
314 let mut false_negatives = 0;
315
316 for score in scores {
317 true_positives += score.true_positives;
318 false_positives += score.false_positives;
319 false_negatives += score.false_negatives;
320 }
321
322 Scores {
323 true_positives,
324 false_positives,
325 false_negatives,
326 }
327 }
328
329 pub fn precision(&self) -> f64 {
330 if self.true_positives + self.false_positives == 0 {
331 0.0
332 } else {
333 self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
334 }
335 }
336
337 pub fn recall(&self) -> f64 {
338 if self.true_positives + self.false_negatives == 0 {
339 0.0
340 } else {
341 self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
342 }
343 }
344
345 pub fn f1_score(&self) -> f64 {
346 let recall = self.recall();
347 let precision = self.precision();
348 if precision + recall == 0.0 {
349 0.0
350 } else {
351 2.0 * precision * recall / (precision + recall)
352 }
353 }
354}
355
356impl std::fmt::Display for EvaluationResult {
357 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
358 if f.alternate() {
359 self.fmt_table(f)
360 } else {
361 self.fmt_markdown(f)
362 }
363 }
364}
365
366impl EvaluationResult {
367 fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
368 write!(
369 f,
370 r#"
371### Context Scores
372{}
373"#,
374 self.context.to_markdown(),
375 )?;
376 if let Some(prediction) = &self.edit_prediction {
377 write!(
378 f,
379 r#"
380 ### Edit Prediction Scores
381 {}"#,
382 prediction.to_markdown()
383 )?;
384 }
385 Ok(())
386 }
387
388 fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
389 writeln!(f, "### Scores\n")?;
390 writeln!(
391 f,
392 " Prompt Generated TP FP FN Precision Recall F1"
393 )?;
394 writeln!(
395 f,
396 "────────────────────────────────────────────────────────────────────────────────────"
397 )?;
398 writeln!(
399 f,
400 "Context Retrieval {:<7} {:<10} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
401 "",
402 "",
403 self.context.true_positives,
404 self.context.false_positives,
405 self.context.false_negatives,
406 self.context.precision() * 100.0,
407 self.context.recall() * 100.0,
408 self.context.f1_score() * 100.0
409 )?;
410 if let Some(edit_prediction) = &self.edit_prediction {
411 writeln!(
412 f,
413 "Edit Prediction {:<7} {:<10} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
414 self.prompt_len,
415 self.generated_len,
416 edit_prediction.true_positives,
417 edit_prediction.false_positives,
418 edit_prediction.false_negatives,
419 edit_prediction.precision() * 100.0,
420 edit_prediction.recall() * 100.0,
421 edit_prediction.f1_score() * 100.0
422 )?;
423 }
424 Ok(())
425 }
426}
427
428pub fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult {
429 let mut eval_result = EvaluationResult {
430 prompt_len: preds.prompt_len,
431 generated_len: preds.generated_len,
432 ..Default::default()
433 };
434
435 let actual_context_lines: HashSet<_> = preds
436 .excerpts
437 .iter()
438 .flat_map(|excerpt| {
439 excerpt
440 .text
441 .lines()
442 .map(|line| format!("{}: {line}", excerpt.path.display()))
443 })
444 .collect();
445
446 let mut false_positive_lines = actual_context_lines.clone();
447
448 for entry in &example.expected_context {
449 let mut best_alternative_score: Option<Scores> = None;
450
451 for alternative in &entry.alternatives {
452 let expected: HashSet<_> = alternative
453 .excerpts
454 .iter()
455 .flat_map(|excerpt| {
456 excerpt
457 .text
458 .lines()
459 .map(|line| format!("{}: {line}", excerpt.path.display()))
460 })
461 .collect();
462
463 let scores = Scores::new(&expected, &actual_context_lines);
464
465 false_positive_lines.retain(|line| !actual_context_lines.contains(line));
466
467 if best_alternative_score
468 .as_ref()
469 .is_none_or(|best| scores.recall() > best.recall())
470 {
471 best_alternative_score = Some(scores);
472 }
473 }
474
475 let best_alternative = best_alternative_score.unwrap_or_default();
476 eval_result.context.false_negatives += best_alternative.false_negatives;
477 eval_result.context.true_positives += best_alternative.true_positives;
478 }
479
480 eval_result.context.false_positives = false_positive_lines.len();
481
482 if predict {
483 // todo: alternatives for patches
484 let expected_patch_lines = example
485 .expected_patch
486 .lines()
487 .map(DiffLine::parse)
488 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
489 .map(|line| line.to_string())
490 .collect();
491
492 let actual_patch_lines = preds
493 .diff
494 .lines()
495 .map(DiffLine::parse)
496 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
497 .map(|line| line.to_string())
498 .collect();
499
500 eval_result.edit_prediction = Some(Scores::new(&expected_patch_lines, &actual_patch_lines));
501 }
502
503 eval_result
504}
505
506/// Return annotated `patch_a` so that:
507/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
508/// Additions and deletions that are present in `patch_b` will be highlighted in green.
509pub fn compare_diffs(patch_a: &str, patch_b: &str, use_color: bool) -> String {
510 let green = if use_color { "\x1b[32m✓ " } else { "" };
511 let red = if use_color { "\x1b[31m✗ " } else { "" };
512 let neutral = if use_color { " " } else { "" };
513 let reset = if use_color { "\x1b[0m" } else { "" };
514 let lines_a = patch_a.lines().map(DiffLine::parse);
515 let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
516
517 let annotated = lines_a
518 .map(|line| match line {
519 DiffLine::Addition(_) | DiffLine::Deletion(_) => {
520 if lines_b.contains(&line) {
521 format!("{green}{line}{reset}")
522 } else {
523 format!("{red}{line}{reset}")
524 }
525 }
526 _ => format!("{neutral}{line}{reset}"),
527 })
528 .collect::<Vec<String>>();
529
530 annotated.join("\n")
531}
532
533fn write_bucketed_analysis(
534 all_results: &Vec<
535 Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
536 >,
537) -> Result<()> {
538 #[derive(Debug)]
539 struct EditBucket {
540 diff: String,
541 is_correct: bool,
542 execution_indices: Vec<String>,
543 reasoning_samples: Vec<String>,
544 }
545
546 let mut total_executions = 0;
547 let mut empty_predictions = Vec::new();
548 let mut errors = Vec::new();
549
550 let mut buckets: HashMap<String, EditBucket> = HashMap::new();
551
552 for result in all_results.iter().flatten() {
553 total_executions += 1;
554
555 let (evaluation_result, execution_data) = match result {
556 Ok((eval_result, execution_data)) => {
557 if execution_data.diff.is_empty() {
558 empty_predictions.push(execution_data);
559 continue;
560 }
561 (eval_result, execution_data)
562 }
563 Err(err) => {
564 errors.push(err);
565 continue;
566 }
567 };
568
569 buckets
570 .entry(execution_data.diff.clone())
571 .and_modify(|bucket| {
572 bucket
573 .execution_indices
574 .push(execution_data.execution_id.clone());
575 bucket
576 .reasoning_samples
577 .push(execution_data.reasoning.clone());
578 })
579 .or_insert_with(|| EditBucket {
580 diff: execution_data.diff.clone(),
581 is_correct: {
582 evaluation_result
583 .edit_prediction
584 .as_ref()
585 .map_or(false, |edit_prediction| {
586 edit_prediction.false_positives == 0
587 && edit_prediction.false_negatives == 0
588 && edit_prediction.true_positives > 0
589 })
590 },
591 execution_indices: vec![execution_data.execution_id.clone()],
592 reasoning_samples: vec![execution_data.reasoning.clone()],
593 });
594 }
595
596 let mut sorted_buckets = buckets.into_values().collect::<Vec<_>>();
597 sorted_buckets.sort_by(|a, b| match (a.is_correct, b.is_correct) {
598 (true, false) => std::cmp::Ordering::Less,
599 (false, true) => std::cmp::Ordering::Greater,
600 _ => b.execution_indices.len().cmp(&a.execution_indices.len()),
601 });
602
603 let output_path = crate::paths::RUN_DIR.join("bucketed_analysis.md");
604 let mut output = std::fs::File::create(&output_path)?;
605
606 writeln!(output, "# Bucketed Edit Analysis\n")?;
607
608 writeln!(output, "## Summary\n")?;
609 writeln!(output, "- **Total executions**: {}", total_executions)?;
610
611 let correct_count: usize = sorted_buckets
612 .iter()
613 .filter(|b| b.is_correct)
614 .map(|b| b.execution_indices.len())
615 .sum();
616
617 let incorrect_count: usize = sorted_buckets
618 .iter()
619 .filter(|b| !b.is_correct)
620 .map(|b| b.execution_indices.len())
621 .sum();
622
623 writeln!(
624 output,
625 "- **Correct predictions**: {} ({:.1}%)",
626 correct_count,
627 (correct_count as f64 / total_executions as f64) * 100.0
628 )?;
629
630 writeln!(
631 output,
632 "- **Incorrect predictions**: {} ({:.1}%)",
633 incorrect_count,
634 (incorrect_count as f64 / total_executions as f64) * 100.0
635 )?;
636
637 writeln!(
638 output,
639 "- **No Predictions**: {} ({:.1}%)",
640 empty_predictions.len(),
641 (empty_predictions.len() as f64 / total_executions as f64) * 100.0
642 )?;
643
644 let unique_incorrect = sorted_buckets.iter().filter(|b| !b.is_correct).count();
645 writeln!(
646 output,
647 "- **Unique incorrect edit patterns**: {}\n",
648 unique_incorrect
649 )?;
650
651 writeln!(output, "---\n")?;
652
653 for (idx, bucket) in sorted_buckets.iter().filter(|b| b.is_correct).enumerate() {
654 if idx == 0 {
655 writeln!(
656 output,
657 "## Correct Predictions ({} occurrences)\n",
658 bucket.execution_indices.len()
659 )?;
660 }
661
662 writeln!(output, "**Predicted Edit:**\n")?;
663 writeln!(output, "```diff")?;
664 writeln!(output, "{}", bucket.diff)?;
665 writeln!(output, "```\n")?;
666
667 writeln!(
668 output,
669 "**Executions:** {}\n",
670 bucket.execution_indices.join(", ")
671 )?;
672 writeln!(output, "---\n")?;
673 }
674
675 for (idx, bucket) in sorted_buckets.iter().filter(|b| !b.is_correct).enumerate() {
676 writeln!(
677 output,
678 "## Incorrect Prediction #{} ({} occurrences)\n",
679 idx + 1,
680 bucket.execution_indices.len()
681 )?;
682
683 writeln!(output, "**Predicted Edit:**\n")?;
684 writeln!(output, "```diff")?;
685 writeln!(output, "{}", bucket.diff)?;
686 writeln!(output, "```\n")?;
687
688 writeln!(
689 output,
690 "**Executions:** {}\n",
691 bucket.execution_indices.join(", ")
692 )?;
693
694 for (exec_id, reasoning) in bucket
695 .execution_indices
696 .iter()
697 .zip(bucket.reasoning_samples.iter())
698 {
699 writeln!(output, "{}", fmt_execution(exec_id, reasoning))?;
700 }
701
702 writeln!(output, "\n---\n")?;
703 }
704
705 if !empty_predictions.is_empty() {
706 writeln!(
707 output,
708 "## No Predictions ({} occurrences)\n",
709 empty_predictions.len()
710 )?;
711
712 for execution_data in &empty_predictions {
713 writeln!(
714 output,
715 "{}",
716 fmt_execution(&execution_data.execution_id, &execution_data.reasoning)
717 )?;
718 }
719 writeln!(output, "\n---\n")?;
720 }
721
722 if !errors.is_empty() {
723 writeln!(output, "## Errors ({} occurrences)\n", errors.len())?;
724
725 for (err, name, repetition_ix) in &errors {
726 writeln!(output, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
727 }
728 writeln!(output, "\n---\n")?;
729 }
730
731 fn fmt_execution(exec_id: &str, reasoning: &str) -> String {
732 let exec_content = format!(
733 "\n### Execution {} `{}/{}/prediction_response.md`{}",
734 exec_id,
735 crate::paths::RUN_DIR.display(),
736 exec_id,
737 indent_text(&format!("\n\n```\n{}\n```\n", reasoning,), 2)
738 );
739 indent_text(&exec_content, 2)
740 }
741
742 fn indent_text(text: &str, spaces: usize) -> String {
743 let indent = " ".repeat(spaces);
744 text.lines()
745 .collect::<Vec<_>>()
746 .join(&format!("\n{}", indent))
747 }
748
749 Ok(())
750}
751
752fn fmt_evaluation_error(err: &anyhow::Error, name: &str, repetition_ix: &Option<u16>) -> String {
753 let err = format!("{err:?}")
754 .replace("<edits", "```xml\n<edits")
755 .replace("</edits>", "</edits>\n```");
756 format!(
757 "### ERROR {name}{}\n\n{err}\n",
758 repetition_ix
759 .map(|ix| format!(" [RUN {ix:03}]"))
760 .unwrap_or_default()
761 )
762}