1use std::{
2 io::IsTerminal,
3 path::{Path, PathBuf},
4 sync::Arc,
5};
6
7use anyhow::Result;
8use clap::Args;
9use collections::HashSet;
10use gpui::AsyncApp;
11use zeta2::udiff::DiffLine;
12
13use crate::{
14 PromptFormat,
15 example::{Example, NamedExample},
16 headless::ZetaCliAppState,
17 paths::print_run_data_dir,
18 predict::{CacheMode, PredictionDetails, zeta2_predict},
19};
20
21#[derive(Debug, Args)]
22pub struct EvaluateArguments {
23 example_paths: Vec<PathBuf>,
24 #[arg(long, value_enum, default_value_t = PromptFormat::default())]
25 prompt_format: PromptFormat,
26 #[arg(long)]
27 use_expected_context: bool,
28 #[clap(long, value_enum, default_value_t = CacheMode::default())]
29 cache: CacheMode,
30}
31
32pub async fn run_evaluate(
33 args: EvaluateArguments,
34 app_state: &Arc<ZetaCliAppState>,
35 cx: &mut AsyncApp,
36) {
37 let example_len = args.example_paths.len();
38 let all_tasks = args.example_paths.into_iter().map(|path| {
39 let app_state = app_state.clone();
40 cx.spawn(async move |cx| {
41 run_evaluate_one(
42 &path,
43 args.prompt_format,
44 args.use_expected_context,
45 args.cache,
46 app_state.clone(),
47 cx,
48 )
49 .await
50 })
51 });
52 let all_results = futures::future::try_join_all(all_tasks).await;
53
54 if let Ok(all_results) = &all_results {
55 let aggregated_result = EvaluationResult {
56 context: Scores::aggregate(all_results.iter().map(|r| &r.context)),
57 edit_prediction: Scores::aggregate(all_results.iter().map(|r| &r.edit_prediction)),
58 };
59
60 if example_len > 1 {
61 println!("\n{}", "-".repeat(80));
62 println!("\n## TOTAL SCORES");
63 println!("{}", aggregated_result.to_markdown());
64 }
65 }
66
67 print_run_data_dir();
68
69 all_results.unwrap();
70}
71
72pub async fn run_evaluate_one(
73 example_path: &Path,
74 prompt_format: PromptFormat,
75 use_expected_context: bool,
76 cache_mode: CacheMode,
77 app_state: Arc<ZetaCliAppState>,
78 cx: &mut AsyncApp,
79) -> Result<EvaluationResult> {
80 let example = NamedExample::load(&example_path).unwrap();
81 let predictions = zeta2_predict(
82 example.clone(),
83 prompt_format,
84 use_expected_context,
85 cache_mode,
86 &app_state,
87 cx,
88 )
89 .await
90 .unwrap();
91
92 let evaluation_result = evaluate(&example.example, &predictions);
93
94 println!(
95 "## Expected edit prediction:\n\n```diff\n{}\n```\n",
96 compare_diffs(&example.example.expected_patch, &predictions.diff)
97 );
98 println!(
99 "## Actual edit prediction:\n\n```diff\n{}\n```\n",
100 compare_diffs(&predictions.diff, &example.example.expected_patch)
101 );
102
103 println!("{}", evaluation_result.to_markdown());
104
105 anyhow::Ok(evaluation_result)
106}
107
108#[derive(Debug, Default)]
109pub struct EvaluationResult {
110 pub edit_prediction: Scores,
111 pub context: Scores,
112}
113
114#[derive(Default, Debug)]
115pub struct Scores {
116 pub true_positives: usize,
117 pub false_positives: usize,
118 pub false_negatives: usize,
119}
120
121impl Scores {
122 pub fn new(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
123 let true_positives = expected.intersection(actual).count();
124 let false_positives = actual.difference(expected).count();
125 let false_negatives = expected.difference(actual).count();
126
127 Scores {
128 true_positives,
129 false_positives,
130 false_negatives,
131 }
132 }
133
134 pub fn to_markdown(&self) -> String {
135 format!(
136 "
137Precision : {:.4}
138Recall : {:.4}
139F1 Score : {:.4}
140True Positives : {}
141False Positives : {}
142False Negatives : {}",
143 self.precision(),
144 self.recall(),
145 self.f1_score(),
146 self.true_positives,
147 self.false_positives,
148 self.false_negatives
149 )
150 }
151
152 pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
153 let mut true_positives = 0;
154 let mut false_positives = 0;
155 let mut false_negatives = 0;
156
157 for score in scores {
158 true_positives += score.true_positives;
159 false_positives += score.false_positives;
160 false_negatives += score.false_negatives;
161 }
162
163 Scores {
164 true_positives,
165 false_positives,
166 false_negatives,
167 }
168 }
169
170 pub fn precision(&self) -> f64 {
171 if self.true_positives + self.false_positives == 0 {
172 0.0
173 } else {
174 self.true_positives as f64 / (self.true_positives + self.false_positives) as f64
175 }
176 }
177
178 pub fn recall(&self) -> f64 {
179 if self.true_positives + self.false_negatives == 0 {
180 0.0
181 } else {
182 self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
183 }
184 }
185
186 pub fn f1_score(&self) -> f64 {
187 let recall = self.recall();
188 let precision = self.precision();
189 if precision + recall == 0.0 {
190 0.0
191 } else {
192 2.0 * precision * recall / (precision + recall)
193 }
194 }
195}
196
197impl EvaluationResult {
198 pub fn to_markdown(&self) -> String {
199 format!(
200 r#"
201### Context Scores
202{}
203
204### Edit Prediction Scores
205{}
206"#,
207 self.context.to_markdown(),
208 self.edit_prediction.to_markdown()
209 )
210 }
211}
212
213pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult {
214 let mut eval_result = EvaluationResult::default();
215
216 let actual_context_lines: HashSet<_> = preds
217 .excerpts
218 .iter()
219 .flat_map(|excerpt| {
220 excerpt
221 .text
222 .lines()
223 .map(|line| format!("{}: {line}", excerpt.path.display()))
224 })
225 .collect();
226
227 let mut false_positive_lines = actual_context_lines.clone();
228
229 for entry in &example.expected_context {
230 let mut best_alternative_score = Scores::default();
231
232 for alternative in &entry.alternatives {
233 let expected: HashSet<_> = alternative
234 .excerpts
235 .iter()
236 .flat_map(|excerpt| {
237 excerpt
238 .text
239 .lines()
240 .map(|line| format!("{}: {line}", excerpt.path.display()))
241 })
242 .collect();
243
244 let scores = Scores::new(&expected, &actual_context_lines);
245
246 false_positive_lines.retain(|line| !actual_context_lines.contains(line));
247
248 if scores.recall() > best_alternative_score.recall() {
249 best_alternative_score = scores;
250 }
251 }
252
253 eval_result.context.false_negatives += best_alternative_score.false_negatives;
254 eval_result.context.true_positives += best_alternative_score.true_positives;
255 }
256
257 eval_result.context.false_positives = false_positive_lines.len();
258
259 // todo: alternatives for patches
260 let expected_patch_lines = example
261 .expected_patch
262 .lines()
263 .map(DiffLine::parse)
264 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
265 .map(|line| line.to_string())
266 .collect();
267
268 let actual_patch_lines = preds
269 .diff
270 .lines()
271 .map(DiffLine::parse)
272 .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
273 .map(|line| line.to_string())
274 .collect();
275
276 eval_result.edit_prediction = Scores::new(&expected_patch_lines, &actual_patch_lines);
277 eval_result
278}
279
280/// Return annotated `patch_a` so that:
281/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
282/// Additions and deletions that are present in `patch_b` will be highlighted in green.
283pub fn compare_diffs(patch_a: &str, patch_b: &str) -> String {
284 let use_color = std::io::stdout().is_terminal();
285 let green = if use_color { "\x1b[32m✓ " } else { "" };
286 let red = if use_color { "\x1b[31m✗ " } else { "" };
287 let neutral = if use_color { " " } else { "" };
288 let reset = if use_color { "\x1b[0m" } else { "" };
289 let lines_a = patch_a.lines().map(DiffLine::parse);
290 let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
291
292 let annotated = lines_a
293 .map(|line| match line {
294 DiffLine::Addition(_) | DiffLine::Deletion(_) => {
295 if lines_b.contains(&line) {
296 format!("{green}{line}{reset}")
297 } else {
298 format!("{red}{line}{reset}")
299 }
300 }
301 _ => format!("{neutral}{line}{reset}"),
302 })
303 .collect::<Vec<String>>();
304
305 annotated.join("\n")
306}