1mod anthropic_client;
2mod distill;
3mod example;
4mod format_prompt;
5mod git;
6mod headless;
7mod load_project;
8mod metrics;
9mod paths;
10mod predict;
11mod progress;
12mod pull_examples;
13mod reorder_patch;
14mod retrieve_context;
15mod score;
16mod split_commit;
17mod split_dataset;
18mod synthesize;
19use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
20use collections::HashSet;
21use edit_prediction::EditPredictionStore;
22use futures::channel::mpsc;
23use futures::{SinkExt as _, StreamExt as _};
24use gpui::{AppContext as _, Application};
25
26use reqwest_client::ReqwestClient;
27use serde::{Deserialize, Serialize};
28use std::fmt::Display;
29use std::fs::{File, OpenOptions};
30use std::hash::{Hash, Hasher};
31use std::io::{BufRead, BufReader, BufWriter, Write};
32use std::{path::PathBuf, sync::Arc};
33
34use crate::distill::run_distill;
35use crate::example::{Example, group_examples_by_repo, read_example_files};
36use crate::format_prompt::run_format_prompt;
37use crate::load_project::run_load_project;
38use crate::paths::FAILED_EXAMPLES_DIR;
39use crate::predict::run_prediction;
40use crate::progress::Progress;
41use crate::retrieve_context::run_context_retrieval;
42use crate::score::run_scoring;
43use crate::split_commit::SplitCommitArgs;
44use crate::split_dataset::SplitArgs;
45use crate::synthesize::{SynthesizeConfig, run_synthesize};
46
47#[derive(Parser, Debug)]
48#[command(name = "ep")]
49struct EpArgs {
50 #[arg(long, default_value_t = false)]
51 printenv: bool,
52 #[clap(long, default_value_t = 10, global = true)]
53 max_parallelism: usize,
54 #[clap(long, global = true)]
55 limit: Option<usize>,
56 /// Filter examples by name
57 #[clap(long, global = true)]
58 name: Option<String>,
59 /// Filter examples by repository
60 #[clap(long, global = true)]
61 repo: Option<String>,
62 #[command(subcommand)]
63 command: Option<Command>,
64 #[clap(global = true, help = INPUTS_HELP)]
65 inputs: Vec<PathBuf>,
66 #[arg(long, short, global = true)]
67 output: Option<PathBuf>,
68 #[arg(long, short, global = true)]
69 in_place: bool,
70 #[arg(long, short, global = true)]
71 failfast: bool,
72}
73
74const INPUTS_HELP: &str = r#"
75Inputs can be file paths or special specifiers:
76
77 path
78 Path to an example(s) file (.md, .json, or .jsonl)
79
80 captured-after:{timestamp}
81 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
82
83 You can specify this multiple times and mix it with file inputs.
84
85 Required environment variables to connect to Snowflake:
86 EP_SNOWFLAKE_API_KEY
87 EP_SNOWFLAKE_BASE_URL
88
89 Optional:
90 EP_SNOWFLAKE_ROLE
91
92Examples:
93
94 # Predict from a file
95 ep predict examples.jsonl
96
97 # Predict from captured examples after a timestamp
98 ep predict captured-after:2025-01-01T00:00:00Z
99
100 # Mix file inputs and captured-after in the same invocation
101 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
102"#;
103
104#[derive(Subcommand, Debug, Clone)]
105enum Command {
106 /// Parse markdown examples and output a combined .jsonl file
107 ParseExample,
108 /// Create git worktrees for each example and load file contents
109 LoadProject,
110 /// Retrieve context for input examples.
111 Context,
112 /// Generate a prompt string for a specific model
113 FormatPrompt(FormatPromptArgs),
114 /// Runs edit prediction
115 Predict(PredictArgs),
116 /// Computes a score based on actual and expected patches
117 Score(PredictArgs),
118 /// Prepares a distillation dataset by copying expected outputs to
119 /// predicted outputs and removing actual outputs and prompts.
120 Distill,
121 /// Print aggregated scores
122 Eval(PredictArgs),
123 /// Generate eval examples by analyzing git commits from a repository
124 Synthesize(SynthesizeArgs),
125 /// Remove git repositories and worktrees
126 Clean,
127 /// Generate an evaluation example by splitting a chronologically-ordered commit
128 SplitCommit(SplitCommitArgs),
129 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
130 Split(SplitArgs),
131}
132
133impl Display for Command {
134 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135 match self {
136 Command::ParseExample => write!(f, "parse-example"),
137 Command::LoadProject => write!(f, "load-project"),
138 Command::Context => write!(f, "context"),
139 Command::FormatPrompt(format_prompt_args) => write!(
140 f,
141 "format-prompt --prompt-format={}",
142 format_prompt_args
143 .prompt_format
144 .to_possible_value()
145 .unwrap()
146 .get_name()
147 ),
148 Command::Predict(predict_args) => {
149 write!(
150 f,
151 "predict --provider={:?}",
152 predict_args
153 .provider
154 .to_possible_value()
155 .unwrap()
156 .get_name()
157 )
158 }
159 Command::Score(predict_args) => {
160 write!(
161 f,
162 "score --provider={:?}",
163 predict_args
164 .provider
165 .to_possible_value()
166 .unwrap()
167 .get_name()
168 )
169 }
170 Command::Distill => write!(f, "distill"),
171 Command::Eval(predict_args) => write!(
172 f,
173 "eval --provider={:?}",
174 predict_args
175 .provider
176 .to_possible_value()
177 .unwrap()
178 .get_name()
179 ),
180 Command::Synthesize(args) => {
181 write!(f, "synthesize --repo={}", args.repo)
182 }
183 Command::Clean => write!(f, "clean"),
184 Command::SplitCommit(_) => write!(f, "split-commit"),
185 Command::Split(_) => write!(f, "split"),
186 }
187 }
188}
189
190#[derive(Debug, Args, Clone)]
191struct FormatPromptArgs {
192 #[clap(long)]
193 prompt_format: PromptFormat,
194}
195
196#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
197enum PromptFormat {
198 Teacher,
199 Zeta2,
200}
201
202#[derive(Debug, Args, Clone)]
203struct PredictArgs {
204 #[clap(long)]
205 provider: PredictionProvider,
206 #[clap(long, default_value_t = 1)]
207 repetitions: usize,
208}
209
210#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
211enum PredictionProvider {
212 Sweep,
213 Mercury,
214 Zeta1,
215 Zeta2,
216 Teacher,
217 TeacherNonBatching,
218}
219
220#[derive(Debug, Args, Clone)]
221struct SynthesizeArgs {
222 /// Repository URL (git@github.com:owner/repo or https://...)
223 #[clap(long)]
224 repo: String,
225
226 /// Number of examples to generate
227 #[clap(long, default_value_t = 5)]
228 count: usize,
229
230 /// Maximum commits to scan before giving up
231 #[clap(long, default_value_t = 100)]
232 max_commits: usize,
233
234 /// Ignore state file and reprocess all commits
235 #[clap(long)]
236 fresh: bool,
237}
238
239impl EpArgs {
240 fn output_path(&self) -> Option<PathBuf> {
241 if self.in_place {
242 if self.inputs.len() == 1 {
243 self.inputs.first().cloned()
244 } else {
245 panic!("--in-place requires exactly one input file")
246 }
247 } else {
248 self.output.clone()
249 }
250 }
251}
252
253async fn load_examples(
254 http_client: Arc<dyn http_client::HttpClient>,
255 args: &EpArgs,
256 output_path: Option<&PathBuf>,
257) -> anyhow::Result<Vec<Example>> {
258 let mut captured_after_timestamps = Vec::new();
259 let mut file_inputs = Vec::new();
260
261 for input in &args.inputs {
262 let input_string = input.to_string_lossy();
263 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
264 captured_after_timestamps.push(timestamp.to_string());
265 } else {
266 file_inputs.push(input.clone());
267 }
268 }
269
270 let mut examples = read_example_files(&file_inputs);
271
272 Progress::global().set_total_examples(examples.len());
273
274 let remaining_limit_for_snowflake =
275 args.limit.map(|limit| limit.saturating_sub(examples.len()));
276
277 if let Some(0) = remaining_limit_for_snowflake {
278 log::info!(
279 "skipping captured-after inputs because --limit is already satisfied by example files"
280 );
281 } else if !captured_after_timestamps.is_empty() {
282 captured_after_timestamps.sort();
283
284 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
285
286 let mut captured_examples = pull_examples::fetch_captured_examples_after(
287 http_client,
288 &captured_after_timestamps,
289 max_rows_per_timestamp,
290 )
291 .await?;
292 examples.append(&mut captured_examples);
293 }
294
295 crate::example::sort_examples_by_repo_and_rev(&mut examples);
296
297 if let Some(name_filter) = &args.name {
298 examples.retain(|example| example.spec.name.contains(name_filter));
299 }
300 if let Some(repo_filter) = &args.repo {
301 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
302 }
303
304 if let Some(limit) = args.limit {
305 if examples.len() > limit {
306 examples.truncate(limit);
307 }
308 }
309
310 if let Some(path) = output_path {
311 resume_from_output(path, &mut examples);
312 }
313
314 Progress::global().set_total_examples(examples.len());
315
316 Ok(examples)
317}
318
319fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
320 let mut hasher = collections::FxHasher::default();
321 spec.hash(&mut hasher);
322 hasher.finish()
323}
324
325fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
326 let file = match File::open(path) {
327 Ok(f) => f,
328 Err(_) => return,
329 };
330
331 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
332
333 let reader = BufReader::new(file);
334 let mut kept_lines = Vec::new();
335 let mut kept_hashes = HashSet::default();
336
337 for line in reader.lines() {
338 let line = match line {
339 Ok(l) => l,
340 Err(_) => continue,
341 };
342
343 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
344 let hash = spec_hash(&output_example.spec);
345 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
346 kept_hashes.insert(hash);
347 kept_lines.push(line);
348 }
349 }
350 }
351
352 let total = examples.len();
353 let already_processed = kept_hashes.len();
354
355 eprintln!(
356 "Resuming: {}/{} examples already processed",
357 already_processed, total
358 );
359
360 let file = OpenOptions::new()
361 .write(true)
362 .truncate(true)
363 .open(path)
364 .expect("Failed to open output file for rewriting");
365 let mut writer = BufWriter::new(file);
366 for line in &kept_lines {
367 writeln!(writer, "{}", line).expect("Failed to write to output file");
368 }
369 writer.flush().expect("Failed to flush output file");
370
371 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
372}
373
374fn main() {
375 let args = EpArgs::parse();
376
377 if args.printenv {
378 ::util::shell_env::print_env();
379 return;
380 }
381
382 let output = args.output_path();
383 let command = match &args.command {
384 Some(cmd) => cmd.clone(),
385 None => {
386 EpArgs::command().print_help().unwrap();
387 return;
388 }
389 };
390
391 match &command {
392 Command::Clean => {
393 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
394 return;
395 }
396 Command::Synthesize(synth_args) => {
397 let Some(output_dir) = args.output else {
398 panic!("output dir is required");
399 };
400 let config = SynthesizeConfig {
401 repo_url: synth_args.repo.clone(),
402 count: synth_args.count,
403 max_commits: synth_args.max_commits,
404 output_dir,
405 fresh: synth_args.fresh,
406 };
407 smol::block_on(async {
408 if let Err(e) = run_synthesize(config).await {
409 eprintln!("Error: {:?}", e);
410 std::process::exit(1);
411 }
412 });
413 return;
414 }
415 Command::SplitCommit(split_commit_args) => {
416 if let Err(error) =
417 split_commit::run_split_commit(split_commit_args, &args.inputs, output.as_ref())
418 {
419 eprintln!("{error:#}");
420 std::process::exit(1);
421 }
422 return;
423 }
424 Command::Split(split_args) => {
425 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
426 eprintln!("{error:#}");
427 std::process::exit(1);
428 }
429 return;
430 }
431 _ => {}
432 }
433
434 let http_client = Arc::new(ReqwestClient::new());
435 let app = Application::headless().with_http_client(http_client);
436
437 app.run(move |cx| {
438 let app_state = Arc::new(headless::init(cx));
439 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
440
441 cx.spawn(async move |cx| {
442 let result = async {
443 let mut examples =
444 load_examples(app_state.client.http_client(), &args, output.as_ref()).await?;
445
446 match &command {
447 Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
448 predict::sync_batches(&args.provider).await?;
449 }
450 _ => (),
451 }
452
453 let failfast_on_single_example = examples.len() == 1;
454
455 let output_sender: Option<mpsc::UnboundedSender<String>> =
456 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
457 output.as_ref().map(|path| {
458 let file = OpenOptions::new()
459 .create(true)
460 .append(true)
461 .open(path)
462 .expect("Failed to open output file");
463 let mut writer = BufWriter::new(file);
464 let (sender, mut receiver) = mpsc::unbounded::<String>();
465 cx.background_spawn(async move {
466 while let Some(line) = receiver.next().await {
467 writeln!(writer, "{}", line).expect("Failed to write example");
468 writer.flush().expect("Failed to flush output");
469 }
470 })
471 .detach();
472 sender
473 })
474 } else {
475 None
476 };
477
478 let mut grouped_examples = group_examples_by_repo(&mut examples);
479 let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
480
481 for example_batch in example_batches {
482 let futures = example_batch.into_iter().map(|repo_examples| async {
483 for example in repo_examples.iter_mut() {
484 let result = async {
485 match &command {
486 Command::ParseExample => {}
487 Command::LoadProject => {
488 run_load_project(example, app_state.clone(), cx.clone())
489 .await?;
490 }
491 Command::Context => {
492 run_context_retrieval(
493 example,
494 app_state.clone(),
495 cx.clone(),
496 )
497 .await?;
498 }
499 Command::FormatPrompt(args) => {
500 run_format_prompt(
501 example,
502 args.prompt_format,
503 app_state.clone(),
504 cx.clone(),
505 )
506 .await?;
507 }
508 Command::Predict(args) => {
509 run_prediction(
510 example,
511 Some(args.provider),
512 args.repetitions,
513 app_state.clone(),
514 cx.clone(),
515 )
516 .await?;
517 }
518 Command::Distill => {
519 run_distill(example).await?;
520 }
521 Command::Score(args) | Command::Eval(args) => {
522 run_scoring(example, &args, app_state.clone(), cx.clone())
523 .await?;
524 }
525 Command::Clean
526 | Command::Synthesize(_)
527 | Command::SplitCommit(_)
528 | Command::Split(_) => {
529 unreachable!()
530 }
531 }
532 anyhow::Ok(())
533 }
534 .await;
535
536 if let Err(error) = result {
537 handle_error(
538 error,
539 &args,
540 &command,
541 &app_state,
542 failfast_on_single_example,
543 example,
544 )
545 .await;
546 }
547
548 if let Some(ref mut sender) = output_sender.clone() {
549 let line = serde_json::to_string(example).unwrap();
550 sender
551 .send(line)
552 .await
553 .expect("Failed to send to output writer");
554 } else if args.output.is_none() && !matches!(command, Command::Eval(_))
555 {
556 let line = serde_json::to_string(example).unwrap();
557 println!("{}", line);
558 }
559 }
560 });
561 futures::future::join_all(futures).await;
562 }
563
564 Progress::global().finalize();
565
566 match &command {
567 Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
568 predict::sync_batches(&args.provider).await?;
569 }
570 _ => (),
571 }
572
573 match &command {
574 Command::Eval(_) => score::print_report(&examples),
575 _ => (),
576 };
577
578 anyhow::Ok(())
579 }
580 .await;
581
582 if let Err(e) = result {
583 panic!("Fatal error: {:?}", e);
584 }
585
586 let _ = cx.update(|cx| cx.quit());
587 })
588 .detach();
589 });
590}
591
592async fn handle_error(
593 error: anyhow::Error,
594 args: &EpArgs,
595 command: &Command,
596 app_state: &Arc<headless::EpAppState>,
597 failfast_on_single_example: bool,
598 example: &Example,
599) {
600 Progress::global().increment_failed();
601 let example_name = example.spec.filename();
602 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
603 app_state
604 .fs
605 .write(
606 &failed_example_path,
607 &serde_json::to_vec_pretty(&example).unwrap(),
608 )
609 .await
610 .unwrap();
611 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
612 app_state
613 .fs
614 .write(&err_path, format!("{error:?}").as_bytes())
615 .await
616 .unwrap();
617
618 let cursor_path = example
619 .repo_name()
620 .unwrap()
621 .worktree_path()
622 .join(&example.spec.cursor_path);
623
624 let msg = format!(
625 indoc::indoc! {"
626 While processing \"{}\":
627
628 \x1b[31m{:?}\x1b[0m
629
630 Example: \x1b[36m{}\x1b[0m
631 Error file: \x1b[36m{}\x1b[0m
632 Cursor file: \x1b[36m{}\x1b[0m
633 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
634 "},
635 example.spec.name,
636 error,
637 failed_example_path.display(),
638 err_path.display(),
639 cursor_path.display(),
640 command,
641 failed_example_path.display(),
642 );
643 if args.failfast || failfast_on_single_example {
644 Progress::global().finalize();
645 panic!("{}", msg);
646 } else {
647 log::error!("{}", msg);
648 }
649}