score.rs

  1use crate::{
  2    PredictArgs, PredictionProvider,
  3    example::{ActualCursor, Example, ExampleScore},
  4    format_prompt::TeacherPrompt,
  5    headless::EpAppState,
  6    metrics,
  7    parse_output::parse_prediction_output,
  8    predict::run_prediction,
  9    progress::{ExampleProgress, Step},
 10    reversal_tracking,
 11};
 12use anyhow::Context as _;
 13use edit_prediction::udiff::{apply_diff_to_string, apply_diff_to_string_with_hunk_offset};
 14use gpui::AsyncApp;
 15use serde::Serialize;
 16use std::fs::File;
 17use std::io::BufWriter;
 18use std::path::Path;
 19use std::sync::Arc;
 20
 21pub async fn run_scoring(
 22    example: &mut Example,
 23    args: &PredictArgs,
 24    app_state: Arc<EpAppState>,
 25    example_progress: &ExampleProgress,
 26    cx: AsyncApp,
 27) -> anyhow::Result<()> {
 28    run_prediction(example, args, app_state, example_progress, cx).await?;
 29
 30    let progress = example_progress.start(Step::Score);
 31
 32    progress.set_substatus("applying patches");
 33    let original_text = &example
 34        .prompt_inputs
 35        .as_ref()
 36        .context("prompt_inputs is required for scoring - run prediction first or ensure JSON includes prompt_inputs")?
 37        .content;
 38    let expected_patches_with_cursors = example.spec.expected_patches_with_cursor_positions();
 39
 40    let expected_texts: Vec<String> = expected_patches_with_cursors
 41        .iter()
 42        .map(|(patch, _)| {
 43            apply_diff_to_string(patch, original_text)
 44                .with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
 45        })
 46        .collect::<Result<Vec<_>, _>>()?;
 47
 48    // For Teacher prompts, we need to extract the editable region to properly compute cursor offsets.
 49    // The actual_cursor_offset from Teacher is relative to the editable region, while the expected
 50    // cursor from the patch is relative to the hunk. We need to apply the patch to the editable
 51    // region to find where the hunk matched, then compute the expected cursor position.
 52    let old_editable_region = if let Some(p) = example.prompt.as_ref() {
 53        if matches!(
 54            p.provider,
 55            PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_)
 56        ) {
 57            Some(
 58                TeacherPrompt::extract_editable_region(&p.input)?
 59                    .replace(TeacherPrompt::USER_CURSOR_MARKER, ""),
 60            )
 61        } else {
 62            None
 63        }
 64    } else {
 65        None
 66    };
 67
 68    let zero_scores = ExampleScore {
 69        delta_chr_f: 0.0,
 70        braces_disbalance: 0,
 71        exact_lines_tp: 0,
 72        exact_lines_fp: 0,
 73        exact_lines_fn: 0,
 74        reversal_ratio: 0.0,
 75        cursor_distance: None,
 76        cursor_exact_match: None,
 77        wrong_editable_region: None,
 78        has_isolated_whitespace_changes: false,
 79    };
 80
 81    let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
 82    let cursor_path = example.spec.cursor_path.as_ref();
 83
 84    progress.set_substatus("computing metrics");
 85    let mut scores = vec![];
 86    for prediction in &example.predictions {
 87        let actual_patch = prediction.actual_patch.clone().or_else(|| {
 88            parse_prediction_output(example, &prediction.actual_output, prediction.provider)
 89                .ok()
 90                .map(|(patch, _)| patch)
 91        });
 92
 93        let Some(actual_patch) = actual_patch else {
 94            scores.push(zero_scores.clone());
 95            continue;
 96        };
 97
 98        let actual_text = match apply_diff_to_string(&actual_patch, original_text) {
 99            Ok(text) => text,
100            Err(_) => {
101                scores.push(zero_scores.clone());
102                continue;
103            }
104        };
105
106        let mut best_delta_chr_f = 0.0f32;
107        let mut best_expected_cursor: Option<usize> = None;
108        let mut best_patch_idx: Option<usize> = None;
109
110        for (idx, expected) in expected_texts.iter().enumerate() {
111            let delta_chr_f = metrics::delta_chr_f(original_text, expected, &actual_text) as f32;
112            if delta_chr_f > best_delta_chr_f {
113                best_delta_chr_f = delta_chr_f;
114                best_patch_idx = Some(idx);
115            }
116        }
117
118        if let Some(idx) = best_patch_idx {
119            // Get the raw cursor offset from the expected patch (relative to hunk new text)
120            let expected_cursor_in_patch = expected_patches_with_cursors
121                .get(idx)
122                .and_then(|(_, cursor)| *cursor);
123
124            // For Teacher prompts, we need to apply the patch to the editable region
125            // to find where the hunk matched, then compute the actual cursor position
126            if let (Some(editable_region), Some(cursor_in_patch)) =
127                (&old_editable_region, expected_cursor_in_patch)
128            {
129                let (patch, _) = &expected_patches_with_cursors[idx];
130                if let Ok((_, hunk_offset)) =
131                    apply_diff_to_string_with_hunk_offset(patch, editable_region)
132                {
133                    let hunk_start = hunk_offset.unwrap_or(0);
134                    best_expected_cursor = Some(hunk_start + cursor_in_patch);
135                }
136            } else {
137                // For non-Teacher prompts or if we can't compute, use raw offset
138                best_expected_cursor = expected_cursor_in_patch;
139            }
140        }
141
142        let disbalance_before = metrics::braces_disbalance(&original_text);
143        let disbalance_after = metrics::braces_disbalance(&actual_text);
144        let braces_disbalance = disbalance_after.saturating_sub(disbalance_before);
145
146        // Compute exact lines match against best matching expected patch
147        let best_exact_lines = expected_patches_with_cursors
148            .iter()
149            .map(|(expected_patch, _)| metrics::exact_lines_match(expected_patch, &actual_patch))
150            .max_by_key(|m| m.true_positives)
151            .unwrap_or_default();
152
153        // Compute reversal ratio
154        let reversal_ratio = reversal_tracking::compute_prediction_reversal_ratio(
155            prompt_inputs,
156            &actual_text,
157            cursor_path,
158        );
159
160        // Compute cursor position metrics
161        let (cursor_distance, cursor_exact_match) =
162            compute_cursor_metrics(best_expected_cursor, prediction.actual_cursor.as_ref());
163
164        // Compute approximation of editable region correctness
165        let wrong_editable_region = Some(!metrics::is_editable_region_correct(&actual_patch));
166
167        // Check for isolated whitespace changes.
168        let has_isolated_whitespace_changes = metrics::has_isolated_whitespace_changes(
169            &actual_patch,
170            prediction.actual_cursor.as_ref(),
171        );
172
173        scores.push(ExampleScore {
174            delta_chr_f: best_delta_chr_f,
175            braces_disbalance,
176            exact_lines_tp: best_exact_lines.true_positives,
177            exact_lines_fp: best_exact_lines.false_positives,
178            exact_lines_fn: best_exact_lines.false_negatives,
179            reversal_ratio,
180            cursor_distance,
181            cursor_exact_match,
182            wrong_editable_region,
183            has_isolated_whitespace_changes,
184        });
185    }
186
187    example.score = scores;
188    Ok(())
189}
190
191fn compute_cursor_metrics(
192    expected_cursor_editable_region_offset: Option<usize>,
193    actual_cursor: Option<&ActualCursor>,
194) -> (Option<usize>, Option<bool>) {
195    match (expected_cursor_editable_region_offset, actual_cursor) {
196        (Some(expected), Some(actual)) => {
197            let distance = expected.abs_diff(actual.editable_region_offset.unwrap_or_default());
198            let exact_match = distance == 0;
199            (Some(distance), Some(exact_match))
200        }
201        (None, None) => {
202            // Neither has cursor position - skip cursor scoring
203            (None, None)
204        }
205        (Some(_), None) | (None, Some(_)) => {
206            // Only one has cursor position - count as miss
207            (None, Some(false))
208        }
209    }
210}
211
212pub fn print_report(examples: &[Example]) {
213    use crate::metrics::ClassificationMetrics;
214
215    const LINE_WIDTH: usize = 101;
216    let separator = "".repeat(LINE_WIDTH);
217
218    println!("{}", separator);
219    println!(
220        "{:<40} {:>8} {:>5} {:>7} {:>7} {:>7} {:>7} {:>6} {:>5}",
221        "Example", "DeltaChrF", "Brace", "F1", "Revert", "QaRev", "QaConf", "Cursor", "WrgER"
222    );
223    println!("{}", separator);
224
225    let mut all_delta_chr_f_scores = Vec::new();
226    let mut all_reversal_ratios = Vec::new();
227    let mut braces_disbalance_sum: usize = 0;
228    let mut total_exact_lines = ClassificationMetrics::default();
229    let mut total_scores: usize = 0;
230    let mut qa_reverts_count: usize = 0;
231    let mut qa_reverts_total: usize = 0;
232    let mut qa_confidence_sum: u64 = 0;
233    let mut qa_confidence_count: usize = 0;
234    let mut cursor_exact_matches: usize = 0;
235    let mut cursor_total: usize = 0;
236    let mut cursor_distance_sum: usize = 0;
237    let mut cursor_distance_count: usize = 0;
238    let mut wrong_editable_region_count: usize = 0;
239    let mut wrong_editable_region_total: usize = 0;
240    let mut isolated_whitespace_count: usize = 0;
241
242    for example in examples {
243        for (score_idx, score) in example.score.iter().enumerate() {
244            let exact_lines = ClassificationMetrics {
245                true_positives: score.exact_lines_tp,
246                false_positives: score.exact_lines_fp,
247                false_negatives: score.exact_lines_fn,
248            };
249
250            // Get QA results for this prediction if available
251            let qa_result = example.qa.get(score_idx).and_then(|q| q.as_ref());
252            let qa_reverts_str = qa_result
253                .and_then(|q| q.reverts_edits)
254                .map(|v| if v { "yes" } else { "no" })
255                .unwrap_or("-");
256            let qa_conf_str = qa_result
257                .and_then(|q| q.confidence)
258                .map(|v| format!("{}", v))
259                .unwrap_or("-".to_string());
260
261            // Format wrong editable region metric
262            let wrong_er_str = match score.wrong_editable_region {
263                Some(true) => "",
264                Some(false) => "",
265                None => "",
266            };
267
268            // Format cursor metric
269            let cursor_str = match (score.cursor_exact_match, score.cursor_distance) {
270                (Some(true), _) => "".to_string(),
271                (Some(false), Some(dist)) => format!("±{}", dist),
272                (Some(false), None) => "".to_string(),
273                (None, _) => "-".to_string(),
274            };
275
276            println!(
277                "{:<40} {:>8.2} {:>5} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6} {:>5}",
278                truncate_name(&example.spec.name, 40),
279                score.delta_chr_f,
280                score.braces_disbalance,
281                exact_lines.f1() * 100.0,
282                score.reversal_ratio * 100.0,
283                qa_reverts_str,
284                qa_conf_str,
285                cursor_str,
286                wrong_er_str
287            );
288
289            all_delta_chr_f_scores.push(score.delta_chr_f);
290            all_reversal_ratios.push(score.reversal_ratio);
291            total_scores += 1;
292            braces_disbalance_sum += score.braces_disbalance;
293            total_exact_lines.true_positives += score.exact_lines_tp;
294            total_exact_lines.false_positives += score.exact_lines_fp;
295            total_exact_lines.false_negatives += score.exact_lines_fn;
296
297            // Accumulate QA metrics
298            if let Some(qa) = qa_result {
299                if let Some(reverts) = qa.reverts_edits {
300                    qa_reverts_total += 1;
301                    if reverts {
302                        qa_reverts_count += 1;
303                    }
304                }
305                if let Some(conf) = qa.confidence {
306                    qa_confidence_sum += conf as u64;
307                    qa_confidence_count += 1;
308                }
309            }
310
311            // Accumulate wrong editable region metrics
312            if let Some(wrong) = score.wrong_editable_region {
313                wrong_editable_region_total += 1;
314                if wrong {
315                    wrong_editable_region_count += 1;
316                }
317            }
318
319            // Accumulate isolated whitespace metrics
320            if score.has_isolated_whitespace_changes {
321                isolated_whitespace_count += 1;
322            }
323
324            // Accumulate cursor metrics
325            if let Some(exact_match) = score.cursor_exact_match {
326                cursor_total += 1;
327                if exact_match {
328                    cursor_exact_matches += 1;
329                }
330            }
331            if let Some(dist) = score.cursor_distance {
332                cursor_distance_sum += dist;
333                cursor_distance_count += 1;
334            }
335        }
336    }
337
338    println!("{}", separator);
339
340    if !all_delta_chr_f_scores.is_empty() {
341        let avg_delta_chr_f: f32 =
342            all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
343        let avg_reversal_ratio: f32 =
344            all_reversal_ratios.iter().sum::<f32>() / all_reversal_ratios.len() as f32;
345        let braces_disbalance_avg: f32 = braces_disbalance_sum as f32 / total_scores as f32;
346
347        let qa_reverts_str = if qa_reverts_total > 0 {
348            format!(
349                "{:.1}%",
350                qa_reverts_count as f32 / qa_reverts_total as f32 * 100.0
351            )
352        } else {
353            "-".to_string()
354        };
355        let qa_conf_str = if qa_confidence_count > 0 {
356            format!(
357                "{:.1}",
358                qa_confidence_sum as f32 / qa_confidence_count as f32
359            )
360        } else {
361            "-".to_string()
362        };
363        let cursor_str = if cursor_total > 0 {
364            format!(
365                "{:.0}%",
366                cursor_exact_matches as f32 / cursor_total as f32 * 100.0
367            )
368        } else {
369            "-".to_string()
370        };
371        let wrong_er_str = if wrong_editable_region_total > 0 {
372            format!(
373                "{:.2}%",
374                wrong_editable_region_count as f32 / wrong_editable_region_total as f32 * 100.0
375            )
376        } else {
377            "-".to_string()
378        };
379        let isolated_ws_str = if total_scores > 0 {
380            format!(
381                "{}/{} ({:.1}%)",
382                isolated_whitespace_count,
383                total_scores,
384                isolated_whitespace_count as f32 / total_scores as f32 * 100.0
385            )
386        } else {
387            "-".to_string()
388        };
389        let avg_cursor_distance = if cursor_distance_count > 0 {
390            Some(cursor_distance_sum as f32 / cursor_distance_count as f32)
391        } else {
392            None
393        };
394
395        println!(
396            "{:<40} {:>8.2} {:>5.1} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6} {:>5}",
397            "TOTAL / AVERAGE",
398            avg_delta_chr_f,
399            braces_disbalance_avg,
400            total_exact_lines.f1() * 100.0,
401            avg_reversal_ratio * 100.0,
402            qa_reverts_str,
403            qa_conf_str,
404            cursor_str,
405            wrong_er_str
406        );
407        println!("{}", separator);
408
409        // Print additional cursor metrics if available
410        if let Some(avg_dist) = avg_cursor_distance {
411            println!(
412                "Cursor: {}/{} exact matches ({:.0}%), avg distance: {:.1} bytes",
413                cursor_exact_matches,
414                cursor_total,
415                cursor_exact_matches as f32 / cursor_total as f32 * 100.0,
416                avg_dist
417            );
418        }
419
420        // Print isolated whitespace metrics
421        if total_scores > 0 {
422            println!("Isolated whitespace changes: {}", isolated_ws_str);
423        }
424    }
425
426    println!("\n");
427}
428
429fn truncate_name(name: &str, max_len: usize) -> String {
430    if name.len() <= max_len {
431        name.to_string()
432    } else {
433        format!("{}...", &name[..max_len - 3])
434    }
435}
436
437#[derive(Serialize)]
438pub struct SummaryJson {
439    pub total_examples: usize,
440    pub avg_delta_chr_f: f32,
441    pub avg_braces_disbalance: f32,
442    pub exact_lines_true_positives: usize,
443    pub exact_lines_false_positives: usize,
444    pub exact_lines_false_negatives: usize,
445    pub exact_lines_precision: f64,
446    pub exact_lines_recall: f64,
447    pub exact_lines_f1: f64,
448    pub avg_reversal_ratio: f32,
449    #[serde(skip_serializing_if = "Option::is_none")]
450    pub qa_avg_reverts_edits: Option<f32>,
451    #[serde(skip_serializing_if = "Option::is_none")]
452    pub qa_avg_confidence: Option<f32>,
453    #[serde(skip_serializing_if = "Option::is_none")]
454    pub cursor_exact_match_rate: Option<f32>,
455    #[serde(skip_serializing_if = "Option::is_none")]
456    pub cursor_avg_distance: Option<f32>,
457    #[serde(skip_serializing_if = "Option::is_none")]
458    pub cursor_total_evaluated: Option<usize>,
459    #[serde(skip_serializing_if = "Option::is_none")]
460    pub wrong_editable_region_rate: Option<f32>,
461    pub isolated_whitespace_rate: Option<f32>,
462}
463
464pub fn compute_summary(examples: &[Example]) -> SummaryJson {
465    use crate::metrics::ClassificationMetrics;
466
467    let mut all_delta_chr_f_scores = Vec::new();
468    let mut all_reversal_ratios = Vec::new();
469    let mut braces_disbalance_sum: usize = 0;
470    let mut total_exact_lines = ClassificationMetrics::default();
471    let mut total_scores: usize = 0;
472    let mut qa_reverts_count: usize = 0;
473    let mut qa_reverts_total: usize = 0;
474    let mut qa_confidence_sum: u64 = 0;
475    let mut qa_confidence_count: usize = 0;
476    let mut cursor_exact_matches: usize = 0;
477    let mut cursor_total: usize = 0;
478    let mut cursor_distance_sum: usize = 0;
479    let mut cursor_distance_count: usize = 0;
480    let mut wrong_editable_region_count: usize = 0;
481    let mut wrong_editable_region_total: usize = 0;
482    let mut isolated_whitespace_count: usize = 0;
483
484    for example in examples {
485        for (score_idx, score) in example.score.iter().enumerate() {
486            all_delta_chr_f_scores.push(score.delta_chr_f);
487            all_reversal_ratios.push(score.reversal_ratio);
488            total_scores += 1;
489            braces_disbalance_sum += score.braces_disbalance;
490            total_exact_lines.true_positives += score.exact_lines_tp;
491            total_exact_lines.false_positives += score.exact_lines_fp;
492            total_exact_lines.false_negatives += score.exact_lines_fn;
493
494            // Accumulate QA metrics
495            if let Some(Some(qa)) = example.qa.get(score_idx) {
496                if let Some(reverts) = qa.reverts_edits {
497                    qa_reverts_total += 1;
498                    if reverts {
499                        qa_reverts_count += 1;
500                    }
501                }
502                if let Some(conf) = qa.confidence {
503                    qa_confidence_sum += conf as u64;
504                    qa_confidence_count += 1;
505                }
506            }
507
508            // Accumulate wrong editable region metrics
509            if let Some(wrong) = score.wrong_editable_region {
510                wrong_editable_region_total += 1;
511                if wrong {
512                    wrong_editable_region_count += 1;
513                }
514            }
515
516            // Accumulate isolated whitespace metrics
517            if score.has_isolated_whitespace_changes {
518                isolated_whitespace_count += 1;
519            }
520
521            // Accumulate cursor metrics
522            if let Some(exact_match) = score.cursor_exact_match {
523                cursor_total += 1;
524                if exact_match {
525                    cursor_exact_matches += 1;
526                }
527            }
528            if let Some(dist) = score.cursor_distance {
529                cursor_distance_sum += dist;
530                cursor_distance_count += 1;
531            }
532        }
533    }
534
535    let avg_delta_chr_f = if all_delta_chr_f_scores.is_empty() {
536        0.0
537    } else {
538        all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32
539    };
540
541    let avg_reversal_ratio = if all_reversal_ratios.is_empty() {
542        0.0
543    } else {
544        all_reversal_ratios.iter().sum::<f32>() / all_reversal_ratios.len() as f32
545    };
546
547    let avg_braces_disbalance = if total_scores == 0 {
548        0.0
549    } else {
550        braces_disbalance_sum as f32 / total_scores as f32
551    };
552
553    let qa_avg_reverts_edits = if qa_reverts_total > 0 {
554        Some(qa_reverts_count as f32 / qa_reverts_total as f32)
555    } else {
556        None
557    };
558
559    let qa_avg_confidence = if qa_confidence_count > 0 {
560        Some(qa_confidence_sum as f32 / qa_confidence_count as f32)
561    } else {
562        None
563    };
564
565    let cursor_exact_match_rate = if cursor_total > 0 {
566        Some(cursor_exact_matches as f32 / cursor_total as f32)
567    } else {
568        None
569    };
570
571    let cursor_avg_distance = if cursor_distance_count > 0 {
572        Some(cursor_distance_sum as f32 / cursor_distance_count as f32)
573    } else {
574        None
575    };
576
577    let cursor_total_evaluated = if cursor_total > 0 {
578        Some(cursor_total)
579    } else {
580        None
581    };
582
583    let wrong_editable_region_rate = if wrong_editable_region_total > 0 {
584        Some(wrong_editable_region_count as f32 / wrong_editable_region_total as f32)
585    } else {
586        None
587    };
588
589    let isolated_whitespace_rate = if total_scores > 0 {
590        Some(isolated_whitespace_count as f32 / total_scores as f32)
591    } else {
592        None
593    };
594
595    SummaryJson {
596        total_examples: total_scores,
597        avg_delta_chr_f,
598        avg_braces_disbalance,
599        exact_lines_true_positives: total_exact_lines.true_positives,
600        exact_lines_false_positives: total_exact_lines.false_positives,
601        exact_lines_false_negatives: total_exact_lines.false_negatives,
602        exact_lines_precision: total_exact_lines.precision(),
603        exact_lines_recall: total_exact_lines.recall(),
604        exact_lines_f1: total_exact_lines.f1(),
605        avg_reversal_ratio,
606        qa_avg_reverts_edits,
607        qa_avg_confidence,
608        cursor_exact_match_rate,
609        cursor_avg_distance,
610        cursor_total_evaluated,
611        wrong_editable_region_rate,
612        isolated_whitespace_rate,
613    }
614}
615
616pub fn write_summary_json(examples: &[Example], path: &Path) -> anyhow::Result<()> {
617    let summary = compute_summary(examples);
618    let file = File::create(path)
619        .with_context(|| format!("Failed to create summary JSON file: {}", path.display()))?;
620    let writer = BufWriter::new(file);
621    serde_json::to_writer_pretty(writer, &summary)
622        .with_context(|| format!("Failed to write summary JSON to: {}", path.display()))?;
623    eprintln!("Wrote summary JSON to: {}", path.display());
624    Ok(())
625}