score.rs

  1use crate::{
  2    PredictArgs,
  3    example::{Example, ExampleScore},
  4    headless::EpAppState,
  5    metrics::{self, ClassificationMetrics},
  6    predict::run_prediction,
  7    progress::{Progress, Step},
  8};
  9use edit_prediction::udiff::DiffLine;
 10use gpui::AsyncApp;
 11use std::sync::Arc;
 12
 13pub async fn run_scoring(
 14    example: &mut Example,
 15    args: &PredictArgs,
 16    app_state: Arc<EpAppState>,
 17    cx: AsyncApp,
 18) {
 19    run_prediction(
 20        example,
 21        Some(args.provider),
 22        args.repetitions,
 23        app_state,
 24        cx,
 25    )
 26    .await;
 27
 28    let _progress = Progress::global().start(Step::Score, &example.name);
 29
 30    let expected_patch = parse_patch(&example.expected_patch);
 31
 32    let mut scores = vec![];
 33
 34    for pred in &example.predictions {
 35        let actual_patch = parse_patch(&pred.actual_patch);
 36        let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
 37        let delta_chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch) as f32;
 38
 39        scores.push(ExampleScore {
 40            delta_chr_f,
 41            line_match,
 42        });
 43    }
 44
 45    example.score = scores;
 46}
 47
 48fn parse_patch(patch: &str) -> Vec<DiffLine<'_>> {
 49    patch.lines().map(DiffLine::parse).collect()
 50}
 51
 52pub fn print_report(examples: &[Example]) {
 53    eprintln!(
 54        "──────────────────────────────────────────────────────────────────────────────────────"
 55    );
 56    eprintln!(
 57        "{:<30} {:>4} {:>4} {:>4} {:>10} {:>8} {:>8} {:>10}",
 58        "Example name", "TP", "FP", "FN", "Precision", "Recall", "F1", "DeltaChrF"
 59    );
 60    eprintln!(
 61        "──────────────────────────────────────────────────────────────────────────────────────"
 62    );
 63
 64    let mut all_line_match_scores = Vec::new();
 65    let mut all_delta_chr_f_scores = Vec::new();
 66
 67    for example in examples {
 68        for score in example.score.iter() {
 69            let line_match = &score.line_match;
 70
 71            eprintln!(
 72                "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
 73                truncate_name(&example.name, 30),
 74                line_match.true_positives,
 75                line_match.false_positives,
 76                line_match.false_negatives,
 77                line_match.precision() * 100.0,
 78                line_match.recall() * 100.0,
 79                line_match.f1_score() * 100.0,
 80                score.delta_chr_f
 81            );
 82
 83            all_line_match_scores.push(line_match.clone());
 84            all_delta_chr_f_scores.push(score.delta_chr_f);
 85        }
 86    }
 87
 88    eprintln!(
 89        "──────────────────────────────────────────────────────────────────────────────────────"
 90    );
 91
 92    if !all_line_match_scores.is_empty() {
 93        let total_line_match = ClassificationMetrics::aggregate(all_line_match_scores.iter());
 94        let avg_delta_chr_f: f32 =
 95            all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
 96
 97        eprintln!(
 98            "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
 99            "TOTAL",
100            total_line_match.true_positives,
101            total_line_match.false_positives,
102            total_line_match.false_negatives,
103            total_line_match.precision() * 100.0,
104            total_line_match.recall() * 100.0,
105            total_line_match.f1_score() * 100.0,
106            avg_delta_chr_f
107        );
108        eprintln!(
109            "──────────────────────────────────────────────────────────────────────────────────────"
110        );
111    }
112
113    eprintln!("\n");
114}
115
116fn truncate_name(name: &str, max_len: usize) -> String {
117    if name.len() <= max_len {
118        name.to_string()
119    } else {
120        format!("{}...", &name[..max_len - 3])
121    }
122}