main.rs

   1mod anthropic_client;
   2mod distill;
   3mod example;
   4mod filter_languages;
   5mod format_prompt;
   6mod git;
   7mod headless;
   8
   9mod load_project;
  10mod metrics;
  11mod openai_client;
  12mod parse_output;
  13mod paths;
  14mod predict;
  15mod progress;
  16mod prompt_assets;
  17mod pull_examples;
  18mod qa;
  19mod reorder_patch;
  20mod repair;
  21mod retrieve_context;
  22mod reversal_tracking;
  23mod score;
  24mod split_commit;
  25mod split_dataset;
  26
  27mod synthesize;
  28mod truncate_expected_patch;
  29mod word_diff;
  30use anyhow::Context as _;
  31use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
  32use collections::{HashMap, HashSet};
  33use edit_prediction::EditPredictionStore;
  34use futures::channel::mpsc;
  35use futures::{SinkExt as _, StreamExt as _};
  36use gaoya::minhash::{
  37    MinHashIndex, MinHasher, MinHasher32, calculate_minhash_params, compute_minhash_similarity,
  38};
  39use gpui::{AppContext as _, BackgroundExecutor, Task};
  40use zeta_prompt::ZetaFormat;
  41
  42use reqwest_client::ReqwestClient;
  43use serde::{Deserialize, Deserializer, Serialize, Serializer};
  44use std::env;
  45use std::fmt::Display;
  46use std::fs::{File, OpenOptions};
  47use std::hash::{Hash, Hasher};
  48use std::io::{BufRead, BufReader, BufWriter, Write};
  49use std::str::FromStr;
  50use std::sync::Mutex;
  51use std::{path::PathBuf, sync::Arc};
  52
  53use crate::distill::run_distill;
  54use crate::example::{Example, group_examples_by_repo, read_example_files};
  55use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
  56use crate::format_prompt::run_format_prompt;
  57use crate::load_project::run_load_project;
  58use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
  59use crate::predict::run_prediction;
  60use crate::progress::Progress;
  61use crate::pull_examples::{fetch_settled_examples_after, parse_settled_after_input};
  62use crate::retrieve_context::run_context_retrieval;
  63use crate::score::run_scoring;
  64use crate::split_commit::SplitCommitArgs;
  65use crate::split_dataset::SplitArgs;
  66use crate::synthesize::{SynthesizeConfig, run_synthesize};
  67use crate::truncate_expected_patch::TruncatePatchArgs;
  68
  69#[derive(Parser, Debug)]
  70#[command(name = "ep")]
  71struct EpArgs {
  72    #[arg(long, default_value_t = false)]
  73    printenv: bool,
  74    #[clap(long, default_value_t = 10, global = true)]
  75    max_parallelism: usize,
  76    /// The limit for the number of examples to process
  77    /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
  78    #[clap(long, global = true)]
  79    limit: Option<usize>,
  80    #[clap(long, global = true)]
  81    offset: Option<usize>,
  82    /// Filter examples by name
  83    #[clap(long, global = true)]
  84    name: Option<String>,
  85    /// Filter examples by repository
  86    #[clap(long, global = true)]
  87    repo: Option<String>,
  88    /// Deduplicate by cursor position and keep at most this many examples per cluster
  89    #[clap(long, global = true)]
  90    max_duplicates: Option<usize>,
  91    #[command(subcommand)]
  92    command: Option<Command>,
  93    /// Input file paths
  94    #[clap(global = true)]
  95    inputs: Vec<PathBuf>,
  96    #[arg(long, short, global = true)]
  97    output: Option<PathBuf>,
  98    #[arg(long, short, global = true)]
  99    in_place: bool,
 100    #[arg(long, short, global = true)]
 101    failfast: bool,
 102    /// How to handle failed examples in output: keep them or skip them.
 103    /// Failed examples are always logged to the run's failed directory.
 104    #[arg(long, global = true, default_value = "keep")]
 105    failed: FailedHandling,
 106    /// Output as markdown files instead of JSONL. When set, -o specifies a directory
 107    /// where one .md file per example will be written (named after each example).
 108    #[arg(long, short, global = true)]
 109    markdown: bool,
 110}
 111
 112/// Controls whether failed examples are included in the main output.
 113/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
 114#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
 115pub enum FailedHandling {
 116    /// Include failed examples in the main output (default)
 117    #[default]
 118    Keep,
 119    /// Exclude failed examples from the main output
 120    Skip,
 121    /// Skip writing files
 122    SkipNoFiles,
 123}
 124
 125const INPUTS_HELP: &str = r#"
 126Inputs can be file paths or special specifiers:
 127
 128  path
 129      Path to an example(s) file (.md, .json, or .jsonl)
 130
 131  captured-after:{timestamp}
 132      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 133      These are examples captured via the "Capture Edit Prediction Example" action.
 134
 135  rejected-after:{timestamp}
 136      Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
 137      These are predictions that were shown to users but rejected (useful for DPO training).
 138
 139  settled-after:{timestamp}
 140      Fetch settled stream examples from Snowflake after the given RFC3339 timestamp.
 141      These are examples from the edit prediction settled stream.
 142
 143  rated-after:{timestamp}
 144      Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
 145      These are predictions that users explicitly rated as positive or negative via the
 146      rate completions modal. Only zeta2 predictions are included.
 147      - Positive ratings: output becomes expected_patches
 148      - Negative ratings: output becomes rejected_patch
 149
 150  rated-positive-after:{timestamp}
 151      Same as rated-after, but only fetches positively rated predictions.
 152
 153  rated-negative-after:{timestamp}
 154      Same as rated-after, but only fetches negatively rated predictions.
 155
 156      Required environment variables to connect to Snowflake:
 157          EP_SNOWFLAKE_API_KEY
 158          EP_SNOWFLAKE_BASE_URL
 159
 160      Optional:
 161          EP_SNOWFLAKE_ROLE
 162
 163Examples:
 164
 165  # Read examples from a file
 166  ep read examples.jsonl -o output.jsonl
 167
 168  # Read captured examples after a timestamp
 169  ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
 170
 171  # Read rejected predictions for DPO training
 172  ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
 173
 174  # Read user-rated predictions
 175  ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
 176
 177  # Read settled stream examples
 178  ep read settled-after:2025-01-01T00:00:00Z -o settled.jsonl
 179
 180  # Read only positively rated predictions
 181  ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
 182
 183  # Read only negatively rated predictions
 184  ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
 185
 186  # Mix multiple input sources
 187  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
 188"#;
 189
 190#[derive(Subcommand, Debug, Clone)]
 191enum Command {
 192    /// Read examples from files or fetch from Snowflake, output as .jsonl
 193    Read(ReadArgs),
 194    /// Create git worktrees for each example and load file contents
 195    LoadProject,
 196    /// Retrieve context for input examples.
 197    Context,
 198    /// Generate a prompt string for a specific model
 199    FormatPrompt(FormatPromptArgs),
 200    /// Runs edit prediction
 201    Predict(PredictArgs),
 202    /// Parse model outputs (actual_output) into unified diffs (actual_patch).
 203    /// Requires format-prompt to have been run first. Uses provider from prompt.
 204    ParseOutput,
 205    /// Computes a score based on actual and expected patches
 206    Score(PredictArgs),
 207    /// Prepares a distillation dataset by copying expected outputs to
 208    /// predicted outputs and removing actual outputs and prompts.
 209    Distill,
 210    /// Print aggregated scores
 211    Eval(EvalArgs),
 212    /// Generate eval examples by analyzing git commits from a repository
 213    Synthesize(SynthesizeArgs),
 214    /// Remove git repositories and worktrees
 215    Clean,
 216    /// Generate an evaluation example by splitting a chronologically-ordered commit
 217    SplitCommit(SplitCommitArgs),
 218    /// Truncate expected patch by the given criteria
 219    TruncatePatch(TruncatePatchArgs),
 220    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
 221    Split(SplitArgs),
 222    /// Filter a JSONL dataset by programming language (based on cursor_path extension)
 223    FilterLanguages(FilterLanguagesArgs),
 224    /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
 225    ImportBatch(ImportBatchArgs),
 226    /// Assess the quality of predictions using LLM-as-a-judge
 227    Qa(qa::QaArgs),
 228    /// Repair predictions that received poor QA scores by generating improved predictions
 229    Repair(repair::RepairArgs),
 230    /// Print all valid zeta formats (lowercase, one per line)
 231    PrintZetaFormats,
 232}
 233
 234impl Display for Command {
 235    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 236        match self {
 237            Command::Read(_) => write!(f, "read"),
 238            Command::LoadProject => write!(f, "load-project"),
 239            Command::Context => write!(f, "context"),
 240            Command::FormatPrompt(args) => {
 241                write!(f, "format-prompt --provider={}", args.provider)
 242            }
 243            Command::Predict(args) => match &args.provider {
 244                Some(provider) => write!(f, "predict --provider={}", provider),
 245                None => write!(f, "predict"),
 246            },
 247            Command::ParseOutput => write!(f, "parse-output"),
 248            Command::Score(args) => match &args.provider {
 249                Some(provider) => write!(f, "score --provider={}", provider),
 250                None => write!(f, "score"),
 251            },
 252            Command::Distill => write!(f, "distill"),
 253            Command::Eval(args) => match &args.predict.provider {
 254                Some(provider) => write!(f, "eval --provider={}", provider),
 255                None => write!(f, "eval"),
 256            },
 257            Command::Synthesize(args) => {
 258                write!(f, "synthesize --repos {}", args.repos.join(" "))
 259            }
 260            Command::Clean => write!(f, "clean"),
 261            Command::SplitCommit(_) => write!(f, "split-commit"),
 262            Command::TruncatePatch(_) => write!(f, "truncate-patch"),
 263            Command::Split(_) => write!(f, "split"),
 264            Command::FilterLanguages(_) => write!(f, "filter-languages"),
 265            Command::ImportBatch(args) => {
 266                write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
 267            }
 268            Command::Qa(_) => {
 269                write!(f, "qa")
 270            }
 271            Command::Repair(_) => {
 272                write!(f, "repair")
 273            }
 274            Command::PrintZetaFormats => {
 275                write!(f, "print-zeta-formats")
 276            }
 277        }
 278    }
 279}
 280
 281#[derive(Debug, Args, Clone)]
 282#[command(after_help = INPUTS_HELP)]
 283struct ReadArgs {}
 284
 285#[derive(Debug, Args, Clone)]
 286struct FormatPromptArgs {
 287    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
 288    provider: PredictionProvider,
 289}
 290
 291#[derive(Debug, Args, Clone)]
 292struct PredictArgs {
 293    #[clap(long, short('p'))]
 294    provider: Option<PredictionProvider>,
 295    #[clap(long, default_value_t = 1)]
 296    repetitions: usize,
 297    /// Only use cached responses, don't queue new requests for batching
 298    #[clap(long)]
 299    cache_only: bool,
 300    /// Wait for all batches to complete before exiting (only applies to batched providers like teacher)
 301    #[clap(long)]
 302    wait: bool,
 303}
 304
 305#[derive(Debug, Args, Clone)]
 306struct EvalArgs {
 307    #[clap(flatten)]
 308    predict: PredictArgs,
 309    /// Path to write summary scores as JSON
 310    #[clap(long)]
 311    summary_json: Option<PathBuf>,
 312    /// Print all individual example lines (default: up to 20)
 313    #[clap(long)]
 314    verbose: bool,
 315}
 316
 317#[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)]
 318pub enum TeacherBackend {
 319    Sonnet46,
 320    #[default]
 321    Sonnet45,
 322    Gpt52,
 323}
 324
 325impl std::fmt::Display for TeacherBackend {
 326    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 327        match self {
 328            TeacherBackend::Sonnet46 => write!(f, "sonnet46"),
 329            TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
 330            TeacherBackend::Gpt52 => write!(f, "gpt52"),
 331        }
 332    }
 333}
 334
 335impl std::str::FromStr for TeacherBackend {
 336    type Err = anyhow::Error;
 337
 338    fn from_str(s: &str) -> Result<Self, Self::Err> {
 339        match s.to_lowercase().as_str() {
 340            "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
 341            "sonnet46" => Ok(TeacherBackend::Sonnet46),
 342            "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
 343            "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
 344            _ => anyhow::bail!(
 345                "unknown teacher backend `{s}`. Valid options: sonnet45, sonnet46, gpt52"
 346            ),
 347        }
 348    }
 349}
 350
 351impl TeacherBackend {
 352    pub fn model_name(&self) -> &'static str {
 353        match self {
 354            TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
 355            TeacherBackend::Sonnet46 => "claude-sonnet-4-6",
 356            TeacherBackend::Gpt52 => "gpt-5.2",
 357        }
 358    }
 359}
 360
 361#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 362enum PredictionProvider {
 363    Mercury,
 364    Zeta1,
 365    Zeta2(ZetaFormat),
 366    Baseten(ZetaFormat),
 367    Teacher(TeacherBackend, ZetaFormat),
 368    TeacherMultiRegion(TeacherBackend),
 369    TeacherNonBatching(TeacherBackend, ZetaFormat),
 370    TeacherMultiRegionNonBatching(TeacherBackend),
 371    Repair,
 372}
 373
 374impl Default for PredictionProvider {
 375    fn default() -> Self {
 376        PredictionProvider::Zeta2(ZetaFormat::default())
 377    }
 378}
 379
 380impl std::fmt::Display for PredictionProvider {
 381    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 382        match self {
 383            PredictionProvider::Mercury => write!(f, "mercury"),
 384            PredictionProvider::Zeta1 => write!(f, "zeta1"),
 385            PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
 386            PredictionProvider::Baseten(format) => write!(f, "baseten:{format}"),
 387            PredictionProvider::Teacher(backend, format) => {
 388                write!(f, "teacher:{backend}:{format:?}")
 389            }
 390            PredictionProvider::TeacherMultiRegion(backend) => {
 391                write!(f, "teacher-multi-region:{backend}")
 392            }
 393            PredictionProvider::TeacherNonBatching(backend, format) => {
 394                write!(f, "teacher-non-batching:{backend}:{format:?}")
 395            }
 396            PredictionProvider::TeacherMultiRegionNonBatching(backend) => {
 397                write!(f, "teacher-multi-region-non-batching:{backend}")
 398            }
 399            PredictionProvider::Repair => write!(f, "repair"),
 400        }
 401    }
 402}
 403
 404impl std::str::FromStr for PredictionProvider {
 405    type Err = anyhow::Error;
 406
 407    fn from_str(s: &str) -> Result<Self, Self::Err> {
 408        let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
 409
 410        let provider_lower = provider.to_lowercase();
 411        match provider_lower.as_str() {
 412            "mercury" => Ok(PredictionProvider::Mercury),
 413            "zeta1" => Ok(PredictionProvider::Zeta1),
 414            "zeta2" => {
 415                let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
 416                Ok(PredictionProvider::Zeta2(format))
 417            }
 418            "teacher" => {
 419                let (backend, format) = parse_teacher_args(arg)?;
 420                Ok(PredictionProvider::Teacher(backend, format))
 421            }
 422            "teacher-non-batching" | "teacher_non_batching" => {
 423                let (backend, format) = parse_teacher_args(arg)?;
 424                Ok(PredictionProvider::TeacherNonBatching(backend, format))
 425            }
 426            "teacher-multi-region" | "teacher_multi_region" => {
 427                let backend = arg
 428                    .map(|a| a.parse())
 429                    .transpose()?
 430                    .unwrap_or(TeacherBackend::default());
 431                Ok(PredictionProvider::TeacherMultiRegion(backend))
 432            }
 433            "teacher-multi-region-non-batching" | "teacher_multi_region_non_batching" => {
 434                let backend = arg
 435                    .map(|a| a.parse())
 436                    .transpose()?
 437                    .unwrap_or(TeacherBackend::default());
 438                Ok(PredictionProvider::TeacherMultiRegionNonBatching(backend))
 439            }
 440            "repair" => Ok(PredictionProvider::Repair),
 441            "baseten" => {
 442                let format = arg
 443                    .map(ZetaFormat::parse)
 444                    .transpose()?
 445                    .unwrap_or(ZetaFormat::default());
 446                Ok(PredictionProvider::Baseten(format))
 447            }
 448            _ => {
 449                anyhow::bail!(
 450                    "unknown provider `{provider}`. Valid options: mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-multi-region, teacher-multi-region:<backend>, teacher-non-batching, teacher-multi-region-non-batching, repair\n\
 451                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
 452                 For teacher providers, you can specify a backend like `teacher:sonnet46`, `teacher-multi-region:sonnet46`, `teacher-multi-region-non-batching:sonnet46`, or `teacher:gpt52`.\n\
 453                 Available zeta versions:\n{}",
 454                    ZetaFormat::options_as_string()
 455                )
 456            }
 457        }
 458    }
 459}
 460
 461fn parse_teacher_args(arg: Option<&str>) -> Result<(TeacherBackend, ZetaFormat), anyhow::Error> {
 462    let mut backend = TeacherBackend::default();
 463    let mut format = ZetaFormat::default();
 464
 465    for arg in arg.unwrap_or_default().split(':') {
 466        if arg.is_empty() {
 467            continue;
 468        }
 469
 470        if let Ok(parsed_backend) = TeacherBackend::from_str(arg) {
 471            backend = parsed_backend;
 472        } else if let Ok(parsed_format) = ZetaFormat::parse(arg) {
 473            format = parsed_format;
 474        } else {
 475            anyhow::bail!("unknown teacher backend or zeta format `{arg}`");
 476        }
 477    }
 478
 479    Ok((backend, format))
 480}
 481
 482impl Serialize for PredictionProvider {
 483    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 484    where
 485        S: Serializer,
 486    {
 487        serializer.serialize_str(&self.to_string())
 488    }
 489}
 490
 491impl<'de> Deserialize<'de> for PredictionProvider {
 492    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 493    where
 494        D: Deserializer<'de>,
 495    {
 496        let s = String::deserialize(deserializer)?;
 497        s.parse().map_err(serde::de::Error::custom)
 498    }
 499}
 500
 501#[derive(Debug, Args, Clone)]
 502struct SynthesizeArgs {
 503    /// Repository URLs (git@github.com:owner/repo or https://...)
 504    #[clap(long, required = true, num_args = 1..)]
 505    repos: Vec<String>,
 506
 507    /// Number of examples to generate per repository
 508    #[clap(long, default_value_t = 5)]
 509    count: usize,
 510
 511    /// Maximum commits to scan per repository before giving up
 512    #[clap(long, default_value_t = 100)]
 513    max_commits: usize,
 514
 515    /// Ignore state file and reprocess all commits
 516    #[clap(long)]
 517    fresh: bool,
 518}
 519
 520#[derive(Debug, Args, Clone)]
 521struct ImportBatchArgs {
 522    /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
 523    #[clap(long, required = true, num_args = 1..)]
 524    batch_ids: Vec<String>,
 525    /// Which provider's batches to import (anthropic or openai)
 526    #[clap(long, default_value = "anthropic")]
 527    provider: BatchProvider,
 528}
 529
 530#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
 531enum BatchProvider {
 532    Anthropic,
 533    Openai,
 534}
 535
 536#[cfg(test)]
 537mod tests {
 538    use super::*;
 539
 540    #[test]
 541    fn prediction_provider_multi_region_non_batched_round_trips_to_primary_spelling() {
 542        let provider: PredictionProvider = "teacher-multi-region-non-batching:sonnet46"
 543            .parse()
 544            .unwrap();
 545        assert_eq!(
 546            provider,
 547            PredictionProvider::TeacherMultiRegionNonBatching(TeacherBackend::Sonnet46)
 548        );
 549        assert_eq!(
 550            provider.to_string(),
 551            "teacher-multi-region-non-batching:sonnet46"
 552        );
 553    }
 554
 555    #[test]
 556    fn prediction_provider_multi_region_non_batched_alias_round_trips_to_primary_spelling() {
 557        let provider: PredictionProvider =
 558            "teacher_multi_region_non_batching:gpt52".parse().unwrap();
 559        assert_eq!(
 560            provider,
 561            PredictionProvider::TeacherMultiRegionNonBatching(TeacherBackend::Gpt52)
 562        );
 563        assert_eq!(
 564            provider.to_string(),
 565            "teacher-multi-region-non-batching:gpt52"
 566        );
 567    }
 568}
 569
 570impl EpArgs {
 571    fn output_path(&self) -> Option<PathBuf> {
 572        if self.in_place {
 573            if self.inputs.len() == 1 {
 574                self.inputs.first().cloned()
 575            } else {
 576                panic!("--in-place requires exactly one input file")
 577            }
 578        } else {
 579            self.output.clone()
 580        }
 581    }
 582}
 583
 584/// Minimum Zed version required for Snowflake queries.
 585/// This version introduced the current request schema with predicted edits in the edit
 586/// history, and open source repos distinguished.
 587const MIN_CAPTURE_VERSION: pull_examples::MinCaptureVersion = pull_examples::MinCaptureVersion {
 588    minor: 224,
 589    patch: 1,
 590};
 591
 592fn deduplicate_examples(examples: &mut Vec<Example>, max_per_cluster: usize) {
 593    let total_before_exact = examples.len();
 594    let mut seen_positions = HashSet::default();
 595    examples.retain(|example| seen_positions.insert(example.spec.cursor_position.clone()));
 596    log::info!(
 597        "exact duplicate filter: {total_before_exact} examples → {} examples ({} removed)",
 598        examples.len(),
 599        total_before_exact - examples.len(),
 600    );
 601
 602    const JACCARD_THRESHOLD: f64 = 0.5;
 603    const NUM_HASHES: usize = 128;
 604    const TOKEN_NGRAM_SIZE: usize = 5;
 605
 606    let (num_bands, band_width) = calculate_minhash_params(JACCARD_THRESHOLD, NUM_HASHES);
 607    let num_hashes = num_bands * band_width;
 608    let minhasher = MinHasher32::new(num_hashes);
 609    let mut index: MinHashIndex<u32, usize> =
 610        MinHashIndex::new(num_bands, band_width, JACCARD_THRESHOLD);
 611
 612    let signatures: Vec<Vec<u32>> = examples
 613        .iter()
 614        .map(|example| {
 615            let shingles = code_token_ngrams(&example.spec.cursor_position, TOKEN_NGRAM_SIZE);
 616            minhasher.create_signature(shingles.iter())
 617        })
 618        .collect();
 619
 620    for (id, signature) in signatures.iter().enumerate() {
 621        index.insert(id, signature.clone());
 622    }
 623
 624    // Build clusters via union-find on LSH candidate pairs.
 625    let mut parent: Vec<usize> = (0..examples.len()).collect();
 626
 627    fn find(parent: &mut Vec<usize>, mut x: usize) -> usize {
 628        while parent[x] != x {
 629            parent[x] = parent[parent[x]];
 630            x = parent[x];
 631        }
 632        x
 633    }
 634
 635    for (id, signature) in signatures.iter().enumerate() {
 636        for candidate in index.query_owned(signature) {
 637            let (a, b) = (find(&mut parent, id), find(&mut parent, candidate));
 638            if a != b {
 639                parent[a] = b;
 640            }
 641        }
 642    }
 643
 644    let mut clusters: HashMap<usize, Vec<usize>> = HashMap::default();
 645    for id in 0..examples.len() {
 646        clusters.entry(find(&mut parent, id)).or_default().push(id);
 647    }
 648
 649    let mut keep: HashSet<usize> = HashSet::default();
 650    for members in clusters.values() {
 651        let selected = greedy_max_min_diverse(members, &signatures, max_per_cluster);
 652        keep.extend(selected);
 653    }
 654
 655    let total = examples.len();
 656    let mut kept_indices: Vec<usize> = keep.into_iter().collect();
 657    kept_indices.sort();
 658
 659    let mut retained = Vec::with_capacity(kept_indices.len());
 660    for index in kept_indices.into_iter().rev() {
 661        retained.push(examples.swap_remove(index));
 662    }
 663    retained.reverse();
 664
 665    *examples = retained;
 666    log::info!(
 667        "near-duplicate filter: {total} examples → {} examples ({} removed)",
 668        examples.len(),
 669        total - examples.len(),
 670    );
 671}
 672
 673fn greedy_max_min_diverse(members: &[usize], signatures: &[Vec<u32>], k: usize) -> Vec<usize> {
 674    if members.len() <= k {
 675        return members.to_vec();
 676    }
 677
 678    let mut selected = vec![members[0]];
 679    let mut min_dist: HashMap<usize, f64> = HashMap::default();
 680    for &member in &members[1..] {
 681        let dist = 1.0 - compute_minhash_similarity(&signatures[selected[0]], &signatures[member]);
 682        min_dist.insert(member, dist);
 683    }
 684
 685    while selected.len() < k {
 686        let &best = min_dist
 687            .iter()
 688            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
 689            .map(|(id, _)| id)
 690            .expect("min_dist should not be empty when selected.len() < k");
 691        selected.push(best);
 692        min_dist.remove(&best);
 693
 694        let best_sig = &signatures[best];
 695        for (member, current_min) in min_dist.iter_mut() {
 696            let dist = 1.0 - compute_minhash_similarity(best_sig, &signatures[*member]);
 697            if dist < *current_min {
 698                *current_min = dist;
 699            }
 700        }
 701    }
 702
 703    selected
 704}
 705
 706fn code_token_ngrams(code: &str, ngram_size: usize) -> Vec<String> {
 707    let tokens: Vec<&str> = word_diff::tokenize(code)
 708        .into_iter()
 709        .filter(|t| !t.trim().is_empty())
 710        .collect();
 711
 712    if tokens.len() < ngram_size {
 713        return vec![tokens.join("\0")];
 714    }
 715
 716    tokens
 717        .windows(ngram_size)
 718        .map(|window| window.join("\0"))
 719        .collect()
 720}
 721
 722async fn load_examples(
 723    http_client: Arc<dyn http_client::HttpClient>,
 724    args: &EpArgs,
 725    output_path: Option<&PathBuf>,
 726    background_executor: BackgroundExecutor,
 727) -> anyhow::Result<Vec<Example>> {
 728    let mut captured_after_timestamps = Vec::new();
 729    let mut rejected_after_timestamps = Vec::new();
 730    let mut requested_after_timestamps = Vec::new();
 731    let mut settled_after_timestamps = Vec::new();
 732    let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
 733        Vec::new();
 734    let mut file_inputs = Vec::new();
 735
 736    for input in &args.inputs {
 737        let input_string = input.to_string_lossy();
 738        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
 739            captured_after_timestamps.push(timestamp.to_string());
 740        } else if let Some(timestamp) =
 741            pull_examples::parse_rejected_after_input(input_string.as_ref())
 742        {
 743            rejected_after_timestamps.push(timestamp.to_string());
 744        } else if let Some(timestamp) =
 745            pull_examples::parse_requested_after_input(input_string.as_ref())
 746        {
 747            requested_after_timestamps.push(timestamp.to_string());
 748        } else if let Some(timestamp) = parse_settled_after_input(input_string.as_ref()) {
 749            settled_after_timestamps.push(timestamp.to_string());
 750        } else if let Some((timestamp, rating_filter)) =
 751            pull_examples::parse_rated_after_input(input_string.as_ref())
 752        {
 753            rated_after_inputs.push((timestamp.to_string(), rating_filter));
 754        } else {
 755            file_inputs.push(input.clone());
 756        }
 757    }
 758
 759    let mut examples = read_example_files(&file_inputs);
 760
 761    // Apply offset to file examples first, then pass remaining offset to Snowflake.
 762    let file_example_count = examples.len();
 763    let remaining_offset = if let Some(offset) = args.offset {
 764        if offset >= file_example_count {
 765            examples.clear();
 766            offset - file_example_count
 767        } else {
 768            examples.splice(0..offset, []);
 769            0
 770        }
 771    } else {
 772        0
 773    };
 774
 775    Progress::global().set_total_examples(examples.len());
 776
 777    let remaining_limit_for_snowflake =
 778        args.limit.map(|limit| limit.saturating_sub(examples.len()));
 779
 780    if let Some(0) = remaining_limit_for_snowflake {
 781        log::info!(
 782            "skipping Snowflake inputs because --limit is already satisfied by example files"
 783        );
 784    } else {
 785        let max_rows_per_timestamp = remaining_limit_for_snowflake;
 786
 787        if !rejected_after_timestamps.is_empty() {
 788            rejected_after_timestamps.sort();
 789
 790            let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
 791                http_client.clone(),
 792                &rejected_after_timestamps,
 793                max_rows_per_timestamp,
 794                remaining_offset,
 795                background_executor.clone(),
 796                Some(MIN_CAPTURE_VERSION),
 797            )
 798            .await?;
 799            examples.append(&mut rejected_examples);
 800        }
 801
 802        if !requested_after_timestamps.is_empty() {
 803            requested_after_timestamps.sort();
 804
 805            let mut requested_examples = pull_examples::fetch_requested_examples_after(
 806                http_client.clone(),
 807                &requested_after_timestamps,
 808                max_rows_per_timestamp,
 809                remaining_offset,
 810                background_executor.clone(),
 811                Some(MIN_CAPTURE_VERSION),
 812            )
 813            .await?;
 814            examples.append(&mut requested_examples);
 815        }
 816
 817        if !captured_after_timestamps.is_empty() {
 818            captured_after_timestamps.sort();
 819
 820            let mut captured_examples = pull_examples::fetch_captured_examples_after(
 821                http_client.clone(),
 822                &captured_after_timestamps,
 823                max_rows_per_timestamp,
 824                remaining_offset,
 825                background_executor.clone(),
 826                Some(MIN_CAPTURE_VERSION),
 827            )
 828            .await?;
 829            examples.append(&mut captured_examples);
 830        }
 831
 832        if !settled_after_timestamps.is_empty() {
 833            settled_after_timestamps.sort();
 834
 835            let mut settled_examples = fetch_settled_examples_after(
 836                http_client.clone(),
 837                &settled_after_timestamps,
 838                max_rows_per_timestamp,
 839                remaining_offset,
 840                background_executor.clone(),
 841                Some(MIN_CAPTURE_VERSION),
 842            )
 843            .await?;
 844            examples.append(&mut settled_examples);
 845        }
 846
 847        if !rated_after_inputs.is_empty() {
 848            rated_after_inputs.sort();
 849
 850            let mut rated_examples = pull_examples::fetch_rated_examples_after(
 851                http_client,
 852                &rated_after_inputs,
 853                max_rows_per_timestamp,
 854                remaining_offset,
 855                background_executor,
 856                Some(MIN_CAPTURE_VERSION),
 857            )
 858            .await?;
 859            examples.append(&mut rated_examples);
 860        }
 861    }
 862
 863    crate::example::sort_examples_by_repo_and_rev(&mut examples);
 864
 865    if let Some(name_filter) = &args.name {
 866        examples.retain(|example| example.spec.name.contains(name_filter));
 867    }
 868    if let Some(repo_filter) = &args.repo {
 869        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
 870    }
 871
 872    // Skip resume logic for --in-place since input and output are the same file,
 873    // which would incorrectly treat all input examples as already processed.
 874    if !args.in_place {
 875        if let Some(path) = output_path
 876            && let Some(command) = &args.command
 877        {
 878            resume_from_output(path, &mut examples, command);
 879        }
 880    }
 881
 882    if let Some(max_duplicates) = args.max_duplicates {
 883        deduplicate_examples(&mut examples, max_duplicates);
 884    }
 885
 886    if let Some(limit) = args.limit {
 887        examples.truncate(limit);
 888    }
 889
 890    let progress = Progress::global();
 891    progress.set_total_examples(examples.len());
 892    progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
 893
 894    Ok(examples)
 895}
 896
 897fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
 898    let mut hasher = collections::FxHasher::default();
 899    spec.hash(&mut hasher);
 900    hasher.finish()
 901}
 902
 903fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
 904    let file = match File::open(path) {
 905        Ok(f) => f,
 906        Err(_) => return,
 907    };
 908
 909    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
 910
 911    let reader = BufReader::new(file);
 912    let mut kept_lines = Vec::new();
 913    let mut kept_hashes = HashSet::default();
 914
 915    for line in reader.lines() {
 916        let line = match line {
 917            Ok(l) => l,
 918            Err(_) => continue,
 919        };
 920
 921        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
 922            let hash = spec_hash(&output_example.spec);
 923            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
 924                let is_complete = match command {
 925                    Command::Qa(_) => output_example
 926                        .qa
 927                        .first()
 928                        .and_then(|q| q.as_ref())
 929                        .and_then(|q| q.confidence)
 930                        .is_some(),
 931                    Command::Repair(_) => output_example.predictions.iter().any(|p| {
 932                        p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
 933                    }),
 934                    _ => true,
 935                };
 936                if is_complete {
 937                    kept_hashes.insert(hash);
 938                    kept_lines.push(line);
 939                }
 940            }
 941        }
 942    }
 943
 944    let total = examples.len();
 945    let already_processed = kept_hashes.len();
 946
 947    eprintln!(
 948        "Resuming: {}/{} examples already processed",
 949        already_processed, total
 950    );
 951
 952    let file = OpenOptions::new()
 953        .write(true)
 954        .truncate(true)
 955        .open(path)
 956        .expect("Failed to open output file for rewriting");
 957    let mut writer = BufWriter::new(file);
 958    for line in &kept_lines {
 959        writeln!(writer, "{}", line).expect("Failed to write to output file");
 960    }
 961    writer.flush().expect("Failed to flush output file");
 962
 963    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
 964}
 965
 966fn main() {
 967    let args = EpArgs::parse();
 968
 969    if args.printenv {
 970        ::util::shell_env::print_env();
 971        return;
 972    }
 973
 974    let output = args.output_path();
 975
 976    if args.markdown && output.is_none() {
 977        eprintln!("--markdown requires -o to specify the output directory");
 978        std::process::exit(1);
 979    }
 980
 981    let command = match &args.command {
 982        Some(cmd) => cmd.clone(),
 983        None => {
 984            EpArgs::command().print_help().unwrap();
 985            return;
 986        }
 987    };
 988
 989    match &command {
 990        Command::ImportBatch(import_args) => {
 991            smol::block_on(async {
 992                match import_args.provider {
 993                    BatchProvider::Anthropic => {
 994                        let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
 995                            .expect("Failed to create Anthropic client");
 996                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 997                            eprintln!("Error importing Anthropic batches: {:?}", e);
 998                            std::process::exit(1);
 999                        }
1000                    }
1001                    BatchProvider::Openai => {
1002                        let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
1003                            .expect("Failed to create OpenAI client");
1004                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
1005                            eprintln!("Error importing OpenAI batches: {:?}", e);
1006                            std::process::exit(1);
1007                        }
1008                    }
1009                }
1010                println!(
1011                    "Successfully imported {} batch(es)",
1012                    import_args.batch_ids.len()
1013                );
1014            });
1015            return;
1016        }
1017        Command::Clean => {
1018            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
1019            return;
1020        }
1021        Command::PrintZetaFormats => {
1022            use strum::IntoEnumIterator as _;
1023            for format in ZetaFormat::iter() {
1024                println!("{}", format.to_string().to_lowercase());
1025            }
1026            return;
1027        }
1028
1029        Command::Synthesize(synth_args) => {
1030            let output_dir = if let Some(output_dir) = args.output {
1031                output_dir
1032            } else {
1033                let default_output_dir = env::current_dir()
1034                    .unwrap()
1035                    .join("crates/edit_prediction_cli/evals-generated");
1036                if default_output_dir.parent().unwrap().exists() {
1037                    std::fs::create_dir(&default_output_dir).ok();
1038                    default_output_dir
1039                } else {
1040                    panic!("output dir is required");
1041                }
1042            };
1043            let config = SynthesizeConfig {
1044                repo_urls: synth_args.repos.clone(),
1045                count: synth_args.count,
1046                max_commits: synth_args.max_commits,
1047                output_dir,
1048                fresh: synth_args.fresh,
1049            };
1050            smol::block_on(async {
1051                if let Err(e) = run_synthesize(config).await {
1052                    eprintln!("Error: {:?}", e);
1053                    std::process::exit(1);
1054                }
1055            });
1056            return;
1057        }
1058        Command::SplitCommit(split_commit_args) => {
1059            if let Err(error) = split_commit::run_split_commit(
1060                split_commit_args,
1061                &args.inputs,
1062                output.as_ref(),
1063                args.failed,
1064            ) {
1065                eprintln!("{error:#}");
1066                std::process::exit(1);
1067            }
1068            return;
1069        }
1070        Command::TruncatePatch(truncate_args) => {
1071            if let Err(error) =
1072                truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
1073            {
1074                eprintln!("{error:#}");
1075                std::process::exit(1);
1076            }
1077            return;
1078        }
1079        Command::Split(split_args) => {
1080            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
1081                eprintln!("{error:#}");
1082                std::process::exit(1);
1083            }
1084            return;
1085        }
1086        Command::FilterLanguages(filter_args) => {
1087            if let Err(error) =
1088                run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
1089            {
1090                eprintln!("{error:#}");
1091                std::process::exit(1);
1092            }
1093            return;
1094        }
1095
1096        _ => {}
1097    }
1098
1099    let http_client = Arc::new(ReqwestClient::new());
1100    let app = gpui_platform::headless().with_http_client(http_client);
1101
1102    app.run(move |cx| {
1103        let app_state = Arc::new(headless::init(cx));
1104        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
1105
1106        cx.spawn(async move |cx| {
1107            let result = async {
1108                let examples = load_examples(
1109                    app_state.client.http_client(),
1110                    &args,
1111                    output.as_ref(),
1112                    cx.background_executor().clone(),
1113                )
1114                .await?;
1115
1116                match &command {
1117                    Command::Predict(args) | Command::Score(args) => {
1118                        predict::sync_batches(args.provider.as_ref()).await?;
1119                    }
1120                    Command::Eval(args) => {
1121                        predict::sync_batches(args.predict.provider.as_ref()).await?;
1122                    }
1123                    Command::Qa(args) => {
1124                        qa::sync_batches(args).await?;
1125                    }
1126                    Command::Repair(args) => {
1127                        repair::sync_batches(args).await?;
1128                    }
1129                    _ => (),
1130                }
1131
1132                let failfast_on_single_example = examples.len() == 1;
1133
1134                // For --markdown mode, create the output directory if it doesn't exist
1135                if args.markdown {
1136                    let dir = output.as_ref().expect("--markdown requires -o");
1137                    if !dir.exists() {
1138                        std::fs::create_dir_all(dir)
1139                            .expect("Failed to create markdown output directory");
1140                    }
1141                }
1142
1143                // Set up JSONL output writer (not used in markdown mode)
1144                let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
1145                let mut in_place_temp_path: Option<PathBuf> = None;
1146                if !args.markdown
1147                    && let Some(output_path) = output.as_ref()
1148                {
1149                    let write_path = if args.in_place {
1150                        let temp = output_path.with_extension("jsonl.tmp");
1151                        in_place_temp_path = Some(temp.clone());
1152                        temp
1153                    } else {
1154                        output_path.clone()
1155                    };
1156
1157                    let file = OpenOptions::new()
1158                        .create(true)
1159                        .write(true)
1160                        .truncate(args.in_place)
1161                        .append(!args.in_place)
1162                        .open(&write_path)
1163                        .expect("Failed to open output file");
1164
1165                    let mut writer = BufWriter::new(file);
1166                    let (sender, mut receiver) = mpsc::unbounded::<String>();
1167                    cx.background_spawn(async move {
1168                        while let Some(line) = receiver.next().await {
1169                            writeln!(writer, "{}", line).expect("Failed to write example");
1170                            writer.flush().expect("Failed to flush output");
1171                        }
1172                    })
1173                    .detach();
1174                    output_sender = Some(sender);
1175                }
1176
1177                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
1178                let finished_examples = Mutex::new(Vec::new());
1179
1180                let mut tasks = Vec::new();
1181                for _ in 0..args.max_parallelism {
1182                    tasks.push(async {
1183                        loop {
1184                            let Some(mut repo_examples) =
1185                                grouped_examples.lock().unwrap().pop_front()
1186                            else {
1187                                break;
1188                            };
1189                            for example in &mut repo_examples {
1190                                let example_progress =
1191                                    Progress::global().start_group(&example.spec.name);
1192
1193                                let result = async {
1194                                    match &command {
1195                                        Command::Read(_) => {}
1196                                        Command::LoadProject => {
1197                                            run_load_project(
1198                                                example,
1199                                                app_state.clone(),
1200                                                &example_progress,
1201                                                cx.clone(),
1202                                            )
1203                                            .await?;
1204                                        }
1205                                        Command::Context => {
1206                                            run_context_retrieval(
1207                                                example,
1208                                                app_state.clone(),
1209                                                &example_progress,
1210                                                cx.clone(),
1211                                            )
1212                                            .await?;
1213                                        }
1214                                        Command::FormatPrompt(args) => {
1215                                            run_format_prompt(
1216                                                example,
1217                                                args,
1218                                                app_state.clone(),
1219                                                &example_progress,
1220                                                cx.clone(),
1221                                            )
1222                                            .await?;
1223                                        }
1224                                        Command::Predict(args) => {
1225                                            run_prediction(
1226                                                example,
1227                                                args,
1228                                                app_state.clone(),
1229                                                &example_progress,
1230                                                cx.clone(),
1231                                            )
1232                                            .await?;
1233                                        }
1234                                        Command::ParseOutput => {
1235                                            parse_output::run_parse_output(example)?;
1236                                        }
1237                                        Command::Distill => {
1238                                            run_distill(example).await?;
1239                                        }
1240                                        Command::Score(args) => {
1241                                            run_scoring(
1242                                                example,
1243                                                args,
1244                                                app_state.clone(),
1245                                                &example_progress,
1246                                                cx.clone(),
1247                                            )
1248                                            .await?;
1249                                        }
1250                                        Command::Eval(args) => {
1251                                            run_scoring(
1252                                                example,
1253                                                &args.predict,
1254                                                app_state.clone(),
1255                                                &example_progress,
1256                                                cx.clone(),
1257                                            )
1258                                            .await?;
1259                                        }
1260                                        Command::Qa(args) => {
1261                                            qa::run_qa(example, args, &example_progress).await?;
1262                                        }
1263                                        Command::Repair(args) => {
1264                                            repair::run_repair(example, args, &example_progress)
1265                                                .await?;
1266                                        }
1267                                        Command::Clean
1268                                        | Command::Synthesize(_)
1269                                        | Command::SplitCommit(_)
1270                                        | Command::Split(_)
1271                                        | Command::TruncatePatch(_)
1272                                        | Command::FilterLanguages(_)
1273                                        | Command::ImportBatch(_)
1274                                        | Command::PrintZetaFormats => {
1275                                            unreachable!()
1276                                        }
1277                                    }
1278                                    anyhow::Ok(())
1279                                }
1280                                .await;
1281
1282                                let failed = if let Err(error) = result {
1283                                    handle_error(
1284                                        error,
1285                                        &args,
1286                                        &command,
1287                                        &app_state,
1288                                        failfast_on_single_example,
1289                                        &example,
1290                                    )
1291                                    .await;
1292                                    true
1293                                } else {
1294                                    false
1295                                };
1296
1297                                let should_write = !failed || args.failed == FailedHandling::Keep;
1298                                if should_write {
1299                                    if args.markdown {
1300                                        let markdown_dir =
1301                                            output.as_ref().expect("--markdown requires -o");
1302                                        let filename = format!("{}.md", example.spec.filename());
1303                                        let path = markdown_dir.join(&filename);
1304                                        let markdown = example.spec.to_markdown();
1305                                        std::fs::write(&path, &markdown)
1306                                            .expect("Failed to write markdown file");
1307                                    } else if let Some(ref mut sender) = output_sender.clone() {
1308                                        let line = serde_json::to_string(&example).unwrap();
1309                                        sender
1310                                            .send(line)
1311                                            .await
1312                                            .expect("Failed to send to output writer");
1313                                    } else if args.output.is_none()
1314                                        && !matches!(command, Command::Eval(_))
1315                                    {
1316                                        let line = serde_json::to_string(&example).unwrap();
1317                                        println!("{}", line);
1318                                    }
1319                                }
1320                            }
1321
1322                            let project = repo_examples
1323                                .iter()
1324                                .find_map(|e| e.state.as_ref().map(|s| s.project.clone()));
1325
1326                            if let Some(project) = project {
1327                                let mut cx = cx.clone();
1328
1329                                let shutdown_task: Task<()> =
1330                                    project.update(&mut cx, |project, cx| {
1331                                        let lsp_store = project.lsp_store();
1332                                        lsp_store.update(cx, |lsp_store, cx| {
1333                                            lsp_store.shutdown_all_language_servers(cx)
1334                                        })
1335                                    });
1336
1337                                shutdown_task.await;
1338
1339                                if let Some(ep_store) =
1340                                    cx.update(|cx| EditPredictionStore::try_global(cx))
1341                                {
1342                                    ep_store.update(&mut cx, |store, _| {
1343                                        store.remove_project(&project);
1344                                    });
1345                                }
1346                            }
1347
1348                            for example in &mut repo_examples {
1349                                example.state.take();
1350                            }
1351                            finished_examples
1352                                .lock()
1353                                .unwrap()
1354                                .extend_from_slice(&repo_examples);
1355                        }
1356                    });
1357                }
1358                futures::future::join_all(tasks).await;
1359
1360                Progress::global().finalize();
1361
1362                let is_markdown = args.markdown;
1363                let write_path = in_place_temp_path.as_ref().or(output.as_ref());
1364                match &command {
1365                    Command::Predict(args) | Command::Score(args) => {
1366                        predict::sync_batches(args.provider.as_ref()).await?;
1367                        if args.wait {
1368                            predict::wait_for_batches(args.provider.as_ref()).await?;
1369                            let mut examples =
1370                                std::mem::take(&mut *finished_examples.lock().unwrap());
1371                            predict::reprocess_after_batch_wait(&mut examples, args).await?;
1372                            rewrite_output(&examples, write_path, is_markdown)?;
1373                            *finished_examples.lock().unwrap() = examples;
1374                        }
1375                    }
1376                    Command::Eval(args) => {
1377                        predict::sync_batches(args.predict.provider.as_ref()).await?;
1378                        if args.predict.wait {
1379                            predict::wait_for_batches(args.predict.provider.as_ref()).await?;
1380                            let mut examples =
1381                                std::mem::take(&mut *finished_examples.lock().unwrap());
1382                            predict::reprocess_after_batch_wait(&mut examples, &args.predict)
1383                                .await?;
1384                            rewrite_output(&examples, write_path, is_markdown)?;
1385                            *finished_examples.lock().unwrap() = examples;
1386                        }
1387                    }
1388                    Command::Qa(args) => {
1389                        qa::sync_batches(args).await?;
1390                    }
1391                    Command::Repair(args) => {
1392                        repair::sync_batches(args).await?;
1393                        if args.wait {
1394                            repair::wait_for_batches(args).await?;
1395                            let mut examples =
1396                                std::mem::take(&mut *finished_examples.lock().unwrap());
1397                            repair::reprocess_after_batch_wait(&mut examples, args).await?;
1398                            rewrite_output(&examples, write_path, is_markdown)?;
1399                            *finished_examples.lock().unwrap() = examples;
1400                        }
1401                    }
1402                    _ => (),
1403                }
1404
1405                match &command {
1406                    Command::Eval(args) => {
1407                        let examples = finished_examples.lock().unwrap();
1408                        score::print_report(&examples, args.verbose);
1409                        if let Some(summary_path) = &args.summary_json {
1410                            score::write_summary_json(&examples, summary_path)?;
1411                        }
1412                    }
1413                    Command::Repair(args) => {
1414                        let examples = finished_examples.lock().unwrap();
1415                        repair::print_report(&examples, args.confidence_threshold);
1416                    }
1417                    _ => (),
1418                };
1419
1420                // For --in-place, atomically rename temp file to original
1421                if let Some(temp_path) = &in_place_temp_path {
1422                    let final_path = output.as_ref().expect("in_place_temp_path requires output");
1423                    std::fs::rename(temp_path, final_path)
1424                        .expect("Failed to rename temp file to final output");
1425                }
1426
1427                anyhow::Ok(())
1428            }
1429            .await;
1430
1431            if let Err(e) = result {
1432                panic!("Fatal error: {:?}", e);
1433            }
1434
1435            let _ = cx.update(|cx| cx.quit());
1436        })
1437        .detach();
1438    });
1439}
1440
1441fn rewrite_output(
1442    examples: &[Example],
1443    output_path: Option<&PathBuf>,
1444    markdown: bool,
1445) -> anyhow::Result<()> {
1446    if markdown {
1447        let dir = output_path.context("--markdown requires -o")?;
1448        for example in examples {
1449            let filename = format!("{}.md", example.spec.filename());
1450            let path = dir.join(&filename);
1451            let markdown = example.spec.to_markdown();
1452            std::fs::write(&path, &markdown).context("Failed to write markdown file")?;
1453        }
1454    } else if let Some(path) = output_path {
1455        let file = OpenOptions::new()
1456            .create(true)
1457            .write(true)
1458            .truncate(true)
1459            .open(path)
1460            .context("Failed to open output file for rewriting")?;
1461        let mut writer = BufWriter::new(file);
1462        for example in examples {
1463            let line = serde_json::to_string(example)?;
1464            writeln!(writer, "{}", line)?;
1465        }
1466        writer.flush()?;
1467    } else {
1468        for example in examples {
1469            let line = serde_json::to_string(example)?;
1470            println!("{}", line);
1471        }
1472    }
1473    Ok(())
1474}
1475
1476async fn handle_error(
1477    error: anyhow::Error,
1478    args: &EpArgs,
1479    command: &Command,
1480    app_state: &Arc<headless::EpAppState>,
1481    failfast_on_single_example: bool,
1482    example: &Example,
1483) {
1484    Progress::global().increment_failed();
1485
1486    let msg;
1487    if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1488        let example_name = example.spec.filename();
1489
1490        let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1491        app_state
1492            .fs
1493            .write(
1494                &failed_example_path,
1495                &serde_json::to_vec_pretty(&example).unwrap(),
1496            )
1497            .await
1498            .unwrap();
1499        let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1500        app_state
1501            .fs
1502            .write(&err_path, format!("{error:?}").as_bytes())
1503            .await
1504            .unwrap();
1505
1506        let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1507        let mut file = OpenOptions::new()
1508            .create(true)
1509            .append(true)
1510            .open(&failed_jsonl_path)
1511            .expect("Failed to open failed.jsonl");
1512        writeln!(file, "{}", serde_json::to_string(example).unwrap())
1513            .expect("Failed to write to failed.jsonl");
1514
1515        let cursor_path = match example.repo_name() {
1516            Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1517            Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1518        };
1519        msg = format!(
1520            indoc::indoc! {"
1521                While processing \"{}\":
1522
1523                \x1b[31m{:?}\x1b[0m
1524
1525                Example:        \x1b[36m{}\x1b[0m
1526                Error file:     \x1b[36m{}\x1b[0m
1527                Cursor file:    \x1b[36m{}\x1b[0m
1528                Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1529            "},
1530            example.spec.name,
1531            error,
1532            failed_example_path.display(),
1533            err_path.display(),
1534            cursor_path.display(),
1535            command,
1536            failed_example_path.display(),
1537        );
1538    } else {
1539        msg = format!(
1540            indoc::indoc! {"
1541            While processing \"{}\":
1542
1543                \x1b[31m{:?}\x1b[0m
1544            "},
1545            example.spec.name, error
1546        );
1547    }
1548
1549    if args.failfast || failfast_on_single_example {
1550        Progress::global().finalize();
1551        panic!("{}", msg);
1552    } else {
1553        log::error!("{}", msg);
1554    }
1555}