main.rs

   1mod anthropic_client;
   2mod distill;
   3mod example;
   4mod filter_languages;
   5mod format_prompt;
   6mod git;
   7mod headless;
   8mod llm_client;
   9mod load_project;
  10mod metrics;
  11mod openai_client;
  12mod parse_output;
  13mod paths;
  14mod predict;
  15mod progress;
  16mod pull_examples;
  17mod qa;
  18mod reorder_patch;
  19mod repair;
  20mod retrieve_context;
  21mod reversal_tracking;
  22mod score;
  23mod split_commit;
  24mod split_dataset;
  25mod synthesize;
  26mod word_diff;
  27use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
  28use collections::HashSet;
  29use edit_prediction::EditPredictionStore;
  30use futures::channel::mpsc;
  31use futures::{SinkExt as _, StreamExt as _};
  32use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
  33use zeta_prompt::ZetaVersion;
  34
  35use reqwest_client::ReqwestClient;
  36use serde::{Deserialize, Deserializer, Serialize, Serializer};
  37use std::fmt::Display;
  38use std::fs::{File, OpenOptions};
  39use std::hash::{Hash, Hasher};
  40use std::io::{BufRead, BufReader, BufWriter, Write};
  41use std::sync::Mutex;
  42use std::{path::PathBuf, sync::Arc};
  43
  44use crate::distill::run_distill;
  45use crate::example::{Example, group_examples_by_repo, read_example_files};
  46use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
  47use crate::format_prompt::run_format_prompt;
  48use crate::load_project::run_load_project;
  49use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
  50use crate::predict::run_prediction;
  51use crate::progress::Progress;
  52use crate::retrieve_context::run_context_retrieval;
  53use crate::score::run_scoring;
  54use crate::split_commit::SplitCommitArgs;
  55use crate::split_dataset::SplitArgs;
  56use crate::synthesize::{SynthesizeConfig, run_synthesize};
  57
  58#[derive(Parser, Debug)]
  59#[command(name = "ep")]
  60struct EpArgs {
  61    #[arg(long, default_value_t = false)]
  62    printenv: bool,
  63    #[clap(long, default_value_t = 10, global = true)]
  64    max_parallelism: usize,
  65    #[clap(long, global = true)]
  66    limit: Option<usize>,
  67    #[clap(long, global = true)]
  68    offset: Option<usize>,
  69    /// Filter examples by name
  70    #[clap(long, global = true)]
  71    name: Option<String>,
  72    /// Filter examples by repository
  73    #[clap(long, global = true)]
  74    repo: Option<String>,
  75    #[command(subcommand)]
  76    command: Option<Command>,
  77    #[clap(global = true, help = INPUTS_HELP)]
  78    inputs: Vec<PathBuf>,
  79    #[arg(long, short, global = true)]
  80    output: Option<PathBuf>,
  81    #[arg(long, short, global = true)]
  82    in_place: bool,
  83    #[arg(long, short, global = true)]
  84    failfast: bool,
  85    /// How to handle failed examples in output: keep them or skip them.
  86    /// Failed examples are always logged to the run's failed directory.
  87    #[arg(long, global = true, default_value = "keep")]
  88    failed: FailedHandling,
  89    /// Output as markdown files instead of JSONL. When set, -o specifies a directory
  90    /// where one .md file per example will be written (named after each example).
  91    #[arg(long, short, global = true)]
  92    markdown: bool,
  93}
  94
  95/// Controls whether failed examples are included in the main output.
  96/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
  97#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
  98pub enum FailedHandling {
  99    /// Include failed examples in the main output (default)
 100    #[default]
 101    Keep,
 102    /// Exclude failed examples from the main output
 103    Skip,
 104    /// Skip writing files
 105    SkipNoFiles,
 106}
 107
 108const INPUTS_HELP: &str = r#"
 109Inputs can be file paths or special specifiers:
 110
 111  path
 112      Path to an example(s) file (.md, .json, or .jsonl)
 113
 114  captured-after:{timestamp}
 115      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
 116      These are examples captured via the "Capture Edit Prediction Example" action.
 117
 118  rejected-after:{timestamp}
 119      Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
 120      These are predictions that were shown to users but rejected (useful for DPO training).
 121
 122      Required environment variables to connect to Snowflake:
 123          EP_SNOWFLAKE_API_KEY
 124          EP_SNOWFLAKE_BASE_URL
 125
 126      Optional:
 127          EP_SNOWFLAKE_ROLE
 128
 129Examples:
 130
 131  # Read examples from a file
 132  ep read examples.jsonl -o output.jsonl
 133
 134  # Read captured examples after a timestamp
 135  ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
 136
 137  # Read rejected predictions for DPO training
 138  ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
 139
 140  # Mix multiple input sources
 141  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
 142"#;
 143
 144#[derive(Subcommand, Debug, Clone)]
 145enum Command {
 146    /// Read examples from files or fetch from Snowflake, output as .jsonl
 147    Read,
 148    /// Create git worktrees for each example and load file contents
 149    LoadProject,
 150    /// Retrieve context for input examples.
 151    Context,
 152    /// Generate a prompt string for a specific model
 153    FormatPrompt(FormatPromptArgs),
 154    /// Runs edit prediction
 155    Predict(PredictArgs),
 156    /// Parse model outputs (actual_output) into unified diffs (actual_patch).
 157    /// Requires format-prompt to have been run first. Uses provider from prompt.
 158    ParseOutput,
 159    /// Computes a score based on actual and expected patches
 160    Score(PredictArgs),
 161    /// Prepares a distillation dataset by copying expected outputs to
 162    /// predicted outputs and removing actual outputs and prompts.
 163    Distill,
 164    /// Print aggregated scores
 165    Eval(EvalArgs),
 166    /// Generate eval examples by analyzing git commits from a repository
 167    Synthesize(SynthesizeArgs),
 168    /// Remove git repositories and worktrees
 169    Clean,
 170    /// Generate an evaluation example by splitting a chronologically-ordered commit
 171    SplitCommit(SplitCommitArgs),
 172    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
 173    Split(SplitArgs),
 174    /// Filter a JSONL dataset by programming language (based on cursor_path extension)
 175    FilterLanguages(FilterLanguagesArgs),
 176    /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
 177    ImportBatch(ImportBatchArgs),
 178    /// Assess the quality of predictions using LLM-as-a-judge
 179    Qa(qa::QaArgs),
 180    /// Repair predictions that received poor QA scores by generating improved predictions
 181    Repair(repair::RepairArgs),
 182}
 183
 184impl Display for Command {
 185    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 186        match self {
 187            Command::Read => write!(f, "read"),
 188            Command::LoadProject => write!(f, "load-project"),
 189            Command::Context => write!(f, "context"),
 190            Command::FormatPrompt(args) => {
 191                write!(f, "format-prompt --provider={}", args.provider)
 192            }
 193            Command::Predict(args) => match &args.provider {
 194                Some(provider) => write!(f, "predict --provider={}", provider),
 195                None => write!(f, "predict"),
 196            },
 197            Command::ParseOutput => write!(f, "parse-output"),
 198            Command::Score(args) => match &args.provider {
 199                Some(provider) => write!(f, "score --provider={}", provider),
 200                None => write!(f, "score"),
 201            },
 202            Command::Distill => write!(f, "distill"),
 203            Command::Eval(args) => match &args.predict.provider {
 204                Some(provider) => write!(f, "eval --provider={}", provider),
 205                None => write!(f, "eval"),
 206            },
 207            Command::Synthesize(args) => {
 208                write!(f, "synthesize --repos {}", args.repos.join(" "))
 209            }
 210            Command::Clean => write!(f, "clean"),
 211            Command::SplitCommit(_) => write!(f, "split-commit"),
 212            Command::Split(_) => write!(f, "split"),
 213            Command::FilterLanguages(_) => write!(f, "filter-languages"),
 214            Command::ImportBatch(args) => {
 215                write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
 216            }
 217            Command::Qa(_) => {
 218                write!(f, "qa")
 219            }
 220            Command::Repair(_) => {
 221                write!(f, "repair")
 222            }
 223        }
 224    }
 225}
 226
 227#[derive(Debug, Args, Clone)]
 228struct FormatPromptArgs {
 229    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
 230    provider: PredictionProvider,
 231}
 232
 233#[derive(Debug, Args, Clone)]
 234struct PredictArgs {
 235    #[clap(long, short('p'))]
 236    provider: Option<PredictionProvider>,
 237    #[clap(long, default_value_t = 1)]
 238    repetitions: usize,
 239}
 240
 241#[derive(Debug, Args, Clone)]
 242struct EvalArgs {
 243    #[clap(flatten)]
 244    predict: PredictArgs,
 245    /// Path to write summary scores as JSON
 246    #[clap(long)]
 247    summary_json: Option<PathBuf>,
 248}
 249
 250#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 251pub enum TeacherBackend {
 252    Sonnet45,
 253    Gpt52,
 254}
 255
 256impl std::fmt::Display for TeacherBackend {
 257    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 258        match self {
 259            TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
 260            TeacherBackend::Gpt52 => write!(f, "gpt52"),
 261        }
 262    }
 263}
 264
 265impl std::str::FromStr for TeacherBackend {
 266    type Err = anyhow::Error;
 267
 268    fn from_str(s: &str) -> Result<Self, Self::Err> {
 269        match s.to_lowercase().as_str() {
 270            "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
 271            "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
 272            "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
 273            _ => anyhow::bail!("unknown teacher backend `{s}`. Valid options: sonnet45, gpt52"),
 274        }
 275    }
 276}
 277
 278impl TeacherBackend {
 279    pub fn model_name(&self) -> &'static str {
 280        match self {
 281            TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
 282            TeacherBackend::Gpt52 => "gpt-5.2",
 283        }
 284    }
 285}
 286
 287#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
 288enum PredictionProvider {
 289    Sweep,
 290    Mercury,
 291    Zeta1,
 292    Zeta2(ZetaVersion),
 293    Teacher(TeacherBackend),
 294    TeacherNonBatching(TeacherBackend),
 295    RepairedTeacher(TeacherBackend),
 296    Repair,
 297}
 298
 299impl Default for PredictionProvider {
 300    fn default() -> Self {
 301        PredictionProvider::Zeta2(ZetaVersion::default())
 302    }
 303}
 304
 305impl std::fmt::Display for PredictionProvider {
 306    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 307        match self {
 308            PredictionProvider::Sweep => write!(f, "sweep"),
 309            PredictionProvider::Mercury => write!(f, "mercury"),
 310            PredictionProvider::Zeta1 => write!(f, "zeta1"),
 311            PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
 312            PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
 313            PredictionProvider::TeacherNonBatching(backend) => {
 314                write!(f, "teacher-non-batching:{backend}")
 315            }
 316            PredictionProvider::RepairedTeacher(backend) => {
 317                write!(f, "repaired-teacher:{backend}")
 318            }
 319            PredictionProvider::Repair => write!(f, "repair"),
 320        }
 321    }
 322}
 323
 324impl std::str::FromStr for PredictionProvider {
 325    type Err = anyhow::Error;
 326
 327    fn from_str(s: &str) -> Result<Self, Self::Err> {
 328        let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
 329
 330        let provider_lower = provider.to_lowercase();
 331        match provider_lower.as_str() {
 332            "sweep" => Ok(PredictionProvider::Sweep),
 333            "mercury" => Ok(PredictionProvider::Mercury),
 334            "zeta1" => Ok(PredictionProvider::Zeta1),
 335            "zeta2" => {
 336                let version = arg.map(ZetaVersion::parse).transpose()?.unwrap_or_default();
 337                Ok(PredictionProvider::Zeta2(version))
 338            }
 339            "teacher" => {
 340                let backend = arg
 341                    .map(|a| a.parse())
 342                    .transpose()?
 343                    .unwrap_or(TeacherBackend::Sonnet45);
 344                Ok(PredictionProvider::Teacher(backend))
 345            }
 346            "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
 347                let backend = arg
 348                    .map(|a| a.parse())
 349                    .transpose()?
 350                    .unwrap_or(TeacherBackend::Sonnet45);
 351                Ok(PredictionProvider::TeacherNonBatching(backend))
 352            }
 353            "repaired-teacher" | "repaired_teacher" | "repairedteacher" => {
 354                let backend = arg
 355                    .map(|a| a.parse())
 356                    .transpose()?
 357                    .unwrap_or(TeacherBackend::Sonnet45);
 358                Ok(PredictionProvider::RepairedTeacher(backend))
 359            }
 360            "repair" => Ok(PredictionProvider::Repair),
 361            _ => {
 362                anyhow::bail!(
 363                    "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repaired-teacher, repair\n\
 364                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
 365                 For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\
 366                 Available zeta versions:\n{}",
 367                    ZetaVersion::options_as_string()
 368                )
 369            }
 370        }
 371    }
 372}
 373
 374impl Serialize for PredictionProvider {
 375    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 376    where
 377        S: Serializer,
 378    {
 379        serializer.serialize_str(&self.to_string())
 380    }
 381}
 382
 383impl<'de> Deserialize<'de> for PredictionProvider {
 384    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 385    where
 386        D: Deserializer<'de>,
 387    {
 388        let s = String::deserialize(deserializer)?;
 389        s.parse().map_err(serde::de::Error::custom)
 390    }
 391}
 392
 393#[derive(Debug, Args, Clone)]
 394struct SynthesizeArgs {
 395    /// Repository URLs (git@github.com:owner/repo or https://...)
 396    #[clap(long, required = true, num_args = 1..)]
 397    repos: Vec<String>,
 398
 399    /// Number of examples to generate per repository
 400    #[clap(long, default_value_t = 5)]
 401    count: usize,
 402
 403    /// Maximum commits to scan per repository before giving up
 404    #[clap(long, default_value_t = 100)]
 405    max_commits: usize,
 406
 407    /// Ignore state file and reprocess all commits
 408    #[clap(long)]
 409    fresh: bool,
 410}
 411
 412#[derive(Debug, Args, Clone)]
 413struct ImportBatchArgs {
 414    /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
 415    #[clap(long, required = true, num_args = 1..)]
 416    batch_ids: Vec<String>,
 417    /// Which provider's batches to import (anthropic or openai)
 418    #[clap(long, default_value = "anthropic")]
 419    provider: BatchProvider,
 420}
 421
 422#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
 423enum BatchProvider {
 424    Anthropic,
 425    Openai,
 426}
 427
 428impl EpArgs {
 429    fn output_path(&self) -> Option<PathBuf> {
 430        if self.in_place {
 431            if self.inputs.len() == 1 {
 432                self.inputs.first().cloned()
 433            } else {
 434                panic!("--in-place requires exactly one input file")
 435            }
 436        } else {
 437            self.output.clone()
 438        }
 439    }
 440}
 441
 442async fn load_examples(
 443    http_client: Arc<dyn http_client::HttpClient>,
 444    args: &EpArgs,
 445    output_path: Option<&PathBuf>,
 446    background_executor: BackgroundExecutor,
 447) -> anyhow::Result<Vec<Example>> {
 448    let mut captured_after_timestamps = Vec::new();
 449    let mut rejected_after_timestamps = Vec::new();
 450    let mut requested_after_timestamps = Vec::new();
 451    let mut file_inputs = Vec::new();
 452
 453    for input in &args.inputs {
 454        let input_string = input.to_string_lossy();
 455        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
 456            captured_after_timestamps.push(timestamp.to_string());
 457        } else if let Some(timestamp) =
 458            pull_examples::parse_rejected_after_input(input_string.as_ref())
 459        {
 460            rejected_after_timestamps.push(timestamp.to_string());
 461        } else if let Some(timestamp) =
 462            pull_examples::parse_requested_after_input(input_string.as_ref())
 463        {
 464            requested_after_timestamps.push(timestamp.to_string());
 465        } else {
 466            file_inputs.push(input.clone());
 467        }
 468    }
 469
 470    let mut examples = read_example_files(&file_inputs);
 471
 472    Progress::global().set_total_examples(examples.len());
 473
 474    let remaining_limit_for_snowflake =
 475        args.limit.map(|limit| limit.saturating_sub(examples.len()));
 476
 477    if let Some(0) = remaining_limit_for_snowflake {
 478        log::info!(
 479            "skipping Snowflake inputs because --limit is already satisfied by example files"
 480        );
 481    } else {
 482        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
 483
 484        if !captured_after_timestamps.is_empty() {
 485            captured_after_timestamps.sort();
 486
 487            let mut captured_examples = pull_examples::fetch_captured_examples_after(
 488                http_client.clone(),
 489                &captured_after_timestamps,
 490                max_rows_per_timestamp,
 491                background_executor.clone(),
 492            )
 493            .await?;
 494            examples.append(&mut captured_examples);
 495        }
 496
 497        if !rejected_after_timestamps.is_empty() {
 498            rejected_after_timestamps.sort();
 499
 500            let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
 501                http_client.clone(),
 502                &rejected_after_timestamps,
 503                max_rows_per_timestamp,
 504                background_executor.clone(),
 505            )
 506            .await?;
 507            examples.append(&mut rejected_examples);
 508        }
 509
 510        if !requested_after_timestamps.is_empty() {
 511            requested_after_timestamps.sort();
 512
 513            let mut requested_examples = pull_examples::fetch_requested_examples_after(
 514                http_client,
 515                &requested_after_timestamps,
 516                max_rows_per_timestamp,
 517                background_executor,
 518            )
 519            .await?;
 520            examples.append(&mut requested_examples);
 521        }
 522    }
 523
 524    crate::example::sort_examples_by_repo_and_rev(&mut examples);
 525
 526    if let Some(name_filter) = &args.name {
 527        examples.retain(|example| example.spec.name.contains(name_filter));
 528    }
 529    if let Some(repo_filter) = &args.repo {
 530        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
 531    }
 532
 533    // Skip resume logic for --in-place since input and output are the same file,
 534    // which would incorrectly treat all input examples as already processed.
 535    if !args.in_place {
 536        if let Some(path) = output_path {
 537            resume_from_output(path, &mut examples);
 538        }
 539    }
 540
 541    if let Some(offset) = args.offset {
 542        examples.splice(0..offset, []);
 543    }
 544
 545    if let Some(limit) = args.limit {
 546        examples.truncate(limit);
 547    }
 548
 549    let progress = Progress::global();
 550    progress.set_total_examples(examples.len());
 551    progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
 552
 553    Ok(examples)
 554}
 555
 556fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
 557    let mut hasher = collections::FxHasher::default();
 558    spec.hash(&mut hasher);
 559    hasher.finish()
 560}
 561
 562fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
 563    let file = match File::open(path) {
 564        Ok(f) => f,
 565        Err(_) => return,
 566    };
 567
 568    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
 569
 570    let reader = BufReader::new(file);
 571    let mut kept_lines = Vec::new();
 572    let mut kept_hashes = HashSet::default();
 573
 574    for line in reader.lines() {
 575        let line = match line {
 576            Ok(l) => l,
 577            Err(_) => continue,
 578        };
 579
 580        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
 581            let hash = spec_hash(&output_example.spec);
 582            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
 583                kept_hashes.insert(hash);
 584                kept_lines.push(line);
 585            }
 586        }
 587    }
 588
 589    let total = examples.len();
 590    let already_processed = kept_hashes.len();
 591
 592    eprintln!(
 593        "Resuming: {}/{} examples already processed",
 594        already_processed, total
 595    );
 596
 597    let file = OpenOptions::new()
 598        .write(true)
 599        .truncate(true)
 600        .open(path)
 601        .expect("Failed to open output file for rewriting");
 602    let mut writer = BufWriter::new(file);
 603    for line in &kept_lines {
 604        writeln!(writer, "{}", line).expect("Failed to write to output file");
 605    }
 606    writer.flush().expect("Failed to flush output file");
 607
 608    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
 609}
 610
 611fn main() {
 612    let args = EpArgs::parse();
 613
 614    if args.printenv {
 615        ::util::shell_env::print_env();
 616        return;
 617    }
 618
 619    let output = args.output_path();
 620
 621    if args.markdown && output.is_none() {
 622        eprintln!("--markdown requires -o to specify the output directory");
 623        std::process::exit(1);
 624    }
 625
 626    let command = match &args.command {
 627        Some(cmd) => cmd.clone(),
 628        None => {
 629            EpArgs::command().print_help().unwrap();
 630            return;
 631        }
 632    };
 633
 634    match &command {
 635        Command::ImportBatch(import_args) => {
 636            smol::block_on(async {
 637                match import_args.provider {
 638                    BatchProvider::Anthropic => {
 639                        let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
 640                            .expect("Failed to create Anthropic client");
 641                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 642                            eprintln!("Error importing Anthropic batches: {:?}", e);
 643                            std::process::exit(1);
 644                        }
 645                    }
 646                    BatchProvider::Openai => {
 647                        let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
 648                            .expect("Failed to create OpenAI client");
 649                        if let Err(e) = client.import_batches(&import_args.batch_ids).await {
 650                            eprintln!("Error importing OpenAI batches: {:?}", e);
 651                            std::process::exit(1);
 652                        }
 653                    }
 654                }
 655                println!(
 656                    "Successfully imported {} batch(es)",
 657                    import_args.batch_ids.len()
 658                );
 659            });
 660            return;
 661        }
 662        Command::Clean => {
 663            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
 664            return;
 665        }
 666        Command::Synthesize(synth_args) => {
 667            let Some(output_dir) = args.output else {
 668                panic!("output dir is required");
 669            };
 670            let config = SynthesizeConfig {
 671                repo_urls: synth_args.repos.clone(),
 672                count: synth_args.count,
 673                max_commits: synth_args.max_commits,
 674                output_dir,
 675                fresh: synth_args.fresh,
 676            };
 677            smol::block_on(async {
 678                if let Err(e) = run_synthesize(config).await {
 679                    eprintln!("Error: {:?}", e);
 680                    std::process::exit(1);
 681                }
 682            });
 683            return;
 684        }
 685        Command::SplitCommit(split_commit_args) => {
 686            if let Err(error) = split_commit::run_split_commit(
 687                split_commit_args,
 688                &args.inputs,
 689                output.as_ref(),
 690                args.failed,
 691            ) {
 692                eprintln!("{error:#}");
 693                std::process::exit(1);
 694            }
 695            return;
 696        }
 697        Command::Split(split_args) => {
 698            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
 699                eprintln!("{error:#}");
 700                std::process::exit(1);
 701            }
 702            return;
 703        }
 704        Command::FilterLanguages(filter_args) => {
 705            if let Err(error) =
 706                run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
 707            {
 708                eprintln!("{error:#}");
 709                std::process::exit(1);
 710            }
 711            return;
 712        }
 713        Command::Qa(qa_args) => {
 714            // Read examples from input files
 715            let mut examples = example::read_example_files(&args.inputs);
 716
 717            // Apply filters
 718            if let Some(name_filter) = &args.name {
 719                examples.retain(|e| e.spec.name.contains(name_filter));
 720            }
 721            if let Some(repo_filter) = &args.repo {
 722                examples.retain(|e| e.spec.repository_url.contains(repo_filter));
 723            }
 724            if let Some(offset) = args.offset {
 725                examples.splice(0..offset, []);
 726            }
 727            if let Some(limit) = args.limit {
 728                examples.truncate(limit);
 729            }
 730
 731            smol::block_on(async {
 732                if let Err(e) = qa::run_qa(&mut examples, qa_args, output.as_ref()).await {
 733                    eprintln!("Error: {:?}", e);
 734                    std::process::exit(1);
 735                }
 736            });
 737            return;
 738        }
 739        Command::Repair(repair_args) => {
 740            // Read examples from input files
 741            let mut examples = example::read_example_files(&args.inputs);
 742
 743            // Apply filters
 744            if let Some(name_filter) = &args.name {
 745                examples.retain(|e| e.spec.name.contains(name_filter));
 746            }
 747            if let Some(repo_filter) = &args.repo {
 748                examples.retain(|e| e.spec.repository_url.contains(repo_filter));
 749            }
 750            if let Some(offset) = args.offset {
 751                examples.splice(0..offset, []);
 752            }
 753            if let Some(limit) = args.limit {
 754                examples.truncate(limit);
 755            }
 756
 757            smol::block_on(async {
 758                if let Err(e) =
 759                    repair::run_repair(&mut examples, repair_args, output.as_ref()).await
 760                {
 761                    eprintln!("Error: {:?}", e);
 762                    std::process::exit(1);
 763                }
 764            });
 765            return;
 766        }
 767        _ => {}
 768    }
 769
 770    let http_client = Arc::new(ReqwestClient::new());
 771    let app = Application::headless().with_http_client(http_client);
 772
 773    app.run(move |cx| {
 774        let app_state = Arc::new(headless::init(cx));
 775        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
 776
 777        cx.spawn(async move |cx| {
 778            let result = async {
 779                let examples = load_examples(
 780                    app_state.client.http_client(),
 781                    &args,
 782                    output.as_ref(),
 783                    cx.background_executor().clone(),
 784                )
 785                .await?;
 786
 787                match &command {
 788                    Command::Predict(args) | Command::Score(args) => {
 789                        predict::sync_batches(args.provider.as_ref()).await?;
 790                    }
 791                    Command::Eval(args) => {
 792                        predict::sync_batches(args.predict.provider.as_ref()).await?;
 793                    }
 794                    _ => (),
 795                }
 796
 797                let failfast_on_single_example = examples.len() == 1;
 798
 799                // For --markdown mode, create the output directory if it doesn't exist
 800                let markdown_output_dir = if args.markdown {
 801                    let dir = output.as_ref().expect("--markdown requires -o");
 802                    if !dir.exists() {
 803                        std::fs::create_dir_all(dir)
 804                            .expect("Failed to create markdown output directory");
 805                    }
 806                    Some(dir.clone())
 807                } else {
 808                    None
 809                };
 810
 811                // For --in-place, write to a temp file and rename at the end to avoid data loss on interruption
 812                let in_place_temp_path = if args.in_place {
 813                    output.as_ref().map(|path| {
 814                        let mut temp_path = path.clone();
 815                        temp_path.set_extension("jsonl.tmp");
 816                        temp_path
 817                    })
 818                } else {
 819                    None
 820                };
 821
 822                let output_sender: Option<mpsc::UnboundedSender<String>> = if !args.markdown
 823                    && (args.output.is_some() || !matches!(command, Command::Eval(_)))
 824                {
 825                    let write_path = in_place_temp_path.as_ref().or(output.as_ref());
 826                    write_path.map(|path| {
 827                        let file = if args.in_place {
 828                            // For --in-place, write to temp file (truncate if exists)
 829                            OpenOptions::new()
 830                                .create(true)
 831                                .write(true)
 832                                .truncate(true)
 833                                .open(path)
 834                                .expect("Failed to open temp output file")
 835                        } else {
 836                            // For regular output, append to support resuming
 837                            OpenOptions::new()
 838                                .create(true)
 839                                .append(true)
 840                                .open(path)
 841                                .expect("Failed to open output file")
 842                        };
 843                        let mut writer = BufWriter::new(file);
 844                        let (sender, mut receiver) = mpsc::unbounded::<String>();
 845                        cx.background_spawn(async move {
 846                            while let Some(line) = receiver.next().await {
 847                                writeln!(writer, "{}", line).expect("Failed to write example");
 848                                writer.flush().expect("Failed to flush output");
 849                            }
 850                        })
 851                        .detach();
 852                        sender
 853                    })
 854                } else {
 855                    None
 856                };
 857
 858                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
 859                let finished_examples = Mutex::new(Vec::new());
 860
 861                let mut tasks = Vec::new();
 862                for _ in 0..args.max_parallelism {
 863                    tasks.push(async {
 864                        loop {
 865                            let Some(mut repo_examples) =
 866                                grouped_examples.lock().unwrap().pop_front()
 867                            else {
 868                                break;
 869                            };
 870                            for example in &mut repo_examples {
 871                                let example_progress =
 872                                    Progress::global().start_group(&example.spec.name);
 873
 874                                let result = async {
 875                                    match &command {
 876                                        Command::Read => {}
 877                                        Command::LoadProject => {
 878                                            run_load_project(
 879                                                example,
 880                                                app_state.clone(),
 881                                                &example_progress,
 882                                                cx.clone(),
 883                                            )
 884                                            .await?;
 885                                        }
 886                                        Command::Context => {
 887                                            run_context_retrieval(
 888                                                example,
 889                                                app_state.clone(),
 890                                                &example_progress,
 891                                                cx.clone(),
 892                                            )
 893                                            .await?;
 894                                        }
 895                                        Command::FormatPrompt(args) => {
 896                                            run_format_prompt(
 897                                                example,
 898                                                args,
 899                                                app_state.clone(),
 900                                                &example_progress,
 901                                                cx.clone(),
 902                                            )
 903                                            .await?;
 904                                        }
 905                                        Command::Predict(args) => {
 906                                            run_prediction(
 907                                                example,
 908                                                args,
 909                                                app_state.clone(),
 910                                                &example_progress,
 911                                                cx.clone(),
 912                                            )
 913                                            .await?;
 914                                        }
 915                                        Command::ParseOutput => {
 916                                            parse_output::run_parse_output(example)?;
 917                                        }
 918                                        Command::Distill => {
 919                                            run_distill(example).await?;
 920                                        }
 921                                        Command::Score(args) => {
 922                                            run_scoring(
 923                                                example,
 924                                                args,
 925                                                app_state.clone(),
 926                                                &example_progress,
 927                                                cx.clone(),
 928                                            )
 929                                            .await?;
 930                                        }
 931                                        Command::Eval(args) => {
 932                                            run_scoring(
 933                                                example,
 934                                                &args.predict,
 935                                                app_state.clone(),
 936                                                &example_progress,
 937                                                cx.clone(),
 938                                            )
 939                                            .await?;
 940                                        }
 941                                        Command::Clean
 942                                        | Command::Synthesize(_)
 943                                        | Command::SplitCommit(_)
 944                                        | Command::Split(_)
 945                                        | Command::FilterLanguages(_)
 946                                        | Command::ImportBatch(_)
 947                                        | Command::Qa(_)
 948                                        | Command::Repair(_) => {
 949                                            unreachable!()
 950                                        }
 951                                    }
 952                                    anyhow::Ok(())
 953                                }
 954                                .await;
 955
 956                                let failed = if let Err(error) = result {
 957                                    handle_error(
 958                                        error,
 959                                        &args,
 960                                        &command,
 961                                        &app_state,
 962                                        failfast_on_single_example,
 963                                        &example,
 964                                    )
 965                                    .await;
 966                                    true
 967                                } else {
 968                                    false
 969                                };
 970
 971                                let should_write = !failed || args.failed == FailedHandling::Keep;
 972                                if should_write {
 973                                    if let Some(ref markdown_dir) = markdown_output_dir {
 974                                        let filename = format!("{}.md", example.spec.filename());
 975                                        let path = markdown_dir.join(&filename);
 976                                        let markdown = example.spec.to_markdown();
 977                                        std::fs::write(&path, &markdown)
 978                                            .expect("Failed to write markdown file");
 979                                    } else if let Some(ref mut sender) = output_sender.clone() {
 980                                        let line = serde_json::to_string(&example).unwrap();
 981                                        sender
 982                                            .send(line)
 983                                            .await
 984                                            .expect("Failed to send to output writer");
 985                                    } else if args.output.is_none()
 986                                        && !matches!(command, Command::Eval(_))
 987                                    {
 988                                        let line = serde_json::to_string(&example).unwrap();
 989                                        println!("{}", line);
 990                                    }
 991                                }
 992                            }
 993
 994                            let repo_url = &repo_examples.first().unwrap().spec.repository_url;
 995                            let project = repo_examples
 996                                .iter()
 997                                .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
 998                                .or_else(|| app_state.project_cache.get(repo_url));
 999
1000                            if let Some(project) = project {
1001                                let mut cx = cx.clone();
1002
1003                                let shutdown_task: Task<()> =
1004                                    project.update(&mut cx, |project, cx| {
1005                                        let lsp_store = project.lsp_store();
1006                                        lsp_store.update(cx, |lsp_store, cx| {
1007                                            lsp_store.shutdown_all_language_servers(cx)
1008                                        })
1009                                    });
1010
1011                                shutdown_task.await;
1012
1013                                if let Some(ep_store) =
1014                                    cx.update(|cx| EditPredictionStore::try_global(cx))
1015                                {
1016                                    ep_store.update(&mut cx, |store, _| {
1017                                        store.remove_project(&project);
1018                                    });
1019                                }
1020                            }
1021
1022                            app_state.project_cache.remove(repo_url);
1023                            for example in &mut repo_examples {
1024                                example.state.take();
1025                            }
1026                            finished_examples
1027                                .lock()
1028                                .unwrap()
1029                                .extend_from_slice(&repo_examples);
1030                        }
1031                    });
1032                }
1033                futures::future::join_all(tasks).await;
1034
1035                Progress::global().finalize();
1036
1037                match &command {
1038                    Command::Predict(args) | Command::Score(args) => {
1039                        predict::sync_batches(args.provider.as_ref()).await?;
1040                    }
1041                    Command::Eval(args) => {
1042                        predict::sync_batches(args.predict.provider.as_ref()).await?;
1043                    }
1044                    _ => (),
1045                }
1046
1047                match &command {
1048                    Command::Eval(args) => {
1049                        let examples = finished_examples.lock().unwrap();
1050                        score::print_report(&examples);
1051                        if let Some(summary_path) = &args.summary_json {
1052                            score::write_summary_json(&examples, summary_path)?;
1053                        }
1054                    }
1055                    _ => (),
1056                };
1057
1058                // For --in-place, atomically rename temp file to original
1059                if let (Some(temp_path), Some(final_path)) = (&in_place_temp_path, &output) {
1060                    std::fs::rename(temp_path, final_path)
1061                        .expect("Failed to rename temp file to final output");
1062                }
1063
1064                anyhow::Ok(())
1065            }
1066            .await;
1067
1068            if let Err(e) = result {
1069                panic!("Fatal error: {:?}", e);
1070            }
1071
1072            let _ = cx.update(|cx| cx.quit());
1073        })
1074        .detach();
1075    });
1076}
1077
1078async fn handle_error(
1079    error: anyhow::Error,
1080    args: &EpArgs,
1081    command: &Command,
1082    app_state: &Arc<headless::EpAppState>,
1083    failfast_on_single_example: bool,
1084    example: &Example,
1085) {
1086    Progress::global().increment_failed();
1087
1088    let msg;
1089    if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1090        let example_name = example.spec.filename();
1091
1092        let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1093        app_state
1094            .fs
1095            .write(
1096                &failed_example_path,
1097                &serde_json::to_vec_pretty(&example).unwrap(),
1098            )
1099            .await
1100            .unwrap();
1101        let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1102        app_state
1103            .fs
1104            .write(&err_path, format!("{error:?}").as_bytes())
1105            .await
1106            .unwrap();
1107
1108        let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1109        let mut file = OpenOptions::new()
1110            .create(true)
1111            .append(true)
1112            .open(&failed_jsonl_path)
1113            .expect("Failed to open failed.jsonl");
1114        writeln!(file, "{}", serde_json::to_string(example).unwrap())
1115            .expect("Failed to write to failed.jsonl");
1116
1117        let cursor_path = example
1118            .repo_name()
1119            .unwrap()
1120            .worktree_path()
1121            .join(&example.spec.cursor_path);
1122        msg = format!(
1123            indoc::indoc! {"
1124                While processing \"{}\":
1125
1126                \x1b[31m{:?}\x1b[0m
1127
1128                Example:        \x1b[36m{}\x1b[0m
1129                Error file:     \x1b[36m{}\x1b[0m
1130                Cursor file:    \x1b[36m{}\x1b[0m
1131                Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1132            "},
1133            example.spec.name,
1134            error,
1135            failed_example_path.display(),
1136            err_path.display(),
1137            cursor_path.display(),
1138            command,
1139            failed_example_path.display(),
1140        );
1141    } else {
1142        msg = format!(
1143            indoc::indoc! {"
1144            While processing \"{}\":
1145
1146                \x1b[31m{:?}\x1b[0m
1147            "},
1148            example.spec.name, error
1149        );
1150    }
1151
1152    if args.failfast || failfast_on_single_example {
1153        Progress::global().finalize();
1154        panic!("{}", msg);
1155    } else {
1156        log::error!("{}", msg);
1157    }
1158}