1mod anthropic_client;
2mod distill;
3mod example;
4mod format_prompt;
5mod git;
6mod headless;
7mod load_project;
8mod metrics;
9mod paths;
10mod predict;
11mod progress;
12mod pull_examples;
13mod reorder_patch;
14mod retrieve_context;
15mod score;
16mod split_commit;
17mod split_dataset;
18mod synthesize;
19use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
20use collections::HashSet;
21use edit_prediction::EditPredictionStore;
22use futures::channel::mpsc;
23use futures::{SinkExt as _, StreamExt as _};
24use gpui::{AppContext as _, Application, BackgroundExecutor};
25use zeta_prompt::ZetaVersion;
26
27use reqwest_client::ReqwestClient;
28use serde::{Deserialize, Serialize};
29use std::fmt::Display;
30use std::fs::{File, OpenOptions};
31use std::hash::{Hash, Hasher};
32use std::io::{BufRead, BufReader, BufWriter, Write};
33use std::{path::PathBuf, sync::Arc};
34
35use crate::distill::run_distill;
36use crate::example::{Example, group_examples_by_repo, read_example_files};
37use crate::format_prompt::run_format_prompt;
38use crate::load_project::run_load_project;
39use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
40use crate::predict::run_prediction;
41use crate::progress::Progress;
42use crate::retrieve_context::run_context_retrieval;
43use crate::score::run_scoring;
44use crate::split_commit::SplitCommitArgs;
45use crate::split_dataset::SplitArgs;
46use crate::synthesize::{SynthesizeConfig, run_synthesize};
47
48#[derive(Parser, Debug)]
49#[command(name = "ep")]
50struct EpArgs {
51 #[arg(long, default_value_t = false)]
52 printenv: bool,
53 #[clap(long, default_value_t = 10, global = true)]
54 max_parallelism: usize,
55 #[clap(long, global = true)]
56 limit: Option<usize>,
57 /// Filter examples by name
58 #[clap(long, global = true)]
59 name: Option<String>,
60 /// Filter examples by repository
61 #[clap(long, global = true)]
62 repo: Option<String>,
63 #[command(subcommand)]
64 command: Option<Command>,
65 #[clap(global = true, help = INPUTS_HELP)]
66 inputs: Vec<PathBuf>,
67 #[arg(long, short, global = true)]
68 output: Option<PathBuf>,
69 #[arg(long, short, global = true)]
70 in_place: bool,
71 #[arg(long, short, global = true)]
72 failfast: bool,
73 /// How to handle failed examples in output: keep them or skip them.
74 /// Failed examples are always logged to the run's failed directory.
75 #[arg(long, global = true, default_value = "keep")]
76 failed: FailedHandling,
77}
78
79/// Controls whether failed examples are included in the main output.
80/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
81#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
82pub enum FailedHandling {
83 /// Include failed examples in the main output (default)
84 #[default]
85 Keep,
86 /// Exclude failed examples from the main output
87 Skip,
88}
89
90const INPUTS_HELP: &str = r#"
91Inputs can be file paths or special specifiers:
92
93 path
94 Path to an example(s) file (.md, .json, or .jsonl)
95
96 captured-after:{timestamp}
97 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
98
99 You can specify this multiple times and mix it with file inputs.
100
101 Required environment variables to connect to Snowflake:
102 EP_SNOWFLAKE_API_KEY
103 EP_SNOWFLAKE_BASE_URL
104
105 Optional:
106 EP_SNOWFLAKE_ROLE
107
108Examples:
109
110 # Predict from a file
111 ep predict examples.jsonl
112
113 # Predict from captured examples after a timestamp
114 ep predict captured-after:2025-01-01T00:00:00Z
115
116 # Mix file inputs and captured-after in the same invocation
117 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
118"#;
119
120#[derive(Subcommand, Debug, Clone)]
121enum Command {
122 /// Parse markdown examples and output a combined .jsonl file
123 ParseExample,
124 /// Create git worktrees for each example and load file contents
125 LoadProject,
126 /// Retrieve context for input examples.
127 Context,
128 /// Generate a prompt string for a specific model
129 FormatPrompt(FormatPromptArgs),
130 /// Runs edit prediction
131 Predict(PredictArgs),
132 /// Computes a score based on actual and expected patches
133 Score(PredictArgs),
134 /// Prepares a distillation dataset by copying expected outputs to
135 /// predicted outputs and removing actual outputs and prompts.
136 Distill,
137 /// Print aggregated scores
138 Eval(PredictArgs),
139 /// Generate eval examples by analyzing git commits from a repository
140 Synthesize(SynthesizeArgs),
141 /// Remove git repositories and worktrees
142 Clean,
143 /// Generate an evaluation example by splitting a chronologically-ordered commit
144 SplitCommit(SplitCommitArgs),
145 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
146 Split(SplitArgs),
147}
148
149impl Display for Command {
150 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151 match self {
152 Command::ParseExample => write!(f, "parse-example"),
153 Command::LoadProject => write!(f, "load-project"),
154 Command::Context => write!(f, "context"),
155 Command::FormatPrompt(format_prompt_args) => write!(
156 f,
157 "format-prompt --prompt-format={}",
158 format_prompt_args
159 .provider
160 .to_possible_value()
161 .unwrap()
162 .get_name()
163 ),
164 Command::Predict(predict_args) => {
165 write!(
166 f,
167 "predict --provider={:?}",
168 predict_args
169 .provider
170 .to_possible_value()
171 .unwrap()
172 .get_name()
173 )
174 }
175 Command::Score(predict_args) => {
176 write!(
177 f,
178 "score --provider={:?}",
179 predict_args
180 .provider
181 .to_possible_value()
182 .unwrap()
183 .get_name()
184 )
185 }
186 Command::Distill => write!(f, "distill"),
187 Command::Eval(predict_args) => write!(
188 f,
189 "eval --provider={:?}",
190 predict_args
191 .provider
192 .to_possible_value()
193 .unwrap()
194 .get_name()
195 ),
196 Command::Synthesize(args) => {
197 write!(f, "synthesize --repo={}", args.repo)
198 }
199 Command::Clean => write!(f, "clean"),
200 Command::SplitCommit(_) => write!(f, "split-commit"),
201 Command::Split(_) => write!(f, "split"),
202 }
203 }
204}
205
206#[derive(Debug, Args, Clone)]
207struct FormatPromptArgs {
208 #[clap(long, short)]
209 provider: PredictionProvider,
210 #[clap(
211 long,
212 short,
213 help = "(only for --provider zeta2) A substring of a zeta_prompt::ZetaVersion variant to use",
214 value_parser = ZetaVersion::parse,
215 default_value_t = ZetaVersion::default(),
216 )]
217 version: ZetaVersion,
218}
219
220#[derive(Debug, Args, Clone)]
221struct PredictArgs {
222 #[clap(long, short)]
223 provider: PredictionProvider,
224 #[clap(long, default_value_t = 1)]
225 repetitions: usize,
226 #[clap(
227 long,
228 short,
229 help = "(only for --provider zeta2) A substring of a zeta_prompt::ZetaVersion variant to use",
230 value_parser = ZetaVersion::parse,
231 )]
232 version: ZetaVersion,
233}
234
235#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
236enum PredictionProvider {
237 Sweep,
238 Mercury,
239 Zeta1,
240 Zeta2,
241 Teacher,
242 TeacherNonBatching,
243}
244
245#[derive(Debug, Args, Clone)]
246struct SynthesizeArgs {
247 /// Repository URL (git@github.com:owner/repo or https://...)
248 #[clap(long)]
249 repo: String,
250
251 /// Number of examples to generate
252 #[clap(long, default_value_t = 5)]
253 count: usize,
254
255 /// Maximum commits to scan before giving up
256 #[clap(long, default_value_t = 100)]
257 max_commits: usize,
258
259 /// Ignore state file and reprocess all commits
260 #[clap(long)]
261 fresh: bool,
262}
263
264impl EpArgs {
265 fn output_path(&self) -> Option<PathBuf> {
266 if self.in_place {
267 if self.inputs.len() == 1 {
268 self.inputs.first().cloned()
269 } else {
270 panic!("--in-place requires exactly one input file")
271 }
272 } else {
273 self.output.clone()
274 }
275 }
276}
277
278async fn load_examples(
279 http_client: Arc<dyn http_client::HttpClient>,
280 args: &EpArgs,
281 output_path: Option<&PathBuf>,
282 background_executor: BackgroundExecutor,
283) -> anyhow::Result<Vec<Example>> {
284 let mut captured_after_timestamps = Vec::new();
285 let mut file_inputs = Vec::new();
286
287 for input in &args.inputs {
288 let input_string = input.to_string_lossy();
289 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
290 captured_after_timestamps.push(timestamp.to_string());
291 } else {
292 file_inputs.push(input.clone());
293 }
294 }
295
296 let mut examples = read_example_files(&file_inputs);
297
298 Progress::global().set_total_examples(examples.len());
299
300 let remaining_limit_for_snowflake =
301 args.limit.map(|limit| limit.saturating_sub(examples.len()));
302
303 if let Some(0) = remaining_limit_for_snowflake {
304 log::info!(
305 "skipping captured-after inputs because --limit is already satisfied by example files"
306 );
307 } else if !captured_after_timestamps.is_empty() {
308 captured_after_timestamps.sort();
309
310 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
311
312 let mut captured_examples = pull_examples::fetch_captured_examples_after(
313 http_client,
314 &captured_after_timestamps,
315 max_rows_per_timestamp,
316 background_executor,
317 )
318 .await?;
319 examples.append(&mut captured_examples);
320 }
321
322 crate::example::sort_examples_by_repo_and_rev(&mut examples);
323
324 if let Some(name_filter) = &args.name {
325 examples.retain(|example| example.spec.name.contains(name_filter));
326 }
327 if let Some(repo_filter) = &args.repo {
328 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
329 }
330
331 if let Some(limit) = args.limit {
332 if examples.len() > limit {
333 examples.truncate(limit);
334 }
335 }
336
337 if let Some(path) = output_path {
338 resume_from_output(path, &mut examples);
339 }
340
341 Progress::global().set_total_examples(examples.len());
342
343 Ok(examples)
344}
345
346fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
347 let mut hasher = collections::FxHasher::default();
348 spec.hash(&mut hasher);
349 hasher.finish()
350}
351
352fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
353 let file = match File::open(path) {
354 Ok(f) => f,
355 Err(_) => return,
356 };
357
358 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
359
360 let reader = BufReader::new(file);
361 let mut kept_lines = Vec::new();
362 let mut kept_hashes = HashSet::default();
363
364 for line in reader.lines() {
365 let line = match line {
366 Ok(l) => l,
367 Err(_) => continue,
368 };
369
370 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
371 let hash = spec_hash(&output_example.spec);
372 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
373 kept_hashes.insert(hash);
374 kept_lines.push(line);
375 }
376 }
377 }
378
379 let total = examples.len();
380 let already_processed = kept_hashes.len();
381
382 eprintln!(
383 "Resuming: {}/{} examples already processed",
384 already_processed, total
385 );
386
387 let file = OpenOptions::new()
388 .write(true)
389 .truncate(true)
390 .open(path)
391 .expect("Failed to open output file for rewriting");
392 let mut writer = BufWriter::new(file);
393 for line in &kept_lines {
394 writeln!(writer, "{}", line).expect("Failed to write to output file");
395 }
396 writer.flush().expect("Failed to flush output file");
397
398 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
399}
400
401fn main() {
402 let args = EpArgs::parse();
403
404 if args.printenv {
405 ::util::shell_env::print_env();
406 return;
407 }
408
409 let output = args.output_path();
410 let command = match &args.command {
411 Some(cmd) => cmd.clone(),
412 None => {
413 EpArgs::command().print_help().unwrap();
414 return;
415 }
416 };
417
418 match &command {
419 Command::Clean => {
420 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
421 return;
422 }
423 Command::Synthesize(synth_args) => {
424 let Some(output_dir) = args.output else {
425 panic!("output dir is required");
426 };
427 let config = SynthesizeConfig {
428 repo_url: synth_args.repo.clone(),
429 count: synth_args.count,
430 max_commits: synth_args.max_commits,
431 output_dir,
432 fresh: synth_args.fresh,
433 };
434 smol::block_on(async {
435 if let Err(e) = run_synthesize(config).await {
436 eprintln!("Error: {:?}", e);
437 std::process::exit(1);
438 }
439 });
440 return;
441 }
442 Command::SplitCommit(split_commit_args) => {
443 if let Err(error) =
444 split_commit::run_split_commit(split_commit_args, &args.inputs, output.as_ref())
445 {
446 eprintln!("{error:#}");
447 std::process::exit(1);
448 }
449 return;
450 }
451 Command::Split(split_args) => {
452 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
453 eprintln!("{error:#}");
454 std::process::exit(1);
455 }
456 return;
457 }
458 _ => {}
459 }
460
461 let http_client = Arc::new(ReqwestClient::new());
462 let app = Application::headless().with_http_client(http_client);
463
464 app.run(move |cx| {
465 let app_state = Arc::new(headless::init(cx));
466 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
467
468 cx.spawn(async move |cx| {
469 let result = async {
470 let mut examples = load_examples(
471 app_state.client.http_client(),
472 &args,
473 output.as_ref(),
474 cx.background_executor().clone(),
475 )
476 .await?;
477
478 match &command {
479 Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
480 predict::sync_batches(&args.provider).await?;
481 }
482 _ => (),
483 }
484
485 let failfast_on_single_example = examples.len() == 1;
486
487 let output_sender: Option<mpsc::UnboundedSender<String>> =
488 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
489 output.as_ref().map(|path| {
490 let file = OpenOptions::new()
491 .create(true)
492 .append(true)
493 .open(path)
494 .expect("Failed to open output file");
495 let mut writer = BufWriter::new(file);
496 let (sender, mut receiver) = mpsc::unbounded::<String>();
497 cx.background_spawn(async move {
498 while let Some(line) = receiver.next().await {
499 writeln!(writer, "{}", line).expect("Failed to write example");
500 writer.flush().expect("Failed to flush output");
501 }
502 })
503 .detach();
504 sender
505 })
506 } else {
507 None
508 };
509
510 let mut grouped_examples = group_examples_by_repo(&mut examples);
511 let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
512
513 for example_batch in example_batches {
514 let futures = example_batch.into_iter().map(|repo_examples| async {
515 for example in repo_examples.iter_mut() {
516 let result = async {
517 match &command {
518 Command::ParseExample => {}
519 Command::LoadProject => {
520 run_load_project(example, app_state.clone(), cx.clone())
521 .await?;
522 }
523 Command::Context => {
524 run_context_retrieval(
525 example,
526 app_state.clone(),
527 cx.clone(),
528 )
529 .await?;
530 }
531 Command::FormatPrompt(args) => {
532 run_format_prompt(
533 example,
534 args,
535 app_state.clone(),
536 cx.clone(),
537 )
538 .await?;
539 }
540 Command::Predict(args) => {
541 run_prediction(
542 example,
543 args,
544 app_state.clone(),
545 cx.clone(),
546 )
547 .await?;
548 }
549 Command::Distill => {
550 run_distill(example).await?;
551 }
552 Command::Score(args) | Command::Eval(args) => {
553 run_scoring(example, &args, app_state.clone(), cx.clone())
554 .await?;
555 }
556 Command::Clean
557 | Command::Synthesize(_)
558 | Command::SplitCommit(_)
559 | Command::Split(_) => {
560 unreachable!()
561 }
562 }
563 anyhow::Ok(())
564 }
565 .await;
566
567 let failed = if let Err(error) = result {
568 handle_error(
569 error,
570 &args,
571 &command,
572 &app_state,
573 failfast_on_single_example,
574 example,
575 )
576 .await;
577 true
578 } else {
579 false
580 };
581
582 let should_write = !failed || args.failed == FailedHandling::Keep;
583 if should_write {
584 if let Some(ref mut sender) = output_sender.clone() {
585 let line = serde_json::to_string(example).unwrap();
586 sender
587 .send(line)
588 .await
589 .expect("Failed to send to output writer");
590 } else if args.output.is_none()
591 && !matches!(command, Command::Eval(_))
592 {
593 let line = serde_json::to_string(example).unwrap();
594 println!("{}", line);
595 }
596 }
597 }
598 });
599 futures::future::join_all(futures).await;
600 }
601
602 Progress::global().finalize();
603
604 match &command {
605 Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
606 predict::sync_batches(&args.provider).await?;
607 }
608 _ => (),
609 }
610
611 match &command {
612 Command::Eval(_) => score::print_report(&examples),
613 _ => (),
614 };
615
616 anyhow::Ok(())
617 }
618 .await;
619
620 if let Err(e) = result {
621 panic!("Fatal error: {:?}", e);
622 }
623
624 let _ = cx.update(|cx| cx.quit());
625 })
626 .detach();
627 });
628}
629
630async fn handle_error(
631 error: anyhow::Error,
632 args: &EpArgs,
633 command: &Command,
634 app_state: &Arc<headless::EpAppState>,
635 failfast_on_single_example: bool,
636 example: &Example,
637) {
638 Progress::global().increment_failed();
639 let example_name = example.spec.filename();
640 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
641 app_state
642 .fs
643 .write(
644 &failed_example_path,
645 &serde_json::to_vec_pretty(&example).unwrap(),
646 )
647 .await
648 .unwrap();
649 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
650 app_state
651 .fs
652 .write(&err_path, format!("{error:?}").as_bytes())
653 .await
654 .unwrap();
655
656 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
657 let mut file = OpenOptions::new()
658 .create(true)
659 .append(true)
660 .open(&failed_jsonl_path)
661 .expect("Failed to open failed.jsonl");
662 writeln!(file, "{}", serde_json::to_string(example).unwrap())
663 .expect("Failed to write to failed.jsonl");
664
665 let cursor_path = example
666 .repo_name()
667 .unwrap()
668 .worktree_path()
669 .join(&example.spec.cursor_path);
670
671 let msg = format!(
672 indoc::indoc! {"
673 While processing \"{}\":
674
675 \x1b[31m{:?}\x1b[0m
676
677 Example: \x1b[36m{}\x1b[0m
678 Error file: \x1b[36m{}\x1b[0m
679 Cursor file: \x1b[36m{}\x1b[0m
680 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
681 "},
682 example.spec.name,
683 error,
684 failed_example_path.display(),
685 err_path.display(),
686 cursor_path.display(),
687 command,
688 failed_example_path.display(),
689 );
690 if args.failfast || failfast_on_single_example {
691 Progress::global().finalize();
692 panic!("{}", msg);
693 } else {
694 log::error!("{}", msg);
695 }
696}