main.rs

   1mod anthropic_client;
   2mod distill;
   3mod example;
   4mod filter_languages;
   5mod format_prompt;
   6mod git;
   7mod headless;
   8mod load_project;
   9mod metrics;
  10mod openai_client;
  11mod parse_output;
  12mod paths;
  13mod predict;
  14mod progress;
  15mod pull_examples;
  16mod qa;
  17mod reorder_patch;
  18mod repair;
  19mod retrieve_context;
  20mod reversal_tracking;
  21mod score;
  22mod split_commit;
  23mod split_dataset;
  24mod synthesize;
  25mod word_diff;
  26use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
  27use collections::HashSet;
  28use edit_prediction::EditPredictionStore;
  29use futures::channel::mpsc;
  30use futures::{SinkExt as _, StreamExt as _};
  31use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
  32use zeta_prompt::ZetaVersion;
  33
  34use reqwest_client::ReqwestClient;
  35use serde::{Deserialize, Deserializer, Serialize, Serializer};
  36use std::fmt::Display;
  37use std::fs::{File, OpenOptions};
  38use std::hash::{Hash, Hasher};
  39use std::io::{BufRead, BufReader, BufWriter, Write};
  40use std::sync::Mutex;
  41use std::{path::PathBuf, sync::Arc};
  42
  43use crate::distill::run_distill;
  44use crate::example::{Example, group_examples_by_repo, read_example_files};
  45use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
  46use crate::format_prompt::run_format_prompt;
  47use crate::load_project::run_load_project;
  48use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
  49use crate::predict::run_prediction;
  50use crate::progress::Progress;
  51use crate::retrieve_context::run_context_retrieval;
  52use crate::score::run_scoring;
  53use crate::split_commit::SplitCommitArgs;
  54use crate::split_dataset::SplitArgs;
  55use crate::synthesize::{SynthesizeConfig, run_synthesize};
  56
  57#[derive(Parser, Debug)]
  58#[command(name = "ep")]
  59struct EpArgs {
  60    #[arg(long, default_value_t = false)]
  61    printenv: bool,
  62    #[clap(long, default_value_t = 10, global = true)]
  63    max_parallelism: usize,
  64    #[clap(long, global = true)]
  65    limit: Option<usize>,
  66    #[clap(long, global = true)]
  67    offset: Option<usize>,
  68    /// Filter examples by name
  69    #[clap(long, global = true)]
  70    name: Option<String>,
  71    /// Filter examples by repository
  72    #[clap(long, global = true)]
  73    repo: Option<String>,
  74    #[command(subcommand)]
  75    command: Option<Command>,
  76    #[clap(global = true, help = INPUTS_HELP)]
  77    inputs: Vec<PathBuf>,
  78    #[arg(long, short, global = true)]
  79    output: Option<PathBuf>,
  80    #[arg(long, short, global = true)]
  81    in_place: bool,
  82    #[arg(long, short, global = true)]
  83    failfast: bool,
  84    /// How to handle failed examples in output: keep them or skip them.
  85    /// Failed examples are always logged to the run's failed directory.
  86    #[arg(long, global = true, default_value = "keep")]
  87    failed: FailedHandling,
  88    /// Output as markdown files instead of JSONL. When set, -o specifies a directory
  89    /// where one .md file per example will be written (named after each example).
  90    #[arg(long, short, global = true)]
  91    markdown: bool,
  92}
  93
  94/// Controls whether failed examples are included in the main output.
  95/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
  96#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
  97pub enum FailedHandling {
  98    /// Include failed examples in the main output (default)
  99    #[default]
 100    Keep,
 101    /// Exclude failed examples from the main output
 102    Skip,
 103    /// Skip writing files
 104    SkipNoFiles,
 105}
 106
 107const INPUTS_HELP: &str = r#"
 108Inputs can be file paths or special specifiers:
 109
 110  path
 111      Path to an example(s) file (.md, .json, or .jsonl)
 112
 113  captured-after:{timestamp}
 114      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 115      These are examples captured via the "Capture Edit Prediction Example" action.
 116
 117  rejected-after:{timestamp}
 118      Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
 119      These are predictions that were shown to users but rejected (useful for DPO training).
 120
 121      Required environment variables to connect to Snowflake:
 122          EP_SNOWFLAKE_API_KEY
 123          EP_SNOWFLAKE_BASE_URL
 124
 125      Optional:
 126          EP_SNOWFLAKE_ROLE
 127
 128Examples:
 129
 130  # Read examples from a file
 131  ep read examples.jsonl -o output.jsonl
 132
 133  # Read captured examples after a timestamp
 134  ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
 135
 136  # Read rejected predictions for DPO training
 137  ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
 138
 139  # Mix multiple input sources
 140  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
 141"#;
 142
 143#[derive(Subcommand, Debug, Clone)]
 144enum Command {
 145    /// Read examples from files or fetch from Snowflake, output as .jsonl
 146    Read,
 147    /// Create git worktrees for each example and load file contents
 148    LoadProject,
 149    /// Retrieve context for input examples.
 150    Context,
 151    /// Generate a prompt string for a specific model
 152    FormatPrompt(FormatPromptArgs),
 153    /// Runs edit prediction
 154    Predict(PredictArgs),
 155    /// Parse model outputs (actual_output) into unified diffs (actual_patch).
 156    /// Requires format-prompt to have been run first. Uses provider from prompt.
 157    ParseOutput,
 158    /// Computes a score based on actual and expected patches
 159    Score(PredictArgs),
 160    /// Prepares a distillation dataset by copying expected outputs to
 161    /// predicted outputs and removing actual outputs and prompts.
 162    Distill,
 163    /// Print aggregated scores
 164    Eval(EvalArgs),
 165    /// Generate eval examples by analyzing git commits from a repository
 166    Synthesize(SynthesizeArgs),
 167    /// Remove git repositories and worktrees
 168    Clean,
 169    /// Generate an evaluation example by splitting a chronologically-ordered commit
 170    SplitCommit(SplitCommitArgs),
 171    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
 172    Split(SplitArgs),
 173    /// Filter a JSONL dataset by programming language (based on cursor_path extension)
 174    FilterLanguages(FilterLanguagesArgs),
 175    /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
 176    ImportBatch(ImportBatchArgs),
 177    /// Assess the quality of predictions using LLM-as-a-judge
 178    Qa(qa::QaArgs),
 179    /// Repair predictions that received poor QA scores by generating improved predictions
 180    Repair(repair::RepairArgs),
 181}
 182
 183impl Display for Command {
 184    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 185        match self {
 186            Command::Read => write!(f, "read"),
 187            Command::LoadProject => write!(f, "load-project"),
 188            Command::Context => write!(f, "context"),
 189            Command::FormatPrompt(args) => {
 190                write!(f, "format-prompt --provider={}", args.provider)
 191            }
 192            Command::Predict(args) => match &args.provider {
 193                Some(provider) => write!(f, "predict --provider={}", provider),
 194                None => write!(f, "predict"),
 195            },
 196            Command::ParseOutput => write!(f, "parse-output"),
 197            Command::Score(args) => match &args.provider {
 198                Some(provider) => write!(f, "score --provider={}", provider),
 199                None => write!(f, "score"),
 200            },
 201            Command::Distill => write!(f, "distill"),
 202            Command::Eval(args) => match &args.predict.provider {
 203                Some(provider) => write!(f, "eval --provider={}", provider),
 204                None => write!(f, "eval"),
 205            },
 206            Command::Synthesize(args) => {
 207                write!(f, "synthesize --repos {}", args.repos.join(" "))
 208            }
 209            Command::Clean => write!(f, "clean"),
 210            Command::SplitCommit(_) => write!(f, "split-commit"),
 211            Command::Split(_) => write!(f, "split"),
 212            Command::FilterLanguages(_) => write!(f, "filter-languages"),
 213            Command::ImportBatch(args) => {
 214                write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
 215            }
 216            Command::Qa(_) => {
 217                write!(f, "qa")
 218            }
 219            Command::Repair(_) => {
 220                write!(f, "repair")
 221            }
 222        }
 223    }
 224}
 225
 226#[derive(Debug, Args, Clone)]
 227struct FormatPromptArgs {
 228    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
 229    provider: PredictionProvider,
 230}
 231
 232#[derive(Debug, Args, Clone)]
 233struct PredictArgs {
 234    #[clap(long, short('p'))]
 235    provider: Option<PredictionProvider>,
 236    #[clap(long, default_value_t = 1)]
 237    repetitions: usize,
 238}
 239
 240#[derive(Debug, Args, Clone)]
 241struct EvalArgs {
 242    #[clap(flatten)]
 243    predict: PredictArgs,
 244    /// Path to write summary scores as JSON
 245    #[clap(long)]
 246    summary_json: Option<PathBuf>,
 247}
 248
 249#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 250pub enum TeacherBackend {
 251    Sonnet45,
 252    Gpt52,
 253}
 254
 255impl std::fmt::Display for TeacherBackend {
 256    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 257        match self {
 258            TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
 259            TeacherBackend::Gpt52 => write!(f, "gpt52"),
 260        }
 261    }
 262}
 263
 264impl std::str::FromStr for TeacherBackend {
 265    type Err = anyhow::Error;
 266
 267    fn from_str(s: &str) -> Result<Self, Self::Err> {
 268        match s.to_lowercase().as_str() {
 269            "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
 270            "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
 271            "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
 272            _ => anyhow::bail!("unknown teacher backend `{s}`. Valid options: sonnet45, gpt52"),
 273        }
 274    }
 275}
 276
 277impl TeacherBackend {
 278    pub fn model_name(&self) -> &'static str {
 279        match self {
 280            TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
 281            TeacherBackend::Gpt52 => "gpt-5.2",
 282        }
 283    }
 284}
 285
 286#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 287enum PredictionProvider {
 288    Sweep,
 289    Mercury,
 290    Zeta1,
 291    Zeta2(ZetaVersion),
 292    Teacher(TeacherBackend),
 293    TeacherNonBatching(TeacherBackend),
 294    Repair,
 295}
 296
 297impl Default for PredictionProvider {
 298    fn default() -> Self {
 299        PredictionProvider::Zeta2(ZetaVersion::default())
 300    }
 301}
 302
 303impl std::fmt::Display for PredictionProvider {
 304    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 305        match self {
 306            PredictionProvider::Sweep => write!(f, "sweep"),
 307            PredictionProvider::Mercury => write!(f, "mercury"),
 308            PredictionProvider::Zeta1 => write!(f, "zeta1"),
 309            PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
 310            PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
 311            PredictionProvider::TeacherNonBatching(backend) => {
 312                write!(f, "teacher-non-batching:{backend}")
 313            }
 314            PredictionProvider::Repair => write!(f, "repair"),
 315        }
 316    }
 317}
 318
 319impl std::str::FromStr for PredictionProvider {
 320    type Err = anyhow::Error;
 321
 322    fn from_str(s: &str) -> Result<Self, Self::Err> {
 323        let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
 324
 325        let provider_lower = provider.to_lowercase();
 326        match provider_lower.as_str() {
 327            "sweep" => Ok(PredictionProvider::Sweep),
 328            "mercury" => Ok(PredictionProvider::Mercury),
 329            "zeta1" => Ok(PredictionProvider::Zeta1),
 330            "zeta2" => {
 331                let version = arg.map(ZetaVersion::parse).transpose()?.unwrap_or_default();
 332                Ok(PredictionProvider::Zeta2(version))
 333            }
 334            "teacher" => {
 335                let backend = arg
 336                    .map(|a| a.parse())
 337                    .transpose()?
 338                    .unwrap_or(TeacherBackend::Sonnet45);
 339                Ok(PredictionProvider::Teacher(backend))
 340            }
 341            "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
 342                let backend = arg
 343                    .map(|a| a.parse())
 344                    .transpose()?
 345                    .unwrap_or(TeacherBackend::Sonnet45);
 346                Ok(PredictionProvider::TeacherNonBatching(backend))
 347            }
 348            "repair" => Ok(PredictionProvider::Repair),
 349            _ => {
 350                anyhow::bail!(
 351                    "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
 352                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
 353                 For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\
 354                 Available zeta versions:\n{}",
 355                    ZetaVersion::options_as_string()
 356                )
 357            }
 358        }
 359    }
 360}
 361
 362impl Serialize for PredictionProvider {
 363    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 364    where
 365        S: Serializer,
 366    {
 367        serializer.serialize_str(&self.to_string())
 368    }
 369}
 370
 371impl<'de> Deserialize<'de> for PredictionProvider {
 372    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 373    where
 374        D: Deserializer<'de>,
 375    {
 376        let s = String::deserialize(deserializer)?;
 377        s.parse().map_err(serde::de::Error::custom)
 378    }
 379}
 380
 381#[derive(Debug, Args, Clone)]
 382struct SynthesizeArgs {
 383    /// Repository URLs (git@github.com:owner/repo or https://...)
 384    #[clap(long, required = true, num_args = 1..)]
 385    repos: Vec<String>,
 386
 387    /// Number of examples to generate per repository
 388    #[clap(long, default_value_t = 5)]
 389    count: usize,
 390
 391    /// Maximum commits to scan per repository before giving up
 392    #[clap(long, default_value_t = 100)]
 393    max_commits: usize,
 394
 395    /// Ignore state file and reprocess all commits
 396    #[clap(long)]
 397    fresh: bool,
 398}
 399
 400#[derive(Debug, Args, Clone)]
 401struct ImportBatchArgs {
 402    /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
 403    #[clap(long, required = true, num_args = 1..)]
 404    batch_ids: Vec<String>,
 405    /// Which provider's batches to import (anthropic or openai)
 406    #[clap(long, default_value = "anthropic")]
 407    provider: BatchProvider,
 408}
 409
 410#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
 411enum BatchProvider {
 412    Anthropic,
 413    Openai,
 414}
 415
 416impl EpArgs {
 417    fn output_path(&self) -> Option<PathBuf> {
 418        if self.in_place {
 419            if self.inputs.len() == 1 {
 420                self.inputs.first().cloned()
 421            } else {
 422                panic!("--in-place requires exactly one input file")
 423            }
 424        } else {
 425            self.output.clone()
 426        }
 427    }
 428}
 429
 430async fn load_examples(
 431    http_client: Arc<dyn http_client::HttpClient>,
 432    args: &EpArgs,
 433    output_path: Option<&PathBuf>,
 434    background_executor: BackgroundExecutor,
 435) -> anyhow::Result<Vec<Example>> {
 436    let mut captured_after_timestamps = Vec::new();
 437    let mut rejected_after_timestamps = Vec::new();
 438    let mut requested_after_timestamps = Vec::new();
 439    let mut file_inputs = Vec::new();
 440
 441    for input in &args.inputs {
 442        let input_string = input.to_string_lossy();
 443        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
 444            captured_after_timestamps.push(timestamp.to_string());
 445        } else if let Some(timestamp) =
 446            pull_examples::parse_rejected_after_input(input_string.as_ref())
 447        {
 448            rejected_after_timestamps.push(timestamp.to_string());
 449        } else if let Some(timestamp) =
 450            pull_examples::parse_requested_after_input(input_string.as_ref())
 451        {
 452            requested_after_timestamps.push(timestamp.to_string());
 453        } else {
 454            file_inputs.push(input.clone());
 455        }
 456    }
 457
 458    let mut examples = read_example_files(&file_inputs);
 459
 460    Progress::global().set_total_examples(examples.len());
 461
 462    let remaining_limit_for_snowflake =
 463        args.limit.map(|limit| limit.saturating_sub(examples.len()));
 464
 465    if let Some(0) = remaining_limit_for_snowflake {
 466        log::info!(
 467            "skipping Snowflake inputs because --limit is already satisfied by example files"
 468        );
 469    } else {
 470        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
 471
 472        if !captured_after_timestamps.is_empty() {
 473            captured_after_timestamps.sort();
 474
 475            let mut captured_examples = pull_examples::fetch_captured_examples_after(
 476                http_client.clone(),
 477                &captured_after_timestamps,
 478                max_rows_per_timestamp,
 479                background_executor.clone(),
 480            )
 481            .await?;
 482            examples.append(&mut captured_examples);
 483        }
 484
 485        if !rejected_after_timestamps.is_empty() {
 486            rejected_after_timestamps.sort();
 487
 488            let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
 489                http_client.clone(),
 490                &rejected_after_timestamps,
 491                max_rows_per_timestamp,
 492                background_executor.clone(),
 493            )
 494            .await?;
 495            examples.append(&mut rejected_examples);
 496        }
 497
 498        if !requested_after_timestamps.is_empty() {
 499            requested_after_timestamps.sort();
 500
 501            let mut requested_examples = pull_examples::fetch_requested_examples_after(
 502                http_client,
 503                &requested_after_timestamps,
 504                max_rows_per_timestamp,
 505                background_executor,
 506            )
 507            .await?;
 508            examples.append(&mut requested_examples);
 509        }
 510    }
 511
 512    crate::example::sort_examples_by_repo_and_rev(&mut examples);
 513
 514    if let Some(name_filter) = &args.name {
 515        examples.retain(|example| example.spec.name.contains(name_filter));
 516    }
 517    if let Some(repo_filter) = &args.repo {
 518        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
 519    }
 520
 521    // Skip resume logic for --in-place since input and output are the same file,
 522    // which would incorrectly treat all input examples as already processed.
 523    if !args.in_place {
 524        if let Some(path) = output_path {
 525            resume_from_output(path, &mut examples);
 526        }
 527    }
 528
 529    if let Some(offset) = args.offset {
 530        examples.splice(0..offset, []);
 531    }
 532
 533    if let Some(limit) = args.limit {
 534        examples.truncate(limit);
 535    }
 536
 537    let progress = Progress::global();
 538    progress.set_total_examples(examples.len());
 539    progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
 540
 541    Ok(examples)
 542}
 543
 544fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
 545    let mut hasher = collections::FxHasher::default();
 546    spec.hash(&mut hasher);
 547    hasher.finish()
 548}
 549
 550fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
 551    let file = match File::open(path) {
 552        Ok(f) => f,
 553        Err(_) => return,
 554    };
 555
 556    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
 557
 558    let reader = BufReader::new(file);
 559    let mut kept_lines = Vec::new();
 560    let mut kept_hashes = HashSet::default();
 561
 562    for line in reader.lines() {
 563        let line = match line {
 564            Ok(l) => l,
 565            Err(_) => continue,
 566        };
 567
 568        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
 569            let hash = spec_hash(&output_example.spec);
 570            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
 571                kept_hashes.insert(hash);
 572                kept_lines.push(line);
 573            }
 574        }
 575    }
 576
 577    let total = examples.len();
 578    let already_processed = kept_hashes.len();
 579
 580    eprintln!(
 581        "Resuming: {}/{} examples already processed",
 582        already_processed, total
 583    );
 584
 585    let file = OpenOptions::new()
 586        .write(true)
 587        .truncate(true)
 588        .open(path)
 589        .expect("Failed to open output file for rewriting");
 590    let mut writer = BufWriter::new(file);
 591    for line in &kept_lines {
 592        writeln!(writer, "{}", line).expect("Failed to write to output file");
 593    }
 594    writer.flush().expect("Failed to flush output file");
 595
 596    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
 597}
 598
 599fn main() {
 600    let args = EpArgs::parse();
 601
 602    if args.printenv {
 603        ::util::shell_env::print_env();
 604        return;
 605    }
 606
 607    let output = args.output_path();
 608
 609    if args.markdown && output.is_none() {
 610        eprintln!("--markdown requires -o to specify the output directory");
 611        std::process::exit(1);
 612    }
 613
 614    let command = match &args.command {
 615        Some(cmd) => cmd.clone(),
 616        None => {
 617            EpArgs::command().print_help().unwrap();
 618            return;
 619        }
 620    };
 621
 622    match &command {
 623        Command::ImportBatch(import_args) => {
 624            smol::block_on(async {
 625                match import_args.provider {
 626                    BatchProvider::Anthropic => {
 627                        let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
 628                            .expect("Failed to create Anthropic client");
 629                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 630                            eprintln!("Error importing Anthropic batches: {:?}", e);
 631                            std::process::exit(1);
 632                        }
 633                    }
 634                    BatchProvider::Openai => {
 635                        let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
 636                            .expect("Failed to create OpenAI client");
 637                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 638                            eprintln!("Error importing OpenAI batches: {:?}", e);
 639                            std::process::exit(1);
 640                        }
 641                    }
 642                }
 643                println!(
 644                    "Successfully imported {} batch(es)",
 645                    import_args.batch_ids.len()
 646                );
 647            });
 648            return;
 649        }
 650        Command::Clean => {
 651            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
 652            return;
 653        }
 654        Command::Synthesize(synth_args) => {
 655            let Some(output_dir) = args.output else {
 656                panic!("output dir is required");
 657            };
 658            let config = SynthesizeConfig {
 659                repo_urls: synth_args.repos.clone(),
 660                count: synth_args.count,
 661                max_commits: synth_args.max_commits,
 662                output_dir,
 663                fresh: synth_args.fresh,
 664            };
 665            smol::block_on(async {
 666                if let Err(e) = run_synthesize(config).await {
 667                    eprintln!("Error: {:?}", e);
 668                    std::process::exit(1);
 669                }
 670            });
 671            return;
 672        }
 673        Command::SplitCommit(split_commit_args) => {
 674            if let Err(error) = split_commit::run_split_commit(
 675                split_commit_args,
 676                &args.inputs,
 677                output.as_ref(),
 678                args.failed,
 679            ) {
 680                eprintln!("{error:#}");
 681                std::process::exit(1);
 682            }
 683            return;
 684        }
 685        Command::Split(split_args) => {
 686            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
 687                eprintln!("{error:#}");
 688                std::process::exit(1);
 689            }
 690            return;
 691        }
 692        Command::FilterLanguages(filter_args) => {
 693            if let Err(error) =
 694                run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
 695            {
 696                eprintln!("{error:#}");
 697                std::process::exit(1);
 698            }
 699            return;
 700        }
 701        Command::Qa(qa_args) => {
 702            // Read examples from input files
 703            let mut examples = example::read_example_files(&args.inputs);
 704
 705            // Apply filters
 706            if let Some(name_filter) = &args.name {
 707                examples.retain(|e| e.spec.name.contains(name_filter));
 708            }
 709            if let Some(repo_filter) = &args.repo {
 710                examples.retain(|e| e.spec.repository_url.contains(repo_filter));
 711            }
 712            if let Some(offset) = args.offset {
 713                examples.splice(0..offset, []);
 714            }
 715            if let Some(limit) = args.limit {
 716                examples.truncate(limit);
 717            }
 718
 719            smol::block_on(async {
 720                if let Err(e) = qa::run_qa(&mut examples, qa_args, output.as_ref()).await {
 721                    eprintln!("Error: {:?}", e);
 722                    std::process::exit(1);
 723                }
 724            });
 725            return;
 726        }
 727        Command::Repair(repair_args) => {
 728            // Read examples from input files
 729            let mut examples = example::read_example_files(&args.inputs);
 730
 731            // Apply filters
 732            if let Some(name_filter) = &args.name {
 733                examples.retain(|e| e.spec.name.contains(name_filter));
 734            }
 735            if let Some(repo_filter) = &args.repo {
 736                examples.retain(|e| e.spec.repository_url.contains(repo_filter));
 737            }
 738            if let Some(offset) = args.offset {
 739                examples.splice(0..offset, []);
 740            }
 741            if let Some(limit) = args.limit {
 742                examples.truncate(limit);
 743            }
 744
 745            smol::block_on(async {
 746                if let Err(e) =
 747                    repair::run_repair(&mut examples, repair_args, output.as_ref()).await
 748                {
 749                    eprintln!("Error: {:?}", e);
 750                    std::process::exit(1);
 751                }
 752            });
 753            return;
 754        }
 755        _ => {}
 756    }
 757
 758    let http_client = Arc::new(ReqwestClient::new());
 759    let app = Application::headless().with_http_client(http_client);
 760
 761    app.run(move |cx| {
 762        let app_state = Arc::new(headless::init(cx));
 763        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
 764
 765        cx.spawn(async move |cx| {
 766            let result = async {
 767                let examples = load_examples(
 768                    app_state.client.http_client(),
 769                    &args,
 770                    output.as_ref(),
 771                    cx.background_executor().clone(),
 772                )
 773                .await?;
 774
 775                match &command {
 776                    Command::Predict(args) | Command::Score(args) => {
 777                        predict::sync_batches(args.provider.as_ref()).await?;
 778                    }
 779                    Command::Eval(args) => {
 780                        predict::sync_batches(args.predict.provider.as_ref()).await?;
 781                    }
 782                    _ => (),
 783                }
 784
 785                let failfast_on_single_example = examples.len() == 1;
 786
 787                // For --markdown mode, create the output directory if it doesn't exist
 788                let markdown_output_dir = if args.markdown {
 789                    let dir = output.as_ref().expect("--markdown requires -o");
 790                    if !dir.exists() {
 791                        std::fs::create_dir_all(dir)
 792                            .expect("Failed to create markdown output directory");
 793                    }
 794                    Some(dir.clone())
 795                } else {
 796                    None
 797                };
 798
 799                // For --in-place, write to a temp file and rename at the end to avoid data loss on interruption
 800                let in_place_temp_path = if args.in_place {
 801                    output.as_ref().map(|path| {
 802                        let mut temp_path = path.clone();
 803                        temp_path.set_extension("jsonl.tmp");
 804                        temp_path
 805                    })
 806                } else {
 807                    None
 808                };
 809
 810                let output_sender: Option<mpsc::UnboundedSender<String>> = if !args.markdown
 811                    && (args.output.is_some() || !matches!(command, Command::Eval(_)))
 812                {
 813                    let write_path = in_place_temp_path.as_ref().or(output.as_ref());
 814                    write_path.map(|path| {
 815                        let file = if args.in_place {
 816                            // For --in-place, write to temp file (truncate if exists)
 817                            OpenOptions::new()
 818                                .create(true)
 819                                .write(true)
 820                                .truncate(true)
 821                                .open(path)
 822                                .expect("Failed to open temp output file")
 823                        } else {
 824                            // For regular output, append to support resuming
 825                            OpenOptions::new()
 826                                .create(true)
 827                                .append(true)
 828                                .open(path)
 829                                .expect("Failed to open output file")
 830                        };
 831                        let mut writer = BufWriter::new(file);
 832                        let (sender, mut receiver) = mpsc::unbounded::<String>();
 833                        cx.background_spawn(async move {
 834                            while let Some(line) = receiver.next().await {
 835                                writeln!(writer, "{}", line).expect("Failed to write example");
 836                                writer.flush().expect("Failed to flush output");
 837                            }
 838                        })
 839                        .detach();
 840                        sender
 841                    })
 842                } else {
 843                    None
 844                };
 845
 846                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
 847                let finished_examples = Mutex::new(Vec::new());
 848
 849                let mut tasks = Vec::new();
 850                for _ in 0..args.max_parallelism {
 851                    tasks.push(async {
 852                        loop {
 853                            let Some(mut repo_examples) =
 854                                grouped_examples.lock().unwrap().pop_front()
 855                            else {
 856                                break;
 857                            };
 858                            for example in &mut repo_examples {
 859                                let example_progress =
 860                                    Progress::global().start_group(&example.spec.name);
 861
 862                                let result = async {
 863                                    match &command {
 864                                        Command::Read => {}
 865                                        Command::LoadProject => {
 866                                            run_load_project(
 867                                                example,
 868                                                app_state.clone(),
 869                                                &example_progress,
 870                                                cx.clone(),
 871                                            )
 872                                            .await?;
 873                                        }
 874                                        Command::Context => {
 875                                            run_context_retrieval(
 876                                                example,
 877                                                app_state.clone(),
 878                                                &example_progress,
 879                                                cx.clone(),
 880                                            )
 881                                            .await?;
 882                                        }
 883                                        Command::FormatPrompt(args) => {
 884                                            run_format_prompt(
 885                                                example,
 886                                                args,
 887                                                app_state.clone(),
 888                                                &example_progress,
 889                                                cx.clone(),
 890                                            )
 891                                            .await?;
 892                                        }
 893                                        Command::Predict(args) => {
 894                                            run_prediction(
 895                                                example,
 896                                                args,
 897                                                app_state.clone(),
 898                                                &example_progress,
 899                                                cx.clone(),
 900                                            )
 901                                            .await?;
 902                                        }
 903                                        Command::ParseOutput => {
 904                                            parse_output::run_parse_output(example)?;
 905                                        }
 906                                        Command::Distill => {
 907                                            run_distill(example).await?;
 908                                        }
 909                                        Command::Score(args) => {
 910                                            run_scoring(
 911                                                example,
 912                                                args,
 913                                                app_state.clone(),
 914                                                &example_progress,
 915                                                cx.clone(),
 916                                            )
 917                                            .await?;
 918                                        }
 919                                        Command::Eval(args) => {
 920                                            run_scoring(
 921                                                example,
 922                                                &args.predict,
 923                                                app_state.clone(),
 924                                                &example_progress,
 925                                                cx.clone(),
 926                                            )
 927                                            .await?;
 928                                        }
 929                                        Command::Clean
 930                                        | Command::Synthesize(_)
 931                                        | Command::SplitCommit(_)
 932                                        | Command::Split(_)
 933                                        | Command::FilterLanguages(_)
 934                                        | Command::ImportBatch(_)
 935                                        | Command::Qa(_)
 936                                        | Command::Repair(_) => {
 937                                            unreachable!()
 938                                        }
 939                                    }
 940                                    anyhow::Ok(())
 941                                }
 942                                .await;
 943
 944                                let failed = if let Err(error) = result {
 945                                    handle_error(
 946                                        error,
 947                                        &args,
 948                                        &command,
 949                                        &app_state,
 950                                        failfast_on_single_example,
 951                                        &example,
 952                                    )
 953                                    .await;
 954                                    true
 955                                } else {
 956                                    false
 957                                };
 958
 959                                let should_write = !failed || args.failed == FailedHandling::Keep;
 960                                if should_write {
 961                                    if let Some(ref markdown_dir) = markdown_output_dir {
 962                                        let filename = format!("{}.md", example.spec.filename());
 963                                        let path = markdown_dir.join(&filename);
 964                                        let markdown = example.spec.to_markdown();
 965                                        std::fs::write(&path, &markdown)
 966                                            .expect("Failed to write markdown file");
 967                                    } else if let Some(ref mut sender) = output_sender.clone() {
 968                                        let line = serde_json::to_string(&example).unwrap();
 969                                        sender
 970                                            .send(line)
 971                                            .await
 972                                            .expect("Failed to send to output writer");
 973                                    } else if args.output.is_none()
 974                                        && !matches!(command, Command::Eval(_))
 975                                    {
 976                                        let line = serde_json::to_string(&example).unwrap();
 977                                        println!("{}", line);
 978                                    }
 979                                }
 980                            }
 981
 982                            let repo_url = &repo_examples.first().unwrap().spec.repository_url;
 983                            let project = repo_examples
 984                                .iter()
 985                                .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
 986                                .or_else(|| app_state.project_cache.get(repo_url));
 987
 988                            if let Some(project) = project {
 989                                let mut cx = cx.clone();
 990
 991                                let shutdown_task: Task<()> =
 992                                    project.update(&mut cx, |project, cx| {
 993                                        let lsp_store = project.lsp_store();
 994                                        lsp_store.update(cx, |lsp_store, cx| {
 995                                            lsp_store.shutdown_all_language_servers(cx)
 996                                        })
 997                                    });
 998
 999                                shutdown_task.await;
1000
1001                                if let Some(ep_store) =
1002                                    cx.update(|cx| EditPredictionStore::try_global(cx))
1003                                {
1004                                    ep_store.update(&mut cx, |store, _| {
1005                                        store.remove_project(&project);
1006                                    });
1007                                }
1008                            }
1009
1010                            app_state.project_cache.remove(repo_url);
1011                            for example in &mut repo_examples {
1012                                example.state.take();
1013                            }
1014                            finished_examples
1015                                .lock()
1016                                .unwrap()
1017                                .extend_from_slice(&repo_examples);
1018                        }
1019                    });
1020                }
1021                futures::future::join_all(tasks).await;
1022
1023                Progress::global().finalize();
1024
1025                match &command {
1026                    Command::Predict(args) | Command::Score(args) => {
1027                        predict::sync_batches(args.provider.as_ref()).await?;
1028                    }
1029                    Command::Eval(args) => {
1030                        predict::sync_batches(args.predict.provider.as_ref()).await?;
1031                    }
1032                    _ => (),
1033                }
1034
1035                match &command {
1036                    Command::Eval(args) => {
1037                        let examples = finished_examples.lock().unwrap();
1038                        score::print_report(&examples);
1039                        if let Some(summary_path) = &args.summary_json {
1040                            score::write_summary_json(&examples, summary_path)?;
1041                        }
1042                    }
1043                    _ => (),
1044                };
1045
1046                // For --in-place, atomically rename temp file to original
1047                if let (Some(temp_path), Some(final_path)) = (&in_place_temp_path, &output) {
1048                    std::fs::rename(temp_path, final_path)
1049                        .expect("Failed to rename temp file to final output");
1050                }
1051
1052                anyhow::Ok(())
1053            }
1054            .await;
1055
1056            if let Err(e) = result {
1057                panic!("Fatal error: {:?}", e);
1058            }
1059
1060            let _ = cx.update(|cx| cx.quit());
1061        })
1062        .detach();
1063    });
1064}
1065
1066async fn handle_error(
1067    error: anyhow::Error,
1068    args: &EpArgs,
1069    command: &Command,
1070    app_state: &Arc<headless::EpAppState>,
1071    failfast_on_single_example: bool,
1072    example: &Example,
1073) {
1074    Progress::global().increment_failed();
1075
1076    let msg;
1077    if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1078        let example_name = example.spec.filename();
1079
1080        let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1081        app_state
1082            .fs
1083            .write(
1084                &failed_example_path,
1085                &serde_json::to_vec_pretty(&example).unwrap(),
1086            )
1087            .await
1088            .unwrap();
1089        let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1090        app_state
1091            .fs
1092            .write(&err_path, format!("{error:?}").as_bytes())
1093            .await
1094            .unwrap();
1095
1096        let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1097        let mut file = OpenOptions::new()
1098            .create(true)
1099            .append(true)
1100            .open(&failed_jsonl_path)
1101            .expect("Failed to open failed.jsonl");
1102        writeln!(file, "{}", serde_json::to_string(example).unwrap())
1103            .expect("Failed to write to failed.jsonl");
1104
1105        let cursor_path = example
1106            .repo_name()
1107            .unwrap()
1108            .worktree_path()
1109            .join(&example.spec.cursor_path);
1110        msg = format!(
1111            indoc::indoc! {"
1112                While processing \"{}\":
1113
1114                \x1b[31m{:?}\x1b[0m
1115
1116                Example:        \x1b[36m{}\x1b[0m
1117                Error file:     \x1b[36m{}\x1b[0m
1118                Cursor file:    \x1b[36m{}\x1b[0m
1119                Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1120            "},
1121            example.spec.name,
1122            error,
1123            failed_example_path.display(),
1124            err_path.display(),
1125            cursor_path.display(),
1126            command,
1127            failed_example_path.display(),
1128        );
1129    } else {
1130        msg = format!(
1131            indoc::indoc! {"
1132            While processing \"{}\":
1133
1134                \x1b[31m{:?}\x1b[0m
1135            "},
1136            example.spec.name, error
1137        );
1138    }
1139
1140    if args.failfast || failfast_on_single_example {
1141        Progress::global().finalize();
1142        panic!("{}", msg);
1143    } else {
1144        log::error!("{}", msg);
1145    }
1146}