evaluate.rs

  1use std::{
  2    io::{IsTerminal, Write},
  3    path::PathBuf,
  4    sync::Arc,
  5};
  6
  7use anyhow::Result;
  8use clap::Args;
  9use collections::HashSet;
 10use gpui::{AsyncApp, Entity};
 11use project::Project;
 12use util::ResultExt as _;
 13use zeta2::{Zeta, udiff::DiffLine};
 14
 15use crate::{
 16    PromptFormat,
 17    example::{Example, NamedExample},
 18    headless::ZetaCliAppState,
 19    paths::print_run_data_dir,
 20    predict::{CacheMode, PredictionDetails, zeta2_predict},
 21};
 22
 23#[derive(Debug, Args)]
 24pub struct EvaluateArguments {
 25    example_paths: Vec<PathBuf>,
 26    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
 27    prompt_format: PromptFormat,
 28    #[arg(long)]
 29    use_expected_context: bool,
 30    #[clap(long, value_enum, default_value_t = CacheMode::default())]
 31    cache: CacheMode,
 32    #[clap(short, long, default_value_t = 1, alias = "repeat")]
 33    repetitions: u16,
 34}
 35
 36pub async fn run_evaluate(
 37    args: EvaluateArguments,
 38    app_state: &Arc<ZetaCliAppState>,
 39    cx: &mut AsyncApp,
 40) {
 41    if args.example_paths.is_empty() {
 42        eprintln!("No examples provided");
 43        return;
 44    }
 45    let all_tasks = args.example_paths.into_iter().map(|path| {
 46        let app_state = app_state.clone();
 47        let example = NamedExample::load(&path).unwrap();
 48
 49        cx.spawn(async move |cx| {
 50            let (project, zetas, _edited_buffers) = example
 51                .setup_project(&app_state, args.repetitions, cx)
 52                .await
 53                .unwrap();
 54
 55            let tasks = zetas.into_iter().enumerate().map(|(repetition_ix, zeta)| {
 56                let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
 57
 58                let example = example.clone();
 59                let project = project.clone();
 60
 61                cx.spawn(async move |cx| {
 62                    let name = example.name.clone();
 63                    run_evaluate_one(
 64                        example,
 65                        repetition_ix,
 66                        project,
 67                        zeta,
 68                        args.prompt_format,
 69                        args.use_expected_context,
 70                        args.cache,
 71                        cx,
 72                    )
 73                    .await
 74                    .map_err(|err| (err, name, repetition_ix))
 75                })
 76            });
 77            futures::future::join_all(tasks).await
 78        })
 79    });
 80    let all_results = futures::future::join_all(all_tasks).await;
 81
 82    write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap();
 83    if let Some(mut output_file) =
 84        std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err()
 85    {
 86        write_aggregated_scores(&mut output_file, &all_results).log_err();
 87    };
 88    print_run_data_dir(args.repetitions == 1);
 89}
 90
 91fn write_aggregated_scores(
 92    w: &mut impl std::io::Write,
 93    all_results: &Vec<Vec<Result<EvaluationResult, (anyhow::Error, String, Option<u16>)>>>,
 94) -> Result<()> {
 95    let mut successful = Vec::new();
 96    let mut failed_count = 0;
 97    writeln!(w, "## Errors\n")?;
 98    for result in all_results.iter().flatten() {
 99        match result {
100            Ok(eval_result) => successful.push(eval_result),
101            Err((err, name, repetition_ix)) => {
102                failed_count += 1;
103                let err = err
104                    .to_string()
105                    .replace("<edits", "```xml\n<edits")
106                    .replace("</edits>", "</edits>\n```");
107                writeln!(
108                    w,
109                    "### ERROR {name}{}\n\n{err}\n",
110                    repetition_ix
111                        .map(|ix| format!(" [RUN {ix:03}]"))
112                        .unwrap_or_default()
113                )?;
114            }
115        }
116    }
117    let aggregated_result = EvaluationResult {
118        context: Scores::aggregate(successful.iter().map(|r| &r.context)),
119        edit_prediction: Scores::aggregate(successful.iter().map(|r| &r.edit_prediction)),
120    };
121
122    writeln!(w, "\n{}", "-".repeat(80))?;
123    writeln!(w, "\n## TOTAL SCORES")?;
124    writeln!(w, "\n### Success Rate")?;
125    writeln!(
126        w,
127        "\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
128        successful.len(),
129        successful.len() + failed_count,
130        (successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
131    )?;
132    writeln!(w, "{}", aggregated_result)?;
133
134    Ok(())
135}
136
137pub async fn run_evaluate_one(
138    example: NamedExample,
139    repetition_ix: Option<u16>,
140    project: Entity<Project>,
141    zeta: Entity<Zeta>,
142    prompt_format: PromptFormat,
143    use_expected_context: bool,
144    cache_mode: CacheMode,
145    cx: &mut AsyncApp,
146) -> Result<EvaluationResult> {
147    let predict_result = zeta2_predict(
148        example.clone(),
149        project,
150        zeta,
151        repetition_ix,
152        prompt_format,
153        use_expected_context,
154        cache_mode,
155        cx,
156    )
157    .await?;
158
159    let evaluation_result = evaluate(&example.example, &predict_result);
160
161    if repetition_ix.is_none() {
162        write_eval_result(
163            &example,
164            &predict_result,
165            &evaluation_result,
166            &mut std::io::stdout(),
167        )?;
168    }
169
170    if let Some(mut results_file) =
171        std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err()
172    {
173        write_eval_result(
174            &example,
175            &predict_result,
176            &evaluation_result,
177            &mut results_file,
178        )
179        .log_err();
180    }
181
182    anyhow::Ok(evaluation_result)
183}
184
185fn write_eval_result(
186    example: &NamedExample,
187    predictions: &PredictionDetails,
188    evaluation_result: &EvaluationResult,
189    out: &mut impl Write,
190) -> Result<()> {
191    writeln!(
192        out,
193        "## Expected edit prediction:\n\n```diff\n{}\n```\n",
194        compare_diffs(&example.example.expected_patch, &predictions.diff)
195    )?;
196    writeln!(
197        out,
198        "## Actual edit prediction:\n\n```diff\n{}\n```\n",
199        compare_diffs(&predictions.diff, &example.example.expected_patch)
200    )?;
201    writeln!(out, "{}", evaluation_result)?;
202
203    anyhow::Ok(())
204}
205
206#[derive(Debug, Default)]
207pub struct EvaluationResult {
208    pub edit_prediction: Scores,
209    pub context: Scores,
210}
211
212#[derive(Default, Debug)]
213pub struct Scores {
214    pub true_positives: usize,
215    pub false_positives: usize,
216    pub false_negatives: usize,
217}
218
219impl Scores {
220    pub fn new(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
221        let true_positives = expected.intersection(actual).count();
222        let false_positives = actual.difference(expected).count();
223        let false_negatives = expected.difference(actual).count();
224
225        Scores {
226            true_positives,
227            false_positives,
228            false_negatives,
229        }
230    }
231
232    pub fn to_markdown(&self) -> String {
233        format!(
234            "
235Precision       : {:.4}
236Recall          : {:.4}
237F1 Score        : {:.4}
238True Positives  : {}
239False Positives : {}
240False Negatives : {}",
241            self.precision(),
242            self.recall(),
243            self.f1_score(),
244            self.true_positives,
245            self.false_positives,
246            self.false_negatives
247        )
248    }
249
250    pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
251        let mut true_positives = 0;
252        let mut false_positives = 0;
253        let mut false_negatives = 0;
254
255        for score in scores {
256            true_positives += score.true_positives;
257            false_positives += score.false_positives;
258            false_negatives += score.false_negatives;
259        }
260
261        Scores {
262            true_positives,
263            false_positives,
264            false_negatives,
265        }
266    }
267
268    pub fn precision(&self) -> f64 {
269        if self.true_positives + self.false_positives == 0 {
270            0.0
271        } else {
272            self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
273        }
274    }
275
276    pub fn recall(&self) -> f64 {
277        if self.true_positives + self.false_negatives == 0 {
278            0.0
279        } else {
280            self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
281        }
282    }
283
284    pub fn f1_score(&self) -> f64 {
285        let recall = self.recall();
286        let precision = self.precision();
287        if precision + recall == 0.0 {
288            0.0
289        } else {
290            2.0 * precision * recall / (precision + recall)
291        }
292    }
293}
294
295impl std::fmt::Display for EvaluationResult {
296    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297        write!(
298            f,
299            r#"
300### Context Scores
301{}
302
303### Edit Prediction Scores
304{}
305"#,
306            self.context.to_markdown(),
307            self.edit_prediction.to_markdown()
308        )
309    }
310}
311
312pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult {
313    let mut eval_result = EvaluationResult::default();
314
315    let actual_context_lines: HashSet<_> = preds
316        .excerpts
317        .iter()
318        .flat_map(|excerpt| {
319            excerpt
320                .text
321                .lines()
322                .map(|line| format!("{}: {line}", excerpt.path.display()))
323        })
324        .collect();
325
326    let mut false_positive_lines = actual_context_lines.clone();
327
328    for entry in &example.expected_context {
329        let mut best_alternative_score = Scores::default();
330
331        for alternative in &entry.alternatives {
332            let expected: HashSet<_> = alternative
333                .excerpts
334                .iter()
335                .flat_map(|excerpt| {
336                    excerpt
337                        .text
338                        .lines()
339                        .map(|line| format!("{}: {line}", excerpt.path.display()))
340                })
341                .collect();
342
343            let scores = Scores::new(&expected, &actual_context_lines);
344
345            false_positive_lines.retain(|line| !actual_context_lines.contains(line));
346
347            if scores.recall() > best_alternative_score.recall() {
348                best_alternative_score = scores;
349            }
350        }
351
352        eval_result.context.false_negatives += best_alternative_score.false_negatives;
353        eval_result.context.true_positives += best_alternative_score.true_positives;
354    }
355
356    eval_result.context.false_positives = false_positive_lines.len();
357
358    // todo: alternatives for patches
359    let expected_patch_lines = example
360        .expected_patch
361        .lines()
362        .map(DiffLine::parse)
363        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
364        .map(|line| line.to_string())
365        .collect();
366
367    let actual_patch_lines = preds
368        .diff
369        .lines()
370        .map(DiffLine::parse)
371        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
372        .map(|line| line.to_string())
373        .collect();
374
375    eval_result.edit_prediction = Scores::new(&expected_patch_lines, &actual_patch_lines);
376    eval_result
377}
378
379/// Return annotated `patch_a` so that:
380/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
381/// Additions and deletions that are present in `patch_b` will be highlighted in green.
382pub fn compare_diffs(patch_a: &str, patch_b: &str) -> String {
383    let use_color = std::io::stdout().is_terminal();
384    let green = if use_color { "\x1b[32m✓ " } else { "" };
385    let red = if use_color { "\x1b[31m✗ " } else { "" };
386    let neutral = if use_color { "  " } else { "" };
387    let reset = if use_color { "\x1b[0m" } else { "" };
388    let lines_a = patch_a.lines().map(DiffLine::parse);
389    let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
390
391    let annotated = lines_a
392        .map(|line| match line {
393            DiffLine::Addition(_) | DiffLine::Deletion(_) => {
394                if lines_b.contains(&line) {
395                    format!("{green}{line}{reset}")
396                } else {
397                    format!("{red}{line}{reset}")
398                }
399            }
400            _ => format!("{neutral}{line}{reset}"),
401        })
402        .collect::<Vec<String>>();
403
404    annotated.join("\n")
405}