score.rs

  1use crate::{
  2    PredictArgs,
  3    example::{Example, ExampleScore},
  4    headless::EpAppState,
  5    metrics::{self, ClassificationMetrics},
  6    predict::run_prediction,
  7};
  8use edit_prediction::udiff::DiffLine;
  9use gpui::AsyncApp;
 10use std::sync::Arc;
 11
 12pub async fn run_scoring(
 13    example: &mut Example,
 14    args: &PredictArgs,
 15    app_state: Arc<EpAppState>,
 16    cx: AsyncApp,
 17) {
 18    run_prediction(
 19        example,
 20        Some(args.provider),
 21        args.repetitions,
 22        app_state,
 23        cx,
 24    )
 25    .await;
 26
 27    let expected_patch = parse_patch(&example.expected_patch);
 28
 29    let mut scores = vec![];
 30
 31    for pred in &example.predictions {
 32        let actual_patch = parse_patch(&pred.actual_patch);
 33        let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
 34        let delta_chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch) as f32;
 35
 36        scores.push(ExampleScore {
 37            delta_chr_f,
 38            line_match,
 39        });
 40    }
 41
 42    example.score = scores;
 43}
 44
 45fn parse_patch(patch: &str) -> Vec<DiffLine<'_>> {
 46    patch.lines().map(DiffLine::parse).collect()
 47}
 48
 49pub fn print_report(examples: &[Example]) {
 50    eprintln!(
 51        "──────────────────────────────────────────────────────────────────────────────────────"
 52    );
 53    eprintln!(
 54        "{:<30} {:>4} {:>4} {:>4} {:>10} {:>8} {:>8} {:>10}",
 55        "Example name", "TP", "FP", "FN", "Precision", "Recall", "F1", "DeltaChrF"
 56    );
 57    eprintln!(
 58        "──────────────────────────────────────────────────────────────────────────────────────"
 59    );
 60
 61    let mut all_line_match_scores = Vec::new();
 62    let mut all_delta_chr_f_scores = Vec::new();
 63
 64    for example in examples {
 65        for score in example.score.iter() {
 66            let line_match = &score.line_match;
 67
 68            eprintln!(
 69                "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
 70                truncate_name(&example.name, 30),
 71                line_match.true_positives,
 72                line_match.false_positives,
 73                line_match.false_negatives,
 74                line_match.precision() * 100.0,
 75                line_match.recall() * 100.0,
 76                line_match.f1_score() * 100.0,
 77                score.delta_chr_f
 78            );
 79
 80            all_line_match_scores.push(line_match.clone());
 81            all_delta_chr_f_scores.push(score.delta_chr_f);
 82        }
 83    }
 84
 85    eprintln!(
 86        "──────────────────────────────────────────────────────────────────────────────────────"
 87    );
 88
 89    if !all_line_match_scores.is_empty() {
 90        let total_line_match = ClassificationMetrics::aggregate(all_line_match_scores.iter());
 91        let avg_delta_chr_f: f32 =
 92            all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
 93
 94        eprintln!(
 95            "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
 96            "TOTAL",
 97            total_line_match.true_positives,
 98            total_line_match.false_positives,
 99            total_line_match.false_negatives,
100            total_line_match.precision() * 100.0,
101            total_line_match.recall() * 100.0,
102            total_line_match.f1_score() * 100.0,
103            avg_delta_chr_f
104        );
105        eprintln!(
106            "──────────────────────────────────────────────────────────────────────────────────────"
107        );
108    }
109
110    eprintln!("\n");
111}
112
113fn truncate_name(name: &str, max_len: usize) -> String {
114    if name.len() <= max_len {
115        name.to_string()
116    } else {
117        format!("{}...", &name[..max_len - 3])
118    }
119}