1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod load_project;
9mod metrics;
10mod openai_client;
11mod parse_output;
12mod paths;
13mod predict;
14mod progress;
15mod prompt_assets;
16mod pull_examples;
17mod qa;
18mod reorder_patch;
19mod repair;
20mod retrieve_context;
21mod reversal_tracking;
22mod score;
23mod split_commit;
24mod split_dataset;
25mod sync_deployments;
26mod synthesize;
27mod truncate_expected_patch;
28mod word_diff;
29use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
30use collections::HashSet;
31use edit_prediction::EditPredictionStore;
32use futures::channel::mpsc;
33use futures::{SinkExt as _, StreamExt as _};
34use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
35use zeta_prompt::ZetaFormat;
36
37use reqwest_client::ReqwestClient;
38use serde::{Deserialize, Deserializer, Serialize, Serializer};
39use std::fmt::Display;
40use std::fs::{File, OpenOptions};
41use std::hash::{Hash, Hasher};
42use std::io::{BufRead, BufReader, BufWriter, Write};
43use std::sync::Mutex;
44use std::{path::PathBuf, sync::Arc};
45
46use crate::distill::run_distill;
47use crate::example::{Example, group_examples_by_repo, read_example_files};
48use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
49use crate::format_prompt::run_format_prompt;
50use crate::load_project::run_load_project;
51use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
52use crate::predict::run_prediction;
53use crate::progress::Progress;
54use crate::retrieve_context::run_context_retrieval;
55use crate::score::run_scoring;
56use crate::split_commit::SplitCommitArgs;
57use crate::split_dataset::SplitArgs;
58use crate::synthesize::{SynthesizeConfig, run_synthesize};
59use crate::truncate_expected_patch::TruncatePatchArgs;
60
61#[derive(Parser, Debug)]
62#[command(name = "ep")]
63struct EpArgs {
64 #[arg(long, default_value_t = false)]
65 printenv: bool,
66 #[clap(long, default_value_t = 10, global = true)]
67 max_parallelism: usize,
68 /// The limit for the number of examples to process
69 /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
70 #[clap(long, global = true)]
71 limit: Option<usize>,
72 #[clap(long, global = true)]
73 offset: Option<usize>,
74 /// Filter examples by name
75 #[clap(long, global = true)]
76 name: Option<String>,
77 /// Filter examples by repository
78 #[clap(long, global = true)]
79 repo: Option<String>,
80 #[command(subcommand)]
81 command: Option<Command>,
82 #[clap(global = true, help = INPUTS_HELP)]
83 inputs: Vec<PathBuf>,
84 #[arg(long, short, global = true)]
85 output: Option<PathBuf>,
86 #[arg(long, short, global = true)]
87 in_place: bool,
88 #[arg(long, short, global = true)]
89 failfast: bool,
90 /// How to handle failed examples in output: keep them or skip them.
91 /// Failed examples are always logged to the run's failed directory.
92 #[arg(long, global = true, default_value = "keep")]
93 failed: FailedHandling,
94 /// Output as markdown files instead of JSONL. When set, -o specifies a directory
95 /// where one .md file per example will be written (named after each example).
96 #[arg(long, short, global = true)]
97 markdown: bool,
98}
99
100/// Controls whether failed examples are included in the main output.
101/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
102#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
103pub enum FailedHandling {
104 /// Include failed examples in the main output (default)
105 #[default]
106 Keep,
107 /// Exclude failed examples from the main output
108 Skip,
109 /// Skip writing files
110 SkipNoFiles,
111}
112
113const INPUTS_HELP: &str = r#"
114Inputs can be file paths or special specifiers:
115
116 path
117 Path to an example(s) file (.md, .json, or .jsonl)
118
119 captured-after:{timestamp}
120 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
121 These are examples captured via the "Capture Edit Prediction Example" action.
122
123 rejected-after:{timestamp}
124 Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
125 These are predictions that were shown to users but rejected (useful for DPO training).
126
127 rated-after:{timestamp}
128 Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
129 These are predictions that users explicitly rated as positive or negative via the
130 rate completions modal. Only zeta2 predictions are included.
131 - Positive ratings: output becomes expected_patches
132 - Negative ratings: output becomes rejected_patch
133
134 rated-positive-after:{timestamp}
135 Same as rated-after, but only fetches positively rated predictions.
136
137 rated-negative-after:{timestamp}
138 Same as rated-after, but only fetches negatively rated predictions.
139
140 Required environment variables to connect to Snowflake:
141 EP_SNOWFLAKE_API_KEY
142 EP_SNOWFLAKE_BASE_URL
143
144 Optional:
145 EP_SNOWFLAKE_ROLE
146
147Examples:
148
149 # Read examples from a file
150 ep read examples.jsonl -o output.jsonl
151
152 # Read captured examples after a timestamp
153 ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
154
155 # Read rejected predictions for DPO training
156 ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
157
158 # Read user-rated predictions
159 ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
160
161 # Read only positively rated predictions
162 ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
163
164 # Read only negatively rated predictions
165 ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
166
167 # Mix multiple input sources
168 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
169"#;
170
171#[derive(Subcommand, Debug, Clone)]
172enum Command {
173 /// Read examples from files or fetch from Snowflake, output as .jsonl
174 Read,
175 /// Create git worktrees for each example and load file contents
176 LoadProject,
177 /// Retrieve context for input examples.
178 Context,
179 /// Generate a prompt string for a specific model
180 FormatPrompt(FormatPromptArgs),
181 /// Runs edit prediction
182 Predict(PredictArgs),
183 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
184 /// Requires format-prompt to have been run first. Uses provider from prompt.
185 ParseOutput,
186 /// Computes a score based on actual and expected patches
187 Score(PredictArgs),
188 /// Prepares a distillation dataset by copying expected outputs to
189 /// predicted outputs and removing actual outputs and prompts.
190 Distill,
191 /// Print aggregated scores
192 Eval(EvalArgs),
193 /// Generate eval examples by analyzing git commits from a repository
194 Synthesize(SynthesizeArgs),
195 /// Remove git repositories and worktrees
196 Clean,
197 /// Generate an evaluation example by splitting a chronologically-ordered commit
198 SplitCommit(SplitCommitArgs),
199 /// Truncate expected patch by the given criteria
200 TruncatePatch(TruncatePatchArgs),
201 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
202 Split(SplitArgs),
203 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
204 FilterLanguages(FilterLanguagesArgs),
205 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
206 ImportBatch(ImportBatchArgs),
207 /// Assess the quality of predictions using LLM-as-a-judge
208 Qa(qa::QaArgs),
209 /// Repair predictions that received poor QA scores by generating improved predictions
210 Repair(repair::RepairArgs),
211 /// Print all valid zeta formats (lowercase, one per line)
212 PrintZetaFormats,
213 /// Sync baseten deployment metadata to Snowflake for linking predictions to experiments
214 SyncDeployments(SyncDeploymentsArgs),
215}
216
217impl Display for Command {
218 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219 match self {
220 Command::Read => write!(f, "read"),
221 Command::LoadProject => write!(f, "load-project"),
222 Command::Context => write!(f, "context"),
223 Command::FormatPrompt(args) => {
224 write!(f, "format-prompt --provider={}", args.provider)
225 }
226 Command::Predict(args) => match &args.provider {
227 Some(provider) => write!(f, "predict --provider={}", provider),
228 None => write!(f, "predict"),
229 },
230 Command::ParseOutput => write!(f, "parse-output"),
231 Command::Score(args) => match &args.provider {
232 Some(provider) => write!(f, "score --provider={}", provider),
233 None => write!(f, "score"),
234 },
235 Command::Distill => write!(f, "distill"),
236 Command::Eval(args) => match &args.predict.provider {
237 Some(provider) => write!(f, "eval --provider={}", provider),
238 None => write!(f, "eval"),
239 },
240 Command::Synthesize(args) => {
241 write!(f, "synthesize --repos {}", args.repos.join(" "))
242 }
243 Command::Clean => write!(f, "clean"),
244 Command::SplitCommit(_) => write!(f, "split-commit"),
245 Command::TruncatePatch(_) => write!(f, "truncate-patch"),
246 Command::Split(_) => write!(f, "split"),
247 Command::FilterLanguages(_) => write!(f, "filter-languages"),
248 Command::ImportBatch(args) => {
249 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
250 }
251 Command::Qa(_) => {
252 write!(f, "qa")
253 }
254 Command::Repair(_) => {
255 write!(f, "repair")
256 }
257 Command::PrintZetaFormats => {
258 write!(f, "print-zeta-formats")
259 }
260 Command::SyncDeployments(_) => {
261 write!(f, "sync-deployments")
262 }
263 }
264 }
265}
266
267#[derive(Debug, Args, Clone)]
268struct SyncDeploymentsArgs {
269 /// BaseTen model name (default: "zeta-2")
270 #[arg(long)]
271 model: Option<String>,
272}
273
274#[derive(Debug, Args, Clone)]
275struct FormatPromptArgs {
276 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
277 provider: PredictionProvider,
278}
279
280#[derive(Debug, Args, Clone)]
281struct PredictArgs {
282 #[clap(long, short('p'))]
283 provider: Option<PredictionProvider>,
284 #[clap(long, default_value_t = 1)]
285 repetitions: usize,
286 /// Only use cached responses, don't queue new requests for batching
287 #[clap(long)]
288 cache_only: bool,
289}
290
291#[derive(Debug, Args, Clone)]
292struct EvalArgs {
293 #[clap(flatten)]
294 predict: PredictArgs,
295 /// Path to write summary scores as JSON
296 #[clap(long)]
297 summary_json: Option<PathBuf>,
298}
299
300#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
301pub enum TeacherBackend {
302 Sonnet45,
303 Gpt52,
304}
305
306impl std::fmt::Display for TeacherBackend {
307 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308 match self {
309 TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
310 TeacherBackend::Gpt52 => write!(f, "gpt52"),
311 }
312 }
313}
314
315impl std::str::FromStr for TeacherBackend {
316 type Err = anyhow::Error;
317
318 fn from_str(s: &str) -> Result<Self, Self::Err> {
319 match s.to_lowercase().as_str() {
320 "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
321 "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
322 "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
323 _ => anyhow::bail!("unknown teacher backend `{s}`. Valid options: sonnet45, gpt52"),
324 }
325 }
326}
327
328impl TeacherBackend {
329 pub fn model_name(&self) -> &'static str {
330 match self {
331 TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
332 TeacherBackend::Gpt52 => "gpt-5.2",
333 }
334 }
335}
336
337#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
338enum PredictionProvider {
339 Sweep,
340 Mercury,
341 Zeta1,
342 Zeta2(ZetaFormat),
343 Teacher(TeacherBackend),
344 TeacherNonBatching(TeacherBackend),
345 Repair,
346}
347
348impl Default for PredictionProvider {
349 fn default() -> Self {
350 PredictionProvider::Zeta2(ZetaFormat::default())
351 }
352}
353
354impl std::fmt::Display for PredictionProvider {
355 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
356 match self {
357 PredictionProvider::Sweep => write!(f, "sweep"),
358 PredictionProvider::Mercury => write!(f, "mercury"),
359 PredictionProvider::Zeta1 => write!(f, "zeta1"),
360 PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
361 PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
362 PredictionProvider::TeacherNonBatching(backend) => {
363 write!(f, "teacher-non-batching:{backend}")
364 }
365 PredictionProvider::Repair => write!(f, "repair"),
366 }
367 }
368}
369
370impl std::str::FromStr for PredictionProvider {
371 type Err = anyhow::Error;
372
373 fn from_str(s: &str) -> Result<Self, Self::Err> {
374 let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
375
376 let provider_lower = provider.to_lowercase();
377 match provider_lower.as_str() {
378 "sweep" => Ok(PredictionProvider::Sweep),
379 "mercury" => Ok(PredictionProvider::Mercury),
380 "zeta1" => Ok(PredictionProvider::Zeta1),
381 "zeta2" => {
382 let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
383 Ok(PredictionProvider::Zeta2(format))
384 }
385 "teacher" => {
386 let backend = arg
387 .map(|a| a.parse())
388 .transpose()?
389 .unwrap_or(TeacherBackend::Sonnet45);
390 Ok(PredictionProvider::Teacher(backend))
391 }
392 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
393 let backend = arg
394 .map(|a| a.parse())
395 .transpose()?
396 .unwrap_or(TeacherBackend::Sonnet45);
397 Ok(PredictionProvider::TeacherNonBatching(backend))
398 }
399 "repair" => Ok(PredictionProvider::Repair),
400 _ => {
401 anyhow::bail!(
402 "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
403 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
404 For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\
405 Available zeta versions:\n{}",
406 ZetaFormat::options_as_string()
407 )
408 }
409 }
410 }
411}
412
413impl Serialize for PredictionProvider {
414 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
415 where
416 S: Serializer,
417 {
418 serializer.serialize_str(&self.to_string())
419 }
420}
421
422impl<'de> Deserialize<'de> for PredictionProvider {
423 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
424 where
425 D: Deserializer<'de>,
426 {
427 let s = String::deserialize(deserializer)?;
428 s.parse().map_err(serde::de::Error::custom)
429 }
430}
431
432#[derive(Debug, Args, Clone)]
433struct SynthesizeArgs {
434 /// Repository URLs (git@github.com:owner/repo or https://...)
435 #[clap(long, required = true, num_args = 1..)]
436 repos: Vec<String>,
437
438 /// Number of examples to generate per repository
439 #[clap(long, default_value_t = 5)]
440 count: usize,
441
442 /// Maximum commits to scan per repository before giving up
443 #[clap(long, default_value_t = 100)]
444 max_commits: usize,
445
446 /// Ignore state file and reprocess all commits
447 #[clap(long)]
448 fresh: bool,
449}
450
451#[derive(Debug, Args, Clone)]
452struct ImportBatchArgs {
453 /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
454 #[clap(long, required = true, num_args = 1..)]
455 batch_ids: Vec<String>,
456 /// Which provider's batches to import (anthropic or openai)
457 #[clap(long, default_value = "anthropic")]
458 provider: BatchProvider,
459}
460
461#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
462enum BatchProvider {
463 Anthropic,
464 Openai,
465}
466
467impl EpArgs {
468 fn output_path(&self) -> Option<PathBuf> {
469 if self.in_place {
470 if self.inputs.len() == 1 {
471 self.inputs.first().cloned()
472 } else {
473 panic!("--in-place requires exactly one input file")
474 }
475 } else {
476 self.output.clone()
477 }
478 }
479}
480
481/// Minimum Zed version required for Snowflake queries.
482/// This version introduced the current request schema with predicted edits in the edit
483/// history, and open source repos distinguished.
484const MIN_CAPTURE_VERSION: pull_examples::MinCaptureVersion = pull_examples::MinCaptureVersion {
485 minor: 224,
486 patch: 1,
487};
488
489async fn load_examples(
490 http_client: Arc<dyn http_client::HttpClient>,
491 args: &EpArgs,
492 output_path: Option<&PathBuf>,
493 background_executor: BackgroundExecutor,
494) -> anyhow::Result<Vec<Example>> {
495 let mut captured_after_timestamps = Vec::new();
496 let mut rejected_after_timestamps = Vec::new();
497 let mut requested_after_timestamps = Vec::new();
498 let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
499 Vec::new();
500 let mut file_inputs = Vec::new();
501
502 for input in &args.inputs {
503 let input_string = input.to_string_lossy();
504 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
505 captured_after_timestamps.push(timestamp.to_string());
506 } else if let Some(timestamp) =
507 pull_examples::parse_rejected_after_input(input_string.as_ref())
508 {
509 rejected_after_timestamps.push(timestamp.to_string());
510 } else if let Some(timestamp) =
511 pull_examples::parse_requested_after_input(input_string.as_ref())
512 {
513 requested_after_timestamps.push(timestamp.to_string());
514 } else if let Some((timestamp, rating_filter)) =
515 pull_examples::parse_rated_after_input(input_string.as_ref())
516 {
517 rated_after_inputs.push((timestamp.to_string(), rating_filter));
518 } else {
519 file_inputs.push(input.clone());
520 }
521 }
522
523 let mut examples = read_example_files(&file_inputs);
524
525 // Apply offset to file examples first, then pass remaining offset to Snowflake.
526 let file_example_count = examples.len();
527 let remaining_offset = if let Some(offset) = args.offset {
528 if offset >= file_example_count {
529 examples.clear();
530 offset - file_example_count
531 } else {
532 examples.splice(0..offset, []);
533 0
534 }
535 } else {
536 0
537 };
538
539 Progress::global().set_total_examples(examples.len());
540
541 let remaining_limit_for_snowflake =
542 args.limit.map(|limit| limit.saturating_sub(examples.len()));
543
544 if let Some(0) = remaining_limit_for_snowflake {
545 log::info!(
546 "skipping Snowflake inputs because --limit is already satisfied by example files"
547 );
548 } else {
549 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
550
551 if !captured_after_timestamps.is_empty() {
552 captured_after_timestamps.sort();
553
554 let mut captured_examples = pull_examples::fetch_captured_examples_after(
555 http_client.clone(),
556 &captured_after_timestamps,
557 max_rows_per_timestamp,
558 remaining_offset,
559 background_executor.clone(),
560 Some(MIN_CAPTURE_VERSION),
561 )
562 .await?;
563 examples.append(&mut captured_examples);
564 }
565
566 if !rejected_after_timestamps.is_empty() {
567 rejected_after_timestamps.sort();
568
569 let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
570 http_client.clone(),
571 &rejected_after_timestamps,
572 max_rows_per_timestamp,
573 remaining_offset,
574 background_executor.clone(),
575 Some(MIN_CAPTURE_VERSION),
576 )
577 .await?;
578 examples.append(&mut rejected_examples);
579 }
580
581 if !requested_after_timestamps.is_empty() {
582 requested_after_timestamps.sort();
583
584 let mut requested_examples = pull_examples::fetch_requested_examples_after(
585 http_client.clone(),
586 &requested_after_timestamps,
587 max_rows_per_timestamp,
588 remaining_offset,
589 background_executor.clone(),
590 Some(MIN_CAPTURE_VERSION),
591 )
592 .await?;
593 examples.append(&mut requested_examples);
594 }
595
596 if !rated_after_inputs.is_empty() {
597 rated_after_inputs.sort();
598
599 let mut rated_examples = pull_examples::fetch_rated_examples_after(
600 http_client,
601 &rated_after_inputs,
602 max_rows_per_timestamp,
603 remaining_offset,
604 background_executor,
605 Some(MIN_CAPTURE_VERSION),
606 )
607 .await?;
608 examples.append(&mut rated_examples);
609 }
610 }
611
612 crate::example::sort_examples_by_repo_and_rev(&mut examples);
613
614 if let Some(name_filter) = &args.name {
615 examples.retain(|example| example.spec.name.contains(name_filter));
616 }
617 if let Some(repo_filter) = &args.repo {
618 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
619 }
620
621 // Skip resume logic for --in-place since input and output are the same file,
622 // which would incorrectly treat all input examples as already processed.
623 if !args.in_place {
624 if let Some(path) = output_path
625 && let Some(command) = &args.command
626 {
627 resume_from_output(path, &mut examples, command);
628 }
629 }
630
631 if let Some(limit) = args.limit {
632 examples.truncate(limit);
633 }
634
635 let progress = Progress::global();
636 progress.set_total_examples(examples.len());
637 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
638
639 Ok(examples)
640}
641
642fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
643 let mut hasher = collections::FxHasher::default();
644 spec.hash(&mut hasher);
645 hasher.finish()
646}
647
648fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
649 let file = match File::open(path) {
650 Ok(f) => f,
651 Err(_) => return,
652 };
653
654 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
655
656 let reader = BufReader::new(file);
657 let mut kept_lines = Vec::new();
658 let mut kept_hashes = HashSet::default();
659
660 for line in reader.lines() {
661 let line = match line {
662 Ok(l) => l,
663 Err(_) => continue,
664 };
665
666 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
667 let hash = spec_hash(&output_example.spec);
668 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
669 let is_complete = match command {
670 Command::Qa(_) => output_example
671 .qa
672 .first()
673 .and_then(|q| q.as_ref())
674 .and_then(|q| q.confidence)
675 .is_some(),
676 Command::Repair(_) => output_example.predictions.iter().any(|p| {
677 p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
678 }),
679 _ => true,
680 };
681 if is_complete {
682 kept_hashes.insert(hash);
683 kept_lines.push(line);
684 }
685 }
686 }
687 }
688
689 let total = examples.len();
690 let already_processed = kept_hashes.len();
691
692 eprintln!(
693 "Resuming: {}/{} examples already processed",
694 already_processed, total
695 );
696
697 let file = OpenOptions::new()
698 .write(true)
699 .truncate(true)
700 .open(path)
701 .expect("Failed to open output file for rewriting");
702 let mut writer = BufWriter::new(file);
703 for line in &kept_lines {
704 writeln!(writer, "{}", line).expect("Failed to write to output file");
705 }
706 writer.flush().expect("Failed to flush output file");
707
708 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
709}
710
711fn main() {
712 let args = EpArgs::parse();
713
714 if args.printenv {
715 ::util::shell_env::print_env();
716 return;
717 }
718
719 let output = args.output_path();
720
721 if args.markdown && output.is_none() {
722 eprintln!("--markdown requires -o to specify the output directory");
723 std::process::exit(1);
724 }
725
726 let command = match &args.command {
727 Some(cmd) => cmd.clone(),
728 None => {
729 EpArgs::command().print_help().unwrap();
730 return;
731 }
732 };
733
734 match &command {
735 Command::ImportBatch(import_args) => {
736 smol::block_on(async {
737 match import_args.provider {
738 BatchProvider::Anthropic => {
739 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
740 .expect("Failed to create Anthropic client");
741 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
742 eprintln!("Error importing Anthropic batches: {:?}", e);
743 std::process::exit(1);
744 }
745 }
746 BatchProvider::Openai => {
747 let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
748 .expect("Failed to create OpenAI client");
749 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
750 eprintln!("Error importing OpenAI batches: {:?}", e);
751 std::process::exit(1);
752 }
753 }
754 }
755 println!(
756 "Successfully imported {} batch(es)",
757 import_args.batch_ids.len()
758 );
759 });
760 return;
761 }
762 Command::Clean => {
763 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
764 return;
765 }
766 Command::PrintZetaFormats => {
767 use strum::IntoEnumIterator as _;
768 for format in ZetaFormat::iter() {
769 println!("{}", format.to_string().to_lowercase());
770 }
771 return;
772 }
773 Command::SyncDeployments(sync_args) => {
774 let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
775 smol::block_on(async {
776 if let Err(e) =
777 sync_deployments::run_sync_deployments(http_client, sync_args.model.clone())
778 .await
779 {
780 eprintln!("Error: {:?}", e);
781 std::process::exit(1);
782 }
783 });
784 return;
785 }
786 Command::Synthesize(synth_args) => {
787 let Some(output_dir) = args.output else {
788 panic!("output dir is required");
789 };
790 let config = SynthesizeConfig {
791 repo_urls: synth_args.repos.clone(),
792 count: synth_args.count,
793 max_commits: synth_args.max_commits,
794 output_dir,
795 fresh: synth_args.fresh,
796 };
797 smol::block_on(async {
798 if let Err(e) = run_synthesize(config).await {
799 eprintln!("Error: {:?}", e);
800 std::process::exit(1);
801 }
802 });
803 return;
804 }
805 Command::SplitCommit(split_commit_args) => {
806 if let Err(error) = split_commit::run_split_commit(
807 split_commit_args,
808 &args.inputs,
809 output.as_ref(),
810 args.failed,
811 ) {
812 eprintln!("{error:#}");
813 std::process::exit(1);
814 }
815 return;
816 }
817 Command::TruncatePatch(truncate_args) => {
818 if let Err(error) =
819 truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
820 {
821 eprintln!("{error:#}");
822 std::process::exit(1);
823 }
824 return;
825 }
826 Command::Split(split_args) => {
827 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
828 eprintln!("{error:#}");
829 std::process::exit(1);
830 }
831 return;
832 }
833 Command::FilterLanguages(filter_args) => {
834 if let Err(error) =
835 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
836 {
837 eprintln!("{error:#}");
838 std::process::exit(1);
839 }
840 return;
841 }
842
843 _ => {}
844 }
845
846 let http_client = Arc::new(ReqwestClient::new());
847 let app = Application::headless().with_http_client(http_client);
848
849 app.run(move |cx| {
850 let app_state = Arc::new(headless::init(cx));
851 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
852
853 cx.spawn(async move |cx| {
854 let result = async {
855 let examples = load_examples(
856 app_state.client.http_client(),
857 &args,
858 output.as_ref(),
859 cx.background_executor().clone(),
860 )
861 .await?;
862
863 match &command {
864 Command::Predict(args) | Command::Score(args) => {
865 predict::sync_batches(args.provider.as_ref()).await?;
866 }
867 Command::Eval(args) => {
868 predict::sync_batches(args.predict.provider.as_ref()).await?;
869 }
870 Command::Qa(args) => {
871 qa::sync_batches(args).await?;
872 }
873 Command::Repair(args) => {
874 repair::sync_batches(args).await?;
875 }
876 _ => (),
877 }
878
879 let failfast_on_single_example = examples.len() == 1;
880
881 // For --markdown mode, create the output directory if it doesn't exist
882 if args.markdown {
883 let dir = output.as_ref().expect("--markdown requires -o");
884 if !dir.exists() {
885 std::fs::create_dir_all(dir)
886 .expect("Failed to create markdown output directory");
887 }
888 }
889
890 // Set up JSONL output writer (not used in markdown mode)
891 let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
892 let mut in_place_temp_path: Option<PathBuf> = None;
893 if !args.markdown
894 && let Some(output_path) = output.as_ref()
895 {
896 let write_path = if args.in_place {
897 let temp = output_path.with_extension("jsonl.tmp");
898 in_place_temp_path = Some(temp.clone());
899 temp
900 } else {
901 output_path.clone()
902 };
903
904 let file = OpenOptions::new()
905 .create(true)
906 .write(true)
907 .truncate(args.in_place)
908 .append(!args.in_place)
909 .open(&write_path)
910 .expect("Failed to open output file");
911
912 let mut writer = BufWriter::new(file);
913 let (sender, mut receiver) = mpsc::unbounded::<String>();
914 cx.background_spawn(async move {
915 while let Some(line) = receiver.next().await {
916 writeln!(writer, "{}", line).expect("Failed to write example");
917 writer.flush().expect("Failed to flush output");
918 }
919 })
920 .detach();
921 output_sender = Some(sender);
922 }
923
924 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
925 let finished_examples = Mutex::new(Vec::new());
926
927 let mut tasks = Vec::new();
928 for _ in 0..args.max_parallelism {
929 tasks.push(async {
930 loop {
931 let Some(mut repo_examples) =
932 grouped_examples.lock().unwrap().pop_front()
933 else {
934 break;
935 };
936 for example in &mut repo_examples {
937 let example_progress =
938 Progress::global().start_group(&example.spec.name);
939
940 let result = async {
941 match &command {
942 Command::Read => {}
943 Command::LoadProject => {
944 run_load_project(
945 example,
946 app_state.clone(),
947 &example_progress,
948 cx.clone(),
949 )
950 .await?;
951 }
952 Command::Context => {
953 run_context_retrieval(
954 example,
955 app_state.clone(),
956 &example_progress,
957 cx.clone(),
958 )
959 .await?;
960 }
961 Command::FormatPrompt(args) => {
962 run_format_prompt(
963 example,
964 args,
965 app_state.clone(),
966 &example_progress,
967 cx.clone(),
968 )
969 .await?;
970 }
971 Command::Predict(args) => {
972 run_prediction(
973 example,
974 args,
975 app_state.clone(),
976 &example_progress,
977 cx.clone(),
978 )
979 .await?;
980 }
981 Command::ParseOutput => {
982 parse_output::run_parse_output(example)?;
983 }
984 Command::Distill => {
985 run_distill(example).await?;
986 }
987 Command::Score(args) => {
988 run_scoring(
989 example,
990 args,
991 app_state.clone(),
992 &example_progress,
993 cx.clone(),
994 )
995 .await?;
996 }
997 Command::Eval(args) => {
998 run_scoring(
999 example,
1000 &args.predict,
1001 app_state.clone(),
1002 &example_progress,
1003 cx.clone(),
1004 )
1005 .await?;
1006 }
1007 Command::Qa(args) => {
1008 qa::run_qa(example, args, &example_progress).await?;
1009 }
1010 Command::Repair(args) => {
1011 repair::run_repair(example, args, &example_progress)
1012 .await?;
1013 }
1014 Command::Clean
1015 | Command::Synthesize(_)
1016 | Command::SplitCommit(_)
1017 | Command::Split(_)
1018 | Command::TruncatePatch(_)
1019 | Command::FilterLanguages(_)
1020 | Command::ImportBatch(_)
1021 | Command::PrintZetaFormats
1022 | Command::SyncDeployments(_) => {
1023 unreachable!()
1024 }
1025 }
1026 anyhow::Ok(())
1027 }
1028 .await;
1029
1030 let failed = if let Err(error) = result {
1031 handle_error(
1032 error,
1033 &args,
1034 &command,
1035 &app_state,
1036 failfast_on_single_example,
1037 &example,
1038 )
1039 .await;
1040 true
1041 } else {
1042 false
1043 };
1044
1045 let should_write = !failed || args.failed == FailedHandling::Keep;
1046 if should_write {
1047 if args.markdown {
1048 let markdown_dir =
1049 output.as_ref().expect("--markdown requires -o");
1050 let filename = format!("{}.md", example.spec.filename());
1051 let path = markdown_dir.join(&filename);
1052 let markdown = example.spec.to_markdown();
1053 std::fs::write(&path, &markdown)
1054 .expect("Failed to write markdown file");
1055 } else if let Some(ref mut sender) = output_sender.clone() {
1056 let line = serde_json::to_string(&example).unwrap();
1057 sender
1058 .send(line)
1059 .await
1060 .expect("Failed to send to output writer");
1061 } else if args.output.is_none()
1062 && !matches!(command, Command::Eval(_))
1063 {
1064 let line = serde_json::to_string(&example).unwrap();
1065 println!("{}", line);
1066 }
1067 }
1068 }
1069
1070 let project = repo_examples
1071 .iter()
1072 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()));
1073
1074 if let Some(project) = project {
1075 let mut cx = cx.clone();
1076
1077 let shutdown_task: Task<()> =
1078 project.update(&mut cx, |project, cx| {
1079 let lsp_store = project.lsp_store();
1080 lsp_store.update(cx, |lsp_store, cx| {
1081 lsp_store.shutdown_all_language_servers(cx)
1082 })
1083 });
1084
1085 shutdown_task.await;
1086
1087 if let Some(ep_store) =
1088 cx.update(|cx| EditPredictionStore::try_global(cx))
1089 {
1090 ep_store.update(&mut cx, |store, _| {
1091 store.remove_project(&project);
1092 });
1093 }
1094 }
1095
1096 for example in &mut repo_examples {
1097 example.state.take();
1098 }
1099 finished_examples
1100 .lock()
1101 .unwrap()
1102 .extend_from_slice(&repo_examples);
1103 }
1104 });
1105 }
1106 futures::future::join_all(tasks).await;
1107
1108 Progress::global().finalize();
1109
1110 match &command {
1111 Command::Predict(args) | Command::Score(args) => {
1112 predict::sync_batches(args.provider.as_ref()).await?;
1113 }
1114 Command::Eval(args) => {
1115 predict::sync_batches(args.predict.provider.as_ref()).await?;
1116 }
1117 Command::Qa(args) => {
1118 qa::sync_batches(args).await?;
1119 }
1120 Command::Repair(args) => {
1121 repair::sync_batches(args).await?;
1122 }
1123 _ => (),
1124 }
1125
1126 match &command {
1127 Command::Eval(args) => {
1128 let examples = finished_examples.lock().unwrap();
1129 score::print_report(&examples);
1130 if let Some(summary_path) = &args.summary_json {
1131 score::write_summary_json(&examples, summary_path)?;
1132 }
1133 }
1134 Command::Repair(args) => {
1135 let examples = finished_examples.lock().unwrap();
1136 repair::print_report(&examples, args.confidence_threshold);
1137 }
1138 _ => (),
1139 };
1140
1141 // For --in-place, atomically rename temp file to original
1142 if let Some(temp_path) = &in_place_temp_path {
1143 let final_path = output.as_ref().expect("in_place_temp_path requires output");
1144 std::fs::rename(temp_path, final_path)
1145 .expect("Failed to rename temp file to final output");
1146 }
1147
1148 anyhow::Ok(())
1149 }
1150 .await;
1151
1152 if let Err(e) = result {
1153 panic!("Fatal error: {:?}", e);
1154 }
1155
1156 let _ = cx.update(|cx| cx.quit());
1157 })
1158 .detach();
1159 });
1160}
1161
1162async fn handle_error(
1163 error: anyhow::Error,
1164 args: &EpArgs,
1165 command: &Command,
1166 app_state: &Arc<headless::EpAppState>,
1167 failfast_on_single_example: bool,
1168 example: &Example,
1169) {
1170 Progress::global().increment_failed();
1171
1172 let msg;
1173 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1174 let example_name = example.spec.filename();
1175
1176 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1177 app_state
1178 .fs
1179 .write(
1180 &failed_example_path,
1181 &serde_json::to_vec_pretty(&example).unwrap(),
1182 )
1183 .await
1184 .unwrap();
1185 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1186 app_state
1187 .fs
1188 .write(&err_path, format!("{error:?}").as_bytes())
1189 .await
1190 .unwrap();
1191
1192 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1193 let mut file = OpenOptions::new()
1194 .create(true)
1195 .append(true)
1196 .open(&failed_jsonl_path)
1197 .expect("Failed to open failed.jsonl");
1198 writeln!(file, "{}", serde_json::to_string(example).unwrap())
1199 .expect("Failed to write to failed.jsonl");
1200
1201 let cursor_path = match example.repo_name() {
1202 Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1203 Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1204 };
1205 msg = format!(
1206 indoc::indoc! {"
1207 While processing \"{}\":
1208
1209 \x1b[31m{:?}\x1b[0m
1210
1211 Example: \x1b[36m{}\x1b[0m
1212 Error file: \x1b[36m{}\x1b[0m
1213 Cursor file: \x1b[36m{}\x1b[0m
1214 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1215 "},
1216 example.spec.name,
1217 error,
1218 failed_example_path.display(),
1219 err_path.display(),
1220 cursor_path.display(),
1221 command,
1222 failed_example_path.display(),
1223 );
1224 } else {
1225 msg = format!(
1226 indoc::indoc! {"
1227 While processing \"{}\":
1228
1229 \x1b[31m{:?}\x1b[0m
1230 "},
1231 example.spec.name, error
1232 );
1233 }
1234
1235 if args.failfast || failfast_on_single_example {
1236 Progress::global().finalize();
1237 panic!("{}", msg);
1238 } else {
1239 log::error!("{}", msg);
1240 }
1241}