1mod anthropic_client;
2mod distill;
3mod example;
4mod format_prompt;
5mod git;
6mod headless;
7mod load_project;
8mod metrics;
9mod paths;
10mod predict;
11mod progress;
12mod retrieve_context;
13mod score;
14mod synthesize;
15
16use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
17use edit_prediction::EditPredictionStore;
18use gpui::Application;
19use reqwest_client::ReqwestClient;
20use serde::{Deserialize, Serialize};
21use std::fmt::Display;
22use std::{path::PathBuf, sync::Arc};
23
24use crate::distill::run_distill;
25use crate::example::{group_examples_by_repo, read_examples, write_examples};
26use crate::format_prompt::run_format_prompt;
27use crate::load_project::run_load_project;
28use crate::paths::FAILED_EXAMPLES_DIR;
29use crate::predict::run_prediction;
30use crate::progress::Progress;
31use crate::retrieve_context::run_context_retrieval;
32use crate::score::run_scoring;
33use crate::synthesize::{SynthesizeConfig, run_synthesize};
34
35#[derive(Parser, Debug)]
36#[command(name = "ep")]
37struct EpArgs {
38 #[arg(long, default_value_t = false)]
39 printenv: bool,
40 #[clap(long, default_value_t = 10, global = true)]
41 max_parallelism: usize,
42 #[command(subcommand)]
43 command: Option<Command>,
44 #[clap(global = true)]
45 inputs: Vec<PathBuf>,
46 #[arg(long, short, global = true)]
47 output: Option<PathBuf>,
48 #[arg(long, short, global = true)]
49 in_place: bool,
50 #[arg(long, short, global = true)]
51 failfast: bool,
52}
53
54#[derive(Subcommand, Debug)]
55enum Command {
56 /// Parse markdown examples and output a combined .jsonl file
57 ParseExample,
58 /// Create git worktrees for each example and load file contents
59 LoadProject,
60 /// Retrieve context for input examples.
61 Context,
62 /// Generate a prompt string for a specific model
63 FormatPrompt(FormatPromptArgs),
64 /// Runs edit prediction
65 Predict(PredictArgs),
66 /// Computes a score based on actual and expected patches
67 Score(PredictArgs),
68 /// Prepares a distillation dataset by copying expected outputs to
69 /// predicted outputs and removing actual outputs and prompts.
70 Distill,
71 /// Print aggregated scores
72 Eval(PredictArgs),
73 /// Generate eval examples by analyzing git commits from a repository
74 Synthesize(SynthesizeArgs),
75 /// Remove git repositories and worktrees
76 Clean,
77}
78
79impl Display for Command {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 match self {
82 Command::ParseExample => write!(f, "parse-example"),
83 Command::LoadProject => write!(f, "load-project"),
84 Command::Context => write!(f, "context"),
85 Command::FormatPrompt(format_prompt_args) => write!(
86 f,
87 "format-prompt --prompt-format={}",
88 format_prompt_args
89 .prompt_format
90 .to_possible_value()
91 .unwrap()
92 .get_name()
93 ),
94 Command::Predict(predict_args) => {
95 write!(
96 f,
97 "predict --provider={:?}",
98 predict_args
99 .provider
100 .to_possible_value()
101 .unwrap()
102 .get_name()
103 )
104 }
105 Command::Score(predict_args) => {
106 write!(
107 f,
108 "score --provider={:?}",
109 predict_args
110 .provider
111 .to_possible_value()
112 .unwrap()
113 .get_name()
114 )
115 }
116 Command::Distill => write!(f, "distill"),
117 Command::Eval(predict_args) => write!(
118 f,
119 "eval --provider={:?}",
120 predict_args
121 .provider
122 .to_possible_value()
123 .unwrap()
124 .get_name()
125 ),
126 Command::Synthesize(args) => {
127 write!(f, "synthesize --repo={}", args.repo)
128 }
129 Command::Clean => write!(f, "clean"),
130 }
131 }
132}
133
134#[derive(Debug, Args)]
135struct FormatPromptArgs {
136 #[clap(long)]
137 prompt_format: PromptFormat,
138}
139
140#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
141enum PromptFormat {
142 Teacher,
143 Zeta2,
144}
145
146#[derive(Debug, Args)]
147struct PredictArgs {
148 #[clap(long)]
149 provider: PredictionProvider,
150 #[clap(long, default_value_t = 1)]
151 repetitions: usize,
152}
153
154#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
155enum PredictionProvider {
156 Sweep,
157 Mercury,
158 Zeta1,
159 Zeta2,
160 Teacher,
161 TeacherNonBatching,
162}
163
164#[derive(Debug, Args)]
165struct SynthesizeArgs {
166 /// Repository URL (git@github.com:owner/repo or https://...)
167 #[clap(long)]
168 repo: String,
169
170 /// Number of examples to generate
171 #[clap(long, default_value_t = 5)]
172 count: usize,
173
174 /// Maximum commits to scan before giving up
175 #[clap(long, default_value_t = 100)]
176 max_commits: usize,
177
178 /// Ignore state file and reprocess all commits
179 #[clap(long)]
180 fresh: bool,
181}
182
183impl EpArgs {
184 fn output_path(&self) -> Option<PathBuf> {
185 if self.in_place {
186 if self.inputs.len() == 1 {
187 self.inputs.first().cloned()
188 } else {
189 panic!("--in-place requires exactly one input file")
190 }
191 } else {
192 self.output.clone()
193 }
194 }
195}
196
197fn main() {
198 let args = EpArgs::parse();
199
200 if args.printenv {
201 ::util::shell_env::print_env();
202 return;
203 }
204
205 let output = args.output_path();
206 let command = match args.command {
207 Some(cmd) => cmd,
208 None => {
209 EpArgs::command().print_help().unwrap();
210 return;
211 }
212 };
213
214 match &command {
215 Command::Clean => {
216 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
217 return;
218 }
219 Command::Synthesize(synth_args) => {
220 let Some(output_dir) = args.output else {
221 panic!("output dir is required");
222 };
223 let config = SynthesizeConfig {
224 repo_url: synth_args.repo.clone(),
225 count: synth_args.count,
226 max_commits: synth_args.max_commits,
227 output_dir,
228 fresh: synth_args.fresh,
229 };
230 smol::block_on(async {
231 if let Err(e) = run_synthesize(config).await {
232 eprintln!("Error: {:?}", e);
233 std::process::exit(1);
234 }
235 });
236 return;
237 }
238 _ => {}
239 }
240
241 let mut examples = read_examples(&args.inputs);
242 let http_client = Arc::new(ReqwestClient::new());
243 let app = Application::headless().with_http_client(http_client);
244
245 app.run(move |cx| {
246 let app_state = Arc::new(headless::init(cx));
247 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
248
249 cx.spawn(async move |cx| {
250 let result = async {
251 if let Command::Predict(args) = &command {
252 predict::sync_batches(&args.provider).await?;
253 }
254
255 let total_examples = examples.len();
256 Progress::global().set_total_examples(total_examples);
257
258 let mut grouped_examples = group_examples_by_repo(&mut examples);
259 let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
260
261 for example_batch in example_batches {
262 let futures = example_batch.into_iter().map(|repo_examples| async {
263 for example in repo_examples.iter_mut() {
264 let result = async {
265 match &command {
266 Command::ParseExample => {}
267 Command::LoadProject => {
268 run_load_project(example, app_state.clone(), cx.clone())
269 .await?;
270 }
271 Command::Context => {
272 run_context_retrieval(
273 example,
274 app_state.clone(),
275 cx.clone(),
276 )
277 .await?;
278 }
279 Command::FormatPrompt(args) => {
280 run_format_prompt(
281 example,
282 args.prompt_format,
283 app_state.clone(),
284 cx.clone(),
285 )
286 .await?;
287 }
288 Command::Predict(args) => {
289 run_prediction(
290 example,
291 Some(args.provider),
292 args.repetitions,
293 app_state.clone(),
294 cx.clone(),
295 )
296 .await?;
297 }
298 Command::Distill => {
299 run_distill(example).await?;
300 }
301 Command::Score(args) | Command::Eval(args) => {
302 run_scoring(example, &args, app_state.clone(), cx.clone())
303 .await?;
304 }
305 Command::Clean | Command::Synthesize(_) => {
306 unreachable!()
307 }
308 }
309 anyhow::Ok(())
310 }
311 .await;
312
313 if let Err(e) = result {
314 Progress::global().increment_failed();
315 let failed_example_path =
316 FAILED_EXAMPLES_DIR.join(format!("{}.json", example.spec.name));
317 app_state
318 .fs
319 .write(
320 &failed_example_path,
321 &serde_json::to_vec_pretty(&example).unwrap(),
322 )
323 .await
324 .unwrap();
325 let err_path = FAILED_EXAMPLES_DIR
326 .join(format!("{}_err.txt", example.spec.name));
327 app_state
328 .fs
329 .write(&err_path, e.to_string().as_bytes())
330 .await
331 .unwrap();
332
333 let msg = format!(
334 indoc::indoc! {"
335 While processing {}:
336
337 {:?}
338
339 Written to: \x1b[36m{}\x1b[0m
340
341 Explore this example data with:
342 fx \x1b[36m{}\x1b[0m
343
344 Re-run this example with:
345 cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
346 "},
347 example.spec.name,
348 e,
349 err_path.display(),
350 failed_example_path.display(),
351 command,
352 failed_example_path.display(),
353 );
354 if args.failfast || total_examples == 1 {
355 Progress::global().finalize();
356 panic!("{}", msg);
357 } else {
358 log::error!("{}", msg);
359 }
360 }
361 }
362 });
363 futures::future::join_all(futures).await;
364 }
365 Progress::global().finalize();
366
367 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
368 write_examples(&examples, output.as_ref());
369 }
370
371 match &command {
372 Command::Predict(args) => predict::sync_batches(&args.provider).await?,
373 Command::Eval(_) => score::print_report(&examples),
374 _ => (),
375 };
376
377 anyhow::Ok(())
378 }
379 .await;
380
381 if let Err(e) = result {
382 panic!("Fatal error: {:?}", e);
383 }
384
385 let _ = cx.update(|cx| cx.quit());
386 })
387 .detach();
388 });
389}