main.rs

   1mod anthropic_client;
   2mod distill;
   3mod example;
   4mod filter_languages;
   5mod format_prompt;
   6mod git;
   7mod headless;
   8mod load_project;
   9mod metrics;
  10mod openai_client;
  11mod parse_output;
  12mod paths;
  13mod predict;
  14mod progress;
  15mod prompt_assets;
  16mod pull_examples;
  17mod qa;
  18mod reorder_patch;
  19mod repair;
  20mod retrieve_context;
  21mod reversal_tracking;
  22mod score;
  23mod split_commit;
  24mod split_dataset;
  25mod sync_deployments;
  26mod synthesize;
  27mod truncate_expected_patch;
  28mod word_diff;
  29use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
  30use collections::HashSet;
  31use edit_prediction::EditPredictionStore;
  32use futures::channel::mpsc;
  33use futures::{SinkExt as _, StreamExt as _};
  34use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
  35use zeta_prompt::ZetaFormat;
  36
  37use reqwest_client::ReqwestClient;
  38use serde::{Deserialize, Deserializer, Serialize, Serializer};
  39use std::fmt::Display;
  40use std::fs::{File, OpenOptions};
  41use std::hash::{Hash, Hasher};
  42use std::io::{BufRead, BufReader, BufWriter, Write};
  43use std::sync::Mutex;
  44use std::{path::PathBuf, sync::Arc};
  45
  46use crate::distill::run_distill;
  47use crate::example::{Example, group_examples_by_repo, read_example_files};
  48use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
  49use crate::format_prompt::run_format_prompt;
  50use crate::load_project::run_load_project;
  51use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
  52use crate::predict::run_prediction;
  53use crate::progress::Progress;
  54use crate::retrieve_context::run_context_retrieval;
  55use crate::score::run_scoring;
  56use crate::split_commit::SplitCommitArgs;
  57use crate::split_dataset::SplitArgs;
  58use crate::synthesize::{SynthesizeConfig, run_synthesize};
  59use crate::truncate_expected_patch::TruncatePatchArgs;
  60
  61#[derive(Parser, Debug)]
  62#[command(name = "ep")]
  63struct EpArgs {
  64    #[arg(long, default_value_t = false)]
  65    printenv: bool,
  66    #[clap(long, default_value_t = 10, global = true)]
  67    max_parallelism: usize,
  68    /// The limit for the number of examples to process
  69    /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
  70    #[clap(long, global = true)]
  71    limit: Option<usize>,
  72    #[clap(long, global = true)]
  73    offset: Option<usize>,
  74    /// Filter examples by name
  75    #[clap(long, global = true)]
  76    name: Option<String>,
  77    /// Filter examples by repository
  78    #[clap(long, global = true)]
  79    repo: Option<String>,
  80    #[command(subcommand)]
  81    command: Option<Command>,
  82    #[clap(global = true, help = INPUTS_HELP)]
  83    inputs: Vec<PathBuf>,
  84    #[arg(long, short, global = true)]
  85    output: Option<PathBuf>,
  86    #[arg(long, short, global = true)]
  87    in_place: bool,
  88    #[arg(long, short, global = true)]
  89    failfast: bool,
  90    /// How to handle failed examples in output: keep them or skip them.
  91    /// Failed examples are always logged to the run's failed directory.
  92    #[arg(long, global = true, default_value = "keep")]
  93    failed: FailedHandling,
  94    /// Output as markdown files instead of JSONL. When set, -o specifies a directory
  95    /// where one .md file per example will be written (named after each example).
  96    #[arg(long, short, global = true)]
  97    markdown: bool,
  98}
  99
 100/// Controls whether failed examples are included in the main output.
 101/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
 102#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
 103pub enum FailedHandling {
 104    /// Include failed examples in the main output (default)
 105    #[default]
 106    Keep,
 107    /// Exclude failed examples from the main output
 108    Skip,
 109    /// Skip writing files
 110    SkipNoFiles,
 111}
 112
 113const INPUTS_HELP: &str = r#"
 114Inputs can be file paths or special specifiers:
 115
 116  path
 117      Path to an example(s) file (.md, .json, or .jsonl)
 118
 119  captured-after:{timestamp}
 120      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 121      These are examples captured via the "Capture Edit Prediction Example" action.
 122
 123  rejected-after:{timestamp}
 124      Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
 125      These are predictions that were shown to users but rejected (useful for DPO training).
 126
 127  rated-after:{timestamp}
 128      Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
 129      These are predictions that users explicitly rated as positive or negative via the
 130      rate completions modal. Only zeta2 predictions are included.
 131      - Positive ratings: output becomes expected_patches
 132      - Negative ratings: output becomes rejected_patch
 133
 134  rated-positive-after:{timestamp}
 135      Same as rated-after, but only fetches positively rated predictions.
 136
 137  rated-negative-after:{timestamp}
 138      Same as rated-after, but only fetches negatively rated predictions.
 139
 140      Required environment variables to connect to Snowflake:
 141          EP_SNOWFLAKE_API_KEY
 142          EP_SNOWFLAKE_BASE_URL
 143
 144      Optional:
 145          EP_SNOWFLAKE_ROLE
 146
 147Examples:
 148
 149  # Read examples from a file
 150  ep read examples.jsonl -o output.jsonl
 151
 152  # Read captured examples after a timestamp
 153  ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
 154
 155  # Read rejected predictions for DPO training
 156  ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
 157
 158  # Read user-rated predictions
 159  ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
 160
 161  # Read only positively rated predictions
 162  ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
 163
 164  # Read only negatively rated predictions
 165  ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
 166
 167  # Mix multiple input sources
 168  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
 169"#;
 170
 171#[derive(Subcommand, Debug, Clone)]
 172enum Command {
 173    /// Read examples from files or fetch from Snowflake, output as .jsonl
 174    Read,
 175    /// Create git worktrees for each example and load file contents
 176    LoadProject,
 177    /// Retrieve context for input examples.
 178    Context,
 179    /// Generate a prompt string for a specific model
 180    FormatPrompt(FormatPromptArgs),
 181    /// Runs edit prediction
 182    Predict(PredictArgs),
 183    /// Parse model outputs (actual_output) into unified diffs (actual_patch).
 184    /// Requires format-prompt to have been run first. Uses provider from prompt.
 185    ParseOutput,
 186    /// Computes a score based on actual and expected patches
 187    Score(PredictArgs),
 188    /// Prepares a distillation dataset by copying expected outputs to
 189    /// predicted outputs and removing actual outputs and prompts.
 190    Distill,
 191    /// Print aggregated scores
 192    Eval(EvalArgs),
 193    /// Generate eval examples by analyzing git commits from a repository
 194    Synthesize(SynthesizeArgs),
 195    /// Remove git repositories and worktrees
 196    Clean,
 197    /// Generate an evaluation example by splitting a chronologically-ordered commit
 198    SplitCommit(SplitCommitArgs),
 199    /// Truncate expected patch by the given criteria
 200    TruncatePatch(TruncatePatchArgs),
 201    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
 202    Split(SplitArgs),
 203    /// Filter a JSONL dataset by programming language (based on cursor_path extension)
 204    FilterLanguages(FilterLanguagesArgs),
 205    /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
 206    ImportBatch(ImportBatchArgs),
 207    /// Assess the quality of predictions using LLM-as-a-judge
 208    Qa(qa::QaArgs),
 209    /// Repair predictions that received poor QA scores by generating improved predictions
 210    Repair(repair::RepairArgs),
 211    /// Print all valid zeta formats (lowercase, one per line)
 212    PrintZetaFormats,
 213    /// Sync baseten deployment metadata to Snowflake for linking predictions to experiments
 214    SyncDeployments(SyncDeploymentsArgs),
 215}
 216
 217impl Display for Command {
 218    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 219        match self {
 220            Command::Read => write!(f, "read"),
 221            Command::LoadProject => write!(f, "load-project"),
 222            Command::Context => write!(f, "context"),
 223            Command::FormatPrompt(args) => {
 224                write!(f, "format-prompt --provider={}", args.provider)
 225            }
 226            Command::Predict(args) => match &args.provider {
 227                Some(provider) => write!(f, "predict --provider={}", provider),
 228                None => write!(f, "predict"),
 229            },
 230            Command::ParseOutput => write!(f, "parse-output"),
 231            Command::Score(args) => match &args.provider {
 232                Some(provider) => write!(f, "score --provider={}", provider),
 233                None => write!(f, "score"),
 234            },
 235            Command::Distill => write!(f, "distill"),
 236            Command::Eval(args) => match &args.predict.provider {
 237                Some(provider) => write!(f, "eval --provider={}", provider),
 238                None => write!(f, "eval"),
 239            },
 240            Command::Synthesize(args) => {
 241                write!(f, "synthesize --repos {}", args.repos.join(" "))
 242            }
 243            Command::Clean => write!(f, "clean"),
 244            Command::SplitCommit(_) => write!(f, "split-commit"),
 245            Command::TruncatePatch(_) => write!(f, "truncate-patch"),
 246            Command::Split(_) => write!(f, "split"),
 247            Command::FilterLanguages(_) => write!(f, "filter-languages"),
 248            Command::ImportBatch(args) => {
 249                write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
 250            }
 251            Command::Qa(_) => {
 252                write!(f, "qa")
 253            }
 254            Command::Repair(_) => {
 255                write!(f, "repair")
 256            }
 257            Command::PrintZetaFormats => {
 258                write!(f, "print-zeta-formats")
 259            }
 260            Command::SyncDeployments(_) => {
 261                write!(f, "sync-deployments")
 262            }
 263        }
 264    }
 265}
 266
 267#[derive(Debug, Args, Clone)]
 268struct SyncDeploymentsArgs {
 269    /// BaseTen model name (default: "zeta-2")
 270    #[arg(long)]
 271    model: Option<String>,
 272}
 273
 274#[derive(Debug, Args, Clone)]
 275struct FormatPromptArgs {
 276    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
 277    provider: PredictionProvider,
 278}
 279
 280#[derive(Debug, Args, Clone)]
 281struct PredictArgs {
 282    #[clap(long, short('p'))]
 283    provider: Option<PredictionProvider>,
 284    #[clap(long, default_value_t = 1)]
 285    repetitions: usize,
 286    /// Only use cached responses, don't queue new requests for batching
 287    #[clap(long)]
 288    cache_only: bool,
 289}
 290
 291#[derive(Debug, Args, Clone)]
 292struct EvalArgs {
 293    #[clap(flatten)]
 294    predict: PredictArgs,
 295    /// Path to write summary scores as JSON
 296    #[clap(long)]
 297    summary_json: Option<PathBuf>,
 298}
 299
 300#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 301pub enum TeacherBackend {
 302    Sonnet45,
 303    Gpt52,
 304}
 305
 306impl std::fmt::Display for TeacherBackend {
 307    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 308        match self {
 309            TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
 310            TeacherBackend::Gpt52 => write!(f, "gpt52"),
 311        }
 312    }
 313}
 314
 315impl std::str::FromStr for TeacherBackend {
 316    type Err = anyhow::Error;
 317
 318    fn from_str(s: &str) -> Result<Self, Self::Err> {
 319        match s.to_lowercase().as_str() {
 320            "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
 321            "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
 322            "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
 323            _ => anyhow::bail!("unknown teacher backend `{s}`. Valid options: sonnet45, gpt52"),
 324        }
 325    }
 326}
 327
 328impl TeacherBackend {
 329    pub fn model_name(&self) -> &'static str {
 330        match self {
 331            TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
 332            TeacherBackend::Gpt52 => "gpt-5.2",
 333        }
 334    }
 335}
 336
 337#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 338enum PredictionProvider {
 339    Sweep,
 340    Mercury,
 341    Zeta1,
 342    Zeta2(ZetaFormat),
 343    Teacher(TeacherBackend),
 344    TeacherNonBatching(TeacherBackend),
 345    Repair,
 346}
 347
 348impl Default for PredictionProvider {
 349    fn default() -> Self {
 350        PredictionProvider::Zeta2(ZetaFormat::default())
 351    }
 352}
 353
 354impl std::fmt::Display for PredictionProvider {
 355    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 356        match self {
 357            PredictionProvider::Sweep => write!(f, "sweep"),
 358            PredictionProvider::Mercury => write!(f, "mercury"),
 359            PredictionProvider::Zeta1 => write!(f, "zeta1"),
 360            PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
 361            PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
 362            PredictionProvider::TeacherNonBatching(backend) => {
 363                write!(f, "teacher-non-batching:{backend}")
 364            }
 365            PredictionProvider::Repair => write!(f, "repair"),
 366        }
 367    }
 368}
 369
 370impl std::str::FromStr for PredictionProvider {
 371    type Err = anyhow::Error;
 372
 373    fn from_str(s: &str) -> Result<Self, Self::Err> {
 374        let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
 375
 376        let provider_lower = provider.to_lowercase();
 377        match provider_lower.as_str() {
 378            "sweep" => Ok(PredictionProvider::Sweep),
 379            "mercury" => Ok(PredictionProvider::Mercury),
 380            "zeta1" => Ok(PredictionProvider::Zeta1),
 381            "zeta2" => {
 382                let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
 383                Ok(PredictionProvider::Zeta2(format))
 384            }
 385            "teacher" => {
 386                let backend = arg
 387                    .map(|a| a.parse())
 388                    .transpose()?
 389                    .unwrap_or(TeacherBackend::Sonnet45);
 390                Ok(PredictionProvider::Teacher(backend))
 391            }
 392            "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
 393                let backend = arg
 394                    .map(|a| a.parse())
 395                    .transpose()?
 396                    .unwrap_or(TeacherBackend::Sonnet45);
 397                Ok(PredictionProvider::TeacherNonBatching(backend))
 398            }
 399            "repair" => Ok(PredictionProvider::Repair),
 400            _ => {
 401                anyhow::bail!(
 402                    "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
 403                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
 404                 For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\
 405                 Available zeta versions:\n{}",
 406                    ZetaFormat::options_as_string()
 407                )
 408            }
 409        }
 410    }
 411}
 412
 413impl Serialize for PredictionProvider {
 414    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 415    where
 416        S: Serializer,
 417    {
 418        serializer.serialize_str(&self.to_string())
 419    }
 420}
 421
 422impl<'de> Deserialize<'de> for PredictionProvider {
 423    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 424    where
 425        D: Deserializer<'de>,
 426    {
 427        let s = String::deserialize(deserializer)?;
 428        s.parse().map_err(serde::de::Error::custom)
 429    }
 430}
 431
 432#[derive(Debug, Args, Clone)]
 433struct SynthesizeArgs {
 434    /// Repository URLs (git@github.com:owner/repo or https://...)
 435    #[clap(long, required = true, num_args = 1..)]
 436    repos: Vec<String>,
 437
 438    /// Number of examples to generate per repository
 439    #[clap(long, default_value_t = 5)]
 440    count: usize,
 441
 442    /// Maximum commits to scan per repository before giving up
 443    #[clap(long, default_value_t = 100)]
 444    max_commits: usize,
 445
 446    /// Ignore state file and reprocess all commits
 447    #[clap(long)]
 448    fresh: bool,
 449}
 450
 451#[derive(Debug, Args, Clone)]
 452struct ImportBatchArgs {
 453    /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
 454    #[clap(long, required = true, num_args = 1..)]
 455    batch_ids: Vec<String>,
 456    /// Which provider's batches to import (anthropic or openai)
 457    #[clap(long, default_value = "anthropic")]
 458    provider: BatchProvider,
 459}
 460
 461#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
 462enum BatchProvider {
 463    Anthropic,
 464    Openai,
 465}
 466
 467impl EpArgs {
 468    fn output_path(&self) -> Option<PathBuf> {
 469        if self.in_place {
 470            if self.inputs.len() == 1 {
 471                self.inputs.first().cloned()
 472            } else {
 473                panic!("--in-place requires exactly one input file")
 474            }
 475        } else {
 476            self.output.clone()
 477        }
 478    }
 479}
 480
 481/// Minimum Zed version required for Snowflake queries.
 482/// This version introduced the current request schema with predicted edits in the edit
 483/// history, and open source repos distinguished.
 484const MIN_CAPTURE_VERSION: pull_examples::MinCaptureVersion = pull_examples::MinCaptureVersion {
 485    minor: 224,
 486    patch: 1,
 487};
 488
 489async fn load_examples(
 490    http_client: Arc<dyn http_client::HttpClient>,
 491    args: &EpArgs,
 492    output_path: Option<&PathBuf>,
 493    background_executor: BackgroundExecutor,
 494) -> anyhow::Result<Vec<Example>> {
 495    let mut captured_after_timestamps = Vec::new();
 496    let mut rejected_after_timestamps = Vec::new();
 497    let mut requested_after_timestamps = Vec::new();
 498    let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
 499        Vec::new();
 500    let mut file_inputs = Vec::new();
 501
 502    for input in &args.inputs {
 503        let input_string = input.to_string_lossy();
 504        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
 505            captured_after_timestamps.push(timestamp.to_string());
 506        } else if let Some(timestamp) =
 507            pull_examples::parse_rejected_after_input(input_string.as_ref())
 508        {
 509            rejected_after_timestamps.push(timestamp.to_string());
 510        } else if let Some(timestamp) =
 511            pull_examples::parse_requested_after_input(input_string.as_ref())
 512        {
 513            requested_after_timestamps.push(timestamp.to_string());
 514        } else if let Some((timestamp, rating_filter)) =
 515            pull_examples::parse_rated_after_input(input_string.as_ref())
 516        {
 517            rated_after_inputs.push((timestamp.to_string(), rating_filter));
 518        } else {
 519            file_inputs.push(input.clone());
 520        }
 521    }
 522
 523    let mut examples = read_example_files(&file_inputs);
 524
 525    // Apply offset to file examples first, then pass remaining offset to Snowflake.
 526    let file_example_count = examples.len();
 527    let remaining_offset = if let Some(offset) = args.offset {
 528        if offset >= file_example_count {
 529            examples.clear();
 530            offset - file_example_count
 531        } else {
 532            examples.splice(0..offset, []);
 533            0
 534        }
 535    } else {
 536        0
 537    };
 538
 539    Progress::global().set_total_examples(examples.len());
 540
 541    let remaining_limit_for_snowflake =
 542        args.limit.map(|limit| limit.saturating_sub(examples.len()));
 543
 544    if let Some(0) = remaining_limit_for_snowflake {
 545        log::info!(
 546            "skipping Snowflake inputs because --limit is already satisfied by example files"
 547        );
 548    } else {
 549        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
 550
 551        if !captured_after_timestamps.is_empty() {
 552            captured_after_timestamps.sort();
 553
 554            let mut captured_examples = pull_examples::fetch_captured_examples_after(
 555                http_client.clone(),
 556                &captured_after_timestamps,
 557                max_rows_per_timestamp,
 558                remaining_offset,
 559                background_executor.clone(),
 560                Some(MIN_CAPTURE_VERSION),
 561            )
 562            .await?;
 563            examples.append(&mut captured_examples);
 564        }
 565
 566        if !rejected_after_timestamps.is_empty() {
 567            rejected_after_timestamps.sort();
 568
 569            let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
 570                http_client.clone(),
 571                &rejected_after_timestamps,
 572                max_rows_per_timestamp,
 573                remaining_offset,
 574                background_executor.clone(),
 575                Some(MIN_CAPTURE_VERSION),
 576            )
 577            .await?;
 578            examples.append(&mut rejected_examples);
 579        }
 580
 581        if !requested_after_timestamps.is_empty() {
 582            requested_after_timestamps.sort();
 583
 584            let mut requested_examples = pull_examples::fetch_requested_examples_after(
 585                http_client.clone(),
 586                &requested_after_timestamps,
 587                max_rows_per_timestamp,
 588                remaining_offset,
 589                background_executor.clone(),
 590                Some(MIN_CAPTURE_VERSION),
 591            )
 592            .await?;
 593            examples.append(&mut requested_examples);
 594        }
 595
 596        if !rated_after_inputs.is_empty() {
 597            rated_after_inputs.sort();
 598
 599            let mut rated_examples = pull_examples::fetch_rated_examples_after(
 600                http_client,
 601                &rated_after_inputs,
 602                max_rows_per_timestamp,
 603                remaining_offset,
 604                background_executor,
 605                Some(MIN_CAPTURE_VERSION),
 606            )
 607            .await?;
 608            examples.append(&mut rated_examples);
 609        }
 610    }
 611
 612    crate::example::sort_examples_by_repo_and_rev(&mut examples);
 613
 614    if let Some(name_filter) = &args.name {
 615        examples.retain(|example| example.spec.name.contains(name_filter));
 616    }
 617    if let Some(repo_filter) = &args.repo {
 618        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
 619    }
 620
 621    // Skip resume logic for --in-place since input and output are the same file,
 622    // which would incorrectly treat all input examples as already processed.
 623    if !args.in_place {
 624        if let Some(path) = output_path
 625            && let Some(command) = &args.command
 626        {
 627            resume_from_output(path, &mut examples, command);
 628        }
 629    }
 630
 631    if let Some(limit) = args.limit {
 632        examples.truncate(limit);
 633    }
 634
 635    let progress = Progress::global();
 636    progress.set_total_examples(examples.len());
 637    progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
 638
 639    Ok(examples)
 640}
 641
 642fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
 643    let mut hasher = collections::FxHasher::default();
 644    spec.hash(&mut hasher);
 645    hasher.finish()
 646}
 647
 648fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
 649    let file = match File::open(path) {
 650        Ok(f) => f,
 651        Err(_) => return,
 652    };
 653
 654    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
 655
 656    let reader = BufReader::new(file);
 657    let mut kept_lines = Vec::new();
 658    let mut kept_hashes = HashSet::default();
 659
 660    for line in reader.lines() {
 661        let line = match line {
 662            Ok(l) => l,
 663            Err(_) => continue,
 664        };
 665
 666        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
 667            let hash = spec_hash(&output_example.spec);
 668            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
 669                let is_complete = match command {
 670                    Command::Qa(_) => output_example
 671                        .qa
 672                        .first()
 673                        .and_then(|q| q.as_ref())
 674                        .and_then(|q| q.confidence)
 675                        .is_some(),
 676                    Command::Repair(_) => output_example.predictions.iter().any(|p| {
 677                        p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
 678                    }),
 679                    _ => true,
 680                };
 681                if is_complete {
 682                    kept_hashes.insert(hash);
 683                    kept_lines.push(line);
 684                }
 685            }
 686        }
 687    }
 688
 689    let total = examples.len();
 690    let already_processed = kept_hashes.len();
 691
 692    eprintln!(
 693        "Resuming: {}/{} examples already processed",
 694        already_processed, total
 695    );
 696
 697    let file = OpenOptions::new()
 698        .write(true)
 699        .truncate(true)
 700        .open(path)
 701        .expect("Failed to open output file for rewriting");
 702    let mut writer = BufWriter::new(file);
 703    for line in &kept_lines {
 704        writeln!(writer, "{}", line).expect("Failed to write to output file");
 705    }
 706    writer.flush().expect("Failed to flush output file");
 707
 708    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
 709}
 710
 711fn main() {
 712    let args = EpArgs::parse();
 713
 714    if args.printenv {
 715        ::util::shell_env::print_env();
 716        return;
 717    }
 718
 719    let output = args.output_path();
 720
 721    if args.markdown && output.is_none() {
 722        eprintln!("--markdown requires -o to specify the output directory");
 723        std::process::exit(1);
 724    }
 725
 726    let command = match &args.command {
 727        Some(cmd) => cmd.clone(),
 728        None => {
 729            EpArgs::command().print_help().unwrap();
 730            return;
 731        }
 732    };
 733
 734    match &command {
 735        Command::ImportBatch(import_args) => {
 736            smol::block_on(async {
 737                match import_args.provider {
 738                    BatchProvider::Anthropic => {
 739                        let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
 740                            .expect("Failed to create Anthropic client");
 741                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 742                            eprintln!("Error importing Anthropic batches: {:?}", e);
 743                            std::process::exit(1);
 744                        }
 745                    }
 746                    BatchProvider::Openai => {
 747                        let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
 748                            .expect("Failed to create OpenAI client");
 749                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 750                            eprintln!("Error importing OpenAI batches: {:?}", e);
 751                            std::process::exit(1);
 752                        }
 753                    }
 754                }
 755                println!(
 756                    "Successfully imported {} batch(es)",
 757                    import_args.batch_ids.len()
 758                );
 759            });
 760            return;
 761        }
 762        Command::Clean => {
 763            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
 764            return;
 765        }
 766        Command::PrintZetaFormats => {
 767            use strum::IntoEnumIterator as _;
 768            for format in ZetaFormat::iter() {
 769                println!("{}", format.to_string().to_lowercase());
 770            }
 771            return;
 772        }
 773        Command::SyncDeployments(sync_args) => {
 774            let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
 775            smol::block_on(async {
 776                if let Err(e) =
 777                    sync_deployments::run_sync_deployments(http_client, sync_args.model.clone())
 778                        .await
 779                {
 780                    eprintln!("Error: {:?}", e);
 781                    std::process::exit(1);
 782                }
 783            });
 784            return;
 785        }
 786        Command::Synthesize(synth_args) => {
 787            let Some(output_dir) = args.output else {
 788                panic!("output dir is required");
 789            };
 790            let config = SynthesizeConfig {
 791                repo_urls: synth_args.repos.clone(),
 792                count: synth_args.count,
 793                max_commits: synth_args.max_commits,
 794                output_dir,
 795                fresh: synth_args.fresh,
 796            };
 797            smol::block_on(async {
 798                if let Err(e) = run_synthesize(config).await {
 799                    eprintln!("Error: {:?}", e);
 800                    std::process::exit(1);
 801                }
 802            });
 803            return;
 804        }
 805        Command::SplitCommit(split_commit_args) => {
 806            if let Err(error) = split_commit::run_split_commit(
 807                split_commit_args,
 808                &args.inputs,
 809                output.as_ref(),
 810                args.failed,
 811            ) {
 812                eprintln!("{error:#}");
 813                std::process::exit(1);
 814            }
 815            return;
 816        }
 817        Command::TruncatePatch(truncate_args) => {
 818            if let Err(error) =
 819                truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
 820            {
 821                eprintln!("{error:#}");
 822                std::process::exit(1);
 823            }
 824            return;
 825        }
 826        Command::Split(split_args) => {
 827            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
 828                eprintln!("{error:#}");
 829                std::process::exit(1);
 830            }
 831            return;
 832        }
 833        Command::FilterLanguages(filter_args) => {
 834            if let Err(error) =
 835                run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
 836            {
 837                eprintln!("{error:#}");
 838                std::process::exit(1);
 839            }
 840            return;
 841        }
 842
 843        _ => {}
 844    }
 845
 846    let http_client = Arc::new(ReqwestClient::new());
 847    let app = Application::headless().with_http_client(http_client);
 848
 849    app.run(move |cx| {
 850        let app_state = Arc::new(headless::init(cx));
 851        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
 852
 853        cx.spawn(async move |cx| {
 854            let result = async {
 855                let examples = load_examples(
 856                    app_state.client.http_client(),
 857                    &args,
 858                    output.as_ref(),
 859                    cx.background_executor().clone(),
 860                )
 861                .await?;
 862
 863                match &command {
 864                    Command::Predict(args) | Command::Score(args) => {
 865                        predict::sync_batches(args.provider.as_ref()).await?;
 866                    }
 867                    Command::Eval(args) => {
 868                        predict::sync_batches(args.predict.provider.as_ref()).await?;
 869                    }
 870                    Command::Qa(args) => {
 871                        qa::sync_batches(args).await?;
 872                    }
 873                    Command::Repair(args) => {
 874                        repair::sync_batches(args).await?;
 875                    }
 876                    _ => (),
 877                }
 878
 879                let failfast_on_single_example = examples.len() == 1;
 880
 881                // For --markdown mode, create the output directory if it doesn't exist
 882                if args.markdown {
 883                    let dir = output.as_ref().expect("--markdown requires -o");
 884                    if !dir.exists() {
 885                        std::fs::create_dir_all(dir)
 886                            .expect("Failed to create markdown output directory");
 887                    }
 888                }
 889
 890                // Set up JSONL output writer (not used in markdown mode)
 891                let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
 892                let mut in_place_temp_path: Option<PathBuf> = None;
 893                if !args.markdown
 894                    && let Some(output_path) = output.as_ref()
 895                {
 896                    let write_path = if args.in_place {
 897                        let temp = output_path.with_extension("jsonl.tmp");
 898                        in_place_temp_path = Some(temp.clone());
 899                        temp
 900                    } else {
 901                        output_path.clone()
 902                    };
 903
 904                    let file = OpenOptions::new()
 905                        .create(true)
 906                        .write(true)
 907                        .truncate(args.in_place)
 908                        .append(!args.in_place)
 909                        .open(&write_path)
 910                        .expect("Failed to open output file");
 911
 912                    let mut writer = BufWriter::new(file);
 913                    let (sender, mut receiver) = mpsc::unbounded::<String>();
 914                    cx.background_spawn(async move {
 915                        while let Some(line) = receiver.next().await {
 916                            writeln!(writer, "{}", line).expect("Failed to write example");
 917                            writer.flush().expect("Failed to flush output");
 918                        }
 919                    })
 920                    .detach();
 921                    output_sender = Some(sender);
 922                }
 923
 924                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
 925                let finished_examples = Mutex::new(Vec::new());
 926
 927                let mut tasks = Vec::new();
 928                for _ in 0..args.max_parallelism {
 929                    tasks.push(async {
 930                        loop {
 931                            let Some(mut repo_examples) =
 932                                grouped_examples.lock().unwrap().pop_front()
 933                            else {
 934                                break;
 935                            };
 936                            for example in &mut repo_examples {
 937                                let example_progress =
 938                                    Progress::global().start_group(&example.spec.name);
 939
 940                                let result = async {
 941                                    match &command {
 942                                        Command::Read => {}
 943                                        Command::LoadProject => {
 944                                            run_load_project(
 945                                                example,
 946                                                app_state.clone(),
 947                                                &example_progress,
 948                                                cx.clone(),
 949                                            )
 950                                            .await?;
 951                                        }
 952                                        Command::Context => {
 953                                            run_context_retrieval(
 954                                                example,
 955                                                app_state.clone(),
 956                                                &example_progress,
 957                                                cx.clone(),
 958                                            )
 959                                            .await?;
 960                                        }
 961                                        Command::FormatPrompt(args) => {
 962                                            run_format_prompt(
 963                                                example,
 964                                                args,
 965                                                app_state.clone(),
 966                                                &example_progress,
 967                                                cx.clone(),
 968                                            )
 969                                            .await?;
 970                                        }
 971                                        Command::Predict(args) => {
 972                                            run_prediction(
 973                                                example,
 974                                                args,
 975                                                app_state.clone(),
 976                                                &example_progress,
 977                                                cx.clone(),
 978                                            )
 979                                            .await?;
 980                                        }
 981                                        Command::ParseOutput => {
 982                                            parse_output::run_parse_output(example)?;
 983                                        }
 984                                        Command::Distill => {
 985                                            run_distill(example).await?;
 986                                        }
 987                                        Command::Score(args) => {
 988                                            run_scoring(
 989                                                example,
 990                                                args,
 991                                                app_state.clone(),
 992                                                &example_progress,
 993                                                cx.clone(),
 994                                            )
 995                                            .await?;
 996                                        }
 997                                        Command::Eval(args) => {
 998                                            run_scoring(
 999                                                example,
1000                                                &args.predict,
1001                                                app_state.clone(),
1002                                                &example_progress,
1003                                                cx.clone(),
1004                                            )
1005                                            .await?;
1006                                        }
1007                                        Command::Qa(args) => {
1008                                            qa::run_qa(example, args, &example_progress).await?;
1009                                        }
1010                                        Command::Repair(args) => {
1011                                            repair::run_repair(example, args, &example_progress)
1012                                                .await?;
1013                                        }
1014                                        Command::Clean
1015                                        | Command::Synthesize(_)
1016                                        | Command::SplitCommit(_)
1017                                        | Command::Split(_)
1018                                        | Command::TruncatePatch(_)
1019                                        | Command::FilterLanguages(_)
1020                                        | Command::ImportBatch(_)
1021                                        | Command::PrintZetaFormats
1022                                        | Command::SyncDeployments(_) => {
1023                                            unreachable!()
1024                                        }
1025                                    }
1026                                    anyhow::Ok(())
1027                                }
1028                                .await;
1029
1030                                let failed = if let Err(error) = result {
1031                                    handle_error(
1032                                        error,
1033                                        &args,
1034                                        &command,
1035                                        &app_state,
1036                                        failfast_on_single_example,
1037                                        &example,
1038                                    )
1039                                    .await;
1040                                    true
1041                                } else {
1042                                    false
1043                                };
1044
1045                                let should_write = !failed || args.failed == FailedHandling::Keep;
1046                                if should_write {
1047                                    if args.markdown {
1048                                        let markdown_dir =
1049                                            output.as_ref().expect("--markdown requires -o");
1050                                        let filename = format!("{}.md", example.spec.filename());
1051                                        let path = markdown_dir.join(&filename);
1052                                        let markdown = example.spec.to_markdown();
1053                                        std::fs::write(&path, &markdown)
1054                                            .expect("Failed to write markdown file");
1055                                    } else if let Some(ref mut sender) = output_sender.clone() {
1056                                        let line = serde_json::to_string(&example).unwrap();
1057                                        sender
1058                                            .send(line)
1059                                            .await
1060                                            .expect("Failed to send to output writer");
1061                                    } else if args.output.is_none()
1062                                        && !matches!(command, Command::Eval(_))
1063                                    {
1064                                        let line = serde_json::to_string(&example).unwrap();
1065                                        println!("{}", line);
1066                                    }
1067                                }
1068                            }
1069
1070                            let project = repo_examples
1071                                .iter()
1072                                .find_map(|e| e.state.as_ref().map(|s| s.project.clone()));
1073
1074                            if let Some(project) = project {
1075                                let mut cx = cx.clone();
1076
1077                                let shutdown_task: Task<()> =
1078                                    project.update(&mut cx, |project, cx| {
1079                                        let lsp_store = project.lsp_store();
1080                                        lsp_store.update(cx, |lsp_store, cx| {
1081                                            lsp_store.shutdown_all_language_servers(cx)
1082                                        })
1083                                    });
1084
1085                                shutdown_task.await;
1086
1087                                if let Some(ep_store) =
1088                                    cx.update(|cx| EditPredictionStore::try_global(cx))
1089                                {
1090                                    ep_store.update(&mut cx, |store, _| {
1091                                        store.remove_project(&project);
1092                                    });
1093                                }
1094                            }
1095
1096                            for example in &mut repo_examples {
1097                                example.state.take();
1098                            }
1099                            finished_examples
1100                                .lock()
1101                                .unwrap()
1102                                .extend_from_slice(&repo_examples);
1103                        }
1104                    });
1105                }
1106                futures::future::join_all(tasks).await;
1107
1108                Progress::global().finalize();
1109
1110                match &command {
1111                    Command::Predict(args) | Command::Score(args) => {
1112                        predict::sync_batches(args.provider.as_ref()).await?;
1113                    }
1114                    Command::Eval(args) => {
1115                        predict::sync_batches(args.predict.provider.as_ref()).await?;
1116                    }
1117                    Command::Qa(args) => {
1118                        qa::sync_batches(args).await?;
1119                    }
1120                    Command::Repair(args) => {
1121                        repair::sync_batches(args).await?;
1122                    }
1123                    _ => (),
1124                }
1125
1126                match &command {
1127                    Command::Eval(args) => {
1128                        let examples = finished_examples.lock().unwrap();
1129                        score::print_report(&examples);
1130                        if let Some(summary_path) = &args.summary_json {
1131                            score::write_summary_json(&examples, summary_path)?;
1132                        }
1133                    }
1134                    Command::Repair(args) => {
1135                        let examples = finished_examples.lock().unwrap();
1136                        repair::print_report(&examples, args.confidence_threshold);
1137                    }
1138                    _ => (),
1139                };
1140
1141                // For --in-place, atomically rename temp file to original
1142                if let Some(temp_path) = &in_place_temp_path {
1143                    let final_path = output.as_ref().expect("in_place_temp_path requires output");
1144                    std::fs::rename(temp_path, final_path)
1145                        .expect("Failed to rename temp file to final output");
1146                }
1147
1148                anyhow::Ok(())
1149            }
1150            .await;
1151
1152            if let Err(e) = result {
1153                panic!("Fatal error: {:?}", e);
1154            }
1155
1156            let _ = cx.update(|cx| cx.quit());
1157        })
1158        .detach();
1159    });
1160}
1161
1162async fn handle_error(
1163    error: anyhow::Error,
1164    args: &EpArgs,
1165    command: &Command,
1166    app_state: &Arc<headless::EpAppState>,
1167    failfast_on_single_example: bool,
1168    example: &Example,
1169) {
1170    Progress::global().increment_failed();
1171
1172    let msg;
1173    if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1174        let example_name = example.spec.filename();
1175
1176        let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1177        app_state
1178            .fs
1179            .write(
1180                &failed_example_path,
1181                &serde_json::to_vec_pretty(&example).unwrap(),
1182            )
1183            .await
1184            .unwrap();
1185        let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1186        app_state
1187            .fs
1188            .write(&err_path, format!("{error:?}").as_bytes())
1189            .await
1190            .unwrap();
1191
1192        let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1193        let mut file = OpenOptions::new()
1194            .create(true)
1195            .append(true)
1196            .open(&failed_jsonl_path)
1197            .expect("Failed to open failed.jsonl");
1198        writeln!(file, "{}", serde_json::to_string(example).unwrap())
1199            .expect("Failed to write to failed.jsonl");
1200
1201        let cursor_path = match example.repo_name() {
1202            Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1203            Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1204        };
1205        msg = format!(
1206            indoc::indoc! {"
1207                While processing \"{}\":
1208
1209                \x1b[31m{:?}\x1b[0m
1210
1211                Example:        \x1b[36m{}\x1b[0m
1212                Error file:     \x1b[36m{}\x1b[0m
1213                Cursor file:    \x1b[36m{}\x1b[0m
1214                Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1215            "},
1216            example.spec.name,
1217            error,
1218            failed_example_path.display(),
1219            err_path.display(),
1220            cursor_path.display(),
1221            command,
1222            failed_example_path.display(),
1223        );
1224    } else {
1225        msg = format!(
1226            indoc::indoc! {"
1227            While processing \"{}\":
1228
1229                \x1b[31m{:?}\x1b[0m
1230            "},
1231            example.spec.name, error
1232        );
1233    }
1234
1235    if args.failfast || failfast_on_single_example {
1236        Progress::global().finalize();
1237        panic!("{}", msg);
1238    } else {
1239        log::error!("{}", msg);
1240    }
1241}