main.rs

  1mod anthropic_client;
  2mod distill;
  3mod example;
  4mod format_prompt;
  5mod git;
  6mod headless;
  7mod load_project;
  8mod metrics;
  9mod paths;
 10mod predict;
 11mod progress;
 12mod reorder_patch;
 13mod retrieve_context;
 14mod score;
 15mod split_commit;
 16mod synthesize;
 17
 18use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
 19use edit_prediction::EditPredictionStore;
 20use gpui::Application;
 21use reqwest_client::ReqwestClient;
 22use serde::{Deserialize, Serialize};
 23use std::fmt::Display;
 24use std::{path::PathBuf, sync::Arc};
 25
 26use crate::distill::run_distill;
 27use crate::example::{group_examples_by_repo, read_examples, write_examples};
 28use crate::format_prompt::run_format_prompt;
 29use crate::load_project::run_load_project;
 30use crate::paths::FAILED_EXAMPLES_DIR;
 31use crate::predict::run_prediction;
 32use crate::progress::Progress;
 33use crate::retrieve_context::run_context_retrieval;
 34use crate::score::run_scoring;
 35use crate::split_commit::SplitCommitArgs;
 36use crate::synthesize::{SynthesizeConfig, run_synthesize};
 37
 38#[derive(Parser, Debug)]
 39#[command(name = "ep")]
 40struct EpArgs {
 41    #[arg(long, default_value_t = false)]
 42    printenv: bool,
 43    #[clap(long, default_value_t = 10, global = true)]
 44    max_parallelism: usize,
 45    #[command(subcommand)]
 46    command: Option<Command>,
 47    #[clap(global = true)]
 48    inputs: Vec<PathBuf>,
 49    #[arg(long, short, global = true)]
 50    output: Option<PathBuf>,
 51    #[arg(long, short, global = true)]
 52    in_place: bool,
 53    #[arg(long, short, global = true)]
 54    failfast: bool,
 55}
 56
 57#[derive(Subcommand, Debug)]
 58enum Command {
 59    /// Parse markdown examples and output a combined .jsonl file
 60    ParseExample,
 61    /// Create git worktrees for each example and load file contents
 62    LoadProject,
 63    /// Retrieve context for input examples.
 64    Context,
 65    /// Generate a prompt string for a specific model
 66    FormatPrompt(FormatPromptArgs),
 67    /// Runs edit prediction
 68    Predict(PredictArgs),
 69    /// Computes a score based on actual and expected patches
 70    Score(PredictArgs),
 71    /// Prepares a distillation dataset by copying expected outputs to
 72    /// predicted outputs and removing actual outputs and prompts.
 73    Distill,
 74    /// Print aggregated scores
 75    Eval(PredictArgs),
 76    /// Generate eval examples by analyzing git commits from a repository
 77    Synthesize(SynthesizeArgs),
 78    /// Remove git repositories and worktrees
 79    Clean,
 80    /// Generate an evaluation example by splitting a chronologically-ordered commit
 81    SplitCommit(SplitCommitArgs),
 82}
 83
 84impl Display for Command {
 85    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 86        match self {
 87            Command::ParseExample => write!(f, "parse-example"),
 88            Command::LoadProject => write!(f, "load-project"),
 89            Command::Context => write!(f, "context"),
 90            Command::FormatPrompt(format_prompt_args) => write!(
 91                f,
 92                "format-prompt --prompt-format={}",
 93                format_prompt_args
 94                    .prompt_format
 95                    .to_possible_value()
 96                    .unwrap()
 97                    .get_name()
 98            ),
 99            Command::Predict(predict_args) => {
100                write!(
101                    f,
102                    "predict --provider={:?}",
103                    predict_args
104                        .provider
105                        .to_possible_value()
106                        .unwrap()
107                        .get_name()
108                )
109            }
110            Command::Score(predict_args) => {
111                write!(
112                    f,
113                    "score --provider={:?}",
114                    predict_args
115                        .provider
116                        .to_possible_value()
117                        .unwrap()
118                        .get_name()
119                )
120            }
121            Command::Distill => write!(f, "distill"),
122            Command::Eval(predict_args) => write!(
123                f,
124                "eval --provider={:?}",
125                predict_args
126                    .provider
127                    .to_possible_value()
128                    .unwrap()
129                    .get_name()
130            ),
131            Command::Synthesize(args) => {
132                write!(f, "synthesize --repo={}", args.repo)
133            }
134            Command::Clean => write!(f, "clean"),
135            Command::SplitCommit(_) => write!(f, "split-commit"),
136        }
137    }
138}
139
140#[derive(Debug, Args)]
141struct FormatPromptArgs {
142    #[clap(long)]
143    prompt_format: PromptFormat,
144}
145
146#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
147enum PromptFormat {
148    Teacher,
149    Zeta2,
150}
151
152#[derive(Debug, Args)]
153struct PredictArgs {
154    #[clap(long)]
155    provider: PredictionProvider,
156    #[clap(long, default_value_t = 1)]
157    repetitions: usize,
158}
159
160#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
161enum PredictionProvider {
162    Sweep,
163    Mercury,
164    Zeta1,
165    Zeta2,
166    Teacher,
167    TeacherNonBatching,
168}
169
170#[derive(Debug, Args)]
171struct SynthesizeArgs {
172    /// Repository URL (git@github.com:owner/repo or https://...)
173    #[clap(long)]
174    repo: String,
175
176    /// Number of examples to generate
177    #[clap(long, default_value_t = 5)]
178    count: usize,
179
180    /// Maximum commits to scan before giving up
181    #[clap(long, default_value_t = 100)]
182    max_commits: usize,
183
184    /// Ignore state file and reprocess all commits
185    #[clap(long)]
186    fresh: bool,
187}
188
189impl EpArgs {
190    fn output_path(&self) -> Option<PathBuf> {
191        if self.in_place {
192            if self.inputs.len() == 1 {
193                self.inputs.first().cloned()
194            } else {
195                panic!("--in-place requires exactly one input file")
196            }
197        } else {
198            self.output.clone()
199        }
200    }
201}
202
203fn main() {
204    let args = EpArgs::parse();
205
206    if args.printenv {
207        ::util::shell_env::print_env();
208        return;
209    }
210
211    let output = args.output_path();
212    let command = match args.command {
213        Some(cmd) => cmd,
214        None => {
215            EpArgs::command().print_help().unwrap();
216            return;
217        }
218    };
219
220    match &command {
221        Command::Clean => {
222            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
223            return;
224        }
225        Command::Synthesize(synth_args) => {
226            let Some(output_dir) = args.output else {
227                panic!("output dir is required");
228            };
229            let config = SynthesizeConfig {
230                repo_url: synth_args.repo.clone(),
231                count: synth_args.count,
232                max_commits: synth_args.max_commits,
233                output_dir,
234                fresh: synth_args.fresh,
235            };
236            smol::block_on(async {
237                if let Err(e) = run_synthesize(config).await {
238                    eprintln!("Error: {:?}", e);
239                    std::process::exit(1);
240                }
241            });
242            return;
243        }
244        Command::SplitCommit(split_commit_args) => {
245            if let Err(error) = split_commit::run_split_commit(split_commit_args) {
246                eprintln!("{error:#}");
247                std::process::exit(1);
248            }
249            return;
250        }
251        _ => {}
252    }
253
254    let mut examples = read_examples(&args.inputs);
255    let http_client = Arc::new(ReqwestClient::new());
256    let app = Application::headless().with_http_client(http_client);
257
258    app.run(move |cx| {
259        let app_state = Arc::new(headless::init(cx));
260        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
261
262        cx.spawn(async move |cx| {
263            let result = async {
264                if let Command::Predict(args) = &command {
265                    predict::sync_batches(&args.provider).await?;
266                }
267
268                let total_examples = examples.len();
269                Progress::global().set_total_examples(total_examples);
270
271                let mut grouped_examples = group_examples_by_repo(&mut examples);
272                let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
273
274                for example_batch in example_batches {
275                    let futures = example_batch.into_iter().map(|repo_examples| async {
276                        for example in repo_examples.iter_mut() {
277                            let result = async {
278                                match &command {
279                                    Command::ParseExample => {}
280                                    Command::LoadProject => {
281                                        run_load_project(example, app_state.clone(), cx.clone())
282                                            .await?;
283                                    }
284                                    Command::Context => {
285                                        run_context_retrieval(
286                                            example,
287                                            app_state.clone(),
288                                            cx.clone(),
289                                        )
290                                        .await?;
291                                    }
292                                    Command::FormatPrompt(args) => {
293                                        run_format_prompt(
294                                            example,
295                                            args.prompt_format,
296                                            app_state.clone(),
297                                            cx.clone(),
298                                        )
299                                        .await?;
300                                    }
301                                    Command::Predict(args) => {
302                                        run_prediction(
303                                            example,
304                                            Some(args.provider),
305                                            args.repetitions,
306                                            app_state.clone(),
307                                            cx.clone(),
308                                        )
309                                        .await?;
310                                    }
311                                    Command::Distill => {
312                                        run_distill(example).await?;
313                                    }
314                                    Command::Score(args) | Command::Eval(args) => {
315                                        run_scoring(example, &args, app_state.clone(), cx.clone())
316                                            .await?;
317                                    }
318                                    Command::Clean
319                                    | Command::Synthesize(_)
320                                    | Command::SplitCommit(_) => {
321                                        unreachable!()
322                                    }
323                                }
324                                anyhow::Ok(())
325                            }
326                            .await;
327
328                            if let Err(e) = result {
329                                Progress::global().increment_failed();
330                                let failed_example_path =
331                                    FAILED_EXAMPLES_DIR.join(format!("{}.json", example.spec.name));
332                                app_state
333                                    .fs
334                                    .write(
335                                        &failed_example_path,
336                                        &serde_json::to_vec_pretty(&example).unwrap(),
337                                    )
338                                    .await
339                                    .unwrap();
340                                let err_path = FAILED_EXAMPLES_DIR
341                                    .join(format!("{}_err.txt", example.spec.name));
342                                app_state
343                                    .fs
344                                    .write(&err_path, e.to_string().as_bytes())
345                                    .await
346                                    .unwrap();
347
348                                let msg = format!(
349                                    indoc::indoc! {"
350                                        While processing {}:
351
352                                        {:?}
353
354                                        Written to: \x1b[36m{}\x1b[0m
355
356                                        Explore this example data with:
357                                            fx \x1b[36m{}\x1b[0m
358
359                                        Re-run this example with:
360                                            cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
361                                    "},
362                                    example.spec.name,
363                                    e,
364                                    err_path.display(),
365                                    failed_example_path.display(),
366                                    command,
367                                    failed_example_path.display(),
368                                );
369                                if args.failfast || total_examples == 1 {
370                                    Progress::global().finalize();
371                                    panic!("{}", msg);
372                                } else {
373                                    log::error!("{}", msg);
374                                }
375                            }
376                        }
377                    });
378                    futures::future::join_all(futures).await;
379                }
380                Progress::global().finalize();
381
382                if args.output.is_some() || !matches!(command, Command::Eval(_)) {
383                    write_examples(&examples, output.as_ref());
384                }
385
386                match &command {
387                    Command::Predict(args) => predict::sync_batches(&args.provider).await?,
388                    Command::Eval(_) => score::print_report(&examples),
389                    _ => (),
390                };
391
392                anyhow::Ok(())
393            }
394            .await;
395
396            if let Err(e) = result {
397                panic!("Fatal error: {:?}", e);
398            }
399
400            let _ = cx.update(|cx| cx.quit());
401        })
402        .detach();
403    });
404}