main.rs

   1mod anthropic_client;
   2mod distill;
   3mod example;
   4mod filter_languages;
   5mod format_prompt;
   6mod git;
   7mod headless;
   8mod load_project;
   9mod metrics;
  10mod parse_output;
  11mod paths;
  12mod predict;
  13mod progress;
  14mod pull_examples;
  15mod qa;
  16mod reorder_patch;
  17mod repair;
  18mod retrieve_context;
  19mod reversal_tracking;
  20mod score;
  21mod split_commit;
  22mod split_dataset;
  23mod synthesize;
  24mod word_diff;
  25use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
  26use collections::HashSet;
  27use edit_prediction::EditPredictionStore;
  28use futures::channel::mpsc;
  29use futures::{SinkExt as _, StreamExt as _};
  30use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
  31use zeta_prompt::ZetaVersion;
  32
  33use reqwest_client::ReqwestClient;
  34use serde::{Deserialize, Deserializer, Serialize, Serializer};
  35use std::fmt::Display;
  36use std::fs::{File, OpenOptions};
  37use std::hash::{Hash, Hasher};
  38use std::io::{BufRead, BufReader, BufWriter, Write};
  39use std::sync::Mutex;
  40use std::{path::PathBuf, sync::Arc};
  41
  42use crate::distill::run_distill;
  43use crate::example::{Example, group_examples_by_repo, read_example_files};
  44use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
  45use crate::format_prompt::run_format_prompt;
  46use crate::load_project::run_load_project;
  47use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
  48use crate::predict::run_prediction;
  49use crate::progress::Progress;
  50use crate::retrieve_context::run_context_retrieval;
  51use crate::score::run_scoring;
  52use crate::split_commit::SplitCommitArgs;
  53use crate::split_dataset::SplitArgs;
  54use crate::synthesize::{SynthesizeConfig, run_synthesize};
  55
  56#[derive(Parser, Debug)]
  57#[command(name = "ep")]
  58struct EpArgs {
  59    #[arg(long, default_value_t = false)]
  60    printenv: bool,
  61    #[clap(long, default_value_t = 10, global = true)]
  62    max_parallelism: usize,
  63    #[clap(long, global = true)]
  64    limit: Option<usize>,
  65    #[clap(long, global = true)]
  66    offset: Option<usize>,
  67    /// Filter examples by name
  68    #[clap(long, global = true)]
  69    name: Option<String>,
  70    /// Filter examples by repository
  71    #[clap(long, global = true)]
  72    repo: Option<String>,
  73    #[command(subcommand)]
  74    command: Option<Command>,
  75    #[clap(global = true, help = INPUTS_HELP)]
  76    inputs: Vec<PathBuf>,
  77    #[arg(long, short, global = true)]
  78    output: Option<PathBuf>,
  79    #[arg(long, short, global = true)]
  80    in_place: bool,
  81    #[arg(long, short, global = true)]
  82    failfast: bool,
  83    /// How to handle failed examples in output: keep them or skip them.
  84    /// Failed examples are always logged to the run's failed directory.
  85    #[arg(long, global = true, default_value = "keep")]
  86    failed: FailedHandling,
  87}
  88
  89/// Controls whether failed examples are included in the main output.
  90/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
  91#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
  92pub enum FailedHandling {
  93    /// Include failed examples in the main output (default)
  94    #[default]
  95    Keep,
  96    /// Exclude failed examples from the main output
  97    Skip,
  98    /// Skip writing files
  99    SkipNoFiles,
 100}
 101
 102const INPUTS_HELP: &str = r#"
 103Inputs can be file paths or special specifiers:
 104
 105  path
 106      Path to an example(s) file (.md, .json, or .jsonl)
 107
 108  captured-after:{timestamp}
 109      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 110      These are examples captured via the "Capture Edit Prediction Example" action.
 111
 112  rejected-after:{timestamp}
 113      Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
 114      These are predictions that were shown to users but rejected (useful for DPO training).
 115
 116      Required environment variables to connect to Snowflake:
 117          EP_SNOWFLAKE_API_KEY
 118          EP_SNOWFLAKE_BASE_URL
 119
 120      Optional:
 121          EP_SNOWFLAKE_ROLE
 122
 123Examples:
 124
 125  # Read examples from a file
 126  ep read examples.jsonl -o output.jsonl
 127
 128  # Read captured examples after a timestamp
 129  ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
 130
 131  # Read rejected predictions for DPO training
 132  ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
 133
 134  # Mix multiple input sources
 135  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
 136"#;
 137
 138#[derive(Subcommand, Debug, Clone)]
 139enum Command {
 140    /// Read examples from files or fetch from Snowflake, output as .jsonl
 141    Read,
 142    /// Create git worktrees for each example and load file contents
 143    LoadProject,
 144    /// Retrieve context for input examples.
 145    Context,
 146    /// Generate a prompt string for a specific model
 147    FormatPrompt(FormatPromptArgs),
 148    /// Runs edit prediction
 149    Predict(PredictArgs),
 150    /// Parse model outputs (actual_output) into unified diffs (actual_patch).
 151    /// Requires format-prompt to have been run first. Uses provider from prompt.
 152    ParseOutput,
 153    /// Computes a score based on actual and expected patches
 154    Score(PredictArgs),
 155    /// Prepares a distillation dataset by copying expected outputs to
 156    /// predicted outputs and removing actual outputs and prompts.
 157    Distill,
 158    /// Print aggregated scores
 159    Eval(EvalArgs),
 160    /// Generate eval examples by analyzing git commits from a repository
 161    Synthesize(SynthesizeArgs),
 162    /// Remove git repositories and worktrees
 163    Clean,
 164    /// Generate an evaluation example by splitting a chronologically-ordered commit
 165    SplitCommit(SplitCommitArgs),
 166    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
 167    Split(SplitArgs),
 168    /// Filter a JSONL dataset by programming language (based on cursor_path extension)
 169    FilterLanguages(FilterLanguagesArgs),
 170    /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
 171    ImportBatch(ImportBatchArgs),
 172    /// Assess the quality of predictions using LLM-as-a-judge
 173    Qa(qa::QaArgs),
 174    /// Repair predictions that received poor QA scores by generating improved predictions
 175    Repair(repair::RepairArgs),
 176}
 177
 178impl Display for Command {
 179    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 180        match self {
 181            Command::Read => write!(f, "read"),
 182            Command::LoadProject => write!(f, "load-project"),
 183            Command::Context => write!(f, "context"),
 184            Command::FormatPrompt(args) => {
 185                write!(f, "format-prompt --provider={}", args.provider)
 186            }
 187            Command::Predict(args) => match &args.provider {
 188                Some(provider) => write!(f, "predict --provider={}", provider),
 189                None => write!(f, "predict"),
 190            },
 191            Command::ParseOutput => write!(f, "parse-output"),
 192            Command::Score(args) => match &args.provider {
 193                Some(provider) => write!(f, "score --provider={}", provider),
 194                None => write!(f, "score"),
 195            },
 196            Command::Distill => write!(f, "distill"),
 197            Command::Eval(args) => match &args.predict.provider {
 198                Some(provider) => write!(f, "eval --provider={}", provider),
 199                None => write!(f, "eval"),
 200            },
 201            Command::Synthesize(args) => {
 202                write!(f, "synthesize --repos {}", args.repos.join(" "))
 203            }
 204            Command::Clean => write!(f, "clean"),
 205            Command::SplitCommit(_) => write!(f, "split-commit"),
 206            Command::Split(_) => write!(f, "split"),
 207            Command::FilterLanguages(_) => write!(f, "filter-languages"),
 208            Command::ImportBatch(args) => {
 209                write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
 210            }
 211            Command::Qa(_) => {
 212                write!(f, "qa")
 213            }
 214            Command::Repair(_) => {
 215                write!(f, "repair")
 216            }
 217        }
 218    }
 219}
 220
 221#[derive(Debug, Args, Clone)]
 222struct FormatPromptArgs {
 223    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
 224    provider: PredictionProvider,
 225}
 226
 227#[derive(Debug, Args, Clone)]
 228struct PredictArgs {
 229    #[clap(long, short('p'))]
 230    provider: Option<PredictionProvider>,
 231    #[clap(long, default_value_t = 1)]
 232    repetitions: usize,
 233}
 234
 235#[derive(Debug, Args, Clone)]
 236struct EvalArgs {
 237    #[clap(flatten)]
 238    predict: PredictArgs,
 239    /// Path to write summary scores as JSON
 240    #[clap(long)]
 241    summary_json: Option<PathBuf>,
 242}
 243
 244#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 245enum PredictionProvider {
 246    Sweep,
 247    Mercury,
 248    Zeta1,
 249    Zeta2(ZetaVersion),
 250    Teacher(ZetaVersion),
 251    TeacherNonBatching(ZetaVersion),
 252    Repair,
 253}
 254
 255impl Default for PredictionProvider {
 256    fn default() -> Self {
 257        PredictionProvider::Zeta2(ZetaVersion::default())
 258    }
 259}
 260
 261impl std::fmt::Display for PredictionProvider {
 262    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 263        match self {
 264            PredictionProvider::Sweep => write!(f, "sweep"),
 265            PredictionProvider::Mercury => write!(f, "mercury"),
 266            PredictionProvider::Zeta1 => write!(f, "zeta1"),
 267            PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
 268            PredictionProvider::Teacher(version) => write!(f, "teacher:{version}"),
 269            PredictionProvider::TeacherNonBatching(version) => {
 270                write!(f, "teacher-non-batching:{version}")
 271            }
 272            PredictionProvider::Repair => write!(f, "repair"),
 273        }
 274    }
 275}
 276
 277impl std::str::FromStr for PredictionProvider {
 278    type Err = anyhow::Error;
 279
 280    fn from_str(mut s: &str) -> Result<Self, Self::Err> {
 281        let mut version = ZetaVersion::default();
 282        if let Some((first, second)) = s.split_once(':') {
 283            version = ZetaVersion::parse(second)?;
 284            s = first;
 285        }
 286
 287        let s_lower = s.to_lowercase();
 288        match s_lower.as_str() {
 289            "sweep" => Ok(PredictionProvider::Sweep),
 290            "mercury" => Ok(PredictionProvider::Mercury),
 291            "zeta1" => Ok(PredictionProvider::Zeta1),
 292            "zeta2" => Ok(PredictionProvider::Zeta2(version)),
 293            "teacher" => Ok(PredictionProvider::Teacher(version)),
 294            "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
 295                Ok(PredictionProvider::TeacherNonBatching(version))
 296            }
 297            "repair" => Ok(PredictionProvider::Repair),
 298            _ => {
 299                anyhow::bail!(
 300                    "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching, repair\n\
 301                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
 302                 Available zeta versions:\n{}",
 303                    ZetaVersion::options_as_string()
 304                )
 305            }
 306        }
 307    }
 308}
 309
 310impl Serialize for PredictionProvider {
 311    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 312    where
 313        S: Serializer,
 314    {
 315        serializer.serialize_str(&self.to_string())
 316    }
 317}
 318
 319impl<'de> Deserialize<'de> for PredictionProvider {
 320    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 321    where
 322        D: Deserializer<'de>,
 323    {
 324        let s = String::deserialize(deserializer)?;
 325        s.parse().map_err(serde::de::Error::custom)
 326    }
 327}
 328
 329#[derive(Debug, Args, Clone)]
 330struct SynthesizeArgs {
 331    /// Repository URLs (git@github.com:owner/repo or https://...)
 332    #[clap(long, required = true, num_args = 1..)]
 333    repos: Vec<String>,
 334
 335    /// Number of examples to generate per repository
 336    #[clap(long, default_value_t = 5)]
 337    count: usize,
 338
 339    /// Maximum commits to scan per repository before giving up
 340    #[clap(long, default_value_t = 100)]
 341    max_commits: usize,
 342
 343    /// Ignore state file and reprocess all commits
 344    #[clap(long)]
 345    fresh: bool,
 346}
 347
 348#[derive(Debug, Args, Clone)]
 349struct ImportBatchArgs {
 350    /// Anthropic batch IDs to import (e.g., msgbatch_xxx)
 351    #[clap(long, required = true, num_args = 1..)]
 352    batch_ids: Vec<String>,
 353}
 354
 355impl EpArgs {
 356    fn output_path(&self) -> Option<PathBuf> {
 357        if self.in_place {
 358            if self.inputs.len() == 1 {
 359                self.inputs.first().cloned()
 360            } else {
 361                panic!("--in-place requires exactly one input file")
 362            }
 363        } else {
 364            self.output.clone()
 365        }
 366    }
 367}
 368
 369async fn load_examples(
 370    http_client: Arc<dyn http_client::HttpClient>,
 371    args: &EpArgs,
 372    output_path: Option<&PathBuf>,
 373    background_executor: BackgroundExecutor,
 374) -> anyhow::Result<Vec<Example>> {
 375    let mut captured_after_timestamps = Vec::new();
 376    let mut rejected_after_timestamps = Vec::new();
 377    let mut file_inputs = Vec::new();
 378
 379    for input in &args.inputs {
 380        let input_string = input.to_string_lossy();
 381        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
 382            captured_after_timestamps.push(timestamp.to_string());
 383        } else if let Some(timestamp) =
 384            pull_examples::parse_rejected_after_input(input_string.as_ref())
 385        {
 386            rejected_after_timestamps.push(timestamp.to_string());
 387        } else {
 388            file_inputs.push(input.clone());
 389        }
 390    }
 391
 392    let mut examples = read_example_files(&file_inputs);
 393
 394    Progress::global().set_total_examples(examples.len());
 395
 396    let remaining_limit_for_snowflake =
 397        args.limit.map(|limit| limit.saturating_sub(examples.len()));
 398
 399    if let Some(0) = remaining_limit_for_snowflake {
 400        log::info!(
 401            "skipping Snowflake inputs because --limit is already satisfied by example files"
 402        );
 403    } else {
 404        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
 405
 406        if !captured_after_timestamps.is_empty() {
 407            captured_after_timestamps.sort();
 408
 409            let mut captured_examples = pull_examples::fetch_captured_examples_after(
 410                http_client.clone(),
 411                &captured_after_timestamps,
 412                max_rows_per_timestamp,
 413                background_executor.clone(),
 414            )
 415            .await?;
 416            examples.append(&mut captured_examples);
 417        }
 418
 419        if !rejected_after_timestamps.is_empty() {
 420            rejected_after_timestamps.sort();
 421
 422            let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
 423                http_client,
 424                &rejected_after_timestamps,
 425                max_rows_per_timestamp,
 426                background_executor,
 427            )
 428            .await?;
 429            examples.append(&mut rejected_examples);
 430        }
 431    }
 432
 433    crate::example::sort_examples_by_repo_and_rev(&mut examples);
 434
 435    if let Some(name_filter) = &args.name {
 436        examples.retain(|example| example.spec.name.contains(name_filter));
 437    }
 438    if let Some(repo_filter) = &args.repo {
 439        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
 440    }
 441
 442    // Skip resume logic for --in-place since input and output are the same file,
 443    // which would incorrectly treat all input examples as already processed.
 444    if !args.in_place {
 445        if let Some(path) = output_path {
 446            resume_from_output(path, &mut examples);
 447        }
 448    }
 449
 450    if let Some(offset) = args.offset {
 451        examples.splice(0..offset, []);
 452    }
 453
 454    if let Some(limit) = args.limit {
 455        examples.truncate(limit);
 456    }
 457
 458    let progress = Progress::global();
 459    progress.set_total_examples(examples.len());
 460    progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
 461
 462    Ok(examples)
 463}
 464
 465fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
 466    let mut hasher = collections::FxHasher::default();
 467    spec.hash(&mut hasher);
 468    hasher.finish()
 469}
 470
 471fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
 472    let file = match File::open(path) {
 473        Ok(f) => f,
 474        Err(_) => return,
 475    };
 476
 477    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
 478
 479    let reader = BufReader::new(file);
 480    let mut kept_lines = Vec::new();
 481    let mut kept_hashes = HashSet::default();
 482
 483    for line in reader.lines() {
 484        let line = match line {
 485            Ok(l) => l,
 486            Err(_) => continue,
 487        };
 488
 489        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
 490            let hash = spec_hash(&output_example.spec);
 491            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
 492                kept_hashes.insert(hash);
 493                kept_lines.push(line);
 494            }
 495        }
 496    }
 497
 498    let total = examples.len();
 499    let already_processed = kept_hashes.len();
 500
 501    eprintln!(
 502        "Resuming: {}/{} examples already processed",
 503        already_processed, total
 504    );
 505
 506    let file = OpenOptions::new()
 507        .write(true)
 508        .truncate(true)
 509        .open(path)
 510        .expect("Failed to open output file for rewriting");
 511    let mut writer = BufWriter::new(file);
 512    for line in &kept_lines {
 513        writeln!(writer, "{}", line).expect("Failed to write to output file");
 514    }
 515    writer.flush().expect("Failed to flush output file");
 516
 517    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
 518}
 519
 520fn main() {
 521    let args = EpArgs::parse();
 522
 523    if args.printenv {
 524        ::util::shell_env::print_env();
 525        return;
 526    }
 527
 528    let output = args.output_path();
 529    let command = match &args.command {
 530        Some(cmd) => cmd.clone(),
 531        None => {
 532            EpArgs::command().print_help().unwrap();
 533            return;
 534        }
 535    };
 536
 537    match &command {
 538        Command::ImportBatch(import_args) => {
 539            smol::block_on(async {
 540                let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
 541                    .expect("Failed to create Anthropic client");
 542                if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 543                    eprintln!("Error importing batches: {:?}", e);
 544                    std::process::exit(1);
 545                }
 546                println!(
 547                    "Successfully imported {} batch(es)",
 548                    import_args.batch_ids.len()
 549                );
 550            });
 551            return;
 552        }
 553        Command::Clean => {
 554            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
 555            return;
 556        }
 557        Command::Synthesize(synth_args) => {
 558            let Some(output_dir) = args.output else {
 559                panic!("output dir is required");
 560            };
 561            let config = SynthesizeConfig {
 562                repo_urls: synth_args.repos.clone(),
 563                count: synth_args.count,
 564                max_commits: synth_args.max_commits,
 565                output_dir,
 566                fresh: synth_args.fresh,
 567            };
 568            smol::block_on(async {
 569                if let Err(e) = run_synthesize(config).await {
 570                    eprintln!("Error: {:?}", e);
 571                    std::process::exit(1);
 572                }
 573            });
 574            return;
 575        }
 576        Command::SplitCommit(split_commit_args) => {
 577            if let Err(error) = split_commit::run_split_commit(
 578                split_commit_args,
 579                &args.inputs,
 580                output.as_ref(),
 581                args.failed,
 582            ) {
 583                eprintln!("{error:#}");
 584                std::process::exit(1);
 585            }
 586            return;
 587        }
 588        Command::Split(split_args) => {
 589            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
 590                eprintln!("{error:#}");
 591                std::process::exit(1);
 592            }
 593            return;
 594        }
 595        Command::FilterLanguages(filter_args) => {
 596            if let Err(error) =
 597                run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
 598            {
 599                eprintln!("{error:#}");
 600                std::process::exit(1);
 601            }
 602            return;
 603        }
 604        Command::Qa(qa_args) => {
 605            // Read examples from input files
 606            let mut examples = example::read_example_files(&args.inputs);
 607
 608            // Apply filters
 609            if let Some(name_filter) = &args.name {
 610                examples.retain(|e| e.spec.name.contains(name_filter));
 611            }
 612            if let Some(repo_filter) = &args.repo {
 613                examples.retain(|e| e.spec.repository_url.contains(repo_filter));
 614            }
 615            if let Some(offset) = args.offset {
 616                examples.splice(0..offset, []);
 617            }
 618            if let Some(limit) = args.limit {
 619                examples.truncate(limit);
 620            }
 621
 622            smol::block_on(async {
 623                if let Err(e) = qa::run_qa(&mut examples, qa_args, output.as_ref()).await {
 624                    eprintln!("Error: {:?}", e);
 625                    std::process::exit(1);
 626                }
 627            });
 628            return;
 629        }
 630        Command::Repair(repair_args) => {
 631            // Read examples from input files
 632            let mut examples = example::read_example_files(&args.inputs);
 633
 634            // Apply filters
 635            if let Some(name_filter) = &args.name {
 636                examples.retain(|e| e.spec.name.contains(name_filter));
 637            }
 638            if let Some(repo_filter) = &args.repo {
 639                examples.retain(|e| e.spec.repository_url.contains(repo_filter));
 640            }
 641            if let Some(offset) = args.offset {
 642                examples.splice(0..offset, []);
 643            }
 644            if let Some(limit) = args.limit {
 645                examples.truncate(limit);
 646            }
 647
 648            smol::block_on(async {
 649                if let Err(e) =
 650                    repair::run_repair(&mut examples, repair_args, output.as_ref()).await
 651                {
 652                    eprintln!("Error: {:?}", e);
 653                    std::process::exit(1);
 654                }
 655            });
 656            return;
 657        }
 658        _ => {}
 659    }
 660
 661    let http_client = Arc::new(ReqwestClient::new());
 662    let app = Application::headless().with_http_client(http_client);
 663
 664    app.run(move |cx| {
 665        let app_state = Arc::new(headless::init(cx));
 666        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
 667
 668        cx.spawn(async move |cx| {
 669            let result = async {
 670                let examples = load_examples(
 671                    app_state.client.http_client(),
 672                    &args,
 673                    output.as_ref(),
 674                    cx.background_executor().clone(),
 675                )
 676                .await?;
 677
 678                match &command {
 679                    Command::Predict(args) | Command::Score(args) => {
 680                        predict::sync_batches(args.provider.as_ref()).await?;
 681                    }
 682                    Command::Eval(args) => {
 683                        predict::sync_batches(args.predict.provider.as_ref()).await?;
 684                    }
 685                    _ => (),
 686                }
 687
 688                let failfast_on_single_example = examples.len() == 1;
 689
 690                // For --in-place, write to a temp file and rename at the end to avoid data loss on interruption
 691                let in_place_temp_path = if args.in_place {
 692                    output.as_ref().map(|path| {
 693                        let mut temp_path = path.clone();
 694                        temp_path.set_extension("jsonl.tmp");
 695                        temp_path
 696                    })
 697                } else {
 698                    None
 699                };
 700
 701                let output_sender: Option<mpsc::UnboundedSender<String>> =
 702                    if args.output.is_some() || !matches!(command, Command::Eval(_)) {
 703                        let write_path = in_place_temp_path.as_ref().or(output.as_ref());
 704                        write_path.map(|path| {
 705                            let file = if args.in_place {
 706                                // For --in-place, write to temp file (truncate if exists)
 707                                OpenOptions::new()
 708                                    .create(true)
 709                                    .write(true)
 710                                    .truncate(true)
 711                                    .open(path)
 712                                    .expect("Failed to open temp output file")
 713                            } else {
 714                                // For regular output, append to support resuming
 715                                OpenOptions::new()
 716                                    .create(true)
 717                                    .append(true)
 718                                    .open(path)
 719                                    .expect("Failed to open output file")
 720                            };
 721                            let mut writer = BufWriter::new(file);
 722                            let (sender, mut receiver) = mpsc::unbounded::<String>();
 723                            cx.background_spawn(async move {
 724                                while let Some(line) = receiver.next().await {
 725                                    writeln!(writer, "{}", line).expect("Failed to write example");
 726                                    writer.flush().expect("Failed to flush output");
 727                                }
 728                            })
 729                            .detach();
 730                            sender
 731                        })
 732                    } else {
 733                        None
 734                    };
 735
 736                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
 737                let finished_examples = Mutex::new(Vec::new());
 738
 739                let mut tasks = Vec::new();
 740                for _ in 0..args.max_parallelism {
 741                    tasks.push(async {
 742                        loop {
 743                            let Some(mut repo_examples) =
 744                                grouped_examples.lock().unwrap().pop_front()
 745                            else {
 746                                break;
 747                            };
 748                            for example in &mut repo_examples {
 749                                let example_progress =
 750                                    Progress::global().start_group(&example.spec.name);
 751
 752                                let result = async {
 753                                    match &command {
 754                                        Command::Read => {}
 755                                        Command::LoadProject => {
 756                                            run_load_project(
 757                                                example,
 758                                                app_state.clone(),
 759                                                &example_progress,
 760                                                cx.clone(),
 761                                            )
 762                                            .await?;
 763                                        }
 764                                        Command::Context => {
 765                                            run_context_retrieval(
 766                                                example,
 767                                                app_state.clone(),
 768                                                &example_progress,
 769                                                cx.clone(),
 770                                            )
 771                                            .await?;
 772                                        }
 773                                        Command::FormatPrompt(args) => {
 774                                            run_format_prompt(
 775                                                example,
 776                                                args,
 777                                                app_state.clone(),
 778                                                &example_progress,
 779                                                cx.clone(),
 780                                            )
 781                                            .await?;
 782                                        }
 783                                        Command::Predict(args) => {
 784                                            run_prediction(
 785                                                example,
 786                                                args,
 787                                                app_state.clone(),
 788                                                &example_progress,
 789                                                cx.clone(),
 790                                            )
 791                                            .await?;
 792                                        }
 793                                        Command::ParseOutput => {
 794                                            parse_output::run_parse_output(example)?;
 795                                        }
 796                                        Command::Distill => {
 797                                            run_distill(example).await?;
 798                                        }
 799                                        Command::Score(args) => {
 800                                            run_scoring(
 801                                                example,
 802                                                args,
 803                                                app_state.clone(),
 804                                                &example_progress,
 805                                                cx.clone(),
 806                                            )
 807                                            .await?;
 808                                        }
 809                                        Command::Eval(args) => {
 810                                            run_scoring(
 811                                                example,
 812                                                &args.predict,
 813                                                app_state.clone(),
 814                                                &example_progress,
 815                                                cx.clone(),
 816                                            )
 817                                            .await?;
 818                                        }
 819                                        Command::Clean
 820                                        | Command::Synthesize(_)
 821                                        | Command::SplitCommit(_)
 822                                        | Command::Split(_)
 823                                        | Command::FilterLanguages(_)
 824                                        | Command::ImportBatch(_)
 825                                        | Command::Qa(_)
 826                                        | Command::Repair(_) => {
 827                                            unreachable!()
 828                                        }
 829                                    }
 830                                    anyhow::Ok(())
 831                                }
 832                                .await;
 833
 834                                let failed = if let Err(error) = result {
 835                                    handle_error(
 836                                        error,
 837                                        &args,
 838                                        &command,
 839                                        &app_state,
 840                                        failfast_on_single_example,
 841                                        &example,
 842                                    )
 843                                    .await;
 844                                    true
 845                                } else {
 846                                    false
 847                                };
 848
 849                                let should_write = !failed || args.failed == FailedHandling::Keep;
 850                                if should_write {
 851                                    if let Some(ref mut sender) = output_sender.clone() {
 852                                        let line = serde_json::to_string(&example).unwrap();
 853                                        sender
 854                                            .send(line)
 855                                            .await
 856                                            .expect("Failed to send to output writer");
 857                                    } else if args.output.is_none()
 858                                        && !matches!(command, Command::Eval(_))
 859                                    {
 860                                        let line = serde_json::to_string(&example).unwrap();
 861                                        println!("{}", line);
 862                                    }
 863                                }
 864                            }
 865
 866                            let repo_url = &repo_examples.first().unwrap().spec.repository_url;
 867                            let project = repo_examples
 868                                .iter()
 869                                .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
 870                                .or_else(|| app_state.project_cache.get(repo_url));
 871
 872                            if let Some(project) = project {
 873                                let mut cx = cx.clone();
 874
 875                                let shutdown_task: Task<()> =
 876                                    project.update(&mut cx, |project, cx| {
 877                                        let lsp_store = project.lsp_store();
 878                                        lsp_store.update(cx, |lsp_store, cx| {
 879                                            lsp_store.shutdown_all_language_servers(cx)
 880                                        })
 881                                    });
 882
 883                                shutdown_task.await;
 884
 885                                if let Some(ep_store) =
 886                                    cx.update(|cx| EditPredictionStore::try_global(cx))
 887                                {
 888                                    ep_store.update(&mut cx, |store, _| {
 889                                        store.remove_project(&project);
 890                                    });
 891                                }
 892                            }
 893
 894                            app_state.project_cache.remove(repo_url);
 895                            for example in &mut repo_examples {
 896                                example.state.take();
 897                            }
 898                            finished_examples
 899                                .lock()
 900                                .unwrap()
 901                                .extend_from_slice(&repo_examples);
 902                        }
 903                    });
 904                }
 905                futures::future::join_all(tasks).await;
 906
 907                Progress::global().finalize();
 908
 909                match &command {
 910                    Command::Predict(args) | Command::Score(args) => {
 911                        predict::sync_batches(args.provider.as_ref()).await?;
 912                    }
 913                    Command::Eval(args) => {
 914                        predict::sync_batches(args.predict.provider.as_ref()).await?;
 915                    }
 916                    _ => (),
 917                }
 918
 919                match &command {
 920                    Command::Eval(args) => {
 921                        let examples = finished_examples.lock().unwrap();
 922                        score::print_report(&examples);
 923                        if let Some(summary_path) = &args.summary_json {
 924                            score::write_summary_json(&examples, summary_path)?;
 925                        }
 926                    }
 927                    _ => (),
 928                };
 929
 930                // For --in-place, atomically rename temp file to original
 931                if let (Some(temp_path), Some(final_path)) = (&in_place_temp_path, &output) {
 932                    std::fs::rename(temp_path, final_path)
 933                        .expect("Failed to rename temp file to final output");
 934                }
 935
 936                anyhow::Ok(())
 937            }
 938            .await;
 939
 940            if let Err(e) = result {
 941                panic!("Fatal error: {:?}", e);
 942            }
 943
 944            let _ = cx.update(|cx| cx.quit());
 945        })
 946        .detach();
 947    });
 948}
 949
 950async fn handle_error(
 951    error: anyhow::Error,
 952    args: &EpArgs,
 953    command: &Command,
 954    app_state: &Arc<headless::EpAppState>,
 955    failfast_on_single_example: bool,
 956    example: &Example,
 957) {
 958    Progress::global().increment_failed();
 959
 960    let msg;
 961    if !matches!(args.failed, FailedHandling::SkipNoFiles) {
 962        let example_name = example.spec.filename();
 963
 964        let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
 965        app_state
 966            .fs
 967            .write(
 968                &failed_example_path,
 969                &serde_json::to_vec_pretty(&example).unwrap(),
 970            )
 971            .await
 972            .unwrap();
 973        let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
 974        app_state
 975            .fs
 976            .write(&err_path, format!("{error:?}").as_bytes())
 977            .await
 978            .unwrap();
 979
 980        let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
 981        let mut file = OpenOptions::new()
 982            .create(true)
 983            .append(true)
 984            .open(&failed_jsonl_path)
 985            .expect("Failed to open failed.jsonl");
 986        writeln!(file, "{}", serde_json::to_string(example).unwrap())
 987            .expect("Failed to write to failed.jsonl");
 988
 989        let cursor_path = example
 990            .repo_name()
 991            .unwrap()
 992            .worktree_path()
 993            .join(&example.spec.cursor_path);
 994        msg = format!(
 995            indoc::indoc! {"
 996                While processing \"{}\":
 997
 998                \x1b[31m{:?}\x1b[0m
 999
1000                Example:        \x1b[36m{}\x1b[0m
1001                Error file:     \x1b[36m{}\x1b[0m
1002                Cursor file:    \x1b[36m{}\x1b[0m
1003                Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1004            "},
1005            example.spec.name,
1006            error,
1007            failed_example_path.display(),
1008            err_path.display(),
1009            cursor_path.display(),
1010            command,
1011            failed_example_path.display(),
1012        );
1013    } else {
1014        msg = format!(
1015            indoc::indoc! {"
1016            While processing \"{}\":
1017
1018                \x1b[31m{:?}\x1b[0m
1019            "},
1020            example.spec.name, error
1021        );
1022    }
1023
1024    if args.failfast || failfast_on_single_example {
1025        Progress::global().finalize();
1026        panic!("{}", msg);
1027    } else {
1028        log::error!("{}", msg);
1029    }
1030}