1mod anthropic_client;
2mod distill;
3mod example;
4mod format_prompt;
5mod git;
6mod headless;
7mod load_project;
8mod metrics;
9mod paths;
10mod predict;
11mod progress;
12mod pull_examples;
13mod reorder_patch;
14mod retrieve_context;
15mod score;
16mod split_commit;
17mod synthesize;
18use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
19use edit_prediction::EditPredictionStore;
20use gpui::Application;
21use reqwest_client::ReqwestClient;
22use serde::{Deserialize, Serialize};
23use std::fmt::Display;
24use std::{path::PathBuf, sync::Arc};
25
26use crate::distill::run_distill;
27use crate::example::{Example, group_examples_by_repo, read_example_files, write_examples};
28use crate::format_prompt::run_format_prompt;
29use crate::load_project::run_load_project;
30use crate::paths::FAILED_EXAMPLES_DIR;
31use crate::predict::run_prediction;
32use crate::progress::Progress;
33use crate::retrieve_context::run_context_retrieval;
34use crate::score::run_scoring;
35use crate::split_commit::SplitCommitArgs;
36use crate::synthesize::{SynthesizeConfig, run_synthesize};
37
38#[derive(Parser, Debug)]
39#[command(name = "ep")]
40struct EpArgs {
41 #[arg(long, default_value_t = false)]
42 printenv: bool,
43 #[clap(long, default_value_t = 10, global = true)]
44 max_parallelism: usize,
45 #[clap(long, global = true)]
46 limit: Option<usize>,
47 #[command(subcommand)]
48 command: Option<Command>,
49 #[clap(global = true, help = INPUTS_HELP)]
50 inputs: Vec<PathBuf>,
51 #[arg(long, short, global = true)]
52 output: Option<PathBuf>,
53 #[arg(long, short, global = true)]
54 in_place: bool,
55 #[arg(long, short, global = true)]
56 failfast: bool,
57}
58
59const INPUTS_HELP: &str = r#"
60Inputs can be file paths or special specifiers:
61
62 path
63 Path to an example(s) file (.md, .json, or .jsonl)
64
65 captured-after:{timestamp}
66 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
67
68 You can specify this multiple times and mix it with file inputs.
69
70 Required environment variables to connect to Snowflake:
71 EP_SNOWFLAKE_API_KEY
72 EP_SNOWFLAKE_BASE_URL
73
74 Optional:
75 EP_SNOWFLAKE_ROLE
76
77Examples:
78
79 # Predict from a file
80 ep predict examples.jsonl
81
82 # Predict from captured examples after a timestamp
83 ep predict captured-after:2025-01-01T00:00:00Z
84
85 # Mix file inputs and captured-after in the same invocation
86 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
87"#;
88
89#[derive(Subcommand, Debug, Clone)]
90enum Command {
91 /// Parse markdown examples and output a combined .jsonl file
92 ParseExample,
93 /// Create git worktrees for each example and load file contents
94 LoadProject,
95 /// Retrieve context for input examples.
96 Context,
97 /// Generate a prompt string for a specific model
98 FormatPrompt(FormatPromptArgs),
99 /// Runs edit prediction
100 Predict(PredictArgs),
101 /// Computes a score based on actual and expected patches
102 Score(PredictArgs),
103 /// Prepares a distillation dataset by copying expected outputs to
104 /// predicted outputs and removing actual outputs and prompts.
105 Distill,
106 /// Print aggregated scores
107 Eval(PredictArgs),
108 /// Generate eval examples by analyzing git commits from a repository
109 Synthesize(SynthesizeArgs),
110 /// Remove git repositories and worktrees
111 Clean,
112 /// Generate an evaluation example by splitting a chronologically-ordered commit
113 SplitCommit(SplitCommitArgs),
114}
115
116impl Display for Command {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 match self {
119 Command::ParseExample => write!(f, "parse-example"),
120 Command::LoadProject => write!(f, "load-project"),
121 Command::Context => write!(f, "context"),
122 Command::FormatPrompt(format_prompt_args) => write!(
123 f,
124 "format-prompt --prompt-format={}",
125 format_prompt_args
126 .prompt_format
127 .to_possible_value()
128 .unwrap()
129 .get_name()
130 ),
131 Command::Predict(predict_args) => {
132 write!(
133 f,
134 "predict --provider={:?}",
135 predict_args
136 .provider
137 .to_possible_value()
138 .unwrap()
139 .get_name()
140 )
141 }
142 Command::Score(predict_args) => {
143 write!(
144 f,
145 "score --provider={:?}",
146 predict_args
147 .provider
148 .to_possible_value()
149 .unwrap()
150 .get_name()
151 )
152 }
153 Command::Distill => write!(f, "distill"),
154 Command::Eval(predict_args) => write!(
155 f,
156 "eval --provider={:?}",
157 predict_args
158 .provider
159 .to_possible_value()
160 .unwrap()
161 .get_name()
162 ),
163 Command::Synthesize(args) => {
164 write!(f, "synthesize --repo={}", args.repo)
165 }
166 Command::Clean => write!(f, "clean"),
167 Command::SplitCommit(_) => write!(f, "split-commit"),
168 }
169 }
170}
171
172#[derive(Debug, Args, Clone)]
173struct FormatPromptArgs {
174 #[clap(long)]
175 prompt_format: PromptFormat,
176}
177
178#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
179enum PromptFormat {
180 Teacher,
181 Zeta2,
182}
183
184#[derive(Debug, Args, Clone)]
185struct PredictArgs {
186 #[clap(long)]
187 provider: PredictionProvider,
188 #[clap(long, default_value_t = 1)]
189 repetitions: usize,
190}
191
192#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
193enum PredictionProvider {
194 Sweep,
195 Mercury,
196 Zeta1,
197 Zeta2,
198 Teacher,
199 TeacherNonBatching,
200}
201
202#[derive(Debug, Args, Clone)]
203struct SynthesizeArgs {
204 /// Repository URL (git@github.com:owner/repo or https://...)
205 #[clap(long)]
206 repo: String,
207
208 /// Number of examples to generate
209 #[clap(long, default_value_t = 5)]
210 count: usize,
211
212 /// Maximum commits to scan before giving up
213 #[clap(long, default_value_t = 100)]
214 max_commits: usize,
215
216 /// Ignore state file and reprocess all commits
217 #[clap(long)]
218 fresh: bool,
219}
220
221impl EpArgs {
222 fn output_path(&self) -> Option<PathBuf> {
223 if self.in_place {
224 if self.inputs.len() == 1 {
225 self.inputs.first().cloned()
226 } else {
227 panic!("--in-place requires exactly one input file")
228 }
229 } else {
230 self.output.clone()
231 }
232 }
233}
234
235async fn load_examples(
236 http_client: Arc<dyn http_client::HttpClient>,
237 args: &EpArgs,
238) -> anyhow::Result<Vec<Example>> {
239 let mut captured_after_timestamps = Vec::new();
240 let mut file_inputs = Vec::new();
241
242 for input in &args.inputs {
243 let input_string = input.to_string_lossy();
244 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
245 captured_after_timestamps.push(timestamp.to_string());
246 } else {
247 file_inputs.push(input.clone());
248 }
249 }
250
251 let mut examples = read_example_files(&file_inputs);
252 let total_steps = examples.len() + captured_after_timestamps.len();
253 Progress::global().set_total_steps(total_steps);
254
255 let remaining_limit_for_snowflake =
256 args.limit.map(|limit| limit.saturating_sub(examples.len()));
257
258 if let Some(0) = remaining_limit_for_snowflake {
259 log::info!(
260 "skipping captured-after inputs because --limit is already satisfied by example files"
261 );
262 } else if !captured_after_timestamps.is_empty() {
263 captured_after_timestamps.sort();
264
265 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
266
267 let mut captured_examples = pull_examples::fetch_captured_examples_after(
268 http_client,
269 &captured_after_timestamps,
270 max_rows_per_timestamp,
271 )
272 .await?;
273 examples.append(&mut captured_examples);
274 }
275
276 crate::example::sort_examples_by_repo_and_rev(&mut examples);
277
278 if let Some(limit) = args.limit {
279 if examples.len() > limit {
280 examples.truncate(limit);
281 }
282 }
283
284 Progress::global().set_total_steps(examples.len() + captured_after_timestamps.len());
285
286 Ok(examples)
287}
288
289fn main() {
290 let args = EpArgs::parse();
291
292 if args.printenv {
293 ::util::shell_env::print_env();
294 return;
295 }
296
297 let output = args.output_path();
298 let command = match &args.command {
299 Some(cmd) => cmd.clone(),
300 None => {
301 EpArgs::command().print_help().unwrap();
302 return;
303 }
304 };
305
306 match &command {
307 Command::Clean => {
308 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
309 return;
310 }
311 Command::Synthesize(synth_args) => {
312 let Some(output_dir) = args.output else {
313 panic!("output dir is required");
314 };
315 let config = SynthesizeConfig {
316 repo_url: synth_args.repo.clone(),
317 count: synth_args.count,
318 max_commits: synth_args.max_commits,
319 output_dir,
320 fresh: synth_args.fresh,
321 };
322 smol::block_on(async {
323 if let Err(e) = run_synthesize(config).await {
324 eprintln!("Error: {:?}", e);
325 std::process::exit(1);
326 }
327 });
328 return;
329 }
330 Command::SplitCommit(split_commit_args) => {
331 if let Err(error) = split_commit::run_split_commit(split_commit_args) {
332 eprintln!("{error:#}");
333 std::process::exit(1);
334 }
335 return;
336 }
337 _ => {}
338 }
339
340 let http_client = Arc::new(ReqwestClient::new());
341 let app = Application::headless().with_http_client(http_client);
342
343 app.run(move |cx| {
344 let app_state = Arc::new(headless::init(cx));
345 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
346
347 cx.spawn(async move |cx| {
348 let result = async {
349 let mut examples = load_examples(app_state.client.http_client(), &args).await?;
350
351 if let Command::Predict(args) = &command {
352 predict::sync_batches(&args.provider).await?;
353 }
354
355 let failfast_on_single_example = examples.len() == 1;
356
357 let mut grouped_examples = group_examples_by_repo(&mut examples);
358 let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
359
360 for example_batch in example_batches {
361 let futures = example_batch.into_iter().map(|repo_examples| async {
362 for example in repo_examples.iter_mut() {
363 let result = async {
364 match &command {
365 Command::ParseExample => {}
366 Command::LoadProject => {
367 run_load_project(example, app_state.clone(), cx.clone())
368 .await?;
369 }
370 Command::Context => {
371 run_context_retrieval(
372 example,
373 app_state.clone(),
374 cx.clone(),
375 )
376 .await?;
377 }
378 Command::FormatPrompt(args) => {
379 run_format_prompt(
380 example,
381 args.prompt_format,
382 app_state.clone(),
383 cx.clone(),
384 )
385 .await?;
386 }
387 Command::Predict(args) => {
388 run_prediction(
389 example,
390 Some(args.provider),
391 args.repetitions,
392 app_state.clone(),
393 cx.clone(),
394 )
395 .await?;
396 }
397 Command::Distill => {
398 run_distill(example).await?;
399 }
400 Command::Score(args) | Command::Eval(args) => {
401 run_scoring(example, &args, app_state.clone(), cx.clone())
402 .await?;
403 }
404 Command::Clean
405 | Command::Synthesize(_)
406 | Command::SplitCommit(_) => {
407 unreachable!()
408 }
409 }
410 anyhow::Ok(())
411 }
412 .await;
413
414 if let Err(e) = result {
415 Progress::global().increment_failed();
416 let failed_example_path =
417 FAILED_EXAMPLES_DIR.join(format!("{}.json", example.spec.name));
418 app_state
419 .fs
420 .write(
421 &failed_example_path,
422 &serde_json::to_vec_pretty(&example).unwrap(),
423 )
424 .await
425 .unwrap();
426 let err_path = FAILED_EXAMPLES_DIR
427 .join(format!("{}_err.txt", example.spec.name));
428 app_state
429 .fs
430 .write(&err_path, e.to_string().as_bytes())
431 .await
432 .unwrap();
433
434 let msg = format!(
435 indoc::indoc! {"
436 While processing \"{}\":
437
438 {:?}
439
440 Written to: \x1b[36m{}\x1b[0m
441
442 Explore this example data with:
443 fx \x1b[36m{}\x1b[0m
444
445 Re-run this example with:
446 cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
447 "},
448 example.spec.name,
449 e,
450 err_path.display(),
451 failed_example_path.display(),
452 command,
453 failed_example_path.display(),
454 );
455 if args.failfast || failfast_on_single_example {
456 Progress::global().finalize();
457 panic!("{}", msg);
458 } else {
459 log::error!("{}", msg);
460 }
461 }
462 }
463 });
464 futures::future::join_all(futures).await;
465 }
466 Progress::global().finalize();
467
468 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
469 write_examples(&examples, output.as_ref());
470 }
471
472 match &command {
473 Command::Predict(args) => predict::sync_batches(&args.provider).await?,
474 Command::Eval(_) => score::print_report(&examples),
475 _ => (),
476 };
477
478 anyhow::Ok(())
479 }
480 .await;
481
482 if let Err(e) = result {
483 panic!("Fatal error: {:?}", e);
484 }
485
486 let _ = cx.update(|cx| cx.quit());
487 })
488 .detach();
489 });
490}