main.rs

  1mod anthropic_client;
  2mod distill;
  3mod example;
  4mod format_prompt;
  5mod git;
  6mod headless;
  7mod load_project;
  8mod metrics;
  9mod paths;
 10mod predict;
 11mod progress;
 12mod pull_examples;
 13mod reorder_patch;
 14mod retrieve_context;
 15mod score;
 16mod split_commit;
 17mod split_dataset;
 18mod synthesize;
 19use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
 20use collections::HashSet;
 21use edit_prediction::EditPredictionStore;
 22use futures::channel::mpsc;
 23use futures::{SinkExt as _, StreamExt as _};
 24use gpui::{AppContext as _, Application, BackgroundExecutor};
 25use zeta_prompt::ZetaVersion;
 26
 27use reqwest_client::ReqwestClient;
 28use serde::{Deserialize, Serialize};
 29use std::fmt::Display;
 30use std::fs::{File, OpenOptions};
 31use std::hash::{Hash, Hasher};
 32use std::io::{BufRead, BufReader, BufWriter, Write};
 33use std::{path::PathBuf, sync::Arc};
 34
 35use crate::distill::run_distill;
 36use crate::example::{Example, group_examples_by_repo, read_example_files};
 37use crate::format_prompt::run_format_prompt;
 38use crate::load_project::run_load_project;
 39use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
 40use crate::predict::run_prediction;
 41use crate::progress::Progress;
 42use crate::retrieve_context::run_context_retrieval;
 43use crate::score::run_scoring;
 44use crate::split_commit::SplitCommitArgs;
 45use crate::split_dataset::SplitArgs;
 46use crate::synthesize::{SynthesizeConfig, run_synthesize};
 47
 48#[derive(Parser, Debug)]
 49#[command(name = "ep")]
 50struct EpArgs {
 51    #[arg(long, default_value_t = false)]
 52    printenv: bool,
 53    #[clap(long, default_value_t = 10, global = true)]
 54    max_parallelism: usize,
 55    #[clap(long, global = true)]
 56    limit: Option<usize>,
 57    /// Filter examples by name
 58    #[clap(long, global = true)]
 59    name: Option<String>,
 60    /// Filter examples by repository
 61    #[clap(long, global = true)]
 62    repo: Option<String>,
 63    #[command(subcommand)]
 64    command: Option<Command>,
 65    #[clap(global = true, help = INPUTS_HELP)]
 66    inputs: Vec<PathBuf>,
 67    #[arg(long, short, global = true)]
 68    output: Option<PathBuf>,
 69    #[arg(long, short, global = true)]
 70    in_place: bool,
 71    #[arg(long, short, global = true)]
 72    failfast: bool,
 73    /// How to handle failed examples in output: keep them or skip them.
 74    /// Failed examples are always logged to the run's failed directory.
 75    #[arg(long, global = true, default_value = "keep")]
 76    failed: FailedHandling,
 77}
 78
 79/// Controls whether failed examples are included in the main output.
 80/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
 81#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
 82pub enum FailedHandling {
 83    /// Include failed examples in the main output (default)
 84    #[default]
 85    Keep,
 86    /// Exclude failed examples from the main output
 87    Skip,
 88}
 89
 90const INPUTS_HELP: &str = r#"
 91Inputs can be file paths or special specifiers:
 92
 93  path
 94      Path to an example(s) file (.md, .json, or .jsonl)
 95
 96  captured-after:{timestamp}
 97      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 98
 99      You can specify this multiple times and mix it with file inputs.
100
101      Required environment variables to connect to Snowflake:
102          EP_SNOWFLAKE_API_KEY
103          EP_SNOWFLAKE_BASE_URL
104
105      Optional:
106          EP_SNOWFLAKE_ROLE
107
108Examples:
109
110  # Predict from a file
111  ep predict examples.jsonl
112
113  # Predict from captured examples after a timestamp
114  ep predict captured-after:2025-01-01T00:00:00Z
115
116  # Mix file inputs and captured-after in the same invocation
117  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
118"#;
119
120#[derive(Subcommand, Debug, Clone)]
121enum Command {
122    /// Parse markdown examples and output a combined .jsonl file
123    ParseExample,
124    /// Create git worktrees for each example and load file contents
125    LoadProject,
126    /// Retrieve context for input examples.
127    Context,
128    /// Generate a prompt string for a specific model
129    FormatPrompt(FormatPromptArgs),
130    /// Runs edit prediction
131    Predict(PredictArgs),
132    /// Computes a score based on actual and expected patches
133    Score(PredictArgs),
134    /// Prepares a distillation dataset by copying expected outputs to
135    /// predicted outputs and removing actual outputs and prompts.
136    Distill,
137    /// Print aggregated scores
138    Eval(PredictArgs),
139    /// Generate eval examples by analyzing git commits from a repository
140    Synthesize(SynthesizeArgs),
141    /// Remove git repositories and worktrees
142    Clean,
143    /// Generate an evaluation example by splitting a chronologically-ordered commit
144    SplitCommit(SplitCommitArgs),
145    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
146    Split(SplitArgs),
147}
148
149impl Display for Command {
150    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151        match self {
152            Command::ParseExample => write!(f, "parse-example"),
153            Command::LoadProject => write!(f, "load-project"),
154            Command::Context => write!(f, "context"),
155            Command::FormatPrompt(format_prompt_args) => write!(
156                f,
157                "format-prompt --prompt-format={}",
158                format_prompt_args
159                    .provider
160                    .to_possible_value()
161                    .unwrap()
162                    .get_name()
163            ),
164            Command::Predict(predict_args) => {
165                write!(
166                    f,
167                    "predict --provider={:?}",
168                    predict_args
169                        .provider
170                        .to_possible_value()
171                        .unwrap()
172                        .get_name()
173                )
174            }
175            Command::Score(predict_args) => {
176                write!(
177                    f,
178                    "score --provider={:?}",
179                    predict_args
180                        .provider
181                        .to_possible_value()
182                        .unwrap()
183                        .get_name()
184                )
185            }
186            Command::Distill => write!(f, "distill"),
187            Command::Eval(predict_args) => write!(
188                f,
189                "eval --provider={:?}",
190                predict_args
191                    .provider
192                    .to_possible_value()
193                    .unwrap()
194                    .get_name()
195            ),
196            Command::Synthesize(args) => {
197                write!(f, "synthesize --repo={}", args.repo)
198            }
199            Command::Clean => write!(f, "clean"),
200            Command::SplitCommit(_) => write!(f, "split-commit"),
201            Command::Split(_) => write!(f, "split"),
202        }
203    }
204}
205
206#[derive(Debug, Args, Clone)]
207struct FormatPromptArgs {
208    #[clap(long, short)]
209    provider: PredictionProvider,
210    #[clap(
211        long,
212        short,
213        help = "(only for --provider zeta2) A substring of a zeta_prompt::ZetaVersion variant to use",
214        value_parser = ZetaVersion::parse,
215        default_value_t = ZetaVersion::default(),
216    )]
217    version: ZetaVersion,
218}
219
220#[derive(Debug, Args, Clone)]
221struct PredictArgs {
222    #[clap(long, short)]
223    provider: PredictionProvider,
224    #[clap(long, default_value_t = 1)]
225    repetitions: usize,
226    #[clap(
227        long,
228        short,
229        help = "(only for --provider zeta2) A substring of a zeta_prompt::ZetaVersion variant to use",
230        value_parser = ZetaVersion::parse,
231    )]
232    version: ZetaVersion,
233}
234
235#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
236enum PredictionProvider {
237    Sweep,
238    Mercury,
239    Zeta1,
240    Zeta2,
241    Teacher,
242    TeacherNonBatching,
243}
244
245#[derive(Debug, Args, Clone)]
246struct SynthesizeArgs {
247    /// Repository URL (git@github.com:owner/repo or https://...)
248    #[clap(long)]
249    repo: String,
250
251    /// Number of examples to generate
252    #[clap(long, default_value_t = 5)]
253    count: usize,
254
255    /// Maximum commits to scan before giving up
256    #[clap(long, default_value_t = 100)]
257    max_commits: usize,
258
259    /// Ignore state file and reprocess all commits
260    #[clap(long)]
261    fresh: bool,
262}
263
264impl EpArgs {
265    fn output_path(&self) -> Option<PathBuf> {
266        if self.in_place {
267            if self.inputs.len() == 1 {
268                self.inputs.first().cloned()
269            } else {
270                panic!("--in-place requires exactly one input file")
271            }
272        } else {
273            self.output.clone()
274        }
275    }
276}
277
278async fn load_examples(
279    http_client: Arc<dyn http_client::HttpClient>,
280    args: &EpArgs,
281    output_path: Option<&PathBuf>,
282    background_executor: BackgroundExecutor,
283) -> anyhow::Result<Vec<Example>> {
284    let mut captured_after_timestamps = Vec::new();
285    let mut file_inputs = Vec::new();
286
287    for input in &args.inputs {
288        let input_string = input.to_string_lossy();
289        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
290            captured_after_timestamps.push(timestamp.to_string());
291        } else {
292            file_inputs.push(input.clone());
293        }
294    }
295
296    let mut examples = read_example_files(&file_inputs);
297
298    Progress::global().set_total_examples(examples.len());
299
300    let remaining_limit_for_snowflake =
301        args.limit.map(|limit| limit.saturating_sub(examples.len()));
302
303    if let Some(0) = remaining_limit_for_snowflake {
304        log::info!(
305            "skipping captured-after inputs because --limit is already satisfied by example files"
306        );
307    } else if !captured_after_timestamps.is_empty() {
308        captured_after_timestamps.sort();
309
310        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
311
312        let mut captured_examples = pull_examples::fetch_captured_examples_after(
313            http_client,
314            &captured_after_timestamps,
315            max_rows_per_timestamp,
316            background_executor,
317        )
318        .await?;
319        examples.append(&mut captured_examples);
320    }
321
322    crate::example::sort_examples_by_repo_and_rev(&mut examples);
323
324    if let Some(name_filter) = &args.name {
325        examples.retain(|example| example.spec.name.contains(name_filter));
326    }
327    if let Some(repo_filter) = &args.repo {
328        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
329    }
330
331    if let Some(limit) = args.limit {
332        if examples.len() > limit {
333            examples.truncate(limit);
334        }
335    }
336
337    if let Some(path) = output_path {
338        resume_from_output(path, &mut examples);
339    }
340
341    Progress::global().set_total_examples(examples.len());
342
343    Ok(examples)
344}
345
346fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
347    let mut hasher = collections::FxHasher::default();
348    spec.hash(&mut hasher);
349    hasher.finish()
350}
351
352fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
353    let file = match File::open(path) {
354        Ok(f) => f,
355        Err(_) => return,
356    };
357
358    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
359
360    let reader = BufReader::new(file);
361    let mut kept_lines = Vec::new();
362    let mut kept_hashes = HashSet::default();
363
364    for line in reader.lines() {
365        let line = match line {
366            Ok(l) => l,
367            Err(_) => continue,
368        };
369
370        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
371            let hash = spec_hash(&output_example.spec);
372            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
373                kept_hashes.insert(hash);
374                kept_lines.push(line);
375            }
376        }
377    }
378
379    let total = examples.len();
380    let already_processed = kept_hashes.len();
381
382    eprintln!(
383        "Resuming: {}/{} examples already processed",
384        already_processed, total
385    );
386
387    let file = OpenOptions::new()
388        .write(true)
389        .truncate(true)
390        .open(path)
391        .expect("Failed to open output file for rewriting");
392    let mut writer = BufWriter::new(file);
393    for line in &kept_lines {
394        writeln!(writer, "{}", line).expect("Failed to write to output file");
395    }
396    writer.flush().expect("Failed to flush output file");
397
398    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
399}
400
401fn main() {
402    let args = EpArgs::parse();
403
404    if args.printenv {
405        ::util::shell_env::print_env();
406        return;
407    }
408
409    let output = args.output_path();
410    let command = match &args.command {
411        Some(cmd) => cmd.clone(),
412        None => {
413            EpArgs::command().print_help().unwrap();
414            return;
415        }
416    };
417
418    match &command {
419        Command::Clean => {
420            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
421            return;
422        }
423        Command::Synthesize(synth_args) => {
424            let Some(output_dir) = args.output else {
425                panic!("output dir is required");
426            };
427            let config = SynthesizeConfig {
428                repo_url: synth_args.repo.clone(),
429                count: synth_args.count,
430                max_commits: synth_args.max_commits,
431                output_dir,
432                fresh: synth_args.fresh,
433            };
434            smol::block_on(async {
435                if let Err(e) = run_synthesize(config).await {
436                    eprintln!("Error: {:?}", e);
437                    std::process::exit(1);
438                }
439            });
440            return;
441        }
442        Command::SplitCommit(split_commit_args) => {
443            if let Err(error) =
444                split_commit::run_split_commit(split_commit_args, &args.inputs, output.as_ref())
445            {
446                eprintln!("{error:#}");
447                std::process::exit(1);
448            }
449            return;
450        }
451        Command::Split(split_args) => {
452            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
453                eprintln!("{error:#}");
454                std::process::exit(1);
455            }
456            return;
457        }
458        _ => {}
459    }
460
461    let http_client = Arc::new(ReqwestClient::new());
462    let app = Application::headless().with_http_client(http_client);
463
464    app.run(move |cx| {
465        let app_state = Arc::new(headless::init(cx));
466        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
467
468        cx.spawn(async move |cx| {
469            let result = async {
470                let mut examples = load_examples(
471                    app_state.client.http_client(),
472                    &args,
473                    output.as_ref(),
474                    cx.background_executor().clone(),
475                )
476                .await?;
477
478                match &command {
479                    Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
480                        predict::sync_batches(&args.provider).await?;
481                    }
482                    _ => (),
483                }
484
485                let failfast_on_single_example = examples.len() == 1;
486
487                let output_sender: Option<mpsc::UnboundedSender<String>> =
488                    if args.output.is_some() || !matches!(command, Command::Eval(_)) {
489                        output.as_ref().map(|path| {
490                            let file = OpenOptions::new()
491                                .create(true)
492                                .append(true)
493                                .open(path)
494                                .expect("Failed to open output file");
495                            let mut writer = BufWriter::new(file);
496                            let (sender, mut receiver) = mpsc::unbounded::<String>();
497                            cx.background_spawn(async move {
498                                while let Some(line) = receiver.next().await {
499                                    writeln!(writer, "{}", line).expect("Failed to write example");
500                                    writer.flush().expect("Failed to flush output");
501                                }
502                            })
503                            .detach();
504                            sender
505                        })
506                    } else {
507                        None
508                    };
509
510                let mut grouped_examples = group_examples_by_repo(&mut examples);
511                let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
512
513                for example_batch in example_batches {
514                    let futures = example_batch.into_iter().map(|repo_examples| async {
515                        for example in repo_examples.iter_mut() {
516                            let result = async {
517                                match &command {
518                                    Command::ParseExample => {}
519                                    Command::LoadProject => {
520                                        run_load_project(example, app_state.clone(), cx.clone())
521                                            .await?;
522                                    }
523                                    Command::Context => {
524                                        run_context_retrieval(
525                                            example,
526                                            app_state.clone(),
527                                            cx.clone(),
528                                        )
529                                        .await?;
530                                    }
531                                    Command::FormatPrompt(args) => {
532                                        run_format_prompt(
533                                            example,
534                                            args,
535                                            app_state.clone(),
536                                            cx.clone(),
537                                        )
538                                        .await?;
539                                    }
540                                    Command::Predict(args) => {
541                                        run_prediction(
542                                            example,
543                                            args,
544                                            app_state.clone(),
545                                            cx.clone(),
546                                        )
547                                        .await?;
548                                    }
549                                    Command::Distill => {
550                                        run_distill(example).await?;
551                                    }
552                                    Command::Score(args) | Command::Eval(args) => {
553                                        run_scoring(example, &args, app_state.clone(), cx.clone())
554                                            .await?;
555                                    }
556                                    Command::Clean
557                                    | Command::Synthesize(_)
558                                    | Command::SplitCommit(_)
559                                    | Command::Split(_) => {
560                                        unreachable!()
561                                    }
562                                }
563                                anyhow::Ok(())
564                            }
565                            .await;
566
567                            let failed = if let Err(error) = result {
568                                handle_error(
569                                    error,
570                                    &args,
571                                    &command,
572                                    &app_state,
573                                    failfast_on_single_example,
574                                    example,
575                                )
576                                .await;
577                                true
578                            } else {
579                                false
580                            };
581
582                            let should_write = !failed || args.failed == FailedHandling::Keep;
583                            if should_write {
584                                if let Some(ref mut sender) = output_sender.clone() {
585                                    let line = serde_json::to_string(example).unwrap();
586                                    sender
587                                        .send(line)
588                                        .await
589                                        .expect("Failed to send to output writer");
590                                } else if args.output.is_none()
591                                    && !matches!(command, Command::Eval(_))
592                                {
593                                    let line = serde_json::to_string(example).unwrap();
594                                    println!("{}", line);
595                                }
596                            }
597                        }
598                    });
599                    futures::future::join_all(futures).await;
600                }
601
602                Progress::global().finalize();
603
604                match &command {
605                    Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
606                        predict::sync_batches(&args.provider).await?;
607                    }
608                    _ => (),
609                }
610
611                match &command {
612                    Command::Eval(_) => score::print_report(&examples),
613                    _ => (),
614                };
615
616                anyhow::Ok(())
617            }
618            .await;
619
620            if let Err(e) = result {
621                panic!("Fatal error: {:?}", e);
622            }
623
624            let _ = cx.update(|cx| cx.quit());
625        })
626        .detach();
627    });
628}
629
630async fn handle_error(
631    error: anyhow::Error,
632    args: &EpArgs,
633    command: &Command,
634    app_state: &Arc<headless::EpAppState>,
635    failfast_on_single_example: bool,
636    example: &Example,
637) {
638    Progress::global().increment_failed();
639    let example_name = example.spec.filename();
640    let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
641    app_state
642        .fs
643        .write(
644            &failed_example_path,
645            &serde_json::to_vec_pretty(&example).unwrap(),
646        )
647        .await
648        .unwrap();
649    let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
650    app_state
651        .fs
652        .write(&err_path, format!("{error:?}").as_bytes())
653        .await
654        .unwrap();
655
656    let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
657    let mut file = OpenOptions::new()
658        .create(true)
659        .append(true)
660        .open(&failed_jsonl_path)
661        .expect("Failed to open failed.jsonl");
662    writeln!(file, "{}", serde_json::to_string(example).unwrap())
663        .expect("Failed to write to failed.jsonl");
664
665    let cursor_path = example
666        .repo_name()
667        .unwrap()
668        .worktree_path()
669        .join(&example.spec.cursor_path);
670
671    let msg = format!(
672        indoc::indoc! {"
673            While processing \"{}\":
674
675            \x1b[31m{:?}\x1b[0m
676
677            Example:        \x1b[36m{}\x1b[0m
678            Error file:     \x1b[36m{}\x1b[0m
679            Cursor file:    \x1b[36m{}\x1b[0m
680            Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
681        "},
682        example.spec.name,
683        error,
684        failed_example_path.display(),
685        err_path.display(),
686        cursor_path.display(),
687        command,
688        failed_example_path.display(),
689    );
690    if args.failfast || failfast_on_single_example {
691        Progress::global().finalize();
692        panic!("{}", msg);
693    } else {
694        log::error!("{}", msg);
695    }
696}