1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod load_project;
9mod metrics;
10mod parse_output;
11mod paths;
12mod predict;
13mod progress;
14mod pull_examples;
15mod qa;
16mod reorder_patch;
17mod repair;
18mod retrieve_context;
19mod reversal_tracking;
20mod score;
21mod split_commit;
22mod split_dataset;
23mod synthesize;
24mod word_diff;
25use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
26use collections::HashSet;
27use edit_prediction::EditPredictionStore;
28use futures::channel::mpsc;
29use futures::{SinkExt as _, StreamExt as _};
30use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
31use zeta_prompt::ZetaVersion;
32
33use reqwest_client::ReqwestClient;
34use serde::{Deserialize, Deserializer, Serialize, Serializer};
35use std::fmt::Display;
36use std::fs::{File, OpenOptions};
37use std::hash::{Hash, Hasher};
38use std::io::{BufRead, BufReader, BufWriter, Write};
39use std::sync::Mutex;
40use std::{path::PathBuf, sync::Arc};
41
42use crate::distill::run_distill;
43use crate::example::{Example, group_examples_by_repo, read_example_files};
44use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
45use crate::format_prompt::run_format_prompt;
46use crate::load_project::run_load_project;
47use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
48use crate::predict::run_prediction;
49use crate::progress::Progress;
50use crate::retrieve_context::run_context_retrieval;
51use crate::score::run_scoring;
52use crate::split_commit::SplitCommitArgs;
53use crate::split_dataset::SplitArgs;
54use crate::synthesize::{SynthesizeConfig, run_synthesize};
55
56#[derive(Parser, Debug)]
57#[command(name = "ep")]
58struct EpArgs {
59 #[arg(long, default_value_t = false)]
60 printenv: bool,
61 #[clap(long, default_value_t = 10, global = true)]
62 max_parallelism: usize,
63 #[clap(long, global = true)]
64 limit: Option<usize>,
65 #[clap(long, global = true)]
66 offset: Option<usize>,
67 /// Filter examples by name
68 #[clap(long, global = true)]
69 name: Option<String>,
70 /// Filter examples by repository
71 #[clap(long, global = true)]
72 repo: Option<String>,
73 #[command(subcommand)]
74 command: Option<Command>,
75 #[clap(global = true, help = INPUTS_HELP)]
76 inputs: Vec<PathBuf>,
77 #[arg(long, short, global = true)]
78 output: Option<PathBuf>,
79 #[arg(long, short, global = true)]
80 in_place: bool,
81 #[arg(long, short, global = true)]
82 failfast: bool,
83 /// How to handle failed examples in output: keep them or skip them.
84 /// Failed examples are always logged to the run's failed directory.
85 #[arg(long, global = true, default_value = "keep")]
86 failed: FailedHandling,
87}
88
89/// Controls whether failed examples are included in the main output.
90/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
91#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
92pub enum FailedHandling {
93 /// Include failed examples in the main output (default)
94 #[default]
95 Keep,
96 /// Exclude failed examples from the main output
97 Skip,
98 /// Skip writing files
99 SkipNoFiles,
100}
101
102const INPUTS_HELP: &str = r#"
103Inputs can be file paths or special specifiers:
104
105 path
106 Path to an example(s) file (.md, .json, or .jsonl)
107
108 captured-after:{timestamp}
109 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
110 These are examples captured via the "Capture Edit Prediction Example" action.
111
112 rejected-after:{timestamp}
113 Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
114 These are predictions that were shown to users but rejected (useful for DPO training).
115
116 Required environment variables to connect to Snowflake:
117 EP_SNOWFLAKE_API_KEY
118 EP_SNOWFLAKE_BASE_URL
119
120 Optional:
121 EP_SNOWFLAKE_ROLE
122
123Examples:
124
125 # Read examples from a file
126 ep read examples.jsonl -o output.jsonl
127
128 # Read captured examples after a timestamp
129 ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
130
131 # Read rejected predictions for DPO training
132 ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
133
134 # Mix multiple input sources
135 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
136"#;
137
138#[derive(Subcommand, Debug, Clone)]
139enum Command {
140 /// Read examples from files or fetch from Snowflake, output as .jsonl
141 Read,
142 /// Create git worktrees for each example and load file contents
143 LoadProject,
144 /// Retrieve context for input examples.
145 Context,
146 /// Generate a prompt string for a specific model
147 FormatPrompt(FormatPromptArgs),
148 /// Runs edit prediction
149 Predict(PredictArgs),
150 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
151 /// Requires format-prompt to have been run first. Uses provider from prompt.
152 ParseOutput,
153 /// Computes a score based on actual and expected patches
154 Score(PredictArgs),
155 /// Prepares a distillation dataset by copying expected outputs to
156 /// predicted outputs and removing actual outputs and prompts.
157 Distill,
158 /// Print aggregated scores
159 Eval(EvalArgs),
160 /// Generate eval examples by analyzing git commits from a repository
161 Synthesize(SynthesizeArgs),
162 /// Remove git repositories and worktrees
163 Clean,
164 /// Generate an evaluation example by splitting a chronologically-ordered commit
165 SplitCommit(SplitCommitArgs),
166 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
167 Split(SplitArgs),
168 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
169 FilterLanguages(FilterLanguagesArgs),
170 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
171 ImportBatch(ImportBatchArgs),
172 /// Assess the quality of predictions using LLM-as-a-judge
173 Qa(qa::QaArgs),
174 /// Repair predictions that received poor QA scores by generating improved predictions
175 Repair(repair::RepairArgs),
176}
177
178impl Display for Command {
179 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180 match self {
181 Command::Read => write!(f, "read"),
182 Command::LoadProject => write!(f, "load-project"),
183 Command::Context => write!(f, "context"),
184 Command::FormatPrompt(args) => {
185 write!(f, "format-prompt --provider={}", args.provider)
186 }
187 Command::Predict(args) => match &args.provider {
188 Some(provider) => write!(f, "predict --provider={}", provider),
189 None => write!(f, "predict"),
190 },
191 Command::ParseOutput => write!(f, "parse-output"),
192 Command::Score(args) => match &args.provider {
193 Some(provider) => write!(f, "score --provider={}", provider),
194 None => write!(f, "score"),
195 },
196 Command::Distill => write!(f, "distill"),
197 Command::Eval(args) => match &args.predict.provider {
198 Some(provider) => write!(f, "eval --provider={}", provider),
199 None => write!(f, "eval"),
200 },
201 Command::Synthesize(args) => {
202 write!(f, "synthesize --repos {}", args.repos.join(" "))
203 }
204 Command::Clean => write!(f, "clean"),
205 Command::SplitCommit(_) => write!(f, "split-commit"),
206 Command::Split(_) => write!(f, "split"),
207 Command::FilterLanguages(_) => write!(f, "filter-languages"),
208 Command::ImportBatch(args) => {
209 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
210 }
211 Command::Qa(_) => {
212 write!(f, "qa")
213 }
214 Command::Repair(_) => {
215 write!(f, "repair")
216 }
217 }
218 }
219}
220
221#[derive(Debug, Args, Clone)]
222struct FormatPromptArgs {
223 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
224 provider: PredictionProvider,
225}
226
227#[derive(Debug, Args, Clone)]
228struct PredictArgs {
229 #[clap(long, short('p'))]
230 provider: Option<PredictionProvider>,
231 #[clap(long, default_value_t = 1)]
232 repetitions: usize,
233}
234
235#[derive(Debug, Args, Clone)]
236struct EvalArgs {
237 #[clap(flatten)]
238 predict: PredictArgs,
239 /// Path to write summary scores as JSON
240 #[clap(long)]
241 summary_json: Option<PathBuf>,
242}
243
244#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
245enum PredictionProvider {
246 Sweep,
247 Mercury,
248 Zeta1,
249 Zeta2(ZetaVersion),
250 Teacher(ZetaVersion),
251 TeacherNonBatching(ZetaVersion),
252 Repair,
253}
254
255impl Default for PredictionProvider {
256 fn default() -> Self {
257 PredictionProvider::Zeta2(ZetaVersion::default())
258 }
259}
260
261impl std::fmt::Display for PredictionProvider {
262 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
263 match self {
264 PredictionProvider::Sweep => write!(f, "sweep"),
265 PredictionProvider::Mercury => write!(f, "mercury"),
266 PredictionProvider::Zeta1 => write!(f, "zeta1"),
267 PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
268 PredictionProvider::Teacher(version) => write!(f, "teacher:{version}"),
269 PredictionProvider::TeacherNonBatching(version) => {
270 write!(f, "teacher-non-batching:{version}")
271 }
272 PredictionProvider::Repair => write!(f, "repair"),
273 }
274 }
275}
276
277impl std::str::FromStr for PredictionProvider {
278 type Err = anyhow::Error;
279
280 fn from_str(mut s: &str) -> Result<Self, Self::Err> {
281 let mut version = ZetaVersion::default();
282 if let Some((first, second)) = s.split_once(':') {
283 version = ZetaVersion::parse(second)?;
284 s = first;
285 }
286
287 let s_lower = s.to_lowercase();
288 match s_lower.as_str() {
289 "sweep" => Ok(PredictionProvider::Sweep),
290 "mercury" => Ok(PredictionProvider::Mercury),
291 "zeta1" => Ok(PredictionProvider::Zeta1),
292 "zeta2" => Ok(PredictionProvider::Zeta2(version)),
293 "teacher" => Ok(PredictionProvider::Teacher(version)),
294 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
295 Ok(PredictionProvider::TeacherNonBatching(version))
296 }
297 "repair" => Ok(PredictionProvider::Repair),
298 _ => {
299 anyhow::bail!(
300 "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching, repair\n\
301 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
302 Available zeta versions:\n{}",
303 ZetaVersion::options_as_string()
304 )
305 }
306 }
307 }
308}
309
310impl Serialize for PredictionProvider {
311 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
312 where
313 S: Serializer,
314 {
315 serializer.serialize_str(&self.to_string())
316 }
317}
318
319impl<'de> Deserialize<'de> for PredictionProvider {
320 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
321 where
322 D: Deserializer<'de>,
323 {
324 let s = String::deserialize(deserializer)?;
325 s.parse().map_err(serde::de::Error::custom)
326 }
327}
328
329#[derive(Debug, Args, Clone)]
330struct SynthesizeArgs {
331 /// Repository URLs (git@github.com:owner/repo or https://...)
332 #[clap(long, required = true, num_args = 1..)]
333 repos: Vec<String>,
334
335 /// Number of examples to generate per repository
336 #[clap(long, default_value_t = 5)]
337 count: usize,
338
339 /// Maximum commits to scan per repository before giving up
340 #[clap(long, default_value_t = 100)]
341 max_commits: usize,
342
343 /// Ignore state file and reprocess all commits
344 #[clap(long)]
345 fresh: bool,
346}
347
348#[derive(Debug, Args, Clone)]
349struct ImportBatchArgs {
350 /// Anthropic batch IDs to import (e.g., msgbatch_xxx)
351 #[clap(long, required = true, num_args = 1..)]
352 batch_ids: Vec<String>,
353}
354
355impl EpArgs {
356 fn output_path(&self) -> Option<PathBuf> {
357 if self.in_place {
358 if self.inputs.len() == 1 {
359 self.inputs.first().cloned()
360 } else {
361 panic!("--in-place requires exactly one input file")
362 }
363 } else {
364 self.output.clone()
365 }
366 }
367}
368
369async fn load_examples(
370 http_client: Arc<dyn http_client::HttpClient>,
371 args: &EpArgs,
372 output_path: Option<&PathBuf>,
373 background_executor: BackgroundExecutor,
374) -> anyhow::Result<Vec<Example>> {
375 let mut captured_after_timestamps = Vec::new();
376 let mut rejected_after_timestamps = Vec::new();
377 let mut file_inputs = Vec::new();
378
379 for input in &args.inputs {
380 let input_string = input.to_string_lossy();
381 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
382 captured_after_timestamps.push(timestamp.to_string());
383 } else if let Some(timestamp) =
384 pull_examples::parse_rejected_after_input(input_string.as_ref())
385 {
386 rejected_after_timestamps.push(timestamp.to_string());
387 } else {
388 file_inputs.push(input.clone());
389 }
390 }
391
392 let mut examples = read_example_files(&file_inputs);
393
394 Progress::global().set_total_examples(examples.len());
395
396 let remaining_limit_for_snowflake =
397 args.limit.map(|limit| limit.saturating_sub(examples.len()));
398
399 if let Some(0) = remaining_limit_for_snowflake {
400 log::info!(
401 "skipping Snowflake inputs because --limit is already satisfied by example files"
402 );
403 } else {
404 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
405
406 if !captured_after_timestamps.is_empty() {
407 captured_after_timestamps.sort();
408
409 let mut captured_examples = pull_examples::fetch_captured_examples_after(
410 http_client.clone(),
411 &captured_after_timestamps,
412 max_rows_per_timestamp,
413 background_executor.clone(),
414 )
415 .await?;
416 examples.append(&mut captured_examples);
417 }
418
419 if !rejected_after_timestamps.is_empty() {
420 rejected_after_timestamps.sort();
421
422 let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
423 http_client,
424 &rejected_after_timestamps,
425 max_rows_per_timestamp,
426 background_executor,
427 )
428 .await?;
429 examples.append(&mut rejected_examples);
430 }
431 }
432
433 crate::example::sort_examples_by_repo_and_rev(&mut examples);
434
435 if let Some(name_filter) = &args.name {
436 examples.retain(|example| example.spec.name.contains(name_filter));
437 }
438 if let Some(repo_filter) = &args.repo {
439 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
440 }
441
442 // Skip resume logic for --in-place since input and output are the same file,
443 // which would incorrectly treat all input examples as already processed.
444 if !args.in_place {
445 if let Some(path) = output_path {
446 resume_from_output(path, &mut examples);
447 }
448 }
449
450 if let Some(offset) = args.offset {
451 examples.splice(0..offset, []);
452 }
453
454 if let Some(limit) = args.limit {
455 examples.truncate(limit);
456 }
457
458 let progress = Progress::global();
459 progress.set_total_examples(examples.len());
460 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
461
462 Ok(examples)
463}
464
465fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
466 let mut hasher = collections::FxHasher::default();
467 spec.hash(&mut hasher);
468 hasher.finish()
469}
470
471fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
472 let file = match File::open(path) {
473 Ok(f) => f,
474 Err(_) => return,
475 };
476
477 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
478
479 let reader = BufReader::new(file);
480 let mut kept_lines = Vec::new();
481 let mut kept_hashes = HashSet::default();
482
483 for line in reader.lines() {
484 let line = match line {
485 Ok(l) => l,
486 Err(_) => continue,
487 };
488
489 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
490 let hash = spec_hash(&output_example.spec);
491 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
492 kept_hashes.insert(hash);
493 kept_lines.push(line);
494 }
495 }
496 }
497
498 let total = examples.len();
499 let already_processed = kept_hashes.len();
500
501 eprintln!(
502 "Resuming: {}/{} examples already processed",
503 already_processed, total
504 );
505
506 let file = OpenOptions::new()
507 .write(true)
508 .truncate(true)
509 .open(path)
510 .expect("Failed to open output file for rewriting");
511 let mut writer = BufWriter::new(file);
512 for line in &kept_lines {
513 writeln!(writer, "{}", line).expect("Failed to write to output file");
514 }
515 writer.flush().expect("Failed to flush output file");
516
517 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
518}
519
520fn main() {
521 let args = EpArgs::parse();
522
523 if args.printenv {
524 ::util::shell_env::print_env();
525 return;
526 }
527
528 let output = args.output_path();
529 let command = match &args.command {
530 Some(cmd) => cmd.clone(),
531 None => {
532 EpArgs::command().print_help().unwrap();
533 return;
534 }
535 };
536
537 match &command {
538 Command::ImportBatch(import_args) => {
539 smol::block_on(async {
540 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
541 .expect("Failed to create Anthropic client");
542 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
543 eprintln!("Error importing batches: {:?}", e);
544 std::process::exit(1);
545 }
546 println!(
547 "Successfully imported {} batch(es)",
548 import_args.batch_ids.len()
549 );
550 });
551 return;
552 }
553 Command::Clean => {
554 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
555 return;
556 }
557 Command::Synthesize(synth_args) => {
558 let Some(output_dir) = args.output else {
559 panic!("output dir is required");
560 };
561 let config = SynthesizeConfig {
562 repo_urls: synth_args.repos.clone(),
563 count: synth_args.count,
564 max_commits: synth_args.max_commits,
565 output_dir,
566 fresh: synth_args.fresh,
567 };
568 smol::block_on(async {
569 if let Err(e) = run_synthesize(config).await {
570 eprintln!("Error: {:?}", e);
571 std::process::exit(1);
572 }
573 });
574 return;
575 }
576 Command::SplitCommit(split_commit_args) => {
577 if let Err(error) = split_commit::run_split_commit(
578 split_commit_args,
579 &args.inputs,
580 output.as_ref(),
581 args.failed,
582 ) {
583 eprintln!("{error:#}");
584 std::process::exit(1);
585 }
586 return;
587 }
588 Command::Split(split_args) => {
589 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
590 eprintln!("{error:#}");
591 std::process::exit(1);
592 }
593 return;
594 }
595 Command::FilterLanguages(filter_args) => {
596 if let Err(error) =
597 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
598 {
599 eprintln!("{error:#}");
600 std::process::exit(1);
601 }
602 return;
603 }
604 Command::Qa(qa_args) => {
605 // Read examples from input files
606 let mut examples = example::read_example_files(&args.inputs);
607
608 // Apply filters
609 if let Some(name_filter) = &args.name {
610 examples.retain(|e| e.spec.name.contains(name_filter));
611 }
612 if let Some(repo_filter) = &args.repo {
613 examples.retain(|e| e.spec.repository_url.contains(repo_filter));
614 }
615 if let Some(offset) = args.offset {
616 examples.splice(0..offset, []);
617 }
618 if let Some(limit) = args.limit {
619 examples.truncate(limit);
620 }
621
622 smol::block_on(async {
623 if let Err(e) = qa::run_qa(&mut examples, qa_args, output.as_ref()).await {
624 eprintln!("Error: {:?}", e);
625 std::process::exit(1);
626 }
627 });
628 return;
629 }
630 Command::Repair(repair_args) => {
631 // Read examples from input files
632 let mut examples = example::read_example_files(&args.inputs);
633
634 // Apply filters
635 if let Some(name_filter) = &args.name {
636 examples.retain(|e| e.spec.name.contains(name_filter));
637 }
638 if let Some(repo_filter) = &args.repo {
639 examples.retain(|e| e.spec.repository_url.contains(repo_filter));
640 }
641 if let Some(offset) = args.offset {
642 examples.splice(0..offset, []);
643 }
644 if let Some(limit) = args.limit {
645 examples.truncate(limit);
646 }
647
648 smol::block_on(async {
649 if let Err(e) =
650 repair::run_repair(&mut examples, repair_args, output.as_ref()).await
651 {
652 eprintln!("Error: {:?}", e);
653 std::process::exit(1);
654 }
655 });
656 return;
657 }
658 _ => {}
659 }
660
661 let http_client = Arc::new(ReqwestClient::new());
662 let app = Application::headless().with_http_client(http_client);
663
664 app.run(move |cx| {
665 let app_state = Arc::new(headless::init(cx));
666 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
667
668 cx.spawn(async move |cx| {
669 let result = async {
670 let examples = load_examples(
671 app_state.client.http_client(),
672 &args,
673 output.as_ref(),
674 cx.background_executor().clone(),
675 )
676 .await?;
677
678 match &command {
679 Command::Predict(args) | Command::Score(args) => {
680 predict::sync_batches(args.provider.as_ref()).await?;
681 }
682 Command::Eval(args) => {
683 predict::sync_batches(args.predict.provider.as_ref()).await?;
684 }
685 _ => (),
686 }
687
688 let failfast_on_single_example = examples.len() == 1;
689
690 // For --in-place, write to a temp file and rename at the end to avoid data loss on interruption
691 let in_place_temp_path = if args.in_place {
692 output.as_ref().map(|path| {
693 let mut temp_path = path.clone();
694 temp_path.set_extension("jsonl.tmp");
695 temp_path
696 })
697 } else {
698 None
699 };
700
701 let output_sender: Option<mpsc::UnboundedSender<String>> =
702 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
703 let write_path = in_place_temp_path.as_ref().or(output.as_ref());
704 write_path.map(|path| {
705 let file = if args.in_place {
706 // For --in-place, write to temp file (truncate if exists)
707 OpenOptions::new()
708 .create(true)
709 .write(true)
710 .truncate(true)
711 .open(path)
712 .expect("Failed to open temp output file")
713 } else {
714 // For regular output, append to support resuming
715 OpenOptions::new()
716 .create(true)
717 .append(true)
718 .open(path)
719 .expect("Failed to open output file")
720 };
721 let mut writer = BufWriter::new(file);
722 let (sender, mut receiver) = mpsc::unbounded::<String>();
723 cx.background_spawn(async move {
724 while let Some(line) = receiver.next().await {
725 writeln!(writer, "{}", line).expect("Failed to write example");
726 writer.flush().expect("Failed to flush output");
727 }
728 })
729 .detach();
730 sender
731 })
732 } else {
733 None
734 };
735
736 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
737 let finished_examples = Mutex::new(Vec::new());
738
739 let mut tasks = Vec::new();
740 for _ in 0..args.max_parallelism {
741 tasks.push(async {
742 loop {
743 let Some(mut repo_examples) =
744 grouped_examples.lock().unwrap().pop_front()
745 else {
746 break;
747 };
748 for example in &mut repo_examples {
749 let example_progress =
750 Progress::global().start_group(&example.spec.name);
751
752 let result = async {
753 match &command {
754 Command::Read => {}
755 Command::LoadProject => {
756 run_load_project(
757 example,
758 app_state.clone(),
759 &example_progress,
760 cx.clone(),
761 )
762 .await?;
763 }
764 Command::Context => {
765 run_context_retrieval(
766 example,
767 app_state.clone(),
768 &example_progress,
769 cx.clone(),
770 )
771 .await?;
772 }
773 Command::FormatPrompt(args) => {
774 run_format_prompt(
775 example,
776 args,
777 app_state.clone(),
778 &example_progress,
779 cx.clone(),
780 )
781 .await?;
782 }
783 Command::Predict(args) => {
784 run_prediction(
785 example,
786 args,
787 app_state.clone(),
788 &example_progress,
789 cx.clone(),
790 )
791 .await?;
792 }
793 Command::ParseOutput => {
794 parse_output::run_parse_output(example)?;
795 }
796 Command::Distill => {
797 run_distill(example).await?;
798 }
799 Command::Score(args) => {
800 run_scoring(
801 example,
802 args,
803 app_state.clone(),
804 &example_progress,
805 cx.clone(),
806 )
807 .await?;
808 }
809 Command::Eval(args) => {
810 run_scoring(
811 example,
812 &args.predict,
813 app_state.clone(),
814 &example_progress,
815 cx.clone(),
816 )
817 .await?;
818 }
819 Command::Clean
820 | Command::Synthesize(_)
821 | Command::SplitCommit(_)
822 | Command::Split(_)
823 | Command::FilterLanguages(_)
824 | Command::ImportBatch(_)
825 | Command::Qa(_)
826 | Command::Repair(_) => {
827 unreachable!()
828 }
829 }
830 anyhow::Ok(())
831 }
832 .await;
833
834 let failed = if let Err(error) = result {
835 handle_error(
836 error,
837 &args,
838 &command,
839 &app_state,
840 failfast_on_single_example,
841 &example,
842 )
843 .await;
844 true
845 } else {
846 false
847 };
848
849 let should_write = !failed || args.failed == FailedHandling::Keep;
850 if should_write {
851 if let Some(ref mut sender) = output_sender.clone() {
852 let line = serde_json::to_string(&example).unwrap();
853 sender
854 .send(line)
855 .await
856 .expect("Failed to send to output writer");
857 } else if args.output.is_none()
858 && !matches!(command, Command::Eval(_))
859 {
860 let line = serde_json::to_string(&example).unwrap();
861 println!("{}", line);
862 }
863 }
864 }
865
866 let repo_url = &repo_examples.first().unwrap().spec.repository_url;
867 let project = repo_examples
868 .iter()
869 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
870 .or_else(|| app_state.project_cache.get(repo_url));
871
872 if let Some(project) = project {
873 let mut cx = cx.clone();
874
875 let shutdown_task: Task<()> =
876 project.update(&mut cx, |project, cx| {
877 let lsp_store = project.lsp_store();
878 lsp_store.update(cx, |lsp_store, cx| {
879 lsp_store.shutdown_all_language_servers(cx)
880 })
881 });
882
883 shutdown_task.await;
884
885 if let Some(ep_store) =
886 cx.update(|cx| EditPredictionStore::try_global(cx))
887 {
888 ep_store.update(&mut cx, |store, _| {
889 store.remove_project(&project);
890 });
891 }
892 }
893
894 app_state.project_cache.remove(repo_url);
895 for example in &mut repo_examples {
896 example.state.take();
897 }
898 finished_examples
899 .lock()
900 .unwrap()
901 .extend_from_slice(&repo_examples);
902 }
903 });
904 }
905 futures::future::join_all(tasks).await;
906
907 Progress::global().finalize();
908
909 match &command {
910 Command::Predict(args) | Command::Score(args) => {
911 predict::sync_batches(args.provider.as_ref()).await?;
912 }
913 Command::Eval(args) => {
914 predict::sync_batches(args.predict.provider.as_ref()).await?;
915 }
916 _ => (),
917 }
918
919 match &command {
920 Command::Eval(args) => {
921 let examples = finished_examples.lock().unwrap();
922 score::print_report(&examples);
923 if let Some(summary_path) = &args.summary_json {
924 score::write_summary_json(&examples, summary_path)?;
925 }
926 }
927 _ => (),
928 };
929
930 // For --in-place, atomically rename temp file to original
931 if let (Some(temp_path), Some(final_path)) = (&in_place_temp_path, &output) {
932 std::fs::rename(temp_path, final_path)
933 .expect("Failed to rename temp file to final output");
934 }
935
936 anyhow::Ok(())
937 }
938 .await;
939
940 if let Err(e) = result {
941 panic!("Fatal error: {:?}", e);
942 }
943
944 let _ = cx.update(|cx| cx.quit());
945 })
946 .detach();
947 });
948}
949
950async fn handle_error(
951 error: anyhow::Error,
952 args: &EpArgs,
953 command: &Command,
954 app_state: &Arc<headless::EpAppState>,
955 failfast_on_single_example: bool,
956 example: &Example,
957) {
958 Progress::global().increment_failed();
959
960 let msg;
961 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
962 let example_name = example.spec.filename();
963
964 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
965 app_state
966 .fs
967 .write(
968 &failed_example_path,
969 &serde_json::to_vec_pretty(&example).unwrap(),
970 )
971 .await
972 .unwrap();
973 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
974 app_state
975 .fs
976 .write(&err_path, format!("{error:?}").as_bytes())
977 .await
978 .unwrap();
979
980 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
981 let mut file = OpenOptions::new()
982 .create(true)
983 .append(true)
984 .open(&failed_jsonl_path)
985 .expect("Failed to open failed.jsonl");
986 writeln!(file, "{}", serde_json::to_string(example).unwrap())
987 .expect("Failed to write to failed.jsonl");
988
989 let cursor_path = example
990 .repo_name()
991 .unwrap()
992 .worktree_path()
993 .join(&example.spec.cursor_path);
994 msg = format!(
995 indoc::indoc! {"
996 While processing \"{}\":
997
998 \x1b[31m{:?}\x1b[0m
999
1000 Example: \x1b[36m{}\x1b[0m
1001 Error file: \x1b[36m{}\x1b[0m
1002 Cursor file: \x1b[36m{}\x1b[0m
1003 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1004 "},
1005 example.spec.name,
1006 error,
1007 failed_example_path.display(),
1008 err_path.display(),
1009 cursor_path.display(),
1010 command,
1011 failed_example_path.display(),
1012 );
1013 } else {
1014 msg = format!(
1015 indoc::indoc! {"
1016 While processing \"{}\":
1017
1018 \x1b[31m{:?}\x1b[0m
1019 "},
1020 example.spec.name, error
1021 );
1022 }
1023
1024 if args.failfast || failfast_on_single_example {
1025 Progress::global().finalize();
1026 panic!("{}", msg);
1027 } else {
1028 log::error!("{}", msg);
1029 }
1030}