1use crate::{
2 PredictArgs,
3 example::{Example, ExampleScore},
4 headless::EpAppState,
5 metrics::{self, ClassificationMetrics},
6 predict::run_prediction,
7 progress::{Progress, Step},
8};
9use edit_prediction::udiff::DiffLine;
10use gpui::AsyncApp;
11use std::sync::Arc;
12
13pub async fn run_scoring(
14 example: &mut Example,
15 args: &PredictArgs,
16 app_state: Arc<EpAppState>,
17 cx: AsyncApp,
18) -> anyhow::Result<()> {
19 run_prediction(
20 example,
21 Some(args.provider),
22 args.repetitions,
23 app_state,
24 cx,
25 )
26 .await?;
27
28 let _progress = Progress::global().start(Step::Score, &example.name);
29
30 let expected_patch = parse_patch(&example.expected_patch);
31
32 let mut scores = vec![];
33
34 for pred in &example.predictions {
35 let actual_patch = parse_patch(&pred.actual_patch);
36 let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
37 let delta_chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch) as f32;
38
39 scores.push(ExampleScore {
40 delta_chr_f,
41 line_match,
42 });
43 }
44
45 example.score = scores;
46 Ok(())
47}
48
49fn parse_patch(patch: &str) -> Vec<DiffLine<'_>> {
50 patch.lines().map(DiffLine::parse).collect()
51}
52
53pub fn print_report(examples: &[Example]) {
54 eprintln!(
55 "──────────────────────────────────────────────────────────────────────────────────────"
56 );
57 eprintln!(
58 "{:<30} {:>4} {:>4} {:>4} {:>10} {:>8} {:>8} {:>10}",
59 "Example name", "TP", "FP", "FN", "Precision", "Recall", "F1", "DeltaChrF"
60 );
61 eprintln!(
62 "──────────────────────────────────────────────────────────────────────────────────────"
63 );
64
65 let mut all_line_match_scores = Vec::new();
66 let mut all_delta_chr_f_scores = Vec::new();
67
68 for example in examples {
69 for score in example.score.iter() {
70 let line_match = &score.line_match;
71
72 eprintln!(
73 "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
74 truncate_name(&example.name, 30),
75 line_match.true_positives,
76 line_match.false_positives,
77 line_match.false_negatives,
78 line_match.precision() * 100.0,
79 line_match.recall() * 100.0,
80 line_match.f1_score() * 100.0,
81 score.delta_chr_f
82 );
83
84 all_line_match_scores.push(line_match.clone());
85 all_delta_chr_f_scores.push(score.delta_chr_f);
86 }
87 }
88
89 eprintln!(
90 "──────────────────────────────────────────────────────────────────────────────────────"
91 );
92
93 if !all_line_match_scores.is_empty() {
94 let total_line_match = ClassificationMetrics::aggregate(all_line_match_scores.iter());
95 let avg_delta_chr_f: f32 =
96 all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
97
98 eprintln!(
99 "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
100 "TOTAL",
101 total_line_match.true_positives,
102 total_line_match.false_positives,
103 total_line_match.false_negatives,
104 total_line_match.precision() * 100.0,
105 total_line_match.recall() * 100.0,
106 total_line_match.f1_score() * 100.0,
107 avg_delta_chr_f
108 );
109 eprintln!(
110 "──────────────────────────────────────────────────────────────────────────────────────"
111 );
112 }
113
114 eprintln!("\n");
115}
116
117fn truncate_name(name: &str, max_len: usize) -> String {
118 if name.len() <= max_len {
119 name.to_string()
120 } else {
121 format!("{}...", &name[..max_len - 3])
122 }
123}