1use crate::{
2 PredictArgs,
3 example::{Example, ExampleScore},
4 headless::EpAppState,
5 metrics,
6 parse_output::parse_prediction_output,
7 predict::run_prediction,
8 progress::{ExampleProgress, Step},
9};
10use anyhow::Context as _;
11use edit_prediction::udiff::apply_diff_to_string;
12use gpui::AsyncApp;
13use std::sync::Arc;
14
15pub async fn run_scoring(
16 example: &mut Example,
17 args: &PredictArgs,
18 app_state: Arc<EpAppState>,
19 example_progress: &ExampleProgress,
20 cx: AsyncApp,
21) -> anyhow::Result<()> {
22 run_prediction(example, args, app_state, example_progress, cx).await?;
23
24 let progress = example_progress.start(Step::Score);
25
26 progress.set_substatus("applying patches");
27 let original_text = &example
28 .prompt_inputs
29 .as_ref()
30 .context("prompt_inputs is required for scoring - run prediction first or ensure JSON includes prompt_inputs")?
31 .content;
32 let expected_texts: Vec<String> = example
33 .spec
34 .expected_patches
35 .iter()
36 .map(|patch| {
37 apply_diff_to_string(patch, original_text)
38 .with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
39 })
40 .collect::<Result<Vec<_>, _>>()?;
41
42 let zero_scores = ExampleScore {
43 delta_chr_f: 0.0,
44 braces_disbalance: 0,
45 };
46
47 progress.set_substatus("computing metrics");
48 let mut scores = vec![];
49 for prediction in &example.predictions {
50 let actual_patch = prediction.actual_patch.clone().or_else(|| {
51 parse_prediction_output(example, &prediction.actual_output, prediction.provider).ok()
52 });
53
54 let Some(actual_patch) = actual_patch else {
55 scores.push(zero_scores.clone());
56 continue;
57 };
58
59 let actual_text = match apply_diff_to_string(&actual_patch, original_text) {
60 Ok(text) => text,
61 Err(_) => {
62 scores.push(zero_scores.clone());
63 continue;
64 }
65 };
66 let best_delta_chr_f = expected_texts
67 .iter()
68 .map(|expected| metrics::delta_chr_f(original_text, expected, &actual_text) as f32)
69 .fold(0.0, f32::max);
70
71 let disbalance_before = metrics::braces_disbalance(&original_text);
72 let disbalance_after = metrics::braces_disbalance(&actual_text);
73 let braces_disbalance = disbalance_after.saturating_sub(disbalance_before);
74 if braces_disbalance > 0 {
75 std::fs::write(
76 "/tmp/unbalanced-count.before",
77 disbalance_before.to_string(),
78 )
79 .ok();
80 std::fs::write("/tmp/unbalanced-count.after", disbalance_after.to_string()).ok();
81 std::fs::write("/tmp/unbalanced-text.before", &original_text).ok();
82 std::fs::write("/tmp/unbalanced-text.after", &actual_text).ok();
83 }
84
85 scores.push(ExampleScore {
86 delta_chr_f: best_delta_chr_f,
87 braces_disbalance,
88 });
89 }
90
91 example.score = scores;
92 Ok(())
93}
94
95pub fn print_report(examples: &[Example]) {
96 eprintln!(
97 "──────────────────────────────────────────────────────────────────────────────────────"
98 );
99 eprintln!(
100 "{:<50} {:>14} {:>10}",
101 "Example name", "BracesDisbalance", "DeltaChrF"
102 );
103 eprintln!(
104 "──────────────────────────────────────────────────────────────────────────────────────"
105 );
106
107 let mut all_delta_chr_f_scores = Vec::new();
108 let mut braces_disbalance_sum: usize = 0;
109 let mut total_scores: usize = 0;
110
111 for example in examples {
112 for score in example.score.iter() {
113 eprintln!(
114 "{:<50} {:>14} {:>9.2}",
115 truncate_name(&example.spec.name, 50),
116 score.braces_disbalance,
117 score.delta_chr_f
118 );
119
120 all_delta_chr_f_scores.push(score.delta_chr_f);
121 total_scores += 1;
122 braces_disbalance_sum += score.braces_disbalance;
123 }
124 }
125
126 eprintln!(
127 "──────────────────────────────────────────────────────────────────────────────────────"
128 );
129
130 if !all_delta_chr_f_scores.is_empty() {
131 let avg_delta_chr_f: f32 =
132 all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
133 let braces_disbalance_avg: f32 = braces_disbalance_sum as f32 / total_scores as f32;
134 let braces_disbalance_display = format!("{:.2}", braces_disbalance_avg);
135
136 eprintln!(
137 "{:<50} {:>14} {:>9.2}",
138 "AVERAGE", braces_disbalance_display, avg_delta_chr_f
139 );
140 eprintln!(
141 "──────────────────────────────────────────────────────────────────────────────────────"
142 );
143 }
144
145 eprintln!("\n");
146}
147
148fn truncate_name(name: &str, max_len: usize) -> String {
149 if name.len() <= max_len {
150 name.to_string()
151 } else {
152 format!("{}...", &name[..max_len - 3])
153 }
154}