main.rs

   1mod anthropic_client;
   2mod distill;
   3mod example;
   4mod filter_languages;
   5mod format_prompt;
   6mod git;
   7mod headless;
   8mod kept_rate;
   9mod load_project;
  10mod metrics;
  11mod openai_client;
  12mod parse_output;
  13mod paths;
  14mod predict;
  15mod progress;
  16mod prompt_assets;
  17mod pull_examples;
  18mod qa;
  19mod reorder_patch;
  20mod repair;
  21mod retrieve_context;
  22mod reversal_tracking;
  23mod score;
  24mod split_commit;
  25mod split_dataset;
  26
  27mod synthesize;
  28mod truncate_expected_patch;
  29mod word_diff;
  30use anyhow::Context as _;
  31use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
  32use collections::{HashMap, HashSet};
  33use edit_prediction::EditPredictionStore;
  34use futures::channel::mpsc;
  35use futures::{SinkExt as _, StreamExt as _};
  36use gaoya::minhash::{
  37    MinHashIndex, MinHasher, MinHasher32, calculate_minhash_params, compute_minhash_similarity,
  38};
  39use gpui::{AppContext as _, BackgroundExecutor, Task};
  40use zeta_prompt::ZetaFormat;
  41
  42use reqwest_client::ReqwestClient;
  43use serde::{Deserialize, Deserializer, Serialize, Serializer};
  44use std::env;
  45use std::fmt::Display;
  46use std::fs::{File, OpenOptions};
  47use std::hash::{Hash, Hasher};
  48use std::io::{BufRead, BufReader, BufWriter, Write};
  49use std::sync::Mutex;
  50use std::{path::PathBuf, sync::Arc};
  51
  52use crate::distill::run_distill;
  53use crate::example::{Example, group_examples_by_repo, read_example_files};
  54use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
  55use crate::format_prompt::run_format_prompt;
  56use crate::load_project::run_load_project;
  57use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
  58use crate::predict::run_prediction;
  59use crate::progress::Progress;
  60use crate::pull_examples::{fetch_settled_examples_after, parse_settled_after_input};
  61use crate::retrieve_context::run_context_retrieval;
  62use crate::score::run_scoring;
  63use crate::split_commit::SplitCommitArgs;
  64use crate::split_dataset::SplitArgs;
  65use crate::synthesize::{SynthesizeConfig, run_synthesize};
  66use crate::truncate_expected_patch::TruncatePatchArgs;
  67
  68#[derive(Parser, Debug)]
  69#[command(name = "ep")]
  70struct EpArgs {
  71    #[arg(long, default_value_t = false)]
  72    printenv: bool,
  73    #[clap(long, default_value_t = 10, global = true)]
  74    max_parallelism: usize,
  75    /// The limit for the number of examples to process
  76    /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
  77    #[clap(long, global = true)]
  78    limit: Option<usize>,
  79    #[clap(long, global = true)]
  80    offset: Option<usize>,
  81    /// Filter examples by name
  82    #[clap(long, global = true)]
  83    name: Option<String>,
  84    /// Filter examples by repository
  85    #[clap(long, global = true)]
  86    repo: Option<String>,
  87    /// Deduplicate by cursor position and keep at most this many examples per cluster
  88    #[clap(long, global = true)]
  89    max_duplicates: Option<usize>,
  90    #[command(subcommand)]
  91    command: Option<Command>,
  92    /// Input file paths
  93    #[clap(global = true)]
  94    inputs: Vec<PathBuf>,
  95    #[arg(long, short, global = true)]
  96    output: Option<PathBuf>,
  97    #[arg(long, short, global = true)]
  98    in_place: bool,
  99    #[arg(long, short, global = true)]
 100    failfast: bool,
 101    /// How to handle failed examples in output: keep them or skip them.
 102    /// Failed examples are always logged to the run's failed directory.
 103    #[arg(long, global = true, default_value = "keep")]
 104    failed: FailedHandling,
 105    /// Output as markdown files instead of JSONL. When set, -o specifies a directory
 106    /// where one .md file per example will be written (named after each example).
 107    #[arg(long, short, global = true)]
 108    markdown: bool,
 109}
 110
 111/// Controls whether failed examples are included in the main output.
 112/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
 113#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
 114pub enum FailedHandling {
 115    /// Include failed examples in the main output (default)
 116    #[default]
 117    Keep,
 118    /// Exclude failed examples from the main output
 119    Skip,
 120    /// Skip writing files
 121    SkipNoFiles,
 122}
 123
 124const INPUTS_HELP: &str = r#"
 125Inputs can be file paths or special specifiers:
 126
 127  path
 128      Path to an example(s) file (.md, .json, or .jsonl)
 129
 130  captured-after:{timestamp}
 131      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 132      These are examples captured via the "Capture Edit Prediction Example" action.
 133
 134  rejected-after:{timestamp}
 135      Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
 136      These are predictions that were shown to users but rejected (useful for DPO training).
 137
 138  settled-after:{timestamp}
 139      Fetch settled stream examples from Snowflake after the given RFC3339 timestamp.
 140      These are examples from the edit prediction settled stream.
 141
 142  rated-after:{timestamp}
 143      Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
 144      These are predictions that users explicitly rated as positive or negative via the
 145      rate completions modal. Only zeta2 predictions are included.
 146      - Positive ratings: output becomes expected_patches
 147      - Negative ratings: output becomes rejected_patch
 148
 149  rated-positive-after:{timestamp}
 150      Same as rated-after, but only fetches positively rated predictions.
 151
 152  rated-negative-after:{timestamp}
 153      Same as rated-after, but only fetches negatively rated predictions.
 154
 155      Required environment variables to connect to Snowflake:
 156          EP_SNOWFLAKE_API_KEY
 157          EP_SNOWFLAKE_BASE_URL
 158
 159      Optional:
 160          EP_SNOWFLAKE_ROLE
 161
 162Examples:
 163
 164  # Read examples from a file
 165  ep read examples.jsonl -o output.jsonl
 166
 167  # Read captured examples after a timestamp
 168  ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
 169
 170  # Read rejected predictions for DPO training
 171  ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
 172
 173  # Read user-rated predictions
 174  ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
 175
 176  # Read settled stream examples
 177  ep read settled-after:2025-01-01T00:00:00Z -o settled.jsonl
 178
 179  # Read only positively rated predictions
 180  ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
 181
 182  # Read only negatively rated predictions
 183  ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
 184
 185  # Mix multiple input sources
 186  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
 187"#;
 188
 189#[derive(Subcommand, Debug, Clone)]
 190enum Command {
 191    /// Read examples from files or fetch from Snowflake, output as .jsonl
 192    Read(ReadArgs),
 193    /// Create git worktrees for each example and load file contents
 194    LoadProject,
 195    /// Retrieve context for input examples.
 196    Context,
 197    /// Generate a prompt string for a specific model
 198    FormatPrompt(FormatPromptArgs),
 199    /// Runs edit prediction
 200    Predict(PredictArgs),
 201    /// Parse model outputs (actual_output) into unified diffs (actual_patch).
 202    /// Requires format-prompt to have been run first. Uses provider from prompt.
 203    ParseOutput,
 204    /// Computes a score based on actual and expected patches
 205    Score(PredictArgs),
 206    /// Prepares a distillation dataset by copying expected outputs to
 207    /// predicted outputs and removing actual outputs and prompts.
 208    Distill,
 209    /// Print aggregated scores
 210    Eval(EvalArgs),
 211    /// Generate eval examples by analyzing git commits from a repository
 212    Synthesize(SynthesizeArgs),
 213    /// Remove git repositories and worktrees
 214    Clean,
 215    /// Generate an evaluation example by splitting a chronologically-ordered commit
 216    SplitCommit(SplitCommitArgs),
 217    /// Truncate expected patch by the given criteria
 218    TruncatePatch(TruncatePatchArgs),
 219    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
 220    Split(SplitArgs),
 221    /// Filter a JSONL dataset by programming language (based on cursor_path extension)
 222    FilterLanguages(FilterLanguagesArgs),
 223    /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
 224    ImportBatch(ImportBatchArgs),
 225    /// Assess the quality of predictions using LLM-as-a-judge
 226    Qa(qa::QaArgs),
 227    /// Repair predictions that received poor QA scores by generating improved predictions
 228    Repair(repair::RepairArgs),
 229    /// Print all valid zeta formats (lowercase, one per line)
 230    PrintZetaFormats,
 231}
 232
 233impl Display for Command {
 234    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 235        match self {
 236            Command::Read(_) => write!(f, "read"),
 237            Command::LoadProject => write!(f, "load-project"),
 238            Command::Context => write!(f, "context"),
 239            Command::FormatPrompt(args) => {
 240                write!(f, "format-prompt --provider={}", args.provider)
 241            }
 242            Command::Predict(args) => match &args.provider {
 243                Some(provider) => write!(f, "predict --provider={}", provider),
 244                None => write!(f, "predict"),
 245            },
 246            Command::ParseOutput => write!(f, "parse-output"),
 247            Command::Score(args) => match &args.provider {
 248                Some(provider) => write!(f, "score --provider={}", provider),
 249                None => write!(f, "score"),
 250            },
 251            Command::Distill => write!(f, "distill"),
 252            Command::Eval(args) => match &args.predict.provider {
 253                Some(provider) => write!(f, "eval --provider={}", provider),
 254                None => write!(f, "eval"),
 255            },
 256            Command::Synthesize(args) => {
 257                write!(f, "synthesize --repos {}", args.repos.join(" "))
 258            }
 259            Command::Clean => write!(f, "clean"),
 260            Command::SplitCommit(_) => write!(f, "split-commit"),
 261            Command::TruncatePatch(_) => write!(f, "truncate-patch"),
 262            Command::Split(_) => write!(f, "split"),
 263            Command::FilterLanguages(_) => write!(f, "filter-languages"),
 264            Command::ImportBatch(args) => {
 265                write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
 266            }
 267            Command::Qa(_) => {
 268                write!(f, "qa")
 269            }
 270            Command::Repair(_) => {
 271                write!(f, "repair")
 272            }
 273            Command::PrintZetaFormats => {
 274                write!(f, "print-zeta-formats")
 275            }
 276        }
 277    }
 278}
 279
 280#[derive(Debug, Args, Clone)]
 281#[command(after_help = INPUTS_HELP)]
 282struct ReadArgs {}
 283
 284#[derive(Debug, Args, Clone)]
 285struct FormatPromptArgs {
 286    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
 287    provider: PredictionProvider,
 288}
 289
 290#[derive(Debug, Args, Clone)]
 291struct PredictArgs {
 292    #[clap(long, short('p'))]
 293    provider: Option<PredictionProvider>,
 294    #[clap(long, default_value_t = 1)]
 295    repetitions: usize,
 296    /// Only use cached responses, don't queue new requests for batching
 297    #[clap(long)]
 298    cache_only: bool,
 299    /// Wait for all batches to complete before exiting (only applies to batched providers like teacher)
 300    #[clap(long)]
 301    wait: bool,
 302}
 303
 304#[derive(Debug, Args, Clone)]
 305struct EvalArgs {
 306    #[clap(flatten)]
 307    predict: PredictArgs,
 308    /// Path to write summary scores as JSON
 309    #[clap(long)]
 310    summary_json: Option<PathBuf>,
 311    /// Print all individual example lines (default: up to 20)
 312    #[clap(long)]
 313    verbose: bool,
 314}
 315
 316#[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)]
 317pub enum TeacherBackend {
 318    Sonnet46,
 319    #[default]
 320    Sonnet45,
 321    Gpt52,
 322}
 323
 324impl std::fmt::Display for TeacherBackend {
 325    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 326        match self {
 327            TeacherBackend::Sonnet46 => write!(f, "sonnet46"),
 328            TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
 329            TeacherBackend::Gpt52 => write!(f, "gpt52"),
 330        }
 331    }
 332}
 333
 334impl std::str::FromStr for TeacherBackend {
 335    type Err = anyhow::Error;
 336
 337    fn from_str(s: &str) -> Result<Self, Self::Err> {
 338        match s.to_lowercase().as_str() {
 339            "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
 340            "sonnet46" => Ok(TeacherBackend::Sonnet46),
 341            "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
 342            "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
 343            _ => anyhow::bail!(
 344                "unknown teacher backend `{s}`. Valid options: sonnet45, sonnet46, gpt52"
 345            ),
 346        }
 347    }
 348}
 349
 350impl TeacherBackend {
 351    pub fn model_name(&self) -> &'static str {
 352        match self {
 353            TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
 354            TeacherBackend::Sonnet46 => "claude-sonnet-4-6",
 355            TeacherBackend::Gpt52 => "gpt-5.2",
 356        }
 357    }
 358}
 359
 360#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 361enum PredictionProvider {
 362    Mercury,
 363    Zeta1,
 364    Zeta2(ZetaFormat),
 365    Baseten(ZetaFormat),
 366    Teacher(TeacherBackend),
 367    TeacherMultiRegion(TeacherBackend),
 368    TeacherNonBatching(TeacherBackend),
 369    TeacherMultiRegionNonBatching(TeacherBackend),
 370    Repair,
 371}
 372
 373impl Default for PredictionProvider {
 374    fn default() -> Self {
 375        PredictionProvider::Zeta2(ZetaFormat::default())
 376    }
 377}
 378
 379impl std::fmt::Display for PredictionProvider {
 380    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 381        match self {
 382            PredictionProvider::Mercury => write!(f, "mercury"),
 383            PredictionProvider::Zeta1 => write!(f, "zeta1"),
 384            PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
 385            PredictionProvider::Baseten(format) => write!(f, "baseten:{format}"),
 386            PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
 387            PredictionProvider::TeacherMultiRegion(backend) => {
 388                write!(f, "teacher-multi-region:{backend}")
 389            }
 390            PredictionProvider::TeacherNonBatching(backend) => {
 391                write!(f, "teacher-non-batching:{backend}")
 392            }
 393            PredictionProvider::TeacherMultiRegionNonBatching(backend) => {
 394                write!(f, "teacher-multi-region-non-batching:{backend}")
 395            }
 396            PredictionProvider::Repair => write!(f, "repair"),
 397        }
 398    }
 399}
 400
 401impl std::str::FromStr for PredictionProvider {
 402    type Err = anyhow::Error;
 403
 404    fn from_str(s: &str) -> Result<Self, Self::Err> {
 405        let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
 406
 407        let provider_lower = provider.to_lowercase();
 408        match provider_lower.as_str() {
 409            "mercury" => Ok(PredictionProvider::Mercury),
 410            "zeta1" => Ok(PredictionProvider::Zeta1),
 411            "zeta2" => {
 412                let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
 413                Ok(PredictionProvider::Zeta2(format))
 414            }
 415            "teacher" => {
 416                let backend = arg
 417                    .map(|a| a.parse())
 418                    .transpose()?
 419                    .unwrap_or(TeacherBackend::default());
 420                Ok(PredictionProvider::Teacher(backend))
 421            }
 422            "teacher-multi-region" | "teacher_multi_region" => {
 423                let backend = arg
 424                    .map(|a| a.parse())
 425                    .transpose()?
 426                    .unwrap_or(TeacherBackend::default());
 427                Ok(PredictionProvider::TeacherMultiRegion(backend))
 428            }
 429            "teacher-non-batching" | "teacher_non_batching" => {
 430                let backend = arg
 431                    .map(|a| a.parse())
 432                    .transpose()?
 433                    .unwrap_or(TeacherBackend::default());
 434                Ok(PredictionProvider::TeacherNonBatching(backend))
 435            }
 436            "teacher-multi-region-non-batching" | "teacher_multi_region_non_batching" => {
 437                let backend = arg
 438                    .map(|a| a.parse())
 439                    .transpose()?
 440                    .unwrap_or(TeacherBackend::default());
 441                Ok(PredictionProvider::TeacherMultiRegionNonBatching(backend))
 442            }
 443            "repair" => Ok(PredictionProvider::Repair),
 444            "baseten" => {
 445                let format = arg
 446                    .map(ZetaFormat::parse)
 447                    .transpose()?
 448                    .unwrap_or(ZetaFormat::default());
 449                Ok(PredictionProvider::Baseten(format))
 450            }
 451            _ => {
 452                anyhow::bail!(
 453                    "unknown provider `{provider}`. Valid options: mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-multi-region, teacher-multi-region:<backend>, teacher-non-batching, teacher-multi-region-non-batching, repair\n\
 454                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
 455                 For teacher providers, you can specify a backend like `teacher:sonnet46`, `teacher-multi-region:sonnet46`, `teacher-multi-region-non-batching:sonnet46`, or `teacher:gpt52`.\n\
 456                 Available zeta versions:\n{}",
 457                    ZetaFormat::options_as_string()
 458                )
 459            }
 460        }
 461    }
 462}
 463
 464impl Serialize for PredictionProvider {
 465    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 466    where
 467        S: Serializer,
 468    {
 469        serializer.serialize_str(&self.to_string())
 470    }
 471}
 472
 473impl<'de> Deserialize<'de> for PredictionProvider {
 474    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 475    where
 476        D: Deserializer<'de>,
 477    {
 478        let s = String::deserialize(deserializer)?;
 479        s.parse().map_err(serde::de::Error::custom)
 480    }
 481}
 482
 483#[derive(Debug, Args, Clone)]
 484struct SynthesizeArgs {
 485    /// Repository URLs (git@github.com:owner/repo or https://...)
 486    #[clap(long, required = true, num_args = 1..)]
 487    repos: Vec<String>,
 488
 489    /// Number of examples to generate per repository
 490    #[clap(long, default_value_t = 5)]
 491    count: usize,
 492
 493    /// Maximum commits to scan per repository before giving up
 494    #[clap(long, default_value_t = 100)]
 495    max_commits: usize,
 496
 497    /// Ignore state file and reprocess all commits
 498    #[clap(long)]
 499    fresh: bool,
 500}
 501
 502#[derive(Debug, Args, Clone)]
 503struct ImportBatchArgs {
 504    /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
 505    #[clap(long, required = true, num_args = 1..)]
 506    batch_ids: Vec<String>,
 507    /// Which provider's batches to import (anthropic or openai)
 508    #[clap(long, default_value = "anthropic")]
 509    provider: BatchProvider,
 510}
 511
 512#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
 513enum BatchProvider {
 514    Anthropic,
 515    Openai,
 516}
 517
 518#[cfg(test)]
 519mod tests {
 520    use super::*;
 521
 522    #[test]
 523    fn prediction_provider_multi_region_non_batched_round_trips_to_primary_spelling() {
 524        let provider: PredictionProvider = "teacher-multi-region-non-batching:sonnet46"
 525            .parse()
 526            .unwrap();
 527        assert_eq!(
 528            provider,
 529            PredictionProvider::TeacherMultiRegionNonBatching(TeacherBackend::Sonnet46)
 530        );
 531        assert_eq!(
 532            provider.to_string(),
 533            "teacher-multi-region-non-batching:sonnet46"
 534        );
 535    }
 536
 537    #[test]
 538    fn prediction_provider_multi_region_non_batched_alias_round_trips_to_primary_spelling() {
 539        let provider: PredictionProvider =
 540            "teacher_multi_region_non_batching:gpt52".parse().unwrap();
 541        assert_eq!(
 542            provider,
 543            PredictionProvider::TeacherMultiRegionNonBatching(TeacherBackend::Gpt52)
 544        );
 545        assert_eq!(
 546            provider.to_string(),
 547            "teacher-multi-region-non-batching:gpt52"
 548        );
 549    }
 550}
 551
 552impl EpArgs {
 553    fn output_path(&self) -> Option<PathBuf> {
 554        if self.in_place {
 555            if self.inputs.len() == 1 {
 556                self.inputs.first().cloned()
 557            } else {
 558                panic!("--in-place requires exactly one input file")
 559            }
 560        } else {
 561            self.output.clone()
 562        }
 563    }
 564}
 565
 566/// Minimum Zed version required for Snowflake queries.
 567/// This version introduced the current request schema with predicted edits in the edit
 568/// history, and open source repos distinguished.
 569const MIN_CAPTURE_VERSION: pull_examples::MinCaptureVersion = pull_examples::MinCaptureVersion {
 570    minor: 224,
 571    patch: 1,
 572};
 573
 574fn deduplicate_examples(examples: &mut Vec<Example>, max_per_cluster: usize) {
 575    let total_before_exact = examples.len();
 576    let mut seen_positions = HashSet::default();
 577    examples.retain(|example| seen_positions.insert(example.spec.cursor_position.clone()));
 578    log::info!(
 579        "exact duplicate filter: {total_before_exact} examples → {} examples ({} removed)",
 580        examples.len(),
 581        total_before_exact - examples.len(),
 582    );
 583
 584    const JACCARD_THRESHOLD: f64 = 0.5;
 585    const NUM_HASHES: usize = 128;
 586    const TOKEN_NGRAM_SIZE: usize = 5;
 587
 588    let (num_bands, band_width) = calculate_minhash_params(JACCARD_THRESHOLD, NUM_HASHES);
 589    let num_hashes = num_bands * band_width;
 590    let minhasher = MinHasher32::new(num_hashes);
 591    let mut index: MinHashIndex<u32, usize> =
 592        MinHashIndex::new(num_bands, band_width, JACCARD_THRESHOLD);
 593
 594    let signatures: Vec<Vec<u32>> = examples
 595        .iter()
 596        .map(|example| {
 597            let shingles = code_token_ngrams(&example.spec.cursor_position, TOKEN_NGRAM_SIZE);
 598            minhasher.create_signature(shingles.iter())
 599        })
 600        .collect();
 601
 602    for (id, signature) in signatures.iter().enumerate() {
 603        index.insert(id, signature.clone());
 604    }
 605
 606    // Build clusters via union-find on LSH candidate pairs.
 607    let mut parent: Vec<usize> = (0..examples.len()).collect();
 608
 609    fn find(parent: &mut Vec<usize>, mut x: usize) -> usize {
 610        while parent[x] != x {
 611            parent[x] = parent[parent[x]];
 612            x = parent[x];
 613        }
 614        x
 615    }
 616
 617    for (id, signature) in signatures.iter().enumerate() {
 618        for candidate in index.query_owned(signature) {
 619            let (a, b) = (find(&mut parent, id), find(&mut parent, candidate));
 620            if a != b {
 621                parent[a] = b;
 622            }
 623        }
 624    }
 625
 626    let mut clusters: HashMap<usize, Vec<usize>> = HashMap::default();
 627    for id in 0..examples.len() {
 628        clusters.entry(find(&mut parent, id)).or_default().push(id);
 629    }
 630
 631    let mut keep: HashSet<usize> = HashSet::default();
 632    for members in clusters.values() {
 633        let selected = greedy_max_min_diverse(members, &signatures, max_per_cluster);
 634        keep.extend(selected);
 635    }
 636
 637    let total = examples.len();
 638    let mut kept_indices: Vec<usize> = keep.into_iter().collect();
 639    kept_indices.sort();
 640
 641    let mut retained = Vec::with_capacity(kept_indices.len());
 642    for index in kept_indices.into_iter().rev() {
 643        retained.push(examples.swap_remove(index));
 644    }
 645    retained.reverse();
 646
 647    *examples = retained;
 648    log::info!(
 649        "near-duplicate filter: {total} examples → {} examples ({} removed)",
 650        examples.len(),
 651        total - examples.len(),
 652    );
 653}
 654
 655fn greedy_max_min_diverse(members: &[usize], signatures: &[Vec<u32>], k: usize) -> Vec<usize> {
 656    if members.len() <= k {
 657        return members.to_vec();
 658    }
 659
 660    let mut selected = vec![members[0]];
 661    let mut min_dist: HashMap<usize, f64> = HashMap::default();
 662    for &member in &members[1..] {
 663        let dist = 1.0 - compute_minhash_similarity(&signatures[selected[0]], &signatures[member]);
 664        min_dist.insert(member, dist);
 665    }
 666
 667    while selected.len() < k {
 668        let &best = min_dist
 669            .iter()
 670            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
 671            .map(|(id, _)| id)
 672            .expect("min_dist should not be empty when selected.len() < k");
 673        selected.push(best);
 674        min_dist.remove(&best);
 675
 676        let best_sig = &signatures[best];
 677        for (member, current_min) in min_dist.iter_mut() {
 678            let dist = 1.0 - compute_minhash_similarity(best_sig, &signatures[*member]);
 679            if dist < *current_min {
 680                *current_min = dist;
 681            }
 682        }
 683    }
 684
 685    selected
 686}
 687
 688fn code_token_ngrams(code: &str, ngram_size: usize) -> Vec<String> {
 689    let tokens: Vec<&str> = word_diff::tokenize(code)
 690        .into_iter()
 691        .filter(|t| !t.trim().is_empty())
 692        .collect();
 693
 694    if tokens.len() < ngram_size {
 695        return vec![tokens.join("\0")];
 696    }
 697
 698    tokens
 699        .windows(ngram_size)
 700        .map(|window| window.join("\0"))
 701        .collect()
 702}
 703
 704async fn load_examples(
 705    http_client: Arc<dyn http_client::HttpClient>,
 706    args: &EpArgs,
 707    output_path: Option<&PathBuf>,
 708    background_executor: BackgroundExecutor,
 709) -> anyhow::Result<Vec<Example>> {
 710    let mut captured_after_timestamps = Vec::new();
 711    let mut rejected_after_timestamps = Vec::new();
 712    let mut requested_after_timestamps = Vec::new();
 713    let mut settled_after_timestamps = Vec::new();
 714    let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
 715        Vec::new();
 716    let mut file_inputs = Vec::new();
 717
 718    for input in &args.inputs {
 719        let input_string = input.to_string_lossy();
 720        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
 721            captured_after_timestamps.push(timestamp.to_string());
 722        } else if let Some(timestamp) =
 723            pull_examples::parse_rejected_after_input(input_string.as_ref())
 724        {
 725            rejected_after_timestamps.push(timestamp.to_string());
 726        } else if let Some(timestamp) =
 727            pull_examples::parse_requested_after_input(input_string.as_ref())
 728        {
 729            requested_after_timestamps.push(timestamp.to_string());
 730        } else if let Some(timestamp) = parse_settled_after_input(input_string.as_ref()) {
 731            settled_after_timestamps.push(timestamp.to_string());
 732        } else if let Some((timestamp, rating_filter)) =
 733            pull_examples::parse_rated_after_input(input_string.as_ref())
 734        {
 735            rated_after_inputs.push((timestamp.to_string(), rating_filter));
 736        } else {
 737            file_inputs.push(input.clone());
 738        }
 739    }
 740
 741    let mut examples = read_example_files(&file_inputs);
 742
 743    // Apply offset to file examples first, then pass remaining offset to Snowflake.
 744    let file_example_count = examples.len();
 745    let remaining_offset = if let Some(offset) = args.offset {
 746        if offset >= file_example_count {
 747            examples.clear();
 748            offset - file_example_count
 749        } else {
 750            examples.splice(0..offset, []);
 751            0
 752        }
 753    } else {
 754        0
 755    };
 756
 757    Progress::global().set_total_examples(examples.len());
 758
 759    let remaining_limit_for_snowflake =
 760        args.limit.map(|limit| limit.saturating_sub(examples.len()));
 761
 762    if let Some(0) = remaining_limit_for_snowflake {
 763        log::info!(
 764            "skipping Snowflake inputs because --limit is already satisfied by example files"
 765        );
 766    } else {
 767        let max_rows_per_timestamp = remaining_limit_for_snowflake;
 768
 769        if !rejected_after_timestamps.is_empty() {
 770            rejected_after_timestamps.sort();
 771
 772            let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
 773                http_client.clone(),
 774                &rejected_after_timestamps,
 775                max_rows_per_timestamp,
 776                remaining_offset,
 777                background_executor.clone(),
 778                Some(MIN_CAPTURE_VERSION),
 779            )
 780            .await?;
 781            examples.append(&mut rejected_examples);
 782        }
 783
 784        if !requested_after_timestamps.is_empty() {
 785            requested_after_timestamps.sort();
 786
 787            let mut requested_examples = pull_examples::fetch_requested_examples_after(
 788                http_client.clone(),
 789                &requested_after_timestamps,
 790                max_rows_per_timestamp,
 791                remaining_offset,
 792                background_executor.clone(),
 793                Some(MIN_CAPTURE_VERSION),
 794            )
 795            .await?;
 796            examples.append(&mut requested_examples);
 797        }
 798
 799        if !captured_after_timestamps.is_empty() {
 800            captured_after_timestamps.sort();
 801
 802            let mut captured_examples = pull_examples::fetch_captured_examples_after(
 803                http_client.clone(),
 804                &captured_after_timestamps,
 805                max_rows_per_timestamp,
 806                remaining_offset,
 807                background_executor.clone(),
 808                Some(MIN_CAPTURE_VERSION),
 809            )
 810            .await?;
 811            examples.append(&mut captured_examples);
 812        }
 813
 814        if !settled_after_timestamps.is_empty() {
 815            settled_after_timestamps.sort();
 816
 817            let mut settled_examples = fetch_settled_examples_after(
 818                http_client.clone(),
 819                &settled_after_timestamps,
 820                max_rows_per_timestamp,
 821                remaining_offset,
 822                background_executor.clone(),
 823                Some(MIN_CAPTURE_VERSION),
 824            )
 825            .await?;
 826            examples.append(&mut settled_examples);
 827        }
 828
 829        if !rated_after_inputs.is_empty() {
 830            rated_after_inputs.sort();
 831
 832            let mut rated_examples = pull_examples::fetch_rated_examples_after(
 833                http_client,
 834                &rated_after_inputs,
 835                max_rows_per_timestamp,
 836                remaining_offset,
 837                background_executor,
 838                Some(MIN_CAPTURE_VERSION),
 839            )
 840            .await?;
 841            examples.append(&mut rated_examples);
 842        }
 843    }
 844
 845    crate::example::sort_examples_by_repo_and_rev(&mut examples);
 846
 847    if let Some(name_filter) = &args.name {
 848        examples.retain(|example| example.spec.name.contains(name_filter));
 849    }
 850    if let Some(repo_filter) = &args.repo {
 851        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
 852    }
 853
 854    // Skip resume logic for --in-place since input and output are the same file,
 855    // which would incorrectly treat all input examples as already processed.
 856    if !args.in_place {
 857        if let Some(path) = output_path
 858            && let Some(command) = &args.command
 859        {
 860            resume_from_output(path, &mut examples, command);
 861        }
 862    }
 863
 864    if let Some(max_duplicates) = args.max_duplicates {
 865        deduplicate_examples(&mut examples, max_duplicates);
 866    }
 867
 868    if let Some(limit) = args.limit {
 869        examples.truncate(limit);
 870    }
 871
 872    let progress = Progress::global();
 873    progress.set_total_examples(examples.len());
 874    progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
 875
 876    Ok(examples)
 877}
 878
 879fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
 880    let mut hasher = collections::FxHasher::default();
 881    spec.hash(&mut hasher);
 882    hasher.finish()
 883}
 884
 885fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
 886    let file = match File::open(path) {
 887        Ok(f) => f,
 888        Err(_) => return,
 889    };
 890
 891    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
 892
 893    let reader = BufReader::new(file);
 894    let mut kept_lines = Vec::new();
 895    let mut kept_hashes = HashSet::default();
 896
 897    for line in reader.lines() {
 898        let line = match line {
 899            Ok(l) => l,
 900            Err(_) => continue,
 901        };
 902
 903        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
 904            let hash = spec_hash(&output_example.spec);
 905            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
 906                let is_complete = match command {
 907                    Command::Qa(_) => output_example
 908                        .qa
 909                        .first()
 910                        .and_then(|q| q.as_ref())
 911                        .and_then(|q| q.confidence)
 912                        .is_some(),
 913                    Command::Repair(_) => output_example.predictions.iter().any(|p| {
 914                        p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
 915                    }),
 916                    _ => true,
 917                };
 918                if is_complete {
 919                    kept_hashes.insert(hash);
 920                    kept_lines.push(line);
 921                }
 922            }
 923        }
 924    }
 925
 926    let total = examples.len();
 927    let already_processed = kept_hashes.len();
 928
 929    eprintln!(
 930        "Resuming: {}/{} examples already processed",
 931        already_processed, total
 932    );
 933
 934    let file = OpenOptions::new()
 935        .write(true)
 936        .truncate(true)
 937        .open(path)
 938        .expect("Failed to open output file for rewriting");
 939    let mut writer = BufWriter::new(file);
 940    for line in &kept_lines {
 941        writeln!(writer, "{}", line).expect("Failed to write to output file");
 942    }
 943    writer.flush().expect("Failed to flush output file");
 944
 945    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
 946}
 947
 948fn main() {
 949    let args = EpArgs::parse();
 950
 951    if args.printenv {
 952        ::util::shell_env::print_env();
 953        return;
 954    }
 955
 956    let output = args.output_path();
 957
 958    if args.markdown && output.is_none() {
 959        eprintln!("--markdown requires -o to specify the output directory");
 960        std::process::exit(1);
 961    }
 962
 963    let command = match &args.command {
 964        Some(cmd) => cmd.clone(),
 965        None => {
 966            EpArgs::command().print_help().unwrap();
 967            return;
 968        }
 969    };
 970
 971    match &command {
 972        Command::ImportBatch(import_args) => {
 973            smol::block_on(async {
 974                match import_args.provider {
 975                    BatchProvider::Anthropic => {
 976                        let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
 977                            .expect("Failed to create Anthropic client");
 978                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 979                            eprintln!("Error importing Anthropic batches: {:?}", e);
 980                            std::process::exit(1);
 981                        }
 982                    }
 983                    BatchProvider::Openai => {
 984                        let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
 985                            .expect("Failed to create OpenAI client");
 986                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 987                            eprintln!("Error importing OpenAI batches: {:?}", e);
 988                            std::process::exit(1);
 989                        }
 990                    }
 991                }
 992                println!(
 993                    "Successfully imported {} batch(es)",
 994                    import_args.batch_ids.len()
 995                );
 996            });
 997            return;
 998        }
 999        Command::Clean => {
1000            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
1001            return;
1002        }
1003        Command::PrintZetaFormats => {
1004            use strum::IntoEnumIterator as _;
1005            for format in ZetaFormat::iter() {
1006                println!("{}", format.to_string().to_lowercase());
1007            }
1008            return;
1009        }
1010
1011        Command::Synthesize(synth_args) => {
1012            let output_dir = if let Some(output_dir) = args.output {
1013                output_dir
1014            } else {
1015                let default_output_dir = env::current_dir()
1016                    .unwrap()
1017                    .join("crates/edit_prediction_cli/evals-generated");
1018                if default_output_dir.parent().unwrap().exists() {
1019                    std::fs::create_dir(&default_output_dir).ok();
1020                    default_output_dir
1021                } else {
1022                    panic!("output dir is required");
1023                }
1024            };
1025            let config = SynthesizeConfig {
1026                repo_urls: synth_args.repos.clone(),
1027                count: synth_args.count,
1028                max_commits: synth_args.max_commits,
1029                output_dir,
1030                fresh: synth_args.fresh,
1031            };
1032            smol::block_on(async {
1033                if let Err(e) = run_synthesize(config).await {
1034                    eprintln!("Error: {:?}", e);
1035                    std::process::exit(1);
1036                }
1037            });
1038            return;
1039        }
1040        Command::SplitCommit(split_commit_args) => {
1041            if let Err(error) = split_commit::run_split_commit(
1042                split_commit_args,
1043                &args.inputs,
1044                output.as_ref(),
1045                args.failed,
1046            ) {
1047                eprintln!("{error:#}");
1048                std::process::exit(1);
1049            }
1050            return;
1051        }
1052        Command::TruncatePatch(truncate_args) => {
1053            if let Err(error) =
1054                truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
1055            {
1056                eprintln!("{error:#}");
1057                std::process::exit(1);
1058            }
1059            return;
1060        }
1061        Command::Split(split_args) => {
1062            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
1063                eprintln!("{error:#}");
1064                std::process::exit(1);
1065            }
1066            return;
1067        }
1068        Command::FilterLanguages(filter_args) => {
1069            if let Err(error) =
1070                run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
1071            {
1072                eprintln!("{error:#}");
1073                std::process::exit(1);
1074            }
1075            return;
1076        }
1077
1078        _ => {}
1079    }
1080
1081    let http_client = Arc::new(ReqwestClient::new());
1082    let app = gpui_platform::headless().with_http_client(http_client);
1083
1084    app.run(move |cx| {
1085        let app_state = Arc::new(headless::init(cx));
1086        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
1087
1088        cx.spawn(async move |cx| {
1089            let result = async {
1090                let examples = load_examples(
1091                    app_state.client.http_client(),
1092                    &args,
1093                    output.as_ref(),
1094                    cx.background_executor().clone(),
1095                )
1096                .await?;
1097
1098                match &command {
1099                    Command::Predict(args) | Command::Score(args) => {
1100                        predict::sync_batches(args.provider.as_ref()).await?;
1101                    }
1102                    Command::Eval(args) => {
1103                        predict::sync_batches(args.predict.provider.as_ref()).await?;
1104                    }
1105                    Command::Qa(args) => {
1106                        qa::sync_batches(args).await?;
1107                    }
1108                    Command::Repair(args) => {
1109                        repair::sync_batches(args).await?;
1110                    }
1111                    _ => (),
1112                }
1113
1114                let failfast_on_single_example = examples.len() == 1;
1115
1116                // For --markdown mode, create the output directory if it doesn't exist
1117                if args.markdown {
1118                    let dir = output.as_ref().expect("--markdown requires -o");
1119                    if !dir.exists() {
1120                        std::fs::create_dir_all(dir)
1121                            .expect("Failed to create markdown output directory");
1122                    }
1123                }
1124
1125                // Set up JSONL output writer (not used in markdown mode)
1126                let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
1127                let mut in_place_temp_path: Option<PathBuf> = None;
1128                if !args.markdown
1129                    && let Some(output_path) = output.as_ref()
1130                {
1131                    let write_path = if args.in_place {
1132                        let temp = output_path.with_extension("jsonl.tmp");
1133                        in_place_temp_path = Some(temp.clone());
1134                        temp
1135                    } else {
1136                        output_path.clone()
1137                    };
1138
1139                    let file = OpenOptions::new()
1140                        .create(true)
1141                        .write(true)
1142                        .truncate(args.in_place)
1143                        .append(!args.in_place)
1144                        .open(&write_path)
1145                        .expect("Failed to open output file");
1146
1147                    let mut writer = BufWriter::new(file);
1148                    let (sender, mut receiver) = mpsc::unbounded::<String>();
1149                    cx.background_spawn(async move {
1150                        while let Some(line) = receiver.next().await {
1151                            writeln!(writer, "{}", line).expect("Failed to write example");
1152                            writer.flush().expect("Failed to flush output");
1153                        }
1154                    })
1155                    .detach();
1156                    output_sender = Some(sender);
1157                }
1158
1159                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
1160                let finished_examples = Mutex::new(Vec::new());
1161
1162                let mut tasks = Vec::new();
1163                for _ in 0..args.max_parallelism {
1164                    tasks.push(async {
1165                        loop {
1166                            let Some(mut repo_examples) =
1167                                grouped_examples.lock().unwrap().pop_front()
1168                            else {
1169                                break;
1170                            };
1171                            for example in &mut repo_examples {
1172                                let example_progress =
1173                                    Progress::global().start_group(&example.spec.name);
1174
1175                                let result = async {
1176                                    match &command {
1177                                        Command::Read(_) => {}
1178                                        Command::LoadProject => {
1179                                            run_load_project(
1180                                                example,
1181                                                app_state.clone(),
1182                                                &example_progress,
1183                                                cx.clone(),
1184                                            )
1185                                            .await?;
1186                                        }
1187                                        Command::Context => {
1188                                            run_context_retrieval(
1189                                                example,
1190                                                app_state.clone(),
1191                                                &example_progress,
1192                                                cx.clone(),
1193                                            )
1194                                            .await?;
1195                                        }
1196                                        Command::FormatPrompt(args) => {
1197                                            run_format_prompt(
1198                                                example,
1199                                                args,
1200                                                app_state.clone(),
1201                                                &example_progress,
1202                                                cx.clone(),
1203                                            )
1204                                            .await?;
1205                                        }
1206                                        Command::Predict(args) => {
1207                                            run_prediction(
1208                                                example,
1209                                                args,
1210                                                app_state.clone(),
1211                                                &example_progress,
1212                                                cx.clone(),
1213                                            )
1214                                            .await?;
1215                                        }
1216                                        Command::ParseOutput => {
1217                                            parse_output::run_parse_output(example)?;
1218                                        }
1219                                        Command::Distill => {
1220                                            run_distill(example).await?;
1221                                        }
1222                                        Command::Score(args) => {
1223                                            run_scoring(
1224                                                example,
1225                                                args,
1226                                                app_state.clone(),
1227                                                &example_progress,
1228                                                cx.clone(),
1229                                            )
1230                                            .await?;
1231                                        }
1232                                        Command::Eval(args) => {
1233                                            run_scoring(
1234                                                example,
1235                                                &args.predict,
1236                                                app_state.clone(),
1237                                                &example_progress,
1238                                                cx.clone(),
1239                                            )
1240                                            .await?;
1241                                        }
1242                                        Command::Qa(args) => {
1243                                            qa::run_qa(example, args, &example_progress).await?;
1244                                        }
1245                                        Command::Repair(args) => {
1246                                            repair::run_repair(example, args, &example_progress)
1247                                                .await?;
1248                                        }
1249                                        Command::Clean
1250                                        | Command::Synthesize(_)
1251                                        | Command::SplitCommit(_)
1252                                        | Command::Split(_)
1253                                        | Command::TruncatePatch(_)
1254                                        | Command::FilterLanguages(_)
1255                                        | Command::ImportBatch(_)
1256                                        | Command::PrintZetaFormats => {
1257                                            unreachable!()
1258                                        }
1259                                    }
1260                                    anyhow::Ok(())
1261                                }
1262                                .await;
1263
1264                                let failed = if let Err(error) = result {
1265                                    handle_error(
1266                                        error,
1267                                        &args,
1268                                        &command,
1269                                        &app_state,
1270                                        failfast_on_single_example,
1271                                        &example,
1272                                    )
1273                                    .await;
1274                                    true
1275                                } else {
1276                                    false
1277                                };
1278
1279                                let should_write = !failed || args.failed == FailedHandling::Keep;
1280                                if should_write {
1281                                    if args.markdown {
1282                                        let markdown_dir =
1283                                            output.as_ref().expect("--markdown requires -o");
1284                                        let filename = format!("{}.md", example.spec.filename());
1285                                        let path = markdown_dir.join(&filename);
1286                                        let markdown = example.spec.to_markdown();
1287                                        std::fs::write(&path, &markdown)
1288                                            .expect("Failed to write markdown file");
1289                                    } else if let Some(ref mut sender) = output_sender.clone() {
1290                                        let line = serde_json::to_string(&example).unwrap();
1291                                        sender
1292                                            .send(line)
1293                                            .await
1294                                            .expect("Failed to send to output writer");
1295                                    } else if args.output.is_none()
1296                                        && !matches!(command, Command::Eval(_))
1297                                    {
1298                                        let line = serde_json::to_string(&example).unwrap();
1299                                        println!("{}", line);
1300                                    }
1301                                }
1302                            }
1303
1304                            let project = repo_examples
1305                                .iter()
1306                                .find_map(|e| e.state.as_ref().map(|s| s.project.clone()));
1307
1308                            if let Some(project) = project {
1309                                let mut cx = cx.clone();
1310
1311                                let shutdown_task: Task<()> =
1312                                    project.update(&mut cx, |project, cx| {
1313                                        let lsp_store = project.lsp_store();
1314                                        lsp_store.update(cx, |lsp_store, cx| {
1315                                            lsp_store.shutdown_all_language_servers(cx)
1316                                        })
1317                                    });
1318
1319                                shutdown_task.await;
1320
1321                                if let Some(ep_store) =
1322                                    cx.update(|cx| EditPredictionStore::try_global(cx))
1323                                {
1324                                    ep_store.update(&mut cx, |store, _| {
1325                                        store.remove_project(&project);
1326                                    });
1327                                }
1328                            }
1329
1330                            for example in &mut repo_examples {
1331                                example.state.take();
1332                            }
1333                            finished_examples
1334                                .lock()
1335                                .unwrap()
1336                                .extend_from_slice(&repo_examples);
1337                        }
1338                    });
1339                }
1340                futures::future::join_all(tasks).await;
1341
1342                Progress::global().finalize();
1343
1344                let is_markdown = args.markdown;
1345                let write_path = in_place_temp_path.as_ref().or(output.as_ref());
1346                match &command {
1347                    Command::Predict(args) | Command::Score(args) => {
1348                        predict::sync_batches(args.provider.as_ref()).await?;
1349                        if args.wait {
1350                            predict::wait_for_batches(args.provider.as_ref()).await?;
1351                            let mut examples =
1352                                std::mem::take(&mut *finished_examples.lock().unwrap());
1353                            predict::reprocess_after_batch_wait(&mut examples, args).await?;
1354                            rewrite_output(&examples, write_path, is_markdown)?;
1355                            *finished_examples.lock().unwrap() = examples;
1356                        }
1357                    }
1358                    Command::Eval(args) => {
1359                        predict::sync_batches(args.predict.provider.as_ref()).await?;
1360                        if args.predict.wait {
1361                            predict::wait_for_batches(args.predict.provider.as_ref()).await?;
1362                            let mut examples =
1363                                std::mem::take(&mut *finished_examples.lock().unwrap());
1364                            predict::reprocess_after_batch_wait(&mut examples, &args.predict)
1365                                .await?;
1366                            rewrite_output(&examples, write_path, is_markdown)?;
1367                            *finished_examples.lock().unwrap() = examples;
1368                        }
1369                    }
1370                    Command::Qa(args) => {
1371                        qa::sync_batches(args).await?;
1372                    }
1373                    Command::Repair(args) => {
1374                        repair::sync_batches(args).await?;
1375                        if args.wait {
1376                            repair::wait_for_batches(args).await?;
1377                            let mut examples =
1378                                std::mem::take(&mut *finished_examples.lock().unwrap());
1379                            repair::reprocess_after_batch_wait(&mut examples, args).await?;
1380                            rewrite_output(&examples, write_path, is_markdown)?;
1381                            *finished_examples.lock().unwrap() = examples;
1382                        }
1383                    }
1384                    _ => (),
1385                }
1386
1387                match &command {
1388                    Command::Eval(args) => {
1389                        let examples = finished_examples.lock().unwrap();
1390                        score::print_report(&examples, args.verbose);
1391                        if let Some(summary_path) = &args.summary_json {
1392                            score::write_summary_json(&examples, summary_path)?;
1393                        }
1394                    }
1395                    Command::Repair(args) => {
1396                        let examples = finished_examples.lock().unwrap();
1397                        repair::print_report(&examples, args.confidence_threshold);
1398                    }
1399                    _ => (),
1400                };
1401
1402                // For --in-place, atomically rename temp file to original
1403                if let Some(temp_path) = &in_place_temp_path {
1404                    let final_path = output.as_ref().expect("in_place_temp_path requires output");
1405                    std::fs::rename(temp_path, final_path)
1406                        .expect("Failed to rename temp file to final output");
1407                }
1408
1409                anyhow::Ok(())
1410            }
1411            .await;
1412
1413            if let Err(e) = result {
1414                panic!("Fatal error: {:?}", e);
1415            }
1416
1417            let _ = cx.update(|cx| cx.quit());
1418        })
1419        .detach();
1420    });
1421}
1422
1423fn rewrite_output(
1424    examples: &[Example],
1425    output_path: Option<&PathBuf>,
1426    markdown: bool,
1427) -> anyhow::Result<()> {
1428    if markdown {
1429        let dir = output_path.context("--markdown requires -o")?;
1430        for example in examples {
1431            let filename = format!("{}.md", example.spec.filename());
1432            let path = dir.join(&filename);
1433            let markdown = example.spec.to_markdown();
1434            std::fs::write(&path, &markdown).context("Failed to write markdown file")?;
1435        }
1436    } else if let Some(path) = output_path {
1437        let file = OpenOptions::new()
1438            .create(true)
1439            .write(true)
1440            .truncate(true)
1441            .open(path)
1442            .context("Failed to open output file for rewriting")?;
1443        let mut writer = BufWriter::new(file);
1444        for example in examples {
1445            let line = serde_json::to_string(example)?;
1446            writeln!(writer, "{}", line)?;
1447        }
1448        writer.flush()?;
1449    } else {
1450        for example in examples {
1451            let line = serde_json::to_string(example)?;
1452            println!("{}", line);
1453        }
1454    }
1455    Ok(())
1456}
1457
1458async fn handle_error(
1459    error: anyhow::Error,
1460    args: &EpArgs,
1461    command: &Command,
1462    app_state: &Arc<headless::EpAppState>,
1463    failfast_on_single_example: bool,
1464    example: &Example,
1465) {
1466    Progress::global().increment_failed();
1467
1468    let msg;
1469    if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1470        let example_name = example.spec.filename();
1471
1472        let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1473        app_state
1474            .fs
1475            .write(
1476                &failed_example_path,
1477                &serde_json::to_vec_pretty(&example).unwrap(),
1478            )
1479            .await
1480            .unwrap();
1481        let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1482        app_state
1483            .fs
1484            .write(&err_path, format!("{error:?}").as_bytes())
1485            .await
1486            .unwrap();
1487
1488        let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1489        let mut file = OpenOptions::new()
1490            .create(true)
1491            .append(true)
1492            .open(&failed_jsonl_path)
1493            .expect("Failed to open failed.jsonl");
1494        writeln!(file, "{}", serde_json::to_string(example).unwrap())
1495            .expect("Failed to write to failed.jsonl");
1496
1497        let cursor_path = match example.repo_name() {
1498            Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1499            Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1500        };
1501        msg = format!(
1502            indoc::indoc! {"
1503                While processing \"{}\":
1504
1505                \x1b[31m{:?}\x1b[0m
1506
1507                Example:        \x1b[36m{}\x1b[0m
1508                Error file:     \x1b[36m{}\x1b[0m
1509                Cursor file:    \x1b[36m{}\x1b[0m
1510                Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1511            "},
1512            example.spec.name,
1513            error,
1514            failed_example_path.display(),
1515            err_path.display(),
1516            cursor_path.display(),
1517            command,
1518            failed_example_path.display(),
1519        );
1520    } else {
1521        msg = format!(
1522            indoc::indoc! {"
1523            While processing \"{}\":
1524
1525                \x1b[31m{:?}\x1b[0m
1526            "},
1527            example.spec.name, error
1528        );
1529    }
1530
1531    if args.failfast || failfast_on_single_example {
1532        Progress::global().finalize();
1533        panic!("{}", msg);
1534    } else {
1535        log::error!("{}", msg);
1536    }
1537}