1mod anthropic_client;
2mod distill;
3mod example;
4mod format_prompt;
5mod git;
6mod headless;
7mod load_project;
8mod metrics;
9mod paths;
10mod predict;
11mod progress;
12mod pull_examples;
13mod reorder_patch;
14mod retrieve_context;
15mod score;
16mod split_commit;
17mod split_dataset;
18mod synthesize;
19use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
20use collections::HashSet;
21use edit_prediction::EditPredictionStore;
22use futures::channel::mpsc;
23use futures::{SinkExt as _, StreamExt as _};
24use gpui::{AppContext as _, Application, BackgroundExecutor};
25use zeta_prompt::ZetaVersion;
26
27use reqwest_client::ReqwestClient;
28use serde::{Deserialize, Deserializer, Serialize, Serializer};
29use std::fmt::Display;
30use std::fs::{File, OpenOptions};
31use std::hash::{Hash, Hasher};
32use std::io::{BufRead, BufReader, BufWriter, Write};
33use std::{path::PathBuf, sync::Arc};
34
35use crate::distill::run_distill;
36use crate::example::{Example, group_examples_by_repo, read_example_files};
37use crate::format_prompt::run_format_prompt;
38use crate::load_project::run_load_project;
39use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
40use crate::predict::run_prediction;
41use crate::progress::Progress;
42use crate::retrieve_context::run_context_retrieval;
43use crate::score::run_scoring;
44use crate::split_commit::SplitCommitArgs;
45use crate::split_dataset::SplitArgs;
46use crate::synthesize::{SynthesizeConfig, run_synthesize};
47
48#[derive(Parser, Debug)]
49#[command(name = "ep")]
50struct EpArgs {
51 #[arg(long, default_value_t = false)]
52 printenv: bool,
53 #[clap(long, default_value_t = 10, global = true)]
54 max_parallelism: usize,
55 #[clap(long, global = true)]
56 limit: Option<usize>,
57 /// Filter examples by name
58 #[clap(long, global = true)]
59 name: Option<String>,
60 /// Filter examples by repository
61 #[clap(long, global = true)]
62 repo: Option<String>,
63 #[command(subcommand)]
64 command: Option<Command>,
65 #[clap(global = true, help = INPUTS_HELP)]
66 inputs: Vec<PathBuf>,
67 #[arg(long, short, global = true)]
68 output: Option<PathBuf>,
69 #[arg(long, short, global = true)]
70 in_place: bool,
71 #[arg(long, short, global = true)]
72 failfast: bool,
73 /// How to handle failed examples in output: keep them or skip them.
74 /// Failed examples are always logged to the run's failed directory.
75 #[arg(long, global = true, default_value = "keep")]
76 failed: FailedHandling,
77}
78
79/// Controls whether failed examples are included in the main output.
80/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
81#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
82pub enum FailedHandling {
83 /// Include failed examples in the main output (default)
84 #[default]
85 Keep,
86 /// Exclude failed examples from the main output
87 Skip,
88}
89
90const INPUTS_HELP: &str = r#"
91Inputs can be file paths or special specifiers:
92
93 path
94 Path to an example(s) file (.md, .json, or .jsonl)
95
96 captured-after:{timestamp}
97 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
98
99 You can specify this multiple times and mix it with file inputs.
100
101 Required environment variables to connect to Snowflake:
102 EP_SNOWFLAKE_API_KEY
103 EP_SNOWFLAKE_BASE_URL
104
105 Optional:
106 EP_SNOWFLAKE_ROLE
107
108Examples:
109
110 # Predict from a file
111 ep predict examples.jsonl
112
113 # Predict from captured examples after a timestamp
114 ep predict captured-after:2025-01-01T00:00:00Z
115
116 # Mix file inputs and captured-after in the same invocation
117 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
118"#;
119
120#[derive(Subcommand, Debug, Clone)]
121enum Command {
122 /// Parse markdown examples and output a combined .jsonl file
123 ParseExample,
124 /// Create git worktrees for each example and load file contents
125 LoadProject,
126 /// Retrieve context for input examples.
127 Context,
128 /// Generate a prompt string for a specific model
129 FormatPrompt(FormatPromptArgs),
130 /// Runs edit prediction
131 Predict(PredictArgs),
132 /// Computes a score based on actual and expected patches
133 Score(PredictArgs),
134 /// Prepares a distillation dataset by copying expected outputs to
135 /// predicted outputs and removing actual outputs and prompts.
136 Distill,
137 /// Print aggregated scores
138 Eval(PredictArgs),
139 /// Generate eval examples by analyzing git commits from a repository
140 Synthesize(SynthesizeArgs),
141 /// Remove git repositories and worktrees
142 Clean,
143 /// Generate an evaluation example by splitting a chronologically-ordered commit
144 SplitCommit(SplitCommitArgs),
145 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
146 Split(SplitArgs),
147}
148
149impl Display for Command {
150 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151 match self {
152 Command::ParseExample => write!(f, "parse-example"),
153 Command::LoadProject => write!(f, "load-project"),
154 Command::Context => write!(f, "context"),
155 Command::FormatPrompt(args) => {
156 write!(f, "format-prompt --provider={}", args.provider)
157 }
158 Command::Predict(args) => {
159 write!(f, "predict --provider={}", args.provider)
160 }
161 Command::Score(args) => {
162 write!(f, "score --provider={}", args.provider)
163 }
164 Command::Distill => write!(f, "distill"),
165 Command::Eval(args) => {
166 write!(f, "eval --provider={}", args.provider)
167 }
168 Command::Synthesize(args) => {
169 write!(f, "synthesize --repos {}", args.repos.join(" "))
170 }
171 Command::Clean => write!(f, "clean"),
172 Command::SplitCommit(_) => write!(f, "split-commit"),
173 Command::Split(_) => write!(f, "split"),
174 }
175 }
176}
177
178#[derive(Debug, Args, Clone)]
179struct FormatPromptArgs {
180 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
181 provider: PredictionProvider,
182}
183
184#[derive(Debug, Args, Clone)]
185struct PredictArgs {
186 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
187 provider: PredictionProvider,
188 #[clap(long, default_value_t = 1)]
189 repetitions: usize,
190}
191
192#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
193enum PredictionProvider {
194 Sweep,
195 Mercury,
196 Zeta1,
197 Zeta2(ZetaVersion),
198 Teacher(ZetaVersion),
199 TeacherNonBatching(ZetaVersion),
200}
201
202impl Default for PredictionProvider {
203 fn default() -> Self {
204 PredictionProvider::Zeta2(ZetaVersion::default())
205 }
206}
207
208impl std::fmt::Display for PredictionProvider {
209 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210 match self {
211 PredictionProvider::Sweep => write!(f, "sweep"),
212 PredictionProvider::Mercury => write!(f, "mercury"),
213 PredictionProvider::Zeta1 => write!(f, "zeta1"),
214 PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
215 PredictionProvider::Teacher(version) => write!(f, "teacher:{version}"),
216 PredictionProvider::TeacherNonBatching(version) => {
217 write!(f, "teacher-non-batching:{version}")
218 }
219 }
220 }
221}
222
223impl std::str::FromStr for PredictionProvider {
224 type Err = anyhow::Error;
225
226 fn from_str(mut s: &str) -> Result<Self, Self::Err> {
227 let mut version = ZetaVersion::default();
228 if let Some((first, second)) = s.split_once(':') {
229 version = ZetaVersion::parse(second)?;
230 s = first;
231 }
232
233 let s_lower = s.to_lowercase();
234 match s_lower.as_str() {
235 "sweep" => Ok(PredictionProvider::Sweep),
236 "mercury" => Ok(PredictionProvider::Mercury),
237 "zeta1" => Ok(PredictionProvider::Zeta1),
238 "zeta2" => Ok(PredictionProvider::Zeta2(version)),
239 "teacher" => Ok(PredictionProvider::Teacher(version)),
240 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
241 Ok(PredictionProvider::TeacherNonBatching(version))
242 }
243 _ => {
244 anyhow::bail!(
245 "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching\n\
246 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
247 Available zeta versions:\n{}",
248 ZetaVersion::options_as_string()
249 )
250 }
251 }
252 }
253}
254
255impl Serialize for PredictionProvider {
256 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
257 where
258 S: Serializer,
259 {
260 serializer.serialize_str(&self.to_string())
261 }
262}
263
264impl<'de> Deserialize<'de> for PredictionProvider {
265 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
266 where
267 D: Deserializer<'de>,
268 {
269 let s = String::deserialize(deserializer)?;
270 s.parse().map_err(serde::de::Error::custom)
271 }
272}
273
274#[derive(Debug, Args, Clone)]
275struct SynthesizeArgs {
276 /// Repository URLs (git@github.com:owner/repo or https://...)
277 #[clap(long, required = true, num_args = 1..)]
278 repos: Vec<String>,
279
280 /// Number of examples to generate per repository
281 #[clap(long, default_value_t = 5)]
282 count: usize,
283
284 /// Maximum commits to scan per repository before giving up
285 #[clap(long, default_value_t = 100)]
286 max_commits: usize,
287
288 /// Ignore state file and reprocess all commits
289 #[clap(long)]
290 fresh: bool,
291}
292
293impl EpArgs {
294 fn output_path(&self) -> Option<PathBuf> {
295 if self.in_place {
296 if self.inputs.len() == 1 {
297 self.inputs.first().cloned()
298 } else {
299 panic!("--in-place requires exactly one input file")
300 }
301 } else {
302 self.output.clone()
303 }
304 }
305}
306
307async fn load_examples(
308 http_client: Arc<dyn http_client::HttpClient>,
309 args: &EpArgs,
310 output_path: Option<&PathBuf>,
311 background_executor: BackgroundExecutor,
312) -> anyhow::Result<Vec<Example>> {
313 let mut captured_after_timestamps = Vec::new();
314 let mut file_inputs = Vec::new();
315
316 for input in &args.inputs {
317 let input_string = input.to_string_lossy();
318 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
319 captured_after_timestamps.push(timestamp.to_string());
320 } else {
321 file_inputs.push(input.clone());
322 }
323 }
324
325 let mut examples = read_example_files(&file_inputs);
326
327 Progress::global().set_total_examples(examples.len());
328
329 let remaining_limit_for_snowflake =
330 args.limit.map(|limit| limit.saturating_sub(examples.len()));
331
332 if let Some(0) = remaining_limit_for_snowflake {
333 log::info!(
334 "skipping captured-after inputs because --limit is already satisfied by example files"
335 );
336 } else if !captured_after_timestamps.is_empty() {
337 captured_after_timestamps.sort();
338
339 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
340
341 let mut captured_examples = pull_examples::fetch_captured_examples_after(
342 http_client,
343 &captured_after_timestamps,
344 max_rows_per_timestamp,
345 background_executor,
346 )
347 .await?;
348 examples.append(&mut captured_examples);
349 }
350
351 crate::example::sort_examples_by_repo_and_rev(&mut examples);
352
353 if let Some(name_filter) = &args.name {
354 examples.retain(|example| example.spec.name.contains(name_filter));
355 }
356 if let Some(repo_filter) = &args.repo {
357 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
358 }
359
360 if let Some(limit) = args.limit {
361 if examples.len() > limit {
362 examples.truncate(limit);
363 }
364 }
365
366 if let Some(path) = output_path {
367 resume_from_output(path, &mut examples);
368 }
369
370 Progress::global().set_total_examples(examples.len());
371
372 Ok(examples)
373}
374
375fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
376 let mut hasher = collections::FxHasher::default();
377 spec.hash(&mut hasher);
378 hasher.finish()
379}
380
381fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
382 let file = match File::open(path) {
383 Ok(f) => f,
384 Err(_) => return,
385 };
386
387 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
388
389 let reader = BufReader::new(file);
390 let mut kept_lines = Vec::new();
391 let mut kept_hashes = HashSet::default();
392
393 for line in reader.lines() {
394 let line = match line {
395 Ok(l) => l,
396 Err(_) => continue,
397 };
398
399 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
400 let hash = spec_hash(&output_example.spec);
401 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
402 kept_hashes.insert(hash);
403 kept_lines.push(line);
404 }
405 }
406 }
407
408 let total = examples.len();
409 let already_processed = kept_hashes.len();
410
411 eprintln!(
412 "Resuming: {}/{} examples already processed",
413 already_processed, total
414 );
415
416 let file = OpenOptions::new()
417 .write(true)
418 .truncate(true)
419 .open(path)
420 .expect("Failed to open output file for rewriting");
421 let mut writer = BufWriter::new(file);
422 for line in &kept_lines {
423 writeln!(writer, "{}", line).expect("Failed to write to output file");
424 }
425 writer.flush().expect("Failed to flush output file");
426
427 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
428}
429
430fn main() {
431 let args = EpArgs::parse();
432
433 if args.printenv {
434 ::util::shell_env::print_env();
435 return;
436 }
437
438 let output = args.output_path();
439 let command = match &args.command {
440 Some(cmd) => cmd.clone(),
441 None => {
442 EpArgs::command().print_help().unwrap();
443 return;
444 }
445 };
446
447 match &command {
448 Command::Clean => {
449 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
450 return;
451 }
452 Command::Synthesize(synth_args) => {
453 let Some(output_dir) = args.output else {
454 panic!("output dir is required");
455 };
456 let config = SynthesizeConfig {
457 repo_urls: synth_args.repos.clone(),
458 count: synth_args.count,
459 max_commits: synth_args.max_commits,
460 output_dir,
461 fresh: synth_args.fresh,
462 };
463 smol::block_on(async {
464 if let Err(e) = run_synthesize(config).await {
465 eprintln!("Error: {:?}", e);
466 std::process::exit(1);
467 }
468 });
469 return;
470 }
471 Command::SplitCommit(split_commit_args) => {
472 if let Err(error) =
473 split_commit::run_split_commit(split_commit_args, &args.inputs, output.as_ref())
474 {
475 eprintln!("{error:#}");
476 std::process::exit(1);
477 }
478 return;
479 }
480 Command::Split(split_args) => {
481 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
482 eprintln!("{error:#}");
483 std::process::exit(1);
484 }
485 return;
486 }
487 _ => {}
488 }
489
490 let http_client = Arc::new(ReqwestClient::new());
491 let app = Application::headless().with_http_client(http_client);
492
493 app.run(move |cx| {
494 let app_state = Arc::new(headless::init(cx));
495 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
496
497 cx.spawn(async move |cx| {
498 let result = async {
499 let mut examples = load_examples(
500 app_state.client.http_client(),
501 &args,
502 output.as_ref(),
503 cx.background_executor().clone(),
504 )
505 .await?;
506
507 match &command {
508 Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
509 predict::sync_batches(&args.provider).await?;
510 }
511 _ => (),
512 }
513
514 let failfast_on_single_example = examples.len() == 1;
515
516 let output_sender: Option<mpsc::UnboundedSender<String>> =
517 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
518 output.as_ref().map(|path| {
519 let file = OpenOptions::new()
520 .create(true)
521 .append(true)
522 .open(path)
523 .expect("Failed to open output file");
524 let mut writer = BufWriter::new(file);
525 let (sender, mut receiver) = mpsc::unbounded::<String>();
526 cx.background_spawn(async move {
527 while let Some(line) = receiver.next().await {
528 writeln!(writer, "{}", line).expect("Failed to write example");
529 writer.flush().expect("Failed to flush output");
530 }
531 })
532 .detach();
533 sender
534 })
535 } else {
536 None
537 };
538
539 let mut grouped_examples = group_examples_by_repo(&mut examples);
540 let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
541
542 for example_batch in example_batches {
543 let futures = example_batch.into_iter().map(|repo_examples| async {
544 for example in repo_examples.iter_mut() {
545 let result = async {
546 match &command {
547 Command::ParseExample => {}
548 Command::LoadProject => {
549 run_load_project(example, app_state.clone(), cx.clone())
550 .await?;
551 }
552 Command::Context => {
553 run_context_retrieval(
554 example,
555 app_state.clone(),
556 cx.clone(),
557 )
558 .await?;
559 }
560 Command::FormatPrompt(args) => {
561 run_format_prompt(
562 example,
563 args,
564 app_state.clone(),
565 cx.clone(),
566 )
567 .await?;
568 }
569 Command::Predict(args) => {
570 run_prediction(
571 example,
572 args,
573 app_state.clone(),
574 cx.clone(),
575 )
576 .await?;
577 }
578 Command::Distill => {
579 run_distill(example).await?;
580 }
581 Command::Score(args) | Command::Eval(args) => {
582 run_scoring(example, &args, app_state.clone(), cx.clone())
583 .await?;
584 }
585 Command::Clean
586 | Command::Synthesize(_)
587 | Command::SplitCommit(_)
588 | Command::Split(_) => {
589 unreachable!()
590 }
591 }
592 anyhow::Ok(())
593 }
594 .await;
595
596 let failed = if let Err(error) = result {
597 handle_error(
598 error,
599 &args,
600 &command,
601 &app_state,
602 failfast_on_single_example,
603 example,
604 )
605 .await;
606 true
607 } else {
608 false
609 };
610
611 let should_write = !failed || args.failed == FailedHandling::Keep;
612 if should_write {
613 if let Some(ref mut sender) = output_sender.clone() {
614 let line = serde_json::to_string(example).unwrap();
615 sender
616 .send(line)
617 .await
618 .expect("Failed to send to output writer");
619 } else if args.output.is_none()
620 && !matches!(command, Command::Eval(_))
621 {
622 let line = serde_json::to_string(example).unwrap();
623 println!("{}", line);
624 }
625 }
626 }
627 });
628 futures::future::join_all(futures).await;
629 }
630
631 Progress::global().finalize();
632
633 match &command {
634 Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
635 predict::sync_batches(&args.provider).await?;
636 }
637 _ => (),
638 }
639
640 match &command {
641 Command::Eval(_) => score::print_report(&examples),
642 _ => (),
643 };
644
645 anyhow::Ok(())
646 }
647 .await;
648
649 if let Err(e) = result {
650 panic!("Fatal error: {:?}", e);
651 }
652
653 let _ = cx.update(|cx| cx.quit());
654 })
655 .detach();
656 });
657}
658
659async fn handle_error(
660 error: anyhow::Error,
661 args: &EpArgs,
662 command: &Command,
663 app_state: &Arc<headless::EpAppState>,
664 failfast_on_single_example: bool,
665 example: &Example,
666) {
667 Progress::global().increment_failed();
668 let example_name = example.spec.filename();
669 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
670 app_state
671 .fs
672 .write(
673 &failed_example_path,
674 &serde_json::to_vec_pretty(&example).unwrap(),
675 )
676 .await
677 .unwrap();
678 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
679 app_state
680 .fs
681 .write(&err_path, format!("{error:?}").as_bytes())
682 .await
683 .unwrap();
684
685 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
686 let mut file = OpenOptions::new()
687 .create(true)
688 .append(true)
689 .open(&failed_jsonl_path)
690 .expect("Failed to open failed.jsonl");
691 writeln!(file, "{}", serde_json::to_string(example).unwrap())
692 .expect("Failed to write to failed.jsonl");
693
694 let cursor_path = example
695 .repo_name()
696 .unwrap()
697 .worktree_path()
698 .join(&example.spec.cursor_path);
699
700 let msg = format!(
701 indoc::indoc! {"
702 While processing \"{}\":
703
704 \x1b[31m{:?}\x1b[0m
705
706 Example: \x1b[36m{}\x1b[0m
707 Error file: \x1b[36m{}\x1b[0m
708 Cursor file: \x1b[36m{}\x1b[0m
709 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
710 "},
711 example.spec.name,
712 error,
713 failed_example_path.display(),
714 err_path.display(),
715 cursor_path.display(),
716 command,
717 failed_example_path.display(),
718 );
719 if args.failfast || failfast_on_single_example {
720 Progress::global().finalize();
721 panic!("{}", msg);
722 } else {
723 log::error!("{}", msg);
724 }
725}