1use crate::{
2 PredictArgs, PredictionProvider,
3 example::{ActualCursor, Example, ExampleScore},
4 format_prompt::TeacherPrompt,
5 headless::EpAppState,
6 metrics,
7 parse_output::parse_prediction_output,
8 predict::run_prediction,
9 progress::{ExampleProgress, Step},
10 reversal_tracking,
11};
12use anyhow::Context as _;
13use gpui::AsyncApp;
14use serde::Serialize;
15use std::fs::File;
16use std::io::BufWriter;
17use std::path::Path;
18use std::sync::Arc;
19use zeta_prompt::udiff::{apply_diff_to_string, apply_diff_to_string_with_hunk_offset};
20
21pub async fn run_scoring(
22 example: &mut Example,
23 args: &PredictArgs,
24 app_state: Arc<EpAppState>,
25 example_progress: &ExampleProgress,
26 cx: AsyncApp,
27) -> anyhow::Result<()> {
28 run_prediction(example, args, app_state, example_progress, cx).await?;
29
30 let progress = example_progress.start(Step::Score);
31
32 progress.set_substatus("applying patches");
33 let prompt_inputs = example
34 .prompt_inputs
35 .as_ref()
36 .context("prompt_inputs is required for scoring - run prediction first or ensure JSON includes prompt_inputs")?;
37 let original_text: &str = prompt_inputs.cursor_excerpt.as_ref();
38 let expected_patches_with_cursors = example.spec.expected_patches_with_cursor_positions();
39
40 let expected_texts: Vec<String> = expected_patches_with_cursors
41 .iter()
42 .map(|(patch, _)| {
43 apply_diff_to_string(patch, original_text)
44 .with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
45 })
46 .collect::<Result<Vec<_>, _>>()?;
47
48 // For Teacher prompts, we need to extract the editable region to properly compute cursor offsets.
49 // The actual_cursor_offset from Teacher is relative to the editable region, while the expected
50 // cursor from the patch is relative to the hunk. We need to apply the patch to the editable
51 // region to find where the hunk matched, then compute the expected cursor position.
52 let old_editable_region = if let Some(p) = example.prompt.as_ref() {
53 if matches!(
54 p.provider,
55 PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_)
56 ) {
57 Some(
58 TeacherPrompt::extract_editable_region(&p.input)?
59 .replace(TeacherPrompt::USER_CURSOR_MARKER, ""),
60 )
61 } else {
62 None
63 }
64 } else {
65 None
66 };
67
68 let zero_scores = ExampleScore {
69 delta_chr_f: 0.0,
70 delta_chr_f_true_positives: 0,
71 delta_chr_f_false_positives: 0,
72 delta_chr_f_false_negatives: 0,
73 delta_chr_f_precision: 0.0,
74 delta_chr_f_recall: 0.0,
75 delta_chr_f_beta: metrics::delta_chr_f_beta(),
76 braces_disbalance: 0,
77 exact_lines_tp: 0,
78 exact_lines_fp: 0,
79 exact_lines_fn: 0,
80 reversal_ratio: 0.0,
81 cursor_distance: None,
82 cursor_exact_match: None,
83 wrong_editable_region: None,
84 has_isolated_whitespace_changes: false,
85 inserted_tokens: 0,
86 deleted_tokens: 0,
87 kept_rate: None,
88 recall_rate: None,
89 kept_chars: None,
90 correctly_deleted_chars: None,
91 discarded_chars: None,
92 cumulative_logprob: None,
93 avg_logprob: None,
94 };
95
96 let cursor_path = example.spec.cursor_path.as_ref();
97
98 progress.set_substatus("computing metrics");
99 let mut scores = vec![];
100 for prediction in &example.predictions {
101 let actual_patch = prediction.actual_patch.clone().or_else(|| {
102 parse_prediction_output(example, &prediction.actual_output, prediction.provider)
103 .ok()
104 .map(|(patch, _)| patch)
105 });
106
107 let Some(actual_patch) = actual_patch else {
108 scores.push(zero_scores.clone());
109 continue;
110 };
111
112 let token_changes = metrics::count_patch_token_changes(&actual_patch);
113
114 let actual_text = match apply_diff_to_string(&actual_patch, original_text) {
115 Ok(text) => text,
116 Err(_) => {
117 let mut s = zero_scores.clone();
118 s.inserted_tokens = token_changes.inserted_tokens;
119 s.deleted_tokens = token_changes.deleted_tokens;
120 scores.push(s);
121 continue;
122 }
123 };
124
125 let mut best_delta_chr_f_metrics = metrics::DeltaChrFMetrics::default();
126 let mut best_expected_cursor: Option<usize> = None;
127 let mut best_patch_idx: Option<usize> = None;
128 let mut best_expected_text: Option<&str> = None;
129
130 for (idx, expected) in expected_texts.iter().enumerate() {
131 let delta_chr_f_metrics = metrics::delta_chr_f(original_text, expected, &actual_text);
132 if delta_chr_f_metrics.score > best_delta_chr_f_metrics.score {
133 best_delta_chr_f_metrics = delta_chr_f_metrics;
134 best_patch_idx = Some(idx);
135 best_expected_text = Some(expected);
136 }
137 }
138
139 if let Some(idx) = best_patch_idx {
140 // Get the raw cursor offset from the expected patch (relative to hunk new text)
141 let expected_cursor_in_patch = expected_patches_with_cursors
142 .get(idx)
143 .and_then(|(_, cursor)| *cursor);
144
145 // For Teacher prompts, we need to apply the patch to the editable region
146 // to find where the hunk matched, then compute the actual cursor position
147 if let (Some(editable_region), Some(cursor_in_patch)) =
148 (&old_editable_region, expected_cursor_in_patch)
149 {
150 let (patch, _) = &expected_patches_with_cursors[idx];
151 if let Ok((_, hunk_offset)) =
152 apply_diff_to_string_with_hunk_offset(patch, editable_region)
153 {
154 let hunk_start = hunk_offset.unwrap_or(0);
155 best_expected_cursor = Some(hunk_start + cursor_in_patch);
156 }
157 } else {
158 // For non-Teacher prompts or if we can't compute, use raw offset
159 best_expected_cursor = expected_cursor_in_patch;
160 }
161 }
162
163 let disbalance_before = metrics::braces_disbalance(&original_text);
164 let disbalance_after = metrics::braces_disbalance(&actual_text);
165 let braces_disbalance = disbalance_after.saturating_sub(disbalance_before);
166
167 // Compute exact lines match against best matching expected patch
168 let best_exact_lines = expected_patches_with_cursors
169 .iter()
170 .map(|(expected_patch, _)| metrics::exact_lines_match(expected_patch, &actual_patch))
171 .max_by_key(|m| m.true_positives)
172 .unwrap_or_default();
173
174 // Compute reversal ratio
175 let reversal_ratio = reversal_tracking::compute_prediction_reversal_ratio(
176 prompt_inputs,
177 &actual_text,
178 cursor_path,
179 );
180
181 // Compute cursor position metrics
182 let (cursor_distance, cursor_exact_match) =
183 compute_cursor_metrics(best_expected_cursor, prediction.actual_cursor.as_ref());
184
185 // Compute approximation of editable region correctness
186 let wrong_editable_region = Some(!metrics::is_editable_region_correct(&actual_patch));
187
188 // Check for isolated whitespace changes.
189 let has_isolated_whitespace_changes = metrics::has_isolated_whitespace_changes(
190 &actual_patch,
191 prediction.actual_cursor.as_ref(),
192 );
193
194 let (kept_rate, recall_rate, kept_chars, correctly_deleted_chars, discarded_chars) =
195 best_expected_text
196 .map(|reference_text| {
197 let result =
198 metrics::compute_kept_rate(original_text, &actual_text, reference_text);
199 (
200 Some(result.kept_rate),
201 Some(result.recall_rate),
202 Some(result.kept_chars),
203 Some(result.correctly_deleted_chars),
204 Some(result.discarded_chars),
205 )
206 })
207 .unwrap_or((None, None, None, None, None));
208
209 scores.push(ExampleScore {
210 delta_chr_f: best_delta_chr_f_metrics.score as f32,
211 delta_chr_f_true_positives: best_delta_chr_f_metrics.counts.true_positives,
212 delta_chr_f_false_positives: best_delta_chr_f_metrics.counts.false_positives,
213 delta_chr_f_false_negatives: best_delta_chr_f_metrics.counts.false_negatives,
214 delta_chr_f_precision: best_delta_chr_f_metrics.precision,
215 delta_chr_f_recall: best_delta_chr_f_metrics.recall,
216 delta_chr_f_beta: best_delta_chr_f_metrics.beta,
217 braces_disbalance,
218 exact_lines_tp: best_exact_lines.true_positives,
219 exact_lines_fp: best_exact_lines.false_positives,
220 exact_lines_fn: best_exact_lines.false_negatives,
221 reversal_ratio,
222 cursor_distance,
223 cursor_exact_match,
224 wrong_editable_region,
225 has_isolated_whitespace_changes,
226 inserted_tokens: token_changes.inserted_tokens,
227 deleted_tokens: token_changes.deleted_tokens,
228 kept_rate,
229 recall_rate,
230 kept_chars,
231 correctly_deleted_chars,
232 discarded_chars,
233 cumulative_logprob: prediction.cumulative_logprob,
234 avg_logprob: prediction.avg_logprob,
235 });
236 }
237
238 example.score = scores;
239 Ok(())
240}
241
242fn compute_cursor_metrics(
243 expected_cursor_editable_region_offset: Option<usize>,
244 actual_cursor: Option<&ActualCursor>,
245) -> (Option<usize>, Option<bool>) {
246 match (expected_cursor_editable_region_offset, actual_cursor) {
247 (Some(expected), Some(actual)) => {
248 let distance = expected.abs_diff(actual.editable_region_offset.unwrap_or_default());
249 let exact_match = distance == 0;
250 (Some(distance), Some(exact_match))
251 }
252 (None, None) => {
253 // Neither has cursor position - skip cursor scoring
254 (None, None)
255 }
256 (Some(_), None) | (None, Some(_)) => {
257 // Only one has cursor position - count as miss
258 (None, Some(false))
259 }
260 }
261}
262
263pub fn print_report(examples: &[Example], verbose: bool) {
264 const MAX_EXAMPLES_DEFAULT: usize = 20;
265 use crate::metrics::ClassificationMetrics;
266
267 const LINE_WIDTH: usize = 101;
268 let separator = "─".repeat(LINE_WIDTH);
269
270 println!("{}", separator);
271 println!(
272 "{:<40} {:>8} {:>5} {:>7} {:>7} {:>7} {:>7} {:>6} {:>5}",
273 "Example", "DeltaChrF", "Brace", "F1", "Revert", "QaRev", "QaConf", "Cursor", "WrgER"
274 );
275 println!("{}", separator);
276
277 let mut all_delta_chr_f_scores = Vec::new();
278 let mut all_reversal_ratios = Vec::new();
279 let mut braces_disbalance_sum: usize = 0;
280 let mut total_delta_chr_f = ClassificationMetrics::default();
281 let mut total_delta_chr_f_precision = 0.0;
282 let mut total_delta_chr_f_recall = 0.0;
283 let mut delta_chr_f_beta = 0.0;
284 let mut total_exact_lines = ClassificationMetrics::default();
285 let mut total_scores: usize = 0;
286 let mut qa_reverts_count: usize = 0;
287 let mut qa_reverts_total: usize = 0;
288 let mut qa_confidence_sum: u64 = 0;
289 let mut qa_confidence_count: usize = 0;
290 let mut cursor_exact_matches: usize = 0;
291 let mut cursor_total: usize = 0;
292 let mut cursor_distance_sum: usize = 0;
293 let mut cursor_distance_count: usize = 0;
294 let mut wrong_editable_region_count: usize = 0;
295 let mut wrong_editable_region_total: usize = 0;
296 let mut isolated_whitespace_count: usize = 0;
297 let mut kept_rate_sum: f64 = 0.0;
298 let mut kept_rate_count: usize = 0;
299 let mut kept_chars_total: usize = 0;
300 let mut correctly_deleted_chars_total: usize = 0;
301 let mut discarded_chars_total: usize = 0;
302 let mut recall_rate_sum: f64 = 0.0;
303 let mut recall_rate_count: usize = 0;
304 let mut patch_inserted_tokens: Vec<usize> = Vec::new();
305 let mut patch_deleted_tokens: Vec<usize> = Vec::new();
306 let mut predictions_with_patch: usize = 0;
307
308 let mut printed_lines: usize = 0;
309 let mut skipped_lines: usize = 0;
310
311 for example in examples {
312 for (score_idx, score) in example.score.iter().enumerate() {
313 let exact_lines = score.exact_lines_counts();
314
315 // Get QA results for this prediction if available
316 let qa_result = example.qa.get(score_idx).and_then(|q| q.as_ref());
317 let qa_reverts_str = qa_result
318 .and_then(|q| q.reverts_edits)
319 .map(|v| if v { "yes" } else { "no" })
320 .unwrap_or("-");
321 let qa_conf_str = qa_result
322 .and_then(|q| q.confidence)
323 .map(|v| format!("{}", v))
324 .unwrap_or("-".to_string());
325
326 // Format wrong editable region metric
327 let wrong_er_str = match score.wrong_editable_region {
328 Some(true) => "✗",
329 Some(false) => "",
330 None => "",
331 };
332
333 // Format cursor metric
334 let cursor_str = match (score.cursor_exact_match, score.cursor_distance) {
335 (Some(true), _) => "✓".to_string(),
336 (Some(false), Some(dist)) => format!("±{}", dist),
337 (Some(false), None) => "✗".to_string(),
338 (None, _) => "-".to_string(),
339 };
340
341 if verbose || printed_lines < MAX_EXAMPLES_DEFAULT {
342 println!(
343 "{:<40} {:>8.2} {:>5} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6} {:>5}",
344 truncate_name(&example.spec.name, 40),
345 score.delta_chr_f,
346 score.braces_disbalance,
347 exact_lines.f1() * 100.0,
348 score.reversal_ratio * 100.0,
349 qa_reverts_str,
350 qa_conf_str,
351 cursor_str,
352 wrong_er_str
353 );
354 printed_lines += 1;
355 } else {
356 skipped_lines += 1;
357 }
358
359 all_delta_chr_f_scores.push(score.delta_chr_f);
360 all_reversal_ratios.push(score.reversal_ratio);
361 total_scores += 1;
362 braces_disbalance_sum += score.braces_disbalance;
363 total_delta_chr_f.accumulate(&score.delta_chr_f_counts());
364 total_delta_chr_f_precision += score.delta_chr_f_precision;
365 total_delta_chr_f_recall += score.delta_chr_f_recall;
366 delta_chr_f_beta = score.delta_chr_f_beta;
367 total_exact_lines.accumulate(&score.exact_lines_counts());
368
369 // Accumulate QA metrics
370 if let Some(qa) = qa_result {
371 if let Some(reverts) = qa.reverts_edits {
372 qa_reverts_total += 1;
373 if reverts {
374 qa_reverts_count += 1;
375 }
376 }
377 if let Some(conf) = qa.confidence {
378 qa_confidence_sum += conf as u64;
379 qa_confidence_count += 1;
380 }
381 }
382
383 // Accumulate wrong editable region metrics
384 if let Some(wrong) = score.wrong_editable_region {
385 wrong_editable_region_total += 1;
386 if wrong {
387 wrong_editable_region_count += 1;
388 }
389 }
390
391 // Accumulate isolated whitespace metrics
392 if score.has_isolated_whitespace_changes {
393 isolated_whitespace_count += 1;
394 }
395
396 // Accumulate kept and recall rate metrics
397 if let Some(kr) = score.kept_rate {
398 kept_rate_sum += kr;
399 kept_rate_count += 1;
400 }
401 if let Some(kept_chars) = score.kept_chars {
402 kept_chars_total += kept_chars;
403 }
404 if let Some(correctly_deleted_chars) = score.correctly_deleted_chars {
405 correctly_deleted_chars_total += correctly_deleted_chars;
406 }
407 if let Some(discarded_chars) = score.discarded_chars {
408 discarded_chars_total += discarded_chars;
409 }
410 if let Some(rr) = score.recall_rate {
411 recall_rate_sum += rr;
412 recall_rate_count += 1;
413 }
414
415 // Accumulate token change metrics (only for predictions that produced a patch)
416 let has_patch = example
417 .predictions
418 .get(score_idx)
419 .and_then(|p| p.actual_patch.as_ref())
420 .is_some_and(|p| !p.is_empty());
421 if has_patch {
422 predictions_with_patch += 1;
423 patch_inserted_tokens.push(score.inserted_tokens);
424 patch_deleted_tokens.push(score.deleted_tokens);
425 }
426
427 // Accumulate cursor metrics
428 if let Some(exact_match) = score.cursor_exact_match {
429 cursor_total += 1;
430 if exact_match {
431 cursor_exact_matches += 1;
432 }
433 }
434 if let Some(dist) = score.cursor_distance {
435 cursor_distance_sum += dist;
436 cursor_distance_count += 1;
437 }
438 }
439 }
440
441 if skipped_lines > 0 {
442 println!(
443 "{:<40} (use --verbose to see all {} examples)",
444 format!("... and {} more", skipped_lines),
445 printed_lines + skipped_lines
446 );
447 }
448 println!("{}", separator);
449
450 if !all_delta_chr_f_scores.is_empty() {
451 let avg_delta_chr_f: f32 =
452 all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
453 let avg_reversal_ratio: f32 =
454 all_reversal_ratios.iter().sum::<f32>() / all_reversal_ratios.len() as f32;
455 let braces_disbalance_avg: f32 = braces_disbalance_sum as f32 / total_scores as f32;
456
457 let qa_reverts_str = if qa_reverts_total > 0 {
458 format!(
459 "{:.1}%",
460 qa_reverts_count as f32 / qa_reverts_total as f32 * 100.0
461 )
462 } else {
463 "-".to_string()
464 };
465 let qa_conf_str = if qa_confidence_count > 0 {
466 format!(
467 "{:.1}",
468 qa_confidence_sum as f32 / qa_confidence_count as f32
469 )
470 } else {
471 "-".to_string()
472 };
473 let cursor_str = if cursor_total > 0 {
474 format!(
475 "{:.0}%",
476 cursor_exact_matches as f32 / cursor_total as f32 * 100.0
477 )
478 } else {
479 "-".to_string()
480 };
481 let wrong_er_str = if wrong_editable_region_total > 0 {
482 format!(
483 "{:.2}%",
484 wrong_editable_region_count as f32 / wrong_editable_region_total as f32 * 100.0
485 )
486 } else {
487 "-".to_string()
488 };
489 let isolated_ws_str = if total_scores > 0 {
490 format!(
491 "{}/{} ({:.1}%)",
492 isolated_whitespace_count,
493 total_scores,
494 isolated_whitespace_count as f32 / total_scores as f32 * 100.0
495 )
496 } else {
497 "-".to_string()
498 };
499 let avg_cursor_distance = if cursor_distance_count > 0 {
500 Some(cursor_distance_sum as f32 / cursor_distance_count as f32)
501 } else {
502 None
503 };
504
505 println!(
506 "{:<40} {:>8.2} {:>5.1} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6} {:>5}",
507 "TOTAL / AVERAGE",
508 avg_delta_chr_f,
509 braces_disbalance_avg,
510 total_exact_lines.f1() * 100.0,
511 avg_reversal_ratio * 100.0,
512 qa_reverts_str,
513 qa_conf_str,
514 cursor_str,
515 wrong_er_str
516 );
517 println!("{}", separator);
518 println!(
519 "Delta chrF (β={:.1}): TP={}, FP={}, FN={}, P={:.1}%, R={:.1}%",
520 delta_chr_f_beta,
521 total_delta_chr_f.true_positives,
522 total_delta_chr_f.false_positives,
523 total_delta_chr_f.false_negatives,
524 total_delta_chr_f_precision / total_scores as f64 * 100.0,
525 total_delta_chr_f_recall / total_scores as f64 * 100.0
526 );
527
528 // Print additional cursor metrics if available
529 if let Some(avg_dist) = avg_cursor_distance {
530 println!(
531 "Cursor: {}/{} exact matches ({:.0}%), avg distance: {:.1} bytes",
532 cursor_exact_matches,
533 cursor_total,
534 cursor_exact_matches as f32 / cursor_total as f32 * 100.0,
535 avg_dist
536 );
537 }
538
539 // Print isolated whitespace metrics
540 if total_scores > 0 {
541 println!("Isolated whitespace changes: {}", isolated_ws_str);
542 }
543
544 // Print kept and recall rate metrics
545 if kept_rate_count > 0 {
546 let avg_kept_rate = kept_rate_sum / kept_rate_count as f64;
547 println!(
548 "Kept rate: {:.1}% avg ({} evaluated, kept chars: {}, correctly deleted chars: {}, discarded chars: {})",
549 avg_kept_rate * 100.0,
550 kept_rate_count,
551 kept_chars_total,
552 correctly_deleted_chars_total,
553 discarded_chars_total
554 );
555 }
556 if recall_rate_count > 0 {
557 let avg_recall_rate = recall_rate_sum / recall_rate_count as f64;
558 println!(
559 "Recall rate: {:.1}% avg ({} evaluated)",
560 avg_recall_rate * 100.0,
561 recall_rate_count
562 );
563 }
564
565 // Print token change percentile summary (only for predictions with a patch)
566 if !patch_inserted_tokens.is_empty() {
567 patch_inserted_tokens.sort_unstable();
568 patch_deleted_tokens.sort_unstable();
569 let mut patch_total_tokens: Vec<usize> = patch_inserted_tokens
570 .iter()
571 .zip(patch_deleted_tokens.iter())
572 .map(|(i, d)| i + d)
573 .collect();
574 patch_total_tokens.sort_unstable();
575
576 let patch_rate = predictions_with_patch as f32 / total_scores as f32 * 100.0;
577 println!();
578 println!(
579 "Token changes ({}/{} predictions produced a patch, {:.1}% — table includes only those)",
580 predictions_with_patch, total_scores, patch_rate
581 );
582 println!(
583 "{:<20} {:>8} {:>8} {:>8} {:>8} {:>8}",
584 "", "p25", "p50", "p75", "p90", "p99"
585 );
586 println!("{}", "─".repeat(LINE_WIDTH));
587 println!(
588 "{:<20} {:>8} {:>8} {:>8} {:>8} {:>8}",
589 "Inserted tokens",
590 percentile(&patch_inserted_tokens, 25),
591 percentile(&patch_inserted_tokens, 50),
592 percentile(&patch_inserted_tokens, 75),
593 percentile(&patch_inserted_tokens, 90),
594 percentile(&patch_inserted_tokens, 99),
595 );
596 println!(
597 "{:<20} {:>8} {:>8} {:>8} {:>8} {:>8}",
598 "Deleted tokens",
599 percentile(&patch_deleted_tokens, 25),
600 percentile(&patch_deleted_tokens, 50),
601 percentile(&patch_deleted_tokens, 75),
602 percentile(&patch_deleted_tokens, 90),
603 percentile(&patch_deleted_tokens, 99),
604 );
605 println!(
606 "{:<20} {:>8} {:>8} {:>8} {:>8} {:>8}",
607 "Total tokens",
608 percentile(&patch_total_tokens, 25),
609 percentile(&patch_total_tokens, 50),
610 percentile(&patch_total_tokens, 75),
611 percentile(&patch_total_tokens, 90),
612 percentile(&patch_total_tokens, 99),
613 );
614 }
615 }
616
617 println!("\n");
618}
619
620fn percentile(sorted_values: &[usize], p: usize) -> usize {
621 if sorted_values.is_empty() {
622 return 0;
623 }
624 let idx = (p as f64 / 100.0 * (sorted_values.len() as f64 - 1.0)).round() as usize;
625 sorted_values[idx.min(sorted_values.len() - 1)]
626}
627
628fn truncate_name(name: &str, max_len: usize) -> String {
629 if name.len() <= max_len {
630 name.to_string()
631 } else {
632 format!("{}...", &name[..max_len - 3])
633 }
634}
635
636#[derive(Serialize)]
637pub struct SummaryJson {
638 pub total_examples: usize,
639 pub avg_delta_chr_f: f32,
640 pub delta_chr_f_beta: f64,
641 pub delta_chr_f_true_positives: usize,
642 pub delta_chr_f_false_positives: usize,
643 pub delta_chr_f_false_negatives: usize,
644 pub delta_chr_f_precision: f64,
645 pub delta_chr_f_recall: f64,
646 pub avg_braces_disbalance: f32,
647 pub exact_lines_true_positives: usize,
648 pub exact_lines_false_positives: usize,
649 pub exact_lines_false_negatives: usize,
650 pub exact_lines_precision: f64,
651 pub exact_lines_recall: f64,
652 pub exact_lines_f1: f64,
653 pub avg_reversal_ratio: f32,
654 #[serde(skip_serializing_if = "Option::is_none")]
655 pub qa_avg_reverts_edits: Option<f32>,
656 #[serde(skip_serializing_if = "Option::is_none")]
657 pub qa_avg_confidence: Option<f32>,
658 #[serde(skip_serializing_if = "Option::is_none")]
659 pub cursor_exact_match_rate: Option<f32>,
660 #[serde(skip_serializing_if = "Option::is_none")]
661 pub cursor_avg_distance: Option<f32>,
662 #[serde(skip_serializing_if = "Option::is_none")]
663 pub cursor_total_evaluated: Option<usize>,
664 #[serde(skip_serializing_if = "Option::is_none")]
665 pub wrong_editable_region_rate: Option<f32>,
666 pub isolated_whitespace_rate: Option<f32>,
667 #[serde(skip_serializing_if = "Option::is_none")]
668 pub avg_kept_rate: Option<f64>,
669 #[serde(skip_serializing_if = "Option::is_none")]
670 pub avg_recall_rate: Option<f64>,
671 #[serde(skip_serializing_if = "Option::is_none")]
672 pub total_kept_chars: Option<usize>,
673 #[serde(skip_serializing_if = "Option::is_none")]
674 pub total_correctly_deleted_chars: Option<usize>,
675 #[serde(skip_serializing_if = "Option::is_none")]
676 pub total_discarded_chars: Option<usize>,
677}
678
679pub fn compute_summary(examples: &[Example]) -> SummaryJson {
680 use crate::metrics::ClassificationMetrics;
681
682 let mut all_delta_chr_f_scores = Vec::new();
683 let mut all_reversal_ratios = Vec::new();
684 let mut braces_disbalance_sum: usize = 0;
685 let mut total_delta_chr_f = ClassificationMetrics::default();
686 let mut total_delta_chr_f_precision = 0.0;
687 let mut total_delta_chr_f_recall = 0.0;
688 let mut delta_chr_f_beta = 0.0;
689 let mut total_exact_lines = ClassificationMetrics::default();
690 let mut total_scores: usize = 0;
691 let mut qa_reverts_count: usize = 0;
692 let mut qa_reverts_total: usize = 0;
693 let mut qa_confidence_sum: u64 = 0;
694 let mut qa_confidence_count: usize = 0;
695 let mut cursor_exact_matches: usize = 0;
696 let mut cursor_total: usize = 0;
697 let mut cursor_distance_sum: usize = 0;
698 let mut cursor_distance_count: usize = 0;
699 let mut wrong_editable_region_count: usize = 0;
700 let mut wrong_editable_region_total: usize = 0;
701 let mut isolated_whitespace_count: usize = 0;
702 let mut kept_rate_sum: f64 = 0.0;
703 let mut kept_rate_count: usize = 0;
704 let mut kept_chars_total: usize = 0;
705 let mut kept_chars_count: usize = 0;
706 let mut correctly_deleted_chars_total: usize = 0;
707 let mut correctly_deleted_chars_count: usize = 0;
708 let mut discarded_chars_total: usize = 0;
709 let mut discarded_chars_count: usize = 0;
710 let mut recall_rate_sum: f64 = 0.0;
711 let mut recall_rate_count: usize = 0;
712
713 for example in examples {
714 for (score_idx, score) in example.score.iter().enumerate() {
715 all_delta_chr_f_scores.push(score.delta_chr_f);
716 all_reversal_ratios.push(score.reversal_ratio);
717 total_scores += 1;
718 braces_disbalance_sum += score.braces_disbalance;
719 total_delta_chr_f.accumulate(&score.delta_chr_f_counts());
720 total_delta_chr_f_precision += score.delta_chr_f_precision;
721 total_delta_chr_f_recall += score.delta_chr_f_recall;
722 delta_chr_f_beta = score.delta_chr_f_beta;
723 total_exact_lines.accumulate(&score.exact_lines_counts());
724
725 // Accumulate QA metrics
726 if let Some(Some(qa)) = example.qa.get(score_idx) {
727 if let Some(reverts) = qa.reverts_edits {
728 qa_reverts_total += 1;
729 if reverts {
730 qa_reverts_count += 1;
731 }
732 }
733 if let Some(conf) = qa.confidence {
734 qa_confidence_sum += conf as u64;
735 qa_confidence_count += 1;
736 }
737 }
738
739 // Accumulate wrong editable region metrics
740 if let Some(wrong) = score.wrong_editable_region {
741 wrong_editable_region_total += 1;
742 if wrong {
743 wrong_editable_region_count += 1;
744 }
745 }
746
747 // Accumulate isolated whitespace metrics
748 if score.has_isolated_whitespace_changes {
749 isolated_whitespace_count += 1;
750 }
751
752 // Accumulate kept and recall rate metrics
753 if let Some(kr) = score.kept_rate {
754 kept_rate_sum += kr;
755 kept_rate_count += 1;
756 }
757 if let Some(kept_chars) = score.kept_chars {
758 kept_chars_total += kept_chars;
759 kept_chars_count += 1;
760 }
761 if let Some(correctly_deleted_chars) = score.correctly_deleted_chars {
762 correctly_deleted_chars_total += correctly_deleted_chars;
763 correctly_deleted_chars_count += 1;
764 }
765 if let Some(discarded_chars) = score.discarded_chars {
766 discarded_chars_total += discarded_chars;
767 discarded_chars_count += 1;
768 }
769 if let Some(rr) = score.recall_rate {
770 recall_rate_sum += rr;
771 recall_rate_count += 1;
772 }
773
774 // Accumulate cursor metrics
775 if let Some(exact_match) = score.cursor_exact_match {
776 cursor_total += 1;
777 if exact_match {
778 cursor_exact_matches += 1;
779 }
780 }
781 if let Some(dist) = score.cursor_distance {
782 cursor_distance_sum += dist;
783 cursor_distance_count += 1;
784 }
785 }
786 }
787
788 let avg_delta_chr_f = if all_delta_chr_f_scores.is_empty() {
789 0.0
790 } else {
791 all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32
792 };
793
794 let avg_reversal_ratio = if all_reversal_ratios.is_empty() {
795 0.0
796 } else {
797 all_reversal_ratios.iter().sum::<f32>() / all_reversal_ratios.len() as f32
798 };
799
800 let avg_braces_disbalance = if total_scores == 0 {
801 0.0
802 } else {
803 braces_disbalance_sum as f32 / total_scores as f32
804 };
805
806 let qa_avg_reverts_edits = if qa_reverts_total > 0 {
807 Some(qa_reverts_count as f32 / qa_reverts_total as f32)
808 } else {
809 None
810 };
811
812 let qa_avg_confidence = if qa_confidence_count > 0 {
813 Some(qa_confidence_sum as f32 / qa_confidence_count as f32)
814 } else {
815 None
816 };
817
818 let cursor_exact_match_rate = if cursor_total > 0 {
819 Some(cursor_exact_matches as f32 / cursor_total as f32)
820 } else {
821 None
822 };
823
824 let cursor_avg_distance = if cursor_distance_count > 0 {
825 Some(cursor_distance_sum as f32 / cursor_distance_count as f32)
826 } else {
827 None
828 };
829
830 let cursor_total_evaluated = if cursor_total > 0 {
831 Some(cursor_total)
832 } else {
833 None
834 };
835
836 let wrong_editable_region_rate = if wrong_editable_region_total > 0 {
837 Some(wrong_editable_region_count as f32 / wrong_editable_region_total as f32)
838 } else {
839 None
840 };
841
842 let isolated_whitespace_rate = if total_scores > 0 {
843 Some(isolated_whitespace_count as f32 / total_scores as f32)
844 } else {
845 None
846 };
847
848 let avg_kept_rate = if kept_rate_count > 0 {
849 Some(kept_rate_sum / kept_rate_count as f64)
850 } else {
851 None
852 };
853
854 let avg_recall_rate = if recall_rate_count > 0 {
855 Some(recall_rate_sum / recall_rate_count as f64)
856 } else {
857 None
858 };
859
860 let total_kept_chars = if kept_chars_count > 0 {
861 Some(kept_chars_total)
862 } else {
863 None
864 };
865
866 let total_correctly_deleted_chars = if correctly_deleted_chars_count > 0 {
867 Some(correctly_deleted_chars_total)
868 } else {
869 None
870 };
871
872 let total_discarded_chars = if discarded_chars_count > 0 {
873 Some(discarded_chars_total)
874 } else {
875 None
876 };
877
878 SummaryJson {
879 total_examples: total_scores,
880 avg_delta_chr_f,
881 delta_chr_f_beta,
882 delta_chr_f_true_positives: total_delta_chr_f.true_positives,
883 delta_chr_f_false_positives: total_delta_chr_f.false_positives,
884 delta_chr_f_false_negatives: total_delta_chr_f.false_negatives,
885 delta_chr_f_precision: if total_scores == 0 {
886 0.0
887 } else {
888 total_delta_chr_f_precision / total_scores as f64
889 },
890 delta_chr_f_recall: if total_scores == 0 {
891 0.0
892 } else {
893 total_delta_chr_f_recall / total_scores as f64
894 },
895 avg_braces_disbalance,
896 exact_lines_true_positives: total_exact_lines.true_positives,
897 exact_lines_false_positives: total_exact_lines.false_positives,
898 exact_lines_false_negatives: total_exact_lines.false_negatives,
899 exact_lines_precision: total_exact_lines.precision(),
900 exact_lines_recall: total_exact_lines.recall(),
901 exact_lines_f1: total_exact_lines.f1(),
902 avg_reversal_ratio,
903 qa_avg_reverts_edits,
904 qa_avg_confidence,
905 cursor_exact_match_rate,
906 cursor_avg_distance,
907 cursor_total_evaluated,
908 wrong_editable_region_rate,
909 isolated_whitespace_rate,
910 avg_kept_rate,
911 avg_recall_rate,
912 total_kept_chars,
913 total_correctly_deleted_chars,
914 total_discarded_chars,
915 }
916}
917
918pub fn write_summary_json(examples: &[Example], path: &Path) -> anyhow::Result<()> {
919 let summary = compute_summary(examples);
920 let file = File::create(path)
921 .with_context(|| format!("Failed to create summary JSON file: {}", path.display()))?;
922 let writer = BufWriter::new(file);
923 serde_json::to_writer_pretty(writer, &summary)
924 .with_context(|| format!("Failed to write summary JSON to: {}", path.display()))?;
925 eprintln!("Wrote summary JSON to: {}", path.display());
926 Ok(())
927}