1mod anthropic_client;
2mod example;
3mod format_prompt;
4mod headless;
5mod load_project;
6mod metrics;
7mod paths;
8mod predict;
9mod retrieve_context;
10mod score;
11
12use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
13use edit_prediction::EditPredictionStore;
14use gpui::Application;
15use reqwest_client::ReqwestClient;
16use serde::{Deserialize, Serialize};
17use std::{path::PathBuf, sync::Arc};
18
19use crate::example::{read_examples, write_examples};
20use crate::format_prompt::run_format_prompt;
21use crate::load_project::run_load_project;
22use crate::predict::run_prediction;
23use crate::retrieve_context::run_context_retrieval;
24use crate::score::run_scoring;
25
26#[derive(Parser, Debug)]
27#[command(name = "ep")]
28struct EpArgs {
29 #[arg(long, default_value_t = false)]
30 printenv: bool,
31 #[clap(long, default_value_t = 10)]
32 max_parallelism: usize,
33 #[command(subcommand)]
34 command: Option<Command>,
35 #[clap(global = true)]
36 inputs: Vec<PathBuf>,
37 #[arg(long, short, global = true)]
38 output: Option<PathBuf>,
39 #[arg(long, short, global = true)]
40 in_place: bool,
41}
42
43#[derive(Subcommand, Debug)]
44enum Command {
45 /// Parse markdown examples and output a combined .jsonl file
46 ParseExample,
47 /// Create git worktrees for each example and load file contents
48 LoadProject,
49 /// Retrieve context for input examples.
50 Context,
51 /// Generate a prompt string for a specific model
52 FormatPrompt(FormatPromptArgs),
53 /// Runs edit prediction
54 Predict(PredictArgs),
55 /// Computes a score based on actual and expected patches
56 Score(PredictArgs),
57 /// Print aggregated scores
58 Eval(PredictArgs),
59 /// Remove git repositories and worktrees
60 Clean,
61}
62
63#[derive(Debug, Args)]
64struct FormatPromptArgs {
65 #[clap(long)]
66 prompt_format: PromptFormat,
67}
68
69#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
70enum PromptFormat {
71 Teacher,
72 Zeta2,
73}
74
75#[derive(Debug, Args)]
76struct PredictArgs {
77 #[clap(long)]
78 provider: PredictionProvider,
79 #[clap(long, default_value_t = 1)]
80 repetitions: usize,
81}
82
83#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
84enum PredictionProvider {
85 Sweep,
86 Mercury,
87 Zeta1,
88 Zeta2,
89 Teacher,
90}
91
92impl EpArgs {
93 fn output_path(&self) -> Option<PathBuf> {
94 if self.in_place {
95 if self.inputs.len() == 1 {
96 self.inputs.first().cloned()
97 } else {
98 panic!("--in-place requires exactly one input file")
99 }
100 } else {
101 self.output.clone()
102 }
103 }
104}
105
106fn main() {
107 zlog::init();
108 zlog::init_output_stderr();
109 let args = EpArgs::parse();
110
111 if args.printenv {
112 ::util::shell_env::print_env();
113 return;
114 }
115
116 let output = args.output_path();
117 let command = match args.command {
118 Some(cmd) => cmd,
119 None => {
120 EpArgs::command().print_help().unwrap();
121 return;
122 }
123 };
124
125 match &command {
126 Command::Clean => {
127 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
128 return;
129 }
130 _ => {}
131 }
132
133 let mut examples = read_examples(&args.inputs);
134 let http_client = Arc::new(ReqwestClient::new());
135 let app = Application::headless().with_http_client(http_client);
136
137 app.run(move |cx| {
138 let app_state = Arc::new(headless::init(cx));
139 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
140
141 cx.spawn(async move |cx| {
142 match &command {
143 Command::Predict(args) => predict::sync_batches(&args.provider).await,
144 _ => (),
145 };
146
147 let chunks = examples.chunks_mut(args.max_parallelism);
148 let total_chunks = chunks.len();
149 for (batch_ix, data) in chunks.enumerate() {
150 let mut futures = Vec::new();
151 eprintln!("Processing batch: {}/{}", batch_ix + 1, total_chunks);
152
153 for example in data.iter_mut() {
154 let cx = cx.clone();
155 let app_state = app_state.clone();
156 futures.push(async {
157 match &command {
158 Command::ParseExample => {}
159 Command::LoadProject => {
160 run_load_project(example, app_state.clone(), cx).await;
161 }
162 Command::Context => {
163 run_context_retrieval(example, app_state, cx).await;
164 }
165 Command::FormatPrompt(args) => {
166 run_format_prompt(example, args.prompt_format, app_state, cx).await;
167 }
168 Command::Predict(args) => {
169 run_prediction(
170 example,
171 Some(args.provider),
172 args.repetitions,
173 app_state.clone(),
174 cx,
175 )
176 .await;
177 }
178 Command::Score(args) | Command::Eval(args) => {
179 run_scoring(example, &args, app_state, cx).await;
180 }
181 Command::Clean => {
182 unreachable!()
183 }
184 }
185 });
186 }
187 futures::future::join_all(futures).await;
188 }
189
190 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
191 write_examples(&examples, output.as_ref());
192 }
193
194 match &command {
195 Command::Predict(args) => predict::sync_batches(&args.provider).await,
196 Command::Eval(_) => score::print_report(&examples),
197 _ => (),
198 };
199
200 let _ = cx.update(|cx| cx.quit());
201 })
202 .detach();
203 });
204}