1use crate::{
2 PredictArgs,
3 example::{Example, ExampleScore},
4 headless::EpAppState,
5 metrics,
6 predict::run_prediction,
7 progress::{Progress, Step},
8};
9use anyhow::Context as _;
10use edit_prediction::udiff::apply_diff_to_string;
11use gpui::AsyncApp;
12use std::sync::Arc;
13
14pub async fn run_scoring(
15 example: &mut Example,
16 args: &PredictArgs,
17 app_state: Arc<EpAppState>,
18 cx: AsyncApp,
19) -> anyhow::Result<()> {
20 run_prediction(example, args, app_state, cx).await?;
21
22 let progress = Progress::global().start(Step::Score, &example.spec.name);
23
24 progress.set_substatus("applying patches");
25 let original_text = &example.prompt_inputs.as_ref().unwrap().content;
26 let expected_texts: Vec<String> = example
27 .spec
28 .expected_patches
29 .iter()
30 .map(|patch| {
31 apply_diff_to_string(patch, original_text)
32 .with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
33 })
34 .collect::<Result<Vec<_>, _>>()?;
35
36 progress.set_substatus("computing metrics");
37 let mut scores = vec![];
38 for prediction in &example.predictions {
39 let actual_text = match apply_diff_to_string(&prediction.actual_patch, original_text) {
40 Ok(text) => text,
41 Err(_) => {
42 scores.push(ExampleScore { delta_chr_f: 0.0 });
43 continue;
44 }
45 };
46 let best_delta_chr_f = expected_texts
47 .iter()
48 .map(|expected| metrics::delta_chr_f(original_text, expected, &actual_text) as f32)
49 .fold(0.0, f32::max);
50 scores.push(ExampleScore {
51 delta_chr_f: best_delta_chr_f,
52 });
53 }
54
55 example.score = scores;
56 Ok(())
57}
58
59pub fn print_report(examples: &[Example]) {
60 eprintln!(
61 "──────────────────────────────────────────────────────────────────────────────────────"
62 );
63 eprintln!("{:<50} {:>10}", "Example name", "DeltaChrF");
64 eprintln!(
65 "──────────────────────────────────────────────────────────────────────────────────────"
66 );
67
68 let mut all_delta_chr_f_scores = Vec::new();
69
70 for example in examples {
71 for score in example.score.iter() {
72 eprintln!(
73 "{:<50} {:>9.2}",
74 truncate_name(&example.spec.name, 50),
75 score.delta_chr_f
76 );
77
78 all_delta_chr_f_scores.push(score.delta_chr_f);
79 }
80 }
81
82 eprintln!(
83 "──────────────────────────────────────────────────────────────────────────────────────"
84 );
85
86 if !all_delta_chr_f_scores.is_empty() {
87 let avg_delta_chr_f: f32 =
88 all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
89
90 eprintln!("{:<50} {:>9.2}", "AVERAGE", avg_delta_chr_f);
91 eprintln!(
92 "──────────────────────────────────────────────────────────────────────────────────────"
93 );
94 }
95
96 eprintln!("\n");
97}
98
99fn truncate_name(name: &str, max_len: usize) -> String {
100 if name.len() <= max_len {
101 name.to_string()
102 } else {
103 format!("{}...", &name[..max_len - 3])
104 }
105}