1use crate::{
2 PredictArgs,
3 example::{Example, ExampleScore},
4 headless::EpAppState,
5 metrics,
6 parse_output::parse_prediction_output,
7 predict::run_prediction,
8 progress::{ExampleProgress, Step},
9 reversal_tracking,
10};
11use anyhow::Context as _;
12use edit_prediction::udiff::apply_diff_to_string;
13use gpui::AsyncApp;
14use serde::Serialize;
15use std::fs::File;
16use std::io::BufWriter;
17use std::path::Path;
18use std::sync::Arc;
19
20pub async fn run_scoring(
21 example: &mut Example,
22 args: &PredictArgs,
23 app_state: Arc<EpAppState>,
24 example_progress: &ExampleProgress,
25 cx: AsyncApp,
26) -> anyhow::Result<()> {
27 run_prediction(example, args, app_state, example_progress, cx).await?;
28
29 let progress = example_progress.start(Step::Score);
30
31 progress.set_substatus("applying patches");
32 let original_text = &example
33 .prompt_inputs
34 .as_ref()
35 .context("prompt_inputs is required for scoring - run prediction first or ensure JSON includes prompt_inputs")?
36 .content;
37 let expected_texts: Vec<String> = example
38 .spec
39 .expected_patches
40 .iter()
41 .map(|patch| {
42 apply_diff_to_string(patch, original_text)
43 .with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
44 })
45 .collect::<Result<Vec<_>, _>>()?;
46
47 let zero_scores = ExampleScore {
48 delta_chr_f: 0.0,
49 braces_disbalance: 0,
50 exact_lines_tp: 0,
51 exact_lines_fp: 0,
52 exact_lines_fn: 0,
53 reversal_ratio: 0.0,
54 };
55
56 let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
57 let cursor_path = example.spec.cursor_path.as_ref();
58
59 progress.set_substatus("computing metrics");
60 let mut scores = vec![];
61 for prediction in &example.predictions {
62 let actual_patch = prediction.actual_patch.clone().or_else(|| {
63 parse_prediction_output(example, &prediction.actual_output, prediction.provider).ok()
64 });
65
66 let Some(actual_patch) = actual_patch else {
67 scores.push(zero_scores.clone());
68 continue;
69 };
70
71 let actual_text = match apply_diff_to_string(&actual_patch, original_text) {
72 Ok(text) => text,
73 Err(_) => {
74 scores.push(zero_scores.clone());
75 continue;
76 }
77 };
78 let best_delta_chr_f = expected_texts
79 .iter()
80 .map(|expected| metrics::delta_chr_f(original_text, expected, &actual_text) as f32)
81 .fold(0.0, f32::max);
82
83 let disbalance_before = metrics::braces_disbalance(&original_text);
84 let disbalance_after = metrics::braces_disbalance(&actual_text);
85 let braces_disbalance = disbalance_after.saturating_sub(disbalance_before);
86 if braces_disbalance > 0 {
87 std::fs::write(
88 "/tmp/unbalanced-count.before",
89 disbalance_before.to_string(),
90 )
91 .ok();
92 std::fs::write("/tmp/unbalanced-count.after", disbalance_after.to_string()).ok();
93 std::fs::write("/tmp/unbalanced-text.before", &original_text).ok();
94 std::fs::write("/tmp/unbalanced-text.after", &actual_text).ok();
95 }
96
97 // Compute exact lines match against best matching expected patch
98 let best_exact_lines = example
99 .spec
100 .expected_patches
101 .iter()
102 .map(|expected_patch| metrics::exact_lines_match(expected_patch, &actual_patch))
103 .max_by_key(|m| m.true_positives)
104 .unwrap_or_default();
105
106 // Compute reversal ratio
107 let reversal_ratio = reversal_tracking::compute_prediction_reversal_ratio(
108 prompt_inputs,
109 &actual_text,
110 cursor_path,
111 );
112
113 scores.push(ExampleScore {
114 delta_chr_f: best_delta_chr_f,
115 braces_disbalance,
116 exact_lines_tp: best_exact_lines.true_positives,
117 exact_lines_fp: best_exact_lines.false_positives,
118 exact_lines_fn: best_exact_lines.false_negatives,
119 reversal_ratio,
120 });
121 }
122
123 example.score = scores;
124 Ok(())
125}
126
127pub fn print_report(examples: &[Example]) {
128 use crate::metrics::ClassificationMetrics;
129
130 const LINE_WIDTH: usize = 82;
131 let separator = "─".repeat(LINE_WIDTH);
132
133 println!("{}", separator);
134 println!(
135 "{:<40} {:>8} {:>5} {:>7} {:>7} {:>7} {:>7}",
136 "Example", "DeltaChrF", "Brace", "F1", "Revert", "QaRev", "QaConf"
137 );
138 println!("{}", separator);
139
140 let mut all_delta_chr_f_scores = Vec::new();
141 let mut all_reversal_ratios = Vec::new();
142 let mut braces_disbalance_sum: usize = 0;
143 let mut total_exact_lines = ClassificationMetrics::default();
144 let mut total_scores: usize = 0;
145 let mut qa_reverts_count: usize = 0;
146 let mut qa_reverts_total: usize = 0;
147 let mut qa_confidence_sum: u64 = 0;
148 let mut qa_confidence_count: usize = 0;
149
150 for example in examples {
151 for (score_idx, score) in example.score.iter().enumerate() {
152 let exact_lines = ClassificationMetrics {
153 true_positives: score.exact_lines_tp,
154 false_positives: score.exact_lines_fp,
155 false_negatives: score.exact_lines_fn,
156 };
157
158 // Get QA results for this prediction if available
159 let qa_result = example.qa.get(score_idx).and_then(|q| q.as_ref());
160 let qa_reverts_str = qa_result
161 .and_then(|q| q.reverts_edits)
162 .map(|v| if v { "yes" } else { "no" })
163 .unwrap_or("-");
164 let qa_conf_str = qa_result
165 .and_then(|q| q.confidence)
166 .map(|v| format!("{}", v))
167 .unwrap_or("-".to_string());
168
169 println!(
170 "{:<40} {:>8.2} {:>5} {:>6.1}% {:>6.1}% {:>7} {:>7}",
171 truncate_name(&example.spec.name, 40),
172 score.delta_chr_f,
173 score.braces_disbalance,
174 exact_lines.f1() * 100.0,
175 score.reversal_ratio * 100.0,
176 qa_reverts_str,
177 qa_conf_str
178 );
179
180 all_delta_chr_f_scores.push(score.delta_chr_f);
181 all_reversal_ratios.push(score.reversal_ratio);
182 total_scores += 1;
183 braces_disbalance_sum += score.braces_disbalance;
184 total_exact_lines.true_positives += score.exact_lines_tp;
185 total_exact_lines.false_positives += score.exact_lines_fp;
186 total_exact_lines.false_negatives += score.exact_lines_fn;
187
188 // Accumulate QA metrics
189 if let Some(qa) = qa_result {
190 if let Some(reverts) = qa.reverts_edits {
191 qa_reverts_total += 1;
192 if reverts {
193 qa_reverts_count += 1;
194 }
195 }
196 if let Some(conf) = qa.confidence {
197 qa_confidence_sum += conf as u64;
198 qa_confidence_count += 1;
199 }
200 }
201 }
202 }
203
204 println!("{}", separator);
205
206 if !all_delta_chr_f_scores.is_empty() {
207 let avg_delta_chr_f: f32 =
208 all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
209 let avg_reversal_ratio: f32 =
210 all_reversal_ratios.iter().sum::<f32>() / all_reversal_ratios.len() as f32;
211 let braces_disbalance_avg: f32 = braces_disbalance_sum as f32 / total_scores as f32;
212
213 let qa_reverts_str = if qa_reverts_total > 0 {
214 format!(
215 "{:.1}%",
216 qa_reverts_count as f32 / qa_reverts_total as f32 * 100.0
217 )
218 } else {
219 "-".to_string()
220 };
221 let qa_conf_str = if qa_confidence_count > 0 {
222 format!(
223 "{:.1}",
224 qa_confidence_sum as f32 / qa_confidence_count as f32
225 )
226 } else {
227 "-".to_string()
228 };
229
230 println!(
231 "{:<40} {:>8.2} {:>5.1} {:>6.1}% {:>6.1}% {:>7} {:>7}",
232 "TOTAL / AVERAGE",
233 avg_delta_chr_f,
234 braces_disbalance_avg,
235 total_exact_lines.f1() * 100.0,
236 avg_reversal_ratio * 100.0,
237 qa_reverts_str,
238 qa_conf_str
239 );
240 println!("{}", separator);
241 }
242
243 println!("\n");
244}
245
246fn truncate_name(name: &str, max_len: usize) -> String {
247 if name.len() <= max_len {
248 name.to_string()
249 } else {
250 format!("{}...", &name[..max_len - 3])
251 }
252}
253
254#[derive(Serialize)]
255pub struct SummaryJson {
256 pub total_examples: usize,
257 pub avg_delta_chr_f: f32,
258 pub avg_braces_disbalance: f32,
259 pub exact_lines_true_positives: usize,
260 pub exact_lines_false_positives: usize,
261 pub exact_lines_false_negatives: usize,
262 pub exact_lines_precision: f64,
263 pub exact_lines_recall: f64,
264 pub exact_lines_f1: f64,
265 pub avg_reversal_ratio: f32,
266 #[serde(skip_serializing_if = "Option::is_none")]
267 pub qa_avg_reverts_edits: Option<f32>,
268 #[serde(skip_serializing_if = "Option::is_none")]
269 pub qa_avg_confidence: Option<f32>,
270}
271
272pub fn compute_summary(examples: &[Example]) -> SummaryJson {
273 use crate::metrics::ClassificationMetrics;
274
275 let mut all_delta_chr_f_scores = Vec::new();
276 let mut all_reversal_ratios = Vec::new();
277 let mut braces_disbalance_sum: usize = 0;
278 let mut total_exact_lines = ClassificationMetrics::default();
279 let mut total_scores: usize = 0;
280 let mut qa_reverts_count: usize = 0;
281 let mut qa_reverts_total: usize = 0;
282 let mut qa_confidence_sum: u64 = 0;
283 let mut qa_confidence_count: usize = 0;
284
285 for example in examples {
286 for (score_idx, score) in example.score.iter().enumerate() {
287 all_delta_chr_f_scores.push(score.delta_chr_f);
288 all_reversal_ratios.push(score.reversal_ratio);
289 total_scores += 1;
290 braces_disbalance_sum += score.braces_disbalance;
291 total_exact_lines.true_positives += score.exact_lines_tp;
292 total_exact_lines.false_positives += score.exact_lines_fp;
293 total_exact_lines.false_negatives += score.exact_lines_fn;
294
295 // Accumulate QA metrics
296 if let Some(Some(qa)) = example.qa.get(score_idx) {
297 if let Some(reverts) = qa.reverts_edits {
298 qa_reverts_total += 1;
299 if reverts {
300 qa_reverts_count += 1;
301 }
302 }
303 if let Some(conf) = qa.confidence {
304 qa_confidence_sum += conf as u64;
305 qa_confidence_count += 1;
306 }
307 }
308 }
309 }
310
311 let avg_delta_chr_f = if all_delta_chr_f_scores.is_empty() {
312 0.0
313 } else {
314 all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32
315 };
316
317 let avg_reversal_ratio = if all_reversal_ratios.is_empty() {
318 0.0
319 } else {
320 all_reversal_ratios.iter().sum::<f32>() / all_reversal_ratios.len() as f32
321 };
322
323 let avg_braces_disbalance = if total_scores == 0 {
324 0.0
325 } else {
326 braces_disbalance_sum as f32 / total_scores as f32
327 };
328
329 let qa_avg_reverts_edits = if qa_reverts_total > 0 {
330 Some(qa_reverts_count as f32 / qa_reverts_total as f32)
331 } else {
332 None
333 };
334
335 let qa_avg_confidence = if qa_confidence_count > 0 {
336 Some(qa_confidence_sum as f32 / qa_confidence_count as f32)
337 } else {
338 None
339 };
340
341 SummaryJson {
342 total_examples: total_scores,
343 avg_delta_chr_f,
344 avg_braces_disbalance,
345 exact_lines_true_positives: total_exact_lines.true_positives,
346 exact_lines_false_positives: total_exact_lines.false_positives,
347 exact_lines_false_negatives: total_exact_lines.false_negatives,
348 exact_lines_precision: total_exact_lines.precision(),
349 exact_lines_recall: total_exact_lines.recall(),
350 exact_lines_f1: total_exact_lines.f1(),
351 avg_reversal_ratio,
352 qa_avg_reverts_edits,
353 qa_avg_confidence,
354 }
355}
356
357pub fn write_summary_json(examples: &[Example], path: &Path) -> anyhow::Result<()> {
358 let summary = compute_summary(examples);
359 let file = File::create(path)
360 .with_context(|| format!("Failed to create summary JSON file: {}", path.display()))?;
361 let writer = BufWriter::new(file);
362 serde_json::to_writer_pretty(writer, &summary)
363 .with_context(|| format!("Failed to write summary JSON to: {}", path.display()))?;
364 eprintln!("Wrote summary JSON to: {}", path.display());
365 Ok(())
366}