evaluate.rs

  1use std::{
  2    fs,
  3    io::IsTerminal,
  4    path::{Path, PathBuf},
  5    sync::Arc,
  6};
  7
  8use anyhow::Result;
  9use clap::Args;
 10use collections::HashSet;
 11use gpui::AsyncApp;
 12use zeta2::udiff::DiffLine;
 13
 14use crate::{
 15    example::{Example, NamedExample},
 16    headless::ZetaCliAppState,
 17    paths::CACHE_DIR,
 18    predict::{PredictionDetails, zeta2_predict},
 19};
 20
 21#[derive(Debug, Args)]
 22pub struct EvaluateArguments {
 23    example_paths: Vec<PathBuf>,
 24    #[clap(long)]
 25    re_run: bool,
 26}
 27
 28pub async fn run_evaluate(
 29    args: EvaluateArguments,
 30    app_state: &Arc<ZetaCliAppState>,
 31    cx: &mut AsyncApp,
 32) {
 33    let example_len = args.example_paths.len();
 34    let all_tasks = args.example_paths.into_iter().map(|path| {
 35        let app_state = app_state.clone();
 36        cx.spawn(async move |cx| run_evaluate_one(&path, args.re_run, app_state.clone(), cx).await)
 37    });
 38    let all_results = futures::future::try_join_all(all_tasks).await.unwrap();
 39
 40    let aggregated_result = EvaluationResult {
 41        context: Scores::aggregate(all_results.iter().map(|r| &r.context)),
 42        edit_prediction: Scores::aggregate(all_results.iter().map(|r| &r.edit_prediction)),
 43    };
 44
 45    if example_len > 1 {
 46        println!("\n{}", "-".repeat(80));
 47        println!("# TOTAL SCORES:");
 48        println!("{}", aggregated_result.to_markdown());
 49    }
 50}
 51
 52pub async fn run_evaluate_one(
 53    example_path: &Path,
 54    re_run: bool,
 55    app_state: Arc<ZetaCliAppState>,
 56    cx: &mut AsyncApp,
 57) -> Result<EvaluationResult> {
 58    let example = NamedExample::load(&example_path).unwrap();
 59    let example_cache_path = CACHE_DIR.join(&example_path.file_name().unwrap());
 60
 61    let predictions = if !re_run && example_cache_path.exists() {
 62        let file_contents = fs::read_to_string(&example_cache_path)?;
 63        let as_json = serde_json::from_str::<PredictionDetails>(&file_contents)?;
 64        log::debug!(
 65            "Loaded predictions from cache: {}",
 66            example_cache_path.display()
 67        );
 68        as_json
 69    } else {
 70        zeta2_predict(example.clone(), &app_state, cx)
 71            .await
 72            .unwrap()
 73    };
 74
 75    if !example_cache_path.exists() {
 76        fs::create_dir_all(&*CACHE_DIR).unwrap();
 77        fs::write(
 78            example_cache_path,
 79            serde_json::to_string(&predictions).unwrap(),
 80        )
 81        .unwrap();
 82    }
 83
 84    let evaluation_result = evaluate(&example.example, &predictions);
 85
 86    println!("# {}\n", example.name);
 87    println!(
 88        "## Expected Context: \n\n```\n{}\n```\n\n",
 89        compare_context(&example.example, &predictions)
 90    );
 91    println!(
 92        "## Expected edit prediction:\n\n```diff\n{}\n```\n",
 93        compare_diffs(&example.example.expected_patch, &predictions.diff)
 94    );
 95    println!(
 96        "## Actual edit prediction:\n\n```diff\n{}\n```\n",
 97        compare_diffs(&predictions.diff, &example.example.expected_patch)
 98    );
 99
100    println!("{}", evaluation_result.to_markdown());
101
102    anyhow::Ok(evaluation_result)
103}
104
105#[derive(Debug, Default)]
106pub struct EvaluationResult {
107    pub context: Scores,
108    pub edit_prediction: Scores,
109}
110
111#[derive(Default, Debug)]
112pub struct Scores {
113    pub precision: f64,
114    pub recall: f64,
115    pub f1_score: f64,
116    pub true_positives: usize,
117    pub false_positives: usize,
118    pub false_negatives: usize,
119}
120
121impl Scores {
122    pub fn to_markdown(&self) -> String {
123        format!(
124            "
125Precision       : {:.4}
126Recall          : {:.4}
127F1 Score        : {:.4}
128True Positives  : {}
129False Positives : {}
130False Negatives : {}",
131            self.precision,
132            self.recall,
133            self.f1_score,
134            self.true_positives,
135            self.false_positives,
136            self.false_negatives
137        )
138    }
139}
140
141impl Scores {
142    pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
143        let mut true_positives = 0;
144        let mut false_positives = 0;
145        let mut false_negatives = 0;
146
147        for score in scores {
148            true_positives += score.true_positives;
149            false_positives += score.false_positives;
150            false_negatives += score.false_negatives;
151        }
152
153        let precision = true_positives as f64 / (true_positives + false_positives) as f64;
154        let recall = true_positives as f64 / (true_positives + false_negatives) as f64;
155        let mut f1_score = 2.0 * precision * recall / (precision + recall);
156        if f1_score.is_nan() {
157            f1_score = 0.0;
158        }
159
160        Scores {
161            precision,
162            recall,
163            f1_score,
164            true_positives,
165            false_positives,
166            false_negatives,
167        }
168    }
169}
170
171impl EvaluationResult {
172    pub fn to_markdown(&self) -> String {
173        format!(
174            r#"
175### Context Scores
176{}
177
178### Edit Prediction Scores
179{}
180"#,
181            self.context.to_markdown(),
182            self.edit_prediction.to_markdown()
183        )
184    }
185}
186
187pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult {
188    let mut result = EvaluationResult::default();
189
190    let expected_context_lines = example
191        .expected_excerpts
192        .iter()
193        .flat_map(|excerpt| {
194            excerpt
195                .text
196                .lines()
197                .map(|line| format!("{}: {line}", excerpt.path.display()))
198        })
199        .collect();
200    let actual_context_lines = preds
201        .excerpts
202        .iter()
203        .flat_map(|excerpt| {
204            excerpt
205                .text
206                .lines()
207                .map(|line| format!("{}: {line}", excerpt.path.display()))
208        })
209        .collect();
210
211    result.context = precision_recall(&expected_context_lines, &actual_context_lines);
212
213    let expected_patch_lines = example
214        .expected_patch
215        .lines()
216        .map(DiffLine::parse)
217        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
218        .map(|line| line.to_string())
219        .collect();
220
221    let actual_patch_lines = preds
222        .diff
223        .lines()
224        .map(DiffLine::parse)
225        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
226        .map(|line| line.to_string())
227        .collect();
228
229    result.edit_prediction = precision_recall(&expected_patch_lines, &actual_patch_lines);
230
231    result
232}
233
234fn precision_recall(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
235    let true_positives = expected.intersection(actual).count();
236    let false_positives = actual.difference(expected).count();
237    let false_negatives = expected.difference(actual).count();
238
239    let precision = if true_positives + false_positives == 0 {
240        0.0
241    } else {
242        true_positives as f64 / (true_positives + false_positives) as f64
243    };
244    let recall = if true_positives + false_negatives == 0 {
245        0.0
246    } else {
247        true_positives as f64 / (true_positives + false_negatives) as f64
248    };
249    let f1_score = if precision + recall == 0.0 {
250        0.0
251    } else {
252        2.0 * precision * recall / (precision + recall)
253    };
254
255    Scores {
256        precision,
257        recall,
258        f1_score,
259        true_positives,
260        false_positives,
261        false_negatives,
262    }
263}
264
265/// Compare actual and expected context.
266///
267/// Return expected context annotated with these markers:
268///
269/// `✓ context line`  -- line was correctly predicted
270/// `✗ context line`  -- line is missing from predictions
271pub fn compare_context(example: &Example, preds: &PredictionDetails) -> String {
272    let use_color = std::io::stdout().is_terminal();
273    let green = if use_color { "\x1b[32m" } else { "" };
274    let red = if use_color { "\x1b[31m" } else { "" };
275    let reset = if use_color { "\x1b[0m" } else { "" };
276    let expected: Vec<_> = example
277        .expected_excerpts
278        .iter()
279        .flat_map(|excerpt| {
280            excerpt
281                .text
282                .lines()
283                .map(|line| (excerpt.path.clone(), line))
284        })
285        .collect();
286    let actual: HashSet<_> = preds
287        .excerpts
288        .iter()
289        .flat_map(|excerpt| {
290            excerpt
291                .text
292                .lines()
293                .map(|line| (excerpt.path.clone(), line))
294        })
295        .collect();
296
297    let annotated = expected
298        .iter()
299        .map(|(path, line)| {
300            if actual.contains(&(path.to_path_buf(), line)) {
301                format!("{green}{line}{reset}")
302            } else {
303                format!("{red}{line}{reset}")
304            }
305        })
306        .collect::<Vec<String>>();
307
308    annotated.join("\n")
309}
310
311/// Return annotated `patch_a` so that:
312/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
313/// Additions and deletions that are present in `patch_b` will be highlighted in green.
314pub fn compare_diffs(patch_a: &str, patch_b: &str) -> String {
315    let use_color = std::io::stdout().is_terminal();
316    let green = if use_color { "\x1b[32m✓ " } else { "" };
317    let red = if use_color { "\x1b[31m✗ " } else { "" };
318    let neutral = if use_color { "  " } else { "" };
319    let reset = if use_color { "\x1b[0m" } else { "" };
320    let lines_a = patch_a.lines().map(DiffLine::parse);
321    let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
322
323    let annotated = lines_a
324        .map(|line| match line {
325            DiffLine::Addition(_) | DiffLine::Deletion(_) => {
326                if lines_b.contains(&line) {
327                    format!("{green}{line}{reset}")
328                } else {
329                    format!("{red}{line}{reset}")
330                }
331            }
332            _ => format!("{neutral}{line}{reset}"),
333        })
334        .collect::<Vec<String>>();
335
336    annotated.join("\n")
337}