1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod load_project;
9mod metrics;
10mod parse_output;
11mod paths;
12mod predict;
13mod progress;
14mod pull_examples;
15mod reorder_patch;
16mod retrieve_context;
17mod score;
18mod split_commit;
19mod split_dataset;
20mod synthesize;
21use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
22use collections::HashSet;
23use edit_prediction::EditPredictionStore;
24use futures::channel::mpsc;
25use futures::{SinkExt as _, StreamExt as _};
26use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
27use zeta_prompt::ZetaVersion;
28
29use reqwest_client::ReqwestClient;
30use serde::{Deserialize, Deserializer, Serialize, Serializer};
31use std::fmt::Display;
32use std::fs::{File, OpenOptions};
33use std::hash::{Hash, Hasher};
34use std::io::{BufRead, BufReader, BufWriter, Write};
35use std::sync::Mutex;
36use std::{path::PathBuf, sync::Arc};
37
38use crate::distill::run_distill;
39use crate::example::{Example, group_examples_by_repo, read_example_files};
40use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
41use crate::format_prompt::run_format_prompt;
42use crate::load_project::run_load_project;
43use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
44use crate::predict::run_prediction;
45use crate::progress::Progress;
46use crate::retrieve_context::run_context_retrieval;
47use crate::score::run_scoring;
48use crate::split_commit::SplitCommitArgs;
49use crate::split_dataset::SplitArgs;
50use crate::synthesize::{SynthesizeConfig, run_synthesize};
51
52#[derive(Parser, Debug)]
53#[command(name = "ep")]
54struct EpArgs {
55 #[arg(long, default_value_t = false)]
56 printenv: bool,
57 #[clap(long, default_value_t = 10, global = true)]
58 max_parallelism: usize,
59 #[clap(long, global = true)]
60 limit: Option<usize>,
61 #[clap(long, global = true)]
62 offset: Option<usize>,
63 /// Filter examples by name
64 #[clap(long, global = true)]
65 name: Option<String>,
66 /// Filter examples by repository
67 #[clap(long, global = true)]
68 repo: Option<String>,
69 #[command(subcommand)]
70 command: Option<Command>,
71 #[clap(global = true, help = INPUTS_HELP)]
72 inputs: Vec<PathBuf>,
73 #[arg(long, short, global = true)]
74 output: Option<PathBuf>,
75 #[arg(long, short, global = true)]
76 in_place: bool,
77 #[arg(long, short, global = true)]
78 failfast: bool,
79 /// How to handle failed examples in output: keep them or skip them.
80 /// Failed examples are always logged to the run's failed directory.
81 #[arg(long, global = true, default_value = "keep")]
82 failed: FailedHandling,
83}
84
85/// Controls whether failed examples are included in the main output.
86/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
87#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
88pub enum FailedHandling {
89 /// Include failed examples in the main output (default)
90 #[default]
91 Keep,
92 /// Exclude failed examples from the main output
93 Skip,
94 /// Skip writing files
95 SkipNoFiles,
96}
97
98const INPUTS_HELP: &str = r#"
99Inputs can be file paths or special specifiers:
100
101 path
102 Path to an example(s) file (.md, .json, or .jsonl)
103
104 captured-after:{timestamp}
105 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
106
107 You can specify this multiple times and mix it with file inputs.
108
109 Required environment variables to connect to Snowflake:
110 EP_SNOWFLAKE_API_KEY
111 EP_SNOWFLAKE_BASE_URL
112
113 Optional:
114 EP_SNOWFLAKE_ROLE
115
116Examples:
117
118 # Predict from a file
119 ep predict examples.jsonl
120
121 # Predict from captured examples after a timestamp
122 ep predict captured-after:2025-01-01T00:00:00Z
123
124 # Mix file inputs and captured-after in the same invocation
125 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
126"#;
127
128#[derive(Subcommand, Debug, Clone)]
129enum Command {
130 /// Parse markdown examples and output a combined .jsonl file
131 ParseExample,
132 /// Create git worktrees for each example and load file contents
133 LoadProject,
134 /// Retrieve context for input examples.
135 Context,
136 /// Generate a prompt string for a specific model
137 FormatPrompt(FormatPromptArgs),
138 /// Runs edit prediction
139 Predict(PredictArgs),
140 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
141 /// Requires format-prompt to have been run first. Uses provider from prompt.
142 ParseOutput,
143 /// Computes a score based on actual and expected patches
144 Score(PredictArgs),
145 /// Prepares a distillation dataset by copying expected outputs to
146 /// predicted outputs and removing actual outputs and prompts.
147 Distill,
148 /// Print aggregated scores
149 Eval(EvalArgs),
150 /// Generate eval examples by analyzing git commits from a repository
151 Synthesize(SynthesizeArgs),
152 /// Remove git repositories and worktrees
153 Clean,
154 /// Generate an evaluation example by splitting a chronologically-ordered commit
155 SplitCommit(SplitCommitArgs),
156 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
157 Split(SplitArgs),
158 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
159 FilterLanguages(FilterLanguagesArgs),
160 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
161 ImportBatch(ImportBatchArgs),
162}
163
164impl Display for Command {
165 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166 match self {
167 Command::ParseExample => write!(f, "parse-example"),
168 Command::LoadProject => write!(f, "load-project"),
169 Command::Context => write!(f, "context"),
170 Command::FormatPrompt(args) => {
171 write!(f, "format-prompt --provider={}", args.provider)
172 }
173 Command::Predict(args) => match &args.provider {
174 Some(provider) => write!(f, "predict --provider={}", provider),
175 None => write!(f, "predict"),
176 },
177 Command::ParseOutput => write!(f, "parse-output"),
178 Command::Score(args) => match &args.provider {
179 Some(provider) => write!(f, "score --provider={}", provider),
180 None => write!(f, "score"),
181 },
182 Command::Distill => write!(f, "distill"),
183 Command::Eval(args) => match &args.predict.provider {
184 Some(provider) => write!(f, "eval --provider={}", provider),
185 None => write!(f, "eval"),
186 },
187 Command::Synthesize(args) => {
188 write!(f, "synthesize --repos {}", args.repos.join(" "))
189 }
190 Command::Clean => write!(f, "clean"),
191 Command::SplitCommit(_) => write!(f, "split-commit"),
192 Command::Split(_) => write!(f, "split"),
193 Command::FilterLanguages(_) => write!(f, "filter-languages"),
194 Command::ImportBatch(args) => {
195 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
196 }
197 }
198 }
199}
200
201#[derive(Debug, Args, Clone)]
202struct FormatPromptArgs {
203 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
204 provider: PredictionProvider,
205}
206
207#[derive(Debug, Args, Clone)]
208struct PredictArgs {
209 #[clap(long, short('p'))]
210 provider: Option<PredictionProvider>,
211 #[clap(long, default_value_t = 1)]
212 repetitions: usize,
213}
214
215#[derive(Debug, Args, Clone)]
216struct EvalArgs {
217 #[clap(flatten)]
218 predict: PredictArgs,
219 /// Path to write summary scores as JSON
220 #[clap(long)]
221 summary_json: Option<PathBuf>,
222}
223
224#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
225enum PredictionProvider {
226 Sweep,
227 Mercury,
228 Zeta1,
229 Zeta2(ZetaVersion),
230 Teacher(ZetaVersion),
231 TeacherNonBatching(ZetaVersion),
232}
233
234impl Default for PredictionProvider {
235 fn default() -> Self {
236 PredictionProvider::Zeta2(ZetaVersion::default())
237 }
238}
239
240impl std::fmt::Display for PredictionProvider {
241 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242 match self {
243 PredictionProvider::Sweep => write!(f, "sweep"),
244 PredictionProvider::Mercury => write!(f, "mercury"),
245 PredictionProvider::Zeta1 => write!(f, "zeta1"),
246 PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
247 PredictionProvider::Teacher(version) => write!(f, "teacher:{version}"),
248 PredictionProvider::TeacherNonBatching(version) => {
249 write!(f, "teacher-non-batching:{version}")
250 }
251 }
252 }
253}
254
255impl std::str::FromStr for PredictionProvider {
256 type Err = anyhow::Error;
257
258 fn from_str(mut s: &str) -> Result<Self, Self::Err> {
259 let mut version = ZetaVersion::default();
260 if let Some((first, second)) = s.split_once(':') {
261 version = ZetaVersion::parse(second)?;
262 s = first;
263 }
264
265 let s_lower = s.to_lowercase();
266 match s_lower.as_str() {
267 "sweep" => Ok(PredictionProvider::Sweep),
268 "mercury" => Ok(PredictionProvider::Mercury),
269 "zeta1" => Ok(PredictionProvider::Zeta1),
270 "zeta2" => Ok(PredictionProvider::Zeta2(version)),
271 "teacher" => Ok(PredictionProvider::Teacher(version)),
272 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
273 Ok(PredictionProvider::TeacherNonBatching(version))
274 }
275 _ => {
276 anyhow::bail!(
277 "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching\n\
278 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
279 Available zeta versions:\n{}",
280 ZetaVersion::options_as_string()
281 )
282 }
283 }
284 }
285}
286
287impl Serialize for PredictionProvider {
288 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
289 where
290 S: Serializer,
291 {
292 serializer.serialize_str(&self.to_string())
293 }
294}
295
296impl<'de> Deserialize<'de> for PredictionProvider {
297 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
298 where
299 D: Deserializer<'de>,
300 {
301 let s = String::deserialize(deserializer)?;
302 s.parse().map_err(serde::de::Error::custom)
303 }
304}
305
306#[derive(Debug, Args, Clone)]
307struct SynthesizeArgs {
308 /// Repository URLs (git@github.com:owner/repo or https://...)
309 #[clap(long, required = true, num_args = 1..)]
310 repos: Vec<String>,
311
312 /// Number of examples to generate per repository
313 #[clap(long, default_value_t = 5)]
314 count: usize,
315
316 /// Maximum commits to scan per repository before giving up
317 #[clap(long, default_value_t = 100)]
318 max_commits: usize,
319
320 /// Ignore state file and reprocess all commits
321 #[clap(long)]
322 fresh: bool,
323}
324
325#[derive(Debug, Args, Clone)]
326struct ImportBatchArgs {
327 /// Anthropic batch IDs to import (e.g., msgbatch_xxx)
328 #[clap(long, required = true, num_args = 1..)]
329 batch_ids: Vec<String>,
330}
331
332impl EpArgs {
333 fn output_path(&self) -> Option<PathBuf> {
334 if self.in_place {
335 if self.inputs.len() == 1 {
336 self.inputs.first().cloned()
337 } else {
338 panic!("--in-place requires exactly one input file")
339 }
340 } else {
341 self.output.clone()
342 }
343 }
344}
345
346async fn load_examples(
347 http_client: Arc<dyn http_client::HttpClient>,
348 args: &EpArgs,
349 output_path: Option<&PathBuf>,
350 background_executor: BackgroundExecutor,
351) -> anyhow::Result<Vec<Example>> {
352 let mut captured_after_timestamps = Vec::new();
353 let mut file_inputs = Vec::new();
354
355 for input in &args.inputs {
356 let input_string = input.to_string_lossy();
357 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
358 captured_after_timestamps.push(timestamp.to_string());
359 } else {
360 file_inputs.push(input.clone());
361 }
362 }
363
364 let mut examples = read_example_files(&file_inputs);
365
366 Progress::global().set_total_examples(examples.len());
367
368 let remaining_limit_for_snowflake =
369 args.limit.map(|limit| limit.saturating_sub(examples.len()));
370
371 if let Some(0) = remaining_limit_for_snowflake {
372 log::info!(
373 "skipping captured-after inputs because --limit is already satisfied by example files"
374 );
375 } else if !captured_after_timestamps.is_empty() {
376 captured_after_timestamps.sort();
377
378 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
379
380 let mut captured_examples = pull_examples::fetch_captured_examples_after(
381 http_client,
382 &captured_after_timestamps,
383 max_rows_per_timestamp,
384 background_executor,
385 )
386 .await?;
387 examples.append(&mut captured_examples);
388 }
389
390 crate::example::sort_examples_by_repo_and_rev(&mut examples);
391
392 if let Some(name_filter) = &args.name {
393 examples.retain(|example| example.spec.name.contains(name_filter));
394 }
395 if let Some(repo_filter) = &args.repo {
396 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
397 }
398
399 // Skip resume logic for --in-place since input and output are the same file,
400 // which would incorrectly treat all input examples as already processed.
401 if !args.in_place {
402 if let Some(path) = output_path {
403 resume_from_output(path, &mut examples);
404 }
405 }
406
407 if let Some(offset) = args.offset {
408 examples.splice(0..offset, []);
409 }
410
411 if let Some(limit) = args.limit {
412 examples.truncate(limit);
413 }
414
415 let progress = Progress::global();
416 progress.set_total_examples(examples.len());
417 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
418
419 Ok(examples)
420}
421
422fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
423 let mut hasher = collections::FxHasher::default();
424 spec.hash(&mut hasher);
425 hasher.finish()
426}
427
428fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
429 let file = match File::open(path) {
430 Ok(f) => f,
431 Err(_) => return,
432 };
433
434 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
435
436 let reader = BufReader::new(file);
437 let mut kept_lines = Vec::new();
438 let mut kept_hashes = HashSet::default();
439
440 for line in reader.lines() {
441 let line = match line {
442 Ok(l) => l,
443 Err(_) => continue,
444 };
445
446 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
447 let hash = spec_hash(&output_example.spec);
448 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
449 kept_hashes.insert(hash);
450 kept_lines.push(line);
451 }
452 }
453 }
454
455 let total = examples.len();
456 let already_processed = kept_hashes.len();
457
458 eprintln!(
459 "Resuming: {}/{} examples already processed",
460 already_processed, total
461 );
462
463 let file = OpenOptions::new()
464 .write(true)
465 .truncate(true)
466 .open(path)
467 .expect("Failed to open output file for rewriting");
468 let mut writer = BufWriter::new(file);
469 for line in &kept_lines {
470 writeln!(writer, "{}", line).expect("Failed to write to output file");
471 }
472 writer.flush().expect("Failed to flush output file");
473
474 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
475}
476
477fn main() {
478 let args = EpArgs::parse();
479
480 if args.printenv {
481 ::util::shell_env::print_env();
482 return;
483 }
484
485 let output = args.output_path();
486 let command = match &args.command {
487 Some(cmd) => cmd.clone(),
488 None => {
489 EpArgs::command().print_help().unwrap();
490 return;
491 }
492 };
493
494 match &command {
495 Command::ImportBatch(import_args) => {
496 smol::block_on(async {
497 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
498 .expect("Failed to create Anthropic client");
499 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
500 eprintln!("Error importing batches: {:?}", e);
501 std::process::exit(1);
502 }
503 println!(
504 "Successfully imported {} batch(es)",
505 import_args.batch_ids.len()
506 );
507 });
508 return;
509 }
510 Command::Clean => {
511 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
512 return;
513 }
514 Command::Synthesize(synth_args) => {
515 let Some(output_dir) = args.output else {
516 panic!("output dir is required");
517 };
518 let config = SynthesizeConfig {
519 repo_urls: synth_args.repos.clone(),
520 count: synth_args.count,
521 max_commits: synth_args.max_commits,
522 output_dir,
523 fresh: synth_args.fresh,
524 };
525 smol::block_on(async {
526 if let Err(e) = run_synthesize(config).await {
527 eprintln!("Error: {:?}", e);
528 std::process::exit(1);
529 }
530 });
531 return;
532 }
533 Command::SplitCommit(split_commit_args) => {
534 if let Err(error) = split_commit::run_split_commit(
535 split_commit_args,
536 &args.inputs,
537 output.as_ref(),
538 args.failed,
539 ) {
540 eprintln!("{error:#}");
541 std::process::exit(1);
542 }
543 return;
544 }
545 Command::Split(split_args) => {
546 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
547 eprintln!("{error:#}");
548 std::process::exit(1);
549 }
550 return;
551 }
552 Command::FilterLanguages(filter_args) => {
553 if let Err(error) =
554 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
555 {
556 eprintln!("{error:#}");
557 std::process::exit(1);
558 }
559 return;
560 }
561 _ => {}
562 }
563
564 let http_client = Arc::new(ReqwestClient::new());
565 let app = Application::headless().with_http_client(http_client);
566
567 app.run(move |cx| {
568 let app_state = Arc::new(headless::init(cx));
569 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
570
571 cx.spawn(async move |cx| {
572 let result = async {
573 let examples = load_examples(
574 app_state.client.http_client(),
575 &args,
576 output.as_ref(),
577 cx.background_executor().clone(),
578 )
579 .await?;
580
581 match &command {
582 Command::Predict(args) | Command::Score(args) => {
583 predict::sync_batches(args.provider.as_ref()).await?;
584 }
585 Command::Eval(args) => {
586 predict::sync_batches(args.predict.provider.as_ref()).await?;
587 }
588 _ => (),
589 }
590
591 let failfast_on_single_example = examples.len() == 1;
592
593 // For --in-place, write to a temp file and rename at the end to avoid data loss on interruption
594 let in_place_temp_path = if args.in_place {
595 output.as_ref().map(|path| {
596 let mut temp_path = path.clone();
597 temp_path.set_extension("jsonl.tmp");
598 temp_path
599 })
600 } else {
601 None
602 };
603
604 let output_sender: Option<mpsc::UnboundedSender<String>> =
605 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
606 let write_path = in_place_temp_path.as_ref().or(output.as_ref());
607 write_path.map(|path| {
608 let file = if args.in_place {
609 // For --in-place, write to temp file (truncate if exists)
610 OpenOptions::new()
611 .create(true)
612 .write(true)
613 .truncate(true)
614 .open(path)
615 .expect("Failed to open temp output file")
616 } else {
617 // For regular output, append to support resuming
618 OpenOptions::new()
619 .create(true)
620 .append(true)
621 .open(path)
622 .expect("Failed to open output file")
623 };
624 let mut writer = BufWriter::new(file);
625 let (sender, mut receiver) = mpsc::unbounded::<String>();
626 cx.background_spawn(async move {
627 while let Some(line) = receiver.next().await {
628 writeln!(writer, "{}", line).expect("Failed to write example");
629 writer.flush().expect("Failed to flush output");
630 }
631 })
632 .detach();
633 sender
634 })
635 } else {
636 None
637 };
638
639 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
640 let finished_examples = Mutex::new(Vec::new());
641
642 let mut tasks = Vec::new();
643 for _ in 0..args.max_parallelism {
644 tasks.push(async {
645 loop {
646 let Some(mut repo_examples) =
647 grouped_examples.lock().unwrap().pop_front()
648 else {
649 break;
650 };
651 for example in &mut repo_examples {
652 let example_progress =
653 Progress::global().start_group(&example.spec.name);
654
655 let result = async {
656 match &command {
657 Command::ParseExample => {}
658 Command::LoadProject => {
659 run_load_project(
660 example,
661 app_state.clone(),
662 &example_progress,
663 cx.clone(),
664 )
665 .await?;
666 }
667 Command::Context => {
668 run_context_retrieval(
669 example,
670 app_state.clone(),
671 &example_progress,
672 cx.clone(),
673 )
674 .await?;
675 }
676 Command::FormatPrompt(args) => {
677 run_format_prompt(
678 example,
679 args,
680 app_state.clone(),
681 &example_progress,
682 cx.clone(),
683 )
684 .await?;
685 }
686 Command::Predict(args) => {
687 run_prediction(
688 example,
689 args,
690 app_state.clone(),
691 &example_progress,
692 cx.clone(),
693 )
694 .await?;
695 }
696 Command::ParseOutput => {
697 parse_output::run_parse_output(example)?;
698 }
699 Command::Distill => {
700 run_distill(example).await?;
701 }
702 Command::Score(args) => {
703 run_scoring(
704 example,
705 args,
706 app_state.clone(),
707 &example_progress,
708 cx.clone(),
709 )
710 .await?;
711 }
712 Command::Eval(args) => {
713 run_scoring(
714 example,
715 &args.predict,
716 app_state.clone(),
717 &example_progress,
718 cx.clone(),
719 )
720 .await?;
721 }
722 Command::Clean
723 | Command::Synthesize(_)
724 | Command::SplitCommit(_)
725 | Command::Split(_)
726 | Command::FilterLanguages(_)
727 | Command::ImportBatch(_) => {
728 unreachable!()
729 }
730 }
731 anyhow::Ok(())
732 }
733 .await;
734
735 let failed = if let Err(error) = result {
736 handle_error(
737 error,
738 &args,
739 &command,
740 &app_state,
741 failfast_on_single_example,
742 &example,
743 )
744 .await;
745 true
746 } else {
747 false
748 };
749
750 let should_write = !failed || args.failed == FailedHandling::Keep;
751 if should_write {
752 if let Some(ref mut sender) = output_sender.clone() {
753 let line = serde_json::to_string(&example).unwrap();
754 sender
755 .send(line)
756 .await
757 .expect("Failed to send to output writer");
758 } else if args.output.is_none()
759 && !matches!(command, Command::Eval(_))
760 {
761 let line = serde_json::to_string(&example).unwrap();
762 println!("{}", line);
763 }
764 }
765 }
766
767 let repo_url = &repo_examples.first().unwrap().spec.repository_url;
768 let project = repo_examples
769 .iter()
770 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
771 .or_else(|| app_state.project_cache.get(repo_url));
772
773 if let Some(project) = project {
774 let mut cx = cx.clone();
775
776 let shutdown_task: Task<()> =
777 project.update(&mut cx, |project, cx| {
778 let lsp_store = project.lsp_store();
779 lsp_store.update(cx, |lsp_store, cx| {
780 lsp_store.shutdown_all_language_servers(cx)
781 })
782 });
783
784 shutdown_task.await;
785
786 if let Some(ep_store) =
787 cx.update(|cx| EditPredictionStore::try_global(cx))
788 {
789 ep_store.update(&mut cx, |store, _| {
790 store.remove_project(&project);
791 });
792 }
793 }
794
795 app_state.project_cache.remove(repo_url);
796 for example in &mut repo_examples {
797 example.state.take();
798 }
799 finished_examples
800 .lock()
801 .unwrap()
802 .extend_from_slice(&repo_examples);
803 }
804 });
805 }
806 futures::future::join_all(tasks).await;
807
808 Progress::global().finalize();
809
810 match &command {
811 Command::Predict(args) | Command::Score(args) => {
812 predict::sync_batches(args.provider.as_ref()).await?;
813 }
814 Command::Eval(args) => {
815 predict::sync_batches(args.predict.provider.as_ref()).await?;
816 }
817 _ => (),
818 }
819
820 match &command {
821 Command::Eval(args) => {
822 let examples = finished_examples.lock().unwrap();
823 score::print_report(&examples);
824 if let Some(summary_path) = &args.summary_json {
825 score::write_summary_json(&examples, summary_path)?;
826 }
827 }
828 _ => (),
829 };
830
831 // For --in-place, atomically rename temp file to original
832 if let (Some(temp_path), Some(final_path)) = (&in_place_temp_path, &output) {
833 std::fs::rename(temp_path, final_path)
834 .expect("Failed to rename temp file to final output");
835 }
836
837 anyhow::Ok(())
838 }
839 .await;
840
841 if let Err(e) = result {
842 panic!("Fatal error: {:?}", e);
843 }
844
845 let _ = cx.update(|cx| cx.quit());
846 })
847 .detach();
848 });
849}
850
851async fn handle_error(
852 error: anyhow::Error,
853 args: &EpArgs,
854 command: &Command,
855 app_state: &Arc<headless::EpAppState>,
856 failfast_on_single_example: bool,
857 example: &Example,
858) {
859 Progress::global().increment_failed();
860
861 let msg;
862 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
863 let example_name = example.spec.filename();
864
865 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
866 app_state
867 .fs
868 .write(
869 &failed_example_path,
870 &serde_json::to_vec_pretty(&example).unwrap(),
871 )
872 .await
873 .unwrap();
874 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
875 app_state
876 .fs
877 .write(&err_path, format!("{error:?}").as_bytes())
878 .await
879 .unwrap();
880
881 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
882 let mut file = OpenOptions::new()
883 .create(true)
884 .append(true)
885 .open(&failed_jsonl_path)
886 .expect("Failed to open failed.jsonl");
887 writeln!(file, "{}", serde_json::to_string(example).unwrap())
888 .expect("Failed to write to failed.jsonl");
889
890 let cursor_path = example
891 .repo_name()
892 .unwrap()
893 .worktree_path()
894 .join(&example.spec.cursor_path);
895 msg = format!(
896 indoc::indoc! {"
897 While processing \"{}\":
898
899 \x1b[31m{:?}\x1b[0m
900
901 Example: \x1b[36m{}\x1b[0m
902 Error file: \x1b[36m{}\x1b[0m
903 Cursor file: \x1b[36m{}\x1b[0m
904 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
905 "},
906 example.spec.name,
907 error,
908 failed_example_path.display(),
909 err_path.display(),
910 cursor_path.display(),
911 command,
912 failed_example_path.display(),
913 );
914 } else {
915 msg = format!(
916 indoc::indoc! {"
917 While processing \"{}\":
918
919 \x1b[31m{:?}\x1b[0m
920 "},
921 example.spec.name, error
922 );
923 }
924
925 if args.failfast || failfast_on_single_example {
926 Progress::global().finalize();
927 panic!("{}", msg);
928 } else {
929 log::error!("{}", msg);
930 }
931}