main.rs

   1mod anthropic_client;
   2mod distill;
   3mod example;
   4mod filter_languages;
   5mod format_prompt;
   6mod git;
   7mod headless;
   8mod load_project;
   9mod metrics;
  10mod openai_client;
  11mod parse_output;
  12mod paths;
  13mod predict;
  14mod progress;
  15mod prompt_assets;
  16mod pull_examples;
  17mod qa;
  18mod reorder_patch;
  19mod repair;
  20mod retrieve_context;
  21mod reversal_tracking;
  22mod score;
  23mod split_commit;
  24mod split_dataset;
  25
  26mod synthesize;
  27mod truncate_expected_patch;
  28mod word_diff;
  29use anyhow::Context as _;
  30use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
  31use collections::{HashMap, HashSet};
  32use edit_prediction::EditPredictionStore;
  33use futures::channel::mpsc;
  34use futures::{SinkExt as _, StreamExt as _};
  35use gaoya::minhash::{
  36    MinHashIndex, MinHasher, MinHasher32, calculate_minhash_params, compute_minhash_similarity,
  37};
  38use gpui::{AppContext as _, BackgroundExecutor, Task};
  39use zeta_prompt::ZetaFormat;
  40
  41use reqwest_client::ReqwestClient;
  42use serde::{Deserialize, Deserializer, Serialize, Serializer};
  43use std::env;
  44use std::fmt::Display;
  45use std::fs::{File, OpenOptions};
  46use std::hash::{Hash, Hasher};
  47use std::io::{BufRead, BufReader, BufWriter, Write};
  48use std::sync::Mutex;
  49use std::{path::PathBuf, sync::Arc};
  50
  51use crate::distill::run_distill;
  52use crate::example::{Example, group_examples_by_repo, read_example_files};
  53use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
  54use crate::format_prompt::run_format_prompt;
  55use crate::load_project::run_load_project;
  56use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
  57use crate::predict::run_prediction;
  58use crate::progress::Progress;
  59use crate::pull_examples::{fetch_settled_examples_after, parse_settled_after_input};
  60use crate::retrieve_context::run_context_retrieval;
  61use crate::score::run_scoring;
  62use crate::split_commit::SplitCommitArgs;
  63use crate::split_dataset::SplitArgs;
  64use crate::synthesize::{SynthesizeConfig, run_synthesize};
  65use crate::truncate_expected_patch::TruncatePatchArgs;
  66
  67#[derive(Parser, Debug)]
  68#[command(name = "ep")]
  69struct EpArgs {
  70    #[arg(long, default_value_t = false)]
  71    printenv: bool,
  72    #[clap(long, default_value_t = 10, global = true)]
  73    max_parallelism: usize,
  74    /// The limit for the number of examples to process
  75    /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
  76    #[clap(long, global = true)]
  77    limit: Option<usize>,
  78    #[clap(long, global = true)]
  79    offset: Option<usize>,
  80    /// Filter examples by name
  81    #[clap(long, global = true)]
  82    name: Option<String>,
  83    /// Filter examples by repository
  84    #[clap(long, global = true)]
  85    repo: Option<String>,
  86    /// Deduplicate by cursor position and keep at most this many examples per cluster
  87    #[clap(long, global = true)]
  88    max_duplicates: Option<usize>,
  89    #[command(subcommand)]
  90    command: Option<Command>,
  91    /// Input file paths
  92    #[clap(global = true)]
  93    inputs: Vec<PathBuf>,
  94    #[arg(long, short, global = true)]
  95    output: Option<PathBuf>,
  96    #[arg(long, short, global = true)]
  97    in_place: bool,
  98    #[arg(long, short, global = true)]
  99    failfast: bool,
 100    /// How to handle failed examples in output: keep them or skip them.
 101    /// Failed examples are always logged to the run's failed directory.
 102    #[arg(long, global = true, default_value = "keep")]
 103    failed: FailedHandling,
 104    /// Output as markdown files instead of JSONL. When set, -o specifies a directory
 105    /// where one .md file per example will be written (named after each example).
 106    #[arg(long, short, global = true)]
 107    markdown: bool,
 108}
 109
 110/// Controls whether failed examples are included in the main output.
 111/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
 112#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
 113pub enum FailedHandling {
 114    /// Include failed examples in the main output (default)
 115    #[default]
 116    Keep,
 117    /// Exclude failed examples from the main output
 118    Skip,
 119    /// Skip writing files
 120    SkipNoFiles,
 121}
 122
 123const INPUTS_HELP: &str = r#"
 124Inputs can be file paths or special specifiers:
 125
 126  path
 127      Path to an example(s) file (.md, .json, or .jsonl)
 128
 129  captured-after:{timestamp}
 130      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 131      These are examples captured via the "Capture Edit Prediction Example" action.
 132
 133  rejected-after:{timestamp}
 134      Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
 135      These are predictions that were shown to users but rejected (useful for DPO training).
 136
 137  settled-after:{timestamp}
 138      Fetch settled stream examples from Snowflake after the given RFC3339 timestamp.
 139      These are examples from the edit prediction settled stream.
 140
 141  rated-after:{timestamp}
 142      Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
 143      These are predictions that users explicitly rated as positive or negative via the
 144      rate completions modal. Only zeta2 predictions are included.
 145      - Positive ratings: output becomes expected_patches
 146      - Negative ratings: output becomes rejected_patch
 147
 148  rated-positive-after:{timestamp}
 149      Same as rated-after, but only fetches positively rated predictions.
 150
 151  rated-negative-after:{timestamp}
 152      Same as rated-after, but only fetches negatively rated predictions.
 153
 154      Required environment variables to connect to Snowflake:
 155          EP_SNOWFLAKE_API_KEY
 156          EP_SNOWFLAKE_BASE_URL
 157
 158      Optional:
 159          EP_SNOWFLAKE_ROLE
 160
 161Examples:
 162
 163  # Read examples from a file
 164  ep read examples.jsonl -o output.jsonl
 165
 166  # Read captured examples after a timestamp
 167  ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
 168
 169  # Read rejected predictions for DPO training
 170  ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
 171
 172  # Read user-rated predictions
 173  ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
 174
 175  # Read settled stream examples
 176  ep read settled-after:2025-01-01T00:00:00Z -o settled.jsonl
 177
 178  # Read only positively rated predictions
 179  ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
 180
 181  # Read only negatively rated predictions
 182  ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
 183
 184  # Mix multiple input sources
 185  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
 186"#;
 187
 188#[derive(Subcommand, Debug, Clone)]
 189enum Command {
 190    /// Read examples from files or fetch from Snowflake, output as .jsonl
 191    Read(ReadArgs),
 192    /// Create git worktrees for each example and load file contents
 193    LoadProject,
 194    /// Retrieve context for input examples.
 195    Context,
 196    /// Generate a prompt string for a specific model
 197    FormatPrompt(FormatPromptArgs),
 198    /// Runs edit prediction
 199    Predict(PredictArgs),
 200    /// Parse model outputs (actual_output) into unified diffs (actual_patch).
 201    /// Requires format-prompt to have been run first. Uses provider from prompt.
 202    ParseOutput,
 203    /// Computes a score based on actual and expected patches
 204    Score(PredictArgs),
 205    /// Prepares a distillation dataset by copying expected outputs to
 206    /// predicted outputs and removing actual outputs and prompts.
 207    Distill,
 208    /// Print aggregated scores
 209    Eval(EvalArgs),
 210    /// Generate eval examples by analyzing git commits from a repository
 211    Synthesize(SynthesizeArgs),
 212    /// Remove git repositories and worktrees
 213    Clean,
 214    /// Generate an evaluation example by splitting a chronologically-ordered commit
 215    SplitCommit(SplitCommitArgs),
 216    /// Truncate expected patch by the given criteria
 217    TruncatePatch(TruncatePatchArgs),
 218    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
 219    Split(SplitArgs),
 220    /// Filter a JSONL dataset by programming language (based on cursor_path extension)
 221    FilterLanguages(FilterLanguagesArgs),
 222    /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
 223    ImportBatch(ImportBatchArgs),
 224    /// Assess the quality of predictions using LLM-as-a-judge
 225    Qa(qa::QaArgs),
 226    /// Repair predictions that received poor QA scores by generating improved predictions
 227    Repair(repair::RepairArgs),
 228    /// Print all valid zeta formats (lowercase, one per line)
 229    PrintZetaFormats,
 230}
 231
 232impl Display for Command {
 233    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 234        match self {
 235            Command::Read(_) => write!(f, "read"),
 236            Command::LoadProject => write!(f, "load-project"),
 237            Command::Context => write!(f, "context"),
 238            Command::FormatPrompt(args) => {
 239                write!(f, "format-prompt --provider={}", args.provider)
 240            }
 241            Command::Predict(args) => match &args.provider {
 242                Some(provider) => write!(f, "predict --provider={}", provider),
 243                None => write!(f, "predict"),
 244            },
 245            Command::ParseOutput => write!(f, "parse-output"),
 246            Command::Score(args) => match &args.provider {
 247                Some(provider) => write!(f, "score --provider={}", provider),
 248                None => write!(f, "score"),
 249            },
 250            Command::Distill => write!(f, "distill"),
 251            Command::Eval(args) => match &args.predict.provider {
 252                Some(provider) => write!(f, "eval --provider={}", provider),
 253                None => write!(f, "eval"),
 254            },
 255            Command::Synthesize(args) => {
 256                write!(f, "synthesize --repos {}", args.repos.join(" "))
 257            }
 258            Command::Clean => write!(f, "clean"),
 259            Command::SplitCommit(_) => write!(f, "split-commit"),
 260            Command::TruncatePatch(_) => write!(f, "truncate-patch"),
 261            Command::Split(_) => write!(f, "split"),
 262            Command::FilterLanguages(_) => write!(f, "filter-languages"),
 263            Command::ImportBatch(args) => {
 264                write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
 265            }
 266            Command::Qa(_) => {
 267                write!(f, "qa")
 268            }
 269            Command::Repair(_) => {
 270                write!(f, "repair")
 271            }
 272            Command::PrintZetaFormats => {
 273                write!(f, "print-zeta-formats")
 274            }
 275        }
 276    }
 277}
 278
 279#[derive(Debug, Args, Clone)]
 280#[command(after_help = INPUTS_HELP)]
 281struct ReadArgs {}
 282
 283#[derive(Debug, Args, Clone)]
 284struct FormatPromptArgs {
 285    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
 286    provider: PredictionProvider,
 287}
 288
 289#[derive(Debug, Args, Clone)]
 290struct PredictArgs {
 291    #[clap(long, short('p'))]
 292    provider: Option<PredictionProvider>,
 293    #[clap(long, default_value_t = 1)]
 294    repetitions: usize,
 295    /// Only use cached responses, don't queue new requests for batching
 296    #[clap(long)]
 297    cache_only: bool,
 298    /// Wait for all batches to complete before exiting (only applies to batched providers like teacher)
 299    #[clap(long)]
 300    wait: bool,
 301}
 302
 303#[derive(Debug, Args, Clone)]
 304struct EvalArgs {
 305    #[clap(flatten)]
 306    predict: PredictArgs,
 307    /// Path to write summary scores as JSON
 308    #[clap(long)]
 309    summary_json: Option<PathBuf>,
 310    /// Print all individual example lines (default: up to 20)
 311    #[clap(long)]
 312    verbose: bool,
 313}
 314
 315#[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)]
 316pub enum TeacherBackend {
 317    Sonnet46,
 318    #[default]
 319    Sonnet45,
 320    Gpt52,
 321}
 322
 323impl std::fmt::Display for TeacherBackend {
 324    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 325        match self {
 326            TeacherBackend::Sonnet46 => write!(f, "sonnet46"),
 327            TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
 328            TeacherBackend::Gpt52 => write!(f, "gpt52"),
 329        }
 330    }
 331}
 332
 333impl std::str::FromStr for TeacherBackend {
 334    type Err = anyhow::Error;
 335
 336    fn from_str(s: &str) -> Result<Self, Self::Err> {
 337        match s.to_lowercase().as_str() {
 338            "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
 339            "sonnet46" => Ok(TeacherBackend::Sonnet46),
 340            "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
 341            "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
 342            _ => anyhow::bail!(
 343                "unknown teacher backend `{s}`. Valid options: sonnet45, sonnet46, gpt52"
 344            ),
 345        }
 346    }
 347}
 348
 349impl TeacherBackend {
 350    pub fn model_name(&self) -> &'static str {
 351        match self {
 352            TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
 353            TeacherBackend::Sonnet46 => "claude-sonnet-4-6",
 354            TeacherBackend::Gpt52 => "gpt-5.2",
 355        }
 356    }
 357}
 358
 359#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 360enum PredictionProvider {
 361    Mercury,
 362    Zeta1,
 363    Zeta2(ZetaFormat),
 364    Baseten(ZetaFormat),
 365    Teacher(TeacherBackend),
 366    TeacherMultiRegion(TeacherBackend),
 367    TeacherNonBatching(TeacherBackend),
 368    TeacherMultiRegionNonBatching(TeacherBackend),
 369    Repair,
 370}
 371
 372impl Default for PredictionProvider {
 373    fn default() -> Self {
 374        PredictionProvider::Zeta2(ZetaFormat::default())
 375    }
 376}
 377
 378impl std::fmt::Display for PredictionProvider {
 379    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 380        match self {
 381            PredictionProvider::Mercury => write!(f, "mercury"),
 382            PredictionProvider::Zeta1 => write!(f, "zeta1"),
 383            PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
 384            PredictionProvider::Baseten(format) => write!(f, "baseten:{format}"),
 385            PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
 386            PredictionProvider::TeacherMultiRegion(backend) => {
 387                write!(f, "teacher-multi-region:{backend}")
 388            }
 389            PredictionProvider::TeacherNonBatching(backend) => {
 390                write!(f, "teacher-non-batching:{backend}")
 391            }
 392            PredictionProvider::TeacherMultiRegionNonBatching(backend) => {
 393                write!(f, "teacher-multi-region-non-batching:{backend}")
 394            }
 395            PredictionProvider::Repair => write!(f, "repair"),
 396        }
 397    }
 398}
 399
 400impl std::str::FromStr for PredictionProvider {
 401    type Err = anyhow::Error;
 402
 403    fn from_str(s: &str) -> Result<Self, Self::Err> {
 404        let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
 405
 406        let provider_lower = provider.to_lowercase();
 407        match provider_lower.as_str() {
 408            "mercury" => Ok(PredictionProvider::Mercury),
 409            "zeta1" => Ok(PredictionProvider::Zeta1),
 410            "zeta2" => {
 411                let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
 412                Ok(PredictionProvider::Zeta2(format))
 413            }
 414            "teacher" => {
 415                let backend = arg
 416                    .map(|a| a.parse())
 417                    .transpose()?
 418                    .unwrap_or(TeacherBackend::default());
 419                Ok(PredictionProvider::Teacher(backend))
 420            }
 421            "teacher-multi-region" | "teacher_multi_region" => {
 422                let backend = arg
 423                    .map(|a| a.parse())
 424                    .transpose()?
 425                    .unwrap_or(TeacherBackend::default());
 426                Ok(PredictionProvider::TeacherMultiRegion(backend))
 427            }
 428            "teacher-non-batching" | "teacher_non_batching" => {
 429                let backend = arg
 430                    .map(|a| a.parse())
 431                    .transpose()?
 432                    .unwrap_or(TeacherBackend::default());
 433                Ok(PredictionProvider::TeacherNonBatching(backend))
 434            }
 435            "teacher-multi-region-non-batching" | "teacher_multi_region_non_batching" => {
 436                let backend = arg
 437                    .map(|a| a.parse())
 438                    .transpose()?
 439                    .unwrap_or(TeacherBackend::default());
 440                Ok(PredictionProvider::TeacherMultiRegionNonBatching(backend))
 441            }
 442            "repair" => Ok(PredictionProvider::Repair),
 443            "baseten" => {
 444                let format = arg
 445                    .map(ZetaFormat::parse)
 446                    .transpose()?
 447                    .unwrap_or(ZetaFormat::default());
 448                Ok(PredictionProvider::Baseten(format))
 449            }
 450            _ => {
 451                anyhow::bail!(
 452                    "unknown provider `{provider}`. Valid options: mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-multi-region, teacher-multi-region:<backend>, teacher-non-batching, teacher-multi-region-non-batching, repair\n\
 453                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
 454                 For teacher providers, you can specify a backend like `teacher:sonnet46`, `teacher-multi-region:sonnet46`, `teacher-multi-region-non-batching:sonnet46`, or `teacher:gpt52`.\n\
 455                 Available zeta versions:\n{}",
 456                    ZetaFormat::options_as_string()
 457                )
 458            }
 459        }
 460    }
 461}
 462
 463impl Serialize for PredictionProvider {
 464    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 465    where
 466        S: Serializer,
 467    {
 468        serializer.serialize_str(&self.to_string())
 469    }
 470}
 471
 472impl<'de> Deserialize<'de> for PredictionProvider {
 473    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 474    where
 475        D: Deserializer<'de>,
 476    {
 477        let s = String::deserialize(deserializer)?;
 478        s.parse().map_err(serde::de::Error::custom)
 479    }
 480}
 481
 482#[derive(Debug, Args, Clone)]
 483struct SynthesizeArgs {
 484    /// Repository URLs (git@github.com:owner/repo or https://...)
 485    #[clap(long, required = true, num_args = 1..)]
 486    repos: Vec<String>,
 487
 488    /// Number of examples to generate per repository
 489    #[clap(long, default_value_t = 5)]
 490    count: usize,
 491
 492    /// Maximum commits to scan per repository before giving up
 493    #[clap(long, default_value_t = 100)]
 494    max_commits: usize,
 495
 496    /// Ignore state file and reprocess all commits
 497    #[clap(long)]
 498    fresh: bool,
 499}
 500
 501#[derive(Debug, Args, Clone)]
 502struct ImportBatchArgs {
 503    /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
 504    #[clap(long, required = true, num_args = 1..)]
 505    batch_ids: Vec<String>,
 506    /// Which provider's batches to import (anthropic or openai)
 507    #[clap(long, default_value = "anthropic")]
 508    provider: BatchProvider,
 509}
 510
 511#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
 512enum BatchProvider {
 513    Anthropic,
 514    Openai,
 515}
 516
 517#[cfg(test)]
 518mod tests {
 519    use super::*;
 520
 521    #[test]
 522    fn prediction_provider_multi_region_non_batched_round_trips_to_primary_spelling() {
 523        let provider: PredictionProvider = "teacher-multi-region-non-batching:sonnet46"
 524            .parse()
 525            .unwrap();
 526        assert_eq!(
 527            provider,
 528            PredictionProvider::TeacherMultiRegionNonBatching(TeacherBackend::Sonnet46)
 529        );
 530        assert_eq!(
 531            provider.to_string(),
 532            "teacher-multi-region-non-batching:sonnet46"
 533        );
 534    }
 535
 536    #[test]
 537    fn prediction_provider_multi_region_non_batched_alias_round_trips_to_primary_spelling() {
 538        let provider: PredictionProvider =
 539            "teacher_multi_region_non_batching:gpt52".parse().unwrap();
 540        assert_eq!(
 541            provider,
 542            PredictionProvider::TeacherMultiRegionNonBatching(TeacherBackend::Gpt52)
 543        );
 544        assert_eq!(
 545            provider.to_string(),
 546            "teacher-multi-region-non-batching:gpt52"
 547        );
 548    }
 549}
 550
 551impl EpArgs {
 552    fn output_path(&self) -> Option<PathBuf> {
 553        if self.in_place {
 554            if self.inputs.len() == 1 {
 555                self.inputs.first().cloned()
 556            } else {
 557                panic!("--in-place requires exactly one input file")
 558            }
 559        } else {
 560            self.output.clone()
 561        }
 562    }
 563}
 564
 565/// Minimum Zed version required for Snowflake queries.
 566/// This version introduced the current request schema with predicted edits in the edit
 567/// history, and open source repos distinguished.
 568const MIN_CAPTURE_VERSION: pull_examples::MinCaptureVersion = pull_examples::MinCaptureVersion {
 569    minor: 224,
 570    patch: 1,
 571};
 572
 573fn deduplicate_examples(examples: &mut Vec<Example>, max_per_cluster: usize) {
 574    let total_before_exact = examples.len();
 575    let mut seen_positions = HashSet::default();
 576    examples.retain(|example| seen_positions.insert(example.spec.cursor_position.clone()));
 577    log::info!(
 578        "exact duplicate filter: {total_before_exact} examples → {} examples ({} removed)",
 579        examples.len(),
 580        total_before_exact - examples.len(),
 581    );
 582
 583    const JACCARD_THRESHOLD: f64 = 0.5;
 584    const NUM_HASHES: usize = 128;
 585    const TOKEN_NGRAM_SIZE: usize = 5;
 586
 587    let (num_bands, band_width) = calculate_minhash_params(JACCARD_THRESHOLD, NUM_HASHES);
 588    let num_hashes = num_bands * band_width;
 589    let minhasher = MinHasher32::new(num_hashes);
 590    let mut index: MinHashIndex<u32, usize> =
 591        MinHashIndex::new(num_bands, band_width, JACCARD_THRESHOLD);
 592
 593    let signatures: Vec<Vec<u32>> = examples
 594        .iter()
 595        .map(|example| {
 596            let shingles = code_token_ngrams(&example.spec.cursor_position, TOKEN_NGRAM_SIZE);
 597            minhasher.create_signature(shingles.iter())
 598        })
 599        .collect();
 600
 601    for (id, signature) in signatures.iter().enumerate() {
 602        index.insert(id, signature.clone());
 603    }
 604
 605    // Build clusters via union-find on LSH candidate pairs.
 606    let mut parent: Vec<usize> = (0..examples.len()).collect();
 607
 608    fn find(parent: &mut Vec<usize>, mut x: usize) -> usize {
 609        while parent[x] != x {
 610            parent[x] = parent[parent[x]];
 611            x = parent[x];
 612        }
 613        x
 614    }
 615
 616    for (id, signature) in signatures.iter().enumerate() {
 617        for candidate in index.query_owned(signature) {
 618            let (a, b) = (find(&mut parent, id), find(&mut parent, candidate));
 619            if a != b {
 620                parent[a] = b;
 621            }
 622        }
 623    }
 624
 625    let mut clusters: HashMap<usize, Vec<usize>> = HashMap::default();
 626    for id in 0..examples.len() {
 627        clusters.entry(find(&mut parent, id)).or_default().push(id);
 628    }
 629
 630    let mut keep: HashSet<usize> = HashSet::default();
 631    for members in clusters.values() {
 632        let selected = greedy_max_min_diverse(members, &signatures, max_per_cluster);
 633        keep.extend(selected);
 634    }
 635
 636    let total = examples.len();
 637    let mut kept_indices: Vec<usize> = keep.into_iter().collect();
 638    kept_indices.sort();
 639
 640    let mut retained = Vec::with_capacity(kept_indices.len());
 641    for index in kept_indices.into_iter().rev() {
 642        retained.push(examples.swap_remove(index));
 643    }
 644    retained.reverse();
 645
 646    *examples = retained;
 647    log::info!(
 648        "near-duplicate filter: {total} examples → {} examples ({} removed)",
 649        examples.len(),
 650        total - examples.len(),
 651    );
 652}
 653
 654fn greedy_max_min_diverse(members: &[usize], signatures: &[Vec<u32>], k: usize) -> Vec<usize> {
 655    if members.len() <= k {
 656        return members.to_vec();
 657    }
 658
 659    let mut selected = vec![members[0]];
 660    let mut min_dist: HashMap<usize, f64> = HashMap::default();
 661    for &member in &members[1..] {
 662        let dist = 1.0 - compute_minhash_similarity(&signatures[selected[0]], &signatures[member]);
 663        min_dist.insert(member, dist);
 664    }
 665
 666    while selected.len() < k {
 667        let &best = min_dist
 668            .iter()
 669            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
 670            .map(|(id, _)| id)
 671            .expect("min_dist should not be empty when selected.len() < k");
 672        selected.push(best);
 673        min_dist.remove(&best);
 674
 675        let best_sig = &signatures[best];
 676        for (member, current_min) in min_dist.iter_mut() {
 677            let dist = 1.0 - compute_minhash_similarity(best_sig, &signatures[*member]);
 678            if dist < *current_min {
 679                *current_min = dist;
 680            }
 681        }
 682    }
 683
 684    selected
 685}
 686
 687fn code_token_ngrams(code: &str, ngram_size: usize) -> Vec<String> {
 688    let tokens: Vec<&str> = word_diff::tokenize(code)
 689        .into_iter()
 690        .filter(|t| !t.trim().is_empty())
 691        .collect();
 692
 693    if tokens.len() < ngram_size {
 694        return vec![tokens.join("\0")];
 695    }
 696
 697    tokens
 698        .windows(ngram_size)
 699        .map(|window| window.join("\0"))
 700        .collect()
 701}
 702
 703async fn load_examples(
 704    http_client: Arc<dyn http_client::HttpClient>,
 705    args: &EpArgs,
 706    output_path: Option<&PathBuf>,
 707    background_executor: BackgroundExecutor,
 708) -> anyhow::Result<Vec<Example>> {
 709    let mut captured_after_timestamps = Vec::new();
 710    let mut rejected_after_timestamps = Vec::new();
 711    let mut requested_after_timestamps = Vec::new();
 712    let mut settled_after_timestamps = Vec::new();
 713    let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
 714        Vec::new();
 715    let mut file_inputs = Vec::new();
 716
 717    for input in &args.inputs {
 718        let input_string = input.to_string_lossy();
 719        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
 720            captured_after_timestamps.push(timestamp.to_string());
 721        } else if let Some(timestamp) =
 722            pull_examples::parse_rejected_after_input(input_string.as_ref())
 723        {
 724            rejected_after_timestamps.push(timestamp.to_string());
 725        } else if let Some(timestamp) =
 726            pull_examples::parse_requested_after_input(input_string.as_ref())
 727        {
 728            requested_after_timestamps.push(timestamp.to_string());
 729        } else if let Some(timestamp) = parse_settled_after_input(input_string.as_ref()) {
 730            settled_after_timestamps.push(timestamp.to_string());
 731        } else if let Some((timestamp, rating_filter)) =
 732            pull_examples::parse_rated_after_input(input_string.as_ref())
 733        {
 734            rated_after_inputs.push((timestamp.to_string(), rating_filter));
 735        } else {
 736            file_inputs.push(input.clone());
 737        }
 738    }
 739
 740    let mut examples = read_example_files(&file_inputs);
 741
 742    // Apply offset to file examples first, then pass remaining offset to Snowflake.
 743    let file_example_count = examples.len();
 744    let remaining_offset = if let Some(offset) = args.offset {
 745        if offset >= file_example_count {
 746            examples.clear();
 747            offset - file_example_count
 748        } else {
 749            examples.splice(0..offset, []);
 750            0
 751        }
 752    } else {
 753        0
 754    };
 755
 756    Progress::global().set_total_examples(examples.len());
 757
 758    let remaining_limit_for_snowflake =
 759        args.limit.map(|limit| limit.saturating_sub(examples.len()));
 760
 761    if let Some(0) = remaining_limit_for_snowflake {
 762        log::info!(
 763            "skipping Snowflake inputs because --limit is already satisfied by example files"
 764        );
 765    } else {
 766        let max_rows_per_timestamp = remaining_limit_for_snowflake;
 767
 768        if !rejected_after_timestamps.is_empty() {
 769            rejected_after_timestamps.sort();
 770
 771            let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
 772                http_client.clone(),
 773                &rejected_after_timestamps,
 774                max_rows_per_timestamp,
 775                remaining_offset,
 776                background_executor.clone(),
 777                Some(MIN_CAPTURE_VERSION),
 778            )
 779            .await?;
 780            examples.append(&mut rejected_examples);
 781        }
 782
 783        if !requested_after_timestamps.is_empty() {
 784            requested_after_timestamps.sort();
 785
 786            let mut requested_examples = pull_examples::fetch_requested_examples_after(
 787                http_client.clone(),
 788                &requested_after_timestamps,
 789                max_rows_per_timestamp,
 790                remaining_offset,
 791                background_executor.clone(),
 792                Some(MIN_CAPTURE_VERSION),
 793            )
 794            .await?;
 795            examples.append(&mut requested_examples);
 796        }
 797
 798        if !captured_after_timestamps.is_empty() {
 799            captured_after_timestamps.sort();
 800
 801            let mut captured_examples = pull_examples::fetch_captured_examples_after(
 802                http_client.clone(),
 803                &captured_after_timestamps,
 804                max_rows_per_timestamp,
 805                remaining_offset,
 806                background_executor.clone(),
 807                Some(MIN_CAPTURE_VERSION),
 808            )
 809            .await?;
 810            examples.append(&mut captured_examples);
 811        }
 812
 813        if !settled_after_timestamps.is_empty() {
 814            settled_after_timestamps.sort();
 815
 816            let mut settled_examples = fetch_settled_examples_after(
 817                http_client.clone(),
 818                &settled_after_timestamps,
 819                max_rows_per_timestamp,
 820                remaining_offset,
 821                background_executor.clone(),
 822                Some(MIN_CAPTURE_VERSION),
 823            )
 824            .await?;
 825            examples.append(&mut settled_examples);
 826        }
 827
 828        if !rated_after_inputs.is_empty() {
 829            rated_after_inputs.sort();
 830
 831            let mut rated_examples = pull_examples::fetch_rated_examples_after(
 832                http_client,
 833                &rated_after_inputs,
 834                max_rows_per_timestamp,
 835                remaining_offset,
 836                background_executor,
 837                Some(MIN_CAPTURE_VERSION),
 838            )
 839            .await?;
 840            examples.append(&mut rated_examples);
 841        }
 842    }
 843
 844    crate::example::sort_examples_by_repo_and_rev(&mut examples);
 845
 846    if let Some(name_filter) = &args.name {
 847        examples.retain(|example| example.spec.name.contains(name_filter));
 848    }
 849    if let Some(repo_filter) = &args.repo {
 850        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
 851    }
 852
 853    // Skip resume logic for --in-place since input and output are the same file,
 854    // which would incorrectly treat all input examples as already processed.
 855    if !args.in_place {
 856        if let Some(path) = output_path
 857            && let Some(command) = &args.command
 858        {
 859            resume_from_output(path, &mut examples, command);
 860        }
 861    }
 862
 863    if let Some(max_duplicates) = args.max_duplicates {
 864        deduplicate_examples(&mut examples, max_duplicates);
 865    }
 866
 867    if let Some(limit) = args.limit {
 868        examples.truncate(limit);
 869    }
 870
 871    let progress = Progress::global();
 872    progress.set_total_examples(examples.len());
 873    progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
 874
 875    Ok(examples)
 876}
 877
 878fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
 879    let mut hasher = collections::FxHasher::default();
 880    spec.hash(&mut hasher);
 881    hasher.finish()
 882}
 883
 884fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
 885    let file = match File::open(path) {
 886        Ok(f) => f,
 887        Err(_) => return,
 888    };
 889
 890    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
 891
 892    let reader = BufReader::new(file);
 893    let mut kept_lines = Vec::new();
 894    let mut kept_hashes = HashSet::default();
 895
 896    for line in reader.lines() {
 897        let line = match line {
 898            Ok(l) => l,
 899            Err(_) => continue,
 900        };
 901
 902        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
 903            let hash = spec_hash(&output_example.spec);
 904            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
 905                let is_complete = match command {
 906                    Command::Qa(_) => output_example
 907                        .qa
 908                        .first()
 909                        .and_then(|q| q.as_ref())
 910                        .and_then(|q| q.confidence)
 911                        .is_some(),
 912                    Command::Repair(_) => output_example.predictions.iter().any(|p| {
 913                        p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
 914                    }),
 915                    _ => true,
 916                };
 917                if is_complete {
 918                    kept_hashes.insert(hash);
 919                    kept_lines.push(line);
 920                }
 921            }
 922        }
 923    }
 924
 925    let total = examples.len();
 926    let already_processed = kept_hashes.len();
 927
 928    eprintln!(
 929        "Resuming: {}/{} examples already processed",
 930        already_processed, total
 931    );
 932
 933    let file = OpenOptions::new()
 934        .write(true)
 935        .truncate(true)
 936        .open(path)
 937        .expect("Failed to open output file for rewriting");
 938    let mut writer = BufWriter::new(file);
 939    for line in &kept_lines {
 940        writeln!(writer, "{}", line).expect("Failed to write to output file");
 941    }
 942    writer.flush().expect("Failed to flush output file");
 943
 944    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
 945}
 946
 947fn main() {
 948    let args = EpArgs::parse();
 949
 950    if args.printenv {
 951        ::util::shell_env::print_env();
 952        return;
 953    }
 954
 955    let output = args.output_path();
 956
 957    if args.markdown && output.is_none() {
 958        eprintln!("--markdown requires -o to specify the output directory");
 959        std::process::exit(1);
 960    }
 961
 962    let command = match &args.command {
 963        Some(cmd) => cmd.clone(),
 964        None => {
 965            EpArgs::command().print_help().unwrap();
 966            return;
 967        }
 968    };
 969
 970    match &command {
 971        Command::ImportBatch(import_args) => {
 972            smol::block_on(async {
 973                match import_args.provider {
 974                    BatchProvider::Anthropic => {
 975                        let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
 976                            .expect("Failed to create Anthropic client");
 977                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 978                            eprintln!("Error importing Anthropic batches: {:?}", e);
 979                            std::process::exit(1);
 980                        }
 981                    }
 982                    BatchProvider::Openai => {
 983                        let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
 984                            .expect("Failed to create OpenAI client");
 985                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 986                            eprintln!("Error importing OpenAI batches: {:?}", e);
 987                            std::process::exit(1);
 988                        }
 989                    }
 990                }
 991                println!(
 992                    "Successfully imported {} batch(es)",
 993                    import_args.batch_ids.len()
 994                );
 995            });
 996            return;
 997        }
 998        Command::Clean => {
 999            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
1000            return;
1001        }
1002        Command::PrintZetaFormats => {
1003            use strum::IntoEnumIterator as _;
1004            for format in ZetaFormat::iter() {
1005                println!("{}", format.to_string().to_lowercase());
1006            }
1007            return;
1008        }
1009
1010        Command::Synthesize(synth_args) => {
1011            let output_dir = if let Some(output_dir) = args.output {
1012                output_dir
1013            } else {
1014                let default_output_dir = env::current_dir()
1015                    .unwrap()
1016                    .join("crates/edit_prediction_cli/evals-generated");
1017                if default_output_dir.parent().unwrap().exists() {
1018                    std::fs::create_dir(&default_output_dir).ok();
1019                    default_output_dir
1020                } else {
1021                    panic!("output dir is required");
1022                }
1023            };
1024            let config = SynthesizeConfig {
1025                repo_urls: synth_args.repos.clone(),
1026                count: synth_args.count,
1027                max_commits: synth_args.max_commits,
1028                output_dir,
1029                fresh: synth_args.fresh,
1030            };
1031            smol::block_on(async {
1032                if let Err(e) = run_synthesize(config).await {
1033                    eprintln!("Error: {:?}", e);
1034                    std::process::exit(1);
1035                }
1036            });
1037            return;
1038        }
1039        Command::SplitCommit(split_commit_args) => {
1040            if let Err(error) = split_commit::run_split_commit(
1041                split_commit_args,
1042                &args.inputs,
1043                output.as_ref(),
1044                args.failed,
1045            ) {
1046                eprintln!("{error:#}");
1047                std::process::exit(1);
1048            }
1049            return;
1050        }
1051        Command::TruncatePatch(truncate_args) => {
1052            if let Err(error) =
1053                truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
1054            {
1055                eprintln!("{error:#}");
1056                std::process::exit(1);
1057            }
1058            return;
1059        }
1060        Command::Split(split_args) => {
1061            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
1062                eprintln!("{error:#}");
1063                std::process::exit(1);
1064            }
1065            return;
1066        }
1067        Command::FilterLanguages(filter_args) => {
1068            if let Err(error) =
1069                run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
1070            {
1071                eprintln!("{error:#}");
1072                std::process::exit(1);
1073            }
1074            return;
1075        }
1076
1077        _ => {}
1078    }
1079
1080    let http_client = Arc::new(ReqwestClient::new());
1081    let app = gpui_platform::headless().with_http_client(http_client);
1082
1083    app.run(move |cx| {
1084        let app_state = Arc::new(headless::init(cx));
1085        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
1086
1087        cx.spawn(async move |cx| {
1088            let result = async {
1089                let examples = load_examples(
1090                    app_state.client.http_client(),
1091                    &args,
1092                    output.as_ref(),
1093                    cx.background_executor().clone(),
1094                )
1095                .await?;
1096
1097                match &command {
1098                    Command::Predict(args) | Command::Score(args) => {
1099                        predict::sync_batches(args.provider.as_ref()).await?;
1100                    }
1101                    Command::Eval(args) => {
1102                        predict::sync_batches(args.predict.provider.as_ref()).await?;
1103                    }
1104                    Command::Qa(args) => {
1105                        qa::sync_batches(args).await?;
1106                    }
1107                    Command::Repair(args) => {
1108                        repair::sync_batches(args).await?;
1109                    }
1110                    _ => (),
1111                }
1112
1113                let failfast_on_single_example = examples.len() == 1;
1114
1115                // For --markdown mode, create the output directory if it doesn't exist
1116                if args.markdown {
1117                    let dir = output.as_ref().expect("--markdown requires -o");
1118                    if !dir.exists() {
1119                        std::fs::create_dir_all(dir)
1120                            .expect("Failed to create markdown output directory");
1121                    }
1122                }
1123
1124                // Set up JSONL output writer (not used in markdown mode)
1125                let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
1126                let mut in_place_temp_path: Option<PathBuf> = None;
1127                if !args.markdown
1128                    && let Some(output_path) = output.as_ref()
1129                {
1130                    let write_path = if args.in_place {
1131                        let temp = output_path.with_extension("jsonl.tmp");
1132                        in_place_temp_path = Some(temp.clone());
1133                        temp
1134                    } else {
1135                        output_path.clone()
1136                    };
1137
1138                    let file = OpenOptions::new()
1139                        .create(true)
1140                        .write(true)
1141                        .truncate(args.in_place)
1142                        .append(!args.in_place)
1143                        .open(&write_path)
1144                        .expect("Failed to open output file");
1145
1146                    let mut writer = BufWriter::new(file);
1147                    let (sender, mut receiver) = mpsc::unbounded::<String>();
1148                    cx.background_spawn(async move {
1149                        while let Some(line) = receiver.next().await {
1150                            writeln!(writer, "{}", line).expect("Failed to write example");
1151                            writer.flush().expect("Failed to flush output");
1152                        }
1153                    })
1154                    .detach();
1155                    output_sender = Some(sender);
1156                }
1157
1158                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
1159                let finished_examples = Mutex::new(Vec::new());
1160
1161                let mut tasks = Vec::new();
1162                for _ in 0..args.max_parallelism {
1163                    tasks.push(async {
1164                        loop {
1165                            let Some(mut repo_examples) =
1166                                grouped_examples.lock().unwrap().pop_front()
1167                            else {
1168                                break;
1169                            };
1170                            for example in &mut repo_examples {
1171                                let example_progress =
1172                                    Progress::global().start_group(&example.spec.name);
1173
1174                                let result = async {
1175                                    match &command {
1176                                        Command::Read(_) => {}
1177                                        Command::LoadProject => {
1178                                            run_load_project(
1179                                                example,
1180                                                app_state.clone(),
1181                                                &example_progress,
1182                                                cx.clone(),
1183                                            )
1184                                            .await?;
1185                                        }
1186                                        Command::Context => {
1187                                            run_context_retrieval(
1188                                                example,
1189                                                app_state.clone(),
1190                                                &example_progress,
1191                                                cx.clone(),
1192                                            )
1193                                            .await?;
1194                                        }
1195                                        Command::FormatPrompt(args) => {
1196                                            run_format_prompt(
1197                                                example,
1198                                                args,
1199                                                app_state.clone(),
1200                                                &example_progress,
1201                                                cx.clone(),
1202                                            )
1203                                            .await?;
1204                                        }
1205                                        Command::Predict(args) => {
1206                                            run_prediction(
1207                                                example,
1208                                                args,
1209                                                app_state.clone(),
1210                                                &example_progress,
1211                                                cx.clone(),
1212                                            )
1213                                            .await?;
1214                                        }
1215                                        Command::ParseOutput => {
1216                                            parse_output::run_parse_output(example)?;
1217                                        }
1218                                        Command::Distill => {
1219                                            run_distill(example).await?;
1220                                        }
1221                                        Command::Score(args) => {
1222                                            run_scoring(
1223                                                example,
1224                                                args,
1225                                                app_state.clone(),
1226                                                &example_progress,
1227                                                cx.clone(),
1228                                            )
1229                                            .await?;
1230                                        }
1231                                        Command::Eval(args) => {
1232                                            run_scoring(
1233                                                example,
1234                                                &args.predict,
1235                                                app_state.clone(),
1236                                                &example_progress,
1237                                                cx.clone(),
1238                                            )
1239                                            .await?;
1240                                        }
1241                                        Command::Qa(args) => {
1242                                            qa::run_qa(example, args, &example_progress).await?;
1243                                        }
1244                                        Command::Repair(args) => {
1245                                            repair::run_repair(example, args, &example_progress)
1246                                                .await?;
1247                                        }
1248                                        Command::Clean
1249                                        | Command::Synthesize(_)
1250                                        | Command::SplitCommit(_)
1251                                        | Command::Split(_)
1252                                        | Command::TruncatePatch(_)
1253                                        | Command::FilterLanguages(_)
1254                                        | Command::ImportBatch(_)
1255                                        | Command::PrintZetaFormats => {
1256                                            unreachable!()
1257                                        }
1258                                    }
1259                                    anyhow::Ok(())
1260                                }
1261                                .await;
1262
1263                                let failed = if let Err(error) = result {
1264                                    handle_error(
1265                                        error,
1266                                        &args,
1267                                        &command,
1268                                        &app_state,
1269                                        failfast_on_single_example,
1270                                        &example,
1271                                    )
1272                                    .await;
1273                                    true
1274                                } else {
1275                                    false
1276                                };
1277
1278                                let should_write = !failed || args.failed == FailedHandling::Keep;
1279                                if should_write {
1280                                    if args.markdown {
1281                                        let markdown_dir =
1282                                            output.as_ref().expect("--markdown requires -o");
1283                                        let filename = format!("{}.md", example.spec.filename());
1284                                        let path = markdown_dir.join(&filename);
1285                                        let markdown = example.spec.to_markdown();
1286                                        std::fs::write(&path, &markdown)
1287                                            .expect("Failed to write markdown file");
1288                                    } else if let Some(ref mut sender) = output_sender.clone() {
1289                                        let line = serde_json::to_string(&example).unwrap();
1290                                        sender
1291                                            .send(line)
1292                                            .await
1293                                            .expect("Failed to send to output writer");
1294                                    } else if args.output.is_none()
1295                                        && !matches!(command, Command::Eval(_))
1296                                    {
1297                                        let line = serde_json::to_string(&example).unwrap();
1298                                        println!("{}", line);
1299                                    }
1300                                }
1301                            }
1302
1303                            let project = repo_examples
1304                                .iter()
1305                                .find_map(|e| e.state.as_ref().map(|s| s.project.clone()));
1306
1307                            if let Some(project) = project {
1308                                let mut cx = cx.clone();
1309
1310                                let shutdown_task: Task<()> =
1311                                    project.update(&mut cx, |project, cx| {
1312                                        let lsp_store = project.lsp_store();
1313                                        lsp_store.update(cx, |lsp_store, cx| {
1314                                            lsp_store.shutdown_all_language_servers(cx)
1315                                        })
1316                                    });
1317
1318                                shutdown_task.await;
1319
1320                                if let Some(ep_store) =
1321                                    cx.update(|cx| EditPredictionStore::try_global(cx))
1322                                {
1323                                    ep_store.update(&mut cx, |store, _| {
1324                                        store.remove_project(&project);
1325                                    });
1326                                }
1327                            }
1328
1329                            for example in &mut repo_examples {
1330                                example.state.take();
1331                            }
1332                            finished_examples
1333                                .lock()
1334                                .unwrap()
1335                                .extend_from_slice(&repo_examples);
1336                        }
1337                    });
1338                }
1339                futures::future::join_all(tasks).await;
1340
1341                Progress::global().finalize();
1342
1343                let is_markdown = args.markdown;
1344                let write_path = in_place_temp_path.as_ref().or(output.as_ref());
1345                match &command {
1346                    Command::Predict(args) | Command::Score(args) => {
1347                        predict::sync_batches(args.provider.as_ref()).await?;
1348                        if args.wait {
1349                            predict::wait_for_batches(args.provider.as_ref()).await?;
1350                            let mut examples =
1351                                std::mem::take(&mut *finished_examples.lock().unwrap());
1352                            predict::reprocess_after_batch_wait(&mut examples, args).await?;
1353                            rewrite_output(&examples, write_path, is_markdown)?;
1354                            *finished_examples.lock().unwrap() = examples;
1355                        }
1356                    }
1357                    Command::Eval(args) => {
1358                        predict::sync_batches(args.predict.provider.as_ref()).await?;
1359                        if args.predict.wait {
1360                            predict::wait_for_batches(args.predict.provider.as_ref()).await?;
1361                            let mut examples =
1362                                std::mem::take(&mut *finished_examples.lock().unwrap());
1363                            predict::reprocess_after_batch_wait(&mut examples, &args.predict)
1364                                .await?;
1365                            rewrite_output(&examples, write_path, is_markdown)?;
1366                            *finished_examples.lock().unwrap() = examples;
1367                        }
1368                    }
1369                    Command::Qa(args) => {
1370                        qa::sync_batches(args).await?;
1371                    }
1372                    Command::Repair(args) => {
1373                        repair::sync_batches(args).await?;
1374                        if args.wait {
1375                            repair::wait_for_batches(args).await?;
1376                            let mut examples =
1377                                std::mem::take(&mut *finished_examples.lock().unwrap());
1378                            repair::reprocess_after_batch_wait(&mut examples, args).await?;
1379                            rewrite_output(&examples, write_path, is_markdown)?;
1380                            *finished_examples.lock().unwrap() = examples;
1381                        }
1382                    }
1383                    _ => (),
1384                }
1385
1386                match &command {
1387                    Command::Eval(args) => {
1388                        let examples = finished_examples.lock().unwrap();
1389                        score::print_report(&examples, args.verbose);
1390                        if let Some(summary_path) = &args.summary_json {
1391                            score::write_summary_json(&examples, summary_path)?;
1392                        }
1393                    }
1394                    Command::Repair(args) => {
1395                        let examples = finished_examples.lock().unwrap();
1396                        repair::print_report(&examples, args.confidence_threshold);
1397                    }
1398                    _ => (),
1399                };
1400
1401                // For --in-place, atomically rename temp file to original
1402                if let Some(temp_path) = &in_place_temp_path {
1403                    let final_path = output.as_ref().expect("in_place_temp_path requires output");
1404                    std::fs::rename(temp_path, final_path)
1405                        .expect("Failed to rename temp file to final output");
1406                }
1407
1408                anyhow::Ok(())
1409            }
1410            .await;
1411
1412            if let Err(e) = result {
1413                panic!("Fatal error: {:?}", e);
1414            }
1415
1416            let _ = cx.update(|cx| cx.quit());
1417        })
1418        .detach();
1419    });
1420}
1421
1422fn rewrite_output(
1423    examples: &[Example],
1424    output_path: Option<&PathBuf>,
1425    markdown: bool,
1426) -> anyhow::Result<()> {
1427    if markdown {
1428        let dir = output_path.context("--markdown requires -o")?;
1429        for example in examples {
1430            let filename = format!("{}.md", example.spec.filename());
1431            let path = dir.join(&filename);
1432            let markdown = example.spec.to_markdown();
1433            std::fs::write(&path, &markdown).context("Failed to write markdown file")?;
1434        }
1435    } else if let Some(path) = output_path {
1436        let file = OpenOptions::new()
1437            .create(true)
1438            .write(true)
1439            .truncate(true)
1440            .open(path)
1441            .context("Failed to open output file for rewriting")?;
1442        let mut writer = BufWriter::new(file);
1443        for example in examples {
1444            let line = serde_json::to_string(example)?;
1445            writeln!(writer, "{}", line)?;
1446        }
1447        writer.flush()?;
1448    } else {
1449        for example in examples {
1450            let line = serde_json::to_string(example)?;
1451            println!("{}", line);
1452        }
1453    }
1454    Ok(())
1455}
1456
1457async fn handle_error(
1458    error: anyhow::Error,
1459    args: &EpArgs,
1460    command: &Command,
1461    app_state: &Arc<headless::EpAppState>,
1462    failfast_on_single_example: bool,
1463    example: &Example,
1464) {
1465    Progress::global().increment_failed();
1466
1467    let msg;
1468    if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1469        let example_name = example.spec.filename();
1470
1471        let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1472        app_state
1473            .fs
1474            .write(
1475                &failed_example_path,
1476                &serde_json::to_vec_pretty(&example).unwrap(),
1477            )
1478            .await
1479            .unwrap();
1480        let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1481        app_state
1482            .fs
1483            .write(&err_path, format!("{error:?}").as_bytes())
1484            .await
1485            .unwrap();
1486
1487        let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1488        let mut file = OpenOptions::new()
1489            .create(true)
1490            .append(true)
1491            .open(&failed_jsonl_path)
1492            .expect("Failed to open failed.jsonl");
1493        writeln!(file, "{}", serde_json::to_string(example).unwrap())
1494            .expect("Failed to write to failed.jsonl");
1495
1496        let cursor_path = match example.repo_name() {
1497            Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1498            Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1499        };
1500        msg = format!(
1501            indoc::indoc! {"
1502                While processing \"{}\":
1503
1504                \x1b[31m{:?}\x1b[0m
1505
1506                Example:        \x1b[36m{}\x1b[0m
1507                Error file:     \x1b[36m{}\x1b[0m
1508                Cursor file:    \x1b[36m{}\x1b[0m
1509                Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1510            "},
1511            example.spec.name,
1512            error,
1513            failed_example_path.display(),
1514            err_path.display(),
1515            cursor_path.display(),
1516            command,
1517            failed_example_path.display(),
1518        );
1519    } else {
1520        msg = format!(
1521            indoc::indoc! {"
1522            While processing \"{}\":
1523
1524                \x1b[31m{:?}\x1b[0m
1525            "},
1526            example.spec.name, error
1527        );
1528    }
1529
1530    if args.failfast || failfast_on_single_example {
1531        Progress::global().finalize();
1532        panic!("{}", msg);
1533    } else {
1534        log::error!("{}", msg);
1535    }
1536}