1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod load_project;
9mod metrics;
10mod parse_output;
11mod paths;
12mod predict;
13mod progress;
14mod pull_examples;
15mod reorder_patch;
16mod retrieve_context;
17mod score;
18mod split_commit;
19mod split_dataset;
20mod synthesize;
21use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
22use collections::HashSet;
23use edit_prediction::EditPredictionStore;
24use futures::channel::mpsc;
25use futures::{SinkExt as _, StreamExt as _};
26use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
27use zeta_prompt::ZetaVersion;
28
29use reqwest_client::ReqwestClient;
30use serde::{Deserialize, Deserializer, Serialize, Serializer};
31use std::fmt::Display;
32use std::fs::{File, OpenOptions};
33use std::hash::{Hash, Hasher};
34use std::io::{BufRead, BufReader, BufWriter, Write};
35use std::sync::Mutex;
36use std::{path::PathBuf, sync::Arc};
37
38use crate::distill::run_distill;
39use crate::example::{Example, group_examples_by_repo, read_example_files};
40use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
41use crate::format_prompt::run_format_prompt;
42use crate::load_project::run_load_project;
43use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
44use crate::predict::run_prediction;
45use crate::progress::Progress;
46use crate::retrieve_context::run_context_retrieval;
47use crate::score::run_scoring;
48use crate::split_commit::SplitCommitArgs;
49use crate::split_dataset::SplitArgs;
50use crate::synthesize::{SynthesizeConfig, run_synthesize};
51
52#[derive(Parser, Debug)]
53#[command(name = "ep")]
54struct EpArgs {
55 #[arg(long, default_value_t = false)]
56 printenv: bool,
57 #[clap(long, default_value_t = 10, global = true)]
58 max_parallelism: usize,
59 #[clap(long, global = true)]
60 limit: Option<usize>,
61 #[clap(long, global = true)]
62 offset: Option<usize>,
63 /// Filter examples by name
64 #[clap(long, global = true)]
65 name: Option<String>,
66 /// Filter examples by repository
67 #[clap(long, global = true)]
68 repo: Option<String>,
69 #[command(subcommand)]
70 command: Option<Command>,
71 #[clap(global = true, help = INPUTS_HELP)]
72 inputs: Vec<PathBuf>,
73 #[arg(long, short, global = true)]
74 output: Option<PathBuf>,
75 #[arg(long, short, global = true)]
76 in_place: bool,
77 #[arg(long, short, global = true)]
78 failfast: bool,
79 /// How to handle failed examples in output: keep them or skip them.
80 /// Failed examples are always logged to the run's failed directory.
81 #[arg(long, global = true, default_value = "keep")]
82 failed: FailedHandling,
83}
84
85/// Controls whether failed examples are included in the main output.
86/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
87#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
88pub enum FailedHandling {
89 /// Include failed examples in the main output (default)
90 #[default]
91 Keep,
92 /// Exclude failed examples from the main output
93 Skip,
94 /// Skip writing files
95 SkipNoFiles,
96}
97
98const INPUTS_HELP: &str = r#"
99Inputs can be file paths or special specifiers:
100
101 path
102 Path to an example(s) file (.md, .json, or .jsonl)
103
104 captured-after:{timestamp}
105 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
106
107 You can specify this multiple times and mix it with file inputs.
108
109 Required environment variables to connect to Snowflake:
110 EP_SNOWFLAKE_API_KEY
111 EP_SNOWFLAKE_BASE_URL
112
113 Optional:
114 EP_SNOWFLAKE_ROLE
115
116Examples:
117
118 # Predict from a file
119 ep predict examples.jsonl
120
121 # Predict from captured examples after a timestamp
122 ep predict captured-after:2025-01-01T00:00:00Z
123
124 # Mix file inputs and captured-after in the same invocation
125 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
126"#;
127
128#[derive(Subcommand, Debug, Clone)]
129enum Command {
130 /// Parse markdown examples and output a combined .jsonl file
131 ParseExample,
132 /// Create git worktrees for each example and load file contents
133 LoadProject,
134 /// Retrieve context for input examples.
135 Context,
136 /// Generate a prompt string for a specific model
137 FormatPrompt(FormatPromptArgs),
138 /// Runs edit prediction
139 Predict(PredictArgs),
140 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
141 /// Requires format-prompt to have been run first. Uses provider from prompt.
142 ParseOutput,
143 /// Computes a score based on actual and expected patches
144 Score(PredictArgs),
145 /// Prepares a distillation dataset by copying expected outputs to
146 /// predicted outputs and removing actual outputs and prompts.
147 Distill,
148 /// Print aggregated scores
149 Eval(PredictArgs),
150 /// Generate eval examples by analyzing git commits from a repository
151 Synthesize(SynthesizeArgs),
152 /// Remove git repositories and worktrees
153 Clean,
154 /// Generate an evaluation example by splitting a chronologically-ordered commit
155 SplitCommit(SplitCommitArgs),
156 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
157 Split(SplitArgs),
158 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
159 FilterLanguages(FilterLanguagesArgs),
160 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
161 ImportBatch(ImportBatchArgs),
162}
163
164impl Display for Command {
165 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166 match self {
167 Command::ParseExample => write!(f, "parse-example"),
168 Command::LoadProject => write!(f, "load-project"),
169 Command::Context => write!(f, "context"),
170 Command::FormatPrompt(args) => {
171 write!(f, "format-prompt --provider={}", args.provider)
172 }
173 Command::Predict(args) => match &args.provider {
174 Some(provider) => write!(f, "predict --provider={}", provider),
175 None => write!(f, "predict"),
176 },
177 Command::ParseOutput => write!(f, "parse-output"),
178 Command::Score(args) => match &args.provider {
179 Some(provider) => write!(f, "score --provider={}", provider),
180 None => write!(f, "score"),
181 },
182 Command::Distill => write!(f, "distill"),
183 Command::Eval(args) => match &args.provider {
184 Some(provider) => write!(f, "eval --provider={}", provider),
185 None => write!(f, "eval"),
186 },
187 Command::Synthesize(args) => {
188 write!(f, "synthesize --repos {}", args.repos.join(" "))
189 }
190 Command::Clean => write!(f, "clean"),
191 Command::SplitCommit(_) => write!(f, "split-commit"),
192 Command::Split(_) => write!(f, "split"),
193 Command::FilterLanguages(_) => write!(f, "filter-languages"),
194 Command::ImportBatch(args) => {
195 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
196 }
197 }
198 }
199}
200
201#[derive(Debug, Args, Clone)]
202struct FormatPromptArgs {
203 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
204 provider: PredictionProvider,
205}
206
207#[derive(Debug, Args, Clone)]
208struct PredictArgs {
209 #[clap(long, short('p'))]
210 provider: Option<PredictionProvider>,
211 #[clap(long, default_value_t = 1)]
212 repetitions: usize,
213}
214
215#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
216enum PredictionProvider {
217 Sweep,
218 Mercury,
219 Zeta1,
220 Zeta2(ZetaVersion),
221 Teacher(ZetaVersion),
222 TeacherNonBatching(ZetaVersion),
223}
224
225impl Default for PredictionProvider {
226 fn default() -> Self {
227 PredictionProvider::Zeta2(ZetaVersion::default())
228 }
229}
230
231impl std::fmt::Display for PredictionProvider {
232 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233 match self {
234 PredictionProvider::Sweep => write!(f, "sweep"),
235 PredictionProvider::Mercury => write!(f, "mercury"),
236 PredictionProvider::Zeta1 => write!(f, "zeta1"),
237 PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
238 PredictionProvider::Teacher(version) => write!(f, "teacher:{version}"),
239 PredictionProvider::TeacherNonBatching(version) => {
240 write!(f, "teacher-non-batching:{version}")
241 }
242 }
243 }
244}
245
246impl std::str::FromStr for PredictionProvider {
247 type Err = anyhow::Error;
248
249 fn from_str(mut s: &str) -> Result<Self, Self::Err> {
250 let mut version = ZetaVersion::default();
251 if let Some((first, second)) = s.split_once(':') {
252 version = ZetaVersion::parse(second)?;
253 s = first;
254 }
255
256 let s_lower = s.to_lowercase();
257 match s_lower.as_str() {
258 "sweep" => Ok(PredictionProvider::Sweep),
259 "mercury" => Ok(PredictionProvider::Mercury),
260 "zeta1" => Ok(PredictionProvider::Zeta1),
261 "zeta2" => Ok(PredictionProvider::Zeta2(version)),
262 "teacher" => Ok(PredictionProvider::Teacher(version)),
263 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
264 Ok(PredictionProvider::TeacherNonBatching(version))
265 }
266 _ => {
267 anyhow::bail!(
268 "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching\n\
269 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
270 Available zeta versions:\n{}",
271 ZetaVersion::options_as_string()
272 )
273 }
274 }
275 }
276}
277
278impl Serialize for PredictionProvider {
279 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
280 where
281 S: Serializer,
282 {
283 serializer.serialize_str(&self.to_string())
284 }
285}
286
287impl<'de> Deserialize<'de> for PredictionProvider {
288 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
289 where
290 D: Deserializer<'de>,
291 {
292 let s = String::deserialize(deserializer)?;
293 s.parse().map_err(serde::de::Error::custom)
294 }
295}
296
297#[derive(Debug, Args, Clone)]
298struct SynthesizeArgs {
299 /// Repository URLs (git@github.com:owner/repo or https://...)
300 #[clap(long, required = true, num_args = 1..)]
301 repos: Vec<String>,
302
303 /// Number of examples to generate per repository
304 #[clap(long, default_value_t = 5)]
305 count: usize,
306
307 /// Maximum commits to scan per repository before giving up
308 #[clap(long, default_value_t = 100)]
309 max_commits: usize,
310
311 /// Ignore state file and reprocess all commits
312 #[clap(long)]
313 fresh: bool,
314}
315
316#[derive(Debug, Args, Clone)]
317struct ImportBatchArgs {
318 /// Anthropic batch IDs to import (e.g., msgbatch_xxx)
319 #[clap(long, required = true, num_args = 1..)]
320 batch_ids: Vec<String>,
321}
322
323impl EpArgs {
324 fn output_path(&self) -> Option<PathBuf> {
325 if self.in_place {
326 if self.inputs.len() == 1 {
327 self.inputs.first().cloned()
328 } else {
329 panic!("--in-place requires exactly one input file")
330 }
331 } else {
332 self.output.clone()
333 }
334 }
335}
336
337async fn load_examples(
338 http_client: Arc<dyn http_client::HttpClient>,
339 args: &EpArgs,
340 output_path: Option<&PathBuf>,
341 background_executor: BackgroundExecutor,
342) -> anyhow::Result<Vec<Example>> {
343 let mut captured_after_timestamps = Vec::new();
344 let mut file_inputs = Vec::new();
345
346 for input in &args.inputs {
347 let input_string = input.to_string_lossy();
348 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
349 captured_after_timestamps.push(timestamp.to_string());
350 } else {
351 file_inputs.push(input.clone());
352 }
353 }
354
355 let mut examples = read_example_files(&file_inputs);
356
357 Progress::global().set_total_examples(examples.len());
358
359 let remaining_limit_for_snowflake =
360 args.limit.map(|limit| limit.saturating_sub(examples.len()));
361
362 if let Some(0) = remaining_limit_for_snowflake {
363 log::info!(
364 "skipping captured-after inputs because --limit is already satisfied by example files"
365 );
366 } else if !captured_after_timestamps.is_empty() {
367 captured_after_timestamps.sort();
368
369 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
370
371 let mut captured_examples = pull_examples::fetch_captured_examples_after(
372 http_client,
373 &captured_after_timestamps,
374 max_rows_per_timestamp,
375 background_executor,
376 )
377 .await?;
378 examples.append(&mut captured_examples);
379 }
380
381 crate::example::sort_examples_by_repo_and_rev(&mut examples);
382
383 if let Some(name_filter) = &args.name {
384 examples.retain(|example| example.spec.name.contains(name_filter));
385 }
386 if let Some(repo_filter) = &args.repo {
387 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
388 }
389
390 // Skip resume logic for --in-place since input and output are the same file,
391 // which would incorrectly treat all input examples as already processed.
392 if !args.in_place {
393 if let Some(path) = output_path {
394 resume_from_output(path, &mut examples);
395 }
396 }
397
398 if let Some(offset) = args.offset {
399 examples.splice(0..offset, []);
400 }
401
402 if let Some(limit) = args.limit {
403 examples.truncate(limit);
404 }
405
406 let progress = Progress::global();
407 progress.set_total_examples(examples.len());
408 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
409
410 Ok(examples)
411}
412
413fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
414 let mut hasher = collections::FxHasher::default();
415 spec.hash(&mut hasher);
416 hasher.finish()
417}
418
419fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
420 let file = match File::open(path) {
421 Ok(f) => f,
422 Err(_) => return,
423 };
424
425 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
426
427 let reader = BufReader::new(file);
428 let mut kept_lines = Vec::new();
429 let mut kept_hashes = HashSet::default();
430
431 for line in reader.lines() {
432 let line = match line {
433 Ok(l) => l,
434 Err(_) => continue,
435 };
436
437 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
438 let hash = spec_hash(&output_example.spec);
439 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
440 kept_hashes.insert(hash);
441 kept_lines.push(line);
442 }
443 }
444 }
445
446 let total = examples.len();
447 let already_processed = kept_hashes.len();
448
449 eprintln!(
450 "Resuming: {}/{} examples already processed",
451 already_processed, total
452 );
453
454 let file = OpenOptions::new()
455 .write(true)
456 .truncate(true)
457 .open(path)
458 .expect("Failed to open output file for rewriting");
459 let mut writer = BufWriter::new(file);
460 for line in &kept_lines {
461 writeln!(writer, "{}", line).expect("Failed to write to output file");
462 }
463 writer.flush().expect("Failed to flush output file");
464
465 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
466}
467
468fn main() {
469 let args = EpArgs::parse();
470
471 if args.printenv {
472 ::util::shell_env::print_env();
473 return;
474 }
475
476 let output = args.output_path();
477 let command = match &args.command {
478 Some(cmd) => cmd.clone(),
479 None => {
480 EpArgs::command().print_help().unwrap();
481 return;
482 }
483 };
484
485 match &command {
486 Command::ImportBatch(import_args) => {
487 smol::block_on(async {
488 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
489 .expect("Failed to create Anthropic client");
490 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
491 eprintln!("Error importing batches: {:?}", e);
492 std::process::exit(1);
493 }
494 println!(
495 "Successfully imported {} batch(es)",
496 import_args.batch_ids.len()
497 );
498 });
499 return;
500 }
501 Command::Clean => {
502 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
503 return;
504 }
505 Command::Synthesize(synth_args) => {
506 let Some(output_dir) = args.output else {
507 panic!("output dir is required");
508 };
509 let config = SynthesizeConfig {
510 repo_urls: synth_args.repos.clone(),
511 count: synth_args.count,
512 max_commits: synth_args.max_commits,
513 output_dir,
514 fresh: synth_args.fresh,
515 };
516 smol::block_on(async {
517 if let Err(e) = run_synthesize(config).await {
518 eprintln!("Error: {:?}", e);
519 std::process::exit(1);
520 }
521 });
522 return;
523 }
524 Command::SplitCommit(split_commit_args) => {
525 if let Err(error) = split_commit::run_split_commit(
526 split_commit_args,
527 &args.inputs,
528 output.as_ref(),
529 args.failed,
530 ) {
531 eprintln!("{error:#}");
532 std::process::exit(1);
533 }
534 return;
535 }
536 Command::Split(split_args) => {
537 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
538 eprintln!("{error:#}");
539 std::process::exit(1);
540 }
541 return;
542 }
543 Command::FilterLanguages(filter_args) => {
544 if let Err(error) =
545 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
546 {
547 eprintln!("{error:#}");
548 std::process::exit(1);
549 }
550 return;
551 }
552 _ => {}
553 }
554
555 let http_client = Arc::new(ReqwestClient::new());
556 let app = Application::headless().with_http_client(http_client);
557
558 app.run(move |cx| {
559 let app_state = Arc::new(headless::init(cx));
560 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
561
562 cx.spawn(async move |cx| {
563 let result = async {
564 let examples = load_examples(
565 app_state.client.http_client(),
566 &args,
567 output.as_ref(),
568 cx.background_executor().clone(),
569 )
570 .await?;
571
572 match &command {
573 Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
574 predict::sync_batches(args.provider.as_ref()).await?;
575 }
576 _ => (),
577 }
578
579 let failfast_on_single_example = examples.len() == 1;
580
581 // For --in-place, write to a temp file and rename at the end to avoid data loss on interruption
582 let in_place_temp_path = if args.in_place {
583 output.as_ref().map(|path| {
584 let mut temp_path = path.clone();
585 temp_path.set_extension("jsonl.tmp");
586 temp_path
587 })
588 } else {
589 None
590 };
591
592 let output_sender: Option<mpsc::UnboundedSender<String>> =
593 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
594 let write_path = in_place_temp_path.as_ref().or(output.as_ref());
595 write_path.map(|path| {
596 let file = if args.in_place {
597 // For --in-place, write to temp file (truncate if exists)
598 OpenOptions::new()
599 .create(true)
600 .write(true)
601 .truncate(true)
602 .open(path)
603 .expect("Failed to open temp output file")
604 } else {
605 // For regular output, append to support resuming
606 OpenOptions::new()
607 .create(true)
608 .append(true)
609 .open(path)
610 .expect("Failed to open output file")
611 };
612 let mut writer = BufWriter::new(file);
613 let (sender, mut receiver) = mpsc::unbounded::<String>();
614 cx.background_spawn(async move {
615 while let Some(line) = receiver.next().await {
616 writeln!(writer, "{}", line).expect("Failed to write example");
617 writer.flush().expect("Failed to flush output");
618 }
619 })
620 .detach();
621 sender
622 })
623 } else {
624 None
625 };
626
627 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
628 let finished_examples = Mutex::new(Vec::new());
629
630 let mut tasks = Vec::new();
631 for _ in 0..args.max_parallelism {
632 tasks.push(async {
633 loop {
634 let Some(mut repo_examples) =
635 grouped_examples.lock().unwrap().pop_front()
636 else {
637 break;
638 };
639 for example in &mut repo_examples {
640 let example_progress =
641 Progress::global().start_group(&example.spec.name);
642
643 let result = async {
644 match &command {
645 Command::ParseExample => {}
646 Command::LoadProject => {
647 run_load_project(
648 example,
649 app_state.clone(),
650 &example_progress,
651 cx.clone(),
652 )
653 .await?;
654 }
655 Command::Context => {
656 run_context_retrieval(
657 example,
658 app_state.clone(),
659 &example_progress,
660 cx.clone(),
661 )
662 .await?;
663 }
664 Command::FormatPrompt(args) => {
665 run_format_prompt(
666 example,
667 args,
668 app_state.clone(),
669 &example_progress,
670 cx.clone(),
671 )
672 .await?;
673 }
674 Command::Predict(args) => {
675 run_prediction(
676 example,
677 args,
678 app_state.clone(),
679 &example_progress,
680 cx.clone(),
681 )
682 .await?;
683 }
684 Command::ParseOutput => {
685 parse_output::run_parse_output(example)?;
686 }
687 Command::Distill => {
688 run_distill(example).await?;
689 }
690 Command::Score(args) | Command::Eval(args) => {
691 run_scoring(
692 example,
693 &args,
694 app_state.clone(),
695 &example_progress,
696 cx.clone(),
697 )
698 .await?;
699 }
700 Command::Clean
701 | Command::Synthesize(_)
702 | Command::SplitCommit(_)
703 | Command::Split(_)
704 | Command::FilterLanguages(_)
705 | Command::ImportBatch(_) => {
706 unreachable!()
707 }
708 }
709 anyhow::Ok(())
710 }
711 .await;
712
713 let failed = if let Err(error) = result {
714 handle_error(
715 error,
716 &args,
717 &command,
718 &app_state,
719 failfast_on_single_example,
720 &example,
721 )
722 .await;
723 true
724 } else {
725 false
726 };
727
728 let should_write = !failed || args.failed == FailedHandling::Keep;
729 if should_write {
730 if let Some(ref mut sender) = output_sender.clone() {
731 let line = serde_json::to_string(&example).unwrap();
732 sender
733 .send(line)
734 .await
735 .expect("Failed to send to output writer");
736 } else if args.output.is_none()
737 && !matches!(command, Command::Eval(_))
738 {
739 let line = serde_json::to_string(&example).unwrap();
740 println!("{}", line);
741 }
742 }
743 }
744
745 let repo_url = &repo_examples.first().unwrap().spec.repository_url;
746 let project = repo_examples
747 .iter()
748 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
749 .or_else(|| app_state.project_cache.get(repo_url));
750
751 if let Some(project) = project {
752 let mut cx = cx.clone();
753
754 let shutdown_task: Task<()> =
755 project.update(&mut cx, |project, cx| {
756 let lsp_store = project.lsp_store();
757 lsp_store.update(cx, |lsp_store, cx| {
758 lsp_store.shutdown_all_language_servers(cx)
759 })
760 });
761
762 shutdown_task.await;
763
764 if let Some(ep_store) =
765 cx.update(|cx| EditPredictionStore::try_global(cx))
766 {
767 ep_store.update(&mut cx, |store, _| {
768 store.remove_project(&project);
769 });
770 }
771 }
772
773 app_state.project_cache.remove(repo_url);
774 for example in &mut repo_examples {
775 example.state.take();
776 }
777 finished_examples
778 .lock()
779 .unwrap()
780 .extend_from_slice(&repo_examples);
781 }
782 });
783 }
784 futures::future::join_all(tasks).await;
785
786 Progress::global().finalize();
787
788 match &command {
789 Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
790 predict::sync_batches(args.provider.as_ref()).await?;
791 }
792 _ => (),
793 }
794
795 match &command {
796 Command::Eval(_) => score::print_report(&finished_examples.lock().unwrap()),
797 _ => (),
798 };
799
800 // For --in-place, atomically rename temp file to original
801 if let (Some(temp_path), Some(final_path)) = (&in_place_temp_path, &output) {
802 std::fs::rename(temp_path, final_path)
803 .expect("Failed to rename temp file to final output");
804 }
805
806 anyhow::Ok(())
807 }
808 .await;
809
810 if let Err(e) = result {
811 panic!("Fatal error: {:?}", e);
812 }
813
814 let _ = cx.update(|cx| cx.quit());
815 })
816 .detach();
817 });
818}
819
820async fn handle_error(
821 error: anyhow::Error,
822 args: &EpArgs,
823 command: &Command,
824 app_state: &Arc<headless::EpAppState>,
825 failfast_on_single_example: bool,
826 example: &Example,
827) {
828 Progress::global().increment_failed();
829
830 let msg;
831 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
832 let example_name = example.spec.filename();
833
834 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
835 app_state
836 .fs
837 .write(
838 &failed_example_path,
839 &serde_json::to_vec_pretty(&example).unwrap(),
840 )
841 .await
842 .unwrap();
843 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
844 app_state
845 .fs
846 .write(&err_path, format!("{error:?}").as_bytes())
847 .await
848 .unwrap();
849
850 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
851 let mut file = OpenOptions::new()
852 .create(true)
853 .append(true)
854 .open(&failed_jsonl_path)
855 .expect("Failed to open failed.jsonl");
856 writeln!(file, "{}", serde_json::to_string(example).unwrap())
857 .expect("Failed to write to failed.jsonl");
858
859 let cursor_path = example
860 .repo_name()
861 .unwrap()
862 .worktree_path()
863 .join(&example.spec.cursor_path);
864 msg = format!(
865 indoc::indoc! {"
866 While processing \"{}\":
867
868 \x1b[31m{:?}\x1b[0m
869
870 Example: \x1b[36m{}\x1b[0m
871 Error file: \x1b[36m{}\x1b[0m
872 Cursor file: \x1b[36m{}\x1b[0m
873 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
874 "},
875 example.spec.name,
876 error,
877 failed_example_path.display(),
878 err_path.display(),
879 cursor_path.display(),
880 command,
881 failed_example_path.display(),
882 );
883 } else {
884 msg = format!(
885 indoc::indoc! {"
886 While processing \"{}\":
887
888 \x1b[31m{:?}\x1b[0m
889 "},
890 example.spec.name, error
891 );
892 }
893
894 if args.failfast || failfast_on_single_example {
895 Progress::global().finalize();
896 panic!("{}", msg);
897 } else {
898 log::error!("{}", msg);
899 }
900}