main.rs

   1mod anthropic_client;
   2mod distill;
   3mod example;
   4mod filter_languages;
   5mod format_prompt;
   6mod git;
   7mod headless;
   8mod load_project;
   9mod metrics;
  10mod openai_client;
  11mod parse_output;
  12mod paths;
  13mod predict;
  14mod progress;
  15mod prompt_assets;
  16mod pull_examples;
  17mod qa;
  18mod reorder_patch;
  19mod repair;
  20mod retrieve_context;
  21mod reversal_tracking;
  22mod score;
  23mod split_commit;
  24mod split_dataset;
  25
  26mod synthesize;
  27mod truncate_expected_patch;
  28mod word_diff;
  29use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
  30use collections::{HashMap, HashSet};
  31use edit_prediction::EditPredictionStore;
  32use futures::channel::mpsc;
  33use futures::{SinkExt as _, StreamExt as _};
  34use gaoya::minhash::{
  35    MinHashIndex, MinHasher, MinHasher32, calculate_minhash_params, compute_minhash_similarity,
  36};
  37use gpui::{AppContext as _, BackgroundExecutor, Task};
  38use zeta_prompt::ZetaFormat;
  39
  40use reqwest_client::ReqwestClient;
  41use serde::{Deserialize, Deserializer, Serialize, Serializer};
  42use std::fmt::Display;
  43use std::fs::{File, OpenOptions};
  44use std::hash::{Hash, Hasher};
  45use std::io::{BufRead, BufReader, BufWriter, Write};
  46use std::sync::Mutex;
  47use std::{path::PathBuf, sync::Arc};
  48
  49use crate::distill::run_distill;
  50use crate::example::{Example, group_examples_by_repo, read_example_files};
  51use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
  52use crate::format_prompt::run_format_prompt;
  53use crate::load_project::run_load_project;
  54use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
  55use crate::predict::run_prediction;
  56use crate::progress::Progress;
  57use crate::retrieve_context::run_context_retrieval;
  58use crate::score::run_scoring;
  59use crate::split_commit::SplitCommitArgs;
  60use crate::split_dataset::SplitArgs;
  61use crate::synthesize::{SynthesizeConfig, run_synthesize};
  62use crate::truncate_expected_patch::TruncatePatchArgs;
  63
  64#[derive(Parser, Debug)]
  65#[command(name = "ep")]
  66struct EpArgs {
  67    #[arg(long, default_value_t = false)]
  68    printenv: bool,
  69    #[clap(long, default_value_t = 10, global = true)]
  70    max_parallelism: usize,
  71    /// The limit for the number of examples to process
  72    /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
  73    #[clap(long, global = true)]
  74    limit: Option<usize>,
  75    #[clap(long, global = true)]
  76    offset: Option<usize>,
  77    /// Filter examples by name
  78    #[clap(long, global = true)]
  79    name: Option<String>,
  80    /// Filter examples by repository
  81    #[clap(long, global = true)]
  82    repo: Option<String>,
  83    /// Deduplicate by cursor position and keep at most this many examples per cluster
  84    #[clap(long, global = true)]
  85    max_duplicates: Option<usize>,
  86    #[command(subcommand)]
  87    command: Option<Command>,
  88    /// Input file paths
  89    #[clap(global = true)]
  90    inputs: Vec<PathBuf>,
  91    #[arg(long, short, global = true)]
  92    output: Option<PathBuf>,
  93    #[arg(long, short, global = true)]
  94    in_place: bool,
  95    #[arg(long, short, global = true)]
  96    failfast: bool,
  97    /// How to handle failed examples in output: keep them or skip them.
  98    /// Failed examples are always logged to the run's failed directory.
  99    #[arg(long, global = true, default_value = "keep")]
 100    failed: FailedHandling,
 101    /// Output as markdown files instead of JSONL. When set, -o specifies a directory
 102    /// where one .md file per example will be written (named after each example).
 103    #[arg(long, short, global = true)]
 104    markdown: bool,
 105}
 106
 107/// Controls whether failed examples are included in the main output.
 108/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
 109#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
 110pub enum FailedHandling {
 111    /// Include failed examples in the main output (default)
 112    #[default]
 113    Keep,
 114    /// Exclude failed examples from the main output
 115    Skip,
 116    /// Skip writing files
 117    SkipNoFiles,
 118}
 119
 120const INPUTS_HELP: &str = r#"
 121Inputs can be file paths or special specifiers:
 122
 123  path
 124      Path to an example(s) file (.md, .json, or .jsonl)
 125
 126  captured-after:{timestamp}
 127      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 128      These are examples captured via the "Capture Edit Prediction Example" action.
 129
 130  rejected-after:{timestamp}
 131      Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
 132      These are predictions that were shown to users but rejected (useful for DPO training).
 133
 134  rated-after:{timestamp}
 135      Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
 136      These are predictions that users explicitly rated as positive or negative via the
 137      rate completions modal. Only zeta2 predictions are included.
 138      - Positive ratings: output becomes expected_patches
 139      - Negative ratings: output becomes rejected_patch
 140
 141  rated-positive-after:{timestamp}
 142      Same as rated-after, but only fetches positively rated predictions.
 143
 144  rated-negative-after:{timestamp}
 145      Same as rated-after, but only fetches negatively rated predictions.
 146
 147      Required environment variables to connect to Snowflake:
 148          EP_SNOWFLAKE_API_KEY
 149          EP_SNOWFLAKE_BASE_URL
 150
 151      Optional:
 152          EP_SNOWFLAKE_ROLE
 153
 154Examples:
 155
 156  # Read examples from a file
 157  ep read examples.jsonl -o output.jsonl
 158
 159  # Read captured examples after a timestamp
 160  ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
 161
 162  # Read rejected predictions for DPO training
 163  ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
 164
 165  # Read user-rated predictions
 166  ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
 167
 168  # Read only positively rated predictions
 169  ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
 170
 171  # Read only negatively rated predictions
 172  ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
 173
 174  # Mix multiple input sources
 175  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
 176"#;
 177
 178#[derive(Subcommand, Debug, Clone)]
 179enum Command {
 180    /// Read examples from files or fetch from Snowflake, output as .jsonl
 181    Read(ReadArgs),
 182    /// Create git worktrees for each example and load file contents
 183    LoadProject,
 184    /// Retrieve context for input examples.
 185    Context,
 186    /// Generate a prompt string for a specific model
 187    FormatPrompt(FormatPromptArgs),
 188    /// Runs edit prediction
 189    Predict(PredictArgs),
 190    /// Parse model outputs (actual_output) into unified diffs (actual_patch).
 191    /// Requires format-prompt to have been run first. Uses provider from prompt.
 192    ParseOutput,
 193    /// Computes a score based on actual and expected patches
 194    Score(PredictArgs),
 195    /// Prepares a distillation dataset by copying expected outputs to
 196    /// predicted outputs and removing actual outputs and prompts.
 197    Distill,
 198    /// Print aggregated scores
 199    Eval(EvalArgs),
 200    /// Generate eval examples by analyzing git commits from a repository
 201    Synthesize(SynthesizeArgs),
 202    /// Remove git repositories and worktrees
 203    Clean,
 204    /// Generate an evaluation example by splitting a chronologically-ordered commit
 205    SplitCommit(SplitCommitArgs),
 206    /// Truncate expected patch by the given criteria
 207    TruncatePatch(TruncatePatchArgs),
 208    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
 209    Split(SplitArgs),
 210    /// Filter a JSONL dataset by programming language (based on cursor_path extension)
 211    FilterLanguages(FilterLanguagesArgs),
 212    /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
 213    ImportBatch(ImportBatchArgs),
 214    /// Assess the quality of predictions using LLM-as-a-judge
 215    Qa(qa::QaArgs),
 216    /// Repair predictions that received poor QA scores by generating improved predictions
 217    Repair(repair::RepairArgs),
 218    /// Print all valid zeta formats (lowercase, one per line)
 219    PrintZetaFormats,
 220}
 221
 222impl Display for Command {
 223    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 224        match self {
 225            Command::Read(_) => write!(f, "read"),
 226            Command::LoadProject => write!(f, "load-project"),
 227            Command::Context => write!(f, "context"),
 228            Command::FormatPrompt(args) => {
 229                write!(f, "format-prompt --provider={}", args.provider)
 230            }
 231            Command::Predict(args) => match &args.provider {
 232                Some(provider) => write!(f, "predict --provider={}", provider),
 233                None => write!(f, "predict"),
 234            },
 235            Command::ParseOutput => write!(f, "parse-output"),
 236            Command::Score(args) => match &args.provider {
 237                Some(provider) => write!(f, "score --provider={}", provider),
 238                None => write!(f, "score"),
 239            },
 240            Command::Distill => write!(f, "distill"),
 241            Command::Eval(args) => match &args.predict.provider {
 242                Some(provider) => write!(f, "eval --provider={}", provider),
 243                None => write!(f, "eval"),
 244            },
 245            Command::Synthesize(args) => {
 246                write!(f, "synthesize --repos {}", args.repos.join(" "))
 247            }
 248            Command::Clean => write!(f, "clean"),
 249            Command::SplitCommit(_) => write!(f, "split-commit"),
 250            Command::TruncatePatch(_) => write!(f, "truncate-patch"),
 251            Command::Split(_) => write!(f, "split"),
 252            Command::FilterLanguages(_) => write!(f, "filter-languages"),
 253            Command::ImportBatch(args) => {
 254                write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
 255            }
 256            Command::Qa(_) => {
 257                write!(f, "qa")
 258            }
 259            Command::Repair(_) => {
 260                write!(f, "repair")
 261            }
 262            Command::PrintZetaFormats => {
 263                write!(f, "print-zeta-formats")
 264            }
 265        }
 266    }
 267}
 268
 269#[derive(Debug, Args, Clone)]
 270#[command(after_help = INPUTS_HELP)]
 271struct ReadArgs {}
 272
 273#[derive(Debug, Args, Clone)]
 274struct FormatPromptArgs {
 275    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
 276    provider: PredictionProvider,
 277}
 278
 279#[derive(Debug, Args, Clone)]
 280struct PredictArgs {
 281    #[clap(long, short('p'))]
 282    provider: Option<PredictionProvider>,
 283    #[clap(long, default_value_t = 1)]
 284    repetitions: usize,
 285    /// Only use cached responses, don't queue new requests for batching
 286    #[clap(long)]
 287    cache_only: bool,
 288}
 289
 290#[derive(Debug, Args, Clone)]
 291struct EvalArgs {
 292    #[clap(flatten)]
 293    predict: PredictArgs,
 294    /// Path to write summary scores as JSON
 295    #[clap(long)]
 296    summary_json: Option<PathBuf>,
 297}
 298
 299#[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)]
 300pub enum TeacherBackend {
 301    Sonnet46,
 302    #[default]
 303    Sonnet45,
 304    Gpt52,
 305}
 306
 307impl std::fmt::Display for TeacherBackend {
 308    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 309        match self {
 310            TeacherBackend::Sonnet46 => write!(f, "sonnet46"),
 311            TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
 312            TeacherBackend::Gpt52 => write!(f, "gpt52"),
 313        }
 314    }
 315}
 316
 317impl std::str::FromStr for TeacherBackend {
 318    type Err = anyhow::Error;
 319
 320    fn from_str(s: &str) -> Result<Self, Self::Err> {
 321        match s.to_lowercase().as_str() {
 322            "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
 323            "sonnet46" => Ok(TeacherBackend::Sonnet46),
 324            "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
 325            "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
 326            _ => anyhow::bail!(
 327                "unknown teacher backend `{s}`. Valid options: sonnet45, sonnet46, gpt52"
 328            ),
 329        }
 330    }
 331}
 332
 333impl TeacherBackend {
 334    pub fn model_name(&self) -> &'static str {
 335        match self {
 336            TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
 337            TeacherBackend::Sonnet46 => "claude-sonnet-4-6",
 338            TeacherBackend::Gpt52 => "gpt-5.2",
 339        }
 340    }
 341}
 342
 343#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 344enum PredictionProvider {
 345    Sweep,
 346    Mercury,
 347    Zeta1,
 348    Zeta2(ZetaFormat),
 349    Teacher(TeacherBackend),
 350    TeacherNonBatching(TeacherBackend),
 351    Repair,
 352}
 353
 354impl Default for PredictionProvider {
 355    fn default() -> Self {
 356        PredictionProvider::Zeta2(ZetaFormat::default())
 357    }
 358}
 359
 360impl std::fmt::Display for PredictionProvider {
 361    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 362        match self {
 363            PredictionProvider::Sweep => write!(f, "sweep"),
 364            PredictionProvider::Mercury => write!(f, "mercury"),
 365            PredictionProvider::Zeta1 => write!(f, "zeta1"),
 366            PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
 367            PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
 368            PredictionProvider::TeacherNonBatching(backend) => {
 369                write!(f, "teacher-non-batching:{backend}")
 370            }
 371            PredictionProvider::Repair => write!(f, "repair"),
 372        }
 373    }
 374}
 375
 376impl std::str::FromStr for PredictionProvider {
 377    type Err = anyhow::Error;
 378
 379    fn from_str(s: &str) -> Result<Self, Self::Err> {
 380        let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
 381
 382        let provider_lower = provider.to_lowercase();
 383        match provider_lower.as_str() {
 384            "sweep" => Ok(PredictionProvider::Sweep),
 385            "mercury" => Ok(PredictionProvider::Mercury),
 386            "zeta1" => Ok(PredictionProvider::Zeta1),
 387            "zeta2" => {
 388                let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
 389                Ok(PredictionProvider::Zeta2(format))
 390            }
 391            "teacher" => {
 392                let backend = arg
 393                    .map(|a| a.parse())
 394                    .transpose()?
 395                    .unwrap_or(TeacherBackend::default());
 396                Ok(PredictionProvider::Teacher(backend))
 397            }
 398            "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
 399                let backend = arg
 400                    .map(|a| a.parse())
 401                    .transpose()?
 402                    .unwrap_or(TeacherBackend::default());
 403                Ok(PredictionProvider::TeacherNonBatching(backend))
 404            }
 405            "repair" => Ok(PredictionProvider::Repair),
 406            _ => {
 407                anyhow::bail!(
 408                    "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
 409                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
 410                 For teacher, you can specify a backend like `teacher:sonnet46` or `teacher:gpt52`.\n\
 411                 Available zeta versions:\n{}",
 412                    ZetaFormat::options_as_string()
 413                )
 414            }
 415        }
 416    }
 417}
 418
 419impl Serialize for PredictionProvider {
 420    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 421    where
 422        S: Serializer,
 423    {
 424        serializer.serialize_str(&self.to_string())
 425    }
 426}
 427
 428impl<'de> Deserialize<'de> for PredictionProvider {
 429    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 430    where
 431        D: Deserializer<'de>,
 432    {
 433        let s = String::deserialize(deserializer)?;
 434        s.parse().map_err(serde::de::Error::custom)
 435    }
 436}
 437
 438#[derive(Debug, Args, Clone)]
 439struct SynthesizeArgs {
 440    /// Repository URLs (git@github.com:owner/repo or https://...)
 441    #[clap(long, required = true, num_args = 1..)]
 442    repos: Vec<String>,
 443
 444    /// Number of examples to generate per repository
 445    #[clap(long, default_value_t = 5)]
 446    count: usize,
 447
 448    /// Maximum commits to scan per repository before giving up
 449    #[clap(long, default_value_t = 100)]
 450    max_commits: usize,
 451
 452    /// Ignore state file and reprocess all commits
 453    #[clap(long)]
 454    fresh: bool,
 455}
 456
 457#[derive(Debug, Args, Clone)]
 458struct ImportBatchArgs {
 459    /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
 460    #[clap(long, required = true, num_args = 1..)]
 461    batch_ids: Vec<String>,
 462    /// Which provider's batches to import (anthropic or openai)
 463    #[clap(long, default_value = "anthropic")]
 464    provider: BatchProvider,
 465}
 466
 467#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
 468enum BatchProvider {
 469    Anthropic,
 470    Openai,
 471}
 472
 473impl EpArgs {
 474    fn output_path(&self) -> Option<PathBuf> {
 475        if self.in_place {
 476            if self.inputs.len() == 1 {
 477                self.inputs.first().cloned()
 478            } else {
 479                panic!("--in-place requires exactly one input file")
 480            }
 481        } else {
 482            self.output.clone()
 483        }
 484    }
 485}
 486
 487/// Minimum Zed version required for Snowflake queries.
 488/// This version introduced the current request schema with predicted edits in the edit
 489/// history, and open source repos distinguished.
 490const MIN_CAPTURE_VERSION: pull_examples::MinCaptureVersion = pull_examples::MinCaptureVersion {
 491    minor: 224,
 492    patch: 1,
 493};
 494
 495fn deduplicate_examples(examples: &mut Vec<Example>, max_per_cluster: usize) {
 496    let total_before_exact = examples.len();
 497    let mut seen_positions = HashSet::default();
 498    examples.retain(|example| seen_positions.insert(example.spec.cursor_position.clone()));
 499    log::info!(
 500        "exact duplicate filter: {total_before_exact} examples → {} examples ({} removed)",
 501        examples.len(),
 502        total_before_exact - examples.len(),
 503    );
 504
 505    const JACCARD_THRESHOLD: f64 = 0.5;
 506    const NUM_HASHES: usize = 128;
 507    const TOKEN_NGRAM_SIZE: usize = 5;
 508
 509    let (num_bands, band_width) = calculate_minhash_params(JACCARD_THRESHOLD, NUM_HASHES);
 510    let num_hashes = num_bands * band_width;
 511    let minhasher = MinHasher32::new(num_hashes);
 512    let mut index: MinHashIndex<u32, usize> =
 513        MinHashIndex::new(num_bands, band_width, JACCARD_THRESHOLD);
 514
 515    let signatures: Vec<Vec<u32>> = examples
 516        .iter()
 517        .map(|example| {
 518            let shingles = code_token_ngrams(&example.spec.cursor_position, TOKEN_NGRAM_SIZE);
 519            minhasher.create_signature(shingles.iter())
 520        })
 521        .collect();
 522
 523    for (id, signature) in signatures.iter().enumerate() {
 524        index.insert(id, signature.clone());
 525    }
 526
 527    // Build clusters via union-find on LSH candidate pairs.
 528    let mut parent: Vec<usize> = (0..examples.len()).collect();
 529
 530    fn find(parent: &mut Vec<usize>, mut x: usize) -> usize {
 531        while parent[x] != x {
 532            parent[x] = parent[parent[x]];
 533            x = parent[x];
 534        }
 535        x
 536    }
 537
 538    for (id, signature) in signatures.iter().enumerate() {
 539        for candidate in index.query_owned(signature) {
 540            let (a, b) = (find(&mut parent, id), find(&mut parent, candidate));
 541            if a != b {
 542                parent[a] = b;
 543            }
 544        }
 545    }
 546
 547    let mut clusters: HashMap<usize, Vec<usize>> = HashMap::default();
 548    for id in 0..examples.len() {
 549        clusters.entry(find(&mut parent, id)).or_default().push(id);
 550    }
 551
 552    let mut keep: HashSet<usize> = HashSet::default();
 553    for members in clusters.values() {
 554        let selected = greedy_max_min_diverse(members, &signatures, max_per_cluster);
 555        keep.extend(selected);
 556    }
 557
 558    let total = examples.len();
 559    let mut kept_indices: Vec<usize> = keep.into_iter().collect();
 560    kept_indices.sort();
 561
 562    let mut retained = Vec::with_capacity(kept_indices.len());
 563    for index in kept_indices.into_iter().rev() {
 564        retained.push(examples.swap_remove(index));
 565    }
 566    retained.reverse();
 567
 568    *examples = retained;
 569    log::info!(
 570        "near-duplicate filter: {total} examples → {} examples ({} removed)",
 571        examples.len(),
 572        total - examples.len(),
 573    );
 574}
 575
 576fn greedy_max_min_diverse(members: &[usize], signatures: &[Vec<u32>], k: usize) -> Vec<usize> {
 577    if members.len() <= k {
 578        return members.to_vec();
 579    }
 580
 581    let mut selected = vec![members[0]];
 582    let mut min_dist: HashMap<usize, f64> = HashMap::default();
 583    for &member in &members[1..] {
 584        let dist = 1.0 - compute_minhash_similarity(&signatures[selected[0]], &signatures[member]);
 585        min_dist.insert(member, dist);
 586    }
 587
 588    while selected.len() < k {
 589        let &best = min_dist
 590            .iter()
 591            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
 592            .map(|(id, _)| id)
 593            .expect("min_dist should not be empty when selected.len() < k");
 594        selected.push(best);
 595        min_dist.remove(&best);
 596
 597        let best_sig = &signatures[best];
 598        for (member, current_min) in min_dist.iter_mut() {
 599            let dist = 1.0 - compute_minhash_similarity(best_sig, &signatures[*member]);
 600            if dist < *current_min {
 601                *current_min = dist;
 602            }
 603        }
 604    }
 605
 606    selected
 607}
 608
 609fn code_token_ngrams(code: &str, ngram_size: usize) -> Vec<String> {
 610    let tokens: Vec<&str> = word_diff::tokenize(code)
 611        .into_iter()
 612        .filter(|t| !t.trim().is_empty())
 613        .collect();
 614
 615    if tokens.len() < ngram_size {
 616        return vec![tokens.join("\0")];
 617    }
 618
 619    tokens
 620        .windows(ngram_size)
 621        .map(|window| window.join("\0"))
 622        .collect()
 623}
 624
 625async fn load_examples(
 626    http_client: Arc<dyn http_client::HttpClient>,
 627    args: &EpArgs,
 628    output_path: Option<&PathBuf>,
 629    background_executor: BackgroundExecutor,
 630) -> anyhow::Result<Vec<Example>> {
 631    let mut captured_after_timestamps = Vec::new();
 632    let mut rejected_after_timestamps = Vec::new();
 633    let mut requested_after_timestamps = Vec::new();
 634    let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
 635        Vec::new();
 636    let mut file_inputs = Vec::new();
 637
 638    for input in &args.inputs {
 639        let input_string = input.to_string_lossy();
 640        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
 641            captured_after_timestamps.push(timestamp.to_string());
 642        } else if let Some(timestamp) =
 643            pull_examples::parse_rejected_after_input(input_string.as_ref())
 644        {
 645            rejected_after_timestamps.push(timestamp.to_string());
 646        } else if let Some(timestamp) =
 647            pull_examples::parse_requested_after_input(input_string.as_ref())
 648        {
 649            requested_after_timestamps.push(timestamp.to_string());
 650        } else if let Some((timestamp, rating_filter)) =
 651            pull_examples::parse_rated_after_input(input_string.as_ref())
 652        {
 653            rated_after_inputs.push((timestamp.to_string(), rating_filter));
 654        } else {
 655            file_inputs.push(input.clone());
 656        }
 657    }
 658
 659    let mut examples = read_example_files(&file_inputs);
 660
 661    // Apply offset to file examples first, then pass remaining offset to Snowflake.
 662    let file_example_count = examples.len();
 663    let remaining_offset = if let Some(offset) = args.offset {
 664        if offset >= file_example_count {
 665            examples.clear();
 666            offset - file_example_count
 667        } else {
 668            examples.splice(0..offset, []);
 669            0
 670        }
 671    } else {
 672        0
 673    };
 674
 675    Progress::global().set_total_examples(examples.len());
 676
 677    let remaining_limit_for_snowflake =
 678        args.limit.map(|limit| limit.saturating_sub(examples.len()));
 679
 680    if let Some(0) = remaining_limit_for_snowflake {
 681        log::info!(
 682            "skipping Snowflake inputs because --limit is already satisfied by example files"
 683        );
 684    } else {
 685        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
 686
 687        if !rejected_after_timestamps.is_empty() {
 688            rejected_after_timestamps.sort();
 689
 690            let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
 691                http_client.clone(),
 692                &rejected_after_timestamps,
 693                max_rows_per_timestamp,
 694                remaining_offset,
 695                background_executor.clone(),
 696                Some(MIN_CAPTURE_VERSION),
 697            )
 698            .await?;
 699            examples.append(&mut rejected_examples);
 700        }
 701
 702        if !requested_after_timestamps.is_empty() {
 703            requested_after_timestamps.sort();
 704
 705            let mut requested_examples = pull_examples::fetch_requested_examples_after(
 706                http_client.clone(),
 707                &requested_after_timestamps,
 708                max_rows_per_timestamp,
 709                remaining_offset,
 710                background_executor.clone(),
 711                Some(MIN_CAPTURE_VERSION),
 712            )
 713            .await?;
 714            examples.append(&mut requested_examples);
 715        }
 716
 717        if !rated_after_inputs.is_empty() {
 718            rated_after_inputs.sort();
 719
 720            let mut rated_examples = pull_examples::fetch_rated_examples_after(
 721                http_client,
 722                &rated_after_inputs,
 723                max_rows_per_timestamp,
 724                remaining_offset,
 725                background_executor,
 726                Some(MIN_CAPTURE_VERSION),
 727            )
 728            .await?;
 729            examples.append(&mut rated_examples);
 730        }
 731    }
 732
 733    crate::example::sort_examples_by_repo_and_rev(&mut examples);
 734
 735    if let Some(name_filter) = &args.name {
 736        examples.retain(|example| example.spec.name.contains(name_filter));
 737    }
 738    if let Some(repo_filter) = &args.repo {
 739        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
 740    }
 741
 742    // Skip resume logic for --in-place since input and output are the same file,
 743    // which would incorrectly treat all input examples as already processed.
 744    if !args.in_place {
 745        if let Some(path) = output_path
 746            && let Some(command) = &args.command
 747        {
 748            resume_from_output(path, &mut examples, command);
 749        }
 750    }
 751
 752    if let Some(max_duplicates) = args.max_duplicates {
 753        deduplicate_examples(&mut examples, max_duplicates);
 754    }
 755
 756    if let Some(limit) = args.limit {
 757        examples.truncate(limit);
 758    }
 759
 760    let progress = Progress::global();
 761    progress.set_total_examples(examples.len());
 762    progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
 763
 764    Ok(examples)
 765}
 766
 767fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
 768    let mut hasher = collections::FxHasher::default();
 769    spec.hash(&mut hasher);
 770    hasher.finish()
 771}
 772
 773fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
 774    let file = match File::open(path) {
 775        Ok(f) => f,
 776        Err(_) => return,
 777    };
 778
 779    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
 780
 781    let reader = BufReader::new(file);
 782    let mut kept_lines = Vec::new();
 783    let mut kept_hashes = HashSet::default();
 784
 785    for line in reader.lines() {
 786        let line = match line {
 787            Ok(l) => l,
 788            Err(_) => continue,
 789        };
 790
 791        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
 792            let hash = spec_hash(&output_example.spec);
 793            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
 794                let is_complete = match command {
 795                    Command::Qa(_) => output_example
 796                        .qa
 797                        .first()
 798                        .and_then(|q| q.as_ref())
 799                        .and_then(|q| q.confidence)
 800                        .is_some(),
 801                    Command::Repair(_) => output_example.predictions.iter().any(|p| {
 802                        p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
 803                    }),
 804                    _ => true,
 805                };
 806                if is_complete {
 807                    kept_hashes.insert(hash);
 808                    kept_lines.push(line);
 809                }
 810            }
 811        }
 812    }
 813
 814    let total = examples.len();
 815    let already_processed = kept_hashes.len();
 816
 817    eprintln!(
 818        "Resuming: {}/{} examples already processed",
 819        already_processed, total
 820    );
 821
 822    let file = OpenOptions::new()
 823        .write(true)
 824        .truncate(true)
 825        .open(path)
 826        .expect("Failed to open output file for rewriting");
 827    let mut writer = BufWriter::new(file);
 828    for line in &kept_lines {
 829        writeln!(writer, "{}", line).expect("Failed to write to output file");
 830    }
 831    writer.flush().expect("Failed to flush output file");
 832
 833    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
 834}
 835
 836fn main() {
 837    let args = EpArgs::parse();
 838
 839    if args.printenv {
 840        ::util::shell_env::print_env();
 841        return;
 842    }
 843
 844    let output = args.output_path();
 845
 846    if args.markdown && output.is_none() {
 847        eprintln!("--markdown requires -o to specify the output directory");
 848        std::process::exit(1);
 849    }
 850
 851    let command = match &args.command {
 852        Some(cmd) => cmd.clone(),
 853        None => {
 854            EpArgs::command().print_help().unwrap();
 855            return;
 856        }
 857    };
 858
 859    match &command {
 860        Command::ImportBatch(import_args) => {
 861            smol::block_on(async {
 862                match import_args.provider {
 863                    BatchProvider::Anthropic => {
 864                        let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
 865                            .expect("Failed to create Anthropic client");
 866                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 867                            eprintln!("Error importing Anthropic batches: {:?}", e);
 868                            std::process::exit(1);
 869                        }
 870                    }
 871                    BatchProvider::Openai => {
 872                        let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
 873                            .expect("Failed to create OpenAI client");
 874                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 875                            eprintln!("Error importing OpenAI batches: {:?}", e);
 876                            std::process::exit(1);
 877                        }
 878                    }
 879                }
 880                println!(
 881                    "Successfully imported {} batch(es)",
 882                    import_args.batch_ids.len()
 883                );
 884            });
 885            return;
 886        }
 887        Command::Clean => {
 888            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
 889            return;
 890        }
 891        Command::PrintZetaFormats => {
 892            use strum::IntoEnumIterator as _;
 893            for format in ZetaFormat::iter() {
 894                println!("{}", format.to_string().to_lowercase());
 895            }
 896            return;
 897        }
 898
 899        Command::Synthesize(synth_args) => {
 900            let Some(output_dir) = args.output else {
 901                panic!("output dir is required");
 902            };
 903            let config = SynthesizeConfig {
 904                repo_urls: synth_args.repos.clone(),
 905                count: synth_args.count,
 906                max_commits: synth_args.max_commits,
 907                output_dir,
 908                fresh: synth_args.fresh,
 909            };
 910            smol::block_on(async {
 911                if let Err(e) = run_synthesize(config).await {
 912                    eprintln!("Error: {:?}", e);
 913                    std::process::exit(1);
 914                }
 915            });
 916            return;
 917        }
 918        Command::SplitCommit(split_commit_args) => {
 919            if let Err(error) = split_commit::run_split_commit(
 920                split_commit_args,
 921                &args.inputs,
 922                output.as_ref(),
 923                args.failed,
 924            ) {
 925                eprintln!("{error:#}");
 926                std::process::exit(1);
 927            }
 928            return;
 929        }
 930        Command::TruncatePatch(truncate_args) => {
 931            if let Err(error) =
 932                truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
 933            {
 934                eprintln!("{error:#}");
 935                std::process::exit(1);
 936            }
 937            return;
 938        }
 939        Command::Split(split_args) => {
 940            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
 941                eprintln!("{error:#}");
 942                std::process::exit(1);
 943            }
 944            return;
 945        }
 946        Command::FilterLanguages(filter_args) => {
 947            if let Err(error) =
 948                run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
 949            {
 950                eprintln!("{error:#}");
 951                std::process::exit(1);
 952            }
 953            return;
 954        }
 955
 956        _ => {}
 957    }
 958
 959    let http_client = Arc::new(ReqwestClient::new());
 960    let app = gpui_platform::headless().with_http_client(http_client);
 961
 962    app.run(move |cx| {
 963        let app_state = Arc::new(headless::init(cx));
 964        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
 965
 966        cx.spawn(async move |cx| {
 967            let result = async {
 968                let examples = load_examples(
 969                    app_state.client.http_client(),
 970                    &args,
 971                    output.as_ref(),
 972                    cx.background_executor().clone(),
 973                )
 974                .await?;
 975
 976                match &command {
 977                    Command::Predict(args) | Command::Score(args) => {
 978                        predict::sync_batches(args.provider.as_ref()).await?;
 979                    }
 980                    Command::Eval(args) => {
 981                        predict::sync_batches(args.predict.provider.as_ref()).await?;
 982                    }
 983                    Command::Qa(args) => {
 984                        qa::sync_batches(args).await?;
 985                    }
 986                    Command::Repair(args) => {
 987                        repair::sync_batches(args).await?;
 988                    }
 989                    _ => (),
 990                }
 991
 992                let failfast_on_single_example = examples.len() == 1;
 993
 994                // For --markdown mode, create the output directory if it doesn't exist
 995                if args.markdown {
 996                    let dir = output.as_ref().expect("--markdown requires -o");
 997                    if !dir.exists() {
 998                        std::fs::create_dir_all(dir)
 999                            .expect("Failed to create markdown output directory");
1000                    }
1001                }
1002
1003                // Set up JSONL output writer (not used in markdown mode)
1004                let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
1005                let mut in_place_temp_path: Option<PathBuf> = None;
1006                if !args.markdown
1007                    && let Some(output_path) = output.as_ref()
1008                {
1009                    let write_path = if args.in_place {
1010                        let temp = output_path.with_extension("jsonl.tmp");
1011                        in_place_temp_path = Some(temp.clone());
1012                        temp
1013                    } else {
1014                        output_path.clone()
1015                    };
1016
1017                    let file = OpenOptions::new()
1018                        .create(true)
1019                        .write(true)
1020                        .truncate(args.in_place)
1021                        .append(!args.in_place)
1022                        .open(&write_path)
1023                        .expect("Failed to open output file");
1024
1025                    let mut writer = BufWriter::new(file);
1026                    let (sender, mut receiver) = mpsc::unbounded::<String>();
1027                    cx.background_spawn(async move {
1028                        while let Some(line) = receiver.next().await {
1029                            writeln!(writer, "{}", line).expect("Failed to write example");
1030                            writer.flush().expect("Failed to flush output");
1031                        }
1032                    })
1033                    .detach();
1034                    output_sender = Some(sender);
1035                }
1036
1037                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
1038                let finished_examples = Mutex::new(Vec::new());
1039
1040                let mut tasks = Vec::new();
1041                for _ in 0..args.max_parallelism {
1042                    tasks.push(async {
1043                        loop {
1044                            let Some(mut repo_examples) =
1045                                grouped_examples.lock().unwrap().pop_front()
1046                            else {
1047                                break;
1048                            };
1049                            for example in &mut repo_examples {
1050                                let example_progress =
1051                                    Progress::global().start_group(&example.spec.name);
1052
1053                                let result = async {
1054                                    match &command {
1055                                        Command::Read(_) => {}
1056                                        Command::LoadProject => {
1057                                            run_load_project(
1058                                                example,
1059                                                app_state.clone(),
1060                                                &example_progress,
1061                                                cx.clone(),
1062                                            )
1063                                            .await?;
1064                                        }
1065                                        Command::Context => {
1066                                            run_context_retrieval(
1067                                                example,
1068                                                app_state.clone(),
1069                                                &example_progress,
1070                                                cx.clone(),
1071                                            )
1072                                            .await?;
1073                                        }
1074                                        Command::FormatPrompt(args) => {
1075                                            run_format_prompt(
1076                                                example,
1077                                                args,
1078                                                app_state.clone(),
1079                                                &example_progress,
1080                                                cx.clone(),
1081                                            )
1082                                            .await?;
1083                                        }
1084                                        Command::Predict(args) => {
1085                                            run_prediction(
1086                                                example,
1087                                                args,
1088                                                app_state.clone(),
1089                                                &example_progress,
1090                                                cx.clone(),
1091                                            )
1092                                            .await?;
1093                                        }
1094                                        Command::ParseOutput => {
1095                                            parse_output::run_parse_output(example)?;
1096                                        }
1097                                        Command::Distill => {
1098                                            run_distill(example).await?;
1099                                        }
1100                                        Command::Score(args) => {
1101                                            run_scoring(
1102                                                example,
1103                                                args,
1104                                                app_state.clone(),
1105                                                &example_progress,
1106                                                cx.clone(),
1107                                            )
1108                                            .await?;
1109                                        }
1110                                        Command::Eval(args) => {
1111                                            run_scoring(
1112                                                example,
1113                                                &args.predict,
1114                                                app_state.clone(),
1115                                                &example_progress,
1116                                                cx.clone(),
1117                                            )
1118                                            .await?;
1119                                        }
1120                                        Command::Qa(args) => {
1121                                            qa::run_qa(example, args, &example_progress).await?;
1122                                        }
1123                                        Command::Repair(args) => {
1124                                            repair::run_repair(example, args, &example_progress)
1125                                                .await?;
1126                                        }
1127                                        Command::Clean
1128                                        | Command::Synthesize(_)
1129                                        | Command::SplitCommit(_)
1130                                        | Command::Split(_)
1131                                        | Command::TruncatePatch(_)
1132                                        | Command::FilterLanguages(_)
1133                                        | Command::ImportBatch(_)
1134                                        | Command::PrintZetaFormats => {
1135                                            unreachable!()
1136                                        }
1137                                    }
1138                                    anyhow::Ok(())
1139                                }
1140                                .await;
1141
1142                                let failed = if let Err(error) = result {
1143                                    handle_error(
1144                                        error,
1145                                        &args,
1146                                        &command,
1147                                        &app_state,
1148                                        failfast_on_single_example,
1149                                        &example,
1150                                    )
1151                                    .await;
1152                                    true
1153                                } else {
1154                                    false
1155                                };
1156
1157                                let should_write = !failed || args.failed == FailedHandling::Keep;
1158                                if should_write {
1159                                    if args.markdown {
1160                                        let markdown_dir =
1161                                            output.as_ref().expect("--markdown requires -o");
1162                                        let filename = format!("{}.md", example.spec.filename());
1163                                        let path = markdown_dir.join(&filename);
1164                                        let markdown = example.spec.to_markdown();
1165                                        std::fs::write(&path, &markdown)
1166                                            .expect("Failed to write markdown file");
1167                                    } else if let Some(ref mut sender) = output_sender.clone() {
1168                                        let line = serde_json::to_string(&example).unwrap();
1169                                        sender
1170                                            .send(line)
1171                                            .await
1172                                            .expect("Failed to send to output writer");
1173                                    } else if args.output.is_none()
1174                                        && !matches!(command, Command::Eval(_))
1175                                    {
1176                                        let line = serde_json::to_string(&example).unwrap();
1177                                        println!("{}", line);
1178                                    }
1179                                }
1180                            }
1181
1182                            let project = repo_examples
1183                                .iter()
1184                                .find_map(|e| e.state.as_ref().map(|s| s.project.clone()));
1185
1186                            if let Some(project) = project {
1187                                let mut cx = cx.clone();
1188
1189                                let shutdown_task: Task<()> =
1190                                    project.update(&mut cx, |project, cx| {
1191                                        let lsp_store = project.lsp_store();
1192                                        lsp_store.update(cx, |lsp_store, cx| {
1193                                            lsp_store.shutdown_all_language_servers(cx)
1194                                        })
1195                                    });
1196
1197                                shutdown_task.await;
1198
1199                                if let Some(ep_store) =
1200                                    cx.update(|cx| EditPredictionStore::try_global(cx))
1201                                {
1202                                    ep_store.update(&mut cx, |store, _| {
1203                                        store.remove_project(&project);
1204                                    });
1205                                }
1206                            }
1207
1208                            for example in &mut repo_examples {
1209                                example.state.take();
1210                            }
1211                            finished_examples
1212                                .lock()
1213                                .unwrap()
1214                                .extend_from_slice(&repo_examples);
1215                        }
1216                    });
1217                }
1218                futures::future::join_all(tasks).await;
1219
1220                Progress::global().finalize();
1221
1222                match &command {
1223                    Command::Predict(args) | Command::Score(args) => {
1224                        predict::sync_batches(args.provider.as_ref()).await?;
1225                    }
1226                    Command::Eval(args) => {
1227                        predict::sync_batches(args.predict.provider.as_ref()).await?;
1228                    }
1229                    Command::Qa(args) => {
1230                        qa::sync_batches(args).await?;
1231                    }
1232                    Command::Repair(args) => {
1233                        repair::sync_batches(args).await?;
1234                    }
1235                    _ => (),
1236                }
1237
1238                match &command {
1239                    Command::Eval(args) => {
1240                        let examples = finished_examples.lock().unwrap();
1241                        score::print_report(&examples);
1242                        if let Some(summary_path) = &args.summary_json {
1243                            score::write_summary_json(&examples, summary_path)?;
1244                        }
1245                    }
1246                    Command::Repair(args) => {
1247                        let examples = finished_examples.lock().unwrap();
1248                        repair::print_report(&examples, args.confidence_threshold);
1249                    }
1250                    _ => (),
1251                };
1252
1253                // For --in-place, atomically rename temp file to original
1254                if let Some(temp_path) = &in_place_temp_path {
1255                    let final_path = output.as_ref().expect("in_place_temp_path requires output");
1256                    std::fs::rename(temp_path, final_path)
1257                        .expect("Failed to rename temp file to final output");
1258                }
1259
1260                anyhow::Ok(())
1261            }
1262            .await;
1263
1264            if let Err(e) = result {
1265                panic!("Fatal error: {:?}", e);
1266            }
1267
1268            let _ = cx.update(|cx| cx.quit());
1269        })
1270        .detach();
1271    });
1272}
1273
1274async fn handle_error(
1275    error: anyhow::Error,
1276    args: &EpArgs,
1277    command: &Command,
1278    app_state: &Arc<headless::EpAppState>,
1279    failfast_on_single_example: bool,
1280    example: &Example,
1281) {
1282    Progress::global().increment_failed();
1283
1284    let msg;
1285    if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1286        let example_name = example.spec.filename();
1287
1288        let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1289        app_state
1290            .fs
1291            .write(
1292                &failed_example_path,
1293                &serde_json::to_vec_pretty(&example).unwrap(),
1294            )
1295            .await
1296            .unwrap();
1297        let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1298        app_state
1299            .fs
1300            .write(&err_path, format!("{error:?}").as_bytes())
1301            .await
1302            .unwrap();
1303
1304        let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1305        let mut file = OpenOptions::new()
1306            .create(true)
1307            .append(true)
1308            .open(&failed_jsonl_path)
1309            .expect("Failed to open failed.jsonl");
1310        writeln!(file, "{}", serde_json::to_string(example).unwrap())
1311            .expect("Failed to write to failed.jsonl");
1312
1313        let cursor_path = match example.repo_name() {
1314            Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1315            Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1316        };
1317        msg = format!(
1318            indoc::indoc! {"
1319                While processing \"{}\":
1320
1321                \x1b[31m{:?}\x1b[0m
1322
1323                Example:        \x1b[36m{}\x1b[0m
1324                Error file:     \x1b[36m{}\x1b[0m
1325                Cursor file:    \x1b[36m{}\x1b[0m
1326                Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1327            "},
1328            example.spec.name,
1329            error,
1330            failed_example_path.display(),
1331            err_path.display(),
1332            cursor_path.display(),
1333            command,
1334            failed_example_path.display(),
1335        );
1336    } else {
1337        msg = format!(
1338            indoc::indoc! {"
1339            While processing \"{}\":
1340
1341                \x1b[31m{:?}\x1b[0m
1342            "},
1343            example.spec.name, error
1344        );
1345    }
1346
1347    if args.failfast || failfast_on_single_example {
1348        Progress::global().finalize();
1349        panic!("{}", msg);
1350    } else {
1351        log::error!("{}", msg);
1352    }
1353}