main.rs

   1mod anthropic_client;
   2mod distill;
   3mod example;
   4mod filter_languages;
   5mod format_prompt;
   6mod git;
   7mod headless;
   8mod load_project;
   9mod metrics;
  10mod openai_client;
  11mod parse_output;
  12mod paths;
  13mod predict;
  14mod progress;
  15mod prompt_assets;
  16mod pull_examples;
  17mod qa;
  18mod reorder_patch;
  19mod repair;
  20mod retrieve_context;
  21mod reversal_tracking;
  22mod score;
  23mod split_commit;
  24mod split_dataset;
  25mod synthesize;
  26mod word_diff;
  27use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
  28use collections::HashSet;
  29use edit_prediction::EditPredictionStore;
  30use futures::channel::mpsc;
  31use futures::{SinkExt as _, StreamExt as _};
  32use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
  33use zeta_prompt::ZetaVersion;
  34
  35use reqwest_client::ReqwestClient;
  36use serde::{Deserialize, Deserializer, Serialize, Serializer};
  37use std::fmt::Display;
  38use std::fs::{File, OpenOptions};
  39use std::hash::{Hash, Hasher};
  40use std::io::{BufRead, BufReader, BufWriter, Write};
  41use std::sync::Mutex;
  42use std::{path::PathBuf, sync::Arc};
  43
  44use crate::distill::run_distill;
  45use crate::example::{Example, group_examples_by_repo, read_example_files};
  46use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
  47use crate::format_prompt::run_format_prompt;
  48use crate::load_project::run_load_project;
  49use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
  50use crate::predict::run_prediction;
  51use crate::progress::Progress;
  52use crate::retrieve_context::run_context_retrieval;
  53use crate::score::run_scoring;
  54use crate::split_commit::SplitCommitArgs;
  55use crate::split_dataset::SplitArgs;
  56use crate::synthesize::{SynthesizeConfig, run_synthesize};
  57
  58#[derive(Parser, Debug)]
  59#[command(name = "ep")]
  60struct EpArgs {
  61    #[arg(long, default_value_t = false)]
  62    printenv: bool,
  63    #[clap(long, default_value_t = 10, global = true)]
  64    max_parallelism: usize,
  65    #[clap(long, global = true)]
  66    limit: Option<usize>,
  67    #[clap(long, global = true)]
  68    offset: Option<usize>,
  69    /// Filter examples by name
  70    #[clap(long, global = true)]
  71    name: Option<String>,
  72    /// Filter examples by repository
  73    #[clap(long, global = true)]
  74    repo: Option<String>,
  75    #[command(subcommand)]
  76    command: Option<Command>,
  77    #[clap(global = true, help = INPUTS_HELP)]
  78    inputs: Vec<PathBuf>,
  79    #[arg(long, short, global = true)]
  80    output: Option<PathBuf>,
  81    #[arg(long, short, global = true)]
  82    in_place: bool,
  83    #[arg(long, short, global = true)]
  84    failfast: bool,
  85    /// How to handle failed examples in output: keep them or skip them.
  86    /// Failed examples are always logged to the run's failed directory.
  87    #[arg(long, global = true, default_value = "keep")]
  88    failed: FailedHandling,
  89    /// Output as markdown files instead of JSONL. When set, -o specifies a directory
  90    /// where one .md file per example will be written (named after each example).
  91    #[arg(long, short, global = true)]
  92    markdown: bool,
  93}
  94
  95/// Controls whether failed examples are included in the main output.
  96/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
  97#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
  98pub enum FailedHandling {
  99    /// Include failed examples in the main output (default)
 100    #[default]
 101    Keep,
 102    /// Exclude failed examples from the main output
 103    Skip,
 104    /// Skip writing files
 105    SkipNoFiles,
 106}
 107
 108const INPUTS_HELP: &str = r#"
 109Inputs can be file paths or special specifiers:
 110
 111  path
 112      Path to an example(s) file (.md, .json, or .jsonl)
 113
 114  captured-after:{timestamp}
 115      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 116      These are examples captured via the "Capture Edit Prediction Example" action.
 117
 118  rejected-after:{timestamp}
 119      Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
 120      These are predictions that were shown to users but rejected (useful for DPO training).
 121
 122  rated-after:{timestamp}
 123      Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
 124      These are predictions that users explicitly rated as positive or negative via the
 125      rate completions modal. Only zeta2 predictions are included.
 126      - Positive ratings: output becomes expected_patches
 127      - Negative ratings: output becomes rejected_patch
 128
 129  rated-positive-after:{timestamp}
 130      Same as rated-after, but only fetches positively rated predictions.
 131
 132  rated-negative-after:{timestamp}
 133      Same as rated-after, but only fetches negatively rated predictions.
 134
 135      Required environment variables to connect to Snowflake:
 136          EP_SNOWFLAKE_API_KEY
 137          EP_SNOWFLAKE_BASE_URL
 138
 139      Optional:
 140          EP_SNOWFLAKE_ROLE
 141
 142Examples:
 143
 144  # Read examples from a file
 145  ep read examples.jsonl -o output.jsonl
 146
 147  # Read captured examples after a timestamp
 148  ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
 149
 150  # Read rejected predictions for DPO training
 151  ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
 152
 153  # Read user-rated predictions
 154  ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
 155
 156  # Read only positively rated predictions
 157  ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
 158
 159  # Read only negatively rated predictions
 160  ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
 161
 162  # Mix multiple input sources
 163  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
 164"#;
 165
 166#[derive(Subcommand, Debug, Clone)]
 167enum Command {
 168    /// Read examples from files or fetch from Snowflake, output as .jsonl
 169    Read,
 170    /// Create git worktrees for each example and load file contents
 171    LoadProject,
 172    /// Retrieve context for input examples.
 173    Context,
 174    /// Generate a prompt string for a specific model
 175    FormatPrompt(FormatPromptArgs),
 176    /// Runs edit prediction
 177    Predict(PredictArgs),
 178    /// Parse model outputs (actual_output) into unified diffs (actual_patch).
 179    /// Requires format-prompt to have been run first. Uses provider from prompt.
 180    ParseOutput,
 181    /// Computes a score based on actual and expected patches
 182    Score(PredictArgs),
 183    /// Prepares a distillation dataset by copying expected outputs to
 184    /// predicted outputs and removing actual outputs and prompts.
 185    Distill,
 186    /// Print aggregated scores
 187    Eval(EvalArgs),
 188    /// Generate eval examples by analyzing git commits from a repository
 189    Synthesize(SynthesizeArgs),
 190    /// Remove git repositories and worktrees
 191    Clean,
 192    /// Generate an evaluation example by splitting a chronologically-ordered commit
 193    SplitCommit(SplitCommitArgs),
 194    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
 195    Split(SplitArgs),
 196    /// Filter a JSONL dataset by programming language (based on cursor_path extension)
 197    FilterLanguages(FilterLanguagesArgs),
 198    /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
 199    ImportBatch(ImportBatchArgs),
 200    /// Assess the quality of predictions using LLM-as-a-judge
 201    Qa(qa::QaArgs),
 202    /// Repair predictions that received poor QA scores by generating improved predictions
 203    Repair(repair::RepairArgs),
 204}
 205
 206impl Display for Command {
 207    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 208        match self {
 209            Command::Read => write!(f, "read"),
 210            Command::LoadProject => write!(f, "load-project"),
 211            Command::Context => write!(f, "context"),
 212            Command::FormatPrompt(args) => {
 213                write!(f, "format-prompt --provider={}", args.provider)
 214            }
 215            Command::Predict(args) => match &args.provider {
 216                Some(provider) => write!(f, "predict --provider={}", provider),
 217                None => write!(f, "predict"),
 218            },
 219            Command::ParseOutput => write!(f, "parse-output"),
 220            Command::Score(args) => match &args.provider {
 221                Some(provider) => write!(f, "score --provider={}", provider),
 222                None => write!(f, "score"),
 223            },
 224            Command::Distill => write!(f, "distill"),
 225            Command::Eval(args) => match &args.predict.provider {
 226                Some(provider) => write!(f, "eval --provider={}", provider),
 227                None => write!(f, "eval"),
 228            },
 229            Command::Synthesize(args) => {
 230                write!(f, "synthesize --repos {}", args.repos.join(" "))
 231            }
 232            Command::Clean => write!(f, "clean"),
 233            Command::SplitCommit(_) => write!(f, "split-commit"),
 234            Command::Split(_) => write!(f, "split"),
 235            Command::FilterLanguages(_) => write!(f, "filter-languages"),
 236            Command::ImportBatch(args) => {
 237                write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
 238            }
 239            Command::Qa(_) => {
 240                write!(f, "qa")
 241            }
 242            Command::Repair(_) => {
 243                write!(f, "repair")
 244            }
 245        }
 246    }
 247}
 248
 249#[derive(Debug, Args, Clone)]
 250struct FormatPromptArgs {
 251    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
 252    provider: PredictionProvider,
 253}
 254
 255#[derive(Debug, Args, Clone)]
 256struct PredictArgs {
 257    #[clap(long, short('p'))]
 258    provider: Option<PredictionProvider>,
 259    #[clap(long, default_value_t = 1)]
 260    repetitions: usize,
 261    /// Only use cached responses, don't queue new requests for batching
 262    #[clap(long)]
 263    cache_only: bool,
 264}
 265
 266#[derive(Debug, Args, Clone)]
 267struct EvalArgs {
 268    #[clap(flatten)]
 269    predict: PredictArgs,
 270    /// Path to write summary scores as JSON
 271    #[clap(long)]
 272    summary_json: Option<PathBuf>,
 273}
 274
 275#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 276pub enum TeacherBackend {
 277    Sonnet45,
 278    Gpt52,
 279}
 280
 281impl std::fmt::Display for TeacherBackend {
 282    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 283        match self {
 284            TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
 285            TeacherBackend::Gpt52 => write!(f, "gpt52"),
 286        }
 287    }
 288}
 289
 290impl std::str::FromStr for TeacherBackend {
 291    type Err = anyhow::Error;
 292
 293    fn from_str(s: &str) -> Result<Self, Self::Err> {
 294        match s.to_lowercase().as_str() {
 295            "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
 296            "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
 297            "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
 298            _ => anyhow::bail!("unknown teacher backend `{s}`. Valid options: sonnet45, gpt52"),
 299        }
 300    }
 301}
 302
 303impl TeacherBackend {
 304    pub fn model_name(&self) -> &'static str {
 305        match self {
 306            TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
 307            TeacherBackend::Gpt52 => "gpt-5.2",
 308        }
 309    }
 310}
 311
 312#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 313enum PredictionProvider {
 314    Sweep,
 315    Mercury,
 316    Zeta1,
 317    Zeta2(ZetaVersion),
 318    Teacher(TeacherBackend),
 319    TeacherNonBatching(TeacherBackend),
 320    Repair,
 321}
 322
 323impl Default for PredictionProvider {
 324    fn default() -> Self {
 325        PredictionProvider::Zeta2(ZetaVersion::default())
 326    }
 327}
 328
 329impl std::fmt::Display for PredictionProvider {
 330    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 331        match self {
 332            PredictionProvider::Sweep => write!(f, "sweep"),
 333            PredictionProvider::Mercury => write!(f, "mercury"),
 334            PredictionProvider::Zeta1 => write!(f, "zeta1"),
 335            PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
 336            PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
 337            PredictionProvider::TeacherNonBatching(backend) => {
 338                write!(f, "teacher-non-batching:{backend}")
 339            }
 340            PredictionProvider::Repair => write!(f, "repair"),
 341        }
 342    }
 343}
 344
 345impl std::str::FromStr for PredictionProvider {
 346    type Err = anyhow::Error;
 347
 348    fn from_str(s: &str) -> Result<Self, Self::Err> {
 349        let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
 350
 351        let provider_lower = provider.to_lowercase();
 352        match provider_lower.as_str() {
 353            "sweep" => Ok(PredictionProvider::Sweep),
 354            "mercury" => Ok(PredictionProvider::Mercury),
 355            "zeta1" => Ok(PredictionProvider::Zeta1),
 356            "zeta2" => {
 357                let version = arg.map(ZetaVersion::parse).transpose()?.unwrap_or_default();
 358                Ok(PredictionProvider::Zeta2(version))
 359            }
 360            "teacher" => {
 361                let backend = arg
 362                    .map(|a| a.parse())
 363                    .transpose()?
 364                    .unwrap_or(TeacherBackend::Sonnet45);
 365                Ok(PredictionProvider::Teacher(backend))
 366            }
 367            "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
 368                let backend = arg
 369                    .map(|a| a.parse())
 370                    .transpose()?
 371                    .unwrap_or(TeacherBackend::Sonnet45);
 372                Ok(PredictionProvider::TeacherNonBatching(backend))
 373            }
 374            "repair" => Ok(PredictionProvider::Repair),
 375            _ => {
 376                anyhow::bail!(
 377                    "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
 378                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
 379                 For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\
 380                 Available zeta versions:\n{}",
 381                    ZetaVersion::options_as_string()
 382                )
 383            }
 384        }
 385    }
 386}
 387
 388impl Serialize for PredictionProvider {
 389    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 390    where
 391        S: Serializer,
 392    {
 393        serializer.serialize_str(&self.to_string())
 394    }
 395}
 396
 397impl<'de> Deserialize<'de> for PredictionProvider {
 398    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 399    where
 400        D: Deserializer<'de>,
 401    {
 402        let s = String::deserialize(deserializer)?;
 403        s.parse().map_err(serde::de::Error::custom)
 404    }
 405}
 406
 407#[derive(Debug, Args, Clone)]
 408struct SynthesizeArgs {
 409    /// Repository URLs (git@github.com:owner/repo or https://...)
 410    #[clap(long, required = true, num_args = 1..)]
 411    repos: Vec<String>,
 412
 413    /// Number of examples to generate per repository
 414    #[clap(long, default_value_t = 5)]
 415    count: usize,
 416
 417    /// Maximum commits to scan per repository before giving up
 418    #[clap(long, default_value_t = 100)]
 419    max_commits: usize,
 420
 421    /// Ignore state file and reprocess all commits
 422    #[clap(long)]
 423    fresh: bool,
 424}
 425
 426#[derive(Debug, Args, Clone)]
 427struct ImportBatchArgs {
 428    /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
 429    #[clap(long, required = true, num_args = 1..)]
 430    batch_ids: Vec<String>,
 431    /// Which provider's batches to import (anthropic or openai)
 432    #[clap(long, default_value = "anthropic")]
 433    provider: BatchProvider,
 434}
 435
 436#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
 437enum BatchProvider {
 438    Anthropic,
 439    Openai,
 440}
 441
 442impl EpArgs {
 443    fn output_path(&self) -> Option<PathBuf> {
 444        if self.in_place {
 445            if self.inputs.len() == 1 {
 446                self.inputs.first().cloned()
 447            } else {
 448                panic!("--in-place requires exactly one input file")
 449            }
 450        } else {
 451            self.output.clone()
 452        }
 453    }
 454}
 455
 456async fn load_examples(
 457    http_client: Arc<dyn http_client::HttpClient>,
 458    args: &EpArgs,
 459    output_path: Option<&PathBuf>,
 460    background_executor: BackgroundExecutor,
 461) -> anyhow::Result<Vec<Example>> {
 462    let mut captured_after_timestamps = Vec::new();
 463    let mut rejected_after_timestamps = Vec::new();
 464    let mut requested_after_timestamps = Vec::new();
 465    let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
 466        Vec::new();
 467    let mut file_inputs = Vec::new();
 468
 469    for input in &args.inputs {
 470        let input_string = input.to_string_lossy();
 471        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
 472            captured_after_timestamps.push(timestamp.to_string());
 473        } else if let Some(timestamp) =
 474            pull_examples::parse_rejected_after_input(input_string.as_ref())
 475        {
 476            rejected_after_timestamps.push(timestamp.to_string());
 477        } else if let Some(timestamp) =
 478            pull_examples::parse_requested_after_input(input_string.as_ref())
 479        {
 480            requested_after_timestamps.push(timestamp.to_string());
 481        } else if let Some((timestamp, rating_filter)) =
 482            pull_examples::parse_rated_after_input(input_string.as_ref())
 483        {
 484            rated_after_inputs.push((timestamp.to_string(), rating_filter));
 485        } else {
 486            file_inputs.push(input.clone());
 487        }
 488    }
 489
 490    let mut examples = read_example_files(&file_inputs);
 491
 492    Progress::global().set_total_examples(examples.len());
 493
 494    let remaining_limit_for_snowflake =
 495        args.limit.map(|limit| limit.saturating_sub(examples.len()));
 496
 497    if let Some(0) = remaining_limit_for_snowflake {
 498        log::info!(
 499            "skipping Snowflake inputs because --limit is already satisfied by example files"
 500        );
 501    } else {
 502        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
 503
 504        if !captured_after_timestamps.is_empty() {
 505            captured_after_timestamps.sort();
 506
 507            let mut captured_examples = pull_examples::fetch_captured_examples_after(
 508                http_client.clone(),
 509                &captured_after_timestamps,
 510                max_rows_per_timestamp,
 511                background_executor.clone(),
 512            )
 513            .await?;
 514            examples.append(&mut captured_examples);
 515        }
 516
 517        if !rejected_after_timestamps.is_empty() {
 518            rejected_after_timestamps.sort();
 519
 520            let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
 521                http_client.clone(),
 522                &rejected_after_timestamps,
 523                max_rows_per_timestamp,
 524                background_executor.clone(),
 525            )
 526            .await?;
 527            examples.append(&mut rejected_examples);
 528        }
 529
 530        if !requested_after_timestamps.is_empty() {
 531            requested_after_timestamps.sort();
 532
 533            let mut requested_examples = pull_examples::fetch_requested_examples_after(
 534                http_client.clone(),
 535                &requested_after_timestamps,
 536                max_rows_per_timestamp,
 537                background_executor.clone(),
 538            )
 539            .await?;
 540            examples.append(&mut requested_examples);
 541        }
 542
 543        if !rated_after_inputs.is_empty() {
 544            rated_after_inputs.sort();
 545
 546            let mut rated_examples = pull_examples::fetch_rated_examples_after(
 547                http_client,
 548                &rated_after_inputs,
 549                max_rows_per_timestamp,
 550                background_executor,
 551            )
 552            .await?;
 553            examples.append(&mut rated_examples);
 554        }
 555    }
 556
 557    crate::example::sort_examples_by_repo_and_rev(&mut examples);
 558
 559    if let Some(name_filter) = &args.name {
 560        examples.retain(|example| example.spec.name.contains(name_filter));
 561    }
 562    if let Some(repo_filter) = &args.repo {
 563        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
 564    }
 565
 566    // Skip resume logic for --in-place since input and output are the same file,
 567    // which would incorrectly treat all input examples as already processed.
 568    if !args.in_place {
 569        if let Some(path) = output_path {
 570            resume_from_output(path, &mut examples);
 571        }
 572    }
 573
 574    if let Some(offset) = args.offset {
 575        examples.splice(0..offset, []);
 576    }
 577
 578    if let Some(limit) = args.limit {
 579        examples.truncate(limit);
 580    }
 581
 582    let progress = Progress::global();
 583    progress.set_total_examples(examples.len());
 584    progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
 585
 586    Ok(examples)
 587}
 588
 589fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
 590    let mut hasher = collections::FxHasher::default();
 591    spec.hash(&mut hasher);
 592    hasher.finish()
 593}
 594
 595fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
 596    let file = match File::open(path) {
 597        Ok(f) => f,
 598        Err(_) => return,
 599    };
 600
 601    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
 602
 603    let reader = BufReader::new(file);
 604    let mut kept_lines = Vec::new();
 605    let mut kept_hashes = HashSet::default();
 606
 607    for line in reader.lines() {
 608        let line = match line {
 609            Ok(l) => l,
 610            Err(_) => continue,
 611        };
 612
 613        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
 614            let hash = spec_hash(&output_example.spec);
 615            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
 616                kept_hashes.insert(hash);
 617                kept_lines.push(line);
 618            }
 619        }
 620    }
 621
 622    let total = examples.len();
 623    let already_processed = kept_hashes.len();
 624
 625    eprintln!(
 626        "Resuming: {}/{} examples already processed",
 627        already_processed, total
 628    );
 629
 630    let file = OpenOptions::new()
 631        .write(true)
 632        .truncate(true)
 633        .open(path)
 634        .expect("Failed to open output file for rewriting");
 635    let mut writer = BufWriter::new(file);
 636    for line in &kept_lines {
 637        writeln!(writer, "{}", line).expect("Failed to write to output file");
 638    }
 639    writer.flush().expect("Failed to flush output file");
 640
 641    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
 642}
 643
 644fn main() {
 645    let args = EpArgs::parse();
 646
 647    if args.printenv {
 648        ::util::shell_env::print_env();
 649        return;
 650    }
 651
 652    let output = args.output_path();
 653
 654    if args.markdown && output.is_none() {
 655        eprintln!("--markdown requires -o to specify the output directory");
 656        std::process::exit(1);
 657    }
 658
 659    let command = match &args.command {
 660        Some(cmd) => cmd.clone(),
 661        None => {
 662            EpArgs::command().print_help().unwrap();
 663            return;
 664        }
 665    };
 666
 667    match &command {
 668        Command::ImportBatch(import_args) => {
 669            smol::block_on(async {
 670                match import_args.provider {
 671                    BatchProvider::Anthropic => {
 672                        let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
 673                            .expect("Failed to create Anthropic client");
 674                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 675                            eprintln!("Error importing Anthropic batches: {:?}", e);
 676                            std::process::exit(1);
 677                        }
 678                    }
 679                    BatchProvider::Openai => {
 680                        let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
 681                            .expect("Failed to create OpenAI client");
 682                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 683                            eprintln!("Error importing OpenAI batches: {:?}", e);
 684                            std::process::exit(1);
 685                        }
 686                    }
 687                }
 688                println!(
 689                    "Successfully imported {} batch(es)",
 690                    import_args.batch_ids.len()
 691                );
 692            });
 693            return;
 694        }
 695        Command::Clean => {
 696            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
 697            return;
 698        }
 699        Command::Synthesize(synth_args) => {
 700            let Some(output_dir) = args.output else {
 701                panic!("output dir is required");
 702            };
 703            let config = SynthesizeConfig {
 704                repo_urls: synth_args.repos.clone(),
 705                count: synth_args.count,
 706                max_commits: synth_args.max_commits,
 707                output_dir,
 708                fresh: synth_args.fresh,
 709            };
 710            smol::block_on(async {
 711                if let Err(e) = run_synthesize(config).await {
 712                    eprintln!("Error: {:?}", e);
 713                    std::process::exit(1);
 714                }
 715            });
 716            return;
 717        }
 718        Command::SplitCommit(split_commit_args) => {
 719            if let Err(error) = split_commit::run_split_commit(
 720                split_commit_args,
 721                &args.inputs,
 722                output.as_ref(),
 723                args.failed,
 724            ) {
 725                eprintln!("{error:#}");
 726                std::process::exit(1);
 727            }
 728            return;
 729        }
 730        Command::Split(split_args) => {
 731            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
 732                eprintln!("{error:#}");
 733                std::process::exit(1);
 734            }
 735            return;
 736        }
 737        Command::FilterLanguages(filter_args) => {
 738            if let Err(error) =
 739                run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
 740            {
 741                eprintln!("{error:#}");
 742                std::process::exit(1);
 743            }
 744            return;
 745        }
 746        Command::Qa(qa_args) => {
 747            // Read examples from input files
 748            let mut examples = example::read_example_files(&args.inputs);
 749
 750            // Apply filters
 751            if let Some(name_filter) = &args.name {
 752                examples.retain(|e| e.spec.name.contains(name_filter));
 753            }
 754            if let Some(repo_filter) = &args.repo {
 755                examples.retain(|e| e.spec.repository_url.contains(repo_filter));
 756            }
 757            if let Some(offset) = args.offset {
 758                examples.splice(0..offset, []);
 759            }
 760            if let Some(limit) = args.limit {
 761                examples.truncate(limit);
 762            }
 763
 764            smol::block_on(async {
 765                if let Err(e) = qa::run_qa(&mut examples, qa_args, output.as_ref()).await {
 766                    eprintln!("Error: {:?}", e);
 767                    std::process::exit(1);
 768                }
 769            });
 770            return;
 771        }
 772        Command::Repair(repair_args) => {
 773            // Read examples from input files
 774            let mut examples = example::read_example_files(&args.inputs);
 775
 776            // Apply filters
 777            if let Some(name_filter) = &args.name {
 778                examples.retain(|e| e.spec.name.contains(name_filter));
 779            }
 780            if let Some(repo_filter) = &args.repo {
 781                examples.retain(|e| e.spec.repository_url.contains(repo_filter));
 782            }
 783            if let Some(offset) = args.offset {
 784                examples.splice(0..offset, []);
 785            }
 786            if let Some(limit) = args.limit {
 787                examples.truncate(limit);
 788            }
 789
 790            smol::block_on(async {
 791                if let Err(e) =
 792                    repair::run_repair(&mut examples, repair_args, output.as_ref()).await
 793                {
 794                    eprintln!("Error: {:?}", e);
 795                    std::process::exit(1);
 796                }
 797            });
 798            return;
 799        }
 800        _ => {}
 801    }
 802
 803    let http_client = Arc::new(ReqwestClient::new());
 804    let app = Application::headless().with_http_client(http_client);
 805
 806    app.run(move |cx| {
 807        let app_state = Arc::new(headless::init(cx));
 808        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
 809
 810        cx.spawn(async move |cx| {
 811            let result = async {
 812                let examples = load_examples(
 813                    app_state.client.http_client(),
 814                    &args,
 815                    output.as_ref(),
 816                    cx.background_executor().clone(),
 817                )
 818                .await?;
 819
 820                match &command {
 821                    Command::Predict(args) | Command::Score(args) => {
 822                        predict::sync_batches(args.provider.as_ref()).await?;
 823                    }
 824                    Command::Eval(args) => {
 825                        predict::sync_batches(args.predict.provider.as_ref()).await?;
 826                    }
 827                    _ => (),
 828                }
 829
 830                let failfast_on_single_example = examples.len() == 1;
 831
 832                // For --markdown mode, create the output directory if it doesn't exist
 833                if args.markdown {
 834                    let dir = output.as_ref().expect("--markdown requires -o");
 835                    if !dir.exists() {
 836                        std::fs::create_dir_all(dir)
 837                            .expect("Failed to create markdown output directory");
 838                    }
 839                }
 840
 841                // Set up JSONL output writer (not used in markdown mode)
 842                let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
 843                let mut in_place_temp_path: Option<PathBuf> = None;
 844                if !args.markdown
 845                    && let Some(output_path) = output.as_ref()
 846                {
 847                    let write_path = if args.in_place {
 848                        let temp = output_path.with_extension("jsonl.tmp");
 849                        in_place_temp_path = Some(temp.clone());
 850                        temp
 851                    } else {
 852                        output_path.clone()
 853                    };
 854
 855                    let file = OpenOptions::new()
 856                        .create(true)
 857                        .write(true)
 858                        .truncate(args.in_place)
 859                        .append(!args.in_place)
 860                        .open(&write_path)
 861                        .expect("Failed to open output file");
 862
 863                    let mut writer = BufWriter::new(file);
 864                    let (sender, mut receiver) = mpsc::unbounded::<String>();
 865                    cx.background_spawn(async move {
 866                        while let Some(line) = receiver.next().await {
 867                            writeln!(writer, "{}", line).expect("Failed to write example");
 868                            writer.flush().expect("Failed to flush output");
 869                        }
 870                    })
 871                    .detach();
 872                    output_sender = Some(sender);
 873                }
 874
 875                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
 876                let finished_examples = Mutex::new(Vec::new());
 877
 878                let mut tasks = Vec::new();
 879                for _ in 0..args.max_parallelism {
 880                    tasks.push(async {
 881                        loop {
 882                            let Some(mut repo_examples) =
 883                                grouped_examples.lock().unwrap().pop_front()
 884                            else {
 885                                break;
 886                            };
 887                            for example in &mut repo_examples {
 888                                let example_progress =
 889                                    Progress::global().start_group(&example.spec.name);
 890
 891                                let result = async {
 892                                    match &command {
 893                                        Command::Read => {}
 894                                        Command::LoadProject => {
 895                                            run_load_project(
 896                                                example,
 897                                                app_state.clone(),
 898                                                &example_progress,
 899                                                cx.clone(),
 900                                            )
 901                                            .await?;
 902                                        }
 903                                        Command::Context => {
 904                                            run_context_retrieval(
 905                                                example,
 906                                                app_state.clone(),
 907                                                &example_progress,
 908                                                cx.clone(),
 909                                            )
 910                                            .await?;
 911                                        }
 912                                        Command::FormatPrompt(args) => {
 913                                            run_format_prompt(
 914                                                example,
 915                                                args,
 916                                                app_state.clone(),
 917                                                &example_progress,
 918                                                cx.clone(),
 919                                            )
 920                                            .await?;
 921                                        }
 922                                        Command::Predict(args) => {
 923                                            run_prediction(
 924                                                example,
 925                                                args,
 926                                                app_state.clone(),
 927                                                &example_progress,
 928                                                cx.clone(),
 929                                            )
 930                                            .await?;
 931                                        }
 932                                        Command::ParseOutput => {
 933                                            parse_output::run_parse_output(example)?;
 934                                        }
 935                                        Command::Distill => {
 936                                            run_distill(example).await?;
 937                                        }
 938                                        Command::Score(args) => {
 939                                            run_scoring(
 940                                                example,
 941                                                args,
 942                                                app_state.clone(),
 943                                                &example_progress,
 944                                                cx.clone(),
 945                                            )
 946                                            .await?;
 947                                        }
 948                                        Command::Eval(args) => {
 949                                            run_scoring(
 950                                                example,
 951                                                &args.predict,
 952                                                app_state.clone(),
 953                                                &example_progress,
 954                                                cx.clone(),
 955                                            )
 956                                            .await?;
 957                                        }
 958                                        Command::Clean
 959                                        | Command::Synthesize(_)
 960                                        | Command::SplitCommit(_)
 961                                        | Command::Split(_)
 962                                        | Command::FilterLanguages(_)
 963                                        | Command::ImportBatch(_)
 964                                        | Command::Qa(_)
 965                                        | Command::Repair(_) => {
 966                                            unreachable!()
 967                                        }
 968                                    }
 969                                    anyhow::Ok(())
 970                                }
 971                                .await;
 972
 973                                let failed = if let Err(error) = result {
 974                                    handle_error(
 975                                        error,
 976                                        &args,
 977                                        &command,
 978                                        &app_state,
 979                                        failfast_on_single_example,
 980                                        &example,
 981                                    )
 982                                    .await;
 983                                    true
 984                                } else {
 985                                    false
 986                                };
 987
 988                                let should_write = !failed || args.failed == FailedHandling::Keep;
 989                                if should_write {
 990                                    if args.markdown {
 991                                        let markdown_dir =
 992                                            output.as_ref().expect("--markdown requires -o");
 993                                        let filename = format!("{}.md", example.spec.filename());
 994                                        let path = markdown_dir.join(&filename);
 995                                        let markdown = example.spec.to_markdown();
 996                                        std::fs::write(&path, &markdown)
 997                                            .expect("Failed to write markdown file");
 998                                    } else if let Some(ref mut sender) = output_sender.clone() {
 999                                        let line = serde_json::to_string(&example).unwrap();
1000                                        sender
1001                                            .send(line)
1002                                            .await
1003                                            .expect("Failed to send to output writer");
1004                                    } else if args.output.is_none()
1005                                        && !matches!(command, Command::Eval(_))
1006                                    {
1007                                        let line = serde_json::to_string(&example).unwrap();
1008                                        println!("{}", line);
1009                                    }
1010                                }
1011                            }
1012
1013                            let repo_url = &repo_examples.first().unwrap().spec.repository_url;
1014                            let project = repo_examples
1015                                .iter()
1016                                .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
1017                                .or_else(|| app_state.project_cache.get(repo_url));
1018
1019                            if let Some(project) = project {
1020                                let mut cx = cx.clone();
1021
1022                                let shutdown_task: Task<()> =
1023                                    project.update(&mut cx, |project, cx| {
1024                                        let lsp_store = project.lsp_store();
1025                                        lsp_store.update(cx, |lsp_store, cx| {
1026                                            lsp_store.shutdown_all_language_servers(cx)
1027                                        })
1028                                    });
1029
1030                                shutdown_task.await;
1031
1032                                if let Some(ep_store) =
1033                                    cx.update(|cx| EditPredictionStore::try_global(cx))
1034                                {
1035                                    ep_store.update(&mut cx, |store, _| {
1036                                        store.remove_project(&project);
1037                                    });
1038                                }
1039                            }
1040
1041                            app_state.project_cache.remove(repo_url);
1042                            for example in &mut repo_examples {
1043                                example.state.take();
1044                            }
1045                            finished_examples
1046                                .lock()
1047                                .unwrap()
1048                                .extend_from_slice(&repo_examples);
1049                        }
1050                    });
1051                }
1052                futures::future::join_all(tasks).await;
1053
1054                Progress::global().finalize();
1055
1056                match &command {
1057                    Command::Predict(args) | Command::Score(args) => {
1058                        predict::sync_batches(args.provider.as_ref()).await?;
1059                    }
1060                    Command::Eval(args) => {
1061                        predict::sync_batches(args.predict.provider.as_ref()).await?;
1062                    }
1063                    _ => (),
1064                }
1065
1066                match &command {
1067                    Command::Eval(args) => {
1068                        let examples = finished_examples.lock().unwrap();
1069                        score::print_report(&examples);
1070                        if let Some(summary_path) = &args.summary_json {
1071                            score::write_summary_json(&examples, summary_path)?;
1072                        }
1073                    }
1074                    _ => (),
1075                };
1076
1077                // For --in-place, atomically rename temp file to original
1078                if let Some(temp_path) = &in_place_temp_path {
1079                    let final_path = output.as_ref().expect("in_place_temp_path requires output");
1080                    std::fs::rename(temp_path, final_path)
1081                        .expect("Failed to rename temp file to final output");
1082                }
1083
1084                anyhow::Ok(())
1085            }
1086            .await;
1087
1088            if let Err(e) = result {
1089                panic!("Fatal error: {:?}", e);
1090            }
1091
1092            let _ = cx.update(|cx| cx.quit());
1093        })
1094        .detach();
1095    });
1096}
1097
1098async fn handle_error(
1099    error: anyhow::Error,
1100    args: &EpArgs,
1101    command: &Command,
1102    app_state: &Arc<headless::EpAppState>,
1103    failfast_on_single_example: bool,
1104    example: &Example,
1105) {
1106    Progress::global().increment_failed();
1107
1108    let msg;
1109    if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1110        let example_name = example.spec.filename();
1111
1112        let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1113        app_state
1114            .fs
1115            .write(
1116                &failed_example_path,
1117                &serde_json::to_vec_pretty(&example).unwrap(),
1118            )
1119            .await
1120            .unwrap();
1121        let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1122        app_state
1123            .fs
1124            .write(&err_path, format!("{error:?}").as_bytes())
1125            .await
1126            .unwrap();
1127
1128        let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1129        let mut file = OpenOptions::new()
1130            .create(true)
1131            .append(true)
1132            .open(&failed_jsonl_path)
1133            .expect("Failed to open failed.jsonl");
1134        writeln!(file, "{}", serde_json::to_string(example).unwrap())
1135            .expect("Failed to write to failed.jsonl");
1136
1137        let cursor_path = match example.repo_name() {
1138            Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1139            Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1140        };
1141        msg = format!(
1142            indoc::indoc! {"
1143                While processing \"{}\":
1144
1145                \x1b[31m{:?}\x1b[0m
1146
1147                Example:        \x1b[36m{}\x1b[0m
1148                Error file:     \x1b[36m{}\x1b[0m
1149                Cursor file:    \x1b[36m{}\x1b[0m
1150                Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1151            "},
1152            example.spec.name,
1153            error,
1154            failed_example_path.display(),
1155            err_path.display(),
1156            cursor_path.display(),
1157            command,
1158            failed_example_path.display(),
1159        );
1160    } else {
1161        msg = format!(
1162            indoc::indoc! {"
1163            While processing \"{}\":
1164
1165                \x1b[31m{:?}\x1b[0m
1166            "},
1167            example.spec.name, error
1168        );
1169    }
1170
1171    if args.failfast || failfast_on_single_example {
1172        Progress::global().finalize();
1173        panic!("{}", msg);
1174    } else {
1175        log::error!("{}", msg);
1176    }
1177}