1use std::{
2 io::IsTerminal,
3 path::{Path, PathBuf},
4 sync::Arc,
5};
6
7use anyhow::Result;
8use clap::Args;
9use collections::HashSet;
10use gpui::AsyncApp;
11use zeta2::udiff::DiffLine;
12
13use crate::{
14 PromptFormat,
15 example::{Example, NamedExample},
16 headless::ZetaCliAppState,
17 predict::{PredictionDetails, zeta2_predict},
18};
19
20#[derive(Debug, Args)]
21pub struct EvaluateArguments {
22 example_paths: Vec<PathBuf>,
23 #[clap(long)]
24 skip_cache: bool,
25 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
26 prompt_format: PromptFormat,
27 #[arg(long)]
28 use_expected_context: bool,
29}
30
31pub async fn run_evaluate(
32 args: EvaluateArguments,
33 app_state: &Arc<ZetaCliAppState>,
34 cx: &mut AsyncApp,
35) {
36 let example_len = args.example_paths.len();
37 let all_tasks = args.example_paths.into_iter().map(|path| {
38 let app_state = app_state.clone();
39 cx.spawn(async move |cx| {
40 run_evaluate_one(
41 &path,
42 args.skip_cache,
43 args.prompt_format,
44 args.use_expected_context,
45 app_state.clone(),
46 cx,
47 )
48 .await
49 })
50 });
51 let all_results = futures::future::try_join_all(all_tasks).await.unwrap();
52
53 let aggregated_result = EvaluationResult {
54 context: Scores::aggregate(all_results.iter().map(|r| &r.context)),
55 edit_prediction: Scores::aggregate(all_results.iter().map(|r| &r.edit_prediction)),
56 };
57
58 if example_len > 1 {
59 println!("\n{}", "-".repeat(80));
60 println!("# TOTAL SCORES:");
61 println!("{}", aggregated_result.to_markdown());
62 }
63}
64
65pub async fn run_evaluate_one(
66 example_path: &Path,
67 skip_cache: bool,
68 prompt_format: PromptFormat,
69 use_expected_context: bool,
70 app_state: Arc<ZetaCliAppState>,
71 cx: &mut AsyncApp,
72) -> Result<EvaluationResult> {
73 let example = NamedExample::load(&example_path).unwrap();
74 let predictions = zeta2_predict(
75 example.clone(),
76 skip_cache,
77 prompt_format,
78 use_expected_context,
79 &app_state,
80 cx,
81 )
82 .await
83 .unwrap();
84
85 let evaluation_result = evaluate(&example.example, &predictions);
86
87 println!(
88 "## Expected edit prediction:\n\n```diff\n{}\n```\n",
89 compare_diffs(&example.example.expected_patch, &predictions.diff)
90 );
91 println!(
92 "## Actual edit prediction:\n\n```diff\n{}\n```\n",
93 compare_diffs(&predictions.diff, &example.example.expected_patch)
94 );
95
96 println!("{}", evaluation_result.to_markdown());
97
98 anyhow::Ok(evaluation_result)
99}
100
101#[derive(Debug, Default)]
102pub struct EvaluationResult {
103 pub edit_prediction: Scores,
104 pub context: Scores,
105}
106
107#[derive(Default, Debug)]
108pub struct Scores {
109 pub true_positives: usize,
110 pub false_positives: usize,
111 pub false_negatives: usize,
112}
113
114impl Scores {
115 pub fn new(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
116 let true_positives = expected.intersection(actual).count();
117 let false_positives = actual.difference(expected).count();
118 let false_negatives = expected.difference(actual).count();
119
120 Scores {
121 true_positives,
122 false_positives,
123 false_negatives,
124 }
125 }
126
127 pub fn to_markdown(&self) -> String {
128 format!(
129 "
130Precision : {:.4}
131Recall : {:.4}
132F1 Score : {:.4}
133True Positives : {}
134False Positives : {}
135False Negatives : {}",
136 self.precision(),
137 self.recall(),
138 self.f1_score(),
139 self.true_positives,
140 self.false_positives,
141 self.false_negatives
142 )
143 }
144
145 pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
146 let mut true_positives = 0;
147 let mut false_positives = 0;
148 let mut false_negatives = 0;
149
150 for score in scores {
151 true_positives += score.true_positives;
152 false_positives += score.false_positives;
153 false_negatives += score.false_negatives;
154 }
155
156 Scores {
157 true_positives,
158 false_positives,
159 false_negatives,
160 }
161 }
162
163 pub fn precision(&self) -> f64 {
164 if self.true_positives + self.false_positives == 0 {
165 0.0
166 } else {
167 self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
168 }
169 }
170
171 pub fn recall(&self) -> f64 {
172 if self.true_positives + self.false_negatives == 0 {
173 0.0
174 } else {
175 self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
176 }
177 }
178
179 pub fn f1_score(&self) -> f64 {
180 let recall = self.recall();
181 let precision = self.precision();
182 if precision + recall == 0.0 {
183 0.0
184 } else {
185 2.0 * precision * recall / (precision + recall)
186 }
187 }
188}
189
190impl EvaluationResult {
191 pub fn to_markdown(&self) -> String {
192 format!(
193 r#"
194### Context Scores
195{}
196
197### Edit Prediction Scores
198{}
199"#,
200 self.context.to_markdown(),
201 self.edit_prediction.to_markdown()
202 )
203 }
204}
205
206pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult {
207 let mut eval_result = EvaluationResult::default();
208
209 let actual_context_lines: HashSet<_> = preds
210 .excerpts
211 .iter()
212 .flat_map(|excerpt| {
213 excerpt
214 .text
215 .lines()
216 .map(|line| format!("{}: {line}", excerpt.path.display()))
217 })
218 .collect();
219
220 let mut false_positive_lines = actual_context_lines.clone();
221
222 for entry in &example.expected_context {
223 let mut best_alternative_score = Scores::default();
224
225 for alternative in &entry.alternatives {
226 let expected: HashSet<_> = alternative
227 .excerpts
228 .iter()
229 .flat_map(|excerpt| {
230 excerpt
231 .text
232 .lines()
233 .map(|line| format!("{}: {line}", excerpt.path.display()))
234 })
235 .collect();
236
237 let scores = Scores::new(&expected, &actual_context_lines);
238
239 false_positive_lines.retain(|line| !actual_context_lines.contains(line));
240
241 if scores.recall() > best_alternative_score.recall() {
242 best_alternative_score = scores;
243 }
244 }
245
246 eval_result.context.false_negatives += best_alternative_score.false_negatives;
247 eval_result.context.true_positives += best_alternative_score.true_positives;
248 }
249
250 eval_result.context.false_positives = false_positive_lines.len();
251
252 // todo: alternatives for patches
253 let expected_patch_lines = example
254 .expected_patch
255 .lines()
256 .map(DiffLine::parse)
257 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
258 .map(|line| line.to_string())
259 .collect();
260
261 let actual_patch_lines = preds
262 .diff
263 .lines()
264 .map(DiffLine::parse)
265 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
266 .map(|line| line.to_string())
267 .collect();
268
269 eval_result.edit_prediction = Scores::new(&expected_patch_lines, &actual_patch_lines);
270 eval_result
271}
272
273/// Return annotated `patch_a` so that:
274/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
275/// Additions and deletions that are present in `patch_b` will be highlighted in green.
276pub fn compare_diffs(patch_a: &str, patch_b: &str) -> String {
277 let use_color = std::io::stdout().is_terminal();
278 let green = if use_color { "\x1b[32m✓ " } else { "" };
279 let red = if use_color { "\x1b[31m✗ " } else { "" };
280 let neutral = if use_color { " " } else { "" };
281 let reset = if use_color { "\x1b[0m" } else { "" };
282 let lines_a = patch_a.lines().map(DiffLine::parse);
283 let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
284
285 let annotated = lines_a
286 .map(|line| match line {
287 DiffLine::Addition(_) | DiffLine::Deletion(_) => {
288 if lines_b.contains(&line) {
289 format!("{green}{line}{reset}")
290 } else {
291 format!("{red}{line}{reset}")
292 }
293 }
294 _ => format!("{neutral}{line}{reset}"),
295 })
296 .collect::<Vec<String>>();
297
298 annotated.join("\n")
299}