1use std::{
2 io::{IsTerminal, Write},
3 path::PathBuf,
4 sync::Arc,
5};
6
7use anyhow::Result;
8use clap::Args;
9use collections::HashSet;
10use gpui::{AsyncApp, Entity};
11use project::Project;
12use util::ResultExt as _;
13use zeta2::{Zeta, udiff::DiffLine};
14
15use crate::{
16 PromptFormat,
17 example::{Example, NamedExample},
18 headless::ZetaCliAppState,
19 paths::print_run_data_dir,
20 predict::{CacheMode, PredictionDetails, zeta2_predict},
21};
22
23#[derive(Debug, Args)]
24pub struct EvaluateArguments {
25 example_paths: Vec<PathBuf>,
26 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
27 prompt_format: PromptFormat,
28 #[arg(long)]
29 use_expected_context: bool,
30 #[clap(long, value_enum, default_value_t = CacheMode::default())]
31 cache: CacheMode,
32 #[clap(short, long, default_value_t = 1, alias = "repeat")]
33 repetitions: u16,
34 #[arg(long)]
35 skip_prediction: bool,
36}
37
38pub async fn run_evaluate(
39 args: EvaluateArguments,
40 app_state: &Arc<ZetaCliAppState>,
41 cx: &mut AsyncApp,
42) {
43 if args.example_paths.is_empty() {
44 eprintln!("No examples provided");
45 return;
46 }
47 let all_tasks = args.example_paths.into_iter().map(|path| {
48 let app_state = app_state.clone();
49 let example = NamedExample::load(&path).expect("Failed to load example");
50
51 cx.spawn(async move |cx| {
52 let (project, zetas, _edited_buffers) = example
53 .setup_project(&app_state, args.repetitions, cx)
54 .await
55 .unwrap();
56
57 let tasks = zetas.into_iter().enumerate().map(|(repetition_ix, zeta)| {
58 let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
59 let example = example.clone();
60 let project = project.clone();
61
62 cx.spawn(async move |cx| {
63 let name = example.name.clone();
64 run_evaluate_one(
65 example,
66 repetition_ix,
67 project,
68 zeta,
69 args.prompt_format,
70 args.use_expected_context,
71 !args.skip_prediction,
72 args.cache,
73 cx,
74 )
75 .await
76 .map_err(|err| (err, name, repetition_ix))
77 })
78 });
79 futures::future::join_all(tasks).await
80 })
81 });
82 let all_results = futures::future::join_all(all_tasks).await;
83
84 write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap();
85 if let Some(mut output_file) =
86 std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err()
87 {
88 write_aggregated_scores(&mut output_file, &all_results).log_err();
89 };
90 print_run_data_dir(args.repetitions == 1, std::io::stdout().is_terminal());
91}
92
93fn write_aggregated_scores(
94 w: &mut impl std::io::Write,
95 all_results: &Vec<Vec<Result<EvaluationResult, (anyhow::Error, String, Option<u16>)>>>,
96) -> Result<()> {
97 let mut successful = Vec::new();
98 let mut failed_count = 0;
99
100 for result in all_results.iter().flatten() {
101 match result {
102 Ok(eval_result) => successful.push(eval_result),
103 Err((err, name, repetition_ix)) => {
104 if failed_count == 0 {
105 writeln!(w, "## Errors\n")?;
106 }
107
108 failed_count += 1;
109 let err = format!("{err:?}")
110 .replace("<edits", "```xml\n<edits")
111 .replace("</edits>", "</edits>\n```");
112 writeln!(
113 w,
114 "### ERROR {name}{}\n\n{err}\n",
115 repetition_ix
116 .map(|ix| format!(" [RUN {ix:03}]"))
117 .unwrap_or_default()
118 )?;
119 }
120 }
121 }
122
123 if successful.len() > 1 {
124 let mut edit_predictions = successful
125 .iter()
126 .filter_map(|r| r.edit_prediction.as_ref())
127 .peekable();
128 let has_edit_predictions = edit_predictions.peek().is_some();
129 let aggregated_result = EvaluationResult {
130 context: Scores::aggregate(successful.iter().map(|r| &r.context)),
131 edit_prediction: has_edit_predictions.then(|| Scores::aggregate(edit_predictions)),
132 prompt_len: successful.iter().map(|r| r.prompt_len).sum::<usize>() / successful.len(),
133 generated_len: successful.iter().map(|r| r.generated_len).sum::<usize>()
134 / successful.len(),
135 };
136
137 writeln!(w, "\n{}", "-".repeat(80))?;
138 writeln!(w, "\n## TOTAL SCORES")?;
139 writeln!(w, "\n### Success Rate")?;
140 writeln!(w, "{:#}", aggregated_result)?;
141 }
142
143 if successful.len() + failed_count > 1 {
144 writeln!(
145 w,
146 "\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
147 successful.len(),
148 successful.len() + failed_count,
149 (successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
150 )?;
151 }
152
153 Ok(())
154}
155
156pub async fn run_evaluate_one(
157 example: NamedExample,
158 repetition_ix: Option<u16>,
159 project: Entity<Project>,
160 zeta: Entity<Zeta>,
161 prompt_format: PromptFormat,
162 use_expected_context: bool,
163 predict: bool,
164 cache_mode: CacheMode,
165 cx: &mut AsyncApp,
166) -> Result<EvaluationResult> {
167 let predict_result = zeta2_predict(
168 example.clone(),
169 project,
170 zeta,
171 repetition_ix,
172 prompt_format,
173 use_expected_context,
174 cache_mode,
175 cx,
176 )
177 .await?;
178
179 let evaluation_result = evaluate(&example.example, &predict_result, predict);
180
181 if repetition_ix.is_none() {
182 write_eval_result(
183 &example,
184 &predict_result,
185 &evaluation_result,
186 &mut std::io::stdout(),
187 std::io::stdout().is_terminal(),
188 predict,
189 )?;
190 }
191
192 if let Some(mut results_file) =
193 std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err()
194 {
195 write_eval_result(
196 &example,
197 &predict_result,
198 &evaluation_result,
199 &mut results_file,
200 false,
201 predict,
202 )
203 .log_err();
204 }
205
206 anyhow::Ok(evaluation_result)
207}
208
209fn write_eval_result(
210 example: &NamedExample,
211 predictions: &PredictionDetails,
212 evaluation_result: &EvaluationResult,
213 out: &mut impl Write,
214 use_color: bool,
215 predict: bool,
216) -> Result<()> {
217 if predict {
218 writeln!(
219 out,
220 "## Expected edit prediction:\n\n```diff\n{}\n```\n",
221 compare_diffs(
222 &example.example.expected_patch,
223 &predictions.diff,
224 use_color
225 )
226 )?;
227 writeln!(
228 out,
229 "## Actual edit prediction:\n\n```diff\n{}\n```\n",
230 compare_diffs(
231 &predictions.diff,
232 &example.example.expected_patch,
233 use_color
234 )
235 )?;
236 }
237
238 writeln!(out, "{:#}", evaluation_result)?;
239
240 anyhow::Ok(())
241}
242
243#[derive(Debug, Default)]
244pub struct EvaluationResult {
245 pub edit_prediction: Option<Scores>,
246 pub context: Scores,
247 pub prompt_len: usize,
248 pub generated_len: usize,
249}
250
251#[derive(Default, Debug)]
252pub struct Scores {
253 pub true_positives: usize,
254 pub false_positives: usize,
255 pub false_negatives: usize,
256}
257
258impl Scores {
259 pub fn new(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
260 let true_positives = expected.intersection(actual).count();
261 let false_positives = actual.difference(expected).count();
262 let false_negatives = expected.difference(actual).count();
263
264 Scores {
265 true_positives,
266 false_positives,
267 false_negatives,
268 }
269 }
270
271 pub fn to_markdown(&self) -> String {
272 format!(
273 "
274Precision : {:.4}
275Recall : {:.4}
276F1 Score : {:.4}
277True Positives : {}
278False Positives : {}
279False Negatives : {}",
280 self.precision(),
281 self.recall(),
282 self.f1_score(),
283 self.true_positives,
284 self.false_positives,
285 self.false_negatives
286 )
287 }
288
289 pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
290 let mut true_positives = 0;
291 let mut false_positives = 0;
292 let mut false_negatives = 0;
293
294 for score in scores {
295 true_positives += score.true_positives;
296 false_positives += score.false_positives;
297 false_negatives += score.false_negatives;
298 }
299
300 Scores {
301 true_positives,
302 false_positives,
303 false_negatives,
304 }
305 }
306
307 pub fn precision(&self) -> f64 {
308 if self.true_positives + self.false_positives == 0 {
309 0.0
310 } else {
311 self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
312 }
313 }
314
315 pub fn recall(&self) -> f64 {
316 if self.true_positives + self.false_negatives == 0 {
317 0.0
318 } else {
319 self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
320 }
321 }
322
323 pub fn f1_score(&self) -> f64 {
324 let recall = self.recall();
325 let precision = self.precision();
326 if precision + recall == 0.0 {
327 0.0
328 } else {
329 2.0 * precision * recall / (precision + recall)
330 }
331 }
332}
333
334impl std::fmt::Display for EvaluationResult {
335 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
336 if f.alternate() {
337 self.fmt_table(f)
338 } else {
339 self.fmt_markdown(f)
340 }
341 }
342}
343
344impl EvaluationResult {
345 fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
346 write!(
347 f,
348 r#"
349### Context Scores
350{}
351"#,
352 self.context.to_markdown(),
353 )?;
354 if let Some(prediction) = &self.edit_prediction {
355 write!(
356 f,
357 r#"
358 ### Edit Prediction Scores
359 {}"#,
360 prediction.to_markdown()
361 )?;
362 }
363 Ok(())
364 }
365
366 fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367 writeln!(f, "### Scores\n")?;
368 writeln!(
369 f,
370 " Prompt Generated TP FP FN Precision Recall F1"
371 )?;
372 writeln!(
373 f,
374 "────────────────────────────────────────────────────────────────────────────────────"
375 )?;
376 writeln!(
377 f,
378 "Context Retrieval {:<7} {:<10} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
379 "",
380 "",
381 self.context.true_positives,
382 self.context.false_positives,
383 self.context.false_negatives,
384 self.context.precision() * 100.0,
385 self.context.recall() * 100.0,
386 self.context.f1_score() * 100.0
387 )?;
388 if let Some(edit_prediction) = &self.edit_prediction {
389 writeln!(
390 f,
391 "Edit Prediction {:<7} {:<10} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
392 self.prompt_len,
393 self.generated_len,
394 edit_prediction.true_positives,
395 edit_prediction.false_positives,
396 edit_prediction.false_negatives,
397 edit_prediction.precision() * 100.0,
398 edit_prediction.recall() * 100.0,
399 edit_prediction.f1_score() * 100.0
400 )?;
401 }
402 Ok(())
403 }
404}
405
406pub fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult {
407 let mut eval_result = EvaluationResult {
408 prompt_len: preds.prompt_len,
409 generated_len: preds.generated_len,
410 ..Default::default()
411 };
412
413 let actual_context_lines: HashSet<_> = preds
414 .excerpts
415 .iter()
416 .flat_map(|excerpt| {
417 excerpt
418 .text
419 .lines()
420 .map(|line| format!("{}: {line}", excerpt.path.display()))
421 })
422 .collect();
423
424 let mut false_positive_lines = actual_context_lines.clone();
425
426 for entry in &example.expected_context {
427 let mut best_alternative_score: Option<Scores> = None;
428
429 for alternative in &entry.alternatives {
430 let expected: HashSet<_> = alternative
431 .excerpts
432 .iter()
433 .flat_map(|excerpt| {
434 excerpt
435 .text
436 .lines()
437 .map(|line| format!("{}: {line}", excerpt.path.display()))
438 })
439 .collect();
440
441 let scores = Scores::new(&expected, &actual_context_lines);
442
443 false_positive_lines.retain(|line| !actual_context_lines.contains(line));
444
445 if best_alternative_score
446 .as_ref()
447 .is_none_or(|best| scores.recall() > best.recall())
448 {
449 best_alternative_score = Some(scores);
450 }
451 }
452
453 let best_alternative = best_alternative_score.unwrap_or_default();
454 eval_result.context.false_negatives += best_alternative.false_negatives;
455 eval_result.context.true_positives += best_alternative.true_positives;
456 }
457
458 eval_result.context.false_positives = false_positive_lines.len();
459
460 if predict {
461 // todo: alternatives for patches
462 let expected_patch_lines = example
463 .expected_patch
464 .lines()
465 .map(DiffLine::parse)
466 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
467 .map(|line| line.to_string())
468 .collect();
469
470 let actual_patch_lines = preds
471 .diff
472 .lines()
473 .map(DiffLine::parse)
474 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
475 .map(|line| line.to_string())
476 .collect();
477
478 eval_result.edit_prediction = Some(Scores::new(&expected_patch_lines, &actual_patch_lines));
479 }
480
481 eval_result
482}
483
484/// Return annotated `patch_a` so that:
485/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
486/// Additions and deletions that are present in `patch_b` will be highlighted in green.
487pub fn compare_diffs(patch_a: &str, patch_b: &str, use_color: bool) -> String {
488 let green = if use_color { "\x1b[32m✓ " } else { "" };
489 let red = if use_color { "\x1b[31m✗ " } else { "" };
490 let neutral = if use_color { " " } else { "" };
491 let reset = if use_color { "\x1b[0m" } else { "" };
492 let lines_a = patch_a.lines().map(DiffLine::parse);
493 let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
494
495 let annotated = lines_a
496 .map(|line| match line {
497 DiffLine::Addition(_) | DiffLine::Deletion(_) => {
498 if lines_b.contains(&line) {
499 format!("{green}{line}{reset}")
500 } else {
501 format!("{red}{line}{reset}")
502 }
503 }
504 _ => format!("{neutral}{line}{reset}"),
505 })
506 .collect::<Vec<String>>();
507
508 annotated.join("\n")
509}