1use crate::{
2 PredictArgs,
3 example::{Example, ExampleScore},
4 headless::EpAppState,
5 metrics,
6 parse_output::parse_prediction_output,
7 predict::run_prediction,
8 progress::{ExampleProgress, Step},
9 reversal_tracking,
10};
11use anyhow::Context as _;
12use edit_prediction::udiff::apply_diff_to_string;
13use gpui::AsyncApp;
14use serde::Serialize;
15use std::fs::File;
16use std::io::BufWriter;
17use std::path::Path;
18use std::sync::Arc;
19
20pub async fn run_scoring(
21 example: &mut Example,
22 args: &PredictArgs,
23 app_state: Arc<EpAppState>,
24 example_progress: &ExampleProgress,
25 cx: AsyncApp,
26) -> anyhow::Result<()> {
27 run_prediction(example, args, app_state, example_progress, cx).await?;
28
29 let progress = example_progress.start(Step::Score);
30
31 progress.set_substatus("applying patches");
32 let original_text = &example
33 .prompt_inputs
34 .as_ref()
35 .context("prompt_inputs is required for scoring - run prediction first or ensure JSON includes prompt_inputs")?
36 .content;
37 let expected_texts: Vec<String> = example
38 .spec
39 .expected_patches
40 .iter()
41 .map(|patch| {
42 apply_diff_to_string(patch, original_text)
43 .with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
44 })
45 .collect::<Result<Vec<_>, _>>()?;
46
47 let zero_scores = ExampleScore {
48 delta_chr_f: 0.0,
49 braces_disbalance: 0,
50 exact_lines_tp: 0,
51 exact_lines_fp: 0,
52 exact_lines_fn: 0,
53 reversal_ratio: 0.0,
54 };
55
56 let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
57 let cursor_path = example.spec.cursor_path.as_ref();
58
59 progress.set_substatus("computing metrics");
60 let mut scores = vec![];
61 for prediction in &example.predictions {
62 let actual_patch = prediction.actual_patch.clone().or_else(|| {
63 parse_prediction_output(example, &prediction.actual_output, prediction.provider).ok()
64 });
65
66 let Some(actual_patch) = actual_patch else {
67 scores.push(zero_scores.clone());
68 continue;
69 };
70
71 let actual_text = match apply_diff_to_string(&actual_patch, original_text) {
72 Ok(text) => text,
73 Err(_) => {
74 scores.push(zero_scores.clone());
75 continue;
76 }
77 };
78 let best_delta_chr_f = expected_texts
79 .iter()
80 .map(|expected| metrics::delta_chr_f(original_text, expected, &actual_text) as f32)
81 .fold(0.0, f32::max);
82
83 let disbalance_before = metrics::braces_disbalance(&original_text);
84 let disbalance_after = metrics::braces_disbalance(&actual_text);
85 let braces_disbalance = disbalance_after.saturating_sub(disbalance_before);
86 if braces_disbalance > 0 {
87 std::fs::write(
88 "/tmp/unbalanced-count.before",
89 disbalance_before.to_string(),
90 )
91 .ok();
92 std::fs::write("/tmp/unbalanced-count.after", disbalance_after.to_string()).ok();
93 std::fs::write("/tmp/unbalanced-text.before", &original_text).ok();
94 std::fs::write("/tmp/unbalanced-text.after", &actual_text).ok();
95 }
96
97 // Compute exact lines match against best matching expected patch
98 let best_exact_lines = example
99 .spec
100 .expected_patches
101 .iter()
102 .map(|expected_patch| metrics::exact_lines_match(expected_patch, &actual_patch))
103 .max_by_key(|m| m.true_positives)
104 .unwrap_or_default();
105
106 // Compute reversal ratio
107 let reversal_ratio = reversal_tracking::compute_prediction_reversal_ratio(
108 prompt_inputs,
109 &actual_text,
110 cursor_path,
111 );
112
113 scores.push(ExampleScore {
114 delta_chr_f: best_delta_chr_f,
115 braces_disbalance,
116 exact_lines_tp: best_exact_lines.true_positives,
117 exact_lines_fp: best_exact_lines.false_positives,
118 exact_lines_fn: best_exact_lines.false_negatives,
119 reversal_ratio,
120 });
121 }
122
123 example.score = scores;
124 Ok(())
125}
126
127pub fn print_report(examples: &[Example]) {
128 use crate::metrics::ClassificationMetrics;
129
130 const LINE_WIDTH: usize = 110;
131 let separator = "─".repeat(LINE_WIDTH);
132
133 println!("{}", separator);
134 println!(
135 "{:<40} {:>8} {:>5} {:>4} {:>4} {:>4} {:>7} {:>7} {:>7} {:>7}",
136 "Example", "DeltaChrF", "Brace", "TP", "FP", "FN", "Prec", "Rec", "F1", "Revert"
137 );
138 println!("{}", separator);
139
140 let mut all_delta_chr_f_scores = Vec::new();
141 let mut all_reversal_ratios = Vec::new();
142 let mut braces_disbalance_sum: usize = 0;
143 let mut total_exact_lines = ClassificationMetrics::default();
144 let mut total_scores: usize = 0;
145
146 for example in examples {
147 for score in example.score.iter() {
148 let exact_lines = ClassificationMetrics {
149 true_positives: score.exact_lines_tp,
150 false_positives: score.exact_lines_fp,
151 false_negatives: score.exact_lines_fn,
152 };
153
154 println!(
155 "{:<40} {:>8.2} {:>5} {:>4} {:>4} {:>4} {:>6.1}% {:>6.1}% {:>6.1}% {:>6.1}%",
156 truncate_name(&example.spec.name, 40),
157 score.delta_chr_f,
158 score.braces_disbalance,
159 score.exact_lines_tp,
160 score.exact_lines_fp,
161 score.exact_lines_fn,
162 exact_lines.precision() * 100.0,
163 exact_lines.recall() * 100.0,
164 exact_lines.f1() * 100.0,
165 score.reversal_ratio * 100.0
166 );
167
168 all_delta_chr_f_scores.push(score.delta_chr_f);
169 all_reversal_ratios.push(score.reversal_ratio);
170 total_scores += 1;
171 braces_disbalance_sum += score.braces_disbalance;
172 total_exact_lines.true_positives += score.exact_lines_tp;
173 total_exact_lines.false_positives += score.exact_lines_fp;
174 total_exact_lines.false_negatives += score.exact_lines_fn;
175 }
176 }
177
178 println!("{}", separator);
179
180 if !all_delta_chr_f_scores.is_empty() {
181 let avg_delta_chr_f: f32 =
182 all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
183 let avg_reversal_ratio: f32 =
184 all_reversal_ratios.iter().sum::<f32>() / all_reversal_ratios.len() as f32;
185 let braces_disbalance_avg: f32 = braces_disbalance_sum as f32 / total_scores as f32;
186
187 println!(
188 "{:<40} {:>8.2} {:>5.1} {:>4} {:>4} {:>4} {:>6.1}% {:>6.1}% {:>6.1}% {:>6.1}%",
189 "TOTAL / AVERAGE",
190 avg_delta_chr_f,
191 braces_disbalance_avg,
192 total_exact_lines.true_positives,
193 total_exact_lines.false_positives,
194 total_exact_lines.false_negatives,
195 total_exact_lines.precision() * 100.0,
196 total_exact_lines.recall() * 100.0,
197 total_exact_lines.f1() * 100.0,
198 avg_reversal_ratio * 100.0
199 );
200 println!("{}", separator);
201 }
202
203 println!("\n");
204}
205
206fn truncate_name(name: &str, max_len: usize) -> String {
207 if name.len() <= max_len {
208 name.to_string()
209 } else {
210 format!("{}...", &name[..max_len - 3])
211 }
212}
213
214#[derive(Serialize)]
215pub struct SummaryJson {
216 pub total_examples: usize,
217 pub avg_delta_chr_f: f32,
218 pub avg_braces_disbalance: f32,
219 pub exact_lines_true_positives: usize,
220 pub exact_lines_false_positives: usize,
221 pub exact_lines_false_negatives: usize,
222 pub exact_lines_precision: f64,
223 pub exact_lines_recall: f64,
224 pub exact_lines_f1: f64,
225 pub avg_reversal_ratio: f32,
226}
227
228pub fn compute_summary(examples: &[Example]) -> SummaryJson {
229 use crate::metrics::ClassificationMetrics;
230
231 let mut all_delta_chr_f_scores = Vec::new();
232 let mut all_reversal_ratios = Vec::new();
233 let mut braces_disbalance_sum: usize = 0;
234 let mut total_exact_lines = ClassificationMetrics::default();
235 let mut total_scores: usize = 0;
236
237 for example in examples {
238 for score in example.score.iter() {
239 all_delta_chr_f_scores.push(score.delta_chr_f);
240 all_reversal_ratios.push(score.reversal_ratio);
241 total_scores += 1;
242 braces_disbalance_sum += score.braces_disbalance;
243 total_exact_lines.true_positives += score.exact_lines_tp;
244 total_exact_lines.false_positives += score.exact_lines_fp;
245 total_exact_lines.false_negatives += score.exact_lines_fn;
246 }
247 }
248
249 let avg_delta_chr_f = if all_delta_chr_f_scores.is_empty() {
250 0.0
251 } else {
252 all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32
253 };
254
255 let avg_reversal_ratio = if all_reversal_ratios.is_empty() {
256 0.0
257 } else {
258 all_reversal_ratios.iter().sum::<f32>() / all_reversal_ratios.len() as f32
259 };
260
261 let avg_braces_disbalance = if total_scores == 0 {
262 0.0
263 } else {
264 braces_disbalance_sum as f32 / total_scores as f32
265 };
266
267 SummaryJson {
268 total_examples: total_scores,
269 avg_delta_chr_f,
270 avg_braces_disbalance,
271 exact_lines_true_positives: total_exact_lines.true_positives,
272 exact_lines_false_positives: total_exact_lines.false_positives,
273 exact_lines_false_negatives: total_exact_lines.false_negatives,
274 exact_lines_precision: total_exact_lines.precision(),
275 exact_lines_recall: total_exact_lines.recall(),
276 exact_lines_f1: total_exact_lines.f1(),
277 avg_reversal_ratio,
278 }
279}
280
281pub fn write_summary_json(examples: &[Example], path: &Path) -> anyhow::Result<()> {
282 let summary = compute_summary(examples);
283 let file = File::create(path)
284 .with_context(|| format!("Failed to create summary JSON file: {}", path.display()))?;
285 let writer = BufWriter::new(file);
286 serde_json::to_writer_pretty(writer, &summary)
287 .with_context(|| format!("Failed to write summary JSON to: {}", path.display()))?;
288 eprintln!("Wrote summary JSON to: {}", path.display());
289 Ok(())
290}