1use std::{
2 io::{IsTerminal, Write},
3 path::PathBuf,
4 sync::Arc,
5};
6
7use anyhow::Result;
8use clap::Args;
9use collections::HashSet;
10use gpui::{AsyncApp, Entity};
11use project::Project;
12use util::ResultExt as _;
13use zeta2::{Zeta, udiff::DiffLine};
14
15use crate::{
16 PromptFormat,
17 example::{Example, NamedExample},
18 headless::ZetaCliAppState,
19 paths::print_run_data_dir,
20 predict::{CacheMode, PredictionDetails, zeta2_predict},
21};
22
23#[derive(Debug, Args)]
24pub struct EvaluateArguments {
25 example_paths: Vec<PathBuf>,
26 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
27 prompt_format: PromptFormat,
28 #[arg(long)]
29 use_expected_context: bool,
30 #[clap(long, value_enum, default_value_t = CacheMode::default())]
31 cache: CacheMode,
32 #[clap(short, long, default_value_t = 1, alias = "repeat")]
33 repetitions: u16,
34}
35
36pub async fn run_evaluate(
37 args: EvaluateArguments,
38 app_state: &Arc<ZetaCliAppState>,
39 cx: &mut AsyncApp,
40) {
41 if args.example_paths.is_empty() {
42 eprintln!("No examples provided");
43 return;
44 }
45 let all_tasks = args.example_paths.into_iter().map(|path| {
46 let app_state = app_state.clone();
47 let example = NamedExample::load(&path).unwrap();
48
49 cx.spawn(async move |cx| {
50 let (project, zetas, _edited_buffers) = example
51 .setup_project(&app_state, args.repetitions, cx)
52 .await
53 .unwrap();
54
55 let tasks = zetas.into_iter().enumerate().map(|(repetition_ix, zeta)| {
56 let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
57
58 let example = example.clone();
59 let project = project.clone();
60
61 cx.spawn(async move |cx| {
62 let name = example.name.clone();
63 run_evaluate_one(
64 example,
65 repetition_ix,
66 project,
67 zeta,
68 args.prompt_format,
69 args.use_expected_context,
70 args.cache,
71 cx,
72 )
73 .await
74 .map_err(|err| (err, name, repetition_ix))
75 })
76 });
77 futures::future::join_all(tasks).await
78 })
79 });
80 let all_results = futures::future::join_all(all_tasks).await;
81
82 write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap();
83 if let Some(mut output_file) =
84 std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err()
85 {
86 write_aggregated_scores(&mut output_file, &all_results).log_err();
87 };
88 print_run_data_dir(args.repetitions == 1);
89}
90
91fn write_aggregated_scores(
92 w: &mut impl std::io::Write,
93 all_results: &Vec<Vec<Result<EvaluationResult, (anyhow::Error, String, Option<u16>)>>>,
94) -> Result<()> {
95 let mut successful = Vec::new();
96 let mut failed_count = 0;
97 writeln!(w, "## Errors\n")?;
98 for result in all_results.iter().flatten() {
99 match result {
100 Ok(eval_result) => successful.push(eval_result),
101 Err((err, name, repetition_ix)) => {
102 failed_count += 1;
103 let err = err
104 .to_string()
105 .replace("<edits", "```xml\n<edits")
106 .replace("</edits>", "</edits>\n```");
107 writeln!(
108 w,
109 "### ERROR {name}{}\n\n{err}\n",
110 repetition_ix
111 .map(|ix| format!(" [RUN {ix:03}]"))
112 .unwrap_or_default()
113 )?;
114 }
115 }
116 }
117 let aggregated_result = EvaluationResult {
118 context: Scores::aggregate(successful.iter().map(|r| &r.context)),
119 edit_prediction: Scores::aggregate(successful.iter().map(|r| &r.edit_prediction)),
120 };
121
122 writeln!(w, "\n{}", "-".repeat(80))?;
123 writeln!(w, "\n## TOTAL SCORES")?;
124 writeln!(w, "\n### Success Rate")?;
125 writeln!(
126 w,
127 "\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
128 successful.len(),
129 successful.len() + failed_count,
130 (successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
131 )?;
132 writeln!(w, "{}", aggregated_result)?;
133
134 Ok(())
135}
136
137pub async fn run_evaluate_one(
138 example: NamedExample,
139 repetition_ix: Option<u16>,
140 project: Entity<Project>,
141 zeta: Entity<Zeta>,
142 prompt_format: PromptFormat,
143 use_expected_context: bool,
144 cache_mode: CacheMode,
145 cx: &mut AsyncApp,
146) -> Result<EvaluationResult> {
147 let predict_result = zeta2_predict(
148 example.clone(),
149 project,
150 zeta,
151 repetition_ix,
152 prompt_format,
153 use_expected_context,
154 cache_mode,
155 cx,
156 )
157 .await?;
158
159 let evaluation_result = evaluate(&example.example, &predict_result);
160
161 if repetition_ix.is_none() {
162 write_eval_result(
163 &example,
164 &predict_result,
165 &evaluation_result,
166 &mut std::io::stdout(),
167 )?;
168 }
169
170 if let Some(mut results_file) =
171 std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err()
172 {
173 write_eval_result(
174 &example,
175 &predict_result,
176 &evaluation_result,
177 &mut results_file,
178 )
179 .log_err();
180 }
181
182 anyhow::Ok(evaluation_result)
183}
184
185fn write_eval_result(
186 example: &NamedExample,
187 predictions: &PredictionDetails,
188 evaluation_result: &EvaluationResult,
189 out: &mut impl Write,
190) -> Result<()> {
191 writeln!(
192 out,
193 "## Expected edit prediction:\n\n```diff\n{}\n```\n",
194 compare_diffs(&example.example.expected_patch, &predictions.diff)
195 )?;
196 writeln!(
197 out,
198 "## Actual edit prediction:\n\n```diff\n{}\n```\n",
199 compare_diffs(&predictions.diff, &example.example.expected_patch)
200 )?;
201 writeln!(out, "{}", evaluation_result)?;
202
203 anyhow::Ok(())
204}
205
206#[derive(Debug, Default)]
207pub struct EvaluationResult {
208 pub edit_prediction: Scores,
209 pub context: Scores,
210}
211
212#[derive(Default, Debug)]
213pub struct Scores {
214 pub true_positives: usize,
215 pub false_positives: usize,
216 pub false_negatives: usize,
217}
218
219impl Scores {
220 pub fn new(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
221 let true_positives = expected.intersection(actual).count();
222 let false_positives = actual.difference(expected).count();
223 let false_negatives = expected.difference(actual).count();
224
225 Scores {
226 true_positives,
227 false_positives,
228 false_negatives,
229 }
230 }
231
232 pub fn to_markdown(&self) -> String {
233 format!(
234 "
235Precision : {:.4}
236Recall : {:.4}
237F1 Score : {:.4}
238True Positives : {}
239False Positives : {}
240False Negatives : {}",
241 self.precision(),
242 self.recall(),
243 self.f1_score(),
244 self.true_positives,
245 self.false_positives,
246 self.false_negatives
247 )
248 }
249
250 pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
251 let mut true_positives = 0;
252 let mut false_positives = 0;
253 let mut false_negatives = 0;
254
255 for score in scores {
256 true_positives += score.true_positives;
257 false_positives += score.false_positives;
258 false_negatives += score.false_negatives;
259 }
260
261 Scores {
262 true_positives,
263 false_positives,
264 false_negatives,
265 }
266 }
267
268 pub fn precision(&self) -> f64 {
269 if self.true_positives + self.false_positives == 0 {
270 0.0
271 } else {
272 self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
273 }
274 }
275
276 pub fn recall(&self) -> f64 {
277 if self.true_positives + self.false_negatives == 0 {
278 0.0
279 } else {
280 self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
281 }
282 }
283
284 pub fn f1_score(&self) -> f64 {
285 let recall = self.recall();
286 let precision = self.precision();
287 if precision + recall == 0.0 {
288 0.0
289 } else {
290 2.0 * precision * recall / (precision + recall)
291 }
292 }
293}
294
295impl std::fmt::Display for EvaluationResult {
296 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297 write!(
298 f,
299 r#"
300### Context Scores
301{}
302
303### Edit Prediction Scores
304{}
305"#,
306 self.context.to_markdown(),
307 self.edit_prediction.to_markdown()
308 )
309 }
310}
311
312pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult {
313 let mut eval_result = EvaluationResult::default();
314
315 let actual_context_lines: HashSet<_> = preds
316 .excerpts
317 .iter()
318 .flat_map(|excerpt| {
319 excerpt
320 .text
321 .lines()
322 .map(|line| format!("{}: {line}", excerpt.path.display()))
323 })
324 .collect();
325
326 let mut false_positive_lines = actual_context_lines.clone();
327
328 for entry in &example.expected_context {
329 let mut best_alternative_score = Scores::default();
330
331 for alternative in &entry.alternatives {
332 let expected: HashSet<_> = alternative
333 .excerpts
334 .iter()
335 .flat_map(|excerpt| {
336 excerpt
337 .text
338 .lines()
339 .map(|line| format!("{}: {line}", excerpt.path.display()))
340 })
341 .collect();
342
343 let scores = Scores::new(&expected, &actual_context_lines);
344
345 false_positive_lines.retain(|line| !actual_context_lines.contains(line));
346
347 if scores.recall() > best_alternative_score.recall() {
348 best_alternative_score = scores;
349 }
350 }
351
352 eval_result.context.false_negatives += best_alternative_score.false_negatives;
353 eval_result.context.true_positives += best_alternative_score.true_positives;
354 }
355
356 eval_result.context.false_positives = false_positive_lines.len();
357
358 // todo: alternatives for patches
359 let expected_patch_lines = example
360 .expected_patch
361 .lines()
362 .map(DiffLine::parse)
363 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
364 .map(|line| line.to_string())
365 .collect();
366
367 let actual_patch_lines = preds
368 .diff
369 .lines()
370 .map(DiffLine::parse)
371 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
372 .map(|line| line.to_string())
373 .collect();
374
375 eval_result.edit_prediction = Scores::new(&expected_patch_lines, &actual_patch_lines);
376 eval_result
377}
378
379/// Return annotated `patch_a` so that:
380/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
381/// Additions and deletions that are present in `patch_b` will be highlighted in green.
382pub fn compare_diffs(patch_a: &str, patch_b: &str) -> String {
383 let use_color = std::io::stdout().is_terminal();
384 let green = if use_color { "\x1b[32m✓ " } else { "" };
385 let red = if use_color { "\x1b[31m✗ " } else { "" };
386 let neutral = if use_color { " " } else { "" };
387 let reset = if use_color { "\x1b[0m" } else { "" };
388 let lines_a = patch_a.lines().map(DiffLine::parse);
389 let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
390
391 let annotated = lines_a
392 .map(|line| match line {
393 DiffLine::Addition(_) | DiffLine::Deletion(_) => {
394 if lines_b.contains(&line) {
395 format!("{green}{line}{reset}")
396 } else {
397 format!("{red}{line}{reset}")
398 }
399 }
400 _ => format!("{neutral}{line}{reset}"),
401 })
402 .collect::<Vec<String>>();
403
404 annotated.join("\n")
405}