main.rs

  1mod anthropic_client;
  2mod distill;
  3mod example;
  4mod format_prompt;
  5mod git;
  6mod headless;
  7mod load_project;
  8mod metrics;
  9mod paths;
 10mod predict;
 11mod progress;
 12mod pull_examples;
 13mod reorder_patch;
 14mod retrieve_context;
 15mod score;
 16mod split_commit;
 17mod synthesize;
 18use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
 19use edit_prediction::EditPredictionStore;
 20use gpui::Application;
 21use reqwest_client::ReqwestClient;
 22use serde::{Deserialize, Serialize};
 23use std::fmt::Display;
 24use std::{path::PathBuf, sync::Arc};
 25
 26use crate::distill::run_distill;
 27use crate::example::{Example, group_examples_by_repo, read_example_files, write_examples};
 28use crate::format_prompt::run_format_prompt;
 29use crate::load_project::run_load_project;
 30use crate::paths::FAILED_EXAMPLES_DIR;
 31use crate::predict::run_prediction;
 32use crate::progress::Progress;
 33use crate::retrieve_context::run_context_retrieval;
 34use crate::score::run_scoring;
 35use crate::split_commit::SplitCommitArgs;
 36use crate::synthesize::{SynthesizeConfig, run_synthesize};
 37
 38#[derive(Parser, Debug)]
 39#[command(name = "ep")]
 40struct EpArgs {
 41    #[arg(long, default_value_t = false)]
 42    printenv: bool,
 43    #[clap(long, default_value_t = 10, global = true)]
 44    max_parallelism: usize,
 45    #[clap(long, global = true)]
 46    limit: Option<usize>,
 47    #[command(subcommand)]
 48    command: Option<Command>,
 49    #[clap(global = true, help = INPUTS_HELP)]
 50    inputs: Vec<PathBuf>,
 51    #[arg(long, short, global = true)]
 52    output: Option<PathBuf>,
 53    #[arg(long, short, global = true)]
 54    in_place: bool,
 55    #[arg(long, short, global = true)]
 56    failfast: bool,
 57}
 58
 59const INPUTS_HELP: &str = r#"
 60Inputs can be file paths or special specifiers:
 61
 62  path
 63      Path to an example(s) file (.md, .json, or .jsonl)
 64
 65  captured-after:{timestamp}
 66      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 67
 68      You can specify this multiple times and mix it with file inputs.
 69
 70      Required environment variables to connect to Snowflake:
 71          EP_SNOWFLAKE_API_KEY
 72          EP_SNOWFLAKE_BASE_URL
 73
 74      Optional:
 75          EP_SNOWFLAKE_ROLE
 76
 77Examples:
 78
 79  # Predict from a file
 80  ep predict examples.jsonl
 81
 82  # Predict from captured examples after a timestamp
 83  ep predict captured-after:2025-01-01T00:00:00Z
 84
 85  # Mix file inputs and captured-after in the same invocation
 86  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
 87"#;
 88
 89#[derive(Subcommand, Debug, Clone)]
 90enum Command {
 91    /// Parse markdown examples and output a combined .jsonl file
 92    ParseExample,
 93    /// Create git worktrees for each example and load file contents
 94    LoadProject,
 95    /// Retrieve context for input examples.
 96    Context,
 97    /// Generate a prompt string for a specific model
 98    FormatPrompt(FormatPromptArgs),
 99    /// Runs edit prediction
100    Predict(PredictArgs),
101    /// Computes a score based on actual and expected patches
102    Score(PredictArgs),
103    /// Prepares a distillation dataset by copying expected outputs to
104    /// predicted outputs and removing actual outputs and prompts.
105    Distill,
106    /// Print aggregated scores
107    Eval(PredictArgs),
108    /// Generate eval examples by analyzing git commits from a repository
109    Synthesize(SynthesizeArgs),
110    /// Remove git repositories and worktrees
111    Clean,
112    /// Generate an evaluation example by splitting a chronologically-ordered commit
113    SplitCommit(SplitCommitArgs),
114}
115
116impl Display for Command {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        match self {
119            Command::ParseExample => write!(f, "parse-example"),
120            Command::LoadProject => write!(f, "load-project"),
121            Command::Context => write!(f, "context"),
122            Command::FormatPrompt(format_prompt_args) => write!(
123                f,
124                "format-prompt --prompt-format={}",
125                format_prompt_args
126                    .prompt_format
127                    .to_possible_value()
128                    .unwrap()
129                    .get_name()
130            ),
131            Command::Predict(predict_args) => {
132                write!(
133                    f,
134                    "predict --provider={:?}",
135                    predict_args
136                        .provider
137                        .to_possible_value()
138                        .unwrap()
139                        .get_name()
140                )
141            }
142            Command::Score(predict_args) => {
143                write!(
144                    f,
145                    "score --provider={:?}",
146                    predict_args
147                        .provider
148                        .to_possible_value()
149                        .unwrap()
150                        .get_name()
151                )
152            }
153            Command::Distill => write!(f, "distill"),
154            Command::Eval(predict_args) => write!(
155                f,
156                "eval --provider={:?}",
157                predict_args
158                    .provider
159                    .to_possible_value()
160                    .unwrap()
161                    .get_name()
162            ),
163            Command::Synthesize(args) => {
164                write!(f, "synthesize --repo={}", args.repo)
165            }
166            Command::Clean => write!(f, "clean"),
167            Command::SplitCommit(_) => write!(f, "split-commit"),
168        }
169    }
170}
171
172#[derive(Debug, Args, Clone)]
173struct FormatPromptArgs {
174    #[clap(long)]
175    prompt_format: PromptFormat,
176}
177
178#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
179enum PromptFormat {
180    Teacher,
181    Zeta2,
182}
183
184#[derive(Debug, Args, Clone)]
185struct PredictArgs {
186    #[clap(long)]
187    provider: PredictionProvider,
188    #[clap(long, default_value_t = 1)]
189    repetitions: usize,
190}
191
192#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
193enum PredictionProvider {
194    Sweep,
195    Mercury,
196    Zeta1,
197    Zeta2,
198    Teacher,
199    TeacherNonBatching,
200}
201
202#[derive(Debug, Args, Clone)]
203struct SynthesizeArgs {
204    /// Repository URL (git@github.com:owner/repo or https://...)
205    #[clap(long)]
206    repo: String,
207
208    /// Number of examples to generate
209    #[clap(long, default_value_t = 5)]
210    count: usize,
211
212    /// Maximum commits to scan before giving up
213    #[clap(long, default_value_t = 100)]
214    max_commits: usize,
215
216    /// Ignore state file and reprocess all commits
217    #[clap(long)]
218    fresh: bool,
219}
220
221impl EpArgs {
222    fn output_path(&self) -> Option<PathBuf> {
223        if self.in_place {
224            if self.inputs.len() == 1 {
225                self.inputs.first().cloned()
226            } else {
227                panic!("--in-place requires exactly one input file")
228            }
229        } else {
230            self.output.clone()
231        }
232    }
233}
234
235async fn load_examples(
236    http_client: Arc<dyn http_client::HttpClient>,
237    args: &EpArgs,
238) -> anyhow::Result<Vec<Example>> {
239    let mut captured_after_timestamps = Vec::new();
240    let mut file_inputs = Vec::new();
241
242    for input in &args.inputs {
243        let input_string = input.to_string_lossy();
244        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
245            captured_after_timestamps.push(timestamp.to_string());
246        } else {
247            file_inputs.push(input.clone());
248        }
249    }
250
251    let mut examples = read_example_files(&file_inputs);
252    let total_steps = examples.len() + captured_after_timestamps.len();
253    Progress::global().set_total_steps(total_steps);
254
255    let remaining_limit_for_snowflake =
256        args.limit.map(|limit| limit.saturating_sub(examples.len()));
257
258    if let Some(0) = remaining_limit_for_snowflake {
259        log::info!(
260            "skipping captured-after inputs because --limit is already satisfied by example files"
261        );
262    } else if !captured_after_timestamps.is_empty() {
263        captured_after_timestamps.sort();
264
265        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
266
267        let mut captured_examples = pull_examples::fetch_captured_examples_after(
268            http_client,
269            &captured_after_timestamps,
270            max_rows_per_timestamp,
271        )
272        .await?;
273        examples.append(&mut captured_examples);
274    }
275
276    crate::example::sort_examples_by_repo_and_rev(&mut examples);
277
278    if let Some(limit) = args.limit {
279        if examples.len() > limit {
280            examples.truncate(limit);
281        }
282    }
283
284    Progress::global().set_total_steps(examples.len() + captured_after_timestamps.len());
285
286    Ok(examples)
287}
288
289fn main() {
290    let args = EpArgs::parse();
291
292    if args.printenv {
293        ::util::shell_env::print_env();
294        return;
295    }
296
297    let output = args.output_path();
298    let command = match &args.command {
299        Some(cmd) => cmd.clone(),
300        None => {
301            EpArgs::command().print_help().unwrap();
302            return;
303        }
304    };
305
306    match &command {
307        Command::Clean => {
308            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
309            return;
310        }
311        Command::Synthesize(synth_args) => {
312            let Some(output_dir) = args.output else {
313                panic!("output dir is required");
314            };
315            let config = SynthesizeConfig {
316                repo_url: synth_args.repo.clone(),
317                count: synth_args.count,
318                max_commits: synth_args.max_commits,
319                output_dir,
320                fresh: synth_args.fresh,
321            };
322            smol::block_on(async {
323                if let Err(e) = run_synthesize(config).await {
324                    eprintln!("Error: {:?}", e);
325                    std::process::exit(1);
326                }
327            });
328            return;
329        }
330        Command::SplitCommit(split_commit_args) => {
331            if let Err(error) = split_commit::run_split_commit(split_commit_args) {
332                eprintln!("{error:#}");
333                std::process::exit(1);
334            }
335            return;
336        }
337        _ => {}
338    }
339
340    let http_client = Arc::new(ReqwestClient::new());
341    let app = Application::headless().with_http_client(http_client);
342
343    app.run(move |cx| {
344        let app_state = Arc::new(headless::init(cx));
345        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
346
347        cx.spawn(async move |cx| {
348            let result = async {
349                let mut examples = load_examples(app_state.client.http_client(), &args).await?;
350
351                if let Command::Predict(args) = &command {
352                    predict::sync_batches(&args.provider).await?;
353                }
354
355                let failfast_on_single_example = examples.len() == 1;
356
357                let mut grouped_examples = group_examples_by_repo(&mut examples);
358                let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
359
360                for example_batch in example_batches {
361                    let futures = example_batch.into_iter().map(|repo_examples| async {
362                        for example in repo_examples.iter_mut() {
363                            let result = async {
364                                match &command {
365                                    Command::ParseExample => {}
366                                    Command::LoadProject => {
367                                        run_load_project(example, app_state.clone(), cx.clone())
368                                            .await?;
369                                    }
370                                    Command::Context => {
371                                        run_context_retrieval(
372                                            example,
373                                            app_state.clone(),
374                                            cx.clone(),
375                                        )
376                                        .await?;
377                                    }
378                                    Command::FormatPrompt(args) => {
379                                        run_format_prompt(
380                                            example,
381                                            args.prompt_format,
382                                            app_state.clone(),
383                                            cx.clone(),
384                                        )
385                                        .await?;
386                                    }
387                                    Command::Predict(args) => {
388                                        run_prediction(
389                                            example,
390                                            Some(args.provider),
391                                            args.repetitions,
392                                            app_state.clone(),
393                                            cx.clone(),
394                                        )
395                                        .await?;
396                                    }
397                                    Command::Distill => {
398                                        run_distill(example).await?;
399                                    }
400                                    Command::Score(args) | Command::Eval(args) => {
401                                        run_scoring(example, &args, app_state.clone(), cx.clone())
402                                            .await?;
403                                    }
404                                    Command::Clean
405                                    | Command::Synthesize(_)
406                                    | Command::SplitCommit(_) => {
407                                        unreachable!()
408                                    }
409                                }
410                                anyhow::Ok(())
411                            }
412                            .await;
413
414                            if let Err(e) = result {
415                                Progress::global().increment_failed();
416                                let failed_example_path =
417                                    FAILED_EXAMPLES_DIR.join(format!("{}.json", example.spec.name));
418                                app_state
419                                    .fs
420                                    .write(
421                                        &failed_example_path,
422                                        &serde_json::to_vec_pretty(&example).unwrap(),
423                                    )
424                                    .await
425                                    .unwrap();
426                                let err_path = FAILED_EXAMPLES_DIR
427                                    .join(format!("{}_err.txt", example.spec.name));
428                                app_state
429                                    .fs
430                                    .write(&err_path, e.to_string().as_bytes())
431                                    .await
432                                    .unwrap();
433
434                                let msg = format!(
435                                    indoc::indoc! {"
436                                        While processing \"{}\":
437
438                                        {:?}
439
440                                        Written to: \x1b[36m{}\x1b[0m
441
442                                        Explore this example data with:
443                                            fx \x1b[36m{}\x1b[0m
444
445                                        Re-run this example with:
446                                            cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
447                                    "},
448                                    example.spec.name,
449                                    e,
450                                    err_path.display(),
451                                    failed_example_path.display(),
452                                    command,
453                                    failed_example_path.display(),
454                                );
455                                if args.failfast || failfast_on_single_example {
456                                    Progress::global().finalize();
457                                    panic!("{}", msg);
458                                } else {
459                                    log::error!("{}", msg);
460                                }
461                            }
462                        }
463                    });
464                    futures::future::join_all(futures).await;
465                }
466                Progress::global().finalize();
467
468                if args.output.is_some() || !matches!(command, Command::Eval(_)) {
469                    write_examples(&examples, output.as_ref());
470                }
471
472                match &command {
473                    Command::Predict(args) => predict::sync_batches(&args.provider).await?,
474                    Command::Eval(_) => score::print_report(&examples),
475                    _ => (),
476                };
477
478                anyhow::Ok(())
479            }
480            .await;
481
482            if let Err(e) = result {
483                panic!("Fatal error: {:?}", e);
484            }
485
486            let _ = cx.update(|cx| cx.quit());
487        })
488        .detach();
489    });
490}