1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod load_project;
9mod metrics;
10mod openai_client;
11mod parse_output;
12mod paths;
13mod predict;
14mod progress;
15mod prompt_assets;
16mod pull_examples;
17mod qa;
18mod reorder_patch;
19mod repair;
20mod retrieve_context;
21mod reversal_tracking;
22mod score;
23mod split_commit;
24mod split_dataset;
25
26mod synthesize;
27mod truncate_expected_patch;
28mod word_diff;
29use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
30use collections::{HashMap, HashSet};
31use edit_prediction::EditPredictionStore;
32use futures::channel::mpsc;
33use futures::{SinkExt as _, StreamExt as _};
34use gaoya::minhash::{
35 MinHashIndex, MinHasher, MinHasher32, calculate_minhash_params, compute_minhash_similarity,
36};
37use gpui::{AppContext as _, BackgroundExecutor, Task};
38use zeta_prompt::ZetaFormat;
39
40use reqwest_client::ReqwestClient;
41use serde::{Deserialize, Deserializer, Serialize, Serializer};
42use std::fmt::Display;
43use std::fs::{File, OpenOptions};
44use std::hash::{Hash, Hasher};
45use std::io::{BufRead, BufReader, BufWriter, Write};
46use std::sync::Mutex;
47use std::{path::PathBuf, sync::Arc};
48
49use crate::distill::run_distill;
50use crate::example::{Example, group_examples_by_repo, read_example_files};
51use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
52use crate::format_prompt::run_format_prompt;
53use crate::load_project::run_load_project;
54use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
55use crate::predict::run_prediction;
56use crate::progress::Progress;
57use crate::retrieve_context::run_context_retrieval;
58use crate::score::run_scoring;
59use crate::split_commit::SplitCommitArgs;
60use crate::split_dataset::SplitArgs;
61use crate::synthesize::{SynthesizeConfig, run_synthesize};
62use crate::truncate_expected_patch::TruncatePatchArgs;
63
64#[derive(Parser, Debug)]
65#[command(name = "ep")]
66struct EpArgs {
67 #[arg(long, default_value_t = false)]
68 printenv: bool,
69 #[clap(long, default_value_t = 10, global = true)]
70 max_parallelism: usize,
71 /// The limit for the number of examples to process
72 /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
73 #[clap(long, global = true)]
74 limit: Option<usize>,
75 #[clap(long, global = true)]
76 offset: Option<usize>,
77 /// Filter examples by name
78 #[clap(long, global = true)]
79 name: Option<String>,
80 /// Filter examples by repository
81 #[clap(long, global = true)]
82 repo: Option<String>,
83 /// Deduplicate by cursor position and keep at most this many examples per cluster
84 #[clap(long, global = true)]
85 max_duplicates: Option<usize>,
86 #[command(subcommand)]
87 command: Option<Command>,
88 /// Input file paths
89 #[clap(global = true)]
90 inputs: Vec<PathBuf>,
91 #[arg(long, short, global = true)]
92 output: Option<PathBuf>,
93 #[arg(long, short, global = true)]
94 in_place: bool,
95 #[arg(long, short, global = true)]
96 failfast: bool,
97 /// How to handle failed examples in output: keep them or skip them.
98 /// Failed examples are always logged to the run's failed directory.
99 #[arg(long, global = true, default_value = "keep")]
100 failed: FailedHandling,
101 /// Output as markdown files instead of JSONL. When set, -o specifies a directory
102 /// where one .md file per example will be written (named after each example).
103 #[arg(long, short, global = true)]
104 markdown: bool,
105}
106
107/// Controls whether failed examples are included in the main output.
108/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
109#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
110pub enum FailedHandling {
111 /// Include failed examples in the main output (default)
112 #[default]
113 Keep,
114 /// Exclude failed examples from the main output
115 Skip,
116 /// Skip writing files
117 SkipNoFiles,
118}
119
120const INPUTS_HELP: &str = r#"
121Inputs can be file paths or special specifiers:
122
123 path
124 Path to an example(s) file (.md, .json, or .jsonl)
125
126 captured-after:{timestamp}
127 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
128 These are examples captured via the "Capture Edit Prediction Example" action.
129
130 rejected-after:{timestamp}
131 Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
132 These are predictions that were shown to users but rejected (useful for DPO training).
133
134 rated-after:{timestamp}
135 Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
136 These are predictions that users explicitly rated as positive or negative via the
137 rate completions modal. Only zeta2 predictions are included.
138 - Positive ratings: output becomes expected_patches
139 - Negative ratings: output becomes rejected_patch
140
141 rated-positive-after:{timestamp}
142 Same as rated-after, but only fetches positively rated predictions.
143
144 rated-negative-after:{timestamp}
145 Same as rated-after, but only fetches negatively rated predictions.
146
147 Required environment variables to connect to Snowflake:
148 EP_SNOWFLAKE_API_KEY
149 EP_SNOWFLAKE_BASE_URL
150
151 Optional:
152 EP_SNOWFLAKE_ROLE
153
154Examples:
155
156 # Read examples from a file
157 ep read examples.jsonl -o output.jsonl
158
159 # Read captured examples after a timestamp
160 ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
161
162 # Read rejected predictions for DPO training
163 ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
164
165 # Read user-rated predictions
166 ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
167
168 # Read only positively rated predictions
169 ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
170
171 # Read only negatively rated predictions
172 ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
173
174 # Mix multiple input sources
175 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
176"#;
177
178#[derive(Subcommand, Debug, Clone)]
179enum Command {
180 /// Read examples from files or fetch from Snowflake, output as .jsonl
181 Read(ReadArgs),
182 /// Create git worktrees for each example and load file contents
183 LoadProject,
184 /// Retrieve context for input examples.
185 Context,
186 /// Generate a prompt string for a specific model
187 FormatPrompt(FormatPromptArgs),
188 /// Runs edit prediction
189 Predict(PredictArgs),
190 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
191 /// Requires format-prompt to have been run first. Uses provider from prompt.
192 ParseOutput,
193 /// Computes a score based on actual and expected patches
194 Score(PredictArgs),
195 /// Prepares a distillation dataset by copying expected outputs to
196 /// predicted outputs and removing actual outputs and prompts.
197 Distill,
198 /// Print aggregated scores
199 Eval(EvalArgs),
200 /// Generate eval examples by analyzing git commits from a repository
201 Synthesize(SynthesizeArgs),
202 /// Remove git repositories and worktrees
203 Clean,
204 /// Generate an evaluation example by splitting a chronologically-ordered commit
205 SplitCommit(SplitCommitArgs),
206 /// Truncate expected patch by the given criteria
207 TruncatePatch(TruncatePatchArgs),
208 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
209 Split(SplitArgs),
210 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
211 FilterLanguages(FilterLanguagesArgs),
212 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
213 ImportBatch(ImportBatchArgs),
214 /// Assess the quality of predictions using LLM-as-a-judge
215 Qa(qa::QaArgs),
216 /// Repair predictions that received poor QA scores by generating improved predictions
217 Repair(repair::RepairArgs),
218 /// Print all valid zeta formats (lowercase, one per line)
219 PrintZetaFormats,
220}
221
222impl Display for Command {
223 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
224 match self {
225 Command::Read(_) => write!(f, "read"),
226 Command::LoadProject => write!(f, "load-project"),
227 Command::Context => write!(f, "context"),
228 Command::FormatPrompt(args) => {
229 write!(f, "format-prompt --provider={}", args.provider)
230 }
231 Command::Predict(args) => match &args.provider {
232 Some(provider) => write!(f, "predict --provider={}", provider),
233 None => write!(f, "predict"),
234 },
235 Command::ParseOutput => write!(f, "parse-output"),
236 Command::Score(args) => match &args.provider {
237 Some(provider) => write!(f, "score --provider={}", provider),
238 None => write!(f, "score"),
239 },
240 Command::Distill => write!(f, "distill"),
241 Command::Eval(args) => match &args.predict.provider {
242 Some(provider) => write!(f, "eval --provider={}", provider),
243 None => write!(f, "eval"),
244 },
245 Command::Synthesize(args) => {
246 write!(f, "synthesize --repos {}", args.repos.join(" "))
247 }
248 Command::Clean => write!(f, "clean"),
249 Command::SplitCommit(_) => write!(f, "split-commit"),
250 Command::TruncatePatch(_) => write!(f, "truncate-patch"),
251 Command::Split(_) => write!(f, "split"),
252 Command::FilterLanguages(_) => write!(f, "filter-languages"),
253 Command::ImportBatch(args) => {
254 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
255 }
256 Command::Qa(_) => {
257 write!(f, "qa")
258 }
259 Command::Repair(_) => {
260 write!(f, "repair")
261 }
262 Command::PrintZetaFormats => {
263 write!(f, "print-zeta-formats")
264 }
265 }
266 }
267}
268
269#[derive(Debug, Args, Clone)]
270#[command(after_help = INPUTS_HELP)]
271struct ReadArgs {}
272
273#[derive(Debug, Args, Clone)]
274struct FormatPromptArgs {
275 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
276 provider: PredictionProvider,
277}
278
279#[derive(Debug, Args, Clone)]
280struct PredictArgs {
281 #[clap(long, short('p'))]
282 provider: Option<PredictionProvider>,
283 #[clap(long, default_value_t = 1)]
284 repetitions: usize,
285 /// Only use cached responses, don't queue new requests for batching
286 #[clap(long)]
287 cache_only: bool,
288}
289
290#[derive(Debug, Args, Clone)]
291struct EvalArgs {
292 #[clap(flatten)]
293 predict: PredictArgs,
294 /// Path to write summary scores as JSON
295 #[clap(long)]
296 summary_json: Option<PathBuf>,
297 /// Print all individual example lines (default: up to 20)
298 #[clap(long)]
299 verbose: bool,
300}
301
302#[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)]
303pub enum TeacherBackend {
304 Sonnet46,
305 #[default]
306 Sonnet45,
307 Gpt52,
308}
309
310impl std::fmt::Display for TeacherBackend {
311 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
312 match self {
313 TeacherBackend::Sonnet46 => write!(f, "sonnet46"),
314 TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
315 TeacherBackend::Gpt52 => write!(f, "gpt52"),
316 }
317 }
318}
319
320impl std::str::FromStr for TeacherBackend {
321 type Err = anyhow::Error;
322
323 fn from_str(s: &str) -> Result<Self, Self::Err> {
324 match s.to_lowercase().as_str() {
325 "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
326 "sonnet46" => Ok(TeacherBackend::Sonnet46),
327 "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
328 "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
329 _ => anyhow::bail!(
330 "unknown teacher backend `{s}`. Valid options: sonnet45, sonnet46, gpt52"
331 ),
332 }
333 }
334}
335
336impl TeacherBackend {
337 pub fn model_name(&self) -> &'static str {
338 match self {
339 TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
340 TeacherBackend::Sonnet46 => "claude-sonnet-4-6",
341 TeacherBackend::Gpt52 => "gpt-5.2",
342 }
343 }
344}
345
346#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
347enum PredictionProvider {
348 Sweep,
349 Mercury,
350 Zeta1,
351 Zeta2(ZetaFormat),
352 Teacher(TeacherBackend),
353 TeacherNonBatching(TeacherBackend),
354 Repair,
355}
356
357impl Default for PredictionProvider {
358 fn default() -> Self {
359 PredictionProvider::Zeta2(ZetaFormat::default())
360 }
361}
362
363impl std::fmt::Display for PredictionProvider {
364 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
365 match self {
366 PredictionProvider::Sweep => write!(f, "sweep"),
367 PredictionProvider::Mercury => write!(f, "mercury"),
368 PredictionProvider::Zeta1 => write!(f, "zeta1"),
369 PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
370 PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
371 PredictionProvider::TeacherNonBatching(backend) => {
372 write!(f, "teacher-non-batching:{backend}")
373 }
374 PredictionProvider::Repair => write!(f, "repair"),
375 }
376 }
377}
378
379impl std::str::FromStr for PredictionProvider {
380 type Err = anyhow::Error;
381
382 fn from_str(s: &str) -> Result<Self, Self::Err> {
383 let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
384
385 let provider_lower = provider.to_lowercase();
386 match provider_lower.as_str() {
387 "sweep" => Ok(PredictionProvider::Sweep),
388 "mercury" => Ok(PredictionProvider::Mercury),
389 "zeta1" => Ok(PredictionProvider::Zeta1),
390 "zeta2" => {
391 let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
392 Ok(PredictionProvider::Zeta2(format))
393 }
394 "teacher" => {
395 let backend = arg
396 .map(|a| a.parse())
397 .transpose()?
398 .unwrap_or(TeacherBackend::default());
399 Ok(PredictionProvider::Teacher(backend))
400 }
401 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
402 let backend = arg
403 .map(|a| a.parse())
404 .transpose()?
405 .unwrap_or(TeacherBackend::default());
406 Ok(PredictionProvider::TeacherNonBatching(backend))
407 }
408 "repair" => Ok(PredictionProvider::Repair),
409 _ => {
410 anyhow::bail!(
411 "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
412 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
413 For teacher, you can specify a backend like `teacher:sonnet46` or `teacher:gpt52`.\n\
414 Available zeta versions:\n{}",
415 ZetaFormat::options_as_string()
416 )
417 }
418 }
419 }
420}
421
422impl Serialize for PredictionProvider {
423 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
424 where
425 S: Serializer,
426 {
427 serializer.serialize_str(&self.to_string())
428 }
429}
430
431impl<'de> Deserialize<'de> for PredictionProvider {
432 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
433 where
434 D: Deserializer<'de>,
435 {
436 let s = String::deserialize(deserializer)?;
437 s.parse().map_err(serde::de::Error::custom)
438 }
439}
440
441#[derive(Debug, Args, Clone)]
442struct SynthesizeArgs {
443 /// Repository URLs (git@github.com:owner/repo or https://...)
444 #[clap(long, required = true, num_args = 1..)]
445 repos: Vec<String>,
446
447 /// Number of examples to generate per repository
448 #[clap(long, default_value_t = 5)]
449 count: usize,
450
451 /// Maximum commits to scan per repository before giving up
452 #[clap(long, default_value_t = 100)]
453 max_commits: usize,
454
455 /// Ignore state file and reprocess all commits
456 #[clap(long)]
457 fresh: bool,
458}
459
460#[derive(Debug, Args, Clone)]
461struct ImportBatchArgs {
462 /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
463 #[clap(long, required = true, num_args = 1..)]
464 batch_ids: Vec<String>,
465 /// Which provider's batches to import (anthropic or openai)
466 #[clap(long, default_value = "anthropic")]
467 provider: BatchProvider,
468}
469
470#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
471enum BatchProvider {
472 Anthropic,
473 Openai,
474}
475
476impl EpArgs {
477 fn output_path(&self) -> Option<PathBuf> {
478 if self.in_place {
479 if self.inputs.len() == 1 {
480 self.inputs.first().cloned()
481 } else {
482 panic!("--in-place requires exactly one input file")
483 }
484 } else {
485 self.output.clone()
486 }
487 }
488}
489
490/// Minimum Zed version required for Snowflake queries.
491/// This version introduced the current request schema with predicted edits in the edit
492/// history, and open source repos distinguished.
493const MIN_CAPTURE_VERSION: pull_examples::MinCaptureVersion = pull_examples::MinCaptureVersion {
494 minor: 224,
495 patch: 1,
496};
497
498fn deduplicate_examples(examples: &mut Vec<Example>, max_per_cluster: usize) {
499 let total_before_exact = examples.len();
500 let mut seen_positions = HashSet::default();
501 examples.retain(|example| seen_positions.insert(example.spec.cursor_position.clone()));
502 log::info!(
503 "exact duplicate filter: {total_before_exact} examples → {} examples ({} removed)",
504 examples.len(),
505 total_before_exact - examples.len(),
506 );
507
508 const JACCARD_THRESHOLD: f64 = 0.5;
509 const NUM_HASHES: usize = 128;
510 const TOKEN_NGRAM_SIZE: usize = 5;
511
512 let (num_bands, band_width) = calculate_minhash_params(JACCARD_THRESHOLD, NUM_HASHES);
513 let num_hashes = num_bands * band_width;
514 let minhasher = MinHasher32::new(num_hashes);
515 let mut index: MinHashIndex<u32, usize> =
516 MinHashIndex::new(num_bands, band_width, JACCARD_THRESHOLD);
517
518 let signatures: Vec<Vec<u32>> = examples
519 .iter()
520 .map(|example| {
521 let shingles = code_token_ngrams(&example.spec.cursor_position, TOKEN_NGRAM_SIZE);
522 minhasher.create_signature(shingles.iter())
523 })
524 .collect();
525
526 for (id, signature) in signatures.iter().enumerate() {
527 index.insert(id, signature.clone());
528 }
529
530 // Build clusters via union-find on LSH candidate pairs.
531 let mut parent: Vec<usize> = (0..examples.len()).collect();
532
533 fn find(parent: &mut Vec<usize>, mut x: usize) -> usize {
534 while parent[x] != x {
535 parent[x] = parent[parent[x]];
536 x = parent[x];
537 }
538 x
539 }
540
541 for (id, signature) in signatures.iter().enumerate() {
542 for candidate in index.query_owned(signature) {
543 let (a, b) = (find(&mut parent, id), find(&mut parent, candidate));
544 if a != b {
545 parent[a] = b;
546 }
547 }
548 }
549
550 let mut clusters: HashMap<usize, Vec<usize>> = HashMap::default();
551 for id in 0..examples.len() {
552 clusters.entry(find(&mut parent, id)).or_default().push(id);
553 }
554
555 let mut keep: HashSet<usize> = HashSet::default();
556 for members in clusters.values() {
557 let selected = greedy_max_min_diverse(members, &signatures, max_per_cluster);
558 keep.extend(selected);
559 }
560
561 let total = examples.len();
562 let mut kept_indices: Vec<usize> = keep.into_iter().collect();
563 kept_indices.sort();
564
565 let mut retained = Vec::with_capacity(kept_indices.len());
566 for index in kept_indices.into_iter().rev() {
567 retained.push(examples.swap_remove(index));
568 }
569 retained.reverse();
570
571 *examples = retained;
572 log::info!(
573 "near-duplicate filter: {total} examples → {} examples ({} removed)",
574 examples.len(),
575 total - examples.len(),
576 );
577}
578
579fn greedy_max_min_diverse(members: &[usize], signatures: &[Vec<u32>], k: usize) -> Vec<usize> {
580 if members.len() <= k {
581 return members.to_vec();
582 }
583
584 let mut selected = vec![members[0]];
585 let mut min_dist: HashMap<usize, f64> = HashMap::default();
586 for &member in &members[1..] {
587 let dist = 1.0 - compute_minhash_similarity(&signatures[selected[0]], &signatures[member]);
588 min_dist.insert(member, dist);
589 }
590
591 while selected.len() < k {
592 let &best = min_dist
593 .iter()
594 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
595 .map(|(id, _)| id)
596 .expect("min_dist should not be empty when selected.len() < k");
597 selected.push(best);
598 min_dist.remove(&best);
599
600 let best_sig = &signatures[best];
601 for (member, current_min) in min_dist.iter_mut() {
602 let dist = 1.0 - compute_minhash_similarity(best_sig, &signatures[*member]);
603 if dist < *current_min {
604 *current_min = dist;
605 }
606 }
607 }
608
609 selected
610}
611
612fn code_token_ngrams(code: &str, ngram_size: usize) -> Vec<String> {
613 let tokens: Vec<&str> = word_diff::tokenize(code)
614 .into_iter()
615 .filter(|t| !t.trim().is_empty())
616 .collect();
617
618 if tokens.len() < ngram_size {
619 return vec![tokens.join("\0")];
620 }
621
622 tokens
623 .windows(ngram_size)
624 .map(|window| window.join("\0"))
625 .collect()
626}
627
628async fn load_examples(
629 http_client: Arc<dyn http_client::HttpClient>,
630 args: &EpArgs,
631 output_path: Option<&PathBuf>,
632 background_executor: BackgroundExecutor,
633) -> anyhow::Result<Vec<Example>> {
634 let mut captured_after_timestamps = Vec::new();
635 let mut rejected_after_timestamps = Vec::new();
636 let mut requested_after_timestamps = Vec::new();
637 let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
638 Vec::new();
639 let mut file_inputs = Vec::new();
640
641 for input in &args.inputs {
642 let input_string = input.to_string_lossy();
643 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
644 captured_after_timestamps.push(timestamp.to_string());
645 } else if let Some(timestamp) =
646 pull_examples::parse_rejected_after_input(input_string.as_ref())
647 {
648 rejected_after_timestamps.push(timestamp.to_string());
649 } else if let Some(timestamp) =
650 pull_examples::parse_requested_after_input(input_string.as_ref())
651 {
652 requested_after_timestamps.push(timestamp.to_string());
653 } else if let Some((timestamp, rating_filter)) =
654 pull_examples::parse_rated_after_input(input_string.as_ref())
655 {
656 rated_after_inputs.push((timestamp.to_string(), rating_filter));
657 } else {
658 file_inputs.push(input.clone());
659 }
660 }
661
662 let mut examples = read_example_files(&file_inputs);
663
664 // Apply offset to file examples first, then pass remaining offset to Snowflake.
665 let file_example_count = examples.len();
666 let remaining_offset = if let Some(offset) = args.offset {
667 if offset >= file_example_count {
668 examples.clear();
669 offset - file_example_count
670 } else {
671 examples.splice(0..offset, []);
672 0
673 }
674 } else {
675 0
676 };
677
678 Progress::global().set_total_examples(examples.len());
679
680 let remaining_limit_for_snowflake =
681 args.limit.map(|limit| limit.saturating_sub(examples.len()));
682
683 if let Some(0) = remaining_limit_for_snowflake {
684 log::info!(
685 "skipping Snowflake inputs because --limit is already satisfied by example files"
686 );
687 } else {
688 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
689
690 if !rejected_after_timestamps.is_empty() {
691 rejected_after_timestamps.sort();
692
693 let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
694 http_client.clone(),
695 &rejected_after_timestamps,
696 max_rows_per_timestamp,
697 remaining_offset,
698 background_executor.clone(),
699 Some(MIN_CAPTURE_VERSION),
700 )
701 .await?;
702 examples.append(&mut rejected_examples);
703 }
704
705 if !requested_after_timestamps.is_empty() {
706 requested_after_timestamps.sort();
707
708 let mut requested_examples = pull_examples::fetch_requested_examples_after(
709 http_client.clone(),
710 &requested_after_timestamps,
711 max_rows_per_timestamp,
712 remaining_offset,
713 background_executor.clone(),
714 Some(MIN_CAPTURE_VERSION),
715 )
716 .await?;
717 examples.append(&mut requested_examples);
718 }
719
720 if !rated_after_inputs.is_empty() {
721 rated_after_inputs.sort();
722
723 let mut rated_examples = pull_examples::fetch_rated_examples_after(
724 http_client,
725 &rated_after_inputs,
726 max_rows_per_timestamp,
727 remaining_offset,
728 background_executor,
729 Some(MIN_CAPTURE_VERSION),
730 )
731 .await?;
732 examples.append(&mut rated_examples);
733 }
734 }
735
736 crate::example::sort_examples_by_repo_and_rev(&mut examples);
737
738 if let Some(name_filter) = &args.name {
739 examples.retain(|example| example.spec.name.contains(name_filter));
740 }
741 if let Some(repo_filter) = &args.repo {
742 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
743 }
744
745 // Skip resume logic for --in-place since input and output are the same file,
746 // which would incorrectly treat all input examples as already processed.
747 if !args.in_place {
748 if let Some(path) = output_path
749 && let Some(command) = &args.command
750 {
751 resume_from_output(path, &mut examples, command);
752 }
753 }
754
755 if let Some(max_duplicates) = args.max_duplicates {
756 deduplicate_examples(&mut examples, max_duplicates);
757 }
758
759 if let Some(limit) = args.limit {
760 examples.truncate(limit);
761 }
762
763 let progress = Progress::global();
764 progress.set_total_examples(examples.len());
765 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
766
767 Ok(examples)
768}
769
770fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
771 let mut hasher = collections::FxHasher::default();
772 spec.hash(&mut hasher);
773 hasher.finish()
774}
775
776fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
777 let file = match File::open(path) {
778 Ok(f) => f,
779 Err(_) => return,
780 };
781
782 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
783
784 let reader = BufReader::new(file);
785 let mut kept_lines = Vec::new();
786 let mut kept_hashes = HashSet::default();
787
788 for line in reader.lines() {
789 let line = match line {
790 Ok(l) => l,
791 Err(_) => continue,
792 };
793
794 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
795 let hash = spec_hash(&output_example.spec);
796 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
797 let is_complete = match command {
798 Command::Qa(_) => output_example
799 .qa
800 .first()
801 .and_then(|q| q.as_ref())
802 .and_then(|q| q.confidence)
803 .is_some(),
804 Command::Repair(_) => output_example.predictions.iter().any(|p| {
805 p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
806 }),
807 _ => true,
808 };
809 if is_complete {
810 kept_hashes.insert(hash);
811 kept_lines.push(line);
812 }
813 }
814 }
815 }
816
817 let total = examples.len();
818 let already_processed = kept_hashes.len();
819
820 eprintln!(
821 "Resuming: {}/{} examples already processed",
822 already_processed, total
823 );
824
825 let file = OpenOptions::new()
826 .write(true)
827 .truncate(true)
828 .open(path)
829 .expect("Failed to open output file for rewriting");
830 let mut writer = BufWriter::new(file);
831 for line in &kept_lines {
832 writeln!(writer, "{}", line).expect("Failed to write to output file");
833 }
834 writer.flush().expect("Failed to flush output file");
835
836 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
837}
838
839fn main() {
840 let args = EpArgs::parse();
841
842 if args.printenv {
843 ::util::shell_env::print_env();
844 return;
845 }
846
847 let output = args.output_path();
848
849 if args.markdown && output.is_none() {
850 eprintln!("--markdown requires -o to specify the output directory");
851 std::process::exit(1);
852 }
853
854 let command = match &args.command {
855 Some(cmd) => cmd.clone(),
856 None => {
857 EpArgs::command().print_help().unwrap();
858 return;
859 }
860 };
861
862 match &command {
863 Command::ImportBatch(import_args) => {
864 smol::block_on(async {
865 match import_args.provider {
866 BatchProvider::Anthropic => {
867 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
868 .expect("Failed to create Anthropic client");
869 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
870 eprintln!("Error importing Anthropic batches: {:?}", e);
871 std::process::exit(1);
872 }
873 }
874 BatchProvider::Openai => {
875 let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
876 .expect("Failed to create OpenAI client");
877 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
878 eprintln!("Error importing OpenAI batches: {:?}", e);
879 std::process::exit(1);
880 }
881 }
882 }
883 println!(
884 "Successfully imported {} batch(es)",
885 import_args.batch_ids.len()
886 );
887 });
888 return;
889 }
890 Command::Clean => {
891 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
892 return;
893 }
894 Command::PrintZetaFormats => {
895 use strum::IntoEnumIterator as _;
896 for format in ZetaFormat::iter() {
897 println!("{}", format.to_string().to_lowercase());
898 }
899 return;
900 }
901
902 Command::Synthesize(synth_args) => {
903 let Some(output_dir) = args.output else {
904 panic!("output dir is required");
905 };
906 let config = SynthesizeConfig {
907 repo_urls: synth_args.repos.clone(),
908 count: synth_args.count,
909 max_commits: synth_args.max_commits,
910 output_dir,
911 fresh: synth_args.fresh,
912 };
913 smol::block_on(async {
914 if let Err(e) = run_synthesize(config).await {
915 eprintln!("Error: {:?}", e);
916 std::process::exit(1);
917 }
918 });
919 return;
920 }
921 Command::SplitCommit(split_commit_args) => {
922 if let Err(error) = split_commit::run_split_commit(
923 split_commit_args,
924 &args.inputs,
925 output.as_ref(),
926 args.failed,
927 ) {
928 eprintln!("{error:#}");
929 std::process::exit(1);
930 }
931 return;
932 }
933 Command::TruncatePatch(truncate_args) => {
934 if let Err(error) =
935 truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
936 {
937 eprintln!("{error:#}");
938 std::process::exit(1);
939 }
940 return;
941 }
942 Command::Split(split_args) => {
943 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
944 eprintln!("{error:#}");
945 std::process::exit(1);
946 }
947 return;
948 }
949 Command::FilterLanguages(filter_args) => {
950 if let Err(error) =
951 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
952 {
953 eprintln!("{error:#}");
954 std::process::exit(1);
955 }
956 return;
957 }
958
959 _ => {}
960 }
961
962 let http_client = Arc::new(ReqwestClient::new());
963 let app = gpui_platform::headless().with_http_client(http_client);
964
965 app.run(move |cx| {
966 let app_state = Arc::new(headless::init(cx));
967 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
968
969 cx.spawn(async move |cx| {
970 let result = async {
971 let examples = load_examples(
972 app_state.client.http_client(),
973 &args,
974 output.as_ref(),
975 cx.background_executor().clone(),
976 )
977 .await?;
978
979 match &command {
980 Command::Predict(args) | Command::Score(args) => {
981 predict::sync_batches(args.provider.as_ref()).await?;
982 }
983 Command::Eval(args) => {
984 predict::sync_batches(args.predict.provider.as_ref()).await?;
985 }
986 Command::Qa(args) => {
987 qa::sync_batches(args).await?;
988 }
989 Command::Repair(args) => {
990 repair::sync_batches(args).await?;
991 }
992 _ => (),
993 }
994
995 let failfast_on_single_example = examples.len() == 1;
996
997 // For --markdown mode, create the output directory if it doesn't exist
998 if args.markdown {
999 let dir = output.as_ref().expect("--markdown requires -o");
1000 if !dir.exists() {
1001 std::fs::create_dir_all(dir)
1002 .expect("Failed to create markdown output directory");
1003 }
1004 }
1005
1006 // Set up JSONL output writer (not used in markdown mode)
1007 let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
1008 let mut in_place_temp_path: Option<PathBuf> = None;
1009 if !args.markdown
1010 && let Some(output_path) = output.as_ref()
1011 {
1012 let write_path = if args.in_place {
1013 let temp = output_path.with_extension("jsonl.tmp");
1014 in_place_temp_path = Some(temp.clone());
1015 temp
1016 } else {
1017 output_path.clone()
1018 };
1019
1020 let file = OpenOptions::new()
1021 .create(true)
1022 .write(true)
1023 .truncate(args.in_place)
1024 .append(!args.in_place)
1025 .open(&write_path)
1026 .expect("Failed to open output file");
1027
1028 let mut writer = BufWriter::new(file);
1029 let (sender, mut receiver) = mpsc::unbounded::<String>();
1030 cx.background_spawn(async move {
1031 while let Some(line) = receiver.next().await {
1032 writeln!(writer, "{}", line).expect("Failed to write example");
1033 writer.flush().expect("Failed to flush output");
1034 }
1035 })
1036 .detach();
1037 output_sender = Some(sender);
1038 }
1039
1040 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
1041 let finished_examples = Mutex::new(Vec::new());
1042
1043 let mut tasks = Vec::new();
1044 for _ in 0..args.max_parallelism {
1045 tasks.push(async {
1046 loop {
1047 let Some(mut repo_examples) =
1048 grouped_examples.lock().unwrap().pop_front()
1049 else {
1050 break;
1051 };
1052 for example in &mut repo_examples {
1053 let example_progress =
1054 Progress::global().start_group(&example.spec.name);
1055
1056 let result = async {
1057 match &command {
1058 Command::Read(_) => {}
1059 Command::LoadProject => {
1060 run_load_project(
1061 example,
1062 app_state.clone(),
1063 &example_progress,
1064 cx.clone(),
1065 )
1066 .await?;
1067 }
1068 Command::Context => {
1069 run_context_retrieval(
1070 example,
1071 app_state.clone(),
1072 &example_progress,
1073 cx.clone(),
1074 )
1075 .await?;
1076 }
1077 Command::FormatPrompt(args) => {
1078 run_format_prompt(
1079 example,
1080 args,
1081 app_state.clone(),
1082 &example_progress,
1083 cx.clone(),
1084 )
1085 .await?;
1086 }
1087 Command::Predict(args) => {
1088 run_prediction(
1089 example,
1090 args,
1091 app_state.clone(),
1092 &example_progress,
1093 cx.clone(),
1094 )
1095 .await?;
1096 }
1097 Command::ParseOutput => {
1098 parse_output::run_parse_output(example)?;
1099 }
1100 Command::Distill => {
1101 run_distill(example).await?;
1102 }
1103 Command::Score(args) => {
1104 run_scoring(
1105 example,
1106 args,
1107 app_state.clone(),
1108 &example_progress,
1109 cx.clone(),
1110 )
1111 .await?;
1112 }
1113 Command::Eval(args) => {
1114 run_scoring(
1115 example,
1116 &args.predict,
1117 app_state.clone(),
1118 &example_progress,
1119 cx.clone(),
1120 )
1121 .await?;
1122 }
1123 Command::Qa(args) => {
1124 qa::run_qa(example, args, &example_progress).await?;
1125 }
1126 Command::Repair(args) => {
1127 repair::run_repair(example, args, &example_progress)
1128 .await?;
1129 }
1130 Command::Clean
1131 | Command::Synthesize(_)
1132 | Command::SplitCommit(_)
1133 | Command::Split(_)
1134 | Command::TruncatePatch(_)
1135 | Command::FilterLanguages(_)
1136 | Command::ImportBatch(_)
1137 | Command::PrintZetaFormats => {
1138 unreachable!()
1139 }
1140 }
1141 anyhow::Ok(())
1142 }
1143 .await;
1144
1145 let failed = if let Err(error) = result {
1146 handle_error(
1147 error,
1148 &args,
1149 &command,
1150 &app_state,
1151 failfast_on_single_example,
1152 &example,
1153 )
1154 .await;
1155 true
1156 } else {
1157 false
1158 };
1159
1160 let should_write = !failed || args.failed == FailedHandling::Keep;
1161 if should_write {
1162 if args.markdown {
1163 let markdown_dir =
1164 output.as_ref().expect("--markdown requires -o");
1165 let filename = format!("{}.md", example.spec.filename());
1166 let path = markdown_dir.join(&filename);
1167 let markdown = example.spec.to_markdown();
1168 std::fs::write(&path, &markdown)
1169 .expect("Failed to write markdown file");
1170 } else if let Some(ref mut sender) = output_sender.clone() {
1171 let line = serde_json::to_string(&example).unwrap();
1172 sender
1173 .send(line)
1174 .await
1175 .expect("Failed to send to output writer");
1176 } else if args.output.is_none()
1177 && !matches!(command, Command::Eval(_))
1178 {
1179 let line = serde_json::to_string(&example).unwrap();
1180 println!("{}", line);
1181 }
1182 }
1183 }
1184
1185 let project = repo_examples
1186 .iter()
1187 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()));
1188
1189 if let Some(project) = project {
1190 let mut cx = cx.clone();
1191
1192 let shutdown_task: Task<()> =
1193 project.update(&mut cx, |project, cx| {
1194 let lsp_store = project.lsp_store();
1195 lsp_store.update(cx, |lsp_store, cx| {
1196 lsp_store.shutdown_all_language_servers(cx)
1197 })
1198 });
1199
1200 shutdown_task.await;
1201
1202 if let Some(ep_store) =
1203 cx.update(|cx| EditPredictionStore::try_global(cx))
1204 {
1205 ep_store.update(&mut cx, |store, _| {
1206 store.remove_project(&project);
1207 });
1208 }
1209 }
1210
1211 for example in &mut repo_examples {
1212 example.state.take();
1213 }
1214 finished_examples
1215 .lock()
1216 .unwrap()
1217 .extend_from_slice(&repo_examples);
1218 }
1219 });
1220 }
1221 futures::future::join_all(tasks).await;
1222
1223 Progress::global().finalize();
1224
1225 match &command {
1226 Command::Predict(args) | Command::Score(args) => {
1227 predict::sync_batches(args.provider.as_ref()).await?;
1228 }
1229 Command::Eval(args) => {
1230 predict::sync_batches(args.predict.provider.as_ref()).await?;
1231 }
1232 Command::Qa(args) => {
1233 qa::sync_batches(args).await?;
1234 }
1235 Command::Repair(args) => {
1236 repair::sync_batches(args).await?;
1237 }
1238 _ => (),
1239 }
1240
1241 match &command {
1242 Command::Eval(args) => {
1243 let examples = finished_examples.lock().unwrap();
1244 score::print_report(&examples, args.verbose);
1245 if let Some(summary_path) = &args.summary_json {
1246 score::write_summary_json(&examples, summary_path)?;
1247 }
1248 }
1249 Command::Repair(args) => {
1250 let examples = finished_examples.lock().unwrap();
1251 repair::print_report(&examples, args.confidence_threshold);
1252 }
1253 _ => (),
1254 };
1255
1256 // For --in-place, atomically rename temp file to original
1257 if let Some(temp_path) = &in_place_temp_path {
1258 let final_path = output.as_ref().expect("in_place_temp_path requires output");
1259 std::fs::rename(temp_path, final_path)
1260 .expect("Failed to rename temp file to final output");
1261 }
1262
1263 anyhow::Ok(())
1264 }
1265 .await;
1266
1267 if let Err(e) = result {
1268 panic!("Fatal error: {:?}", e);
1269 }
1270
1271 let _ = cx.update(|cx| cx.quit());
1272 })
1273 .detach();
1274 });
1275}
1276
1277async fn handle_error(
1278 error: anyhow::Error,
1279 args: &EpArgs,
1280 command: &Command,
1281 app_state: &Arc<headless::EpAppState>,
1282 failfast_on_single_example: bool,
1283 example: &Example,
1284) {
1285 Progress::global().increment_failed();
1286
1287 let msg;
1288 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1289 let example_name = example.spec.filename();
1290
1291 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1292 app_state
1293 .fs
1294 .write(
1295 &failed_example_path,
1296 &serde_json::to_vec_pretty(&example).unwrap(),
1297 )
1298 .await
1299 .unwrap();
1300 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1301 app_state
1302 .fs
1303 .write(&err_path, format!("{error:?}").as_bytes())
1304 .await
1305 .unwrap();
1306
1307 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1308 let mut file = OpenOptions::new()
1309 .create(true)
1310 .append(true)
1311 .open(&failed_jsonl_path)
1312 .expect("Failed to open failed.jsonl");
1313 writeln!(file, "{}", serde_json::to_string(example).unwrap())
1314 .expect("Failed to write to failed.jsonl");
1315
1316 let cursor_path = match example.repo_name() {
1317 Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1318 Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1319 };
1320 msg = format!(
1321 indoc::indoc! {"
1322 While processing \"{}\":
1323
1324 \x1b[31m{:?}\x1b[0m
1325
1326 Example: \x1b[36m{}\x1b[0m
1327 Error file: \x1b[36m{}\x1b[0m
1328 Cursor file: \x1b[36m{}\x1b[0m
1329 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1330 "},
1331 example.spec.name,
1332 error,
1333 failed_example_path.display(),
1334 err_path.display(),
1335 cursor_path.display(),
1336 command,
1337 failed_example_path.display(),
1338 );
1339 } else {
1340 msg = format!(
1341 indoc::indoc! {"
1342 While processing \"{}\":
1343
1344 \x1b[31m{:?}\x1b[0m
1345 "},
1346 example.spec.name, error
1347 );
1348 }
1349
1350 if args.failfast || failfast_on_single_example {
1351 Progress::global().finalize();
1352 panic!("{}", msg);
1353 } else {
1354 log::error!("{}", msg);
1355 }
1356}