evaluate.rs

  1use std::{
  2    io::{IsTerminal, Write},
  3    path::PathBuf,
  4    sync::Arc,
  5};
  6
  7use anyhow::Result;
  8use clap::Args;
  9use collections::HashSet;
 10use gpui::{AsyncApp, Entity};
 11use project::Project;
 12use util::ResultExt as _;
 13use zeta2::{Zeta, udiff::DiffLine};
 14
 15use crate::{
 16    PromptFormat,
 17    example::{Example, NamedExample},
 18    headless::ZetaCliAppState,
 19    paths::print_run_data_dir,
 20    predict::{CacheMode, PredictionDetails, zeta2_predict},
 21};
 22
 23#[derive(Debug, Args)]
 24pub struct EvaluateArguments {
 25    example_paths: Vec<PathBuf>,
 26    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
 27    prompt_format: PromptFormat,
 28    #[arg(long)]
 29    use_expected_context: bool,
 30    #[clap(long, value_enum, default_value_t = CacheMode::default())]
 31    cache: CacheMode,
 32    #[clap(short, long, default_value_t = 1, alias = "repeat")]
 33    repetitions: u16,
 34    #[arg(long)]
 35    skip_prediction: bool,
 36}
 37
 38pub async fn run_evaluate(
 39    args: EvaluateArguments,
 40    app_state: &Arc<ZetaCliAppState>,
 41    cx: &mut AsyncApp,
 42) {
 43    if args.example_paths.is_empty() {
 44        eprintln!("No examples provided");
 45        return;
 46    }
 47    let all_tasks = args.example_paths.into_iter().map(|path| {
 48        let app_state = app_state.clone();
 49        let example = NamedExample::load(&path).unwrap();
 50
 51        cx.spawn(async move |cx| {
 52            let (project, zetas, _edited_buffers) = example
 53                .setup_project(&app_state, args.repetitions, cx)
 54                .await
 55                .unwrap();
 56
 57            let tasks = zetas.into_iter().enumerate().map(|(repetition_ix, zeta)| {
 58                let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
 59                let example = example.clone();
 60                let project = project.clone();
 61
 62                cx.spawn(async move |cx| {
 63                    let name = example.name.clone();
 64                    run_evaluate_one(
 65                        example,
 66                        repetition_ix,
 67                        project,
 68                        zeta,
 69                        args.prompt_format,
 70                        args.use_expected_context,
 71                        !args.skip_prediction,
 72                        args.cache,
 73                        cx,
 74                    )
 75                    .await
 76                    .map_err(|err| (err, name, repetition_ix))
 77                })
 78            });
 79            futures::future::join_all(tasks).await
 80        })
 81    });
 82    let all_results = futures::future::join_all(all_tasks).await;
 83
 84    write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap();
 85    if let Some(mut output_file) =
 86        std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err()
 87    {
 88        write_aggregated_scores(&mut output_file, &all_results).log_err();
 89    };
 90    print_run_data_dir(args.repetitions == 1, std::io::stdout().is_terminal());
 91}
 92
 93fn write_aggregated_scores(
 94    w: &mut impl std::io::Write,
 95    all_results: &Vec<Vec<Result<EvaluationResult, (anyhow::Error, String, Option<u16>)>>>,
 96) -> Result<()> {
 97    let mut successful = Vec::new();
 98    let mut failed_count = 0;
 99
100    for result in all_results.iter().flatten() {
101        match result {
102            Ok(eval_result) => successful.push(eval_result),
103            Err((err, name, repetition_ix)) => {
104                if failed_count == 0 {
105                    writeln!(w, "## Errors\n")?;
106                }
107
108                failed_count += 1;
109                let err = format!("{err:?}")
110                    .replace("<edits", "```xml\n<edits")
111                    .replace("</edits>", "</edits>\n```");
112                writeln!(
113                    w,
114                    "### ERROR {name}{}\n\n{err}\n",
115                    repetition_ix
116                        .map(|ix| format!(" [RUN {ix:03}]"))
117                        .unwrap_or_default()
118                )?;
119            }
120        }
121    }
122
123    if successful.len() > 1 {
124        let mut edit_predictions = successful
125            .iter()
126            .filter_map(|r| r.edit_prediction.as_ref())
127            .peekable();
128        let has_edit_predictions = edit_predictions.peek().is_some();
129        let aggregated_result = EvaluationResult {
130            context: Scores::aggregate(successful.iter().map(|r| &r.context)),
131            edit_prediction: has_edit_predictions.then(|| Scores::aggregate(edit_predictions)),
132        };
133
134        writeln!(w, "\n{}", "-".repeat(80))?;
135        writeln!(w, "\n## TOTAL SCORES")?;
136        writeln!(w, "\n### Success Rate")?;
137        writeln!(w, "{}", aggregated_result)?;
138    }
139
140    if successful.len() + failed_count > 1 {
141        writeln!(
142            w,
143            "\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
144            successful.len(),
145            successful.len() + failed_count,
146            (successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
147        )?;
148    }
149
150    Ok(())
151}
152
153pub async fn run_evaluate_one(
154    example: NamedExample,
155    repetition_ix: Option<u16>,
156    project: Entity<Project>,
157    zeta: Entity<Zeta>,
158    prompt_format: PromptFormat,
159    use_expected_context: bool,
160    predict: bool,
161    cache_mode: CacheMode,
162    cx: &mut AsyncApp,
163) -> Result<EvaluationResult> {
164    let predict_result = zeta2_predict(
165        example.clone(),
166        project,
167        zeta,
168        repetition_ix,
169        prompt_format,
170        use_expected_context,
171        cache_mode,
172        cx,
173    )
174    .await?;
175
176    let evaluation_result = evaluate(&example.example, &predict_result, predict);
177
178    if repetition_ix.is_none() {
179        write_eval_result(
180            &example,
181            &predict_result,
182            &evaluation_result,
183            &mut std::io::stdout(),
184            std::io::stdout().is_terminal(),
185            predict,
186        )?;
187    }
188
189    if let Some(mut results_file) =
190        std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err()
191    {
192        write_eval_result(
193            &example,
194            &predict_result,
195            &evaluation_result,
196            &mut results_file,
197            false,
198            predict,
199        )
200        .log_err();
201    }
202
203    anyhow::Ok(evaluation_result)
204}
205
206fn write_eval_result(
207    example: &NamedExample,
208    predictions: &PredictionDetails,
209    evaluation_result: &EvaluationResult,
210    out: &mut impl Write,
211    use_color: bool,
212    predict: bool,
213) -> Result<()> {
214    if predict {
215        writeln!(
216            out,
217            "## Expected edit prediction:\n\n```diff\n{}\n```\n",
218            compare_diffs(
219                &example.example.expected_patch,
220                &predictions.diff,
221                use_color
222            )
223        )?;
224        writeln!(
225            out,
226            "## Actual edit prediction:\n\n```diff\n{}\n```\n",
227            compare_diffs(
228                &predictions.diff,
229                &example.example.expected_patch,
230                use_color
231            )
232        )?;
233    }
234
235    writeln!(out, "{:#}", evaluation_result)?;
236
237    anyhow::Ok(())
238}
239
240#[derive(Debug, Default)]
241pub struct EvaluationResult {
242    pub edit_prediction: Option<Scores>,
243    pub context: Scores,
244}
245
246#[derive(Default, Debug)]
247pub struct Scores {
248    pub true_positives: usize,
249    pub false_positives: usize,
250    pub false_negatives: usize,
251}
252
253impl Scores {
254    pub fn new(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
255        let true_positives = expected.intersection(actual).count();
256        let false_positives = actual.difference(expected).count();
257        let false_negatives = expected.difference(actual).count();
258
259        Scores {
260            true_positives,
261            false_positives,
262            false_negatives,
263        }
264    }
265
266    pub fn to_markdown(&self) -> String {
267        format!(
268            "
269Precision       : {:.4}
270Recall          : {:.4}
271F1 Score        : {:.4}
272True Positives  : {}
273False Positives : {}
274False Negatives : {}",
275            self.precision(),
276            self.recall(),
277            self.f1_score(),
278            self.true_positives,
279            self.false_positives,
280            self.false_negatives
281        )
282    }
283
284    pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
285        let mut true_positives = 0;
286        let mut false_positives = 0;
287        let mut false_negatives = 0;
288
289        for score in scores {
290            true_positives += score.true_positives;
291            false_positives += score.false_positives;
292            false_negatives += score.false_negatives;
293        }
294
295        Scores {
296            true_positives,
297            false_positives,
298            false_negatives,
299        }
300    }
301
302    pub fn precision(&self) -> f64 {
303        if self.true_positives + self.false_positives == 0 {
304            0.0
305        } else {
306            self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
307        }
308    }
309
310    pub fn recall(&self) -> f64 {
311        if self.true_positives + self.false_negatives == 0 {
312            0.0
313        } else {
314            self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
315        }
316    }
317
318    pub fn f1_score(&self) -> f64 {
319        let recall = self.recall();
320        let precision = self.precision();
321        if precision + recall == 0.0 {
322            0.0
323        } else {
324            2.0 * precision * recall / (precision + recall)
325        }
326    }
327}
328
329impl std::fmt::Display for EvaluationResult {
330    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
331        if f.alternate() {
332            self.fmt_table(f)
333        } else {
334            self.fmt_markdown(f)
335        }
336    }
337}
338
339impl EvaluationResult {
340    fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341        write!(
342            f,
343            r#"
344### Context Scores
345{}
346"#,
347            self.context.to_markdown(),
348        )?;
349        if let Some(prediction) = &self.edit_prediction {
350            write!(
351                f,
352                r#"
353                ### Edit Prediction Scores
354                {}"#,
355                prediction.to_markdown()
356            )?;
357        }
358        Ok(())
359    }
360
361    fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
362        writeln!(f, "### Scores\n")?;
363        writeln!(
364            f,
365            "                   TP     FP     FN     Precision   Recall     F1"
366        )?;
367        writeln!(
368            f,
369            "──────────────────────────────────────────────────────────────────"
370        )?;
371        writeln!(
372            f,
373            "Context Retrieval  {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
374            self.context.true_positives,
375            self.context.false_positives,
376            self.context.false_negatives,
377            self.context.precision() * 100.0,
378            self.context.recall() * 100.0,
379            self.context.f1_score() * 100.0
380        )?;
381        if let Some(edit_prediction) = &self.edit_prediction {
382            writeln!(
383                f,
384                "Edit Prediction    {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
385                edit_prediction.true_positives,
386                edit_prediction.false_positives,
387                edit_prediction.false_negatives,
388                edit_prediction.precision() * 100.0,
389                edit_prediction.recall() * 100.0,
390                edit_prediction.f1_score() * 100.0
391            )?;
392        }
393        Ok(())
394    }
395}
396
397pub fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult {
398    let mut eval_result = EvaluationResult::default();
399
400    let actual_context_lines: HashSet<_> = preds
401        .excerpts
402        .iter()
403        .flat_map(|excerpt| {
404            excerpt
405                .text
406                .lines()
407                .map(|line| format!("{}: {line}", excerpt.path.display()))
408        })
409        .collect();
410
411    let mut false_positive_lines = actual_context_lines.clone();
412
413    for entry in &example.expected_context {
414        let mut best_alternative_score: Option<Scores> = None;
415
416        for alternative in &entry.alternatives {
417            let expected: HashSet<_> = alternative
418                .excerpts
419                .iter()
420                .flat_map(|excerpt| {
421                    excerpt
422                        .text
423                        .lines()
424                        .map(|line| format!("{}: {line}", excerpt.path.display()))
425                })
426                .collect();
427
428            let scores = Scores::new(&expected, &actual_context_lines);
429
430            false_positive_lines.retain(|line| !actual_context_lines.contains(line));
431
432            if best_alternative_score
433                .as_ref()
434                .is_none_or(|best| scores.recall() > best.recall())
435            {
436                best_alternative_score = Some(scores);
437            }
438        }
439
440        let best_alternative = best_alternative_score.unwrap_or_default();
441        eval_result.context.false_negatives += best_alternative.false_negatives;
442        eval_result.context.true_positives += best_alternative.true_positives;
443    }
444
445    eval_result.context.false_positives = false_positive_lines.len();
446
447    if predict {
448        // todo: alternatives for patches
449        let expected_patch_lines = example
450            .expected_patch
451            .lines()
452            .map(DiffLine::parse)
453            .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
454            .map(|line| line.to_string())
455            .collect();
456
457        let actual_patch_lines = preds
458            .diff
459            .lines()
460            .map(DiffLine::parse)
461            .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
462            .map(|line| line.to_string())
463            .collect();
464
465        eval_result.edit_prediction = Some(Scores::new(&expected_patch_lines, &actual_patch_lines));
466    }
467
468    eval_result
469}
470
471/// Return annotated `patch_a` so that:
472/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
473/// Additions and deletions that are present in `patch_b` will be highlighted in green.
474pub fn compare_diffs(patch_a: &str, patch_b: &str, use_color: bool) -> String {
475    let green = if use_color { "\x1b[32m✓ " } else { "" };
476    let red = if use_color { "\x1b[31m✗ " } else { "" };
477    let neutral = if use_color { "  " } else { "" };
478    let reset = if use_color { "\x1b[0m" } else { "" };
479    let lines_a = patch_a.lines().map(DiffLine::parse);
480    let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
481
482    let annotated = lines_a
483        .map(|line| match line {
484            DiffLine::Addition(_) | DiffLine::Deletion(_) => {
485                if lines_b.contains(&line) {
486                    format!("{green}{line}{reset}")
487                } else {
488                    format!("{red}{line}{reset}")
489                }
490            }
491            _ => format!("{neutral}{line}{reset}"),
492        })
493        .collect::<Vec<String>>();
494
495    annotated.join("\n")
496}