main.rs

  1mod anthropic_client;
  2mod distill;
  3mod example;
  4mod format_prompt;
  5mod headless;
  6mod load_project;
  7mod metrics;
  8mod paths;
  9mod predict;
 10mod retrieve_context;
 11mod score;
 12
 13use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
 14use edit_prediction::EditPredictionStore;
 15use gpui::Application;
 16use reqwest_client::ReqwestClient;
 17use serde::{Deserialize, Serialize};
 18use std::{path::PathBuf, sync::Arc};
 19
 20use crate::distill::run_distill;
 21use crate::example::{read_examples, write_examples};
 22use crate::format_prompt::run_format_prompt;
 23use crate::load_project::run_load_project;
 24use crate::predict::run_prediction;
 25use crate::retrieve_context::run_context_retrieval;
 26use crate::score::run_scoring;
 27
 28#[derive(Parser, Debug)]
 29#[command(name = "ep")]
 30struct EpArgs {
 31    #[arg(long, default_value_t = false)]
 32    printenv: bool,
 33    #[clap(long, default_value_t = 10)]
 34    max_parallelism: usize,
 35    #[command(subcommand)]
 36    command: Option<Command>,
 37    #[clap(global = true)]
 38    inputs: Vec<PathBuf>,
 39    #[arg(long, short, global = true)]
 40    output: Option<PathBuf>,
 41    #[arg(long, short, global = true)]
 42    in_place: bool,
 43}
 44
 45#[derive(Subcommand, Debug)]
 46enum Command {
 47    /// Parse markdown examples and output a combined .jsonl file
 48    ParseExample,
 49    /// Create git worktrees for each example and load file contents
 50    LoadProject,
 51    /// Retrieve context for input examples.
 52    Context,
 53    /// Generate a prompt string for a specific model
 54    FormatPrompt(FormatPromptArgs),
 55    /// Runs edit prediction
 56    Predict(PredictArgs),
 57    /// Computes a score based on actual and expected patches
 58    Score(PredictArgs),
 59    /// Prepares a distillation dataset by copying expected outputs to
 60    /// predicted outputs and removing actual outputs and prompts.
 61    Distill,
 62    /// Print aggregated scores
 63    Eval(PredictArgs),
 64    /// Remove git repositories and worktrees
 65    Clean,
 66}
 67
 68#[derive(Debug, Args)]
 69struct FormatPromptArgs {
 70    #[clap(long)]
 71    prompt_format: PromptFormat,
 72}
 73
 74#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
 75enum PromptFormat {
 76    Teacher,
 77    Zeta2,
 78}
 79
 80#[derive(Debug, Args)]
 81struct PredictArgs {
 82    #[clap(long)]
 83    provider: PredictionProvider,
 84    #[clap(long, default_value_t = 1)]
 85    repetitions: usize,
 86}
 87
 88#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
 89enum PredictionProvider {
 90    Sweep,
 91    Mercury,
 92    Zeta1,
 93    Zeta2,
 94    Teacher,
 95    TeacherNonBatching,
 96}
 97
 98impl EpArgs {
 99    fn output_path(&self) -> Option<PathBuf> {
100        if self.in_place {
101            if self.inputs.len() == 1 {
102                self.inputs.first().cloned()
103            } else {
104                panic!("--in-place requires exactly one input file")
105            }
106        } else {
107            self.output.clone()
108        }
109    }
110}
111
112fn main() {
113    zlog::init();
114    zlog::init_output_stderr();
115    let args = EpArgs::parse();
116
117    if args.printenv {
118        ::util::shell_env::print_env();
119        return;
120    }
121
122    let output = args.output_path();
123    let command = match args.command {
124        Some(cmd) => cmd,
125        None => {
126            EpArgs::command().print_help().unwrap();
127            return;
128        }
129    };
130
131    match &command {
132        Command::Clean => {
133            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
134            return;
135        }
136        _ => {}
137    }
138
139    let mut examples = read_examples(&args.inputs);
140    let http_client = Arc::new(ReqwestClient::new());
141    let app = Application::headless().with_http_client(http_client);
142
143    app.run(move |cx| {
144        let app_state = Arc::new(headless::init(cx));
145        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
146
147        cx.spawn(async move |cx| {
148            match &command {
149                Command::Predict(args) => predict::sync_batches(&args.provider).await,
150                _ => (),
151            };
152
153            let chunks = examples.chunks_mut(args.max_parallelism);
154            let total_chunks = chunks.len();
155            for (batch_ix, data) in chunks.enumerate() {
156                let mut futures = Vec::new();
157                eprintln!("Processing batch: {}/{}", batch_ix + 1, total_chunks);
158
159                for example in data.iter_mut() {
160                    let cx = cx.clone();
161                    let app_state = app_state.clone();
162                    futures.push(async {
163                        match &command {
164                            Command::ParseExample => {}
165                            Command::LoadProject => {
166                                run_load_project(example, app_state.clone(), cx).await;
167                            }
168                            Command::Context => {
169                                run_context_retrieval(example, app_state, cx).await;
170                            }
171                            Command::FormatPrompt(args) => {
172                                run_format_prompt(example, args.prompt_format, app_state, cx).await;
173                            }
174                            Command::Predict(args) => {
175                                run_prediction(
176                                    example,
177                                    Some(args.provider),
178                                    args.repetitions,
179                                    app_state.clone(),
180                                    cx,
181                                )
182                                .await;
183                            }
184                            Command::Distill => {
185                                run_distill(example).await;
186                            }
187                            Command::Score(args) | Command::Eval(args) => {
188                                run_scoring(example, &args, app_state, cx).await;
189                            }
190                            Command::Clean => {
191                                unreachable!()
192                            }
193                        }
194                    });
195                }
196                futures::future::join_all(futures).await;
197            }
198
199            if args.output.is_some() || !matches!(command, Command::Eval(_)) {
200                write_examples(&examples, output.as_ref());
201            }
202
203            match &command {
204                Command::Predict(args) => predict::sync_batches(&args.provider).await,
205                Command::Eval(_) => score::print_report(&examples),
206                _ => (),
207            };
208
209            let _ = cx.update(|cx| cx.quit());
210        })
211        .detach();
212    });
213}