score.rs

  1use crate::{
  2    PredictArgs,
  3    example::{Example, ExampleScore},
  4    headless::EpAppState,
  5    metrics::{self, ClassificationMetrics},
  6    predict::run_prediction,
  7    progress::{Progress, Step},
  8};
  9use edit_prediction::udiff::DiffLine;
 10use gpui::AsyncApp;
 11use std::sync::Arc;
 12
 13pub async fn run_scoring(
 14    example: &mut Example,
 15    args: &PredictArgs,
 16    app_state: Arc<EpAppState>,
 17    progress: Arc<Progress>,
 18    cx: AsyncApp,
 19) {
 20    run_prediction(
 21        example,
 22        Some(args.provider),
 23        args.repetitions,
 24        app_state,
 25        progress.clone(),
 26        cx,
 27    )
 28    .await;
 29
 30    let _progress = progress.start(Step::Score, &example.name);
 31
 32    let expected_patch = parse_patch(&example.expected_patch);
 33
 34    let mut scores = vec![];
 35
 36    for pred in &example.predictions {
 37        let actual_patch = parse_patch(&pred.actual_patch);
 38        let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
 39        let delta_chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch) as f32;
 40
 41        scores.push(ExampleScore {
 42            delta_chr_f,
 43            line_match,
 44        });
 45    }
 46
 47    example.score = scores;
 48}
 49
 50fn parse_patch(patch: &str) -> Vec<DiffLine<'_>> {
 51    patch.lines().map(DiffLine::parse).collect()
 52}
 53
 54pub fn print_report(examples: &[Example]) {
 55    eprintln!(
 56        "──────────────────────────────────────────────────────────────────────────────────────"
 57    );
 58    eprintln!(
 59        "{:<30} {:>4} {:>4} {:>4} {:>10} {:>8} {:>8} {:>10}",
 60        "Example name", "TP", "FP", "FN", "Precision", "Recall", "F1", "DeltaChrF"
 61    );
 62    eprintln!(
 63        "──────────────────────────────────────────────────────────────────────────────────────"
 64    );
 65
 66    let mut all_line_match_scores = Vec::new();
 67    let mut all_delta_chr_f_scores = Vec::new();
 68
 69    for example in examples {
 70        for score in example.score.iter() {
 71            let line_match = &score.line_match;
 72
 73            eprintln!(
 74                "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
 75                truncate_name(&example.name, 30),
 76                line_match.true_positives,
 77                line_match.false_positives,
 78                line_match.false_negatives,
 79                line_match.precision() * 100.0,
 80                line_match.recall() * 100.0,
 81                line_match.f1_score() * 100.0,
 82                score.delta_chr_f
 83            );
 84
 85            all_line_match_scores.push(line_match.clone());
 86            all_delta_chr_f_scores.push(score.delta_chr_f);
 87        }
 88    }
 89
 90    eprintln!(
 91        "──────────────────────────────────────────────────────────────────────────────────────"
 92    );
 93
 94    if !all_line_match_scores.is_empty() {
 95        let total_line_match = ClassificationMetrics::aggregate(all_line_match_scores.iter());
 96        let avg_delta_chr_f: f32 =
 97            all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
 98
 99        eprintln!(
100            "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
101            "TOTAL",
102            total_line_match.true_positives,
103            total_line_match.false_positives,
104            total_line_match.false_negatives,
105            total_line_match.precision() * 100.0,
106            total_line_match.recall() * 100.0,
107            total_line_match.f1_score() * 100.0,
108            avg_delta_chr_f
109        );
110        eprintln!(
111            "──────────────────────────────────────────────────────────────────────────────────────"
112        );
113    }
114
115    eprintln!("\n");
116}
117
118fn truncate_name(name: &str, max_len: usize) -> String {
119    if name.len() <= max_len {
120        name.to_string()
121    } else {
122        format!("{}...", &name[..max_len - 3])
123    }
124}