main.rs

  1mod anthropic_client;
  2mod distill;
  3mod example;
  4mod format_prompt;
  5mod headless;
  6mod load_project;
  7mod metrics;
  8mod paths;
  9mod predict;
 10mod progress;
 11mod retrieve_context;
 12mod score;
 13
 14use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
 15use edit_prediction::EditPredictionStore;
 16use gpui::Application;
 17use reqwest_client::ReqwestClient;
 18use serde::{Deserialize, Serialize};
 19use std::{path::PathBuf, sync::Arc};
 20
 21use crate::distill::run_distill;
 22use crate::example::{group_examples_by_repo, read_examples, write_examples};
 23use crate::format_prompt::run_format_prompt;
 24use crate::load_project::run_load_project;
 25use crate::predict::run_prediction;
 26use crate::progress::Progress;
 27use crate::retrieve_context::run_context_retrieval;
 28use crate::score::run_scoring;
 29
 30#[derive(Parser, Debug)]
 31#[command(name = "ep")]
 32struct EpArgs {
 33    #[arg(long, default_value_t = false)]
 34    printenv: bool,
 35    #[clap(long, default_value_t = 10)]
 36    max_parallelism: usize,
 37    #[command(subcommand)]
 38    command: Option<Command>,
 39    #[clap(global = true)]
 40    inputs: Vec<PathBuf>,
 41    #[arg(long, short, global = true)]
 42    output: Option<PathBuf>,
 43    #[arg(long, short, global = true)]
 44    in_place: bool,
 45}
 46
 47#[derive(Subcommand, Debug)]
 48enum Command {
 49    /// Parse markdown examples and output a combined .jsonl file
 50    ParseExample,
 51    /// Create git worktrees for each example and load file contents
 52    LoadProject,
 53    /// Retrieve context for input examples.
 54    Context,
 55    /// Generate a prompt string for a specific model
 56    FormatPrompt(FormatPromptArgs),
 57    /// Runs edit prediction
 58    Predict(PredictArgs),
 59    /// Computes a score based on actual and expected patches
 60    Score(PredictArgs),
 61    /// Prepares a distillation dataset by copying expected outputs to
 62    /// predicted outputs and removing actual outputs and prompts.
 63    Distill,
 64    /// Print aggregated scores
 65    Eval(PredictArgs),
 66    /// Remove git repositories and worktrees
 67    Clean,
 68}
 69
 70#[derive(Debug, Args)]
 71struct FormatPromptArgs {
 72    #[clap(long)]
 73    prompt_format: PromptFormat,
 74}
 75
 76#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
 77enum PromptFormat {
 78    Teacher,
 79    Zeta2,
 80}
 81
 82#[derive(Debug, Args)]
 83struct PredictArgs {
 84    #[clap(long)]
 85    provider: PredictionProvider,
 86    #[clap(long, default_value_t = 1)]
 87    repetitions: usize,
 88}
 89
 90#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
 91enum PredictionProvider {
 92    Sweep,
 93    Mercury,
 94    Zeta1,
 95    Zeta2,
 96    Teacher,
 97    TeacherNonBatching,
 98}
 99
100impl EpArgs {
101    fn output_path(&self) -> Option<PathBuf> {
102        if self.in_place {
103            if self.inputs.len() == 1 {
104                self.inputs.first().cloned()
105            } else {
106                panic!("--in-place requires exactly one input file")
107            }
108        } else {
109            self.output.clone()
110        }
111    }
112}
113
114fn main() {
115    let _ = zlog::try_init(Some("error".into()));
116    zlog::init_output_stderr();
117    let args = EpArgs::parse();
118
119    if args.printenv {
120        ::util::shell_env::print_env();
121        return;
122    }
123
124    let output = args.output_path();
125    let command = match args.command {
126        Some(cmd) => cmd,
127        None => {
128            EpArgs::command().print_help().unwrap();
129            return;
130        }
131    };
132
133    match &command {
134        Command::Clean => {
135            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
136            return;
137        }
138        _ => {}
139    }
140
141    let mut examples = read_examples(&args.inputs);
142    let http_client = Arc::new(ReqwestClient::new());
143    let app = Application::headless().with_http_client(http_client);
144
145    app.run(move |cx| {
146        let app_state = Arc::new(headless::init(cx));
147        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
148
149        cx.spawn(async move |cx| {
150            if let Command::Predict(args) = &command {
151                predict::sync_batches(&args.provider).await
152            };
153
154            let total_examples = examples.len();
155            let progress = Progress::new(total_examples);
156
157            let mut grouped_examples = group_examples_by_repo(&mut examples);
158            let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
159
160            for example_batch in example_batches {
161                let futures = example_batch.into_iter().map(|repo_examples| async {
162                    for example in repo_examples.iter_mut() {
163                        match &command {
164                            Command::ParseExample => {}
165                            Command::LoadProject => {
166                                run_load_project(
167                                    example,
168                                    app_state.clone(),
169                                    progress.clone(),
170                                    cx.clone(),
171                                )
172                                .await;
173                            }
174                            Command::Context => {
175                                run_context_retrieval(
176                                    example,
177                                    app_state.clone(),
178                                    progress.clone(),
179                                    cx.clone(),
180                                )
181                                .await;
182                            }
183                            Command::FormatPrompt(args) => {
184                                run_format_prompt(
185                                    example,
186                                    args.prompt_format,
187                                    app_state.clone(),
188                                    progress.clone(),
189                                    cx.clone(),
190                                )
191                                .await;
192                            }
193                            Command::Predict(args) => {
194                                run_prediction(
195                                    example,
196                                    Some(args.provider),
197                                    args.repetitions,
198                                    app_state.clone(),
199                                    progress.clone(),
200                                    cx.clone(),
201                                )
202                                .await;
203                            }
204                            Command::Distill => {
205                                run_distill(example).await;
206                            }
207                            Command::Score(args) | Command::Eval(args) => {
208                                run_scoring(
209                                    example,
210                                    &args,
211                                    app_state.clone(),
212                                    progress.clone(),
213                                    cx.clone(),
214                                )
215                                .await;
216                            }
217                            Command::Clean => {
218                                unreachable!()
219                            }
220                        }
221                    }
222                });
223                futures::future::join_all(futures).await;
224            }
225            progress.clear();
226
227            if args.output.is_some() || !matches!(command, Command::Eval(_)) {
228                write_examples(&examples, output.as_ref());
229            }
230
231            match &command {
232                Command::Predict(args) => predict::sync_batches(&args.provider).await,
233                Command::Eval(_) => score::print_report(&examples),
234                _ => (),
235            };
236
237            let _ = cx.update(|cx| cx.quit());
238        })
239        .detach();
240    });
241}