main.rs

   1mod anthropic_client;
   2mod distill;
   3mod example;
   4mod filter_languages;
   5mod format_prompt;
   6mod git;
   7mod headless;
   8mod load_project;
   9mod metrics;
  10mod openai_client;
  11mod parse_output;
  12mod paths;
  13mod predict;
  14mod progress;
  15mod pull_examples;
  16mod qa;
  17mod reorder_patch;
  18mod repair;
  19mod retrieve_context;
  20mod reversal_tracking;
  21mod score;
  22mod split_commit;
  23mod split_dataset;
  24mod synthesize;
  25mod word_diff;
  26use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
  27use collections::HashSet;
  28use edit_prediction::EditPredictionStore;
  29use futures::channel::mpsc;
  30use futures::{SinkExt as _, StreamExt as _};
  31use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
  32use zeta_prompt::ZetaVersion;
  33
  34use reqwest_client::ReqwestClient;
  35use serde::{Deserialize, Deserializer, Serialize, Serializer};
  36use std::fmt::Display;
  37use std::fs::{File, OpenOptions};
  38use std::hash::{Hash, Hasher};
  39use std::io::{BufRead, BufReader, BufWriter, Write};
  40use std::sync::Mutex;
  41use std::{path::PathBuf, sync::Arc};
  42
  43use crate::distill::run_distill;
  44use crate::example::{Example, group_examples_by_repo, read_example_files};
  45use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
  46use crate::format_prompt::run_format_prompt;
  47use crate::load_project::run_load_project;
  48use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
  49use crate::predict::run_prediction;
  50use crate::progress::Progress;
  51use crate::retrieve_context::run_context_retrieval;
  52use crate::score::run_scoring;
  53use crate::split_commit::SplitCommitArgs;
  54use crate::split_dataset::SplitArgs;
  55use crate::synthesize::{SynthesizeConfig, run_synthesize};
  56
  57#[derive(Parser, Debug)]
  58#[command(name = "ep")]
  59struct EpArgs {
  60    #[arg(long, default_value_t = false)]
  61    printenv: bool,
  62    #[clap(long, default_value_t = 10, global = true)]
  63    max_parallelism: usize,
  64    #[clap(long, global = true)]
  65    limit: Option<usize>,
  66    #[clap(long, global = true)]
  67    offset: Option<usize>,
  68    /// Filter examples by name
  69    #[clap(long, global = true)]
  70    name: Option<String>,
  71    /// Filter examples by repository
  72    #[clap(long, global = true)]
  73    repo: Option<String>,
  74    #[command(subcommand)]
  75    command: Option<Command>,
  76    #[clap(global = true, help = INPUTS_HELP)]
  77    inputs: Vec<PathBuf>,
  78    #[arg(long, short, global = true)]
  79    output: Option<PathBuf>,
  80    #[arg(long, short, global = true)]
  81    in_place: bool,
  82    #[arg(long, short, global = true)]
  83    failfast: bool,
  84    /// How to handle failed examples in output: keep them or skip them.
  85    /// Failed examples are always logged to the run's failed directory.
  86    #[arg(long, global = true, default_value = "keep")]
  87    failed: FailedHandling,
  88    /// Output as markdown files instead of JSONL. When set, -o specifies a directory
  89    /// where one .md file per example will be written (named after each example).
  90    #[arg(long, short, global = true)]
  91    markdown: bool,
  92}
  93
  94/// Controls whether failed examples are included in the main output.
  95/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
  96#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
  97pub enum FailedHandling {
  98    /// Include failed examples in the main output (default)
  99    #[default]
 100    Keep,
 101    /// Exclude failed examples from the main output
 102    Skip,
 103    /// Skip writing files
 104    SkipNoFiles,
 105}
 106
 107const INPUTS_HELP: &str = r#"
 108Inputs can be file paths or special specifiers:
 109
 110  path
 111      Path to an example(s) file (.md, .json, or .jsonl)
 112
 113  captured-after:{timestamp}
 114      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 115      These are examples captured via the "Capture Edit Prediction Example" action.
 116
 117  rejected-after:{timestamp}
 118      Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
 119      These are predictions that were shown to users but rejected (useful for DPO training).
 120
 121  rated-after:{timestamp}
 122      Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
 123      These are predictions that users explicitly rated as positive or negative via the
 124      rate completions modal. Only zeta2 predictions are included.
 125      - Positive ratings: output becomes expected_patches
 126      - Negative ratings: output becomes rejected_patch
 127
 128  rated-positive-after:{timestamp}
 129      Same as rated-after, but only fetches positively rated predictions.
 130
 131  rated-negative-after:{timestamp}
 132      Same as rated-after, but only fetches negatively rated predictions.
 133
 134      Required environment variables to connect to Snowflake:
 135          EP_SNOWFLAKE_API_KEY
 136          EP_SNOWFLAKE_BASE_URL
 137
 138      Optional:
 139          EP_SNOWFLAKE_ROLE
 140
 141Examples:
 142
 143  # Read examples from a file
 144  ep read examples.jsonl -o output.jsonl
 145
 146  # Read captured examples after a timestamp
 147  ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
 148
 149  # Read rejected predictions for DPO training
 150  ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
 151
 152  # Read user-rated predictions
 153  ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
 154
 155  # Read only positively rated predictions
 156  ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
 157
 158  # Read only negatively rated predictions
 159  ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
 160
 161  # Mix multiple input sources
 162  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
 163"#;
 164
 165#[derive(Subcommand, Debug, Clone)]
 166enum Command {
 167    /// Read examples from files or fetch from Snowflake, output as .jsonl
 168    Read,
 169    /// Create git worktrees for each example and load file contents
 170    LoadProject,
 171    /// Retrieve context for input examples.
 172    Context,
 173    /// Generate a prompt string for a specific model
 174    FormatPrompt(FormatPromptArgs),
 175    /// Runs edit prediction
 176    Predict(PredictArgs),
 177    /// Parse model outputs (actual_output) into unified diffs (actual_patch).
 178    /// Requires format-prompt to have been run first. Uses provider from prompt.
 179    ParseOutput,
 180    /// Computes a score based on actual and expected patches
 181    Score(PredictArgs),
 182    /// Prepares a distillation dataset by copying expected outputs to
 183    /// predicted outputs and removing actual outputs and prompts.
 184    Distill,
 185    /// Print aggregated scores
 186    Eval(EvalArgs),
 187    /// Generate eval examples by analyzing git commits from a repository
 188    Synthesize(SynthesizeArgs),
 189    /// Remove git repositories and worktrees
 190    Clean,
 191    /// Generate an evaluation example by splitting a chronologically-ordered commit
 192    SplitCommit(SplitCommitArgs),
 193    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
 194    Split(SplitArgs),
 195    /// Filter a JSONL dataset by programming language (based on cursor_path extension)
 196    FilterLanguages(FilterLanguagesArgs),
 197    /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
 198    ImportBatch(ImportBatchArgs),
 199    /// Assess the quality of predictions using LLM-as-a-judge
 200    Qa(qa::QaArgs),
 201    /// Repair predictions that received poor QA scores by generating improved predictions
 202    Repair(repair::RepairArgs),
 203}
 204
 205impl Display for Command {
 206    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 207        match self {
 208            Command::Read => write!(f, "read"),
 209            Command::LoadProject => write!(f, "load-project"),
 210            Command::Context => write!(f, "context"),
 211            Command::FormatPrompt(args) => {
 212                write!(f, "format-prompt --provider={}", args.provider)
 213            }
 214            Command::Predict(args) => match &args.provider {
 215                Some(provider) => write!(f, "predict --provider={}", provider),
 216                None => write!(f, "predict"),
 217            },
 218            Command::ParseOutput => write!(f, "parse-output"),
 219            Command::Score(args) => match &args.provider {
 220                Some(provider) => write!(f, "score --provider={}", provider),
 221                None => write!(f, "score"),
 222            },
 223            Command::Distill => write!(f, "distill"),
 224            Command::Eval(args) => match &args.predict.provider {
 225                Some(provider) => write!(f, "eval --provider={}", provider),
 226                None => write!(f, "eval"),
 227            },
 228            Command::Synthesize(args) => {
 229                write!(f, "synthesize --repos {}", args.repos.join(" "))
 230            }
 231            Command::Clean => write!(f, "clean"),
 232            Command::SplitCommit(_) => write!(f, "split-commit"),
 233            Command::Split(_) => write!(f, "split"),
 234            Command::FilterLanguages(_) => write!(f, "filter-languages"),
 235            Command::ImportBatch(args) => {
 236                write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
 237            }
 238            Command::Qa(_) => {
 239                write!(f, "qa")
 240            }
 241            Command::Repair(_) => {
 242                write!(f, "repair")
 243            }
 244        }
 245    }
 246}
 247
 248#[derive(Debug, Args, Clone)]
 249struct FormatPromptArgs {
 250    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
 251    provider: PredictionProvider,
 252}
 253
 254#[derive(Debug, Args, Clone)]
 255struct PredictArgs {
 256    #[clap(long, short('p'))]
 257    provider: Option<PredictionProvider>,
 258    #[clap(long, default_value_t = 1)]
 259    repetitions: usize,
 260    /// Only use cached responses, don't queue new requests for batching
 261    #[clap(long)]
 262    cache_only: bool,
 263}
 264
 265#[derive(Debug, Args, Clone)]
 266struct EvalArgs {
 267    #[clap(flatten)]
 268    predict: PredictArgs,
 269    /// Path to write summary scores as JSON
 270    #[clap(long)]
 271    summary_json: Option<PathBuf>,
 272}
 273
 274#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 275pub enum TeacherBackend {
 276    Sonnet45,
 277    Gpt52,
 278}
 279
 280impl std::fmt::Display for TeacherBackend {
 281    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 282        match self {
 283            TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
 284            TeacherBackend::Gpt52 => write!(f, "gpt52"),
 285        }
 286    }
 287}
 288
 289impl std::str::FromStr for TeacherBackend {
 290    type Err = anyhow::Error;
 291
 292    fn from_str(s: &str) -> Result<Self, Self::Err> {
 293        match s.to_lowercase().as_str() {
 294            "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
 295            "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
 296            "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
 297            _ => anyhow::bail!("unknown teacher backend `{s}`. Valid options: sonnet45, gpt52"),
 298        }
 299    }
 300}
 301
 302impl TeacherBackend {
 303    pub fn model_name(&self) -> &'static str {
 304        match self {
 305            TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
 306            TeacherBackend::Gpt52 => "gpt-5.2",
 307        }
 308    }
 309}
 310
 311#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 312enum PredictionProvider {
 313    Sweep,
 314    Mercury,
 315    Zeta1,
 316    Zeta2(ZetaVersion),
 317    Teacher(TeacherBackend),
 318    TeacherNonBatching(TeacherBackend),
 319    Repair,
 320}
 321
 322impl Default for PredictionProvider {
 323    fn default() -> Self {
 324        PredictionProvider::Zeta2(ZetaVersion::default())
 325    }
 326}
 327
 328impl std::fmt::Display for PredictionProvider {
 329    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 330        match self {
 331            PredictionProvider::Sweep => write!(f, "sweep"),
 332            PredictionProvider::Mercury => write!(f, "mercury"),
 333            PredictionProvider::Zeta1 => write!(f, "zeta1"),
 334            PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
 335            PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
 336            PredictionProvider::TeacherNonBatching(backend) => {
 337                write!(f, "teacher-non-batching:{backend}")
 338            }
 339            PredictionProvider::Repair => write!(f, "repair"),
 340        }
 341    }
 342}
 343
 344impl std::str::FromStr for PredictionProvider {
 345    type Err = anyhow::Error;
 346
 347    fn from_str(s: &str) -> Result<Self, Self::Err> {
 348        let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
 349
 350        let provider_lower = provider.to_lowercase();
 351        match provider_lower.as_str() {
 352            "sweep" => Ok(PredictionProvider::Sweep),
 353            "mercury" => Ok(PredictionProvider::Mercury),
 354            "zeta1" => Ok(PredictionProvider::Zeta1),
 355            "zeta2" => {
 356                let version = arg.map(ZetaVersion::parse).transpose()?.unwrap_or_default();
 357                Ok(PredictionProvider::Zeta2(version))
 358            }
 359            "teacher" => {
 360                let backend = arg
 361                    .map(|a| a.parse())
 362                    .transpose()?
 363                    .unwrap_or(TeacherBackend::Sonnet45);
 364                Ok(PredictionProvider::Teacher(backend))
 365            }
 366            "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
 367                let backend = arg
 368                    .map(|a| a.parse())
 369                    .transpose()?
 370                    .unwrap_or(TeacherBackend::Sonnet45);
 371                Ok(PredictionProvider::TeacherNonBatching(backend))
 372            }
 373            "repair" => Ok(PredictionProvider::Repair),
 374            _ => {
 375                anyhow::bail!(
 376                    "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
 377                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
 378                 For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\
 379                 Available zeta versions:\n{}",
 380                    ZetaVersion::options_as_string()
 381                )
 382            }
 383        }
 384    }
 385}
 386
 387impl Serialize for PredictionProvider {
 388    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 389    where
 390        S: Serializer,
 391    {
 392        serializer.serialize_str(&self.to_string())
 393    }
 394}
 395
 396impl<'de> Deserialize<'de> for PredictionProvider {
 397    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 398    where
 399        D: Deserializer<'de>,
 400    {
 401        let s = String::deserialize(deserializer)?;
 402        s.parse().map_err(serde::de::Error::custom)
 403    }
 404}
 405
 406#[derive(Debug, Args, Clone)]
 407struct SynthesizeArgs {
 408    /// Repository URLs (git@github.com:owner/repo or https://...)
 409    #[clap(long, required = true, num_args = 1..)]
 410    repos: Vec<String>,
 411
 412    /// Number of examples to generate per repository
 413    #[clap(long, default_value_t = 5)]
 414    count: usize,
 415
 416    /// Maximum commits to scan per repository before giving up
 417    #[clap(long, default_value_t = 100)]
 418    max_commits: usize,
 419
 420    /// Ignore state file and reprocess all commits
 421    #[clap(long)]
 422    fresh: bool,
 423}
 424
 425#[derive(Debug, Args, Clone)]
 426struct ImportBatchArgs {
 427    /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
 428    #[clap(long, required = true, num_args = 1..)]
 429    batch_ids: Vec<String>,
 430    /// Which provider's batches to import (anthropic or openai)
 431    #[clap(long, default_value = "anthropic")]
 432    provider: BatchProvider,
 433}
 434
 435#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
 436enum BatchProvider {
 437    Anthropic,
 438    Openai,
 439}
 440
 441impl EpArgs {
 442    fn output_path(&self) -> Option<PathBuf> {
 443        if self.in_place {
 444            if self.inputs.len() == 1 {
 445                self.inputs.first().cloned()
 446            } else {
 447                panic!("--in-place requires exactly one input file")
 448            }
 449        } else {
 450            self.output.clone()
 451        }
 452    }
 453}
 454
 455async fn load_examples(
 456    http_client: Arc<dyn http_client::HttpClient>,
 457    args: &EpArgs,
 458    output_path: Option<&PathBuf>,
 459    background_executor: BackgroundExecutor,
 460) -> anyhow::Result<Vec<Example>> {
 461    let mut captured_after_timestamps = Vec::new();
 462    let mut rejected_after_timestamps = Vec::new();
 463    let mut requested_after_timestamps = Vec::new();
 464    let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
 465        Vec::new();
 466    let mut file_inputs = Vec::new();
 467
 468    for input in &args.inputs {
 469        let input_string = input.to_string_lossy();
 470        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
 471            captured_after_timestamps.push(timestamp.to_string());
 472        } else if let Some(timestamp) =
 473            pull_examples::parse_rejected_after_input(input_string.as_ref())
 474        {
 475            rejected_after_timestamps.push(timestamp.to_string());
 476        } else if let Some(timestamp) =
 477            pull_examples::parse_requested_after_input(input_string.as_ref())
 478        {
 479            requested_after_timestamps.push(timestamp.to_string());
 480        } else if let Some((timestamp, rating_filter)) =
 481            pull_examples::parse_rated_after_input(input_string.as_ref())
 482        {
 483            rated_after_inputs.push((timestamp.to_string(), rating_filter));
 484        } else {
 485            file_inputs.push(input.clone());
 486        }
 487    }
 488
 489    let mut examples = read_example_files(&file_inputs);
 490
 491    Progress::global().set_total_examples(examples.len());
 492
 493    let remaining_limit_for_snowflake =
 494        args.limit.map(|limit| limit.saturating_sub(examples.len()));
 495
 496    if let Some(0) = remaining_limit_for_snowflake {
 497        log::info!(
 498            "skipping Snowflake inputs because --limit is already satisfied by example files"
 499        );
 500    } else {
 501        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
 502
 503        if !captured_after_timestamps.is_empty() {
 504            captured_after_timestamps.sort();
 505
 506            let mut captured_examples = pull_examples::fetch_captured_examples_after(
 507                http_client.clone(),
 508                &captured_after_timestamps,
 509                max_rows_per_timestamp,
 510                background_executor.clone(),
 511            )
 512            .await?;
 513            examples.append(&mut captured_examples);
 514        }
 515
 516        if !rejected_after_timestamps.is_empty() {
 517            rejected_after_timestamps.sort();
 518
 519            let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
 520                http_client.clone(),
 521                &rejected_after_timestamps,
 522                max_rows_per_timestamp,
 523                background_executor.clone(),
 524            )
 525            .await?;
 526            examples.append(&mut rejected_examples);
 527        }
 528
 529        if !requested_after_timestamps.is_empty() {
 530            requested_after_timestamps.sort();
 531
 532            let mut requested_examples = pull_examples::fetch_requested_examples_after(
 533                http_client.clone(),
 534                &requested_after_timestamps,
 535                max_rows_per_timestamp,
 536                background_executor.clone(),
 537            )
 538            .await?;
 539            examples.append(&mut requested_examples);
 540        }
 541
 542        if !rated_after_inputs.is_empty() {
 543            rated_after_inputs.sort();
 544
 545            let mut rated_examples = pull_examples::fetch_rated_examples_after(
 546                http_client,
 547                &rated_after_inputs,
 548                max_rows_per_timestamp,
 549                background_executor,
 550            )
 551            .await?;
 552            examples.append(&mut rated_examples);
 553        }
 554    }
 555
 556    crate::example::sort_examples_by_repo_and_rev(&mut examples);
 557
 558    if let Some(name_filter) = &args.name {
 559        examples.retain(|example| example.spec.name.contains(name_filter));
 560    }
 561    if let Some(repo_filter) = &args.repo {
 562        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
 563    }
 564
 565    // Skip resume logic for --in-place since input and output are the same file,
 566    // which would incorrectly treat all input examples as already processed.
 567    if !args.in_place {
 568        if let Some(path) = output_path {
 569            resume_from_output(path, &mut examples);
 570        }
 571    }
 572
 573    if let Some(offset) = args.offset {
 574        examples.splice(0..offset, []);
 575    }
 576
 577    if let Some(limit) = args.limit {
 578        examples.truncate(limit);
 579    }
 580
 581    let progress = Progress::global();
 582    progress.set_total_examples(examples.len());
 583    progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
 584
 585    Ok(examples)
 586}
 587
 588fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
 589    let mut hasher = collections::FxHasher::default();
 590    spec.hash(&mut hasher);
 591    hasher.finish()
 592}
 593
 594fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
 595    let file = match File::open(path) {
 596        Ok(f) => f,
 597        Err(_) => return,
 598    };
 599
 600    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
 601
 602    let reader = BufReader::new(file);
 603    let mut kept_lines = Vec::new();
 604    let mut kept_hashes = HashSet::default();
 605
 606    for line in reader.lines() {
 607        let line = match line {
 608            Ok(l) => l,
 609            Err(_) => continue,
 610        };
 611
 612        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
 613            let hash = spec_hash(&output_example.spec);
 614            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
 615                kept_hashes.insert(hash);
 616                kept_lines.push(line);
 617            }
 618        }
 619    }
 620
 621    let total = examples.len();
 622    let already_processed = kept_hashes.len();
 623
 624    eprintln!(
 625        "Resuming: {}/{} examples already processed",
 626        already_processed, total
 627    );
 628
 629    let file = OpenOptions::new()
 630        .write(true)
 631        .truncate(true)
 632        .open(path)
 633        .expect("Failed to open output file for rewriting");
 634    let mut writer = BufWriter::new(file);
 635    for line in &kept_lines {
 636        writeln!(writer, "{}", line).expect("Failed to write to output file");
 637    }
 638    writer.flush().expect("Failed to flush output file");
 639
 640    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
 641}
 642
 643fn main() {
 644    let args = EpArgs::parse();
 645
 646    if args.printenv {
 647        ::util::shell_env::print_env();
 648        return;
 649    }
 650
 651    let output = args.output_path();
 652
 653    if args.markdown && output.is_none() {
 654        eprintln!("--markdown requires -o to specify the output directory");
 655        std::process::exit(1);
 656    }
 657
 658    let command = match &args.command {
 659        Some(cmd) => cmd.clone(),
 660        None => {
 661            EpArgs::command().print_help().unwrap();
 662            return;
 663        }
 664    };
 665
 666    match &command {
 667        Command::ImportBatch(import_args) => {
 668            smol::block_on(async {
 669                match import_args.provider {
 670                    BatchProvider::Anthropic => {
 671                        let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
 672                            .expect("Failed to create Anthropic client");
 673                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 674                            eprintln!("Error importing Anthropic batches: {:?}", e);
 675                            std::process::exit(1);
 676                        }
 677                    }
 678                    BatchProvider::Openai => {
 679                        let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
 680                            .expect("Failed to create OpenAI client");
 681                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 682                            eprintln!("Error importing OpenAI batches: {:?}", e);
 683                            std::process::exit(1);
 684                        }
 685                    }
 686                }
 687                println!(
 688                    "Successfully imported {} batch(es)",
 689                    import_args.batch_ids.len()
 690                );
 691            });
 692            return;
 693        }
 694        Command::Clean => {
 695            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
 696            return;
 697        }
 698        Command::Synthesize(synth_args) => {
 699            let Some(output_dir) = args.output else {
 700                panic!("output dir is required");
 701            };
 702            let config = SynthesizeConfig {
 703                repo_urls: synth_args.repos.clone(),
 704                count: synth_args.count,
 705                max_commits: synth_args.max_commits,
 706                output_dir,
 707                fresh: synth_args.fresh,
 708            };
 709            smol::block_on(async {
 710                if let Err(e) = run_synthesize(config).await {
 711                    eprintln!("Error: {:?}", e);
 712                    std::process::exit(1);
 713                }
 714            });
 715            return;
 716        }
 717        Command::SplitCommit(split_commit_args) => {
 718            if let Err(error) = split_commit::run_split_commit(
 719                split_commit_args,
 720                &args.inputs,
 721                output.as_ref(),
 722                args.failed,
 723            ) {
 724                eprintln!("{error:#}");
 725                std::process::exit(1);
 726            }
 727            return;
 728        }
 729        Command::Split(split_args) => {
 730            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
 731                eprintln!("{error:#}");
 732                std::process::exit(1);
 733            }
 734            return;
 735        }
 736        Command::FilterLanguages(filter_args) => {
 737            if let Err(error) =
 738                run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
 739            {
 740                eprintln!("{error:#}");
 741                std::process::exit(1);
 742            }
 743            return;
 744        }
 745        Command::Qa(qa_args) => {
 746            // Read examples from input files
 747            let mut examples = example::read_example_files(&args.inputs);
 748
 749            // Apply filters
 750            if let Some(name_filter) = &args.name {
 751                examples.retain(|e| e.spec.name.contains(name_filter));
 752            }
 753            if let Some(repo_filter) = &args.repo {
 754                examples.retain(|e| e.spec.repository_url.contains(repo_filter));
 755            }
 756            if let Some(offset) = args.offset {
 757                examples.splice(0..offset, []);
 758            }
 759            if let Some(limit) = args.limit {
 760                examples.truncate(limit);
 761            }
 762
 763            smol::block_on(async {
 764                if let Err(e) = qa::run_qa(&mut examples, qa_args, output.as_ref()).await {
 765                    eprintln!("Error: {:?}", e);
 766                    std::process::exit(1);
 767                }
 768            });
 769            return;
 770        }
 771        Command::Repair(repair_args) => {
 772            // Read examples from input files
 773            let mut examples = example::read_example_files(&args.inputs);
 774
 775            // Apply filters
 776            if let Some(name_filter) = &args.name {
 777                examples.retain(|e| e.spec.name.contains(name_filter));
 778            }
 779            if let Some(repo_filter) = &args.repo {
 780                examples.retain(|e| e.spec.repository_url.contains(repo_filter));
 781            }
 782            if let Some(offset) = args.offset {
 783                examples.splice(0..offset, []);
 784            }
 785            if let Some(limit) = args.limit {
 786                examples.truncate(limit);
 787            }
 788
 789            smol::block_on(async {
 790                if let Err(e) =
 791                    repair::run_repair(&mut examples, repair_args, output.as_ref()).await
 792                {
 793                    eprintln!("Error: {:?}", e);
 794                    std::process::exit(1);
 795                }
 796            });
 797            return;
 798        }
 799        _ => {}
 800    }
 801
 802    let http_client = Arc::new(ReqwestClient::new());
 803    let app = Application::headless().with_http_client(http_client);
 804
 805    app.run(move |cx| {
 806        let app_state = Arc::new(headless::init(cx));
 807        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
 808
 809        cx.spawn(async move |cx| {
 810            let result = async {
 811                let examples = load_examples(
 812                    app_state.client.http_client(),
 813                    &args,
 814                    output.as_ref(),
 815                    cx.background_executor().clone(),
 816                )
 817                .await?;
 818
 819                match &command {
 820                    Command::Predict(args) | Command::Score(args) => {
 821                        predict::sync_batches(args.provider.as_ref()).await?;
 822                    }
 823                    Command::Eval(args) => {
 824                        predict::sync_batches(args.predict.provider.as_ref()).await?;
 825                    }
 826                    _ => (),
 827                }
 828
 829                let failfast_on_single_example = examples.len() == 1;
 830
 831                // For --markdown mode, create the output directory if it doesn't exist
 832                if args.markdown {
 833                    let dir = output.as_ref().expect("--markdown requires -o");
 834                    if !dir.exists() {
 835                        std::fs::create_dir_all(dir)
 836                            .expect("Failed to create markdown output directory");
 837                    }
 838                }
 839
 840                // Set up JSONL output writer (not used in markdown mode)
 841                let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
 842                let mut in_place_temp_path: Option<PathBuf> = None;
 843                if !args.markdown
 844                    && let Some(output_path) = output.as_ref()
 845                {
 846                    let write_path = if args.in_place {
 847                        let temp = output_path.with_extension("jsonl.tmp");
 848                        in_place_temp_path = Some(temp.clone());
 849                        temp
 850                    } else {
 851                        output_path.clone()
 852                    };
 853
 854                    let file = OpenOptions::new()
 855                        .create(true)
 856                        .write(true)
 857                        .truncate(args.in_place)
 858                        .append(!args.in_place)
 859                        .open(&write_path)
 860                        .expect("Failed to open output file");
 861
 862                    let mut writer = BufWriter::new(file);
 863                    let (sender, mut receiver) = mpsc::unbounded::<String>();
 864                    cx.background_spawn(async move {
 865                        while let Some(line) = receiver.next().await {
 866                            writeln!(writer, "{}", line).expect("Failed to write example");
 867                            writer.flush().expect("Failed to flush output");
 868                        }
 869                    })
 870                    .detach();
 871                    output_sender = Some(sender);
 872                }
 873
 874                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
 875                let finished_examples = Mutex::new(Vec::new());
 876
 877                let mut tasks = Vec::new();
 878                for _ in 0..args.max_parallelism {
 879                    tasks.push(async {
 880                        loop {
 881                            let Some(mut repo_examples) =
 882                                grouped_examples.lock().unwrap().pop_front()
 883                            else {
 884                                break;
 885                            };
 886                            for example in &mut repo_examples {
 887                                let example_progress =
 888                                    Progress::global().start_group(&example.spec.name);
 889
 890                                let result = async {
 891                                    match &command {
 892                                        Command::Read => {}
 893                                        Command::LoadProject => {
 894                                            run_load_project(
 895                                                example,
 896                                                app_state.clone(),
 897                                                &example_progress,
 898                                                cx.clone(),
 899                                            )
 900                                            .await?;
 901                                        }
 902                                        Command::Context => {
 903                                            run_context_retrieval(
 904                                                example,
 905                                                app_state.clone(),
 906                                                &example_progress,
 907                                                cx.clone(),
 908                                            )
 909                                            .await?;
 910                                        }
 911                                        Command::FormatPrompt(args) => {
 912                                            run_format_prompt(
 913                                                example,
 914                                                args,
 915                                                app_state.clone(),
 916                                                &example_progress,
 917                                                cx.clone(),
 918                                            )
 919                                            .await?;
 920                                        }
 921                                        Command::Predict(args) => {
 922                                            run_prediction(
 923                                                example,
 924                                                args,
 925                                                app_state.clone(),
 926                                                &example_progress,
 927                                                cx.clone(),
 928                                            )
 929                                            .await?;
 930                                        }
 931                                        Command::ParseOutput => {
 932                                            parse_output::run_parse_output(example)?;
 933                                        }
 934                                        Command::Distill => {
 935                                            run_distill(example).await?;
 936                                        }
 937                                        Command::Score(args) => {
 938                                            run_scoring(
 939                                                example,
 940                                                args,
 941                                                app_state.clone(),
 942                                                &example_progress,
 943                                                cx.clone(),
 944                                            )
 945                                            .await?;
 946                                        }
 947                                        Command::Eval(args) => {
 948                                            run_scoring(
 949                                                example,
 950                                                &args.predict,
 951                                                app_state.clone(),
 952                                                &example_progress,
 953                                                cx.clone(),
 954                                            )
 955                                            .await?;
 956                                        }
 957                                        Command::Clean
 958                                        | Command::Synthesize(_)
 959                                        | Command::SplitCommit(_)
 960                                        | Command::Split(_)
 961                                        | Command::FilterLanguages(_)
 962                                        | Command::ImportBatch(_)
 963                                        | Command::Qa(_)
 964                                        | Command::Repair(_) => {
 965                                            unreachable!()
 966                                        }
 967                                    }
 968                                    anyhow::Ok(())
 969                                }
 970                                .await;
 971
 972                                let failed = if let Err(error) = result {
 973                                    handle_error(
 974                                        error,
 975                                        &args,
 976                                        &command,
 977                                        &app_state,
 978                                        failfast_on_single_example,
 979                                        &example,
 980                                    )
 981                                    .await;
 982                                    true
 983                                } else {
 984                                    false
 985                                };
 986
 987                                let should_write = !failed || args.failed == FailedHandling::Keep;
 988                                if should_write {
 989                                    if args.markdown {
 990                                        let markdown_dir =
 991                                            output.as_ref().expect("--markdown requires -o");
 992                                        let filename = format!("{}.md", example.spec.filename());
 993                                        let path = markdown_dir.join(&filename);
 994                                        let markdown = example.spec.to_markdown();
 995                                        std::fs::write(&path, &markdown)
 996                                            .expect("Failed to write markdown file");
 997                                    } else if let Some(ref mut sender) = output_sender.clone() {
 998                                        let line = serde_json::to_string(&example).unwrap();
 999                                        sender
1000                                            .send(line)
1001                                            .await
1002                                            .expect("Failed to send to output writer");
1003                                    } else if args.output.is_none()
1004                                        && !matches!(command, Command::Eval(_))
1005                                    {
1006                                        let line = serde_json::to_string(&example).unwrap();
1007                                        println!("{}", line);
1008                                    }
1009                                }
1010                            }
1011
1012                            let repo_url = &repo_examples.first().unwrap().spec.repository_url;
1013                            let project = repo_examples
1014                                .iter()
1015                                .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
1016                                .or_else(|| app_state.project_cache.get(repo_url));
1017
1018                            if let Some(project) = project {
1019                                let mut cx = cx.clone();
1020
1021                                let shutdown_task: Task<()> =
1022                                    project.update(&mut cx, |project, cx| {
1023                                        let lsp_store = project.lsp_store();
1024                                        lsp_store.update(cx, |lsp_store, cx| {
1025                                            lsp_store.shutdown_all_language_servers(cx)
1026                                        })
1027                                    });
1028
1029                                shutdown_task.await;
1030
1031                                if let Some(ep_store) =
1032                                    cx.update(|cx| EditPredictionStore::try_global(cx))
1033                                {
1034                                    ep_store.update(&mut cx, |store, _| {
1035                                        store.remove_project(&project);
1036                                    });
1037                                }
1038                            }
1039
1040                            app_state.project_cache.remove(repo_url);
1041                            for example in &mut repo_examples {
1042                                example.state.take();
1043                            }
1044                            finished_examples
1045                                .lock()
1046                                .unwrap()
1047                                .extend_from_slice(&repo_examples);
1048                        }
1049                    });
1050                }
1051                futures::future::join_all(tasks).await;
1052
1053                Progress::global().finalize();
1054
1055                match &command {
1056                    Command::Predict(args) | Command::Score(args) => {
1057                        predict::sync_batches(args.provider.as_ref()).await?;
1058                    }
1059                    Command::Eval(args) => {
1060                        predict::sync_batches(args.predict.provider.as_ref()).await?;
1061                    }
1062                    _ => (),
1063                }
1064
1065                match &command {
1066                    Command::Eval(args) => {
1067                        let examples = finished_examples.lock().unwrap();
1068                        score::print_report(&examples);
1069                        if let Some(summary_path) = &args.summary_json {
1070                            score::write_summary_json(&examples, summary_path)?;
1071                        }
1072                    }
1073                    _ => (),
1074                };
1075
1076                // For --in-place, atomically rename temp file to original
1077                if let Some(temp_path) = &in_place_temp_path {
1078                    let final_path = output.as_ref().expect("in_place_temp_path requires output");
1079                    std::fs::rename(temp_path, final_path)
1080                        .expect("Failed to rename temp file to final output");
1081                }
1082
1083                anyhow::Ok(())
1084            }
1085            .await;
1086
1087            if let Err(e) = result {
1088                panic!("Fatal error: {:?}", e);
1089            }
1090
1091            let _ = cx.update(|cx| cx.quit());
1092        })
1093        .detach();
1094    });
1095}
1096
1097async fn handle_error(
1098    error: anyhow::Error,
1099    args: &EpArgs,
1100    command: &Command,
1101    app_state: &Arc<headless::EpAppState>,
1102    failfast_on_single_example: bool,
1103    example: &Example,
1104) {
1105    Progress::global().increment_failed();
1106
1107    let msg;
1108    if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1109        let example_name = example.spec.filename();
1110
1111        let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1112        app_state
1113            .fs
1114            .write(
1115                &failed_example_path,
1116                &serde_json::to_vec_pretty(&example).unwrap(),
1117            )
1118            .await
1119            .unwrap();
1120        let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1121        app_state
1122            .fs
1123            .write(&err_path, format!("{error:?}").as_bytes())
1124            .await
1125            .unwrap();
1126
1127        let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1128        let mut file = OpenOptions::new()
1129            .create(true)
1130            .append(true)
1131            .open(&failed_jsonl_path)
1132            .expect("Failed to open failed.jsonl");
1133        writeln!(file, "{}", serde_json::to_string(example).unwrap())
1134            .expect("Failed to write to failed.jsonl");
1135
1136        let cursor_path = match example.repo_name() {
1137            Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1138            Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1139        };
1140        msg = format!(
1141            indoc::indoc! {"
1142                While processing \"{}\":
1143
1144                \x1b[31m{:?}\x1b[0m
1145
1146                Example:        \x1b[36m{}\x1b[0m
1147                Error file:     \x1b[36m{}\x1b[0m
1148                Cursor file:    \x1b[36m{}\x1b[0m
1149                Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1150            "},
1151            example.spec.name,
1152            error,
1153            failed_example_path.display(),
1154            err_path.display(),
1155            cursor_path.display(),
1156            command,
1157            failed_example_path.display(),
1158        );
1159    } else {
1160        msg = format!(
1161            indoc::indoc! {"
1162            While processing \"{}\":
1163
1164                \x1b[31m{:?}\x1b[0m
1165            "},
1166            example.spec.name, error
1167        );
1168    }
1169
1170    if args.failfast || failfast_on_single_example {
1171        Progress::global().finalize();
1172        panic!("{}", msg);
1173    } else {
1174        log::error!("{}", msg);
1175    }
1176}