1use crate::{
2 PredictArgs,
3 example::{Example, ExampleScore},
4 headless::EpAppState,
5 metrics::{self, ClassificationMetrics},
6 predict::run_prediction,
7};
8use edit_prediction::udiff::DiffLine;
9use gpui::AsyncApp;
10use std::sync::Arc;
11
12pub async fn run_scoring(
13 example: &mut Example,
14 args: &PredictArgs,
15 app_state: Arc<EpAppState>,
16 cx: AsyncApp,
17) {
18 run_prediction(
19 example,
20 Some(args.provider),
21 args.repetitions,
22 app_state,
23 cx,
24 )
25 .await;
26
27 let expected_patch = parse_patch(&example.expected_patch);
28
29 let mut scores = vec![];
30
31 for pred in &example.predictions {
32 let actual_patch = parse_patch(&pred.actual_patch);
33 let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
34 let delta_chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch) as f32;
35
36 scores.push(ExampleScore {
37 delta_chr_f,
38 line_match,
39 });
40 }
41
42 example.score = scores;
43}
44
45fn parse_patch(patch: &str) -> Vec<DiffLine<'_>> {
46 patch.lines().map(DiffLine::parse).collect()
47}
48
49pub fn print_report(examples: &[Example]) {
50 eprintln!(
51 "──────────────────────────────────────────────────────────────────────────────────────"
52 );
53 eprintln!(
54 "{:<30} {:>4} {:>4} {:>4} {:>10} {:>8} {:>8} {:>10}",
55 "Example name", "TP", "FP", "FN", "Precision", "Recall", "F1", "DeltaChrF"
56 );
57 eprintln!(
58 "──────────────────────────────────────────────────────────────────────────────────────"
59 );
60
61 let mut all_line_match_scores = Vec::new();
62 let mut all_delta_chr_f_scores = Vec::new();
63
64 for example in examples {
65 for score in example.score.iter() {
66 let line_match = &score.line_match;
67
68 eprintln!(
69 "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
70 truncate_name(&example.name, 30),
71 line_match.true_positives,
72 line_match.false_positives,
73 line_match.false_negatives,
74 line_match.precision() * 100.0,
75 line_match.recall() * 100.0,
76 line_match.f1_score() * 100.0,
77 score.delta_chr_f
78 );
79
80 all_line_match_scores.push(line_match.clone());
81 all_delta_chr_f_scores.push(score.delta_chr_f);
82 }
83 }
84
85 eprintln!(
86 "──────────────────────────────────────────────────────────────────────────────────────"
87 );
88
89 if !all_line_match_scores.is_empty() {
90 let total_line_match = ClassificationMetrics::aggregate(all_line_match_scores.iter());
91 let avg_delta_chr_f: f32 =
92 all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
93
94 eprintln!(
95 "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
96 "TOTAL",
97 total_line_match.true_positives,
98 total_line_match.false_positives,
99 total_line_match.false_negatives,
100 total_line_match.precision() * 100.0,
101 total_line_match.recall() * 100.0,
102 total_line_match.f1_score() * 100.0,
103 avg_delta_chr_f
104 );
105 eprintln!(
106 "──────────────────────────────────────────────────────────────────────────────────────"
107 );
108 }
109
110 eprintln!("\n");
111}
112
113fn truncate_name(name: &str, max_len: usize) -> String {
114 if name.len() <= max_len {
115 name.to_string()
116 } else {
117 format!("{}...", &name[..max_len - 3])
118 }
119}