1mod anthropic_client;
2mod distill;
3mod example;
4mod format_prompt;
5mod git;
6mod headless;
7mod load_project;
8mod metrics;
9mod paths;
10mod predict;
11mod progress;
12mod pull_examples;
13mod reorder_patch;
14mod retrieve_context;
15mod score;
16mod split_commit;
17mod split_dataset;
18mod synthesize;
19use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
20use collections::HashSet;
21use edit_prediction::EditPredictionStore;
22use futures::channel::mpsc;
23use futures::{SinkExt as _, StreamExt as _};
24use gpui::{AppContext as _, Application, BackgroundExecutor};
25use zeta_prompt::ZetaVersion;
26
27use reqwest_client::ReqwestClient;
28use serde::{Deserialize, Deserializer, Serialize, Serializer};
29use std::fmt::Display;
30use std::fs::{File, OpenOptions};
31use std::hash::{Hash, Hasher};
32use std::io::{BufRead, BufReader, BufWriter, Write};
33use std::sync::Mutex;
34use std::{path::PathBuf, sync::Arc};
35
36use crate::distill::run_distill;
37use crate::example::{Example, group_examples_by_repo, read_example_files};
38use crate::format_prompt::run_format_prompt;
39use crate::load_project::run_load_project;
40use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
41use crate::predict::run_prediction;
42use crate::progress::Progress;
43use crate::retrieve_context::run_context_retrieval;
44use crate::score::run_scoring;
45use crate::split_commit::SplitCommitArgs;
46use crate::split_dataset::SplitArgs;
47use crate::synthesize::{SynthesizeConfig, run_synthesize};
48
49#[derive(Parser, Debug)]
50#[command(name = "ep")]
51struct EpArgs {
52 #[arg(long, default_value_t = false)]
53 printenv: bool,
54 #[clap(long, default_value_t = 10, global = true)]
55 max_parallelism: usize,
56 #[clap(long, global = true)]
57 limit: Option<usize>,
58 /// Filter examples by name
59 #[clap(long, global = true)]
60 name: Option<String>,
61 /// Filter examples by repository
62 #[clap(long, global = true)]
63 repo: Option<String>,
64 #[command(subcommand)]
65 command: Option<Command>,
66 #[clap(global = true, help = INPUTS_HELP)]
67 inputs: Vec<PathBuf>,
68 #[arg(long, short, global = true)]
69 output: Option<PathBuf>,
70 #[arg(long, short, global = true)]
71 in_place: bool,
72 #[arg(long, short, global = true)]
73 failfast: bool,
74 /// How to handle failed examples in output: keep them or skip them.
75 /// Failed examples are always logged to the run's failed directory.
76 #[arg(long, global = true, default_value = "keep")]
77 failed: FailedHandling,
78}
79
80/// Controls whether failed examples are included in the main output.
81/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
82#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
83pub enum FailedHandling {
84 /// Include failed examples in the main output (default)
85 #[default]
86 Keep,
87 /// Exclude failed examples from the main output
88 Skip,
89}
90
91const INPUTS_HELP: &str = r#"
92Inputs can be file paths or special specifiers:
93
94 path
95 Path to an example(s) file (.md, .json, or .jsonl)
96
97 captured-after:{timestamp}
98 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
99
100 You can specify this multiple times and mix it with file inputs.
101
102 Required environment variables to connect to Snowflake:
103 EP_SNOWFLAKE_API_KEY
104 EP_SNOWFLAKE_BASE_URL
105
106 Optional:
107 EP_SNOWFLAKE_ROLE
108
109Examples:
110
111 # Predict from a file
112 ep predict examples.jsonl
113
114 # Predict from captured examples after a timestamp
115 ep predict captured-after:2025-01-01T00:00:00Z
116
117 # Mix file inputs and captured-after in the same invocation
118 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
119"#;
120
121#[derive(Subcommand, Debug, Clone)]
122enum Command {
123 /// Parse markdown examples and output a combined .jsonl file
124 ParseExample,
125 /// Create git worktrees for each example and load file contents
126 LoadProject,
127 /// Retrieve context for input examples.
128 Context,
129 /// Generate a prompt string for a specific model
130 FormatPrompt(FormatPromptArgs),
131 /// Runs edit prediction
132 Predict(PredictArgs),
133 /// Computes a score based on actual and expected patches
134 Score(PredictArgs),
135 /// Prepares a distillation dataset by copying expected outputs to
136 /// predicted outputs and removing actual outputs and prompts.
137 Distill,
138 /// Print aggregated scores
139 Eval(PredictArgs),
140 /// Generate eval examples by analyzing git commits from a repository
141 Synthesize(SynthesizeArgs),
142 /// Remove git repositories and worktrees
143 Clean,
144 /// Generate an evaluation example by splitting a chronologically-ordered commit
145 SplitCommit(SplitCommitArgs),
146 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
147 Split(SplitArgs),
148}
149
150impl Display for Command {
151 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152 match self {
153 Command::ParseExample => write!(f, "parse-example"),
154 Command::LoadProject => write!(f, "load-project"),
155 Command::Context => write!(f, "context"),
156 Command::FormatPrompt(args) => {
157 write!(f, "format-prompt --provider={}", args.provider)
158 }
159 Command::Predict(args) => {
160 write!(f, "predict --provider={}", args.provider)
161 }
162 Command::Score(args) => {
163 write!(f, "score --provider={}", args.provider)
164 }
165 Command::Distill => write!(f, "distill"),
166 Command::Eval(args) => {
167 write!(f, "eval --provider={}", args.provider)
168 }
169 Command::Synthesize(args) => {
170 write!(f, "synthesize --repos {}", args.repos.join(" "))
171 }
172 Command::Clean => write!(f, "clean"),
173 Command::SplitCommit(_) => write!(f, "split-commit"),
174 Command::Split(_) => write!(f, "split"),
175 }
176 }
177}
178
179#[derive(Debug, Args, Clone)]
180struct FormatPromptArgs {
181 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
182 provider: PredictionProvider,
183}
184
185#[derive(Debug, Args, Clone)]
186struct PredictArgs {
187 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
188 provider: PredictionProvider,
189 #[clap(long, default_value_t = 1)]
190 repetitions: usize,
191}
192
193#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
194enum PredictionProvider {
195 Sweep,
196 Mercury,
197 Zeta1,
198 Zeta2(ZetaVersion),
199 Teacher(ZetaVersion),
200 TeacherNonBatching(ZetaVersion),
201}
202
203impl Default for PredictionProvider {
204 fn default() -> Self {
205 PredictionProvider::Zeta2(ZetaVersion::default())
206 }
207}
208
209impl std::fmt::Display for PredictionProvider {
210 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211 match self {
212 PredictionProvider::Sweep => write!(f, "sweep"),
213 PredictionProvider::Mercury => write!(f, "mercury"),
214 PredictionProvider::Zeta1 => write!(f, "zeta1"),
215 PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
216 PredictionProvider::Teacher(version) => write!(f, "teacher:{version}"),
217 PredictionProvider::TeacherNonBatching(version) => {
218 write!(f, "teacher-non-batching:{version}")
219 }
220 }
221 }
222}
223
224impl std::str::FromStr for PredictionProvider {
225 type Err = anyhow::Error;
226
227 fn from_str(mut s: &str) -> Result<Self, Self::Err> {
228 let mut version = ZetaVersion::default();
229 if let Some((first, second)) = s.split_once(':') {
230 version = ZetaVersion::parse(second)?;
231 s = first;
232 }
233
234 let s_lower = s.to_lowercase();
235 match s_lower.as_str() {
236 "sweep" => Ok(PredictionProvider::Sweep),
237 "mercury" => Ok(PredictionProvider::Mercury),
238 "zeta1" => Ok(PredictionProvider::Zeta1),
239 "zeta2" => Ok(PredictionProvider::Zeta2(version)),
240 "teacher" => Ok(PredictionProvider::Teacher(version)),
241 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
242 Ok(PredictionProvider::TeacherNonBatching(version))
243 }
244 _ => {
245 anyhow::bail!(
246 "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching\n\
247 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
248 Available zeta versions:\n{}",
249 ZetaVersion::options_as_string()
250 )
251 }
252 }
253 }
254}
255
256impl Serialize for PredictionProvider {
257 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
258 where
259 S: Serializer,
260 {
261 serializer.serialize_str(&self.to_string())
262 }
263}
264
265impl<'de> Deserialize<'de> for PredictionProvider {
266 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
267 where
268 D: Deserializer<'de>,
269 {
270 let s = String::deserialize(deserializer)?;
271 s.parse().map_err(serde::de::Error::custom)
272 }
273}
274
275#[derive(Debug, Args, Clone)]
276struct SynthesizeArgs {
277 /// Repository URLs (git@github.com:owner/repo or https://...)
278 #[clap(long, required = true, num_args = 1..)]
279 repos: Vec<String>,
280
281 /// Number of examples to generate per repository
282 #[clap(long, default_value_t = 5)]
283 count: usize,
284
285 /// Maximum commits to scan per repository before giving up
286 #[clap(long, default_value_t = 100)]
287 max_commits: usize,
288
289 /// Ignore state file and reprocess all commits
290 #[clap(long)]
291 fresh: bool,
292}
293
294impl EpArgs {
295 fn output_path(&self) -> Option<PathBuf> {
296 if self.in_place {
297 if self.inputs.len() == 1 {
298 self.inputs.first().cloned()
299 } else {
300 panic!("--in-place requires exactly one input file")
301 }
302 } else {
303 self.output.clone()
304 }
305 }
306}
307
308async fn load_examples(
309 http_client: Arc<dyn http_client::HttpClient>,
310 args: &EpArgs,
311 output_path: Option<&PathBuf>,
312 background_executor: BackgroundExecutor,
313) -> anyhow::Result<Vec<Example>> {
314 let mut captured_after_timestamps = Vec::new();
315 let mut file_inputs = Vec::new();
316
317 for input in &args.inputs {
318 let input_string = input.to_string_lossy();
319 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
320 captured_after_timestamps.push(timestamp.to_string());
321 } else {
322 file_inputs.push(input.clone());
323 }
324 }
325
326 let mut examples = read_example_files(&file_inputs);
327
328 Progress::global().set_total_examples(examples.len());
329
330 let remaining_limit_for_snowflake =
331 args.limit.map(|limit| limit.saturating_sub(examples.len()));
332
333 if let Some(0) = remaining_limit_for_snowflake {
334 log::info!(
335 "skipping captured-after inputs because --limit is already satisfied by example files"
336 );
337 } else if !captured_after_timestamps.is_empty() {
338 captured_after_timestamps.sort();
339
340 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
341
342 let mut captured_examples = pull_examples::fetch_captured_examples_after(
343 http_client,
344 &captured_after_timestamps,
345 max_rows_per_timestamp,
346 background_executor,
347 )
348 .await?;
349 examples.append(&mut captured_examples);
350 }
351
352 crate::example::sort_examples_by_repo_and_rev(&mut examples);
353
354 if let Some(name_filter) = &args.name {
355 examples.retain(|example| example.spec.name.contains(name_filter));
356 }
357 if let Some(repo_filter) = &args.repo {
358 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
359 }
360
361 if let Some(limit) = args.limit {
362 if examples.len() > limit {
363 examples.truncate(limit);
364 }
365 }
366
367 if let Some(path) = output_path {
368 resume_from_output(path, &mut examples);
369 }
370
371 Progress::global().set_total_examples(examples.len());
372
373 Ok(examples)
374}
375
376fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
377 let mut hasher = collections::FxHasher::default();
378 spec.hash(&mut hasher);
379 hasher.finish()
380}
381
382fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
383 let file = match File::open(path) {
384 Ok(f) => f,
385 Err(_) => return,
386 };
387
388 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
389
390 let reader = BufReader::new(file);
391 let mut kept_lines = Vec::new();
392 let mut kept_hashes = HashSet::default();
393
394 for line in reader.lines() {
395 let line = match line {
396 Ok(l) => l,
397 Err(_) => continue,
398 };
399
400 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
401 let hash = spec_hash(&output_example.spec);
402 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
403 kept_hashes.insert(hash);
404 kept_lines.push(line);
405 }
406 }
407 }
408
409 let total = examples.len();
410 let already_processed = kept_hashes.len();
411
412 eprintln!(
413 "Resuming: {}/{} examples already processed",
414 already_processed, total
415 );
416
417 let file = OpenOptions::new()
418 .write(true)
419 .truncate(true)
420 .open(path)
421 .expect("Failed to open output file for rewriting");
422 let mut writer = BufWriter::new(file);
423 for line in &kept_lines {
424 writeln!(writer, "{}", line).expect("Failed to write to output file");
425 }
426 writer.flush().expect("Failed to flush output file");
427
428 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
429}
430
431fn main() {
432 let args = EpArgs::parse();
433
434 if args.printenv {
435 ::util::shell_env::print_env();
436 return;
437 }
438
439 let output = args.output_path();
440 let command = match &args.command {
441 Some(cmd) => cmd.clone(),
442 None => {
443 EpArgs::command().print_help().unwrap();
444 return;
445 }
446 };
447
448 match &command {
449 Command::Clean => {
450 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
451 return;
452 }
453 Command::Synthesize(synth_args) => {
454 let Some(output_dir) = args.output else {
455 panic!("output dir is required");
456 };
457 let config = SynthesizeConfig {
458 repo_urls: synth_args.repos.clone(),
459 count: synth_args.count,
460 max_commits: synth_args.max_commits,
461 output_dir,
462 fresh: synth_args.fresh,
463 };
464 smol::block_on(async {
465 if let Err(e) = run_synthesize(config).await {
466 eprintln!("Error: {:?}", e);
467 std::process::exit(1);
468 }
469 });
470 return;
471 }
472 Command::SplitCommit(split_commit_args) => {
473 if let Err(error) =
474 split_commit::run_split_commit(split_commit_args, &args.inputs, output.as_ref())
475 {
476 eprintln!("{error:#}");
477 std::process::exit(1);
478 }
479 return;
480 }
481 Command::Split(split_args) => {
482 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
483 eprintln!("{error:#}");
484 std::process::exit(1);
485 }
486 return;
487 }
488 _ => {}
489 }
490
491 let http_client = Arc::new(ReqwestClient::new());
492 let app = Application::headless().with_http_client(http_client);
493
494 app.run(move |cx| {
495 let app_state = Arc::new(headless::init(cx));
496 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
497
498 cx.spawn(async move |cx| {
499 let result = async {
500 let examples = load_examples(
501 app_state.client.http_client(),
502 &args,
503 output.as_ref(),
504 cx.background_executor().clone(),
505 )
506 .await?;
507
508 match &command {
509 Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
510 predict::sync_batches(&args.provider).await?;
511 }
512 _ => (),
513 }
514
515 let failfast_on_single_example = examples.len() == 1;
516
517 let output_sender: Option<mpsc::UnboundedSender<String>> =
518 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
519 output.as_ref().map(|path| {
520 let file = OpenOptions::new()
521 .create(true)
522 .append(true)
523 .open(path)
524 .expect("Failed to open output file");
525 let mut writer = BufWriter::new(file);
526 let (sender, mut receiver) = mpsc::unbounded::<String>();
527 cx.background_spawn(async move {
528 while let Some(line) = receiver.next().await {
529 writeln!(writer, "{}", line).expect("Failed to write example");
530 writer.flush().expect("Failed to flush output");
531 }
532 })
533 .detach();
534 sender
535 })
536 } else {
537 None
538 };
539
540 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
541 let finished_examples = Mutex::new(Vec::new());
542
543 let mut tasks = Vec::new();
544 for _ in 0..args.max_parallelism {
545 tasks.push(async {
546 loop {
547 let Some(mut repo_examples) =
548 grouped_examples.lock().unwrap().pop_front()
549 else {
550 break;
551 };
552 for example in &mut repo_examples {
553 let example_progress =
554 Progress::global().start_group(&example.spec.name);
555
556 let result = async {
557 match &command {
558 Command::ParseExample => {}
559 Command::LoadProject => {
560 run_load_project(
561 example,
562 app_state.clone(),
563 &example_progress,
564 cx.clone(),
565 )
566 .await?;
567 }
568 Command::Context => {
569 run_context_retrieval(
570 example,
571 app_state.clone(),
572 &example_progress,
573 cx.clone(),
574 )
575 .await?;
576 }
577 Command::FormatPrompt(args) => {
578 run_format_prompt(
579 example,
580 args,
581 app_state.clone(),
582 &example_progress,
583 cx.clone(),
584 )
585 .await?;
586 }
587 Command::Predict(args) => {
588 run_prediction(
589 example,
590 args,
591 app_state.clone(),
592 &example_progress,
593 cx.clone(),
594 )
595 .await?;
596 }
597 Command::Distill => {
598 run_distill(example).await?;
599 }
600 Command::Score(args) | Command::Eval(args) => {
601 run_scoring(
602 example,
603 &args,
604 app_state.clone(),
605 &example_progress,
606 cx.clone(),
607 )
608 .await?;
609 }
610 Command::Clean
611 | Command::Synthesize(_)
612 | Command::SplitCommit(_)
613 | Command::Split(_) => {
614 unreachable!()
615 }
616 }
617 anyhow::Ok(())
618 }
619 .await;
620
621 let failed = if let Err(error) = result {
622 handle_error(
623 error,
624 &args,
625 &command,
626 &app_state,
627 failfast_on_single_example,
628 &example,
629 )
630 .await;
631 true
632 } else {
633 false
634 };
635
636 let should_write = !failed || args.failed == FailedHandling::Keep;
637 if should_write {
638 if let Some(ref mut sender) = output_sender.clone() {
639 let line = serde_json::to_string(&example).unwrap();
640 sender
641 .send(line)
642 .await
643 .expect("Failed to send to output writer");
644 } else if args.output.is_none()
645 && !matches!(command, Command::Eval(_))
646 {
647 let line = serde_json::to_string(&example).unwrap();
648 println!("{}", line);
649 }
650 }
651 }
652
653 if let Some(state) =
654 repo_examples.first().and_then(|e| e.state.as_ref())
655 {
656 let mut cx = cx.clone();
657 if let Some(ep_store) =
658 cx.update(|cx| EditPredictionStore::try_global(cx))
659 {
660 let project = state.project.clone();
661 ep_store.update(&mut cx, |store, _| {
662 store.remove_project(&project);
663 });
664 }
665 }
666
667 app_state
668 .project_cache
669 .remove(&repo_examples.first().unwrap().spec.repository_url);
670 for example in &mut repo_examples {
671 example.state.take();
672 }
673 finished_examples
674 .lock()
675 .unwrap()
676 .extend_from_slice(&repo_examples);
677 }
678 });
679 }
680 futures::future::join_all(tasks).await;
681
682 Progress::global().finalize();
683
684 match &command {
685 Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
686 predict::sync_batches(&args.provider).await?;
687 }
688 _ => (),
689 }
690
691 match &command {
692 Command::Eval(_) => score::print_report(&finished_examples.lock().unwrap()),
693 _ => (),
694 };
695
696 anyhow::Ok(())
697 }
698 .await;
699
700 if let Err(e) = result {
701 panic!("Fatal error: {:?}", e);
702 }
703
704 let _ = cx.update(|cx| cx.quit());
705 })
706 .detach();
707 });
708}
709
710async fn handle_error(
711 error: anyhow::Error,
712 args: &EpArgs,
713 command: &Command,
714 app_state: &Arc<headless::EpAppState>,
715 failfast_on_single_example: bool,
716 example: &Example,
717) {
718 Progress::global().increment_failed();
719 let example_name = example.spec.filename();
720 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
721 app_state
722 .fs
723 .write(
724 &failed_example_path,
725 &serde_json::to_vec_pretty(&example).unwrap(),
726 )
727 .await
728 .unwrap();
729 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
730 app_state
731 .fs
732 .write(&err_path, format!("{error:?}").as_bytes())
733 .await
734 .unwrap();
735
736 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
737 let mut file = OpenOptions::new()
738 .create(true)
739 .append(true)
740 .open(&failed_jsonl_path)
741 .expect("Failed to open failed.jsonl");
742 writeln!(file, "{}", serde_json::to_string(example).unwrap())
743 .expect("Failed to write to failed.jsonl");
744
745 let cursor_path = example
746 .repo_name()
747 .unwrap()
748 .worktree_path()
749 .join(&example.spec.cursor_path);
750
751 let msg = format!(
752 indoc::indoc! {"
753 While processing \"{}\":
754
755 \x1b[31m{:?}\x1b[0m
756
757 Example: \x1b[36m{}\x1b[0m
758 Error file: \x1b[36m{}\x1b[0m
759 Cursor file: \x1b[36m{}\x1b[0m
760 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
761 "},
762 example.spec.name,
763 error,
764 failed_example_path.display(),
765 err_path.display(),
766 cursor_path.display(),
767 command,
768 failed_example_path.display(),
769 );
770 if args.failfast || failfast_on_single_example {
771 Progress::global().finalize();
772 panic!("{}", msg);
773 } else {
774 log::error!("{}", msg);
775 }
776}