score.rs

  1use crate::{
  2    PredictArgs,
  3    example::{Example, ExampleScore},
  4    headless::EpAppState,
  5    metrics,
  6    parse_output::parse_prediction_output,
  7    predict::run_prediction,
  8    progress::{ExampleProgress, Step},
  9};
 10use anyhow::Context as _;
 11use edit_prediction::udiff::apply_diff_to_string;
 12use gpui::AsyncApp;
 13use std::sync::Arc;
 14
 15pub async fn run_scoring(
 16    example: &mut Example,
 17    args: &PredictArgs,
 18    app_state: Arc<EpAppState>,
 19    example_progress: &ExampleProgress,
 20    cx: AsyncApp,
 21) -> anyhow::Result<()> {
 22    run_prediction(example, args, app_state, example_progress, cx).await?;
 23
 24    let progress = example_progress.start(Step::Score);
 25
 26    progress.set_substatus("applying patches");
 27    let original_text = &example
 28        .prompt_inputs
 29        .as_ref()
 30        .context("prompt_inputs is required for scoring - run prediction first or ensure JSON includes prompt_inputs")?
 31        .content;
 32    let expected_texts: Vec<String> = example
 33        .spec
 34        .expected_patches
 35        .iter()
 36        .map(|patch| {
 37            apply_diff_to_string(patch, original_text)
 38                .with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
 39        })
 40        .collect::<Result<Vec<_>, _>>()?;
 41
 42    let zero_scores = ExampleScore {
 43        delta_chr_f: 0.0,
 44        braces_disbalance: 0,
 45        exact_lines_tp: 0,
 46        exact_lines_fp: 0,
 47        exact_lines_fn: 0,
 48    };
 49
 50    progress.set_substatus("computing metrics");
 51    let mut scores = vec![];
 52    for prediction in &example.predictions {
 53        let actual_patch = prediction.actual_patch.clone().or_else(|| {
 54            parse_prediction_output(example, &prediction.actual_output, prediction.provider).ok()
 55        });
 56
 57        let Some(actual_patch) = actual_patch else {
 58            scores.push(zero_scores.clone());
 59            continue;
 60        };
 61
 62        let actual_text = match apply_diff_to_string(&actual_patch, original_text) {
 63            Ok(text) => text,
 64            Err(_) => {
 65                scores.push(zero_scores.clone());
 66                continue;
 67            }
 68        };
 69        let best_delta_chr_f = expected_texts
 70            .iter()
 71            .map(|expected| metrics::delta_chr_f(original_text, expected, &actual_text) as f32)
 72            .fold(0.0, f32::max);
 73
 74        let disbalance_before = metrics::braces_disbalance(&original_text);
 75        let disbalance_after = metrics::braces_disbalance(&actual_text);
 76        let braces_disbalance = disbalance_after.saturating_sub(disbalance_before);
 77        if braces_disbalance > 0 {
 78            std::fs::write(
 79                "/tmp/unbalanced-count.before",
 80                disbalance_before.to_string(),
 81            )
 82            .ok();
 83            std::fs::write("/tmp/unbalanced-count.after", disbalance_after.to_string()).ok();
 84            std::fs::write("/tmp/unbalanced-text.before", &original_text).ok();
 85            std::fs::write("/tmp/unbalanced-text.after", &actual_text).ok();
 86        }
 87
 88        // Compute exact lines match against best matching expected patch
 89        let best_exact_lines = example
 90            .spec
 91            .expected_patches
 92            .iter()
 93            .map(|expected_patch| metrics::exact_lines_match(expected_patch, &actual_patch))
 94            .max_by_key(|m| m.true_positives)
 95            .unwrap_or_default();
 96
 97        scores.push(ExampleScore {
 98            delta_chr_f: best_delta_chr_f,
 99            braces_disbalance,
100            exact_lines_tp: best_exact_lines.true_positives,
101            exact_lines_fp: best_exact_lines.false_positives,
102            exact_lines_fn: best_exact_lines.false_negatives,
103        });
104    }
105
106    example.score = scores;
107    Ok(())
108}
109
110pub fn print_report(examples: &[Example]) {
111    use crate::metrics::ClassificationMetrics;
112
113    const LINE_WIDTH: usize = 100;
114    let separator = "".repeat(LINE_WIDTH);
115
116    eprintln!("{}", separator);
117    eprintln!(
118        "{:<40} {:>8} {:>5} {:>4} {:>4} {:>4} {:>7} {:>7} {:>7}",
119        "Example", "DeltaChrF", "Brace", "TP", "FP", "FN", "Prec", "Rec", "F1"
120    );
121    eprintln!("{}", separator);
122
123    let mut all_delta_chr_f_scores = Vec::new();
124    let mut braces_disbalance_sum: usize = 0;
125    let mut total_exact_lines = ClassificationMetrics::default();
126    let mut total_scores: usize = 0;
127
128    for example in examples {
129        for score in example.score.iter() {
130            let exact_lines = ClassificationMetrics {
131                true_positives: score.exact_lines_tp,
132                false_positives: score.exact_lines_fp,
133                false_negatives: score.exact_lines_fn,
134            };
135
136            eprintln!(
137                "{:<40} {:>8.2} {:>5} {:>4} {:>4} {:>4} {:>6.1}% {:>6.1}% {:>6.1}%",
138                truncate_name(&example.spec.name, 40),
139                score.delta_chr_f,
140                score.braces_disbalance,
141                score.exact_lines_tp,
142                score.exact_lines_fp,
143                score.exact_lines_fn,
144                exact_lines.precision() * 100.0,
145                exact_lines.recall() * 100.0,
146                exact_lines.f1() * 100.0
147            );
148
149            all_delta_chr_f_scores.push(score.delta_chr_f);
150            total_scores += 1;
151            braces_disbalance_sum += score.braces_disbalance;
152            total_exact_lines.true_positives += score.exact_lines_tp;
153            total_exact_lines.false_positives += score.exact_lines_fp;
154            total_exact_lines.false_negatives += score.exact_lines_fn;
155        }
156    }
157
158    eprintln!("{}", separator);
159
160    if !all_delta_chr_f_scores.is_empty() {
161        let avg_delta_chr_f: f32 =
162            all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
163        let braces_disbalance_avg: f32 = braces_disbalance_sum as f32 / total_scores as f32;
164
165        eprintln!(
166            "{:<40} {:>8.2} {:>5.1} {:>4} {:>4} {:>4} {:>6.1}% {:>6.1}% {:>6.1}%",
167            "TOTAL / AVERAGE",
168            avg_delta_chr_f,
169            braces_disbalance_avg,
170            total_exact_lines.true_positives,
171            total_exact_lines.false_positives,
172            total_exact_lines.false_negatives,
173            total_exact_lines.precision() * 100.0,
174            total_exact_lines.recall() * 100.0,
175            total_exact_lines.f1() * 100.0
176        );
177        eprintln!("{}", separator);
178    }
179
180    eprintln!("\n");
181}
182
183fn truncate_name(name: &str, max_len: usize) -> String {
184    if name.len() <= max_len {
185        name.to_string()
186    } else {
187        format!("{}...", &name[..max_len - 3])
188    }
189}