main.rs

  1mod anthropic_client;
  2mod distill;
  3mod example;
  4mod filter_languages;
  5mod format_prompt;
  6mod git;
  7mod headless;
  8mod load_project;
  9mod metrics;
 10mod parse_output;
 11mod paths;
 12mod predict;
 13mod progress;
 14mod pull_examples;
 15mod qa;
 16mod reorder_patch;
 17mod retrieve_context;
 18mod score;
 19mod split_commit;
 20mod split_dataset;
 21mod synthesize;
 22mod word_diff;
 23use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
 24use collections::HashSet;
 25use edit_prediction::EditPredictionStore;
 26use futures::channel::mpsc;
 27use futures::{SinkExt as _, StreamExt as _};
 28use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
 29use zeta_prompt::ZetaVersion;
 30
 31use reqwest_client::ReqwestClient;
 32use serde::{Deserialize, Deserializer, Serialize, Serializer};
 33use std::fmt::Display;
 34use std::fs::{File, OpenOptions};
 35use std::hash::{Hash, Hasher};
 36use std::io::{BufRead, BufReader, BufWriter, Write};
 37use std::sync::Mutex;
 38use std::{path::PathBuf, sync::Arc};
 39
 40use crate::distill::run_distill;
 41use crate::example::{Example, group_examples_by_repo, read_example_files};
 42use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
 43use crate::format_prompt::run_format_prompt;
 44use crate::load_project::run_load_project;
 45use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
 46use crate::predict::run_prediction;
 47use crate::progress::Progress;
 48use crate::retrieve_context::run_context_retrieval;
 49use crate::score::run_scoring;
 50use crate::split_commit::SplitCommitArgs;
 51use crate::split_dataset::SplitArgs;
 52use crate::synthesize::{SynthesizeConfig, run_synthesize};
 53
 54#[derive(Parser, Debug)]
 55#[command(name = "ep")]
 56struct EpArgs {
 57    #[arg(long, default_value_t = false)]
 58    printenv: bool,
 59    #[clap(long, default_value_t = 10, global = true)]
 60    max_parallelism: usize,
 61    #[clap(long, global = true)]
 62    limit: Option<usize>,
 63    #[clap(long, global = true)]
 64    offset: Option<usize>,
 65    /// Filter examples by name
 66    #[clap(long, global = true)]
 67    name: Option<String>,
 68    /// Filter examples by repository
 69    #[clap(long, global = true)]
 70    repo: Option<String>,
 71    #[command(subcommand)]
 72    command: Option<Command>,
 73    #[clap(global = true, help = INPUTS_HELP)]
 74    inputs: Vec<PathBuf>,
 75    #[arg(long, short, global = true)]
 76    output: Option<PathBuf>,
 77    #[arg(long, short, global = true)]
 78    in_place: bool,
 79    #[arg(long, short, global = true)]
 80    failfast: bool,
 81    /// How to handle failed examples in output: keep them or skip them.
 82    /// Failed examples are always logged to the run's failed directory.
 83    #[arg(long, global = true, default_value = "keep")]
 84    failed: FailedHandling,
 85}
 86
 87/// Controls whether failed examples are included in the main output.
 88/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
 89#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
 90pub enum FailedHandling {
 91    /// Include failed examples in the main output (default)
 92    #[default]
 93    Keep,
 94    /// Exclude failed examples from the main output
 95    Skip,
 96    /// Skip writing files
 97    SkipNoFiles,
 98}
 99
100const INPUTS_HELP: &str = r#"
101Inputs can be file paths or special specifiers:
102
103  path
104      Path to an example(s) file (.md, .json, or .jsonl)
105
106  captured-after:{timestamp}
107      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
108
109      You can specify this multiple times and mix it with file inputs.
110
111      Required environment variables to connect to Snowflake:
112          EP_SNOWFLAKE_API_KEY
113          EP_SNOWFLAKE_BASE_URL
114
115      Optional:
116          EP_SNOWFLAKE_ROLE
117
118Examples:
119
120  # Predict from a file
121  ep predict examples.jsonl
122
123  # Predict from captured examples after a timestamp
124  ep predict captured-after:2025-01-01T00:00:00Z
125
126  # Mix file inputs and captured-after in the same invocation
127  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
128"#;
129
130#[derive(Subcommand, Debug, Clone)]
131enum Command {
132    /// Parse markdown examples and output a combined .jsonl file
133    ParseExample,
134    /// Create git worktrees for each example and load file contents
135    LoadProject,
136    /// Retrieve context for input examples.
137    Context,
138    /// Generate a prompt string for a specific model
139    FormatPrompt(FormatPromptArgs),
140    /// Runs edit prediction
141    Predict(PredictArgs),
142    /// Parse model outputs (actual_output) into unified diffs (actual_patch).
143    /// Requires format-prompt to have been run first. Uses provider from prompt.
144    ParseOutput,
145    /// Computes a score based on actual and expected patches
146    Score(PredictArgs),
147    /// Prepares a distillation dataset by copying expected outputs to
148    /// predicted outputs and removing actual outputs and prompts.
149    Distill,
150    /// Print aggregated scores
151    Eval(EvalArgs),
152    /// Generate eval examples by analyzing git commits from a repository
153    Synthesize(SynthesizeArgs),
154    /// Remove git repositories and worktrees
155    Clean,
156    /// Generate an evaluation example by splitting a chronologically-ordered commit
157    SplitCommit(SplitCommitArgs),
158    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
159    Split(SplitArgs),
160    /// Filter a JSONL dataset by programming language (based on cursor_path extension)
161    FilterLanguages(FilterLanguagesArgs),
162    /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
163    ImportBatch(ImportBatchArgs),
164    /// Assess the quality of predictions using LLM-as-a-judge
165    Qa(qa::QaArgs),
166}
167
168impl Display for Command {
169    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170        match self {
171            Command::ParseExample => write!(f, "parse-example"),
172            Command::LoadProject => write!(f, "load-project"),
173            Command::Context => write!(f, "context"),
174            Command::FormatPrompt(args) => {
175                write!(f, "format-prompt --provider={}", args.provider)
176            }
177            Command::Predict(args) => match &args.provider {
178                Some(provider) => write!(f, "predict --provider={}", provider),
179                None => write!(f, "predict"),
180            },
181            Command::ParseOutput => write!(f, "parse-output"),
182            Command::Score(args) => match &args.provider {
183                Some(provider) => write!(f, "score --provider={}", provider),
184                None => write!(f, "score"),
185            },
186            Command::Distill => write!(f, "distill"),
187            Command::Eval(args) => match &args.predict.provider {
188                Some(provider) => write!(f, "eval --provider={}", provider),
189                None => write!(f, "eval"),
190            },
191            Command::Synthesize(args) => {
192                write!(f, "synthesize --repos {}", args.repos.join(" "))
193            }
194            Command::Clean => write!(f, "clean"),
195            Command::SplitCommit(_) => write!(f, "split-commit"),
196            Command::Split(_) => write!(f, "split"),
197            Command::FilterLanguages(_) => write!(f, "filter-languages"),
198            Command::ImportBatch(args) => {
199                write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
200            }
201            Command::Qa(_) => {
202                write!(f, "qa")
203            }
204        }
205    }
206}
207
208#[derive(Debug, Args, Clone)]
209struct FormatPromptArgs {
210    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
211    provider: PredictionProvider,
212}
213
214#[derive(Debug, Args, Clone)]
215struct PredictArgs {
216    #[clap(long, short('p'))]
217    provider: Option<PredictionProvider>,
218    #[clap(long, default_value_t = 1)]
219    repetitions: usize,
220}
221
222#[derive(Debug, Args, Clone)]
223struct EvalArgs {
224    #[clap(flatten)]
225    predict: PredictArgs,
226    /// Path to write summary scores as JSON
227    #[clap(long)]
228    summary_json: Option<PathBuf>,
229}
230
231#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
232enum PredictionProvider {
233    Sweep,
234    Mercury,
235    Zeta1,
236    Zeta2(ZetaVersion),
237    Teacher(ZetaVersion),
238    TeacherNonBatching(ZetaVersion),
239}
240
241impl Default for PredictionProvider {
242    fn default() -> Self {
243        PredictionProvider::Zeta2(ZetaVersion::default())
244    }
245}
246
247impl std::fmt::Display for PredictionProvider {
248    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249        match self {
250            PredictionProvider::Sweep => write!(f, "sweep"),
251            PredictionProvider::Mercury => write!(f, "mercury"),
252            PredictionProvider::Zeta1 => write!(f, "zeta1"),
253            PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
254            PredictionProvider::Teacher(version) => write!(f, "teacher:{version}"),
255            PredictionProvider::TeacherNonBatching(version) => {
256                write!(f, "teacher-non-batching:{version}")
257            }
258        }
259    }
260}
261
262impl std::str::FromStr for PredictionProvider {
263    type Err = anyhow::Error;
264
265    fn from_str(mut s: &str) -> Result<Self, Self::Err> {
266        let mut version = ZetaVersion::default();
267        if let Some((first, second)) = s.split_once(':') {
268            version = ZetaVersion::parse(second)?;
269            s = first;
270        }
271
272        let s_lower = s.to_lowercase();
273        match s_lower.as_str() {
274            "sweep" => Ok(PredictionProvider::Sweep),
275            "mercury" => Ok(PredictionProvider::Mercury),
276            "zeta1" => Ok(PredictionProvider::Zeta1),
277            "zeta2" => Ok(PredictionProvider::Zeta2(version)),
278            "teacher" => Ok(PredictionProvider::Teacher(version)),
279            "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
280                Ok(PredictionProvider::TeacherNonBatching(version))
281            }
282            _ => {
283                anyhow::bail!(
284                    "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching\n\
285                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
286                 Available zeta versions:\n{}",
287                    ZetaVersion::options_as_string()
288                )
289            }
290        }
291    }
292}
293
294impl Serialize for PredictionProvider {
295    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
296    where
297        S: Serializer,
298    {
299        serializer.serialize_str(&self.to_string())
300    }
301}
302
303impl<'de> Deserialize<'de> for PredictionProvider {
304    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
305    where
306        D: Deserializer<'de>,
307    {
308        let s = String::deserialize(deserializer)?;
309        s.parse().map_err(serde::de::Error::custom)
310    }
311}
312
313#[derive(Debug, Args, Clone)]
314struct SynthesizeArgs {
315    /// Repository URLs (git@github.com:owner/repo or https://...)
316    #[clap(long, required = true, num_args = 1..)]
317    repos: Vec<String>,
318
319    /// Number of examples to generate per repository
320    #[clap(long, default_value_t = 5)]
321    count: usize,
322
323    /// Maximum commits to scan per repository before giving up
324    #[clap(long, default_value_t = 100)]
325    max_commits: usize,
326
327    /// Ignore state file and reprocess all commits
328    #[clap(long)]
329    fresh: bool,
330}
331
332#[derive(Debug, Args, Clone)]
333struct ImportBatchArgs {
334    /// Anthropic batch IDs to import (e.g., msgbatch_xxx)
335    #[clap(long, required = true, num_args = 1..)]
336    batch_ids: Vec<String>,
337}
338
339impl EpArgs {
340    fn output_path(&self) -> Option<PathBuf> {
341        if self.in_place {
342            if self.inputs.len() == 1 {
343                self.inputs.first().cloned()
344            } else {
345                panic!("--in-place requires exactly one input file")
346            }
347        } else {
348            self.output.clone()
349        }
350    }
351}
352
353async fn load_examples(
354    http_client: Arc<dyn http_client::HttpClient>,
355    args: &EpArgs,
356    output_path: Option<&PathBuf>,
357    background_executor: BackgroundExecutor,
358) -> anyhow::Result<Vec<Example>> {
359    let mut captured_after_timestamps = Vec::new();
360    let mut file_inputs = Vec::new();
361
362    for input in &args.inputs {
363        let input_string = input.to_string_lossy();
364        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
365            captured_after_timestamps.push(timestamp.to_string());
366        } else {
367            file_inputs.push(input.clone());
368        }
369    }
370
371    let mut examples = read_example_files(&file_inputs);
372
373    Progress::global().set_total_examples(examples.len());
374
375    let remaining_limit_for_snowflake =
376        args.limit.map(|limit| limit.saturating_sub(examples.len()));
377
378    if let Some(0) = remaining_limit_for_snowflake {
379        log::info!(
380            "skipping captured-after inputs because --limit is already satisfied by example files"
381        );
382    } else if !captured_after_timestamps.is_empty() {
383        captured_after_timestamps.sort();
384
385        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
386
387        let mut captured_examples = pull_examples::fetch_captured_examples_after(
388            http_client,
389            &captured_after_timestamps,
390            max_rows_per_timestamp,
391            background_executor,
392        )
393        .await?;
394        examples.append(&mut captured_examples);
395    }
396
397    crate::example::sort_examples_by_repo_and_rev(&mut examples);
398
399    if let Some(name_filter) = &args.name {
400        examples.retain(|example| example.spec.name.contains(name_filter));
401    }
402    if let Some(repo_filter) = &args.repo {
403        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
404    }
405
406    // Skip resume logic for --in-place since input and output are the same file,
407    // which would incorrectly treat all input examples as already processed.
408    if !args.in_place {
409        if let Some(path) = output_path {
410            resume_from_output(path, &mut examples);
411        }
412    }
413
414    if let Some(offset) = args.offset {
415        examples.splice(0..offset, []);
416    }
417
418    if let Some(limit) = args.limit {
419        examples.truncate(limit);
420    }
421
422    let progress = Progress::global();
423    progress.set_total_examples(examples.len());
424    progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
425
426    Ok(examples)
427}
428
429fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
430    let mut hasher = collections::FxHasher::default();
431    spec.hash(&mut hasher);
432    hasher.finish()
433}
434
435fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
436    let file = match File::open(path) {
437        Ok(f) => f,
438        Err(_) => return,
439    };
440
441    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
442
443    let reader = BufReader::new(file);
444    let mut kept_lines = Vec::new();
445    let mut kept_hashes = HashSet::default();
446
447    for line in reader.lines() {
448        let line = match line {
449            Ok(l) => l,
450            Err(_) => continue,
451        };
452
453        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
454            let hash = spec_hash(&output_example.spec);
455            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
456                kept_hashes.insert(hash);
457                kept_lines.push(line);
458            }
459        }
460    }
461
462    let total = examples.len();
463    let already_processed = kept_hashes.len();
464
465    eprintln!(
466        "Resuming: {}/{} examples already processed",
467        already_processed, total
468    );
469
470    let file = OpenOptions::new()
471        .write(true)
472        .truncate(true)
473        .open(path)
474        .expect("Failed to open output file for rewriting");
475    let mut writer = BufWriter::new(file);
476    for line in &kept_lines {
477        writeln!(writer, "{}", line).expect("Failed to write to output file");
478    }
479    writer.flush().expect("Failed to flush output file");
480
481    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
482}
483
484fn main() {
485    let args = EpArgs::parse();
486
487    if args.printenv {
488        ::util::shell_env::print_env();
489        return;
490    }
491
492    let output = args.output_path();
493    let command = match &args.command {
494        Some(cmd) => cmd.clone(),
495        None => {
496            EpArgs::command().print_help().unwrap();
497            return;
498        }
499    };
500
501    match &command {
502        Command::ImportBatch(import_args) => {
503            smol::block_on(async {
504                let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
505                    .expect("Failed to create Anthropic client");
506                if let Err(e) = client.import_batches(&import_args.batch_ids).await {
507                    eprintln!("Error importing batches: {:?}", e);
508                    std::process::exit(1);
509                }
510                println!(
511                    "Successfully imported {} batch(es)",
512                    import_args.batch_ids.len()
513                );
514            });
515            return;
516        }
517        Command::Clean => {
518            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
519            return;
520        }
521        Command::Synthesize(synth_args) => {
522            let Some(output_dir) = args.output else {
523                panic!("output dir is required");
524            };
525            let config = SynthesizeConfig {
526                repo_urls: synth_args.repos.clone(),
527                count: synth_args.count,
528                max_commits: synth_args.max_commits,
529                output_dir,
530                fresh: synth_args.fresh,
531            };
532            smol::block_on(async {
533                if let Err(e) = run_synthesize(config).await {
534                    eprintln!("Error: {:?}", e);
535                    std::process::exit(1);
536                }
537            });
538            return;
539        }
540        Command::SplitCommit(split_commit_args) => {
541            if let Err(error) = split_commit::run_split_commit(
542                split_commit_args,
543                &args.inputs,
544                output.as_ref(),
545                args.failed,
546            ) {
547                eprintln!("{error:#}");
548                std::process::exit(1);
549            }
550            return;
551        }
552        Command::Split(split_args) => {
553            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
554                eprintln!("{error:#}");
555                std::process::exit(1);
556            }
557            return;
558        }
559        Command::FilterLanguages(filter_args) => {
560            if let Err(error) =
561                run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
562            {
563                eprintln!("{error:#}");
564                std::process::exit(1);
565            }
566            return;
567        }
568        Command::Qa(qa_args) => {
569            // Read examples from input files
570            let mut examples = example::read_example_files(&args.inputs);
571
572            // Apply filters
573            if let Some(name_filter) = &args.name {
574                examples.retain(|e| e.spec.name.contains(name_filter));
575            }
576            if let Some(repo_filter) = &args.repo {
577                examples.retain(|e| e.spec.repository_url.contains(repo_filter));
578            }
579            if let Some(offset) = args.offset {
580                examples.splice(0..offset, []);
581            }
582            if let Some(limit) = args.limit {
583                examples.truncate(limit);
584            }
585
586            smol::block_on(async {
587                if let Err(e) = qa::run_qa(&mut examples, qa_args, output.as_ref()).await {
588                    eprintln!("Error: {:?}", e);
589                    std::process::exit(1);
590                }
591            });
592            return;
593        }
594        _ => {}
595    }
596
597    let http_client = Arc::new(ReqwestClient::new());
598    let app = Application::headless().with_http_client(http_client);
599
600    app.run(move |cx| {
601        let app_state = Arc::new(headless::init(cx));
602        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
603
604        cx.spawn(async move |cx| {
605            let result = async {
606                let examples = load_examples(
607                    app_state.client.http_client(),
608                    &args,
609                    output.as_ref(),
610                    cx.background_executor().clone(),
611                )
612                .await?;
613
614                match &command {
615                    Command::Predict(args) | Command::Score(args) => {
616                        predict::sync_batches(args.provider.as_ref()).await?;
617                    }
618                    Command::Eval(args) => {
619                        predict::sync_batches(args.predict.provider.as_ref()).await?;
620                    }
621                    _ => (),
622                }
623
624                let failfast_on_single_example = examples.len() == 1;
625
626                // For --in-place, write to a temp file and rename at the end to avoid data loss on interruption
627                let in_place_temp_path = if args.in_place {
628                    output.as_ref().map(|path| {
629                        let mut temp_path = path.clone();
630                        temp_path.set_extension("jsonl.tmp");
631                        temp_path
632                    })
633                } else {
634                    None
635                };
636
637                let output_sender: Option<mpsc::UnboundedSender<String>> =
638                    if args.output.is_some() || !matches!(command, Command::Eval(_)) {
639                        let write_path = in_place_temp_path.as_ref().or(output.as_ref());
640                        write_path.map(|path| {
641                            let file = if args.in_place {
642                                // For --in-place, write to temp file (truncate if exists)
643                                OpenOptions::new()
644                                    .create(true)
645                                    .write(true)
646                                    .truncate(true)
647                                    .open(path)
648                                    .expect("Failed to open temp output file")
649                            } else {
650                                // For regular output, append to support resuming
651                                OpenOptions::new()
652                                    .create(true)
653                                    .append(true)
654                                    .open(path)
655                                    .expect("Failed to open output file")
656                            };
657                            let mut writer = BufWriter::new(file);
658                            let (sender, mut receiver) = mpsc::unbounded::<String>();
659                            cx.background_spawn(async move {
660                                while let Some(line) = receiver.next().await {
661                                    writeln!(writer, "{}", line).expect("Failed to write example");
662                                    writer.flush().expect("Failed to flush output");
663                                }
664                            })
665                            .detach();
666                            sender
667                        })
668                    } else {
669                        None
670                    };
671
672                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
673                let finished_examples = Mutex::new(Vec::new());
674
675                let mut tasks = Vec::new();
676                for _ in 0..args.max_parallelism {
677                    tasks.push(async {
678                        loop {
679                            let Some(mut repo_examples) =
680                                grouped_examples.lock().unwrap().pop_front()
681                            else {
682                                break;
683                            };
684                            for example in &mut repo_examples {
685                                let example_progress =
686                                    Progress::global().start_group(&example.spec.name);
687
688                                let result = async {
689                                    match &command {
690                                        Command::ParseExample => {}
691                                        Command::LoadProject => {
692                                            run_load_project(
693                                                example,
694                                                app_state.clone(),
695                                                &example_progress,
696                                                cx.clone(),
697                                            )
698                                            .await?;
699                                        }
700                                        Command::Context => {
701                                            run_context_retrieval(
702                                                example,
703                                                app_state.clone(),
704                                                &example_progress,
705                                                cx.clone(),
706                                            )
707                                            .await?;
708                                        }
709                                        Command::FormatPrompt(args) => {
710                                            run_format_prompt(
711                                                example,
712                                                args,
713                                                app_state.clone(),
714                                                &example_progress,
715                                                cx.clone(),
716                                            )
717                                            .await?;
718                                        }
719                                        Command::Predict(args) => {
720                                            run_prediction(
721                                                example,
722                                                args,
723                                                app_state.clone(),
724                                                &example_progress,
725                                                cx.clone(),
726                                            )
727                                            .await?;
728                                        }
729                                        Command::ParseOutput => {
730                                            parse_output::run_parse_output(example)?;
731                                        }
732                                        Command::Distill => {
733                                            run_distill(example).await?;
734                                        }
735                                        Command::Score(args) => {
736                                            run_scoring(
737                                                example,
738                                                args,
739                                                app_state.clone(),
740                                                &example_progress,
741                                                cx.clone(),
742                                            )
743                                            .await?;
744                                        }
745                                        Command::Eval(args) => {
746                                            run_scoring(
747                                                example,
748                                                &args.predict,
749                                                app_state.clone(),
750                                                &example_progress,
751                                                cx.clone(),
752                                            )
753                                            .await?;
754                                        }
755                                        Command::Clean
756                                        | Command::Synthesize(_)
757                                        | Command::SplitCommit(_)
758                                        | Command::Split(_)
759                                        | Command::FilterLanguages(_)
760                                        | Command::ImportBatch(_)
761                                        | Command::Qa(_) => {
762                                            unreachable!()
763                                        }
764                                    }
765                                    anyhow::Ok(())
766                                }
767                                .await;
768
769                                let failed = if let Err(error) = result {
770                                    handle_error(
771                                        error,
772                                        &args,
773                                        &command,
774                                        &app_state,
775                                        failfast_on_single_example,
776                                        &example,
777                                    )
778                                    .await;
779                                    true
780                                } else {
781                                    false
782                                };
783
784                                let should_write = !failed || args.failed == FailedHandling::Keep;
785                                if should_write {
786                                    if let Some(ref mut sender) = output_sender.clone() {
787                                        let line = serde_json::to_string(&example).unwrap();
788                                        sender
789                                            .send(line)
790                                            .await
791                                            .expect("Failed to send to output writer");
792                                    } else if args.output.is_none()
793                                        && !matches!(command, Command::Eval(_))
794                                    {
795                                        let line = serde_json::to_string(&example).unwrap();
796                                        println!("{}", line);
797                                    }
798                                }
799                            }
800
801                            let repo_url = &repo_examples.first().unwrap().spec.repository_url;
802                            let project = repo_examples
803                                .iter()
804                                .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
805                                .or_else(|| app_state.project_cache.get(repo_url));
806
807                            if let Some(project) = project {
808                                let mut cx = cx.clone();
809
810                                let shutdown_task: Task<()> =
811                                    project.update(&mut cx, |project, cx| {
812                                        let lsp_store = project.lsp_store();
813                                        lsp_store.update(cx, |lsp_store, cx| {
814                                            lsp_store.shutdown_all_language_servers(cx)
815                                        })
816                                    });
817
818                                shutdown_task.await;
819
820                                if let Some(ep_store) =
821                                    cx.update(|cx| EditPredictionStore::try_global(cx))
822                                {
823                                    ep_store.update(&mut cx, |store, _| {
824                                        store.remove_project(&project);
825                                    });
826                                }
827                            }
828
829                            app_state.project_cache.remove(repo_url);
830                            for example in &mut repo_examples {
831                                example.state.take();
832                            }
833                            finished_examples
834                                .lock()
835                                .unwrap()
836                                .extend_from_slice(&repo_examples);
837                        }
838                    });
839                }
840                futures::future::join_all(tasks).await;
841
842                Progress::global().finalize();
843
844                match &command {
845                    Command::Predict(args) | Command::Score(args) => {
846                        predict::sync_batches(args.provider.as_ref()).await?;
847                    }
848                    Command::Eval(args) => {
849                        predict::sync_batches(args.predict.provider.as_ref()).await?;
850                    }
851                    _ => (),
852                }
853
854                match &command {
855                    Command::Eval(args) => {
856                        let examples = finished_examples.lock().unwrap();
857                        score::print_report(&examples);
858                        if let Some(summary_path) = &args.summary_json {
859                            score::write_summary_json(&examples, summary_path)?;
860                        }
861                    }
862                    _ => (),
863                };
864
865                // For --in-place, atomically rename temp file to original
866                if let (Some(temp_path), Some(final_path)) = (&in_place_temp_path, &output) {
867                    std::fs::rename(temp_path, final_path)
868                        .expect("Failed to rename temp file to final output");
869                }
870
871                anyhow::Ok(())
872            }
873            .await;
874
875            if let Err(e) = result {
876                panic!("Fatal error: {:?}", e);
877            }
878
879            let _ = cx.update(|cx| cx.quit());
880        })
881        .detach();
882    });
883}
884
885async fn handle_error(
886    error: anyhow::Error,
887    args: &EpArgs,
888    command: &Command,
889    app_state: &Arc<headless::EpAppState>,
890    failfast_on_single_example: bool,
891    example: &Example,
892) {
893    Progress::global().increment_failed();
894
895    let msg;
896    if !matches!(args.failed, FailedHandling::SkipNoFiles) {
897        let example_name = example.spec.filename();
898
899        let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
900        app_state
901            .fs
902            .write(
903                &failed_example_path,
904                &serde_json::to_vec_pretty(&example).unwrap(),
905            )
906            .await
907            .unwrap();
908        let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
909        app_state
910            .fs
911            .write(&err_path, format!("{error:?}").as_bytes())
912            .await
913            .unwrap();
914
915        let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
916        let mut file = OpenOptions::new()
917            .create(true)
918            .append(true)
919            .open(&failed_jsonl_path)
920            .expect("Failed to open failed.jsonl");
921        writeln!(file, "{}", serde_json::to_string(example).unwrap())
922            .expect("Failed to write to failed.jsonl");
923
924        let cursor_path = example
925            .repo_name()
926            .unwrap()
927            .worktree_path()
928            .join(&example.spec.cursor_path);
929        msg = format!(
930            indoc::indoc! {"
931                While processing \"{}\":
932
933                \x1b[31m{:?}\x1b[0m
934
935                Example:        \x1b[36m{}\x1b[0m
936                Error file:     \x1b[36m{}\x1b[0m
937                Cursor file:    \x1b[36m{}\x1b[0m
938                Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
939            "},
940            example.spec.name,
941            error,
942            failed_example_path.display(),
943            err_path.display(),
944            cursor_path.display(),
945            command,
946            failed_example_path.display(),
947        );
948    } else {
949        msg = format!(
950            indoc::indoc! {"
951            While processing \"{}\":
952
953                \x1b[31m{:?}\x1b[0m
954            "},
955            example.spec.name, error
956        );
957    }
958
959    if args.failfast || failfast_on_single_example {
960        Progress::global().finalize();
961        panic!("{}", msg);
962    } else {
963        log::error!("{}", msg);
964    }
965}