main.rs

   1mod anthropic_client;
   2mod distill;
   3mod example;
   4mod filter_languages;
   5mod format_prompt;
   6mod git;
   7mod headless;
   8mod load_project;
   9mod metrics;
  10mod openai_client;
  11mod parse_output;
  12mod paths;
  13mod predict;
  14mod progress;
  15mod prompt_assets;
  16mod pull_examples;
  17mod qa;
  18mod reorder_patch;
  19mod repair;
  20mod retrieve_context;
  21mod reversal_tracking;
  22mod score;
  23mod split_commit;
  24mod split_dataset;
  25
  26mod synthesize;
  27mod truncate_expected_patch;
  28mod word_diff;
  29use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
  30use collections::HashSet;
  31use edit_prediction::EditPredictionStore;
  32use futures::channel::mpsc;
  33use futures::{SinkExt as _, StreamExt as _};
  34use gpui::{AppContext as _, BackgroundExecutor, Task};
  35use zeta_prompt::ZetaFormat;
  36
  37use reqwest_client::ReqwestClient;
  38use serde::{Deserialize, Deserializer, Serialize, Serializer};
  39use std::fmt::Display;
  40use std::fs::{File, OpenOptions};
  41use std::hash::{Hash, Hasher};
  42use std::io::{BufRead, BufReader, BufWriter, Write};
  43use std::sync::Mutex;
  44use std::{path::PathBuf, sync::Arc};
  45
  46use crate::distill::run_distill;
  47use crate::example::{Example, group_examples_by_repo, read_example_files};
  48use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
  49use crate::format_prompt::run_format_prompt;
  50use crate::load_project::run_load_project;
  51use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
  52use crate::predict::run_prediction;
  53use crate::progress::Progress;
  54use crate::retrieve_context::run_context_retrieval;
  55use crate::score::run_scoring;
  56use crate::split_commit::SplitCommitArgs;
  57use crate::split_dataset::SplitArgs;
  58use crate::synthesize::{SynthesizeConfig, run_synthesize};
  59use crate::truncate_expected_patch::TruncatePatchArgs;
  60
  61#[derive(Parser, Debug)]
  62#[command(name = "ep")]
  63struct EpArgs {
  64    #[arg(long, default_value_t = false)]
  65    printenv: bool,
  66    #[clap(long, default_value_t = 10, global = true)]
  67    max_parallelism: usize,
  68    /// The limit for the number of examples to process
  69    /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
  70    #[clap(long, global = true)]
  71    limit: Option<usize>,
  72    #[clap(long, global = true)]
  73    offset: Option<usize>,
  74    /// Filter examples by name
  75    #[clap(long, global = true)]
  76    name: Option<String>,
  77    /// Filter examples by repository
  78    #[clap(long, global = true)]
  79    repo: Option<String>,
  80    #[command(subcommand)]
  81    command: Option<Command>,
  82    #[clap(global = true, help = INPUTS_HELP)]
  83    inputs: Vec<PathBuf>,
  84    #[arg(long, short, global = true)]
  85    output: Option<PathBuf>,
  86    #[arg(long, short, global = true)]
  87    in_place: bool,
  88    #[arg(long, short, global = true)]
  89    failfast: bool,
  90    /// How to handle failed examples in output: keep them or skip them.
  91    /// Failed examples are always logged to the run's failed directory.
  92    #[arg(long, global = true, default_value = "keep")]
  93    failed: FailedHandling,
  94    /// Output as markdown files instead of JSONL. When set, -o specifies a directory
  95    /// where one .md file per example will be written (named after each example).
  96    #[arg(long, short, global = true)]
  97    markdown: bool,
  98}
  99
 100/// Controls whether failed examples are included in the main output.
 101/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
 102#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
 103pub enum FailedHandling {
 104    /// Include failed examples in the main output (default)
 105    #[default]
 106    Keep,
 107    /// Exclude failed examples from the main output
 108    Skip,
 109    /// Skip writing files
 110    SkipNoFiles,
 111}
 112
 113const INPUTS_HELP: &str = r#"
 114Inputs can be file paths or special specifiers:
 115
 116  path
 117      Path to an example(s) file (.md, .json, or .jsonl)
 118
 119  captured-after:{timestamp}
 120      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 121      These are examples captured via the "Capture Edit Prediction Example" action.
 122
 123  rejected-after:{timestamp}
 124      Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
 125      These are predictions that were shown to users but rejected (useful for DPO training).
 126
 127  rated-after:{timestamp}
 128      Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
 129      These are predictions that users explicitly rated as positive or negative via the
 130      rate completions modal. Only zeta2 predictions are included.
 131      - Positive ratings: output becomes expected_patches
 132      - Negative ratings: output becomes rejected_patch
 133
 134  rated-positive-after:{timestamp}
 135      Same as rated-after, but only fetches positively rated predictions.
 136
 137  rated-negative-after:{timestamp}
 138      Same as rated-after, but only fetches negatively rated predictions.
 139
 140      Required environment variables to connect to Snowflake:
 141          EP_SNOWFLAKE_API_KEY
 142          EP_SNOWFLAKE_BASE_URL
 143
 144      Optional:
 145          EP_SNOWFLAKE_ROLE
 146
 147Examples:
 148
 149  # Read examples from a file
 150  ep read examples.jsonl -o output.jsonl
 151
 152  # Read captured examples after a timestamp
 153  ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
 154
 155  # Read rejected predictions for DPO training
 156  ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
 157
 158  # Read user-rated predictions
 159  ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
 160
 161  # Read only positively rated predictions
 162  ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
 163
 164  # Read only negatively rated predictions
 165  ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
 166
 167  # Mix multiple input sources
 168  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
 169"#;
 170
 171#[derive(Subcommand, Debug, Clone)]
 172enum Command {
 173    /// Read examples from files or fetch from Snowflake, output as .jsonl
 174    Read,
 175    /// Create git worktrees for each example and load file contents
 176    LoadProject,
 177    /// Retrieve context for input examples.
 178    Context,
 179    /// Generate a prompt string for a specific model
 180    FormatPrompt(FormatPromptArgs),
 181    /// Runs edit prediction
 182    Predict(PredictArgs),
 183    /// Parse model outputs (actual_output) into unified diffs (actual_patch).
 184    /// Requires format-prompt to have been run first. Uses provider from prompt.
 185    ParseOutput,
 186    /// Computes a score based on actual and expected patches
 187    Score(PredictArgs),
 188    /// Prepares a distillation dataset by copying expected outputs to
 189    /// predicted outputs and removing actual outputs and prompts.
 190    Distill,
 191    /// Print aggregated scores
 192    Eval(EvalArgs),
 193    /// Generate eval examples by analyzing git commits from a repository
 194    Synthesize(SynthesizeArgs),
 195    /// Remove git repositories and worktrees
 196    Clean,
 197    /// Generate an evaluation example by splitting a chronologically-ordered commit
 198    SplitCommit(SplitCommitArgs),
 199    /// Truncate expected patch by the given criteria
 200    TruncatePatch(TruncatePatchArgs),
 201    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
 202    Split(SplitArgs),
 203    /// Filter a JSONL dataset by programming language (based on cursor_path extension)
 204    FilterLanguages(FilterLanguagesArgs),
 205    /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
 206    ImportBatch(ImportBatchArgs),
 207    /// Assess the quality of predictions using LLM-as-a-judge
 208    Qa(qa::QaArgs),
 209    /// Repair predictions that received poor QA scores by generating improved predictions
 210    Repair(repair::RepairArgs),
 211    /// Print all valid zeta formats (lowercase, one per line)
 212    PrintZetaFormats,
 213}
 214
 215impl Display for Command {
 216    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 217        match self {
 218            Command::Read => write!(f, "read"),
 219            Command::LoadProject => write!(f, "load-project"),
 220            Command::Context => write!(f, "context"),
 221            Command::FormatPrompt(args) => {
 222                write!(f, "format-prompt --provider={}", args.provider)
 223            }
 224            Command::Predict(args) => match &args.provider {
 225                Some(provider) => write!(f, "predict --provider={}", provider),
 226                None => write!(f, "predict"),
 227            },
 228            Command::ParseOutput => write!(f, "parse-output"),
 229            Command::Score(args) => match &args.provider {
 230                Some(provider) => write!(f, "score --provider={}", provider),
 231                None => write!(f, "score"),
 232            },
 233            Command::Distill => write!(f, "distill"),
 234            Command::Eval(args) => match &args.predict.provider {
 235                Some(provider) => write!(f, "eval --provider={}", provider),
 236                None => write!(f, "eval"),
 237            },
 238            Command::Synthesize(args) => {
 239                write!(f, "synthesize --repos {}", args.repos.join(" "))
 240            }
 241            Command::Clean => write!(f, "clean"),
 242            Command::SplitCommit(_) => write!(f, "split-commit"),
 243            Command::TruncatePatch(_) => write!(f, "truncate-patch"),
 244            Command::Split(_) => write!(f, "split"),
 245            Command::FilterLanguages(_) => write!(f, "filter-languages"),
 246            Command::ImportBatch(args) => {
 247                write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
 248            }
 249            Command::Qa(_) => {
 250                write!(f, "qa")
 251            }
 252            Command::Repair(_) => {
 253                write!(f, "repair")
 254            }
 255            Command::PrintZetaFormats => {
 256                write!(f, "print-zeta-formats")
 257            }
 258        }
 259    }
 260}
 261
 262#[derive(Debug, Args, Clone)]
 263struct FormatPromptArgs {
 264    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
 265    provider: PredictionProvider,
 266}
 267
 268#[derive(Debug, Args, Clone)]
 269struct PredictArgs {
 270    #[clap(long, short('p'))]
 271    provider: Option<PredictionProvider>,
 272    #[clap(long, default_value_t = 1)]
 273    repetitions: usize,
 274    /// Only use cached responses, don't queue new requests for batching
 275    #[clap(long)]
 276    cache_only: bool,
 277}
 278
 279#[derive(Debug, Args, Clone)]
 280struct EvalArgs {
 281    #[clap(flatten)]
 282    predict: PredictArgs,
 283    /// Path to write summary scores as JSON
 284    #[clap(long)]
 285    summary_json: Option<PathBuf>,
 286}
 287
 288#[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)]
 289pub enum TeacherBackend {
 290    Sonnet46,
 291    #[default]
 292    Sonnet45,
 293    Gpt52,
 294}
 295
 296impl std::fmt::Display for TeacherBackend {
 297    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 298        match self {
 299            TeacherBackend::Sonnet46 => write!(f, "sonnet46"),
 300            TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
 301            TeacherBackend::Gpt52 => write!(f, "gpt52"),
 302        }
 303    }
 304}
 305
 306impl std::str::FromStr for TeacherBackend {
 307    type Err = anyhow::Error;
 308
 309    fn from_str(s: &str) -> Result<Self, Self::Err> {
 310        match s.to_lowercase().as_str() {
 311            "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
 312            "sonnet46" => Ok(TeacherBackend::Sonnet46),
 313            "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
 314            "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
 315            _ => anyhow::bail!(
 316                "unknown teacher backend `{s}`. Valid options: sonnet45, sonnet46, gpt52"
 317            ),
 318        }
 319    }
 320}
 321
 322impl TeacherBackend {
 323    pub fn model_name(&self) -> &'static str {
 324        match self {
 325            TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
 326            TeacherBackend::Sonnet46 => "claude-sonnet-4-6",
 327            TeacherBackend::Gpt52 => "gpt-5.2",
 328        }
 329    }
 330}
 331
 332#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 333enum PredictionProvider {
 334    Sweep,
 335    Mercury,
 336    Zeta1,
 337    Zeta2(ZetaFormat),
 338    Teacher(TeacherBackend),
 339    TeacherNonBatching(TeacherBackend),
 340    Repair,
 341}
 342
 343impl Default for PredictionProvider {
 344    fn default() -> Self {
 345        PredictionProvider::Zeta2(ZetaFormat::default())
 346    }
 347}
 348
 349impl std::fmt::Display for PredictionProvider {
 350    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 351        match self {
 352            PredictionProvider::Sweep => write!(f, "sweep"),
 353            PredictionProvider::Mercury => write!(f, "mercury"),
 354            PredictionProvider::Zeta1 => write!(f, "zeta1"),
 355            PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
 356            PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
 357            PredictionProvider::TeacherNonBatching(backend) => {
 358                write!(f, "teacher-non-batching:{backend}")
 359            }
 360            PredictionProvider::Repair => write!(f, "repair"),
 361        }
 362    }
 363}
 364
 365impl std::str::FromStr for PredictionProvider {
 366    type Err = anyhow::Error;
 367
 368    fn from_str(s: &str) -> Result<Self, Self::Err> {
 369        let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
 370
 371        let provider_lower = provider.to_lowercase();
 372        match provider_lower.as_str() {
 373            "sweep" => Ok(PredictionProvider::Sweep),
 374            "mercury" => Ok(PredictionProvider::Mercury),
 375            "zeta1" => Ok(PredictionProvider::Zeta1),
 376            "zeta2" => {
 377                let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
 378                Ok(PredictionProvider::Zeta2(format))
 379            }
 380            "teacher" => {
 381                let backend = arg
 382                    .map(|a| a.parse())
 383                    .transpose()?
 384                    .unwrap_or(TeacherBackend::default());
 385                Ok(PredictionProvider::Teacher(backend))
 386            }
 387            "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
 388                let backend = arg
 389                    .map(|a| a.parse())
 390                    .transpose()?
 391                    .unwrap_or(TeacherBackend::default());
 392                Ok(PredictionProvider::TeacherNonBatching(backend))
 393            }
 394            "repair" => Ok(PredictionProvider::Repair),
 395            _ => {
 396                anyhow::bail!(
 397                    "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
 398                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
 399                 For teacher, you can specify a backend like `teacher:sonnet46` or `teacher:gpt52`.\n\
 400                 Available zeta versions:\n{}",
 401                    ZetaFormat::options_as_string()
 402                )
 403            }
 404        }
 405    }
 406}
 407
 408impl Serialize for PredictionProvider {
 409    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 410    where
 411        S: Serializer,
 412    {
 413        serializer.serialize_str(&self.to_string())
 414    }
 415}
 416
 417impl<'de> Deserialize<'de> for PredictionProvider {
 418    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 419    where
 420        D: Deserializer<'de>,
 421    {
 422        let s = String::deserialize(deserializer)?;
 423        s.parse().map_err(serde::de::Error::custom)
 424    }
 425}
 426
 427#[derive(Debug, Args, Clone)]
 428struct SynthesizeArgs {
 429    /// Repository URLs (git@github.com:owner/repo or https://...)
 430    #[clap(long, required = true, num_args = 1..)]
 431    repos: Vec<String>,
 432
 433    /// Number of examples to generate per repository
 434    #[clap(long, default_value_t = 5)]
 435    count: usize,
 436
 437    /// Maximum commits to scan per repository before giving up
 438    #[clap(long, default_value_t = 100)]
 439    max_commits: usize,
 440
 441    /// Ignore state file and reprocess all commits
 442    #[clap(long)]
 443    fresh: bool,
 444}
 445
 446#[derive(Debug, Args, Clone)]
 447struct ImportBatchArgs {
 448    /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
 449    #[clap(long, required = true, num_args = 1..)]
 450    batch_ids: Vec<String>,
 451    /// Which provider's batches to import (anthropic or openai)
 452    #[clap(long, default_value = "anthropic")]
 453    provider: BatchProvider,
 454}
 455
 456#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
 457enum BatchProvider {
 458    Anthropic,
 459    Openai,
 460}
 461
 462impl EpArgs {
 463    fn output_path(&self) -> Option<PathBuf> {
 464        if self.in_place {
 465            if self.inputs.len() == 1 {
 466                self.inputs.first().cloned()
 467            } else {
 468                panic!("--in-place requires exactly one input file")
 469            }
 470        } else {
 471            self.output.clone()
 472        }
 473    }
 474}
 475
 476/// Minimum Zed version required for Snowflake queries.
 477/// This version introduced the current request schema with predicted edits in the edit
 478/// history, and open source repos distinguished.
 479const MIN_CAPTURE_VERSION: pull_examples::MinCaptureVersion = pull_examples::MinCaptureVersion {
 480    minor: 224,
 481    patch: 1,
 482};
 483
 484async fn load_examples(
 485    http_client: Arc<dyn http_client::HttpClient>,
 486    args: &EpArgs,
 487    output_path: Option<&PathBuf>,
 488    background_executor: BackgroundExecutor,
 489) -> anyhow::Result<Vec<Example>> {
 490    let mut captured_after_timestamps = Vec::new();
 491    let mut rejected_after_timestamps = Vec::new();
 492    let mut requested_after_timestamps = Vec::new();
 493    let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
 494        Vec::new();
 495    let mut file_inputs = Vec::new();
 496
 497    for input in &args.inputs {
 498        let input_string = input.to_string_lossy();
 499        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
 500            captured_after_timestamps.push(timestamp.to_string());
 501        } else if let Some(timestamp) =
 502            pull_examples::parse_rejected_after_input(input_string.as_ref())
 503        {
 504            rejected_after_timestamps.push(timestamp.to_string());
 505        } else if let Some(timestamp) =
 506            pull_examples::parse_requested_after_input(input_string.as_ref())
 507        {
 508            requested_after_timestamps.push(timestamp.to_string());
 509        } else if let Some((timestamp, rating_filter)) =
 510            pull_examples::parse_rated_after_input(input_string.as_ref())
 511        {
 512            rated_after_inputs.push((timestamp.to_string(), rating_filter));
 513        } else {
 514            file_inputs.push(input.clone());
 515        }
 516    }
 517
 518    let mut examples = read_example_files(&file_inputs);
 519
 520    // Apply offset to file examples first, then pass remaining offset to Snowflake.
 521    let file_example_count = examples.len();
 522    let remaining_offset = if let Some(offset) = args.offset {
 523        if offset >= file_example_count {
 524            examples.clear();
 525            offset - file_example_count
 526        } else {
 527            examples.splice(0..offset, []);
 528            0
 529        }
 530    } else {
 531        0
 532    };
 533
 534    Progress::global().set_total_examples(examples.len());
 535
 536    let remaining_limit_for_snowflake =
 537        args.limit.map(|limit| limit.saturating_sub(examples.len()));
 538
 539    if let Some(0) = remaining_limit_for_snowflake {
 540        log::info!(
 541            "skipping Snowflake inputs because --limit is already satisfied by example files"
 542        );
 543    } else {
 544        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
 545
 546        if !captured_after_timestamps.is_empty() {
 547            captured_after_timestamps.sort();
 548
 549            let mut captured_examples = pull_examples::fetch_captured_examples_after(
 550                http_client.clone(),
 551                &captured_after_timestamps,
 552                max_rows_per_timestamp,
 553                remaining_offset,
 554                background_executor.clone(),
 555                Some(MIN_CAPTURE_VERSION),
 556            )
 557            .await?;
 558            examples.append(&mut captured_examples);
 559        }
 560
 561        if !rejected_after_timestamps.is_empty() {
 562            rejected_after_timestamps.sort();
 563
 564            let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
 565                http_client.clone(),
 566                &rejected_after_timestamps,
 567                max_rows_per_timestamp,
 568                remaining_offset,
 569                background_executor.clone(),
 570                Some(MIN_CAPTURE_VERSION),
 571            )
 572            .await?;
 573            examples.append(&mut rejected_examples);
 574        }
 575
 576        if !requested_after_timestamps.is_empty() {
 577            requested_after_timestamps.sort();
 578
 579            let mut requested_examples = pull_examples::fetch_requested_examples_after(
 580                http_client.clone(),
 581                &requested_after_timestamps,
 582                max_rows_per_timestamp,
 583                remaining_offset,
 584                background_executor.clone(),
 585                Some(MIN_CAPTURE_VERSION),
 586            )
 587            .await?;
 588            examples.append(&mut requested_examples);
 589        }
 590
 591        if !rated_after_inputs.is_empty() {
 592            rated_after_inputs.sort();
 593
 594            let mut rated_examples = pull_examples::fetch_rated_examples_after(
 595                http_client,
 596                &rated_after_inputs,
 597                max_rows_per_timestamp,
 598                remaining_offset,
 599                background_executor,
 600                Some(MIN_CAPTURE_VERSION),
 601            )
 602            .await?;
 603            examples.append(&mut rated_examples);
 604        }
 605    }
 606
 607    crate::example::sort_examples_by_repo_and_rev(&mut examples);
 608
 609    if let Some(name_filter) = &args.name {
 610        examples.retain(|example| example.spec.name.contains(name_filter));
 611    }
 612    if let Some(repo_filter) = &args.repo {
 613        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
 614    }
 615
 616    // Skip resume logic for --in-place since input and output are the same file,
 617    // which would incorrectly treat all input examples as already processed.
 618    if !args.in_place {
 619        if let Some(path) = output_path
 620            && let Some(command) = &args.command
 621        {
 622            resume_from_output(path, &mut examples, command);
 623        }
 624    }
 625
 626    if let Some(limit) = args.limit {
 627        examples.truncate(limit);
 628    }
 629
 630    let progress = Progress::global();
 631    progress.set_total_examples(examples.len());
 632    progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
 633
 634    Ok(examples)
 635}
 636
 637fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
 638    let mut hasher = collections::FxHasher::default();
 639    spec.hash(&mut hasher);
 640    hasher.finish()
 641}
 642
 643fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
 644    let file = match File::open(path) {
 645        Ok(f) => f,
 646        Err(_) => return,
 647    };
 648
 649    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
 650
 651    let reader = BufReader::new(file);
 652    let mut kept_lines = Vec::new();
 653    let mut kept_hashes = HashSet::default();
 654
 655    for line in reader.lines() {
 656        let line = match line {
 657            Ok(l) => l,
 658            Err(_) => continue,
 659        };
 660
 661        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
 662            let hash = spec_hash(&output_example.spec);
 663            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
 664                let is_complete = match command {
 665                    Command::Qa(_) => output_example
 666                        .qa
 667                        .first()
 668                        .and_then(|q| q.as_ref())
 669                        .and_then(|q| q.confidence)
 670                        .is_some(),
 671                    Command::Repair(_) => output_example.predictions.iter().any(|p| {
 672                        p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
 673                    }),
 674                    _ => true,
 675                };
 676                if is_complete {
 677                    kept_hashes.insert(hash);
 678                    kept_lines.push(line);
 679                }
 680            }
 681        }
 682    }
 683
 684    let total = examples.len();
 685    let already_processed = kept_hashes.len();
 686
 687    eprintln!(
 688        "Resuming: {}/{} examples already processed",
 689        already_processed, total
 690    );
 691
 692    let file = OpenOptions::new()
 693        .write(true)
 694        .truncate(true)
 695        .open(path)
 696        .expect("Failed to open output file for rewriting");
 697    let mut writer = BufWriter::new(file);
 698    for line in &kept_lines {
 699        writeln!(writer, "{}", line).expect("Failed to write to output file");
 700    }
 701    writer.flush().expect("Failed to flush output file");
 702
 703    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
 704}
 705
 706fn main() {
 707    let args = EpArgs::parse();
 708
 709    if args.printenv {
 710        ::util::shell_env::print_env();
 711        return;
 712    }
 713
 714    let output = args.output_path();
 715
 716    if args.markdown && output.is_none() {
 717        eprintln!("--markdown requires -o to specify the output directory");
 718        std::process::exit(1);
 719    }
 720
 721    let command = match &args.command {
 722        Some(cmd) => cmd.clone(),
 723        None => {
 724            EpArgs::command().print_help().unwrap();
 725            return;
 726        }
 727    };
 728
 729    match &command {
 730        Command::ImportBatch(import_args) => {
 731            smol::block_on(async {
 732                match import_args.provider {
 733                    BatchProvider::Anthropic => {
 734                        let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
 735                            .expect("Failed to create Anthropic client");
 736                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 737                            eprintln!("Error importing Anthropic batches: {:?}", e);
 738                            std::process::exit(1);
 739                        }
 740                    }
 741                    BatchProvider::Openai => {
 742                        let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
 743                            .expect("Failed to create OpenAI client");
 744                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 745                            eprintln!("Error importing OpenAI batches: {:?}", e);
 746                            std::process::exit(1);
 747                        }
 748                    }
 749                }
 750                println!(
 751                    "Successfully imported {} batch(es)",
 752                    import_args.batch_ids.len()
 753                );
 754            });
 755            return;
 756        }
 757        Command::Clean => {
 758            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
 759            return;
 760        }
 761        Command::PrintZetaFormats => {
 762            use strum::IntoEnumIterator as _;
 763            for format in ZetaFormat::iter() {
 764                println!("{}", format.to_string().to_lowercase());
 765            }
 766            return;
 767        }
 768
 769        Command::Synthesize(synth_args) => {
 770            let Some(output_dir) = args.output else {
 771                panic!("output dir is required");
 772            };
 773            let config = SynthesizeConfig {
 774                repo_urls: synth_args.repos.clone(),
 775                count: synth_args.count,
 776                max_commits: synth_args.max_commits,
 777                output_dir,
 778                fresh: synth_args.fresh,
 779            };
 780            smol::block_on(async {
 781                if let Err(e) = run_synthesize(config).await {
 782                    eprintln!("Error: {:?}", e);
 783                    std::process::exit(1);
 784                }
 785            });
 786            return;
 787        }
 788        Command::SplitCommit(split_commit_args) => {
 789            if let Err(error) = split_commit::run_split_commit(
 790                split_commit_args,
 791                &args.inputs,
 792                output.as_ref(),
 793                args.failed,
 794            ) {
 795                eprintln!("{error:#}");
 796                std::process::exit(1);
 797            }
 798            return;
 799        }
 800        Command::TruncatePatch(truncate_args) => {
 801            if let Err(error) =
 802                truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
 803            {
 804                eprintln!("{error:#}");
 805                std::process::exit(1);
 806            }
 807            return;
 808        }
 809        Command::Split(split_args) => {
 810            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
 811                eprintln!("{error:#}");
 812                std::process::exit(1);
 813            }
 814            return;
 815        }
 816        Command::FilterLanguages(filter_args) => {
 817            if let Err(error) =
 818                run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
 819            {
 820                eprintln!("{error:#}");
 821                std::process::exit(1);
 822            }
 823            return;
 824        }
 825
 826        _ => {}
 827    }
 828
 829    let http_client = Arc::new(ReqwestClient::new());
 830    let app = gpui_platform::headless().with_http_client(http_client);
 831
 832    app.run(move |cx| {
 833        let app_state = Arc::new(headless::init(cx));
 834        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
 835
 836        cx.spawn(async move |cx| {
 837            let result = async {
 838                let examples = load_examples(
 839                    app_state.client.http_client(),
 840                    &args,
 841                    output.as_ref(),
 842                    cx.background_executor().clone(),
 843                )
 844                .await?;
 845
 846                match &command {
 847                    Command::Predict(args) | Command::Score(args) => {
 848                        predict::sync_batches(args.provider.as_ref()).await?;
 849                    }
 850                    Command::Eval(args) => {
 851                        predict::sync_batches(args.predict.provider.as_ref()).await?;
 852                    }
 853                    Command::Qa(args) => {
 854                        qa::sync_batches(args).await?;
 855                    }
 856                    Command::Repair(args) => {
 857                        repair::sync_batches(args).await?;
 858                    }
 859                    _ => (),
 860                }
 861
 862                let failfast_on_single_example = examples.len() == 1;
 863
 864                // For --markdown mode, create the output directory if it doesn't exist
 865                if args.markdown {
 866                    let dir = output.as_ref().expect("--markdown requires -o");
 867                    if !dir.exists() {
 868                        std::fs::create_dir_all(dir)
 869                            .expect("Failed to create markdown output directory");
 870                    }
 871                }
 872
 873                // Set up JSONL output writer (not used in markdown mode)
 874                let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
 875                let mut in_place_temp_path: Option<PathBuf> = None;
 876                if !args.markdown
 877                    && let Some(output_path) = output.as_ref()
 878                {
 879                    let write_path = if args.in_place {
 880                        let temp = output_path.with_extension("jsonl.tmp");
 881                        in_place_temp_path = Some(temp.clone());
 882                        temp
 883                    } else {
 884                        output_path.clone()
 885                    };
 886
 887                    let file = OpenOptions::new()
 888                        .create(true)
 889                        .write(true)
 890                        .truncate(args.in_place)
 891                        .append(!args.in_place)
 892                        .open(&write_path)
 893                        .expect("Failed to open output file");
 894
 895                    let mut writer = BufWriter::new(file);
 896                    let (sender, mut receiver) = mpsc::unbounded::<String>();
 897                    cx.background_spawn(async move {
 898                        while let Some(line) = receiver.next().await {
 899                            writeln!(writer, "{}", line).expect("Failed to write example");
 900                            writer.flush().expect("Failed to flush output");
 901                        }
 902                    })
 903                    .detach();
 904                    output_sender = Some(sender);
 905                }
 906
 907                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
 908                let finished_examples = Mutex::new(Vec::new());
 909
 910                let mut tasks = Vec::new();
 911                for _ in 0..args.max_parallelism {
 912                    tasks.push(async {
 913                        loop {
 914                            let Some(mut repo_examples) =
 915                                grouped_examples.lock().unwrap().pop_front()
 916                            else {
 917                                break;
 918                            };
 919                            for example in &mut repo_examples {
 920                                let example_progress =
 921                                    Progress::global().start_group(&example.spec.name);
 922
 923                                let result = async {
 924                                    match &command {
 925                                        Command::Read => {}
 926                                        Command::LoadProject => {
 927                                            run_load_project(
 928                                                example,
 929                                                app_state.clone(),
 930                                                &example_progress,
 931                                                cx.clone(),
 932                                            )
 933                                            .await?;
 934                                        }
 935                                        Command::Context => {
 936                                            run_context_retrieval(
 937                                                example,
 938                                                app_state.clone(),
 939                                                &example_progress,
 940                                                cx.clone(),
 941                                            )
 942                                            .await?;
 943                                        }
 944                                        Command::FormatPrompt(args) => {
 945                                            run_format_prompt(
 946                                                example,
 947                                                args,
 948                                                app_state.clone(),
 949                                                &example_progress,
 950                                                cx.clone(),
 951                                            )
 952                                            .await?;
 953                                        }
 954                                        Command::Predict(args) => {
 955                                            run_prediction(
 956                                                example,
 957                                                args,
 958                                                app_state.clone(),
 959                                                &example_progress,
 960                                                cx.clone(),
 961                                            )
 962                                            .await?;
 963                                        }
 964                                        Command::ParseOutput => {
 965                                            parse_output::run_parse_output(example)?;
 966                                        }
 967                                        Command::Distill => {
 968                                            run_distill(example).await?;
 969                                        }
 970                                        Command::Score(args) => {
 971                                            run_scoring(
 972                                                example,
 973                                                args,
 974                                                app_state.clone(),
 975                                                &example_progress,
 976                                                cx.clone(),
 977                                            )
 978                                            .await?;
 979                                        }
 980                                        Command::Eval(args) => {
 981                                            run_scoring(
 982                                                example,
 983                                                &args.predict,
 984                                                app_state.clone(),
 985                                                &example_progress,
 986                                                cx.clone(),
 987                                            )
 988                                            .await?;
 989                                        }
 990                                        Command::Qa(args) => {
 991                                            qa::run_qa(example, args, &example_progress).await?;
 992                                        }
 993                                        Command::Repair(args) => {
 994                                            repair::run_repair(example, args, &example_progress)
 995                                                .await?;
 996                                        }
 997                                        Command::Clean
 998                                        | Command::Synthesize(_)
 999                                        | Command::SplitCommit(_)
1000                                        | Command::Split(_)
1001                                        | Command::TruncatePatch(_)
1002                                        | Command::FilterLanguages(_)
1003                                        | Command::ImportBatch(_)
1004                                        | Command::PrintZetaFormats => {
1005                                            unreachable!()
1006                                        }
1007                                    }
1008                                    anyhow::Ok(())
1009                                }
1010                                .await;
1011
1012                                let failed = if let Err(error) = result {
1013                                    handle_error(
1014                                        error,
1015                                        &args,
1016                                        &command,
1017                                        &app_state,
1018                                        failfast_on_single_example,
1019                                        &example,
1020                                    )
1021                                    .await;
1022                                    true
1023                                } else {
1024                                    false
1025                                };
1026
1027                                let should_write = !failed || args.failed == FailedHandling::Keep;
1028                                if should_write {
1029                                    if args.markdown {
1030                                        let markdown_dir =
1031                                            output.as_ref().expect("--markdown requires -o");
1032                                        let filename = format!("{}.md", example.spec.filename());
1033                                        let path = markdown_dir.join(&filename);
1034                                        let markdown = example.spec.to_markdown();
1035                                        std::fs::write(&path, &markdown)
1036                                            .expect("Failed to write markdown file");
1037                                    } else if let Some(ref mut sender) = output_sender.clone() {
1038                                        let line = serde_json::to_string(&example).unwrap();
1039                                        sender
1040                                            .send(line)
1041                                            .await
1042                                            .expect("Failed to send to output writer");
1043                                    } else if args.output.is_none()
1044                                        && !matches!(command, Command::Eval(_))
1045                                    {
1046                                        let line = serde_json::to_string(&example).unwrap();
1047                                        println!("{}", line);
1048                                    }
1049                                }
1050                            }
1051
1052                            let project = repo_examples
1053                                .iter()
1054                                .find_map(|e| e.state.as_ref().map(|s| s.project.clone()));
1055
1056                            if let Some(project) = project {
1057                                let mut cx = cx.clone();
1058
1059                                let shutdown_task: Task<()> =
1060                                    project.update(&mut cx, |project, cx| {
1061                                        let lsp_store = project.lsp_store();
1062                                        lsp_store.update(cx, |lsp_store, cx| {
1063                                            lsp_store.shutdown_all_language_servers(cx)
1064                                        })
1065                                    });
1066
1067                                shutdown_task.await;
1068
1069                                if let Some(ep_store) =
1070                                    cx.update(|cx| EditPredictionStore::try_global(cx))
1071                                {
1072                                    ep_store.update(&mut cx, |store, _| {
1073                                        store.remove_project(&project);
1074                                    });
1075                                }
1076                            }
1077
1078                            for example in &mut repo_examples {
1079                                example.state.take();
1080                            }
1081                            finished_examples
1082                                .lock()
1083                                .unwrap()
1084                                .extend_from_slice(&repo_examples);
1085                        }
1086                    });
1087                }
1088                futures::future::join_all(tasks).await;
1089
1090                Progress::global().finalize();
1091
1092                match &command {
1093                    Command::Predict(args) | Command::Score(args) => {
1094                        predict::sync_batches(args.provider.as_ref()).await?;
1095                    }
1096                    Command::Eval(args) => {
1097                        predict::sync_batches(args.predict.provider.as_ref()).await?;
1098                    }
1099                    Command::Qa(args) => {
1100                        qa::sync_batches(args).await?;
1101                    }
1102                    Command::Repair(args) => {
1103                        repair::sync_batches(args).await?;
1104                    }
1105                    _ => (),
1106                }
1107
1108                match &command {
1109                    Command::Eval(args) => {
1110                        let examples = finished_examples.lock().unwrap();
1111                        score::print_report(&examples);
1112                        if let Some(summary_path) = &args.summary_json {
1113                            score::write_summary_json(&examples, summary_path)?;
1114                        }
1115                    }
1116                    Command::Repair(args) => {
1117                        let examples = finished_examples.lock().unwrap();
1118                        repair::print_report(&examples, args.confidence_threshold);
1119                    }
1120                    _ => (),
1121                };
1122
1123                // For --in-place, atomically rename temp file to original
1124                if let Some(temp_path) = &in_place_temp_path {
1125                    let final_path = output.as_ref().expect("in_place_temp_path requires output");
1126                    std::fs::rename(temp_path, final_path)
1127                        .expect("Failed to rename temp file to final output");
1128                }
1129
1130                anyhow::Ok(())
1131            }
1132            .await;
1133
1134            if let Err(e) = result {
1135                panic!("Fatal error: {:?}", e);
1136            }
1137
1138            let _ = cx.update(|cx| cx.quit());
1139        })
1140        .detach();
1141    });
1142}
1143
1144async fn handle_error(
1145    error: anyhow::Error,
1146    args: &EpArgs,
1147    command: &Command,
1148    app_state: &Arc<headless::EpAppState>,
1149    failfast_on_single_example: bool,
1150    example: &Example,
1151) {
1152    Progress::global().increment_failed();
1153
1154    let msg;
1155    if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1156        let example_name = example.spec.filename();
1157
1158        let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1159        app_state
1160            .fs
1161            .write(
1162                &failed_example_path,
1163                &serde_json::to_vec_pretty(&example).unwrap(),
1164            )
1165            .await
1166            .unwrap();
1167        let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1168        app_state
1169            .fs
1170            .write(&err_path, format!("{error:?}").as_bytes())
1171            .await
1172            .unwrap();
1173
1174        let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1175        let mut file = OpenOptions::new()
1176            .create(true)
1177            .append(true)
1178            .open(&failed_jsonl_path)
1179            .expect("Failed to open failed.jsonl");
1180        writeln!(file, "{}", serde_json::to_string(example).unwrap())
1181            .expect("Failed to write to failed.jsonl");
1182
1183        let cursor_path = match example.repo_name() {
1184            Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1185            Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1186        };
1187        msg = format!(
1188            indoc::indoc! {"
1189                While processing \"{}\":
1190
1191                \x1b[31m{:?}\x1b[0m
1192
1193                Example:        \x1b[36m{}\x1b[0m
1194                Error file:     \x1b[36m{}\x1b[0m
1195                Cursor file:    \x1b[36m{}\x1b[0m
1196                Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1197            "},
1198            example.spec.name,
1199            error,
1200            failed_example_path.display(),
1201            err_path.display(),
1202            cursor_path.display(),
1203            command,
1204            failed_example_path.display(),
1205        );
1206    } else {
1207        msg = format!(
1208            indoc::indoc! {"
1209            While processing \"{}\":
1210
1211                \x1b[31m{:?}\x1b[0m
1212            "},
1213            example.spec.name, error
1214        );
1215    }
1216
1217    if args.failfast || failfast_on_single_example {
1218        Progress::global().finalize();
1219        panic!("{}", msg);
1220    } else {
1221        log::error!("{}", msg);
1222    }
1223}