main.rs

   1mod anthropic_client;
   2mod distill;
   3mod example;
   4mod filter_languages;
   5mod format_prompt;
   6mod git;
   7mod headless;
   8mod load_project;
   9mod metrics;
  10mod openai_client;
  11mod parse_output;
  12mod paths;
  13mod predict;
  14mod progress;
  15mod prompt_assets;
  16mod pull_examples;
  17mod qa;
  18mod reorder_patch;
  19mod repair;
  20mod retrieve_context;
  21mod reversal_tracking;
  22mod score;
  23mod split_commit;
  24mod split_dataset;
  25
  26mod synthesize;
  27mod truncate_expected_patch;
  28mod word_diff;
  29use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
  30use collections::{HashMap, HashSet};
  31use edit_prediction::EditPredictionStore;
  32use futures::channel::mpsc;
  33use futures::{SinkExt as _, StreamExt as _};
  34use gaoya::minhash::{
  35    MinHashIndex, MinHasher, MinHasher32, calculate_minhash_params, compute_minhash_similarity,
  36};
  37use gpui::{AppContext as _, BackgroundExecutor, Task};
  38use zeta_prompt::ZetaFormat;
  39
  40use reqwest_client::ReqwestClient;
  41use serde::{Deserialize, Deserializer, Serialize, Serializer};
  42use std::env;
  43use std::fmt::Display;
  44use std::fs::{File, OpenOptions};
  45use std::hash::{Hash, Hasher};
  46use std::io::{BufRead, BufReader, BufWriter, Write};
  47use std::sync::Mutex;
  48use std::{path::PathBuf, sync::Arc};
  49
  50use crate::distill::run_distill;
  51use crate::example::{Example, group_examples_by_repo, read_example_files};
  52use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
  53use crate::format_prompt::run_format_prompt;
  54use crate::load_project::run_load_project;
  55use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
  56use crate::predict::run_prediction;
  57use crate::progress::Progress;
  58use crate::retrieve_context::run_context_retrieval;
  59use crate::score::run_scoring;
  60use crate::split_commit::SplitCommitArgs;
  61use crate::split_dataset::SplitArgs;
  62use crate::synthesize::{SynthesizeConfig, run_synthesize};
  63use crate::truncate_expected_patch::TruncatePatchArgs;
  64
  65#[derive(Parser, Debug)]
  66#[command(name = "ep")]
  67struct EpArgs {
  68    #[arg(long, default_value_t = false)]
  69    printenv: bool,
  70    #[clap(long, default_value_t = 10, global = true)]
  71    max_parallelism: usize,
  72    /// The limit for the number of examples to process
  73    /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
  74    #[clap(long, global = true)]
  75    limit: Option<usize>,
  76    #[clap(long, global = true)]
  77    offset: Option<usize>,
  78    /// Filter examples by name
  79    #[clap(long, global = true)]
  80    name: Option<String>,
  81    /// Filter examples by repository
  82    #[clap(long, global = true)]
  83    repo: Option<String>,
  84    /// Deduplicate by cursor position and keep at most this many examples per cluster
  85    #[clap(long, global = true)]
  86    max_duplicates: Option<usize>,
  87    #[command(subcommand)]
  88    command: Option<Command>,
  89    /// Input file paths
  90    #[clap(global = true)]
  91    inputs: Vec<PathBuf>,
  92    #[arg(long, short, global = true)]
  93    output: Option<PathBuf>,
  94    #[arg(long, short, global = true)]
  95    in_place: bool,
  96    #[arg(long, short, global = true)]
  97    failfast: bool,
  98    /// How to handle failed examples in output: keep them or skip them.
  99    /// Failed examples are always logged to the run's failed directory.
 100    #[arg(long, global = true, default_value = "keep")]
 101    failed: FailedHandling,
 102    /// Output as markdown files instead of JSONL. When set, -o specifies a directory
 103    /// where one .md file per example will be written (named after each example).
 104    #[arg(long, short, global = true)]
 105    markdown: bool,
 106}
 107
 108/// Controls whether failed examples are included in the main output.
 109/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
 110#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
 111pub enum FailedHandling {
 112    /// Include failed examples in the main output (default)
 113    #[default]
 114    Keep,
 115    /// Exclude failed examples from the main output
 116    Skip,
 117    /// Skip writing files
 118    SkipNoFiles,
 119}
 120
 121const INPUTS_HELP: &str = r#"
 122Inputs can be file paths or special specifiers:
 123
 124  path
 125      Path to an example(s) file (.md, .json, or .jsonl)
 126
 127  captured-after:{timestamp}
 128      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 129      These are examples captured via the "Capture Edit Prediction Example" action.
 130
 131  rejected-after:{timestamp}
 132      Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
 133      These are predictions that were shown to users but rejected (useful for DPO training).
 134
 135  rated-after:{timestamp}
 136      Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
 137      These are predictions that users explicitly rated as positive or negative via the
 138      rate completions modal. Only zeta2 predictions are included.
 139      - Positive ratings: output becomes expected_patches
 140      - Negative ratings: output becomes rejected_patch
 141
 142  rated-positive-after:{timestamp}
 143      Same as rated-after, but only fetches positively rated predictions.
 144
 145  rated-negative-after:{timestamp}
 146      Same as rated-after, but only fetches negatively rated predictions.
 147
 148      Required environment variables to connect to Snowflake:
 149          EP_SNOWFLAKE_API_KEY
 150          EP_SNOWFLAKE_BASE_URL
 151
 152      Optional:
 153          EP_SNOWFLAKE_ROLE
 154
 155Examples:
 156
 157  # Read examples from a file
 158  ep read examples.jsonl -o output.jsonl
 159
 160  # Read captured examples after a timestamp
 161  ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
 162
 163  # Read rejected predictions for DPO training
 164  ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
 165
 166  # Read user-rated predictions
 167  ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
 168
 169  # Read only positively rated predictions
 170  ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
 171
 172  # Read only negatively rated predictions
 173  ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
 174
 175  # Mix multiple input sources
 176  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
 177"#;
 178
 179#[derive(Subcommand, Debug, Clone)]
 180enum Command {
 181    /// Read examples from files or fetch from Snowflake, output as .jsonl
 182    Read(ReadArgs),
 183    /// Create git worktrees for each example and load file contents
 184    LoadProject,
 185    /// Retrieve context for input examples.
 186    Context,
 187    /// Generate a prompt string for a specific model
 188    FormatPrompt(FormatPromptArgs),
 189    /// Runs edit prediction
 190    Predict(PredictArgs),
 191    /// Parse model outputs (actual_output) into unified diffs (actual_patch).
 192    /// Requires format-prompt to have been run first. Uses provider from prompt.
 193    ParseOutput,
 194    /// Computes a score based on actual and expected patches
 195    Score(PredictArgs),
 196    /// Prepares a distillation dataset by copying expected outputs to
 197    /// predicted outputs and removing actual outputs and prompts.
 198    Distill,
 199    /// Print aggregated scores
 200    Eval(EvalArgs),
 201    /// Generate eval examples by analyzing git commits from a repository
 202    Synthesize(SynthesizeArgs),
 203    /// Remove git repositories and worktrees
 204    Clean,
 205    /// Generate an evaluation example by splitting a chronologically-ordered commit
 206    SplitCommit(SplitCommitArgs),
 207    /// Truncate expected patch by the given criteria
 208    TruncatePatch(TruncatePatchArgs),
 209    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
 210    Split(SplitArgs),
 211    /// Filter a JSONL dataset by programming language (based on cursor_path extension)
 212    FilterLanguages(FilterLanguagesArgs),
 213    /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
 214    ImportBatch(ImportBatchArgs),
 215    /// Assess the quality of predictions using LLM-as-a-judge
 216    Qa(qa::QaArgs),
 217    /// Repair predictions that received poor QA scores by generating improved predictions
 218    Repair(repair::RepairArgs),
 219    /// Print all valid zeta formats (lowercase, one per line)
 220    PrintZetaFormats,
 221}
 222
 223impl Display for Command {
 224    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 225        match self {
 226            Command::Read(_) => write!(f, "read"),
 227            Command::LoadProject => write!(f, "load-project"),
 228            Command::Context => write!(f, "context"),
 229            Command::FormatPrompt(args) => {
 230                write!(f, "format-prompt --provider={}", args.provider)
 231            }
 232            Command::Predict(args) => match &args.provider {
 233                Some(provider) => write!(f, "predict --provider={}", provider),
 234                None => write!(f, "predict"),
 235            },
 236            Command::ParseOutput => write!(f, "parse-output"),
 237            Command::Score(args) => match &args.provider {
 238                Some(provider) => write!(f, "score --provider={}", provider),
 239                None => write!(f, "score"),
 240            },
 241            Command::Distill => write!(f, "distill"),
 242            Command::Eval(args) => match &args.predict.provider {
 243                Some(provider) => write!(f, "eval --provider={}", provider),
 244                None => write!(f, "eval"),
 245            },
 246            Command::Synthesize(args) => {
 247                write!(f, "synthesize --repos {}", args.repos.join(" "))
 248            }
 249            Command::Clean => write!(f, "clean"),
 250            Command::SplitCommit(_) => write!(f, "split-commit"),
 251            Command::TruncatePatch(_) => write!(f, "truncate-patch"),
 252            Command::Split(_) => write!(f, "split"),
 253            Command::FilterLanguages(_) => write!(f, "filter-languages"),
 254            Command::ImportBatch(args) => {
 255                write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
 256            }
 257            Command::Qa(_) => {
 258                write!(f, "qa")
 259            }
 260            Command::Repair(_) => {
 261                write!(f, "repair")
 262            }
 263            Command::PrintZetaFormats => {
 264                write!(f, "print-zeta-formats")
 265            }
 266        }
 267    }
 268}
 269
 270#[derive(Debug, Args, Clone)]
 271#[command(after_help = INPUTS_HELP)]
 272struct ReadArgs {}
 273
 274#[derive(Debug, Args, Clone)]
 275struct FormatPromptArgs {
 276    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
 277    provider: PredictionProvider,
 278}
 279
 280#[derive(Debug, Args, Clone)]
 281struct PredictArgs {
 282    #[clap(long, short('p'))]
 283    provider: Option<PredictionProvider>,
 284    #[clap(long, default_value_t = 1)]
 285    repetitions: usize,
 286    /// Only use cached responses, don't queue new requests for batching
 287    #[clap(long)]
 288    cache_only: bool,
 289}
 290
 291#[derive(Debug, Args, Clone)]
 292struct EvalArgs {
 293    #[clap(flatten)]
 294    predict: PredictArgs,
 295    /// Path to write summary scores as JSON
 296    #[clap(long)]
 297    summary_json: Option<PathBuf>,
 298    /// Print all individual example lines (default: up to 20)
 299    #[clap(long)]
 300    verbose: bool,
 301}
 302
 303#[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)]
 304pub enum TeacherBackend {
 305    Sonnet46,
 306    #[default]
 307    Sonnet45,
 308    Gpt52,
 309}
 310
 311impl std::fmt::Display for TeacherBackend {
 312    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 313        match self {
 314            TeacherBackend::Sonnet46 => write!(f, "sonnet46"),
 315            TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
 316            TeacherBackend::Gpt52 => write!(f, "gpt52"),
 317        }
 318    }
 319}
 320
 321impl std::str::FromStr for TeacherBackend {
 322    type Err = anyhow::Error;
 323
 324    fn from_str(s: &str) -> Result<Self, Self::Err> {
 325        match s.to_lowercase().as_str() {
 326            "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
 327            "sonnet46" => Ok(TeacherBackend::Sonnet46),
 328            "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
 329            "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
 330            _ => anyhow::bail!(
 331                "unknown teacher backend `{s}`. Valid options: sonnet45, sonnet46, gpt52"
 332            ),
 333        }
 334    }
 335}
 336
 337impl TeacherBackend {
 338    pub fn model_name(&self) -> &'static str {
 339        match self {
 340            TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
 341            TeacherBackend::Sonnet46 => "claude-sonnet-4-6",
 342            TeacherBackend::Gpt52 => "gpt-5.2",
 343        }
 344    }
 345}
 346
 347#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 348enum PredictionProvider {
 349    Sweep,
 350    Mercury,
 351    Zeta1,
 352    Zeta2(ZetaFormat),
 353    Teacher(TeacherBackend),
 354    TeacherNonBatching(TeacherBackend),
 355    Repair,
 356}
 357
 358impl Default for PredictionProvider {
 359    fn default() -> Self {
 360        PredictionProvider::Zeta2(ZetaFormat::default())
 361    }
 362}
 363
 364impl std::fmt::Display for PredictionProvider {
 365    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 366        match self {
 367            PredictionProvider::Sweep => write!(f, "sweep"),
 368            PredictionProvider::Mercury => write!(f, "mercury"),
 369            PredictionProvider::Zeta1 => write!(f, "zeta1"),
 370            PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
 371            PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
 372            PredictionProvider::TeacherNonBatching(backend) => {
 373                write!(f, "teacher-non-batching:{backend}")
 374            }
 375            PredictionProvider::Repair => write!(f, "repair"),
 376        }
 377    }
 378}
 379
 380impl std::str::FromStr for PredictionProvider {
 381    type Err = anyhow::Error;
 382
 383    fn from_str(s: &str) -> Result<Self, Self::Err> {
 384        let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
 385
 386        let provider_lower = provider.to_lowercase();
 387        match provider_lower.as_str() {
 388            "sweep" => Ok(PredictionProvider::Sweep),
 389            "mercury" => Ok(PredictionProvider::Mercury),
 390            "zeta1" => Ok(PredictionProvider::Zeta1),
 391            "zeta2" => {
 392                let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
 393                Ok(PredictionProvider::Zeta2(format))
 394            }
 395            "teacher" => {
 396                let backend = arg
 397                    .map(|a| a.parse())
 398                    .transpose()?
 399                    .unwrap_or(TeacherBackend::default());
 400                Ok(PredictionProvider::Teacher(backend))
 401            }
 402            "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
 403                let backend = arg
 404                    .map(|a| a.parse())
 405                    .transpose()?
 406                    .unwrap_or(TeacherBackend::default());
 407                Ok(PredictionProvider::TeacherNonBatching(backend))
 408            }
 409            "repair" => Ok(PredictionProvider::Repair),
 410            _ => {
 411                anyhow::bail!(
 412                    "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
 413                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
 414                 For teacher, you can specify a backend like `teacher:sonnet46` or `teacher:gpt52`.\n\
 415                 Available zeta versions:\n{}",
 416                    ZetaFormat::options_as_string()
 417                )
 418            }
 419        }
 420    }
 421}
 422
 423impl Serialize for PredictionProvider {
 424    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 425    where
 426        S: Serializer,
 427    {
 428        serializer.serialize_str(&self.to_string())
 429    }
 430}
 431
 432impl<'de> Deserialize<'de> for PredictionProvider {
 433    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 434    where
 435        D: Deserializer<'de>,
 436    {
 437        let s = String::deserialize(deserializer)?;
 438        s.parse().map_err(serde::de::Error::custom)
 439    }
 440}
 441
 442#[derive(Debug, Args, Clone)]
 443struct SynthesizeArgs {
 444    /// Repository URLs (git@github.com:owner/repo or https://...)
 445    #[clap(long, required = true, num_args = 1..)]
 446    repos: Vec<String>,
 447
 448    /// Number of examples to generate per repository
 449    #[clap(long, default_value_t = 5)]
 450    count: usize,
 451
 452    /// Maximum commits to scan per repository before giving up
 453    #[clap(long, default_value_t = 100)]
 454    max_commits: usize,
 455
 456    /// Ignore state file and reprocess all commits
 457    #[clap(long)]
 458    fresh: bool,
 459}
 460
 461#[derive(Debug, Args, Clone)]
 462struct ImportBatchArgs {
 463    /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
 464    #[clap(long, required = true, num_args = 1..)]
 465    batch_ids: Vec<String>,
 466    /// Which provider's batches to import (anthropic or openai)
 467    #[clap(long, default_value = "anthropic")]
 468    provider: BatchProvider,
 469}
 470
 471#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
 472enum BatchProvider {
 473    Anthropic,
 474    Openai,
 475}
 476
 477impl EpArgs {
 478    fn output_path(&self) -> Option<PathBuf> {
 479        if self.in_place {
 480            if self.inputs.len() == 1 {
 481                self.inputs.first().cloned()
 482            } else {
 483                panic!("--in-place requires exactly one input file")
 484            }
 485        } else {
 486            self.output.clone()
 487        }
 488    }
 489}
 490
 491/// Minimum Zed version required for Snowflake queries.
 492/// This version introduced the current request schema with predicted edits in the edit
 493/// history, and open source repos distinguished.
 494const MIN_CAPTURE_VERSION: pull_examples::MinCaptureVersion = pull_examples::MinCaptureVersion {
 495    minor: 224,
 496    patch: 1,
 497};
 498
 499fn deduplicate_examples(examples: &mut Vec<Example>, max_per_cluster: usize) {
 500    let total_before_exact = examples.len();
 501    let mut seen_positions = HashSet::default();
 502    examples.retain(|example| seen_positions.insert(example.spec.cursor_position.clone()));
 503    log::info!(
 504        "exact duplicate filter: {total_before_exact} examples → {} examples ({} removed)",
 505        examples.len(),
 506        total_before_exact - examples.len(),
 507    );
 508
 509    const JACCARD_THRESHOLD: f64 = 0.5;
 510    const NUM_HASHES: usize = 128;
 511    const TOKEN_NGRAM_SIZE: usize = 5;
 512
 513    let (num_bands, band_width) = calculate_minhash_params(JACCARD_THRESHOLD, NUM_HASHES);
 514    let num_hashes = num_bands * band_width;
 515    let minhasher = MinHasher32::new(num_hashes);
 516    let mut index: MinHashIndex<u32, usize> =
 517        MinHashIndex::new(num_bands, band_width, JACCARD_THRESHOLD);
 518
 519    let signatures: Vec<Vec<u32>> = examples
 520        .iter()
 521        .map(|example| {
 522            let shingles = code_token_ngrams(&example.spec.cursor_position, TOKEN_NGRAM_SIZE);
 523            minhasher.create_signature(shingles.iter())
 524        })
 525        .collect();
 526
 527    for (id, signature) in signatures.iter().enumerate() {
 528        index.insert(id, signature.clone());
 529    }
 530
 531    // Build clusters via union-find on LSH candidate pairs.
 532    let mut parent: Vec<usize> = (0..examples.len()).collect();
 533
 534    fn find(parent: &mut Vec<usize>, mut x: usize) -> usize {
 535        while parent[x] != x {
 536            parent[x] = parent[parent[x]];
 537            x = parent[x];
 538        }
 539        x
 540    }
 541
 542    for (id, signature) in signatures.iter().enumerate() {
 543        for candidate in index.query_owned(signature) {
 544            let (a, b) = (find(&mut parent, id), find(&mut parent, candidate));
 545            if a != b {
 546                parent[a] = b;
 547            }
 548        }
 549    }
 550
 551    let mut clusters: HashMap<usize, Vec<usize>> = HashMap::default();
 552    for id in 0..examples.len() {
 553        clusters.entry(find(&mut parent, id)).or_default().push(id);
 554    }
 555
 556    let mut keep: HashSet<usize> = HashSet::default();
 557    for members in clusters.values() {
 558        let selected = greedy_max_min_diverse(members, &signatures, max_per_cluster);
 559        keep.extend(selected);
 560    }
 561
 562    let total = examples.len();
 563    let mut kept_indices: Vec<usize> = keep.into_iter().collect();
 564    kept_indices.sort();
 565
 566    let mut retained = Vec::with_capacity(kept_indices.len());
 567    for index in kept_indices.into_iter().rev() {
 568        retained.push(examples.swap_remove(index));
 569    }
 570    retained.reverse();
 571
 572    *examples = retained;
 573    log::info!(
 574        "near-duplicate filter: {total} examples → {} examples ({} removed)",
 575        examples.len(),
 576        total - examples.len(),
 577    );
 578}
 579
 580fn greedy_max_min_diverse(members: &[usize], signatures: &[Vec<u32>], k: usize) -> Vec<usize> {
 581    if members.len() <= k {
 582        return members.to_vec();
 583    }
 584
 585    let mut selected = vec![members[0]];
 586    let mut min_dist: HashMap<usize, f64> = HashMap::default();
 587    for &member in &members[1..] {
 588        let dist = 1.0 - compute_minhash_similarity(&signatures[selected[0]], &signatures[member]);
 589        min_dist.insert(member, dist);
 590    }
 591
 592    while selected.len() < k {
 593        let &best = min_dist
 594            .iter()
 595            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
 596            .map(|(id, _)| id)
 597            .expect("min_dist should not be empty when selected.len() < k");
 598        selected.push(best);
 599        min_dist.remove(&best);
 600
 601        let best_sig = &signatures[best];
 602        for (member, current_min) in min_dist.iter_mut() {
 603            let dist = 1.0 - compute_minhash_similarity(best_sig, &signatures[*member]);
 604            if dist < *current_min {
 605                *current_min = dist;
 606            }
 607        }
 608    }
 609
 610    selected
 611}
 612
 613fn code_token_ngrams(code: &str, ngram_size: usize) -> Vec<String> {
 614    let tokens: Vec<&str> = word_diff::tokenize(code)
 615        .into_iter()
 616        .filter(|t| !t.trim().is_empty())
 617        .collect();
 618
 619    if tokens.len() < ngram_size {
 620        return vec![tokens.join("\0")];
 621    }
 622
 623    tokens
 624        .windows(ngram_size)
 625        .map(|window| window.join("\0"))
 626        .collect()
 627}
 628
 629async fn load_examples(
 630    http_client: Arc<dyn http_client::HttpClient>,
 631    args: &EpArgs,
 632    output_path: Option<&PathBuf>,
 633    background_executor: BackgroundExecutor,
 634) -> anyhow::Result<Vec<Example>> {
 635    let mut captured_after_timestamps = Vec::new();
 636    let mut rejected_after_timestamps = Vec::new();
 637    let mut requested_after_timestamps = Vec::new();
 638    let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
 639        Vec::new();
 640    let mut file_inputs = Vec::new();
 641
 642    for input in &args.inputs {
 643        let input_string = input.to_string_lossy();
 644        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
 645            captured_after_timestamps.push(timestamp.to_string());
 646        } else if let Some(timestamp) =
 647            pull_examples::parse_rejected_after_input(input_string.as_ref())
 648        {
 649            rejected_after_timestamps.push(timestamp.to_string());
 650        } else if let Some(timestamp) =
 651            pull_examples::parse_requested_after_input(input_string.as_ref())
 652        {
 653            requested_after_timestamps.push(timestamp.to_string());
 654        } else if let Some((timestamp, rating_filter)) =
 655            pull_examples::parse_rated_after_input(input_string.as_ref())
 656        {
 657            rated_after_inputs.push((timestamp.to_string(), rating_filter));
 658        } else {
 659            file_inputs.push(input.clone());
 660        }
 661    }
 662
 663    let mut examples = read_example_files(&file_inputs);
 664
 665    // Apply offset to file examples first, then pass remaining offset to Snowflake.
 666    let file_example_count = examples.len();
 667    let remaining_offset = if let Some(offset) = args.offset {
 668        if offset >= file_example_count {
 669            examples.clear();
 670            offset - file_example_count
 671        } else {
 672            examples.splice(0..offset, []);
 673            0
 674        }
 675    } else {
 676        0
 677    };
 678
 679    Progress::global().set_total_examples(examples.len());
 680
 681    let remaining_limit_for_snowflake =
 682        args.limit.map(|limit| limit.saturating_sub(examples.len()));
 683
 684    if let Some(0) = remaining_limit_for_snowflake {
 685        log::info!(
 686            "skipping Snowflake inputs because --limit is already satisfied by example files"
 687        );
 688    } else {
 689        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
 690
 691        if !rejected_after_timestamps.is_empty() {
 692            rejected_after_timestamps.sort();
 693
 694            let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
 695                http_client.clone(),
 696                &rejected_after_timestamps,
 697                max_rows_per_timestamp,
 698                remaining_offset,
 699                background_executor.clone(),
 700                Some(MIN_CAPTURE_VERSION),
 701            )
 702            .await?;
 703            examples.append(&mut rejected_examples);
 704        }
 705
 706        if !requested_after_timestamps.is_empty() {
 707            requested_after_timestamps.sort();
 708
 709            let mut requested_examples = pull_examples::fetch_requested_examples_after(
 710                http_client.clone(),
 711                &requested_after_timestamps,
 712                max_rows_per_timestamp,
 713                remaining_offset,
 714                background_executor.clone(),
 715                Some(MIN_CAPTURE_VERSION),
 716            )
 717            .await?;
 718            examples.append(&mut requested_examples);
 719        }
 720
 721        if !rated_after_inputs.is_empty() {
 722            rated_after_inputs.sort();
 723
 724            let mut rated_examples = pull_examples::fetch_rated_examples_after(
 725                http_client,
 726                &rated_after_inputs,
 727                max_rows_per_timestamp,
 728                remaining_offset,
 729                background_executor,
 730                Some(MIN_CAPTURE_VERSION),
 731            )
 732            .await?;
 733            examples.append(&mut rated_examples);
 734        }
 735    }
 736
 737    crate::example::sort_examples_by_repo_and_rev(&mut examples);
 738
 739    if let Some(name_filter) = &args.name {
 740        examples.retain(|example| example.spec.name.contains(name_filter));
 741    }
 742    if let Some(repo_filter) = &args.repo {
 743        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
 744    }
 745
 746    // Skip resume logic for --in-place since input and output are the same file,
 747    // which would incorrectly treat all input examples as already processed.
 748    if !args.in_place {
 749        if let Some(path) = output_path
 750            && let Some(command) = &args.command
 751        {
 752            resume_from_output(path, &mut examples, command);
 753        }
 754    }
 755
 756    if let Some(max_duplicates) = args.max_duplicates {
 757        deduplicate_examples(&mut examples, max_duplicates);
 758    }
 759
 760    if let Some(limit) = args.limit {
 761        examples.truncate(limit);
 762    }
 763
 764    let progress = Progress::global();
 765    progress.set_total_examples(examples.len());
 766    progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
 767
 768    Ok(examples)
 769}
 770
 771fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
 772    let mut hasher = collections::FxHasher::default();
 773    spec.hash(&mut hasher);
 774    hasher.finish()
 775}
 776
 777fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
 778    let file = match File::open(path) {
 779        Ok(f) => f,
 780        Err(_) => return,
 781    };
 782
 783    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
 784
 785    let reader = BufReader::new(file);
 786    let mut kept_lines = Vec::new();
 787    let mut kept_hashes = HashSet::default();
 788
 789    for line in reader.lines() {
 790        let line = match line {
 791            Ok(l) => l,
 792            Err(_) => continue,
 793        };
 794
 795        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
 796            let hash = spec_hash(&output_example.spec);
 797            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
 798                let is_complete = match command {
 799                    Command::Qa(_) => output_example
 800                        .qa
 801                        .first()
 802                        .and_then(|q| q.as_ref())
 803                        .and_then(|q| q.confidence)
 804                        .is_some(),
 805                    Command::Repair(_) => output_example.predictions.iter().any(|p| {
 806                        p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
 807                    }),
 808                    _ => true,
 809                };
 810                if is_complete {
 811                    kept_hashes.insert(hash);
 812                    kept_lines.push(line);
 813                }
 814            }
 815        }
 816    }
 817
 818    let total = examples.len();
 819    let already_processed = kept_hashes.len();
 820
 821    eprintln!(
 822        "Resuming: {}/{} examples already processed",
 823        already_processed, total
 824    );
 825
 826    let file = OpenOptions::new()
 827        .write(true)
 828        .truncate(true)
 829        .open(path)
 830        .expect("Failed to open output file for rewriting");
 831    let mut writer = BufWriter::new(file);
 832    for line in &kept_lines {
 833        writeln!(writer, "{}", line).expect("Failed to write to output file");
 834    }
 835    writer.flush().expect("Failed to flush output file");
 836
 837    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
 838}
 839
 840fn main() {
 841    let args = EpArgs::parse();
 842
 843    if args.printenv {
 844        ::util::shell_env::print_env();
 845        return;
 846    }
 847
 848    let output = args.output_path();
 849
 850    if args.markdown && output.is_none() {
 851        eprintln!("--markdown requires -o to specify the output directory");
 852        std::process::exit(1);
 853    }
 854
 855    let command = match &args.command {
 856        Some(cmd) => cmd.clone(),
 857        None => {
 858            EpArgs::command().print_help().unwrap();
 859            return;
 860        }
 861    };
 862
 863    match &command {
 864        Command::ImportBatch(import_args) => {
 865            smol::block_on(async {
 866                match import_args.provider {
 867                    BatchProvider::Anthropic => {
 868                        let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
 869                            .expect("Failed to create Anthropic client");
 870                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 871                            eprintln!("Error importing Anthropic batches: {:?}", e);
 872                            std::process::exit(1);
 873                        }
 874                    }
 875                    BatchProvider::Openai => {
 876                        let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
 877                            .expect("Failed to create OpenAI client");
 878                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 879                            eprintln!("Error importing OpenAI batches: {:?}", e);
 880                            std::process::exit(1);
 881                        }
 882                    }
 883                }
 884                println!(
 885                    "Successfully imported {} batch(es)",
 886                    import_args.batch_ids.len()
 887                );
 888            });
 889            return;
 890        }
 891        Command::Clean => {
 892            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
 893            return;
 894        }
 895        Command::PrintZetaFormats => {
 896            use strum::IntoEnumIterator as _;
 897            for format in ZetaFormat::iter() {
 898                println!("{}", format.to_string().to_lowercase());
 899            }
 900            return;
 901        }
 902
 903        Command::Synthesize(synth_args) => {
 904            let output_dir = if let Some(output_dir) = args.output {
 905                output_dir
 906            } else {
 907                let default_output_dir = env::current_dir()
 908                    .unwrap()
 909                    .join("crates/edit_prediction_cli/evals-generated");
 910                if default_output_dir.parent().unwrap().exists() {
 911                    std::fs::create_dir(&default_output_dir).ok();
 912                    default_output_dir
 913                } else {
 914                    panic!("output dir is required");
 915                }
 916            };
 917            let config = SynthesizeConfig {
 918                repo_urls: synth_args.repos.clone(),
 919                count: synth_args.count,
 920                max_commits: synth_args.max_commits,
 921                output_dir,
 922                fresh: synth_args.fresh,
 923            };
 924            smol::block_on(async {
 925                if let Err(e) = run_synthesize(config).await {
 926                    eprintln!("Error: {:?}", e);
 927                    std::process::exit(1);
 928                }
 929            });
 930            return;
 931        }
 932        Command::SplitCommit(split_commit_args) => {
 933            if let Err(error) = split_commit::run_split_commit(
 934                split_commit_args,
 935                &args.inputs,
 936                output.as_ref(),
 937                args.failed,
 938            ) {
 939                eprintln!("{error:#}");
 940                std::process::exit(1);
 941            }
 942            return;
 943        }
 944        Command::TruncatePatch(truncate_args) => {
 945            if let Err(error) =
 946                truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
 947            {
 948                eprintln!("{error:#}");
 949                std::process::exit(1);
 950            }
 951            return;
 952        }
 953        Command::Split(split_args) => {
 954            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
 955                eprintln!("{error:#}");
 956                std::process::exit(1);
 957            }
 958            return;
 959        }
 960        Command::FilterLanguages(filter_args) => {
 961            if let Err(error) =
 962                run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
 963            {
 964                eprintln!("{error:#}");
 965                std::process::exit(1);
 966            }
 967            return;
 968        }
 969
 970        _ => {}
 971    }
 972
 973    let http_client = Arc::new(ReqwestClient::new());
 974    let app = gpui_platform::headless().with_http_client(http_client);
 975
 976    app.run(move |cx| {
 977        let app_state = Arc::new(headless::init(cx));
 978        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
 979
 980        cx.spawn(async move |cx| {
 981            let result = async {
 982                let examples = load_examples(
 983                    app_state.client.http_client(),
 984                    &args,
 985                    output.as_ref(),
 986                    cx.background_executor().clone(),
 987                )
 988                .await?;
 989
 990                match &command {
 991                    Command::Predict(args) | Command::Score(args) => {
 992                        predict::sync_batches(args.provider.as_ref()).await?;
 993                    }
 994                    Command::Eval(args) => {
 995                        predict::sync_batches(args.predict.provider.as_ref()).await?;
 996                    }
 997                    Command::Qa(args) => {
 998                        qa::sync_batches(args).await?;
 999                    }
1000                    Command::Repair(args) => {
1001                        repair::sync_batches(args).await?;
1002                    }
1003                    _ => (),
1004                }
1005
1006                let failfast_on_single_example = examples.len() == 1;
1007
1008                // For --markdown mode, create the output directory if it doesn't exist
1009                if args.markdown {
1010                    let dir = output.as_ref().expect("--markdown requires -o");
1011                    if !dir.exists() {
1012                        std::fs::create_dir_all(dir)
1013                            .expect("Failed to create markdown output directory");
1014                    }
1015                }
1016
1017                // Set up JSONL output writer (not used in markdown mode)
1018                let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
1019                let mut in_place_temp_path: Option<PathBuf> = None;
1020                if !args.markdown
1021                    && let Some(output_path) = output.as_ref()
1022                {
1023                    let write_path = if args.in_place {
1024                        let temp = output_path.with_extension("jsonl.tmp");
1025                        in_place_temp_path = Some(temp.clone());
1026                        temp
1027                    } else {
1028                        output_path.clone()
1029                    };
1030
1031                    let file = OpenOptions::new()
1032                        .create(true)
1033                        .write(true)
1034                        .truncate(args.in_place)
1035                        .append(!args.in_place)
1036                        .open(&write_path)
1037                        .expect("Failed to open output file");
1038
1039                    let mut writer = BufWriter::new(file);
1040                    let (sender, mut receiver) = mpsc::unbounded::<String>();
1041                    cx.background_spawn(async move {
1042                        while let Some(line) = receiver.next().await {
1043                            writeln!(writer, "{}", line).expect("Failed to write example");
1044                            writer.flush().expect("Failed to flush output");
1045                        }
1046                    })
1047                    .detach();
1048                    output_sender = Some(sender);
1049                }
1050
1051                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
1052                let finished_examples = Mutex::new(Vec::new());
1053
1054                let mut tasks = Vec::new();
1055                for _ in 0..args.max_parallelism {
1056                    tasks.push(async {
1057                        loop {
1058                            let Some(mut repo_examples) =
1059                                grouped_examples.lock().unwrap().pop_front()
1060                            else {
1061                                break;
1062                            };
1063                            for example in &mut repo_examples {
1064                                let example_progress =
1065                                    Progress::global().start_group(&example.spec.name);
1066
1067                                let result = async {
1068                                    match &command {
1069                                        Command::Read(_) => {}
1070                                        Command::LoadProject => {
1071                                            run_load_project(
1072                                                example,
1073                                                app_state.clone(),
1074                                                &example_progress,
1075                                                cx.clone(),
1076                                            )
1077                                            .await?;
1078                                        }
1079                                        Command::Context => {
1080                                            run_context_retrieval(
1081                                                example,
1082                                                app_state.clone(),
1083                                                &example_progress,
1084                                                cx.clone(),
1085                                            )
1086                                            .await?;
1087                                        }
1088                                        Command::FormatPrompt(args) => {
1089                                            run_format_prompt(
1090                                                example,
1091                                                args,
1092                                                app_state.clone(),
1093                                                &example_progress,
1094                                                cx.clone(),
1095                                            )
1096                                            .await?;
1097                                        }
1098                                        Command::Predict(args) => {
1099                                            run_prediction(
1100                                                example,
1101                                                args,
1102                                                app_state.clone(),
1103                                                &example_progress,
1104                                                cx.clone(),
1105                                            )
1106                                            .await?;
1107                                        }
1108                                        Command::ParseOutput => {
1109                                            parse_output::run_parse_output(example)?;
1110                                        }
1111                                        Command::Distill => {
1112                                            run_distill(example).await?;
1113                                        }
1114                                        Command::Score(args) => {
1115                                            run_scoring(
1116                                                example,
1117                                                args,
1118                                                app_state.clone(),
1119                                                &example_progress,
1120                                                cx.clone(),
1121                                            )
1122                                            .await?;
1123                                        }
1124                                        Command::Eval(args) => {
1125                                            run_scoring(
1126                                                example,
1127                                                &args.predict,
1128                                                app_state.clone(),
1129                                                &example_progress,
1130                                                cx.clone(),
1131                                            )
1132                                            .await?;
1133                                        }
1134                                        Command::Qa(args) => {
1135                                            qa::run_qa(example, args, &example_progress).await?;
1136                                        }
1137                                        Command::Repair(args) => {
1138                                            repair::run_repair(example, args, &example_progress)
1139                                                .await?;
1140                                        }
1141                                        Command::Clean
1142                                        | Command::Synthesize(_)
1143                                        | Command::SplitCommit(_)
1144                                        | Command::Split(_)
1145                                        | Command::TruncatePatch(_)
1146                                        | Command::FilterLanguages(_)
1147                                        | Command::ImportBatch(_)
1148                                        | Command::PrintZetaFormats => {
1149                                            unreachable!()
1150                                        }
1151                                    }
1152                                    anyhow::Ok(())
1153                                }
1154                                .await;
1155
1156                                let failed = if let Err(error) = result {
1157                                    handle_error(
1158                                        error,
1159                                        &args,
1160                                        &command,
1161                                        &app_state,
1162                                        failfast_on_single_example,
1163                                        &example,
1164                                    )
1165                                    .await;
1166                                    true
1167                                } else {
1168                                    false
1169                                };
1170
1171                                let should_write = !failed || args.failed == FailedHandling::Keep;
1172                                if should_write {
1173                                    if args.markdown {
1174                                        let markdown_dir =
1175                                            output.as_ref().expect("--markdown requires -o");
1176                                        let filename = format!("{}.md", example.spec.filename());
1177                                        let path = markdown_dir.join(&filename);
1178                                        let markdown = example.spec.to_markdown();
1179                                        std::fs::write(&path, &markdown)
1180                                            .expect("Failed to write markdown file");
1181                                    } else if let Some(ref mut sender) = output_sender.clone() {
1182                                        let line = serde_json::to_string(&example).unwrap();
1183                                        sender
1184                                            .send(line)
1185                                            .await
1186                                            .expect("Failed to send to output writer");
1187                                    } else if args.output.is_none()
1188                                        && !matches!(command, Command::Eval(_))
1189                                    {
1190                                        let line = serde_json::to_string(&example).unwrap();
1191                                        println!("{}", line);
1192                                    }
1193                                }
1194                            }
1195
1196                            let project = repo_examples
1197                                .iter()
1198                                .find_map(|e| e.state.as_ref().map(|s| s.project.clone()));
1199
1200                            if let Some(project) = project {
1201                                let mut cx = cx.clone();
1202
1203                                let shutdown_task: Task<()> =
1204                                    project.update(&mut cx, |project, cx| {
1205                                        let lsp_store = project.lsp_store();
1206                                        lsp_store.update(cx, |lsp_store, cx| {
1207                                            lsp_store.shutdown_all_language_servers(cx)
1208                                        })
1209                                    });
1210
1211                                shutdown_task.await;
1212
1213                                if let Some(ep_store) =
1214                                    cx.update(|cx| EditPredictionStore::try_global(cx))
1215                                {
1216                                    ep_store.update(&mut cx, |store, _| {
1217                                        store.remove_project(&project);
1218                                    });
1219                                }
1220                            }
1221
1222                            for example in &mut repo_examples {
1223                                example.state.take();
1224                            }
1225                            finished_examples
1226                                .lock()
1227                                .unwrap()
1228                                .extend_from_slice(&repo_examples);
1229                        }
1230                    });
1231                }
1232                futures::future::join_all(tasks).await;
1233
1234                Progress::global().finalize();
1235
1236                match &command {
1237                    Command::Predict(args) | Command::Score(args) => {
1238                        predict::sync_batches(args.provider.as_ref()).await?;
1239                    }
1240                    Command::Eval(args) => {
1241                        predict::sync_batches(args.predict.provider.as_ref()).await?;
1242                    }
1243                    Command::Qa(args) => {
1244                        qa::sync_batches(args).await?;
1245                    }
1246                    Command::Repair(args) => {
1247                        repair::sync_batches(args).await?;
1248                    }
1249                    _ => (),
1250                }
1251
1252                match &command {
1253                    Command::Eval(args) => {
1254                        let examples = finished_examples.lock().unwrap();
1255                        score::print_report(&examples, args.verbose);
1256                        if let Some(summary_path) = &args.summary_json {
1257                            score::write_summary_json(&examples, summary_path)?;
1258                        }
1259                    }
1260                    Command::Repair(args) => {
1261                        let examples = finished_examples.lock().unwrap();
1262                        repair::print_report(&examples, args.confidence_threshold);
1263                    }
1264                    _ => (),
1265                };
1266
1267                // For --in-place, atomically rename temp file to original
1268                if let Some(temp_path) = &in_place_temp_path {
1269                    let final_path = output.as_ref().expect("in_place_temp_path requires output");
1270                    std::fs::rename(temp_path, final_path)
1271                        .expect("Failed to rename temp file to final output");
1272                }
1273
1274                anyhow::Ok(())
1275            }
1276            .await;
1277
1278            if let Err(e) = result {
1279                panic!("Fatal error: {:?}", e);
1280            }
1281
1282            let _ = cx.update(|cx| cx.quit());
1283        })
1284        .detach();
1285    });
1286}
1287
1288async fn handle_error(
1289    error: anyhow::Error,
1290    args: &EpArgs,
1291    command: &Command,
1292    app_state: &Arc<headless::EpAppState>,
1293    failfast_on_single_example: bool,
1294    example: &Example,
1295) {
1296    Progress::global().increment_failed();
1297
1298    let msg;
1299    if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1300        let example_name = example.spec.filename();
1301
1302        let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1303        app_state
1304            .fs
1305            .write(
1306                &failed_example_path,
1307                &serde_json::to_vec_pretty(&example).unwrap(),
1308            )
1309            .await
1310            .unwrap();
1311        let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1312        app_state
1313            .fs
1314            .write(&err_path, format!("{error:?}").as_bytes())
1315            .await
1316            .unwrap();
1317
1318        let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1319        let mut file = OpenOptions::new()
1320            .create(true)
1321            .append(true)
1322            .open(&failed_jsonl_path)
1323            .expect("Failed to open failed.jsonl");
1324        writeln!(file, "{}", serde_json::to_string(example).unwrap())
1325            .expect("Failed to write to failed.jsonl");
1326
1327        let cursor_path = match example.repo_name() {
1328            Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1329            Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1330        };
1331        msg = format!(
1332            indoc::indoc! {"
1333                While processing \"{}\":
1334
1335                \x1b[31m{:?}\x1b[0m
1336
1337                Example:        \x1b[36m{}\x1b[0m
1338                Error file:     \x1b[36m{}\x1b[0m
1339                Cursor file:    \x1b[36m{}\x1b[0m
1340                Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1341            "},
1342            example.spec.name,
1343            error,
1344            failed_example_path.display(),
1345            err_path.display(),
1346            cursor_path.display(),
1347            command,
1348            failed_example_path.display(),
1349        );
1350    } else {
1351        msg = format!(
1352            indoc::indoc! {"
1353            While processing \"{}\":
1354
1355                \x1b[31m{:?}\x1b[0m
1356            "},
1357            example.spec.name, error
1358        );
1359    }
1360
1361    if args.failfast || failfast_on_single_example {
1362        Progress::global().finalize();
1363        panic!("{}", msg);
1364    } else {
1365        log::error!("{}", msg);
1366    }
1367}