1use std::{
2 fs,
3 io::IsTerminal,
4 path::{Path, PathBuf},
5 sync::Arc,
6};
7
8use anyhow::Result;
9use clap::Args;
10use cloud_llm_client::udiff::DiffLine;
11use collections::HashSet;
12use gpui::AsyncApp;
13
14use crate::{
15 example::{Example, Excerpt, NamedExample},
16 headless::ZetaCliAppState,
17 predict::{PredictionDetails, zeta2_predict},
18};
19
20#[derive(Debug, Args)]
21pub struct EvaluateArguments {
22 example_paths: Vec<PathBuf>,
23 #[clap(long)]
24 re_run: bool,
25}
26
27pub async fn run_evaluate(
28 args: EvaluateArguments,
29 app_state: &Arc<ZetaCliAppState>,
30 cx: &mut AsyncApp,
31) {
32 let example_len = args.example_paths.len();
33 let all_tasks = args.example_paths.into_iter().map(|path| {
34 let app_state = app_state.clone();
35 cx.spawn(async move |cx| run_evaluate_one(&path, args.re_run, app_state.clone(), cx).await)
36 });
37 let all_results = futures::future::try_join_all(all_tasks).await.unwrap();
38
39 let aggregated_result = EvaluationResult {
40 context: Scores::aggregate(all_results.iter().map(|r| &r.context)),
41 edit_prediction: Scores::aggregate(all_results.iter().map(|r| &r.edit_prediction)),
42 edit_sites_coverage: all_results
43 .iter()
44 .map(|r| r.edit_sites_coverage)
45 .sum::<f64>()
46 / all_results.len() as f64,
47 };
48
49 if example_len > 1 {
50 println!("\n{}", "-".repeat(80));
51 println!("# TOTAL SCORES:");
52 println!("{}", aggregated_result.to_markdown());
53 }
54}
55
56pub async fn run_evaluate_one(
57 example_path: &Path,
58 re_run: bool,
59 app_state: Arc<ZetaCliAppState>,
60 cx: &mut AsyncApp,
61) -> Result<EvaluationResult> {
62 let cache_dir = Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap_or_default())
63 .join("../../target/zeta-prediction-cache");
64 let example = NamedExample::load(&example_path).unwrap();
65 let example_cache_path = cache_dir.join(&example_path.file_name().unwrap());
66
67 let predictions = if !re_run && example_cache_path.exists() {
68 let file_contents = fs::read_to_string(&example_cache_path)?;
69 let as_json = serde_json::from_str::<PredictionDetails>(&file_contents)?;
70 log::debug!(
71 "Loaded predictions from cache: {}",
72 example_cache_path.display()
73 );
74 as_json
75 } else {
76 zeta2_predict(example.clone(), &app_state, cx)
77 .await
78 .unwrap()
79 };
80
81 if !example_cache_path.exists() {
82 fs::create_dir_all(&cache_dir).unwrap();
83 fs::write(
84 example_cache_path,
85 serde_json::to_string(&predictions).unwrap(),
86 )
87 .unwrap();
88 }
89
90 let evaluation_result = evaluate(&example.example, &predictions);
91
92 println!("# {}\n", example.name);
93 println!(
94 "## Expected Context: \n\n```\n{}\n```\n\n",
95 compare_context(&example.example, &predictions)
96 );
97 println!(
98 "## Expected edit prediction:\n\n```diff\n{}\n```\n",
99 compare_diffs(&example.example.expected_patch, &predictions.diff)
100 );
101 println!(
102 "## Actual edit prediction:\n\n```diff\n{}\n```\n",
103 compare_diffs(&predictions.diff, &example.example.expected_patch)
104 );
105
106 println!("{}", evaluation_result.to_markdown());
107
108 anyhow::Ok(evaluation_result)
109}
110
111#[derive(Debug, Default)]
112pub struct EvaluationResult {
113 pub context: Scores,
114
115 /// Ratio of edited lines that we expect to edit (as indicated in the
116 /// expected patch) AND were included into the context
117 /// num_correctly_retrieved_lines / num_expected_lines
118 pub edit_sites_coverage: f64,
119
120 pub edit_prediction: Scores,
121}
122
123#[derive(Default, Debug)]
124pub struct Scores {
125 pub precision: f64,
126 pub recall: f64,
127 pub f1_score: f64,
128 pub true_positives: usize,
129 pub false_positives: usize,
130 pub false_negatives: usize,
131}
132
133impl Scores {
134 pub fn to_markdown(&self) -> String {
135 format!(
136 "
137Precision : {:.4}
138Recall : {:.4}
139F1 Score : {:.4}
140True Positives : {}
141False Positives : {}
142False Negatives : {}",
143 self.precision,
144 self.recall,
145 self.f1_score,
146 self.true_positives,
147 self.false_positives,
148 self.false_negatives
149 )
150 }
151}
152
153impl Scores {
154 pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
155 let mut true_positives = 0;
156 let mut false_positives = 0;
157 let mut false_negatives = 0;
158
159 for score in scores {
160 true_positives += score.true_positives;
161 false_positives += score.false_positives;
162 false_negatives += score.false_negatives;
163 }
164
165 let precision = true_positives as f64 / (true_positives + false_positives) as f64;
166 let recall = true_positives as f64 / (true_positives + false_negatives) as f64;
167 let mut f1_score = 2.0 * precision * recall / (precision + recall);
168 if f1_score.is_nan() {
169 f1_score = 0.0;
170 }
171
172 Scores {
173 precision,
174 recall,
175 f1_score,
176 true_positives,
177 false_positives,
178 false_negatives,
179 }
180 }
181}
182
183#[derive(Debug, Clone)]
184struct EditSitesScores {
185 num_edit_sites: u32,
186 num_correctly_retrieved: u32,
187}
188
189impl EvaluationResult {
190 pub fn to_markdown(&self) -> String {
191 format!(
192 r#"
193### Context Scores
194{}
195Edit sites coverage: {}
196
197### Edit Prediction Scores
198{}
199"#,
200 self.context.to_markdown(),
201 self.edit_sites_coverage,
202 self.edit_prediction.to_markdown()
203 )
204 }
205}
206
207pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult {
208 let mut result = EvaluationResult::default();
209
210 let expected_context_lines = example
211 .expected_excerpts
212 .iter()
213 .flat_map(|excerpt| {
214 excerpt
215 .text
216 .lines()
217 .map(|line| format!("{}: {line}", excerpt.path.display()))
218 })
219 .collect();
220 let actual_context_lines = preds
221 .excerpts
222 .iter()
223 .flat_map(|excerpt| {
224 excerpt
225 .text
226 .lines()
227 .map(|line| format!("{}: {line}", excerpt.path.display()))
228 })
229 .collect();
230
231 result.context = precision_recall(&expected_context_lines, &actual_context_lines);
232
233 let expected_patch_lines = example
234 .expected_patch
235 .lines()
236 .map(DiffLine::parse)
237 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
238 .map(|line| line.to_string())
239 .collect();
240
241 let actual_patch_lines = preds
242 .diff
243 .lines()
244 .map(DiffLine::parse)
245 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
246 .map(|line| line.to_string())
247 .collect();
248
249 result.edit_prediction = precision_recall(&expected_patch_lines, &actual_patch_lines);
250
251 result.edit_sites_coverage =
252 calculate_edit_sites_coverage(&example.expected_patch, &preds.excerpts);
253
254 result
255}
256
257/// Compute the ratio of lines that we expect to edit (are in the expected patch) that
258/// were included in the retrieved context
259/// `num_correctly_retrieved_lines / num_edited_lines_in_expected_patch`
260///
261/// In order to make an edit in some line, the model has to have an access to this line.
262/// If we don't include the line in the retrieved context, there's no chance to make an edit.
263///
264/// This metric reflects that, where 1.0 -- we retrieved all lines to be
265/// edited, and 0.0 -- we retrieved none of them.
266///
267/// Example:
268fn calculate_edit_sites_coverage(patch: &str, excerpts: &[Excerpt]) -> EditSitesScores {
269 // todo:
270 let expected_patch_lines = patch
271 .lines()
272 .map(DiffLine::parse)
273 .filter_map(|line| match line {
274 DiffLine::Deletion(text) => Some(text.trim().to_string()),
275 _ => None,
276 })
277 .collect::<Vec<_>>();
278
279 let correct_cases = expected_patch_lines
280 .iter()
281 .filter(|line| {
282 excerpts.iter().any(|excerpt| {
283 excerpt
284 .text
285 .lines()
286 .any(|excerpt_line| excerpt_line == *line)
287 })
288 })
289 .count();
290 let total_cases = expected_patch_lines.len();
291
292 if total_cases == 0 {
293 0.0
294 } else {
295 correct_cases as f64 / total_cases as f64
296 }
297}
298
299fn precision_recall(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
300 let true_positives = expected.intersection(actual).count();
301 let false_positives = actual.difference(expected).count();
302 let false_negatives = expected.difference(actual).count();
303
304 let precision = if true_positives + false_positives == 0 {
305 0.0
306 } else {
307 true_positives as f64 / (true_positives + false_positives) as f64
308 };
309 let recall = if true_positives + false_negatives == 0 {
310 0.0
311 } else {
312 true_positives as f64 / (true_positives + false_negatives) as f64
313 };
314 let f1_score = if precision + recall == 0.0 {
315 0.0
316 } else {
317 2.0 * precision * recall / (precision + recall)
318 };
319
320 Scores {
321 precision,
322 recall,
323 f1_score,
324 true_positives,
325 false_positives,
326 false_negatives,
327 }
328}
329
330/// Compare actual and expected context.
331///
332/// Return expected context annotated with these markers:
333///
334/// `✓ context line` -- line was correctly predicted
335/// `✗ context line` -- line is missing from predictions
336pub fn compare_context(example: &Example, preds: &PredictionDetails) -> String {
337 let use_color = std::io::stdout().is_terminal();
338 let green = if use_color { "\x1b[32m" } else { "" };
339 let red = if use_color { "\x1b[31m" } else { "" };
340 let reset = if use_color { "\x1b[0m" } else { "" };
341 let expected: Vec<_> = example
342 .expected_excerpts
343 .iter()
344 .flat_map(|excerpt| {
345 excerpt
346 .text
347 .lines()
348 .map(|line| (excerpt.path.clone(), line))
349 })
350 .collect();
351 let actual: HashSet<_> = preds
352 .excerpts
353 .iter()
354 .flat_map(|excerpt| {
355 excerpt
356 .text
357 .lines()
358 .map(|line| (excerpt.path.clone(), line))
359 })
360 .collect();
361
362 let annotated = expected
363 .iter()
364 .map(|(path, line)| {
365 if actual.contains(&(path.to_path_buf(), line)) {
366 format!("{green}✓ {line}{reset}")
367 } else {
368 format!("{red}✗ {line}{reset}")
369 }
370 })
371 .collect::<Vec<String>>();
372
373 annotated.join("\n")
374}
375
376/// Return annotated `patch_a` so that:
377/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
378/// Additions and deletions that are present in `patch_b` will be highlighted in green.
379pub fn compare_diffs(patch_a: &str, patch_b: &str) -> String {
380 let use_color = std::io::stdout().is_terminal();
381 let green = if use_color { "\x1b[32m✓ " } else { "" };
382 let red = if use_color { "\x1b[31m✗ " } else { "" };
383 let neutral = if use_color { " " } else { "" };
384 let reset = if use_color { "\x1b[0m" } else { "" };
385 let lines_a = patch_a.lines().map(DiffLine::parse);
386 let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
387
388 let annotated = lines_a
389 .map(|line| match line {
390 DiffLine::Addition(_) | DiffLine::Deletion(_) => {
391 if lines_b.contains(&line) {
392 format!("{green}{line}{reset}")
393 } else {
394 format!("{red}{line}{reset}")
395 }
396 }
397 _ => format!("{neutral}{line}{reset}"),
398 })
399 .collect::<Vec<String>>();
400
401 annotated.join("\n")
402}
403
404#[cfg(test)]
405mod tests {
406 use super::calculate_edit_sites_coverage;
407 use crate::example::Excerpt;
408
409 #[test]
410 fn test_evaluate_expected_edit_places() {
411 let patch = indoc::indoc! {"
412 --- a/test.txt
413 +++ b/test.txt
414 @@ -1,4 +1,4 @@
415 apple
416 -banana
417 +BANANA
418 cherry
419 -date
420 +DATE
421 "};
422
423 let one_correct_excerpt = vec![Excerpt {
424 path: "test.txt".into(),
425 text: "apple\nbanana\n".to_string(),
426 }];
427
428 assert_eq!(
429 calculate_edit_sites_coverage(&patch, &one_correct_excerpt),
430 0.5,
431 );
432
433 let both_correct_excerpts = vec![
434 Excerpt {
435 path: "test.txt".into(),
436 text: "apple\nbanana\n".to_string(),
437 },
438 Excerpt {
439 path: "test.txt".into(),
440 text: "cherry\ndate\n".to_string(),
441 },
442 ];
443
444 assert_eq!(
445 calculate_edit_sites_coverage(&patch, &both_correct_excerpts),
446 1.0,
447 );
448
449 let incorrect_excerpts = vec![Excerpt {
450 path: "test.txt".into(),
451 text: "apple\n".into(),
452 }];
453 assert_eq!(
454 calculate_edit_sites_coverage(&patch, &incorrect_excerpts),
455 0.0,
456 );
457 }
458}