1use std::{
2 collections::HashMap,
3 io::{IsTerminal, Write},
4 sync::Arc,
5};
6
7use anyhow::Result;
8use collections::HashSet;
9use gpui::{AsyncApp, Entity};
10use project::Project;
11use util::ResultExt as _;
12use zeta::{Zeta, udiff::DiffLine};
13
14use crate::{
15 EvaluateArguments, PredictionOptions,
16 example::{Example, NamedExample},
17 headless::ZetaCliAppState,
18 paths::print_run_data_dir,
19 predict::{PredictionDetails, perform_predict, setup_zeta},
20};
21
22#[derive(Debug)]
23pub(crate) struct ExecutionData {
24 execution_id: String,
25 diff: String,
26 reasoning: String,
27}
28
29pub async fn run_evaluate(
30 args: EvaluateArguments,
31 app_state: &Arc<ZetaCliAppState>,
32 cx: &mut AsyncApp,
33) {
34 if args.example_paths.is_empty() {
35 eprintln!("No examples provided");
36 return;
37 }
38
39 let all_tasks = args.example_paths.into_iter().map(|path| {
40 let options = args.options.clone();
41 let app_state = app_state.clone();
42 let example = NamedExample::load(&path).expect("Failed to load example");
43
44 cx.spawn(async move |cx| {
45 let project = example.setup_project(&app_state, cx).await.unwrap();
46
47 let providers = (0..args.repetitions)
48 .map(|_| setup_zeta(args.options.provider, &project, &app_state, cx).unwrap())
49 .collect::<Vec<_>>();
50
51 let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
52
53 let tasks = providers
54 .into_iter()
55 .enumerate()
56 .map(move |(repetition_ix, zeta)| {
57 let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
58 let example = example.clone();
59 let project = project.clone();
60 let options = options.clone();
61
62 cx.spawn(async move |cx| {
63 let name = example.name.clone();
64 run_evaluate_one(
65 example,
66 repetition_ix,
67 project,
68 zeta,
69 options,
70 !args.skip_prediction,
71 cx,
72 )
73 .await
74 .map_err(|err| (err, name, repetition_ix))
75 })
76 });
77 futures::future::join_all(tasks).await
78 })
79 });
80 let all_results = futures::future::join_all(all_tasks).await;
81
82 write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap();
83 if let Some(mut output_file) =
84 std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err()
85 {
86 write_aggregated_scores(&mut output_file, &all_results).log_err();
87 };
88
89 if args.repetitions > 1 {
90 if let Err(e) = write_bucketed_analysis(&all_results) {
91 eprintln!("Failed to write bucketed analysis: {:?}", e);
92 }
93 }
94
95 print_run_data_dir(args.repetitions == 1, std::io::stdout().is_terminal());
96}
97
98fn write_aggregated_scores(
99 w: &mut impl std::io::Write,
100 all_results: &Vec<
101 Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
102 >,
103) -> Result<()> {
104 let mut successful = Vec::new();
105 let mut failed_count = 0;
106
107 for result in all_results.iter().flatten() {
108 match result {
109 Ok((eval_result, _execution_data)) => successful.push(eval_result),
110 Err((err, name, repetition_ix)) => {
111 if failed_count == 0 {
112 writeln!(w, "## Errors\n")?;
113 }
114
115 failed_count += 1;
116 writeln!(w, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
117 }
118 }
119 }
120
121 if successful.len() > 1 {
122 let mut edit_predictions = successful
123 .iter()
124 .filter_map(|r| r.edit_prediction.as_ref())
125 .peekable();
126 let has_edit_predictions = edit_predictions.peek().is_some();
127 let aggregated_result = EvaluationResult {
128 edit_prediction: has_edit_predictions.then(|| Scores::aggregate(edit_predictions)),
129 prompt_len: successful.iter().map(|r| r.prompt_len).sum::<usize>() / successful.len(),
130 generated_len: successful.iter().map(|r| r.generated_len).sum::<usize>()
131 / successful.len(),
132 };
133
134 writeln!(w, "\n{}", "-".repeat(80))?;
135 writeln!(w, "\n## TOTAL SCORES")?;
136 writeln!(w, "{:#}", aggregated_result)?;
137 }
138
139 if successful.len() + failed_count > 1 {
140 writeln!(
141 w,
142 "\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
143 successful.len(),
144 successful.len() + failed_count,
145 (successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
146 )?;
147 }
148
149 Ok(())
150}
151
152pub async fn run_evaluate_one(
153 example: NamedExample,
154 repetition_ix: Option<u16>,
155 project: Entity<Project>,
156 zeta: Entity<Zeta>,
157 prediction_options: PredictionOptions,
158 predict: bool,
159 cx: &mut AsyncApp,
160) -> Result<(EvaluationResult, ExecutionData)> {
161 let predict_result = perform_predict(
162 example.clone(),
163 project,
164 zeta,
165 repetition_ix,
166 prediction_options,
167 cx,
168 )
169 .await?;
170
171 let evaluation_result = evaluate(&example.example, &predict_result, predict);
172
173 if repetition_ix.is_none() {
174 write_eval_result(
175 &example,
176 &predict_result,
177 &evaluation_result,
178 &mut std::io::stdout(),
179 std::io::stdout().is_terminal(),
180 predict,
181 )?;
182 }
183
184 if let Some(mut results_file) =
185 std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err()
186 {
187 write_eval_result(
188 &example,
189 &predict_result,
190 &evaluation_result,
191 &mut results_file,
192 false,
193 predict,
194 )
195 .log_err();
196 }
197
198 let execution_data = ExecutionData {
199 execution_id: if let Some(rep_ix) = repetition_ix {
200 format!("{:03}", rep_ix)
201 } else {
202 example.name.clone()
203 },
204 diff: predict_result.diff.clone(),
205 reasoning: std::fs::read_to_string(
206 predict_result
207 .run_example_dir
208 .join("prediction_response.md"),
209 )
210 .unwrap_or_default(),
211 };
212
213 anyhow::Ok((evaluation_result, execution_data))
214}
215
216fn write_eval_result(
217 example: &NamedExample,
218 predictions: &PredictionDetails,
219 evaluation_result: &EvaluationResult,
220 out: &mut impl Write,
221 use_color: bool,
222 predict: bool,
223) -> Result<()> {
224 if predict {
225 writeln!(
226 out,
227 "## Expected edit prediction:\n\n```diff\n{}\n```\n",
228 compare_diffs(
229 &example.example.expected_patch,
230 &predictions.diff,
231 use_color
232 )
233 )?;
234 writeln!(
235 out,
236 "## Actual edit prediction:\n\n```diff\n{}\n```\n",
237 compare_diffs(
238 &predictions.diff,
239 &example.example.expected_patch,
240 use_color
241 )
242 )?;
243 }
244
245 writeln!(out, "{:#}", evaluation_result)?;
246
247 anyhow::Ok(())
248}
249
250#[derive(Debug, Default)]
251pub struct EvaluationResult {
252 pub edit_prediction: Option<Scores>,
253 pub prompt_len: usize,
254 pub generated_len: usize,
255}
256
257#[derive(Default, Debug)]
258pub struct Scores {
259 pub true_positives: usize,
260 pub false_positives: usize,
261 pub false_negatives: usize,
262}
263
264impl Scores {
265 pub fn new(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
266 let true_positives = expected.intersection(actual).count();
267 let false_positives = actual.difference(expected).count();
268 let false_negatives = expected.difference(actual).count();
269
270 Scores {
271 true_positives,
272 false_positives,
273 false_negatives,
274 }
275 }
276
277 pub fn to_markdown(&self) -> String {
278 format!(
279 "
280Precision : {:.4}
281Recall : {:.4}
282F1 Score : {:.4}
283True Positives : {}
284False Positives : {}
285False Negatives : {}",
286 self.precision(),
287 self.recall(),
288 self.f1_score(),
289 self.true_positives,
290 self.false_positives,
291 self.false_negatives
292 )
293 }
294
295 pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
296 let mut true_positives = 0;
297 let mut false_positives = 0;
298 let mut false_negatives = 0;
299
300 for score in scores {
301 true_positives += score.true_positives;
302 false_positives += score.false_positives;
303 false_negatives += score.false_negatives;
304 }
305
306 Scores {
307 true_positives,
308 false_positives,
309 false_negatives,
310 }
311 }
312
313 pub fn precision(&self) -> f64 {
314 if self.true_positives + self.false_positives == 0 {
315 0.0
316 } else {
317 self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
318 }
319 }
320
321 pub fn recall(&self) -> f64 {
322 if self.true_positives + self.false_negatives == 0 {
323 0.0
324 } else {
325 self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
326 }
327 }
328
329 pub fn f1_score(&self) -> f64 {
330 let recall = self.recall();
331 let precision = self.precision();
332 if precision + recall == 0.0 {
333 0.0
334 } else {
335 2.0 * precision * recall / (precision + recall)
336 }
337 }
338}
339
340impl std::fmt::Display for EvaluationResult {
341 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
342 if f.alternate() {
343 self.fmt_table(f)
344 } else {
345 self.fmt_markdown(f)
346 }
347 }
348}
349
350impl EvaluationResult {
351 fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
352 if let Some(prediction) = &self.edit_prediction {
353 write!(
354 f,
355 r#"
356 ### Edit Prediction Scores
357 {}"#,
358 prediction.to_markdown()
359 )?;
360 }
361 Ok(())
362 }
363
364 fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
365 writeln!(f, "### Scores\n")?;
366 writeln!(
367 f,
368 " Prompt Generated TP FP FN Precision Recall F1"
369 )?;
370 writeln!(
371 f,
372 "───────────────────────────────────────────────────────────────────────────────────────────────"
373 )?;
374 if let Some(edit_prediction) = &self.edit_prediction {
375 writeln!(
376 f,
377 "Edit Prediction {:<7} {:<9} {:<6} {:<6} {:<6} {:>9.2} {:>8.2} {:>7.2}",
378 self.prompt_len,
379 self.generated_len,
380 edit_prediction.true_positives,
381 edit_prediction.false_positives,
382 edit_prediction.false_negatives,
383 edit_prediction.precision() * 100.0,
384 edit_prediction.recall() * 100.0,
385 edit_prediction.f1_score() * 100.0
386 )?;
387 }
388 Ok(())
389 }
390}
391
392fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult {
393 let mut eval_result = EvaluationResult {
394 prompt_len: preds.prompt_len,
395 generated_len: preds.generated_len,
396 ..Default::default()
397 };
398
399 if predict {
400 // todo: alternatives for patches
401 let expected_patch = example
402 .expected_patch
403 .lines()
404 .map(DiffLine::parse)
405 .collect::<Vec<_>>();
406 let expected_patch_lines = expected_patch
407 .iter()
408 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
409 .map(|line| line.to_string())
410 .collect();
411
412 let actual_patch_lines = preds
413 .diff
414 .lines()
415 .map(DiffLine::parse)
416 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
417 .map(|line| line.to_string())
418 .collect();
419
420 eval_result.edit_prediction = Some(Scores::new(&expected_patch_lines, &actual_patch_lines));
421 }
422
423 eval_result
424}
425
426/// Return annotated `patch_a` so that:
427/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
428/// Additions and deletions that are present in `patch_b` will be highlighted in green.
429pub fn compare_diffs(patch_a: &str, patch_b: &str, use_color: bool) -> String {
430 let green = if use_color { "\x1b[32m✓ " } else { "" };
431 let red = if use_color { "\x1b[31m✗ " } else { "" };
432 let neutral = if use_color { " " } else { "" };
433 let reset = if use_color { "\x1b[0m" } else { "" };
434 let lines_a = patch_a.lines().map(DiffLine::parse);
435 let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
436
437 let annotated = lines_a
438 .map(|line| match line {
439 DiffLine::Addition(_) | DiffLine::Deletion(_) => {
440 if lines_b.contains(&line) {
441 format!("{green}{line}{reset}")
442 } else {
443 format!("{red}{line}{reset}")
444 }
445 }
446 _ => format!("{neutral}{line}{reset}"),
447 })
448 .collect::<Vec<String>>();
449
450 annotated.join("\n")
451}
452
453fn write_bucketed_analysis(
454 all_results: &Vec<
455 Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
456 >,
457) -> Result<()> {
458 #[derive(Debug)]
459 struct EditBucket {
460 diff: String,
461 is_correct: bool,
462 execution_indices: Vec<String>,
463 reasoning_samples: Vec<String>,
464 }
465
466 let mut total_executions = 0;
467 let mut empty_predictions = Vec::new();
468 let mut errors = Vec::new();
469
470 let mut buckets: HashMap<String, EditBucket> = HashMap::new();
471
472 for result in all_results.iter().flatten() {
473 total_executions += 1;
474
475 let (evaluation_result, execution_data) = match result {
476 Ok((eval_result, execution_data)) => {
477 if execution_data.diff.is_empty() {
478 empty_predictions.push(execution_data);
479 continue;
480 }
481 (eval_result, execution_data)
482 }
483 Err(err) => {
484 errors.push(err);
485 continue;
486 }
487 };
488
489 buckets
490 .entry(execution_data.diff.clone())
491 .and_modify(|bucket| {
492 bucket
493 .execution_indices
494 .push(execution_data.execution_id.clone());
495 bucket
496 .reasoning_samples
497 .push(execution_data.reasoning.clone());
498 })
499 .or_insert_with(|| EditBucket {
500 diff: execution_data.diff.clone(),
501 is_correct: {
502 evaluation_result
503 .edit_prediction
504 .as_ref()
505 .map_or(false, |edit_prediction| {
506 edit_prediction.false_positives == 0
507 && edit_prediction.false_negatives == 0
508 && edit_prediction.true_positives > 0
509 })
510 },
511 execution_indices: vec![execution_data.execution_id.clone()],
512 reasoning_samples: vec![execution_data.reasoning.clone()],
513 });
514 }
515
516 let mut sorted_buckets = buckets.into_values().collect::<Vec<_>>();
517 sorted_buckets.sort_by(|a, b| match (a.is_correct, b.is_correct) {
518 (true, false) => std::cmp::Ordering::Less,
519 (false, true) => std::cmp::Ordering::Greater,
520 _ => b.execution_indices.len().cmp(&a.execution_indices.len()),
521 });
522
523 let output_path = crate::paths::RUN_DIR.join("bucketed_analysis.md");
524 let mut output = std::fs::File::create(&output_path)?;
525
526 writeln!(output, "# Bucketed Edit Analysis\n")?;
527
528 writeln!(output, "## Summary\n")?;
529 writeln!(output, "- **Total executions**: {}", total_executions)?;
530
531 let correct_count: usize = sorted_buckets
532 .iter()
533 .filter(|b| b.is_correct)
534 .map(|b| b.execution_indices.len())
535 .sum();
536
537 let incorrect_count: usize = sorted_buckets
538 .iter()
539 .filter(|b| !b.is_correct)
540 .map(|b| b.execution_indices.len())
541 .sum();
542
543 writeln!(
544 output,
545 "- **Correct predictions**: {} ({:.1}%)",
546 correct_count,
547 (correct_count as f64 / total_executions as f64) * 100.0
548 )?;
549
550 writeln!(
551 output,
552 "- **Incorrect predictions**: {} ({:.1}%)",
553 incorrect_count,
554 (incorrect_count as f64 / total_executions as f64) * 100.0
555 )?;
556
557 writeln!(
558 output,
559 "- **No Predictions**: {} ({:.1}%)",
560 empty_predictions.len(),
561 (empty_predictions.len() as f64 / total_executions as f64) * 100.0
562 )?;
563
564 let unique_incorrect = sorted_buckets.iter().filter(|b| !b.is_correct).count();
565 writeln!(
566 output,
567 "- **Unique incorrect edit patterns**: {}\n",
568 unique_incorrect
569 )?;
570
571 writeln!(output, "---\n")?;
572
573 for (idx, bucket) in sorted_buckets.iter().filter(|b| b.is_correct).enumerate() {
574 if idx == 0 {
575 writeln!(
576 output,
577 "## Correct Predictions ({} occurrences)\n",
578 bucket.execution_indices.len()
579 )?;
580 }
581
582 writeln!(output, "**Predicted Edit:**\n")?;
583 writeln!(output, "```diff")?;
584 writeln!(output, "{}", bucket.diff)?;
585 writeln!(output, "```\n")?;
586
587 writeln!(
588 output,
589 "**Executions:** {}\n",
590 bucket.execution_indices.join(", ")
591 )?;
592 writeln!(output, "---\n")?;
593 }
594
595 for (idx, bucket) in sorted_buckets.iter().filter(|b| !b.is_correct).enumerate() {
596 writeln!(
597 output,
598 "## Incorrect Prediction #{} ({} occurrences)\n",
599 idx + 1,
600 bucket.execution_indices.len()
601 )?;
602
603 writeln!(output, "**Predicted Edit:**\n")?;
604 writeln!(output, "```diff")?;
605 writeln!(output, "{}", bucket.diff)?;
606 writeln!(output, "```\n")?;
607
608 writeln!(
609 output,
610 "**Executions:** {}\n",
611 bucket.execution_indices.join(", ")
612 )?;
613
614 for (exec_id, reasoning) in bucket
615 .execution_indices
616 .iter()
617 .zip(bucket.reasoning_samples.iter())
618 {
619 writeln!(output, "{}", fmt_execution(exec_id, reasoning))?;
620 }
621
622 writeln!(output, "\n---\n")?;
623 }
624
625 if !empty_predictions.is_empty() {
626 writeln!(
627 output,
628 "## No Predictions ({} occurrences)\n",
629 empty_predictions.len()
630 )?;
631
632 for execution_data in &empty_predictions {
633 writeln!(
634 output,
635 "{}",
636 fmt_execution(&execution_data.execution_id, &execution_data.reasoning)
637 )?;
638 }
639 writeln!(output, "\n---\n")?;
640 }
641
642 if !errors.is_empty() {
643 writeln!(output, "## Errors ({} occurrences)\n", errors.len())?;
644
645 for (err, name, repetition_ix) in &errors {
646 writeln!(output, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
647 }
648 writeln!(output, "\n---\n")?;
649 }
650
651 fn fmt_execution(exec_id: &str, reasoning: &str) -> String {
652 let exec_content = format!(
653 "\n### Execution {} `{}/{}/prediction_response.md`{}",
654 exec_id,
655 crate::paths::RUN_DIR.display(),
656 exec_id,
657 indent_text(&format!("\n\n```\n{}\n```\n", reasoning,), 2)
658 );
659 indent_text(&exec_content, 2)
660 }
661
662 fn indent_text(text: &str, spaces: usize) -> String {
663 let indent = " ".repeat(spaces);
664 text.lines()
665 .collect::<Vec<_>>()
666 .join(&format!("\n{}", indent))
667 }
668
669 Ok(())
670}
671
672fn fmt_evaluation_error(err: &anyhow::Error, name: &str, repetition_ix: &Option<u16>) -> String {
673 let err = format!("{err:?}")
674 .replace("<edits", "```xml\n<edits")
675 .replace("</edits>", "</edits>\n```");
676 format!(
677 "### ERROR {name}{}\n\n{err}\n",
678 repetition_ix
679 .map(|ix| format!(" [RUN {ix:03}]"))
680 .unwrap_or_default()
681 )
682}