1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod load_project;
9mod metrics;
10mod parse_output;
11mod paths;
12mod predict;
13mod progress;
14mod pull_examples;
15mod qa;
16mod reorder_patch;
17mod repair;
18mod retrieve_context;
19mod score;
20mod split_commit;
21mod split_dataset;
22mod synthesize;
23mod word_diff;
24use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
25use collections::HashSet;
26use edit_prediction::EditPredictionStore;
27use futures::channel::mpsc;
28use futures::{SinkExt as _, StreamExt as _};
29use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
30use zeta_prompt::ZetaVersion;
31
32use reqwest_client::ReqwestClient;
33use serde::{Deserialize, Deserializer, Serialize, Serializer};
34use std::fmt::Display;
35use std::fs::{File, OpenOptions};
36use std::hash::{Hash, Hasher};
37use std::io::{BufRead, BufReader, BufWriter, Write};
38use std::sync::Mutex;
39use std::{path::PathBuf, sync::Arc};
40
41use crate::distill::run_distill;
42use crate::example::{Example, group_examples_by_repo, read_example_files};
43use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
44use crate::format_prompt::run_format_prompt;
45use crate::load_project::run_load_project;
46use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
47use crate::predict::run_prediction;
48use crate::progress::Progress;
49use crate::retrieve_context::run_context_retrieval;
50use crate::score::run_scoring;
51use crate::split_commit::SplitCommitArgs;
52use crate::split_dataset::SplitArgs;
53use crate::synthesize::{SynthesizeConfig, run_synthesize};
54
55#[derive(Parser, Debug)]
56#[command(name = "ep")]
57struct EpArgs {
58 #[arg(long, default_value_t = false)]
59 printenv: bool,
60 #[clap(long, default_value_t = 10, global = true)]
61 max_parallelism: usize,
62 #[clap(long, global = true)]
63 limit: Option<usize>,
64 #[clap(long, global = true)]
65 offset: Option<usize>,
66 /// Filter examples by name
67 #[clap(long, global = true)]
68 name: Option<String>,
69 /// Filter examples by repository
70 #[clap(long, global = true)]
71 repo: Option<String>,
72 #[command(subcommand)]
73 command: Option<Command>,
74 #[clap(global = true, help = INPUTS_HELP)]
75 inputs: Vec<PathBuf>,
76 #[arg(long, short, global = true)]
77 output: Option<PathBuf>,
78 #[arg(long, short, global = true)]
79 in_place: bool,
80 #[arg(long, short, global = true)]
81 failfast: bool,
82 /// How to handle failed examples in output: keep them or skip them.
83 /// Failed examples are always logged to the run's failed directory.
84 #[arg(long, global = true, default_value = "keep")]
85 failed: FailedHandling,
86}
87
88/// Controls whether failed examples are included in the main output.
89/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
90#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
91pub enum FailedHandling {
92 /// Include failed examples in the main output (default)
93 #[default]
94 Keep,
95 /// Exclude failed examples from the main output
96 Skip,
97 /// Skip writing files
98 SkipNoFiles,
99}
100
101const INPUTS_HELP: &str = r#"
102Inputs can be file paths or special specifiers:
103
104 path
105 Path to an example(s) file (.md, .json, or .jsonl)
106
107 captured-after:{timestamp}
108 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
109 These are examples captured via the "Capture Edit Prediction Example" action.
110
111 rejected-after:{timestamp}
112 Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
113 These are predictions that were shown to users but rejected (useful for DPO training).
114
115 Required environment variables to connect to Snowflake:
116 EP_SNOWFLAKE_API_KEY
117 EP_SNOWFLAKE_BASE_URL
118
119 Optional:
120 EP_SNOWFLAKE_ROLE
121
122Examples:
123
124 # Read examples from a file
125 ep read examples.jsonl -o output.jsonl
126
127 # Read captured examples after a timestamp
128 ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
129
130 # Read rejected predictions for DPO training
131 ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
132
133 # Mix multiple input sources
134 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
135"#;
136
137#[derive(Subcommand, Debug, Clone)]
138enum Command {
139 /// Read examples from files or fetch from Snowflake, output as .jsonl
140 Read,
141 /// Create git worktrees for each example and load file contents
142 LoadProject,
143 /// Retrieve context for input examples.
144 Context,
145 /// Generate a prompt string for a specific model
146 FormatPrompt(FormatPromptArgs),
147 /// Runs edit prediction
148 Predict(PredictArgs),
149 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
150 /// Requires format-prompt to have been run first. Uses provider from prompt.
151 ParseOutput,
152 /// Computes a score based on actual and expected patches
153 Score(PredictArgs),
154 /// Prepares a distillation dataset by copying expected outputs to
155 /// predicted outputs and removing actual outputs and prompts.
156 Distill,
157 /// Print aggregated scores
158 Eval(EvalArgs),
159 /// Generate eval examples by analyzing git commits from a repository
160 Synthesize(SynthesizeArgs),
161 /// Remove git repositories and worktrees
162 Clean,
163 /// Generate an evaluation example by splitting a chronologically-ordered commit
164 SplitCommit(SplitCommitArgs),
165 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
166 Split(SplitArgs),
167 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
168 FilterLanguages(FilterLanguagesArgs),
169 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
170 ImportBatch(ImportBatchArgs),
171 /// Assess the quality of predictions using LLM-as-a-judge
172 Qa(qa::QaArgs),
173 /// Repair predictions that received poor QA scores by generating improved predictions
174 Repair(repair::RepairArgs),
175}
176
177impl Display for Command {
178 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179 match self {
180 Command::Read => write!(f, "read"),
181 Command::LoadProject => write!(f, "load-project"),
182 Command::Context => write!(f, "context"),
183 Command::FormatPrompt(args) => {
184 write!(f, "format-prompt --provider={}", args.provider)
185 }
186 Command::Predict(args) => match &args.provider {
187 Some(provider) => write!(f, "predict --provider={}", provider),
188 None => write!(f, "predict"),
189 },
190 Command::ParseOutput => write!(f, "parse-output"),
191 Command::Score(args) => match &args.provider {
192 Some(provider) => write!(f, "score --provider={}", provider),
193 None => write!(f, "score"),
194 },
195 Command::Distill => write!(f, "distill"),
196 Command::Eval(args) => match &args.predict.provider {
197 Some(provider) => write!(f, "eval --provider={}", provider),
198 None => write!(f, "eval"),
199 },
200 Command::Synthesize(args) => {
201 write!(f, "synthesize --repos {}", args.repos.join(" "))
202 }
203 Command::Clean => write!(f, "clean"),
204 Command::SplitCommit(_) => write!(f, "split-commit"),
205 Command::Split(_) => write!(f, "split"),
206 Command::FilterLanguages(_) => write!(f, "filter-languages"),
207 Command::ImportBatch(args) => {
208 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
209 }
210 Command::Qa(_) => {
211 write!(f, "qa")
212 }
213 Command::Repair(_) => {
214 write!(f, "repair")
215 }
216 }
217 }
218}
219
220#[derive(Debug, Args, Clone)]
221struct FormatPromptArgs {
222 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
223 provider: PredictionProvider,
224}
225
226#[derive(Debug, Args, Clone)]
227struct PredictArgs {
228 #[clap(long, short('p'))]
229 provider: Option<PredictionProvider>,
230 #[clap(long, default_value_t = 1)]
231 repetitions: usize,
232}
233
234#[derive(Debug, Args, Clone)]
235struct EvalArgs {
236 #[clap(flatten)]
237 predict: PredictArgs,
238 /// Path to write summary scores as JSON
239 #[clap(long)]
240 summary_json: Option<PathBuf>,
241}
242
243#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
244enum PredictionProvider {
245 Sweep,
246 Mercury,
247 Zeta1,
248 Zeta2(ZetaVersion),
249 Teacher(ZetaVersion),
250 TeacherNonBatching(ZetaVersion),
251 Repair,
252}
253
254impl Default for PredictionProvider {
255 fn default() -> Self {
256 PredictionProvider::Zeta2(ZetaVersion::default())
257 }
258}
259
260impl std::fmt::Display for PredictionProvider {
261 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262 match self {
263 PredictionProvider::Sweep => write!(f, "sweep"),
264 PredictionProvider::Mercury => write!(f, "mercury"),
265 PredictionProvider::Zeta1 => write!(f, "zeta1"),
266 PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
267 PredictionProvider::Teacher(version) => write!(f, "teacher:{version}"),
268 PredictionProvider::TeacherNonBatching(version) => {
269 write!(f, "teacher-non-batching:{version}")
270 }
271 PredictionProvider::Repair => write!(f, "repair"),
272 }
273 }
274}
275
276impl std::str::FromStr for PredictionProvider {
277 type Err = anyhow::Error;
278
279 fn from_str(mut s: &str) -> Result<Self, Self::Err> {
280 let mut version = ZetaVersion::default();
281 if let Some((first, second)) = s.split_once(':') {
282 version = ZetaVersion::parse(second)?;
283 s = first;
284 }
285
286 let s_lower = s.to_lowercase();
287 match s_lower.as_str() {
288 "sweep" => Ok(PredictionProvider::Sweep),
289 "mercury" => Ok(PredictionProvider::Mercury),
290 "zeta1" => Ok(PredictionProvider::Zeta1),
291 "zeta2" => Ok(PredictionProvider::Zeta2(version)),
292 "teacher" => Ok(PredictionProvider::Teacher(version)),
293 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
294 Ok(PredictionProvider::TeacherNonBatching(version))
295 }
296 "repair" => Ok(PredictionProvider::Repair),
297 _ => {
298 anyhow::bail!(
299 "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching, repair\n\
300 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
301 Available zeta versions:\n{}",
302 ZetaVersion::options_as_string()
303 )
304 }
305 }
306 }
307}
308
309impl Serialize for PredictionProvider {
310 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
311 where
312 S: Serializer,
313 {
314 serializer.serialize_str(&self.to_string())
315 }
316}
317
318impl<'de> Deserialize<'de> for PredictionProvider {
319 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
320 where
321 D: Deserializer<'de>,
322 {
323 let s = String::deserialize(deserializer)?;
324 s.parse().map_err(serde::de::Error::custom)
325 }
326}
327
328#[derive(Debug, Args, Clone)]
329struct SynthesizeArgs {
330 /// Repository URLs (git@github.com:owner/repo or https://...)
331 #[clap(long, required = true, num_args = 1..)]
332 repos: Vec<String>,
333
334 /// Number of examples to generate per repository
335 #[clap(long, default_value_t = 5)]
336 count: usize,
337
338 /// Maximum commits to scan per repository before giving up
339 #[clap(long, default_value_t = 100)]
340 max_commits: usize,
341
342 /// Ignore state file and reprocess all commits
343 #[clap(long)]
344 fresh: bool,
345}
346
347#[derive(Debug, Args, Clone)]
348struct ImportBatchArgs {
349 /// Anthropic batch IDs to import (e.g., msgbatch_xxx)
350 #[clap(long, required = true, num_args = 1..)]
351 batch_ids: Vec<String>,
352}
353
354impl EpArgs {
355 fn output_path(&self) -> Option<PathBuf> {
356 if self.in_place {
357 if self.inputs.len() == 1 {
358 self.inputs.first().cloned()
359 } else {
360 panic!("--in-place requires exactly one input file")
361 }
362 } else {
363 self.output.clone()
364 }
365 }
366}
367
368async fn load_examples(
369 http_client: Arc<dyn http_client::HttpClient>,
370 args: &EpArgs,
371 output_path: Option<&PathBuf>,
372 background_executor: BackgroundExecutor,
373) -> anyhow::Result<Vec<Example>> {
374 let mut captured_after_timestamps = Vec::new();
375 let mut rejected_after_timestamps = Vec::new();
376 let mut file_inputs = Vec::new();
377
378 for input in &args.inputs {
379 let input_string = input.to_string_lossy();
380 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
381 captured_after_timestamps.push(timestamp.to_string());
382 } else if let Some(timestamp) =
383 pull_examples::parse_rejected_after_input(input_string.as_ref())
384 {
385 rejected_after_timestamps.push(timestamp.to_string());
386 } else {
387 file_inputs.push(input.clone());
388 }
389 }
390
391 let mut examples = read_example_files(&file_inputs);
392
393 Progress::global().set_total_examples(examples.len());
394
395 let remaining_limit_for_snowflake =
396 args.limit.map(|limit| limit.saturating_sub(examples.len()));
397
398 if let Some(0) = remaining_limit_for_snowflake {
399 log::info!(
400 "skipping Snowflake inputs because --limit is already satisfied by example files"
401 );
402 } else {
403 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
404
405 if !captured_after_timestamps.is_empty() {
406 captured_after_timestamps.sort();
407
408 let mut captured_examples = pull_examples::fetch_captured_examples_after(
409 http_client.clone(),
410 &captured_after_timestamps,
411 max_rows_per_timestamp,
412 background_executor.clone(),
413 )
414 .await?;
415 examples.append(&mut captured_examples);
416 }
417
418 if !rejected_after_timestamps.is_empty() {
419 rejected_after_timestamps.sort();
420
421 let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
422 http_client,
423 &rejected_after_timestamps,
424 max_rows_per_timestamp,
425 background_executor,
426 )
427 .await?;
428 examples.append(&mut rejected_examples);
429 }
430 }
431
432 crate::example::sort_examples_by_repo_and_rev(&mut examples);
433
434 if let Some(name_filter) = &args.name {
435 examples.retain(|example| example.spec.name.contains(name_filter));
436 }
437 if let Some(repo_filter) = &args.repo {
438 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
439 }
440
441 // Skip resume logic for --in-place since input and output are the same file,
442 // which would incorrectly treat all input examples as already processed.
443 if !args.in_place {
444 if let Some(path) = output_path {
445 resume_from_output(path, &mut examples);
446 }
447 }
448
449 if let Some(offset) = args.offset {
450 examples.splice(0..offset, []);
451 }
452
453 if let Some(limit) = args.limit {
454 examples.truncate(limit);
455 }
456
457 let progress = Progress::global();
458 progress.set_total_examples(examples.len());
459 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
460
461 Ok(examples)
462}
463
464fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
465 let mut hasher = collections::FxHasher::default();
466 spec.hash(&mut hasher);
467 hasher.finish()
468}
469
470fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
471 let file = match File::open(path) {
472 Ok(f) => f,
473 Err(_) => return,
474 };
475
476 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
477
478 let reader = BufReader::new(file);
479 let mut kept_lines = Vec::new();
480 let mut kept_hashes = HashSet::default();
481
482 for line in reader.lines() {
483 let line = match line {
484 Ok(l) => l,
485 Err(_) => continue,
486 };
487
488 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
489 let hash = spec_hash(&output_example.spec);
490 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
491 kept_hashes.insert(hash);
492 kept_lines.push(line);
493 }
494 }
495 }
496
497 let total = examples.len();
498 let already_processed = kept_hashes.len();
499
500 eprintln!(
501 "Resuming: {}/{} examples already processed",
502 already_processed, total
503 );
504
505 let file = OpenOptions::new()
506 .write(true)
507 .truncate(true)
508 .open(path)
509 .expect("Failed to open output file for rewriting");
510 let mut writer = BufWriter::new(file);
511 for line in &kept_lines {
512 writeln!(writer, "{}", line).expect("Failed to write to output file");
513 }
514 writer.flush().expect("Failed to flush output file");
515
516 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
517}
518
519fn main() {
520 let args = EpArgs::parse();
521
522 if args.printenv {
523 ::util::shell_env::print_env();
524 return;
525 }
526
527 let output = args.output_path();
528 let command = match &args.command {
529 Some(cmd) => cmd.clone(),
530 None => {
531 EpArgs::command().print_help().unwrap();
532 return;
533 }
534 };
535
536 match &command {
537 Command::ImportBatch(import_args) => {
538 smol::block_on(async {
539 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
540 .expect("Failed to create Anthropic client");
541 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
542 eprintln!("Error importing batches: {:?}", e);
543 std::process::exit(1);
544 }
545 println!(
546 "Successfully imported {} batch(es)",
547 import_args.batch_ids.len()
548 );
549 });
550 return;
551 }
552 Command::Clean => {
553 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
554 return;
555 }
556 Command::Synthesize(synth_args) => {
557 let Some(output_dir) = args.output else {
558 panic!("output dir is required");
559 };
560 let config = SynthesizeConfig {
561 repo_urls: synth_args.repos.clone(),
562 count: synth_args.count,
563 max_commits: synth_args.max_commits,
564 output_dir,
565 fresh: synth_args.fresh,
566 };
567 smol::block_on(async {
568 if let Err(e) = run_synthesize(config).await {
569 eprintln!("Error: {:?}", e);
570 std::process::exit(1);
571 }
572 });
573 return;
574 }
575 Command::SplitCommit(split_commit_args) => {
576 if let Err(error) = split_commit::run_split_commit(
577 split_commit_args,
578 &args.inputs,
579 output.as_ref(),
580 args.failed,
581 ) {
582 eprintln!("{error:#}");
583 std::process::exit(1);
584 }
585 return;
586 }
587 Command::Split(split_args) => {
588 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
589 eprintln!("{error:#}");
590 std::process::exit(1);
591 }
592 return;
593 }
594 Command::FilterLanguages(filter_args) => {
595 if let Err(error) =
596 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
597 {
598 eprintln!("{error:#}");
599 std::process::exit(1);
600 }
601 return;
602 }
603 Command::Qa(qa_args) => {
604 // Read examples from input files
605 let mut examples = example::read_example_files(&args.inputs);
606
607 // Apply filters
608 if let Some(name_filter) = &args.name {
609 examples.retain(|e| e.spec.name.contains(name_filter));
610 }
611 if let Some(repo_filter) = &args.repo {
612 examples.retain(|e| e.spec.repository_url.contains(repo_filter));
613 }
614 if let Some(offset) = args.offset {
615 examples.splice(0..offset, []);
616 }
617 if let Some(limit) = args.limit {
618 examples.truncate(limit);
619 }
620
621 smol::block_on(async {
622 if let Err(e) = qa::run_qa(&mut examples, qa_args, output.as_ref()).await {
623 eprintln!("Error: {:?}", e);
624 std::process::exit(1);
625 }
626 });
627 return;
628 }
629 Command::Repair(repair_args) => {
630 // Read examples from input files
631 let mut examples = example::read_example_files(&args.inputs);
632
633 // Apply filters
634 if let Some(name_filter) = &args.name {
635 examples.retain(|e| e.spec.name.contains(name_filter));
636 }
637 if let Some(repo_filter) = &args.repo {
638 examples.retain(|e| e.spec.repository_url.contains(repo_filter));
639 }
640 if let Some(offset) = args.offset {
641 examples.splice(0..offset, []);
642 }
643 if let Some(limit) = args.limit {
644 examples.truncate(limit);
645 }
646
647 smol::block_on(async {
648 if let Err(e) =
649 repair::run_repair(&mut examples, repair_args, output.as_ref()).await
650 {
651 eprintln!("Error: {:?}", e);
652 std::process::exit(1);
653 }
654 });
655 return;
656 }
657 _ => {}
658 }
659
660 let http_client = Arc::new(ReqwestClient::new());
661 let app = Application::headless().with_http_client(http_client);
662
663 app.run(move |cx| {
664 let app_state = Arc::new(headless::init(cx));
665 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
666
667 cx.spawn(async move |cx| {
668 let result = async {
669 let examples = load_examples(
670 app_state.client.http_client(),
671 &args,
672 output.as_ref(),
673 cx.background_executor().clone(),
674 )
675 .await?;
676
677 match &command {
678 Command::Predict(args) | Command::Score(args) => {
679 predict::sync_batches(args.provider.as_ref()).await?;
680 }
681 Command::Eval(args) => {
682 predict::sync_batches(args.predict.provider.as_ref()).await?;
683 }
684 _ => (),
685 }
686
687 let failfast_on_single_example = examples.len() == 1;
688
689 // For --in-place, write to a temp file and rename at the end to avoid data loss on interruption
690 let in_place_temp_path = if args.in_place {
691 output.as_ref().map(|path| {
692 let mut temp_path = path.clone();
693 temp_path.set_extension("jsonl.tmp");
694 temp_path
695 })
696 } else {
697 None
698 };
699
700 let output_sender: Option<mpsc::UnboundedSender<String>> =
701 if args.output.is_some() || !matches!(command, Command::Eval(_)) {
702 let write_path = in_place_temp_path.as_ref().or(output.as_ref());
703 write_path.map(|path| {
704 let file = if args.in_place {
705 // For --in-place, write to temp file (truncate if exists)
706 OpenOptions::new()
707 .create(true)
708 .write(true)
709 .truncate(true)
710 .open(path)
711 .expect("Failed to open temp output file")
712 } else {
713 // For regular output, append to support resuming
714 OpenOptions::new()
715 .create(true)
716 .append(true)
717 .open(path)
718 .expect("Failed to open output file")
719 };
720 let mut writer = BufWriter::new(file);
721 let (sender, mut receiver) = mpsc::unbounded::<String>();
722 cx.background_spawn(async move {
723 while let Some(line) = receiver.next().await {
724 writeln!(writer, "{}", line).expect("Failed to write example");
725 writer.flush().expect("Failed to flush output");
726 }
727 })
728 .detach();
729 sender
730 })
731 } else {
732 None
733 };
734
735 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
736 let finished_examples = Mutex::new(Vec::new());
737
738 let mut tasks = Vec::new();
739 for _ in 0..args.max_parallelism {
740 tasks.push(async {
741 loop {
742 let Some(mut repo_examples) =
743 grouped_examples.lock().unwrap().pop_front()
744 else {
745 break;
746 };
747 for example in &mut repo_examples {
748 let example_progress =
749 Progress::global().start_group(&example.spec.name);
750
751 let result = async {
752 match &command {
753 Command::Read => {}
754 Command::LoadProject => {
755 run_load_project(
756 example,
757 app_state.clone(),
758 &example_progress,
759 cx.clone(),
760 )
761 .await?;
762 }
763 Command::Context => {
764 run_context_retrieval(
765 example,
766 app_state.clone(),
767 &example_progress,
768 cx.clone(),
769 )
770 .await?;
771 }
772 Command::FormatPrompt(args) => {
773 run_format_prompt(
774 example,
775 args,
776 app_state.clone(),
777 &example_progress,
778 cx.clone(),
779 )
780 .await?;
781 }
782 Command::Predict(args) => {
783 run_prediction(
784 example,
785 args,
786 app_state.clone(),
787 &example_progress,
788 cx.clone(),
789 )
790 .await?;
791 }
792 Command::ParseOutput => {
793 parse_output::run_parse_output(example)?;
794 }
795 Command::Distill => {
796 run_distill(example).await?;
797 }
798 Command::Score(args) => {
799 run_scoring(
800 example,
801 args,
802 app_state.clone(),
803 &example_progress,
804 cx.clone(),
805 )
806 .await?;
807 }
808 Command::Eval(args) => {
809 run_scoring(
810 example,
811 &args.predict,
812 app_state.clone(),
813 &example_progress,
814 cx.clone(),
815 )
816 .await?;
817 }
818 Command::Clean
819 | Command::Synthesize(_)
820 | Command::SplitCommit(_)
821 | Command::Split(_)
822 | Command::FilterLanguages(_)
823 | Command::ImportBatch(_)
824 | Command::Qa(_)
825 | Command::Repair(_) => {
826 unreachable!()
827 }
828 }
829 anyhow::Ok(())
830 }
831 .await;
832
833 let failed = if let Err(error) = result {
834 handle_error(
835 error,
836 &args,
837 &command,
838 &app_state,
839 failfast_on_single_example,
840 &example,
841 )
842 .await;
843 true
844 } else {
845 false
846 };
847
848 let should_write = !failed || args.failed == FailedHandling::Keep;
849 if should_write {
850 if let Some(ref mut sender) = output_sender.clone() {
851 let line = serde_json::to_string(&example).unwrap();
852 sender
853 .send(line)
854 .await
855 .expect("Failed to send to output writer");
856 } else if args.output.is_none()
857 && !matches!(command, Command::Eval(_))
858 {
859 let line = serde_json::to_string(&example).unwrap();
860 println!("{}", line);
861 }
862 }
863 }
864
865 let repo_url = &repo_examples.first().unwrap().spec.repository_url;
866 let project = repo_examples
867 .iter()
868 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
869 .or_else(|| app_state.project_cache.get(repo_url));
870
871 if let Some(project) = project {
872 let mut cx = cx.clone();
873
874 let shutdown_task: Task<()> =
875 project.update(&mut cx, |project, cx| {
876 let lsp_store = project.lsp_store();
877 lsp_store.update(cx, |lsp_store, cx| {
878 lsp_store.shutdown_all_language_servers(cx)
879 })
880 });
881
882 shutdown_task.await;
883
884 if let Some(ep_store) =
885 cx.update(|cx| EditPredictionStore::try_global(cx))
886 {
887 ep_store.update(&mut cx, |store, _| {
888 store.remove_project(&project);
889 });
890 }
891 }
892
893 app_state.project_cache.remove(repo_url);
894 for example in &mut repo_examples {
895 example.state.take();
896 }
897 finished_examples
898 .lock()
899 .unwrap()
900 .extend_from_slice(&repo_examples);
901 }
902 });
903 }
904 futures::future::join_all(tasks).await;
905
906 Progress::global().finalize();
907
908 match &command {
909 Command::Predict(args) | Command::Score(args) => {
910 predict::sync_batches(args.provider.as_ref()).await?;
911 }
912 Command::Eval(args) => {
913 predict::sync_batches(args.predict.provider.as_ref()).await?;
914 }
915 _ => (),
916 }
917
918 match &command {
919 Command::Eval(args) => {
920 let examples = finished_examples.lock().unwrap();
921 score::print_report(&examples);
922 if let Some(summary_path) = &args.summary_json {
923 score::write_summary_json(&examples, summary_path)?;
924 }
925 }
926 _ => (),
927 };
928
929 // For --in-place, atomically rename temp file to original
930 if let (Some(temp_path), Some(final_path)) = (&in_place_temp_path, &output) {
931 std::fs::rename(temp_path, final_path)
932 .expect("Failed to rename temp file to final output");
933 }
934
935 anyhow::Ok(())
936 }
937 .await;
938
939 if let Err(e) = result {
940 panic!("Fatal error: {:?}", e);
941 }
942
943 let _ = cx.update(|cx| cx.quit());
944 })
945 .detach();
946 });
947}
948
949async fn handle_error(
950 error: anyhow::Error,
951 args: &EpArgs,
952 command: &Command,
953 app_state: &Arc<headless::EpAppState>,
954 failfast_on_single_example: bool,
955 example: &Example,
956) {
957 Progress::global().increment_failed();
958
959 let msg;
960 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
961 let example_name = example.spec.filename();
962
963 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
964 app_state
965 .fs
966 .write(
967 &failed_example_path,
968 &serde_json::to_vec_pretty(&example).unwrap(),
969 )
970 .await
971 .unwrap();
972 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
973 app_state
974 .fs
975 .write(&err_path, format!("{error:?}").as_bytes())
976 .await
977 .unwrap();
978
979 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
980 let mut file = OpenOptions::new()
981 .create(true)
982 .append(true)
983 .open(&failed_jsonl_path)
984 .expect("Failed to open failed.jsonl");
985 writeln!(file, "{}", serde_json::to_string(example).unwrap())
986 .expect("Failed to write to failed.jsonl");
987
988 let cursor_path = example
989 .repo_name()
990 .unwrap()
991 .worktree_path()
992 .join(&example.spec.cursor_path);
993 msg = format!(
994 indoc::indoc! {"
995 While processing \"{}\":
996
997 \x1b[31m{:?}\x1b[0m
998
999 Example: \x1b[36m{}\x1b[0m
1000 Error file: \x1b[36m{}\x1b[0m
1001 Cursor file: \x1b[36m{}\x1b[0m
1002 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1003 "},
1004 example.spec.name,
1005 error,
1006 failed_example_path.display(),
1007 err_path.display(),
1008 cursor_path.display(),
1009 command,
1010 failed_example_path.display(),
1011 );
1012 } else {
1013 msg = format!(
1014 indoc::indoc! {"
1015 While processing \"{}\":
1016
1017 \x1b[31m{:?}\x1b[0m
1018 "},
1019 example.spec.name, error
1020 );
1021 }
1022
1023 if args.failfast || failfast_on_single_example {
1024 Progress::global().finalize();
1025 panic!("{}", msg);
1026 } else {
1027 log::error!("{}", msg);
1028 }
1029}