1use crate::{
2 PredictArgs,
3 example::{Example, ExampleScore},
4 headless::EpAppState,
5 metrics,
6 parse_output::parse_prediction_output,
7 predict::run_prediction,
8 progress::{ExampleProgress, Step},
9};
10use anyhow::Context as _;
11use edit_prediction::udiff::apply_diff_to_string;
12use gpui::AsyncApp;
13use serde::Serialize;
14use std::fs::File;
15use std::io::BufWriter;
16use std::path::Path;
17use std::sync::Arc;
18
19pub async fn run_scoring(
20 example: &mut Example,
21 args: &PredictArgs,
22 app_state: Arc<EpAppState>,
23 example_progress: &ExampleProgress,
24 cx: AsyncApp,
25) -> anyhow::Result<()> {
26 run_prediction(example, args, app_state, example_progress, cx).await?;
27
28 let progress = example_progress.start(Step::Score);
29
30 progress.set_substatus("applying patches");
31 let original_text = &example
32 .prompt_inputs
33 .as_ref()
34 .context("prompt_inputs is required for scoring - run prediction first or ensure JSON includes prompt_inputs")?
35 .content;
36 let expected_texts: Vec<String> = example
37 .spec
38 .expected_patches
39 .iter()
40 .map(|patch| {
41 apply_diff_to_string(patch, original_text)
42 .with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
43 })
44 .collect::<Result<Vec<_>, _>>()?;
45
46 let zero_scores = ExampleScore {
47 delta_chr_f: 0.0,
48 braces_disbalance: 0,
49 exact_lines_tp: 0,
50 exact_lines_fp: 0,
51 exact_lines_fn: 0,
52 };
53
54 progress.set_substatus("computing metrics");
55 let mut scores = vec![];
56 for prediction in &example.predictions {
57 let actual_patch = prediction.actual_patch.clone().or_else(|| {
58 parse_prediction_output(example, &prediction.actual_output, prediction.provider).ok()
59 });
60
61 let Some(actual_patch) = actual_patch else {
62 scores.push(zero_scores.clone());
63 continue;
64 };
65
66 let actual_text = match apply_diff_to_string(&actual_patch, original_text) {
67 Ok(text) => text,
68 Err(_) => {
69 scores.push(zero_scores.clone());
70 continue;
71 }
72 };
73 let best_delta_chr_f = expected_texts
74 .iter()
75 .map(|expected| metrics::delta_chr_f(original_text, expected, &actual_text) as f32)
76 .fold(0.0, f32::max);
77
78 let disbalance_before = metrics::braces_disbalance(&original_text);
79 let disbalance_after = metrics::braces_disbalance(&actual_text);
80 let braces_disbalance = disbalance_after.saturating_sub(disbalance_before);
81 if braces_disbalance > 0 {
82 std::fs::write(
83 "/tmp/unbalanced-count.before",
84 disbalance_before.to_string(),
85 )
86 .ok();
87 std::fs::write("/tmp/unbalanced-count.after", disbalance_after.to_string()).ok();
88 std::fs::write("/tmp/unbalanced-text.before", &original_text).ok();
89 std::fs::write("/tmp/unbalanced-text.after", &actual_text).ok();
90 }
91
92 // Compute exact lines match against best matching expected patch
93 let best_exact_lines = example
94 .spec
95 .expected_patches
96 .iter()
97 .map(|expected_patch| metrics::exact_lines_match(expected_patch, &actual_patch))
98 .max_by_key(|m| m.true_positives)
99 .unwrap_or_default();
100
101 scores.push(ExampleScore {
102 delta_chr_f: best_delta_chr_f,
103 braces_disbalance,
104 exact_lines_tp: best_exact_lines.true_positives,
105 exact_lines_fp: best_exact_lines.false_positives,
106 exact_lines_fn: best_exact_lines.false_negatives,
107 });
108 }
109
110 example.score = scores;
111 Ok(())
112}
113
114pub fn print_report(examples: &[Example]) {
115 use crate::metrics::ClassificationMetrics;
116
117 const LINE_WIDTH: usize = 100;
118 let separator = "─".repeat(LINE_WIDTH);
119
120 println!("{}", separator);
121 println!(
122 "{:<40} {:>8} {:>5} {:>4} {:>4} {:>4} {:>7} {:>7} {:>7}",
123 "Example", "DeltaChrF", "Brace", "TP", "FP", "FN", "Prec", "Rec", "F1"
124 );
125 println!("{}", separator);
126
127 let mut all_delta_chr_f_scores = Vec::new();
128 let mut braces_disbalance_sum: usize = 0;
129 let mut total_exact_lines = ClassificationMetrics::default();
130 let mut total_scores: usize = 0;
131
132 for example in examples {
133 for score in example.score.iter() {
134 let exact_lines = ClassificationMetrics {
135 true_positives: score.exact_lines_tp,
136 false_positives: score.exact_lines_fp,
137 false_negatives: score.exact_lines_fn,
138 };
139
140 println!(
141 "{:<40} {:>8.2} {:>5} {:>4} {:>4} {:>4} {:>6.1}% {:>6.1}% {:>6.1}%",
142 truncate_name(&example.spec.name, 40),
143 score.delta_chr_f,
144 score.braces_disbalance,
145 score.exact_lines_tp,
146 score.exact_lines_fp,
147 score.exact_lines_fn,
148 exact_lines.precision() * 100.0,
149 exact_lines.recall() * 100.0,
150 exact_lines.f1() * 100.0
151 );
152
153 all_delta_chr_f_scores.push(score.delta_chr_f);
154 total_scores += 1;
155 braces_disbalance_sum += score.braces_disbalance;
156 total_exact_lines.true_positives += score.exact_lines_tp;
157 total_exact_lines.false_positives += score.exact_lines_fp;
158 total_exact_lines.false_negatives += score.exact_lines_fn;
159 }
160 }
161
162 println!("{}", separator);
163
164 if !all_delta_chr_f_scores.is_empty() {
165 let avg_delta_chr_f: f32 =
166 all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
167 let braces_disbalance_avg: f32 = braces_disbalance_sum as f32 / total_scores as f32;
168
169 println!(
170 "{:<40} {:>8.2} {:>5.1} {:>4} {:>4} {:>4} {:>6.1}% {:>6.1}% {:>6.1}%",
171 "TOTAL / AVERAGE",
172 avg_delta_chr_f,
173 braces_disbalance_avg,
174 total_exact_lines.true_positives,
175 total_exact_lines.false_positives,
176 total_exact_lines.false_negatives,
177 total_exact_lines.precision() * 100.0,
178 total_exact_lines.recall() * 100.0,
179 total_exact_lines.f1() * 100.0
180 );
181 println!("{}", separator);
182 }
183
184 println!("\n");
185}
186
187fn truncate_name(name: &str, max_len: usize) -> String {
188 if name.len() <= max_len {
189 name.to_string()
190 } else {
191 format!("{}...", &name[..max_len - 3])
192 }
193}
194
195#[derive(Serialize)]
196pub struct SummaryJson {
197 pub total_examples: usize,
198 pub avg_delta_chr_f: f32,
199 pub avg_braces_disbalance: f32,
200 pub exact_lines_true_positives: usize,
201 pub exact_lines_false_positives: usize,
202 pub exact_lines_false_negatives: usize,
203 pub exact_lines_precision: f64,
204 pub exact_lines_recall: f64,
205 pub exact_lines_f1: f64,
206}
207
208pub fn compute_summary(examples: &[Example]) -> SummaryJson {
209 use crate::metrics::ClassificationMetrics;
210
211 let mut all_delta_chr_f_scores = Vec::new();
212 let mut braces_disbalance_sum: usize = 0;
213 let mut total_exact_lines = ClassificationMetrics::default();
214 let mut total_scores: usize = 0;
215
216 for example in examples {
217 for score in example.score.iter() {
218 all_delta_chr_f_scores.push(score.delta_chr_f);
219 total_scores += 1;
220 braces_disbalance_sum += score.braces_disbalance;
221 total_exact_lines.true_positives += score.exact_lines_tp;
222 total_exact_lines.false_positives += score.exact_lines_fp;
223 total_exact_lines.false_negatives += score.exact_lines_fn;
224 }
225 }
226
227 let avg_delta_chr_f = if all_delta_chr_f_scores.is_empty() {
228 0.0
229 } else {
230 all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32
231 };
232
233 let avg_braces_disbalance = if total_scores == 0 {
234 0.0
235 } else {
236 braces_disbalance_sum as f32 / total_scores as f32
237 };
238
239 SummaryJson {
240 total_examples: total_scores,
241 avg_delta_chr_f,
242 avg_braces_disbalance,
243 exact_lines_true_positives: total_exact_lines.true_positives,
244 exact_lines_false_positives: total_exact_lines.false_positives,
245 exact_lines_false_negatives: total_exact_lines.false_negatives,
246 exact_lines_precision: total_exact_lines.precision(),
247 exact_lines_recall: total_exact_lines.recall(),
248 exact_lines_f1: total_exact_lines.f1(),
249 }
250}
251
252pub fn write_summary_json(examples: &[Example], path: &Path) -> anyhow::Result<()> {
253 let summary = compute_summary(examples);
254 let file = File::create(path)
255 .with_context(|| format!("Failed to create summary JSON file: {}", path.display()))?;
256 let writer = BufWriter::new(file);
257 serde_json::to_writer_pretty(writer, &summary)
258 .with_context(|| format!("Failed to write summary JSON to: {}", path.display()))?;
259 eprintln!("Wrote summary JSON to: {}", path.display());
260 Ok(())
261}