evaluate.rs

  1use std::{
  2    io::{IsTerminal, Write},
  3    path::PathBuf,
  4    sync::Arc,
  5};
  6
  7use anyhow::Result;
  8use clap::Args;
  9use collections::HashSet;
 10use gpui::{AsyncApp, Entity};
 11use project::Project;
 12use util::ResultExt as _;
 13use zeta2::{Zeta, udiff::DiffLine};
 14
 15use crate::{
 16    PromptFormat,
 17    example::{Example, NamedExample},
 18    headless::ZetaCliAppState,
 19    paths::print_run_data_dir,
 20    predict::{CacheMode, PredictionDetails, zeta2_predict},
 21};
 22
 23#[derive(Debug, Args)]
 24pub struct EvaluateArguments {
 25    example_paths: Vec<PathBuf>,
 26    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
 27    prompt_format: PromptFormat,
 28    #[arg(long)]
 29    use_expected_context: bool,
 30    #[clap(long, value_enum, default_value_t = CacheMode::default())]
 31    cache: CacheMode,
 32    #[clap(short, long, default_value_t = 1, alias = "repeat")]
 33    repetitions: u16,
 34    #[arg(long)]
 35    skip_prediction: bool,
 36}
 37
 38pub async fn run_evaluate(
 39    args: EvaluateArguments,
 40    app_state: &Arc<ZetaCliAppState>,
 41    cx: &mut AsyncApp,
 42) {
 43    if args.example_paths.is_empty() {
 44        eprintln!("No examples provided");
 45        return;
 46    }
 47    let all_tasks = args.example_paths.into_iter().map(|path| {
 48        let app_state = app_state.clone();
 49        let example = NamedExample::load(&path).expect("Failed to load example");
 50
 51        cx.spawn(async move |cx| {
 52            let (project, zetas, _edited_buffers) = example
 53                .setup_project(&app_state, args.repetitions, cx)
 54                .await
 55                .unwrap();
 56
 57            let tasks = zetas.into_iter().enumerate().map(|(repetition_ix, zeta)| {
 58                let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
 59                let example = example.clone();
 60                let project = project.clone();
 61
 62                cx.spawn(async move |cx| {
 63                    let name = example.name.clone();
 64                    run_evaluate_one(
 65                        example,
 66                        repetition_ix,
 67                        project,
 68                        zeta,
 69                        args.prompt_format,
 70                        args.use_expected_context,
 71                        !args.skip_prediction,
 72                        args.cache,
 73                        cx,
 74                    )
 75                    .await
 76                    .map_err(|err| (err, name, repetition_ix))
 77                })
 78            });
 79            futures::future::join_all(tasks).await
 80        })
 81    });
 82    let all_results = futures::future::join_all(all_tasks).await;
 83
 84    write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap();
 85    if let Some(mut output_file) =
 86        std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err()
 87    {
 88        write_aggregated_scores(&mut output_file, &all_results).log_err();
 89    };
 90    print_run_data_dir(args.repetitions == 1, std::io::stdout().is_terminal());
 91}
 92
 93fn write_aggregated_scores(
 94    w: &mut impl std::io::Write,
 95    all_results: &Vec<Vec<Result<EvaluationResult, (anyhow::Error, String, Option<u16>)>>>,
 96) -> Result<()> {
 97    let mut successful = Vec::new();
 98    let mut failed_count = 0;
 99
100    for result in all_results.iter().flatten() {
101        match result {
102            Ok(eval_result) => successful.push(eval_result),
103            Err((err, name, repetition_ix)) => {
104                if failed_count == 0 {
105                    writeln!(w, "## Errors\n")?;
106                }
107
108                failed_count += 1;
109                let err = format!("{err:?}")
110                    .replace("<edits", "```xml\n<edits")
111                    .replace("</edits>", "</edits>\n```");
112                writeln!(
113                    w,
114                    "### ERROR {name}{}\n\n{err}\n",
115                    repetition_ix
116                        .map(|ix| format!(" [RUN {ix:03}]"))
117                        .unwrap_or_default()
118                )?;
119            }
120        }
121    }
122
123    if successful.len() > 1 {
124        let mut edit_predictions = successful
125            .iter()
126            .filter_map(|r| r.edit_prediction.as_ref())
127            .peekable();
128        let has_edit_predictions = edit_predictions.peek().is_some();
129        let aggregated_result = EvaluationResult {
130            context: Scores::aggregate(successful.iter().map(|r| &r.context)),
131            edit_prediction: has_edit_predictions.then(|| Scores::aggregate(edit_predictions)),
132            prompt_len: successful.iter().map(|r| r.prompt_len).sum::<usize>() / successful.len(),
133            generated_len: successful.iter().map(|r| r.generated_len).sum::<usize>()
134                / successful.len(),
135        };
136
137        writeln!(w, "\n{}", "-".repeat(80))?;
138        writeln!(w, "\n## TOTAL SCORES")?;
139        writeln!(w, "\n### Success Rate")?;
140        writeln!(w, "{:#}", aggregated_result)?;
141    }
142
143    if successful.len() + failed_count > 1 {
144        writeln!(
145            w,
146            "\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
147            successful.len(),
148            successful.len() + failed_count,
149            (successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
150        )?;
151    }
152
153    Ok(())
154}
155
156pub async fn run_evaluate_one(
157    example: NamedExample,
158    repetition_ix: Option<u16>,
159    project: Entity<Project>,
160    zeta: Entity<Zeta>,
161    prompt_format: PromptFormat,
162    use_expected_context: bool,
163    predict: bool,
164    cache_mode: CacheMode,
165    cx: &mut AsyncApp,
166) -> Result<EvaluationResult> {
167    let predict_result = zeta2_predict(
168        example.clone(),
169        project,
170        zeta,
171        repetition_ix,
172        prompt_format,
173        use_expected_context,
174        cache_mode,
175        cx,
176    )
177    .await?;
178
179    let evaluation_result = evaluate(&example.example, &predict_result, predict);
180
181    if repetition_ix.is_none() {
182        write_eval_result(
183            &example,
184            &predict_result,
185            &evaluation_result,
186            &mut std::io::stdout(),
187            std::io::stdout().is_terminal(),
188            predict,
189        )?;
190    }
191
192    if let Some(mut results_file) =
193        std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err()
194    {
195        write_eval_result(
196            &example,
197            &predict_result,
198            &evaluation_result,
199            &mut results_file,
200            false,
201            predict,
202        )
203        .log_err();
204    }
205
206    anyhow::Ok(evaluation_result)
207}
208
209fn write_eval_result(
210    example: &NamedExample,
211    predictions: &PredictionDetails,
212    evaluation_result: &EvaluationResult,
213    out: &mut impl Write,
214    use_color: bool,
215    predict: bool,
216) -> Result<()> {
217    if predict {
218        writeln!(
219            out,
220            "## Expected edit prediction:\n\n```diff\n{}\n```\n",
221            compare_diffs(
222                &example.example.expected_patch,
223                &predictions.diff,
224                use_color
225            )
226        )?;
227        writeln!(
228            out,
229            "## Actual edit prediction:\n\n```diff\n{}\n```\n",
230            compare_diffs(
231                &predictions.diff,
232                &example.example.expected_patch,
233                use_color
234            )
235        )?;
236    }
237
238    writeln!(out, "{:#}", evaluation_result)?;
239
240    anyhow::Ok(())
241}
242
243#[derive(Debug, Default)]
244pub struct EvaluationResult {
245    pub edit_prediction: Option<Scores>,
246    pub context: Scores,
247    pub prompt_len: usize,
248    pub generated_len: usize,
249}
250
251#[derive(Default, Debug)]
252pub struct Scores {
253    pub true_positives: usize,
254    pub false_positives: usize,
255    pub false_negatives: usize,
256}
257
258impl Scores {
259    pub fn new(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
260        let true_positives = expected.intersection(actual).count();
261        let false_positives = actual.difference(expected).count();
262        let false_negatives = expected.difference(actual).count();
263
264        Scores {
265            true_positives,
266            false_positives,
267            false_negatives,
268        }
269    }
270
271    pub fn to_markdown(&self) -> String {
272        format!(
273            "
274Precision       : {:.4}
275Recall          : {:.4}
276F1 Score        : {:.4}
277True Positives  : {}
278False Positives : {}
279False Negatives : {}",
280            self.precision(),
281            self.recall(),
282            self.f1_score(),
283            self.true_positives,
284            self.false_positives,
285            self.false_negatives
286        )
287    }
288
289    pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
290        let mut true_positives = 0;
291        let mut false_positives = 0;
292        let mut false_negatives = 0;
293
294        for score in scores {
295            true_positives += score.true_positives;
296            false_positives += score.false_positives;
297            false_negatives += score.false_negatives;
298        }
299
300        Scores {
301            true_positives,
302            false_positives,
303            false_negatives,
304        }
305    }
306
307    pub fn precision(&self) -> f64 {
308        if self.true_positives + self.false_positives == 0 {
309            0.0
310        } else {
311            self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
312        }
313    }
314
315    pub fn recall(&self) -> f64 {
316        if self.true_positives + self.false_negatives == 0 {
317            0.0
318        } else {
319            self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
320        }
321    }
322
323    pub fn f1_score(&self) -> f64 {
324        let recall = self.recall();
325        let precision = self.precision();
326        if precision + recall == 0.0 {
327            0.0
328        } else {
329            2.0 * precision * recall / (precision + recall)
330        }
331    }
332}
333
334impl std::fmt::Display for EvaluationResult {
335    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
336        if f.alternate() {
337            self.fmt_table(f)
338        } else {
339            self.fmt_markdown(f)
340        }
341    }
342}
343
344impl EvaluationResult {
345    fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
346        write!(
347            f,
348            r#"
349### Context Scores
350{}
351"#,
352            self.context.to_markdown(),
353        )?;
354        if let Some(prediction) = &self.edit_prediction {
355            write!(
356                f,
357                r#"
358                ### Edit Prediction Scores
359                {}"#,
360                prediction.to_markdown()
361            )?;
362        }
363        Ok(())
364    }
365
366    fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367        writeln!(f, "### Scores\n")?;
368        writeln!(
369            f,
370            "                   Prompt  Generated  TP     FP     FN     Precision   Recall     F1"
371        )?;
372        writeln!(
373            f,
374            "────────────────────────────────────────────────────────────────────────────────────"
375        )?;
376        writeln!(
377            f,
378            "Context Retrieval  {:<7} {:<10} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
379            "",
380            "",
381            self.context.true_positives,
382            self.context.false_positives,
383            self.context.false_negatives,
384            self.context.precision() * 100.0,
385            self.context.recall() * 100.0,
386            self.context.f1_score() * 100.0
387        )?;
388        if let Some(edit_prediction) = &self.edit_prediction {
389            writeln!(
390                f,
391                "Edit Prediction    {:<7} {:<10} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
392                self.prompt_len,
393                self.generated_len,
394                edit_prediction.true_positives,
395                edit_prediction.false_positives,
396                edit_prediction.false_negatives,
397                edit_prediction.precision() * 100.0,
398                edit_prediction.recall() * 100.0,
399                edit_prediction.f1_score() * 100.0
400            )?;
401        }
402        Ok(())
403    }
404}
405
406pub fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult {
407    let mut eval_result = EvaluationResult {
408        prompt_len: preds.prompt_len,
409        generated_len: preds.generated_len,
410        ..Default::default()
411    };
412
413    let actual_context_lines: HashSet<_> = preds
414        .excerpts
415        .iter()
416        .flat_map(|excerpt| {
417            excerpt
418                .text
419                .lines()
420                .map(|line| format!("{}: {line}", excerpt.path.display()))
421        })
422        .collect();
423
424    let mut false_positive_lines = actual_context_lines.clone();
425
426    for entry in &example.expected_context {
427        let mut best_alternative_score: Option<Scores> = None;
428
429        for alternative in &entry.alternatives {
430            let expected: HashSet<_> = alternative
431                .excerpts
432                .iter()
433                .flat_map(|excerpt| {
434                    excerpt
435                        .text
436                        .lines()
437                        .map(|line| format!("{}: {line}", excerpt.path.display()))
438                })
439                .collect();
440
441            let scores = Scores::new(&expected, &actual_context_lines);
442
443            false_positive_lines.retain(|line| !actual_context_lines.contains(line));
444
445            if best_alternative_score
446                .as_ref()
447                .is_none_or(|best| scores.recall() > best.recall())
448            {
449                best_alternative_score = Some(scores);
450            }
451        }
452
453        let best_alternative = best_alternative_score.unwrap_or_default();
454        eval_result.context.false_negatives += best_alternative.false_negatives;
455        eval_result.context.true_positives += best_alternative.true_positives;
456    }
457
458    eval_result.context.false_positives = false_positive_lines.len();
459
460    if predict {
461        // todo: alternatives for patches
462        let expected_patch_lines = example
463            .expected_patch
464            .lines()
465            .map(DiffLine::parse)
466            .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
467            .map(|line| line.to_string())
468            .collect();
469
470        let actual_patch_lines = preds
471            .diff
472            .lines()
473            .map(DiffLine::parse)
474            .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
475            .map(|line| line.to_string())
476            .collect();
477
478        eval_result.edit_prediction = Some(Scores::new(&expected_patch_lines, &actual_patch_lines));
479    }
480
481    eval_result
482}
483
484/// Return annotated `patch_a` so that:
485/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
486/// Additions and deletions that are present in `patch_b` will be highlighted in green.
487pub fn compare_diffs(patch_a: &str, patch_b: &str, use_color: bool) -> String {
488    let green = if use_color { "\x1b[32m✓ " } else { "" };
489    let red = if use_color { "\x1b[31m✗ " } else { "" };
490    let neutral = if use_color { "  " } else { "" };
491    let reset = if use_color { "\x1b[0m" } else { "" };
492    let lines_a = patch_a.lines().map(DiffLine::parse);
493    let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
494
495    let annotated = lines_a
496        .map(|line| match line {
497            DiffLine::Addition(_) | DiffLine::Deletion(_) => {
498                if lines_b.contains(&line) {
499                    format!("{green}{line}{reset}")
500                } else {
501                    format!("{red}{line}{reset}")
502                }
503            }
504            _ => format!("{neutral}{line}{reset}"),
505        })
506        .collect::<Vec<String>>();
507
508    annotated.join("\n")
509}