main.rs

  1mod anthropic_client;
  2mod distill;
  3mod example;
  4mod format_prompt;
  5mod git;
  6mod headless;
  7mod load_project;
  8mod metrics;
  9mod paths;
 10mod predict;
 11mod progress;
 12mod pull_examples;
 13mod reorder_patch;
 14mod retrieve_context;
 15mod score;
 16mod split_commit;
 17mod split_dataset;
 18mod synthesize;
 19use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
 20use collections::HashSet;
 21use edit_prediction::EditPredictionStore;
 22use futures::channel::mpsc;
 23use futures::{SinkExt as _, StreamExt as _};
 24use gpui::{AppContext as _, Application};
 25use zeta_prompt::ZetaVersion;
 26
 27use reqwest_client::ReqwestClient;
 28use serde::{Deserialize, Serialize};
 29use std::fmt::Display;
 30use std::fs::{File, OpenOptions};
 31use std::hash::{Hash, Hasher};
 32use std::io::{BufRead, BufReader, BufWriter, Write};
 33use std::{path::PathBuf, sync::Arc};
 34
 35use crate::distill::run_distill;
 36use crate::example::{Example, group_examples_by_repo, read_example_files};
 37use crate::format_prompt::run_format_prompt;
 38use crate::load_project::run_load_project;
 39use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
 40use crate::predict::run_prediction;
 41use crate::progress::Progress;
 42use crate::retrieve_context::run_context_retrieval;
 43use crate::score::run_scoring;
 44use crate::split_commit::SplitCommitArgs;
 45use crate::split_dataset::SplitArgs;
 46use crate::synthesize::{SynthesizeConfig, run_synthesize};
 47
 48#[derive(Parser, Debug)]
 49#[command(name = "ep")]
 50struct EpArgs {
 51    #[arg(long, default_value_t = false)]
 52    printenv: bool,
 53    #[clap(long, default_value_t = 10, global = true)]
 54    max_parallelism: usize,
 55    #[clap(long, global = true)]
 56    limit: Option<usize>,
 57    /// Filter examples by name
 58    #[clap(long, global = true)]
 59    name: Option<String>,
 60    /// Filter examples by repository
 61    #[clap(long, global = true)]
 62    repo: Option<String>,
 63    #[command(subcommand)]
 64    command: Option<Command>,
 65    #[clap(global = true, help = INPUTS_HELP)]
 66    inputs: Vec<PathBuf>,
 67    #[arg(long, short, global = true)]
 68    output: Option<PathBuf>,
 69    #[arg(long, short, global = true)]
 70    in_place: bool,
 71    #[arg(long, short, global = true)]
 72    failfast: bool,
 73    /// How to handle failed examples in output: keep them or skip them.
 74    /// Failed examples are always logged to the run's failed directory.
 75    #[arg(long, global = true, default_value = "keep")]
 76    failed: FailedHandling,
 77}
 78
 79/// Controls whether failed examples are included in the main output.
 80/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
 81#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
 82pub enum FailedHandling {
 83    /// Include failed examples in the main output (default)
 84    #[default]
 85    Keep,
 86    /// Exclude failed examples from the main output
 87    Skip,
 88}
 89
 90const INPUTS_HELP: &str = r#"
 91Inputs can be file paths or special specifiers:
 92
 93  path
 94      Path to an example(s) file (.md, .json, or .jsonl)
 95
 96  captured-after:{timestamp}
 97      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 98
 99      You can specify this multiple times and mix it with file inputs.
100
101      Required environment variables to connect to Snowflake:
102          EP_SNOWFLAKE_API_KEY
103          EP_SNOWFLAKE_BASE_URL
104
105      Optional:
106          EP_SNOWFLAKE_ROLE
107
108Examples:
109
110  # Predict from a file
111  ep predict examples.jsonl
112
113  # Predict from captured examples after a timestamp
114  ep predict captured-after:2025-01-01T00:00:00Z
115
116  # Mix file inputs and captured-after in the same invocation
117  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
118"#;
119
120#[derive(Subcommand, Debug, Clone)]
121enum Command {
122    /// Parse markdown examples and output a combined .jsonl file
123    ParseExample,
124    /// Create git worktrees for each example and load file contents
125    LoadProject,
126    /// Retrieve context for input examples.
127    Context,
128    /// Generate a prompt string for a specific model
129    FormatPrompt(FormatPromptArgs),
130    /// Runs edit prediction
131    Predict(PredictArgs),
132    /// Computes a score based on actual and expected patches
133    Score(PredictArgs),
134    /// Prepares a distillation dataset by copying expected outputs to
135    /// predicted outputs and removing actual outputs and prompts.
136    Distill,
137    /// Print aggregated scores
138    Eval(PredictArgs),
139    /// Generate eval examples by analyzing git commits from a repository
140    Synthesize(SynthesizeArgs),
141    /// Remove git repositories and worktrees
142    Clean,
143    /// Generate an evaluation example by splitting a chronologically-ordered commit
144    SplitCommit(SplitCommitArgs),
145    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
146    Split(SplitArgs),
147}
148
149impl Display for Command {
150    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151        match self {
152            Command::ParseExample => write!(f, "parse-example"),
153            Command::LoadProject => write!(f, "load-project"),
154            Command::Context => write!(f, "context"),
155            Command::FormatPrompt(format_prompt_args) => write!(
156                f,
157                "format-prompt --prompt-format={}",
158                format_prompt_args
159                    .provider
160                    .to_possible_value()
161                    .unwrap()
162                    .get_name()
163            ),
164            Command::Predict(predict_args) => {
165                write!(
166                    f,
167                    "predict --provider={:?}",
168                    predict_args
169                        .provider
170                        .to_possible_value()
171                        .unwrap()
172                        .get_name()
173                )
174            }
175            Command::Score(predict_args) => {
176                write!(
177                    f,
178                    "score --provider={:?}",
179                    predict_args
180                        .provider
181                        .to_possible_value()
182                        .unwrap()
183                        .get_name()
184                )
185            }
186            Command::Distill => write!(f, "distill"),
187            Command::Eval(predict_args) => write!(
188                f,
189                "eval --provider={:?}",
190                predict_args
191                    .provider
192                    .to_possible_value()
193                    .unwrap()
194                    .get_name()
195            ),
196            Command::Synthesize(args) => {
197                write!(f, "synthesize --repo={}", args.repo)
198            }
199            Command::Clean => write!(f, "clean"),
200            Command::SplitCommit(_) => write!(f, "split-commit"),
201            Command::Split(_) => write!(f, "split"),
202        }
203    }
204}
205
206#[derive(Debug, Args, Clone)]
207struct FormatPromptArgs {
208    #[clap(long, short)]
209    provider: PredictionProvider,
210    #[clap(
211        long,
212        short,
213        help = "(only for --provider zeta2) A substring of a zeta_prompt::ZetaVersion variant to use",
214        value_parser = ZetaVersion::parse,
215        default_value_t = ZetaVersion::default(),
216    )]
217    version: ZetaVersion,
218}
219
220#[derive(Debug, Args, Clone)]
221struct PredictArgs {
222    #[clap(long, short)]
223    provider: PredictionProvider,
224    #[clap(long, default_value_t = 1)]
225    repetitions: usize,
226    #[clap(
227        long,
228        short,
229        help = "(only for --provider zeta2) A substring of a zeta_prompt::ZetaVersion variant to use",
230        value_parser = ZetaVersion::parse,
231    )]
232    version: ZetaVersion,
233}
234
235#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
236enum PredictionProvider {
237    Sweep,
238    Mercury,
239    Zeta1,
240    Zeta2,
241    Teacher,
242    TeacherNonBatching,
243}
244
245#[derive(Debug, Args, Clone)]
246struct SynthesizeArgs {
247    /// Repository URL (git@github.com:owner/repo or https://...)
248    #[clap(long)]
249    repo: String,
250
251    /// Number of examples to generate
252    #[clap(long, default_value_t = 5)]
253    count: usize,
254
255    /// Maximum commits to scan before giving up
256    #[clap(long, default_value_t = 100)]
257    max_commits: usize,
258
259    /// Ignore state file and reprocess all commits
260    #[clap(long)]
261    fresh: bool,
262}
263
264impl EpArgs {
265    fn output_path(&self) -> Option<PathBuf> {
266        if self.in_place {
267            if self.inputs.len() == 1 {
268                self.inputs.first().cloned()
269            } else {
270                panic!("--in-place requires exactly one input file")
271            }
272        } else {
273            self.output.clone()
274        }
275    }
276}
277
278async fn load_examples(
279    http_client: Arc<dyn http_client::HttpClient>,
280    args: &EpArgs,
281    output_path: Option<&PathBuf>,
282) -> anyhow::Result<Vec<Example>> {
283    let mut captured_after_timestamps = Vec::new();
284    let mut file_inputs = Vec::new();
285
286    for input in &args.inputs {
287        let input_string = input.to_string_lossy();
288        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
289            captured_after_timestamps.push(timestamp.to_string());
290        } else {
291            file_inputs.push(input.clone());
292        }
293    }
294
295    let mut examples = read_example_files(&file_inputs);
296
297    Progress::global().set_total_examples(examples.len());
298
299    let remaining_limit_for_snowflake =
300        args.limit.map(|limit| limit.saturating_sub(examples.len()));
301
302    if let Some(0) = remaining_limit_for_snowflake {
303        log::info!(
304            "skipping captured-after inputs because --limit is already satisfied by example files"
305        );
306    } else if !captured_after_timestamps.is_empty() {
307        captured_after_timestamps.sort();
308
309        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
310
311        let mut captured_examples = pull_examples::fetch_captured_examples_after(
312            http_client,
313            &captured_after_timestamps,
314            max_rows_per_timestamp,
315        )
316        .await?;
317        examples.append(&mut captured_examples);
318    }
319
320    crate::example::sort_examples_by_repo_and_rev(&mut examples);
321
322    if let Some(name_filter) = &args.name {
323        examples.retain(|example| example.spec.name.contains(name_filter));
324    }
325    if let Some(repo_filter) = &args.repo {
326        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
327    }
328
329    if let Some(limit) = args.limit {
330        if examples.len() > limit {
331            examples.truncate(limit);
332        }
333    }
334
335    if let Some(path) = output_path {
336        resume_from_output(path, &mut examples);
337    }
338
339    Progress::global().set_total_examples(examples.len());
340
341    Ok(examples)
342}
343
344fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
345    let mut hasher = collections::FxHasher::default();
346    spec.hash(&mut hasher);
347    hasher.finish()
348}
349
350fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
351    let file = match File::open(path) {
352        Ok(f) => f,
353        Err(_) => return,
354    };
355
356    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
357
358    let reader = BufReader::new(file);
359    let mut kept_lines = Vec::new();
360    let mut kept_hashes = HashSet::default();
361
362    for line in reader.lines() {
363        let line = match line {
364            Ok(l) => l,
365            Err(_) => continue,
366        };
367
368        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
369            let hash = spec_hash(&output_example.spec);
370            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
371                kept_hashes.insert(hash);
372                kept_lines.push(line);
373            }
374        }
375    }
376
377    let total = examples.len();
378    let already_processed = kept_hashes.len();
379
380    eprintln!(
381        "Resuming: {}/{} examples already processed",
382        already_processed, total
383    );
384
385    let file = OpenOptions::new()
386        .write(true)
387        .truncate(true)
388        .open(path)
389        .expect("Failed to open output file for rewriting");
390    let mut writer = BufWriter::new(file);
391    for line in &kept_lines {
392        writeln!(writer, "{}", line).expect("Failed to write to output file");
393    }
394    writer.flush().expect("Failed to flush output file");
395
396    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
397}
398
399fn main() {
400    let args = EpArgs::parse();
401
402    if args.printenv {
403        ::util::shell_env::print_env();
404        return;
405    }
406
407    let output = args.output_path();
408    let command = match &args.command {
409        Some(cmd) => cmd.clone(),
410        None => {
411            EpArgs::command().print_help().unwrap();
412            return;
413        }
414    };
415
416    match &command {
417        Command::Clean => {
418            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
419            return;
420        }
421        Command::Synthesize(synth_args) => {
422            let Some(output_dir) = args.output else {
423                panic!("output dir is required");
424            };
425            let config = SynthesizeConfig {
426                repo_url: synth_args.repo.clone(),
427                count: synth_args.count,
428                max_commits: synth_args.max_commits,
429                output_dir,
430                fresh: synth_args.fresh,
431            };
432            smol::block_on(async {
433                if let Err(e) = run_synthesize(config).await {
434                    eprintln!("Error: {:?}", e);
435                    std::process::exit(1);
436                }
437            });
438            return;
439        }
440        Command::SplitCommit(split_commit_args) => {
441            if let Err(error) =
442                split_commit::run_split_commit(split_commit_args, &args.inputs, output.as_ref())
443            {
444                eprintln!("{error:#}");
445                std::process::exit(1);
446            }
447            return;
448        }
449        Command::Split(split_args) => {
450            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
451                eprintln!("{error:#}");
452                std::process::exit(1);
453            }
454            return;
455        }
456        _ => {}
457    }
458
459    let http_client = Arc::new(ReqwestClient::new());
460    let app = Application::headless().with_http_client(http_client);
461
462    app.run(move |cx| {
463        let app_state = Arc::new(headless::init(cx));
464        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
465
466        cx.spawn(async move |cx| {
467            let result = async {
468                let mut examples =
469                    load_examples(app_state.client.http_client(), &args, output.as_ref()).await?;
470
471                match &command {
472                    Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
473                        predict::sync_batches(&args.provider).await?;
474                    }
475                    _ => (),
476                }
477
478                let failfast_on_single_example = examples.len() == 1;
479
480                let output_sender: Option<mpsc::UnboundedSender<String>> =
481                    if args.output.is_some() || !matches!(command, Command::Eval(_)) {
482                        output.as_ref().map(|path| {
483                            let file = OpenOptions::new()
484                                .create(true)
485                                .append(true)
486                                .open(path)
487                                .expect("Failed to open output file");
488                            let mut writer = BufWriter::new(file);
489                            let (sender, mut receiver) = mpsc::unbounded::<String>();
490                            cx.background_spawn(async move {
491                                while let Some(line) = receiver.next().await {
492                                    writeln!(writer, "{}", line).expect("Failed to write example");
493                                    writer.flush().expect("Failed to flush output");
494                                }
495                            })
496                            .detach();
497                            sender
498                        })
499                    } else {
500                        None
501                    };
502
503                let mut grouped_examples = group_examples_by_repo(&mut examples);
504                let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
505
506                for example_batch in example_batches {
507                    let futures = example_batch.into_iter().map(|repo_examples| async {
508                        for example in repo_examples.iter_mut() {
509                            let result = async {
510                                match &command {
511                                    Command::ParseExample => {}
512                                    Command::LoadProject => {
513                                        run_load_project(example, app_state.clone(), cx.clone())
514                                            .await?;
515                                    }
516                                    Command::Context => {
517                                        run_context_retrieval(
518                                            example,
519                                            app_state.clone(),
520                                            cx.clone(),
521                                        )
522                                        .await?;
523                                    }
524                                    Command::FormatPrompt(args) => {
525                                        run_format_prompt(
526                                            example,
527                                            args,
528                                            app_state.clone(),
529                                            cx.clone(),
530                                        )
531                                        .await?;
532                                    }
533                                    Command::Predict(args) => {
534                                        run_prediction(
535                                            example,
536                                            args,
537                                            app_state.clone(),
538                                            cx.clone(),
539                                        )
540                                        .await?;
541                                    }
542                                    Command::Distill => {
543                                        run_distill(example).await?;
544                                    }
545                                    Command::Score(args) | Command::Eval(args) => {
546                                        run_scoring(example, &args, app_state.clone(), cx.clone())
547                                            .await?;
548                                    }
549                                    Command::Clean
550                                    | Command::Synthesize(_)
551                                    | Command::SplitCommit(_)
552                                    | Command::Split(_) => {
553                                        unreachable!()
554                                    }
555                                }
556                                anyhow::Ok(())
557                            }
558                            .await;
559
560                            let failed = if let Err(error) = result {
561                                handle_error(
562                                    error,
563                                    &args,
564                                    &command,
565                                    &app_state,
566                                    failfast_on_single_example,
567                                    example,
568                                )
569                                .await;
570                                true
571                            } else {
572                                false
573                            };
574
575                            let should_write = !failed || args.failed == FailedHandling::Keep;
576                            if should_write {
577                                if let Some(ref mut sender) = output_sender.clone() {
578                                    let line = serde_json::to_string(example).unwrap();
579                                    sender
580                                        .send(line)
581                                        .await
582                                        .expect("Failed to send to output writer");
583                                } else if args.output.is_none()
584                                    && !matches!(command, Command::Eval(_))
585                                {
586                                    let line = serde_json::to_string(example).unwrap();
587                                    println!("{}", line);
588                                }
589                            }
590                        }
591                    });
592                    futures::future::join_all(futures).await;
593                }
594
595                Progress::global().finalize();
596
597                match &command {
598                    Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
599                        predict::sync_batches(&args.provider).await?;
600                    }
601                    _ => (),
602                }
603
604                match &command {
605                    Command::Eval(_) => score::print_report(&examples),
606                    _ => (),
607                };
608
609                anyhow::Ok(())
610            }
611            .await;
612
613            if let Err(e) = result {
614                panic!("Fatal error: {:?}", e);
615            }
616
617            let _ = cx.update(|cx| cx.quit());
618        })
619        .detach();
620    });
621}
622
623async fn handle_error(
624    error: anyhow::Error,
625    args: &EpArgs,
626    command: &Command,
627    app_state: &Arc<headless::EpAppState>,
628    failfast_on_single_example: bool,
629    example: &Example,
630) {
631    Progress::global().increment_failed();
632    let example_name = example.spec.filename();
633    let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
634    app_state
635        .fs
636        .write(
637            &failed_example_path,
638            &serde_json::to_vec_pretty(&example).unwrap(),
639        )
640        .await
641        .unwrap();
642    let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
643    app_state
644        .fs
645        .write(&err_path, format!("{error:?}").as_bytes())
646        .await
647        .unwrap();
648
649    let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
650    let mut file = OpenOptions::new()
651        .create(true)
652        .append(true)
653        .open(&failed_jsonl_path)
654        .expect("Failed to open failed.jsonl");
655    writeln!(file, "{}", serde_json::to_string(example).unwrap())
656        .expect("Failed to write to failed.jsonl");
657
658    let cursor_path = example
659        .repo_name()
660        .unwrap()
661        .worktree_path()
662        .join(&example.spec.cursor_path);
663
664    let msg = format!(
665        indoc::indoc! {"
666            While processing \"{}\":
667
668            \x1b[31m{:?}\x1b[0m
669
670            Example:        \x1b[36m{}\x1b[0m
671            Error file:     \x1b[36m{}\x1b[0m
672            Cursor file:    \x1b[36m{}\x1b[0m
673            Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
674        "},
675        example.spec.name,
676        error,
677        failed_example_path.display(),
678        err_path.display(),
679        cursor_path.display(),
680        command,
681        failed_example_path.display(),
682    );
683    if args.failfast || failfast_on_single_example {
684        Progress::global().finalize();
685        panic!("{}", msg);
686    } else {
687        log::error!("{}", msg);
688    }
689}