1use crate::{
2 PredictArgs,
3 example::{Example, ExampleScore},
4 headless::EpAppState,
5 metrics,
6 predict::run_prediction,
7 progress::{Progress, Step},
8};
9use anyhow::Context as _;
10use edit_prediction::udiff::apply_diff_to_string;
11use gpui::AsyncApp;
12use std::sync::Arc;
13
14pub async fn run_scoring(
15 example: &mut Example,
16 args: &PredictArgs,
17 app_state: Arc<EpAppState>,
18 cx: AsyncApp,
19) -> anyhow::Result<()> {
20 run_prediction(
21 example,
22 Some(args.provider),
23 args.repetitions,
24 app_state,
25 cx,
26 )
27 .await?;
28
29 let progress = Progress::global().start(Step::Score, &example.spec.name);
30
31 progress.set_substatus("applying patches");
32 let original_text = &example.buffer.as_ref().unwrap().content;
33 let expected_texts: Vec<String> = example
34 .spec
35 .expected_patches
36 .iter()
37 .map(|patch| {
38 apply_diff_to_string(patch, original_text)
39 .with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
40 })
41 .collect::<Result<Vec<_>, _>>()?;
42
43 progress.set_substatus("computing metrics");
44 let mut scores = vec![];
45 for prediction in &example.predictions {
46 let actual_text = match apply_diff_to_string(&prediction.actual_patch, original_text) {
47 Ok(text) => text,
48 Err(_) => {
49 scores.push(ExampleScore { delta_chr_f: 0.0 });
50 continue;
51 }
52 };
53 let best_delta_chr_f = expected_texts
54 .iter()
55 .map(|expected| metrics::delta_chr_f(original_text, expected, &actual_text) as f32)
56 .fold(0.0, f32::max);
57 scores.push(ExampleScore {
58 delta_chr_f: best_delta_chr_f,
59 });
60 }
61
62 example.score = scores;
63 Ok(())
64}
65
66pub fn print_report(examples: &[Example]) {
67 eprintln!(
68 "──────────────────────────────────────────────────────────────────────────────────────"
69 );
70 eprintln!("{:<50} {:>10}", "Example name", "DeltaChrF");
71 eprintln!(
72 "──────────────────────────────────────────────────────────────────────────────────────"
73 );
74
75 let mut all_delta_chr_f_scores = Vec::new();
76
77 for example in examples {
78 for score in example.score.iter() {
79 eprintln!(
80 "{:<50} {:>9.2}",
81 truncate_name(&example.spec.name, 50),
82 score.delta_chr_f
83 );
84
85 all_delta_chr_f_scores.push(score.delta_chr_f);
86 }
87 }
88
89 eprintln!(
90 "──────────────────────────────────────────────────────────────────────────────────────"
91 );
92
93 if !all_delta_chr_f_scores.is_empty() {
94 let avg_delta_chr_f: f32 =
95 all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
96
97 eprintln!("{:<50} {:>9.2}", "AVERAGE", avg_delta_chr_f);
98 eprintln!(
99 "──────────────────────────────────────────────────────────────────────────────────────"
100 );
101 }
102
103 eprintln!("\n");
104}
105
106fn truncate_name(name: &str, max_len: usize) -> String {
107 if name.len() <= max_len {
108 name.to_string()
109 } else {
110 format!("{}...", &name[..max_len - 3])
111 }
112}