main.rs

   1mod anthropic_client;
   2mod distill;
   3mod example;
   4mod filter_languages;
   5mod format_prompt;
   6mod git;
   7mod headless;
   8mod load_project;
   9mod metrics;
  10mod openai_client;
  11mod parse_output;
  12mod paths;
  13mod predict;
  14mod progress;
  15mod prompt_assets;
  16mod pull_examples;
  17mod qa;
  18mod reorder_patch;
  19mod repair;
  20mod retrieve_context;
  21mod reversal_tracking;
  22mod score;
  23mod split_commit;
  24mod split_dataset;
  25
  26mod synthesize;
  27mod truncate_expected_patch;
  28mod word_diff;
  29use anyhow::Context as _;
  30use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
  31use collections::{HashMap, HashSet};
  32use edit_prediction::EditPredictionStore;
  33use futures::channel::mpsc;
  34use futures::{SinkExt as _, StreamExt as _};
  35use gaoya::minhash::{
  36    MinHashIndex, MinHasher, MinHasher32, calculate_minhash_params, compute_minhash_similarity,
  37};
  38use gpui::{AppContext as _, BackgroundExecutor, Task};
  39use zeta_prompt::ZetaFormat;
  40
  41use reqwest_client::ReqwestClient;
  42use serde::{Deserialize, Deserializer, Serialize, Serializer};
  43use std::env;
  44use std::fmt::Display;
  45use std::fs::{File, OpenOptions};
  46use std::hash::{Hash, Hasher};
  47use std::io::{BufRead, BufReader, BufWriter, Write};
  48use std::sync::Mutex;
  49use std::{path::PathBuf, sync::Arc};
  50
  51use crate::distill::run_distill;
  52use crate::example::{Example, group_examples_by_repo, read_example_files};
  53use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
  54use crate::format_prompt::run_format_prompt;
  55use crate::load_project::run_load_project;
  56use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
  57use crate::predict::run_prediction;
  58use crate::progress::Progress;
  59use crate::pull_examples::{fetch_settled_examples_after, parse_settled_after_input};
  60use crate::retrieve_context::run_context_retrieval;
  61use crate::score::run_scoring;
  62use crate::split_commit::SplitCommitArgs;
  63use crate::split_dataset::SplitArgs;
  64use crate::synthesize::{SynthesizeConfig, run_synthesize};
  65use crate::truncate_expected_patch::TruncatePatchArgs;
  66
  67#[derive(Parser, Debug)]
  68#[command(name = "ep")]
  69struct EpArgs {
  70    #[arg(long, default_value_t = false)]
  71    printenv: bool,
  72    #[clap(long, default_value_t = 10, global = true)]
  73    max_parallelism: usize,
  74    /// The limit for the number of examples to process
  75    /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
  76    #[clap(long, global = true)]
  77    limit: Option<usize>,
  78    #[clap(long, global = true)]
  79    offset: Option<usize>,
  80    /// Filter examples by name
  81    #[clap(long, global = true)]
  82    name: Option<String>,
  83    /// Filter examples by repository
  84    #[clap(long, global = true)]
  85    repo: Option<String>,
  86    /// Deduplicate by cursor position and keep at most this many examples per cluster
  87    #[clap(long, global = true)]
  88    max_duplicates: Option<usize>,
  89    #[command(subcommand)]
  90    command: Option<Command>,
  91    /// Input file paths
  92    #[clap(global = true)]
  93    inputs: Vec<PathBuf>,
  94    #[arg(long, short, global = true)]
  95    output: Option<PathBuf>,
  96    #[arg(long, short, global = true)]
  97    in_place: bool,
  98    #[arg(long, short, global = true)]
  99    failfast: bool,
 100    /// How to handle failed examples in output: keep them or skip them.
 101    /// Failed examples are always logged to the run's failed directory.
 102    #[arg(long, global = true, default_value = "keep")]
 103    failed: FailedHandling,
 104    /// Output as markdown files instead of JSONL. When set, -o specifies a directory
 105    /// where one .md file per example will be written (named after each example).
 106    #[arg(long, short, global = true)]
 107    markdown: bool,
 108}
 109
 110/// Controls whether failed examples are included in the main output.
 111/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
 112#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
 113pub enum FailedHandling {
 114    /// Include failed examples in the main output (default)
 115    #[default]
 116    Keep,
 117    /// Exclude failed examples from the main output
 118    Skip,
 119    /// Skip writing files
 120    SkipNoFiles,
 121}
 122
 123const INPUTS_HELP: &str = r#"
 124Inputs can be file paths or special specifiers:
 125
 126  path
 127      Path to an example(s) file (.md, .json, or .jsonl)
 128
 129  captured-after:{timestamp}
 130      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 131      These are examples captured via the "Capture Edit Prediction Example" action.
 132
 133  rejected-after:{timestamp}
 134      Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
 135      These are predictions that were shown to users but rejected (useful for DPO training).
 136
 137  settled-after:{timestamp}
 138      Fetch settled stream examples from Snowflake after the given RFC3339 timestamp.
 139      These are examples from the edit prediction settled stream.
 140
 141  rated-after:{timestamp}
 142      Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
 143      These are predictions that users explicitly rated as positive or negative via the
 144      rate completions modal. Only zeta2 predictions are included.
 145      - Positive ratings: output becomes expected_patches
 146      - Negative ratings: output becomes rejected_patch
 147
 148  rated-positive-after:{timestamp}
 149      Same as rated-after, but only fetches positively rated predictions.
 150
 151  rated-negative-after:{timestamp}
 152      Same as rated-after, but only fetches negatively rated predictions.
 153
 154      Required environment variables to connect to Snowflake:
 155          EP_SNOWFLAKE_API_KEY
 156          EP_SNOWFLAKE_BASE_URL
 157
 158      Optional:
 159          EP_SNOWFLAKE_ROLE
 160
 161Examples:
 162
 163  # Read examples from a file
 164  ep read examples.jsonl -o output.jsonl
 165
 166  # Read captured examples after a timestamp
 167  ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
 168
 169  # Read rejected predictions for DPO training
 170  ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
 171
 172  # Read user-rated predictions
 173  ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
 174
 175  # Read settled stream examples
 176  ep read settled-after:2025-01-01T00:00:00Z -o settled.jsonl
 177
 178  # Read only positively rated predictions
 179  ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
 180
 181  # Read only negatively rated predictions
 182  ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
 183
 184  # Mix multiple input sources
 185  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
 186"#;
 187
 188#[derive(Subcommand, Debug, Clone)]
 189enum Command {
 190    /// Read examples from files or fetch from Snowflake, output as .jsonl
 191    Read(ReadArgs),
 192    /// Create git worktrees for each example and load file contents
 193    LoadProject,
 194    /// Retrieve context for input examples.
 195    Context,
 196    /// Generate a prompt string for a specific model
 197    FormatPrompt(FormatPromptArgs),
 198    /// Runs edit prediction
 199    Predict(PredictArgs),
 200    /// Parse model outputs (actual_output) into unified diffs (actual_patch).
 201    /// Requires format-prompt to have been run first. Uses provider from prompt.
 202    ParseOutput,
 203    /// Computes a score based on actual and expected patches
 204    Score(PredictArgs),
 205    /// Prepares a distillation dataset by copying expected outputs to
 206    /// predicted outputs and removing actual outputs and prompts.
 207    Distill,
 208    /// Print aggregated scores
 209    Eval(EvalArgs),
 210    /// Generate eval examples by analyzing git commits from a repository
 211    Synthesize(SynthesizeArgs),
 212    /// Remove git repositories and worktrees
 213    Clean,
 214    /// Generate an evaluation example by splitting a chronologically-ordered commit
 215    SplitCommit(SplitCommitArgs),
 216    /// Truncate expected patch by the given criteria
 217    TruncatePatch(TruncatePatchArgs),
 218    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
 219    Split(SplitArgs),
 220    /// Filter a JSONL dataset by programming language (based on cursor_path extension)
 221    FilterLanguages(FilterLanguagesArgs),
 222    /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
 223    ImportBatch(ImportBatchArgs),
 224    /// Assess the quality of predictions using LLM-as-a-judge
 225    Qa(qa::QaArgs),
 226    /// Repair predictions that received poor QA scores by generating improved predictions
 227    Repair(repair::RepairArgs),
 228    /// Print all valid zeta formats (lowercase, one per line)
 229    PrintZetaFormats,
 230}
 231
 232impl Display for Command {
 233    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 234        match self {
 235            Command::Read(_) => write!(f, "read"),
 236            Command::LoadProject => write!(f, "load-project"),
 237            Command::Context => write!(f, "context"),
 238            Command::FormatPrompt(args) => {
 239                write!(f, "format-prompt --provider={}", args.provider)
 240            }
 241            Command::Predict(args) => match &args.provider {
 242                Some(provider) => write!(f, "predict --provider={}", provider),
 243                None => write!(f, "predict"),
 244            },
 245            Command::ParseOutput => write!(f, "parse-output"),
 246            Command::Score(args) => match &args.provider {
 247                Some(provider) => write!(f, "score --provider={}", provider),
 248                None => write!(f, "score"),
 249            },
 250            Command::Distill => write!(f, "distill"),
 251            Command::Eval(args) => match &args.predict.provider {
 252                Some(provider) => write!(f, "eval --provider={}", provider),
 253                None => write!(f, "eval"),
 254            },
 255            Command::Synthesize(args) => {
 256                write!(f, "synthesize --repos {}", args.repos.join(" "))
 257            }
 258            Command::Clean => write!(f, "clean"),
 259            Command::SplitCommit(_) => write!(f, "split-commit"),
 260            Command::TruncatePatch(_) => write!(f, "truncate-patch"),
 261            Command::Split(_) => write!(f, "split"),
 262            Command::FilterLanguages(_) => write!(f, "filter-languages"),
 263            Command::ImportBatch(args) => {
 264                write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
 265            }
 266            Command::Qa(_) => {
 267                write!(f, "qa")
 268            }
 269            Command::Repair(_) => {
 270                write!(f, "repair")
 271            }
 272            Command::PrintZetaFormats => {
 273                write!(f, "print-zeta-formats")
 274            }
 275        }
 276    }
 277}
 278
 279#[derive(Debug, Args, Clone)]
 280#[command(after_help = INPUTS_HELP)]
 281struct ReadArgs {}
 282
 283#[derive(Debug, Args, Clone)]
 284struct FormatPromptArgs {
 285    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
 286    provider: PredictionProvider,
 287}
 288
 289#[derive(Debug, Args, Clone)]
 290struct PredictArgs {
 291    #[clap(long, short('p'))]
 292    provider: Option<PredictionProvider>,
 293    #[clap(long, default_value_t = 1)]
 294    repetitions: usize,
 295    /// Only use cached responses, don't queue new requests for batching
 296    #[clap(long)]
 297    cache_only: bool,
 298    /// Wait for all batches to complete before exiting (only applies to batched providers like teacher)
 299    #[clap(long)]
 300    wait: bool,
 301}
 302
 303#[derive(Debug, Args, Clone)]
 304struct EvalArgs {
 305    #[clap(flatten)]
 306    predict: PredictArgs,
 307    /// Path to write summary scores as JSON
 308    #[clap(long)]
 309    summary_json: Option<PathBuf>,
 310    /// Print all individual example lines (default: up to 20)
 311    #[clap(long)]
 312    verbose: bool,
 313}
 314
 315#[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)]
 316pub enum TeacherBackend {
 317    Sonnet46,
 318    #[default]
 319    Sonnet45,
 320    Gpt52,
 321}
 322
 323impl std::fmt::Display for TeacherBackend {
 324    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 325        match self {
 326            TeacherBackend::Sonnet46 => write!(f, "sonnet46"),
 327            TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
 328            TeacherBackend::Gpt52 => write!(f, "gpt52"),
 329        }
 330    }
 331}
 332
 333impl std::str::FromStr for TeacherBackend {
 334    type Err = anyhow::Error;
 335
 336    fn from_str(s: &str) -> Result<Self, Self::Err> {
 337        match s.to_lowercase().as_str() {
 338            "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
 339            "sonnet46" => Ok(TeacherBackend::Sonnet46),
 340            "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
 341            "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
 342            _ => anyhow::bail!(
 343                "unknown teacher backend `{s}`. Valid options: sonnet45, sonnet46, gpt52"
 344            ),
 345        }
 346    }
 347}
 348
 349impl TeacherBackend {
 350    pub fn model_name(&self) -> &'static str {
 351        match self {
 352            TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
 353            TeacherBackend::Sonnet46 => "claude-sonnet-4-6",
 354            TeacherBackend::Gpt52 => "gpt-5.2",
 355        }
 356    }
 357}
 358
 359#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 360enum PredictionProvider {
 361    Sweep,
 362    Mercury,
 363    Zeta1,
 364    Zeta2(ZetaFormat),
 365    Baseten(ZetaFormat),
 366    Teacher(TeacherBackend),
 367    TeacherMultiRegion(TeacherBackend),
 368    TeacherNonBatching(TeacherBackend),
 369    TeacherMultiRegionNonBatching(TeacherBackend),
 370    Repair,
 371}
 372
 373impl Default for PredictionProvider {
 374    fn default() -> Self {
 375        PredictionProvider::Zeta2(ZetaFormat::default())
 376    }
 377}
 378
 379impl std::fmt::Display for PredictionProvider {
 380    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 381        match self {
 382            PredictionProvider::Sweep => write!(f, "sweep"),
 383            PredictionProvider::Mercury => write!(f, "mercury"),
 384            PredictionProvider::Zeta1 => write!(f, "zeta1"),
 385            PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
 386            PredictionProvider::Baseten(format) => write!(f, "baseten:{format}"),
 387            PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
 388            PredictionProvider::TeacherMultiRegion(backend) => {
 389                write!(f, "teacher-multi-region:{backend}")
 390            }
 391            PredictionProvider::TeacherNonBatching(backend) => {
 392                write!(f, "teacher-non-batching:{backend}")
 393            }
 394            PredictionProvider::TeacherMultiRegionNonBatching(backend) => {
 395                write!(f, "teacher-multi-region-non-batching:{backend}")
 396            }
 397            PredictionProvider::Repair => write!(f, "repair"),
 398        }
 399    }
 400}
 401
 402impl std::str::FromStr for PredictionProvider {
 403    type Err = anyhow::Error;
 404
 405    fn from_str(s: &str) -> Result<Self, Self::Err> {
 406        let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
 407
 408        let provider_lower = provider.to_lowercase();
 409        match provider_lower.as_str() {
 410            "sweep" => Ok(PredictionProvider::Sweep),
 411            "mercury" => Ok(PredictionProvider::Mercury),
 412            "zeta1" => Ok(PredictionProvider::Zeta1),
 413            "zeta2" => {
 414                let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
 415                Ok(PredictionProvider::Zeta2(format))
 416            }
 417            "teacher" => {
 418                let backend = arg
 419                    .map(|a| a.parse())
 420                    .transpose()?
 421                    .unwrap_or(TeacherBackend::default());
 422                Ok(PredictionProvider::Teacher(backend))
 423            }
 424            "teacher-multi-region" | "teacher_multi_region" => {
 425                let backend = arg
 426                    .map(|a| a.parse())
 427                    .transpose()?
 428                    .unwrap_or(TeacherBackend::default());
 429                Ok(PredictionProvider::TeacherMultiRegion(backend))
 430            }
 431            "teacher-non-batching" | "teacher_non_batching" => {
 432                let backend = arg
 433                    .map(|a| a.parse())
 434                    .transpose()?
 435                    .unwrap_or(TeacherBackend::default());
 436                Ok(PredictionProvider::TeacherNonBatching(backend))
 437            }
 438            "teacher-multi-region-non-batching" | "teacher_multi_region_non_batching" => {
 439                let backend = arg
 440                    .map(|a| a.parse())
 441                    .transpose()?
 442                    .unwrap_or(TeacherBackend::default());
 443                Ok(PredictionProvider::TeacherMultiRegionNonBatching(backend))
 444            }
 445            "repair" => Ok(PredictionProvider::Repair),
 446            "baseten" => {
 447                let format = arg
 448                    .map(ZetaFormat::parse)
 449                    .transpose()?
 450                    .unwrap_or(ZetaFormat::default());
 451                Ok(PredictionProvider::Baseten(format))
 452            }
 453            _ => {
 454                anyhow::bail!(
 455                    "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-multi-region, teacher-multi-region:<backend>, teacher-non-batching, teacher-multi-region-non-batching, repair\n\
 456                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
 457                 For teacher providers, you can specify a backend like `teacher:sonnet46`, `teacher-multi-region:sonnet46`, `teacher-multi-region-non-batching:sonnet46`, or `teacher:gpt52`.\n\
 458                 Available zeta versions:\n{}",
 459                    ZetaFormat::options_as_string()
 460                )
 461            }
 462        }
 463    }
 464}
 465
 466impl Serialize for PredictionProvider {
 467    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 468    where
 469        S: Serializer,
 470    {
 471        serializer.serialize_str(&self.to_string())
 472    }
 473}
 474
 475impl<'de> Deserialize<'de> for PredictionProvider {
 476    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 477    where
 478        D: Deserializer<'de>,
 479    {
 480        let s = String::deserialize(deserializer)?;
 481        s.parse().map_err(serde::de::Error::custom)
 482    }
 483}
 484
 485#[derive(Debug, Args, Clone)]
 486struct SynthesizeArgs {
 487    /// Repository URLs (git@github.com:owner/repo or https://...)
 488    #[clap(long, required = true, num_args = 1..)]
 489    repos: Vec<String>,
 490
 491    /// Number of examples to generate per repository
 492    #[clap(long, default_value_t = 5)]
 493    count: usize,
 494
 495    /// Maximum commits to scan per repository before giving up
 496    #[clap(long, default_value_t = 100)]
 497    max_commits: usize,
 498
 499    /// Ignore state file and reprocess all commits
 500    #[clap(long)]
 501    fresh: bool,
 502}
 503
 504#[derive(Debug, Args, Clone)]
 505struct ImportBatchArgs {
 506    /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
 507    #[clap(long, required = true, num_args = 1..)]
 508    batch_ids: Vec<String>,
 509    /// Which provider's batches to import (anthropic or openai)
 510    #[clap(long, default_value = "anthropic")]
 511    provider: BatchProvider,
 512}
 513
 514#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
 515enum BatchProvider {
 516    Anthropic,
 517    Openai,
 518}
 519
 520#[cfg(test)]
 521mod tests {
 522    use super::*;
 523
 524    #[test]
 525    fn prediction_provider_multi_region_non_batched_round_trips_to_primary_spelling() {
 526        let provider: PredictionProvider = "teacher-multi-region-non-batching:sonnet46"
 527            .parse()
 528            .unwrap();
 529        assert_eq!(
 530            provider,
 531            PredictionProvider::TeacherMultiRegionNonBatching(TeacherBackend::Sonnet46)
 532        );
 533        assert_eq!(
 534            provider.to_string(),
 535            "teacher-multi-region-non-batching:sonnet46"
 536        );
 537    }
 538
 539    #[test]
 540    fn prediction_provider_multi_region_non_batched_alias_round_trips_to_primary_spelling() {
 541        let provider: PredictionProvider =
 542            "teacher_multi_region_non_batching:gpt52".parse().unwrap();
 543        assert_eq!(
 544            provider,
 545            PredictionProvider::TeacherMultiRegionNonBatching(TeacherBackend::Gpt52)
 546        );
 547        assert_eq!(
 548            provider.to_string(),
 549            "teacher-multi-region-non-batching:gpt52"
 550        );
 551    }
 552}
 553
 554impl EpArgs {
 555    fn output_path(&self) -> Option<PathBuf> {
 556        if self.in_place {
 557            if self.inputs.len() == 1 {
 558                self.inputs.first().cloned()
 559            } else {
 560                panic!("--in-place requires exactly one input file")
 561            }
 562        } else {
 563            self.output.clone()
 564        }
 565    }
 566}
 567
 568/// Minimum Zed version required for Snowflake queries.
 569/// This version introduced the current request schema with predicted edits in the edit
 570/// history, and open source repos distinguished.
 571const MIN_CAPTURE_VERSION: pull_examples::MinCaptureVersion = pull_examples::MinCaptureVersion {
 572    minor: 224,
 573    patch: 1,
 574};
 575
 576fn deduplicate_examples(examples: &mut Vec<Example>, max_per_cluster: usize) {
 577    let total_before_exact = examples.len();
 578    let mut seen_positions = HashSet::default();
 579    examples.retain(|example| seen_positions.insert(example.spec.cursor_position.clone()));
 580    log::info!(
 581        "exact duplicate filter: {total_before_exact} examples → {} examples ({} removed)",
 582        examples.len(),
 583        total_before_exact - examples.len(),
 584    );
 585
 586    const JACCARD_THRESHOLD: f64 = 0.5;
 587    const NUM_HASHES: usize = 128;
 588    const TOKEN_NGRAM_SIZE: usize = 5;
 589
 590    let (num_bands, band_width) = calculate_minhash_params(JACCARD_THRESHOLD, NUM_HASHES);
 591    let num_hashes = num_bands * band_width;
 592    let minhasher = MinHasher32::new(num_hashes);
 593    let mut index: MinHashIndex<u32, usize> =
 594        MinHashIndex::new(num_bands, band_width, JACCARD_THRESHOLD);
 595
 596    let signatures: Vec<Vec<u32>> = examples
 597        .iter()
 598        .map(|example| {
 599            let shingles = code_token_ngrams(&example.spec.cursor_position, TOKEN_NGRAM_SIZE);
 600            minhasher.create_signature(shingles.iter())
 601        })
 602        .collect();
 603
 604    for (id, signature) in signatures.iter().enumerate() {
 605        index.insert(id, signature.clone());
 606    }
 607
 608    // Build clusters via union-find on LSH candidate pairs.
 609    let mut parent: Vec<usize> = (0..examples.len()).collect();
 610
 611    fn find(parent: &mut Vec<usize>, mut x: usize) -> usize {
 612        while parent[x] != x {
 613            parent[x] = parent[parent[x]];
 614            x = parent[x];
 615        }
 616        x
 617    }
 618
 619    for (id, signature) in signatures.iter().enumerate() {
 620        for candidate in index.query_owned(signature) {
 621            let (a, b) = (find(&mut parent, id), find(&mut parent, candidate));
 622            if a != b {
 623                parent[a] = b;
 624            }
 625        }
 626    }
 627
 628    let mut clusters: HashMap<usize, Vec<usize>> = HashMap::default();
 629    for id in 0..examples.len() {
 630        clusters.entry(find(&mut parent, id)).or_default().push(id);
 631    }
 632
 633    let mut keep: HashSet<usize> = HashSet::default();
 634    for members in clusters.values() {
 635        let selected = greedy_max_min_diverse(members, &signatures, max_per_cluster);
 636        keep.extend(selected);
 637    }
 638
 639    let total = examples.len();
 640    let mut kept_indices: Vec<usize> = keep.into_iter().collect();
 641    kept_indices.sort();
 642
 643    let mut retained = Vec::with_capacity(kept_indices.len());
 644    for index in kept_indices.into_iter().rev() {
 645        retained.push(examples.swap_remove(index));
 646    }
 647    retained.reverse();
 648
 649    *examples = retained;
 650    log::info!(
 651        "near-duplicate filter: {total} examples → {} examples ({} removed)",
 652        examples.len(),
 653        total - examples.len(),
 654    );
 655}
 656
 657fn greedy_max_min_diverse(members: &[usize], signatures: &[Vec<u32>], k: usize) -> Vec<usize> {
 658    if members.len() <= k {
 659        return members.to_vec();
 660    }
 661
 662    let mut selected = vec![members[0]];
 663    let mut min_dist: HashMap<usize, f64> = HashMap::default();
 664    for &member in &members[1..] {
 665        let dist = 1.0 - compute_minhash_similarity(&signatures[selected[0]], &signatures[member]);
 666        min_dist.insert(member, dist);
 667    }
 668
 669    while selected.len() < k {
 670        let &best = min_dist
 671            .iter()
 672            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
 673            .map(|(id, _)| id)
 674            .expect("min_dist should not be empty when selected.len() < k");
 675        selected.push(best);
 676        min_dist.remove(&best);
 677
 678        let best_sig = &signatures[best];
 679        for (member, current_min) in min_dist.iter_mut() {
 680            let dist = 1.0 - compute_minhash_similarity(best_sig, &signatures[*member]);
 681            if dist < *current_min {
 682                *current_min = dist;
 683            }
 684        }
 685    }
 686
 687    selected
 688}
 689
 690fn code_token_ngrams(code: &str, ngram_size: usize) -> Vec<String> {
 691    let tokens: Vec<&str> = word_diff::tokenize(code)
 692        .into_iter()
 693        .filter(|t| !t.trim().is_empty())
 694        .collect();
 695
 696    if tokens.len() < ngram_size {
 697        return vec![tokens.join("\0")];
 698    }
 699
 700    tokens
 701        .windows(ngram_size)
 702        .map(|window| window.join("\0"))
 703        .collect()
 704}
 705
 706async fn load_examples(
 707    http_client: Arc<dyn http_client::HttpClient>,
 708    args: &EpArgs,
 709    output_path: Option<&PathBuf>,
 710    background_executor: BackgroundExecutor,
 711) -> anyhow::Result<Vec<Example>> {
 712    let mut captured_after_timestamps = Vec::new();
 713    let mut rejected_after_timestamps = Vec::new();
 714    let mut requested_after_timestamps = Vec::new();
 715    let mut settled_after_timestamps = Vec::new();
 716    let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
 717        Vec::new();
 718    let mut file_inputs = Vec::new();
 719
 720    for input in &args.inputs {
 721        let input_string = input.to_string_lossy();
 722        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
 723            captured_after_timestamps.push(timestamp.to_string());
 724        } else if let Some(timestamp) =
 725            pull_examples::parse_rejected_after_input(input_string.as_ref())
 726        {
 727            rejected_after_timestamps.push(timestamp.to_string());
 728        } else if let Some(timestamp) =
 729            pull_examples::parse_requested_after_input(input_string.as_ref())
 730        {
 731            requested_after_timestamps.push(timestamp.to_string());
 732        } else if let Some(timestamp) = parse_settled_after_input(input_string.as_ref()) {
 733            settled_after_timestamps.push(timestamp.to_string());
 734        } else if let Some((timestamp, rating_filter)) =
 735            pull_examples::parse_rated_after_input(input_string.as_ref())
 736        {
 737            rated_after_inputs.push((timestamp.to_string(), rating_filter));
 738        } else {
 739            file_inputs.push(input.clone());
 740        }
 741    }
 742
 743    let mut examples = read_example_files(&file_inputs);
 744
 745    // Apply offset to file examples first, then pass remaining offset to Snowflake.
 746    let file_example_count = examples.len();
 747    let remaining_offset = if let Some(offset) = args.offset {
 748        if offset >= file_example_count {
 749            examples.clear();
 750            offset - file_example_count
 751        } else {
 752            examples.splice(0..offset, []);
 753            0
 754        }
 755    } else {
 756        0
 757    };
 758
 759    Progress::global().set_total_examples(examples.len());
 760
 761    let remaining_limit_for_snowflake =
 762        args.limit.map(|limit| limit.saturating_sub(examples.len()));
 763
 764    if let Some(0) = remaining_limit_for_snowflake {
 765        log::info!(
 766            "skipping Snowflake inputs because --limit is already satisfied by example files"
 767        );
 768    } else {
 769        let max_rows_per_timestamp = remaining_limit_for_snowflake;
 770
 771        if !rejected_after_timestamps.is_empty() {
 772            rejected_after_timestamps.sort();
 773
 774            let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
 775                http_client.clone(),
 776                &rejected_after_timestamps,
 777                max_rows_per_timestamp,
 778                remaining_offset,
 779                background_executor.clone(),
 780                Some(MIN_CAPTURE_VERSION),
 781            )
 782            .await?;
 783            examples.append(&mut rejected_examples);
 784        }
 785
 786        if !requested_after_timestamps.is_empty() {
 787            requested_after_timestamps.sort();
 788
 789            let mut requested_examples = pull_examples::fetch_requested_examples_after(
 790                http_client.clone(),
 791                &requested_after_timestamps,
 792                max_rows_per_timestamp,
 793                remaining_offset,
 794                background_executor.clone(),
 795                Some(MIN_CAPTURE_VERSION),
 796            )
 797            .await?;
 798            examples.append(&mut requested_examples);
 799        }
 800
 801        if !captured_after_timestamps.is_empty() {
 802            captured_after_timestamps.sort();
 803
 804            let mut captured_examples = pull_examples::fetch_captured_examples_after(
 805                http_client.clone(),
 806                &captured_after_timestamps,
 807                max_rows_per_timestamp,
 808                remaining_offset,
 809                background_executor.clone(),
 810                Some(MIN_CAPTURE_VERSION),
 811            )
 812            .await?;
 813            examples.append(&mut captured_examples);
 814        }
 815
 816        if !settled_after_timestamps.is_empty() {
 817            settled_after_timestamps.sort();
 818
 819            let mut settled_examples = fetch_settled_examples_after(
 820                http_client.clone(),
 821                &settled_after_timestamps,
 822                max_rows_per_timestamp,
 823                remaining_offset,
 824                background_executor.clone(),
 825                Some(MIN_CAPTURE_VERSION),
 826            )
 827            .await?;
 828            examples.append(&mut settled_examples);
 829        }
 830
 831        if !rated_after_inputs.is_empty() {
 832            rated_after_inputs.sort();
 833
 834            let mut rated_examples = pull_examples::fetch_rated_examples_after(
 835                http_client,
 836                &rated_after_inputs,
 837                max_rows_per_timestamp,
 838                remaining_offset,
 839                background_executor,
 840                Some(MIN_CAPTURE_VERSION),
 841            )
 842            .await?;
 843            examples.append(&mut rated_examples);
 844        }
 845    }
 846
 847    crate::example::sort_examples_by_repo_and_rev(&mut examples);
 848
 849    if let Some(name_filter) = &args.name {
 850        examples.retain(|example| example.spec.name.contains(name_filter));
 851    }
 852    if let Some(repo_filter) = &args.repo {
 853        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
 854    }
 855
 856    // Skip resume logic for --in-place since input and output are the same file,
 857    // which would incorrectly treat all input examples as already processed.
 858    if !args.in_place {
 859        if let Some(path) = output_path
 860            && let Some(command) = &args.command
 861        {
 862            resume_from_output(path, &mut examples, command);
 863        }
 864    }
 865
 866    if let Some(max_duplicates) = args.max_duplicates {
 867        deduplicate_examples(&mut examples, max_duplicates);
 868    }
 869
 870    if let Some(limit) = args.limit {
 871        examples.truncate(limit);
 872    }
 873
 874    let progress = Progress::global();
 875    progress.set_total_examples(examples.len());
 876    progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
 877
 878    Ok(examples)
 879}
 880
 881fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
 882    let mut hasher = collections::FxHasher::default();
 883    spec.hash(&mut hasher);
 884    hasher.finish()
 885}
 886
 887fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
 888    let file = match File::open(path) {
 889        Ok(f) => f,
 890        Err(_) => return,
 891    };
 892
 893    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
 894
 895    let reader = BufReader::new(file);
 896    let mut kept_lines = Vec::new();
 897    let mut kept_hashes = HashSet::default();
 898
 899    for line in reader.lines() {
 900        let line = match line {
 901            Ok(l) => l,
 902            Err(_) => continue,
 903        };
 904
 905        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
 906            let hash = spec_hash(&output_example.spec);
 907            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
 908                let is_complete = match command {
 909                    Command::Qa(_) => output_example
 910                        .qa
 911                        .first()
 912                        .and_then(|q| q.as_ref())
 913                        .and_then(|q| q.confidence)
 914                        .is_some(),
 915                    Command::Repair(_) => output_example.predictions.iter().any(|p| {
 916                        p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
 917                    }),
 918                    _ => true,
 919                };
 920                if is_complete {
 921                    kept_hashes.insert(hash);
 922                    kept_lines.push(line);
 923                }
 924            }
 925        }
 926    }
 927
 928    let total = examples.len();
 929    let already_processed = kept_hashes.len();
 930
 931    eprintln!(
 932        "Resuming: {}/{} examples already processed",
 933        already_processed, total
 934    );
 935
 936    let file = OpenOptions::new()
 937        .write(true)
 938        .truncate(true)
 939        .open(path)
 940        .expect("Failed to open output file for rewriting");
 941    let mut writer = BufWriter::new(file);
 942    for line in &kept_lines {
 943        writeln!(writer, "{}", line).expect("Failed to write to output file");
 944    }
 945    writer.flush().expect("Failed to flush output file");
 946
 947    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
 948}
 949
 950fn main() {
 951    let args = EpArgs::parse();
 952
 953    if args.printenv {
 954        ::util::shell_env::print_env();
 955        return;
 956    }
 957
 958    let output = args.output_path();
 959
 960    if args.markdown && output.is_none() {
 961        eprintln!("--markdown requires -o to specify the output directory");
 962        std::process::exit(1);
 963    }
 964
 965    let command = match &args.command {
 966        Some(cmd) => cmd.clone(),
 967        None => {
 968            EpArgs::command().print_help().unwrap();
 969            return;
 970        }
 971    };
 972
 973    match &command {
 974        Command::ImportBatch(import_args) => {
 975            smol::block_on(async {
 976                match import_args.provider {
 977                    BatchProvider::Anthropic => {
 978                        let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
 979                            .expect("Failed to create Anthropic client");
 980                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 981                            eprintln!("Error importing Anthropic batches: {:?}", e);
 982                            std::process::exit(1);
 983                        }
 984                    }
 985                    BatchProvider::Openai => {
 986                        let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
 987                            .expect("Failed to create OpenAI client");
 988                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 989                            eprintln!("Error importing OpenAI batches: {:?}", e);
 990                            std::process::exit(1);
 991                        }
 992                    }
 993                }
 994                println!(
 995                    "Successfully imported {} batch(es)",
 996                    import_args.batch_ids.len()
 997                );
 998            });
 999            return;
1000        }
1001        Command::Clean => {
1002            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
1003            return;
1004        }
1005        Command::PrintZetaFormats => {
1006            use strum::IntoEnumIterator as _;
1007            for format in ZetaFormat::iter() {
1008                println!("{}", format.to_string().to_lowercase());
1009            }
1010            return;
1011        }
1012
1013        Command::Synthesize(synth_args) => {
1014            let output_dir = if let Some(output_dir) = args.output {
1015                output_dir
1016            } else {
1017                let default_output_dir = env::current_dir()
1018                    .unwrap()
1019                    .join("crates/edit_prediction_cli/evals-generated");
1020                if default_output_dir.parent().unwrap().exists() {
1021                    std::fs::create_dir(&default_output_dir).ok();
1022                    default_output_dir
1023                } else {
1024                    panic!("output dir is required");
1025                }
1026            };
1027            let config = SynthesizeConfig {
1028                repo_urls: synth_args.repos.clone(),
1029                count: synth_args.count,
1030                max_commits: synth_args.max_commits,
1031                output_dir,
1032                fresh: synth_args.fresh,
1033            };
1034            smol::block_on(async {
1035                if let Err(e) = run_synthesize(config).await {
1036                    eprintln!("Error: {:?}", e);
1037                    std::process::exit(1);
1038                }
1039            });
1040            return;
1041        }
1042        Command::SplitCommit(split_commit_args) => {
1043            if let Err(error) = split_commit::run_split_commit(
1044                split_commit_args,
1045                &args.inputs,
1046                output.as_ref(),
1047                args.failed,
1048            ) {
1049                eprintln!("{error:#}");
1050                std::process::exit(1);
1051            }
1052            return;
1053        }
1054        Command::TruncatePatch(truncate_args) => {
1055            if let Err(error) =
1056                truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
1057            {
1058                eprintln!("{error:#}");
1059                std::process::exit(1);
1060            }
1061            return;
1062        }
1063        Command::Split(split_args) => {
1064            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
1065                eprintln!("{error:#}");
1066                std::process::exit(1);
1067            }
1068            return;
1069        }
1070        Command::FilterLanguages(filter_args) => {
1071            if let Err(error) =
1072                run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
1073            {
1074                eprintln!("{error:#}");
1075                std::process::exit(1);
1076            }
1077            return;
1078        }
1079
1080        _ => {}
1081    }
1082
1083    let http_client = Arc::new(ReqwestClient::new());
1084    let app = gpui_platform::headless().with_http_client(http_client);
1085
1086    app.run(move |cx| {
1087        let app_state = Arc::new(headless::init(cx));
1088        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
1089
1090        cx.spawn(async move |cx| {
1091            let result = async {
1092                let examples = load_examples(
1093                    app_state.client.http_client(),
1094                    &args,
1095                    output.as_ref(),
1096                    cx.background_executor().clone(),
1097                )
1098                .await?;
1099
1100                match &command {
1101                    Command::Predict(args) | Command::Score(args) => {
1102                        predict::sync_batches(args.provider.as_ref()).await?;
1103                    }
1104                    Command::Eval(args) => {
1105                        predict::sync_batches(args.predict.provider.as_ref()).await?;
1106                    }
1107                    Command::Qa(args) => {
1108                        qa::sync_batches(args).await?;
1109                    }
1110                    Command::Repair(args) => {
1111                        repair::sync_batches(args).await?;
1112                    }
1113                    _ => (),
1114                }
1115
1116                let failfast_on_single_example = examples.len() == 1;
1117
1118                // For --markdown mode, create the output directory if it doesn't exist
1119                if args.markdown {
1120                    let dir = output.as_ref().expect("--markdown requires -o");
1121                    if !dir.exists() {
1122                        std::fs::create_dir_all(dir)
1123                            .expect("Failed to create markdown output directory");
1124                    }
1125                }
1126
1127                // Set up JSONL output writer (not used in markdown mode)
1128                let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
1129                let mut in_place_temp_path: Option<PathBuf> = None;
1130                if !args.markdown
1131                    && let Some(output_path) = output.as_ref()
1132                {
1133                    let write_path = if args.in_place {
1134                        let temp = output_path.with_extension("jsonl.tmp");
1135                        in_place_temp_path = Some(temp.clone());
1136                        temp
1137                    } else {
1138                        output_path.clone()
1139                    };
1140
1141                    let file = OpenOptions::new()
1142                        .create(true)
1143                        .write(true)
1144                        .truncate(args.in_place)
1145                        .append(!args.in_place)
1146                        .open(&write_path)
1147                        .expect("Failed to open output file");
1148
1149                    let mut writer = BufWriter::new(file);
1150                    let (sender, mut receiver) = mpsc::unbounded::<String>();
1151                    cx.background_spawn(async move {
1152                        while let Some(line) = receiver.next().await {
1153                            writeln!(writer, "{}", line).expect("Failed to write example");
1154                            writer.flush().expect("Failed to flush output");
1155                        }
1156                    })
1157                    .detach();
1158                    output_sender = Some(sender);
1159                }
1160
1161                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
1162                let finished_examples = Mutex::new(Vec::new());
1163
1164                let mut tasks = Vec::new();
1165                for _ in 0..args.max_parallelism {
1166                    tasks.push(async {
1167                        loop {
1168                            let Some(mut repo_examples) =
1169                                grouped_examples.lock().unwrap().pop_front()
1170                            else {
1171                                break;
1172                            };
1173                            for example in &mut repo_examples {
1174                                let example_progress =
1175                                    Progress::global().start_group(&example.spec.name);
1176
1177                                let result = async {
1178                                    match &command {
1179                                        Command::Read(_) => {}
1180                                        Command::LoadProject => {
1181                                            run_load_project(
1182                                                example,
1183                                                app_state.clone(),
1184                                                &example_progress,
1185                                                cx.clone(),
1186                                            )
1187                                            .await?;
1188                                        }
1189                                        Command::Context => {
1190                                            run_context_retrieval(
1191                                                example,
1192                                                app_state.clone(),
1193                                                &example_progress,
1194                                                cx.clone(),
1195                                            )
1196                                            .await?;
1197                                        }
1198                                        Command::FormatPrompt(args) => {
1199                                            run_format_prompt(
1200                                                example,
1201                                                args,
1202                                                app_state.clone(),
1203                                                &example_progress,
1204                                                cx.clone(),
1205                                            )
1206                                            .await?;
1207                                        }
1208                                        Command::Predict(args) => {
1209                                            run_prediction(
1210                                                example,
1211                                                args,
1212                                                app_state.clone(),
1213                                                &example_progress,
1214                                                cx.clone(),
1215                                            )
1216                                            .await?;
1217                                        }
1218                                        Command::ParseOutput => {
1219                                            parse_output::run_parse_output(example)?;
1220                                        }
1221                                        Command::Distill => {
1222                                            run_distill(example).await?;
1223                                        }
1224                                        Command::Score(args) => {
1225                                            run_scoring(
1226                                                example,
1227                                                args,
1228                                                app_state.clone(),
1229                                                &example_progress,
1230                                                cx.clone(),
1231                                            )
1232                                            .await?;
1233                                        }
1234                                        Command::Eval(args) => {
1235                                            run_scoring(
1236                                                example,
1237                                                &args.predict,
1238                                                app_state.clone(),
1239                                                &example_progress,
1240                                                cx.clone(),
1241                                            )
1242                                            .await?;
1243                                        }
1244                                        Command::Qa(args) => {
1245                                            qa::run_qa(example, args, &example_progress).await?;
1246                                        }
1247                                        Command::Repair(args) => {
1248                                            repair::run_repair(example, args, &example_progress)
1249                                                .await?;
1250                                        }
1251                                        Command::Clean
1252                                        | Command::Synthesize(_)
1253                                        | Command::SplitCommit(_)
1254                                        | Command::Split(_)
1255                                        | Command::TruncatePatch(_)
1256                                        | Command::FilterLanguages(_)
1257                                        | Command::ImportBatch(_)
1258                                        | Command::PrintZetaFormats => {
1259                                            unreachable!()
1260                                        }
1261                                    }
1262                                    anyhow::Ok(())
1263                                }
1264                                .await;
1265
1266                                let failed = if let Err(error) = result {
1267                                    handle_error(
1268                                        error,
1269                                        &args,
1270                                        &command,
1271                                        &app_state,
1272                                        failfast_on_single_example,
1273                                        &example,
1274                                    )
1275                                    .await;
1276                                    true
1277                                } else {
1278                                    false
1279                                };
1280
1281                                let should_write = !failed || args.failed == FailedHandling::Keep;
1282                                if should_write {
1283                                    if args.markdown {
1284                                        let markdown_dir =
1285                                            output.as_ref().expect("--markdown requires -o");
1286                                        let filename = format!("{}.md", example.spec.filename());
1287                                        let path = markdown_dir.join(&filename);
1288                                        let markdown = example.spec.to_markdown();
1289                                        std::fs::write(&path, &markdown)
1290                                            .expect("Failed to write markdown file");
1291                                    } else if let Some(ref mut sender) = output_sender.clone() {
1292                                        let line = serde_json::to_string(&example).unwrap();
1293                                        sender
1294                                            .send(line)
1295                                            .await
1296                                            .expect("Failed to send to output writer");
1297                                    } else if args.output.is_none()
1298                                        && !matches!(command, Command::Eval(_))
1299                                    {
1300                                        let line = serde_json::to_string(&example).unwrap();
1301                                        println!("{}", line);
1302                                    }
1303                                }
1304                            }
1305
1306                            let project = repo_examples
1307                                .iter()
1308                                .find_map(|e| e.state.as_ref().map(|s| s.project.clone()));
1309
1310                            if let Some(project) = project {
1311                                let mut cx = cx.clone();
1312
1313                                let shutdown_task: Task<()> =
1314                                    project.update(&mut cx, |project, cx| {
1315                                        let lsp_store = project.lsp_store();
1316                                        lsp_store.update(cx, |lsp_store, cx| {
1317                                            lsp_store.shutdown_all_language_servers(cx)
1318                                        })
1319                                    });
1320
1321                                shutdown_task.await;
1322
1323                                if let Some(ep_store) =
1324                                    cx.update(|cx| EditPredictionStore::try_global(cx))
1325                                {
1326                                    ep_store.update(&mut cx, |store, _| {
1327                                        store.remove_project(&project);
1328                                    });
1329                                }
1330                            }
1331
1332                            for example in &mut repo_examples {
1333                                example.state.take();
1334                            }
1335                            finished_examples
1336                                .lock()
1337                                .unwrap()
1338                                .extend_from_slice(&repo_examples);
1339                        }
1340                    });
1341                }
1342                futures::future::join_all(tasks).await;
1343
1344                Progress::global().finalize();
1345
1346                let is_markdown = args.markdown;
1347                let write_path = in_place_temp_path.as_ref().or(output.as_ref());
1348                match &command {
1349                    Command::Predict(args) | Command::Score(args) => {
1350                        predict::sync_batches(args.provider.as_ref()).await?;
1351                        if args.wait {
1352                            predict::wait_for_batches(args.provider.as_ref()).await?;
1353                            let mut examples =
1354                                std::mem::take(&mut *finished_examples.lock().unwrap());
1355                            predict::reprocess_after_batch_wait(&mut examples, args).await?;
1356                            rewrite_output(&examples, write_path, is_markdown)?;
1357                            *finished_examples.lock().unwrap() = examples;
1358                        }
1359                    }
1360                    Command::Eval(args) => {
1361                        predict::sync_batches(args.predict.provider.as_ref()).await?;
1362                        if args.predict.wait {
1363                            predict::wait_for_batches(args.predict.provider.as_ref()).await?;
1364                            let mut examples =
1365                                std::mem::take(&mut *finished_examples.lock().unwrap());
1366                            predict::reprocess_after_batch_wait(&mut examples, &args.predict)
1367                                .await?;
1368                            rewrite_output(&examples, write_path, is_markdown)?;
1369                            *finished_examples.lock().unwrap() = examples;
1370                        }
1371                    }
1372                    Command::Qa(args) => {
1373                        qa::sync_batches(args).await?;
1374                    }
1375                    Command::Repair(args) => {
1376                        repair::sync_batches(args).await?;
1377                        if args.wait {
1378                            repair::wait_for_batches(args).await?;
1379                            let mut examples =
1380                                std::mem::take(&mut *finished_examples.lock().unwrap());
1381                            repair::reprocess_after_batch_wait(&mut examples, args).await?;
1382                            rewrite_output(&examples, write_path, is_markdown)?;
1383                            *finished_examples.lock().unwrap() = examples;
1384                        }
1385                    }
1386                    _ => (),
1387                }
1388
1389                match &command {
1390                    Command::Eval(args) => {
1391                        let examples = finished_examples.lock().unwrap();
1392                        score::print_report(&examples, args.verbose);
1393                        if let Some(summary_path) = &args.summary_json {
1394                            score::write_summary_json(&examples, summary_path)?;
1395                        }
1396                    }
1397                    Command::Repair(args) => {
1398                        let examples = finished_examples.lock().unwrap();
1399                        repair::print_report(&examples, args.confidence_threshold);
1400                    }
1401                    _ => (),
1402                };
1403
1404                // For --in-place, atomically rename temp file to original
1405                if let Some(temp_path) = &in_place_temp_path {
1406                    let final_path = output.as_ref().expect("in_place_temp_path requires output");
1407                    std::fs::rename(temp_path, final_path)
1408                        .expect("Failed to rename temp file to final output");
1409                }
1410
1411                anyhow::Ok(())
1412            }
1413            .await;
1414
1415            if let Err(e) = result {
1416                panic!("Fatal error: {:?}", e);
1417            }
1418
1419            let _ = cx.update(|cx| cx.quit());
1420        })
1421        .detach();
1422    });
1423}
1424
1425fn rewrite_output(
1426    examples: &[Example],
1427    output_path: Option<&PathBuf>,
1428    markdown: bool,
1429) -> anyhow::Result<()> {
1430    if markdown {
1431        let dir = output_path.context("--markdown requires -o")?;
1432        for example in examples {
1433            let filename = format!("{}.md", example.spec.filename());
1434            let path = dir.join(&filename);
1435            let markdown = example.spec.to_markdown();
1436            std::fs::write(&path, &markdown).context("Failed to write markdown file")?;
1437        }
1438    } else if let Some(path) = output_path {
1439        let file = OpenOptions::new()
1440            .create(true)
1441            .write(true)
1442            .truncate(true)
1443            .open(path)
1444            .context("Failed to open output file for rewriting")?;
1445        let mut writer = BufWriter::new(file);
1446        for example in examples {
1447            let line = serde_json::to_string(example)?;
1448            writeln!(writer, "{}", line)?;
1449        }
1450        writer.flush()?;
1451    } else {
1452        for example in examples {
1453            let line = serde_json::to_string(example)?;
1454            println!("{}", line);
1455        }
1456    }
1457    Ok(())
1458}
1459
1460async fn handle_error(
1461    error: anyhow::Error,
1462    args: &EpArgs,
1463    command: &Command,
1464    app_state: &Arc<headless::EpAppState>,
1465    failfast_on_single_example: bool,
1466    example: &Example,
1467) {
1468    Progress::global().increment_failed();
1469
1470    let msg;
1471    if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1472        let example_name = example.spec.filename();
1473
1474        let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1475        app_state
1476            .fs
1477            .write(
1478                &failed_example_path,
1479                &serde_json::to_vec_pretty(&example).unwrap(),
1480            )
1481            .await
1482            .unwrap();
1483        let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1484        app_state
1485            .fs
1486            .write(&err_path, format!("{error:?}").as_bytes())
1487            .await
1488            .unwrap();
1489
1490        let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1491        let mut file = OpenOptions::new()
1492            .create(true)
1493            .append(true)
1494            .open(&failed_jsonl_path)
1495            .expect("Failed to open failed.jsonl");
1496        writeln!(file, "{}", serde_json::to_string(example).unwrap())
1497            .expect("Failed to write to failed.jsonl");
1498
1499        let cursor_path = match example.repo_name() {
1500            Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1501            Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1502        };
1503        msg = format!(
1504            indoc::indoc! {"
1505                While processing \"{}\":
1506
1507                \x1b[31m{:?}\x1b[0m
1508
1509                Example:        \x1b[36m{}\x1b[0m
1510                Error file:     \x1b[36m{}\x1b[0m
1511                Cursor file:    \x1b[36m{}\x1b[0m
1512                Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1513            "},
1514            example.spec.name,
1515            error,
1516            failed_example_path.display(),
1517            err_path.display(),
1518            cursor_path.display(),
1519            command,
1520            failed_example_path.display(),
1521        );
1522    } else {
1523        msg = format!(
1524            indoc::indoc! {"
1525            While processing \"{}\":
1526
1527                \x1b[31m{:?}\x1b[0m
1528            "},
1529            example.spec.name, error
1530        );
1531    }
1532
1533    if args.failfast || failfast_on_single_example {
1534        Progress::global().finalize();
1535        panic!("{}", msg);
1536    } else {
1537        log::error!("{}", msg);
1538    }
1539}