evaluate.rs

  1use std::{
  2    fs,
  3    io::IsTerminal,
  4    path::{Path, PathBuf},
  5    sync::Arc,
  6};
  7
  8use anyhow::Result;
  9use clap::Args;
 10use collections::HashSet;
 11use gpui::AsyncApp;
 12use zeta2::udiff::DiffLine;
 13
 14use crate::{
 15    example::{Example, NamedExample},
 16    headless::ZetaCliAppState,
 17    paths::CACHE_DIR,
 18    predict::{PredictionDetails, zeta2_predict},
 19};
 20
 21#[derive(Debug, Args)]
 22pub struct EvaluateArguments {
 23    example_paths: Vec<PathBuf>,
 24    #[clap(long)]
 25    re_run: bool,
 26}
 27
 28pub async fn run_evaluate(
 29    args: EvaluateArguments,
 30    app_state: &Arc<ZetaCliAppState>,
 31    cx: &mut AsyncApp,
 32) {
 33    let example_len = args.example_paths.len();
 34    let all_tasks = args.example_paths.into_iter().map(|path| {
 35        let app_state = app_state.clone();
 36        cx.spawn(async move |cx| run_evaluate_one(&path, args.re_run, app_state.clone(), cx).await)
 37    });
 38    let all_results = futures::future::try_join_all(all_tasks).await.unwrap();
 39
 40    let aggregated_result = EvaluationResult {
 41        context: Scores::aggregate(all_results.iter().map(|r| &r.context)),
 42        edit_prediction: Scores::aggregate(all_results.iter().map(|r| &r.edit_prediction)),
 43    };
 44
 45    if example_len > 1 {
 46        println!("\n{}", "-".repeat(80));
 47        println!("# TOTAL SCORES:");
 48        println!("{}", aggregated_result.to_markdown());
 49    }
 50}
 51
 52pub async fn run_evaluate_one(
 53    example_path: &Path,
 54    re_run: bool,
 55    app_state: Arc<ZetaCliAppState>,
 56    cx: &mut AsyncApp,
 57) -> Result<EvaluationResult> {
 58    let example = NamedExample::load(&example_path).unwrap();
 59    let example_cache_path = CACHE_DIR.join(&example_path.file_name().unwrap());
 60
 61    let predictions = if !re_run && example_cache_path.exists() {
 62        let file_contents = fs::read_to_string(&example_cache_path)?;
 63        let as_json = serde_json::from_str::<PredictionDetails>(&file_contents)?;
 64        log::debug!(
 65            "Loaded predictions from cache: {}",
 66            example_cache_path.display()
 67        );
 68        as_json
 69    } else {
 70        zeta2_predict(example.clone(), &app_state, cx)
 71            .await
 72            .unwrap()
 73    };
 74
 75    if !example_cache_path.exists() {
 76        fs::create_dir_all(&*CACHE_DIR).unwrap();
 77        fs::write(
 78            example_cache_path,
 79            serde_json::to_string(&predictions).unwrap(),
 80        )
 81        .unwrap();
 82    }
 83
 84    let evaluation_result = evaluate(&example.example, &predictions);
 85
 86    println!(
 87        "## Expected edit prediction:\n\n```diff\n{}\n```\n",
 88        compare_diffs(&example.example.expected_patch, &predictions.diff)
 89    );
 90    println!(
 91        "## Actual edit prediction:\n\n```diff\n{}\n```\n",
 92        compare_diffs(&predictions.diff, &example.example.expected_patch)
 93    );
 94
 95    println!("{}", evaluation_result.to_markdown());
 96
 97    anyhow::Ok(evaluation_result)
 98}
 99
100#[derive(Debug, Default)]
101pub struct EvaluationResult {
102    pub edit_prediction: Scores,
103    pub context: Scores,
104}
105
106#[derive(Default, Debug)]
107pub struct Scores {
108    pub true_positives: usize,
109    pub false_positives: usize,
110    pub false_negatives: usize,
111}
112
113impl Scores {
114    pub fn new(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
115        let true_positives = expected.intersection(actual).count();
116        let false_positives = actual.difference(expected).count();
117        let false_negatives = expected.difference(actual).count();
118
119        Scores {
120            true_positives,
121            false_positives,
122            false_negatives,
123        }
124    }
125
126    pub fn to_markdown(&self) -> String {
127        format!(
128            "
129Precision       : {:.4}
130Recall          : {:.4}
131F1 Score        : {:.4}
132True Positives  : {}
133False Positives : {}
134False Negatives : {}",
135            self.precision(),
136            self.recall(),
137            self.f1_score(),
138            self.true_positives,
139            self.false_positives,
140            self.false_negatives
141        )
142    }
143
144    pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
145        let mut true_positives = 0;
146        let mut false_positives = 0;
147        let mut false_negatives = 0;
148
149        for score in scores {
150            true_positives += score.true_positives;
151            false_positives += score.false_positives;
152            false_negatives += score.false_negatives;
153        }
154
155        Scores {
156            true_positives,
157            false_positives,
158            false_negatives,
159        }
160    }
161
162    pub fn precision(&self) -> f64 {
163        if self.true_positives + self.false_positives == 0 {
164            0.0
165        } else {
166            self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
167        }
168    }
169
170    pub fn recall(&self) -> f64 {
171        if self.true_positives + self.false_negatives == 0 {
172            0.0
173        } else {
174            self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
175        }
176    }
177
178    pub fn f1_score(&self) -> f64 {
179        let recall = self.recall();
180        let precision = self.precision();
181        if precision + recall == 0.0 {
182            0.0
183        } else {
184            2.0 * precision * recall / (precision + recall)
185        }
186    }
187}
188
189impl EvaluationResult {
190    pub fn to_markdown(&self) -> String {
191        format!(
192            r#"
193### Context Scores
194{}
195
196### Edit Prediction Scores
197{}
198"#,
199            self.context.to_markdown(),
200            self.edit_prediction.to_markdown()
201        )
202    }
203}
204
205pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult {
206    let mut eval_result = EvaluationResult::default();
207
208    let actual_context_lines: HashSet<_> = preds
209        .excerpts
210        .iter()
211        .flat_map(|excerpt| {
212            excerpt
213                .text
214                .lines()
215                .map(|line| format!("{}: {line}", excerpt.path.display()))
216        })
217        .collect();
218
219    let mut false_positive_lines = actual_context_lines.clone();
220
221    for entry in &example.expected_context {
222        let mut best_alternative_score = Scores::default();
223
224        for alternative in &entry.alternatives {
225            let expected: HashSet<_> = alternative
226                .excerpts
227                .iter()
228                .flat_map(|excerpt| {
229                    excerpt
230                        .text
231                        .lines()
232                        .map(|line| format!("{}: {line}", excerpt.path.display()))
233                })
234                .collect();
235
236            let scores = Scores::new(&expected, &actual_context_lines);
237
238            false_positive_lines.retain(|line| !actual_context_lines.contains(line));
239
240            if scores.recall() > best_alternative_score.recall() {
241                best_alternative_score = scores;
242            }
243        }
244
245        eval_result.context.false_negatives += best_alternative_score.false_negatives;
246        eval_result.context.true_positives += best_alternative_score.true_positives;
247    }
248
249    eval_result.context.false_positives = false_positive_lines.len();
250
251    // todo: alternatives for patches
252    let expected_patch_lines = example
253        .expected_patch
254        .lines()
255        .map(DiffLine::parse)
256        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
257        .map(|line| line.to_string())
258        .collect();
259
260    let actual_patch_lines = preds
261        .diff
262        .lines()
263        .map(DiffLine::parse)
264        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
265        .map(|line| line.to_string())
266        .collect();
267
268    eval_result.edit_prediction = Scores::new(&expected_patch_lines, &actual_patch_lines);
269    eval_result
270}
271
272/// Return annotated `patch_a` so that:
273/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
274/// Additions and deletions that are present in `patch_b` will be highlighted in green.
275pub fn compare_diffs(patch_a: &str, patch_b: &str) -> String {
276    let use_color = std::io::stdout().is_terminal();
277    let green = if use_color { "\x1b[32m✓ " } else { "" };
278    let red = if use_color { "\x1b[31m✗ " } else { "" };
279    let neutral = if use_color { "  " } else { "" };
280    let reset = if use_color { "\x1b[0m" } else { "" };
281    let lines_a = patch_a.lines().map(DiffLine::parse);
282    let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
283
284    let annotated = lines_a
285        .map(|line| match line {
286            DiffLine::Addition(_) | DiffLine::Deletion(_) => {
287                if lines_b.contains(&line) {
288                    format!("{green}{line}{reset}")
289                } else {
290                    format!("{red}{line}{reset}")
291                }
292            }
293            _ => format!("{neutral}{line}{reset}"),
294        })
295        .collect::<Vec<String>>();
296
297    annotated.join("\n")
298}