main.rs

   1mod anthropic_client;
   2mod distill;
   3mod example;
   4mod filter_languages;
   5mod format_prompt;
   6mod git;
   7mod headless;
   8mod load_project;
   9mod metrics;
  10mod openai_client;
  11mod parse_output;
  12mod paths;
  13mod predict;
  14mod progress;
  15mod pull_examples;
  16mod qa;
  17mod reorder_patch;
  18mod repair;
  19mod retrieve_context;
  20mod reversal_tracking;
  21mod score;
  22mod split_commit;
  23mod split_dataset;
  24mod synthesize;
  25mod word_diff;
  26use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
  27use collections::HashSet;
  28use edit_prediction::EditPredictionStore;
  29use futures::channel::mpsc;
  30use futures::{SinkExt as _, StreamExt as _};
  31use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
  32use zeta_prompt::ZetaVersion;
  33
  34use reqwest_client::ReqwestClient;
  35use serde::{Deserialize, Deserializer, Serialize, Serializer};
  36use std::fmt::Display;
  37use std::fs::{File, OpenOptions};
  38use std::hash::{Hash, Hasher};
  39use std::io::{BufRead, BufReader, BufWriter, Write};
  40use std::sync::Mutex;
  41use std::{path::PathBuf, sync::Arc};
  42
  43use crate::distill::run_distill;
  44use crate::example::{Example, group_examples_by_repo, read_example_files};
  45use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
  46use crate::format_prompt::run_format_prompt;
  47use crate::load_project::run_load_project;
  48use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
  49use crate::predict::run_prediction;
  50use crate::progress::Progress;
  51use crate::retrieve_context::run_context_retrieval;
  52use crate::score::run_scoring;
  53use crate::split_commit::SplitCommitArgs;
  54use crate::split_dataset::SplitArgs;
  55use crate::synthesize::{SynthesizeConfig, run_synthesize};
  56
  57#[derive(Parser, Debug)]
  58#[command(name = "ep")]
  59struct EpArgs {
  60    #[arg(long, default_value_t = false)]
  61    printenv: bool,
  62    #[clap(long, default_value_t = 10, global = true)]
  63    max_parallelism: usize,
  64    #[clap(long, global = true)]
  65    limit: Option<usize>,
  66    #[clap(long, global = true)]
  67    offset: Option<usize>,
  68    /// Filter examples by name
  69    #[clap(long, global = true)]
  70    name: Option<String>,
  71    /// Filter examples by repository
  72    #[clap(long, global = true)]
  73    repo: Option<String>,
  74    #[command(subcommand)]
  75    command: Option<Command>,
  76    #[clap(global = true, help = INPUTS_HELP)]
  77    inputs: Vec<PathBuf>,
  78    #[arg(long, short, global = true)]
  79    output: Option<PathBuf>,
  80    #[arg(long, short, global = true)]
  81    in_place: bool,
  82    #[arg(long, short, global = true)]
  83    failfast: bool,
  84    /// How to handle failed examples in output: keep them or skip them.
  85    /// Failed examples are always logged to the run's failed directory.
  86    #[arg(long, global = true, default_value = "keep")]
  87    failed: FailedHandling,
  88    /// Output as markdown files instead of JSONL. When set, -o specifies a directory
  89    /// where one .md file per example will be written (named after each example).
  90    #[arg(long, short, global = true)]
  91    markdown: bool,
  92}
  93
  94/// Controls whether failed examples are included in the main output.
  95/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
  96#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
  97pub enum FailedHandling {
  98    /// Include failed examples in the main output (default)
  99    #[default]
 100    Keep,
 101    /// Exclude failed examples from the main output
 102    Skip,
 103    /// Skip writing files
 104    SkipNoFiles,
 105}
 106
 107const INPUTS_HELP: &str = r#"
 108Inputs can be file paths or special specifiers:
 109
 110  path
 111      Path to an example(s) file (.md, .json, or .jsonl)
 112
 113  captured-after:{timestamp}
 114      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 115      These are examples captured via the "Capture Edit Prediction Example" action.
 116
 117  rejected-after:{timestamp}
 118      Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
 119      These are predictions that were shown to users but rejected (useful for DPO training).
 120
 121      Required environment variables to connect to Snowflake:
 122          EP_SNOWFLAKE_API_KEY
 123          EP_SNOWFLAKE_BASE_URL
 124
 125      Optional:
 126          EP_SNOWFLAKE_ROLE
 127
 128Examples:
 129
 130  # Read examples from a file
 131  ep read examples.jsonl -o output.jsonl
 132
 133  # Read captured examples after a timestamp
 134  ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
 135
 136  # Read rejected predictions for DPO training
 137  ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
 138
 139  # Mix multiple input sources
 140  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
 141"#;
 142
 143#[derive(Subcommand, Debug, Clone)]
 144enum Command {
 145    /// Read examples from files or fetch from Snowflake, output as .jsonl
 146    Read,
 147    /// Create git worktrees for each example and load file contents
 148    LoadProject,
 149    /// Retrieve context for input examples.
 150    Context,
 151    /// Generate a prompt string for a specific model
 152    FormatPrompt(FormatPromptArgs),
 153    /// Runs edit prediction
 154    Predict(PredictArgs),
 155    /// Parse model outputs (actual_output) into unified diffs (actual_patch).
 156    /// Requires format-prompt to have been run first. Uses provider from prompt.
 157    ParseOutput,
 158    /// Computes a score based on actual and expected patches
 159    Score(PredictArgs),
 160    /// Prepares a distillation dataset by copying expected outputs to
 161    /// predicted outputs and removing actual outputs and prompts.
 162    Distill,
 163    /// Print aggregated scores
 164    Eval(EvalArgs),
 165    /// Generate eval examples by analyzing git commits from a repository
 166    Synthesize(SynthesizeArgs),
 167    /// Remove git repositories and worktrees
 168    Clean,
 169    /// Generate an evaluation example by splitting a chronologically-ordered commit
 170    SplitCommit(SplitCommitArgs),
 171    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
 172    Split(SplitArgs),
 173    /// Filter a JSONL dataset by programming language (based on cursor_path extension)
 174    FilterLanguages(FilterLanguagesArgs),
 175    /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
 176    ImportBatch(ImportBatchArgs),
 177    /// Assess the quality of predictions using LLM-as-a-judge
 178    Qa(qa::QaArgs),
 179    /// Repair predictions that received poor QA scores by generating improved predictions
 180    Repair(repair::RepairArgs),
 181}
 182
 183impl Display for Command {
 184    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 185        match self {
 186            Command::Read => write!(f, "read"),
 187            Command::LoadProject => write!(f, "load-project"),
 188            Command::Context => write!(f, "context"),
 189            Command::FormatPrompt(args) => {
 190                write!(f, "format-prompt --provider={}", args.provider)
 191            }
 192            Command::Predict(args) => match &args.provider {
 193                Some(provider) => write!(f, "predict --provider={}", provider),
 194                None => write!(f, "predict"),
 195            },
 196            Command::ParseOutput => write!(f, "parse-output"),
 197            Command::Score(args) => match &args.provider {
 198                Some(provider) => write!(f, "score --provider={}", provider),
 199                None => write!(f, "score"),
 200            },
 201            Command::Distill => write!(f, "distill"),
 202            Command::Eval(args) => match &args.predict.provider {
 203                Some(provider) => write!(f, "eval --provider={}", provider),
 204                None => write!(f, "eval"),
 205            },
 206            Command::Synthesize(args) => {
 207                write!(f, "synthesize --repos {}", args.repos.join(" "))
 208            }
 209            Command::Clean => write!(f, "clean"),
 210            Command::SplitCommit(_) => write!(f, "split-commit"),
 211            Command::Split(_) => write!(f, "split"),
 212            Command::FilterLanguages(_) => write!(f, "filter-languages"),
 213            Command::ImportBatch(args) => {
 214                write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
 215            }
 216            Command::Qa(_) => {
 217                write!(f, "qa")
 218            }
 219            Command::Repair(_) => {
 220                write!(f, "repair")
 221            }
 222        }
 223    }
 224}
 225
 226#[derive(Debug, Args, Clone)]
 227struct FormatPromptArgs {
 228    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
 229    provider: PredictionProvider,
 230}
 231
 232#[derive(Debug, Args, Clone)]
 233struct PredictArgs {
 234    #[clap(long, short('p'))]
 235    provider: Option<PredictionProvider>,
 236    #[clap(long, default_value_t = 1)]
 237    repetitions: usize,
 238}
 239
 240#[derive(Debug, Args, Clone)]
 241struct EvalArgs {
 242    #[clap(flatten)]
 243    predict: PredictArgs,
 244    /// Path to write summary scores as JSON
 245    #[clap(long)]
 246    summary_json: Option<PathBuf>,
 247}
 248
 249#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 250pub enum TeacherBackend {
 251    Sonnet45,
 252    Gpt52,
 253}
 254
 255impl std::fmt::Display for TeacherBackend {
 256    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 257        match self {
 258            TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
 259            TeacherBackend::Gpt52 => write!(f, "gpt52"),
 260        }
 261    }
 262}
 263
 264impl std::str::FromStr for TeacherBackend {
 265    type Err = anyhow::Error;
 266
 267    fn from_str(s: &str) -> Result<Self, Self::Err> {
 268        match s.to_lowercase().as_str() {
 269            "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
 270            "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
 271            "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
 272            _ => anyhow::bail!("unknown teacher backend `{s}`. Valid options: sonnet45, gpt52"),
 273        }
 274    }
 275}
 276
 277impl TeacherBackend {
 278    pub fn model_name(&self) -> &'static str {
 279        match self {
 280            TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
 281            TeacherBackend::Gpt52 => "gpt-5.2",
 282        }
 283    }
 284}
 285
 286#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 287enum PredictionProvider {
 288    Sweep,
 289    Mercury,
 290    Zeta1,
 291    Zeta2(ZetaVersion),
 292    Teacher(TeacherBackend),
 293    TeacherNonBatching(TeacherBackend),
 294    Repair,
 295}
 296
 297impl Default for PredictionProvider {
 298    fn default() -> Self {
 299        PredictionProvider::Zeta2(ZetaVersion::default())
 300    }
 301}
 302
 303impl std::fmt::Display for PredictionProvider {
 304    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 305        match self {
 306            PredictionProvider::Sweep => write!(f, "sweep"),
 307            PredictionProvider::Mercury => write!(f, "mercury"),
 308            PredictionProvider::Zeta1 => write!(f, "zeta1"),
 309            PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
 310            PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
 311            PredictionProvider::TeacherNonBatching(backend) => {
 312                write!(f, "teacher-non-batching:{backend}")
 313            }
 314            PredictionProvider::Repair => write!(f, "repair"),
 315        }
 316    }
 317}
 318
 319impl std::str::FromStr for PredictionProvider {
 320    type Err = anyhow::Error;
 321
 322    fn from_str(s: &str) -> Result<Self, Self::Err> {
 323        let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
 324
 325        let provider_lower = provider.to_lowercase();
 326        match provider_lower.as_str() {
 327            "sweep" => Ok(PredictionProvider::Sweep),
 328            "mercury" => Ok(PredictionProvider::Mercury),
 329            "zeta1" => Ok(PredictionProvider::Zeta1),
 330            "zeta2" => {
 331                let version = arg.map(ZetaVersion::parse).transpose()?.unwrap_or_default();
 332                Ok(PredictionProvider::Zeta2(version))
 333            }
 334            "teacher" => {
 335                let backend = arg
 336                    .map(|a| a.parse())
 337                    .transpose()?
 338                    .unwrap_or(TeacherBackend::Sonnet45);
 339                Ok(PredictionProvider::Teacher(backend))
 340            }
 341            "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
 342                let backend = arg
 343                    .map(|a| a.parse())
 344                    .transpose()?
 345                    .unwrap_or(TeacherBackend::Sonnet45);
 346                Ok(PredictionProvider::TeacherNonBatching(backend))
 347            }
 348            "repair" => Ok(PredictionProvider::Repair),
 349            _ => {
 350                anyhow::bail!(
 351                    "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
 352                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
 353                 For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\
 354                 Available zeta versions:\n{}",
 355                    ZetaVersion::options_as_string()
 356                )
 357            }
 358        }
 359    }
 360}
 361
 362impl Serialize for PredictionProvider {
 363    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 364    where
 365        S: Serializer,
 366    {
 367        serializer.serialize_str(&self.to_string())
 368    }
 369}
 370
 371impl<'de> Deserialize<'de> for PredictionProvider {
 372    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 373    where
 374        D: Deserializer<'de>,
 375    {
 376        let s = String::deserialize(deserializer)?;
 377        s.parse().map_err(serde::de::Error::custom)
 378    }
 379}
 380
 381#[derive(Debug, Args, Clone)]
 382struct SynthesizeArgs {
 383    /// Repository URLs (git@github.com:owner/repo or https://...)
 384    #[clap(long, required = true, num_args = 1..)]
 385    repos: Vec<String>,
 386
 387    /// Number of examples to generate per repository
 388    #[clap(long, default_value_t = 5)]
 389    count: usize,
 390
 391    /// Maximum commits to scan per repository before giving up
 392    #[clap(long, default_value_t = 100)]
 393    max_commits: usize,
 394
 395    /// Ignore state file and reprocess all commits
 396    #[clap(long)]
 397    fresh: bool,
 398}
 399
 400#[derive(Debug, Args, Clone)]
 401struct ImportBatchArgs {
 402    /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
 403    #[clap(long, required = true, num_args = 1..)]
 404    batch_ids: Vec<String>,
 405    /// Which provider's batches to import (anthropic or openai)
 406    #[clap(long, default_value = "anthropic")]
 407    provider: BatchProvider,
 408}
 409
 410#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
 411enum BatchProvider {
 412    Anthropic,
 413    Openai,
 414}
 415
 416impl EpArgs {
 417    fn output_path(&self) -> Option<PathBuf> {
 418        if self.in_place {
 419            if self.inputs.len() == 1 {
 420                self.inputs.first().cloned()
 421            } else {
 422                panic!("--in-place requires exactly one input file")
 423            }
 424        } else {
 425            self.output.clone()
 426        }
 427    }
 428}
 429
 430async fn load_examples(
 431    http_client: Arc<dyn http_client::HttpClient>,
 432    args: &EpArgs,
 433    output_path: Option<&PathBuf>,
 434    background_executor: BackgroundExecutor,
 435) -> anyhow::Result<Vec<Example>> {
 436    let mut captured_after_timestamps = Vec::new();
 437    let mut rejected_after_timestamps = Vec::new();
 438    let mut file_inputs = Vec::new();
 439
 440    for input in &args.inputs {
 441        let input_string = input.to_string_lossy();
 442        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
 443            captured_after_timestamps.push(timestamp.to_string());
 444        } else if let Some(timestamp) =
 445            pull_examples::parse_rejected_after_input(input_string.as_ref())
 446        {
 447            rejected_after_timestamps.push(timestamp.to_string());
 448        } else {
 449            file_inputs.push(input.clone());
 450        }
 451    }
 452
 453    let mut examples = read_example_files(&file_inputs);
 454
 455    Progress::global().set_total_examples(examples.len());
 456
 457    let remaining_limit_for_snowflake =
 458        args.limit.map(|limit| limit.saturating_sub(examples.len()));
 459
 460    if let Some(0) = remaining_limit_for_snowflake {
 461        log::info!(
 462            "skipping Snowflake inputs because --limit is already satisfied by example files"
 463        );
 464    } else {
 465        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
 466
 467        if !captured_after_timestamps.is_empty() {
 468            captured_after_timestamps.sort();
 469
 470            let mut captured_examples = pull_examples::fetch_captured_examples_after(
 471                http_client.clone(),
 472                &captured_after_timestamps,
 473                max_rows_per_timestamp,
 474                background_executor.clone(),
 475            )
 476            .await?;
 477            examples.append(&mut captured_examples);
 478        }
 479
 480        if !rejected_after_timestamps.is_empty() {
 481            rejected_after_timestamps.sort();
 482
 483            let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
 484                http_client,
 485                &rejected_after_timestamps,
 486                max_rows_per_timestamp,
 487                background_executor,
 488            )
 489            .await?;
 490            examples.append(&mut rejected_examples);
 491        }
 492    }
 493
 494    crate::example::sort_examples_by_repo_and_rev(&mut examples);
 495
 496    if let Some(name_filter) = &args.name {
 497        examples.retain(|example| example.spec.name.contains(name_filter));
 498    }
 499    if let Some(repo_filter) = &args.repo {
 500        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
 501    }
 502
 503    // Skip resume logic for --in-place since input and output are the same file,
 504    // which would incorrectly treat all input examples as already processed.
 505    if !args.in_place {
 506        if let Some(path) = output_path {
 507            resume_from_output(path, &mut examples);
 508        }
 509    }
 510
 511    if let Some(offset) = args.offset {
 512        examples.splice(0..offset, []);
 513    }
 514
 515    if let Some(limit) = args.limit {
 516        examples.truncate(limit);
 517    }
 518
 519    let progress = Progress::global();
 520    progress.set_total_examples(examples.len());
 521    progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
 522
 523    Ok(examples)
 524}
 525
 526fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
 527    let mut hasher = collections::FxHasher::default();
 528    spec.hash(&mut hasher);
 529    hasher.finish()
 530}
 531
 532fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
 533    let file = match File::open(path) {
 534        Ok(f) => f,
 535        Err(_) => return,
 536    };
 537
 538    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
 539
 540    let reader = BufReader::new(file);
 541    let mut kept_lines = Vec::new();
 542    let mut kept_hashes = HashSet::default();
 543
 544    for line in reader.lines() {
 545        let line = match line {
 546            Ok(l) => l,
 547            Err(_) => continue,
 548        };
 549
 550        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
 551            let hash = spec_hash(&output_example.spec);
 552            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
 553                kept_hashes.insert(hash);
 554                kept_lines.push(line);
 555            }
 556        }
 557    }
 558
 559    let total = examples.len();
 560    let already_processed = kept_hashes.len();
 561
 562    eprintln!(
 563        "Resuming: {}/{} examples already processed",
 564        already_processed, total
 565    );
 566
 567    let file = OpenOptions::new()
 568        .write(true)
 569        .truncate(true)
 570        .open(path)
 571        .expect("Failed to open output file for rewriting");
 572    let mut writer = BufWriter::new(file);
 573    for line in &kept_lines {
 574        writeln!(writer, "{}", line).expect("Failed to write to output file");
 575    }
 576    writer.flush().expect("Failed to flush output file");
 577
 578    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
 579}
 580
 581fn main() {
 582    let args = EpArgs::parse();
 583
 584    if args.printenv {
 585        ::util::shell_env::print_env();
 586        return;
 587    }
 588
 589    let output = args.output_path();
 590
 591    if args.markdown && output.is_none() {
 592        eprintln!("--markdown requires -o to specify the output directory");
 593        std::process::exit(1);
 594    }
 595
 596    let command = match &args.command {
 597        Some(cmd) => cmd.clone(),
 598        None => {
 599            EpArgs::command().print_help().unwrap();
 600            return;
 601        }
 602    };
 603
 604    match &command {
 605        Command::ImportBatch(import_args) => {
 606            smol::block_on(async {
 607                match import_args.provider {
 608                    BatchProvider::Anthropic => {
 609                        let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
 610                            .expect("Failed to create Anthropic client");
 611                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 612                            eprintln!("Error importing Anthropic batches: {:?}", e);
 613                            std::process::exit(1);
 614                        }
 615                    }
 616                    BatchProvider::Openai => {
 617                        let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
 618                            .expect("Failed to create OpenAI client");
 619                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 620                            eprintln!("Error importing OpenAI batches: {:?}", e);
 621                            std::process::exit(1);
 622                        }
 623                    }
 624                }
 625                println!(
 626                    "Successfully imported {} batch(es)",
 627                    import_args.batch_ids.len()
 628                );
 629            });
 630            return;
 631        }
 632        Command::Clean => {
 633            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
 634            return;
 635        }
 636        Command::Synthesize(synth_args) => {
 637            let Some(output_dir) = args.output else {
 638                panic!("output dir is required");
 639            };
 640            let config = SynthesizeConfig {
 641                repo_urls: synth_args.repos.clone(),
 642                count: synth_args.count,
 643                max_commits: synth_args.max_commits,
 644                output_dir,
 645                fresh: synth_args.fresh,
 646            };
 647            smol::block_on(async {
 648                if let Err(e) = run_synthesize(config).await {
 649                    eprintln!("Error: {:?}", e);
 650                    std::process::exit(1);
 651                }
 652            });
 653            return;
 654        }
 655        Command::SplitCommit(split_commit_args) => {
 656            if let Err(error) = split_commit::run_split_commit(
 657                split_commit_args,
 658                &args.inputs,
 659                output.as_ref(),
 660                args.failed,
 661            ) {
 662                eprintln!("{error:#}");
 663                std::process::exit(1);
 664            }
 665            return;
 666        }
 667        Command::Split(split_args) => {
 668            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
 669                eprintln!("{error:#}");
 670                std::process::exit(1);
 671            }
 672            return;
 673        }
 674        Command::FilterLanguages(filter_args) => {
 675            if let Err(error) =
 676                run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
 677            {
 678                eprintln!("{error:#}");
 679                std::process::exit(1);
 680            }
 681            return;
 682        }
 683        Command::Qa(qa_args) => {
 684            // Read examples from input files
 685            let mut examples = example::read_example_files(&args.inputs);
 686
 687            // Apply filters
 688            if let Some(name_filter) = &args.name {
 689                examples.retain(|e| e.spec.name.contains(name_filter));
 690            }
 691            if let Some(repo_filter) = &args.repo {
 692                examples.retain(|e| e.spec.repository_url.contains(repo_filter));
 693            }
 694            if let Some(offset) = args.offset {
 695                examples.splice(0..offset, []);
 696            }
 697            if let Some(limit) = args.limit {
 698                examples.truncate(limit);
 699            }
 700
 701            smol::block_on(async {
 702                if let Err(e) = qa::run_qa(&mut examples, qa_args, output.as_ref()).await {
 703                    eprintln!("Error: {:?}", e);
 704                    std::process::exit(1);
 705                }
 706            });
 707            return;
 708        }
 709        Command::Repair(repair_args) => {
 710            // Read examples from input files
 711            let mut examples = example::read_example_files(&args.inputs);
 712
 713            // Apply filters
 714            if let Some(name_filter) = &args.name {
 715                examples.retain(|e| e.spec.name.contains(name_filter));
 716            }
 717            if let Some(repo_filter) = &args.repo {
 718                examples.retain(|e| e.spec.repository_url.contains(repo_filter));
 719            }
 720            if let Some(offset) = args.offset {
 721                examples.splice(0..offset, []);
 722            }
 723            if let Some(limit) = args.limit {
 724                examples.truncate(limit);
 725            }
 726
 727            smol::block_on(async {
 728                if let Err(e) =
 729                    repair::run_repair(&mut examples, repair_args, output.as_ref()).await
 730                {
 731                    eprintln!("Error: {:?}", e);
 732                    std::process::exit(1);
 733                }
 734            });
 735            return;
 736        }
 737        _ => {}
 738    }
 739
 740    let http_client = Arc::new(ReqwestClient::new());
 741    let app = Application::headless().with_http_client(http_client);
 742
 743    app.run(move |cx| {
 744        let app_state = Arc::new(headless::init(cx));
 745        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
 746
 747        cx.spawn(async move |cx| {
 748            let result = async {
 749                let examples = load_examples(
 750                    app_state.client.http_client(),
 751                    &args,
 752                    output.as_ref(),
 753                    cx.background_executor().clone(),
 754                )
 755                .await?;
 756
 757                match &command {
 758                    Command::Predict(args) | Command::Score(args) => {
 759                        predict::sync_batches(args.provider.as_ref()).await?;
 760                    }
 761                    Command::Eval(args) => {
 762                        predict::sync_batches(args.predict.provider.as_ref()).await?;
 763                    }
 764                    _ => (),
 765                }
 766
 767                let failfast_on_single_example = examples.len() == 1;
 768
 769                // For --markdown mode, create the output directory if it doesn't exist
 770                let markdown_output_dir = if args.markdown {
 771                    let dir = output.as_ref().expect("--markdown requires -o");
 772                    if !dir.exists() {
 773                        std::fs::create_dir_all(dir)
 774                            .expect("Failed to create markdown output directory");
 775                    }
 776                    Some(dir.clone())
 777                } else {
 778                    None
 779                };
 780
 781                // For --in-place, write to a temp file and rename at the end to avoid data loss on interruption
 782                let in_place_temp_path = if args.in_place {
 783                    output.as_ref().map(|path| {
 784                        let mut temp_path = path.clone();
 785                        temp_path.set_extension("jsonl.tmp");
 786                        temp_path
 787                    })
 788                } else {
 789                    None
 790                };
 791
 792                let output_sender: Option<mpsc::UnboundedSender<String>> = if !args.markdown
 793                    && (args.output.is_some() || !matches!(command, Command::Eval(_)))
 794                {
 795                    let write_path = in_place_temp_path.as_ref().or(output.as_ref());
 796                    write_path.map(|path| {
 797                        let file = if args.in_place {
 798                            // For --in-place, write to temp file (truncate if exists)
 799                            OpenOptions::new()
 800                                .create(true)
 801                                .write(true)
 802                                .truncate(true)
 803                                .open(path)
 804                                .expect("Failed to open temp output file")
 805                        } else {
 806                            // For regular output, append to support resuming
 807                            OpenOptions::new()
 808                                .create(true)
 809                                .append(true)
 810                                .open(path)
 811                                .expect("Failed to open output file")
 812                        };
 813                        let mut writer = BufWriter::new(file);
 814                        let (sender, mut receiver) = mpsc::unbounded::<String>();
 815                        cx.background_spawn(async move {
 816                            while let Some(line) = receiver.next().await {
 817                                writeln!(writer, "{}", line).expect("Failed to write example");
 818                                writer.flush().expect("Failed to flush output");
 819                            }
 820                        })
 821                        .detach();
 822                        sender
 823                    })
 824                } else {
 825                    None
 826                };
 827
 828                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
 829                let finished_examples = Mutex::new(Vec::new());
 830
 831                let mut tasks = Vec::new();
 832                for _ in 0..args.max_parallelism {
 833                    tasks.push(async {
 834                        loop {
 835                            let Some(mut repo_examples) =
 836                                grouped_examples.lock().unwrap().pop_front()
 837                            else {
 838                                break;
 839                            };
 840                            for example in &mut repo_examples {
 841                                let example_progress =
 842                                    Progress::global().start_group(&example.spec.name);
 843
 844                                let result = async {
 845                                    match &command {
 846                                        Command::Read => {}
 847                                        Command::LoadProject => {
 848                                            run_load_project(
 849                                                example,
 850                                                app_state.clone(),
 851                                                &example_progress,
 852                                                cx.clone(),
 853                                            )
 854                                            .await?;
 855                                        }
 856                                        Command::Context => {
 857                                            run_context_retrieval(
 858                                                example,
 859                                                app_state.clone(),
 860                                                &example_progress,
 861                                                cx.clone(),
 862                                            )
 863                                            .await?;
 864                                        }
 865                                        Command::FormatPrompt(args) => {
 866                                            run_format_prompt(
 867                                                example,
 868                                                args,
 869                                                app_state.clone(),
 870                                                &example_progress,
 871                                                cx.clone(),
 872                                            )
 873                                            .await?;
 874                                        }
 875                                        Command::Predict(args) => {
 876                                            run_prediction(
 877                                                example,
 878                                                args,
 879                                                app_state.clone(),
 880                                                &example_progress,
 881                                                cx.clone(),
 882                                            )
 883                                            .await?;
 884                                        }
 885                                        Command::ParseOutput => {
 886                                            parse_output::run_parse_output(example)?;
 887                                        }
 888                                        Command::Distill => {
 889                                            run_distill(example).await?;
 890                                        }
 891                                        Command::Score(args) => {
 892                                            run_scoring(
 893                                                example,
 894                                                args,
 895                                                app_state.clone(),
 896                                                &example_progress,
 897                                                cx.clone(),
 898                                            )
 899                                            .await?;
 900                                        }
 901                                        Command::Eval(args) => {
 902                                            run_scoring(
 903                                                example,
 904                                                &args.predict,
 905                                                app_state.clone(),
 906                                                &example_progress,
 907                                                cx.clone(),
 908                                            )
 909                                            .await?;
 910                                        }
 911                                        Command::Clean
 912                                        | Command::Synthesize(_)
 913                                        | Command::SplitCommit(_)
 914                                        | Command::Split(_)
 915                                        | Command::FilterLanguages(_)
 916                                        | Command::ImportBatch(_)
 917                                        | Command::Qa(_)
 918                                        | Command::Repair(_) => {
 919                                            unreachable!()
 920                                        }
 921                                    }
 922                                    anyhow::Ok(())
 923                                }
 924                                .await;
 925
 926                                let failed = if let Err(error) = result {
 927                                    handle_error(
 928                                        error,
 929                                        &args,
 930                                        &command,
 931                                        &app_state,
 932                                        failfast_on_single_example,
 933                                        &example,
 934                                    )
 935                                    .await;
 936                                    true
 937                                } else {
 938                                    false
 939                                };
 940
 941                                let should_write = !failed || args.failed == FailedHandling::Keep;
 942                                if should_write {
 943                                    if let Some(ref markdown_dir) = markdown_output_dir {
 944                                        let filename = format!("{}.md", example.spec.filename());
 945                                        let path = markdown_dir.join(&filename);
 946                                        let markdown = example.spec.to_markdown();
 947                                        std::fs::write(&path, &markdown)
 948                                            .expect("Failed to write markdown file");
 949                                    } else if let Some(ref mut sender) = output_sender.clone() {
 950                                        let line = serde_json::to_string(&example).unwrap();
 951                                        sender
 952                                            .send(line)
 953                                            .await
 954                                            .expect("Failed to send to output writer");
 955                                    } else if args.output.is_none()
 956                                        && !matches!(command, Command::Eval(_))
 957                                    {
 958                                        let line = serde_json::to_string(&example).unwrap();
 959                                        println!("{}", line);
 960                                    }
 961                                }
 962                            }
 963
 964                            let repo_url = &repo_examples.first().unwrap().spec.repository_url;
 965                            let project = repo_examples
 966                                .iter()
 967                                .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
 968                                .or_else(|| app_state.project_cache.get(repo_url));
 969
 970                            if let Some(project) = project {
 971                                let mut cx = cx.clone();
 972
 973                                let shutdown_task: Task<()> =
 974                                    project.update(&mut cx, |project, cx| {
 975                                        let lsp_store = project.lsp_store();
 976                                        lsp_store.update(cx, |lsp_store, cx| {
 977                                            lsp_store.shutdown_all_language_servers(cx)
 978                                        })
 979                                    });
 980
 981                                shutdown_task.await;
 982
 983                                if let Some(ep_store) =
 984                                    cx.update(|cx| EditPredictionStore::try_global(cx))
 985                                {
 986                                    ep_store.update(&mut cx, |store, _| {
 987                                        store.remove_project(&project);
 988                                    });
 989                                }
 990                            }
 991
 992                            app_state.project_cache.remove(repo_url);
 993                            for example in &mut repo_examples {
 994                                example.state.take();
 995                            }
 996                            finished_examples
 997                                .lock()
 998                                .unwrap()
 999                                .extend_from_slice(&repo_examples);
1000                        }
1001                    });
1002                }
1003                futures::future::join_all(tasks).await;
1004
1005                Progress::global().finalize();
1006
1007                match &command {
1008                    Command::Predict(args) | Command::Score(args) => {
1009                        predict::sync_batches(args.provider.as_ref()).await?;
1010                    }
1011                    Command::Eval(args) => {
1012                        predict::sync_batches(args.predict.provider.as_ref()).await?;
1013                    }
1014                    _ => (),
1015                }
1016
1017                match &command {
1018                    Command::Eval(args) => {
1019                        let examples = finished_examples.lock().unwrap();
1020                        score::print_report(&examples);
1021                        if let Some(summary_path) = &args.summary_json {
1022                            score::write_summary_json(&examples, summary_path)?;
1023                        }
1024                    }
1025                    _ => (),
1026                };
1027
1028                // For --in-place, atomically rename temp file to original
1029                if let (Some(temp_path), Some(final_path)) = (&in_place_temp_path, &output) {
1030                    std::fs::rename(temp_path, final_path)
1031                        .expect("Failed to rename temp file to final output");
1032                }
1033
1034                anyhow::Ok(())
1035            }
1036            .await;
1037
1038            if let Err(e) = result {
1039                panic!("Fatal error: {:?}", e);
1040            }
1041
1042            let _ = cx.update(|cx| cx.quit());
1043        })
1044        .detach();
1045    });
1046}
1047
1048async fn handle_error(
1049    error: anyhow::Error,
1050    args: &EpArgs,
1051    command: &Command,
1052    app_state: &Arc<headless::EpAppState>,
1053    failfast_on_single_example: bool,
1054    example: &Example,
1055) {
1056    Progress::global().increment_failed();
1057
1058    let msg;
1059    if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1060        let example_name = example.spec.filename();
1061
1062        let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1063        app_state
1064            .fs
1065            .write(
1066                &failed_example_path,
1067                &serde_json::to_vec_pretty(&example).unwrap(),
1068            )
1069            .await
1070            .unwrap();
1071        let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1072        app_state
1073            .fs
1074            .write(&err_path, format!("{error:?}").as_bytes())
1075            .await
1076            .unwrap();
1077
1078        let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1079        let mut file = OpenOptions::new()
1080            .create(true)
1081            .append(true)
1082            .open(&failed_jsonl_path)
1083            .expect("Failed to open failed.jsonl");
1084        writeln!(file, "{}", serde_json::to_string(example).unwrap())
1085            .expect("Failed to write to failed.jsonl");
1086
1087        let cursor_path = example
1088            .repo_name()
1089            .unwrap()
1090            .worktree_path()
1091            .join(&example.spec.cursor_path);
1092        msg = format!(
1093            indoc::indoc! {"
1094                While processing \"{}\":
1095
1096                \x1b[31m{:?}\x1b[0m
1097
1098                Example:        \x1b[36m{}\x1b[0m
1099                Error file:     \x1b[36m{}\x1b[0m
1100                Cursor file:    \x1b[36m{}\x1b[0m
1101                Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1102            "},
1103            example.spec.name,
1104            error,
1105            failed_example_path.display(),
1106            err_path.display(),
1107            cursor_path.display(),
1108            command,
1109            failed_example_path.display(),
1110        );
1111    } else {
1112        msg = format!(
1113            indoc::indoc! {"
1114            While processing \"{}\":
1115
1116                \x1b[31m{:?}\x1b[0m
1117            "},
1118            example.spec.name, error
1119        );
1120    }
1121
1122    if args.failfast || failfast_on_single_example {
1123        Progress::global().finalize();
1124        panic!("{}", msg);
1125    } else {
1126        log::error!("{}", msg);
1127    }
1128}