evaluate.rs

  1use std::{
  2    collections::{BTreeSet, HashMap},
  3    io::{IsTerminal, Write},
  4    sync::Arc,
  5};
  6
  7use anyhow::Result;
  8use collections::HashSet;
  9use gpui::{AsyncApp, Entity};
 10use project::Project;
 11use util::ResultExt as _;
 12use zeta2::{Zeta, udiff::DiffLine};
 13
 14use crate::{
 15    EvaluateArguments, PredictionOptions,
 16    example::{Example, NamedExample},
 17    headless::ZetaCliAppState,
 18    paths::print_run_data_dir,
 19    predict::{PredictionDetails, perform_predict, setup_zeta},
 20};
 21
 22#[derive(Debug)]
 23pub(crate) struct ExecutionData {
 24    execution_id: String,
 25    diff: String,
 26    reasoning: String,
 27}
 28
 29pub async fn run_evaluate(
 30    args: EvaluateArguments,
 31    app_state: &Arc<ZetaCliAppState>,
 32    cx: &mut AsyncApp,
 33) {
 34    if args.example_paths.is_empty() {
 35        eprintln!("No examples provided");
 36        return;
 37    }
 38
 39    let all_tasks = args.example_paths.into_iter().map(|path| {
 40        let options = args.options.clone();
 41        let app_state = app_state.clone();
 42        let example = NamedExample::load(&path).expect("Failed to load example");
 43
 44        cx.spawn(async move |cx| {
 45            let project = example.setup_project(&app_state, cx).await.unwrap();
 46
 47            let providers = (0..args.repetitions)
 48                .map(|_| setup_zeta(args.options.provider, &project, &app_state, cx).unwrap())
 49                .collect::<Vec<_>>();
 50
 51            let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
 52
 53            let tasks = providers
 54                .into_iter()
 55                .enumerate()
 56                .map(move |(repetition_ix, zeta)| {
 57                    let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
 58                    let example = example.clone();
 59                    let project = project.clone();
 60                    let options = options.clone();
 61
 62                    cx.spawn(async move |cx| {
 63                        let name = example.name.clone();
 64                        run_evaluate_one(
 65                            example,
 66                            repetition_ix,
 67                            project,
 68                            zeta,
 69                            options,
 70                            !args.skip_prediction,
 71                            cx,
 72                        )
 73                        .await
 74                        .map_err(|err| (err, name, repetition_ix))
 75                    })
 76                });
 77            futures::future::join_all(tasks).await
 78        })
 79    });
 80    let all_results = futures::future::join_all(all_tasks).await;
 81
 82    write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap();
 83    if let Some(mut output_file) =
 84        std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err()
 85    {
 86        write_aggregated_scores(&mut output_file, &all_results).log_err();
 87    };
 88
 89    if args.repetitions > 1 {
 90        if let Err(e) = write_bucketed_analysis(&all_results) {
 91            eprintln!("Failed to write bucketed analysis: {:?}", e);
 92        }
 93    }
 94
 95    print_run_data_dir(args.repetitions == 1, std::io::stdout().is_terminal());
 96}
 97
 98fn write_aggregated_scores(
 99    w: &mut impl std::io::Write,
100    all_results: &Vec<
101        Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
102    >,
103) -> Result<()> {
104    let mut successful = Vec::new();
105    let mut failed_count = 0;
106
107    for result in all_results.iter().flatten() {
108        match result {
109            Ok((eval_result, _execution_data)) => successful.push(eval_result),
110            Err((err, name, repetition_ix)) => {
111                if failed_count == 0 {
112                    writeln!(w, "## Errors\n")?;
113                }
114
115                failed_count += 1;
116                writeln!(w, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
117            }
118        }
119    }
120
121    if successful.len() > 1 {
122        let mut edit_predictions = successful
123            .iter()
124            .filter_map(|r| r.edit_prediction.as_ref())
125            .peekable();
126        let has_edit_predictions = edit_predictions.peek().is_some();
127        let aggregated_result = EvaluationResult {
128            context: Scores::aggregate(successful.iter().map(|r| &r.context)),
129            edit_prediction: has_edit_predictions.then(|| Scores::aggregate(edit_predictions)),
130            prompt_len: successful.iter().map(|r| r.prompt_len).sum::<usize>() / successful.len(),
131            generated_len: successful.iter().map(|r| r.generated_len).sum::<usize>()
132                / successful.len(),
133            context_lines_found_in_context: successful
134                .iter()
135                .map(|r| r.context_lines_found_in_context)
136                .sum::<usize>()
137                / successful.len(),
138            context_lines_in_expected_patch: successful
139                .iter()
140                .map(|r| r.context_lines_in_expected_patch)
141                .sum::<usize>()
142                / successful.len(),
143        };
144
145        writeln!(w, "\n{}", "-".repeat(80))?;
146        writeln!(w, "\n## TOTAL SCORES")?;
147        writeln!(w, "{:#}", aggregated_result)?;
148    }
149
150    if successful.len() + failed_count > 1 {
151        writeln!(
152            w,
153            "\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
154            successful.len(),
155            successful.len() + failed_count,
156            (successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
157        )?;
158    }
159
160    Ok(())
161}
162
163pub async fn run_evaluate_one(
164    example: NamedExample,
165    repetition_ix: Option<u16>,
166    project: Entity<Project>,
167    zeta: Entity<Zeta>,
168    prediction_options: PredictionOptions,
169    predict: bool,
170    cx: &mut AsyncApp,
171) -> Result<(EvaluationResult, ExecutionData)> {
172    let predict_result = perform_predict(
173        example.clone(),
174        project,
175        zeta,
176        repetition_ix,
177        prediction_options,
178        cx,
179    )
180    .await?;
181
182    let evaluation_result = evaluate(&example.example, &predict_result, predict);
183
184    if repetition_ix.is_none() {
185        write_eval_result(
186            &example,
187            &predict_result,
188            &evaluation_result,
189            &mut std::io::stdout(),
190            std::io::stdout().is_terminal(),
191            predict,
192        )?;
193    }
194
195    if let Some(mut results_file) =
196        std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err()
197    {
198        write_eval_result(
199            &example,
200            &predict_result,
201            &evaluation_result,
202            &mut results_file,
203            false,
204            predict,
205        )
206        .log_err();
207    }
208
209    let execution_data = ExecutionData {
210        execution_id: if let Some(rep_ix) = repetition_ix {
211            format!("{:03}", rep_ix)
212        } else {
213            example.name.clone()
214        },
215        diff: predict_result.diff.clone(),
216        reasoning: std::fs::read_to_string(
217            predict_result
218                .run_example_dir
219                .join("prediction_response.md"),
220        )
221        .unwrap_or_default(),
222    };
223
224    anyhow::Ok((evaluation_result, execution_data))
225}
226
227fn write_eval_result(
228    example: &NamedExample,
229    predictions: &PredictionDetails,
230    evaluation_result: &EvaluationResult,
231    out: &mut impl Write,
232    use_color: bool,
233    predict: bool,
234) -> Result<()> {
235    if predict {
236        writeln!(
237            out,
238            "## Expected edit prediction:\n\n```diff\n{}\n```\n",
239            compare_diffs(
240                &example.example.expected_patch,
241                &predictions.diff,
242                use_color
243            )
244        )?;
245        writeln!(
246            out,
247            "## Actual edit prediction:\n\n```diff\n{}\n```\n",
248            compare_diffs(
249                &predictions.diff,
250                &example.example.expected_patch,
251                use_color
252            )
253        )?;
254    }
255
256    writeln!(out, "{:#}", evaluation_result)?;
257
258    anyhow::Ok(())
259}
260
261#[derive(Debug, Default)]
262pub struct EvaluationResult {
263    pub edit_prediction: Option<Scores>,
264    pub context: Scores,
265    pub prompt_len: usize,
266    pub generated_len: usize,
267    pub context_lines_in_expected_patch: usize,
268    pub context_lines_found_in_context: usize,
269}
270
271#[derive(Default, Debug)]
272pub struct Scores {
273    pub true_positives: usize,
274    pub false_positives: usize,
275    pub false_negatives: usize,
276}
277
278impl Scores {
279    pub fn new(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
280        let true_positives = expected.intersection(actual).count();
281        let false_positives = actual.difference(expected).count();
282        let false_negatives = expected.difference(actual).count();
283
284        Scores {
285            true_positives,
286            false_positives,
287            false_negatives,
288        }
289    }
290
291    pub fn to_markdown(&self) -> String {
292        format!(
293            "
294Precision       : {:.4}
295Recall          : {:.4}
296F1 Score        : {:.4}
297True Positives  : {}
298False Positives : {}
299False Negatives : {}",
300            self.precision(),
301            self.recall(),
302            self.f1_score(),
303            self.true_positives,
304            self.false_positives,
305            self.false_negatives
306        )
307    }
308
309    pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
310        let mut true_positives = 0;
311        let mut false_positives = 0;
312        let mut false_negatives = 0;
313
314        for score in scores {
315            true_positives += score.true_positives;
316            false_positives += score.false_positives;
317            false_negatives += score.false_negatives;
318        }
319
320        Scores {
321            true_positives,
322            false_positives,
323            false_negatives,
324        }
325    }
326
327    pub fn precision(&self) -> f64 {
328        if self.true_positives + self.false_positives == 0 {
329            0.0
330        } else {
331            self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
332        }
333    }
334
335    pub fn recall(&self) -> f64 {
336        if self.true_positives + self.false_negatives == 0 {
337            0.0
338        } else {
339            self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
340        }
341    }
342
343    pub fn f1_score(&self) -> f64 {
344        let recall = self.recall();
345        let precision = self.precision();
346        if precision + recall == 0.0 {
347            0.0
348        } else {
349            2.0 * precision * recall / (precision + recall)
350        }
351    }
352}
353
354impl std::fmt::Display for EvaluationResult {
355    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
356        if f.alternate() {
357            self.fmt_table(f)
358        } else {
359            self.fmt_markdown(f)
360        }
361    }
362}
363
364impl EvaluationResult {
365    fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
366        write!(
367            f,
368            r#"
369### Context Scores
370{}
371"#,
372            self.context.to_markdown(),
373        )?;
374        if let Some(prediction) = &self.edit_prediction {
375            write!(
376                f,
377                r#"
378                ### Edit Prediction Scores
379                {}"#,
380                prediction.to_markdown()
381            )?;
382        }
383        Ok(())
384    }
385
386    fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
387        writeln!(f, "### Scores\n")?;
388        writeln!(
389            f,
390            "                   Prompt  Generated RetrievedContext PatchContext     TP     FP     FN     Precision   Recall     F1"
391        )?;
392        writeln!(
393            f,
394            "─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────"
395        )?;
396        writeln!(
397            f,
398            "Context Retrieval  {:<7} {:<9} {:<16} {:<16} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
399            "",
400            "",
401            "",
402            "",
403            self.context.true_positives,
404            self.context.false_positives,
405            self.context.false_negatives,
406            self.context.precision() * 100.0,
407            self.context.recall() * 100.0,
408            self.context.f1_score() * 100.0
409        )?;
410        if let Some(edit_prediction) = &self.edit_prediction {
411            writeln!(
412                f,
413                "Edit Prediction    {:<7} {:<9} {:<16} {:<16} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
414                self.prompt_len,
415                self.generated_len,
416                self.context_lines_found_in_context,
417                self.context_lines_in_expected_patch,
418                edit_prediction.true_positives,
419                edit_prediction.false_positives,
420                edit_prediction.false_negatives,
421                edit_prediction.precision() * 100.0,
422                edit_prediction.recall() * 100.0,
423                edit_prediction.f1_score() * 100.0
424            )?;
425        }
426        Ok(())
427    }
428}
429
430fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult {
431    let mut eval_result = EvaluationResult {
432        prompt_len: preds.prompt_len,
433        generated_len: preds.generated_len,
434        ..Default::default()
435    };
436
437    let actual_context_lines: HashSet<_> = preds
438        .excerpts
439        .iter()
440        .flat_map(|excerpt| {
441            excerpt
442                .text
443                .lines()
444                .map(|line| format!("{}: {line}", excerpt.path.display()))
445        })
446        .collect();
447
448    let mut false_positive_lines = actual_context_lines.clone();
449
450    for entry in &example.expected_context {
451        let mut best_alternative_score: Option<Scores> = None;
452
453        for alternative in &entry.alternatives {
454            let expected: HashSet<_> = alternative
455                .excerpts
456                .iter()
457                .flat_map(|excerpt| {
458                    excerpt
459                        .text
460                        .lines()
461                        .map(|line| format!("{}: {line}", excerpt.path.display()))
462                })
463                .collect();
464
465            let scores = Scores::new(&expected, &actual_context_lines);
466
467            false_positive_lines.retain(|line| !expected.contains(line));
468
469            if best_alternative_score
470                .as_ref()
471                .is_none_or(|best| scores.recall() > best.recall())
472            {
473                best_alternative_score = Some(scores);
474            }
475        }
476
477        let best_alternative = best_alternative_score.unwrap_or_default();
478        eval_result.context.false_negatives += best_alternative.false_negatives;
479        eval_result.context.true_positives += best_alternative.true_positives;
480    }
481
482    eval_result.context.false_positives = false_positive_lines.len();
483
484    if predict {
485        // todo: alternatives for patches
486        let expected_patch = example
487            .expected_patch
488            .lines()
489            .map(DiffLine::parse)
490            .collect::<Vec<_>>();
491        let expected_patch_lines = expected_patch
492            .iter()
493            .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
494            .map(|line| line.to_string())
495            .collect();
496        let expected_context_lines = expected_patch
497            .iter()
498            .filter_map(|line| {
499                if let DiffLine::Context(str) = line {
500                    Some(String::from(*str))
501                } else {
502                    None
503                }
504            })
505            .collect::<BTreeSet<_>>();
506        let actual_context_lines = preds
507            .excerpts
508            .iter()
509            .flat_map(|excerpt| excerpt.text.lines().map(ToOwned::to_owned))
510            .collect::<BTreeSet<_>>();
511
512        let matched = expected_context_lines
513            .intersection(&actual_context_lines)
514            .count();
515
516        let actual_patch_lines = preds
517            .diff
518            .lines()
519            .map(DiffLine::parse)
520            .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
521            .map(|line| line.to_string())
522            .collect();
523
524        eval_result.edit_prediction = Some(Scores::new(&expected_patch_lines, &actual_patch_lines));
525        eval_result.context_lines_in_expected_patch = expected_context_lines.len();
526        eval_result.context_lines_found_in_context = matched;
527    }
528
529    eval_result
530}
531
532/// Return annotated `patch_a` so that:
533/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
534/// Additions and deletions that are present in `patch_b` will be highlighted in green.
535pub fn compare_diffs(patch_a: &str, patch_b: &str, use_color: bool) -> String {
536    let green = if use_color { "\x1b[32m✓ " } else { "" };
537    let red = if use_color { "\x1b[31m✗ " } else { "" };
538    let neutral = if use_color { "  " } else { "" };
539    let reset = if use_color { "\x1b[0m" } else { "" };
540    let lines_a = patch_a.lines().map(DiffLine::parse);
541    let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
542
543    let annotated = lines_a
544        .map(|line| match line {
545            DiffLine::Addition(_) | DiffLine::Deletion(_) => {
546                if lines_b.contains(&line) {
547                    format!("{green}{line}{reset}")
548                } else {
549                    format!("{red}{line}{reset}")
550                }
551            }
552            _ => format!("{neutral}{line}{reset}"),
553        })
554        .collect::<Vec<String>>();
555
556    annotated.join("\n")
557}
558
559fn write_bucketed_analysis(
560    all_results: &Vec<
561        Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
562    >,
563) -> Result<()> {
564    #[derive(Debug)]
565    struct EditBucket {
566        diff: String,
567        is_correct: bool,
568        execution_indices: Vec<String>,
569        reasoning_samples: Vec<String>,
570    }
571
572    let mut total_executions = 0;
573    let mut empty_predictions = Vec::new();
574    let mut errors = Vec::new();
575
576    let mut buckets: HashMap<String, EditBucket> = HashMap::new();
577
578    for result in all_results.iter().flatten() {
579        total_executions += 1;
580
581        let (evaluation_result, execution_data) = match result {
582            Ok((eval_result, execution_data)) => {
583                if execution_data.diff.is_empty() {
584                    empty_predictions.push(execution_data);
585                    continue;
586                }
587                (eval_result, execution_data)
588            }
589            Err(err) => {
590                errors.push(err);
591                continue;
592            }
593        };
594
595        buckets
596            .entry(execution_data.diff.clone())
597            .and_modify(|bucket| {
598                bucket
599                    .execution_indices
600                    .push(execution_data.execution_id.clone());
601                bucket
602                    .reasoning_samples
603                    .push(execution_data.reasoning.clone());
604            })
605            .or_insert_with(|| EditBucket {
606                diff: execution_data.diff.clone(),
607                is_correct: {
608                    evaluation_result
609                        .edit_prediction
610                        .as_ref()
611                        .map_or(false, |edit_prediction| {
612                            edit_prediction.false_positives == 0
613                                && edit_prediction.false_negatives == 0
614                                && edit_prediction.true_positives > 0
615                        })
616                },
617                execution_indices: vec![execution_data.execution_id.clone()],
618                reasoning_samples: vec![execution_data.reasoning.clone()],
619            });
620    }
621
622    let mut sorted_buckets = buckets.into_values().collect::<Vec<_>>();
623    sorted_buckets.sort_by(|a, b| match (a.is_correct, b.is_correct) {
624        (true, false) => std::cmp::Ordering::Less,
625        (false, true) => std::cmp::Ordering::Greater,
626        _ => b.execution_indices.len().cmp(&a.execution_indices.len()),
627    });
628
629    let output_path = crate::paths::RUN_DIR.join("bucketed_analysis.md");
630    let mut output = std::fs::File::create(&output_path)?;
631
632    writeln!(output, "# Bucketed Edit Analysis\n")?;
633
634    writeln!(output, "## Summary\n")?;
635    writeln!(output, "- **Total executions**: {}", total_executions)?;
636
637    let correct_count: usize = sorted_buckets
638        .iter()
639        .filter(|b| b.is_correct)
640        .map(|b| b.execution_indices.len())
641        .sum();
642
643    let incorrect_count: usize = sorted_buckets
644        .iter()
645        .filter(|b| !b.is_correct)
646        .map(|b| b.execution_indices.len())
647        .sum();
648
649    writeln!(
650        output,
651        "- **Correct predictions**: {} ({:.1}%)",
652        correct_count,
653        (correct_count as f64 / total_executions as f64) * 100.0
654    )?;
655
656    writeln!(
657        output,
658        "- **Incorrect predictions**: {} ({:.1}%)",
659        incorrect_count,
660        (incorrect_count as f64 / total_executions as f64) * 100.0
661    )?;
662
663    writeln!(
664        output,
665        "- **No Predictions**: {} ({:.1}%)",
666        empty_predictions.len(),
667        (empty_predictions.len() as f64 / total_executions as f64) * 100.0
668    )?;
669
670    let unique_incorrect = sorted_buckets.iter().filter(|b| !b.is_correct).count();
671    writeln!(
672        output,
673        "- **Unique incorrect edit patterns**: {}\n",
674        unique_incorrect
675    )?;
676
677    writeln!(output, "---\n")?;
678
679    for (idx, bucket) in sorted_buckets.iter().filter(|b| b.is_correct).enumerate() {
680        if idx == 0 {
681            writeln!(
682                output,
683                "## Correct Predictions ({} occurrences)\n",
684                bucket.execution_indices.len()
685            )?;
686        }
687
688        writeln!(output, "**Predicted Edit:**\n")?;
689        writeln!(output, "```diff")?;
690        writeln!(output, "{}", bucket.diff)?;
691        writeln!(output, "```\n")?;
692
693        writeln!(
694            output,
695            "**Executions:** {}\n",
696            bucket.execution_indices.join(", ")
697        )?;
698        writeln!(output, "---\n")?;
699    }
700
701    for (idx, bucket) in sorted_buckets.iter().filter(|b| !b.is_correct).enumerate() {
702        writeln!(
703            output,
704            "## Incorrect Prediction #{} ({} occurrences)\n",
705            idx + 1,
706            bucket.execution_indices.len()
707        )?;
708
709        writeln!(output, "**Predicted Edit:**\n")?;
710        writeln!(output, "```diff")?;
711        writeln!(output, "{}", bucket.diff)?;
712        writeln!(output, "```\n")?;
713
714        writeln!(
715            output,
716            "**Executions:** {}\n",
717            bucket.execution_indices.join(", ")
718        )?;
719
720        for (exec_id, reasoning) in bucket
721            .execution_indices
722            .iter()
723            .zip(bucket.reasoning_samples.iter())
724        {
725            writeln!(output, "{}", fmt_execution(exec_id, reasoning))?;
726        }
727
728        writeln!(output, "\n---\n")?;
729    }
730
731    if !empty_predictions.is_empty() {
732        writeln!(
733            output,
734            "## No Predictions ({} occurrences)\n",
735            empty_predictions.len()
736        )?;
737
738        for execution_data in &empty_predictions {
739            writeln!(
740                output,
741                "{}",
742                fmt_execution(&execution_data.execution_id, &execution_data.reasoning)
743            )?;
744        }
745        writeln!(output, "\n---\n")?;
746    }
747
748    if !errors.is_empty() {
749        writeln!(output, "## Errors ({} occurrences)\n", errors.len())?;
750
751        for (err, name, repetition_ix) in &errors {
752            writeln!(output, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
753        }
754        writeln!(output, "\n---\n")?;
755    }
756
757    fn fmt_execution(exec_id: &str, reasoning: &str) -> String {
758        let exec_content = format!(
759            "\n### Execution {} `{}/{}/prediction_response.md`{}",
760            exec_id,
761            crate::paths::RUN_DIR.display(),
762            exec_id,
763            indent_text(&format!("\n\n```\n{}\n```\n", reasoning,), 2)
764        );
765        indent_text(&exec_content, 2)
766    }
767
768    fn indent_text(text: &str, spaces: usize) -> String {
769        let indent = " ".repeat(spaces);
770        text.lines()
771            .collect::<Vec<_>>()
772            .join(&format!("\n{}", indent))
773    }
774
775    Ok(())
776}
777
778fn fmt_evaluation_error(err: &anyhow::Error, name: &str, repetition_ix: &Option<u16>) -> String {
779    let err = format!("{err:?}")
780        .replace("<edits", "```xml\n<edits")
781        .replace("</edits>", "</edits>\n```");
782    format!(
783        "### ERROR {name}{}\n\n{err}\n",
784        repetition_ix
785            .map(|ix| format!(" [RUN {ix:03}]"))
786            .unwrap_or_default()
787    )
788}