main.rs

   1mod anthropic_client;
   2mod distill;
   3mod example;
   4mod filter_languages;
   5mod format_prompt;
   6mod git;
   7mod headless;
   8mod load_project;
   9mod metrics;
  10mod openai_client;
  11mod parse_output;
  12mod paths;
  13mod predict;
  14mod progress;
  15mod pull_examples;
  16mod qa;
  17mod reorder_patch;
  18mod repair;
  19mod retrieve_context;
  20mod reversal_tracking;
  21mod score;
  22mod split_commit;
  23mod split_dataset;
  24mod synthesize;
  25mod word_diff;
  26use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
  27use collections::HashSet;
  28use edit_prediction::EditPredictionStore;
  29use futures::channel::mpsc;
  30use futures::{SinkExt as _, StreamExt as _};
  31use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
  32use zeta_prompt::ZetaVersion;
  33
  34use reqwest_client::ReqwestClient;
  35use serde::{Deserialize, Deserializer, Serialize, Serializer};
  36use std::fmt::Display;
  37use std::fs::{File, OpenOptions};
  38use std::hash::{Hash, Hasher};
  39use std::io::{BufRead, BufReader, BufWriter, Write};
  40use std::sync::Mutex;
  41use std::{path::PathBuf, sync::Arc};
  42
  43use crate::distill::run_distill;
  44use crate::example::{Example, group_examples_by_repo, read_example_files};
  45use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
  46use crate::format_prompt::run_format_prompt;
  47use crate::load_project::run_load_project;
  48use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
  49use crate::predict::run_prediction;
  50use crate::progress::Progress;
  51use crate::retrieve_context::run_context_retrieval;
  52use crate::score::run_scoring;
  53use crate::split_commit::SplitCommitArgs;
  54use crate::split_dataset::SplitArgs;
  55use crate::synthesize::{SynthesizeConfig, run_synthesize};
  56
  57#[derive(Parser, Debug)]
  58#[command(name = "ep")]
  59struct EpArgs {
  60    #[arg(long, default_value_t = false)]
  61    printenv: bool,
  62    #[clap(long, default_value_t = 10, global = true)]
  63    max_parallelism: usize,
  64    #[clap(long, global = true)]
  65    limit: Option<usize>,
  66    #[clap(long, global = true)]
  67    offset: Option<usize>,
  68    /// Filter examples by name
  69    #[clap(long, global = true)]
  70    name: Option<String>,
  71    /// Filter examples by repository
  72    #[clap(long, global = true)]
  73    repo: Option<String>,
  74    #[command(subcommand)]
  75    command: Option<Command>,
  76    #[clap(global = true, help = INPUTS_HELP)]
  77    inputs: Vec<PathBuf>,
  78    #[arg(long, short, global = true)]
  79    output: Option<PathBuf>,
  80    #[arg(long, short, global = true)]
  81    in_place: bool,
  82    #[arg(long, short, global = true)]
  83    failfast: bool,
  84    /// How to handle failed examples in output: keep them or skip them.
  85    /// Failed examples are always logged to the run's failed directory.
  86    #[arg(long, global = true, default_value = "keep")]
  87    failed: FailedHandling,
  88    /// Output as markdown files instead of JSONL. When set, -o specifies a directory
  89    /// where one .md file per example will be written (named after each example).
  90    #[arg(long, short, global = true)]
  91    markdown: bool,
  92}
  93
  94/// Controls whether failed examples are included in the main output.
  95/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
  96#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
  97pub enum FailedHandling {
  98    /// Include failed examples in the main output (default)
  99    #[default]
 100    Keep,
 101    /// Exclude failed examples from the main output
 102    Skip,
 103    /// Skip writing files
 104    SkipNoFiles,
 105}
 106
 107const INPUTS_HELP: &str = r#"
 108Inputs can be file paths or special specifiers:
 109
 110  path
 111      Path to an example(s) file (.md, .json, or .jsonl)
 112
 113  captured-after:{timestamp}
 114      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 115      These are examples captured via the "Capture Edit Prediction Example" action.
 116
 117  rejected-after:{timestamp}
 118      Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
 119      These are predictions that were shown to users but rejected (useful for DPO training).
 120
 121  rated-after:{timestamp}
 122      Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
 123      These are predictions that users explicitly rated as positive or negative via the
 124      rate completions modal. Only zeta2 predictions are included.
 125      - Positive ratings: output becomes expected_patches
 126      - Negative ratings: output becomes rejected_patch
 127
 128  rated-positive-after:{timestamp}
 129      Same as rated-after, but only fetches positively rated predictions.
 130
 131  rated-negative-after:{timestamp}
 132      Same as rated-after, but only fetches negatively rated predictions.
 133
 134      Required environment variables to connect to Snowflake:
 135          EP_SNOWFLAKE_API_KEY
 136          EP_SNOWFLAKE_BASE_URL
 137
 138      Optional:
 139          EP_SNOWFLAKE_ROLE
 140
 141Examples:
 142
 143  # Read examples from a file
 144  ep read examples.jsonl -o output.jsonl
 145
 146  # Read captured examples after a timestamp
 147  ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
 148
 149  # Read rejected predictions for DPO training
 150  ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
 151
 152  # Read user-rated predictions
 153  ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
 154
 155  # Read only positively rated predictions
 156  ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
 157
 158  # Read only negatively rated predictions
 159  ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
 160
 161  # Mix multiple input sources
 162  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
 163"#;
 164
 165#[derive(Subcommand, Debug, Clone)]
 166enum Command {
 167    /// Read examples from files or fetch from Snowflake, output as .jsonl
 168    Read,
 169    /// Create git worktrees for each example and load file contents
 170    LoadProject,
 171    /// Retrieve context for input examples.
 172    Context,
 173    /// Generate a prompt string for a specific model
 174    FormatPrompt(FormatPromptArgs),
 175    /// Runs edit prediction
 176    Predict(PredictArgs),
 177    /// Parse model outputs (actual_output) into unified diffs (actual_patch).
 178    /// Requires format-prompt to have been run first. Uses provider from prompt.
 179    ParseOutput,
 180    /// Computes a score based on actual and expected patches
 181    Score(PredictArgs),
 182    /// Prepares a distillation dataset by copying expected outputs to
 183    /// predicted outputs and removing actual outputs and prompts.
 184    Distill,
 185    /// Print aggregated scores
 186    Eval(EvalArgs),
 187    /// Generate eval examples by analyzing git commits from a repository
 188    Synthesize(SynthesizeArgs),
 189    /// Remove git repositories and worktrees
 190    Clean,
 191    /// Generate an evaluation example by splitting a chronologically-ordered commit
 192    SplitCommit(SplitCommitArgs),
 193    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
 194    Split(SplitArgs),
 195    /// Filter a JSONL dataset by programming language (based on cursor_path extension)
 196    FilterLanguages(FilterLanguagesArgs),
 197    /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
 198    ImportBatch(ImportBatchArgs),
 199    /// Assess the quality of predictions using LLM-as-a-judge
 200    Qa(qa::QaArgs),
 201    /// Repair predictions that received poor QA scores by generating improved predictions
 202    Repair(repair::RepairArgs),
 203}
 204
 205impl Display for Command {
 206    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 207        match self {
 208            Command::Read => write!(f, "read"),
 209            Command::LoadProject => write!(f, "load-project"),
 210            Command::Context => write!(f, "context"),
 211            Command::FormatPrompt(args) => {
 212                write!(f, "format-prompt --provider={}", args.provider)
 213            }
 214            Command::Predict(args) => match &args.provider {
 215                Some(provider) => write!(f, "predict --provider={}", provider),
 216                None => write!(f, "predict"),
 217            },
 218            Command::ParseOutput => write!(f, "parse-output"),
 219            Command::Score(args) => match &args.provider {
 220                Some(provider) => write!(f, "score --provider={}", provider),
 221                None => write!(f, "score"),
 222            },
 223            Command::Distill => write!(f, "distill"),
 224            Command::Eval(args) => match &args.predict.provider {
 225                Some(provider) => write!(f, "eval --provider={}", provider),
 226                None => write!(f, "eval"),
 227            },
 228            Command::Synthesize(args) => {
 229                write!(f, "synthesize --repos {}", args.repos.join(" "))
 230            }
 231            Command::Clean => write!(f, "clean"),
 232            Command::SplitCommit(_) => write!(f, "split-commit"),
 233            Command::Split(_) => write!(f, "split"),
 234            Command::FilterLanguages(_) => write!(f, "filter-languages"),
 235            Command::ImportBatch(args) => {
 236                write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
 237            }
 238            Command::Qa(_) => {
 239                write!(f, "qa")
 240            }
 241            Command::Repair(_) => {
 242                write!(f, "repair")
 243            }
 244        }
 245    }
 246}
 247
 248#[derive(Debug, Args, Clone)]
 249struct FormatPromptArgs {
 250    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
 251    provider: PredictionProvider,
 252}
 253
 254#[derive(Debug, Args, Clone)]
 255struct PredictArgs {
 256    #[clap(long, short('p'))]
 257    provider: Option<PredictionProvider>,
 258    #[clap(long, default_value_t = 1)]
 259    repetitions: usize,
 260}
 261
 262#[derive(Debug, Args, Clone)]
 263struct EvalArgs {
 264    #[clap(flatten)]
 265    predict: PredictArgs,
 266    /// Path to write summary scores as JSON
 267    #[clap(long)]
 268    summary_json: Option<PathBuf>,
 269}
 270
 271#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 272pub enum TeacherBackend {
 273    Sonnet45,
 274    Gpt52,
 275}
 276
 277impl std::fmt::Display for TeacherBackend {
 278    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 279        match self {
 280            TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
 281            TeacherBackend::Gpt52 => write!(f, "gpt52"),
 282        }
 283    }
 284}
 285
 286impl std::str::FromStr for TeacherBackend {
 287    type Err = anyhow::Error;
 288
 289    fn from_str(s: &str) -> Result<Self, Self::Err> {
 290        match s.to_lowercase().as_str() {
 291            "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
 292            "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
 293            "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
 294            _ => anyhow::bail!("unknown teacher backend `{s}`. Valid options: sonnet45, gpt52"),
 295        }
 296    }
 297}
 298
 299impl TeacherBackend {
 300    pub fn model_name(&self) -> &'static str {
 301        match self {
 302            TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
 303            TeacherBackend::Gpt52 => "gpt-5.2",
 304        }
 305    }
 306}
 307
 308#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 309enum PredictionProvider {
 310    Sweep,
 311    Mercury,
 312    Zeta1,
 313    Zeta2(ZetaVersion),
 314    Teacher(TeacherBackend),
 315    TeacherNonBatching(TeacherBackend),
 316    Repair,
 317}
 318
 319impl Default for PredictionProvider {
 320    fn default() -> Self {
 321        PredictionProvider::Zeta2(ZetaVersion::default())
 322    }
 323}
 324
 325impl std::fmt::Display for PredictionProvider {
 326    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 327        match self {
 328            PredictionProvider::Sweep => write!(f, "sweep"),
 329            PredictionProvider::Mercury => write!(f, "mercury"),
 330            PredictionProvider::Zeta1 => write!(f, "zeta1"),
 331            PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
 332            PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
 333            PredictionProvider::TeacherNonBatching(backend) => {
 334                write!(f, "teacher-non-batching:{backend}")
 335            }
 336            PredictionProvider::Repair => write!(f, "repair"),
 337        }
 338    }
 339}
 340
 341impl std::str::FromStr for PredictionProvider {
 342    type Err = anyhow::Error;
 343
 344    fn from_str(s: &str) -> Result<Self, Self::Err> {
 345        let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
 346
 347        let provider_lower = provider.to_lowercase();
 348        match provider_lower.as_str() {
 349            "sweep" => Ok(PredictionProvider::Sweep),
 350            "mercury" => Ok(PredictionProvider::Mercury),
 351            "zeta1" => Ok(PredictionProvider::Zeta1),
 352            "zeta2" => {
 353                let version = arg.map(ZetaVersion::parse).transpose()?.unwrap_or_default();
 354                Ok(PredictionProvider::Zeta2(version))
 355            }
 356            "teacher" => {
 357                let backend = arg
 358                    .map(|a| a.parse())
 359                    .transpose()?
 360                    .unwrap_or(TeacherBackend::Sonnet45);
 361                Ok(PredictionProvider::Teacher(backend))
 362            }
 363            "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
 364                let backend = arg
 365                    .map(|a| a.parse())
 366                    .transpose()?
 367                    .unwrap_or(TeacherBackend::Sonnet45);
 368                Ok(PredictionProvider::TeacherNonBatching(backend))
 369            }
 370            "repair" => Ok(PredictionProvider::Repair),
 371            _ => {
 372                anyhow::bail!(
 373                    "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
 374                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
 375                 For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\
 376                 Available zeta versions:\n{}",
 377                    ZetaVersion::options_as_string()
 378                )
 379            }
 380        }
 381    }
 382}
 383
 384impl Serialize for PredictionProvider {
 385    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 386    where
 387        S: Serializer,
 388    {
 389        serializer.serialize_str(&self.to_string())
 390    }
 391}
 392
 393impl<'de> Deserialize<'de> for PredictionProvider {
 394    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 395    where
 396        D: Deserializer<'de>,
 397    {
 398        let s = String::deserialize(deserializer)?;
 399        s.parse().map_err(serde::de::Error::custom)
 400    }
 401}
 402
 403#[derive(Debug, Args, Clone)]
 404struct SynthesizeArgs {
 405    /// Repository URLs (git@github.com:owner/repo or https://...)
 406    #[clap(long, required = true, num_args = 1..)]
 407    repos: Vec<String>,
 408
 409    /// Number of examples to generate per repository
 410    #[clap(long, default_value_t = 5)]
 411    count: usize,
 412
 413    /// Maximum commits to scan per repository before giving up
 414    #[clap(long, default_value_t = 100)]
 415    max_commits: usize,
 416
 417    /// Ignore state file and reprocess all commits
 418    #[clap(long)]
 419    fresh: bool,
 420}
 421
 422#[derive(Debug, Args, Clone)]
 423struct ImportBatchArgs {
 424    /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
 425    #[clap(long, required = true, num_args = 1..)]
 426    batch_ids: Vec<String>,
 427    /// Which provider's batches to import (anthropic or openai)
 428    #[clap(long, default_value = "anthropic")]
 429    provider: BatchProvider,
 430}
 431
 432#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
 433enum BatchProvider {
 434    Anthropic,
 435    Openai,
 436}
 437
 438impl EpArgs {
 439    fn output_path(&self) -> Option<PathBuf> {
 440        if self.in_place {
 441            if self.inputs.len() == 1 {
 442                self.inputs.first().cloned()
 443            } else {
 444                panic!("--in-place requires exactly one input file")
 445            }
 446        } else {
 447            self.output.clone()
 448        }
 449    }
 450}
 451
 452async fn load_examples(
 453    http_client: Arc<dyn http_client::HttpClient>,
 454    args: &EpArgs,
 455    output_path: Option<&PathBuf>,
 456    background_executor: BackgroundExecutor,
 457) -> anyhow::Result<Vec<Example>> {
 458    let mut captured_after_timestamps = Vec::new();
 459    let mut rejected_after_timestamps = Vec::new();
 460    let mut requested_after_timestamps = Vec::new();
 461    let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
 462        Vec::new();
 463    let mut file_inputs = Vec::new();
 464
 465    for input in &args.inputs {
 466        let input_string = input.to_string_lossy();
 467        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
 468            captured_after_timestamps.push(timestamp.to_string());
 469        } else if let Some(timestamp) =
 470            pull_examples::parse_rejected_after_input(input_string.as_ref())
 471        {
 472            rejected_after_timestamps.push(timestamp.to_string());
 473        } else if let Some(timestamp) =
 474            pull_examples::parse_requested_after_input(input_string.as_ref())
 475        {
 476            requested_after_timestamps.push(timestamp.to_string());
 477        } else if let Some((timestamp, rating_filter)) =
 478            pull_examples::parse_rated_after_input(input_string.as_ref())
 479        {
 480            rated_after_inputs.push((timestamp.to_string(), rating_filter));
 481        } else {
 482            file_inputs.push(input.clone());
 483        }
 484    }
 485
 486    let mut examples = read_example_files(&file_inputs);
 487
 488    Progress::global().set_total_examples(examples.len());
 489
 490    let remaining_limit_for_snowflake =
 491        args.limit.map(|limit| limit.saturating_sub(examples.len()));
 492
 493    if let Some(0) = remaining_limit_for_snowflake {
 494        log::info!(
 495            "skipping Snowflake inputs because --limit is already satisfied by example files"
 496        );
 497    } else {
 498        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
 499
 500        if !captured_after_timestamps.is_empty() {
 501            captured_after_timestamps.sort();
 502
 503            let mut captured_examples = pull_examples::fetch_captured_examples_after(
 504                http_client.clone(),
 505                &captured_after_timestamps,
 506                max_rows_per_timestamp,
 507                background_executor.clone(),
 508            )
 509            .await?;
 510            examples.append(&mut captured_examples);
 511        }
 512
 513        if !rejected_after_timestamps.is_empty() {
 514            rejected_after_timestamps.sort();
 515
 516            let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
 517                http_client.clone(),
 518                &rejected_after_timestamps,
 519                max_rows_per_timestamp,
 520                background_executor.clone(),
 521            )
 522            .await?;
 523            examples.append(&mut rejected_examples);
 524        }
 525
 526        if !requested_after_timestamps.is_empty() {
 527            requested_after_timestamps.sort();
 528
 529            let mut requested_examples = pull_examples::fetch_requested_examples_after(
 530                http_client.clone(),
 531                &requested_after_timestamps,
 532                max_rows_per_timestamp,
 533                background_executor.clone(),
 534            )
 535            .await?;
 536            examples.append(&mut requested_examples);
 537        }
 538
 539        if !rated_after_inputs.is_empty() {
 540            rated_after_inputs.sort();
 541
 542            let mut rated_examples = pull_examples::fetch_rated_examples_after(
 543                http_client,
 544                &rated_after_inputs,
 545                max_rows_per_timestamp,
 546                background_executor,
 547            )
 548            .await?;
 549            examples.append(&mut rated_examples);
 550        }
 551    }
 552
 553    crate::example::sort_examples_by_repo_and_rev(&mut examples);
 554
 555    if let Some(name_filter) = &args.name {
 556        examples.retain(|example| example.spec.name.contains(name_filter));
 557    }
 558    if let Some(repo_filter) = &args.repo {
 559        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
 560    }
 561
 562    // Skip resume logic for --in-place since input and output are the same file,
 563    // which would incorrectly treat all input examples as already processed.
 564    if !args.in_place {
 565        if let Some(path) = output_path {
 566            resume_from_output(path, &mut examples);
 567        }
 568    }
 569
 570    if let Some(offset) = args.offset {
 571        examples.splice(0..offset, []);
 572    }
 573
 574    if let Some(limit) = args.limit {
 575        examples.truncate(limit);
 576    }
 577
 578    let progress = Progress::global();
 579    progress.set_total_examples(examples.len());
 580    progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
 581
 582    Ok(examples)
 583}
 584
 585fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
 586    let mut hasher = collections::FxHasher::default();
 587    spec.hash(&mut hasher);
 588    hasher.finish()
 589}
 590
 591fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
 592    let file = match File::open(path) {
 593        Ok(f) => f,
 594        Err(_) => return,
 595    };
 596
 597    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
 598
 599    let reader = BufReader::new(file);
 600    let mut kept_lines = Vec::new();
 601    let mut kept_hashes = HashSet::default();
 602
 603    for line in reader.lines() {
 604        let line = match line {
 605            Ok(l) => l,
 606            Err(_) => continue,
 607        };
 608
 609        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
 610            let hash = spec_hash(&output_example.spec);
 611            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
 612                kept_hashes.insert(hash);
 613                kept_lines.push(line);
 614            }
 615        }
 616    }
 617
 618    let total = examples.len();
 619    let already_processed = kept_hashes.len();
 620
 621    eprintln!(
 622        "Resuming: {}/{} examples already processed",
 623        already_processed, total
 624    );
 625
 626    let file = OpenOptions::new()
 627        .write(true)
 628        .truncate(true)
 629        .open(path)
 630        .expect("Failed to open output file for rewriting");
 631    let mut writer = BufWriter::new(file);
 632    for line in &kept_lines {
 633        writeln!(writer, "{}", line).expect("Failed to write to output file");
 634    }
 635    writer.flush().expect("Failed to flush output file");
 636
 637    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
 638}
 639
 640fn main() {
 641    let args = EpArgs::parse();
 642
 643    if args.printenv {
 644        ::util::shell_env::print_env();
 645        return;
 646    }
 647
 648    let output = args.output_path();
 649
 650    if args.markdown && output.is_none() {
 651        eprintln!("--markdown requires -o to specify the output directory");
 652        std::process::exit(1);
 653    }
 654
 655    let command = match &args.command {
 656        Some(cmd) => cmd.clone(),
 657        None => {
 658            EpArgs::command().print_help().unwrap();
 659            return;
 660        }
 661    };
 662
 663    match &command {
 664        Command::ImportBatch(import_args) => {
 665            smol::block_on(async {
 666                match import_args.provider {
 667                    BatchProvider::Anthropic => {
 668                        let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
 669                            .expect("Failed to create Anthropic client");
 670                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 671                            eprintln!("Error importing Anthropic batches: {:?}", e);
 672                            std::process::exit(1);
 673                        }
 674                    }
 675                    BatchProvider::Openai => {
 676                        let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
 677                            .expect("Failed to create OpenAI client");
 678                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 679                            eprintln!("Error importing OpenAI batches: {:?}", e);
 680                            std::process::exit(1);
 681                        }
 682                    }
 683                }
 684                println!(
 685                    "Successfully imported {} batch(es)",
 686                    import_args.batch_ids.len()
 687                );
 688            });
 689            return;
 690        }
 691        Command::Clean => {
 692            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
 693            return;
 694        }
 695        Command::Synthesize(synth_args) => {
 696            let Some(output_dir) = args.output else {
 697                panic!("output dir is required");
 698            };
 699            let config = SynthesizeConfig {
 700                repo_urls: synth_args.repos.clone(),
 701                count: synth_args.count,
 702                max_commits: synth_args.max_commits,
 703                output_dir,
 704                fresh: synth_args.fresh,
 705            };
 706            smol::block_on(async {
 707                if let Err(e) = run_synthesize(config).await {
 708                    eprintln!("Error: {:?}", e);
 709                    std::process::exit(1);
 710                }
 711            });
 712            return;
 713        }
 714        Command::SplitCommit(split_commit_args) => {
 715            if let Err(error) = split_commit::run_split_commit(
 716                split_commit_args,
 717                &args.inputs,
 718                output.as_ref(),
 719                args.failed,
 720            ) {
 721                eprintln!("{error:#}");
 722                std::process::exit(1);
 723            }
 724            return;
 725        }
 726        Command::Split(split_args) => {
 727            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
 728                eprintln!("{error:#}");
 729                std::process::exit(1);
 730            }
 731            return;
 732        }
 733        Command::FilterLanguages(filter_args) => {
 734            if let Err(error) =
 735                run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
 736            {
 737                eprintln!("{error:#}");
 738                std::process::exit(1);
 739            }
 740            return;
 741        }
 742        Command::Qa(qa_args) => {
 743            // Read examples from input files
 744            let mut examples = example::read_example_files(&args.inputs);
 745
 746            // Apply filters
 747            if let Some(name_filter) = &args.name {
 748                examples.retain(|e| e.spec.name.contains(name_filter));
 749            }
 750            if let Some(repo_filter) = &args.repo {
 751                examples.retain(|e| e.spec.repository_url.contains(repo_filter));
 752            }
 753            if let Some(offset) = args.offset {
 754                examples.splice(0..offset, []);
 755            }
 756            if let Some(limit) = args.limit {
 757                examples.truncate(limit);
 758            }
 759
 760            smol::block_on(async {
 761                if let Err(e) = qa::run_qa(&mut examples, qa_args, output.as_ref()).await {
 762                    eprintln!("Error: {:?}", e);
 763                    std::process::exit(1);
 764                }
 765            });
 766            return;
 767        }
 768        Command::Repair(repair_args) => {
 769            // Read examples from input files
 770            let mut examples = example::read_example_files(&args.inputs);
 771
 772            // Apply filters
 773            if let Some(name_filter) = &args.name {
 774                examples.retain(|e| e.spec.name.contains(name_filter));
 775            }
 776            if let Some(repo_filter) = &args.repo {
 777                examples.retain(|e| e.spec.repository_url.contains(repo_filter));
 778            }
 779            if let Some(offset) = args.offset {
 780                examples.splice(0..offset, []);
 781            }
 782            if let Some(limit) = args.limit {
 783                examples.truncate(limit);
 784            }
 785
 786            smol::block_on(async {
 787                if let Err(e) =
 788                    repair::run_repair(&mut examples, repair_args, output.as_ref()).await
 789                {
 790                    eprintln!("Error: {:?}", e);
 791                    std::process::exit(1);
 792                }
 793            });
 794            return;
 795        }
 796        _ => {}
 797    }
 798
 799    let http_client = Arc::new(ReqwestClient::new());
 800    let app = Application::headless().with_http_client(http_client);
 801
 802    app.run(move |cx| {
 803        let app_state = Arc::new(headless::init(cx));
 804        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
 805
 806        cx.spawn(async move |cx| {
 807            let result = async {
 808                let examples = load_examples(
 809                    app_state.client.http_client(),
 810                    &args,
 811                    output.as_ref(),
 812                    cx.background_executor().clone(),
 813                )
 814                .await?;
 815
 816                match &command {
 817                    Command::Predict(args) | Command::Score(args) => {
 818                        predict::sync_batches(args.provider.as_ref()).await?;
 819                    }
 820                    Command::Eval(args) => {
 821                        predict::sync_batches(args.predict.provider.as_ref()).await?;
 822                    }
 823                    _ => (),
 824                }
 825
 826                let failfast_on_single_example = examples.len() == 1;
 827
 828                // For --markdown mode, create the output directory if it doesn't exist
 829                if args.markdown {
 830                    let dir = output.as_ref().expect("--markdown requires -o");
 831                    if !dir.exists() {
 832                        std::fs::create_dir_all(dir)
 833                            .expect("Failed to create markdown output directory");
 834                    }
 835                }
 836
 837                // Set up JSONL output writer (not used in markdown mode)
 838                let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
 839                let mut in_place_temp_path: Option<PathBuf> = None;
 840                if !args.markdown
 841                    && let Some(output_path) = output.as_ref()
 842                {
 843                    let write_path = if args.in_place {
 844                        let temp = output_path.with_extension("jsonl.tmp");
 845                        in_place_temp_path = Some(temp.clone());
 846                        temp
 847                    } else {
 848                        output_path.clone()
 849                    };
 850
 851                    let file = OpenOptions::new()
 852                        .create(true)
 853                        .write(true)
 854                        .truncate(args.in_place)
 855                        .append(!args.in_place)
 856                        .open(&write_path)
 857                        .expect("Failed to open output file");
 858
 859                    let mut writer = BufWriter::new(file);
 860                    let (sender, mut receiver) = mpsc::unbounded::<String>();
 861                    cx.background_spawn(async move {
 862                        while let Some(line) = receiver.next().await {
 863                            writeln!(writer, "{}", line).expect("Failed to write example");
 864                            writer.flush().expect("Failed to flush output");
 865                        }
 866                    })
 867                    .detach();
 868                    output_sender = Some(sender);
 869                }
 870
 871                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
 872                let finished_examples = Mutex::new(Vec::new());
 873
 874                let mut tasks = Vec::new();
 875                for _ in 0..args.max_parallelism {
 876                    tasks.push(async {
 877                        loop {
 878                            let Some(mut repo_examples) =
 879                                grouped_examples.lock().unwrap().pop_front()
 880                            else {
 881                                break;
 882                            };
 883                            for example in &mut repo_examples {
 884                                let example_progress =
 885                                    Progress::global().start_group(&example.spec.name);
 886
 887                                let result = async {
 888                                    match &command {
 889                                        Command::Read => {}
 890                                        Command::LoadProject => {
 891                                            run_load_project(
 892                                                example,
 893                                                app_state.clone(),
 894                                                &example_progress,
 895                                                cx.clone(),
 896                                            )
 897                                            .await?;
 898                                        }
 899                                        Command::Context => {
 900                                            run_context_retrieval(
 901                                                example,
 902                                                app_state.clone(),
 903                                                &example_progress,
 904                                                cx.clone(),
 905                                            )
 906                                            .await?;
 907                                        }
 908                                        Command::FormatPrompt(args) => {
 909                                            run_format_prompt(
 910                                                example,
 911                                                args,
 912                                                app_state.clone(),
 913                                                &example_progress,
 914                                                cx.clone(),
 915                                            )
 916                                            .await?;
 917                                        }
 918                                        Command::Predict(args) => {
 919                                            run_prediction(
 920                                                example,
 921                                                args,
 922                                                app_state.clone(),
 923                                                &example_progress,
 924                                                cx.clone(),
 925                                            )
 926                                            .await?;
 927                                        }
 928                                        Command::ParseOutput => {
 929                                            parse_output::run_parse_output(example)?;
 930                                        }
 931                                        Command::Distill => {
 932                                            run_distill(example).await?;
 933                                        }
 934                                        Command::Score(args) => {
 935                                            run_scoring(
 936                                                example,
 937                                                args,
 938                                                app_state.clone(),
 939                                                &example_progress,
 940                                                cx.clone(),
 941                                            )
 942                                            .await?;
 943                                        }
 944                                        Command::Eval(args) => {
 945                                            run_scoring(
 946                                                example,
 947                                                &args.predict,
 948                                                app_state.clone(),
 949                                                &example_progress,
 950                                                cx.clone(),
 951                                            )
 952                                            .await?;
 953                                        }
 954                                        Command::Clean
 955                                        | Command::Synthesize(_)
 956                                        | Command::SplitCommit(_)
 957                                        | Command::Split(_)
 958                                        | Command::FilterLanguages(_)
 959                                        | Command::ImportBatch(_)
 960                                        | Command::Qa(_)
 961                                        | Command::Repair(_) => {
 962                                            unreachable!()
 963                                        }
 964                                    }
 965                                    anyhow::Ok(())
 966                                }
 967                                .await;
 968
 969                                let failed = if let Err(error) = result {
 970                                    handle_error(
 971                                        error,
 972                                        &args,
 973                                        &command,
 974                                        &app_state,
 975                                        failfast_on_single_example,
 976                                        &example,
 977                                    )
 978                                    .await;
 979                                    true
 980                                } else {
 981                                    false
 982                                };
 983
 984                                let should_write = !failed || args.failed == FailedHandling::Keep;
 985                                if should_write {
 986                                    if args.markdown {
 987                                        let markdown_dir =
 988                                            output.as_ref().expect("--markdown requires -o");
 989                                        let filename = format!("{}.md", example.spec.filename());
 990                                        let path = markdown_dir.join(&filename);
 991                                        let markdown = example.spec.to_markdown();
 992                                        std::fs::write(&path, &markdown)
 993                                            .expect("Failed to write markdown file");
 994                                    } else if let Some(ref mut sender) = output_sender.clone() {
 995                                        let line = serde_json::to_string(&example).unwrap();
 996                                        sender
 997                                            .send(line)
 998                                            .await
 999                                            .expect("Failed to send to output writer");
1000                                    } else if args.output.is_none()
1001                                        && !matches!(command, Command::Eval(_))
1002                                    {
1003                                        let line = serde_json::to_string(&example).unwrap();
1004                                        println!("{}", line);
1005                                    }
1006                                }
1007                            }
1008
1009                            let repo_url = &repo_examples.first().unwrap().spec.repository_url;
1010                            let project = repo_examples
1011                                .iter()
1012                                .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
1013                                .or_else(|| app_state.project_cache.get(repo_url));
1014
1015                            if let Some(project) = project {
1016                                let mut cx = cx.clone();
1017
1018                                let shutdown_task: Task<()> =
1019                                    project.update(&mut cx, |project, cx| {
1020                                        let lsp_store = project.lsp_store();
1021                                        lsp_store.update(cx, |lsp_store, cx| {
1022                                            lsp_store.shutdown_all_language_servers(cx)
1023                                        })
1024                                    });
1025
1026                                shutdown_task.await;
1027
1028                                if let Some(ep_store) =
1029                                    cx.update(|cx| EditPredictionStore::try_global(cx))
1030                                {
1031                                    ep_store.update(&mut cx, |store, _| {
1032                                        store.remove_project(&project);
1033                                    });
1034                                }
1035                            }
1036
1037                            app_state.project_cache.remove(repo_url);
1038                            for example in &mut repo_examples {
1039                                example.state.take();
1040                            }
1041                            finished_examples
1042                                .lock()
1043                                .unwrap()
1044                                .extend_from_slice(&repo_examples);
1045                        }
1046                    });
1047                }
1048                futures::future::join_all(tasks).await;
1049
1050                Progress::global().finalize();
1051
1052                match &command {
1053                    Command::Predict(args) | Command::Score(args) => {
1054                        predict::sync_batches(args.provider.as_ref()).await?;
1055                    }
1056                    Command::Eval(args) => {
1057                        predict::sync_batches(args.predict.provider.as_ref()).await?;
1058                    }
1059                    _ => (),
1060                }
1061
1062                match &command {
1063                    Command::Eval(args) => {
1064                        let examples = finished_examples.lock().unwrap();
1065                        score::print_report(&examples);
1066                        if let Some(summary_path) = &args.summary_json {
1067                            score::write_summary_json(&examples, summary_path)?;
1068                        }
1069                    }
1070                    _ => (),
1071                };
1072
1073                // For --in-place, atomically rename temp file to original
1074                if let Some(temp_path) = &in_place_temp_path {
1075                    let final_path = output.as_ref().expect("in_place_temp_path requires output");
1076                    std::fs::rename(temp_path, final_path)
1077                        .expect("Failed to rename temp file to final output");
1078                }
1079
1080                anyhow::Ok(())
1081            }
1082            .await;
1083
1084            if let Err(e) = result {
1085                panic!("Fatal error: {:?}", e);
1086            }
1087
1088            let _ = cx.update(|cx| cx.quit());
1089        })
1090        .detach();
1091    });
1092}
1093
1094async fn handle_error(
1095    error: anyhow::Error,
1096    args: &EpArgs,
1097    command: &Command,
1098    app_state: &Arc<headless::EpAppState>,
1099    failfast_on_single_example: bool,
1100    example: &Example,
1101) {
1102    Progress::global().increment_failed();
1103
1104    let msg;
1105    if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1106        let example_name = example.spec.filename();
1107
1108        let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1109        app_state
1110            .fs
1111            .write(
1112                &failed_example_path,
1113                &serde_json::to_vec_pretty(&example).unwrap(),
1114            )
1115            .await
1116            .unwrap();
1117        let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1118        app_state
1119            .fs
1120            .write(&err_path, format!("{error:?}").as_bytes())
1121            .await
1122            .unwrap();
1123
1124        let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1125        let mut file = OpenOptions::new()
1126            .create(true)
1127            .append(true)
1128            .open(&failed_jsonl_path)
1129            .expect("Failed to open failed.jsonl");
1130        writeln!(file, "{}", serde_json::to_string(example).unwrap())
1131            .expect("Failed to write to failed.jsonl");
1132
1133        let cursor_path = example
1134            .repo_name()
1135            .unwrap()
1136            .worktree_path()
1137            .join(&example.spec.cursor_path);
1138        msg = format!(
1139            indoc::indoc! {"
1140                While processing \"{}\":
1141
1142                \x1b[31m{:?}\x1b[0m
1143
1144                Example:        \x1b[36m{}\x1b[0m
1145                Error file:     \x1b[36m{}\x1b[0m
1146                Cursor file:    \x1b[36m{}\x1b[0m
1147                Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1148            "},
1149            example.spec.name,
1150            error,
1151            failed_example_path.display(),
1152            err_path.display(),
1153            cursor_path.display(),
1154            command,
1155            failed_example_path.display(),
1156        );
1157    } else {
1158        msg = format!(
1159            indoc::indoc! {"
1160            While processing \"{}\":
1161
1162                \x1b[31m{:?}\x1b[0m
1163            "},
1164            example.spec.name, error
1165        );
1166    }
1167
1168    if args.failfast || failfast_on_single_example {
1169        Progress::global().finalize();
1170        panic!("{}", msg);
1171    } else {
1172        log::error!("{}", msg);
1173    }
1174}