1mod anthropic_client;
2mod distill;
3mod example;
4mod format_prompt;
5mod git;
6mod headless;
7mod load_project;
8mod metrics;
9mod paths;
10mod predict;
11mod progress;
12mod pull_examples;
13mod reorder_patch;
14mod retrieve_context;
15mod score;
16mod split_commit;
17mod split_dataset;
18mod synthesize;
19use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
20use collections::HashSet;
21use edit_prediction::EditPredictionStore;
22use futures::channel::mpsc;
23use futures::{SinkExt as _, StreamExt as _};
24use gpui::{AppContext as _, Application, BackgroundExecutor};
25use zeta_prompt::ZetaVersion;
26
27use reqwest_client::ReqwestClient;
28use serde::{Deserialize, Deserializer, Serialize, Serializer};
29use std::fmt::Display;
30use std::fs::{File, OpenOptions};
31use std::hash::{Hash, Hasher};
32use std::io::{BufRead, BufReader, BufWriter, Write};
33use std::sync::Mutex;
34use std::{path::PathBuf, sync::Arc};
35
36use crate::distill::run_distill;
37use crate::example::{Example, group_examples_by_repo, read_example_files};
38use crate::format_prompt::run_format_prompt;
39use crate::load_project::run_load_project;
40use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
41use crate::predict::run_prediction;
42use crate::progress::Progress;
43use crate::retrieve_context::run_context_retrieval;
44use crate::score::run_scoring;
45use crate::split_commit::SplitCommitArgs;
46use crate::split_dataset::SplitArgs;
47use crate::synthesize::{SynthesizeConfig, run_synthesize};
48
49#[derive(Parser, Debug)]
50#[command(name = "ep")]
51struct EpArgs {
52 #[arg(long, default_value_t = false)]
53 printenv: bool,
54 #[clap(long, default_value_t = 10, global = true)]
55 max_parallelism: usize,
56 #[clap(long, global = true)]
57 limit: Option<usize>,
58 /// Filter examples by name
59 #[clap(long, global = true)]
60 name: Option<String>,
61 /// Filter examples by repository
62 #[clap(long, global = true)]
63 repo: Option<String>,
64 #[command(subcommand)]
65 command: Option<Command>,
66 #[clap(global = true, help = INPUTS_HELP)]
67 inputs: Vec<PathBuf>,
68 #[arg(long, short, global = true)]
69 output: Option<PathBuf>,
70 #[arg(long, short, global = true)]
71 in_place: bool,
72 #[arg(long, short, global = true)]
73 failfast: bool,
74 /// How to handle failed examples in output: keep them or skip them.
75 /// Failed examples are always logged to the run's failed directory.
76 #[arg(long, global = true, default_value = "keep")]
77 failed: FailedHandling,
78}
79
80/// Controls whether failed examples are included in the main output.
81/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
82#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
83pub enum FailedHandling {
84 /// Include failed examples in the main output (default)
85 #[default]
86 Keep,
87 /// Exclude failed examples from the main output
88 Skip,
89}
90
91const INPUTS_HELP: &str = r#"
92Inputs can be file paths or special specifiers:
93
94 path
95 Path to an example(s) file (.md, .json, or .jsonl)
96
97 captured-after:{timestamp}
98 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
99
100 You can specify this multiple times and mix it with file inputs.
101
102 Required environment variables to connect to Snowflake:
103 EP_SNOWFLAKE_API_KEY
104 EP_SNOWFLAKE_BASE_URL
105
106 Optional:
107 EP_SNOWFLAKE_ROLE
108
109Examples:
110
111 # Predict from a file
112 ep predict examples.jsonl
113
114 # Predict from captured examples after a timestamp
115 ep predict captured-after:2025-01-01T00:00:00Z
116
117 # Mix file inputs and captured-after in the same invocation
118 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
119"#;
120
121#[derive(Subcommand, Debug, Clone)]
122enum Command {
123 /// Parse markdown examples and output a combined .jsonl file
124 ParseExample,
125 /// Create git worktrees for each example and load file contents
126 LoadProject,
127 /// Retrieve context for input examples.
128 Context,
129 /// Generate a prompt string for a specific model
130 FormatPrompt(FormatPromptArgs),
131 /// Runs edit prediction
132 Predict(PredictArgs),
133 /// Computes a score based on actual and expected patches
134 Score(PredictArgs),
135 /// Prepares a distillation dataset by copying expected outputs to
136 /// predicted outputs and removing actual outputs and prompts.
137 Distill,
138 /// Print aggregated scores
139 Eval(PredictArgs),
140 /// Generate eval examples by analyzing git commits from a repository
141 Synthesize(SynthesizeArgs),
142 /// Remove git repositories and worktrees
143 Clean,
144 /// Generate an evaluation example by splitting a chronologically-ordered commit
145 SplitCommit(SplitCommitArgs),
146 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
147 Split(SplitArgs),
148}
149
150impl Display for Command {
151 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152 match self {
153 Command::ParseExample => write!(f, "parse-example"),
154 Command::LoadProject => write!(f, "load-project"),
155 Command::Context => write!(f, "context"),
156 Command::FormatPrompt(args) => {
157 write!(f, "format-prompt --provider={}", args.provider)
158 }
159 Command::Predict(args) => {
160 write!(f, "predict --provider={}", args.provider)
161 }
162 Command::Score(args) => {
163 write!(f, "score --provider={}", args.provider)
164 }
165 Command::Distill => write!(f, "distill"),
166 Command::Eval(args) => {
167 write!(f, "eval --provider={}", args.provider)
168 }
169 Command::Synthesize(args) => {
170 write!(f, "synthesize --repos {}", args.repos.join(" "))
171 }
172 Command::Clean => write!(f, "clean"),
173 Command::SplitCommit(_) => write!(f, "split-commit"),
174 Command::Split(_) => write!(f, "split"),
175 }
176 }
177}
178
179#[derive(Debug, Args, Clone)]
180struct FormatPromptArgs {
181 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
182 provider: PredictionProvider,
183}
184
185#[derive(Debug, Args, Clone)]
186struct PredictArgs {
187 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
188 provider: PredictionProvider,
189 #[clap(long, default_value_t = 1)]
190 repetitions: usize,
191}
192
193#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
194enum PredictionProvider {
195 Sweep,
196 Mercury,
197 Zeta1,
198 Zeta2(ZetaVersion),
199 Teacher(ZetaVersion),
200 TeacherNonBatching(ZetaVersion),
201}
202
203impl Default for PredictionProvider {
204 fn default() -> Self {
205 PredictionProvider::Zeta2(ZetaVersion::default())
206 }
207}
208
209impl std::fmt::Display for PredictionProvider {
210 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211 match self {
212 PredictionProvider::Sweep => write!(f, "sweep"),
213 PredictionProvider::Mercury => write!(f, "mercury"),
214 PredictionProvider::Zeta1 => write!(f, "zeta1"),
215 PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
216 PredictionProvider::Teacher(version) => write!(f, "teacher:{version}"),
217 PredictionProvider::TeacherNonBatching(version) => {
218 write!(f, "teacher-non-batching:{version}")
219 }
220 }
221 }
222}
223
224impl std::str::FromStr for PredictionProvider {
225 type Err = anyhow::Error;
226
227 fn from_str(mut s: &str) -> Result<Self, Self::Err> {
228 let mut version = ZetaVersion::default();
229 if let Some((first, second)) = s.split_once(':') {
230 version = ZetaVersion::parse(second)?;
231 s = first;
232 }
233
234 let s_lower = s.to_lowercase();
235 match s_lower.as_str() {
236 "sweep" => Ok(PredictionProvider::Sweep),
237 "mercury" => Ok(PredictionProvider::Mercury),
238 "zeta1" => Ok(PredictionProvider::Zeta1),
239 "zeta2" => Ok(PredictionProvider::Zeta2(version)),
240 "teacher" => Ok(PredictionProvider::Teacher(version)),
241 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
242 Ok(PredictionProvider::TeacherNonBatching(version))
243 }
244 _ => {
245 anyhow::bail!(
246 "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching\n\
247 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
248 Available zeta versions:\n{}",
249 ZetaVersion::options_as_string()
250 )
251 }
252 }
253 }
254}
255
256impl Serialize for PredictionProvider {
257 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
258 where
259 S: Serializer,
260 {
261 serializer.serialize_str(&self.to_string())
262 }
263}
264
265impl<'de> Deserialize<'de> for PredictionProvider {
266 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
267 where
268 D: Deserializer<'de>,
269 {
270 let s = String::deserialize(deserializer)?;
271 s.parse().map_err(serde::de::Error::custom)
272 }
273}
274
275#[derive(Debug, Args, Clone)]
276struct SynthesizeArgs {
277 /// Repository URLs (git@github.com:owner/repo or https://...)
278 #[clap(long, required = true, num_args = 1..)]
279 repos: Vec<String>,
280
281 /// Number of examples to generate per repository
282 #[clap(long, default_value_t = 5)]
283 count: usize,
284
285 /// Maximum commits to scan per repository before giving up
286 #[clap(long, default_value_t = 100)]
287 max_commits: usize,
288
289 /// Ignore state file and reprocess all commits
290 #[clap(long)]
291 fresh: bool,
292}
293
294impl EpArgs {
295 fn output_path(&self) -> Option<PathBuf> {
296 if self.in_place {
297 if self.inputs.len() == 1 {
298 self.inputs.first().cloned()
299 } else {
300 panic!("--in-place requires exactly one input file")
301 }
302 } else {
303 self.output.clone()
304 }
305 }
306}
307
308async fn load_examples(
309 http_client: Arc<dyn http_client::HttpClient>,
310 args: &EpArgs,
311 output_path: Option<&PathBuf>,
312 background_executor: BackgroundExecutor,
313) -> anyhow::Result<Vec<Example>> {
314 let mut captured_after_timestamps = Vec::new();
315 let mut file_inputs = Vec::new();
316
317 for input in &args.inputs {
318 let input_string = input.to_string_lossy();
319 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
320 captured_after_timestamps.push(timestamp.to_string());
321 } else {
322 file_inputs.push(input.clone());
323 }
324 }
325
326 let mut examples = read_example_files(&file_inputs);
327
328 Progress::global().set_total_examples(examples.len());
329
330 let remaining_limit_for_snowflake =
331 args.limit.map(|limit| limit.saturating_sub(examples.len()));
332
333 if let Some(0) = remaining_limit_for_snowflake {
334 log::info!(
335 "skipping captured-after inputs because --limit is already satisfied by example files"
336 );
337 } else if !captured_after_timestamps.is_empty() {
338 captured_after_timestamps.sort();
339
340 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
341
342 let mut captured_examples = pull_examples::fetch_captured_examples_after(
343 http_client,
344 &captured_after_timestamps,
345 max_rows_per_timestamp,
346 background_executor,
347 )
348 .await?;
349 examples.append(&mut captured_examples);
350 }
351
352 crate::example::sort_examples_by_repo_and_rev(&mut examples);
353
354 if let Some(name_filter) = &args.name {
355 examples.retain(|example| example.spec.name.contains(name_filter));
356 }
357 if let Some(repo_filter) = &args.repo {
358 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
359 }
360
361 if let Some(limit) = args.limit {
362 if examples.len() > limit {
363 examples.truncate(limit);
364 }
365 }
366
367 if let Some(path) = output_path {
368 resume_from_output(path, &mut examples);
369 }
370
371 Progress::global().set_total_examples(examples.len());
372
373 Ok(examples)
374}
375
376fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
377 let mut hasher = collections::FxHasher::default();
378 spec.hash(&mut hasher);
379 hasher.finish()
380}
381
382fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
383 let file = match File::open(path) {
384 Ok(f) => f,
385 Err(_) => return,
386 };
387
388 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
389
390 let reader = BufReader::new(file);
391 let mut kept_lines = Vec::new();
392 let mut kept_hashes = HashSet::default();
393
394 for line in reader.lines() {
395 let line = match line {
396 Ok(l) => l,
397 Err(_) => continue,
398 };
399
400 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
401 let hash = spec_hash(&output_example.spec);
402 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
403 kept_hashes.insert(hash);
404 kept_lines.push(line);
405 }
406 }
407 }
408
409 let total = examples.len();
410 let already_processed = kept_hashes.len();
411
412 eprintln!(
413 "Resuming: {}/{} examples already processed",
414 already_processed, total
415 );
416
417 let file = OpenOptions::new()
418 .write(true)
419 .truncate(true)
420 .open(path)
421 .expect("Failed to open output file for rewriting");
422 let mut writer = BufWriter::new(file);
423 for line in &kept_lines {
424 writeln!(writer, "{}", line).expect("Failed to write to output file");
425 }
426 writer.flush().expect("Failed to flush output file");
427
428 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
429}
430
431fn main() {
432 let args = EpArgs::parse();
433
434 if args.printenv {
435 ::util::shell_env::print_env();
436 return;
437 }
438
439 let output = args.output_path();
440 let command = match &args.command {
441 Some(cmd) => cmd.clone(),
442 None => {
443 EpArgs::command().print_help().unwrap();
444 return;
445 }
446 };
447
448 match &command {
449 Command::Clean => {
450 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
451 return;
452 }
453 Command::Synthesize(synth_args) => {
454 let Some(output_dir) = args.output else {
455 panic!("output dir is required");
456 };
457 let config = SynthesizeConfig {
458 repo_urls: synth_args.repos.clone(),
459 count: synth_args.count,
460 max_commits: synth_args.max_commits,
461 output_dir,
462 fresh: synth_args.fresh,
463 };
464 smol::block_on(async {
465 if let Err(e) = run_synthesize(config).await {
466 eprintln!("Error: {:?}", e);
467 std::process::exit(1);
468 }
469 });
470 return;
471 }
472 Command::SplitCommit(split_commit_args) => {
473 if let Err(error) = split_commit::run_split_commit(
474 split_commit_args,
475 &args.inputs,
476 output.as_ref(),
477 args.failed,
478 ) {
479 eprintln!("{error:#}");
480 std::process::exit(1);
481 }
482 return;
483 }
484 Command::Split(split_args) => {
485 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
486 eprintln!("{error:#}");
487 std::process::exit(1);
488 }
489 return;
490 }
491 _ => {}
492 }
493
494 let http_client = Arc::new(ReqwestClient::new());
495 let app = Application::headless().with_http_client(http_client);
496
497 app.run(move |cx| {
498 let app_state = Arc::new(headless::init(cx));
499 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
500
501 cx.spawn(async move |cx| {
502 let result = async {
503 let examples = load_examples(
504 app_state.client.http_client(),
505 &args,
506 output.as_ref(),
507 cx.background_executor().clone(),
508 )
509 .await?;
510
511 match &command {
512 Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
513 predict::sync_batches(&args.provider).await?;
514 }
515 _ => (),
516 }
517
518 let failfast_on_single_example = examples.len() == 1;
519
520 let output_sender: Option<mpsc::UnboundedSender<String>> =
521 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
522 output.as_ref().map(|path| {
523 let file = OpenOptions::new()
524 .create(true)
525 .append(true)
526 .open(path)
527 .expect("Failed to open output file");
528 let mut writer = BufWriter::new(file);
529 let (sender, mut receiver) = mpsc::unbounded::<String>();
530 cx.background_spawn(async move {
531 while let Some(line) = receiver.next().await {
532 writeln!(writer, "{}", line).expect("Failed to write example");
533 writer.flush().expect("Failed to flush output");
534 }
535 })
536 .detach();
537 sender
538 })
539 } else {
540 None
541 };
542
543 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
544 let finished_examples = Mutex::new(Vec::new());
545
546 let mut tasks = Vec::new();
547 for _ in 0..args.max_parallelism {
548 tasks.push(async {
549 loop {
550 let Some(mut repo_examples) =
551 grouped_examples.lock().unwrap().pop_front()
552 else {
553 break;
554 };
555 for example in &mut repo_examples {
556 let example_progress =
557 Progress::global().start_group(&example.spec.name);
558
559 let result = async {
560 match &command {
561 Command::ParseExample => {}
562 Command::LoadProject => {
563 run_load_project(
564 example,
565 app_state.clone(),
566 &example_progress,
567 cx.clone(),
568 )
569 .await?;
570 }
571 Command::Context => {
572 run_context_retrieval(
573 example,
574 app_state.clone(),
575 &example_progress,
576 cx.clone(),
577 )
578 .await?;
579 }
580 Command::FormatPrompt(args) => {
581 run_format_prompt(
582 example,
583 args,
584 app_state.clone(),
585 &example_progress,
586 cx.clone(),
587 )
588 .await?;
589 }
590 Command::Predict(args) => {
591 run_prediction(
592 example,
593 args,
594 app_state.clone(),
595 &example_progress,
596 cx.clone(),
597 )
598 .await?;
599 }
600 Command::Distill => {
601 run_distill(example).await?;
602 }
603 Command::Score(args) | Command::Eval(args) => {
604 run_scoring(
605 example,
606 &args,
607 app_state.clone(),
608 &example_progress,
609 cx.clone(),
610 )
611 .await?;
612 }
613 Command::Clean
614 | Command::Synthesize(_)
615 | Command::SplitCommit(_)
616 | Command::Split(_) => {
617 unreachable!()
618 }
619 }
620 anyhow::Ok(())
621 }
622 .await;
623
624 let failed = if let Err(error) = result {
625 handle_error(
626 error,
627 &args,
628 &command,
629 &app_state,
630 failfast_on_single_example,
631 &example,
632 )
633 .await;
634 true
635 } else {
636 false
637 };
638
639 let should_write = !failed || args.failed == FailedHandling::Keep;
640 if should_write {
641 if let Some(ref mut sender) = output_sender.clone() {
642 let line = serde_json::to_string(&example).unwrap();
643 sender
644 .send(line)
645 .await
646 .expect("Failed to send to output writer");
647 } else if args.output.is_none()
648 && !matches!(command, Command::Eval(_))
649 {
650 let line = serde_json::to_string(&example).unwrap();
651 println!("{}", line);
652 }
653 }
654 }
655
656 if let Some(state) =
657 repo_examples.first().and_then(|e| e.state.as_ref())
658 {
659 let mut cx = cx.clone();
660 if let Some(ep_store) =
661 cx.update(|cx| EditPredictionStore::try_global(cx))
662 {
663 let project = state.project.clone();
664 ep_store.update(&mut cx, |store, _| {
665 store.remove_project(&project);
666 });
667 }
668 }
669
670 app_state
671 .project_cache
672 .remove(&repo_examples.first().unwrap().spec.repository_url);
673 for example in &mut repo_examples {
674 example.state.take();
675 }
676 finished_examples
677 .lock()
678 .unwrap()
679 .extend_from_slice(&repo_examples);
680 }
681 });
682 }
683 futures::future::join_all(tasks).await;
684
685 Progress::global().finalize();
686
687 match &command {
688 Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
689 predict::sync_batches(&args.provider).await?;
690 }
691 _ => (),
692 }
693
694 match &command {
695 Command::Eval(_) => score::print_report(&finished_examples.lock().unwrap()),
696 _ => (),
697 };
698
699 anyhow::Ok(())
700 }
701 .await;
702
703 if let Err(e) = result {
704 panic!("Fatal error: {:?}", e);
705 }
706
707 let _ = cx.update(|cx| cx.quit());
708 })
709 .detach();
710 });
711}
712
713async fn handle_error(
714 error: anyhow::Error,
715 args: &EpArgs,
716 command: &Command,
717 app_state: &Arc<headless::EpAppState>,
718 failfast_on_single_example: bool,
719 example: &Example,
720) {
721 Progress::global().increment_failed();
722 let example_name = example.spec.filename();
723 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
724 app_state
725 .fs
726 .write(
727 &failed_example_path,
728 &serde_json::to_vec_pretty(&example).unwrap(),
729 )
730 .await
731 .unwrap();
732 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
733 app_state
734 .fs
735 .write(&err_path, format!("{error:?}").as_bytes())
736 .await
737 .unwrap();
738
739 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
740 let mut file = OpenOptions::new()
741 .create(true)
742 .append(true)
743 .open(&failed_jsonl_path)
744 .expect("Failed to open failed.jsonl");
745 writeln!(file, "{}", serde_json::to_string(example).unwrap())
746 .expect("Failed to write to failed.jsonl");
747
748 let cursor_path = example
749 .repo_name()
750 .unwrap()
751 .worktree_path()
752 .join(&example.spec.cursor_path);
753
754 let msg = format!(
755 indoc::indoc! {"
756 While processing \"{}\":
757
758 \x1b[31m{:?}\x1b[0m
759
760 Example: \x1b[36m{}\x1b[0m
761 Error file: \x1b[36m{}\x1b[0m
762 Cursor file: \x1b[36m{}\x1b[0m
763 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
764 "},
765 example.spec.name,
766 error,
767 failed_example_path.display(),
768 err_path.display(),
769 cursor_path.display(),
770 command,
771 failed_example_path.display(),
772 );
773 if args.failfast || failfast_on_single_example {
774 Progress::global().finalize();
775 panic!("{}", msg);
776 } else {
777 log::error!("{}", msg);
778 }
779}