1mod anthropic_client;
2mod distill;
3mod example;
4mod format_prompt;
5mod git;
6mod headless;
7mod load_project;
8mod metrics;
9mod paths;
10mod predict;
11mod progress;
12mod pull_examples;
13mod reorder_patch;
14mod retrieve_context;
15mod score;
16mod split_commit;
17mod synthesize;
18use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
19use edit_prediction::EditPredictionStore;
20use gpui::Application;
21use reqwest_client::ReqwestClient;
22use serde::{Deserialize, Serialize};
23use std::fmt::Display;
24use std::{path::PathBuf, sync::Arc};
25
26use crate::distill::run_distill;
27use crate::example::{Example, group_examples_by_repo, read_example_files, write_examples};
28use crate::format_prompt::run_format_prompt;
29use crate::load_project::run_load_project;
30use crate::paths::FAILED_EXAMPLES_DIR;
31use crate::predict::run_prediction;
32use crate::progress::Progress;
33use crate::retrieve_context::run_context_retrieval;
34use crate::score::run_scoring;
35use crate::split_commit::SplitCommitArgs;
36use crate::synthesize::{SynthesizeConfig, run_synthesize};
37
38#[derive(Parser, Debug)]
39#[command(name = "ep")]
40struct EpArgs {
41 #[arg(long, default_value_t = false)]
42 printenv: bool,
43 #[clap(long, default_value_t = 10, global = true)]
44 max_parallelism: usize,
45 #[clap(long, global = true)]
46 limit: Option<usize>,
47 #[command(subcommand)]
48 command: Option<Command>,
49 #[clap(global = true, help = INPUTS_HELP)]
50 inputs: Vec<PathBuf>,
51 #[arg(long, short, global = true)]
52 output: Option<PathBuf>,
53 #[arg(long, short, global = true)]
54 in_place: bool,
55 #[arg(long, short, global = true)]
56 failfast: bool,
57}
58
59const INPUTS_HELP: &str = r#"
60Inputs can be file paths or special specifiers:
61
62 path
63 Path to an example(s) file (.md, .json, or .jsonl)
64
65 captured-after:{timestamp}
66 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
67
68 You can specify this multiple times and mix it with file inputs.
69
70 Required environment variables to connect to Snowflake:
71 EP_SNOWFLAKE_API_KEY
72 EP_SNOWFLAKE_BASE_URL
73
74 Optional:
75 EP_SNOWFLAKE_ROLE
76
77Examples:
78
79 # Predict from a file
80 ep predict examples.jsonl
81
82 # Predict from captured examples after a timestamp
83 ep predict captured-after:2025-01-01T00:00:00Z
84
85 # Mix file inputs and captured-after in the same invocation
86 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
87"#;
88
89#[derive(Subcommand, Debug, Clone)]
90enum Command {
91 /// Parse markdown examples and output a combined .jsonl file
92 ParseExample,
93 /// Create git worktrees for each example and load file contents
94 LoadProject,
95 /// Retrieve context for input examples.
96 Context,
97 /// Generate a prompt string for a specific model
98 FormatPrompt(FormatPromptArgs),
99 /// Runs edit prediction
100 Predict(PredictArgs),
101 /// Computes a score based on actual and expected patches
102 Score(PredictArgs),
103 /// Prepares a distillation dataset by copying expected outputs to
104 /// predicted outputs and removing actual outputs and prompts.
105 Distill,
106 /// Print aggregated scores
107 Eval(PredictArgs),
108 /// Generate eval examples by analyzing git commits from a repository
109 Synthesize(SynthesizeArgs),
110 /// Remove git repositories and worktrees
111 Clean,
112 /// Generate an evaluation example by splitting a chronologically-ordered commit
113 SplitCommit(SplitCommitArgs),
114}
115
116impl Display for Command {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 match self {
119 Command::ParseExample => write!(f, "parse-example"),
120 Command::LoadProject => write!(f, "load-project"),
121 Command::Context => write!(f, "context"),
122 Command::FormatPrompt(format_prompt_args) => write!(
123 f,
124 "format-prompt --prompt-format={}",
125 format_prompt_args
126 .prompt_format
127 .to_possible_value()
128 .unwrap()
129 .get_name()
130 ),
131 Command::Predict(predict_args) => {
132 write!(
133 f,
134 "predict --provider={:?}",
135 predict_args
136 .provider
137 .to_possible_value()
138 .unwrap()
139 .get_name()
140 )
141 }
142 Command::Score(predict_args) => {
143 write!(
144 f,
145 "score --provider={:?}",
146 predict_args
147 .provider
148 .to_possible_value()
149 .unwrap()
150 .get_name()
151 )
152 }
153 Command::Distill => write!(f, "distill"),
154 Command::Eval(predict_args) => write!(
155 f,
156 "eval --provider={:?}",
157 predict_args
158 .provider
159 .to_possible_value()
160 .unwrap()
161 .get_name()
162 ),
163 Command::Synthesize(args) => {
164 write!(f, "synthesize --repo={}", args.repo)
165 }
166 Command::Clean => write!(f, "clean"),
167 Command::SplitCommit(_) => write!(f, "split-commit"),
168 }
169 }
170}
171
172#[derive(Debug, Args, Clone)]
173struct FormatPromptArgs {
174 #[clap(long)]
175 prompt_format: PromptFormat,
176}
177
178#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
179enum PromptFormat {
180 Teacher,
181 Zeta2,
182}
183
184#[derive(Debug, Args, Clone)]
185struct PredictArgs {
186 #[clap(long)]
187 provider: PredictionProvider,
188 #[clap(long, default_value_t = 1)]
189 repetitions: usize,
190}
191
192#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
193enum PredictionProvider {
194 Sweep,
195 Mercury,
196 Zeta1,
197 Zeta2,
198 Teacher,
199 TeacherNonBatching,
200}
201
202#[derive(Debug, Args, Clone)]
203struct SynthesizeArgs {
204 /// Repository URL (git@github.com:owner/repo or https://...)
205 #[clap(long)]
206 repo: String,
207
208 /// Number of examples to generate
209 #[clap(long, default_value_t = 5)]
210 count: usize,
211
212 /// Maximum commits to scan before giving up
213 #[clap(long, default_value_t = 100)]
214 max_commits: usize,
215
216 /// Ignore state file and reprocess all commits
217 #[clap(long)]
218 fresh: bool,
219}
220
221impl EpArgs {
222 fn output_path(&self) -> Option<PathBuf> {
223 if self.in_place {
224 if self.inputs.len() == 1 {
225 self.inputs.first().cloned()
226 } else {
227 panic!("--in-place requires exactly one input file")
228 }
229 } else {
230 self.output.clone()
231 }
232 }
233}
234
235async fn load_examples(
236 http_client: Arc<dyn http_client::HttpClient>,
237 args: &EpArgs,
238) -> anyhow::Result<Vec<Example>> {
239 let mut captured_after_timestamps = Vec::new();
240 let mut file_inputs = Vec::new();
241
242 for input in &args.inputs {
243 let input_string = input.to_string_lossy();
244 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
245 captured_after_timestamps.push(timestamp.to_string());
246 } else {
247 file_inputs.push(input.clone());
248 }
249 }
250
251 let mut examples = read_example_files(&file_inputs);
252 let total_steps = examples.len() + captured_after_timestamps.len();
253 Progress::global().set_total_steps(total_steps);
254
255 let remaining_limit_for_snowflake =
256 args.limit.map(|limit| limit.saturating_sub(examples.len()));
257
258 if let Some(0) = remaining_limit_for_snowflake {
259 log::info!(
260 "skipping captured-after inputs because --limit is already satisfied by example files"
261 );
262 } else if !captured_after_timestamps.is_empty() {
263 captured_after_timestamps.sort();
264
265 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
266
267 let mut captured_examples = pull_examples::fetch_captured_examples_after(
268 http_client,
269 &captured_after_timestamps,
270 max_rows_per_timestamp,
271 )
272 .await?;
273 examples.append(&mut captured_examples);
274 }
275
276 crate::example::sort_examples_by_repo_and_rev(&mut examples);
277
278 if let Some(limit) = args.limit {
279 if examples.len() > limit {
280 examples.truncate(limit);
281 }
282 }
283
284 Progress::global().set_total_steps(examples.len() + captured_after_timestamps.len());
285
286 Ok(examples)
287}
288
289fn main() {
290 let args = EpArgs::parse();
291
292 if args.printenv {
293 ::util::shell_env::print_env();
294 return;
295 }
296
297 let output = args.output_path();
298 let command = match &args.command {
299 Some(cmd) => cmd.clone(),
300 None => {
301 EpArgs::command().print_help().unwrap();
302 return;
303 }
304 };
305
306 match &command {
307 Command::Clean => {
308 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
309 return;
310 }
311 Command::Synthesize(synth_args) => {
312 let Some(output_dir) = args.output else {
313 panic!("output dir is required");
314 };
315 let config = SynthesizeConfig {
316 repo_url: synth_args.repo.clone(),
317 count: synth_args.count,
318 max_commits: synth_args.max_commits,
319 output_dir,
320 fresh: synth_args.fresh,
321 };
322 smol::block_on(async {
323 if let Err(e) = run_synthesize(config).await {
324 eprintln!("Error: {:?}", e);
325 std::process::exit(1);
326 }
327 });
328 return;
329 }
330 Command::SplitCommit(split_commit_args) => {
331 if let Err(error) =
332 split_commit::run_split_commit(split_commit_args, &args.inputs, output.as_ref())
333 {
334 eprintln!("{error:#}");
335 std::process::exit(1);
336 }
337 return;
338 }
339 _ => {}
340 }
341
342 let http_client = Arc::new(ReqwestClient::new());
343 let app = Application::headless().with_http_client(http_client);
344
345 app.run(move |cx| {
346 let app_state = Arc::new(headless::init(cx));
347 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
348
349 cx.spawn(async move |cx| {
350 let result = async {
351 let mut examples = load_examples(app_state.client.http_client(), &args).await?;
352
353 if let Command::Predict(args) = &command {
354 predict::sync_batches(&args.provider).await?;
355 }
356
357 let failfast_on_single_example = examples.len() == 1;
358
359 let mut grouped_examples = group_examples_by_repo(&mut examples);
360 let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
361
362 for example_batch in example_batches {
363 let futures = example_batch.into_iter().map(|repo_examples| async {
364 for example in repo_examples.iter_mut() {
365 let result = async {
366 match &command {
367 Command::ParseExample => {}
368 Command::LoadProject => {
369 run_load_project(example, app_state.clone(), cx.clone())
370 .await?;
371 }
372 Command::Context => {
373 run_context_retrieval(
374 example,
375 app_state.clone(),
376 cx.clone(),
377 )
378 .await?;
379 }
380 Command::FormatPrompt(args) => {
381 run_format_prompt(
382 example,
383 args.prompt_format,
384 app_state.clone(),
385 cx.clone(),
386 )
387 .await?;
388 }
389 Command::Predict(args) => {
390 run_prediction(
391 example,
392 Some(args.provider),
393 args.repetitions,
394 app_state.clone(),
395 cx.clone(),
396 )
397 .await?;
398 }
399 Command::Distill => {
400 run_distill(example).await?;
401 }
402 Command::Score(args) | Command::Eval(args) => {
403 run_scoring(example, &args, app_state.clone(), cx.clone())
404 .await?;
405 }
406 Command::Clean
407 | Command::Synthesize(_)
408 | Command::SplitCommit(_) => {
409 unreachable!()
410 }
411 }
412 anyhow::Ok(())
413 }
414 .await;
415
416 if let Err(e) = result {
417 Progress::global().increment_failed();
418 let failed_example_path =
419 FAILED_EXAMPLES_DIR.join(format!("{}.json", example.spec.name));
420 app_state
421 .fs
422 .write(
423 &failed_example_path,
424 &serde_json::to_vec_pretty(&example).unwrap(),
425 )
426 .await
427 .unwrap();
428 let err_path = FAILED_EXAMPLES_DIR
429 .join(format!("{}_err.txt", example.spec.name));
430 app_state
431 .fs
432 .write(&err_path, format!("{e:?}").as_bytes())
433 .await
434 .unwrap();
435
436 let msg = format!(
437 indoc::indoc! {"
438 While processing \"{}\":
439
440 {:?}
441
442 Written to: \x1b[36m{}\x1b[0m
443
444 Explore this example data with:
445 fx \x1b[36m{}\x1b[0m
446
447 Re-run this example with:
448 cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
449 "},
450 example.spec.name,
451 e,
452 err_path.display(),
453 failed_example_path.display(),
454 command,
455 failed_example_path.display(),
456 );
457 if args.failfast || failfast_on_single_example {
458 Progress::global().finalize();
459 panic!("{}", msg);
460 } else {
461 log::error!("{}", msg);
462 }
463 }
464 }
465 });
466 futures::future::join_all(futures).await;
467 }
468 Progress::global().finalize();
469
470 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
471 write_examples(&examples, output.as_ref());
472 }
473
474 match &command {
475 Command::Predict(args) => predict::sync_batches(&args.provider).await?,
476 Command::Eval(_) => score::print_report(&examples),
477 _ => (),
478 };
479
480 anyhow::Ok(())
481 }
482 .await;
483
484 if let Err(e) = result {
485 panic!("Fatal error: {:?}", e);
486 }
487
488 let _ = cx.update(|cx| cx.quit());
489 })
490 .detach();
491 });
492}