1mod anthropic_client;
2mod distill;
3mod example;
4mod format_prompt;
5mod git;
6mod headless;
7mod load_project;
8mod metrics;
9mod paths;
10mod predict;
11mod progress;
12mod pull_examples;
13mod reorder_patch;
14mod retrieve_context;
15mod score;
16mod split_commit;
17mod split_dataset;
18mod synthesize;
19use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
20use collections::HashSet;
21use edit_prediction::EditPredictionStore;
22use futures::channel::mpsc;
23use futures::{SinkExt as _, StreamExt as _};
24use gpui::{AppContext as _, Application, BackgroundExecutor};
25use zeta_prompt::ZetaVersion;
26
27use reqwest_client::ReqwestClient;
28use serde::{Deserialize, Deserializer, Serialize, Serializer};
29use std::fmt::Display;
30use std::fs::{File, OpenOptions};
31use std::hash::{Hash, Hasher};
32use std::io::{BufRead, BufReader, BufWriter, Write};
33use std::sync::Mutex;
34use std::{path::PathBuf, sync::Arc};
35
36use crate::distill::run_distill;
37use crate::example::{Example, group_examples_by_repo, read_example_files};
38use crate::format_prompt::run_format_prompt;
39use crate::load_project::run_load_project;
40use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
41use crate::predict::run_prediction;
42use crate::progress::Progress;
43use crate::retrieve_context::run_context_retrieval;
44use crate::score::run_scoring;
45use crate::split_commit::SplitCommitArgs;
46use crate::split_dataset::SplitArgs;
47use crate::synthesize::{SynthesizeConfig, run_synthesize};
48
49#[derive(Parser, Debug)]
50#[command(name = "ep")]
51struct EpArgs {
52 #[arg(long, default_value_t = false)]
53 printenv: bool,
54 #[clap(long, default_value_t = 10, global = true)]
55 max_parallelism: usize,
56 #[clap(long, global = true)]
57 limit: Option<usize>,
58 /// Filter examples by name
59 #[clap(long, global = true)]
60 name: Option<String>,
61 /// Filter examples by repository
62 #[clap(long, global = true)]
63 repo: Option<String>,
64 #[command(subcommand)]
65 command: Option<Command>,
66 #[clap(global = true, help = INPUTS_HELP)]
67 inputs: Vec<PathBuf>,
68 #[arg(long, short, global = true)]
69 output: Option<PathBuf>,
70 #[arg(long, short, global = true)]
71 in_place: bool,
72 #[arg(long, short, global = true)]
73 failfast: bool,
74 /// How to handle failed examples in output: keep them or skip them.
75 /// Failed examples are always logged to the run's failed directory.
76 #[arg(long, global = true, default_value = "keep")]
77 failed: FailedHandling,
78}
79
80/// Controls whether failed examples are included in the main output.
81/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
82#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
83pub enum FailedHandling {
84 /// Include failed examples in the main output (default)
85 #[default]
86 Keep,
87 /// Exclude failed examples from the main output
88 Skip,
89}
90
91const INPUTS_HELP: &str = r#"
92Inputs can be file paths or special specifiers:
93
94 path
95 Path to an example(s) file (.md, .json, or .jsonl)
96
97 captured-after:{timestamp}
98 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
99
100 You can specify this multiple times and mix it with file inputs.
101
102 Required environment variables to connect to Snowflake:
103 EP_SNOWFLAKE_API_KEY
104 EP_SNOWFLAKE_BASE_URL
105
106 Optional:
107 EP_SNOWFLAKE_ROLE
108
109Examples:
110
111 # Predict from a file
112 ep predict examples.jsonl
113
114 # Predict from captured examples after a timestamp
115 ep predict captured-after:2025-01-01T00:00:00Z
116
117 # Mix file inputs and captured-after in the same invocation
118 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
119"#;
120
121#[derive(Subcommand, Debug, Clone)]
122enum Command {
123 /// Parse markdown examples and output a combined .jsonl file
124 ParseExample,
125 /// Create git worktrees for each example and load file contents
126 LoadProject,
127 /// Retrieve context for input examples.
128 Context,
129 /// Generate a prompt string for a specific model
130 FormatPrompt(FormatPromptArgs),
131 /// Runs edit prediction
132 Predict(PredictArgs),
133 /// Computes a score based on actual and expected patches
134 Score(PredictArgs),
135 /// Prepares a distillation dataset by copying expected outputs to
136 /// predicted outputs and removing actual outputs and prompts.
137 Distill,
138 /// Print aggregated scores
139 Eval(PredictArgs),
140 /// Generate eval examples by analyzing git commits from a repository
141 Synthesize(SynthesizeArgs),
142 /// Remove git repositories and worktrees
143 Clean,
144 /// Generate an evaluation example by splitting a chronologically-ordered commit
145 SplitCommit(SplitCommitArgs),
146 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
147 Split(SplitArgs),
148}
149
150impl Display for Command {
151 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152 match self {
153 Command::ParseExample => write!(f, "parse-example"),
154 Command::LoadProject => write!(f, "load-project"),
155 Command::Context => write!(f, "context"),
156 Command::FormatPrompt(args) => {
157 write!(f, "format-prompt --provider={}", args.provider)
158 }
159 Command::Predict(args) => {
160 write!(f, "predict --provider={}", args.provider)
161 }
162 Command::Score(args) => {
163 write!(f, "score --provider={}", args.provider)
164 }
165 Command::Distill => write!(f, "distill"),
166 Command::Eval(args) => {
167 write!(f, "eval --provider={}", args.provider)
168 }
169 Command::Synthesize(args) => {
170 write!(f, "synthesize --repos {}", args.repos.join(" "))
171 }
172 Command::Clean => write!(f, "clean"),
173 Command::SplitCommit(_) => write!(f, "split-commit"),
174 Command::Split(_) => write!(f, "split"),
175 }
176 }
177}
178
179#[derive(Debug, Args, Clone)]
180struct FormatPromptArgs {
181 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
182 provider: PredictionProvider,
183}
184
185#[derive(Debug, Args, Clone)]
186struct PredictArgs {
187 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
188 provider: PredictionProvider,
189 #[clap(long, default_value_t = 1)]
190 repetitions: usize,
191}
192
193#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
194enum PredictionProvider {
195 Sweep,
196 Mercury,
197 Zeta1,
198 Zeta2(ZetaVersion),
199 Teacher(ZetaVersion),
200 TeacherNonBatching(ZetaVersion),
201}
202
203impl Default for PredictionProvider {
204 fn default() -> Self {
205 PredictionProvider::Zeta2(ZetaVersion::default())
206 }
207}
208
209impl std::fmt::Display for PredictionProvider {
210 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211 match self {
212 PredictionProvider::Sweep => write!(f, "sweep"),
213 PredictionProvider::Mercury => write!(f, "mercury"),
214 PredictionProvider::Zeta1 => write!(f, "zeta1"),
215 PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
216 PredictionProvider::Teacher(version) => write!(f, "teacher:{version}"),
217 PredictionProvider::TeacherNonBatching(version) => {
218 write!(f, "teacher-non-batching:{version}")
219 }
220 }
221 }
222}
223
224impl std::str::FromStr for PredictionProvider {
225 type Err = anyhow::Error;
226
227 fn from_str(mut s: &str) -> Result<Self, Self::Err> {
228 let mut version = ZetaVersion::default();
229 if let Some((first, second)) = s.split_once(':') {
230 version = ZetaVersion::parse(second)?;
231 s = first;
232 }
233
234 let s_lower = s.to_lowercase();
235 match s_lower.as_str() {
236 "sweep" => Ok(PredictionProvider::Sweep),
237 "mercury" => Ok(PredictionProvider::Mercury),
238 "zeta1" => Ok(PredictionProvider::Zeta1),
239 "zeta2" => Ok(PredictionProvider::Zeta2(version)),
240 "teacher" => Ok(PredictionProvider::Teacher(version)),
241 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
242 Ok(PredictionProvider::TeacherNonBatching(version))
243 }
244 _ => {
245 anyhow::bail!(
246 "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching\n\
247 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
248 Available zeta versions:\n{}",
249 ZetaVersion::options_as_string()
250 )
251 }
252 }
253 }
254}
255
256impl Serialize for PredictionProvider {
257 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
258 where
259 S: Serializer,
260 {
261 serializer.serialize_str(&self.to_string())
262 }
263}
264
265impl<'de> Deserialize<'de> for PredictionProvider {
266 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
267 where
268 D: Deserializer<'de>,
269 {
270 let s = String::deserialize(deserializer)?;
271 s.parse().map_err(serde::de::Error::custom)
272 }
273}
274
275#[derive(Debug, Args, Clone)]
276struct SynthesizeArgs {
277 /// Repository URLs (git@github.com:owner/repo or https://...)
278 #[clap(long, required = true, num_args = 1..)]
279 repos: Vec<String>,
280
281 /// Number of examples to generate per repository
282 #[clap(long, default_value_t = 5)]
283 count: usize,
284
285 /// Maximum commits to scan per repository before giving up
286 #[clap(long, default_value_t = 100)]
287 max_commits: usize,
288
289 /// Ignore state file and reprocess all commits
290 #[clap(long)]
291 fresh: bool,
292}
293
294impl EpArgs {
295 fn output_path(&self) -> Option<PathBuf> {
296 if self.in_place {
297 if self.inputs.len() == 1 {
298 self.inputs.first().cloned()
299 } else {
300 panic!("--in-place requires exactly one input file")
301 }
302 } else {
303 self.output.clone()
304 }
305 }
306}
307
308async fn load_examples(
309 http_client: Arc<dyn http_client::HttpClient>,
310 args: &EpArgs,
311 output_path: Option<&PathBuf>,
312 background_executor: BackgroundExecutor,
313) -> anyhow::Result<Vec<Example>> {
314 let mut captured_after_timestamps = Vec::new();
315 let mut file_inputs = Vec::new();
316
317 for input in &args.inputs {
318 let input_string = input.to_string_lossy();
319 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
320 captured_after_timestamps.push(timestamp.to_string());
321 } else {
322 file_inputs.push(input.clone());
323 }
324 }
325
326 let mut examples = read_example_files(&file_inputs);
327
328 Progress::global().set_total_examples(examples.len());
329
330 let remaining_limit_for_snowflake =
331 args.limit.map(|limit| limit.saturating_sub(examples.len()));
332
333 if let Some(0) = remaining_limit_for_snowflake {
334 log::info!(
335 "skipping captured-after inputs because --limit is already satisfied by example files"
336 );
337 } else if !captured_after_timestamps.is_empty() {
338 captured_after_timestamps.sort();
339
340 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
341
342 let mut captured_examples = pull_examples::fetch_captured_examples_after(
343 http_client,
344 &captured_after_timestamps,
345 max_rows_per_timestamp,
346 background_executor,
347 )
348 .await?;
349 examples.append(&mut captured_examples);
350 }
351
352 crate::example::sort_examples_by_repo_and_rev(&mut examples);
353
354 if let Some(name_filter) = &args.name {
355 examples.retain(|example| example.spec.name.contains(name_filter));
356 }
357 if let Some(repo_filter) = &args.repo {
358 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
359 }
360
361 if let Some(limit) = args.limit {
362 if examples.len() > limit {
363 examples.truncate(limit);
364 }
365 }
366
367 // Skip resume logic for --in-place since input and output are the same file,
368 // which would incorrectly treat all input examples as already processed.
369 if !args.in_place {
370 if let Some(path) = output_path {
371 resume_from_output(path, &mut examples);
372 }
373 }
374
375 Progress::global().set_total_examples(examples.len());
376
377 Ok(examples)
378}
379
380fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
381 let mut hasher = collections::FxHasher::default();
382 spec.hash(&mut hasher);
383 hasher.finish()
384}
385
386fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
387 let file = match File::open(path) {
388 Ok(f) => f,
389 Err(_) => return,
390 };
391
392 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
393
394 let reader = BufReader::new(file);
395 let mut kept_lines = Vec::new();
396 let mut kept_hashes = HashSet::default();
397
398 for line in reader.lines() {
399 let line = match line {
400 Ok(l) => l,
401 Err(_) => continue,
402 };
403
404 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
405 let hash = spec_hash(&output_example.spec);
406 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
407 kept_hashes.insert(hash);
408 kept_lines.push(line);
409 }
410 }
411 }
412
413 let total = examples.len();
414 let already_processed = kept_hashes.len();
415
416 eprintln!(
417 "Resuming: {}/{} examples already processed",
418 already_processed, total
419 );
420
421 let file = OpenOptions::new()
422 .write(true)
423 .truncate(true)
424 .open(path)
425 .expect("Failed to open output file for rewriting");
426 let mut writer = BufWriter::new(file);
427 for line in &kept_lines {
428 writeln!(writer, "{}", line).expect("Failed to write to output file");
429 }
430 writer.flush().expect("Failed to flush output file");
431
432 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
433}
434
435fn main() {
436 let args = EpArgs::parse();
437
438 if args.printenv {
439 ::util::shell_env::print_env();
440 return;
441 }
442
443 let output = args.output_path();
444 let command = match &args.command {
445 Some(cmd) => cmd.clone(),
446 None => {
447 EpArgs::command().print_help().unwrap();
448 return;
449 }
450 };
451
452 match &command {
453 Command::Clean => {
454 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
455 return;
456 }
457 Command::Synthesize(synth_args) => {
458 let Some(output_dir) = args.output else {
459 panic!("output dir is required");
460 };
461 let config = SynthesizeConfig {
462 repo_urls: synth_args.repos.clone(),
463 count: synth_args.count,
464 max_commits: synth_args.max_commits,
465 output_dir,
466 fresh: synth_args.fresh,
467 };
468 smol::block_on(async {
469 if let Err(e) = run_synthesize(config).await {
470 eprintln!("Error: {:?}", e);
471 std::process::exit(1);
472 }
473 });
474 return;
475 }
476 Command::SplitCommit(split_commit_args) => {
477 if let Err(error) = split_commit::run_split_commit(
478 split_commit_args,
479 &args.inputs,
480 output.as_ref(),
481 args.failed,
482 ) {
483 eprintln!("{error:#}");
484 std::process::exit(1);
485 }
486 return;
487 }
488 Command::Split(split_args) => {
489 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
490 eprintln!("{error:#}");
491 std::process::exit(1);
492 }
493 return;
494 }
495 _ => {}
496 }
497
498 let http_client = Arc::new(ReqwestClient::new());
499 let app = Application::headless().with_http_client(http_client);
500
501 app.run(move |cx| {
502 let app_state = Arc::new(headless::init(cx));
503 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
504
505 cx.spawn(async move |cx| {
506 let result = async {
507 let examples = load_examples(
508 app_state.client.http_client(),
509 &args,
510 output.as_ref(),
511 cx.background_executor().clone(),
512 )
513 .await?;
514
515 match &command {
516 Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
517 predict::sync_batches(&args.provider).await?;
518 }
519 _ => (),
520 }
521
522 let failfast_on_single_example = examples.len() == 1;
523
524 let output_sender: Option<mpsc::UnboundedSender<String>> =
525 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
526 output.as_ref().map(|path| {
527 let file = OpenOptions::new()
528 .create(true)
529 .append(true)
530 .open(path)
531 .expect("Failed to open output file");
532 let mut writer = BufWriter::new(file);
533 let (sender, mut receiver) = mpsc::unbounded::<String>();
534 cx.background_spawn(async move {
535 while let Some(line) = receiver.next().await {
536 writeln!(writer, "{}", line).expect("Failed to write example");
537 writer.flush().expect("Failed to flush output");
538 }
539 })
540 .detach();
541 sender
542 })
543 } else {
544 None
545 };
546
547 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
548 let finished_examples = Mutex::new(Vec::new());
549
550 let mut tasks = Vec::new();
551 for _ in 0..args.max_parallelism {
552 tasks.push(async {
553 loop {
554 let Some(mut repo_examples) =
555 grouped_examples.lock().unwrap().pop_front()
556 else {
557 break;
558 };
559 for example in &mut repo_examples {
560 let example_progress =
561 Progress::global().start_group(&example.spec.name);
562
563 let result = async {
564 match &command {
565 Command::ParseExample => {}
566 Command::LoadProject => {
567 run_load_project(
568 example,
569 app_state.clone(),
570 &example_progress,
571 cx.clone(),
572 )
573 .await?;
574 }
575 Command::Context => {
576 run_context_retrieval(
577 example,
578 app_state.clone(),
579 &example_progress,
580 cx.clone(),
581 )
582 .await?;
583 }
584 Command::FormatPrompt(args) => {
585 run_format_prompt(
586 example,
587 args,
588 app_state.clone(),
589 &example_progress,
590 cx.clone(),
591 )
592 .await?;
593 }
594 Command::Predict(args) => {
595 run_prediction(
596 example,
597 args,
598 app_state.clone(),
599 &example_progress,
600 cx.clone(),
601 )
602 .await?;
603 }
604 Command::Distill => {
605 run_distill(example).await?;
606 }
607 Command::Score(args) | Command::Eval(args) => {
608 run_scoring(
609 example,
610 &args,
611 app_state.clone(),
612 &example_progress,
613 cx.clone(),
614 )
615 .await?;
616 }
617 Command::Clean
618 | Command::Synthesize(_)
619 | Command::SplitCommit(_)
620 | Command::Split(_) => {
621 unreachable!()
622 }
623 }
624 anyhow::Ok(())
625 }
626 .await;
627
628 let failed = if let Err(error) = result {
629 handle_error(
630 error,
631 &args,
632 &command,
633 &app_state,
634 failfast_on_single_example,
635 &example,
636 )
637 .await;
638 true
639 } else {
640 false
641 };
642
643 let should_write = !failed || args.failed == FailedHandling::Keep;
644 if should_write {
645 if let Some(ref mut sender) = output_sender.clone() {
646 let line = serde_json::to_string(&example).unwrap();
647 sender
648 .send(line)
649 .await
650 .expect("Failed to send to output writer");
651 } else if args.output.is_none()
652 && !matches!(command, Command::Eval(_))
653 {
654 let line = serde_json::to_string(&example).unwrap();
655 println!("{}", line);
656 }
657 }
658 }
659
660 if let Some(state) =
661 repo_examples.first().and_then(|e| e.state.as_ref())
662 {
663 let mut cx = cx.clone();
664 if let Some(ep_store) =
665 cx.update(|cx| EditPredictionStore::try_global(cx))
666 {
667 let project = state.project.clone();
668 ep_store.update(&mut cx, |store, _| {
669 store.remove_project(&project);
670 });
671 }
672 }
673
674 app_state
675 .project_cache
676 .remove(&repo_examples.first().unwrap().spec.repository_url);
677 for example in &mut repo_examples {
678 example.state.take();
679 }
680 finished_examples
681 .lock()
682 .unwrap()
683 .extend_from_slice(&repo_examples);
684 }
685 });
686 }
687 futures::future::join_all(tasks).await;
688
689 Progress::global().finalize();
690
691 match &command {
692 Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
693 predict::sync_batches(&args.provider).await?;
694 }
695 _ => (),
696 }
697
698 match &command {
699 Command::Eval(_) => score::print_report(&finished_examples.lock().unwrap()),
700 _ => (),
701 };
702
703 anyhow::Ok(())
704 }
705 .await;
706
707 if let Err(e) = result {
708 panic!("Fatal error: {:?}", e);
709 }
710
711 let _ = cx.update(|cx| cx.quit());
712 })
713 .detach();
714 });
715}
716
717async fn handle_error(
718 error: anyhow::Error,
719 args: &EpArgs,
720 command: &Command,
721 app_state: &Arc<headless::EpAppState>,
722 failfast_on_single_example: bool,
723 example: &Example,
724) {
725 Progress::global().increment_failed();
726 let example_name = example.spec.filename();
727 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
728 app_state
729 .fs
730 .write(
731 &failed_example_path,
732 &serde_json::to_vec_pretty(&example).unwrap(),
733 )
734 .await
735 .unwrap();
736 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
737 app_state
738 .fs
739 .write(&err_path, format!("{error:?}").as_bytes())
740 .await
741 .unwrap();
742
743 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
744 let mut file = OpenOptions::new()
745 .create(true)
746 .append(true)
747 .open(&failed_jsonl_path)
748 .expect("Failed to open failed.jsonl");
749 writeln!(file, "{}", serde_json::to_string(example).unwrap())
750 .expect("Failed to write to failed.jsonl");
751
752 let cursor_path = example
753 .repo_name()
754 .unwrap()
755 .worktree_path()
756 .join(&example.spec.cursor_path);
757
758 let msg = format!(
759 indoc::indoc! {"
760 While processing \"{}\":
761
762 \x1b[31m{:?}\x1b[0m
763
764 Example: \x1b[36m{}\x1b[0m
765 Error file: \x1b[36m{}\x1b[0m
766 Cursor file: \x1b[36m{}\x1b[0m
767 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
768 "},
769 example.spec.name,
770 error,
771 failed_example_path.display(),
772 err_path.display(),
773 cursor_path.display(),
774 command,
775 failed_example_path.display(),
776 );
777 if args.failfast || failfast_on_single_example {
778 Progress::global().finalize();
779 panic!("{}", msg);
780 } else {
781 log::error!("{}", msg);
782 }
783}