score.rs

  1use crate::{
  2    PredictArgs,
  3    example::{Example, ExampleScore},
  4    headless::EpAppState,
  5    metrics,
  6    parse_output::parse_prediction_output,
  7    predict::run_prediction,
  8    progress::{ExampleProgress, Step},
  9    reversal_tracking,
 10};
 11use anyhow::Context as _;
 12use edit_prediction::udiff::apply_diff_to_string;
 13use gpui::AsyncApp;
 14use serde::Serialize;
 15use std::fs::File;
 16use std::io::BufWriter;
 17use std::path::Path;
 18use std::sync::Arc;
 19
 20pub async fn run_scoring(
 21    example: &mut Example,
 22    args: &PredictArgs,
 23    app_state: Arc<EpAppState>,
 24    example_progress: &ExampleProgress,
 25    cx: AsyncApp,
 26) -> anyhow::Result<()> {
 27    run_prediction(example, args, app_state, example_progress, cx).await?;
 28
 29    let progress = example_progress.start(Step::Score);
 30
 31    progress.set_substatus("applying patches");
 32    let original_text = &example
 33        .prompt_inputs
 34        .as_ref()
 35        .context("prompt_inputs is required for scoring - run prediction first or ensure JSON includes prompt_inputs")?
 36        .content;
 37    let expected_texts: Vec<String> = example
 38        .spec
 39        .expected_patches
 40        .iter()
 41        .map(|patch| {
 42            apply_diff_to_string(patch, original_text)
 43                .with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
 44        })
 45        .collect::<Result<Vec<_>, _>>()?;
 46
 47    let zero_scores = ExampleScore {
 48        delta_chr_f: 0.0,
 49        braces_disbalance: 0,
 50        exact_lines_tp: 0,
 51        exact_lines_fp: 0,
 52        exact_lines_fn: 0,
 53        reversal_ratio: 0.0,
 54    };
 55
 56    let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
 57    let cursor_path = example.spec.cursor_path.as_ref();
 58
 59    progress.set_substatus("computing metrics");
 60    let mut scores = vec![];
 61    for prediction in &example.predictions {
 62        let actual_patch = prediction.actual_patch.clone().or_else(|| {
 63            parse_prediction_output(example, &prediction.actual_output, prediction.provider).ok()
 64        });
 65
 66        let Some(actual_patch) = actual_patch else {
 67            scores.push(zero_scores.clone());
 68            continue;
 69        };
 70
 71        let actual_text = match apply_diff_to_string(&actual_patch, original_text) {
 72            Ok(text) => text,
 73            Err(_) => {
 74                scores.push(zero_scores.clone());
 75                continue;
 76            }
 77        };
 78        let best_delta_chr_f = expected_texts
 79            .iter()
 80            .map(|expected| metrics::delta_chr_f(original_text, expected, &actual_text) as f32)
 81            .fold(0.0, f32::max);
 82
 83        let disbalance_before = metrics::braces_disbalance(&original_text);
 84        let disbalance_after = metrics::braces_disbalance(&actual_text);
 85        let braces_disbalance = disbalance_after.saturating_sub(disbalance_before);
 86        if braces_disbalance > 0 {
 87            std::fs::write(
 88                "/tmp/unbalanced-count.before",
 89                disbalance_before.to_string(),
 90            )
 91            .ok();
 92            std::fs::write("/tmp/unbalanced-count.after", disbalance_after.to_string()).ok();
 93            std::fs::write("/tmp/unbalanced-text.before", &original_text).ok();
 94            std::fs::write("/tmp/unbalanced-text.after", &actual_text).ok();
 95        }
 96
 97        // Compute exact lines match against best matching expected patch
 98        let best_exact_lines = example
 99            .spec
100            .expected_patches
101            .iter()
102            .map(|expected_patch| metrics::exact_lines_match(expected_patch, &actual_patch))
103            .max_by_key(|m| m.true_positives)
104            .unwrap_or_default();
105
106        // Compute reversal ratio
107        let reversal_ratio = reversal_tracking::compute_prediction_reversal_ratio(
108            prompt_inputs,
109            &actual_text,
110            cursor_path,
111        );
112
113        scores.push(ExampleScore {
114            delta_chr_f: best_delta_chr_f,
115            braces_disbalance,
116            exact_lines_tp: best_exact_lines.true_positives,
117            exact_lines_fp: best_exact_lines.false_positives,
118            exact_lines_fn: best_exact_lines.false_negatives,
119            reversal_ratio,
120        });
121    }
122
123    example.score = scores;
124    Ok(())
125}
126
127pub fn print_report(examples: &[Example]) {
128    use crate::metrics::ClassificationMetrics;
129
130    const LINE_WIDTH: usize = 110;
131    let separator = "".repeat(LINE_WIDTH);
132
133    println!("{}", separator);
134    println!(
135        "{:<40} {:>8} {:>5} {:>4} {:>4} {:>4} {:>7} {:>7} {:>7} {:>7}",
136        "Example", "DeltaChrF", "Brace", "TP", "FP", "FN", "Prec", "Rec", "F1", "Revert"
137    );
138    println!("{}", separator);
139
140    let mut all_delta_chr_f_scores = Vec::new();
141    let mut all_reversal_ratios = Vec::new();
142    let mut braces_disbalance_sum: usize = 0;
143    let mut total_exact_lines = ClassificationMetrics::default();
144    let mut total_scores: usize = 0;
145
146    for example in examples {
147        for score in example.score.iter() {
148            let exact_lines = ClassificationMetrics {
149                true_positives: score.exact_lines_tp,
150                false_positives: score.exact_lines_fp,
151                false_negatives: score.exact_lines_fn,
152            };
153
154            println!(
155                "{:<40} {:>8.2} {:>5} {:>4} {:>4} {:>4} {:>6.1}% {:>6.1}% {:>6.1}% {:>6.1}%",
156                truncate_name(&example.spec.name, 40),
157                score.delta_chr_f,
158                score.braces_disbalance,
159                score.exact_lines_tp,
160                score.exact_lines_fp,
161                score.exact_lines_fn,
162                exact_lines.precision() * 100.0,
163                exact_lines.recall() * 100.0,
164                exact_lines.f1() * 100.0,
165                score.reversal_ratio * 100.0
166            );
167
168            all_delta_chr_f_scores.push(score.delta_chr_f);
169            all_reversal_ratios.push(score.reversal_ratio);
170            total_scores += 1;
171            braces_disbalance_sum += score.braces_disbalance;
172            total_exact_lines.true_positives += score.exact_lines_tp;
173            total_exact_lines.false_positives += score.exact_lines_fp;
174            total_exact_lines.false_negatives += score.exact_lines_fn;
175        }
176    }
177
178    println!("{}", separator);
179
180    if !all_delta_chr_f_scores.is_empty() {
181        let avg_delta_chr_f: f32 =
182            all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
183        let avg_reversal_ratio: f32 =
184            all_reversal_ratios.iter().sum::<f32>() / all_reversal_ratios.len() as f32;
185        let braces_disbalance_avg: f32 = braces_disbalance_sum as f32 / total_scores as f32;
186
187        println!(
188            "{:<40} {:>8.2} {:>5.1} {:>4} {:>4} {:>4} {:>6.1}% {:>6.1}% {:>6.1}% {:>6.1}%",
189            "TOTAL / AVERAGE",
190            avg_delta_chr_f,
191            braces_disbalance_avg,
192            total_exact_lines.true_positives,
193            total_exact_lines.false_positives,
194            total_exact_lines.false_negatives,
195            total_exact_lines.precision() * 100.0,
196            total_exact_lines.recall() * 100.0,
197            total_exact_lines.f1() * 100.0,
198            avg_reversal_ratio * 100.0
199        );
200        println!("{}", separator);
201    }
202
203    println!("\n");
204}
205
206fn truncate_name(name: &str, max_len: usize) -> String {
207    if name.len() <= max_len {
208        name.to_string()
209    } else {
210        format!("{}...", &name[..max_len - 3])
211    }
212}
213
214#[derive(Serialize)]
215pub struct SummaryJson {
216    pub total_examples: usize,
217    pub avg_delta_chr_f: f32,
218    pub avg_braces_disbalance: f32,
219    pub exact_lines_true_positives: usize,
220    pub exact_lines_false_positives: usize,
221    pub exact_lines_false_negatives: usize,
222    pub exact_lines_precision: f64,
223    pub exact_lines_recall: f64,
224    pub exact_lines_f1: f64,
225    pub avg_reversal_ratio: f32,
226}
227
228pub fn compute_summary(examples: &[Example]) -> SummaryJson {
229    use crate::metrics::ClassificationMetrics;
230
231    let mut all_delta_chr_f_scores = Vec::new();
232    let mut all_reversal_ratios = Vec::new();
233    let mut braces_disbalance_sum: usize = 0;
234    let mut total_exact_lines = ClassificationMetrics::default();
235    let mut total_scores: usize = 0;
236
237    for example in examples {
238        for score in example.score.iter() {
239            all_delta_chr_f_scores.push(score.delta_chr_f);
240            all_reversal_ratios.push(score.reversal_ratio);
241            total_scores += 1;
242            braces_disbalance_sum += score.braces_disbalance;
243            total_exact_lines.true_positives += score.exact_lines_tp;
244            total_exact_lines.false_positives += score.exact_lines_fp;
245            total_exact_lines.false_negatives += score.exact_lines_fn;
246        }
247    }
248
249    let avg_delta_chr_f = if all_delta_chr_f_scores.is_empty() {
250        0.0
251    } else {
252        all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32
253    };
254
255    let avg_reversal_ratio = if all_reversal_ratios.is_empty() {
256        0.0
257    } else {
258        all_reversal_ratios.iter().sum::<f32>() / all_reversal_ratios.len() as f32
259    };
260
261    let avg_braces_disbalance = if total_scores == 0 {
262        0.0
263    } else {
264        braces_disbalance_sum as f32 / total_scores as f32
265    };
266
267    SummaryJson {
268        total_examples: total_scores,
269        avg_delta_chr_f,
270        avg_braces_disbalance,
271        exact_lines_true_positives: total_exact_lines.true_positives,
272        exact_lines_false_positives: total_exact_lines.false_positives,
273        exact_lines_false_negatives: total_exact_lines.false_negatives,
274        exact_lines_precision: total_exact_lines.precision(),
275        exact_lines_recall: total_exact_lines.recall(),
276        exact_lines_f1: total_exact_lines.f1(),
277        avg_reversal_ratio,
278    }
279}
280
281pub fn write_summary_json(examples: &[Example], path: &Path) -> anyhow::Result<()> {
282    let summary = compute_summary(examples);
283    let file = File::create(path)
284        .with_context(|| format!("Failed to create summary JSON file: {}", path.display()))?;
285    let writer = BufWriter::new(file);
286    serde_json::to_writer_pretty(writer, &summary)
287        .with_context(|| format!("Failed to write summary JSON to: {}", path.display()))?;
288    eprintln!("Wrote summary JSON to: {}", path.display());
289    Ok(())
290}