1mod anthropic_client;
2mod distill;
3mod example;
4mod format_prompt;
5mod headless;
6mod load_project;
7mod metrics;
8mod paths;
9mod predict;
10mod progress;
11mod retrieve_context;
12mod score;
13
14use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
15use edit_prediction::EditPredictionStore;
16use gpui::Application;
17use reqwest_client::ReqwestClient;
18use serde::{Deserialize, Serialize};
19use std::fmt::Display;
20use std::{path::PathBuf, sync::Arc};
21
22use crate::distill::run_distill;
23use crate::example::{group_examples_by_repo, read_examples, write_examples};
24use crate::format_prompt::run_format_prompt;
25use crate::load_project::run_load_project;
26use crate::paths::FAILED_EXAMPLES_DIR;
27use crate::predict::run_prediction;
28use crate::progress::Progress;
29use crate::retrieve_context::run_context_retrieval;
30use crate::score::run_scoring;
31
32#[derive(Parser, Debug)]
33#[command(name = "ep")]
34struct EpArgs {
35 #[arg(long, default_value_t = false)]
36 printenv: bool,
37 #[clap(long, default_value_t = 10, global = true)]
38 max_parallelism: usize,
39 #[command(subcommand)]
40 command: Option<Command>,
41 #[clap(global = true)]
42 inputs: Vec<PathBuf>,
43 #[arg(long, short, global = true)]
44 output: Option<PathBuf>,
45 #[arg(long, short, global = true)]
46 in_place: bool,
47 #[arg(long, short, global = true)]
48 failfast: bool,
49}
50
51#[derive(Subcommand, Debug)]
52enum Command {
53 /// Parse markdown examples and output a combined .jsonl file
54 ParseExample,
55 /// Create git worktrees for each example and load file contents
56 LoadProject,
57 /// Retrieve context for input examples.
58 Context,
59 /// Generate a prompt string for a specific model
60 FormatPrompt(FormatPromptArgs),
61 /// Runs edit prediction
62 Predict(PredictArgs),
63 /// Computes a score based on actual and expected patches
64 Score(PredictArgs),
65 /// Prepares a distillation dataset by copying expected outputs to
66 /// predicted outputs and removing actual outputs and prompts.
67 Distill,
68 /// Print aggregated scores
69 Eval(PredictArgs),
70 /// Remove git repositories and worktrees
71 Clean,
72}
73
74impl Display for Command {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 match self {
77 Command::ParseExample => write!(f, "parse-example"),
78 Command::LoadProject => write!(f, "load-project"),
79 Command::Context => write!(f, "context"),
80 Command::FormatPrompt(format_prompt_args) => write!(
81 f,
82 "format-prompt --prompt-format={}",
83 format_prompt_args
84 .prompt_format
85 .to_possible_value()
86 .unwrap()
87 .get_name()
88 ),
89 Command::Predict(predict_args) => {
90 write!(
91 f,
92 "predict --provider={:?}",
93 predict_args
94 .provider
95 .to_possible_value()
96 .unwrap()
97 .get_name()
98 )
99 }
100 Command::Score(predict_args) => {
101 write!(
102 f,
103 "score --provider={:?}",
104 predict_args
105 .provider
106 .to_possible_value()
107 .unwrap()
108 .get_name()
109 )
110 }
111 Command::Distill => write!(f, "distill"),
112 Command::Eval(predict_args) => write!(
113 f,
114 "eval --provider={:?}",
115 predict_args
116 .provider
117 .to_possible_value()
118 .unwrap()
119 .get_name()
120 ),
121 Command::Clean => write!(f, "clean"),
122 }
123 }
124}
125
126#[derive(Debug, Args)]
127struct FormatPromptArgs {
128 #[clap(long)]
129 prompt_format: PromptFormat,
130}
131
132#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
133enum PromptFormat {
134 Teacher,
135 Zeta2,
136}
137
138#[derive(Debug, Args)]
139struct PredictArgs {
140 #[clap(long)]
141 provider: PredictionProvider,
142 #[clap(long, default_value_t = 1)]
143 repetitions: usize,
144}
145
146#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
147enum PredictionProvider {
148 Sweep,
149 Mercury,
150 Zeta1,
151 Zeta2,
152 Teacher,
153 TeacherNonBatching,
154}
155
156impl EpArgs {
157 fn output_path(&self) -> Option<PathBuf> {
158 if self.in_place {
159 if self.inputs.len() == 1 {
160 self.inputs.first().cloned()
161 } else {
162 panic!("--in-place requires exactly one input file")
163 }
164 } else {
165 self.output.clone()
166 }
167 }
168}
169
170fn main() {
171 let args = EpArgs::parse();
172
173 if args.printenv {
174 ::util::shell_env::print_env();
175 return;
176 }
177
178 let output = args.output_path();
179 let command = match args.command {
180 Some(cmd) => cmd,
181 None => {
182 EpArgs::command().print_help().unwrap();
183 return;
184 }
185 };
186
187 match &command {
188 Command::Clean => {
189 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
190 return;
191 }
192 _ => {}
193 }
194
195 let mut examples = read_examples(&args.inputs);
196 let http_client = Arc::new(ReqwestClient::new());
197 let app = Application::headless().with_http_client(http_client);
198
199 app.run(move |cx| {
200 let app_state = Arc::new(headless::init(cx));
201 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
202
203 cx.spawn(async move |cx| {
204 let result = async {
205 if let Command::Predict(args) = &command {
206 predict::sync_batches(&args.provider).await?;
207 }
208
209 let total_examples = examples.len();
210 Progress::global().set_total_examples(total_examples);
211
212 let mut grouped_examples = group_examples_by_repo(&mut examples);
213 let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
214
215 for example_batch in example_batches {
216 let futures = example_batch.into_iter().map(|repo_examples| async {
217 for example in repo_examples.iter_mut() {
218 let result = async {
219 match &command {
220 Command::ParseExample => {}
221 Command::LoadProject => {
222 run_load_project(example, app_state.clone(), cx.clone())
223 .await?;
224 }
225 Command::Context => {
226 run_context_retrieval(
227 example,
228 app_state.clone(),
229 cx.clone(),
230 )
231 .await?;
232 }
233 Command::FormatPrompt(args) => {
234 run_format_prompt(
235 example,
236 args.prompt_format,
237 app_state.clone(),
238 cx.clone(),
239 )
240 .await?;
241 }
242 Command::Predict(args) => {
243 run_prediction(
244 example,
245 Some(args.provider),
246 args.repetitions,
247 app_state.clone(),
248 cx.clone(),
249 )
250 .await?;
251 }
252 Command::Distill => {
253 run_distill(example).await?;
254 }
255 Command::Score(args) | Command::Eval(args) => {
256 run_scoring(example, &args, app_state.clone(), cx.clone())
257 .await?;
258 }
259 Command::Clean => {
260 unreachable!()
261 }
262 }
263 anyhow::Ok(())
264 }
265 .await;
266
267 if let Err(e) = result {
268 Progress::global().increment_failed();
269 let failed_example_path =
270 FAILED_EXAMPLES_DIR.join(format!("{}.json", example.spec.name));
271 app_state
272 .fs
273 .write(
274 &failed_example_path,
275 &serde_json::to_vec_pretty(&example).unwrap(),
276 )
277 .await
278 .unwrap();
279 let err_path = FAILED_EXAMPLES_DIR
280 .join(format!("{}_err.txt", example.spec.name));
281 app_state
282 .fs
283 .write(&err_path, e.to_string().as_bytes())
284 .await
285 .unwrap();
286
287 let msg = format!(
288 indoc::indoc! {"
289 While processing {}:
290
291 {:?}
292
293 Written to: \x1b[36m{}\x1b[0m
294
295 Explore this example data with:
296 fx \x1b[36m{}\x1b[0m
297
298 Re-run this example with:
299 cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
300 "},
301 example.spec.name,
302 e,
303 err_path.display(),
304 failed_example_path.display(),
305 command,
306 failed_example_path.display(),
307 );
308 if args.failfast || total_examples == 1 {
309 Progress::global().finalize();
310 panic!("{}", msg);
311 } else {
312 log::error!("{}", msg);
313 }
314 }
315 }
316 });
317 futures::future::join_all(futures).await;
318 }
319 Progress::global().finalize();
320
321 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
322 write_examples(&examples, output.as_ref());
323 }
324
325 match &command {
326 Command::Predict(args) => predict::sync_batches(&args.provider).await?,
327 Command::Eval(_) => score::print_report(&examples),
328 _ => (),
329 };
330
331 anyhow::Ok(())
332 }
333 .await;
334
335 if let Err(e) = result {
336 panic!("Fatal error: {:?}", e);
337 }
338
339 let _ = cx.update(|cx| cx.quit());
340 })
341 .detach();
342 });
343}