score.rs

  1use crate::{
  2    PredictArgs,
  3    example::{Example, ExampleScore},
  4    headless::EpAppState,
  5    metrics,
  6    parse_output::parse_prediction_output,
  7    predict::run_prediction,
  8    progress::{ExampleProgress, Step},
  9    reversal_tracking,
 10};
 11use anyhow::Context as _;
 12use edit_prediction::udiff::apply_diff_to_string;
 13use gpui::AsyncApp;
 14use serde::Serialize;
 15use std::fs::File;
 16use std::io::BufWriter;
 17use std::path::Path;
 18use std::sync::Arc;
 19
 20pub async fn run_scoring(
 21    example: &mut Example,
 22    args: &PredictArgs,
 23    app_state: Arc<EpAppState>,
 24    example_progress: &ExampleProgress,
 25    cx: AsyncApp,
 26) -> anyhow::Result<()> {
 27    run_prediction(example, args, app_state, example_progress, cx).await?;
 28
 29    let progress = example_progress.start(Step::Score);
 30
 31    progress.set_substatus("applying patches");
 32    let original_text = &example
 33        .prompt_inputs
 34        .as_ref()
 35        .context("prompt_inputs is required for scoring - run prediction first or ensure JSON includes prompt_inputs")?
 36        .content;
 37    let expected_texts: Vec<String> = example
 38        .spec
 39        .expected_patches
 40        .iter()
 41        .map(|patch| {
 42            apply_diff_to_string(patch, original_text)
 43                .with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
 44        })
 45        .collect::<Result<Vec<_>, _>>()?;
 46
 47    let zero_scores = ExampleScore {
 48        delta_chr_f: 0.0,
 49        braces_disbalance: 0,
 50        exact_lines_tp: 0,
 51        exact_lines_fp: 0,
 52        exact_lines_fn: 0,
 53        reversal_ratio: 0.0,
 54    };
 55
 56    let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
 57    let cursor_path = example.spec.cursor_path.as_ref();
 58
 59    progress.set_substatus("computing metrics");
 60    let mut scores = vec![];
 61    for prediction in &example.predictions {
 62        let actual_patch = prediction.actual_patch.clone().or_else(|| {
 63            parse_prediction_output(example, &prediction.actual_output, prediction.provider).ok()
 64        });
 65
 66        let Some(actual_patch) = actual_patch else {
 67            scores.push(zero_scores.clone());
 68            continue;
 69        };
 70
 71        let actual_text = match apply_diff_to_string(&actual_patch, original_text) {
 72            Ok(text) => text,
 73            Err(_) => {
 74                scores.push(zero_scores.clone());
 75                continue;
 76            }
 77        };
 78        let best_delta_chr_f = expected_texts
 79            .iter()
 80            .map(|expected| metrics::delta_chr_f(original_text, expected, &actual_text) as f32)
 81            .fold(0.0, f32::max);
 82
 83        let disbalance_before = metrics::braces_disbalance(&original_text);
 84        let disbalance_after = metrics::braces_disbalance(&actual_text);
 85        let braces_disbalance = disbalance_after.saturating_sub(disbalance_before);
 86        if braces_disbalance > 0 {
 87            std::fs::write(
 88                "/tmp/unbalanced-count.before",
 89                disbalance_before.to_string(),
 90            )
 91            .ok();
 92            std::fs::write("/tmp/unbalanced-count.after", disbalance_after.to_string()).ok();
 93            std::fs::write("/tmp/unbalanced-text.before", &original_text).ok();
 94            std::fs::write("/tmp/unbalanced-text.after", &actual_text).ok();
 95        }
 96
 97        // Compute exact lines match against best matching expected patch
 98        let best_exact_lines = example
 99            .spec
100            .expected_patches
101            .iter()
102            .map(|expected_patch| metrics::exact_lines_match(expected_patch, &actual_patch))
103            .max_by_key(|m| m.true_positives)
104            .unwrap_or_default();
105
106        // Compute reversal ratio
107        let reversal_ratio = reversal_tracking::compute_prediction_reversal_ratio(
108            prompt_inputs,
109            &actual_text,
110            cursor_path,
111        );
112
113        scores.push(ExampleScore {
114            delta_chr_f: best_delta_chr_f,
115            braces_disbalance,
116            exact_lines_tp: best_exact_lines.true_positives,
117            exact_lines_fp: best_exact_lines.false_positives,
118            exact_lines_fn: best_exact_lines.false_negatives,
119            reversal_ratio,
120        });
121    }
122
123    example.score = scores;
124    Ok(())
125}
126
127pub fn print_report(examples: &[Example]) {
128    use crate::metrics::ClassificationMetrics;
129
130    const LINE_WIDTH: usize = 82;
131    let separator = "".repeat(LINE_WIDTH);
132
133    println!("{}", separator);
134    println!(
135        "{:<40} {:>8} {:>5} {:>7} {:>7} {:>7} {:>7}",
136        "Example", "DeltaChrF", "Brace", "F1", "Revert", "QaRev", "QaConf"
137    );
138    println!("{}", separator);
139
140    let mut all_delta_chr_f_scores = Vec::new();
141    let mut all_reversal_ratios = Vec::new();
142    let mut braces_disbalance_sum: usize = 0;
143    let mut total_exact_lines = ClassificationMetrics::default();
144    let mut total_scores: usize = 0;
145    let mut qa_reverts_count: usize = 0;
146    let mut qa_reverts_total: usize = 0;
147    let mut qa_confidence_sum: u64 = 0;
148    let mut qa_confidence_count: usize = 0;
149
150    for example in examples {
151        for (score_idx, score) in example.score.iter().enumerate() {
152            let exact_lines = ClassificationMetrics {
153                true_positives: score.exact_lines_tp,
154                false_positives: score.exact_lines_fp,
155                false_negatives: score.exact_lines_fn,
156            };
157
158            // Get QA results for this prediction if available
159            let qa_result = example.qa.get(score_idx).and_then(|q| q.as_ref());
160            let qa_reverts_str = qa_result
161                .and_then(|q| q.reverts_edits)
162                .map(|v| if v { "yes" } else { "no" })
163                .unwrap_or("-");
164            let qa_conf_str = qa_result
165                .and_then(|q| q.confidence)
166                .map(|v| format!("{}", v))
167                .unwrap_or("-".to_string());
168
169            println!(
170                "{:<40} {:>8.2} {:>5} {:>6.1}% {:>6.1}% {:>7} {:>7}",
171                truncate_name(&example.spec.name, 40),
172                score.delta_chr_f,
173                score.braces_disbalance,
174                exact_lines.f1() * 100.0,
175                score.reversal_ratio * 100.0,
176                qa_reverts_str,
177                qa_conf_str
178            );
179
180            all_delta_chr_f_scores.push(score.delta_chr_f);
181            all_reversal_ratios.push(score.reversal_ratio);
182            total_scores += 1;
183            braces_disbalance_sum += score.braces_disbalance;
184            total_exact_lines.true_positives += score.exact_lines_tp;
185            total_exact_lines.false_positives += score.exact_lines_fp;
186            total_exact_lines.false_negatives += score.exact_lines_fn;
187
188            // Accumulate QA metrics
189            if let Some(qa) = qa_result {
190                if let Some(reverts) = qa.reverts_edits {
191                    qa_reverts_total += 1;
192                    if reverts {
193                        qa_reverts_count += 1;
194                    }
195                }
196                if let Some(conf) = qa.confidence {
197                    qa_confidence_sum += conf as u64;
198                    qa_confidence_count += 1;
199                }
200            }
201        }
202    }
203
204    println!("{}", separator);
205
206    if !all_delta_chr_f_scores.is_empty() {
207        let avg_delta_chr_f: f32 =
208            all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
209        let avg_reversal_ratio: f32 =
210            all_reversal_ratios.iter().sum::<f32>() / all_reversal_ratios.len() as f32;
211        let braces_disbalance_avg: f32 = braces_disbalance_sum as f32 / total_scores as f32;
212
213        let qa_reverts_str = if qa_reverts_total > 0 {
214            format!(
215                "{:.1}%",
216                qa_reverts_count as f32 / qa_reverts_total as f32 * 100.0
217            )
218        } else {
219            "-".to_string()
220        };
221        let qa_conf_str = if qa_confidence_count > 0 {
222            format!(
223                "{:.1}",
224                qa_confidence_sum as f32 / qa_confidence_count as f32
225            )
226        } else {
227            "-".to_string()
228        };
229
230        println!(
231            "{:<40} {:>8.2} {:>5.1} {:>6.1}% {:>6.1}% {:>7} {:>7}",
232            "TOTAL / AVERAGE",
233            avg_delta_chr_f,
234            braces_disbalance_avg,
235            total_exact_lines.f1() * 100.0,
236            avg_reversal_ratio * 100.0,
237            qa_reverts_str,
238            qa_conf_str
239        );
240        println!("{}", separator);
241    }
242
243    println!("\n");
244}
245
246fn truncate_name(name: &str, max_len: usize) -> String {
247    if name.len() <= max_len {
248        name.to_string()
249    } else {
250        format!("{}...", &name[..max_len - 3])
251    }
252}
253
254#[derive(Serialize)]
255pub struct SummaryJson {
256    pub total_examples: usize,
257    pub avg_delta_chr_f: f32,
258    pub avg_braces_disbalance: f32,
259    pub exact_lines_true_positives: usize,
260    pub exact_lines_false_positives: usize,
261    pub exact_lines_false_negatives: usize,
262    pub exact_lines_precision: f64,
263    pub exact_lines_recall: f64,
264    pub exact_lines_f1: f64,
265    pub avg_reversal_ratio: f32,
266    #[serde(skip_serializing_if = "Option::is_none")]
267    pub qa_avg_reverts_edits: Option<f32>,
268    #[serde(skip_serializing_if = "Option::is_none")]
269    pub qa_avg_confidence: Option<f32>,
270}
271
272pub fn compute_summary(examples: &[Example]) -> SummaryJson {
273    use crate::metrics::ClassificationMetrics;
274
275    let mut all_delta_chr_f_scores = Vec::new();
276    let mut all_reversal_ratios = Vec::new();
277    let mut braces_disbalance_sum: usize = 0;
278    let mut total_exact_lines = ClassificationMetrics::default();
279    let mut total_scores: usize = 0;
280    let mut qa_reverts_count: usize = 0;
281    let mut qa_reverts_total: usize = 0;
282    let mut qa_confidence_sum: u64 = 0;
283    let mut qa_confidence_count: usize = 0;
284
285    for example in examples {
286        for (score_idx, score) in example.score.iter().enumerate() {
287            all_delta_chr_f_scores.push(score.delta_chr_f);
288            all_reversal_ratios.push(score.reversal_ratio);
289            total_scores += 1;
290            braces_disbalance_sum += score.braces_disbalance;
291            total_exact_lines.true_positives += score.exact_lines_tp;
292            total_exact_lines.false_positives += score.exact_lines_fp;
293            total_exact_lines.false_negatives += score.exact_lines_fn;
294
295            // Accumulate QA metrics
296            if let Some(Some(qa)) = example.qa.get(score_idx) {
297                if let Some(reverts) = qa.reverts_edits {
298                    qa_reverts_total += 1;
299                    if reverts {
300                        qa_reverts_count += 1;
301                    }
302                }
303                if let Some(conf) = qa.confidence {
304                    qa_confidence_sum += conf as u64;
305                    qa_confidence_count += 1;
306                }
307            }
308        }
309    }
310
311    let avg_delta_chr_f = if all_delta_chr_f_scores.is_empty() {
312        0.0
313    } else {
314        all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32
315    };
316
317    let avg_reversal_ratio = if all_reversal_ratios.is_empty() {
318        0.0
319    } else {
320        all_reversal_ratios.iter().sum::<f32>() / all_reversal_ratios.len() as f32
321    };
322
323    let avg_braces_disbalance = if total_scores == 0 {
324        0.0
325    } else {
326        braces_disbalance_sum as f32 / total_scores as f32
327    };
328
329    let qa_avg_reverts_edits = if qa_reverts_total > 0 {
330        Some(qa_reverts_count as f32 / qa_reverts_total as f32)
331    } else {
332        None
333    };
334
335    let qa_avg_confidence = if qa_confidence_count > 0 {
336        Some(qa_confidence_sum as f32 / qa_confidence_count as f32)
337    } else {
338        None
339    };
340
341    SummaryJson {
342        total_examples: total_scores,
343        avg_delta_chr_f,
344        avg_braces_disbalance,
345        exact_lines_true_positives: total_exact_lines.true_positives,
346        exact_lines_false_positives: total_exact_lines.false_positives,
347        exact_lines_false_negatives: total_exact_lines.false_negatives,
348        exact_lines_precision: total_exact_lines.precision(),
349        exact_lines_recall: total_exact_lines.recall(),
350        exact_lines_f1: total_exact_lines.f1(),
351        avg_reversal_ratio,
352        qa_avg_reverts_edits,
353        qa_avg_confidence,
354    }
355}
356
357pub fn write_summary_json(examples: &[Example], path: &Path) -> anyhow::Result<()> {
358    let summary = compute_summary(examples);
359    let file = File::create(path)
360        .with_context(|| format!("Failed to create summary JSON file: {}", path.display()))?;
361    let writer = BufWriter::new(file);
362    serde_json::to_writer_pretty(writer, &summary)
363        .with_context(|| format!("Failed to write summary JSON to: {}", path.display()))?;
364    eprintln!("Wrote summary JSON to: {}", path.display());
365    Ok(())
366}