evaluate.rs

  1use std::{
  2    io::{IsTerminal, Write},
  3    path::PathBuf,
  4    sync::Arc,
  5};
  6
  7use anyhow::Result;
  8use clap::Args;
  9use collections::HashSet;
 10use gpui::{AsyncApp, Entity};
 11use project::Project;
 12use util::ResultExt as _;
 13use zeta2::{Zeta, udiff::DiffLine};
 14
 15use crate::{
 16    PromptFormat,
 17    example::{Example, NamedExample},
 18    headless::ZetaCliAppState,
 19    paths::print_run_data_dir,
 20    predict::{CacheMode, PredictionDetails, zeta2_predict},
 21};
 22
 23#[derive(Debug, Args)]
 24pub struct EvaluateArguments {
 25    example_paths: Vec<PathBuf>,
 26    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
 27    prompt_format: PromptFormat,
 28    #[arg(long)]
 29    use_expected_context: bool,
 30    #[clap(long, value_enum, default_value_t = CacheMode::default())]
 31    cache: CacheMode,
 32    #[clap(short, long, default_value_t = 1, alias = "repeat")]
 33    repetitions: u16,
 34}
 35
 36pub async fn run_evaluate(
 37    args: EvaluateArguments,
 38    app_state: &Arc<ZetaCliAppState>,
 39    cx: &mut AsyncApp,
 40) {
 41    if args.example_paths.is_empty() {
 42        eprintln!("No examples provided");
 43        return;
 44    }
 45    let all_tasks = args.example_paths.into_iter().map(|path| {
 46        let app_state = app_state.clone();
 47        let example = NamedExample::load(&path).unwrap();
 48
 49        cx.spawn(async move |cx| {
 50            let (project, zetas, _edited_buffers) = example
 51                .setup_project(&app_state, args.repetitions, cx)
 52                .await
 53                .unwrap();
 54
 55            let tasks = zetas.into_iter().enumerate().map(|(repetition_ix, zeta)| {
 56                let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
 57                let example = example.clone();
 58                let project = project.clone();
 59
 60                cx.spawn(async move |cx| {
 61                    let name = example.name.clone();
 62                    run_evaluate_one(
 63                        example,
 64                        repetition_ix,
 65                        project,
 66                        zeta,
 67                        args.prompt_format,
 68                        args.use_expected_context,
 69                        args.cache,
 70                        cx,
 71                    )
 72                    .await
 73                    .map_err(|err| (err, name, repetition_ix))
 74                })
 75            });
 76            futures::future::join_all(tasks).await
 77        })
 78    });
 79    let all_results = futures::future::join_all(all_tasks).await;
 80
 81    write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap();
 82    if let Some(mut output_file) =
 83        std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err()
 84    {
 85        write_aggregated_scores(&mut output_file, &all_results).log_err();
 86    };
 87    print_run_data_dir(args.repetitions == 1);
 88}
 89
 90fn write_aggregated_scores(
 91    w: &mut impl std::io::Write,
 92    all_results: &Vec<Vec<Result<EvaluationResult, (anyhow::Error, String, Option<u16>)>>>,
 93) -> Result<()> {
 94    let mut successful = Vec::new();
 95    let mut failed_count = 0;
 96
 97    for result in all_results.iter().flatten() {
 98        match result {
 99            Ok(eval_result) => successful.push(eval_result),
100            Err((err, name, repetition_ix)) => {
101                if failed_count == 0 {
102                    writeln!(w, "## Errors\n")?;
103                }
104
105                failed_count += 1;
106                let err = err
107                    .to_string()
108                    .replace("<edits", "```xml\n<edits")
109                    .replace("</edits>", "</edits>\n```");
110                writeln!(
111                    w,
112                    "### ERROR {name}{}\n\n{err}\n",
113                    repetition_ix
114                        .map(|ix| format!(" [RUN {ix:03}]"))
115                        .unwrap_or_default()
116                )?;
117            }
118        }
119    }
120
121    if successful.len() > 1 {
122        let aggregated_result = EvaluationResult {
123            context: Scores::aggregate(successful.iter().map(|r| &r.context)),
124            edit_prediction: Scores::aggregate(successful.iter().map(|r| &r.edit_prediction)),
125        };
126
127        writeln!(w, "\n{}", "-".repeat(80))?;
128        writeln!(w, "\n## TOTAL SCORES")?;
129        writeln!(w, "\n### Success Rate")?;
130        writeln!(w, "{}", aggregated_result)?;
131    }
132
133    if successful.len() + failed_count > 1 {
134        writeln!(
135            w,
136            "\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
137            successful.len(),
138            successful.len() + failed_count,
139            (successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
140        )?;
141    }
142
143    Ok(())
144}
145
146pub async fn run_evaluate_one(
147    example: NamedExample,
148    repetition_ix: Option<u16>,
149    project: Entity<Project>,
150    zeta: Entity<Zeta>,
151    prompt_format: PromptFormat,
152    use_expected_context: bool,
153    cache_mode: CacheMode,
154    cx: &mut AsyncApp,
155) -> Result<EvaluationResult> {
156    let predict_result = zeta2_predict(
157        example.clone(),
158        project,
159        zeta,
160        repetition_ix,
161        prompt_format,
162        use_expected_context,
163        cache_mode,
164        cx,
165    )
166    .await?;
167
168    let evaluation_result = evaluate(&example.example, &predict_result);
169
170    if repetition_ix.is_none() {
171        write_eval_result(
172            &example,
173            &predict_result,
174            &evaluation_result,
175            &mut std::io::stdout(),
176        )?;
177    }
178
179    if let Some(mut results_file) =
180        std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err()
181    {
182        write_eval_result(
183            &example,
184            &predict_result,
185            &evaluation_result,
186            &mut results_file,
187        )
188        .log_err();
189    }
190
191    anyhow::Ok(evaluation_result)
192}
193
194fn write_eval_result(
195    example: &NamedExample,
196    predictions: &PredictionDetails,
197    evaluation_result: &EvaluationResult,
198    out: &mut impl Write,
199) -> Result<()> {
200    writeln!(
201        out,
202        "## Expected edit prediction:\n\n```diff\n{}\n```\n",
203        compare_diffs(&example.example.expected_patch, &predictions.diff)
204    )?;
205    writeln!(
206        out,
207        "## Actual edit prediction:\n\n```diff\n{}\n```\n",
208        compare_diffs(&predictions.diff, &example.example.expected_patch)
209    )?;
210    writeln!(out, "{:#}", evaluation_result)?;
211
212    anyhow::Ok(())
213}
214
215#[derive(Debug, Default)]
216pub struct EvaluationResult {
217    pub edit_prediction: Scores,
218    pub context: Scores,
219}
220
221#[derive(Default, Debug)]
222pub struct Scores {
223    pub true_positives: usize,
224    pub false_positives: usize,
225    pub false_negatives: usize,
226}
227
228impl Scores {
229    pub fn new(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
230        let true_positives = expected.intersection(actual).count();
231        let false_positives = actual.difference(expected).count();
232        let false_negatives = expected.difference(actual).count();
233
234        Scores {
235            true_positives,
236            false_positives,
237            false_negatives,
238        }
239    }
240
241    pub fn to_markdown(&self) -> String {
242        format!(
243            "
244Precision       : {:.4}
245Recall          : {:.4}
246F1 Score        : {:.4}
247True Positives  : {}
248False Positives : {}
249False Negatives : {}",
250            self.precision(),
251            self.recall(),
252            self.f1_score(),
253            self.true_positives,
254            self.false_positives,
255            self.false_negatives
256        )
257    }
258
259    pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
260        let mut true_positives = 0;
261        let mut false_positives = 0;
262        let mut false_negatives = 0;
263
264        for score in scores {
265            true_positives += score.true_positives;
266            false_positives += score.false_positives;
267            false_negatives += score.false_negatives;
268        }
269
270        Scores {
271            true_positives,
272            false_positives,
273            false_negatives,
274        }
275    }
276
277    pub fn precision(&self) -> f64 {
278        if self.true_positives + self.false_positives == 0 {
279            0.0
280        } else {
281            self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
282        }
283    }
284
285    pub fn recall(&self) -> f64 {
286        if self.true_positives + self.false_negatives == 0 {
287            0.0
288        } else {
289            self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
290        }
291    }
292
293    pub fn f1_score(&self) -> f64 {
294        let recall = self.recall();
295        let precision = self.precision();
296        if precision + recall == 0.0 {
297            0.0
298        } else {
299            2.0 * precision * recall / (precision + recall)
300        }
301    }
302}
303
304impl std::fmt::Display for EvaluationResult {
305    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306        if f.alternate() {
307            self.fmt_table(f)
308        } else {
309            self.fmt_markdown(f)
310        }
311    }
312}
313
314impl EvaluationResult {
315    fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316        write!(
317            f,
318            r#"
319### Context Scores
320{}
321
322### Edit Prediction Scores
323{}
324"#,
325            self.context.to_markdown(),
326            self.edit_prediction.to_markdown()
327        )
328    }
329
330    fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
331        writeln!(f, "### Scores\n")?;
332        writeln!(
333            f,
334            "                   TP     FP     FN     Precision   Recall     F1"
335        )?;
336        writeln!(
337            f,
338            "──────────────────────────────────────────────────────────────────"
339        )?;
340        writeln!(
341            f,
342            "Context Retrieval  {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
343            self.context.true_positives,
344            self.context.false_positives,
345            self.context.false_negatives,
346            self.context.precision() * 100.0,
347            self.context.recall() * 100.0,
348            self.context.f1_score() * 100.0
349        )?;
350        writeln!(
351            f,
352            "Edit Prediction    {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
353            self.edit_prediction.true_positives,
354            self.edit_prediction.false_positives,
355            self.edit_prediction.false_negatives,
356            self.edit_prediction.precision() * 100.0,
357            self.edit_prediction.recall() * 100.0,
358            self.edit_prediction.f1_score() * 100.0
359        )
360    }
361}
362
363pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult {
364    let mut eval_result = EvaluationResult::default();
365
366    let actual_context_lines: HashSet<_> = preds
367        .excerpts
368        .iter()
369        .flat_map(|excerpt| {
370            excerpt
371                .text
372                .lines()
373                .map(|line| format!("{}: {line}", excerpt.path.display()))
374        })
375        .collect();
376
377    let mut false_positive_lines = actual_context_lines.clone();
378
379    for entry in &example.expected_context {
380        let mut best_alternative_score: Option<Scores> = None;
381
382        for alternative in &entry.alternatives {
383            let expected: HashSet<_> = alternative
384                .excerpts
385                .iter()
386                .flat_map(|excerpt| {
387                    excerpt
388                        .text
389                        .lines()
390                        .map(|line| format!("{}: {line}", excerpt.path.display()))
391                })
392                .collect();
393
394            let scores = Scores::new(&expected, &actual_context_lines);
395
396            false_positive_lines.retain(|line| !actual_context_lines.contains(line));
397
398            if best_alternative_score
399                .as_ref()
400                .is_none_or(|best| scores.recall() > best.recall())
401            {
402                best_alternative_score = Some(scores);
403            }
404        }
405
406        let best_alternative = best_alternative_score.unwrap_or_default();
407        eval_result.context.false_negatives += best_alternative.false_negatives;
408        eval_result.context.true_positives += best_alternative.true_positives;
409    }
410
411    eval_result.context.false_positives = false_positive_lines.len();
412
413    // todo: alternatives for patches
414    let expected_patch_lines = example
415        .expected_patch
416        .lines()
417        .map(DiffLine::parse)
418        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
419        .map(|line| line.to_string())
420        .collect();
421
422    let actual_patch_lines = preds
423        .diff
424        .lines()
425        .map(DiffLine::parse)
426        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
427        .map(|line| line.to_string())
428        .collect();
429
430    eval_result.edit_prediction = Scores::new(&expected_patch_lines, &actual_patch_lines);
431    eval_result
432}
433
434/// Return annotated `patch_a` so that:
435/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
436/// Additions and deletions that are present in `patch_b` will be highlighted in green.
437pub fn compare_diffs(patch_a: &str, patch_b: &str) -> String {
438    let use_color = std::io::stdout().is_terminal();
439    let green = if use_color { "\x1b[32m✓ " } else { "" };
440    let red = if use_color { "\x1b[31m✗ " } else { "" };
441    let neutral = if use_color { "  " } else { "" };
442    let reset = if use_color { "\x1b[0m" } else { "" };
443    let lines_a = patch_a.lines().map(DiffLine::parse);
444    let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
445
446    let annotated = lines_a
447        .map(|line| match line {
448            DiffLine::Addition(_) | DiffLine::Deletion(_) => {
449                if lines_b.contains(&line) {
450                    format!("{green}{line}{reset}")
451                } else {
452                    format!("{red}{line}{reset}")
453                }
454            }
455            _ => format!("{neutral}{line}{reset}"),
456        })
457        .collect::<Vec<String>>();
458
459    annotated.join("\n")
460}