main.rs

  1mod anthropic_client;
  2mod distill;
  3mod example;
  4mod format_prompt;
  5mod git;
  6mod headless;
  7mod load_project;
  8mod metrics;
  9mod paths;
 10mod predict;
 11mod progress;
 12mod retrieve_context;
 13mod score;
 14mod synthesize;
 15
 16use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
 17use edit_prediction::EditPredictionStore;
 18use gpui::Application;
 19use reqwest_client::ReqwestClient;
 20use serde::{Deserialize, Serialize};
 21use std::fmt::Display;
 22use std::{path::PathBuf, sync::Arc};
 23
 24use crate::distill::run_distill;
 25use crate::example::{group_examples_by_repo, read_examples, write_examples};
 26use crate::format_prompt::run_format_prompt;
 27use crate::load_project::run_load_project;
 28use crate::paths::FAILED_EXAMPLES_DIR;
 29use crate::predict::run_prediction;
 30use crate::progress::Progress;
 31use crate::retrieve_context::run_context_retrieval;
 32use crate::score::run_scoring;
 33use crate::synthesize::{SynthesizeConfig, run_synthesize};
 34
 35#[derive(Parser, Debug)]
 36#[command(name = "ep")]
 37struct EpArgs {
 38    #[arg(long, default_value_t = false)]
 39    printenv: bool,
 40    #[clap(long, default_value_t = 10, global = true)]
 41    max_parallelism: usize,
 42    #[command(subcommand)]
 43    command: Option<Command>,
 44    #[clap(global = true)]
 45    inputs: Vec<PathBuf>,
 46    #[arg(long, short, global = true)]
 47    output: Option<PathBuf>,
 48    #[arg(long, short, global = true)]
 49    in_place: bool,
 50    #[arg(long, short, global = true)]
 51    failfast: bool,
 52}
 53
 54#[derive(Subcommand, Debug)]
 55enum Command {
 56    /// Parse markdown examples and output a combined .jsonl file
 57    ParseExample,
 58    /// Create git worktrees for each example and load file contents
 59    LoadProject,
 60    /// Retrieve context for input examples.
 61    Context,
 62    /// Generate a prompt string for a specific model
 63    FormatPrompt(FormatPromptArgs),
 64    /// Runs edit prediction
 65    Predict(PredictArgs),
 66    /// Computes a score based on actual and expected patches
 67    Score(PredictArgs),
 68    /// Prepares a distillation dataset by copying expected outputs to
 69    /// predicted outputs and removing actual outputs and prompts.
 70    Distill,
 71    /// Print aggregated scores
 72    Eval(PredictArgs),
 73    /// Generate eval examples by analyzing git commits from a repository
 74    Synthesize(SynthesizeArgs),
 75    /// Remove git repositories and worktrees
 76    Clean,
 77}
 78
 79impl Display for Command {
 80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 81        match self {
 82            Command::ParseExample => write!(f, "parse-example"),
 83            Command::LoadProject => write!(f, "load-project"),
 84            Command::Context => write!(f, "context"),
 85            Command::FormatPrompt(format_prompt_args) => write!(
 86                f,
 87                "format-prompt --prompt-format={}",
 88                format_prompt_args
 89                    .prompt_format
 90                    .to_possible_value()
 91                    .unwrap()
 92                    .get_name()
 93            ),
 94            Command::Predict(predict_args) => {
 95                write!(
 96                    f,
 97                    "predict --provider={:?}",
 98                    predict_args
 99                        .provider
100                        .to_possible_value()
101                        .unwrap()
102                        .get_name()
103                )
104            }
105            Command::Score(predict_args) => {
106                write!(
107                    f,
108                    "score --provider={:?}",
109                    predict_args
110                        .provider
111                        .to_possible_value()
112                        .unwrap()
113                        .get_name()
114                )
115            }
116            Command::Distill => write!(f, "distill"),
117            Command::Eval(predict_args) => write!(
118                f,
119                "eval --provider={:?}",
120                predict_args
121                    .provider
122                    .to_possible_value()
123                    .unwrap()
124                    .get_name()
125            ),
126            Command::Synthesize(args) => {
127                write!(f, "synthesize --repo={}", args.repo)
128            }
129            Command::Clean => write!(f, "clean"),
130        }
131    }
132}
133
134#[derive(Debug, Args)]
135struct FormatPromptArgs {
136    #[clap(long)]
137    prompt_format: PromptFormat,
138}
139
140#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
141enum PromptFormat {
142    Teacher,
143    Zeta2,
144}
145
146#[derive(Debug, Args)]
147struct PredictArgs {
148    #[clap(long)]
149    provider: PredictionProvider,
150    #[clap(long, default_value_t = 1)]
151    repetitions: usize,
152}
153
154#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
155enum PredictionProvider {
156    Sweep,
157    Mercury,
158    Zeta1,
159    Zeta2,
160    Teacher,
161    TeacherNonBatching,
162}
163
164#[derive(Debug, Args)]
165struct SynthesizeArgs {
166    /// Repository URL (git@github.com:owner/repo or https://...)
167    #[clap(long)]
168    repo: String,
169
170    /// Number of examples to generate
171    #[clap(long, default_value_t = 5)]
172    count: usize,
173
174    /// Maximum commits to scan before giving up
175    #[clap(long, default_value_t = 100)]
176    max_commits: usize,
177
178    /// Ignore state file and reprocess all commits
179    #[clap(long)]
180    fresh: bool,
181}
182
183impl EpArgs {
184    fn output_path(&self) -> Option<PathBuf> {
185        if self.in_place {
186            if self.inputs.len() == 1 {
187                self.inputs.first().cloned()
188            } else {
189                panic!("--in-place requires exactly one input file")
190            }
191        } else {
192            self.output.clone()
193        }
194    }
195}
196
197fn main() {
198    let args = EpArgs::parse();
199
200    if args.printenv {
201        ::util::shell_env::print_env();
202        return;
203    }
204
205    let output = args.output_path();
206    let command = match args.command {
207        Some(cmd) => cmd,
208        None => {
209            EpArgs::command().print_help().unwrap();
210            return;
211        }
212    };
213
214    match &command {
215        Command::Clean => {
216            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
217            return;
218        }
219        Command::Synthesize(synth_args) => {
220            let Some(output_dir) = args.output else {
221                panic!("output dir is required");
222            };
223            let config = SynthesizeConfig {
224                repo_url: synth_args.repo.clone(),
225                count: synth_args.count,
226                max_commits: synth_args.max_commits,
227                output_dir,
228                fresh: synth_args.fresh,
229            };
230            smol::block_on(async {
231                if let Err(e) = run_synthesize(config).await {
232                    eprintln!("Error: {:?}", e);
233                    std::process::exit(1);
234                }
235            });
236            return;
237        }
238        _ => {}
239    }
240
241    let mut examples = read_examples(&args.inputs);
242    let http_client = Arc::new(ReqwestClient::new());
243    let app = Application::headless().with_http_client(http_client);
244
245    app.run(move |cx| {
246        let app_state = Arc::new(headless::init(cx));
247        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
248
249        cx.spawn(async move |cx| {
250            let result = async {
251                if let Command::Predict(args) = &command {
252                    predict::sync_batches(&args.provider).await?;
253                }
254
255                let total_examples = examples.len();
256                Progress::global().set_total_examples(total_examples);
257
258                let mut grouped_examples = group_examples_by_repo(&mut examples);
259                let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
260
261                for example_batch in example_batches {
262                    let futures = example_batch.into_iter().map(|repo_examples| async {
263                        for example in repo_examples.iter_mut() {
264                            let result = async {
265                                match &command {
266                                    Command::ParseExample => {}
267                                    Command::LoadProject => {
268                                        run_load_project(example, app_state.clone(), cx.clone())
269                                            .await?;
270                                    }
271                                    Command::Context => {
272                                        run_context_retrieval(
273                                            example,
274                                            app_state.clone(),
275                                            cx.clone(),
276                                        )
277                                        .await?;
278                                    }
279                                    Command::FormatPrompt(args) => {
280                                        run_format_prompt(
281                                            example,
282                                            args.prompt_format,
283                                            app_state.clone(),
284                                            cx.clone(),
285                                        )
286                                        .await?;
287                                    }
288                                    Command::Predict(args) => {
289                                        run_prediction(
290                                            example,
291                                            Some(args.provider),
292                                            args.repetitions,
293                                            app_state.clone(),
294                                            cx.clone(),
295                                        )
296                                        .await?;
297                                    }
298                                    Command::Distill => {
299                                        run_distill(example).await?;
300                                    }
301                                    Command::Score(args) | Command::Eval(args) => {
302                                        run_scoring(example, &args, app_state.clone(), cx.clone())
303                                            .await?;
304                                    }
305                                    Command::Clean | Command::Synthesize(_) => {
306                                        unreachable!()
307                                    }
308                                }
309                                anyhow::Ok(())
310                            }
311                            .await;
312
313                            if let Err(e) = result {
314                                Progress::global().increment_failed();
315                                let failed_example_path =
316                                    FAILED_EXAMPLES_DIR.join(format!("{}.json", example.spec.name));
317                                app_state
318                                    .fs
319                                    .write(
320                                        &failed_example_path,
321                                        &serde_json::to_vec_pretty(&example).unwrap(),
322                                    )
323                                    .await
324                                    .unwrap();
325                                let err_path = FAILED_EXAMPLES_DIR
326                                    .join(format!("{}_err.txt", example.spec.name));
327                                app_state
328                                    .fs
329                                    .write(&err_path, e.to_string().as_bytes())
330                                    .await
331                                    .unwrap();
332
333                                let msg = format!(
334                                    indoc::indoc! {"
335                                        While processing {}:
336
337                                        {:?}
338
339                                        Written to: \x1b[36m{}\x1b[0m
340
341                                        Explore this example data with:
342                                            fx \x1b[36m{}\x1b[0m
343
344                                        Re-run this example with:
345                                            cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
346                                    "},
347                                    example.spec.name,
348                                    e,
349                                    err_path.display(),
350                                    failed_example_path.display(),
351                                    command,
352                                    failed_example_path.display(),
353                                );
354                                if args.failfast || total_examples == 1 {
355                                    Progress::global().finalize();
356                                    panic!("{}", msg);
357                                } else {
358                                    log::error!("{}", msg);
359                                }
360                            }
361                        }
362                    });
363                    futures::future::join_all(futures).await;
364                }
365                Progress::global().finalize();
366
367                if args.output.is_some() || !matches!(command, Command::Eval(_)) {
368                    write_examples(&examples, output.as_ref());
369                }
370
371                match &command {
372                    Command::Predict(args) => predict::sync_batches(&args.provider).await?,
373                    Command::Eval(_) => score::print_report(&examples),
374                    _ => (),
375                };
376
377                anyhow::Ok(())
378            }
379            .await;
380
381            if let Err(e) = result {
382                panic!("Fatal error: {:?}", e);
383            }
384
385            let _ = cx.update(|cx| cx.quit());
386        })
387        .detach();
388    });
389}