main.rs

   1mod anthropic_client;
   2mod distill;
   3mod example;
   4mod filter_languages;
   5mod format_prompt;
   6mod git;
   7mod headless;
   8mod load_project;
   9mod metrics;
  10mod openai_client;
  11mod parse_output;
  12mod paths;
  13mod predict;
  14mod progress;
  15mod prompt_assets;
  16mod pull_examples;
  17mod qa;
  18mod reorder_patch;
  19mod repair;
  20mod retrieve_context;
  21mod reversal_tracking;
  22mod score;
  23mod split_commit;
  24mod split_dataset;
  25mod synthesize;
  26mod truncate_expected_patch;
  27mod word_diff;
  28use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
  29use collections::HashSet;
  30use edit_prediction::EditPredictionStore;
  31use futures::channel::mpsc;
  32use futures::{SinkExt as _, StreamExt as _};
  33use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
  34use zeta_prompt::ZetaFormat;
  35
  36use reqwest_client::ReqwestClient;
  37use serde::{Deserialize, Deserializer, Serialize, Serializer};
  38use std::fmt::Display;
  39use std::fs::{File, OpenOptions};
  40use std::hash::{Hash, Hasher};
  41use std::io::{BufRead, BufReader, BufWriter, Write};
  42use std::sync::Mutex;
  43use std::{path::PathBuf, sync::Arc};
  44
  45use crate::distill::run_distill;
  46use crate::example::{Example, group_examples_by_repo, read_example_files};
  47use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
  48use crate::format_prompt::run_format_prompt;
  49use crate::load_project::run_load_project;
  50use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
  51use crate::predict::run_prediction;
  52use crate::progress::Progress;
  53use crate::retrieve_context::run_context_retrieval;
  54use crate::score::run_scoring;
  55use crate::split_commit::SplitCommitArgs;
  56use crate::split_dataset::SplitArgs;
  57use crate::synthesize::{SynthesizeConfig, run_synthesize};
  58use crate::truncate_expected_patch::TruncatePatchArgs;
  59
  60#[derive(Parser, Debug)]
  61#[command(name = "ep")]
  62struct EpArgs {
  63    #[arg(long, default_value_t = false)]
  64    printenv: bool,
  65    #[clap(long, default_value_t = 10, global = true)]
  66    max_parallelism: usize,
  67    /// The limit for the number of examples to process
  68    /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
  69    #[clap(long, global = true)]
  70    limit: Option<usize>,
  71    #[clap(long, global = true)]
  72    offset: Option<usize>,
  73    /// Filter examples by name
  74    #[clap(long, global = true)]
  75    name: Option<String>,
  76    /// Filter examples by repository
  77    #[clap(long, global = true)]
  78    repo: Option<String>,
  79    #[command(subcommand)]
  80    command: Option<Command>,
  81    #[clap(global = true, help = INPUTS_HELP)]
  82    inputs: Vec<PathBuf>,
  83    #[arg(long, short, global = true)]
  84    output: Option<PathBuf>,
  85    #[arg(long, short, global = true)]
  86    in_place: bool,
  87    #[arg(long, short, global = true)]
  88    failfast: bool,
  89    /// How to handle failed examples in output: keep them or skip them.
  90    /// Failed examples are always logged to the run's failed directory.
  91    #[arg(long, global = true, default_value = "keep")]
  92    failed: FailedHandling,
  93    /// Output as markdown files instead of JSONL. When set, -o specifies a directory
  94    /// where one .md file per example will be written (named after each example).
  95    #[arg(long, short, global = true)]
  96    markdown: bool,
  97}
  98
  99/// Controls whether failed examples are included in the main output.
 100/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
 101#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
 102pub enum FailedHandling {
 103    /// Include failed examples in the main output (default)
 104    #[default]
 105    Keep,
 106    /// Exclude failed examples from the main output
 107    Skip,
 108    /// Skip writing files
 109    SkipNoFiles,
 110}
 111
 112const INPUTS_HELP: &str = r#"
 113Inputs can be file paths or special specifiers:
 114
 115  path
 116      Path to an example(s) file (.md, .json, or .jsonl)
 117
 118  captured-after:{timestamp}
 119      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 120      These are examples captured via the "Capture Edit Prediction Example" action.
 121
 122  rejected-after:{timestamp}
 123      Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
 124      These are predictions that were shown to users but rejected (useful for DPO training).
 125
 126  rated-after:{timestamp}
 127      Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
 128      These are predictions that users explicitly rated as positive or negative via the
 129      rate completions modal. Only zeta2 predictions are included.
 130      - Positive ratings: output becomes expected_patches
 131      - Negative ratings: output becomes rejected_patch
 132
 133  rated-positive-after:{timestamp}
 134      Same as rated-after, but only fetches positively rated predictions.
 135
 136  rated-negative-after:{timestamp}
 137      Same as rated-after, but only fetches negatively rated predictions.
 138
 139      Required environment variables to connect to Snowflake:
 140          EP_SNOWFLAKE_API_KEY
 141          EP_SNOWFLAKE_BASE_URL
 142
 143      Optional:
 144          EP_SNOWFLAKE_ROLE
 145
 146Examples:
 147
 148  # Read examples from a file
 149  ep read examples.jsonl -o output.jsonl
 150
 151  # Read captured examples after a timestamp
 152  ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
 153
 154  # Read rejected predictions for DPO training
 155  ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
 156
 157  # Read user-rated predictions
 158  ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
 159
 160  # Read only positively rated predictions
 161  ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
 162
 163  # Read only negatively rated predictions
 164  ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
 165
 166  # Mix multiple input sources
 167  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
 168"#;
 169
 170#[derive(Subcommand, Debug, Clone)]
 171enum Command {
 172    /// Read examples from files or fetch from Snowflake, output as .jsonl
 173    Read,
 174    /// Create git worktrees for each example and load file contents
 175    LoadProject,
 176    /// Retrieve context for input examples.
 177    Context,
 178    /// Generate a prompt string for a specific model
 179    FormatPrompt(FormatPromptArgs),
 180    /// Runs edit prediction
 181    Predict(PredictArgs),
 182    /// Parse model outputs (actual_output) into unified diffs (actual_patch).
 183    /// Requires format-prompt to have been run first. Uses provider from prompt.
 184    ParseOutput,
 185    /// Computes a score based on actual and expected patches
 186    Score(PredictArgs),
 187    /// Prepares a distillation dataset by copying expected outputs to
 188    /// predicted outputs and removing actual outputs and prompts.
 189    Distill,
 190    /// Print aggregated scores
 191    Eval(EvalArgs),
 192    /// Generate eval examples by analyzing git commits from a repository
 193    Synthesize(SynthesizeArgs),
 194    /// Remove git repositories and worktrees
 195    Clean,
 196    /// Generate an evaluation example by splitting a chronologically-ordered commit
 197    SplitCommit(SplitCommitArgs),
 198    /// Truncate expected patch by the given criteria
 199    TruncatePatch(TruncatePatchArgs),
 200    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
 201    Split(SplitArgs),
 202    /// Filter a JSONL dataset by programming language (based on cursor_path extension)
 203    FilterLanguages(FilterLanguagesArgs),
 204    /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
 205    ImportBatch(ImportBatchArgs),
 206    /// Assess the quality of predictions using LLM-as-a-judge
 207    Qa(qa::QaArgs),
 208    /// Repair predictions that received poor QA scores by generating improved predictions
 209    Repair(repair::RepairArgs),
 210    /// Print all valid zeta formats (lowercase, one per line)
 211    PrintZetaFormats,
 212}
 213
 214impl Display for Command {
 215    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 216        match self {
 217            Command::Read => write!(f, "read"),
 218            Command::LoadProject => write!(f, "load-project"),
 219            Command::Context => write!(f, "context"),
 220            Command::FormatPrompt(args) => {
 221                write!(f, "format-prompt --provider={}", args.provider)
 222            }
 223            Command::Predict(args) => match &args.provider {
 224                Some(provider) => write!(f, "predict --provider={}", provider),
 225                None => write!(f, "predict"),
 226            },
 227            Command::ParseOutput => write!(f, "parse-output"),
 228            Command::Score(args) => match &args.provider {
 229                Some(provider) => write!(f, "score --provider={}", provider),
 230                None => write!(f, "score"),
 231            },
 232            Command::Distill => write!(f, "distill"),
 233            Command::Eval(args) => match &args.predict.provider {
 234                Some(provider) => write!(f, "eval --provider={}", provider),
 235                None => write!(f, "eval"),
 236            },
 237            Command::Synthesize(args) => {
 238                write!(f, "synthesize --repos {}", args.repos.join(" "))
 239            }
 240            Command::Clean => write!(f, "clean"),
 241            Command::SplitCommit(_) => write!(f, "split-commit"),
 242            Command::TruncatePatch(_) => write!(f, "truncate-patch"),
 243            Command::Split(_) => write!(f, "split"),
 244            Command::FilterLanguages(_) => write!(f, "filter-languages"),
 245            Command::ImportBatch(args) => {
 246                write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
 247            }
 248            Command::Qa(_) => {
 249                write!(f, "qa")
 250            }
 251            Command::Repair(_) => {
 252                write!(f, "repair")
 253            }
 254            Command::PrintZetaFormats => {
 255                write!(f, "print-zeta-formats")
 256            }
 257        }
 258    }
 259}
 260
 261#[derive(Debug, Args, Clone)]
 262struct FormatPromptArgs {
 263    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
 264    provider: PredictionProvider,
 265}
 266
 267#[derive(Debug, Args, Clone)]
 268struct PredictArgs {
 269    #[clap(long, short('p'))]
 270    provider: Option<PredictionProvider>,
 271    #[clap(long, default_value_t = 1)]
 272    repetitions: usize,
 273    /// Only use cached responses, don't queue new requests for batching
 274    #[clap(long)]
 275    cache_only: bool,
 276}
 277
 278#[derive(Debug, Args, Clone)]
 279struct EvalArgs {
 280    #[clap(flatten)]
 281    predict: PredictArgs,
 282    /// Path to write summary scores as JSON
 283    #[clap(long)]
 284    summary_json: Option<PathBuf>,
 285}
 286
 287#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 288pub enum TeacherBackend {
 289    Sonnet45,
 290    Gpt52,
 291}
 292
 293impl std::fmt::Display for TeacherBackend {
 294    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 295        match self {
 296            TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
 297            TeacherBackend::Gpt52 => write!(f, "gpt52"),
 298        }
 299    }
 300}
 301
 302impl std::str::FromStr for TeacherBackend {
 303    type Err = anyhow::Error;
 304
 305    fn from_str(s: &str) -> Result<Self, Self::Err> {
 306        match s.to_lowercase().as_str() {
 307            "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
 308            "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
 309            "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
 310            _ => anyhow::bail!("unknown teacher backend `{s}`. Valid options: sonnet45, gpt52"),
 311        }
 312    }
 313}
 314
 315impl TeacherBackend {
 316    pub fn model_name(&self) -> &'static str {
 317        match self {
 318            TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
 319            TeacherBackend::Gpt52 => "gpt-5.2",
 320        }
 321    }
 322}
 323
 324#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 325enum PredictionProvider {
 326    Sweep,
 327    Mercury,
 328    Zeta1,
 329    Zeta2(ZetaFormat),
 330    Teacher(TeacherBackend),
 331    TeacherNonBatching(TeacherBackend),
 332    Repair,
 333}
 334
 335impl Default for PredictionProvider {
 336    fn default() -> Self {
 337        PredictionProvider::Zeta2(ZetaFormat::default())
 338    }
 339}
 340
 341impl std::fmt::Display for PredictionProvider {
 342    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 343        match self {
 344            PredictionProvider::Sweep => write!(f, "sweep"),
 345            PredictionProvider::Mercury => write!(f, "mercury"),
 346            PredictionProvider::Zeta1 => write!(f, "zeta1"),
 347            PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
 348            PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
 349            PredictionProvider::TeacherNonBatching(backend) => {
 350                write!(f, "teacher-non-batching:{backend}")
 351            }
 352            PredictionProvider::Repair => write!(f, "repair"),
 353        }
 354    }
 355}
 356
 357impl std::str::FromStr for PredictionProvider {
 358    type Err = anyhow::Error;
 359
 360    fn from_str(s: &str) -> Result<Self, Self::Err> {
 361        let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
 362
 363        let provider_lower = provider.to_lowercase();
 364        match provider_lower.as_str() {
 365            "sweep" => Ok(PredictionProvider::Sweep),
 366            "mercury" => Ok(PredictionProvider::Mercury),
 367            "zeta1" => Ok(PredictionProvider::Zeta1),
 368            "zeta2" => {
 369                let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
 370                Ok(PredictionProvider::Zeta2(format))
 371            }
 372            "teacher" => {
 373                let backend = arg
 374                    .map(|a| a.parse())
 375                    .transpose()?
 376                    .unwrap_or(TeacherBackend::Sonnet45);
 377                Ok(PredictionProvider::Teacher(backend))
 378            }
 379            "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
 380                let backend = arg
 381                    .map(|a| a.parse())
 382                    .transpose()?
 383                    .unwrap_or(TeacherBackend::Sonnet45);
 384                Ok(PredictionProvider::TeacherNonBatching(backend))
 385            }
 386            "repair" => Ok(PredictionProvider::Repair),
 387            _ => {
 388                anyhow::bail!(
 389                    "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
 390                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
 391                 For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\
 392                 Available zeta versions:\n{}",
 393                    ZetaFormat::options_as_string()
 394                )
 395            }
 396        }
 397    }
 398}
 399
 400impl Serialize for PredictionProvider {
 401    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 402    where
 403        S: Serializer,
 404    {
 405        serializer.serialize_str(&self.to_string())
 406    }
 407}
 408
 409impl<'de> Deserialize<'de> for PredictionProvider {
 410    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 411    where
 412        D: Deserializer<'de>,
 413    {
 414        let s = String::deserialize(deserializer)?;
 415        s.parse().map_err(serde::de::Error::custom)
 416    }
 417}
 418
 419#[derive(Debug, Args, Clone)]
 420struct SynthesizeArgs {
 421    /// Repository URLs (git@github.com:owner/repo or https://...)
 422    #[clap(long, required = true, num_args = 1..)]
 423    repos: Vec<String>,
 424
 425    /// Number of examples to generate per repository
 426    #[clap(long, default_value_t = 5)]
 427    count: usize,
 428
 429    /// Maximum commits to scan per repository before giving up
 430    #[clap(long, default_value_t = 100)]
 431    max_commits: usize,
 432
 433    /// Ignore state file and reprocess all commits
 434    #[clap(long)]
 435    fresh: bool,
 436}
 437
 438#[derive(Debug, Args, Clone)]
 439struct ImportBatchArgs {
 440    /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
 441    #[clap(long, required = true, num_args = 1..)]
 442    batch_ids: Vec<String>,
 443    /// Which provider's batches to import (anthropic or openai)
 444    #[clap(long, default_value = "anthropic")]
 445    provider: BatchProvider,
 446}
 447
 448#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
 449enum BatchProvider {
 450    Anthropic,
 451    Openai,
 452}
 453
 454impl EpArgs {
 455    fn output_path(&self) -> Option<PathBuf> {
 456        if self.in_place {
 457            if self.inputs.len() == 1 {
 458                self.inputs.first().cloned()
 459            } else {
 460                panic!("--in-place requires exactly one input file")
 461            }
 462        } else {
 463            self.output.clone()
 464        }
 465    }
 466}
 467
 468async fn load_examples(
 469    http_client: Arc<dyn http_client::HttpClient>,
 470    args: &EpArgs,
 471    output_path: Option<&PathBuf>,
 472    background_executor: BackgroundExecutor,
 473) -> anyhow::Result<Vec<Example>> {
 474    let mut captured_after_timestamps = Vec::new();
 475    let mut rejected_after_timestamps = Vec::new();
 476    let mut requested_after_timestamps = Vec::new();
 477    let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
 478        Vec::new();
 479    let mut file_inputs = Vec::new();
 480
 481    for input in &args.inputs {
 482        let input_string = input.to_string_lossy();
 483        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
 484            captured_after_timestamps.push(timestamp.to_string());
 485        } else if let Some(timestamp) =
 486            pull_examples::parse_rejected_after_input(input_string.as_ref())
 487        {
 488            rejected_after_timestamps.push(timestamp.to_string());
 489        } else if let Some(timestamp) =
 490            pull_examples::parse_requested_after_input(input_string.as_ref())
 491        {
 492            requested_after_timestamps.push(timestamp.to_string());
 493        } else if let Some((timestamp, rating_filter)) =
 494            pull_examples::parse_rated_after_input(input_string.as_ref())
 495        {
 496            rated_after_inputs.push((timestamp.to_string(), rating_filter));
 497        } else {
 498            file_inputs.push(input.clone());
 499        }
 500    }
 501
 502    let mut examples = read_example_files(&file_inputs);
 503
 504    Progress::global().set_total_examples(examples.len());
 505
 506    let remaining_limit_for_snowflake =
 507        args.limit.map(|limit| limit.saturating_sub(examples.len()));
 508
 509    if let Some(0) = remaining_limit_for_snowflake {
 510        log::info!(
 511            "skipping Snowflake inputs because --limit is already satisfied by example files"
 512        );
 513    } else {
 514        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
 515
 516        if !captured_after_timestamps.is_empty() {
 517            captured_after_timestamps.sort();
 518
 519            let mut captured_examples = pull_examples::fetch_captured_examples_after(
 520                http_client.clone(),
 521                &captured_after_timestamps,
 522                max_rows_per_timestamp,
 523                background_executor.clone(),
 524            )
 525            .await?;
 526            examples.append(&mut captured_examples);
 527        }
 528
 529        if !rejected_after_timestamps.is_empty() {
 530            rejected_after_timestamps.sort();
 531
 532            let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
 533                http_client.clone(),
 534                &rejected_after_timestamps,
 535                max_rows_per_timestamp,
 536                background_executor.clone(),
 537            )
 538            .await?;
 539            examples.append(&mut rejected_examples);
 540        }
 541
 542        if !requested_after_timestamps.is_empty() {
 543            requested_after_timestamps.sort();
 544
 545            let mut requested_examples = pull_examples::fetch_requested_examples_after(
 546                http_client.clone(),
 547                &requested_after_timestamps,
 548                max_rows_per_timestamp,
 549                background_executor.clone(),
 550            )
 551            .await?;
 552            examples.append(&mut requested_examples);
 553        }
 554
 555        if !rated_after_inputs.is_empty() {
 556            rated_after_inputs.sort();
 557
 558            let mut rated_examples = pull_examples::fetch_rated_examples_after(
 559                http_client,
 560                &rated_after_inputs,
 561                max_rows_per_timestamp,
 562                background_executor,
 563            )
 564            .await?;
 565            examples.append(&mut rated_examples);
 566        }
 567    }
 568
 569    crate::example::sort_examples_by_repo_and_rev(&mut examples);
 570
 571    if let Some(name_filter) = &args.name {
 572        examples.retain(|example| example.spec.name.contains(name_filter));
 573    }
 574    if let Some(repo_filter) = &args.repo {
 575        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
 576    }
 577
 578    // Skip resume logic for --in-place since input and output are the same file,
 579    // which would incorrectly treat all input examples as already processed.
 580    if !args.in_place {
 581        if let Some(path) = output_path
 582            && let Some(command) = &args.command
 583        {
 584            resume_from_output(path, &mut examples, command);
 585        }
 586    }
 587
 588    if let Some(offset) = args.offset {
 589        examples.splice(0..offset, []);
 590    }
 591
 592    if let Some(limit) = args.limit {
 593        examples.truncate(limit);
 594    }
 595
 596    let progress = Progress::global();
 597    progress.set_total_examples(examples.len());
 598    progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
 599
 600    Ok(examples)
 601}
 602
 603fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
 604    let mut hasher = collections::FxHasher::default();
 605    spec.hash(&mut hasher);
 606    hasher.finish()
 607}
 608
 609fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
 610    let file = match File::open(path) {
 611        Ok(f) => f,
 612        Err(_) => return,
 613    };
 614
 615    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
 616
 617    let reader = BufReader::new(file);
 618    let mut kept_lines = Vec::new();
 619    let mut kept_hashes = HashSet::default();
 620
 621    for line in reader.lines() {
 622        let line = match line {
 623            Ok(l) => l,
 624            Err(_) => continue,
 625        };
 626
 627        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
 628            let hash = spec_hash(&output_example.spec);
 629            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
 630                let is_complete = match command {
 631                    Command::Qa(_) => output_example
 632                        .qa
 633                        .first()
 634                        .and_then(|q| q.as_ref())
 635                        .and_then(|q| q.confidence)
 636                        .is_some(),
 637                    Command::Repair(_) => output_example.predictions.iter().any(|p| {
 638                        p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
 639                    }),
 640                    _ => true,
 641                };
 642                if is_complete {
 643                    kept_hashes.insert(hash);
 644                    kept_lines.push(line);
 645                }
 646            }
 647        }
 648    }
 649
 650    let total = examples.len();
 651    let already_processed = kept_hashes.len();
 652
 653    eprintln!(
 654        "Resuming: {}/{} examples already processed",
 655        already_processed, total
 656    );
 657
 658    let file = OpenOptions::new()
 659        .write(true)
 660        .truncate(true)
 661        .open(path)
 662        .expect("Failed to open output file for rewriting");
 663    let mut writer = BufWriter::new(file);
 664    for line in &kept_lines {
 665        writeln!(writer, "{}", line).expect("Failed to write to output file");
 666    }
 667    writer.flush().expect("Failed to flush output file");
 668
 669    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
 670}
 671
 672fn main() {
 673    let args = EpArgs::parse();
 674
 675    if args.printenv {
 676        ::util::shell_env::print_env();
 677        return;
 678    }
 679
 680    let output = args.output_path();
 681
 682    if args.markdown && output.is_none() {
 683        eprintln!("--markdown requires -o to specify the output directory");
 684        std::process::exit(1);
 685    }
 686
 687    let command = match &args.command {
 688        Some(cmd) => cmd.clone(),
 689        None => {
 690            EpArgs::command().print_help().unwrap();
 691            return;
 692        }
 693    };
 694
 695    match &command {
 696        Command::ImportBatch(import_args) => {
 697            smol::block_on(async {
 698                match import_args.provider {
 699                    BatchProvider::Anthropic => {
 700                        let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
 701                            .expect("Failed to create Anthropic client");
 702                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 703                            eprintln!("Error importing Anthropic batches: {:?}", e);
 704                            std::process::exit(1);
 705                        }
 706                    }
 707                    BatchProvider::Openai => {
 708                        let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
 709                            .expect("Failed to create OpenAI client");
 710                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 711                            eprintln!("Error importing OpenAI batches: {:?}", e);
 712                            std::process::exit(1);
 713                        }
 714                    }
 715                }
 716                println!(
 717                    "Successfully imported {} batch(es)",
 718                    import_args.batch_ids.len()
 719                );
 720            });
 721            return;
 722        }
 723        Command::Clean => {
 724            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
 725            return;
 726        }
 727        Command::PrintZetaFormats => {
 728            use strum::IntoEnumIterator as _;
 729            for format in ZetaFormat::iter() {
 730                println!("{}", format.to_string().to_lowercase());
 731            }
 732            return;
 733        }
 734        Command::Synthesize(synth_args) => {
 735            let Some(output_dir) = args.output else {
 736                panic!("output dir is required");
 737            };
 738            let config = SynthesizeConfig {
 739                repo_urls: synth_args.repos.clone(),
 740                count: synth_args.count,
 741                max_commits: synth_args.max_commits,
 742                output_dir,
 743                fresh: synth_args.fresh,
 744            };
 745            smol::block_on(async {
 746                if let Err(e) = run_synthesize(config).await {
 747                    eprintln!("Error: {:?}", e);
 748                    std::process::exit(1);
 749                }
 750            });
 751            return;
 752        }
 753        Command::SplitCommit(split_commit_args) => {
 754            if let Err(error) = split_commit::run_split_commit(
 755                split_commit_args,
 756                &args.inputs,
 757                output.as_ref(),
 758                args.failed,
 759            ) {
 760                eprintln!("{error:#}");
 761                std::process::exit(1);
 762            }
 763            return;
 764        }
 765        Command::TruncatePatch(truncate_args) => {
 766            if let Err(error) =
 767                truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
 768            {
 769                eprintln!("{error:#}");
 770                std::process::exit(1);
 771            }
 772            return;
 773        }
 774        Command::Split(split_args) => {
 775            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
 776                eprintln!("{error:#}");
 777                std::process::exit(1);
 778            }
 779            return;
 780        }
 781        Command::FilterLanguages(filter_args) => {
 782            if let Err(error) =
 783                run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
 784            {
 785                eprintln!("{error:#}");
 786                std::process::exit(1);
 787            }
 788            return;
 789        }
 790
 791        _ => {}
 792    }
 793
 794    let http_client = Arc::new(ReqwestClient::new());
 795    let app = Application::headless().with_http_client(http_client);
 796
 797    app.run(move |cx| {
 798        let app_state = Arc::new(headless::init(cx));
 799        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
 800
 801        cx.spawn(async move |cx| {
 802            let result = async {
 803                let examples = load_examples(
 804                    app_state.client.http_client(),
 805                    &args,
 806                    output.as_ref(),
 807                    cx.background_executor().clone(),
 808                )
 809                .await?;
 810
 811                match &command {
 812                    Command::Predict(args) | Command::Score(args) => {
 813                        predict::sync_batches(args.provider.as_ref()).await?;
 814                    }
 815                    Command::Eval(args) => {
 816                        predict::sync_batches(args.predict.provider.as_ref()).await?;
 817                    }
 818                    Command::Qa(args) => {
 819                        qa::sync_batches(args).await?;
 820                    }
 821                    Command::Repair(args) => {
 822                        repair::sync_batches(args).await?;
 823                    }
 824                    _ => (),
 825                }
 826
 827                let failfast_on_single_example = examples.len() == 1;
 828
 829                // For --markdown mode, create the output directory if it doesn't exist
 830                if args.markdown {
 831                    let dir = output.as_ref().expect("--markdown requires -o");
 832                    if !dir.exists() {
 833                        std::fs::create_dir_all(dir)
 834                            .expect("Failed to create markdown output directory");
 835                    }
 836                }
 837
 838                // Set up JSONL output writer (not used in markdown mode)
 839                let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
 840                let mut in_place_temp_path: Option<PathBuf> = None;
 841                if !args.markdown
 842                    && let Some(output_path) = output.as_ref()
 843                {
 844                    let write_path = if args.in_place {
 845                        let temp = output_path.with_extension("jsonl.tmp");
 846                        in_place_temp_path = Some(temp.clone());
 847                        temp
 848                    } else {
 849                        output_path.clone()
 850                    };
 851
 852                    let file = OpenOptions::new()
 853                        .create(true)
 854                        .write(true)
 855                        .truncate(args.in_place)
 856                        .append(!args.in_place)
 857                        .open(&write_path)
 858                        .expect("Failed to open output file");
 859
 860                    let mut writer = BufWriter::new(file);
 861                    let (sender, mut receiver) = mpsc::unbounded::<String>();
 862                    cx.background_spawn(async move {
 863                        while let Some(line) = receiver.next().await {
 864                            writeln!(writer, "{}", line).expect("Failed to write example");
 865                            writer.flush().expect("Failed to flush output");
 866                        }
 867                    })
 868                    .detach();
 869                    output_sender = Some(sender);
 870                }
 871
 872                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
 873                let finished_examples = Mutex::new(Vec::new());
 874
 875                let mut tasks = Vec::new();
 876                for _ in 0..args.max_parallelism {
 877                    tasks.push(async {
 878                        loop {
 879                            let Some(mut repo_examples) =
 880                                grouped_examples.lock().unwrap().pop_front()
 881                            else {
 882                                break;
 883                            };
 884                            for example in &mut repo_examples {
 885                                let example_progress =
 886                                    Progress::global().start_group(&example.spec.name);
 887
 888                                let result = async {
 889                                    match &command {
 890                                        Command::Read => {}
 891                                        Command::LoadProject => {
 892                                            run_load_project(
 893                                                example,
 894                                                app_state.clone(),
 895                                                &example_progress,
 896                                                cx.clone(),
 897                                            )
 898                                            .await?;
 899                                        }
 900                                        Command::Context => {
 901                                            run_context_retrieval(
 902                                                example,
 903                                                app_state.clone(),
 904                                                &example_progress,
 905                                                cx.clone(),
 906                                            )
 907                                            .await?;
 908                                        }
 909                                        Command::FormatPrompt(args) => {
 910                                            run_format_prompt(
 911                                                example,
 912                                                args,
 913                                                app_state.clone(),
 914                                                &example_progress,
 915                                                cx.clone(),
 916                                            )
 917                                            .await?;
 918                                        }
 919                                        Command::Predict(args) => {
 920                                            run_prediction(
 921                                                example,
 922                                                args,
 923                                                app_state.clone(),
 924                                                &example_progress,
 925                                                cx.clone(),
 926                                            )
 927                                            .await?;
 928                                        }
 929                                        Command::ParseOutput => {
 930                                            parse_output::run_parse_output(example)?;
 931                                        }
 932                                        Command::Distill => {
 933                                            run_distill(example).await?;
 934                                        }
 935                                        Command::Score(args) => {
 936                                            run_scoring(
 937                                                example,
 938                                                args,
 939                                                app_state.clone(),
 940                                                &example_progress,
 941                                                cx.clone(),
 942                                            )
 943                                            .await?;
 944                                        }
 945                                        Command::Eval(args) => {
 946                                            run_scoring(
 947                                                example,
 948                                                &args.predict,
 949                                                app_state.clone(),
 950                                                &example_progress,
 951                                                cx.clone(),
 952                                            )
 953                                            .await?;
 954                                        }
 955                                        Command::Qa(args) => {
 956                                            qa::run_qa(example, args, &example_progress).await?;
 957                                        }
 958                                        Command::Repair(args) => {
 959                                            repair::run_repair(example, args, &example_progress)
 960                                                .await?;
 961                                        }
 962                                        Command::Clean
 963                                        | Command::Synthesize(_)
 964                                        | Command::SplitCommit(_)
 965                                        | Command::Split(_)
 966                                        | Command::TruncatePatch(_)
 967                                        | Command::FilterLanguages(_)
 968                                        | Command::ImportBatch(_)
 969                                        | Command::PrintZetaFormats => {
 970                                            unreachable!()
 971                                        }
 972                                    }
 973                                    anyhow::Ok(())
 974                                }
 975                                .await;
 976
 977                                let failed = if let Err(error) = result {
 978                                    handle_error(
 979                                        error,
 980                                        &args,
 981                                        &command,
 982                                        &app_state,
 983                                        failfast_on_single_example,
 984                                        &example,
 985                                    )
 986                                    .await;
 987                                    true
 988                                } else {
 989                                    false
 990                                };
 991
 992                                let should_write = !failed || args.failed == FailedHandling::Keep;
 993                                if should_write {
 994                                    if args.markdown {
 995                                        let markdown_dir =
 996                                            output.as_ref().expect("--markdown requires -o");
 997                                        let filename = format!("{}.md", example.spec.filename());
 998                                        let path = markdown_dir.join(&filename);
 999                                        let markdown = example.spec.to_markdown();
1000                                        std::fs::write(&path, &markdown)
1001                                            .expect("Failed to write markdown file");
1002                                    } else if let Some(ref mut sender) = output_sender.clone() {
1003                                        let line = serde_json::to_string(&example).unwrap();
1004                                        sender
1005                                            .send(line)
1006                                            .await
1007                                            .expect("Failed to send to output writer");
1008                                    } else if args.output.is_none()
1009                                        && !matches!(command, Command::Eval(_))
1010                                    {
1011                                        let line = serde_json::to_string(&example).unwrap();
1012                                        println!("{}", line);
1013                                    }
1014                                }
1015                            }
1016
1017                            let repo_url = &repo_examples.first().unwrap().spec.repository_url;
1018                            let project = repo_examples
1019                                .iter()
1020                                .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
1021                                .or_else(|| app_state.project_cache.get(repo_url));
1022
1023                            if let Some(project) = project {
1024                                let mut cx = cx.clone();
1025
1026                                let shutdown_task: Task<()> =
1027                                    project.update(&mut cx, |project, cx| {
1028                                        let lsp_store = project.lsp_store();
1029                                        lsp_store.update(cx, |lsp_store, cx| {
1030                                            lsp_store.shutdown_all_language_servers(cx)
1031                                        })
1032                                    });
1033
1034                                shutdown_task.await;
1035
1036                                if let Some(ep_store) =
1037                                    cx.update(|cx| EditPredictionStore::try_global(cx))
1038                                {
1039                                    ep_store.update(&mut cx, |store, _| {
1040                                        store.remove_project(&project);
1041                                    });
1042                                }
1043                            }
1044
1045                            app_state.project_cache.remove(repo_url);
1046                            for example in &mut repo_examples {
1047                                example.state.take();
1048                            }
1049                            finished_examples
1050                                .lock()
1051                                .unwrap()
1052                                .extend_from_slice(&repo_examples);
1053                        }
1054                    });
1055                }
1056                futures::future::join_all(tasks).await;
1057
1058                Progress::global().finalize();
1059
1060                match &command {
1061                    Command::Predict(args) | Command::Score(args) => {
1062                        predict::sync_batches(args.provider.as_ref()).await?;
1063                    }
1064                    Command::Eval(args) => {
1065                        predict::sync_batches(args.predict.provider.as_ref()).await?;
1066                    }
1067                    Command::Qa(args) => {
1068                        qa::sync_batches(args).await?;
1069                    }
1070                    Command::Repair(args) => {
1071                        repair::sync_batches(args).await?;
1072                    }
1073                    _ => (),
1074                }
1075
1076                match &command {
1077                    Command::Eval(args) => {
1078                        let examples = finished_examples.lock().unwrap();
1079                        score::print_report(&examples);
1080                        if let Some(summary_path) = &args.summary_json {
1081                            score::write_summary_json(&examples, summary_path)?;
1082                        }
1083                    }
1084                    Command::Repair(args) => {
1085                        let examples = finished_examples.lock().unwrap();
1086                        repair::print_report(&examples, args.confidence_threshold);
1087                    }
1088                    _ => (),
1089                };
1090
1091                // For --in-place, atomically rename temp file to original
1092                if let Some(temp_path) = &in_place_temp_path {
1093                    let final_path = output.as_ref().expect("in_place_temp_path requires output");
1094                    std::fs::rename(temp_path, final_path)
1095                        .expect("Failed to rename temp file to final output");
1096                }
1097
1098                anyhow::Ok(())
1099            }
1100            .await;
1101
1102            if let Err(e) = result {
1103                panic!("Fatal error: {:?}", e);
1104            }
1105
1106            let _ = cx.update(|cx| cx.quit());
1107        })
1108        .detach();
1109    });
1110}
1111
1112async fn handle_error(
1113    error: anyhow::Error,
1114    args: &EpArgs,
1115    command: &Command,
1116    app_state: &Arc<headless::EpAppState>,
1117    failfast_on_single_example: bool,
1118    example: &Example,
1119) {
1120    Progress::global().increment_failed();
1121
1122    let msg;
1123    if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1124        let example_name = example.spec.filename();
1125
1126        let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1127        app_state
1128            .fs
1129            .write(
1130                &failed_example_path,
1131                &serde_json::to_vec_pretty(&example).unwrap(),
1132            )
1133            .await
1134            .unwrap();
1135        let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1136        app_state
1137            .fs
1138            .write(&err_path, format!("{error:?}").as_bytes())
1139            .await
1140            .unwrap();
1141
1142        let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1143        let mut file = OpenOptions::new()
1144            .create(true)
1145            .append(true)
1146            .open(&failed_jsonl_path)
1147            .expect("Failed to open failed.jsonl");
1148        writeln!(file, "{}", serde_json::to_string(example).unwrap())
1149            .expect("Failed to write to failed.jsonl");
1150
1151        let cursor_path = match example.repo_name() {
1152            Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1153            Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1154        };
1155        msg = format!(
1156            indoc::indoc! {"
1157                While processing \"{}\":
1158
1159                \x1b[31m{:?}\x1b[0m
1160
1161                Example:        \x1b[36m{}\x1b[0m
1162                Error file:     \x1b[36m{}\x1b[0m
1163                Cursor file:    \x1b[36m{}\x1b[0m
1164                Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1165            "},
1166            example.spec.name,
1167            error,
1168            failed_example_path.display(),
1169            err_path.display(),
1170            cursor_path.display(),
1171            command,
1172            failed_example_path.display(),
1173        );
1174    } else {
1175        msg = format!(
1176            indoc::indoc! {"
1177            While processing \"{}\":
1178
1179                \x1b[31m{:?}\x1b[0m
1180            "},
1181            example.spec.name, error
1182        );
1183    }
1184
1185    if args.failfast || failfast_on_single_example {
1186        Progress::global().finalize();
1187        panic!("{}", msg);
1188    } else {
1189        log::error!("{}", msg);
1190    }
1191}