main.rs

   1mod anthropic_client;
   2mod distill;
   3mod example;
   4mod filter_languages;
   5mod format_prompt;
   6mod git;
   7mod headless;
   8mod load_project;
   9mod metrics;
  10mod openai_client;
  11mod parse_output;
  12mod paths;
  13mod predict;
  14mod progress;
  15mod prompt_assets;
  16mod pull_examples;
  17mod qa;
  18mod reorder_patch;
  19mod repair;
  20mod retrieve_context;
  21mod reversal_tracking;
  22mod score;
  23mod split_commit;
  24mod split_dataset;
  25
  26mod synthesize;
  27mod truncate_expected_patch;
  28mod word_diff;
  29use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
  30use collections::{HashMap, HashSet};
  31use edit_prediction::EditPredictionStore;
  32use futures::channel::mpsc;
  33use futures::{SinkExt as _, StreamExt as _};
  34use gaoya::minhash::{
  35    MinHashIndex, MinHasher, MinHasher32, calculate_minhash_params, compute_minhash_similarity,
  36};
  37use gpui::{AppContext as _, BackgroundExecutor, Task};
  38use zeta_prompt::ZetaFormat;
  39
  40use reqwest_client::ReqwestClient;
  41use serde::{Deserialize, Deserializer, Serialize, Serializer};
  42use std::fmt::Display;
  43use std::fs::{File, OpenOptions};
  44use std::hash::{Hash, Hasher};
  45use std::io::{BufRead, BufReader, BufWriter, Write};
  46use std::sync::Mutex;
  47use std::{path::PathBuf, sync::Arc};
  48
  49use crate::distill::run_distill;
  50use crate::example::{Example, group_examples_by_repo, read_example_files};
  51use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
  52use crate::format_prompt::run_format_prompt;
  53use crate::load_project::run_load_project;
  54use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
  55use crate::predict::run_prediction;
  56use crate::progress::Progress;
  57use crate::retrieve_context::run_context_retrieval;
  58use crate::score::run_scoring;
  59use crate::split_commit::SplitCommitArgs;
  60use crate::split_dataset::SplitArgs;
  61use crate::synthesize::{SynthesizeConfig, run_synthesize};
  62use crate::truncate_expected_patch::TruncatePatchArgs;
  63
  64#[derive(Parser, Debug)]
  65#[command(name = "ep")]
  66struct EpArgs {
  67    #[arg(long, default_value_t = false)]
  68    printenv: bool,
  69    #[clap(long, default_value_t = 10, global = true)]
  70    max_parallelism: usize,
  71    /// The limit for the number of examples to process
  72    /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
  73    #[clap(long, global = true)]
  74    limit: Option<usize>,
  75    #[clap(long, global = true)]
  76    offset: Option<usize>,
  77    /// Filter examples by name
  78    #[clap(long, global = true)]
  79    name: Option<String>,
  80    /// Filter examples by repository
  81    #[clap(long, global = true)]
  82    repo: Option<String>,
  83    /// Deduplicate by cursor position and keep at most this many examples per cluster
  84    #[clap(long, global = true)]
  85    max_duplicates: Option<usize>,
  86    #[command(subcommand)]
  87    command: Option<Command>,
  88    /// Input file paths
  89    #[clap(global = true)]
  90    inputs: Vec<PathBuf>,
  91    #[arg(long, short, global = true)]
  92    output: Option<PathBuf>,
  93    #[arg(long, short, global = true)]
  94    in_place: bool,
  95    #[arg(long, short, global = true)]
  96    failfast: bool,
  97    /// How to handle failed examples in output: keep them or skip them.
  98    /// Failed examples are always logged to the run's failed directory.
  99    #[arg(long, global = true, default_value = "keep")]
 100    failed: FailedHandling,
 101    /// Output as markdown files instead of JSONL. When set, -o specifies a directory
 102    /// where one .md file per example will be written (named after each example).
 103    #[arg(long, short, global = true)]
 104    markdown: bool,
 105}
 106
 107/// Controls whether failed examples are included in the main output.
 108/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
 109#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
 110pub enum FailedHandling {
 111    /// Include failed examples in the main output (default)
 112    #[default]
 113    Keep,
 114    /// Exclude failed examples from the main output
 115    Skip,
 116    /// Skip writing files
 117    SkipNoFiles,
 118}
 119
 120const INPUTS_HELP: &str = r#"
 121Inputs can be file paths or special specifiers:
 122
 123  path
 124      Path to an example(s) file (.md, .json, or .jsonl)
 125
 126  captured-after:{timestamp}
 127      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 128      These are examples captured via the "Capture Edit Prediction Example" action.
 129
 130  rejected-after:{timestamp}
 131      Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
 132      These are predictions that were shown to users but rejected (useful for DPO training).
 133
 134  rated-after:{timestamp}
 135      Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
 136      These are predictions that users explicitly rated as positive or negative via the
 137      rate completions modal. Only zeta2 predictions are included.
 138      - Positive ratings: output becomes expected_patches
 139      - Negative ratings: output becomes rejected_patch
 140
 141  rated-positive-after:{timestamp}
 142      Same as rated-after, but only fetches positively rated predictions.
 143
 144  rated-negative-after:{timestamp}
 145      Same as rated-after, but only fetches negatively rated predictions.
 146
 147      Required environment variables to connect to Snowflake:
 148          EP_SNOWFLAKE_API_KEY
 149          EP_SNOWFLAKE_BASE_URL
 150
 151      Optional:
 152          EP_SNOWFLAKE_ROLE
 153
 154Examples:
 155
 156  # Read examples from a file
 157  ep read examples.jsonl -o output.jsonl
 158
 159  # Read captured examples after a timestamp
 160  ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
 161
 162  # Read rejected predictions for DPO training
 163  ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
 164
 165  # Read user-rated predictions
 166  ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
 167
 168  # Read only positively rated predictions
 169  ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
 170
 171  # Read only negatively rated predictions
 172  ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
 173
 174  # Mix multiple input sources
 175  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
 176"#;
 177
 178#[derive(Subcommand, Debug, Clone)]
 179enum Command {
 180    /// Read examples from files or fetch from Snowflake, output as .jsonl
 181    Read(ReadArgs),
 182    /// Create git worktrees for each example and load file contents
 183    LoadProject,
 184    /// Retrieve context for input examples.
 185    Context,
 186    /// Generate a prompt string for a specific model
 187    FormatPrompt(FormatPromptArgs),
 188    /// Runs edit prediction
 189    Predict(PredictArgs),
 190    /// Parse model outputs (actual_output) into unified diffs (actual_patch).
 191    /// Requires format-prompt to have been run first. Uses provider from prompt.
 192    ParseOutput,
 193    /// Computes a score based on actual and expected patches
 194    Score(PredictArgs),
 195    /// Prepares a distillation dataset by copying expected outputs to
 196    /// predicted outputs and removing actual outputs and prompts.
 197    Distill,
 198    /// Print aggregated scores
 199    Eval(EvalArgs),
 200    /// Generate eval examples by analyzing git commits from a repository
 201    Synthesize(SynthesizeArgs),
 202    /// Remove git repositories and worktrees
 203    Clean,
 204    /// Generate an evaluation example by splitting a chronologically-ordered commit
 205    SplitCommit(SplitCommitArgs),
 206    /// Truncate expected patch by the given criteria
 207    TruncatePatch(TruncatePatchArgs),
 208    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
 209    Split(SplitArgs),
 210    /// Filter a JSONL dataset by programming language (based on cursor_path extension)
 211    FilterLanguages(FilterLanguagesArgs),
 212    /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
 213    ImportBatch(ImportBatchArgs),
 214    /// Assess the quality of predictions using LLM-as-a-judge
 215    Qa(qa::QaArgs),
 216    /// Repair predictions that received poor QA scores by generating improved predictions
 217    Repair(repair::RepairArgs),
 218    /// Print all valid zeta formats (lowercase, one per line)
 219    PrintZetaFormats,
 220}
 221
 222impl Display for Command {
 223    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 224        match self {
 225            Command::Read(_) => write!(f, "read"),
 226            Command::LoadProject => write!(f, "load-project"),
 227            Command::Context => write!(f, "context"),
 228            Command::FormatPrompt(args) => {
 229                write!(f, "format-prompt --provider={}", args.provider)
 230            }
 231            Command::Predict(args) => match &args.provider {
 232                Some(provider) => write!(f, "predict --provider={}", provider),
 233                None => write!(f, "predict"),
 234            },
 235            Command::ParseOutput => write!(f, "parse-output"),
 236            Command::Score(args) => match &args.provider {
 237                Some(provider) => write!(f, "score --provider={}", provider),
 238                None => write!(f, "score"),
 239            },
 240            Command::Distill => write!(f, "distill"),
 241            Command::Eval(args) => match &args.predict.provider {
 242                Some(provider) => write!(f, "eval --provider={}", provider),
 243                None => write!(f, "eval"),
 244            },
 245            Command::Synthesize(args) => {
 246                write!(f, "synthesize --repos {}", args.repos.join(" "))
 247            }
 248            Command::Clean => write!(f, "clean"),
 249            Command::SplitCommit(_) => write!(f, "split-commit"),
 250            Command::TruncatePatch(_) => write!(f, "truncate-patch"),
 251            Command::Split(_) => write!(f, "split"),
 252            Command::FilterLanguages(_) => write!(f, "filter-languages"),
 253            Command::ImportBatch(args) => {
 254                write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
 255            }
 256            Command::Qa(_) => {
 257                write!(f, "qa")
 258            }
 259            Command::Repair(_) => {
 260                write!(f, "repair")
 261            }
 262            Command::PrintZetaFormats => {
 263                write!(f, "print-zeta-formats")
 264            }
 265        }
 266    }
 267}
 268
 269#[derive(Debug, Args, Clone)]
 270#[command(after_help = INPUTS_HELP)]
 271struct ReadArgs {}
 272
 273#[derive(Debug, Args, Clone)]
 274struct FormatPromptArgs {
 275    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
 276    provider: PredictionProvider,
 277}
 278
 279#[derive(Debug, Args, Clone)]
 280struct PredictArgs {
 281    #[clap(long, short('p'))]
 282    provider: Option<PredictionProvider>,
 283    #[clap(long, default_value_t = 1)]
 284    repetitions: usize,
 285    /// Only use cached responses, don't queue new requests for batching
 286    #[clap(long)]
 287    cache_only: bool,
 288}
 289
 290#[derive(Debug, Args, Clone)]
 291struct EvalArgs {
 292    #[clap(flatten)]
 293    predict: PredictArgs,
 294    /// Path to write summary scores as JSON
 295    #[clap(long)]
 296    summary_json: Option<PathBuf>,
 297    /// Print all individual example lines (default: up to 20)
 298    #[clap(long)]
 299    verbose: bool,
 300}
 301
 302#[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)]
 303pub enum TeacherBackend {
 304    Sonnet46,
 305    #[default]
 306    Sonnet45,
 307    Gpt52,
 308}
 309
 310impl std::fmt::Display for TeacherBackend {
 311    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 312        match self {
 313            TeacherBackend::Sonnet46 => write!(f, "sonnet46"),
 314            TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
 315            TeacherBackend::Gpt52 => write!(f, "gpt52"),
 316        }
 317    }
 318}
 319
 320impl std::str::FromStr for TeacherBackend {
 321    type Err = anyhow::Error;
 322
 323    fn from_str(s: &str) -> Result<Self, Self::Err> {
 324        match s.to_lowercase().as_str() {
 325            "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
 326            "sonnet46" => Ok(TeacherBackend::Sonnet46),
 327            "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
 328            "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
 329            _ => anyhow::bail!(
 330                "unknown teacher backend `{s}`. Valid options: sonnet45, sonnet46, gpt52"
 331            ),
 332        }
 333    }
 334}
 335
 336impl TeacherBackend {
 337    pub fn model_name(&self) -> &'static str {
 338        match self {
 339            TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
 340            TeacherBackend::Sonnet46 => "claude-sonnet-4-6",
 341            TeacherBackend::Gpt52 => "gpt-5.2",
 342        }
 343    }
 344}
 345
 346#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 347enum PredictionProvider {
 348    Sweep,
 349    Mercury,
 350    Zeta1,
 351    Zeta2(ZetaFormat),
 352    Teacher(TeacherBackend),
 353    TeacherNonBatching(TeacherBackend),
 354    Repair,
 355}
 356
 357impl Default for PredictionProvider {
 358    fn default() -> Self {
 359        PredictionProvider::Zeta2(ZetaFormat::default())
 360    }
 361}
 362
 363impl std::fmt::Display for PredictionProvider {
 364    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 365        match self {
 366            PredictionProvider::Sweep => write!(f, "sweep"),
 367            PredictionProvider::Mercury => write!(f, "mercury"),
 368            PredictionProvider::Zeta1 => write!(f, "zeta1"),
 369            PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
 370            PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
 371            PredictionProvider::TeacherNonBatching(backend) => {
 372                write!(f, "teacher-non-batching:{backend}")
 373            }
 374            PredictionProvider::Repair => write!(f, "repair"),
 375        }
 376    }
 377}
 378
 379impl std::str::FromStr for PredictionProvider {
 380    type Err = anyhow::Error;
 381
 382    fn from_str(s: &str) -> Result<Self, Self::Err> {
 383        let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
 384
 385        let provider_lower = provider.to_lowercase();
 386        match provider_lower.as_str() {
 387            "sweep" => Ok(PredictionProvider::Sweep),
 388            "mercury" => Ok(PredictionProvider::Mercury),
 389            "zeta1" => Ok(PredictionProvider::Zeta1),
 390            "zeta2" => {
 391                let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
 392                Ok(PredictionProvider::Zeta2(format))
 393            }
 394            "teacher" => {
 395                let backend = arg
 396                    .map(|a| a.parse())
 397                    .transpose()?
 398                    .unwrap_or(TeacherBackend::default());
 399                Ok(PredictionProvider::Teacher(backend))
 400            }
 401            "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
 402                let backend = arg
 403                    .map(|a| a.parse())
 404                    .transpose()?
 405                    .unwrap_or(TeacherBackend::default());
 406                Ok(PredictionProvider::TeacherNonBatching(backend))
 407            }
 408            "repair" => Ok(PredictionProvider::Repair),
 409            _ => {
 410                anyhow::bail!(
 411                    "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
 412                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
 413                 For teacher, you can specify a backend like `teacher:sonnet46` or `teacher:gpt52`.\n\
 414                 Available zeta versions:\n{}",
 415                    ZetaFormat::options_as_string()
 416                )
 417            }
 418        }
 419    }
 420}
 421
 422impl Serialize for PredictionProvider {
 423    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 424    where
 425        S: Serializer,
 426    {
 427        serializer.serialize_str(&self.to_string())
 428    }
 429}
 430
 431impl<'de> Deserialize<'de> for PredictionProvider {
 432    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 433    where
 434        D: Deserializer<'de>,
 435    {
 436        let s = String::deserialize(deserializer)?;
 437        s.parse().map_err(serde::de::Error::custom)
 438    }
 439}
 440
 441#[derive(Debug, Args, Clone)]
 442struct SynthesizeArgs {
 443    /// Repository URLs (git@github.com:owner/repo or https://...)
 444    #[clap(long, required = true, num_args = 1..)]
 445    repos: Vec<String>,
 446
 447    /// Number of examples to generate per repository
 448    #[clap(long, default_value_t = 5)]
 449    count: usize,
 450
 451    /// Maximum commits to scan per repository before giving up
 452    #[clap(long, default_value_t = 100)]
 453    max_commits: usize,
 454
 455    /// Ignore state file and reprocess all commits
 456    #[clap(long)]
 457    fresh: bool,
 458}
 459
 460#[derive(Debug, Args, Clone)]
 461struct ImportBatchArgs {
 462    /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
 463    #[clap(long, required = true, num_args = 1..)]
 464    batch_ids: Vec<String>,
 465    /// Which provider's batches to import (anthropic or openai)
 466    #[clap(long, default_value = "anthropic")]
 467    provider: BatchProvider,
 468}
 469
 470#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
 471enum BatchProvider {
 472    Anthropic,
 473    Openai,
 474}
 475
 476impl EpArgs {
 477    fn output_path(&self) -> Option<PathBuf> {
 478        if self.in_place {
 479            if self.inputs.len() == 1 {
 480                self.inputs.first().cloned()
 481            } else {
 482                panic!("--in-place requires exactly one input file")
 483            }
 484        } else {
 485            self.output.clone()
 486        }
 487    }
 488}
 489
 490/// Minimum Zed version required for Snowflake queries.
 491/// This version introduced the current request schema with predicted edits in the edit
 492/// history, and open source repos distinguished.
 493const MIN_CAPTURE_VERSION: pull_examples::MinCaptureVersion = pull_examples::MinCaptureVersion {
 494    minor: 224,
 495    patch: 1,
 496};
 497
 498fn deduplicate_examples(examples: &mut Vec<Example>, max_per_cluster: usize) {
 499    let total_before_exact = examples.len();
 500    let mut seen_positions = HashSet::default();
 501    examples.retain(|example| seen_positions.insert(example.spec.cursor_position.clone()));
 502    log::info!(
 503        "exact duplicate filter: {total_before_exact} examples → {} examples ({} removed)",
 504        examples.len(),
 505        total_before_exact - examples.len(),
 506    );
 507
 508    const JACCARD_THRESHOLD: f64 = 0.5;
 509    const NUM_HASHES: usize = 128;
 510    const TOKEN_NGRAM_SIZE: usize = 5;
 511
 512    let (num_bands, band_width) = calculate_minhash_params(JACCARD_THRESHOLD, NUM_HASHES);
 513    let num_hashes = num_bands * band_width;
 514    let minhasher = MinHasher32::new(num_hashes);
 515    let mut index: MinHashIndex<u32, usize> =
 516        MinHashIndex::new(num_bands, band_width, JACCARD_THRESHOLD);
 517
 518    let signatures: Vec<Vec<u32>> = examples
 519        .iter()
 520        .map(|example| {
 521            let shingles = code_token_ngrams(&example.spec.cursor_position, TOKEN_NGRAM_SIZE);
 522            minhasher.create_signature(shingles.iter())
 523        })
 524        .collect();
 525
 526    for (id, signature) in signatures.iter().enumerate() {
 527        index.insert(id, signature.clone());
 528    }
 529
 530    // Build clusters via union-find on LSH candidate pairs.
 531    let mut parent: Vec<usize> = (0..examples.len()).collect();
 532
 533    fn find(parent: &mut Vec<usize>, mut x: usize) -> usize {
 534        while parent[x] != x {
 535            parent[x] = parent[parent[x]];
 536            x = parent[x];
 537        }
 538        x
 539    }
 540
 541    for (id, signature) in signatures.iter().enumerate() {
 542        for candidate in index.query_owned(signature) {
 543            let (a, b) = (find(&mut parent, id), find(&mut parent, candidate));
 544            if a != b {
 545                parent[a] = b;
 546            }
 547        }
 548    }
 549
 550    let mut clusters: HashMap<usize, Vec<usize>> = HashMap::default();
 551    for id in 0..examples.len() {
 552        clusters.entry(find(&mut parent, id)).or_default().push(id);
 553    }
 554
 555    let mut keep: HashSet<usize> = HashSet::default();
 556    for members in clusters.values() {
 557        let selected = greedy_max_min_diverse(members, &signatures, max_per_cluster);
 558        keep.extend(selected);
 559    }
 560
 561    let total = examples.len();
 562    let mut kept_indices: Vec<usize> = keep.into_iter().collect();
 563    kept_indices.sort();
 564
 565    let mut retained = Vec::with_capacity(kept_indices.len());
 566    for index in kept_indices.into_iter().rev() {
 567        retained.push(examples.swap_remove(index));
 568    }
 569    retained.reverse();
 570
 571    *examples = retained;
 572    log::info!(
 573        "near-duplicate filter: {total} examples → {} examples ({} removed)",
 574        examples.len(),
 575        total - examples.len(),
 576    );
 577}
 578
 579fn greedy_max_min_diverse(members: &[usize], signatures: &[Vec<u32>], k: usize) -> Vec<usize> {
 580    if members.len() <= k {
 581        return members.to_vec();
 582    }
 583
 584    let mut selected = vec![members[0]];
 585    let mut min_dist: HashMap<usize, f64> = HashMap::default();
 586    for &member in &members[1..] {
 587        let dist = 1.0 - compute_minhash_similarity(&signatures[selected[0]], &signatures[member]);
 588        min_dist.insert(member, dist);
 589    }
 590
 591    while selected.len() < k {
 592        let &best = min_dist
 593            .iter()
 594            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
 595            .map(|(id, _)| id)
 596            .expect("min_dist should not be empty when selected.len() < k");
 597        selected.push(best);
 598        min_dist.remove(&best);
 599
 600        let best_sig = &signatures[best];
 601        for (member, current_min) in min_dist.iter_mut() {
 602            let dist = 1.0 - compute_minhash_similarity(best_sig, &signatures[*member]);
 603            if dist < *current_min {
 604                *current_min = dist;
 605            }
 606        }
 607    }
 608
 609    selected
 610}
 611
 612fn code_token_ngrams(code: &str, ngram_size: usize) -> Vec<String> {
 613    let tokens: Vec<&str> = word_diff::tokenize(code)
 614        .into_iter()
 615        .filter(|t| !t.trim().is_empty())
 616        .collect();
 617
 618    if tokens.len() < ngram_size {
 619        return vec![tokens.join("\0")];
 620    }
 621
 622    tokens
 623        .windows(ngram_size)
 624        .map(|window| window.join("\0"))
 625        .collect()
 626}
 627
 628async fn load_examples(
 629    http_client: Arc<dyn http_client::HttpClient>,
 630    args: &EpArgs,
 631    output_path: Option<&PathBuf>,
 632    background_executor: BackgroundExecutor,
 633) -> anyhow::Result<Vec<Example>> {
 634    let mut captured_after_timestamps = Vec::new();
 635    let mut rejected_after_timestamps = Vec::new();
 636    let mut requested_after_timestamps = Vec::new();
 637    let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
 638        Vec::new();
 639    let mut file_inputs = Vec::new();
 640
 641    for input in &args.inputs {
 642        let input_string = input.to_string_lossy();
 643        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
 644            captured_after_timestamps.push(timestamp.to_string());
 645        } else if let Some(timestamp) =
 646            pull_examples::parse_rejected_after_input(input_string.as_ref())
 647        {
 648            rejected_after_timestamps.push(timestamp.to_string());
 649        } else if let Some(timestamp) =
 650            pull_examples::parse_requested_after_input(input_string.as_ref())
 651        {
 652            requested_after_timestamps.push(timestamp.to_string());
 653        } else if let Some((timestamp, rating_filter)) =
 654            pull_examples::parse_rated_after_input(input_string.as_ref())
 655        {
 656            rated_after_inputs.push((timestamp.to_string(), rating_filter));
 657        } else {
 658            file_inputs.push(input.clone());
 659        }
 660    }
 661
 662    let mut examples = read_example_files(&file_inputs);
 663
 664    // Apply offset to file examples first, then pass remaining offset to Snowflake.
 665    let file_example_count = examples.len();
 666    let remaining_offset = if let Some(offset) = args.offset {
 667        if offset >= file_example_count {
 668            examples.clear();
 669            offset - file_example_count
 670        } else {
 671            examples.splice(0..offset, []);
 672            0
 673        }
 674    } else {
 675        0
 676    };
 677
 678    Progress::global().set_total_examples(examples.len());
 679
 680    let remaining_limit_for_snowflake =
 681        args.limit.map(|limit| limit.saturating_sub(examples.len()));
 682
 683    if let Some(0) = remaining_limit_for_snowflake {
 684        log::info!(
 685            "skipping Snowflake inputs because --limit is already satisfied by example files"
 686        );
 687    } else {
 688        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
 689
 690        if !rejected_after_timestamps.is_empty() {
 691            rejected_after_timestamps.sort();
 692
 693            let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
 694                http_client.clone(),
 695                &rejected_after_timestamps,
 696                max_rows_per_timestamp,
 697                remaining_offset,
 698                background_executor.clone(),
 699                Some(MIN_CAPTURE_VERSION),
 700            )
 701            .await?;
 702            examples.append(&mut rejected_examples);
 703        }
 704
 705        if !requested_after_timestamps.is_empty() {
 706            requested_after_timestamps.sort();
 707
 708            let mut requested_examples = pull_examples::fetch_requested_examples_after(
 709                http_client.clone(),
 710                &requested_after_timestamps,
 711                max_rows_per_timestamp,
 712                remaining_offset,
 713                background_executor.clone(),
 714                Some(MIN_CAPTURE_VERSION),
 715            )
 716            .await?;
 717            examples.append(&mut requested_examples);
 718        }
 719
 720        if !rated_after_inputs.is_empty() {
 721            rated_after_inputs.sort();
 722
 723            let mut rated_examples = pull_examples::fetch_rated_examples_after(
 724                http_client,
 725                &rated_after_inputs,
 726                max_rows_per_timestamp,
 727                remaining_offset,
 728                background_executor,
 729                Some(MIN_CAPTURE_VERSION),
 730            )
 731            .await?;
 732            examples.append(&mut rated_examples);
 733        }
 734    }
 735
 736    crate::example::sort_examples_by_repo_and_rev(&mut examples);
 737
 738    if let Some(name_filter) = &args.name {
 739        examples.retain(|example| example.spec.name.contains(name_filter));
 740    }
 741    if let Some(repo_filter) = &args.repo {
 742        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
 743    }
 744
 745    // Skip resume logic for --in-place since input and output are the same file,
 746    // which would incorrectly treat all input examples as already processed.
 747    if !args.in_place {
 748        if let Some(path) = output_path
 749            && let Some(command) = &args.command
 750        {
 751            resume_from_output(path, &mut examples, command);
 752        }
 753    }
 754
 755    if let Some(max_duplicates) = args.max_duplicates {
 756        deduplicate_examples(&mut examples, max_duplicates);
 757    }
 758
 759    if let Some(limit) = args.limit {
 760        examples.truncate(limit);
 761    }
 762
 763    let progress = Progress::global();
 764    progress.set_total_examples(examples.len());
 765    progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
 766
 767    Ok(examples)
 768}
 769
 770fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
 771    let mut hasher = collections::FxHasher::default();
 772    spec.hash(&mut hasher);
 773    hasher.finish()
 774}
 775
 776fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
 777    let file = match File::open(path) {
 778        Ok(f) => f,
 779        Err(_) => return,
 780    };
 781
 782    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
 783
 784    let reader = BufReader::new(file);
 785    let mut kept_lines = Vec::new();
 786    let mut kept_hashes = HashSet::default();
 787
 788    for line in reader.lines() {
 789        let line = match line {
 790            Ok(l) => l,
 791            Err(_) => continue,
 792        };
 793
 794        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
 795            let hash = spec_hash(&output_example.spec);
 796            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
 797                let is_complete = match command {
 798                    Command::Qa(_) => output_example
 799                        .qa
 800                        .first()
 801                        .and_then(|q| q.as_ref())
 802                        .and_then(|q| q.confidence)
 803                        .is_some(),
 804                    Command::Repair(_) => output_example.predictions.iter().any(|p| {
 805                        p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
 806                    }),
 807                    _ => true,
 808                };
 809                if is_complete {
 810                    kept_hashes.insert(hash);
 811                    kept_lines.push(line);
 812                }
 813            }
 814        }
 815    }
 816
 817    let total = examples.len();
 818    let already_processed = kept_hashes.len();
 819
 820    eprintln!(
 821        "Resuming: {}/{} examples already processed",
 822        already_processed, total
 823    );
 824
 825    let file = OpenOptions::new()
 826        .write(true)
 827        .truncate(true)
 828        .open(path)
 829        .expect("Failed to open output file for rewriting");
 830    let mut writer = BufWriter::new(file);
 831    for line in &kept_lines {
 832        writeln!(writer, "{}", line).expect("Failed to write to output file");
 833    }
 834    writer.flush().expect("Failed to flush output file");
 835
 836    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
 837}
 838
 839fn main() {
 840    let args = EpArgs::parse();
 841
 842    if args.printenv {
 843        ::util::shell_env::print_env();
 844        return;
 845    }
 846
 847    let output = args.output_path();
 848
 849    if args.markdown && output.is_none() {
 850        eprintln!("--markdown requires -o to specify the output directory");
 851        std::process::exit(1);
 852    }
 853
 854    let command = match &args.command {
 855        Some(cmd) => cmd.clone(),
 856        None => {
 857            EpArgs::command().print_help().unwrap();
 858            return;
 859        }
 860    };
 861
 862    match &command {
 863        Command::ImportBatch(import_args) => {
 864            smol::block_on(async {
 865                match import_args.provider {
 866                    BatchProvider::Anthropic => {
 867                        let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
 868                            .expect("Failed to create Anthropic client");
 869                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 870                            eprintln!("Error importing Anthropic batches: {:?}", e);
 871                            std::process::exit(1);
 872                        }
 873                    }
 874                    BatchProvider::Openai => {
 875                        let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
 876                            .expect("Failed to create OpenAI client");
 877                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 878                            eprintln!("Error importing OpenAI batches: {:?}", e);
 879                            std::process::exit(1);
 880                        }
 881                    }
 882                }
 883                println!(
 884                    "Successfully imported {} batch(es)",
 885                    import_args.batch_ids.len()
 886                );
 887            });
 888            return;
 889        }
 890        Command::Clean => {
 891            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
 892            return;
 893        }
 894        Command::PrintZetaFormats => {
 895            use strum::IntoEnumIterator as _;
 896            for format in ZetaFormat::iter() {
 897                println!("{}", format.to_string().to_lowercase());
 898            }
 899            return;
 900        }
 901
 902        Command::Synthesize(synth_args) => {
 903            let Some(output_dir) = args.output else {
 904                panic!("output dir is required");
 905            };
 906            let config = SynthesizeConfig {
 907                repo_urls: synth_args.repos.clone(),
 908                count: synth_args.count,
 909                max_commits: synth_args.max_commits,
 910                output_dir,
 911                fresh: synth_args.fresh,
 912            };
 913            smol::block_on(async {
 914                if let Err(e) = run_synthesize(config).await {
 915                    eprintln!("Error: {:?}", e);
 916                    std::process::exit(1);
 917                }
 918            });
 919            return;
 920        }
 921        Command::SplitCommit(split_commit_args) => {
 922            if let Err(error) = split_commit::run_split_commit(
 923                split_commit_args,
 924                &args.inputs,
 925                output.as_ref(),
 926                args.failed,
 927            ) {
 928                eprintln!("{error:#}");
 929                std::process::exit(1);
 930            }
 931            return;
 932        }
 933        Command::TruncatePatch(truncate_args) => {
 934            if let Err(error) =
 935                truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
 936            {
 937                eprintln!("{error:#}");
 938                std::process::exit(1);
 939            }
 940            return;
 941        }
 942        Command::Split(split_args) => {
 943            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
 944                eprintln!("{error:#}");
 945                std::process::exit(1);
 946            }
 947            return;
 948        }
 949        Command::FilterLanguages(filter_args) => {
 950            if let Err(error) =
 951                run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
 952            {
 953                eprintln!("{error:#}");
 954                std::process::exit(1);
 955            }
 956            return;
 957        }
 958
 959        _ => {}
 960    }
 961
 962    let http_client = Arc::new(ReqwestClient::new());
 963    let app = gpui_platform::headless().with_http_client(http_client);
 964
 965    app.run(move |cx| {
 966        let app_state = Arc::new(headless::init(cx));
 967        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
 968
 969        cx.spawn(async move |cx| {
 970            let result = async {
 971                let examples = load_examples(
 972                    app_state.client.http_client(),
 973                    &args,
 974                    output.as_ref(),
 975                    cx.background_executor().clone(),
 976                )
 977                .await?;
 978
 979                match &command {
 980                    Command::Predict(args) | Command::Score(args) => {
 981                        predict::sync_batches(args.provider.as_ref()).await?;
 982                    }
 983                    Command::Eval(args) => {
 984                        predict::sync_batches(args.predict.provider.as_ref()).await?;
 985                    }
 986                    Command::Qa(args) => {
 987                        qa::sync_batches(args).await?;
 988                    }
 989                    Command::Repair(args) => {
 990                        repair::sync_batches(args).await?;
 991                    }
 992                    _ => (),
 993                }
 994
 995                let failfast_on_single_example = examples.len() == 1;
 996
 997                // For --markdown mode, create the output directory if it doesn't exist
 998                if args.markdown {
 999                    let dir = output.as_ref().expect("--markdown requires -o");
1000                    if !dir.exists() {
1001                        std::fs::create_dir_all(dir)
1002                            .expect("Failed to create markdown output directory");
1003                    }
1004                }
1005
1006                // Set up JSONL output writer (not used in markdown mode)
1007                let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
1008                let mut in_place_temp_path: Option<PathBuf> = None;
1009                if !args.markdown
1010                    && let Some(output_path) = output.as_ref()
1011                {
1012                    let write_path = if args.in_place {
1013                        let temp = output_path.with_extension("jsonl.tmp");
1014                        in_place_temp_path = Some(temp.clone());
1015                        temp
1016                    } else {
1017                        output_path.clone()
1018                    };
1019
1020                    let file = OpenOptions::new()
1021                        .create(true)
1022                        .write(true)
1023                        .truncate(args.in_place)
1024                        .append(!args.in_place)
1025                        .open(&write_path)
1026                        .expect("Failed to open output file");
1027
1028                    let mut writer = BufWriter::new(file);
1029                    let (sender, mut receiver) = mpsc::unbounded::<String>();
1030                    cx.background_spawn(async move {
1031                        while let Some(line) = receiver.next().await {
1032                            writeln!(writer, "{}", line).expect("Failed to write example");
1033                            writer.flush().expect("Failed to flush output");
1034                        }
1035                    })
1036                    .detach();
1037                    output_sender = Some(sender);
1038                }
1039
1040                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
1041                let finished_examples = Mutex::new(Vec::new());
1042
1043                let mut tasks = Vec::new();
1044                for _ in 0..args.max_parallelism {
1045                    tasks.push(async {
1046                        loop {
1047                            let Some(mut repo_examples) =
1048                                grouped_examples.lock().unwrap().pop_front()
1049                            else {
1050                                break;
1051                            };
1052                            for example in &mut repo_examples {
1053                                let example_progress =
1054                                    Progress::global().start_group(&example.spec.name);
1055
1056                                let result = async {
1057                                    match &command {
1058                                        Command::Read(_) => {}
1059                                        Command::LoadProject => {
1060                                            run_load_project(
1061                                                example,
1062                                                app_state.clone(),
1063                                                &example_progress,
1064                                                cx.clone(),
1065                                            )
1066                                            .await?;
1067                                        }
1068                                        Command::Context => {
1069                                            run_context_retrieval(
1070                                                example,
1071                                                app_state.clone(),
1072                                                &example_progress,
1073                                                cx.clone(),
1074                                            )
1075                                            .await?;
1076                                        }
1077                                        Command::FormatPrompt(args) => {
1078                                            run_format_prompt(
1079                                                example,
1080                                                args,
1081                                                app_state.clone(),
1082                                                &example_progress,
1083                                                cx.clone(),
1084                                            )
1085                                            .await?;
1086                                        }
1087                                        Command::Predict(args) => {
1088                                            run_prediction(
1089                                                example,
1090                                                args,
1091                                                app_state.clone(),
1092                                                &example_progress,
1093                                                cx.clone(),
1094                                            )
1095                                            .await?;
1096                                        }
1097                                        Command::ParseOutput => {
1098                                            parse_output::run_parse_output(example)?;
1099                                        }
1100                                        Command::Distill => {
1101                                            run_distill(example).await?;
1102                                        }
1103                                        Command::Score(args) => {
1104                                            run_scoring(
1105                                                example,
1106                                                args,
1107                                                app_state.clone(),
1108                                                &example_progress,
1109                                                cx.clone(),
1110                                            )
1111                                            .await?;
1112                                        }
1113                                        Command::Eval(args) => {
1114                                            run_scoring(
1115                                                example,
1116                                                &args.predict,
1117                                                app_state.clone(),
1118                                                &example_progress,
1119                                                cx.clone(),
1120                                            )
1121                                            .await?;
1122                                        }
1123                                        Command::Qa(args) => {
1124                                            qa::run_qa(example, args, &example_progress).await?;
1125                                        }
1126                                        Command::Repair(args) => {
1127                                            repair::run_repair(example, args, &example_progress)
1128                                                .await?;
1129                                        }
1130                                        Command::Clean
1131                                        | Command::Synthesize(_)
1132                                        | Command::SplitCommit(_)
1133                                        | Command::Split(_)
1134                                        | Command::TruncatePatch(_)
1135                                        | Command::FilterLanguages(_)
1136                                        | Command::ImportBatch(_)
1137                                        | Command::PrintZetaFormats => {
1138                                            unreachable!()
1139                                        }
1140                                    }
1141                                    anyhow::Ok(())
1142                                }
1143                                .await;
1144
1145                                let failed = if let Err(error) = result {
1146                                    handle_error(
1147                                        error,
1148                                        &args,
1149                                        &command,
1150                                        &app_state,
1151                                        failfast_on_single_example,
1152                                        &example,
1153                                    )
1154                                    .await;
1155                                    true
1156                                } else {
1157                                    false
1158                                };
1159
1160                                let should_write = !failed || args.failed == FailedHandling::Keep;
1161                                if should_write {
1162                                    if args.markdown {
1163                                        let markdown_dir =
1164                                            output.as_ref().expect("--markdown requires -o");
1165                                        let filename = format!("{}.md", example.spec.filename());
1166                                        let path = markdown_dir.join(&filename);
1167                                        let markdown = example.spec.to_markdown();
1168                                        std::fs::write(&path, &markdown)
1169                                            .expect("Failed to write markdown file");
1170                                    } else if let Some(ref mut sender) = output_sender.clone() {
1171                                        let line = serde_json::to_string(&example).unwrap();
1172                                        sender
1173                                            .send(line)
1174                                            .await
1175                                            .expect("Failed to send to output writer");
1176                                    } else if args.output.is_none()
1177                                        && !matches!(command, Command::Eval(_))
1178                                    {
1179                                        let line = serde_json::to_string(&example).unwrap();
1180                                        println!("{}", line);
1181                                    }
1182                                }
1183                            }
1184
1185                            let project = repo_examples
1186                                .iter()
1187                                .find_map(|e| e.state.as_ref().map(|s| s.project.clone()));
1188
1189                            if let Some(project) = project {
1190                                let mut cx = cx.clone();
1191
1192                                let shutdown_task: Task<()> =
1193                                    project.update(&mut cx, |project, cx| {
1194                                        let lsp_store = project.lsp_store();
1195                                        lsp_store.update(cx, |lsp_store, cx| {
1196                                            lsp_store.shutdown_all_language_servers(cx)
1197                                        })
1198                                    });
1199
1200                                shutdown_task.await;
1201
1202                                if let Some(ep_store) =
1203                                    cx.update(|cx| EditPredictionStore::try_global(cx))
1204                                {
1205                                    ep_store.update(&mut cx, |store, _| {
1206                                        store.remove_project(&project);
1207                                    });
1208                                }
1209                            }
1210
1211                            for example in &mut repo_examples {
1212                                example.state.take();
1213                            }
1214                            finished_examples
1215                                .lock()
1216                                .unwrap()
1217                                .extend_from_slice(&repo_examples);
1218                        }
1219                    });
1220                }
1221                futures::future::join_all(tasks).await;
1222
1223                Progress::global().finalize();
1224
1225                match &command {
1226                    Command::Predict(args) | Command::Score(args) => {
1227                        predict::sync_batches(args.provider.as_ref()).await?;
1228                    }
1229                    Command::Eval(args) => {
1230                        predict::sync_batches(args.predict.provider.as_ref()).await?;
1231                    }
1232                    Command::Qa(args) => {
1233                        qa::sync_batches(args).await?;
1234                    }
1235                    Command::Repair(args) => {
1236                        repair::sync_batches(args).await?;
1237                    }
1238                    _ => (),
1239                }
1240
1241                match &command {
1242                    Command::Eval(args) => {
1243                        let examples = finished_examples.lock().unwrap();
1244                        score::print_report(&examples, args.verbose);
1245                        if let Some(summary_path) = &args.summary_json {
1246                            score::write_summary_json(&examples, summary_path)?;
1247                        }
1248                    }
1249                    Command::Repair(args) => {
1250                        let examples = finished_examples.lock().unwrap();
1251                        repair::print_report(&examples, args.confidence_threshold);
1252                    }
1253                    _ => (),
1254                };
1255
1256                // For --in-place, atomically rename temp file to original
1257                if let Some(temp_path) = &in_place_temp_path {
1258                    let final_path = output.as_ref().expect("in_place_temp_path requires output");
1259                    std::fs::rename(temp_path, final_path)
1260                        .expect("Failed to rename temp file to final output");
1261                }
1262
1263                anyhow::Ok(())
1264            }
1265            .await;
1266
1267            if let Err(e) = result {
1268                panic!("Fatal error: {:?}", e);
1269            }
1270
1271            let _ = cx.update(|cx| cx.quit());
1272        })
1273        .detach();
1274    });
1275}
1276
1277async fn handle_error(
1278    error: anyhow::Error,
1279    args: &EpArgs,
1280    command: &Command,
1281    app_state: &Arc<headless::EpAppState>,
1282    failfast_on_single_example: bool,
1283    example: &Example,
1284) {
1285    Progress::global().increment_failed();
1286
1287    let msg;
1288    if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1289        let example_name = example.spec.filename();
1290
1291        let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1292        app_state
1293            .fs
1294            .write(
1295                &failed_example_path,
1296                &serde_json::to_vec_pretty(&example).unwrap(),
1297            )
1298            .await
1299            .unwrap();
1300        let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1301        app_state
1302            .fs
1303            .write(&err_path, format!("{error:?}").as_bytes())
1304            .await
1305            .unwrap();
1306
1307        let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1308        let mut file = OpenOptions::new()
1309            .create(true)
1310            .append(true)
1311            .open(&failed_jsonl_path)
1312            .expect("Failed to open failed.jsonl");
1313        writeln!(file, "{}", serde_json::to_string(example).unwrap())
1314            .expect("Failed to write to failed.jsonl");
1315
1316        let cursor_path = match example.repo_name() {
1317            Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1318            Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1319        };
1320        msg = format!(
1321            indoc::indoc! {"
1322                While processing \"{}\":
1323
1324                \x1b[31m{:?}\x1b[0m
1325
1326                Example:        \x1b[36m{}\x1b[0m
1327                Error file:     \x1b[36m{}\x1b[0m
1328                Cursor file:    \x1b[36m{}\x1b[0m
1329                Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1330            "},
1331            example.spec.name,
1332            error,
1333            failed_example_path.display(),
1334            err_path.display(),
1335            cursor_path.display(),
1336            command,
1337            failed_example_path.display(),
1338        );
1339    } else {
1340        msg = format!(
1341            indoc::indoc! {"
1342            While processing \"{}\":
1343
1344                \x1b[31m{:?}\x1b[0m
1345            "},
1346            example.spec.name, error
1347        );
1348    }
1349
1350    if args.failfast || failfast_on_single_example {
1351        Progress::global().finalize();
1352        panic!("{}", msg);
1353    } else {
1354        log::error!("{}", msg);
1355    }
1356}