main.rs

  1mod anthropic_client;
  2mod distill;
  3mod example;
  4mod format_prompt;
  5mod git;
  6mod headless;
  7mod load_project;
  8mod metrics;
  9mod paths;
 10mod predict;
 11mod progress;
 12mod pull_examples;
 13mod reorder_patch;
 14mod retrieve_context;
 15mod score;
 16mod split_commit;
 17mod split_dataset;
 18mod synthesize;
 19use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
 20use collections::HashSet;
 21use edit_prediction::EditPredictionStore;
 22use futures::channel::mpsc;
 23use futures::{SinkExt as _, StreamExt as _};
 24use gpui::{AppContext as _, Application, BackgroundExecutor};
 25use zeta_prompt::ZetaVersion;
 26
 27use reqwest_client::ReqwestClient;
 28use serde::{Deserialize, Deserializer, Serialize, Serializer};
 29use std::fmt::Display;
 30use std::fs::{File, OpenOptions};
 31use std::hash::{Hash, Hasher};
 32use std::io::{BufRead, BufReader, BufWriter, Write};
 33use std::{path::PathBuf, sync::Arc};
 34
 35use crate::distill::run_distill;
 36use crate::example::{Example, group_examples_by_repo, read_example_files};
 37use crate::format_prompt::run_format_prompt;
 38use crate::load_project::run_load_project;
 39use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
 40use crate::predict::run_prediction;
 41use crate::progress::Progress;
 42use crate::retrieve_context::run_context_retrieval;
 43use crate::score::run_scoring;
 44use crate::split_commit::SplitCommitArgs;
 45use crate::split_dataset::SplitArgs;
 46use crate::synthesize::{SynthesizeConfig, run_synthesize};
 47
 48#[derive(Parser, Debug)]
 49#[command(name = "ep")]
 50struct EpArgs {
 51    #[arg(long, default_value_t = false)]
 52    printenv: bool,
 53    #[clap(long, default_value_t = 10, global = true)]
 54    max_parallelism: usize,
 55    #[clap(long, global = true)]
 56    limit: Option<usize>,
 57    /// Filter examples by name
 58    #[clap(long, global = true)]
 59    name: Option<String>,
 60    /// Filter examples by repository
 61    #[clap(long, global = true)]
 62    repo: Option<String>,
 63    #[command(subcommand)]
 64    command: Option<Command>,
 65    #[clap(global = true, help = INPUTS_HELP)]
 66    inputs: Vec<PathBuf>,
 67    #[arg(long, short, global = true)]
 68    output: Option<PathBuf>,
 69    #[arg(long, short, global = true)]
 70    in_place: bool,
 71    #[arg(long, short, global = true)]
 72    failfast: bool,
 73    /// How to handle failed examples in output: keep them or skip them.
 74    /// Failed examples are always logged to the run's failed directory.
 75    #[arg(long, global = true, default_value = "keep")]
 76    failed: FailedHandling,
 77}
 78
 79/// Controls whether failed examples are included in the main output.
 80/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
 81#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
 82pub enum FailedHandling {
 83    /// Include failed examples in the main output (default)
 84    #[default]
 85    Keep,
 86    /// Exclude failed examples from the main output
 87    Skip,
 88}
 89
 90const INPUTS_HELP: &str = r#"
 91Inputs can be file paths or special specifiers:
 92
 93  path
 94      Path to an example(s) file (.md, .json, or .jsonl)
 95
 96  captured-after:{timestamp}
 97      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 98
 99      You can specify this multiple times and mix it with file inputs.
100
101      Required environment variables to connect to Snowflake:
102          EP_SNOWFLAKE_API_KEY
103          EP_SNOWFLAKE_BASE_URL
104
105      Optional:
106          EP_SNOWFLAKE_ROLE
107
108Examples:
109
110  # Predict from a file
111  ep predict examples.jsonl
112
113  # Predict from captured examples after a timestamp
114  ep predict captured-after:2025-01-01T00:00:00Z
115
116  # Mix file inputs and captured-after in the same invocation
117  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
118"#;
119
120#[derive(Subcommand, Debug, Clone)]
121enum Command {
122    /// Parse markdown examples and output a combined .jsonl file
123    ParseExample,
124    /// Create git worktrees for each example and load file contents
125    LoadProject,
126    /// Retrieve context for input examples.
127    Context,
128    /// Generate a prompt string for a specific model
129    FormatPrompt(FormatPromptArgs),
130    /// Runs edit prediction
131    Predict(PredictArgs),
132    /// Computes a score based on actual and expected patches
133    Score(PredictArgs),
134    /// Prepares a distillation dataset by copying expected outputs to
135    /// predicted outputs and removing actual outputs and prompts.
136    Distill,
137    /// Print aggregated scores
138    Eval(PredictArgs),
139    /// Generate eval examples by analyzing git commits from a repository
140    Synthesize(SynthesizeArgs),
141    /// Remove git repositories and worktrees
142    Clean,
143    /// Generate an evaluation example by splitting a chronologically-ordered commit
144    SplitCommit(SplitCommitArgs),
145    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
146    Split(SplitArgs),
147}
148
149impl Display for Command {
150    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151        match self {
152            Command::ParseExample => write!(f, "parse-example"),
153            Command::LoadProject => write!(f, "load-project"),
154            Command::Context => write!(f, "context"),
155            Command::FormatPrompt(args) => {
156                write!(f, "format-prompt --provider={}", args.provider)
157            }
158            Command::Predict(args) => {
159                write!(f, "predict --provider={}", args.provider)
160            }
161            Command::Score(args) => {
162                write!(f, "score --provider={}", args.provider)
163            }
164            Command::Distill => write!(f, "distill"),
165            Command::Eval(args) => {
166                write!(f, "eval --provider={}", args.provider)
167            }
168            Command::Synthesize(args) => {
169                write!(f, "synthesize --repos {}", args.repos.join(" "))
170            }
171            Command::Clean => write!(f, "clean"),
172            Command::SplitCommit(_) => write!(f, "split-commit"),
173            Command::Split(_) => write!(f, "split"),
174        }
175    }
176}
177
178#[derive(Debug, Args, Clone)]
179struct FormatPromptArgs {
180    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
181    provider: PredictionProvider,
182}
183
184#[derive(Debug, Args, Clone)]
185struct PredictArgs {
186    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
187    provider: PredictionProvider,
188    #[clap(long, default_value_t = 1)]
189    repetitions: usize,
190}
191
192#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
193enum PredictionProvider {
194    Sweep,
195    Mercury,
196    Zeta1,
197    Zeta2(ZetaVersion),
198    Teacher,
199    TeacherNonBatching,
200}
201
202impl Default for PredictionProvider {
203    fn default() -> Self {
204        PredictionProvider::Zeta2(ZetaVersion::default())
205    }
206}
207
208impl std::fmt::Display for PredictionProvider {
209    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210        match self {
211            PredictionProvider::Sweep => write!(f, "sweep"),
212            PredictionProvider::Mercury => write!(f, "mercury"),
213            PredictionProvider::Zeta1 => write!(f, "zeta1"),
214            PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
215            PredictionProvider::Teacher => write!(f, "teacher"),
216            PredictionProvider::TeacherNonBatching => write!(f, "teacher-non-batching"),
217        }
218    }
219}
220
221impl std::str::FromStr for PredictionProvider {
222    type Err = anyhow::Error;
223
224    fn from_str(s: &str) -> Result<Self, Self::Err> {
225        let s_lower = s.to_lowercase();
226        match s_lower.as_str() {
227            "sweep" => Ok(PredictionProvider::Sweep),
228            "mercury" => Ok(PredictionProvider::Mercury),
229            "zeta1" => Ok(PredictionProvider::Zeta1),
230            // Handle both old format "zeta2" and new format with version
231            "zeta2" => Ok(PredictionProvider::Zeta2(ZetaVersion::default())),
232            "teacher" => Ok(PredictionProvider::Teacher),
233            "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
234                Ok(PredictionProvider::TeacherNonBatching)
235            }
236            _ if s_lower.starts_with("zeta2:") => {
237                let version_str = &s[6..];
238                let version = ZetaVersion::parse(version_str)?;
239                Ok(PredictionProvider::Zeta2(version))
240            }
241            _ => anyhow::bail!(
242                "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching\n\
243                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
244                 Available zeta versions:\n{}",
245                ZetaVersion::options_as_string()
246            ),
247        }
248    }
249}
250
251impl Serialize for PredictionProvider {
252    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
253    where
254        S: Serializer,
255    {
256        serializer.serialize_str(&self.to_string())
257    }
258}
259
260impl<'de> Deserialize<'de> for PredictionProvider {
261    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
262    where
263        D: Deserializer<'de>,
264    {
265        let s = String::deserialize(deserializer)?;
266        s.parse().map_err(serde::de::Error::custom)
267    }
268}
269
270#[derive(Debug, Args, Clone)]
271struct SynthesizeArgs {
272    /// Repository URLs (git@github.com:owner/repo or https://...)
273    #[clap(long, required = true, num_args = 1..)]
274    repos: Vec<String>,
275
276    /// Number of examples to generate per repository
277    #[clap(long, default_value_t = 5)]
278    count: usize,
279
280    /// Maximum commits to scan per repository before giving up
281    #[clap(long, default_value_t = 100)]
282    max_commits: usize,
283
284    /// Ignore state file and reprocess all commits
285    #[clap(long)]
286    fresh: bool,
287}
288
289impl EpArgs {
290    fn output_path(&self) -> Option<PathBuf> {
291        if self.in_place {
292            if self.inputs.len() == 1 {
293                self.inputs.first().cloned()
294            } else {
295                panic!("--in-place requires exactly one input file")
296            }
297        } else {
298            self.output.clone()
299        }
300    }
301}
302
303async fn load_examples(
304    http_client: Arc<dyn http_client::HttpClient>,
305    args: &EpArgs,
306    output_path: Option<&PathBuf>,
307    background_executor: BackgroundExecutor,
308) -> anyhow::Result<Vec<Example>> {
309    let mut captured_after_timestamps = Vec::new();
310    let mut file_inputs = Vec::new();
311
312    for input in &args.inputs {
313        let input_string = input.to_string_lossy();
314        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
315            captured_after_timestamps.push(timestamp.to_string());
316        } else {
317            file_inputs.push(input.clone());
318        }
319    }
320
321    let mut examples = read_example_files(&file_inputs);
322
323    Progress::global().set_total_examples(examples.len());
324
325    let remaining_limit_for_snowflake =
326        args.limit.map(|limit| limit.saturating_sub(examples.len()));
327
328    if let Some(0) = remaining_limit_for_snowflake {
329        log::info!(
330            "skipping captured-after inputs because --limit is already satisfied by example files"
331        );
332    } else if !captured_after_timestamps.is_empty() {
333        captured_after_timestamps.sort();
334
335        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
336
337        let mut captured_examples = pull_examples::fetch_captured_examples_after(
338            http_client,
339            &captured_after_timestamps,
340            max_rows_per_timestamp,
341            background_executor,
342        )
343        .await?;
344        examples.append(&mut captured_examples);
345    }
346
347    crate::example::sort_examples_by_repo_and_rev(&mut examples);
348
349    if let Some(name_filter) = &args.name {
350        examples.retain(|example| example.spec.name.contains(name_filter));
351    }
352    if let Some(repo_filter) = &args.repo {
353        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
354    }
355
356    if let Some(limit) = args.limit {
357        if examples.len() > limit {
358            examples.truncate(limit);
359        }
360    }
361
362    if let Some(path) = output_path {
363        resume_from_output(path, &mut examples);
364    }
365
366    Progress::global().set_total_examples(examples.len());
367
368    Ok(examples)
369}
370
371fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
372    let mut hasher = collections::FxHasher::default();
373    spec.hash(&mut hasher);
374    hasher.finish()
375}
376
377fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
378    let file = match File::open(path) {
379        Ok(f) => f,
380        Err(_) => return,
381    };
382
383    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
384
385    let reader = BufReader::new(file);
386    let mut kept_lines = Vec::new();
387    let mut kept_hashes = HashSet::default();
388
389    for line in reader.lines() {
390        let line = match line {
391            Ok(l) => l,
392            Err(_) => continue,
393        };
394
395        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
396            let hash = spec_hash(&output_example.spec);
397            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
398                kept_hashes.insert(hash);
399                kept_lines.push(line);
400            }
401        }
402    }
403
404    let total = examples.len();
405    let already_processed = kept_hashes.len();
406
407    eprintln!(
408        "Resuming: {}/{} examples already processed",
409        already_processed, total
410    );
411
412    let file = OpenOptions::new()
413        .write(true)
414        .truncate(true)
415        .open(path)
416        .expect("Failed to open output file for rewriting");
417    let mut writer = BufWriter::new(file);
418    for line in &kept_lines {
419        writeln!(writer, "{}", line).expect("Failed to write to output file");
420    }
421    writer.flush().expect("Failed to flush output file");
422
423    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
424}
425
426fn main() {
427    let args = EpArgs::parse();
428
429    if args.printenv {
430        ::util::shell_env::print_env();
431        return;
432    }
433
434    let output = args.output_path();
435    let command = match &args.command {
436        Some(cmd) => cmd.clone(),
437        None => {
438            EpArgs::command().print_help().unwrap();
439            return;
440        }
441    };
442
443    match &command {
444        Command::Clean => {
445            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
446            return;
447        }
448        Command::Synthesize(synth_args) => {
449            let Some(output_dir) = args.output else {
450                panic!("output dir is required");
451            };
452            let config = SynthesizeConfig {
453                repo_urls: synth_args.repos.clone(),
454                count: synth_args.count,
455                max_commits: synth_args.max_commits,
456                output_dir,
457                fresh: synth_args.fresh,
458            };
459            smol::block_on(async {
460                if let Err(e) = run_synthesize(config).await {
461                    eprintln!("Error: {:?}", e);
462                    std::process::exit(1);
463                }
464            });
465            return;
466        }
467        Command::SplitCommit(split_commit_args) => {
468            if let Err(error) =
469                split_commit::run_split_commit(split_commit_args, &args.inputs, output.as_ref())
470            {
471                eprintln!("{error:#}");
472                std::process::exit(1);
473            }
474            return;
475        }
476        Command::Split(split_args) => {
477            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
478                eprintln!("{error:#}");
479                std::process::exit(1);
480            }
481            return;
482        }
483        _ => {}
484    }
485
486    let http_client = Arc::new(ReqwestClient::new());
487    let app = Application::headless().with_http_client(http_client);
488
489    app.run(move |cx| {
490        let app_state = Arc::new(headless::init(cx));
491        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
492
493        cx.spawn(async move |cx| {
494            let result = async {
495                let mut examples = load_examples(
496                    app_state.client.http_client(),
497                    &args,
498                    output.as_ref(),
499                    cx.background_executor().clone(),
500                )
501                .await?;
502
503                match &command {
504                    Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
505                        predict::sync_batches(&args.provider).await?;
506                    }
507                    _ => (),
508                }
509
510                let failfast_on_single_example = examples.len() == 1;
511
512                let output_sender: Option<mpsc::UnboundedSender<String>> =
513                    if args.output.is_some() || !matches!(command, Command::Eval(_)) {
514                        output.as_ref().map(|path| {
515                            let file = OpenOptions::new()
516                                .create(true)
517                                .append(true)
518                                .open(path)
519                                .expect("Failed to open output file");
520                            let mut writer = BufWriter::new(file);
521                            let (sender, mut receiver) = mpsc::unbounded::<String>();
522                            cx.background_spawn(async move {
523                                while let Some(line) = receiver.next().await {
524                                    writeln!(writer, "{}", line).expect("Failed to write example");
525                                    writer.flush().expect("Failed to flush output");
526                                }
527                            })
528                            .detach();
529                            sender
530                        })
531                    } else {
532                        None
533                    };
534
535                let mut grouped_examples = group_examples_by_repo(&mut examples);
536                let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
537
538                for example_batch in example_batches {
539                    let futures = example_batch.into_iter().map(|repo_examples| async {
540                        for example in repo_examples.iter_mut() {
541                            let result = async {
542                                match &command {
543                                    Command::ParseExample => {}
544                                    Command::LoadProject => {
545                                        run_load_project(example, app_state.clone(), cx.clone())
546                                            .await?;
547                                    }
548                                    Command::Context => {
549                                        run_context_retrieval(
550                                            example,
551                                            app_state.clone(),
552                                            cx.clone(),
553                                        )
554                                        .await?;
555                                    }
556                                    Command::FormatPrompt(args) => {
557                                        run_format_prompt(
558                                            example,
559                                            args,
560                                            app_state.clone(),
561                                            cx.clone(),
562                                        )
563                                        .await?;
564                                    }
565                                    Command::Predict(args) => {
566                                        run_prediction(
567                                            example,
568                                            args,
569                                            app_state.clone(),
570                                            cx.clone(),
571                                        )
572                                        .await?;
573                                    }
574                                    Command::Distill => {
575                                        run_distill(example).await?;
576                                    }
577                                    Command::Score(args) | Command::Eval(args) => {
578                                        run_scoring(example, &args, app_state.clone(), cx.clone())
579                                            .await?;
580                                    }
581                                    Command::Clean
582                                    | Command::Synthesize(_)
583                                    | Command::SplitCommit(_)
584                                    | Command::Split(_) => {
585                                        unreachable!()
586                                    }
587                                }
588                                anyhow::Ok(())
589                            }
590                            .await;
591
592                            let failed = if let Err(error) = result {
593                                handle_error(
594                                    error,
595                                    &args,
596                                    &command,
597                                    &app_state,
598                                    failfast_on_single_example,
599                                    example,
600                                )
601                                .await;
602                                true
603                            } else {
604                                false
605                            };
606
607                            let should_write = !failed || args.failed == FailedHandling::Keep;
608                            if should_write {
609                                if let Some(ref mut sender) = output_sender.clone() {
610                                    let line = serde_json::to_string(example).unwrap();
611                                    sender
612                                        .send(line)
613                                        .await
614                                        .expect("Failed to send to output writer");
615                                } else if args.output.is_none()
616                                    && !matches!(command, Command::Eval(_))
617                                {
618                                    let line = serde_json::to_string(example).unwrap();
619                                    println!("{}", line);
620                                }
621                            }
622                        }
623                    });
624                    futures::future::join_all(futures).await;
625                }
626
627                Progress::global().finalize();
628
629                match &command {
630                    Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
631                        predict::sync_batches(&args.provider).await?;
632                    }
633                    _ => (),
634                }
635
636                match &command {
637                    Command::Eval(_) => score::print_report(&examples),
638                    _ => (),
639                };
640
641                anyhow::Ok(())
642            }
643            .await;
644
645            if let Err(e) = result {
646                panic!("Fatal error: {:?}", e);
647            }
648
649            let _ = cx.update(|cx| cx.quit());
650        })
651        .detach();
652    });
653}
654
655async fn handle_error(
656    error: anyhow::Error,
657    args: &EpArgs,
658    command: &Command,
659    app_state: &Arc<headless::EpAppState>,
660    failfast_on_single_example: bool,
661    example: &Example,
662) {
663    Progress::global().increment_failed();
664    let example_name = example.spec.filename();
665    let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
666    app_state
667        .fs
668        .write(
669            &failed_example_path,
670            &serde_json::to_vec_pretty(&example).unwrap(),
671        )
672        .await
673        .unwrap();
674    let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
675    app_state
676        .fs
677        .write(&err_path, format!("{error:?}").as_bytes())
678        .await
679        .unwrap();
680
681    let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
682    let mut file = OpenOptions::new()
683        .create(true)
684        .append(true)
685        .open(&failed_jsonl_path)
686        .expect("Failed to open failed.jsonl");
687    writeln!(file, "{}", serde_json::to_string(example).unwrap())
688        .expect("Failed to write to failed.jsonl");
689
690    let cursor_path = example
691        .repo_name()
692        .unwrap()
693        .worktree_path()
694        .join(&example.spec.cursor_path);
695
696    let msg = format!(
697        indoc::indoc! {"
698            While processing \"{}\":
699
700            \x1b[31m{:?}\x1b[0m
701
702            Example:        \x1b[36m{}\x1b[0m
703            Error file:     \x1b[36m{}\x1b[0m
704            Cursor file:    \x1b[36m{}\x1b[0m
705            Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
706        "},
707        example.spec.name,
708        error,
709        failed_example_path.display(),
710        err_path.display(),
711        cursor_path.display(),
712        command,
713        failed_example_path.display(),
714    );
715    if args.failfast || failfast_on_single_example {
716        Progress::global().finalize();
717        panic!("{}", msg);
718    } else {
719        log::error!("{}", msg);
720    }
721}