1mod anthropic_client;
2mod distill;
3mod example;
4mod format_prompt;
5mod git;
6mod headless;
7mod load_project;
8mod metrics;
9mod paths;
10mod predict;
11mod progress;
12mod pull_examples;
13mod reorder_patch;
14mod retrieve_context;
15mod score;
16mod split_commit;
17mod synthesize;
18use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
19use collections::HashSet;
20use edit_prediction::EditPredictionStore;
21use futures::channel::mpsc;
22use futures::{SinkExt as _, StreamExt as _};
23use gpui::{AppContext as _, Application};
24
25use reqwest_client::ReqwestClient;
26use serde::{Deserialize, Serialize};
27use std::fmt::Display;
28use std::fs::{File, OpenOptions};
29use std::hash::{Hash, Hasher};
30use std::io::{BufRead, BufReader, BufWriter, Write};
31use std::{path::PathBuf, sync::Arc};
32
33use crate::distill::run_distill;
34use crate::example::{Example, group_examples_by_repo, read_example_files};
35use crate::format_prompt::run_format_prompt;
36use crate::load_project::run_load_project;
37use crate::paths::FAILED_EXAMPLES_DIR;
38use crate::predict::run_prediction;
39use crate::progress::Progress;
40use crate::retrieve_context::run_context_retrieval;
41use crate::score::run_scoring;
42use crate::split_commit::SplitCommitArgs;
43use crate::synthesize::{SynthesizeConfig, run_synthesize};
44
45#[derive(Parser, Debug)]
46#[command(name = "ep")]
47struct EpArgs {
48 #[arg(long, default_value_t = false)]
49 printenv: bool,
50 #[clap(long, default_value_t = 10, global = true)]
51 max_parallelism: usize,
52 #[clap(long, global = true)]
53 limit: Option<usize>,
54 /// Filter examples by name
55 #[clap(long, global = true)]
56 name: Option<String>,
57 /// Filter examples by repository
58 #[clap(long, global = true)]
59 repo: Option<String>,
60 #[command(subcommand)]
61 command: Option<Command>,
62 #[clap(global = true, help = INPUTS_HELP)]
63 inputs: Vec<PathBuf>,
64 #[arg(long, short, global = true)]
65 output: Option<PathBuf>,
66 #[arg(long, short, global = true)]
67 in_place: bool,
68 #[arg(long, short, global = true)]
69 failfast: bool,
70}
71
72const INPUTS_HELP: &str = r#"
73Inputs can be file paths or special specifiers:
74
75 path
76 Path to an example(s) file (.md, .json, or .jsonl)
77
78 captured-after:{timestamp}
79 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
80
81 You can specify this multiple times and mix it with file inputs.
82
83 Required environment variables to connect to Snowflake:
84 EP_SNOWFLAKE_API_KEY
85 EP_SNOWFLAKE_BASE_URL
86
87 Optional:
88 EP_SNOWFLAKE_ROLE
89
90Examples:
91
92 # Predict from a file
93 ep predict examples.jsonl
94
95 # Predict from captured examples after a timestamp
96 ep predict captured-after:2025-01-01T00:00:00Z
97
98 # Mix file inputs and captured-after in the same invocation
99 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
100"#;
101
102#[derive(Subcommand, Debug, Clone)]
103enum Command {
104 /// Parse markdown examples and output a combined .jsonl file
105 ParseExample,
106 /// Create git worktrees for each example and load file contents
107 LoadProject,
108 /// Retrieve context for input examples.
109 Context,
110 /// Generate a prompt string for a specific model
111 FormatPrompt(FormatPromptArgs),
112 /// Runs edit prediction
113 Predict(PredictArgs),
114 /// Computes a score based on actual and expected patches
115 Score(PredictArgs),
116 /// Prepares a distillation dataset by copying expected outputs to
117 /// predicted outputs and removing actual outputs and prompts.
118 Distill,
119 /// Print aggregated scores
120 Eval(PredictArgs),
121 /// Generate eval examples by analyzing git commits from a repository
122 Synthesize(SynthesizeArgs),
123 /// Remove git repositories and worktrees
124 Clean,
125 /// Generate an evaluation example by splitting a chronologically-ordered commit
126 SplitCommit(SplitCommitArgs),
127}
128
129impl Display for Command {
130 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131 match self {
132 Command::ParseExample => write!(f, "parse-example"),
133 Command::LoadProject => write!(f, "load-project"),
134 Command::Context => write!(f, "context"),
135 Command::FormatPrompt(format_prompt_args) => write!(
136 f,
137 "format-prompt --prompt-format={}",
138 format_prompt_args
139 .prompt_format
140 .to_possible_value()
141 .unwrap()
142 .get_name()
143 ),
144 Command::Predict(predict_args) => {
145 write!(
146 f,
147 "predict --provider={:?}",
148 predict_args
149 .provider
150 .to_possible_value()
151 .unwrap()
152 .get_name()
153 )
154 }
155 Command::Score(predict_args) => {
156 write!(
157 f,
158 "score --provider={:?}",
159 predict_args
160 .provider
161 .to_possible_value()
162 .unwrap()
163 .get_name()
164 )
165 }
166 Command::Distill => write!(f, "distill"),
167 Command::Eval(predict_args) => write!(
168 f,
169 "eval --provider={:?}",
170 predict_args
171 .provider
172 .to_possible_value()
173 .unwrap()
174 .get_name()
175 ),
176 Command::Synthesize(args) => {
177 write!(f, "synthesize --repo={}", args.repo)
178 }
179 Command::Clean => write!(f, "clean"),
180 Command::SplitCommit(_) => write!(f, "split-commit"),
181 }
182 }
183}
184
185#[derive(Debug, Args, Clone)]
186struct FormatPromptArgs {
187 #[clap(long)]
188 prompt_format: PromptFormat,
189}
190
191#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
192enum PromptFormat {
193 Teacher,
194 Zeta2,
195}
196
197#[derive(Debug, Args, Clone)]
198struct PredictArgs {
199 #[clap(long)]
200 provider: PredictionProvider,
201 #[clap(long, default_value_t = 1)]
202 repetitions: usize,
203}
204
205#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
206enum PredictionProvider {
207 Sweep,
208 Mercury,
209 Zeta1,
210 Zeta2,
211 Teacher,
212 TeacherNonBatching,
213}
214
215#[derive(Debug, Args, Clone)]
216struct SynthesizeArgs {
217 /// Repository URL (git@github.com:owner/repo or https://...)
218 #[clap(long)]
219 repo: String,
220
221 /// Number of examples to generate
222 #[clap(long, default_value_t = 5)]
223 count: usize,
224
225 /// Maximum commits to scan before giving up
226 #[clap(long, default_value_t = 100)]
227 max_commits: usize,
228
229 /// Ignore state file and reprocess all commits
230 #[clap(long)]
231 fresh: bool,
232}
233
234impl EpArgs {
235 fn output_path(&self) -> Option<PathBuf> {
236 if self.in_place {
237 if self.inputs.len() == 1 {
238 self.inputs.first().cloned()
239 } else {
240 panic!("--in-place requires exactly one input file")
241 }
242 } else {
243 self.output.clone()
244 }
245 }
246}
247
248async fn load_examples(
249 http_client: Arc<dyn http_client::HttpClient>,
250 args: &EpArgs,
251 output_path: Option<&PathBuf>,
252) -> anyhow::Result<Vec<Example>> {
253 let mut captured_after_timestamps = Vec::new();
254 let mut file_inputs = Vec::new();
255
256 for input in &args.inputs {
257 let input_string = input.to_string_lossy();
258 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
259 captured_after_timestamps.push(timestamp.to_string());
260 } else {
261 file_inputs.push(input.clone());
262 }
263 }
264
265 let mut examples = read_example_files(&file_inputs);
266
267 Progress::global().set_total_examples(examples.len());
268
269 let remaining_limit_for_snowflake =
270 args.limit.map(|limit| limit.saturating_sub(examples.len()));
271
272 if let Some(0) = remaining_limit_for_snowflake {
273 log::info!(
274 "skipping captured-after inputs because --limit is already satisfied by example files"
275 );
276 } else if !captured_after_timestamps.is_empty() {
277 captured_after_timestamps.sort();
278
279 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
280
281 let mut captured_examples = pull_examples::fetch_captured_examples_after(
282 http_client,
283 &captured_after_timestamps,
284 max_rows_per_timestamp,
285 )
286 .await?;
287 examples.append(&mut captured_examples);
288 }
289
290 crate::example::sort_examples_by_repo_and_rev(&mut examples);
291
292 if let Some(name_filter) = &args.name {
293 examples.retain(|example| example.spec.name.contains(name_filter));
294 }
295 if let Some(repo_filter) = &args.repo {
296 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
297 }
298
299 if let Some(limit) = args.limit {
300 if examples.len() > limit {
301 examples.truncate(limit);
302 }
303 }
304
305 if let Some(path) = output_path {
306 resume_from_output(path, &mut examples);
307 }
308
309 Progress::global().set_total_examples(examples.len());
310
311 Ok(examples)
312}
313
314fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
315 let mut hasher = collections::FxHasher::default();
316 spec.hash(&mut hasher);
317 hasher.finish()
318}
319
320fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
321 let file = match File::open(path) {
322 Ok(f) => f,
323 Err(_) => return,
324 };
325
326 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
327
328 let reader = BufReader::new(file);
329 let mut kept_lines = Vec::new();
330 let mut kept_hashes = HashSet::default();
331
332 for line in reader.lines() {
333 let line = match line {
334 Ok(l) => l,
335 Err(_) => continue,
336 };
337
338 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
339 let hash = spec_hash(&output_example.spec);
340 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
341 kept_hashes.insert(hash);
342 kept_lines.push(line);
343 }
344 }
345 }
346
347 let total = examples.len();
348 let already_processed = kept_hashes.len();
349
350 eprintln!(
351 "Resuming: {}/{} examples already processed",
352 already_processed, total
353 );
354
355 let file = OpenOptions::new()
356 .write(true)
357 .truncate(true)
358 .open(path)
359 .expect("Failed to open output file for rewriting");
360 let mut writer = BufWriter::new(file);
361 for line in &kept_lines {
362 writeln!(writer, "{}", line).expect("Failed to write to output file");
363 }
364 writer.flush().expect("Failed to flush output file");
365
366 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
367}
368
369fn main() {
370 let args = EpArgs::parse();
371
372 if args.printenv {
373 ::util::shell_env::print_env();
374 return;
375 }
376
377 let output = args.output_path();
378 let command = match &args.command {
379 Some(cmd) => cmd.clone(),
380 None => {
381 EpArgs::command().print_help().unwrap();
382 return;
383 }
384 };
385
386 match &command {
387 Command::Clean => {
388 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
389 return;
390 }
391 Command::Synthesize(synth_args) => {
392 let Some(output_dir) = args.output else {
393 panic!("output dir is required");
394 };
395 let config = SynthesizeConfig {
396 repo_url: synth_args.repo.clone(),
397 count: synth_args.count,
398 max_commits: synth_args.max_commits,
399 output_dir,
400 fresh: synth_args.fresh,
401 };
402 smol::block_on(async {
403 if let Err(e) = run_synthesize(config).await {
404 eprintln!("Error: {:?}", e);
405 std::process::exit(1);
406 }
407 });
408 return;
409 }
410 Command::SplitCommit(split_commit_args) => {
411 if let Err(error) =
412 split_commit::run_split_commit(split_commit_args, &args.inputs, output.as_ref())
413 {
414 eprintln!("{error:#}");
415 std::process::exit(1);
416 }
417 return;
418 }
419 _ => {}
420 }
421
422 let http_client = Arc::new(ReqwestClient::new());
423 let app = Application::headless().with_http_client(http_client);
424
425 app.run(move |cx| {
426 let app_state = Arc::new(headless::init(cx));
427 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
428
429 cx.spawn(async move |cx| {
430 let result = async {
431 let mut examples =
432 load_examples(app_state.client.http_client(), &args, output.as_ref()).await?;
433
434 if let Command::Predict(args) = &command {
435 predict::sync_batches(&args.provider).await?;
436 }
437
438 let failfast_on_single_example = examples.len() == 1;
439
440 let output_sender: Option<mpsc::UnboundedSender<String>> =
441 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
442 output.as_ref().map(|path| {
443 let file = OpenOptions::new()
444 .create(true)
445 .append(true)
446 .open(path)
447 .expect("Failed to open output file");
448 let mut writer = BufWriter::new(file);
449 let (sender, mut receiver) = mpsc::unbounded::<String>();
450 cx.background_spawn(async move {
451 while let Some(line) = receiver.next().await {
452 writeln!(writer, "{}", line).expect("Failed to write example");
453 writer.flush().expect("Failed to flush output");
454 }
455 })
456 .detach();
457 sender
458 })
459 } else {
460 None
461 };
462
463 let mut grouped_examples = group_examples_by_repo(&mut examples);
464 let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
465
466 for example_batch in example_batches {
467 let futures = example_batch.into_iter().map(|repo_examples| async {
468 for example in repo_examples.iter_mut() {
469 let result = async {
470 match &command {
471 Command::ParseExample => {}
472 Command::LoadProject => {
473 run_load_project(example, app_state.clone(), cx.clone())
474 .await?;
475 }
476 Command::Context => {
477 run_context_retrieval(
478 example,
479 app_state.clone(),
480 cx.clone(),
481 )
482 .await?;
483 }
484 Command::FormatPrompt(args) => {
485 run_format_prompt(
486 example,
487 args.prompt_format,
488 app_state.clone(),
489 cx.clone(),
490 )
491 .await?;
492 }
493 Command::Predict(args) => {
494 run_prediction(
495 example,
496 Some(args.provider),
497 args.repetitions,
498 app_state.clone(),
499 cx.clone(),
500 )
501 .await?;
502 }
503 Command::Distill => {
504 run_distill(example).await?;
505 }
506 Command::Score(args) | Command::Eval(args) => {
507 run_scoring(example, &args, app_state.clone(), cx.clone())
508 .await?;
509 }
510 Command::Clean
511 | Command::Synthesize(_)
512 | Command::SplitCommit(_) => {
513 unreachable!()
514 }
515 }
516 anyhow::Ok(())
517 }
518 .await;
519
520 if let Err(error) = result {
521 handle_error(
522 error,
523 &args,
524 &command,
525 &app_state,
526 failfast_on_single_example,
527 example,
528 )
529 .await;
530 }
531
532 if let Some(ref mut sender) = output_sender.clone() {
533 let line = serde_json::to_string(example).unwrap();
534 sender
535 .send(line)
536 .await
537 .expect("Failed to send to output writer");
538 } else if args.output.is_none() && !matches!(command, Command::Eval(_))
539 {
540 let line = serde_json::to_string(example).unwrap();
541 println!("{}", line);
542 }
543 }
544 });
545 futures::future::join_all(futures).await;
546 }
547
548 Progress::global().finalize();
549
550 match &command {
551 Command::Predict(args) => predict::sync_batches(&args.provider).await?,
552 Command::Eval(_) => score::print_report(&examples),
553 _ => (),
554 };
555
556 anyhow::Ok(())
557 }
558 .await;
559
560 if let Err(e) = result {
561 panic!("Fatal error: {:?}", e);
562 }
563
564 let _ = cx.update(|cx| cx.quit());
565 })
566 .detach();
567 });
568}
569
570async fn handle_error(
571 error: anyhow::Error,
572 args: &EpArgs,
573 command: &Command,
574 app_state: &Arc<headless::EpAppState>,
575 failfast_on_single_example: bool,
576 example: &Example,
577) {
578 Progress::global().increment_failed();
579 let example_name = example.spec.filename();
580 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
581 app_state
582 .fs
583 .write(
584 &failed_example_path,
585 &serde_json::to_vec_pretty(&example).unwrap(),
586 )
587 .await
588 .unwrap();
589 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
590 app_state
591 .fs
592 .write(&err_path, format!("{error:?}").as_bytes())
593 .await
594 .unwrap();
595
596 let file_path = example
597 .repo_name()
598 .unwrap()
599 .worktree_path()
600 .join(&example.spec.cursor_path);
601
602 let msg = format!(
603 indoc::indoc! {"
604 While processing \"{}\":
605
606 {:?}
607
608 Written to: \x1b[36m{}\x1b[0m
609
610 Cursor File: \x1b[36m{}\x1b[0m
611
612 Explore this example data with:
613 fx \x1b[36m{}\x1b[0m
614
615 Re-run this example with:
616 cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
617 "},
618 example.spec.name,
619 error,
620 err_path.display(),
621 file_path.display(),
622 failed_example_path.display(),
623 command,
624 failed_example_path.display(),
625 );
626 if args.failfast || failfast_on_single_example {
627 Progress::global().finalize();
628 panic!("{}", msg);
629 } else {
630 log::error!("{}", msg);
631 }
632}