evaluate.rs

  1use std::{
  2    io::{IsTerminal, Write},
  3    path::PathBuf,
  4    sync::Arc,
  5};
  6
  7use anyhow::Result;
  8use clap::Args;
  9use collections::HashSet;
 10use gpui::{AsyncApp, Entity};
 11use project::Project;
 12use util::ResultExt as _;
 13use zeta2::{Zeta, udiff::DiffLine};
 14
 15use crate::{
 16    PromptFormat,
 17    example::{Example, NamedExample},
 18    headless::ZetaCliAppState,
 19    paths::print_run_data_dir,
 20    predict::{CacheMode, PredictionDetails, zeta2_predict},
 21};
 22
 23#[derive(Debug, Args)]
 24pub struct EvaluateArguments {
 25    example_paths: Vec<PathBuf>,
 26    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
 27    prompt_format: PromptFormat,
 28    #[arg(long)]
 29    use_expected_context: bool,
 30    #[clap(long, value_enum, default_value_t = CacheMode::default())]
 31    cache: CacheMode,
 32    #[clap(short, long, default_value_t = 1, alias = "repeat")]
 33    repetitions: u16,
 34}
 35
 36pub async fn run_evaluate(
 37    args: EvaluateArguments,
 38    app_state: &Arc<ZetaCliAppState>,
 39    cx: &mut AsyncApp,
 40) {
 41    if args.example_paths.is_empty() {
 42        eprintln!("No examples provided");
 43        return;
 44    }
 45    let all_tasks = args.example_paths.into_iter().map(|path| {
 46        let app_state = app_state.clone();
 47        let example = NamedExample::load(&path).unwrap();
 48
 49        cx.spawn(async move |cx| {
 50            let (project, zetas, _edited_buffers) = example
 51                .setup_project(&app_state, args.repetitions, cx)
 52                .await
 53                .unwrap();
 54
 55            let tasks = zetas.into_iter().enumerate().map(|(repetition_ix, zeta)| {
 56                let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
 57                let example = example.clone();
 58                let project = project.clone();
 59
 60                cx.spawn(async move |cx| {
 61                    let name = example.name.clone();
 62                    run_evaluate_one(
 63                        example,
 64                        repetition_ix,
 65                        project,
 66                        zeta,
 67                        args.prompt_format,
 68                        args.use_expected_context,
 69                        args.cache,
 70                        cx,
 71                    )
 72                    .await
 73                    .map_err(|err| (err, name, repetition_ix))
 74                })
 75            });
 76            futures::future::join_all(tasks).await
 77        })
 78    });
 79    let all_results = futures::future::join_all(all_tasks).await;
 80
 81    write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap();
 82    if let Some(mut output_file) =
 83        std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err()
 84    {
 85        write_aggregated_scores(&mut output_file, &all_results).log_err();
 86    };
 87    print_run_data_dir(args.repetitions == 1, std::io::stdout().is_terminal());
 88}
 89
 90fn write_aggregated_scores(
 91    w: &mut impl std::io::Write,
 92    all_results: &Vec<Vec<Result<EvaluationResult, (anyhow::Error, String, Option<u16>)>>>,
 93) -> Result<()> {
 94    let mut successful = Vec::new();
 95    let mut failed_count = 0;
 96
 97    for result in all_results.iter().flatten() {
 98        match result {
 99            Ok(eval_result) => successful.push(eval_result),
100            Err((err, name, repetition_ix)) => {
101                if failed_count == 0 {
102                    writeln!(w, "## Errors\n")?;
103                }
104
105                failed_count += 1;
106                let err = format!("{err:?}")
107                    .replace("<edits", "```xml\n<edits")
108                    .replace("</edits>", "</edits>\n```");
109                writeln!(
110                    w,
111                    "### ERROR {name}{}\n\n{err}\n",
112                    repetition_ix
113                        .map(|ix| format!(" [RUN {ix:03}]"))
114                        .unwrap_or_default()
115                )?;
116            }
117        }
118    }
119
120    if successful.len() > 1 {
121        let aggregated_result = EvaluationResult {
122            context: Scores::aggregate(successful.iter().map(|r| &r.context)),
123            edit_prediction: Scores::aggregate(successful.iter().map(|r| &r.edit_prediction)),
124        };
125
126        writeln!(w, "\n{}", "-".repeat(80))?;
127        writeln!(w, "\n## TOTAL SCORES")?;
128        writeln!(w, "\n### Success Rate")?;
129        writeln!(w, "{}", aggregated_result)?;
130    }
131
132    if successful.len() + failed_count > 1 {
133        writeln!(
134            w,
135            "\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
136            successful.len(),
137            successful.len() + failed_count,
138            (successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
139        )?;
140    }
141
142    Ok(())
143}
144
145pub async fn run_evaluate_one(
146    example: NamedExample,
147    repetition_ix: Option<u16>,
148    project: Entity<Project>,
149    zeta: Entity<Zeta>,
150    prompt_format: PromptFormat,
151    use_expected_context: bool,
152    cache_mode: CacheMode,
153    cx: &mut AsyncApp,
154) -> Result<EvaluationResult> {
155    let predict_result = zeta2_predict(
156        example.clone(),
157        project,
158        zeta,
159        repetition_ix,
160        prompt_format,
161        use_expected_context,
162        cache_mode,
163        cx,
164    )
165    .await?;
166
167    let evaluation_result = evaluate(&example.example, &predict_result);
168
169    if repetition_ix.is_none() {
170        write_eval_result(
171            &example,
172            &predict_result,
173            &evaluation_result,
174            &mut std::io::stdout(),
175            std::io::stdout().is_terminal(),
176        )?;
177    }
178
179    if let Some(mut results_file) =
180        std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err()
181    {
182        write_eval_result(
183            &example,
184            &predict_result,
185            &evaluation_result,
186            &mut results_file,
187            false,
188        )
189        .log_err();
190    }
191
192    anyhow::Ok(evaluation_result)
193}
194
195fn write_eval_result(
196    example: &NamedExample,
197    predictions: &PredictionDetails,
198    evaluation_result: &EvaluationResult,
199    out: &mut impl Write,
200    use_color: bool,
201) -> Result<()> {
202    writeln!(
203        out,
204        "## Expected edit prediction:\n\n```diff\n{}\n```\n",
205        compare_diffs(
206            &example.example.expected_patch,
207            &predictions.diff,
208            use_color
209        )
210    )?;
211    writeln!(
212        out,
213        "## Actual edit prediction:\n\n```diff\n{}\n```\n",
214        compare_diffs(
215            &predictions.diff,
216            &example.example.expected_patch,
217            use_color
218        )
219    )?;
220    writeln!(out, "{:#}", evaluation_result)?;
221
222    anyhow::Ok(())
223}
224
225#[derive(Debug, Default)]
226pub struct EvaluationResult {
227    pub edit_prediction: Scores,
228    pub context: Scores,
229}
230
231#[derive(Default, Debug)]
232pub struct Scores {
233    pub true_positives: usize,
234    pub false_positives: usize,
235    pub false_negatives: usize,
236}
237
238impl Scores {
239    pub fn new(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
240        let true_positives = expected.intersection(actual).count();
241        let false_positives = actual.difference(expected).count();
242        let false_negatives = expected.difference(actual).count();
243
244        Scores {
245            true_positives,
246            false_positives,
247            false_negatives,
248        }
249    }
250
251    pub fn to_markdown(&self) -> String {
252        format!(
253            "
254Precision       : {:.4}
255Recall          : {:.4}
256F1 Score        : {:.4}
257True Positives  : {}
258False Positives : {}
259False Negatives : {}",
260            self.precision(),
261            self.recall(),
262            self.f1_score(),
263            self.true_positives,
264            self.false_positives,
265            self.false_negatives
266        )
267    }
268
269    pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
270        let mut true_positives = 0;
271        let mut false_positives = 0;
272        let mut false_negatives = 0;
273
274        for score in scores {
275            true_positives += score.true_positives;
276            false_positives += score.false_positives;
277            false_negatives += score.false_negatives;
278        }
279
280        Scores {
281            true_positives,
282            false_positives,
283            false_negatives,
284        }
285    }
286
287    pub fn precision(&self) -> f64 {
288        if self.true_positives + self.false_positives == 0 {
289            0.0
290        } else {
291            self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
292        }
293    }
294
295    pub fn recall(&self) -> f64 {
296        if self.true_positives + self.false_negatives == 0 {
297            0.0
298        } else {
299            self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
300        }
301    }
302
303    pub fn f1_score(&self) -> f64 {
304        let recall = self.recall();
305        let precision = self.precision();
306        if precision + recall == 0.0 {
307            0.0
308        } else {
309            2.0 * precision * recall / (precision + recall)
310        }
311    }
312}
313
314impl std::fmt::Display for EvaluationResult {
315    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316        if f.alternate() {
317            self.fmt_table(f)
318        } else {
319            self.fmt_markdown(f)
320        }
321    }
322}
323
324impl EvaluationResult {
325    fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326        write!(
327            f,
328            r#"
329### Context Scores
330{}
331
332### Edit Prediction Scores
333{}
334"#,
335            self.context.to_markdown(),
336            self.edit_prediction.to_markdown()
337        )
338    }
339
340    fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341        writeln!(f, "### Scores\n")?;
342        writeln!(
343            f,
344            "                   TP     FP     FN     Precision   Recall     F1"
345        )?;
346        writeln!(
347            f,
348            "──────────────────────────────────────────────────────────────────"
349        )?;
350        writeln!(
351            f,
352            "Context Retrieval  {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
353            self.context.true_positives,
354            self.context.false_positives,
355            self.context.false_negatives,
356            self.context.precision() * 100.0,
357            self.context.recall() * 100.0,
358            self.context.f1_score() * 100.0
359        )?;
360        writeln!(
361            f,
362            "Edit Prediction    {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
363            self.edit_prediction.true_positives,
364            self.edit_prediction.false_positives,
365            self.edit_prediction.false_negatives,
366            self.edit_prediction.precision() * 100.0,
367            self.edit_prediction.recall() * 100.0,
368            self.edit_prediction.f1_score() * 100.0
369        )
370    }
371}
372
373pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult {
374    let mut eval_result = EvaluationResult::default();
375
376    let actual_context_lines: HashSet<_> = preds
377        .excerpts
378        .iter()
379        .flat_map(|excerpt| {
380            excerpt
381                .text
382                .lines()
383                .map(|line| format!("{}: {line}", excerpt.path.display()))
384        })
385        .collect();
386
387    let mut false_positive_lines = actual_context_lines.clone();
388
389    for entry in &example.expected_context {
390        let mut best_alternative_score: Option<Scores> = None;
391
392        for alternative in &entry.alternatives {
393            let expected: HashSet<_> = alternative
394                .excerpts
395                .iter()
396                .flat_map(|excerpt| {
397                    excerpt
398                        .text
399                        .lines()
400                        .map(|line| format!("{}: {line}", excerpt.path.display()))
401                })
402                .collect();
403
404            let scores = Scores::new(&expected, &actual_context_lines);
405
406            false_positive_lines.retain(|line| !actual_context_lines.contains(line));
407
408            if best_alternative_score
409                .as_ref()
410                .is_none_or(|best| scores.recall() > best.recall())
411            {
412                best_alternative_score = Some(scores);
413            }
414        }
415
416        let best_alternative = best_alternative_score.unwrap_or_default();
417        eval_result.context.false_negatives += best_alternative.false_negatives;
418        eval_result.context.true_positives += best_alternative.true_positives;
419    }
420
421    eval_result.context.false_positives = false_positive_lines.len();
422
423    // todo: alternatives for patches
424    let expected_patch_lines = example
425        .expected_patch
426        .lines()
427        .map(DiffLine::parse)
428        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
429        .map(|line| line.to_string())
430        .collect();
431
432    let actual_patch_lines = preds
433        .diff
434        .lines()
435        .map(DiffLine::parse)
436        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
437        .map(|line| line.to_string())
438        .collect();
439
440    eval_result.edit_prediction = Scores::new(&expected_patch_lines, &actual_patch_lines);
441    eval_result
442}
443
444/// Return annotated `patch_a` so that:
445/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
446/// Additions and deletions that are present in `patch_b` will be highlighted in green.
447pub fn compare_diffs(patch_a: &str, patch_b: &str, use_color: bool) -> String {
448    let green = if use_color { "\x1b[32m✓ " } else { "" };
449    let red = if use_color { "\x1b[31m✗ " } else { "" };
450    let neutral = if use_color { "  " } else { "" };
451    let reset = if use_color { "\x1b[0m" } else { "" };
452    let lines_a = patch_a.lines().map(DiffLine::parse);
453    let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
454
455    let annotated = lines_a
456        .map(|line| match line {
457            DiffLine::Addition(_) | DiffLine::Deletion(_) => {
458                if lines_b.contains(&line) {
459                    format!("{green}{line}{reset}")
460                } else {
461                    format!("{red}{line}{reset}")
462                }
463            }
464            _ => format!("{neutral}{line}{reset}"),
465        })
466        .collect::<Vec<String>>();
467
468    annotated.join("\n")
469}