1mod anthropic_client;
2mod distill;
3mod example;
4mod format_prompt;
5mod git;
6mod headless;
7mod load_project;
8mod metrics;
9mod paths;
10mod predict;
11mod progress;
12mod reorder_patch;
13mod retrieve_context;
14mod score;
15mod split_commit;
16mod synthesize;
17
18use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
19use edit_prediction::EditPredictionStore;
20use gpui::Application;
21use reqwest_client::ReqwestClient;
22use serde::{Deserialize, Serialize};
23use std::fmt::Display;
24use std::{path::PathBuf, sync::Arc};
25
26use crate::distill::run_distill;
27use crate::example::{group_examples_by_repo, read_examples, write_examples};
28use crate::format_prompt::run_format_prompt;
29use crate::load_project::run_load_project;
30use crate::paths::FAILED_EXAMPLES_DIR;
31use crate::predict::run_prediction;
32use crate::progress::Progress;
33use crate::retrieve_context::run_context_retrieval;
34use crate::score::run_scoring;
35use crate::split_commit::SplitCommitArgs;
36use crate::synthesize::{SynthesizeConfig, run_synthesize};
37
38#[derive(Parser, Debug)]
39#[command(name = "ep")]
40struct EpArgs {
41 #[arg(long, default_value_t = false)]
42 printenv: bool,
43 #[clap(long, default_value_t = 10, global = true)]
44 max_parallelism: usize,
45 #[command(subcommand)]
46 command: Option<Command>,
47 #[clap(global = true)]
48 inputs: Vec<PathBuf>,
49 #[arg(long, short, global = true)]
50 output: Option<PathBuf>,
51 #[arg(long, short, global = true)]
52 in_place: bool,
53 #[arg(long, short, global = true)]
54 failfast: bool,
55}
56
57#[derive(Subcommand, Debug)]
58enum Command {
59 /// Parse markdown examples and output a combined .jsonl file
60 ParseExample,
61 /// Create git worktrees for each example and load file contents
62 LoadProject,
63 /// Retrieve context for input examples.
64 Context,
65 /// Generate a prompt string for a specific model
66 FormatPrompt(FormatPromptArgs),
67 /// Runs edit prediction
68 Predict(PredictArgs),
69 /// Computes a score based on actual and expected patches
70 Score(PredictArgs),
71 /// Prepares a distillation dataset by copying expected outputs to
72 /// predicted outputs and removing actual outputs and prompts.
73 Distill,
74 /// Print aggregated scores
75 Eval(PredictArgs),
76 /// Generate eval examples by analyzing git commits from a repository
77 Synthesize(SynthesizeArgs),
78 /// Remove git repositories and worktrees
79 Clean,
80 /// Generate an evaluation example by splitting a chronologically-ordered commit
81 SplitCommit(SplitCommitArgs),
82}
83
84impl Display for Command {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 match self {
87 Command::ParseExample => write!(f, "parse-example"),
88 Command::LoadProject => write!(f, "load-project"),
89 Command::Context => write!(f, "context"),
90 Command::FormatPrompt(format_prompt_args) => write!(
91 f,
92 "format-prompt --prompt-format={}",
93 format_prompt_args
94 .prompt_format
95 .to_possible_value()
96 .unwrap()
97 .get_name()
98 ),
99 Command::Predict(predict_args) => {
100 write!(
101 f,
102 "predict --provider={:?}",
103 predict_args
104 .provider
105 .to_possible_value()
106 .unwrap()
107 .get_name()
108 )
109 }
110 Command::Score(predict_args) => {
111 write!(
112 f,
113 "score --provider={:?}",
114 predict_args
115 .provider
116 .to_possible_value()
117 .unwrap()
118 .get_name()
119 )
120 }
121 Command::Distill => write!(f, "distill"),
122 Command::Eval(predict_args) => write!(
123 f,
124 "eval --provider={:?}",
125 predict_args
126 .provider
127 .to_possible_value()
128 .unwrap()
129 .get_name()
130 ),
131 Command::Synthesize(args) => {
132 write!(f, "synthesize --repo={}", args.repo)
133 }
134 Command::Clean => write!(f, "clean"),
135 Command::SplitCommit(_) => write!(f, "split-commit"),
136 }
137 }
138}
139
140#[derive(Debug, Args)]
141struct FormatPromptArgs {
142 #[clap(long)]
143 prompt_format: PromptFormat,
144}
145
146#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
147enum PromptFormat {
148 Teacher,
149 Zeta2,
150}
151
152#[derive(Debug, Args)]
153struct PredictArgs {
154 #[clap(long)]
155 provider: PredictionProvider,
156 #[clap(long, default_value_t = 1)]
157 repetitions: usize,
158}
159
160#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
161enum PredictionProvider {
162 Sweep,
163 Mercury,
164 Zeta1,
165 Zeta2,
166 Teacher,
167 TeacherNonBatching,
168}
169
170#[derive(Debug, Args)]
171struct SynthesizeArgs {
172 /// Repository URL (git@github.com:owner/repo or https://...)
173 #[clap(long)]
174 repo: String,
175
176 /// Number of examples to generate
177 #[clap(long, default_value_t = 5)]
178 count: usize,
179
180 /// Maximum commits to scan before giving up
181 #[clap(long, default_value_t = 100)]
182 max_commits: usize,
183
184 /// Ignore state file and reprocess all commits
185 #[clap(long)]
186 fresh: bool,
187}
188
189impl EpArgs {
190 fn output_path(&self) -> Option<PathBuf> {
191 if self.in_place {
192 if self.inputs.len() == 1 {
193 self.inputs.first().cloned()
194 } else {
195 panic!("--in-place requires exactly one input file")
196 }
197 } else {
198 self.output.clone()
199 }
200 }
201}
202
203fn main() {
204 let args = EpArgs::parse();
205
206 if args.printenv {
207 ::util::shell_env::print_env();
208 return;
209 }
210
211 let output = args.output_path();
212 let command = match args.command {
213 Some(cmd) => cmd,
214 None => {
215 EpArgs::command().print_help().unwrap();
216 return;
217 }
218 };
219
220 match &command {
221 Command::Clean => {
222 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
223 return;
224 }
225 Command::Synthesize(synth_args) => {
226 let Some(output_dir) = args.output else {
227 panic!("output dir is required");
228 };
229 let config = SynthesizeConfig {
230 repo_url: synth_args.repo.clone(),
231 count: synth_args.count,
232 max_commits: synth_args.max_commits,
233 output_dir,
234 fresh: synth_args.fresh,
235 };
236 smol::block_on(async {
237 if let Err(e) = run_synthesize(config).await {
238 eprintln!("Error: {:?}", e);
239 std::process::exit(1);
240 }
241 });
242 return;
243 }
244 Command::SplitCommit(split_commit_args) => {
245 if let Err(error) = split_commit::run_split_commit(split_commit_args) {
246 eprintln!("{error:#}");
247 std::process::exit(1);
248 }
249 return;
250 }
251 _ => {}
252 }
253
254 let mut examples = read_examples(&args.inputs);
255 let http_client = Arc::new(ReqwestClient::new());
256 let app = Application::headless().with_http_client(http_client);
257
258 app.run(move |cx| {
259 let app_state = Arc::new(headless::init(cx));
260 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
261
262 cx.spawn(async move |cx| {
263 let result = async {
264 if let Command::Predict(args) = &command {
265 predict::sync_batches(&args.provider).await?;
266 }
267
268 let total_examples = examples.len();
269 Progress::global().set_total_examples(total_examples);
270
271 let mut grouped_examples = group_examples_by_repo(&mut examples);
272 let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
273
274 for example_batch in example_batches {
275 let futures = example_batch.into_iter().map(|repo_examples| async {
276 for example in repo_examples.iter_mut() {
277 let result = async {
278 match &command {
279 Command::ParseExample => {}
280 Command::LoadProject => {
281 run_load_project(example, app_state.clone(), cx.clone())
282 .await?;
283 }
284 Command::Context => {
285 run_context_retrieval(
286 example,
287 app_state.clone(),
288 cx.clone(),
289 )
290 .await?;
291 }
292 Command::FormatPrompt(args) => {
293 run_format_prompt(
294 example,
295 args.prompt_format,
296 app_state.clone(),
297 cx.clone(),
298 )
299 .await?;
300 }
301 Command::Predict(args) => {
302 run_prediction(
303 example,
304 Some(args.provider),
305 args.repetitions,
306 app_state.clone(),
307 cx.clone(),
308 )
309 .await?;
310 }
311 Command::Distill => {
312 run_distill(example).await?;
313 }
314 Command::Score(args) | Command::Eval(args) => {
315 run_scoring(example, &args, app_state.clone(), cx.clone())
316 .await?;
317 }
318 Command::Clean
319 | Command::Synthesize(_)
320 | Command::SplitCommit(_) => {
321 unreachable!()
322 }
323 }
324 anyhow::Ok(())
325 }
326 .await;
327
328 if let Err(e) = result {
329 Progress::global().increment_failed();
330 let failed_example_path =
331 FAILED_EXAMPLES_DIR.join(format!("{}.json", example.spec.name));
332 app_state
333 .fs
334 .write(
335 &failed_example_path,
336 &serde_json::to_vec_pretty(&example).unwrap(),
337 )
338 .await
339 .unwrap();
340 let err_path = FAILED_EXAMPLES_DIR
341 .join(format!("{}_err.txt", example.spec.name));
342 app_state
343 .fs
344 .write(&err_path, e.to_string().as_bytes())
345 .await
346 .unwrap();
347
348 let msg = format!(
349 indoc::indoc! {"
350 While processing {}:
351
352 {:?}
353
354 Written to: \x1b[36m{}\x1b[0m
355
356 Explore this example data with:
357 fx \x1b[36m{}\x1b[0m
358
359 Re-run this example with:
360 cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
361 "},
362 example.spec.name,
363 e,
364 err_path.display(),
365 failed_example_path.display(),
366 command,
367 failed_example_path.display(),
368 );
369 if args.failfast || total_examples == 1 {
370 Progress::global().finalize();
371 panic!("{}", msg);
372 } else {
373 log::error!("{}", msg);
374 }
375 }
376 }
377 });
378 futures::future::join_all(futures).await;
379 }
380 Progress::global().finalize();
381
382 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
383 write_examples(&examples, output.as_ref());
384 }
385
386 match &command {
387 Command::Predict(args) => predict::sync_batches(&args.provider).await?,
388 Command::Eval(_) => score::print_report(&examples),
389 _ => (),
390 };
391
392 anyhow::Ok(())
393 }
394 .await;
395
396 if let Err(e) = result {
397 panic!("Fatal error: {:?}", e);
398 }
399
400 let _ = cx.update(|cx| cx.quit());
401 })
402 .detach();
403 });
404}