evaluate.rs

  1use std::{
  2    io::IsTerminal,
  3    path::{Path, PathBuf},
  4    sync::Arc,
  5};
  6
  7use anyhow::Result;
  8use clap::Args;
  9use collections::HashSet;
 10use gpui::AsyncApp;
 11use zeta2::udiff::DiffLine;
 12
 13use crate::{
 14    PromptFormat,
 15    example::{Example, NamedExample},
 16    headless::ZetaCliAppState,
 17    predict::{PredictionDetails, zeta2_predict},
 18};
 19
 20#[derive(Debug, Args)]
 21pub struct EvaluateArguments {
 22    example_paths: Vec<PathBuf>,
 23    #[clap(long)]
 24    skip_cache: bool,
 25    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
 26    prompt_format: PromptFormat,
 27    #[arg(long)]
 28    use_expected_context: bool,
 29}
 30
 31pub async fn run_evaluate(
 32    args: EvaluateArguments,
 33    app_state: &Arc<ZetaCliAppState>,
 34    cx: &mut AsyncApp,
 35) {
 36    let example_len = args.example_paths.len();
 37    let all_tasks = args.example_paths.into_iter().map(|path| {
 38        let app_state = app_state.clone();
 39        cx.spawn(async move |cx| {
 40            run_evaluate_one(
 41                &path,
 42                args.skip_cache,
 43                args.prompt_format,
 44                args.use_expected_context,
 45                app_state.clone(),
 46                cx,
 47            )
 48            .await
 49        })
 50    });
 51    let all_results = futures::future::try_join_all(all_tasks).await.unwrap();
 52
 53    let aggregated_result = EvaluationResult {
 54        context: Scores::aggregate(all_results.iter().map(|r| &r.context)),
 55        edit_prediction: Scores::aggregate(all_results.iter().map(|r| &r.edit_prediction)),
 56    };
 57
 58    if example_len > 1 {
 59        println!("\n{}", "-".repeat(80));
 60        println!("# TOTAL SCORES:");
 61        println!("{}", aggregated_result.to_markdown());
 62    }
 63}
 64
 65pub async fn run_evaluate_one(
 66    example_path: &Path,
 67    skip_cache: bool,
 68    prompt_format: PromptFormat,
 69    use_expected_context: bool,
 70    app_state: Arc<ZetaCliAppState>,
 71    cx: &mut AsyncApp,
 72) -> Result<EvaluationResult> {
 73    let example = NamedExample::load(&example_path).unwrap();
 74    let predictions = zeta2_predict(
 75        example.clone(),
 76        skip_cache,
 77        prompt_format,
 78        use_expected_context,
 79        &app_state,
 80        cx,
 81    )
 82    .await
 83    .unwrap();
 84
 85    let evaluation_result = evaluate(&example.example, &predictions);
 86
 87    println!(
 88        "## Expected edit prediction:\n\n```diff\n{}\n```\n",
 89        compare_diffs(&example.example.expected_patch, &predictions.diff)
 90    );
 91    println!(
 92        "## Actual edit prediction:\n\n```diff\n{}\n```\n",
 93        compare_diffs(&predictions.diff, &example.example.expected_patch)
 94    );
 95
 96    println!("{}", evaluation_result.to_markdown());
 97
 98    anyhow::Ok(evaluation_result)
 99}
100
101#[derive(Debug, Default)]
102pub struct EvaluationResult {
103    pub edit_prediction: Scores,
104    pub context: Scores,
105}
106
107#[derive(Default, Debug)]
108pub struct Scores {
109    pub true_positives: usize,
110    pub false_positives: usize,
111    pub false_negatives: usize,
112}
113
114impl Scores {
115    pub fn new(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
116        let true_positives = expected.intersection(actual).count();
117        let false_positives = actual.difference(expected).count();
118        let false_negatives = expected.difference(actual).count();
119
120        Scores {
121            true_positives,
122            false_positives,
123            false_negatives,
124        }
125    }
126
127    pub fn to_markdown(&self) -> String {
128        format!(
129            "
130Precision       : {:.4}
131Recall          : {:.4}
132F1 Score        : {:.4}
133True Positives  : {}
134False Positives : {}
135False Negatives : {}",
136            self.precision(),
137            self.recall(),
138            self.f1_score(),
139            self.true_positives,
140            self.false_positives,
141            self.false_negatives
142        )
143    }
144
145    pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
146        let mut true_positives = 0;
147        let mut false_positives = 0;
148        let mut false_negatives = 0;
149
150        for score in scores {
151            true_positives += score.true_positives;
152            false_positives += score.false_positives;
153            false_negatives += score.false_negatives;
154        }
155
156        Scores {
157            true_positives,
158            false_positives,
159            false_negatives,
160        }
161    }
162
163    pub fn precision(&self) -> f64 {
164        if self.true_positives + self.false_positives == 0 {
165            0.0
166        } else {
167            self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
168        }
169    }
170
171    pub fn recall(&self) -> f64 {
172        if self.true_positives + self.false_negatives == 0 {
173            0.0
174        } else {
175            self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
176        }
177    }
178
179    pub fn f1_score(&self) -> f64 {
180        let recall = self.recall();
181        let precision = self.precision();
182        if precision + recall == 0.0 {
183            0.0
184        } else {
185            2.0 * precision * recall / (precision + recall)
186        }
187    }
188}
189
190impl EvaluationResult {
191    pub fn to_markdown(&self) -> String {
192        format!(
193            r#"
194### Context Scores
195{}
196
197### Edit Prediction Scores
198{}
199"#,
200            self.context.to_markdown(),
201            self.edit_prediction.to_markdown()
202        )
203    }
204}
205
206pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult {
207    let mut eval_result = EvaluationResult::default();
208
209    let actual_context_lines: HashSet<_> = preds
210        .excerpts
211        .iter()
212        .flat_map(|excerpt| {
213            excerpt
214                .text
215                .lines()
216                .map(|line| format!("{}: {line}", excerpt.path.display()))
217        })
218        .collect();
219
220    let mut false_positive_lines = actual_context_lines.clone();
221
222    for entry in &example.expected_context {
223        let mut best_alternative_score = Scores::default();
224
225        for alternative in &entry.alternatives {
226            let expected: HashSet<_> = alternative
227                .excerpts
228                .iter()
229                .flat_map(|excerpt| {
230                    excerpt
231                        .text
232                        .lines()
233                        .map(|line| format!("{}: {line}", excerpt.path.display()))
234                })
235                .collect();
236
237            let scores = Scores::new(&expected, &actual_context_lines);
238
239            false_positive_lines.retain(|line| !actual_context_lines.contains(line));
240
241            if scores.recall() > best_alternative_score.recall() {
242                best_alternative_score = scores;
243            }
244        }
245
246        eval_result.context.false_negatives += best_alternative_score.false_negatives;
247        eval_result.context.true_positives += best_alternative_score.true_positives;
248    }
249
250    eval_result.context.false_positives = false_positive_lines.len();
251
252    // todo: alternatives for patches
253    let expected_patch_lines = example
254        .expected_patch
255        .lines()
256        .map(DiffLine::parse)
257        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
258        .map(|line| line.to_string())
259        .collect();
260
261    let actual_patch_lines = preds
262        .diff
263        .lines()
264        .map(DiffLine::parse)
265        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
266        .map(|line| line.to_string())
267        .collect();
268
269    eval_result.edit_prediction = Scores::new(&expected_patch_lines, &actual_patch_lines);
270    eval_result
271}
272
273/// Return annotated `patch_a` so that:
274/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
275/// Additions and deletions that are present in `patch_b` will be highlighted in green.
276pub fn compare_diffs(patch_a: &str, patch_b: &str) -> String {
277    let use_color = std::io::stdout().is_terminal();
278    let green = if use_color { "\x1b[32m✓ " } else { "" };
279    let red = if use_color { "\x1b[31m✗ " } else { "" };
280    let neutral = if use_color { "  " } else { "" };
281    let reset = if use_color { "\x1b[0m" } else { "" };
282    let lines_a = patch_a.lines().map(DiffLine::parse);
283    let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
284
285    let annotated = lines_a
286        .map(|line| match line {
287            DiffLine::Addition(_) | DiffLine::Deletion(_) => {
288                if lines_b.contains(&line) {
289                    format!("{green}{line}{reset}")
290                } else {
291                    format!("{red}{line}{reset}")
292                }
293            }
294            _ => format!("{neutral}{line}{reset}"),
295        })
296        .collect::<Vec<String>>();
297
298    annotated.join("\n")
299}