main.rs

  1mod anthropic_client;
  2mod distill;
  3mod example;
  4mod format_prompt;
  5mod git;
  6mod headless;
  7mod load_project;
  8mod metrics;
  9mod paths;
 10mod predict;
 11mod progress;
 12mod pull_examples;
 13mod reorder_patch;
 14mod retrieve_context;
 15mod score;
 16mod split_commit;
 17mod synthesize;
 18use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
 19use edit_prediction::EditPredictionStore;
 20use gpui::Application;
 21use reqwest_client::ReqwestClient;
 22use serde::{Deserialize, Serialize};
 23use std::fmt::Display;
 24use std::{path::PathBuf, sync::Arc};
 25
 26use crate::distill::run_distill;
 27use crate::example::{Example, group_examples_by_repo, read_example_files, write_examples};
 28use crate::format_prompt::run_format_prompt;
 29use crate::load_project::run_load_project;
 30use crate::paths::FAILED_EXAMPLES_DIR;
 31use crate::predict::run_prediction;
 32use crate::progress::Progress;
 33use crate::retrieve_context::run_context_retrieval;
 34use crate::score::run_scoring;
 35use crate::split_commit::SplitCommitArgs;
 36use crate::synthesize::{SynthesizeConfig, run_synthesize};
 37
 38#[derive(Parser, Debug)]
 39#[command(name = "ep")]
 40struct EpArgs {
 41    #[arg(long, default_value_t = false)]
 42    printenv: bool,
 43    #[clap(long, default_value_t = 10, global = true)]
 44    max_parallelism: usize,
 45    #[clap(long, global = true)]
 46    limit: Option<usize>,
 47    /// Filter examples by name
 48    #[clap(long, global = true)]
 49    name: Option<String>,
 50    /// Filter examples by repository
 51    #[clap(long, global = true)]
 52    repo: Option<String>,
 53    #[command(subcommand)]
 54    command: Option<Command>,
 55    #[clap(global = true, help = INPUTS_HELP)]
 56    inputs: Vec<PathBuf>,
 57    #[arg(long, short, global = true)]
 58    output: Option<PathBuf>,
 59    #[arg(long, short, global = true)]
 60    in_place: bool,
 61    #[arg(long, short, global = true)]
 62    failfast: bool,
 63}
 64
 65const INPUTS_HELP: &str = r#"
 66Inputs can be file paths or special specifiers:
 67
 68  path
 69      Path to an example(s) file (.md, .json, or .jsonl)
 70
 71  captured-after:{timestamp}
 72      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 73
 74      You can specify this multiple times and mix it with file inputs.
 75
 76      Required environment variables to connect to Snowflake:
 77          EP_SNOWFLAKE_API_KEY
 78          EP_SNOWFLAKE_BASE_URL
 79
 80      Optional:
 81          EP_SNOWFLAKE_ROLE
 82
 83Examples:
 84
 85  # Predict from a file
 86  ep predict examples.jsonl
 87
 88  # Predict from captured examples after a timestamp
 89  ep predict captured-after:2025-01-01T00:00:00Z
 90
 91  # Mix file inputs and captured-after in the same invocation
 92  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
 93"#;
 94
 95#[derive(Subcommand, Debug, Clone)]
 96enum Command {
 97    /// Parse markdown examples and output a combined .jsonl file
 98    ParseExample,
 99    /// Create git worktrees for each example and load file contents
100    LoadProject,
101    /// Retrieve context for input examples.
102    Context,
103    /// Generate a prompt string for a specific model
104    FormatPrompt(FormatPromptArgs),
105    /// Runs edit prediction
106    Predict(PredictArgs),
107    /// Computes a score based on actual and expected patches
108    Score(PredictArgs),
109    /// Prepares a distillation dataset by copying expected outputs to
110    /// predicted outputs and removing actual outputs and prompts.
111    Distill,
112    /// Print aggregated scores
113    Eval(PredictArgs),
114    /// Generate eval examples by analyzing git commits from a repository
115    Synthesize(SynthesizeArgs),
116    /// Remove git repositories and worktrees
117    Clean,
118    /// Generate an evaluation example by splitting a chronologically-ordered commit
119    SplitCommit(SplitCommitArgs),
120}
121
122impl Display for Command {
123    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124        match self {
125            Command::ParseExample => write!(f, "parse-example"),
126            Command::LoadProject => write!(f, "load-project"),
127            Command::Context => write!(f, "context"),
128            Command::FormatPrompt(format_prompt_args) => write!(
129                f,
130                "format-prompt --prompt-format={}",
131                format_prompt_args
132                    .prompt_format
133                    .to_possible_value()
134                    .unwrap()
135                    .get_name()
136            ),
137            Command::Predict(predict_args) => {
138                write!(
139                    f,
140                    "predict --provider={:?}",
141                    predict_args
142                        .provider
143                        .to_possible_value()
144                        .unwrap()
145                        .get_name()
146                )
147            }
148            Command::Score(predict_args) => {
149                write!(
150                    f,
151                    "score --provider={:?}",
152                    predict_args
153                        .provider
154                        .to_possible_value()
155                        .unwrap()
156                        .get_name()
157                )
158            }
159            Command::Distill => write!(f, "distill"),
160            Command::Eval(predict_args) => write!(
161                f,
162                "eval --provider={:?}",
163                predict_args
164                    .provider
165                    .to_possible_value()
166                    .unwrap()
167                    .get_name()
168            ),
169            Command::Synthesize(args) => {
170                write!(f, "synthesize --repo={}", args.repo)
171            }
172            Command::Clean => write!(f, "clean"),
173            Command::SplitCommit(_) => write!(f, "split-commit"),
174        }
175    }
176}
177
178#[derive(Debug, Args, Clone)]
179struct FormatPromptArgs {
180    #[clap(long)]
181    prompt_format: PromptFormat,
182}
183
184#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
185enum PromptFormat {
186    Teacher,
187    Zeta2,
188}
189
190#[derive(Debug, Args, Clone)]
191struct PredictArgs {
192    #[clap(long)]
193    provider: PredictionProvider,
194    #[clap(long, default_value_t = 1)]
195    repetitions: usize,
196}
197
198#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
199enum PredictionProvider {
200    Sweep,
201    Mercury,
202    Zeta1,
203    Zeta2,
204    Teacher,
205    TeacherNonBatching,
206}
207
208#[derive(Debug, Args, Clone)]
209struct SynthesizeArgs {
210    /// Repository URL (git@github.com:owner/repo or https://...)
211    #[clap(long)]
212    repo: String,
213
214    /// Number of examples to generate
215    #[clap(long, default_value_t = 5)]
216    count: usize,
217
218    /// Maximum commits to scan before giving up
219    #[clap(long, default_value_t = 100)]
220    max_commits: usize,
221
222    /// Ignore state file and reprocess all commits
223    #[clap(long)]
224    fresh: bool,
225}
226
227impl EpArgs {
228    fn output_path(&self) -> Option<PathBuf> {
229        if self.in_place {
230            if self.inputs.len() == 1 {
231                self.inputs.first().cloned()
232            } else {
233                panic!("--in-place requires exactly one input file")
234            }
235        } else {
236            self.output.clone()
237        }
238    }
239}
240
241async fn load_examples(
242    http_client: Arc<dyn http_client::HttpClient>,
243    args: &EpArgs,
244) -> anyhow::Result<Vec<Example>> {
245    let mut captured_after_timestamps = Vec::new();
246    let mut file_inputs = Vec::new();
247
248    for input in &args.inputs {
249        let input_string = input.to_string_lossy();
250        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
251            captured_after_timestamps.push(timestamp.to_string());
252        } else {
253            file_inputs.push(input.clone());
254        }
255    }
256
257    let mut examples = read_example_files(&file_inputs);
258
259    Progress::global().set_total_examples(examples.len());
260
261    let remaining_limit_for_snowflake =
262        args.limit.map(|limit| limit.saturating_sub(examples.len()));
263
264    if let Some(0) = remaining_limit_for_snowflake {
265        log::info!(
266            "skipping captured-after inputs because --limit is already satisfied by example files"
267        );
268    } else if !captured_after_timestamps.is_empty() {
269        captured_after_timestamps.sort();
270
271        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
272
273        let mut captured_examples = pull_examples::fetch_captured_examples_after(
274            http_client,
275            &captured_after_timestamps,
276            max_rows_per_timestamp,
277        )
278        .await?;
279        examples.append(&mut captured_examples);
280    }
281
282    crate::example::sort_examples_by_repo_and_rev(&mut examples);
283
284    if let Some(name_filter) = &args.name {
285        examples.retain(|example| example.spec.name.contains(name_filter));
286    }
287    if let Some(repo_filter) = &args.repo {
288        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
289    }
290
291    if let Some(limit) = args.limit {
292        if examples.len() > limit {
293            examples.truncate(limit);
294        }
295    }
296
297    Progress::global().set_total_examples(examples.len());
298
299    Ok(examples)
300}
301
302fn main() {
303    let args = EpArgs::parse();
304
305    if args.printenv {
306        ::util::shell_env::print_env();
307        return;
308    }
309
310    let output = args.output_path();
311    let command = match &args.command {
312        Some(cmd) => cmd.clone(),
313        None => {
314            EpArgs::command().print_help().unwrap();
315            return;
316        }
317    };
318
319    match &command {
320        Command::Clean => {
321            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
322            return;
323        }
324        Command::Synthesize(synth_args) => {
325            let Some(output_dir) = args.output else {
326                panic!("output dir is required");
327            };
328            let config = SynthesizeConfig {
329                repo_url: synth_args.repo.clone(),
330                count: synth_args.count,
331                max_commits: synth_args.max_commits,
332                output_dir,
333                fresh: synth_args.fresh,
334            };
335            smol::block_on(async {
336                if let Err(e) = run_synthesize(config).await {
337                    eprintln!("Error: {:?}", e);
338                    std::process::exit(1);
339                }
340            });
341            return;
342        }
343        Command::SplitCommit(split_commit_args) => {
344            if let Err(error) =
345                split_commit::run_split_commit(split_commit_args, &args.inputs, output.as_ref())
346            {
347                eprintln!("{error:#}");
348                std::process::exit(1);
349            }
350            return;
351        }
352        _ => {}
353    }
354
355    let http_client = Arc::new(ReqwestClient::new());
356    let app = Application::headless().with_http_client(http_client);
357
358    app.run(move |cx| {
359        let app_state = Arc::new(headless::init(cx));
360        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
361
362        cx.spawn(async move |cx| {
363            let result = async {
364                let mut examples = load_examples(app_state.client.http_client(), &args).await?;
365
366                if let Command::Predict(args) = &command {
367                    predict::sync_batches(&args.provider).await?;
368                }
369
370                let failfast_on_single_example = examples.len() == 1;
371
372                let mut grouped_examples = group_examples_by_repo(&mut examples);
373                let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
374
375                for example_batch in example_batches {
376                    let futures = example_batch.into_iter().map(|repo_examples| async {
377                        for example in repo_examples.iter_mut() {
378                            let result = async {
379                                match &command {
380                                    Command::ParseExample => {}
381                                    Command::LoadProject => {
382                                        run_load_project(example, app_state.clone(), cx.clone())
383                                            .await?;
384                                    }
385                                    Command::Context => {
386                                        run_context_retrieval(
387                                            example,
388                                            app_state.clone(),
389                                            cx.clone(),
390                                        )
391                                        .await?;
392                                    }
393                                    Command::FormatPrompt(args) => {
394                                        run_format_prompt(
395                                            example,
396                                            args.prompt_format,
397                                            app_state.clone(),
398                                            cx.clone(),
399                                        )
400                                        .await?;
401                                    }
402                                    Command::Predict(args) => {
403                                        run_prediction(
404                                            example,
405                                            Some(args.provider),
406                                            args.repetitions,
407                                            app_state.clone(),
408                                            cx.clone(),
409                                        )
410                                        .await?;
411                                    }
412                                    Command::Distill => {
413                                        run_distill(example).await?;
414                                    }
415                                    Command::Score(args) | Command::Eval(args) => {
416                                        run_scoring(example, &args, app_state.clone(), cx.clone())
417                                            .await?;
418                                    }
419                                    Command::Clean
420                                    | Command::Synthesize(_)
421                                    | Command::SplitCommit(_) => {
422                                        unreachable!()
423                                    }
424                                }
425                                anyhow::Ok(())
426                            }
427                            .await;
428
429                            if let Err(error) = result {
430                                handle_error(
431                                    error,
432                                    &args,
433                                    &command,
434                                    &app_state,
435                                    failfast_on_single_example,
436                                    example,
437                                )
438                                .await;
439                            }
440                        }
441                    });
442                    futures::future::join_all(futures).await;
443                }
444                Progress::global().finalize();
445
446                if args.output.is_some() || !matches!(command, Command::Eval(_)) {
447                    write_examples(&examples, output.as_ref());
448                }
449
450                match &command {
451                    Command::Predict(args) => predict::sync_batches(&args.provider).await?,
452                    Command::Eval(_) => score::print_report(&examples),
453                    _ => (),
454                };
455
456                anyhow::Ok(())
457            }
458            .await;
459
460            if let Err(e) = result {
461                panic!("Fatal error: {:?}", e);
462            }
463
464            let _ = cx.update(|cx| cx.quit());
465        })
466        .detach();
467    });
468}
469
470async fn handle_error(
471    error: anyhow::Error,
472    args: &EpArgs,
473    command: &Command,
474    app_state: &Arc<headless::EpAppState>,
475    failfast_on_single_example: bool,
476    example: &Example,
477) {
478    Progress::global().increment_failed();
479    let example_name = example.spec.filename();
480    let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
481    app_state
482        .fs
483        .write(
484            &failed_example_path,
485            &serde_json::to_vec_pretty(&example).unwrap(),
486        )
487        .await
488        .unwrap();
489    let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
490    app_state
491        .fs
492        .write(&err_path, format!("{error:?}").as_bytes())
493        .await
494        .unwrap();
495
496    let file_path = example
497        .repo_name()
498        .unwrap()
499        .worktree_path()
500        .join(&example.spec.cursor_path);
501
502    let msg = format!(
503        indoc::indoc! {"
504            While processing \"{}\":
505
506            {:?}
507
508            Written to: \x1b[36m{}\x1b[0m
509
510            Cursor File: \x1b[36m{}\x1b[0m
511
512            Explore this example data with:
513            fx \x1b[36m{}\x1b[0m
514
515            Re-run this example with:
516            cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
517        "},
518        example.spec.name,
519        error,
520        err_path.display(),
521        file_path.display(),
522        failed_example_path.display(),
523        command,
524        failed_example_path.display(),
525    );
526    if args.failfast || failfast_on_single_example {
527        Progress::global().finalize();
528        panic!("{}", msg);
529    } else {
530        log::error!("{}", msg);
531    }
532}