1mod anthropic_client;
2mod distill;
3mod example;
4mod format_prompt;
5mod git;
6mod headless;
7mod load_project;
8mod metrics;
9mod paths;
10mod predict;
11mod progress;
12mod pull_examples;
13mod reorder_patch;
14mod retrieve_context;
15mod score;
16mod split_commit;
17mod split_dataset;
18mod synthesize;
19use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
20use collections::HashSet;
21use edit_prediction::EditPredictionStore;
22use futures::channel::mpsc;
23use futures::{SinkExt as _, StreamExt as _};
24use gpui::{AppContext as _, Application};
25
26use reqwest_client::ReqwestClient;
27use serde::{Deserialize, Serialize};
28use std::fmt::Display;
29use std::fs::{File, OpenOptions};
30use std::hash::{Hash, Hasher};
31use std::io::{BufRead, BufReader, BufWriter, Write};
32use std::{path::PathBuf, sync::Arc};
33
34use crate::distill::run_distill;
35use crate::example::{Example, group_examples_by_repo, read_example_files};
36use crate::format_prompt::run_format_prompt;
37use crate::load_project::run_load_project;
38use crate::paths::FAILED_EXAMPLES_DIR;
39use crate::predict::run_prediction;
40use crate::progress::Progress;
41use crate::retrieve_context::run_context_retrieval;
42use crate::score::run_scoring;
43use crate::split_commit::SplitCommitArgs;
44use crate::split_dataset::SplitArgs;
45use crate::synthesize::{SynthesizeConfig, run_synthesize};
46
47#[derive(Parser, Debug)]
48#[command(name = "ep")]
49struct EpArgs {
50 #[arg(long, default_value_t = false)]
51 printenv: bool,
52 #[clap(long, default_value_t = 10, global = true)]
53 max_parallelism: usize,
54 #[clap(long, global = true)]
55 limit: Option<usize>,
56 /// Filter examples by name
57 #[clap(long, global = true)]
58 name: Option<String>,
59 /// Filter examples by repository
60 #[clap(long, global = true)]
61 repo: Option<String>,
62 #[command(subcommand)]
63 command: Option<Command>,
64 #[clap(global = true, help = INPUTS_HELP)]
65 inputs: Vec<PathBuf>,
66 #[arg(long, short, global = true)]
67 output: Option<PathBuf>,
68 #[arg(long, short, global = true)]
69 in_place: bool,
70 #[arg(long, short, global = true)]
71 failfast: bool,
72}
73
74const INPUTS_HELP: &str = r#"
75Inputs can be file paths or special specifiers:
76
77 path
78 Path to an example(s) file (.md, .json, or .jsonl)
79
80 captured-after:{timestamp}
81 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
82
83 You can specify this multiple times and mix it with file inputs.
84
85 Required environment variables to connect to Snowflake:
86 EP_SNOWFLAKE_API_KEY
87 EP_SNOWFLAKE_BASE_URL
88
89 Optional:
90 EP_SNOWFLAKE_ROLE
91
92Examples:
93
94 # Predict from a file
95 ep predict examples.jsonl
96
97 # Predict from captured examples after a timestamp
98 ep predict captured-after:2025-01-01T00:00:00Z
99
100 # Mix file inputs and captured-after in the same invocation
101 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
102"#;
103
104#[derive(Subcommand, Debug, Clone)]
105enum Command {
106 /// Parse markdown examples and output a combined .jsonl file
107 ParseExample,
108 /// Create git worktrees for each example and load file contents
109 LoadProject,
110 /// Retrieve context for input examples.
111 Context,
112 /// Generate a prompt string for a specific model
113 FormatPrompt(FormatPromptArgs),
114 /// Runs edit prediction
115 Predict(PredictArgs),
116 /// Computes a score based on actual and expected patches
117 Score(PredictArgs),
118 /// Prepares a distillation dataset by copying expected outputs to
119 /// predicted outputs and removing actual outputs and prompts.
120 Distill,
121 /// Print aggregated scores
122 Eval(PredictArgs),
123 /// Generate eval examples by analyzing git commits from a repository
124 Synthesize(SynthesizeArgs),
125 /// Remove git repositories and worktrees
126 Clean,
127 /// Generate an evaluation example by splitting a chronologically-ordered commit
128 SplitCommit(SplitCommitArgs),
129 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
130 Split(SplitArgs),
131}
132
133impl Display for Command {
134 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135 match self {
136 Command::ParseExample => write!(f, "parse-example"),
137 Command::LoadProject => write!(f, "load-project"),
138 Command::Context => write!(f, "context"),
139 Command::FormatPrompt(format_prompt_args) => write!(
140 f,
141 "format-prompt --prompt-format={}",
142 format_prompt_args
143 .prompt_format
144 .to_possible_value()
145 .unwrap()
146 .get_name()
147 ),
148 Command::Predict(predict_args) => {
149 write!(
150 f,
151 "predict --provider={:?}",
152 predict_args
153 .provider
154 .to_possible_value()
155 .unwrap()
156 .get_name()
157 )
158 }
159 Command::Score(predict_args) => {
160 write!(
161 f,
162 "score --provider={:?}",
163 predict_args
164 .provider
165 .to_possible_value()
166 .unwrap()
167 .get_name()
168 )
169 }
170 Command::Distill => write!(f, "distill"),
171 Command::Eval(predict_args) => write!(
172 f,
173 "eval --provider={:?}",
174 predict_args
175 .provider
176 .to_possible_value()
177 .unwrap()
178 .get_name()
179 ),
180 Command::Synthesize(args) => {
181 write!(f, "synthesize --repo={}", args.repo)
182 }
183 Command::Clean => write!(f, "clean"),
184 Command::SplitCommit(_) => write!(f, "split-commit"),
185 Command::Split(_) => write!(f, "split"),
186 }
187 }
188}
189
190#[derive(Debug, Args, Clone)]
191struct FormatPromptArgs {
192 #[clap(long)]
193 prompt_format: PromptFormat,
194}
195
196#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
197enum PromptFormat {
198 Teacher,
199 Zeta2,
200}
201
202#[derive(Debug, Args, Clone)]
203struct PredictArgs {
204 #[clap(long)]
205 provider: PredictionProvider,
206 #[clap(long, default_value_t = 1)]
207 repetitions: usize,
208}
209
210#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
211enum PredictionProvider {
212 Sweep,
213 Mercury,
214 Zeta1,
215 Zeta2,
216 Teacher,
217 TeacherNonBatching,
218}
219
220#[derive(Debug, Args, Clone)]
221struct SynthesizeArgs {
222 /// Repository URL (git@github.com:owner/repo or https://...)
223 #[clap(long)]
224 repo: String,
225
226 /// Number of examples to generate
227 #[clap(long, default_value_t = 5)]
228 count: usize,
229
230 /// Maximum commits to scan before giving up
231 #[clap(long, default_value_t = 100)]
232 max_commits: usize,
233
234 /// Ignore state file and reprocess all commits
235 #[clap(long)]
236 fresh: bool,
237}
238
239impl EpArgs {
240 fn output_path(&self) -> Option<PathBuf> {
241 if self.in_place {
242 if self.inputs.len() == 1 {
243 self.inputs.first().cloned()
244 } else {
245 panic!("--in-place requires exactly one input file")
246 }
247 } else {
248 self.output.clone()
249 }
250 }
251}
252
253async fn load_examples(
254 http_client: Arc<dyn http_client::HttpClient>,
255 args: &EpArgs,
256 output_path: Option<&PathBuf>,
257) -> anyhow::Result<Vec<Example>> {
258 let mut captured_after_timestamps = Vec::new();
259 let mut file_inputs = Vec::new();
260
261 for input in &args.inputs {
262 let input_string = input.to_string_lossy();
263 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
264 captured_after_timestamps.push(timestamp.to_string());
265 } else {
266 file_inputs.push(input.clone());
267 }
268 }
269
270 let mut examples = read_example_files(&file_inputs);
271
272 Progress::global().set_total_examples(examples.len());
273
274 let remaining_limit_for_snowflake =
275 args.limit.map(|limit| limit.saturating_sub(examples.len()));
276
277 if let Some(0) = remaining_limit_for_snowflake {
278 log::info!(
279 "skipping captured-after inputs because --limit is already satisfied by example files"
280 );
281 } else if !captured_after_timestamps.is_empty() {
282 captured_after_timestamps.sort();
283
284 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
285
286 let mut captured_examples = pull_examples::fetch_captured_examples_after(
287 http_client,
288 &captured_after_timestamps,
289 max_rows_per_timestamp,
290 )
291 .await?;
292 examples.append(&mut captured_examples);
293 }
294
295 crate::example::sort_examples_by_repo_and_rev(&mut examples);
296
297 if let Some(name_filter) = &args.name {
298 examples.retain(|example| example.spec.name.contains(name_filter));
299 }
300 if let Some(repo_filter) = &args.repo {
301 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
302 }
303
304 if let Some(limit) = args.limit {
305 if examples.len() > limit {
306 examples.truncate(limit);
307 }
308 }
309
310 if let Some(path) = output_path {
311 resume_from_output(path, &mut examples);
312 }
313
314 Progress::global().set_total_examples(examples.len());
315
316 Ok(examples)
317}
318
319fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
320 let mut hasher = collections::FxHasher::default();
321 spec.hash(&mut hasher);
322 hasher.finish()
323}
324
325fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
326 let file = match File::open(path) {
327 Ok(f) => f,
328 Err(_) => return,
329 };
330
331 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
332
333 let reader = BufReader::new(file);
334 let mut kept_lines = Vec::new();
335 let mut kept_hashes = HashSet::default();
336
337 for line in reader.lines() {
338 let line = match line {
339 Ok(l) => l,
340 Err(_) => continue,
341 };
342
343 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
344 let hash = spec_hash(&output_example.spec);
345 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
346 kept_hashes.insert(hash);
347 kept_lines.push(line);
348 }
349 }
350 }
351
352 let total = examples.len();
353 let already_processed = kept_hashes.len();
354
355 eprintln!(
356 "Resuming: {}/{} examples already processed",
357 already_processed, total
358 );
359
360 let file = OpenOptions::new()
361 .write(true)
362 .truncate(true)
363 .open(path)
364 .expect("Failed to open output file for rewriting");
365 let mut writer = BufWriter::new(file);
366 for line in &kept_lines {
367 writeln!(writer, "{}", line).expect("Failed to write to output file");
368 }
369 writer.flush().expect("Failed to flush output file");
370
371 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
372}
373
374fn main() {
375 let args = EpArgs::parse();
376
377 if args.printenv {
378 ::util::shell_env::print_env();
379 return;
380 }
381
382 let output = args.output_path();
383 let command = match &args.command {
384 Some(cmd) => cmd.clone(),
385 None => {
386 EpArgs::command().print_help().unwrap();
387 return;
388 }
389 };
390
391 match &command {
392 Command::Clean => {
393 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
394 return;
395 }
396 Command::Synthesize(synth_args) => {
397 let Some(output_dir) = args.output else {
398 panic!("output dir is required");
399 };
400 let config = SynthesizeConfig {
401 repo_url: synth_args.repo.clone(),
402 count: synth_args.count,
403 max_commits: synth_args.max_commits,
404 output_dir,
405 fresh: synth_args.fresh,
406 };
407 smol::block_on(async {
408 if let Err(e) = run_synthesize(config).await {
409 eprintln!("Error: {:?}", e);
410 std::process::exit(1);
411 }
412 });
413 return;
414 }
415 Command::SplitCommit(split_commit_args) => {
416 if let Err(error) =
417 split_commit::run_split_commit(split_commit_args, &args.inputs, output.as_ref())
418 {
419 eprintln!("{error:#}");
420 std::process::exit(1);
421 }
422 return;
423 }
424 Command::Split(split_args) => {
425 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
426 eprintln!("{error:#}");
427 std::process::exit(1);
428 }
429 return;
430 }
431 _ => {}
432 }
433
434 let http_client = Arc::new(ReqwestClient::new());
435 let app = Application::headless().with_http_client(http_client);
436
437 app.run(move |cx| {
438 let app_state = Arc::new(headless::init(cx));
439 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
440
441 cx.spawn(async move |cx| {
442 let result = async {
443 let mut examples =
444 load_examples(app_state.client.http_client(), &args, output.as_ref()).await?;
445
446 if let Command::Predict(args) = &command {
447 predict::sync_batches(&args.provider).await?;
448 }
449
450 let failfast_on_single_example = examples.len() == 1;
451
452 let output_sender: Option<mpsc::UnboundedSender<String>> =
453 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
454 output.as_ref().map(|path| {
455 let file = OpenOptions::new()
456 .create(true)
457 .append(true)
458 .open(path)
459 .expect("Failed to open output file");
460 let mut writer = BufWriter::new(file);
461 let (sender, mut receiver) = mpsc::unbounded::<String>();
462 cx.background_spawn(async move {
463 while let Some(line) = receiver.next().await {
464 writeln!(writer, "{}", line).expect("Failed to write example");
465 writer.flush().expect("Failed to flush output");
466 }
467 })
468 .detach();
469 sender
470 })
471 } else {
472 None
473 };
474
475 let mut grouped_examples = group_examples_by_repo(&mut examples);
476 let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
477
478 for example_batch in example_batches {
479 let futures = example_batch.into_iter().map(|repo_examples| async {
480 for example in repo_examples.iter_mut() {
481 let result = async {
482 match &command {
483 Command::ParseExample => {}
484 Command::LoadProject => {
485 run_load_project(example, app_state.clone(), cx.clone())
486 .await?;
487 }
488 Command::Context => {
489 run_context_retrieval(
490 example,
491 app_state.clone(),
492 cx.clone(),
493 )
494 .await?;
495 }
496 Command::FormatPrompt(args) => {
497 run_format_prompt(
498 example,
499 args.prompt_format,
500 app_state.clone(),
501 cx.clone(),
502 )
503 .await?;
504 }
505 Command::Predict(args) => {
506 run_prediction(
507 example,
508 Some(args.provider),
509 args.repetitions,
510 app_state.clone(),
511 cx.clone(),
512 )
513 .await?;
514 }
515 Command::Distill => {
516 run_distill(example).await?;
517 }
518 Command::Score(args) | Command::Eval(args) => {
519 run_scoring(example, &args, app_state.clone(), cx.clone())
520 .await?;
521 }
522 Command::Clean
523 | Command::Synthesize(_)
524 | Command::SplitCommit(_)
525 | Command::Split(_) => {
526 unreachable!()
527 }
528 }
529 anyhow::Ok(())
530 }
531 .await;
532
533 if let Err(error) = result {
534 handle_error(
535 error,
536 &args,
537 &command,
538 &app_state,
539 failfast_on_single_example,
540 example,
541 )
542 .await;
543 }
544
545 if let Some(ref mut sender) = output_sender.clone() {
546 let line = serde_json::to_string(example).unwrap();
547 sender
548 .send(line)
549 .await
550 .expect("Failed to send to output writer");
551 } else if args.output.is_none() && !matches!(command, Command::Eval(_))
552 {
553 let line = serde_json::to_string(example).unwrap();
554 println!("{}", line);
555 }
556 }
557 });
558 futures::future::join_all(futures).await;
559 }
560
561 Progress::global().finalize();
562
563 match &command {
564 Command::Predict(args) => predict::sync_batches(&args.provider).await?,
565 Command::Eval(_) => score::print_report(&examples),
566 _ => (),
567 };
568
569 anyhow::Ok(())
570 }
571 .await;
572
573 if let Err(e) = result {
574 panic!("Fatal error: {:?}", e);
575 }
576
577 let _ = cx.update(|cx| cx.quit());
578 })
579 .detach();
580 });
581}
582
583async fn handle_error(
584 error: anyhow::Error,
585 args: &EpArgs,
586 command: &Command,
587 app_state: &Arc<headless::EpAppState>,
588 failfast_on_single_example: bool,
589 example: &Example,
590) {
591 Progress::global().increment_failed();
592 let example_name = example.spec.filename();
593 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
594 app_state
595 .fs
596 .write(
597 &failed_example_path,
598 &serde_json::to_vec_pretty(&example).unwrap(),
599 )
600 .await
601 .unwrap();
602 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
603 app_state
604 .fs
605 .write(&err_path, format!("{error:?}").as_bytes())
606 .await
607 .unwrap();
608
609 let file_path = example
610 .repo_name()
611 .unwrap()
612 .worktree_path()
613 .join(&example.spec.cursor_path);
614
615 let msg = format!(
616 indoc::indoc! {"
617 While processing \"{}\":
618
619 \x1b[31m{:?}\x1b[0m
620
621 Example: \x1b[36m{}\x1b[0m
622 Error file: \x1b[36m{}\x1b[0m
623 Cursor file: \x1b[36m{}\x1b[0m
624 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
625 "},
626 example.spec.name,
627 error,
628 err_path.display(),
629 file_path.display(),
630 failed_example_path.display(),
631 command,
632 failed_example_path.display(),
633 );
634 if args.failfast || failfast_on_single_example {
635 Progress::global().finalize();
636 panic!("{}", msg);
637 } else {
638 log::error!("{}", msg);
639 }
640}