1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod load_project;
9mod metrics;
10mod openai_client;
11mod parse_output;
12mod paths;
13mod predict;
14mod progress;
15mod prompt_assets;
16mod pull_examples;
17mod qa;
18mod reorder_patch;
19mod repair;
20mod retrieve_context;
21mod reversal_tracking;
22mod score;
23mod split_commit;
24mod split_dataset;
25mod sync_deployments;
26mod synthesize;
27mod truncate_expected_patch;
28mod word_diff;
29use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
30use collections::HashSet;
31use edit_prediction::EditPredictionStore;
32use futures::channel::mpsc;
33use futures::{SinkExt as _, StreamExt as _};
34use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
35use zeta_prompt::ZetaFormat;
36
37use reqwest_client::ReqwestClient;
38use serde::{Deserialize, Deserializer, Serialize, Serializer};
39use std::fmt::Display;
40use std::fs::{File, OpenOptions};
41use std::hash::{Hash, Hasher};
42use std::io::{BufRead, BufReader, BufWriter, Write};
43use std::sync::Mutex;
44use std::{path::PathBuf, sync::Arc};
45
46use crate::distill::run_distill;
47use crate::example::{Example, group_examples_by_repo, read_example_files};
48use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
49use crate::format_prompt::run_format_prompt;
50use crate::load_project::run_load_project;
51use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
52use crate::predict::run_prediction;
53use crate::progress::Progress;
54use crate::retrieve_context::run_context_retrieval;
55use crate::score::run_scoring;
56use crate::split_commit::SplitCommitArgs;
57use crate::split_dataset::SplitArgs;
58use crate::synthesize::{SynthesizeConfig, run_synthesize};
59use crate::truncate_expected_patch::TruncatePatchArgs;
60
61#[derive(Parser, Debug)]
62#[command(name = "ep")]
63struct EpArgs {
64 #[arg(long, default_value_t = false)]
65 printenv: bool,
66 #[clap(long, default_value_t = 10, global = true)]
67 max_parallelism: usize,
68 /// The limit for the number of examples to process
69 /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
70 #[clap(long, global = true)]
71 limit: Option<usize>,
72 #[clap(long, global = true)]
73 offset: Option<usize>,
74 /// Filter examples by name
75 #[clap(long, global = true)]
76 name: Option<String>,
77 /// Filter examples by repository
78 #[clap(long, global = true)]
79 repo: Option<String>,
80 #[command(subcommand)]
81 command: Option<Command>,
82 #[clap(global = true, help = INPUTS_HELP)]
83 inputs: Vec<PathBuf>,
84 #[arg(long, short, global = true)]
85 output: Option<PathBuf>,
86 #[arg(long, short, global = true)]
87 in_place: bool,
88 #[arg(long, short, global = true)]
89 failfast: bool,
90 /// How to handle failed examples in output: keep them or skip them.
91 /// Failed examples are always logged to the run's failed directory.
92 #[arg(long, global = true, default_value = "keep")]
93 failed: FailedHandling,
94 /// Output as markdown files instead of JSONL. When set, -o specifies a directory
95 /// where one .md file per example will be written (named after each example).
96 #[arg(long, short, global = true)]
97 markdown: bool,
98}
99
100/// Controls whether failed examples are included in the main output.
101/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
102#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
103pub enum FailedHandling {
104 /// Include failed examples in the main output (default)
105 #[default]
106 Keep,
107 /// Exclude failed examples from the main output
108 Skip,
109 /// Skip writing files
110 SkipNoFiles,
111}
112
113const INPUTS_HELP: &str = r#"
114Inputs can be file paths or special specifiers:
115
116 path
117 Path to an example(s) file (.md, .json, or .jsonl)
118
119 captured-after:{timestamp}
120 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
121 These are examples captured via the "Capture Edit Prediction Example" action.
122
123 rejected-after:{timestamp}
124 Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
125 These are predictions that were shown to users but rejected (useful for DPO training).
126
127 rated-after:{timestamp}
128 Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
129 These are predictions that users explicitly rated as positive or negative via the
130 rate completions modal. Only zeta2 predictions are included.
131 - Positive ratings: output becomes expected_patches
132 - Negative ratings: output becomes rejected_patch
133
134 rated-positive-after:{timestamp}
135 Same as rated-after, but only fetches positively rated predictions.
136
137 rated-negative-after:{timestamp}
138 Same as rated-after, but only fetches negatively rated predictions.
139
140 Required environment variables to connect to Snowflake:
141 EP_SNOWFLAKE_API_KEY
142 EP_SNOWFLAKE_BASE_URL
143
144 Optional:
145 EP_SNOWFLAKE_ROLE
146
147Examples:
148
149 # Read examples from a file
150 ep read examples.jsonl -o output.jsonl
151
152 # Read captured examples after a timestamp
153 ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
154
155 # Read rejected predictions for DPO training
156 ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
157
158 # Read user-rated predictions
159 ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
160
161 # Read only positively rated predictions
162 ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
163
164 # Read only negatively rated predictions
165 ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
166
167 # Mix multiple input sources
168 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
169"#;
170
171#[derive(Subcommand, Debug, Clone)]
172enum Command {
173 /// Read examples from files or fetch from Snowflake, output as .jsonl
174 Read,
175 /// Create git worktrees for each example and load file contents
176 LoadProject,
177 /// Retrieve context for input examples.
178 Context,
179 /// Generate a prompt string for a specific model
180 FormatPrompt(FormatPromptArgs),
181 /// Runs edit prediction
182 Predict(PredictArgs),
183 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
184 /// Requires format-prompt to have been run first. Uses provider from prompt.
185 ParseOutput,
186 /// Computes a score based on actual and expected patches
187 Score(PredictArgs),
188 /// Prepares a distillation dataset by copying expected outputs to
189 /// predicted outputs and removing actual outputs and prompts.
190 Distill,
191 /// Print aggregated scores
192 Eval(EvalArgs),
193 /// Generate eval examples by analyzing git commits from a repository
194 Synthesize(SynthesizeArgs),
195 /// Remove git repositories and worktrees
196 Clean,
197 /// Generate an evaluation example by splitting a chronologically-ordered commit
198 SplitCommit(SplitCommitArgs),
199 /// Truncate expected patch by the given criteria
200 TruncatePatch(TruncatePatchArgs),
201 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
202 Split(SplitArgs),
203 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
204 FilterLanguages(FilterLanguagesArgs),
205 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
206 ImportBatch(ImportBatchArgs),
207 /// Assess the quality of predictions using LLM-as-a-judge
208 Qa(qa::QaArgs),
209 /// Repair predictions that received poor QA scores by generating improved predictions
210 Repair(repair::RepairArgs),
211 /// Print all valid zeta formats (lowercase, one per line)
212 PrintZetaFormats,
213 /// Sync baseten deployment metadata to Snowflake for linking predictions to experiments
214 SyncDeployments(SyncDeploymentsArgs),
215}
216
217impl Display for Command {
218 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219 match self {
220 Command::Read => write!(f, "read"),
221 Command::LoadProject => write!(f, "load-project"),
222 Command::Context => write!(f, "context"),
223 Command::FormatPrompt(args) => {
224 write!(f, "format-prompt --provider={}", args.provider)
225 }
226 Command::Predict(args) => match &args.provider {
227 Some(provider) => write!(f, "predict --provider={}", provider),
228 None => write!(f, "predict"),
229 },
230 Command::ParseOutput => write!(f, "parse-output"),
231 Command::Score(args) => match &args.provider {
232 Some(provider) => write!(f, "score --provider={}", provider),
233 None => write!(f, "score"),
234 },
235 Command::Distill => write!(f, "distill"),
236 Command::Eval(args) => match &args.predict.provider {
237 Some(provider) => write!(f, "eval --provider={}", provider),
238 None => write!(f, "eval"),
239 },
240 Command::Synthesize(args) => {
241 write!(f, "synthesize --repos {}", args.repos.join(" "))
242 }
243 Command::Clean => write!(f, "clean"),
244 Command::SplitCommit(_) => write!(f, "split-commit"),
245 Command::TruncatePatch(_) => write!(f, "truncate-patch"),
246 Command::Split(_) => write!(f, "split"),
247 Command::FilterLanguages(_) => write!(f, "filter-languages"),
248 Command::ImportBatch(args) => {
249 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
250 }
251 Command::Qa(_) => {
252 write!(f, "qa")
253 }
254 Command::Repair(_) => {
255 write!(f, "repair")
256 }
257 Command::PrintZetaFormats => {
258 write!(f, "print-zeta-formats")
259 }
260 Command::SyncDeployments(_) => {
261 write!(f, "sync-deployments")
262 }
263 }
264 }
265}
266
267#[derive(Debug, Args, Clone)]
268struct SyncDeploymentsArgs {
269 /// BaseTen model name (default: "zeta-2")
270 #[arg(long)]
271 model: Option<String>,
272}
273
274#[derive(Debug, Args, Clone)]
275struct FormatPromptArgs {
276 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
277 provider: PredictionProvider,
278}
279
280#[derive(Debug, Args, Clone)]
281struct PredictArgs {
282 #[clap(long, short('p'))]
283 provider: Option<PredictionProvider>,
284 #[clap(long, default_value_t = 1)]
285 repetitions: usize,
286 /// Only use cached responses, don't queue new requests for batching
287 #[clap(long)]
288 cache_only: bool,
289}
290
291#[derive(Debug, Args, Clone)]
292struct EvalArgs {
293 #[clap(flatten)]
294 predict: PredictArgs,
295 /// Path to write summary scores as JSON
296 #[clap(long)]
297 summary_json: Option<PathBuf>,
298}
299
300#[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)]
301pub enum TeacherBackend {
302 Sonnet46,
303 #[default]
304 Sonnet45,
305 Gpt52,
306}
307
308impl std::fmt::Display for TeacherBackend {
309 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310 match self {
311 TeacherBackend::Sonnet46 => write!(f, "sonnet46"),
312 TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
313 TeacherBackend::Gpt52 => write!(f, "gpt52"),
314 }
315 }
316}
317
318impl std::str::FromStr for TeacherBackend {
319 type Err = anyhow::Error;
320
321 fn from_str(s: &str) -> Result<Self, Self::Err> {
322 match s.to_lowercase().as_str() {
323 "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
324 "sonnet46" => Ok(TeacherBackend::Sonnet46),
325 "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
326 "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
327 _ => anyhow::bail!(
328 "unknown teacher backend `{s}`. Valid options: sonnet45, sonnet46, gpt52"
329 ),
330 }
331 }
332}
333
334impl TeacherBackend {
335 pub fn model_name(&self) -> &'static str {
336 match self {
337 TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
338 TeacherBackend::Sonnet46 => "claude-sonnet-4-6",
339 TeacherBackend::Gpt52 => "gpt-5.2",
340 }
341 }
342}
343
344#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
345enum PredictionProvider {
346 Sweep,
347 Mercury,
348 Zeta1,
349 Zeta2(ZetaFormat),
350 Teacher(TeacherBackend),
351 TeacherNonBatching(TeacherBackend),
352 Repair,
353}
354
355impl Default for PredictionProvider {
356 fn default() -> Self {
357 PredictionProvider::Zeta2(ZetaFormat::default())
358 }
359}
360
361impl std::fmt::Display for PredictionProvider {
362 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
363 match self {
364 PredictionProvider::Sweep => write!(f, "sweep"),
365 PredictionProvider::Mercury => write!(f, "mercury"),
366 PredictionProvider::Zeta1 => write!(f, "zeta1"),
367 PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
368 PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
369 PredictionProvider::TeacherNonBatching(backend) => {
370 write!(f, "teacher-non-batching:{backend}")
371 }
372 PredictionProvider::Repair => write!(f, "repair"),
373 }
374 }
375}
376
377impl std::str::FromStr for PredictionProvider {
378 type Err = anyhow::Error;
379
380 fn from_str(s: &str) -> Result<Self, Self::Err> {
381 let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
382
383 let provider_lower = provider.to_lowercase();
384 match provider_lower.as_str() {
385 "sweep" => Ok(PredictionProvider::Sweep),
386 "mercury" => Ok(PredictionProvider::Mercury),
387 "zeta1" => Ok(PredictionProvider::Zeta1),
388 "zeta2" => {
389 let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
390 Ok(PredictionProvider::Zeta2(format))
391 }
392 "teacher" => {
393 let backend = arg
394 .map(|a| a.parse())
395 .transpose()?
396 .unwrap_or(TeacherBackend::default());
397 Ok(PredictionProvider::Teacher(backend))
398 }
399 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
400 let backend = arg
401 .map(|a| a.parse())
402 .transpose()?
403 .unwrap_or(TeacherBackend::default());
404 Ok(PredictionProvider::TeacherNonBatching(backend))
405 }
406 "repair" => Ok(PredictionProvider::Repair),
407 _ => {
408 anyhow::bail!(
409 "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
410 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
411 For teacher, you can specify a backend like `teacher:sonnet46` or `teacher:gpt52`.\n\
412 Available zeta versions:\n{}",
413 ZetaFormat::options_as_string()
414 )
415 }
416 }
417 }
418}
419
420impl Serialize for PredictionProvider {
421 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
422 where
423 S: Serializer,
424 {
425 serializer.serialize_str(&self.to_string())
426 }
427}
428
429impl<'de> Deserialize<'de> for PredictionProvider {
430 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
431 where
432 D: Deserializer<'de>,
433 {
434 let s = String::deserialize(deserializer)?;
435 s.parse().map_err(serde::de::Error::custom)
436 }
437}
438
439#[derive(Debug, Args, Clone)]
440struct SynthesizeArgs {
441 /// Repository URLs (git@github.com:owner/repo or https://...)
442 #[clap(long, required = true, num_args = 1..)]
443 repos: Vec<String>,
444
445 /// Number of examples to generate per repository
446 #[clap(long, default_value_t = 5)]
447 count: usize,
448
449 /// Maximum commits to scan per repository before giving up
450 #[clap(long, default_value_t = 100)]
451 max_commits: usize,
452
453 /// Ignore state file and reprocess all commits
454 #[clap(long)]
455 fresh: bool,
456}
457
458#[derive(Debug, Args, Clone)]
459struct ImportBatchArgs {
460 /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
461 #[clap(long, required = true, num_args = 1..)]
462 batch_ids: Vec<String>,
463 /// Which provider's batches to import (anthropic or openai)
464 #[clap(long, default_value = "anthropic")]
465 provider: BatchProvider,
466}
467
468#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
469enum BatchProvider {
470 Anthropic,
471 Openai,
472}
473
474impl EpArgs {
475 fn output_path(&self) -> Option<PathBuf> {
476 if self.in_place {
477 if self.inputs.len() == 1 {
478 self.inputs.first().cloned()
479 } else {
480 panic!("--in-place requires exactly one input file")
481 }
482 } else {
483 self.output.clone()
484 }
485 }
486}
487
488/// Minimum Zed version required for Snowflake queries.
489/// This version introduced the current request schema with predicted edits in the edit
490/// history, and open source repos distinguished.
491const MIN_CAPTURE_VERSION: pull_examples::MinCaptureVersion = pull_examples::MinCaptureVersion {
492 minor: 224,
493 patch: 1,
494};
495
496async fn load_examples(
497 http_client: Arc<dyn http_client::HttpClient>,
498 args: &EpArgs,
499 output_path: Option<&PathBuf>,
500 background_executor: BackgroundExecutor,
501) -> anyhow::Result<Vec<Example>> {
502 let mut captured_after_timestamps = Vec::new();
503 let mut rejected_after_timestamps = Vec::new();
504 let mut requested_after_timestamps = Vec::new();
505 let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
506 Vec::new();
507 let mut file_inputs = Vec::new();
508
509 for input in &args.inputs {
510 let input_string = input.to_string_lossy();
511 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
512 captured_after_timestamps.push(timestamp.to_string());
513 } else if let Some(timestamp) =
514 pull_examples::parse_rejected_after_input(input_string.as_ref())
515 {
516 rejected_after_timestamps.push(timestamp.to_string());
517 } else if let Some(timestamp) =
518 pull_examples::parse_requested_after_input(input_string.as_ref())
519 {
520 requested_after_timestamps.push(timestamp.to_string());
521 } else if let Some((timestamp, rating_filter)) =
522 pull_examples::parse_rated_after_input(input_string.as_ref())
523 {
524 rated_after_inputs.push((timestamp.to_string(), rating_filter));
525 } else {
526 file_inputs.push(input.clone());
527 }
528 }
529
530 let mut examples = read_example_files(&file_inputs);
531
532 // Apply offset to file examples first, then pass remaining offset to Snowflake.
533 let file_example_count = examples.len();
534 let remaining_offset = if let Some(offset) = args.offset {
535 if offset >= file_example_count {
536 examples.clear();
537 offset - file_example_count
538 } else {
539 examples.splice(0..offset, []);
540 0
541 }
542 } else {
543 0
544 };
545
546 Progress::global().set_total_examples(examples.len());
547
548 let remaining_limit_for_snowflake =
549 args.limit.map(|limit| limit.saturating_sub(examples.len()));
550
551 if let Some(0) = remaining_limit_for_snowflake {
552 log::info!(
553 "skipping Snowflake inputs because --limit is already satisfied by example files"
554 );
555 } else {
556 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
557
558 if !captured_after_timestamps.is_empty() {
559 captured_after_timestamps.sort();
560
561 let mut captured_examples = pull_examples::fetch_captured_examples_after(
562 http_client.clone(),
563 &captured_after_timestamps,
564 max_rows_per_timestamp,
565 remaining_offset,
566 background_executor.clone(),
567 Some(MIN_CAPTURE_VERSION),
568 )
569 .await?;
570 examples.append(&mut captured_examples);
571 }
572
573 if !rejected_after_timestamps.is_empty() {
574 rejected_after_timestamps.sort();
575
576 let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
577 http_client.clone(),
578 &rejected_after_timestamps,
579 max_rows_per_timestamp,
580 remaining_offset,
581 background_executor.clone(),
582 Some(MIN_CAPTURE_VERSION),
583 )
584 .await?;
585 examples.append(&mut rejected_examples);
586 }
587
588 if !requested_after_timestamps.is_empty() {
589 requested_after_timestamps.sort();
590
591 let mut requested_examples = pull_examples::fetch_requested_examples_after(
592 http_client.clone(),
593 &requested_after_timestamps,
594 max_rows_per_timestamp,
595 remaining_offset,
596 background_executor.clone(),
597 Some(MIN_CAPTURE_VERSION),
598 )
599 .await?;
600 examples.append(&mut requested_examples);
601 }
602
603 if !rated_after_inputs.is_empty() {
604 rated_after_inputs.sort();
605
606 let mut rated_examples = pull_examples::fetch_rated_examples_after(
607 http_client,
608 &rated_after_inputs,
609 max_rows_per_timestamp,
610 remaining_offset,
611 background_executor,
612 Some(MIN_CAPTURE_VERSION),
613 )
614 .await?;
615 examples.append(&mut rated_examples);
616 }
617 }
618
619 crate::example::sort_examples_by_repo_and_rev(&mut examples);
620
621 if let Some(name_filter) = &args.name {
622 examples.retain(|example| example.spec.name.contains(name_filter));
623 }
624 if let Some(repo_filter) = &args.repo {
625 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
626 }
627
628 // Skip resume logic for --in-place since input and output are the same file,
629 // which would incorrectly treat all input examples as already processed.
630 if !args.in_place {
631 if let Some(path) = output_path
632 && let Some(command) = &args.command
633 {
634 resume_from_output(path, &mut examples, command);
635 }
636 }
637
638 if let Some(limit) = args.limit {
639 examples.truncate(limit);
640 }
641
642 let progress = Progress::global();
643 progress.set_total_examples(examples.len());
644 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
645
646 Ok(examples)
647}
648
649fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
650 let mut hasher = collections::FxHasher::default();
651 spec.hash(&mut hasher);
652 hasher.finish()
653}
654
655fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
656 let file = match File::open(path) {
657 Ok(f) => f,
658 Err(_) => return,
659 };
660
661 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
662
663 let reader = BufReader::new(file);
664 let mut kept_lines = Vec::new();
665 let mut kept_hashes = HashSet::default();
666
667 for line in reader.lines() {
668 let line = match line {
669 Ok(l) => l,
670 Err(_) => continue,
671 };
672
673 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
674 let hash = spec_hash(&output_example.spec);
675 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
676 let is_complete = match command {
677 Command::Qa(_) => output_example
678 .qa
679 .first()
680 .and_then(|q| q.as_ref())
681 .and_then(|q| q.confidence)
682 .is_some(),
683 Command::Repair(_) => output_example.predictions.iter().any(|p| {
684 p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
685 }),
686 _ => true,
687 };
688 if is_complete {
689 kept_hashes.insert(hash);
690 kept_lines.push(line);
691 }
692 }
693 }
694 }
695
696 let total = examples.len();
697 let already_processed = kept_hashes.len();
698
699 eprintln!(
700 "Resuming: {}/{} examples already processed",
701 already_processed, total
702 );
703
704 let file = OpenOptions::new()
705 .write(true)
706 .truncate(true)
707 .open(path)
708 .expect("Failed to open output file for rewriting");
709 let mut writer = BufWriter::new(file);
710 for line in &kept_lines {
711 writeln!(writer, "{}", line).expect("Failed to write to output file");
712 }
713 writer.flush().expect("Failed to flush output file");
714
715 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
716}
717
718fn main() {
719 let args = EpArgs::parse();
720
721 if args.printenv {
722 ::util::shell_env::print_env();
723 return;
724 }
725
726 let output = args.output_path();
727
728 if args.markdown && output.is_none() {
729 eprintln!("--markdown requires -o to specify the output directory");
730 std::process::exit(1);
731 }
732
733 let command = match &args.command {
734 Some(cmd) => cmd.clone(),
735 None => {
736 EpArgs::command().print_help().unwrap();
737 return;
738 }
739 };
740
741 match &command {
742 Command::ImportBatch(import_args) => {
743 smol::block_on(async {
744 match import_args.provider {
745 BatchProvider::Anthropic => {
746 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
747 .expect("Failed to create Anthropic client");
748 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
749 eprintln!("Error importing Anthropic batches: {:?}", e);
750 std::process::exit(1);
751 }
752 }
753 BatchProvider::Openai => {
754 let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
755 .expect("Failed to create OpenAI client");
756 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
757 eprintln!("Error importing OpenAI batches: {:?}", e);
758 std::process::exit(1);
759 }
760 }
761 }
762 println!(
763 "Successfully imported {} batch(es)",
764 import_args.batch_ids.len()
765 );
766 });
767 return;
768 }
769 Command::Clean => {
770 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
771 return;
772 }
773 Command::PrintZetaFormats => {
774 use strum::IntoEnumIterator as _;
775 for format in ZetaFormat::iter() {
776 println!("{}", format.to_string().to_lowercase());
777 }
778 return;
779 }
780 Command::SyncDeployments(sync_args) => {
781 let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
782 smol::block_on(async {
783 if let Err(e) =
784 sync_deployments::run_sync_deployments(http_client, sync_args.model.clone())
785 .await
786 {
787 eprintln!("Error: {:?}", e);
788 std::process::exit(1);
789 }
790 });
791 return;
792 }
793 Command::Synthesize(synth_args) => {
794 let Some(output_dir) = args.output else {
795 panic!("output dir is required");
796 };
797 let config = SynthesizeConfig {
798 repo_urls: synth_args.repos.clone(),
799 count: synth_args.count,
800 max_commits: synth_args.max_commits,
801 output_dir,
802 fresh: synth_args.fresh,
803 };
804 smol::block_on(async {
805 if let Err(e) = run_synthesize(config).await {
806 eprintln!("Error: {:?}", e);
807 std::process::exit(1);
808 }
809 });
810 return;
811 }
812 Command::SplitCommit(split_commit_args) => {
813 if let Err(error) = split_commit::run_split_commit(
814 split_commit_args,
815 &args.inputs,
816 output.as_ref(),
817 args.failed,
818 ) {
819 eprintln!("{error:#}");
820 std::process::exit(1);
821 }
822 return;
823 }
824 Command::TruncatePatch(truncate_args) => {
825 if let Err(error) =
826 truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
827 {
828 eprintln!("{error:#}");
829 std::process::exit(1);
830 }
831 return;
832 }
833 Command::Split(split_args) => {
834 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
835 eprintln!("{error:#}");
836 std::process::exit(1);
837 }
838 return;
839 }
840 Command::FilterLanguages(filter_args) => {
841 if let Err(error) =
842 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
843 {
844 eprintln!("{error:#}");
845 std::process::exit(1);
846 }
847 return;
848 }
849
850 _ => {}
851 }
852
853 let http_client = Arc::new(ReqwestClient::new());
854 let app = Application::headless().with_http_client(http_client);
855
856 app.run(move |cx| {
857 let app_state = Arc::new(headless::init(cx));
858 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
859
860 cx.spawn(async move |cx| {
861 let result = async {
862 let examples = load_examples(
863 app_state.client.http_client(),
864 &args,
865 output.as_ref(),
866 cx.background_executor().clone(),
867 )
868 .await?;
869
870 match &command {
871 Command::Predict(args) | Command::Score(args) => {
872 predict::sync_batches(args.provider.as_ref()).await?;
873 }
874 Command::Eval(args) => {
875 predict::sync_batches(args.predict.provider.as_ref()).await?;
876 }
877 Command::Qa(args) => {
878 qa::sync_batches(args).await?;
879 }
880 Command::Repair(args) => {
881 repair::sync_batches(args).await?;
882 }
883 _ => (),
884 }
885
886 let failfast_on_single_example = examples.len() == 1;
887
888 // For --markdown mode, create the output directory if it doesn't exist
889 if args.markdown {
890 let dir = output.as_ref().expect("--markdown requires -o");
891 if !dir.exists() {
892 std::fs::create_dir_all(dir)
893 .expect("Failed to create markdown output directory");
894 }
895 }
896
897 // Set up JSONL output writer (not used in markdown mode)
898 let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
899 let mut in_place_temp_path: Option<PathBuf> = None;
900 if !args.markdown
901 && let Some(output_path) = output.as_ref()
902 {
903 let write_path = if args.in_place {
904 let temp = output_path.with_extension("jsonl.tmp");
905 in_place_temp_path = Some(temp.clone());
906 temp
907 } else {
908 output_path.clone()
909 };
910
911 let file = OpenOptions::new()
912 .create(true)
913 .write(true)
914 .truncate(args.in_place)
915 .append(!args.in_place)
916 .open(&write_path)
917 .expect("Failed to open output file");
918
919 let mut writer = BufWriter::new(file);
920 let (sender, mut receiver) = mpsc::unbounded::<String>();
921 cx.background_spawn(async move {
922 while let Some(line) = receiver.next().await {
923 writeln!(writer, "{}", line).expect("Failed to write example");
924 writer.flush().expect("Failed to flush output");
925 }
926 })
927 .detach();
928 output_sender = Some(sender);
929 }
930
931 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
932 let finished_examples = Mutex::new(Vec::new());
933
934 let mut tasks = Vec::new();
935 for _ in 0..args.max_parallelism {
936 tasks.push(async {
937 loop {
938 let Some(mut repo_examples) =
939 grouped_examples.lock().unwrap().pop_front()
940 else {
941 break;
942 };
943 for example in &mut repo_examples {
944 let example_progress =
945 Progress::global().start_group(&example.spec.name);
946
947 let result = async {
948 match &command {
949 Command::Read => {}
950 Command::LoadProject => {
951 run_load_project(
952 example,
953 app_state.clone(),
954 &example_progress,
955 cx.clone(),
956 )
957 .await?;
958 }
959 Command::Context => {
960 run_context_retrieval(
961 example,
962 app_state.clone(),
963 &example_progress,
964 cx.clone(),
965 )
966 .await?;
967 }
968 Command::FormatPrompt(args) => {
969 run_format_prompt(
970 example,
971 args,
972 app_state.clone(),
973 &example_progress,
974 cx.clone(),
975 )
976 .await?;
977 }
978 Command::Predict(args) => {
979 run_prediction(
980 example,
981 args,
982 app_state.clone(),
983 &example_progress,
984 cx.clone(),
985 )
986 .await?;
987 }
988 Command::ParseOutput => {
989 parse_output::run_parse_output(example)?;
990 }
991 Command::Distill => {
992 run_distill(example).await?;
993 }
994 Command::Score(args) => {
995 run_scoring(
996 example,
997 args,
998 app_state.clone(),
999 &example_progress,
1000 cx.clone(),
1001 )
1002 .await?;
1003 }
1004 Command::Eval(args) => {
1005 run_scoring(
1006 example,
1007 &args.predict,
1008 app_state.clone(),
1009 &example_progress,
1010 cx.clone(),
1011 )
1012 .await?;
1013 }
1014 Command::Qa(args) => {
1015 qa::run_qa(example, args, &example_progress).await?;
1016 }
1017 Command::Repair(args) => {
1018 repair::run_repair(example, args, &example_progress)
1019 .await?;
1020 }
1021 Command::Clean
1022 | Command::Synthesize(_)
1023 | Command::SplitCommit(_)
1024 | Command::Split(_)
1025 | Command::TruncatePatch(_)
1026 | Command::FilterLanguages(_)
1027 | Command::ImportBatch(_)
1028 | Command::PrintZetaFormats
1029 | Command::SyncDeployments(_) => {
1030 unreachable!()
1031 }
1032 }
1033 anyhow::Ok(())
1034 }
1035 .await;
1036
1037 let failed = if let Err(error) = result {
1038 handle_error(
1039 error,
1040 &args,
1041 &command,
1042 &app_state,
1043 failfast_on_single_example,
1044 &example,
1045 )
1046 .await;
1047 true
1048 } else {
1049 false
1050 };
1051
1052 let should_write = !failed || args.failed == FailedHandling::Keep;
1053 if should_write {
1054 if args.markdown {
1055 let markdown_dir =
1056 output.as_ref().expect("--markdown requires -o");
1057 let filename = format!("{}.md", example.spec.filename());
1058 let path = markdown_dir.join(&filename);
1059 let markdown = example.spec.to_markdown();
1060 std::fs::write(&path, &markdown)
1061 .expect("Failed to write markdown file");
1062 } else if let Some(ref mut sender) = output_sender.clone() {
1063 let line = serde_json::to_string(&example).unwrap();
1064 sender
1065 .send(line)
1066 .await
1067 .expect("Failed to send to output writer");
1068 } else if args.output.is_none()
1069 && !matches!(command, Command::Eval(_))
1070 {
1071 let line = serde_json::to_string(&example).unwrap();
1072 println!("{}", line);
1073 }
1074 }
1075 }
1076
1077 let project = repo_examples
1078 .iter()
1079 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()));
1080
1081 if let Some(project) = project {
1082 let mut cx = cx.clone();
1083
1084 let shutdown_task: Task<()> =
1085 project.update(&mut cx, |project, cx| {
1086 let lsp_store = project.lsp_store();
1087 lsp_store.update(cx, |lsp_store, cx| {
1088 lsp_store.shutdown_all_language_servers(cx)
1089 })
1090 });
1091
1092 shutdown_task.await;
1093
1094 if let Some(ep_store) =
1095 cx.update(|cx| EditPredictionStore::try_global(cx))
1096 {
1097 ep_store.update(&mut cx, |store, _| {
1098 store.remove_project(&project);
1099 });
1100 }
1101 }
1102
1103 for example in &mut repo_examples {
1104 example.state.take();
1105 }
1106 finished_examples
1107 .lock()
1108 .unwrap()
1109 .extend_from_slice(&repo_examples);
1110 }
1111 });
1112 }
1113 futures::future::join_all(tasks).await;
1114
1115 Progress::global().finalize();
1116
1117 match &command {
1118 Command::Predict(args) | Command::Score(args) => {
1119 predict::sync_batches(args.provider.as_ref()).await?;
1120 }
1121 Command::Eval(args) => {
1122 predict::sync_batches(args.predict.provider.as_ref()).await?;
1123 }
1124 Command::Qa(args) => {
1125 qa::sync_batches(args).await?;
1126 }
1127 Command::Repair(args) => {
1128 repair::sync_batches(args).await?;
1129 }
1130 _ => (),
1131 }
1132
1133 match &command {
1134 Command::Eval(args) => {
1135 let examples = finished_examples.lock().unwrap();
1136 score::print_report(&examples);
1137 if let Some(summary_path) = &args.summary_json {
1138 score::write_summary_json(&examples, summary_path)?;
1139 }
1140 }
1141 Command::Repair(args) => {
1142 let examples = finished_examples.lock().unwrap();
1143 repair::print_report(&examples, args.confidence_threshold);
1144 }
1145 _ => (),
1146 };
1147
1148 // For --in-place, atomically rename temp file to original
1149 if let Some(temp_path) = &in_place_temp_path {
1150 let final_path = output.as_ref().expect("in_place_temp_path requires output");
1151 std::fs::rename(temp_path, final_path)
1152 .expect("Failed to rename temp file to final output");
1153 }
1154
1155 anyhow::Ok(())
1156 }
1157 .await;
1158
1159 if let Err(e) = result {
1160 panic!("Fatal error: {:?}", e);
1161 }
1162
1163 let _ = cx.update(|cx| cx.quit());
1164 })
1165 .detach();
1166 });
1167}
1168
1169async fn handle_error(
1170 error: anyhow::Error,
1171 args: &EpArgs,
1172 command: &Command,
1173 app_state: &Arc<headless::EpAppState>,
1174 failfast_on_single_example: bool,
1175 example: &Example,
1176) {
1177 Progress::global().increment_failed();
1178
1179 let msg;
1180 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1181 let example_name = example.spec.filename();
1182
1183 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1184 app_state
1185 .fs
1186 .write(
1187 &failed_example_path,
1188 &serde_json::to_vec_pretty(&example).unwrap(),
1189 )
1190 .await
1191 .unwrap();
1192 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1193 app_state
1194 .fs
1195 .write(&err_path, format!("{error:?}").as_bytes())
1196 .await
1197 .unwrap();
1198
1199 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1200 let mut file = OpenOptions::new()
1201 .create(true)
1202 .append(true)
1203 .open(&failed_jsonl_path)
1204 .expect("Failed to open failed.jsonl");
1205 writeln!(file, "{}", serde_json::to_string(example).unwrap())
1206 .expect("Failed to write to failed.jsonl");
1207
1208 let cursor_path = match example.repo_name() {
1209 Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1210 Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1211 };
1212 msg = format!(
1213 indoc::indoc! {"
1214 While processing \"{}\":
1215
1216 \x1b[31m{:?}\x1b[0m
1217
1218 Example: \x1b[36m{}\x1b[0m
1219 Error file: \x1b[36m{}\x1b[0m
1220 Cursor file: \x1b[36m{}\x1b[0m
1221 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1222 "},
1223 example.spec.name,
1224 error,
1225 failed_example_path.display(),
1226 err_path.display(),
1227 cursor_path.display(),
1228 command,
1229 failed_example_path.display(),
1230 );
1231 } else {
1232 msg = format!(
1233 indoc::indoc! {"
1234 While processing \"{}\":
1235
1236 \x1b[31m{:?}\x1b[0m
1237 "},
1238 example.spec.name, error
1239 );
1240 }
1241
1242 if args.failfast || failfast_on_single_example {
1243 Progress::global().finalize();
1244 panic!("{}", msg);
1245 } else {
1246 log::error!("{}", msg);
1247 }
1248}