score.rs

  1use crate::{
  2    PredictArgs,
  3    example::{Example, ExampleScore},
  4    headless::EpAppState,
  5    metrics,
  6    parse_output::parse_prediction_output,
  7    predict::run_prediction,
  8    progress::{ExampleProgress, Step},
  9};
 10use anyhow::Context as _;
 11use edit_prediction::udiff::apply_diff_to_string;
 12use gpui::AsyncApp;
 13use serde::Serialize;
 14use std::fs::File;
 15use std::io::BufWriter;
 16use std::path::Path;
 17use std::sync::Arc;
 18
 19pub async fn run_scoring(
 20    example: &mut Example,
 21    args: &PredictArgs,
 22    app_state: Arc<EpAppState>,
 23    example_progress: &ExampleProgress,
 24    cx: AsyncApp,
 25) -> anyhow::Result<()> {
 26    run_prediction(example, args, app_state, example_progress, cx).await?;
 27
 28    let progress = example_progress.start(Step::Score);
 29
 30    progress.set_substatus("applying patches");
 31    let original_text = &example
 32        .prompt_inputs
 33        .as_ref()
 34        .context("prompt_inputs is required for scoring - run prediction first or ensure JSON includes prompt_inputs")?
 35        .content;
 36    let expected_texts: Vec<String> = example
 37        .spec
 38        .expected_patches
 39        .iter()
 40        .map(|patch| {
 41            apply_diff_to_string(patch, original_text)
 42                .with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
 43        })
 44        .collect::<Result<Vec<_>, _>>()?;
 45
 46    let zero_scores = ExampleScore {
 47        delta_chr_f: 0.0,
 48        braces_disbalance: 0,
 49        exact_lines_tp: 0,
 50        exact_lines_fp: 0,
 51        exact_lines_fn: 0,
 52    };
 53
 54    progress.set_substatus("computing metrics");
 55    let mut scores = vec![];
 56    for prediction in &example.predictions {
 57        let actual_patch = prediction.actual_patch.clone().or_else(|| {
 58            parse_prediction_output(example, &prediction.actual_output, prediction.provider).ok()
 59        });
 60
 61        let Some(actual_patch) = actual_patch else {
 62            scores.push(zero_scores.clone());
 63            continue;
 64        };
 65
 66        let actual_text = match apply_diff_to_string(&actual_patch, original_text) {
 67            Ok(text) => text,
 68            Err(_) => {
 69                scores.push(zero_scores.clone());
 70                continue;
 71            }
 72        };
 73        let best_delta_chr_f = expected_texts
 74            .iter()
 75            .map(|expected| metrics::delta_chr_f(original_text, expected, &actual_text) as f32)
 76            .fold(0.0, f32::max);
 77
 78        let disbalance_before = metrics::braces_disbalance(&original_text);
 79        let disbalance_after = metrics::braces_disbalance(&actual_text);
 80        let braces_disbalance = disbalance_after.saturating_sub(disbalance_before);
 81        if braces_disbalance > 0 {
 82            std::fs::write(
 83                "/tmp/unbalanced-count.before",
 84                disbalance_before.to_string(),
 85            )
 86            .ok();
 87            std::fs::write("/tmp/unbalanced-count.after", disbalance_after.to_string()).ok();
 88            std::fs::write("/tmp/unbalanced-text.before", &original_text).ok();
 89            std::fs::write("/tmp/unbalanced-text.after", &actual_text).ok();
 90        }
 91
 92        // Compute exact lines match against best matching expected patch
 93        let best_exact_lines = example
 94            .spec
 95            .expected_patches
 96            .iter()
 97            .map(|expected_patch| metrics::exact_lines_match(expected_patch, &actual_patch))
 98            .max_by_key(|m| m.true_positives)
 99            .unwrap_or_default();
100
101        scores.push(ExampleScore {
102            delta_chr_f: best_delta_chr_f,
103            braces_disbalance,
104            exact_lines_tp: best_exact_lines.true_positives,
105            exact_lines_fp: best_exact_lines.false_positives,
106            exact_lines_fn: best_exact_lines.false_negatives,
107        });
108    }
109
110    example.score = scores;
111    Ok(())
112}
113
114pub fn print_report(examples: &[Example]) {
115    use crate::metrics::ClassificationMetrics;
116
117    const LINE_WIDTH: usize = 100;
118    let separator = "".repeat(LINE_WIDTH);
119
120    println!("{}", separator);
121    println!(
122        "{:<40} {:>8} {:>5} {:>4} {:>4} {:>4} {:>7} {:>7} {:>7}",
123        "Example", "DeltaChrF", "Brace", "TP", "FP", "FN", "Prec", "Rec", "F1"
124    );
125    println!("{}", separator);
126
127    let mut all_delta_chr_f_scores = Vec::new();
128    let mut braces_disbalance_sum: usize = 0;
129    let mut total_exact_lines = ClassificationMetrics::default();
130    let mut total_scores: usize = 0;
131
132    for example in examples {
133        for score in example.score.iter() {
134            let exact_lines = ClassificationMetrics {
135                true_positives: score.exact_lines_tp,
136                false_positives: score.exact_lines_fp,
137                false_negatives: score.exact_lines_fn,
138            };
139
140            println!(
141                "{:<40} {:>8.2} {:>5} {:>4} {:>4} {:>4} {:>6.1}% {:>6.1}% {:>6.1}%",
142                truncate_name(&example.spec.name, 40),
143                score.delta_chr_f,
144                score.braces_disbalance,
145                score.exact_lines_tp,
146                score.exact_lines_fp,
147                score.exact_lines_fn,
148                exact_lines.precision() * 100.0,
149                exact_lines.recall() * 100.0,
150                exact_lines.f1() * 100.0
151            );
152
153            all_delta_chr_f_scores.push(score.delta_chr_f);
154            total_scores += 1;
155            braces_disbalance_sum += score.braces_disbalance;
156            total_exact_lines.true_positives += score.exact_lines_tp;
157            total_exact_lines.false_positives += score.exact_lines_fp;
158            total_exact_lines.false_negatives += score.exact_lines_fn;
159        }
160    }
161
162    println!("{}", separator);
163
164    if !all_delta_chr_f_scores.is_empty() {
165        let avg_delta_chr_f: f32 =
166            all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
167        let braces_disbalance_avg: f32 = braces_disbalance_sum as f32 / total_scores as f32;
168
169        println!(
170            "{:<40} {:>8.2} {:>5.1} {:>4} {:>4} {:>4} {:>6.1}% {:>6.1}% {:>6.1}%",
171            "TOTAL / AVERAGE",
172            avg_delta_chr_f,
173            braces_disbalance_avg,
174            total_exact_lines.true_positives,
175            total_exact_lines.false_positives,
176            total_exact_lines.false_negatives,
177            total_exact_lines.precision() * 100.0,
178            total_exact_lines.recall() * 100.0,
179            total_exact_lines.f1() * 100.0
180        );
181        println!("{}", separator);
182    }
183
184    println!("\n");
185}
186
187fn truncate_name(name: &str, max_len: usize) -> String {
188    if name.len() <= max_len {
189        name.to_string()
190    } else {
191        format!("{}...", &name[..max_len - 3])
192    }
193}
194
195#[derive(Serialize)]
196pub struct SummaryJson {
197    pub total_examples: usize,
198    pub avg_delta_chr_f: f32,
199    pub avg_braces_disbalance: f32,
200    pub exact_lines_true_positives: usize,
201    pub exact_lines_false_positives: usize,
202    pub exact_lines_false_negatives: usize,
203    pub exact_lines_precision: f64,
204    pub exact_lines_recall: f64,
205    pub exact_lines_f1: f64,
206}
207
208pub fn compute_summary(examples: &[Example]) -> SummaryJson {
209    use crate::metrics::ClassificationMetrics;
210
211    let mut all_delta_chr_f_scores = Vec::new();
212    let mut braces_disbalance_sum: usize = 0;
213    let mut total_exact_lines = ClassificationMetrics::default();
214    let mut total_scores: usize = 0;
215
216    for example in examples {
217        for score in example.score.iter() {
218            all_delta_chr_f_scores.push(score.delta_chr_f);
219            total_scores += 1;
220            braces_disbalance_sum += score.braces_disbalance;
221            total_exact_lines.true_positives += score.exact_lines_tp;
222            total_exact_lines.false_positives += score.exact_lines_fp;
223            total_exact_lines.false_negatives += score.exact_lines_fn;
224        }
225    }
226
227    let avg_delta_chr_f = if all_delta_chr_f_scores.is_empty() {
228        0.0
229    } else {
230        all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32
231    };
232
233    let avg_braces_disbalance = if total_scores == 0 {
234        0.0
235    } else {
236        braces_disbalance_sum as f32 / total_scores as f32
237    };
238
239    SummaryJson {
240        total_examples: total_scores,
241        avg_delta_chr_f,
242        avg_braces_disbalance,
243        exact_lines_true_positives: total_exact_lines.true_positives,
244        exact_lines_false_positives: total_exact_lines.false_positives,
245        exact_lines_false_negatives: total_exact_lines.false_negatives,
246        exact_lines_precision: total_exact_lines.precision(),
247        exact_lines_recall: total_exact_lines.recall(),
248        exact_lines_f1: total_exact_lines.f1(),
249    }
250}
251
252pub fn write_summary_json(examples: &[Example], path: &Path) -> anyhow::Result<()> {
253    let summary = compute_summary(examples);
254    let file = File::create(path)
255        .with_context(|| format!("Failed to create summary JSON file: {}", path.display()))?;
256    let writer = BufWriter::new(file);
257    serde_json::to_writer_pretty(writer, &summary)
258        .with_context(|| format!("Failed to write summary JSON to: {}", path.display()))?;
259    eprintln!("Wrote summary JSON to: {}", path.display());
260    Ok(())
261}