1mod anthropic_client;
2mod distill;
3mod example;
4mod format_prompt;
5mod git;
6mod headless;
7mod load_project;
8mod metrics;
9mod parse_output;
10mod paths;
11mod predict;
12mod progress;
13mod pull_examples;
14mod reorder_patch;
15mod retrieve_context;
16mod score;
17mod split_commit;
18mod split_dataset;
19mod synthesize;
20use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
21use collections::HashSet;
22use edit_prediction::EditPredictionStore;
23use futures::channel::mpsc;
24use futures::{SinkExt as _, StreamExt as _};
25use gpui::{AppContext as _, Application, BackgroundExecutor};
26use zeta_prompt::ZetaVersion;
27
28use reqwest_client::ReqwestClient;
29use serde::{Deserialize, Deserializer, Serialize, Serializer};
30use std::fmt::Display;
31use std::fs::{File, OpenOptions};
32use std::hash::{Hash, Hasher};
33use std::io::{BufRead, BufReader, BufWriter, Write};
34use std::sync::Mutex;
35use std::{path::PathBuf, sync::Arc};
36
37use crate::distill::run_distill;
38use crate::example::{Example, group_examples_by_repo, read_example_files};
39use crate::format_prompt::run_format_prompt;
40use crate::load_project::run_load_project;
41use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
42use crate::predict::run_prediction;
43use crate::progress::Progress;
44use crate::retrieve_context::run_context_retrieval;
45use crate::score::run_scoring;
46use crate::split_commit::SplitCommitArgs;
47use crate::split_dataset::SplitArgs;
48use crate::synthesize::{SynthesizeConfig, run_synthesize};
49
50#[derive(Parser, Debug)]
51#[command(name = "ep")]
52struct EpArgs {
53 #[arg(long, default_value_t = false)]
54 printenv: bool,
55 #[clap(long, default_value_t = 10, global = true)]
56 max_parallelism: usize,
57 #[clap(long, global = true)]
58 limit: Option<usize>,
59 /// Filter examples by name
60 #[clap(long, global = true)]
61 name: Option<String>,
62 /// Filter examples by repository
63 #[clap(long, global = true)]
64 repo: Option<String>,
65 #[command(subcommand)]
66 command: Option<Command>,
67 #[clap(global = true, help = INPUTS_HELP)]
68 inputs: Vec<PathBuf>,
69 #[arg(long, short, global = true)]
70 output: Option<PathBuf>,
71 #[arg(long, short, global = true)]
72 in_place: bool,
73 #[arg(long, short, global = true)]
74 failfast: bool,
75 /// How to handle failed examples in output: keep them or skip them.
76 /// Failed examples are always logged to the run's failed directory.
77 #[arg(long, global = true, default_value = "keep")]
78 failed: FailedHandling,
79}
80
81/// Controls whether failed examples are included in the main output.
82/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
83#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
84pub enum FailedHandling {
85 /// Include failed examples in the main output (default)
86 #[default]
87 Keep,
88 /// Exclude failed examples from the main output
89 Skip,
90}
91
92const INPUTS_HELP: &str = r#"
93Inputs can be file paths or special specifiers:
94
95 path
96 Path to an example(s) file (.md, .json, or .jsonl)
97
98 captured-after:{timestamp}
99 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
100
101 You can specify this multiple times and mix it with file inputs.
102
103 Required environment variables to connect to Snowflake:
104 EP_SNOWFLAKE_API_KEY
105 EP_SNOWFLAKE_BASE_URL
106
107 Optional:
108 EP_SNOWFLAKE_ROLE
109
110Examples:
111
112 # Predict from a file
113 ep predict examples.jsonl
114
115 # Predict from captured examples after a timestamp
116 ep predict captured-after:2025-01-01T00:00:00Z
117
118 # Mix file inputs and captured-after in the same invocation
119 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
120"#;
121
122#[derive(Subcommand, Debug, Clone)]
123enum Command {
124 /// Parse markdown examples and output a combined .jsonl file
125 ParseExample,
126 /// Create git worktrees for each example and load file contents
127 LoadProject,
128 /// Retrieve context for input examples.
129 Context,
130 /// Generate a prompt string for a specific model
131 FormatPrompt(FormatPromptArgs),
132 /// Runs edit prediction
133 Predict(PredictArgs),
134 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
135 /// Requires format-prompt to have been run first. Uses provider from prompt.
136 ParseOutput,
137 /// Computes a score based on actual and expected patches
138 Score(PredictArgs),
139 /// Prepares a distillation dataset by copying expected outputs to
140 /// predicted outputs and removing actual outputs and prompts.
141 Distill,
142 /// Print aggregated scores
143 Eval(PredictArgs),
144 /// Generate eval examples by analyzing git commits from a repository
145 Synthesize(SynthesizeArgs),
146 /// Remove git repositories and worktrees
147 Clean,
148 /// Generate an evaluation example by splitting a chronologically-ordered commit
149 SplitCommit(SplitCommitArgs),
150 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
151 Split(SplitArgs),
152}
153
154impl Display for Command {
155 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156 match self {
157 Command::ParseExample => write!(f, "parse-example"),
158 Command::LoadProject => write!(f, "load-project"),
159 Command::Context => write!(f, "context"),
160 Command::FormatPrompt(args) => {
161 write!(f, "format-prompt --provider={}", args.provider)
162 }
163 Command::Predict(args) => match &args.provider {
164 Some(provider) => write!(f, "predict --provider={}", provider),
165 None => write!(f, "predict"),
166 },
167 Command::ParseOutput => write!(f, "parse-output"),
168 Command::Score(args) => match &args.provider {
169 Some(provider) => write!(f, "score --provider={}", provider),
170 None => write!(f, "score"),
171 },
172 Command::Distill => write!(f, "distill"),
173 Command::Eval(args) => match &args.provider {
174 Some(provider) => write!(f, "eval --provider={}", provider),
175 None => write!(f, "eval"),
176 },
177 Command::Synthesize(args) => {
178 write!(f, "synthesize --repos {}", args.repos.join(" "))
179 }
180 Command::Clean => write!(f, "clean"),
181 Command::SplitCommit(_) => write!(f, "split-commit"),
182 Command::Split(_) => write!(f, "split"),
183 }
184 }
185}
186
187#[derive(Debug, Args, Clone)]
188struct FormatPromptArgs {
189 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
190 provider: PredictionProvider,
191}
192
193#[derive(Debug, Args, Clone)]
194struct PredictArgs {
195 #[clap(long, short('p'))]
196 provider: Option<PredictionProvider>,
197 #[clap(long, default_value_t = 1)]
198 repetitions: usize,
199}
200
201#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
202enum PredictionProvider {
203 Sweep,
204 Mercury,
205 Zeta1,
206 Zeta2(ZetaVersion),
207 Teacher(ZetaVersion),
208 TeacherNonBatching(ZetaVersion),
209}
210
211impl Default for PredictionProvider {
212 fn default() -> Self {
213 PredictionProvider::Zeta2(ZetaVersion::default())
214 }
215}
216
217impl std::fmt::Display for PredictionProvider {
218 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219 match self {
220 PredictionProvider::Sweep => write!(f, "sweep"),
221 PredictionProvider::Mercury => write!(f, "mercury"),
222 PredictionProvider::Zeta1 => write!(f, "zeta1"),
223 PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
224 PredictionProvider::Teacher(version) => write!(f, "teacher:{version}"),
225 PredictionProvider::TeacherNonBatching(version) => {
226 write!(f, "teacher-non-batching:{version}")
227 }
228 }
229 }
230}
231
232impl std::str::FromStr for PredictionProvider {
233 type Err = anyhow::Error;
234
235 fn from_str(mut s: &str) -> Result<Self, Self::Err> {
236 let mut version = ZetaVersion::default();
237 if let Some((first, second)) = s.split_once(':') {
238 version = ZetaVersion::parse(second)?;
239 s = first;
240 }
241
242 let s_lower = s.to_lowercase();
243 match s_lower.as_str() {
244 "sweep" => Ok(PredictionProvider::Sweep),
245 "mercury" => Ok(PredictionProvider::Mercury),
246 "zeta1" => Ok(PredictionProvider::Zeta1),
247 "zeta2" => Ok(PredictionProvider::Zeta2(version)),
248 "teacher" => Ok(PredictionProvider::Teacher(version)),
249 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
250 Ok(PredictionProvider::TeacherNonBatching(version))
251 }
252 _ => {
253 anyhow::bail!(
254 "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching\n\
255 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
256 Available zeta versions:\n{}",
257 ZetaVersion::options_as_string()
258 )
259 }
260 }
261 }
262}
263
264impl Serialize for PredictionProvider {
265 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
266 where
267 S: Serializer,
268 {
269 serializer.serialize_str(&self.to_string())
270 }
271}
272
273impl<'de> Deserialize<'de> for PredictionProvider {
274 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
275 where
276 D: Deserializer<'de>,
277 {
278 let s = String::deserialize(deserializer)?;
279 s.parse().map_err(serde::de::Error::custom)
280 }
281}
282
283#[derive(Debug, Args, Clone)]
284struct SynthesizeArgs {
285 /// Repository URLs (git@github.com:owner/repo or https://...)
286 #[clap(long, required = true, num_args = 1..)]
287 repos: Vec<String>,
288
289 /// Number of examples to generate per repository
290 #[clap(long, default_value_t = 5)]
291 count: usize,
292
293 /// Maximum commits to scan per repository before giving up
294 #[clap(long, default_value_t = 100)]
295 max_commits: usize,
296
297 /// Ignore state file and reprocess all commits
298 #[clap(long)]
299 fresh: bool,
300}
301
302impl EpArgs {
303 fn output_path(&self) -> Option<PathBuf> {
304 if self.in_place {
305 if self.inputs.len() == 1 {
306 self.inputs.first().cloned()
307 } else {
308 panic!("--in-place requires exactly one input file")
309 }
310 } else {
311 self.output.clone()
312 }
313 }
314}
315
316async fn load_examples(
317 http_client: Arc<dyn http_client::HttpClient>,
318 args: &EpArgs,
319 output_path: Option<&PathBuf>,
320 background_executor: BackgroundExecutor,
321) -> anyhow::Result<Vec<Example>> {
322 let mut captured_after_timestamps = Vec::new();
323 let mut file_inputs = Vec::new();
324
325 for input in &args.inputs {
326 let input_string = input.to_string_lossy();
327 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
328 captured_after_timestamps.push(timestamp.to_string());
329 } else {
330 file_inputs.push(input.clone());
331 }
332 }
333
334 let mut examples = read_example_files(&file_inputs);
335
336 Progress::global().set_total_examples(examples.len());
337
338 let remaining_limit_for_snowflake =
339 args.limit.map(|limit| limit.saturating_sub(examples.len()));
340
341 if let Some(0) = remaining_limit_for_snowflake {
342 log::info!(
343 "skipping captured-after inputs because --limit is already satisfied by example files"
344 );
345 } else if !captured_after_timestamps.is_empty() {
346 captured_after_timestamps.sort();
347
348 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
349
350 let mut captured_examples = pull_examples::fetch_captured_examples_after(
351 http_client,
352 &captured_after_timestamps,
353 max_rows_per_timestamp,
354 background_executor,
355 )
356 .await?;
357 examples.append(&mut captured_examples);
358 }
359
360 crate::example::sort_examples_by_repo_and_rev(&mut examples);
361
362 if let Some(name_filter) = &args.name {
363 examples.retain(|example| example.spec.name.contains(name_filter));
364 }
365 if let Some(repo_filter) = &args.repo {
366 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
367 }
368
369 if let Some(limit) = args.limit {
370 if examples.len() > limit {
371 examples.truncate(limit);
372 }
373 }
374
375 // Skip resume logic for --in-place since input and output are the same file,
376 // which would incorrectly treat all input examples as already processed.
377 if !args.in_place {
378 if let Some(path) = output_path {
379 resume_from_output(path, &mut examples);
380 }
381 }
382
383 Progress::global().set_total_examples(examples.len());
384
385 Ok(examples)
386}
387
388fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
389 let mut hasher = collections::FxHasher::default();
390 spec.hash(&mut hasher);
391 hasher.finish()
392}
393
394fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
395 let file = match File::open(path) {
396 Ok(f) => f,
397 Err(_) => return,
398 };
399
400 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
401
402 let reader = BufReader::new(file);
403 let mut kept_lines = Vec::new();
404 let mut kept_hashes = HashSet::default();
405
406 for line in reader.lines() {
407 let line = match line {
408 Ok(l) => l,
409 Err(_) => continue,
410 };
411
412 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
413 let hash = spec_hash(&output_example.spec);
414 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
415 kept_hashes.insert(hash);
416 kept_lines.push(line);
417 }
418 }
419 }
420
421 let total = examples.len();
422 let already_processed = kept_hashes.len();
423
424 eprintln!(
425 "Resuming: {}/{} examples already processed",
426 already_processed, total
427 );
428
429 let file = OpenOptions::new()
430 .write(true)
431 .truncate(true)
432 .open(path)
433 .expect("Failed to open output file for rewriting");
434 let mut writer = BufWriter::new(file);
435 for line in &kept_lines {
436 writeln!(writer, "{}", line).expect("Failed to write to output file");
437 }
438 writer.flush().expect("Failed to flush output file");
439
440 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
441}
442
443fn main() {
444 let args = EpArgs::parse();
445
446 if args.printenv {
447 ::util::shell_env::print_env();
448 return;
449 }
450
451 let output = args.output_path();
452 let command = match &args.command {
453 Some(cmd) => cmd.clone(),
454 None => {
455 EpArgs::command().print_help().unwrap();
456 return;
457 }
458 };
459
460 match &command {
461 Command::Clean => {
462 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
463 return;
464 }
465 Command::Synthesize(synth_args) => {
466 let Some(output_dir) = args.output else {
467 panic!("output dir is required");
468 };
469 let config = SynthesizeConfig {
470 repo_urls: synth_args.repos.clone(),
471 count: synth_args.count,
472 max_commits: synth_args.max_commits,
473 output_dir,
474 fresh: synth_args.fresh,
475 };
476 smol::block_on(async {
477 if let Err(e) = run_synthesize(config).await {
478 eprintln!("Error: {:?}", e);
479 std::process::exit(1);
480 }
481 });
482 return;
483 }
484 Command::SplitCommit(split_commit_args) => {
485 if let Err(error) = split_commit::run_split_commit(
486 split_commit_args,
487 &args.inputs,
488 output.as_ref(),
489 args.failed,
490 ) {
491 eprintln!("{error:#}");
492 std::process::exit(1);
493 }
494 return;
495 }
496 Command::Split(split_args) => {
497 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
498 eprintln!("{error:#}");
499 std::process::exit(1);
500 }
501 return;
502 }
503 _ => {}
504 }
505
506 let http_client = Arc::new(ReqwestClient::new());
507 let app = Application::headless().with_http_client(http_client);
508
509 app.run(move |cx| {
510 let app_state = Arc::new(headless::init(cx));
511 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
512
513 cx.spawn(async move |cx| {
514 let result = async {
515 let examples = load_examples(
516 app_state.client.http_client(),
517 &args,
518 output.as_ref(),
519 cx.background_executor().clone(),
520 )
521 .await?;
522
523 match &command {
524 Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
525 predict::sync_batches(args.provider.as_ref()).await?;
526 }
527 _ => (),
528 }
529
530 let failfast_on_single_example = examples.len() == 1;
531
532 let output_sender: Option<mpsc::UnboundedSender<String>> =
533 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
534 output.as_ref().map(|path| {
535 let file = OpenOptions::new()
536 .create(true)
537 .append(true)
538 .open(path)
539 .expect("Failed to open output file");
540 let mut writer = BufWriter::new(file);
541 let (sender, mut receiver) = mpsc::unbounded::<String>();
542 cx.background_spawn(async move {
543 while let Some(line) = receiver.next().await {
544 writeln!(writer, "{}", line).expect("Failed to write example");
545 writer.flush().expect("Failed to flush output");
546 }
547 })
548 .detach();
549 sender
550 })
551 } else {
552 None
553 };
554
555 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
556 let finished_examples = Mutex::new(Vec::new());
557
558 let mut tasks = Vec::new();
559 for _ in 0..args.max_parallelism {
560 tasks.push(async {
561 loop {
562 let Some(mut repo_examples) =
563 grouped_examples.lock().unwrap().pop_front()
564 else {
565 break;
566 };
567 for example in &mut repo_examples {
568 let example_progress =
569 Progress::global().start_group(&example.spec.name);
570
571 let result = async {
572 match &command {
573 Command::ParseExample => {}
574 Command::LoadProject => {
575 run_load_project(
576 example,
577 app_state.clone(),
578 &example_progress,
579 cx.clone(),
580 )
581 .await?;
582 }
583 Command::Context => {
584 run_context_retrieval(
585 example,
586 app_state.clone(),
587 &example_progress,
588 cx.clone(),
589 )
590 .await?;
591 }
592 Command::FormatPrompt(args) => {
593 run_format_prompt(
594 example,
595 args,
596 app_state.clone(),
597 &example_progress,
598 cx.clone(),
599 )
600 .await?;
601 }
602 Command::Predict(args) => {
603 run_prediction(
604 example,
605 args,
606 app_state.clone(),
607 &example_progress,
608 cx.clone(),
609 )
610 .await?;
611 }
612 Command::ParseOutput => {
613 parse_output::run_parse_output(example)?;
614 }
615 Command::Distill => {
616 run_distill(example).await?;
617 }
618 Command::Score(args) | Command::Eval(args) => {
619 run_scoring(
620 example,
621 &args,
622 app_state.clone(),
623 &example_progress,
624 cx.clone(),
625 )
626 .await?;
627 }
628 Command::Clean
629 | Command::Synthesize(_)
630 | Command::SplitCommit(_)
631 | Command::Split(_) => {
632 unreachable!()
633 }
634 }
635 anyhow::Ok(())
636 }
637 .await;
638
639 let failed = if let Err(error) = result {
640 handle_error(
641 error,
642 &args,
643 &command,
644 &app_state,
645 failfast_on_single_example,
646 &example,
647 )
648 .await;
649 true
650 } else {
651 false
652 };
653
654 let should_write = !failed || args.failed == FailedHandling::Keep;
655 if should_write {
656 if let Some(ref mut sender) = output_sender.clone() {
657 let line = serde_json::to_string(&example).unwrap();
658 sender
659 .send(line)
660 .await
661 .expect("Failed to send to output writer");
662 } else if args.output.is_none()
663 && !matches!(command, Command::Eval(_))
664 {
665 let line = serde_json::to_string(&example).unwrap();
666 println!("{}", line);
667 }
668 }
669 }
670
671 if let Some(state) =
672 repo_examples.first().and_then(|e| e.state.as_ref())
673 {
674 let mut cx = cx.clone();
675 if let Some(ep_store) =
676 cx.update(|cx| EditPredictionStore::try_global(cx))
677 {
678 let project = state.project.clone();
679 ep_store.update(&mut cx, |store, _| {
680 store.remove_project(&project);
681 });
682 }
683 }
684
685 app_state
686 .project_cache
687 .remove(&repo_examples.first().unwrap().spec.repository_url);
688 for example in &mut repo_examples {
689 example.state.take();
690 }
691 finished_examples
692 .lock()
693 .unwrap()
694 .extend_from_slice(&repo_examples);
695 }
696 });
697 }
698 futures::future::join_all(tasks).await;
699
700 Progress::global().finalize();
701
702 match &command {
703 Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
704 predict::sync_batches(args.provider.as_ref()).await?;
705 }
706 _ => (),
707 }
708
709 match &command {
710 Command::Eval(_) => score::print_report(&finished_examples.lock().unwrap()),
711 _ => (),
712 };
713
714 anyhow::Ok(())
715 }
716 .await;
717
718 if let Err(e) = result {
719 panic!("Fatal error: {:?}", e);
720 }
721
722 let _ = cx.update(|cx| cx.quit());
723 })
724 .detach();
725 });
726}
727
728async fn handle_error(
729 error: anyhow::Error,
730 args: &EpArgs,
731 command: &Command,
732 app_state: &Arc<headless::EpAppState>,
733 failfast_on_single_example: bool,
734 example: &Example,
735) {
736 Progress::global().increment_failed();
737 let example_name = example.spec.filename();
738 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
739 app_state
740 .fs
741 .write(
742 &failed_example_path,
743 &serde_json::to_vec_pretty(&example).unwrap(),
744 )
745 .await
746 .unwrap();
747 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
748 app_state
749 .fs
750 .write(&err_path, format!("{error:?}").as_bytes())
751 .await
752 .unwrap();
753
754 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
755 let mut file = OpenOptions::new()
756 .create(true)
757 .append(true)
758 .open(&failed_jsonl_path)
759 .expect("Failed to open failed.jsonl");
760 writeln!(file, "{}", serde_json::to_string(example).unwrap())
761 .expect("Failed to write to failed.jsonl");
762
763 let cursor_path = example
764 .repo_name()
765 .unwrap()
766 .worktree_path()
767 .join(&example.spec.cursor_path);
768
769 let msg = format!(
770 indoc::indoc! {"
771 While processing \"{}\":
772
773 \x1b[31m{:?}\x1b[0m
774
775 Example: \x1b[36m{}\x1b[0m
776 Error file: \x1b[36m{}\x1b[0m
777 Cursor file: \x1b[36m{}\x1b[0m
778 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
779 "},
780 example.spec.name,
781 error,
782 failed_example_path.display(),
783 err_path.display(),
784 cursor_path.display(),
785 command,
786 failed_example_path.display(),
787 );
788 if args.failfast || failfast_on_single_example {
789 Progress::global().finalize();
790 panic!("{}", msg);
791 } else {
792 log::error!("{}", msg);
793 }
794}