1mod anthropic_client;
2mod distill;
3mod example;
4mod filter_languages;
5mod format_prompt;
6mod git;
7mod headless;
8mod load_project;
9mod metrics;
10mod openai_client;
11mod parse_output;
12mod paths;
13mod predict;
14mod progress;
15mod prompt_assets;
16mod pull_examples;
17mod qa;
18mod reorder_patch;
19mod repair;
20mod retrieve_context;
21mod reversal_tracking;
22mod score;
23mod split_commit;
24mod split_dataset;
25
26mod synthesize;
27mod truncate_expected_patch;
28mod word_diff;
29use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
30use collections::{HashMap, HashSet};
31use edit_prediction::EditPredictionStore;
32use futures::channel::mpsc;
33use futures::{SinkExt as _, StreamExt as _};
34use gaoya::minhash::{
35 MinHashIndex, MinHasher, MinHasher32, calculate_minhash_params, compute_minhash_similarity,
36};
37use gpui::{AppContext as _, BackgroundExecutor, Task};
38use zeta_prompt::ZetaFormat;
39
40use reqwest_client::ReqwestClient;
41use serde::{Deserialize, Deserializer, Serialize, Serializer};
42use std::env;
43use std::fmt::Display;
44use std::fs::{File, OpenOptions};
45use std::hash::{Hash, Hasher};
46use std::io::{BufRead, BufReader, BufWriter, Write};
47use std::sync::Mutex;
48use std::{path::PathBuf, sync::Arc};
49
50use crate::distill::run_distill;
51use crate::example::{Example, group_examples_by_repo, read_example_files};
52use crate::filter_languages::{FilterLanguagesArgs, run_filter_languages};
53use crate::format_prompt::run_format_prompt;
54use crate::load_project::run_load_project;
55use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
56use crate::predict::run_prediction;
57use crate::progress::Progress;
58use crate::pull_examples::{fetch_settled_examples_after, parse_settled_after_input};
59use crate::retrieve_context::run_context_retrieval;
60use crate::score::run_scoring;
61use crate::split_commit::SplitCommitArgs;
62use crate::split_dataset::SplitArgs;
63use crate::synthesize::{SynthesizeConfig, run_synthesize};
64use crate::truncate_expected_patch::TruncatePatchArgs;
65
66#[derive(Parser, Debug)]
67#[command(name = "ep")]
68struct EpArgs {
69 #[arg(long, default_value_t = false)]
70 printenv: bool,
71 #[clap(long, default_value_t = 10, global = true)]
72 max_parallelism: usize,
73 /// The limit for the number of examples to process
74 /// Default is unlimited for processing local datasets, 5000 when pulling from snowflake
75 #[clap(long, global = true)]
76 limit: Option<usize>,
77 #[clap(long, global = true)]
78 offset: Option<usize>,
79 /// Filter examples by name
80 #[clap(long, global = true)]
81 name: Option<String>,
82 /// Filter examples by repository
83 #[clap(long, global = true)]
84 repo: Option<String>,
85 /// Deduplicate by cursor position and keep at most this many examples per cluster
86 #[clap(long, global = true)]
87 max_duplicates: Option<usize>,
88 #[command(subcommand)]
89 command: Option<Command>,
90 /// Input file paths
91 #[clap(global = true)]
92 inputs: Vec<PathBuf>,
93 #[arg(long, short, global = true)]
94 output: Option<PathBuf>,
95 #[arg(long, short, global = true)]
96 in_place: bool,
97 #[arg(long, short, global = true)]
98 failfast: bool,
99 /// How to handle failed examples in output: keep them or skip them.
100 /// Failed examples are always logged to the run's failed directory.
101 #[arg(long, global = true, default_value = "keep")]
102 failed: FailedHandling,
103 /// Output as markdown files instead of JSONL. When set, -o specifies a directory
104 /// where one .md file per example will be written (named after each example).
105 #[arg(long, short, global = true)]
106 markdown: bool,
107}
108
109/// Controls whether failed examples are included in the main output.
110/// Failed examples are always logged to the run's failed/ directory regardless of this setting.
111#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)]
112pub enum FailedHandling {
113 /// Include failed examples in the main output (default)
114 #[default]
115 Keep,
116 /// Exclude failed examples from the main output
117 Skip,
118 /// Skip writing files
119 SkipNoFiles,
120}
121
122const INPUTS_HELP: &str = r#"
123Inputs can be file paths or special specifiers:
124
125 path
126 Path to an example(s) file (.md, .json, or .jsonl)
127
128 captured-after:{timestamp}
129 Fetch captured examples from Snowflake after the given RFC3339 timestamp.
130 These are examples captured via the "Capture Edit Prediction Example" action.
131
132 rejected-after:{timestamp}
133 Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
134 These are predictions that were shown to users but rejected (useful for DPO training).
135
136 settled-after:{timestamp}
137 Fetch settled stream examples from Snowflake after the given RFC3339 timestamp.
138 These are examples from the edit prediction settled stream.
139
140 rated-after:{timestamp}
141 Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
142 These are predictions that users explicitly rated as positive or negative via the
143 rate completions modal. Only zeta2 predictions are included.
144 - Positive ratings: output becomes expected_patches
145 - Negative ratings: output becomes rejected_patch
146
147 rated-positive-after:{timestamp}
148 Same as rated-after, but only fetches positively rated predictions.
149
150 rated-negative-after:{timestamp}
151 Same as rated-after, but only fetches negatively rated predictions.
152
153 Required environment variables to connect to Snowflake:
154 EP_SNOWFLAKE_API_KEY
155 EP_SNOWFLAKE_BASE_URL
156
157 Optional:
158 EP_SNOWFLAKE_ROLE
159
160Examples:
161
162 # Read examples from a file
163 ep read examples.jsonl -o output.jsonl
164
165 # Read captured examples after a timestamp
166 ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
167
168 # Read rejected predictions for DPO training
169 ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
170
171 # Read user-rated predictions
172 ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
173
174 # Read settled stream examples
175 ep read settled-after:2025-01-01T00:00:00Z -o settled.jsonl
176
177 # Read only positively rated predictions
178 ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
179
180 # Read only negatively rated predictions
181 ep read rated-negative-after:2025-01-01T00:00:00Z -o negative.jsonl
182
183 # Mix multiple input sources
184 ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
185"#;
186
187#[derive(Subcommand, Debug, Clone)]
188enum Command {
189 /// Read examples from files or fetch from Snowflake, output as .jsonl
190 Read(ReadArgs),
191 /// Create git worktrees for each example and load file contents
192 LoadProject,
193 /// Retrieve context for input examples.
194 Context,
195 /// Generate a prompt string for a specific model
196 FormatPrompt(FormatPromptArgs),
197 /// Runs edit prediction
198 Predict(PredictArgs),
199 /// Parse model outputs (actual_output) into unified diffs (actual_patch).
200 /// Requires format-prompt to have been run first. Uses provider from prompt.
201 ParseOutput,
202 /// Computes a score based on actual and expected patches
203 Score(PredictArgs),
204 /// Prepares a distillation dataset by copying expected outputs to
205 /// predicted outputs and removing actual outputs and prompts.
206 Distill,
207 /// Print aggregated scores
208 Eval(EvalArgs),
209 /// Generate eval examples by analyzing git commits from a repository
210 Synthesize(SynthesizeArgs),
211 /// Remove git repositories and worktrees
212 Clean,
213 /// Generate an evaluation example by splitting a chronologically-ordered commit
214 SplitCommit(SplitCommitArgs),
215 /// Truncate expected patch by the given criteria
216 TruncatePatch(TruncatePatchArgs),
217 /// Split a JSONL dataset into multiple files (stratified by repository_url if present)
218 Split(SplitArgs),
219 /// Filter a JSONL dataset by programming language (based on cursor_path extension)
220 FilterLanguages(FilterLanguagesArgs),
221 /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
222 ImportBatch(ImportBatchArgs),
223 /// Assess the quality of predictions using LLM-as-a-judge
224 Qa(qa::QaArgs),
225 /// Repair predictions that received poor QA scores by generating improved predictions
226 Repair(repair::RepairArgs),
227 /// Print all valid zeta formats (lowercase, one per line)
228 PrintZetaFormats,
229}
230
231impl Display for Command {
232 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233 match self {
234 Command::Read(_) => write!(f, "read"),
235 Command::LoadProject => write!(f, "load-project"),
236 Command::Context => write!(f, "context"),
237 Command::FormatPrompt(args) => {
238 write!(f, "format-prompt --provider={}", args.provider)
239 }
240 Command::Predict(args) => match &args.provider {
241 Some(provider) => write!(f, "predict --provider={}", provider),
242 None => write!(f, "predict"),
243 },
244 Command::ParseOutput => write!(f, "parse-output"),
245 Command::Score(args) => match &args.provider {
246 Some(provider) => write!(f, "score --provider={}", provider),
247 None => write!(f, "score"),
248 },
249 Command::Distill => write!(f, "distill"),
250 Command::Eval(args) => match &args.predict.provider {
251 Some(provider) => write!(f, "eval --provider={}", provider),
252 None => write!(f, "eval"),
253 },
254 Command::Synthesize(args) => {
255 write!(f, "synthesize --repos {}", args.repos.join(" "))
256 }
257 Command::Clean => write!(f, "clean"),
258 Command::SplitCommit(_) => write!(f, "split-commit"),
259 Command::TruncatePatch(_) => write!(f, "truncate-patch"),
260 Command::Split(_) => write!(f, "split"),
261 Command::FilterLanguages(_) => write!(f, "filter-languages"),
262 Command::ImportBatch(args) => {
263 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
264 }
265 Command::Qa(_) => {
266 write!(f, "qa")
267 }
268 Command::Repair(_) => {
269 write!(f, "repair")
270 }
271 Command::PrintZetaFormats => {
272 write!(f, "print-zeta-formats")
273 }
274 }
275 }
276}
277
278#[derive(Debug, Args, Clone)]
279#[command(after_help = INPUTS_HELP)]
280struct ReadArgs {}
281
282#[derive(Debug, Args, Clone)]
283struct FormatPromptArgs {
284 #[clap(long, short('p'), default_value_t = PredictionProvider::default())]
285 provider: PredictionProvider,
286}
287
288#[derive(Debug, Args, Clone)]
289struct PredictArgs {
290 #[clap(long, short('p'))]
291 provider: Option<PredictionProvider>,
292 #[clap(long, default_value_t = 1)]
293 repetitions: usize,
294 /// Only use cached responses, don't queue new requests for batching
295 #[clap(long)]
296 cache_only: bool,
297}
298
299#[derive(Debug, Args, Clone)]
300struct EvalArgs {
301 #[clap(flatten)]
302 predict: PredictArgs,
303 /// Path to write summary scores as JSON
304 #[clap(long)]
305 summary_json: Option<PathBuf>,
306 /// Print all individual example lines (default: up to 20)
307 #[clap(long)]
308 verbose: bool,
309}
310
311#[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)]
312pub enum TeacherBackend {
313 Sonnet46,
314 #[default]
315 Sonnet45,
316 Gpt52,
317}
318
319impl std::fmt::Display for TeacherBackend {
320 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
321 match self {
322 TeacherBackend::Sonnet46 => write!(f, "sonnet46"),
323 TeacherBackend::Sonnet45 => write!(f, "sonnet45"),
324 TeacherBackend::Gpt52 => write!(f, "gpt52"),
325 }
326 }
327}
328
329impl std::str::FromStr for TeacherBackend {
330 type Err = anyhow::Error;
331
332 fn from_str(s: &str) -> Result<Self, Self::Err> {
333 match s.to_lowercase().as_str() {
334 "sonnet45" | "sonnet" | "claude" => Ok(TeacherBackend::Sonnet45),
335 "sonnet46" => Ok(TeacherBackend::Sonnet46),
336 "gpt52" | "gpt" | "openai" => Ok(TeacherBackend::Gpt52),
337 "v0114180editableregion" => Ok(TeacherBackend::Sonnet45),
338 _ => anyhow::bail!(
339 "unknown teacher backend `{s}`. Valid options: sonnet45, sonnet46, gpt52"
340 ),
341 }
342 }
343}
344
345impl TeacherBackend {
346 pub fn model_name(&self) -> &'static str {
347 match self {
348 TeacherBackend::Sonnet45 => "claude-sonnet-4-5",
349 TeacherBackend::Sonnet46 => "claude-sonnet-4-6",
350 TeacherBackend::Gpt52 => "gpt-5.2",
351 }
352 }
353}
354
355#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
356enum PredictionProvider {
357 Sweep,
358 Mercury,
359 Zeta1,
360 Zeta2(ZetaFormat),
361 Baseten(ZetaFormat),
362 Teacher(TeacherBackend),
363 TeacherNonBatching(TeacherBackend),
364 Repair,
365}
366
367impl Default for PredictionProvider {
368 fn default() -> Self {
369 PredictionProvider::Zeta2(ZetaFormat::default())
370 }
371}
372
373impl std::fmt::Display for PredictionProvider {
374 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
375 match self {
376 PredictionProvider::Sweep => write!(f, "sweep"),
377 PredictionProvider::Mercury => write!(f, "mercury"),
378 PredictionProvider::Zeta1 => write!(f, "zeta1"),
379 PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
380 PredictionProvider::Baseten(format) => write!(f, "baseten:{format}"),
381 PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
382 PredictionProvider::TeacherNonBatching(backend) => {
383 write!(f, "teacher-non-batching:{backend}")
384 }
385 PredictionProvider::Repair => write!(f, "repair"),
386 }
387 }
388}
389
390impl std::str::FromStr for PredictionProvider {
391 type Err = anyhow::Error;
392
393 fn from_str(s: &str) -> Result<Self, Self::Err> {
394 let (provider, arg) = s.split_once(':').map_or((s, None), |(p, a)| (p, Some(a)));
395
396 let provider_lower = provider.to_lowercase();
397 match provider_lower.as_str() {
398 "sweep" => Ok(PredictionProvider::Sweep),
399 "mercury" => Ok(PredictionProvider::Mercury),
400 "zeta1" => Ok(PredictionProvider::Zeta1),
401 "zeta2" => {
402 let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
403 Ok(PredictionProvider::Zeta2(format))
404 }
405 "teacher" => {
406 let backend = arg
407 .map(|a| a.parse())
408 .transpose()?
409 .unwrap_or(TeacherBackend::default());
410 Ok(PredictionProvider::Teacher(backend))
411 }
412 "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
413 let backend = arg
414 .map(|a| a.parse())
415 .transpose()?
416 .unwrap_or(TeacherBackend::default());
417 Ok(PredictionProvider::TeacherNonBatching(backend))
418 }
419 "repair" => Ok(PredictionProvider::Repair),
420 "baseten" => {
421 let format = arg
422 .map(ZetaFormat::parse)
423 .transpose()?
424 .unwrap_or(ZetaFormat::default());
425 Ok(PredictionProvider::Baseten(format))
426 }
427 _ => {
428 anyhow::bail!(
429 "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
430 For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
431 For teacher, you can specify a backend like `teacher:sonnet46` or `teacher:gpt52`.\n\
432 Available zeta versions:\n{}",
433 ZetaFormat::options_as_string()
434 )
435 }
436 }
437 }
438}
439
440impl Serialize for PredictionProvider {
441 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
442 where
443 S: Serializer,
444 {
445 serializer.serialize_str(&self.to_string())
446 }
447}
448
449impl<'de> Deserialize<'de> for PredictionProvider {
450 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
451 where
452 D: Deserializer<'de>,
453 {
454 let s = String::deserialize(deserializer)?;
455 s.parse().map_err(serde::de::Error::custom)
456 }
457}
458
459#[derive(Debug, Args, Clone)]
460struct SynthesizeArgs {
461 /// Repository URLs (git@github.com:owner/repo or https://...)
462 #[clap(long, required = true, num_args = 1..)]
463 repos: Vec<String>,
464
465 /// Number of examples to generate per repository
466 #[clap(long, default_value_t = 5)]
467 count: usize,
468
469 /// Maximum commits to scan per repository before giving up
470 #[clap(long, default_value_t = 100)]
471 max_commits: usize,
472
473 /// Ignore state file and reprocess all commits
474 #[clap(long)]
475 fresh: bool,
476}
477
478#[derive(Debug, Args, Clone)]
479struct ImportBatchArgs {
480 /// Batch IDs to import (e.g., msgbatch_xxx for Anthropic, batch_xxx for OpenAI)
481 #[clap(long, required = true, num_args = 1..)]
482 batch_ids: Vec<String>,
483 /// Which provider's batches to import (anthropic or openai)
484 #[clap(long, default_value = "anthropic")]
485 provider: BatchProvider,
486}
487
488#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
489enum BatchProvider {
490 Anthropic,
491 Openai,
492}
493
494impl EpArgs {
495 fn output_path(&self) -> Option<PathBuf> {
496 if self.in_place {
497 if self.inputs.len() == 1 {
498 self.inputs.first().cloned()
499 } else {
500 panic!("--in-place requires exactly one input file")
501 }
502 } else {
503 self.output.clone()
504 }
505 }
506}
507
508/// Minimum Zed version required for Snowflake queries.
509/// This version introduced the current request schema with predicted edits in the edit
510/// history, and open source repos distinguished.
511const MIN_CAPTURE_VERSION: pull_examples::MinCaptureVersion = pull_examples::MinCaptureVersion {
512 minor: 224,
513 patch: 1,
514};
515
516fn deduplicate_examples(examples: &mut Vec<Example>, max_per_cluster: usize) {
517 let total_before_exact = examples.len();
518 let mut seen_positions = HashSet::default();
519 examples.retain(|example| seen_positions.insert(example.spec.cursor_position.clone()));
520 log::info!(
521 "exact duplicate filter: {total_before_exact} examples → {} examples ({} removed)",
522 examples.len(),
523 total_before_exact - examples.len(),
524 );
525
526 const JACCARD_THRESHOLD: f64 = 0.5;
527 const NUM_HASHES: usize = 128;
528 const TOKEN_NGRAM_SIZE: usize = 5;
529
530 let (num_bands, band_width) = calculate_minhash_params(JACCARD_THRESHOLD, NUM_HASHES);
531 let num_hashes = num_bands * band_width;
532 let minhasher = MinHasher32::new(num_hashes);
533 let mut index: MinHashIndex<u32, usize> =
534 MinHashIndex::new(num_bands, band_width, JACCARD_THRESHOLD);
535
536 let signatures: Vec<Vec<u32>> = examples
537 .iter()
538 .map(|example| {
539 let shingles = code_token_ngrams(&example.spec.cursor_position, TOKEN_NGRAM_SIZE);
540 minhasher.create_signature(shingles.iter())
541 })
542 .collect();
543
544 for (id, signature) in signatures.iter().enumerate() {
545 index.insert(id, signature.clone());
546 }
547
548 // Build clusters via union-find on LSH candidate pairs.
549 let mut parent: Vec<usize> = (0..examples.len()).collect();
550
551 fn find(parent: &mut Vec<usize>, mut x: usize) -> usize {
552 while parent[x] != x {
553 parent[x] = parent[parent[x]];
554 x = parent[x];
555 }
556 x
557 }
558
559 for (id, signature) in signatures.iter().enumerate() {
560 for candidate in index.query_owned(signature) {
561 let (a, b) = (find(&mut parent, id), find(&mut parent, candidate));
562 if a != b {
563 parent[a] = b;
564 }
565 }
566 }
567
568 let mut clusters: HashMap<usize, Vec<usize>> = HashMap::default();
569 for id in 0..examples.len() {
570 clusters.entry(find(&mut parent, id)).or_default().push(id);
571 }
572
573 let mut keep: HashSet<usize> = HashSet::default();
574 for members in clusters.values() {
575 let selected = greedy_max_min_diverse(members, &signatures, max_per_cluster);
576 keep.extend(selected);
577 }
578
579 let total = examples.len();
580 let mut kept_indices: Vec<usize> = keep.into_iter().collect();
581 kept_indices.sort();
582
583 let mut retained = Vec::with_capacity(kept_indices.len());
584 for index in kept_indices.into_iter().rev() {
585 retained.push(examples.swap_remove(index));
586 }
587 retained.reverse();
588
589 *examples = retained;
590 log::info!(
591 "near-duplicate filter: {total} examples → {} examples ({} removed)",
592 examples.len(),
593 total - examples.len(),
594 );
595}
596
597fn greedy_max_min_diverse(members: &[usize], signatures: &[Vec<u32>], k: usize) -> Vec<usize> {
598 if members.len() <= k {
599 return members.to_vec();
600 }
601
602 let mut selected = vec![members[0]];
603 let mut min_dist: HashMap<usize, f64> = HashMap::default();
604 for &member in &members[1..] {
605 let dist = 1.0 - compute_minhash_similarity(&signatures[selected[0]], &signatures[member]);
606 min_dist.insert(member, dist);
607 }
608
609 while selected.len() < k {
610 let &best = min_dist
611 .iter()
612 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
613 .map(|(id, _)| id)
614 .expect("min_dist should not be empty when selected.len() < k");
615 selected.push(best);
616 min_dist.remove(&best);
617
618 let best_sig = &signatures[best];
619 for (member, current_min) in min_dist.iter_mut() {
620 let dist = 1.0 - compute_minhash_similarity(best_sig, &signatures[*member]);
621 if dist < *current_min {
622 *current_min = dist;
623 }
624 }
625 }
626
627 selected
628}
629
630fn code_token_ngrams(code: &str, ngram_size: usize) -> Vec<String> {
631 let tokens: Vec<&str> = word_diff::tokenize(code)
632 .into_iter()
633 .filter(|t| !t.trim().is_empty())
634 .collect();
635
636 if tokens.len() < ngram_size {
637 return vec![tokens.join("\0")];
638 }
639
640 tokens
641 .windows(ngram_size)
642 .map(|window| window.join("\0"))
643 .collect()
644}
645
646async fn load_examples(
647 http_client: Arc<dyn http_client::HttpClient>,
648 args: &EpArgs,
649 output_path: Option<&PathBuf>,
650 background_executor: BackgroundExecutor,
651) -> anyhow::Result<Vec<Example>> {
652 let mut captured_after_timestamps = Vec::new();
653 let mut rejected_after_timestamps = Vec::new();
654 let mut requested_after_timestamps = Vec::new();
655 let mut settled_after_timestamps = Vec::new();
656 let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
657 Vec::new();
658 let mut file_inputs = Vec::new();
659
660 for input in &args.inputs {
661 let input_string = input.to_string_lossy();
662 if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
663 captured_after_timestamps.push(timestamp.to_string());
664 } else if let Some(timestamp) =
665 pull_examples::parse_rejected_after_input(input_string.as_ref())
666 {
667 rejected_after_timestamps.push(timestamp.to_string());
668 } else if let Some(timestamp) =
669 pull_examples::parse_requested_after_input(input_string.as_ref())
670 {
671 requested_after_timestamps.push(timestamp.to_string());
672 } else if let Some(timestamp) = parse_settled_after_input(input_string.as_ref()) {
673 settled_after_timestamps.push(timestamp.to_string());
674 } else if let Some((timestamp, rating_filter)) =
675 pull_examples::parse_rated_after_input(input_string.as_ref())
676 {
677 rated_after_inputs.push((timestamp.to_string(), rating_filter));
678 } else {
679 file_inputs.push(input.clone());
680 }
681 }
682
683 let mut examples = read_example_files(&file_inputs);
684
685 // Apply offset to file examples first, then pass remaining offset to Snowflake.
686 let file_example_count = examples.len();
687 let remaining_offset = if let Some(offset) = args.offset {
688 if offset >= file_example_count {
689 examples.clear();
690 offset - file_example_count
691 } else {
692 examples.splice(0..offset, []);
693 0
694 }
695 } else {
696 0
697 };
698
699 Progress::global().set_total_examples(examples.len());
700
701 let remaining_limit_for_snowflake =
702 args.limit.map(|limit| limit.saturating_sub(examples.len()));
703
704 if let Some(0) = remaining_limit_for_snowflake {
705 log::info!(
706 "skipping Snowflake inputs because --limit is already satisfied by example files"
707 );
708 } else {
709 let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
710
711 if !rejected_after_timestamps.is_empty() {
712 rejected_after_timestamps.sort();
713
714 let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
715 http_client.clone(),
716 &rejected_after_timestamps,
717 max_rows_per_timestamp,
718 remaining_offset,
719 background_executor.clone(),
720 Some(MIN_CAPTURE_VERSION),
721 )
722 .await?;
723 examples.append(&mut rejected_examples);
724 }
725
726 if !requested_after_timestamps.is_empty() {
727 requested_after_timestamps.sort();
728
729 let mut requested_examples = pull_examples::fetch_requested_examples_after(
730 http_client.clone(),
731 &requested_after_timestamps,
732 max_rows_per_timestamp,
733 remaining_offset,
734 background_executor.clone(),
735 Some(MIN_CAPTURE_VERSION),
736 )
737 .await?;
738 examples.append(&mut requested_examples);
739 }
740
741 if !captured_after_timestamps.is_empty() {
742 captured_after_timestamps.sort();
743
744 let mut captured_examples = pull_examples::fetch_captured_examples_after(
745 http_client.clone(),
746 &captured_after_timestamps,
747 max_rows_per_timestamp,
748 remaining_offset,
749 background_executor.clone(),
750 Some(MIN_CAPTURE_VERSION),
751 )
752 .await?;
753 examples.append(&mut captured_examples);
754 }
755
756 if !settled_after_timestamps.is_empty() {
757 settled_after_timestamps.sort();
758
759 let mut settled_examples = fetch_settled_examples_after(
760 http_client.clone(),
761 &settled_after_timestamps,
762 max_rows_per_timestamp,
763 remaining_offset,
764 background_executor.clone(),
765 Some(MIN_CAPTURE_VERSION),
766 )
767 .await?;
768 examples.append(&mut settled_examples);
769 }
770
771 if !rated_after_inputs.is_empty() {
772 rated_after_inputs.sort();
773
774 let mut rated_examples = pull_examples::fetch_rated_examples_after(
775 http_client,
776 &rated_after_inputs,
777 max_rows_per_timestamp,
778 remaining_offset,
779 background_executor,
780 Some(MIN_CAPTURE_VERSION),
781 )
782 .await?;
783 examples.append(&mut rated_examples);
784 }
785 }
786
787 crate::example::sort_examples_by_repo_and_rev(&mut examples);
788
789 if let Some(name_filter) = &args.name {
790 examples.retain(|example| example.spec.name.contains(name_filter));
791 }
792 if let Some(repo_filter) = &args.repo {
793 examples.retain(|example| example.spec.repository_url.contains(repo_filter));
794 }
795
796 // Skip resume logic for --in-place since input and output are the same file,
797 // which would incorrectly treat all input examples as already processed.
798 if !args.in_place {
799 if let Some(path) = output_path
800 && let Some(command) = &args.command
801 {
802 resume_from_output(path, &mut examples, command);
803 }
804 }
805
806 if let Some(max_duplicates) = args.max_duplicates {
807 deduplicate_examples(&mut examples, max_duplicates);
808 }
809
810 if let Some(limit) = args.limit {
811 examples.truncate(limit);
812 }
813
814 let progress = Progress::global();
815 progress.set_total_examples(examples.len());
816 progress.set_max_example_name_len(examples.iter().map(|e| &e.spec.name));
817
818 Ok(examples)
819}
820
821fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
822 let mut hasher = collections::FxHasher::default();
823 spec.hash(&mut hasher);
824 hasher.finish()
825}
826
827fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>, command: &Command) {
828 let file = match File::open(path) {
829 Ok(f) => f,
830 Err(_) => return,
831 };
832
833 let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
834
835 let reader = BufReader::new(file);
836 let mut kept_lines = Vec::new();
837 let mut kept_hashes = HashSet::default();
838
839 for line in reader.lines() {
840 let line = match line {
841 Ok(l) => l,
842 Err(_) => continue,
843 };
844
845 if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
846 let hash = spec_hash(&output_example.spec);
847 if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
848 let is_complete = match command {
849 Command::Qa(_) => output_example
850 .qa
851 .first()
852 .and_then(|q| q.as_ref())
853 .and_then(|q| q.confidence)
854 .is_some(),
855 Command::Repair(_) => output_example.predictions.iter().any(|p| {
856 p.provider == PredictionProvider::Repair && p.actual_patch.is_some()
857 }),
858 _ => true,
859 };
860 if is_complete {
861 kept_hashes.insert(hash);
862 kept_lines.push(line);
863 }
864 }
865 }
866 }
867
868 let total = examples.len();
869 let already_processed = kept_hashes.len();
870
871 eprintln!(
872 "Resuming: {}/{} examples already processed",
873 already_processed, total
874 );
875
876 let file = OpenOptions::new()
877 .write(true)
878 .truncate(true)
879 .open(path)
880 .expect("Failed to open output file for rewriting");
881 let mut writer = BufWriter::new(file);
882 for line in &kept_lines {
883 writeln!(writer, "{}", line).expect("Failed to write to output file");
884 }
885 writer.flush().expect("Failed to flush output file");
886
887 examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
888}
889
890fn main() {
891 let args = EpArgs::parse();
892
893 if args.printenv {
894 ::util::shell_env::print_env();
895 return;
896 }
897
898 let output = args.output_path();
899
900 if args.markdown && output.is_none() {
901 eprintln!("--markdown requires -o to specify the output directory");
902 std::process::exit(1);
903 }
904
905 let command = match &args.command {
906 Some(cmd) => cmd.clone(),
907 None => {
908 EpArgs::command().print_help().unwrap();
909 return;
910 }
911 };
912
913 match &command {
914 Command::ImportBatch(import_args) => {
915 smol::block_on(async {
916 match import_args.provider {
917 BatchProvider::Anthropic => {
918 let client = anthropic_client::AnthropicClient::batch(&paths::LLM_CACHE_DB)
919 .expect("Failed to create Anthropic client");
920 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
921 eprintln!("Error importing Anthropic batches: {:?}", e);
922 std::process::exit(1);
923 }
924 }
925 BatchProvider::Openai => {
926 let client = openai_client::OpenAiClient::batch(&paths::LLM_CACHE_DB)
927 .expect("Failed to create OpenAI client");
928 if let Err(e) = client.import_batches(&import_args.batch_ids).await {
929 eprintln!("Error importing OpenAI batches: {:?}", e);
930 std::process::exit(1);
931 }
932 }
933 }
934 println!(
935 "Successfully imported {} batch(es)",
936 import_args.batch_ids.len()
937 );
938 });
939 return;
940 }
941 Command::Clean => {
942 std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
943 return;
944 }
945 Command::PrintZetaFormats => {
946 use strum::IntoEnumIterator as _;
947 for format in ZetaFormat::iter() {
948 println!("{}", format.to_string().to_lowercase());
949 }
950 return;
951 }
952
953 Command::Synthesize(synth_args) => {
954 let output_dir = if let Some(output_dir) = args.output {
955 output_dir
956 } else {
957 let default_output_dir = env::current_dir()
958 .unwrap()
959 .join("crates/edit_prediction_cli/evals-generated");
960 if default_output_dir.parent().unwrap().exists() {
961 std::fs::create_dir(&default_output_dir).ok();
962 default_output_dir
963 } else {
964 panic!("output dir is required");
965 }
966 };
967 let config = SynthesizeConfig {
968 repo_urls: synth_args.repos.clone(),
969 count: synth_args.count,
970 max_commits: synth_args.max_commits,
971 output_dir,
972 fresh: synth_args.fresh,
973 };
974 smol::block_on(async {
975 if let Err(e) = run_synthesize(config).await {
976 eprintln!("Error: {:?}", e);
977 std::process::exit(1);
978 }
979 });
980 return;
981 }
982 Command::SplitCommit(split_commit_args) => {
983 if let Err(error) = split_commit::run_split_commit(
984 split_commit_args,
985 &args.inputs,
986 output.as_ref(),
987 args.failed,
988 ) {
989 eprintln!("{error:#}");
990 std::process::exit(1);
991 }
992 return;
993 }
994 Command::TruncatePatch(truncate_args) => {
995 if let Err(error) =
996 truncate_expected_patch::run_truncate_expected_patch(truncate_args, &args.inputs)
997 {
998 eprintln!("{error:#}");
999 std::process::exit(1);
1000 }
1001 return;
1002 }
1003 Command::Split(split_args) => {
1004 if let Err(error) = split_dataset::run_split(split_args, &args.inputs) {
1005 eprintln!("{error:#}");
1006 std::process::exit(1);
1007 }
1008 return;
1009 }
1010 Command::FilterLanguages(filter_args) => {
1011 if let Err(error) =
1012 run_filter_languages(filter_args, &args.inputs, args.output.as_ref())
1013 {
1014 eprintln!("{error:#}");
1015 std::process::exit(1);
1016 }
1017 return;
1018 }
1019
1020 _ => {}
1021 }
1022
1023 let http_client = Arc::new(ReqwestClient::new());
1024 let app = gpui_platform::headless().with_http_client(http_client);
1025
1026 app.run(move |cx| {
1027 let app_state = Arc::new(headless::init(cx));
1028 EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
1029
1030 cx.spawn(async move |cx| {
1031 let result = async {
1032 let examples = load_examples(
1033 app_state.client.http_client(),
1034 &args,
1035 output.as_ref(),
1036 cx.background_executor().clone(),
1037 )
1038 .await?;
1039
1040 match &command {
1041 Command::Predict(args) | Command::Score(args) => {
1042 predict::sync_batches(args.provider.as_ref()).await?;
1043 }
1044 Command::Eval(args) => {
1045 predict::sync_batches(args.predict.provider.as_ref()).await?;
1046 }
1047 Command::Qa(args) => {
1048 qa::sync_batches(args).await?;
1049 }
1050 Command::Repair(args) => {
1051 repair::sync_batches(args).await?;
1052 }
1053 _ => (),
1054 }
1055
1056 let failfast_on_single_example = examples.len() == 1;
1057
1058 // For --markdown mode, create the output directory if it doesn't exist
1059 if args.markdown {
1060 let dir = output.as_ref().expect("--markdown requires -o");
1061 if !dir.exists() {
1062 std::fs::create_dir_all(dir)
1063 .expect("Failed to create markdown output directory");
1064 }
1065 }
1066
1067 // Set up JSONL output writer (not used in markdown mode)
1068 let mut output_sender: Option<mpsc::UnboundedSender<String>> = None;
1069 let mut in_place_temp_path: Option<PathBuf> = None;
1070 if !args.markdown
1071 && let Some(output_path) = output.as_ref()
1072 {
1073 let write_path = if args.in_place {
1074 let temp = output_path.with_extension("jsonl.tmp");
1075 in_place_temp_path = Some(temp.clone());
1076 temp
1077 } else {
1078 output_path.clone()
1079 };
1080
1081 let file = OpenOptions::new()
1082 .create(true)
1083 .write(true)
1084 .truncate(args.in_place)
1085 .append(!args.in_place)
1086 .open(&write_path)
1087 .expect("Failed to open output file");
1088
1089 let mut writer = BufWriter::new(file);
1090 let (sender, mut receiver) = mpsc::unbounded::<String>();
1091 cx.background_spawn(async move {
1092 while let Some(line) = receiver.next().await {
1093 writeln!(writer, "{}", line).expect("Failed to write example");
1094 writer.flush().expect("Failed to flush output");
1095 }
1096 })
1097 .detach();
1098 output_sender = Some(sender);
1099 }
1100
1101 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
1102 let finished_examples = Mutex::new(Vec::new());
1103
1104 let mut tasks = Vec::new();
1105 for _ in 0..args.max_parallelism {
1106 tasks.push(async {
1107 loop {
1108 let Some(mut repo_examples) =
1109 grouped_examples.lock().unwrap().pop_front()
1110 else {
1111 break;
1112 };
1113 for example in &mut repo_examples {
1114 let example_progress =
1115 Progress::global().start_group(&example.spec.name);
1116
1117 let result = async {
1118 match &command {
1119 Command::Read(_) => {}
1120 Command::LoadProject => {
1121 run_load_project(
1122 example,
1123 app_state.clone(),
1124 &example_progress,
1125 cx.clone(),
1126 )
1127 .await?;
1128 }
1129 Command::Context => {
1130 run_context_retrieval(
1131 example,
1132 app_state.clone(),
1133 &example_progress,
1134 cx.clone(),
1135 )
1136 .await?;
1137 }
1138 Command::FormatPrompt(args) => {
1139 run_format_prompt(
1140 example,
1141 args,
1142 app_state.clone(),
1143 &example_progress,
1144 cx.clone(),
1145 )
1146 .await?;
1147 }
1148 Command::Predict(args) => {
1149 run_prediction(
1150 example,
1151 args,
1152 app_state.clone(),
1153 &example_progress,
1154 cx.clone(),
1155 )
1156 .await?;
1157 }
1158 Command::ParseOutput => {
1159 parse_output::run_parse_output(example)?;
1160 }
1161 Command::Distill => {
1162 run_distill(example).await?;
1163 }
1164 Command::Score(args) => {
1165 run_scoring(
1166 example,
1167 args,
1168 app_state.clone(),
1169 &example_progress,
1170 cx.clone(),
1171 )
1172 .await?;
1173 }
1174 Command::Eval(args) => {
1175 run_scoring(
1176 example,
1177 &args.predict,
1178 app_state.clone(),
1179 &example_progress,
1180 cx.clone(),
1181 )
1182 .await?;
1183 }
1184 Command::Qa(args) => {
1185 qa::run_qa(example, args, &example_progress).await?;
1186 }
1187 Command::Repair(args) => {
1188 repair::run_repair(example, args, &example_progress)
1189 .await?;
1190 }
1191 Command::Clean
1192 | Command::Synthesize(_)
1193 | Command::SplitCommit(_)
1194 | Command::Split(_)
1195 | Command::TruncatePatch(_)
1196 | Command::FilterLanguages(_)
1197 | Command::ImportBatch(_)
1198 | Command::PrintZetaFormats => {
1199 unreachable!()
1200 }
1201 }
1202 anyhow::Ok(())
1203 }
1204 .await;
1205
1206 let failed = if let Err(error) = result {
1207 handle_error(
1208 error,
1209 &args,
1210 &command,
1211 &app_state,
1212 failfast_on_single_example,
1213 &example,
1214 )
1215 .await;
1216 true
1217 } else {
1218 false
1219 };
1220
1221 let should_write = !failed || args.failed == FailedHandling::Keep;
1222 if should_write {
1223 if args.markdown {
1224 let markdown_dir =
1225 output.as_ref().expect("--markdown requires -o");
1226 let filename = format!("{}.md", example.spec.filename());
1227 let path = markdown_dir.join(&filename);
1228 let markdown = example.spec.to_markdown();
1229 std::fs::write(&path, &markdown)
1230 .expect("Failed to write markdown file");
1231 } else if let Some(ref mut sender) = output_sender.clone() {
1232 let line = serde_json::to_string(&example).unwrap();
1233 sender
1234 .send(line)
1235 .await
1236 .expect("Failed to send to output writer");
1237 } else if args.output.is_none()
1238 && !matches!(command, Command::Eval(_))
1239 {
1240 let line = serde_json::to_string(&example).unwrap();
1241 println!("{}", line);
1242 }
1243 }
1244 }
1245
1246 let project = repo_examples
1247 .iter()
1248 .find_map(|e| e.state.as_ref().map(|s| s.project.clone()));
1249
1250 if let Some(project) = project {
1251 let mut cx = cx.clone();
1252
1253 let shutdown_task: Task<()> =
1254 project.update(&mut cx, |project, cx| {
1255 let lsp_store = project.lsp_store();
1256 lsp_store.update(cx, |lsp_store, cx| {
1257 lsp_store.shutdown_all_language_servers(cx)
1258 })
1259 });
1260
1261 shutdown_task.await;
1262
1263 if let Some(ep_store) =
1264 cx.update(|cx| EditPredictionStore::try_global(cx))
1265 {
1266 ep_store.update(&mut cx, |store, _| {
1267 store.remove_project(&project);
1268 });
1269 }
1270 }
1271
1272 for example in &mut repo_examples {
1273 example.state.take();
1274 }
1275 finished_examples
1276 .lock()
1277 .unwrap()
1278 .extend_from_slice(&repo_examples);
1279 }
1280 });
1281 }
1282 futures::future::join_all(tasks).await;
1283
1284 Progress::global().finalize();
1285
1286 match &command {
1287 Command::Predict(args) | Command::Score(args) => {
1288 predict::sync_batches(args.provider.as_ref()).await?;
1289 }
1290 Command::Eval(args) => {
1291 predict::sync_batches(args.predict.provider.as_ref()).await?;
1292 }
1293 Command::Qa(args) => {
1294 qa::sync_batches(args).await?;
1295 }
1296 Command::Repair(args) => {
1297 repair::sync_batches(args).await?;
1298 }
1299 _ => (),
1300 }
1301
1302 match &command {
1303 Command::Eval(args) => {
1304 let examples = finished_examples.lock().unwrap();
1305 score::print_report(&examples, args.verbose);
1306 if let Some(summary_path) = &args.summary_json {
1307 score::write_summary_json(&examples, summary_path)?;
1308 }
1309 }
1310 Command::Repair(args) => {
1311 let examples = finished_examples.lock().unwrap();
1312 repair::print_report(&examples, args.confidence_threshold);
1313 }
1314 _ => (),
1315 };
1316
1317 // For --in-place, atomically rename temp file to original
1318 if let Some(temp_path) = &in_place_temp_path {
1319 let final_path = output.as_ref().expect("in_place_temp_path requires output");
1320 std::fs::rename(temp_path, final_path)
1321 .expect("Failed to rename temp file to final output");
1322 }
1323
1324 anyhow::Ok(())
1325 }
1326 .await;
1327
1328 if let Err(e) = result {
1329 panic!("Fatal error: {:?}", e);
1330 }
1331
1332 let _ = cx.update(|cx| cx.quit());
1333 })
1334 .detach();
1335 });
1336}
1337
1338async fn handle_error(
1339 error: anyhow::Error,
1340 args: &EpArgs,
1341 command: &Command,
1342 app_state: &Arc<headless::EpAppState>,
1343 failfast_on_single_example: bool,
1344 example: &Example,
1345) {
1346 Progress::global().increment_failed();
1347
1348 let msg;
1349 if !matches!(args.failed, FailedHandling::SkipNoFiles) {
1350 let example_name = example.spec.filename();
1351
1352 let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
1353 app_state
1354 .fs
1355 .write(
1356 &failed_example_path,
1357 &serde_json::to_vec_pretty(&example).unwrap(),
1358 )
1359 .await
1360 .unwrap();
1361 let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
1362 app_state
1363 .fs
1364 .write(&err_path, format!("{error:?}").as_bytes())
1365 .await
1366 .unwrap();
1367
1368 let failed_jsonl_path = RUN_DIR.join("failed.jsonl");
1369 let mut file = OpenOptions::new()
1370 .create(true)
1371 .append(true)
1372 .open(&failed_jsonl_path)
1373 .expect("Failed to open failed.jsonl");
1374 writeln!(file, "{}", serde_json::to_string(example).unwrap())
1375 .expect("Failed to write to failed.jsonl");
1376
1377 let cursor_path = match example.repo_name() {
1378 Ok(repo_name) => repo_name.worktree_path().join(&example.spec.cursor_path),
1379 Err(_) => example.spec.cursor_path.as_ref().to_path_buf(),
1380 };
1381 msg = format!(
1382 indoc::indoc! {"
1383 While processing \"{}\":
1384
1385 \x1b[31m{:?}\x1b[0m
1386
1387 Example: \x1b[36m{}\x1b[0m
1388 Error file: \x1b[36m{}\x1b[0m
1389 Cursor file: \x1b[36m{}\x1b[0m
1390 Re-run: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
1391 "},
1392 example.spec.name,
1393 error,
1394 failed_example_path.display(),
1395 err_path.display(),
1396 cursor_path.display(),
1397 command,
1398 failed_example_path.display(),
1399 );
1400 } else {
1401 msg = format!(
1402 indoc::indoc! {"
1403 While processing \"{}\":
1404
1405 \x1b[31m{:?}\x1b[0m
1406 "},
1407 example.spec.name, error
1408 );
1409 }
1410
1411 if args.failfast || failfast_on_single_example {
1412 Progress::global().finalize();
1413 panic!("{}", msg);
1414 } else {
1415 log::error!("{}", msg);
1416 }
1417}