1use std::{
2 fs,
3 io::IsTerminal,
4 path::{Path, PathBuf},
5 sync::Arc,
6};
7
8use anyhow::Result;
9use clap::Args;
10use collections::HashSet;
11use gpui::AsyncApp;
12use zeta2::udiff::DiffLine;
13
14use crate::{
15 example::{Example, NamedExample},
16 headless::ZetaCliAppState,
17 paths::CACHE_DIR,
18 predict::{PredictionDetails, zeta2_predict},
19};
20
21#[derive(Debug, Args)]
22pub struct EvaluateArguments {
23 example_paths: Vec<PathBuf>,
24 #[clap(long)]
25 re_run: bool,
26}
27
28pub async fn run_evaluate(
29 args: EvaluateArguments,
30 app_state: &Arc<ZetaCliAppState>,
31 cx: &mut AsyncApp,
32) {
33 let example_len = args.example_paths.len();
34 let all_tasks = args.example_paths.into_iter().map(|path| {
35 let app_state = app_state.clone();
36 cx.spawn(async move |cx| run_evaluate_one(&path, args.re_run, app_state.clone(), cx).await)
37 });
38 let all_results = futures::future::try_join_all(all_tasks).await.unwrap();
39
40 let aggregated_result = EvaluationResult {
41 context: Scores::aggregate(all_results.iter().map(|r| &r.context)),
42 edit_prediction: Scores::aggregate(all_results.iter().map(|r| &r.edit_prediction)),
43 };
44
45 if example_len > 1 {
46 println!("\n{}", "-".repeat(80));
47 println!("# TOTAL SCORES:");
48 println!("{}", aggregated_result.to_markdown());
49 }
50}
51
52pub async fn run_evaluate_one(
53 example_path: &Path,
54 re_run: bool,
55 app_state: Arc<ZetaCliAppState>,
56 cx: &mut AsyncApp,
57) -> Result<EvaluationResult> {
58 let example = NamedExample::load(&example_path).unwrap();
59 let example_cache_path = CACHE_DIR.join(&example_path.file_name().unwrap());
60
61 let predictions = if !re_run && example_cache_path.exists() {
62 let file_contents = fs::read_to_string(&example_cache_path)?;
63 let as_json = serde_json::from_str::<PredictionDetails>(&file_contents)?;
64 log::debug!(
65 "Loaded predictions from cache: {}",
66 example_cache_path.display()
67 );
68 as_json
69 } else {
70 zeta2_predict(example.clone(), &app_state, cx)
71 .await
72 .unwrap()
73 };
74
75 if !example_cache_path.exists() {
76 fs::create_dir_all(&*CACHE_DIR).unwrap();
77 fs::write(
78 example_cache_path,
79 serde_json::to_string(&predictions).unwrap(),
80 )
81 .unwrap();
82 }
83
84 let evaluation_result = evaluate(&example.example, &predictions);
85
86 println!("# {}\n", example.name);
87 println!(
88 "## Expected Context: \n\n```\n{}\n```\n\n",
89 compare_context(&example.example, &predictions)
90 );
91 println!(
92 "## Expected edit prediction:\n\n```diff\n{}\n```\n",
93 compare_diffs(&example.example.expected_patch, &predictions.diff)
94 );
95 println!(
96 "## Actual edit prediction:\n\n```diff\n{}\n```\n",
97 compare_diffs(&predictions.diff, &example.example.expected_patch)
98 );
99
100 println!("{}", evaluation_result.to_markdown());
101
102 anyhow::Ok(evaluation_result)
103}
104
105#[derive(Debug, Default)]
106pub struct EvaluationResult {
107 pub context: Scores,
108 pub edit_prediction: Scores,
109}
110
111#[derive(Default, Debug)]
112pub struct Scores {
113 pub precision: f64,
114 pub recall: f64,
115 pub f1_score: f64,
116 pub true_positives: usize,
117 pub false_positives: usize,
118 pub false_negatives: usize,
119}
120
121impl Scores {
122 pub fn to_markdown(&self) -> String {
123 format!(
124 "
125Precision : {:.4}
126Recall : {:.4}
127F1 Score : {:.4}
128True Positives : {}
129False Positives : {}
130False Negatives : {}",
131 self.precision,
132 self.recall,
133 self.f1_score,
134 self.true_positives,
135 self.false_positives,
136 self.false_negatives
137 )
138 }
139}
140
141impl Scores {
142 pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
143 let mut true_positives = 0;
144 let mut false_positives = 0;
145 let mut false_negatives = 0;
146
147 for score in scores {
148 true_positives += score.true_positives;
149 false_positives += score.false_positives;
150 false_negatives += score.false_negatives;
151 }
152
153 let precision = true_positives as f64 / (true_positives + false_positives) as f64;
154 let recall = true_positives as f64 / (true_positives + false_negatives) as f64;
155 let mut f1_score = 2.0 * precision * recall / (precision + recall);
156 if f1_score.is_nan() {
157 f1_score = 0.0;
158 }
159
160 Scores {
161 precision,
162 recall,
163 f1_score,
164 true_positives,
165 false_positives,
166 false_negatives,
167 }
168 }
169}
170
171impl EvaluationResult {
172 pub fn to_markdown(&self) -> String {
173 format!(
174 r#"
175### Context Scores
176{}
177
178### Edit Prediction Scores
179{}
180"#,
181 self.context.to_markdown(),
182 self.edit_prediction.to_markdown()
183 )
184 }
185}
186
187pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult {
188 let mut result = EvaluationResult::default();
189
190 let expected_context_lines = example
191 .expected_excerpts
192 .iter()
193 .flat_map(|excerpt| {
194 excerpt
195 .text
196 .lines()
197 .map(|line| format!("{}: {line}", excerpt.path.display()))
198 })
199 .collect();
200 let actual_context_lines = preds
201 .excerpts
202 .iter()
203 .flat_map(|excerpt| {
204 excerpt
205 .text
206 .lines()
207 .map(|line| format!("{}: {line}", excerpt.path.display()))
208 })
209 .collect();
210
211 result.context = precision_recall(&expected_context_lines, &actual_context_lines);
212
213 let expected_patch_lines = example
214 .expected_patch
215 .lines()
216 .map(DiffLine::parse)
217 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
218 .map(|line| line.to_string())
219 .collect();
220
221 let actual_patch_lines = preds
222 .diff
223 .lines()
224 .map(DiffLine::parse)
225 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
226 .map(|line| line.to_string())
227 .collect();
228
229 result.edit_prediction = precision_recall(&expected_patch_lines, &actual_patch_lines);
230
231 result
232}
233
234fn precision_recall(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
235 let true_positives = expected.intersection(actual).count();
236 let false_positives = actual.difference(expected).count();
237 let false_negatives = expected.difference(actual).count();
238
239 let precision = if true_positives + false_positives == 0 {
240 0.0
241 } else {
242 true_positives as f64 / (true_positives + false_positives) as f64
243 };
244 let recall = if true_positives + false_negatives == 0 {
245 0.0
246 } else {
247 true_positives as f64 / (true_positives + false_negatives) as f64
248 };
249 let f1_score = if precision + recall == 0.0 {
250 0.0
251 } else {
252 2.0 * precision * recall / (precision + recall)
253 };
254
255 Scores {
256 precision,
257 recall,
258 f1_score,
259 true_positives,
260 false_positives,
261 false_negatives,
262 }
263}
264
265/// Compare actual and expected context.
266///
267/// Return expected context annotated with these markers:
268///
269/// `✓ context line` -- line was correctly predicted
270/// `✗ context line` -- line is missing from predictions
271pub fn compare_context(example: &Example, preds: &PredictionDetails) -> String {
272 let use_color = std::io::stdout().is_terminal();
273 let green = if use_color { "\x1b[32m" } else { "" };
274 let red = if use_color { "\x1b[31m" } else { "" };
275 let reset = if use_color { "\x1b[0m" } else { "" };
276 let expected: Vec<_> = example
277 .expected_excerpts
278 .iter()
279 .flat_map(|excerpt| {
280 excerpt
281 .text
282 .lines()
283 .map(|line| (excerpt.path.clone(), line))
284 })
285 .collect();
286 let actual: HashSet<_> = preds
287 .excerpts
288 .iter()
289 .flat_map(|excerpt| {
290 excerpt
291 .text
292 .lines()
293 .map(|line| (excerpt.path.clone(), line))
294 })
295 .collect();
296
297 let annotated = expected
298 .iter()
299 .map(|(path, line)| {
300 if actual.contains(&(path.to_path_buf(), line)) {
301 format!("{green}✓ {line}{reset}")
302 } else {
303 format!("{red}✗ {line}{reset}")
304 }
305 })
306 .collect::<Vec<String>>();
307
308 annotated.join("\n")
309}
310
311/// Return annotated `patch_a` so that:
312/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
313/// Additions and deletions that are present in `patch_b` will be highlighted in green.
314pub fn compare_diffs(patch_a: &str, patch_b: &str) -> String {
315 let use_color = std::io::stdout().is_terminal();
316 let green = if use_color { "\x1b[32m✓ " } else { "" };
317 let red = if use_color { "\x1b[31m✗ " } else { "" };
318 let neutral = if use_color { " " } else { "" };
319 let reset = if use_color { "\x1b[0m" } else { "" };
320 let lines_a = patch_a.lines().map(DiffLine::parse);
321 let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
322
323 let annotated = lines_a
324 .map(|line| match line {
325 DiffLine::Addition(_) | DiffLine::Deletion(_) => {
326 if lines_b.contains(&line) {
327 format!("{green}{line}{reset}")
328 } else {
329 format!("{red}{line}{reset}")
330 }
331 }
332 _ => format!("{neutral}{line}{reset}"),
333 })
334 .collect::<Vec<String>>();
335
336 annotated.join("\n")
337}