score.rs

  1use crate::{
  2    PredictArgs, PredictionProvider,
  3    example::{ActualCursor, Example, ExampleScore},
  4    format_prompt::TeacherPrompt,
  5    headless::EpAppState,
  6    metrics,
  7    parse_output::parse_prediction_output,
  8    predict::run_prediction,
  9    progress::{ExampleProgress, Step},
 10    reversal_tracking,
 11};
 12use anyhow::Context as _;
 13use gpui::AsyncApp;
 14use serde::Serialize;
 15use std::fs::File;
 16use std::io::BufWriter;
 17use std::path::Path;
 18use std::sync::Arc;
 19use zeta_prompt::udiff::{apply_diff_to_string, apply_diff_to_string_with_hunk_offset};
 20
 21pub async fn run_scoring(
 22    example: &mut Example,
 23    args: &PredictArgs,
 24    app_state: Arc<EpAppState>,
 25    example_progress: &ExampleProgress,
 26    cx: AsyncApp,
 27) -> anyhow::Result<()> {
 28    run_prediction(example, args, app_state, example_progress, cx).await?;
 29
 30    let progress = example_progress.start(Step::Score);
 31
 32    progress.set_substatus("applying patches");
 33    let prompt_inputs = example
 34        .prompt_inputs
 35        .as_ref()
 36        .context("prompt_inputs is required for scoring - run prediction first or ensure JSON includes prompt_inputs")?;
 37    let original_text: &str = prompt_inputs.cursor_excerpt.as_ref();
 38    let expected_patches_with_cursors = example.spec.expected_patches_with_cursor_positions();
 39
 40    let expected_texts: Vec<String> = expected_patches_with_cursors
 41        .iter()
 42        .map(|(patch, _)| {
 43            apply_diff_to_string(patch, original_text)
 44                .with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
 45        })
 46        .collect::<Result<Vec<_>, _>>()?;
 47
 48    // For Teacher prompts, we need to extract the editable region to properly compute cursor offsets.
 49    // The actual_cursor_offset from Teacher is relative to the editable region, while the expected
 50    // cursor from the patch is relative to the hunk. We need to apply the patch to the editable
 51    // region to find where the hunk matched, then compute the expected cursor position.
 52    let old_editable_region = if let Some(p) = example.prompt.as_ref() {
 53        if matches!(
 54            p.provider,
 55            PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_)
 56        ) {
 57            Some(
 58                TeacherPrompt::extract_editable_region(&p.input)?
 59                    .replace(TeacherPrompt::USER_CURSOR_MARKER, ""),
 60            )
 61        } else {
 62            None
 63        }
 64    } else {
 65        None
 66    };
 67
 68    let zero_scores = ExampleScore {
 69        delta_chr_f: 0.0,
 70        delta_chr_f_true_positives: 0,
 71        delta_chr_f_false_positives: 0,
 72        delta_chr_f_false_negatives: 0,
 73        delta_chr_f_precision: 0.0,
 74        delta_chr_f_recall: 0.0,
 75        delta_chr_f_beta: metrics::delta_chr_f_beta(),
 76        braces_disbalance: 0,
 77        exact_lines_tp: 0,
 78        exact_lines_fp: 0,
 79        exact_lines_fn: 0,
 80        reversal_ratio: 0.0,
 81        cursor_distance: None,
 82        cursor_exact_match: None,
 83        wrong_editable_region: None,
 84        has_isolated_whitespace_changes: false,
 85        inserted_tokens: 0,
 86        deleted_tokens: 0,
 87        kept_rate: None,
 88        recall_rate: None,
 89        kept_chars: None,
 90        correctly_deleted_chars: None,
 91        discarded_chars: None,
 92        cumulative_logprob: None,
 93        avg_logprob: None,
 94    };
 95
 96    let cursor_path = example.spec.cursor_path.as_ref();
 97
 98    progress.set_substatus("computing metrics");
 99    let mut scores = vec![];
100    for prediction in &example.predictions {
101        let actual_patch = prediction.actual_patch.clone().or_else(|| {
102            parse_prediction_output(example, &prediction.actual_output, prediction.provider)
103                .ok()
104                .map(|(patch, _)| patch)
105        });
106
107        let Some(actual_patch) = actual_patch else {
108            scores.push(zero_scores.clone());
109            continue;
110        };
111
112        let token_changes = metrics::count_patch_token_changes(&actual_patch);
113
114        let actual_text = match apply_diff_to_string(&actual_patch, original_text) {
115            Ok(text) => text,
116            Err(_) => {
117                let mut s = zero_scores.clone();
118                s.inserted_tokens = token_changes.inserted_tokens;
119                s.deleted_tokens = token_changes.deleted_tokens;
120                scores.push(s);
121                continue;
122            }
123        };
124
125        let mut best_delta_chr_f_metrics = metrics::DeltaChrFMetrics::default();
126        let mut best_expected_cursor: Option<usize> = None;
127        let mut best_patch_idx: Option<usize> = None;
128        let mut best_expected_text: Option<&str> = None;
129
130        for (idx, expected) in expected_texts.iter().enumerate() {
131            let delta_chr_f_metrics = metrics::delta_chr_f(original_text, expected, &actual_text);
132            if delta_chr_f_metrics.score > best_delta_chr_f_metrics.score {
133                best_delta_chr_f_metrics = delta_chr_f_metrics;
134                best_patch_idx = Some(idx);
135                best_expected_text = Some(expected);
136            }
137        }
138
139        if let Some(idx) = best_patch_idx {
140            // Get the raw cursor offset from the expected patch (relative to hunk new text)
141            let expected_cursor_in_patch = expected_patches_with_cursors
142                .get(idx)
143                .and_then(|(_, cursor)| *cursor);
144
145            // For Teacher prompts, we need to apply the patch to the editable region
146            // to find where the hunk matched, then compute the actual cursor position
147            if let (Some(editable_region), Some(cursor_in_patch)) =
148                (&old_editable_region, expected_cursor_in_patch)
149            {
150                let (patch, _) = &expected_patches_with_cursors[idx];
151                if let Ok((_, hunk_offset)) =
152                    apply_diff_to_string_with_hunk_offset(patch, editable_region)
153                {
154                    let hunk_start = hunk_offset.unwrap_or(0);
155                    best_expected_cursor = Some(hunk_start + cursor_in_patch);
156                }
157            } else {
158                // For non-Teacher prompts or if we can't compute, use raw offset
159                best_expected_cursor = expected_cursor_in_patch;
160            }
161        }
162
163        let disbalance_before = metrics::braces_disbalance(&original_text);
164        let disbalance_after = metrics::braces_disbalance(&actual_text);
165        let braces_disbalance = disbalance_after.saturating_sub(disbalance_before);
166
167        // Compute exact lines match against best matching expected patch
168        let best_exact_lines = expected_patches_with_cursors
169            .iter()
170            .map(|(expected_patch, _)| metrics::exact_lines_match(expected_patch, &actual_patch))
171            .max_by_key(|m| m.true_positives)
172            .unwrap_or_default();
173
174        // Compute reversal ratio
175        let reversal_ratio = reversal_tracking::compute_prediction_reversal_ratio(
176            prompt_inputs,
177            &actual_text,
178            cursor_path,
179        );
180
181        // Compute cursor position metrics
182        let (cursor_distance, cursor_exact_match) =
183            compute_cursor_metrics(best_expected_cursor, prediction.actual_cursor.as_ref());
184
185        // Compute approximation of editable region correctness
186        let wrong_editable_region = Some(!metrics::is_editable_region_correct(&actual_patch));
187
188        // Check for isolated whitespace changes.
189        let has_isolated_whitespace_changes = metrics::has_isolated_whitespace_changes(
190            &actual_patch,
191            prediction.actual_cursor.as_ref(),
192        );
193
194        let (kept_rate, recall_rate, kept_chars, correctly_deleted_chars, discarded_chars) =
195            best_expected_text
196                .map(|reference_text| {
197                    let result =
198                        metrics::compute_kept_rate(original_text, &actual_text, reference_text);
199                    (
200                        Some(result.kept_rate),
201                        Some(result.recall_rate),
202                        Some(result.kept_chars),
203                        Some(result.correctly_deleted_chars),
204                        Some(result.discarded_chars),
205                    )
206                })
207                .unwrap_or((None, None, None, None, None));
208
209        scores.push(ExampleScore {
210            delta_chr_f: best_delta_chr_f_metrics.score as f32,
211            delta_chr_f_true_positives: best_delta_chr_f_metrics.counts.true_positives,
212            delta_chr_f_false_positives: best_delta_chr_f_metrics.counts.false_positives,
213            delta_chr_f_false_negatives: best_delta_chr_f_metrics.counts.false_negatives,
214            delta_chr_f_precision: best_delta_chr_f_metrics.precision,
215            delta_chr_f_recall: best_delta_chr_f_metrics.recall,
216            delta_chr_f_beta: best_delta_chr_f_metrics.beta,
217            braces_disbalance,
218            exact_lines_tp: best_exact_lines.true_positives,
219            exact_lines_fp: best_exact_lines.false_positives,
220            exact_lines_fn: best_exact_lines.false_negatives,
221            reversal_ratio,
222            cursor_distance,
223            cursor_exact_match,
224            wrong_editable_region,
225            has_isolated_whitespace_changes,
226            inserted_tokens: token_changes.inserted_tokens,
227            deleted_tokens: token_changes.deleted_tokens,
228            kept_rate,
229            recall_rate,
230            kept_chars,
231            correctly_deleted_chars,
232            discarded_chars,
233            cumulative_logprob: prediction.cumulative_logprob,
234            avg_logprob: prediction.avg_logprob,
235        });
236    }
237
238    example.score = scores;
239    Ok(())
240}
241
242fn compute_cursor_metrics(
243    expected_cursor_editable_region_offset: Option<usize>,
244    actual_cursor: Option<&ActualCursor>,
245) -> (Option<usize>, Option<bool>) {
246    match (expected_cursor_editable_region_offset, actual_cursor) {
247        (Some(expected), Some(actual)) => {
248            let distance = expected.abs_diff(actual.editable_region_offset.unwrap_or_default());
249            let exact_match = distance == 0;
250            (Some(distance), Some(exact_match))
251        }
252        (None, None) => {
253            // Neither has cursor position - skip cursor scoring
254            (None, None)
255        }
256        (Some(_), None) | (None, Some(_)) => {
257            // Only one has cursor position - count as miss
258            (None, Some(false))
259        }
260    }
261}
262
263pub fn print_report(examples: &[Example], verbose: bool) {
264    const MAX_EXAMPLES_DEFAULT: usize = 20;
265    use crate::metrics::ClassificationMetrics;
266
267    const LINE_WIDTH: usize = 101;
268    let separator = "".repeat(LINE_WIDTH);
269
270    println!("{}", separator);
271    println!(
272        "{:<40} {:>8} {:>5} {:>7} {:>7} {:>7} {:>7} {:>6} {:>5}",
273        "Example", "DeltaChrF", "Brace", "F1", "Revert", "QaRev", "QaConf", "Cursor", "WrgER"
274    );
275    println!("{}", separator);
276
277    let mut all_delta_chr_f_scores = Vec::new();
278    let mut all_reversal_ratios = Vec::new();
279    let mut braces_disbalance_sum: usize = 0;
280    let mut total_delta_chr_f = ClassificationMetrics::default();
281    let mut total_delta_chr_f_precision = 0.0;
282    let mut total_delta_chr_f_recall = 0.0;
283    let mut delta_chr_f_beta = 0.0;
284    let mut total_exact_lines = ClassificationMetrics::default();
285    let mut total_scores: usize = 0;
286    let mut qa_reverts_count: usize = 0;
287    let mut qa_reverts_total: usize = 0;
288    let mut qa_confidence_sum: u64 = 0;
289    let mut qa_confidence_count: usize = 0;
290    let mut cursor_exact_matches: usize = 0;
291    let mut cursor_total: usize = 0;
292    let mut cursor_distance_sum: usize = 0;
293    let mut cursor_distance_count: usize = 0;
294    let mut wrong_editable_region_count: usize = 0;
295    let mut wrong_editable_region_total: usize = 0;
296    let mut isolated_whitespace_count: usize = 0;
297    let mut kept_rate_sum: f64 = 0.0;
298    let mut kept_rate_count: usize = 0;
299    let mut kept_chars_total: usize = 0;
300    let mut correctly_deleted_chars_total: usize = 0;
301    let mut discarded_chars_total: usize = 0;
302    let mut recall_rate_sum: f64 = 0.0;
303    let mut recall_rate_count: usize = 0;
304    let mut patch_inserted_tokens: Vec<usize> = Vec::new();
305    let mut patch_deleted_tokens: Vec<usize> = Vec::new();
306    let mut predictions_with_patch: usize = 0;
307
308    let mut printed_lines: usize = 0;
309    let mut skipped_lines: usize = 0;
310
311    for example in examples {
312        for (score_idx, score) in example.score.iter().enumerate() {
313            let exact_lines = score.exact_lines_counts();
314
315            // Get QA results for this prediction if available
316            let qa_result = example.qa.get(score_idx).and_then(|q| q.as_ref());
317            let qa_reverts_str = qa_result
318                .and_then(|q| q.reverts_edits)
319                .map(|v| if v { "yes" } else { "no" })
320                .unwrap_or("-");
321            let qa_conf_str = qa_result
322                .and_then(|q| q.confidence)
323                .map(|v| format!("{}", v))
324                .unwrap_or("-".to_string());
325
326            // Format wrong editable region metric
327            let wrong_er_str = match score.wrong_editable_region {
328                Some(true) => "",
329                Some(false) => "",
330                None => "",
331            };
332
333            // Format cursor metric
334            let cursor_str = match (score.cursor_exact_match, score.cursor_distance) {
335                (Some(true), _) => "".to_string(),
336                (Some(false), Some(dist)) => format!("±{}", dist),
337                (Some(false), None) => "".to_string(),
338                (None, _) => "-".to_string(),
339            };
340
341            if verbose || printed_lines < MAX_EXAMPLES_DEFAULT {
342                println!(
343                    "{:<40} {:>8.2} {:>5} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6} {:>5}",
344                    truncate_name(&example.spec.name, 40),
345                    score.delta_chr_f,
346                    score.braces_disbalance,
347                    exact_lines.f1() * 100.0,
348                    score.reversal_ratio * 100.0,
349                    qa_reverts_str,
350                    qa_conf_str,
351                    cursor_str,
352                    wrong_er_str
353                );
354                printed_lines += 1;
355            } else {
356                skipped_lines += 1;
357            }
358
359            all_delta_chr_f_scores.push(score.delta_chr_f);
360            all_reversal_ratios.push(score.reversal_ratio);
361            total_scores += 1;
362            braces_disbalance_sum += score.braces_disbalance;
363            total_delta_chr_f.accumulate(&score.delta_chr_f_counts());
364            total_delta_chr_f_precision += score.delta_chr_f_precision;
365            total_delta_chr_f_recall += score.delta_chr_f_recall;
366            delta_chr_f_beta = score.delta_chr_f_beta;
367            total_exact_lines.accumulate(&score.exact_lines_counts());
368
369            // Accumulate QA metrics
370            if let Some(qa) = qa_result {
371                if let Some(reverts) = qa.reverts_edits {
372                    qa_reverts_total += 1;
373                    if reverts {
374                        qa_reverts_count += 1;
375                    }
376                }
377                if let Some(conf) = qa.confidence {
378                    qa_confidence_sum += conf as u64;
379                    qa_confidence_count += 1;
380                }
381            }
382
383            // Accumulate wrong editable region metrics
384            if let Some(wrong) = score.wrong_editable_region {
385                wrong_editable_region_total += 1;
386                if wrong {
387                    wrong_editable_region_count += 1;
388                }
389            }
390
391            // Accumulate isolated whitespace metrics
392            if score.has_isolated_whitespace_changes {
393                isolated_whitespace_count += 1;
394            }
395
396            // Accumulate kept and recall rate metrics
397            if let Some(kr) = score.kept_rate {
398                kept_rate_sum += kr;
399                kept_rate_count += 1;
400            }
401            if let Some(kept_chars) = score.kept_chars {
402                kept_chars_total += kept_chars;
403            }
404            if let Some(correctly_deleted_chars) = score.correctly_deleted_chars {
405                correctly_deleted_chars_total += correctly_deleted_chars;
406            }
407            if let Some(discarded_chars) = score.discarded_chars {
408                discarded_chars_total += discarded_chars;
409            }
410            if let Some(rr) = score.recall_rate {
411                recall_rate_sum += rr;
412                recall_rate_count += 1;
413            }
414
415            // Accumulate token change metrics (only for predictions that produced a patch)
416            let has_patch = example
417                .predictions
418                .get(score_idx)
419                .and_then(|p| p.actual_patch.as_ref())
420                .is_some_and(|p| !p.is_empty());
421            if has_patch {
422                predictions_with_patch += 1;
423                patch_inserted_tokens.push(score.inserted_tokens);
424                patch_deleted_tokens.push(score.deleted_tokens);
425            }
426
427            // Accumulate cursor metrics
428            if let Some(exact_match) = score.cursor_exact_match {
429                cursor_total += 1;
430                if exact_match {
431                    cursor_exact_matches += 1;
432                }
433            }
434            if let Some(dist) = score.cursor_distance {
435                cursor_distance_sum += dist;
436                cursor_distance_count += 1;
437            }
438        }
439    }
440
441    if skipped_lines > 0 {
442        println!(
443            "{:<40} (use --verbose to see all {} examples)",
444            format!("... and {} more", skipped_lines),
445            printed_lines + skipped_lines
446        );
447    }
448    println!("{}", separator);
449
450    if !all_delta_chr_f_scores.is_empty() {
451        let avg_delta_chr_f: f32 =
452            all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
453        let avg_reversal_ratio: f32 =
454            all_reversal_ratios.iter().sum::<f32>() / all_reversal_ratios.len() as f32;
455        let braces_disbalance_avg: f32 = braces_disbalance_sum as f32 / total_scores as f32;
456
457        let qa_reverts_str = if qa_reverts_total > 0 {
458            format!(
459                "{:.1}%",
460                qa_reverts_count as f32 / qa_reverts_total as f32 * 100.0
461            )
462        } else {
463            "-".to_string()
464        };
465        let qa_conf_str = if qa_confidence_count > 0 {
466            format!(
467                "{:.1}",
468                qa_confidence_sum as f32 / qa_confidence_count as f32
469            )
470        } else {
471            "-".to_string()
472        };
473        let cursor_str = if cursor_total > 0 {
474            format!(
475                "{:.0}%",
476                cursor_exact_matches as f32 / cursor_total as f32 * 100.0
477            )
478        } else {
479            "-".to_string()
480        };
481        let wrong_er_str = if wrong_editable_region_total > 0 {
482            format!(
483                "{:.2}%",
484                wrong_editable_region_count as f32 / wrong_editable_region_total as f32 * 100.0
485            )
486        } else {
487            "-".to_string()
488        };
489        let isolated_ws_str = if total_scores > 0 {
490            format!(
491                "{}/{} ({:.1}%)",
492                isolated_whitespace_count,
493                total_scores,
494                isolated_whitespace_count as f32 / total_scores as f32 * 100.0
495            )
496        } else {
497            "-".to_string()
498        };
499        let avg_cursor_distance = if cursor_distance_count > 0 {
500            Some(cursor_distance_sum as f32 / cursor_distance_count as f32)
501        } else {
502            None
503        };
504
505        println!(
506            "{:<40} {:>8.2} {:>5.1} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6} {:>5}",
507            "TOTAL / AVERAGE",
508            avg_delta_chr_f,
509            braces_disbalance_avg,
510            total_exact_lines.f1() * 100.0,
511            avg_reversal_ratio * 100.0,
512            qa_reverts_str,
513            qa_conf_str,
514            cursor_str,
515            wrong_er_str
516        );
517        println!("{}", separator);
518        println!(
519            "Delta chrF (β={:.1}): TP={}, FP={}, FN={}, P={:.1}%, R={:.1}%",
520            delta_chr_f_beta,
521            total_delta_chr_f.true_positives,
522            total_delta_chr_f.false_positives,
523            total_delta_chr_f.false_negatives,
524            total_delta_chr_f_precision / total_scores as f64 * 100.0,
525            total_delta_chr_f_recall / total_scores as f64 * 100.0
526        );
527
528        // Print additional cursor metrics if available
529        if let Some(avg_dist) = avg_cursor_distance {
530            println!(
531                "Cursor: {}/{} exact matches ({:.0}%), avg distance: {:.1} bytes",
532                cursor_exact_matches,
533                cursor_total,
534                cursor_exact_matches as f32 / cursor_total as f32 * 100.0,
535                avg_dist
536            );
537        }
538
539        // Print isolated whitespace metrics
540        if total_scores > 0 {
541            println!("Isolated whitespace changes: {}", isolated_ws_str);
542        }
543
544        // Print kept and recall rate metrics
545        if kept_rate_count > 0 {
546            let avg_kept_rate = kept_rate_sum / kept_rate_count as f64;
547            println!(
548                "Kept rate: {:.1}% avg ({} evaluated, kept chars: {}, correctly deleted chars: {}, discarded chars: {})",
549                avg_kept_rate * 100.0,
550                kept_rate_count,
551                kept_chars_total,
552                correctly_deleted_chars_total,
553                discarded_chars_total
554            );
555        }
556        if recall_rate_count > 0 {
557            let avg_recall_rate = recall_rate_sum / recall_rate_count as f64;
558            println!(
559                "Recall rate: {:.1}% avg ({} evaluated)",
560                avg_recall_rate * 100.0,
561                recall_rate_count
562            );
563        }
564
565        // Print token change percentile summary (only for predictions with a patch)
566        if !patch_inserted_tokens.is_empty() {
567            patch_inserted_tokens.sort_unstable();
568            patch_deleted_tokens.sort_unstable();
569            let mut patch_total_tokens: Vec<usize> = patch_inserted_tokens
570                .iter()
571                .zip(patch_deleted_tokens.iter())
572                .map(|(i, d)| i + d)
573                .collect();
574            patch_total_tokens.sort_unstable();
575
576            let patch_rate = predictions_with_patch as f32 / total_scores as f32 * 100.0;
577            println!();
578            println!(
579                "Token changes ({}/{} predictions produced a patch, {:.1}% — table includes only those)",
580                predictions_with_patch, total_scores, patch_rate
581            );
582            println!(
583                "{:<20} {:>8} {:>8} {:>8} {:>8} {:>8}",
584                "", "p25", "p50", "p75", "p90", "p99"
585            );
586            println!("{}", "".repeat(LINE_WIDTH));
587            println!(
588                "{:<20} {:>8} {:>8} {:>8} {:>8} {:>8}",
589                "Inserted tokens",
590                percentile(&patch_inserted_tokens, 25),
591                percentile(&patch_inserted_tokens, 50),
592                percentile(&patch_inserted_tokens, 75),
593                percentile(&patch_inserted_tokens, 90),
594                percentile(&patch_inserted_tokens, 99),
595            );
596            println!(
597                "{:<20} {:>8} {:>8} {:>8} {:>8} {:>8}",
598                "Deleted tokens",
599                percentile(&patch_deleted_tokens, 25),
600                percentile(&patch_deleted_tokens, 50),
601                percentile(&patch_deleted_tokens, 75),
602                percentile(&patch_deleted_tokens, 90),
603                percentile(&patch_deleted_tokens, 99),
604            );
605            println!(
606                "{:<20} {:>8} {:>8} {:>8} {:>8} {:>8}",
607                "Total tokens",
608                percentile(&patch_total_tokens, 25),
609                percentile(&patch_total_tokens, 50),
610                percentile(&patch_total_tokens, 75),
611                percentile(&patch_total_tokens, 90),
612                percentile(&patch_total_tokens, 99),
613            );
614        }
615    }
616
617    println!("\n");
618}
619
620fn percentile(sorted_values: &[usize], p: usize) -> usize {
621    if sorted_values.is_empty() {
622        return 0;
623    }
624    let idx = (p as f64 / 100.0 * (sorted_values.len() as f64 - 1.0)).round() as usize;
625    sorted_values[idx.min(sorted_values.len() - 1)]
626}
627
628fn truncate_name(name: &str, max_len: usize) -> String {
629    if name.len() <= max_len {
630        name.to_string()
631    } else {
632        format!("{}...", &name[..max_len - 3])
633    }
634}
635
636#[derive(Serialize)]
637pub struct SummaryJson {
638    pub total_examples: usize,
639    pub avg_delta_chr_f: f32,
640    pub delta_chr_f_beta: f64,
641    pub delta_chr_f_true_positives: usize,
642    pub delta_chr_f_false_positives: usize,
643    pub delta_chr_f_false_negatives: usize,
644    pub delta_chr_f_precision: f64,
645    pub delta_chr_f_recall: f64,
646    pub avg_braces_disbalance: f32,
647    pub exact_lines_true_positives: usize,
648    pub exact_lines_false_positives: usize,
649    pub exact_lines_false_negatives: usize,
650    pub exact_lines_precision: f64,
651    pub exact_lines_recall: f64,
652    pub exact_lines_f1: f64,
653    pub avg_reversal_ratio: f32,
654    #[serde(skip_serializing_if = "Option::is_none")]
655    pub qa_avg_reverts_edits: Option<f32>,
656    #[serde(skip_serializing_if = "Option::is_none")]
657    pub qa_avg_confidence: Option<f32>,
658    #[serde(skip_serializing_if = "Option::is_none")]
659    pub cursor_exact_match_rate: Option<f32>,
660    #[serde(skip_serializing_if = "Option::is_none")]
661    pub cursor_avg_distance: Option<f32>,
662    #[serde(skip_serializing_if = "Option::is_none")]
663    pub cursor_total_evaluated: Option<usize>,
664    #[serde(skip_serializing_if = "Option::is_none")]
665    pub wrong_editable_region_rate: Option<f32>,
666    pub isolated_whitespace_rate: Option<f32>,
667    #[serde(skip_serializing_if = "Option::is_none")]
668    pub avg_kept_rate: Option<f64>,
669    #[serde(skip_serializing_if = "Option::is_none")]
670    pub avg_recall_rate: Option<f64>,
671    #[serde(skip_serializing_if = "Option::is_none")]
672    pub total_kept_chars: Option<usize>,
673    #[serde(skip_serializing_if = "Option::is_none")]
674    pub total_correctly_deleted_chars: Option<usize>,
675    #[serde(skip_serializing_if = "Option::is_none")]
676    pub total_discarded_chars: Option<usize>,
677}
678
679pub fn compute_summary(examples: &[Example]) -> SummaryJson {
680    use crate::metrics::ClassificationMetrics;
681
682    let mut all_delta_chr_f_scores = Vec::new();
683    let mut all_reversal_ratios = Vec::new();
684    let mut braces_disbalance_sum: usize = 0;
685    let mut total_delta_chr_f = ClassificationMetrics::default();
686    let mut total_delta_chr_f_precision = 0.0;
687    let mut total_delta_chr_f_recall = 0.0;
688    let mut delta_chr_f_beta = 0.0;
689    let mut total_exact_lines = ClassificationMetrics::default();
690    let mut total_scores: usize = 0;
691    let mut qa_reverts_count: usize = 0;
692    let mut qa_reverts_total: usize = 0;
693    let mut qa_confidence_sum: u64 = 0;
694    let mut qa_confidence_count: usize = 0;
695    let mut cursor_exact_matches: usize = 0;
696    let mut cursor_total: usize = 0;
697    let mut cursor_distance_sum: usize = 0;
698    let mut cursor_distance_count: usize = 0;
699    let mut wrong_editable_region_count: usize = 0;
700    let mut wrong_editable_region_total: usize = 0;
701    let mut isolated_whitespace_count: usize = 0;
702    let mut kept_rate_sum: f64 = 0.0;
703    let mut kept_rate_count: usize = 0;
704    let mut kept_chars_total: usize = 0;
705    let mut kept_chars_count: usize = 0;
706    let mut correctly_deleted_chars_total: usize = 0;
707    let mut correctly_deleted_chars_count: usize = 0;
708    let mut discarded_chars_total: usize = 0;
709    let mut discarded_chars_count: usize = 0;
710    let mut recall_rate_sum: f64 = 0.0;
711    let mut recall_rate_count: usize = 0;
712
713    for example in examples {
714        for (score_idx, score) in example.score.iter().enumerate() {
715            all_delta_chr_f_scores.push(score.delta_chr_f);
716            all_reversal_ratios.push(score.reversal_ratio);
717            total_scores += 1;
718            braces_disbalance_sum += score.braces_disbalance;
719            total_delta_chr_f.accumulate(&score.delta_chr_f_counts());
720            total_delta_chr_f_precision += score.delta_chr_f_precision;
721            total_delta_chr_f_recall += score.delta_chr_f_recall;
722            delta_chr_f_beta = score.delta_chr_f_beta;
723            total_exact_lines.accumulate(&score.exact_lines_counts());
724
725            // Accumulate QA metrics
726            if let Some(Some(qa)) = example.qa.get(score_idx) {
727                if let Some(reverts) = qa.reverts_edits {
728                    qa_reverts_total += 1;
729                    if reverts {
730                        qa_reverts_count += 1;
731                    }
732                }
733                if let Some(conf) = qa.confidence {
734                    qa_confidence_sum += conf as u64;
735                    qa_confidence_count += 1;
736                }
737            }
738
739            // Accumulate wrong editable region metrics
740            if let Some(wrong) = score.wrong_editable_region {
741                wrong_editable_region_total += 1;
742                if wrong {
743                    wrong_editable_region_count += 1;
744                }
745            }
746
747            // Accumulate isolated whitespace metrics
748            if score.has_isolated_whitespace_changes {
749                isolated_whitespace_count += 1;
750            }
751
752            // Accumulate kept and recall rate metrics
753            if let Some(kr) = score.kept_rate {
754                kept_rate_sum += kr;
755                kept_rate_count += 1;
756            }
757            if let Some(kept_chars) = score.kept_chars {
758                kept_chars_total += kept_chars;
759                kept_chars_count += 1;
760            }
761            if let Some(correctly_deleted_chars) = score.correctly_deleted_chars {
762                correctly_deleted_chars_total += correctly_deleted_chars;
763                correctly_deleted_chars_count += 1;
764            }
765            if let Some(discarded_chars) = score.discarded_chars {
766                discarded_chars_total += discarded_chars;
767                discarded_chars_count += 1;
768            }
769            if let Some(rr) = score.recall_rate {
770                recall_rate_sum += rr;
771                recall_rate_count += 1;
772            }
773
774            // Accumulate cursor metrics
775            if let Some(exact_match) = score.cursor_exact_match {
776                cursor_total += 1;
777                if exact_match {
778                    cursor_exact_matches += 1;
779                }
780            }
781            if let Some(dist) = score.cursor_distance {
782                cursor_distance_sum += dist;
783                cursor_distance_count += 1;
784            }
785        }
786    }
787
788    let avg_delta_chr_f = if all_delta_chr_f_scores.is_empty() {
789        0.0
790    } else {
791        all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32
792    };
793
794    let avg_reversal_ratio = if all_reversal_ratios.is_empty() {
795        0.0
796    } else {
797        all_reversal_ratios.iter().sum::<f32>() / all_reversal_ratios.len() as f32
798    };
799
800    let avg_braces_disbalance = if total_scores == 0 {
801        0.0
802    } else {
803        braces_disbalance_sum as f32 / total_scores as f32
804    };
805
806    let qa_avg_reverts_edits = if qa_reverts_total > 0 {
807        Some(qa_reverts_count as f32 / qa_reverts_total as f32)
808    } else {
809        None
810    };
811
812    let qa_avg_confidence = if qa_confidence_count > 0 {
813        Some(qa_confidence_sum as f32 / qa_confidence_count as f32)
814    } else {
815        None
816    };
817
818    let cursor_exact_match_rate = if cursor_total > 0 {
819        Some(cursor_exact_matches as f32 / cursor_total as f32)
820    } else {
821        None
822    };
823
824    let cursor_avg_distance = if cursor_distance_count > 0 {
825        Some(cursor_distance_sum as f32 / cursor_distance_count as f32)
826    } else {
827        None
828    };
829
830    let cursor_total_evaluated = if cursor_total > 0 {
831        Some(cursor_total)
832    } else {
833        None
834    };
835
836    let wrong_editable_region_rate = if wrong_editable_region_total > 0 {
837        Some(wrong_editable_region_count as f32 / wrong_editable_region_total as f32)
838    } else {
839        None
840    };
841
842    let isolated_whitespace_rate = if total_scores > 0 {
843        Some(isolated_whitespace_count as f32 / total_scores as f32)
844    } else {
845        None
846    };
847
848    let avg_kept_rate = if kept_rate_count > 0 {
849        Some(kept_rate_sum / kept_rate_count as f64)
850    } else {
851        None
852    };
853
854    let avg_recall_rate = if recall_rate_count > 0 {
855        Some(recall_rate_sum / recall_rate_count as f64)
856    } else {
857        None
858    };
859
860    let total_kept_chars = if kept_chars_count > 0 {
861        Some(kept_chars_total)
862    } else {
863        None
864    };
865
866    let total_correctly_deleted_chars = if correctly_deleted_chars_count > 0 {
867        Some(correctly_deleted_chars_total)
868    } else {
869        None
870    };
871
872    let total_discarded_chars = if discarded_chars_count > 0 {
873        Some(discarded_chars_total)
874    } else {
875        None
876    };
877
878    SummaryJson {
879        total_examples: total_scores,
880        avg_delta_chr_f,
881        delta_chr_f_beta,
882        delta_chr_f_true_positives: total_delta_chr_f.true_positives,
883        delta_chr_f_false_positives: total_delta_chr_f.false_positives,
884        delta_chr_f_false_negatives: total_delta_chr_f.false_negatives,
885        delta_chr_f_precision: if total_scores == 0 {
886            0.0
887        } else {
888            total_delta_chr_f_precision / total_scores as f64
889        },
890        delta_chr_f_recall: if total_scores == 0 {
891            0.0
892        } else {
893            total_delta_chr_f_recall / total_scores as f64
894        },
895        avg_braces_disbalance,
896        exact_lines_true_positives: total_exact_lines.true_positives,
897        exact_lines_false_positives: total_exact_lines.false_positives,
898        exact_lines_false_negatives: total_exact_lines.false_negatives,
899        exact_lines_precision: total_exact_lines.precision(),
900        exact_lines_recall: total_exact_lines.recall(),
901        exact_lines_f1: total_exact_lines.f1(),
902        avg_reversal_ratio,
903        qa_avg_reverts_edits,
904        qa_avg_confidence,
905        cursor_exact_match_rate,
906        cursor_avg_distance,
907        cursor_total_evaluated,
908        wrong_editable_region_rate,
909        isolated_whitespace_rate,
910        avg_kept_rate,
911        avg_recall_rate,
912        total_kept_chars,
913        total_correctly_deleted_chars,
914        total_discarded_chars,
915    }
916}
917
918pub fn write_summary_json(examples: &[Example], path: &Path) -> anyhow::Result<()> {
919    let summary = compute_summary(examples);
920    let file = File::create(path)
921        .with_context(|| format!("Failed to create summary JSON file: {}", path.display()))?;
922    let writer = BufWriter::new(file);
923    serde_json::to_writer_pretty(writer, &summary)
924        .with_context(|| format!("Failed to write summary JSON to: {}", path.display()))?;
925    eprintln!("Wrote summary JSON to: {}", path.display());
926    Ok(())
927}