1use crate::{
2 PredictArgs, PredictionProvider,
3 example::{Example, ExampleScore},
4 format_prompt::TeacherPrompt,
5 headless::EpAppState,
6 metrics,
7 parse_output::parse_prediction_output,
8 predict::run_prediction,
9 progress::{ExampleProgress, Step},
10 reversal_tracking,
11};
12use anyhow::Context as _;
13use edit_prediction::udiff::{apply_diff_to_string, apply_diff_to_string_with_hunk_offset};
14use gpui::AsyncApp;
15use serde::Serialize;
16use std::fs::File;
17use std::io::BufWriter;
18use std::path::Path;
19use std::sync::Arc;
20
21pub async fn run_scoring(
22 example: &mut Example,
23 args: &PredictArgs,
24 app_state: Arc<EpAppState>,
25 example_progress: &ExampleProgress,
26 cx: AsyncApp,
27) -> anyhow::Result<()> {
28 run_prediction(example, args, app_state, example_progress, cx).await?;
29
30 let progress = example_progress.start(Step::Score);
31
32 progress.set_substatus("applying patches");
33 let original_text = &example
34 .prompt_inputs
35 .as_ref()
36 .context("prompt_inputs is required for scoring - run prediction first or ensure JSON includes prompt_inputs")?
37 .content;
38 let expected_patches_with_cursors = example.spec.expected_patches_with_cursor_positions();
39
40 let expected_texts: Vec<String> = expected_patches_with_cursors
41 .iter()
42 .map(|(patch, _)| {
43 apply_diff_to_string(patch, original_text)
44 .with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
45 })
46 .collect::<Result<Vec<_>, _>>()?;
47
48 // For Teacher prompts, we need to extract the editable region to properly compute cursor offsets.
49 // The actual_cursor_offset from Teacher is relative to the editable region, while the expected
50 // cursor from the patch is relative to the hunk. We need to apply the patch to the editable
51 // region to find where the hunk matched, then compute the expected cursor position.
52 let old_editable_region = if let Some(p) = example.prompt.as_ref() {
53 if matches!(
54 p.provider,
55 PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_)
56 ) {
57 Some(
58 TeacherPrompt::extract_editable_region(&p.input)?
59 .replace(TeacherPrompt::USER_CURSOR_MARKER, ""),
60 )
61 } else {
62 None
63 }
64 } else {
65 None
66 };
67
68 let zero_scores = ExampleScore {
69 delta_chr_f: 0.0,
70 braces_disbalance: 0,
71 exact_lines_tp: 0,
72 exact_lines_fp: 0,
73 exact_lines_fn: 0,
74 reversal_ratio: 0.0,
75 cursor_distance: None,
76 cursor_exact_match: None,
77 };
78
79 let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
80 let cursor_path = example.spec.cursor_path.as_ref();
81
82 progress.set_substatus("computing metrics");
83 let mut scores = vec![];
84 for prediction in &example.predictions {
85 let actual_patch = prediction.actual_patch.clone().or_else(|| {
86 parse_prediction_output(example, &prediction.actual_output, prediction.provider)
87 .ok()
88 .map(|(patch, _)| patch)
89 });
90
91 let Some(actual_patch) = actual_patch else {
92 scores.push(zero_scores.clone());
93 continue;
94 };
95
96 let actual_text = match apply_diff_to_string(&actual_patch, original_text) {
97 Ok(text) => text,
98 Err(_) => {
99 scores.push(zero_scores.clone());
100 continue;
101 }
102 };
103
104 let mut best_delta_chr_f = 0.0f32;
105 let mut best_expected_cursor: Option<usize> = None;
106 let mut best_patch_idx: Option<usize> = None;
107
108 for (idx, expected) in expected_texts.iter().enumerate() {
109 let delta_chr_f = metrics::delta_chr_f(original_text, expected, &actual_text) as f32;
110 if delta_chr_f > best_delta_chr_f {
111 best_delta_chr_f = delta_chr_f;
112 best_patch_idx = Some(idx);
113 }
114 }
115
116 if let Some(idx) = best_patch_idx {
117 // Get the raw cursor offset from the expected patch (relative to hunk new text)
118 let expected_cursor_in_patch = expected_patches_with_cursors
119 .get(idx)
120 .and_then(|(_, cursor)| *cursor);
121
122 // For Teacher prompts, we need to apply the patch to the editable region
123 // to find where the hunk matched, then compute the actual cursor position
124 if let (Some(editable_region), Some(cursor_in_patch)) =
125 (&old_editable_region, expected_cursor_in_patch)
126 {
127 let (patch, _) = &expected_patches_with_cursors[idx];
128 if let Ok((_, hunk_offset)) =
129 apply_diff_to_string_with_hunk_offset(patch, editable_region)
130 {
131 let hunk_start = hunk_offset.unwrap_or(0);
132 best_expected_cursor = Some(hunk_start + cursor_in_patch);
133 }
134 } else {
135 // For non-Teacher prompts or if we can't compute, use raw offset
136 best_expected_cursor = expected_cursor_in_patch;
137 }
138 }
139
140 let disbalance_before = metrics::braces_disbalance(&original_text);
141 let disbalance_after = metrics::braces_disbalance(&actual_text);
142 let braces_disbalance = disbalance_after.saturating_sub(disbalance_before);
143 if braces_disbalance > 0 {
144 std::fs::write(
145 "/tmp/unbalanced-count.before",
146 disbalance_before.to_string(),
147 )
148 .ok();
149 std::fs::write("/tmp/unbalanced-count.after", disbalance_after.to_string()).ok();
150 std::fs::write("/tmp/unbalanced-text.before", &original_text).ok();
151 std::fs::write("/tmp/unbalanced-text.after", &actual_text).ok();
152 }
153
154 // Compute exact lines match against best matching expected patch
155 let best_exact_lines = expected_patches_with_cursors
156 .iter()
157 .map(|(expected_patch, _)| metrics::exact_lines_match(expected_patch, &actual_patch))
158 .max_by_key(|m| m.true_positives)
159 .unwrap_or_default();
160
161 // Compute reversal ratio
162 let reversal_ratio = reversal_tracking::compute_prediction_reversal_ratio(
163 prompt_inputs,
164 &actual_text,
165 cursor_path,
166 );
167
168 // Compute cursor position metrics
169 let (cursor_distance, cursor_exact_match) =
170 compute_cursor_metrics(best_expected_cursor, prediction.actual_cursor_offset);
171
172 scores.push(ExampleScore {
173 delta_chr_f: best_delta_chr_f,
174 braces_disbalance,
175 exact_lines_tp: best_exact_lines.true_positives,
176 exact_lines_fp: best_exact_lines.false_positives,
177 exact_lines_fn: best_exact_lines.false_negatives,
178 reversal_ratio,
179 cursor_distance,
180 cursor_exact_match,
181 });
182 }
183
184 example.score = scores;
185 Ok(())
186}
187
188fn compute_cursor_metrics(
189 expected_cursor: Option<usize>,
190 actual_cursor: Option<usize>,
191) -> (Option<usize>, Option<bool>) {
192 match (expected_cursor, actual_cursor) {
193 (Some(expected), Some(actual)) => {
194 let distance = expected.abs_diff(actual);
195 let exact_match = expected == actual;
196 (Some(distance), Some(exact_match))
197 }
198 (None, None) => {
199 // Neither has cursor position - skip cursor scoring
200 (None, None)
201 }
202 (Some(_), None) | (None, Some(_)) => {
203 // Only one has cursor position - count as miss
204 (None, Some(false))
205 }
206 }
207}
208
209pub fn print_report(examples: &[Example]) {
210 use crate::metrics::ClassificationMetrics;
211
212 const LINE_WIDTH: usize = 94;
213 let separator = "─".repeat(LINE_WIDTH);
214
215 println!("{}", separator);
216 println!(
217 "{:<40} {:>8} {:>5} {:>7} {:>7} {:>7} {:>7} {:>6}",
218 "Example", "DeltaChrF", "Brace", "F1", "Revert", "QaRev", "QaConf", "Cursor"
219 );
220 println!("{}", separator);
221
222 let mut all_delta_chr_f_scores = Vec::new();
223 let mut all_reversal_ratios = Vec::new();
224 let mut braces_disbalance_sum: usize = 0;
225 let mut total_exact_lines = ClassificationMetrics::default();
226 let mut total_scores: usize = 0;
227 let mut qa_reverts_count: usize = 0;
228 let mut qa_reverts_total: usize = 0;
229 let mut qa_confidence_sum: u64 = 0;
230 let mut qa_confidence_count: usize = 0;
231 let mut cursor_exact_matches: usize = 0;
232 let mut cursor_total: usize = 0;
233 let mut cursor_distance_sum: usize = 0;
234 let mut cursor_distance_count: usize = 0;
235
236 for example in examples {
237 for (score_idx, score) in example.score.iter().enumerate() {
238 let exact_lines = ClassificationMetrics {
239 true_positives: score.exact_lines_tp,
240 false_positives: score.exact_lines_fp,
241 false_negatives: score.exact_lines_fn,
242 };
243
244 // Get QA results for this prediction if available
245 let qa_result = example.qa.get(score_idx).and_then(|q| q.as_ref());
246 let qa_reverts_str = qa_result
247 .and_then(|q| q.reverts_edits)
248 .map(|v| if v { "yes" } else { "no" })
249 .unwrap_or("-");
250 let qa_conf_str = qa_result
251 .and_then(|q| q.confidence)
252 .map(|v| format!("{}", v))
253 .unwrap_or("-".to_string());
254
255 // Format cursor metric
256 let cursor_str = match (score.cursor_exact_match, score.cursor_distance) {
257 (Some(true), _) => "✓".to_string(),
258 (Some(false), Some(dist)) => format!("±{}", dist),
259 (Some(false), None) => "✗".to_string(),
260 (None, _) => "-".to_string(),
261 };
262
263 println!(
264 "{:<40} {:>8.2} {:>5} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6}",
265 truncate_name(&example.spec.name, 40),
266 score.delta_chr_f,
267 score.braces_disbalance,
268 exact_lines.f1() * 100.0,
269 score.reversal_ratio * 100.0,
270 qa_reverts_str,
271 qa_conf_str,
272 cursor_str
273 );
274
275 all_delta_chr_f_scores.push(score.delta_chr_f);
276 all_reversal_ratios.push(score.reversal_ratio);
277 total_scores += 1;
278 braces_disbalance_sum += score.braces_disbalance;
279 total_exact_lines.true_positives += score.exact_lines_tp;
280 total_exact_lines.false_positives += score.exact_lines_fp;
281 total_exact_lines.false_negatives += score.exact_lines_fn;
282
283 // Accumulate QA metrics
284 if let Some(qa) = qa_result {
285 if let Some(reverts) = qa.reverts_edits {
286 qa_reverts_total += 1;
287 if reverts {
288 qa_reverts_count += 1;
289 }
290 }
291 if let Some(conf) = qa.confidence {
292 qa_confidence_sum += conf as u64;
293 qa_confidence_count += 1;
294 }
295 }
296
297 // Accumulate cursor metrics
298 if let Some(exact_match) = score.cursor_exact_match {
299 cursor_total += 1;
300 if exact_match {
301 cursor_exact_matches += 1;
302 }
303 }
304 if let Some(dist) = score.cursor_distance {
305 cursor_distance_sum += dist;
306 cursor_distance_count += 1;
307 }
308 }
309 }
310
311 println!("{}", separator);
312
313 if !all_delta_chr_f_scores.is_empty() {
314 let avg_delta_chr_f: f32 =
315 all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
316 let avg_reversal_ratio: f32 =
317 all_reversal_ratios.iter().sum::<f32>() / all_reversal_ratios.len() as f32;
318 let braces_disbalance_avg: f32 = braces_disbalance_sum as f32 / total_scores as f32;
319
320 let qa_reverts_str = if qa_reverts_total > 0 {
321 format!(
322 "{:.1}%",
323 qa_reverts_count as f32 / qa_reverts_total as f32 * 100.0
324 )
325 } else {
326 "-".to_string()
327 };
328 let qa_conf_str = if qa_confidence_count > 0 {
329 format!(
330 "{:.1}",
331 qa_confidence_sum as f32 / qa_confidence_count as f32
332 )
333 } else {
334 "-".to_string()
335 };
336 let cursor_str = if cursor_total > 0 {
337 format!(
338 "{:.0}%",
339 cursor_exact_matches as f32 / cursor_total as f32 * 100.0
340 )
341 } else {
342 "-".to_string()
343 };
344 let avg_cursor_distance = if cursor_distance_count > 0 {
345 Some(cursor_distance_sum as f32 / cursor_distance_count as f32)
346 } else {
347 None
348 };
349
350 println!(
351 "{:<40} {:>8.2} {:>5.1} {:>6.1}% {:>6.1}% {:>7} {:>7} {:>6}",
352 "TOTAL / AVERAGE",
353 avg_delta_chr_f,
354 braces_disbalance_avg,
355 total_exact_lines.f1() * 100.0,
356 avg_reversal_ratio * 100.0,
357 qa_reverts_str,
358 qa_conf_str,
359 cursor_str
360 );
361 println!("{}", separator);
362
363 // Print additional cursor metrics if available
364 if let Some(avg_dist) = avg_cursor_distance {
365 println!(
366 "Cursor: {}/{} exact matches ({:.0}%), avg distance: {:.1} bytes",
367 cursor_exact_matches,
368 cursor_total,
369 cursor_exact_matches as f32 / cursor_total as f32 * 100.0,
370 avg_dist
371 );
372 }
373 }
374
375 println!("\n");
376}
377
378fn truncate_name(name: &str, max_len: usize) -> String {
379 if name.len() <= max_len {
380 name.to_string()
381 } else {
382 format!("{}...", &name[..max_len - 3])
383 }
384}
385
386#[derive(Serialize)]
387pub struct SummaryJson {
388 pub total_examples: usize,
389 pub avg_delta_chr_f: f32,
390 pub avg_braces_disbalance: f32,
391 pub exact_lines_true_positives: usize,
392 pub exact_lines_false_positives: usize,
393 pub exact_lines_false_negatives: usize,
394 pub exact_lines_precision: f64,
395 pub exact_lines_recall: f64,
396 pub exact_lines_f1: f64,
397 pub avg_reversal_ratio: f32,
398 #[serde(skip_serializing_if = "Option::is_none")]
399 pub qa_avg_reverts_edits: Option<f32>,
400 #[serde(skip_serializing_if = "Option::is_none")]
401 pub qa_avg_confidence: Option<f32>,
402 #[serde(skip_serializing_if = "Option::is_none")]
403 pub cursor_exact_match_rate: Option<f32>,
404 #[serde(skip_serializing_if = "Option::is_none")]
405 pub cursor_avg_distance: Option<f32>,
406 #[serde(skip_serializing_if = "Option::is_none")]
407 pub cursor_total_evaluated: Option<usize>,
408}
409
410pub fn compute_summary(examples: &[Example]) -> SummaryJson {
411 use crate::metrics::ClassificationMetrics;
412
413 let mut all_delta_chr_f_scores = Vec::new();
414 let mut all_reversal_ratios = Vec::new();
415 let mut braces_disbalance_sum: usize = 0;
416 let mut total_exact_lines = ClassificationMetrics::default();
417 let mut total_scores: usize = 0;
418 let mut qa_reverts_count: usize = 0;
419 let mut qa_reverts_total: usize = 0;
420 let mut qa_confidence_sum: u64 = 0;
421 let mut qa_confidence_count: usize = 0;
422 let mut cursor_exact_matches: usize = 0;
423 let mut cursor_total: usize = 0;
424 let mut cursor_distance_sum: usize = 0;
425 let mut cursor_distance_count: usize = 0;
426
427 for example in examples {
428 for (score_idx, score) in example.score.iter().enumerate() {
429 all_delta_chr_f_scores.push(score.delta_chr_f);
430 all_reversal_ratios.push(score.reversal_ratio);
431 total_scores += 1;
432 braces_disbalance_sum += score.braces_disbalance;
433 total_exact_lines.true_positives += score.exact_lines_tp;
434 total_exact_lines.false_positives += score.exact_lines_fp;
435 total_exact_lines.false_negatives += score.exact_lines_fn;
436
437 // Accumulate QA metrics
438 if let Some(Some(qa)) = example.qa.get(score_idx) {
439 if let Some(reverts) = qa.reverts_edits {
440 qa_reverts_total += 1;
441 if reverts {
442 qa_reverts_count += 1;
443 }
444 }
445 if let Some(conf) = qa.confidence {
446 qa_confidence_sum += conf as u64;
447 qa_confidence_count += 1;
448 }
449 }
450
451 // Accumulate cursor metrics
452 if let Some(exact_match) = score.cursor_exact_match {
453 cursor_total += 1;
454 if exact_match {
455 cursor_exact_matches += 1;
456 }
457 }
458 if let Some(dist) = score.cursor_distance {
459 cursor_distance_sum += dist;
460 cursor_distance_count += 1;
461 }
462 }
463 }
464
465 let avg_delta_chr_f = if all_delta_chr_f_scores.is_empty() {
466 0.0
467 } else {
468 all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32
469 };
470
471 let avg_reversal_ratio = if all_reversal_ratios.is_empty() {
472 0.0
473 } else {
474 all_reversal_ratios.iter().sum::<f32>() / all_reversal_ratios.len() as f32
475 };
476
477 let avg_braces_disbalance = if total_scores == 0 {
478 0.0
479 } else {
480 braces_disbalance_sum as f32 / total_scores as f32
481 };
482
483 let qa_avg_reverts_edits = if qa_reverts_total > 0 {
484 Some(qa_reverts_count as f32 / qa_reverts_total as f32)
485 } else {
486 None
487 };
488
489 let qa_avg_confidence = if qa_confidence_count > 0 {
490 Some(qa_confidence_sum as f32 / qa_confidence_count as f32)
491 } else {
492 None
493 };
494
495 let cursor_exact_match_rate = if cursor_total > 0 {
496 Some(cursor_exact_matches as f32 / cursor_total as f32)
497 } else {
498 None
499 };
500
501 let cursor_avg_distance = if cursor_distance_count > 0 {
502 Some(cursor_distance_sum as f32 / cursor_distance_count as f32)
503 } else {
504 None
505 };
506
507 let cursor_total_evaluated = if cursor_total > 0 {
508 Some(cursor_total)
509 } else {
510 None
511 };
512
513 SummaryJson {
514 total_examples: total_scores,
515 avg_delta_chr_f,
516 avg_braces_disbalance,
517 exact_lines_true_positives: total_exact_lines.true_positives,
518 exact_lines_false_positives: total_exact_lines.false_positives,
519 exact_lines_false_negatives: total_exact_lines.false_negatives,
520 exact_lines_precision: total_exact_lines.precision(),
521 exact_lines_recall: total_exact_lines.recall(),
522 exact_lines_f1: total_exact_lines.f1(),
523 avg_reversal_ratio,
524 qa_avg_reverts_edits,
525 qa_avg_confidence,
526 cursor_exact_match_rate,
527 cursor_avg_distance,
528 cursor_total_evaluated,
529 }
530}
531
532pub fn write_summary_json(examples: &[Example], path: &Path) -> anyhow::Result<()> {
533 let summary = compute_summary(examples);
534 let file = File::create(path)
535 .with_context(|| format!("Failed to create summary JSON file: {}", path.display()))?;
536 let writer = BufWriter::new(file);
537 serde_json::to_writer_pretty(writer, &summary)
538 .with_context(|| format!("Failed to write summary JSON to: {}", path.display()))?;
539 eprintln!("Wrote summary JSON to: {}", path.display());
540 Ok(())
541}