1mod anthropic_client;
2mod distill;
3mod example;
4mod format_prompt;
5mod git;
6mod headless;
7mod load_project;
8mod metrics;
9mod paths;
10mod predict;
11mod progress;
12mod pull_examples;
13mod reorder_patch;
14mod retrieve_context;
15mod score;
16mod split_commit;
17mod split_dataset;
18mod synthesize;
19use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
20use collections::HashSet;
21use edit_prediction::EditPredictionStore;
22use futures::channel::mpsc;
23use futures::{SinkExt as _, StreamExt as _};
24use gpui::{AppContext as _, Application, BackgroundExecutor};
25use zeta_prompt::ZetaVersion;
26
27use reqwest_client::ReqwestClient;
28use serde::{Deserialize, Deserializer, Serialize, Serializer};
29use std::fmt::Display;
30use std::fs::{File, OpenOptions};
31use std::hash::{Hash, Hasher};
32use std::io::{BufRead, BufReader, BufWriter, Write};
33use std::{path::PathBuf, sync::Arc};
34
35use crate::distill::run_distill;
36use crate::example::{Example, group_examples_by_repo, read_example_files};
37use crate::format_prompt::run_format_prompt;
38use crate::load_project::run_load_project;
39use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
40use crate::predict::run_prediction;
41use crate::progress::Progress;
42use crate::retrieve_context::run_context_retrieval;
43use crate::score::run_scoring;
44use crate::split_commit::SplitCommitArgs;
45use crate::split_dataset::SplitArgs;
46use crate::synthesize::{SynthesizeConfig, run_synthesize};
47
48#[derive(Parser, Debug)]
49#[command(name = "ep")]
50struct EpArgs {
51 #[arg(long, default_value_t = false)]
52 printenv: bool,
53 #[clap(long, default_value_t = 10, global = true)]
54 max_parallelism: usize,
55 #[clap(long, global = true)]
56 limit: Option<usize>,
57 /// Filter examples by name
58 #[clap(long, global = true)]
59 name: Option<String>,
60 /// Filter examples by repository
61 #[clap(long, global = true)]
62 repo: Option<String>,
63 #[command(subcommand)]
64 command: Option<Command>,
65 #[clap(global = true, help = INPUTS_HELP)]
66 inputs: Vec<PathBuf>,
67 #[arg(long, short, global = true)]
68 output: Option<PathBuf>,
69 #[arg(long, short, global = true)]
70 in_place: bool,
71 #[arg(long, short, global = true)]
72 failfast: bool,
73 /// How to handle failed examples in output: keep them or skip them.
74 /// Failed examples are always logged to the run's failed directory.
75 #[arg(long, global = true, default_value = "keep")]
76 failed: FailedHandling,
77}
78
79/// Controls whether failed examples are included in the main output.
80/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
81#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
82pub enum FailedHandling {
83 /// Include failed examples in the main output (default)
84 #[default]
85 Keep,
86 /// Exclude failed examples from the main output
87 Skip,
88}
89
90const INPUTS_HELP: &str = r#"
91Inputs can be file paths or special specifiers:
92
93 path
94 Path to an example(s) file (.md, .json, or .jsonl)
95
96 captured-after:{timestamp}
97 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
98
99 You can specify this multiple times and mix it with file inputs.
100
101 Required environment variables to connect to Snowflake:
102 EP_SNOWFLAKE_API_KEY
103 EP_SNOWFLAKE_BASE_URL
104
105 Optional:
106 EP_SNOWFLAKE_ROLE
107
108Examples:
109
110 # Predict from a file
111 ep predict examples.jsonl
112
113 # Predict from captured examples after a timestamp
114 ep predict captured-after:2025-01-01T00:00:00Z
115
116 # Mix file inputs and captured-after in the same invocation
117 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
118"#;
119
120#[derive(Subcommand, Debug, Clone)]
121enum Command {
122 /// Parse markdown examples and output a combined .jsonl file
123 ParseExample,
124 /// Create git worktrees for each example and load file contents
125 LoadProject,
126 /// Retrieve context for input examples.
127 Context,
128 /// Generate a prompt string for a specific model
129 FormatPrompt(FormatPromptArgs),
130 /// Runs edit prediction
131 Predict(PredictArgs),
132 /// Computes a score based on actual and expected patches
133 Score(PredictArgs),
134 /// Prepares a distillation dataset by copying expected outputs to
135 /// predicted outputs and removing actual outputs and prompts.
136 Distill,
137 /// Print aggregated scores
138 Eval(PredictArgs),
139 /// Generate eval examples by analyzing git commits from a repository
140 Synthesize(SynthesizeArgs),
141 /// Remove git repositories and worktrees
142 Clean,
143 /// Generate an evaluation example by splitting a chronologically-ordered commit
144 SplitCommit(SplitCommitArgs),
145 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
146 Split(SplitArgs),
147}
148
149impl Display for Command {
150 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151 match self {
152 Command::ParseExample => write!(f, "parse-example"),
153 Command::LoadProject => write!(f, "load-project"),
154 Command::Context => write!(f, "context"),
155 Command::FormatPrompt(args) => {
156 write!(f, "format-prompt --provider={}", args.provider)
157 }
158 Command::Predict(args) => {
159 write!(f, "predict --provider={}", args.provider)
160 }
161 Command::Score(args) => {
162 write!(f, "score --provider={}", args.provider)
163 }
164 Command::Distill => write!(f, "distill"),
165 Command::Eval(args) => {
166 write!(f, "eval --provider={}", args.provider)
167 }
168 Command::Synthesize(args) => {
169 write!(f, "synthesize --repos {}", args.repos.join(" "))
170 }
171 Command::Clean => write!(f, "clean"),
172 Command::SplitCommit(_) => write!(f, "split-commit"),
173 Command::Split(_) => write!(f, "split"),
174 }
175 }
176}
177
178#[derive(Debug, Args, Clone)]
179struct FormatPromptArgs {
180 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
181 provider: PredictionProvider,
182}
183
184#[derive(Debug, Args, Clone)]
185struct PredictArgs {
186 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
187 provider: PredictionProvider,
188 #[clap(long, default_value_t = 1)]
189 repetitions: usize,
190}
191
192#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
193enum PredictionProvider {
194 Sweep,
195 Mercury,
196 Zeta1,
197 Zeta2(ZetaVersion),
198 Teacher,
199 TeacherNonBatching,
200}
201
202impl Default for PredictionProvider {
203 fn default() -> Self {
204 PredictionProvider::Zeta2(ZetaVersion::default())
205 }
206}
207
208impl std::fmt::Display for PredictionProvider {
209 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210 match self {
211 PredictionProvider::Sweep => write!(f, "sweep"),
212 PredictionProvider::Mercury => write!(f, "mercury"),
213 PredictionProvider::Zeta1 => write!(f, "zeta1"),
214 PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
215 PredictionProvider::Teacher => write!(f, "teacher"),
216 PredictionProvider::TeacherNonBatching => write!(f, "teacher-non-batching"),
217 }
218 }
219}
220
221impl std::str::FromStr for PredictionProvider {
222 type Err = anyhow::Error;
223
224 fn from_str(s: &str) -> Result<Self, Self::Err> {
225 let s_lower = s.to_lowercase();
226 match s_lower.as_str() {
227 "sweep" => Ok(PredictionProvider::Sweep),
228 "mercury" => Ok(PredictionProvider::Mercury),
229 "zeta1" => Ok(PredictionProvider::Zeta1),
230 // Handle both old format "zeta2" and new format with version
231 "zeta2" => Ok(PredictionProvider::Zeta2(ZetaVersion::default())),
232 "teacher" => Ok(PredictionProvider::Teacher),
233 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
234 Ok(PredictionProvider::TeacherNonBatching)
235 }
236 _ if s_lower.starts_with("zeta2:") => {
237 let version_str = &s[6..];
238 let version = ZetaVersion::parse(version_str)?;
239 Ok(PredictionProvider::Zeta2(version))
240 }
241 _ => anyhow::bail!(
242 "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching\n\
243 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
244 Available zeta versions:\n{}",
245 ZetaVersion::options_as_string()
246 ),
247 }
248 }
249}
250
251impl Serialize for PredictionProvider {
252 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
253 where
254 S: Serializer,
255 {
256 serializer.serialize_str(&self.to_string())
257 }
258}
259
260impl<'de> Deserialize<'de> for PredictionProvider {
261 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
262 where
263 D: Deserializer<'de>,
264 {
265 let s = String::deserialize(deserializer)?;
266 s.parse().map_err(serde::de::Error::custom)
267 }
268}
269
270#[derive(Debug, Args, Clone)]
271struct SynthesizeArgs {
272 /// Repository URLs (git@github.com:owner/repo or https://...)
273 #[clap(long, required = true, num_args = 1..)]
274 repos: Vec<String>,
275
276 /// Number of examples to generate per repository
277 #[clap(long, default_value_t = 5)]
278 count: usize,
279
280 /// Maximum commits to scan per repository before giving up
281 #[clap(long, default_value_t = 100)]
282 max_commits: usize,
283
284 /// Ignore state file and reprocess all commits
285 #[clap(long)]
286 fresh: bool,
287}
288
289impl EpArgs {
290 fn output_path(&self) -> Option<PathBuf> {
291 if self.in_place {
292 if self.inputs.len() == 1 {
293 self.inputs.first().cloned()
294 } else {
295 panic!("--in-place requires exactly one input file")
296 }
297 } else {
298 self.output.clone()
299 }
300 }
301}
302
303async fn load_examples(
304 http_client: Arc<dyn http_client::HttpClient>,
305 args: &EpArgs,
306 output_path: Option<&PathBuf>,
307 background_executor: BackgroundExecutor,
308) -> anyhow::Result<Vec<Example>> {
309 let mut captured_after_timestamps = Vec::new();
310 let mut file_inputs = Vec::new();
311
312 for input in &args.inputs {
313 let input_string = input.to_string_lossy();
314 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
315 captured_after_timestamps.push(timestamp.to_string());
316 } else {
317 file_inputs.push(input.clone());
318 }
319 }
320
321 let mut examples = read_example_files(&file_inputs);
322
323 Progress::global().set_total_examples(examples.len());
324
325 let remaining_limit_for_snowflake =
326 args.limit.map(|limit| limit.saturating_sub(examples.len()));
327
328 if let Some(0) = remaining_limit_for_snowflake {
329 log::info!(
330 "skipping captured-after inputs because --limit is already satisfied by example files"
331 );
332 } else if !captured_after_timestamps.is_empty() {
333 captured_after_timestamps.sort();
334
335 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
336
337 let mut captured_examples = pull_examples::fetch_captured_examples_after(
338 http_client,
339 &captured_after_timestamps,
340 max_rows_per_timestamp,
341 background_executor,
342 )
343 .await?;
344 examples.append(&mut captured_examples);
345 }
346
347 crate::example::sort_examples_by_repo_and_rev(&mut examples);
348
349 if let Some(name_filter) = &args.name {
350 examples.retain(|example| example.spec.name.contains(name_filter));
351 }
352 if let Some(repo_filter) = &args.repo {
353 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
354 }
355
356 if let Some(limit) = args.limit {
357 if examples.len() > limit {
358 examples.truncate(limit);
359 }
360 }
361
362 if let Some(path) = output_path {
363 resume_from_output(path, &mut examples);
364 }
365
366 Progress::global().set_total_examples(examples.len());
367
368 Ok(examples)
369}
370
371fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
372 let mut hasher = collections::FxHasher::default();
373 spec.hash(&mut hasher);
374 hasher.finish()
375}
376
377fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
378 let file = match File::open(path) {
379 Ok(f) => f,
380 Err(_) => return,
381 };
382
383 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
384
385 let reader = BufReader::new(file);
386 let mut kept_lines = Vec::new();
387 let mut kept_hashes = HashSet::default();
388
389 for line in reader.lines() {
390 let line = match line {
391 Ok(l) => l,
392 Err(_) => continue,
393 };
394
395 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
396 let hash = spec_hash(&output_example.spec);
397 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
398 kept_hashes.insert(hash);
399 kept_lines.push(line);
400 }
401 }
402 }
403
404 let total = examples.len();
405 let already_processed = kept_hashes.len();
406
407 eprintln!(
408 "Resuming: {}/{} examples already processed",
409 already_processed, total
410 );
411
412 let file = OpenOptions::new()
413 .write(true)
414 .truncate(true)
415 .open(path)
416 .expect("Failed to open output file for rewriting");
417 let mut writer = BufWriter::new(file);
418 for line in &kept_lines {
419 writeln!(writer, "{}", line).expect("Failed to write to output file");
420 }
421 writer.flush().expect("Failed to flush output file");
422
423 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
424}
425
426fn main() {
427 let args = EpArgs::parse();
428
429 if args.printenv {
430 ::util::shell_env::print_env();
431 return;
432 }
433
434 let output = args.output_path();
435 let command = match &args.command {
436 Some(cmd) => cmd.clone(),
437 None => {
438 EpArgs::command().print_help().unwrap();
439 return;
440 }
441 };
442
443 match &command {
444 Command::Clean => {
445 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
446 return;
447 }
448 Command::Synthesize(synth_args) => {
449 let Some(output_dir) = args.output else {
450 panic!("output dir is required");
451 };
452 let config = SynthesizeConfig {
453 repo_urls: synth_args.repos.clone(),
454 count: synth_args.count,
455 max_commits: synth_args.max_commits,
456 output_dir,
457 fresh: synth_args.fresh,
458 };
459 smol::block_on(async {
460 if let Err(e) = run_synthesize(config).await {
461 eprintln!("Error: {:?}", e);
462 std::process::exit(1);
463 }
464 });
465 return;
466 }
467 Command::SplitCommit(split_commit_args) => {
468 if let Err(error) =
469 split_commit::run_split_commit(split_commit_args, &args.inputs, output.as_ref())
470 {
471 eprintln!("{error:#}");
472 std::process::exit(1);
473 }
474 return;
475 }
476 Command::Split(split_args) => {
477 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
478 eprintln!("{error:#}");
479 std::process::exit(1);
480 }
481 return;
482 }
483 _ => {}
484 }
485
486 let http_client = Arc::new(ReqwestClient::new());
487 let app = Application::headless().with_http_client(http_client);
488
489 app.run(move |cx| {
490 let app_state = Arc::new(headless::init(cx));
491 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
492
493 cx.spawn(async move |cx| {
494 let result = async {
495 let mut examples = load_examples(
496 app_state.client.http_client(),
497 &args,
498 output.as_ref(),
499 cx.background_executor().clone(),
500 )
501 .await?;
502
503 match &command {
504 Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
505 predict::sync_batches(&args.provider).await?;
506 }
507 _ => (),
508 }
509
510 let failfast_on_single_example = examples.len() == 1;
511
512 let output_sender: Option<mpsc::UnboundedSender<String>> =
513 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
514 output.as_ref().map(|path| {
515 let file = OpenOptions::new()
516 .create(true)
517 .append(true)
518 .open(path)
519 .expect("Failed to open output file");
520 let mut writer = BufWriter::new(file);
521 let (sender, mut receiver) = mpsc::unbounded::<String>();
522 cx.background_spawn(async move {
523 while let Some(line) = receiver.next().await {
524 writeln!(writer, "{}", line).expect("Failed to write example");
525 writer.flush().expect("Failed to flush output");
526 }
527 })
528 .detach();
529 sender
530 })
531 } else {
532 None
533 };
534
535 let mut grouped_examples = group_examples_by_repo(&mut examples);
536 let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
537
538 for example_batch in example_batches {
539 let futures = example_batch.into_iter().map(|repo_examples| async {
540 for example in repo_examples.iter_mut() {
541 let result = async {
542 match &command {
543 Command::ParseExample => {}
544 Command::LoadProject => {
545 run_load_project(example, app_state.clone(), cx.clone())
546 .await?;
547 }
548 Command::Context => {
549 run_context_retrieval(
550 example,
551 app_state.clone(),
552 cx.clone(),
553 )
554 .await?;
555 }
556 Command::FormatPrompt(args) => {
557 run_format_prompt(
558 example,
559 args,
560 app_state.clone(),
561 cx.clone(),
562 )
563 .await?;
564 }
565 Command::Predict(args) => {
566 run_prediction(
567 example,
568 args,
569 app_state.clone(),
570 cx.clone(),
571 )
572 .await?;
573 }
574 Command::Distill => {
575 run_distill(example).await?;
576 }
577 Command::Score(args) | Command::Eval(args) => {
578 run_scoring(example, &args, app_state.clone(), cx.clone())
579 .await?;
580 }
581 Command::Clean
582 | Command::Synthesize(_)
583 | Command::SplitCommit(_)
584 | Command::Split(_) => {
585 unreachable!()
586 }
587 }
588 anyhow::Ok(())
589 }
590 .await;
591
592 let failed = if let Err(error) = result {
593 handle_error(
594 error,
595 &args,
596 &command,
597 &app_state,
598 failfast_on_single_example,
599 example,
600 )
601 .await;
602 true
603 } else {
604 false
605 };
606
607 let should_write = !failed || args.failed == FailedHandling::Keep;
608 if should_write {
609 if let Some(ref mut sender) = output_sender.clone() {
610 let line = serde_json::to_string(example).unwrap();
611 sender
612 .send(line)
613 .await
614 .expect("Failed to send to output writer");
615 } else if args.output.is_none()
616 && !matches!(command, Command::Eval(_))
617 {
618 let line = serde_json::to_string(example).unwrap();
619 println!("{}", line);
620 }
621 }
622 }
623 });
624 futures::future::join_all(futures).await;
625 }
626
627 Progress::global().finalize();
628
629 match &command {
630 Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
631 predict::sync_batches(&args.provider).await?;
632 }
633 _ => (),
634 }
635
636 match &command {
637 Command::Eval(_) => score::print_report(&examples),
638 _ => (),
639 };
640
641 anyhow::Ok(())
642 }
643 .await;
644
645 if let Err(e) = result {
646 panic!("Fatal error: {:?}", e);
647 }
648
649 let _ = cx.update(|cx| cx.quit());
650 })
651 .detach();
652 });
653}
654
655async fn handle_error(
656 error: anyhow::Error,
657 args: &EpArgs,
658 command: &Command,
659 app_state: &Arc<headless::EpAppState>,
660 failfast_on_single_example: bool,
661 example: &Example,
662) {
663 Progress::global().increment_failed();
664 let example_name = example.spec.filename();
665 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
666 app_state
667 .fs
668 .write(
669 &failed_example_path,
670 &serde_json::to_vec_pretty(&example).unwrap(),
671 )
672 .await
673 .unwrap();
674 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
675 app_state
676 .fs
677 .write(&err_path, format!("{error:?}").as_bytes())
678 .await
679 .unwrap();
680
681 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
682 let mut file = OpenOptions::new()
683 .create(true)
684 .append(true)
685 .open(&failed_jsonl_path)
686 .expect("Failed to open failed.jsonl");
687 writeln!(file, "{}", serde_json::to_string(example).unwrap())
688 .expect("Failed to write to failed.jsonl");
689
690 let cursor_path = example
691 .repo_name()
692 .unwrap()
693 .worktree_path()
694 .join(&example.spec.cursor_path);
695
696 let msg = format!(
697 indoc::indoc! {"
698 While processing \"{}\":
699
700 \x1b[31m{:?}\x1b[0m
701
702 Example: \x1b[36m{}\x1b[0m
703 Error file: \x1b[36m{}\x1b[0m
704 Cursor file: \x1b[36m{}\x1b[0m
705 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
706 "},
707 example.spec.name,
708 error,
709 failed_example_path.display(),
710 err_path.display(),
711 cursor_path.display(),
712 command,
713 failed_example_path.display(),
714 );
715 if args.failfast || failfast_on_single_example {
716 Progress::global().finalize();
717 panic!("{}", msg);
718 } else {
719 log::error!("{}", msg);
720 }
721}