main.rs

   1mod anthropic_client;
   2mod distill;
   3mod example;
   4mod filter_languages;
   5mod format_prompt;
   6mod git;
   7mod headless;
   8mod load_project;
   9mod metrics;
  10mod openai_client;
  11mod parse_output;
  12mod paths;
  13mod predict;
  14mod progress;
  15mod prompt_assets;
  16mod pull_examples;
  17mod qa;
  18mod reorder_patch;
  19mod repair;
  20mod retrieve_context;
  21mod reversal_tracking;
  22mod score;
  23mod split_commit;
  24mod split_dataset;
  25mod synthesize;
  26mod truncate_expected_patch;
  27mod word_diff;
  28use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
  29use collections::HashSet;
  30use edit_prediction::EditPredictionStore;
  31use futures::channel::mpsc;
  32use futures::{SinkExt as _, StreamExt as _};
  33use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
  34use zeta_prompt::ZetaVersion;
  35
  36use reqwest_client::ReqwestClient;
  37use serde::{Deserialize, Deserializer, Serialize, Serializer};
  38use std::fmt::Display;
  39use std::fs::{File, OpenOptions};
  40use std::hash::{Hash, Hasher};
  41use std::io::{BufRead, BufReader, BufWriter, Write};
  42use std::sync::Mutex;
  43use std::{path::PathBuf, sync::Arc};
  44
  45use crate::distill::run_distill;
  46use crate::example::{Example, group_examples_by_repo, read_example_files};
  47use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
  48use crate::format_prompt::run_format_prompt;
  49use crate::load_project::run_load_project;
  50use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
  51use crate::predict::run_prediction;
  52use crate::progress::Progress;
  53use crate::retrieve_context::run_context_retrieval;
  54use crate::score::run_scoring;
  55use crate::split_commit::SplitCommitArgs;
  56use crate::split_dataset::SplitArgs;
  57use crate::synthesize::{SynthesizeConfig, run_synthesize};
  58use crate::truncate_expected_patch::TruncatePatchArgs;
  59
  60#[derive(Parser, Debug)]
  61#[command(name = "ep")]
  62struct EpArgs {
  63    #[arg(long, default_value_t = false)]
  64    printenv: bool,
  65    #[clap(long, default_value_t = 10, global = true)]
  66    max_parallelism: usize,
  67    /// The limit for the number of examples to process
  68    /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
  69    #[clap(long, global = true)]
  70    limit: Option<usize>,
  71    #[clap(long, global = true)]
  72    offset: Option<usize>,
  73    /// Filter examples by name
  74    #[clap(long, global = true)]
  75    name: Option<String>,
  76    /// Filter examples by repository
  77    #[clap(long, global = true)]
  78    repo: Option<String>,
  79    #[command(subcommand)]
  80    command: Option<Command>,
  81    #[clap(global = true, help = INPUTS_HELP)]
  82    inputs: Vec<PathBuf>,
  83    #[arg(long, short, global = true)]
  84    output: Option<PathBuf>,
  85    #[arg(long, short, global = true)]
  86    in_place: bool,
  87    #[arg(long, short, global = true)]
  88    failfast: bool,
  89    /// How to handle failed examples in output: keep them or skip them.
  90    /// Failed examples are always logged to the run's failed directory.
  91    #[arg(long, global = true, default_value = "keep")]
  92    failed: FailedHandling,
  93    /// Output as markdown files instead of JSONL. When set, -o specifies a directory
  94    /// where one .md file per example will be written (named after each example).
  95    #[arg(long, short, global = true)]
  96    markdown: bool,
  97}
  98
  99/// Controls whether failed examples are included in the main output.
 100/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
 101#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
 102pub enum FailedHandling {
 103    /// Include failed examples in the main output (default)
 104    #[default]
 105    Keep,
 106    /// Exclude failed examples from the main output
 107    Skip,
 108    /// Skip writing files
 109    SkipNoFiles,
 110}
 111
 112const INPUTS_HELP: &str = r#"
 113Inputs can be file paths or special specifiers:
 114
 115  path
 116      Path to an example(s) file (.md, .json, or .jsonl)
 117
 118  captured-after:{timestamp}
 119      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 120      These are examples captured via the "Capture Edit Prediction Example" action.
 121
 122  rejected-after:{timestamp}
 123      Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
 124      These are predictions that were shown to users but rejected (useful for DPO training).
 125
 126  rated-after:{timestamp}
 127      Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
 128      These are predictions that users explicitly rated as positive or negative via the
 129      rate completions modal. Only zeta2 predictions are included.
 130      - Positive ratings: output becomes expected_patches
 131      - Negative ratings: output becomes rejected_patch
 132
 133  rated-positive-after:{timestamp}
 134      Same as rated-after, but only fetches positively rated predictions.
 135
 136  rated-negative-after:{timestamp}
 137      Same as rated-after, but only fetches negatively rated predictions.
 138
 139      Required environment variables to connect to Snowflake:
 140          EP_SNOWFLAKE_API_KEY
 141          EP_SNOWFLAKE_BASE_URL
 142
 143      Optional:
 144          EP_SNOWFLAKE_ROLE
 145
 146Examples:
 147
 148  # Read examples from a file
 149  ep read examples.jsonl -o output.jsonl
 150
 151  # Read captured examples after a timestamp
 152  ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
 153
 154  # Read rejected predictions for DPO training
 155  ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
 156
 157  # Read user-rated predictions
 158  ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
 159
 160  # Read only positively rated predictions
 161  ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
 162
 163  # Read only negatively rated predictions
 164  ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
 165
 166  # Mix multiple input sources
 167  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
 168"#;
 169
 170#[derive(Subcommand, Debug, Clone)]
 171enum Command {
 172    /// Read examples from files or fetch from Snowflake, output as .jsonl
 173    Read,
 174    /// Create git worktrees for each example and load file contents
 175    LoadProject,
 176    /// Retrieve context for input examples.
 177    Context,
 178    /// Generate a prompt string for a specific model
 179    FormatPrompt(FormatPromptArgs),
 180    /// Runs edit prediction
 181    Predict(PredictArgs),
 182    /// Parse model outputs (actual_output) into unified diffs (actual_patch).
 183    /// Requires format-prompt to have been run first. Uses provider from prompt.
 184    ParseOutput,
 185    /// Computes a score based on actual and expected patches
 186    Score(PredictArgs),
 187    /// Prepares a distillation dataset by copying expected outputs to
 188    /// predicted outputs and removing actual outputs and prompts.
 189    Distill,
 190    /// Print aggregated scores
 191    Eval(EvalArgs),
 192    /// Generate eval examples by analyzing git commits from a repository
 193    Synthesize(SynthesizeArgs),
 194    /// Remove git repositories and worktrees
 195    Clean,
 196    /// Generate an evaluation example by splitting a chronologically-ordered commit
 197    SplitCommit(SplitCommitArgs),
 198    /// Truncate expected patch by the given criteria
 199    TruncatePatch(TruncatePatchArgs),
 200    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
 201    Split(SplitArgs),
 202    /// Filter a JSONL dataset by programming language (based on cursor_path extension)
 203    FilterLanguages(FilterLanguagesArgs),
 204    /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
 205    ImportBatch(ImportBatchArgs),
 206    /// Assess the quality of predictions using LLM-as-a-judge
 207    Qa(qa::QaArgs),
 208    /// Repair predictions that received poor QA scores by generating improved predictions
 209    Repair(repair::RepairArgs),
 210}
 211
 212impl Display for Command {
 213    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 214        match self {
 215            Command::Read => write!(f, "read"),
 216            Command::LoadProject => write!(f, "load-project"),
 217            Command::Context => write!(f, "context"),
 218            Command::FormatPrompt(args) => {
 219                write!(f, "format-prompt --provider={}", args.provider)
 220            }
 221            Command::Predict(args) => match &args.provider {
 222                Some(provider) => write!(f, "predict --provider={}", provider),
 223                None => write!(f, "predict"),
 224            },
 225            Command::ParseOutput => write!(f, "parse-output"),
 226            Command::Score(args) => match &args.provider {
 227                Some(provider) => write!(f, "score --provider={}", provider),
 228                None => write!(f, "score"),
 229            },
 230            Command::Distill => write!(f, "distill"),
 231            Command::Eval(args) => match &args.predict.provider {
 232                Some(provider) => write!(f, "eval --provider={}", provider),
 233                None => write!(f, "eval"),
 234            },
 235            Command::Synthesize(args) => {
 236                write!(f, "synthesize --repos {}", args.repos.join(" "))
 237            }
 238            Command::Clean => write!(f, "clean"),
 239            Command::SplitCommit(_) => write!(f, "split-commit"),
 240            Command::TruncatePatch(_) => write!(f, "truncate-patch"),
 241            Command::Split(_) => write!(f, "split"),
 242            Command::FilterLanguages(_) => write!(f, "filter-languages"),
 243            Command::ImportBatch(args) => {
 244                write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
 245            }
 246            Command::Qa(_) => {
 247                write!(f, "qa")
 248            }
 249            Command::Repair(_) => {
 250                write!(f, "repair")
 251            }
 252        }
 253    }
 254}
 255
 256#[derive(Debug, Args, Clone)]
 257struct FormatPromptArgs {
 258    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
 259    provider: PredictionProvider,
 260}
 261
 262#[derive(Debug, Args, Clone)]
 263struct PredictArgs {
 264    #[clap(long, short('p'))]
 265    provider: Option<PredictionProvider>,
 266    #[clap(long, default_value_t = 1)]
 267    repetitions: usize,
 268    /// Only use cached responses, don't queue new requests for batching
 269    #[clap(long)]
 270    cache_only: bool,
 271}
 272
 273#[derive(Debug, Args, Clone)]
 274struct EvalArgs {
 275    #[clap(flatten)]
 276    predict: PredictArgs,
 277    /// Path to write summary scores as JSON
 278    #[clap(long)]
 279    summary_json: Option<PathBuf>,
 280}
 281
 282#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 283pub enum TeacherBackend {
 284    Sonnet45,
 285    Gpt52,
 286}
 287
 288impl std::fmt::Display for TeacherBackend {
 289    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 290        match self {
 291            TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
 292            TeacherBackend::Gpt52 => write!(f, "gpt52"),
 293        }
 294    }
 295}
 296
 297impl std::str::FromStr for TeacherBackend {
 298    type Err = anyhow::Error;
 299
 300    fn from_str(s: &str) -> Result<Self, Self::Err> {
 301        match s.to_lowercase().as_str() {
 302            "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
 303            "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
 304            "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
 305            _ => anyhow::bail!("unknown teacher backend `{s}`. Valid options: sonnet45, gpt52"),
 306        }
 307    }
 308}
 309
 310impl TeacherBackend {
 311    pub fn model_name(&self) -> &'static str {
 312        match self {
 313            TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
 314            TeacherBackend::Gpt52 => "gpt-5.2",
 315        }
 316    }
 317}
 318
 319#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 320enum PredictionProvider {
 321    Sweep,
 322    Mercury,
 323    Zeta1,
 324    Zeta2(ZetaVersion),
 325    Teacher(TeacherBackend),
 326    TeacherNonBatching(TeacherBackend),
 327    Repair,
 328}
 329
 330impl Default for PredictionProvider {
 331    fn default() -> Self {
 332        PredictionProvider::Zeta2(ZetaVersion::default())
 333    }
 334}
 335
 336impl std::fmt::Display for PredictionProvider {
 337    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 338        match self {
 339            PredictionProvider::Sweep => write!(f, "sweep"),
 340            PredictionProvider::Mercury => write!(f, "mercury"),
 341            PredictionProvider::Zeta1 => write!(f, "zeta1"),
 342            PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
 343            PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
 344            PredictionProvider::TeacherNonBatching(backend) => {
 345                write!(f, "teacher-non-batching:{backend}")
 346            }
 347            PredictionProvider::Repair => write!(f, "repair"),
 348        }
 349    }
 350}
 351
 352impl std::str::FromStr for PredictionProvider {
 353    type Err = anyhow::Error;
 354
 355    fn from_str(s: &str) -> Result<Self, Self::Err> {
 356        let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
 357
 358        let provider_lower = provider.to_lowercase();
 359        match provider_lower.as_str() {
 360            "sweep" => Ok(PredictionProvider::Sweep),
 361            "mercury" => Ok(PredictionProvider::Mercury),
 362            "zeta1" => Ok(PredictionProvider::Zeta1),
 363            "zeta2" => {
 364                let version = arg.map(ZetaVersion::parse).transpose()?.unwrap_or_default();
 365                Ok(PredictionProvider::Zeta2(version))
 366            }
 367            "teacher" => {
 368                let backend = arg
 369                    .map(|a| a.parse())
 370                    .transpose()?
 371                    .unwrap_or(TeacherBackend::Sonnet45);
 372                Ok(PredictionProvider::Teacher(backend))
 373            }
 374            "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
 375                let backend = arg
 376                    .map(|a| a.parse())
 377                    .transpose()?
 378                    .unwrap_or(TeacherBackend::Sonnet45);
 379                Ok(PredictionProvider::TeacherNonBatching(backend))
 380            }
 381            "repair" => Ok(PredictionProvider::Repair),
 382            _ => {
 383                anyhow::bail!(
 384                    "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
 385                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
 386                 For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\
 387                 Available zeta versions:\n{}",
 388                    ZetaVersion::options_as_string()
 389                )
 390            }
 391        }
 392    }
 393}
 394
 395impl Serialize for PredictionProvider {
 396    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 397    where
 398        S: Serializer,
 399    {
 400        serializer.serialize_str(&self.to_string())
 401    }
 402}
 403
 404impl<'de> Deserialize<'de> for PredictionProvider {
 405    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 406    where
 407        D: Deserializer<'de>,
 408    {
 409        let s = String::deserialize(deserializer)?;
 410        s.parse().map_err(serde::de::Error::custom)
 411    }
 412}
 413
 414#[derive(Debug, Args, Clone)]
 415struct SynthesizeArgs {
 416    /// Repository URLs (git@github.com:owner/repo or https://...)
 417    #[clap(long, required = true, num_args = 1..)]
 418    repos: Vec<String>,
 419
 420    /// Number of examples to generate per repository
 421    #[clap(long, default_value_t = 5)]
 422    count: usize,
 423
 424    /// Maximum commits to scan per repository before giving up
 425    #[clap(long, default_value_t = 100)]
 426    max_commits: usize,
 427
 428    /// Ignore state file and reprocess all commits
 429    #[clap(long)]
 430    fresh: bool,
 431}
 432
 433#[derive(Debug, Args, Clone)]
 434struct ImportBatchArgs {
 435    /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
 436    #[clap(long, required = true, num_args = 1..)]
 437    batch_ids: Vec<String>,
 438    /// Which provider's batches to import (anthropic or openai)
 439    #[clap(long, default_value = "anthropic")]
 440    provider: BatchProvider,
 441}
 442
 443#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
 444enum BatchProvider {
 445    Anthropic,
 446    Openai,
 447}
 448
 449impl EpArgs {
 450    fn output_path(&self) -> Option<PathBuf> {
 451        if self.in_place {
 452            if self.inputs.len() == 1 {
 453                self.inputs.first().cloned()
 454            } else {
 455                panic!("--in-place requires exactly one input file")
 456            }
 457        } else {
 458            self.output.clone()
 459        }
 460    }
 461}
 462
 463async fn load_examples(
 464    http_client: Arc<dyn http_client::HttpClient>,
 465    args: &EpArgs,
 466    output_path: Option<&PathBuf>,
 467    background_executor: BackgroundExecutor,
 468) -> anyhow::Result<Vec<Example>> {
 469    let mut captured_after_timestamps = Vec::new();
 470    let mut rejected_after_timestamps = Vec::new();
 471    let mut requested_after_timestamps = Vec::new();
 472    let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
 473        Vec::new();
 474    let mut file_inputs = Vec::new();
 475
 476    for input in &args.inputs {
 477        let input_string = input.to_string_lossy();
 478        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
 479            captured_after_timestamps.push(timestamp.to_string());
 480        } else if let Some(timestamp) =
 481            pull_examples::parse_rejected_after_input(input_string.as_ref())
 482        {
 483            rejected_after_timestamps.push(timestamp.to_string());
 484        } else if let Some(timestamp) =
 485            pull_examples::parse_requested_after_input(input_string.as_ref())
 486        {
 487            requested_after_timestamps.push(timestamp.to_string());
 488        } else if let Some((timestamp, rating_filter)) =
 489            pull_examples::parse_rated_after_input(input_string.as_ref())
 490        {
 491            rated_after_inputs.push((timestamp.to_string(), rating_filter));
 492        } else {
 493            file_inputs.push(input.clone());
 494        }
 495    }
 496
 497    let mut examples = read_example_files(&file_inputs);
 498
 499    Progress::global().set_total_examples(examples.len());
 500
 501    let remaining_limit_for_snowflake =
 502        args.limit.map(|limit| limit.saturating_sub(examples.len()));
 503
 504    if let Some(0) = remaining_limit_for_snowflake {
 505        log::info!(
 506            "skipping Snowflake inputs because --limit is already satisfied by example files"
 507        );
 508    } else {
 509        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
 510
 511        if !captured_after_timestamps.is_empty() {
 512            captured_after_timestamps.sort();
 513
 514            let mut captured_examples = pull_examples::fetch_captured_examples_after(
 515                http_client.clone(),
 516                &captured_after_timestamps,
 517                max_rows_per_timestamp,
 518                background_executor.clone(),
 519            )
 520            .await?;
 521            examples.append(&mut captured_examples);
 522        }
 523
 524        if !rejected_after_timestamps.is_empty() {
 525            rejected_after_timestamps.sort();
 526
 527            let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
 528                http_client.clone(),
 529                &rejected_after_timestamps,
 530                max_rows_per_timestamp,
 531                background_executor.clone(),
 532            )
 533            .await?;
 534            examples.append(&mut rejected_examples);
 535        }
 536
 537        if !requested_after_timestamps.is_empty() {
 538            requested_after_timestamps.sort();
 539
 540            let mut requested_examples = pull_examples::fetch_requested_examples_after(
 541                http_client.clone(),
 542                &requested_after_timestamps,
 543                max_rows_per_timestamp,
 544                background_executor.clone(),
 545            )
 546            .await?;
 547            examples.append(&mut requested_examples);
 548        }
 549
 550        if !rated_after_inputs.is_empty() {
 551            rated_after_inputs.sort();
 552
 553            let mut rated_examples = pull_examples::fetch_rated_examples_after(
 554                http_client,
 555                &rated_after_inputs,
 556                max_rows_per_timestamp,
 557                background_executor,
 558            )
 559            .await?;
 560            examples.append(&mut rated_examples);
 561        }
 562    }
 563
 564    crate::example::sort_examples_by_repo_and_rev(&mut examples);
 565
 566    if let Some(name_filter) = &args.name {
 567        examples.retain(|example| example.spec.name.contains(name_filter));
 568    }
 569    if let Some(repo_filter) = &args.repo {
 570        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
 571    }
 572
 573    // Skip resume logic for --in-place since input and output are the same file,
 574    // which would incorrectly treat all input examples as already processed.
 575    if !args.in_place {
 576        if let Some(path) = output_path
 577            && let Some(command) = &args.command
 578        {
 579            resume_from_output(path, &mut examples, command);
 580        }
 581    }
 582
 583    if let Some(offset) = args.offset {
 584        examples.splice(0..offset, []);
 585    }
 586
 587    if let Some(limit) = args.limit {
 588        examples.truncate(limit);
 589    }
 590
 591    let progress = Progress::global();
 592    progress.set_total_examples(examples.len());
 593    progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
 594
 595    Ok(examples)
 596}
 597
 598fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
 599    let mut hasher = collections::FxHasher::default();
 600    spec.hash(&mut hasher);
 601    hasher.finish()
 602}
 603
 604fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
 605    let file = match File::open(path) {
 606        Ok(f) => f,
 607        Err(_) => return,
 608    };
 609
 610    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
 611
 612    let reader = BufReader::new(file);
 613    let mut kept_lines = Vec::new();
 614    let mut kept_hashes = HashSet::default();
 615
 616    for line in reader.lines() {
 617        let line = match line {
 618            Ok(l) => l,
 619            Err(_) => continue,
 620        };
 621
 622        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
 623            let hash = spec_hash(&output_example.spec);
 624            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
 625                let is_complete = match command {
 626                    Command::Qa(_) => output_example
 627                        .qa
 628                        .first()
 629                        .and_then(|q| q.as_ref())
 630                        .and_then(|q| q.confidence)
 631                        .is_some(),
 632                    Command::Repair(_) => output_example.predictions.iter().any(|p| {
 633                        p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
 634                    }),
 635                    _ => true,
 636                };
 637                if is_complete {
 638                    kept_hashes.insert(hash);
 639                    kept_lines.push(line);
 640                }
 641            }
 642        }
 643    }
 644
 645    let total = examples.len();
 646    let already_processed = kept_hashes.len();
 647
 648    eprintln!(
 649        "Resuming: {}/{} examples already processed",
 650        already_processed, total
 651    );
 652
 653    let file = OpenOptions::new()
 654        .write(true)
 655        .truncate(true)
 656        .open(path)
 657        .expect("Failed to open output file for rewriting");
 658    let mut writer = BufWriter::new(file);
 659    for line in &kept_lines {
 660        writeln!(writer, "{}", line).expect("Failed to write to output file");
 661    }
 662    writer.flush().expect("Failed to flush output file");
 663
 664    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
 665}
 666
 667fn main() {
 668    let args = EpArgs::parse();
 669
 670    if args.printenv {
 671        ::util::shell_env::print_env();
 672        return;
 673    }
 674
 675    let output = args.output_path();
 676
 677    if args.markdown && output.is_none() {
 678        eprintln!("--markdown requires -o to specify the output directory");
 679        std::process::exit(1);
 680    }
 681
 682    let command = match &args.command {
 683        Some(cmd) => cmd.clone(),
 684        None => {
 685            EpArgs::command().print_help().unwrap();
 686            return;
 687        }
 688    };
 689
 690    match &command {
 691        Command::ImportBatch(import_args) => {
 692            smol::block_on(async {
 693                match import_args.provider {
 694                    BatchProvider::Anthropic => {
 695                        let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
 696                            .expect("Failed to create Anthropic client");
 697                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 698                            eprintln!("Error importing Anthropic batches: {:?}", e);
 699                            std::process::exit(1);
 700                        }
 701                    }
 702                    BatchProvider::Openai => {
 703                        let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
 704                            .expect("Failed to create OpenAI client");
 705                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 706                            eprintln!("Error importing OpenAI batches: {:?}", e);
 707                            std::process::exit(1);
 708                        }
 709                    }
 710                }
 711                println!(
 712                    "Successfully imported {} batch(es)",
 713                    import_args.batch_ids.len()
 714                );
 715            });
 716            return;
 717        }
 718        Command::Clean => {
 719            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
 720            return;
 721        }
 722        Command::Synthesize(synth_args) => {
 723            let Some(output_dir) = args.output else {
 724                panic!("output dir is required");
 725            };
 726            let config = SynthesizeConfig {
 727                repo_urls: synth_args.repos.clone(),
 728                count: synth_args.count,
 729                max_commits: synth_args.max_commits,
 730                output_dir,
 731                fresh: synth_args.fresh,
 732            };
 733            smol::block_on(async {
 734                if let Err(e) = run_synthesize(config).await {
 735                    eprintln!("Error: {:?}", e);
 736                    std::process::exit(1);
 737                }
 738            });
 739            return;
 740        }
 741        Command::SplitCommit(split_commit_args) => {
 742            if let Err(error) = split_commit::run_split_commit(
 743                split_commit_args,
 744                &args.inputs,
 745                output.as_ref(),
 746                args.failed,
 747            ) {
 748                eprintln!("{error:#}");
 749                std::process::exit(1);
 750            }
 751            return;
 752        }
 753        Command::TruncatePatch(truncate_args) => {
 754            if let Err(error) =
 755                truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
 756            {
 757                eprintln!("{error:#}");
 758                std::process::exit(1);
 759            }
 760            return;
 761        }
 762        Command::Split(split_args) => {
 763            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
 764                eprintln!("{error:#}");
 765                std::process::exit(1);
 766            }
 767            return;
 768        }
 769        Command::FilterLanguages(filter_args) => {
 770            if let Err(error) =
 771                run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
 772            {
 773                eprintln!("{error:#}");
 774                std::process::exit(1);
 775            }
 776            return;
 777        }
 778
 779        _ => {}
 780    }
 781
 782    let http_client = Arc::new(ReqwestClient::new());
 783    let app = Application::headless().with_http_client(http_client);
 784
 785    app.run(move |cx| {
 786        let app_state = Arc::new(headless::init(cx));
 787        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
 788
 789        cx.spawn(async move |cx| {
 790            let result = async {
 791                let examples = load_examples(
 792                    app_state.client.http_client(),
 793                    &args,
 794                    output.as_ref(),
 795                    cx.background_executor().clone(),
 796                )
 797                .await?;
 798
 799                match &command {
 800                    Command::Predict(args) | Command::Score(args) => {
 801                        predict::sync_batches(args.provider.as_ref()).await?;
 802                    }
 803                    Command::Eval(args) => {
 804                        predict::sync_batches(args.predict.provider.as_ref()).await?;
 805                    }
 806                    Command::Qa(args) => {
 807                        qa::sync_batches(args).await?;
 808                    }
 809                    Command::Repair(args) => {
 810                        repair::sync_batches(args).await?;
 811                    }
 812                    _ => (),
 813                }
 814
 815                let failfast_on_single_example = examples.len() == 1;
 816
 817                // For --markdown mode, create the output directory if it doesn't exist
 818                if args.markdown {
 819                    let dir = output.as_ref().expect("--markdown requires -o");
 820                    if !dir.exists() {
 821                        std::fs::create_dir_all(dir)
 822                            .expect("Failed to create markdown output directory");
 823                    }
 824                }
 825
 826                // Set up JSONL output writer (not used in markdown mode)
 827                let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
 828                let mut in_place_temp_path: Option<PathBuf> = None;
 829                if !args.markdown
 830                    && let Some(output_path) = output.as_ref()
 831                {
 832                    let write_path = if args.in_place {
 833                        let temp = output_path.with_extension("jsonl.tmp");
 834                        in_place_temp_path = Some(temp.clone());
 835                        temp
 836                    } else {
 837                        output_path.clone()
 838                    };
 839
 840                    let file = OpenOptions::new()
 841                        .create(true)
 842                        .write(true)
 843                        .truncate(args.in_place)
 844                        .append(!args.in_place)
 845                        .open(&write_path)
 846                        .expect("Failed to open output file");
 847
 848                    let mut writer = BufWriter::new(file);
 849                    let (sender, mut receiver) = mpsc::unbounded::<String>();
 850                    cx.background_spawn(async move {
 851                        while let Some(line) = receiver.next().await {
 852                            writeln!(writer, "{}", line).expect("Failed to write example");
 853                            writer.flush().expect("Failed to flush output");
 854                        }
 855                    })
 856                    .detach();
 857                    output_sender = Some(sender);
 858                }
 859
 860                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
 861                let finished_examples = Mutex::new(Vec::new());
 862
 863                let mut tasks = Vec::new();
 864                for _ in 0..args.max_parallelism {
 865                    tasks.push(async {
 866                        loop {
 867                            let Some(mut repo_examples) =
 868                                grouped_examples.lock().unwrap().pop_front()
 869                            else {
 870                                break;
 871                            };
 872                            for example in &mut repo_examples {
 873                                let example_progress =
 874                                    Progress::global().start_group(&example.spec.name);
 875
 876                                let result = async {
 877                                    match &command {
 878                                        Command::Read => {}
 879                                        Command::LoadProject => {
 880                                            run_load_project(
 881                                                example,
 882                                                app_state.clone(),
 883                                                &example_progress,
 884                                                cx.clone(),
 885                                            )
 886                                            .await?;
 887                                        }
 888                                        Command::Context => {
 889                                            run_context_retrieval(
 890                                                example,
 891                                                app_state.clone(),
 892                                                &example_progress,
 893                                                cx.clone(),
 894                                            )
 895                                            .await?;
 896                                        }
 897                                        Command::FormatPrompt(args) => {
 898                                            run_format_prompt(
 899                                                example,
 900                                                args,
 901                                                app_state.clone(),
 902                                                &example_progress,
 903                                                cx.clone(),
 904                                            )
 905                                            .await?;
 906                                        }
 907                                        Command::Predict(args) => {
 908                                            run_prediction(
 909                                                example,
 910                                                args,
 911                                                app_state.clone(),
 912                                                &example_progress,
 913                                                cx.clone(),
 914                                            )
 915                                            .await?;
 916                                        }
 917                                        Command::ParseOutput => {
 918                                            parse_output::run_parse_output(example)?;
 919                                        }
 920                                        Command::Distill => {
 921                                            run_distill(example).await?;
 922                                        }
 923                                        Command::Score(args) => {
 924                                            run_scoring(
 925                                                example,
 926                                                args,
 927                                                app_state.clone(),
 928                                                &example_progress,
 929                                                cx.clone(),
 930                                            )
 931                                            .await?;
 932                                        }
 933                                        Command::Eval(args) => {
 934                                            run_scoring(
 935                                                example,
 936                                                &args.predict,
 937                                                app_state.clone(),
 938                                                &example_progress,
 939                                                cx.clone(),
 940                                            )
 941                                            .await?;
 942                                        }
 943                                        Command::Qa(args) => {
 944                                            qa::run_qa(example, args, &example_progress).await?;
 945                                        }
 946                                        Command::Repair(args) => {
 947                                            repair::run_repair(example, args, &example_progress)
 948                                                .await?;
 949                                        }
 950                                        Command::Clean
 951                                        | Command::Synthesize(_)
 952                                        | Command::SplitCommit(_)
 953                                        | Command::Split(_)
 954                                        | Command::TruncatePatch(_)
 955                                        | Command::FilterLanguages(_)
 956                                        | Command::ImportBatch(_) => {
 957                                            unreachable!()
 958                                        }
 959                                    }
 960                                    anyhow::Ok(())
 961                                }
 962                                .await;
 963
 964                                let failed = if let Err(error) = result {
 965                                    handle_error(
 966                                        error,
 967                                        &args,
 968                                        &command,
 969                                        &app_state,
 970                                        failfast_on_single_example,
 971                                        &example,
 972                                    )
 973                                    .await;
 974                                    true
 975                                } else {
 976                                    false
 977                                };
 978
 979                                let should_write = !failed || args.failed == FailedHandling::Keep;
 980                                if should_write {
 981                                    if args.markdown {
 982                                        let markdown_dir =
 983                                            output.as_ref().expect("--markdown requires -o");
 984                                        let filename = format!("{}.md", example.spec.filename());
 985                                        let path = markdown_dir.join(&filename);
 986                                        let markdown = example.spec.to_markdown();
 987                                        std::fs::write(&path, &markdown)
 988                                            .expect("Failed to write markdown file");
 989                                    } else if let Some(ref mut sender) = output_sender.clone() {
 990                                        let line = serde_json::to_string(&example).unwrap();
 991                                        sender
 992                                            .send(line)
 993                                            .await
 994                                            .expect("Failed to send to output writer");
 995                                    } else if args.output.is_none()
 996                                        && !matches!(command, Command::Eval(_))
 997                                    {
 998                                        let line = serde_json::to_string(&example).unwrap();
 999                                        println!("{}", line);
1000                                    }
1001                                }
1002                            }
1003
1004                            let repo_url = &repo_examples.first().unwrap().spec.repository_url;
1005                            let project = repo_examples
1006                                .iter()
1007                                .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
1008                                .or_else(|| app_state.project_cache.get(repo_url));
1009
1010                            if let Some(project) = project {
1011                                let mut cx = cx.clone();
1012
1013                                let shutdown_task: Task<()> =
1014                                    project.update(&mut cx, |project, cx| {
1015                                        let lsp_store = project.lsp_store();
1016                                        lsp_store.update(cx, |lsp_store, cx| {
1017                                            lsp_store.shutdown_all_language_servers(cx)
1018                                        })
1019                                    });
1020
1021                                shutdown_task.await;
1022
1023                                if let Some(ep_store) =
1024                                    cx.update(|cx| EditPredictionStore::try_global(cx))
1025                                {
1026                                    ep_store.update(&mut cx, |store, _| {
1027                                        store.remove_project(&project);
1028                                    });
1029                                }
1030                            }
1031
1032                            app_state.project_cache.remove(repo_url);
1033                            for example in &mut repo_examples {
1034                                example.state.take();
1035                            }
1036                            finished_examples
1037                                .lock()
1038                                .unwrap()
1039                                .extend_from_slice(&repo_examples);
1040                        }
1041                    });
1042                }
1043                futures::future::join_all(tasks).await;
1044
1045                Progress::global().finalize();
1046
1047                match &command {
1048                    Command::Predict(args) | Command::Score(args) => {
1049                        predict::sync_batches(args.provider.as_ref()).await?;
1050                    }
1051                    Command::Eval(args) => {
1052                        predict::sync_batches(args.predict.provider.as_ref()).await?;
1053                    }
1054                    Command::Qa(args) => {
1055                        qa::sync_batches(args).await?;
1056                    }
1057                    Command::Repair(args) => {
1058                        repair::sync_batches(args).await?;
1059                    }
1060                    _ => (),
1061                }
1062
1063                match &command {
1064                    Command::Eval(args) => {
1065                        let examples = finished_examples.lock().unwrap();
1066                        score::print_report(&examples);
1067                        if let Some(summary_path) = &args.summary_json {
1068                            score::write_summary_json(&examples, summary_path)?;
1069                        }
1070                    }
1071                    _ => (),
1072                };
1073
1074                // For --in-place, atomically rename temp file to original
1075                if let Some(temp_path) = &in_place_temp_path {
1076                    let final_path = output.as_ref().expect("in_place_temp_path requires output");
1077                    std::fs::rename(temp_path, final_path)
1078                        .expect("Failed to rename temp file to final output");
1079                }
1080
1081                anyhow::Ok(())
1082            }
1083            .await;
1084
1085            if let Err(e) = result {
1086                panic!("Fatal error: {:?}", e);
1087            }
1088
1089            let _ = cx.update(|cx| cx.quit());
1090        })
1091        .detach();
1092    });
1093}
1094
1095async fn handle_error(
1096    error: anyhow::Error,
1097    args: &EpArgs,
1098    command: &Command,
1099    app_state: &Arc<headless::EpAppState>,
1100    failfast_on_single_example: bool,
1101    example: &Example,
1102) {
1103    Progress::global().increment_failed();
1104
1105    let msg;
1106    if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1107        let example_name = example.spec.filename();
1108
1109        let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1110        app_state
1111            .fs
1112            .write(
1113                &failed_example_path,
1114                &serde_json::to_vec_pretty(&example).unwrap(),
1115            )
1116            .await
1117            .unwrap();
1118        let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1119        app_state
1120            .fs
1121            .write(&err_path, format!("{error:?}").as_bytes())
1122            .await
1123            .unwrap();
1124
1125        let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1126        let mut file = OpenOptions::new()
1127            .create(true)
1128            .append(true)
1129            .open(&failed_jsonl_path)
1130            .expect("Failed to open failed.jsonl");
1131        writeln!(file, "{}", serde_json::to_string(example).unwrap())
1132            .expect("Failed to write to failed.jsonl");
1133
1134        let cursor_path = match example.repo_name() {
1135            Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1136            Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1137        };
1138        msg = format!(
1139            indoc::indoc! {"
1140                While processing \"{}\":
1141
1142                \x1b[31m{:?}\x1b[0m
1143
1144                Example:        \x1b[36m{}\x1b[0m
1145                Error file:     \x1b[36m{}\x1b[0m
1146                Cursor file:    \x1b[36m{}\x1b[0m
1147                Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1148            "},
1149            example.spec.name,
1150            error,
1151            failed_example_path.display(),
1152            err_path.display(),
1153            cursor_path.display(),
1154            command,
1155            failed_example_path.display(),
1156        );
1157    } else {
1158        msg = format!(
1159            indoc::indoc! {"
1160            While processing \"{}\":
1161
1162                \x1b[31m{:?}\x1b[0m
1163            "},
1164            example.spec.name, error
1165        );
1166    }
1167
1168    if args.failfast || failfast_on_single_example {
1169        Progress::global().finalize();
1170        panic!("{}", msg);
1171    } else {
1172        log::error!("{}", msg);
1173    }
1174}