1mod anthropic_client;
2mod distill;
3mod example;
4mod format_prompt;
5mod headless;
6mod load_project;
7mod metrics;
8mod paths;
9mod predict;
10mod progress;
11mod retrieve_context;
12mod score;
13
14use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
15use edit_prediction::EditPredictionStore;
16use gpui::Application;
17use reqwest_client::ReqwestClient;
18use serde::{Deserialize, Serialize};
19use std::{path::PathBuf, sync::Arc};
20
21use crate::distill::run_distill;
22use crate::example::{group_examples_by_repo, read_examples, write_examples};
23use crate::format_prompt::run_format_prompt;
24use crate::load_project::run_load_project;
25use crate::predict::run_prediction;
26use crate::progress::Progress;
27use crate::retrieve_context::run_context_retrieval;
28use crate::score::run_scoring;
29
30#[derive(Parser, Debug)]
31#[command(name = "ep")]
32struct EpArgs {
33 #[arg(long, default_value_t = false)]
34 printenv: bool,
35 #[clap(long, default_value_t = 10)]
36 max_parallelism: usize,
37 #[command(subcommand)]
38 command: Option<Command>,
39 #[clap(global = true)]
40 inputs: Vec<PathBuf>,
41 #[arg(long, short, global = true)]
42 output: Option<PathBuf>,
43 #[arg(long, short, global = true)]
44 in_place: bool,
45}
46
47#[derive(Subcommand, Debug)]
48enum Command {
49 /// Parse markdown examples and output a combined .jsonl file
50 ParseExample,
51 /// Create git worktrees for each example and load file contents
52 LoadProject,
53 /// Retrieve context for input examples.
54 Context,
55 /// Generate a prompt string for a specific model
56 FormatPrompt(FormatPromptArgs),
57 /// Runs edit prediction
58 Predict(PredictArgs),
59 /// Computes a score based on actual and expected patches
60 Score(PredictArgs),
61 /// Prepares a distillation dataset by copying expected outputs to
62 /// predicted outputs and removing actual outputs and prompts.
63 Distill,
64 /// Print aggregated scores
65 Eval(PredictArgs),
66 /// Remove git repositories and worktrees
67 Clean,
68}
69
70#[derive(Debug, Args)]
71struct FormatPromptArgs {
72 #[clap(long)]
73 prompt_format: PromptFormat,
74}
75
76#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
77enum PromptFormat {
78 Teacher,
79 Zeta2,
80}
81
82#[derive(Debug, Args)]
83struct PredictArgs {
84 #[clap(long)]
85 provider: PredictionProvider,
86 #[clap(long, default_value_t = 1)]
87 repetitions: usize,
88}
89
90#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
91enum PredictionProvider {
92 Sweep,
93 Mercury,
94 Zeta1,
95 Zeta2,
96 Teacher,
97 TeacherNonBatching,
98}
99
100impl EpArgs {
101 fn output_path(&self) -> Option<PathBuf> {
102 if self.in_place {
103 if self.inputs.len() == 1 {
104 self.inputs.first().cloned()
105 } else {
106 panic!("--in-place requires exactly one input file")
107 }
108 } else {
109 self.output.clone()
110 }
111 }
112}
113
114fn main() {
115 let _ = zlog::try_init(Some("error".into()));
116 zlog::init_output_stderr();
117 let args = EpArgs::parse();
118
119 if args.printenv {
120 ::util::shell_env::print_env();
121 return;
122 }
123
124 let output = args.output_path();
125 let command = match args.command {
126 Some(cmd) => cmd,
127 None => {
128 EpArgs::command().print_help().unwrap();
129 return;
130 }
131 };
132
133 match &command {
134 Command::Clean => {
135 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
136 return;
137 }
138 _ => {}
139 }
140
141 let mut examples = read_examples(&args.inputs);
142 let http_client = Arc::new(ReqwestClient::new());
143 let app = Application::headless().with_http_client(http_client);
144
145 app.run(move |cx| {
146 let app_state = Arc::new(headless::init(cx));
147 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
148
149 cx.spawn(async move |cx| {
150 if let Command::Predict(args) = &command {
151 predict::sync_batches(&args.provider).await
152 };
153
154 let total_examples = examples.len();
155 let progress = Progress::new(total_examples);
156
157 let mut grouped_examples = group_examples_by_repo(&mut examples);
158 let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
159
160 for example_batch in example_batches {
161 let futures = example_batch.into_iter().map(|repo_examples| async {
162 for example in repo_examples.iter_mut() {
163 match &command {
164 Command::ParseExample => {}
165 Command::LoadProject => {
166 run_load_project(
167 example,
168 app_state.clone(),
169 progress.clone(),
170 cx.clone(),
171 )
172 .await;
173 }
174 Command::Context => {
175 run_context_retrieval(
176 example,
177 app_state.clone(),
178 progress.clone(),
179 cx.clone(),
180 )
181 .await;
182 }
183 Command::FormatPrompt(args) => {
184 run_format_prompt(
185 example,
186 args.prompt_format,
187 app_state.clone(),
188 progress.clone(),
189 cx.clone(),
190 )
191 .await;
192 }
193 Command::Predict(args) => {
194 run_prediction(
195 example,
196 Some(args.provider),
197 args.repetitions,
198 app_state.clone(),
199 progress.clone(),
200 cx.clone(),
201 )
202 .await;
203 }
204 Command::Distill => {
205 run_distill(example).await;
206 }
207 Command::Score(args) | Command::Eval(args) => {
208 run_scoring(
209 example,
210 &args,
211 app_state.clone(),
212 progress.clone(),
213 cx.clone(),
214 )
215 .await;
216 }
217 Command::Clean => {
218 unreachable!()
219 }
220 }
221 }
222 });
223 futures::future::join_all(futures).await;
224 }
225 progress.clear();
226
227 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
228 write_examples(&examples, output.as_ref());
229 }
230
231 match &command {
232 Command::Predict(args) => predict::sync_batches(&args.provider).await,
233 Command::Eval(_) => score::print_report(&examples),
234 _ => (),
235 };
236
237 let _ = cx.update(|cx| cx.quit());
238 })
239 .detach();
240 });
241}