1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod load_project;
9mod metrics;
10mod parse_output;
11mod paths;
12mod predict;
13mod progress;
14mod pull_examples;
15mod reorder_patch;
16mod retrieve_context;
17mod score;
18mod split_commit;
19mod split_dataset;
20mod synthesize;
21use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
22use collections::HashSet;
23use edit_prediction::EditPredictionStore;
24use futures::channel::mpsc;
25use futures::{SinkExt as _, StreamExt as _};
26use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
27use zeta_prompt::ZetaVersion;
28
29use reqwest_client::ReqwestClient;
30use serde::{Deserialize, Deserializer, Serialize, Serializer};
31use std::fmt::Display;
32use std::fs::{File, OpenOptions};
33use std::hash::{Hash, Hasher};
34use std::io::{BufRead, BufReader, BufWriter, Write};
35use std::sync::Mutex;
36use std::{path::PathBuf, sync::Arc};
37
38use crate::distill::run_distill;
39use crate::example::{Example, group_examples_by_repo, read_example_files};
40use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
41use crate::format_prompt::run_format_prompt;
42use crate::load_project::run_load_project;
43use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
44use crate::predict::run_prediction;
45use crate::progress::Progress;
46use crate::retrieve_context::run_context_retrieval;
47use crate::score::run_scoring;
48use crate::split_commit::SplitCommitArgs;
49use crate::split_dataset::SplitArgs;
50use crate::synthesize::{SynthesizeConfig, run_synthesize};
51
52#[derive(Parser, Debug)]
53#[command(name = "ep")]
54struct EpArgs {
55 #[arg(long, default_value_t = false)]
56 printenv: bool,
57 #[clap(long, default_value_t = 10, global = true)]
58 max_parallelism: usize,
59 #[clap(long, global = true)]
60 limit: Option<usize>,
61 #[clap(long, global = true)]
62 offset: Option<usize>,
63 /// Filter examples by name
64 #[clap(long, global = true)]
65 name: Option<String>,
66 /// Filter examples by repository
67 #[clap(long, global = true)]
68 repo: Option<String>,
69 #[command(subcommand)]
70 command: Option<Command>,
71 #[clap(global = true, help = INPUTS_HELP)]
72 inputs: Vec<PathBuf>,
73 #[arg(long, short, global = true)]
74 output: Option<PathBuf>,
75 #[arg(long, short, global = true)]
76 in_place: bool,
77 #[arg(long, short, global = true)]
78 failfast: bool,
79 /// How to handle failed examples in output: keep them or skip them.
80 /// Failed examples are always logged to the run's failed directory.
81 #[arg(long, global = true, default_value = "keep")]
82 failed: FailedHandling,
83}
84
85/// Controls whether failed examples are included in the main output.
86/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
87#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
88pub enum FailedHandling {
89 /// Include failed examples in the main output (default)
90 #[default]
91 Keep,
92 /// Exclude failed examples from the main output
93 Skip,
94 /// Skip writing files
95 SkipNoFiles,
96}
97
98const INPUTS_HELP: &str = r#"
99Inputs can be file paths or special specifiers:
100
101 path
102 Path to an example(s) file (.md, .json, or .jsonl)
103
104 captured-after:{timestamp}
105 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
106
107 You can specify this multiple times and mix it with file inputs.
108
109 Required environment variables to connect to Snowflake:
110 EP_SNOWFLAKE_API_KEY
111 EP_SNOWFLAKE_BASE_URL
112
113 Optional:
114 EP_SNOWFLAKE_ROLE
115
116Examples:
117
118 # Predict from a file
119 ep predict examples.jsonl
120
121 # Predict from captured examples after a timestamp
122 ep predict captured-after:2025-01-01T00:00:00Z
123
124 # Mix file inputs and captured-after in the same invocation
125 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
126"#;
127
128#[derive(Subcommand, Debug, Clone)]
129enum Command {
130 /// Parse markdown examples and output a combined .jsonl file
131 ParseExample,
132 /// Create git worktrees for each example and load file contents
133 LoadProject,
134 /// Retrieve context for input examples.
135 Context,
136 /// Generate a prompt string for a specific model
137 FormatPrompt(FormatPromptArgs),
138 /// Runs edit prediction
139 Predict(PredictArgs),
140 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
141 /// Requires format-prompt to have been run first. Uses provider from prompt.
142 ParseOutput,
143 /// Computes a score based on actual and expected patches
144 Score(PredictArgs),
145 /// Prepares a distillation dataset by copying expected outputs to
146 /// predicted outputs and removing actual outputs and prompts.
147 Distill,
148 /// Print aggregated scores
149 Eval(PredictArgs),
150 /// Generate eval examples by analyzing git commits from a repository
151 Synthesize(SynthesizeArgs),
152 /// Remove git repositories and worktrees
153 Clean,
154 /// Generate an evaluation example by splitting a chronologically-ordered commit
155 SplitCommit(SplitCommitArgs),
156 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
157 Split(SplitArgs),
158 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
159 FilterLanguages(FilterLanguagesArgs),
160 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
161 ImportBatch(ImportBatchArgs),
162}
163
164impl Display for Command {
165 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166 match self {
167 Command::ParseExample => write!(f, "parse-example"),
168 Command::LoadProject => write!(f, "load-project"),
169 Command::Context => write!(f, "context"),
170 Command::FormatPrompt(args) => {
171 write!(f, "format-prompt --provider={}", args.provider)
172 }
173 Command::Predict(args) => match &args.provider {
174 Some(provider) => write!(f, "predict --provider={}", provider),
175 None => write!(f, "predict"),
176 },
177 Command::ParseOutput => write!(f, "parse-output"),
178 Command::Score(args) => match &args.provider {
179 Some(provider) => write!(f, "score --provider={}", provider),
180 None => write!(f, "score"),
181 },
182 Command::Distill => write!(f, "distill"),
183 Command::Eval(args) => match &args.provider {
184 Some(provider) => write!(f, "eval --provider={}", provider),
185 None => write!(f, "eval"),
186 },
187 Command::Synthesize(args) => {
188 write!(f, "synthesize --repos {}", args.repos.join(" "))
189 }
190 Command::Clean => write!(f, "clean"),
191 Command::SplitCommit(_) => write!(f, "split-commit"),
192 Command::Split(_) => write!(f, "split"),
193 Command::FilterLanguages(_) => write!(f, "filter-languages"),
194 Command::ImportBatch(args) => {
195 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
196 }
197 }
198 }
199}
200
201#[derive(Debug, Args, Clone)]
202struct FormatPromptArgs {
203 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
204 provider: PredictionProvider,
205}
206
207#[derive(Debug, Args, Clone)]
208struct PredictArgs {
209 #[clap(long, short('p'))]
210 provider: Option<PredictionProvider>,
211 #[clap(long, default_value_t = 1)]
212 repetitions: usize,
213}
214
215#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
216enum PredictionProvider {
217 Sweep,
218 Mercury,
219 Zeta1,
220 Zeta2(ZetaVersion),
221 Teacher(ZetaVersion),
222 TeacherNonBatching(ZetaVersion),
223}
224
225impl Default for PredictionProvider {
226 fn default() -> Self {
227 PredictionProvider::Zeta2(ZetaVersion::default())
228 }
229}
230
231impl std::fmt::Display for PredictionProvider {
232 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233 match self {
234 PredictionProvider::Sweep => write!(f, "sweep"),
235 PredictionProvider::Mercury => write!(f, "mercury"),
236 PredictionProvider::Zeta1 => write!(f, "zeta1"),
237 PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
238 PredictionProvider::Teacher(version) => write!(f, "teacher:{version}"),
239 PredictionProvider::TeacherNonBatching(version) => {
240 write!(f, "teacher-non-batching:{version}")
241 }
242 }
243 }
244}
245
246impl std::str::FromStr for PredictionProvider {
247 type Err = anyhow::Error;
248
249 fn from_str(mut s: &str) -> Result<Self, Self::Err> {
250 let mut version = ZetaVersion::default();
251 if let Some((first, second)) = s.split_once(':') {
252 version = ZetaVersion::parse(second)?;
253 s = first;
254 }
255
256 let s_lower = s.to_lowercase();
257 match s_lower.as_str() {
258 "sweep" => Ok(PredictionProvider::Sweep),
259 "mercury" => Ok(PredictionProvider::Mercury),
260 "zeta1" => Ok(PredictionProvider::Zeta1),
261 "zeta2" => Ok(PredictionProvider::Zeta2(version)),
262 "teacher" => Ok(PredictionProvider::Teacher(version)),
263 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
264 Ok(PredictionProvider::TeacherNonBatching(version))
265 }
266 _ => {
267 anyhow::bail!(
268 "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching\n\
269 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
270 Available zeta versions:\n{}",
271 ZetaVersion::options_as_string()
272 )
273 }
274 }
275 }
276}
277
278impl Serialize for PredictionProvider {
279 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
280 where
281 S: Serializer,
282 {
283 serializer.serialize_str(&self.to_string())
284 }
285}
286
287impl<'de> Deserialize<'de> for PredictionProvider {
288 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
289 where
290 D: Deserializer<'de>,
291 {
292 let s = String::deserialize(deserializer)?;
293 s.parse().map_err(serde::de::Error::custom)
294 }
295}
296
297#[derive(Debug, Args, Clone)]
298struct SynthesizeArgs {
299 /// Repository URLs (git@github.com:owner/repo or https://...)
300 #[clap(long, required = true, num_args = 1..)]
301 repos: Vec<String>,
302
303 /// Number of examples to generate per repository
304 #[clap(long, default_value_t = 5)]
305 count: usize,
306
307 /// Maximum commits to scan per repository before giving up
308 #[clap(long, default_value_t = 100)]
309 max_commits: usize,
310
311 /// Ignore state file and reprocess all commits
312 #[clap(long)]
313 fresh: bool,
314}
315
316#[derive(Debug, Args, Clone)]
317struct ImportBatchArgs {
318 /// Anthropic batch IDs to import (e.g., msgbatch_xxx)
319 #[clap(long, required = true, num_args = 1..)]
320 batch_ids: Vec<String>,
321}
322
323impl EpArgs {
324 fn output_path(&self) -> Option<PathBuf> {
325 if self.in_place {
326 if self.inputs.len() == 1 {
327 self.inputs.first().cloned()
328 } else {
329 panic!("--in-place requires exactly one input file")
330 }
331 } else {
332 self.output.clone()
333 }
334 }
335}
336
337async fn load_examples(
338 http_client: Arc<dyn http_client::HttpClient>,
339 args: &EpArgs,
340 output_path: Option<&PathBuf>,
341 background_executor: BackgroundExecutor,
342) -> anyhow::Result<Vec<Example>> {
343 let mut captured_after_timestamps = Vec::new();
344 let mut file_inputs = Vec::new();
345
346 for input in &args.inputs {
347 let input_string = input.to_string_lossy();
348 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
349 captured_after_timestamps.push(timestamp.to_string());
350 } else {
351 file_inputs.push(input.clone());
352 }
353 }
354
355 let mut examples = read_example_files(&file_inputs);
356
357 Progress::global().set_total_examples(examples.len());
358
359 let remaining_limit_for_snowflake =
360 args.limit.map(|limit| limit.saturating_sub(examples.len()));
361
362 if let Some(0) = remaining_limit_for_snowflake {
363 log::info!(
364 "skipping captured-after inputs because --limit is already satisfied by example files"
365 );
366 } else if !captured_after_timestamps.is_empty() {
367 captured_after_timestamps.sort();
368
369 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
370
371 let mut captured_examples = pull_examples::fetch_captured_examples_after(
372 http_client,
373 &captured_after_timestamps,
374 max_rows_per_timestamp,
375 background_executor,
376 )
377 .await?;
378 examples.append(&mut captured_examples);
379 }
380
381 crate::example::sort_examples_by_repo_and_rev(&mut examples);
382
383 if let Some(name_filter) = &args.name {
384 examples.retain(|example| example.spec.name.contains(name_filter));
385 }
386 if let Some(repo_filter) = &args.repo {
387 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
388 }
389
390 // Skip resume logic for --in-place since input and output are the same file,
391 // which would incorrectly treat all input examples as already processed.
392 if !args.in_place {
393 if let Some(path) = output_path {
394 resume_from_output(path, &mut examples);
395 }
396 }
397
398 if let Some(offset) = args.offset {
399 examples.splice(0..offset, []);
400 }
401
402 if let Some(limit) = args.limit {
403 examples.truncate(limit);
404 }
405
406 Progress::global().set_total_examples(examples.len());
407
408 Ok(examples)
409}
410
411fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
412 let mut hasher = collections::FxHasher::default();
413 spec.hash(&mut hasher);
414 hasher.finish()
415}
416
417fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
418 let file = match File::open(path) {
419 Ok(f) => f,
420 Err(_) => return,
421 };
422
423 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
424
425 let reader = BufReader::new(file);
426 let mut kept_lines = Vec::new();
427 let mut kept_hashes = HashSet::default();
428
429 for line in reader.lines() {
430 let line = match line {
431 Ok(l) => l,
432 Err(_) => continue,
433 };
434
435 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
436 let hash = spec_hash(&output_example.spec);
437 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
438 kept_hashes.insert(hash);
439 kept_lines.push(line);
440 }
441 }
442 }
443
444 let total = examples.len();
445 let already_processed = kept_hashes.len();
446
447 eprintln!(
448 "Resuming: {}/{} examples already processed",
449 already_processed, total
450 );
451
452 let file = OpenOptions::new()
453 .write(true)
454 .truncate(true)
455 .open(path)
456 .expect("Failed to open output file for rewriting");
457 let mut writer = BufWriter::new(file);
458 for line in &kept_lines {
459 writeln!(writer, "{}", line).expect("Failed to write to output file");
460 }
461 writer.flush().expect("Failed to flush output file");
462
463 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
464}
465
466fn main() {
467 let args = EpArgs::parse();
468
469 if args.printenv {
470 ::util::shell_env::print_env();
471 return;
472 }
473
474 let output = args.output_path();
475 let command = match &args.command {
476 Some(cmd) => cmd.clone(),
477 None => {
478 EpArgs::command().print_help().unwrap();
479 return;
480 }
481 };
482
483 match &command {
484 Command::ImportBatch(import_args) => {
485 smol::block_on(async {
486 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
487 .expect("Failed to create Anthropic client");
488 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
489 eprintln!("Error importing batches: {:?}", e);
490 std::process::exit(1);
491 }
492 println!(
493 "Successfully imported {} batch(es)",
494 import_args.batch_ids.len()
495 );
496 });
497 return;
498 }
499 Command::Clean => {
500 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
501 return;
502 }
503 Command::Synthesize(synth_args) => {
504 let Some(output_dir) = args.output else {
505 panic!("output dir is required");
506 };
507 let config = SynthesizeConfig {
508 repo_urls: synth_args.repos.clone(),
509 count: synth_args.count,
510 max_commits: synth_args.max_commits,
511 output_dir,
512 fresh: synth_args.fresh,
513 };
514 smol::block_on(async {
515 if let Err(e) = run_synthesize(config).await {
516 eprintln!("Error: {:?}", e);
517 std::process::exit(1);
518 }
519 });
520 return;
521 }
522 Command::SplitCommit(split_commit_args) => {
523 if let Err(error) = split_commit::run_split_commit(
524 split_commit_args,
525 &args.inputs,
526 output.as_ref(),
527 args.failed,
528 ) {
529 eprintln!("{error:#}");
530 std::process::exit(1);
531 }
532 return;
533 }
534 Command::Split(split_args) => {
535 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
536 eprintln!("{error:#}");
537 std::process::exit(1);
538 }
539 return;
540 }
541 Command::FilterLanguages(filter_args) => {
542 if let Err(error) =
543 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
544 {
545 eprintln!("{error:#}");
546 std::process::exit(1);
547 }
548 return;
549 }
550 _ => {}
551 }
552
553 let http_client = Arc::new(ReqwestClient::new());
554 let app = Application::headless().with_http_client(http_client);
555
556 app.run(move |cx| {
557 let app_state = Arc::new(headless::init(cx));
558 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
559
560 cx.spawn(async move |cx| {
561 let result = async {
562 let examples = load_examples(
563 app_state.client.http_client(),
564 &args,
565 output.as_ref(),
566 cx.background_executor().clone(),
567 )
568 .await?;
569
570 match &command {
571 Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
572 predict::sync_batches(args.provider.as_ref()).await?;
573 }
574 _ => (),
575 }
576
577 let failfast_on_single_example = examples.len() == 1;
578
579 // For --in-place, write to a temp file and rename at the end to avoid data loss on interruption
580 let in_place_temp_path = if args.in_place {
581 output.as_ref().map(|path| {
582 let mut temp_path = path.clone();
583 temp_path.set_extension("jsonl.tmp");
584 temp_path
585 })
586 } else {
587 None
588 };
589
590 let output_sender: Option<mpsc::UnboundedSender<String>> =
591 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
592 let write_path = in_place_temp_path.as_ref().or(output.as_ref());
593 write_path.map(|path| {
594 let file = if args.in_place {
595 // For --in-place, write to temp file (truncate if exists)
596 OpenOptions::new()
597 .create(true)
598 .write(true)
599 .truncate(true)
600 .open(path)
601 .expect("Failed to open temp output file")
602 } else {
603 // For regular output, append to support resuming
604 OpenOptions::new()
605 .create(true)
606 .append(true)
607 .open(path)
608 .expect("Failed to open output file")
609 };
610 let mut writer = BufWriter::new(file);
611 let (sender, mut receiver) = mpsc::unbounded::<String>();
612 cx.background_spawn(async move {
613 while let Some(line) = receiver.next().await {
614 writeln!(writer, "{}", line).expect("Failed to write example");
615 writer.flush().expect("Failed to flush output");
616 }
617 })
618 .detach();
619 sender
620 })
621 } else {
622 None
623 };
624
625 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
626 let finished_examples = Mutex::new(Vec::new());
627
628 let mut tasks = Vec::new();
629 for _ in 0..args.max_parallelism {
630 tasks.push(async {
631 loop {
632 let Some(mut repo_examples) =
633 grouped_examples.lock().unwrap().pop_front()
634 else {
635 break;
636 };
637 for example in &mut repo_examples {
638 let example_progress =
639 Progress::global().start_group(&example.spec.name);
640
641 let result = async {
642 match &command {
643 Command::ParseExample => {}
644 Command::LoadProject => {
645 run_load_project(
646 example,
647 app_state.clone(),
648 &example_progress,
649 cx.clone(),
650 )
651 .await?;
652 }
653 Command::Context => {
654 run_context_retrieval(
655 example,
656 app_state.clone(),
657 &example_progress,
658 cx.clone(),
659 )
660 .await?;
661 }
662 Command::FormatPrompt(args) => {
663 run_format_prompt(
664 example,
665 args,
666 app_state.clone(),
667 &example_progress,
668 cx.clone(),
669 )
670 .await?;
671 }
672 Command::Predict(args) => {
673 run_prediction(
674 example,
675 args,
676 app_state.clone(),
677 &example_progress,
678 cx.clone(),
679 )
680 .await?;
681 }
682 Command::ParseOutput => {
683 parse_output::run_parse_output(example)?;
684 }
685 Command::Distill => {
686 run_distill(example).await?;
687 }
688 Command::Score(args) | Command::Eval(args) => {
689 run_scoring(
690 example,
691 &args,
692 app_state.clone(),
693 &example_progress,
694 cx.clone(),
695 )
696 .await?;
697 }
698 Command::Clean
699 | Command::Synthesize(_)
700 | Command::SplitCommit(_)
701 | Command::Split(_)
702 | Command::FilterLanguages(_)
703 | Command::ImportBatch(_) => {
704 unreachable!()
705 }
706 }
707 anyhow::Ok(())
708 }
709 .await;
710
711 let failed = if let Err(error) = result {
712 handle_error(
713 error,
714 &args,
715 &command,
716 &app_state,
717 failfast_on_single_example,
718 &example,
719 )
720 .await;
721 true
722 } else {
723 false
724 };
725
726 let should_write = !failed || args.failed == FailedHandling::Keep;
727 if should_write {
728 if let Some(ref mut sender) = output_sender.clone() {
729 let line = serde_json::to_string(&example).unwrap();
730 sender
731 .send(line)
732 .await
733 .expect("Failed to send to output writer");
734 } else if args.output.is_none()
735 && !matches!(command, Command::Eval(_))
736 {
737 let line = serde_json::to_string(&example).unwrap();
738 println!("{}", line);
739 }
740 }
741 }
742
743 let repo_url = &repo_examples.first().unwrap().spec.repository_url;
744 let project = repo_examples
745 .iter()
746 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
747 .or_else(|| app_state.project_cache.get(repo_url));
748
749 if let Some(project) = project {
750 let mut cx = cx.clone();
751
752 let shutdown_task: Task<()> =
753 project.update(&mut cx, |project, cx| {
754 let lsp_store = project.lsp_store();
755 lsp_store.update(cx, |lsp_store, cx| {
756 lsp_store.shutdown_all_language_servers(cx)
757 })
758 });
759
760 shutdown_task.await;
761
762 if let Some(ep_store) =
763 cx.update(|cx| EditPredictionStore::try_global(cx))
764 {
765 ep_store.update(&mut cx, |store, _| {
766 store.remove_project(&project);
767 });
768 }
769 }
770
771 app_state.project_cache.remove(repo_url);
772 for example in &mut repo_examples {
773 example.state.take();
774 }
775 finished_examples
776 .lock()
777 .unwrap()
778 .extend_from_slice(&repo_examples);
779 }
780 });
781 }
782 futures::future::join_all(tasks).await;
783
784 Progress::global().finalize();
785
786 match &command {
787 Command::Predict(args) | Command::Score(args) | Command::Eval(args) => {
788 predict::sync_batches(args.provider.as_ref()).await?;
789 }
790 _ => (),
791 }
792
793 match &command {
794 Command::Eval(_) => score::print_report(&finished_examples.lock().unwrap()),
795 _ => (),
796 };
797
798 // For --in-place, atomically rename temp file to original
799 if let (Some(temp_path), Some(final_path)) = (&in_place_temp_path, &output) {
800 std::fs::rename(temp_path, final_path)
801 .expect("Failed to rename temp file to final output");
802 }
803
804 anyhow::Ok(())
805 }
806 .await;
807
808 if let Err(e) = result {
809 panic!("Fatal error: {:?}", e);
810 }
811
812 let _ = cx.update(|cx| cx.quit());
813 })
814 .detach();
815 });
816}
817
818async fn handle_error(
819 error: anyhow::Error,
820 args: &EpArgs,
821 command: &Command,
822 app_state: &Arc<headless::EpAppState>,
823 failfast_on_single_example: bool,
824 example: &Example,
825) {
826 Progress::global().increment_failed();
827
828 let msg;
829 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
830 let example_name = example.spec.filename();
831
832 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
833 app_state
834 .fs
835 .write(
836 &failed_example_path,
837 &serde_json::to_vec_pretty(&example).unwrap(),
838 )
839 .await
840 .unwrap();
841 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
842 app_state
843 .fs
844 .write(&err_path, format!("{error:?}").as_bytes())
845 .await
846 .unwrap();
847
848 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
849 let mut file = OpenOptions::new()
850 .create(true)
851 .append(true)
852 .open(&failed_jsonl_path)
853 .expect("Failed to open failed.jsonl");
854 writeln!(file, "{}", serde_json::to_string(example).unwrap())
855 .expect("Failed to write to failed.jsonl");
856
857 let cursor_path = example
858 .repo_name()
859 .unwrap()
860 .worktree_path()
861 .join(&example.spec.cursor_path);
862 msg = format!(
863 indoc::indoc! {"
864 While processing \"{}\":
865
866 \x1b[31m{:?}\x1b[0m
867
868 Example: \x1b[36m{}\x1b[0m
869 Error file: \x1b[36m{}\x1b[0m
870 Cursor file: \x1b[36m{}\x1b[0m
871 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
872 "},
873 example.spec.name,
874 error,
875 failed_example_path.display(),
876 err_path.display(),
877 cursor_path.display(),
878 command,
879 failed_example_path.display(),
880 );
881 } else {
882 msg = format!(
883 indoc::indoc! {"
884 While processing \"{}\":
885
886 \x1b[31m{:?}\x1b[0m
887 "},
888 example.spec.name, error
889 );
890 }
891
892 if args.failfast || failfast_on_single_example {
893 Progress::global().finalize();
894 panic!("{}", msg);
895 } else {
896 log::error!("{}", msg);
897 }
898}