1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod load_project;
9mod metrics;
10mod openai_client;
11mod parse_output;
12mod paths;
13mod predict;
14mod progress;
15mod prompt_assets;
16mod pull_examples;
17mod qa;
18mod reorder_patch;
19mod repair;
20mod retrieve_context;
21mod reversal_tracking;
22mod score;
23mod split_commit;
24mod split_dataset;
25mod sync_deployments;
26mod synthesize;
27mod truncate_expected_patch;
28mod word_diff;
29use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
30use collections::HashSet;
31use edit_prediction::EditPredictionStore;
32use futures::channel::mpsc;
33use futures::{SinkExt as _, StreamExt as _};
34use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
35use zeta_prompt::ZetaFormat;
36
37use reqwest_client::ReqwestClient;
38use serde::{Deserialize, Deserializer, Serialize, Serializer};
39use std::fmt::Display;
40use std::fs::{File, OpenOptions};
41use std::hash::{Hash, Hasher};
42use std::io::{BufRead, BufReader, BufWriter, Write};
43use std::sync::Mutex;
44use std::{path::PathBuf, sync::Arc};
45
46use crate::distill::run_distill;
47use crate::example::{Example, group_examples_by_repo, read_example_files};
48use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
49use crate::format_prompt::run_format_prompt;
50use crate::load_project::run_load_project;
51use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
52use crate::predict::run_prediction;
53use crate::progress::Progress;
54use crate::retrieve_context::run_context_retrieval;
55use crate::score::run_scoring;
56use crate::split_commit::SplitCommitArgs;
57use crate::split_dataset::SplitArgs;
58use crate::synthesize::{SynthesizeConfig, run_synthesize};
59use crate::truncate_expected_patch::TruncatePatchArgs;
60
61#[derive(Parser, Debug)]
62#[command(name = "ep")]
63struct EpArgs {
64 #[arg(long, default_value_t = false)]
65 printenv: bool,
66 #[clap(long, default_value_t = 10, global = true)]
67 max_parallelism: usize,
68 /// The limit for the number of examples to process
69 /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
70 #[clap(long, global = true)]
71 limit: Option<usize>,
72 #[clap(long, global = true)]
73 offset: Option<usize>,
74 /// Filter examples by name
75 #[clap(long, global = true)]
76 name: Option<String>,
77 /// Filter examples by repository
78 #[clap(long, global = true)]
79 repo: Option<String>,
80 #[command(subcommand)]
81 command: Option<Command>,
82 #[clap(global = true, help = INPUTS_HELP)]
83 inputs: Vec<PathBuf>,
84 #[arg(long, short, global = true)]
85 output: Option<PathBuf>,
86 #[arg(long, short, global = true)]
87 in_place: bool,
88 #[arg(long, short, global = true)]
89 failfast: bool,
90 /// How to handle failed examples in output: keep them or skip them.
91 /// Failed examples are always logged to the run's failed directory.
92 #[arg(long, global = true, default_value = "keep")]
93 failed: FailedHandling,
94 /// Output as markdown files instead of JSONL. When set, -o specifies a directory
95 /// where one .md file per example will be written (named after each example).
96 #[arg(long, short, global = true)]
97 markdown: bool,
98}
99
100/// Controls whether failed examples are included in the main output.
101/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
102#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
103pub enum FailedHandling {
104 /// Include failed examples in the main output (default)
105 #[default]
106 Keep,
107 /// Exclude failed examples from the main output
108 Skip,
109 /// Skip writing files
110 SkipNoFiles,
111}
112
113const INPUTS_HELP: &str = r#"
114Inputs can be file paths or special specifiers:
115
116 path
117 Path to an example(s) file (.md, .json, or .jsonl)
118
119 captured-after:{timestamp}
120 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
121 These are examples captured via the "Capture Edit Prediction Example" action.
122
123 rejected-after:{timestamp}
124 Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
125 These are predictions that were shown to users but rejected (useful for DPO training).
126
127 rated-after:{timestamp}
128 Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
129 These are predictions that users explicitly rated as positive or negative via the
130 rate completions modal. Only zeta2 predictions are included.
131 - Positive ratings: output becomes expected_patches
132 - Negative ratings: output becomes rejected_patch
133
134 rated-positive-after:{timestamp}
135 Same as rated-after, but only fetches positively rated predictions.
136
137 rated-negative-after:{timestamp}
138 Same as rated-after, but only fetches negatively rated predictions.
139
140 Required environment variables to connect to Snowflake:
141 EP_SNOWFLAKE_API_KEY
142 EP_SNOWFLAKE_BASE_URL
143
144 Optional:
145 EP_SNOWFLAKE_ROLE
146
147Examples:
148
149 # Read examples from a file
150 ep read examples.jsonl -o output.jsonl
151
152 # Read captured examples after a timestamp
153 ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
154
155 # Read rejected predictions for DPO training
156 ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
157
158 # Read user-rated predictions
159 ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
160
161 # Read only positively rated predictions
162 ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
163
164 # Read only negatively rated predictions
165 ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
166
167 # Mix multiple input sources
168 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
169"#;
170
171#[derive(Subcommand, Debug, Clone)]
172enum Command {
173 /// Read examples from files or fetch from Snowflake, output as .jsonl
174 Read,
175 /// Create git worktrees for each example and load file contents
176 LoadProject,
177 /// Retrieve context for input examples.
178 Context,
179 /// Generate a prompt string for a specific model
180 FormatPrompt(FormatPromptArgs),
181 /// Runs edit prediction
182 Predict(PredictArgs),
183 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
184 /// Requires format-prompt to have been run first. Uses provider from prompt.
185 ParseOutput,
186 /// Computes a score based on actual and expected patches
187 Score(PredictArgs),
188 /// Prepares a distillation dataset by copying expected outputs to
189 /// predicted outputs and removing actual outputs and prompts.
190 Distill,
191 /// Print aggregated scores
192 Eval(EvalArgs),
193 /// Generate eval examples by analyzing git commits from a repository
194 Synthesize(SynthesizeArgs),
195 /// Remove git repositories and worktrees
196 Clean,
197 /// Generate an evaluation example by splitting a chronologically-ordered commit
198 SplitCommit(SplitCommitArgs),
199 /// Truncate expected patch by the given criteria
200 TruncatePatch(TruncatePatchArgs),
201 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
202 Split(SplitArgs),
203 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
204 FilterLanguages(FilterLanguagesArgs),
205 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
206 ImportBatch(ImportBatchArgs),
207 /// Assess the quality of predictions using LLM-as-a-judge
208 Qa(qa::QaArgs),
209 /// Repair predictions that received poor QA scores by generating improved predictions
210 Repair(repair::RepairArgs),
211 /// Print all valid zeta formats (lowercase, one per line)
212 PrintZetaFormats,
213 /// Sync baseten deployment metadata to Snowflake for linking predictions to experiments
214 SyncDeployments(SyncDeploymentsArgs),
215}
216
217impl Display for Command {
218 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219 match self {
220 Command::Read => write!(f, "read"),
221 Command::LoadProject => write!(f, "load-project"),
222 Command::Context => write!(f, "context"),
223 Command::FormatPrompt(args) => {
224 write!(f, "format-prompt --provider={}", args.provider)
225 }
226 Command::Predict(args) => match &args.provider {
227 Some(provider) => write!(f, "predict --provider={}", provider),
228 None => write!(f, "predict"),
229 },
230 Command::ParseOutput => write!(f, "parse-output"),
231 Command::Score(args) => match &args.provider {
232 Some(provider) => write!(f, "score --provider={}", provider),
233 None => write!(f, "score"),
234 },
235 Command::Distill => write!(f, "distill"),
236 Command::Eval(args) => match &args.predict.provider {
237 Some(provider) => write!(f, "eval --provider={}", provider),
238 None => write!(f, "eval"),
239 },
240 Command::Synthesize(args) => {
241 write!(f, "synthesize --repos {}", args.repos.join(" "))
242 }
243 Command::Clean => write!(f, "clean"),
244 Command::SplitCommit(_) => write!(f, "split-commit"),
245 Command::TruncatePatch(_) => write!(f, "truncate-patch"),
246 Command::Split(_) => write!(f, "split"),
247 Command::FilterLanguages(_) => write!(f, "filter-languages"),
248 Command::ImportBatch(args) => {
249 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
250 }
251 Command::Qa(_) => {
252 write!(f, "qa")
253 }
254 Command::Repair(_) => {
255 write!(f, "repair")
256 }
257 Command::PrintZetaFormats => {
258 write!(f, "print-zeta-formats")
259 }
260 Command::SyncDeployments(_) => {
261 write!(f, "sync-deployments")
262 }
263 }
264 }
265}
266
267#[derive(Debug, Args, Clone)]
268struct SyncDeploymentsArgs {
269 /// BaseTen model name (default: "zeta-2")
270 #[arg(long)]
271 model: Option<String>,
272}
273
274#[derive(Debug, Args, Clone)]
275struct FormatPromptArgs {
276 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
277 provider: PredictionProvider,
278}
279
280#[derive(Debug, Args, Clone)]
281struct PredictArgs {
282 #[clap(long, short('p'))]
283 provider: Option<PredictionProvider>,
284 #[clap(long, default_value_t = 1)]
285 repetitions: usize,
286 /// Only use cached responses, don't queue new requests for batching
287 #[clap(long)]
288 cache_only: bool,
289}
290
291#[derive(Debug, Args, Clone)]
292struct EvalArgs {
293 #[clap(flatten)]
294 predict: PredictArgs,
295 /// Path to write summary scores as JSON
296 #[clap(long)]
297 summary_json: Option<PathBuf>,
298}
299
300#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
301pub enum TeacherBackend {
302 Sonnet45,
303 Gpt52,
304}
305
306impl std::fmt::Display for TeacherBackend {
307 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308 match self {
309 TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
310 TeacherBackend::Gpt52 => write!(f, "gpt52"),
311 }
312 }
313}
314
315impl std::str::FromStr for TeacherBackend {
316 type Err = anyhow::Error;
317
318 fn from_str(s: &str) -> Result<Self, Self::Err> {
319 match s.to_lowercase().as_str() {
320 "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
321 "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
322 "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
323 _ => anyhow::bail!("unknown teacher backend `{s}`. Valid options: sonnet45, gpt52"),
324 }
325 }
326}
327
328impl TeacherBackend {
329 pub fn model_name(&self) -> &'static str {
330 match self {
331 TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
332 TeacherBackend::Gpt52 => "gpt-5.2",
333 }
334 }
335}
336
337#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
338enum PredictionProvider {
339 Sweep,
340 Mercury,
341 Zeta1,
342 Zeta2(ZetaFormat),
343 Teacher(TeacherBackend),
344 TeacherNonBatching(TeacherBackend),
345 Repair,
346}
347
348impl Default for PredictionProvider {
349 fn default() -> Self {
350 PredictionProvider::Zeta2(ZetaFormat::default())
351 }
352}
353
354impl std::fmt::Display for PredictionProvider {
355 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
356 match self {
357 PredictionProvider::Sweep => write!(f, "sweep"),
358 PredictionProvider::Mercury => write!(f, "mercury"),
359 PredictionProvider::Zeta1 => write!(f, "zeta1"),
360 PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
361 PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
362 PredictionProvider::TeacherNonBatching(backend) => {
363 write!(f, "teacher-non-batching:{backend}")
364 }
365 PredictionProvider::Repair => write!(f, "repair"),
366 }
367 }
368}
369
370impl std::str::FromStr for PredictionProvider {
371 type Err = anyhow::Error;
372
373 fn from_str(s: &str) -> Result<Self, Self::Err> {
374 let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
375
376 let provider_lower = provider.to_lowercase();
377 match provider_lower.as_str() {
378 "sweep" => Ok(PredictionProvider::Sweep),
379 "mercury" => Ok(PredictionProvider::Mercury),
380 "zeta1" => Ok(PredictionProvider::Zeta1),
381 "zeta2" => {
382 let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
383 Ok(PredictionProvider::Zeta2(format))
384 }
385 "teacher" => {
386 let backend = arg
387 .map(|a| a.parse())
388 .transpose()?
389 .unwrap_or(TeacherBackend::Sonnet45);
390 Ok(PredictionProvider::Teacher(backend))
391 }
392 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
393 let backend = arg
394 .map(|a| a.parse())
395 .transpose()?
396 .unwrap_or(TeacherBackend::Sonnet45);
397 Ok(PredictionProvider::TeacherNonBatching(backend))
398 }
399 "repair" => Ok(PredictionProvider::Repair),
400 _ => {
401 anyhow::bail!(
402 "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
403 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
404 For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\
405 Available zeta versions:\n{}",
406 ZetaFormat::options_as_string()
407 )
408 }
409 }
410 }
411}
412
413impl Serialize for PredictionProvider {
414 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
415 where
416 S: Serializer,
417 {
418 serializer.serialize_str(&self.to_string())
419 }
420}
421
422impl<'de> Deserialize<'de> for PredictionProvider {
423 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
424 where
425 D: Deserializer<'de>,
426 {
427 let s = String::deserialize(deserializer)?;
428 s.parse().map_err(serde::de::Error::custom)
429 }
430}
431
432#[derive(Debug, Args, Clone)]
433struct SynthesizeArgs {
434 /// Repository URLs (git@github.com:owner/repo or https://...)
435 #[clap(long, required = true, num_args = 1..)]
436 repos: Vec<String>,
437
438 /// Number of examples to generate per repository
439 #[clap(long, default_value_t = 5)]
440 count: usize,
441
442 /// Maximum commits to scan per repository before giving up
443 #[clap(long, default_value_t = 100)]
444 max_commits: usize,
445
446 /// Ignore state file and reprocess all commits
447 #[clap(long)]
448 fresh: bool,
449}
450
451#[derive(Debug, Args, Clone)]
452struct ImportBatchArgs {
453 /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
454 #[clap(long, required = true, num_args = 1..)]
455 batch_ids: Vec<String>,
456 /// Which provider's batches to import (anthropic or openai)
457 #[clap(long, default_value = "anthropic")]
458 provider: BatchProvider,
459}
460
461#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
462enum BatchProvider {
463 Anthropic,
464 Openai,
465}
466
467impl EpArgs {
468 fn output_path(&self) -> Option<PathBuf> {
469 if self.in_place {
470 if self.inputs.len() == 1 {
471 self.inputs.first().cloned()
472 } else {
473 panic!("--in-place requires exactly one input file")
474 }
475 } else {
476 self.output.clone()
477 }
478 }
479}
480
481async fn load_examples(
482 http_client: Arc<dyn http_client::HttpClient>,
483 args: &EpArgs,
484 output_path: Option<&PathBuf>,
485 background_executor: BackgroundExecutor,
486) -> anyhow::Result<Vec<Example>> {
487 let mut captured_after_timestamps = Vec::new();
488 let mut rejected_after_timestamps = Vec::new();
489 let mut requested_after_timestamps = Vec::new();
490 let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
491 Vec::new();
492 let mut file_inputs = Vec::new();
493
494 for input in &args.inputs {
495 let input_string = input.to_string_lossy();
496 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
497 captured_after_timestamps.push(timestamp.to_string());
498 } else if let Some(timestamp) =
499 pull_examples::parse_rejected_after_input(input_string.as_ref())
500 {
501 rejected_after_timestamps.push(timestamp.to_string());
502 } else if let Some(timestamp) =
503 pull_examples::parse_requested_after_input(input_string.as_ref())
504 {
505 requested_after_timestamps.push(timestamp.to_string());
506 } else if let Some((timestamp, rating_filter)) =
507 pull_examples::parse_rated_after_input(input_string.as_ref())
508 {
509 rated_after_inputs.push((timestamp.to_string(), rating_filter));
510 } else {
511 file_inputs.push(input.clone());
512 }
513 }
514
515 let mut examples = read_example_files(&file_inputs);
516
517 Progress::global().set_total_examples(examples.len());
518
519 let remaining_limit_for_snowflake =
520 args.limit.map(|limit| limit.saturating_sub(examples.len()));
521
522 if let Some(0) = remaining_limit_for_snowflake {
523 log::info!(
524 "skipping Snowflake inputs because --limit is already satisfied by example files"
525 );
526 } else {
527 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
528
529 if !captured_after_timestamps.is_empty() {
530 captured_after_timestamps.sort();
531
532 let mut captured_examples = pull_examples::fetch_captured_examples_after(
533 http_client.clone(),
534 &captured_after_timestamps,
535 max_rows_per_timestamp,
536 background_executor.clone(),
537 )
538 .await?;
539 examples.append(&mut captured_examples);
540 }
541
542 if !rejected_after_timestamps.is_empty() {
543 rejected_after_timestamps.sort();
544
545 let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
546 http_client.clone(),
547 &rejected_after_timestamps,
548 max_rows_per_timestamp,
549 background_executor.clone(),
550 )
551 .await?;
552 examples.append(&mut rejected_examples);
553 }
554
555 if !requested_after_timestamps.is_empty() {
556 requested_after_timestamps.sort();
557
558 let mut requested_examples = pull_examples::fetch_requested_examples_after(
559 http_client.clone(),
560 &requested_after_timestamps,
561 max_rows_per_timestamp,
562 background_executor.clone(),
563 )
564 .await?;
565 examples.append(&mut requested_examples);
566 }
567
568 if !rated_after_inputs.is_empty() {
569 rated_after_inputs.sort();
570
571 let mut rated_examples = pull_examples::fetch_rated_examples_after(
572 http_client,
573 &rated_after_inputs,
574 max_rows_per_timestamp,
575 background_executor,
576 )
577 .await?;
578 examples.append(&mut rated_examples);
579 }
580 }
581
582 crate::example::sort_examples_by_repo_and_rev(&mut examples);
583
584 if let Some(name_filter) = &args.name {
585 examples.retain(|example| example.spec.name.contains(name_filter));
586 }
587 if let Some(repo_filter) = &args.repo {
588 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
589 }
590
591 // Skip resume logic for --in-place since input and output are the same file,
592 // which would incorrectly treat all input examples as already processed.
593 if !args.in_place {
594 if let Some(path) = output_path
595 && let Some(command) = &args.command
596 {
597 resume_from_output(path, &mut examples, command);
598 }
599 }
600
601 if let Some(offset) = args.offset {
602 examples.splice(0..offset, []);
603 }
604
605 if let Some(limit) = args.limit {
606 examples.truncate(limit);
607 }
608
609 let progress = Progress::global();
610 progress.set_total_examples(examples.len());
611 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
612
613 Ok(examples)
614}
615
616fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
617 let mut hasher = collections::FxHasher::default();
618 spec.hash(&mut hasher);
619 hasher.finish()
620}
621
622fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
623 let file = match File::open(path) {
624 Ok(f) => f,
625 Err(_) => return,
626 };
627
628 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
629
630 let reader = BufReader::new(file);
631 let mut kept_lines = Vec::new();
632 let mut kept_hashes = HashSet::default();
633
634 for line in reader.lines() {
635 let line = match line {
636 Ok(l) => l,
637 Err(_) => continue,
638 };
639
640 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
641 let hash = spec_hash(&output_example.spec);
642 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
643 let is_complete = match command {
644 Command::Qa(_) => output_example
645 .qa
646 .first()
647 .and_then(|q| q.as_ref())
648 .and_then(|q| q.confidence)
649 .is_some(),
650 Command::Repair(_) => output_example.predictions.iter().any(|p| {
651 p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
652 }),
653 _ => true,
654 };
655 if is_complete {
656 kept_hashes.insert(hash);
657 kept_lines.push(line);
658 }
659 }
660 }
661 }
662
663 let total = examples.len();
664 let already_processed = kept_hashes.len();
665
666 eprintln!(
667 "Resuming: {}/{} examples already processed",
668 already_processed, total
669 );
670
671 let file = OpenOptions::new()
672 .write(true)
673 .truncate(true)
674 .open(path)
675 .expect("Failed to open output file for rewriting");
676 let mut writer = BufWriter::new(file);
677 for line in &kept_lines {
678 writeln!(writer, "{}", line).expect("Failed to write to output file");
679 }
680 writer.flush().expect("Failed to flush output file");
681
682 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
683}
684
685fn main() {
686 let args = EpArgs::parse();
687
688 if args.printenv {
689 ::util::shell_env::print_env();
690 return;
691 }
692
693 let output = args.output_path();
694
695 if args.markdown && output.is_none() {
696 eprintln!("--markdown requires -o to specify the output directory");
697 std::process::exit(1);
698 }
699
700 let command = match &args.command {
701 Some(cmd) => cmd.clone(),
702 None => {
703 EpArgs::command().print_help().unwrap();
704 return;
705 }
706 };
707
708 match &command {
709 Command::ImportBatch(import_args) => {
710 smol::block_on(async {
711 match import_args.provider {
712 BatchProvider::Anthropic => {
713 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
714 .expect("Failed to create Anthropic client");
715 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
716 eprintln!("Error importing Anthropic batches: {:?}", e);
717 std::process::exit(1);
718 }
719 }
720 BatchProvider::Openai => {
721 let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
722 .expect("Failed to create OpenAI client");
723 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
724 eprintln!("Error importing OpenAI batches: {:?}", e);
725 std::process::exit(1);
726 }
727 }
728 }
729 println!(
730 "Successfully imported {} batch(es)",
731 import_args.batch_ids.len()
732 );
733 });
734 return;
735 }
736 Command::Clean => {
737 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
738 return;
739 }
740 Command::PrintZetaFormats => {
741 use strum::IntoEnumIterator as _;
742 for format in ZetaFormat::iter() {
743 println!("{}", format.to_string().to_lowercase());
744 }
745 return;
746 }
747 Command::SyncDeployments(sync_args) => {
748 let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
749 smol::block_on(async {
750 if let Err(e) =
751 sync_deployments::run_sync_deployments(http_client, sync_args.model.clone())
752 .await
753 {
754 eprintln!("Error: {:?}", e);
755 std::process::exit(1);
756 }
757 });
758 return;
759 }
760 Command::Synthesize(synth_args) => {
761 let Some(output_dir) = args.output else {
762 panic!("output dir is required");
763 };
764 let config = SynthesizeConfig {
765 repo_urls: synth_args.repos.clone(),
766 count: synth_args.count,
767 max_commits: synth_args.max_commits,
768 output_dir,
769 fresh: synth_args.fresh,
770 };
771 smol::block_on(async {
772 if let Err(e) = run_synthesize(config).await {
773 eprintln!("Error: {:?}", e);
774 std::process::exit(1);
775 }
776 });
777 return;
778 }
779 Command::SplitCommit(split_commit_args) => {
780 if let Err(error) = split_commit::run_split_commit(
781 split_commit_args,
782 &args.inputs,
783 output.as_ref(),
784 args.failed,
785 ) {
786 eprintln!("{error:#}");
787 std::process::exit(1);
788 }
789 return;
790 }
791 Command::TruncatePatch(truncate_args) => {
792 if let Err(error) =
793 truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
794 {
795 eprintln!("{error:#}");
796 std::process::exit(1);
797 }
798 return;
799 }
800 Command::Split(split_args) => {
801 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
802 eprintln!("{error:#}");
803 std::process::exit(1);
804 }
805 return;
806 }
807 Command::FilterLanguages(filter_args) => {
808 if let Err(error) =
809 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
810 {
811 eprintln!("{error:#}");
812 std::process::exit(1);
813 }
814 return;
815 }
816
817 _ => {}
818 }
819
820 let http_client = Arc::new(ReqwestClient::new());
821 let app = Application::headless().with_http_client(http_client);
822
823 app.run(move |cx| {
824 let app_state = Arc::new(headless::init(cx));
825 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
826
827 cx.spawn(async move |cx| {
828 let result = async {
829 let examples = load_examples(
830 app_state.client.http_client(),
831 &args,
832 output.as_ref(),
833 cx.background_executor().clone(),
834 )
835 .await?;
836
837 match &command {
838 Command::Predict(args) | Command::Score(args) => {
839 predict::sync_batches(args.provider.as_ref()).await?;
840 }
841 Command::Eval(args) => {
842 predict::sync_batches(args.predict.provider.as_ref()).await?;
843 }
844 Command::Qa(args) => {
845 qa::sync_batches(args).await?;
846 }
847 Command::Repair(args) => {
848 repair::sync_batches(args).await?;
849 }
850 _ => (),
851 }
852
853 let failfast_on_single_example = examples.len() == 1;
854
855 // For --markdown mode, create the output directory if it doesn't exist
856 if args.markdown {
857 let dir = output.as_ref().expect("--markdown requires -o");
858 if !dir.exists() {
859 std::fs::create_dir_all(dir)
860 .expect("Failed to create markdown output directory");
861 }
862 }
863
864 // Set up JSONL output writer (not used in markdown mode)
865 let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
866 let mut in_place_temp_path: Option<PathBuf> = None;
867 if !args.markdown
868 && let Some(output_path) = output.as_ref()
869 {
870 let write_path = if args.in_place {
871 let temp = output_path.with_extension("jsonl.tmp");
872 in_place_temp_path = Some(temp.clone());
873 temp
874 } else {
875 output_path.clone()
876 };
877
878 let file = OpenOptions::new()
879 .create(true)
880 .write(true)
881 .truncate(args.in_place)
882 .append(!args.in_place)
883 .open(&write_path)
884 .expect("Failed to open output file");
885
886 let mut writer = BufWriter::new(file);
887 let (sender, mut receiver) = mpsc::unbounded::<String>();
888 cx.background_spawn(async move {
889 while let Some(line) = receiver.next().await {
890 writeln!(writer, "{}", line).expect("Failed to write example");
891 writer.flush().expect("Failed to flush output");
892 }
893 })
894 .detach();
895 output_sender = Some(sender);
896 }
897
898 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
899 let finished_examples = Mutex::new(Vec::new());
900
901 let mut tasks = Vec::new();
902 for _ in 0..args.max_parallelism {
903 tasks.push(async {
904 loop {
905 let Some(mut repo_examples) =
906 grouped_examples.lock().unwrap().pop_front()
907 else {
908 break;
909 };
910 for example in &mut repo_examples {
911 let example_progress =
912 Progress::global().start_group(&example.spec.name);
913
914 let result = async {
915 match &command {
916 Command::Read => {}
917 Command::LoadProject => {
918 run_load_project(
919 example,
920 app_state.clone(),
921 &example_progress,
922 cx.clone(),
923 )
924 .await?;
925 }
926 Command::Context => {
927 run_context_retrieval(
928 example,
929 app_state.clone(),
930 &example_progress,
931 cx.clone(),
932 )
933 .await?;
934 }
935 Command::FormatPrompt(args) => {
936 run_format_prompt(
937 example,
938 args,
939 app_state.clone(),
940 &example_progress,
941 cx.clone(),
942 )
943 .await?;
944 }
945 Command::Predict(args) => {
946 run_prediction(
947 example,
948 args,
949 app_state.clone(),
950 &example_progress,
951 cx.clone(),
952 )
953 .await?;
954 }
955 Command::ParseOutput => {
956 parse_output::run_parse_output(example)?;
957 }
958 Command::Distill => {
959 run_distill(example).await?;
960 }
961 Command::Score(args) => {
962 run_scoring(
963 example,
964 args,
965 app_state.clone(),
966 &example_progress,
967 cx.clone(),
968 )
969 .await?;
970 }
971 Command::Eval(args) => {
972 run_scoring(
973 example,
974 &args.predict,
975 app_state.clone(),
976 &example_progress,
977 cx.clone(),
978 )
979 .await?;
980 }
981 Command::Qa(args) => {
982 qa::run_qa(example, args, &example_progress).await?;
983 }
984 Command::Repair(args) => {
985 repair::run_repair(example, args, &example_progress)
986 .await?;
987 }
988 Command::Clean
989 | Command::Synthesize(_)
990 | Command::SplitCommit(_)
991 | Command::Split(_)
992 | Command::TruncatePatch(_)
993 | Command::FilterLanguages(_)
994 | Command::ImportBatch(_)
995 | Command::PrintZetaFormats
996 | Command::SyncDeployments(_) => {
997 unreachable!()
998 }
999 }
1000 anyhow::Ok(())
1001 }
1002 .await;
1003
1004 let failed = if let Err(error) = result {
1005 handle_error(
1006 error,
1007 &args,
1008 &command,
1009 &app_state,
1010 failfast_on_single_example,
1011 &example,
1012 )
1013 .await;
1014 true
1015 } else {
1016 false
1017 };
1018
1019 let should_write = !failed || args.failed == FailedHandling::Keep;
1020 if should_write {
1021 if args.markdown {
1022 let markdown_dir =
1023 output.as_ref().expect("--markdown requires -o");
1024 let filename = format!("{}.md", example.spec.filename());
1025 let path = markdown_dir.join(&filename);
1026 let markdown = example.spec.to_markdown();
1027 std::fs::write(&path, &markdown)
1028 .expect("Failed to write markdown file");
1029 } else if let Some(ref mut sender) = output_sender.clone() {
1030 let line = serde_json::to_string(&example).unwrap();
1031 sender
1032 .send(line)
1033 .await
1034 .expect("Failed to send to output writer");
1035 } else if args.output.is_none()
1036 && !matches!(command, Command::Eval(_))
1037 {
1038 let line = serde_json::to_string(&example).unwrap();
1039 println!("{}", line);
1040 }
1041 }
1042 }
1043
1044 let repo_url = &repo_examples.first().unwrap().spec.repository_url;
1045 let project = repo_examples
1046 .iter()
1047 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()))
1048 .or_else(|| app_state.project_cache.get(repo_url));
1049
1050 if let Some(project) = project {
1051 let mut cx = cx.clone();
1052
1053 let shutdown_task: Task<()> =
1054 project.update(&mut cx, |project, cx| {
1055 let lsp_store = project.lsp_store();
1056 lsp_store.update(cx, |lsp_store, cx| {
1057 lsp_store.shutdown_all_language_servers(cx)
1058 })
1059 });
1060
1061 shutdown_task.await;
1062
1063 if let Some(ep_store) =
1064 cx.update(|cx| EditPredictionStore::try_global(cx))
1065 {
1066 ep_store.update(&mut cx, |store, _| {
1067 store.remove_project(&project);
1068 });
1069 }
1070 }
1071
1072 app_state.project_cache.remove(repo_url);
1073 for example in &mut repo_examples {
1074 example.state.take();
1075 }
1076 finished_examples
1077 .lock()
1078 .unwrap()
1079 .extend_from_slice(&repo_examples);
1080 }
1081 });
1082 }
1083 futures::future::join_all(tasks).await;
1084
1085 Progress::global().finalize();
1086
1087 match &command {
1088 Command::Predict(args) | Command::Score(args) => {
1089 predict::sync_batches(args.provider.as_ref()).await?;
1090 }
1091 Command::Eval(args) => {
1092 predict::sync_batches(args.predict.provider.as_ref()).await?;
1093 }
1094 Command::Qa(args) => {
1095 qa::sync_batches(args).await?;
1096 }
1097 Command::Repair(args) => {
1098 repair::sync_batches(args).await?;
1099 }
1100 _ => (),
1101 }
1102
1103 match &command {
1104 Command::Eval(args) => {
1105 let examples = finished_examples.lock().unwrap();
1106 score::print_report(&examples);
1107 if let Some(summary_path) = &args.summary_json {
1108 score::write_summary_json(&examples, summary_path)?;
1109 }
1110 }
1111 Command::Repair(args) => {
1112 let examples = finished_examples.lock().unwrap();
1113 repair::print_report(&examples, args.confidence_threshold);
1114 }
1115 _ => (),
1116 };
1117
1118 // For --in-place, atomically rename temp file to original
1119 if let Some(temp_path) = &in_place_temp_path {
1120 let final_path = output.as_ref().expect("in_place_temp_path requires output");
1121 std::fs::rename(temp_path, final_path)
1122 .expect("Failed to rename temp file to final output");
1123 }
1124
1125 anyhow::Ok(())
1126 }
1127 .await;
1128
1129 if let Err(e) = result {
1130 panic!("Fatal error: {:?}", e);
1131 }
1132
1133 let _ = cx.update(|cx| cx.quit());
1134 })
1135 .detach();
1136 });
1137}
1138
1139async fn handle_error(
1140 error: anyhow::Error,
1141 args: &EpArgs,
1142 command: &Command,
1143 app_state: &Arc<headless::EpAppState>,
1144 failfast_on_single_example: bool,
1145 example: &Example,
1146) {
1147 Progress::global().increment_failed();
1148
1149 let msg;
1150 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1151 let example_name = example.spec.filename();
1152
1153 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1154 app_state
1155 .fs
1156 .write(
1157 &failed_example_path,
1158 &serde_json::to_vec_pretty(&example).unwrap(),
1159 )
1160 .await
1161 .unwrap();
1162 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1163 app_state
1164 .fs
1165 .write(&err_path, format!("{error:?}").as_bytes())
1166 .await
1167 .unwrap();
1168
1169 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1170 let mut file = OpenOptions::new()
1171 .create(true)
1172 .append(true)
1173 .open(&failed_jsonl_path)
1174 .expect("Failed to open failed.jsonl");
1175 writeln!(file, "{}", serde_json::to_string(example).unwrap())
1176 .expect("Failed to write to failed.jsonl");
1177
1178 let cursor_path = match example.repo_name() {
1179 Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1180 Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1181 };
1182 msg = format!(
1183 indoc::indoc! {"
1184 While processing \"{}\":
1185
1186 \x1b[31m{:?}\x1b[0m
1187
1188 Example: \x1b[36m{}\x1b[0m
1189 Error file: \x1b[36m{}\x1b[0m
1190 Cursor file: \x1b[36m{}\x1b[0m
1191 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1192 "},
1193 example.spec.name,
1194 error,
1195 failed_example_path.display(),
1196 err_path.display(),
1197 cursor_path.display(),
1198 command,
1199 failed_example_path.display(),
1200 );
1201 } else {
1202 msg = format!(
1203 indoc::indoc! {"
1204 While processing \"{}\":
1205
1206 \x1b[31m{:?}\x1b[0m
1207 "},
1208 example.spec.name, error
1209 );
1210 }
1211
1212 if args.failfast || failfast_on_single_example {
1213 Progress::global().finalize();
1214 panic!("{}", msg);
1215 } else {
1216 log::error!("{}", msg);
1217 }
1218}