1use std::{
2 collections::{BTreeSet, HashMap},
3 io::{IsTerminal, Write},
4 sync::Arc,
5};
6
7use anyhow::Result;
8use collections::HashSet;
9use gpui::{AsyncApp, Entity};
10use project::Project;
11use util::ResultExt as _;
12use zeta2::{Zeta, udiff::DiffLine};
13
14use crate::{
15 EvaluateArguments, PredictionOptions,
16 example::{Example, NamedExample},
17 headless::ZetaCliAppState,
18 paths::print_run_data_dir,
19 predict::{PredictionDetails, perform_predict, setup_zeta},
20};
21
22#[derive(Debug)]
23pub(crate) struct ExecutionData {
24 execution_id: String,
25 diff: String,
26 reasoning: String,
27}
28
29pub async fn run_evaluate(
30 args: EvaluateArguments,
31 app_state: &Arc<ZetaCliAppState>,
32 cx: &mut AsyncApp,
33) {
34 if args.example_paths.is_empty() {
35 eprintln!("No examples provided");
36 return;
37 }
38
39 let all_tasks = args.example_paths.into_iter().map(|path| {
40 let options = args.options.clone();
41 let app_state = app_state.clone();
42 let example = NamedExample::load(&path).expect("Failed to load example");
43
44 cx.spawn(async move |cx| {
45 let project = example.setup_project(&app_state, cx).await.unwrap();
46
47 let providers = (0..args.repetitions)
48 .map(|_| setup_zeta(args.options.provider, &project, &app_state, cx).unwrap())
49 .collect::<Vec<_>>();
50
51 let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
52
53 let tasks = providers
54 .into_iter()
55 .enumerate()
56 .map(move |(repetition_ix, zeta)| {
57 let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
58 let example = example.clone();
59 let project = project.clone();
60 let options = options.clone();
61
62 cx.spawn(async move |cx| {
63 let name = example.name.clone();
64 run_evaluate_one(
65 example,
66 repetition_ix,
67 project,
68 zeta,
69 options,
70 !args.skip_prediction,
71 cx,
72 )
73 .await
74 .map_err(|err| (err, name, repetition_ix))
75 })
76 });
77 futures::future::join_all(tasks).await
78 })
79 });
80 let all_results = futures::future::join_all(all_tasks).await;
81
82 write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap();
83 if let Some(mut output_file) =
84 std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err()
85 {
86 write_aggregated_scores(&mut output_file, &all_results).log_err();
87 };
88
89 if args.repetitions > 1 {
90 if let Err(e) = write_bucketed_analysis(&all_results) {
91 eprintln!("Failed to write bucketed analysis: {:?}", e);
92 }
93 }
94
95 print_run_data_dir(args.repetitions == 1, std::io::stdout().is_terminal());
96}
97
98fn write_aggregated_scores(
99 w: &mut impl std::io::Write,
100 all_results: &Vec<
101 Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
102 >,
103) -> Result<()> {
104 let mut successful = Vec::new();
105 let mut failed_count = 0;
106
107 for result in all_results.iter().flatten() {
108 match result {
109 Ok((eval_result, _execution_data)) => successful.push(eval_result),
110 Err((err, name, repetition_ix)) => {
111 if failed_count == 0 {
112 writeln!(w, "## Errors\n")?;
113 }
114
115 failed_count += 1;
116 writeln!(w, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
117 }
118 }
119 }
120
121 if successful.len() > 1 {
122 let mut edit_predictions = successful
123 .iter()
124 .filter_map(|r| r.edit_prediction.as_ref())
125 .peekable();
126 let has_edit_predictions = edit_predictions.peek().is_some();
127 let aggregated_result = EvaluationResult {
128 context: Scores::aggregate(successful.iter().map(|r| &r.context)),
129 edit_prediction: has_edit_predictions.then(|| Scores::aggregate(edit_predictions)),
130 prompt_len: successful.iter().map(|r| r.prompt_len).sum::<usize>() / successful.len(),
131 generated_len: successful.iter().map(|r| r.generated_len).sum::<usize>()
132 / successful.len(),
133 context_lines_found_in_context: successful
134 .iter()
135 .map(|r| r.context_lines_found_in_context)
136 .sum::<usize>()
137 / successful.len(),
138 context_lines_in_expected_patch: successful
139 .iter()
140 .map(|r| r.context_lines_in_expected_patch)
141 .sum::<usize>()
142 / successful.len(),
143 };
144
145 writeln!(w, "\n{}", "-".repeat(80))?;
146 writeln!(w, "\n## TOTAL SCORES")?;
147 writeln!(w, "{:#}", aggregated_result)?;
148 }
149
150 if successful.len() + failed_count > 1 {
151 writeln!(
152 w,
153 "\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
154 successful.len(),
155 successful.len() + failed_count,
156 (successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
157 )?;
158 }
159
160 Ok(())
161}
162
163pub async fn run_evaluate_one(
164 example: NamedExample,
165 repetition_ix: Option<u16>,
166 project: Entity<Project>,
167 zeta: Entity<Zeta>,
168 prediction_options: PredictionOptions,
169 predict: bool,
170 cx: &mut AsyncApp,
171) -> Result<(EvaluationResult, ExecutionData)> {
172 let predict_result = perform_predict(
173 example.clone(),
174 project,
175 zeta,
176 repetition_ix,
177 prediction_options,
178 cx,
179 )
180 .await?;
181
182 let evaluation_result = evaluate(&example.example, &predict_result, predict);
183
184 if repetition_ix.is_none() {
185 write_eval_result(
186 &example,
187 &predict_result,
188 &evaluation_result,
189 &mut std::io::stdout(),
190 std::io::stdout().is_terminal(),
191 predict,
192 )?;
193 }
194
195 if let Some(mut results_file) =
196 std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err()
197 {
198 write_eval_result(
199 &example,
200 &predict_result,
201 &evaluation_result,
202 &mut results_file,
203 false,
204 predict,
205 )
206 .log_err();
207 }
208
209 let execution_data = ExecutionData {
210 execution_id: if let Some(rep_ix) = repetition_ix {
211 format!("{:03}", rep_ix)
212 } else {
213 example.name.clone()
214 },
215 diff: predict_result.diff.clone(),
216 reasoning: std::fs::read_to_string(
217 predict_result
218 .run_example_dir
219 .join("prediction_response.md"),
220 )
221 .unwrap_or_default(),
222 };
223
224 anyhow::Ok((evaluation_result, execution_data))
225}
226
227fn write_eval_result(
228 example: &NamedExample,
229 predictions: &PredictionDetails,
230 evaluation_result: &EvaluationResult,
231 out: &mut impl Write,
232 use_color: bool,
233 predict: bool,
234) -> Result<()> {
235 if predict {
236 writeln!(
237 out,
238 "## Expected edit prediction:\n\n```diff\n{}\n```\n",
239 compare_diffs(
240 &example.example.expected_patch,
241 &predictions.diff,
242 use_color
243 )
244 )?;
245 writeln!(
246 out,
247 "## Actual edit prediction:\n\n```diff\n{}\n```\n",
248 compare_diffs(
249 &predictions.diff,
250 &example.example.expected_patch,
251 use_color
252 )
253 )?;
254 }
255
256 writeln!(out, "{:#}", evaluation_result)?;
257
258 anyhow::Ok(())
259}
260
261#[derive(Debug, Default)]
262pub struct EvaluationResult {
263 pub edit_prediction: Option<Scores>,
264 pub context: Scores,
265 pub prompt_len: usize,
266 pub generated_len: usize,
267 pub context_lines_in_expected_patch: usize,
268 pub context_lines_found_in_context: usize,
269}
270
271#[derive(Default, Debug)]
272pub struct Scores {
273 pub true_positives: usize,
274 pub false_positives: usize,
275 pub false_negatives: usize,
276}
277
278impl Scores {
279 pub fn new(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
280 let true_positives = expected.intersection(actual).count();
281 let false_positives = actual.difference(expected).count();
282 let false_negatives = expected.difference(actual).count();
283
284 Scores {
285 true_positives,
286 false_positives,
287 false_negatives,
288 }
289 }
290
291 pub fn to_markdown(&self) -> String {
292 format!(
293 "
294Precision : {:.4}
295Recall : {:.4}
296F1 Score : {:.4}
297True Positives : {}
298False Positives : {}
299False Negatives : {}",
300 self.precision(),
301 self.recall(),
302 self.f1_score(),
303 self.true_positives,
304 self.false_positives,
305 self.false_negatives
306 )
307 }
308
309 pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
310 let mut true_positives = 0;
311 let mut false_positives = 0;
312 let mut false_negatives = 0;
313
314 for score in scores {
315 true_positives += score.true_positives;
316 false_positives += score.false_positives;
317 false_negatives += score.false_negatives;
318 }
319
320 Scores {
321 true_positives,
322 false_positives,
323 false_negatives,
324 }
325 }
326
327 pub fn precision(&self) -> f64 {
328 if self.true_positives + self.false_positives == 0 {
329 0.0
330 } else {
331 self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
332 }
333 }
334
335 pub fn recall(&self) -> f64 {
336 if self.true_positives + self.false_negatives == 0 {
337 0.0
338 } else {
339 self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
340 }
341 }
342
343 pub fn f1_score(&self) -> f64 {
344 let recall = self.recall();
345 let precision = self.precision();
346 if precision + recall == 0.0 {
347 0.0
348 } else {
349 2.0 * precision * recall / (precision + recall)
350 }
351 }
352}
353
354impl std::fmt::Display for EvaluationResult {
355 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
356 if f.alternate() {
357 self.fmt_table(f)
358 } else {
359 self.fmt_markdown(f)
360 }
361 }
362}
363
364impl EvaluationResult {
365 fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
366 write!(
367 f,
368 r#"
369### Context Scores
370{}
371"#,
372 self.context.to_markdown(),
373 )?;
374 if let Some(prediction) = &self.edit_prediction {
375 write!(
376 f,
377 r#"
378 ### Edit Prediction Scores
379 {}"#,
380 prediction.to_markdown()
381 )?;
382 }
383 Ok(())
384 }
385
386 fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
387 writeln!(f, "### Scores\n")?;
388 writeln!(
389 f,
390 " Prompt Generated RetrievedContext PatchContext TP FP FN Precision Recall F1"
391 )?;
392 writeln!(
393 f,
394 "─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────"
395 )?;
396 writeln!(
397 f,
398 "Context Retrieval {:<7} {:<9} {:<16} {:<16} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
399 "",
400 "",
401 "",
402 "",
403 self.context.true_positives,
404 self.context.false_positives,
405 self.context.false_negatives,
406 self.context.precision() * 100.0,
407 self.context.recall() * 100.0,
408 self.context.f1_score() * 100.0
409 )?;
410 if let Some(edit_prediction) = &self.edit_prediction {
411 writeln!(
412 f,
413 "Edit Prediction {:<7} {:<9} {:<16} {:<16} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
414 self.prompt_len,
415 self.generated_len,
416 self.context_lines_found_in_context,
417 self.context_lines_in_expected_patch,
418 edit_prediction.true_positives,
419 edit_prediction.false_positives,
420 edit_prediction.false_negatives,
421 edit_prediction.precision() * 100.0,
422 edit_prediction.recall() * 100.0,
423 edit_prediction.f1_score() * 100.0
424 )?;
425 }
426 Ok(())
427 }
428}
429
430fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult {
431 let mut eval_result = EvaluationResult {
432 prompt_len: preds.prompt_len,
433 generated_len: preds.generated_len,
434 ..Default::default()
435 };
436
437 let actual_context_lines: HashSet<_> = preds
438 .excerpts
439 .iter()
440 .flat_map(|excerpt| {
441 excerpt
442 .text
443 .lines()
444 .map(|line| format!("{}: {line}", excerpt.path.display()))
445 })
446 .collect();
447
448 let mut false_positive_lines = actual_context_lines.clone();
449
450 for entry in &example.expected_context {
451 let mut best_alternative_score: Option<Scores> = None;
452
453 for alternative in &entry.alternatives {
454 let expected: HashSet<_> = alternative
455 .excerpts
456 .iter()
457 .flat_map(|excerpt| {
458 excerpt
459 .text
460 .lines()
461 .map(|line| format!("{}: {line}", excerpt.path.display()))
462 })
463 .collect();
464
465 let scores = Scores::new(&expected, &actual_context_lines);
466
467 false_positive_lines.retain(|line| !expected.contains(line));
468
469 if best_alternative_score
470 .as_ref()
471 .is_none_or(|best| scores.recall() > best.recall())
472 {
473 best_alternative_score = Some(scores);
474 }
475 }
476
477 let best_alternative = best_alternative_score.unwrap_or_default();
478 eval_result.context.false_negatives += best_alternative.false_negatives;
479 eval_result.context.true_positives += best_alternative.true_positives;
480 }
481
482 eval_result.context.false_positives = false_positive_lines.len();
483
484 if predict {
485 // todo: alternatives for patches
486 let expected_patch = example
487 .expected_patch
488 .lines()
489 .map(DiffLine::parse)
490 .collect::<Vec<_>>();
491 let expected_patch_lines = expected_patch
492 .iter()
493 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
494 .map(|line| line.to_string())
495 .collect();
496 let expected_context_lines = expected_patch
497 .iter()
498 .filter_map(|line| {
499 if let DiffLine::Context(str) = line {
500 Some(String::from(*str))
501 } else {
502 None
503 }
504 })
505 .collect::<BTreeSet<_>>();
506 let actual_context_lines = preds
507 .excerpts
508 .iter()
509 .flat_map(|excerpt| excerpt.text.lines().map(ToOwned::to_owned))
510 .collect::<BTreeSet<_>>();
511
512 let matched = expected_context_lines
513 .intersection(&actual_context_lines)
514 .count();
515
516 let actual_patch_lines = preds
517 .diff
518 .lines()
519 .map(DiffLine::parse)
520 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
521 .map(|line| line.to_string())
522 .collect();
523
524 eval_result.edit_prediction = Some(Scores::new(&expected_patch_lines, &actual_patch_lines));
525 eval_result.context_lines_in_expected_patch = expected_context_lines.len();
526 eval_result.context_lines_found_in_context = matched;
527 }
528
529 eval_result
530}
531
532/// Return annotated `patch_a` so that:
533/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
534/// Additions and deletions that are present in `patch_b` will be highlighted in green.
535pub fn compare_diffs(patch_a: &str, patch_b: &str, use_color: bool) -> String {
536 let green = if use_color { "\x1b[32m✓ " } else { "" };
537 let red = if use_color { "\x1b[31m✗ " } else { "" };
538 let neutral = if use_color { " " } else { "" };
539 let reset = if use_color { "\x1b[0m" } else { "" };
540 let lines_a = patch_a.lines().map(DiffLine::parse);
541 let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
542
543 let annotated = lines_a
544 .map(|line| match line {
545 DiffLine::Addition(_) | DiffLine::Deletion(_) => {
546 if lines_b.contains(&line) {
547 format!("{green}{line}{reset}")
548 } else {
549 format!("{red}{line}{reset}")
550 }
551 }
552 _ => format!("{neutral}{line}{reset}"),
553 })
554 .collect::<Vec<String>>();
555
556 annotated.join("\n")
557}
558
559fn write_bucketed_analysis(
560 all_results: &Vec<
561 Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
562 >,
563) -> Result<()> {
564 #[derive(Debug)]
565 struct EditBucket {
566 diff: String,
567 is_correct: bool,
568 execution_indices: Vec<String>,
569 reasoning_samples: Vec<String>,
570 }
571
572 let mut total_executions = 0;
573 let mut empty_predictions = Vec::new();
574 let mut errors = Vec::new();
575
576 let mut buckets: HashMap<String, EditBucket> = HashMap::new();
577
578 for result in all_results.iter().flatten() {
579 total_executions += 1;
580
581 let (evaluation_result, execution_data) = match result {
582 Ok((eval_result, execution_data)) => {
583 if execution_data.diff.is_empty() {
584 empty_predictions.push(execution_data);
585 continue;
586 }
587 (eval_result, execution_data)
588 }
589 Err(err) => {
590 errors.push(err);
591 continue;
592 }
593 };
594
595 buckets
596 .entry(execution_data.diff.clone())
597 .and_modify(|bucket| {
598 bucket
599 .execution_indices
600 .push(execution_data.execution_id.clone());
601 bucket
602 .reasoning_samples
603 .push(execution_data.reasoning.clone());
604 })
605 .or_insert_with(|| EditBucket {
606 diff: execution_data.diff.clone(),
607 is_correct: {
608 evaluation_result
609 .edit_prediction
610 .as_ref()
611 .map_or(false, |edit_prediction| {
612 edit_prediction.false_positives == 0
613 && edit_prediction.false_negatives == 0
614 && edit_prediction.true_positives > 0
615 })
616 },
617 execution_indices: vec![execution_data.execution_id.clone()],
618 reasoning_samples: vec![execution_data.reasoning.clone()],
619 });
620 }
621
622 let mut sorted_buckets = buckets.into_values().collect::<Vec<_>>();
623 sorted_buckets.sort_by(|a, b| match (a.is_correct, b.is_correct) {
624 (true, false) => std::cmp::Ordering::Less,
625 (false, true) => std::cmp::Ordering::Greater,
626 _ => b.execution_indices.len().cmp(&a.execution_indices.len()),
627 });
628
629 let output_path = crate::paths::RUN_DIR.join("bucketed_analysis.md");
630 let mut output = std::fs::File::create(&output_path)?;
631
632 writeln!(output, "# Bucketed Edit Analysis\n")?;
633
634 writeln!(output, "## Summary\n")?;
635 writeln!(output, "- **Total executions**: {}", total_executions)?;
636
637 let correct_count: usize = sorted_buckets
638 .iter()
639 .filter(|b| b.is_correct)
640 .map(|b| b.execution_indices.len())
641 .sum();
642
643 let incorrect_count: usize = sorted_buckets
644 .iter()
645 .filter(|b| !b.is_correct)
646 .map(|b| b.execution_indices.len())
647 .sum();
648
649 writeln!(
650 output,
651 "- **Correct predictions**: {} ({:.1}%)",
652 correct_count,
653 (correct_count as f64 / total_executions as f64) * 100.0
654 )?;
655
656 writeln!(
657 output,
658 "- **Incorrect predictions**: {} ({:.1}%)",
659 incorrect_count,
660 (incorrect_count as f64 / total_executions as f64) * 100.0
661 )?;
662
663 writeln!(
664 output,
665 "- **No Predictions**: {} ({:.1}%)",
666 empty_predictions.len(),
667 (empty_predictions.len() as f64 / total_executions as f64) * 100.0
668 )?;
669
670 let unique_incorrect = sorted_buckets.iter().filter(|b| !b.is_correct).count();
671 writeln!(
672 output,
673 "- **Unique incorrect edit patterns**: {}\n",
674 unique_incorrect
675 )?;
676
677 writeln!(output, "---\n")?;
678
679 for (idx, bucket) in sorted_buckets.iter().filter(|b| b.is_correct).enumerate() {
680 if idx == 0 {
681 writeln!(
682 output,
683 "## Correct Predictions ({} occurrences)\n",
684 bucket.execution_indices.len()
685 )?;
686 }
687
688 writeln!(output, "**Predicted Edit:**\n")?;
689 writeln!(output, "```diff")?;
690 writeln!(output, "{}", bucket.diff)?;
691 writeln!(output, "```\n")?;
692
693 writeln!(
694 output,
695 "**Executions:** {}\n",
696 bucket.execution_indices.join(", ")
697 )?;
698 writeln!(output, "---\n")?;
699 }
700
701 for (idx, bucket) in sorted_buckets.iter().filter(|b| !b.is_correct).enumerate() {
702 writeln!(
703 output,
704 "## Incorrect Prediction #{} ({} occurrences)\n",
705 idx + 1,
706 bucket.execution_indices.len()
707 )?;
708
709 writeln!(output, "**Predicted Edit:**\n")?;
710 writeln!(output, "```diff")?;
711 writeln!(output, "{}", bucket.diff)?;
712 writeln!(output, "```\n")?;
713
714 writeln!(
715 output,
716 "**Executions:** {}\n",
717 bucket.execution_indices.join(", ")
718 )?;
719
720 for (exec_id, reasoning) in bucket
721 .execution_indices
722 .iter()
723 .zip(bucket.reasoning_samples.iter())
724 {
725 writeln!(output, "{}", fmt_execution(exec_id, reasoning))?;
726 }
727
728 writeln!(output, "\n---\n")?;
729 }
730
731 if !empty_predictions.is_empty() {
732 writeln!(
733 output,
734 "## No Predictions ({} occurrences)\n",
735 empty_predictions.len()
736 )?;
737
738 for execution_data in &empty_predictions {
739 writeln!(
740 output,
741 "{}",
742 fmt_execution(&execution_data.execution_id, &execution_data.reasoning)
743 )?;
744 }
745 writeln!(output, "\n---\n")?;
746 }
747
748 if !errors.is_empty() {
749 writeln!(output, "## Errors ({} occurrences)\n", errors.len())?;
750
751 for (err, name, repetition_ix) in &errors {
752 writeln!(output, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
753 }
754 writeln!(output, "\n---\n")?;
755 }
756
757 fn fmt_execution(exec_id: &str, reasoning: &str) -> String {
758 let exec_content = format!(
759 "\n### Execution {} `{}/{}/prediction_response.md`{}",
760 exec_id,
761 crate::paths::RUN_DIR.display(),
762 exec_id,
763 indent_text(&format!("\n\n```\n{}\n```\n", reasoning,), 2)
764 );
765 indent_text(&exec_content, 2)
766 }
767
768 fn indent_text(text: &str, spaces: usize) -> String {
769 let indent = " ".repeat(spaces);
770 text.lines()
771 .collect::<Vec<_>>()
772 .join(&format!("\n{}", indent))
773 }
774
775 Ok(())
776}
777
778fn fmt_evaluation_error(err: &anyhow::Error, name: &str, repetition_ix: &Option<u16>) -> String {
779 let err = format!("{err:?}")
780 .replace("<edits", "```xml\n<edits")
781 .replace("</edits>", "</edits>\n```");
782 format!(
783 "### ERROR {name}{}\n\n{err}\n",
784 repetition_ix
785 .map(|ix| format!(" [RUN {ix:03}]"))
786 .unwrap_or_default()
787 )
788}