1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod load_project;
9mod metrics;
10mod parse_output;
11mod paths;
12mod predict;
13mod progress;
14mod pull_examples;
15mod qa;
16mod reorder_patch;
17mod retrieve_context;
18mod score;
19mod split_commit;
20mod split_dataset;
21mod synthesize;
22mod word_diff;
23use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
24use collections::HashSet;
25use edit_prediction::EditPredictionStore;
26use futures::channel::mpsc;
27use futures::{SinkExt as _, StreamExt as _};
28use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
29use zeta_prompt::ZetaVersion;
30
31use reqwest_client::ReqwestClient;
32use serde::{Deserialize, Deserializer, Serialize, Serializer};
33use std::fmt::Display;
34use std::fs::{File, OpenOptions};
35use std::hash::{Hash, Hasher};
36use std::io::{BufRead, BufReader, BufWriter, Write};
37use std::sync::Mutex;
38use std::{path::PathBuf, sync::Arc};
39
40use crate::distill::run_distill;
41use crate::example::{Example, group_examples_by_repo, read_example_files};
42use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
43use crate::format_prompt::run_format_prompt;
44use crate::load_project::run_load_project;
45use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
46use crate::predict::run_prediction;
47use crate::progress::Progress;
48use crate::retrieve_context::run_context_retrieval;
49use crate::score::run_scoring;
50use crate::split_commit::SplitCommitArgs;
51use crate::split_dataset::SplitArgs;
52use crate::synthesize::{SynthesizeConfig, run_synthesize};
53
54#[derive(Parser, Debug)]
55#[command(name = "ep")]
56struct EpArgs {
57 #[arg(long, default_value_t = false)]
58 printenv: bool,
59 #[clap(long, default_value_t = 10, global = true)]
60 max_parallelism: usize,
61 #[clap(long, global = true)]
62 limit: Option<usize>,
63 #[clap(long, global = true)]
64 offset: Option<usize>,
65 /// Filter examples by name
66 #[clap(long, global = true)]
67 name: Option<String>,
68 /// Filter examples by repository
69 #[clap(long, global = true)]
70 repo: Option<String>,
71 #[command(subcommand)]
72 command: Option<Command>,
73 #[clap(global = true, help = INPUTS_HELP)]
74 inputs: Vec<PathBuf>,
75 #[arg(long, short, global = true)]
76 output: Option<PathBuf>,
77 #[arg(long, short, global = true)]
78 in_place: bool,
79 #[arg(long, short, global = true)]
80 failfast: bool,
81 /// How to handle failed examples in output: keep them or skip them.
82 /// Failed examples are always logged to the run's failed directory.
83 #[arg(long, global = true, default_value = "keep")]
84 failed: FailedHandling,
85}
86
87/// Controls whether failed examples are included in the main output.
88/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
89#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
90pub enum FailedHandling {
91 /// Include failed examples in the main output (default)
92 #[default]
93 Keep,
94 /// Exclude failed examples from the main output
95 Skip,
96 /// Skip writing files
97 SkipNoFiles,
98}
99
100const INPUTS_HELP: &str = r#"
101Inputs can be file paths or special specifiers:
102
103 path
104 Path to an example(s) file (.md, .json, or .jsonl)
105
106 captured-after:{timestamp}
107 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
108
109 You can specify this multiple times and mix it with file inputs.
110
111 Required environment variables to connect to Snowflake:
112 EP_SNOWFLAKE_API_KEY
113 EP_SNOWFLAKE_BASE_URL
114
115 Optional:
116 EP_SNOWFLAKE_ROLE
117
118Examples:
119
120 # Predict from a file
121 ep predict examples.jsonl
122
123 # Predict from captured examples after a timestamp
124 ep predict captured-after:2025-01-01T00:00:00Z
125
126 # Mix file inputs and captured-after in the same invocation
127 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
128"#;
129
130#[derive(Subcommand, Debug, Clone)]
131enum Command {
132 /// Parse markdown examples and output a combined .jsonl file
133 ParseExample,
134 /// Create git worktrees for each example and load file contents
135 LoadProject,
136 /// Retrieve context for input examples.
137 Context,
138 /// Generate a prompt string for a specific model
139 FormatPrompt(FormatPromptArgs),
140 /// Runs edit prediction
141 Predict(PredictArgs),
142 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
143 /// Requires format-prompt to have been run first. Uses provider from prompt.
144 ParseOutput,
145 /// Computes a score based on actual and expected patches
146 Score(PredictArgs),
147 /// Prepares a distillation dataset by copying expected outputs to
148 /// predicted outputs and removing actual outputs and prompts.
149 Distill,
150 /// Print aggregated scores
151 Eval(EvalArgs),
152 /// Generate eval examples by analyzing git commits from a repository
153 Synthesize(SynthesizeArgs),
154 /// Remove git repositories and worktrees
155 Clean,
156 /// Generate an evaluation example by splitting a chronologically-ordered commit
157 SplitCommit(SplitCommitArgs),
158 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
159 Split(SplitArgs),
160 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
161 FilterLanguages(FilterLanguagesArgs),
162 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
163 ImportBatch(ImportBatchArgs),
164 /// Assess the quality of predictions using LLM-as-a-judge
165 Qa(qa::QaArgs),
166}
167
168impl Display for Command {
169 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170 match self {
171 Command::ParseExample => write!(f, "parse-example"),
172 Command::LoadProject => write!(f, "load-project"),
173 Command::Context => write!(f, "context"),
174 Command::FormatPrompt(args) => {
175 write!(f, "format-prompt --provider={}", args.provider)
176 }
177 Command::Predict(args) => match &args.provider {
178 Some(provider) => write!(f, "predict --provider={}", provider),
179 None => write!(f, "predict"),
180 },
181 Command::ParseOutput => write!(f, "parse-output"),
182 Command::Score(args) => match &args.provider {
183 Some(provider) => write!(f, "score --provider={}", provider),
184 None => write!(f, "score"),
185 },
186 Command::Distill => write!(f, "distill"),
187 Command::Eval(args) => match &args.predict.provider {
188 Some(provider) => write!(f, "eval --provider={}", provider),
189 None => write!(f, "eval"),
190 },
191 Command::Synthesize(args) => {
192 write!(f, "synthesize --repos {}", args.repos.join(" "))
193 }
194 Command::Clean => write!(f, "clean"),
195 Command::SplitCommit(_) => write!(f, "split-commit"),
196 Command::Split(_) => write!(f, "split"),
197 Command::FilterLanguages(_) => write!(f, "filter-languages"),
198 Command::ImportBatch(args) => {
199 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
200 }
201 Command::Qa(_) => {
202 write!(f, "qa")
203 }
204 }
205 }
206}
207
208#[derive(Debug, Args, Clone)]
209struct FormatPromptArgs {
210 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
211 provider: PredictionProvider,
212}
213
214#[derive(Debug, Args, Clone)]
215struct PredictArgs {
216 #[clap(long, short('p'))]
217 provider: Option<PredictionProvider>,
218 #[clap(long, default_value_t = 1)]
219 repetitions: usize,
220}
221
222#[derive(Debug, Args, Clone)]
223struct EvalArgs {
224 #[clap(flatten)]
225 predict: PredictArgs,
226 /// Path to write summary scores as JSON
227 #[clap(long)]
228 summary_json: Option<PathBuf>,
229}
230
231#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
232enum PredictionProvider {
233 Sweep,
234 Mercury,
235 Zeta1,
236 Zeta2(ZetaVersion),
237 Teacher(ZetaVersion),
238 TeacherNonBatching(ZetaVersion),
239}
240
241impl Default for PredictionProvider {
242 fn default() -> Self {
243 PredictionProvider::Zeta2(ZetaVersion::default())
244 }
245}
246
247impl std::fmt::Display for PredictionProvider {
248 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249 match self {
250 PredictionProvider::Sweep => write!(f, "sweep"),
251 PredictionProvider::Mercury => write!(f, "mercury"),
252 PredictionProvider::Zeta1 => write!(f, "zeta1"),
253 PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
254 PredictionProvider::Teacher(version) => write!(f, "teacher:{version}"),
255 PredictionProvider::TeacherNonBatching(version) => {
256 write!(f, "teacher-non-batching:{version}")
257 }
258 }
259 }
260}
261
262impl std::str::FromStr for PredictionProvider {
263 type Err = anyhow::Error;
264
265 fn from_str(mut s: &str) -> Result<Self, Self::Err> {
266 let mut version = ZetaVersion::default();
267 if let Some((first, second)) = s.split_once(':') {
268 version = ZetaVersion::parse(second)?;
269 s = first;
270 }
271
272 let s_lower = s.to_lowercase();
273 match s_lower.as_str() {
274 "sweep" => Ok(PredictionProvider::Sweep),
275 "mercury" => Ok(PredictionProvider::Mercury),
276 "zeta1" => Ok(PredictionProvider::Zeta1),
277 "zeta2" => Ok(PredictionProvider::Zeta2(version)),
278 "teacher" => Ok(PredictionProvider::Teacher(version)),
279 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
280 Ok(PredictionProvider::TeacherNonBatching(version))
281 }
282 _ => {
283 anyhow::bail!(
284 "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching\n\
285 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
286 Available zeta versions:\n{}",
287 ZetaVersion::options_as_string()
288 )
289 }
290 }
291 }
292}
293
294impl Serialize for PredictionProvider {
295 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
296 where
297 S: Serializer,
298 {
299 serializer.serialize_str(&self.to_string())
300 }
301}
302
303impl<'de> Deserialize<'de> for PredictionProvider {
304 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
305 where
306 D: Deserializer<'de>,
307 {
308 let s = String::deserialize(deserializer)?;
309 s.parse().map_err(serde::de::Error::custom)
310 }
311}
312
313#[derive(Debug, Args, Clone)]
314struct SynthesizeArgs {
315 /// Repository URLs (git@github.com:owner/repo or https://...)
316 #[clap(long, required = true, num_args = 1..)]
317 repos: Vec<String>,
318
319 /// Number of examples to generate per repository
320 #[clap(long, default_value_t = 5)]
321 count: usize,
322
323 /// Maximum commits to scan per repository before giving up
324 #[clap(long, default_value_t = 100)]
325 max_commits: usize,
326
327 /// Ignore state file and reprocess all commits
328 #[clap(long)]
329 fresh: bool,
330}
331
332#[derive(Debug, Args, Clone)]
333struct ImportBatchArgs {
334 /// Anthropic batch IDs to import (e.g., msgbatch_xxx)
335 #[clap(long, required = true, num_args = 1..)]
336 batch_ids: Vec<String>,
337}
338
339impl EpArgs {
340 fn output_path(&self) -> Option<PathBuf> {
341 if self.in_place {
342 if self.inputs.len() == 1 {
343 self.inputs.first().cloned()
344 } else {
345 panic!("--in-place requires exactly one input file")
346 }
347 } else {
348 self.output.clone()
349 }
350 }
351}
352
353async fn load_examples(
354 http_client: Arc<dyn http_client::HttpClient>,
355 args: &EpArgs,
356 output_path: Option<&PathBuf>,
357 background_executor: BackgroundExecutor,
358) -> anyhow::Result<Vec<Example>> {
359 let mut captured_after_timestamps = Vec::new();
360 let mut file_inputs = Vec::new();
361
362 for input in &args.inputs {
363 let input_string = input.to_string_lossy();
364 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
365 captured_after_timestamps.push(timestamp.to_string());
366 } else {
367 file_inputs.push(input.clone());
368 }
369 }
370
371 let mut examples = read_example_files(&file_inputs);
372
373 Progress::global().set_total_examples(examples.len());
374
375 let remaining_limit_for_snowflake =
376 args.limit.map(|limit| limit.saturating_sub(examples.len()));
377
378 if let Some(0) = remaining_limit_for_snowflake {
379 log::info!(
380 "skipping captured-after inputs because --limit is already satisfied by example files"
381 );
382 } else if !captured_after_timestamps.is_empty() {
383 captured_after_timestamps.sort();
384
385 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
386
387 let mut captured_examples = pull_examples::fetch_captured_examples_after(
388 http_client,
389 &captured_after_timestamps,
390 max_rows_per_timestamp,
391 background_executor,
392 )
393 .await?;
394 examples.append(&mut captured_examples);
395 }
396
397 crate::example::sort_examples_by_repo_and_rev(&mut examples);
398
399 if let Some(name_filter) = &args.name {
400 examples.retain(|example| example.spec.name.contains(name_filter));
401 }
402 if let Some(repo_filter) = &args.repo {
403 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
404 }
405
406 // Skip resume logic for --in-place since input and output are the same file,
407 // which would incorrectly treat all input examples as already processed.
408 if !args.in_place {
409 if let Some(path) = output_path {
410 resume_from_output(path, &mut examples);
411 }
412 }
413
414 if let Some(offset) = args.offset {
415 examples.splice(0..offset, []);
416 }
417
418 if let Some(limit) = args.limit {
419 examples.truncate(limit);
420 }
421
422 let progress = Progress::global();
423 progress.set_total_examples(examples.len());
424 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
425
426 Ok(examples)
427}
428
429fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
430 let mut hasher = collections::FxHasher::default();
431 spec.hash(&mut hasher);
432 hasher.finish()
433}
434
435fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
436 let file = match File::open(path) {
437 Ok(f) => f,
438 Err(_) => return,
439 };
440
441 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
442
443 let reader = BufReader::new(file);
444 let mut kept_lines = Vec::new();
445 let mut kept_hashes = HashSet::default();
446
447 for line in reader.lines() {
448 let line = match line {
449 Ok(l) => l,
450 Err(_) => continue,
451 };
452
453 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
454 let hash = spec_hash(&output_example.spec);
455 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
456 kept_hashes.insert(hash);
457 kept_lines.push(line);
458 }
459 }
460 }
461
462 let total = examples.len();
463 let already_processed = kept_hashes.len();
464
465 eprintln!(
466 "Resuming: {}/{} examples already processed",
467 already_processed, total
468 );
469
470 let file = OpenOptions::new()
471 .write(true)
472 .truncate(true)
473 .open(path)
474 .expect("Failed to open output file for rewriting");
475 let mut writer = BufWriter::new(file);
476 for line in &kept_lines {
477 writeln!(writer, "{}", line).expect("Failed to write to output file");
478 }
479 writer.flush().expect("Failed to flush output file");
480
481 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
482}
483
484fn main() {
485 let args = EpArgs::parse();
486
487 if args.printenv {
488 ::util::shell_env::print_env();
489 return;
490 }
491
492 let output = args.output_path();
493 let command = match &args.command {
494 Some(cmd) => cmd.clone(),
495 None => {
496 EpArgs::command().print_help().unwrap();
497 return;
498 }
499 };
500
501 match &command {
502 Command::ImportBatch(import_args) => {
503 smol::block_on(async {
504 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
505 .expect("Failed to create Anthropic client");
506 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
507 eprintln!("Error importing batches: {:?}", e);
508 std::process::exit(1);
509 }
510 println!(
511 "Successfully imported {} batch(es)",
512 import_args.batch_ids.len()
513 );
514 });
515 return;
516 }
517 Command::Clean => {
518 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
519 return;
520 }
521 Command::Synthesize(synth_args) => {
522 let Some(output_dir) = args.output else {
523 panic!("output dir is required");
524 };
525 let config = SynthesizeConfig {
526 repo_urls: synth_args.repos.clone(),
527 count: synth_args.count,
528 max_commits: synth_args.max_commits,
529 output_dir,
530 fresh: synth_args.fresh,
531 };
532 smol::block_on(async {
533 if let Err(e) = run_synthesize(config).await {
534 eprintln!("Error: {:?}", e);
535 std::process::exit(1);
536 }
537 });
538 return;
539 }
540 Command::SplitCommit(split_commit_args) => {
541 if let Err(error) = split_commit::run_split_commit(
542 split_commit_args,
543 &args.inputs,
544 output.as_ref(),
545 args.failed,
546 ) {
547 eprintln!("{error:#}");
548 std::process::exit(1);
549 }
550 return;
551 }
552 Command::Split(split_args) => {
553 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
554 eprintln!("{error:#}");
555 std::process::exit(1);
556 }
557 return;
558 }
559 Command::FilterLanguages(filter_args) => {
560 if let Err(error) =
561 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
562 {
563 eprintln!("{error:#}");
564 std::process::exit(1);
565 }
566 return;
567 }
568 Command::Qa(qa_args) => {
569 // Read examples from input files
570 let mut examples = example::read_example_files(&args.inputs);
571
572 // Apply filters
573 if let Some(name_filter) = &args.name {
574 examples.retain(|e| e.spec.name.contains(name_filter));
575 }
576 if let Some(repo_filter) = &args.repo {
577 examples.retain(|e| e.spec.repository_url.contains(repo_filter));
578 }
579 if let Some(offset) = args.offset {
580 examples.splice(0..offset, []);
581 }
582 if let Some(limit) = args.limit {
583 examples.truncate(limit);
584 }
585
586 smol::block_on(async {
587 if let Err(e) = qa::run_qa(&mut examples, qa_args, output.as_ref()).await {
588 eprintln!("Error: {:?}", e);
589 std::process::exit(1);
590 }
591 });
592 return;
593 }
594 _ => {}
595 }
596
597 let http_client = Arc::new(ReqwestClient::new());
598 let app = Application::headless().with_http_client(http_client);
599
600 app.run(move |cx| {
601 let app_state = Arc::new(headless::init(cx));
602 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
603
604 cx.spawn(async move |cx| {
605 let result = async {
606 let examples = load_examples(
607 app_state.client.http_client(),
608 &args,
609 output.as_ref(),
610 cx.background_executor().clone(),
611 )
612 .await?;
613
614 match &command {
615 Command::Predict(args) | Command::Score(args) => {
616 predict::sync_batches(args.provider.as_ref()).await?;
617 }
618 Command::Eval(args) => {
619 predict::sync_batches(args.predict.provider.as_ref()).await?;
620 }
621 _ => (),
622 }
623
624 let failfast_on_single_example = examples.len() == 1;
625
626 // For --in-place, write to a temp file and rename at the end to avoid data loss on interruption
627 let in_place_temp_path = if args.in_place {
628 output.as_ref().map(|path| {
629 let mut temp_path = path.clone();
630 temp_path.set_extension("jsonl.tmp");
631 temp_path
632 })
633 } else {
634 None
635 };
636
637 let output_sender: Option<mpsc::UnboundedSender<String>> =
638 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
639 let write_path = in_place_temp_path.as_ref().or(output.as_ref());
640 write_path.map(|path| {
641 let file = if args.in_place {
642 // For --in-place, write to temp file (truncate if exists)
643 OpenOptions::new()
644 .create(true)
645 .write(true)
646 .truncate(true)
647 .open(path)
648 .expect("Failed to open temp output file")
649 } else {
650 // For regular output, append to support resuming
651 OpenOptions::new()
652 .create(true)
653 .append(true)
654 .open(path)
655 .expect("Failed to open output file")
656 };
657 let mut writer = BufWriter::new(file);
658 let (sender, mut receiver) = mpsc::unbounded::<String>();
659 cx.background_spawn(async move {
660 while let Some(line) = receiver.next().await {
661 writeln!(writer, "{}", line).expect("Failed to write example");
662 writer.flush().expect("Failed to flush output");
663 }
664 })
665 .detach();
666 sender
667 })
668 } else {
669 None
670 };
671
672 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
673 let finished_examples = Mutex::new(Vec::new());
674
675 let mut tasks = Vec::new();
676 for _ in 0..args.max_parallelism {
677 tasks.push(async {
678 loop {
679 let Some(mut repo_examples) =
680 grouped_examples.lock().unwrap().pop_front()
681 else {
682 break;
683 };
684 for example in &mut repo_examples {
685 let example_progress =
686 Progress::global().start_group(&example.spec.name);
687
688 let result = async {
689 match &command {
690 Command::ParseExample => {}
691 Command::LoadProject => {
692 run_load_project(
693 example,
694 app_state.clone(),
695 &example_progress,
696 cx.clone(),
697 )
698 .await?;
699 }
700 Command::Context => {
701 run_context_retrieval(
702 example,
703 app_state.clone(),
704 &example_progress,
705 cx.clone(),
706 )
707 .await?;
708 }
709 Command::FormatPrompt(args) => {
710 run_format_prompt(
711 example,
712 args,
713 app_state.clone(),
714 &example_progress,
715 cx.clone(),
716 )
717 .await?;
718 }
719 Command::Predict(args) => {
720 run_prediction(
721 example,
722 args,
723 app_state.clone(),
724 &example_progress,
725 cx.clone(),
726 )
727 .await?;
728 }
729 Command::ParseOutput => {
730 parse_output::run_parse_output(example)?;
731 }
732 Command::Distill => {
733 run_distill(example).await?;
734 }
735 Command::Score(args) => {
736 run_scoring(
737 example,
738 args,
739 app_state.clone(),
740 &example_progress,
741 cx.clone(),
742 )
743 .await?;
744 }
745 Command::Eval(args) => {
746 run_scoring(
747 example,
748 &args.predict,
749 app_state.clone(),
750 &example_progress,
751 cx.clone(),
752 )
753 .await?;
754 }
755 Command::Clean
756 | Command::Synthesize(_)
757 | Command::SplitCommit(_)
758 | Command::Split(_)
759 | Command::FilterLanguages(_)
760 | Command::ImportBatch(_)
761 | Command::Qa(_) => {
762 unreachable!()
763 }
764 }
765 anyhow::Ok(())
766 }
767 .await;
768
769 let failed = if let Err(error) = result {
770 handle_error(
771 error,
772 &args,
773 &command,
774 &app_state,
775 failfast_on_single_example,
776 &example,
777 )
778 .await;
779 true
780 } else {
781 false
782 };
783
784 let should_write = !failed || args.failed == FailedHandling::Keep;
785 if should_write {
786 if let Some(ref mut sender) = output_sender.clone() {
787 let line = serde_json::to_string(&example).unwrap();
788 sender
789 .send(line)
790 .await
791 .expect("Failed to send to output writer");
792 } else if args.output.is_none()
793 && !matches!(command, Command::Eval(_))
794 {
795 let line = serde_json::to_string(&example).unwrap();
796 println!("{}", line);
797 }
798 }
799 }
800
801 let repo_url = &repo_examples.first().unwrap().spec.repository_url;
802 let project = repo_examples
803 .iter()
804 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
805 .or_else(|| app_state.project_cache.get(repo_url));
806
807 if let Some(project) = project {
808 let mut cx = cx.clone();
809
810 let shutdown_task: Task<()> =
811 project.update(&mut cx, |project, cx| {
812 let lsp_store = project.lsp_store();
813 lsp_store.update(cx, |lsp_store, cx| {
814 lsp_store.shutdown_all_language_servers(cx)
815 })
816 });
817
818 shutdown_task.await;
819
820 if let Some(ep_store) =
821 cx.update(|cx| EditPredictionStore::try_global(cx))
822 {
823 ep_store.update(&mut cx, |store, _| {
824 store.remove_project(&project);
825 });
826 }
827 }
828
829 app_state.project_cache.remove(repo_url);
830 for example in &mut repo_examples {
831 example.state.take();
832 }
833 finished_examples
834 .lock()
835 .unwrap()
836 .extend_from_slice(&repo_examples);
837 }
838 });
839 }
840 futures::future::join_all(tasks).await;
841
842 Progress::global().finalize();
843
844 match &command {
845 Command::Predict(args) | Command::Score(args) => {
846 predict::sync_batches(args.provider.as_ref()).await?;
847 }
848 Command::Eval(args) => {
849 predict::sync_batches(args.predict.provider.as_ref()).await?;
850 }
851 _ => (),
852 }
853
854 match &command {
855 Command::Eval(args) => {
856 let examples = finished_examples.lock().unwrap();
857 score::print_report(&examples);
858 if let Some(summary_path) = &args.summary_json {
859 score::write_summary_json(&examples, summary_path)?;
860 }
861 }
862 _ => (),
863 };
864
865 // For --in-place, atomically rename temp file to original
866 if let (Some(temp_path), Some(final_path)) = (&in_place_temp_path, &output) {
867 std::fs::rename(temp_path, final_path)
868 .expect("Failed to rename temp file to final output");
869 }
870
871 anyhow::Ok(())
872 }
873 .await;
874
875 if let Err(e) = result {
876 panic!("Fatal error: {:?}", e);
877 }
878
879 let _ = cx.update(|cx| cx.quit());
880 })
881 .detach();
882 });
883}
884
885async fn handle_error(
886 error: anyhow::Error,
887 args: &EpArgs,
888 command: &Command,
889 app_state: &Arc<headless::EpAppState>,
890 failfast_on_single_example: bool,
891 example: &Example,
892) {
893 Progress::global().increment_failed();
894
895 let msg;
896 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
897 let example_name = example.spec.filename();
898
899 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
900 app_state
901 .fs
902 .write(
903 &failed_example_path,
904 &serde_json::to_vec_pretty(&example).unwrap(),
905 )
906 .await
907 .unwrap();
908 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
909 app_state
910 .fs
911 .write(&err_path, format!("{error:?}").as_bytes())
912 .await
913 .unwrap();
914
915 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
916 let mut file = OpenOptions::new()
917 .create(true)
918 .append(true)
919 .open(&failed_jsonl_path)
920 .expect("Failed to open failed.jsonl");
921 writeln!(file, "{}", serde_json::to_string(example).unwrap())
922 .expect("Failed to write to failed.jsonl");
923
924 let cursor_path = example
925 .repo_name()
926 .unwrap()
927 .worktree_path()
928 .join(&example.spec.cursor_path);
929 msg = format!(
930 indoc::indoc! {"
931 While processing \"{}\":
932
933 \x1b[31m{:?}\x1b[0m
934
935 Example: \x1b[36m{}\x1b[0m
936 Error file: \x1b[36m{}\x1b[0m
937 Cursor file: \x1b[36m{}\x1b[0m
938 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
939 "},
940 example.spec.name,
941 error,
942 failed_example_path.display(),
943 err_path.display(),
944 cursor_path.display(),
945 command,
946 failed_example_path.display(),
947 );
948 } else {
949 msg = format!(
950 indoc::indoc! {"
951 While processing \"{}\":
952
953 \x1b[31m{:?}\x1b[0m
954 "},
955 example.spec.name, error
956 );
957 }
958
959 if args.failfast || failfast_on_single_example {
960 Progress::global().finalize();
961 panic!("{}", msg);
962 } else {
963 log::error!("{}", msg);
964 }
965}