1mod anthropic_client;
2mod distill;
3mod example;
4mod format_prompt;
5mod git;
6mod headless;
7mod load_project;
8mod metrics;
9mod paths;
10mod predict;
11mod progress;
12mod pull_examples;
13mod reorder_patch;
14mod retrieve_context;
15mod score;
16mod split_commit;
17mod split_dataset;
18mod synthesize;
19use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
20use collections::HashSet;
21use edit_prediction::EditPredictionStore;
22use futures::channel::mpsc;
23use futures::{SinkExt as _, StreamExt as _};
24use gpui::{AppContext as _, Application};
25
26use reqwest_client::ReqwestClient;
27use serde::{Deserialize, Serialize};
28use std::fmt::Display;
29use std::fs::{File, OpenOptions};
30use std::hash::{Hash, Hasher};
31use std::io::{BufRead, BufReader, BufWriter, Write};
32use std::{path::PathBuf, sync::Arc};
33
34use crate::distill::run_distill;
35use crate::example::{Example, group_examples_by_repo, read_example_files};
36use crate::format_prompt::run_format_prompt;
37use crate::load_project::run_load_project;
38use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
39use crate::predict::run_prediction;
40use crate::progress::Progress;
41use crate::retrieve_context::run_context_retrieval;
42use crate::score::run_scoring;
43use crate::split_commit::SplitCommitArgs;
44use crate::split_dataset::SplitArgs;
45use crate::synthesize::{SynthesizeConfig, run_synthesize};
46
47#[derive(Parser, Debug)]
48#[command(name = "ep")]
49struct EpArgs {
50 #[arg(long, default_value_t = false)]
51 printenv: bool,
52 #[clap(long, default_value_t = 10, global = true)]
53 max_parallelism: usize,
54 #[clap(long, global = true)]
55 limit: Option<usize>,
56 /// Filter examples by name
57 #[clap(long, global = true)]
58 name: Option<String>,
59 /// Filter examples by repository
60 #[clap(long, global = true)]
61 repo: Option<String>,
62 #[command(subcommand)]
63 command: Option<Command>,
64 #[clap(global = true, help = INPUTS_HELP)]
65 inputs: Vec<PathBuf>,
66 #[arg(long, short, global = true)]
67 output: Option<PathBuf>,
68 #[arg(long, short, global = true)]
69 in_place: bool,
70 #[arg(long, short, global = true)]
71 failfast: bool,
72 /// How to handle failed examples in output: keep them or skip them.
73 /// Failed examples are always logged to the run's failed directory.
74 #[arg(long, global = true, default_value = "keep")]
75 failed: FailedHandling,
76}
77
78/// Controls whether failed examples are included in the main output.
79/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
80#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
81pub enum FailedHandling {
82 /// Include failed examples in the main output (default)
83 #[default]
84 Keep,
85 /// Exclude failed examples from the main output
86 Skip,
87}
88
89const INPUTS_HELP: &str = r#"
90Inputs can be file paths or special specifiers:
91
92 path
93 Path to an example(s) file (.md, .json, or .jsonl)
94
95 captured-after:{timestamp}
96 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
97
98 You can specify this multiple times and mix it with file inputs.
99
100 Required environment variables to connect to Snowflake:
101 EP_SNOWFLAKE_API_KEY
102 EP_SNOWFLAKE_BASE_URL
103
104 Optional:
105 EP_SNOWFLAKE_ROLE
106
107Examples:
108
109 # Predict from a file
110 ep predict examples.jsonl
111
112 # Predict from captured examples after a timestamp
113 ep predict captured-after:2025-01-01T00:00:00Z
114
115 # Mix file inputs and captured-after in the same invocation
116 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
117"#;
118
119#[derive(Subcommand, Debug, Clone)]
120enum Command {
121 /// Parse markdown examples and output a combined .jsonl file
122 ParseExample,
123 /// Create git worktrees for each example and load file contents
124 LoadProject,
125 /// Retrieve context for input examples.
126 Context,
127 /// Generate a prompt string for a specific model
128 FormatPrompt(FormatPromptArgs),
129 /// Runs edit prediction
130 Predict(PredictArgs),
131 /// Computes a score based on actual and expected patches
132 Score(PredictArgs),
133 /// Prepares a distillation dataset by copying expected outputs to
134 /// predicted outputs and removing actual outputs and prompts.
135 Distill,
136 /// Print aggregated scores
137 Eval(PredictArgs),
138 /// Generate eval examples by analyzing git commits from a repository
139 Synthesize(SynthesizeArgs),
140 /// Remove git repositories and worktrees
141 Clean,
142 /// Generate an evaluation example by splitting a chronologically-ordered commit
143 SplitCommit(SplitCommitArgs),
144 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
145 Split(SplitArgs),
146}
147
148impl Display for Command {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 match self {
151 Command::ParseExample => write!(f, "parse-example"),
152 Command::LoadProject => write!(f, "load-project"),
153 Command::Context => write!(f, "context"),
154 Command::FormatPrompt(format_prompt_args) => write!(
155 f,
156 "format-prompt --prompt-format={}",
157 format_prompt_args
158 .prompt_format
159 .to_possible_value()
160 .unwrap()
161 .get_name()
162 ),
163 Command::Predict(predict_args) => {
164 write!(
165 f,
166 "predict --provider={:?}",
167 predict_args
168 .provider
169 .to_possible_value()
170 .unwrap()
171 .get_name()
172 )
173 }
174 Command::Score(predict_args) => {
175 write!(
176 f,
177 "score --provider={:?}",
178 predict_args
179 .provider
180 .to_possible_value()
181 .unwrap()
182 .get_name()
183 )
184 }
185 Command::Distill => write!(f, "distill"),
186 Command::Eval(predict_args) => write!(
187 f,
188 "eval --provider={:?}",
189 predict_args
190 .provider
191 .to_possible_value()
192 .unwrap()
193 .get_name()
194 ),
195 Command::Synthesize(args) => {
196 write!(f, "synthesize --repo={}", args.repo)
197 }
198 Command::Clean => write!(f, "clean"),
199 Command::SplitCommit(_) => write!(f, "split-commit"),
200 Command::Split(_) => write!(f, "split"),
201 }
202 }
203}
204
205#[derive(Debug, Args, Clone)]
206struct FormatPromptArgs {
207 #[clap(long, short('p'))]
208 prompt_format: PromptFormat,
209}
210
211#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
212enum PromptFormat {
213 Teacher,
214 Zeta2,
215}
216
217#[derive(Debug, Args, Clone)]
218struct PredictArgs {
219 #[clap(long)]
220 provider: PredictionProvider,
221 #[clap(long, default_value_t = 1)]
222 repetitions: usize,
223}
224
225#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
226enum PredictionProvider {
227 Sweep,
228 Mercury,
229 Zeta1,
230 Zeta2,
231 Teacher,
232 TeacherNonBatching,
233}
234
235#[derive(Debug, Args, Clone)]
236struct SynthesizeArgs {
237 /// Repository URL (git@github.com:owner/repo or https://...)
238 #[clap(long)]
239 repo: String,
240
241 /// Number of examples to generate
242 #[clap(long, default_value_t = 5)]
243 count: usize,
244
245 /// Maximum commits to scan before giving up
246 #[clap(long, default_value_t = 100)]
247 max_commits: usize,
248
249 /// Ignore state file and reprocess all commits
250 #[clap(long)]
251 fresh: bool,
252}
253
254impl EpArgs {
255 fn output_path(&self) -> Option<PathBuf> {
256 if self.in_place {
257 if self.inputs.len() == 1 {
258 self.inputs.first().cloned()
259 } else {
260 panic!("--in-place requires exactly one input file")
261 }
262 } else {
263 self.output.clone()
264 }
265 }
266}
267
268async fn load_examples(
269 http_client: Arc<dyn http_client::HttpClient>,
270 args: &EpArgs,
271 output_path: Option<&PathBuf>,
272) -> anyhow::Result<Vec<Example>> {
273 let mut captured_after_timestamps = Vec::new();
274 let mut file_inputs = Vec::new();
275
276 for input in &args.inputs {
277 let input_string = input.to_string_lossy();
278 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
279 captured_after_timestamps.push(timestamp.to_string());
280 } else {
281 file_inputs.push(input.clone());
282 }
283 }
284
285 let mut examples = read_example_files(&file_inputs);
286
287 Progress::global().set_total_examples(examples.len());
288
289 let remaining_limit_for_snowflake =
290 args.limit.map(|limit| limit.saturating_sub(examples.len()));
291
292 if let Some(0) = remaining_limit_for_snowflake {
293 log::info!(
294 "skipping captured-after inputs because --limit is already satisfied by example files"
295 );
296 } else if !captured_after_timestamps.is_empty() {
297 captured_after_timestamps.sort();
298
299 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
300
301 let mut captured_examples = pull_examples::fetch_captured_examples_after(
302 http_client,
303 &captured_after_timestamps,
304 max_rows_per_timestamp,
305 )
306 .await?;
307 examples.append(&mut captured_examples);
308 }
309
310 crate::example::sort_examples_by_repo_and_rev(&mut examples);
311
312 if let Some(name_filter) = &args.name {
313 examples.retain(|example| example.spec.name.contains(name_filter));
314 }
315 if let Some(repo_filter) = &args.repo {
316 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
317 }
318
319 if let Some(limit) = args.limit {
320 if examples.len() > limit {
321 examples.truncate(limit);
322 }
323 }
324
325 if let Some(path) = output_path {
326 resume_from_output(path, &mut examples);
327 }
328
329 Progress::global().set_total_examples(examples.len());
330
331 Ok(examples)
332}
333
334fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
335 let mut hasher = collections::FxHasher::default();
336 spec.hash(&mut hasher);
337 hasher.finish()
338}
339
340fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
341 let file = match File::open(path) {
342 Ok(f) => f,
343 Err(_) => return,
344 };
345
346 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
347
348 let reader = BufReader::new(file);
349 let mut kept_lines = Vec::new();
350 let mut kept_hashes = HashSet::default();
351
352 for line in reader.lines() {
353 let line = match line {
354 Ok(l) => l,
355 Err(_) => continue,
356 };
357
358 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
359 let hash = spec_hash(&output_example.spec);
360 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
361 kept_hashes.insert(hash);
362 kept_lines.push(line);
363 }
364 }
365 }
366
367 let total = examples.len();
368 let already_processed = kept_hashes.len();
369
370 eprintln!(
371 "Resuming: {}/{} examples already processed",
372 already_processed, total
373 );
374
375 let file = OpenOptions::new()
376 .write(true)
377 .truncate(true)
378 .open(path)
379 .expect("Failed to open output file for rewriting");
380 let mut writer = BufWriter::new(file);
381 for line in &kept_lines {
382 writeln!(writer, "{}", line).expect("Failed to write to output file");
383 }
384 writer.flush().expect("Failed to flush output file");
385
386 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
387}
388
389fn main() {
390 let args = EpArgs::parse();
391
392 if args.printenv {
393 ::util::shell_env::print_env();
394 return;
395 }
396
397 let output = args.output_path();
398 let command = match &args.command {
399 Some(cmd) => cmd.clone(),
400 None => {
401 EpArgs::command().print_help().unwrap();
402 return;
403 }
404 };
405
406 match &command {
407 Command::Clean => {
408 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
409 return;
410 }
411 Command::Synthesize(synth_args) => {
412 let Some(output_dir) = args.output else {
413 panic!("output dir is required");
414 };
415 let config = SynthesizeConfig {
416 repo_url: synth_args.repo.clone(),
417 count: synth_args.count,
418 max_commits: synth_args.max_commits,
419 output_dir,
420 fresh: synth_args.fresh,
421 };
422 smol::block_on(async {
423 if let Err(e) = run_synthesize(config).await {
424 eprintln!("Error: {:?}", e);
425 std::process::exit(1);
426 }
427 });
428 return;
429 }
430 Command::SplitCommit(split_commit_args) => {
431 if let Err(error) =
432 split_commit::run_split_commit(split_commit_args, &args.inputs, output.as_ref())
433 {
434 eprintln!("{error:#}");
435 std::process::exit(1);
436 }
437 return;
438 }
439 Command::Split(split_args) => {
440 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
441 eprintln!("{error:#}");
442 std::process::exit(1);
443 }
444 return;
445 }
446 _ => {}
447 }
448
449 let http_client = Arc::new(ReqwestClient::new());
450 let app = Application::headless().with_http_client(http_client);
451
452 app.run(move |cx| {
453 let app_state = Arc::new(headless::init(cx));
454 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
455
456 cx.spawn(async move |cx| {
457 let result = async {
458 let mut examples =
459 load_examples(app_state.client.http_client(), &args, output.as_ref()).await?;
460
461 match &command {
462 Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
463 predict::sync_batches(&args.provider).await?;
464 }
465 _ => (),
466 }
467
468 let failfast_on_single_example = examples.len() == 1;
469
470 let output_sender: Option<mpsc::UnboundedSender<String>> =
471 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
472 output.as_ref().map(|path| {
473 let file = OpenOptions::new()
474 .create(true)
475 .append(true)
476 .open(path)
477 .expect("Failed to open output file");
478 let mut writer = BufWriter::new(file);
479 let (sender, mut receiver) = mpsc::unbounded::<String>();
480 cx.background_spawn(async move {
481 while let Some(line) = receiver.next().await {
482 writeln!(writer, "{}", line).expect("Failed to write example");
483 writer.flush().expect("Failed to flush output");
484 }
485 })
486 .detach();
487 sender
488 })
489 } else {
490 None
491 };
492
493 let mut grouped_examples = group_examples_by_repo(&mut examples);
494 let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
495
496 for example_batch in example_batches {
497 let futures = example_batch.into_iter().map(|repo_examples| async {
498 for example in repo_examples.iter_mut() {
499 let result = async {
500 match &command {
501 Command::ParseExample => {}
502 Command::LoadProject => {
503 run_load_project(example, app_state.clone(), cx.clone())
504 .await?;
505 }
506 Command::Context => {
507 run_context_retrieval(
508 example,
509 app_state.clone(),
510 cx.clone(),
511 )
512 .await?;
513 }
514 Command::FormatPrompt(args) => {
515 run_format_prompt(
516 example,
517 args.prompt_format,
518 app_state.clone(),
519 cx.clone(),
520 )
521 .await?;
522 }
523 Command::Predict(args) => {
524 run_prediction(
525 example,
526 Some(args.provider),
527 args.repetitions,
528 app_state.clone(),
529 cx.clone(),
530 )
531 .await?;
532 }
533 Command::Distill => {
534 run_distill(example).await?;
535 }
536 Command::Score(args) | Command::Eval(args) => {
537 run_scoring(example, &args, app_state.clone(), cx.clone())
538 .await?;
539 }
540 Command::Clean
541 | Command::Synthesize(_)
542 | Command::SplitCommit(_)
543 | Command::Split(_) => {
544 unreachable!()
545 }
546 }
547 anyhow::Ok(())
548 }
549 .await;
550
551 let failed = if let Err(error) = result {
552 handle_error(
553 error,
554 &args,
555 &command,
556 &app_state,
557 failfast_on_single_example,
558 example,
559 )
560 .await;
561 true
562 } else {
563 false
564 };
565
566 let should_write = !failed || args.failed == FailedHandling::Keep;
567 if should_write {
568 if let Some(ref mut sender) = output_sender.clone() {
569 let line = serde_json::to_string(example).unwrap();
570 sender
571 .send(line)
572 .await
573 .expect("Failed to send to output writer");
574 } else if args.output.is_none()
575 && !matches!(command, Command::Eval(_))
576 {
577 let line = serde_json::to_string(example).unwrap();
578 println!("{}", line);
579 }
580 }
581 }
582 });
583 futures::future::join_all(futures).await;
584 }
585
586 Progress::global().finalize();
587
588 match &command {
589 Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
590 predict::sync_batches(&args.provider).await?;
591 }
592 _ => (),
593 }
594
595 match &command {
596 Command::Eval(_) => score::print_report(&examples),
597 _ => (),
598 };
599
600 anyhow::Ok(())
601 }
602 .await;
603
604 if let Err(e) = result {
605 panic!("Fatal error: {:?}", e);
606 }
607
608 let _ = cx.update(|cx| cx.quit());
609 })
610 .detach();
611 });
612}
613
614async fn handle_error(
615 error: anyhow::Error,
616 args: &EpArgs,
617 command: &Command,
618 app_state: &Arc<headless::EpAppState>,
619 failfast_on_single_example: bool,
620 example: &Example,
621) {
622 Progress::global().increment_failed();
623 let example_name = example.spec.filename();
624 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
625 app_state
626 .fs
627 .write(
628 &failed_example_path,
629 &serde_json::to_vec_pretty(&example).unwrap(),
630 )
631 .await
632 .unwrap();
633 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
634 app_state
635 .fs
636 .write(&err_path, format!("{error:?}").as_bytes())
637 .await
638 .unwrap();
639
640 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
641 let mut file = OpenOptions::new()
642 .create(true)
643 .append(true)
644 .open(&failed_jsonl_path)
645 .expect("Failed to open failed.jsonl");
646 writeln!(file, "{}", serde_json::to_string(example).unwrap())
647 .expect("Failed to write to failed.jsonl");
648
649 let cursor_path = example
650 .repo_name()
651 .unwrap()
652 .worktree_path()
653 .join(&example.spec.cursor_path);
654
655 let msg = format!(
656 indoc::indoc! {"
657 While processing \"{}\":
658
659 \x1b[31m{:?}\x1b[0m
660
661 Example: \x1b[36m{}\x1b[0m
662 Error file: \x1b[36m{}\x1b[0m
663 Cursor file: \x1b[36m{}\x1b[0m
664 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
665 "},
666 example.spec.name,
667 error,
668 failed_example_path.display(),
669 err_path.display(),
670 cursor_path.display(),
671 command,
672 failed_example_path.display(),
673 );
674 if args.failfast || failfast_on_single_example {
675 Progress::global().finalize();
676 panic!("{}", msg);
677 } else {
678 log::error!("{}", msg);
679 }
680}