main.rs

   1mod anthropic_client;
   2mod distill;
   3mod example;
   4mod filter_languages;
   5mod format_prompt;
   6mod git;
   7mod headless;
   8mod load_project;
   9mod metrics;
  10mod openai_client;
  11mod parse_output;
  12mod paths;
  13mod predict;
  14mod progress;
  15mod prompt_assets;
  16mod pull_examples;
  17mod qa;
  18mod reorder_patch;
  19mod repair;
  20mod retrieve_context;
  21mod reversal_tracking;
  22mod score;
  23mod split_commit;
  24mod split_dataset;
  25mod sync_deployments;
  26mod synthesize;
  27mod truncate_expected_patch;
  28mod word_diff;
  29use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
  30use collections::HashSet;
  31use edit_prediction::EditPredictionStore;
  32use futures::channel::mpsc;
  33use futures::{SinkExt as _, StreamExt as _};
  34use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
  35use zeta_prompt::ZetaFormat;
  36
  37use reqwest_client::ReqwestClient;
  38use serde::{Deserialize, Deserializer, Serialize, Serializer};
  39use std::fmt::Display;
  40use std::fs::{File, OpenOptions};
  41use std::hash::{Hash, Hasher};
  42use std::io::{BufRead, BufReader, BufWriter, Write};
  43use std::sync::Mutex;
  44use std::{path::PathBuf, sync::Arc};
  45
  46use crate::distill::run_distill;
  47use crate::example::{Example, group_examples_by_repo, read_example_files};
  48use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
  49use crate::format_prompt::run_format_prompt;
  50use crate::load_project::run_load_project;
  51use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
  52use crate::predict::run_prediction;
  53use crate::progress::Progress;
  54use crate::retrieve_context::run_context_retrieval;
  55use crate::score::run_scoring;
  56use crate::split_commit::SplitCommitArgs;
  57use crate::split_dataset::SplitArgs;
  58use crate::synthesize::{SynthesizeConfig, run_synthesize};
  59use crate::truncate_expected_patch::TruncatePatchArgs;
  60
  61#[derive(Parser, Debug)]
  62#[command(name = "ep")]
  63struct EpArgs {
  64    #[arg(long, default_value_t = false)]
  65    printenv: bool,
  66    #[clap(long, default_value_t = 10, global = true)]
  67    max_parallelism: usize,
  68    /// The limit for the number of examples to process
  69    /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
  70    #[clap(long, global = true)]
  71    limit: Option<usize>,
  72    #[clap(long, global = true)]
  73    offset: Option<usize>,
  74    /// Filter examples by name
  75    #[clap(long, global = true)]
  76    name: Option<String>,
  77    /// Filter examples by repository
  78    #[clap(long, global = true)]
  79    repo: Option<String>,
  80    #[command(subcommand)]
  81    command: Option<Command>,
  82    #[clap(global = true, help = INPUTS_HELP)]
  83    inputs: Vec<PathBuf>,
  84    #[arg(long, short, global = true)]
  85    output: Option<PathBuf>,
  86    #[arg(long, short, global = true)]
  87    in_place: bool,
  88    #[arg(long, short, global = true)]
  89    failfast: bool,
  90    /// How to handle failed examples in output: keep them or skip them.
  91    /// Failed examples are always logged to the run's failed directory.
  92    #[arg(long, global = true, default_value = "keep")]
  93    failed: FailedHandling,
  94    /// Output as markdown files instead of JSONL. When set, -o specifies a directory
  95    /// where one .md file per example will be written (named after each example).
  96    #[arg(long, short, global = true)]
  97    markdown: bool,
  98}
  99
 100/// Controls whether failed examples are included in the main output.
 101/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
 102#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
 103pub enum FailedHandling {
 104    /// Include failed examples in the main output (default)
 105    #[default]
 106    Keep,
 107    /// Exclude failed examples from the main output
 108    Skip,
 109    /// Skip writing files
 110    SkipNoFiles,
 111}
 112
 113const INPUTS_HELP: &str = r#"
 114Inputs can be file paths or special specifiers:
 115
 116  path
 117      Path to an example(s) file (.md, .json, or .jsonl)
 118
 119  captured-after:{timestamp}
 120      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 121      These are examples captured via the "Capture Edit Prediction Example" action.
 122
 123  rejected-after:{timestamp}
 124      Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
 125      These are predictions that were shown to users but rejected (useful for DPO training).
 126
 127  rated-after:{timestamp}
 128      Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
 129      These are predictions that users explicitly rated as positive or negative via the
 130      rate completions modal. Only zeta2 predictions are included.
 131      - Positive ratings: output becomes expected_patches
 132      - Negative ratings: output becomes rejected_patch
 133
 134  rated-positive-after:{timestamp}
 135      Same as rated-after, but only fetches positively rated predictions.
 136
 137  rated-negative-after:{timestamp}
 138      Same as rated-after, but only fetches negatively rated predictions.
 139
 140      Required environment variables to connect to Snowflake:
 141          EP_SNOWFLAKE_API_KEY
 142          EP_SNOWFLAKE_BASE_URL
 143
 144      Optional:
 145          EP_SNOWFLAKE_ROLE
 146
 147Examples:
 148
 149  # Read examples from a file
 150  ep read examples.jsonl -o output.jsonl
 151
 152  # Read captured examples after a timestamp
 153  ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
 154
 155  # Read rejected predictions for DPO training
 156  ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
 157
 158  # Read user-rated predictions
 159  ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
 160
 161  # Read only positively rated predictions
 162  ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
 163
 164  # Read only negatively rated predictions
 165  ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
 166
 167  # Mix multiple input sources
 168  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
 169"#;
 170
 171#[derive(Subcommand, Debug, Clone)]
 172enum Command {
 173    /// Read examples from files or fetch from Snowflake, output as .jsonl
 174    Read,
 175    /// Create git worktrees for each example and load file contents
 176    LoadProject,
 177    /// Retrieve context for input examples.
 178    Context,
 179    /// Generate a prompt string for a specific model
 180    FormatPrompt(FormatPromptArgs),
 181    /// Runs edit prediction
 182    Predict(PredictArgs),
 183    /// Parse model outputs (actual_output) into unified diffs (actual_patch).
 184    /// Requires format-prompt to have been run first. Uses provider from prompt.
 185    ParseOutput,
 186    /// Computes a score based on actual and expected patches
 187    Score(PredictArgs),
 188    /// Prepares a distillation dataset by copying expected outputs to
 189    /// predicted outputs and removing actual outputs and prompts.
 190    Distill,
 191    /// Print aggregated scores
 192    Eval(EvalArgs),
 193    /// Generate eval examples by analyzing git commits from a repository
 194    Synthesize(SynthesizeArgs),
 195    /// Remove git repositories and worktrees
 196    Clean,
 197    /// Generate an evaluation example by splitting a chronologically-ordered commit
 198    SplitCommit(SplitCommitArgs),
 199    /// Truncate expected patch by the given criteria
 200    TruncatePatch(TruncatePatchArgs),
 201    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
 202    Split(SplitArgs),
 203    /// Filter a JSONL dataset by programming language (based on cursor_path extension)
 204    FilterLanguages(FilterLanguagesArgs),
 205    /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
 206    ImportBatch(ImportBatchArgs),
 207    /// Assess the quality of predictions using LLM-as-a-judge
 208    Qa(qa::QaArgs),
 209    /// Repair predictions that received poor QA scores by generating improved predictions
 210    Repair(repair::RepairArgs),
 211    /// Print all valid zeta formats (lowercase, one per line)
 212    PrintZetaFormats,
 213    /// Sync baseten deployment metadata to Snowflake for linking predictions to experiments
 214    SyncDeployments(SyncDeploymentsArgs),
 215}
 216
 217impl Display for Command {
 218    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 219        match self {
 220            Command::Read => write!(f, "read"),
 221            Command::LoadProject => write!(f, "load-project"),
 222            Command::Context => write!(f, "context"),
 223            Command::FormatPrompt(args) => {
 224                write!(f, "format-prompt --provider={}", args.provider)
 225            }
 226            Command::Predict(args) => match &args.provider {
 227                Some(provider) => write!(f, "predict --provider={}", provider),
 228                None => write!(f, "predict"),
 229            },
 230            Command::ParseOutput => write!(f, "parse-output"),
 231            Command::Score(args) => match &args.provider {
 232                Some(provider) => write!(f, "score --provider={}", provider),
 233                None => write!(f, "score"),
 234            },
 235            Command::Distill => write!(f, "distill"),
 236            Command::Eval(args) => match &args.predict.provider {
 237                Some(provider) => write!(f, "eval --provider={}", provider),
 238                None => write!(f, "eval"),
 239            },
 240            Command::Synthesize(args) => {
 241                write!(f, "synthesize --repos {}", args.repos.join(" "))
 242            }
 243            Command::Clean => write!(f, "clean"),
 244            Command::SplitCommit(_) => write!(f, "split-commit"),
 245            Command::TruncatePatch(_) => write!(f, "truncate-patch"),
 246            Command::Split(_) => write!(f, "split"),
 247            Command::FilterLanguages(_) => write!(f, "filter-languages"),
 248            Command::ImportBatch(args) => {
 249                write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
 250            }
 251            Command::Qa(_) => {
 252                write!(f, "qa")
 253            }
 254            Command::Repair(_) => {
 255                write!(f, "repair")
 256            }
 257            Command::PrintZetaFormats => {
 258                write!(f, "print-zeta-formats")
 259            }
 260            Command::SyncDeployments(_) => {
 261                write!(f, "sync-deployments")
 262            }
 263        }
 264    }
 265}
 266
 267#[derive(Debug, Args, Clone)]
 268struct SyncDeploymentsArgs {
 269    /// BaseTen model name (default: "zeta-2")
 270    #[arg(long)]
 271    model: Option<String>,
 272}
 273
 274#[derive(Debug, Args, Clone)]
 275struct FormatPromptArgs {
 276    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
 277    provider: PredictionProvider,
 278}
 279
 280#[derive(Debug, Args, Clone)]
 281struct PredictArgs {
 282    #[clap(long, short('p'))]
 283    provider: Option<PredictionProvider>,
 284    #[clap(long, default_value_t = 1)]
 285    repetitions: usize,
 286    /// Only use cached responses, don't queue new requests for batching
 287    #[clap(long)]
 288    cache_only: bool,
 289}
 290
 291#[derive(Debug, Args, Clone)]
 292struct EvalArgs {
 293    #[clap(flatten)]
 294    predict: PredictArgs,
 295    /// Path to write summary scores as JSON
 296    #[clap(long)]
 297    summary_json: Option<PathBuf>,
 298}
 299
 300#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 301pub enum TeacherBackend {
 302    Sonnet45,
 303    Gpt52,
 304}
 305
 306impl std::fmt::Display for TeacherBackend {
 307    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 308        match self {
 309            TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
 310            TeacherBackend::Gpt52 => write!(f, "gpt52"),
 311        }
 312    }
 313}
 314
 315impl std::str::FromStr for TeacherBackend {
 316    type Err = anyhow::Error;
 317
 318    fn from_str(s: &str) -> Result<Self, Self::Err> {
 319        match s.to_lowercase().as_str() {
 320            "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
 321            "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
 322            "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
 323            _ => anyhow::bail!("unknown teacher backend `{s}`. Valid options: sonnet45, gpt52"),
 324        }
 325    }
 326}
 327
 328impl TeacherBackend {
 329    pub fn model_name(&self) -> &'static str {
 330        match self {
 331            TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
 332            TeacherBackend::Gpt52 => "gpt-5.2",
 333        }
 334    }
 335}
 336
 337#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 338enum PredictionProvider {
 339    Sweep,
 340    Mercury,
 341    Zeta1,
 342    Zeta2(ZetaFormat),
 343    Teacher(TeacherBackend),
 344    TeacherNonBatching(TeacherBackend),
 345    Repair,
 346}
 347
 348impl Default for PredictionProvider {
 349    fn default() -> Self {
 350        PredictionProvider::Zeta2(ZetaFormat::default())
 351    }
 352}
 353
 354impl std::fmt::Display for PredictionProvider {
 355    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 356        match self {
 357            PredictionProvider::Sweep => write!(f, "sweep"),
 358            PredictionProvider::Mercury => write!(f, "mercury"),
 359            PredictionProvider::Zeta1 => write!(f, "zeta1"),
 360            PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
 361            PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
 362            PredictionProvider::TeacherNonBatching(backend) => {
 363                write!(f, "teacher-non-batching:{backend}")
 364            }
 365            PredictionProvider::Repair => write!(f, "repair"),
 366        }
 367    }
 368}
 369
 370impl std::str::FromStr for PredictionProvider {
 371    type Err = anyhow::Error;
 372
 373    fn from_str(s: &str) -> Result<Self, Self::Err> {
 374        let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
 375
 376        let provider_lower = provider.to_lowercase();
 377        match provider_lower.as_str() {
 378            "sweep" => Ok(PredictionProvider::Sweep),
 379            "mercury" => Ok(PredictionProvider::Mercury),
 380            "zeta1" => Ok(PredictionProvider::Zeta1),
 381            "zeta2" => {
 382                let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
 383                Ok(PredictionProvider::Zeta2(format))
 384            }
 385            "teacher" => {
 386                let backend = arg
 387                    .map(|a| a.parse())
 388                    .transpose()?
 389                    .unwrap_or(TeacherBackend::Sonnet45);
 390                Ok(PredictionProvider::Teacher(backend))
 391            }
 392            "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
 393                let backend = arg
 394                    .map(|a| a.parse())
 395                    .transpose()?
 396                    .unwrap_or(TeacherBackend::Sonnet45);
 397                Ok(PredictionProvider::TeacherNonBatching(backend))
 398            }
 399            "repair" => Ok(PredictionProvider::Repair),
 400            _ => {
 401                anyhow::bail!(
 402                    "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
 403                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
 404                 For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\
 405                 Available zeta versions:\n{}",
 406                    ZetaFormat::options_as_string()
 407                )
 408            }
 409        }
 410    }
 411}
 412
 413impl Serialize for PredictionProvider {
 414    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 415    where
 416        S: Serializer,
 417    {
 418        serializer.serialize_str(&self.to_string())
 419    }
 420}
 421
 422impl<'de> Deserialize<'de> for PredictionProvider {
 423    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 424    where
 425        D: Deserializer<'de>,
 426    {
 427        let s = String::deserialize(deserializer)?;
 428        s.parse().map_err(serde::de::Error::custom)
 429    }
 430}
 431
 432#[derive(Debug, Args, Clone)]
 433struct SynthesizeArgs {
 434    /// Repository URLs (git@github.com:owner/repo or https://...)
 435    #[clap(long, required = true, num_args = 1..)]
 436    repos: Vec<String>,
 437
 438    /// Number of examples to generate per repository
 439    #[clap(long, default_value_t = 5)]
 440    count: usize,
 441
 442    /// Maximum commits to scan per repository before giving up
 443    #[clap(long, default_value_t = 100)]
 444    max_commits: usize,
 445
 446    /// Ignore state file and reprocess all commits
 447    #[clap(long)]
 448    fresh: bool,
 449}
 450
 451#[derive(Debug, Args, Clone)]
 452struct ImportBatchArgs {
 453    /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
 454    #[clap(long, required = true, num_args = 1..)]
 455    batch_ids: Vec<String>,
 456    /// Which provider's batches to import (anthropic or openai)
 457    #[clap(long, default_value = "anthropic")]
 458    provider: BatchProvider,
 459}
 460
 461#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
 462enum BatchProvider {
 463    Anthropic,
 464    Openai,
 465}
 466
 467impl EpArgs {
 468    fn output_path(&self) -> Option<PathBuf> {
 469        if self.in_place {
 470            if self.inputs.len() == 1 {
 471                self.inputs.first().cloned()
 472            } else {
 473                panic!("--in-place requires exactly one input file")
 474            }
 475        } else {
 476            self.output.clone()
 477        }
 478    }
 479}
 480
 481async fn load_examples(
 482    http_client: Arc<dyn http_client::HttpClient>,
 483    args: &EpArgs,
 484    output_path: Option<&PathBuf>,
 485    background_executor: BackgroundExecutor,
 486) -> anyhow::Result<Vec<Example>> {
 487    let mut captured_after_timestamps = Vec::new();
 488    let mut rejected_after_timestamps = Vec::new();
 489    let mut requested_after_timestamps = Vec::new();
 490    let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
 491        Vec::new();
 492    let mut file_inputs = Vec::new();
 493
 494    for input in &args.inputs {
 495        let input_string = input.to_string_lossy();
 496        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
 497            captured_after_timestamps.push(timestamp.to_string());
 498        } else if let Some(timestamp) =
 499            pull_examples::parse_rejected_after_input(input_string.as_ref())
 500        {
 501            rejected_after_timestamps.push(timestamp.to_string());
 502        } else if let Some(timestamp) =
 503            pull_examples::parse_requested_after_input(input_string.as_ref())
 504        {
 505            requested_after_timestamps.push(timestamp.to_string());
 506        } else if let Some((timestamp, rating_filter)) =
 507            pull_examples::parse_rated_after_input(input_string.as_ref())
 508        {
 509            rated_after_inputs.push((timestamp.to_string(), rating_filter));
 510        } else {
 511            file_inputs.push(input.clone());
 512        }
 513    }
 514
 515    let mut examples = read_example_files(&file_inputs);
 516
 517    Progress::global().set_total_examples(examples.len());
 518
 519    let remaining_limit_for_snowflake =
 520        args.limit.map(|limit| limit.saturating_sub(examples.len()));
 521
 522    if let Some(0) = remaining_limit_for_snowflake {
 523        log::info!(
 524            "skipping Snowflake inputs because --limit is already satisfied by example files"
 525        );
 526    } else {
 527        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
 528
 529        if !captured_after_timestamps.is_empty() {
 530            captured_after_timestamps.sort();
 531
 532            let mut captured_examples = pull_examples::fetch_captured_examples_after(
 533                http_client.clone(),
 534                &captured_after_timestamps,
 535                max_rows_per_timestamp,
 536                background_executor.clone(),
 537            )
 538            .await?;
 539            examples.append(&mut captured_examples);
 540        }
 541
 542        if !rejected_after_timestamps.is_empty() {
 543            rejected_after_timestamps.sort();
 544
 545            let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
 546                http_client.clone(),
 547                &rejected_after_timestamps,
 548                max_rows_per_timestamp,
 549                background_executor.clone(),
 550            )
 551            .await?;
 552            examples.append(&mut rejected_examples);
 553        }
 554
 555        if !requested_after_timestamps.is_empty() {
 556            requested_after_timestamps.sort();
 557
 558            let mut requested_examples = pull_examples::fetch_requested_examples_after(
 559                http_client.clone(),
 560                &requested_after_timestamps,
 561                max_rows_per_timestamp,
 562                background_executor.clone(),
 563            )
 564            .await?;
 565            examples.append(&mut requested_examples);
 566        }
 567
 568        if !rated_after_inputs.is_empty() {
 569            rated_after_inputs.sort();
 570
 571            let mut rated_examples = pull_examples::fetch_rated_examples_after(
 572                http_client,
 573                &rated_after_inputs,
 574                max_rows_per_timestamp,
 575                background_executor,
 576            )
 577            .await?;
 578            examples.append(&mut rated_examples);
 579        }
 580    }
 581
 582    crate::example::sort_examples_by_repo_and_rev(&mut examples);
 583
 584    if let Some(name_filter) = &args.name {
 585        examples.retain(|example| example.spec.name.contains(name_filter));
 586    }
 587    if let Some(repo_filter) = &args.repo {
 588        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
 589    }
 590
 591    // Skip resume logic for --in-place since input and output are the same file,
 592    // which would incorrectly treat all input examples as already processed.
 593    if !args.in_place {
 594        if let Some(path) = output_path
 595            && let Some(command) = &args.command
 596        {
 597            resume_from_output(path, &mut examples, command);
 598        }
 599    }
 600
 601    if let Some(offset) = args.offset {
 602        examples.splice(0..offset, []);
 603    }
 604
 605    if let Some(limit) = args.limit {
 606        examples.truncate(limit);
 607    }
 608
 609    let progress = Progress::global();
 610    progress.set_total_examples(examples.len());
 611    progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
 612
 613    Ok(examples)
 614}
 615
 616fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
 617    let mut hasher = collections::FxHasher::default();
 618    spec.hash(&mut hasher);
 619    hasher.finish()
 620}
 621
 622fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
 623    let file = match File::open(path) {
 624        Ok(f) => f,
 625        Err(_) => return,
 626    };
 627
 628    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
 629
 630    let reader = BufReader::new(file);
 631    let mut kept_lines = Vec::new();
 632    let mut kept_hashes = HashSet::default();
 633
 634    for line in reader.lines() {
 635        let line = match line {
 636            Ok(l) => l,
 637            Err(_) => continue,
 638        };
 639
 640        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
 641            let hash = spec_hash(&output_example.spec);
 642            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
 643                let is_complete = match command {
 644                    Command::Qa(_) => output_example
 645                        .qa
 646                        .first()
 647                        .and_then(|q| q.as_ref())
 648                        .and_then(|q| q.confidence)
 649                        .is_some(),
 650                    Command::Repair(_) => output_example.predictions.iter().any(|p| {
 651                        p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
 652                    }),
 653                    _ => true,
 654                };
 655                if is_complete {
 656                    kept_hashes.insert(hash);
 657                    kept_lines.push(line);
 658                }
 659            }
 660        }
 661    }
 662
 663    let total = examples.len();
 664    let already_processed = kept_hashes.len();
 665
 666    eprintln!(
 667        "Resuming: {}/{} examples already processed",
 668        already_processed, total
 669    );
 670
 671    let file = OpenOptions::new()
 672        .write(true)
 673        .truncate(true)
 674        .open(path)
 675        .expect("Failed to open output file for rewriting");
 676    let mut writer = BufWriter::new(file);
 677    for line in &kept_lines {
 678        writeln!(writer, "{}", line).expect("Failed to write to output file");
 679    }
 680    writer.flush().expect("Failed to flush output file");
 681
 682    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
 683}
 684
 685fn main() {
 686    let args = EpArgs::parse();
 687
 688    if args.printenv {
 689        ::util::shell_env::print_env();
 690        return;
 691    }
 692
 693    let output = args.output_path();
 694
 695    if args.markdown && output.is_none() {
 696        eprintln!("--markdown requires -o to specify the output directory");
 697        std::process::exit(1);
 698    }
 699
 700    let command = match &args.command {
 701        Some(cmd) => cmd.clone(),
 702        None => {
 703            EpArgs::command().print_help().unwrap();
 704            return;
 705        }
 706    };
 707
 708    match &command {
 709        Command::ImportBatch(import_args) => {
 710            smol::block_on(async {
 711                match import_args.provider {
 712                    BatchProvider::Anthropic => {
 713                        let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
 714                            .expect("Failed to create Anthropic client");
 715                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 716                            eprintln!("Error importing Anthropic batches: {:?}", e);
 717                            std::process::exit(1);
 718                        }
 719                    }
 720                    BatchProvider::Openai => {
 721                        let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
 722                            .expect("Failed to create OpenAI client");
 723                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 724                            eprintln!("Error importing OpenAI batches: {:?}", e);
 725                            std::process::exit(1);
 726                        }
 727                    }
 728                }
 729                println!(
 730                    "Successfully imported {} batch(es)",
 731                    import_args.batch_ids.len()
 732                );
 733            });
 734            return;
 735        }
 736        Command::Clean => {
 737            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
 738            return;
 739        }
 740        Command::PrintZetaFormats => {
 741            use strum::IntoEnumIterator as _;
 742            for format in ZetaFormat::iter() {
 743                println!("{}", format.to_string().to_lowercase());
 744            }
 745            return;
 746        }
 747        Command::SyncDeployments(sync_args) => {
 748            let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
 749            smol::block_on(async {
 750                if let Err(e) =
 751                    sync_deployments::run_sync_deployments(http_client, sync_args.model.clone())
 752                        .await
 753                {
 754                    eprintln!("Error: {:?}", e);
 755                    std::process::exit(1);
 756                }
 757            });
 758            return;
 759        }
 760        Command::Synthesize(synth_args) => {
 761            let Some(output_dir) = args.output else {
 762                panic!("output dir is required");
 763            };
 764            let config = SynthesizeConfig {
 765                repo_urls: synth_args.repos.clone(),
 766                count: synth_args.count,
 767                max_commits: synth_args.max_commits,
 768                output_dir,
 769                fresh: synth_args.fresh,
 770            };
 771            smol::block_on(async {
 772                if let Err(e) = run_synthesize(config).await {
 773                    eprintln!("Error: {:?}", e);
 774                    std::process::exit(1);
 775                }
 776            });
 777            return;
 778        }
 779        Command::SplitCommit(split_commit_args) => {
 780            if let Err(error) = split_commit::run_split_commit(
 781                split_commit_args,
 782                &args.inputs,
 783                output.as_ref(),
 784                args.failed,
 785            ) {
 786                eprintln!("{error:#}");
 787                std::process::exit(1);
 788            }
 789            return;
 790        }
 791        Command::TruncatePatch(truncate_args) => {
 792            if let Err(error) =
 793                truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
 794            {
 795                eprintln!("{error:#}");
 796                std::process::exit(1);
 797            }
 798            return;
 799        }
 800        Command::Split(split_args) => {
 801            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
 802                eprintln!("{error:#}");
 803                std::process::exit(1);
 804            }
 805            return;
 806        }
 807        Command::FilterLanguages(filter_args) => {
 808            if let Err(error) =
 809                run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
 810            {
 811                eprintln!("{error:#}");
 812                std::process::exit(1);
 813            }
 814            return;
 815        }
 816
 817        _ => {}
 818    }
 819
 820    let http_client = Arc::new(ReqwestClient::new());
 821    let app = Application::headless().with_http_client(http_client);
 822
 823    app.run(move |cx| {
 824        let app_state = Arc::new(headless::init(cx));
 825        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
 826
 827        cx.spawn(async move |cx| {
 828            let result = async {
 829                let examples = load_examples(
 830                    app_state.client.http_client(),
 831                    &args,
 832                    output.as_ref(),
 833                    cx.background_executor().clone(),
 834                )
 835                .await?;
 836
 837                match &command {
 838                    Command::Predict(args) | Command::Score(args) => {
 839                        predict::sync_batches(args.provider.as_ref()).await?;
 840                    }
 841                    Command::Eval(args) => {
 842                        predict::sync_batches(args.predict.provider.as_ref()).await?;
 843                    }
 844                    Command::Qa(args) => {
 845                        qa::sync_batches(args).await?;
 846                    }
 847                    Command::Repair(args) => {
 848                        repair::sync_batches(args).await?;
 849                    }
 850                    _ => (),
 851                }
 852
 853                let failfast_on_single_example = examples.len() == 1;
 854
 855                // For --markdown mode, create the output directory if it doesn't exist
 856                if args.markdown {
 857                    let dir = output.as_ref().expect("--markdown requires -o");
 858                    if !dir.exists() {
 859                        std::fs::create_dir_all(dir)
 860                            .expect("Failed to create markdown output directory");
 861                    }
 862                }
 863
 864                // Set up JSONL output writer (not used in markdown mode)
 865                let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
 866                let mut in_place_temp_path: Option<PathBuf> = None;
 867                if !args.markdown
 868                    && let Some(output_path) = output.as_ref()
 869                {
 870                    let write_path = if args.in_place {
 871                        let temp = output_path.with_extension("jsonl.tmp");
 872                        in_place_temp_path = Some(temp.clone());
 873                        temp
 874                    } else {
 875                        output_path.clone()
 876                    };
 877
 878                    let file = OpenOptions::new()
 879                        .create(true)
 880                        .write(true)
 881                        .truncate(args.in_place)
 882                        .append(!args.in_place)
 883                        .open(&write_path)
 884                        .expect("Failed to open output file");
 885
 886                    let mut writer = BufWriter::new(file);
 887                    let (sender, mut receiver) = mpsc::unbounded::<String>();
 888                    cx.background_spawn(async move {
 889                        while let Some(line) = receiver.next().await {
 890                            writeln!(writer, "{}", line).expect("Failed to write example");
 891                            writer.flush().expect("Failed to flush output");
 892                        }
 893                    })
 894                    .detach();
 895                    output_sender = Some(sender);
 896                }
 897
 898                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
 899                let finished_examples = Mutex::new(Vec::new());
 900
 901                let mut tasks = Vec::new();
 902                for _ in 0..args.max_parallelism {
 903                    tasks.push(async {
 904                        loop {
 905                            let Some(mut repo_examples) =
 906                                grouped_examples.lock().unwrap().pop_front()
 907                            else {
 908                                break;
 909                            };
 910                            for example in &mut repo_examples {
 911                                let example_progress =
 912                                    Progress::global().start_group(&example.spec.name);
 913
 914                                let result = async {
 915                                    match &command {
 916                                        Command::Read => {}
 917                                        Command::LoadProject => {
 918                                            run_load_project(
 919                                                example,
 920                                                app_state.clone(),
 921                                                &example_progress,
 922                                                cx.clone(),
 923                                            )
 924                                            .await?;
 925                                        }
 926                                        Command::Context => {
 927                                            run_context_retrieval(
 928                                                example,
 929                                                app_state.clone(),
 930                                                &example_progress,
 931                                                cx.clone(),
 932                                            )
 933                                            .await?;
 934                                        }
 935                                        Command::FormatPrompt(args) => {
 936                                            run_format_prompt(
 937                                                example,
 938                                                args,
 939                                                app_state.clone(),
 940                                                &example_progress,
 941                                                cx.clone(),
 942                                            )
 943                                            .await?;
 944                                        }
 945                                        Command::Predict(args) => {
 946                                            run_prediction(
 947                                                example,
 948                                                args,
 949                                                app_state.clone(),
 950                                                &example_progress,
 951                                                cx.clone(),
 952                                            )
 953                                            .await?;
 954                                        }
 955                                        Command::ParseOutput => {
 956                                            parse_output::run_parse_output(example)?;
 957                                        }
 958                                        Command::Distill => {
 959                                            run_distill(example).await?;
 960                                        }
 961                                        Command::Score(args) => {
 962                                            run_scoring(
 963                                                example,
 964                                                args,
 965                                                app_state.clone(),
 966                                                &example_progress,
 967                                                cx.clone(),
 968                                            )
 969                                            .await?;
 970                                        }
 971                                        Command::Eval(args) => {
 972                                            run_scoring(
 973                                                example,
 974                                                &args.predict,
 975                                                app_state.clone(),
 976                                                &example_progress,
 977                                                cx.clone(),
 978                                            )
 979                                            .await?;
 980                                        }
 981                                        Command::Qa(args) => {
 982                                            qa::run_qa(example, args, &example_progress).await?;
 983                                        }
 984                                        Command::Repair(args) => {
 985                                            repair::run_repair(example, args, &example_progress)
 986                                                .await?;
 987                                        }
 988                                        Command::Clean
 989                                        | Command::Synthesize(_)
 990                                        | Command::SplitCommit(_)
 991                                        | Command::Split(_)
 992                                        | Command::TruncatePatch(_)
 993                                        | Command::FilterLanguages(_)
 994                                        | Command::ImportBatch(_)
 995                                        | Command::PrintZetaFormats
 996                                        | Command::SyncDeployments(_) => {
 997                                            unreachable!()
 998                                        }
 999                                    }
1000                                    anyhow::Ok(())
1001                                }
1002                                .await;
1003
1004                                let failed = if let Err(error) = result {
1005                                    handle_error(
1006                                        error,
1007                                        &args,
1008                                        &command,
1009                                        &app_state,
1010                                        failfast_on_single_example,
1011                                        &example,
1012                                    )
1013                                    .await;
1014                                    true
1015                                } else {
1016                                    false
1017                                };
1018
1019                                let should_write = !failed || args.failed == FailedHandling::Keep;
1020                                if should_write {
1021                                    if args.markdown {
1022                                        let markdown_dir =
1023                                            output.as_ref().expect("--markdown requires -o");
1024                                        let filename = format!("{}.md", example.spec.filename());
1025                                        let path = markdown_dir.join(&filename);
1026                                        let markdown = example.spec.to_markdown();
1027                                        std::fs::write(&path, &markdown)
1028                                            .expect("Failed to write markdown file");
1029                                    } else if let Some(ref mut sender) = output_sender.clone() {
1030                                        let line = serde_json::to_string(&example).unwrap();
1031                                        sender
1032                                            .send(line)
1033                                            .await
1034                                            .expect("Failed to send to output writer");
1035                                    } else if args.output.is_none()
1036                                        && !matches!(command, Command::Eval(_))
1037                                    {
1038                                        let line = serde_json::to_string(&example).unwrap();
1039                                        println!("{}", line);
1040                                    }
1041                                }
1042                            }
1043
1044                            let repo_url = &repo_examples.first().unwrap().spec.repository_url;
1045                            let project = repo_examples
1046                                .iter()
1047                                .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
1048                                .or_else(|| app_state.project_cache.get(repo_url));
1049
1050                            if let Some(project) = project {
1051                                let mut cx = cx.clone();
1052
1053                                let shutdown_task: Task<()> =
1054                                    project.update(&mut cx, |project, cx| {
1055                                        let lsp_store = project.lsp_store();
1056                                        lsp_store.update(cx, |lsp_store, cx| {
1057                                            lsp_store.shutdown_all_language_servers(cx)
1058                                        })
1059                                    });
1060
1061                                shutdown_task.await;
1062
1063                                if let Some(ep_store) =
1064                                    cx.update(|cx| EditPredictionStore::try_global(cx))
1065                                {
1066                                    ep_store.update(&mut cx, |store, _| {
1067                                        store.remove_project(&project);
1068                                    });
1069                                }
1070                            }
1071
1072                            app_state.project_cache.remove(repo_url);
1073                            for example in &mut repo_examples {
1074                                example.state.take();
1075                            }
1076                            finished_examples
1077                                .lock()
1078                                .unwrap()
1079                                .extend_from_slice(&repo_examples);
1080                        }
1081                    });
1082                }
1083                futures::future::join_all(tasks).await;
1084
1085                Progress::global().finalize();
1086
1087                match &command {
1088                    Command::Predict(args) | Command::Score(args) => {
1089                        predict::sync_batches(args.provider.as_ref()).await?;
1090                    }
1091                    Command::Eval(args) => {
1092                        predict::sync_batches(args.predict.provider.as_ref()).await?;
1093                    }
1094                    Command::Qa(args) => {
1095                        qa::sync_batches(args).await?;
1096                    }
1097                    Command::Repair(args) => {
1098                        repair::sync_batches(args).await?;
1099                    }
1100                    _ => (),
1101                }
1102
1103                match &command {
1104                    Command::Eval(args) => {
1105                        let examples = finished_examples.lock().unwrap();
1106                        score::print_report(&examples);
1107                        if let Some(summary_path) = &args.summary_json {
1108                            score::write_summary_json(&examples, summary_path)?;
1109                        }
1110                    }
1111                    Command::Repair(args) => {
1112                        let examples = finished_examples.lock().unwrap();
1113                        repair::print_report(&examples, args.confidence_threshold);
1114                    }
1115                    _ => (),
1116                };
1117
1118                // For --in-place, atomically rename temp file to original
1119                if let Some(temp_path) = &in_place_temp_path {
1120                    let final_path = output.as_ref().expect("in_place_temp_path requires output");
1121                    std::fs::rename(temp_path, final_path)
1122                        .expect("Failed to rename temp file to final output");
1123                }
1124
1125                anyhow::Ok(())
1126            }
1127            .await;
1128
1129            if let Err(e) = result {
1130                panic!("Fatal error: {:?}", e);
1131            }
1132
1133            let _ = cx.update(|cx| cx.quit());
1134        })
1135        .detach();
1136    });
1137}
1138
1139async fn handle_error(
1140    error: anyhow::Error,
1141    args: &EpArgs,
1142    command: &Command,
1143    app_state: &Arc<headless::EpAppState>,
1144    failfast_on_single_example: bool,
1145    example: &Example,
1146) {
1147    Progress::global().increment_failed();
1148
1149    let msg;
1150    if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1151        let example_name = example.spec.filename();
1152
1153        let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1154        app_state
1155            .fs
1156            .write(
1157                &failed_example_path,
1158                &serde_json::to_vec_pretty(&example).unwrap(),
1159            )
1160            .await
1161            .unwrap();
1162        let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1163        app_state
1164            .fs
1165            .write(&err_path, format!("{error:?}").as_bytes())
1166            .await
1167            .unwrap();
1168
1169        let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1170        let mut file = OpenOptions::new()
1171            .create(true)
1172            .append(true)
1173            .open(&failed_jsonl_path)
1174            .expect("Failed to open failed.jsonl");
1175        writeln!(file, "{}", serde_json::to_string(example).unwrap())
1176            .expect("Failed to write to failed.jsonl");
1177
1178        let cursor_path = match example.repo_name() {
1179            Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1180            Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1181        };
1182        msg = format!(
1183            indoc::indoc! {"
1184                While processing \"{}\":
1185
1186                \x1b[31m{:?}\x1b[0m
1187
1188                Example:        \x1b[36m{}\x1b[0m
1189                Error file:     \x1b[36m{}\x1b[0m
1190                Cursor file:    \x1b[36m{}\x1b[0m
1191                Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1192            "},
1193            example.spec.name,
1194            error,
1195            failed_example_path.display(),
1196            err_path.display(),
1197            cursor_path.display(),
1198            command,
1199            failed_example_path.display(),
1200        );
1201    } else {
1202        msg = format!(
1203            indoc::indoc! {"
1204            While processing \"{}\":
1205
1206                \x1b[31m{:?}\x1b[0m
1207            "},
1208            example.spec.name, error
1209        );
1210    }
1211
1212    if args.failfast || failfast_on_single_example {
1213        Progress::global().finalize();
1214        panic!("{}", msg);
1215    } else {
1216        log::error!("{}", msg);
1217    }
1218}