1use crate::metrics::{self, Scores};
2use std::{
3 collections::HashMap,
4 io::{IsTerminal, Write},
5 sync::Arc,
6};
7
8use anyhow::Result;
9use gpui::{AsyncApp, Entity};
10use project::Project;
11use util::ResultExt as _;
12use zeta::{Zeta, udiff::DiffLine};
13
14use crate::{
15 EvaluateArguments, PredictionOptions,
16 example::{Example, NamedExample},
17 headless::ZetaCliAppState,
18 paths::print_run_data_dir,
19 predict::{PredictionDetails, perform_predict, setup_zeta},
20};
21
22#[derive(Debug)]
23pub(crate) struct ExecutionData {
24 execution_id: String,
25 diff: String,
26 reasoning: String,
27}
28
29pub async fn run_evaluate(
30 args: EvaluateArguments,
31 app_state: &Arc<ZetaCliAppState>,
32 cx: &mut AsyncApp,
33) {
34 if args.example_paths.is_empty() {
35 eprintln!("No examples provided");
36 return;
37 }
38
39 let all_tasks = args.example_paths.into_iter().map(|path| {
40 let options = args.options.clone();
41 let app_state = app_state.clone();
42 let example = NamedExample::load(&path).expect("Failed to load example");
43
44 cx.spawn(async move |cx| {
45 let project = example.setup_project(&app_state, cx).await.unwrap();
46
47 let providers = (0..args.repetitions)
48 .map(|_| setup_zeta(args.options.provider, &project, &app_state, cx).unwrap())
49 .collect::<Vec<_>>();
50
51 let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
52
53 let tasks = providers
54 .into_iter()
55 .enumerate()
56 .map(move |(repetition_ix, zeta)| {
57 let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
58 let example = example.clone();
59 let project = project.clone();
60 let options = options.clone();
61
62 cx.spawn(async move |cx| {
63 let name = example.name.clone();
64 run_evaluate_one(
65 example,
66 repetition_ix,
67 project,
68 zeta,
69 options,
70 !args.skip_prediction,
71 cx,
72 )
73 .await
74 .map_err(|err| (err, name, repetition_ix))
75 })
76 });
77 futures::future::join_all(tasks).await
78 })
79 });
80 let all_results = futures::future::join_all(all_tasks).await;
81
82 write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap();
83 if let Some(mut output_file) =
84 std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err()
85 {
86 write_aggregated_scores(&mut output_file, &all_results).log_err();
87 };
88
89 if args.repetitions > 1 {
90 if let Err(e) = write_bucketed_analysis(&all_results) {
91 eprintln!("Failed to write bucketed analysis: {:?}", e);
92 }
93 }
94
95 print_run_data_dir(args.repetitions == 1, std::io::stdout().is_terminal());
96}
97
98fn write_aggregated_scores(
99 w: &mut impl std::io::Write,
100 all_results: &Vec<
101 Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
102 >,
103) -> Result<()> {
104 let mut successful = Vec::new();
105 let mut failed_count = 0;
106
107 for result in all_results.iter().flatten() {
108 match result {
109 Ok((eval_result, _execution_data)) => successful.push(eval_result),
110 Err((err, name, repetition_ix)) => {
111 if failed_count == 0 {
112 writeln!(w, "## Errors\n")?;
113 }
114
115 failed_count += 1;
116 writeln!(w, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
117 }
118 }
119 }
120
121 if successful.len() > 1 {
122 let edit_scores = successful
123 .iter()
124 .filter_map(|r| r.edit_scores.clone())
125 .collect::<Vec<_>>();
126 let has_edit_predictions = edit_scores.len() > 0;
127 let aggregated_result = EvaluationResult {
128 context_scores: Scores::aggregate(successful.iter().map(|r| &r.context_scores)),
129 edit_scores: has_edit_predictions.then(|| EditScores::aggregate(&edit_scores)),
130 prompt_len: successful.iter().map(|r| r.prompt_len).sum::<usize>() / successful.len(),
131 generated_len: successful.iter().map(|r| r.generated_len).sum::<usize>()
132 / successful.len(),
133 };
134
135 writeln!(w, "\n{}", "-".repeat(80))?;
136 writeln!(w, "\n## TOTAL SCORES")?;
137 writeln!(w, "{:#}", aggregated_result)?;
138 }
139
140 if successful.len() + failed_count > 1 {
141 writeln!(
142 w,
143 "\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
144 successful.len(),
145 successful.len() + failed_count,
146 (successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
147 )?;
148 }
149
150 Ok(())
151}
152
153pub async fn run_evaluate_one(
154 example: NamedExample,
155 repetition_ix: Option<u16>,
156 project: Entity<Project>,
157 zeta: Entity<Zeta>,
158 prediction_options: PredictionOptions,
159 predict: bool,
160 cx: &mut AsyncApp,
161) -> Result<(EvaluationResult, ExecutionData)> {
162 let predict_result = perform_predict(
163 example.clone(),
164 project,
165 zeta,
166 repetition_ix,
167 prediction_options,
168 cx,
169 )
170 .await?;
171
172 let evaluation_result = evaluate(&example.example, &predict_result, predict);
173
174 if repetition_ix.is_none() {
175 write_eval_result(
176 &example,
177 &predict_result,
178 &evaluation_result,
179 &mut std::io::stdout(),
180 std::io::stdout().is_terminal(),
181 predict,
182 )?;
183 }
184
185 if let Some(mut results_file) =
186 std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err()
187 {
188 write_eval_result(
189 &example,
190 &predict_result,
191 &evaluation_result,
192 &mut results_file,
193 false,
194 predict,
195 )
196 .log_err();
197 }
198
199 let execution_data = ExecutionData {
200 execution_id: if let Some(rep_ix) = repetition_ix {
201 format!("{:03}", rep_ix)
202 } else {
203 example.name.clone()
204 },
205 diff: predict_result.diff.clone(),
206 reasoning: std::fs::read_to_string(
207 predict_result
208 .run_example_dir
209 .join("prediction_response.md"),
210 )
211 .unwrap_or_default(),
212 };
213
214 anyhow::Ok((evaluation_result, execution_data))
215}
216
217fn write_eval_result(
218 example: &NamedExample,
219 predictions: &PredictionDetails,
220 evaluation_result: &EvaluationResult,
221 out: &mut impl Write,
222 use_color: bool,
223 predict: bool,
224) -> Result<()> {
225 if predict {
226 writeln!(
227 out,
228 "## Expected edit prediction:\n\n```diff\n{}\n```\n",
229 compare_diffs(
230 &example.example.expected_patch,
231 &predictions.diff,
232 use_color
233 )
234 )?;
235 writeln!(
236 out,
237 "## Actual edit prediction:\n\n```diff\n{}\n```\n",
238 compare_diffs(
239 &predictions.diff,
240 &example.example.expected_patch,
241 use_color
242 )
243 )?;
244 }
245
246 writeln!(out, "{:#}", evaluation_result)?;
247
248 anyhow::Ok(())
249}
250
251#[derive(Debug, Default, Clone)]
252pub struct EditScores {
253 pub line_match: Scores,
254 pub chr_f: f64,
255}
256
257impl EditScores {
258 pub fn aggregate(scores: &[EditScores]) -> EditScores {
259 let line_match = Scores::aggregate(scores.iter().map(|s| &s.line_match));
260 let chr_f = scores.iter().map(|s| s.chr_f).sum::<f64>() / scores.len() as f64;
261
262 EditScores { line_match, chr_f }
263 }
264}
265
266#[derive(Debug, Default)]
267pub struct EvaluationResult {
268 pub edit_scores: Option<EditScores>,
269 pub context_scores: Scores,
270 pub prompt_len: usize,
271 pub generated_len: usize,
272}
273
274impl std::fmt::Display for EvaluationResult {
275 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276 if f.alternate() {
277 self.fmt_table(f)
278 } else {
279 self.fmt_markdown(f)
280 }
281 }
282}
283
284impl EvaluationResult {
285 fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286 write!(
287 f,
288 r#"
289### Context Scores
290{}
291"#,
292 self.context_scores.to_markdown(),
293 )?;
294 if let Some(scores) = &self.edit_scores {
295 write!(
296 f,
297 r#"
298 ### Edit Prediction Scores
299 {}"#,
300 scores.line_match.to_markdown()
301 )?;
302 }
303 Ok(())
304 }
305
306 fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307 writeln!(f, "#### Prompt Statistics")?;
308 writeln!(f, "─────────────────────────")?;
309 writeln!(f, "Prompt_len Generated_len")?;
310 writeln!(f, "─────────────────────────")?;
311 writeln!(f, "{:<11} {:<14}", self.prompt_len, self.generated_len,)?;
312 writeln!(f)?;
313 writeln!(f)?;
314 writeln!(f, "#### Performance Scores")?;
315 writeln!(
316 f,
317 "──────────────────────────────────────────────────────────────────"
318 )?;
319 writeln!(
320 f,
321 " TP FP FN Precision Recall F1"
322 )?;
323 writeln!(
324 f,
325 "──────────────────────────────────────────────────────────────────"
326 )?;
327 writeln!(
328 f,
329 "Context Retrieval {:<6} {:<6} {:<6} {:>8.2} {:>7.2} {:>6.2}",
330 self.context_scores.true_positives,
331 self.context_scores.false_positives,
332 self.context_scores.false_negatives,
333 self.context_scores.precision() * 100.0,
334 self.context_scores.recall() * 100.0,
335 self.context_scores.f1_score() * 100.0
336 )?;
337 if let Some(edit_scores) = &self.edit_scores {
338 let line_match = &edit_scores.line_match;
339 writeln!(f, "Edit Prediction")?;
340 writeln!(
341 f,
342 " ├─ exact lines {:<6} {:<6} {:<6} {:>8.2} {:>7.2} {:>6.2}",
343 line_match.true_positives,
344 line_match.false_positives,
345 line_match.false_negatives,
346 line_match.precision() * 100.0,
347 line_match.recall() * 100.0,
348 line_match.f1_score() * 100.0
349 )?;
350 writeln!(
351 f,
352 " └─ diff chrF {:<6} {:<6} {:<6} {:>8} {:>8} {:>6.2}",
353 "-", "-", "-", "-", "-", edit_scores.chr_f
354 )?;
355 }
356 Ok(())
357 }
358}
359
360fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult {
361 let mut eval_result = EvaluationResult {
362 prompt_len: preds.prompt_len,
363 generated_len: preds.generated_len,
364 ..Default::default()
365 };
366
367 if predict {
368 // todo: alternatives for patches
369 let expected_patch = example
370 .expected_patch
371 .lines()
372 .map(DiffLine::parse)
373 .collect::<Vec<_>>();
374 let actual_patch = preds.diff.lines().map(DiffLine::parse).collect::<Vec<_>>();
375
376 let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
377 let chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch);
378
379 eval_result.edit_scores = Some(EditScores { line_match, chr_f });
380 }
381
382 eval_result
383}
384
385/// Return annotated `patch_a` so that:
386/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
387/// Additions and deletions that are present in `patch_b` will be highlighted in green.
388pub fn compare_diffs(patch_a: &str, patch_b: &str, use_color: bool) -> String {
389 let green = if use_color { "\x1b[32m✓ " } else { "" };
390 let red = if use_color { "\x1b[31m✗ " } else { "" };
391 let neutral = if use_color { " " } else { "" };
392 let reset = if use_color { "\x1b[0m" } else { "" };
393 let lines_a = patch_a.lines().map(DiffLine::parse);
394 let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
395
396 let annotated = lines_a
397 .map(|line| match line {
398 DiffLine::Addition(_) | DiffLine::Deletion(_) => {
399 if lines_b.contains(&line) {
400 format!("{green}{line}{reset}")
401 } else {
402 format!("{red}{line}{reset}")
403 }
404 }
405 _ => format!("{neutral}{line}{reset}"),
406 })
407 .collect::<Vec<String>>();
408
409 annotated.join("\n")
410}
411
412fn write_bucketed_analysis(
413 all_results: &Vec<
414 Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
415 >,
416) -> Result<()> {
417 #[derive(Debug)]
418 struct EditBucket {
419 diff: String,
420 is_correct: bool,
421 execution_indices: Vec<String>,
422 reasoning_samples: Vec<String>,
423 }
424
425 let mut total_executions = 0;
426 let mut empty_predictions = Vec::new();
427 let mut errors = Vec::new();
428
429 let mut buckets: HashMap<String, EditBucket> = HashMap::new();
430
431 for result in all_results.iter().flatten() {
432 total_executions += 1;
433
434 let (evaluation_result, execution_data) = match result {
435 Ok((eval_result, execution_data)) => {
436 if execution_data.diff.is_empty() {
437 empty_predictions.push(execution_data);
438 continue;
439 }
440 (eval_result, execution_data)
441 }
442 Err(err) => {
443 errors.push(err);
444 continue;
445 }
446 };
447
448 buckets
449 .entry(execution_data.diff.clone())
450 .and_modify(|bucket| {
451 bucket
452 .execution_indices
453 .push(execution_data.execution_id.clone());
454 bucket
455 .reasoning_samples
456 .push(execution_data.reasoning.clone());
457 })
458 .or_insert_with(|| EditBucket {
459 diff: execution_data.diff.clone(),
460 is_correct: {
461 evaluation_result
462 .edit_scores
463 .as_ref()
464 .map_or(false, |edit_scores| {
465 edit_scores.line_match.false_positives == 0
466 && edit_scores.line_match.false_negatives == 0
467 && edit_scores.line_match.true_positives > 0
468 })
469 },
470 execution_indices: vec![execution_data.execution_id.clone()],
471 reasoning_samples: vec![execution_data.reasoning.clone()],
472 });
473 }
474
475 let mut sorted_buckets = buckets.into_values().collect::<Vec<_>>();
476 sorted_buckets.sort_by(|a, b| match (a.is_correct, b.is_correct) {
477 (true, false) => std::cmp::Ordering::Less,
478 (false, true) => std::cmp::Ordering::Greater,
479 _ => b.execution_indices.len().cmp(&a.execution_indices.len()),
480 });
481
482 let output_path = crate::paths::RUN_DIR.join("bucketed_analysis.md");
483 let mut output = std::fs::File::create(&output_path)?;
484
485 writeln!(output, "# Bucketed Edit Analysis\n")?;
486
487 writeln!(output, "## Summary\n")?;
488 writeln!(output, "- **Total executions**: {}", total_executions)?;
489
490 let correct_count: usize = sorted_buckets
491 .iter()
492 .filter(|b| b.is_correct)
493 .map(|b| b.execution_indices.len())
494 .sum();
495
496 let incorrect_count: usize = sorted_buckets
497 .iter()
498 .filter(|b| !b.is_correct)
499 .map(|b| b.execution_indices.len())
500 .sum();
501
502 writeln!(
503 output,
504 "- **Correct predictions**: {} ({:.1}%)",
505 correct_count,
506 (correct_count as f64 / total_executions as f64) * 100.0
507 )?;
508
509 writeln!(
510 output,
511 "- **Incorrect predictions**: {} ({:.1}%)",
512 incorrect_count,
513 (incorrect_count as f64 / total_executions as f64) * 100.0
514 )?;
515
516 writeln!(
517 output,
518 "- **No Predictions**: {} ({:.1}%)",
519 empty_predictions.len(),
520 (empty_predictions.len() as f64 / total_executions as f64) * 100.0
521 )?;
522
523 let unique_incorrect = sorted_buckets.iter().filter(|b| !b.is_correct).count();
524 writeln!(
525 output,
526 "- **Unique incorrect edit patterns**: {}\n",
527 unique_incorrect
528 )?;
529
530 writeln!(output, "---\n")?;
531
532 for (idx, bucket) in sorted_buckets.iter().filter(|b| b.is_correct).enumerate() {
533 if idx == 0 {
534 writeln!(
535 output,
536 "## Correct Predictions ({} occurrences)\n",
537 bucket.execution_indices.len()
538 )?;
539 }
540
541 writeln!(output, "**Predicted Edit:**\n")?;
542 writeln!(output, "```diff")?;
543 writeln!(output, "{}", bucket.diff)?;
544 writeln!(output, "```\n")?;
545
546 writeln!(
547 output,
548 "**Executions:** {}\n",
549 bucket.execution_indices.join(", ")
550 )?;
551 writeln!(output, "---\n")?;
552 }
553
554 for (idx, bucket) in sorted_buckets.iter().filter(|b| !b.is_correct).enumerate() {
555 writeln!(
556 output,
557 "## Incorrect Prediction #{} ({} occurrences)\n",
558 idx + 1,
559 bucket.execution_indices.len()
560 )?;
561
562 writeln!(output, "**Predicted Edit:**\n")?;
563 writeln!(output, "```diff")?;
564 writeln!(output, "{}", bucket.diff)?;
565 writeln!(output, "```\n")?;
566
567 writeln!(
568 output,
569 "**Executions:** {}\n",
570 bucket.execution_indices.join(", ")
571 )?;
572
573 for (exec_id, reasoning) in bucket
574 .execution_indices
575 .iter()
576 .zip(bucket.reasoning_samples.iter())
577 {
578 writeln!(output, "{}", fmt_execution(exec_id, reasoning))?;
579 }
580
581 writeln!(output, "\n---\n")?;
582 }
583
584 if !empty_predictions.is_empty() {
585 writeln!(
586 output,
587 "## No Predictions ({} occurrences)\n",
588 empty_predictions.len()
589 )?;
590
591 for execution_data in &empty_predictions {
592 writeln!(
593 output,
594 "{}",
595 fmt_execution(&execution_data.execution_id, &execution_data.reasoning)
596 )?;
597 }
598 writeln!(output, "\n---\n")?;
599 }
600
601 if !errors.is_empty() {
602 writeln!(output, "## Errors ({} occurrences)\n", errors.len())?;
603
604 for (err, name, repetition_ix) in &errors {
605 writeln!(output, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
606 }
607 writeln!(output, "\n---\n")?;
608 }
609
610 fn fmt_execution(exec_id: &str, reasoning: &str) -> String {
611 let exec_content = format!(
612 "\n### Execution {} `{}/{}/prediction_response.md`{}",
613 exec_id,
614 crate::paths::RUN_DIR.display(),
615 exec_id,
616 indent_text(&format!("\n\n```\n{}\n```\n", reasoning,), 2)
617 );
618 indent_text(&exec_content, 2)
619 }
620
621 fn indent_text(text: &str, spaces: usize) -> String {
622 let indent = " ".repeat(spaces);
623 text.lines()
624 .collect::<Vec<_>>()
625 .join(&format!("\n{}", indent))
626 }
627
628 Ok(())
629}
630
631fn fmt_evaluation_error(err: &anyhow::Error, name: &str, repetition_ix: &Option<u16>) -> String {
632 let err = format!("{err:?}")
633 .replace("<edits", "```xml\n<edits")
634 .replace("</edits>", "</edits>\n```");
635 format!(
636 "### ERROR {name}{}\n\n{err}\n",
637 repetition_ix
638 .map(|ix| format!(" [RUN {ix:03}]"))
639 .unwrap_or_default()
640 )
641}