diff --git a/crates/edit_prediction_cli/src/qa.rs b/crates/edit_prediction_cli/src/qa.rs index 59304bed825aabf37df48ead43d8d52525282946..d8ed61e9e1f4c8823a188e8917e74cc04042fec3 100644 --- a/crates/edit_prediction_cli/src/qa.rs +++ b/crates/edit_prediction_cli/src/qa.rs @@ -29,7 +29,7 @@ pub struct QaArgs { pub wait: bool, /// Which LLM provider to use (anthropic or openai) - #[clap(long, default_value = "anthropic")] + #[clap(long, default_value = "openai")] pub backend: BatchProvider, } diff --git a/crates/edit_prediction_cli/src/score.rs b/crates/edit_prediction_cli/src/score.rs index 010763e507475088cbe686fc7fbfc6a0e1427ad1..486e433ca0e9a69712023c418c06f331c758ec02 100644 --- a/crates/edit_prediction_cli/src/score.rs +++ b/crates/edit_prediction_cli/src/score.rs @@ -127,13 +127,13 @@ pub async fn run_scoring( pub fn print_report(examples: &[Example]) { use crate::metrics::ClassificationMetrics; - const LINE_WIDTH: usize = 110; + const LINE_WIDTH: usize = 82; let separator = "─".repeat(LINE_WIDTH); println!("{}", separator); println!( - "{:<40} {:>8} {:>5} {:>4} {:>4} {:>4} {:>7} {:>7} {:>7} {:>7}", - "Example", "DeltaChrF", "Brace", "TP", "FP", "FN", "Prec", "Rec", "F1", "Revert" + "{:<40} {:>8} {:>5} {:>7} {:>7} {:>7} {:>7}", + "Example", "DeltaChrF", "Brace", "F1", "Revert", "QaRev", "QaConf" ); println!("{}", separator); @@ -142,27 +142,39 @@ pub fn print_report(examples: &[Example]) { let mut braces_disbalance_sum: usize = 0; let mut total_exact_lines = ClassificationMetrics::default(); let mut total_scores: usize = 0; + let mut qa_reverts_count: usize = 0; + let mut qa_reverts_total: usize = 0; + let mut qa_confidence_sum: u64 = 0; + let mut qa_confidence_count: usize = 0; for example in examples { - for score in example.score.iter() { + for (score_idx, score) in example.score.iter().enumerate() { let exact_lines = ClassificationMetrics { true_positives: score.exact_lines_tp, false_positives: score.exact_lines_fp, false_negatives: score.exact_lines_fn, }; + // Get QA results for this prediction if available + let qa_result = example.qa.get(score_idx).and_then(|q| q.as_ref()); + let qa_reverts_str = qa_result + .and_then(|q| q.reverts_edits) + .map(|v| if v { "yes" } else { "no" }) + .unwrap_or("-"); + let qa_conf_str = qa_result + .and_then(|q| q.confidence) + .map(|v| format!("{}", v)) + .unwrap_or("-".to_string()); + println!( - "{:<40} {:>8.2} {:>5} {:>4} {:>4} {:>4} {:>6.1}% {:>6.1}% {:>6.1}% {:>6.1}%", + "{:<40} {:>8.2} {:>5} {:>6.1}% {:>6.1}% {:>7} {:>7}", truncate_name(&example.spec.name, 40), score.delta_chr_f, score.braces_disbalance, - score.exact_lines_tp, - score.exact_lines_fp, - score.exact_lines_fn, - exact_lines.precision() * 100.0, - exact_lines.recall() * 100.0, exact_lines.f1() * 100.0, - score.reversal_ratio * 100.0 + score.reversal_ratio * 100.0, + qa_reverts_str, + qa_conf_str ); all_delta_chr_f_scores.push(score.delta_chr_f); @@ -172,6 +184,20 @@ pub fn print_report(examples: &[Example]) { total_exact_lines.true_positives += score.exact_lines_tp; total_exact_lines.false_positives += score.exact_lines_fp; total_exact_lines.false_negatives += score.exact_lines_fn; + + // Accumulate QA metrics + if let Some(qa) = qa_result { + if let Some(reverts) = qa.reverts_edits { + qa_reverts_total += 1; + if reverts { + qa_reverts_count += 1; + } + } + if let Some(conf) = qa.confidence { + qa_confidence_sum += conf as u64; + qa_confidence_count += 1; + } + } } } @@ -184,18 +210,32 @@ pub fn print_report(examples: &[Example]) { all_reversal_ratios.iter().sum::() / all_reversal_ratios.len() as f32; let braces_disbalance_avg: f32 = braces_disbalance_sum as f32 / total_scores as f32; + let qa_reverts_str = if qa_reverts_total > 0 { + format!( + "{:.1}%", + qa_reverts_count as f32 / qa_reverts_total as f32 * 100.0 + ) + } else { + "-".to_string() + }; + let qa_conf_str = if qa_confidence_count > 0 { + format!( + "{:.1}", + qa_confidence_sum as f32 / qa_confidence_count as f32 + ) + } else { + "-".to_string() + }; + println!( - "{:<40} {:>8.2} {:>5.1} {:>4} {:>4} {:>4} {:>6.1}% {:>6.1}% {:>6.1}% {:>6.1}%", + "{:<40} {:>8.2} {:>5.1} {:>6.1}% {:>6.1}% {:>7} {:>7}", "TOTAL / AVERAGE", avg_delta_chr_f, braces_disbalance_avg, - total_exact_lines.true_positives, - total_exact_lines.false_positives, - total_exact_lines.false_negatives, - total_exact_lines.precision() * 100.0, - total_exact_lines.recall() * 100.0, total_exact_lines.f1() * 100.0, - avg_reversal_ratio * 100.0 + avg_reversal_ratio * 100.0, + qa_reverts_str, + qa_conf_str ); println!("{}", separator); } @@ -223,6 +263,10 @@ pub struct SummaryJson { pub exact_lines_recall: f64, pub exact_lines_f1: f64, pub avg_reversal_ratio: f32, + #[serde(skip_serializing_if = "Option::is_none")] + pub qa_avg_reverts_edits: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub qa_avg_confidence: Option, } pub fn compute_summary(examples: &[Example]) -> SummaryJson { @@ -233,9 +277,13 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { let mut braces_disbalance_sum: usize = 0; let mut total_exact_lines = ClassificationMetrics::default(); let mut total_scores: usize = 0; + let mut qa_reverts_count: usize = 0; + let mut qa_reverts_total: usize = 0; + let mut qa_confidence_sum: u64 = 0; + let mut qa_confidence_count: usize = 0; for example in examples { - for score in example.score.iter() { + for (score_idx, score) in example.score.iter().enumerate() { all_delta_chr_f_scores.push(score.delta_chr_f); all_reversal_ratios.push(score.reversal_ratio); total_scores += 1; @@ -243,6 +291,20 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { total_exact_lines.true_positives += score.exact_lines_tp; total_exact_lines.false_positives += score.exact_lines_fp; total_exact_lines.false_negatives += score.exact_lines_fn; + + // Accumulate QA metrics + if let Some(Some(qa)) = example.qa.get(score_idx) { + if let Some(reverts) = qa.reverts_edits { + qa_reverts_total += 1; + if reverts { + qa_reverts_count += 1; + } + } + if let Some(conf) = qa.confidence { + qa_confidence_sum += conf as u64; + qa_confidence_count += 1; + } + } } } @@ -264,6 +326,18 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { braces_disbalance_sum as f32 / total_scores as f32 }; + let qa_avg_reverts_edits = if qa_reverts_total > 0 { + Some(qa_reverts_count as f32 / qa_reverts_total as f32) + } else { + None + }; + + let qa_avg_confidence = if qa_confidence_count > 0 { + Some(qa_confidence_sum as f32 / qa_confidence_count as f32) + } else { + None + }; + SummaryJson { total_examples: total_scores, avg_delta_chr_f, @@ -275,6 +349,8 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { exact_lines_recall: total_exact_lines.recall(), exact_lines_f1: total_exact_lines.f1(), avg_reversal_ratio, + qa_avg_reverts_edits, + qa_avg_confidence, } }