main.rs

   1mod anthropic_client;
   2mod distill;
   3mod example;
   4mod filter_languages;
   5mod format_prompt;
   6mod git;
   7mod headless;
   8mod load_project;
   9mod metrics;
  10mod openai_client;
  11mod parse_output;
  12mod paths;
  13mod predict;
  14mod progress;
  15mod prompt_assets;
  16mod pull_examples;
  17mod qa;
  18mod reorder_patch;
  19mod repair;
  20mod retrieve_context;
  21mod reversal_tracking;
  22mod score;
  23mod split_commit;
  24mod split_dataset;
  25
  26mod synthesize;
  27mod truncate_expected_patch;
  28mod word_diff;
  29use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
  30use collections::{HashMap, HashSet};
  31use edit_prediction::EditPredictionStore;
  32use futures::channel::mpsc;
  33use futures::{SinkExt as _, StreamExt as _};
  34use gaoya::minhash::{
  35    MinHashIndex, MinHasher, MinHasher32, calculate_minhash_params, compute_minhash_similarity,
  36};
  37use gpui::{AppContext as _, BackgroundExecutor, Task};
  38use zeta_prompt::ZetaFormat;
  39
  40use reqwest_client::ReqwestClient;
  41use serde::{Deserialize, Deserializer, Serialize, Serializer};
  42use std::fmt::Display;
  43use std::fs::{File, OpenOptions};
  44use std::hash::{Hash, Hasher};
  45use std::io::{BufRead, BufReader, BufWriter, Write};
  46use std::sync::Mutex;
  47use std::{path::PathBuf, sync::Arc};
  48
  49use crate::distill::run_distill;
  50use crate::example::{Example, group_examples_by_repo, read_example_files};
  51use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
  52use crate::format_prompt::run_format_prompt;
  53use crate::load_project::run_load_project;
  54use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
  55use crate::predict::run_prediction;
  56use crate::progress::Progress;
  57use crate::retrieve_context::run_context_retrieval;
  58use crate::score::run_scoring;
  59use crate::split_commit::SplitCommitArgs;
  60use crate::split_dataset::SplitArgs;
  61use crate::synthesize::{SynthesizeConfig, run_synthesize};
  62use crate::truncate_expected_patch::TruncatePatchArgs;
  63
  64#[derive(Parser, Debug)]
  65#[command(name = "ep")]
  66struct EpArgs {
  67    #[arg(long, default_value_t = false)]
  68    printenv: bool,
  69    #[clap(long, default_value_t = 10, global = true)]
  70    max_parallelism: usize,
  71    /// The limit for the number of examples to process
  72    /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
  73    #[clap(long, global = true)]
  74    limit: Option<usize>,
  75    #[clap(long, global = true)]
  76    offset: Option<usize>,
  77    /// Filter examples by name
  78    #[clap(long, global = true)]
  79    name: Option<String>,
  80    /// Filter examples by repository
  81    #[clap(long, global = true)]
  82    repo: Option<String>,
  83    /// Deduplicate by cursor position and keep at most this many examples per cluster
  84    #[clap(long, global = true)]
  85    max_duplicates: Option<usize>,
  86    #[command(subcommand)]
  87    command: Option<Command>,
  88    /// Input file paths
  89    #[clap(global = true)]
  90    inputs: Vec<PathBuf>,
  91    #[arg(long, short, global = true)]
  92    output: Option<PathBuf>,
  93    #[arg(long, short, global = true)]
  94    in_place: bool,
  95    #[arg(long, short, global = true)]
  96    failfast: bool,
  97    /// How to handle failed examples in output: keep them or skip them.
  98    /// Failed examples are always logged to the run's failed directory.
  99    #[arg(long, global = true, default_value = "keep")]
 100    failed: FailedHandling,
 101    /// Output as markdown files instead of JSONL. When set, -o specifies a directory
 102    /// where one .md file per example will be written (named after each example).
 103    #[arg(long, short, global = true)]
 104    markdown: bool,
 105}
 106
 107/// Controls whether failed examples are included in the main output.
 108/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
 109#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
 110pub enum FailedHandling {
 111    /// Include failed examples in the main output (default)
 112    #[default]
 113    Keep,
 114    /// Exclude failed examples from the main output
 115    Skip,
 116    /// Skip writing files
 117    SkipNoFiles,
 118}
 119
 120const INPUTS_HELP: &str = r#"
 121Inputs can be file paths or special specifiers:
 122
 123  path
 124      Path to an example(s) file (.md, .json, or .jsonl)
 125
 126  captured-after:{timestamp}
 127      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 128      These are examples captured via the "Capture Edit Prediction Example" action.
 129
 130  rejected-after:{timestamp}
 131      Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
 132      These are predictions that were shown to users but rejected (useful for DPO training).
 133
 134  rated-after:{timestamp}
 135      Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
 136      These are predictions that users explicitly rated as positive or negative via the
 137      rate completions modal. Only zeta2 predictions are included.
 138      - Positive ratings: output becomes expected_patches
 139      - Negative ratings: output becomes rejected_patch
 140
 141  rated-positive-after:{timestamp}
 142      Same as rated-after, but only fetches positively rated predictions.
 143
 144  rated-negative-after:{timestamp}
 145      Same as rated-after, but only fetches negatively rated predictions.
 146
 147      Required environment variables to connect to Snowflake:
 148          EP_SNOWFLAKE_API_KEY
 149          EP_SNOWFLAKE_BASE_URL
 150
 151      Optional:
 152          EP_SNOWFLAKE_ROLE
 153
 154Examples:
 155
 156  # Read examples from a file
 157  ep read examples.jsonl -o output.jsonl
 158
 159  # Read captured examples after a timestamp
 160  ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
 161
 162  # Read rejected predictions for DPO training
 163  ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
 164
 165  # Read user-rated predictions
 166  ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
 167
 168  # Read only positively rated predictions
 169  ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
 170
 171  # Read only negatively rated predictions
 172  ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
 173
 174  # Mix multiple input sources
 175  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
 176"#;
 177
 178#[derive(Subcommand, Debug, Clone)]
 179enum Command {
 180    /// Read examples from files or fetch from Snowflake, output as .jsonl
 181    Read(ReadArgs),
 182    /// Create git worktrees for each example and load file contents
 183    LoadProject,
 184    /// Retrieve context for input examples.
 185    Context,
 186    /// Generate a prompt string for a specific model
 187    FormatPrompt(FormatPromptArgs),
 188    /// Runs edit prediction
 189    Predict(PredictArgs),
 190    /// Parse model outputs (actual_output) into unified diffs (actual_patch).
 191    /// Requires format-prompt to have been run first. Uses provider from prompt.
 192    ParseOutput,
 193    /// Computes a score based on actual and expected patches
 194    Score(PredictArgs),
 195    /// Prepares a distillation dataset by copying expected outputs to
 196    /// predicted outputs and removing actual outputs and prompts.
 197    Distill,
 198    /// Print aggregated scores
 199    Eval(EvalArgs),
 200    /// Generate eval examples by analyzing git commits from a repository
 201    Synthesize(SynthesizeArgs),
 202    /// Remove git repositories and worktrees
 203    Clean,
 204    /// Generate an evaluation example by splitting a chronologically-ordered commit
 205    SplitCommit(SplitCommitArgs),
 206    /// Truncate expected patch by the given criteria
 207    TruncatePatch(TruncatePatchArgs),
 208    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
 209    Split(SplitArgs),
 210    /// Filter a JSONL dataset by programming language (based on cursor_path extension)
 211    FilterLanguages(FilterLanguagesArgs),
 212    /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
 213    ImportBatch(ImportBatchArgs),
 214    /// Assess the quality of predictions using LLM-as-a-judge
 215    Qa(qa::QaArgs),
 216    /// Repair predictions that received poor QA scores by generating improved predictions
 217    Repair(repair::RepairArgs),
 218    /// Print all valid zeta formats (lowercase, one per line)
 219    PrintZetaFormats,
 220}
 221
 222impl Display for Command {
 223    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 224        match self {
 225            Command::Read(_) => write!(f, "read"),
 226            Command::LoadProject => write!(f, "load-project"),
 227            Command::Context => write!(f, "context"),
 228            Command::FormatPrompt(args) => {
 229                write!(f, "format-prompt --provider={}", args.provider)
 230            }
 231            Command::Predict(args) => match &args.provider {
 232                Some(provider) => write!(f, "predict --provider={}", provider),
 233                None => write!(f, "predict"),
 234            },
 235            Command::ParseOutput => write!(f, "parse-output"),
 236            Command::Score(args) => match &args.provider {
 237                Some(provider) => write!(f, "score --provider={}", provider),
 238                None => write!(f, "score"),
 239            },
 240            Command::Distill => write!(f, "distill"),
 241            Command::Eval(args) => match &args.predict.provider {
 242                Some(provider) => write!(f, "eval --provider={}", provider),
 243                None => write!(f, "eval"),
 244            },
 245            Command::Synthesize(args) => {
 246                write!(f, "synthesize --repos {}", args.repos.join(" "))
 247            }
 248            Command::Clean => write!(f, "clean"),
 249            Command::SplitCommit(_) => write!(f, "split-commit"),
 250            Command::TruncatePatch(_) => write!(f, "truncate-patch"),
 251            Command::Split(_) => write!(f, "split"),
 252            Command::FilterLanguages(_) => write!(f, "filter-languages"),
 253            Command::ImportBatch(args) => {
 254                write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
 255            }
 256            Command::Qa(_) => {
 257                write!(f, "qa")
 258            }
 259            Command::Repair(_) => {
 260                write!(f, "repair")
 261            }
 262            Command::PrintZetaFormats => {
 263                write!(f, "print-zeta-formats")
 264            }
 265        }
 266    }
 267}
 268
 269#[derive(Debug, Args, Clone)]
 270#[command(after_help = INPUTS_HELP)]
 271struct ReadArgs {}
 272
 273#[derive(Debug, Args, Clone)]
 274struct FormatPromptArgs {
 275    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
 276    provider: PredictionProvider,
 277}
 278
 279#[derive(Debug, Args, Clone)]
 280struct PredictArgs {
 281    #[clap(long, short('p'))]
 282    provider: Option<PredictionProvider>,
 283    #[clap(long, default_value_t = 1)]
 284    repetitions: usize,
 285    /// Only use cached responses, don't queue new requests for batching
 286    #[clap(long)]
 287    cache_only: bool,
 288}
 289
 290#[derive(Debug, Args, Clone)]
 291struct EvalArgs {
 292    #[clap(flatten)]
 293    predict: PredictArgs,
 294    /// Path to write summary scores as JSON
 295    #[clap(long)]
 296    summary_json: Option<PathBuf>,
 297}
 298
 299#[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)]
 300pub enum TeacherBackend {
 301    Sonnet46,
 302    #[default]
 303    Sonnet45,
 304    Gpt52,
 305}
 306
 307impl std::fmt::Display for TeacherBackend {
 308    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 309        match self {
 310            TeacherBackend::Sonnet46 => write!(f, "sonnet46"),
 311            TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
 312            TeacherBackend::Gpt52 => write!(f, "gpt52"),
 313        }
 314    }
 315}
 316
 317impl std::str::FromStr for TeacherBackend {
 318    type Err = anyhow::Error;
 319
 320    fn from_str(s: &str) -> Result<Self, Self::Err> {
 321        match s.to_lowercase().as_str() {
 322            "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
 323            "sonnet46" => Ok(TeacherBackend::Sonnet46),
 324            "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
 325            "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
 326            _ => anyhow::bail!(
 327                "unknown teacher backend `{s}`. Valid options: sonnet45, sonnet46, gpt52"
 328            ),
 329        }
 330    }
 331}
 332
 333impl TeacherBackend {
 334    pub fn model_name(&self) -> &'static str {
 335        match self {
 336            TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
 337            TeacherBackend::Sonnet46 => "claude-sonnet-4-6",
 338            TeacherBackend::Gpt52 => "gpt-5.2",
 339        }
 340    }
 341}
 342
 343#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 344enum PredictionProvider {
 345    Sweep,
 346    Mercury,
 347    Zeta1,
 348    Zeta2(ZetaFormat),
 349    Teacher(TeacherBackend),
 350    TeacherNonBatching(TeacherBackend),
 351    Repair,
 352}
 353
 354impl Default for PredictionProvider {
 355    fn default() -> Self {
 356        PredictionProvider::Zeta2(ZetaFormat::default())
 357    }
 358}
 359
 360impl std::fmt::Display for PredictionProvider {
 361    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 362        match self {
 363            PredictionProvider::Sweep => write!(f, "sweep"),
 364            PredictionProvider::Mercury => write!(f, "mercury"),
 365            PredictionProvider::Zeta1 => write!(f, "zeta1"),
 366            PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
 367            PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
 368            PredictionProvider::TeacherNonBatching(backend) => {
 369                write!(f, "teacher-non-batching:{backend}")
 370            }
 371            PredictionProvider::Repair => write!(f, "repair"),
 372        }
 373    }
 374}
 375
 376impl std::str::FromStr for PredictionProvider {
 377    type Err = anyhow::Error;
 378
 379    fn from_str(s: &str) -> Result<Self, Self::Err> {
 380        let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
 381
 382        let provider_lower = provider.to_lowercase();
 383        match provider_lower.as_str() {
 384            "sweep" => Ok(PredictionProvider::Sweep),
 385            "mercury" => Ok(PredictionProvider::Mercury),
 386            "zeta1" => Ok(PredictionProvider::Zeta1),
 387            "zeta2" => {
 388                let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
 389                Ok(PredictionProvider::Zeta2(format))
 390            }
 391            "teacher" => {
 392                let backend = arg
 393                    .map(|a| a.parse())
 394                    .transpose()?
 395                    .unwrap_or(TeacherBackend::default());
 396                Ok(PredictionProvider::Teacher(backend))
 397            }
 398            "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
 399                let backend = arg
 400                    .map(|a| a.parse())
 401                    .transpose()?
 402                    .unwrap_or(TeacherBackend::default());
 403                Ok(PredictionProvider::TeacherNonBatching(backend))
 404            }
 405            "repair" => Ok(PredictionProvider::Repair),
 406            _ => {
 407                anyhow::bail!(
 408                    "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
 409                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
 410                 For teacher, you can specify a backend like `teacher:sonnet46` or `teacher:gpt52`.\n\
 411                 Available zeta versions:\n{}",
 412                    ZetaFormat::options_as_string()
 413                )
 414            }
 415        }
 416    }
 417}
 418
 419impl Serialize for PredictionProvider {
 420    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 421    where
 422        S: Serializer,
 423    {
 424        serializer.serialize_str(&self.to_string())
 425    }
 426}
 427
 428impl<'de> Deserialize<'de> for PredictionProvider {
 429    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 430    where
 431        D: Deserializer<'de>,
 432    {
 433        let s = String::deserialize(deserializer)?;
 434        s.parse().map_err(serde::de::Error::custom)
 435    }
 436}
 437
 438#[derive(Debug, Args, Clone)]
 439struct SynthesizeArgs {
 440    /// Repository URLs (git@github.com:owner/repo or https://...)
 441    #[clap(long, required = true, num_args = 1..)]
 442    repos: Vec<String>,
 443
 444    /// Number of examples to generate per repository
 445    #[clap(long, default_value_t = 5)]
 446    count: usize,
 447
 448    /// Maximum commits to scan per repository before giving up
 449    #[clap(long, default_value_t = 100)]
 450    max_commits: usize,
 451
 452    /// Ignore state file and reprocess all commits
 453    #[clap(long)]
 454    fresh: bool,
 455}
 456
 457#[derive(Debug, Args, Clone)]
 458struct ImportBatchArgs {
 459    /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
 460    #[clap(long, required = true, num_args = 1..)]
 461    batch_ids: Vec<String>,
 462    /// Which provider's batches to import (anthropic or openai)
 463    #[clap(long, default_value = "anthropic")]
 464    provider: BatchProvider,
 465}
 466
 467#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
 468enum BatchProvider {
 469    Anthropic,
 470    Openai,
 471}
 472
 473impl EpArgs {
 474    fn output_path(&self) -> Option<PathBuf> {
 475        if self.in_place {
 476            if self.inputs.len() == 1 {
 477                self.inputs.first().cloned()
 478            } else {
 479                panic!("--in-place requires exactly one input file")
 480            }
 481        } else {
 482            self.output.clone()
 483        }
 484    }
 485}
 486
 487/// Minimum Zed version required for Snowflake queries.
 488/// This version introduced the current request schema with predicted edits in the edit
 489/// history, and open source repos distinguished.
 490const MIN_CAPTURE_VERSION: pull_examples::MinCaptureVersion = pull_examples::MinCaptureVersion {
 491    minor: 224,
 492    patch: 1,
 493};
 494
 495fn deduplicate_examples(examples: &mut Vec<Example>, max_per_cluster: usize) {
 496    let total_before_exact = examples.len();
 497    let mut seen_positions = HashSet::default();
 498    examples.retain(|example| seen_positions.insert(example.spec.cursor_position.clone()));
 499    log::info!(
 500        "exact duplicate filter: {total_before_exact} examples → {} examples ({} removed)",
 501        examples.len(),
 502        total_before_exact - examples.len(),
 503    );
 504
 505    const JACCARD_THRESHOLD: f64 = 0.5;
 506    const NUM_HASHES: usize = 128;
 507    const TOKEN_NGRAM_SIZE: usize = 5;
 508
 509    let (num_bands, band_width) = calculate_minhash_params(JACCARD_THRESHOLD, NUM_HASHES);
 510    let num_hashes = num_bands * band_width;
 511    let minhasher = MinHasher32::new(num_hashes);
 512    let mut index: MinHashIndex<u32, usize> =
 513        MinHashIndex::new(num_bands, band_width, JACCARD_THRESHOLD);
 514
 515    let signatures: Vec<Vec<u32>> = examples
 516        .iter()
 517        .map(|example| {
 518            let shingles = code_token_ngrams(&example.spec.cursor_position, TOKEN_NGRAM_SIZE);
 519            minhasher.create_signature(shingles.iter())
 520        })
 521        .collect();
 522
 523    for (id, signature) in signatures.iter().enumerate() {
 524        index.insert(id, signature.clone());
 525    }
 526
 527    // Build clusters via union-find on LSH candidate pairs.
 528    let mut parent: Vec<usize> = (0..examples.len()).collect();
 529
 530    fn find(parent: &mut Vec<usize>, mut x: usize) -> usize {
 531        while parent[x] != x {
 532            parent[x] = parent[parent[x]];
 533            x = parent[x];
 534        }
 535        x
 536    }
 537
 538    for (id, signature) in signatures.iter().enumerate() {
 539        for candidate in index.query_owned(signature) {
 540            let (a, b) = (find(&mut parent, id), find(&mut parent, candidate));
 541            if a != b {
 542                parent[a] = b;
 543            }
 544        }
 545    }
 546
 547    let mut clusters: HashMap<usize, Vec<usize>> = HashMap::default();
 548    for id in 0..examples.len() {
 549        clusters.entry(find(&mut parent, id)).or_default().push(id);
 550    }
 551
 552    let mut keep: HashSet<usize> = HashSet::default();
 553    for members in clusters.values() {
 554        let selected = greedy_max_min_diverse(members, &signatures, max_per_cluster);
 555        keep.extend(selected);
 556    }
 557
 558    let total = examples.len();
 559    let mut kept_indices: Vec<usize> = keep.into_iter().collect();
 560    kept_indices.sort();
 561
 562    let mut retained = Vec::with_capacity(kept_indices.len());
 563    for index in kept_indices.into_iter().rev() {
 564        retained.push(examples.swap_remove(index));
 565    }
 566    retained.reverse();
 567
 568    *examples = retained;
 569    log::info!(
 570        "near-duplicate filter: {total} examples → {} examples ({} removed)",
 571        examples.len(),
 572        total - examples.len(),
 573    );
 574}
 575
 576fn greedy_max_min_diverse(members: &[usize], signatures: &[Vec<u32>], k: usize) -> Vec<usize> {
 577    if members.len() <= k {
 578        return members.to_vec();
 579    }
 580
 581    let mut selected = vec![members[0]];
 582    let mut min_dist: HashMap<usize, f64> = HashMap::default();
 583    for &member in &members[1..] {
 584        let dist = 1.0 - compute_minhash_similarity(&signatures[selected[0]], &signatures[member]);
 585        min_dist.insert(member, dist);
 586    }
 587
 588    while selected.len() < k {
 589        let &best = min_dist
 590            .iter()
 591            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
 592            .map(|(id, _)| id)
 593            .expect("min_dist should not be empty when selected.len() < k");
 594        selected.push(best);
 595        min_dist.remove(&best);
 596
 597        let best_sig = &signatures[best];
 598        for (member, current_min) in min_dist.iter_mut() {
 599            let dist = 1.0 - compute_minhash_similarity(best_sig, &signatures[*member]);
 600            if dist < *current_min {
 601                *current_min = dist;
 602            }
 603        }
 604    }
 605
 606    selected
 607}
 608
 609fn code_token_ngrams(code: &str, ngram_size: usize) -> Vec<String> {
 610    let tokens: Vec<&str> = word_diff::tokenize(code)
 611        .into_iter()
 612        .filter(|t| !t.trim().is_empty())
 613        .collect();
 614
 615    if tokens.len() < ngram_size {
 616        return vec![tokens.join("\0")];
 617    }
 618
 619    tokens
 620        .windows(ngram_size)
 621        .map(|window| window.join("\0"))
 622        .collect()
 623}
 624
 625async fn load_examples(
 626    http_client: Arc<dyn http_client::HttpClient>,
 627    args: &EpArgs,
 628    output_path: Option<&PathBuf>,
 629    background_executor: BackgroundExecutor,
 630) -> anyhow::Result<Vec<Example>> {
 631    let mut captured_after_timestamps = Vec::new();
 632    let mut rejected_after_timestamps = Vec::new();
 633    let mut requested_after_timestamps = Vec::new();
 634    let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
 635        Vec::new();
 636    let mut file_inputs = Vec::new();
 637
 638    for input in &args.inputs {
 639        let input_string = input.to_string_lossy();
 640        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
 641            captured_after_timestamps.push(timestamp.to_string());
 642        } else if let Some(timestamp) =
 643            pull_examples::parse_rejected_after_input(input_string.as_ref())
 644        {
 645            rejected_after_timestamps.push(timestamp.to_string());
 646        } else if let Some(timestamp) =
 647            pull_examples::parse_requested_after_input(input_string.as_ref())
 648        {
 649            requested_after_timestamps.push(timestamp.to_string());
 650        } else if let Some((timestamp, rating_filter)) =
 651            pull_examples::parse_rated_after_input(input_string.as_ref())
 652        {
 653            rated_after_inputs.push((timestamp.to_string(), rating_filter));
 654        } else {
 655            file_inputs.push(input.clone());
 656        }
 657    }
 658
 659    let mut examples = read_example_files(&file_inputs);
 660
 661    // Apply offset to file examples first, then pass remaining offset to Snowflake.
 662    let file_example_count = examples.len();
 663    let remaining_offset = if let Some(offset) = args.offset {
 664        if offset >= file_example_count {
 665            examples.clear();
 666            offset - file_example_count
 667        } else {
 668            examples.splice(0..offset, []);
 669            0
 670        }
 671    } else {
 672        0
 673    };
 674
 675    Progress::global().set_total_examples(examples.len());
 676
 677    let remaining_limit_for_snowflake =
 678        args.limit.map(|limit| limit.saturating_sub(examples.len()));
 679
 680    if let Some(0) = remaining_limit_for_snowflake {
 681        log::info!(
 682            "skipping Snowflake inputs because --limit is already satisfied by example files"
 683        );
 684    } else {
 685        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
 686
 687        if !captured_after_timestamps.is_empty() {
 688            captured_after_timestamps.sort();
 689
 690            let mut captured_examples = pull_examples::fetch_captured_examples_after(
 691                http_client.clone(),
 692                &captured_after_timestamps,
 693                max_rows_per_timestamp,
 694                remaining_offset,
 695                background_executor.clone(),
 696                Some(MIN_CAPTURE_VERSION),
 697            )
 698            .await?;
 699            examples.append(&mut captured_examples);
 700        }
 701
 702        if !rejected_after_timestamps.is_empty() {
 703            rejected_after_timestamps.sort();
 704
 705            let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
 706                http_client.clone(),
 707                &rejected_after_timestamps,
 708                max_rows_per_timestamp,
 709                remaining_offset,
 710                background_executor.clone(),
 711                Some(MIN_CAPTURE_VERSION),
 712            )
 713            .await?;
 714            examples.append(&mut rejected_examples);
 715        }
 716
 717        if !requested_after_timestamps.is_empty() {
 718            requested_after_timestamps.sort();
 719
 720            let mut requested_examples = pull_examples::fetch_requested_examples_after(
 721                http_client.clone(),
 722                &requested_after_timestamps,
 723                max_rows_per_timestamp,
 724                remaining_offset,
 725                background_executor.clone(),
 726                Some(MIN_CAPTURE_VERSION),
 727            )
 728            .await?;
 729            examples.append(&mut requested_examples);
 730        }
 731
 732        if !rated_after_inputs.is_empty() {
 733            rated_after_inputs.sort();
 734
 735            let mut rated_examples = pull_examples::fetch_rated_examples_after(
 736                http_client,
 737                &rated_after_inputs,
 738                max_rows_per_timestamp,
 739                remaining_offset,
 740                background_executor,
 741                Some(MIN_CAPTURE_VERSION),
 742            )
 743            .await?;
 744            examples.append(&mut rated_examples);
 745        }
 746    }
 747
 748    crate::example::sort_examples_by_repo_and_rev(&mut examples);
 749
 750    if let Some(name_filter) = &args.name {
 751        examples.retain(|example| example.spec.name.contains(name_filter));
 752    }
 753    if let Some(repo_filter) = &args.repo {
 754        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
 755    }
 756
 757    // Skip resume logic for --in-place since input and output are the same file,
 758    // which would incorrectly treat all input examples as already processed.
 759    if !args.in_place {
 760        if let Some(path) = output_path
 761            && let Some(command) = &args.command
 762        {
 763            resume_from_output(path, &mut examples, command);
 764        }
 765    }
 766
 767    if let Some(max_duplicates) = args.max_duplicates {
 768        deduplicate_examples(&mut examples, max_duplicates);
 769    }
 770
 771    if let Some(limit) = args.limit {
 772        examples.truncate(limit);
 773    }
 774
 775    let progress = Progress::global();
 776    progress.set_total_examples(examples.len());
 777    progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
 778
 779    Ok(examples)
 780}
 781
 782fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
 783    let mut hasher = collections::FxHasher::default();
 784    spec.hash(&mut hasher);
 785    hasher.finish()
 786}
 787
 788fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
 789    let file = match File::open(path) {
 790        Ok(f) => f,
 791        Err(_) => return,
 792    };
 793
 794    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
 795
 796    let reader = BufReader::new(file);
 797    let mut kept_lines = Vec::new();
 798    let mut kept_hashes = HashSet::default();
 799
 800    for line in reader.lines() {
 801        let line = match line {
 802            Ok(l) => l,
 803            Err(_) => continue,
 804        };
 805
 806        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
 807            let hash = spec_hash(&output_example.spec);
 808            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
 809                let is_complete = match command {
 810                    Command::Qa(_) => output_example
 811                        .qa
 812                        .first()
 813                        .and_then(|q| q.as_ref())
 814                        .and_then(|q| q.confidence)
 815                        .is_some(),
 816                    Command::Repair(_) => output_example.predictions.iter().any(|p| {
 817                        p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
 818                    }),
 819                    _ => true,
 820                };
 821                if is_complete {
 822                    kept_hashes.insert(hash);
 823                    kept_lines.push(line);
 824                }
 825            }
 826        }
 827    }
 828
 829    let total = examples.len();
 830    let already_processed = kept_hashes.len();
 831
 832    eprintln!(
 833        "Resuming: {}/{} examples already processed",
 834        already_processed, total
 835    );
 836
 837    let file = OpenOptions::new()
 838        .write(true)
 839        .truncate(true)
 840        .open(path)
 841        .expect("Failed to open output file for rewriting");
 842    let mut writer = BufWriter::new(file);
 843    for line in &kept_lines {
 844        writeln!(writer, "{}", line).expect("Failed to write to output file");
 845    }
 846    writer.flush().expect("Failed to flush output file");
 847
 848    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
 849}
 850
 851fn main() {
 852    let args = EpArgs::parse();
 853
 854    if args.printenv {
 855        ::util::shell_env::print_env();
 856        return;
 857    }
 858
 859    let output = args.output_path();
 860
 861    if args.markdown && output.is_none() {
 862        eprintln!("--markdown requires -o to specify the output directory");
 863        std::process::exit(1);
 864    }
 865
 866    let command = match &args.command {
 867        Some(cmd) => cmd.clone(),
 868        None => {
 869            EpArgs::command().print_help().unwrap();
 870            return;
 871        }
 872    };
 873
 874    match &command {
 875        Command::ImportBatch(import_args) => {
 876            smol::block_on(async {
 877                match import_args.provider {
 878                    BatchProvider::Anthropic => {
 879                        let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
 880                            .expect("Failed to create Anthropic client");
 881                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 882                            eprintln!("Error importing Anthropic batches: {:?}", e);
 883                            std::process::exit(1);
 884                        }
 885                    }
 886                    BatchProvider::Openai => {
 887                        let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
 888                            .expect("Failed to create OpenAI client");
 889                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 890                            eprintln!("Error importing OpenAI batches: {:?}", e);
 891                            std::process::exit(1);
 892                        }
 893                    }
 894                }
 895                println!(
 896                    "Successfully imported {} batch(es)",
 897                    import_args.batch_ids.len()
 898                );
 899            });
 900            return;
 901        }
 902        Command::Clean => {
 903            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
 904            return;
 905        }
 906        Command::PrintZetaFormats => {
 907            use strum::IntoEnumIterator as _;
 908            for format in ZetaFormat::iter() {
 909                println!("{}", format.to_string().to_lowercase());
 910            }
 911            return;
 912        }
 913
 914        Command::Synthesize(synth_args) => {
 915            let Some(output_dir) = args.output else {
 916                panic!("output dir is required");
 917            };
 918            let config = SynthesizeConfig {
 919                repo_urls: synth_args.repos.clone(),
 920                count: synth_args.count,
 921                max_commits: synth_args.max_commits,
 922                output_dir,
 923                fresh: synth_args.fresh,
 924            };
 925            smol::block_on(async {
 926                if let Err(e) = run_synthesize(config).await {
 927                    eprintln!("Error: {:?}", e);
 928                    std::process::exit(1);
 929                }
 930            });
 931            return;
 932        }
 933        Command::SplitCommit(split_commit_args) => {
 934            if let Err(error) = split_commit::run_split_commit(
 935                split_commit_args,
 936                &args.inputs,
 937                output.as_ref(),
 938                args.failed,
 939            ) {
 940                eprintln!("{error:#}");
 941                std::process::exit(1);
 942            }
 943            return;
 944        }
 945        Command::TruncatePatch(truncate_args) => {
 946            if let Err(error) =
 947                truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
 948            {
 949                eprintln!("{error:#}");
 950                std::process::exit(1);
 951            }
 952            return;
 953        }
 954        Command::Split(split_args) => {
 955            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
 956                eprintln!("{error:#}");
 957                std::process::exit(1);
 958            }
 959            return;
 960        }
 961        Command::FilterLanguages(filter_args) => {
 962            if let Err(error) =
 963                run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
 964            {
 965                eprintln!("{error:#}");
 966                std::process::exit(1);
 967            }
 968            return;
 969        }
 970
 971        _ => {}
 972    }
 973
 974    let http_client = Arc::new(ReqwestClient::new());
 975    let app = gpui_platform::headless().with_http_client(http_client);
 976
 977    app.run(move |cx| {
 978        let app_state = Arc::new(headless::init(cx));
 979        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
 980
 981        cx.spawn(async move |cx| {
 982            let result = async {
 983                let examples = load_examples(
 984                    app_state.client.http_client(),
 985                    &args,
 986                    output.as_ref(),
 987                    cx.background_executor().clone(),
 988                )
 989                .await?;
 990
 991                match &command {
 992                    Command::Predict(args) | Command::Score(args) => {
 993                        predict::sync_batches(args.provider.as_ref()).await?;
 994                    }
 995                    Command::Eval(args) => {
 996                        predict::sync_batches(args.predict.provider.as_ref()).await?;
 997                    }
 998                    Command::Qa(args) => {
 999                        qa::sync_batches(args).await?;
1000                    }
1001                    Command::Repair(args) => {
1002                        repair::sync_batches(args).await?;
1003                    }
1004                    _ => (),
1005                }
1006
1007                let failfast_on_single_example = examples.len() == 1;
1008
1009                // For --markdown mode, create the output directory if it doesn't exist
1010                if args.markdown {
1011                    let dir = output.as_ref().expect("--markdown requires -o");
1012                    if !dir.exists() {
1013                        std::fs::create_dir_all(dir)
1014                            .expect("Failed to create markdown output directory");
1015                    }
1016                }
1017
1018                // Set up JSONL output writer (not used in markdown mode)
1019                let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
1020                let mut in_place_temp_path: Option<PathBuf> = None;
1021                if !args.markdown
1022                    && let Some(output_path) = output.as_ref()
1023                {
1024                    let write_path = if args.in_place {
1025                        let temp = output_path.with_extension("jsonl.tmp");
1026                        in_place_temp_path = Some(temp.clone());
1027                        temp
1028                    } else {
1029                        output_path.clone()
1030                    };
1031
1032                    let file = OpenOptions::new()
1033                        .create(true)
1034                        .write(true)
1035                        .truncate(args.in_place)
1036                        .append(!args.in_place)
1037                        .open(&write_path)
1038                        .expect("Failed to open output file");
1039
1040                    let mut writer = BufWriter::new(file);
1041                    let (sender, mut receiver) = mpsc::unbounded::<String>();
1042                    cx.background_spawn(async move {
1043                        while let Some(line) = receiver.next().await {
1044                            writeln!(writer, "{}", line).expect("Failed to write example");
1045                            writer.flush().expect("Failed to flush output");
1046                        }
1047                    })
1048                    .detach();
1049                    output_sender = Some(sender);
1050                }
1051
1052                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
1053                let finished_examples = Mutex::new(Vec::new());
1054
1055                let mut tasks = Vec::new();
1056                for _ in 0..args.max_parallelism {
1057                    tasks.push(async {
1058                        loop {
1059                            let Some(mut repo_examples) =
1060                                grouped_examples.lock().unwrap().pop_front()
1061                            else {
1062                                break;
1063                            };
1064                            for example in &mut repo_examples {
1065                                let example_progress =
1066                                    Progress::global().start_group(&example.spec.name);
1067
1068                                let result = async {
1069                                    match &command {
1070                                        Command::Read(_) => {}
1071                                        Command::LoadProject => {
1072                                            run_load_project(
1073                                                example,
1074                                                app_state.clone(),
1075                                                &example_progress,
1076                                                cx.clone(),
1077                                            )
1078                                            .await?;
1079                                        }
1080                                        Command::Context => {
1081                                            run_context_retrieval(
1082                                                example,
1083                                                app_state.clone(),
1084                                                &example_progress,
1085                                                cx.clone(),
1086                                            )
1087                                            .await?;
1088                                        }
1089                                        Command::FormatPrompt(args) => {
1090                                            run_format_prompt(
1091                                                example,
1092                                                args,
1093                                                app_state.clone(),
1094                                                &example_progress,
1095                                                cx.clone(),
1096                                            )
1097                                            .await?;
1098                                        }
1099                                        Command::Predict(args) => {
1100                                            run_prediction(
1101                                                example,
1102                                                args,
1103                                                app_state.clone(),
1104                                                &example_progress,
1105                                                cx.clone(),
1106                                            )
1107                                            .await?;
1108                                        }
1109                                        Command::ParseOutput => {
1110                                            parse_output::run_parse_output(example)?;
1111                                        }
1112                                        Command::Distill => {
1113                                            run_distill(example).await?;
1114                                        }
1115                                        Command::Score(args) => {
1116                                            run_scoring(
1117                                                example,
1118                                                args,
1119                                                app_state.clone(),
1120                                                &example_progress,
1121                                                cx.clone(),
1122                                            )
1123                                            .await?;
1124                                        }
1125                                        Command::Eval(args) => {
1126                                            run_scoring(
1127                                                example,
1128                                                &args.predict,
1129                                                app_state.clone(),
1130                                                &example_progress,
1131                                                cx.clone(),
1132                                            )
1133                                            .await?;
1134                                        }
1135                                        Command::Qa(args) => {
1136                                            qa::run_qa(example, args, &example_progress).await?;
1137                                        }
1138                                        Command::Repair(args) => {
1139                                            repair::run_repair(example, args, &example_progress)
1140                                                .await?;
1141                                        }
1142                                        Command::Clean
1143                                        | Command::Synthesize(_)
1144                                        | Command::SplitCommit(_)
1145                                        | Command::Split(_)
1146                                        | Command::TruncatePatch(_)
1147                                        | Command::FilterLanguages(_)
1148                                        | Command::ImportBatch(_)
1149                                        | Command::PrintZetaFormats => {
1150                                            unreachable!()
1151                                        }
1152                                    }
1153                                    anyhow::Ok(())
1154                                }
1155                                .await;
1156
1157                                let failed = if let Err(error) = result {
1158                                    handle_error(
1159                                        error,
1160                                        &args,
1161                                        &command,
1162                                        &app_state,
1163                                        failfast_on_single_example,
1164                                        &example,
1165                                    )
1166                                    .await;
1167                                    true
1168                                } else {
1169                                    false
1170                                };
1171
1172                                let should_write = !failed || args.failed == FailedHandling::Keep;
1173                                if should_write {
1174                                    if args.markdown {
1175                                        let markdown_dir =
1176                                            output.as_ref().expect("--markdown requires -o");
1177                                        let filename = format!("{}.md", example.spec.filename());
1178                                        let path = markdown_dir.join(&filename);
1179                                        let markdown = example.spec.to_markdown();
1180                                        std::fs::write(&path, &markdown)
1181                                            .expect("Failed to write markdown file");
1182                                    } else if let Some(ref mut sender) = output_sender.clone() {
1183                                        let line = serde_json::to_string(&example).unwrap();
1184                                        sender
1185                                            .send(line)
1186                                            .await
1187                                            .expect("Failed to send to output writer");
1188                                    } else if args.output.is_none()
1189                                        && !matches!(command, Command::Eval(_))
1190                                    {
1191                                        let line = serde_json::to_string(&example).unwrap();
1192                                        println!("{}", line);
1193                                    }
1194                                }
1195                            }
1196
1197                            let project = repo_examples
1198                                .iter()
1199                                .find_map(|e| e.state.as_ref().map(|s| s.project.clone()));
1200
1201                            if let Some(project) = project {
1202                                let mut cx = cx.clone();
1203
1204                                let shutdown_task: Task<()> =
1205                                    project.update(&mut cx, |project, cx| {
1206                                        let lsp_store = project.lsp_store();
1207                                        lsp_store.update(cx, |lsp_store, cx| {
1208                                            lsp_store.shutdown_all_language_servers(cx)
1209                                        })
1210                                    });
1211
1212                                shutdown_task.await;
1213
1214                                if let Some(ep_store) =
1215                                    cx.update(|cx| EditPredictionStore::try_global(cx))
1216                                {
1217                                    ep_store.update(&mut cx, |store, _| {
1218                                        store.remove_project(&project);
1219                                    });
1220                                }
1221                            }
1222
1223                            for example in &mut repo_examples {
1224                                example.state.take();
1225                            }
1226                            finished_examples
1227                                .lock()
1228                                .unwrap()
1229                                .extend_from_slice(&repo_examples);
1230                        }
1231                    });
1232                }
1233                futures::future::join_all(tasks).await;
1234
1235                Progress::global().finalize();
1236
1237                match &command {
1238                    Command::Predict(args) | Command::Score(args) => {
1239                        predict::sync_batches(args.provider.as_ref()).await?;
1240                    }
1241                    Command::Eval(args) => {
1242                        predict::sync_batches(args.predict.provider.as_ref()).await?;
1243                    }
1244                    Command::Qa(args) => {
1245                        qa::sync_batches(args).await?;
1246                    }
1247                    Command::Repair(args) => {
1248                        repair::sync_batches(args).await?;
1249                    }
1250                    _ => (),
1251                }
1252
1253                match &command {
1254                    Command::Eval(args) => {
1255                        let examples = finished_examples.lock().unwrap();
1256                        score::print_report(&examples);
1257                        if let Some(summary_path) = &args.summary_json {
1258                            score::write_summary_json(&examples, summary_path)?;
1259                        }
1260                    }
1261                    Command::Repair(args) => {
1262                        let examples = finished_examples.lock().unwrap();
1263                        repair::print_report(&examples, args.confidence_threshold);
1264                    }
1265                    _ => (),
1266                };
1267
1268                // For --in-place, atomically rename temp file to original
1269                if let Some(temp_path) = &in_place_temp_path {
1270                    let final_path = output.as_ref().expect("in_place_temp_path requires output");
1271                    std::fs::rename(temp_path, final_path)
1272                        .expect("Failed to rename temp file to final output");
1273                }
1274
1275                anyhow::Ok(())
1276            }
1277            .await;
1278
1279            if let Err(e) = result {
1280                panic!("Fatal error: {:?}", e);
1281            }
1282
1283            let _ = cx.update(|cx| cx.quit());
1284        })
1285        .detach();
1286    });
1287}
1288
1289async fn handle_error(
1290    error: anyhow::Error,
1291    args: &EpArgs,
1292    command: &Command,
1293    app_state: &Arc<headless::EpAppState>,
1294    failfast_on_single_example: bool,
1295    example: &Example,
1296) {
1297    Progress::global().increment_failed();
1298
1299    let msg;
1300    if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1301        let example_name = example.spec.filename();
1302
1303        let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1304        app_state
1305            .fs
1306            .write(
1307                &failed_example_path,
1308                &serde_json::to_vec_pretty(&example).unwrap(),
1309            )
1310            .await
1311            .unwrap();
1312        let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1313        app_state
1314            .fs
1315            .write(&err_path, format!("{error:?}").as_bytes())
1316            .await
1317            .unwrap();
1318
1319        let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1320        let mut file = OpenOptions::new()
1321            .create(true)
1322            .append(true)
1323            .open(&failed_jsonl_path)
1324            .expect("Failed to open failed.jsonl");
1325        writeln!(file, "{}", serde_json::to_string(example).unwrap())
1326            .expect("Failed to write to failed.jsonl");
1327
1328        let cursor_path = match example.repo_name() {
1329            Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1330            Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1331        };
1332        msg = format!(
1333            indoc::indoc! {"
1334                While processing \"{}\":
1335
1336                \x1b[31m{:?}\x1b[0m
1337
1338                Example:        \x1b[36m{}\x1b[0m
1339                Error file:     \x1b[36m{}\x1b[0m
1340                Cursor file:    \x1b[36m{}\x1b[0m
1341                Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1342            "},
1343            example.spec.name,
1344            error,
1345            failed_example_path.display(),
1346            err_path.display(),
1347            cursor_path.display(),
1348            command,
1349            failed_example_path.display(),
1350        );
1351    } else {
1352        msg = format!(
1353            indoc::indoc! {"
1354            While processing \"{}\":
1355
1356                \x1b[31m{:?}\x1b[0m
1357            "},
1358            example.spec.name, error
1359        );
1360    }
1361
1362    if args.failfast || failfast_on_single_example {
1363        Progress::global().finalize();
1364        panic!("{}", msg);
1365    } else {
1366        log::error!("{}", msg);
1367    }
1368}