evaluate.rs

  1use std::{
  2    fs,
  3    io::IsTerminal,
  4    path::{Path, PathBuf},
  5    sync::Arc,
  6};
  7
  8use anyhow::Result;
  9use clap::Args;
 10use cloud_llm_client::udiff::DiffLine;
 11use collections::HashSet;
 12use gpui::AsyncApp;
 13
 14use crate::{
 15    example::{Example, Excerpt, NamedExample},
 16    headless::ZetaCliAppState,
 17    predict::{PredictionDetails, zeta2_predict},
 18};
 19
 20#[derive(Debug, Args)]
 21pub struct EvaluateArguments {
 22    example_paths: Vec<PathBuf>,
 23    #[clap(long)]
 24    re_run: bool,
 25}
 26
 27pub async fn run_evaluate(
 28    args: EvaluateArguments,
 29    app_state: &Arc<ZetaCliAppState>,
 30    cx: &mut AsyncApp,
 31) {
 32    let example_len = args.example_paths.len();
 33    let all_tasks = args.example_paths.into_iter().map(|path| {
 34        let app_state = app_state.clone();
 35        cx.spawn(async move |cx| run_evaluate_one(&path, args.re_run, app_state.clone(), cx).await)
 36    });
 37    let all_results = futures::future::try_join_all(all_tasks).await.unwrap();
 38
 39    let aggregated_result = EvaluationResult {
 40        context: Scores::aggregate(all_results.iter().map(|r| &r.context)),
 41        edit_prediction: Scores::aggregate(all_results.iter().map(|r| &r.edit_prediction)),
 42        edit_sites_coverage: all_results
 43            .iter()
 44            .map(|r| r.edit_sites_coverage)
 45            .sum::<f64>()
 46            / all_results.len() as f64,
 47    };
 48
 49    if example_len > 1 {
 50        println!("\n{}", "-".repeat(80));
 51        println!("# TOTAL SCORES:");
 52        println!("{}", aggregated_result.to_markdown());
 53    }
 54}
 55
 56pub async fn run_evaluate_one(
 57    example_path: &Path,
 58    re_run: bool,
 59    app_state: Arc<ZetaCliAppState>,
 60    cx: &mut AsyncApp,
 61) -> Result<EvaluationResult> {
 62    let cache_dir = Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap_or_default())
 63        .join("../../target/zeta-prediction-cache");
 64    let example = NamedExample::load(&example_path).unwrap();
 65    let example_cache_path = cache_dir.join(&example_path.file_name().unwrap());
 66
 67    let predictions = if !re_run && example_cache_path.exists() {
 68        let file_contents = fs::read_to_string(&example_cache_path)?;
 69        let as_json = serde_json::from_str::<PredictionDetails>(&file_contents)?;
 70        log::debug!(
 71            "Loaded predictions from cache: {}",
 72            example_cache_path.display()
 73        );
 74        as_json
 75    } else {
 76        zeta2_predict(example.clone(), &app_state, cx)
 77            .await
 78            .unwrap()
 79    };
 80
 81    if !example_cache_path.exists() {
 82        fs::create_dir_all(&cache_dir).unwrap();
 83        fs::write(
 84            example_cache_path,
 85            serde_json::to_string(&predictions).unwrap(),
 86        )
 87        .unwrap();
 88    }
 89
 90    let evaluation_result = evaluate(&example.example, &predictions);
 91
 92    println!("# {}\n", example.name);
 93    println!(
 94        "## Expected Context: \n\n```\n{}\n```\n\n",
 95        compare_context(&example.example, &predictions)
 96    );
 97    println!(
 98        "## Expected edit prediction:\n\n```diff\n{}\n```\n",
 99        compare_diffs(&example.example.expected_patch, &predictions.diff)
100    );
101    println!(
102        "## Actual edit prediction:\n\n```diff\n{}\n```\n",
103        compare_diffs(&predictions.diff, &example.example.expected_patch)
104    );
105
106    println!("{}", evaluation_result.to_markdown());
107
108    anyhow::Ok(evaluation_result)
109}
110
111#[derive(Debug, Default)]
112pub struct EvaluationResult {
113    pub context: Scores,
114
115    /// Ratio of edited lines that we expect to edit (as indicated in the
116    /// expected patch) AND were included into the context
117    /// num_correctly_retrieved_lines / num_expected_lines
118    pub edit_sites_coverage: f64,
119
120    pub edit_prediction: Scores,
121}
122
123#[derive(Default, Debug)]
124pub struct Scores {
125    pub precision: f64,
126    pub recall: f64,
127    pub f1_score: f64,
128    pub true_positives: usize,
129    pub false_positives: usize,
130    pub false_negatives: usize,
131}
132
133impl Scores {
134    pub fn to_markdown(&self) -> String {
135        format!(
136            "
137Precision          : {:.4}
138Recall             : {:.4}
139F1 Score           : {:.4}
140True Positives     : {}
141False Positives    : {}
142False Negatives    : {}",
143            self.precision,
144            self.recall,
145            self.f1_score,
146            self.true_positives,
147            self.false_positives,
148            self.false_negatives
149        )
150    }
151}
152
153impl Scores {
154    pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
155        let mut true_positives = 0;
156        let mut false_positives = 0;
157        let mut false_negatives = 0;
158
159        for score in scores {
160            true_positives += score.true_positives;
161            false_positives += score.false_positives;
162            false_negatives += score.false_negatives;
163        }
164
165        let precision = true_positives as f64 / (true_positives + false_positives) as f64;
166        let recall = true_positives as f64 / (true_positives + false_negatives) as f64;
167        let mut f1_score = 2.0 * precision * recall / (precision + recall);
168        if f1_score.is_nan() {
169            f1_score = 0.0;
170        }
171
172        Scores {
173            precision,
174            recall,
175            f1_score,
176            true_positives,
177            false_positives,
178            false_negatives,
179        }
180    }
181}
182
183#[derive(Debug, Clone)]
184struct EditSitesScores {
185    num_edit_sites: u32,
186    num_correctly_retrieved: u32,
187}
188
189impl EvaluationResult {
190    pub fn to_markdown(&self) -> String {
191        format!(
192            r#"
193### Context Scores
194{}
195Edit sites coverage: {}
196
197### Edit Prediction Scores
198{}
199"#,
200            self.context.to_markdown(),
201            self.edit_sites_coverage,
202            self.edit_prediction.to_markdown()
203        )
204    }
205}
206
207pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult {
208    let mut result = EvaluationResult::default();
209
210    let expected_context_lines = example
211        .expected_excerpts
212        .iter()
213        .flat_map(|excerpt| {
214            excerpt
215                .text
216                .lines()
217                .map(|line| format!("{}: {line}", excerpt.path.display()))
218        })
219        .collect();
220    let actual_context_lines = preds
221        .excerpts
222        .iter()
223        .flat_map(|excerpt| {
224            excerpt
225                .text
226                .lines()
227                .map(|line| format!("{}: {line}", excerpt.path.display()))
228        })
229        .collect();
230
231    result.context = precision_recall(&expected_context_lines, &actual_context_lines);
232
233    let expected_patch_lines = example
234        .expected_patch
235        .lines()
236        .map(DiffLine::parse)
237        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
238        .map(|line| line.to_string())
239        .collect();
240
241    let actual_patch_lines = preds
242        .diff
243        .lines()
244        .map(DiffLine::parse)
245        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
246        .map(|line| line.to_string())
247        .collect();
248
249    result.edit_prediction = precision_recall(&expected_patch_lines, &actual_patch_lines);
250
251    result.edit_sites_coverage =
252        calculate_edit_sites_coverage(&example.expected_patch, &preds.excerpts);
253
254    result
255}
256
257/// Compute the ratio of lines that we expect to edit (are in the expected patch) that
258/// were included in the retrieved context
259/// `num_correctly_retrieved_lines / num_edited_lines_in_expected_patch`
260///
261/// In order to make an edit in some line, the model has to have an access to this line.
262/// If we don't include the line in the retrieved context, there's no chance to make an edit.
263///
264/// This metric reflects that, where 1.0 -- we retrieved all lines to be
265/// edited, and 0.0 -- we retrieved none of them.
266///
267/// Example:
268fn calculate_edit_sites_coverage(patch: &str, excerpts: &[Excerpt]) -> EditSitesScores {
269    // todo:
270    let expected_patch_lines = patch
271        .lines()
272        .map(DiffLine::parse)
273        .filter_map(|line| match line {
274            DiffLine::Deletion(text) => Some(text.trim().to_string()),
275            _ => None,
276        })
277        .collect::<Vec<_>>();
278
279    let correct_cases = expected_patch_lines
280        .iter()
281        .filter(|line| {
282            excerpts.iter().any(|excerpt| {
283                excerpt
284                    .text
285                    .lines()
286                    .any(|excerpt_line| excerpt_line == *line)
287            })
288        })
289        .count();
290    let total_cases = expected_patch_lines.len();
291
292    if total_cases == 0 {
293        0.0
294    } else {
295        correct_cases as f64 / total_cases as f64
296    }
297}
298
299fn precision_recall(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
300    let true_positives = expected.intersection(actual).count();
301    let false_positives = actual.difference(expected).count();
302    let false_negatives = expected.difference(actual).count();
303
304    let precision = if true_positives + false_positives == 0 {
305        0.0
306    } else {
307        true_positives as f64 / (true_positives + false_positives) as f64
308    };
309    let recall = if true_positives + false_negatives == 0 {
310        0.0
311    } else {
312        true_positives as f64 / (true_positives + false_negatives) as f64
313    };
314    let f1_score = if precision + recall == 0.0 {
315        0.0
316    } else {
317        2.0 * precision * recall / (precision + recall)
318    };
319
320    Scores {
321        precision,
322        recall,
323        f1_score,
324        true_positives,
325        false_positives,
326        false_negatives,
327    }
328}
329
330/// Compare actual and expected context.
331///
332/// Return expected context annotated with these markers:
333///
334/// `✓ context line`  -- line was correctly predicted
335/// `✗ context line`  -- line is missing from predictions
336pub fn compare_context(example: &Example, preds: &PredictionDetails) -> String {
337    let use_color = std::io::stdout().is_terminal();
338    let green = if use_color { "\x1b[32m" } else { "" };
339    let red = if use_color { "\x1b[31m" } else { "" };
340    let reset = if use_color { "\x1b[0m" } else { "" };
341    let expected: Vec<_> = example
342        .expected_excerpts
343        .iter()
344        .flat_map(|excerpt| {
345            excerpt
346                .text
347                .lines()
348                .map(|line| (excerpt.path.clone(), line))
349        })
350        .collect();
351    let actual: HashSet<_> = preds
352        .excerpts
353        .iter()
354        .flat_map(|excerpt| {
355            excerpt
356                .text
357                .lines()
358                .map(|line| (excerpt.path.clone(), line))
359        })
360        .collect();
361
362    let annotated = expected
363        .iter()
364        .map(|(path, line)| {
365            if actual.contains(&(path.to_path_buf(), line)) {
366                format!("{green}{line}{reset}")
367            } else {
368                format!("{red}{line}{reset}")
369            }
370        })
371        .collect::<Vec<String>>();
372
373    annotated.join("\n")
374}
375
376/// Return annotated `patch_a` so that:
377/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
378/// Additions and deletions that are present in `patch_b` will be highlighted in green.
379pub fn compare_diffs(patch_a: &str, patch_b: &str) -> String {
380    let use_color = std::io::stdout().is_terminal();
381    let green = if use_color { "\x1b[32m✓ " } else { "" };
382    let red = if use_color { "\x1b[31m✗ " } else { "" };
383    let neutral = if use_color { "  " } else { "" };
384    let reset = if use_color { "\x1b[0m" } else { "" };
385    let lines_a = patch_a.lines().map(DiffLine::parse);
386    let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
387
388    let annotated = lines_a
389        .map(|line| match line {
390            DiffLine::Addition(_) | DiffLine::Deletion(_) => {
391                if lines_b.contains(&line) {
392                    format!("{green}{line}{reset}")
393                } else {
394                    format!("{red}{line}{reset}")
395                }
396            }
397            _ => format!("{neutral}{line}{reset}"),
398        })
399        .collect::<Vec<String>>();
400
401    annotated.join("\n")
402}
403
404#[cfg(test)]
405mod tests {
406    use super::calculate_edit_sites_coverage;
407    use crate::example::Excerpt;
408
409    #[test]
410    fn test_evaluate_expected_edit_places() {
411        let patch = indoc::indoc! {"
412            --- a/test.txt
413            +++ b/test.txt
414            @@ -1,4 +1,4 @@
415             apple
416            -banana
417            +BANANA
418             cherry
419            -date
420            +DATE
421            "};
422
423        let one_correct_excerpt = vec![Excerpt {
424            path: "test.txt".into(),
425            text: "apple\nbanana\n".to_string(),
426        }];
427
428        assert_eq!(
429            calculate_edit_sites_coverage(&patch, &one_correct_excerpt),
430            0.5,
431        );
432
433        let both_correct_excerpts = vec![
434            Excerpt {
435                path: "test.txt".into(),
436                text: "apple\nbanana\n".to_string(),
437            },
438            Excerpt {
439                path: "test.txt".into(),
440                text: "cherry\ndate\n".to_string(),
441            },
442        ];
443
444        assert_eq!(
445            calculate_edit_sites_coverage(&patch, &both_correct_excerpts),
446            1.0,
447        );
448
449        let incorrect_excerpts = vec![Excerpt {
450            path: "test.txt".into(),
451            text: "apple\n".into(),
452        }];
453        assert_eq!(
454            calculate_edit_sites_coverage(&patch, &incorrect_excerpts),
455            0.0,
456        );
457    }
458}