1mod anthropic_client;
2mod distill;
3mod example;
4mod format_prompt;
5mod git;
6mod headless;
7mod load_project;
8mod metrics;
9mod paths;
10mod predict;
11mod progress;
12mod pull_examples;
13mod reorder_patch;
14mod retrieve_context;
15mod score;
16mod split_commit;
17mod split_dataset;
18mod synthesize;
19use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
20use collections::HashSet;
21use edit_prediction::EditPredictionStore;
22use futures::channel::mpsc;
23use futures::{SinkExt as _, StreamExt as _};
24use gpui::{AppContext as _, Application};
25use zeta_prompt::ZetaVersion;
26
27use reqwest_client::ReqwestClient;
28use serde::{Deserialize, Serialize};
29use std::fmt::Display;
30use std::fs::{File, OpenOptions};
31use std::hash::{Hash, Hasher};
32use std::io::{BufRead, BufReader, BufWriter, Write};
33use std::{path::PathBuf, sync::Arc};
34
35use crate::distill::run_distill;
36use crate::example::{Example, group_examples_by_repo, read_example_files};
37use crate::format_prompt::run_format_prompt;
38use crate::load_project::run_load_project;
39use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
40use crate::predict::run_prediction;
41use crate::progress::Progress;
42use crate::retrieve_context::run_context_retrieval;
43use crate::score::run_scoring;
44use crate::split_commit::SplitCommitArgs;
45use crate::split_dataset::SplitArgs;
46use crate::synthesize::{SynthesizeConfig, run_synthesize};
47
48#[derive(Parser, Debug)]
49#[command(name = "ep")]
50struct EpArgs {
51 #[arg(long, default_value_t = false)]
52 printenv: bool,
53 #[clap(long, default_value_t = 10, global = true)]
54 max_parallelism: usize,
55 #[clap(long, global = true)]
56 limit: Option<usize>,
57 /// Filter examples by name
58 #[clap(long, global = true)]
59 name: Option<String>,
60 /// Filter examples by repository
61 #[clap(long, global = true)]
62 repo: Option<String>,
63 #[command(subcommand)]
64 command: Option<Command>,
65 #[clap(global = true, help = INPUTS_HELP)]
66 inputs: Vec<PathBuf>,
67 #[arg(long, short, global = true)]
68 output: Option<PathBuf>,
69 #[arg(long, short, global = true)]
70 in_place: bool,
71 #[arg(long, short, global = true)]
72 failfast: bool,
73 /// How to handle failed examples in output: keep them or skip them.
74 /// Failed examples are always logged to the run's failed directory.
75 #[arg(long, global = true, default_value = "keep")]
76 failed: FailedHandling,
77}
78
79/// Controls whether failed examples are included in the main output.
80/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
81#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
82pub enum FailedHandling {
83 /// Include failed examples in the main output (default)
84 #[default]
85 Keep,
86 /// Exclude failed examples from the main output
87 Skip,
88}
89
90const INPUTS_HELP: &str = r#"
91Inputs can be file paths or special specifiers:
92
93 path
94 Path to an example(s) file (.md, .json, or .jsonl)
95
96 captured-after:{timestamp}
97 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
98
99 You can specify this multiple times and mix it with file inputs.
100
101 Required environment variables to connect to Snowflake:
102 EP_SNOWFLAKE_API_KEY
103 EP_SNOWFLAKE_BASE_URL
104
105 Optional:
106 EP_SNOWFLAKE_ROLE
107
108Examples:
109
110 # Predict from a file
111 ep predict examples.jsonl
112
113 # Predict from captured examples after a timestamp
114 ep predict captured-after:2025-01-01T00:00:00Z
115
116 # Mix file inputs and captured-after in the same invocation
117 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
118"#;
119
120#[derive(Subcommand, Debug, Clone)]
121enum Command {
122 /// Parse markdown examples and output a combined .jsonl file
123 ParseExample,
124 /// Create git worktrees for each example and load file contents
125 LoadProject,
126 /// Retrieve context for input examples.
127 Context,
128 /// Generate a prompt string for a specific model
129 FormatPrompt(FormatPromptArgs),
130 /// Runs edit prediction
131 Predict(PredictArgs),
132 /// Computes a score based on actual and expected patches
133 Score(PredictArgs),
134 /// Prepares a distillation dataset by copying expected outputs to
135 /// predicted outputs and removing actual outputs and prompts.
136 Distill,
137 /// Print aggregated scores
138 Eval(PredictArgs),
139 /// Generate eval examples by analyzing git commits from a repository
140 Synthesize(SynthesizeArgs),
141 /// Remove git repositories and worktrees
142 Clean,
143 /// Generate an evaluation example by splitting a chronologically-ordered commit
144 SplitCommit(SplitCommitArgs),
145 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
146 Split(SplitArgs),
147}
148
149impl Display for Command {
150 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151 match self {
152 Command::ParseExample => write!(f, "parse-example"),
153 Command::LoadProject => write!(f, "load-project"),
154 Command::Context => write!(f, "context"),
155 Command::FormatPrompt(format_prompt_args) => write!(
156 f,
157 "format-prompt --prompt-format={}",
158 format_prompt_args
159 .provider
160 .to_possible_value()
161 .unwrap()
162 .get_name()
163 ),
164 Command::Predict(predict_args) => {
165 write!(
166 f,
167 "predict --provider={:?}",
168 predict_args
169 .provider
170 .to_possible_value()
171 .unwrap()
172 .get_name()
173 )
174 }
175 Command::Score(predict_args) => {
176 write!(
177 f,
178 "score --provider={:?}",
179 predict_args
180 .provider
181 .to_possible_value()
182 .unwrap()
183 .get_name()
184 )
185 }
186 Command::Distill => write!(f, "distill"),
187 Command::Eval(predict_args) => write!(
188 f,
189 "eval --provider={:?}",
190 predict_args
191 .provider
192 .to_possible_value()
193 .unwrap()
194 .get_name()
195 ),
196 Command::Synthesize(args) => {
197 write!(f, "synthesize --repo={}", args.repo)
198 }
199 Command::Clean => write!(f, "clean"),
200 Command::SplitCommit(_) => write!(f, "split-commit"),
201 Command::Split(_) => write!(f, "split"),
202 }
203 }
204}
205
206#[derive(Debug, Args, Clone)]
207struct FormatPromptArgs {
208 #[clap(long, short)]
209 provider: PredictionProvider,
210 #[clap(
211 long,
212 short,
213 help = "(only for --provider zeta2) A substring of a zeta_prompt::ZetaVersion variant to use",
214 value_parser = ZetaVersion::parse,
215 default_value_t = ZetaVersion::default(),
216 )]
217 version: ZetaVersion,
218}
219
220#[derive(Debug, Args, Clone)]
221struct PredictArgs {
222 #[clap(long, short)]
223 provider: PredictionProvider,
224 #[clap(long, default_value_t = 1)]
225 repetitions: usize,
226 #[clap(
227 long,
228 short,
229 help = "(only for --provider zeta2) A substring of a zeta_prompt::ZetaVersion variant to use",
230 value_parser = ZetaVersion::parse,
231 )]
232 version: ZetaVersion,
233}
234
235#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
236enum PredictionProvider {
237 Sweep,
238 Mercury,
239 Zeta1,
240 Zeta2,
241 Teacher,
242 TeacherNonBatching,
243}
244
245#[derive(Debug, Args, Clone)]
246struct SynthesizeArgs {
247 /// Repository URL (git@github.com:owner/repo or https://...)
248 #[clap(long)]
249 repo: String,
250
251 /// Number of examples to generate
252 #[clap(long, default_value_t = 5)]
253 count: usize,
254
255 /// Maximum commits to scan before giving up
256 #[clap(long, default_value_t = 100)]
257 max_commits: usize,
258
259 /// Ignore state file and reprocess all commits
260 #[clap(long)]
261 fresh: bool,
262}
263
264impl EpArgs {
265 fn output_path(&self) -> Option<PathBuf> {
266 if self.in_place {
267 if self.inputs.len() == 1 {
268 self.inputs.first().cloned()
269 } else {
270 panic!("--in-place requires exactly one input file")
271 }
272 } else {
273 self.output.clone()
274 }
275 }
276}
277
278async fn load_examples(
279 http_client: Arc<dyn http_client::HttpClient>,
280 args: &EpArgs,
281 output_path: Option<&PathBuf>,
282) -> anyhow::Result<Vec<Example>> {
283 let mut captured_after_timestamps = Vec::new();
284 let mut file_inputs = Vec::new();
285
286 for input in &args.inputs {
287 let input_string = input.to_string_lossy();
288 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
289 captured_after_timestamps.push(timestamp.to_string());
290 } else {
291 file_inputs.push(input.clone());
292 }
293 }
294
295 let mut examples = read_example_files(&file_inputs);
296
297 Progress::global().set_total_examples(examples.len());
298
299 let remaining_limit_for_snowflake =
300 args.limit.map(|limit| limit.saturating_sub(examples.len()));
301
302 if let Some(0) = remaining_limit_for_snowflake {
303 log::info!(
304 "skipping captured-after inputs because --limit is already satisfied by example files"
305 );
306 } else if !captured_after_timestamps.is_empty() {
307 captured_after_timestamps.sort();
308
309 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
310
311 let mut captured_examples = pull_examples::fetch_captured_examples_after(
312 http_client,
313 &captured_after_timestamps,
314 max_rows_per_timestamp,
315 )
316 .await?;
317 examples.append(&mut captured_examples);
318 }
319
320 crate::example::sort_examples_by_repo_and_rev(&mut examples);
321
322 if let Some(name_filter) = &args.name {
323 examples.retain(|example| example.spec.name.contains(name_filter));
324 }
325 if let Some(repo_filter) = &args.repo {
326 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
327 }
328
329 if let Some(limit) = args.limit {
330 if examples.len() > limit {
331 examples.truncate(limit);
332 }
333 }
334
335 if let Some(path) = output_path {
336 resume_from_output(path, &mut examples);
337 }
338
339 Progress::global().set_total_examples(examples.len());
340
341 Ok(examples)
342}
343
344fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
345 let mut hasher = collections::FxHasher::default();
346 spec.hash(&mut hasher);
347 hasher.finish()
348}
349
350fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
351 let file = match File::open(path) {
352 Ok(f) => f,
353 Err(_) => return,
354 };
355
356 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
357
358 let reader = BufReader::new(file);
359 let mut kept_lines = Vec::new();
360 let mut kept_hashes = HashSet::default();
361
362 for line in reader.lines() {
363 let line = match line {
364 Ok(l) => l,
365 Err(_) => continue,
366 };
367
368 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
369 let hash = spec_hash(&output_example.spec);
370 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
371 kept_hashes.insert(hash);
372 kept_lines.push(line);
373 }
374 }
375 }
376
377 let total = examples.len();
378 let already_processed = kept_hashes.len();
379
380 eprintln!(
381 "Resuming: {}/{} examples already processed",
382 already_processed, total
383 );
384
385 let file = OpenOptions::new()
386 .write(true)
387 .truncate(true)
388 .open(path)
389 .expect("Failed to open output file for rewriting");
390 let mut writer = BufWriter::new(file);
391 for line in &kept_lines {
392 writeln!(writer, "{}", line).expect("Failed to write to output file");
393 }
394 writer.flush().expect("Failed to flush output file");
395
396 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
397}
398
399fn main() {
400 let args = EpArgs::parse();
401
402 if args.printenv {
403 ::util::shell_env::print_env();
404 return;
405 }
406
407 let output = args.output_path();
408 let command = match &args.command {
409 Some(cmd) => cmd.clone(),
410 None => {
411 EpArgs::command().print_help().unwrap();
412 return;
413 }
414 };
415
416 match &command {
417 Command::Clean => {
418 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
419 return;
420 }
421 Command::Synthesize(synth_args) => {
422 let Some(output_dir) = args.output else {
423 panic!("output dir is required");
424 };
425 let config = SynthesizeConfig {
426 repo_url: synth_args.repo.clone(),
427 count: synth_args.count,
428 max_commits: synth_args.max_commits,
429 output_dir,
430 fresh: synth_args.fresh,
431 };
432 smol::block_on(async {
433 if let Err(e) = run_synthesize(config).await {
434 eprintln!("Error: {:?}", e);
435 std::process::exit(1);
436 }
437 });
438 return;
439 }
440 Command::SplitCommit(split_commit_args) => {
441 if let Err(error) =
442 split_commit::run_split_commit(split_commit_args, &args.inputs, output.as_ref())
443 {
444 eprintln!("{error:#}");
445 std::process::exit(1);
446 }
447 return;
448 }
449 Command::Split(split_args) => {
450 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
451 eprintln!("{error:#}");
452 std::process::exit(1);
453 }
454 return;
455 }
456 _ => {}
457 }
458
459 let http_client = Arc::new(ReqwestClient::new());
460 let app = Application::headless().with_http_client(http_client);
461
462 app.run(move |cx| {
463 let app_state = Arc::new(headless::init(cx));
464 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
465
466 cx.spawn(async move |cx| {
467 let result = async {
468 let mut examples =
469 load_examples(app_state.client.http_client(), &args, output.as_ref()).await?;
470
471 match &command {
472 Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
473 predict::sync_batches(&args.provider).await?;
474 }
475 _ => (),
476 }
477
478 let failfast_on_single_example = examples.len() == 1;
479
480 let output_sender: Option<mpsc::UnboundedSender<String>> =
481 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
482 output.as_ref().map(|path| {
483 let file = OpenOptions::new()
484 .create(true)
485 .append(true)
486 .open(path)
487 .expect("Failed to open output file");
488 let mut writer = BufWriter::new(file);
489 let (sender, mut receiver) = mpsc::unbounded::<String>();
490 cx.background_spawn(async move {
491 while let Some(line) = receiver.next().await {
492 writeln!(writer, "{}", line).expect("Failed to write example");
493 writer.flush().expect("Failed to flush output");
494 }
495 })
496 .detach();
497 sender
498 })
499 } else {
500 None
501 };
502
503 let mut grouped_examples = group_examples_by_repo(&mut examples);
504 let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
505
506 for example_batch in example_batches {
507 let futures = example_batch.into_iter().map(|repo_examples| async {
508 for example in repo_examples.iter_mut() {
509 let result = async {
510 match &command {
511 Command::ParseExample => {}
512 Command::LoadProject => {
513 run_load_project(example, app_state.clone(), cx.clone())
514 .await?;
515 }
516 Command::Context => {
517 run_context_retrieval(
518 example,
519 app_state.clone(),
520 cx.clone(),
521 )
522 .await?;
523 }
524 Command::FormatPrompt(args) => {
525 run_format_prompt(
526 example,
527 args,
528 app_state.clone(),
529 cx.clone(),
530 )
531 .await?;
532 }
533 Command::Predict(args) => {
534 run_prediction(
535 example,
536 args,
537 app_state.clone(),
538 cx.clone(),
539 )
540 .await?;
541 }
542 Command::Distill => {
543 run_distill(example).await?;
544 }
545 Command::Score(args) | Command::Eval(args) => {
546 run_scoring(example, &args, app_state.clone(), cx.clone())
547 .await?;
548 }
549 Command::Clean
550 | Command::Synthesize(_)
551 | Command::SplitCommit(_)
552 | Command::Split(_) => {
553 unreachable!()
554 }
555 }
556 anyhow::Ok(())
557 }
558 .await;
559
560 let failed = if let Err(error) = result {
561 handle_error(
562 error,
563 &args,
564 &command,
565 &app_state,
566 failfast_on_single_example,
567 example,
568 )
569 .await;
570 true
571 } else {
572 false
573 };
574
575 let should_write = !failed || args.failed == FailedHandling::Keep;
576 if should_write {
577 if let Some(ref mut sender) = output_sender.clone() {
578 let line = serde_json::to_string(example).unwrap();
579 sender
580 .send(line)
581 .await
582 .expect("Failed to send to output writer");
583 } else if args.output.is_none()
584 && !matches!(command, Command::Eval(_))
585 {
586 let line = serde_json::to_string(example).unwrap();
587 println!("{}", line);
588 }
589 }
590 }
591 });
592 futures::future::join_all(futures).await;
593 }
594
595 Progress::global().finalize();
596
597 match &command {
598 Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
599 predict::sync_batches(&args.provider).await?;
600 }
601 _ => (),
602 }
603
604 match &command {
605 Command::Eval(_) => score::print_report(&examples),
606 _ => (),
607 };
608
609 anyhow::Ok(())
610 }
611 .await;
612
613 if let Err(e) = result {
614 panic!("Fatal error: {:?}", e);
615 }
616
617 let _ = cx.update(|cx| cx.quit());
618 })
619 .detach();
620 });
621}
622
623async fn handle_error(
624 error: anyhow::Error,
625 args: &EpArgs,
626 command: &Command,
627 app_state: &Arc<headless::EpAppState>,
628 failfast_on_single_example: bool,
629 example: &Example,
630) {
631 Progress::global().increment_failed();
632 let example_name = example.spec.filename();
633 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
634 app_state
635 .fs
636 .write(
637 &failed_example_path,
638 &serde_json::to_vec_pretty(&example).unwrap(),
639 )
640 .await
641 .unwrap();
642 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
643 app_state
644 .fs
645 .write(&err_path, format!("{error:?}").as_bytes())
646 .await
647 .unwrap();
648
649 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
650 let mut file = OpenOptions::new()
651 .create(true)
652 .append(true)
653 .open(&failed_jsonl_path)
654 .expect("Failed to open failed.jsonl");
655 writeln!(file, "{}", serde_json::to_string(example).unwrap())
656 .expect("Failed to write to failed.jsonl");
657
658 let cursor_path = example
659 .repo_name()
660 .unwrap()
661 .worktree_path()
662 .join(&example.spec.cursor_path);
663
664 let msg = format!(
665 indoc::indoc! {"
666 While processing \"{}\":
667
668 \x1b[31m{:?}\x1b[0m
669
670 Example: \x1b[36m{}\x1b[0m
671 Error file: \x1b[36m{}\x1b[0m
672 Cursor file: \x1b[36m{}\x1b[0m
673 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
674 "},
675 example.spec.name,
676 error,
677 failed_example_path.display(),
678 err_path.display(),
679 cursor_path.display(),
680 command,
681 failed_example_path.display(),
682 );
683 if args.failfast || failfast_on_single_example {
684 Progress::global().finalize();
685 panic!("{}", msg);
686 } else {
687 log::error!("{}", msg);
688 }
689}