evaluate.rs

  1use std::{
  2    collections::HashMap,
  3    io::{IsTerminal, Write},
  4    path::PathBuf,
  5    sync::Arc,
  6};
  7
  8use anyhow::Result;
  9use clap::Args;
 10use collections::HashSet;
 11use gpui::{AsyncApp, Entity};
 12use project::Project;
 13use util::ResultExt as _;
 14use zeta2::{Zeta, udiff::DiffLine};
 15
 16use crate::{
 17    PromptFormat,
 18    example::{Example, NamedExample},
 19    headless::ZetaCliAppState,
 20    paths::print_run_data_dir,
 21    predict::{CacheMode, PredictionDetails, zeta2_predict},
 22};
 23
 24#[derive(Debug, Args)]
 25pub struct EvaluateArguments {
 26    example_paths: Vec<PathBuf>,
 27    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
 28    prompt_format: PromptFormat,
 29    #[arg(long)]
 30    use_expected_context: bool,
 31    #[clap(long, value_enum, default_value_t = CacheMode::default())]
 32    cache: CacheMode,
 33    #[clap(short, long, default_value_t = 1, alias = "repeat")]
 34    repetitions: u16,
 35    #[arg(long)]
 36    skip_prediction: bool,
 37}
 38
 39#[derive(Debug)]
 40pub(crate) struct ExecutionData {
 41    execution_id: String,
 42    diff: String,
 43    reasoning: String,
 44}
 45
 46pub async fn run_evaluate(
 47    args: EvaluateArguments,
 48    app_state: &Arc<ZetaCliAppState>,
 49    cx: &mut AsyncApp,
 50) {
 51    if args.example_paths.is_empty() {
 52        eprintln!("No examples provided");
 53        return;
 54    }
 55    let all_tasks = args.example_paths.into_iter().map(|path| {
 56        let app_state = app_state.clone();
 57        let example = NamedExample::load(&path).expect("Failed to load example");
 58
 59        cx.spawn(async move |cx| {
 60            let (project, zetas, _edited_buffers) = example
 61                .setup_project(&app_state, args.repetitions, cx)
 62                .await
 63                .unwrap();
 64
 65            let tasks = zetas.into_iter().enumerate().map(|(repetition_ix, zeta)| {
 66                let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
 67                let example = example.clone();
 68                let project = project.clone();
 69
 70                cx.spawn(async move |cx| {
 71                    let name = example.name.clone();
 72                    run_evaluate_one(
 73                        example,
 74                        repetition_ix,
 75                        project,
 76                        zeta,
 77                        args.prompt_format,
 78                        args.use_expected_context,
 79                        !args.skip_prediction,
 80                        args.cache,
 81                        cx,
 82                    )
 83                    .await
 84                    .map_err(|err| (err, name, repetition_ix))
 85                })
 86            });
 87            futures::future::join_all(tasks).await
 88        })
 89    });
 90    let all_results = futures::future::join_all(all_tasks).await;
 91
 92    write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap();
 93    if let Some(mut output_file) =
 94        std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err()
 95    {
 96        write_aggregated_scores(&mut output_file, &all_results).log_err();
 97    };
 98
 99    if args.repetitions > 1 {
100        if let Err(e) = write_bucketed_analysis(&all_results) {
101            eprintln!("Failed to write bucketed analysis: {:?}", e);
102        }
103    }
104
105    print_run_data_dir(args.repetitions == 1, std::io::stdout().is_terminal());
106}
107
108fn write_aggregated_scores(
109    w: &mut impl std::io::Write,
110    all_results: &Vec<
111        Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
112    >,
113) -> Result<()> {
114    let mut successful = Vec::new();
115    let mut failed_count = 0;
116
117    for result in all_results.iter().flatten() {
118        match result {
119            Ok((eval_result, _execution_data)) => successful.push(eval_result),
120            Err((err, name, repetition_ix)) => {
121                if failed_count == 0 {
122                    writeln!(w, "## Errors\n")?;
123                }
124
125                failed_count += 1;
126                writeln!(w, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
127            }
128        }
129    }
130
131    if successful.len() > 1 {
132        let mut edit_predictions = successful
133            .iter()
134            .filter_map(|r| r.edit_prediction.as_ref())
135            .peekable();
136        let has_edit_predictions = edit_predictions.peek().is_some();
137        let aggregated_result = EvaluationResult {
138            context: Scores::aggregate(successful.iter().map(|r| &r.context)),
139            edit_prediction: has_edit_predictions.then(|| Scores::aggregate(edit_predictions)),
140            prompt_len: successful.iter().map(|r| r.prompt_len).sum::<usize>() / successful.len(),
141            generated_len: successful.iter().map(|r| r.generated_len).sum::<usize>()
142                / successful.len(),
143        };
144
145        writeln!(w, "\n{}", "-".repeat(80))?;
146        writeln!(w, "\n## TOTAL SCORES")?;
147        writeln!(w, "{:#}", aggregated_result)?;
148    }
149
150    if successful.len() + failed_count > 1 {
151        writeln!(
152            w,
153            "\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
154            successful.len(),
155            successful.len() + failed_count,
156            (successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
157        )?;
158    }
159
160    Ok(())
161}
162
163pub async fn run_evaluate_one(
164    example: NamedExample,
165    repetition_ix: Option<u16>,
166    project: Entity<Project>,
167    zeta: Entity<Zeta>,
168    prompt_format: PromptFormat,
169    use_expected_context: bool,
170    predict: bool,
171    cache_mode: CacheMode,
172    cx: &mut AsyncApp,
173) -> Result<(EvaluationResult, ExecutionData)> {
174    let predict_result = zeta2_predict(
175        example.clone(),
176        project,
177        zeta,
178        repetition_ix,
179        prompt_format,
180        use_expected_context,
181        cache_mode,
182        cx,
183    )
184    .await?;
185
186    let evaluation_result = evaluate(&example.example, &predict_result, predict);
187
188    if repetition_ix.is_none() {
189        write_eval_result(
190            &example,
191            &predict_result,
192            &evaluation_result,
193            &mut std::io::stdout(),
194            std::io::stdout().is_terminal(),
195            predict,
196        )?;
197    }
198
199    if let Some(mut results_file) =
200        std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err()
201    {
202        write_eval_result(
203            &example,
204            &predict_result,
205            &evaluation_result,
206            &mut results_file,
207            false,
208            predict,
209        )
210        .log_err();
211    }
212
213    let execution_data = ExecutionData {
214        execution_id: if let Some(rep_ix) = repetition_ix {
215            format!("{:03}", rep_ix)
216        } else {
217            example.name.clone()
218        },
219        diff: predict_result.diff.clone(),
220        reasoning: std::fs::read_to_string(
221            predict_result
222                .run_example_dir
223                .join("prediction_response.md"),
224        )
225        .unwrap_or_default(),
226    };
227
228    anyhow::Ok((evaluation_result, execution_data))
229}
230
231fn write_eval_result(
232    example: &NamedExample,
233    predictions: &PredictionDetails,
234    evaluation_result: &EvaluationResult,
235    out: &mut impl Write,
236    use_color: bool,
237    predict: bool,
238) -> Result<()> {
239    if predict {
240        writeln!(
241            out,
242            "## Expected edit prediction:\n\n```diff\n{}\n```\n",
243            compare_diffs(
244                &example.example.expected_patch,
245                &predictions.diff,
246                use_color
247            )
248        )?;
249        writeln!(
250            out,
251            "## Actual edit prediction:\n\n```diff\n{}\n```\n",
252            compare_diffs(
253                &predictions.diff,
254                &example.example.expected_patch,
255                use_color
256            )
257        )?;
258    }
259
260    writeln!(out, "{:#}", evaluation_result)?;
261
262    anyhow::Ok(())
263}
264
265#[derive(Debug, Default)]
266pub struct EvaluationResult {
267    pub edit_prediction: Option<Scores>,
268    pub context: Scores,
269    pub prompt_len: usize,
270    pub generated_len: usize,
271}
272
273#[derive(Default, Debug)]
274pub struct Scores {
275    pub true_positives: usize,
276    pub false_positives: usize,
277    pub false_negatives: usize,
278}
279
280impl Scores {
281    pub fn new(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
282        let true_positives = expected.intersection(actual).count();
283        let false_positives = actual.difference(expected).count();
284        let false_negatives = expected.difference(actual).count();
285
286        Scores {
287            true_positives,
288            false_positives,
289            false_negatives,
290        }
291    }
292
293    pub fn to_markdown(&self) -> String {
294        format!(
295            "
296Precision       : {:.4}
297Recall          : {:.4}
298F1 Score        : {:.4}
299True Positives  : {}
300False Positives : {}
301False Negatives : {}",
302            self.precision(),
303            self.recall(),
304            self.f1_score(),
305            self.true_positives,
306            self.false_positives,
307            self.false_negatives
308        )
309    }
310
311    pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
312        let mut true_positives = 0;
313        let mut false_positives = 0;
314        let mut false_negatives = 0;
315
316        for score in scores {
317            true_positives += score.true_positives;
318            false_positives += score.false_positives;
319            false_negatives += score.false_negatives;
320        }
321
322        Scores {
323            true_positives,
324            false_positives,
325            false_negatives,
326        }
327    }
328
329    pub fn precision(&self) -> f64 {
330        if self.true_positives + self.false_positives == 0 {
331            0.0
332        } else {
333            self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
334        }
335    }
336
337    pub fn recall(&self) -> f64 {
338        if self.true_positives + self.false_negatives == 0 {
339            0.0
340        } else {
341            self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
342        }
343    }
344
345    pub fn f1_score(&self) -> f64 {
346        let recall = self.recall();
347        let precision = self.precision();
348        if precision + recall == 0.0 {
349            0.0
350        } else {
351            2.0 * precision * recall / (precision + recall)
352        }
353    }
354}
355
356impl std::fmt::Display for EvaluationResult {
357    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
358        if f.alternate() {
359            self.fmt_table(f)
360        } else {
361            self.fmt_markdown(f)
362        }
363    }
364}
365
366impl EvaluationResult {
367    fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
368        write!(
369            f,
370            r#"
371### Context Scores
372{}
373"#,
374            self.context.to_markdown(),
375        )?;
376        if let Some(prediction) = &self.edit_prediction {
377            write!(
378                f,
379                r#"
380                ### Edit Prediction Scores
381                {}"#,
382                prediction.to_markdown()
383            )?;
384        }
385        Ok(())
386    }
387
388    fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
389        writeln!(f, "### Scores\n")?;
390        writeln!(
391            f,
392            "                   Prompt  Generated  TP     FP     FN     Precision   Recall     F1"
393        )?;
394        writeln!(
395            f,
396            "────────────────────────────────────────────────────────────────────────────────────"
397        )?;
398        writeln!(
399            f,
400            "Context Retrieval  {:<7} {:<10} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
401            "",
402            "",
403            self.context.true_positives,
404            self.context.false_positives,
405            self.context.false_negatives,
406            self.context.precision() * 100.0,
407            self.context.recall() * 100.0,
408            self.context.f1_score() * 100.0
409        )?;
410        if let Some(edit_prediction) = &self.edit_prediction {
411            writeln!(
412                f,
413                "Edit Prediction    {:<7} {:<10} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
414                self.prompt_len,
415                self.generated_len,
416                edit_prediction.true_positives,
417                edit_prediction.false_positives,
418                edit_prediction.false_negatives,
419                edit_prediction.precision() * 100.0,
420                edit_prediction.recall() * 100.0,
421                edit_prediction.f1_score() * 100.0
422            )?;
423        }
424        Ok(())
425    }
426}
427
428pub fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult {
429    let mut eval_result = EvaluationResult {
430        prompt_len: preds.prompt_len,
431        generated_len: preds.generated_len,
432        ..Default::default()
433    };
434
435    let actual_context_lines: HashSet<_> = preds
436        .excerpts
437        .iter()
438        .flat_map(|excerpt| {
439            excerpt
440                .text
441                .lines()
442                .map(|line| format!("{}: {line}", excerpt.path.display()))
443        })
444        .collect();
445
446    let mut false_positive_lines = actual_context_lines.clone();
447
448    for entry in &example.expected_context {
449        let mut best_alternative_score: Option<Scores> = None;
450
451        for alternative in &entry.alternatives {
452            let expected: HashSet<_> = alternative
453                .excerpts
454                .iter()
455                .flat_map(|excerpt| {
456                    excerpt
457                        .text
458                        .lines()
459                        .map(|line| format!("{}: {line}", excerpt.path.display()))
460                })
461                .collect();
462
463            let scores = Scores::new(&expected, &actual_context_lines);
464
465            false_positive_lines.retain(|line| !actual_context_lines.contains(line));
466
467            if best_alternative_score
468                .as_ref()
469                .is_none_or(|best| scores.recall() > best.recall())
470            {
471                best_alternative_score = Some(scores);
472            }
473        }
474
475        let best_alternative = best_alternative_score.unwrap_or_default();
476        eval_result.context.false_negatives += best_alternative.false_negatives;
477        eval_result.context.true_positives += best_alternative.true_positives;
478    }
479
480    eval_result.context.false_positives = false_positive_lines.len();
481
482    if predict {
483        // todo: alternatives for patches
484        let expected_patch_lines = example
485            .expected_patch
486            .lines()
487            .map(DiffLine::parse)
488            .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
489            .map(|line| line.to_string())
490            .collect();
491
492        let actual_patch_lines = preds
493            .diff
494            .lines()
495            .map(DiffLine::parse)
496            .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
497            .map(|line| line.to_string())
498            .collect();
499
500        eval_result.edit_prediction = Some(Scores::new(&expected_patch_lines, &actual_patch_lines));
501    }
502
503    eval_result
504}
505
506/// Return annotated `patch_a` so that:
507/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
508/// Additions and deletions that are present in `patch_b` will be highlighted in green.
509pub fn compare_diffs(patch_a: &str, patch_b: &str, use_color: bool) -> String {
510    let green = if use_color { "\x1b[32m✓ " } else { "" };
511    let red = if use_color { "\x1b[31m✗ " } else { "" };
512    let neutral = if use_color { "  " } else { "" };
513    let reset = if use_color { "\x1b[0m" } else { "" };
514    let lines_a = patch_a.lines().map(DiffLine::parse);
515    let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
516
517    let annotated = lines_a
518        .map(|line| match line {
519            DiffLine::Addition(_) | DiffLine::Deletion(_) => {
520                if lines_b.contains(&line) {
521                    format!("{green}{line}{reset}")
522                } else {
523                    format!("{red}{line}{reset}")
524                }
525            }
526            _ => format!("{neutral}{line}{reset}"),
527        })
528        .collect::<Vec<String>>();
529
530    annotated.join("\n")
531}
532
533fn write_bucketed_analysis(
534    all_results: &Vec<
535        Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
536    >,
537) -> Result<()> {
538    #[derive(Debug)]
539    struct EditBucket {
540        diff: String,
541        is_correct: bool,
542        execution_indices: Vec<String>,
543        reasoning_samples: Vec<String>,
544    }
545
546    let mut total_executions = 0;
547    let mut empty_predictions = Vec::new();
548    let mut errors = Vec::new();
549
550    let mut buckets: HashMap<String, EditBucket> = HashMap::new();
551
552    for result in all_results.iter().flatten() {
553        total_executions += 1;
554
555        let (evaluation_result, execution_data) = match result {
556            Ok((eval_result, execution_data)) => {
557                if execution_data.diff.is_empty() {
558                    empty_predictions.push(execution_data);
559                    continue;
560                }
561                (eval_result, execution_data)
562            }
563            Err(err) => {
564                errors.push(err);
565                continue;
566            }
567        };
568
569        buckets
570            .entry(execution_data.diff.clone())
571            .and_modify(|bucket| {
572                bucket
573                    .execution_indices
574                    .push(execution_data.execution_id.clone());
575                bucket
576                    .reasoning_samples
577                    .push(execution_data.reasoning.clone());
578            })
579            .or_insert_with(|| EditBucket {
580                diff: execution_data.diff.clone(),
581                is_correct: {
582                    evaluation_result
583                        .edit_prediction
584                        .as_ref()
585                        .map_or(false, |edit_prediction| {
586                            edit_prediction.false_positives == 0
587                                && edit_prediction.false_negatives == 0
588                                && edit_prediction.true_positives > 0
589                        })
590                },
591                execution_indices: vec![execution_data.execution_id.clone()],
592                reasoning_samples: vec![execution_data.reasoning.clone()],
593            });
594    }
595
596    let mut sorted_buckets = buckets.into_values().collect::<Vec<_>>();
597    sorted_buckets.sort_by(|a, b| match (a.is_correct, b.is_correct) {
598        (true, false) => std::cmp::Ordering::Less,
599        (false, true) => std::cmp::Ordering::Greater,
600        _ => b.execution_indices.len().cmp(&a.execution_indices.len()),
601    });
602
603    let output_path = crate::paths::RUN_DIR.join("bucketed_analysis.md");
604    let mut output = std::fs::File::create(&output_path)?;
605
606    writeln!(output, "# Bucketed Edit Analysis\n")?;
607
608    writeln!(output, "## Summary\n")?;
609    writeln!(output, "- **Total executions**: {}", total_executions)?;
610
611    let correct_count: usize = sorted_buckets
612        .iter()
613        .filter(|b| b.is_correct)
614        .map(|b| b.execution_indices.len())
615        .sum();
616
617    let incorrect_count: usize = sorted_buckets
618        .iter()
619        .filter(|b| !b.is_correct)
620        .map(|b| b.execution_indices.len())
621        .sum();
622
623    writeln!(
624        output,
625        "- **Correct predictions**: {} ({:.1}%)",
626        correct_count,
627        (correct_count as f64 / total_executions as f64) * 100.0
628    )?;
629
630    writeln!(
631        output,
632        "- **Incorrect predictions**: {} ({:.1}%)",
633        incorrect_count,
634        (incorrect_count as f64 / total_executions as f64) * 100.0
635    )?;
636
637    writeln!(
638        output,
639        "- **No Predictions**: {} ({:.1}%)",
640        empty_predictions.len(),
641        (empty_predictions.len() as f64 / total_executions as f64) * 100.0
642    )?;
643
644    let unique_incorrect = sorted_buckets.iter().filter(|b| !b.is_correct).count();
645    writeln!(
646        output,
647        "- **Unique incorrect edit patterns**: {}\n",
648        unique_incorrect
649    )?;
650
651    writeln!(output, "---\n")?;
652
653    for (idx, bucket) in sorted_buckets.iter().filter(|b| b.is_correct).enumerate() {
654        if idx == 0 {
655            writeln!(
656                output,
657                "## Correct Predictions ({} occurrences)\n",
658                bucket.execution_indices.len()
659            )?;
660        }
661
662        writeln!(output, "**Predicted Edit:**\n")?;
663        writeln!(output, "```diff")?;
664        writeln!(output, "{}", bucket.diff)?;
665        writeln!(output, "```\n")?;
666
667        writeln!(
668            output,
669            "**Executions:** {}\n",
670            bucket.execution_indices.join(", ")
671        )?;
672        writeln!(output, "---\n")?;
673    }
674
675    for (idx, bucket) in sorted_buckets.iter().filter(|b| !b.is_correct).enumerate() {
676        writeln!(
677            output,
678            "## Incorrect Prediction #{} ({} occurrences)\n",
679            idx + 1,
680            bucket.execution_indices.len()
681        )?;
682
683        writeln!(output, "**Predicted Edit:**\n")?;
684        writeln!(output, "```diff")?;
685        writeln!(output, "{}", bucket.diff)?;
686        writeln!(output, "```\n")?;
687
688        writeln!(
689            output,
690            "**Executions:** {}\n",
691            bucket.execution_indices.join(", ")
692        )?;
693
694        for (exec_id, reasoning) in bucket
695            .execution_indices
696            .iter()
697            .zip(bucket.reasoning_samples.iter())
698        {
699            writeln!(output, "{}", fmt_execution(exec_id, reasoning))?;
700        }
701
702        writeln!(output, "\n---\n")?;
703    }
704
705    if !empty_predictions.is_empty() {
706        writeln!(
707            output,
708            "## No Predictions ({} occurrences)\n",
709            empty_predictions.len()
710        )?;
711
712        for execution_data in &empty_predictions {
713            writeln!(
714                output,
715                "{}",
716                fmt_execution(&execution_data.execution_id, &execution_data.reasoning)
717            )?;
718        }
719        writeln!(output, "\n---\n")?;
720    }
721
722    if !errors.is_empty() {
723        writeln!(output, "## Errors ({} occurrences)\n", errors.len())?;
724
725        for (err, name, repetition_ix) in &errors {
726            writeln!(output, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
727        }
728        writeln!(output, "\n---\n")?;
729    }
730
731    fn fmt_execution(exec_id: &str, reasoning: &str) -> String {
732        let exec_content = format!(
733            "\n### Execution {} `{}/{}/prediction_response.md`{}",
734            exec_id,
735            crate::paths::RUN_DIR.display(),
736            exec_id,
737            indent_text(&format!("\n\n```\n{}\n```\n", reasoning,), 2)
738        );
739        indent_text(&exec_content, 2)
740    }
741
742    fn indent_text(text: &str, spaces: usize) -> String {
743        let indent = " ".repeat(spaces);
744        text.lines()
745            .collect::<Vec<_>>()
746            .join(&format!("\n{}", indent))
747    }
748
749    Ok(())
750}
751
752fn fmt_evaluation_error(err: &anyhow::Error, name: &str, repetition_ix: &Option<u16>) -> String {
753    let err = format!("{err:?}")
754        .replace("<edits", "```xml\n<edits")
755        .replace("</edits>", "</edits>\n```");
756    format!(
757        "### ERROR {name}{}\n\n{err}\n",
758        repetition_ix
759            .map(|ix| format!(" [RUN {ix:03}]"))
760            .unwrap_or_default()
761    )
762}