1use std::{
2 fs,
3 io::IsTerminal,
4 path::{Path, PathBuf},
5 sync::Arc,
6};
7
8use anyhow::Result;
9use clap::Args;
10use cloud_llm_client::udiff::DiffLine;
11use collections::HashSet;
12use gpui::AsyncApp;
13
14use crate::{
15 example::{Example, NamedExample},
16 headless::ZetaCliAppState,
17 predict::{PredictionDetails, zeta2_predict},
18};
19
20#[derive(Debug, Args)]
21pub struct EvaluateArguments {
22 example_paths: Vec<PathBuf>,
23 #[clap(long)]
24 re_run: bool,
25}
26
27pub async fn run_evaluate(
28 args: EvaluateArguments,
29 app_state: &Arc<ZetaCliAppState>,
30 cx: &mut AsyncApp,
31) {
32 let example_len = args.example_paths.len();
33 let all_tasks = args.example_paths.into_iter().map(|path| {
34 let app_state = app_state.clone();
35 cx.spawn(async move |cx| run_evaluate_one(&path, args.re_run, app_state.clone(), cx).await)
36 });
37 let all_results = futures::future::try_join_all(all_tasks).await.unwrap();
38
39 let aggregated_result = EvaluationResult {
40 context: Scores::aggregate(all_results.iter().map(|r| &r.context)),
41 edit_prediction: Scores::aggregate(all_results.iter().map(|r| &r.edit_prediction)),
42 };
43
44 if example_len > 1 {
45 println!("\n{}", "-".repeat(80));
46 println!("# TOTAL SCORES:");
47 println!("{}", aggregated_result.to_markdown());
48 }
49}
50
51pub async fn run_evaluate_one(
52 example_path: &Path,
53 re_run: bool,
54 app_state: Arc<ZetaCliAppState>,
55 cx: &mut AsyncApp,
56) -> Result<EvaluationResult> {
57 let cache_dir = Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap_or_default())
58 .join("../../target/zeta-prediction-cache");
59 let example = NamedExample::load(&example_path).unwrap();
60 let example_cache_path = cache_dir.join(&example_path.file_name().unwrap());
61
62 let predictions = if !re_run && example_cache_path.exists() {
63 let file_contents = fs::read_to_string(&example_cache_path)?;
64 let as_json = serde_json::from_str::<PredictionDetails>(&file_contents)?;
65 log::debug!(
66 "Loaded predictions from cache: {}",
67 example_cache_path.display()
68 );
69 as_json
70 } else {
71 zeta2_predict(example.clone(), &app_state, cx)
72 .await
73 .unwrap()
74 };
75
76 if !example_cache_path.exists() {
77 fs::create_dir_all(&cache_dir).unwrap();
78 fs::write(
79 example_cache_path,
80 serde_json::to_string(&predictions).unwrap(),
81 )
82 .unwrap();
83 }
84
85 let evaluation_result = evaluate(&example.example, &predictions);
86
87 println!("# {}\n", example.name);
88 println!(
89 "## Expected Context: \n\n```\n{}\n```\n\n",
90 compare_context(&example.example, &predictions)
91 );
92 println!(
93 "## Expected edit prediction:\n\n```diff\n{}\n```\n",
94 compare_diffs(&example.example.expected_patch, &predictions.diff)
95 );
96 println!(
97 "## Actual edit prediction:\n\n```diff\n{}\n```\n",
98 compare_diffs(&predictions.diff, &example.example.expected_patch)
99 );
100
101 println!("{}", evaluation_result.to_markdown());
102
103 anyhow::Ok(evaluation_result)
104}
105
106#[derive(Debug, Default)]
107pub struct EvaluationResult {
108 pub context: Scores,
109 pub edit_prediction: Scores,
110}
111
112#[derive(Default, Debug)]
113pub struct Scores {
114 pub precision: f64,
115 pub recall: f64,
116 pub f1_score: f64,
117 pub true_positives: usize,
118 pub false_positives: usize,
119 pub false_negatives: usize,
120}
121
122impl Scores {
123 pub fn to_markdown(&self) -> String {
124 format!(
125 "
126Precision : {:.4}
127Recall : {:.4}
128F1 Score : {:.4}
129True Positives : {}
130False Positives : {}
131False Negatives : {}",
132 self.precision,
133 self.recall,
134 self.f1_score,
135 self.true_positives,
136 self.false_positives,
137 self.false_negatives
138 )
139 }
140}
141
142impl Scores {
143 pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
144 let mut true_positives = 0;
145 let mut false_positives = 0;
146 let mut false_negatives = 0;
147
148 for score in scores {
149 true_positives += score.true_positives;
150 false_positives += score.false_positives;
151 false_negatives += score.false_negatives;
152 }
153
154 let precision = true_positives as f64 / (true_positives + false_positives) as f64;
155 let recall = true_positives as f64 / (true_positives + false_negatives) as f64;
156 let mut f1_score = 2.0 * precision * recall / (precision + recall);
157 if f1_score.is_nan() {
158 f1_score = 0.0;
159 }
160
161 Scores {
162 precision,
163 recall,
164 f1_score,
165 true_positives,
166 false_positives,
167 false_negatives,
168 }
169 }
170}
171
172impl EvaluationResult {
173 pub fn to_markdown(&self) -> String {
174 format!(
175 r#"
176### Context Scores
177{}
178
179### Edit Prediction Scores
180{}
181"#,
182 self.context.to_markdown(),
183 self.edit_prediction.to_markdown()
184 )
185 }
186}
187
188pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult {
189 let mut result = EvaluationResult::default();
190
191 let expected_context_lines = example
192 .expected_excerpts
193 .iter()
194 .flat_map(|excerpt| {
195 excerpt
196 .text
197 .lines()
198 .map(|line| format!("{}: {line}", excerpt.path.display()))
199 })
200 .collect();
201 let actual_context_lines = preds
202 .excerpts
203 .iter()
204 .flat_map(|excerpt| {
205 excerpt
206 .text
207 .lines()
208 .map(|line| format!("{}: {line}", excerpt.path.display()))
209 })
210 .collect();
211
212 result.context = precision_recall(&expected_context_lines, &actual_context_lines);
213
214 let expected_patch_lines = example
215 .expected_patch
216 .lines()
217 .map(DiffLine::parse)
218 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
219 .map(|line| line.to_string())
220 .collect();
221
222 let actual_patch_lines = preds
223 .diff
224 .lines()
225 .map(DiffLine::parse)
226 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
227 .map(|line| line.to_string())
228 .collect();
229
230 result.edit_prediction = precision_recall(&expected_patch_lines, &actual_patch_lines);
231
232 result
233}
234
235fn precision_recall(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
236 let true_positives = expected.intersection(actual).count();
237 let false_positives = actual.difference(expected).count();
238 let false_negatives = expected.difference(actual).count();
239
240 let precision = if true_positives + false_positives == 0 {
241 0.0
242 } else {
243 true_positives as f64 / (true_positives + false_positives) as f64
244 };
245 let recall = if true_positives + false_negatives == 0 {
246 0.0
247 } else {
248 true_positives as f64 / (true_positives + false_negatives) as f64
249 };
250 let f1_score = if precision + recall == 0.0 {
251 0.0
252 } else {
253 2.0 * precision * recall / (precision + recall)
254 };
255
256 Scores {
257 precision,
258 recall,
259 f1_score,
260 true_positives,
261 false_positives,
262 false_negatives,
263 }
264}
265
266/// Compare actual and expected context.
267///
268/// Return expected context annotated with these markers:
269///
270/// `✓ context line` -- line was correctly predicted
271/// `✗ context line` -- line is missing from predictions
272pub fn compare_context(example: &Example, preds: &PredictionDetails) -> String {
273 let use_color = std::io::stdout().is_terminal();
274 let green = if use_color { "\x1b[32m" } else { "" };
275 let red = if use_color { "\x1b[31m" } else { "" };
276 let reset = if use_color { "\x1b[0m" } else { "" };
277 let expected: Vec<_> = example
278 .expected_excerpts
279 .iter()
280 .flat_map(|excerpt| {
281 excerpt
282 .text
283 .lines()
284 .map(|line| (excerpt.path.clone(), line))
285 })
286 .collect();
287 let actual: HashSet<_> = preds
288 .excerpts
289 .iter()
290 .flat_map(|excerpt| {
291 excerpt
292 .text
293 .lines()
294 .map(|line| (excerpt.path.clone(), line))
295 })
296 .collect();
297
298 let annotated = expected
299 .iter()
300 .map(|(path, line)| {
301 if actual.contains(&(path.to_path_buf(), line)) {
302 format!("{green}✓ {line}{reset}")
303 } else {
304 format!("{red}✗ {line}{reset}")
305 }
306 })
307 .collect::<Vec<String>>();
308
309 annotated.join("\n")
310}
311
312/// Return annotated `patch_a` so that:
313/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
314/// Additions and deletions that are present in `patch_b` will be highlighted in green.
315pub fn compare_diffs(patch_a: &str, patch_b: &str) -> String {
316 let use_color = std::io::stdout().is_terminal();
317 let green = if use_color { "\x1b[32m✓ " } else { "" };
318 let red = if use_color { "\x1b[31m✗ " } else { "" };
319 let neutral = if use_color { " " } else { "" };
320 let reset = if use_color { "\x1b[0m" } else { "" };
321 let lines_a = patch_a.lines().map(DiffLine::parse);
322 let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
323
324 let annotated = lines_a
325 .map(|line| match line {
326 DiffLine::Addition(_) | DiffLine::Deletion(_) => {
327 if lines_b.contains(&line) {
328 format!("{green}{line}{reset}")
329 } else {
330 format!("{red}{line}{reset}")
331 }
332 }
333 _ => format!("{neutral}{line}{reset}"),
334 })
335 .collect::<Vec<String>>();
336
337 annotated.join("\n")
338}