1use std::{
2 fs,
3 io::IsTerminal,
4 path::{Path, PathBuf},
5 sync::Arc,
6};
7
8use anyhow::Result;
9use clap::Args;
10use collections::HashSet;
11use gpui::AsyncApp;
12use zeta2::udiff::DiffLine;
13
14use crate::{
15 example::{Example, NamedExample},
16 headless::ZetaCliAppState,
17 paths::CACHE_DIR,
18 predict::{PredictionDetails, zeta2_predict},
19};
20
21#[derive(Debug, Args)]
22pub struct EvaluateArguments {
23 example_paths: Vec<PathBuf>,
24 #[clap(long)]
25 re_run: bool,
26}
27
28pub async fn run_evaluate(
29 args: EvaluateArguments,
30 app_state: &Arc<ZetaCliAppState>,
31 cx: &mut AsyncApp,
32) {
33 let example_len = args.example_paths.len();
34 let all_tasks = args.example_paths.into_iter().map(|path| {
35 let app_state = app_state.clone();
36 cx.spawn(async move |cx| run_evaluate_one(&path, args.re_run, app_state.clone(), cx).await)
37 });
38 let all_results = futures::future::try_join_all(all_tasks).await.unwrap();
39
40 let aggregated_result = EvaluationResult {
41 context: Scores::aggregate(all_results.iter().map(|r| &r.context)),
42 edit_prediction: Scores::aggregate(all_results.iter().map(|r| &r.edit_prediction)),
43 };
44
45 if example_len > 1 {
46 println!("\n{}", "-".repeat(80));
47 println!("# TOTAL SCORES:");
48 println!("{}", aggregated_result.to_markdown());
49 }
50}
51
52pub async fn run_evaluate_one(
53 example_path: &Path,
54 re_run: bool,
55 app_state: Arc<ZetaCliAppState>,
56 cx: &mut AsyncApp,
57) -> Result<EvaluationResult> {
58 let example = NamedExample::load(&example_path).unwrap();
59 let example_cache_path = CACHE_DIR.join(&example_path.file_name().unwrap());
60
61 let predictions = if !re_run && example_cache_path.exists() {
62 let file_contents = fs::read_to_string(&example_cache_path)?;
63 let as_json = serde_json::from_str::<PredictionDetails>(&file_contents)?;
64 log::debug!(
65 "Loaded predictions from cache: {}",
66 example_cache_path.display()
67 );
68 as_json
69 } else {
70 zeta2_predict(example.clone(), &app_state, cx)
71 .await
72 .unwrap()
73 };
74
75 if !example_cache_path.exists() {
76 fs::create_dir_all(&*CACHE_DIR).unwrap();
77 fs::write(
78 example_cache_path,
79 serde_json::to_string(&predictions).unwrap(),
80 )
81 .unwrap();
82 }
83
84 let evaluation_result = evaluate(&example.example, &predictions);
85
86 println!(
87 "## Expected edit prediction:\n\n```diff\n{}\n```\n",
88 compare_diffs(&example.example.expected_patch, &predictions.diff)
89 );
90 println!(
91 "## Actual edit prediction:\n\n```diff\n{}\n```\n",
92 compare_diffs(&predictions.diff, &example.example.expected_patch)
93 );
94
95 println!("{}", evaluation_result.to_markdown());
96
97 anyhow::Ok(evaluation_result)
98}
99
100#[derive(Debug, Default)]
101pub struct EvaluationResult {
102 pub edit_prediction: Scores,
103 pub context: Scores,
104}
105
106#[derive(Default, Debug)]
107pub struct Scores {
108 pub true_positives: usize,
109 pub false_positives: usize,
110 pub false_negatives: usize,
111}
112
113impl Scores {
114 pub fn new(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
115 let true_positives = expected.intersection(actual).count();
116 let false_positives = actual.difference(expected).count();
117 let false_negatives = expected.difference(actual).count();
118
119 Scores {
120 true_positives,
121 false_positives,
122 false_negatives,
123 }
124 }
125
126 pub fn to_markdown(&self) -> String {
127 format!(
128 "
129Precision : {:.4}
130Recall : {:.4}
131F1 Score : {:.4}
132True Positives : {}
133False Positives : {}
134False Negatives : {}",
135 self.precision(),
136 self.recall(),
137 self.f1_score(),
138 self.true_positives,
139 self.false_positives,
140 self.false_negatives
141 )
142 }
143
144 pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
145 let mut true_positives = 0;
146 let mut false_positives = 0;
147 let mut false_negatives = 0;
148
149 for score in scores {
150 true_positives += score.true_positives;
151 false_positives += score.false_positives;
152 false_negatives += score.false_negatives;
153 }
154
155 Scores {
156 true_positives,
157 false_positives,
158 false_negatives,
159 }
160 }
161
162 pub fn precision(&self) -> f64 {
163 if self.true_positives + self.false_positives == 0 {
164 0.0
165 } else {
166 self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
167 }
168 }
169
170 pub fn recall(&self) -> f64 {
171 if self.true_positives + self.false_negatives == 0 {
172 0.0
173 } else {
174 self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
175 }
176 }
177
178 pub fn f1_score(&self) -> f64 {
179 let recall = self.recall();
180 let precision = self.precision();
181 if precision + recall == 0.0 {
182 0.0
183 } else {
184 2.0 * precision * recall / (precision + recall)
185 }
186 }
187}
188
189impl EvaluationResult {
190 pub fn to_markdown(&self) -> String {
191 format!(
192 r#"
193### Context Scores
194{}
195
196### Edit Prediction Scores
197{}
198"#,
199 self.context.to_markdown(),
200 self.edit_prediction.to_markdown()
201 )
202 }
203}
204
205pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult {
206 let mut eval_result = EvaluationResult::default();
207
208 let actual_context_lines: HashSet<_> = preds
209 .excerpts
210 .iter()
211 .flat_map(|excerpt| {
212 excerpt
213 .text
214 .lines()
215 .map(|line| format!("{}: {line}", excerpt.path.display()))
216 })
217 .collect();
218
219 let mut false_positive_lines = actual_context_lines.clone();
220
221 for entry in &example.expected_context {
222 let mut best_alternative_score = Scores::default();
223
224 for alternative in &entry.alternatives {
225 let expected: HashSet<_> = alternative
226 .excerpts
227 .iter()
228 .flat_map(|excerpt| {
229 excerpt
230 .text
231 .lines()
232 .map(|line| format!("{}: {line}", excerpt.path.display()))
233 })
234 .collect();
235
236 let scores = Scores::new(&expected, &actual_context_lines);
237
238 false_positive_lines.retain(|line| !actual_context_lines.contains(line));
239
240 if scores.recall() > best_alternative_score.recall() {
241 best_alternative_score = scores;
242 }
243 }
244
245 eval_result.context.false_negatives += best_alternative_score.false_negatives;
246 eval_result.context.true_positives += best_alternative_score.true_positives;
247 }
248
249 eval_result.context.false_positives = false_positive_lines.len();
250
251 // todo: alternatives for patches
252 let expected_patch_lines = example
253 .expected_patch
254 .lines()
255 .map(DiffLine::parse)
256 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
257 .map(|line| line.to_string())
258 .collect();
259
260 let actual_patch_lines = preds
261 .diff
262 .lines()
263 .map(DiffLine::parse)
264 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
265 .map(|line| line.to_string())
266 .collect();
267
268 eval_result.edit_prediction = Scores::new(&expected_patch_lines, &actual_patch_lines);
269 eval_result
270}
271
272/// Return annotated `patch_a` so that:
273/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
274/// Additions and deletions that are present in `patch_b` will be highlighted in green.
275pub fn compare_diffs(patch_a: &str, patch_b: &str) -> String {
276 let use_color = std::io::stdout().is_terminal();
277 let green = if use_color { "\x1b[32m✓ " } else { "" };
278 let red = if use_color { "\x1b[31m✗ " } else { "" };
279 let neutral = if use_color { " " } else { "" };
280 let reset = if use_color { "\x1b[0m" } else { "" };
281 let lines_a = patch_a.lines().map(DiffLine::parse);
282 let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
283
284 let annotated = lines_a
285 .map(|line| match line {
286 DiffLine::Addition(_) | DiffLine::Deletion(_) => {
287 if lines_b.contains(&line) {
288 format!("{green}{line}{reset}")
289 } else {
290 format!("{red}{line}{reset}")
291 }
292 }
293 _ => format!("{neutral}{line}{reset}"),
294 })
295 .collect::<Vec<String>>();
296
297 annotated.join("\n")
298}