1mod anthropic_client;
2mod distill;
3mod example;
4mod format_prompt;
5mod headless;
6mod load_project;
7mod metrics;
8mod paths;
9mod predict;
10mod retrieve_context;
11mod score;
12
13use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
14use edit_prediction::EditPredictionStore;
15use gpui::Application;
16use reqwest_client::ReqwestClient;
17use serde::{Deserialize, Serialize};
18use std::{path::PathBuf, sync::Arc};
19
20use crate::distill::run_distill;
21use crate::example::{read_examples, write_examples};
22use crate::format_prompt::run_format_prompt;
23use crate::load_project::run_load_project;
24use crate::predict::run_prediction;
25use crate::retrieve_context::run_context_retrieval;
26use crate::score::run_scoring;
27
28#[derive(Parser, Debug)]
29#[command(name = "ep")]
30struct EpArgs {
31 #[arg(long, default_value_t = false)]
32 printenv: bool,
33 #[clap(long, default_value_t = 10)]
34 max_parallelism: usize,
35 #[command(subcommand)]
36 command: Option<Command>,
37 #[clap(global = true)]
38 inputs: Vec<PathBuf>,
39 #[arg(long, short, global = true)]
40 output: Option<PathBuf>,
41 #[arg(long, short, global = true)]
42 in_place: bool,
43}
44
45#[derive(Subcommand, Debug)]
46enum Command {
47 /// Parse markdown examples and output a combined .jsonl file
48 ParseExample,
49 /// Create git worktrees for each example and load file contents
50 LoadProject,
51 /// Retrieve context for input examples.
52 Context,
53 /// Generate a prompt string for a specific model
54 FormatPrompt(FormatPromptArgs),
55 /// Runs edit prediction
56 Predict(PredictArgs),
57 /// Computes a score based on actual and expected patches
58 Score(PredictArgs),
59 /// Prepares a distillation dataset by copying expected outputs to
60 /// predicted outputs and removing actual outputs and prompts.
61 Distill,
62 /// Print aggregated scores
63 Eval(PredictArgs),
64 /// Remove git repositories and worktrees
65 Clean,
66}
67
68#[derive(Debug, Args)]
69struct FormatPromptArgs {
70 #[clap(long)]
71 prompt_format: PromptFormat,
72}
73
74#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
75enum PromptFormat {
76 Teacher,
77 Zeta2,
78}
79
80#[derive(Debug, Args)]
81struct PredictArgs {
82 #[clap(long)]
83 provider: PredictionProvider,
84 #[clap(long, default_value_t = 1)]
85 repetitions: usize,
86}
87
88#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
89enum PredictionProvider {
90 Sweep,
91 Mercury,
92 Zeta1,
93 Zeta2,
94 Teacher,
95 TeacherNonBatching,
96}
97
98impl EpArgs {
99 fn output_path(&self) -> Option<PathBuf> {
100 if self.in_place {
101 if self.inputs.len() == 1 {
102 self.inputs.first().cloned()
103 } else {
104 panic!("--in-place requires exactly one input file")
105 }
106 } else {
107 self.output.clone()
108 }
109 }
110}
111
112fn main() {
113 zlog::init();
114 zlog::init_output_stderr();
115 let args = EpArgs::parse();
116
117 if args.printenv {
118 ::util::shell_env::print_env();
119 return;
120 }
121
122 let output = args.output_path();
123 let command = match args.command {
124 Some(cmd) => cmd,
125 None => {
126 EpArgs::command().print_help().unwrap();
127 return;
128 }
129 };
130
131 match &command {
132 Command::Clean => {
133 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
134 return;
135 }
136 _ => {}
137 }
138
139 let mut examples = read_examples(&args.inputs);
140 let http_client = Arc::new(ReqwestClient::new());
141 let app = Application::headless().with_http_client(http_client);
142
143 app.run(move |cx| {
144 let app_state = Arc::new(headless::init(cx));
145 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
146
147 cx.spawn(async move |cx| {
148 match &command {
149 Command::Predict(args) => predict::sync_batches(&args.provider).await,
150 _ => (),
151 };
152
153 let chunks = examples.chunks_mut(args.max_parallelism);
154 let total_chunks = chunks.len();
155 for (batch_ix, data) in chunks.enumerate() {
156 let mut futures = Vec::new();
157 eprintln!("Processing batch: {}/{}", batch_ix + 1, total_chunks);
158
159 for example in data.iter_mut() {
160 let cx = cx.clone();
161 let app_state = app_state.clone();
162 futures.push(async {
163 match &command {
164 Command::ParseExample => {}
165 Command::LoadProject => {
166 run_load_project(example, app_state.clone(), cx).await;
167 }
168 Command::Context => {
169 run_context_retrieval(example, app_state, cx).await;
170 }
171 Command::FormatPrompt(args) => {
172 run_format_prompt(example, args.prompt_format, app_state, cx).await;
173 }
174 Command::Predict(args) => {
175 run_prediction(
176 example,
177 Some(args.provider),
178 args.repetitions,
179 app_state.clone(),
180 cx,
181 )
182 .await;
183 }
184 Command::Distill => {
185 run_distill(example).await;
186 }
187 Command::Score(args) | Command::Eval(args) => {
188 run_scoring(example, &args, app_state, cx).await;
189 }
190 Command::Clean => {
191 unreachable!()
192 }
193 }
194 });
195 }
196 futures::future::join_all(futures).await;
197 }
198
199 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
200 write_examples(&examples, output.as_ref());
201 }
202
203 match &command {
204 Command::Predict(args) => predict::sync_batches(&args.provider).await,
205 Command::Eval(_) => score::print_report(&examples),
206 _ => (),
207 };
208
209 let _ = cx.update(|cx| cx.quit());
210 })
211 .detach();
212 });
213}