1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod load_project;
9mod metrics;
10mod openai_client;
11mod parse_output;
12mod paths;
13mod predict;
14mod progress;
15mod prompt_assets;
16mod pull_examples;
17mod qa;
18mod reorder_patch;
19mod repair;
20mod retrieve_context;
21mod reversal_tracking;
22mod score;
23mod split_commit;
24mod split_dataset;
25
26mod synthesize;
27mod truncate_expected_patch;
28mod word_diff;
29use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
30use collections::{HashMap, HashSet};
31use edit_prediction::EditPredictionStore;
32use futures::channel::mpsc;
33use futures::{SinkExt as _, StreamExt as _};
34use gaoya::minhash::{
35 MinHashIndex, MinHasher, MinHasher32, calculate_minhash_params, compute_minhash_similarity,
36};
37use gpui::{AppContext as _, BackgroundExecutor, Task};
38use zeta_prompt::ZetaFormat;
39
40use reqwest_client::ReqwestClient;
41use serde::{Deserialize, Deserializer, Serialize, Serializer};
42use std::env;
43use std::fmt::Display;
44use std::fs::{File, OpenOptions};
45use std::hash::{Hash, Hasher};
46use std::io::{BufRead, BufReader, BufWriter, Write};
47use std::sync::Mutex;
48use std::{path::PathBuf, sync::Arc};
49
50use crate::distill::run_distill;
51use crate::example::{Example, group_examples_by_repo, read_example_files};
52use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
53use crate::format_prompt::run_format_prompt;
54use crate::load_project::run_load_project;
55use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
56use crate::predict::run_prediction;
57use crate::progress::Progress;
58use crate::pull_examples::{fetch_settled_examples_after, parse_settled_after_input};
59use crate::retrieve_context::run_context_retrieval;
60use crate::score::run_scoring;
61use crate::split_commit::SplitCommitArgs;
62use crate::split_dataset::SplitArgs;
63use crate::synthesize::{SynthesizeConfig, run_synthesize};
64use crate::truncate_expected_patch::TruncatePatchArgs;
65
66#[derive(Parser, Debug)]
67#[command(name = "ep")]
68struct EpArgs {
69 #[arg(long, default_value_t = false)]
70 printenv: bool,
71 #[clap(long, default_value_t = 10, global = true)]
72 max_parallelism: usize,
73 /// The limit for the number of examples to process
74 /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
75 #[clap(long, global = true)]
76 limit: Option<usize>,
77 #[clap(long, global = true)]
78 offset: Option<usize>,
79 /// Filter examples by name
80 #[clap(long, global = true)]
81 name: Option<String>,
82 /// Filter examples by repository
83 #[clap(long, global = true)]
84 repo: Option<String>,
85 /// Deduplicate by cursor position and keep at most this many examples per cluster
86 #[clap(long, global = true)]
87 max_duplicates: Option<usize>,
88 #[command(subcommand)]
89 command: Option<Command>,
90 /// Input file paths
91 #[clap(global = true)]
92 inputs: Vec<PathBuf>,
93 #[arg(long, short, global = true)]
94 output: Option<PathBuf>,
95 #[arg(long, short, global = true)]
96 in_place: bool,
97 #[arg(long, short, global = true)]
98 failfast: bool,
99 /// How to handle failed examples in output: keep them or skip them.
100 /// Failed examples are always logged to the run's failed directory.
101 #[arg(long, global = true, default_value = "keep")]
102 failed: FailedHandling,
103 /// Output as markdown files instead of JSONL. When set, -o specifies a directory
104 /// where one .md file per example will be written (named after each example).
105 #[arg(long, short, global = true)]
106 markdown: bool,
107}
108
109/// Controls whether failed examples are included in the main output.
110/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
111#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
112pub enum FailedHandling {
113 /// Include failed examples in the main output (default)
114 #[default]
115 Keep,
116 /// Exclude failed examples from the main output
117 Skip,
118 /// Skip writing files
119 SkipNoFiles,
120}
121
122const INPUTS_HELP: &str = r#"
123Inputs can be file paths or special specifiers:
124
125 path
126 Path to an example(s) file (.md, .json, or .jsonl)
127
128 captured-after:{timestamp}
129 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
130 These are examples captured via the "Capture Edit Prediction Example" action.
131
132 rejected-after:{timestamp}
133 Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
134 These are predictions that were shown to users but rejected (useful for DPO training).
135
136 settled-after:{timestamp}
137 Fetch settled stream examples from Snowflake after the given RFC3339 timestamp.
138 These are examples from the edit prediction settled stream.
139
140 rated-after:{timestamp}
141 Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
142 These are predictions that users explicitly rated as positive or negative via the
143 rate completions modal. Only zeta2 predictions are included.
144 - Positive ratings: output becomes expected_patches
145 - Negative ratings: output becomes rejected_patch
146
147 rated-positive-after:{timestamp}
148 Same as rated-after, but only fetches positively rated predictions.
149
150 rated-negative-after:{timestamp}
151 Same as rated-after, but only fetches negatively rated predictions.
152
153 Required environment variables to connect to Snowflake:
154 EP_SNOWFLAKE_API_KEY
155 EP_SNOWFLAKE_BASE_URL
156
157 Optional:
158 EP_SNOWFLAKE_ROLE
159
160Examples:
161
162 # Read examples from a file
163 ep read examples.jsonl -o output.jsonl
164
165 # Read captured examples after a timestamp
166 ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
167
168 # Read rejected predictions for DPO training
169 ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
170
171 # Read user-rated predictions
172 ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
173
174 # Read settled stream examples
175 ep read settled-after:2025-01-01T00:00:00Z -o settled.jsonl
176
177 # Read only positively rated predictions
178 ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
179
180 # Read only negatively rated predictions
181 ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
182
183 # Mix multiple input sources
184 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
185"#;
186
187#[derive(Subcommand, Debug, Clone)]
188enum Command {
189 /// Read examples from files or fetch from Snowflake, output as .jsonl
190 Read(ReadArgs),
191 /// Create git worktrees for each example and load file contents
192 LoadProject,
193 /// Retrieve context for input examples.
194 Context,
195 /// Generate a prompt string for a specific model
196 FormatPrompt(FormatPromptArgs),
197 /// Runs edit prediction
198 Predict(PredictArgs),
199 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
200 /// Requires format-prompt to have been run first. Uses provider from prompt.
201 ParseOutput,
202 /// Computes a score based on actual and expected patches
203 Score(PredictArgs),
204 /// Prepares a distillation dataset by copying expected outputs to
205 /// predicted outputs and removing actual outputs and prompts.
206 Distill,
207 /// Print aggregated scores
208 Eval(EvalArgs),
209 /// Generate eval examples by analyzing git commits from a repository
210 Synthesize(SynthesizeArgs),
211 /// Remove git repositories and worktrees
212 Clean,
213 /// Generate an evaluation example by splitting a chronologically-ordered commit
214 SplitCommit(SplitCommitArgs),
215 /// Truncate expected patch by the given criteria
216 TruncatePatch(TruncatePatchArgs),
217 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
218 Split(SplitArgs),
219 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
220 FilterLanguages(FilterLanguagesArgs),
221 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
222 ImportBatch(ImportBatchArgs),
223 /// Assess the quality of predictions using LLM-as-a-judge
224 Qa(qa::QaArgs),
225 /// Repair predictions that received poor QA scores by generating improved predictions
226 Repair(repair::RepairArgs),
227 /// Print all valid zeta formats (lowercase, one per line)
228 PrintZetaFormats,
229}
230
231impl Display for Command {
232 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233 match self {
234 Command::Read(_) => write!(f, "read"),
235 Command::LoadProject => write!(f, "load-project"),
236 Command::Context => write!(f, "context"),
237 Command::FormatPrompt(args) => {
238 write!(f, "format-prompt --provider={}", args.provider)
239 }
240 Command::Predict(args) => match &args.provider {
241 Some(provider) => write!(f, "predict --provider={}", provider),
242 None => write!(f, "predict"),
243 },
244 Command::ParseOutput => write!(f, "parse-output"),
245 Command::Score(args) => match &args.provider {
246 Some(provider) => write!(f, "score --provider={}", provider),
247 None => write!(f, "score"),
248 },
249 Command::Distill => write!(f, "distill"),
250 Command::Eval(args) => match &args.predict.provider {
251 Some(provider) => write!(f, "eval --provider={}", provider),
252 None => write!(f, "eval"),
253 },
254 Command::Synthesize(args) => {
255 write!(f, "synthesize --repos {}", args.repos.join(" "))
256 }
257 Command::Clean => write!(f, "clean"),
258 Command::SplitCommit(_) => write!(f, "split-commit"),
259 Command::TruncatePatch(_) => write!(f, "truncate-patch"),
260 Command::Split(_) => write!(f, "split"),
261 Command::FilterLanguages(_) => write!(f, "filter-languages"),
262 Command::ImportBatch(args) => {
263 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
264 }
265 Command::Qa(_) => {
266 write!(f, "qa")
267 }
268 Command::Repair(_) => {
269 write!(f, "repair")
270 }
271 Command::PrintZetaFormats => {
272 write!(f, "print-zeta-formats")
273 }
274 }
275 }
276}
277
278#[derive(Debug, Args, Clone)]
279#[command(after_help = INPUTS_HELP)]
280struct ReadArgs {}
281
282#[derive(Debug, Args, Clone)]
283struct FormatPromptArgs {
284 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
285 provider: PredictionProvider,
286}
287
288#[derive(Debug, Args, Clone)]
289struct PredictArgs {
290 #[clap(long, short('p'))]
291 provider: Option<PredictionProvider>,
292 #[clap(long, default_value_t = 1)]
293 repetitions: usize,
294 /// Only use cached responses, don't queue new requests for batching
295 #[clap(long)]
296 cache_only: bool,
297}
298
299#[derive(Debug, Args, Clone)]
300struct EvalArgs {
301 #[clap(flatten)]
302 predict: PredictArgs,
303 /// Path to write summary scores as JSON
304 #[clap(long)]
305 summary_json: Option<PathBuf>,
306 /// Print all individual example lines (default: up to 20)
307 #[clap(long)]
308 verbose: bool,
309}
310
311#[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)]
312pub enum TeacherBackend {
313 Sonnet46,
314 #[default]
315 Sonnet45,
316 Gpt52,
317}
318
319impl std::fmt::Display for TeacherBackend {
320 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
321 match self {
322 TeacherBackend::Sonnet46 => write!(f, "sonnet46"),
323 TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
324 TeacherBackend::Gpt52 => write!(f, "gpt52"),
325 }
326 }
327}
328
329impl std::str::FromStr for TeacherBackend {
330 type Err = anyhow::Error;
331
332 fn from_str(s: &str) -> Result<Self, Self::Err> {
333 match s.to_lowercase().as_str() {
334 "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
335 "sonnet46" => Ok(TeacherBackend::Sonnet46),
336 "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
337 "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
338 _ => anyhow::bail!(
339 "unknown teacher backend `{s}`. Valid options: sonnet45, sonnet46, gpt52"
340 ),
341 }
342 }
343}
344
345impl TeacherBackend {
346 pub fn model_name(&self) -> &'static str {
347 match self {
348 TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
349 TeacherBackend::Sonnet46 => "claude-sonnet-4-6",
350 TeacherBackend::Gpt52 => "gpt-5.2",
351 }
352 }
353}
354
355#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
356enum PredictionProvider {
357 Sweep,
358 Mercury,
359 Zeta1,
360 Zeta2(ZetaFormat),
361 Baseten(ZetaFormat),
362 Teacher(TeacherBackend),
363 TeacherNonBatching(TeacherBackend),
364 Repair,
365}
366
367impl Default for PredictionProvider {
368 fn default() -> Self {
369 PredictionProvider::Zeta2(ZetaFormat::default())
370 }
371}
372
373impl std::fmt::Display for PredictionProvider {
374 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
375 match self {
376 PredictionProvider::Sweep => write!(f, "sweep"),
377 PredictionProvider::Mercury => write!(f, "mercury"),
378 PredictionProvider::Zeta1 => write!(f, "zeta1"),
379 PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
380 PredictionProvider::Baseten(format) => write!(f, "baseten:{format}"),
381 PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
382 PredictionProvider::TeacherNonBatching(backend) => {
383 write!(f, "teacher-non-batching:{backend}")
384 }
385 PredictionProvider::Repair => write!(f, "repair"),
386 }
387 }
388}
389
390impl std::str::FromStr for PredictionProvider {
391 type Err = anyhow::Error;
392
393 fn from_str(s: &str) -> Result<Self, Self::Err> {
394 let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
395
396 let provider_lower = provider.to_lowercase();
397 match provider_lower.as_str() {
398 "sweep" => Ok(PredictionProvider::Sweep),
399 "mercury" => Ok(PredictionProvider::Mercury),
400 "zeta1" => Ok(PredictionProvider::Zeta1),
401 "zeta2" => {
402 let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
403 Ok(PredictionProvider::Zeta2(format))
404 }
405 "teacher" => {
406 let backend = arg
407 .map(|a| a.parse())
408 .transpose()?
409 .unwrap_or(TeacherBackend::default());
410 Ok(PredictionProvider::Teacher(backend))
411 }
412 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
413 let backend = arg
414 .map(|a| a.parse())
415 .transpose()?
416 .unwrap_or(TeacherBackend::default());
417 Ok(PredictionProvider::TeacherNonBatching(backend))
418 }
419 "repair" => Ok(PredictionProvider::Repair),
420 "baseten" => {
421 let format = arg
422 .map(ZetaFormat::parse)
423 .transpose()?
424 .unwrap_or(ZetaFormat::default());
425 Ok(PredictionProvider::Baseten(format))
426 }
427 _ => {
428 anyhow::bail!(
429 "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
430 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
431 For teacher, you can specify a backend like `teacher:sonnet46` or `teacher:gpt52`.\n\
432 Available zeta versions:\n{}",
433 ZetaFormat::options_as_string()
434 )
435 }
436 }
437 }
438}
439
440impl Serialize for PredictionProvider {
441 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
442 where
443 S: Serializer,
444 {
445 serializer.serialize_str(&self.to_string())
446 }
447}
448
449impl<'de> Deserialize<'de> for PredictionProvider {
450 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
451 where
452 D: Deserializer<'de>,
453 {
454 let s = String::deserialize(deserializer)?;
455 s.parse().map_err(serde::de::Error::custom)
456 }
457}
458
459#[derive(Debug, Args, Clone)]
460struct SynthesizeArgs {
461 /// Repository URLs (git@github.com:owner/repo or https://...)
462 #[clap(long, required = true, num_args = 1..)]
463 repos: Vec<String>,
464
465 /// Number of examples to generate per repository
466 #[clap(long, default_value_t = 5)]
467 count: usize,
468
469 /// Maximum commits to scan per repository before giving up
470 #[clap(long, default_value_t = 100)]
471 max_commits: usize,
472
473 /// Ignore state file and reprocess all commits
474 #[clap(long)]
475 fresh: bool,
476}
477
478#[derive(Debug, Args, Clone)]
479struct ImportBatchArgs {
480 /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
481 #[clap(long, required = true, num_args = 1..)]
482 batch_ids: Vec<String>,
483 /// Which provider's batches to import (anthropic or openai)
484 #[clap(long, default_value = "anthropic")]
485 provider: BatchProvider,
486}
487
488#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
489enum BatchProvider {
490 Anthropic,
491 Openai,
492}
493
494impl EpArgs {
495 fn output_path(&self) -> Option<PathBuf> {
496 if self.in_place {
497 if self.inputs.len() == 1 {
498 self.inputs.first().cloned()
499 } else {
500 panic!("--in-place requires exactly one input file")
501 }
502 } else {
503 self.output.clone()
504 }
505 }
506}
507
508/// Minimum Zed version required for Snowflake queries.
509/// This version introduced the current request schema with predicted edits in the edit
510/// history, and open source repos distinguished.
511const MIN_CAPTURE_VERSION: pull_examples::MinCaptureVersion = pull_examples::MinCaptureVersion {
512 minor: 224,
513 patch: 1,
514};
515
516fn deduplicate_examples(examples: &mut Vec<Example>, max_per_cluster: usize) {
517 let total_before_exact = examples.len();
518 let mut seen_positions = HashSet::default();
519 examples.retain(|example| seen_positions.insert(example.spec.cursor_position.clone()));
520 log::info!(
521 "exact duplicate filter: {total_before_exact} examples → {} examples ({} removed)",
522 examples.len(),
523 total_before_exact - examples.len(),
524 );
525
526 const JACCARD_THRESHOLD: f64 = 0.5;
527 const NUM_HASHES: usize = 128;
528 const TOKEN_NGRAM_SIZE: usize = 5;
529
530 let (num_bands, band_width) = calculate_minhash_params(JACCARD_THRESHOLD, NUM_HASHES);
531 let num_hashes = num_bands * band_width;
532 let minhasher = MinHasher32::new(num_hashes);
533 let mut index: MinHashIndex<u32, usize> =
534 MinHashIndex::new(num_bands, band_width, JACCARD_THRESHOLD);
535
536 let signatures: Vec<Vec<u32>> = examples
537 .iter()
538 .map(|example| {
539 let shingles = code_token_ngrams(&example.spec.cursor_position, TOKEN_NGRAM_SIZE);
540 minhasher.create_signature(shingles.iter())
541 })
542 .collect();
543
544 for (id, signature) in signatures.iter().enumerate() {
545 index.insert(id, signature.clone());
546 }
547
548 // Build clusters via union-find on LSH candidate pairs.
549 let mut parent: Vec<usize> = (0..examples.len()).collect();
550
551 fn find(parent: &mut Vec<usize>, mut x: usize) -> usize {
552 while parent[x] != x {
553 parent[x] = parent[parent[x]];
554 x = parent[x];
555 }
556 x
557 }
558
559 for (id, signature) in signatures.iter().enumerate() {
560 for candidate in index.query_owned(signature) {
561 let (a, b) = (find(&mut parent, id), find(&mut parent, candidate));
562 if a != b {
563 parent[a] = b;
564 }
565 }
566 }
567
568 let mut clusters: HashMap<usize, Vec<usize>> = HashMap::default();
569 for id in 0..examples.len() {
570 clusters.entry(find(&mut parent, id)).or_default().push(id);
571 }
572
573 let mut keep: HashSet<usize> = HashSet::default();
574 for members in clusters.values() {
575 let selected = greedy_max_min_diverse(members, &signatures, max_per_cluster);
576 keep.extend(selected);
577 }
578
579 let total = examples.len();
580 let mut kept_indices: Vec<usize> = keep.into_iter().collect();
581 kept_indices.sort();
582
583 let mut retained = Vec::with_capacity(kept_indices.len());
584 for index in kept_indices.into_iter().rev() {
585 retained.push(examples.swap_remove(index));
586 }
587 retained.reverse();
588
589 *examples = retained;
590 log::info!(
591 "near-duplicate filter: {total} examples → {} examples ({} removed)",
592 examples.len(),
593 total - examples.len(),
594 );
595}
596
597fn greedy_max_min_diverse(members: &[usize], signatures: &[Vec<u32>], k: usize) -> Vec<usize> {
598 if members.len() <= k {
599 return members.to_vec();
600 }
601
602 let mut selected = vec![members[0]];
603 let mut min_dist: HashMap<usize, f64> = HashMap::default();
604 for &member in &members[1..] {
605 let dist = 1.0 - compute_minhash_similarity(&signatures[selected[0]], &signatures[member]);
606 min_dist.insert(member, dist);
607 }
608
609 while selected.len() < k {
610 let &best = min_dist
611 .iter()
612 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
613 .map(|(id, _)| id)
614 .expect("min_dist should not be empty when selected.len() < k");
615 selected.push(best);
616 min_dist.remove(&best);
617
618 let best_sig = &signatures[best];
619 for (member, current_min) in min_dist.iter_mut() {
620 let dist = 1.0 - compute_minhash_similarity(best_sig, &signatures[*member]);
621 if dist < *current_min {
622 *current_min = dist;
623 }
624 }
625 }
626
627 selected
628}
629
630fn code_token_ngrams(code: &str, ngram_size: usize) -> Vec<String> {
631 let tokens: Vec<&str> = word_diff::tokenize(code)
632 .into_iter()
633 .filter(|t| !t.trim().is_empty())
634 .collect();
635
636 if tokens.len() < ngram_size {
637 return vec![tokens.join("\0")];
638 }
639
640 tokens
641 .windows(ngram_size)
642 .map(|window| window.join("\0"))
643 .collect()
644}
645
646async fn load_examples(
647 http_client: Arc<dyn http_client::HttpClient>,
648 args: &EpArgs,
649 output_path: Option<&PathBuf>,
650 background_executor: BackgroundExecutor,
651) -> anyhow::Result<Vec<Example>> {
652 let mut captured_after_timestamps = Vec::new();
653 let mut rejected_after_timestamps = Vec::new();
654 let mut requested_after_timestamps = Vec::new();
655 let mut settled_after_timestamps = Vec::new();
656 let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
657 Vec::new();
658 let mut file_inputs = Vec::new();
659
660 for input in &args.inputs {
661 let input_string = input.to_string_lossy();
662 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
663 captured_after_timestamps.push(timestamp.to_string());
664 } else if let Some(timestamp) =
665 pull_examples::parse_rejected_after_input(input_string.as_ref())
666 {
667 rejected_after_timestamps.push(timestamp.to_string());
668 } else if let Some(timestamp) =
669 pull_examples::parse_requested_after_input(input_string.as_ref())
670 {
671 requested_after_timestamps.push(timestamp.to_string());
672 } else if let Some(timestamp) = parse_settled_after_input(input_string.as_ref()) {
673 settled_after_timestamps.push(timestamp.to_string());
674 } else if let Some((timestamp, rating_filter)) =
675 pull_examples::parse_rated_after_input(input_string.as_ref())
676 {
677 rated_after_inputs.push((timestamp.to_string(), rating_filter));
678 } else {
679 file_inputs.push(input.clone());
680 }
681 }
682
683 let mut examples = read_example_files(&file_inputs);
684
685 // Apply offset to file examples first, then pass remaining offset to Snowflake.
686 let file_example_count = examples.len();
687 let remaining_offset = if let Some(offset) = args.offset {
688 if offset >= file_example_count {
689 examples.clear();
690 offset - file_example_count
691 } else {
692 examples.splice(0..offset, []);
693 0
694 }
695 } else {
696 0
697 };
698
699 Progress::global().set_total_examples(examples.len());
700
701 let remaining_limit_for_snowflake =
702 args.limit.map(|limit| limit.saturating_sub(examples.len()));
703
704 if let Some(0) = remaining_limit_for_snowflake {
705 log::info!(
706 "skipping Snowflake inputs because --limit is already satisfied by example files"
707 );
708 } else {
709 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
710
711 if !rejected_after_timestamps.is_empty() {
712 rejected_after_timestamps.sort();
713
714 let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
715 http_client.clone(),
716 &rejected_after_timestamps,
717 max_rows_per_timestamp,
718 remaining_offset,
719 background_executor.clone(),
720 Some(MIN_CAPTURE_VERSION),
721 )
722 .await?;
723 examples.append(&mut rejected_examples);
724 }
725
726 if !requested_after_timestamps.is_empty() {
727 requested_after_timestamps.sort();
728
729 let mut requested_examples = pull_examples::fetch_requested_examples_after(
730 http_client.clone(),
731 &requested_after_timestamps,
732 max_rows_per_timestamp,
733 remaining_offset,
734 background_executor.clone(),
735 Some(MIN_CAPTURE_VERSION),
736 )
737 .await?;
738 examples.append(&mut requested_examples);
739 }
740
741 if !settled_after_timestamps.is_empty() {
742 settled_after_timestamps.sort();
743
744 let mut settled_examples = fetch_settled_examples_after(
745 http_client.clone(),
746 &settled_after_timestamps,
747 max_rows_per_timestamp,
748 remaining_offset,
749 background_executor.clone(),
750 Some(MIN_CAPTURE_VERSION),
751 )
752 .await?;
753 examples.append(&mut settled_examples);
754 }
755
756 if !rated_after_inputs.is_empty() {
757 rated_after_inputs.sort();
758
759 let mut rated_examples = pull_examples::fetch_rated_examples_after(
760 http_client,
761 &rated_after_inputs,
762 max_rows_per_timestamp,
763 remaining_offset,
764 background_executor,
765 Some(MIN_CAPTURE_VERSION),
766 )
767 .await?;
768 examples.append(&mut rated_examples);
769 }
770 }
771
772 crate::example::sort_examples_by_repo_and_rev(&mut examples);
773
774 if let Some(name_filter) = &args.name {
775 examples.retain(|example| example.spec.name.contains(name_filter));
776 }
777 if let Some(repo_filter) = &args.repo {
778 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
779 }
780
781 // Skip resume logic for --in-place since input and output are the same file,
782 // which would incorrectly treat all input examples as already processed.
783 if !args.in_place {
784 if let Some(path) = output_path
785 && let Some(command) = &args.command
786 {
787 resume_from_output(path, &mut examples, command);
788 }
789 }
790
791 if let Some(max_duplicates) = args.max_duplicates {
792 deduplicate_examples(&mut examples, max_duplicates);
793 }
794
795 if let Some(limit) = args.limit {
796 examples.truncate(limit);
797 }
798
799 let progress = Progress::global();
800 progress.set_total_examples(examples.len());
801 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
802
803 Ok(examples)
804}
805
806fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
807 let mut hasher = collections::FxHasher::default();
808 spec.hash(&mut hasher);
809 hasher.finish()
810}
811
812fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
813 let file = match File::open(path) {
814 Ok(f) => f,
815 Err(_) => return,
816 };
817
818 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
819
820 let reader = BufReader::new(file);
821 let mut kept_lines = Vec::new();
822 let mut kept_hashes = HashSet::default();
823
824 for line in reader.lines() {
825 let line = match line {
826 Ok(l) => l,
827 Err(_) => continue,
828 };
829
830 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
831 let hash = spec_hash(&output_example.spec);
832 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
833 let is_complete = match command {
834 Command::Qa(_) => output_example
835 .qa
836 .first()
837 .and_then(|q| q.as_ref())
838 .and_then(|q| q.confidence)
839 .is_some(),
840 Command::Repair(_) => output_example.predictions.iter().any(|p| {
841 p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
842 }),
843 _ => true,
844 };
845 if is_complete {
846 kept_hashes.insert(hash);
847 kept_lines.push(line);
848 }
849 }
850 }
851 }
852
853 let total = examples.len();
854 let already_processed = kept_hashes.len();
855
856 eprintln!(
857 "Resuming: {}/{} examples already processed",
858 already_processed, total
859 );
860
861 let file = OpenOptions::new()
862 .write(true)
863 .truncate(true)
864 .open(path)
865 .expect("Failed to open output file for rewriting");
866 let mut writer = BufWriter::new(file);
867 for line in &kept_lines {
868 writeln!(writer, "{}", line).expect("Failed to write to output file");
869 }
870 writer.flush().expect("Failed to flush output file");
871
872 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
873}
874
875fn main() {
876 let args = EpArgs::parse();
877
878 if args.printenv {
879 ::util::shell_env::print_env();
880 return;
881 }
882
883 let output = args.output_path();
884
885 if args.markdown && output.is_none() {
886 eprintln!("--markdown requires -o to specify the output directory");
887 std::process::exit(1);
888 }
889
890 let command = match &args.command {
891 Some(cmd) => cmd.clone(),
892 None => {
893 EpArgs::command().print_help().unwrap();
894 return;
895 }
896 };
897
898 match &command {
899 Command::ImportBatch(import_args) => {
900 smol::block_on(async {
901 match import_args.provider {
902 BatchProvider::Anthropic => {
903 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
904 .expect("Failed to create Anthropic client");
905 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
906 eprintln!("Error importing Anthropic batches: {:?}", e);
907 std::process::exit(1);
908 }
909 }
910 BatchProvider::Openai => {
911 let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
912 .expect("Failed to create OpenAI client");
913 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
914 eprintln!("Error importing OpenAI batches: {:?}", e);
915 std::process::exit(1);
916 }
917 }
918 }
919 println!(
920 "Successfully imported {} batch(es)",
921 import_args.batch_ids.len()
922 );
923 });
924 return;
925 }
926 Command::Clean => {
927 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
928 return;
929 }
930 Command::PrintZetaFormats => {
931 use strum::IntoEnumIterator as _;
932 for format in ZetaFormat::iter() {
933 println!("{}", format.to_string().to_lowercase());
934 }
935 return;
936 }
937
938 Command::Synthesize(synth_args) => {
939 let output_dir = if let Some(output_dir) = args.output {
940 output_dir
941 } else {
942 let default_output_dir = env::current_dir()
943 .unwrap()
944 .join("crates/edit_prediction_cli/evals-generated");
945 if default_output_dir.parent().unwrap().exists() {
946 std::fs::create_dir(&default_output_dir).ok();
947 default_output_dir
948 } else {
949 panic!("output dir is required");
950 }
951 };
952 let config = SynthesizeConfig {
953 repo_urls: synth_args.repos.clone(),
954 count: synth_args.count,
955 max_commits: synth_args.max_commits,
956 output_dir,
957 fresh: synth_args.fresh,
958 };
959 smol::block_on(async {
960 if let Err(e) = run_synthesize(config).await {
961 eprintln!("Error: {:?}", e);
962 std::process::exit(1);
963 }
964 });
965 return;
966 }
967 Command::SplitCommit(split_commit_args) => {
968 if let Err(error) = split_commit::run_split_commit(
969 split_commit_args,
970 &args.inputs,
971 output.as_ref(),
972 args.failed,
973 ) {
974 eprintln!("{error:#}");
975 std::process::exit(1);
976 }
977 return;
978 }
979 Command::TruncatePatch(truncate_args) => {
980 if let Err(error) =
981 truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
982 {
983 eprintln!("{error:#}");
984 std::process::exit(1);
985 }
986 return;
987 }
988 Command::Split(split_args) => {
989 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
990 eprintln!("{error:#}");
991 std::process::exit(1);
992 }
993 return;
994 }
995 Command::FilterLanguages(filter_args) => {
996 if let Err(error) =
997 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
998 {
999 eprintln!("{error:#}");
1000 std::process::exit(1);
1001 }
1002 return;
1003 }
1004
1005 _ => {}
1006 }
1007
1008 let http_client = Arc::new(ReqwestClient::new());
1009 let app = gpui_platform::headless().with_http_client(http_client);
1010
1011 app.run(move |cx| {
1012 let app_state = Arc::new(headless::init(cx));
1013 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
1014
1015 cx.spawn(async move |cx| {
1016 let result = async {
1017 let examples = load_examples(
1018 app_state.client.http_client(),
1019 &args,
1020 output.as_ref(),
1021 cx.background_executor().clone(),
1022 )
1023 .await?;
1024
1025 match &command {
1026 Command::Predict(args) | Command::Score(args) => {
1027 predict::sync_batches(args.provider.as_ref()).await?;
1028 }
1029 Command::Eval(args) => {
1030 predict::sync_batches(args.predict.provider.as_ref()).await?;
1031 }
1032 Command::Qa(args) => {
1033 qa::sync_batches(args).await?;
1034 }
1035 Command::Repair(args) => {
1036 repair::sync_batches(args).await?;
1037 }
1038 _ => (),
1039 }
1040
1041 let failfast_on_single_example = examples.len() == 1;
1042
1043 // For --markdown mode, create the output directory if it doesn't exist
1044 if args.markdown {
1045 let dir = output.as_ref().expect("--markdown requires -o");
1046 if !dir.exists() {
1047 std::fs::create_dir_all(dir)
1048 .expect("Failed to create markdown output directory");
1049 }
1050 }
1051
1052 // Set up JSONL output writer (not used in markdown mode)
1053 let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
1054 let mut in_place_temp_path: Option<PathBuf> = None;
1055 if !args.markdown
1056 && let Some(output_path) = output.as_ref()
1057 {
1058 let write_path = if args.in_place {
1059 let temp = output_path.with_extension("jsonl.tmp");
1060 in_place_temp_path = Some(temp.clone());
1061 temp
1062 } else {
1063 output_path.clone()
1064 };
1065
1066 let file = OpenOptions::new()
1067 .create(true)
1068 .write(true)
1069 .truncate(args.in_place)
1070 .append(!args.in_place)
1071 .open(&write_path)
1072 .expect("Failed to open output file");
1073
1074 let mut writer = BufWriter::new(file);
1075 let (sender, mut receiver) = mpsc::unbounded::<String>();
1076 cx.background_spawn(async move {
1077 while let Some(line) = receiver.next().await {
1078 writeln!(writer, "{}", line).expect("Failed to write example");
1079 writer.flush().expect("Failed to flush output");
1080 }
1081 })
1082 .detach();
1083 output_sender = Some(sender);
1084 }
1085
1086 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
1087 let finished_examples = Mutex::new(Vec::new());
1088
1089 let mut tasks = Vec::new();
1090 for _ in 0..args.max_parallelism {
1091 tasks.push(async {
1092 loop {
1093 let Some(mut repo_examples) =
1094 grouped_examples.lock().unwrap().pop_front()
1095 else {
1096 break;
1097 };
1098 for example in &mut repo_examples {
1099 let example_progress =
1100 Progress::global().start_group(&example.spec.name);
1101
1102 let result = async {
1103 match &command {
1104 Command::Read(_) => {}
1105 Command::LoadProject => {
1106 run_load_project(
1107 example,
1108 app_state.clone(),
1109 &example_progress,
1110 cx.clone(),
1111 )
1112 .await?;
1113 }
1114 Command::Context => {
1115 run_context_retrieval(
1116 example,
1117 app_state.clone(),
1118 &example_progress,
1119 cx.clone(),
1120 )
1121 .await?;
1122 }
1123 Command::FormatPrompt(args) => {
1124 run_format_prompt(
1125 example,
1126 args,
1127 app_state.clone(),
1128 &example_progress,
1129 cx.clone(),
1130 )
1131 .await?;
1132 }
1133 Command::Predict(args) => {
1134 run_prediction(
1135 example,
1136 args,
1137 app_state.clone(),
1138 &example_progress,
1139 cx.clone(),
1140 )
1141 .await?;
1142 }
1143 Command::ParseOutput => {
1144 parse_output::run_parse_output(example)?;
1145 }
1146 Command::Distill => {
1147 run_distill(example).await?;
1148 }
1149 Command::Score(args) => {
1150 run_scoring(
1151 example,
1152 args,
1153 app_state.clone(),
1154 &example_progress,
1155 cx.clone(),
1156 )
1157 .await?;
1158 }
1159 Command::Eval(args) => {
1160 run_scoring(
1161 example,
1162 &args.predict,
1163 app_state.clone(),
1164 &example_progress,
1165 cx.clone(),
1166 )
1167 .await?;
1168 }
1169 Command::Qa(args) => {
1170 qa::run_qa(example, args, &example_progress).await?;
1171 }
1172 Command::Repair(args) => {
1173 repair::run_repair(example, args, &example_progress)
1174 .await?;
1175 }
1176 Command::Clean
1177 | Command::Synthesize(_)
1178 | Command::SplitCommit(_)
1179 | Command::Split(_)
1180 | Command::TruncatePatch(_)
1181 | Command::FilterLanguages(_)
1182 | Command::ImportBatch(_)
1183 | Command::PrintZetaFormats => {
1184 unreachable!()
1185 }
1186 }
1187 anyhow::Ok(())
1188 }
1189 .await;
1190
1191 let failed = if let Err(error) = result {
1192 handle_error(
1193 error,
1194 &args,
1195 &command,
1196 &app_state,
1197 failfast_on_single_example,
1198 &example,
1199 )
1200 .await;
1201 true
1202 } else {
1203 false
1204 };
1205
1206 let should_write = !failed || args.failed == FailedHandling::Keep;
1207 if should_write {
1208 if args.markdown {
1209 let markdown_dir =
1210 output.as_ref().expect("--markdown requires -o");
1211 let filename = format!("{}.md", example.spec.filename());
1212 let path = markdown_dir.join(&filename);
1213 let markdown = example.spec.to_markdown();
1214 std::fs::write(&path, &markdown)
1215 .expect("Failed to write markdown file");
1216 } else if let Some(ref mut sender) = output_sender.clone() {
1217 let line = serde_json::to_string(&example).unwrap();
1218 sender
1219 .send(line)
1220 .await
1221 .expect("Failed to send to output writer");
1222 } else if args.output.is_none()
1223 && !matches!(command, Command::Eval(_))
1224 {
1225 let line = serde_json::to_string(&example).unwrap();
1226 println!("{}", line);
1227 }
1228 }
1229 }
1230
1231 let project = repo_examples
1232 .iter()
1233 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()));
1234
1235 if let Some(project) = project {
1236 let mut cx = cx.clone();
1237
1238 let shutdown_task: Task<()> =
1239 project.update(&mut cx, |project, cx| {
1240 let lsp_store = project.lsp_store();
1241 lsp_store.update(cx, |lsp_store, cx| {
1242 lsp_store.shutdown_all_language_servers(cx)
1243 })
1244 });
1245
1246 shutdown_task.await;
1247
1248 if let Some(ep_store) =
1249 cx.update(|cx| EditPredictionStore::try_global(cx))
1250 {
1251 ep_store.update(&mut cx, |store, _| {
1252 store.remove_project(&project);
1253 });
1254 }
1255 }
1256
1257 for example in &mut repo_examples {
1258 example.state.take();
1259 }
1260 finished_examples
1261 .lock()
1262 .unwrap()
1263 .extend_from_slice(&repo_examples);
1264 }
1265 });
1266 }
1267 futures::future::join_all(tasks).await;
1268
1269 Progress::global().finalize();
1270
1271 match &command {
1272 Command::Predict(args) | Command::Score(args) => {
1273 predict::sync_batches(args.provider.as_ref()).await?;
1274 }
1275 Command::Eval(args) => {
1276 predict::sync_batches(args.predict.provider.as_ref()).await?;
1277 }
1278 Command::Qa(args) => {
1279 qa::sync_batches(args).await?;
1280 }
1281 Command::Repair(args) => {
1282 repair::sync_batches(args).await?;
1283 }
1284 _ => (),
1285 }
1286
1287 match &command {
1288 Command::Eval(args) => {
1289 let examples = finished_examples.lock().unwrap();
1290 score::print_report(&examples, args.verbose);
1291 if let Some(summary_path) = &args.summary_json {
1292 score::write_summary_json(&examples, summary_path)?;
1293 }
1294 }
1295 Command::Repair(args) => {
1296 let examples = finished_examples.lock().unwrap();
1297 repair::print_report(&examples, args.confidence_threshold);
1298 }
1299 _ => (),
1300 };
1301
1302 // For --in-place, atomically rename temp file to original
1303 if let Some(temp_path) = &in_place_temp_path {
1304 let final_path = output.as_ref().expect("in_place_temp_path requires output");
1305 std::fs::rename(temp_path, final_path)
1306 .expect("Failed to rename temp file to final output");
1307 }
1308
1309 anyhow::Ok(())
1310 }
1311 .await;
1312
1313 if let Err(e) = result {
1314 panic!("Fatal error: {:?}", e);
1315 }
1316
1317 let _ = cx.update(|cx| cx.quit());
1318 })
1319 .detach();
1320 });
1321}
1322
1323async fn handle_error(
1324 error: anyhow::Error,
1325 args: &EpArgs,
1326 command: &Command,
1327 app_state: &Arc<headless::EpAppState>,
1328 failfast_on_single_example: bool,
1329 example: &Example,
1330) {
1331 Progress::global().increment_failed();
1332
1333 let msg;
1334 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1335 let example_name = example.spec.filename();
1336
1337 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1338 app_state
1339 .fs
1340 .write(
1341 &failed_example_path,
1342 &serde_json::to_vec_pretty(&example).unwrap(),
1343 )
1344 .await
1345 .unwrap();
1346 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1347 app_state
1348 .fs
1349 .write(&err_path, format!("{error:?}").as_bytes())
1350 .await
1351 .unwrap();
1352
1353 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1354 let mut file = OpenOptions::new()
1355 .create(true)
1356 .append(true)
1357 .open(&failed_jsonl_path)
1358 .expect("Failed to open failed.jsonl");
1359 writeln!(file, "{}", serde_json::to_string(example).unwrap())
1360 .expect("Failed to write to failed.jsonl");
1361
1362 let cursor_path = match example.repo_name() {
1363 Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1364 Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1365 };
1366 msg = format!(
1367 indoc::indoc! {"
1368 While processing \"{}\":
1369
1370 \x1b[31m{:?}\x1b[0m
1371
1372 Example: \x1b[36m{}\x1b[0m
1373 Error file: \x1b[36m{}\x1b[0m
1374 Cursor file: \x1b[36m{}\x1b[0m
1375 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1376 "},
1377 example.spec.name,
1378 error,
1379 failed_example_path.display(),
1380 err_path.display(),
1381 cursor_path.display(),
1382 command,
1383 failed_example_path.display(),
1384 );
1385 } else {
1386 msg = format!(
1387 indoc::indoc! {"
1388 While processing \"{}\":
1389
1390 \x1b[31m{:?}\x1b[0m
1391 "},
1392 example.spec.name, error
1393 );
1394 }
1395
1396 if args.failfast || failfast_on_single_example {
1397 Progress::global().finalize();
1398 panic!("{}", msg);
1399 } else {
1400 log::error!("{}", msg);
1401 }
1402}