1use std::{
2 io::{IsTerminal, Write},
3 path::PathBuf,
4 sync::Arc,
5};
6
7use anyhow::Result;
8use clap::Args;
9use collections::HashSet;
10use gpui::{AsyncApp, Entity};
11use project::Project;
12use util::ResultExt as _;
13use zeta2::{Zeta, udiff::DiffLine};
14
15use crate::{
16 PromptFormat,
17 example::{Example, NamedExample},
18 headless::ZetaCliAppState,
19 paths::print_run_data_dir,
20 predict::{CacheMode, PredictionDetails, zeta2_predict},
21};
22
23#[derive(Debug, Args)]
24pub struct EvaluateArguments {
25 example_paths: Vec<PathBuf>,
26 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
27 prompt_format: PromptFormat,
28 #[arg(long)]
29 use_expected_context: bool,
30 #[clap(long, value_enum, default_value_t = CacheMode::default())]
31 cache: CacheMode,
32 #[clap(short, long, default_value_t = 1, alias = "repeat")]
33 repetitions: u16,
34}
35
36pub async fn run_evaluate(
37 args: EvaluateArguments,
38 app_state: &Arc<ZetaCliAppState>,
39 cx: &mut AsyncApp,
40) {
41 if args.example_paths.is_empty() {
42 eprintln!("No examples provided");
43 return;
44 }
45 let all_tasks = args.example_paths.into_iter().map(|path| {
46 let app_state = app_state.clone();
47 let example = NamedExample::load(&path).unwrap();
48
49 cx.spawn(async move |cx| {
50 let (project, zetas, _edited_buffers) = example
51 .setup_project(&app_state, args.repetitions, cx)
52 .await
53 .unwrap();
54
55 let tasks = zetas.into_iter().enumerate().map(|(repetition_ix, zeta)| {
56 let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
57 let example = example.clone();
58 let project = project.clone();
59
60 cx.spawn(async move |cx| {
61 let name = example.name.clone();
62 run_evaluate_one(
63 example,
64 repetition_ix,
65 project,
66 zeta,
67 args.prompt_format,
68 args.use_expected_context,
69 args.cache,
70 cx,
71 )
72 .await
73 .map_err(|err| (err, name, repetition_ix))
74 })
75 });
76 futures::future::join_all(tasks).await
77 })
78 });
79 let all_results = futures::future::join_all(all_tasks).await;
80
81 write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap();
82 if let Some(mut output_file) =
83 std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err()
84 {
85 write_aggregated_scores(&mut output_file, &all_results).log_err();
86 };
87 print_run_data_dir(args.repetitions == 1);
88}
89
90fn write_aggregated_scores(
91 w: &mut impl std::io::Write,
92 all_results: &Vec<Vec<Result<EvaluationResult, (anyhow::Error, String, Option<u16>)>>>,
93) -> Result<()> {
94 let mut successful = Vec::new();
95 let mut failed_count = 0;
96
97 for result in all_results.iter().flatten() {
98 match result {
99 Ok(eval_result) => successful.push(eval_result),
100 Err((err, name, repetition_ix)) => {
101 if failed_count == 0 {
102 writeln!(w, "## Errors\n")?;
103 }
104
105 failed_count += 1;
106 let err = err
107 .to_string()
108 .replace("<edits", "```xml\n<edits")
109 .replace("</edits>", "</edits>\n```");
110 writeln!(
111 w,
112 "### ERROR {name}{}\n\n{err}\n",
113 repetition_ix
114 .map(|ix| format!(" [RUN {ix:03}]"))
115 .unwrap_or_default()
116 )?;
117 }
118 }
119 }
120
121 if successful.len() > 1 {
122 let aggregated_result = EvaluationResult {
123 context: Scores::aggregate(successful.iter().map(|r| &r.context)),
124 edit_prediction: Scores::aggregate(successful.iter().map(|r| &r.edit_prediction)),
125 };
126
127 writeln!(w, "\n{}", "-".repeat(80))?;
128 writeln!(w, "\n## TOTAL SCORES")?;
129 writeln!(w, "\n### Success Rate")?;
130 writeln!(w, "{}", aggregated_result)?;
131 }
132
133 if successful.len() + failed_count > 1 {
134 writeln!(
135 w,
136 "\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
137 successful.len(),
138 successful.len() + failed_count,
139 (successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
140 )?;
141 }
142
143 Ok(())
144}
145
146pub async fn run_evaluate_one(
147 example: NamedExample,
148 repetition_ix: Option<u16>,
149 project: Entity<Project>,
150 zeta: Entity<Zeta>,
151 prompt_format: PromptFormat,
152 use_expected_context: bool,
153 cache_mode: CacheMode,
154 cx: &mut AsyncApp,
155) -> Result<EvaluationResult> {
156 let predict_result = zeta2_predict(
157 example.clone(),
158 project,
159 zeta,
160 repetition_ix,
161 prompt_format,
162 use_expected_context,
163 cache_mode,
164 cx,
165 )
166 .await?;
167
168 let evaluation_result = evaluate(&example.example, &predict_result);
169
170 if repetition_ix.is_none() {
171 write_eval_result(
172 &example,
173 &predict_result,
174 &evaluation_result,
175 &mut std::io::stdout(),
176 )?;
177 }
178
179 if let Some(mut results_file) =
180 std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err()
181 {
182 write_eval_result(
183 &example,
184 &predict_result,
185 &evaluation_result,
186 &mut results_file,
187 )
188 .log_err();
189 }
190
191 anyhow::Ok(evaluation_result)
192}
193
194fn write_eval_result(
195 example: &NamedExample,
196 predictions: &PredictionDetails,
197 evaluation_result: &EvaluationResult,
198 out: &mut impl Write,
199) -> Result<()> {
200 writeln!(
201 out,
202 "## Expected edit prediction:\n\n```diff\n{}\n```\n",
203 compare_diffs(&example.example.expected_patch, &predictions.diff)
204 )?;
205 writeln!(
206 out,
207 "## Actual edit prediction:\n\n```diff\n{}\n```\n",
208 compare_diffs(&predictions.diff, &example.example.expected_patch)
209 )?;
210 writeln!(out, "{:#}", evaluation_result)?;
211
212 anyhow::Ok(())
213}
214
215#[derive(Debug, Default)]
216pub struct EvaluationResult {
217 pub edit_prediction: Scores,
218 pub context: Scores,
219}
220
221#[derive(Default, Debug)]
222pub struct Scores {
223 pub true_positives: usize,
224 pub false_positives: usize,
225 pub false_negatives: usize,
226}
227
228impl Scores {
229 pub fn new(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
230 let true_positives = expected.intersection(actual).count();
231 let false_positives = actual.difference(expected).count();
232 let false_negatives = expected.difference(actual).count();
233
234 Scores {
235 true_positives,
236 false_positives,
237 false_negatives,
238 }
239 }
240
241 pub fn to_markdown(&self) -> String {
242 format!(
243 "
244Precision : {:.4}
245Recall : {:.4}
246F1 Score : {:.4}
247True Positives : {}
248False Positives : {}
249False Negatives : {}",
250 self.precision(),
251 self.recall(),
252 self.f1_score(),
253 self.true_positives,
254 self.false_positives,
255 self.false_negatives
256 )
257 }
258
259 pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
260 let mut true_positives = 0;
261 let mut false_positives = 0;
262 let mut false_negatives = 0;
263
264 for score in scores {
265 true_positives += score.true_positives;
266 false_positives += score.false_positives;
267 false_negatives += score.false_negatives;
268 }
269
270 Scores {
271 true_positives,
272 false_positives,
273 false_negatives,
274 }
275 }
276
277 pub fn precision(&self) -> f64 {
278 if self.true_positives + self.false_positives == 0 {
279 0.0
280 } else {
281 self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
282 }
283 }
284
285 pub fn recall(&self) -> f64 {
286 if self.true_positives + self.false_negatives == 0 {
287 0.0
288 } else {
289 self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
290 }
291 }
292
293 pub fn f1_score(&self) -> f64 {
294 let recall = self.recall();
295 let precision = self.precision();
296 if precision + recall == 0.0 {
297 0.0
298 } else {
299 2.0 * precision * recall / (precision + recall)
300 }
301 }
302}
303
304impl std::fmt::Display for EvaluationResult {
305 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306 if f.alternate() {
307 self.fmt_table(f)
308 } else {
309 self.fmt_markdown(f)
310 }
311 }
312}
313
314impl EvaluationResult {
315 fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316 write!(
317 f,
318 r#"
319### Context Scores
320{}
321
322### Edit Prediction Scores
323{}
324"#,
325 self.context.to_markdown(),
326 self.edit_prediction.to_markdown()
327 )
328 }
329
330 fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
331 writeln!(f, "### Scores\n")?;
332 writeln!(
333 f,
334 " TP FP FN Precision Recall F1"
335 )?;
336 writeln!(
337 f,
338 "──────────────────────────────────────────────────────────────────"
339 )?;
340 writeln!(
341 f,
342 "Context Retrieval {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
343 self.context.true_positives,
344 self.context.false_positives,
345 self.context.false_negatives,
346 self.context.precision() * 100.0,
347 self.context.recall() * 100.0,
348 self.context.f1_score() * 100.0
349 )?;
350 writeln!(
351 f,
352 "Edit Prediction {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
353 self.edit_prediction.true_positives,
354 self.edit_prediction.false_positives,
355 self.edit_prediction.false_negatives,
356 self.edit_prediction.precision() * 100.0,
357 self.edit_prediction.recall() * 100.0,
358 self.edit_prediction.f1_score() * 100.0
359 )
360 }
361}
362
363pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult {
364 let mut eval_result = EvaluationResult::default();
365
366 let actual_context_lines: HashSet<_> = preds
367 .excerpts
368 .iter()
369 .flat_map(|excerpt| {
370 excerpt
371 .text
372 .lines()
373 .map(|line| format!("{}: {line}", excerpt.path.display()))
374 })
375 .collect();
376
377 let mut false_positive_lines = actual_context_lines.clone();
378
379 for entry in &example.expected_context {
380 let mut best_alternative_score: Option<Scores> = None;
381
382 for alternative in &entry.alternatives {
383 let expected: HashSet<_> = alternative
384 .excerpts
385 .iter()
386 .flat_map(|excerpt| {
387 excerpt
388 .text
389 .lines()
390 .map(|line| format!("{}: {line}", excerpt.path.display()))
391 })
392 .collect();
393
394 let scores = Scores::new(&expected, &actual_context_lines);
395
396 false_positive_lines.retain(|line| !actual_context_lines.contains(line));
397
398 if best_alternative_score
399 .as_ref()
400 .is_none_or(|best| scores.recall() > best.recall())
401 {
402 best_alternative_score = Some(scores);
403 }
404 }
405
406 let best_alternative = best_alternative_score.unwrap_or_default();
407 eval_result.context.false_negatives += best_alternative.false_negatives;
408 eval_result.context.true_positives += best_alternative.true_positives;
409 }
410
411 eval_result.context.false_positives = false_positive_lines.len();
412
413 // todo: alternatives for patches
414 let expected_patch_lines = example
415 .expected_patch
416 .lines()
417 .map(DiffLine::parse)
418 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
419 .map(|line| line.to_string())
420 .collect();
421
422 let actual_patch_lines = preds
423 .diff
424 .lines()
425 .map(DiffLine::parse)
426 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
427 .map(|line| line.to_string())
428 .collect();
429
430 eval_result.edit_prediction = Scores::new(&expected_patch_lines, &actual_patch_lines);
431 eval_result
432}
433
434/// Return annotated `patch_a` so that:
435/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
436/// Additions and deletions that are present in `patch_b` will be highlighted in green.
437pub fn compare_diffs(patch_a: &str, patch_b: &str) -> String {
438 let use_color = std::io::stdout().is_terminal();
439 let green = if use_color { "\x1b[32m✓ " } else { "" };
440 let red = if use_color { "\x1b[31m✗ " } else { "" };
441 let neutral = if use_color { " " } else { "" };
442 let reset = if use_color { "\x1b[0m" } else { "" };
443 let lines_a = patch_a.lines().map(DiffLine::parse);
444 let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
445
446 let annotated = lines_a
447 .map(|line| match line {
448 DiffLine::Addition(_) | DiffLine::Deletion(_) => {
449 if lines_b.contains(&line) {
450 format!("{green}{line}{reset}")
451 } else {
452 format!("{red}{line}{reset}")
453 }
454 }
455 _ => format!("{neutral}{line}{reset}"),
456 })
457 .collect::<Vec<String>>();
458
459 annotated.join("\n")
460}