main.rs

  1mod anthropic_client;
  2mod distill;
  3mod example;
  4mod format_prompt;
  5mod headless;
  6mod load_project;
  7mod metrics;
  8mod paths;
  9mod predict;
 10mod progress;
 11mod retrieve_context;
 12mod score;
 13
 14use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
 15use edit_prediction::EditPredictionStore;
 16use gpui::Application;
 17use reqwest_client::ReqwestClient;
 18use serde::{Deserialize, Serialize};
 19use std::fmt::Display;
 20use std::{path::PathBuf, sync::Arc};
 21
 22use crate::distill::run_distill;
 23use crate::example::{group_examples_by_repo, read_examples, write_examples};
 24use crate::format_prompt::run_format_prompt;
 25use crate::load_project::run_load_project;
 26use crate::paths::FAILED_EXAMPLES_DIR;
 27use crate::predict::run_prediction;
 28use crate::progress::Progress;
 29use crate::retrieve_context::run_context_retrieval;
 30use crate::score::run_scoring;
 31
 32#[derive(Parser, Debug)]
 33#[command(name = "ep")]
 34struct EpArgs {
 35    #[arg(long, default_value_t = false)]
 36    printenv: bool,
 37    #[clap(long, default_value_t = 10, global = true)]
 38    max_parallelism: usize,
 39    #[command(subcommand)]
 40    command: Option<Command>,
 41    #[clap(global = true)]
 42    inputs: Vec<PathBuf>,
 43    #[arg(long, short, global = true)]
 44    output: Option<PathBuf>,
 45    #[arg(long, short, global = true)]
 46    in_place: bool,
 47    #[arg(long, short, global = true)]
 48    failfast: bool,
 49}
 50
 51#[derive(Subcommand, Debug)]
 52enum Command {
 53    /// Parse markdown examples and output a combined .jsonl file
 54    ParseExample,
 55    /// Create git worktrees for each example and load file contents
 56    LoadProject,
 57    /// Retrieve context for input examples.
 58    Context,
 59    /// Generate a prompt string for a specific model
 60    FormatPrompt(FormatPromptArgs),
 61    /// Runs edit prediction
 62    Predict(PredictArgs),
 63    /// Computes a score based on actual and expected patches
 64    Score(PredictArgs),
 65    /// Prepares a distillation dataset by copying expected outputs to
 66    /// predicted outputs and removing actual outputs and prompts.
 67    Distill,
 68    /// Print aggregated scores
 69    Eval(PredictArgs),
 70    /// Remove git repositories and worktrees
 71    Clean,
 72}
 73
 74impl Display for Command {
 75    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 76        match self {
 77            Command::ParseExample => write!(f, "parse-example"),
 78            Command::LoadProject => write!(f, "load-project"),
 79            Command::Context => write!(f, "context"),
 80            Command::FormatPrompt(format_prompt_args) => write!(
 81                f,
 82                "format-prompt --prompt-format={}",
 83                format_prompt_args
 84                    .prompt_format
 85                    .to_possible_value()
 86                    .unwrap()
 87                    .get_name()
 88            ),
 89            Command::Predict(predict_args) => {
 90                write!(
 91                    f,
 92                    "predict --provider={:?}",
 93                    predict_args
 94                        .provider
 95                        .to_possible_value()
 96                        .unwrap()
 97                        .get_name()
 98                )
 99            }
100            Command::Score(predict_args) => {
101                write!(
102                    f,
103                    "score --provider={:?}",
104                    predict_args
105                        .provider
106                        .to_possible_value()
107                        .unwrap()
108                        .get_name()
109                )
110            }
111            Command::Distill => write!(f, "distill"),
112            Command::Eval(predict_args) => write!(
113                f,
114                "eval --provider={:?}",
115                predict_args
116                    .provider
117                    .to_possible_value()
118                    .unwrap()
119                    .get_name()
120            ),
121            Command::Clean => write!(f, "clean"),
122        }
123    }
124}
125
126#[derive(Debug, Args)]
127struct FormatPromptArgs {
128    #[clap(long)]
129    prompt_format: PromptFormat,
130}
131
132#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
133enum PromptFormat {
134    Teacher,
135    Zeta2,
136}
137
138#[derive(Debug, Args)]
139struct PredictArgs {
140    #[clap(long)]
141    provider: PredictionProvider,
142    #[clap(long, default_value_t = 1)]
143    repetitions: usize,
144}
145
146#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
147enum PredictionProvider {
148    Sweep,
149    Mercury,
150    Zeta1,
151    Zeta2,
152    Teacher,
153    TeacherNonBatching,
154}
155
156impl EpArgs {
157    fn output_path(&self) -> Option<PathBuf> {
158        if self.in_place {
159            if self.inputs.len() == 1 {
160                self.inputs.first().cloned()
161            } else {
162                panic!("--in-place requires exactly one input file")
163            }
164        } else {
165            self.output.clone()
166        }
167    }
168}
169
170fn main() {
171    let args = EpArgs::parse();
172
173    if args.printenv {
174        ::util::shell_env::print_env();
175        return;
176    }
177
178    let output = args.output_path();
179    let command = match args.command {
180        Some(cmd) => cmd,
181        None => {
182            EpArgs::command().print_help().unwrap();
183            return;
184        }
185    };
186
187    match &command {
188        Command::Clean => {
189            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
190            return;
191        }
192        _ => {}
193    }
194
195    let mut examples = read_examples(&args.inputs);
196    let http_client = Arc::new(ReqwestClient::new());
197    let app = Application::headless().with_http_client(http_client);
198
199    app.run(move |cx| {
200        let app_state = Arc::new(headless::init(cx));
201        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
202
203        cx.spawn(async move |cx| {
204            let result = async {
205                if let Command::Predict(args) = &command {
206                    predict::sync_batches(&args.provider).await?;
207                }
208
209                let total_examples = examples.len();
210                Progress::global().set_total_examples(total_examples);
211
212                let mut grouped_examples = group_examples_by_repo(&mut examples);
213                let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
214
215                for example_batch in example_batches {
216                    let futures = example_batch.into_iter().map(|repo_examples| async {
217                        for example in repo_examples.iter_mut() {
218                            let result = async {
219                                match &command {
220                                    Command::ParseExample => {}
221                                    Command::LoadProject => {
222                                        run_load_project(example, app_state.clone(), cx.clone())
223                                            .await?;
224                                    }
225                                    Command::Context => {
226                                        run_context_retrieval(
227                                            example,
228                                            app_state.clone(),
229                                            cx.clone(),
230                                        )
231                                        .await?;
232                                    }
233                                    Command::FormatPrompt(args) => {
234                                        run_format_prompt(
235                                            example,
236                                            args.prompt_format,
237                                            app_state.clone(),
238                                            cx.clone(),
239                                        )
240                                        .await?;
241                                    }
242                                    Command::Predict(args) => {
243                                        run_prediction(
244                                            example,
245                                            Some(args.provider),
246                                            args.repetitions,
247                                            app_state.clone(),
248                                            cx.clone(),
249                                        )
250                                        .await?;
251                                    }
252                                    Command::Distill => {
253                                        run_distill(example).await?;
254                                    }
255                                    Command::Score(args) | Command::Eval(args) => {
256                                        run_scoring(example, &args, app_state.clone(), cx.clone())
257                                            .await?;
258                                    }
259                                    Command::Clean => {
260                                        unreachable!()
261                                    }
262                                }
263                                anyhow::Ok(())
264                            }
265                            .await;
266
267                            if let Err(e) = result {
268                                Progress::global().increment_failed();
269                                let failed_example_path =
270                                    FAILED_EXAMPLES_DIR.join(format!("{}.json", example.spec.name));
271                                app_state
272                                    .fs
273                                    .write(
274                                        &failed_example_path,
275                                        &serde_json::to_vec_pretty(&example).unwrap(),
276                                    )
277                                    .await
278                                    .unwrap();
279                                let err_path = FAILED_EXAMPLES_DIR
280                                    .join(format!("{}_err.txt", example.spec.name));
281                                app_state
282                                    .fs
283                                    .write(&err_path, e.to_string().as_bytes())
284                                    .await
285                                    .unwrap();
286
287                                let msg = format!(
288                                    indoc::indoc! {"
289                                        While processing {}:
290
291                                        {:?}
292
293                                        Written to: \x1b[36m{}\x1b[0m
294
295                                        Explore this example data with:
296                                            fx \x1b[36m{}\x1b[0m
297
298                                        Re-run this example with:
299                                            cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
300                                    "},
301                                    example.spec.name,
302                                    e,
303                                    err_path.display(),
304                                    failed_example_path.display(),
305                                    command,
306                                    failed_example_path.display(),
307                                );
308                                if args.failfast || total_examples == 1 {
309                                    Progress::global().finalize();
310                                    panic!("{}", msg);
311                                } else {
312                                    log::error!("{}", msg);
313                                }
314                            }
315                        }
316                    });
317                    futures::future::join_all(futures).await;
318                }
319                Progress::global().finalize();
320
321                if args.output.is_some() || !matches!(command, Command::Eval(_)) {
322                    write_examples(&examples, output.as_ref());
323                }
324
325                match &command {
326                    Command::Predict(args) => predict::sync_batches(&args.provider).await?,
327                    Command::Eval(_) => score::print_report(&examples),
328                    _ => (),
329                };
330
331                anyhow::Ok(())
332            }
333            .await;
334
335            if let Err(e) = result {
336                panic!("Fatal error: {:?}", e);
337            }
338
339            let _ = cx.update(|cx| cx.quit());
340        })
341        .detach();
342    });
343}