1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod load_project;
9mod metrics;
10mod openai_client;
11mod parse_output;
12mod paths;
13mod predict;
14mod progress;
15mod prompt_assets;
16mod pull_examples;
17mod qa;
18mod reorder_patch;
19mod repair;
20mod retrieve_context;
21mod reversal_tracking;
22mod score;
23mod split_commit;
24mod split_dataset;
25
26mod synthesize;
27mod truncate_expected_patch;
28mod word_diff;
29use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
30use collections::{HashMap, HashSet};
31use edit_prediction::EditPredictionStore;
32use futures::channel::mpsc;
33use futures::{SinkExt as _, StreamExt as _};
34use gaoya::minhash::{
35 MinHashIndex, MinHasher, MinHasher32, calculate_minhash_params, compute_minhash_similarity,
36};
37use gpui::{AppContext as _, BackgroundExecutor, Task};
38use zeta_prompt::ZetaFormat;
39
40use reqwest_client::ReqwestClient;
41use serde::{Deserialize, Deserializer, Serialize, Serializer};
42use std::fmt::Display;
43use std::fs::{File, OpenOptions};
44use std::hash::{Hash, Hasher};
45use std::io::{BufRead, BufReader, BufWriter, Write};
46use std::sync::Mutex;
47use std::{path::PathBuf, sync::Arc};
48
49use crate::distill::run_distill;
50use crate::example::{Example, group_examples_by_repo, read_example_files};
51use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
52use crate::format_prompt::run_format_prompt;
53use crate::load_project::run_load_project;
54use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
55use crate::predict::run_prediction;
56use crate::progress::Progress;
57use crate::retrieve_context::run_context_retrieval;
58use crate::score::run_scoring;
59use crate::split_commit::SplitCommitArgs;
60use crate::split_dataset::SplitArgs;
61use crate::synthesize::{SynthesizeConfig, run_synthesize};
62use crate::truncate_expected_patch::TruncatePatchArgs;
63
64#[derive(Parser, Debug)]
65#[command(name = "ep")]
66struct EpArgs {
67 #[arg(long, default_value_t = false)]
68 printenv: bool,
69 #[clap(long, default_value_t = 10, global = true)]
70 max_parallelism: usize,
71 /// The limit for the number of examples to process
72 /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
73 #[clap(long, global = true)]
74 limit: Option<usize>,
75 #[clap(long, global = true)]
76 offset: Option<usize>,
77 /// Filter examples by name
78 #[clap(long, global = true)]
79 name: Option<String>,
80 /// Filter examples by repository
81 #[clap(long, global = true)]
82 repo: Option<String>,
83 /// Deduplicate by cursor position and keep at most this many examples per cluster
84 #[clap(long, global = true)]
85 max_duplicates: Option<usize>,
86 #[command(subcommand)]
87 command: Option<Command>,
88 /// Input file paths
89 #[clap(global = true)]
90 inputs: Vec<PathBuf>,
91 #[arg(long, short, global = true)]
92 output: Option<PathBuf>,
93 #[arg(long, short, global = true)]
94 in_place: bool,
95 #[arg(long, short, global = true)]
96 failfast: bool,
97 /// How to handle failed examples in output: keep them or skip them.
98 /// Failed examples are always logged to the run's failed directory.
99 #[arg(long, global = true, default_value = "keep")]
100 failed: FailedHandling,
101 /// Output as markdown files instead of JSONL. When set, -o specifies a directory
102 /// where one .md file per example will be written (named after each example).
103 #[arg(long, short, global = true)]
104 markdown: bool,
105}
106
107/// Controls whether failed examples are included in the main output.
108/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
109#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
110pub enum FailedHandling {
111 /// Include failed examples in the main output (default)
112 #[default]
113 Keep,
114 /// Exclude failed examples from the main output
115 Skip,
116 /// Skip writing files
117 SkipNoFiles,
118}
119
120const INPUTS_HELP: &str = r#"
121Inputs can be file paths or special specifiers:
122
123 path
124 Path to an example(s) file (.md, .json, or .jsonl)
125
126 captured-after:{timestamp}
127 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
128 These are examples captured via the "Capture Edit Prediction Example" action.
129
130 rejected-after:{timestamp}
131 Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
132 These are predictions that were shown to users but rejected (useful for DPO training).
133
134 rated-after:{timestamp}
135 Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
136 These are predictions that users explicitly rated as positive or negative via the
137 rate completions modal. Only zeta2 predictions are included.
138 - Positive ratings: output becomes expected_patches
139 - Negative ratings: output becomes rejected_patch
140
141 rated-positive-after:{timestamp}
142 Same as rated-after, but only fetches positively rated predictions.
143
144 rated-negative-after:{timestamp}
145 Same as rated-after, but only fetches negatively rated predictions.
146
147 Required environment variables to connect to Snowflake:
148 EP_SNOWFLAKE_API_KEY
149 EP_SNOWFLAKE_BASE_URL
150
151 Optional:
152 EP_SNOWFLAKE_ROLE
153
154Examples:
155
156 # Read examples from a file
157 ep read examples.jsonl -o output.jsonl
158
159 # Read captured examples after a timestamp
160 ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
161
162 # Read rejected predictions for DPO training
163 ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
164
165 # Read user-rated predictions
166 ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
167
168 # Read only positively rated predictions
169 ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
170
171 # Read only negatively rated predictions
172 ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
173
174 # Mix multiple input sources
175 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
176"#;
177
178#[derive(Subcommand, Debug, Clone)]
179enum Command {
180 /// Read examples from files or fetch from Snowflake, output as .jsonl
181 Read(ReadArgs),
182 /// Create git worktrees for each example and load file contents
183 LoadProject,
184 /// Retrieve context for input examples.
185 Context,
186 /// Generate a prompt string for a specific model
187 FormatPrompt(FormatPromptArgs),
188 /// Runs edit prediction
189 Predict(PredictArgs),
190 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
191 /// Requires format-prompt to have been run first. Uses provider from prompt.
192 ParseOutput,
193 /// Computes a score based on actual and expected patches
194 Score(PredictArgs),
195 /// Prepares a distillation dataset by copying expected outputs to
196 /// predicted outputs and removing actual outputs and prompts.
197 Distill,
198 /// Print aggregated scores
199 Eval(EvalArgs),
200 /// Generate eval examples by analyzing git commits from a repository
201 Synthesize(SynthesizeArgs),
202 /// Remove git repositories and worktrees
203 Clean,
204 /// Generate an evaluation example by splitting a chronologically-ordered commit
205 SplitCommit(SplitCommitArgs),
206 /// Truncate expected patch by the given criteria
207 TruncatePatch(TruncatePatchArgs),
208 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
209 Split(SplitArgs),
210 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
211 FilterLanguages(FilterLanguagesArgs),
212 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
213 ImportBatch(ImportBatchArgs),
214 /// Assess the quality of predictions using LLM-as-a-judge
215 Qa(qa::QaArgs),
216 /// Repair predictions that received poor QA scores by generating improved predictions
217 Repair(repair::RepairArgs),
218 /// Print all valid zeta formats (lowercase, one per line)
219 PrintZetaFormats,
220}
221
222impl Display for Command {
223 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
224 match self {
225 Command::Read(_) => write!(f, "read"),
226 Command::LoadProject => write!(f, "load-project"),
227 Command::Context => write!(f, "context"),
228 Command::FormatPrompt(args) => {
229 write!(f, "format-prompt --provider={}", args.provider)
230 }
231 Command::Predict(args) => match &args.provider {
232 Some(provider) => write!(f, "predict --provider={}", provider),
233 None => write!(f, "predict"),
234 },
235 Command::ParseOutput => write!(f, "parse-output"),
236 Command::Score(args) => match &args.provider {
237 Some(provider) => write!(f, "score --provider={}", provider),
238 None => write!(f, "score"),
239 },
240 Command::Distill => write!(f, "distill"),
241 Command::Eval(args) => match &args.predict.provider {
242 Some(provider) => write!(f, "eval --provider={}", provider),
243 None => write!(f, "eval"),
244 },
245 Command::Synthesize(args) => {
246 write!(f, "synthesize --repos {}", args.repos.join(" "))
247 }
248 Command::Clean => write!(f, "clean"),
249 Command::SplitCommit(_) => write!(f, "split-commit"),
250 Command::TruncatePatch(_) => write!(f, "truncate-patch"),
251 Command::Split(_) => write!(f, "split"),
252 Command::FilterLanguages(_) => write!(f, "filter-languages"),
253 Command::ImportBatch(args) => {
254 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
255 }
256 Command::Qa(_) => {
257 write!(f, "qa")
258 }
259 Command::Repair(_) => {
260 write!(f, "repair")
261 }
262 Command::PrintZetaFormats => {
263 write!(f, "print-zeta-formats")
264 }
265 }
266 }
267}
268
269#[derive(Debug, Args, Clone)]
270#[command(after_help = INPUTS_HELP)]
271struct ReadArgs {}
272
273#[derive(Debug, Args, Clone)]
274struct FormatPromptArgs {
275 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
276 provider: PredictionProvider,
277}
278
279#[derive(Debug, Args, Clone)]
280struct PredictArgs {
281 #[clap(long, short('p'))]
282 provider: Option<PredictionProvider>,
283 #[clap(long, default_value_t = 1)]
284 repetitions: usize,
285 /// Only use cached responses, don't queue new requests for batching
286 #[clap(long)]
287 cache_only: bool,
288}
289
290#[derive(Debug, Args, Clone)]
291struct EvalArgs {
292 #[clap(flatten)]
293 predict: PredictArgs,
294 /// Path to write summary scores as JSON
295 #[clap(long)]
296 summary_json: Option<PathBuf>,
297}
298
299#[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)]
300pub enum TeacherBackend {
301 Sonnet46,
302 #[default]
303 Sonnet45,
304 Gpt52,
305}
306
307impl std::fmt::Display for TeacherBackend {
308 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309 match self {
310 TeacherBackend::Sonnet46 => write!(f, "sonnet46"),
311 TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
312 TeacherBackend::Gpt52 => write!(f, "gpt52"),
313 }
314 }
315}
316
317impl std::str::FromStr for TeacherBackend {
318 type Err = anyhow::Error;
319
320 fn from_str(s: &str) -> Result<Self, Self::Err> {
321 match s.to_lowercase().as_str() {
322 "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
323 "sonnet46" => Ok(TeacherBackend::Sonnet46),
324 "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
325 "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
326 _ => anyhow::bail!(
327 "unknown teacher backend `{s}`. Valid options: sonnet45, sonnet46, gpt52"
328 ),
329 }
330 }
331}
332
333impl TeacherBackend {
334 pub fn model_name(&self) -> &'static str {
335 match self {
336 TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
337 TeacherBackend::Sonnet46 => "claude-sonnet-4-6",
338 TeacherBackend::Gpt52 => "gpt-5.2",
339 }
340 }
341}
342
343#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
344enum PredictionProvider {
345 Sweep,
346 Mercury,
347 Zeta1,
348 Zeta2(ZetaFormat),
349 Teacher(TeacherBackend),
350 TeacherNonBatching(TeacherBackend),
351 Repair,
352}
353
354impl Default for PredictionProvider {
355 fn default() -> Self {
356 PredictionProvider::Zeta2(ZetaFormat::default())
357 }
358}
359
360impl std::fmt::Display for PredictionProvider {
361 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
362 match self {
363 PredictionProvider::Sweep => write!(f, "sweep"),
364 PredictionProvider::Mercury => write!(f, "mercury"),
365 PredictionProvider::Zeta1 => write!(f, "zeta1"),
366 PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
367 PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
368 PredictionProvider::TeacherNonBatching(backend) => {
369 write!(f, "teacher-non-batching:{backend}")
370 }
371 PredictionProvider::Repair => write!(f, "repair"),
372 }
373 }
374}
375
376impl std::str::FromStr for PredictionProvider {
377 type Err = anyhow::Error;
378
379 fn from_str(s: &str) -> Result<Self, Self::Err> {
380 let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
381
382 let provider_lower = provider.to_lowercase();
383 match provider_lower.as_str() {
384 "sweep" => Ok(PredictionProvider::Sweep),
385 "mercury" => Ok(PredictionProvider::Mercury),
386 "zeta1" => Ok(PredictionProvider::Zeta1),
387 "zeta2" => {
388 let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
389 Ok(PredictionProvider::Zeta2(format))
390 }
391 "teacher" => {
392 let backend = arg
393 .map(|a| a.parse())
394 .transpose()?
395 .unwrap_or(TeacherBackend::default());
396 Ok(PredictionProvider::Teacher(backend))
397 }
398 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
399 let backend = arg
400 .map(|a| a.parse())
401 .transpose()?
402 .unwrap_or(TeacherBackend::default());
403 Ok(PredictionProvider::TeacherNonBatching(backend))
404 }
405 "repair" => Ok(PredictionProvider::Repair),
406 _ => {
407 anyhow::bail!(
408 "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
409 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
410 For teacher, you can specify a backend like `teacher:sonnet46` or `teacher:gpt52`.\n\
411 Available zeta versions:\n{}",
412 ZetaFormat::options_as_string()
413 )
414 }
415 }
416 }
417}
418
419impl Serialize for PredictionProvider {
420 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
421 where
422 S: Serializer,
423 {
424 serializer.serialize_str(&self.to_string())
425 }
426}
427
428impl<'de> Deserialize<'de> for PredictionProvider {
429 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
430 where
431 D: Deserializer<'de>,
432 {
433 let s = String::deserialize(deserializer)?;
434 s.parse().map_err(serde::de::Error::custom)
435 }
436}
437
438#[derive(Debug, Args, Clone)]
439struct SynthesizeArgs {
440 /// Repository URLs (git@github.com:owner/repo or https://...)
441 #[clap(long, required = true, num_args = 1..)]
442 repos: Vec<String>,
443
444 /// Number of examples to generate per repository
445 #[clap(long, default_value_t = 5)]
446 count: usize,
447
448 /// Maximum commits to scan per repository before giving up
449 #[clap(long, default_value_t = 100)]
450 max_commits: usize,
451
452 /// Ignore state file and reprocess all commits
453 #[clap(long)]
454 fresh: bool,
455}
456
457#[derive(Debug, Args, Clone)]
458struct ImportBatchArgs {
459 /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
460 #[clap(long, required = true, num_args = 1..)]
461 batch_ids: Vec<String>,
462 /// Which provider's batches to import (anthropic or openai)
463 #[clap(long, default_value = "anthropic")]
464 provider: BatchProvider,
465}
466
467#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
468enum BatchProvider {
469 Anthropic,
470 Openai,
471}
472
473impl EpArgs {
474 fn output_path(&self) -> Option<PathBuf> {
475 if self.in_place {
476 if self.inputs.len() == 1 {
477 self.inputs.first().cloned()
478 } else {
479 panic!("--in-place requires exactly one input file")
480 }
481 } else {
482 self.output.clone()
483 }
484 }
485}
486
487/// Minimum Zed version required for Snowflake queries.
488/// This version introduced the current request schema with predicted edits in the edit
489/// history, and open source repos distinguished.
490const MIN_CAPTURE_VERSION: pull_examples::MinCaptureVersion = pull_examples::MinCaptureVersion {
491 minor: 224,
492 patch: 1,
493};
494
495fn deduplicate_examples(examples: &mut Vec<Example>, max_per_cluster: usize) {
496 let total_before_exact = examples.len();
497 let mut seen_positions = HashSet::default();
498 examples.retain(|example| seen_positions.insert(example.spec.cursor_position.clone()));
499 log::info!(
500 "exact duplicate filter: {total_before_exact} examples → {} examples ({} removed)",
501 examples.len(),
502 total_before_exact - examples.len(),
503 );
504
505 const JACCARD_THRESHOLD: f64 = 0.5;
506 const NUM_HASHES: usize = 128;
507 const TOKEN_NGRAM_SIZE: usize = 5;
508
509 let (num_bands, band_width) = calculate_minhash_params(JACCARD_THRESHOLD, NUM_HASHES);
510 let num_hashes = num_bands * band_width;
511 let minhasher = MinHasher32::new(num_hashes);
512 let mut index: MinHashIndex<u32, usize> =
513 MinHashIndex::new(num_bands, band_width, JACCARD_THRESHOLD);
514
515 let signatures: Vec<Vec<u32>> = examples
516 .iter()
517 .map(|example| {
518 let shingles = code_token_ngrams(&example.spec.cursor_position, TOKEN_NGRAM_SIZE);
519 minhasher.create_signature(shingles.iter())
520 })
521 .collect();
522
523 for (id, signature) in signatures.iter().enumerate() {
524 index.insert(id, signature.clone());
525 }
526
527 // Build clusters via union-find on LSH candidate pairs.
528 let mut parent: Vec<usize> = (0..examples.len()).collect();
529
530 fn find(parent: &mut Vec<usize>, mut x: usize) -> usize {
531 while parent[x] != x {
532 parent[x] = parent[parent[x]];
533 x = parent[x];
534 }
535 x
536 }
537
538 for (id, signature) in signatures.iter().enumerate() {
539 for candidate in index.query_owned(signature) {
540 let (a, b) = (find(&mut parent, id), find(&mut parent, candidate));
541 if a != b {
542 parent[a] = b;
543 }
544 }
545 }
546
547 let mut clusters: HashMap<usize, Vec<usize>> = HashMap::default();
548 for id in 0..examples.len() {
549 clusters.entry(find(&mut parent, id)).or_default().push(id);
550 }
551
552 let mut keep: HashSet<usize> = HashSet::default();
553 for members in clusters.values() {
554 let selected = greedy_max_min_diverse(members, &signatures, max_per_cluster);
555 keep.extend(selected);
556 }
557
558 let total = examples.len();
559 let mut kept_indices: Vec<usize> = keep.into_iter().collect();
560 kept_indices.sort();
561
562 let mut retained = Vec::with_capacity(kept_indices.len());
563 for index in kept_indices.into_iter().rev() {
564 retained.push(examples.swap_remove(index));
565 }
566 retained.reverse();
567
568 *examples = retained;
569 log::info!(
570 "near-duplicate filter: {total} examples → {} examples ({} removed)",
571 examples.len(),
572 total - examples.len(),
573 );
574}
575
576fn greedy_max_min_diverse(members: &[usize], signatures: &[Vec<u32>], k: usize) -> Vec<usize> {
577 if members.len() <= k {
578 return members.to_vec();
579 }
580
581 let mut selected = vec![members[0]];
582 let mut min_dist: HashMap<usize, f64> = HashMap::default();
583 for &member in &members[1..] {
584 let dist = 1.0 - compute_minhash_similarity(&signatures[selected[0]], &signatures[member]);
585 min_dist.insert(member, dist);
586 }
587
588 while selected.len() < k {
589 let &best = min_dist
590 .iter()
591 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
592 .map(|(id, _)| id)
593 .expect("min_dist should not be empty when selected.len() < k");
594 selected.push(best);
595 min_dist.remove(&best);
596
597 let best_sig = &signatures[best];
598 for (member, current_min) in min_dist.iter_mut() {
599 let dist = 1.0 - compute_minhash_similarity(best_sig, &signatures[*member]);
600 if dist < *current_min {
601 *current_min = dist;
602 }
603 }
604 }
605
606 selected
607}
608
609fn code_token_ngrams(code: &str, ngram_size: usize) -> Vec<String> {
610 let tokens: Vec<&str> = word_diff::tokenize(code)
611 .into_iter()
612 .filter(|t| !t.trim().is_empty())
613 .collect();
614
615 if tokens.len() < ngram_size {
616 return vec![tokens.join("\0")];
617 }
618
619 tokens
620 .windows(ngram_size)
621 .map(|window| window.join("\0"))
622 .collect()
623}
624
625async fn load_examples(
626 http_client: Arc<dyn http_client::HttpClient>,
627 args: &EpArgs,
628 output_path: Option<&PathBuf>,
629 background_executor: BackgroundExecutor,
630) -> anyhow::Result<Vec<Example>> {
631 let mut captured_after_timestamps = Vec::new();
632 let mut rejected_after_timestamps = Vec::new();
633 let mut requested_after_timestamps = Vec::new();
634 let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
635 Vec::new();
636 let mut file_inputs = Vec::new();
637
638 for input in &args.inputs {
639 let input_string = input.to_string_lossy();
640 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
641 captured_after_timestamps.push(timestamp.to_string());
642 } else if let Some(timestamp) =
643 pull_examples::parse_rejected_after_input(input_string.as_ref())
644 {
645 rejected_after_timestamps.push(timestamp.to_string());
646 } else if let Some(timestamp) =
647 pull_examples::parse_requested_after_input(input_string.as_ref())
648 {
649 requested_after_timestamps.push(timestamp.to_string());
650 } else if let Some((timestamp, rating_filter)) =
651 pull_examples::parse_rated_after_input(input_string.as_ref())
652 {
653 rated_after_inputs.push((timestamp.to_string(), rating_filter));
654 } else {
655 file_inputs.push(input.clone());
656 }
657 }
658
659 let mut examples = read_example_files(&file_inputs);
660
661 // Apply offset to file examples first, then pass remaining offset to Snowflake.
662 let file_example_count = examples.len();
663 let remaining_offset = if let Some(offset) = args.offset {
664 if offset >= file_example_count {
665 examples.clear();
666 offset - file_example_count
667 } else {
668 examples.splice(0..offset, []);
669 0
670 }
671 } else {
672 0
673 };
674
675 Progress::global().set_total_examples(examples.len());
676
677 let remaining_limit_for_snowflake =
678 args.limit.map(|limit| limit.saturating_sub(examples.len()));
679
680 if let Some(0) = remaining_limit_for_snowflake {
681 log::info!(
682 "skipping Snowflake inputs because --limit is already satisfied by example files"
683 );
684 } else {
685 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
686
687 if !captured_after_timestamps.is_empty() {
688 captured_after_timestamps.sort();
689
690 let mut captured_examples = pull_examples::fetch_captured_examples_after(
691 http_client.clone(),
692 &captured_after_timestamps,
693 max_rows_per_timestamp,
694 remaining_offset,
695 background_executor.clone(),
696 Some(MIN_CAPTURE_VERSION),
697 )
698 .await?;
699 examples.append(&mut captured_examples);
700 }
701
702 if !rejected_after_timestamps.is_empty() {
703 rejected_after_timestamps.sort();
704
705 let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
706 http_client.clone(),
707 &rejected_after_timestamps,
708 max_rows_per_timestamp,
709 remaining_offset,
710 background_executor.clone(),
711 Some(MIN_CAPTURE_VERSION),
712 )
713 .await?;
714 examples.append(&mut rejected_examples);
715 }
716
717 if !requested_after_timestamps.is_empty() {
718 requested_after_timestamps.sort();
719
720 let mut requested_examples = pull_examples::fetch_requested_examples_after(
721 http_client.clone(),
722 &requested_after_timestamps,
723 max_rows_per_timestamp,
724 remaining_offset,
725 background_executor.clone(),
726 Some(MIN_CAPTURE_VERSION),
727 )
728 .await?;
729 examples.append(&mut requested_examples);
730 }
731
732 if !rated_after_inputs.is_empty() {
733 rated_after_inputs.sort();
734
735 let mut rated_examples = pull_examples::fetch_rated_examples_after(
736 http_client,
737 &rated_after_inputs,
738 max_rows_per_timestamp,
739 remaining_offset,
740 background_executor,
741 Some(MIN_CAPTURE_VERSION),
742 )
743 .await?;
744 examples.append(&mut rated_examples);
745 }
746 }
747
748 crate::example::sort_examples_by_repo_and_rev(&mut examples);
749
750 if let Some(name_filter) = &args.name {
751 examples.retain(|example| example.spec.name.contains(name_filter));
752 }
753 if let Some(repo_filter) = &args.repo {
754 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
755 }
756
757 // Skip resume logic for --in-place since input and output are the same file,
758 // which would incorrectly treat all input examples as already processed.
759 if !args.in_place {
760 if let Some(path) = output_path
761 && let Some(command) = &args.command
762 {
763 resume_from_output(path, &mut examples, command);
764 }
765 }
766
767 if let Some(max_duplicates) = args.max_duplicates {
768 deduplicate_examples(&mut examples, max_duplicates);
769 }
770
771 if let Some(limit) = args.limit {
772 examples.truncate(limit);
773 }
774
775 let progress = Progress::global();
776 progress.set_total_examples(examples.len());
777 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
778
779 Ok(examples)
780}
781
782fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
783 let mut hasher = collections::FxHasher::default();
784 spec.hash(&mut hasher);
785 hasher.finish()
786}
787
788fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
789 let file = match File::open(path) {
790 Ok(f) => f,
791 Err(_) => return,
792 };
793
794 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
795
796 let reader = BufReader::new(file);
797 let mut kept_lines = Vec::new();
798 let mut kept_hashes = HashSet::default();
799
800 for line in reader.lines() {
801 let line = match line {
802 Ok(l) => l,
803 Err(_) => continue,
804 };
805
806 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
807 let hash = spec_hash(&output_example.spec);
808 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
809 let is_complete = match command {
810 Command::Qa(_) => output_example
811 .qa
812 .first()
813 .and_then(|q| q.as_ref())
814 .and_then(|q| q.confidence)
815 .is_some(),
816 Command::Repair(_) => output_example.predictions.iter().any(|p| {
817 p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
818 }),
819 _ => true,
820 };
821 if is_complete {
822 kept_hashes.insert(hash);
823 kept_lines.push(line);
824 }
825 }
826 }
827 }
828
829 let total = examples.len();
830 let already_processed = kept_hashes.len();
831
832 eprintln!(
833 "Resuming: {}/{} examples already processed",
834 already_processed, total
835 );
836
837 let file = OpenOptions::new()
838 .write(true)
839 .truncate(true)
840 .open(path)
841 .expect("Failed to open output file for rewriting");
842 let mut writer = BufWriter::new(file);
843 for line in &kept_lines {
844 writeln!(writer, "{}", line).expect("Failed to write to output file");
845 }
846 writer.flush().expect("Failed to flush output file");
847
848 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
849}
850
851fn main() {
852 let args = EpArgs::parse();
853
854 if args.printenv {
855 ::util::shell_env::print_env();
856 return;
857 }
858
859 let output = args.output_path();
860
861 if args.markdown && output.is_none() {
862 eprintln!("--markdown requires -o to specify the output directory");
863 std::process::exit(1);
864 }
865
866 let command = match &args.command {
867 Some(cmd) => cmd.clone(),
868 None => {
869 EpArgs::command().print_help().unwrap();
870 return;
871 }
872 };
873
874 match &command {
875 Command::ImportBatch(import_args) => {
876 smol::block_on(async {
877 match import_args.provider {
878 BatchProvider::Anthropic => {
879 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
880 .expect("Failed to create Anthropic client");
881 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
882 eprintln!("Error importing Anthropic batches: {:?}", e);
883 std::process::exit(1);
884 }
885 }
886 BatchProvider::Openai => {
887 let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
888 .expect("Failed to create OpenAI client");
889 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
890 eprintln!("Error importing OpenAI batches: {:?}", e);
891 std::process::exit(1);
892 }
893 }
894 }
895 println!(
896 "Successfully imported {} batch(es)",
897 import_args.batch_ids.len()
898 );
899 });
900 return;
901 }
902 Command::Clean => {
903 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
904 return;
905 }
906 Command::PrintZetaFormats => {
907 use strum::IntoEnumIterator as _;
908 for format in ZetaFormat::iter() {
909 println!("{}", format.to_string().to_lowercase());
910 }
911 return;
912 }
913
914 Command::Synthesize(synth_args) => {
915 let Some(output_dir) = args.output else {
916 panic!("output dir is required");
917 };
918 let config = SynthesizeConfig {
919 repo_urls: synth_args.repos.clone(),
920 count: synth_args.count,
921 max_commits: synth_args.max_commits,
922 output_dir,
923 fresh: synth_args.fresh,
924 };
925 smol::block_on(async {
926 if let Err(e) = run_synthesize(config).await {
927 eprintln!("Error: {:?}", e);
928 std::process::exit(1);
929 }
930 });
931 return;
932 }
933 Command::SplitCommit(split_commit_args) => {
934 if let Err(error) = split_commit::run_split_commit(
935 split_commit_args,
936 &args.inputs,
937 output.as_ref(),
938 args.failed,
939 ) {
940 eprintln!("{error:#}");
941 std::process::exit(1);
942 }
943 return;
944 }
945 Command::TruncatePatch(truncate_args) => {
946 if let Err(error) =
947 truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
948 {
949 eprintln!("{error:#}");
950 std::process::exit(1);
951 }
952 return;
953 }
954 Command::Split(split_args) => {
955 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
956 eprintln!("{error:#}");
957 std::process::exit(1);
958 }
959 return;
960 }
961 Command::FilterLanguages(filter_args) => {
962 if let Err(error) =
963 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
964 {
965 eprintln!("{error:#}");
966 std::process::exit(1);
967 }
968 return;
969 }
970
971 _ => {}
972 }
973
974 let http_client = Arc::new(ReqwestClient::new());
975 let app = gpui_platform::headless().with_http_client(http_client);
976
977 app.run(move |cx| {
978 let app_state = Arc::new(headless::init(cx));
979 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
980
981 cx.spawn(async move |cx| {
982 let result = async {
983 let examples = load_examples(
984 app_state.client.http_client(),
985 &args,
986 output.as_ref(),
987 cx.background_executor().clone(),
988 )
989 .await?;
990
991 match &command {
992 Command::Predict(args) | Command::Score(args) => {
993 predict::sync_batches(args.provider.as_ref()).await?;
994 }
995 Command::Eval(args) => {
996 predict::sync_batches(args.predict.provider.as_ref()).await?;
997 }
998 Command::Qa(args) => {
999 qa::sync_batches(args).await?;
1000 }
1001 Command::Repair(args) => {
1002 repair::sync_batches(args).await?;
1003 }
1004 _ => (),
1005 }
1006
1007 let failfast_on_single_example = examples.len() == 1;
1008
1009 // For --markdown mode, create the output directory if it doesn't exist
1010 if args.markdown {
1011 let dir = output.as_ref().expect("--markdown requires -o");
1012 if !dir.exists() {
1013 std::fs::create_dir_all(dir)
1014 .expect("Failed to create markdown output directory");
1015 }
1016 }
1017
1018 // Set up JSONL output writer (not used in markdown mode)
1019 let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
1020 let mut in_place_temp_path: Option<PathBuf> = None;
1021 if !args.markdown
1022 && let Some(output_path) = output.as_ref()
1023 {
1024 let write_path = if args.in_place {
1025 let temp = output_path.with_extension("jsonl.tmp");
1026 in_place_temp_path = Some(temp.clone());
1027 temp
1028 } else {
1029 output_path.clone()
1030 };
1031
1032 let file = OpenOptions::new()
1033 .create(true)
1034 .write(true)
1035 .truncate(args.in_place)
1036 .append(!args.in_place)
1037 .open(&write_path)
1038 .expect("Failed to open output file");
1039
1040 let mut writer = BufWriter::new(file);
1041 let (sender, mut receiver) = mpsc::unbounded::<String>();
1042 cx.background_spawn(async move {
1043 while let Some(line) = receiver.next().await {
1044 writeln!(writer, "{}", line).expect("Failed to write example");
1045 writer.flush().expect("Failed to flush output");
1046 }
1047 })
1048 .detach();
1049 output_sender = Some(sender);
1050 }
1051
1052 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
1053 let finished_examples = Mutex::new(Vec::new());
1054
1055 let mut tasks = Vec::new();
1056 for _ in 0..args.max_parallelism {
1057 tasks.push(async {
1058 loop {
1059 let Some(mut repo_examples) =
1060 grouped_examples.lock().unwrap().pop_front()
1061 else {
1062 break;
1063 };
1064 for example in &mut repo_examples {
1065 let example_progress =
1066 Progress::global().start_group(&example.spec.name);
1067
1068 let result = async {
1069 match &command {
1070 Command::Read(_) => {}
1071 Command::LoadProject => {
1072 run_load_project(
1073 example,
1074 app_state.clone(),
1075 &example_progress,
1076 cx.clone(),
1077 )
1078 .await?;
1079 }
1080 Command::Context => {
1081 run_context_retrieval(
1082 example,
1083 app_state.clone(),
1084 &example_progress,
1085 cx.clone(),
1086 )
1087 .await?;
1088 }
1089 Command::FormatPrompt(args) => {
1090 run_format_prompt(
1091 example,
1092 args,
1093 app_state.clone(),
1094 &example_progress,
1095 cx.clone(),
1096 )
1097 .await?;
1098 }
1099 Command::Predict(args) => {
1100 run_prediction(
1101 example,
1102 args,
1103 app_state.clone(),
1104 &example_progress,
1105 cx.clone(),
1106 )
1107 .await?;
1108 }
1109 Command::ParseOutput => {
1110 parse_output::run_parse_output(example)?;
1111 }
1112 Command::Distill => {
1113 run_distill(example).await?;
1114 }
1115 Command::Score(args) => {
1116 run_scoring(
1117 example,
1118 args,
1119 app_state.clone(),
1120 &example_progress,
1121 cx.clone(),
1122 )
1123 .await?;
1124 }
1125 Command::Eval(args) => {
1126 run_scoring(
1127 example,
1128 &args.predict,
1129 app_state.clone(),
1130 &example_progress,
1131 cx.clone(),
1132 )
1133 .await?;
1134 }
1135 Command::Qa(args) => {
1136 qa::run_qa(example, args, &example_progress).await?;
1137 }
1138 Command::Repair(args) => {
1139 repair::run_repair(example, args, &example_progress)
1140 .await?;
1141 }
1142 Command::Clean
1143 | Command::Synthesize(_)
1144 | Command::SplitCommit(_)
1145 | Command::Split(_)
1146 | Command::TruncatePatch(_)
1147 | Command::FilterLanguages(_)
1148 | Command::ImportBatch(_)
1149 | Command::PrintZetaFormats => {
1150 unreachable!()
1151 }
1152 }
1153 anyhow::Ok(())
1154 }
1155 .await;
1156
1157 let failed = if let Err(error) = result {
1158 handle_error(
1159 error,
1160 &args,
1161 &command,
1162 &app_state,
1163 failfast_on_single_example,
1164 &example,
1165 )
1166 .await;
1167 true
1168 } else {
1169 false
1170 };
1171
1172 let should_write = !failed || args.failed == FailedHandling::Keep;
1173 if should_write {
1174 if args.markdown {
1175 let markdown_dir =
1176 output.as_ref().expect("--markdown requires -o");
1177 let filename = format!("{}.md", example.spec.filename());
1178 let path = markdown_dir.join(&filename);
1179 let markdown = example.spec.to_markdown();
1180 std::fs::write(&path, &markdown)
1181 .expect("Failed to write markdown file");
1182 } else if let Some(ref mut sender) = output_sender.clone() {
1183 let line = serde_json::to_string(&example).unwrap();
1184 sender
1185 .send(line)
1186 .await
1187 .expect("Failed to send to output writer");
1188 } else if args.output.is_none()
1189 && !matches!(command, Command::Eval(_))
1190 {
1191 let line = serde_json::to_string(&example).unwrap();
1192 println!("{}", line);
1193 }
1194 }
1195 }
1196
1197 let project = repo_examples
1198 .iter()
1199 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()));
1200
1201 if let Some(project) = project {
1202 let mut cx = cx.clone();
1203
1204 let shutdown_task: Task<()> =
1205 project.update(&mut cx, |project, cx| {
1206 let lsp_store = project.lsp_store();
1207 lsp_store.update(cx, |lsp_store, cx| {
1208 lsp_store.shutdown_all_language_servers(cx)
1209 })
1210 });
1211
1212 shutdown_task.await;
1213
1214 if let Some(ep_store) =
1215 cx.update(|cx| EditPredictionStore::try_global(cx))
1216 {
1217 ep_store.update(&mut cx, |store, _| {
1218 store.remove_project(&project);
1219 });
1220 }
1221 }
1222
1223 for example in &mut repo_examples {
1224 example.state.take();
1225 }
1226 finished_examples
1227 .lock()
1228 .unwrap()
1229 .extend_from_slice(&repo_examples);
1230 }
1231 });
1232 }
1233 futures::future::join_all(tasks).await;
1234
1235 Progress::global().finalize();
1236
1237 match &command {
1238 Command::Predict(args) | Command::Score(args) => {
1239 predict::sync_batches(args.provider.as_ref()).await?;
1240 }
1241 Command::Eval(args) => {
1242 predict::sync_batches(args.predict.provider.as_ref()).await?;
1243 }
1244 Command::Qa(args) => {
1245 qa::sync_batches(args).await?;
1246 }
1247 Command::Repair(args) => {
1248 repair::sync_batches(args).await?;
1249 }
1250 _ => (),
1251 }
1252
1253 match &command {
1254 Command::Eval(args) => {
1255 let examples = finished_examples.lock().unwrap();
1256 score::print_report(&examples);
1257 if let Some(summary_path) = &args.summary_json {
1258 score::write_summary_json(&examples, summary_path)?;
1259 }
1260 }
1261 Command::Repair(args) => {
1262 let examples = finished_examples.lock().unwrap();
1263 repair::print_report(&examples, args.confidence_threshold);
1264 }
1265 _ => (),
1266 };
1267
1268 // For --in-place, atomically rename temp file to original
1269 if let Some(temp_path) = &in_place_temp_path {
1270 let final_path = output.as_ref().expect("in_place_temp_path requires output");
1271 std::fs::rename(temp_path, final_path)
1272 .expect("Failed to rename temp file to final output");
1273 }
1274
1275 anyhow::Ok(())
1276 }
1277 .await;
1278
1279 if let Err(e) = result {
1280 panic!("Fatal error: {:?}", e);
1281 }
1282
1283 let _ = cx.update(|cx| cx.quit());
1284 })
1285 .detach();
1286 });
1287}
1288
1289async fn handle_error(
1290 error: anyhow::Error,
1291 args: &EpArgs,
1292 command: &Command,
1293 app_state: &Arc<headless::EpAppState>,
1294 failfast_on_single_example: bool,
1295 example: &Example,
1296) {
1297 Progress::global().increment_failed();
1298
1299 let msg;
1300 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1301 let example_name = example.spec.filename();
1302
1303 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1304 app_state
1305 .fs
1306 .write(
1307 &failed_example_path,
1308 &serde_json::to_vec_pretty(&example).unwrap(),
1309 )
1310 .await
1311 .unwrap();
1312 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1313 app_state
1314 .fs
1315 .write(&err_path, format!("{error:?}").as_bytes())
1316 .await
1317 .unwrap();
1318
1319 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1320 let mut file = OpenOptions::new()
1321 .create(true)
1322 .append(true)
1323 .open(&failed_jsonl_path)
1324 .expect("Failed to open failed.jsonl");
1325 writeln!(file, "{}", serde_json::to_string(example).unwrap())
1326 .expect("Failed to write to failed.jsonl");
1327
1328 let cursor_path = match example.repo_name() {
1329 Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1330 Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1331 };
1332 msg = format!(
1333 indoc::indoc! {"
1334 While processing \"{}\":
1335
1336 \x1b[31m{:?}\x1b[0m
1337
1338 Example: \x1b[36m{}\x1b[0m
1339 Error file: \x1b[36m{}\x1b[0m
1340 Cursor file: \x1b[36m{}\x1b[0m
1341 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1342 "},
1343 example.spec.name,
1344 error,
1345 failed_example_path.display(),
1346 err_path.display(),
1347 cursor_path.display(),
1348 command,
1349 failed_example_path.display(),
1350 );
1351 } else {
1352 msg = format!(
1353 indoc::indoc! {"
1354 While processing \"{}\":
1355
1356 \x1b[31m{:?}\x1b[0m
1357 "},
1358 example.spec.name, error
1359 );
1360 }
1361
1362 if args.failfast || failfast_on_single_example {
1363 Progress::global().finalize();
1364 panic!("{}", msg);
1365 } else {
1366 log::error!("{}", msg);
1367 }
1368}