1use crate::{
2 PredictArgs, PredictionProvider,
3 example::{ActualCursor, Example, ExampleScore},
4 format_prompt::TeacherPrompt,
5 headless::EpAppState,
6 metrics,
7 parse_output::parse_prediction_output,
8 predict::run_prediction,
9 progress::{ExampleProgress, Step},
10 reversal_tracking,
11};
12use anyhow::Context as _;
13use gpui::AsyncApp;
14use serde::Serialize;
15use std::fs::File;
16use std::io::BufWriter;
17use std::path::Path;
18use std::sync::Arc;
19use zeta_prompt::udiff::{apply_diff_to_string, apply_diff_to_string_with_hunk_offset};
20
21pub async fn run_scoring(
22 example: &mut Example,
23 args: &PredictArgs,
24 app_state: Arc<EpAppState>,
25 example_progress: &ExampleProgress,
26 cx: AsyncApp,
27) -> anyhow::Result<()> {
28 run_prediction(example, args, app_state, example_progress, cx).await?;
29
30 let progress = example_progress.start(Step::Score);
31
32 progress.set_substatus("applying patches");
33 let prompt_inputs = example
34 .prompt_inputs
35 .as_ref()
36 .context("prompt_inputs is required for scoring - run prediction first or ensure JSON includes prompt_inputs")?;
37 let original_text: &str = prompt_inputs.cursor_excerpt.as_ref();
38 let expected_patches_with_cursors = example.spec.expected_patches_with_cursor_positions();
39
40 let expected_texts: Vec<String> = expected_patches_with_cursors
41 .iter()
42 .map(|(patch, _)| {
43 apply_diff_to_string(patch, original_text)
44 .with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
45 })
46 .collect::<Result<Vec<_>, _>>()?;
47
48 // For Teacher prompts, we need to extract the editable region to properly compute cursor offsets.
49 // The actual_cursor_offset from Teacher is relative to the editable region, while the expected
50 // cursor from the patch is relative to the hunk. We need to apply the patch to the editable
51 // region to find where the hunk matched, then compute the expected cursor position.
52 let old_editable_region = if let Some(p) = example.prompt.as_ref() {
53 if matches!(
54 p.provider,
55 PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_)
56 ) {
57 Some(
58 TeacherPrompt::extract_editable_region(&p.input)?
59 .replace(TeacherPrompt::USER_CURSOR_MARKER, ""),
60 )
61 } else {
62 None
63 }
64 } else {
65 None
66 };
67
68 let zero_scores = ExampleScore {
69 delta_chr_f: 0.0,
70 delta_chr_f_true_positives: 0,
71 delta_chr_f_false_positives: 0,
72 delta_chr_f_false_negatives: 0,
73 delta_chr_f_precision: 0.0,
74 delta_chr_f_recall: 0.0,
75 delta_chr_f_beta: metrics::delta_chr_f_beta(),
76 braces_disbalance: 0,
77 exact_lines_tp: 0,
78 exact_lines_fp: 0,
79 exact_lines_fn: 0,
80 reversal_ratio: 0.0,
81 cursor_distance: None,
82 cursor_exact_match: None,
83 wrong_editable_region: None,
84 has_isolated_whitespace_changes: false,
85 inserted_tokens: 0,
86 deleted_tokens: 0,
87 kept_rate: None,
88 recall_rate: None,
89 cumulative_logprob: None,
90 avg_logprob: None,
91 };
92
93 let cursor_path = example.spec.cursor_path.as_ref();
94
95 progress.set_substatus("computing metrics");
96 let mut scores = vec![];
97 for prediction in &example.predictions {
98 let actual_patch = prediction.actual_patch.clone().or_else(|| {
99 parse_prediction_output(example, &prediction.actual_output, prediction.provider)
100 .ok()
101 .map(|(patch, _)| patch)
102 });
103
104 let Some(actual_patch) = actual_patch else {
105 scores.push(zero_scores.clone());
106 continue;
107 };
108
109 let token_changes = metrics::count_patch_token_changes(&actual_patch);
110
111 let actual_text = match apply_diff_to_string(&actual_patch, original_text) {
112 Ok(text) => text,
113 Err(_) => {
114 let mut s = zero_scores.clone();
115 s.inserted_tokens = token_changes.inserted_tokens;
116 s.deleted_tokens = token_changes.deleted_tokens;
117 scores.push(s);
118 continue;
119 }
120 };
121
122 let mut best_delta_chr_f_metrics = metrics::DeltaChrFMetrics::default();
123 let mut best_expected_cursor: Option<usize> = None;
124 let mut best_patch_idx: Option<usize> = None;
125 let mut best_expected_text: Option<&str> = None;
126
127 for (idx, expected) in expected_texts.iter().enumerate() {
128 let delta_chr_f_metrics = metrics::delta_chr_f(original_text, expected, &actual_text);
129 if delta_chr_f_metrics.score > best_delta_chr_f_metrics.score {
130 best_delta_chr_f_metrics = delta_chr_f_metrics;
131 best_patch_idx = Some(idx);
132 best_expected_text = Some(expected);
133 }
134 }
135
136 if let Some(idx) = best_patch_idx {
137 // Get the raw cursor offset from the expected patch (relative to hunk new text)
138 let expected_cursor_in_patch = expected_patches_with_cursors
139 .get(idx)
140 .and_then(|(_, cursor)| *cursor);
141
142 // For Teacher prompts, we need to apply the patch to the editable region
143 // to find where the hunk matched, then compute the actual cursor position
144 if let (Some(editable_region), Some(cursor_in_patch)) =
145 (&old_editable_region, expected_cursor_in_patch)
146 {
147 let (patch, _) = &expected_patches_with_cursors[idx];
148 if let Ok((_, hunk_offset)) =
149 apply_diff_to_string_with_hunk_offset(patch, editable_region)
150 {
151 let hunk_start = hunk_offset.unwrap_or(0);
152 best_expected_cursor = Some(hunk_start + cursor_in_patch);
153 }
154 } else {
155 // For non-Teacher prompts or if we can't compute, use raw offset
156 best_expected_cursor = expected_cursor_in_patch;
157 }
158 }
159
160 let disbalance_before = metrics::braces_disbalance(&original_text);
161 let disbalance_after = metrics::braces_disbalance(&actual_text);
162 let braces_disbalance = disbalance_after.saturating_sub(disbalance_before);
163
164 // Compute exact lines match against best matching expected patch
165 let best_exact_lines = expected_patches_with_cursors
166 .iter()
167 .map(|(expected_patch, _)| metrics::exact_lines_match(expected_patch, &actual_patch))
168 .max_by_key(|m| m.true_positives)
169 .unwrap_or_default();
170
171 // Compute reversal ratio
172 let reversal_ratio = reversal_tracking::compute_prediction_reversal_ratio(
173 prompt_inputs,
174 &actual_text,
175 cursor_path,
176 );
177
178 // Compute cursor position metrics
179 let (cursor_distance, cursor_exact_match) =
180 compute_cursor_metrics(best_expected_cursor, prediction.actual_cursor.as_ref());
181
182 // Compute approximation of editable region correctness
183 let wrong_editable_region = Some(!metrics::is_editable_region_correct(&actual_patch));
184
185 // Check for isolated whitespace changes.
186 let has_isolated_whitespace_changes = metrics::has_isolated_whitespace_changes(
187 &actual_patch,
188 prediction.actual_cursor.as_ref(),
189 );
190
191 let (kept_rate, recall_rate) = best_expected_text
192 .map(|reference_text| {
193 let result =
194 metrics::compute_kept_rate(original_text, &actual_text, reference_text);
195 (Some(result.kept_rate), Some(result.recall_rate))
196 })
197 .unwrap_or((None, None));
198
199 scores.push(ExampleScore {
200 delta_chr_f: best_delta_chr_f_metrics.score as f32,
201 delta_chr_f_true_positives: best_delta_chr_f_metrics.counts.true_positives,
202 delta_chr_f_false_positives: best_delta_chr_f_metrics.counts.false_positives,
203 delta_chr_f_false_negatives: best_delta_chr_f_metrics.counts.false_negatives,
204 delta_chr_f_precision: best_delta_chr_f_metrics.precision,
205 delta_chr_f_recall: best_delta_chr_f_metrics.recall,
206 delta_chr_f_beta: best_delta_chr_f_metrics.beta,
207 braces_disbalance,
208 exact_lines_tp: best_exact_lines.true_positives,
209 exact_lines_fp: best_exact_lines.false_positives,
210 exact_lines_fn: best_exact_lines.false_negatives,
211 reversal_ratio,
212 cursor_distance,
213 cursor_exact_match,
214 wrong_editable_region,
215 has_isolated_whitespace_changes,
216 inserted_tokens: token_changes.inserted_tokens,
217 deleted_tokens: token_changes.deleted_tokens,
218 kept_rate,
219 recall_rate,
220 cumulative_logprob: prediction.cumulative_logprob,
221 avg_logprob: prediction.avg_logprob,
222 });
223 }
224
225 example.score = scores;
226 Ok(())
227}
228
229fn compute_cursor_metrics(
230 expected_cursor_editable_region_offset: Option<usize>,
231 actual_cursor: Option<&ActualCursor>,
232) -> (Option<usize>, Option<bool>) {
233 match (expected_cursor_editable_region_offset, actual_cursor) {
234 (Some(expected), Some(actual)) => {
235 let distance = expected.abs_diff(actual.editable_region_offset.unwrap_or_default());
236 let exact_match = distance == 0;
237 (Some(distance), Some(exact_match))
238 }
239 (None, None) => {
240 // Neither has cursor position - skip cursor scoring
241 (None, None)
242 }
243 (Some(_), None) | (None, Some(_)) => {
244 // Only one has cursor position - count as miss
245 (None, Some(false))
246 }
247 }
248}
249
250pub fn print_report(examples: &[Example], verbose: bool) {
251 const MAX_EXAMPLES_DEFAULT: usize = 20;
252 use crate::metrics::ClassificationMetrics;
253
254 const LINE_WIDTH: usize = 101;
255 let separator = "─".repeat(LINE_WIDTH);
256
257 println!("{}", separator);
258 println!(
259 "{:<40} {:>8} {:>5} {:>7} {:>7} {:>7} {:>7} {:>6} {:>5}",
260 "Example", "DeltaChrF", "Brace", "F1", "Revert", "QaRev", "QaConf", "Cursor", "WrgER"
261 );
262 println!("{}", separator);
263
264 let mut all_delta_chr_f_scores = Vec::new();
265 let mut all_reversal_ratios = Vec::new();
266 let mut braces_disbalance_sum: usize = 0;
267 let mut total_delta_chr_f = ClassificationMetrics::default();
268 let mut total_delta_chr_f_precision = 0.0;
269 let mut total_delta_chr_f_recall = 0.0;
270 let mut delta_chr_f_beta = 0.0;
271 let mut total_exact_lines = ClassificationMetrics::default();
272 let mut total_scores: usize = 0;
273 let mut qa_reverts_count: usize = 0;
274 let mut qa_reverts_total: usize = 0;
275 let mut qa_confidence_sum: u64 = 0;
276 let mut qa_confidence_count: usize = 0;
277 let mut cursor_exact_matches: usize = 0;
278 let mut cursor_total: usize = 0;
279 let mut cursor_distance_sum: usize = 0;
280 let mut cursor_distance_count: usize = 0;
281 let mut wrong_editable_region_count: usize = 0;
282 let mut wrong_editable_region_total: usize = 0;
283 let mut isolated_whitespace_count: usize = 0;
284 let mut kept_rate_sum: f64 = 0.0;
285 let mut kept_rate_count: usize = 0;
286 let mut recall_rate_sum: f64 = 0.0;
287 let mut recall_rate_count: usize = 0;
288 let mut patch_inserted_tokens: Vec<usize> = Vec::new();
289 let mut patch_deleted_tokens: Vec<usize> = Vec::new();
290 let mut predictions_with_patch: usize = 0;
291
292 let mut printed_lines: usize = 0;
293 let mut skipped_lines: usize = 0;
294
295 for example in examples {
296 for (score_idx, score) in example.score.iter().enumerate() {
297 let exact_lines = score.exact_lines_counts();
298
299 // Get QA results for this prediction if available
300 let qa_result = example.qa.get(score_idx).and_then(|q| q.as_ref());
301 let qa_reverts_str = qa_result
302 .and_then(|q| q.reverts_edits)
303 .map(|v| if v { "yes" } else { "no" })
304 .unwrap_or("-");
305 let qa_conf_str = qa_result
306 .and_then(|q| q.confidence)
307 .map(|v| format!("{}", v))
308 .unwrap_or("-".to_string());
309
310 // Format wrong editable region metric
311 let wrong_er_str = match score.wrong_editable_region {
312 Some(true) => "✗",
313 Some(false) => "",
314 None => "",
315 };
316
317 // Format cursor metric
318 let cursor_str = match (score.cursor_exact_match, score.cursor_distance) {
319 (Some(true), _) => "✓".to_string(),
320 (Some(false), Some(dist)) => format!("±{}", dist),
321 (Some(false), None) => "✗".to_string(),
322 (None, _) => "-".to_string(),
323 };
324
325 if verbose || printed_lines < MAX_EXAMPLES_DEFAULT {
326 println!(
327 "{:<40} {:>8.2} {:>5} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6} {:>5}",
328 truncate_name(&example.spec.name, 40),
329 score.delta_chr_f,
330 score.braces_disbalance,
331 exact_lines.f1() * 100.0,
332 score.reversal_ratio * 100.0,
333 qa_reverts_str,
334 qa_conf_str,
335 cursor_str,
336 wrong_er_str
337 );
338 printed_lines += 1;
339 } else {
340 skipped_lines += 1;
341 }
342
343 all_delta_chr_f_scores.push(score.delta_chr_f);
344 all_reversal_ratios.push(score.reversal_ratio);
345 total_scores += 1;
346 braces_disbalance_sum += score.braces_disbalance;
347 total_delta_chr_f.accumulate(&score.delta_chr_f_counts());
348 total_delta_chr_f_precision += score.delta_chr_f_precision;
349 total_delta_chr_f_recall += score.delta_chr_f_recall;
350 delta_chr_f_beta = score.delta_chr_f_beta;
351 total_exact_lines.accumulate(&score.exact_lines_counts());
352
353 // Accumulate QA metrics
354 if let Some(qa) = qa_result {
355 if let Some(reverts) = qa.reverts_edits {
356 qa_reverts_total += 1;
357 if reverts {
358 qa_reverts_count += 1;
359 }
360 }
361 if let Some(conf) = qa.confidence {
362 qa_confidence_sum += conf as u64;
363 qa_confidence_count += 1;
364 }
365 }
366
367 // Accumulate wrong editable region metrics
368 if let Some(wrong) = score.wrong_editable_region {
369 wrong_editable_region_total += 1;
370 if wrong {
371 wrong_editable_region_count += 1;
372 }
373 }
374
375 // Accumulate isolated whitespace metrics
376 if score.has_isolated_whitespace_changes {
377 isolated_whitespace_count += 1;
378 }
379
380 // Accumulate kept and recall rate metrics
381 if let Some(kr) = score.kept_rate {
382 kept_rate_sum += kr;
383 kept_rate_count += 1;
384 }
385 if let Some(rr) = score.recall_rate {
386 recall_rate_sum += rr;
387 recall_rate_count += 1;
388 }
389
390 // Accumulate token change metrics (only for predictions that produced a patch)
391 let has_patch = example
392 .predictions
393 .get(score_idx)
394 .and_then(|p| p.actual_patch.as_ref())
395 .is_some_and(|p| !p.is_empty());
396 if has_patch {
397 predictions_with_patch += 1;
398 patch_inserted_tokens.push(score.inserted_tokens);
399 patch_deleted_tokens.push(score.deleted_tokens);
400 }
401
402 // Accumulate cursor metrics
403 if let Some(exact_match) = score.cursor_exact_match {
404 cursor_total += 1;
405 if exact_match {
406 cursor_exact_matches += 1;
407 }
408 }
409 if let Some(dist) = score.cursor_distance {
410 cursor_distance_sum += dist;
411 cursor_distance_count += 1;
412 }
413 }
414 }
415
416 if skipped_lines > 0 {
417 println!(
418 "{:<40} (use --verbose to see all {} examples)",
419 format!("... and {} more", skipped_lines),
420 printed_lines + skipped_lines
421 );
422 }
423 println!("{}", separator);
424
425 if !all_delta_chr_f_scores.is_empty() {
426 let avg_delta_chr_f: f32 =
427 all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
428 let avg_reversal_ratio: f32 =
429 all_reversal_ratios.iter().sum::<f32>() / all_reversal_ratios.len() as f32;
430 let braces_disbalance_avg: f32 = braces_disbalance_sum as f32 / total_scores as f32;
431
432 let qa_reverts_str = if qa_reverts_total > 0 {
433 format!(
434 "{:.1}%",
435 qa_reverts_count as f32 / qa_reverts_total as f32 * 100.0
436 )
437 } else {
438 "-".to_string()
439 };
440 let qa_conf_str = if qa_confidence_count > 0 {
441 format!(
442 "{:.1}",
443 qa_confidence_sum as f32 / qa_confidence_count as f32
444 )
445 } else {
446 "-".to_string()
447 };
448 let cursor_str = if cursor_total > 0 {
449 format!(
450 "{:.0}%",
451 cursor_exact_matches as f32 / cursor_total as f32 * 100.0
452 )
453 } else {
454 "-".to_string()
455 };
456 let wrong_er_str = if wrong_editable_region_total > 0 {
457 format!(
458 "{:.2}%",
459 wrong_editable_region_count as f32 / wrong_editable_region_total as f32 * 100.0
460 )
461 } else {
462 "-".to_string()
463 };
464 let isolated_ws_str = if total_scores > 0 {
465 format!(
466 "{}/{} ({:.1}%)",
467 isolated_whitespace_count,
468 total_scores,
469 isolated_whitespace_count as f32 / total_scores as f32 * 100.0
470 )
471 } else {
472 "-".to_string()
473 };
474 let avg_cursor_distance = if cursor_distance_count > 0 {
475 Some(cursor_distance_sum as f32 / cursor_distance_count as f32)
476 } else {
477 None
478 };
479
480 println!(
481 "{:<40} {:>8.2} {:>5.1} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6} {:>5}",
482 "TOTAL / AVERAGE",
483 avg_delta_chr_f,
484 braces_disbalance_avg,
485 total_exact_lines.f1() * 100.0,
486 avg_reversal_ratio * 100.0,
487 qa_reverts_str,
488 qa_conf_str,
489 cursor_str,
490 wrong_er_str
491 );
492 println!("{}", separator);
493 println!(
494 "Delta chrF (β={:.1}): TP={}, FP={}, FN={}, P={:.1}%, R={:.1}%",
495 delta_chr_f_beta,
496 total_delta_chr_f.true_positives,
497 total_delta_chr_f.false_positives,
498 total_delta_chr_f.false_negatives,
499 total_delta_chr_f_precision / total_scores as f64 * 100.0,
500 total_delta_chr_f_recall / total_scores as f64 * 100.0
501 );
502
503 // Print additional cursor metrics if available
504 if let Some(avg_dist) = avg_cursor_distance {
505 println!(
506 "Cursor: {}/{} exact matches ({:.0}%), avg distance: {:.1} bytes",
507 cursor_exact_matches,
508 cursor_total,
509 cursor_exact_matches as f32 / cursor_total as f32 * 100.0,
510 avg_dist
511 );
512 }
513
514 // Print isolated whitespace metrics
515 if total_scores > 0 {
516 println!("Isolated whitespace changes: {}", isolated_ws_str);
517 }
518
519 // Print kept and recall rate metrics
520 if kept_rate_count > 0 {
521 let avg_kept_rate = kept_rate_sum / kept_rate_count as f64;
522 println!(
523 "Kept rate: {:.1}% avg ({} evaluated)",
524 avg_kept_rate * 100.0,
525 kept_rate_count
526 );
527 }
528 if recall_rate_count > 0 {
529 let avg_recall_rate = recall_rate_sum / recall_rate_count as f64;
530 println!(
531 "Recall rate: {:.1}% avg ({} evaluated)",
532 avg_recall_rate * 100.0,
533 recall_rate_count
534 );
535 }
536
537 // Print token change percentile summary (only for predictions with a patch)
538 if !patch_inserted_tokens.is_empty() {
539 patch_inserted_tokens.sort_unstable();
540 patch_deleted_tokens.sort_unstable();
541 let mut patch_total_tokens: Vec<usize> = patch_inserted_tokens
542 .iter()
543 .zip(patch_deleted_tokens.iter())
544 .map(|(i, d)| i + d)
545 .collect();
546 patch_total_tokens.sort_unstable();
547
548 let patch_rate = predictions_with_patch as f32 / total_scores as f32 * 100.0;
549 println!();
550 println!(
551 "Token changes ({}/{} predictions produced a patch, {:.1}% — table includes only those)",
552 predictions_with_patch, total_scores, patch_rate
553 );
554 println!(
555 "{:<20} {:>8} {:>8} {:>8} {:>8} {:>8}",
556 "", "p25", "p50", "p75", "p90", "p99"
557 );
558 println!("{}", "─".repeat(LINE_WIDTH));
559 println!(
560 "{:<20} {:>8} {:>8} {:>8} {:>8} {:>8}",
561 "Inserted tokens",
562 percentile(&patch_inserted_tokens, 25),
563 percentile(&patch_inserted_tokens, 50),
564 percentile(&patch_inserted_tokens, 75),
565 percentile(&patch_inserted_tokens, 90),
566 percentile(&patch_inserted_tokens, 99),
567 );
568 println!(
569 "{:<20} {:>8} {:>8} {:>8} {:>8} {:>8}",
570 "Deleted tokens",
571 percentile(&patch_deleted_tokens, 25),
572 percentile(&patch_deleted_tokens, 50),
573 percentile(&patch_deleted_tokens, 75),
574 percentile(&patch_deleted_tokens, 90),
575 percentile(&patch_deleted_tokens, 99),
576 );
577 println!(
578 "{:<20} {:>8} {:>8} {:>8} {:>8} {:>8}",
579 "Total tokens",
580 percentile(&patch_total_tokens, 25),
581 percentile(&patch_total_tokens, 50),
582 percentile(&patch_total_tokens, 75),
583 percentile(&patch_total_tokens, 90),
584 percentile(&patch_total_tokens, 99),
585 );
586 }
587 }
588
589 println!("\n");
590}
591
592fn percentile(sorted_values: &[usize], p: usize) -> usize {
593 if sorted_values.is_empty() {
594 return 0;
595 }
596 let idx = (p as f64 / 100.0 * (sorted_values.len() as f64 - 1.0)).round() as usize;
597 sorted_values[idx.min(sorted_values.len() - 1)]
598}
599
600fn truncate_name(name: &str, max_len: usize) -> String {
601 if name.len() <= max_len {
602 name.to_string()
603 } else {
604 format!("{}...", &name[..max_len - 3])
605 }
606}
607
608#[derive(Serialize)]
609pub struct SummaryJson {
610 pub total_examples: usize,
611 pub avg_delta_chr_f: f32,
612 pub delta_chr_f_beta: f64,
613 pub delta_chr_f_true_positives: usize,
614 pub delta_chr_f_false_positives: usize,
615 pub delta_chr_f_false_negatives: usize,
616 pub delta_chr_f_precision: f64,
617 pub delta_chr_f_recall: f64,
618 pub avg_braces_disbalance: f32,
619 pub exact_lines_true_positives: usize,
620 pub exact_lines_false_positives: usize,
621 pub exact_lines_false_negatives: usize,
622 pub exact_lines_precision: f64,
623 pub exact_lines_recall: f64,
624 pub exact_lines_f1: f64,
625 pub avg_reversal_ratio: f32,
626 #[serde(skip_serializing_if = "Option::is_none")]
627 pub qa_avg_reverts_edits: Option<f32>,
628 #[serde(skip_serializing_if = "Option::is_none")]
629 pub qa_avg_confidence: Option<f32>,
630 #[serde(skip_serializing_if = "Option::is_none")]
631 pub cursor_exact_match_rate: Option<f32>,
632 #[serde(skip_serializing_if = "Option::is_none")]
633 pub cursor_avg_distance: Option<f32>,
634 #[serde(skip_serializing_if = "Option::is_none")]
635 pub cursor_total_evaluated: Option<usize>,
636 #[serde(skip_serializing_if = "Option::is_none")]
637 pub wrong_editable_region_rate: Option<f32>,
638 pub isolated_whitespace_rate: Option<f32>,
639 #[serde(skip_serializing_if = "Option::is_none")]
640 pub avg_kept_rate: Option<f64>,
641 #[serde(skip_serializing_if = "Option::is_none")]
642 pub avg_recall_rate: Option<f64>,
643}
644
645pub fn compute_summary(examples: &[Example]) -> SummaryJson {
646 use crate::metrics::ClassificationMetrics;
647
648 let mut all_delta_chr_f_scores = Vec::new();
649 let mut all_reversal_ratios = Vec::new();
650 let mut braces_disbalance_sum: usize = 0;
651 let mut total_delta_chr_f = ClassificationMetrics::default();
652 let mut total_delta_chr_f_precision = 0.0;
653 let mut total_delta_chr_f_recall = 0.0;
654 let mut delta_chr_f_beta = 0.0;
655 let mut total_exact_lines = ClassificationMetrics::default();
656 let mut total_scores: usize = 0;
657 let mut qa_reverts_count: usize = 0;
658 let mut qa_reverts_total: usize = 0;
659 let mut qa_confidence_sum: u64 = 0;
660 let mut qa_confidence_count: usize = 0;
661 let mut cursor_exact_matches: usize = 0;
662 let mut cursor_total: usize = 0;
663 let mut cursor_distance_sum: usize = 0;
664 let mut cursor_distance_count: usize = 0;
665 let mut wrong_editable_region_count: usize = 0;
666 let mut wrong_editable_region_total: usize = 0;
667 let mut isolated_whitespace_count: usize = 0;
668 let mut kept_rate_sum: f64 = 0.0;
669 let mut kept_rate_count: usize = 0;
670 let mut recall_rate_sum: f64 = 0.0;
671 let mut recall_rate_count: usize = 0;
672
673 for example in examples {
674 for (score_idx, score) in example.score.iter().enumerate() {
675 all_delta_chr_f_scores.push(score.delta_chr_f);
676 all_reversal_ratios.push(score.reversal_ratio);
677 total_scores += 1;
678 braces_disbalance_sum += score.braces_disbalance;
679 total_delta_chr_f.accumulate(&score.delta_chr_f_counts());
680 total_delta_chr_f_precision += score.delta_chr_f_precision;
681 total_delta_chr_f_recall += score.delta_chr_f_recall;
682 delta_chr_f_beta = score.delta_chr_f_beta;
683 total_exact_lines.accumulate(&score.exact_lines_counts());
684
685 // Accumulate QA metrics
686 if let Some(Some(qa)) = example.qa.get(score_idx) {
687 if let Some(reverts) = qa.reverts_edits {
688 qa_reverts_total += 1;
689 if reverts {
690 qa_reverts_count += 1;
691 }
692 }
693 if let Some(conf) = qa.confidence {
694 qa_confidence_sum += conf as u64;
695 qa_confidence_count += 1;
696 }
697 }
698
699 // Accumulate wrong editable region metrics
700 if let Some(wrong) = score.wrong_editable_region {
701 wrong_editable_region_total += 1;
702 if wrong {
703 wrong_editable_region_count += 1;
704 }
705 }
706
707 // Accumulate isolated whitespace metrics
708 if score.has_isolated_whitespace_changes {
709 isolated_whitespace_count += 1;
710 }
711
712 // Accumulate kept and recall rate metrics
713 if let Some(kr) = score.kept_rate {
714 kept_rate_sum += kr;
715 kept_rate_count += 1;
716 }
717 if let Some(rr) = score.recall_rate {
718 recall_rate_sum += rr;
719 recall_rate_count += 1;
720 }
721
722 // Accumulate cursor metrics
723 if let Some(exact_match) = score.cursor_exact_match {
724 cursor_total += 1;
725 if exact_match {
726 cursor_exact_matches += 1;
727 }
728 }
729 if let Some(dist) = score.cursor_distance {
730 cursor_distance_sum += dist;
731 cursor_distance_count += 1;
732 }
733 }
734 }
735
736 let avg_delta_chr_f = if all_delta_chr_f_scores.is_empty() {
737 0.0
738 } else {
739 all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32
740 };
741
742 let avg_reversal_ratio = if all_reversal_ratios.is_empty() {
743 0.0
744 } else {
745 all_reversal_ratios.iter().sum::<f32>() / all_reversal_ratios.len() as f32
746 };
747
748 let avg_braces_disbalance = if total_scores == 0 {
749 0.0
750 } else {
751 braces_disbalance_sum as f32 / total_scores as f32
752 };
753
754 let qa_avg_reverts_edits = if qa_reverts_total > 0 {
755 Some(qa_reverts_count as f32 / qa_reverts_total as f32)
756 } else {
757 None
758 };
759
760 let qa_avg_confidence = if qa_confidence_count > 0 {
761 Some(qa_confidence_sum as f32 / qa_confidence_count as f32)
762 } else {
763 None
764 };
765
766 let cursor_exact_match_rate = if cursor_total > 0 {
767 Some(cursor_exact_matches as f32 / cursor_total as f32)
768 } else {
769 None
770 };
771
772 let cursor_avg_distance = if cursor_distance_count > 0 {
773 Some(cursor_distance_sum as f32 / cursor_distance_count as f32)
774 } else {
775 None
776 };
777
778 let cursor_total_evaluated = if cursor_total > 0 {
779 Some(cursor_total)
780 } else {
781 None
782 };
783
784 let wrong_editable_region_rate = if wrong_editable_region_total > 0 {
785 Some(wrong_editable_region_count as f32 / wrong_editable_region_total as f32)
786 } else {
787 None
788 };
789
790 let isolated_whitespace_rate = if total_scores > 0 {
791 Some(isolated_whitespace_count as f32 / total_scores as f32)
792 } else {
793 None
794 };
795
796 let avg_kept_rate = if kept_rate_count > 0 {
797 Some(kept_rate_sum / kept_rate_count as f64)
798 } else {
799 None
800 };
801
802 let avg_recall_rate = if recall_rate_count > 0 {
803 Some(recall_rate_sum / recall_rate_count as f64)
804 } else {
805 None
806 };
807
808 SummaryJson {
809 total_examples: total_scores,
810 avg_delta_chr_f,
811 delta_chr_f_beta,
812 delta_chr_f_true_positives: total_delta_chr_f.true_positives,
813 delta_chr_f_false_positives: total_delta_chr_f.false_positives,
814 delta_chr_f_false_negatives: total_delta_chr_f.false_negatives,
815 delta_chr_f_precision: if total_scores == 0 {
816 0.0
817 } else {
818 total_delta_chr_f_precision / total_scores as f64
819 },
820 delta_chr_f_recall: if total_scores == 0 {
821 0.0
822 } else {
823 total_delta_chr_f_recall / total_scores as f64
824 },
825 avg_braces_disbalance,
826 exact_lines_true_positives: total_exact_lines.true_positives,
827 exact_lines_false_positives: total_exact_lines.false_positives,
828 exact_lines_false_negatives: total_exact_lines.false_negatives,
829 exact_lines_precision: total_exact_lines.precision(),
830 exact_lines_recall: total_exact_lines.recall(),
831 exact_lines_f1: total_exact_lines.f1(),
832 avg_reversal_ratio,
833 qa_avg_reverts_edits,
834 qa_avg_confidence,
835 cursor_exact_match_rate,
836 cursor_avg_distance,
837 cursor_total_evaluated,
838 wrong_editable_region_rate,
839 isolated_whitespace_rate,
840 avg_kept_rate,
841 avg_recall_rate,
842 }
843}
844
845pub fn write_summary_json(examples: &[Example], path: &Path) -> anyhow::Result<()> {
846 let summary = compute_summary(examples);
847 let file = File::create(path)
848 .with_context(|| format!("Failed to create summary JSON file: {}", path.display()))?;
849 let writer = BufWriter::new(file);
850 serde_json::to_writer_pretty(writer, &summary)
851 .with_context(|| format!("Failed to write summary JSON to: {}", path.display()))?;
852 eprintln!("Wrote summary JSON to: {}", path.display());
853 Ok(())
854}