main.rs

  1mod anthropic_client;
  2mod distill;
  3mod example;
  4mod format_prompt;
  5mod git;
  6mod headless;
  7mod load_project;
  8mod metrics;
  9mod paths;
 10mod predict;
 11mod progress;
 12mod retrieve_context;
 13mod score;
 14mod synthesize;
 15
 16use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
 17use edit_prediction::EditPredictionStore;
 18use gpui::Application;
 19use reqwest_client::ReqwestClient;
 20use serde::{Deserialize, Serialize};
 21use std::fmt::Display;
 22use std::{path::PathBuf, sync::Arc};
 23
 24use crate::distill::run_distill;
 25use crate::example::{group_examples_by_repo, read_examples, write_examples};
 26use crate::format_prompt::run_format_prompt;
 27use crate::load_project::run_load_project;
 28use crate::paths::FAILED_EXAMPLES_DIR;
 29use crate::predict::run_prediction;
 30use crate::progress::Progress;
 31use crate::retrieve_context::run_context_retrieval;
 32use crate::score::run_scoring;
 33use crate::synthesize::{SynthesizeConfig, run_synthesize};
 34
 35#[derive(Parser, Debug)]
 36#[command(name = "ep")]
 37struct EpArgs {
 38    #[arg(long, default_value_t = false)]
 39    printenv: bool,
 40    #[clap(long, default_value_t = 10, global = true)]
 41    max_parallelism: usize,
 42    #[command(subcommand)]
 43    command: Option<Command>,
 44    #[clap(global = true)]
 45    inputs: Vec<PathBuf>,
 46    #[arg(long, short, global = true)]
 47    output: Option<PathBuf>,
 48    #[arg(long, short, global = true)]
 49    in_place: bool,
 50    #[arg(long, short, global = true)]
 51    failfast: bool,
 52}
 53
 54#[derive(Subcommand, Debug)]
 55enum Command {
 56    /// Parse markdown examples and output a combined .jsonl file
 57    ParseExample,
 58    /// Create git worktrees for each example and load file contents
 59    LoadProject,
 60    /// Retrieve context for input examples.
 61    Context,
 62    /// Generate a prompt string for a specific model
 63    FormatPrompt(FormatPromptArgs),
 64    /// Runs edit prediction
 65    Predict(PredictArgs),
 66    /// Computes a score based on actual and expected patches
 67    Score(PredictArgs),
 68    /// Prepares a distillation dataset by copying expected outputs to
 69    /// predicted outputs and removing actual outputs and prompts.
 70    Distill,
 71    /// Print aggregated scores
 72    Eval(PredictArgs),
 73    /// Generate eval examples by analyzing git commits from a repository
 74    Synthesize(SynthesizeArgs),
 75    /// Remove git repositories and worktrees
 76    Clean,
 77}
 78
 79impl Display for Command {
 80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 81        match self {
 82            Command::ParseExample => write!(f, "parse-example"),
 83            Command::LoadProject => write!(f, "load-project"),
 84            Command::Context => write!(f, "context"),
 85            Command::FormatPrompt(format_prompt_args) => write!(
 86                f,
 87                "format-prompt --prompt-format={}",
 88                format_prompt_args
 89                    .prompt_format
 90                    .to_possible_value()
 91                    .unwrap()
 92                    .get_name()
 93            ),
 94            Command::Predict(predict_args) => {
 95                write!(
 96                    f,
 97                    "predict --provider={:?}",
 98                    predict_args
 99                        .provider
100                        .to_possible_value()
101                        .unwrap()
102                        .get_name()
103                )
104            }
105            Command::Score(predict_args) => {
106                write!(
107                    f,
108                    "score --provider={:?}",
109                    predict_args
110                        .provider
111                        .to_possible_value()
112                        .unwrap()
113                        .get_name()
114                )
115            }
116            Command::Distill => write!(f, "distill"),
117            Command::Eval(predict_args) => write!(
118                f,
119                "eval --provider={:?}",
120                predict_args
121                    .provider
122                    .to_possible_value()
123                    .unwrap()
124                    .get_name()
125            ),
126            Command::Synthesize(args) => {
127                write!(f, "synthesize --repo={}", args.repo)
128            }
129            Command::Clean => write!(f, "clean"),
130        }
131    }
132}
133
134#[derive(Debug, Args)]
135struct FormatPromptArgs {
136    #[clap(long)]
137    prompt_format: PromptFormat,
138}
139
140#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
141enum PromptFormat {
142    Teacher,
143    Zeta2,
144}
145
146#[derive(Debug, Args)]
147struct PredictArgs {
148    #[clap(long)]
149    provider: PredictionProvider,
150    #[clap(long, default_value_t = 1)]
151    repetitions: usize,
152}
153
154#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
155enum PredictionProvider {
156    Sweep,
157    Mercury,
158    Zeta1,
159    Zeta2,
160    Teacher,
161    TeacherNonBatching,
162}
163
164#[derive(Debug, Args)]
165struct SynthesizeArgs {
166    /// Repository URL (git@github.com:owner/repo or https://...)
167    #[clap(long)]
168    repo: String,
169
170    /// Number of examples to generate
171    #[clap(long, default_value_t = 5)]
172    count: usize,
173
174    /// Maximum commits to scan before giving up
175    #[clap(long, default_value_t = 100)]
176    max_commits: usize,
177
178    /// Output directory for draft examples
179    #[clap(long, default_value = "staging")]
180    output_dir: PathBuf,
181
182    /// Only generate examples that require retrieved context to make a correct prediction
183    #[clap(long)]
184    require_context: bool,
185
186    /// Ignore state file and reprocess all commits
187    #[clap(long)]
188    fresh: bool,
189}
190
191impl EpArgs {
192    fn output_path(&self) -> Option<PathBuf> {
193        if self.in_place {
194            if self.inputs.len() == 1 {
195                self.inputs.first().cloned()
196            } else {
197                panic!("--in-place requires exactly one input file")
198            }
199        } else {
200            self.output.clone()
201        }
202    }
203}
204
205fn main() {
206    let args = EpArgs::parse();
207
208    if args.printenv {
209        ::util::shell_env::print_env();
210        return;
211    }
212
213    let output = args.output_path();
214    let command = match args.command {
215        Some(cmd) => cmd,
216        None => {
217            EpArgs::command().print_help().unwrap();
218            return;
219        }
220    };
221
222    match &command {
223        Command::Clean => {
224            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
225            return;
226        }
227        Command::Synthesize(synth_args) => {
228            let config = SynthesizeConfig {
229                repo_url: synth_args.repo.clone(),
230                count: synth_args.count,
231                max_commits: synth_args.max_commits,
232                output_dir: synth_args.output_dir.clone(),
233                require_context: synth_args.require_context,
234                fresh: synth_args.fresh,
235            };
236            smol::block_on(async {
237                if let Err(e) = run_synthesize(config).await {
238                    eprintln!("Error: {:?}", e);
239                    std::process::exit(1);
240                }
241            });
242            return;
243        }
244        _ => {}
245    }
246
247    let mut examples = read_examples(&args.inputs);
248    let http_client = Arc::new(ReqwestClient::new());
249    let app = Application::headless().with_http_client(http_client);
250
251    app.run(move |cx| {
252        let app_state = Arc::new(headless::init(cx));
253        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
254
255        cx.spawn(async move |cx| {
256            let result = async {
257                if let Command::Predict(args) = &command {
258                    predict::sync_batches(&args.provider).await?;
259                }
260
261                let total_examples = examples.len();
262                Progress::global().set_total_examples(total_examples);
263
264                let mut grouped_examples = group_examples_by_repo(&mut examples);
265                let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
266
267                for example_batch in example_batches {
268                    let futures = example_batch.into_iter().map(|repo_examples| async {
269                        for example in repo_examples.iter_mut() {
270                            let result = async {
271                                match &command {
272                                    Command::ParseExample => {}
273                                    Command::LoadProject => {
274                                        run_load_project(example, app_state.clone(), cx.clone())
275                                            .await?;
276                                    }
277                                    Command::Context => {
278                                        run_context_retrieval(
279                                            example,
280                                            app_state.clone(),
281                                            cx.clone(),
282                                        )
283                                        .await?;
284                                    }
285                                    Command::FormatPrompt(args) => {
286                                        run_format_prompt(
287                                            example,
288                                            args.prompt_format,
289                                            app_state.clone(),
290                                            cx.clone(),
291                                        )
292                                        .await?;
293                                    }
294                                    Command::Predict(args) => {
295                                        run_prediction(
296                                            example,
297                                            Some(args.provider),
298                                            args.repetitions,
299                                            app_state.clone(),
300                                            cx.clone(),
301                                        )
302                                        .await?;
303                                    }
304                                    Command::Distill => {
305                                        run_distill(example).await?;
306                                    }
307                                    Command::Score(args) | Command::Eval(args) => {
308                                        run_scoring(example, &args, app_state.clone(), cx.clone())
309                                            .await?;
310                                    }
311                                    Command::Clean | Command::Synthesize(_) => {
312                                        unreachable!()
313                                    }
314                                }
315                                anyhow::Ok(())
316                            }
317                            .await;
318
319                            if let Err(e) = result {
320                                Progress::global().increment_failed();
321                                let failed_example_path =
322                                    FAILED_EXAMPLES_DIR.join(format!("{}.json", example.spec.name));
323                                app_state
324                                    .fs
325                                    .write(
326                                        &failed_example_path,
327                                        &serde_json::to_vec_pretty(&example).unwrap(),
328                                    )
329                                    .await
330                                    .unwrap();
331                                let err_path = FAILED_EXAMPLES_DIR
332                                    .join(format!("{}_err.txt", example.spec.name));
333                                app_state
334                                    .fs
335                                    .write(&err_path, e.to_string().as_bytes())
336                                    .await
337                                    .unwrap();
338
339                                let msg = format!(
340                                    indoc::indoc! {"
341                                        While processing {}:
342
343                                        {:?}
344
345                                        Written to: \x1b[36m{}\x1b[0m
346
347                                        Explore this example data with:
348                                            fx \x1b[36m{}\x1b[0m
349
350                                        Re-run this example with:
351                                            cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
352                                    "},
353                                    example.spec.name,
354                                    e,
355                                    err_path.display(),
356                                    failed_example_path.display(),
357                                    command,
358                                    failed_example_path.display(),
359                                );
360                                if args.failfast || total_examples == 1 {
361                                    Progress::global().finalize();
362                                    panic!("{}", msg);
363                                } else {
364                                    log::error!("{}", msg);
365                                }
366                            }
367                        }
368                    });
369                    futures::future::join_all(futures).await;
370                }
371                Progress::global().finalize();
372
373                if args.output.is_some() || !matches!(command, Command::Eval(_)) {
374                    write_examples(&examples, output.as_ref());
375                }
376
377                match &command {
378                    Command::Predict(args) => predict::sync_batches(&args.provider).await?,
379                    Command::Eval(_) => score::print_report(&examples),
380                    _ => (),
381                };
382
383                anyhow::Ok(())
384            }
385            .await;
386
387            if let Err(e) = result {
388                panic!("Fatal error: {:?}", e);
389            }
390
391            let _ = cx.update(|cx| cx.quit());
392        })
393        .detach();
394    });
395}