1mod anthropic_client;
2mod distill;
3mod example;
4mod format_prompt;
5mod git;
6mod headless;
7mod load_project;
8mod metrics;
9mod paths;
10mod predict;
11mod progress;
12mod retrieve_context;
13mod score;
14mod synthesize;
15
16use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
17use edit_prediction::EditPredictionStore;
18use gpui::Application;
19use reqwest_client::ReqwestClient;
20use serde::{Deserialize, Serialize};
21use std::fmt::Display;
22use std::{path::PathBuf, sync::Arc};
23
24use crate::distill::run_distill;
25use crate::example::{group_examples_by_repo, read_examples, write_examples};
26use crate::format_prompt::run_format_prompt;
27use crate::load_project::run_load_project;
28use crate::paths::FAILED_EXAMPLES_DIR;
29use crate::predict::run_prediction;
30use crate::progress::Progress;
31use crate::retrieve_context::run_context_retrieval;
32use crate::score::run_scoring;
33use crate::synthesize::{SynthesizeConfig, run_synthesize};
34
35#[derive(Parser, Debug)]
36#[command(name = "ep")]
37struct EpArgs {
38 #[arg(long, default_value_t = false)]
39 printenv: bool,
40 #[clap(long, default_value_t = 10, global = true)]
41 max_parallelism: usize,
42 #[command(subcommand)]
43 command: Option<Command>,
44 #[clap(global = true)]
45 inputs: Vec<PathBuf>,
46 #[arg(long, short, global = true)]
47 output: Option<PathBuf>,
48 #[arg(long, short, global = true)]
49 in_place: bool,
50 #[arg(long, short, global = true)]
51 failfast: bool,
52}
53
54#[derive(Subcommand, Debug)]
55enum Command {
56 /// Parse markdown examples and output a combined .jsonl file
57 ParseExample,
58 /// Create git worktrees for each example and load file contents
59 LoadProject,
60 /// Retrieve context for input examples.
61 Context,
62 /// Generate a prompt string for a specific model
63 FormatPrompt(FormatPromptArgs),
64 /// Runs edit prediction
65 Predict(PredictArgs),
66 /// Computes a score based on actual and expected patches
67 Score(PredictArgs),
68 /// Prepares a distillation dataset by copying expected outputs to
69 /// predicted outputs and removing actual outputs and prompts.
70 Distill,
71 /// Print aggregated scores
72 Eval(PredictArgs),
73 /// Generate eval examples by analyzing git commits from a repository
74 Synthesize(SynthesizeArgs),
75 /// Remove git repositories and worktrees
76 Clean,
77}
78
79impl Display for Command {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 match self {
82 Command::ParseExample => write!(f, "parse-example"),
83 Command::LoadProject => write!(f, "load-project"),
84 Command::Context => write!(f, "context"),
85 Command::FormatPrompt(format_prompt_args) => write!(
86 f,
87 "format-prompt --prompt-format={}",
88 format_prompt_args
89 .prompt_format
90 .to_possible_value()
91 .unwrap()
92 .get_name()
93 ),
94 Command::Predict(predict_args) => {
95 write!(
96 f,
97 "predict --provider={:?}",
98 predict_args
99 .provider
100 .to_possible_value()
101 .unwrap()
102 .get_name()
103 )
104 }
105 Command::Score(predict_args) => {
106 write!(
107 f,
108 "score --provider={:?}",
109 predict_args
110 .provider
111 .to_possible_value()
112 .unwrap()
113 .get_name()
114 )
115 }
116 Command::Distill => write!(f, "distill"),
117 Command::Eval(predict_args) => write!(
118 f,
119 "eval --provider={:?}",
120 predict_args
121 .provider
122 .to_possible_value()
123 .unwrap()
124 .get_name()
125 ),
126 Command::Synthesize(args) => {
127 write!(f, "synthesize --repo={}", args.repo)
128 }
129 Command::Clean => write!(f, "clean"),
130 }
131 }
132}
133
134#[derive(Debug, Args)]
135struct FormatPromptArgs {
136 #[clap(long)]
137 prompt_format: PromptFormat,
138}
139
140#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
141enum PromptFormat {
142 Teacher,
143 Zeta2,
144}
145
146#[derive(Debug, Args)]
147struct PredictArgs {
148 #[clap(long)]
149 provider: PredictionProvider,
150 #[clap(long, default_value_t = 1)]
151 repetitions: usize,
152}
153
154#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
155enum PredictionProvider {
156 Sweep,
157 Mercury,
158 Zeta1,
159 Zeta2,
160 Teacher,
161 TeacherNonBatching,
162}
163
164#[derive(Debug, Args)]
165struct SynthesizeArgs {
166 /// Repository URL (git@github.com:owner/repo or https://...)
167 #[clap(long)]
168 repo: String,
169
170 /// Number of examples to generate
171 #[clap(long, default_value_t = 5)]
172 count: usize,
173
174 /// Maximum commits to scan before giving up
175 #[clap(long, default_value_t = 100)]
176 max_commits: usize,
177
178 /// Output directory for draft examples
179 #[clap(long, default_value = "staging")]
180 output_dir: PathBuf,
181
182 /// Only generate examples that require retrieved context to make a correct prediction
183 #[clap(long)]
184 require_context: bool,
185
186 /// Ignore state file and reprocess all commits
187 #[clap(long)]
188 fresh: bool,
189}
190
191impl EpArgs {
192 fn output_path(&self) -> Option<PathBuf> {
193 if self.in_place {
194 if self.inputs.len() == 1 {
195 self.inputs.first().cloned()
196 } else {
197 panic!("--in-place requires exactly one input file")
198 }
199 } else {
200 self.output.clone()
201 }
202 }
203}
204
205fn main() {
206 let args = EpArgs::parse();
207
208 if args.printenv {
209 ::util::shell_env::print_env();
210 return;
211 }
212
213 let output = args.output_path();
214 let command = match args.command {
215 Some(cmd) => cmd,
216 None => {
217 EpArgs::command().print_help().unwrap();
218 return;
219 }
220 };
221
222 match &command {
223 Command::Clean => {
224 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
225 return;
226 }
227 Command::Synthesize(synth_args) => {
228 let config = SynthesizeConfig {
229 repo_url: synth_args.repo.clone(),
230 count: synth_args.count,
231 max_commits: synth_args.max_commits,
232 output_dir: synth_args.output_dir.clone(),
233 require_context: synth_args.require_context,
234 fresh: synth_args.fresh,
235 };
236 smol::block_on(async {
237 if let Err(e) = run_synthesize(config).await {
238 eprintln!("Error: {:?}", e);
239 std::process::exit(1);
240 }
241 });
242 return;
243 }
244 _ => {}
245 }
246
247 let mut examples = read_examples(&args.inputs);
248 let http_client = Arc::new(ReqwestClient::new());
249 let app = Application::headless().with_http_client(http_client);
250
251 app.run(move |cx| {
252 let app_state = Arc::new(headless::init(cx));
253 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
254
255 cx.spawn(async move |cx| {
256 let result = async {
257 if let Command::Predict(args) = &command {
258 predict::sync_batches(&args.provider).await?;
259 }
260
261 let total_examples = examples.len();
262 Progress::global().set_total_examples(total_examples);
263
264 let mut grouped_examples = group_examples_by_repo(&mut examples);
265 let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
266
267 for example_batch in example_batches {
268 let futures = example_batch.into_iter().map(|repo_examples| async {
269 for example in repo_examples.iter_mut() {
270 let result = async {
271 match &command {
272 Command::ParseExample => {}
273 Command::LoadProject => {
274 run_load_project(example, app_state.clone(), cx.clone())
275 .await?;
276 }
277 Command::Context => {
278 run_context_retrieval(
279 example,
280 app_state.clone(),
281 cx.clone(),
282 )
283 .await?;
284 }
285 Command::FormatPrompt(args) => {
286 run_format_prompt(
287 example,
288 args.prompt_format,
289 app_state.clone(),
290 cx.clone(),
291 )
292 .await?;
293 }
294 Command::Predict(args) => {
295 run_prediction(
296 example,
297 Some(args.provider),
298 args.repetitions,
299 app_state.clone(),
300 cx.clone(),
301 )
302 .await?;
303 }
304 Command::Distill => {
305 run_distill(example).await?;
306 }
307 Command::Score(args) | Command::Eval(args) => {
308 run_scoring(example, &args, app_state.clone(), cx.clone())
309 .await?;
310 }
311 Command::Clean | Command::Synthesize(_) => {
312 unreachable!()
313 }
314 }
315 anyhow::Ok(())
316 }
317 .await;
318
319 if let Err(e) = result {
320 Progress::global().increment_failed();
321 let failed_example_path =
322 FAILED_EXAMPLES_DIR.join(format!("{}.json", example.spec.name));
323 app_state
324 .fs
325 .write(
326 &failed_example_path,
327 &serde_json::to_vec_pretty(&example).unwrap(),
328 )
329 .await
330 .unwrap();
331 let err_path = FAILED_EXAMPLES_DIR
332 .join(format!("{}_err.txt", example.spec.name));
333 app_state
334 .fs
335 .write(&err_path, e.to_string().as_bytes())
336 .await
337 .unwrap();
338
339 let msg = format!(
340 indoc::indoc! {"
341 While processing {}:
342
343 {:?}
344
345 Written to: \x1b[36m{}\x1b[0m
346
347 Explore this example data with:
348 fx \x1b[36m{}\x1b[0m
349
350 Re-run this example with:
351 cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
352 "},
353 example.spec.name,
354 e,
355 err_path.display(),
356 failed_example_path.display(),
357 command,
358 failed_example_path.display(),
359 );
360 if args.failfast || total_examples == 1 {
361 Progress::global().finalize();
362 panic!("{}", msg);
363 } else {
364 log::error!("{}", msg);
365 }
366 }
367 }
368 });
369 futures::future::join_all(futures).await;
370 }
371 Progress::global().finalize();
372
373 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
374 write_examples(&examples, output.as_ref());
375 }
376
377 match &command {
378 Command::Predict(args) => predict::sync_batches(&args.provider).await?,
379 Command::Eval(_) => score::print_report(&examples),
380 _ => (),
381 };
382
383 anyhow::Ok(())
384 }
385 .await;
386
387 if let Err(e) = result {
388 panic!("Fatal error: {:?}", e);
389 }
390
391 let _ = cx.update(|cx| cx.quit());
392 })
393 .detach();
394 });
395}