main.rs

  1mod anthropic_client;
  2mod distill;
  3mod example;
  4mod filter_languages;
  5mod format_prompt;
  6mod git;
  7mod headless;
  8mod load_project;
  9mod metrics;
 10mod parse_output;
 11mod paths;
 12mod predict;
 13mod progress;
 14mod pull_examples;
 15mod qa;
 16mod reorder_patch;
 17mod retrieve_context;
 18mod score;
 19mod split_commit;
 20mod split_dataset;
 21mod synthesize;
 22mod word_diff;
 23use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
 24use collections::HashSet;
 25use edit_prediction::EditPredictionStore;
 26use futures::channel::mpsc;
 27use futures::{SinkExt as _, StreamExt as _};
 28use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
 29use zeta_prompt::ZetaVersion;
 30
 31use reqwest_client::ReqwestClient;
 32use serde::{Deserialize, Deserializer, Serialize, Serializer};
 33use std::fmt::Display;
 34use std::fs::{File, OpenOptions};
 35use std::hash::{Hash, Hasher};
 36use std::io::{BufRead, BufReader, BufWriter, Write};
 37use std::sync::Mutex;
 38use std::{path::PathBuf, sync::Arc};
 39
 40use crate::distill::run_distill;
 41use crate::example::{Example, group_examples_by_repo, read_example_files};
 42use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
 43use crate::format_prompt::run_format_prompt;
 44use crate::load_project::run_load_project;
 45use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
 46use crate::predict::run_prediction;
 47use crate::progress::Progress;
 48use crate::retrieve_context::run_context_retrieval;
 49use crate::score::run_scoring;
 50use crate::split_commit::SplitCommitArgs;
 51use crate::split_dataset::SplitArgs;
 52use crate::synthesize::{SynthesizeConfig, run_synthesize};
 53
 54#[derive(Parser, Debug)]
 55#[command(name = "ep")]
 56struct EpArgs {
 57    #[arg(long, default_value_t = false)]
 58    printenv: bool,
 59    #[clap(long, default_value_t = 10, global = true)]
 60    max_parallelism: usize,
 61    #[clap(long, global = true)]
 62    limit: Option<usize>,
 63    #[clap(long, global = true)]
 64    offset: Option<usize>,
 65    /// Filter examples by name
 66    #[clap(long, global = true)]
 67    name: Option<String>,
 68    /// Filter examples by repository
 69    #[clap(long, global = true)]
 70    repo: Option<String>,
 71    #[command(subcommand)]
 72    command: Option<Command>,
 73    #[clap(global = true, help = INPUTS_HELP)]
 74    inputs: Vec<PathBuf>,
 75    #[arg(long, short, global = true)]
 76    output: Option<PathBuf>,
 77    #[arg(long, short, global = true)]
 78    in_place: bool,
 79    #[arg(long, short, global = true)]
 80    failfast: bool,
 81    /// How to handle failed examples in output: keep them or skip them.
 82    /// Failed examples are always logged to the run's failed directory.
 83    #[arg(long, global = true, default_value = "keep")]
 84    failed: FailedHandling,
 85}
 86
 87/// Controls whether failed examples are included in the main output.
 88/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
 89#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
 90pub enum FailedHandling {
 91    /// Include failed examples in the main output (default)
 92    #[default]
 93    Keep,
 94    /// Exclude failed examples from the main output
 95    Skip,
 96    /// Skip writing files
 97    SkipNoFiles,
 98}
 99
100const INPUTS_HELP: &str = r#"
101Inputs can be file paths or special specifiers:
102
103  path
104      Path to an example(s) file (.md, .json, or .jsonl)
105
106  captured-after:{timestamp}
107      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
108      These are examples captured via the "Capture Edit Prediction Example" action.
109
110  rejected-after:{timestamp}
111      Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
112      These are predictions that were shown to users but rejected (useful for DPO training).
113
114      Required environment variables to connect to Snowflake:
115          EP_SNOWFLAKE_API_KEY
116          EP_SNOWFLAKE_BASE_URL
117
118      Optional:
119          EP_SNOWFLAKE_ROLE
120
121Examples:
122
123  # Read examples from a file
124  ep read examples.jsonl -o output.jsonl
125
126  # Read captured examples after a timestamp
127  ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
128
129  # Read rejected predictions for DPO training
130  ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
131
132  # Mix multiple input sources
133  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
134"#;
135
136#[derive(Subcommand, Debug, Clone)]
137enum Command {
138    /// Read examples from files or fetch from Snowflake, output as .jsonl
139    Read,
140    /// Create git worktrees for each example and load file contents
141    LoadProject,
142    /// Retrieve context for input examples.
143    Context,
144    /// Generate a prompt string for a specific model
145    FormatPrompt(FormatPromptArgs),
146    /// Runs edit prediction
147    Predict(PredictArgs),
148    /// Parse model outputs (actual_output) into unified diffs (actual_patch).
149    /// Requires format-prompt to have been run first. Uses provider from prompt.
150    ParseOutput,
151    /// Computes a score based on actual and expected patches
152    Score(PredictArgs),
153    /// Prepares a distillation dataset by copying expected outputs to
154    /// predicted outputs and removing actual outputs and prompts.
155    Distill,
156    /// Print aggregated scores
157    Eval(EvalArgs),
158    /// Generate eval examples by analyzing git commits from a repository
159    Synthesize(SynthesizeArgs),
160    /// Remove git repositories and worktrees
161    Clean,
162    /// Generate an evaluation example by splitting a chronologically-ordered commit
163    SplitCommit(SplitCommitArgs),
164    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
165    Split(SplitArgs),
166    /// Filter a JSONL dataset by programming language (based on cursor_path extension)
167    FilterLanguages(FilterLanguagesArgs),
168    /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
169    ImportBatch(ImportBatchArgs),
170    /// Assess the quality of predictions using LLM-as-a-judge
171    Qa(qa::QaArgs),
172}
173
174impl Display for Command {
175    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
176        match self {
177            Command::Read => write!(f, "read"),
178            Command::LoadProject => write!(f, "load-project"),
179            Command::Context => write!(f, "context"),
180            Command::FormatPrompt(args) => {
181                write!(f, "format-prompt --provider={}", args.provider)
182            }
183            Command::Predict(args) => match &args.provider {
184                Some(provider) => write!(f, "predict --provider={}", provider),
185                None => write!(f, "predict"),
186            },
187            Command::ParseOutput => write!(f, "parse-output"),
188            Command::Score(args) => match &args.provider {
189                Some(provider) => write!(f, "score --provider={}", provider),
190                None => write!(f, "score"),
191            },
192            Command::Distill => write!(f, "distill"),
193            Command::Eval(args) => match &args.predict.provider {
194                Some(provider) => write!(f, "eval --provider={}", provider),
195                None => write!(f, "eval"),
196            },
197            Command::Synthesize(args) => {
198                write!(f, "synthesize --repos {}", args.repos.join(" "))
199            }
200            Command::Clean => write!(f, "clean"),
201            Command::SplitCommit(_) => write!(f, "split-commit"),
202            Command::Split(_) => write!(f, "split"),
203            Command::FilterLanguages(_) => write!(f, "filter-languages"),
204            Command::ImportBatch(args) => {
205                write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
206            }
207            Command::Qa(_) => {
208                write!(f, "qa")
209            }
210        }
211    }
212}
213
214#[derive(Debug, Args, Clone)]
215struct FormatPromptArgs {
216    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
217    provider: PredictionProvider,
218}
219
220#[derive(Debug, Args, Clone)]
221struct PredictArgs {
222    #[clap(long, short('p'))]
223    provider: Option<PredictionProvider>,
224    #[clap(long, default_value_t = 1)]
225    repetitions: usize,
226}
227
228#[derive(Debug, Args, Clone)]
229struct EvalArgs {
230    #[clap(flatten)]
231    predict: PredictArgs,
232    /// Path to write summary scores as JSON
233    #[clap(long)]
234    summary_json: Option<PathBuf>,
235}
236
237#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
238enum PredictionProvider {
239    Sweep,
240    Mercury,
241    Zeta1,
242    Zeta2(ZetaVersion),
243    Teacher(ZetaVersion),
244    TeacherNonBatching(ZetaVersion),
245}
246
247impl Default for PredictionProvider {
248    fn default() -> Self {
249        PredictionProvider::Zeta2(ZetaVersion::default())
250    }
251}
252
253impl std::fmt::Display for PredictionProvider {
254    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255        match self {
256            PredictionProvider::Sweep => write!(f, "sweep"),
257            PredictionProvider::Mercury => write!(f, "mercury"),
258            PredictionProvider::Zeta1 => write!(f, "zeta1"),
259            PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
260            PredictionProvider::Teacher(version) => write!(f, "teacher:{version}"),
261            PredictionProvider::TeacherNonBatching(version) => {
262                write!(f, "teacher-non-batching:{version}")
263            }
264        }
265    }
266}
267
268impl std::str::FromStr for PredictionProvider {
269    type Err = anyhow::Error;
270
271    fn from_str(mut s: &str) -> Result<Self, Self::Err> {
272        let mut version = ZetaVersion::default();
273        if let Some((first, second)) = s.split_once(':') {
274            version = ZetaVersion::parse(second)?;
275            s = first;
276        }
277
278        let s_lower = s.to_lowercase();
279        match s_lower.as_str() {
280            "sweep" => Ok(PredictionProvider::Sweep),
281            "mercury" => Ok(PredictionProvider::Mercury),
282            "zeta1" => Ok(PredictionProvider::Zeta1),
283            "zeta2" => Ok(PredictionProvider::Zeta2(version)),
284            "teacher" => Ok(PredictionProvider::Teacher(version)),
285            "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
286                Ok(PredictionProvider::TeacherNonBatching(version))
287            }
288            _ => {
289                anyhow::bail!(
290                    "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching\n\
291                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
292                 Available zeta versions:\n{}",
293                    ZetaVersion::options_as_string()
294                )
295            }
296        }
297    }
298}
299
300impl Serialize for PredictionProvider {
301    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
302    where
303        S: Serializer,
304    {
305        serializer.serialize_str(&self.to_string())
306    }
307}
308
309impl<'de> Deserialize<'de> for PredictionProvider {
310    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
311    where
312        D: Deserializer<'de>,
313    {
314        let s = String::deserialize(deserializer)?;
315        s.parse().map_err(serde::de::Error::custom)
316    }
317}
318
319#[derive(Debug, Args, Clone)]
320struct SynthesizeArgs {
321    /// Repository URLs (git@github.com:owner/repo or https://...)
322    #[clap(long, required = true, num_args = 1..)]
323    repos: Vec<String>,
324
325    /// Number of examples to generate per repository
326    #[clap(long, default_value_t = 5)]
327    count: usize,
328
329    /// Maximum commits to scan per repository before giving up
330    #[clap(long, default_value_t = 100)]
331    max_commits: usize,
332
333    /// Ignore state file and reprocess all commits
334    #[clap(long)]
335    fresh: bool,
336}
337
338#[derive(Debug, Args, Clone)]
339struct ImportBatchArgs {
340    /// Anthropic batch IDs to import (e.g., msgbatch_xxx)
341    #[clap(long, required = true, num_args = 1..)]
342    batch_ids: Vec<String>,
343}
344
345impl EpArgs {
346    fn output_path(&self) -> Option<PathBuf> {
347        if self.in_place {
348            if self.inputs.len() == 1 {
349                self.inputs.first().cloned()
350            } else {
351                panic!("--in-place requires exactly one input file")
352            }
353        } else {
354            self.output.clone()
355        }
356    }
357}
358
359async fn load_examples(
360    http_client: Arc<dyn http_client::HttpClient>,
361    args: &EpArgs,
362    output_path: Option<&PathBuf>,
363    background_executor: BackgroundExecutor,
364) -> anyhow::Result<Vec<Example>> {
365    let mut captured_after_timestamps = Vec::new();
366    let mut rejected_after_timestamps = Vec::new();
367    let mut file_inputs = Vec::new();
368
369    for input in &args.inputs {
370        let input_string = input.to_string_lossy();
371        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
372            captured_after_timestamps.push(timestamp.to_string());
373        } else if let Some(timestamp) =
374            pull_examples::parse_rejected_after_input(input_string.as_ref())
375        {
376            rejected_after_timestamps.push(timestamp.to_string());
377        } else {
378            file_inputs.push(input.clone());
379        }
380    }
381
382    let mut examples = read_example_files(&file_inputs);
383
384    Progress::global().set_total_examples(examples.len());
385
386    let remaining_limit_for_snowflake =
387        args.limit.map(|limit| limit.saturating_sub(examples.len()));
388
389    if let Some(0) = remaining_limit_for_snowflake {
390        log::info!(
391            "skipping Snowflake inputs because --limit is already satisfied by example files"
392        );
393    } else {
394        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
395
396        if !captured_after_timestamps.is_empty() {
397            captured_after_timestamps.sort();
398
399            let mut captured_examples = pull_examples::fetch_captured_examples_after(
400                http_client.clone(),
401                &captured_after_timestamps,
402                max_rows_per_timestamp,
403                background_executor.clone(),
404            )
405            .await?;
406            examples.append(&mut captured_examples);
407        }
408
409        if !rejected_after_timestamps.is_empty() {
410            rejected_after_timestamps.sort();
411
412            let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
413                http_client,
414                &rejected_after_timestamps,
415                max_rows_per_timestamp,
416                background_executor,
417            )
418            .await?;
419            examples.append(&mut rejected_examples);
420        }
421    }
422
423    crate::example::sort_examples_by_repo_and_rev(&mut examples);
424
425    if let Some(name_filter) = &args.name {
426        examples.retain(|example| example.spec.name.contains(name_filter));
427    }
428    if let Some(repo_filter) = &args.repo {
429        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
430    }
431
432    // Skip resume logic for --in-place since input and output are the same file,
433    // which would incorrectly treat all input examples as already processed.
434    if !args.in_place {
435        if let Some(path) = output_path {
436            resume_from_output(path, &mut examples);
437        }
438    }
439
440    if let Some(offset) = args.offset {
441        examples.splice(0..offset, []);
442    }
443
444    if let Some(limit) = args.limit {
445        examples.truncate(limit);
446    }
447
448    let progress = Progress::global();
449    progress.set_total_examples(examples.len());
450    progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
451
452    Ok(examples)
453}
454
455fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
456    let mut hasher = collections::FxHasher::default();
457    spec.hash(&mut hasher);
458    hasher.finish()
459}
460
461fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
462    let file = match File::open(path) {
463        Ok(f) => f,
464        Err(_) => return,
465    };
466
467    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
468
469    let reader = BufReader::new(file);
470    let mut kept_lines = Vec::new();
471    let mut kept_hashes = HashSet::default();
472
473    for line in reader.lines() {
474        let line = match line {
475            Ok(l) => l,
476            Err(_) => continue,
477        };
478
479        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
480            let hash = spec_hash(&output_example.spec);
481            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
482                kept_hashes.insert(hash);
483                kept_lines.push(line);
484            }
485        }
486    }
487
488    let total = examples.len();
489    let already_processed = kept_hashes.len();
490
491    eprintln!(
492        "Resuming: {}/{} examples already processed",
493        already_processed, total
494    );
495
496    let file = OpenOptions::new()
497        .write(true)
498        .truncate(true)
499        .open(path)
500        .expect("Failed to open output file for rewriting");
501    let mut writer = BufWriter::new(file);
502    for line in &kept_lines {
503        writeln!(writer, "{}", line).expect("Failed to write to output file");
504    }
505    writer.flush().expect("Failed to flush output file");
506
507    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
508}
509
510fn main() {
511    let args = EpArgs::parse();
512
513    if args.printenv {
514        ::util::shell_env::print_env();
515        return;
516    }
517
518    let output = args.output_path();
519    let command = match &args.command {
520        Some(cmd) => cmd.clone(),
521        None => {
522            EpArgs::command().print_help().unwrap();
523            return;
524        }
525    };
526
527    match &command {
528        Command::ImportBatch(import_args) => {
529            smol::block_on(async {
530                let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
531                    .expect("Failed to create Anthropic client");
532                if let Err(e) = client.import_batches(&import_args.batch_ids).await {
533                    eprintln!("Error importing batches: {:?}", e);
534                    std::process::exit(1);
535                }
536                println!(
537                    "Successfully imported {} batch(es)",
538                    import_args.batch_ids.len()
539                );
540            });
541            return;
542        }
543        Command::Clean => {
544            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
545            return;
546        }
547        Command::Synthesize(synth_args) => {
548            let Some(output_dir) = args.output else {
549                panic!("output dir is required");
550            };
551            let config = SynthesizeConfig {
552                repo_urls: synth_args.repos.clone(),
553                count: synth_args.count,
554                max_commits: synth_args.max_commits,
555                output_dir,
556                fresh: synth_args.fresh,
557            };
558            smol::block_on(async {
559                if let Err(e) = run_synthesize(config).await {
560                    eprintln!("Error: {:?}", e);
561                    std::process::exit(1);
562                }
563            });
564            return;
565        }
566        Command::SplitCommit(split_commit_args) => {
567            if let Err(error) = split_commit::run_split_commit(
568                split_commit_args,
569                &args.inputs,
570                output.as_ref(),
571                args.failed,
572            ) {
573                eprintln!("{error:#}");
574                std::process::exit(1);
575            }
576            return;
577        }
578        Command::Split(split_args) => {
579            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
580                eprintln!("{error:#}");
581                std::process::exit(1);
582            }
583            return;
584        }
585        Command::FilterLanguages(filter_args) => {
586            if let Err(error) =
587                run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
588            {
589                eprintln!("{error:#}");
590                std::process::exit(1);
591            }
592            return;
593        }
594        Command::Qa(qa_args) => {
595            // Read examples from input files
596            let mut examples = example::read_example_files(&args.inputs);
597
598            // Apply filters
599            if let Some(name_filter) = &args.name {
600                examples.retain(|e| e.spec.name.contains(name_filter));
601            }
602            if let Some(repo_filter) = &args.repo {
603                examples.retain(|e| e.spec.repository_url.contains(repo_filter));
604            }
605            if let Some(offset) = args.offset {
606                examples.splice(0..offset, []);
607            }
608            if let Some(limit) = args.limit {
609                examples.truncate(limit);
610            }
611
612            smol::block_on(async {
613                if let Err(e) = qa::run_qa(&mut examples, qa_args, output.as_ref()).await {
614                    eprintln!("Error: {:?}", e);
615                    std::process::exit(1);
616                }
617            });
618            return;
619        }
620        _ => {}
621    }
622
623    let http_client = Arc::new(ReqwestClient::new());
624    let app = Application::headless().with_http_client(http_client);
625
626    app.run(move |cx| {
627        let app_state = Arc::new(headless::init(cx));
628        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
629
630        cx.spawn(async move |cx| {
631            let result = async {
632                let examples = load_examples(
633                    app_state.client.http_client(),
634                    &args,
635                    output.as_ref(),
636                    cx.background_executor().clone(),
637                )
638                .await?;
639
640                match &command {
641                    Command::Predict(args) | Command::Score(args) => {
642                        predict::sync_batches(args.provider.as_ref()).await?;
643                    }
644                    Command::Eval(args) => {
645                        predict::sync_batches(args.predict.provider.as_ref()).await?;
646                    }
647                    _ => (),
648                }
649
650                let failfast_on_single_example = examples.len() == 1;
651
652                // For --in-place, write to a temp file and rename at the end to avoid data loss on interruption
653                let in_place_temp_path = if args.in_place {
654                    output.as_ref().map(|path| {
655                        let mut temp_path = path.clone();
656                        temp_path.set_extension("jsonl.tmp");
657                        temp_path
658                    })
659                } else {
660                    None
661                };
662
663                let output_sender: Option<mpsc::UnboundedSender<String>> =
664                    if args.output.is_some() || !matches!(command, Command::Eval(_)) {
665                        let write_path = in_place_temp_path.as_ref().or(output.as_ref());
666                        write_path.map(|path| {
667                            let file = if args.in_place {
668                                // For --in-place, write to temp file (truncate if exists)
669                                OpenOptions::new()
670                                    .create(true)
671                                    .write(true)
672                                    .truncate(true)
673                                    .open(path)
674                                    .expect("Failed to open temp output file")
675                            } else {
676                                // For regular output, append to support resuming
677                                OpenOptions::new()
678                                    .create(true)
679                                    .append(true)
680                                    .open(path)
681                                    .expect("Failed to open output file")
682                            };
683                            let mut writer = BufWriter::new(file);
684                            let (sender, mut receiver) = mpsc::unbounded::<String>();
685                            cx.background_spawn(async move {
686                                while let Some(line) = receiver.next().await {
687                                    writeln!(writer, "{}", line).expect("Failed to write example");
688                                    writer.flush().expect("Failed to flush output");
689                                }
690                            })
691                            .detach();
692                            sender
693                        })
694                    } else {
695                        None
696                    };
697
698                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
699                let finished_examples = Mutex::new(Vec::new());
700
701                let mut tasks = Vec::new();
702                for _ in 0..args.max_parallelism {
703                    tasks.push(async {
704                        loop {
705                            let Some(mut repo_examples) =
706                                grouped_examples.lock().unwrap().pop_front()
707                            else {
708                                break;
709                            };
710                            for example in &mut repo_examples {
711                                let example_progress =
712                                    Progress::global().start_group(&example.spec.name);
713
714                                let result = async {
715                                    match &command {
716                                        Command::Read => {}
717                                        Command::LoadProject => {
718                                            run_load_project(
719                                                example,
720                                                app_state.clone(),
721                                                &example_progress,
722                                                cx.clone(),
723                                            )
724                                            .await?;
725                                        }
726                                        Command::Context => {
727                                            run_context_retrieval(
728                                                example,
729                                                app_state.clone(),
730                                                &example_progress,
731                                                cx.clone(),
732                                            )
733                                            .await?;
734                                        }
735                                        Command::FormatPrompt(args) => {
736                                            run_format_prompt(
737                                                example,
738                                                args,
739                                                app_state.clone(),
740                                                &example_progress,
741                                                cx.clone(),
742                                            )
743                                            .await?;
744                                        }
745                                        Command::Predict(args) => {
746                                            run_prediction(
747                                                example,
748                                                args,
749                                                app_state.clone(),
750                                                &example_progress,
751                                                cx.clone(),
752                                            )
753                                            .await?;
754                                        }
755                                        Command::ParseOutput => {
756                                            parse_output::run_parse_output(example)?;
757                                        }
758                                        Command::Distill => {
759                                            run_distill(example).await?;
760                                        }
761                                        Command::Score(args) => {
762                                            run_scoring(
763                                                example,
764                                                args,
765                                                app_state.clone(),
766                                                &example_progress,
767                                                cx.clone(),
768                                            )
769                                            .await?;
770                                        }
771                                        Command::Eval(args) => {
772                                            run_scoring(
773                                                example,
774                                                &args.predict,
775                                                app_state.clone(),
776                                                &example_progress,
777                                                cx.clone(),
778                                            )
779                                            .await?;
780                                        }
781                                        Command::Clean
782                                        | Command::Synthesize(_)
783                                        | Command::SplitCommit(_)
784                                        | Command::Split(_)
785                                        | Command::FilterLanguages(_)
786                                        | Command::ImportBatch(_)
787                                        | Command::Qa(_) => {
788                                            unreachable!()
789                                        }
790                                    }
791                                    anyhow::Ok(())
792                                }
793                                .await;
794
795                                let failed = if let Err(error) = result {
796                                    handle_error(
797                                        error,
798                                        &args,
799                                        &command,
800                                        &app_state,
801                                        failfast_on_single_example,
802                                        &example,
803                                    )
804                                    .await;
805                                    true
806                                } else {
807                                    false
808                                };
809
810                                let should_write = !failed || args.failed == FailedHandling::Keep;
811                                if should_write {
812                                    if let Some(ref mut sender) = output_sender.clone() {
813                                        let line = serde_json::to_string(&example).unwrap();
814                                        sender
815                                            .send(line)
816                                            .await
817                                            .expect("Failed to send to output writer");
818                                    } else if args.output.is_none()
819                                        && !matches!(command, Command::Eval(_))
820                                    {
821                                        let line = serde_json::to_string(&example).unwrap();
822                                        println!("{}", line);
823                                    }
824                                }
825                            }
826
827                            let repo_url = &repo_examples.first().unwrap().spec.repository_url;
828                            let project = repo_examples
829                                .iter()
830                                .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
831                                .or_else(|| app_state.project_cache.get(repo_url));
832
833                            if let Some(project) = project {
834                                let mut cx = cx.clone();
835
836                                let shutdown_task: Task<()> =
837                                    project.update(&mut cx, |project, cx| {
838                                        let lsp_store = project.lsp_store();
839                                        lsp_store.update(cx, |lsp_store, cx| {
840                                            lsp_store.shutdown_all_language_servers(cx)
841                                        })
842                                    });
843
844                                shutdown_task.await;
845
846                                if let Some(ep_store) =
847                                    cx.update(|cx| EditPredictionStore::try_global(cx))
848                                {
849                                    ep_store.update(&mut cx, |store, _| {
850                                        store.remove_project(&project);
851                                    });
852                                }
853                            }
854
855                            app_state.project_cache.remove(repo_url);
856                            for example in &mut repo_examples {
857                                example.state.take();
858                            }
859                            finished_examples
860                                .lock()
861                                .unwrap()
862                                .extend_from_slice(&repo_examples);
863                        }
864                    });
865                }
866                futures::future::join_all(tasks).await;
867
868                Progress::global().finalize();
869
870                match &command {
871                    Command::Predict(args) | Command::Score(args) => {
872                        predict::sync_batches(args.provider.as_ref()).await?;
873                    }
874                    Command::Eval(args) => {
875                        predict::sync_batches(args.predict.provider.as_ref()).await?;
876                    }
877                    _ => (),
878                }
879
880                match &command {
881                    Command::Eval(args) => {
882                        let examples = finished_examples.lock().unwrap();
883                        score::print_report(&examples);
884                        if let Some(summary_path) = &args.summary_json {
885                            score::write_summary_json(&examples, summary_path)?;
886                        }
887                    }
888                    _ => (),
889                };
890
891                // For --in-place, atomically rename temp file to original
892                if let (Some(temp_path), Some(final_path)) = (&in_place_temp_path, &output) {
893                    std::fs::rename(temp_path, final_path)
894                        .expect("Failed to rename temp file to final output");
895                }
896
897                anyhow::Ok(())
898            }
899            .await;
900
901            if let Err(e) = result {
902                panic!("Fatal error: {:?}", e);
903            }
904
905            let _ = cx.update(|cx| cx.quit());
906        })
907        .detach();
908    });
909}
910
911async fn handle_error(
912    error: anyhow::Error,
913    args: &EpArgs,
914    command: &Command,
915    app_state: &Arc<headless::EpAppState>,
916    failfast_on_single_example: bool,
917    example: &Example,
918) {
919    Progress::global().increment_failed();
920
921    let msg;
922    if !matches!(args.failed, FailedHandling::SkipNoFiles) {
923        let example_name = example.spec.filename();
924
925        let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
926        app_state
927            .fs
928            .write(
929                &failed_example_path,
930                &serde_json::to_vec_pretty(&example).unwrap(),
931            )
932            .await
933            .unwrap();
934        let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
935        app_state
936            .fs
937            .write(&err_path, format!("{error:?}").as_bytes())
938            .await
939            .unwrap();
940
941        let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
942        let mut file = OpenOptions::new()
943            .create(true)
944            .append(true)
945            .open(&failed_jsonl_path)
946            .expect("Failed to open failed.jsonl");
947        writeln!(file, "{}", serde_json::to_string(example).unwrap())
948            .expect("Failed to write to failed.jsonl");
949
950        let cursor_path = example
951            .repo_name()
952            .unwrap()
953            .worktree_path()
954            .join(&example.spec.cursor_path);
955        msg = format!(
956            indoc::indoc! {"
957                While processing \"{}\":
958
959                \x1b[31m{:?}\x1b[0m
960
961                Example:        \x1b[36m{}\x1b[0m
962                Error file:     \x1b[36m{}\x1b[0m
963                Cursor file:    \x1b[36m{}\x1b[0m
964                Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
965            "},
966            example.spec.name,
967            error,
968            failed_example_path.display(),
969            err_path.display(),
970            cursor_path.display(),
971            command,
972            failed_example_path.display(),
973        );
974    } else {
975        msg = format!(
976            indoc::indoc! {"
977            While processing \"{}\":
978
979                \x1b[31m{:?}\x1b[0m
980            "},
981            example.spec.name, error
982        );
983    }
984
985    if args.failfast || failfast_on_single_example {
986        Progress::global().finalize();
987        panic!("{}", msg);
988    } else {
989        log::error!("{}", msg);
990    }
991}