evaluate.rs

  1use std::{
  2    io::IsTerminal,
  3    path::{Path, PathBuf},
  4    sync::Arc,
  5};
  6
  7use anyhow::Result;
  8use clap::Args;
  9use collections::HashSet;
 10use gpui::AsyncApp;
 11use zeta2::udiff::DiffLine;
 12
 13use crate::{
 14    PromptFormat,
 15    example::{Example, NamedExample},
 16    headless::ZetaCliAppState,
 17    predict::{PredictionDetails, zeta2_predict},
 18};
 19
 20#[derive(Debug, Args)]
 21pub struct EvaluateArguments {
 22    example_paths: Vec<PathBuf>,
 23    #[clap(long)]
 24    skip_cache: bool,
 25    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
 26    prompt_format: PromptFormat,
 27}
 28
 29pub async fn run_evaluate(
 30    args: EvaluateArguments,
 31    app_state: &Arc<ZetaCliAppState>,
 32    cx: &mut AsyncApp,
 33) {
 34    let example_len = args.example_paths.len();
 35    let all_tasks = args.example_paths.into_iter().map(|path| {
 36        let app_state = app_state.clone();
 37        cx.spawn(async move |cx| {
 38            run_evaluate_one(
 39                &path,
 40                args.skip_cache,
 41                args.prompt_format,
 42                app_state.clone(),
 43                cx,
 44            )
 45            .await
 46        })
 47    });
 48    let all_results = futures::future::try_join_all(all_tasks).await.unwrap();
 49
 50    let aggregated_result = EvaluationResult {
 51        context: Scores::aggregate(all_results.iter().map(|r| &r.context)),
 52        edit_prediction: Scores::aggregate(all_results.iter().map(|r| &r.edit_prediction)),
 53    };
 54
 55    if example_len > 1 {
 56        println!("\n{}", "-".repeat(80));
 57        println!("# TOTAL SCORES:");
 58        println!("{}", aggregated_result.to_markdown());
 59    }
 60}
 61
 62pub async fn run_evaluate_one(
 63    example_path: &Path,
 64    skip_cache: bool,
 65    prompt_format: PromptFormat,
 66    app_state: Arc<ZetaCliAppState>,
 67    cx: &mut AsyncApp,
 68) -> Result<EvaluationResult> {
 69    let example = NamedExample::load(&example_path).unwrap();
 70    let predictions = zeta2_predict(example.clone(), skip_cache, prompt_format, &app_state, cx)
 71        .await
 72        .unwrap();
 73
 74    let evaluation_result = evaluate(&example.example, &predictions);
 75
 76    println!(
 77        "## Expected edit prediction:\n\n```diff\n{}\n```\n",
 78        compare_diffs(&example.example.expected_patch, &predictions.diff)
 79    );
 80    println!(
 81        "## Actual edit prediction:\n\n```diff\n{}\n```\n",
 82        compare_diffs(&predictions.diff, &example.example.expected_patch)
 83    );
 84
 85    println!("{}", evaluation_result.to_markdown());
 86
 87    anyhow::Ok(evaluation_result)
 88}
 89
 90#[derive(Debug, Default)]
 91pub struct EvaluationResult {
 92    pub edit_prediction: Scores,
 93    pub context: Scores,
 94}
 95
 96#[derive(Default, Debug)]
 97pub struct Scores {
 98    pub true_positives: usize,
 99    pub false_positives: usize,
100    pub false_negatives: usize,
101}
102
103impl Scores {
104    pub fn new(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
105        let true_positives = expected.intersection(actual).count();
106        let false_positives = actual.difference(expected).count();
107        let false_negatives = expected.difference(actual).count();
108
109        Scores {
110            true_positives,
111            false_positives,
112            false_negatives,
113        }
114    }
115
116    pub fn to_markdown(&self) -> String {
117        format!(
118            "
119Precision       : {:.4}
120Recall          : {:.4}
121F1 Score        : {:.4}
122True Positives  : {}
123False Positives : {}
124False Negatives : {}",
125            self.precision(),
126            self.recall(),
127            self.f1_score(),
128            self.true_positives,
129            self.false_positives,
130            self.false_negatives
131        )
132    }
133
134    pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
135        let mut true_positives = 0;
136        let mut false_positives = 0;
137        let mut false_negatives = 0;
138
139        for score in scores {
140            true_positives += score.true_positives;
141            false_positives += score.false_positives;
142            false_negatives += score.false_negatives;
143        }
144
145        Scores {
146            true_positives,
147            false_positives,
148            false_negatives,
149        }
150    }
151
152    pub fn precision(&self) -> f64 {
153        if self.true_positives + self.false_positives == 0 {
154            0.0
155        } else {
156            self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
157        }
158    }
159
160    pub fn recall(&self) -> f64 {
161        if self.true_positives + self.false_negatives == 0 {
162            0.0
163        } else {
164            self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
165        }
166    }
167
168    pub fn f1_score(&self) -> f64 {
169        let recall = self.recall();
170        let precision = self.precision();
171        if precision + recall == 0.0 {
172            0.0
173        } else {
174            2.0 * precision * recall / (precision + recall)
175        }
176    }
177}
178
179impl EvaluationResult {
180    pub fn to_markdown(&self) -> String {
181        format!(
182            r#"
183### Context Scores
184{}
185
186### Edit Prediction Scores
187{}
188"#,
189            self.context.to_markdown(),
190            self.edit_prediction.to_markdown()
191        )
192    }
193}
194
195pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult {
196    let mut eval_result = EvaluationResult::default();
197
198    let actual_context_lines: HashSet<_> = preds
199        .excerpts
200        .iter()
201        .flat_map(|excerpt| {
202            excerpt
203                .text
204                .lines()
205                .map(|line| format!("{}: {line}", excerpt.path.display()))
206        })
207        .collect();
208
209    let mut false_positive_lines = actual_context_lines.clone();
210
211    for entry in &example.expected_context {
212        let mut best_alternative_score = Scores::default();
213
214        for alternative in &entry.alternatives {
215            let expected: HashSet<_> = alternative
216                .excerpts
217                .iter()
218                .flat_map(|excerpt| {
219                    excerpt
220                        .text
221                        .lines()
222                        .map(|line| format!("{}: {line}", excerpt.path.display()))
223                })
224                .collect();
225
226            let scores = Scores::new(&expected, &actual_context_lines);
227
228            false_positive_lines.retain(|line| !actual_context_lines.contains(line));
229
230            if scores.recall() > best_alternative_score.recall() {
231                best_alternative_score = scores;
232            }
233        }
234
235        eval_result.context.false_negatives += best_alternative_score.false_negatives;
236        eval_result.context.true_positives += best_alternative_score.true_positives;
237    }
238
239    eval_result.context.false_positives = false_positive_lines.len();
240
241    // todo: alternatives for patches
242    let expected_patch_lines = example
243        .expected_patch
244        .lines()
245        .map(DiffLine::parse)
246        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
247        .map(|line| line.to_string())
248        .collect();
249
250    let actual_patch_lines = preds
251        .diff
252        .lines()
253        .map(DiffLine::parse)
254        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
255        .map(|line| line.to_string())
256        .collect();
257
258    eval_result.edit_prediction = Scores::new(&expected_patch_lines, &actual_patch_lines);
259    eval_result
260}
261
262/// Return annotated `patch_a` so that:
263/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
264/// Additions and deletions that are present in `patch_b` will be highlighted in green.
265pub fn compare_diffs(patch_a: &str, patch_b: &str) -> String {
266    let use_color = std::io::stdout().is_terminal();
267    let green = if use_color { "\x1b[32m✓ " } else { "" };
268    let red = if use_color { "\x1b[31m✗ " } else { "" };
269    let neutral = if use_color { "  " } else { "" };
270    let reset = if use_color { "\x1b[0m" } else { "" };
271    let lines_a = patch_a.lines().map(DiffLine::parse);
272    let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
273
274    let annotated = lines_a
275        .map(|line| match line {
276            DiffLine::Addition(_) | DiffLine::Deletion(_) => {
277                if lines_b.contains(&line) {
278                    format!("{green}{line}{reset}")
279                } else {
280                    format!("{red}{line}{reset}")
281                }
282            }
283            _ => format!("{neutral}{line}{reset}"),
284        })
285        .collect::<Vec<String>>();
286
287    annotated.join("\n")
288}