1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod load_project;
9mod metrics;
10mod openai_client;
11mod parse_output;
12mod paths;
13mod predict;
14mod progress;
15mod prompt_assets;
16mod pull_examples;
17mod qa;
18mod reorder_patch;
19mod repair;
20mod retrieve_context;
21mod reversal_tracking;
22mod score;
23mod split_commit;
24mod split_dataset;
25
26mod synthesize;
27mod truncate_expected_patch;
28mod word_diff;
29use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
30use collections::HashSet;
31use edit_prediction::EditPredictionStore;
32use futures::channel::mpsc;
33use futures::{SinkExt as _, StreamExt as _};
34use gpui::{AppContext as _, BackgroundExecutor, Task};
35use zeta_prompt::ZetaFormat;
36
37use reqwest_client::ReqwestClient;
38use serde::{Deserialize, Deserializer, Serialize, Serializer};
39use std::fmt::Display;
40use std::fs::{File, OpenOptions};
41use std::hash::{Hash, Hasher};
42use std::io::{BufRead, BufReader, BufWriter, Write};
43use std::sync::Mutex;
44use std::{path::PathBuf, sync::Arc};
45
46use crate::distill::run_distill;
47use crate::example::{Example, group_examples_by_repo, read_example_files};
48use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
49use crate::format_prompt::run_format_prompt;
50use crate::load_project::run_load_project;
51use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
52use crate::predict::run_prediction;
53use crate::progress::Progress;
54use crate::retrieve_context::run_context_retrieval;
55use crate::score::run_scoring;
56use crate::split_commit::SplitCommitArgs;
57use crate::split_dataset::SplitArgs;
58use crate::synthesize::{SynthesizeConfig, run_synthesize};
59use crate::truncate_expected_patch::TruncatePatchArgs;
60
61#[derive(Parser, Debug)]
62#[command(name = "ep")]
63struct EpArgs {
64 #[arg(long, default_value_t = false)]
65 printenv: bool,
66 #[clap(long, default_value_t = 10, global = true)]
67 max_parallelism: usize,
68 /// The limit for the number of examples to process
69 /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
70 #[clap(long, global = true)]
71 limit: Option<usize>,
72 #[clap(long, global = true)]
73 offset: Option<usize>,
74 /// Filter examples by name
75 #[clap(long, global = true)]
76 name: Option<String>,
77 /// Filter examples by repository
78 #[clap(long, global = true)]
79 repo: Option<String>,
80 #[command(subcommand)]
81 command: Option<Command>,
82 #[clap(global = true, help = INPUTS_HELP)]
83 inputs: Vec<PathBuf>,
84 #[arg(long, short, global = true)]
85 output: Option<PathBuf>,
86 #[arg(long, short, global = true)]
87 in_place: bool,
88 #[arg(long, short, global = true)]
89 failfast: bool,
90 /// How to handle failed examples in output: keep them or skip them.
91 /// Failed examples are always logged to the run's failed directory.
92 #[arg(long, global = true, default_value = "keep")]
93 failed: FailedHandling,
94 /// Output as markdown files instead of JSONL. When set, -o specifies a directory
95 /// where one .md file per example will be written (named after each example).
96 #[arg(long, short, global = true)]
97 markdown: bool,
98}
99
100/// Controls whether failed examples are included in the main output.
101/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
102#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
103pub enum FailedHandling {
104 /// Include failed examples in the main output (default)
105 #[default]
106 Keep,
107 /// Exclude failed examples from the main output
108 Skip,
109 /// Skip writing files
110 SkipNoFiles,
111}
112
113const INPUTS_HELP: &str = r#"
114Inputs can be file paths or special specifiers:
115
116 path
117 Path to an example(s) file (.md, .json, or .jsonl)
118
119 captured-after:{timestamp}
120 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
121 These are examples captured via the "Capture Edit Prediction Example" action.
122
123 rejected-after:{timestamp}
124 Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
125 These are predictions that were shown to users but rejected (useful for DPO training).
126
127 rated-after:{timestamp}
128 Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
129 These are predictions that users explicitly rated as positive or negative via the
130 rate completions modal. Only zeta2 predictions are included.
131 - Positive ratings: output becomes expected_patches
132 - Negative ratings: output becomes rejected_patch
133
134 rated-positive-after:{timestamp}
135 Same as rated-after, but only fetches positively rated predictions.
136
137 rated-negative-after:{timestamp}
138 Same as rated-after, but only fetches negatively rated predictions.
139
140 Required environment variables to connect to Snowflake:
141 EP_SNOWFLAKE_API_KEY
142 EP_SNOWFLAKE_BASE_URL
143
144 Optional:
145 EP_SNOWFLAKE_ROLE
146
147Examples:
148
149 # Read examples from a file
150 ep read examples.jsonl -o output.jsonl
151
152 # Read captured examples after a timestamp
153 ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
154
155 # Read rejected predictions for DPO training
156 ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
157
158 # Read user-rated predictions
159 ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
160
161 # Read only positively rated predictions
162 ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
163
164 # Read only negatively rated predictions
165 ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
166
167 # Mix multiple input sources
168 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
169"#;
170
171#[derive(Subcommand, Debug, Clone)]
172enum Command {
173 /// Read examples from files or fetch from Snowflake, output as .jsonl
174 Read,
175 /// Create git worktrees for each example and load file contents
176 LoadProject,
177 /// Retrieve context for input examples.
178 Context,
179 /// Generate a prompt string for a specific model
180 FormatPrompt(FormatPromptArgs),
181 /// Runs edit prediction
182 Predict(PredictArgs),
183 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
184 /// Requires format-prompt to have been run first. Uses provider from prompt.
185 ParseOutput,
186 /// Computes a score based on actual and expected patches
187 Score(PredictArgs),
188 /// Prepares a distillation dataset by copying expected outputs to
189 /// predicted outputs and removing actual outputs and prompts.
190 Distill,
191 /// Print aggregated scores
192 Eval(EvalArgs),
193 /// Generate eval examples by analyzing git commits from a repository
194 Synthesize(SynthesizeArgs),
195 /// Remove git repositories and worktrees
196 Clean,
197 /// Generate an evaluation example by splitting a chronologically-ordered commit
198 SplitCommit(SplitCommitArgs),
199 /// Truncate expected patch by the given criteria
200 TruncatePatch(TruncatePatchArgs),
201 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
202 Split(SplitArgs),
203 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
204 FilterLanguages(FilterLanguagesArgs),
205 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
206 ImportBatch(ImportBatchArgs),
207 /// Assess the quality of predictions using LLM-as-a-judge
208 Qa(qa::QaArgs),
209 /// Repair predictions that received poor QA scores by generating improved predictions
210 Repair(repair::RepairArgs),
211 /// Print all valid zeta formats (lowercase, one per line)
212 PrintZetaFormats,
213}
214
215impl Display for Command {
216 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217 match self {
218 Command::Read => write!(f, "read"),
219 Command::LoadProject => write!(f, "load-project"),
220 Command::Context => write!(f, "context"),
221 Command::FormatPrompt(args) => {
222 write!(f, "format-prompt --provider={}", args.provider)
223 }
224 Command::Predict(args) => match &args.provider {
225 Some(provider) => write!(f, "predict --provider={}", provider),
226 None => write!(f, "predict"),
227 },
228 Command::ParseOutput => write!(f, "parse-output"),
229 Command::Score(args) => match &args.provider {
230 Some(provider) => write!(f, "score --provider={}", provider),
231 None => write!(f, "score"),
232 },
233 Command::Distill => write!(f, "distill"),
234 Command::Eval(args) => match &args.predict.provider {
235 Some(provider) => write!(f, "eval --provider={}", provider),
236 None => write!(f, "eval"),
237 },
238 Command::Synthesize(args) => {
239 write!(f, "synthesize --repos {}", args.repos.join(" "))
240 }
241 Command::Clean => write!(f, "clean"),
242 Command::SplitCommit(_) => write!(f, "split-commit"),
243 Command::TruncatePatch(_) => write!(f, "truncate-patch"),
244 Command::Split(_) => write!(f, "split"),
245 Command::FilterLanguages(_) => write!(f, "filter-languages"),
246 Command::ImportBatch(args) => {
247 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
248 }
249 Command::Qa(_) => {
250 write!(f, "qa")
251 }
252 Command::Repair(_) => {
253 write!(f, "repair")
254 }
255 Command::PrintZetaFormats => {
256 write!(f, "print-zeta-formats")
257 }
258 }
259 }
260}
261
262#[derive(Debug, Args, Clone)]
263struct FormatPromptArgs {
264 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
265 provider: PredictionProvider,
266}
267
268#[derive(Debug, Args, Clone)]
269struct PredictArgs {
270 #[clap(long, short('p'))]
271 provider: Option<PredictionProvider>,
272 #[clap(long, default_value_t = 1)]
273 repetitions: usize,
274 /// Only use cached responses, don't queue new requests for batching
275 #[clap(long)]
276 cache_only: bool,
277}
278
279#[derive(Debug, Args, Clone)]
280struct EvalArgs {
281 #[clap(flatten)]
282 predict: PredictArgs,
283 /// Path to write summary scores as JSON
284 #[clap(long)]
285 summary_json: Option<PathBuf>,
286}
287
288#[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)]
289pub enum TeacherBackend {
290 Sonnet46,
291 #[default]
292 Sonnet45,
293 Gpt52,
294}
295
296impl std::fmt::Display for TeacherBackend {
297 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
298 match self {
299 TeacherBackend::Sonnet46 => write!(f, "sonnet46"),
300 TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
301 TeacherBackend::Gpt52 => write!(f, "gpt52"),
302 }
303 }
304}
305
306impl std::str::FromStr for TeacherBackend {
307 type Err = anyhow::Error;
308
309 fn from_str(s: &str) -> Result<Self, Self::Err> {
310 match s.to_lowercase().as_str() {
311 "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
312 "sonnet46" => Ok(TeacherBackend::Sonnet46),
313 "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
314 "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
315 _ => anyhow::bail!(
316 "unknown teacher backend `{s}`. Valid options: sonnet45, sonnet46, gpt52"
317 ),
318 }
319 }
320}
321
322impl TeacherBackend {
323 pub fn model_name(&self) -> &'static str {
324 match self {
325 TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
326 TeacherBackend::Sonnet46 => "claude-sonnet-4-6",
327 TeacherBackend::Gpt52 => "gpt-5.2",
328 }
329 }
330}
331
332#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
333enum PredictionProvider {
334 Sweep,
335 Mercury,
336 Zeta1,
337 Zeta2(ZetaFormat),
338 Teacher(TeacherBackend),
339 TeacherNonBatching(TeacherBackend),
340 Repair,
341}
342
343impl Default for PredictionProvider {
344 fn default() -> Self {
345 PredictionProvider::Zeta2(ZetaFormat::default())
346 }
347}
348
349impl std::fmt::Display for PredictionProvider {
350 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
351 match self {
352 PredictionProvider::Sweep => write!(f, "sweep"),
353 PredictionProvider::Mercury => write!(f, "mercury"),
354 PredictionProvider::Zeta1 => write!(f, "zeta1"),
355 PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
356 PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
357 PredictionProvider::TeacherNonBatching(backend) => {
358 write!(f, "teacher-non-batching:{backend}")
359 }
360 PredictionProvider::Repair => write!(f, "repair"),
361 }
362 }
363}
364
365impl std::str::FromStr for PredictionProvider {
366 type Err = anyhow::Error;
367
368 fn from_str(s: &str) -> Result<Self, Self::Err> {
369 let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
370
371 let provider_lower = provider.to_lowercase();
372 match provider_lower.as_str() {
373 "sweep" => Ok(PredictionProvider::Sweep),
374 "mercury" => Ok(PredictionProvider::Mercury),
375 "zeta1" => Ok(PredictionProvider::Zeta1),
376 "zeta2" => {
377 let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
378 Ok(PredictionProvider::Zeta2(format))
379 }
380 "teacher" => {
381 let backend = arg
382 .map(|a| a.parse())
383 .transpose()?
384 .unwrap_or(TeacherBackend::default());
385 Ok(PredictionProvider::Teacher(backend))
386 }
387 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
388 let backend = arg
389 .map(|a| a.parse())
390 .transpose()?
391 .unwrap_or(TeacherBackend::default());
392 Ok(PredictionProvider::TeacherNonBatching(backend))
393 }
394 "repair" => Ok(PredictionProvider::Repair),
395 _ => {
396 anyhow::bail!(
397 "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
398 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
399 For teacher, you can specify a backend like `teacher:sonnet46` or `teacher:gpt52`.\n\
400 Available zeta versions:\n{}",
401 ZetaFormat::options_as_string()
402 )
403 }
404 }
405 }
406}
407
408impl Serialize for PredictionProvider {
409 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
410 where
411 S: Serializer,
412 {
413 serializer.serialize_str(&self.to_string())
414 }
415}
416
417impl<'de> Deserialize<'de> for PredictionProvider {
418 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
419 where
420 D: Deserializer<'de>,
421 {
422 let s = String::deserialize(deserializer)?;
423 s.parse().map_err(serde::de::Error::custom)
424 }
425}
426
427#[derive(Debug, Args, Clone)]
428struct SynthesizeArgs {
429 /// Repository URLs (git@github.com:owner/repo or https://...)
430 #[clap(long, required = true, num_args = 1..)]
431 repos: Vec<String>,
432
433 /// Number of examples to generate per repository
434 #[clap(long, default_value_t = 5)]
435 count: usize,
436
437 /// Maximum commits to scan per repository before giving up
438 #[clap(long, default_value_t = 100)]
439 max_commits: usize,
440
441 /// Ignore state file and reprocess all commits
442 #[clap(long)]
443 fresh: bool,
444}
445
446#[derive(Debug, Args, Clone)]
447struct ImportBatchArgs {
448 /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
449 #[clap(long, required = true, num_args = 1..)]
450 batch_ids: Vec<String>,
451 /// Which provider's batches to import (anthropic or openai)
452 #[clap(long, default_value = "anthropic")]
453 provider: BatchProvider,
454}
455
456#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
457enum BatchProvider {
458 Anthropic,
459 Openai,
460}
461
462impl EpArgs {
463 fn output_path(&self) -> Option<PathBuf> {
464 if self.in_place {
465 if self.inputs.len() == 1 {
466 self.inputs.first().cloned()
467 } else {
468 panic!("--in-place requires exactly one input file")
469 }
470 } else {
471 self.output.clone()
472 }
473 }
474}
475
476/// Minimum Zed version required for Snowflake queries.
477/// This version introduced the current request schema with predicted edits in the edit
478/// history, and open source repos distinguished.
479const MIN_CAPTURE_VERSION: pull_examples::MinCaptureVersion = pull_examples::MinCaptureVersion {
480 minor: 224,
481 patch: 1,
482};
483
484async fn load_examples(
485 http_client: Arc<dyn http_client::HttpClient>,
486 args: &EpArgs,
487 output_path: Option<&PathBuf>,
488 background_executor: BackgroundExecutor,
489) -> anyhow::Result<Vec<Example>> {
490 let mut captured_after_timestamps = Vec::new();
491 let mut rejected_after_timestamps = Vec::new();
492 let mut requested_after_timestamps = Vec::new();
493 let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
494 Vec::new();
495 let mut file_inputs = Vec::new();
496
497 for input in &args.inputs {
498 let input_string = input.to_string_lossy();
499 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
500 captured_after_timestamps.push(timestamp.to_string());
501 } else if let Some(timestamp) =
502 pull_examples::parse_rejected_after_input(input_string.as_ref())
503 {
504 rejected_after_timestamps.push(timestamp.to_string());
505 } else if let Some(timestamp) =
506 pull_examples::parse_requested_after_input(input_string.as_ref())
507 {
508 requested_after_timestamps.push(timestamp.to_string());
509 } else if let Some((timestamp, rating_filter)) =
510 pull_examples::parse_rated_after_input(input_string.as_ref())
511 {
512 rated_after_inputs.push((timestamp.to_string(), rating_filter));
513 } else {
514 file_inputs.push(input.clone());
515 }
516 }
517
518 let mut examples = read_example_files(&file_inputs);
519
520 // Apply offset to file examples first, then pass remaining offset to Snowflake.
521 let file_example_count = examples.len();
522 let remaining_offset = if let Some(offset) = args.offset {
523 if offset >= file_example_count {
524 examples.clear();
525 offset - file_example_count
526 } else {
527 examples.splice(0..offset, []);
528 0
529 }
530 } else {
531 0
532 };
533
534 Progress::global().set_total_examples(examples.len());
535
536 let remaining_limit_for_snowflake =
537 args.limit.map(|limit| limit.saturating_sub(examples.len()));
538
539 if let Some(0) = remaining_limit_for_snowflake {
540 log::info!(
541 "skipping Snowflake inputs because --limit is already satisfied by example files"
542 );
543 } else {
544 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
545
546 if !captured_after_timestamps.is_empty() {
547 captured_after_timestamps.sort();
548
549 let mut captured_examples = pull_examples::fetch_captured_examples_after(
550 http_client.clone(),
551 &captured_after_timestamps,
552 max_rows_per_timestamp,
553 remaining_offset,
554 background_executor.clone(),
555 Some(MIN_CAPTURE_VERSION),
556 )
557 .await?;
558 examples.append(&mut captured_examples);
559 }
560
561 if !rejected_after_timestamps.is_empty() {
562 rejected_after_timestamps.sort();
563
564 let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
565 http_client.clone(),
566 &rejected_after_timestamps,
567 max_rows_per_timestamp,
568 remaining_offset,
569 background_executor.clone(),
570 Some(MIN_CAPTURE_VERSION),
571 )
572 .await?;
573 examples.append(&mut rejected_examples);
574 }
575
576 if !requested_after_timestamps.is_empty() {
577 requested_after_timestamps.sort();
578
579 let mut requested_examples = pull_examples::fetch_requested_examples_after(
580 http_client.clone(),
581 &requested_after_timestamps,
582 max_rows_per_timestamp,
583 remaining_offset,
584 background_executor.clone(),
585 Some(MIN_CAPTURE_VERSION),
586 )
587 .await?;
588 examples.append(&mut requested_examples);
589 }
590
591 if !rated_after_inputs.is_empty() {
592 rated_after_inputs.sort();
593
594 let mut rated_examples = pull_examples::fetch_rated_examples_after(
595 http_client,
596 &rated_after_inputs,
597 max_rows_per_timestamp,
598 remaining_offset,
599 background_executor,
600 Some(MIN_CAPTURE_VERSION),
601 )
602 .await?;
603 examples.append(&mut rated_examples);
604 }
605 }
606
607 crate::example::sort_examples_by_repo_and_rev(&mut examples);
608
609 if let Some(name_filter) = &args.name {
610 examples.retain(|example| example.spec.name.contains(name_filter));
611 }
612 if let Some(repo_filter) = &args.repo {
613 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
614 }
615
616 // Skip resume logic for --in-place since input and output are the same file,
617 // which would incorrectly treat all input examples as already processed.
618 if !args.in_place {
619 if let Some(path) = output_path
620 && let Some(command) = &args.command
621 {
622 resume_from_output(path, &mut examples, command);
623 }
624 }
625
626 if let Some(limit) = args.limit {
627 examples.truncate(limit);
628 }
629
630 let progress = Progress::global();
631 progress.set_total_examples(examples.len());
632 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
633
634 Ok(examples)
635}
636
637fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
638 let mut hasher = collections::FxHasher::default();
639 spec.hash(&mut hasher);
640 hasher.finish()
641}
642
643fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
644 let file = match File::open(path) {
645 Ok(f) => f,
646 Err(_) => return,
647 };
648
649 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
650
651 let reader = BufReader::new(file);
652 let mut kept_lines = Vec::new();
653 let mut kept_hashes = HashSet::default();
654
655 for line in reader.lines() {
656 let line = match line {
657 Ok(l) => l,
658 Err(_) => continue,
659 };
660
661 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
662 let hash = spec_hash(&output_example.spec);
663 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
664 let is_complete = match command {
665 Command::Qa(_) => output_example
666 .qa
667 .first()
668 .and_then(|q| q.as_ref())
669 .and_then(|q| q.confidence)
670 .is_some(),
671 Command::Repair(_) => output_example.predictions.iter().any(|p| {
672 p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
673 }),
674 _ => true,
675 };
676 if is_complete {
677 kept_hashes.insert(hash);
678 kept_lines.push(line);
679 }
680 }
681 }
682 }
683
684 let total = examples.len();
685 let already_processed = kept_hashes.len();
686
687 eprintln!(
688 "Resuming: {}/{} examples already processed",
689 already_processed, total
690 );
691
692 let file = OpenOptions::new()
693 .write(true)
694 .truncate(true)
695 .open(path)
696 .expect("Failed to open output file for rewriting");
697 let mut writer = BufWriter::new(file);
698 for line in &kept_lines {
699 writeln!(writer, "{}", line).expect("Failed to write to output file");
700 }
701 writer.flush().expect("Failed to flush output file");
702
703 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
704}
705
706fn main() {
707 let args = EpArgs::parse();
708
709 if args.printenv {
710 ::util::shell_env::print_env();
711 return;
712 }
713
714 let output = args.output_path();
715
716 if args.markdown && output.is_none() {
717 eprintln!("--markdown requires -o to specify the output directory");
718 std::process::exit(1);
719 }
720
721 let command = match &args.command {
722 Some(cmd) => cmd.clone(),
723 None => {
724 EpArgs::command().print_help().unwrap();
725 return;
726 }
727 };
728
729 match &command {
730 Command::ImportBatch(import_args) => {
731 smol::block_on(async {
732 match import_args.provider {
733 BatchProvider::Anthropic => {
734 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
735 .expect("Failed to create Anthropic client");
736 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
737 eprintln!("Error importing Anthropic batches: {:?}", e);
738 std::process::exit(1);
739 }
740 }
741 BatchProvider::Openai => {
742 let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
743 .expect("Failed to create OpenAI client");
744 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
745 eprintln!("Error importing OpenAI batches: {:?}", e);
746 std::process::exit(1);
747 }
748 }
749 }
750 println!(
751 "Successfully imported {} batch(es)",
752 import_args.batch_ids.len()
753 );
754 });
755 return;
756 }
757 Command::Clean => {
758 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
759 return;
760 }
761 Command::PrintZetaFormats => {
762 use strum::IntoEnumIterator as _;
763 for format in ZetaFormat::iter() {
764 println!("{}", format.to_string().to_lowercase());
765 }
766 return;
767 }
768
769 Command::Synthesize(synth_args) => {
770 let Some(output_dir) = args.output else {
771 panic!("output dir is required");
772 };
773 let config = SynthesizeConfig {
774 repo_urls: synth_args.repos.clone(),
775 count: synth_args.count,
776 max_commits: synth_args.max_commits,
777 output_dir,
778 fresh: synth_args.fresh,
779 };
780 smol::block_on(async {
781 if let Err(e) = run_synthesize(config).await {
782 eprintln!("Error: {:?}", e);
783 std::process::exit(1);
784 }
785 });
786 return;
787 }
788 Command::SplitCommit(split_commit_args) => {
789 if let Err(error) = split_commit::run_split_commit(
790 split_commit_args,
791 &args.inputs,
792 output.as_ref(),
793 args.failed,
794 ) {
795 eprintln!("{error:#}");
796 std::process::exit(1);
797 }
798 return;
799 }
800 Command::TruncatePatch(truncate_args) => {
801 if let Err(error) =
802 truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
803 {
804 eprintln!("{error:#}");
805 std::process::exit(1);
806 }
807 return;
808 }
809 Command::Split(split_args) => {
810 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
811 eprintln!("{error:#}");
812 std::process::exit(1);
813 }
814 return;
815 }
816 Command::FilterLanguages(filter_args) => {
817 if let Err(error) =
818 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
819 {
820 eprintln!("{error:#}");
821 std::process::exit(1);
822 }
823 return;
824 }
825
826 _ => {}
827 }
828
829 let http_client = Arc::new(ReqwestClient::new());
830 let app = gpui_platform::headless().with_http_client(http_client);
831
832 app.run(move |cx| {
833 let app_state = Arc::new(headless::init(cx));
834 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
835
836 cx.spawn(async move |cx| {
837 let result = async {
838 let examples = load_examples(
839 app_state.client.http_client(),
840 &args,
841 output.as_ref(),
842 cx.background_executor().clone(),
843 )
844 .await?;
845
846 match &command {
847 Command::Predict(args) | Command::Score(args) => {
848 predict::sync_batches(args.provider.as_ref()).await?;
849 }
850 Command::Eval(args) => {
851 predict::sync_batches(args.predict.provider.as_ref()).await?;
852 }
853 Command::Qa(args) => {
854 qa::sync_batches(args).await?;
855 }
856 Command::Repair(args) => {
857 repair::sync_batches(args).await?;
858 }
859 _ => (),
860 }
861
862 let failfast_on_single_example = examples.len() == 1;
863
864 // For --markdown mode, create the output directory if it doesn't exist
865 if args.markdown {
866 let dir = output.as_ref().expect("--markdown requires -o");
867 if !dir.exists() {
868 std::fs::create_dir_all(dir)
869 .expect("Failed to create markdown output directory");
870 }
871 }
872
873 // Set up JSONL output writer (not used in markdown mode)
874 let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
875 let mut in_place_temp_path: Option<PathBuf> = None;
876 if !args.markdown
877 && let Some(output_path) = output.as_ref()
878 {
879 let write_path = if args.in_place {
880 let temp = output_path.with_extension("jsonl.tmp");
881 in_place_temp_path = Some(temp.clone());
882 temp
883 } else {
884 output_path.clone()
885 };
886
887 let file = OpenOptions::new()
888 .create(true)
889 .write(true)
890 .truncate(args.in_place)
891 .append(!args.in_place)
892 .open(&write_path)
893 .expect("Failed to open output file");
894
895 let mut writer = BufWriter::new(file);
896 let (sender, mut receiver) = mpsc::unbounded::<String>();
897 cx.background_spawn(async move {
898 while let Some(line) = receiver.next().await {
899 writeln!(writer, "{}", line).expect("Failed to write example");
900 writer.flush().expect("Failed to flush output");
901 }
902 })
903 .detach();
904 output_sender = Some(sender);
905 }
906
907 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
908 let finished_examples = Mutex::new(Vec::new());
909
910 let mut tasks = Vec::new();
911 for _ in 0..args.max_parallelism {
912 tasks.push(async {
913 loop {
914 let Some(mut repo_examples) =
915 grouped_examples.lock().unwrap().pop_front()
916 else {
917 break;
918 };
919 for example in &mut repo_examples {
920 let example_progress =
921 Progress::global().start_group(&example.spec.name);
922
923 let result = async {
924 match &command {
925 Command::Read => {}
926 Command::LoadProject => {
927 run_load_project(
928 example,
929 app_state.clone(),
930 &example_progress,
931 cx.clone(),
932 )
933 .await?;
934 }
935 Command::Context => {
936 run_context_retrieval(
937 example,
938 app_state.clone(),
939 &example_progress,
940 cx.clone(),
941 )
942 .await?;
943 }
944 Command::FormatPrompt(args) => {
945 run_format_prompt(
946 example,
947 args,
948 app_state.clone(),
949 &example_progress,
950 cx.clone(),
951 )
952 .await?;
953 }
954 Command::Predict(args) => {
955 run_prediction(
956 example,
957 args,
958 app_state.clone(),
959 &example_progress,
960 cx.clone(),
961 )
962 .await?;
963 }
964 Command::ParseOutput => {
965 parse_output::run_parse_output(example)?;
966 }
967 Command::Distill => {
968 run_distill(example).await?;
969 }
970 Command::Score(args) => {
971 run_scoring(
972 example,
973 args,
974 app_state.clone(),
975 &example_progress,
976 cx.clone(),
977 )
978 .await?;
979 }
980 Command::Eval(args) => {
981 run_scoring(
982 example,
983 &args.predict,
984 app_state.clone(),
985 &example_progress,
986 cx.clone(),
987 )
988 .await?;
989 }
990 Command::Qa(args) => {
991 qa::run_qa(example, args, &example_progress).await?;
992 }
993 Command::Repair(args) => {
994 repair::run_repair(example, args, &example_progress)
995 .await?;
996 }
997 Command::Clean
998 | Command::Synthesize(_)
999 | Command::SplitCommit(_)
1000 | Command::Split(_)
1001 | Command::TruncatePatch(_)
1002 | Command::FilterLanguages(_)
1003 | Command::ImportBatch(_)
1004 | Command::PrintZetaFormats => {
1005 unreachable!()
1006 }
1007 }
1008 anyhow::Ok(())
1009 }
1010 .await;
1011
1012 let failed = if let Err(error) = result {
1013 handle_error(
1014 error,
1015 &args,
1016 &command,
1017 &app_state,
1018 failfast_on_single_example,
1019 &example,
1020 )
1021 .await;
1022 true
1023 } else {
1024 false
1025 };
1026
1027 let should_write = !failed || args.failed == FailedHandling::Keep;
1028 if should_write {
1029 if args.markdown {
1030 let markdown_dir =
1031 output.as_ref().expect("--markdown requires -o");
1032 let filename = format!("{}.md", example.spec.filename());
1033 let path = markdown_dir.join(&filename);
1034 let markdown = example.spec.to_markdown();
1035 std::fs::write(&path, &markdown)
1036 .expect("Failed to write markdown file");
1037 } else if let Some(ref mut sender) = output_sender.clone() {
1038 let line = serde_json::to_string(&example).unwrap();
1039 sender
1040 .send(line)
1041 .await
1042 .expect("Failed to send to output writer");
1043 } else if args.output.is_none()
1044 && !matches!(command, Command::Eval(_))
1045 {
1046 let line = serde_json::to_string(&example).unwrap();
1047 println!("{}", line);
1048 }
1049 }
1050 }
1051
1052 let project = repo_examples
1053 .iter()
1054 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()));
1055
1056 if let Some(project) = project {
1057 let mut cx = cx.clone();
1058
1059 let shutdown_task: Task<()> =
1060 project.update(&mut cx, |project, cx| {
1061 let lsp_store = project.lsp_store();
1062 lsp_store.update(cx, |lsp_store, cx| {
1063 lsp_store.shutdown_all_language_servers(cx)
1064 })
1065 });
1066
1067 shutdown_task.await;
1068
1069 if let Some(ep_store) =
1070 cx.update(|cx| EditPredictionStore::try_global(cx))
1071 {
1072 ep_store.update(&mut cx, |store, _| {
1073 store.remove_project(&project);
1074 });
1075 }
1076 }
1077
1078 for example in &mut repo_examples {
1079 example.state.take();
1080 }
1081 finished_examples
1082 .lock()
1083 .unwrap()
1084 .extend_from_slice(&repo_examples);
1085 }
1086 });
1087 }
1088 futures::future::join_all(tasks).await;
1089
1090 Progress::global().finalize();
1091
1092 match &command {
1093 Command::Predict(args) | Command::Score(args) => {
1094 predict::sync_batches(args.provider.as_ref()).await?;
1095 }
1096 Command::Eval(args) => {
1097 predict::sync_batches(args.predict.provider.as_ref()).await?;
1098 }
1099 Command::Qa(args) => {
1100 qa::sync_batches(args).await?;
1101 }
1102 Command::Repair(args) => {
1103 repair::sync_batches(args).await?;
1104 }
1105 _ => (),
1106 }
1107
1108 match &command {
1109 Command::Eval(args) => {
1110 let examples = finished_examples.lock().unwrap();
1111 score::print_report(&examples);
1112 if let Some(summary_path) = &args.summary_json {
1113 score::write_summary_json(&examples, summary_path)?;
1114 }
1115 }
1116 Command::Repair(args) => {
1117 let examples = finished_examples.lock().unwrap();
1118 repair::print_report(&examples, args.confidence_threshold);
1119 }
1120 _ => (),
1121 };
1122
1123 // For --in-place, atomically rename temp file to original
1124 if let Some(temp_path) = &in_place_temp_path {
1125 let final_path = output.as_ref().expect("in_place_temp_path requires output");
1126 std::fs::rename(temp_path, final_path)
1127 .expect("Failed to rename temp file to final output");
1128 }
1129
1130 anyhow::Ok(())
1131 }
1132 .await;
1133
1134 if let Err(e) = result {
1135 panic!("Fatal error: {:?}", e);
1136 }
1137
1138 let _ = cx.update(|cx| cx.quit());
1139 })
1140 .detach();
1141 });
1142}
1143
1144async fn handle_error(
1145 error: anyhow::Error,
1146 args: &EpArgs,
1147 command: &Command,
1148 app_state: &Arc<headless::EpAppState>,
1149 failfast_on_single_example: bool,
1150 example: &Example,
1151) {
1152 Progress::global().increment_failed();
1153
1154 let msg;
1155 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1156 let example_name = example.spec.filename();
1157
1158 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1159 app_state
1160 .fs
1161 .write(
1162 &failed_example_path,
1163 &serde_json::to_vec_pretty(&example).unwrap(),
1164 )
1165 .await
1166 .unwrap();
1167 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1168 app_state
1169 .fs
1170 .write(&err_path, format!("{error:?}").as_bytes())
1171 .await
1172 .unwrap();
1173
1174 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1175 let mut file = OpenOptions::new()
1176 .create(true)
1177 .append(true)
1178 .open(&failed_jsonl_path)
1179 .expect("Failed to open failed.jsonl");
1180 writeln!(file, "{}", serde_json::to_string(example).unwrap())
1181 .expect("Failed to write to failed.jsonl");
1182
1183 let cursor_path = match example.repo_name() {
1184 Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1185 Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1186 };
1187 msg = format!(
1188 indoc::indoc! {"
1189 While processing \"{}\":
1190
1191 \x1b[31m{:?}\x1b[0m
1192
1193 Example: \x1b[36m{}\x1b[0m
1194 Error file: \x1b[36m{}\x1b[0m
1195 Cursor file: \x1b[36m{}\x1b[0m
1196 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1197 "},
1198 example.spec.name,
1199 error,
1200 failed_example_path.display(),
1201 err_path.display(),
1202 cursor_path.display(),
1203 command,
1204 failed_example_path.display(),
1205 );
1206 } else {
1207 msg = format!(
1208 indoc::indoc! {"
1209 While processing \"{}\":
1210
1211 \x1b[31m{:?}\x1b[0m
1212 "},
1213 example.spec.name, error
1214 );
1215 }
1216
1217 if args.failfast || failfast_on_single_example {
1218 Progress::global().finalize();
1219 panic!("{}", msg);
1220 } else {
1221 log::error!("{}", msg);
1222 }
1223}