1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod load_project;
9mod metrics;
10mod openai_client;
11mod parse_output;
12mod paths;
13mod predict;
14mod progress;
15mod pull_examples;
16mod qa;
17mod reorder_patch;
18mod repair;
19mod retrieve_context;
20mod reversal_tracking;
21mod score;
22mod split_commit;
23mod split_dataset;
24mod synthesize;
25mod word_diff;
26use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
27use collections::HashSet;
28use edit_prediction::EditPredictionStore;
29use futures::channel::mpsc;
30use futures::{SinkExt as _, StreamExt as _};
31use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
32use zeta_prompt::ZetaVersion;
33
34use reqwest_client::ReqwestClient;
35use serde::{Deserialize, Deserializer, Serialize, Serializer};
36use std::fmt::Display;
37use std::fs::{File, OpenOptions};
38use std::hash::{Hash, Hasher};
39use std::io::{BufRead, BufReader, BufWriter, Write};
40use std::sync::Mutex;
41use std::{path::PathBuf, sync::Arc};
42
43use crate::distill::run_distill;
44use crate::example::{Example, group_examples_by_repo, read_example_files};
45use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
46use crate::format_prompt::run_format_prompt;
47use crate::load_project::run_load_project;
48use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
49use crate::predict::run_prediction;
50use crate::progress::Progress;
51use crate::retrieve_context::run_context_retrieval;
52use crate::score::run_scoring;
53use crate::split_commit::SplitCommitArgs;
54use crate::split_dataset::SplitArgs;
55use crate::synthesize::{SynthesizeConfig, run_synthesize};
56
57#[derive(Parser, Debug)]
58#[command(name = "ep")]
59struct EpArgs {
60 #[arg(long, default_value_t = false)]
61 printenv: bool,
62 #[clap(long, default_value_t = 10, global = true)]
63 max_parallelism: usize,
64 #[clap(long, global = true)]
65 limit: Option<usize>,
66 #[clap(long, global = true)]
67 offset: Option<usize>,
68 /// Filter examples by name
69 #[clap(long, global = true)]
70 name: Option<String>,
71 /// Filter examples by repository
72 #[clap(long, global = true)]
73 repo: Option<String>,
74 #[command(subcommand)]
75 command: Option<Command>,
76 #[clap(global = true, help = INPUTS_HELP)]
77 inputs: Vec<PathBuf>,
78 #[arg(long, short, global = true)]
79 output: Option<PathBuf>,
80 #[arg(long, short, global = true)]
81 in_place: bool,
82 #[arg(long, short, global = true)]
83 failfast: bool,
84 /// How to handle failed examples in output: keep them or skip them.
85 /// Failed examples are always logged to the run's failed directory.
86 #[arg(long, global = true, default_value = "keep")]
87 failed: FailedHandling,
88 /// Output as markdown files instead of JSONL. When set, -o specifies a directory
89 /// where one .md file per example will be written (named after each example).
90 #[arg(long, short, global = true)]
91 markdown: bool,
92}
93
94/// Controls whether failed examples are included in the main output.
95/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
96#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
97pub enum FailedHandling {
98 /// Include failed examples in the main output (default)
99 #[default]
100 Keep,
101 /// Exclude failed examples from the main output
102 Skip,
103 /// Skip writing files
104 SkipNoFiles,
105}
106
107const INPUTS_HELP: &str = r#"
108Inputs can be file paths or special specifiers:
109
110 path
111 Path to an example(s) file (.md, .json, or .jsonl)
112
113 captured-after:{timestamp}
114 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
115 These are examples captured via the "Capture Edit Prediction Example" action.
116
117 rejected-after:{timestamp}
118 Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
119 These are predictions that were shown to users but rejected (useful for DPO training).
120
121 Required environment variables to connect to Snowflake:
122 EP_SNOWFLAKE_API_KEY
123 EP_SNOWFLAKE_BASE_URL
124
125 Optional:
126 EP_SNOWFLAKE_ROLE
127
128Examples:
129
130 # Read examples from a file
131 ep read examples.jsonl -o output.jsonl
132
133 # Read captured examples after a timestamp
134 ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
135
136 # Read rejected predictions for DPO training
137 ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
138
139 # Mix multiple input sources
140 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
141"#;
142
143#[derive(Subcommand, Debug, Clone)]
144enum Command {
145 /// Read examples from files or fetch from Snowflake, output as .jsonl
146 Read,
147 /// Create git worktrees for each example and load file contents
148 LoadProject,
149 /// Retrieve context for input examples.
150 Context,
151 /// Generate a prompt string for a specific model
152 FormatPrompt(FormatPromptArgs),
153 /// Runs edit prediction
154 Predict(PredictArgs),
155 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
156 /// Requires format-prompt to have been run first. Uses provider from prompt.
157 ParseOutput,
158 /// Computes a score based on actual and expected patches
159 Score(PredictArgs),
160 /// Prepares a distillation dataset by copying expected outputs to
161 /// predicted outputs and removing actual outputs and prompts.
162 Distill,
163 /// Print aggregated scores
164 Eval(EvalArgs),
165 /// Generate eval examples by analyzing git commits from a repository
166 Synthesize(SynthesizeArgs),
167 /// Remove git repositories and worktrees
168 Clean,
169 /// Generate an evaluation example by splitting a chronologically-ordered commit
170 SplitCommit(SplitCommitArgs),
171 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
172 Split(SplitArgs),
173 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
174 FilterLanguages(FilterLanguagesArgs),
175 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
176 ImportBatch(ImportBatchArgs),
177 /// Assess the quality of predictions using LLM-as-a-judge
178 Qa(qa::QaArgs),
179 /// Repair predictions that received poor QA scores by generating improved predictions
180 Repair(repair::RepairArgs),
181}
182
183impl Display for Command {
184 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185 match self {
186 Command::Read => write!(f, "read"),
187 Command::LoadProject => write!(f, "load-project"),
188 Command::Context => write!(f, "context"),
189 Command::FormatPrompt(args) => {
190 write!(f, "format-prompt --provider={}", args.provider)
191 }
192 Command::Predict(args) => match &args.provider {
193 Some(provider) => write!(f, "predict --provider={}", provider),
194 None => write!(f, "predict"),
195 },
196 Command::ParseOutput => write!(f, "parse-output"),
197 Command::Score(args) => match &args.provider {
198 Some(provider) => write!(f, "score --provider={}", provider),
199 None => write!(f, "score"),
200 },
201 Command::Distill => write!(f, "distill"),
202 Command::Eval(args) => match &args.predict.provider {
203 Some(provider) => write!(f, "eval --provider={}", provider),
204 None => write!(f, "eval"),
205 },
206 Command::Synthesize(args) => {
207 write!(f, "synthesize --repos {}", args.repos.join(" "))
208 }
209 Command::Clean => write!(f, "clean"),
210 Command::SplitCommit(_) => write!(f, "split-commit"),
211 Command::Split(_) => write!(f, "split"),
212 Command::FilterLanguages(_) => write!(f, "filter-languages"),
213 Command::ImportBatch(args) => {
214 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
215 }
216 Command::Qa(_) => {
217 write!(f, "qa")
218 }
219 Command::Repair(_) => {
220 write!(f, "repair")
221 }
222 }
223 }
224}
225
226#[derive(Debug, Args, Clone)]
227struct FormatPromptArgs {
228 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
229 provider: PredictionProvider,
230}
231
232#[derive(Debug, Args, Clone)]
233struct PredictArgs {
234 #[clap(long, short('p'))]
235 provider: Option<PredictionProvider>,
236 #[clap(long, default_value_t = 1)]
237 repetitions: usize,
238}
239
240#[derive(Debug, Args, Clone)]
241struct EvalArgs {
242 #[clap(flatten)]
243 predict: PredictArgs,
244 /// Path to write summary scores as JSON
245 #[clap(long)]
246 summary_json: Option<PathBuf>,
247}
248
249#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
250pub enum TeacherBackend {
251 Sonnet45,
252 Gpt52,
253}
254
255impl std::fmt::Display for TeacherBackend {
256 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257 match self {
258 TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
259 TeacherBackend::Gpt52 => write!(f, "gpt52"),
260 }
261 }
262}
263
264impl std::str::FromStr for TeacherBackend {
265 type Err = anyhow::Error;
266
267 fn from_str(s: &str) -> Result<Self, Self::Err> {
268 match s.to_lowercase().as_str() {
269 "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
270 "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
271 "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
272 _ => anyhow::bail!("unknown teacher backend `{s}`. Valid options: sonnet45, gpt52"),
273 }
274 }
275}
276
277impl TeacherBackend {
278 pub fn model_name(&self) -> &'static str {
279 match self {
280 TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
281 TeacherBackend::Gpt52 => "gpt-5.2",
282 }
283 }
284}
285
286#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
287enum PredictionProvider {
288 Sweep,
289 Mercury,
290 Zeta1,
291 Zeta2(ZetaVersion),
292 Teacher(TeacherBackend),
293 TeacherNonBatching(TeacherBackend),
294 Repair,
295}
296
297impl Default for PredictionProvider {
298 fn default() -> Self {
299 PredictionProvider::Zeta2(ZetaVersion::default())
300 }
301}
302
303impl std::fmt::Display for PredictionProvider {
304 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305 match self {
306 PredictionProvider::Sweep => write!(f, "sweep"),
307 PredictionProvider::Mercury => write!(f, "mercury"),
308 PredictionProvider::Zeta1 => write!(f, "zeta1"),
309 PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
310 PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
311 PredictionProvider::TeacherNonBatching(backend) => {
312 write!(f, "teacher-non-batching:{backend}")
313 }
314 PredictionProvider::Repair => write!(f, "repair"),
315 }
316 }
317}
318
319impl std::str::FromStr for PredictionProvider {
320 type Err = anyhow::Error;
321
322 fn from_str(s: &str) -> Result<Self, Self::Err> {
323 let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
324
325 let provider_lower = provider.to_lowercase();
326 match provider_lower.as_str() {
327 "sweep" => Ok(PredictionProvider::Sweep),
328 "mercury" => Ok(PredictionProvider::Mercury),
329 "zeta1" => Ok(PredictionProvider::Zeta1),
330 "zeta2" => {
331 let version = arg.map(ZetaVersion::parse).transpose()?.unwrap_or_default();
332 Ok(PredictionProvider::Zeta2(version))
333 }
334 "teacher" => {
335 let backend = arg
336 .map(|a| a.parse())
337 .transpose()?
338 .unwrap_or(TeacherBackend::Sonnet45);
339 Ok(PredictionProvider::Teacher(backend))
340 }
341 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
342 let backend = arg
343 .map(|a| a.parse())
344 .transpose()?
345 .unwrap_or(TeacherBackend::Sonnet45);
346 Ok(PredictionProvider::TeacherNonBatching(backend))
347 }
348 "repair" => Ok(PredictionProvider::Repair),
349 _ => {
350 anyhow::bail!(
351 "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
352 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
353 For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\
354 Available zeta versions:\n{}",
355 ZetaVersion::options_as_string()
356 )
357 }
358 }
359 }
360}
361
362impl Serialize for PredictionProvider {
363 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
364 where
365 S: Serializer,
366 {
367 serializer.serialize_str(&self.to_string())
368 }
369}
370
371impl<'de> Deserialize<'de> for PredictionProvider {
372 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
373 where
374 D: Deserializer<'de>,
375 {
376 let s = String::deserialize(deserializer)?;
377 s.parse().map_err(serde::de::Error::custom)
378 }
379}
380
381#[derive(Debug, Args, Clone)]
382struct SynthesizeArgs {
383 /// Repository URLs (git@github.com:owner/repo or https://...)
384 #[clap(long, required = true, num_args = 1..)]
385 repos: Vec<String>,
386
387 /// Number of examples to generate per repository
388 #[clap(long, default_value_t = 5)]
389 count: usize,
390
391 /// Maximum commits to scan per repository before giving up
392 #[clap(long, default_value_t = 100)]
393 max_commits: usize,
394
395 /// Ignore state file and reprocess all commits
396 #[clap(long)]
397 fresh: bool,
398}
399
400#[derive(Debug, Args, Clone)]
401struct ImportBatchArgs {
402 /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
403 #[clap(long, required = true, num_args = 1..)]
404 batch_ids: Vec<String>,
405 /// Which provider's batches to import (anthropic or openai)
406 #[clap(long, default_value = "anthropic")]
407 provider: BatchProvider,
408}
409
410#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
411enum BatchProvider {
412 Anthropic,
413 Openai,
414}
415
416impl EpArgs {
417 fn output_path(&self) -> Option<PathBuf> {
418 if self.in_place {
419 if self.inputs.len() == 1 {
420 self.inputs.first().cloned()
421 } else {
422 panic!("--in-place requires exactly one input file")
423 }
424 } else {
425 self.output.clone()
426 }
427 }
428}
429
430async fn load_examples(
431 http_client: Arc<dyn http_client::HttpClient>,
432 args: &EpArgs,
433 output_path: Option<&PathBuf>,
434 background_executor: BackgroundExecutor,
435) -> anyhow::Result<Vec<Example>> {
436 let mut captured_after_timestamps = Vec::new();
437 let mut rejected_after_timestamps = Vec::new();
438 let mut file_inputs = Vec::new();
439
440 for input in &args.inputs {
441 let input_string = input.to_string_lossy();
442 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
443 captured_after_timestamps.push(timestamp.to_string());
444 } else if let Some(timestamp) =
445 pull_examples::parse_rejected_after_input(input_string.as_ref())
446 {
447 rejected_after_timestamps.push(timestamp.to_string());
448 } else {
449 file_inputs.push(input.clone());
450 }
451 }
452
453 let mut examples = read_example_files(&file_inputs);
454
455 Progress::global().set_total_examples(examples.len());
456
457 let remaining_limit_for_snowflake =
458 args.limit.map(|limit| limit.saturating_sub(examples.len()));
459
460 if let Some(0) = remaining_limit_for_snowflake {
461 log::info!(
462 "skipping Snowflake inputs because --limit is already satisfied by example files"
463 );
464 } else {
465 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
466
467 if !captured_after_timestamps.is_empty() {
468 captured_after_timestamps.sort();
469
470 let mut captured_examples = pull_examples::fetch_captured_examples_after(
471 http_client.clone(),
472 &captured_after_timestamps,
473 max_rows_per_timestamp,
474 background_executor.clone(),
475 )
476 .await?;
477 examples.append(&mut captured_examples);
478 }
479
480 if !rejected_after_timestamps.is_empty() {
481 rejected_after_timestamps.sort();
482
483 let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
484 http_client,
485 &rejected_after_timestamps,
486 max_rows_per_timestamp,
487 background_executor,
488 )
489 .await?;
490 examples.append(&mut rejected_examples);
491 }
492 }
493
494 crate::example::sort_examples_by_repo_and_rev(&mut examples);
495
496 if let Some(name_filter) = &args.name {
497 examples.retain(|example| example.spec.name.contains(name_filter));
498 }
499 if let Some(repo_filter) = &args.repo {
500 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
501 }
502
503 // Skip resume logic for --in-place since input and output are the same file,
504 // which would incorrectly treat all input examples as already processed.
505 if !args.in_place {
506 if let Some(path) = output_path {
507 resume_from_output(path, &mut examples);
508 }
509 }
510
511 if let Some(offset) = args.offset {
512 examples.splice(0..offset, []);
513 }
514
515 if let Some(limit) = args.limit {
516 examples.truncate(limit);
517 }
518
519 let progress = Progress::global();
520 progress.set_total_examples(examples.len());
521 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
522
523 Ok(examples)
524}
525
526fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
527 let mut hasher = collections::FxHasher::default();
528 spec.hash(&mut hasher);
529 hasher.finish()
530}
531
532fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
533 let file = match File::open(path) {
534 Ok(f) => f,
535 Err(_) => return,
536 };
537
538 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
539
540 let reader = BufReader::new(file);
541 let mut kept_lines = Vec::new();
542 let mut kept_hashes = HashSet::default();
543
544 for line in reader.lines() {
545 let line = match line {
546 Ok(l) => l,
547 Err(_) => continue,
548 };
549
550 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
551 let hash = spec_hash(&output_example.spec);
552 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
553 kept_hashes.insert(hash);
554 kept_lines.push(line);
555 }
556 }
557 }
558
559 let total = examples.len();
560 let already_processed = kept_hashes.len();
561
562 eprintln!(
563 "Resuming: {}/{} examples already processed",
564 already_processed, total
565 );
566
567 let file = OpenOptions::new()
568 .write(true)
569 .truncate(true)
570 .open(path)
571 .expect("Failed to open output file for rewriting");
572 let mut writer = BufWriter::new(file);
573 for line in &kept_lines {
574 writeln!(writer, "{}", line).expect("Failed to write to output file");
575 }
576 writer.flush().expect("Failed to flush output file");
577
578 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
579}
580
581fn main() {
582 let args = EpArgs::parse();
583
584 if args.printenv {
585 ::util::shell_env::print_env();
586 return;
587 }
588
589 let output = args.output_path();
590
591 if args.markdown && output.is_none() {
592 eprintln!("--markdown requires -o to specify the output directory");
593 std::process::exit(1);
594 }
595
596 let command = match &args.command {
597 Some(cmd) => cmd.clone(),
598 None => {
599 EpArgs::command().print_help().unwrap();
600 return;
601 }
602 };
603
604 match &command {
605 Command::ImportBatch(import_args) => {
606 smol::block_on(async {
607 match import_args.provider {
608 BatchProvider::Anthropic => {
609 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
610 .expect("Failed to create Anthropic client");
611 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
612 eprintln!("Error importing Anthropic batches: {:?}", e);
613 std::process::exit(1);
614 }
615 }
616 BatchProvider::Openai => {
617 let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
618 .expect("Failed to create OpenAI client");
619 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
620 eprintln!("Error importing OpenAI batches: {:?}", e);
621 std::process::exit(1);
622 }
623 }
624 }
625 println!(
626 "Successfully imported {} batch(es)",
627 import_args.batch_ids.len()
628 );
629 });
630 return;
631 }
632 Command::Clean => {
633 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
634 return;
635 }
636 Command::Synthesize(synth_args) => {
637 let Some(output_dir) = args.output else {
638 panic!("output dir is required");
639 };
640 let config = SynthesizeConfig {
641 repo_urls: synth_args.repos.clone(),
642 count: synth_args.count,
643 max_commits: synth_args.max_commits,
644 output_dir,
645 fresh: synth_args.fresh,
646 };
647 smol::block_on(async {
648 if let Err(e) = run_synthesize(config).await {
649 eprintln!("Error: {:?}", e);
650 std::process::exit(1);
651 }
652 });
653 return;
654 }
655 Command::SplitCommit(split_commit_args) => {
656 if let Err(error) = split_commit::run_split_commit(
657 split_commit_args,
658 &args.inputs,
659 output.as_ref(),
660 args.failed,
661 ) {
662 eprintln!("{error:#}");
663 std::process::exit(1);
664 }
665 return;
666 }
667 Command::Split(split_args) => {
668 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
669 eprintln!("{error:#}");
670 std::process::exit(1);
671 }
672 return;
673 }
674 Command::FilterLanguages(filter_args) => {
675 if let Err(error) =
676 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
677 {
678 eprintln!("{error:#}");
679 std::process::exit(1);
680 }
681 return;
682 }
683 Command::Qa(qa_args) => {
684 // Read examples from input files
685 let mut examples = example::read_example_files(&args.inputs);
686
687 // Apply filters
688 if let Some(name_filter) = &args.name {
689 examples.retain(|e| e.spec.name.contains(name_filter));
690 }
691 if let Some(repo_filter) = &args.repo {
692 examples.retain(|e| e.spec.repository_url.contains(repo_filter));
693 }
694 if let Some(offset) = args.offset {
695 examples.splice(0..offset, []);
696 }
697 if let Some(limit) = args.limit {
698 examples.truncate(limit);
699 }
700
701 smol::block_on(async {
702 if let Err(e) = qa::run_qa(&mut examples, qa_args, output.as_ref()).await {
703 eprintln!("Error: {:?}", e);
704 std::process::exit(1);
705 }
706 });
707 return;
708 }
709 Command::Repair(repair_args) => {
710 // Read examples from input files
711 let mut examples = example::read_example_files(&args.inputs);
712
713 // Apply filters
714 if let Some(name_filter) = &args.name {
715 examples.retain(|e| e.spec.name.contains(name_filter));
716 }
717 if let Some(repo_filter) = &args.repo {
718 examples.retain(|e| e.spec.repository_url.contains(repo_filter));
719 }
720 if let Some(offset) = args.offset {
721 examples.splice(0..offset, []);
722 }
723 if let Some(limit) = args.limit {
724 examples.truncate(limit);
725 }
726
727 smol::block_on(async {
728 if let Err(e) =
729 repair::run_repair(&mut examples, repair_args, output.as_ref()).await
730 {
731 eprintln!("Error: {:?}", e);
732 std::process::exit(1);
733 }
734 });
735 return;
736 }
737 _ => {}
738 }
739
740 let http_client = Arc::new(ReqwestClient::new());
741 let app = Application::headless().with_http_client(http_client);
742
743 app.run(move |cx| {
744 let app_state = Arc::new(headless::init(cx));
745 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
746
747 cx.spawn(async move |cx| {
748 let result = async {
749 let examples = load_examples(
750 app_state.client.http_client(),
751 &args,
752 output.as_ref(),
753 cx.background_executor().clone(),
754 )
755 .await?;
756
757 match &command {
758 Command::Predict(args) | Command::Score(args) => {
759 predict::sync_batches(args.provider.as_ref()).await?;
760 }
761 Command::Eval(args) => {
762 predict::sync_batches(args.predict.provider.as_ref()).await?;
763 }
764 _ => (),
765 }
766
767 let failfast_on_single_example = examples.len() == 1;
768
769 // For --markdown mode, create the output directory if it doesn't exist
770 let markdown_output_dir = if args.markdown {
771 let dir = output.as_ref().expect("--markdown requires -o");
772 if !dir.exists() {
773 std::fs::create_dir_all(dir)
774 .expect("Failed to create markdown output directory");
775 }
776 Some(dir.clone())
777 } else {
778 None
779 };
780
781 // For --in-place, write to a temp file and rename at the end to avoid data loss on interruption
782 let in_place_temp_path = if args.in_place {
783 output.as_ref().map(|path| {
784 let mut temp_path = path.clone();
785 temp_path.set_extension("jsonl.tmp");
786 temp_path
787 })
788 } else {
789 None
790 };
791
792 let output_sender: Option<mpsc::UnboundedSender<String>> = if !args.markdown
793 && (args.output.is_some() || !matches!(command, Command::Eval(_)))
794 {
795 let write_path = in_place_temp_path.as_ref().or(output.as_ref());
796 write_path.map(|path| {
797 let file = if args.in_place {
798 // For --in-place, write to temp file (truncate if exists)
799 OpenOptions::new()
800 .create(true)
801 .write(true)
802 .truncate(true)
803 .open(path)
804 .expect("Failed to open temp output file")
805 } else {
806 // For regular output, append to support resuming
807 OpenOptions::new()
808 .create(true)
809 .append(true)
810 .open(path)
811 .expect("Failed to open output file")
812 };
813 let mut writer = BufWriter::new(file);
814 let (sender, mut receiver) = mpsc::unbounded::<String>();
815 cx.background_spawn(async move {
816 while let Some(line) = receiver.next().await {
817 writeln!(writer, "{}", line).expect("Failed to write example");
818 writer.flush().expect("Failed to flush output");
819 }
820 })
821 .detach();
822 sender
823 })
824 } else {
825 None
826 };
827
828 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
829 let finished_examples = Mutex::new(Vec::new());
830
831 let mut tasks = Vec::new();
832 for _ in 0..args.max_parallelism {
833 tasks.push(async {
834 loop {
835 let Some(mut repo_examples) =
836 grouped_examples.lock().unwrap().pop_front()
837 else {
838 break;
839 };
840 for example in &mut repo_examples {
841 let example_progress =
842 Progress::global().start_group(&example.spec.name);
843
844 let result = async {
845 match &command {
846 Command::Read => {}
847 Command::LoadProject => {
848 run_load_project(
849 example,
850 app_state.clone(),
851 &example_progress,
852 cx.clone(),
853 )
854 .await?;
855 }
856 Command::Context => {
857 run_context_retrieval(
858 example,
859 app_state.clone(),
860 &example_progress,
861 cx.clone(),
862 )
863 .await?;
864 }
865 Command::FormatPrompt(args) => {
866 run_format_prompt(
867 example,
868 args,
869 app_state.clone(),
870 &example_progress,
871 cx.clone(),
872 )
873 .await?;
874 }
875 Command::Predict(args) => {
876 run_prediction(
877 example,
878 args,
879 app_state.clone(),
880 &example_progress,
881 cx.clone(),
882 )
883 .await?;
884 }
885 Command::ParseOutput => {
886 parse_output::run_parse_output(example)?;
887 }
888 Command::Distill => {
889 run_distill(example).await?;
890 }
891 Command::Score(args) => {
892 run_scoring(
893 example,
894 args,
895 app_state.clone(),
896 &example_progress,
897 cx.clone(),
898 )
899 .await?;
900 }
901 Command::Eval(args) => {
902 run_scoring(
903 example,
904 &args.predict,
905 app_state.clone(),
906 &example_progress,
907 cx.clone(),
908 )
909 .await?;
910 }
911 Command::Clean
912 | Command::Synthesize(_)
913 | Command::SplitCommit(_)
914 | Command::Split(_)
915 | Command::FilterLanguages(_)
916 | Command::ImportBatch(_)
917 | Command::Qa(_)
918 | Command::Repair(_) => {
919 unreachable!()
920 }
921 }
922 anyhow::Ok(())
923 }
924 .await;
925
926 let failed = if let Err(error) = result {
927 handle_error(
928 error,
929 &args,
930 &command,
931 &app_state,
932 failfast_on_single_example,
933 &example,
934 )
935 .await;
936 true
937 } else {
938 false
939 };
940
941 let should_write = !failed || args.failed == FailedHandling::Keep;
942 if should_write {
943 if let Some(ref markdown_dir) = markdown_output_dir {
944 let filename = format!("{}.md", example.spec.filename());
945 let path = markdown_dir.join(&filename);
946 let markdown = example.spec.to_markdown();
947 std::fs::write(&path, &markdown)
948 .expect("Failed to write markdown file");
949 } else if let Some(ref mut sender) = output_sender.clone() {
950 let line = serde_json::to_string(&example).unwrap();
951 sender
952 .send(line)
953 .await
954 .expect("Failed to send to output writer");
955 } else if args.output.is_none()
956 && !matches!(command, Command::Eval(_))
957 {
958 let line = serde_json::to_string(&example).unwrap();
959 println!("{}", line);
960 }
961 }
962 }
963
964 let repo_url = &repo_examples.first().unwrap().spec.repository_url;
965 let project = repo_examples
966 .iter()
967 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
968 .or_else(|| app_state.project_cache.get(repo_url));
969
970 if let Some(project) = project {
971 let mut cx = cx.clone();
972
973 let shutdown_task: Task<()> =
974 project.update(&mut cx, |project, cx| {
975 let lsp_store = project.lsp_store();
976 lsp_store.update(cx, |lsp_store, cx| {
977 lsp_store.shutdown_all_language_servers(cx)
978 })
979 });
980
981 shutdown_task.await;
982
983 if let Some(ep_store) =
984 cx.update(|cx| EditPredictionStore::try_global(cx))
985 {
986 ep_store.update(&mut cx, |store, _| {
987 store.remove_project(&project);
988 });
989 }
990 }
991
992 app_state.project_cache.remove(repo_url);
993 for example in &mut repo_examples {
994 example.state.take();
995 }
996 finished_examples
997 .lock()
998 .unwrap()
999 .extend_from_slice(&repo_examples);
1000 }
1001 });
1002 }
1003 futures::future::join_all(tasks).await;
1004
1005 Progress::global().finalize();
1006
1007 match &command {
1008 Command::Predict(args) | Command::Score(args) => {
1009 predict::sync_batches(args.provider.as_ref()).await?;
1010 }
1011 Command::Eval(args) => {
1012 predict::sync_batches(args.predict.provider.as_ref()).await?;
1013 }
1014 _ => (),
1015 }
1016
1017 match &command {
1018 Command::Eval(args) => {
1019 let examples = finished_examples.lock().unwrap();
1020 score::print_report(&examples);
1021 if let Some(summary_path) = &args.summary_json {
1022 score::write_summary_json(&examples, summary_path)?;
1023 }
1024 }
1025 _ => (),
1026 };
1027
1028 // For --in-place, atomically rename temp file to original
1029 if let (Some(temp_path), Some(final_path)) = (&in_place_temp_path, &output) {
1030 std::fs::rename(temp_path, final_path)
1031 .expect("Failed to rename temp file to final output");
1032 }
1033
1034 anyhow::Ok(())
1035 }
1036 .await;
1037
1038 if let Err(e) = result {
1039 panic!("Fatal error: {:?}", e);
1040 }
1041
1042 let _ = cx.update(|cx| cx.quit());
1043 })
1044 .detach();
1045 });
1046}
1047
1048async fn handle_error(
1049 error: anyhow::Error,
1050 args: &EpArgs,
1051 command: &Command,
1052 app_state: &Arc<headless::EpAppState>,
1053 failfast_on_single_example: bool,
1054 example: &Example,
1055) {
1056 Progress::global().increment_failed();
1057
1058 let msg;
1059 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1060 let example_name = example.spec.filename();
1061
1062 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1063 app_state
1064 .fs
1065 .write(
1066 &failed_example_path,
1067 &serde_json::to_vec_pretty(&example).unwrap(),
1068 )
1069 .await
1070 .unwrap();
1071 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1072 app_state
1073 .fs
1074 .write(&err_path, format!("{error:?}").as_bytes())
1075 .await
1076 .unwrap();
1077
1078 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1079 let mut file = OpenOptions::new()
1080 .create(true)
1081 .append(true)
1082 .open(&failed_jsonl_path)
1083 .expect("Failed to open failed.jsonl");
1084 writeln!(file, "{}", serde_json::to_string(example).unwrap())
1085 .expect("Failed to write to failed.jsonl");
1086
1087 let cursor_path = example
1088 .repo_name()
1089 .unwrap()
1090 .worktree_path()
1091 .join(&example.spec.cursor_path);
1092 msg = format!(
1093 indoc::indoc! {"
1094 While processing \"{}\":
1095
1096 \x1b[31m{:?}\x1b[0m
1097
1098 Example: \x1b[36m{}\x1b[0m
1099 Error file: \x1b[36m{}\x1b[0m
1100 Cursor file: \x1b[36m{}\x1b[0m
1101 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1102 "},
1103 example.spec.name,
1104 error,
1105 failed_example_path.display(),
1106 err_path.display(),
1107 cursor_path.display(),
1108 command,
1109 failed_example_path.display(),
1110 );
1111 } else {
1112 msg = format!(
1113 indoc::indoc! {"
1114 While processing \"{}\":
1115
1116 \x1b[31m{:?}\x1b[0m
1117 "},
1118 example.spec.name, error
1119 );
1120 }
1121
1122 if args.failfast || failfast_on_single_example {
1123 Progress::global().finalize();
1124 panic!("{}", msg);
1125 } else {
1126 log::error!("{}", msg);
1127 }
1128}