main.rs

   1mod anthropic_client;
   2mod distill;
   3mod example;
   4mod filter_languages;
   5mod format_prompt;
   6mod git;
   7mod headless;
   8
   9mod load_project;
  10mod metrics;
  11mod openai_client;
  12mod parse_output;
  13mod paths;
  14mod predict;
  15mod progress;
  16mod prompt_assets;
  17mod pull_examples;
  18mod qa;
  19mod reorder_patch;
  20mod repair;
  21mod retrieve_context;
  22mod reversal_tracking;
  23mod score;
  24mod split_commit;
  25mod split_dataset;
  26
  27mod synthesize;
  28mod truncate_expected_patch;
  29mod word_diff;
  30use anyhow::Context as _;
  31use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
  32use collections::{HashMap, HashSet};
  33use edit_prediction::EditPredictionStore;
  34use futures::channel::mpsc;
  35use futures::{SinkExt as _, StreamExt as _};
  36use gaoya::minhash::{
  37    MinHashIndex, MinHasher, MinHasher32, calculate_minhash_params, compute_minhash_similarity,
  38};
  39use gpui::{AppContext as _, BackgroundExecutor, Task};
  40use zeta_prompt::ZetaFormat;
  41
  42use reqwest_client::ReqwestClient;
  43use serde::{Deserialize, Deserializer, Serialize, Serializer};
  44use std::env;
  45use std::fmt::Display;
  46use std::fs::{File, OpenOptions};
  47use std::hash::{Hash, Hasher};
  48use std::io::{BufRead, BufReader, BufWriter, Write};
  49use std::str::FromStr;
  50use std::sync::Mutex;
  51use std::{path::PathBuf, sync::Arc};
  52
  53use crate::distill::run_distill;
  54use crate::example::{Example, group_examples_by_repo, read_example_files};
  55use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
  56use crate::format_prompt::run_format_prompt;
  57use crate::load_project::run_load_project;
  58use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
  59use crate::predict::run_prediction;
  60use crate::progress::Progress;
  61use crate::pull_examples::{fetch_settled_examples_after, parse_settled_after_input};
  62use crate::retrieve_context::run_context_retrieval;
  63use crate::score::run_scoring;
  64use crate::split_commit::SplitCommitArgs;
  65use crate::split_dataset::SplitArgs;
  66use crate::synthesize::{SynthesizeConfig, run_synthesize};
  67use crate::truncate_expected_patch::TruncatePatchArgs;
  68
  69#[derive(Parser, Debug)]
  70#[command(name = "ep")]
  71struct EpArgs {
  72    #[arg(long, default_value_t = false)]
  73    printenv: bool,
  74    #[clap(long, default_value_t = 10, global = true)]
  75    max_parallelism: usize,
  76    /// The limit for the number of examples to process
  77    /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
  78    #[clap(long, global = true)]
  79    limit: Option<usize>,
  80    #[clap(long, global = true)]
  81    offset: Option<usize>,
  82    /// Filter examples by name
  83    #[clap(long, global = true)]
  84    name: Option<String>,
  85    /// Filter examples by repository
  86    #[clap(long, global = true)]
  87    repo: Option<String>,
  88    /// Deduplicate by cursor position and keep at most this many examples per cluster
  89    #[clap(long, global = true)]
  90    max_duplicates: Option<usize>,
  91    #[command(subcommand)]
  92    command: Option<Command>,
  93    /// Input file paths
  94    #[clap(global = true)]
  95    inputs: Vec<PathBuf>,
  96    #[arg(long, short, global = true)]
  97    output: Option<PathBuf>,
  98    #[arg(long, short, global = true)]
  99    in_place: bool,
 100    #[arg(long, short, global = true)]
 101    failfast: bool,
 102    /// How to handle failed examples in output: keep them or skip them.
 103    /// Failed examples are always logged to the run's failed directory.
 104    #[arg(long, global = true, default_value = "keep")]
 105    failed: FailedHandling,
 106    /// Output as markdown files instead of JSONL. When set, -o specifies a directory
 107    /// where one .md file per example will be written (named after each example).
 108    #[arg(long, short, global = true)]
 109    markdown: bool,
 110}
 111
 112/// Controls whether failed examples are included in the main output.
 113/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
 114#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
 115pub enum FailedHandling {
 116    /// Include failed examples in the main output (default)
 117    #[default]
 118    Keep,
 119    /// Exclude failed examples from the main output
 120    Skip,
 121    /// Skip writing files
 122    SkipNoFiles,
 123}
 124
 125const INPUTS_HELP: &str = r#"
 126Inputs can be file paths or special specifiers:
 127
 128  path
 129      Path to an example(s) file (.md, .json, or .jsonl)
 130
 131  captured-after:{timestamp}
 132      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 133      These are examples captured via the "Capture Edit Prediction Example" action.
 134
 135  rejected-after:{timestamp}
 136      Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
 137      These are predictions that were shown to users but rejected (useful for DPO training).
 138
 139  settled-after:{timestamp}
 140      Fetch settled stream examples from Snowflake after the given RFC3339 timestamp.
 141      These are examples from the edit prediction settled stream.
 142
 143  rated-after:{timestamp}
 144      Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
 145      These are predictions that users explicitly rated as positive or negative via the
 146      rate completions modal. Only zeta2 predictions are included.
 147      - Positive ratings: output becomes expected_patches
 148      - Negative ratings: output becomes rejected_patch
 149
 150  rated-positive-after:{timestamp}
 151      Same as rated-after, but only fetches positively rated predictions.
 152
 153  rated-negative-after:{timestamp}
 154      Same as rated-after, but only fetches negatively rated predictions.
 155
 156      Required environment variables to connect to Snowflake:
 157          EP_SNOWFLAKE_API_KEY
 158          EP_SNOWFLAKE_BASE_URL
 159
 160      Optional:
 161          EP_SNOWFLAKE_ROLE
 162
 163Examples:
 164
 165  # Read examples from a file
 166  ep read examples.jsonl -o output.jsonl
 167
 168  # Read captured examples after a timestamp
 169  ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
 170
 171  # Read rejected predictions for DPO training
 172  ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
 173
 174  # Read user-rated predictions
 175  ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
 176
 177  # Read settled stream examples
 178  ep read settled-after:2025-01-01T00:00:00Z -o settled.jsonl
 179
 180  # Read only positively rated predictions
 181  ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
 182
 183  # Read only negatively rated predictions
 184  ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
 185
 186  # Mix multiple input sources
 187  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
 188"#;
 189
 190#[derive(Subcommand, Debug, Clone)]
 191enum Command {
 192    /// Read examples from files or fetch from Snowflake, output as .jsonl
 193    Read(ReadArgs),
 194    /// Create git worktrees for each example and load file contents
 195    LoadProject,
 196    /// Retrieve context for input examples.
 197    Context,
 198    /// Generate a prompt string for a specific model
 199    FormatPrompt(FormatPromptArgs),
 200    /// Runs edit prediction
 201    Predict(PredictArgs),
 202    /// Parse model outputs (actual_output) into unified diffs (actual_patch).
 203    /// Requires format-prompt to have been run first. Uses provider from prompt.
 204    ParseOutput,
 205    /// Computes a score based on actual and expected patches
 206    Score(PredictArgs),
 207    /// Prepares a distillation dataset by copying expected outputs to
 208    /// predicted outputs and removing actual outputs and prompts.
 209    Distill,
 210    /// Print aggregated scores
 211    Eval(EvalArgs),
 212    /// Generate eval examples by analyzing git commits from a repository
 213    Synthesize(SynthesizeArgs),
 214    /// Remove git repositories and worktrees
 215    Clean,
 216    /// Generate an evaluation example by splitting a chronologically-ordered commit
 217    SplitCommit(SplitCommitArgs),
 218    /// Truncate expected patch by the given criteria
 219    TruncatePatch(TruncatePatchArgs),
 220    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
 221    Split(SplitArgs),
 222    /// Filter a JSONL dataset by programming language (based on cursor_path extension)
 223    FilterLanguages(FilterLanguagesArgs),
 224    /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
 225    ImportBatch(ImportBatchArgs),
 226    /// Assess the quality of predictions using LLM-as-a-judge
 227    Qa(qa::QaArgs),
 228    /// Repair predictions that received poor QA scores by generating improved predictions
 229    Repair(repair::RepairArgs),
 230    /// Print all valid zeta formats (lowercase, one per line)
 231    PrintZetaFormats,
 232}
 233
 234impl Display for Command {
 235    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 236        match self {
 237            Command::Read(_) => write!(f, "read"),
 238            Command::LoadProject => write!(f, "load-project"),
 239            Command::Context => write!(f, "context"),
 240            Command::FormatPrompt(args) => {
 241                write!(f, "format-prompt --provider={}", args.provider)
 242            }
 243            Command::Predict(args) => match &args.provider {
 244                Some(provider) => write!(f, "predict --provider={}", provider),
 245                None => write!(f, "predict"),
 246            },
 247            Command::ParseOutput => write!(f, "parse-output"),
 248            Command::Score(args) => match &args.provider {
 249                Some(provider) => write!(f, "score --provider={}", provider),
 250                None => write!(f, "score"),
 251            },
 252            Command::Distill => write!(f, "distill"),
 253            Command::Eval(args) => match &args.predict.provider {
 254                Some(provider) => write!(f, "eval --provider={}", provider),
 255                None => write!(f, "eval"),
 256            },
 257            Command::Synthesize(args) => {
 258                write!(f, "synthesize --repos {}", args.repos.join(" "))
 259            }
 260            Command::Clean => write!(f, "clean"),
 261            Command::SplitCommit(_) => write!(f, "split-commit"),
 262            Command::TruncatePatch(_) => write!(f, "truncate-patch"),
 263            Command::Split(_) => write!(f, "split"),
 264            Command::FilterLanguages(_) => write!(f, "filter-languages"),
 265            Command::ImportBatch(args) => {
 266                write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
 267            }
 268            Command::Qa(_) => {
 269                write!(f, "qa")
 270            }
 271            Command::Repair(_) => {
 272                write!(f, "repair")
 273            }
 274            Command::PrintZetaFormats => {
 275                write!(f, "print-zeta-formats")
 276            }
 277        }
 278    }
 279}
 280
 281#[derive(Debug, Args, Clone)]
 282#[command(after_help = INPUTS_HELP)]
 283struct ReadArgs {}
 284
 285#[derive(Debug, Args, Clone)]
 286struct FormatPromptArgs {
 287    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
 288    provider: PredictionProvider,
 289}
 290
 291#[derive(Debug, Args, Clone)]
 292struct PredictArgs {
 293    #[clap(long, short('p'))]
 294    provider: Option<PredictionProvider>,
 295    #[clap(long, default_value_t = 1)]
 296    repetitions: usize,
 297    /// Only use cached responses, don't queue new requests for batching
 298    #[clap(long)]
 299    cache_only: bool,
 300    /// Wait for all batches to complete before exiting (only applies to batched providers like teacher)
 301    #[clap(long)]
 302    wait: bool,
 303}
 304
 305#[derive(Debug, Args, Clone)]
 306struct EvalArgs {
 307    #[clap(flatten)]
 308    predict: PredictArgs,
 309    /// Path to write summary scores as JSON
 310    #[clap(long)]
 311    summary_json: Option<PathBuf>,
 312    /// Print all individual example lines (default: up to 20)
 313    #[clap(long)]
 314    verbose: bool,
 315}
 316
 317#[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)]
 318pub enum TeacherBackend {
 319    Sonnet46,
 320    #[default]
 321    Sonnet45,
 322    Gpt52,
 323}
 324
 325impl std::fmt::Display for TeacherBackend {
 326    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 327        match self {
 328            TeacherBackend::Sonnet46 => write!(f, "sonnet46"),
 329            TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
 330            TeacherBackend::Gpt52 => write!(f, "gpt52"),
 331        }
 332    }
 333}
 334
 335impl std::str::FromStr for TeacherBackend {
 336    type Err = anyhow::Error;
 337
 338    fn from_str(s: &str) -> Result<Self, Self::Err> {
 339        match s.to_lowercase().as_str() {
 340            "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
 341            "sonnet46" => Ok(TeacherBackend::Sonnet46),
 342            "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
 343            "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
 344            _ => anyhow::bail!(
 345                "unknown teacher backend `{s}`. Valid options: sonnet45, sonnet46, gpt52"
 346            ),
 347        }
 348    }
 349}
 350
 351impl TeacherBackend {
 352    pub fn model_name(&self) -> &'static str {
 353        match self {
 354            TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
 355            TeacherBackend::Sonnet46 => "claude-sonnet-4-6",
 356            TeacherBackend::Gpt52 => "gpt-5.2",
 357        }
 358    }
 359}
 360
 361#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 362enum PredictionProvider {
 363    Mercury,
 364    Zeta1,
 365    Zeta2(ZetaFormat),
 366    Baseten(ZetaFormat),
 367    Teacher(TeacherBackend, ZetaFormat),
 368    TeacherMultiRegion(TeacherBackend),
 369    TeacherNonBatching(TeacherBackend, ZetaFormat),
 370    TeacherMultiRegionNonBatching(TeacherBackend),
 371    Repair,
 372}
 373
 374impl Default for PredictionProvider {
 375    fn default() -> Self {
 376        PredictionProvider::Zeta2(ZetaFormat::default())
 377    }
 378}
 379
 380impl std::fmt::Display for PredictionProvider {
 381    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 382        match self {
 383            PredictionProvider::Mercury => write!(f, "mercury"),
 384            PredictionProvider::Zeta1 => write!(f, "zeta1"),
 385            PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
 386            PredictionProvider::Baseten(format) => write!(f, "baseten:{format}"),
 387            PredictionProvider::Teacher(backend, format) => {
 388                write!(f, "teacher:{backend}:{format:?}")
 389            }
 390            PredictionProvider::TeacherMultiRegion(backend) => {
 391                write!(f, "teacher-multi-region:{backend}")
 392            }
 393            PredictionProvider::TeacherNonBatching(backend, format) => {
 394                write!(f, "teacher-non-batching:{backend}:{format:?}")
 395            }
 396            PredictionProvider::TeacherMultiRegionNonBatching(backend) => {
 397                write!(f, "teacher-multi-region-non-batching:{backend}")
 398            }
 399            PredictionProvider::Repair => write!(f, "repair"),
 400        }
 401    }
 402}
 403
 404impl std::str::FromStr for PredictionProvider {
 405    type Err = anyhow::Error;
 406
 407    fn from_str(s: &str) -> Result<Self, Self::Err> {
 408        let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
 409
 410        let provider_lower = provider.to_lowercase();
 411        match provider_lower.as_str() {
 412            "mercury" => Ok(PredictionProvider::Mercury),
 413            "zeta1" => Ok(PredictionProvider::Zeta1),
 414            "zeta2" => {
 415                let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
 416                Ok(PredictionProvider::Zeta2(format))
 417            }
 418            "teacher" => parse_teacher_args(arg),
 419            "teacher-non-batching" | "teacher_non_batching" => {
 420                let backend = arg
 421                    .map(|a| a.parse())
 422                    .transpose()?
 423                    .unwrap_or(TeacherBackend::default());
 424                Ok(PredictionProvider::TeacherNonBatching(
 425                    backend,
 426                    ZetaFormat::default(),
 427                ))
 428            }
 429            "teacher-multi-region" | "teacher_multi_region" => {
 430                let backend = arg
 431                    .map(|a| a.parse())
 432                    .transpose()?
 433                    .unwrap_or(TeacherBackend::default());
 434                Ok(PredictionProvider::TeacherMultiRegion(backend))
 435            }
 436            "teacher-multi-region-non-batching" | "teacher_multi_region_non_batching" => {
 437                let backend = arg
 438                    .map(|a| a.parse())
 439                    .transpose()?
 440                    .unwrap_or(TeacherBackend::default());
 441                Ok(PredictionProvider::TeacherMultiRegionNonBatching(backend))
 442            }
 443            "repair" => Ok(PredictionProvider::Repair),
 444            "baseten" => {
 445                let format = arg
 446                    .map(ZetaFormat::parse)
 447                    .transpose()?
 448                    .unwrap_or(ZetaFormat::default());
 449                Ok(PredictionProvider::Baseten(format))
 450            }
 451            _ => {
 452                anyhow::bail!(
 453                    "unknown provider `{provider}`. Valid options: mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-multi-region, teacher-multi-region:<backend>, teacher-non-batching, teacher-multi-region-non-batching, repair\n\
 454                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
 455                 For teacher providers, you can specify a backend like `teacher:sonnet46`, `teacher-multi-region:sonnet46`, `teacher-multi-region-non-batching:sonnet46`, or `teacher:gpt52`.\n\
 456                 Available zeta versions:\n{}",
 457                    ZetaFormat::options_as_string()
 458                )
 459            }
 460        }
 461    }
 462}
 463
 464fn parse_teacher_args(arg: Option<&str>) -> Result<PredictionProvider, anyhow::Error> {
 465    let mut backend = TeacherBackend::default();
 466    let mut format = ZetaFormat::default();
 467
 468    for arg in arg.unwrap_or_default().split(':') {
 469        if arg.is_empty() {
 470            continue;
 471        }
 472
 473        if let Ok(parsed_backend) = TeacherBackend::from_str(arg) {
 474            backend = parsed_backend;
 475        } else if let Ok(parsed_format) = ZetaFormat::parse(arg) {
 476            format = parsed_format;
 477        } else {
 478            anyhow::bail!("unknown teacher backend or zeta format `{arg}`");
 479        }
 480    }
 481
 482    Ok(PredictionProvider::Teacher(backend, format))
 483}
 484
 485impl Serialize for PredictionProvider {
 486    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 487    where
 488        S: Serializer,
 489    {
 490        serializer.serialize_str(&self.to_string())
 491    }
 492}
 493
 494impl<'de> Deserialize<'de> for PredictionProvider {
 495    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 496    where
 497        D: Deserializer<'de>,
 498    {
 499        let s = String::deserialize(deserializer)?;
 500        s.parse().map_err(serde::de::Error::custom)
 501    }
 502}
 503
 504#[derive(Debug, Args, Clone)]
 505struct SynthesizeArgs {
 506    /// Repository URLs (git@github.com:owner/repo or https://...)
 507    #[clap(long, required = true, num_args = 1..)]
 508    repos: Vec<String>,
 509
 510    /// Number of examples to generate per repository
 511    #[clap(long, default_value_t = 5)]
 512    count: usize,
 513
 514    /// Maximum commits to scan per repository before giving up
 515    #[clap(long, default_value_t = 100)]
 516    max_commits: usize,
 517
 518    /// Ignore state file and reprocess all commits
 519    #[clap(long)]
 520    fresh: bool,
 521}
 522
 523#[derive(Debug, Args, Clone)]
 524struct ImportBatchArgs {
 525    /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
 526    #[clap(long, required = true, num_args = 1..)]
 527    batch_ids: Vec<String>,
 528    /// Which provider's batches to import (anthropic or openai)
 529    #[clap(long, default_value = "anthropic")]
 530    provider: BatchProvider,
 531}
 532
 533#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
 534enum BatchProvider {
 535    Anthropic,
 536    Openai,
 537}
 538
 539#[cfg(test)]
 540mod tests {
 541    use super::*;
 542
 543    #[test]
 544    fn prediction_provider_multi_region_non_batched_round_trips_to_primary_spelling() {
 545        let provider: PredictionProvider = "teacher-multi-region-non-batching:sonnet46"
 546            .parse()
 547            .unwrap();
 548        assert_eq!(
 549            provider,
 550            PredictionProvider::TeacherMultiRegionNonBatching(TeacherBackend::Sonnet46)
 551        );
 552        assert_eq!(
 553            provider.to_string(),
 554            "teacher-multi-region-non-batching:sonnet46"
 555        );
 556    }
 557
 558    #[test]
 559    fn prediction_provider_multi_region_non_batched_alias_round_trips_to_primary_spelling() {
 560        let provider: PredictionProvider =
 561            "teacher_multi_region_non_batching:gpt52".parse().unwrap();
 562        assert_eq!(
 563            provider,
 564            PredictionProvider::TeacherMultiRegionNonBatching(TeacherBackend::Gpt52)
 565        );
 566        assert_eq!(
 567            provider.to_string(),
 568            "teacher-multi-region-non-batching:gpt52"
 569        );
 570    }
 571}
 572
 573impl EpArgs {
 574    fn output_path(&self) -> Option<PathBuf> {
 575        if self.in_place {
 576            if self.inputs.len() == 1 {
 577                self.inputs.first().cloned()
 578            } else {
 579                panic!("--in-place requires exactly one input file")
 580            }
 581        } else {
 582            self.output.clone()
 583        }
 584    }
 585}
 586
 587/// Minimum Zed version required for Snowflake queries.
 588/// This version introduced the current request schema with predicted edits in the edit
 589/// history, and open source repos distinguished.
 590const MIN_CAPTURE_VERSION: pull_examples::MinCaptureVersion = pull_examples::MinCaptureVersion {
 591    minor: 224,
 592    patch: 1,
 593};
 594
 595fn deduplicate_examples(examples: &mut Vec<Example>, max_per_cluster: usize) {
 596    let total_before_exact = examples.len();
 597    let mut seen_positions = HashSet::default();
 598    examples.retain(|example| seen_positions.insert(example.spec.cursor_position.clone()));
 599    log::info!(
 600        "exact duplicate filter: {total_before_exact} examples → {} examples ({} removed)",
 601        examples.len(),
 602        total_before_exact - examples.len(),
 603    );
 604
 605    const JACCARD_THRESHOLD: f64 = 0.5;
 606    const NUM_HASHES: usize = 128;
 607    const TOKEN_NGRAM_SIZE: usize = 5;
 608
 609    let (num_bands, band_width) = calculate_minhash_params(JACCARD_THRESHOLD, NUM_HASHES);
 610    let num_hashes = num_bands * band_width;
 611    let minhasher = MinHasher32::new(num_hashes);
 612    let mut index: MinHashIndex<u32, usize> =
 613        MinHashIndex::new(num_bands, band_width, JACCARD_THRESHOLD);
 614
 615    let signatures: Vec<Vec<u32>> = examples
 616        .iter()
 617        .map(|example| {
 618            let shingles = code_token_ngrams(&example.spec.cursor_position, TOKEN_NGRAM_SIZE);
 619            minhasher.create_signature(shingles.iter())
 620        })
 621        .collect();
 622
 623    for (id, signature) in signatures.iter().enumerate() {
 624        index.insert(id, signature.clone());
 625    }
 626
 627    // Build clusters via union-find on LSH candidate pairs.
 628    let mut parent: Vec<usize> = (0..examples.len()).collect();
 629
 630    fn find(parent: &mut Vec<usize>, mut x: usize) -> usize {
 631        while parent[x] != x {
 632            parent[x] = parent[parent[x]];
 633            x = parent[x];
 634        }
 635        x
 636    }
 637
 638    for (id, signature) in signatures.iter().enumerate() {
 639        for candidate in index.query_owned(signature) {
 640            let (a, b) = (find(&mut parent, id), find(&mut parent, candidate));
 641            if a != b {
 642                parent[a] = b;
 643            }
 644        }
 645    }
 646
 647    let mut clusters: HashMap<usize, Vec<usize>> = HashMap::default();
 648    for id in 0..examples.len() {
 649        clusters.entry(find(&mut parent, id)).or_default().push(id);
 650    }
 651
 652    let mut keep: HashSet<usize> = HashSet::default();
 653    for members in clusters.values() {
 654        let selected = greedy_max_min_diverse(members, &signatures, max_per_cluster);
 655        keep.extend(selected);
 656    }
 657
 658    let total = examples.len();
 659    let mut kept_indices: Vec<usize> = keep.into_iter().collect();
 660    kept_indices.sort();
 661
 662    let mut retained = Vec::with_capacity(kept_indices.len());
 663    for index in kept_indices.into_iter().rev() {
 664        retained.push(examples.swap_remove(index));
 665    }
 666    retained.reverse();
 667
 668    *examples = retained;
 669    log::info!(
 670        "near-duplicate filter: {total} examples → {} examples ({} removed)",
 671        examples.len(),
 672        total - examples.len(),
 673    );
 674}
 675
 676fn greedy_max_min_diverse(members: &[usize], signatures: &[Vec<u32>], k: usize) -> Vec<usize> {
 677    if members.len() <= k {
 678        return members.to_vec();
 679    }
 680
 681    let mut selected = vec![members[0]];
 682    let mut min_dist: HashMap<usize, f64> = HashMap::default();
 683    for &member in &members[1..] {
 684        let dist = 1.0 - compute_minhash_similarity(&signatures[selected[0]], &signatures[member]);
 685        min_dist.insert(member, dist);
 686    }
 687
 688    while selected.len() < k {
 689        let &best = min_dist
 690            .iter()
 691            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
 692            .map(|(id, _)| id)
 693            .expect("min_dist should not be empty when selected.len() < k");
 694        selected.push(best);
 695        min_dist.remove(&best);
 696
 697        let best_sig = &signatures[best];
 698        for (member, current_min) in min_dist.iter_mut() {
 699            let dist = 1.0 - compute_minhash_similarity(best_sig, &signatures[*member]);
 700            if dist < *current_min {
 701                *current_min = dist;
 702            }
 703        }
 704    }
 705
 706    selected
 707}
 708
 709fn code_token_ngrams(code: &str, ngram_size: usize) -> Vec<String> {
 710    let tokens: Vec<&str> = word_diff::tokenize(code)
 711        .into_iter()
 712        .filter(|t| !t.trim().is_empty())
 713        .collect();
 714
 715    if tokens.len() < ngram_size {
 716        return vec![tokens.join("\0")];
 717    }
 718
 719    tokens
 720        .windows(ngram_size)
 721        .map(|window| window.join("\0"))
 722        .collect()
 723}
 724
 725async fn load_examples(
 726    http_client: Arc<dyn http_client::HttpClient>,
 727    args: &EpArgs,
 728    output_path: Option<&PathBuf>,
 729    background_executor: BackgroundExecutor,
 730) -> anyhow::Result<Vec<Example>> {
 731    let mut captured_after_timestamps = Vec::new();
 732    let mut rejected_after_timestamps = Vec::new();
 733    let mut requested_after_timestamps = Vec::new();
 734    let mut settled_after_timestamps = Vec::new();
 735    let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
 736        Vec::new();
 737    let mut file_inputs = Vec::new();
 738
 739    for input in &args.inputs {
 740        let input_string = input.to_string_lossy();
 741        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
 742            captured_after_timestamps.push(timestamp.to_string());
 743        } else if let Some(timestamp) =
 744            pull_examples::parse_rejected_after_input(input_string.as_ref())
 745        {
 746            rejected_after_timestamps.push(timestamp.to_string());
 747        } else if let Some(timestamp) =
 748            pull_examples::parse_requested_after_input(input_string.as_ref())
 749        {
 750            requested_after_timestamps.push(timestamp.to_string());
 751        } else if let Some(timestamp) = parse_settled_after_input(input_string.as_ref()) {
 752            settled_after_timestamps.push(timestamp.to_string());
 753        } else if let Some((timestamp, rating_filter)) =
 754            pull_examples::parse_rated_after_input(input_string.as_ref())
 755        {
 756            rated_after_inputs.push((timestamp.to_string(), rating_filter));
 757        } else {
 758            file_inputs.push(input.clone());
 759        }
 760    }
 761
 762    let mut examples = read_example_files(&file_inputs);
 763
 764    // Apply offset to file examples first, then pass remaining offset to Snowflake.
 765    let file_example_count = examples.len();
 766    let remaining_offset = if let Some(offset) = args.offset {
 767        if offset >= file_example_count {
 768            examples.clear();
 769            offset - file_example_count
 770        } else {
 771            examples.splice(0..offset, []);
 772            0
 773        }
 774    } else {
 775        0
 776    };
 777
 778    Progress::global().set_total_examples(examples.len());
 779
 780    let remaining_limit_for_snowflake =
 781        args.limit.map(|limit| limit.saturating_sub(examples.len()));
 782
 783    if let Some(0) = remaining_limit_for_snowflake {
 784        log::info!(
 785            "skipping Snowflake inputs because --limit is already satisfied by example files"
 786        );
 787    } else {
 788        let max_rows_per_timestamp = remaining_limit_for_snowflake;
 789
 790        if !rejected_after_timestamps.is_empty() {
 791            rejected_after_timestamps.sort();
 792
 793            let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
 794                http_client.clone(),
 795                &rejected_after_timestamps,
 796                max_rows_per_timestamp,
 797                remaining_offset,
 798                background_executor.clone(),
 799                Some(MIN_CAPTURE_VERSION),
 800            )
 801            .await?;
 802            examples.append(&mut rejected_examples);
 803        }
 804
 805        if !requested_after_timestamps.is_empty() {
 806            requested_after_timestamps.sort();
 807
 808            let mut requested_examples = pull_examples::fetch_requested_examples_after(
 809                http_client.clone(),
 810                &requested_after_timestamps,
 811                max_rows_per_timestamp,
 812                remaining_offset,
 813                background_executor.clone(),
 814                Some(MIN_CAPTURE_VERSION),
 815            )
 816            .await?;
 817            examples.append(&mut requested_examples);
 818        }
 819
 820        if !captured_after_timestamps.is_empty() {
 821            captured_after_timestamps.sort();
 822
 823            let mut captured_examples = pull_examples::fetch_captured_examples_after(
 824                http_client.clone(),
 825                &captured_after_timestamps,
 826                max_rows_per_timestamp,
 827                remaining_offset,
 828                background_executor.clone(),
 829                Some(MIN_CAPTURE_VERSION),
 830            )
 831            .await?;
 832            examples.append(&mut captured_examples);
 833        }
 834
 835        if !settled_after_timestamps.is_empty() {
 836            settled_after_timestamps.sort();
 837
 838            let mut settled_examples = fetch_settled_examples_after(
 839                http_client.clone(),
 840                &settled_after_timestamps,
 841                max_rows_per_timestamp,
 842                remaining_offset,
 843                background_executor.clone(),
 844                Some(MIN_CAPTURE_VERSION),
 845            )
 846            .await?;
 847            examples.append(&mut settled_examples);
 848        }
 849
 850        if !rated_after_inputs.is_empty() {
 851            rated_after_inputs.sort();
 852
 853            let mut rated_examples = pull_examples::fetch_rated_examples_after(
 854                http_client,
 855                &rated_after_inputs,
 856                max_rows_per_timestamp,
 857                remaining_offset,
 858                background_executor,
 859                Some(MIN_CAPTURE_VERSION),
 860            )
 861            .await?;
 862            examples.append(&mut rated_examples);
 863        }
 864    }
 865
 866    crate::example::sort_examples_by_repo_and_rev(&mut examples);
 867
 868    if let Some(name_filter) = &args.name {
 869        examples.retain(|example| example.spec.name.contains(name_filter));
 870    }
 871    if let Some(repo_filter) = &args.repo {
 872        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
 873    }
 874
 875    // Skip resume logic for --in-place since input and output are the same file,
 876    // which would incorrectly treat all input examples as already processed.
 877    if !args.in_place {
 878        if let Some(path) = output_path
 879            && let Some(command) = &args.command
 880        {
 881            resume_from_output(path, &mut examples, command);
 882        }
 883    }
 884
 885    if let Some(max_duplicates) = args.max_duplicates {
 886        deduplicate_examples(&mut examples, max_duplicates);
 887    }
 888
 889    if let Some(limit) = args.limit {
 890        examples.truncate(limit);
 891    }
 892
 893    let progress = Progress::global();
 894    progress.set_total_examples(examples.len());
 895    progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
 896
 897    Ok(examples)
 898}
 899
 900fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
 901    let mut hasher = collections::FxHasher::default();
 902    spec.hash(&mut hasher);
 903    hasher.finish()
 904}
 905
 906fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
 907    let file = match File::open(path) {
 908        Ok(f) => f,
 909        Err(_) => return,
 910    };
 911
 912    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
 913
 914    let reader = BufReader::new(file);
 915    let mut kept_lines = Vec::new();
 916    let mut kept_hashes = HashSet::default();
 917
 918    for line in reader.lines() {
 919        let line = match line {
 920            Ok(l) => l,
 921            Err(_) => continue,
 922        };
 923
 924        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
 925            let hash = spec_hash(&output_example.spec);
 926            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
 927                let is_complete = match command {
 928                    Command::Qa(_) => output_example
 929                        .qa
 930                        .first()
 931                        .and_then(|q| q.as_ref())
 932                        .and_then(|q| q.confidence)
 933                        .is_some(),
 934                    Command::Repair(_) => output_example.predictions.iter().any(|p| {
 935                        p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
 936                    }),
 937                    _ => true,
 938                };
 939                if is_complete {
 940                    kept_hashes.insert(hash);
 941                    kept_lines.push(line);
 942                }
 943            }
 944        }
 945    }
 946
 947    let total = examples.len();
 948    let already_processed = kept_hashes.len();
 949
 950    eprintln!(
 951        "Resuming: {}/{} examples already processed",
 952        already_processed, total
 953    );
 954
 955    let file = OpenOptions::new()
 956        .write(true)
 957        .truncate(true)
 958        .open(path)
 959        .expect("Failed to open output file for rewriting");
 960    let mut writer = BufWriter::new(file);
 961    for line in &kept_lines {
 962        writeln!(writer, "{}", line).expect("Failed to write to output file");
 963    }
 964    writer.flush().expect("Failed to flush output file");
 965
 966    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
 967}
 968
 969fn main() {
 970    let args = EpArgs::parse();
 971
 972    if args.printenv {
 973        ::util::shell_env::print_env();
 974        return;
 975    }
 976
 977    let output = args.output_path();
 978
 979    if args.markdown && output.is_none() {
 980        eprintln!("--markdown requires -o to specify the output directory");
 981        std::process::exit(1);
 982    }
 983
 984    let command = match &args.command {
 985        Some(cmd) => cmd.clone(),
 986        None => {
 987            EpArgs::command().print_help().unwrap();
 988            return;
 989        }
 990    };
 991
 992    match &command {
 993        Command::ImportBatch(import_args) => {
 994            smol::block_on(async {
 995                match import_args.provider {
 996                    BatchProvider::Anthropic => {
 997                        let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
 998                            .expect("Failed to create Anthropic client");
 999                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
1000                            eprintln!("Error importing Anthropic batches: {:?}", e);
1001                            std::process::exit(1);
1002                        }
1003                    }
1004                    BatchProvider::Openai => {
1005                        let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
1006                            .expect("Failed to create OpenAI client");
1007                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
1008                            eprintln!("Error importing OpenAI batches: {:?}", e);
1009                            std::process::exit(1);
1010                        }
1011                    }
1012                }
1013                println!(
1014                    "Successfully imported {} batch(es)",
1015                    import_args.batch_ids.len()
1016                );
1017            });
1018            return;
1019        }
1020        Command::Clean => {
1021            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
1022            return;
1023        }
1024        Command::PrintZetaFormats => {
1025            use strum::IntoEnumIterator as _;
1026            for format in ZetaFormat::iter() {
1027                println!("{}", format.to_string().to_lowercase());
1028            }
1029            return;
1030        }
1031
1032        Command::Synthesize(synth_args) => {
1033            let output_dir = if let Some(output_dir) = args.output {
1034                output_dir
1035            } else {
1036                let default_output_dir = env::current_dir()
1037                    .unwrap()
1038                    .join("crates/edit_prediction_cli/evals-generated");
1039                if default_output_dir.parent().unwrap().exists() {
1040                    std::fs::create_dir(&default_output_dir).ok();
1041                    default_output_dir
1042                } else {
1043                    panic!("output dir is required");
1044                }
1045            };
1046            let config = SynthesizeConfig {
1047                repo_urls: synth_args.repos.clone(),
1048                count: synth_args.count,
1049                max_commits: synth_args.max_commits,
1050                output_dir,
1051                fresh: synth_args.fresh,
1052            };
1053            smol::block_on(async {
1054                if let Err(e) = run_synthesize(config).await {
1055                    eprintln!("Error: {:?}", e);
1056                    std::process::exit(1);
1057                }
1058            });
1059            return;
1060        }
1061        Command::SplitCommit(split_commit_args) => {
1062            if let Err(error) = split_commit::run_split_commit(
1063                split_commit_args,
1064                &args.inputs,
1065                output.as_ref(),
1066                args.failed,
1067            ) {
1068                eprintln!("{error:#}");
1069                std::process::exit(1);
1070            }
1071            return;
1072        }
1073        Command::TruncatePatch(truncate_args) => {
1074            if let Err(error) =
1075                truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
1076            {
1077                eprintln!("{error:#}");
1078                std::process::exit(1);
1079            }
1080            return;
1081        }
1082        Command::Split(split_args) => {
1083            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
1084                eprintln!("{error:#}");
1085                std::process::exit(1);
1086            }
1087            return;
1088        }
1089        Command::FilterLanguages(filter_args) => {
1090            if let Err(error) =
1091                run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
1092            {
1093                eprintln!("{error:#}");
1094                std::process::exit(1);
1095            }
1096            return;
1097        }
1098
1099        _ => {}
1100    }
1101
1102    let http_client = Arc::new(ReqwestClient::new());
1103    let app = gpui_platform::headless().with_http_client(http_client);
1104
1105    app.run(move |cx| {
1106        let app_state = Arc::new(headless::init(cx));
1107        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
1108
1109        cx.spawn(async move |cx| {
1110            let result = async {
1111                let examples = load_examples(
1112                    app_state.client.http_client(),
1113                    &args,
1114                    output.as_ref(),
1115                    cx.background_executor().clone(),
1116                )
1117                .await?;
1118
1119                match &command {
1120                    Command::Predict(args) | Command::Score(args) => {
1121                        predict::sync_batches(args.provider.as_ref()).await?;
1122                    }
1123                    Command::Eval(args) => {
1124                        predict::sync_batches(args.predict.provider.as_ref()).await?;
1125                    }
1126                    Command::Qa(args) => {
1127                        qa::sync_batches(args).await?;
1128                    }
1129                    Command::Repair(args) => {
1130                        repair::sync_batches(args).await?;
1131                    }
1132                    _ => (),
1133                }
1134
1135                let failfast_on_single_example = examples.len() == 1;
1136
1137                // For --markdown mode, create the output directory if it doesn't exist
1138                if args.markdown {
1139                    let dir = output.as_ref().expect("--markdown requires -o");
1140                    if !dir.exists() {
1141                        std::fs::create_dir_all(dir)
1142                            .expect("Failed to create markdown output directory");
1143                    }
1144                }
1145
1146                // Set up JSONL output writer (not used in markdown mode)
1147                let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
1148                let mut in_place_temp_path: Option<PathBuf> = None;
1149                if !args.markdown
1150                    && let Some(output_path) = output.as_ref()
1151                {
1152                    let write_path = if args.in_place {
1153                        let temp = output_path.with_extension("jsonl.tmp");
1154                        in_place_temp_path = Some(temp.clone());
1155                        temp
1156                    } else {
1157                        output_path.clone()
1158                    };
1159
1160                    let file = OpenOptions::new()
1161                        .create(true)
1162                        .write(true)
1163                        .truncate(args.in_place)
1164                        .append(!args.in_place)
1165                        .open(&write_path)
1166                        .expect("Failed to open output file");
1167
1168                    let mut writer = BufWriter::new(file);
1169                    let (sender, mut receiver) = mpsc::unbounded::<String>();
1170                    cx.background_spawn(async move {
1171                        while let Some(line) = receiver.next().await {
1172                            writeln!(writer, "{}", line).expect("Failed to write example");
1173                            writer.flush().expect("Failed to flush output");
1174                        }
1175                    })
1176                    .detach();
1177                    output_sender = Some(sender);
1178                }
1179
1180                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
1181                let finished_examples = Mutex::new(Vec::new());
1182
1183                let mut tasks = Vec::new();
1184                for _ in 0..args.max_parallelism {
1185                    tasks.push(async {
1186                        loop {
1187                            let Some(mut repo_examples) =
1188                                grouped_examples.lock().unwrap().pop_front()
1189                            else {
1190                                break;
1191                            };
1192                            for example in &mut repo_examples {
1193                                let example_progress =
1194                                    Progress::global().start_group(&example.spec.name);
1195
1196                                let result = async {
1197                                    match &command {
1198                                        Command::Read(_) => {}
1199                                        Command::LoadProject => {
1200                                            run_load_project(
1201                                                example,
1202                                                app_state.clone(),
1203                                                &example_progress,
1204                                                cx.clone(),
1205                                            )
1206                                            .await?;
1207                                        }
1208                                        Command::Context => {
1209                                            run_context_retrieval(
1210                                                example,
1211                                                app_state.clone(),
1212                                                &example_progress,
1213                                                cx.clone(),
1214                                            )
1215                                            .await?;
1216                                        }
1217                                        Command::FormatPrompt(args) => {
1218                                            run_format_prompt(
1219                                                example,
1220                                                args,
1221                                                app_state.clone(),
1222                                                &example_progress,
1223                                                cx.clone(),
1224                                            )
1225                                            .await?;
1226                                        }
1227                                        Command::Predict(args) => {
1228                                            run_prediction(
1229                                                example,
1230                                                args,
1231                                                app_state.clone(),
1232                                                &example_progress,
1233                                                cx.clone(),
1234                                            )
1235                                            .await?;
1236                                        }
1237                                        Command::ParseOutput => {
1238                                            parse_output::run_parse_output(example)?;
1239                                        }
1240                                        Command::Distill => {
1241                                            run_distill(example).await?;
1242                                        }
1243                                        Command::Score(args) => {
1244                                            run_scoring(
1245                                                example,
1246                                                args,
1247                                                app_state.clone(),
1248                                                &example_progress,
1249                                                cx.clone(),
1250                                            )
1251                                            .await?;
1252                                        }
1253                                        Command::Eval(args) => {
1254                                            run_scoring(
1255                                                example,
1256                                                &args.predict,
1257                                                app_state.clone(),
1258                                                &example_progress,
1259                                                cx.clone(),
1260                                            )
1261                                            .await?;
1262                                        }
1263                                        Command::Qa(args) => {
1264                                            qa::run_qa(example, args, &example_progress).await?;
1265                                        }
1266                                        Command::Repair(args) => {
1267                                            repair::run_repair(example, args, &example_progress)
1268                                                .await?;
1269                                        }
1270                                        Command::Clean
1271                                        | Command::Synthesize(_)
1272                                        | Command::SplitCommit(_)
1273                                        | Command::Split(_)
1274                                        | Command::TruncatePatch(_)
1275                                        | Command::FilterLanguages(_)
1276                                        | Command::ImportBatch(_)
1277                                        | Command::PrintZetaFormats => {
1278                                            unreachable!()
1279                                        }
1280                                    }
1281                                    anyhow::Ok(())
1282                                }
1283                                .await;
1284
1285                                let failed = if let Err(error) = result {
1286                                    handle_error(
1287                                        error,
1288                                        &args,
1289                                        &command,
1290                                        &app_state,
1291                                        failfast_on_single_example,
1292                                        &example,
1293                                    )
1294                                    .await;
1295                                    true
1296                                } else {
1297                                    false
1298                                };
1299
1300                                let should_write = !failed || args.failed == FailedHandling::Keep;
1301                                if should_write {
1302                                    if args.markdown {
1303                                        let markdown_dir =
1304                                            output.as_ref().expect("--markdown requires -o");
1305                                        let filename = format!("{}.md", example.spec.filename());
1306                                        let path = markdown_dir.join(&filename);
1307                                        let markdown = example.spec.to_markdown();
1308                                        std::fs::write(&path, &markdown)
1309                                            .expect("Failed to write markdown file");
1310                                    } else if let Some(ref mut sender) = output_sender.clone() {
1311                                        let line = serde_json::to_string(&example).unwrap();
1312                                        sender
1313                                            .send(line)
1314                                            .await
1315                                            .expect("Failed to send to output writer");
1316                                    } else if args.output.is_none()
1317                                        && !matches!(command, Command::Eval(_))
1318                                    {
1319                                        let line = serde_json::to_string(&example).unwrap();
1320                                        println!("{}", line);
1321                                    }
1322                                }
1323                            }
1324
1325                            let project = repo_examples
1326                                .iter()
1327                                .find_map(|e| e.state.as_ref().map(|s| s.project.clone()));
1328
1329                            if let Some(project) = project {
1330                                let mut cx = cx.clone();
1331
1332                                let shutdown_task: Task<()> =
1333                                    project.update(&mut cx, |project, cx| {
1334                                        let lsp_store = project.lsp_store();
1335                                        lsp_store.update(cx, |lsp_store, cx| {
1336                                            lsp_store.shutdown_all_language_servers(cx)
1337                                        })
1338                                    });
1339
1340                                shutdown_task.await;
1341
1342                                if let Some(ep_store) =
1343                                    cx.update(|cx| EditPredictionStore::try_global(cx))
1344                                {
1345                                    ep_store.update(&mut cx, |store, _| {
1346                                        store.remove_project(&project);
1347                                    });
1348                                }
1349                            }
1350
1351                            for example in &mut repo_examples {
1352                                example.state.take();
1353                            }
1354                            finished_examples
1355                                .lock()
1356                                .unwrap()
1357                                .extend_from_slice(&repo_examples);
1358                        }
1359                    });
1360                }
1361                futures::future::join_all(tasks).await;
1362
1363                Progress::global().finalize();
1364
1365                let is_markdown = args.markdown;
1366                let write_path = in_place_temp_path.as_ref().or(output.as_ref());
1367                match &command {
1368                    Command::Predict(args) | Command::Score(args) => {
1369                        predict::sync_batches(args.provider.as_ref()).await?;
1370                        if args.wait {
1371                            predict::wait_for_batches(args.provider.as_ref()).await?;
1372                            let mut examples =
1373                                std::mem::take(&mut *finished_examples.lock().unwrap());
1374                            predict::reprocess_after_batch_wait(&mut examples, args).await?;
1375                            rewrite_output(&examples, write_path, is_markdown)?;
1376                            *finished_examples.lock().unwrap() = examples;
1377                        }
1378                    }
1379                    Command::Eval(args) => {
1380                        predict::sync_batches(args.predict.provider.as_ref()).await?;
1381                        if args.predict.wait {
1382                            predict::wait_for_batches(args.predict.provider.as_ref()).await?;
1383                            let mut examples =
1384                                std::mem::take(&mut *finished_examples.lock().unwrap());
1385                            predict::reprocess_after_batch_wait(&mut examples, &args.predict)
1386                                .await?;
1387                            rewrite_output(&examples, write_path, is_markdown)?;
1388                            *finished_examples.lock().unwrap() = examples;
1389                        }
1390                    }
1391                    Command::Qa(args) => {
1392                        qa::sync_batches(args).await?;
1393                    }
1394                    Command::Repair(args) => {
1395                        repair::sync_batches(args).await?;
1396                        if args.wait {
1397                            repair::wait_for_batches(args).await?;
1398                            let mut examples =
1399                                std::mem::take(&mut *finished_examples.lock().unwrap());
1400                            repair::reprocess_after_batch_wait(&mut examples, args).await?;
1401                            rewrite_output(&examples, write_path, is_markdown)?;
1402                            *finished_examples.lock().unwrap() = examples;
1403                        }
1404                    }
1405                    _ => (),
1406                }
1407
1408                match &command {
1409                    Command::Eval(args) => {
1410                        let examples = finished_examples.lock().unwrap();
1411                        score::print_report(&examples, args.verbose);
1412                        if let Some(summary_path) = &args.summary_json {
1413                            score::write_summary_json(&examples, summary_path)?;
1414                        }
1415                    }
1416                    Command::Repair(args) => {
1417                        let examples = finished_examples.lock().unwrap();
1418                        repair::print_report(&examples, args.confidence_threshold);
1419                    }
1420                    _ => (),
1421                };
1422
1423                // For --in-place, atomically rename temp file to original
1424                if let Some(temp_path) = &in_place_temp_path {
1425                    let final_path = output.as_ref().expect("in_place_temp_path requires output");
1426                    std::fs::rename(temp_path, final_path)
1427                        .expect("Failed to rename temp file to final output");
1428                }
1429
1430                anyhow::Ok(())
1431            }
1432            .await;
1433
1434            if let Err(e) = result {
1435                panic!("Fatal error: {:?}", e);
1436            }
1437
1438            let _ = cx.update(|cx| cx.quit());
1439        })
1440        .detach();
1441    });
1442}
1443
1444fn rewrite_output(
1445    examples: &[Example],
1446    output_path: Option<&PathBuf>,
1447    markdown: bool,
1448) -> anyhow::Result<()> {
1449    if markdown {
1450        let dir = output_path.context("--markdown requires -o")?;
1451        for example in examples {
1452            let filename = format!("{}.md", example.spec.filename());
1453            let path = dir.join(&filename);
1454            let markdown = example.spec.to_markdown();
1455            std::fs::write(&path, &markdown).context("Failed to write markdown file")?;
1456        }
1457    } else if let Some(path) = output_path {
1458        let file = OpenOptions::new()
1459            .create(true)
1460            .write(true)
1461            .truncate(true)
1462            .open(path)
1463            .context("Failed to open output file for rewriting")?;
1464        let mut writer = BufWriter::new(file);
1465        for example in examples {
1466            let line = serde_json::to_string(example)?;
1467            writeln!(writer, "{}", line)?;
1468        }
1469        writer.flush()?;
1470    } else {
1471        for example in examples {
1472            let line = serde_json::to_string(example)?;
1473            println!("{}", line);
1474        }
1475    }
1476    Ok(())
1477}
1478
1479async fn handle_error(
1480    error: anyhow::Error,
1481    args: &EpArgs,
1482    command: &Command,
1483    app_state: &Arc<headless::EpAppState>,
1484    failfast_on_single_example: bool,
1485    example: &Example,
1486) {
1487    Progress::global().increment_failed();
1488
1489    let msg;
1490    if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1491        let example_name = example.spec.filename();
1492
1493        let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1494        app_state
1495            .fs
1496            .write(
1497                &failed_example_path,
1498                &serde_json::to_vec_pretty(&example).unwrap(),
1499            )
1500            .await
1501            .unwrap();
1502        let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1503        app_state
1504            .fs
1505            .write(&err_path, format!("{error:?}").as_bytes())
1506            .await
1507            .unwrap();
1508
1509        let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1510        let mut file = OpenOptions::new()
1511            .create(true)
1512            .append(true)
1513            .open(&failed_jsonl_path)
1514            .expect("Failed to open failed.jsonl");
1515        writeln!(file, "{}", serde_json::to_string(example).unwrap())
1516            .expect("Failed to write to failed.jsonl");
1517
1518        let cursor_path = match example.repo_name() {
1519            Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1520            Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1521        };
1522        msg = format!(
1523            indoc::indoc! {"
1524                While processing \"{}\":
1525
1526                \x1b[31m{:?}\x1b[0m
1527
1528                Example:        \x1b[36m{}\x1b[0m
1529                Error file:     \x1b[36m{}\x1b[0m
1530                Cursor file:    \x1b[36m{}\x1b[0m
1531                Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1532            "},
1533            example.spec.name,
1534            error,
1535            failed_example_path.display(),
1536            err_path.display(),
1537            cursor_path.display(),
1538            command,
1539            failed_example_path.display(),
1540        );
1541    } else {
1542        msg = format!(
1543            indoc::indoc! {"
1544            While processing \"{}\":
1545
1546                \x1b[31m{:?}\x1b[0m
1547            "},
1548            example.spec.name, error
1549        );
1550    }
1551
1552    if args.failfast || failfast_on_single_example {
1553        Progress::global().finalize();
1554        panic!("{}", msg);
1555    } else {
1556        log::error!("{}", msg);
1557    }
1558}