main.rs

  1mod anthropic_client;
  2mod distill;
  3mod example;
  4mod filter_languages;
  5mod format_prompt;
  6mod git;
  7mod headless;
  8mod load_project;
  9mod metrics;
 10mod parse_output;
 11mod paths;
 12mod predict;
 13mod progress;
 14mod pull_examples;
 15mod reorder_patch;
 16mod retrieve_context;
 17mod score;
 18mod split_commit;
 19mod split_dataset;
 20mod synthesize;
 21use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
 22use collections::HashSet;
 23use edit_prediction::EditPredictionStore;
 24use futures::channel::mpsc;
 25use futures::{SinkExt as _, StreamExt as _};
 26use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
 27use zeta_prompt::ZetaVersion;
 28
 29use reqwest_client::ReqwestClient;
 30use serde::{Deserialize, Deserializer, Serialize, Serializer};
 31use std::fmt::Display;
 32use std::fs::{File, OpenOptions};
 33use std::hash::{Hash, Hasher};
 34use std::io::{BufRead, BufReader, BufWriter, Write};
 35use std::sync::Mutex;
 36use std::{path::PathBuf, sync::Arc};
 37
 38use crate::distill::run_distill;
 39use crate::example::{Example, group_examples_by_repo, read_example_files};
 40use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
 41use crate::format_prompt::run_format_prompt;
 42use crate::load_project::run_load_project;
 43use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
 44use crate::predict::run_prediction;
 45use crate::progress::Progress;
 46use crate::retrieve_context::run_context_retrieval;
 47use crate::score::run_scoring;
 48use crate::split_commit::SplitCommitArgs;
 49use crate::split_dataset::SplitArgs;
 50use crate::synthesize::{SynthesizeConfig, run_synthesize};
 51
 52#[derive(Parser, Debug)]
 53#[command(name = "ep")]
 54struct EpArgs {
 55    #[arg(long, default_value_t = false)]
 56    printenv: bool,
 57    #[clap(long, default_value_t = 10, global = true)]
 58    max_parallelism: usize,
 59    #[clap(long, global = true)]
 60    limit: Option<usize>,
 61    #[clap(long, global = true)]
 62    offset: Option<usize>,
 63    /// Filter examples by name
 64    #[clap(long, global = true)]
 65    name: Option<String>,
 66    /// Filter examples by repository
 67    #[clap(long, global = true)]
 68    repo: Option<String>,
 69    #[command(subcommand)]
 70    command: Option<Command>,
 71    #[clap(global = true, help = INPUTS_HELP)]
 72    inputs: Vec<PathBuf>,
 73    #[arg(long, short, global = true)]
 74    output: Option<PathBuf>,
 75    #[arg(long, short, global = true)]
 76    in_place: bool,
 77    #[arg(long, short, global = true)]
 78    failfast: bool,
 79    /// How to handle failed examples in output: keep them or skip them.
 80    /// Failed examples are always logged to the run's failed directory.
 81    #[arg(long, global = true, default_value = "keep")]
 82    failed: FailedHandling,
 83}
 84
 85/// Controls whether failed examples are included in the main output.
 86/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
 87#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
 88pub enum FailedHandling {
 89    /// Include failed examples in the main output (default)
 90    #[default]
 91    Keep,
 92    /// Exclude failed examples from the main output
 93    Skip,
 94    /// Skip writing files
 95    SkipNoFiles,
 96}
 97
 98const INPUTS_HELP: &str = r#"
 99Inputs can be file paths or special specifiers:
100
101  path
102      Path to an example(s) file (.md, .json, or .jsonl)
103
104  captured-after:{timestamp}
105      Fetch captured examples from Snowflake after the given RFC3339 timestamp.
106
107      You can specify this multiple times and mix it with file inputs.
108
109      Required environment variables to connect to Snowflake:
110          EP_SNOWFLAKE_API_KEY
111          EP_SNOWFLAKE_BASE_URL
112
113      Optional:
114          EP_SNOWFLAKE_ROLE
115
116Examples:
117
118  # Predict from a file
119  ep predict examples.jsonl
120
121  # Predict from captured examples after a timestamp
122  ep predict captured-after:2025-01-01T00:00:00Z
123
124  # Mix file inputs and captured-after in the same invocation
125  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
126"#;
127
128#[derive(Subcommand, Debug, Clone)]
129enum Command {
130    /// Parse markdown examples and output a combined .jsonl file
131    ParseExample,
132    /// Create git worktrees for each example and load file contents
133    LoadProject,
134    /// Retrieve context for input examples.
135    Context,
136    /// Generate a prompt string for a specific model
137    FormatPrompt(FormatPromptArgs),
138    /// Runs edit prediction
139    Predict(PredictArgs),
140    /// Parse model outputs (actual_output) into unified diffs (actual_patch).
141    /// Requires format-prompt to have been run first. Uses provider from prompt.
142    ParseOutput,
143    /// Computes a score based on actual and expected patches
144    Score(PredictArgs),
145    /// Prepares a distillation dataset by copying expected outputs to
146    /// predicted outputs and removing actual outputs and prompts.
147    Distill,
148    /// Print aggregated scores
149    Eval(PredictArgs),
150    /// Generate eval examples by analyzing git commits from a repository
151    Synthesize(SynthesizeArgs),
152    /// Remove git repositories and worktrees
153    Clean,
154    /// Generate an evaluation example by splitting a chronologically-ordered commit
155    SplitCommit(SplitCommitArgs),
156    /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
157    Split(SplitArgs),
158    /// Filter a JSONL dataset by programming language (based on cursor_path extension)
159    FilterLanguages(FilterLanguagesArgs),
160    /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
161    ImportBatch(ImportBatchArgs),
162}
163
164impl Display for Command {
165    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166        match self {
167            Command::ParseExample => write!(f, "parse-example"),
168            Command::LoadProject => write!(f, "load-project"),
169            Command::Context => write!(f, "context"),
170            Command::FormatPrompt(args) => {
171                write!(f, "format-prompt --provider={}", args.provider)
172            }
173            Command::Predict(args) => match &args.provider {
174                Some(provider) => write!(f, "predict --provider={}", provider),
175                None => write!(f, "predict"),
176            },
177            Command::ParseOutput => write!(f, "parse-output"),
178            Command::Score(args) => match &args.provider {
179                Some(provider) => write!(f, "score --provider={}", provider),
180                None => write!(f, "score"),
181            },
182            Command::Distill => write!(f, "distill"),
183            Command::Eval(args) => match &args.provider {
184                Some(provider) => write!(f, "eval --provider={}", provider),
185                None => write!(f, "eval"),
186            },
187            Command::Synthesize(args) => {
188                write!(f, "synthesize --repos {}", args.repos.join(" "))
189            }
190            Command::Clean => write!(f, "clean"),
191            Command::SplitCommit(_) => write!(f, "split-commit"),
192            Command::Split(_) => write!(f, "split"),
193            Command::FilterLanguages(_) => write!(f, "filter-languages"),
194            Command::ImportBatch(args) => {
195                write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
196            }
197        }
198    }
199}
200
201#[derive(Debug, Args, Clone)]
202struct FormatPromptArgs {
203    #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
204    provider: PredictionProvider,
205}
206
207#[derive(Debug, Args, Clone)]
208struct PredictArgs {
209    #[clap(long, short('p'))]
210    provider: Option<PredictionProvider>,
211    #[clap(long, default_value_t = 1)]
212    repetitions: usize,
213}
214
215#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
216enum PredictionProvider {
217    Sweep,
218    Mercury,
219    Zeta1,
220    Zeta2(ZetaVersion),
221    Teacher(ZetaVersion),
222    TeacherNonBatching(ZetaVersion),
223}
224
225impl Default for PredictionProvider {
226    fn default() -> Self {
227        PredictionProvider::Zeta2(ZetaVersion::default())
228    }
229}
230
231impl std::fmt::Display for PredictionProvider {
232    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233        match self {
234            PredictionProvider::Sweep => write!(f, "sweep"),
235            PredictionProvider::Mercury => write!(f, "mercury"),
236            PredictionProvider::Zeta1 => write!(f, "zeta1"),
237            PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
238            PredictionProvider::Teacher(version) => write!(f, "teacher:{version}"),
239            PredictionProvider::TeacherNonBatching(version) => {
240                write!(f, "teacher-non-batching:{version}")
241            }
242        }
243    }
244}
245
246impl std::str::FromStr for PredictionProvider {
247    type Err = anyhow::Error;
248
249    fn from_str(mut s: &str) -> Result<Self, Self::Err> {
250        let mut version = ZetaVersion::default();
251        if let Some((first, second)) = s.split_once(':') {
252            version = ZetaVersion::parse(second)?;
253            s = first;
254        }
255
256        let s_lower = s.to_lowercase();
257        match s_lower.as_str() {
258            "sweep" => Ok(PredictionProvider::Sweep),
259            "mercury" => Ok(PredictionProvider::Mercury),
260            "zeta1" => Ok(PredictionProvider::Zeta1),
261            "zeta2" => Ok(PredictionProvider::Zeta2(version)),
262            "teacher" => Ok(PredictionProvider::Teacher(version)),
263            "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
264                Ok(PredictionProvider::TeacherNonBatching(version))
265            }
266            _ => {
267                anyhow::bail!(
268                    "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching\n\
269                 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
270                 Available zeta versions:\n{}",
271                    ZetaVersion::options_as_string()
272                )
273            }
274        }
275    }
276}
277
278impl Serialize for PredictionProvider {
279    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
280    where
281        S: Serializer,
282    {
283        serializer.serialize_str(&self.to_string())
284    }
285}
286
287impl<'de> Deserialize<'de> for PredictionProvider {
288    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
289    where
290        D: Deserializer<'de>,
291    {
292        let s = String::deserialize(deserializer)?;
293        s.parse().map_err(serde::de::Error::custom)
294    }
295}
296
297#[derive(Debug, Args, Clone)]
298struct SynthesizeArgs {
299    /// Repository URLs (git@github.com:owner/repo or https://...)
300    #[clap(long, required = true, num_args = 1..)]
301    repos: Vec<String>,
302
303    /// Number of examples to generate per repository
304    #[clap(long, default_value_t = 5)]
305    count: usize,
306
307    /// Maximum commits to scan per repository before giving up
308    #[clap(long, default_value_t = 100)]
309    max_commits: usize,
310
311    /// Ignore state file and reprocess all commits
312    #[clap(long)]
313    fresh: bool,
314}
315
316#[derive(Debug, Args, Clone)]
317struct ImportBatchArgs {
318    /// Anthropic batch IDs to import (e.g., msgbatch_xxx)
319    #[clap(long, required = true, num_args = 1..)]
320    batch_ids: Vec<String>,
321}
322
323impl EpArgs {
324    fn output_path(&self) -> Option<PathBuf> {
325        if self.in_place {
326            if self.inputs.len() == 1 {
327                self.inputs.first().cloned()
328            } else {
329                panic!("--in-place requires exactly one input file")
330            }
331        } else {
332            self.output.clone()
333        }
334    }
335}
336
337async fn load_examples(
338    http_client: Arc<dyn http_client::HttpClient>,
339    args: &EpArgs,
340    output_path: Option<&PathBuf>,
341    background_executor: BackgroundExecutor,
342) -> anyhow::Result<Vec<Example>> {
343    let mut captured_after_timestamps = Vec::new();
344    let mut file_inputs = Vec::new();
345
346    for input in &args.inputs {
347        let input_string = input.to_string_lossy();
348        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
349            captured_after_timestamps.push(timestamp.to_string());
350        } else {
351            file_inputs.push(input.clone());
352        }
353    }
354
355    let mut examples = read_example_files(&file_inputs);
356
357    Progress::global().set_total_examples(examples.len());
358
359    let remaining_limit_for_snowflake =
360        args.limit.map(|limit| limit.saturating_sub(examples.len()));
361
362    if let Some(0) = remaining_limit_for_snowflake {
363        log::info!(
364            "skipping captured-after inputs because --limit is already satisfied by example files"
365        );
366    } else if !captured_after_timestamps.is_empty() {
367        captured_after_timestamps.sort();
368
369        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
370
371        let mut captured_examples = pull_examples::fetch_captured_examples_after(
372            http_client,
373            &captured_after_timestamps,
374            max_rows_per_timestamp,
375            background_executor,
376        )
377        .await?;
378        examples.append(&mut captured_examples);
379    }
380
381    crate::example::sort_examples_by_repo_and_rev(&mut examples);
382
383    if let Some(name_filter) = &args.name {
384        examples.retain(|example| example.spec.name.contains(name_filter));
385    }
386    if let Some(repo_filter) = &args.repo {
387        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
388    }
389
390    // Skip resume logic for --in-place since input and output are the same file,
391    // which would incorrectly treat all input examples as already processed.
392    if !args.in_place {
393        if let Some(path) = output_path {
394            resume_from_output(path, &mut examples);
395        }
396    }
397
398    if let Some(offset) = args.offset {
399        examples.splice(0..offset, []);
400    }
401
402    if let Some(limit) = args.limit {
403        examples.truncate(limit);
404    }
405
406    Progress::global().set_total_examples(examples.len());
407
408    Ok(examples)
409}
410
411fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
412    let mut hasher = collections::FxHasher::default();
413    spec.hash(&mut hasher);
414    hasher.finish()
415}
416
417fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
418    let file = match File::open(path) {
419        Ok(f) => f,
420        Err(_) => return,
421    };
422
423    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
424
425    let reader = BufReader::new(file);
426    let mut kept_lines = Vec::new();
427    let mut kept_hashes = HashSet::default();
428
429    for line in reader.lines() {
430        let line = match line {
431            Ok(l) => l,
432            Err(_) => continue,
433        };
434
435        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
436            let hash = spec_hash(&output_example.spec);
437            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
438                kept_hashes.insert(hash);
439                kept_lines.push(line);
440            }
441        }
442    }
443
444    let total = examples.len();
445    let already_processed = kept_hashes.len();
446
447    eprintln!(
448        "Resuming: {}/{} examples already processed",
449        already_processed, total
450    );
451
452    let file = OpenOptions::new()
453        .write(true)
454        .truncate(true)
455        .open(path)
456        .expect("Failed to open output file for rewriting");
457    let mut writer = BufWriter::new(file);
458    for line in &kept_lines {
459        writeln!(writer, "{}", line).expect("Failed to write to output file");
460    }
461    writer.flush().expect("Failed to flush output file");
462
463    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
464}
465
466fn main() {
467    let args = EpArgs::parse();
468
469    if args.printenv {
470        ::util::shell_env::print_env();
471        return;
472    }
473
474    let output = args.output_path();
475    let command = match &args.command {
476        Some(cmd) => cmd.clone(),
477        None => {
478            EpArgs::command().print_help().unwrap();
479            return;
480        }
481    };
482
483    match &command {
484        Command::ImportBatch(import_args) => {
485            smol::block_on(async {
486                let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
487                    .expect("Failed to create Anthropic client");
488                if let Err(e) = client.import_batches(&import_args.batch_ids).await {
489                    eprintln!("Error importing batches: {:?}", e);
490                    std::process::exit(1);
491                }
492                println!(
493                    "Successfully imported {} batch(es)",
494                    import_args.batch_ids.len()
495                );
496            });
497            return;
498        }
499        Command::Clean => {
500            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
501            return;
502        }
503        Command::Synthesize(synth_args) => {
504            let Some(output_dir) = args.output else {
505                panic!("output dir is required");
506            };
507            let config = SynthesizeConfig {
508                repo_urls: synth_args.repos.clone(),
509                count: synth_args.count,
510                max_commits: synth_args.max_commits,
511                output_dir,
512                fresh: synth_args.fresh,
513            };
514            smol::block_on(async {
515                if let Err(e) = run_synthesize(config).await {
516                    eprintln!("Error: {:?}", e);
517                    std::process::exit(1);
518                }
519            });
520            return;
521        }
522        Command::SplitCommit(split_commit_args) => {
523            if let Err(error) = split_commit::run_split_commit(
524                split_commit_args,
525                &args.inputs,
526                output.as_ref(),
527                args.failed,
528            ) {
529                eprintln!("{error:#}");
530                std::process::exit(1);
531            }
532            return;
533        }
534        Command::Split(split_args) => {
535            if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
536                eprintln!("{error:#}");
537                std::process::exit(1);
538            }
539            return;
540        }
541        Command::FilterLanguages(filter_args) => {
542            if let Err(error) =
543                run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
544            {
545                eprintln!("{error:#}");
546                std::process::exit(1);
547            }
548            return;
549        }
550        _ => {}
551    }
552
553    let http_client = Arc::new(ReqwestClient::new());
554    let app = Application::headless().with_http_client(http_client);
555
556    app.run(move |cx| {
557        let app_state = Arc::new(headless::init(cx));
558        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
559
560        cx.spawn(async move |cx| {
561            let result = async {
562                let examples = load_examples(
563                    app_state.client.http_client(),
564                    &args,
565                    output.as_ref(),
566                    cx.background_executor().clone(),
567                )
568                .await?;
569
570                match &command {
571                    Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
572                        predict::sync_batches(args.provider.as_ref()).await?;
573                    }
574                    _ => (),
575                }
576
577                let failfast_on_single_example = examples.len() == 1;
578
579                // For --in-place, write to a temp file and rename at the end to avoid data loss on interruption
580                let in_place_temp_path = if args.in_place {
581                    output.as_ref().map(|path| {
582                        let mut temp_path = path.clone();
583                        temp_path.set_extension("jsonl.tmp");
584                        temp_path
585                    })
586                } else {
587                    None
588                };
589
590                let output_sender: Option<mpsc::UnboundedSender<String>> =
591                    if args.output.is_some() || !matches!(command, Command::Eval(_)) {
592                        let write_path = in_place_temp_path.as_ref().or(output.as_ref());
593                        write_path.map(|path| {
594                            let file = if args.in_place {
595                                // For --in-place, write to temp file (truncate if exists)
596                                OpenOptions::new()
597                                    .create(true)
598                                    .write(true)
599                                    .truncate(true)
600                                    .open(path)
601                                    .expect("Failed to open temp output file")
602                            } else {
603                                // For regular output, append to support resuming
604                                OpenOptions::new()
605                                    .create(true)
606                                    .append(true)
607                                    .open(path)
608                                    .expect("Failed to open output file")
609                            };
610                            let mut writer = BufWriter::new(file);
611                            let (sender, mut receiver) = mpsc::unbounded::<String>();
612                            cx.background_spawn(async move {
613                                while let Some(line) = receiver.next().await {
614                                    writeln!(writer, "{}", line).expect("Failed to write example");
615                                    writer.flush().expect("Failed to flush output");
616                                }
617                            })
618                            .detach();
619                            sender
620                        })
621                    } else {
622                        None
623                    };
624
625                let grouped_examples = Mutex::new(group_examples_by_repo(examples));
626                let finished_examples = Mutex::new(Vec::new());
627
628                let mut tasks = Vec::new();
629                for _ in 0..args.max_parallelism {
630                    tasks.push(async {
631                        loop {
632                            let Some(mut repo_examples) =
633                                grouped_examples.lock().unwrap().pop_front()
634                            else {
635                                break;
636                            };
637                            for example in &mut repo_examples {
638                                let example_progress =
639                                    Progress::global().start_group(&example.spec.name);
640
641                                let result = async {
642                                    match &command {
643                                        Command::ParseExample => {}
644                                        Command::LoadProject => {
645                                            run_load_project(
646                                                example,
647                                                app_state.clone(),
648                                                &example_progress,
649                                                cx.clone(),
650                                            )
651                                            .await?;
652                                        }
653                                        Command::Context => {
654                                            run_context_retrieval(
655                                                example,
656                                                app_state.clone(),
657                                                &example_progress,
658                                                cx.clone(),
659                                            )
660                                            .await?;
661                                        }
662                                        Command::FormatPrompt(args) => {
663                                            run_format_prompt(
664                                                example,
665                                                args,
666                                                app_state.clone(),
667                                                &example_progress,
668                                                cx.clone(),
669                                            )
670                                            .await?;
671                                        }
672                                        Command::Predict(args) => {
673                                            run_prediction(
674                                                example,
675                                                args,
676                                                app_state.clone(),
677                                                &example_progress,
678                                                cx.clone(),
679                                            )
680                                            .await?;
681                                        }
682                                        Command::ParseOutput => {
683                                            parse_output::run_parse_output(example)?;
684                                        }
685                                        Command::Distill => {
686                                            run_distill(example).await?;
687                                        }
688                                        Command::Score(args) | Command::Eval(args) => {
689                                            run_scoring(
690                                                example,
691                                                &args,
692                                                app_state.clone(),
693                                                &example_progress,
694                                                cx.clone(),
695                                            )
696                                            .await?;
697                                        }
698                                        Command::Clean
699                                        | Command::Synthesize(_)
700                                        | Command::SplitCommit(_)
701                                        | Command::Split(_)
702                                        | Command::FilterLanguages(_)
703                                        | Command::ImportBatch(_) => {
704                                            unreachable!()
705                                        }
706                                    }
707                                    anyhow::Ok(())
708                                }
709                                .await;
710
711                                let failed = if let Err(error) = result {
712                                    handle_error(
713                                        error,
714                                        &args,
715                                        &command,
716                                        &app_state,
717                                        failfast_on_single_example,
718                                        &example,
719                                    )
720                                    .await;
721                                    true
722                                } else {
723                                    false
724                                };
725
726                                let should_write = !failed || args.failed == FailedHandling::Keep;
727                                if should_write {
728                                    if let Some(ref mut sender) = output_sender.clone() {
729                                        let line = serde_json::to_string(&example).unwrap();
730                                        sender
731                                            .send(line)
732                                            .await
733                                            .expect("Failed to send to output writer");
734                                    } else if args.output.is_none()
735                                        && !matches!(command, Command::Eval(_))
736                                    {
737                                        let line = serde_json::to_string(&example).unwrap();
738                                        println!("{}", line);
739                                    }
740                                }
741                            }
742
743                            let repo_url = &repo_examples.first().unwrap().spec.repository_url;
744                            let project = repo_examples
745                                .iter()
746                                .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
747                                .or_else(|| app_state.project_cache.get(repo_url));
748
749                            if let Some(project) = project {
750                                let mut cx = cx.clone();
751
752                                let shutdown_task: Task<()> =
753                                    project.update(&mut cx, |project, cx| {
754                                        let lsp_store = project.lsp_store();
755                                        lsp_store.update(cx, |lsp_store, cx| {
756                                            lsp_store.shutdown_all_language_servers(cx)
757                                        })
758                                    });
759
760                                shutdown_task.await;
761
762                                if let Some(ep_store) =
763                                    cx.update(|cx| EditPredictionStore::try_global(cx))
764                                {
765                                    ep_store.update(&mut cx, |store, _| {
766                                        store.remove_project(&project);
767                                    });
768                                }
769                            }
770
771                            app_state.project_cache.remove(repo_url);
772                            for example in &mut repo_examples {
773                                example.state.take();
774                            }
775                            finished_examples
776                                .lock()
777                                .unwrap()
778                                .extend_from_slice(&repo_examples);
779                        }
780                    });
781                }
782                futures::future::join_all(tasks).await;
783
784                Progress::global().finalize();
785
786                match &command {
787                    Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
788                        predict::sync_batches(args.provider.as_ref()).await?;
789                    }
790                    _ => (),
791                }
792
793                match &command {
794                    Command::Eval(_) => score::print_report(&finished_examples.lock().unwrap()),
795                    _ => (),
796                };
797
798                // For --in-place, atomically rename temp file to original
799                if let (Some(temp_path), Some(final_path)) = (&in_place_temp_path, &output) {
800                    std::fs::rename(temp_path, final_path)
801                        .expect("Failed to rename temp file to final output");
802                }
803
804                anyhow::Ok(())
805            }
806            .await;
807
808            if let Err(e) = result {
809                panic!("Fatal error: {:?}", e);
810            }
811
812            let _ = cx.update(|cx| cx.quit());
813        })
814        .detach();
815    });
816}
817
818async fn handle_error(
819    error: anyhow::Error,
820    args: &EpArgs,
821    command: &Command,
822    app_state: &Arc<headless::EpAppState>,
823    failfast_on_single_example: bool,
824    example: &Example,
825) {
826    Progress::global().increment_failed();
827
828    let msg;
829    if !matches!(args.failed, FailedHandling::SkipNoFiles) {
830        let example_name = example.spec.filename();
831
832        let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
833        app_state
834            .fs
835            .write(
836                &failed_example_path,
837                &serde_json::to_vec_pretty(&example).unwrap(),
838            )
839            .await
840            .unwrap();
841        let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
842        app_state
843            .fs
844            .write(&err_path, format!("{error:?}").as_bytes())
845            .await
846            .unwrap();
847
848        let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
849        let mut file = OpenOptions::new()
850            .create(true)
851            .append(true)
852            .open(&failed_jsonl_path)
853            .expect("Failed to open failed.jsonl");
854        writeln!(file, "{}", serde_json::to_string(example).unwrap())
855            .expect("Failed to write to failed.jsonl");
856
857        let cursor_path = example
858            .repo_name()
859            .unwrap()
860            .worktree_path()
861            .join(&example.spec.cursor_path);
862        msg = format!(
863            indoc::indoc! {"
864                While processing \"{}\":
865
866                \x1b[31m{:?}\x1b[0m
867
868                Example:        \x1b[36m{}\x1b[0m
869                Error file:     \x1b[36m{}\x1b[0m
870                Cursor file:    \x1b[36m{}\x1b[0m
871                Re-run:         cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
872            "},
873            example.spec.name,
874            error,
875            failed_example_path.display(),
876            err_path.display(),
877            cursor_path.display(),
878            command,
879            failed_example_path.display(),
880        );
881    } else {
882        msg = format!(
883            indoc::indoc! {"
884            While processing \"{}\":
885
886                \x1b[31m{:?}\x1b[0m
887            "},
888            example.spec.name, error
889        );
890    }
891
892    if args.failfast || failfast_on_single_example {
893        Progress::global().finalize();
894        panic!("{}", msg);
895    } else {
896        log::error!("{}", msg);
897    }
898}