main.rs

  1mod anthropic_client;
  2mod distill;
  3mod example;
  4mod filter_languages;
  5mod format_prompt;
  6mod git;
  7mod headless;
  8mod load_project;
  9mod metrics;
 10mod parse_output;
 11mod paths;
 12mod predict;
 13mod progress;
 14mod pull_examples;
 15mod reorder_patch;
 16mod retrieve_context;
 17mod score;
 18mod split_commit;
 19mod split_dataset;
 20mod synthesize;
 21use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
 22use collections::HashSet;
 23use edit_prediction::EditPredictionStore;
 24use futures::channel::mpsc;
 25use futures::{SinkExt as _, StreamExt as _};
 26use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
 27use zeta_prompt::ZetaVersion;
 28
 29use reqwest_client::ReqwestClient;
 30use serde::{Deserialize, Deserializer, Serialize, Serializer};
 31use std::fmt::Display;
 32use std::fs::{File, OpenOptions};
 33use std::hash::{Hash, Hasher};
 34use std::io::{BufRead, BufReader, BufWriter, Write};
 35use std::sync::Mutex;
 36use std::{path::PathBuf, sync::Arc};
 37
 38use crate::distill::run_distill;
 39use crate::example::{Example, group_examples_by_repo, read_example_files};
 40use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
 41use crate::format_prompt::run_format_prompt;
 42use crate::load_project::run_load_project;
 43use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
 44use crate::predict::run_prediction;
 45use crate::progress::Progress;
 46use crate::retrieve_context::run_context_retrieval;
 47use crate::score::run_scoring;
 48use crate::split_commit::SplitCommitArgs;
 49use crate::split_dataset::SplitArgs;
 50use crate::synthesize::{SynthesizeConfig, run_synthesize};
 51
 52#[derive(Parser, Debug)]
 53#[command(name = "ep")]
 54struct EpArgs {
 55    #[arg(long, default_value_t = false)]
 56    printenv: bool,
 57    #[clap(long, default_value_t = 10, global = true)]
 58    max_parallelism: usize,
 59    #[clap(long, global = true)]
 60    limit: Option<usize>,
 61    #[clap(long, global = true)]
 62    offset: Option<usize>,
 63    /// Filter examples by name
 64    #[clap(long, global = true)]
 65    name: Option<String>,
 66    /// Filter examples by repository
 67    #[clap(long, global = true)]
 68    repo: Option<String>,
 69    #[command(subcommand)]
 70    command: Option<Command>,
 71    #[clap(global = true, help = INPUTS_HELP)]
 72    inputs: Vec<PathBuf>,
 73    #[arg(long, short, global = true)]
 74    output: Option<PathBuf>,
 75    #[arg(long, short, global = true)]
 76    in_place: bool,
 77    #[arg(long, short, global = true)]
 78    failfast: bool,
 79    /// How to handle failed examples in output: keep them or skip them.
 80    /// Failed examples are always logged to the run's failed directory.
 81    #[arg(long, global = true, default_value = "keep")]
 82    failed: FailedHandling,
 83}
 84
 85/// Controls whether failed examples are included in the main output.
 86/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
 87#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
 88pub enum FailedHandling {
 89    /// Include failed examples in the main output (default)
 90    #[default]
 91    Keep,
 92    /// Exclude failed examples from the main output
 93    Skip,
 94    /// Skip writing files
 95    SkipNoFiles,
 96}
 97
 98const INPUTS_HELP: &str = r#"
 99Inputs can be file paths or special specifiers:
100
101  path
102      Path to an example(s) file (.md, .json, or .jsonl)
103
104  captured-after:{timestamp}
105      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
106
107      You can specify this multiple times and mix it with file inputs.
108
109      Required environment variables to connect to Snowflake:
110          EP_SNOWFLAKE_API_KEY
111          EP_SNOWFLAKE_BASE_URL
112
113      Optional:
114          EP_SNOWFLAKE_ROLE
115
116Examples:
117
118  # Predict from a file
119  ep predict examples.jsonl
120
121  # Predict from captured examples after a timestamp
122  ep predict captured-after:2025-01-01T00:00:00Z
123
124  # Mix file inputs and captured-after in the same invocation
125  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
126"#;
127
128#[derive(Subcommand, Debug, Clone)]
129enum Command {
130    /// Parse markdown examples and output a combined .jsonl file
131    ParseExample,
132    /// Create git worktrees for each example and load file contents
133    LoadProject,
134    /// Retrieve context for input examples.
135    Context,
136    /// Generate a prompt string for a specific model
137    FormatPrompt(FormatPromptArgs),
138    /// Runs edit prediction
139    Predict(PredictArgs),
140    /// Parse model outputs (actual_output) into unified diffs (actual_patch).
141    /// Requires format-prompt to have been run first. Uses provider from prompt.
142    ParseOutput,
143    /// Computes a score based on actual and expected patches
144    Score(PredictArgs),
145    /// Prepares a distillation dataset by copying expected outputs to
146    /// predicted outputs and removing actual outputs and prompts.
147    Distill,
148    /// Print aggregated scores
149    Eval(PredictArgs),
150    /// Generate eval examples by analyzing git commits from a repository
151    Synthesize(SynthesizeArgs),
152    /// Remove git repositories and worktrees
153    Clean,
154    /// Generate an evaluation example by splitting a chronologically-ordered commit
155    SplitCommit(SplitCommitArgs),
156    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
157    Split(SplitArgs),
158    /// Filter a JSONL dataset by programming language (based on cursor_path extension)
159    FilterLanguages(FilterLanguagesArgs),
160    /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
161    ImportBatch(ImportBatchArgs),
162}
163
164impl Display for Command {
165    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166        match self {
167            Command::ParseExample => write!(f, "parse-example"),
168            Command::LoadProject => write!(f, "load-project"),
169            Command::Context => write!(f, "context"),
170            Command::FormatPrompt(args) => {
171                write!(f, "format-prompt --provider={}", args.provider)
172            }
173            Command::Predict(args) => match &args.provider {
174                Some(provider) => write!(f, "predict --provider={}", provider),
175                None => write!(f, "predict"),
176            },
177            Command::ParseOutput => write!(f, "parse-output"),
178            Command::Score(args) => match &args.provider {
179                Some(provider) => write!(f, "score --provider={}", provider),
180                None => write!(f, "score"),
181            },
182            Command::Distill => write!(f, "distill"),
183            Command::Eval(args) => match &args.provider {
184                Some(provider) => write!(f, "eval --provider={}", provider),
185                None => write!(f, "eval"),
186            },
187            Command::Synthesize(args) => {
188                write!(f, "synthesize --repos {}", args.repos.join(" "))
189            }
190            Command::Clean => write!(f, "clean"),
191            Command::SplitCommit(_) => write!(f, "split-commit"),
192            Command::Split(_) => write!(f, "split"),
193            Command::FilterLanguages(_) => write!(f, "filter-languages"),
194            Command::ImportBatch(args) => {
195                write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
196            }
197        }
198    }
199}
200
201#[derive(Debug, Args, Clone)]
202struct FormatPromptArgs {
203    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
204    provider: PredictionProvider,
205}
206
207#[derive(Debug, Args, Clone)]
208struct PredictArgs {
209    #[clap(long, short('p'))]
210    provider: Option<PredictionProvider>,
211    #[clap(long, default_value_t = 1)]
212    repetitions: usize,
213}
214
215#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
216enum PredictionProvider {
217    Sweep,
218    Mercury,
219    Zeta1,
220    Zeta2(ZetaVersion),
221    Teacher(ZetaVersion),
222    TeacherNonBatching(ZetaVersion),
223}
224
225impl Default for PredictionProvider {
226    fn default() -> Self {
227        PredictionProvider::Zeta2(ZetaVersion::default())
228    }
229}
230
231impl std::fmt::Display for PredictionProvider {
232    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233        match self {
234            PredictionProvider::Sweep => write!(f, "sweep"),
235            PredictionProvider::Mercury => write!(f, "mercury"),
236            PredictionProvider::Zeta1 => write!(f, "zeta1"),
237            PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
238            PredictionProvider::Teacher(version) => write!(f, "teacher:{version}"),
239            PredictionProvider::TeacherNonBatching(version) => {
240                write!(f, "teacher-non-batching:{version}")
241            }
242        }
243    }
244}
245
246impl std::str::FromStr for PredictionProvider {
247    type Err = anyhow::Error;
248
249    fn from_str(mut s: &str) -> Result<Self, Self::Err> {
250        let mut version = ZetaVersion::default();
251        if let Some((first, second)) = s.split_once(':') {
252            version = ZetaVersion::parse(second)?;
253            s = first;
254        }
255
256        let s_lower = s.to_lowercase();
257        match s_lower.as_str() {
258            "sweep" => Ok(PredictionProvider::Sweep),
259            "mercury" => Ok(PredictionProvider::Mercury),
260            "zeta1" => Ok(PredictionProvider::Zeta1),
261            "zeta2" => Ok(PredictionProvider::Zeta2(version)),
262            "teacher" => Ok(PredictionProvider::Teacher(version)),
263            "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
264                Ok(PredictionProvider::TeacherNonBatching(version))
265            }
266            _ => {
267                anyhow::bail!(
268                    "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching\n\
269                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
270                 Available zeta versions:\n{}",
271                    ZetaVersion::options_as_string()
272                )
273            }
274        }
275    }
276}
277
278impl Serialize for PredictionProvider {
279    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
280    where
281        S: Serializer,
282    {
283        serializer.serialize_str(&self.to_string())
284    }
285}
286
287impl<'de> Deserialize<'de> for PredictionProvider {
288    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
289    where
290        D: Deserializer<'de>,
291    {
292        let s = String::deserialize(deserializer)?;
293        s.parse().map_err(serde::de::Error::custom)
294    }
295}
296
297#[derive(Debug, Args, Clone)]
298struct SynthesizeArgs {
299    /// Repository URLs (git@github.com:owner/repo or https://...)
300    #[clap(long, required = true, num_args = 1..)]
301    repos: Vec<String>,
302
303    /// Number of examples to generate per repository
304    #[clap(long, default_value_t = 5)]
305    count: usize,
306
307    /// Maximum commits to scan per repository before giving up
308    #[clap(long, default_value_t = 100)]
309    max_commits: usize,
310
311    /// Ignore state file and reprocess all commits
312    #[clap(long)]
313    fresh: bool,
314}
315
316#[derive(Debug, Args, Clone)]
317struct ImportBatchArgs {
318    /// Anthropic batch IDs to import (e.g., msgbatch_xxx)
319    #[clap(long, required = true, num_args = 1..)]
320    batch_ids: Vec<String>,
321}
322
323impl EpArgs {
324    fn output_path(&self) -> Option<PathBuf> {
325        if self.in_place {
326            if self.inputs.len() == 1 {
327                self.inputs.first().cloned()
328            } else {
329                panic!("--in-place requires exactly one input file")
330            }
331        } else {
332            self.output.clone()
333        }
334    }
335}
336
337async fn load_examples(
338    http_client: Arc<dyn http_client::HttpClient>,
339    args: &EpArgs,
340    output_path: Option<&PathBuf>,
341    background_executor: BackgroundExecutor,
342) -> anyhow::Result<Vec<Example>> {
343    let mut captured_after_timestamps = Vec::new();
344    let mut file_inputs = Vec::new();
345
346    for input in &args.inputs {
347        let input_string = input.to_string_lossy();
348        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
349            captured_after_timestamps.push(timestamp.to_string());
350        } else {
351            file_inputs.push(input.clone());
352        }
353    }
354
355    let mut examples = read_example_files(&file_inputs);
356
357    Progress::global().set_total_examples(examples.len());
358
359    let remaining_limit_for_snowflake =
360        args.limit.map(|limit| limit.saturating_sub(examples.len()));
361
362    if let Some(0) = remaining_limit_for_snowflake {
363        log::info!(
364            "skipping captured-after inputs because --limit is already satisfied by example files"
365        );
366    } else if !captured_after_timestamps.is_empty() {
367        captured_after_timestamps.sort();
368
369        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
370
371        let mut captured_examples = pull_examples::fetch_captured_examples_after(
372            http_client,
373            &captured_after_timestamps,
374            max_rows_per_timestamp,
375            background_executor,
376        )
377        .await?;
378        examples.append(&mut captured_examples);
379    }
380
381    crate::example::sort_examples_by_repo_and_rev(&mut examples);
382
383    if let Some(name_filter) = &args.name {
384        examples.retain(|example| example.spec.name.contains(name_filter));
385    }
386    if let Some(repo_filter) = &args.repo {
387        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
388    }
389
390    // Skip resume logic for --in-place since input and output are the same file,
391    // which would incorrectly treat all input examples as already processed.
392    if !args.in_place {
393        if let Some(path) = output_path {
394            resume_from_output(path, &mut examples);
395        }
396    }
397
398    if let Some(offset) = args.offset {
399        examples.splice(0..offset, []);
400    }
401
402    if let Some(limit) = args.limit {
403        examples.truncate(limit);
404    }
405
406    let progress = Progress::global();
407    progress.set_total_examples(examples.len());
408    progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
409
410    Ok(examples)
411}
412
413fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
414    let mut hasher = collections::FxHasher::default();
415    spec.hash(&mut hasher);
416    hasher.finish()
417}
418
419fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
420    let file = match File::open(path) {
421        Ok(f) => f,
422        Err(_) => return,
423    };
424
425    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
426
427    let reader = BufReader::new(file);
428    let mut kept_lines = Vec::new();
429    let mut kept_hashes = HashSet::default();
430
431    for line in reader.lines() {
432        let line = match line {
433            Ok(l) => l,
434            Err(_) => continue,
435        };
436
437        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
438            let hash = spec_hash(&output_example.spec);
439            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
440                kept_hashes.insert(hash);
441                kept_lines.push(line);
442            }
443        }
444    }
445
446    let total = examples.len();
447    let already_processed = kept_hashes.len();
448
449    eprintln!(
450        "Resuming: {}/{} examples already processed",
451        already_processed, total
452    );
453
454    let file = OpenOptions::new()
455        .write(true)
456        .truncate(true)
457        .open(path)
458        .expect("Failed to open output file for rewriting");
459    let mut writer = BufWriter::new(file);
460    for line in &kept_lines {
461        writeln!(writer, "{}", line).expect("Failed to write to output file");
462    }
463    writer.flush().expect("Failed to flush output file");
464
465    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
466}
467
468fn main() {
469    let args = EpArgs::parse();
470
471    if args.printenv {
472        ::util::shell_env::print_env();
473        return;
474    }
475
476    let output = args.output_path();
477    let command = match &args.command {
478        Some(cmd) => cmd.clone(),
479        None => {
480            EpArgs::command().print_help().unwrap();
481            return;
482        }
483    };
484
485    match &command {
486        Command::ImportBatch(import_args) => {
487            smol::block_on(async {
488                let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
489                    .expect("Failed to create Anthropic client");
490                if let Err(e) = client.import_batches(&import_args.batch_ids).await {
491                    eprintln!("Error importing batches: {:?}", e);
492                    std::process::exit(1);
493                }
494                println!(
495                    "Successfully imported {} batch(es)",
496                    import_args.batch_ids.len()
497                );
498            });
499            return;
500        }
501        Command::Clean => {
502            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
503            return;
504        }
505        Command::Synthesize(synth_args) => {
506            let Some(output_dir) = args.output else {
507                panic!("output dir is required");
508            };
509            let config = SynthesizeConfig {
510                repo_urls: synth_args.repos.clone(),
511                count: synth_args.count,
512                max_commits: synth_args.max_commits,
513                output_dir,
514                fresh: synth_args.fresh,
515            };
516            smol::block_on(async {
517                if let Err(e) = run_synthesize(config).await {
518                    eprintln!("Error: {:?}", e);
519                    std::process::exit(1);
520                }
521            });
522            return;
523        }
524        Command::SplitCommit(split_commit_args) => {
525            if let Err(error) = split_commit::run_split_commit(
526                split_commit_args,
527                &args.inputs,
528                output.as_ref(),
529                args.failed,
530            ) {
531                eprintln!("{error:#}");
532                std::process::exit(1);
533            }
534            return;
535        }
536        Command::Split(split_args) => {
537            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
538                eprintln!("{error:#}");
539                std::process::exit(1);
540            }
541            return;
542        }
543        Command::FilterLanguages(filter_args) => {
544            if let Err(error) =
545                run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
546            {
547                eprintln!("{error:#}");
548                std::process::exit(1);
549            }
550            return;
551        }
552        _ => {}
553    }
554
555    let http_client = Arc::new(ReqwestClient::new());
556    let app = Application::headless().with_http_client(http_client);
557
558    app.run(move |cx| {
559        let app_state = Arc::new(headless::init(cx));
560        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
561
562        cx.spawn(async move |cx| {
563            let result = async {
564                let examples = load_examples(
565                    app_state.client.http_client(),
566                    &args,
567                    output.as_ref(),
568                    cx.background_executor().clone(),
569                )
570                .await?;
571
572                match &command {
573                    Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
574                        predict::sync_batches(args.provider.as_ref()).await?;
575                    }
576                    _ => (),
577                }
578
579                let failfast_on_single_example = examples.len() == 1;
580
581                // For --in-place, write to a temp file and rename at the end to avoid data loss on interruption
582                let in_place_temp_path = if args.in_place {
583                    output.as_ref().map(|path| {
584                        let mut temp_path = path.clone();
585                        temp_path.set_extension("jsonl.tmp");
586                        temp_path
587                    })
588                } else {
589                    None
590                };
591
592                let output_sender: Option<mpsc::UnboundedSender<String>> =
593                    if args.output.is_some() || !matches!(command, Command::Eval(_)) {
594                        let write_path = in_place_temp_path.as_ref().or(output.as_ref());
595                        write_path.map(|path| {
596                            let file = if args.in_place {
597                                // For --in-place, write to temp file (truncate if exists)
598                                OpenOptions::new()
599                                    .create(true)
600                                    .write(true)
601                                    .truncate(true)
602                                    .open(path)
603                                    .expect("Failed to open temp output file")
604                            } else {
605                                // For regular output, append to support resuming
606                                OpenOptions::new()
607                                    .create(true)
608                                    .append(true)
609                                    .open(path)
610                                    .expect("Failed to open output file")
611                            };
612                            let mut writer = BufWriter::new(file);
613                            let (sender, mut receiver) = mpsc::unbounded::<String>();
614                            cx.background_spawn(async move {
615                                while let Some(line) = receiver.next().await {
616                                    writeln!(writer, "{}", line).expect("Failed to write example");
617                                    writer.flush().expect("Failed to flush output");
618                                }
619                            })
620                            .detach();
621                            sender
622                        })
623                    } else {
624                        None
625                    };
626
627                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
628                let finished_examples = Mutex::new(Vec::new());
629
630                let mut tasks = Vec::new();
631                for _ in 0..args.max_parallelism {
632                    tasks.push(async {
633                        loop {
634                            let Some(mut repo_examples) =
635                                grouped_examples.lock().unwrap().pop_front()
636                            else {
637                                break;
638                            };
639                            for example in &mut repo_examples {
640                                let example_progress =
641                                    Progress::global().start_group(&example.spec.name);
642
643                                let result = async {
644                                    match &command {
645                                        Command::ParseExample => {}
646                                        Command::LoadProject => {
647                                            run_load_project(
648                                                example,
649                                                app_state.clone(),
650                                                &example_progress,
651                                                cx.clone(),
652                                            )
653                                            .await?;
654                                        }
655                                        Command::Context => {
656                                            run_context_retrieval(
657                                                example,
658                                                app_state.clone(),
659                                                &example_progress,
660                                                cx.clone(),
661                                            )
662                                            .await?;
663                                        }
664                                        Command::FormatPrompt(args) => {
665                                            run_format_prompt(
666                                                example,
667                                                args,
668                                                app_state.clone(),
669                                                &example_progress,
670                                                cx.clone(),
671                                            )
672                                            .await?;
673                                        }
674                                        Command::Predict(args) => {
675                                            run_prediction(
676                                                example,
677                                                args,
678                                                app_state.clone(),
679                                                &example_progress,
680                                                cx.clone(),
681                                            )
682                                            .await?;
683                                        }
684                                        Command::ParseOutput => {
685                                            parse_output::run_parse_output(example)?;
686                                        }
687                                        Command::Distill => {
688                                            run_distill(example).await?;
689                                        }
690                                        Command::Score(args) | Command::Eval(args) => {
691                                            run_scoring(
692                                                example,
693                                                &args,
694                                                app_state.clone(),
695                                                &example_progress,
696                                                cx.clone(),
697                                            )
698                                            .await?;
699                                        }
700                                        Command::Clean
701                                        | Command::Synthesize(_)
702                                        | Command::SplitCommit(_)
703                                        | Command::Split(_)
704                                        | Command::FilterLanguages(_)
705                                        | Command::ImportBatch(_) => {
706                                            unreachable!()
707                                        }
708                                    }
709                                    anyhow::Ok(())
710                                }
711                                .await;
712
713                                let failed = if let Err(error) = result {
714                                    handle_error(
715                                        error,
716                                        &args,
717                                        &command,
718                                        &app_state,
719                                        failfast_on_single_example,
720                                        &example,
721                                    )
722                                    .await;
723                                    true
724                                } else {
725                                    false
726                                };
727
728                                let should_write = !failed || args.failed == FailedHandling::Keep;
729                                if should_write {
730                                    if let Some(ref mut sender) = output_sender.clone() {
731                                        let line = serde_json::to_string(&example).unwrap();
732                                        sender
733                                            .send(line)
734                                            .await
735                                            .expect("Failed to send to output writer");
736                                    } else if args.output.is_none()
737                                        && !matches!(command, Command::Eval(_))
738                                    {
739                                        let line = serde_json::to_string(&example).unwrap();
740                                        println!("{}", line);
741                                    }
742                                }
743                            }
744
745                            let repo_url = &repo_examples.first().unwrap().spec.repository_url;
746                            let project = repo_examples
747                                .iter()
748                                .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
749                                .or_else(|| app_state.project_cache.get(repo_url));
750
751                            if let Some(project) = project {
752                                let mut cx = cx.clone();
753
754                                let shutdown_task: Task<()> =
755                                    project.update(&mut cx, |project, cx| {
756                                        let lsp_store = project.lsp_store();
757                                        lsp_store.update(cx, |lsp_store, cx| {
758                                            lsp_store.shutdown_all_language_servers(cx)
759                                        })
760                                    });
761
762                                shutdown_task.await;
763
764                                if let Some(ep_store) =
765                                    cx.update(|cx| EditPredictionStore::try_global(cx))
766                                {
767                                    ep_store.update(&mut cx, |store, _| {
768                                        store.remove_project(&project);
769                                    });
770                                }
771                            }
772
773                            app_state.project_cache.remove(repo_url);
774                            for example in &mut repo_examples {
775                                example.state.take();
776                            }
777                            finished_examples
778                                .lock()
779                                .unwrap()
780                                .extend_from_slice(&repo_examples);
781                        }
782                    });
783                }
784                futures::future::join_all(tasks).await;
785
786                Progress::global().finalize();
787
788                match &command {
789                    Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
790                        predict::sync_batches(args.provider.as_ref()).await?;
791                    }
792                    _ => (),
793                }
794
795                match &command {
796                    Command::Eval(_) => score::print_report(&finished_examples.lock().unwrap()),
797                    _ => (),
798                };
799
800                // For --in-place, atomically rename temp file to original
801                if let (Some(temp_path), Some(final_path)) = (&in_place_temp_path, &output) {
802                    std::fs::rename(temp_path, final_path)
803                        .expect("Failed to rename temp file to final output");
804                }
805
806                anyhow::Ok(())
807            }
808            .await;
809
810            if let Err(e) = result {
811                panic!("Fatal error: {:?}", e);
812            }
813
814            let _ = cx.update(|cx| cx.quit());
815        })
816        .detach();
817    });
818}
819
820async fn handle_error(
821    error: anyhow::Error,
822    args: &EpArgs,
823    command: &Command,
824    app_state: &Arc<headless::EpAppState>,
825    failfast_on_single_example: bool,
826    example: &Example,
827) {
828    Progress::global().increment_failed();
829
830    let msg;
831    if !matches!(args.failed, FailedHandling::SkipNoFiles) {
832        let example_name = example.spec.filename();
833
834        let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
835        app_state
836            .fs
837            .write(
838                &failed_example_path,
839                &serde_json::to_vec_pretty(&example).unwrap(),
840            )
841            .await
842            .unwrap();
843        let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
844        app_state
845            .fs
846            .write(&err_path, format!("{error:?}").as_bytes())
847            .await
848            .unwrap();
849
850        let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
851        let mut file = OpenOptions::new()
852            .create(true)
853            .append(true)
854            .open(&failed_jsonl_path)
855            .expect("Failed to open failed.jsonl");
856        writeln!(file, "{}", serde_json::to_string(example).unwrap())
857            .expect("Failed to write to failed.jsonl");
858
859        let cursor_path = example
860            .repo_name()
861            .unwrap()
862            .worktree_path()
863            .join(&example.spec.cursor_path);
864        msg = format!(
865            indoc::indoc! {"
866                While processing \"{}\":
867
868                \x1b[31m{:?}\x1b[0m
869
870                Example:        \x1b[36m{}\x1b[0m
871                Error file:     \x1b[36m{}\x1b[0m
872                Cursor file:    \x1b[36m{}\x1b[0m
873                Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
874            "},
875            example.spec.name,
876            error,
877            failed_example_path.display(),
878            err_path.display(),
879            cursor_path.display(),
880            command,
881            failed_example_path.display(),
882        );
883    } else {
884        msg = format!(
885            indoc::indoc! {"
886            While processing \"{}\":
887
888                \x1b[31m{:?}\x1b[0m
889            "},
890            example.spec.name, error
891        );
892    }
893
894    if args.failfast || failfast_on_single_example {
895        Progress::global().finalize();
896        panic!("{}", msg);
897    } else {
898        log::error!("{}", msg);
899    }
900}