1mod anthropic_client;
2mod distill;
3mod example;
4mod format_prompt;
5mod git;
6mod headless;
7mod load_project;
8mod metrics;
9mod paths;
10mod predict;
11mod progress;
12mod pull_examples;
13mod reorder_patch;
14mod retrieve_context;
15mod score;
16mod split_commit;
17mod synthesize;
18use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
19use edit_prediction::EditPredictionStore;
20use gpui::Application;
21use reqwest_client::ReqwestClient;
22use serde::{Deserialize, Serialize};
23use std::fmt::Display;
24use std::{path::PathBuf, sync::Arc};
25
26use crate::distill::run_distill;
27use crate::example::{Example, group_examples_by_repo, read_example_files, write_examples};
28use crate::format_prompt::run_format_prompt;
29use crate::load_project::run_load_project;
30use crate::paths::FAILED_EXAMPLES_DIR;
31use crate::predict::run_prediction;
32use crate::progress::Progress;
33use crate::retrieve_context::run_context_retrieval;
34use crate::score::run_scoring;
35use crate::split_commit::SplitCommitArgs;
36use crate::synthesize::{SynthesizeConfig, run_synthesize};
37
38#[derive(Parser, Debug)]
39#[command(name = "ep")]
40struct EpArgs {
41 #[arg(long, default_value_t = false)]
42 printenv: bool,
43 #[clap(long, default_value_t = 10, global = true)]
44 max_parallelism: usize,
45 #[clap(long, global = true)]
46 limit: Option<usize>,
47 /// Filter examples by name
48 #[clap(long, global = true)]
49 name: Option<String>,
50 /// Filter examples by repository
51 #[clap(long, global = true)]
52 repo: Option<String>,
53 #[command(subcommand)]
54 command: Option<Command>,
55 #[clap(global = true, help = INPUTS_HELP)]
56 inputs: Vec<PathBuf>,
57 #[arg(long, short, global = true)]
58 output: Option<PathBuf>,
59 #[arg(long, short, global = true)]
60 in_place: bool,
61 #[arg(long, short, global = true)]
62 failfast: bool,
63}
64
65const INPUTS_HELP: &str = r#"
66Inputs can be file paths or special specifiers:
67
68 path
69 Path to an example(s) file (.md, .json, or .jsonl)
70
71 captured-after:{timestamp}
72 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
73
74 You can specify this multiple times and mix it with file inputs.
75
76 Required environment variables to connect to Snowflake:
77 EP_SNOWFLAKE_API_KEY
78 EP_SNOWFLAKE_BASE_URL
79
80 Optional:
81 EP_SNOWFLAKE_ROLE
82
83Examples:
84
85 # Predict from a file
86 ep predict examples.jsonl
87
88 # Predict from captured examples after a timestamp
89 ep predict captured-after:2025-01-01T00:00:00Z
90
91 # Mix file inputs and captured-after in the same invocation
92 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
93"#;
94
95#[derive(Subcommand, Debug, Clone)]
96enum Command {
97 /// Parse markdown examples and output a combined .jsonl file
98 ParseExample,
99 /// Create git worktrees for each example and load file contents
100 LoadProject,
101 /// Retrieve context for input examples.
102 Context,
103 /// Generate a prompt string for a specific model
104 FormatPrompt(FormatPromptArgs),
105 /// Runs edit prediction
106 Predict(PredictArgs),
107 /// Computes a score based on actual and expected patches
108 Score(PredictArgs),
109 /// Prepares a distillation dataset by copying expected outputs to
110 /// predicted outputs and removing actual outputs and prompts.
111 Distill,
112 /// Print aggregated scores
113 Eval(PredictArgs),
114 /// Generate eval examples by analyzing git commits from a repository
115 Synthesize(SynthesizeArgs),
116 /// Remove git repositories and worktrees
117 Clean,
118 /// Generate an evaluation example by splitting a chronologically-ordered commit
119 SplitCommit(SplitCommitArgs),
120}
121
122impl Display for Command {
123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124 match self {
125 Command::ParseExample => write!(f, "parse-example"),
126 Command::LoadProject => write!(f, "load-project"),
127 Command::Context => write!(f, "context"),
128 Command::FormatPrompt(format_prompt_args) => write!(
129 f,
130 "format-prompt --prompt-format={}",
131 format_prompt_args
132 .prompt_format
133 .to_possible_value()
134 .unwrap()
135 .get_name()
136 ),
137 Command::Predict(predict_args) => {
138 write!(
139 f,
140 "predict --provider={:?}",
141 predict_args
142 .provider
143 .to_possible_value()
144 .unwrap()
145 .get_name()
146 )
147 }
148 Command::Score(predict_args) => {
149 write!(
150 f,
151 "score --provider={:?}",
152 predict_args
153 .provider
154 .to_possible_value()
155 .unwrap()
156 .get_name()
157 )
158 }
159 Command::Distill => write!(f, "distill"),
160 Command::Eval(predict_args) => write!(
161 f,
162 "eval --provider={:?}",
163 predict_args
164 .provider
165 .to_possible_value()
166 .unwrap()
167 .get_name()
168 ),
169 Command::Synthesize(args) => {
170 write!(f, "synthesize --repo={}", args.repo)
171 }
172 Command::Clean => write!(f, "clean"),
173 Command::SplitCommit(_) => write!(f, "split-commit"),
174 }
175 }
176}
177
178#[derive(Debug, Args, Clone)]
179struct FormatPromptArgs {
180 #[clap(long)]
181 prompt_format: PromptFormat,
182}
183
184#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
185enum PromptFormat {
186 Teacher,
187 Zeta2,
188}
189
190#[derive(Debug, Args, Clone)]
191struct PredictArgs {
192 #[clap(long)]
193 provider: PredictionProvider,
194 #[clap(long, default_value_t = 1)]
195 repetitions: usize,
196}
197
198#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
199enum PredictionProvider {
200 Sweep,
201 Mercury,
202 Zeta1,
203 Zeta2,
204 Teacher,
205 TeacherNonBatching,
206}
207
208#[derive(Debug, Args, Clone)]
209struct SynthesizeArgs {
210 /// Repository URL (git@github.com:owner/repo or https://...)
211 #[clap(long)]
212 repo: String,
213
214 /// Number of examples to generate
215 #[clap(long, default_value_t = 5)]
216 count: usize,
217
218 /// Maximum commits to scan before giving up
219 #[clap(long, default_value_t = 100)]
220 max_commits: usize,
221
222 /// Ignore state file and reprocess all commits
223 #[clap(long)]
224 fresh: bool,
225}
226
227impl EpArgs {
228 fn output_path(&self) -> Option<PathBuf> {
229 if self.in_place {
230 if self.inputs.len() == 1 {
231 self.inputs.first().cloned()
232 } else {
233 panic!("--in-place requires exactly one input file")
234 }
235 } else {
236 self.output.clone()
237 }
238 }
239}
240
241async fn load_examples(
242 http_client: Arc<dyn http_client::HttpClient>,
243 args: &EpArgs,
244) -> anyhow::Result<Vec<Example>> {
245 let mut captured_after_timestamps = Vec::new();
246 let mut file_inputs = Vec::new();
247
248 for input in &args.inputs {
249 let input_string = input.to_string_lossy();
250 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
251 captured_after_timestamps.push(timestamp.to_string());
252 } else {
253 file_inputs.push(input.clone());
254 }
255 }
256
257 let mut examples = read_example_files(&file_inputs);
258
259 Progress::global().set_total_examples(examples.len());
260
261 let remaining_limit_for_snowflake =
262 args.limit.map(|limit| limit.saturating_sub(examples.len()));
263
264 if let Some(0) = remaining_limit_for_snowflake {
265 log::info!(
266 "skipping captured-after inputs because --limit is already satisfied by example files"
267 );
268 } else if !captured_after_timestamps.is_empty() {
269 captured_after_timestamps.sort();
270
271 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
272
273 let mut captured_examples = pull_examples::fetch_captured_examples_after(
274 http_client,
275 &captured_after_timestamps,
276 max_rows_per_timestamp,
277 )
278 .await?;
279 examples.append(&mut captured_examples);
280 }
281
282 crate::example::sort_examples_by_repo_and_rev(&mut examples);
283
284 if let Some(name_filter) = &args.name {
285 examples.retain(|example| example.spec.name.contains(name_filter));
286 }
287 if let Some(repo_filter) = &args.repo {
288 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
289 }
290
291 if let Some(limit) = args.limit {
292 if examples.len() > limit {
293 examples.truncate(limit);
294 }
295 }
296
297 Progress::global().set_total_examples(examples.len());
298
299 Ok(examples)
300}
301
302fn main() {
303 let args = EpArgs::parse();
304
305 if args.printenv {
306 ::util::shell_env::print_env();
307 return;
308 }
309
310 let output = args.output_path();
311 let command = match &args.command {
312 Some(cmd) => cmd.clone(),
313 None => {
314 EpArgs::command().print_help().unwrap();
315 return;
316 }
317 };
318
319 match &command {
320 Command::Clean => {
321 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
322 return;
323 }
324 Command::Synthesize(synth_args) => {
325 let Some(output_dir) = args.output else {
326 panic!("output dir is required");
327 };
328 let config = SynthesizeConfig {
329 repo_url: synth_args.repo.clone(),
330 count: synth_args.count,
331 max_commits: synth_args.max_commits,
332 output_dir,
333 fresh: synth_args.fresh,
334 };
335 smol::block_on(async {
336 if let Err(e) = run_synthesize(config).await {
337 eprintln!("Error: {:?}", e);
338 std::process::exit(1);
339 }
340 });
341 return;
342 }
343 Command::SplitCommit(split_commit_args) => {
344 if let Err(error) =
345 split_commit::run_split_commit(split_commit_args, &args.inputs, output.as_ref())
346 {
347 eprintln!("{error:#}");
348 std::process::exit(1);
349 }
350 return;
351 }
352 _ => {}
353 }
354
355 let http_client = Arc::new(ReqwestClient::new());
356 let app = Application::headless().with_http_client(http_client);
357
358 app.run(move |cx| {
359 let app_state = Arc::new(headless::init(cx));
360 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
361
362 cx.spawn(async move |cx| {
363 let result = async {
364 let mut examples = load_examples(app_state.client.http_client(), &args).await?;
365
366 if let Command::Predict(args) = &command {
367 predict::sync_batches(&args.provider).await?;
368 }
369
370 let failfast_on_single_example = examples.len() == 1;
371
372 let mut grouped_examples = group_examples_by_repo(&mut examples);
373 let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
374
375 for example_batch in example_batches {
376 let futures = example_batch.into_iter().map(|repo_examples| async {
377 for example in repo_examples.iter_mut() {
378 let result = async {
379 match &command {
380 Command::ParseExample => {}
381 Command::LoadProject => {
382 run_load_project(example, app_state.clone(), cx.clone())
383 .await?;
384 }
385 Command::Context => {
386 run_context_retrieval(
387 example,
388 app_state.clone(),
389 cx.clone(),
390 )
391 .await?;
392 }
393 Command::FormatPrompt(args) => {
394 run_format_prompt(
395 example,
396 args.prompt_format,
397 app_state.clone(),
398 cx.clone(),
399 )
400 .await?;
401 }
402 Command::Predict(args) => {
403 run_prediction(
404 example,
405 Some(args.provider),
406 args.repetitions,
407 app_state.clone(),
408 cx.clone(),
409 )
410 .await?;
411 }
412 Command::Distill => {
413 run_distill(example).await?;
414 }
415 Command::Score(args) | Command::Eval(args) => {
416 run_scoring(example, &args, app_state.clone(), cx.clone())
417 .await?;
418 }
419 Command::Clean
420 | Command::Synthesize(_)
421 | Command::SplitCommit(_) => {
422 unreachable!()
423 }
424 }
425 anyhow::Ok(())
426 }
427 .await;
428
429 if let Err(error) = result {
430 handle_error(
431 error,
432 &args,
433 &command,
434 &app_state,
435 failfast_on_single_example,
436 example,
437 )
438 .await;
439 }
440 }
441 });
442 futures::future::join_all(futures).await;
443 }
444 Progress::global().finalize();
445
446 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
447 write_examples(&examples, output.as_ref());
448 }
449
450 match &command {
451 Command::Predict(args) => predict::sync_batches(&args.provider).await?,
452 Command::Eval(_) => score::print_report(&examples),
453 _ => (),
454 };
455
456 anyhow::Ok(())
457 }
458 .await;
459
460 if let Err(e) = result {
461 panic!("Fatal error: {:?}", e);
462 }
463
464 let _ = cx.update(|cx| cx.quit());
465 })
466 .detach();
467 });
468}
469
470async fn handle_error(
471 error: anyhow::Error,
472 args: &EpArgs,
473 command: &Command,
474 app_state: &Arc<headless::EpAppState>,
475 failfast_on_single_example: bool,
476 example: &Example,
477) {
478 Progress::global().increment_failed();
479 let example_name = example.spec.filename();
480 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
481 app_state
482 .fs
483 .write(
484 &failed_example_path,
485 &serde_json::to_vec_pretty(&example).unwrap(),
486 )
487 .await
488 .unwrap();
489 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
490 app_state
491 .fs
492 .write(&err_path, format!("{error:?}").as_bytes())
493 .await
494 .unwrap();
495
496 let file_path = example
497 .repo_name()
498 .unwrap()
499 .worktree_path()
500 .join(&example.spec.cursor_path);
501
502 let msg = format!(
503 indoc::indoc! {"
504 While processing \"{}\":
505
506 {:?}
507
508 Written to: \x1b[36m{}\x1b[0m
509
510 Cursor File: \x1b[36m{}\x1b[0m
511
512 Explore this example data with:
513 fx \x1b[36m{}\x1b[0m
514
515 Re-run this example with:
516 cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
517 "},
518 example.spec.name,
519 error,
520 err_path.display(),
521 file_path.display(),
522 failed_example_path.display(),
523 command,
524 failed_example_path.display(),
525 );
526 if args.failfast || failfast_on_single_example {
527 Progress::global().finalize();
528 panic!("{}", msg);
529 } else {
530 log::error!("{}", msg);
531 }
532}