1use crate::{
2 PredictArgs,
3 example::{Example, ExampleScore},
4 headless::EpAppState,
5 metrics,
6 parse_output::parse_prediction_output,
7 predict::run_prediction,
8 progress::{ExampleProgress, Step},
9};
10use anyhow::Context as _;
11use edit_prediction::udiff::apply_diff_to_string;
12use gpui::AsyncApp;
13use std::sync::Arc;
14
15pub async fn run_scoring(
16 example: &mut Example,
17 args: &PredictArgs,
18 app_state: Arc<EpAppState>,
19 example_progress: &ExampleProgress,
20 cx: AsyncApp,
21) -> anyhow::Result<()> {
22 run_prediction(example, args, app_state, example_progress, cx).await?;
23
24 let progress = example_progress.start(Step::Score);
25
26 progress.set_substatus("applying patches");
27 let original_text = &example
28 .prompt_inputs
29 .as_ref()
30 .context("prompt_inputs is required for scoring - run prediction first or ensure JSON includes prompt_inputs")?
31 .content;
32 let expected_texts: Vec<String> = example
33 .spec
34 .expected_patches
35 .iter()
36 .map(|patch| {
37 apply_diff_to_string(patch, original_text)
38 .with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
39 })
40 .collect::<Result<Vec<_>, _>>()?;
41
42 progress.set_substatus("computing metrics");
43 let mut scores = vec![];
44 for prediction in &example.predictions {
45 let actual_patch = match &prediction.actual_patch {
46 Some(patch) => patch.clone(),
47 None => {
48 if prediction.actual_output.is_empty() {
49 scores.push(ExampleScore { delta_chr_f: 0.0 });
50 continue;
51 }
52 match parse_prediction_output(
53 example,
54 &prediction.actual_output,
55 prediction.provider,
56 ) {
57 Ok(patch) => patch,
58 Err(_) => {
59 scores.push(ExampleScore { delta_chr_f: 0.0 });
60 continue;
61 }
62 }
63 }
64 };
65 let actual_text = match apply_diff_to_string(&actual_patch, original_text) {
66 Ok(text) => text,
67 Err(_) => {
68 scores.push(ExampleScore { delta_chr_f: 0.0 });
69 continue;
70 }
71 };
72 let best_delta_chr_f = expected_texts
73 .iter()
74 .map(|expected| metrics::delta_chr_f(original_text, expected, &actual_text) as f32)
75 .fold(0.0, f32::max);
76 scores.push(ExampleScore {
77 delta_chr_f: best_delta_chr_f,
78 });
79 }
80
81 example.score = scores;
82 Ok(())
83}
84
85pub fn print_report(examples: &[Example]) {
86 eprintln!(
87 "──────────────────────────────────────────────────────────────────────────────────────"
88 );
89 eprintln!("{:<50} {:>10}", "Example name", "DeltaChrF");
90 eprintln!(
91 "──────────────────────────────────────────────────────────────────────────────────────"
92 );
93
94 let mut all_delta_chr_f_scores = Vec::new();
95
96 for example in examples {
97 for score in example.score.iter() {
98 eprintln!(
99 "{:<50} {:>9.2}",
100 truncate_name(&example.spec.name, 50),
101 score.delta_chr_f
102 );
103
104 all_delta_chr_f_scores.push(score.delta_chr_f);
105 }
106 }
107
108 eprintln!(
109 "──────────────────────────────────────────────────────────────────────────────────────"
110 );
111
112 if !all_delta_chr_f_scores.is_empty() {
113 let avg_delta_chr_f: f32 =
114 all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
115
116 eprintln!("{:<50} {:>9.2}", "AVERAGE", avg_delta_chr_f);
117 eprintln!(
118 "──────────────────────────────────────────────────────────────────────────────────────"
119 );
120 }
121
122 eprintln!("\n");
123}
124
125fn truncate_name(name: &str, max_len: usize) -> String {
126 if name.len() <= max_len {
127 name.to_string()
128 } else {
129 format!("{}...", &name[..max_len - 3])
130 }
131}