1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod load_project;
9mod metrics;
10mod openai_client;
11mod parse_output;
12mod paths;
13mod predict;
14mod progress;
15mod prompt_assets;
16mod pull_examples;
17mod qa;
18mod reorder_patch;
19mod repair;
20mod retrieve_context;
21mod reversal_tracking;
22mod score;
23mod split_commit;
24mod split_dataset;
25mod synthesize;
26mod truncate_expected_patch;
27mod word_diff;
28use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
29use collections::HashSet;
30use edit_prediction::EditPredictionStore;
31use futures::channel::mpsc;
32use futures::{SinkExt as _, StreamExt as _};
33use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
34use zeta_prompt::ZetaFormat;
35
36use reqwest_client::ReqwestClient;
37use serde::{Deserialize, Deserializer, Serialize, Serializer};
38use std::fmt::Display;
39use std::fs::{File, OpenOptions};
40use std::hash::{Hash, Hasher};
41use std::io::{BufRead, BufReader, BufWriter, Write};
42use std::sync::Mutex;
43use std::{path::PathBuf, sync::Arc};
44
45use crate::distill::run_distill;
46use crate::example::{Example, group_examples_by_repo, read_example_files};
47use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
48use crate::format_prompt::run_format_prompt;
49use crate::load_project::run_load_project;
50use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
51use crate::predict::run_prediction;
52use crate::progress::Progress;
53use crate::retrieve_context::run_context_retrieval;
54use crate::score::run_scoring;
55use crate::split_commit::SplitCommitArgs;
56use crate::split_dataset::SplitArgs;
57use crate::synthesize::{SynthesizeConfig, run_synthesize};
58use crate::truncate_expected_patch::TruncatePatchArgs;
59
60#[derive(Parser, Debug)]
61#[command(name = "ep")]
62struct EpArgs {
63 #[arg(long, default_value_t = false)]
64 printenv: bool,
65 #[clap(long, default_value_t = 10, global = true)]
66 max_parallelism: usize,
67 /// The limit for the number of examples to process
68 /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
69 #[clap(long, global = true)]
70 limit: Option<usize>,
71 #[clap(long, global = true)]
72 offset: Option<usize>,
73 /// Filter examples by name
74 #[clap(long, global = true)]
75 name: Option<String>,
76 /// Filter examples by repository
77 #[clap(long, global = true)]
78 repo: Option<String>,
79 #[command(subcommand)]
80 command: Option<Command>,
81 #[clap(global = true, help = INPUTS_HELP)]
82 inputs: Vec<PathBuf>,
83 #[arg(long, short, global = true)]
84 output: Option<PathBuf>,
85 #[arg(long, short, global = true)]
86 in_place: bool,
87 #[arg(long, short, global = true)]
88 failfast: bool,
89 /// How to handle failed examples in output: keep them or skip them.
90 /// Failed examples are always logged to the run's failed directory.
91 #[arg(long, global = true, default_value = "keep")]
92 failed: FailedHandling,
93 /// Output as markdown files instead of JSONL. When set, -o specifies a directory
94 /// where one .md file per example will be written (named after each example).
95 #[arg(long, short, global = true)]
96 markdown: bool,
97}
98
99/// Controls whether failed examples are included in the main output.
100/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
101#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
102pub enum FailedHandling {
103 /// Include failed examples in the main output (default)
104 #[default]
105 Keep,
106 /// Exclude failed examples from the main output
107 Skip,
108 /// Skip writing files
109 SkipNoFiles,
110}
111
112const INPUTS_HELP: &str = r#"
113Inputs can be file paths or special specifiers:
114
115 path
116 Path to an example(s) file (.md, .json, or .jsonl)
117
118 captured-after:{timestamp}
119 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
120 These are examples captured via the "Capture Edit Prediction Example" action.
121
122 rejected-after:{timestamp}
123 Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
124 These are predictions that were shown to users but rejected (useful for DPO training).
125
126 rated-after:{timestamp}
127 Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
128 These are predictions that users explicitly rated as positive or negative via the
129 rate completions modal. Only zeta2 predictions are included.
130 - Positive ratings: output becomes expected_patches
131 - Negative ratings: output becomes rejected_patch
132
133 rated-positive-after:{timestamp}
134 Same as rated-after, but only fetches positively rated predictions.
135
136 rated-negative-after:{timestamp}
137 Same as rated-after, but only fetches negatively rated predictions.
138
139 Required environment variables to connect to Snowflake:
140 EP_SNOWFLAKE_API_KEY
141 EP_SNOWFLAKE_BASE_URL
142
143 Optional:
144 EP_SNOWFLAKE_ROLE
145
146Examples:
147
148 # Read examples from a file
149 ep read examples.jsonl -o output.jsonl
150
151 # Read captured examples after a timestamp
152 ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
153
154 # Read rejected predictions for DPO training
155 ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
156
157 # Read user-rated predictions
158 ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
159
160 # Read only positively rated predictions
161 ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
162
163 # Read only negatively rated predictions
164 ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
165
166 # Mix multiple input sources
167 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
168"#;
169
170#[derive(Subcommand, Debug, Clone)]
171enum Command {
172 /// Read examples from files or fetch from Snowflake, output as .jsonl
173 Read,
174 /// Create git worktrees for each example and load file contents
175 LoadProject,
176 /// Retrieve context for input examples.
177 Context,
178 /// Generate a prompt string for a specific model
179 FormatPrompt(FormatPromptArgs),
180 /// Runs edit prediction
181 Predict(PredictArgs),
182 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
183 /// Requires format-prompt to have been run first. Uses provider from prompt.
184 ParseOutput,
185 /// Computes a score based on actual and expected patches
186 Score(PredictArgs),
187 /// Prepares a distillation dataset by copying expected outputs to
188 /// predicted outputs and removing actual outputs and prompts.
189 Distill,
190 /// Print aggregated scores
191 Eval(EvalArgs),
192 /// Generate eval examples by analyzing git commits from a repository
193 Synthesize(SynthesizeArgs),
194 /// Remove git repositories and worktrees
195 Clean,
196 /// Generate an evaluation example by splitting a chronologically-ordered commit
197 SplitCommit(SplitCommitArgs),
198 /// Truncate expected patch by the given criteria
199 TruncatePatch(TruncatePatchArgs),
200 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
201 Split(SplitArgs),
202 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
203 FilterLanguages(FilterLanguagesArgs),
204 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
205 ImportBatch(ImportBatchArgs),
206 /// Assess the quality of predictions using LLM-as-a-judge
207 Qa(qa::QaArgs),
208 /// Repair predictions that received poor QA scores by generating improved predictions
209 Repair(repair::RepairArgs),
210 /// Print all valid zeta formats (lowercase, one per line)
211 PrintZetaFormats,
212}
213
214impl Display for Command {
215 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216 match self {
217 Command::Read => write!(f, "read"),
218 Command::LoadProject => write!(f, "load-project"),
219 Command::Context => write!(f, "context"),
220 Command::FormatPrompt(args) => {
221 write!(f, "format-prompt --provider={}", args.provider)
222 }
223 Command::Predict(args) => match &args.provider {
224 Some(provider) => write!(f, "predict --provider={}", provider),
225 None => write!(f, "predict"),
226 },
227 Command::ParseOutput => write!(f, "parse-output"),
228 Command::Score(args) => match &args.provider {
229 Some(provider) => write!(f, "score --provider={}", provider),
230 None => write!(f, "score"),
231 },
232 Command::Distill => write!(f, "distill"),
233 Command::Eval(args) => match &args.predict.provider {
234 Some(provider) => write!(f, "eval --provider={}", provider),
235 None => write!(f, "eval"),
236 },
237 Command::Synthesize(args) => {
238 write!(f, "synthesize --repos {}", args.repos.join(" "))
239 }
240 Command::Clean => write!(f, "clean"),
241 Command::SplitCommit(_) => write!(f, "split-commit"),
242 Command::TruncatePatch(_) => write!(f, "truncate-patch"),
243 Command::Split(_) => write!(f, "split"),
244 Command::FilterLanguages(_) => write!(f, "filter-languages"),
245 Command::ImportBatch(args) => {
246 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
247 }
248 Command::Qa(_) => {
249 write!(f, "qa")
250 }
251 Command::Repair(_) => {
252 write!(f, "repair")
253 }
254 Command::PrintZetaFormats => {
255 write!(f, "print-zeta-formats")
256 }
257 }
258 }
259}
260
261#[derive(Debug, Args, Clone)]
262struct FormatPromptArgs {
263 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
264 provider: PredictionProvider,
265}
266
267#[derive(Debug, Args, Clone)]
268struct PredictArgs {
269 #[clap(long, short('p'))]
270 provider: Option<PredictionProvider>,
271 #[clap(long, default_value_t = 1)]
272 repetitions: usize,
273 /// Only use cached responses, don't queue new requests for batching
274 #[clap(long)]
275 cache_only: bool,
276}
277
278#[derive(Debug, Args, Clone)]
279struct EvalArgs {
280 #[clap(flatten)]
281 predict: PredictArgs,
282 /// Path to write summary scores as JSON
283 #[clap(long)]
284 summary_json: Option<PathBuf>,
285}
286
287#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
288pub enum TeacherBackend {
289 Sonnet45,
290 Gpt52,
291}
292
293impl std::fmt::Display for TeacherBackend {
294 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295 match self {
296 TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
297 TeacherBackend::Gpt52 => write!(f, "gpt52"),
298 }
299 }
300}
301
302impl std::str::FromStr for TeacherBackend {
303 type Err = anyhow::Error;
304
305 fn from_str(s: &str) -> Result<Self, Self::Err> {
306 match s.to_lowercase().as_str() {
307 "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
308 "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
309 "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
310 _ => anyhow::bail!("unknown teacher backend `{s}`. Valid options: sonnet45, gpt52"),
311 }
312 }
313}
314
315impl TeacherBackend {
316 pub fn model_name(&self) -> &'static str {
317 match self {
318 TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
319 TeacherBackend::Gpt52 => "gpt-5.2",
320 }
321 }
322}
323
324#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
325enum PredictionProvider {
326 Sweep,
327 Mercury,
328 Zeta1,
329 Zeta2(ZetaFormat),
330 Teacher(TeacherBackend),
331 TeacherNonBatching(TeacherBackend),
332 Repair,
333}
334
335impl Default for PredictionProvider {
336 fn default() -> Self {
337 PredictionProvider::Zeta2(ZetaFormat::default())
338 }
339}
340
341impl std::fmt::Display for PredictionProvider {
342 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
343 match self {
344 PredictionProvider::Sweep => write!(f, "sweep"),
345 PredictionProvider::Mercury => write!(f, "mercury"),
346 PredictionProvider::Zeta1 => write!(f, "zeta1"),
347 PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
348 PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
349 PredictionProvider::TeacherNonBatching(backend) => {
350 write!(f, "teacher-non-batching:{backend}")
351 }
352 PredictionProvider::Repair => write!(f, "repair"),
353 }
354 }
355}
356
357impl std::str::FromStr for PredictionProvider {
358 type Err = anyhow::Error;
359
360 fn from_str(s: &str) -> Result<Self, Self::Err> {
361 let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
362
363 let provider_lower = provider.to_lowercase();
364 match provider_lower.as_str() {
365 "sweep" => Ok(PredictionProvider::Sweep),
366 "mercury" => Ok(PredictionProvider::Mercury),
367 "zeta1" => Ok(PredictionProvider::Zeta1),
368 "zeta2" => {
369 let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
370 Ok(PredictionProvider::Zeta2(format))
371 }
372 "teacher" => {
373 let backend = arg
374 .map(|a| a.parse())
375 .transpose()?
376 .unwrap_or(TeacherBackend::Sonnet45);
377 Ok(PredictionProvider::Teacher(backend))
378 }
379 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
380 let backend = arg
381 .map(|a| a.parse())
382 .transpose()?
383 .unwrap_or(TeacherBackend::Sonnet45);
384 Ok(PredictionProvider::TeacherNonBatching(backend))
385 }
386 "repair" => Ok(PredictionProvider::Repair),
387 _ => {
388 anyhow::bail!(
389 "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
390 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
391 For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\
392 Available zeta versions:\n{}",
393 ZetaFormat::options_as_string()
394 )
395 }
396 }
397 }
398}
399
400impl Serialize for PredictionProvider {
401 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
402 where
403 S: Serializer,
404 {
405 serializer.serialize_str(&self.to_string())
406 }
407}
408
409impl<'de> Deserialize<'de> for PredictionProvider {
410 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
411 where
412 D: Deserializer<'de>,
413 {
414 let s = String::deserialize(deserializer)?;
415 s.parse().map_err(serde::de::Error::custom)
416 }
417}
418
419#[derive(Debug, Args, Clone)]
420struct SynthesizeArgs {
421 /// Repository URLs (git@github.com:owner/repo or https://...)
422 #[clap(long, required = true, num_args = 1..)]
423 repos: Vec<String>,
424
425 /// Number of examples to generate per repository
426 #[clap(long, default_value_t = 5)]
427 count: usize,
428
429 /// Maximum commits to scan per repository before giving up
430 #[clap(long, default_value_t = 100)]
431 max_commits: usize,
432
433 /// Ignore state file and reprocess all commits
434 #[clap(long)]
435 fresh: bool,
436}
437
438#[derive(Debug, Args, Clone)]
439struct ImportBatchArgs {
440 /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
441 #[clap(long, required = true, num_args = 1..)]
442 batch_ids: Vec<String>,
443 /// Which provider's batches to import (anthropic or openai)
444 #[clap(long, default_value = "anthropic")]
445 provider: BatchProvider,
446}
447
448#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
449enum BatchProvider {
450 Anthropic,
451 Openai,
452}
453
454impl EpArgs {
455 fn output_path(&self) -> Option<PathBuf> {
456 if self.in_place {
457 if self.inputs.len() == 1 {
458 self.inputs.first().cloned()
459 } else {
460 panic!("--in-place requires exactly one input file")
461 }
462 } else {
463 self.output.clone()
464 }
465 }
466}
467
468async fn load_examples(
469 http_client: Arc<dyn http_client::HttpClient>,
470 args: &EpArgs,
471 output_path: Option<&PathBuf>,
472 background_executor: BackgroundExecutor,
473) -> anyhow::Result<Vec<Example>> {
474 let mut captured_after_timestamps = Vec::new();
475 let mut rejected_after_timestamps = Vec::new();
476 let mut requested_after_timestamps = Vec::new();
477 let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
478 Vec::new();
479 let mut file_inputs = Vec::new();
480
481 for input in &args.inputs {
482 let input_string = input.to_string_lossy();
483 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
484 captured_after_timestamps.push(timestamp.to_string());
485 } else if let Some(timestamp) =
486 pull_examples::parse_rejected_after_input(input_string.as_ref())
487 {
488 rejected_after_timestamps.push(timestamp.to_string());
489 } else if let Some(timestamp) =
490 pull_examples::parse_requested_after_input(input_string.as_ref())
491 {
492 requested_after_timestamps.push(timestamp.to_string());
493 } else if let Some((timestamp, rating_filter)) =
494 pull_examples::parse_rated_after_input(input_string.as_ref())
495 {
496 rated_after_inputs.push((timestamp.to_string(), rating_filter));
497 } else {
498 file_inputs.push(input.clone());
499 }
500 }
501
502 let mut examples = read_example_files(&file_inputs);
503
504 Progress::global().set_total_examples(examples.len());
505
506 let remaining_limit_for_snowflake =
507 args.limit.map(|limit| limit.saturating_sub(examples.len()));
508
509 if let Some(0) = remaining_limit_for_snowflake {
510 log::info!(
511 "skipping Snowflake inputs because --limit is already satisfied by example files"
512 );
513 } else {
514 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
515
516 if !captured_after_timestamps.is_empty() {
517 captured_after_timestamps.sort();
518
519 let mut captured_examples = pull_examples::fetch_captured_examples_after(
520 http_client.clone(),
521 &captured_after_timestamps,
522 max_rows_per_timestamp,
523 background_executor.clone(),
524 )
525 .await?;
526 examples.append(&mut captured_examples);
527 }
528
529 if !rejected_after_timestamps.is_empty() {
530 rejected_after_timestamps.sort();
531
532 let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
533 http_client.clone(),
534 &rejected_after_timestamps,
535 max_rows_per_timestamp,
536 background_executor.clone(),
537 )
538 .await?;
539 examples.append(&mut rejected_examples);
540 }
541
542 if !requested_after_timestamps.is_empty() {
543 requested_after_timestamps.sort();
544
545 let mut requested_examples = pull_examples::fetch_requested_examples_after(
546 http_client.clone(),
547 &requested_after_timestamps,
548 max_rows_per_timestamp,
549 background_executor.clone(),
550 )
551 .await?;
552 examples.append(&mut requested_examples);
553 }
554
555 if !rated_after_inputs.is_empty() {
556 rated_after_inputs.sort();
557
558 let mut rated_examples = pull_examples::fetch_rated_examples_after(
559 http_client,
560 &rated_after_inputs,
561 max_rows_per_timestamp,
562 background_executor,
563 )
564 .await?;
565 examples.append(&mut rated_examples);
566 }
567 }
568
569 crate::example::sort_examples_by_repo_and_rev(&mut examples);
570
571 if let Some(name_filter) = &args.name {
572 examples.retain(|example| example.spec.name.contains(name_filter));
573 }
574 if let Some(repo_filter) = &args.repo {
575 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
576 }
577
578 // Skip resume logic for --in-place since input and output are the same file,
579 // which would incorrectly treat all input examples as already processed.
580 if !args.in_place {
581 if let Some(path) = output_path
582 && let Some(command) = &args.command
583 {
584 resume_from_output(path, &mut examples, command);
585 }
586 }
587
588 if let Some(offset) = args.offset {
589 examples.splice(0..offset, []);
590 }
591
592 if let Some(limit) = args.limit {
593 examples.truncate(limit);
594 }
595
596 let progress = Progress::global();
597 progress.set_total_examples(examples.len());
598 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
599
600 Ok(examples)
601}
602
603fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
604 let mut hasher = collections::FxHasher::default();
605 spec.hash(&mut hasher);
606 hasher.finish()
607}
608
609fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
610 let file = match File::open(path) {
611 Ok(f) => f,
612 Err(_) => return,
613 };
614
615 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
616
617 let reader = BufReader::new(file);
618 let mut kept_lines = Vec::new();
619 let mut kept_hashes = HashSet::default();
620
621 for line in reader.lines() {
622 let line = match line {
623 Ok(l) => l,
624 Err(_) => continue,
625 };
626
627 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
628 let hash = spec_hash(&output_example.spec);
629 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
630 let is_complete = match command {
631 Command::Qa(_) => output_example
632 .qa
633 .first()
634 .and_then(|q| q.as_ref())
635 .and_then(|q| q.confidence)
636 .is_some(),
637 Command::Repair(_) => output_example.predictions.iter().any(|p| {
638 p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
639 }),
640 _ => true,
641 };
642 if is_complete {
643 kept_hashes.insert(hash);
644 kept_lines.push(line);
645 }
646 }
647 }
648 }
649
650 let total = examples.len();
651 let already_processed = kept_hashes.len();
652
653 eprintln!(
654 "Resuming: {}/{} examples already processed",
655 already_processed, total
656 );
657
658 let file = OpenOptions::new()
659 .write(true)
660 .truncate(true)
661 .open(path)
662 .expect("Failed to open output file for rewriting");
663 let mut writer = BufWriter::new(file);
664 for line in &kept_lines {
665 writeln!(writer, "{}", line).expect("Failed to write to output file");
666 }
667 writer.flush().expect("Failed to flush output file");
668
669 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
670}
671
672fn main() {
673 let args = EpArgs::parse();
674
675 if args.printenv {
676 ::util::shell_env::print_env();
677 return;
678 }
679
680 let output = args.output_path();
681
682 if args.markdown && output.is_none() {
683 eprintln!("--markdown requires -o to specify the output directory");
684 std::process::exit(1);
685 }
686
687 let command = match &args.command {
688 Some(cmd) => cmd.clone(),
689 None => {
690 EpArgs::command().print_help().unwrap();
691 return;
692 }
693 };
694
695 match &command {
696 Command::ImportBatch(import_args) => {
697 smol::block_on(async {
698 match import_args.provider {
699 BatchProvider::Anthropic => {
700 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
701 .expect("Failed to create Anthropic client");
702 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
703 eprintln!("Error importing Anthropic batches: {:?}", e);
704 std::process::exit(1);
705 }
706 }
707 BatchProvider::Openai => {
708 let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
709 .expect("Failed to create OpenAI client");
710 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
711 eprintln!("Error importing OpenAI batches: {:?}", e);
712 std::process::exit(1);
713 }
714 }
715 }
716 println!(
717 "Successfully imported {} batch(es)",
718 import_args.batch_ids.len()
719 );
720 });
721 return;
722 }
723 Command::Clean => {
724 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
725 return;
726 }
727 Command::PrintZetaFormats => {
728 use strum::IntoEnumIterator as _;
729 for format in ZetaFormat::iter() {
730 println!("{}", format.to_string().to_lowercase());
731 }
732 return;
733 }
734 Command::Synthesize(synth_args) => {
735 let Some(output_dir) = args.output else {
736 panic!("output dir is required");
737 };
738 let config = SynthesizeConfig {
739 repo_urls: synth_args.repos.clone(),
740 count: synth_args.count,
741 max_commits: synth_args.max_commits,
742 output_dir,
743 fresh: synth_args.fresh,
744 };
745 smol::block_on(async {
746 if let Err(e) = run_synthesize(config).await {
747 eprintln!("Error: {:?}", e);
748 std::process::exit(1);
749 }
750 });
751 return;
752 }
753 Command::SplitCommit(split_commit_args) => {
754 if let Err(error) = split_commit::run_split_commit(
755 split_commit_args,
756 &args.inputs,
757 output.as_ref(),
758 args.failed,
759 ) {
760 eprintln!("{error:#}");
761 std::process::exit(1);
762 }
763 return;
764 }
765 Command::TruncatePatch(truncate_args) => {
766 if let Err(error) =
767 truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
768 {
769 eprintln!("{error:#}");
770 std::process::exit(1);
771 }
772 return;
773 }
774 Command::Split(split_args) => {
775 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
776 eprintln!("{error:#}");
777 std::process::exit(1);
778 }
779 return;
780 }
781 Command::FilterLanguages(filter_args) => {
782 if let Err(error) =
783 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
784 {
785 eprintln!("{error:#}");
786 std::process::exit(1);
787 }
788 return;
789 }
790
791 _ => {}
792 }
793
794 let http_client = Arc::new(ReqwestClient::new());
795 let app = Application::headless().with_http_client(http_client);
796
797 app.run(move |cx| {
798 let app_state = Arc::new(headless::init(cx));
799 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
800
801 cx.spawn(async move |cx| {
802 let result = async {
803 let examples = load_examples(
804 app_state.client.http_client(),
805 &args,
806 output.as_ref(),
807 cx.background_executor().clone(),
808 )
809 .await?;
810
811 match &command {
812 Command::Predict(args) | Command::Score(args) => {
813 predict::sync_batches(args.provider.as_ref()).await?;
814 }
815 Command::Eval(args) => {
816 predict::sync_batches(args.predict.provider.as_ref()).await?;
817 }
818 Command::Qa(args) => {
819 qa::sync_batches(args).await?;
820 }
821 Command::Repair(args) => {
822 repair::sync_batches(args).await?;
823 }
824 _ => (),
825 }
826
827 let failfast_on_single_example = examples.len() == 1;
828
829 // For --markdown mode, create the output directory if it doesn't exist
830 if args.markdown {
831 let dir = output.as_ref().expect("--markdown requires -o");
832 if !dir.exists() {
833 std::fs::create_dir_all(dir)
834 .expect("Failed to create markdown output directory");
835 }
836 }
837
838 // Set up JSONL output writer (not used in markdown mode)
839 let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
840 let mut in_place_temp_path: Option<PathBuf> = None;
841 if !args.markdown
842 && let Some(output_path) = output.as_ref()
843 {
844 let write_path = if args.in_place {
845 let temp = output_path.with_extension("jsonl.tmp");
846 in_place_temp_path = Some(temp.clone());
847 temp
848 } else {
849 output_path.clone()
850 };
851
852 let file = OpenOptions::new()
853 .create(true)
854 .write(true)
855 .truncate(args.in_place)
856 .append(!args.in_place)
857 .open(&write_path)
858 .expect("Failed to open output file");
859
860 let mut writer = BufWriter::new(file);
861 let (sender, mut receiver) = mpsc::unbounded::<String>();
862 cx.background_spawn(async move {
863 while let Some(line) = receiver.next().await {
864 writeln!(writer, "{}", line).expect("Failed to write example");
865 writer.flush().expect("Failed to flush output");
866 }
867 })
868 .detach();
869 output_sender = Some(sender);
870 }
871
872 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
873 let finished_examples = Mutex::new(Vec::new());
874
875 let mut tasks = Vec::new();
876 for _ in 0..args.max_parallelism {
877 tasks.push(async {
878 loop {
879 let Some(mut repo_examples) =
880 grouped_examples.lock().unwrap().pop_front()
881 else {
882 break;
883 };
884 for example in &mut repo_examples {
885 let example_progress =
886 Progress::global().start_group(&example.spec.name);
887
888 let result = async {
889 match &command {
890 Command::Read => {}
891 Command::LoadProject => {
892 run_load_project(
893 example,
894 app_state.clone(),
895 &example_progress,
896 cx.clone(),
897 )
898 .await?;
899 }
900 Command::Context => {
901 run_context_retrieval(
902 example,
903 app_state.clone(),
904 &example_progress,
905 cx.clone(),
906 )
907 .await?;
908 }
909 Command::FormatPrompt(args) => {
910 run_format_prompt(
911 example,
912 args,
913 app_state.clone(),
914 &example_progress,
915 cx.clone(),
916 )
917 .await?;
918 }
919 Command::Predict(args) => {
920 run_prediction(
921 example,
922 args,
923 app_state.clone(),
924 &example_progress,
925 cx.clone(),
926 )
927 .await?;
928 }
929 Command::ParseOutput => {
930 parse_output::run_parse_output(example)?;
931 }
932 Command::Distill => {
933 run_distill(example).await?;
934 }
935 Command::Score(args) => {
936 run_scoring(
937 example,
938 args,
939 app_state.clone(),
940 &example_progress,
941 cx.clone(),
942 )
943 .await?;
944 }
945 Command::Eval(args) => {
946 run_scoring(
947 example,
948 &args.predict,
949 app_state.clone(),
950 &example_progress,
951 cx.clone(),
952 )
953 .await?;
954 }
955 Command::Qa(args) => {
956 qa::run_qa(example, args, &example_progress).await?;
957 }
958 Command::Repair(args) => {
959 repair::run_repair(example, args, &example_progress)
960 .await?;
961 }
962 Command::Clean
963 | Command::Synthesize(_)
964 | Command::SplitCommit(_)
965 | Command::Split(_)
966 | Command::TruncatePatch(_)
967 | Command::FilterLanguages(_)
968 | Command::ImportBatch(_)
969 | Command::PrintZetaFormats => {
970 unreachable!()
971 }
972 }
973 anyhow::Ok(())
974 }
975 .await;
976
977 let failed = if let Err(error) = result {
978 handle_error(
979 error,
980 &args,
981 &command,
982 &app_state,
983 failfast_on_single_example,
984 &example,
985 )
986 .await;
987 true
988 } else {
989 false
990 };
991
992 let should_write = !failed || args.failed == FailedHandling::Keep;
993 if should_write {
994 if args.markdown {
995 let markdown_dir =
996 output.as_ref().expect("--markdown requires -o");
997 let filename = format!("{}.md", example.spec.filename());
998 let path = markdown_dir.join(&filename);
999 let markdown = example.spec.to_markdown();
1000 std::fs::write(&path, &markdown)
1001 .expect("Failed to write markdown file");
1002 } else if let Some(ref mut sender) = output_sender.clone() {
1003 let line = serde_json::to_string(&example).unwrap();
1004 sender
1005 .send(line)
1006 .await
1007 .expect("Failed to send to output writer");
1008 } else if args.output.is_none()
1009 && !matches!(command, Command::Eval(_))
1010 {
1011 let line = serde_json::to_string(&example).unwrap();
1012 println!("{}", line);
1013 }
1014 }
1015 }
1016
1017 let repo_url = &repo_examples.first().unwrap().spec.repository_url;
1018 let project = repo_examples
1019 .iter()
1020 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
1021 .or_else(|| app_state.project_cache.get(repo_url));
1022
1023 if let Some(project) = project {
1024 let mut cx = cx.clone();
1025
1026 let shutdown_task: Task<()> =
1027 project.update(&mut cx, |project, cx| {
1028 let lsp_store = project.lsp_store();
1029 lsp_store.update(cx, |lsp_store, cx| {
1030 lsp_store.shutdown_all_language_servers(cx)
1031 })
1032 });
1033
1034 shutdown_task.await;
1035
1036 if let Some(ep_store) =
1037 cx.update(|cx| EditPredictionStore::try_global(cx))
1038 {
1039 ep_store.update(&mut cx, |store, _| {
1040 store.remove_project(&project);
1041 });
1042 }
1043 }
1044
1045 app_state.project_cache.remove(repo_url);
1046 for example in &mut repo_examples {
1047 example.state.take();
1048 }
1049 finished_examples
1050 .lock()
1051 .unwrap()
1052 .extend_from_slice(&repo_examples);
1053 }
1054 });
1055 }
1056 futures::future::join_all(tasks).await;
1057
1058 Progress::global().finalize();
1059
1060 match &command {
1061 Command::Predict(args) | Command::Score(args) => {
1062 predict::sync_batches(args.provider.as_ref()).await?;
1063 }
1064 Command::Eval(args) => {
1065 predict::sync_batches(args.predict.provider.as_ref()).await?;
1066 }
1067 Command::Qa(args) => {
1068 qa::sync_batches(args).await?;
1069 }
1070 Command::Repair(args) => {
1071 repair::sync_batches(args).await?;
1072 }
1073 _ => (),
1074 }
1075
1076 match &command {
1077 Command::Eval(args) => {
1078 let examples = finished_examples.lock().unwrap();
1079 score::print_report(&examples);
1080 if let Some(summary_path) = &args.summary_json {
1081 score::write_summary_json(&examples, summary_path)?;
1082 }
1083 }
1084 Command::Repair(args) => {
1085 let examples = finished_examples.lock().unwrap();
1086 repair::print_report(&examples, args.confidence_threshold);
1087 }
1088 _ => (),
1089 };
1090
1091 // For --in-place, atomically rename temp file to original
1092 if let Some(temp_path) = &in_place_temp_path {
1093 let final_path = output.as_ref().expect("in_place_temp_path requires output");
1094 std::fs::rename(temp_path, final_path)
1095 .expect("Failed to rename temp file to final output");
1096 }
1097
1098 anyhow::Ok(())
1099 }
1100 .await;
1101
1102 if let Err(e) = result {
1103 panic!("Fatal error: {:?}", e);
1104 }
1105
1106 let _ = cx.update(|cx| cx.quit());
1107 })
1108 .detach();
1109 });
1110}
1111
1112async fn handle_error(
1113 error: anyhow::Error,
1114 args: &EpArgs,
1115 command: &Command,
1116 app_state: &Arc<headless::EpAppState>,
1117 failfast_on_single_example: bool,
1118 example: &Example,
1119) {
1120 Progress::global().increment_failed();
1121
1122 let msg;
1123 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1124 let example_name = example.spec.filename();
1125
1126 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1127 app_state
1128 .fs
1129 .write(
1130 &failed_example_path,
1131 &serde_json::to_vec_pretty(&example).unwrap(),
1132 )
1133 .await
1134 .unwrap();
1135 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1136 app_state
1137 .fs
1138 .write(&err_path, format!("{error:?}").as_bytes())
1139 .await
1140 .unwrap();
1141
1142 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1143 let mut file = OpenOptions::new()
1144 .create(true)
1145 .append(true)
1146 .open(&failed_jsonl_path)
1147 .expect("Failed to open failed.jsonl");
1148 writeln!(file, "{}", serde_json::to_string(example).unwrap())
1149 .expect("Failed to write to failed.jsonl");
1150
1151 let cursor_path = match example.repo_name() {
1152 Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1153 Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1154 };
1155 msg = format!(
1156 indoc::indoc! {"
1157 While processing \"{}\":
1158
1159 \x1b[31m{:?}\x1b[0m
1160
1161 Example: \x1b[36m{}\x1b[0m
1162 Error file: \x1b[36m{}\x1b[0m
1163 Cursor file: \x1b[36m{}\x1b[0m
1164 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1165 "},
1166 example.spec.name,
1167 error,
1168 failed_example_path.display(),
1169 err_path.display(),
1170 cursor_path.display(),
1171 command,
1172 failed_example_path.display(),
1173 );
1174 } else {
1175 msg = format!(
1176 indoc::indoc! {"
1177 While processing \"{}\":
1178
1179 \x1b[31m{:?}\x1b[0m
1180 "},
1181 example.spec.name, error
1182 );
1183 }
1184
1185 if args.failfast || failfast_on_single_example {
1186 Progress::global().finalize();
1187 panic!("{}", msg);
1188 } else {
1189 log::error!("{}", msg);
1190 }
1191}