1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod load_project;
9mod metrics;
10mod openai_client;
11mod parse_output;
12mod paths;
13mod predict;
14mod progress;
15mod prompt_assets;
16mod pull_examples;
17mod qa;
18mod reorder_patch;
19mod repair;
20mod retrieve_context;
21mod reversal_tracking;
22mod score;
23mod split_commit;
24mod split_dataset;
25
26mod synthesize;
27mod truncate_expected_patch;
28mod word_diff;
29use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
30use collections::{HashMap, HashSet};
31use edit_prediction::EditPredictionStore;
32use futures::channel::mpsc;
33use futures::{SinkExt as _, StreamExt as _};
34use gaoya::minhash::{
35 MinHashIndex, MinHasher, MinHasher32, calculate_minhash_params, compute_minhash_similarity,
36};
37use gpui::{AppContext as _, BackgroundExecutor, Task};
38use zeta_prompt::ZetaFormat;
39
40use reqwest_client::ReqwestClient;
41use serde::{Deserialize, Deserializer, Serialize, Serializer};
42use std::env;
43use std::fmt::Display;
44use std::fs::{File, OpenOptions};
45use std::hash::{Hash, Hasher};
46use std::io::{BufRead, BufReader, BufWriter, Write};
47use std::sync::Mutex;
48use std::{path::PathBuf, sync::Arc};
49
50use crate::distill::run_distill;
51use crate::example::{Example, group_examples_by_repo, read_example_files};
52use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
53use crate::format_prompt::run_format_prompt;
54use crate::load_project::run_load_project;
55use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
56use crate::predict::run_prediction;
57use crate::progress::Progress;
58use crate::pull_examples::{fetch_settled_examples_after, parse_settled_after_input};
59use crate::retrieve_context::run_context_retrieval;
60use crate::score::run_scoring;
61use crate::split_commit::SplitCommitArgs;
62use crate::split_dataset::SplitArgs;
63use crate::synthesize::{SynthesizeConfig, run_synthesize};
64use crate::truncate_expected_patch::TruncatePatchArgs;
65
66#[derive(Parser, Debug)]
67#[command(name = "ep")]
68struct EpArgs {
69 #[arg(long, default_value_t = false)]
70 printenv: bool,
71 #[clap(long, default_value_t = 10, global = true)]
72 max_parallelism: usize,
73 /// The limit for the number of examples to process
74 /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
75 #[clap(long, global = true)]
76 limit: Option<usize>,
77 #[clap(long, global = true)]
78 offset: Option<usize>,
79 /// Filter examples by name
80 #[clap(long, global = true)]
81 name: Option<String>,
82 /// Filter examples by repository
83 #[clap(long, global = true)]
84 repo: Option<String>,
85 /// Deduplicate by cursor position and keep at most this many examples per cluster
86 #[clap(long, global = true)]
87 max_duplicates: Option<usize>,
88 #[command(subcommand)]
89 command: Option<Command>,
90 /// Input file paths
91 #[clap(global = true)]
92 inputs: Vec<PathBuf>,
93 #[arg(long, short, global = true)]
94 output: Option<PathBuf>,
95 #[arg(long, short, global = true)]
96 in_place: bool,
97 #[arg(long, short, global = true)]
98 failfast: bool,
99 /// How to handle failed examples in output: keep them or skip them.
100 /// Failed examples are always logged to the run's failed directory.
101 #[arg(long, global = true, default_value = "keep")]
102 failed: FailedHandling,
103 /// Output as markdown files instead of JSONL. When set, -o specifies a directory
104 /// where one .md file per example will be written (named after each example).
105 #[arg(long, short, global = true)]
106 markdown: bool,
107}
108
109/// Controls whether failed examples are included in the main output.
110/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
111#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
112pub enum FailedHandling {
113 /// Include failed examples in the main output (default)
114 #[default]
115 Keep,
116 /// Exclude failed examples from the main output
117 Skip,
118 /// Skip writing files
119 SkipNoFiles,
120}
121
122const INPUTS_HELP: &str = r#"
123Inputs can be file paths or special specifiers:
124
125 path
126 Path to an example(s) file (.md, .json, or .jsonl)
127
128 captured-after:{timestamp}
129 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
130 These are examples captured via the "Capture Edit Prediction Example" action.
131
132 rejected-after:{timestamp}
133 Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
134 These are predictions that were shown to users but rejected (useful for DPO training).
135
136 settled-after:{timestamp}
137 Fetch settled stream examples from Snowflake after the given RFC3339 timestamp.
138 These are examples from the edit prediction settled stream.
139
140 rated-after:{timestamp}
141 Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
142 These are predictions that users explicitly rated as positive or negative via the
143 rate completions modal. Only zeta2 predictions are included.
144 - Positive ratings: output becomes expected_patches
145 - Negative ratings: output becomes rejected_patch
146
147 rated-positive-after:{timestamp}
148 Same as rated-after, but only fetches positively rated predictions.
149
150 rated-negative-after:{timestamp}
151 Same as rated-after, but only fetches negatively rated predictions.
152
153 Required environment variables to connect to Snowflake:
154 EP_SNOWFLAKE_API_KEY
155 EP_SNOWFLAKE_BASE_URL
156
157 Optional:
158 EP_SNOWFLAKE_ROLE
159
160Examples:
161
162 # Read examples from a file
163 ep read examples.jsonl -o output.jsonl
164
165 # Read captured examples after a timestamp
166 ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
167
168 # Read rejected predictions for DPO training
169 ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
170
171 # Read user-rated predictions
172 ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
173
174 # Read settled stream examples
175 ep read settled-after:2025-01-01T00:00:00Z -o settled.jsonl
176
177 # Read only positively rated predictions
178 ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
179
180 # Read only negatively rated predictions
181 ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
182
183 # Mix multiple input sources
184 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
185"#;
186
187#[derive(Subcommand, Debug, Clone)]
188enum Command {
189 /// Read examples from files or fetch from Snowflake, output as .jsonl
190 Read(ReadArgs),
191 /// Create git worktrees for each example and load file contents
192 LoadProject,
193 /// Retrieve context for input examples.
194 Context,
195 /// Generate a prompt string for a specific model
196 FormatPrompt(FormatPromptArgs),
197 /// Runs edit prediction
198 Predict(PredictArgs),
199 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
200 /// Requires format-prompt to have been run first. Uses provider from prompt.
201 ParseOutput,
202 /// Computes a score based on actual and expected patches
203 Score(PredictArgs),
204 /// Prepares a distillation dataset by copying expected outputs to
205 /// predicted outputs and removing actual outputs and prompts.
206 Distill,
207 /// Print aggregated scores
208 Eval(EvalArgs),
209 /// Generate eval examples by analyzing git commits from a repository
210 Synthesize(SynthesizeArgs),
211 /// Remove git repositories and worktrees
212 Clean,
213 /// Generate an evaluation example by splitting a chronologically-ordered commit
214 SplitCommit(SplitCommitArgs),
215 /// Truncate expected patch by the given criteria
216 TruncatePatch(TruncatePatchArgs),
217 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
218 Split(SplitArgs),
219 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
220 FilterLanguages(FilterLanguagesArgs),
221 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
222 ImportBatch(ImportBatchArgs),
223 /// Assess the quality of predictions using LLM-as-a-judge
224 Qa(qa::QaArgs),
225 /// Repair predictions that received poor QA scores by generating improved predictions
226 Repair(repair::RepairArgs),
227 /// Print all valid zeta formats (lowercase, one per line)
228 PrintZetaFormats,
229}
230
231impl Display for Command {
232 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233 match self {
234 Command::Read(_) => write!(f, "read"),
235 Command::LoadProject => write!(f, "load-project"),
236 Command::Context => write!(f, "context"),
237 Command::FormatPrompt(args) => {
238 write!(f, "format-prompt --provider={}", args.provider)
239 }
240 Command::Predict(args) => match &args.provider {
241 Some(provider) => write!(f, "predict --provider={}", provider),
242 None => write!(f, "predict"),
243 },
244 Command::ParseOutput => write!(f, "parse-output"),
245 Command::Score(args) => match &args.provider {
246 Some(provider) => write!(f, "score --provider={}", provider),
247 None => write!(f, "score"),
248 },
249 Command::Distill => write!(f, "distill"),
250 Command::Eval(args) => match &args.predict.provider {
251 Some(provider) => write!(f, "eval --provider={}", provider),
252 None => write!(f, "eval"),
253 },
254 Command::Synthesize(args) => {
255 write!(f, "synthesize --repos {}", args.repos.join(" "))
256 }
257 Command::Clean => write!(f, "clean"),
258 Command::SplitCommit(_) => write!(f, "split-commit"),
259 Command::TruncatePatch(_) => write!(f, "truncate-patch"),
260 Command::Split(_) => write!(f, "split"),
261 Command::FilterLanguages(_) => write!(f, "filter-languages"),
262 Command::ImportBatch(args) => {
263 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
264 }
265 Command::Qa(_) => {
266 write!(f, "qa")
267 }
268 Command::Repair(_) => {
269 write!(f, "repair")
270 }
271 Command::PrintZetaFormats => {
272 write!(f, "print-zeta-formats")
273 }
274 }
275 }
276}
277
278#[derive(Debug, Args, Clone)]
279#[command(after_help = INPUTS_HELP)]
280struct ReadArgs {}
281
282#[derive(Debug, Args, Clone)]
283struct FormatPromptArgs {
284 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
285 provider: PredictionProvider,
286}
287
288#[derive(Debug, Args, Clone)]
289struct PredictArgs {
290 #[clap(long, short('p'))]
291 provider: Option<PredictionProvider>,
292 #[clap(long, default_value_t = 1)]
293 repetitions: usize,
294 /// Only use cached responses, don't queue new requests for batching
295 #[clap(long)]
296 cache_only: bool,
297}
298
299#[derive(Debug, Args, Clone)]
300struct EvalArgs {
301 #[clap(flatten)]
302 predict: PredictArgs,
303 /// Path to write summary scores as JSON
304 #[clap(long)]
305 summary_json: Option<PathBuf>,
306 /// Print all individual example lines (default: up to 20)
307 #[clap(long)]
308 verbose: bool,
309}
310
311#[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)]
312pub enum TeacherBackend {
313 Sonnet46,
314 #[default]
315 Sonnet45,
316 Gpt52,
317}
318
319impl std::fmt::Display for TeacherBackend {
320 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
321 match self {
322 TeacherBackend::Sonnet46 => write!(f, "sonnet46"),
323 TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
324 TeacherBackend::Gpt52 => write!(f, "gpt52"),
325 }
326 }
327}
328
329impl std::str::FromStr for TeacherBackend {
330 type Err = anyhow::Error;
331
332 fn from_str(s: &str) -> Result<Self, Self::Err> {
333 match s.to_lowercase().as_str() {
334 "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
335 "sonnet46" => Ok(TeacherBackend::Sonnet46),
336 "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
337 "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
338 _ => anyhow::bail!(
339 "unknown teacher backend `{s}`. Valid options: sonnet45, sonnet46, gpt52"
340 ),
341 }
342 }
343}
344
345impl TeacherBackend {
346 pub fn model_name(&self) -> &'static str {
347 match self {
348 TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
349 TeacherBackend::Sonnet46 => "claude-sonnet-4-6",
350 TeacherBackend::Gpt52 => "gpt-5.2",
351 }
352 }
353}
354
355#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
356enum PredictionProvider {
357 Sweep,
358 Mercury,
359 Zeta1,
360 Zeta2(ZetaFormat),
361 Teacher(TeacherBackend),
362 TeacherNonBatching(TeacherBackend),
363 Repair,
364}
365
366impl Default for PredictionProvider {
367 fn default() -> Self {
368 PredictionProvider::Zeta2(ZetaFormat::default())
369 }
370}
371
372impl std::fmt::Display for PredictionProvider {
373 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374 match self {
375 PredictionProvider::Sweep => write!(f, "sweep"),
376 PredictionProvider::Mercury => write!(f, "mercury"),
377 PredictionProvider::Zeta1 => write!(f, "zeta1"),
378 PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
379 PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
380 PredictionProvider::TeacherNonBatching(backend) => {
381 write!(f, "teacher-non-batching:{backend}")
382 }
383 PredictionProvider::Repair => write!(f, "repair"),
384 }
385 }
386}
387
388impl std::str::FromStr for PredictionProvider {
389 type Err = anyhow::Error;
390
391 fn from_str(s: &str) -> Result<Self, Self::Err> {
392 let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
393
394 let provider_lower = provider.to_lowercase();
395 match provider_lower.as_str() {
396 "sweep" => Ok(PredictionProvider::Sweep),
397 "mercury" => Ok(PredictionProvider::Mercury),
398 "zeta1" => Ok(PredictionProvider::Zeta1),
399 "zeta2" => {
400 let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
401 Ok(PredictionProvider::Zeta2(format))
402 }
403 "teacher" => {
404 let backend = arg
405 .map(|a| a.parse())
406 .transpose()?
407 .unwrap_or(TeacherBackend::default());
408 Ok(PredictionProvider::Teacher(backend))
409 }
410 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
411 let backend = arg
412 .map(|a| a.parse())
413 .transpose()?
414 .unwrap_or(TeacherBackend::default());
415 Ok(PredictionProvider::TeacherNonBatching(backend))
416 }
417 "repair" => Ok(PredictionProvider::Repair),
418 _ => {
419 anyhow::bail!(
420 "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
421 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
422 For teacher, you can specify a backend like `teacher:sonnet46` or `teacher:gpt52`.\n\
423 Available zeta versions:\n{}",
424 ZetaFormat::options_as_string()
425 )
426 }
427 }
428 }
429}
430
431impl Serialize for PredictionProvider {
432 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
433 where
434 S: Serializer,
435 {
436 serializer.serialize_str(&self.to_string())
437 }
438}
439
440impl<'de> Deserialize<'de> for PredictionProvider {
441 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
442 where
443 D: Deserializer<'de>,
444 {
445 let s = String::deserialize(deserializer)?;
446 s.parse().map_err(serde::de::Error::custom)
447 }
448}
449
450#[derive(Debug, Args, Clone)]
451struct SynthesizeArgs {
452 /// Repository URLs (git@github.com:owner/repo or https://...)
453 #[clap(long, required = true, num_args = 1..)]
454 repos: Vec<String>,
455
456 /// Number of examples to generate per repository
457 #[clap(long, default_value_t = 5)]
458 count: usize,
459
460 /// Maximum commits to scan per repository before giving up
461 #[clap(long, default_value_t = 100)]
462 max_commits: usize,
463
464 /// Ignore state file and reprocess all commits
465 #[clap(long)]
466 fresh: bool,
467}
468
469#[derive(Debug, Args, Clone)]
470struct ImportBatchArgs {
471 /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
472 #[clap(long, required = true, num_args = 1..)]
473 batch_ids: Vec<String>,
474 /// Which provider's batches to import (anthropic or openai)
475 #[clap(long, default_value = "anthropic")]
476 provider: BatchProvider,
477}
478
479#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
480enum BatchProvider {
481 Anthropic,
482 Openai,
483}
484
485impl EpArgs {
486 fn output_path(&self) -> Option<PathBuf> {
487 if self.in_place {
488 if self.inputs.len() == 1 {
489 self.inputs.first().cloned()
490 } else {
491 panic!("--in-place requires exactly one input file")
492 }
493 } else {
494 self.output.clone()
495 }
496 }
497}
498
499/// Minimum Zed version required for Snowflake queries.
500/// This version introduced the current request schema with predicted edits in the edit
501/// history, and open source repos distinguished.
502const MIN_CAPTURE_VERSION: pull_examples::MinCaptureVersion = pull_examples::MinCaptureVersion {
503 minor: 224,
504 patch: 1,
505};
506
507fn deduplicate_examples(examples: &mut Vec<Example>, max_per_cluster: usize) {
508 let total_before_exact = examples.len();
509 let mut seen_positions = HashSet::default();
510 examples.retain(|example| seen_positions.insert(example.spec.cursor_position.clone()));
511 log::info!(
512 "exact duplicate filter: {total_before_exact} examples → {} examples ({} removed)",
513 examples.len(),
514 total_before_exact - examples.len(),
515 );
516
517 const JACCARD_THRESHOLD: f64 = 0.5;
518 const NUM_HASHES: usize = 128;
519 const TOKEN_NGRAM_SIZE: usize = 5;
520
521 let (num_bands, band_width) = calculate_minhash_params(JACCARD_THRESHOLD, NUM_HASHES);
522 let num_hashes = num_bands * band_width;
523 let minhasher = MinHasher32::new(num_hashes);
524 let mut index: MinHashIndex<u32, usize> =
525 MinHashIndex::new(num_bands, band_width, JACCARD_THRESHOLD);
526
527 let signatures: Vec<Vec<u32>> = examples
528 .iter()
529 .map(|example| {
530 let shingles = code_token_ngrams(&example.spec.cursor_position, TOKEN_NGRAM_SIZE);
531 minhasher.create_signature(shingles.iter())
532 })
533 .collect();
534
535 for (id, signature) in signatures.iter().enumerate() {
536 index.insert(id, signature.clone());
537 }
538
539 // Build clusters via union-find on LSH candidate pairs.
540 let mut parent: Vec<usize> = (0..examples.len()).collect();
541
542 fn find(parent: &mut Vec<usize>, mut x: usize) -> usize {
543 while parent[x] != x {
544 parent[x] = parent[parent[x]];
545 x = parent[x];
546 }
547 x
548 }
549
550 for (id, signature) in signatures.iter().enumerate() {
551 for candidate in index.query_owned(signature) {
552 let (a, b) = (find(&mut parent, id), find(&mut parent, candidate));
553 if a != b {
554 parent[a] = b;
555 }
556 }
557 }
558
559 let mut clusters: HashMap<usize, Vec<usize>> = HashMap::default();
560 for id in 0..examples.len() {
561 clusters.entry(find(&mut parent, id)).or_default().push(id);
562 }
563
564 let mut keep: HashSet<usize> = HashSet::default();
565 for members in clusters.values() {
566 let selected = greedy_max_min_diverse(members, &signatures, max_per_cluster);
567 keep.extend(selected);
568 }
569
570 let total = examples.len();
571 let mut kept_indices: Vec<usize> = keep.into_iter().collect();
572 kept_indices.sort();
573
574 let mut retained = Vec::with_capacity(kept_indices.len());
575 for index in kept_indices.into_iter().rev() {
576 retained.push(examples.swap_remove(index));
577 }
578 retained.reverse();
579
580 *examples = retained;
581 log::info!(
582 "near-duplicate filter: {total} examples → {} examples ({} removed)",
583 examples.len(),
584 total - examples.len(),
585 );
586}
587
588fn greedy_max_min_diverse(members: &[usize], signatures: &[Vec<u32>], k: usize) -> Vec<usize> {
589 if members.len() <= k {
590 return members.to_vec();
591 }
592
593 let mut selected = vec![members[0]];
594 let mut min_dist: HashMap<usize, f64> = HashMap::default();
595 for &member in &members[1..] {
596 let dist = 1.0 - compute_minhash_similarity(&signatures[selected[0]], &signatures[member]);
597 min_dist.insert(member, dist);
598 }
599
600 while selected.len() < k {
601 let &best = min_dist
602 .iter()
603 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
604 .map(|(id, _)| id)
605 .expect("min_dist should not be empty when selected.len() < k");
606 selected.push(best);
607 min_dist.remove(&best);
608
609 let best_sig = &signatures[best];
610 for (member, current_min) in min_dist.iter_mut() {
611 let dist = 1.0 - compute_minhash_similarity(best_sig, &signatures[*member]);
612 if dist < *current_min {
613 *current_min = dist;
614 }
615 }
616 }
617
618 selected
619}
620
621fn code_token_ngrams(code: &str, ngram_size: usize) -> Vec<String> {
622 let tokens: Vec<&str> = word_diff::tokenize(code)
623 .into_iter()
624 .filter(|t| !t.trim().is_empty())
625 .collect();
626
627 if tokens.len() < ngram_size {
628 return vec![tokens.join("\0")];
629 }
630
631 tokens
632 .windows(ngram_size)
633 .map(|window| window.join("\0"))
634 .collect()
635}
636
637async fn load_examples(
638 http_client: Arc<dyn http_client::HttpClient>,
639 args: &EpArgs,
640 output_path: Option<&PathBuf>,
641 background_executor: BackgroundExecutor,
642) -> anyhow::Result<Vec<Example>> {
643 let mut captured_after_timestamps = Vec::new();
644 let mut rejected_after_timestamps = Vec::new();
645 let mut requested_after_timestamps = Vec::new();
646 let mut settled_after_timestamps = Vec::new();
647 let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
648 Vec::new();
649 let mut file_inputs = Vec::new();
650
651 for input in &args.inputs {
652 let input_string = input.to_string_lossy();
653 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
654 captured_after_timestamps.push(timestamp.to_string());
655 } else if let Some(timestamp) =
656 pull_examples::parse_rejected_after_input(input_string.as_ref())
657 {
658 rejected_after_timestamps.push(timestamp.to_string());
659 } else if let Some(timestamp) =
660 pull_examples::parse_requested_after_input(input_string.as_ref())
661 {
662 requested_after_timestamps.push(timestamp.to_string());
663 } else if let Some(timestamp) = parse_settled_after_input(input_string.as_ref()) {
664 settled_after_timestamps.push(timestamp.to_string());
665 } else if let Some((timestamp, rating_filter)) =
666 pull_examples::parse_rated_after_input(input_string.as_ref())
667 {
668 rated_after_inputs.push((timestamp.to_string(), rating_filter));
669 } else {
670 file_inputs.push(input.clone());
671 }
672 }
673
674 let mut examples = read_example_files(&file_inputs);
675
676 // Apply offset to file examples first, then pass remaining offset to Snowflake.
677 let file_example_count = examples.len();
678 let remaining_offset = if let Some(offset) = args.offset {
679 if offset >= file_example_count {
680 examples.clear();
681 offset - file_example_count
682 } else {
683 examples.splice(0..offset, []);
684 0
685 }
686 } else {
687 0
688 };
689
690 Progress::global().set_total_examples(examples.len());
691
692 let remaining_limit_for_snowflake =
693 args.limit.map(|limit| limit.saturating_sub(examples.len()));
694
695 if let Some(0) = remaining_limit_for_snowflake {
696 log::info!(
697 "skipping Snowflake inputs because --limit is already satisfied by example files"
698 );
699 } else {
700 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
701
702 if !rejected_after_timestamps.is_empty() {
703 rejected_after_timestamps.sort();
704
705 let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
706 http_client.clone(),
707 &rejected_after_timestamps,
708 max_rows_per_timestamp,
709 remaining_offset,
710 background_executor.clone(),
711 Some(MIN_CAPTURE_VERSION),
712 )
713 .await?;
714 examples.append(&mut rejected_examples);
715 }
716
717 if !requested_after_timestamps.is_empty() {
718 requested_after_timestamps.sort();
719
720 let mut requested_examples = pull_examples::fetch_requested_examples_after(
721 http_client.clone(),
722 &requested_after_timestamps,
723 max_rows_per_timestamp,
724 remaining_offset,
725 background_executor.clone(),
726 Some(MIN_CAPTURE_VERSION),
727 )
728 .await?;
729 examples.append(&mut requested_examples);
730 }
731
732 if !settled_after_timestamps.is_empty() {
733 settled_after_timestamps.sort();
734
735 let mut settled_examples = fetch_settled_examples_after(
736 http_client.clone(),
737 &settled_after_timestamps,
738 max_rows_per_timestamp,
739 remaining_offset,
740 background_executor.clone(),
741 Some(MIN_CAPTURE_VERSION),
742 )
743 .await?;
744 examples.append(&mut settled_examples);
745 }
746
747 if !rated_after_inputs.is_empty() {
748 rated_after_inputs.sort();
749
750 let mut rated_examples = pull_examples::fetch_rated_examples_after(
751 http_client,
752 &rated_after_inputs,
753 max_rows_per_timestamp,
754 remaining_offset,
755 background_executor,
756 Some(MIN_CAPTURE_VERSION),
757 )
758 .await?;
759 examples.append(&mut rated_examples);
760 }
761 }
762
763 crate::example::sort_examples_by_repo_and_rev(&mut examples);
764
765 if let Some(name_filter) = &args.name {
766 examples.retain(|example| example.spec.name.contains(name_filter));
767 }
768 if let Some(repo_filter) = &args.repo {
769 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
770 }
771
772 // Skip resume logic for --in-place since input and output are the same file,
773 // which would incorrectly treat all input examples as already processed.
774 if !args.in_place {
775 if let Some(path) = output_path
776 && let Some(command) = &args.command
777 {
778 resume_from_output(path, &mut examples, command);
779 }
780 }
781
782 if let Some(max_duplicates) = args.max_duplicates {
783 deduplicate_examples(&mut examples, max_duplicates);
784 }
785
786 if let Some(limit) = args.limit {
787 examples.truncate(limit);
788 }
789
790 let progress = Progress::global();
791 progress.set_total_examples(examples.len());
792 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
793
794 Ok(examples)
795}
796
797fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
798 let mut hasher = collections::FxHasher::default();
799 spec.hash(&mut hasher);
800 hasher.finish()
801}
802
803fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
804 let file = match File::open(path) {
805 Ok(f) => f,
806 Err(_) => return,
807 };
808
809 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
810
811 let reader = BufReader::new(file);
812 let mut kept_lines = Vec::new();
813 let mut kept_hashes = HashSet::default();
814
815 for line in reader.lines() {
816 let line = match line {
817 Ok(l) => l,
818 Err(_) => continue,
819 };
820
821 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
822 let hash = spec_hash(&output_example.spec);
823 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
824 let is_complete = match command {
825 Command::Qa(_) => output_example
826 .qa
827 .first()
828 .and_then(|q| q.as_ref())
829 .and_then(|q| q.confidence)
830 .is_some(),
831 Command::Repair(_) => output_example.predictions.iter().any(|p| {
832 p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
833 }),
834 _ => true,
835 };
836 if is_complete {
837 kept_hashes.insert(hash);
838 kept_lines.push(line);
839 }
840 }
841 }
842 }
843
844 let total = examples.len();
845 let already_processed = kept_hashes.len();
846
847 eprintln!(
848 "Resuming: {}/{} examples already processed",
849 already_processed, total
850 );
851
852 let file = OpenOptions::new()
853 .write(true)
854 .truncate(true)
855 .open(path)
856 .expect("Failed to open output file for rewriting");
857 let mut writer = BufWriter::new(file);
858 for line in &kept_lines {
859 writeln!(writer, "{}", line).expect("Failed to write to output file");
860 }
861 writer.flush().expect("Failed to flush output file");
862
863 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
864}
865
866fn main() {
867 let args = EpArgs::parse();
868
869 if args.printenv {
870 ::util::shell_env::print_env();
871 return;
872 }
873
874 let output = args.output_path();
875
876 if args.markdown && output.is_none() {
877 eprintln!("--markdown requires -o to specify the output directory");
878 std::process::exit(1);
879 }
880
881 let command = match &args.command {
882 Some(cmd) => cmd.clone(),
883 None => {
884 EpArgs::command().print_help().unwrap();
885 return;
886 }
887 };
888
889 match &command {
890 Command::ImportBatch(import_args) => {
891 smol::block_on(async {
892 match import_args.provider {
893 BatchProvider::Anthropic => {
894 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
895 .expect("Failed to create Anthropic client");
896 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
897 eprintln!("Error importing Anthropic batches: {:?}", e);
898 std::process::exit(1);
899 }
900 }
901 BatchProvider::Openai => {
902 let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
903 .expect("Failed to create OpenAI client");
904 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
905 eprintln!("Error importing OpenAI batches: {:?}", e);
906 std::process::exit(1);
907 }
908 }
909 }
910 println!(
911 "Successfully imported {} batch(es)",
912 import_args.batch_ids.len()
913 );
914 });
915 return;
916 }
917 Command::Clean => {
918 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
919 return;
920 }
921 Command::PrintZetaFormats => {
922 use strum::IntoEnumIterator as _;
923 for format in ZetaFormat::iter() {
924 println!("{}", format.to_string().to_lowercase());
925 }
926 return;
927 }
928
929 Command::Synthesize(synth_args) => {
930 let output_dir = if let Some(output_dir) = args.output {
931 output_dir
932 } else {
933 let default_output_dir = env::current_dir()
934 .unwrap()
935 .join("crates/edit_prediction_cli/evals-generated");
936 if default_output_dir.parent().unwrap().exists() {
937 std::fs::create_dir(&default_output_dir).ok();
938 default_output_dir
939 } else {
940 panic!("output dir is required");
941 }
942 };
943 let config = SynthesizeConfig {
944 repo_urls: synth_args.repos.clone(),
945 count: synth_args.count,
946 max_commits: synth_args.max_commits,
947 output_dir,
948 fresh: synth_args.fresh,
949 };
950 smol::block_on(async {
951 if let Err(e) = run_synthesize(config).await {
952 eprintln!("Error: {:?}", e);
953 std::process::exit(1);
954 }
955 });
956 return;
957 }
958 Command::SplitCommit(split_commit_args) => {
959 if let Err(error) = split_commit::run_split_commit(
960 split_commit_args,
961 &args.inputs,
962 output.as_ref(),
963 args.failed,
964 ) {
965 eprintln!("{error:#}");
966 std::process::exit(1);
967 }
968 return;
969 }
970 Command::TruncatePatch(truncate_args) => {
971 if let Err(error) =
972 truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
973 {
974 eprintln!("{error:#}");
975 std::process::exit(1);
976 }
977 return;
978 }
979 Command::Split(split_args) => {
980 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
981 eprintln!("{error:#}");
982 std::process::exit(1);
983 }
984 return;
985 }
986 Command::FilterLanguages(filter_args) => {
987 if let Err(error) =
988 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
989 {
990 eprintln!("{error:#}");
991 std::process::exit(1);
992 }
993 return;
994 }
995
996 _ => {}
997 }
998
999 let http_client = Arc::new(ReqwestClient::new());
1000 let app = gpui_platform::headless().with_http_client(http_client);
1001
1002 app.run(move |cx| {
1003 let app_state = Arc::new(headless::init(cx));
1004 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
1005
1006 cx.spawn(async move |cx| {
1007 let result = async {
1008 let examples = load_examples(
1009 app_state.client.http_client(),
1010 &args,
1011 output.as_ref(),
1012 cx.background_executor().clone(),
1013 )
1014 .await?;
1015
1016 match &command {
1017 Command::Predict(args) | Command::Score(args) => {
1018 predict::sync_batches(args.provider.as_ref()).await?;
1019 }
1020 Command::Eval(args) => {
1021 predict::sync_batches(args.predict.provider.as_ref()).await?;
1022 }
1023 Command::Qa(args) => {
1024 qa::sync_batches(args).await?;
1025 }
1026 Command::Repair(args) => {
1027 repair::sync_batches(args).await?;
1028 }
1029 _ => (),
1030 }
1031
1032 let failfast_on_single_example = examples.len() == 1;
1033
1034 // For --markdown mode, create the output directory if it doesn't exist
1035 if args.markdown {
1036 let dir = output.as_ref().expect("--markdown requires -o");
1037 if !dir.exists() {
1038 std::fs::create_dir_all(dir)
1039 .expect("Failed to create markdown output directory");
1040 }
1041 }
1042
1043 // Set up JSONL output writer (not used in markdown mode)
1044 let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
1045 let mut in_place_temp_path: Option<PathBuf> = None;
1046 if !args.markdown
1047 && let Some(output_path) = output.as_ref()
1048 {
1049 let write_path = if args.in_place {
1050 let temp = output_path.with_extension("jsonl.tmp");
1051 in_place_temp_path = Some(temp.clone());
1052 temp
1053 } else {
1054 output_path.clone()
1055 };
1056
1057 let file = OpenOptions::new()
1058 .create(true)
1059 .write(true)
1060 .truncate(args.in_place)
1061 .append(!args.in_place)
1062 .open(&write_path)
1063 .expect("Failed to open output file");
1064
1065 let mut writer = BufWriter::new(file);
1066 let (sender, mut receiver) = mpsc::unbounded::<String>();
1067 cx.background_spawn(async move {
1068 while let Some(line) = receiver.next().await {
1069 writeln!(writer, "{}", line).expect("Failed to write example");
1070 writer.flush().expect("Failed to flush output");
1071 }
1072 })
1073 .detach();
1074 output_sender = Some(sender);
1075 }
1076
1077 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
1078 let finished_examples = Mutex::new(Vec::new());
1079
1080 let mut tasks = Vec::new();
1081 for _ in 0..args.max_parallelism {
1082 tasks.push(async {
1083 loop {
1084 let Some(mut repo_examples) =
1085 grouped_examples.lock().unwrap().pop_front()
1086 else {
1087 break;
1088 };
1089 for example in &mut repo_examples {
1090 let example_progress =
1091 Progress::global().start_group(&example.spec.name);
1092
1093 let result = async {
1094 match &command {
1095 Command::Read(_) => {}
1096 Command::LoadProject => {
1097 run_load_project(
1098 example,
1099 app_state.clone(),
1100 &example_progress,
1101 cx.clone(),
1102 )
1103 .await?;
1104 }
1105 Command::Context => {
1106 run_context_retrieval(
1107 example,
1108 app_state.clone(),
1109 &example_progress,
1110 cx.clone(),
1111 )
1112 .await?;
1113 }
1114 Command::FormatPrompt(args) => {
1115 run_format_prompt(
1116 example,
1117 args,
1118 app_state.clone(),
1119 &example_progress,
1120 cx.clone(),
1121 )
1122 .await?;
1123 }
1124 Command::Predict(args) => {
1125 run_prediction(
1126 example,
1127 args,
1128 app_state.clone(),
1129 &example_progress,
1130 cx.clone(),
1131 )
1132 .await?;
1133 }
1134 Command::ParseOutput => {
1135 parse_output::run_parse_output(example)?;
1136 }
1137 Command::Distill => {
1138 run_distill(example).await?;
1139 }
1140 Command::Score(args) => {
1141 run_scoring(
1142 example,
1143 args,
1144 app_state.clone(),
1145 &example_progress,
1146 cx.clone(),
1147 )
1148 .await?;
1149 }
1150 Command::Eval(args) => {
1151 run_scoring(
1152 example,
1153 &args.predict,
1154 app_state.clone(),
1155 &example_progress,
1156 cx.clone(),
1157 )
1158 .await?;
1159 }
1160 Command::Qa(args) => {
1161 qa::run_qa(example, args, &example_progress).await?;
1162 }
1163 Command::Repair(args) => {
1164 repair::run_repair(example, args, &example_progress)
1165 .await?;
1166 }
1167 Command::Clean
1168 | Command::Synthesize(_)
1169 | Command::SplitCommit(_)
1170 | Command::Split(_)
1171 | Command::TruncatePatch(_)
1172 | Command::FilterLanguages(_)
1173 | Command::ImportBatch(_)
1174 | Command::PrintZetaFormats => {
1175 unreachable!()
1176 }
1177 }
1178 anyhow::Ok(())
1179 }
1180 .await;
1181
1182 let failed = if let Err(error) = result {
1183 handle_error(
1184 error,
1185 &args,
1186 &command,
1187 &app_state,
1188 failfast_on_single_example,
1189 &example,
1190 )
1191 .await;
1192 true
1193 } else {
1194 false
1195 };
1196
1197 let should_write = !failed || args.failed == FailedHandling::Keep;
1198 if should_write {
1199 if args.markdown {
1200 let markdown_dir =
1201 output.as_ref().expect("--markdown requires -o");
1202 let filename = format!("{}.md", example.spec.filename());
1203 let path = markdown_dir.join(&filename);
1204 let markdown = example.spec.to_markdown();
1205 std::fs::write(&path, &markdown)
1206 .expect("Failed to write markdown file");
1207 } else if let Some(ref mut sender) = output_sender.clone() {
1208 let line = serde_json::to_string(&example).unwrap();
1209 sender
1210 .send(line)
1211 .await
1212 .expect("Failed to send to output writer");
1213 } else if args.output.is_none()
1214 && !matches!(command, Command::Eval(_))
1215 {
1216 let line = serde_json::to_string(&example).unwrap();
1217 println!("{}", line);
1218 }
1219 }
1220 }
1221
1222 let project = repo_examples
1223 .iter()
1224 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()));
1225
1226 if let Some(project) = project {
1227 let mut cx = cx.clone();
1228
1229 let shutdown_task: Task<()> =
1230 project.update(&mut cx, |project, cx| {
1231 let lsp_store = project.lsp_store();
1232 lsp_store.update(cx, |lsp_store, cx| {
1233 lsp_store.shutdown_all_language_servers(cx)
1234 })
1235 });
1236
1237 shutdown_task.await;
1238
1239 if let Some(ep_store) =
1240 cx.update(|cx| EditPredictionStore::try_global(cx))
1241 {
1242 ep_store.update(&mut cx, |store, _| {
1243 store.remove_project(&project);
1244 });
1245 }
1246 }
1247
1248 for example in &mut repo_examples {
1249 example.state.take();
1250 }
1251 finished_examples
1252 .lock()
1253 .unwrap()
1254 .extend_from_slice(&repo_examples);
1255 }
1256 });
1257 }
1258 futures::future::join_all(tasks).await;
1259
1260 Progress::global().finalize();
1261
1262 match &command {
1263 Command::Predict(args) | Command::Score(args) => {
1264 predict::sync_batches(args.provider.as_ref()).await?;
1265 }
1266 Command::Eval(args) => {
1267 predict::sync_batches(args.predict.provider.as_ref()).await?;
1268 }
1269 Command::Qa(args) => {
1270 qa::sync_batches(args).await?;
1271 }
1272 Command::Repair(args) => {
1273 repair::sync_batches(args).await?;
1274 }
1275 _ => (),
1276 }
1277
1278 match &command {
1279 Command::Eval(args) => {
1280 let examples = finished_examples.lock().unwrap();
1281 score::print_report(&examples, args.verbose);
1282 if let Some(summary_path) = &args.summary_json {
1283 score::write_summary_json(&examples, summary_path)?;
1284 }
1285 }
1286 Command::Repair(args) => {
1287 let examples = finished_examples.lock().unwrap();
1288 repair::print_report(&examples, args.confidence_threshold);
1289 }
1290 _ => (),
1291 };
1292
1293 // For --in-place, atomically rename temp file to original
1294 if let Some(temp_path) = &in_place_temp_path {
1295 let final_path = output.as_ref().expect("in_place_temp_path requires output");
1296 std::fs::rename(temp_path, final_path)
1297 .expect("Failed to rename temp file to final output");
1298 }
1299
1300 anyhow::Ok(())
1301 }
1302 .await;
1303
1304 if let Err(e) = result {
1305 panic!("Fatal error: {:?}", e);
1306 }
1307
1308 let _ = cx.update(|cx| cx.quit());
1309 })
1310 .detach();
1311 });
1312}
1313
1314async fn handle_error(
1315 error: anyhow::Error,
1316 args: &EpArgs,
1317 command: &Command,
1318 app_state: &Arc<headless::EpAppState>,
1319 failfast_on_single_example: bool,
1320 example: &Example,
1321) {
1322 Progress::global().increment_failed();
1323
1324 let msg;
1325 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1326 let example_name = example.spec.filename();
1327
1328 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1329 app_state
1330 .fs
1331 .write(
1332 &failed_example_path,
1333 &serde_json::to_vec_pretty(&example).unwrap(),
1334 )
1335 .await
1336 .unwrap();
1337 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1338 app_state
1339 .fs
1340 .write(&err_path, format!("{error:?}").as_bytes())
1341 .await
1342 .unwrap();
1343
1344 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1345 let mut file = OpenOptions::new()
1346 .create(true)
1347 .append(true)
1348 .open(&failed_jsonl_path)
1349 .expect("Failed to open failed.jsonl");
1350 writeln!(file, "{}", serde_json::to_string(example).unwrap())
1351 .expect("Failed to write to failed.jsonl");
1352
1353 let cursor_path = match example.repo_name() {
1354 Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1355 Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1356 };
1357 msg = format!(
1358 indoc::indoc! {"
1359 While processing \"{}\":
1360
1361 \x1b[31m{:?}\x1b[0m
1362
1363 Example: \x1b[36m{}\x1b[0m
1364 Error file: \x1b[36m{}\x1b[0m
1365 Cursor file: \x1b[36m{}\x1b[0m
1366 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1367 "},
1368 example.spec.name,
1369 error,
1370 failed_example_path.display(),
1371 err_path.display(),
1372 cursor_path.display(),
1373 command,
1374 failed_example_path.display(),
1375 );
1376 } else {
1377 msg = format!(
1378 indoc::indoc! {"
1379 While processing \"{}\":
1380
1381 \x1b[31m{:?}\x1b[0m
1382 "},
1383 example.spec.name, error
1384 );
1385 }
1386
1387 if args.failfast || failfast_on_single_example {
1388 Progress::global().finalize();
1389 panic!("{}", msg);
1390 } else {
1391 log::error!("{}", msg);
1392 }
1393}