1use std::{
2 io::{IsTerminal, Write},
3 path::PathBuf,
4 sync::Arc,
5};
6
7use anyhow::Result;
8use clap::Args;
9use collections::HashSet;
10use gpui::{AsyncApp, Entity};
11use project::Project;
12use util::ResultExt as _;
13use zeta2::{Zeta, udiff::DiffLine};
14
15use crate::{
16 PromptFormat,
17 example::{Example, NamedExample},
18 headless::ZetaCliAppState,
19 paths::print_run_data_dir,
20 predict::{CacheMode, PredictionDetails, zeta2_predict},
21};
22
23#[derive(Debug, Args)]
24pub struct EvaluateArguments {
25 example_paths: Vec<PathBuf>,
26 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
27 prompt_format: PromptFormat,
28 #[arg(long)]
29 use_expected_context: bool,
30 #[clap(long, value_enum, default_value_t = CacheMode::default())]
31 cache: CacheMode,
32 #[clap(short, long, default_value_t = 1, alias = "repeat")]
33 repetitions: u16,
34 #[arg(long)]
35 skip_prediction: bool,
36}
37
38pub async fn run_evaluate(
39 args: EvaluateArguments,
40 app_state: &Arc<ZetaCliAppState>,
41 cx: &mut AsyncApp,
42) {
43 if args.example_paths.is_empty() {
44 eprintln!("No examples provided");
45 return;
46 }
47 let all_tasks = args.example_paths.into_iter().map(|path| {
48 let app_state = app_state.clone();
49 let example = NamedExample::load(&path).unwrap();
50
51 cx.spawn(async move |cx| {
52 let (project, zetas, _edited_buffers) = example
53 .setup_project(&app_state, args.repetitions, cx)
54 .await
55 .unwrap();
56
57 let tasks = zetas.into_iter().enumerate().map(|(repetition_ix, zeta)| {
58 let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
59 let example = example.clone();
60 let project = project.clone();
61
62 cx.spawn(async move |cx| {
63 let name = example.name.clone();
64 run_evaluate_one(
65 example,
66 repetition_ix,
67 project,
68 zeta,
69 args.prompt_format,
70 args.use_expected_context,
71 !args.skip_prediction,
72 args.cache,
73 cx,
74 )
75 .await
76 .map_err(|err| (err, name, repetition_ix))
77 })
78 });
79 futures::future::join_all(tasks).await
80 })
81 });
82 let all_results = futures::future::join_all(all_tasks).await;
83
84 write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap();
85 if let Some(mut output_file) =
86 std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err()
87 {
88 write_aggregated_scores(&mut output_file, &all_results).log_err();
89 };
90 print_run_data_dir(args.repetitions == 1, std::io::stdout().is_terminal());
91}
92
93fn write_aggregated_scores(
94 w: &mut impl std::io::Write,
95 all_results: &Vec<Vec<Result<EvaluationResult, (anyhow::Error, String, Option<u16>)>>>,
96) -> Result<()> {
97 let mut successful = Vec::new();
98 let mut failed_count = 0;
99
100 for result in all_results.iter().flatten() {
101 match result {
102 Ok(eval_result) => successful.push(eval_result),
103 Err((err, name, repetition_ix)) => {
104 if failed_count == 0 {
105 writeln!(w, "## Errors\n")?;
106 }
107
108 failed_count += 1;
109 let err = format!("{err:?}")
110 .replace("<edits", "```xml\n<edits")
111 .replace("</edits>", "</edits>\n```");
112 writeln!(
113 w,
114 "### ERROR {name}{}\n\n{err}\n",
115 repetition_ix
116 .map(|ix| format!(" [RUN {ix:03}]"))
117 .unwrap_or_default()
118 )?;
119 }
120 }
121 }
122
123 if successful.len() > 1 {
124 let mut edit_predictions = successful
125 .iter()
126 .filter_map(|r| r.edit_prediction.as_ref())
127 .peekable();
128 let has_edit_predictions = edit_predictions.peek().is_some();
129 let aggregated_result = EvaluationResult {
130 context: Scores::aggregate(successful.iter().map(|r| &r.context)),
131 edit_prediction: has_edit_predictions.then(|| Scores::aggregate(edit_predictions)),
132 };
133
134 writeln!(w, "\n{}", "-".repeat(80))?;
135 writeln!(w, "\n## TOTAL SCORES")?;
136 writeln!(w, "\n### Success Rate")?;
137 writeln!(w, "{}", aggregated_result)?;
138 }
139
140 if successful.len() + failed_count > 1 {
141 writeln!(
142 w,
143 "\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
144 successful.len(),
145 successful.len() + failed_count,
146 (successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
147 )?;
148 }
149
150 Ok(())
151}
152
153pub async fn run_evaluate_one(
154 example: NamedExample,
155 repetition_ix: Option<u16>,
156 project: Entity<Project>,
157 zeta: Entity<Zeta>,
158 prompt_format: PromptFormat,
159 use_expected_context: bool,
160 predict: bool,
161 cache_mode: CacheMode,
162 cx: &mut AsyncApp,
163) -> Result<EvaluationResult> {
164 let predict_result = zeta2_predict(
165 example.clone(),
166 project,
167 zeta,
168 repetition_ix,
169 prompt_format,
170 use_expected_context,
171 cache_mode,
172 cx,
173 )
174 .await?;
175
176 let evaluation_result = evaluate(&example.example, &predict_result, predict);
177
178 if repetition_ix.is_none() {
179 write_eval_result(
180 &example,
181 &predict_result,
182 &evaluation_result,
183 &mut std::io::stdout(),
184 std::io::stdout().is_terminal(),
185 predict,
186 )?;
187 }
188
189 if let Some(mut results_file) =
190 std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err()
191 {
192 write_eval_result(
193 &example,
194 &predict_result,
195 &evaluation_result,
196 &mut results_file,
197 false,
198 predict,
199 )
200 .log_err();
201 }
202
203 anyhow::Ok(evaluation_result)
204}
205
206fn write_eval_result(
207 example: &NamedExample,
208 predictions: &PredictionDetails,
209 evaluation_result: &EvaluationResult,
210 out: &mut impl Write,
211 use_color: bool,
212 predict: bool,
213) -> Result<()> {
214 if predict {
215 writeln!(
216 out,
217 "## Expected edit prediction:\n\n```diff\n{}\n```\n",
218 compare_diffs(
219 &example.example.expected_patch,
220 &predictions.diff,
221 use_color
222 )
223 )?;
224 writeln!(
225 out,
226 "## Actual edit prediction:\n\n```diff\n{}\n```\n",
227 compare_diffs(
228 &predictions.diff,
229 &example.example.expected_patch,
230 use_color
231 )
232 )?;
233 }
234
235 writeln!(out, "{:#}", evaluation_result)?;
236
237 anyhow::Ok(())
238}
239
240#[derive(Debug, Default)]
241pub struct EvaluationResult {
242 pub edit_prediction: Option<Scores>,
243 pub context: Scores,
244}
245
246#[derive(Default, Debug)]
247pub struct Scores {
248 pub true_positives: usize,
249 pub false_positives: usize,
250 pub false_negatives: usize,
251}
252
253impl Scores {
254 pub fn new(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
255 let true_positives = expected.intersection(actual).count();
256 let false_positives = actual.difference(expected).count();
257 let false_negatives = expected.difference(actual).count();
258
259 Scores {
260 true_positives,
261 false_positives,
262 false_negatives,
263 }
264 }
265
266 pub fn to_markdown(&self) -> String {
267 format!(
268 "
269Precision : {:.4}
270Recall : {:.4}
271F1 Score : {:.4}
272True Positives : {}
273False Positives : {}
274False Negatives : {}",
275 self.precision(),
276 self.recall(),
277 self.f1_score(),
278 self.true_positives,
279 self.false_positives,
280 self.false_negatives
281 )
282 }
283
284 pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
285 let mut true_positives = 0;
286 let mut false_positives = 0;
287 let mut false_negatives = 0;
288
289 for score in scores {
290 true_positives += score.true_positives;
291 false_positives += score.false_positives;
292 false_negatives += score.false_negatives;
293 }
294
295 Scores {
296 true_positives,
297 false_positives,
298 false_negatives,
299 }
300 }
301
302 pub fn precision(&self) -> f64 {
303 if self.true_positives + self.false_positives == 0 {
304 0.0
305 } else {
306 self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
307 }
308 }
309
310 pub fn recall(&self) -> f64 {
311 if self.true_positives + self.false_negatives == 0 {
312 0.0
313 } else {
314 self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
315 }
316 }
317
318 pub fn f1_score(&self) -> f64 {
319 let recall = self.recall();
320 let precision = self.precision();
321 if precision + recall == 0.0 {
322 0.0
323 } else {
324 2.0 * precision * recall / (precision + recall)
325 }
326 }
327}
328
329impl std::fmt::Display for EvaluationResult {
330 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
331 if f.alternate() {
332 self.fmt_table(f)
333 } else {
334 self.fmt_markdown(f)
335 }
336 }
337}
338
339impl EvaluationResult {
340 fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341 write!(
342 f,
343 r#"
344### Context Scores
345{}
346"#,
347 self.context.to_markdown(),
348 )?;
349 if let Some(prediction) = &self.edit_prediction {
350 write!(
351 f,
352 r#"
353 ### Edit Prediction Scores
354 {}"#,
355 prediction.to_markdown()
356 )?;
357 }
358 Ok(())
359 }
360
361 fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
362 writeln!(f, "### Scores\n")?;
363 writeln!(
364 f,
365 " TP FP FN Precision Recall F1"
366 )?;
367 writeln!(
368 f,
369 "──────────────────────────────────────────────────────────────────"
370 )?;
371 writeln!(
372 f,
373 "Context Retrieval {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
374 self.context.true_positives,
375 self.context.false_positives,
376 self.context.false_negatives,
377 self.context.precision() * 100.0,
378 self.context.recall() * 100.0,
379 self.context.f1_score() * 100.0
380 )?;
381 if let Some(edit_prediction) = &self.edit_prediction {
382 writeln!(
383 f,
384 "Edit Prediction {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
385 edit_prediction.true_positives,
386 edit_prediction.false_positives,
387 edit_prediction.false_negatives,
388 edit_prediction.precision() * 100.0,
389 edit_prediction.recall() * 100.0,
390 edit_prediction.f1_score() * 100.0
391 )?;
392 }
393 Ok(())
394 }
395}
396
397pub fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult {
398 let mut eval_result = EvaluationResult::default();
399
400 let actual_context_lines: HashSet<_> = preds
401 .excerpts
402 .iter()
403 .flat_map(|excerpt| {
404 excerpt
405 .text
406 .lines()
407 .map(|line| format!("{}: {line}", excerpt.path.display()))
408 })
409 .collect();
410
411 let mut false_positive_lines = actual_context_lines.clone();
412
413 for entry in &example.expected_context {
414 let mut best_alternative_score: Option<Scores> = None;
415
416 for alternative in &entry.alternatives {
417 let expected: HashSet<_> = alternative
418 .excerpts
419 .iter()
420 .flat_map(|excerpt| {
421 excerpt
422 .text
423 .lines()
424 .map(|line| format!("{}: {line}", excerpt.path.display()))
425 })
426 .collect();
427
428 let scores = Scores::new(&expected, &actual_context_lines);
429
430 false_positive_lines.retain(|line| !actual_context_lines.contains(line));
431
432 if best_alternative_score
433 .as_ref()
434 .is_none_or(|best| scores.recall() > best.recall())
435 {
436 best_alternative_score = Some(scores);
437 }
438 }
439
440 let best_alternative = best_alternative_score.unwrap_or_default();
441 eval_result.context.false_negatives += best_alternative.false_negatives;
442 eval_result.context.true_positives += best_alternative.true_positives;
443 }
444
445 eval_result.context.false_positives = false_positive_lines.len();
446
447 if predict {
448 // todo: alternatives for patches
449 let expected_patch_lines = example
450 .expected_patch
451 .lines()
452 .map(DiffLine::parse)
453 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
454 .map(|line| line.to_string())
455 .collect();
456
457 let actual_patch_lines = preds
458 .diff
459 .lines()
460 .map(DiffLine::parse)
461 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
462 .map(|line| line.to_string())
463 .collect();
464
465 eval_result.edit_prediction = Some(Scores::new(&expected_patch_lines, &actual_patch_lines));
466 }
467
468 eval_result
469}
470
471/// Return annotated `patch_a` so that:
472/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
473/// Additions and deletions that are present in `patch_b` will be highlighted in green.
474pub fn compare_diffs(patch_a: &str, patch_b: &str, use_color: bool) -> String {
475 let green = if use_color { "\x1b[32m✓ " } else { "" };
476 let red = if use_color { "\x1b[31m✗ " } else { "" };
477 let neutral = if use_color { " " } else { "" };
478 let reset = if use_color { "\x1b[0m" } else { "" };
479 let lines_a = patch_a.lines().map(DiffLine::parse);
480 let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
481
482 let annotated = lines_a
483 .map(|line| match line {
484 DiffLine::Addition(_) | DiffLine::Deletion(_) => {
485 if lines_b.contains(&line) {
486 format!("{green}{line}{reset}")
487 } else {
488 format!("{red}{line}{reset}")
489 }
490 }
491 _ => format!("{neutral}{line}{reset}"),
492 })
493 .collect::<Vec<String>>();
494
495 annotated.join("\n")
496}