1use crate::{
2 PredictArgs,
3 example::{Example, ExampleScore},
4 headless::EpAppState,
5 metrics::{self, ClassificationMetrics},
6 predict::run_prediction,
7 progress::{Progress, Step},
8};
9use edit_prediction::udiff::DiffLine;
10use gpui::AsyncApp;
11use std::sync::Arc;
12
13pub async fn run_scoring(
14 example: &mut Example,
15 args: &PredictArgs,
16 app_state: Arc<EpAppState>,
17 progress: Arc<Progress>,
18 cx: AsyncApp,
19) {
20 run_prediction(
21 example,
22 Some(args.provider),
23 args.repetitions,
24 app_state,
25 progress.clone(),
26 cx,
27 )
28 .await;
29
30 let _progress = progress.start(Step::Score, &example.name);
31
32 let expected_patch = parse_patch(&example.expected_patch);
33
34 let mut scores = vec![];
35
36 for pred in &example.predictions {
37 let actual_patch = parse_patch(&pred.actual_patch);
38 let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
39 let delta_chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch) as f32;
40
41 scores.push(ExampleScore {
42 delta_chr_f,
43 line_match,
44 });
45 }
46
47 example.score = scores;
48}
49
50fn parse_patch(patch: &str) -> Vec<DiffLine<'_>> {
51 patch.lines().map(DiffLine::parse).collect()
52}
53
54pub fn print_report(examples: &[Example]) {
55 eprintln!(
56 "──────────────────────────────────────────────────────────────────────────────────────"
57 );
58 eprintln!(
59 "{:<30} {:>4} {:>4} {:>4} {:>10} {:>8} {:>8} {:>10}",
60 "Example name", "TP", "FP", "FN", "Precision", "Recall", "F1", "DeltaChrF"
61 );
62 eprintln!(
63 "──────────────────────────────────────────────────────────────────────────────────────"
64 );
65
66 let mut all_line_match_scores = Vec::new();
67 let mut all_delta_chr_f_scores = Vec::new();
68
69 for example in examples {
70 for score in example.score.iter() {
71 let line_match = &score.line_match;
72
73 eprintln!(
74 "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
75 truncate_name(&example.name, 30),
76 line_match.true_positives,
77 line_match.false_positives,
78 line_match.false_negatives,
79 line_match.precision() * 100.0,
80 line_match.recall() * 100.0,
81 line_match.f1_score() * 100.0,
82 score.delta_chr_f
83 );
84
85 all_line_match_scores.push(line_match.clone());
86 all_delta_chr_f_scores.push(score.delta_chr_f);
87 }
88 }
89
90 eprintln!(
91 "──────────────────────────────────────────────────────────────────────────────────────"
92 );
93
94 if !all_line_match_scores.is_empty() {
95 let total_line_match = ClassificationMetrics::aggregate(all_line_match_scores.iter());
96 let avg_delta_chr_f: f32 =
97 all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
98
99 eprintln!(
100 "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
101 "TOTAL",
102 total_line_match.true_positives,
103 total_line_match.false_positives,
104 total_line_match.false_negatives,
105 total_line_match.precision() * 100.0,
106 total_line_match.recall() * 100.0,
107 total_line_match.f1_score() * 100.0,
108 avg_delta_chr_f
109 );
110 eprintln!(
111 "──────────────────────────────────────────────────────────────────────────────────────"
112 );
113 }
114
115 eprintln!("\n");
116}
117
118fn truncate_name(name: &str, max_len: usize) -> String {
119 if name.len() <= max_len {
120 name.to_string()
121 } else {
122 format!("{}...", &name[..max_len - 3])
123 }
124}