evaluate.rs

  1use std::{
  2    collections::HashMap,
  3    io::{IsTerminal, Write},
  4    sync::Arc,
  5};
  6
  7use anyhow::Result;
  8use collections::HashSet;
  9use gpui::{AsyncApp, Entity};
 10use project::Project;
 11use util::ResultExt as _;
 12use zeta::{Zeta, udiff::DiffLine};
 13
 14use crate::{
 15    EvaluateArguments, PredictionOptions,
 16    example::{Example, NamedExample},
 17    headless::ZetaCliAppState,
 18    paths::print_run_data_dir,
 19    predict::{PredictionDetails, perform_predict, setup_zeta},
 20};
 21
 22#[derive(Debug)]
 23pub(crate) struct ExecutionData {
 24    execution_id: String,
 25    diff: String,
 26    reasoning: String,
 27}
 28
 29pub async fn run_evaluate(
 30    args: EvaluateArguments,
 31    app_state: &Arc<ZetaCliAppState>,
 32    cx: &mut AsyncApp,
 33) {
 34    if args.example_paths.is_empty() {
 35        eprintln!("No examples provided");
 36        return;
 37    }
 38
 39    let all_tasks = args.example_paths.into_iter().map(|path| {
 40        let options = args.options.clone();
 41        let app_state = app_state.clone();
 42        let example = NamedExample::load(&path).expect("Failed to load example");
 43
 44        cx.spawn(async move |cx| {
 45            let project = example.setup_project(&app_state, cx).await.unwrap();
 46
 47            let providers = (0..args.repetitions)
 48                .map(|_| setup_zeta(args.options.provider, &project, &app_state, cx).unwrap())
 49                .collect::<Vec<_>>();
 50
 51            let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
 52
 53            let tasks = providers
 54                .into_iter()
 55                .enumerate()
 56                .map(move |(repetition_ix, zeta)| {
 57                    let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
 58                    let example = example.clone();
 59                    let project = project.clone();
 60                    let options = options.clone();
 61
 62                    cx.spawn(async move |cx| {
 63                        let name = example.name.clone();
 64                        run_evaluate_one(
 65                            example,
 66                            repetition_ix,
 67                            project,
 68                            zeta,
 69                            options,
 70                            !args.skip_prediction,
 71                            cx,
 72                        )
 73                        .await
 74                        .map_err(|err| (err, name, repetition_ix))
 75                    })
 76                });
 77            futures::future::join_all(tasks).await
 78        })
 79    });
 80    let all_results = futures::future::join_all(all_tasks).await;
 81
 82    write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap();
 83    if let Some(mut output_file) =
 84        std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err()
 85    {
 86        write_aggregated_scores(&mut output_file, &all_results).log_err();
 87    };
 88
 89    if args.repetitions > 1 {
 90        if let Err(e) = write_bucketed_analysis(&all_results) {
 91            eprintln!("Failed to write bucketed analysis: {:?}", e);
 92        }
 93    }
 94
 95    print_run_data_dir(args.repetitions == 1, std::io::stdout().is_terminal());
 96}
 97
 98fn write_aggregated_scores(
 99    w: &mut impl std::io::Write,
100    all_results: &Vec<
101        Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
102    >,
103) -> Result<()> {
104    let mut successful = Vec::new();
105    let mut failed_count = 0;
106
107    for result in all_results.iter().flatten() {
108        match result {
109            Ok((eval_result, _execution_data)) => successful.push(eval_result),
110            Err((err, name, repetition_ix)) => {
111                if failed_count == 0 {
112                    writeln!(w, "## Errors\n")?;
113                }
114
115                failed_count += 1;
116                writeln!(w, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
117            }
118        }
119    }
120
121    if successful.len() > 1 {
122        let mut edit_predictions = successful
123            .iter()
124            .filter_map(|r| r.edit_prediction.as_ref())
125            .peekable();
126        let has_edit_predictions = edit_predictions.peek().is_some();
127        let aggregated_result = EvaluationResult {
128            edit_prediction: has_edit_predictions.then(|| Scores::aggregate(edit_predictions)),
129            prompt_len: successful.iter().map(|r| r.prompt_len).sum::<usize>() / successful.len(),
130            generated_len: successful.iter().map(|r| r.generated_len).sum::<usize>()
131                / successful.len(),
132        };
133
134        writeln!(w, "\n{}", "-".repeat(80))?;
135        writeln!(w, "\n## TOTAL SCORES")?;
136        writeln!(w, "{:#}", aggregated_result)?;
137    }
138
139    if successful.len() + failed_count > 1 {
140        writeln!(
141            w,
142            "\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
143            successful.len(),
144            successful.len() + failed_count,
145            (successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
146        )?;
147    }
148
149    Ok(())
150}
151
152pub async fn run_evaluate_one(
153    example: NamedExample,
154    repetition_ix: Option<u16>,
155    project: Entity<Project>,
156    zeta: Entity<Zeta>,
157    prediction_options: PredictionOptions,
158    predict: bool,
159    cx: &mut AsyncApp,
160) -> Result<(EvaluationResult, ExecutionData)> {
161    let predict_result = perform_predict(
162        example.clone(),
163        project,
164        zeta,
165        repetition_ix,
166        prediction_options,
167        cx,
168    )
169    .await?;
170
171    let evaluation_result = evaluate(&example.example, &predict_result, predict);
172
173    if repetition_ix.is_none() {
174        write_eval_result(
175            &example,
176            &predict_result,
177            &evaluation_result,
178            &mut std::io::stdout(),
179            std::io::stdout().is_terminal(),
180            predict,
181        )?;
182    }
183
184    if let Some(mut results_file) =
185        std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err()
186    {
187        write_eval_result(
188            &example,
189            &predict_result,
190            &evaluation_result,
191            &mut results_file,
192            false,
193            predict,
194        )
195        .log_err();
196    }
197
198    let execution_data = ExecutionData {
199        execution_id: if let Some(rep_ix) = repetition_ix {
200            format!("{:03}", rep_ix)
201        } else {
202            example.name.clone()
203        },
204        diff: predict_result.diff.clone(),
205        reasoning: std::fs::read_to_string(
206            predict_result
207                .run_example_dir
208                .join("prediction_response.md"),
209        )
210        .unwrap_or_default(),
211    };
212
213    anyhow::Ok((evaluation_result, execution_data))
214}
215
216fn write_eval_result(
217    example: &NamedExample,
218    predictions: &PredictionDetails,
219    evaluation_result: &EvaluationResult,
220    out: &mut impl Write,
221    use_color: bool,
222    predict: bool,
223) -> Result<()> {
224    if predict {
225        writeln!(
226            out,
227            "## Expected edit prediction:\n\n```diff\n{}\n```\n",
228            compare_diffs(
229                &example.example.expected_patch,
230                &predictions.diff,
231                use_color
232            )
233        )?;
234        writeln!(
235            out,
236            "## Actual edit prediction:\n\n```diff\n{}\n```\n",
237            compare_diffs(
238                &predictions.diff,
239                &example.example.expected_patch,
240                use_color
241            )
242        )?;
243    }
244
245    writeln!(out, "{:#}", evaluation_result)?;
246
247    anyhow::Ok(())
248}
249
250#[derive(Debug, Default)]
251pub struct EvaluationResult {
252    pub edit_prediction: Option<Scores>,
253    pub prompt_len: usize,
254    pub generated_len: usize,
255}
256
257#[derive(Default, Debug)]
258pub struct Scores {
259    pub true_positives: usize,
260    pub false_positives: usize,
261    pub false_negatives: usize,
262}
263
264impl Scores {
265    pub fn new(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
266        let true_positives = expected.intersection(actual).count();
267        let false_positives = actual.difference(expected).count();
268        let false_negatives = expected.difference(actual).count();
269
270        Scores {
271            true_positives,
272            false_positives,
273            false_negatives,
274        }
275    }
276
277    pub fn to_markdown(&self) -> String {
278        format!(
279            "
280Precision       : {:.4}
281Recall          : {:.4}
282F1 Score        : {:.4}
283True Positives  : {}
284False Positives : {}
285False Negatives : {}",
286            self.precision(),
287            self.recall(),
288            self.f1_score(),
289            self.true_positives,
290            self.false_positives,
291            self.false_negatives
292        )
293    }
294
295    pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
296        let mut true_positives = 0;
297        let mut false_positives = 0;
298        let mut false_negatives = 0;
299
300        for score in scores {
301            true_positives += score.true_positives;
302            false_positives += score.false_positives;
303            false_negatives += score.false_negatives;
304        }
305
306        Scores {
307            true_positives,
308            false_positives,
309            false_negatives,
310        }
311    }
312
313    pub fn precision(&self) -> f64 {
314        if self.true_positives + self.false_positives == 0 {
315            0.0
316        } else {
317            self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
318        }
319    }
320
321    pub fn recall(&self) -> f64 {
322        if self.true_positives + self.false_negatives == 0 {
323            0.0
324        } else {
325            self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
326        }
327    }
328
329    pub fn f1_score(&self) -> f64 {
330        let recall = self.recall();
331        let precision = self.precision();
332        if precision + recall == 0.0 {
333            0.0
334        } else {
335            2.0 * precision * recall / (precision + recall)
336        }
337    }
338}
339
340impl std::fmt::Display for EvaluationResult {
341    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
342        if f.alternate() {
343            self.fmt_table(f)
344        } else {
345            self.fmt_markdown(f)
346        }
347    }
348}
349
350impl EvaluationResult {
351    fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
352        if let Some(prediction) = &self.edit_prediction {
353            write!(
354                f,
355                r#"
356                ### Edit Prediction Scores
357                {}"#,
358                prediction.to_markdown()
359            )?;
360        }
361        Ok(())
362    }
363
364    fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
365        writeln!(f, "### Scores\n")?;
366        writeln!(
367            f,
368            "                   Prompt  Generated  TP     FP     FN     Precision   Recall      F1"
369        )?;
370        writeln!(
371            f,
372            "───────────────────────────────────────────────────────────────────────────────────────────────"
373        )?;
374        if let Some(edit_prediction) = &self.edit_prediction {
375            writeln!(
376                f,
377                "Edit Prediction    {:<7} {:<9}  {:<6} {:<6} {:<6} {:>9.2} {:>8.2} {:>7.2}",
378                self.prompt_len,
379                self.generated_len,
380                edit_prediction.true_positives,
381                edit_prediction.false_positives,
382                edit_prediction.false_negatives,
383                edit_prediction.precision() * 100.0,
384                edit_prediction.recall() * 100.0,
385                edit_prediction.f1_score() * 100.0
386            )?;
387        }
388        Ok(())
389    }
390}
391
392fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult {
393    let mut eval_result = EvaluationResult {
394        prompt_len: preds.prompt_len,
395        generated_len: preds.generated_len,
396        ..Default::default()
397    };
398
399    if predict {
400        // todo: alternatives for patches
401        let expected_patch = example
402            .expected_patch
403            .lines()
404            .map(DiffLine::parse)
405            .collect::<Vec<_>>();
406        let expected_patch_lines = expected_patch
407            .iter()
408            .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
409            .map(|line| line.to_string())
410            .collect();
411
412        let actual_patch_lines = preds
413            .diff
414            .lines()
415            .map(DiffLine::parse)
416            .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
417            .map(|line| line.to_string())
418            .collect();
419
420        eval_result.edit_prediction = Some(Scores::new(&expected_patch_lines, &actual_patch_lines));
421    }
422
423    eval_result
424}
425
426/// Return annotated `patch_a` so that:
427/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
428/// Additions and deletions that are present in `patch_b` will be highlighted in green.
429pub fn compare_diffs(patch_a: &str, patch_b: &str, use_color: bool) -> String {
430    let green = if use_color { "\x1b[32m✓ " } else { "" };
431    let red = if use_color { "\x1b[31m✗ " } else { "" };
432    let neutral = if use_color { "  " } else { "" };
433    let reset = if use_color { "\x1b[0m" } else { "" };
434    let lines_a = patch_a.lines().map(DiffLine::parse);
435    let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
436
437    let annotated = lines_a
438        .map(|line| match line {
439            DiffLine::Addition(_) | DiffLine::Deletion(_) => {
440                if lines_b.contains(&line) {
441                    format!("{green}{line}{reset}")
442                } else {
443                    format!("{red}{line}{reset}")
444                }
445            }
446            _ => format!("{neutral}{line}{reset}"),
447        })
448        .collect::<Vec<String>>();
449
450    annotated.join("\n")
451}
452
453fn write_bucketed_analysis(
454    all_results: &Vec<
455        Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
456    >,
457) -> Result<()> {
458    #[derive(Debug)]
459    struct EditBucket {
460        diff: String,
461        is_correct: bool,
462        execution_indices: Vec<String>,
463        reasoning_samples: Vec<String>,
464    }
465
466    let mut total_executions = 0;
467    let mut empty_predictions = Vec::new();
468    let mut errors = Vec::new();
469
470    let mut buckets: HashMap<String, EditBucket> = HashMap::new();
471
472    for result in all_results.iter().flatten() {
473        total_executions += 1;
474
475        let (evaluation_result, execution_data) = match result {
476            Ok((eval_result, execution_data)) => {
477                if execution_data.diff.is_empty() {
478                    empty_predictions.push(execution_data);
479                    continue;
480                }
481                (eval_result, execution_data)
482            }
483            Err(err) => {
484                errors.push(err);
485                continue;
486            }
487        };
488
489        buckets
490            .entry(execution_data.diff.clone())
491            .and_modify(|bucket| {
492                bucket
493                    .execution_indices
494                    .push(execution_data.execution_id.clone());
495                bucket
496                    .reasoning_samples
497                    .push(execution_data.reasoning.clone());
498            })
499            .or_insert_with(|| EditBucket {
500                diff: execution_data.diff.clone(),
501                is_correct: {
502                    evaluation_result
503                        .edit_prediction
504                        .as_ref()
505                        .map_or(false, |edit_prediction| {
506                            edit_prediction.false_positives == 0
507                                && edit_prediction.false_negatives == 0
508                                && edit_prediction.true_positives > 0
509                        })
510                },
511                execution_indices: vec![execution_data.execution_id.clone()],
512                reasoning_samples: vec![execution_data.reasoning.clone()],
513            });
514    }
515
516    let mut sorted_buckets = buckets.into_values().collect::<Vec<_>>();
517    sorted_buckets.sort_by(|a, b| match (a.is_correct, b.is_correct) {
518        (true, false) => std::cmp::Ordering::Less,
519        (false, true) => std::cmp::Ordering::Greater,
520        _ => b.execution_indices.len().cmp(&a.execution_indices.len()),
521    });
522
523    let output_path = crate::paths::RUN_DIR.join("bucketed_analysis.md");
524    let mut output = std::fs::File::create(&output_path)?;
525
526    writeln!(output, "# Bucketed Edit Analysis\n")?;
527
528    writeln!(output, "## Summary\n")?;
529    writeln!(output, "- **Total executions**: {}", total_executions)?;
530
531    let correct_count: usize = sorted_buckets
532        .iter()
533        .filter(|b| b.is_correct)
534        .map(|b| b.execution_indices.len())
535        .sum();
536
537    let incorrect_count: usize = sorted_buckets
538        .iter()
539        .filter(|b| !b.is_correct)
540        .map(|b| b.execution_indices.len())
541        .sum();
542
543    writeln!(
544        output,
545        "- **Correct predictions**: {} ({:.1}%)",
546        correct_count,
547        (correct_count as f64 / total_executions as f64) * 100.0
548    )?;
549
550    writeln!(
551        output,
552        "- **Incorrect predictions**: {} ({:.1}%)",
553        incorrect_count,
554        (incorrect_count as f64 / total_executions as f64) * 100.0
555    )?;
556
557    writeln!(
558        output,
559        "- **No Predictions**: {} ({:.1}%)",
560        empty_predictions.len(),
561        (empty_predictions.len() as f64 / total_executions as f64) * 100.0
562    )?;
563
564    let unique_incorrect = sorted_buckets.iter().filter(|b| !b.is_correct).count();
565    writeln!(
566        output,
567        "- **Unique incorrect edit patterns**: {}\n",
568        unique_incorrect
569    )?;
570
571    writeln!(output, "---\n")?;
572
573    for (idx, bucket) in sorted_buckets.iter().filter(|b| b.is_correct).enumerate() {
574        if idx == 0 {
575            writeln!(
576                output,
577                "## Correct Predictions ({} occurrences)\n",
578                bucket.execution_indices.len()
579            )?;
580        }
581
582        writeln!(output, "**Predicted Edit:**\n")?;
583        writeln!(output, "```diff")?;
584        writeln!(output, "{}", bucket.diff)?;
585        writeln!(output, "```\n")?;
586
587        writeln!(
588            output,
589            "**Executions:** {}\n",
590            bucket.execution_indices.join(", ")
591        )?;
592        writeln!(output, "---\n")?;
593    }
594
595    for (idx, bucket) in sorted_buckets.iter().filter(|b| !b.is_correct).enumerate() {
596        writeln!(
597            output,
598            "## Incorrect Prediction #{} ({} occurrences)\n",
599            idx + 1,
600            bucket.execution_indices.len()
601        )?;
602
603        writeln!(output, "**Predicted Edit:**\n")?;
604        writeln!(output, "```diff")?;
605        writeln!(output, "{}", bucket.diff)?;
606        writeln!(output, "```\n")?;
607
608        writeln!(
609            output,
610            "**Executions:** {}\n",
611            bucket.execution_indices.join(", ")
612        )?;
613
614        for (exec_id, reasoning) in bucket
615            .execution_indices
616            .iter()
617            .zip(bucket.reasoning_samples.iter())
618        {
619            writeln!(output, "{}", fmt_execution(exec_id, reasoning))?;
620        }
621
622        writeln!(output, "\n---\n")?;
623    }
624
625    if !empty_predictions.is_empty() {
626        writeln!(
627            output,
628            "## No Predictions ({} occurrences)\n",
629            empty_predictions.len()
630        )?;
631
632        for execution_data in &empty_predictions {
633            writeln!(
634                output,
635                "{}",
636                fmt_execution(&execution_data.execution_id, &execution_data.reasoning)
637            )?;
638        }
639        writeln!(output, "\n---\n")?;
640    }
641
642    if !errors.is_empty() {
643        writeln!(output, "## Errors ({} occurrences)\n", errors.len())?;
644
645        for (err, name, repetition_ix) in &errors {
646            writeln!(output, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
647        }
648        writeln!(output, "\n---\n")?;
649    }
650
651    fn fmt_execution(exec_id: &str, reasoning: &str) -> String {
652        let exec_content = format!(
653            "\n### Execution {} `{}/{}/prediction_response.md`{}",
654            exec_id,
655            crate::paths::RUN_DIR.display(),
656            exec_id,
657            indent_text(&format!("\n\n```\n{}\n```\n", reasoning,), 2)
658        );
659        indent_text(&exec_content, 2)
660    }
661
662    fn indent_text(text: &str, spaces: usize) -> String {
663        let indent = " ".repeat(spaces);
664        text.lines()
665            .collect::<Vec<_>>()
666            .join(&format!("\n{}", indent))
667    }
668
669    Ok(())
670}
671
672fn fmt_evaluation_error(err: &anyhow::Error, name: &str, repetition_ix: &Option<u16>) -> String {
673    let err = format!("{err:?}")
674        .replace("<edits", "```xml\n<edits")
675        .replace("</edits>", "</edits>\n```");
676    format!(
677        "### ERROR {name}{}\n\n{err}\n",
678        repetition_ix
679            .map(|ix| format!(" [RUN {ix:03}]"))
680            .unwrap_or_default()
681    )
682}