evaluate.rs

  1use std::{
  2    io::IsTerminal,
  3    path::{Path, PathBuf},
  4    sync::Arc,
  5};
  6
  7use anyhow::Result;
  8use clap::Args;
  9use collections::HashSet;
 10use gpui::AsyncApp;
 11use zeta2::udiff::DiffLine;
 12
 13use crate::{
 14    PromptFormat,
 15    example::{Example, NamedExample},
 16    headless::ZetaCliAppState,
 17    paths::print_run_data_dir,
 18    predict::{CacheMode, PredictionDetails, zeta2_predict},
 19};
 20
 21#[derive(Debug, Args)]
 22pub struct EvaluateArguments {
 23    example_paths: Vec<PathBuf>,
 24    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
 25    prompt_format: PromptFormat,
 26    #[arg(long)]
 27    use_expected_context: bool,
 28    #[clap(long, value_enum, default_value_t = CacheMode::default())]
 29    cache: CacheMode,
 30}
 31
 32pub async fn run_evaluate(
 33    args: EvaluateArguments,
 34    app_state: &Arc<ZetaCliAppState>,
 35    cx: &mut AsyncApp,
 36) {
 37    let example_len = args.example_paths.len();
 38    let all_tasks = args.example_paths.into_iter().map(|path| {
 39        let app_state = app_state.clone();
 40        cx.spawn(async move |cx| {
 41            run_evaluate_one(
 42                &path,
 43                args.prompt_format,
 44                args.use_expected_context,
 45                args.cache,
 46                app_state.clone(),
 47                cx,
 48            )
 49            .await
 50        })
 51    });
 52    let all_results = futures::future::try_join_all(all_tasks).await;
 53
 54    if let Ok(all_results) = &all_results {
 55        let aggregated_result = EvaluationResult {
 56            context: Scores::aggregate(all_results.iter().map(|r| &r.context)),
 57            edit_prediction: Scores::aggregate(all_results.iter().map(|r| &r.edit_prediction)),
 58        };
 59
 60        if example_len > 1 {
 61            println!("\n{}", "-".repeat(80));
 62            println!("\n## TOTAL SCORES");
 63            println!("{}", aggregated_result.to_markdown());
 64        }
 65    }
 66
 67    print_run_data_dir();
 68
 69    all_results.unwrap();
 70}
 71
 72pub async fn run_evaluate_one(
 73    example_path: &Path,
 74    prompt_format: PromptFormat,
 75    use_expected_context: bool,
 76    cache_mode: CacheMode,
 77    app_state: Arc<ZetaCliAppState>,
 78    cx: &mut AsyncApp,
 79) -> Result<EvaluationResult> {
 80    let example = NamedExample::load(&example_path).unwrap();
 81    let predictions = zeta2_predict(
 82        example.clone(),
 83        prompt_format,
 84        use_expected_context,
 85        cache_mode,
 86        &app_state,
 87        cx,
 88    )
 89    .await
 90    .unwrap();
 91
 92    let evaluation_result = evaluate(&example.example, &predictions);
 93
 94    println!(
 95        "## Expected edit prediction:\n\n```diff\n{}\n```\n",
 96        compare_diffs(&example.example.expected_patch, &predictions.diff)
 97    );
 98    println!(
 99        "## Actual edit prediction:\n\n```diff\n{}\n```\n",
100        compare_diffs(&predictions.diff, &example.example.expected_patch)
101    );
102
103    println!("{}", evaluation_result.to_markdown());
104
105    anyhow::Ok(evaluation_result)
106}
107
108#[derive(Debug, Default)]
109pub struct EvaluationResult {
110    pub edit_prediction: Scores,
111    pub context: Scores,
112}
113
114#[derive(Default, Debug)]
115pub struct Scores {
116    pub true_positives: usize,
117    pub false_positives: usize,
118    pub false_negatives: usize,
119}
120
121impl Scores {
122    pub fn new(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
123        let true_positives = expected.intersection(actual).count();
124        let false_positives = actual.difference(expected).count();
125        let false_negatives = expected.difference(actual).count();
126
127        Scores {
128            true_positives,
129            false_positives,
130            false_negatives,
131        }
132    }
133
134    pub fn to_markdown(&self) -> String {
135        format!(
136            "
137Precision       : {:.4}
138Recall          : {:.4}
139F1 Score        : {:.4}
140True Positives  : {}
141False Positives : {}
142False Negatives : {}",
143            self.precision(),
144            self.recall(),
145            self.f1_score(),
146            self.true_positives,
147            self.false_positives,
148            self.false_negatives
149        )
150    }
151
152    pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
153        let mut true_positives = 0;
154        let mut false_positives = 0;
155        let mut false_negatives = 0;
156
157        for score in scores {
158            true_positives += score.true_positives;
159            false_positives += score.false_positives;
160            false_negatives += score.false_negatives;
161        }
162
163        Scores {
164            true_positives,
165            false_positives,
166            false_negatives,
167        }
168    }
169
170    pub fn precision(&self) -> f64 {
171        if self.true_positives + self.false_positives == 0 {
172            0.0
173        } else {
174            self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
175        }
176    }
177
178    pub fn recall(&self) -> f64 {
179        if self.true_positives + self.false_negatives == 0 {
180            0.0
181        } else {
182            self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
183        }
184    }
185
186    pub fn f1_score(&self) -> f64 {
187        let recall = self.recall();
188        let precision = self.precision();
189        if precision + recall == 0.0 {
190            0.0
191        } else {
192            2.0 * precision * recall / (precision + recall)
193        }
194    }
195}
196
197impl EvaluationResult {
198    pub fn to_markdown(&self) -> String {
199        format!(
200            r#"
201### Context Scores
202{}
203
204### Edit Prediction Scores
205{}
206"#,
207            self.context.to_markdown(),
208            self.edit_prediction.to_markdown()
209        )
210    }
211}
212
213pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult {
214    let mut eval_result = EvaluationResult::default();
215
216    let actual_context_lines: HashSet<_> = preds
217        .excerpts
218        .iter()
219        .flat_map(|excerpt| {
220            excerpt
221                .text
222                .lines()
223                .map(|line| format!("{}: {line}", excerpt.path.display()))
224        })
225        .collect();
226
227    let mut false_positive_lines = actual_context_lines.clone();
228
229    for entry in &example.expected_context {
230        let mut best_alternative_score = Scores::default();
231
232        for alternative in &entry.alternatives {
233            let expected: HashSet<_> = alternative
234                .excerpts
235                .iter()
236                .flat_map(|excerpt| {
237                    excerpt
238                        .text
239                        .lines()
240                        .map(|line| format!("{}: {line}", excerpt.path.display()))
241                })
242                .collect();
243
244            let scores = Scores::new(&expected, &actual_context_lines);
245
246            false_positive_lines.retain(|line| !actual_context_lines.contains(line));
247
248            if scores.recall() > best_alternative_score.recall() {
249                best_alternative_score = scores;
250            }
251        }
252
253        eval_result.context.false_negatives += best_alternative_score.false_negatives;
254        eval_result.context.true_positives += best_alternative_score.true_positives;
255    }
256
257    eval_result.context.false_positives = false_positive_lines.len();
258
259    // todo: alternatives for patches
260    let expected_patch_lines = example
261        .expected_patch
262        .lines()
263        .map(DiffLine::parse)
264        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
265        .map(|line| line.to_string())
266        .collect();
267
268    let actual_patch_lines = preds
269        .diff
270        .lines()
271        .map(DiffLine::parse)
272        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
273        .map(|line| line.to_string())
274        .collect();
275
276    eval_result.edit_prediction = Scores::new(&expected_patch_lines, &actual_patch_lines);
277    eval_result
278}
279
280/// Return annotated `patch_a` so that:
281/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
282/// Additions and deletions that are present in `patch_b` will be highlighted in green.
283pub fn compare_diffs(patch_a: &str, patch_b: &str) -> String {
284    let use_color = std::io::stdout().is_terminal();
285    let green = if use_color { "\x1b[32m✓ " } else { "" };
286    let red = if use_color { "\x1b[31m✗ " } else { "" };
287    let neutral = if use_color { "  " } else { "" };
288    let reset = if use_color { "\x1b[0m" } else { "" };
289    let lines_a = patch_a.lines().map(DiffLine::parse);
290    let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
291
292    let annotated = lines_a
293        .map(|line| match line {
294            DiffLine::Addition(_) | DiffLine::Deletion(_) => {
295                if lines_b.contains(&line) {
296                    format!("{green}{line}{reset}")
297                } else {
298                    format!("{red}{line}{reset}")
299                }
300            }
301            _ => format!("{neutral}{line}{reset}"),
302        })
303        .collect::<Vec<String>>();
304
305    annotated.join("\n")
306}