score.rs

  1use crate::{
  2    PredictArgs,
  3    example::{Example, ExampleScore},
  4    headless::EpAppState,
  5    metrics::{self, ClassificationMetrics},
  6    predict::run_prediction,
  7    progress::{Progress, Step},
  8};
  9use edit_prediction::udiff::DiffLine;
 10use gpui::AsyncApp;
 11use std::sync::Arc;
 12
 13pub async fn run_scoring(
 14    example: &mut Example,
 15    args: &PredictArgs,
 16    app_state: Arc<EpAppState>,
 17    cx: AsyncApp,
 18) -> anyhow::Result<()> {
 19    run_prediction(
 20        example,
 21        Some(args.provider),
 22        args.repetitions,
 23        app_state,
 24        cx,
 25    )
 26    .await?;
 27
 28    let _progress = Progress::global().start(Step::Score, &example.spec.name);
 29
 30    let expected_patch = parse_patch(&example.spec.expected_patch);
 31
 32    let mut scores = vec![];
 33
 34    for pred in &example.predictions {
 35        let actual_patch = parse_patch(&pred.actual_patch);
 36        let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
 37        let delta_chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch) as f32;
 38
 39        scores.push(ExampleScore {
 40            delta_chr_f,
 41            line_match,
 42        });
 43    }
 44
 45    example.score = scores;
 46    Ok(())
 47}
 48
 49fn parse_patch(patch: &str) -> Vec<DiffLine<'_>> {
 50    patch.lines().map(DiffLine::parse).collect()
 51}
 52
 53pub fn print_report(examples: &[Example]) {
 54    eprintln!(
 55        "──────────────────────────────────────────────────────────────────────────────────────"
 56    );
 57    eprintln!(
 58        "{:<30} {:>4} {:>4} {:>4} {:>10} {:>8} {:>8} {:>10}",
 59        "Example name", "TP", "FP", "FN", "Precision", "Recall", "F1", "DeltaChrF"
 60    );
 61    eprintln!(
 62        "──────────────────────────────────────────────────────────────────────────────────────"
 63    );
 64
 65    let mut all_line_match_scores = Vec::new();
 66    let mut all_delta_chr_f_scores = Vec::new();
 67
 68    for example in examples {
 69        for score in example.score.iter() {
 70            let line_match = &score.line_match;
 71
 72            eprintln!(
 73                "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
 74                truncate_name(&example.spec.name, 30),
 75                line_match.true_positives,
 76                line_match.false_positives,
 77                line_match.false_negatives,
 78                line_match.precision() * 100.0,
 79                line_match.recall() * 100.0,
 80                line_match.f1_score() * 100.0,
 81                score.delta_chr_f
 82            );
 83
 84            all_line_match_scores.push(line_match.clone());
 85            all_delta_chr_f_scores.push(score.delta_chr_f);
 86        }
 87    }
 88
 89    eprintln!(
 90        "──────────────────────────────────────────────────────────────────────────────────────"
 91    );
 92
 93    if !all_line_match_scores.is_empty() {
 94        let total_line_match = ClassificationMetrics::aggregate(all_line_match_scores.iter());
 95        let avg_delta_chr_f: f32 =
 96            all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
 97
 98        eprintln!(
 99            "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
100            "TOTAL",
101            total_line_match.true_positives,
102            total_line_match.false_positives,
103            total_line_match.false_negatives,
104            total_line_match.precision() * 100.0,
105            total_line_match.recall() * 100.0,
106            total_line_match.f1_score() * 100.0,
107            avg_delta_chr_f
108        );
109        eprintln!(
110            "──────────────────────────────────────────────────────────────────────────────────────"
111        );
112    }
113
114    eprintln!("\n");
115}
116
117fn truncate_name(name: &str, max_len: usize) -> String {
118    if name.len() <= max_len {
119        name.to_string()
120    } else {
121        format!("{}...", &name[..max_len - 3])
122    }
123}