1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod load_project;
9mod metrics;
10mod openai_client;
11mod parse_output;
12mod paths;
13mod predict;
14mod progress;
15mod prompt_assets;
16mod pull_examples;
17mod qa;
18mod reorder_patch;
19mod repair;
20mod retrieve_context;
21mod reversal_tracking;
22mod score;
23mod split_commit;
24mod split_dataset;
25
26mod synthesize;
27mod truncate_expected_patch;
28mod word_diff;
29use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
30use collections::{HashMap, HashSet};
31use edit_prediction::EditPredictionStore;
32use futures::channel::mpsc;
33use futures::{SinkExt as _, StreamExt as _};
34use gaoya::minhash::{
35 MinHashIndex, MinHasher, MinHasher32, calculate_minhash_params, compute_minhash_similarity,
36};
37use gpui::{AppContext as _, BackgroundExecutor, Task};
38use zeta_prompt::ZetaFormat;
39
40use reqwest_client::ReqwestClient;
41use serde::{Deserialize, Deserializer, Serialize, Serializer};
42use std::fmt::Display;
43use std::fs::{File, OpenOptions};
44use std::hash::{Hash, Hasher};
45use std::io::{BufRead, BufReader, BufWriter, Write};
46use std::sync::Mutex;
47use std::{path::PathBuf, sync::Arc};
48
49use crate::distill::run_distill;
50use crate::example::{Example, group_examples_by_repo, read_example_files};
51use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
52use crate::format_prompt::run_format_prompt;
53use crate::load_project::run_load_project;
54use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
55use crate::predict::run_prediction;
56use crate::progress::Progress;
57use crate::retrieve_context::run_context_retrieval;
58use crate::score::run_scoring;
59use crate::split_commit::SplitCommitArgs;
60use crate::split_dataset::SplitArgs;
61use crate::synthesize::{SynthesizeConfig, run_synthesize};
62use crate::truncate_expected_patch::TruncatePatchArgs;
63
64#[derive(Parser, Debug)]
65#[command(name = "ep")]
66struct EpArgs {
67 #[arg(long, default_value_t = false)]
68 printenv: bool,
69 #[clap(long, default_value_t = 10, global = true)]
70 max_parallelism: usize,
71 /// The limit for the number of examples to process
72 /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
73 #[clap(long, global = true)]
74 limit: Option<usize>,
75 #[clap(long, global = true)]
76 offset: Option<usize>,
77 /// Filter examples by name
78 #[clap(long, global = true)]
79 name: Option<String>,
80 /// Filter examples by repository
81 #[clap(long, global = true)]
82 repo: Option<String>,
83 /// Deduplicate by cursor position and keep at most this many examples per cluster
84 #[clap(long, global = true)]
85 max_duplicates: Option<usize>,
86 #[command(subcommand)]
87 command: Option<Command>,
88 /// Input file paths
89 #[clap(global = true)]
90 inputs: Vec<PathBuf>,
91 #[arg(long, short, global = true)]
92 output: Option<PathBuf>,
93 #[arg(long, short, global = true)]
94 in_place: bool,
95 #[arg(long, short, global = true)]
96 failfast: bool,
97 /// How to handle failed examples in output: keep them or skip them.
98 /// Failed examples are always logged to the run's failed directory.
99 #[arg(long, global = true, default_value = "keep")]
100 failed: FailedHandling,
101 /// Output as markdown files instead of JSONL. When set, -o specifies a directory
102 /// where one .md file per example will be written (named after each example).
103 #[arg(long, short, global = true)]
104 markdown: bool,
105}
106
107/// Controls whether failed examples are included in the main output.
108/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
109#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
110pub enum FailedHandling {
111 /// Include failed examples in the main output (default)
112 #[default]
113 Keep,
114 /// Exclude failed examples from the main output
115 Skip,
116 /// Skip writing files
117 SkipNoFiles,
118}
119
120const INPUTS_HELP: &str = r#"
121Inputs can be file paths or special specifiers:
122
123 path
124 Path to an example(s) file (.md, .json, or .jsonl)
125
126 captured-after:{timestamp}
127 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
128 These are examples captured via the "Capture Edit Prediction Example" action.
129
130 rejected-after:{timestamp}
131 Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
132 These are predictions that were shown to users but rejected (useful for DPO training).
133
134 rated-after:{timestamp}
135 Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
136 These are predictions that users explicitly rated as positive or negative via the
137 rate completions modal. Only zeta2 predictions are included.
138 - Positive ratings: output becomes expected_patches
139 - Negative ratings: output becomes rejected_patch
140
141 rated-positive-after:{timestamp}
142 Same as rated-after, but only fetches positively rated predictions.
143
144 rated-negative-after:{timestamp}
145 Same as rated-after, but only fetches negatively rated predictions.
146
147 Required environment variables to connect to Snowflake:
148 EP_SNOWFLAKE_API_KEY
149 EP_SNOWFLAKE_BASE_URL
150
151 Optional:
152 EP_SNOWFLAKE_ROLE
153
154Examples:
155
156 # Read examples from a file
157 ep read examples.jsonl -o output.jsonl
158
159 # Read captured examples after a timestamp
160 ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
161
162 # Read rejected predictions for DPO training
163 ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
164
165 # Read user-rated predictions
166 ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
167
168 # Read only positively rated predictions
169 ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
170
171 # Read only negatively rated predictions
172 ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
173
174 # Mix multiple input sources
175 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
176"#;
177
178#[derive(Subcommand, Debug, Clone)]
179enum Command {
180 /// Read examples from files or fetch from Snowflake, output as .jsonl
181 Read(ReadArgs),
182 /// Create git worktrees for each example and load file contents
183 LoadProject,
184 /// Retrieve context for input examples.
185 Context,
186 /// Generate a prompt string for a specific model
187 FormatPrompt(FormatPromptArgs),
188 /// Runs edit prediction
189 Predict(PredictArgs),
190 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
191 /// Requires format-prompt to have been run first. Uses provider from prompt.
192 ParseOutput,
193 /// Computes a score based on actual and expected patches
194 Score(PredictArgs),
195 /// Prepares a distillation dataset by copying expected outputs to
196 /// predicted outputs and removing actual outputs and prompts.
197 Distill,
198 /// Print aggregated scores
199 Eval(EvalArgs),
200 /// Generate eval examples by analyzing git commits from a repository
201 Synthesize(SynthesizeArgs),
202 /// Remove git repositories and worktrees
203 Clean,
204 /// Generate an evaluation example by splitting a chronologically-ordered commit
205 SplitCommit(SplitCommitArgs),
206 /// Truncate expected patch by the given criteria
207 TruncatePatch(TruncatePatchArgs),
208 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
209 Split(SplitArgs),
210 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
211 FilterLanguages(FilterLanguagesArgs),
212 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
213 ImportBatch(ImportBatchArgs),
214 /// Assess the quality of predictions using LLM-as-a-judge
215 Qa(qa::QaArgs),
216 /// Repair predictions that received poor QA scores by generating improved predictions
217 Repair(repair::RepairArgs),
218 /// Print all valid zeta formats (lowercase, one per line)
219 PrintZetaFormats,
220}
221
222impl Display for Command {
223 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
224 match self {
225 Command::Read(_) => write!(f, "read"),
226 Command::LoadProject => write!(f, "load-project"),
227 Command::Context => write!(f, "context"),
228 Command::FormatPrompt(args) => {
229 write!(f, "format-prompt --provider={}", args.provider)
230 }
231 Command::Predict(args) => match &args.provider {
232 Some(provider) => write!(f, "predict --provider={}", provider),
233 None => write!(f, "predict"),
234 },
235 Command::ParseOutput => write!(f, "parse-output"),
236 Command::Score(args) => match &args.provider {
237 Some(provider) => write!(f, "score --provider={}", provider),
238 None => write!(f, "score"),
239 },
240 Command::Distill => write!(f, "distill"),
241 Command::Eval(args) => match &args.predict.provider {
242 Some(provider) => write!(f, "eval --provider={}", provider),
243 None => write!(f, "eval"),
244 },
245 Command::Synthesize(args) => {
246 write!(f, "synthesize --repos {}", args.repos.join(" "))
247 }
248 Command::Clean => write!(f, "clean"),
249 Command::SplitCommit(_) => write!(f, "split-commit"),
250 Command::TruncatePatch(_) => write!(f, "truncate-patch"),
251 Command::Split(_) => write!(f, "split"),
252 Command::FilterLanguages(_) => write!(f, "filter-languages"),
253 Command::ImportBatch(args) => {
254 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
255 }
256 Command::Qa(_) => {
257 write!(f, "qa")
258 }
259 Command::Repair(_) => {
260 write!(f, "repair")
261 }
262 Command::PrintZetaFormats => {
263 write!(f, "print-zeta-formats")
264 }
265 }
266 }
267}
268
269#[derive(Debug, Args, Clone)]
270#[command(after_help = INPUTS_HELP)]
271struct ReadArgs {}
272
273#[derive(Debug, Args, Clone)]
274struct FormatPromptArgs {
275 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
276 provider: PredictionProvider,
277}
278
279#[derive(Debug, Args, Clone)]
280struct PredictArgs {
281 #[clap(long, short('p'))]
282 provider: Option<PredictionProvider>,
283 #[clap(long, default_value_t = 1)]
284 repetitions: usize,
285 /// Only use cached responses, don't queue new requests for batching
286 #[clap(long)]
287 cache_only: bool,
288}
289
290#[derive(Debug, Args, Clone)]
291struct EvalArgs {
292 #[clap(flatten)]
293 predict: PredictArgs,
294 /// Path to write summary scores as JSON
295 #[clap(long)]
296 summary_json: Option<PathBuf>,
297}
298
299#[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)]
300pub enum TeacherBackend {
301 Sonnet46,
302 #[default]
303 Sonnet45,
304 Gpt52,
305}
306
307impl std::fmt::Display for TeacherBackend {
308 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309 match self {
310 TeacherBackend::Sonnet46 => write!(f, "sonnet46"),
311 TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
312 TeacherBackend::Gpt52 => write!(f, "gpt52"),
313 }
314 }
315}
316
317impl std::str::FromStr for TeacherBackend {
318 type Err = anyhow::Error;
319
320 fn from_str(s: &str) -> Result<Self, Self::Err> {
321 match s.to_lowercase().as_str() {
322 "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
323 "sonnet46" => Ok(TeacherBackend::Sonnet46),
324 "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
325 "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
326 _ => anyhow::bail!(
327 "unknown teacher backend `{s}`. Valid options: sonnet45, sonnet46, gpt52"
328 ),
329 }
330 }
331}
332
333impl TeacherBackend {
334 pub fn model_name(&self) -> &'static str {
335 match self {
336 TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
337 TeacherBackend::Sonnet46 => "claude-sonnet-4-6",
338 TeacherBackend::Gpt52 => "gpt-5.2",
339 }
340 }
341}
342
343#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
344enum PredictionProvider {
345 Sweep,
346 Mercury,
347 Zeta1,
348 Zeta2(ZetaFormat),
349 Teacher(TeacherBackend),
350 TeacherNonBatching(TeacherBackend),
351 Repair,
352}
353
354impl Default for PredictionProvider {
355 fn default() -> Self {
356 PredictionProvider::Zeta2(ZetaFormat::default())
357 }
358}
359
360impl std::fmt::Display for PredictionProvider {
361 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
362 match self {
363 PredictionProvider::Sweep => write!(f, "sweep"),
364 PredictionProvider::Mercury => write!(f, "mercury"),
365 PredictionProvider::Zeta1 => write!(f, "zeta1"),
366 PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
367 PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
368 PredictionProvider::TeacherNonBatching(backend) => {
369 write!(f, "teacher-non-batching:{backend}")
370 }
371 PredictionProvider::Repair => write!(f, "repair"),
372 }
373 }
374}
375
376impl std::str::FromStr for PredictionProvider {
377 type Err = anyhow::Error;
378
379 fn from_str(s: &str) -> Result<Self, Self::Err> {
380 let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
381
382 let provider_lower = provider.to_lowercase();
383 match provider_lower.as_str() {
384 "sweep" => Ok(PredictionProvider::Sweep),
385 "mercury" => Ok(PredictionProvider::Mercury),
386 "zeta1" => Ok(PredictionProvider::Zeta1),
387 "zeta2" => {
388 let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
389 Ok(PredictionProvider::Zeta2(format))
390 }
391 "teacher" => {
392 let backend = arg
393 .map(|a| a.parse())
394 .transpose()?
395 .unwrap_or(TeacherBackend::default());
396 Ok(PredictionProvider::Teacher(backend))
397 }
398 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
399 let backend = arg
400 .map(|a| a.parse())
401 .transpose()?
402 .unwrap_or(TeacherBackend::default());
403 Ok(PredictionProvider::TeacherNonBatching(backend))
404 }
405 "repair" => Ok(PredictionProvider::Repair),
406 _ => {
407 anyhow::bail!(
408 "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
409 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
410 For teacher, you can specify a backend like `teacher:sonnet46` or `teacher:gpt52`.\n\
411 Available zeta versions:\n{}",
412 ZetaFormat::options_as_string()
413 )
414 }
415 }
416 }
417}
418
419impl Serialize for PredictionProvider {
420 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
421 where
422 S: Serializer,
423 {
424 serializer.serialize_str(&self.to_string())
425 }
426}
427
428impl<'de> Deserialize<'de> for PredictionProvider {
429 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
430 where
431 D: Deserializer<'de>,
432 {
433 let s = String::deserialize(deserializer)?;
434 s.parse().map_err(serde::de::Error::custom)
435 }
436}
437
438#[derive(Debug, Args, Clone)]
439struct SynthesizeArgs {
440 /// Repository URLs (git@github.com:owner/repo or https://...)
441 #[clap(long, required = true, num_args = 1..)]
442 repos: Vec<String>,
443
444 /// Number of examples to generate per repository
445 #[clap(long, default_value_t = 5)]
446 count: usize,
447
448 /// Maximum commits to scan per repository before giving up
449 #[clap(long, default_value_t = 100)]
450 max_commits: usize,
451
452 /// Ignore state file and reprocess all commits
453 #[clap(long)]
454 fresh: bool,
455}
456
457#[derive(Debug, Args, Clone)]
458struct ImportBatchArgs {
459 /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
460 #[clap(long, required = true, num_args = 1..)]
461 batch_ids: Vec<String>,
462 /// Which provider's batches to import (anthropic or openai)
463 #[clap(long, default_value = "anthropic")]
464 provider: BatchProvider,
465}
466
467#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
468enum BatchProvider {
469 Anthropic,
470 Openai,
471}
472
473impl EpArgs {
474 fn output_path(&self) -> Option<PathBuf> {
475 if self.in_place {
476 if self.inputs.len() == 1 {
477 self.inputs.first().cloned()
478 } else {
479 panic!("--in-place requires exactly one input file")
480 }
481 } else {
482 self.output.clone()
483 }
484 }
485}
486
487/// Minimum Zed version required for Snowflake queries.
488/// This version introduced the current request schema with predicted edits in the edit
489/// history, and open source repos distinguished.
490const MIN_CAPTURE_VERSION: pull_examples::MinCaptureVersion = pull_examples::MinCaptureVersion {
491 minor: 224,
492 patch: 1,
493};
494
495fn deduplicate_examples(examples: &mut Vec<Example>, max_per_cluster: usize) {
496 let total_before_exact = examples.len();
497 let mut seen_positions = HashSet::default();
498 examples.retain(|example| seen_positions.insert(example.spec.cursor_position.clone()));
499 log::info!(
500 "exact duplicate filter: {total_before_exact} examples → {} examples ({} removed)",
501 examples.len(),
502 total_before_exact - examples.len(),
503 );
504
505 const JACCARD_THRESHOLD: f64 = 0.5;
506 const NUM_HASHES: usize = 128;
507 const TOKEN_NGRAM_SIZE: usize = 5;
508
509 let (num_bands, band_width) = calculate_minhash_params(JACCARD_THRESHOLD, NUM_HASHES);
510 let num_hashes = num_bands * band_width;
511 let minhasher = MinHasher32::new(num_hashes);
512 let mut index: MinHashIndex<u32, usize> =
513 MinHashIndex::new(num_bands, band_width, JACCARD_THRESHOLD);
514
515 let signatures: Vec<Vec<u32>> = examples
516 .iter()
517 .map(|example| {
518 let shingles = code_token_ngrams(&example.spec.cursor_position, TOKEN_NGRAM_SIZE);
519 minhasher.create_signature(shingles.iter())
520 })
521 .collect();
522
523 for (id, signature) in signatures.iter().enumerate() {
524 index.insert(id, signature.clone());
525 }
526
527 // Build clusters via union-find on LSH candidate pairs.
528 let mut parent: Vec<usize> = (0..examples.len()).collect();
529
530 fn find(parent: &mut Vec<usize>, mut x: usize) -> usize {
531 while parent[x] != x {
532 parent[x] = parent[parent[x]];
533 x = parent[x];
534 }
535 x
536 }
537
538 for (id, signature) in signatures.iter().enumerate() {
539 for candidate in index.query_owned(signature) {
540 let (a, b) = (find(&mut parent, id), find(&mut parent, candidate));
541 if a != b {
542 parent[a] = b;
543 }
544 }
545 }
546
547 let mut clusters: HashMap<usize, Vec<usize>> = HashMap::default();
548 for id in 0..examples.len() {
549 clusters.entry(find(&mut parent, id)).or_default().push(id);
550 }
551
552 let mut keep: HashSet<usize> = HashSet::default();
553 for members in clusters.values() {
554 let selected = greedy_max_min_diverse(members, &signatures, max_per_cluster);
555 keep.extend(selected);
556 }
557
558 let total = examples.len();
559 let mut kept_indices: Vec<usize> = keep.into_iter().collect();
560 kept_indices.sort();
561
562 let mut retained = Vec::with_capacity(kept_indices.len());
563 for index in kept_indices.into_iter().rev() {
564 retained.push(examples.swap_remove(index));
565 }
566 retained.reverse();
567
568 *examples = retained;
569 log::info!(
570 "near-duplicate filter: {total} examples → {} examples ({} removed)",
571 examples.len(),
572 total - examples.len(),
573 );
574}
575
576fn greedy_max_min_diverse(members: &[usize], signatures: &[Vec<u32>], k: usize) -> Vec<usize> {
577 if members.len() <= k {
578 return members.to_vec();
579 }
580
581 let mut selected = vec![members[0]];
582 let mut min_dist: HashMap<usize, f64> = HashMap::default();
583 for &member in &members[1..] {
584 let dist = 1.0 - compute_minhash_similarity(&signatures[selected[0]], &signatures[member]);
585 min_dist.insert(member, dist);
586 }
587
588 while selected.len() < k {
589 let &best = min_dist
590 .iter()
591 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
592 .map(|(id, _)| id)
593 .expect("min_dist should not be empty when selected.len() < k");
594 selected.push(best);
595 min_dist.remove(&best);
596
597 let best_sig = &signatures[best];
598 for (member, current_min) in min_dist.iter_mut() {
599 let dist = 1.0 - compute_minhash_similarity(best_sig, &signatures[*member]);
600 if dist < *current_min {
601 *current_min = dist;
602 }
603 }
604 }
605
606 selected
607}
608
609fn code_token_ngrams(code: &str, ngram_size: usize) -> Vec<String> {
610 let tokens: Vec<&str> = word_diff::tokenize(code)
611 .into_iter()
612 .filter(|t| !t.trim().is_empty())
613 .collect();
614
615 if tokens.len() < ngram_size {
616 return vec![tokens.join("\0")];
617 }
618
619 tokens
620 .windows(ngram_size)
621 .map(|window| window.join("\0"))
622 .collect()
623}
624
625async fn load_examples(
626 http_client: Arc<dyn http_client::HttpClient>,
627 args: &EpArgs,
628 output_path: Option<&PathBuf>,
629 background_executor: BackgroundExecutor,
630) -> anyhow::Result<Vec<Example>> {
631 let mut captured_after_timestamps = Vec::new();
632 let mut rejected_after_timestamps = Vec::new();
633 let mut requested_after_timestamps = Vec::new();
634 let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
635 Vec::new();
636 let mut file_inputs = Vec::new();
637
638 for input in &args.inputs {
639 let input_string = input.to_string_lossy();
640 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
641 captured_after_timestamps.push(timestamp.to_string());
642 } else if let Some(timestamp) =
643 pull_examples::parse_rejected_after_input(input_string.as_ref())
644 {
645 rejected_after_timestamps.push(timestamp.to_string());
646 } else if let Some(timestamp) =
647 pull_examples::parse_requested_after_input(input_string.as_ref())
648 {
649 requested_after_timestamps.push(timestamp.to_string());
650 } else if let Some((timestamp, rating_filter)) =
651 pull_examples::parse_rated_after_input(input_string.as_ref())
652 {
653 rated_after_inputs.push((timestamp.to_string(), rating_filter));
654 } else {
655 file_inputs.push(input.clone());
656 }
657 }
658
659 let mut examples = read_example_files(&file_inputs);
660
661 // Apply offset to file examples first, then pass remaining offset to Snowflake.
662 let file_example_count = examples.len();
663 let remaining_offset = if let Some(offset) = args.offset {
664 if offset >= file_example_count {
665 examples.clear();
666 offset - file_example_count
667 } else {
668 examples.splice(0..offset, []);
669 0
670 }
671 } else {
672 0
673 };
674
675 Progress::global().set_total_examples(examples.len());
676
677 let remaining_limit_for_snowflake =
678 args.limit.map(|limit| limit.saturating_sub(examples.len()));
679
680 if let Some(0) = remaining_limit_for_snowflake {
681 log::info!(
682 "skipping Snowflake inputs because --limit is already satisfied by example files"
683 );
684 } else {
685 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
686
687 if !rejected_after_timestamps.is_empty() {
688 rejected_after_timestamps.sort();
689
690 let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
691 http_client.clone(),
692 &rejected_after_timestamps,
693 max_rows_per_timestamp,
694 remaining_offset,
695 background_executor.clone(),
696 Some(MIN_CAPTURE_VERSION),
697 )
698 .await?;
699 examples.append(&mut rejected_examples);
700 }
701
702 if !requested_after_timestamps.is_empty() {
703 requested_after_timestamps.sort();
704
705 let mut requested_examples = pull_examples::fetch_requested_examples_after(
706 http_client.clone(),
707 &requested_after_timestamps,
708 max_rows_per_timestamp,
709 remaining_offset,
710 background_executor.clone(),
711 Some(MIN_CAPTURE_VERSION),
712 )
713 .await?;
714 examples.append(&mut requested_examples);
715 }
716
717 if !rated_after_inputs.is_empty() {
718 rated_after_inputs.sort();
719
720 let mut rated_examples = pull_examples::fetch_rated_examples_after(
721 http_client,
722 &rated_after_inputs,
723 max_rows_per_timestamp,
724 remaining_offset,
725 background_executor,
726 Some(MIN_CAPTURE_VERSION),
727 )
728 .await?;
729 examples.append(&mut rated_examples);
730 }
731 }
732
733 crate::example::sort_examples_by_repo_and_rev(&mut examples);
734
735 if let Some(name_filter) = &args.name {
736 examples.retain(|example| example.spec.name.contains(name_filter));
737 }
738 if let Some(repo_filter) = &args.repo {
739 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
740 }
741
742 // Skip resume logic for --in-place since input and output are the same file,
743 // which would incorrectly treat all input examples as already processed.
744 if !args.in_place {
745 if let Some(path) = output_path
746 && let Some(command) = &args.command
747 {
748 resume_from_output(path, &mut examples, command);
749 }
750 }
751
752 if let Some(max_duplicates) = args.max_duplicates {
753 deduplicate_examples(&mut examples, max_duplicates);
754 }
755
756 if let Some(limit) = args.limit {
757 examples.truncate(limit);
758 }
759
760 let progress = Progress::global();
761 progress.set_total_examples(examples.len());
762 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
763
764 Ok(examples)
765}
766
767fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
768 let mut hasher = collections::FxHasher::default();
769 spec.hash(&mut hasher);
770 hasher.finish()
771}
772
773fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
774 let file = match File::open(path) {
775 Ok(f) => f,
776 Err(_) => return,
777 };
778
779 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
780
781 let reader = BufReader::new(file);
782 let mut kept_lines = Vec::new();
783 let mut kept_hashes = HashSet::default();
784
785 for line in reader.lines() {
786 let line = match line {
787 Ok(l) => l,
788 Err(_) => continue,
789 };
790
791 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
792 let hash = spec_hash(&output_example.spec);
793 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
794 let is_complete = match command {
795 Command::Qa(_) => output_example
796 .qa
797 .first()
798 .and_then(|q| q.as_ref())
799 .and_then(|q| q.confidence)
800 .is_some(),
801 Command::Repair(_) => output_example.predictions.iter().any(|p| {
802 p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
803 }),
804 _ => true,
805 };
806 if is_complete {
807 kept_hashes.insert(hash);
808 kept_lines.push(line);
809 }
810 }
811 }
812 }
813
814 let total = examples.len();
815 let already_processed = kept_hashes.len();
816
817 eprintln!(
818 "Resuming: {}/{} examples already processed",
819 already_processed, total
820 );
821
822 let file = OpenOptions::new()
823 .write(true)
824 .truncate(true)
825 .open(path)
826 .expect("Failed to open output file for rewriting");
827 let mut writer = BufWriter::new(file);
828 for line in &kept_lines {
829 writeln!(writer, "{}", line).expect("Failed to write to output file");
830 }
831 writer.flush().expect("Failed to flush output file");
832
833 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
834}
835
836fn main() {
837 let args = EpArgs::parse();
838
839 if args.printenv {
840 ::util::shell_env::print_env();
841 return;
842 }
843
844 let output = args.output_path();
845
846 if args.markdown && output.is_none() {
847 eprintln!("--markdown requires -o to specify the output directory");
848 std::process::exit(1);
849 }
850
851 let command = match &args.command {
852 Some(cmd) => cmd.clone(),
853 None => {
854 EpArgs::command().print_help().unwrap();
855 return;
856 }
857 };
858
859 match &command {
860 Command::ImportBatch(import_args) => {
861 smol::block_on(async {
862 match import_args.provider {
863 BatchProvider::Anthropic => {
864 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
865 .expect("Failed to create Anthropic client");
866 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
867 eprintln!("Error importing Anthropic batches: {:?}", e);
868 std::process::exit(1);
869 }
870 }
871 BatchProvider::Openai => {
872 let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
873 .expect("Failed to create OpenAI client");
874 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
875 eprintln!("Error importing OpenAI batches: {:?}", e);
876 std::process::exit(1);
877 }
878 }
879 }
880 println!(
881 "Successfully imported {} batch(es)",
882 import_args.batch_ids.len()
883 );
884 });
885 return;
886 }
887 Command::Clean => {
888 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
889 return;
890 }
891 Command::PrintZetaFormats => {
892 use strum::IntoEnumIterator as _;
893 for format in ZetaFormat::iter() {
894 println!("{}", format.to_string().to_lowercase());
895 }
896 return;
897 }
898
899 Command::Synthesize(synth_args) => {
900 let Some(output_dir) = args.output else {
901 panic!("output dir is required");
902 };
903 let config = SynthesizeConfig {
904 repo_urls: synth_args.repos.clone(),
905 count: synth_args.count,
906 max_commits: synth_args.max_commits,
907 output_dir,
908 fresh: synth_args.fresh,
909 };
910 smol::block_on(async {
911 if let Err(e) = run_synthesize(config).await {
912 eprintln!("Error: {:?}", e);
913 std::process::exit(1);
914 }
915 });
916 return;
917 }
918 Command::SplitCommit(split_commit_args) => {
919 if let Err(error) = split_commit::run_split_commit(
920 split_commit_args,
921 &args.inputs,
922 output.as_ref(),
923 args.failed,
924 ) {
925 eprintln!("{error:#}");
926 std::process::exit(1);
927 }
928 return;
929 }
930 Command::TruncatePatch(truncate_args) => {
931 if let Err(error) =
932 truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
933 {
934 eprintln!("{error:#}");
935 std::process::exit(1);
936 }
937 return;
938 }
939 Command::Split(split_args) => {
940 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
941 eprintln!("{error:#}");
942 std::process::exit(1);
943 }
944 return;
945 }
946 Command::FilterLanguages(filter_args) => {
947 if let Err(error) =
948 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
949 {
950 eprintln!("{error:#}");
951 std::process::exit(1);
952 }
953 return;
954 }
955
956 _ => {}
957 }
958
959 let http_client = Arc::new(ReqwestClient::new());
960 let app = gpui_platform::headless().with_http_client(http_client);
961
962 app.run(move |cx| {
963 let app_state = Arc::new(headless::init(cx));
964 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
965
966 cx.spawn(async move |cx| {
967 let result = async {
968 let examples = load_examples(
969 app_state.client.http_client(),
970 &args,
971 output.as_ref(),
972 cx.background_executor().clone(),
973 )
974 .await?;
975
976 match &command {
977 Command::Predict(args) | Command::Score(args) => {
978 predict::sync_batches(args.provider.as_ref()).await?;
979 }
980 Command::Eval(args) => {
981 predict::sync_batches(args.predict.provider.as_ref()).await?;
982 }
983 Command::Qa(args) => {
984 qa::sync_batches(args).await?;
985 }
986 Command::Repair(args) => {
987 repair::sync_batches(args).await?;
988 }
989 _ => (),
990 }
991
992 let failfast_on_single_example = examples.len() == 1;
993
994 // For --markdown mode, create the output directory if it doesn't exist
995 if args.markdown {
996 let dir = output.as_ref().expect("--markdown requires -o");
997 if !dir.exists() {
998 std::fs::create_dir_all(dir)
999 .expect("Failed to create markdown output directory");
1000 }
1001 }
1002
1003 // Set up JSONL output writer (not used in markdown mode)
1004 let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
1005 let mut in_place_temp_path: Option<PathBuf> = None;
1006 if !args.markdown
1007 && let Some(output_path) = output.as_ref()
1008 {
1009 let write_path = if args.in_place {
1010 let temp = output_path.with_extension("jsonl.tmp");
1011 in_place_temp_path = Some(temp.clone());
1012 temp
1013 } else {
1014 output_path.clone()
1015 };
1016
1017 let file = OpenOptions::new()
1018 .create(true)
1019 .write(true)
1020 .truncate(args.in_place)
1021 .append(!args.in_place)
1022 .open(&write_path)
1023 .expect("Failed to open output file");
1024
1025 let mut writer = BufWriter::new(file);
1026 let (sender, mut receiver) = mpsc::unbounded::<String>();
1027 cx.background_spawn(async move {
1028 while let Some(line) = receiver.next().await {
1029 writeln!(writer, "{}", line).expect("Failed to write example");
1030 writer.flush().expect("Failed to flush output");
1031 }
1032 })
1033 .detach();
1034 output_sender = Some(sender);
1035 }
1036
1037 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
1038 let finished_examples = Mutex::new(Vec::new());
1039
1040 let mut tasks = Vec::new();
1041 for _ in 0..args.max_parallelism {
1042 tasks.push(async {
1043 loop {
1044 let Some(mut repo_examples) =
1045 grouped_examples.lock().unwrap().pop_front()
1046 else {
1047 break;
1048 };
1049 for example in &mut repo_examples {
1050 let example_progress =
1051 Progress::global().start_group(&example.spec.name);
1052
1053 let result = async {
1054 match &command {
1055 Command::Read(_) => {}
1056 Command::LoadProject => {
1057 run_load_project(
1058 example,
1059 app_state.clone(),
1060 &example_progress,
1061 cx.clone(),
1062 )
1063 .await?;
1064 }
1065 Command::Context => {
1066 run_context_retrieval(
1067 example,
1068 app_state.clone(),
1069 &example_progress,
1070 cx.clone(),
1071 )
1072 .await?;
1073 }
1074 Command::FormatPrompt(args) => {
1075 run_format_prompt(
1076 example,
1077 args,
1078 app_state.clone(),
1079 &example_progress,
1080 cx.clone(),
1081 )
1082 .await?;
1083 }
1084 Command::Predict(args) => {
1085 run_prediction(
1086 example,
1087 args,
1088 app_state.clone(),
1089 &example_progress,
1090 cx.clone(),
1091 )
1092 .await?;
1093 }
1094 Command::ParseOutput => {
1095 parse_output::run_parse_output(example)?;
1096 }
1097 Command::Distill => {
1098 run_distill(example).await?;
1099 }
1100 Command::Score(args) => {
1101 run_scoring(
1102 example,
1103 args,
1104 app_state.clone(),
1105 &example_progress,
1106 cx.clone(),
1107 )
1108 .await?;
1109 }
1110 Command::Eval(args) => {
1111 run_scoring(
1112 example,
1113 &args.predict,
1114 app_state.clone(),
1115 &example_progress,
1116 cx.clone(),
1117 )
1118 .await?;
1119 }
1120 Command::Qa(args) => {
1121 qa::run_qa(example, args, &example_progress).await?;
1122 }
1123 Command::Repair(args) => {
1124 repair::run_repair(example, args, &example_progress)
1125 .await?;
1126 }
1127 Command::Clean
1128 | Command::Synthesize(_)
1129 | Command::SplitCommit(_)
1130 | Command::Split(_)
1131 | Command::TruncatePatch(_)
1132 | Command::FilterLanguages(_)
1133 | Command::ImportBatch(_)
1134 | Command::PrintZetaFormats => {
1135 unreachable!()
1136 }
1137 }
1138 anyhow::Ok(())
1139 }
1140 .await;
1141
1142 let failed = if let Err(error) = result {
1143 handle_error(
1144 error,
1145 &args,
1146 &command,
1147 &app_state,
1148 failfast_on_single_example,
1149 &example,
1150 )
1151 .await;
1152 true
1153 } else {
1154 false
1155 };
1156
1157 let should_write = !failed || args.failed == FailedHandling::Keep;
1158 if should_write {
1159 if args.markdown {
1160 let markdown_dir =
1161 output.as_ref().expect("--markdown requires -o");
1162 let filename = format!("{}.md", example.spec.filename());
1163 let path = markdown_dir.join(&filename);
1164 let markdown = example.spec.to_markdown();
1165 std::fs::write(&path, &markdown)
1166 .expect("Failed to write markdown file");
1167 } else if let Some(ref mut sender) = output_sender.clone() {
1168 let line = serde_json::to_string(&example).unwrap();
1169 sender
1170 .send(line)
1171 .await
1172 .expect("Failed to send to output writer");
1173 } else if args.output.is_none()
1174 && !matches!(command, Command::Eval(_))
1175 {
1176 let line = serde_json::to_string(&example).unwrap();
1177 println!("{}", line);
1178 }
1179 }
1180 }
1181
1182 let project = repo_examples
1183 .iter()
1184 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()));
1185
1186 if let Some(project) = project {
1187 let mut cx = cx.clone();
1188
1189 let shutdown_task: Task<()> =
1190 project.update(&mut cx, |project, cx| {
1191 let lsp_store = project.lsp_store();
1192 lsp_store.update(cx, |lsp_store, cx| {
1193 lsp_store.shutdown_all_language_servers(cx)
1194 })
1195 });
1196
1197 shutdown_task.await;
1198
1199 if let Some(ep_store) =
1200 cx.update(|cx| EditPredictionStore::try_global(cx))
1201 {
1202 ep_store.update(&mut cx, |store, _| {
1203 store.remove_project(&project);
1204 });
1205 }
1206 }
1207
1208 for example in &mut repo_examples {
1209 example.state.take();
1210 }
1211 finished_examples
1212 .lock()
1213 .unwrap()
1214 .extend_from_slice(&repo_examples);
1215 }
1216 });
1217 }
1218 futures::future::join_all(tasks).await;
1219
1220 Progress::global().finalize();
1221
1222 match &command {
1223 Command::Predict(args) | Command::Score(args) => {
1224 predict::sync_batches(args.provider.as_ref()).await?;
1225 }
1226 Command::Eval(args) => {
1227 predict::sync_batches(args.predict.provider.as_ref()).await?;
1228 }
1229 Command::Qa(args) => {
1230 qa::sync_batches(args).await?;
1231 }
1232 Command::Repair(args) => {
1233 repair::sync_batches(args).await?;
1234 }
1235 _ => (),
1236 }
1237
1238 match &command {
1239 Command::Eval(args) => {
1240 let examples = finished_examples.lock().unwrap();
1241 score::print_report(&examples);
1242 if let Some(summary_path) = &args.summary_json {
1243 score::write_summary_json(&examples, summary_path)?;
1244 }
1245 }
1246 Command::Repair(args) => {
1247 let examples = finished_examples.lock().unwrap();
1248 repair::print_report(&examples, args.confidence_threshold);
1249 }
1250 _ => (),
1251 };
1252
1253 // For --in-place, atomically rename temp file to original
1254 if let Some(temp_path) = &in_place_temp_path {
1255 let final_path = output.as_ref().expect("in_place_temp_path requires output");
1256 std::fs::rename(temp_path, final_path)
1257 .expect("Failed to rename temp file to final output");
1258 }
1259
1260 anyhow::Ok(())
1261 }
1262 .await;
1263
1264 if let Err(e) = result {
1265 panic!("Fatal error: {:?}", e);
1266 }
1267
1268 let _ = cx.update(|cx| cx.quit());
1269 })
1270 .detach();
1271 });
1272}
1273
1274async fn handle_error(
1275 error: anyhow::Error,
1276 args: &EpArgs,
1277 command: &Command,
1278 app_state: &Arc<headless::EpAppState>,
1279 failfast_on_single_example: bool,
1280 example: &Example,
1281) {
1282 Progress::global().increment_failed();
1283
1284 let msg;
1285 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1286 let example_name = example.spec.filename();
1287
1288 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1289 app_state
1290 .fs
1291 .write(
1292 &failed_example_path,
1293 &serde_json::to_vec_pretty(&example).unwrap(),
1294 )
1295 .await
1296 .unwrap();
1297 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1298 app_state
1299 .fs
1300 .write(&err_path, format!("{error:?}").as_bytes())
1301 .await
1302 .unwrap();
1303
1304 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1305 let mut file = OpenOptions::new()
1306 .create(true)
1307 .append(true)
1308 .open(&failed_jsonl_path)
1309 .expect("Failed to open failed.jsonl");
1310 writeln!(file, "{}", serde_json::to_string(example).unwrap())
1311 .expect("Failed to write to failed.jsonl");
1312
1313 let cursor_path = match example.repo_name() {
1314 Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1315 Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1316 };
1317 msg = format!(
1318 indoc::indoc! {"
1319 While processing \"{}\":
1320
1321 \x1b[31m{:?}\x1b[0m
1322
1323 Example: \x1b[36m{}\x1b[0m
1324 Error file: \x1b[36m{}\x1b[0m
1325 Cursor file: \x1b[36m{}\x1b[0m
1326 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1327 "},
1328 example.spec.name,
1329 error,
1330 failed_example_path.display(),
1331 err_path.display(),
1332 cursor_path.display(),
1333 command,
1334 failed_example_path.display(),
1335 );
1336 } else {
1337 msg = format!(
1338 indoc::indoc! {"
1339 While processing \"{}\":
1340
1341 \x1b[31m{:?}\x1b[0m
1342 "},
1343 example.spec.name, error
1344 );
1345 }
1346
1347 if args.failfast || failfast_on_single_example {
1348 Progress::global().finalize();
1349 panic!("{}", msg);
1350 } else {
1351 log::error!("{}", msg);
1352 }
1353}