1use std::{
2 collections::{BTreeSet, HashMap},
3 io::{IsTerminal, Write},
4 path::PathBuf,
5 sync::Arc,
6};
7
8use anyhow::Result;
9use clap::Args;
10use collections::HashSet;
11use gpui::{AsyncApp, Entity};
12use project::Project;
13use util::ResultExt as _;
14use zeta2::{Zeta, udiff::DiffLine};
15
16use crate::{
17 PromptFormat,
18 example::{Example, NamedExample},
19 headless::ZetaCliAppState,
20 paths::print_run_data_dir,
21 predict::{CacheMode, PredictionDetails, zeta2_predict},
22};
23
24#[derive(Debug, Args)]
25pub struct EvaluateArguments {
26 example_paths: Vec<PathBuf>,
27 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
28 prompt_format: PromptFormat,
29 #[arg(long)]
30 use_expected_context: bool,
31 #[clap(long, value_enum, default_value_t = CacheMode::default())]
32 cache: CacheMode,
33 #[clap(short, long, default_value_t = 1, alias = "repeat")]
34 repetitions: u16,
35 #[arg(long)]
36 skip_prediction: bool,
37}
38
39#[derive(Debug)]
40pub(crate) struct ExecutionData {
41 execution_id: String,
42 diff: String,
43 reasoning: String,
44}
45
46pub async fn run_evaluate(
47 args: EvaluateArguments,
48 app_state: &Arc<ZetaCliAppState>,
49 cx: &mut AsyncApp,
50) {
51 if args.example_paths.is_empty() {
52 eprintln!("No examples provided");
53 return;
54 }
55 let all_tasks = args.example_paths.into_iter().map(|path| {
56 let app_state = app_state.clone();
57 let example = NamedExample::load(&path).expect("Failed to load example");
58
59 cx.spawn(async move |cx| {
60 let (project, zetas, _edited_buffers) = example
61 .setup_project(&app_state, args.repetitions, cx)
62 .await
63 .unwrap();
64
65 let tasks = zetas.into_iter().enumerate().map(|(repetition_ix, zeta)| {
66 let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
67 let example = example.clone();
68 let project = project.clone();
69
70 cx.spawn(async move |cx| {
71 let name = example.name.clone();
72 run_evaluate_one(
73 example,
74 repetition_ix,
75 project,
76 zeta,
77 args.prompt_format,
78 args.use_expected_context,
79 !args.skip_prediction,
80 args.cache,
81 cx,
82 )
83 .await
84 .map_err(|err| (err, name, repetition_ix))
85 })
86 });
87 futures::future::join_all(tasks).await
88 })
89 });
90 let all_results = futures::future::join_all(all_tasks).await;
91
92 write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap();
93 if let Some(mut output_file) =
94 std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err()
95 {
96 write_aggregated_scores(&mut output_file, &all_results).log_err();
97 };
98
99 if args.repetitions > 1 {
100 if let Err(e) = write_bucketed_analysis(&all_results) {
101 eprintln!("Failed to write bucketed analysis: {:?}", e);
102 }
103 }
104
105 print_run_data_dir(args.repetitions == 1, std::io::stdout().is_terminal());
106}
107
108fn write_aggregated_scores(
109 w: &mut impl std::io::Write,
110 all_results: &Vec<
111 Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
112 >,
113) -> Result<()> {
114 let mut successful = Vec::new();
115 let mut failed_count = 0;
116
117 for result in all_results.iter().flatten() {
118 match result {
119 Ok((eval_result, _execution_data)) => successful.push(eval_result),
120 Err((err, name, repetition_ix)) => {
121 if failed_count == 0 {
122 writeln!(w, "## Errors\n")?;
123 }
124
125 failed_count += 1;
126 writeln!(w, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
127 }
128 }
129 }
130
131 if successful.len() > 1 {
132 let mut edit_predictions = successful
133 .iter()
134 .filter_map(|r| r.edit_prediction.as_ref())
135 .peekable();
136 let has_edit_predictions = edit_predictions.peek().is_some();
137 let aggregated_result = EvaluationResult {
138 context: Scores::aggregate(successful.iter().map(|r| &r.context)),
139 edit_prediction: has_edit_predictions.then(|| Scores::aggregate(edit_predictions)),
140 prompt_len: successful.iter().map(|r| r.prompt_len).sum::<usize>() / successful.len(),
141 generated_len: successful.iter().map(|r| r.generated_len).sum::<usize>()
142 / successful.len(),
143 context_lines_found_in_context: successful
144 .iter()
145 .map(|r| r.context_lines_found_in_context)
146 .sum::<usize>()
147 / successful.len(),
148 context_lines_in_expected_patch: successful
149 .iter()
150 .map(|r| r.context_lines_in_expected_patch)
151 .sum::<usize>()
152 / successful.len(),
153 };
154
155 writeln!(w, "\n{}", "-".repeat(80))?;
156 writeln!(w, "\n## TOTAL SCORES")?;
157 writeln!(w, "{:#}", aggregated_result)?;
158 }
159
160 if successful.len() + failed_count > 1 {
161 writeln!(
162 w,
163 "\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
164 successful.len(),
165 successful.len() + failed_count,
166 (successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
167 )?;
168 }
169
170 Ok(())
171}
172
173pub async fn run_evaluate_one(
174 example: NamedExample,
175 repetition_ix: Option<u16>,
176 project: Entity<Project>,
177 zeta: Entity<Zeta>,
178 prompt_format: PromptFormat,
179 use_expected_context: bool,
180 predict: bool,
181 cache_mode: CacheMode,
182 cx: &mut AsyncApp,
183) -> Result<(EvaluationResult, ExecutionData)> {
184 let predict_result = zeta2_predict(
185 example.clone(),
186 project,
187 zeta,
188 repetition_ix,
189 prompt_format,
190 use_expected_context,
191 cache_mode,
192 cx,
193 )
194 .await?;
195
196 let evaluation_result = evaluate(&example.example, &predict_result, predict);
197
198 if repetition_ix.is_none() {
199 write_eval_result(
200 &example,
201 &predict_result,
202 &evaluation_result,
203 &mut std::io::stdout(),
204 std::io::stdout().is_terminal(),
205 predict,
206 )?;
207 }
208
209 if let Some(mut results_file) =
210 std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err()
211 {
212 write_eval_result(
213 &example,
214 &predict_result,
215 &evaluation_result,
216 &mut results_file,
217 false,
218 predict,
219 )
220 .log_err();
221 }
222
223 let execution_data = ExecutionData {
224 execution_id: if let Some(rep_ix) = repetition_ix {
225 format!("{:03}", rep_ix)
226 } else {
227 example.name.clone()
228 },
229 diff: predict_result.diff.clone(),
230 reasoning: std::fs::read_to_string(
231 predict_result
232 .run_example_dir
233 .join("prediction_response.md"),
234 )
235 .unwrap_or_default(),
236 };
237
238 anyhow::Ok((evaluation_result, execution_data))
239}
240
241fn write_eval_result(
242 example: &NamedExample,
243 predictions: &PredictionDetails,
244 evaluation_result: &EvaluationResult,
245 out: &mut impl Write,
246 use_color: bool,
247 predict: bool,
248) -> Result<()> {
249 if predict {
250 writeln!(
251 out,
252 "## Expected edit prediction:\n\n```diff\n{}\n```\n",
253 compare_diffs(
254 &example.example.expected_patch,
255 &predictions.diff,
256 use_color
257 )
258 )?;
259 writeln!(
260 out,
261 "## Actual edit prediction:\n\n```diff\n{}\n```\n",
262 compare_diffs(
263 &predictions.diff,
264 &example.example.expected_patch,
265 use_color
266 )
267 )?;
268 }
269
270 writeln!(out, "{:#}", evaluation_result)?;
271
272 anyhow::Ok(())
273}
274
275#[derive(Debug, Default)]
276pub struct EvaluationResult {
277 pub edit_prediction: Option<Scores>,
278 pub context: Scores,
279 pub prompt_len: usize,
280 pub generated_len: usize,
281 pub context_lines_in_expected_patch: usize,
282 pub context_lines_found_in_context: usize,
283}
284
285#[derive(Default, Debug)]
286pub struct Scores {
287 pub true_positives: usize,
288 pub false_positives: usize,
289 pub false_negatives: usize,
290}
291
292impl Scores {
293 pub fn new(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
294 let true_positives = expected.intersection(actual).count();
295 let false_positives = actual.difference(expected).count();
296 let false_negatives = expected.difference(actual).count();
297
298 Scores {
299 true_positives,
300 false_positives,
301 false_negatives,
302 }
303 }
304
305 pub fn to_markdown(&self) -> String {
306 format!(
307 "
308Precision : {:.4}
309Recall : {:.4}
310F1 Score : {:.4}
311True Positives : {}
312False Positives : {}
313False Negatives : {}",
314 self.precision(),
315 self.recall(),
316 self.f1_score(),
317 self.true_positives,
318 self.false_positives,
319 self.false_negatives
320 )
321 }
322
323 pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
324 let mut true_positives = 0;
325 let mut false_positives = 0;
326 let mut false_negatives = 0;
327
328 for score in scores {
329 true_positives += score.true_positives;
330 false_positives += score.false_positives;
331 false_negatives += score.false_negatives;
332 }
333
334 Scores {
335 true_positives,
336 false_positives,
337 false_negatives,
338 }
339 }
340
341 pub fn precision(&self) -> f64 {
342 if self.true_positives + self.false_positives == 0 {
343 0.0
344 } else {
345 self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
346 }
347 }
348
349 pub fn recall(&self) -> f64 {
350 if self.true_positives + self.false_negatives == 0 {
351 0.0
352 } else {
353 self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
354 }
355 }
356
357 pub fn f1_score(&self) -> f64 {
358 let recall = self.recall();
359 let precision = self.precision();
360 if precision + recall == 0.0 {
361 0.0
362 } else {
363 2.0 * precision * recall / (precision + recall)
364 }
365 }
366}
367
368impl std::fmt::Display for EvaluationResult {
369 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
370 if f.alternate() {
371 self.fmt_table(f)
372 } else {
373 self.fmt_markdown(f)
374 }
375 }
376}
377
378impl EvaluationResult {
379 fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
380 write!(
381 f,
382 r#"
383### Context Scores
384{}
385"#,
386 self.context.to_markdown(),
387 )?;
388 if let Some(prediction) = &self.edit_prediction {
389 write!(
390 f,
391 r#"
392 ### Edit Prediction Scores
393 {}"#,
394 prediction.to_markdown()
395 )?;
396 }
397 Ok(())
398 }
399
400 fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
401 writeln!(f, "### Scores\n")?;
402 writeln!(
403 f,
404 " Prompt Generated RetrievedContext PatchContext TP FP FN Precision Recall F1"
405 )?;
406 writeln!(
407 f,
408 "─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────"
409 )?;
410 writeln!(
411 f,
412 "Context Retrieval {:<7} {:<9} {:<16} {:<16} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
413 "",
414 "",
415 "",
416 "",
417 self.context.true_positives,
418 self.context.false_positives,
419 self.context.false_negatives,
420 self.context.precision() * 100.0,
421 self.context.recall() * 100.0,
422 self.context.f1_score() * 100.0
423 )?;
424 if let Some(edit_prediction) = &self.edit_prediction {
425 writeln!(
426 f,
427 "Edit Prediction {:<7} {:<9} {:<16} {:<16} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
428 self.prompt_len,
429 self.generated_len,
430 self.context_lines_found_in_context,
431 self.context_lines_in_expected_patch,
432 edit_prediction.true_positives,
433 edit_prediction.false_positives,
434 edit_prediction.false_negatives,
435 edit_prediction.precision() * 100.0,
436 edit_prediction.recall() * 100.0,
437 edit_prediction.f1_score() * 100.0
438 )?;
439 }
440 Ok(())
441 }
442}
443
444fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult {
445 let mut eval_result = EvaluationResult {
446 prompt_len: preds.prompt_len,
447 generated_len: preds.generated_len,
448 ..Default::default()
449 };
450
451 let actual_context_lines: HashSet<_> = preds
452 .excerpts
453 .iter()
454 .flat_map(|excerpt| {
455 excerpt
456 .text
457 .lines()
458 .map(|line| format!("{}: {line}", excerpt.path.display()))
459 })
460 .collect();
461
462 let mut false_positive_lines = actual_context_lines.clone();
463
464 for entry in &example.expected_context {
465 let mut best_alternative_score: Option<Scores> = None;
466
467 for alternative in &entry.alternatives {
468 let expected: HashSet<_> = alternative
469 .excerpts
470 .iter()
471 .flat_map(|excerpt| {
472 excerpt
473 .text
474 .lines()
475 .map(|line| format!("{}: {line}", excerpt.path.display()))
476 })
477 .collect();
478
479 let scores = Scores::new(&expected, &actual_context_lines);
480
481 false_positive_lines.retain(|line| !expected.contains(line));
482
483 if best_alternative_score
484 .as_ref()
485 .is_none_or(|best| scores.recall() > best.recall())
486 {
487 best_alternative_score = Some(scores);
488 }
489 }
490
491 let best_alternative = best_alternative_score.unwrap_or_default();
492 eval_result.context.false_negatives += best_alternative.false_negatives;
493 eval_result.context.true_positives += best_alternative.true_positives;
494 }
495
496 eval_result.context.false_positives = false_positive_lines.len();
497
498 if predict {
499 // todo: alternatives for patches
500 let expected_patch = example
501 .expected_patch
502 .lines()
503 .map(DiffLine::parse)
504 .collect::<Vec<_>>();
505 let expected_patch_lines = expected_patch
506 .iter()
507 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
508 .map(|line| line.to_string())
509 .collect();
510 let expected_context_lines = expected_patch
511 .iter()
512 .filter_map(|line| {
513 if let DiffLine::Context(str) = line {
514 Some(String::from(*str))
515 } else {
516 None
517 }
518 })
519 .collect::<BTreeSet<_>>();
520 let actual_context_lines = preds
521 .excerpts
522 .iter()
523 .flat_map(|excerpt| excerpt.text.lines().map(ToOwned::to_owned))
524 .collect::<BTreeSet<_>>();
525
526 let matched = expected_context_lines
527 .intersection(&actual_context_lines)
528 .count();
529
530 let actual_patch_lines = preds
531 .diff
532 .lines()
533 .map(DiffLine::parse)
534 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
535 .map(|line| line.to_string())
536 .collect();
537
538 eval_result.edit_prediction = Some(Scores::new(&expected_patch_lines, &actual_patch_lines));
539 eval_result.context_lines_in_expected_patch = expected_context_lines.len();
540 eval_result.context_lines_found_in_context = matched;
541 }
542
543 eval_result
544}
545
546/// Return annotated `patch_a` so that:
547/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
548/// Additions and deletions that are present in `patch_b` will be highlighted in green.
549pub fn compare_diffs(patch_a: &str, patch_b: &str, use_color: bool) -> String {
550 let green = if use_color { "\x1b[32m✓ " } else { "" };
551 let red = if use_color { "\x1b[31m✗ " } else { "" };
552 let neutral = if use_color { " " } else { "" };
553 let reset = if use_color { "\x1b[0m" } else { "" };
554 let lines_a = patch_a.lines().map(DiffLine::parse);
555 let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
556
557 let annotated = lines_a
558 .map(|line| match line {
559 DiffLine::Addition(_) | DiffLine::Deletion(_) => {
560 if lines_b.contains(&line) {
561 format!("{green}{line}{reset}")
562 } else {
563 format!("{red}{line}{reset}")
564 }
565 }
566 _ => format!("{neutral}{line}{reset}"),
567 })
568 .collect::<Vec<String>>();
569
570 annotated.join("\n")
571}
572
573fn write_bucketed_analysis(
574 all_results: &Vec<
575 Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
576 >,
577) -> Result<()> {
578 #[derive(Debug)]
579 struct EditBucket {
580 diff: String,
581 is_correct: bool,
582 execution_indices: Vec<String>,
583 reasoning_samples: Vec<String>,
584 }
585
586 let mut total_executions = 0;
587 let mut empty_predictions = Vec::new();
588 let mut errors = Vec::new();
589
590 let mut buckets: HashMap<String, EditBucket> = HashMap::new();
591
592 for result in all_results.iter().flatten() {
593 total_executions += 1;
594
595 let (evaluation_result, execution_data) = match result {
596 Ok((eval_result, execution_data)) => {
597 if execution_data.diff.is_empty() {
598 empty_predictions.push(execution_data);
599 continue;
600 }
601 (eval_result, execution_data)
602 }
603 Err(err) => {
604 errors.push(err);
605 continue;
606 }
607 };
608
609 buckets
610 .entry(execution_data.diff.clone())
611 .and_modify(|bucket| {
612 bucket
613 .execution_indices
614 .push(execution_data.execution_id.clone());
615 bucket
616 .reasoning_samples
617 .push(execution_data.reasoning.clone());
618 })
619 .or_insert_with(|| EditBucket {
620 diff: execution_data.diff.clone(),
621 is_correct: {
622 evaluation_result
623 .edit_prediction
624 .as_ref()
625 .map_or(false, |edit_prediction| {
626 edit_prediction.false_positives == 0
627 && edit_prediction.false_negatives == 0
628 && edit_prediction.true_positives > 0
629 })
630 },
631 execution_indices: vec![execution_data.execution_id.clone()],
632 reasoning_samples: vec![execution_data.reasoning.clone()],
633 });
634 }
635
636 let mut sorted_buckets = buckets.into_values().collect::<Vec<_>>();
637 sorted_buckets.sort_by(|a, b| match (a.is_correct, b.is_correct) {
638 (true, false) => std::cmp::Ordering::Less,
639 (false, true) => std::cmp::Ordering::Greater,
640 _ => b.execution_indices.len().cmp(&a.execution_indices.len()),
641 });
642
643 let output_path = crate::paths::RUN_DIR.join("bucketed_analysis.md");
644 let mut output = std::fs::File::create(&output_path)?;
645
646 writeln!(output, "# Bucketed Edit Analysis\n")?;
647
648 writeln!(output, "## Summary\n")?;
649 writeln!(output, "- **Total executions**: {}", total_executions)?;
650
651 let correct_count: usize = sorted_buckets
652 .iter()
653 .filter(|b| b.is_correct)
654 .map(|b| b.execution_indices.len())
655 .sum();
656
657 let incorrect_count: usize = sorted_buckets
658 .iter()
659 .filter(|b| !b.is_correct)
660 .map(|b| b.execution_indices.len())
661 .sum();
662
663 writeln!(
664 output,
665 "- **Correct predictions**: {} ({:.1}%)",
666 correct_count,
667 (correct_count as f64 / total_executions as f64) * 100.0
668 )?;
669
670 writeln!(
671 output,
672 "- **Incorrect predictions**: {} ({:.1}%)",
673 incorrect_count,
674 (incorrect_count as f64 / total_executions as f64) * 100.0
675 )?;
676
677 writeln!(
678 output,
679 "- **No Predictions**: {} ({:.1}%)",
680 empty_predictions.len(),
681 (empty_predictions.len() as f64 / total_executions as f64) * 100.0
682 )?;
683
684 let unique_incorrect = sorted_buckets.iter().filter(|b| !b.is_correct).count();
685 writeln!(
686 output,
687 "- **Unique incorrect edit patterns**: {}\n",
688 unique_incorrect
689 )?;
690
691 writeln!(output, "---\n")?;
692
693 for (idx, bucket) in sorted_buckets.iter().filter(|b| b.is_correct).enumerate() {
694 if idx == 0 {
695 writeln!(
696 output,
697 "## Correct Predictions ({} occurrences)\n",
698 bucket.execution_indices.len()
699 )?;
700 }
701
702 writeln!(output, "**Predicted Edit:**\n")?;
703 writeln!(output, "```diff")?;
704 writeln!(output, "{}", bucket.diff)?;
705 writeln!(output, "```\n")?;
706
707 writeln!(
708 output,
709 "**Executions:** {}\n",
710 bucket.execution_indices.join(", ")
711 )?;
712 writeln!(output, "---\n")?;
713 }
714
715 for (idx, bucket) in sorted_buckets.iter().filter(|b| !b.is_correct).enumerate() {
716 writeln!(
717 output,
718 "## Incorrect Prediction #{} ({} occurrences)\n",
719 idx + 1,
720 bucket.execution_indices.len()
721 )?;
722
723 writeln!(output, "**Predicted Edit:**\n")?;
724 writeln!(output, "```diff")?;
725 writeln!(output, "{}", bucket.diff)?;
726 writeln!(output, "```\n")?;
727
728 writeln!(
729 output,
730 "**Executions:** {}\n",
731 bucket.execution_indices.join(", ")
732 )?;
733
734 for (exec_id, reasoning) in bucket
735 .execution_indices
736 .iter()
737 .zip(bucket.reasoning_samples.iter())
738 {
739 writeln!(output, "{}", fmt_execution(exec_id, reasoning))?;
740 }
741
742 writeln!(output, "\n---\n")?;
743 }
744
745 if !empty_predictions.is_empty() {
746 writeln!(
747 output,
748 "## No Predictions ({} occurrences)\n",
749 empty_predictions.len()
750 )?;
751
752 for execution_data in &empty_predictions {
753 writeln!(
754 output,
755 "{}",
756 fmt_execution(&execution_data.execution_id, &execution_data.reasoning)
757 )?;
758 }
759 writeln!(output, "\n---\n")?;
760 }
761
762 if !errors.is_empty() {
763 writeln!(output, "## Errors ({} occurrences)\n", errors.len())?;
764
765 for (err, name, repetition_ix) in &errors {
766 writeln!(output, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
767 }
768 writeln!(output, "\n---\n")?;
769 }
770
771 fn fmt_execution(exec_id: &str, reasoning: &str) -> String {
772 let exec_content = format!(
773 "\n### Execution {} `{}/{}/prediction_response.md`{}",
774 exec_id,
775 crate::paths::RUN_DIR.display(),
776 exec_id,
777 indent_text(&format!("\n\n```\n{}\n```\n", reasoning,), 2)
778 );
779 indent_text(&exec_content, 2)
780 }
781
782 fn indent_text(text: &str, spaces: usize) -> String {
783 let indent = " ".repeat(spaces);
784 text.lines()
785 .collect::<Vec<_>>()
786 .join(&format!("\n{}", indent))
787 }
788
789 Ok(())
790}
791
792fn fmt_evaluation_error(err: &anyhow::Error, name: &str, repetition_ix: &Option<u16>) -> String {
793 let err = format!("{err:?}")
794 .replace("<edits", "```xml\n<edits")
795 .replace("</edits>", "</edits>\n```");
796 format!(
797 "### ERROR {name}{}\n\n{err}\n",
798 repetition_ix
799 .map(|ix| format!(" [RUN {ix:03}]"))
800 .unwrap_or_default()
801 )
802}