1use crate::{
2 PredictArgs,
3 example::{Example, ExampleScore},
4 headless::EpAppState,
5 metrics::{self, ClassificationMetrics},
6 predict::run_prediction,
7 progress::{Progress, Step},
8};
9use edit_prediction::udiff::DiffLine;
10use gpui::AsyncApp;
11use std::sync::Arc;
12
13pub async fn run_scoring(
14 example: &mut Example,
15 args: &PredictArgs,
16 app_state: Arc<EpAppState>,
17 cx: AsyncApp,
18) {
19 run_prediction(
20 example,
21 Some(args.provider),
22 args.repetitions,
23 app_state,
24 cx,
25 )
26 .await;
27
28 let _progress = Progress::global().start(Step::Score, &example.name);
29
30 let expected_patch = parse_patch(&example.expected_patch);
31
32 let mut scores = vec![];
33
34 for pred in &example.predictions {
35 let actual_patch = parse_patch(&pred.actual_patch);
36 let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
37 let delta_chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch) as f32;
38
39 scores.push(ExampleScore {
40 delta_chr_f,
41 line_match,
42 });
43 }
44
45 example.score = scores;
46}
47
48fn parse_patch(patch: &str) -> Vec<DiffLine<'_>> {
49 patch.lines().map(DiffLine::parse).collect()
50}
51
52pub fn print_report(examples: &[Example]) {
53 eprintln!(
54 "──────────────────────────────────────────────────────────────────────────────────────"
55 );
56 eprintln!(
57 "{:<30} {:>4} {:>4} {:>4} {:>10} {:>8} {:>8} {:>10}",
58 "Example name", "TP", "FP", "FN", "Precision", "Recall", "F1", "DeltaChrF"
59 );
60 eprintln!(
61 "──────────────────────────────────────────────────────────────────────────────────────"
62 );
63
64 let mut all_line_match_scores = Vec::new();
65 let mut all_delta_chr_f_scores = Vec::new();
66
67 for example in examples {
68 for score in example.score.iter() {
69 let line_match = &score.line_match;
70
71 eprintln!(
72 "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
73 truncate_name(&example.name, 30),
74 line_match.true_positives,
75 line_match.false_positives,
76 line_match.false_negatives,
77 line_match.precision() * 100.0,
78 line_match.recall() * 100.0,
79 line_match.f1_score() * 100.0,
80 score.delta_chr_f
81 );
82
83 all_line_match_scores.push(line_match.clone());
84 all_delta_chr_f_scores.push(score.delta_chr_f);
85 }
86 }
87
88 eprintln!(
89 "──────────────────────────────────────────────────────────────────────────────────────"
90 );
91
92 if !all_line_match_scores.is_empty() {
93 let total_line_match = ClassificationMetrics::aggregate(all_line_match_scores.iter());
94 let avg_delta_chr_f: f32 =
95 all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
96
97 eprintln!(
98 "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
99 "TOTAL",
100 total_line_match.true_positives,
101 total_line_match.false_positives,
102 total_line_match.false_negatives,
103 total_line_match.precision() * 100.0,
104 total_line_match.recall() * 100.0,
105 total_line_match.f1_score() * 100.0,
106 avg_delta_chr_f
107 );
108 eprintln!(
109 "──────────────────────────────────────────────────────────────────────────────────────"
110 );
111 }
112
113 eprintln!("\n");
114}
115
116fn truncate_name(name: &str, max_len: usize) -> String {
117 if name.len() <= max_len {
118 name.to_string()
119 } else {
120 format!("{}...", &name[..max_len - 3])
121 }
122}