main.rs

   1mod anthropic_client;
   2mod distill;
   3mod example;
   4mod filter_languages;
   5mod format_prompt;
   6mod git;
   7mod headless;
   8mod load_project;
   9mod metrics;
  10mod parse_output;
  11mod paths;
  12mod predict;
  13mod progress;
  14mod pull_examples;
  15mod qa;
  16mod reorder_patch;
  17mod repair;
  18mod retrieve_context;
  19mod score;
  20mod split_commit;
  21mod split_dataset;
  22mod synthesize;
  23mod word_diff;
  24use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
  25use collections::HashSet;
  26use edit_prediction::EditPredictionStore;
  27use futures::channel::mpsc;
  28use futures::{SinkExt as _, StreamExt as _};
  29use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
  30use zeta_prompt::ZetaVersion;
  31
  32use reqwest_client::ReqwestClient;
  33use serde::{Deserialize, Deserializer, Serialize, Serializer};
  34use std::fmt::Display;
  35use std::fs::{File, OpenOptions};
  36use std::hash::{Hash, Hasher};
  37use std::io::{BufRead, BufReader, BufWriter, Write};
  38use std::sync::Mutex;
  39use std::{path::PathBuf, sync::Arc};
  40
  41use crate::distill::run_distill;
  42use crate::example::{Example, group_examples_by_repo, read_example_files};
  43use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
  44use crate::format_prompt::run_format_prompt;
  45use crate::load_project::run_load_project;
  46use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
  47use crate::predict::run_prediction;
  48use crate::progress::Progress;
  49use crate::retrieve_context::run_context_retrieval;
  50use crate::score::run_scoring;
  51use crate::split_commit::SplitCommitArgs;
  52use crate::split_dataset::SplitArgs;
  53use crate::synthesize::{SynthesizeConfig, run_synthesize};
  54
  55#[derive(Parser, Debug)]
  56#[command(name = "ep")]
  57struct EpArgs {
  58    #[arg(long, default_value_t = false)]
  59    printenv: bool,
  60    #[clap(long, default_value_t = 10, global = true)]
  61    max_parallelism: usize,
  62    #[clap(long, global = true)]
  63    limit: Option<usize>,
  64    #[clap(long, global = true)]
  65    offset: Option<usize>,
  66    /// Filter examples by name
  67    #[clap(long, global = true)]
  68    name: Option<String>,
  69    /// Filter examples by repository
  70    #[clap(long, global = true)]
  71    repo: Option<String>,
  72    #[command(subcommand)]
  73    command: Option<Command>,
  74    #[clap(global = true, help = INPUTS_HELP)]
  75    inputs: Vec<PathBuf>,
  76    #[arg(long, short, global = true)]
  77    output: Option<PathBuf>,
  78    #[arg(long, short, global = true)]
  79    in_place: bool,
  80    #[arg(long, short, global = true)]
  81    failfast: bool,
  82    /// How to handle failed examples in output: keep them or skip them.
  83    /// Failed examples are always logged to the run's failed directory.
  84    #[arg(long, global = true, default_value = "keep")]
  85    failed: FailedHandling,
  86}
  87
  88/// Controls whether failed examples are included in the main output.
  89/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
  90#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
  91pub enum FailedHandling {
  92    /// Include failed examples in the main output (default)
  93    #[default]
  94    Keep,
  95    /// Exclude failed examples from the main output
  96    Skip,
  97    /// Skip writing files
  98    SkipNoFiles,
  99}
 100
 101const INPUTS_HELP: &str = r#"
 102Inputs can be file paths or special specifiers:
 103
 104  path
 105      Path to an example(s) file (.md, .json, or .jsonl)
 106
 107  captured-after:{timestamp}
 108      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 109      These are examples captured via the "Capture Edit Prediction Example" action.
 110
 111  rejected-after:{timestamp}
 112      Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
 113      These are predictions that were shown to users but rejected (useful for DPO training).
 114
 115      Required environment variables to connect to Snowflake:
 116          EP_SNOWFLAKE_API_KEY
 117          EP_SNOWFLAKE_BASE_URL
 118
 119      Optional:
 120          EP_SNOWFLAKE_ROLE
 121
 122Examples:
 123
 124  # Read examples from a file
 125  ep read examples.jsonl -o output.jsonl
 126
 127  # Read captured examples after a timestamp
 128  ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
 129
 130  # Read rejected predictions for DPO training
 131  ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
 132
 133  # Mix multiple input sources
 134  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
 135"#;
 136
 137#[derive(Subcommand, Debug, Clone)]
 138enum Command {
 139    /// Read examples from files or fetch from Snowflake, output as .jsonl
 140    Read,
 141    /// Create git worktrees for each example and load file contents
 142    LoadProject,
 143    /// Retrieve context for input examples.
 144    Context,
 145    /// Generate a prompt string for a specific model
 146    FormatPrompt(FormatPromptArgs),
 147    /// Runs edit prediction
 148    Predict(PredictArgs),
 149    /// Parse model outputs (actual_output) into unified diffs (actual_patch).
 150    /// Requires format-prompt to have been run first. Uses provider from prompt.
 151    ParseOutput,
 152    /// Computes a score based on actual and expected patches
 153    Score(PredictArgs),
 154    /// Prepares a distillation dataset by copying expected outputs to
 155    /// predicted outputs and removing actual outputs and prompts.
 156    Distill,
 157    /// Print aggregated scores
 158    Eval(EvalArgs),
 159    /// Generate eval examples by analyzing git commits from a repository
 160    Synthesize(SynthesizeArgs),
 161    /// Remove git repositories and worktrees
 162    Clean,
 163    /// Generate an evaluation example by splitting a chronologically-ordered commit
 164    SplitCommit(SplitCommitArgs),
 165    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
 166    Split(SplitArgs),
 167    /// Filter a JSONL dataset by programming language (based on cursor_path extension)
 168    FilterLanguages(FilterLanguagesArgs),
 169    /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
 170    ImportBatch(ImportBatchArgs),
 171    /// Assess the quality of predictions using LLM-as-a-judge
 172    Qa(qa::QaArgs),
 173    /// Repair predictions that received poor QA scores by generating improved predictions
 174    Repair(repair::RepairArgs),
 175}
 176
 177impl Display for Command {
 178    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 179        match self {
 180            Command::Read => write!(f, "read"),
 181            Command::LoadProject => write!(f, "load-project"),
 182            Command::Context => write!(f, "context"),
 183            Command::FormatPrompt(args) => {
 184                write!(f, "format-prompt --provider={}", args.provider)
 185            }
 186            Command::Predict(args) => match &args.provider {
 187                Some(provider) => write!(f, "predict --provider={}", provider),
 188                None => write!(f, "predict"),
 189            },
 190            Command::ParseOutput => write!(f, "parse-output"),
 191            Command::Score(args) => match &args.provider {
 192                Some(provider) => write!(f, "score --provider={}", provider),
 193                None => write!(f, "score"),
 194            },
 195            Command::Distill => write!(f, "distill"),
 196            Command::Eval(args) => match &args.predict.provider {
 197                Some(provider) => write!(f, "eval --provider={}", provider),
 198                None => write!(f, "eval"),
 199            },
 200            Command::Synthesize(args) => {
 201                write!(f, "synthesize --repos {}", args.repos.join(" "))
 202            }
 203            Command::Clean => write!(f, "clean"),
 204            Command::SplitCommit(_) => write!(f, "split-commit"),
 205            Command::Split(_) => write!(f, "split"),
 206            Command::FilterLanguages(_) => write!(f, "filter-languages"),
 207            Command::ImportBatch(args) => {
 208                write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
 209            }
 210            Command::Qa(_) => {
 211                write!(f, "qa")
 212            }
 213            Command::Repair(_) => {
 214                write!(f, "repair")
 215            }
 216        }
 217    }
 218}
 219
 220#[derive(Debug, Args, Clone)]
 221struct FormatPromptArgs {
 222    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
 223    provider: PredictionProvider,
 224}
 225
 226#[derive(Debug, Args, Clone)]
 227struct PredictArgs {
 228    #[clap(long, short('p'))]
 229    provider: Option<PredictionProvider>,
 230    #[clap(long, default_value_t = 1)]
 231    repetitions: usize,
 232}
 233
 234#[derive(Debug, Args, Clone)]
 235struct EvalArgs {
 236    #[clap(flatten)]
 237    predict: PredictArgs,
 238    /// Path to write summary scores as JSON
 239    #[clap(long)]
 240    summary_json: Option<PathBuf>,
 241}
 242
 243#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 244enum PredictionProvider {
 245    Sweep,
 246    Mercury,
 247    Zeta1,
 248    Zeta2(ZetaVersion),
 249    Teacher(ZetaVersion),
 250    TeacherNonBatching(ZetaVersion),
 251    Repair,
 252}
 253
 254impl Default for PredictionProvider {
 255    fn default() -> Self {
 256        PredictionProvider::Zeta2(ZetaVersion::default())
 257    }
 258}
 259
 260impl std::fmt::Display for PredictionProvider {
 261    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 262        match self {
 263            PredictionProvider::Sweep => write!(f, "sweep"),
 264            PredictionProvider::Mercury => write!(f, "mercury"),
 265            PredictionProvider::Zeta1 => write!(f, "zeta1"),
 266            PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
 267            PredictionProvider::Teacher(version) => write!(f, "teacher:{version}"),
 268            PredictionProvider::TeacherNonBatching(version) => {
 269                write!(f, "teacher-non-batching:{version}")
 270            }
 271            PredictionProvider::Repair => write!(f, "repair"),
 272        }
 273    }
 274}
 275
 276impl std::str::FromStr for PredictionProvider {
 277    type Err = anyhow::Error;
 278
 279    fn from_str(mut s: &str) -> Result<Self, Self::Err> {
 280        let mut version = ZetaVersion::default();
 281        if let Some((first, second)) = s.split_once(':') {
 282            version = ZetaVersion::parse(second)?;
 283            s = first;
 284        }
 285
 286        let s_lower = s.to_lowercase();
 287        match s_lower.as_str() {
 288            "sweep" => Ok(PredictionProvider::Sweep),
 289            "mercury" => Ok(PredictionProvider::Mercury),
 290            "zeta1" => Ok(PredictionProvider::Zeta1),
 291            "zeta2" => Ok(PredictionProvider::Zeta2(version)),
 292            "teacher" => Ok(PredictionProvider::Teacher(version)),
 293            "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
 294                Ok(PredictionProvider::TeacherNonBatching(version))
 295            }
 296            "repair" => Ok(PredictionProvider::Repair),
 297            _ => {
 298                anyhow::bail!(
 299                    "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching, repair\n\
 300                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
 301                 Available zeta versions:\n{}",
 302                    ZetaVersion::options_as_string()
 303                )
 304            }
 305        }
 306    }
 307}
 308
 309impl Serialize for PredictionProvider {
 310    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 311    where
 312        S: Serializer,
 313    {
 314        serializer.serialize_str(&self.to_string())
 315    }
 316}
 317
 318impl<'de> Deserialize<'de> for PredictionProvider {
 319    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 320    where
 321        D: Deserializer<'de>,
 322    {
 323        let s = String::deserialize(deserializer)?;
 324        s.parse().map_err(serde::de::Error::custom)
 325    }
 326}
 327
 328#[derive(Debug, Args, Clone)]
 329struct SynthesizeArgs {
 330    /// Repository URLs (git@github.com:owner/repo or https://...)
 331    #[clap(long, required = true, num_args = 1..)]
 332    repos: Vec<String>,
 333
 334    /// Number of examples to generate per repository
 335    #[clap(long, default_value_t = 5)]
 336    count: usize,
 337
 338    /// Maximum commits to scan per repository before giving up
 339    #[clap(long, default_value_t = 100)]
 340    max_commits: usize,
 341
 342    /// Ignore state file and reprocess all commits
 343    #[clap(long)]
 344    fresh: bool,
 345}
 346
 347#[derive(Debug, Args, Clone)]
 348struct ImportBatchArgs {
 349    /// Anthropic batch IDs to import (e.g., msgbatch_xxx)
 350    #[clap(long, required = true, num_args = 1..)]
 351    batch_ids: Vec<String>,
 352}
 353
 354impl EpArgs {
 355    fn output_path(&self) -> Option<PathBuf> {
 356        if self.in_place {
 357            if self.inputs.len() == 1 {
 358                self.inputs.first().cloned()
 359            } else {
 360                panic!("--in-place requires exactly one input file")
 361            }
 362        } else {
 363            self.output.clone()
 364        }
 365    }
 366}
 367
 368async fn load_examples(
 369    http_client: Arc<dyn http_client::HttpClient>,
 370    args: &EpArgs,
 371    output_path: Option<&PathBuf>,
 372    background_executor: BackgroundExecutor,
 373) -> anyhow::Result<Vec<Example>> {
 374    let mut captured_after_timestamps = Vec::new();
 375    let mut rejected_after_timestamps = Vec::new();
 376    let mut file_inputs = Vec::new();
 377
 378    for input in &args.inputs {
 379        let input_string = input.to_string_lossy();
 380        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
 381            captured_after_timestamps.push(timestamp.to_string());
 382        } else if let Some(timestamp) =
 383            pull_examples::parse_rejected_after_input(input_string.as_ref())
 384        {
 385            rejected_after_timestamps.push(timestamp.to_string());
 386        } else {
 387            file_inputs.push(input.clone());
 388        }
 389    }
 390
 391    let mut examples = read_example_files(&file_inputs);
 392
 393    Progress::global().set_total_examples(examples.len());
 394
 395    let remaining_limit_for_snowflake =
 396        args.limit.map(|limit| limit.saturating_sub(examples.len()));
 397
 398    if let Some(0) = remaining_limit_for_snowflake {
 399        log::info!(
 400            "skipping Snowflake inputs because --limit is already satisfied by example files"
 401        );
 402    } else {
 403        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
 404
 405        if !captured_after_timestamps.is_empty() {
 406            captured_after_timestamps.sort();
 407
 408            let mut captured_examples = pull_examples::fetch_captured_examples_after(
 409                http_client.clone(),
 410                &captured_after_timestamps,
 411                max_rows_per_timestamp,
 412                background_executor.clone(),
 413            )
 414            .await?;
 415            examples.append(&mut captured_examples);
 416        }
 417
 418        if !rejected_after_timestamps.is_empty() {
 419            rejected_after_timestamps.sort();
 420
 421            let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
 422                http_client,
 423                &rejected_after_timestamps,
 424                max_rows_per_timestamp,
 425                background_executor,
 426            )
 427            .await?;
 428            examples.append(&mut rejected_examples);
 429        }
 430    }
 431
 432    crate::example::sort_examples_by_repo_and_rev(&mut examples);
 433
 434    if let Some(name_filter) = &args.name {
 435        examples.retain(|example| example.spec.name.contains(name_filter));
 436    }
 437    if let Some(repo_filter) = &args.repo {
 438        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
 439    }
 440
 441    // Skip resume logic for --in-place since input and output are the same file,
 442    // which would incorrectly treat all input examples as already processed.
 443    if !args.in_place {
 444        if let Some(path) = output_path {
 445            resume_from_output(path, &mut examples);
 446        }
 447    }
 448
 449    if let Some(offset) = args.offset {
 450        examples.splice(0..offset, []);
 451    }
 452
 453    if let Some(limit) = args.limit {
 454        examples.truncate(limit);
 455    }
 456
 457    let progress = Progress::global();
 458    progress.set_total_examples(examples.len());
 459    progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
 460
 461    Ok(examples)
 462}
 463
 464fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
 465    let mut hasher = collections::FxHasher::default();
 466    spec.hash(&mut hasher);
 467    hasher.finish()
 468}
 469
 470fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
 471    let file = match File::open(path) {
 472        Ok(f) => f,
 473        Err(_) => return,
 474    };
 475
 476    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
 477
 478    let reader = BufReader::new(file);
 479    let mut kept_lines = Vec::new();
 480    let mut kept_hashes = HashSet::default();
 481
 482    for line in reader.lines() {
 483        let line = match line {
 484            Ok(l) => l,
 485            Err(_) => continue,
 486        };
 487
 488        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
 489            let hash = spec_hash(&output_example.spec);
 490            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
 491                kept_hashes.insert(hash);
 492                kept_lines.push(line);
 493            }
 494        }
 495    }
 496
 497    let total = examples.len();
 498    let already_processed = kept_hashes.len();
 499
 500    eprintln!(
 501        "Resuming: {}/{} examples already processed",
 502        already_processed, total
 503    );
 504
 505    let file = OpenOptions::new()
 506        .write(true)
 507        .truncate(true)
 508        .open(path)
 509        .expect("Failed to open output file for rewriting");
 510    let mut writer = BufWriter::new(file);
 511    for line in &kept_lines {
 512        writeln!(writer, "{}", line).expect("Failed to write to output file");
 513    }
 514    writer.flush().expect("Failed to flush output file");
 515
 516    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
 517}
 518
 519fn main() {
 520    let args = EpArgs::parse();
 521
 522    if args.printenv {
 523        ::util::shell_env::print_env();
 524        return;
 525    }
 526
 527    let output = args.output_path();
 528    let command = match &args.command {
 529        Some(cmd) => cmd.clone(),
 530        None => {
 531            EpArgs::command().print_help().unwrap();
 532            return;
 533        }
 534    };
 535
 536    match &command {
 537        Command::ImportBatch(import_args) => {
 538            smol::block_on(async {
 539                let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
 540                    .expect("Failed to create Anthropic client");
 541                if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 542                    eprintln!("Error importing batches: {:?}", e);
 543                    std::process::exit(1);
 544                }
 545                println!(
 546                    "Successfully imported {} batch(es)",
 547                    import_args.batch_ids.len()
 548                );
 549            });
 550            return;
 551        }
 552        Command::Clean => {
 553            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
 554            return;
 555        }
 556        Command::Synthesize(synth_args) => {
 557            let Some(output_dir) = args.output else {
 558                panic!("output dir is required");
 559            };
 560            let config = SynthesizeConfig {
 561                repo_urls: synth_args.repos.clone(),
 562                count: synth_args.count,
 563                max_commits: synth_args.max_commits,
 564                output_dir,
 565                fresh: synth_args.fresh,
 566            };
 567            smol::block_on(async {
 568                if let Err(e) = run_synthesize(config).await {
 569                    eprintln!("Error: {:?}", e);
 570                    std::process::exit(1);
 571                }
 572            });
 573            return;
 574        }
 575        Command::SplitCommit(split_commit_args) => {
 576            if let Err(error) = split_commit::run_split_commit(
 577                split_commit_args,
 578                &args.inputs,
 579                output.as_ref(),
 580                args.failed,
 581            ) {
 582                eprintln!("{error:#}");
 583                std::process::exit(1);
 584            }
 585            return;
 586        }
 587        Command::Split(split_args) => {
 588            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
 589                eprintln!("{error:#}");
 590                std::process::exit(1);
 591            }
 592            return;
 593        }
 594        Command::FilterLanguages(filter_args) => {
 595            if let Err(error) =
 596                run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
 597            {
 598                eprintln!("{error:#}");
 599                std::process::exit(1);
 600            }
 601            return;
 602        }
 603        Command::Qa(qa_args) => {
 604            // Read examples from input files
 605            let mut examples = example::read_example_files(&args.inputs);
 606
 607            // Apply filters
 608            if let Some(name_filter) = &args.name {
 609                examples.retain(|e| e.spec.name.contains(name_filter));
 610            }
 611            if let Some(repo_filter) = &args.repo {
 612                examples.retain(|e| e.spec.repository_url.contains(repo_filter));
 613            }
 614            if let Some(offset) = args.offset {
 615                examples.splice(0..offset, []);
 616            }
 617            if let Some(limit) = args.limit {
 618                examples.truncate(limit);
 619            }
 620
 621            smol::block_on(async {
 622                if let Err(e) = qa::run_qa(&mut examples, qa_args, output.as_ref()).await {
 623                    eprintln!("Error: {:?}", e);
 624                    std::process::exit(1);
 625                }
 626            });
 627            return;
 628        }
 629        Command::Repair(repair_args) => {
 630            // Read examples from input files
 631            let mut examples = example::read_example_files(&args.inputs);
 632
 633            // Apply filters
 634            if let Some(name_filter) = &args.name {
 635                examples.retain(|e| e.spec.name.contains(name_filter));
 636            }
 637            if let Some(repo_filter) = &args.repo {
 638                examples.retain(|e| e.spec.repository_url.contains(repo_filter));
 639            }
 640            if let Some(offset) = args.offset {
 641                examples.splice(0..offset, []);
 642            }
 643            if let Some(limit) = args.limit {
 644                examples.truncate(limit);
 645            }
 646
 647            smol::block_on(async {
 648                if let Err(e) =
 649                    repair::run_repair(&mut examples, repair_args, output.as_ref()).await
 650                {
 651                    eprintln!("Error: {:?}", e);
 652                    std::process::exit(1);
 653                }
 654            });
 655            return;
 656        }
 657        _ => {}
 658    }
 659
 660    let http_client = Arc::new(ReqwestClient::new());
 661    let app = Application::headless().with_http_client(http_client);
 662
 663    app.run(move |cx| {
 664        let app_state = Arc::new(headless::init(cx));
 665        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
 666
 667        cx.spawn(async move |cx| {
 668            let result = async {
 669                let examples = load_examples(
 670                    app_state.client.http_client(),
 671                    &args,
 672                    output.as_ref(),
 673                    cx.background_executor().clone(),
 674                )
 675                .await?;
 676
 677                match &command {
 678                    Command::Predict(args) | Command::Score(args) => {
 679                        predict::sync_batches(args.provider.as_ref()).await?;
 680                    }
 681                    Command::Eval(args) => {
 682                        predict::sync_batches(args.predict.provider.as_ref()).await?;
 683                    }
 684                    _ => (),
 685                }
 686
 687                let failfast_on_single_example = examples.len() == 1;
 688
 689                // For --in-place, write to a temp file and rename at the end to avoid data loss on interruption
 690                let in_place_temp_path = if args.in_place {
 691                    output.as_ref().map(|path| {
 692                        let mut temp_path = path.clone();
 693                        temp_path.set_extension("jsonl.tmp");
 694                        temp_path
 695                    })
 696                } else {
 697                    None
 698                };
 699
 700                let output_sender: Option<mpsc::UnboundedSender<String>> =
 701                    if args.output.is_some() || !matches!(command, Command::Eval(_)) {
 702                        let write_path = in_place_temp_path.as_ref().or(output.as_ref());
 703                        write_path.map(|path| {
 704                            let file = if args.in_place {
 705                                // For --in-place, write to temp file (truncate if exists)
 706                                OpenOptions::new()
 707                                    .create(true)
 708                                    .write(true)
 709                                    .truncate(true)
 710                                    .open(path)
 711                                    .expect("Failed to open temp output file")
 712                            } else {
 713                                // For regular output, append to support resuming
 714                                OpenOptions::new()
 715                                    .create(true)
 716                                    .append(true)
 717                                    .open(path)
 718                                    .expect("Failed to open output file")
 719                            };
 720                            let mut writer = BufWriter::new(file);
 721                            let (sender, mut receiver) = mpsc::unbounded::<String>();
 722                            cx.background_spawn(async move {
 723                                while let Some(line) = receiver.next().await {
 724                                    writeln!(writer, "{}", line).expect("Failed to write example");
 725                                    writer.flush().expect("Failed to flush output");
 726                                }
 727                            })
 728                            .detach();
 729                            sender
 730                        })
 731                    } else {
 732                        None
 733                    };
 734
 735                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
 736                let finished_examples = Mutex::new(Vec::new());
 737
 738                let mut tasks = Vec::new();
 739                for _ in 0..args.max_parallelism {
 740                    tasks.push(async {
 741                        loop {
 742                            let Some(mut repo_examples) =
 743                                grouped_examples.lock().unwrap().pop_front()
 744                            else {
 745                                break;
 746                            };
 747                            for example in &mut repo_examples {
 748                                let example_progress =
 749                                    Progress::global().start_group(&example.spec.name);
 750
 751                                let result = async {
 752                                    match &command {
 753                                        Command::Read => {}
 754                                        Command::LoadProject => {
 755                                            run_load_project(
 756                                                example,
 757                                                app_state.clone(),
 758                                                &example_progress,
 759                                                cx.clone(),
 760                                            )
 761                                            .await?;
 762                                        }
 763                                        Command::Context => {
 764                                            run_context_retrieval(
 765                                                example,
 766                                                app_state.clone(),
 767                                                &example_progress,
 768                                                cx.clone(),
 769                                            )
 770                                            .await?;
 771                                        }
 772                                        Command::FormatPrompt(args) => {
 773                                            run_format_prompt(
 774                                                example,
 775                                                args,
 776                                                app_state.clone(),
 777                                                &example_progress,
 778                                                cx.clone(),
 779                                            )
 780                                            .await?;
 781                                        }
 782                                        Command::Predict(args) => {
 783                                            run_prediction(
 784                                                example,
 785                                                args,
 786                                                app_state.clone(),
 787                                                &example_progress,
 788                                                cx.clone(),
 789                                            )
 790                                            .await?;
 791                                        }
 792                                        Command::ParseOutput => {
 793                                            parse_output::run_parse_output(example)?;
 794                                        }
 795                                        Command::Distill => {
 796                                            run_distill(example).await?;
 797                                        }
 798                                        Command::Score(args) => {
 799                                            run_scoring(
 800                                                example,
 801                                                args,
 802                                                app_state.clone(),
 803                                                &example_progress,
 804                                                cx.clone(),
 805                                            )
 806                                            .await?;
 807                                        }
 808                                        Command::Eval(args) => {
 809                                            run_scoring(
 810                                                example,
 811                                                &args.predict,
 812                                                app_state.clone(),
 813                                                &example_progress,
 814                                                cx.clone(),
 815                                            )
 816                                            .await?;
 817                                        }
 818                                        Command::Clean
 819                                        | Command::Synthesize(_)
 820                                        | Command::SplitCommit(_)
 821                                        | Command::Split(_)
 822                                        | Command::FilterLanguages(_)
 823                                        | Command::ImportBatch(_)
 824                                        | Command::Qa(_)
 825                                        | Command::Repair(_) => {
 826                                            unreachable!()
 827                                        }
 828                                    }
 829                                    anyhow::Ok(())
 830                                }
 831                                .await;
 832
 833                                let failed = if let Err(error) = result {
 834                                    handle_error(
 835                                        error,
 836                                        &args,
 837                                        &command,
 838                                        &app_state,
 839                                        failfast_on_single_example,
 840                                        &example,
 841                                    )
 842                                    .await;
 843                                    true
 844                                } else {
 845                                    false
 846                                };
 847
 848                                let should_write = !failed || args.failed == FailedHandling::Keep;
 849                                if should_write {
 850                                    if let Some(ref mut sender) = output_sender.clone() {
 851                                        let line = serde_json::to_string(&example).unwrap();
 852                                        sender
 853                                            .send(line)
 854                                            .await
 855                                            .expect("Failed to send to output writer");
 856                                    } else if args.output.is_none()
 857                                        && !matches!(command, Command::Eval(_))
 858                                    {
 859                                        let line = serde_json::to_string(&example).unwrap();
 860                                        println!("{}", line);
 861                                    }
 862                                }
 863                            }
 864
 865                            let repo_url = &repo_examples.first().unwrap().spec.repository_url;
 866                            let project = repo_examples
 867                                .iter()
 868                                .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
 869                                .or_else(|| app_state.project_cache.get(repo_url));
 870
 871                            if let Some(project) = project {
 872                                let mut cx = cx.clone();
 873
 874                                let shutdown_task: Task<()> =
 875                                    project.update(&mut cx, |project, cx| {
 876                                        let lsp_store = project.lsp_store();
 877                                        lsp_store.update(cx, |lsp_store, cx| {
 878                                            lsp_store.shutdown_all_language_servers(cx)
 879                                        })
 880                                    });
 881
 882                                shutdown_task.await;
 883
 884                                if let Some(ep_store) =
 885                                    cx.update(|cx| EditPredictionStore::try_global(cx))
 886                                {
 887                                    ep_store.update(&mut cx, |store, _| {
 888                                        store.remove_project(&project);
 889                                    });
 890                                }
 891                            }
 892
 893                            app_state.project_cache.remove(repo_url);
 894                            for example in &mut repo_examples {
 895                                example.state.take();
 896                            }
 897                            finished_examples
 898                                .lock()
 899                                .unwrap()
 900                                .extend_from_slice(&repo_examples);
 901                        }
 902                    });
 903                }
 904                futures::future::join_all(tasks).await;
 905
 906                Progress::global().finalize();
 907
 908                match &command {
 909                    Command::Predict(args) | Command::Score(args) => {
 910                        predict::sync_batches(args.provider.as_ref()).await?;
 911                    }
 912                    Command::Eval(args) => {
 913                        predict::sync_batches(args.predict.provider.as_ref()).await?;
 914                    }
 915                    _ => (),
 916                }
 917
 918                match &command {
 919                    Command::Eval(args) => {
 920                        let examples = finished_examples.lock().unwrap();
 921                        score::print_report(&examples);
 922                        if let Some(summary_path) = &args.summary_json {
 923                            score::write_summary_json(&examples, summary_path)?;
 924                        }
 925                    }
 926                    _ => (),
 927                };
 928
 929                // For --in-place, atomically rename temp file to original
 930                if let (Some(temp_path), Some(final_path)) = (&in_place_temp_path, &output) {
 931                    std::fs::rename(temp_path, final_path)
 932                        .expect("Failed to rename temp file to final output");
 933                }
 934
 935                anyhow::Ok(())
 936            }
 937            .await;
 938
 939            if let Err(e) = result {
 940                panic!("Fatal error: {:?}", e);
 941            }
 942
 943            let _ = cx.update(|cx| cx.quit());
 944        })
 945        .detach();
 946    });
 947}
 948
 949async fn handle_error(
 950    error: anyhow::Error,
 951    args: &EpArgs,
 952    command: &Command,
 953    app_state: &Arc<headless::EpAppState>,
 954    failfast_on_single_example: bool,
 955    example: &Example,
 956) {
 957    Progress::global().increment_failed();
 958
 959    let msg;
 960    if !matches!(args.failed, FailedHandling::SkipNoFiles) {
 961        let example_name = example.spec.filename();
 962
 963        let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
 964        app_state
 965            .fs
 966            .write(
 967                &failed_example_path,
 968                &serde_json::to_vec_pretty(&example).unwrap(),
 969            )
 970            .await
 971            .unwrap();
 972        let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
 973        app_state
 974            .fs
 975            .write(&err_path, format!("{error:?}").as_bytes())
 976            .await
 977            .unwrap();
 978
 979        let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
 980        let mut file = OpenOptions::new()
 981            .create(true)
 982            .append(true)
 983            .open(&failed_jsonl_path)
 984            .expect("Failed to open failed.jsonl");
 985        writeln!(file, "{}", serde_json::to_string(example).unwrap())
 986            .expect("Failed to write to failed.jsonl");
 987
 988        let cursor_path = example
 989            .repo_name()
 990            .unwrap()
 991            .worktree_path()
 992            .join(&example.spec.cursor_path);
 993        msg = format!(
 994            indoc::indoc! {"
 995                While processing \"{}\":
 996
 997                \x1b[31m{:?}\x1b[0m
 998
 999                Example:        \x1b[36m{}\x1b[0m
1000                Error file:     \x1b[36m{}\x1b[0m
1001                Cursor file:    \x1b[36m{}\x1b[0m
1002                Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1003            "},
1004            example.spec.name,
1005            error,
1006            failed_example_path.display(),
1007            err_path.display(),
1008            cursor_path.display(),
1009            command,
1010            failed_example_path.display(),
1011        );
1012    } else {
1013        msg = format!(
1014            indoc::indoc! {"
1015            While processing \"{}\":
1016
1017                \x1b[31m{:?}\x1b[0m
1018            "},
1019            example.spec.name, error
1020        );
1021    }
1022
1023    if args.failfast || failfast_on_single_example {
1024        Progress::global().finalize();
1025        panic!("{}", msg);
1026    } else {
1027        log::error!("{}", msg);
1028    }
1029}