main.rs

   1mod anthropic_client;
   2mod distill;
   3mod example;
   4mod filter_languages;
   5mod format_prompt;
   6mod git;
   7mod headless;
   8mod load_project;
   9mod metrics;
  10mod openai_client;
  11mod parse_output;
  12mod paths;
  13mod predict;
  14mod progress;
  15mod prompt_assets;
  16mod pull_examples;
  17mod qa;
  18mod reorder_patch;
  19mod repair;
  20mod retrieve_context;
  21mod reversal_tracking;
  22mod score;
  23mod split_commit;
  24mod split_dataset;
  25mod synthesize;
  26mod word_diff;
  27use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
  28use collections::HashSet;
  29use edit_prediction::EditPredictionStore;
  30use futures::channel::mpsc;
  31use futures::{SinkExt as _, StreamExt as _};
  32use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
  33use zeta_prompt::ZetaVersion;
  34
  35use reqwest_client::ReqwestClient;
  36use serde::{Deserialize, Deserializer, Serialize, Serializer};
  37use std::fmt::Display;
  38use std::fs::{File, OpenOptions};
  39use std::hash::{Hash, Hasher};
  40use std::io::{BufRead, BufReader, BufWriter, Write};
  41use std::sync::Mutex;
  42use std::{path::PathBuf, sync::Arc};
  43
  44use crate::distill::run_distill;
  45use crate::example::{Example, group_examples_by_repo, read_example_files};
  46use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
  47use crate::format_prompt::run_format_prompt;
  48use crate::load_project::run_load_project;
  49use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
  50use crate::predict::run_prediction;
  51use crate::progress::Progress;
  52use crate::retrieve_context::run_context_retrieval;
  53use crate::score::run_scoring;
  54use crate::split_commit::SplitCommitArgs;
  55use crate::split_dataset::SplitArgs;
  56use crate::synthesize::{SynthesizeConfig, run_synthesize};
  57
  58#[derive(Parser, Debug)]
  59#[command(name = "ep")]
  60struct EpArgs {
  61    #[arg(long, default_value_t = false)]
  62    printenv: bool,
  63    #[clap(long, default_value_t = 10, global = true)]
  64    max_parallelism: usize,
  65    /// The limit for the number of examples to process
  66    /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
  67    #[clap(long, global = true)]
  68    limit: Option<usize>,
  69    #[clap(long, global = true)]
  70    offset: Option<usize>,
  71    /// Filter examples by name
  72    #[clap(long, global = true)]
  73    name: Option<String>,
  74    /// Filter examples by repository
  75    #[clap(long, global = true)]
  76    repo: Option<String>,
  77    #[command(subcommand)]
  78    command: Option<Command>,
  79    #[clap(global = true, help = INPUTS_HELP)]
  80    inputs: Vec<PathBuf>,
  81    #[arg(long, short, global = true)]
  82    output: Option<PathBuf>,
  83    #[arg(long, short, global = true)]
  84    in_place: bool,
  85    #[arg(long, short, global = true)]
  86    failfast: bool,
  87    /// How to handle failed examples in output: keep them or skip them.
  88    /// Failed examples are always logged to the run's failed directory.
  89    #[arg(long, global = true, default_value = "keep")]
  90    failed: FailedHandling,
  91    /// Output as markdown files instead of JSONL. When set, -o specifies a directory
  92    /// where one .md file per example will be written (named after each example).
  93    #[arg(long, short, global = true)]
  94    markdown: bool,
  95}
  96
  97/// Controls whether failed examples are included in the main output.
  98/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
  99#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
 100pub enum FailedHandling {
 101    /// Include failed examples in the main output (default)
 102    #[default]
 103    Keep,
 104    /// Exclude failed examples from the main output
 105    Skip,
 106    /// Skip writing files
 107    SkipNoFiles,
 108}
 109
 110const INPUTS_HELP: &str = r#"
 111Inputs can be file paths or special specifiers:
 112
 113  path
 114      Path to an example(s) file (.md, .json, or .jsonl)
 115
 116  captured-after:{timestamp}
 117      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 118      These are examples captured via the "Capture Edit Prediction Example" action.
 119
 120  rejected-after:{timestamp}
 121      Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
 122      These are predictions that were shown to users but rejected (useful for DPO training).
 123
 124  rated-after:{timestamp}
 125      Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
 126      These are predictions that users explicitly rated as positive or negative via the
 127      rate completions modal. Only zeta2 predictions are included.
 128      - Positive ratings: output becomes expected_patches
 129      - Negative ratings: output becomes rejected_patch
 130
 131  rated-positive-after:{timestamp}
 132      Same as rated-after, but only fetches positively rated predictions.
 133
 134  rated-negative-after:{timestamp}
 135      Same as rated-after, but only fetches negatively rated predictions.
 136
 137      Required environment variables to connect to Snowflake:
 138          EP_SNOWFLAKE_API_KEY
 139          EP_SNOWFLAKE_BASE_URL
 140
 141      Optional:
 142          EP_SNOWFLAKE_ROLE
 143
 144Examples:
 145
 146  # Read examples from a file
 147  ep read examples.jsonl -o output.jsonl
 148
 149  # Read captured examples after a timestamp
 150  ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
 151
 152  # Read rejected predictions for DPO training
 153  ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
 154
 155  # Read user-rated predictions
 156  ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
 157
 158  # Read only positively rated predictions
 159  ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
 160
 161  # Read only negatively rated predictions
 162  ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
 163
 164  # Mix multiple input sources
 165  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
 166"#;
 167
 168#[derive(Subcommand, Debug, Clone)]
 169enum Command {
 170    /// Read examples from files or fetch from Snowflake, output as .jsonl
 171    Read,
 172    /// Create git worktrees for each example and load file contents
 173    LoadProject,
 174    /// Retrieve context for input examples.
 175    Context,
 176    /// Generate a prompt string for a specific model
 177    FormatPrompt(FormatPromptArgs),
 178    /// Runs edit prediction
 179    Predict(PredictArgs),
 180    /// Parse model outputs (actual_output) into unified diffs (actual_patch).
 181    /// Requires format-prompt to have been run first. Uses provider from prompt.
 182    ParseOutput,
 183    /// Computes a score based on actual and expected patches
 184    Score(PredictArgs),
 185    /// Prepares a distillation dataset by copying expected outputs to
 186    /// predicted outputs and removing actual outputs and prompts.
 187    Distill,
 188    /// Print aggregated scores
 189    Eval(EvalArgs),
 190    /// Generate eval examples by analyzing git commits from a repository
 191    Synthesize(SynthesizeArgs),
 192    /// Remove git repositories and worktrees
 193    Clean,
 194    /// Generate an evaluation example by splitting a chronologically-ordered commit
 195    SplitCommit(SplitCommitArgs),
 196    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
 197    Split(SplitArgs),
 198    /// Filter a JSONL dataset by programming language (based on cursor_path extension)
 199    FilterLanguages(FilterLanguagesArgs),
 200    /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
 201    ImportBatch(ImportBatchArgs),
 202    /// Assess the quality of predictions using LLM-as-a-judge
 203    Qa(qa::QaArgs),
 204    /// Repair predictions that received poor QA scores by generating improved predictions
 205    Repair(repair::RepairArgs),
 206}
 207
 208impl Display for Command {
 209    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 210        match self {
 211            Command::Read => write!(f, "read"),
 212            Command::LoadProject => write!(f, "load-project"),
 213            Command::Context => write!(f, "context"),
 214            Command::FormatPrompt(args) => {
 215                write!(f, "format-prompt --provider={}", args.provider)
 216            }
 217            Command::Predict(args) => match &args.provider {
 218                Some(provider) => write!(f, "predict --provider={}", provider),
 219                None => write!(f, "predict"),
 220            },
 221            Command::ParseOutput => write!(f, "parse-output"),
 222            Command::Score(args) => match &args.provider {
 223                Some(provider) => write!(f, "score --provider={}", provider),
 224                None => write!(f, "score"),
 225            },
 226            Command::Distill => write!(f, "distill"),
 227            Command::Eval(args) => match &args.predict.provider {
 228                Some(provider) => write!(f, "eval --provider={}", provider),
 229                None => write!(f, "eval"),
 230            },
 231            Command::Synthesize(args) => {
 232                write!(f, "synthesize --repos {}", args.repos.join(" "))
 233            }
 234            Command::Clean => write!(f, "clean"),
 235            Command::SplitCommit(_) => write!(f, "split-commit"),
 236            Command::Split(_) => write!(f, "split"),
 237            Command::FilterLanguages(_) => write!(f, "filter-languages"),
 238            Command::ImportBatch(args) => {
 239                write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
 240            }
 241            Command::Qa(_) => {
 242                write!(f, "qa")
 243            }
 244            Command::Repair(_) => {
 245                write!(f, "repair")
 246            }
 247        }
 248    }
 249}
 250
 251#[derive(Debug, Args, Clone)]
 252struct FormatPromptArgs {
 253    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
 254    provider: PredictionProvider,
 255}
 256
 257#[derive(Debug, Args, Clone)]
 258struct PredictArgs {
 259    #[clap(long, short('p'))]
 260    provider: Option<PredictionProvider>,
 261    #[clap(long, default_value_t = 1)]
 262    repetitions: usize,
 263    /// Only use cached responses, don't queue new requests for batching
 264    #[clap(long)]
 265    cache_only: bool,
 266}
 267
 268#[derive(Debug, Args, Clone)]
 269struct EvalArgs {
 270    #[clap(flatten)]
 271    predict: PredictArgs,
 272    /// Path to write summary scores as JSON
 273    #[clap(long)]
 274    summary_json: Option<PathBuf>,
 275}
 276
 277#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 278pub enum TeacherBackend {
 279    Sonnet45,
 280    Gpt52,
 281}
 282
 283impl std::fmt::Display for TeacherBackend {
 284    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 285        match self {
 286            TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
 287            TeacherBackend::Gpt52 => write!(f, "gpt52"),
 288        }
 289    }
 290}
 291
 292impl std::str::FromStr for TeacherBackend {
 293    type Err = anyhow::Error;
 294
 295    fn from_str(s: &str) -> Result<Self, Self::Err> {
 296        match s.to_lowercase().as_str() {
 297            "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
 298            "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
 299            "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
 300            _ => anyhow::bail!("unknown teacher backend `{s}`. Valid options: sonnet45, gpt52"),
 301        }
 302    }
 303}
 304
 305impl TeacherBackend {
 306    pub fn model_name(&self) -> &'static str {
 307        match self {
 308            TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
 309            TeacherBackend::Gpt52 => "gpt-5.2",
 310        }
 311    }
 312}
 313
 314#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 315enum PredictionProvider {
 316    Sweep,
 317    Mercury,
 318    Zeta1,
 319    Zeta2(ZetaVersion),
 320    Teacher(TeacherBackend),
 321    TeacherNonBatching(TeacherBackend),
 322    Repair,
 323}
 324
 325impl Default for PredictionProvider {
 326    fn default() -> Self {
 327        PredictionProvider::Zeta2(ZetaVersion::default())
 328    }
 329}
 330
 331impl std::fmt::Display for PredictionProvider {
 332    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 333        match self {
 334            PredictionProvider::Sweep => write!(f, "sweep"),
 335            PredictionProvider::Mercury => write!(f, "mercury"),
 336            PredictionProvider::Zeta1 => write!(f, "zeta1"),
 337            PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
 338            PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
 339            PredictionProvider::TeacherNonBatching(backend) => {
 340                write!(f, "teacher-non-batching:{backend}")
 341            }
 342            PredictionProvider::Repair => write!(f, "repair"),
 343        }
 344    }
 345}
 346
 347impl std::str::FromStr for PredictionProvider {
 348    type Err = anyhow::Error;
 349
 350    fn from_str(s: &str) -> Result<Self, Self::Err> {
 351        let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
 352
 353        let provider_lower = provider.to_lowercase();
 354        match provider_lower.as_str() {
 355            "sweep" => Ok(PredictionProvider::Sweep),
 356            "mercury" => Ok(PredictionProvider::Mercury),
 357            "zeta1" => Ok(PredictionProvider::Zeta1),
 358            "zeta2" => {
 359                let version = arg.map(ZetaVersion::parse).transpose()?.unwrap_or_default();
 360                Ok(PredictionProvider::Zeta2(version))
 361            }
 362            "teacher" => {
 363                let backend = arg
 364                    .map(|a| a.parse())
 365                    .transpose()?
 366                    .unwrap_or(TeacherBackend::Sonnet45);
 367                Ok(PredictionProvider::Teacher(backend))
 368            }
 369            "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
 370                let backend = arg
 371                    .map(|a| a.parse())
 372                    .transpose()?
 373                    .unwrap_or(TeacherBackend::Sonnet45);
 374                Ok(PredictionProvider::TeacherNonBatching(backend))
 375            }
 376            "repair" => Ok(PredictionProvider::Repair),
 377            _ => {
 378                anyhow::bail!(
 379                    "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
 380                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
 381                 For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\
 382                 Available zeta versions:\n{}",
 383                    ZetaVersion::options_as_string()
 384                )
 385            }
 386        }
 387    }
 388}
 389
 390impl Serialize for PredictionProvider {
 391    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 392    where
 393        S: Serializer,
 394    {
 395        serializer.serialize_str(&self.to_string())
 396    }
 397}
 398
 399impl<'de> Deserialize<'de> for PredictionProvider {
 400    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 401    where
 402        D: Deserializer<'de>,
 403    {
 404        let s = String::deserialize(deserializer)?;
 405        s.parse().map_err(serde::de::Error::custom)
 406    }
 407}
 408
 409#[derive(Debug, Args, Clone)]
 410struct SynthesizeArgs {
 411    /// Repository URLs (git@github.com:owner/repo or https://...)
 412    #[clap(long, required = true, num_args = 1..)]
 413    repos: Vec<String>,
 414
 415    /// Number of examples to generate per repository
 416    #[clap(long, default_value_t = 5)]
 417    count: usize,
 418
 419    /// Maximum commits to scan per repository before giving up
 420    #[clap(long, default_value_t = 100)]
 421    max_commits: usize,
 422
 423    /// Ignore state file and reprocess all commits
 424    #[clap(long)]
 425    fresh: bool,
 426}
 427
 428#[derive(Debug, Args, Clone)]
 429struct ImportBatchArgs {
 430    /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
 431    #[clap(long, required = true, num_args = 1..)]
 432    batch_ids: Vec<String>,
 433    /// Which provider's batches to import (anthropic or openai)
 434    #[clap(long, default_value = "anthropic")]
 435    provider: BatchProvider,
 436}
 437
 438#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
 439enum BatchProvider {
 440    Anthropic,
 441    Openai,
 442}
 443
 444impl EpArgs {
 445    fn output_path(&self) -> Option<PathBuf> {
 446        if self.in_place {
 447            if self.inputs.len() == 1 {
 448                self.inputs.first().cloned()
 449            } else {
 450                panic!("--in-place requires exactly one input file")
 451            }
 452        } else {
 453            self.output.clone()
 454        }
 455    }
 456}
 457
 458async fn load_examples(
 459    http_client: Arc<dyn http_client::HttpClient>,
 460    args: &EpArgs,
 461    output_path: Option<&PathBuf>,
 462    background_executor: BackgroundExecutor,
 463) -> anyhow::Result<Vec<Example>> {
 464    let mut captured_after_timestamps = Vec::new();
 465    let mut rejected_after_timestamps = Vec::new();
 466    let mut requested_after_timestamps = Vec::new();
 467    let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
 468        Vec::new();
 469    let mut file_inputs = Vec::new();
 470
 471    for input in &args.inputs {
 472        let input_string = input.to_string_lossy();
 473        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
 474            captured_after_timestamps.push(timestamp.to_string());
 475        } else if let Some(timestamp) =
 476            pull_examples::parse_rejected_after_input(input_string.as_ref())
 477        {
 478            rejected_after_timestamps.push(timestamp.to_string());
 479        } else if let Some(timestamp) =
 480            pull_examples::parse_requested_after_input(input_string.as_ref())
 481        {
 482            requested_after_timestamps.push(timestamp.to_string());
 483        } else if let Some((timestamp, rating_filter)) =
 484            pull_examples::parse_rated_after_input(input_string.as_ref())
 485        {
 486            rated_after_inputs.push((timestamp.to_string(), rating_filter));
 487        } else {
 488            file_inputs.push(input.clone());
 489        }
 490    }
 491
 492    let mut examples = read_example_files(&file_inputs);
 493
 494    Progress::global().set_total_examples(examples.len());
 495
 496    let remaining_limit_for_snowflake =
 497        args.limit.map(|limit| limit.saturating_sub(examples.len()));
 498
 499    if let Some(0) = remaining_limit_for_snowflake {
 500        log::info!(
 501            "skipping Snowflake inputs because --limit is already satisfied by example files"
 502        );
 503    } else {
 504        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
 505
 506        if !captured_after_timestamps.is_empty() {
 507            captured_after_timestamps.sort();
 508
 509            let mut captured_examples = pull_examples::fetch_captured_examples_after(
 510                http_client.clone(),
 511                &captured_after_timestamps,
 512                max_rows_per_timestamp,
 513                background_executor.clone(),
 514            )
 515            .await?;
 516            examples.append(&mut captured_examples);
 517        }
 518
 519        if !rejected_after_timestamps.is_empty() {
 520            rejected_after_timestamps.sort();
 521
 522            let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
 523                http_client.clone(),
 524                &rejected_after_timestamps,
 525                max_rows_per_timestamp,
 526                background_executor.clone(),
 527            )
 528            .await?;
 529            examples.append(&mut rejected_examples);
 530        }
 531
 532        if !requested_after_timestamps.is_empty() {
 533            requested_after_timestamps.sort();
 534
 535            let mut requested_examples = pull_examples::fetch_requested_examples_after(
 536                http_client.clone(),
 537                &requested_after_timestamps,
 538                max_rows_per_timestamp,
 539                background_executor.clone(),
 540            )
 541            .await?;
 542            examples.append(&mut requested_examples);
 543        }
 544
 545        if !rated_after_inputs.is_empty() {
 546            rated_after_inputs.sort();
 547
 548            let mut rated_examples = pull_examples::fetch_rated_examples_after(
 549                http_client,
 550                &rated_after_inputs,
 551                max_rows_per_timestamp,
 552                background_executor,
 553            )
 554            .await?;
 555            examples.append(&mut rated_examples);
 556        }
 557    }
 558
 559    crate::example::sort_examples_by_repo_and_rev(&mut examples);
 560
 561    if let Some(name_filter) = &args.name {
 562        examples.retain(|example| example.spec.name.contains(name_filter));
 563    }
 564    if let Some(repo_filter) = &args.repo {
 565        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
 566    }
 567
 568    // Skip resume logic for --in-place since input and output are the same file,
 569    // which would incorrectly treat all input examples as already processed.
 570    if !args.in_place {
 571        if let Some(path) = output_path
 572            && let Some(command) = &args.command
 573        {
 574            resume_from_output(path, &mut examples, command);
 575        }
 576    }
 577
 578    if let Some(offset) = args.offset {
 579        examples.splice(0..offset, []);
 580    }
 581
 582    if let Some(limit) = args.limit {
 583        examples.truncate(limit);
 584    }
 585
 586    let progress = Progress::global();
 587    progress.set_total_examples(examples.len());
 588    progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
 589
 590    Ok(examples)
 591}
 592
 593fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
 594    let mut hasher = collections::FxHasher::default();
 595    spec.hash(&mut hasher);
 596    hasher.finish()
 597}
 598
 599fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
 600    let file = match File::open(path) {
 601        Ok(f) => f,
 602        Err(_) => return,
 603    };
 604
 605    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
 606
 607    let reader = BufReader::new(file);
 608    let mut kept_lines = Vec::new();
 609    let mut kept_hashes = HashSet::default();
 610
 611    for line in reader.lines() {
 612        let line = match line {
 613            Ok(l) => l,
 614            Err(_) => continue,
 615        };
 616
 617        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
 618            let hash = spec_hash(&output_example.spec);
 619            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
 620                let is_complete = match command {
 621                    Command::Qa(_) => output_example
 622                        .qa
 623                        .first()
 624                        .and_then(|q| q.as_ref())
 625                        .and_then(|q| q.confidence)
 626                        .is_some(),
 627                    Command::Repair(_) => output_example.predictions.iter().any(|p| {
 628                        p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
 629                    }),
 630                    _ => true,
 631                };
 632                if is_complete {
 633                    kept_hashes.insert(hash);
 634                    kept_lines.push(line);
 635                }
 636            }
 637        }
 638    }
 639
 640    let total = examples.len();
 641    let already_processed = kept_hashes.len();
 642
 643    eprintln!(
 644        "Resuming: {}/{} examples already processed",
 645        already_processed, total
 646    );
 647
 648    let file = OpenOptions::new()
 649        .write(true)
 650        .truncate(true)
 651        .open(path)
 652        .expect("Failed to open output file for rewriting");
 653    let mut writer = BufWriter::new(file);
 654    for line in &kept_lines {
 655        writeln!(writer, "{}", line).expect("Failed to write to output file");
 656    }
 657    writer.flush().expect("Failed to flush output file");
 658
 659    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
 660}
 661
 662fn main() {
 663    let args = EpArgs::parse();
 664
 665    if args.printenv {
 666        ::util::shell_env::print_env();
 667        return;
 668    }
 669
 670    let output = args.output_path();
 671
 672    if args.markdown && output.is_none() {
 673        eprintln!("--markdown requires -o to specify the output directory");
 674        std::process::exit(1);
 675    }
 676
 677    let command = match &args.command {
 678        Some(cmd) => cmd.clone(),
 679        None => {
 680            EpArgs::command().print_help().unwrap();
 681            return;
 682        }
 683    };
 684
 685    match &command {
 686        Command::ImportBatch(import_args) => {
 687            smol::block_on(async {
 688                match import_args.provider {
 689                    BatchProvider::Anthropic => {
 690                        let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
 691                            .expect("Failed to create Anthropic client");
 692                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 693                            eprintln!("Error importing Anthropic batches: {:?}", e);
 694                            std::process::exit(1);
 695                        }
 696                    }
 697                    BatchProvider::Openai => {
 698                        let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
 699                            .expect("Failed to create OpenAI client");
 700                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 701                            eprintln!("Error importing OpenAI batches: {:?}", e);
 702                            std::process::exit(1);
 703                        }
 704                    }
 705                }
 706                println!(
 707                    "Successfully imported {} batch(es)",
 708                    import_args.batch_ids.len()
 709                );
 710            });
 711            return;
 712        }
 713        Command::Clean => {
 714            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
 715            return;
 716        }
 717        Command::Synthesize(synth_args) => {
 718            let Some(output_dir) = args.output else {
 719                panic!("output dir is required");
 720            };
 721            let config = SynthesizeConfig {
 722                repo_urls: synth_args.repos.clone(),
 723                count: synth_args.count,
 724                max_commits: synth_args.max_commits,
 725                output_dir,
 726                fresh: synth_args.fresh,
 727            };
 728            smol::block_on(async {
 729                if let Err(e) = run_synthesize(config).await {
 730                    eprintln!("Error: {:?}", e);
 731                    std::process::exit(1);
 732                }
 733            });
 734            return;
 735        }
 736        Command::SplitCommit(split_commit_args) => {
 737            if let Err(error) = split_commit::run_split_commit(
 738                split_commit_args,
 739                &args.inputs,
 740                output.as_ref(),
 741                args.failed,
 742            ) {
 743                eprintln!("{error:#}");
 744                std::process::exit(1);
 745            }
 746            return;
 747        }
 748        Command::Split(split_args) => {
 749            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
 750                eprintln!("{error:#}");
 751                std::process::exit(1);
 752            }
 753            return;
 754        }
 755        Command::FilterLanguages(filter_args) => {
 756            if let Err(error) =
 757                run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
 758            {
 759                eprintln!("{error:#}");
 760                std::process::exit(1);
 761            }
 762            return;
 763        }
 764
 765        _ => {}
 766    }
 767
 768    let http_client = Arc::new(ReqwestClient::new());
 769    let app = Application::headless().with_http_client(http_client);
 770
 771    app.run(move |cx| {
 772        let app_state = Arc::new(headless::init(cx));
 773        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
 774
 775        cx.spawn(async move |cx| {
 776            let result = async {
 777                let examples = load_examples(
 778                    app_state.client.http_client(),
 779                    &args,
 780                    output.as_ref(),
 781                    cx.background_executor().clone(),
 782                )
 783                .await?;
 784
 785                match &command {
 786                    Command::Predict(args) | Command::Score(args) => {
 787                        predict::sync_batches(args.provider.as_ref()).await?;
 788                    }
 789                    Command::Eval(args) => {
 790                        predict::sync_batches(args.predict.provider.as_ref()).await?;
 791                    }
 792                    Command::Qa(args) => {
 793                        qa::sync_batches(args).await?;
 794                    }
 795                    Command::Repair(args) => {
 796                        repair::sync_batches(args).await?;
 797                    }
 798                    _ => (),
 799                }
 800
 801                let failfast_on_single_example = examples.len() == 1;
 802
 803                // For --markdown mode, create the output directory if it doesn't exist
 804                if args.markdown {
 805                    let dir = output.as_ref().expect("--markdown requires -o");
 806                    if !dir.exists() {
 807                        std::fs::create_dir_all(dir)
 808                            .expect("Failed to create markdown output directory");
 809                    }
 810                }
 811
 812                // Set up JSONL output writer (not used in markdown mode)
 813                let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
 814                let mut in_place_temp_path: Option<PathBuf> = None;
 815                if !args.markdown
 816                    && let Some(output_path) = output.as_ref()
 817                {
 818                    let write_path = if args.in_place {
 819                        let temp = output_path.with_extension("jsonl.tmp");
 820                        in_place_temp_path = Some(temp.clone());
 821                        temp
 822                    } else {
 823                        output_path.clone()
 824                    };
 825
 826                    let file = OpenOptions::new()
 827                        .create(true)
 828                        .write(true)
 829                        .truncate(args.in_place)
 830                        .append(!args.in_place)
 831                        .open(&write_path)
 832                        .expect("Failed to open output file");
 833
 834                    let mut writer = BufWriter::new(file);
 835                    let (sender, mut receiver) = mpsc::unbounded::<String>();
 836                    cx.background_spawn(async move {
 837                        while let Some(line) = receiver.next().await {
 838                            writeln!(writer, "{}", line).expect("Failed to write example");
 839                            writer.flush().expect("Failed to flush output");
 840                        }
 841                    })
 842                    .detach();
 843                    output_sender = Some(sender);
 844                }
 845
 846                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
 847                let finished_examples = Mutex::new(Vec::new());
 848
 849                let mut tasks = Vec::new();
 850                for _ in 0..args.max_parallelism {
 851                    tasks.push(async {
 852                        loop {
 853                            let Some(mut repo_examples) =
 854                                grouped_examples.lock().unwrap().pop_front()
 855                            else {
 856                                break;
 857                            };
 858                            for example in &mut repo_examples {
 859                                let example_progress =
 860                                    Progress::global().start_group(&example.spec.name);
 861
 862                                let result = async {
 863                                    match &command {
 864                                        Command::Read => {}
 865                                        Command::LoadProject => {
 866                                            run_load_project(
 867                                                example,
 868                                                app_state.clone(),
 869                                                &example_progress,
 870                                                cx.clone(),
 871                                            )
 872                                            .await?;
 873                                        }
 874                                        Command::Context => {
 875                                            run_context_retrieval(
 876                                                example,
 877                                                app_state.clone(),
 878                                                &example_progress,
 879                                                cx.clone(),
 880                                            )
 881                                            .await?;
 882                                        }
 883                                        Command::FormatPrompt(args) => {
 884                                            run_format_prompt(
 885                                                example,
 886                                                args,
 887                                                app_state.clone(),
 888                                                &example_progress,
 889                                                cx.clone(),
 890                                            )
 891                                            .await?;
 892                                        }
 893                                        Command::Predict(args) => {
 894                                            run_prediction(
 895                                                example,
 896                                                args,
 897                                                app_state.clone(),
 898                                                &example_progress,
 899                                                cx.clone(),
 900                                            )
 901                                            .await?;
 902                                        }
 903                                        Command::ParseOutput => {
 904                                            parse_output::run_parse_output(example)?;
 905                                        }
 906                                        Command::Distill => {
 907                                            run_distill(example).await?;
 908                                        }
 909                                        Command::Score(args) => {
 910                                            run_scoring(
 911                                                example,
 912                                                args,
 913                                                app_state.clone(),
 914                                                &example_progress,
 915                                                cx.clone(),
 916                                            )
 917                                            .await?;
 918                                        }
 919                                        Command::Eval(args) => {
 920                                            run_scoring(
 921                                                example,
 922                                                &args.predict,
 923                                                app_state.clone(),
 924                                                &example_progress,
 925                                                cx.clone(),
 926                                            )
 927                                            .await?;
 928                                        }
 929                                        Command::Qa(args) => {
 930                                            qa::run_qa(example, args, &example_progress).await?;
 931                                        }
 932                                        Command::Repair(args) => {
 933                                            repair::run_repair(example, args, &example_progress)
 934                                                .await?;
 935                                        }
 936                                        Command::Clean
 937                                        | Command::Synthesize(_)
 938                                        | Command::SplitCommit(_)
 939                                        | Command::Split(_)
 940                                        | Command::FilterLanguages(_)
 941                                        | Command::ImportBatch(_) => {
 942                                            unreachable!()
 943                                        }
 944                                    }
 945                                    anyhow::Ok(())
 946                                }
 947                                .await;
 948
 949                                let failed = if let Err(error) = result {
 950                                    handle_error(
 951                                        error,
 952                                        &args,
 953                                        &command,
 954                                        &app_state,
 955                                        failfast_on_single_example,
 956                                        &example,
 957                                    )
 958                                    .await;
 959                                    true
 960                                } else {
 961                                    false
 962                                };
 963
 964                                let should_write = !failed || args.failed == FailedHandling::Keep;
 965                                if should_write {
 966                                    if args.markdown {
 967                                        let markdown_dir =
 968                                            output.as_ref().expect("--markdown requires -o");
 969                                        let filename = format!("{}.md", example.spec.filename());
 970                                        let path = markdown_dir.join(&filename);
 971                                        let markdown = example.spec.to_markdown();
 972                                        std::fs::write(&path, &markdown)
 973                                            .expect("Failed to write markdown file");
 974                                    } else if let Some(ref mut sender) = output_sender.clone() {
 975                                        let line = serde_json::to_string(&example).unwrap();
 976                                        sender
 977                                            .send(line)
 978                                            .await
 979                                            .expect("Failed to send to output writer");
 980                                    } else if args.output.is_none()
 981                                        && !matches!(command, Command::Eval(_))
 982                                    {
 983                                        let line = serde_json::to_string(&example).unwrap();
 984                                        println!("{}", line);
 985                                    }
 986                                }
 987                            }
 988
 989                            let repo_url = &repo_examples.first().unwrap().spec.repository_url;
 990                            let project = repo_examples
 991                                .iter()
 992                                .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
 993                                .or_else(|| app_state.project_cache.get(repo_url));
 994
 995                            if let Some(project) = project {
 996                                let mut cx = cx.clone();
 997
 998                                let shutdown_task: Task<()> =
 999                                    project.update(&mut cx, |project, cx| {
1000                                        let lsp_store = project.lsp_store();
1001                                        lsp_store.update(cx, |lsp_store, cx| {
1002                                            lsp_store.shutdown_all_language_servers(cx)
1003                                        })
1004                                    });
1005
1006                                shutdown_task.await;
1007
1008                                if let Some(ep_store) =
1009                                    cx.update(|cx| EditPredictionStore::try_global(cx))
1010                                {
1011                                    ep_store.update(&mut cx, |store, _| {
1012                                        store.remove_project(&project);
1013                                    });
1014                                }
1015                            }
1016
1017                            app_state.project_cache.remove(repo_url);
1018                            for example in &mut repo_examples {
1019                                example.state.take();
1020                            }
1021                            finished_examples
1022                                .lock()
1023                                .unwrap()
1024                                .extend_from_slice(&repo_examples);
1025                        }
1026                    });
1027                }
1028                futures::future::join_all(tasks).await;
1029
1030                Progress::global().finalize();
1031
1032                match &command {
1033                    Command::Predict(args) | Command::Score(args) => {
1034                        predict::sync_batches(args.provider.as_ref()).await?;
1035                    }
1036                    Command::Eval(args) => {
1037                        predict::sync_batches(args.predict.provider.as_ref()).await?;
1038                    }
1039                    Command::Qa(args) => {
1040                        qa::sync_batches(args).await?;
1041                    }
1042                    Command::Repair(args) => {
1043                        repair::sync_batches(args).await?;
1044                    }
1045                    _ => (),
1046                }
1047
1048                match &command {
1049                    Command::Eval(args) => {
1050                        let examples = finished_examples.lock().unwrap();
1051                        score::print_report(&examples);
1052                        if let Some(summary_path) = &args.summary_json {
1053                            score::write_summary_json(&examples, summary_path)?;
1054                        }
1055                    }
1056                    _ => (),
1057                };
1058
1059                // For --in-place, atomically rename temp file to original
1060                if let Some(temp_path) = &in_place_temp_path {
1061                    let final_path = output.as_ref().expect("in_place_temp_path requires output");
1062                    std::fs::rename(temp_path, final_path)
1063                        .expect("Failed to rename temp file to final output");
1064                }
1065
1066                anyhow::Ok(())
1067            }
1068            .await;
1069
1070            if let Err(e) = result {
1071                panic!("Fatal error: {:?}", e);
1072            }
1073
1074            let _ = cx.update(|cx| cx.quit());
1075        })
1076        .detach();
1077    });
1078}
1079
1080async fn handle_error(
1081    error: anyhow::Error,
1082    args: &EpArgs,
1083    command: &Command,
1084    app_state: &Arc<headless::EpAppState>,
1085    failfast_on_single_example: bool,
1086    example: &Example,
1087) {
1088    Progress::global().increment_failed();
1089
1090    let msg;
1091    if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1092        let example_name = example.spec.filename();
1093
1094        let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1095        app_state
1096            .fs
1097            .write(
1098                &failed_example_path,
1099                &serde_json::to_vec_pretty(&example).unwrap(),
1100            )
1101            .await
1102            .unwrap();
1103        let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1104        app_state
1105            .fs
1106            .write(&err_path, format!("{error:?}").as_bytes())
1107            .await
1108            .unwrap();
1109
1110        let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1111        let mut file = OpenOptions::new()
1112            .create(true)
1113            .append(true)
1114            .open(&failed_jsonl_path)
1115            .expect("Failed to open failed.jsonl");
1116        writeln!(file, "{}", serde_json::to_string(example).unwrap())
1117            .expect("Failed to write to failed.jsonl");
1118
1119        let cursor_path = match example.repo_name() {
1120            Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1121            Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1122        };
1123        msg = format!(
1124            indoc::indoc! {"
1125                While processing \"{}\":
1126
1127                \x1b[31m{:?}\x1b[0m
1128
1129                Example:        \x1b[36m{}\x1b[0m
1130                Error file:     \x1b[36m{}\x1b[0m
1131                Cursor file:    \x1b[36m{}\x1b[0m
1132                Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1133            "},
1134            example.spec.name,
1135            error,
1136            failed_example_path.display(),
1137            err_path.display(),
1138            cursor_path.display(),
1139            command,
1140            failed_example_path.display(),
1141        );
1142    } else {
1143        msg = format!(
1144            indoc::indoc! {"
1145            While processing \"{}\":
1146
1147                \x1b[31m{:?}\x1b[0m
1148            "},
1149            example.spec.name, error
1150        );
1151    }
1152
1153    if args.failfast || failfast_on_single_example {
1154        Progress::global().finalize();
1155        panic!("{}", msg);
1156    } else {
1157        log::error!("{}", msg);
1158    }
1159}