1use std::{
2 io::IsTerminal,
3 path::{Path, PathBuf},
4 sync::Arc,
5};
6
7use anyhow::Result;
8use clap::Args;
9use collections::HashSet;
10use gpui::AsyncApp;
11use zeta2::udiff::DiffLine;
12
13use crate::{
14 PromptFormat,
15 example::{Example, NamedExample},
16 headless::ZetaCliAppState,
17 predict::{PredictionDetails, zeta2_predict},
18};
19
20#[derive(Debug, Args)]
21pub struct EvaluateArguments {
22 example_paths: Vec<PathBuf>,
23 #[clap(long)]
24 skip_cache: bool,
25 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
26 prompt_format: PromptFormat,
27}
28
29pub async fn run_evaluate(
30 args: EvaluateArguments,
31 app_state: &Arc<ZetaCliAppState>,
32 cx: &mut AsyncApp,
33) {
34 let example_len = args.example_paths.len();
35 let all_tasks = args.example_paths.into_iter().map(|path| {
36 let app_state = app_state.clone();
37 cx.spawn(async move |cx| {
38 run_evaluate_one(
39 &path,
40 args.skip_cache,
41 args.prompt_format,
42 app_state.clone(),
43 cx,
44 )
45 .await
46 })
47 });
48 let all_results = futures::future::try_join_all(all_tasks).await.unwrap();
49
50 let aggregated_result = EvaluationResult {
51 context: Scores::aggregate(all_results.iter().map(|r| &r.context)),
52 edit_prediction: Scores::aggregate(all_results.iter().map(|r| &r.edit_prediction)),
53 };
54
55 if example_len > 1 {
56 println!("\n{}", "-".repeat(80));
57 println!("# TOTAL SCORES:");
58 println!("{}", aggregated_result.to_markdown());
59 }
60}
61
62pub async fn run_evaluate_one(
63 example_path: &Path,
64 skip_cache: bool,
65 prompt_format: PromptFormat,
66 app_state: Arc<ZetaCliAppState>,
67 cx: &mut AsyncApp,
68) -> Result<EvaluationResult> {
69 let example = NamedExample::load(&example_path).unwrap();
70 let predictions = zeta2_predict(example.clone(), skip_cache, prompt_format, &app_state, cx)
71 .await
72 .unwrap();
73
74 let evaluation_result = evaluate(&example.example, &predictions);
75
76 println!(
77 "## Expected edit prediction:\n\n```diff\n{}\n```\n",
78 compare_diffs(&example.example.expected_patch, &predictions.diff)
79 );
80 println!(
81 "## Actual edit prediction:\n\n```diff\n{}\n```\n",
82 compare_diffs(&predictions.diff, &example.example.expected_patch)
83 );
84
85 println!("{}", evaluation_result.to_markdown());
86
87 anyhow::Ok(evaluation_result)
88}
89
90#[derive(Debug, Default)]
91pub struct EvaluationResult {
92 pub edit_prediction: Scores,
93 pub context: Scores,
94}
95
96#[derive(Default, Debug)]
97pub struct Scores {
98 pub true_positives: usize,
99 pub false_positives: usize,
100 pub false_negatives: usize,
101}
102
103impl Scores {
104 pub fn new(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
105 let true_positives = expected.intersection(actual).count();
106 let false_positives = actual.difference(expected).count();
107 let false_negatives = expected.difference(actual).count();
108
109 Scores {
110 true_positives,
111 false_positives,
112 false_negatives,
113 }
114 }
115
116 pub fn to_markdown(&self) -> String {
117 format!(
118 "
119Precision : {:.4}
120Recall : {:.4}
121F1 Score : {:.4}
122True Positives : {}
123False Positives : {}
124False Negatives : {}",
125 self.precision(),
126 self.recall(),
127 self.f1_score(),
128 self.true_positives,
129 self.false_positives,
130 self.false_negatives
131 )
132 }
133
134 pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
135 let mut true_positives = 0;
136 let mut false_positives = 0;
137 let mut false_negatives = 0;
138
139 for score in scores {
140 true_positives += score.true_positives;
141 false_positives += score.false_positives;
142 false_negatives += score.false_negatives;
143 }
144
145 Scores {
146 true_positives,
147 false_positives,
148 false_negatives,
149 }
150 }
151
152 pub fn precision(&self) -> f64 {
153 if self.true_positives + self.false_positives == 0 {
154 0.0
155 } else {
156 self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
157 }
158 }
159
160 pub fn recall(&self) -> f64 {
161 if self.true_positives + self.false_negatives == 0 {
162 0.0
163 } else {
164 self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
165 }
166 }
167
168 pub fn f1_score(&self) -> f64 {
169 let recall = self.recall();
170 let precision = self.precision();
171 if precision + recall == 0.0 {
172 0.0
173 } else {
174 2.0 * precision * recall / (precision + recall)
175 }
176 }
177}
178
179impl EvaluationResult {
180 pub fn to_markdown(&self) -> String {
181 format!(
182 r#"
183### Context Scores
184{}
185
186### Edit Prediction Scores
187{}
188"#,
189 self.context.to_markdown(),
190 self.edit_prediction.to_markdown()
191 )
192 }
193}
194
195pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult {
196 let mut eval_result = EvaluationResult::default();
197
198 let actual_context_lines: HashSet<_> = preds
199 .excerpts
200 .iter()
201 .flat_map(|excerpt| {
202 excerpt
203 .text
204 .lines()
205 .map(|line| format!("{}: {line}", excerpt.path.display()))
206 })
207 .collect();
208
209 let mut false_positive_lines = actual_context_lines.clone();
210
211 for entry in &example.expected_context {
212 let mut best_alternative_score = Scores::default();
213
214 for alternative in &entry.alternatives {
215 let expected: HashSet<_> = alternative
216 .excerpts
217 .iter()
218 .flat_map(|excerpt| {
219 excerpt
220 .text
221 .lines()
222 .map(|line| format!("{}: {line}", excerpt.path.display()))
223 })
224 .collect();
225
226 let scores = Scores::new(&expected, &actual_context_lines);
227
228 false_positive_lines.retain(|line| !actual_context_lines.contains(line));
229
230 if scores.recall() > best_alternative_score.recall() {
231 best_alternative_score = scores;
232 }
233 }
234
235 eval_result.context.false_negatives += best_alternative_score.false_negatives;
236 eval_result.context.true_positives += best_alternative_score.true_positives;
237 }
238
239 eval_result.context.false_positives = false_positive_lines.len();
240
241 // todo: alternatives for patches
242 let expected_patch_lines = example
243 .expected_patch
244 .lines()
245 .map(DiffLine::parse)
246 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
247 .map(|line| line.to_string())
248 .collect();
249
250 let actual_patch_lines = preds
251 .diff
252 .lines()
253 .map(DiffLine::parse)
254 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
255 .map(|line| line.to_string())
256 .collect();
257
258 eval_result.edit_prediction = Scores::new(&expected_patch_lines, &actual_patch_lines);
259 eval_result
260}
261
262/// Return annotated `patch_a` so that:
263/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
264/// Additions and deletions that are present in `patch_b` will be highlighted in green.
265pub fn compare_diffs(patch_a: &str, patch_b: &str) -> String {
266 let use_color = std::io::stdout().is_terminal();
267 let green = if use_color { "\x1b[32m✓ " } else { "" };
268 let red = if use_color { "\x1b[31m✗ " } else { "" };
269 let neutral = if use_color { " " } else { "" };
270 let reset = if use_color { "\x1b[0m" } else { "" };
271 let lines_a = patch_a.lines().map(DiffLine::parse);
272 let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
273
274 let annotated = lines_a
275 .map(|line| match line {
276 DiffLine::Addition(_) | DiffLine::Deletion(_) => {
277 if lines_b.contains(&line) {
278 format!("{green}{line}{reset}")
279 } else {
280 format!("{red}{line}{reset}")
281 }
282 }
283 _ => format!("{neutral}{line}{reset}"),
284 })
285 .collect::<Vec<String>>();
286
287 annotated.join("\n")
288}