1use crate::{
2 PredictArgs,
3 example::{Example, ExampleScore},
4 headless::EpAppState,
5 metrics,
6 predict::run_prediction,
7 progress::{ExampleProgress, Step},
8};
9use anyhow::Context as _;
10use edit_prediction::udiff::apply_diff_to_string;
11use gpui::AsyncApp;
12use std::sync::Arc;
13
14pub async fn run_scoring(
15 example: &mut Example,
16 args: &PredictArgs,
17 app_state: Arc<EpAppState>,
18 example_progress: &ExampleProgress,
19 cx: AsyncApp,
20) -> anyhow::Result<()> {
21 run_prediction(example, args, app_state, example_progress, cx).await?;
22
23 let progress = example_progress.start(Step::Score);
24
25 progress.set_substatus("applying patches");
26 let original_text = &example.prompt_inputs.as_ref().unwrap().content;
27 let expected_texts: Vec<String> = example
28 .spec
29 .expected_patches
30 .iter()
31 .map(|patch| {
32 apply_diff_to_string(patch, original_text)
33 .with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
34 })
35 .collect::<Result<Vec<_>, _>>()?;
36
37 progress.set_substatus("computing metrics");
38 let mut scores = vec![];
39 for prediction in &example.predictions {
40 let actual_text = match apply_diff_to_string(&prediction.actual_patch, original_text) {
41 Ok(text) => text,
42 Err(_) => {
43 scores.push(ExampleScore { delta_chr_f: 0.0 });
44 continue;
45 }
46 };
47 let best_delta_chr_f = expected_texts
48 .iter()
49 .map(|expected| metrics::delta_chr_f(original_text, expected, &actual_text) as f32)
50 .fold(0.0, f32::max);
51 scores.push(ExampleScore {
52 delta_chr_f: best_delta_chr_f,
53 });
54 }
55
56 example.score = scores;
57 Ok(())
58}
59
60pub fn print_report(examples: &[Example]) {
61 eprintln!(
62 "──────────────────────────────────────────────────────────────────────────────────────"
63 );
64 eprintln!("{:<50} {:>10}", "Example name", "DeltaChrF");
65 eprintln!(
66 "──────────────────────────────────────────────────────────────────────────────────────"
67 );
68
69 let mut all_delta_chr_f_scores = Vec::new();
70
71 for example in examples {
72 for score in example.score.iter() {
73 eprintln!(
74 "{:<50} {:>9.2}",
75 truncate_name(&example.spec.name, 50),
76 score.delta_chr_f
77 );
78
79 all_delta_chr_f_scores.push(score.delta_chr_f);
80 }
81 }
82
83 eprintln!(
84 "──────────────────────────────────────────────────────────────────────────────────────"
85 );
86
87 if !all_delta_chr_f_scores.is_empty() {
88 let avg_delta_chr_f: f32 =
89 all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
90
91 eprintln!("{:<50} {:>9.2}", "AVERAGE", avg_delta_chr_f);
92 eprintln!(
93 "──────────────────────────────────────────────────────────────────────────────────────"
94 );
95 }
96
97 eprintln!("\n");
98}
99
100fn truncate_name(name: &str, max_len: usize) -> String {
101 if name.len() <= max_len {
102 name.to_string()
103 } else {
104 format!("{}...", &name[..max_len - 3])
105 }
106}