main.rs

  1mod anthropic_client;
  2mod example;
  3mod format_prompt;
  4mod headless;
  5mod load_project;
  6mod metrics;
  7mod paths;
  8mod predict;
  9mod retrieve_context;
 10mod score;
 11
 12use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
 13use edit_prediction::EditPredictionStore;
 14use gpui::Application;
 15use reqwest_client::ReqwestClient;
 16use serde::{Deserialize, Serialize};
 17use std::{path::PathBuf, sync::Arc};
 18
 19use crate::example::{read_examples, write_examples};
 20use crate::format_prompt::run_format_prompt;
 21use crate::load_project::run_load_project;
 22use crate::predict::run_prediction;
 23use crate::retrieve_context::run_context_retrieval;
 24use crate::score::run_scoring;
 25
 26#[derive(Parser, Debug)]
 27#[command(name = "ep")]
 28struct EpArgs {
 29    #[arg(long, default_value_t = false)]
 30    printenv: bool,
 31    #[clap(long, default_value_t = 10)]
 32    max_parallelism: usize,
 33    #[command(subcommand)]
 34    command: Option<Command>,
 35    #[clap(global = true)]
 36    inputs: Vec<PathBuf>,
 37    #[arg(long, short, global = true)]
 38    output: Option<PathBuf>,
 39    #[arg(long, short, global = true)]
 40    in_place: bool,
 41}
 42
 43#[derive(Subcommand, Debug)]
 44enum Command {
 45    /// Parse markdown examples and output a combined .jsonl file
 46    ParseExample,
 47    /// Create git worktrees for each example and load file contents
 48    LoadProject,
 49    /// Retrieve context for input examples.
 50    Context,
 51    /// Generate a prompt string for a specific model
 52    FormatPrompt(FormatPromptArgs),
 53    /// Runs edit prediction
 54    Predict(PredictArgs),
 55    /// Computes a score based on actual and expected patches
 56    Score(PredictArgs),
 57    /// Print aggregated scores
 58    Eval(PredictArgs),
 59    /// Remove git repositories and worktrees
 60    Clean,
 61}
 62
 63#[derive(Debug, Args)]
 64struct FormatPromptArgs {
 65    #[clap(long)]
 66    prompt_format: PromptFormat,
 67}
 68
 69#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
 70enum PromptFormat {
 71    Teacher,
 72    Zeta2,
 73}
 74
 75#[derive(Debug, Args)]
 76struct PredictArgs {
 77    #[clap(long)]
 78    provider: PredictionProvider,
 79    #[clap(long, default_value_t = 1)]
 80    repetitions: usize,
 81}
 82
 83#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
 84enum PredictionProvider {
 85    Sweep,
 86    Mercury,
 87    Zeta1,
 88    Zeta2,
 89    Teacher,
 90}
 91
 92impl EpArgs {
 93    fn output_path(&self) -> Option<PathBuf> {
 94        if self.in_place {
 95            if self.inputs.len() == 1 {
 96                self.inputs.first().cloned()
 97            } else {
 98                panic!("--in-place requires exactly one input file")
 99            }
100        } else {
101            self.output.clone()
102        }
103    }
104}
105
106fn main() {
107    zlog::init();
108    zlog::init_output_stderr();
109    let args = EpArgs::parse();
110
111    if args.printenv {
112        ::util::shell_env::print_env();
113        return;
114    }
115
116    let output = args.output_path();
117    let command = match args.command {
118        Some(cmd) => cmd,
119        None => {
120            EpArgs::command().print_help().unwrap();
121            return;
122        }
123    };
124
125    match &command {
126        Command::Clean => {
127            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
128            return;
129        }
130        _ => {}
131    }
132
133    let mut examples = read_examples(&args.inputs);
134    let http_client = Arc::new(ReqwestClient::new());
135    let app = Application::headless().with_http_client(http_client);
136
137    app.run(move |cx| {
138        let app_state = Arc::new(headless::init(cx));
139        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
140
141        cx.spawn(async move |cx| {
142            match &command {
143                Command::Predict(args) => predict::sync_batches(&args.provider).await,
144                _ => (),
145            };
146
147            let chunks = examples.chunks_mut(args.max_parallelism);
148            let total_chunks = chunks.len();
149            for (batch_ix, data) in chunks.enumerate() {
150                let mut futures = Vec::new();
151                eprintln!("Processing batch: {}/{}", batch_ix + 1, total_chunks);
152
153                for example in data.iter_mut() {
154                    let cx = cx.clone();
155                    let app_state = app_state.clone();
156                    futures.push(async {
157                        match &command {
158                            Command::ParseExample => {}
159                            Command::LoadProject => {
160                                run_load_project(example, app_state.clone(), cx).await;
161                            }
162                            Command::Context => {
163                                run_context_retrieval(example, app_state, cx).await;
164                            }
165                            Command::FormatPrompt(args) => {
166                                run_format_prompt(example, args.prompt_format, app_state, cx).await;
167                            }
168                            Command::Predict(args) => {
169                                run_prediction(
170                                    example,
171                                    Some(args.provider),
172                                    args.repetitions,
173                                    app_state.clone(),
174                                    cx,
175                                )
176                                .await;
177                            }
178                            Command::Score(args) | Command::Eval(args) => {
179                                run_scoring(example, &args, app_state, cx).await;
180                            }
181                            Command::Clean => {
182                                unreachable!()
183                            }
184                        }
185                    });
186                }
187                futures::future::join_all(futures).await;
188            }
189
190            if args.output.is_some() || !matches!(command, Command::Eval(_)) {
191                write_examples(&examples, output.as_ref());
192            }
193
194            match &command {
195                Command::Predict(args) => predict::sync_batches(&args.provider).await,
196                Command::Eval(_) => score::print_report(&examples),
197                _ => (),
198            };
199
200            let _ = cx.update(|cx| cx.quit());
201        })
202        .detach();
203    });
204}