evaluate.rs

  1use std::{
  2    io::{IsTerminal, Write},
  3    path::PathBuf,
  4    sync::Arc,
  5};
  6
  7use anyhow::Result;
  8use clap::Args;
  9use collections::HashSet;
 10use gpui::{AsyncApp, Entity};
 11use project::Project;
 12use util::ResultExt as _;
 13use zeta2::{Zeta, udiff::DiffLine};
 14
 15use crate::{
 16    PromptFormat,
 17    example::{Example, NamedExample},
 18    headless::ZetaCliAppState,
 19    paths::print_run_data_dir,
 20    predict::{CacheMode, PredictionDetails, zeta2_predict},
 21};
 22
 23#[derive(Debug, Args)]
 24pub struct EvaluateArguments {
 25    example_paths: Vec<PathBuf>,
 26    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
 27    prompt_format: PromptFormat,
 28    #[arg(long)]
 29    use_expected_context: bool,
 30    #[clap(long, value_enum, default_value_t = CacheMode::default())]
 31    cache: CacheMode,
 32    #[clap(short, long, default_value_t = 1, alias = "repeat")]
 33    repetitions: u16,
 34}
 35
 36pub async fn run_evaluate(
 37    args: EvaluateArguments,
 38    app_state: &Arc<ZetaCliAppState>,
 39    cx: &mut AsyncApp,
 40) {
 41    if args.example_paths.is_empty() {
 42        eprintln!("No examples provided");
 43        return;
 44    }
 45    let all_tasks = args.example_paths.into_iter().map(|path| {
 46        let app_state = app_state.clone();
 47        let example = NamedExample::load(&path).unwrap();
 48
 49        cx.spawn(async move |cx| {
 50            let (project, zetas, _edited_buffers) = example
 51                .setup_project(&app_state, args.repetitions, cx)
 52                .await
 53                .unwrap();
 54
 55            let tasks = zetas.into_iter().enumerate().map(|(repetition_ix, zeta)| {
 56                let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
 57
 58                let example = example.clone();
 59                let project = project.clone();
 60
 61                cx.spawn(async move |cx| {
 62                    let name = example.name.clone();
 63                    run_evaluate_one(
 64                        example,
 65                        repetition_ix,
 66                        project,
 67                        zeta,
 68                        args.prompt_format,
 69                        args.use_expected_context,
 70                        args.cache,
 71                        cx,
 72                    )
 73                    .await
 74                    .map_err(|err| (err, name, repetition_ix))
 75                })
 76            });
 77            futures::future::join_all(tasks).await
 78        })
 79    });
 80    let all_results = futures::future::join_all(all_tasks).await;
 81
 82    write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap();
 83    if let Some(mut output_file) =
 84        std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err()
 85    {
 86        write_aggregated_scores(&mut output_file, &all_results).log_err();
 87    };
 88    print_run_data_dir(args.repetitions == 1);
 89}
 90
 91fn write_aggregated_scores(
 92    w: &mut impl std::io::Write,
 93    all_results: &Vec<Vec<Result<EvaluationResult, (anyhow::Error, String, Option<u16>)>>>,
 94) -> Result<()> {
 95    let mut successful = Vec::new();
 96    let mut failed_count = 0;
 97
 98    for result in all_results.iter().flatten() {
 99        match result {
100            Ok(eval_result) => successful.push(eval_result),
101            Err((err, name, repetition_ix)) => {
102                if failed_count == 0 {
103                    writeln!(w, "## Errors\n")?;
104                }
105
106                failed_count += 1;
107                let err = err
108                    .to_string()
109                    .replace("<edits", "```xml\n<edits")
110                    .replace("</edits>", "</edits>\n```");
111                writeln!(
112                    w,
113                    "### ERROR {name}{}\n\n{err}\n",
114                    repetition_ix
115                        .map(|ix| format!(" [RUN {ix:03}]"))
116                        .unwrap_or_default()
117                )?;
118            }
119        }
120    }
121
122    if successful.len() > 1 {
123        let aggregated_result = EvaluationResult {
124            context: Scores::aggregate(successful.iter().map(|r| &r.context)),
125            edit_prediction: Scores::aggregate(successful.iter().map(|r| &r.edit_prediction)),
126        };
127
128        writeln!(w, "\n{}", "-".repeat(80))?;
129        writeln!(w, "\n## TOTAL SCORES")?;
130        writeln!(w, "\n### Success Rate")?;
131        writeln!(w, "{}", aggregated_result)?;
132    }
133
134    if successful.len() + failed_count > 1 {
135        writeln!(
136            w,
137            "\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
138            successful.len(),
139            successful.len() + failed_count,
140            (successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
141        )?;
142    }
143
144    Ok(())
145}
146
147pub async fn run_evaluate_one(
148    example: NamedExample,
149    repetition_ix: Option<u16>,
150    project: Entity<Project>,
151    zeta: Entity<Zeta>,
152    prompt_format: PromptFormat,
153    use_expected_context: bool,
154    cache_mode: CacheMode,
155    cx: &mut AsyncApp,
156) -> Result<EvaluationResult> {
157    let predict_result = zeta2_predict(
158        example.clone(),
159        project,
160        zeta,
161        repetition_ix,
162        prompt_format,
163        use_expected_context,
164        cache_mode,
165        cx,
166    )
167    .await?;
168
169    let evaluation_result = evaluate(&example.example, &predict_result);
170
171    if repetition_ix.is_none() {
172        write_eval_result(
173            &example,
174            &predict_result,
175            &evaluation_result,
176            &mut std::io::stdout(),
177        )?;
178    }
179
180    if let Some(mut results_file) =
181        std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err()
182    {
183        write_eval_result(
184            &example,
185            &predict_result,
186            &evaluation_result,
187            &mut results_file,
188        )
189        .log_err();
190    }
191
192    anyhow::Ok(evaluation_result)
193}
194
195fn write_eval_result(
196    example: &NamedExample,
197    predictions: &PredictionDetails,
198    evaluation_result: &EvaluationResult,
199    out: &mut impl Write,
200) -> Result<()> {
201    writeln!(
202        out,
203        "## Expected edit prediction:\n\n```diff\n{}\n```\n",
204        compare_diffs(&example.example.expected_patch, &predictions.diff)
205    )?;
206    writeln!(
207        out,
208        "## Actual edit prediction:\n\n```diff\n{}\n```\n",
209        compare_diffs(&predictions.diff, &example.example.expected_patch)
210    )?;
211    writeln!(out, "{}", evaluation_result)?;
212
213    anyhow::Ok(())
214}
215
216#[derive(Debug, Default)]
217pub struct EvaluationResult {
218    pub edit_prediction: Scores,
219    pub context: Scores,
220}
221
222#[derive(Default, Debug)]
223pub struct Scores {
224    pub true_positives: usize,
225    pub false_positives: usize,
226    pub false_negatives: usize,
227}
228
229impl Scores {
230    pub fn new(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
231        let true_positives = expected.intersection(actual).count();
232        let false_positives = actual.difference(expected).count();
233        let false_negatives = expected.difference(actual).count();
234
235        Scores {
236            true_positives,
237            false_positives,
238            false_negatives,
239        }
240    }
241
242    pub fn to_markdown(&self) -> String {
243        format!(
244            "
245Precision       : {:.4}
246Recall          : {:.4}
247F1 Score        : {:.4}
248True Positives  : {}
249False Positives : {}
250False Negatives : {}",
251            self.precision(),
252            self.recall(),
253            self.f1_score(),
254            self.true_positives,
255            self.false_positives,
256            self.false_negatives
257        )
258    }
259
260    pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
261        let mut true_positives = 0;
262        let mut false_positives = 0;
263        let mut false_negatives = 0;
264
265        for score in scores {
266            true_positives += score.true_positives;
267            false_positives += score.false_positives;
268            false_negatives += score.false_negatives;
269        }
270
271        Scores {
272            true_positives,
273            false_positives,
274            false_negatives,
275        }
276    }
277
278    pub fn precision(&self) -> f64 {
279        if self.true_positives + self.false_positives == 0 {
280            0.0
281        } else {
282            self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
283        }
284    }
285
286    pub fn recall(&self) -> f64 {
287        if self.true_positives + self.false_negatives == 0 {
288            0.0
289        } else {
290            self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
291        }
292    }
293
294    pub fn f1_score(&self) -> f64 {
295        let recall = self.recall();
296        let precision = self.precision();
297        if precision + recall == 0.0 {
298            0.0
299        } else {
300            2.0 * precision * recall / (precision + recall)
301        }
302    }
303}
304
305impl std::fmt::Display for EvaluationResult {
306    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307        write!(
308            f,
309            r#"
310### Context Scores
311{}
312
313### Edit Prediction Scores
314{}
315"#,
316            self.context.to_markdown(),
317            self.edit_prediction.to_markdown()
318        )
319    }
320}
321
322pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult {
323    let mut eval_result = EvaluationResult::default();
324
325    let actual_context_lines: HashSet<_> = preds
326        .excerpts
327        .iter()
328        .flat_map(|excerpt| {
329            excerpt
330                .text
331                .lines()
332                .map(|line| format!("{}: {line}", excerpt.path.display()))
333        })
334        .collect();
335
336    let mut false_positive_lines = actual_context_lines.clone();
337
338    for entry in &example.expected_context {
339        let mut best_alternative_score: Option<Scores> = None;
340
341        for alternative in &entry.alternatives {
342            let expected: HashSet<_> = alternative
343                .excerpts
344                .iter()
345                .flat_map(|excerpt| {
346                    excerpt
347                        .text
348                        .lines()
349                        .map(|line| format!("{}: {line}", excerpt.path.display()))
350                })
351                .collect();
352
353            let scores = Scores::new(&expected, &actual_context_lines);
354
355            false_positive_lines.retain(|line| !actual_context_lines.contains(line));
356
357            if best_alternative_score
358                .as_ref()
359                .is_none_or(|best| scores.recall() > best.recall())
360            {
361                best_alternative_score = Some(scores);
362            }
363        }
364
365        let best_alternative = best_alternative_score.unwrap_or_default();
366        eval_result.context.false_negatives += best_alternative.false_negatives;
367        eval_result.context.true_positives += best_alternative.true_positives;
368    }
369
370    eval_result.context.false_positives = false_positive_lines.len();
371
372    // todo: alternatives for patches
373    let expected_patch_lines = example
374        .expected_patch
375        .lines()
376        .map(DiffLine::parse)
377        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
378        .map(|line| line.to_string())
379        .collect();
380
381    let actual_patch_lines = preds
382        .diff
383        .lines()
384        .map(DiffLine::parse)
385        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
386        .map(|line| line.to_string())
387        .collect();
388
389    eval_result.edit_prediction = Scores::new(&expected_patch_lines, &actual_patch_lines);
390    eval_result
391}
392
393/// Return annotated `patch_a` so that:
394/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
395/// Additions and deletions that are present in `patch_b` will be highlighted in green.
396pub fn compare_diffs(patch_a: &str, patch_b: &str) -> String {
397    let use_color = std::io::stdout().is_terminal();
398    let green = if use_color { "\x1b[32m✓ " } else { "" };
399    let red = if use_color { "\x1b[31m✗ " } else { "" };
400    let neutral = if use_color { "  " } else { "" };
401    let reset = if use_color { "\x1b[0m" } else { "" };
402    let lines_a = patch_a.lines().map(DiffLine::parse);
403    let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
404
405    let annotated = lines_a
406        .map(|line| match line {
407            DiffLine::Addition(_) | DiffLine::Deletion(_) => {
408                if lines_b.contains(&line) {
409                    format!("{green}{line}{reset}")
410                } else {
411                    format!("{red}{line}{reset}")
412                }
413            }
414            _ => format!("{neutral}{line}{reset}"),
415        })
416        .collect::<Vec<String>>();
417
418    annotated.join("\n")
419}