1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod load_project;
9mod metrics;
10mod openai_client;
11mod parse_output;
12mod paths;
13mod predict;
14mod progress;
15mod prompt_assets;
16mod pull_examples;
17mod qa;
18mod reorder_patch;
19mod repair;
20mod retrieve_context;
21mod reversal_tracking;
22mod score;
23mod split_commit;
24mod split_dataset;
25
26mod synthesize;
27mod truncate_expected_patch;
28mod word_diff;
29use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
30use collections::{HashMap, HashSet};
31use edit_prediction::EditPredictionStore;
32use futures::channel::mpsc;
33use futures::{SinkExt as _, StreamExt as _};
34use gaoya::minhash::{
35 MinHashIndex, MinHasher, MinHasher32, calculate_minhash_params, compute_minhash_similarity,
36};
37use gpui::{AppContext as _, BackgroundExecutor, Task};
38use zeta_prompt::ZetaFormat;
39
40use reqwest_client::ReqwestClient;
41use serde::{Deserialize, Deserializer, Serialize, Serializer};
42use std::env;
43use std::fmt::Display;
44use std::fs::{File, OpenOptions};
45use std::hash::{Hash, Hasher};
46use std::io::{BufRead, BufReader, BufWriter, Write};
47use std::sync::Mutex;
48use std::{path::PathBuf, sync::Arc};
49
50use crate::distill::run_distill;
51use crate::example::{Example, group_examples_by_repo, read_example_files};
52use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
53use crate::format_prompt::run_format_prompt;
54use crate::load_project::run_load_project;
55use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
56use crate::predict::run_prediction;
57use crate::progress::Progress;
58use crate::retrieve_context::run_context_retrieval;
59use crate::score::run_scoring;
60use crate::split_commit::SplitCommitArgs;
61use crate::split_dataset::SplitArgs;
62use crate::synthesize::{SynthesizeConfig, run_synthesize};
63use crate::truncate_expected_patch::TruncatePatchArgs;
64
65#[derive(Parser, Debug)]
66#[command(name = "ep")]
67struct EpArgs {
68 #[arg(long, default_value_t = false)]
69 printenv: bool,
70 #[clap(long, default_value_t = 10, global = true)]
71 max_parallelism: usize,
72 /// The limit for the number of examples to process
73 /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
74 #[clap(long, global = true)]
75 limit: Option<usize>,
76 #[clap(long, global = true)]
77 offset: Option<usize>,
78 /// Filter examples by name
79 #[clap(long, global = true)]
80 name: Option<String>,
81 /// Filter examples by repository
82 #[clap(long, global = true)]
83 repo: Option<String>,
84 /// Deduplicate by cursor position and keep at most this many examples per cluster
85 #[clap(long, global = true)]
86 max_duplicates: Option<usize>,
87 #[command(subcommand)]
88 command: Option<Command>,
89 /// Input file paths
90 #[clap(global = true)]
91 inputs: Vec<PathBuf>,
92 #[arg(long, short, global = true)]
93 output: Option<PathBuf>,
94 #[arg(long, short, global = true)]
95 in_place: bool,
96 #[arg(long, short, global = true)]
97 failfast: bool,
98 /// How to handle failed examples in output: keep them or skip them.
99 /// Failed examples are always logged to the run's failed directory.
100 #[arg(long, global = true, default_value = "keep")]
101 failed: FailedHandling,
102 /// Output as markdown files instead of JSONL. When set, -o specifies a directory
103 /// where one .md file per example will be written (named after each example).
104 #[arg(long, short, global = true)]
105 markdown: bool,
106}
107
108/// Controls whether failed examples are included in the main output.
109/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
110#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
111pub enum FailedHandling {
112 /// Include failed examples in the main output (default)
113 #[default]
114 Keep,
115 /// Exclude failed examples from the main output
116 Skip,
117 /// Skip writing files
118 SkipNoFiles,
119}
120
121const INPUTS_HELP: &str = r#"
122Inputs can be file paths or special specifiers:
123
124 path
125 Path to an example(s) file (.md, .json, or .jsonl)
126
127 captured-after:{timestamp}
128 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
129 These are examples captured via the "Capture Edit Prediction Example" action.
130
131 rejected-after:{timestamp}
132 Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
133 These are predictions that were shown to users but rejected (useful for DPO training).
134
135 rated-after:{timestamp}
136 Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
137 These are predictions that users explicitly rated as positive or negative via the
138 rate completions modal. Only zeta2 predictions are included.
139 - Positive ratings: output becomes expected_patches
140 - Negative ratings: output becomes rejected_patch
141
142 rated-positive-after:{timestamp}
143 Same as rated-after, but only fetches positively rated predictions.
144
145 rated-negative-after:{timestamp}
146 Same as rated-after, but only fetches negatively rated predictions.
147
148 Required environment variables to connect to Snowflake:
149 EP_SNOWFLAKE_API_KEY
150 EP_SNOWFLAKE_BASE_URL
151
152 Optional:
153 EP_SNOWFLAKE_ROLE
154
155Examples:
156
157 # Read examples from a file
158 ep read examples.jsonl -o output.jsonl
159
160 # Read captured examples after a timestamp
161 ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
162
163 # Read rejected predictions for DPO training
164 ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
165
166 # Read user-rated predictions
167 ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
168
169 # Read only positively rated predictions
170 ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
171
172 # Read only negatively rated predictions
173 ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
174
175 # Mix multiple input sources
176 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
177"#;
178
179#[derive(Subcommand, Debug, Clone)]
180enum Command {
181 /// Read examples from files or fetch from Snowflake, output as .jsonl
182 Read(ReadArgs),
183 /// Create git worktrees for each example and load file contents
184 LoadProject,
185 /// Retrieve context for input examples.
186 Context,
187 /// Generate a prompt string for a specific model
188 FormatPrompt(FormatPromptArgs),
189 /// Runs edit prediction
190 Predict(PredictArgs),
191 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
192 /// Requires format-prompt to have been run first. Uses provider from prompt.
193 ParseOutput,
194 /// Computes a score based on actual and expected patches
195 Score(PredictArgs),
196 /// Prepares a distillation dataset by copying expected outputs to
197 /// predicted outputs and removing actual outputs and prompts.
198 Distill,
199 /// Print aggregated scores
200 Eval(EvalArgs),
201 /// Generate eval examples by analyzing git commits from a repository
202 Synthesize(SynthesizeArgs),
203 /// Remove git repositories and worktrees
204 Clean,
205 /// Generate an evaluation example by splitting a chronologically-ordered commit
206 SplitCommit(SplitCommitArgs),
207 /// Truncate expected patch by the given criteria
208 TruncatePatch(TruncatePatchArgs),
209 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
210 Split(SplitArgs),
211 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
212 FilterLanguages(FilterLanguagesArgs),
213 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
214 ImportBatch(ImportBatchArgs),
215 /// Assess the quality of predictions using LLM-as-a-judge
216 Qa(qa::QaArgs),
217 /// Repair predictions that received poor QA scores by generating improved predictions
218 Repair(repair::RepairArgs),
219 /// Print all valid zeta formats (lowercase, one per line)
220 PrintZetaFormats,
221}
222
223impl Display for Command {
224 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
225 match self {
226 Command::Read(_) => write!(f, "read"),
227 Command::LoadProject => write!(f, "load-project"),
228 Command::Context => write!(f, "context"),
229 Command::FormatPrompt(args) => {
230 write!(f, "format-prompt --provider={}", args.provider)
231 }
232 Command::Predict(args) => match &args.provider {
233 Some(provider) => write!(f, "predict --provider={}", provider),
234 None => write!(f, "predict"),
235 },
236 Command::ParseOutput => write!(f, "parse-output"),
237 Command::Score(args) => match &args.provider {
238 Some(provider) => write!(f, "score --provider={}", provider),
239 None => write!(f, "score"),
240 },
241 Command::Distill => write!(f, "distill"),
242 Command::Eval(args) => match &args.predict.provider {
243 Some(provider) => write!(f, "eval --provider={}", provider),
244 None => write!(f, "eval"),
245 },
246 Command::Synthesize(args) => {
247 write!(f, "synthesize --repos {}", args.repos.join(" "))
248 }
249 Command::Clean => write!(f, "clean"),
250 Command::SplitCommit(_) => write!(f, "split-commit"),
251 Command::TruncatePatch(_) => write!(f, "truncate-patch"),
252 Command::Split(_) => write!(f, "split"),
253 Command::FilterLanguages(_) => write!(f, "filter-languages"),
254 Command::ImportBatch(args) => {
255 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
256 }
257 Command::Qa(_) => {
258 write!(f, "qa")
259 }
260 Command::Repair(_) => {
261 write!(f, "repair")
262 }
263 Command::PrintZetaFormats => {
264 write!(f, "print-zeta-formats")
265 }
266 }
267 }
268}
269
270#[derive(Debug, Args, Clone)]
271#[command(after_help = INPUTS_HELP)]
272struct ReadArgs {}
273
274#[derive(Debug, Args, Clone)]
275struct FormatPromptArgs {
276 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
277 provider: PredictionProvider,
278}
279
280#[derive(Debug, Args, Clone)]
281struct PredictArgs {
282 #[clap(long, short('p'))]
283 provider: Option<PredictionProvider>,
284 #[clap(long, default_value_t = 1)]
285 repetitions: usize,
286 /// Only use cached responses, don't queue new requests for batching
287 #[clap(long)]
288 cache_only: bool,
289}
290
291#[derive(Debug, Args, Clone)]
292struct EvalArgs {
293 #[clap(flatten)]
294 predict: PredictArgs,
295 /// Path to write summary scores as JSON
296 #[clap(long)]
297 summary_json: Option<PathBuf>,
298 /// Print all individual example lines (default: up to 20)
299 #[clap(long)]
300 verbose: bool,
301}
302
303#[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)]
304pub enum TeacherBackend {
305 Sonnet46,
306 #[default]
307 Sonnet45,
308 Gpt52,
309}
310
311impl std::fmt::Display for TeacherBackend {
312 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313 match self {
314 TeacherBackend::Sonnet46 => write!(f, "sonnet46"),
315 TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
316 TeacherBackend::Gpt52 => write!(f, "gpt52"),
317 }
318 }
319}
320
321impl std::str::FromStr for TeacherBackend {
322 type Err = anyhow::Error;
323
324 fn from_str(s: &str) -> Result<Self, Self::Err> {
325 match s.to_lowercase().as_str() {
326 "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
327 "sonnet46" => Ok(TeacherBackend::Sonnet46),
328 "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
329 "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
330 _ => anyhow::bail!(
331 "unknown teacher backend `{s}`. Valid options: sonnet45, sonnet46, gpt52"
332 ),
333 }
334 }
335}
336
337impl TeacherBackend {
338 pub fn model_name(&self) -> &'static str {
339 match self {
340 TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
341 TeacherBackend::Sonnet46 => "claude-sonnet-4-6",
342 TeacherBackend::Gpt52 => "gpt-5.2",
343 }
344 }
345}
346
347#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
348enum PredictionProvider {
349 Sweep,
350 Mercury,
351 Zeta1,
352 Zeta2(ZetaFormat),
353 Teacher(TeacherBackend),
354 TeacherNonBatching(TeacherBackend),
355 Repair,
356}
357
358impl Default for PredictionProvider {
359 fn default() -> Self {
360 PredictionProvider::Zeta2(ZetaFormat::default())
361 }
362}
363
364impl std::fmt::Display for PredictionProvider {
365 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
366 match self {
367 PredictionProvider::Sweep => write!(f, "sweep"),
368 PredictionProvider::Mercury => write!(f, "mercury"),
369 PredictionProvider::Zeta1 => write!(f, "zeta1"),
370 PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
371 PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
372 PredictionProvider::TeacherNonBatching(backend) => {
373 write!(f, "teacher-non-batching:{backend}")
374 }
375 PredictionProvider::Repair => write!(f, "repair"),
376 }
377 }
378}
379
380impl std::str::FromStr for PredictionProvider {
381 type Err = anyhow::Error;
382
383 fn from_str(s: &str) -> Result<Self, Self::Err> {
384 let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
385
386 let provider_lower = provider.to_lowercase();
387 match provider_lower.as_str() {
388 "sweep" => Ok(PredictionProvider::Sweep),
389 "mercury" => Ok(PredictionProvider::Mercury),
390 "zeta1" => Ok(PredictionProvider::Zeta1),
391 "zeta2" => {
392 let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
393 Ok(PredictionProvider::Zeta2(format))
394 }
395 "teacher" => {
396 let backend = arg
397 .map(|a| a.parse())
398 .transpose()?
399 .unwrap_or(TeacherBackend::default());
400 Ok(PredictionProvider::Teacher(backend))
401 }
402 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
403 let backend = arg
404 .map(|a| a.parse())
405 .transpose()?
406 .unwrap_or(TeacherBackend::default());
407 Ok(PredictionProvider::TeacherNonBatching(backend))
408 }
409 "repair" => Ok(PredictionProvider::Repair),
410 _ => {
411 anyhow::bail!(
412 "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
413 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
414 For teacher, you can specify a backend like `teacher:sonnet46` or `teacher:gpt52`.\n\
415 Available zeta versions:\n{}",
416 ZetaFormat::options_as_string()
417 )
418 }
419 }
420 }
421}
422
423impl Serialize for PredictionProvider {
424 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
425 where
426 S: Serializer,
427 {
428 serializer.serialize_str(&self.to_string())
429 }
430}
431
432impl<'de> Deserialize<'de> for PredictionProvider {
433 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
434 where
435 D: Deserializer<'de>,
436 {
437 let s = String::deserialize(deserializer)?;
438 s.parse().map_err(serde::de::Error::custom)
439 }
440}
441
442#[derive(Debug, Args, Clone)]
443struct SynthesizeArgs {
444 /// Repository URLs (git@github.com:owner/repo or https://...)
445 #[clap(long, required = true, num_args = 1..)]
446 repos: Vec<String>,
447
448 /// Number of examples to generate per repository
449 #[clap(long, default_value_t = 5)]
450 count: usize,
451
452 /// Maximum commits to scan per repository before giving up
453 #[clap(long, default_value_t = 100)]
454 max_commits: usize,
455
456 /// Ignore state file and reprocess all commits
457 #[clap(long)]
458 fresh: bool,
459}
460
461#[derive(Debug, Args, Clone)]
462struct ImportBatchArgs {
463 /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
464 #[clap(long, required = true, num_args = 1..)]
465 batch_ids: Vec<String>,
466 /// Which provider's batches to import (anthropic or openai)
467 #[clap(long, default_value = "anthropic")]
468 provider: BatchProvider,
469}
470
471#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
472enum BatchProvider {
473 Anthropic,
474 Openai,
475}
476
477impl EpArgs {
478 fn output_path(&self) -> Option<PathBuf> {
479 if self.in_place {
480 if self.inputs.len() == 1 {
481 self.inputs.first().cloned()
482 } else {
483 panic!("--in-place requires exactly one input file")
484 }
485 } else {
486 self.output.clone()
487 }
488 }
489}
490
491/// Minimum Zed version required for Snowflake queries.
492/// This version introduced the current request schema with predicted edits in the edit
493/// history, and open source repos distinguished.
494const MIN_CAPTURE_VERSION: pull_examples::MinCaptureVersion = pull_examples::MinCaptureVersion {
495 minor: 224,
496 patch: 1,
497};
498
499fn deduplicate_examples(examples: &mut Vec<Example>, max_per_cluster: usize) {
500 let total_before_exact = examples.len();
501 let mut seen_positions = HashSet::default();
502 examples.retain(|example| seen_positions.insert(example.spec.cursor_position.clone()));
503 log::info!(
504 "exact duplicate filter: {total_before_exact} examples → {} examples ({} removed)",
505 examples.len(),
506 total_before_exact - examples.len(),
507 );
508
509 const JACCARD_THRESHOLD: f64 = 0.5;
510 const NUM_HASHES: usize = 128;
511 const TOKEN_NGRAM_SIZE: usize = 5;
512
513 let (num_bands, band_width) = calculate_minhash_params(JACCARD_THRESHOLD, NUM_HASHES);
514 let num_hashes = num_bands * band_width;
515 let minhasher = MinHasher32::new(num_hashes);
516 let mut index: MinHashIndex<u32, usize> =
517 MinHashIndex::new(num_bands, band_width, JACCARD_THRESHOLD);
518
519 let signatures: Vec<Vec<u32>> = examples
520 .iter()
521 .map(|example| {
522 let shingles = code_token_ngrams(&example.spec.cursor_position, TOKEN_NGRAM_SIZE);
523 minhasher.create_signature(shingles.iter())
524 })
525 .collect();
526
527 for (id, signature) in signatures.iter().enumerate() {
528 index.insert(id, signature.clone());
529 }
530
531 // Build clusters via union-find on LSH candidate pairs.
532 let mut parent: Vec<usize> = (0..examples.len()).collect();
533
534 fn find(parent: &mut Vec<usize>, mut x: usize) -> usize {
535 while parent[x] != x {
536 parent[x] = parent[parent[x]];
537 x = parent[x];
538 }
539 x
540 }
541
542 for (id, signature) in signatures.iter().enumerate() {
543 for candidate in index.query_owned(signature) {
544 let (a, b) = (find(&mut parent, id), find(&mut parent, candidate));
545 if a != b {
546 parent[a] = b;
547 }
548 }
549 }
550
551 let mut clusters: HashMap<usize, Vec<usize>> = HashMap::default();
552 for id in 0..examples.len() {
553 clusters.entry(find(&mut parent, id)).or_default().push(id);
554 }
555
556 let mut keep: HashSet<usize> = HashSet::default();
557 for members in clusters.values() {
558 let selected = greedy_max_min_diverse(members, &signatures, max_per_cluster);
559 keep.extend(selected);
560 }
561
562 let total = examples.len();
563 let mut kept_indices: Vec<usize> = keep.into_iter().collect();
564 kept_indices.sort();
565
566 let mut retained = Vec::with_capacity(kept_indices.len());
567 for index in kept_indices.into_iter().rev() {
568 retained.push(examples.swap_remove(index));
569 }
570 retained.reverse();
571
572 *examples = retained;
573 log::info!(
574 "near-duplicate filter: {total} examples → {} examples ({} removed)",
575 examples.len(),
576 total - examples.len(),
577 );
578}
579
580fn greedy_max_min_diverse(members: &[usize], signatures: &[Vec<u32>], k: usize) -> Vec<usize> {
581 if members.len() <= k {
582 return members.to_vec();
583 }
584
585 let mut selected = vec![members[0]];
586 let mut min_dist: HashMap<usize, f64> = HashMap::default();
587 for &member in &members[1..] {
588 let dist = 1.0 - compute_minhash_similarity(&signatures[selected[0]], &signatures[member]);
589 min_dist.insert(member, dist);
590 }
591
592 while selected.len() < k {
593 let &best = min_dist
594 .iter()
595 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
596 .map(|(id, _)| id)
597 .expect("min_dist should not be empty when selected.len() < k");
598 selected.push(best);
599 min_dist.remove(&best);
600
601 let best_sig = &signatures[best];
602 for (member, current_min) in min_dist.iter_mut() {
603 let dist = 1.0 - compute_minhash_similarity(best_sig, &signatures[*member]);
604 if dist < *current_min {
605 *current_min = dist;
606 }
607 }
608 }
609
610 selected
611}
612
613fn code_token_ngrams(code: &str, ngram_size: usize) -> Vec<String> {
614 let tokens: Vec<&str> = word_diff::tokenize(code)
615 .into_iter()
616 .filter(|t| !t.trim().is_empty())
617 .collect();
618
619 if tokens.len() < ngram_size {
620 return vec![tokens.join("\0")];
621 }
622
623 tokens
624 .windows(ngram_size)
625 .map(|window| window.join("\0"))
626 .collect()
627}
628
629async fn load_examples(
630 http_client: Arc<dyn http_client::HttpClient>,
631 args: &EpArgs,
632 output_path: Option<&PathBuf>,
633 background_executor: BackgroundExecutor,
634) -> anyhow::Result<Vec<Example>> {
635 let mut captured_after_timestamps = Vec::new();
636 let mut rejected_after_timestamps = Vec::new();
637 let mut requested_after_timestamps = Vec::new();
638 let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
639 Vec::new();
640 let mut file_inputs = Vec::new();
641
642 for input in &args.inputs {
643 let input_string = input.to_string_lossy();
644 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
645 captured_after_timestamps.push(timestamp.to_string());
646 } else if let Some(timestamp) =
647 pull_examples::parse_rejected_after_input(input_string.as_ref())
648 {
649 rejected_after_timestamps.push(timestamp.to_string());
650 } else if let Some(timestamp) =
651 pull_examples::parse_requested_after_input(input_string.as_ref())
652 {
653 requested_after_timestamps.push(timestamp.to_string());
654 } else if let Some((timestamp, rating_filter)) =
655 pull_examples::parse_rated_after_input(input_string.as_ref())
656 {
657 rated_after_inputs.push((timestamp.to_string(), rating_filter));
658 } else {
659 file_inputs.push(input.clone());
660 }
661 }
662
663 let mut examples = read_example_files(&file_inputs);
664
665 // Apply offset to file examples first, then pass remaining offset to Snowflake.
666 let file_example_count = examples.len();
667 let remaining_offset = if let Some(offset) = args.offset {
668 if offset >= file_example_count {
669 examples.clear();
670 offset - file_example_count
671 } else {
672 examples.splice(0..offset, []);
673 0
674 }
675 } else {
676 0
677 };
678
679 Progress::global().set_total_examples(examples.len());
680
681 let remaining_limit_for_snowflake =
682 args.limit.map(|limit| limit.saturating_sub(examples.len()));
683
684 if let Some(0) = remaining_limit_for_snowflake {
685 log::info!(
686 "skipping Snowflake inputs because --limit is already satisfied by example files"
687 );
688 } else {
689 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
690
691 if !rejected_after_timestamps.is_empty() {
692 rejected_after_timestamps.sort();
693
694 let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
695 http_client.clone(),
696 &rejected_after_timestamps,
697 max_rows_per_timestamp,
698 remaining_offset,
699 background_executor.clone(),
700 Some(MIN_CAPTURE_VERSION),
701 )
702 .await?;
703 examples.append(&mut rejected_examples);
704 }
705
706 if !requested_after_timestamps.is_empty() {
707 requested_after_timestamps.sort();
708
709 let mut requested_examples = pull_examples::fetch_requested_examples_after(
710 http_client.clone(),
711 &requested_after_timestamps,
712 max_rows_per_timestamp,
713 remaining_offset,
714 background_executor.clone(),
715 Some(MIN_CAPTURE_VERSION),
716 )
717 .await?;
718 examples.append(&mut requested_examples);
719 }
720
721 if !rated_after_inputs.is_empty() {
722 rated_after_inputs.sort();
723
724 let mut rated_examples = pull_examples::fetch_rated_examples_after(
725 http_client,
726 &rated_after_inputs,
727 max_rows_per_timestamp,
728 remaining_offset,
729 background_executor,
730 Some(MIN_CAPTURE_VERSION),
731 )
732 .await?;
733 examples.append(&mut rated_examples);
734 }
735 }
736
737 crate::example::sort_examples_by_repo_and_rev(&mut examples);
738
739 if let Some(name_filter) = &args.name {
740 examples.retain(|example| example.spec.name.contains(name_filter));
741 }
742 if let Some(repo_filter) = &args.repo {
743 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
744 }
745
746 // Skip resume logic for --in-place since input and output are the same file,
747 // which would incorrectly treat all input examples as already processed.
748 if !args.in_place {
749 if let Some(path) = output_path
750 && let Some(command) = &args.command
751 {
752 resume_from_output(path, &mut examples, command);
753 }
754 }
755
756 if let Some(max_duplicates) = args.max_duplicates {
757 deduplicate_examples(&mut examples, max_duplicates);
758 }
759
760 if let Some(limit) = args.limit {
761 examples.truncate(limit);
762 }
763
764 let progress = Progress::global();
765 progress.set_total_examples(examples.len());
766 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
767
768 Ok(examples)
769}
770
771fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
772 let mut hasher = collections::FxHasher::default();
773 spec.hash(&mut hasher);
774 hasher.finish()
775}
776
777fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
778 let file = match File::open(path) {
779 Ok(f) => f,
780 Err(_) => return,
781 };
782
783 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
784
785 let reader = BufReader::new(file);
786 let mut kept_lines = Vec::new();
787 let mut kept_hashes = HashSet::default();
788
789 for line in reader.lines() {
790 let line = match line {
791 Ok(l) => l,
792 Err(_) => continue,
793 };
794
795 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
796 let hash = spec_hash(&output_example.spec);
797 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
798 let is_complete = match command {
799 Command::Qa(_) => output_example
800 .qa
801 .first()
802 .and_then(|q| q.as_ref())
803 .and_then(|q| q.confidence)
804 .is_some(),
805 Command::Repair(_) => output_example.predictions.iter().any(|p| {
806 p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
807 }),
808 _ => true,
809 };
810 if is_complete {
811 kept_hashes.insert(hash);
812 kept_lines.push(line);
813 }
814 }
815 }
816 }
817
818 let total = examples.len();
819 let already_processed = kept_hashes.len();
820
821 eprintln!(
822 "Resuming: {}/{} examples already processed",
823 already_processed, total
824 );
825
826 let file = OpenOptions::new()
827 .write(true)
828 .truncate(true)
829 .open(path)
830 .expect("Failed to open output file for rewriting");
831 let mut writer = BufWriter::new(file);
832 for line in &kept_lines {
833 writeln!(writer, "{}", line).expect("Failed to write to output file");
834 }
835 writer.flush().expect("Failed to flush output file");
836
837 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
838}
839
840fn main() {
841 let args = EpArgs::parse();
842
843 if args.printenv {
844 ::util::shell_env::print_env();
845 return;
846 }
847
848 let output = args.output_path();
849
850 if args.markdown && output.is_none() {
851 eprintln!("--markdown requires -o to specify the output directory");
852 std::process::exit(1);
853 }
854
855 let command = match &args.command {
856 Some(cmd) => cmd.clone(),
857 None => {
858 EpArgs::command().print_help().unwrap();
859 return;
860 }
861 };
862
863 match &command {
864 Command::ImportBatch(import_args) => {
865 smol::block_on(async {
866 match import_args.provider {
867 BatchProvider::Anthropic => {
868 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
869 .expect("Failed to create Anthropic client");
870 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
871 eprintln!("Error importing Anthropic batches: {:?}", e);
872 std::process::exit(1);
873 }
874 }
875 BatchProvider::Openai => {
876 let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
877 .expect("Failed to create OpenAI client");
878 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
879 eprintln!("Error importing OpenAI batches: {:?}", e);
880 std::process::exit(1);
881 }
882 }
883 }
884 println!(
885 "Successfully imported {} batch(es)",
886 import_args.batch_ids.len()
887 );
888 });
889 return;
890 }
891 Command::Clean => {
892 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
893 return;
894 }
895 Command::PrintZetaFormats => {
896 use strum::IntoEnumIterator as _;
897 for format in ZetaFormat::iter() {
898 println!("{}", format.to_string().to_lowercase());
899 }
900 return;
901 }
902
903 Command::Synthesize(synth_args) => {
904 let output_dir = if let Some(output_dir) = args.output {
905 output_dir
906 } else {
907 let default_output_dir = env::current_dir()
908 .unwrap()
909 .join("crates/edit_prediction_cli/evals-generated");
910 if default_output_dir.parent().unwrap().exists() {
911 std::fs::create_dir(&default_output_dir).ok();
912 default_output_dir
913 } else {
914 panic!("output dir is required");
915 }
916 };
917 let config = SynthesizeConfig {
918 repo_urls: synth_args.repos.clone(),
919 count: synth_args.count,
920 max_commits: synth_args.max_commits,
921 output_dir,
922 fresh: synth_args.fresh,
923 };
924 smol::block_on(async {
925 if let Err(e) = run_synthesize(config).await {
926 eprintln!("Error: {:?}", e);
927 std::process::exit(1);
928 }
929 });
930 return;
931 }
932 Command::SplitCommit(split_commit_args) => {
933 if let Err(error) = split_commit::run_split_commit(
934 split_commit_args,
935 &args.inputs,
936 output.as_ref(),
937 args.failed,
938 ) {
939 eprintln!("{error:#}");
940 std::process::exit(1);
941 }
942 return;
943 }
944 Command::TruncatePatch(truncate_args) => {
945 if let Err(error) =
946 truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
947 {
948 eprintln!("{error:#}");
949 std::process::exit(1);
950 }
951 return;
952 }
953 Command::Split(split_args) => {
954 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
955 eprintln!("{error:#}");
956 std::process::exit(1);
957 }
958 return;
959 }
960 Command::FilterLanguages(filter_args) => {
961 if let Err(error) =
962 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
963 {
964 eprintln!("{error:#}");
965 std::process::exit(1);
966 }
967 return;
968 }
969
970 _ => {}
971 }
972
973 let http_client = Arc::new(ReqwestClient::new());
974 let app = gpui_platform::headless().with_http_client(http_client);
975
976 app.run(move |cx| {
977 let app_state = Arc::new(headless::init(cx));
978 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
979
980 cx.spawn(async move |cx| {
981 let result = async {
982 let examples = load_examples(
983 app_state.client.http_client(),
984 &args,
985 output.as_ref(),
986 cx.background_executor().clone(),
987 )
988 .await?;
989
990 match &command {
991 Command::Predict(args) | Command::Score(args) => {
992 predict::sync_batches(args.provider.as_ref()).await?;
993 }
994 Command::Eval(args) => {
995 predict::sync_batches(args.predict.provider.as_ref()).await?;
996 }
997 Command::Qa(args) => {
998 qa::sync_batches(args).await?;
999 }
1000 Command::Repair(args) => {
1001 repair::sync_batches(args).await?;
1002 }
1003 _ => (),
1004 }
1005
1006 let failfast_on_single_example = examples.len() == 1;
1007
1008 // For --markdown mode, create the output directory if it doesn't exist
1009 if args.markdown {
1010 let dir = output.as_ref().expect("--markdown requires -o");
1011 if !dir.exists() {
1012 std::fs::create_dir_all(dir)
1013 .expect("Failed to create markdown output directory");
1014 }
1015 }
1016
1017 // Set up JSONL output writer (not used in markdown mode)
1018 let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
1019 let mut in_place_temp_path: Option<PathBuf> = None;
1020 if !args.markdown
1021 && let Some(output_path) = output.as_ref()
1022 {
1023 let write_path = if args.in_place {
1024 let temp = output_path.with_extension("jsonl.tmp");
1025 in_place_temp_path = Some(temp.clone());
1026 temp
1027 } else {
1028 output_path.clone()
1029 };
1030
1031 let file = OpenOptions::new()
1032 .create(true)
1033 .write(true)
1034 .truncate(args.in_place)
1035 .append(!args.in_place)
1036 .open(&write_path)
1037 .expect("Failed to open output file");
1038
1039 let mut writer = BufWriter::new(file);
1040 let (sender, mut receiver) = mpsc::unbounded::<String>();
1041 cx.background_spawn(async move {
1042 while let Some(line) = receiver.next().await {
1043 writeln!(writer, "{}", line).expect("Failed to write example");
1044 writer.flush().expect("Failed to flush output");
1045 }
1046 })
1047 .detach();
1048 output_sender = Some(sender);
1049 }
1050
1051 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
1052 let finished_examples = Mutex::new(Vec::new());
1053
1054 let mut tasks = Vec::new();
1055 for _ in 0..args.max_parallelism {
1056 tasks.push(async {
1057 loop {
1058 let Some(mut repo_examples) =
1059 grouped_examples.lock().unwrap().pop_front()
1060 else {
1061 break;
1062 };
1063 for example in &mut repo_examples {
1064 let example_progress =
1065 Progress::global().start_group(&example.spec.name);
1066
1067 let result = async {
1068 match &command {
1069 Command::Read(_) => {}
1070 Command::LoadProject => {
1071 run_load_project(
1072 example,
1073 app_state.clone(),
1074 &example_progress,
1075 cx.clone(),
1076 )
1077 .await?;
1078 }
1079 Command::Context => {
1080 run_context_retrieval(
1081 example,
1082 app_state.clone(),
1083 &example_progress,
1084 cx.clone(),
1085 )
1086 .await?;
1087 }
1088 Command::FormatPrompt(args) => {
1089 run_format_prompt(
1090 example,
1091 args,
1092 app_state.clone(),
1093 &example_progress,
1094 cx.clone(),
1095 )
1096 .await?;
1097 }
1098 Command::Predict(args) => {
1099 run_prediction(
1100 example,
1101 args,
1102 app_state.clone(),
1103 &example_progress,
1104 cx.clone(),
1105 )
1106 .await?;
1107 }
1108 Command::ParseOutput => {
1109 parse_output::run_parse_output(example)?;
1110 }
1111 Command::Distill => {
1112 run_distill(example).await?;
1113 }
1114 Command::Score(args) => {
1115 run_scoring(
1116 example,
1117 args,
1118 app_state.clone(),
1119 &example_progress,
1120 cx.clone(),
1121 )
1122 .await?;
1123 }
1124 Command::Eval(args) => {
1125 run_scoring(
1126 example,
1127 &args.predict,
1128 app_state.clone(),
1129 &example_progress,
1130 cx.clone(),
1131 )
1132 .await?;
1133 }
1134 Command::Qa(args) => {
1135 qa::run_qa(example, args, &example_progress).await?;
1136 }
1137 Command::Repair(args) => {
1138 repair::run_repair(example, args, &example_progress)
1139 .await?;
1140 }
1141 Command::Clean
1142 | Command::Synthesize(_)
1143 | Command::SplitCommit(_)
1144 | Command::Split(_)
1145 | Command::TruncatePatch(_)
1146 | Command::FilterLanguages(_)
1147 | Command::ImportBatch(_)
1148 | Command::PrintZetaFormats => {
1149 unreachable!()
1150 }
1151 }
1152 anyhow::Ok(())
1153 }
1154 .await;
1155
1156 let failed = if let Err(error) = result {
1157 handle_error(
1158 error,
1159 &args,
1160 &command,
1161 &app_state,
1162 failfast_on_single_example,
1163 &example,
1164 )
1165 .await;
1166 true
1167 } else {
1168 false
1169 };
1170
1171 let should_write = !failed || args.failed == FailedHandling::Keep;
1172 if should_write {
1173 if args.markdown {
1174 let markdown_dir =
1175 output.as_ref().expect("--markdown requires -o");
1176 let filename = format!("{}.md", example.spec.filename());
1177 let path = markdown_dir.join(&filename);
1178 let markdown = example.spec.to_markdown();
1179 std::fs::write(&path, &markdown)
1180 .expect("Failed to write markdown file");
1181 } else if let Some(ref mut sender) = output_sender.clone() {
1182 let line = serde_json::to_string(&example).unwrap();
1183 sender
1184 .send(line)
1185 .await
1186 .expect("Failed to send to output writer");
1187 } else if args.output.is_none()
1188 && !matches!(command, Command::Eval(_))
1189 {
1190 let line = serde_json::to_string(&example).unwrap();
1191 println!("{}", line);
1192 }
1193 }
1194 }
1195
1196 let project = repo_examples
1197 .iter()
1198 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()));
1199
1200 if let Some(project) = project {
1201 let mut cx = cx.clone();
1202
1203 let shutdown_task: Task<()> =
1204 project.update(&mut cx, |project, cx| {
1205 let lsp_store = project.lsp_store();
1206 lsp_store.update(cx, |lsp_store, cx| {
1207 lsp_store.shutdown_all_language_servers(cx)
1208 })
1209 });
1210
1211 shutdown_task.await;
1212
1213 if let Some(ep_store) =
1214 cx.update(|cx| EditPredictionStore::try_global(cx))
1215 {
1216 ep_store.update(&mut cx, |store, _| {
1217 store.remove_project(&project);
1218 });
1219 }
1220 }
1221
1222 for example in &mut repo_examples {
1223 example.state.take();
1224 }
1225 finished_examples
1226 .lock()
1227 .unwrap()
1228 .extend_from_slice(&repo_examples);
1229 }
1230 });
1231 }
1232 futures::future::join_all(tasks).await;
1233
1234 Progress::global().finalize();
1235
1236 match &command {
1237 Command::Predict(args) | Command::Score(args) => {
1238 predict::sync_batches(args.provider.as_ref()).await?;
1239 }
1240 Command::Eval(args) => {
1241 predict::sync_batches(args.predict.provider.as_ref()).await?;
1242 }
1243 Command::Qa(args) => {
1244 qa::sync_batches(args).await?;
1245 }
1246 Command::Repair(args) => {
1247 repair::sync_batches(args).await?;
1248 }
1249 _ => (),
1250 }
1251
1252 match &command {
1253 Command::Eval(args) => {
1254 let examples = finished_examples.lock().unwrap();
1255 score::print_report(&examples, args.verbose);
1256 if let Some(summary_path) = &args.summary_json {
1257 score::write_summary_json(&examples, summary_path)?;
1258 }
1259 }
1260 Command::Repair(args) => {
1261 let examples = finished_examples.lock().unwrap();
1262 repair::print_report(&examples, args.confidence_threshold);
1263 }
1264 _ => (),
1265 };
1266
1267 // For --in-place, atomically rename temp file to original
1268 if let Some(temp_path) = &in_place_temp_path {
1269 let final_path = output.as_ref().expect("in_place_temp_path requires output");
1270 std::fs::rename(temp_path, final_path)
1271 .expect("Failed to rename temp file to final output");
1272 }
1273
1274 anyhow::Ok(())
1275 }
1276 .await;
1277
1278 if let Err(e) = result {
1279 panic!("Fatal error: {:?}", e);
1280 }
1281
1282 let _ = cx.update(|cx| cx.quit());
1283 })
1284 .detach();
1285 });
1286}
1287
1288async fn handle_error(
1289 error: anyhow::Error,
1290 args: &EpArgs,
1291 command: &Command,
1292 app_state: &Arc<headless::EpAppState>,
1293 failfast_on_single_example: bool,
1294 example: &Example,
1295) {
1296 Progress::global().increment_failed();
1297
1298 let msg;
1299 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1300 let example_name = example.spec.filename();
1301
1302 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1303 app_state
1304 .fs
1305 .write(
1306 &failed_example_path,
1307 &serde_json::to_vec_pretty(&example).unwrap(),
1308 )
1309 .await
1310 .unwrap();
1311 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1312 app_state
1313 .fs
1314 .write(&err_path, format!("{error:?}").as_bytes())
1315 .await
1316 .unwrap();
1317
1318 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1319 let mut file = OpenOptions::new()
1320 .create(true)
1321 .append(true)
1322 .open(&failed_jsonl_path)
1323 .expect("Failed to open failed.jsonl");
1324 writeln!(file, "{}", serde_json::to_string(example).unwrap())
1325 .expect("Failed to write to failed.jsonl");
1326
1327 let cursor_path = match example.repo_name() {
1328 Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1329 Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1330 };
1331 msg = format!(
1332 indoc::indoc! {"
1333 While processing \"{}\":
1334
1335 \x1b[31m{:?}\x1b[0m
1336
1337 Example: \x1b[36m{}\x1b[0m
1338 Error file: \x1b[36m{}\x1b[0m
1339 Cursor file: \x1b[36m{}\x1b[0m
1340 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1341 "},
1342 example.spec.name,
1343 error,
1344 failed_example_path.display(),
1345 err_path.display(),
1346 cursor_path.display(),
1347 command,
1348 failed_example_path.display(),
1349 );
1350 } else {
1351 msg = format!(
1352 indoc::indoc! {"
1353 While processing \"{}\":
1354
1355 \x1b[31m{:?}\x1b[0m
1356 "},
1357 example.spec.name, error
1358 );
1359 }
1360
1361 if args.failfast || failfast_on_single_example {
1362 Progress::global().finalize();
1363 panic!("{}", msg);
1364 } else {
1365 log::error!("{}", msg);
1366 }
1367}