1use crate::{
2 PredictArgs,
3 example::{Example, ExampleScore},
4 headless::EpAppState,
5 metrics,
6 predict::run_prediction,
7 progress::{Progress, Step},
8};
9use anyhow::Context as _;
10use edit_prediction::udiff::apply_diff_to_string;
11use gpui::AsyncApp;
12use std::sync::Arc;
13
14pub async fn run_scoring(
15 example: &mut Example,
16 args: &PredictArgs,
17 app_state: Arc<EpAppState>,
18 cx: AsyncApp,
19) -> anyhow::Result<()> {
20 run_prediction(
21 example,
22 Some(args.provider),
23 args.repetitions,
24 app_state,
25 cx,
26 )
27 .await?;
28
29 let _progress = Progress::global().start(Step::Score, &example.spec.name);
30
31 let original_text = &example.buffer.as_ref().unwrap().content;
32 let expected_texts: Vec<String> = example
33 .spec
34 .expected_patches
35 .iter()
36 .map(|patch| {
37 apply_diff_to_string(original_text, patch)
38 .with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
39 })
40 .collect::<Result<Vec<_>, _>>()?;
41
42 let mut scores = vec![];
43 for prediction in &example.predictions {
44 let actual_text = match apply_diff_to_string(original_text, &prediction.actual_patch) {
45 Ok(text) => text,
46 Err(_) => {
47 scores.push(ExampleScore { delta_chr_f: 0.0 });
48 continue;
49 }
50 };
51 let best_delta_chr_f = expected_texts
52 .iter()
53 .map(|expected| metrics::delta_chr_f(original_text, expected, &actual_text) as f32)
54 .fold(0.0, f32::max);
55 scores.push(ExampleScore {
56 delta_chr_f: best_delta_chr_f,
57 });
58 }
59
60 example.score = scores;
61 Ok(())
62}
63
64pub fn print_report(examples: &[Example]) {
65 eprintln!(
66 "──────────────────────────────────────────────────────────────────────────────────────"
67 );
68 eprintln!("{:<50} {:>10}", "Example name", "DeltaChrF");
69 eprintln!(
70 "──────────────────────────────────────────────────────────────────────────────────────"
71 );
72
73 let mut all_delta_chr_f_scores = Vec::new();
74
75 for example in examples {
76 for score in example.score.iter() {
77 eprintln!(
78 "{:<50} {:>9.2}",
79 truncate_name(&example.spec.name, 50),
80 score.delta_chr_f
81 );
82
83 all_delta_chr_f_scores.push(score.delta_chr_f);
84 }
85 }
86
87 eprintln!(
88 "──────────────────────────────────────────────────────────────────────────────────────"
89 );
90
91 if !all_delta_chr_f_scores.is_empty() {
92 let avg_delta_chr_f: f32 =
93 all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
94
95 eprintln!("{:<50} {:>9.2}", "AVERAGE", avg_delta_chr_f);
96 eprintln!(
97 "──────────────────────────────────────────────────────────────────────────────────────"
98 );
99 }
100
101 eprintln!("\n");
102}
103
104fn truncate_name(name: &str, max_len: usize) -> String {
105 if name.len() <= max_len {
106 name.to_string()
107 } else {
108 format!("{}...", &name[..max_len - 3])
109 }
110}