score.rs

  1use crate::{
  2    PredictArgs, PredictionProvider,
  3    example::{ActualCursor, Example, ExampleScore},
  4    format_prompt::TeacherPrompt,
  5    headless::EpAppState,
  6    metrics,
  7    parse_output::parse_prediction_output,
  8    predict::run_prediction,
  9    progress::{ExampleProgress, Step},
 10    reversal_tracking,
 11};
 12use anyhow::Context as _;
 13use gpui::AsyncApp;
 14use serde::Serialize;
 15use std::fs::File;
 16use std::io::BufWriter;
 17use std::path::Path;
 18use std::sync::Arc;
 19use zeta_prompt::udiff::{apply_diff_to_string, apply_diff_to_string_with_hunk_offset};
 20
 21pub async fn run_scoring(
 22    example: &mut Example,
 23    args: &PredictArgs,
 24    app_state: Arc<EpAppState>,
 25    example_progress: &ExampleProgress,
 26    cx: AsyncApp,
 27) -> anyhow::Result<()> {
 28    run_prediction(example, args, app_state, example_progress, cx).await?;
 29
 30    let progress = example_progress.start(Step::Score);
 31
 32    progress.set_substatus("applying patches");
 33    let prompt_inputs = example
 34        .prompt_inputs
 35        .as_ref()
 36        .context("prompt_inputs is required for scoring - run prediction first or ensure JSON includes prompt_inputs")?;
 37    let original_text: &str = prompt_inputs.cursor_excerpt.as_ref();
 38    let expected_patches_with_cursors = example.spec.expected_patches_with_cursor_positions();
 39
 40    let expected_texts: Vec<String> = expected_patches_with_cursors
 41        .iter()
 42        .map(|(patch, _)| {
 43            apply_diff_to_string(patch, original_text)
 44                .with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
 45        })
 46        .collect::<Result<Vec<_>, _>>()?;
 47
 48    // For Teacher prompts, we need to extract the editable region to properly compute cursor offsets.
 49    // The actual_cursor_offset from Teacher is relative to the editable region, while the expected
 50    // cursor from the patch is relative to the hunk. We need to apply the patch to the editable
 51    // region to find where the hunk matched, then compute the expected cursor position.
 52    let old_editable_region = if let Some(p) = example.prompt.as_ref() {
 53        if matches!(
 54            p.provider,
 55            PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_)
 56        ) {
 57            Some(
 58                TeacherPrompt::extract_editable_region(&p.input)?
 59                    .replace(TeacherPrompt::USER_CURSOR_MARKER, ""),
 60            )
 61        } else {
 62            None
 63        }
 64    } else {
 65        None
 66    };
 67
 68    let zero_scores = ExampleScore {
 69        delta_chr_f: 0.0,
 70        delta_chr_f_true_positives: 0,
 71        delta_chr_f_false_positives: 0,
 72        delta_chr_f_false_negatives: 0,
 73        delta_chr_f_precision: 0.0,
 74        delta_chr_f_recall: 0.0,
 75        delta_chr_f_beta: metrics::delta_chr_f_beta(),
 76        braces_disbalance: 0,
 77        exact_lines_tp: 0,
 78        exact_lines_fp: 0,
 79        exact_lines_fn: 0,
 80        reversal_ratio: 0.0,
 81        cursor_distance: None,
 82        cursor_exact_match: None,
 83        wrong_editable_region: None,
 84        has_isolated_whitespace_changes: false,
 85        inserted_tokens: 0,
 86        deleted_tokens: 0,
 87        kept_rate: None,
 88        recall_rate: None,
 89        cumulative_logprob: None,
 90        avg_logprob: None,
 91    };
 92
 93    let cursor_path = example.spec.cursor_path.as_ref();
 94
 95    progress.set_substatus("computing metrics");
 96    let mut scores = vec![];
 97    for prediction in &example.predictions {
 98        let actual_patch = prediction.actual_patch.clone().or_else(|| {
 99            parse_prediction_output(example, &prediction.actual_output, prediction.provider)
100                .ok()
101                .map(|(patch, _)| patch)
102        });
103
104        let Some(actual_patch) = actual_patch else {
105            scores.push(zero_scores.clone());
106            continue;
107        };
108
109        let token_changes = metrics::count_patch_token_changes(&actual_patch);
110
111        let actual_text = match apply_diff_to_string(&actual_patch, original_text) {
112            Ok(text) => text,
113            Err(_) => {
114                let mut s = zero_scores.clone();
115                s.inserted_tokens = token_changes.inserted_tokens;
116                s.deleted_tokens = token_changes.deleted_tokens;
117                scores.push(s);
118                continue;
119            }
120        };
121
122        let mut best_delta_chr_f_metrics = metrics::DeltaChrFMetrics::default();
123        let mut best_expected_cursor: Option<usize> = None;
124        let mut best_patch_idx: Option<usize> = None;
125        let mut best_expected_text: Option<&str> = None;
126
127        for (idx, expected) in expected_texts.iter().enumerate() {
128            let delta_chr_f_metrics = metrics::delta_chr_f(original_text, expected, &actual_text);
129            if delta_chr_f_metrics.score > best_delta_chr_f_metrics.score {
130                best_delta_chr_f_metrics = delta_chr_f_metrics;
131                best_patch_idx = Some(idx);
132                best_expected_text = Some(expected);
133            }
134        }
135
136        if let Some(idx) = best_patch_idx {
137            // Get the raw cursor offset from the expected patch (relative to hunk new text)
138            let expected_cursor_in_patch = expected_patches_with_cursors
139                .get(idx)
140                .and_then(|(_, cursor)| *cursor);
141
142            // For Teacher prompts, we need to apply the patch to the editable region
143            // to find where the hunk matched, then compute the actual cursor position
144            if let (Some(editable_region), Some(cursor_in_patch)) =
145                (&old_editable_region, expected_cursor_in_patch)
146            {
147                let (patch, _) = &expected_patches_with_cursors[idx];
148                if let Ok((_, hunk_offset)) =
149                    apply_diff_to_string_with_hunk_offset(patch, editable_region)
150                {
151                    let hunk_start = hunk_offset.unwrap_or(0);
152                    best_expected_cursor = Some(hunk_start + cursor_in_patch);
153                }
154            } else {
155                // For non-Teacher prompts or if we can't compute, use raw offset
156                best_expected_cursor = expected_cursor_in_patch;
157            }
158        }
159
160        let disbalance_before = metrics::braces_disbalance(&original_text);
161        let disbalance_after = metrics::braces_disbalance(&actual_text);
162        let braces_disbalance = disbalance_after.saturating_sub(disbalance_before);
163
164        // Compute exact lines match against best matching expected patch
165        let best_exact_lines = expected_patches_with_cursors
166            .iter()
167            .map(|(expected_patch, _)| metrics::exact_lines_match(expected_patch, &actual_patch))
168            .max_by_key(|m| m.true_positives)
169            .unwrap_or_default();
170
171        // Compute reversal ratio
172        let reversal_ratio = reversal_tracking::compute_prediction_reversal_ratio(
173            prompt_inputs,
174            &actual_text,
175            cursor_path,
176        );
177
178        // Compute cursor position metrics
179        let (cursor_distance, cursor_exact_match) =
180            compute_cursor_metrics(best_expected_cursor, prediction.actual_cursor.as_ref());
181
182        // Compute approximation of editable region correctness
183        let wrong_editable_region = Some(!metrics::is_editable_region_correct(&actual_patch));
184
185        // Check for isolated whitespace changes.
186        let has_isolated_whitespace_changes = metrics::has_isolated_whitespace_changes(
187            &actual_patch,
188            prediction.actual_cursor.as_ref(),
189        );
190
191        let (kept_rate, recall_rate) = best_expected_text
192            .map(|reference_text| {
193                let result =
194                    metrics::compute_kept_rate(original_text, &actual_text, reference_text);
195                (Some(result.kept_rate), Some(result.recall_rate))
196            })
197            .unwrap_or((None, None));
198
199        scores.push(ExampleScore {
200            delta_chr_f: best_delta_chr_f_metrics.score as f32,
201            delta_chr_f_true_positives: best_delta_chr_f_metrics.counts.true_positives,
202            delta_chr_f_false_positives: best_delta_chr_f_metrics.counts.false_positives,
203            delta_chr_f_false_negatives: best_delta_chr_f_metrics.counts.false_negatives,
204            delta_chr_f_precision: best_delta_chr_f_metrics.precision,
205            delta_chr_f_recall: best_delta_chr_f_metrics.recall,
206            delta_chr_f_beta: best_delta_chr_f_metrics.beta,
207            braces_disbalance,
208            exact_lines_tp: best_exact_lines.true_positives,
209            exact_lines_fp: best_exact_lines.false_positives,
210            exact_lines_fn: best_exact_lines.false_negatives,
211            reversal_ratio,
212            cursor_distance,
213            cursor_exact_match,
214            wrong_editable_region,
215            has_isolated_whitespace_changes,
216            inserted_tokens: token_changes.inserted_tokens,
217            deleted_tokens: token_changes.deleted_tokens,
218            kept_rate,
219            recall_rate,
220            cumulative_logprob: prediction.cumulative_logprob,
221            avg_logprob: prediction.avg_logprob,
222        });
223    }
224
225    example.score = scores;
226    Ok(())
227}
228
229fn compute_cursor_metrics(
230    expected_cursor_editable_region_offset: Option<usize>,
231    actual_cursor: Option<&ActualCursor>,
232) -> (Option<usize>, Option<bool>) {
233    match (expected_cursor_editable_region_offset, actual_cursor) {
234        (Some(expected), Some(actual)) => {
235            let distance = expected.abs_diff(actual.editable_region_offset.unwrap_or_default());
236            let exact_match = distance == 0;
237            (Some(distance), Some(exact_match))
238        }
239        (None, None) => {
240            // Neither has cursor position - skip cursor scoring
241            (None, None)
242        }
243        (Some(_), None) | (None, Some(_)) => {
244            // Only one has cursor position - count as miss
245            (None, Some(false))
246        }
247    }
248}
249
250pub fn print_report(examples: &[Example], verbose: bool) {
251    const MAX_EXAMPLES_DEFAULT: usize = 20;
252    use crate::metrics::ClassificationMetrics;
253
254    const LINE_WIDTH: usize = 101;
255    let separator = "".repeat(LINE_WIDTH);
256
257    println!("{}", separator);
258    println!(
259        "{:<40} {:>8} {:>5} {:>7} {:>7} {:>7} {:>7} {:>6} {:>5}",
260        "Example", "DeltaChrF", "Brace", "F1", "Revert", "QaRev", "QaConf", "Cursor", "WrgER"
261    );
262    println!("{}", separator);
263
264    let mut all_delta_chr_f_scores = Vec::new();
265    let mut all_reversal_ratios = Vec::new();
266    let mut braces_disbalance_sum: usize = 0;
267    let mut total_delta_chr_f = ClassificationMetrics::default();
268    let mut total_delta_chr_f_precision = 0.0;
269    let mut total_delta_chr_f_recall = 0.0;
270    let mut delta_chr_f_beta = 0.0;
271    let mut total_exact_lines = ClassificationMetrics::default();
272    let mut total_scores: usize = 0;
273    let mut qa_reverts_count: usize = 0;
274    let mut qa_reverts_total: usize = 0;
275    let mut qa_confidence_sum: u64 = 0;
276    let mut qa_confidence_count: usize = 0;
277    let mut cursor_exact_matches: usize = 0;
278    let mut cursor_total: usize = 0;
279    let mut cursor_distance_sum: usize = 0;
280    let mut cursor_distance_count: usize = 0;
281    let mut wrong_editable_region_count: usize = 0;
282    let mut wrong_editable_region_total: usize = 0;
283    let mut isolated_whitespace_count: usize = 0;
284    let mut kept_rate_sum: f64 = 0.0;
285    let mut kept_rate_count: usize = 0;
286    let mut recall_rate_sum: f64 = 0.0;
287    let mut recall_rate_count: usize = 0;
288    let mut patch_inserted_tokens: Vec<usize> = Vec::new();
289    let mut patch_deleted_tokens: Vec<usize> = Vec::new();
290    let mut predictions_with_patch: usize = 0;
291
292    let mut printed_lines: usize = 0;
293    let mut skipped_lines: usize = 0;
294
295    for example in examples {
296        for (score_idx, score) in example.score.iter().enumerate() {
297            let exact_lines = score.exact_lines_counts();
298
299            // Get QA results for this prediction if available
300            let qa_result = example.qa.get(score_idx).and_then(|q| q.as_ref());
301            let qa_reverts_str = qa_result
302                .and_then(|q| q.reverts_edits)
303                .map(|v| if v { "yes" } else { "no" })
304                .unwrap_or("-");
305            let qa_conf_str = qa_result
306                .and_then(|q| q.confidence)
307                .map(|v| format!("{}", v))
308                .unwrap_or("-".to_string());
309
310            // Format wrong editable region metric
311            let wrong_er_str = match score.wrong_editable_region {
312                Some(true) => "",
313                Some(false) => "",
314                None => "",
315            };
316
317            // Format cursor metric
318            let cursor_str = match (score.cursor_exact_match, score.cursor_distance) {
319                (Some(true), _) => "".to_string(),
320                (Some(false), Some(dist)) => format!("±{}", dist),
321                (Some(false), None) => "".to_string(),
322                (None, _) => "-".to_string(),
323            };
324
325            if verbose || printed_lines < MAX_EXAMPLES_DEFAULT {
326                println!(
327                    "{:<40} {:>8.2} {:>5} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6} {:>5}",
328                    truncate_name(&example.spec.name, 40),
329                    score.delta_chr_f,
330                    score.braces_disbalance,
331                    exact_lines.f1() * 100.0,
332                    score.reversal_ratio * 100.0,
333                    qa_reverts_str,
334                    qa_conf_str,
335                    cursor_str,
336                    wrong_er_str
337                );
338                printed_lines += 1;
339            } else {
340                skipped_lines += 1;
341            }
342
343            all_delta_chr_f_scores.push(score.delta_chr_f);
344            all_reversal_ratios.push(score.reversal_ratio);
345            total_scores += 1;
346            braces_disbalance_sum += score.braces_disbalance;
347            total_delta_chr_f.accumulate(&score.delta_chr_f_counts());
348            total_delta_chr_f_precision += score.delta_chr_f_precision;
349            total_delta_chr_f_recall += score.delta_chr_f_recall;
350            delta_chr_f_beta = score.delta_chr_f_beta;
351            total_exact_lines.accumulate(&score.exact_lines_counts());
352
353            // Accumulate QA metrics
354            if let Some(qa) = qa_result {
355                if let Some(reverts) = qa.reverts_edits {
356                    qa_reverts_total += 1;
357                    if reverts {
358                        qa_reverts_count += 1;
359                    }
360                }
361                if let Some(conf) = qa.confidence {
362                    qa_confidence_sum += conf as u64;
363                    qa_confidence_count += 1;
364                }
365            }
366
367            // Accumulate wrong editable region metrics
368            if let Some(wrong) = score.wrong_editable_region {
369                wrong_editable_region_total += 1;
370                if wrong {
371                    wrong_editable_region_count += 1;
372                }
373            }
374
375            // Accumulate isolated whitespace metrics
376            if score.has_isolated_whitespace_changes {
377                isolated_whitespace_count += 1;
378            }
379
380            // Accumulate kept and recall rate metrics
381            if let Some(kr) = score.kept_rate {
382                kept_rate_sum += kr;
383                kept_rate_count += 1;
384            }
385            if let Some(rr) = score.recall_rate {
386                recall_rate_sum += rr;
387                recall_rate_count += 1;
388            }
389
390            // Accumulate token change metrics (only for predictions that produced a patch)
391            let has_patch = example
392                .predictions
393                .get(score_idx)
394                .and_then(|p| p.actual_patch.as_ref())
395                .is_some_and(|p| !p.is_empty());
396            if has_patch {
397                predictions_with_patch += 1;
398                patch_inserted_tokens.push(score.inserted_tokens);
399                patch_deleted_tokens.push(score.deleted_tokens);
400            }
401
402            // Accumulate cursor metrics
403            if let Some(exact_match) = score.cursor_exact_match {
404                cursor_total += 1;
405                if exact_match {
406                    cursor_exact_matches += 1;
407                }
408            }
409            if let Some(dist) = score.cursor_distance {
410                cursor_distance_sum += dist;
411                cursor_distance_count += 1;
412            }
413        }
414    }
415
416    if skipped_lines > 0 {
417        println!(
418            "{:<40} (use --verbose to see all {} examples)",
419            format!("... and {} more", skipped_lines),
420            printed_lines + skipped_lines
421        );
422    }
423    println!("{}", separator);
424
425    if !all_delta_chr_f_scores.is_empty() {
426        let avg_delta_chr_f: f32 =
427            all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
428        let avg_reversal_ratio: f32 =
429            all_reversal_ratios.iter().sum::<f32>() / all_reversal_ratios.len() as f32;
430        let braces_disbalance_avg: f32 = braces_disbalance_sum as f32 / total_scores as f32;
431
432        let qa_reverts_str = if qa_reverts_total > 0 {
433            format!(
434                "{:.1}%",
435                qa_reverts_count as f32 / qa_reverts_total as f32 * 100.0
436            )
437        } else {
438            "-".to_string()
439        };
440        let qa_conf_str = if qa_confidence_count > 0 {
441            format!(
442                "{:.1}",
443                qa_confidence_sum as f32 / qa_confidence_count as f32
444            )
445        } else {
446            "-".to_string()
447        };
448        let cursor_str = if cursor_total > 0 {
449            format!(
450                "{:.0}%",
451                cursor_exact_matches as f32 / cursor_total as f32 * 100.0
452            )
453        } else {
454            "-".to_string()
455        };
456        let wrong_er_str = if wrong_editable_region_total > 0 {
457            format!(
458                "{:.2}%",
459                wrong_editable_region_count as f32 / wrong_editable_region_total as f32 * 100.0
460            )
461        } else {
462            "-".to_string()
463        };
464        let isolated_ws_str = if total_scores > 0 {
465            format!(
466                "{}/{} ({:.1}%)",
467                isolated_whitespace_count,
468                total_scores,
469                isolated_whitespace_count as f32 / total_scores as f32 * 100.0
470            )
471        } else {
472            "-".to_string()
473        };
474        let avg_cursor_distance = if cursor_distance_count > 0 {
475            Some(cursor_distance_sum as f32 / cursor_distance_count as f32)
476        } else {
477            None
478        };
479
480        println!(
481            "{:<40} {:>8.2} {:>5.1} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6} {:>5}",
482            "TOTAL / AVERAGE",
483            avg_delta_chr_f,
484            braces_disbalance_avg,
485            total_exact_lines.f1() * 100.0,
486            avg_reversal_ratio * 100.0,
487            qa_reverts_str,
488            qa_conf_str,
489            cursor_str,
490            wrong_er_str
491        );
492        println!("{}", separator);
493        println!(
494            "Delta chrF (β={:.1}): TP={}, FP={}, FN={}, P={:.1}%, R={:.1}%",
495            delta_chr_f_beta,
496            total_delta_chr_f.true_positives,
497            total_delta_chr_f.false_positives,
498            total_delta_chr_f.false_negatives,
499            total_delta_chr_f_precision / total_scores as f64 * 100.0,
500            total_delta_chr_f_recall / total_scores as f64 * 100.0
501        );
502
503        // Print additional cursor metrics if available
504        if let Some(avg_dist) = avg_cursor_distance {
505            println!(
506                "Cursor: {}/{} exact matches ({:.0}%), avg distance: {:.1} bytes",
507                cursor_exact_matches,
508                cursor_total,
509                cursor_exact_matches as f32 / cursor_total as f32 * 100.0,
510                avg_dist
511            );
512        }
513
514        // Print isolated whitespace metrics
515        if total_scores > 0 {
516            println!("Isolated whitespace changes: {}", isolated_ws_str);
517        }
518
519        // Print kept and recall rate metrics
520        if kept_rate_count > 0 {
521            let avg_kept_rate = kept_rate_sum / kept_rate_count as f64;
522            println!(
523                "Kept rate: {:.1}% avg ({} evaluated)",
524                avg_kept_rate * 100.0,
525                kept_rate_count
526            );
527        }
528        if recall_rate_count > 0 {
529            let avg_recall_rate = recall_rate_sum / recall_rate_count as f64;
530            println!(
531                "Recall rate: {:.1}% avg ({} evaluated)",
532                avg_recall_rate * 100.0,
533                recall_rate_count
534            );
535        }
536
537        // Print token change percentile summary (only for predictions with a patch)
538        if !patch_inserted_tokens.is_empty() {
539            patch_inserted_tokens.sort_unstable();
540            patch_deleted_tokens.sort_unstable();
541            let mut patch_total_tokens: Vec<usize> = patch_inserted_tokens
542                .iter()
543                .zip(patch_deleted_tokens.iter())
544                .map(|(i, d)| i + d)
545                .collect();
546            patch_total_tokens.sort_unstable();
547
548            let patch_rate = predictions_with_patch as f32 / total_scores as f32 * 100.0;
549            println!();
550            println!(
551                "Token changes ({}/{} predictions produced a patch, {:.1}% — table includes only those)",
552                predictions_with_patch, total_scores, patch_rate
553            );
554            println!(
555                "{:<20} {:>8} {:>8} {:>8} {:>8} {:>8}",
556                "", "p25", "p50", "p75", "p90", "p99"
557            );
558            println!("{}", "".repeat(LINE_WIDTH));
559            println!(
560                "{:<20} {:>8} {:>8} {:>8} {:>8} {:>8}",
561                "Inserted tokens",
562                percentile(&patch_inserted_tokens, 25),
563                percentile(&patch_inserted_tokens, 50),
564                percentile(&patch_inserted_tokens, 75),
565                percentile(&patch_inserted_tokens, 90),
566                percentile(&patch_inserted_tokens, 99),
567            );
568            println!(
569                "{:<20} {:>8} {:>8} {:>8} {:>8} {:>8}",
570                "Deleted tokens",
571                percentile(&patch_deleted_tokens, 25),
572                percentile(&patch_deleted_tokens, 50),
573                percentile(&patch_deleted_tokens, 75),
574                percentile(&patch_deleted_tokens, 90),
575                percentile(&patch_deleted_tokens, 99),
576            );
577            println!(
578                "{:<20} {:>8} {:>8} {:>8} {:>8} {:>8}",
579                "Total tokens",
580                percentile(&patch_total_tokens, 25),
581                percentile(&patch_total_tokens, 50),
582                percentile(&patch_total_tokens, 75),
583                percentile(&patch_total_tokens, 90),
584                percentile(&patch_total_tokens, 99),
585            );
586        }
587    }
588
589    println!("\n");
590}
591
592fn percentile(sorted_values: &[usize], p: usize) -> usize {
593    if sorted_values.is_empty() {
594        return 0;
595    }
596    let idx = (p as f64 / 100.0 * (sorted_values.len() as f64 - 1.0)).round() as usize;
597    sorted_values[idx.min(sorted_values.len() - 1)]
598}
599
600fn truncate_name(name: &str, max_len: usize) -> String {
601    if name.len() <= max_len {
602        name.to_string()
603    } else {
604        format!("{}...", &name[..max_len - 3])
605    }
606}
607
608#[derive(Serialize)]
609pub struct SummaryJson {
610    pub total_examples: usize,
611    pub avg_delta_chr_f: f32,
612    pub delta_chr_f_beta: f64,
613    pub delta_chr_f_true_positives: usize,
614    pub delta_chr_f_false_positives: usize,
615    pub delta_chr_f_false_negatives: usize,
616    pub delta_chr_f_precision: f64,
617    pub delta_chr_f_recall: f64,
618    pub avg_braces_disbalance: f32,
619    pub exact_lines_true_positives: usize,
620    pub exact_lines_false_positives: usize,
621    pub exact_lines_false_negatives: usize,
622    pub exact_lines_precision: f64,
623    pub exact_lines_recall: f64,
624    pub exact_lines_f1: f64,
625    pub avg_reversal_ratio: f32,
626    #[serde(skip_serializing_if = "Option::is_none")]
627    pub qa_avg_reverts_edits: Option<f32>,
628    #[serde(skip_serializing_if = "Option::is_none")]
629    pub qa_avg_confidence: Option<f32>,
630    #[serde(skip_serializing_if = "Option::is_none")]
631    pub cursor_exact_match_rate: Option<f32>,
632    #[serde(skip_serializing_if = "Option::is_none")]
633    pub cursor_avg_distance: Option<f32>,
634    #[serde(skip_serializing_if = "Option::is_none")]
635    pub cursor_total_evaluated: Option<usize>,
636    #[serde(skip_serializing_if = "Option::is_none")]
637    pub wrong_editable_region_rate: Option<f32>,
638    pub isolated_whitespace_rate: Option<f32>,
639    #[serde(skip_serializing_if = "Option::is_none")]
640    pub avg_kept_rate: Option<f64>,
641    #[serde(skip_serializing_if = "Option::is_none")]
642    pub avg_recall_rate: Option<f64>,
643}
644
645pub fn compute_summary(examples: &[Example]) -> SummaryJson {
646    use crate::metrics::ClassificationMetrics;
647
648    let mut all_delta_chr_f_scores = Vec::new();
649    let mut all_reversal_ratios = Vec::new();
650    let mut braces_disbalance_sum: usize = 0;
651    let mut total_delta_chr_f = ClassificationMetrics::default();
652    let mut total_delta_chr_f_precision = 0.0;
653    let mut total_delta_chr_f_recall = 0.0;
654    let mut delta_chr_f_beta = 0.0;
655    let mut total_exact_lines = ClassificationMetrics::default();
656    let mut total_scores: usize = 0;
657    let mut qa_reverts_count: usize = 0;
658    let mut qa_reverts_total: usize = 0;
659    let mut qa_confidence_sum: u64 = 0;
660    let mut qa_confidence_count: usize = 0;
661    let mut cursor_exact_matches: usize = 0;
662    let mut cursor_total: usize = 0;
663    let mut cursor_distance_sum: usize = 0;
664    let mut cursor_distance_count: usize = 0;
665    let mut wrong_editable_region_count: usize = 0;
666    let mut wrong_editable_region_total: usize = 0;
667    let mut isolated_whitespace_count: usize = 0;
668    let mut kept_rate_sum: f64 = 0.0;
669    let mut kept_rate_count: usize = 0;
670    let mut recall_rate_sum: f64 = 0.0;
671    let mut recall_rate_count: usize = 0;
672
673    for example in examples {
674        for (score_idx, score) in example.score.iter().enumerate() {
675            all_delta_chr_f_scores.push(score.delta_chr_f);
676            all_reversal_ratios.push(score.reversal_ratio);
677            total_scores += 1;
678            braces_disbalance_sum += score.braces_disbalance;
679            total_delta_chr_f.accumulate(&score.delta_chr_f_counts());
680            total_delta_chr_f_precision += score.delta_chr_f_precision;
681            total_delta_chr_f_recall += score.delta_chr_f_recall;
682            delta_chr_f_beta = score.delta_chr_f_beta;
683            total_exact_lines.accumulate(&score.exact_lines_counts());
684
685            // Accumulate QA metrics
686            if let Some(Some(qa)) = example.qa.get(score_idx) {
687                if let Some(reverts) = qa.reverts_edits {
688                    qa_reverts_total += 1;
689                    if reverts {
690                        qa_reverts_count += 1;
691                    }
692                }
693                if let Some(conf) = qa.confidence {
694                    qa_confidence_sum += conf as u64;
695                    qa_confidence_count += 1;
696                }
697            }
698
699            // Accumulate wrong editable region metrics
700            if let Some(wrong) = score.wrong_editable_region {
701                wrong_editable_region_total += 1;
702                if wrong {
703                    wrong_editable_region_count += 1;
704                }
705            }
706
707            // Accumulate isolated whitespace metrics
708            if score.has_isolated_whitespace_changes {
709                isolated_whitespace_count += 1;
710            }
711
712            // Accumulate kept and recall rate metrics
713            if let Some(kr) = score.kept_rate {
714                kept_rate_sum += kr;
715                kept_rate_count += 1;
716            }
717            if let Some(rr) = score.recall_rate {
718                recall_rate_sum += rr;
719                recall_rate_count += 1;
720            }
721
722            // Accumulate cursor metrics
723            if let Some(exact_match) = score.cursor_exact_match {
724                cursor_total += 1;
725                if exact_match {
726                    cursor_exact_matches += 1;
727                }
728            }
729            if let Some(dist) = score.cursor_distance {
730                cursor_distance_sum += dist;
731                cursor_distance_count += 1;
732            }
733        }
734    }
735
736    let avg_delta_chr_f = if all_delta_chr_f_scores.is_empty() {
737        0.0
738    } else {
739        all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32
740    };
741
742    let avg_reversal_ratio = if all_reversal_ratios.is_empty() {
743        0.0
744    } else {
745        all_reversal_ratios.iter().sum::<f32>() / all_reversal_ratios.len() as f32
746    };
747
748    let avg_braces_disbalance = if total_scores == 0 {
749        0.0
750    } else {
751        braces_disbalance_sum as f32 / total_scores as f32
752    };
753
754    let qa_avg_reverts_edits = if qa_reverts_total > 0 {
755        Some(qa_reverts_count as f32 / qa_reverts_total as f32)
756    } else {
757        None
758    };
759
760    let qa_avg_confidence = if qa_confidence_count > 0 {
761        Some(qa_confidence_sum as f32 / qa_confidence_count as f32)
762    } else {
763        None
764    };
765
766    let cursor_exact_match_rate = if cursor_total > 0 {
767        Some(cursor_exact_matches as f32 / cursor_total as f32)
768    } else {
769        None
770    };
771
772    let cursor_avg_distance = if cursor_distance_count > 0 {
773        Some(cursor_distance_sum as f32 / cursor_distance_count as f32)
774    } else {
775        None
776    };
777
778    let cursor_total_evaluated = if cursor_total > 0 {
779        Some(cursor_total)
780    } else {
781        None
782    };
783
784    let wrong_editable_region_rate = if wrong_editable_region_total > 0 {
785        Some(wrong_editable_region_count as f32 / wrong_editable_region_total as f32)
786    } else {
787        None
788    };
789
790    let isolated_whitespace_rate = if total_scores > 0 {
791        Some(isolated_whitespace_count as f32 / total_scores as f32)
792    } else {
793        None
794    };
795
796    let avg_kept_rate = if kept_rate_count > 0 {
797        Some(kept_rate_sum / kept_rate_count as f64)
798    } else {
799        None
800    };
801
802    let avg_recall_rate = if recall_rate_count > 0 {
803        Some(recall_rate_sum / recall_rate_count as f64)
804    } else {
805        None
806    };
807
808    SummaryJson {
809        total_examples: total_scores,
810        avg_delta_chr_f,
811        delta_chr_f_beta,
812        delta_chr_f_true_positives: total_delta_chr_f.true_positives,
813        delta_chr_f_false_positives: total_delta_chr_f.false_positives,
814        delta_chr_f_false_negatives: total_delta_chr_f.false_negatives,
815        delta_chr_f_precision: if total_scores == 0 {
816            0.0
817        } else {
818            total_delta_chr_f_precision / total_scores as f64
819        },
820        delta_chr_f_recall: if total_scores == 0 {
821            0.0
822        } else {
823            total_delta_chr_f_recall / total_scores as f64
824        },
825        avg_braces_disbalance,
826        exact_lines_true_positives: total_exact_lines.true_positives,
827        exact_lines_false_positives: total_exact_lines.false_positives,
828        exact_lines_false_negatives: total_exact_lines.false_negatives,
829        exact_lines_precision: total_exact_lines.precision(),
830        exact_lines_recall: total_exact_lines.recall(),
831        exact_lines_f1: total_exact_lines.f1(),
832        avg_reversal_ratio,
833        qa_avg_reverts_edits,
834        qa_avg_confidence,
835        cursor_exact_match_rate,
836        cursor_avg_distance,
837        cursor_total_evaluated,
838        wrong_editable_region_rate,
839        isolated_whitespace_rate,
840        avg_kept_rate,
841        avg_recall_rate,
842    }
843}
844
845pub fn write_summary_json(examples: &[Example], path: &Path) -> anyhow::Result<()> {
846    let summary = compute_summary(examples);
847    let file = File::create(path)
848        .with_context(|| format!("Failed to create summary JSON file: {}", path.display()))?;
849    let writer = BufWriter::new(file);
850    serde_json::to_writer_pretty(writer, &summary)
851        .with_context(|| format!("Failed to write summary JSON to: {}", path.display()))?;
852    eprintln!("Wrote summary JSON to: {}", path.display());
853    Ok(())
854}