main.rs

  1mod anthropic_client;
  2mod distill;
  3mod example;
  4mod format_prompt;
  5mod git;
  6mod headless;
  7mod load_project;
  8mod metrics;
  9mod paths;
 10mod predict;
 11mod progress;
 12mod pull_examples;
 13mod reorder_patch;
 14mod retrieve_context;
 15mod score;
 16mod split_commit;
 17mod split_dataset;
 18mod synthesize;
 19use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
 20use collections::HashSet;
 21use edit_prediction::EditPredictionStore;
 22use futures::channel::mpsc;
 23use futures::{SinkExt as _, StreamExt as _};
 24use gpui::{AppContext as _, Application, BackgroundExecutor};
 25use zeta_prompt::ZetaVersion;
 26
 27use reqwest_client::ReqwestClient;
 28use serde::{Deserialize, Deserializer, Serialize, Serializer};
 29use std::fmt::Display;
 30use std::fs::{File, OpenOptions};
 31use std::hash::{Hash, Hasher};
 32use std::io::{BufRead, BufReader, BufWriter, Write};
 33use std::{path::PathBuf, sync::Arc};
 34
 35use crate::distill::run_distill;
 36use crate::example::{Example, group_examples_by_repo, read_example_files};
 37use crate::format_prompt::run_format_prompt;
 38use crate::load_project::run_load_project;
 39use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
 40use crate::predict::run_prediction;
 41use crate::progress::Progress;
 42use crate::retrieve_context::run_context_retrieval;
 43use crate::score::run_scoring;
 44use crate::split_commit::SplitCommitArgs;
 45use crate::split_dataset::SplitArgs;
 46use crate::synthesize::{SynthesizeConfig, run_synthesize};
 47
 48#[derive(Parser, Debug)]
 49#[command(name = "ep")]
 50struct EpArgs {
 51    #[arg(long, default_value_t = false)]
 52    printenv: bool,
 53    #[clap(long, default_value_t = 10, global = true)]
 54    max_parallelism: usize,
 55    #[clap(long, global = true)]
 56    limit: Option<usize>,
 57    /// Filter examples by name
 58    #[clap(long, global = true)]
 59    name: Option<String>,
 60    /// Filter examples by repository
 61    #[clap(long, global = true)]
 62    repo: Option<String>,
 63    #[command(subcommand)]
 64    command: Option<Command>,
 65    #[clap(global = true, help = INPUTS_HELP)]
 66    inputs: Vec<PathBuf>,
 67    #[arg(long, short, global = true)]
 68    output: Option<PathBuf>,
 69    #[arg(long, short, global = true)]
 70    in_place: bool,
 71    #[arg(long, short, global = true)]
 72    failfast: bool,
 73    /// How to handle failed examples in output: keep them or skip them.
 74    /// Failed examples are always logged to the run's failed directory.
 75    #[arg(long, global = true, default_value = "keep")]
 76    failed: FailedHandling,
 77}
 78
 79/// Controls whether failed examples are included in the main output.
 80/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
 81#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
 82pub enum FailedHandling {
 83    /// Include failed examples in the main output (default)
 84    #[default]
 85    Keep,
 86    /// Exclude failed examples from the main output
 87    Skip,
 88}
 89
 90const INPUTS_HELP: &str = r#"
 91Inputs can be file paths or special specifiers:
 92
 93  path
 94      Path to an example(s) file (.md, .json, or .jsonl)
 95
 96  captured-after:{timestamp}
 97      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 98
 99      You can specify this multiple times and mix it with file inputs.
100
101      Required environment variables to connect to Snowflake:
102          EP_SNOWFLAKE_API_KEY
103          EP_SNOWFLAKE_BASE_URL
104
105      Optional:
106          EP_SNOWFLAKE_ROLE
107
108Examples:
109
110  # Predict from a file
111  ep predict examples.jsonl
112
113  # Predict from captured examples after a timestamp
114  ep predict captured-after:2025-01-01T00:00:00Z
115
116  # Mix file inputs and captured-after in the same invocation
117  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
118"#;
119
120#[derive(Subcommand, Debug, Clone)]
121enum Command {
122    /// Parse markdown examples and output a combined .jsonl file
123    ParseExample,
124    /// Create git worktrees for each example and load file contents
125    LoadProject,
126    /// Retrieve context for input examples.
127    Context,
128    /// Generate a prompt string for a specific model
129    FormatPrompt(FormatPromptArgs),
130    /// Runs edit prediction
131    Predict(PredictArgs),
132    /// Computes a score based on actual and expected patches
133    Score(PredictArgs),
134    /// Prepares a distillation dataset by copying expected outputs to
135    /// predicted outputs and removing actual outputs and prompts.
136    Distill,
137    /// Print aggregated scores
138    Eval(PredictArgs),
139    /// Generate eval examples by analyzing git commits from a repository
140    Synthesize(SynthesizeArgs),
141    /// Remove git repositories and worktrees
142    Clean,
143    /// Generate an evaluation example by splitting a chronologically-ordered commit
144    SplitCommit(SplitCommitArgs),
145    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
146    Split(SplitArgs),
147}
148
149impl Display for Command {
150    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151        match self {
152            Command::ParseExample => write!(f, "parse-example"),
153            Command::LoadProject => write!(f, "load-project"),
154            Command::Context => write!(f, "context"),
155            Command::FormatPrompt(args) => {
156                write!(f, "format-prompt --provider={}", args.provider)
157            }
158            Command::Predict(args) => {
159                write!(f, "predict --provider={}", args.provider)
160            }
161            Command::Score(args) => {
162                write!(f, "score --provider={}", args.provider)
163            }
164            Command::Distill => write!(f, "distill"),
165            Command::Eval(args) => {
166                write!(f, "eval --provider={}", args.provider)
167            }
168            Command::Synthesize(args) => {
169                write!(f, "synthesize --repos {}", args.repos.join(" "))
170            }
171            Command::Clean => write!(f, "clean"),
172            Command::SplitCommit(_) => write!(f, "split-commit"),
173            Command::Split(_) => write!(f, "split"),
174        }
175    }
176}
177
178#[derive(Debug, Args, Clone)]
179struct FormatPromptArgs {
180    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
181    provider: PredictionProvider,
182}
183
184#[derive(Debug, Args, Clone)]
185struct PredictArgs {
186    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
187    provider: PredictionProvider,
188    #[clap(long, default_value_t = 1)]
189    repetitions: usize,
190}
191
192#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
193enum PredictionProvider {
194    Sweep,
195    Mercury,
196    Zeta1,
197    Zeta2(ZetaVersion),
198    Teacher(ZetaVersion),
199    TeacherNonBatching(ZetaVersion),
200}
201
202impl Default for PredictionProvider {
203    fn default() -> Self {
204        PredictionProvider::Zeta2(ZetaVersion::default())
205    }
206}
207
208impl std::fmt::Display for PredictionProvider {
209    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210        match self {
211            PredictionProvider::Sweep => write!(f, "sweep"),
212            PredictionProvider::Mercury => write!(f, "mercury"),
213            PredictionProvider::Zeta1 => write!(f, "zeta1"),
214            PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
215            PredictionProvider::Teacher(version) => write!(f, "teacher:{version}"),
216            PredictionProvider::TeacherNonBatching(version) => {
217                write!(f, "teacher-non-batching:{version}")
218            }
219        }
220    }
221}
222
223impl std::str::FromStr for PredictionProvider {
224    type Err = anyhow::Error;
225
226    fn from_str(mut s: &str) -> Result<Self, Self::Err> {
227        let mut version = ZetaVersion::default();
228        if let Some((first, second)) = s.split_once(':') {
229            version = ZetaVersion::parse(second)?;
230            s = first;
231        }
232
233        let s_lower = s.to_lowercase();
234        match s_lower.as_str() {
235            "sweep" => Ok(PredictionProvider::Sweep),
236            "mercury" => Ok(PredictionProvider::Mercury),
237            "zeta1" => Ok(PredictionProvider::Zeta1),
238            "zeta2" => Ok(PredictionProvider::Zeta2(version)),
239            "teacher" => Ok(PredictionProvider::Teacher(version)),
240            "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
241                Ok(PredictionProvider::TeacherNonBatching(version))
242            }
243            _ => {
244                anyhow::bail!(
245                    "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching\n\
246                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
247                 Available zeta versions:\n{}",
248                    ZetaVersion::options_as_string()
249                )
250            }
251        }
252    }
253}
254
255impl Serialize for PredictionProvider {
256    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
257    where
258        S: Serializer,
259    {
260        serializer.serialize_str(&self.to_string())
261    }
262}
263
264impl<'de> Deserialize<'de> for PredictionProvider {
265    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
266    where
267        D: Deserializer<'de>,
268    {
269        let s = String::deserialize(deserializer)?;
270        s.parse().map_err(serde::de::Error::custom)
271    }
272}
273
274#[derive(Debug, Args, Clone)]
275struct SynthesizeArgs {
276    /// Repository URLs (git@github.com:owner/repo or https://...)
277    #[clap(long, required = true, num_args = 1..)]
278    repos: Vec<String>,
279
280    /// Number of examples to generate per repository
281    #[clap(long, default_value_t = 5)]
282    count: usize,
283
284    /// Maximum commits to scan per repository before giving up
285    #[clap(long, default_value_t = 100)]
286    max_commits: usize,
287
288    /// Ignore state file and reprocess all commits
289    #[clap(long)]
290    fresh: bool,
291}
292
293impl EpArgs {
294    fn output_path(&self) -> Option<PathBuf> {
295        if self.in_place {
296            if self.inputs.len() == 1 {
297                self.inputs.first().cloned()
298            } else {
299                panic!("--in-place requires exactly one input file")
300            }
301        } else {
302            self.output.clone()
303        }
304    }
305}
306
307async fn load_examples(
308    http_client: Arc<dyn http_client::HttpClient>,
309    args: &EpArgs,
310    output_path: Option<&PathBuf>,
311    background_executor: BackgroundExecutor,
312) -> anyhow::Result<Vec<Example>> {
313    let mut captured_after_timestamps = Vec::new();
314    let mut file_inputs = Vec::new();
315
316    for input in &args.inputs {
317        let input_string = input.to_string_lossy();
318        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
319            captured_after_timestamps.push(timestamp.to_string());
320        } else {
321            file_inputs.push(input.clone());
322        }
323    }
324
325    let mut examples = read_example_files(&file_inputs);
326
327    Progress::global().set_total_examples(examples.len());
328
329    let remaining_limit_for_snowflake =
330        args.limit.map(|limit| limit.saturating_sub(examples.len()));
331
332    if let Some(0) = remaining_limit_for_snowflake {
333        log::info!(
334            "skipping captured-after inputs because --limit is already satisfied by example files"
335        );
336    } else if !captured_after_timestamps.is_empty() {
337        captured_after_timestamps.sort();
338
339        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
340
341        let mut captured_examples = pull_examples::fetch_captured_examples_after(
342            http_client,
343            &captured_after_timestamps,
344            max_rows_per_timestamp,
345            background_executor,
346        )
347        .await?;
348        examples.append(&mut captured_examples);
349    }
350
351    crate::example::sort_examples_by_repo_and_rev(&mut examples);
352
353    if let Some(name_filter) = &args.name {
354        examples.retain(|example| example.spec.name.contains(name_filter));
355    }
356    if let Some(repo_filter) = &args.repo {
357        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
358    }
359
360    if let Some(limit) = args.limit {
361        if examples.len() > limit {
362            examples.truncate(limit);
363        }
364    }
365
366    if let Some(path) = output_path {
367        resume_from_output(path, &mut examples);
368    }
369
370    Progress::global().set_total_examples(examples.len());
371
372    Ok(examples)
373}
374
375fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
376    let mut hasher = collections::FxHasher::default();
377    spec.hash(&mut hasher);
378    hasher.finish()
379}
380
381fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
382    let file = match File::open(path) {
383        Ok(f) => f,
384        Err(_) => return,
385    };
386
387    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
388
389    let reader = BufReader::new(file);
390    let mut kept_lines = Vec::new();
391    let mut kept_hashes = HashSet::default();
392
393    for line in reader.lines() {
394        let line = match line {
395            Ok(l) => l,
396            Err(_) => continue,
397        };
398
399        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
400            let hash = spec_hash(&output_example.spec);
401            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
402                kept_hashes.insert(hash);
403                kept_lines.push(line);
404            }
405        }
406    }
407
408    let total = examples.len();
409    let already_processed = kept_hashes.len();
410
411    eprintln!(
412        "Resuming: {}/{} examples already processed",
413        already_processed, total
414    );
415
416    let file = OpenOptions::new()
417        .write(true)
418        .truncate(true)
419        .open(path)
420        .expect("Failed to open output file for rewriting");
421    let mut writer = BufWriter::new(file);
422    for line in &kept_lines {
423        writeln!(writer, "{}", line).expect("Failed to write to output file");
424    }
425    writer.flush().expect("Failed to flush output file");
426
427    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
428}
429
430fn main() {
431    let args = EpArgs::parse();
432
433    if args.printenv {
434        ::util::shell_env::print_env();
435        return;
436    }
437
438    let output = args.output_path();
439    let command = match &args.command {
440        Some(cmd) => cmd.clone(),
441        None => {
442            EpArgs::command().print_help().unwrap();
443            return;
444        }
445    };
446
447    match &command {
448        Command::Clean => {
449            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
450            return;
451        }
452        Command::Synthesize(synth_args) => {
453            let Some(output_dir) = args.output else {
454                panic!("output dir is required");
455            };
456            let config = SynthesizeConfig {
457                repo_urls: synth_args.repos.clone(),
458                count: synth_args.count,
459                max_commits: synth_args.max_commits,
460                output_dir,
461                fresh: synth_args.fresh,
462            };
463            smol::block_on(async {
464                if let Err(e) = run_synthesize(config).await {
465                    eprintln!("Error: {:?}", e);
466                    std::process::exit(1);
467                }
468            });
469            return;
470        }
471        Command::SplitCommit(split_commit_args) => {
472            if let Err(error) =
473                split_commit::run_split_commit(split_commit_args, &args.inputs, output.as_ref())
474            {
475                eprintln!("{error:#}");
476                std::process::exit(1);
477            }
478            return;
479        }
480        Command::Split(split_args) => {
481            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
482                eprintln!("{error:#}");
483                std::process::exit(1);
484            }
485            return;
486        }
487        _ => {}
488    }
489
490    let http_client = Arc::new(ReqwestClient::new());
491    let app = Application::headless().with_http_client(http_client);
492
493    app.run(move |cx| {
494        let app_state = Arc::new(headless::init(cx));
495        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
496
497        cx.spawn(async move |cx| {
498            let result = async {
499                let mut examples = load_examples(
500                    app_state.client.http_client(),
501                    &args,
502                    output.as_ref(),
503                    cx.background_executor().clone(),
504                )
505                .await?;
506
507                match &command {
508                    Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
509                        predict::sync_batches(&args.provider).await?;
510                    }
511                    _ => (),
512                }
513
514                let failfast_on_single_example = examples.len() == 1;
515
516                let output_sender: Option<mpsc::UnboundedSender<String>> =
517                    if args.output.is_some() || !matches!(command, Command::Eval(_)) {
518                        output.as_ref().map(|path| {
519                            let file = OpenOptions::new()
520                                .create(true)
521                                .append(true)
522                                .open(path)
523                                .expect("Failed to open output file");
524                            let mut writer = BufWriter::new(file);
525                            let (sender, mut receiver) = mpsc::unbounded::<String>();
526                            cx.background_spawn(async move {
527                                while let Some(line) = receiver.next().await {
528                                    writeln!(writer, "{}", line).expect("Failed to write example");
529                                    writer.flush().expect("Failed to flush output");
530                                }
531                            })
532                            .detach();
533                            sender
534                        })
535                    } else {
536                        None
537                    };
538
539                let mut grouped_examples = group_examples_by_repo(&mut examples);
540                let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
541
542                for example_batch in example_batches {
543                    let futures = example_batch.into_iter().map(|repo_examples| async {
544                        for example in repo_examples.iter_mut() {
545                            let result = async {
546                                match &command {
547                                    Command::ParseExample => {}
548                                    Command::LoadProject => {
549                                        run_load_project(example, app_state.clone(), cx.clone())
550                                            .await?;
551                                    }
552                                    Command::Context => {
553                                        run_context_retrieval(
554                                            example,
555                                            app_state.clone(),
556                                            cx.clone(),
557                                        )
558                                        .await?;
559                                    }
560                                    Command::FormatPrompt(args) => {
561                                        run_format_prompt(
562                                            example,
563                                            args,
564                                            app_state.clone(),
565                                            cx.clone(),
566                                        )
567                                        .await?;
568                                    }
569                                    Command::Predict(args) => {
570                                        run_prediction(
571                                            example,
572                                            args,
573                                            app_state.clone(),
574                                            cx.clone(),
575                                        )
576                                        .await?;
577                                    }
578                                    Command::Distill => {
579                                        run_distill(example).await?;
580                                    }
581                                    Command::Score(args) | Command::Eval(args) => {
582                                        run_scoring(example, &args, app_state.clone(), cx.clone())
583                                            .await?;
584                                    }
585                                    Command::Clean
586                                    | Command::Synthesize(_)
587                                    | Command::SplitCommit(_)
588                                    | Command::Split(_) => {
589                                        unreachable!()
590                                    }
591                                }
592                                anyhow::Ok(())
593                            }
594                            .await;
595
596                            let failed = if let Err(error) = result {
597                                handle_error(
598                                    error,
599                                    &args,
600                                    &command,
601                                    &app_state,
602                                    failfast_on_single_example,
603                                    example,
604                                )
605                                .await;
606                                true
607                            } else {
608                                false
609                            };
610
611                            let should_write = !failed || args.failed == FailedHandling::Keep;
612                            if should_write {
613                                if let Some(ref mut sender) = output_sender.clone() {
614                                    let line = serde_json::to_string(example).unwrap();
615                                    sender
616                                        .send(line)
617                                        .await
618                                        .expect("Failed to send to output writer");
619                                } else if args.output.is_none()
620                                    && !matches!(command, Command::Eval(_))
621                                {
622                                    let line = serde_json::to_string(example).unwrap();
623                                    println!("{}", line);
624                                }
625                            }
626                        }
627                    });
628                    futures::future::join_all(futures).await;
629                }
630
631                Progress::global().finalize();
632
633                match &command {
634                    Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
635                        predict::sync_batches(&args.provider).await?;
636                    }
637                    _ => (),
638                }
639
640                match &command {
641                    Command::Eval(_) => score::print_report(&examples),
642                    _ => (),
643                };
644
645                anyhow::Ok(())
646            }
647            .await;
648
649            if let Err(e) = result {
650                panic!("Fatal error: {:?}", e);
651            }
652
653            let _ = cx.update(|cx| cx.quit());
654        })
655        .detach();
656    });
657}
658
659async fn handle_error(
660    error: anyhow::Error,
661    args: &EpArgs,
662    command: &Command,
663    app_state: &Arc<headless::EpAppState>,
664    failfast_on_single_example: bool,
665    example: &Example,
666) {
667    Progress::global().increment_failed();
668    let example_name = example.spec.filename();
669    let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
670    app_state
671        .fs
672        .write(
673            &failed_example_path,
674            &serde_json::to_vec_pretty(&example).unwrap(),
675        )
676        .await
677        .unwrap();
678    let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
679    app_state
680        .fs
681        .write(&err_path, format!("{error:?}").as_bytes())
682        .await
683        .unwrap();
684
685    let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
686    let mut file = OpenOptions::new()
687        .create(true)
688        .append(true)
689        .open(&failed_jsonl_path)
690        .expect("Failed to open failed.jsonl");
691    writeln!(file, "{}", serde_json::to_string(example).unwrap())
692        .expect("Failed to write to failed.jsonl");
693
694    let cursor_path = example
695        .repo_name()
696        .unwrap()
697        .worktree_path()
698        .join(&example.spec.cursor_path);
699
700    let msg = format!(
701        indoc::indoc! {"
702            While processing \"{}\":
703
704            \x1b[31m{:?}\x1b[0m
705
706            Example:        \x1b[36m{}\x1b[0m
707            Error file:     \x1b[36m{}\x1b[0m
708            Cursor file:    \x1b[36m{}\x1b[0m
709            Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
710        "},
711        example.spec.name,
712        error,
713        failed_example_path.display(),
714        err_path.display(),
715        cursor_path.display(),
716        command,
717        failed_example_path.display(),
718    );
719    if args.failfast || failfast_on_single_example {
720        Progress::global().finalize();
721        panic!("{}", msg);
722    } else {
723        log::error!("{}", msg);
724    }
725}