evaluate.rs

  1use crate::metrics::{self, Scores};
  2use std::{
  3    collections::HashMap,
  4    io::{IsTerminal, Write},
  5    sync::Arc,
  6};
  7
  8use anyhow::Result;
  9use gpui::{AsyncApp, Entity};
 10use project::Project;
 11use util::ResultExt as _;
 12use zeta::{Zeta, udiff::DiffLine};
 13
 14use crate::{
 15    EvaluateArguments, PredictionOptions,
 16    example::{Example, NamedExample},
 17    headless::ZetaCliAppState,
 18    paths::print_run_data_dir,
 19    predict::{PredictionDetails, perform_predict, setup_zeta},
 20};
 21
 22#[derive(Debug)]
 23pub(crate) struct ExecutionData {
 24    execution_id: String,
 25    diff: String,
 26    reasoning: String,
 27}
 28
 29pub async fn run_evaluate(
 30    args: EvaluateArguments,
 31    app_state: &Arc<ZetaCliAppState>,
 32    cx: &mut AsyncApp,
 33) {
 34    if args.example_paths.is_empty() {
 35        eprintln!("No examples provided");
 36        return;
 37    }
 38
 39    let all_tasks = args.example_paths.into_iter().map(|path| {
 40        let options = args.options.clone();
 41        let app_state = app_state.clone();
 42        let example = NamedExample::load(&path).expect("Failed to load example");
 43
 44        cx.spawn(async move |cx| {
 45            let project = example.setup_project(&app_state, cx).await.unwrap();
 46
 47            let providers = (0..args.repetitions)
 48                .map(|_| setup_zeta(args.options.provider, &project, &app_state, cx).unwrap())
 49                .collect::<Vec<_>>();
 50
 51            let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
 52
 53            let tasks = providers
 54                .into_iter()
 55                .enumerate()
 56                .map(move |(repetition_ix, zeta)| {
 57                    let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
 58                    let example = example.clone();
 59                    let project = project.clone();
 60                    let options = options.clone();
 61
 62                    cx.spawn(async move |cx| {
 63                        let name = example.name.clone();
 64                        run_evaluate_one(
 65                            example,
 66                            repetition_ix,
 67                            project,
 68                            zeta,
 69                            options,
 70                            !args.skip_prediction,
 71                            cx,
 72                        )
 73                        .await
 74                        .map_err(|err| (err, name, repetition_ix))
 75                    })
 76                });
 77            futures::future::join_all(tasks).await
 78        })
 79    });
 80    let all_results = futures::future::join_all(all_tasks).await;
 81
 82    write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap();
 83    if let Some(mut output_file) =
 84        std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err()
 85    {
 86        write_aggregated_scores(&mut output_file, &all_results).log_err();
 87    };
 88
 89    if args.repetitions > 1 {
 90        if let Err(e) = write_bucketed_analysis(&all_results) {
 91            eprintln!("Failed to write bucketed analysis: {:?}", e);
 92        }
 93    }
 94
 95    print_run_data_dir(args.repetitions == 1, std::io::stdout().is_terminal());
 96}
 97
 98fn write_aggregated_scores(
 99    w: &mut impl std::io::Write,
100    all_results: &Vec<
101        Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
102    >,
103) -> Result<()> {
104    let mut successful = Vec::new();
105    let mut failed_count = 0;
106
107    for result in all_results.iter().flatten() {
108        match result {
109            Ok((eval_result, _execution_data)) => successful.push(eval_result),
110            Err((err, name, repetition_ix)) => {
111                if failed_count == 0 {
112                    writeln!(w, "## Errors\n")?;
113                }
114
115                failed_count += 1;
116                writeln!(w, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
117            }
118        }
119    }
120
121    if successful.len() > 1 {
122        let edit_scores = successful
123            .iter()
124            .filter_map(|r| r.edit_scores.clone())
125            .collect::<Vec<_>>();
126        let has_edit_predictions = edit_scores.len() > 0;
127        let aggregated_result = EvaluationResult {
128            context_scores: Scores::aggregate(successful.iter().map(|r| &r.context_scores)),
129            edit_scores: has_edit_predictions.then(|| EditScores::aggregate(&edit_scores)),
130            prompt_len: successful.iter().map(|r| r.prompt_len).sum::<usize>() / successful.len(),
131            generated_len: successful.iter().map(|r| r.generated_len).sum::<usize>()
132                / successful.len(),
133        };
134
135        writeln!(w, "\n{}", "-".repeat(80))?;
136        writeln!(w, "\n## TOTAL SCORES")?;
137        writeln!(w, "{:#}", aggregated_result)?;
138    }
139
140    if successful.len() + failed_count > 1 {
141        writeln!(
142            w,
143            "\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
144            successful.len(),
145            successful.len() + failed_count,
146            (successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
147        )?;
148    }
149
150    Ok(())
151}
152
153pub async fn run_evaluate_one(
154    example: NamedExample,
155    repetition_ix: Option<u16>,
156    project: Entity<Project>,
157    zeta: Entity<Zeta>,
158    prediction_options: PredictionOptions,
159    predict: bool,
160    cx: &mut AsyncApp,
161) -> Result<(EvaluationResult, ExecutionData)> {
162    let predict_result = perform_predict(
163        example.clone(),
164        project,
165        zeta,
166        repetition_ix,
167        prediction_options,
168        cx,
169    )
170    .await?;
171
172    let evaluation_result = evaluate(&example.example, &predict_result, predict);
173
174    if repetition_ix.is_none() {
175        write_eval_result(
176            &example,
177            &predict_result,
178            &evaluation_result,
179            &mut std::io::stdout(),
180            std::io::stdout().is_terminal(),
181            predict,
182        )?;
183    }
184
185    if let Some(mut results_file) =
186        std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err()
187    {
188        write_eval_result(
189            &example,
190            &predict_result,
191            &evaluation_result,
192            &mut results_file,
193            false,
194            predict,
195        )
196        .log_err();
197    }
198
199    let execution_data = ExecutionData {
200        execution_id: if let Some(rep_ix) = repetition_ix {
201            format!("{:03}", rep_ix)
202        } else {
203            example.name.clone()
204        },
205        diff: predict_result.diff.clone(),
206        reasoning: std::fs::read_to_string(
207            predict_result
208                .run_example_dir
209                .join("prediction_response.md"),
210        )
211        .unwrap_or_default(),
212    };
213
214    anyhow::Ok((evaluation_result, execution_data))
215}
216
217fn write_eval_result(
218    example: &NamedExample,
219    predictions: &PredictionDetails,
220    evaluation_result: &EvaluationResult,
221    out: &mut impl Write,
222    use_color: bool,
223    predict: bool,
224) -> Result<()> {
225    if predict {
226        writeln!(
227            out,
228            "## Expected edit prediction:\n\n```diff\n{}\n```\n",
229            compare_diffs(
230                &example.example.expected_patch,
231                &predictions.diff,
232                use_color
233            )
234        )?;
235        writeln!(
236            out,
237            "## Actual edit prediction:\n\n```diff\n{}\n```\n",
238            compare_diffs(
239                &predictions.diff,
240                &example.example.expected_patch,
241                use_color
242            )
243        )?;
244    }
245
246    writeln!(out, "{:#}", evaluation_result)?;
247
248    anyhow::Ok(())
249}
250
251#[derive(Debug, Default, Clone)]
252pub struct EditScores {
253    pub line_match: Scores,
254    pub chr_f: f64,
255}
256
257impl EditScores {
258    pub fn aggregate(scores: &[EditScores]) -> EditScores {
259        let line_match = Scores::aggregate(scores.iter().map(|s| &s.line_match));
260        let chr_f = scores.iter().map(|s| s.chr_f).sum::<f64>() / scores.len() as f64;
261
262        EditScores { line_match, chr_f }
263    }
264}
265
266#[derive(Debug, Default)]
267pub struct EvaluationResult {
268    pub edit_scores: Option<EditScores>,
269    pub context_scores: Scores,
270    pub prompt_len: usize,
271    pub generated_len: usize,
272}
273
274impl std::fmt::Display for EvaluationResult {
275    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276        if f.alternate() {
277            self.fmt_table(f)
278        } else {
279            self.fmt_markdown(f)
280        }
281    }
282}
283
284impl EvaluationResult {
285    fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286        write!(
287            f,
288            r#"
289### Context Scores
290{}
291"#,
292            self.context_scores.to_markdown(),
293        )?;
294        if let Some(scores) = &self.edit_scores {
295            write!(
296                f,
297                r#"
298                ### Edit Prediction Scores
299                {}"#,
300                scores.line_match.to_markdown()
301            )?;
302        }
303        Ok(())
304    }
305
306    fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307        writeln!(f, "#### Prompt Statistics")?;
308        writeln!(f, "─────────────────────────")?;
309        writeln!(f, "Prompt_len  Generated_len")?;
310        writeln!(f, "─────────────────────────")?;
311        writeln!(f, "{:<11} {:<14}", self.prompt_len, self.generated_len,)?;
312        writeln!(f)?;
313        writeln!(f)?;
314        writeln!(f, "#### Performance Scores")?;
315        writeln!(
316            f,
317            "──────────────────────────────────────────────────────────────────"
318        )?;
319        writeln!(
320            f,
321            "                   TP     FP     FN     Precision   Recall     F1"
322        )?;
323        writeln!(
324            f,
325            "──────────────────────────────────────────────────────────────────"
326        )?;
327        writeln!(
328            f,
329            "Context Retrieval  {:<6} {:<6} {:<6} {:>8.2}  {:>7.2}  {:>6.2}",
330            self.context_scores.true_positives,
331            self.context_scores.false_positives,
332            self.context_scores.false_negatives,
333            self.context_scores.precision() * 100.0,
334            self.context_scores.recall() * 100.0,
335            self.context_scores.f1_score() * 100.0
336        )?;
337        if let Some(edit_scores) = &self.edit_scores {
338            let line_match = &edit_scores.line_match;
339            writeln!(f, "Edit Prediction")?;
340            writeln!(
341                f,
342                "  ├─ exact lines   {:<6} {:<6} {:<6} {:>8.2}  {:>7.2}  {:>6.2}",
343                line_match.true_positives,
344                line_match.false_positives,
345                line_match.false_negatives,
346                line_match.precision() * 100.0,
347                line_match.recall() * 100.0,
348                line_match.f1_score() * 100.0
349            )?;
350            writeln!(
351                f,
352                "  └─ diff chrF     {:<6} {:<6} {:<6} {:>8} {:>8}  {:>6.2}",
353                "-", "-", "-", "-", "-", edit_scores.chr_f
354            )?;
355        }
356        Ok(())
357    }
358}
359
360fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult {
361    let mut eval_result = EvaluationResult {
362        prompt_len: preds.prompt_len,
363        generated_len: preds.generated_len,
364        ..Default::default()
365    };
366
367    if predict {
368        // todo: alternatives for patches
369        let expected_patch = example
370            .expected_patch
371            .lines()
372            .map(DiffLine::parse)
373            .collect::<Vec<_>>();
374        let actual_patch = preds.diff.lines().map(DiffLine::parse).collect::<Vec<_>>();
375
376        let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
377        let chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch);
378
379        eval_result.edit_scores = Some(EditScores { line_match, chr_f });
380    }
381
382    eval_result
383}
384
385/// Return annotated `patch_a` so that:
386/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
387/// Additions and deletions that are present in `patch_b` will be highlighted in green.
388pub fn compare_diffs(patch_a: &str, patch_b: &str, use_color: bool) -> String {
389    let green = if use_color { "\x1b[32m✓ " } else { "" };
390    let red = if use_color { "\x1b[31m✗ " } else { "" };
391    let neutral = if use_color { "  " } else { "" };
392    let reset = if use_color { "\x1b[0m" } else { "" };
393    let lines_a = patch_a.lines().map(DiffLine::parse);
394    let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
395
396    let annotated = lines_a
397        .map(|line| match line {
398            DiffLine::Addition(_) | DiffLine::Deletion(_) => {
399                if lines_b.contains(&line) {
400                    format!("{green}{line}{reset}")
401                } else {
402                    format!("{red}{line}{reset}")
403                }
404            }
405            _ => format!("{neutral}{line}{reset}"),
406        })
407        .collect::<Vec<String>>();
408
409    annotated.join("\n")
410}
411
412fn write_bucketed_analysis(
413    all_results: &Vec<
414        Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
415    >,
416) -> Result<()> {
417    #[derive(Debug)]
418    struct EditBucket {
419        diff: String,
420        is_correct: bool,
421        execution_indices: Vec<String>,
422        reasoning_samples: Vec<String>,
423    }
424
425    let mut total_executions = 0;
426    let mut empty_predictions = Vec::new();
427    let mut errors = Vec::new();
428
429    let mut buckets: HashMap<String, EditBucket> = HashMap::new();
430
431    for result in all_results.iter().flatten() {
432        total_executions += 1;
433
434        let (evaluation_result, execution_data) = match result {
435            Ok((eval_result, execution_data)) => {
436                if execution_data.diff.is_empty() {
437                    empty_predictions.push(execution_data);
438                    continue;
439                }
440                (eval_result, execution_data)
441            }
442            Err(err) => {
443                errors.push(err);
444                continue;
445            }
446        };
447
448        buckets
449            .entry(execution_data.diff.clone())
450            .and_modify(|bucket| {
451                bucket
452                    .execution_indices
453                    .push(execution_data.execution_id.clone());
454                bucket
455                    .reasoning_samples
456                    .push(execution_data.reasoning.clone());
457            })
458            .or_insert_with(|| EditBucket {
459                diff: execution_data.diff.clone(),
460                is_correct: {
461                    evaluation_result
462                        .edit_scores
463                        .as_ref()
464                        .map_or(false, |edit_scores| {
465                            edit_scores.line_match.false_positives == 0
466                                && edit_scores.line_match.false_negatives == 0
467                                && edit_scores.line_match.true_positives > 0
468                        })
469                },
470                execution_indices: vec![execution_data.execution_id.clone()],
471                reasoning_samples: vec![execution_data.reasoning.clone()],
472            });
473    }
474
475    let mut sorted_buckets = buckets.into_values().collect::<Vec<_>>();
476    sorted_buckets.sort_by(|a, b| match (a.is_correct, b.is_correct) {
477        (true, false) => std::cmp::Ordering::Less,
478        (false, true) => std::cmp::Ordering::Greater,
479        _ => b.execution_indices.len().cmp(&a.execution_indices.len()),
480    });
481
482    let output_path = crate::paths::RUN_DIR.join("bucketed_analysis.md");
483    let mut output = std::fs::File::create(&output_path)?;
484
485    writeln!(output, "# Bucketed Edit Analysis\n")?;
486
487    writeln!(output, "## Summary\n")?;
488    writeln!(output, "- **Total executions**: {}", total_executions)?;
489
490    let correct_count: usize = sorted_buckets
491        .iter()
492        .filter(|b| b.is_correct)
493        .map(|b| b.execution_indices.len())
494        .sum();
495
496    let incorrect_count: usize = sorted_buckets
497        .iter()
498        .filter(|b| !b.is_correct)
499        .map(|b| b.execution_indices.len())
500        .sum();
501
502    writeln!(
503        output,
504        "- **Correct predictions**: {} ({:.1}%)",
505        correct_count,
506        (correct_count as f64 / total_executions as f64) * 100.0
507    )?;
508
509    writeln!(
510        output,
511        "- **Incorrect predictions**: {} ({:.1}%)",
512        incorrect_count,
513        (incorrect_count as f64 / total_executions as f64) * 100.0
514    )?;
515
516    writeln!(
517        output,
518        "- **No Predictions**: {} ({:.1}%)",
519        empty_predictions.len(),
520        (empty_predictions.len() as f64 / total_executions as f64) * 100.0
521    )?;
522
523    let unique_incorrect = sorted_buckets.iter().filter(|b| !b.is_correct).count();
524    writeln!(
525        output,
526        "- **Unique incorrect edit patterns**: {}\n",
527        unique_incorrect
528    )?;
529
530    writeln!(output, "---\n")?;
531
532    for (idx, bucket) in sorted_buckets.iter().filter(|b| b.is_correct).enumerate() {
533        if idx == 0 {
534            writeln!(
535                output,
536                "## Correct Predictions ({} occurrences)\n",
537                bucket.execution_indices.len()
538            )?;
539        }
540
541        writeln!(output, "**Predicted Edit:**\n")?;
542        writeln!(output, "```diff")?;
543        writeln!(output, "{}", bucket.diff)?;
544        writeln!(output, "```\n")?;
545
546        writeln!(
547            output,
548            "**Executions:** {}\n",
549            bucket.execution_indices.join(", ")
550        )?;
551        writeln!(output, "---\n")?;
552    }
553
554    for (idx, bucket) in sorted_buckets.iter().filter(|b| !b.is_correct).enumerate() {
555        writeln!(
556            output,
557            "## Incorrect Prediction #{} ({} occurrences)\n",
558            idx + 1,
559            bucket.execution_indices.len()
560        )?;
561
562        writeln!(output, "**Predicted Edit:**\n")?;
563        writeln!(output, "```diff")?;
564        writeln!(output, "{}", bucket.diff)?;
565        writeln!(output, "```\n")?;
566
567        writeln!(
568            output,
569            "**Executions:** {}\n",
570            bucket.execution_indices.join(", ")
571        )?;
572
573        for (exec_id, reasoning) in bucket
574            .execution_indices
575            .iter()
576            .zip(bucket.reasoning_samples.iter())
577        {
578            writeln!(output, "{}", fmt_execution(exec_id, reasoning))?;
579        }
580
581        writeln!(output, "\n---\n")?;
582    }
583
584    if !empty_predictions.is_empty() {
585        writeln!(
586            output,
587            "## No Predictions ({} occurrences)\n",
588            empty_predictions.len()
589        )?;
590
591        for execution_data in &empty_predictions {
592            writeln!(
593                output,
594                "{}",
595                fmt_execution(&execution_data.execution_id, &execution_data.reasoning)
596            )?;
597        }
598        writeln!(output, "\n---\n")?;
599    }
600
601    if !errors.is_empty() {
602        writeln!(output, "## Errors ({} occurrences)\n", errors.len())?;
603
604        for (err, name, repetition_ix) in &errors {
605            writeln!(output, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
606        }
607        writeln!(output, "\n---\n")?;
608    }
609
610    fn fmt_execution(exec_id: &str, reasoning: &str) -> String {
611        let exec_content = format!(
612            "\n### Execution {} `{}/{}/prediction_response.md`{}",
613            exec_id,
614            crate::paths::RUN_DIR.display(),
615            exec_id,
616            indent_text(&format!("\n\n```\n{}\n```\n", reasoning,), 2)
617        );
618        indent_text(&exec_content, 2)
619    }
620
621    fn indent_text(text: &str, spaces: usize) -> String {
622        let indent = " ".repeat(spaces);
623        text.lines()
624            .collect::<Vec<_>>()
625            .join(&format!("\n{}", indent))
626    }
627
628    Ok(())
629}
630
631fn fmt_evaluation_error(err: &anyhow::Error, name: &str, repetition_ix: &Option<u16>) -> String {
632    let err = format!("{err:?}")
633        .replace("<edits", "```xml\n<edits")
634        .replace("</edits>", "</edits>\n```");
635    format!(
636        "### ERROR {name}{}\n\n{err}\n",
637        repetition_ix
638            .map(|ix| format!(" [RUN {ix:03}]"))
639            .unwrap_or_default()
640    )
641}