score.rs

  1use crate::{
  2    PredictArgs, PredictionProvider,
  3    example::{Example, ExampleScore},
  4    format_prompt::TeacherPrompt,
  5    headless::EpAppState,
  6    metrics,
  7    parse_output::parse_prediction_output,
  8    predict::run_prediction,
  9    progress::{ExampleProgress, Step},
 10    reversal_tracking,
 11};
 12use anyhow::Context as _;
 13use edit_prediction::udiff::{apply_diff_to_string, apply_diff_to_string_with_hunk_offset};
 14use gpui::AsyncApp;
 15use serde::Serialize;
 16use std::fs::File;
 17use std::io::BufWriter;
 18use std::path::Path;
 19use std::sync::Arc;
 20
 21pub async fn run_scoring(
 22    example: &mut Example,
 23    args: &PredictArgs,
 24    app_state: Arc<EpAppState>,
 25    example_progress: &ExampleProgress,
 26    cx: AsyncApp,
 27) -> anyhow::Result<()> {
 28    run_prediction(example, args, app_state, example_progress, cx).await?;
 29
 30    let progress = example_progress.start(Step::Score);
 31
 32    progress.set_substatus("applying patches");
 33    let original_text = &example
 34        .prompt_inputs
 35        .as_ref()
 36        .context("prompt_inputs is required for scoring - run prediction first or ensure JSON includes prompt_inputs")?
 37        .content;
 38    let expected_patches_with_cursors = example.spec.expected_patches_with_cursor_positions();
 39
 40    let expected_texts: Vec<String> = expected_patches_with_cursors
 41        .iter()
 42        .map(|(patch, _)| {
 43            apply_diff_to_string(patch, original_text)
 44                .with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
 45        })
 46        .collect::<Result<Vec<_>, _>>()?;
 47
 48    // For Teacher prompts, we need to extract the editable region to properly compute cursor offsets.
 49    // The actual_cursor_offset from Teacher is relative to the editable region, while the expected
 50    // cursor from the patch is relative to the hunk. We need to apply the patch to the editable
 51    // region to find where the hunk matched, then compute the expected cursor position.
 52    let old_editable_region = if let Some(p) = example.prompt.as_ref() {
 53        if matches!(
 54            p.provider,
 55            PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_)
 56        ) {
 57            Some(
 58                TeacherPrompt::extract_editable_region(&p.input)?
 59                    .replace(TeacherPrompt::USER_CURSOR_MARKER, ""),
 60            )
 61        } else {
 62            None
 63        }
 64    } else {
 65        None
 66    };
 67
 68    let zero_scores = ExampleScore {
 69        delta_chr_f: 0.0,
 70        braces_disbalance: 0,
 71        exact_lines_tp: 0,
 72        exact_lines_fp: 0,
 73        exact_lines_fn: 0,
 74        reversal_ratio: 0.0,
 75        cursor_distance: None,
 76        cursor_exact_match: None,
 77    };
 78
 79    let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
 80    let cursor_path = example.spec.cursor_path.as_ref();
 81
 82    progress.set_substatus("computing metrics");
 83    let mut scores = vec![];
 84    for prediction in &example.predictions {
 85        let actual_patch = prediction.actual_patch.clone().or_else(|| {
 86            parse_prediction_output(example, &prediction.actual_output, prediction.provider)
 87                .ok()
 88                .map(|(patch, _)| patch)
 89        });
 90
 91        let Some(actual_patch) = actual_patch else {
 92            scores.push(zero_scores.clone());
 93            continue;
 94        };
 95
 96        let actual_text = match apply_diff_to_string(&actual_patch, original_text) {
 97            Ok(text) => text,
 98            Err(_) => {
 99                scores.push(zero_scores.clone());
100                continue;
101            }
102        };
103
104        let mut best_delta_chr_f = 0.0f32;
105        let mut best_expected_cursor: Option<usize> = None;
106        let mut best_patch_idx: Option<usize> = None;
107
108        for (idx, expected) in expected_texts.iter().enumerate() {
109            let delta_chr_f = metrics::delta_chr_f(original_text, expected, &actual_text) as f32;
110            if delta_chr_f > best_delta_chr_f {
111                best_delta_chr_f = delta_chr_f;
112                best_patch_idx = Some(idx);
113            }
114        }
115
116        if let Some(idx) = best_patch_idx {
117            // Get the raw cursor offset from the expected patch (relative to hunk new text)
118            let expected_cursor_in_patch = expected_patches_with_cursors
119                .get(idx)
120                .and_then(|(_, cursor)| *cursor);
121
122            // For Teacher prompts, we need to apply the patch to the editable region
123            // to find where the hunk matched, then compute the actual cursor position
124            if let (Some(editable_region), Some(cursor_in_patch)) =
125                (&old_editable_region, expected_cursor_in_patch)
126            {
127                let (patch, _) = &expected_patches_with_cursors[idx];
128                if let Ok((_, hunk_offset)) =
129                    apply_diff_to_string_with_hunk_offset(patch, editable_region)
130                {
131                    let hunk_start = hunk_offset.unwrap_or(0);
132                    best_expected_cursor = Some(hunk_start + cursor_in_patch);
133                }
134            } else {
135                // For non-Teacher prompts or if we can't compute, use raw offset
136                best_expected_cursor = expected_cursor_in_patch;
137            }
138        }
139
140        let disbalance_before = metrics::braces_disbalance(&original_text);
141        let disbalance_after = metrics::braces_disbalance(&actual_text);
142        let braces_disbalance = disbalance_after.saturating_sub(disbalance_before);
143        if braces_disbalance > 0 {
144            std::fs::write(
145                "/tmp/unbalanced-count.before",
146                disbalance_before.to_string(),
147            )
148            .ok();
149            std::fs::write("/tmp/unbalanced-count.after", disbalance_after.to_string()).ok();
150            std::fs::write("/tmp/unbalanced-text.before", &original_text).ok();
151            std::fs::write("/tmp/unbalanced-text.after", &actual_text).ok();
152        }
153
154        // Compute exact lines match against best matching expected patch
155        let best_exact_lines = expected_patches_with_cursors
156            .iter()
157            .map(|(expected_patch, _)| metrics::exact_lines_match(expected_patch, &actual_patch))
158            .max_by_key(|m| m.true_positives)
159            .unwrap_or_default();
160
161        // Compute reversal ratio
162        let reversal_ratio = reversal_tracking::compute_prediction_reversal_ratio(
163            prompt_inputs,
164            &actual_text,
165            cursor_path,
166        );
167
168        // Compute cursor position metrics
169        let (cursor_distance, cursor_exact_match) =
170            compute_cursor_metrics(best_expected_cursor, prediction.actual_cursor_offset);
171
172        scores.push(ExampleScore {
173            delta_chr_f: best_delta_chr_f,
174            braces_disbalance,
175            exact_lines_tp: best_exact_lines.true_positives,
176            exact_lines_fp: best_exact_lines.false_positives,
177            exact_lines_fn: best_exact_lines.false_negatives,
178            reversal_ratio,
179            cursor_distance,
180            cursor_exact_match,
181        });
182    }
183
184    example.score = scores;
185    Ok(())
186}
187
188fn compute_cursor_metrics(
189    expected_cursor: Option<usize>,
190    actual_cursor: Option<usize>,
191) -> (Option<usize>, Option<bool>) {
192    match (expected_cursor, actual_cursor) {
193        (Some(expected), Some(actual)) => {
194            let distance = expected.abs_diff(actual);
195            let exact_match = expected == actual;
196            (Some(distance), Some(exact_match))
197        }
198        (None, None) => {
199            // Neither has cursor position - skip cursor scoring
200            (None, None)
201        }
202        (Some(_), None) | (None, Some(_)) => {
203            // Only one has cursor position - count as miss
204            (None, Some(false))
205        }
206    }
207}
208
209pub fn print_report(examples: &[Example]) {
210    use crate::metrics::ClassificationMetrics;
211
212    const LINE_WIDTH: usize = 94;
213    let separator = "".repeat(LINE_WIDTH);
214
215    println!("{}", separator);
216    println!(
217        "{:<40} {:>8} {:>5} {:>7} {:>7} {:>7} {:>7} {:>6}",
218        "Example", "DeltaChrF", "Brace", "F1", "Revert", "QaRev", "QaConf", "Cursor"
219    );
220    println!("{}", separator);
221
222    let mut all_delta_chr_f_scores = Vec::new();
223    let mut all_reversal_ratios = Vec::new();
224    let mut braces_disbalance_sum: usize = 0;
225    let mut total_exact_lines = ClassificationMetrics::default();
226    let mut total_scores: usize = 0;
227    let mut qa_reverts_count: usize = 0;
228    let mut qa_reverts_total: usize = 0;
229    let mut qa_confidence_sum: u64 = 0;
230    let mut qa_confidence_count: usize = 0;
231    let mut cursor_exact_matches: usize = 0;
232    let mut cursor_total: usize = 0;
233    let mut cursor_distance_sum: usize = 0;
234    let mut cursor_distance_count: usize = 0;
235
236    for example in examples {
237        for (score_idx, score) in example.score.iter().enumerate() {
238            let exact_lines = ClassificationMetrics {
239                true_positives: score.exact_lines_tp,
240                false_positives: score.exact_lines_fp,
241                false_negatives: score.exact_lines_fn,
242            };
243
244            // Get QA results for this prediction if available
245            let qa_result = example.qa.get(score_idx).and_then(|q| q.as_ref());
246            let qa_reverts_str = qa_result
247                .and_then(|q| q.reverts_edits)
248                .map(|v| if v { "yes" } else { "no" })
249                .unwrap_or("-");
250            let qa_conf_str = qa_result
251                .and_then(|q| q.confidence)
252                .map(|v| format!("{}", v))
253                .unwrap_or("-".to_string());
254
255            // Format cursor metric
256            let cursor_str = match (score.cursor_exact_match, score.cursor_distance) {
257                (Some(true), _) => "".to_string(),
258                (Some(false), Some(dist)) => format!("±{}", dist),
259                (Some(false), None) => "".to_string(),
260                (None, _) => "-".to_string(),
261            };
262
263            println!(
264                "{:<40} {:>8.2} {:>5} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6}",
265                truncate_name(&example.spec.name, 40),
266                score.delta_chr_f,
267                score.braces_disbalance,
268                exact_lines.f1() * 100.0,
269                score.reversal_ratio * 100.0,
270                qa_reverts_str,
271                qa_conf_str,
272                cursor_str
273            );
274
275            all_delta_chr_f_scores.push(score.delta_chr_f);
276            all_reversal_ratios.push(score.reversal_ratio);
277            total_scores += 1;
278            braces_disbalance_sum += score.braces_disbalance;
279            total_exact_lines.true_positives += score.exact_lines_tp;
280            total_exact_lines.false_positives += score.exact_lines_fp;
281            total_exact_lines.false_negatives += score.exact_lines_fn;
282
283            // Accumulate QA metrics
284            if let Some(qa) = qa_result {
285                if let Some(reverts) = qa.reverts_edits {
286                    qa_reverts_total += 1;
287                    if reverts {
288                        qa_reverts_count += 1;
289                    }
290                }
291                if let Some(conf) = qa.confidence {
292                    qa_confidence_sum += conf as u64;
293                    qa_confidence_count += 1;
294                }
295            }
296
297            // Accumulate cursor metrics
298            if let Some(exact_match) = score.cursor_exact_match {
299                cursor_total += 1;
300                if exact_match {
301                    cursor_exact_matches += 1;
302                }
303            }
304            if let Some(dist) = score.cursor_distance {
305                cursor_distance_sum += dist;
306                cursor_distance_count += 1;
307            }
308        }
309    }
310
311    println!("{}", separator);
312
313    if !all_delta_chr_f_scores.is_empty() {
314        let avg_delta_chr_f: f32 =
315            all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
316        let avg_reversal_ratio: f32 =
317            all_reversal_ratios.iter().sum::<f32>() / all_reversal_ratios.len() as f32;
318        let braces_disbalance_avg: f32 = braces_disbalance_sum as f32 / total_scores as f32;
319
320        let qa_reverts_str = if qa_reverts_total > 0 {
321            format!(
322                "{:.1}%",
323                qa_reverts_count as f32 / qa_reverts_total as f32 * 100.0
324            )
325        } else {
326            "-".to_string()
327        };
328        let qa_conf_str = if qa_confidence_count > 0 {
329            format!(
330                "{:.1}",
331                qa_confidence_sum as f32 / qa_confidence_count as f32
332            )
333        } else {
334            "-".to_string()
335        };
336        let cursor_str = if cursor_total > 0 {
337            format!(
338                "{:.0}%",
339                cursor_exact_matches as f32 / cursor_total as f32 * 100.0
340            )
341        } else {
342            "-".to_string()
343        };
344        let avg_cursor_distance = if cursor_distance_count > 0 {
345            Some(cursor_distance_sum as f32 / cursor_distance_count as f32)
346        } else {
347            None
348        };
349
350        println!(
351            "{:<40} {:>8.2} {:>5.1} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6}",
352            "TOTAL / AVERAGE",
353            avg_delta_chr_f,
354            braces_disbalance_avg,
355            total_exact_lines.f1() * 100.0,
356            avg_reversal_ratio * 100.0,
357            qa_reverts_str,
358            qa_conf_str,
359            cursor_str
360        );
361        println!("{}", separator);
362
363        // Print additional cursor metrics if available
364        if let Some(avg_dist) = avg_cursor_distance {
365            println!(
366                "Cursor: {}/{} exact matches ({:.0}%), avg distance: {:.1} bytes",
367                cursor_exact_matches,
368                cursor_total,
369                cursor_exact_matches as f32 / cursor_total as f32 * 100.0,
370                avg_dist
371            );
372        }
373    }
374
375    println!("\n");
376}
377
378fn truncate_name(name: &str, max_len: usize) -> String {
379    if name.len() <= max_len {
380        name.to_string()
381    } else {
382        format!("{}...", &name[..max_len - 3])
383    }
384}
385
386#[derive(Serialize)]
387pub struct SummaryJson {
388    pub total_examples: usize,
389    pub avg_delta_chr_f: f32,
390    pub avg_braces_disbalance: f32,
391    pub exact_lines_true_positives: usize,
392    pub exact_lines_false_positives: usize,
393    pub exact_lines_false_negatives: usize,
394    pub exact_lines_precision: f64,
395    pub exact_lines_recall: f64,
396    pub exact_lines_f1: f64,
397    pub avg_reversal_ratio: f32,
398    #[serde(skip_serializing_if = "Option::is_none")]
399    pub qa_avg_reverts_edits: Option<f32>,
400    #[serde(skip_serializing_if = "Option::is_none")]
401    pub qa_avg_confidence: Option<f32>,
402    #[serde(skip_serializing_if = "Option::is_none")]
403    pub cursor_exact_match_rate: Option<f32>,
404    #[serde(skip_serializing_if = "Option::is_none")]
405    pub cursor_avg_distance: Option<f32>,
406    #[serde(skip_serializing_if = "Option::is_none")]
407    pub cursor_total_evaluated: Option<usize>,
408}
409
410pub fn compute_summary(examples: &[Example]) -> SummaryJson {
411    use crate::metrics::ClassificationMetrics;
412
413    let mut all_delta_chr_f_scores = Vec::new();
414    let mut all_reversal_ratios = Vec::new();
415    let mut braces_disbalance_sum: usize = 0;
416    let mut total_exact_lines = ClassificationMetrics::default();
417    let mut total_scores: usize = 0;
418    let mut qa_reverts_count: usize = 0;
419    let mut qa_reverts_total: usize = 0;
420    let mut qa_confidence_sum: u64 = 0;
421    let mut qa_confidence_count: usize = 0;
422    let mut cursor_exact_matches: usize = 0;
423    let mut cursor_total: usize = 0;
424    let mut cursor_distance_sum: usize = 0;
425    let mut cursor_distance_count: usize = 0;
426
427    for example in examples {
428        for (score_idx, score) in example.score.iter().enumerate() {
429            all_delta_chr_f_scores.push(score.delta_chr_f);
430            all_reversal_ratios.push(score.reversal_ratio);
431            total_scores += 1;
432            braces_disbalance_sum += score.braces_disbalance;
433            total_exact_lines.true_positives += score.exact_lines_tp;
434            total_exact_lines.false_positives += score.exact_lines_fp;
435            total_exact_lines.false_negatives += score.exact_lines_fn;
436
437            // Accumulate QA metrics
438            if let Some(Some(qa)) = example.qa.get(score_idx) {
439                if let Some(reverts) = qa.reverts_edits {
440                    qa_reverts_total += 1;
441                    if reverts {
442                        qa_reverts_count += 1;
443                    }
444                }
445                if let Some(conf) = qa.confidence {
446                    qa_confidence_sum += conf as u64;
447                    qa_confidence_count += 1;
448                }
449            }
450
451            // Accumulate cursor metrics
452            if let Some(exact_match) = score.cursor_exact_match {
453                cursor_total += 1;
454                if exact_match {
455                    cursor_exact_matches += 1;
456                }
457            }
458            if let Some(dist) = score.cursor_distance {
459                cursor_distance_sum += dist;
460                cursor_distance_count += 1;
461            }
462        }
463    }
464
465    let avg_delta_chr_f = if all_delta_chr_f_scores.is_empty() {
466        0.0
467    } else {
468        all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32
469    };
470
471    let avg_reversal_ratio = if all_reversal_ratios.is_empty() {
472        0.0
473    } else {
474        all_reversal_ratios.iter().sum::<f32>() / all_reversal_ratios.len() as f32
475    };
476
477    let avg_braces_disbalance = if total_scores == 0 {
478        0.0
479    } else {
480        braces_disbalance_sum as f32 / total_scores as f32
481    };
482
483    let qa_avg_reverts_edits = if qa_reverts_total > 0 {
484        Some(qa_reverts_count as f32 / qa_reverts_total as f32)
485    } else {
486        None
487    };
488
489    let qa_avg_confidence = if qa_confidence_count > 0 {
490        Some(qa_confidence_sum as f32 / qa_confidence_count as f32)
491    } else {
492        None
493    };
494
495    let cursor_exact_match_rate = if cursor_total > 0 {
496        Some(cursor_exact_matches as f32 / cursor_total as f32)
497    } else {
498        None
499    };
500
501    let cursor_avg_distance = if cursor_distance_count > 0 {
502        Some(cursor_distance_sum as f32 / cursor_distance_count as f32)
503    } else {
504        None
505    };
506
507    let cursor_total_evaluated = if cursor_total > 0 {
508        Some(cursor_total)
509    } else {
510        None
511    };
512
513    SummaryJson {
514        total_examples: total_scores,
515        avg_delta_chr_f,
516        avg_braces_disbalance,
517        exact_lines_true_positives: total_exact_lines.true_positives,
518        exact_lines_false_positives: total_exact_lines.false_positives,
519        exact_lines_false_negatives: total_exact_lines.false_negatives,
520        exact_lines_precision: total_exact_lines.precision(),
521        exact_lines_recall: total_exact_lines.recall(),
522        exact_lines_f1: total_exact_lines.f1(),
523        avg_reversal_ratio,
524        qa_avg_reverts_edits,
525        qa_avg_confidence,
526        cursor_exact_match_rate,
527        cursor_avg_distance,
528        cursor_total_evaluated,
529    }
530}
531
532pub fn write_summary_json(examples: &[Example], path: &Path) -> anyhow::Result<()> {
533    let summary = compute_summary(examples);
534    let file = File::create(path)
535        .with_context(|| format!("Failed to create summary JSON file: {}", path.display()))?;
536    let writer = BufWriter::new(file);
537    serde_json::to_writer_pretty(writer, &summary)
538        .with_context(|| format!("Failed to write summary JSON to: {}", path.display()))?;
539    eprintln!("Wrote summary JSON to: {}", path.display());
540    Ok(())
541}