1use crate::{
2 PredictArgs, PredictionProvider,
3 example::{ActualCursor, Example, ExampleScore},
4 format_prompt::TeacherPrompt,
5 headless::EpAppState,
6 metrics,
7 parse_output::parse_prediction_output,
8 predict::run_prediction,
9 progress::{ExampleProgress, Step},
10 reversal_tracking,
11};
12use anyhow::Context as _;
13use edit_prediction::udiff::{apply_diff_to_string, apply_diff_to_string_with_hunk_offset};
14use gpui::AsyncApp;
15use serde::Serialize;
16use std::fs::File;
17use std::io::BufWriter;
18use std::path::Path;
19use std::sync::Arc;
20
21pub async fn run_scoring(
22 example: &mut Example,
23 args: &PredictArgs,
24 app_state: Arc<EpAppState>,
25 example_progress: &ExampleProgress,
26 cx: AsyncApp,
27) -> anyhow::Result<()> {
28 run_prediction(example, args, app_state, example_progress, cx).await?;
29
30 let progress = example_progress.start(Step::Score);
31
32 progress.set_substatus("applying patches");
33 let original_text = &example
34 .prompt_inputs
35 .as_ref()
36 .context("prompt_inputs is required for scoring - run prediction first or ensure JSON includes prompt_inputs")?
37 .content;
38 let expected_patches_with_cursors = example.spec.expected_patches_with_cursor_positions();
39
40 let expected_texts: Vec<String> = expected_patches_with_cursors
41 .iter()
42 .map(|(patch, _)| {
43 apply_diff_to_string(patch, original_text)
44 .with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
45 })
46 .collect::<Result<Vec<_>, _>>()?;
47
48 // For Teacher prompts, we need to extract the editable region to properly compute cursor offsets.
49 // The actual_cursor_offset from Teacher is relative to the editable region, while the expected
50 // cursor from the patch is relative to the hunk. We need to apply the patch to the editable
51 // region to find where the hunk matched, then compute the expected cursor position.
52 let old_editable_region = if let Some(p) = example.prompt.as_ref() {
53 if matches!(
54 p.provider,
55 PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_)
56 ) {
57 Some(
58 TeacherPrompt::extract_editable_region(&p.input)?
59 .replace(TeacherPrompt::USER_CURSOR_MARKER, ""),
60 )
61 } else {
62 None
63 }
64 } else {
65 None
66 };
67
68 let zero_scores = ExampleScore {
69 delta_chr_f: 0.0,
70 braces_disbalance: 0,
71 exact_lines_tp: 0,
72 exact_lines_fp: 0,
73 exact_lines_fn: 0,
74 reversal_ratio: 0.0,
75 cursor_distance: None,
76 cursor_exact_match: None,
77 wrong_editable_region: None,
78 has_isolated_whitespace_changes: false,
79 };
80
81 let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
82 let cursor_path = example.spec.cursor_path.as_ref();
83
84 progress.set_substatus("computing metrics");
85 let mut scores = vec![];
86 for prediction in &example.predictions {
87 let actual_patch = prediction.actual_patch.clone().or_else(|| {
88 parse_prediction_output(example, &prediction.actual_output, prediction.provider)
89 .ok()
90 .map(|(patch, _)| patch)
91 });
92
93 let Some(actual_patch) = actual_patch else {
94 scores.push(zero_scores.clone());
95 continue;
96 };
97
98 let actual_text = match apply_diff_to_string(&actual_patch, original_text) {
99 Ok(text) => text,
100 Err(_) => {
101 scores.push(zero_scores.clone());
102 continue;
103 }
104 };
105
106 let mut best_delta_chr_f = 0.0f32;
107 let mut best_expected_cursor: Option<usize> = None;
108 let mut best_patch_idx: Option<usize> = None;
109
110 for (idx, expected) in expected_texts.iter().enumerate() {
111 let delta_chr_f = metrics::delta_chr_f(original_text, expected, &actual_text) as f32;
112 if delta_chr_f > best_delta_chr_f {
113 best_delta_chr_f = delta_chr_f;
114 best_patch_idx = Some(idx);
115 }
116 }
117
118 if let Some(idx) = best_patch_idx {
119 // Get the raw cursor offset from the expected patch (relative to hunk new text)
120 let expected_cursor_in_patch = expected_patches_with_cursors
121 .get(idx)
122 .and_then(|(_, cursor)| *cursor);
123
124 // For Teacher prompts, we need to apply the patch to the editable region
125 // to find where the hunk matched, then compute the actual cursor position
126 if let (Some(editable_region), Some(cursor_in_patch)) =
127 (&old_editable_region, expected_cursor_in_patch)
128 {
129 let (patch, _) = &expected_patches_with_cursors[idx];
130 if let Ok((_, hunk_offset)) =
131 apply_diff_to_string_with_hunk_offset(patch, editable_region)
132 {
133 let hunk_start = hunk_offset.unwrap_or(0);
134 best_expected_cursor = Some(hunk_start + cursor_in_patch);
135 }
136 } else {
137 // For non-Teacher prompts or if we can't compute, use raw offset
138 best_expected_cursor = expected_cursor_in_patch;
139 }
140 }
141
142 let disbalance_before = metrics::braces_disbalance(&original_text);
143 let disbalance_after = metrics::braces_disbalance(&actual_text);
144 let braces_disbalance = disbalance_after.saturating_sub(disbalance_before);
145
146 // Compute exact lines match against best matching expected patch
147 let best_exact_lines = expected_patches_with_cursors
148 .iter()
149 .map(|(expected_patch, _)| metrics::exact_lines_match(expected_patch, &actual_patch))
150 .max_by_key(|m| m.true_positives)
151 .unwrap_or_default();
152
153 // Compute reversal ratio
154 let reversal_ratio = reversal_tracking::compute_prediction_reversal_ratio(
155 prompt_inputs,
156 &actual_text,
157 cursor_path,
158 );
159
160 // Compute cursor position metrics
161 let (cursor_distance, cursor_exact_match) =
162 compute_cursor_metrics(best_expected_cursor, prediction.actual_cursor.as_ref());
163
164 // Compute approximation of editable region correctness
165 let wrong_editable_region = Some(!metrics::is_editable_region_correct(&actual_patch));
166
167 // Check for isolated whitespace changes.
168 let has_isolated_whitespace_changes = metrics::has_isolated_whitespace_changes(
169 &actual_patch,
170 prediction.actual_cursor.as_ref(),
171 );
172
173 scores.push(ExampleScore {
174 delta_chr_f: best_delta_chr_f,
175 braces_disbalance,
176 exact_lines_tp: best_exact_lines.true_positives,
177 exact_lines_fp: best_exact_lines.false_positives,
178 exact_lines_fn: best_exact_lines.false_negatives,
179 reversal_ratio,
180 cursor_distance,
181 cursor_exact_match,
182 wrong_editable_region,
183 has_isolated_whitespace_changes,
184 });
185 }
186
187 example.score = scores;
188 Ok(())
189}
190
191fn compute_cursor_metrics(
192 expected_cursor_editable_region_offset: Option<usize>,
193 actual_cursor: Option<&ActualCursor>,
194) -> (Option<usize>, Option<bool>) {
195 match (expected_cursor_editable_region_offset, actual_cursor) {
196 (Some(expected), Some(actual)) => {
197 let distance = expected.abs_diff(actual.editable_region_offset.unwrap_or_default());
198 let exact_match = distance == 0;
199 (Some(distance), Some(exact_match))
200 }
201 (None, None) => {
202 // Neither has cursor position - skip cursor scoring
203 (None, None)
204 }
205 (Some(_), None) | (None, Some(_)) => {
206 // Only one has cursor position - count as miss
207 (None, Some(false))
208 }
209 }
210}
211
212pub fn print_report(examples: &[Example]) {
213 use crate::metrics::ClassificationMetrics;
214
215 const LINE_WIDTH: usize = 101;
216 let separator = "─".repeat(LINE_WIDTH);
217
218 println!("{}", separator);
219 println!(
220 "{:<40} {:>8} {:>5} {:>7} {:>7} {:>7} {:>7} {:>6} {:>5}",
221 "Example", "DeltaChrF", "Brace", "F1", "Revert", "QaRev", "QaConf", "Cursor", "WrgER"
222 );
223 println!("{}", separator);
224
225 let mut all_delta_chr_f_scores = Vec::new();
226 let mut all_reversal_ratios = Vec::new();
227 let mut braces_disbalance_sum: usize = 0;
228 let mut total_exact_lines = ClassificationMetrics::default();
229 let mut total_scores: usize = 0;
230 let mut qa_reverts_count: usize = 0;
231 let mut qa_reverts_total: usize = 0;
232 let mut qa_confidence_sum: u64 = 0;
233 let mut qa_confidence_count: usize = 0;
234 let mut cursor_exact_matches: usize = 0;
235 let mut cursor_total: usize = 0;
236 let mut cursor_distance_sum: usize = 0;
237 let mut cursor_distance_count: usize = 0;
238 let mut wrong_editable_region_count: usize = 0;
239 let mut wrong_editable_region_total: usize = 0;
240 let mut isolated_whitespace_count: usize = 0;
241
242 for example in examples {
243 for (score_idx, score) in example.score.iter().enumerate() {
244 let exact_lines = ClassificationMetrics {
245 true_positives: score.exact_lines_tp,
246 false_positives: score.exact_lines_fp,
247 false_negatives: score.exact_lines_fn,
248 };
249
250 // Get QA results for this prediction if available
251 let qa_result = example.qa.get(score_idx).and_then(|q| q.as_ref());
252 let qa_reverts_str = qa_result
253 .and_then(|q| q.reverts_edits)
254 .map(|v| if v { "yes" } else { "no" })
255 .unwrap_or("-");
256 let qa_conf_str = qa_result
257 .and_then(|q| q.confidence)
258 .map(|v| format!("{}", v))
259 .unwrap_or("-".to_string());
260
261 // Format wrong editable region metric
262 let wrong_er_str = match score.wrong_editable_region {
263 Some(true) => "✗",
264 Some(false) => "",
265 None => "",
266 };
267
268 // Format cursor metric
269 let cursor_str = match (score.cursor_exact_match, score.cursor_distance) {
270 (Some(true), _) => "✓".to_string(),
271 (Some(false), Some(dist)) => format!("±{}", dist),
272 (Some(false), None) => "✗".to_string(),
273 (None, _) => "-".to_string(),
274 };
275
276 println!(
277 "{:<40} {:>8.2} {:>5} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6} {:>5}",
278 truncate_name(&example.spec.name, 40),
279 score.delta_chr_f,
280 score.braces_disbalance,
281 exact_lines.f1() * 100.0,
282 score.reversal_ratio * 100.0,
283 qa_reverts_str,
284 qa_conf_str,
285 cursor_str,
286 wrong_er_str
287 );
288
289 all_delta_chr_f_scores.push(score.delta_chr_f);
290 all_reversal_ratios.push(score.reversal_ratio);
291 total_scores += 1;
292 braces_disbalance_sum += score.braces_disbalance;
293 total_exact_lines.true_positives += score.exact_lines_tp;
294 total_exact_lines.false_positives += score.exact_lines_fp;
295 total_exact_lines.false_negatives += score.exact_lines_fn;
296
297 // Accumulate QA metrics
298 if let Some(qa) = qa_result {
299 if let Some(reverts) = qa.reverts_edits {
300 qa_reverts_total += 1;
301 if reverts {
302 qa_reverts_count += 1;
303 }
304 }
305 if let Some(conf) = qa.confidence {
306 qa_confidence_sum += conf as u64;
307 qa_confidence_count += 1;
308 }
309 }
310
311 // Accumulate wrong editable region metrics
312 if let Some(wrong) = score.wrong_editable_region {
313 wrong_editable_region_total += 1;
314 if wrong {
315 wrong_editable_region_count += 1;
316 }
317 }
318
319 // Accumulate isolated whitespace metrics
320 if score.has_isolated_whitespace_changes {
321 isolated_whitespace_count += 1;
322 }
323
324 // Accumulate cursor metrics
325 if let Some(exact_match) = score.cursor_exact_match {
326 cursor_total += 1;
327 if exact_match {
328 cursor_exact_matches += 1;
329 }
330 }
331 if let Some(dist) = score.cursor_distance {
332 cursor_distance_sum += dist;
333 cursor_distance_count += 1;
334 }
335 }
336 }
337
338 println!("{}", separator);
339
340 if !all_delta_chr_f_scores.is_empty() {
341 let avg_delta_chr_f: f32 =
342 all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
343 let avg_reversal_ratio: f32 =
344 all_reversal_ratios.iter().sum::<f32>() / all_reversal_ratios.len() as f32;
345 let braces_disbalance_avg: f32 = braces_disbalance_sum as f32 / total_scores as f32;
346
347 let qa_reverts_str = if qa_reverts_total > 0 {
348 format!(
349 "{:.1}%",
350 qa_reverts_count as f32 / qa_reverts_total as f32 * 100.0
351 )
352 } else {
353 "-".to_string()
354 };
355 let qa_conf_str = if qa_confidence_count > 0 {
356 format!(
357 "{:.1}",
358 qa_confidence_sum as f32 / qa_confidence_count as f32
359 )
360 } else {
361 "-".to_string()
362 };
363 let cursor_str = if cursor_total > 0 {
364 format!(
365 "{:.0}%",
366 cursor_exact_matches as f32 / cursor_total as f32 * 100.0
367 )
368 } else {
369 "-".to_string()
370 };
371 let wrong_er_str = if wrong_editable_region_total > 0 {
372 format!(
373 "{:.2}%",
374 wrong_editable_region_count as f32 / wrong_editable_region_total as f32 * 100.0
375 )
376 } else {
377 "-".to_string()
378 };
379 let isolated_ws_str = if total_scores > 0 {
380 format!(
381 "{}/{} ({:.1}%)",
382 isolated_whitespace_count,
383 total_scores,
384 isolated_whitespace_count as f32 / total_scores as f32 * 100.0
385 )
386 } else {
387 "-".to_string()
388 };
389 let avg_cursor_distance = if cursor_distance_count > 0 {
390 Some(cursor_distance_sum as f32 / cursor_distance_count as f32)
391 } else {
392 None
393 };
394
395 println!(
396 "{:<40} {:>8.2} {:>5.1} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6} {:>5}",
397 "TOTAL / AVERAGE",
398 avg_delta_chr_f,
399 braces_disbalance_avg,
400 total_exact_lines.f1() * 100.0,
401 avg_reversal_ratio * 100.0,
402 qa_reverts_str,
403 qa_conf_str,
404 cursor_str,
405 wrong_er_str
406 );
407 println!("{}", separator);
408
409 // Print additional cursor metrics if available
410 if let Some(avg_dist) = avg_cursor_distance {
411 println!(
412 "Cursor: {}/{} exact matches ({:.0}%), avg distance: {:.1} bytes",
413 cursor_exact_matches,
414 cursor_total,
415 cursor_exact_matches as f32 / cursor_total as f32 * 100.0,
416 avg_dist
417 );
418 }
419
420 // Print isolated whitespace metrics
421 if total_scores > 0 {
422 println!("Isolated whitespace changes: {}", isolated_ws_str);
423 }
424 }
425
426 println!("\n");
427}
428
429fn truncate_name(name: &str, max_len: usize) -> String {
430 if name.len() <= max_len {
431 name.to_string()
432 } else {
433 format!("{}...", &name[..max_len - 3])
434 }
435}
436
437#[derive(Serialize)]
438pub struct SummaryJson {
439 pub total_examples: usize,
440 pub avg_delta_chr_f: f32,
441 pub avg_braces_disbalance: f32,
442 pub exact_lines_true_positives: usize,
443 pub exact_lines_false_positives: usize,
444 pub exact_lines_false_negatives: usize,
445 pub exact_lines_precision: f64,
446 pub exact_lines_recall: f64,
447 pub exact_lines_f1: f64,
448 pub avg_reversal_ratio: f32,
449 #[serde(skip_serializing_if = "Option::is_none")]
450 pub qa_avg_reverts_edits: Option<f32>,
451 #[serde(skip_serializing_if = "Option::is_none")]
452 pub qa_avg_confidence: Option<f32>,
453 #[serde(skip_serializing_if = "Option::is_none")]
454 pub cursor_exact_match_rate: Option<f32>,
455 #[serde(skip_serializing_if = "Option::is_none")]
456 pub cursor_avg_distance: Option<f32>,
457 #[serde(skip_serializing_if = "Option::is_none")]
458 pub cursor_total_evaluated: Option<usize>,
459 #[serde(skip_serializing_if = "Option::is_none")]
460 pub wrong_editable_region_rate: Option<f32>,
461 pub isolated_whitespace_rate: Option<f32>,
462}
463
464pub fn compute_summary(examples: &[Example]) -> SummaryJson {
465 use crate::metrics::ClassificationMetrics;
466
467 let mut all_delta_chr_f_scores = Vec::new();
468 let mut all_reversal_ratios = Vec::new();
469 let mut braces_disbalance_sum: usize = 0;
470 let mut total_exact_lines = ClassificationMetrics::default();
471 let mut total_scores: usize = 0;
472 let mut qa_reverts_count: usize = 0;
473 let mut qa_reverts_total: usize = 0;
474 let mut qa_confidence_sum: u64 = 0;
475 let mut qa_confidence_count: usize = 0;
476 let mut cursor_exact_matches: usize = 0;
477 let mut cursor_total: usize = 0;
478 let mut cursor_distance_sum: usize = 0;
479 let mut cursor_distance_count: usize = 0;
480 let mut wrong_editable_region_count: usize = 0;
481 let mut wrong_editable_region_total: usize = 0;
482 let mut isolated_whitespace_count: usize = 0;
483
484 for example in examples {
485 for (score_idx, score) in example.score.iter().enumerate() {
486 all_delta_chr_f_scores.push(score.delta_chr_f);
487 all_reversal_ratios.push(score.reversal_ratio);
488 total_scores += 1;
489 braces_disbalance_sum += score.braces_disbalance;
490 total_exact_lines.true_positives += score.exact_lines_tp;
491 total_exact_lines.false_positives += score.exact_lines_fp;
492 total_exact_lines.false_negatives += score.exact_lines_fn;
493
494 // Accumulate QA metrics
495 if let Some(Some(qa)) = example.qa.get(score_idx) {
496 if let Some(reverts) = qa.reverts_edits {
497 qa_reverts_total += 1;
498 if reverts {
499 qa_reverts_count += 1;
500 }
501 }
502 if let Some(conf) = qa.confidence {
503 qa_confidence_sum += conf as u64;
504 qa_confidence_count += 1;
505 }
506 }
507
508 // Accumulate wrong editable region metrics
509 if let Some(wrong) = score.wrong_editable_region {
510 wrong_editable_region_total += 1;
511 if wrong {
512 wrong_editable_region_count += 1;
513 }
514 }
515
516 // Accumulate isolated whitespace metrics
517 if score.has_isolated_whitespace_changes {
518 isolated_whitespace_count += 1;
519 }
520
521 // Accumulate cursor metrics
522 if let Some(exact_match) = score.cursor_exact_match {
523 cursor_total += 1;
524 if exact_match {
525 cursor_exact_matches += 1;
526 }
527 }
528 if let Some(dist) = score.cursor_distance {
529 cursor_distance_sum += dist;
530 cursor_distance_count += 1;
531 }
532 }
533 }
534
535 let avg_delta_chr_f = if all_delta_chr_f_scores.is_empty() {
536 0.0
537 } else {
538 all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32
539 };
540
541 let avg_reversal_ratio = if all_reversal_ratios.is_empty() {
542 0.0
543 } else {
544 all_reversal_ratios.iter().sum::<f32>() / all_reversal_ratios.len() as f32
545 };
546
547 let avg_braces_disbalance = if total_scores == 0 {
548 0.0
549 } else {
550 braces_disbalance_sum as f32 / total_scores as f32
551 };
552
553 let qa_avg_reverts_edits = if qa_reverts_total > 0 {
554 Some(qa_reverts_count as f32 / qa_reverts_total as f32)
555 } else {
556 None
557 };
558
559 let qa_avg_confidence = if qa_confidence_count > 0 {
560 Some(qa_confidence_sum as f32 / qa_confidence_count as f32)
561 } else {
562 None
563 };
564
565 let cursor_exact_match_rate = if cursor_total > 0 {
566 Some(cursor_exact_matches as f32 / cursor_total as f32)
567 } else {
568 None
569 };
570
571 let cursor_avg_distance = if cursor_distance_count > 0 {
572 Some(cursor_distance_sum as f32 / cursor_distance_count as f32)
573 } else {
574 None
575 };
576
577 let cursor_total_evaluated = if cursor_total > 0 {
578 Some(cursor_total)
579 } else {
580 None
581 };
582
583 let wrong_editable_region_rate = if wrong_editable_region_total > 0 {
584 Some(wrong_editable_region_count as f32 / wrong_editable_region_total as f32)
585 } else {
586 None
587 };
588
589 let isolated_whitespace_rate = if total_scores > 0 {
590 Some(isolated_whitespace_count as f32 / total_scores as f32)
591 } else {
592 None
593 };
594
595 SummaryJson {
596 total_examples: total_scores,
597 avg_delta_chr_f,
598 avg_braces_disbalance,
599 exact_lines_true_positives: total_exact_lines.true_positives,
600 exact_lines_false_positives: total_exact_lines.false_positives,
601 exact_lines_false_negatives: total_exact_lines.false_negatives,
602 exact_lines_precision: total_exact_lines.precision(),
603 exact_lines_recall: total_exact_lines.recall(),
604 exact_lines_f1: total_exact_lines.f1(),
605 avg_reversal_ratio,
606 qa_avg_reverts_edits,
607 qa_avg_confidence,
608 cursor_exact_match_rate,
609 cursor_avg_distance,
610 cursor_total_evaluated,
611 wrong_editable_region_rate,
612 isolated_whitespace_rate,
613 }
614}
615
616pub fn write_summary_json(examples: &[Example], path: &Path) -> anyhow::Result<()> {
617 let summary = compute_summary(examples);
618 let file = File::create(path)
619 .with_context(|| format!("Failed to create summary JSON file: {}", path.display()))?;
620 let writer = BufWriter::new(file);
621 serde_json::to_writer_pretty(writer, &summary)
622 .with_context(|| format!("Failed to write summary JSON to: {}", path.display()))?;
623 eprintln!("Wrote summary JSON to: {}", path.display());
624 Ok(())
625}