main.rs

   1mod anthropic_client;
   2mod distill;
   3mod example;
   4mod filter_languages;
   5mod format_prompt;
   6mod git;
   7mod headless;
   8mod load_project;
   9mod metrics;
  10mod openai_client;
  11mod parse_output;
  12mod paths;
  13mod predict;
  14mod progress;
  15mod prompt_assets;
  16mod pull_examples;
  17mod qa;
  18mod reorder_patch;
  19mod repair;
  20mod retrieve_context;
  21mod reversal_tracking;
  22mod score;
  23mod split_commit;
  24mod split_dataset;
  25mod synthesize;
  26mod word_diff;
  27use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
  28use collections::HashSet;
  29use edit_prediction::EditPredictionStore;
  30use futures::channel::mpsc;
  31use futures::{SinkExt as _, StreamExt as _};
  32use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
  33use zeta_prompt::ZetaVersion;
  34
  35use reqwest_client::ReqwestClient;
  36use serde::{Deserialize, Deserializer, Serialize, Serializer};
  37use std::fmt::Display;
  38use std::fs::{File, OpenOptions};
  39use std::hash::{Hash, Hasher};
  40use std::io::{BufRead, BufReader, BufWriter, Write};
  41use std::sync::Mutex;
  42use std::{path::PathBuf, sync::Arc};
  43
  44use crate::distill::run_distill;
  45use crate::example::{Example, group_examples_by_repo, read_example_files};
  46use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
  47use crate::format_prompt::run_format_prompt;
  48use crate::load_project::run_load_project;
  49use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
  50use crate::predict::run_prediction;
  51use crate::progress::Progress;
  52use crate::retrieve_context::run_context_retrieval;
  53use crate::score::run_scoring;
  54use crate::split_commit::SplitCommitArgs;
  55use crate::split_dataset::SplitArgs;
  56use crate::synthesize::{SynthesizeConfig, run_synthesize};
  57
  58#[derive(Parser, Debug)]
  59#[command(name = "ep")]
  60struct EpArgs {
  61    #[arg(long, default_value_t = false)]
  62    printenv: bool,
  63    #[clap(long, default_value_t = 10, global = true)]
  64    max_parallelism: usize,
  65    /// The limit for the number of examples to process
  66    /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
  67    #[clap(long, global = true)]
  68    limit: Option<usize>,
  69    #[clap(long, global = true)]
  70    offset: Option<usize>,
  71    /// Filter examples by name
  72    #[clap(long, global = true)]
  73    name: Option<String>,
  74    /// Filter examples by repository
  75    #[clap(long, global = true)]
  76    repo: Option<String>,
  77    #[command(subcommand)]
  78    command: Option<Command>,
  79    #[clap(global = true, help = INPUTS_HELP)]
  80    inputs: Vec<PathBuf>,
  81    #[arg(long, short, global = true)]
  82    output: Option<PathBuf>,
  83    #[arg(long, short, global = true)]
  84    in_place: bool,
  85    #[arg(long, short, global = true)]
  86    failfast: bool,
  87    /// How to handle failed examples in output: keep them or skip them.
  88    /// Failed examples are always logged to the run's failed directory.
  89    #[arg(long, global = true, default_value = "keep")]
  90    failed: FailedHandling,
  91    /// Output as markdown files instead of JSONL. When set, -o specifies a directory
  92    /// where one .md file per example will be written (named after each example).
  93    #[arg(long, short, global = true)]
  94    markdown: bool,
  95}
  96
  97/// Controls whether failed examples are included in the main output.
  98/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
  99#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
 100pub enum FailedHandling {
 101    /// Include failed examples in the main output (default)
 102    #[default]
 103    Keep,
 104    /// Exclude failed examples from the main output
 105    Skip,
 106    /// Skip writing files
 107    SkipNoFiles,
 108}
 109
 110const INPUTS_HELP: &str = r#"
 111Inputs can be file paths or special specifiers:
 112
 113  path
 114      Path to an example(s) file (.md, .json, or .jsonl)
 115
 116  captured-after:{timestamp}
 117      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 118      These are examples captured via the "Capture Edit Prediction Example" action.
 119
 120  rejected-after:{timestamp}
 121      Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
 122      These are predictions that were shown to users but rejected (useful for DPO training).
 123
 124  rated-after:{timestamp}
 125      Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
 126      These are predictions that users explicitly rated as positive or negative via the
 127      rate completions modal. Only zeta2 predictions are included.
 128      - Positive ratings: output becomes expected_patches
 129      - Negative ratings: output becomes rejected_patch
 130
 131  rated-positive-after:{timestamp}
 132      Same as rated-after, but only fetches positively rated predictions.
 133
 134  rated-negative-after:{timestamp}
 135      Same as rated-after, but only fetches negatively rated predictions.
 136
 137      Required environment variables to connect to Snowflake:
 138          EP_SNOWFLAKE_API_KEY
 139          EP_SNOWFLAKE_BASE_URL
 140
 141      Optional:
 142          EP_SNOWFLAKE_ROLE
 143
 144Examples:
 145
 146  # Read examples from a file
 147  ep read examples.jsonl -o output.jsonl
 148
 149  # Read captured examples after a timestamp
 150  ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
 151
 152  # Read rejected predictions for DPO training
 153  ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
 154
 155  # Read user-rated predictions
 156  ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
 157
 158  # Read only positively rated predictions
 159  ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
 160
 161  # Read only negatively rated predictions
 162  ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
 163
 164  # Mix multiple input sources
 165  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
 166"#;
 167
 168#[derive(Subcommand, Debug, Clone)]
 169enum Command {
 170    /// Read examples from files or fetch from Snowflake, output as .jsonl
 171    Read,
 172    /// Create git worktrees for each example and load file contents
 173    LoadProject,
 174    /// Retrieve context for input examples.
 175    Context,
 176    /// Generate a prompt string for a specific model
 177    FormatPrompt(FormatPromptArgs),
 178    /// Runs edit prediction
 179    Predict(PredictArgs),
 180    /// Parse model outputs (actual_output) into unified diffs (actual_patch).
 181    /// Requires format-prompt to have been run first. Uses provider from prompt.
 182    ParseOutput,
 183    /// Computes a score based on actual and expected patches
 184    Score(PredictArgs),
 185    /// Prepares a distillation dataset by copying expected outputs to
 186    /// predicted outputs and removing actual outputs and prompts.
 187    Distill,
 188    /// Print aggregated scores
 189    Eval(EvalArgs),
 190    /// Generate eval examples by analyzing git commits from a repository
 191    Synthesize(SynthesizeArgs),
 192    /// Remove git repositories and worktrees
 193    Clean,
 194    /// Generate an evaluation example by splitting a chronologically-ordered commit
 195    SplitCommit(SplitCommitArgs),
 196    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
 197    Split(SplitArgs),
 198    /// Filter a JSONL dataset by programming language (based on cursor_path extension)
 199    FilterLanguages(FilterLanguagesArgs),
 200    /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
 201    ImportBatch(ImportBatchArgs),
 202    /// Assess the quality of predictions using LLM-as-a-judge
 203    Qa(qa::QaArgs),
 204    /// Repair predictions that received poor QA scores by generating improved predictions
 205    Repair(repair::RepairArgs),
 206}
 207
 208impl Display for Command {
 209    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 210        match self {
 211            Command::Read => write!(f, "read"),
 212            Command::LoadProject => write!(f, "load-project"),
 213            Command::Context => write!(f, "context"),
 214            Command::FormatPrompt(args) => {
 215                write!(f, "format-prompt --provider={}", args.provider)
 216            }
 217            Command::Predict(args) => match &args.provider {
 218                Some(provider) => write!(f, "predict --provider={}", provider),
 219                None => write!(f, "predict"),
 220            },
 221            Command::ParseOutput => write!(f, "parse-output"),
 222            Command::Score(args) => match &args.provider {
 223                Some(provider) => write!(f, "score --provider={}", provider),
 224                None => write!(f, "score"),
 225            },
 226            Command::Distill => write!(f, "distill"),
 227            Command::Eval(args) => match &args.predict.provider {
 228                Some(provider) => write!(f, "eval --provider={}", provider),
 229                None => write!(f, "eval"),
 230            },
 231            Command::Synthesize(args) => {
 232                write!(f, "synthesize --repos {}", args.repos.join(" "))
 233            }
 234            Command::Clean => write!(f, "clean"),
 235            Command::SplitCommit(_) => write!(f, "split-commit"),
 236            Command::Split(_) => write!(f, "split"),
 237            Command::FilterLanguages(_) => write!(f, "filter-languages"),
 238            Command::ImportBatch(args) => {
 239                write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
 240            }
 241            Command::Qa(_) => {
 242                write!(f, "qa")
 243            }
 244            Command::Repair(_) => {
 245                write!(f, "repair")
 246            }
 247        }
 248    }
 249}
 250
 251#[derive(Debug, Args, Clone)]
 252struct FormatPromptArgs {
 253    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
 254    provider: PredictionProvider,
 255}
 256
 257#[derive(Debug, Args, Clone)]
 258struct PredictArgs {
 259    #[clap(long, short('p'))]
 260    provider: Option<PredictionProvider>,
 261    #[clap(long, default_value_t = 1)]
 262    repetitions: usize,
 263    /// Only use cached responses, don't queue new requests for batching
 264    #[clap(long)]
 265    cache_only: bool,
 266}
 267
 268#[derive(Debug, Args, Clone)]
 269struct EvalArgs {
 270    #[clap(flatten)]
 271    predict: PredictArgs,
 272    /// Path to write summary scores as JSON
 273    #[clap(long)]
 274    summary_json: Option<PathBuf>,
 275}
 276
 277#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 278pub enum TeacherBackend {
 279    Sonnet45,
 280    Gpt52,
 281}
 282
 283impl std::fmt::Display for TeacherBackend {
 284    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 285        match self {
 286            TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
 287            TeacherBackend::Gpt52 => write!(f, "gpt52"),
 288        }
 289    }
 290}
 291
 292impl std::str::FromStr for TeacherBackend {
 293    type Err = anyhow::Error;
 294
 295    fn from_str(s: &str) -> Result<Self, Self::Err> {
 296        match s.to_lowercase().as_str() {
 297            "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
 298            "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
 299            "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
 300            _ => anyhow::bail!("unknown teacher backend `{s}`. Valid options: sonnet45, gpt52"),
 301        }
 302    }
 303}
 304
 305impl TeacherBackend {
 306    pub fn model_name(&self) -> &'static str {
 307        match self {
 308            TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
 309            TeacherBackend::Gpt52 => "gpt-5.2",
 310        }
 311    }
 312}
 313
 314#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 315enum PredictionProvider {
 316    Sweep,
 317    Mercury,
 318    Zeta1,
 319    Zeta2(ZetaVersion),
 320    Teacher(TeacherBackend),
 321    TeacherNonBatching(TeacherBackend),
 322    Repair,
 323}
 324
 325impl Default for PredictionProvider {
 326    fn default() -> Self {
 327        PredictionProvider::Zeta2(ZetaVersion::default())
 328    }
 329}
 330
 331impl std::fmt::Display for PredictionProvider {
 332    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 333        match self {
 334            PredictionProvider::Sweep => write!(f, "sweep"),
 335            PredictionProvider::Mercury => write!(f, "mercury"),
 336            PredictionProvider::Zeta1 => write!(f, "zeta1"),
 337            PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
 338            PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
 339            PredictionProvider::TeacherNonBatching(backend) => {
 340                write!(f, "teacher-non-batching:{backend}")
 341            }
 342            PredictionProvider::Repair => write!(f, "repair"),
 343        }
 344    }
 345}
 346
 347impl std::str::FromStr for PredictionProvider {
 348    type Err = anyhow::Error;
 349
 350    fn from_str(s: &str) -> Result<Self, Self::Err> {
 351        let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
 352
 353        let provider_lower = provider.to_lowercase();
 354        match provider_lower.as_str() {
 355            "sweep" => Ok(PredictionProvider::Sweep),
 356            "mercury" => Ok(PredictionProvider::Mercury),
 357            "zeta1" => Ok(PredictionProvider::Zeta1),
 358            "zeta2" => {
 359                let version = arg.map(ZetaVersion::parse).transpose()?.unwrap_or_default();
 360                Ok(PredictionProvider::Zeta2(version))
 361            }
 362            "teacher" => {
 363                let backend = arg
 364                    .map(|a| a.parse())
 365                    .transpose()?
 366                    .unwrap_or(TeacherBackend::Sonnet45);
 367                Ok(PredictionProvider::Teacher(backend))
 368            }
 369            "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
 370                let backend = arg
 371                    .map(|a| a.parse())
 372                    .transpose()?
 373                    .unwrap_or(TeacherBackend::Sonnet45);
 374                Ok(PredictionProvider::TeacherNonBatching(backend))
 375            }
 376            "repair" => Ok(PredictionProvider::Repair),
 377            _ => {
 378                anyhow::bail!(
 379                    "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
 380                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
 381                 For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\
 382                 Available zeta versions:\n{}",
 383                    ZetaVersion::options_as_string()
 384                )
 385            }
 386        }
 387    }
 388}
 389
 390impl Serialize for PredictionProvider {
 391    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 392    where
 393        S: Serializer,
 394    {
 395        serializer.serialize_str(&self.to_string())
 396    }
 397}
 398
 399impl<'de> Deserialize<'de> for PredictionProvider {
 400    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 401    where
 402        D: Deserializer<'de>,
 403    {
 404        let s = String::deserialize(deserializer)?;
 405        s.parse().map_err(serde::de::Error::custom)
 406    }
 407}
 408
 409#[derive(Debug, Args, Clone)]
 410struct SynthesizeArgs {
 411    /// Repository URLs (git@github.com:owner/repo or https://...)
 412    #[clap(long, required = true, num_args = 1..)]
 413    repos: Vec<String>,
 414
 415    /// Number of examples to generate per repository
 416    #[clap(long, default_value_t = 5)]
 417    count: usize,
 418
 419    /// Maximum commits to scan per repository before giving up
 420    #[clap(long, default_value_t = 100)]
 421    max_commits: usize,
 422
 423    /// Ignore state file and reprocess all commits
 424    #[clap(long)]
 425    fresh: bool,
 426}
 427
 428#[derive(Debug, Args, Clone)]
 429struct ImportBatchArgs {
 430    /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
 431    #[clap(long, required = true, num_args = 1..)]
 432    batch_ids: Vec<String>,
 433    /// Which provider's batches to import (anthropic or openai)
 434    #[clap(long, default_value = "anthropic")]
 435    provider: BatchProvider,
 436}
 437
 438#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
 439enum BatchProvider {
 440    Anthropic,
 441    Openai,
 442}
 443
 444impl EpArgs {
 445    fn output_path(&self) -> Option<PathBuf> {
 446        if self.in_place {
 447            if self.inputs.len() == 1 {
 448                self.inputs.first().cloned()
 449            } else {
 450                panic!("--in-place requires exactly one input file")
 451            }
 452        } else {
 453            self.output.clone()
 454        }
 455    }
 456}
 457
 458async fn load_examples(
 459    http_client: Arc<dyn http_client::HttpClient>,
 460    args: &EpArgs,
 461    output_path: Option<&PathBuf>,
 462    background_executor: BackgroundExecutor,
 463) -> anyhow::Result<Vec<Example>> {
 464    let mut captured_after_timestamps = Vec::new();
 465    let mut rejected_after_timestamps = Vec::new();
 466    let mut requested_after_timestamps = Vec::new();
 467    let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
 468        Vec::new();
 469    let mut file_inputs = Vec::new();
 470
 471    for input in &args.inputs {
 472        let input_string = input.to_string_lossy();
 473        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
 474            captured_after_timestamps.push(timestamp.to_string());
 475        } else if let Some(timestamp) =
 476            pull_examples::parse_rejected_after_input(input_string.as_ref())
 477        {
 478            rejected_after_timestamps.push(timestamp.to_string());
 479        } else if let Some(timestamp) =
 480            pull_examples::parse_requested_after_input(input_string.as_ref())
 481        {
 482            requested_after_timestamps.push(timestamp.to_string());
 483        } else if let Some((timestamp, rating_filter)) =
 484            pull_examples::parse_rated_after_input(input_string.as_ref())
 485        {
 486            rated_after_inputs.push((timestamp.to_string(), rating_filter));
 487        } else {
 488            file_inputs.push(input.clone());
 489        }
 490    }
 491
 492    let mut examples = read_example_files(&file_inputs);
 493
 494    Progress::global().set_total_examples(examples.len());
 495
 496    let remaining_limit_for_snowflake =
 497        args.limit.map(|limit| limit.saturating_sub(examples.len()));
 498
 499    if let Some(0) = remaining_limit_for_snowflake {
 500        log::info!(
 501            "skipping Snowflake inputs because --limit is already satisfied by example files"
 502        );
 503    } else {
 504        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
 505
 506        if !captured_after_timestamps.is_empty() {
 507            captured_after_timestamps.sort();
 508
 509            let mut captured_examples = pull_examples::fetch_captured_examples_after(
 510                http_client.clone(),
 511                &captured_after_timestamps,
 512                max_rows_per_timestamp,
 513                background_executor.clone(),
 514            )
 515            .await?;
 516            examples.append(&mut captured_examples);
 517        }
 518
 519        if !rejected_after_timestamps.is_empty() {
 520            rejected_after_timestamps.sort();
 521
 522            let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
 523                http_client.clone(),
 524                &rejected_after_timestamps,
 525                max_rows_per_timestamp,
 526                background_executor.clone(),
 527            )
 528            .await?;
 529            examples.append(&mut rejected_examples);
 530        }
 531
 532        if !requested_after_timestamps.is_empty() {
 533            requested_after_timestamps.sort();
 534
 535            let mut requested_examples = pull_examples::fetch_requested_examples_after(
 536                http_client.clone(),
 537                &requested_after_timestamps,
 538                max_rows_per_timestamp,
 539                background_executor.clone(),
 540            )
 541            .await?;
 542            examples.append(&mut requested_examples);
 543        }
 544
 545        if !rated_after_inputs.is_empty() {
 546            rated_after_inputs.sort();
 547
 548            let mut rated_examples = pull_examples::fetch_rated_examples_after(
 549                http_client,
 550                &rated_after_inputs,
 551                max_rows_per_timestamp,
 552                background_executor,
 553            )
 554            .await?;
 555            examples.append(&mut rated_examples);
 556        }
 557    }
 558
 559    crate::example::sort_examples_by_repo_and_rev(&mut examples);
 560
 561    if let Some(name_filter) = &args.name {
 562        examples.retain(|example| example.spec.name.contains(name_filter));
 563    }
 564    if let Some(repo_filter) = &args.repo {
 565        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
 566    }
 567
 568    // Skip resume logic for --in-place since input and output are the same file,
 569    // which would incorrectly treat all input examples as already processed.
 570    if !args.in_place {
 571        if let Some(path) = output_path {
 572            resume_from_output(path, &mut examples);
 573        }
 574    }
 575
 576    if let Some(offset) = args.offset {
 577        examples.splice(0..offset, []);
 578    }
 579
 580    if let Some(limit) = args.limit {
 581        examples.truncate(limit);
 582    }
 583
 584    let progress = Progress::global();
 585    progress.set_total_examples(examples.len());
 586    progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
 587
 588    Ok(examples)
 589}
 590
 591fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
 592    let mut hasher = collections::FxHasher::default();
 593    spec.hash(&mut hasher);
 594    hasher.finish()
 595}
 596
 597fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
 598    let file = match File::open(path) {
 599        Ok(f) => f,
 600        Err(_) => return,
 601    };
 602
 603    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
 604
 605    let reader = BufReader::new(file);
 606    let mut kept_lines = Vec::new();
 607    let mut kept_hashes = HashSet::default();
 608
 609    for line in reader.lines() {
 610        let line = match line {
 611            Ok(l) => l,
 612            Err(_) => continue,
 613        };
 614
 615        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
 616            let hash = spec_hash(&output_example.spec);
 617            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
 618                kept_hashes.insert(hash);
 619                kept_lines.push(line);
 620            }
 621        }
 622    }
 623
 624    let total = examples.len();
 625    let already_processed = kept_hashes.len();
 626
 627    eprintln!(
 628        "Resuming: {}/{} examples already processed",
 629        already_processed, total
 630    );
 631
 632    let file = OpenOptions::new()
 633        .write(true)
 634        .truncate(true)
 635        .open(path)
 636        .expect("Failed to open output file for rewriting");
 637    let mut writer = BufWriter::new(file);
 638    for line in &kept_lines {
 639        writeln!(writer, "{}", line).expect("Failed to write to output file");
 640    }
 641    writer.flush().expect("Failed to flush output file");
 642
 643    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
 644}
 645
 646fn main() {
 647    let args = EpArgs::parse();
 648
 649    if args.printenv {
 650        ::util::shell_env::print_env();
 651        return;
 652    }
 653
 654    let output = args.output_path();
 655
 656    if args.markdown && output.is_none() {
 657        eprintln!("--markdown requires -o to specify the output directory");
 658        std::process::exit(1);
 659    }
 660
 661    let command = match &args.command {
 662        Some(cmd) => cmd.clone(),
 663        None => {
 664            EpArgs::command().print_help().unwrap();
 665            return;
 666        }
 667    };
 668
 669    match &command {
 670        Command::ImportBatch(import_args) => {
 671            smol::block_on(async {
 672                match import_args.provider {
 673                    BatchProvider::Anthropic => {
 674                        let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
 675                            .expect("Failed to create Anthropic client");
 676                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 677                            eprintln!("Error importing Anthropic batches: {:?}", e);
 678                            std::process::exit(1);
 679                        }
 680                    }
 681                    BatchProvider::Openai => {
 682                        let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
 683                            .expect("Failed to create OpenAI client");
 684                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 685                            eprintln!("Error importing OpenAI batches: {:?}", e);
 686                            std::process::exit(1);
 687                        }
 688                    }
 689                }
 690                println!(
 691                    "Successfully imported {} batch(es)",
 692                    import_args.batch_ids.len()
 693                );
 694            });
 695            return;
 696        }
 697        Command::Clean => {
 698            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
 699            return;
 700        }
 701        Command::Synthesize(synth_args) => {
 702            let Some(output_dir) = args.output else {
 703                panic!("output dir is required");
 704            };
 705            let config = SynthesizeConfig {
 706                repo_urls: synth_args.repos.clone(),
 707                count: synth_args.count,
 708                max_commits: synth_args.max_commits,
 709                output_dir,
 710                fresh: synth_args.fresh,
 711            };
 712            smol::block_on(async {
 713                if let Err(e) = run_synthesize(config).await {
 714                    eprintln!("Error: {:?}", e);
 715                    std::process::exit(1);
 716                }
 717            });
 718            return;
 719        }
 720        Command::SplitCommit(split_commit_args) => {
 721            if let Err(error) = split_commit::run_split_commit(
 722                split_commit_args,
 723                &args.inputs,
 724                output.as_ref(),
 725                args.failed,
 726            ) {
 727                eprintln!("{error:#}");
 728                std::process::exit(1);
 729            }
 730            return;
 731        }
 732        Command::Split(split_args) => {
 733            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
 734                eprintln!("{error:#}");
 735                std::process::exit(1);
 736            }
 737            return;
 738        }
 739        Command::FilterLanguages(filter_args) => {
 740            if let Err(error) =
 741                run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
 742            {
 743                eprintln!("{error:#}");
 744                std::process::exit(1);
 745            }
 746            return;
 747        }
 748        Command::Qa(qa_args) => {
 749            // Read examples from input files
 750            let mut examples = example::read_example_files(&args.inputs);
 751
 752            // Apply filters
 753            if let Some(name_filter) = &args.name {
 754                examples.retain(|e| e.spec.name.contains(name_filter));
 755            }
 756            if let Some(repo_filter) = &args.repo {
 757                examples.retain(|e| e.spec.repository_url.contains(repo_filter));
 758            }
 759            if let Some(offset) = args.offset {
 760                examples.splice(0..offset, []);
 761            }
 762            if let Some(limit) = args.limit {
 763                examples.truncate(limit);
 764            }
 765
 766            smol::block_on(async {
 767                if let Err(e) = qa::run_qa(&mut examples, qa_args, output.as_ref()).await {
 768                    eprintln!("Error: {:?}", e);
 769                    std::process::exit(1);
 770                }
 771            });
 772            return;
 773        }
 774        Command::Repair(repair_args) => {
 775            // Read examples from input files
 776            let mut examples = example::read_example_files(&args.inputs);
 777
 778            // Apply filters
 779            if let Some(name_filter) = &args.name {
 780                examples.retain(|e| e.spec.name.contains(name_filter));
 781            }
 782            if let Some(repo_filter) = &args.repo {
 783                examples.retain(|e| e.spec.repository_url.contains(repo_filter));
 784            }
 785            if let Some(offset) = args.offset {
 786                examples.splice(0..offset, []);
 787            }
 788            if let Some(limit) = args.limit {
 789                examples.truncate(limit);
 790            }
 791
 792            smol::block_on(async {
 793                if let Err(e) =
 794                    repair::run_repair(&mut examples, repair_args, output.as_ref()).await
 795                {
 796                    eprintln!("Error: {:?}", e);
 797                    std::process::exit(1);
 798                }
 799            });
 800            return;
 801        }
 802        _ => {}
 803    }
 804
 805    let http_client = Arc::new(ReqwestClient::new());
 806    let app = Application::headless().with_http_client(http_client);
 807
 808    app.run(move |cx| {
 809        let app_state = Arc::new(headless::init(cx));
 810        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
 811
 812        cx.spawn(async move |cx| {
 813            let result = async {
 814                let examples = load_examples(
 815                    app_state.client.http_client(),
 816                    &args,
 817                    output.as_ref(),
 818                    cx.background_executor().clone(),
 819                )
 820                .await?;
 821
 822                match &command {
 823                    Command::Predict(args) | Command::Score(args) => {
 824                        predict::sync_batches(args.provider.as_ref()).await?;
 825                    }
 826                    Command::Eval(args) => {
 827                        predict::sync_batches(args.predict.provider.as_ref()).await?;
 828                    }
 829                    _ => (),
 830                }
 831
 832                let failfast_on_single_example = examples.len() == 1;
 833
 834                // For --markdown mode, create the output directory if it doesn't exist
 835                if args.markdown {
 836                    let dir = output.as_ref().expect("--markdown requires -o");
 837                    if !dir.exists() {
 838                        std::fs::create_dir_all(dir)
 839                            .expect("Failed to create markdown output directory");
 840                    }
 841                }
 842
 843                // Set up JSONL output writer (not used in markdown mode)
 844                let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
 845                let mut in_place_temp_path: Option<PathBuf> = None;
 846                if !args.markdown
 847                    && let Some(output_path) = output.as_ref()
 848                {
 849                    let write_path = if args.in_place {
 850                        let temp = output_path.with_extension("jsonl.tmp");
 851                        in_place_temp_path = Some(temp.clone());
 852                        temp
 853                    } else {
 854                        output_path.clone()
 855                    };
 856
 857                    let file = OpenOptions::new()
 858                        .create(true)
 859                        .write(true)
 860                        .truncate(args.in_place)
 861                        .append(!args.in_place)
 862                        .open(&write_path)
 863                        .expect("Failed to open output file");
 864
 865                    let mut writer = BufWriter::new(file);
 866                    let (sender, mut receiver) = mpsc::unbounded::<String>();
 867                    cx.background_spawn(async move {
 868                        while let Some(line) = receiver.next().await {
 869                            writeln!(writer, "{}", line).expect("Failed to write example");
 870                            writer.flush().expect("Failed to flush output");
 871                        }
 872                    })
 873                    .detach();
 874                    output_sender = Some(sender);
 875                }
 876
 877                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
 878                let finished_examples = Mutex::new(Vec::new());
 879
 880                let mut tasks = Vec::new();
 881                for _ in 0..args.max_parallelism {
 882                    tasks.push(async {
 883                        loop {
 884                            let Some(mut repo_examples) =
 885                                grouped_examples.lock().unwrap().pop_front()
 886                            else {
 887                                break;
 888                            };
 889                            for example in &mut repo_examples {
 890                                let example_progress =
 891                                    Progress::global().start_group(&example.spec.name);
 892
 893                                let result = async {
 894                                    match &command {
 895                                        Command::Read => {}
 896                                        Command::LoadProject => {
 897                                            run_load_project(
 898                                                example,
 899                                                app_state.clone(),
 900                                                &example_progress,
 901                                                cx.clone(),
 902                                            )
 903                                            .await?;
 904                                        }
 905                                        Command::Context => {
 906                                            run_context_retrieval(
 907                                                example,
 908                                                app_state.clone(),
 909                                                &example_progress,
 910                                                cx.clone(),
 911                                            )
 912                                            .await?;
 913                                        }
 914                                        Command::FormatPrompt(args) => {
 915                                            run_format_prompt(
 916                                                example,
 917                                                args,
 918                                                app_state.clone(),
 919                                                &example_progress,
 920                                                cx.clone(),
 921                                            )
 922                                            .await?;
 923                                        }
 924                                        Command::Predict(args) => {
 925                                            run_prediction(
 926                                                example,
 927                                                args,
 928                                                app_state.clone(),
 929                                                &example_progress,
 930                                                cx.clone(),
 931                                            )
 932                                            .await?;
 933                                        }
 934                                        Command::ParseOutput => {
 935                                            parse_output::run_parse_output(example)?;
 936                                        }
 937                                        Command::Distill => {
 938                                            run_distill(example).await?;
 939                                        }
 940                                        Command::Score(args) => {
 941                                            run_scoring(
 942                                                example,
 943                                                args,
 944                                                app_state.clone(),
 945                                                &example_progress,
 946                                                cx.clone(),
 947                                            )
 948                                            .await?;
 949                                        }
 950                                        Command::Eval(args) => {
 951                                            run_scoring(
 952                                                example,
 953                                                &args.predict,
 954                                                app_state.clone(),
 955                                                &example_progress,
 956                                                cx.clone(),
 957                                            )
 958                                            .await?;
 959                                        }
 960                                        Command::Clean
 961                                        | Command::Synthesize(_)
 962                                        | Command::SplitCommit(_)
 963                                        | Command::Split(_)
 964                                        | Command::FilterLanguages(_)
 965                                        | Command::ImportBatch(_)
 966                                        | Command::Qa(_)
 967                                        | Command::Repair(_) => {
 968                                            unreachable!()
 969                                        }
 970                                    }
 971                                    anyhow::Ok(())
 972                                }
 973                                .await;
 974
 975                                let failed = if let Err(error) = result {
 976                                    handle_error(
 977                                        error,
 978                                        &args,
 979                                        &command,
 980                                        &app_state,
 981                                        failfast_on_single_example,
 982                                        &example,
 983                                    )
 984                                    .await;
 985                                    true
 986                                } else {
 987                                    false
 988                                };
 989
 990                                let should_write = !failed || args.failed == FailedHandling::Keep;
 991                                if should_write {
 992                                    if args.markdown {
 993                                        let markdown_dir =
 994                                            output.as_ref().expect("--markdown requires -o");
 995                                        let filename = format!("{}.md", example.spec.filename());
 996                                        let path = markdown_dir.join(&filename);
 997                                        let markdown = example.spec.to_markdown();
 998                                        std::fs::write(&path, &markdown)
 999                                            .expect("Failed to write markdown file");
1000                                    } else if let Some(ref mut sender) = output_sender.clone() {
1001                                        let line = serde_json::to_string(&example).unwrap();
1002                                        sender
1003                                            .send(line)
1004                                            .await
1005                                            .expect("Failed to send to output writer");
1006                                    } else if args.output.is_none()
1007                                        && !matches!(command, Command::Eval(_))
1008                                    {
1009                                        let line = serde_json::to_string(&example).unwrap();
1010                                        println!("{}", line);
1011                                    }
1012                                }
1013                            }
1014
1015                            let repo_url = &repo_examples.first().unwrap().spec.repository_url;
1016                            let project = repo_examples
1017                                .iter()
1018                                .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
1019                                .or_else(|| app_state.project_cache.get(repo_url));
1020
1021                            if let Some(project) = project {
1022                                let mut cx = cx.clone();
1023
1024                                let shutdown_task: Task<()> =
1025                                    project.update(&mut cx, |project, cx| {
1026                                        let lsp_store = project.lsp_store();
1027                                        lsp_store.update(cx, |lsp_store, cx| {
1028                                            lsp_store.shutdown_all_language_servers(cx)
1029                                        })
1030                                    });
1031
1032                                shutdown_task.await;
1033
1034                                if let Some(ep_store) =
1035                                    cx.update(|cx| EditPredictionStore::try_global(cx))
1036                                {
1037                                    ep_store.update(&mut cx, |store, _| {
1038                                        store.remove_project(&project);
1039                                    });
1040                                }
1041                            }
1042
1043                            app_state.project_cache.remove(repo_url);
1044                            for example in &mut repo_examples {
1045                                example.state.take();
1046                            }
1047                            finished_examples
1048                                .lock()
1049                                .unwrap()
1050                                .extend_from_slice(&repo_examples);
1051                        }
1052                    });
1053                }
1054                futures::future::join_all(tasks).await;
1055
1056                Progress::global().finalize();
1057
1058                match &command {
1059                    Command::Predict(args) | Command::Score(args) => {
1060                        predict::sync_batches(args.provider.as_ref()).await?;
1061                    }
1062                    Command::Eval(args) => {
1063                        predict::sync_batches(args.predict.provider.as_ref()).await?;
1064                    }
1065                    _ => (),
1066                }
1067
1068                match &command {
1069                    Command::Eval(args) => {
1070                        let examples = finished_examples.lock().unwrap();
1071                        score::print_report(&examples);
1072                        if let Some(summary_path) = &args.summary_json {
1073                            score::write_summary_json(&examples, summary_path)?;
1074                        }
1075                    }
1076                    _ => (),
1077                };
1078
1079                // For --in-place, atomically rename temp file to original
1080                if let Some(temp_path) = &in_place_temp_path {
1081                    let final_path = output.as_ref().expect("in_place_temp_path requires output");
1082                    std::fs::rename(temp_path, final_path)
1083                        .expect("Failed to rename temp file to final output");
1084                }
1085
1086                anyhow::Ok(())
1087            }
1088            .await;
1089
1090            if let Err(e) = result {
1091                panic!("Fatal error: {:?}", e);
1092            }
1093
1094            let _ = cx.update(|cx| cx.quit());
1095        })
1096        .detach();
1097    });
1098}
1099
1100async fn handle_error(
1101    error: anyhow::Error,
1102    args: &EpArgs,
1103    command: &Command,
1104    app_state: &Arc<headless::EpAppState>,
1105    failfast_on_single_example: bool,
1106    example: &Example,
1107) {
1108    Progress::global().increment_failed();
1109
1110    let msg;
1111    if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1112        let example_name = example.spec.filename();
1113
1114        let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1115        app_state
1116            .fs
1117            .write(
1118                &failed_example_path,
1119                &serde_json::to_vec_pretty(&example).unwrap(),
1120            )
1121            .await
1122            .unwrap();
1123        let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1124        app_state
1125            .fs
1126            .write(&err_path, format!("{error:?}").as_bytes())
1127            .await
1128            .unwrap();
1129
1130        let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1131        let mut file = OpenOptions::new()
1132            .create(true)
1133            .append(true)
1134            .open(&failed_jsonl_path)
1135            .expect("Failed to open failed.jsonl");
1136        writeln!(file, "{}", serde_json::to_string(example).unwrap())
1137            .expect("Failed to write to failed.jsonl");
1138
1139        let cursor_path = match example.repo_name() {
1140            Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1141            Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1142        };
1143        msg = format!(
1144            indoc::indoc! {"
1145                While processing \"{}\":
1146
1147                \x1b[31m{:?}\x1b[0m
1148
1149                Example:        \x1b[36m{}\x1b[0m
1150                Error file:     \x1b[36m{}\x1b[0m
1151                Cursor file:    \x1b[36m{}\x1b[0m
1152                Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1153            "},
1154            example.spec.name,
1155            error,
1156            failed_example_path.display(),
1157            err_path.display(),
1158            cursor_path.display(),
1159            command,
1160            failed_example_path.display(),
1161        );
1162    } else {
1163        msg = format!(
1164            indoc::indoc! {"
1165            While processing \"{}\":
1166
1167                \x1b[31m{:?}\x1b[0m
1168            "},
1169            example.spec.name, error
1170        );
1171    }
1172
1173    if args.failfast || failfast_on_single_example {
1174        Progress::global().finalize();
1175        panic!("{}", msg);
1176    } else {
1177        log::error!("{}", msg);
1178    }
1179}