main.rs

   1mod anthropic_client;
   2mod distill;
   3mod example;
   4mod filter_languages;
   5mod format_prompt;
   6mod git;
   7mod headless;
   8mod load_project;
   9mod metrics;
  10mod openai_client;
  11mod parse_output;
  12mod paths;
  13mod predict;
  14mod progress;
  15mod prompt_assets;
  16mod pull_examples;
  17mod qa;
  18mod reorder_patch;
  19mod repair;
  20mod retrieve_context;
  21mod reversal_tracking;
  22mod score;
  23mod split_commit;
  24mod split_dataset;
  25mod sync_deployments;
  26mod synthesize;
  27mod truncate_expected_patch;
  28mod word_diff;
  29use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
  30use collections::HashSet;
  31use edit_prediction::EditPredictionStore;
  32use futures::channel::mpsc;
  33use futures::{SinkExt as _, StreamExt as _};
  34use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
  35use zeta_prompt::ZetaFormat;
  36
  37use reqwest_client::ReqwestClient;
  38use serde::{Deserialize, Deserializer, Serialize, Serializer};
  39use std::fmt::Display;
  40use std::fs::{File, OpenOptions};
  41use std::hash::{Hash, Hasher};
  42use std::io::{BufRead, BufReader, BufWriter, Write};
  43use std::sync::Mutex;
  44use std::{path::PathBuf, sync::Arc};
  45
  46use crate::distill::run_distill;
  47use crate::example::{Example, group_examples_by_repo, read_example_files};
  48use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
  49use crate::format_prompt::run_format_prompt;
  50use crate::load_project::run_load_project;
  51use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
  52use crate::predict::run_prediction;
  53use crate::progress::Progress;
  54use crate::retrieve_context::run_context_retrieval;
  55use crate::score::run_scoring;
  56use crate::split_commit::SplitCommitArgs;
  57use crate::split_dataset::SplitArgs;
  58use crate::synthesize::{SynthesizeConfig, run_synthesize};
  59use crate::truncate_expected_patch::TruncatePatchArgs;
  60
  61#[derive(Parser, Debug)]
  62#[command(name = "ep")]
  63struct EpArgs {
  64    #[arg(long, default_value_t = false)]
  65    printenv: bool,
  66    #[clap(long, default_value_t = 10, global = true)]
  67    max_parallelism: usize,
  68    /// The limit for the number of examples to process
  69    /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
  70    #[clap(long, global = true)]
  71    limit: Option<usize>,
  72    #[clap(long, global = true)]
  73    offset: Option<usize>,
  74    /// Filter examples by name
  75    #[clap(long, global = true)]
  76    name: Option<String>,
  77    /// Filter examples by repository
  78    #[clap(long, global = true)]
  79    repo: Option<String>,
  80    #[command(subcommand)]
  81    command: Option<Command>,
  82    #[clap(global = true, help = INPUTS_HELP)]
  83    inputs: Vec<PathBuf>,
  84    #[arg(long, short, global = true)]
  85    output: Option<PathBuf>,
  86    #[arg(long, short, global = true)]
  87    in_place: bool,
  88    #[arg(long, short, global = true)]
  89    failfast: bool,
  90    /// How to handle failed examples in output: keep them or skip them.
  91    /// Failed examples are always logged to the run's failed directory.
  92    #[arg(long, global = true, default_value = "keep")]
  93    failed: FailedHandling,
  94    /// Output as markdown files instead of JSONL. When set, -o specifies a directory
  95    /// where one .md file per example will be written (named after each example).
  96    #[arg(long, short, global = true)]
  97    markdown: bool,
  98}
  99
 100/// Controls whether failed examples are included in the main output.
 101/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
 102#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
 103pub enum FailedHandling {
 104    /// Include failed examples in the main output (default)
 105    #[default]
 106    Keep,
 107    /// Exclude failed examples from the main output
 108    Skip,
 109    /// Skip writing files
 110    SkipNoFiles,
 111}
 112
 113const INPUTS_HELP: &str = r#"
 114Inputs can be file paths or special specifiers:
 115
 116  path
 117      Path to an example(s) file (.md, .json, or .jsonl)
 118
 119  captured-after:{timestamp}
 120      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 121      These are examples captured via the "Capture Edit Prediction Example" action.
 122
 123  rejected-after:{timestamp}
 124      Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
 125      These are predictions that were shown to users but rejected (useful for DPO training).
 126
 127  rated-after:{timestamp}
 128      Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
 129      These are predictions that users explicitly rated as positive or negative via the
 130      rate completions modal. Only zeta2 predictions are included.
 131      - Positive ratings: output becomes expected_patches
 132      - Negative ratings: output becomes rejected_patch
 133
 134  rated-positive-after:{timestamp}
 135      Same as rated-after, but only fetches positively rated predictions.
 136
 137  rated-negative-after:{timestamp}
 138      Same as rated-after, but only fetches negatively rated predictions.
 139
 140      Required environment variables to connect to Snowflake:
 141          EP_SNOWFLAKE_API_KEY
 142          EP_SNOWFLAKE_BASE_URL
 143
 144      Optional:
 145          EP_SNOWFLAKE_ROLE
 146
 147Examples:
 148
 149  # Read examples from a file
 150  ep read examples.jsonl -o output.jsonl
 151
 152  # Read captured examples after a timestamp
 153  ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
 154
 155  # Read rejected predictions for DPO training
 156  ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
 157
 158  # Read user-rated predictions
 159  ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
 160
 161  # Read only positively rated predictions
 162  ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
 163
 164  # Read only negatively rated predictions
 165  ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
 166
 167  # Mix multiple input sources
 168  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
 169"#;
 170
 171#[derive(Subcommand, Debug, Clone)]
 172enum Command {
 173    /// Read examples from files or fetch from Snowflake, output as .jsonl
 174    Read,
 175    /// Create git worktrees for each example and load file contents
 176    LoadProject,
 177    /// Retrieve context for input examples.
 178    Context,
 179    /// Generate a prompt string for a specific model
 180    FormatPrompt(FormatPromptArgs),
 181    /// Runs edit prediction
 182    Predict(PredictArgs),
 183    /// Parse model outputs (actual_output) into unified diffs (actual_patch).
 184    /// Requires format-prompt to have been run first. Uses provider from prompt.
 185    ParseOutput,
 186    /// Computes a score based on actual and expected patches
 187    Score(PredictArgs),
 188    /// Prepares a distillation dataset by copying expected outputs to
 189    /// predicted outputs and removing actual outputs and prompts.
 190    Distill,
 191    /// Print aggregated scores
 192    Eval(EvalArgs),
 193    /// Generate eval examples by analyzing git commits from a repository
 194    Synthesize(SynthesizeArgs),
 195    /// Remove git repositories and worktrees
 196    Clean,
 197    /// Generate an evaluation example by splitting a chronologically-ordered commit
 198    SplitCommit(SplitCommitArgs),
 199    /// Truncate expected patch by the given criteria
 200    TruncatePatch(TruncatePatchArgs),
 201    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
 202    Split(SplitArgs),
 203    /// Filter a JSONL dataset by programming language (based on cursor_path extension)
 204    FilterLanguages(FilterLanguagesArgs),
 205    /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
 206    ImportBatch(ImportBatchArgs),
 207    /// Assess the quality of predictions using LLM-as-a-judge
 208    Qa(qa::QaArgs),
 209    /// Repair predictions that received poor QA scores by generating improved predictions
 210    Repair(repair::RepairArgs),
 211    /// Print all valid zeta formats (lowercase, one per line)
 212    PrintZetaFormats,
 213    /// Sync baseten deployment metadata to Snowflake for linking predictions to experiments
 214    SyncDeployments(SyncDeploymentsArgs),
 215}
 216
 217impl Display for Command {
 218    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 219        match self {
 220            Command::Read => write!(f, "read"),
 221            Command::LoadProject => write!(f, "load-project"),
 222            Command::Context => write!(f, "context"),
 223            Command::FormatPrompt(args) => {
 224                write!(f, "format-prompt --provider={}", args.provider)
 225            }
 226            Command::Predict(args) => match &args.provider {
 227                Some(provider) => write!(f, "predict --provider={}", provider),
 228                None => write!(f, "predict"),
 229            },
 230            Command::ParseOutput => write!(f, "parse-output"),
 231            Command::Score(args) => match &args.provider {
 232                Some(provider) => write!(f, "score --provider={}", provider),
 233                None => write!(f, "score"),
 234            },
 235            Command::Distill => write!(f, "distill"),
 236            Command::Eval(args) => match &args.predict.provider {
 237                Some(provider) => write!(f, "eval --provider={}", provider),
 238                None => write!(f, "eval"),
 239            },
 240            Command::Synthesize(args) => {
 241                write!(f, "synthesize --repos {}", args.repos.join(" "))
 242            }
 243            Command::Clean => write!(f, "clean"),
 244            Command::SplitCommit(_) => write!(f, "split-commit"),
 245            Command::TruncatePatch(_) => write!(f, "truncate-patch"),
 246            Command::Split(_) => write!(f, "split"),
 247            Command::FilterLanguages(_) => write!(f, "filter-languages"),
 248            Command::ImportBatch(args) => {
 249                write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
 250            }
 251            Command::Qa(_) => {
 252                write!(f, "qa")
 253            }
 254            Command::Repair(_) => {
 255                write!(f, "repair")
 256            }
 257            Command::PrintZetaFormats => {
 258                write!(f, "print-zeta-formats")
 259            }
 260            Command::SyncDeployments(_) => {
 261                write!(f, "sync-deployments")
 262            }
 263        }
 264    }
 265}
 266
 267#[derive(Debug, Args, Clone)]
 268struct SyncDeploymentsArgs {
 269    /// BaseTen model name (default: "zeta-2")
 270    #[arg(long)]
 271    model: Option<String>,
 272}
 273
 274#[derive(Debug, Args, Clone)]
 275struct FormatPromptArgs {
 276    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
 277    provider: PredictionProvider,
 278}
 279
 280#[derive(Debug, Args, Clone)]
 281struct PredictArgs {
 282    #[clap(long, short('p'))]
 283    provider: Option<PredictionProvider>,
 284    #[clap(long, default_value_t = 1)]
 285    repetitions: usize,
 286    /// Only use cached responses, don't queue new requests for batching
 287    #[clap(long)]
 288    cache_only: bool,
 289}
 290
 291#[derive(Debug, Args, Clone)]
 292struct EvalArgs {
 293    #[clap(flatten)]
 294    predict: PredictArgs,
 295    /// Path to write summary scores as JSON
 296    #[clap(long)]
 297    summary_json: Option<PathBuf>,
 298}
 299
 300#[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)]
 301pub enum TeacherBackend {
 302    Sonnet46,
 303    #[default]
 304    Sonnet45,
 305    Gpt52,
 306}
 307
 308impl std::fmt::Display for TeacherBackend {
 309    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 310        match self {
 311            TeacherBackend::Sonnet46 => write!(f, "sonnet46"),
 312            TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
 313            TeacherBackend::Gpt52 => write!(f, "gpt52"),
 314        }
 315    }
 316}
 317
 318impl std::str::FromStr for TeacherBackend {
 319    type Err = anyhow::Error;
 320
 321    fn from_str(s: &str) -> Result<Self, Self::Err> {
 322        match s.to_lowercase().as_str() {
 323            "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
 324            "sonnet46" => Ok(TeacherBackend::Sonnet46),
 325            "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
 326            "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
 327            _ => anyhow::bail!(
 328                "unknown teacher backend `{s}`. Valid options: sonnet45, sonnet46, gpt52"
 329            ),
 330        }
 331    }
 332}
 333
 334impl TeacherBackend {
 335    pub fn model_name(&self) -> &'static str {
 336        match self {
 337            TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
 338            TeacherBackend::Sonnet46 => "claude-sonnet-4-6",
 339            TeacherBackend::Gpt52 => "gpt-5.2",
 340        }
 341    }
 342}
 343
 344#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 345enum PredictionProvider {
 346    Sweep,
 347    Mercury,
 348    Zeta1,
 349    Zeta2(ZetaFormat),
 350    Teacher(TeacherBackend),
 351    TeacherNonBatching(TeacherBackend),
 352    Repair,
 353}
 354
 355impl Default for PredictionProvider {
 356    fn default() -> Self {
 357        PredictionProvider::Zeta2(ZetaFormat::default())
 358    }
 359}
 360
 361impl std::fmt::Display for PredictionProvider {
 362    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 363        match self {
 364            PredictionProvider::Sweep => write!(f, "sweep"),
 365            PredictionProvider::Mercury => write!(f, "mercury"),
 366            PredictionProvider::Zeta1 => write!(f, "zeta1"),
 367            PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
 368            PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
 369            PredictionProvider::TeacherNonBatching(backend) => {
 370                write!(f, "teacher-non-batching:{backend}")
 371            }
 372            PredictionProvider::Repair => write!(f, "repair"),
 373        }
 374    }
 375}
 376
 377impl std::str::FromStr for PredictionProvider {
 378    type Err = anyhow::Error;
 379
 380    fn from_str(s: &str) -> Result<Self, Self::Err> {
 381        let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
 382
 383        let provider_lower = provider.to_lowercase();
 384        match provider_lower.as_str() {
 385            "sweep" => Ok(PredictionProvider::Sweep),
 386            "mercury" => Ok(PredictionProvider::Mercury),
 387            "zeta1" => Ok(PredictionProvider::Zeta1),
 388            "zeta2" => {
 389                let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
 390                Ok(PredictionProvider::Zeta2(format))
 391            }
 392            "teacher" => {
 393                let backend = arg
 394                    .map(|a| a.parse())
 395                    .transpose()?
 396                    .unwrap_or(TeacherBackend::default());
 397                Ok(PredictionProvider::Teacher(backend))
 398            }
 399            "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
 400                let backend = arg
 401                    .map(|a| a.parse())
 402                    .transpose()?
 403                    .unwrap_or(TeacherBackend::default());
 404                Ok(PredictionProvider::TeacherNonBatching(backend))
 405            }
 406            "repair" => Ok(PredictionProvider::Repair),
 407            _ => {
 408                anyhow::bail!(
 409                    "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
 410                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
 411                 For teacher, you can specify a backend like `teacher:sonnet46` or `teacher:gpt52`.\n\
 412                 Available zeta versions:\n{}",
 413                    ZetaFormat::options_as_string()
 414                )
 415            }
 416        }
 417    }
 418}
 419
 420impl Serialize for PredictionProvider {
 421    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 422    where
 423        S: Serializer,
 424    {
 425        serializer.serialize_str(&self.to_string())
 426    }
 427}
 428
 429impl<'de> Deserialize<'de> for PredictionProvider {
 430    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 431    where
 432        D: Deserializer<'de>,
 433    {
 434        let s = String::deserialize(deserializer)?;
 435        s.parse().map_err(serde::de::Error::custom)
 436    }
 437}
 438
 439#[derive(Debug, Args, Clone)]
 440struct SynthesizeArgs {
 441    /// Repository URLs (git@github.com:owner/repo or https://...)
 442    #[clap(long, required = true, num_args = 1..)]
 443    repos: Vec<String>,
 444
 445    /// Number of examples to generate per repository
 446    #[clap(long, default_value_t = 5)]
 447    count: usize,
 448
 449    /// Maximum commits to scan per repository before giving up
 450    #[clap(long, default_value_t = 100)]
 451    max_commits: usize,
 452
 453    /// Ignore state file and reprocess all commits
 454    #[clap(long)]
 455    fresh: bool,
 456}
 457
 458#[derive(Debug, Args, Clone)]
 459struct ImportBatchArgs {
 460    /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
 461    #[clap(long, required = true, num_args = 1..)]
 462    batch_ids: Vec<String>,
 463    /// Which provider's batches to import (anthropic or openai)
 464    #[clap(long, default_value = "anthropic")]
 465    provider: BatchProvider,
 466}
 467
 468#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
 469enum BatchProvider {
 470    Anthropic,
 471    Openai,
 472}
 473
 474impl EpArgs {
 475    fn output_path(&self) -> Option<PathBuf> {
 476        if self.in_place {
 477            if self.inputs.len() == 1 {
 478                self.inputs.first().cloned()
 479            } else {
 480                panic!("--in-place requires exactly one input file")
 481            }
 482        } else {
 483            self.output.clone()
 484        }
 485    }
 486}
 487
 488/// Minimum Zed version required for Snowflake queries.
 489/// This version introduced the current request schema with predicted edits in the edit
 490/// history, and open source repos distinguished.
 491const MIN_CAPTURE_VERSION: pull_examples::MinCaptureVersion = pull_examples::MinCaptureVersion {
 492    minor: 224,
 493    patch: 1,
 494};
 495
 496async fn load_examples(
 497    http_client: Arc<dyn http_client::HttpClient>,
 498    args: &EpArgs,
 499    output_path: Option<&PathBuf>,
 500    background_executor: BackgroundExecutor,
 501) -> anyhow::Result<Vec<Example>> {
 502    let mut captured_after_timestamps = Vec::new();
 503    let mut rejected_after_timestamps = Vec::new();
 504    let mut requested_after_timestamps = Vec::new();
 505    let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
 506        Vec::new();
 507    let mut file_inputs = Vec::new();
 508
 509    for input in &args.inputs {
 510        let input_string = input.to_string_lossy();
 511        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
 512            captured_after_timestamps.push(timestamp.to_string());
 513        } else if let Some(timestamp) =
 514            pull_examples::parse_rejected_after_input(input_string.as_ref())
 515        {
 516            rejected_after_timestamps.push(timestamp.to_string());
 517        } else if let Some(timestamp) =
 518            pull_examples::parse_requested_after_input(input_string.as_ref())
 519        {
 520            requested_after_timestamps.push(timestamp.to_string());
 521        } else if let Some((timestamp, rating_filter)) =
 522            pull_examples::parse_rated_after_input(input_string.as_ref())
 523        {
 524            rated_after_inputs.push((timestamp.to_string(), rating_filter));
 525        } else {
 526            file_inputs.push(input.clone());
 527        }
 528    }
 529
 530    let mut examples = read_example_files(&file_inputs);
 531
 532    // Apply offset to file examples first, then pass remaining offset to Snowflake.
 533    let file_example_count = examples.len();
 534    let remaining_offset = if let Some(offset) = args.offset {
 535        if offset >= file_example_count {
 536            examples.clear();
 537            offset - file_example_count
 538        } else {
 539            examples.splice(0..offset, []);
 540            0
 541        }
 542    } else {
 543        0
 544    };
 545
 546    Progress::global().set_total_examples(examples.len());
 547
 548    let remaining_limit_for_snowflake =
 549        args.limit.map(|limit| limit.saturating_sub(examples.len()));
 550
 551    if let Some(0) = remaining_limit_for_snowflake {
 552        log::info!(
 553            "skipping Snowflake inputs because --limit is already satisfied by example files"
 554        );
 555    } else {
 556        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
 557
 558        if !captured_after_timestamps.is_empty() {
 559            captured_after_timestamps.sort();
 560
 561            let mut captured_examples = pull_examples::fetch_captured_examples_after(
 562                http_client.clone(),
 563                &captured_after_timestamps,
 564                max_rows_per_timestamp,
 565                remaining_offset,
 566                background_executor.clone(),
 567                Some(MIN_CAPTURE_VERSION),
 568            )
 569            .await?;
 570            examples.append(&mut captured_examples);
 571        }
 572
 573        if !rejected_after_timestamps.is_empty() {
 574            rejected_after_timestamps.sort();
 575
 576            let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
 577                http_client.clone(),
 578                &rejected_after_timestamps,
 579                max_rows_per_timestamp,
 580                remaining_offset,
 581                background_executor.clone(),
 582                Some(MIN_CAPTURE_VERSION),
 583            )
 584            .await?;
 585            examples.append(&mut rejected_examples);
 586        }
 587
 588        if !requested_after_timestamps.is_empty() {
 589            requested_after_timestamps.sort();
 590
 591            let mut requested_examples = pull_examples::fetch_requested_examples_after(
 592                http_client.clone(),
 593                &requested_after_timestamps,
 594                max_rows_per_timestamp,
 595                remaining_offset,
 596                background_executor.clone(),
 597                Some(MIN_CAPTURE_VERSION),
 598            )
 599            .await?;
 600            examples.append(&mut requested_examples);
 601        }
 602
 603        if !rated_after_inputs.is_empty() {
 604            rated_after_inputs.sort();
 605
 606            let mut rated_examples = pull_examples::fetch_rated_examples_after(
 607                http_client,
 608                &rated_after_inputs,
 609                max_rows_per_timestamp,
 610                remaining_offset,
 611                background_executor,
 612                Some(MIN_CAPTURE_VERSION),
 613            )
 614            .await?;
 615            examples.append(&mut rated_examples);
 616        }
 617    }
 618
 619    crate::example::sort_examples_by_repo_and_rev(&mut examples);
 620
 621    if let Some(name_filter) = &args.name {
 622        examples.retain(|example| example.spec.name.contains(name_filter));
 623    }
 624    if let Some(repo_filter) = &args.repo {
 625        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
 626    }
 627
 628    // Skip resume logic for --in-place since input and output are the same file,
 629    // which would incorrectly treat all input examples as already processed.
 630    if !args.in_place {
 631        if let Some(path) = output_path
 632            && let Some(command) = &args.command
 633        {
 634            resume_from_output(path, &mut examples, command);
 635        }
 636    }
 637
 638    if let Some(limit) = args.limit {
 639        examples.truncate(limit);
 640    }
 641
 642    let progress = Progress::global();
 643    progress.set_total_examples(examples.len());
 644    progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
 645
 646    Ok(examples)
 647}
 648
 649fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
 650    let mut hasher = collections::FxHasher::default();
 651    spec.hash(&mut hasher);
 652    hasher.finish()
 653}
 654
 655fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
 656    let file = match File::open(path) {
 657        Ok(f) => f,
 658        Err(_) => return,
 659    };
 660
 661    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
 662
 663    let reader = BufReader::new(file);
 664    let mut kept_lines = Vec::new();
 665    let mut kept_hashes = HashSet::default();
 666
 667    for line in reader.lines() {
 668        let line = match line {
 669            Ok(l) => l,
 670            Err(_) => continue,
 671        };
 672
 673        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
 674            let hash = spec_hash(&output_example.spec);
 675            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
 676                let is_complete = match command {
 677                    Command::Qa(_) => output_example
 678                        .qa
 679                        .first()
 680                        .and_then(|q| q.as_ref())
 681                        .and_then(|q| q.confidence)
 682                        .is_some(),
 683                    Command::Repair(_) => output_example.predictions.iter().any(|p| {
 684                        p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
 685                    }),
 686                    _ => true,
 687                };
 688                if is_complete {
 689                    kept_hashes.insert(hash);
 690                    kept_lines.push(line);
 691                }
 692            }
 693        }
 694    }
 695
 696    let total = examples.len();
 697    let already_processed = kept_hashes.len();
 698
 699    eprintln!(
 700        "Resuming: {}/{} examples already processed",
 701        already_processed, total
 702    );
 703
 704    let file = OpenOptions::new()
 705        .write(true)
 706        .truncate(true)
 707        .open(path)
 708        .expect("Failed to open output file for rewriting");
 709    let mut writer = BufWriter::new(file);
 710    for line in &kept_lines {
 711        writeln!(writer, "{}", line).expect("Failed to write to output file");
 712    }
 713    writer.flush().expect("Failed to flush output file");
 714
 715    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
 716}
 717
 718fn main() {
 719    let args = EpArgs::parse();
 720
 721    if args.printenv {
 722        ::util::shell_env::print_env();
 723        return;
 724    }
 725
 726    let output = args.output_path();
 727
 728    if args.markdown && output.is_none() {
 729        eprintln!("--markdown requires -o to specify the output directory");
 730        std::process::exit(1);
 731    }
 732
 733    let command = match &args.command {
 734        Some(cmd) => cmd.clone(),
 735        None => {
 736            EpArgs::command().print_help().unwrap();
 737            return;
 738        }
 739    };
 740
 741    match &command {
 742        Command::ImportBatch(import_args) => {
 743            smol::block_on(async {
 744                match import_args.provider {
 745                    BatchProvider::Anthropic => {
 746                        let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
 747                            .expect("Failed to create Anthropic client");
 748                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 749                            eprintln!("Error importing Anthropic batches: {:?}", e);
 750                            std::process::exit(1);
 751                        }
 752                    }
 753                    BatchProvider::Openai => {
 754                        let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
 755                            .expect("Failed to create OpenAI client");
 756                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 757                            eprintln!("Error importing OpenAI batches: {:?}", e);
 758                            std::process::exit(1);
 759                        }
 760                    }
 761                }
 762                println!(
 763                    "Successfully imported {} batch(es)",
 764                    import_args.batch_ids.len()
 765                );
 766            });
 767            return;
 768        }
 769        Command::Clean => {
 770            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
 771            return;
 772        }
 773        Command::PrintZetaFormats => {
 774            use strum::IntoEnumIterator as _;
 775            for format in ZetaFormat::iter() {
 776                println!("{}", format.to_string().to_lowercase());
 777            }
 778            return;
 779        }
 780        Command::SyncDeployments(sync_args) => {
 781            let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
 782            smol::block_on(async {
 783                if let Err(e) =
 784                    sync_deployments::run_sync_deployments(http_client, sync_args.model.clone())
 785                        .await
 786                {
 787                    eprintln!("Error: {:?}", e);
 788                    std::process::exit(1);
 789                }
 790            });
 791            return;
 792        }
 793        Command::Synthesize(synth_args) => {
 794            let Some(output_dir) = args.output else {
 795                panic!("output dir is required");
 796            };
 797            let config = SynthesizeConfig {
 798                repo_urls: synth_args.repos.clone(),
 799                count: synth_args.count,
 800                max_commits: synth_args.max_commits,
 801                output_dir,
 802                fresh: synth_args.fresh,
 803            };
 804            smol::block_on(async {
 805                if let Err(e) = run_synthesize(config).await {
 806                    eprintln!("Error: {:?}", e);
 807                    std::process::exit(1);
 808                }
 809            });
 810            return;
 811        }
 812        Command::SplitCommit(split_commit_args) => {
 813            if let Err(error) = split_commit::run_split_commit(
 814                split_commit_args,
 815                &args.inputs,
 816                output.as_ref(),
 817                args.failed,
 818            ) {
 819                eprintln!("{error:#}");
 820                std::process::exit(1);
 821            }
 822            return;
 823        }
 824        Command::TruncatePatch(truncate_args) => {
 825            if let Err(error) =
 826                truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
 827            {
 828                eprintln!("{error:#}");
 829                std::process::exit(1);
 830            }
 831            return;
 832        }
 833        Command::Split(split_args) => {
 834            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
 835                eprintln!("{error:#}");
 836                std::process::exit(1);
 837            }
 838            return;
 839        }
 840        Command::FilterLanguages(filter_args) => {
 841            if let Err(error) =
 842                run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
 843            {
 844                eprintln!("{error:#}");
 845                std::process::exit(1);
 846            }
 847            return;
 848        }
 849
 850        _ => {}
 851    }
 852
 853    let http_client = Arc::new(ReqwestClient::new());
 854    let app = Application::headless().with_http_client(http_client);
 855
 856    app.run(move |cx| {
 857        let app_state = Arc::new(headless::init(cx));
 858        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
 859
 860        cx.spawn(async move |cx| {
 861            let result = async {
 862                let examples = load_examples(
 863                    app_state.client.http_client(),
 864                    &args,
 865                    output.as_ref(),
 866                    cx.background_executor().clone(),
 867                )
 868                .await?;
 869
 870                match &command {
 871                    Command::Predict(args) | Command::Score(args) => {
 872                        predict::sync_batches(args.provider.as_ref()).await?;
 873                    }
 874                    Command::Eval(args) => {
 875                        predict::sync_batches(args.predict.provider.as_ref()).await?;
 876                    }
 877                    Command::Qa(args) => {
 878                        qa::sync_batches(args).await?;
 879                    }
 880                    Command::Repair(args) => {
 881                        repair::sync_batches(args).await?;
 882                    }
 883                    _ => (),
 884                }
 885
 886                let failfast_on_single_example = examples.len() == 1;
 887
 888                // For --markdown mode, create the output directory if it doesn't exist
 889                if args.markdown {
 890                    let dir = output.as_ref().expect("--markdown requires -o");
 891                    if !dir.exists() {
 892                        std::fs::create_dir_all(dir)
 893                            .expect("Failed to create markdown output directory");
 894                    }
 895                }
 896
 897                // Set up JSONL output writer (not used in markdown mode)
 898                let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
 899                let mut in_place_temp_path: Option<PathBuf> = None;
 900                if !args.markdown
 901                    && let Some(output_path) = output.as_ref()
 902                {
 903                    let write_path = if args.in_place {
 904                        let temp = output_path.with_extension("jsonl.tmp");
 905                        in_place_temp_path = Some(temp.clone());
 906                        temp
 907                    } else {
 908                        output_path.clone()
 909                    };
 910
 911                    let file = OpenOptions::new()
 912                        .create(true)
 913                        .write(true)
 914                        .truncate(args.in_place)
 915                        .append(!args.in_place)
 916                        .open(&write_path)
 917                        .expect("Failed to open output file");
 918
 919                    let mut writer = BufWriter::new(file);
 920                    let (sender, mut receiver) = mpsc::unbounded::<String>();
 921                    cx.background_spawn(async move {
 922                        while let Some(line) = receiver.next().await {
 923                            writeln!(writer, "{}", line).expect("Failed to write example");
 924                            writer.flush().expect("Failed to flush output");
 925                        }
 926                    })
 927                    .detach();
 928                    output_sender = Some(sender);
 929                }
 930
 931                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
 932                let finished_examples = Mutex::new(Vec::new());
 933
 934                let mut tasks = Vec::new();
 935                for _ in 0..args.max_parallelism {
 936                    tasks.push(async {
 937                        loop {
 938                            let Some(mut repo_examples) =
 939                                grouped_examples.lock().unwrap().pop_front()
 940                            else {
 941                                break;
 942                            };
 943                            for example in &mut repo_examples {
 944                                let example_progress =
 945                                    Progress::global().start_group(&example.spec.name);
 946
 947                                let result = async {
 948                                    match &command {
 949                                        Command::Read => {}
 950                                        Command::LoadProject => {
 951                                            run_load_project(
 952                                                example,
 953                                                app_state.clone(),
 954                                                &example_progress,
 955                                                cx.clone(),
 956                                            )
 957                                            .await?;
 958                                        }
 959                                        Command::Context => {
 960                                            run_context_retrieval(
 961                                                example,
 962                                                app_state.clone(),
 963                                                &example_progress,
 964                                                cx.clone(),
 965                                            )
 966                                            .await?;
 967                                        }
 968                                        Command::FormatPrompt(args) => {
 969                                            run_format_prompt(
 970                                                example,
 971                                                args,
 972                                                app_state.clone(),
 973                                                &example_progress,
 974                                                cx.clone(),
 975                                            )
 976                                            .await?;
 977                                        }
 978                                        Command::Predict(args) => {
 979                                            run_prediction(
 980                                                example,
 981                                                args,
 982                                                app_state.clone(),
 983                                                &example_progress,
 984                                                cx.clone(),
 985                                            )
 986                                            .await?;
 987                                        }
 988                                        Command::ParseOutput => {
 989                                            parse_output::run_parse_output(example)?;
 990                                        }
 991                                        Command::Distill => {
 992                                            run_distill(example).await?;
 993                                        }
 994                                        Command::Score(args) => {
 995                                            run_scoring(
 996                                                example,
 997                                                args,
 998                                                app_state.clone(),
 999                                                &example_progress,
1000                                                cx.clone(),
1001                                            )
1002                                            .await?;
1003                                        }
1004                                        Command::Eval(args) => {
1005                                            run_scoring(
1006                                                example,
1007                                                &args.predict,
1008                                                app_state.clone(),
1009                                                &example_progress,
1010                                                cx.clone(),
1011                                            )
1012                                            .await?;
1013                                        }
1014                                        Command::Qa(args) => {
1015                                            qa::run_qa(example, args, &example_progress).await?;
1016                                        }
1017                                        Command::Repair(args) => {
1018                                            repair::run_repair(example, args, &example_progress)
1019                                                .await?;
1020                                        }
1021                                        Command::Clean
1022                                        | Command::Synthesize(_)
1023                                        | Command::SplitCommit(_)
1024                                        | Command::Split(_)
1025                                        | Command::TruncatePatch(_)
1026                                        | Command::FilterLanguages(_)
1027                                        | Command::ImportBatch(_)
1028                                        | Command::PrintZetaFormats
1029                                        | Command::SyncDeployments(_) => {
1030                                            unreachable!()
1031                                        }
1032                                    }
1033                                    anyhow::Ok(())
1034                                }
1035                                .await;
1036
1037                                let failed = if let Err(error) = result {
1038                                    handle_error(
1039                                        error,
1040                                        &args,
1041                                        &command,
1042                                        &app_state,
1043                                        failfast_on_single_example,
1044                                        &example,
1045                                    )
1046                                    .await;
1047                                    true
1048                                } else {
1049                                    false
1050                                };
1051
1052                                let should_write = !failed || args.failed == FailedHandling::Keep;
1053                                if should_write {
1054                                    if args.markdown {
1055                                        let markdown_dir =
1056                                            output.as_ref().expect("--markdown requires -o");
1057                                        let filename = format!("{}.md", example.spec.filename());
1058                                        let path = markdown_dir.join(&filename);
1059                                        let markdown = example.spec.to_markdown();
1060                                        std::fs::write(&path, &markdown)
1061                                            .expect("Failed to write markdown file");
1062                                    } else if let Some(ref mut sender) = output_sender.clone() {
1063                                        let line = serde_json::to_string(&example).unwrap();
1064                                        sender
1065                                            .send(line)
1066                                            .await
1067                                            .expect("Failed to send to output writer");
1068                                    } else if args.output.is_none()
1069                                        && !matches!(command, Command::Eval(_))
1070                                    {
1071                                        let line = serde_json::to_string(&example).unwrap();
1072                                        println!("{}", line);
1073                                    }
1074                                }
1075                            }
1076
1077                            let project = repo_examples
1078                                .iter()
1079                                .find_map(|e| e.state.as_ref().map(|s| s.project.clone()));
1080
1081                            if let Some(project) = project {
1082                                let mut cx = cx.clone();
1083
1084                                let shutdown_task: Task<()> =
1085                                    project.update(&mut cx, |project, cx| {
1086                                        let lsp_store = project.lsp_store();
1087                                        lsp_store.update(cx, |lsp_store, cx| {
1088                                            lsp_store.shutdown_all_language_servers(cx)
1089                                        })
1090                                    });
1091
1092                                shutdown_task.await;
1093
1094                                if let Some(ep_store) =
1095                                    cx.update(|cx| EditPredictionStore::try_global(cx))
1096                                {
1097                                    ep_store.update(&mut cx, |store, _| {
1098                                        store.remove_project(&project);
1099                                    });
1100                                }
1101                            }
1102
1103                            for example in &mut repo_examples {
1104                                example.state.take();
1105                            }
1106                            finished_examples
1107                                .lock()
1108                                .unwrap()
1109                                .extend_from_slice(&repo_examples);
1110                        }
1111                    });
1112                }
1113                futures::future::join_all(tasks).await;
1114
1115                Progress::global().finalize();
1116
1117                match &command {
1118                    Command::Predict(args) | Command::Score(args) => {
1119                        predict::sync_batches(args.provider.as_ref()).await?;
1120                    }
1121                    Command::Eval(args) => {
1122                        predict::sync_batches(args.predict.provider.as_ref()).await?;
1123                    }
1124                    Command::Qa(args) => {
1125                        qa::sync_batches(args).await?;
1126                    }
1127                    Command::Repair(args) => {
1128                        repair::sync_batches(args).await?;
1129                    }
1130                    _ => (),
1131                }
1132
1133                match &command {
1134                    Command::Eval(args) => {
1135                        let examples = finished_examples.lock().unwrap();
1136                        score::print_report(&examples);
1137                        if let Some(summary_path) = &args.summary_json {
1138                            score::write_summary_json(&examples, summary_path)?;
1139                        }
1140                    }
1141                    Command::Repair(args) => {
1142                        let examples = finished_examples.lock().unwrap();
1143                        repair::print_report(&examples, args.confidence_threshold);
1144                    }
1145                    _ => (),
1146                };
1147
1148                // For --in-place, atomically rename temp file to original
1149                if let Some(temp_path) = &in_place_temp_path {
1150                    let final_path = output.as_ref().expect("in_place_temp_path requires output");
1151                    std::fs::rename(temp_path, final_path)
1152                        .expect("Failed to rename temp file to final output");
1153                }
1154
1155                anyhow::Ok(())
1156            }
1157            .await;
1158
1159            if let Err(e) = result {
1160                panic!("Fatal error: {:?}", e);
1161            }
1162
1163            let _ = cx.update(|cx| cx.quit());
1164        })
1165        .detach();
1166    });
1167}
1168
1169async fn handle_error(
1170    error: anyhow::Error,
1171    args: &EpArgs,
1172    command: &Command,
1173    app_state: &Arc<headless::EpAppState>,
1174    failfast_on_single_example: bool,
1175    example: &Example,
1176) {
1177    Progress::global().increment_failed();
1178
1179    let msg;
1180    if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1181        let example_name = example.spec.filename();
1182
1183        let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1184        app_state
1185            .fs
1186            .write(
1187                &failed_example_path,
1188                &serde_json::to_vec_pretty(&example).unwrap(),
1189            )
1190            .await
1191            .unwrap();
1192        let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1193        app_state
1194            .fs
1195            .write(&err_path, format!("{error:?}").as_bytes())
1196            .await
1197            .unwrap();
1198
1199        let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1200        let mut file = OpenOptions::new()
1201            .create(true)
1202            .append(true)
1203            .open(&failed_jsonl_path)
1204            .expect("Failed to open failed.jsonl");
1205        writeln!(file, "{}", serde_json::to_string(example).unwrap())
1206            .expect("Failed to write to failed.jsonl");
1207
1208        let cursor_path = match example.repo_name() {
1209            Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1210            Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1211        };
1212        msg = format!(
1213            indoc::indoc! {"
1214                While processing \"{}\":
1215
1216                \x1b[31m{:?}\x1b[0m
1217
1218                Example:        \x1b[36m{}\x1b[0m
1219                Error file:     \x1b[36m{}\x1b[0m
1220                Cursor file:    \x1b[36m{}\x1b[0m
1221                Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1222            "},
1223            example.spec.name,
1224            error,
1225            failed_example_path.display(),
1226            err_path.display(),
1227            cursor_path.display(),
1228            command,
1229            failed_example_path.display(),
1230        );
1231    } else {
1232        msg = format!(
1233            indoc::indoc! {"
1234            While processing \"{}\":
1235
1236                \x1b[31m{:?}\x1b[0m
1237            "},
1238            example.spec.name, error
1239        );
1240    }
1241
1242    if args.failfast || failfast_on_single_example {
1243        Progress::global().finalize();
1244        panic!("{}", msg);
1245    } else {
1246        log::error!("{}", msg);
1247    }
1248}