evaluate.rs

  1use std::{
  2    fs,
  3    io::IsTerminal,
  4    path::{Path, PathBuf},
  5    sync::Arc,
  6};
  7
  8use anyhow::Result;
  9use clap::Args;
 10use cloud_llm_client::udiff::DiffLine;
 11use collections::HashSet;
 12use gpui::AsyncApp;
 13
 14use crate::{
 15    example::{Example, NamedExample},
 16    headless::ZetaCliAppState,
 17    predict::{PredictionDetails, zeta2_predict},
 18};
 19
 20#[derive(Debug, Args)]
 21pub struct EvaluateArguments {
 22    example_paths: Vec<PathBuf>,
 23    #[clap(long)]
 24    re_run: bool,
 25}
 26
 27pub async fn run_evaluate(
 28    args: EvaluateArguments,
 29    app_state: &Arc<ZetaCliAppState>,
 30    cx: &mut AsyncApp,
 31) {
 32    let example_len = args.example_paths.len();
 33    let all_tasks = args.example_paths.into_iter().map(|path| {
 34        let app_state = app_state.clone();
 35        cx.spawn(async move |cx| run_evaluate_one(&path, args.re_run, app_state.clone(), cx).await)
 36    });
 37    let all_results = futures::future::try_join_all(all_tasks).await.unwrap();
 38
 39    let aggregated_result = EvaluationResult {
 40        context: Scores::aggregate(all_results.iter().map(|r| &r.context)),
 41        edit_prediction: Scores::aggregate(all_results.iter().map(|r| &r.edit_prediction)),
 42    };
 43
 44    if example_len > 1 {
 45        println!("\n{}", "-".repeat(80));
 46        println!("# TOTAL SCORES:");
 47        println!("{}", aggregated_result.to_markdown());
 48    }
 49}
 50
 51pub async fn run_evaluate_one(
 52    example_path: &Path,
 53    re_run: bool,
 54    app_state: Arc<ZetaCliAppState>,
 55    cx: &mut AsyncApp,
 56) -> Result<EvaluationResult> {
 57    let cache_dir = Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap_or_default())
 58        .join("../../target/zeta-prediction-cache");
 59    let example = NamedExample::load(&example_path).unwrap();
 60    let example_cache_path = cache_dir.join(&example_path.file_name().unwrap());
 61
 62    let predictions = if !re_run && example_cache_path.exists() {
 63        let file_contents = fs::read_to_string(&example_cache_path)?;
 64        let as_json = serde_json::from_str::<PredictionDetails>(&file_contents)?;
 65        log::debug!(
 66            "Loaded predictions from cache: {}",
 67            example_cache_path.display()
 68        );
 69        as_json
 70    } else {
 71        zeta2_predict(example.clone(), &app_state, cx)
 72            .await
 73            .unwrap()
 74    };
 75
 76    if !example_cache_path.exists() {
 77        fs::create_dir_all(&cache_dir).unwrap();
 78        fs::write(
 79            example_cache_path,
 80            serde_json::to_string(&predictions).unwrap(),
 81        )
 82        .unwrap();
 83    }
 84
 85    let evaluation_result = evaluate(&example.example, &predictions);
 86
 87    println!("# {}\n", example.name);
 88    println!(
 89        "## Expected Context: \n\n```\n{}\n```\n\n",
 90        compare_context(&example.example, &predictions)
 91    );
 92    println!(
 93        "## Expected edit prediction:\n\n```diff\n{}\n```\n",
 94        compare_diffs(&example.example.expected_patch, &predictions.diff)
 95    );
 96    println!(
 97        "## Actual edit prediction:\n\n```diff\n{}\n```\n",
 98        compare_diffs(&predictions.diff, &example.example.expected_patch)
 99    );
100
101    println!("{}", evaluation_result.to_markdown());
102
103    anyhow::Ok(evaluation_result)
104}
105
106#[derive(Debug, Default)]
107pub struct EvaluationResult {
108    pub context: Scores,
109    pub edit_prediction: Scores,
110}
111
112#[derive(Default, Debug)]
113pub struct Scores {
114    pub precision: f64,
115    pub recall: f64,
116    pub f1_score: f64,
117    pub true_positives: usize,
118    pub false_positives: usize,
119    pub false_negatives: usize,
120}
121
122impl Scores {
123    pub fn to_markdown(&self) -> String {
124        format!(
125            "
126Precision       : {:.4}
127Recall          : {:.4}
128F1 Score        : {:.4}
129True Positives  : {}
130False Positives : {}
131False Negatives : {}",
132            self.precision,
133            self.recall,
134            self.f1_score,
135            self.true_positives,
136            self.false_positives,
137            self.false_negatives
138        )
139    }
140}
141
142impl Scores {
143    pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
144        let mut true_positives = 0;
145        let mut false_positives = 0;
146        let mut false_negatives = 0;
147
148        for score in scores {
149            true_positives += score.true_positives;
150            false_positives += score.false_positives;
151            false_negatives += score.false_negatives;
152        }
153
154        let precision = true_positives as f64 / (true_positives + false_positives) as f64;
155        let recall = true_positives as f64 / (true_positives + false_negatives) as f64;
156        let mut f1_score = 2.0 * precision * recall / (precision + recall);
157        if f1_score.is_nan() {
158            f1_score = 0.0;
159        }
160
161        Scores {
162            precision,
163            recall,
164            f1_score,
165            true_positives,
166            false_positives,
167            false_negatives,
168        }
169    }
170}
171
172impl EvaluationResult {
173    pub fn to_markdown(&self) -> String {
174        format!(
175            r#"
176### Context Scores
177{}
178
179### Edit Prediction Scores
180{}
181"#,
182            self.context.to_markdown(),
183            self.edit_prediction.to_markdown()
184        )
185    }
186}
187
188pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult {
189    let mut result = EvaluationResult::default();
190
191    let expected_context_lines = example
192        .expected_excerpts
193        .iter()
194        .flat_map(|excerpt| {
195            excerpt
196                .text
197                .lines()
198                .map(|line| format!("{}: {line}", excerpt.path.display()))
199        })
200        .collect();
201    let actual_context_lines = preds
202        .excerpts
203        .iter()
204        .flat_map(|excerpt| {
205            excerpt
206                .text
207                .lines()
208                .map(|line| format!("{}: {line}", excerpt.path.display()))
209        })
210        .collect();
211
212    result.context = precision_recall(&expected_context_lines, &actual_context_lines);
213
214    let expected_patch_lines = example
215        .expected_patch
216        .lines()
217        .map(DiffLine::parse)
218        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
219        .map(|line| line.to_string())
220        .collect();
221
222    let actual_patch_lines = preds
223        .diff
224        .lines()
225        .map(DiffLine::parse)
226        .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
227        .map(|line| line.to_string())
228        .collect();
229
230    result.edit_prediction = precision_recall(&expected_patch_lines, &actual_patch_lines);
231
232    result
233}
234
235fn precision_recall(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
236    let true_positives = expected.intersection(actual).count();
237    let false_positives = actual.difference(expected).count();
238    let false_negatives = expected.difference(actual).count();
239
240    let precision = if true_positives + false_positives == 0 {
241        0.0
242    } else {
243        true_positives as f64 / (true_positives + false_positives) as f64
244    };
245    let recall = if true_positives + false_negatives == 0 {
246        0.0
247    } else {
248        true_positives as f64 / (true_positives + false_negatives) as f64
249    };
250    let f1_score = if precision + recall == 0.0 {
251        0.0
252    } else {
253        2.0 * precision * recall / (precision + recall)
254    };
255
256    Scores {
257        precision,
258        recall,
259        f1_score,
260        true_positives,
261        false_positives,
262        false_negatives,
263    }
264}
265
266/// Compare actual and expected context.
267///
268/// Return expected context annotated with these markers:
269///
270/// `✓ context line`  -- line was correctly predicted
271/// `✗ context line`  -- line is missing from predictions
272pub fn compare_context(example: &Example, preds: &PredictionDetails) -> String {
273    let use_color = std::io::stdout().is_terminal();
274    let green = if use_color { "\x1b[32m" } else { "" };
275    let red = if use_color { "\x1b[31m" } else { "" };
276    let reset = if use_color { "\x1b[0m" } else { "" };
277    let expected: Vec<_> = example
278        .expected_excerpts
279        .iter()
280        .flat_map(|excerpt| {
281            excerpt
282                .text
283                .lines()
284                .map(|line| (excerpt.path.clone(), line))
285        })
286        .collect();
287    let actual: HashSet<_> = preds
288        .excerpts
289        .iter()
290        .flat_map(|excerpt| {
291            excerpt
292                .text
293                .lines()
294                .map(|line| (excerpt.path.clone(), line))
295        })
296        .collect();
297
298    let annotated = expected
299        .iter()
300        .map(|(path, line)| {
301            if actual.contains(&(path.to_path_buf(), line)) {
302                format!("{green}{line}{reset}")
303            } else {
304                format!("{red}{line}{reset}")
305            }
306        })
307        .collect::<Vec<String>>();
308
309    annotated.join("\n")
310}
311
312/// Return annotated `patch_a` so that:
313/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
314/// Additions and deletions that are present in `patch_b` will be highlighted in green.
315pub fn compare_diffs(patch_a: &str, patch_b: &str) -> String {
316    let use_color = std::io::stdout().is_terminal();
317    let green = if use_color { "\x1b[32m✓ " } else { "" };
318    let red = if use_color { "\x1b[31m✗ " } else { "" };
319    let neutral = if use_color { "  " } else { "" };
320    let reset = if use_color { "\x1b[0m" } else { "" };
321    let lines_a = patch_a.lines().map(DiffLine::parse);
322    let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
323
324    let annotated = lines_a
325        .map(|line| match line {
326            DiffLine::Addition(_) | DiffLine::Deletion(_) => {
327                if lines_b.contains(&line) {
328                    format!("{green}{line}{reset}")
329                } else {
330                    format!("{red}{line}{reset}")
331                }
332            }
333            _ => format!("{neutral}{line}{reset}"),
334        })
335        .collect::<Vec<String>>();
336
337    annotated.join("\n")
338}