1mod anthropic_client;
2mod distill;
3mod example;
4mod format_prompt;
5mod git;
6mod headless;
7mod load_project;
8mod metrics;
9mod paths;
10mod predict;
11mod progress;
12mod pull_examples;
13mod reorder_patch;
14mod retrieve_context;
15mod score;
16mod split_commit;
17mod synthesize;
18use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
19use edit_prediction::EditPredictionStore;
20use gpui::Application;
21use reqwest_client::ReqwestClient;
22use serde::{Deserialize, Serialize};
23use std::fmt::Display;
24use std::{path::PathBuf, sync::Arc};
25
26use crate::distill::run_distill;
27use crate::example::{Example, group_examples_by_repo, read_example_files, write_examples};
28use crate::format_prompt::run_format_prompt;
29use crate::load_project::run_load_project;
30use crate::paths::FAILED_EXAMPLES_DIR;
31use crate::predict::run_prediction;
32use crate::progress::Progress;
33use crate::retrieve_context::run_context_retrieval;
34use crate::score::run_scoring;
35use crate::split_commit::SplitCommitArgs;
36use crate::synthesize::{SynthesizeConfig, run_synthesize};
37
38#[derive(Parser, Debug)]
39#[command(name = "ep")]
40struct EpArgs {
41 #[arg(long, default_value_t = false)]
42 printenv: bool,
43 #[clap(long, default_value_t = 10, global = true)]
44 max_parallelism: usize,
45 #[clap(long, global = true)]
46 limit: Option<usize>,
47 /// Filter examples by name
48 #[clap(long, global = true)]
49 name: Option<String>,
50 /// Filter examples by repository
51 #[clap(long, global = true)]
52 repo: Option<String>,
53 #[command(subcommand)]
54 command: Option<Command>,
55 #[clap(global = true, help = INPUTS_HELP)]
56 inputs: Vec<PathBuf>,
57 #[arg(long, short, global = true)]
58 output: Option<PathBuf>,
59 #[arg(long, short, global = true)]
60 in_place: bool,
61 #[arg(long, short, global = true)]
62 failfast: bool,
63}
64
65const INPUTS_HELP: &str = r#"
66Inputs can be file paths or special specifiers:
67
68 path
69 Path to an example(s) file (.md, .json, or .jsonl)
70
71 captured-after:{timestamp}
72 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
73
74 You can specify this multiple times and mix it with file inputs.
75
76 Required environment variables to connect to Snowflake:
77 EP_SNOWFLAKE_API_KEY
78 EP_SNOWFLAKE_BASE_URL
79
80 Optional:
81 EP_SNOWFLAKE_ROLE
82
83Examples:
84
85 # Predict from a file
86 ep predict examples.jsonl
87
88 # Predict from captured examples after a timestamp
89 ep predict captured-after:2025-01-01T00:00:00Z
90
91 # Mix file inputs and captured-after in the same invocation
92 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
93"#;
94
95#[derive(Subcommand, Debug, Clone)]
96enum Command {
97 /// Parse markdown examples and output a combined .jsonl file
98 ParseExample,
99 /// Create git worktrees for each example and load file contents
100 LoadProject,
101 /// Retrieve context for input examples.
102 Context,
103 /// Generate a prompt string for a specific model
104 FormatPrompt(FormatPromptArgs),
105 /// Runs edit prediction
106 Predict(PredictArgs),
107 /// Computes a score based on actual and expected patches
108 Score(PredictArgs),
109 /// Prepares a distillation dataset by copying expected outputs to
110 /// predicted outputs and removing actual outputs and prompts.
111 Distill,
112 /// Print aggregated scores
113 Eval(PredictArgs),
114 /// Generate eval examples by analyzing git commits from a repository
115 Synthesize(SynthesizeArgs),
116 /// Remove git repositories and worktrees
117 Clean,
118 /// Generate an evaluation example by splitting a chronologically-ordered commit
119 SplitCommit(SplitCommitArgs),
120}
121
122impl Display for Command {
123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124 match self {
125 Command::ParseExample => write!(f, "parse-example"),
126 Command::LoadProject => write!(f, "load-project"),
127 Command::Context => write!(f, "context"),
128 Command::FormatPrompt(format_prompt_args) => write!(
129 f,
130 "format-prompt --prompt-format={}",
131 format_prompt_args
132 .prompt_format
133 .to_possible_value()
134 .unwrap()
135 .get_name()
136 ),
137 Command::Predict(predict_args) => {
138 write!(
139 f,
140 "predict --provider={:?}",
141 predict_args
142 .provider
143 .to_possible_value()
144 .unwrap()
145 .get_name()
146 )
147 }
148 Command::Score(predict_args) => {
149 write!(
150 f,
151 "score --provider={:?}",
152 predict_args
153 .provider
154 .to_possible_value()
155 .unwrap()
156 .get_name()
157 )
158 }
159 Command::Distill => write!(f, "distill"),
160 Command::Eval(predict_args) => write!(
161 f,
162 "eval --provider={:?}",
163 predict_args
164 .provider
165 .to_possible_value()
166 .unwrap()
167 .get_name()
168 ),
169 Command::Synthesize(args) => {
170 write!(f, "synthesize --repo={}", args.repo)
171 }
172 Command::Clean => write!(f, "clean"),
173 Command::SplitCommit(_) => write!(f, "split-commit"),
174 }
175 }
176}
177
178#[derive(Debug, Args, Clone)]
179struct FormatPromptArgs {
180 #[clap(long)]
181 prompt_format: PromptFormat,
182}
183
184#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
185enum PromptFormat {
186 Teacher,
187 Zeta2,
188}
189
190#[derive(Debug, Args, Clone)]
191struct PredictArgs {
192 #[clap(long)]
193 provider: PredictionProvider,
194 #[clap(long, default_value_t = 1)]
195 repetitions: usize,
196}
197
198#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
199enum PredictionProvider {
200 Sweep,
201 Mercury,
202 Zeta1,
203 Zeta2,
204 Teacher,
205 TeacherNonBatching,
206}
207
208#[derive(Debug, Args, Clone)]
209struct SynthesizeArgs {
210 /// Repository URL (git@github.com:owner/repo or https://...)
211 #[clap(long)]
212 repo: String,
213
214 /// Number of examples to generate
215 #[clap(long, default_value_t = 5)]
216 count: usize,
217
218 /// Maximum commits to scan before giving up
219 #[clap(long, default_value_t = 100)]
220 max_commits: usize,
221
222 /// Ignore state file and reprocess all commits
223 #[clap(long)]
224 fresh: bool,
225}
226
227impl EpArgs {
228 fn output_path(&self) -> Option<PathBuf> {
229 if self.in_place {
230 if self.inputs.len() == 1 {
231 self.inputs.first().cloned()
232 } else {
233 panic!("--in-place requires exactly one input file")
234 }
235 } else {
236 self.output.clone()
237 }
238 }
239}
240
241async fn load_examples(
242 http_client: Arc<dyn http_client::HttpClient>,
243 args: &EpArgs,
244) -> anyhow::Result<Vec<Example>> {
245 let mut captured_after_timestamps = Vec::new();
246 let mut file_inputs = Vec::new();
247
248 for input in &args.inputs {
249 let input_string = input.to_string_lossy();
250 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
251 captured_after_timestamps.push(timestamp.to_string());
252 } else {
253 file_inputs.push(input.clone());
254 }
255 }
256
257 let mut examples = read_example_files(&file_inputs);
258
259 let total_steps = examples.len() + captured_after_timestamps.len();
260 Progress::global().set_total_steps(total_steps);
261
262 let remaining_limit_for_snowflake =
263 args.limit.map(|limit| limit.saturating_sub(examples.len()));
264
265 if let Some(0) = remaining_limit_for_snowflake {
266 log::info!(
267 "skipping captured-after inputs because --limit is already satisfied by example files"
268 );
269 } else if !captured_after_timestamps.is_empty() {
270 captured_after_timestamps.sort();
271
272 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
273
274 let mut captured_examples = pull_examples::fetch_captured_examples_after(
275 http_client,
276 &captured_after_timestamps,
277 max_rows_per_timestamp,
278 )
279 .await?;
280 examples.append(&mut captured_examples);
281 }
282
283 crate::example::sort_examples_by_repo_and_rev(&mut examples);
284
285 if let Some(name_filter) = &args.name {
286 examples.retain(|example| example.spec.name.contains(name_filter));
287 }
288 if let Some(repo_filter) = &args.repo {
289 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
290 }
291
292 if let Some(limit) = args.limit {
293 if examples.len() > limit {
294 examples.truncate(limit);
295 }
296 }
297
298 Progress::global().set_total_steps(examples.len() + captured_after_timestamps.len());
299
300 Ok(examples)
301}
302
303fn main() {
304 let args = EpArgs::parse();
305
306 if args.printenv {
307 ::util::shell_env::print_env();
308 return;
309 }
310
311 let output = args.output_path();
312 let command = match &args.command {
313 Some(cmd) => cmd.clone(),
314 None => {
315 EpArgs::command().print_help().unwrap();
316 return;
317 }
318 };
319
320 match &command {
321 Command::Clean => {
322 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
323 return;
324 }
325 Command::Synthesize(synth_args) => {
326 let Some(output_dir) = args.output else {
327 panic!("output dir is required");
328 };
329 let config = SynthesizeConfig {
330 repo_url: synth_args.repo.clone(),
331 count: synth_args.count,
332 max_commits: synth_args.max_commits,
333 output_dir,
334 fresh: synth_args.fresh,
335 };
336 smol::block_on(async {
337 if let Err(e) = run_synthesize(config).await {
338 eprintln!("Error: {:?}", e);
339 std::process::exit(1);
340 }
341 });
342 return;
343 }
344 Command::SplitCommit(split_commit_args) => {
345 if let Err(error) =
346 split_commit::run_split_commit(split_commit_args, &args.inputs, output.as_ref())
347 {
348 eprintln!("{error:#}");
349 std::process::exit(1);
350 }
351 return;
352 }
353 _ => {}
354 }
355
356 let http_client = Arc::new(ReqwestClient::new());
357 let app = Application::headless().with_http_client(http_client);
358
359 app.run(move |cx| {
360 let app_state = Arc::new(headless::init(cx));
361 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
362
363 cx.spawn(async move |cx| {
364 let result = async {
365 let mut examples = load_examples(app_state.client.http_client(), &args).await?;
366
367 if let Command::Predict(args) = &command {
368 predict::sync_batches(&args.provider).await?;
369 }
370
371 let failfast_on_single_example = examples.len() == 1;
372
373 let mut grouped_examples = group_examples_by_repo(&mut examples);
374 let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
375
376 for example_batch in example_batches {
377 let futures = example_batch.into_iter().map(|repo_examples| async {
378 for example in repo_examples.iter_mut() {
379 let result = async {
380 match &command {
381 Command::ParseExample => {}
382 Command::LoadProject => {
383 run_load_project(example, app_state.clone(), cx.clone())
384 .await?;
385 }
386 Command::Context => {
387 run_context_retrieval(
388 example,
389 app_state.clone(),
390 cx.clone(),
391 )
392 .await?;
393 }
394 Command::FormatPrompt(args) => {
395 run_format_prompt(
396 example,
397 args.prompt_format,
398 app_state.clone(),
399 cx.clone(),
400 )
401 .await?;
402 }
403 Command::Predict(args) => {
404 run_prediction(
405 example,
406 Some(args.provider),
407 args.repetitions,
408 app_state.clone(),
409 cx.clone(),
410 )
411 .await?;
412 }
413 Command::Distill => {
414 run_distill(example).await?;
415 }
416 Command::Score(args) | Command::Eval(args) => {
417 run_scoring(example, &args, app_state.clone(), cx.clone())
418 .await?;
419 }
420 Command::Clean
421 | Command::Synthesize(_)
422 | Command::SplitCommit(_) => {
423 unreachable!()
424 }
425 }
426 anyhow::Ok(())
427 }
428 .await;
429
430 if let Err(error) = result {
431 handle_error(
432 error,
433 &args,
434 &command,
435 &app_state,
436 failfast_on_single_example,
437 example,
438 )
439 .await;
440 }
441 }
442 });
443 futures::future::join_all(futures).await;
444 }
445 Progress::global().finalize();
446
447 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
448 write_examples(&examples, output.as_ref());
449 }
450
451 match &command {
452 Command::Predict(args) => predict::sync_batches(&args.provider).await?,
453 Command::Eval(_) => score::print_report(&examples),
454 _ => (),
455 };
456
457 anyhow::Ok(())
458 }
459 .await;
460
461 if let Err(e) = result {
462 panic!("Fatal error: {:?}", e);
463 }
464
465 let _ = cx.update(|cx| cx.quit());
466 })
467 .detach();
468 });
469}
470
471async fn handle_error(
472 error: anyhow::Error,
473 args: &EpArgs,
474 command: &Command,
475 app_state: &Arc<headless::EpAppState>,
476 failfast_on_single_example: bool,
477 example: &Example,
478) {
479 Progress::global().increment_failed();
480 let example_name = example.spec.filename();
481 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
482 app_state
483 .fs
484 .write(
485 &failed_example_path,
486 &serde_json::to_vec_pretty(&example).unwrap(),
487 )
488 .await
489 .unwrap();
490 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
491 app_state
492 .fs
493 .write(&err_path, format!("{error:?}").as_bytes())
494 .await
495 .unwrap();
496
497 let file_path = example
498 .repo_name()
499 .unwrap()
500 .worktree_path()
501 .join(&example.spec.cursor_path);
502
503 let msg = format!(
504 indoc::indoc! {"
505 While processing \"{}\":
506
507 {:?}
508
509 Written to: \x1b[36m{}\x1b[0m
510
511 Cursor File: \x1b[36m{}\x1b[0m
512
513 Explore this example data with:
514 fx \x1b[36m{}\x1b[0m
515
516 Re-run this example with:
517 cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
518 "},
519 example.spec.name,
520 error,
521 err_path.display(),
522 file_path.display(),
523 failed_example_path.display(),
524 command,
525 failed_example_path.display(),
526 );
527 if args.failfast || failfast_on_single_example {
528 Progress::global().finalize();
529 panic!("{}", msg);
530 } else {
531 log::error!("{}", msg);
532 }
533}